edsl 0.1.39.dev1__py3-none-any.whl → 0.1.39.dev3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +332 -332
- edsl/BaseDiff.py +260 -260
- edsl/TemplateLoader.py +24 -24
- edsl/__init__.py +49 -49
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +867 -867
- edsl/agents/AgentList.py +413 -413
- edsl/agents/Invigilator.py +233 -233
- edsl/agents/InvigilatorBase.py +270 -265
- edsl/agents/PromptConstructor.py +354 -354
- edsl/agents/__init__.py +3 -3
- edsl/agents/descriptors.py +99 -99
- edsl/agents/prompt_helpers.py +129 -129
- edsl/auto/AutoStudy.py +117 -117
- edsl/auto/StageBase.py +230 -230
- edsl/auto/StageGenerateSurvey.py +178 -178
- edsl/auto/StageLabelQuestions.py +125 -125
- edsl/auto/StagePersona.py +61 -61
- edsl/auto/StagePersonaDimensionValueRanges.py +88 -88
- edsl/auto/StagePersonaDimensionValues.py +74 -74
- edsl/auto/StagePersonaDimensions.py +69 -69
- edsl/auto/StageQuestions.py +73 -73
- edsl/auto/SurveyCreatorPipeline.py +21 -21
- edsl/auto/utilities.py +224 -224
- edsl/base/Base.py +279 -279
- edsl/config.py +157 -157
- edsl/conversation/Conversation.py +290 -290
- edsl/conversation/car_buying.py +58 -58
- edsl/conversation/chips.py +95 -95
- edsl/conversation/mug_negotiation.py +81 -81
- edsl/conversation/next_speaker_utilities.py +93 -93
- edsl/coop/PriceFetcher.py +54 -54
- edsl/coop/__init__.py +2 -2
- edsl/coop/coop.py +1028 -1028
- edsl/coop/utils.py +131 -131
- edsl/data/Cache.py +555 -555
- edsl/data/CacheEntry.py +233 -233
- edsl/data/CacheHandler.py +149 -149
- edsl/data/RemoteCacheSync.py +78 -78
- edsl/data/SQLiteDict.py +292 -292
- edsl/data/__init__.py +4 -4
- edsl/data/orm.py +10 -10
- edsl/data_transfer_models.py +73 -73
- edsl/enums.py +175 -175
- edsl/exceptions/BaseException.py +21 -21
- edsl/exceptions/__init__.py +54 -54
- edsl/exceptions/agents.py +42 -42
- edsl/exceptions/cache.py +5 -5
- edsl/exceptions/configuration.py +16 -16
- edsl/exceptions/coop.py +10 -10
- edsl/exceptions/data.py +14 -14
- edsl/exceptions/general.py +34 -34
- edsl/exceptions/jobs.py +33 -33
- edsl/exceptions/language_models.py +63 -63
- edsl/exceptions/prompts.py +15 -15
- edsl/exceptions/questions.py +91 -91
- edsl/exceptions/results.py +29 -29
- edsl/exceptions/scenarios.py +22 -22
- edsl/exceptions/surveys.py +37 -37
- edsl/inference_services/AnthropicService.py +87 -87
- edsl/inference_services/AwsBedrock.py +120 -120
- edsl/inference_services/AzureAI.py +217 -217
- edsl/inference_services/DeepInfraService.py +18 -18
- edsl/inference_services/GoogleService.py +148 -148
- edsl/inference_services/GroqService.py +20 -20
- edsl/inference_services/InferenceServiceABC.py +147 -147
- edsl/inference_services/InferenceServicesCollection.py +97 -97
- edsl/inference_services/MistralAIService.py +123 -123
- edsl/inference_services/OllamaService.py +18 -18
- edsl/inference_services/OpenAIService.py +224 -224
- edsl/inference_services/PerplexityService.py +163 -163
- edsl/inference_services/TestService.py +89 -89
- edsl/inference_services/TogetherAIService.py +170 -170
- edsl/inference_services/models_available_cache.py +118 -118
- edsl/inference_services/rate_limits_cache.py +25 -25
- edsl/inference_services/registry.py +41 -41
- edsl/inference_services/write_available.py +10 -10
- edsl/jobs/Answers.py +56 -56
- edsl/jobs/Jobs.py +898 -898
- edsl/jobs/JobsChecks.py +147 -147
- edsl/jobs/JobsPrompts.py +268 -268
- edsl/jobs/JobsRemoteInferenceHandler.py +239 -239
- edsl/jobs/__init__.py +1 -1
- edsl/jobs/buckets/BucketCollection.py +63 -63
- edsl/jobs/buckets/ModelBuckets.py +65 -65
- edsl/jobs/buckets/TokenBucket.py +251 -251
- edsl/jobs/interviews/Interview.py +661 -661
- edsl/jobs/interviews/InterviewExceptionCollection.py +99 -99
- edsl/jobs/interviews/InterviewExceptionEntry.py +186 -186
- edsl/jobs/interviews/InterviewStatistic.py +63 -63
- edsl/jobs/interviews/InterviewStatisticsCollection.py +25 -25
- edsl/jobs/interviews/InterviewStatusDictionary.py +78 -78
- edsl/jobs/interviews/InterviewStatusLog.py +92 -92
- edsl/jobs/interviews/ReportErrors.py +66 -66
- edsl/jobs/interviews/interview_status_enum.py +9 -9
- edsl/jobs/runners/JobsRunnerAsyncio.py +466 -466
- edsl/jobs/runners/JobsRunnerStatus.py +330 -330
- edsl/jobs/tasks/QuestionTaskCreator.py +242 -242
- edsl/jobs/tasks/TaskCreators.py +64 -64
- edsl/jobs/tasks/TaskHistory.py +450 -450
- edsl/jobs/tasks/TaskStatusLog.py +23 -23
- edsl/jobs/tasks/task_status_enum.py +163 -163
- edsl/jobs/tokens/InterviewTokenUsage.py +27 -27
- edsl/jobs/tokens/TokenUsage.py +34 -34
- edsl/language_models/KeyLookup.py +30 -30
- edsl/language_models/LanguageModel.py +668 -668
- edsl/language_models/ModelList.py +155 -155
- edsl/language_models/RegisterLanguageModelsMeta.py +184 -184
- edsl/language_models/__init__.py +3 -3
- edsl/language_models/fake_openai_call.py +15 -15
- edsl/language_models/fake_openai_service.py +61 -61
- edsl/language_models/registry.py +190 -190
- edsl/language_models/repair.py +156 -156
- edsl/language_models/unused/ReplicateBase.py +83 -83
- edsl/language_models/utilities.py +64 -64
- edsl/notebooks/Notebook.py +258 -258
- edsl/notebooks/__init__.py +1 -1
- edsl/prompts/Prompt.py +362 -362
- edsl/prompts/__init__.py +2 -2
- edsl/questions/AnswerValidatorMixin.py +289 -289
- edsl/questions/QuestionBase.py +664 -664
- edsl/questions/QuestionBaseGenMixin.py +161 -161
- edsl/questions/QuestionBasePromptsMixin.py +217 -217
- edsl/questions/QuestionBudget.py +227 -227
- edsl/questions/QuestionCheckBox.py +359 -359
- edsl/questions/QuestionExtract.py +182 -182
- edsl/questions/QuestionFreeText.py +114 -114
- edsl/questions/QuestionFunctional.py +166 -166
- edsl/questions/QuestionList.py +231 -231
- edsl/questions/QuestionMultipleChoice.py +286 -286
- edsl/questions/QuestionNumerical.py +153 -153
- edsl/questions/QuestionRank.py +324 -324
- edsl/questions/Quick.py +41 -41
- edsl/questions/RegisterQuestionsMeta.py +71 -71
- edsl/questions/ResponseValidatorABC.py +174 -174
- edsl/questions/SimpleAskMixin.py +73 -73
- edsl/questions/__init__.py +26 -26
- edsl/questions/compose_questions.py +98 -98
- edsl/questions/decorators.py +21 -21
- edsl/questions/derived/QuestionLikertFive.py +76 -76
- edsl/questions/derived/QuestionLinearScale.py +87 -87
- edsl/questions/derived/QuestionTopK.py +93 -93
- edsl/questions/derived/QuestionYesNo.py +82 -82
- edsl/questions/descriptors.py +413 -413
- edsl/questions/prompt_templates/question_budget.jinja +13 -13
- edsl/questions/prompt_templates/question_checkbox.jinja +32 -32
- edsl/questions/prompt_templates/question_extract.jinja +11 -11
- edsl/questions/prompt_templates/question_free_text.jinja +3 -3
- edsl/questions/prompt_templates/question_linear_scale.jinja +11 -11
- edsl/questions/prompt_templates/question_list.jinja +17 -17
- edsl/questions/prompt_templates/question_multiple_choice.jinja +33 -33
- edsl/questions/prompt_templates/question_numerical.jinja +36 -36
- edsl/questions/question_registry.py +177 -177
- edsl/questions/settings.py +12 -12
- edsl/questions/templates/budget/answering_instructions.jinja +7 -7
- edsl/questions/templates/budget/question_presentation.jinja +7 -7
- edsl/questions/templates/checkbox/answering_instructions.jinja +10 -10
- edsl/questions/templates/checkbox/question_presentation.jinja +22 -22
- edsl/questions/templates/extract/answering_instructions.jinja +7 -7
- edsl/questions/templates/likert_five/answering_instructions.jinja +10 -10
- edsl/questions/templates/likert_five/question_presentation.jinja +11 -11
- edsl/questions/templates/linear_scale/answering_instructions.jinja +5 -5
- edsl/questions/templates/linear_scale/question_presentation.jinja +5 -5
- edsl/questions/templates/list/answering_instructions.jinja +3 -3
- edsl/questions/templates/list/question_presentation.jinja +5 -5
- edsl/questions/templates/multiple_choice/answering_instructions.jinja +9 -9
- edsl/questions/templates/multiple_choice/question_presentation.jinja +11 -11
- edsl/questions/templates/numerical/answering_instructions.jinja +6 -6
- edsl/questions/templates/numerical/question_presentation.jinja +6 -6
- edsl/questions/templates/rank/answering_instructions.jinja +11 -11
- edsl/questions/templates/rank/question_presentation.jinja +15 -15
- edsl/questions/templates/top_k/answering_instructions.jinja +8 -8
- edsl/questions/templates/top_k/question_presentation.jinja +22 -22
- edsl/questions/templates/yes_no/answering_instructions.jinja +6 -6
- edsl/questions/templates/yes_no/question_presentation.jinja +11 -11
- edsl/results/CSSParameterizer.py +108 -108
- edsl/results/Dataset.py +424 -424
- edsl/results/DatasetExportMixin.py +731 -731
- edsl/results/DatasetTree.py +275 -275
- edsl/results/Result.py +465 -465
- edsl/results/Results.py +1165 -1165
- edsl/results/ResultsDBMixin.py +238 -238
- edsl/results/ResultsExportMixin.py +43 -43
- edsl/results/ResultsFetchMixin.py +33 -33
- edsl/results/ResultsGGMixin.py +121 -121
- edsl/results/ResultsToolsMixin.py +98 -98
- edsl/results/Selector.py +135 -135
- edsl/results/TableDisplay.py +198 -198
- edsl/results/__init__.py +2 -2
- edsl/results/table_display.css +77 -77
- edsl/results/tree_explore.py +115 -115
- edsl/scenarios/FileStore.py +632 -632
- edsl/scenarios/Scenario.py +601 -601
- edsl/scenarios/ScenarioHtmlMixin.py +64 -64
- edsl/scenarios/ScenarioJoin.py +127 -127
- edsl/scenarios/ScenarioList.py +1287 -1287
- edsl/scenarios/ScenarioListExportMixin.py +52 -52
- edsl/scenarios/ScenarioListPdfMixin.py +261 -261
- edsl/scenarios/__init__.py +4 -4
- edsl/shared.py +1 -1
- edsl/study/ObjectEntry.py +173 -173
- edsl/study/ProofOfWork.py +113 -113
- edsl/study/SnapShot.py +80 -80
- edsl/study/Study.py +528 -528
- edsl/study/__init__.py +4 -4
- edsl/surveys/DAG.py +148 -148
- edsl/surveys/Memory.py +31 -31
- edsl/surveys/MemoryPlan.py +244 -244
- edsl/surveys/Rule.py +326 -326
- edsl/surveys/RuleCollection.py +387 -387
- edsl/surveys/Survey.py +1801 -1801
- edsl/surveys/SurveyCSS.py +261 -261
- edsl/surveys/SurveyExportMixin.py +259 -259
- edsl/surveys/SurveyFlowVisualizationMixin.py +179 -179
- edsl/surveys/SurveyQualtricsImport.py +284 -284
- edsl/surveys/__init__.py +3 -3
- edsl/surveys/base.py +53 -53
- edsl/surveys/descriptors.py +56 -56
- edsl/surveys/instructions/ChangeInstruction.py +49 -49
- edsl/surveys/instructions/Instruction.py +65 -65
- edsl/surveys/instructions/InstructionCollection.py +77 -77
- edsl/templates/error_reporting/base.html +23 -23
- edsl/templates/error_reporting/exceptions_by_model.html +34 -34
- edsl/templates/error_reporting/exceptions_by_question_name.html +16 -16
- edsl/templates/error_reporting/exceptions_by_type.html +16 -16
- edsl/templates/error_reporting/interview_details.html +115 -115
- edsl/templates/error_reporting/interviews.html +19 -19
- edsl/templates/error_reporting/overview.html +4 -4
- edsl/templates/error_reporting/performance_plot.html +1 -1
- edsl/templates/error_reporting/report.css +73 -73
- edsl/templates/error_reporting/report.html +117 -117
- edsl/templates/error_reporting/report.js +25 -25
- edsl/tools/__init__.py +1 -1
- edsl/tools/clusters.py +192 -192
- edsl/tools/embeddings.py +27 -27
- edsl/tools/embeddings_plotting.py +118 -118
- edsl/tools/plotting.py +112 -112
- edsl/tools/summarize.py +18 -18
- edsl/utilities/SystemInfo.py +28 -28
- edsl/utilities/__init__.py +22 -22
- edsl/utilities/ast_utilities.py +25 -25
- edsl/utilities/data/Registry.py +6 -6
- edsl/utilities/data/__init__.py +1 -1
- edsl/utilities/data/scooter_results.json +1 -1
- edsl/utilities/decorators.py +77 -77
- edsl/utilities/gcp_bucket/cloud_storage.py +96 -96
- edsl/utilities/interface.py +627 -627
- edsl/utilities/naming_utilities.py +263 -263
- edsl/utilities/repair_functions.py +28 -28
- edsl/utilities/restricted_python.py +70 -70
- edsl/utilities/utilities.py +424 -424
- {edsl-0.1.39.dev1.dist-info → edsl-0.1.39.dev3.dist-info}/LICENSE +21 -21
- {edsl-0.1.39.dev1.dist-info → edsl-0.1.39.dev3.dist-info}/METADATA +1 -1
- edsl-0.1.39.dev3.dist-info/RECORD +277 -0
- edsl-0.1.39.dev1.dist-info/RECORD +0 -277
- {edsl-0.1.39.dev1.dist-info → edsl-0.1.39.dev3.dist-info}/WHEEL +0 -0
@@ -1,330 +1,330 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
|
-
import os
|
4
|
-
import time
|
5
|
-
import requests
|
6
|
-
import warnings
|
7
|
-
from abc import ABC, abstractmethod
|
8
|
-
from dataclasses import dataclass
|
9
|
-
|
10
|
-
from typing import Any, List, DefaultDict, Optional, Dict
|
11
|
-
from collections import defaultdict
|
12
|
-
from uuid import UUID
|
13
|
-
|
14
|
-
from edsl.jobs.tokens.InterviewTokenUsage import InterviewTokenUsage
|
15
|
-
|
16
|
-
InterviewTokenUsageMapping = DefaultDict[str, InterviewTokenUsage]
|
17
|
-
|
18
|
-
from edsl.jobs.interviews.InterviewStatistic import InterviewStatistic
|
19
|
-
from edsl.jobs.tokens.InterviewTokenUsage import InterviewTokenUsage
|
20
|
-
|
21
|
-
|
22
|
-
@dataclass
|
23
|
-
class ModelInfo:
|
24
|
-
model_name: str
|
25
|
-
TPM_limit_k: float
|
26
|
-
RPM_limit_k: float
|
27
|
-
num_tasks_waiting: int
|
28
|
-
token_usage_info: dict
|
29
|
-
|
30
|
-
|
31
|
-
@dataclass
|
32
|
-
class ModelTokenUsageStats:
|
33
|
-
token_usage_type: str
|
34
|
-
details: List[dict]
|
35
|
-
cost: str
|
36
|
-
|
37
|
-
|
38
|
-
class JobsRunnerStatusBase(ABC):
|
39
|
-
def __init__(
|
40
|
-
self,
|
41
|
-
jobs_runner: "JobsRunnerAsyncio",
|
42
|
-
n: int,
|
43
|
-
refresh_rate: float = 1,
|
44
|
-
endpoint_url: Optional[str] = "http://localhost:8000",
|
45
|
-
job_uuid: Optional[UUID] = None,
|
46
|
-
api_key: str = None,
|
47
|
-
):
|
48
|
-
self.jobs_runner = jobs_runner
|
49
|
-
|
50
|
-
# The uuid of the job on Coop
|
51
|
-
self.job_uuid = job_uuid
|
52
|
-
|
53
|
-
self.base_url = f"{endpoint_url}"
|
54
|
-
|
55
|
-
self.start_time = time.time()
|
56
|
-
self.completed_interviews = []
|
57
|
-
self.refresh_rate = refresh_rate
|
58
|
-
self.statistics = [
|
59
|
-
"elapsed_time",
|
60
|
-
"total_interviews_requested",
|
61
|
-
"completed_interviews",
|
62
|
-
# "percent_complete",
|
63
|
-
"average_time_per_interview",
|
64
|
-
# "task_remaining",
|
65
|
-
"estimated_time_remaining",
|
66
|
-
"exceptions",
|
67
|
-
"unfixed_exceptions",
|
68
|
-
"throughput",
|
69
|
-
]
|
70
|
-
self.num_total_interviews = n * len(self.jobs_runner.interviews)
|
71
|
-
|
72
|
-
self.distinct_models = list(
|
73
|
-
set(i.model.model for i in self.jobs_runner.interviews)
|
74
|
-
)
|
75
|
-
|
76
|
-
self.completed_interview_by_model = defaultdict(list)
|
77
|
-
|
78
|
-
self.api_key = api_key or os.getenv("EXPECTED_PARROT_API_KEY")
|
79
|
-
|
80
|
-
@abstractmethod
|
81
|
-
def has_ep_api_key(self):
|
82
|
-
"""
|
83
|
-
Checks if the user has an Expected Parrot API key.
|
84
|
-
"""
|
85
|
-
pass
|
86
|
-
|
87
|
-
def get_status_dict(self) -> Dict[str, Any]:
|
88
|
-
"""
|
89
|
-
Converts current status into a JSON-serializable dictionary.
|
90
|
-
"""
|
91
|
-
# Get all statistics
|
92
|
-
stats = {}
|
93
|
-
for stat_name in self.statistics:
|
94
|
-
stat = self._compute_statistic(stat_name)
|
95
|
-
name, value = list(stat.items())[0]
|
96
|
-
stats[name] = value
|
97
|
-
|
98
|
-
# Calculate overall progress
|
99
|
-
total_interviews = len(self.jobs_runner.total_interviews)
|
100
|
-
completed = len(self.completed_interviews)
|
101
|
-
|
102
|
-
# Get model-specific progress
|
103
|
-
model_progress = {}
|
104
|
-
for model in self.distinct_models:
|
105
|
-
completed_for_model = len(self.completed_interview_by_model[model])
|
106
|
-
target_for_model = int(
|
107
|
-
self.num_total_interviews / len(self.distinct_models)
|
108
|
-
)
|
109
|
-
model_progress[model] = {
|
110
|
-
"completed": completed_for_model,
|
111
|
-
"total": target_for_model,
|
112
|
-
"percent": (
|
113
|
-
(completed_for_model / target_for_model * 100)
|
114
|
-
if target_for_model > 0
|
115
|
-
else 0
|
116
|
-
),
|
117
|
-
}
|
118
|
-
|
119
|
-
status_dict = {
|
120
|
-
"overall_progress": {
|
121
|
-
"completed": completed,
|
122
|
-
"total": total_interviews,
|
123
|
-
"percent": (
|
124
|
-
(completed / total_interviews * 100) if total_interviews > 0 else 0
|
125
|
-
),
|
126
|
-
},
|
127
|
-
"language_model_progress": model_progress,
|
128
|
-
"statistics": stats,
|
129
|
-
"status": "completed" if completed >= total_interviews else "running",
|
130
|
-
}
|
131
|
-
|
132
|
-
model_queues = {}
|
133
|
-
for model, bucket in self.jobs_runner.bucket_collection.items():
|
134
|
-
model_name = model.model
|
135
|
-
model_queues[model_name] = {
|
136
|
-
"language_model_name": model_name,
|
137
|
-
"requests_bucket": {
|
138
|
-
"completed": bucket.requests_bucket.num_released,
|
139
|
-
"requested": bucket.requests_bucket.num_requests,
|
140
|
-
"tokens_returned": bucket.requests_bucket.tokens_returned,
|
141
|
-
"target_rate": round(bucket.requests_bucket.target_rate, 1),
|
142
|
-
"current_rate": round(bucket.requests_bucket.get_throughput(), 1),
|
143
|
-
},
|
144
|
-
"tokens_bucket": {
|
145
|
-
"completed": bucket.tokens_bucket.num_released,
|
146
|
-
"requested": bucket.tokens_bucket.num_requests,
|
147
|
-
"tokens_returned": bucket.tokens_bucket.tokens_returned,
|
148
|
-
"target_rate": round(bucket.tokens_bucket.target_rate, 1),
|
149
|
-
"current_rate": round(bucket.tokens_bucket.get_throughput(), 1),
|
150
|
-
},
|
151
|
-
}
|
152
|
-
status_dict["language_model_queues"] = model_queues
|
153
|
-
return status_dict
|
154
|
-
|
155
|
-
@abstractmethod
|
156
|
-
def setup(self):
|
157
|
-
"""
|
158
|
-
Conducts any setup that needs to happen prior to sending status updates.
|
159
|
-
|
160
|
-
Ex. For a local job, creates a job in the Coop database.
|
161
|
-
"""
|
162
|
-
pass
|
163
|
-
|
164
|
-
@abstractmethod
|
165
|
-
def send_status_update(self):
|
166
|
-
"""
|
167
|
-
Updates the current status of the job.
|
168
|
-
"""
|
169
|
-
pass
|
170
|
-
|
171
|
-
def add_completed_interview(self, result):
|
172
|
-
self.completed_interviews.append(result.interview_hash)
|
173
|
-
|
174
|
-
relevant_model = result.model.model
|
175
|
-
self.completed_interview_by_model[relevant_model].append(result.interview_hash)
|
176
|
-
|
177
|
-
def _compute_statistic(self, stat_name: str):
|
178
|
-
completed_tasks = self.completed_interviews
|
179
|
-
elapsed_time = time.time() - self.start_time
|
180
|
-
interviews = self.jobs_runner.total_interviews
|
181
|
-
|
182
|
-
stat_definitions = {
|
183
|
-
"elapsed_time": lambda: InterviewStatistic(
|
184
|
-
"elapsed_time", value=elapsed_time, digits=1, units="sec."
|
185
|
-
),
|
186
|
-
"total_interviews_requested": lambda: InterviewStatistic(
|
187
|
-
"total_interviews_requested", value=len(interviews), units=""
|
188
|
-
),
|
189
|
-
"completed_interviews": lambda: InterviewStatistic(
|
190
|
-
"completed_interviews", value=len(completed_tasks), units=""
|
191
|
-
),
|
192
|
-
"percent_complete": lambda: InterviewStatistic(
|
193
|
-
"percent_complete",
|
194
|
-
value=(
|
195
|
-
len(completed_tasks) / len(interviews) * 100
|
196
|
-
if len(interviews) > 0
|
197
|
-
else 0
|
198
|
-
),
|
199
|
-
digits=1,
|
200
|
-
units="%",
|
201
|
-
),
|
202
|
-
"average_time_per_interview": lambda: InterviewStatistic(
|
203
|
-
"average_time_per_interview",
|
204
|
-
value=elapsed_time / len(completed_tasks) if completed_tasks else 0,
|
205
|
-
digits=2,
|
206
|
-
units="sec.",
|
207
|
-
),
|
208
|
-
"task_remaining": lambda: InterviewStatistic(
|
209
|
-
"task_remaining", value=len(interviews) - len(completed_tasks), units=""
|
210
|
-
),
|
211
|
-
"estimated_time_remaining": lambda: InterviewStatistic(
|
212
|
-
"estimated_time_remaining",
|
213
|
-
value=(
|
214
|
-
(len(interviews) - len(completed_tasks))
|
215
|
-
* (elapsed_time / len(completed_tasks))
|
216
|
-
if len(completed_tasks) > 0
|
217
|
-
else 0
|
218
|
-
),
|
219
|
-
digits=1,
|
220
|
-
units="sec.",
|
221
|
-
),
|
222
|
-
"exceptions": lambda: InterviewStatistic(
|
223
|
-
"exceptions",
|
224
|
-
value=sum(len(i.exceptions) for i in interviews),
|
225
|
-
units="",
|
226
|
-
),
|
227
|
-
"unfixed_exceptions": lambda: InterviewStatistic(
|
228
|
-
"unfixed_exceptions",
|
229
|
-
value=sum(i.exceptions.num_unfixed() for i in interviews),
|
230
|
-
units="",
|
231
|
-
),
|
232
|
-
"throughput": lambda: InterviewStatistic(
|
233
|
-
"throughput",
|
234
|
-
value=len(completed_tasks) / elapsed_time if elapsed_time > 0 else 0,
|
235
|
-
digits=2,
|
236
|
-
units="interviews/sec.",
|
237
|
-
),
|
238
|
-
}
|
239
|
-
return stat_definitions[stat_name]()
|
240
|
-
|
241
|
-
def update_progress(self, stop_event):
|
242
|
-
while not stop_event.is_set():
|
243
|
-
self.send_status_update()
|
244
|
-
time.sleep(self.refresh_rate)
|
245
|
-
|
246
|
-
self.send_status_update()
|
247
|
-
|
248
|
-
|
249
|
-
class JobsRunnerStatus(JobsRunnerStatusBase):
|
250
|
-
@property
|
251
|
-
def create_url(self) -> str:
|
252
|
-
return f"{self.base_url}/api/v0/local-job"
|
253
|
-
|
254
|
-
@property
|
255
|
-
def viewing_url(self) -> str:
|
256
|
-
return f"{self.base_url}/home/local-job-progress/{str(self.job_uuid)}"
|
257
|
-
|
258
|
-
@property
|
259
|
-
def update_url(self) -> str:
|
260
|
-
return f"{self.base_url}/api/v0/local-job/{str(self.job_uuid)}"
|
261
|
-
|
262
|
-
def setup(self) -> None:
|
263
|
-
"""
|
264
|
-
Creates a local job on Coop if one does not already exist.
|
265
|
-
"""
|
266
|
-
|
267
|
-
headers = {"Content-Type": "application/json"}
|
268
|
-
|
269
|
-
if self.api_key:
|
270
|
-
headers["Authorization"] = f"Bearer {self.api_key}"
|
271
|
-
else:
|
272
|
-
headers["Authorization"] = f"Bearer None"
|
273
|
-
|
274
|
-
if self.job_uuid is None:
|
275
|
-
# Create a new local job
|
276
|
-
response = requests.post(
|
277
|
-
self.create_url,
|
278
|
-
headers=headers,
|
279
|
-
timeout=1,
|
280
|
-
)
|
281
|
-
response.raise_for_status()
|
282
|
-
data = response.json()
|
283
|
-
self.job_uuid = data.get("job_uuid")
|
284
|
-
|
285
|
-
print(f"Running with progress bar. View progress at {self.viewing_url}")
|
286
|
-
|
287
|
-
def send_status_update(self) -> None:
|
288
|
-
"""
|
289
|
-
Sends current status to the web endpoint using the instance's job_uuid.
|
290
|
-
"""
|
291
|
-
try:
|
292
|
-
# Get the status dictionary and add the job_id
|
293
|
-
status_dict = self.get_status_dict()
|
294
|
-
|
295
|
-
# Make the UUID JSON serializable
|
296
|
-
status_dict["job_id"] = str(self.job_uuid)
|
297
|
-
|
298
|
-
headers = {"Content-Type": "application/json"}
|
299
|
-
|
300
|
-
if self.api_key:
|
301
|
-
headers["Authorization"] = f"Bearer {self.api_key}"
|
302
|
-
else:
|
303
|
-
headers["Authorization"] = f"Bearer None"
|
304
|
-
|
305
|
-
# Send the update
|
306
|
-
response = requests.patch(
|
307
|
-
self.update_url,
|
308
|
-
json=status_dict,
|
309
|
-
headers=headers,
|
310
|
-
timeout=1,
|
311
|
-
)
|
312
|
-
response.raise_for_status()
|
313
|
-
except requests.exceptions.RequestException as e:
|
314
|
-
print(f"Failed to send status update for job {self.job_uuid}: {e}")
|
315
|
-
|
316
|
-
def has_ep_api_key(self) -> bool:
|
317
|
-
"""
|
318
|
-
Returns True if the user has an Expected Parrot API key. Otherwise, returns False.
|
319
|
-
"""
|
320
|
-
|
321
|
-
if self.api_key is not None:
|
322
|
-
return True
|
323
|
-
else:
|
324
|
-
return False
|
325
|
-
|
326
|
-
|
327
|
-
if __name__ == "__main__":
|
328
|
-
import doctest
|
329
|
-
|
330
|
-
doctest.testmod(optionflags=doctest.ELLIPSIS)
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import os
|
4
|
+
import time
|
5
|
+
import requests
|
6
|
+
import warnings
|
7
|
+
from abc import ABC, abstractmethod
|
8
|
+
from dataclasses import dataclass
|
9
|
+
|
10
|
+
from typing import Any, List, DefaultDict, Optional, Dict
|
11
|
+
from collections import defaultdict
|
12
|
+
from uuid import UUID
|
13
|
+
|
14
|
+
from edsl.jobs.tokens.InterviewTokenUsage import InterviewTokenUsage
|
15
|
+
|
16
|
+
InterviewTokenUsageMapping = DefaultDict[str, InterviewTokenUsage]
|
17
|
+
|
18
|
+
from edsl.jobs.interviews.InterviewStatistic import InterviewStatistic
|
19
|
+
from edsl.jobs.tokens.InterviewTokenUsage import InterviewTokenUsage
|
20
|
+
|
21
|
+
|
22
|
+
@dataclass
|
23
|
+
class ModelInfo:
|
24
|
+
model_name: str
|
25
|
+
TPM_limit_k: float
|
26
|
+
RPM_limit_k: float
|
27
|
+
num_tasks_waiting: int
|
28
|
+
token_usage_info: dict
|
29
|
+
|
30
|
+
|
31
|
+
@dataclass
|
32
|
+
class ModelTokenUsageStats:
|
33
|
+
token_usage_type: str
|
34
|
+
details: List[dict]
|
35
|
+
cost: str
|
36
|
+
|
37
|
+
|
38
|
+
class JobsRunnerStatusBase(ABC):
|
39
|
+
def __init__(
|
40
|
+
self,
|
41
|
+
jobs_runner: "JobsRunnerAsyncio",
|
42
|
+
n: int,
|
43
|
+
refresh_rate: float = 1,
|
44
|
+
endpoint_url: Optional[str] = "http://localhost:8000",
|
45
|
+
job_uuid: Optional[UUID] = None,
|
46
|
+
api_key: str = None,
|
47
|
+
):
|
48
|
+
self.jobs_runner = jobs_runner
|
49
|
+
|
50
|
+
# The uuid of the job on Coop
|
51
|
+
self.job_uuid = job_uuid
|
52
|
+
|
53
|
+
self.base_url = f"{endpoint_url}"
|
54
|
+
|
55
|
+
self.start_time = time.time()
|
56
|
+
self.completed_interviews = []
|
57
|
+
self.refresh_rate = refresh_rate
|
58
|
+
self.statistics = [
|
59
|
+
"elapsed_time",
|
60
|
+
"total_interviews_requested",
|
61
|
+
"completed_interviews",
|
62
|
+
# "percent_complete",
|
63
|
+
"average_time_per_interview",
|
64
|
+
# "task_remaining",
|
65
|
+
"estimated_time_remaining",
|
66
|
+
"exceptions",
|
67
|
+
"unfixed_exceptions",
|
68
|
+
"throughput",
|
69
|
+
]
|
70
|
+
self.num_total_interviews = n * len(self.jobs_runner.interviews)
|
71
|
+
|
72
|
+
self.distinct_models = list(
|
73
|
+
set(i.model.model for i in self.jobs_runner.interviews)
|
74
|
+
)
|
75
|
+
|
76
|
+
self.completed_interview_by_model = defaultdict(list)
|
77
|
+
|
78
|
+
self.api_key = api_key or os.getenv("EXPECTED_PARROT_API_KEY")
|
79
|
+
|
80
|
+
@abstractmethod
|
81
|
+
def has_ep_api_key(self):
|
82
|
+
"""
|
83
|
+
Checks if the user has an Expected Parrot API key.
|
84
|
+
"""
|
85
|
+
pass
|
86
|
+
|
87
|
+
def get_status_dict(self) -> Dict[str, Any]:
|
88
|
+
"""
|
89
|
+
Converts current status into a JSON-serializable dictionary.
|
90
|
+
"""
|
91
|
+
# Get all statistics
|
92
|
+
stats = {}
|
93
|
+
for stat_name in self.statistics:
|
94
|
+
stat = self._compute_statistic(stat_name)
|
95
|
+
name, value = list(stat.items())[0]
|
96
|
+
stats[name] = value
|
97
|
+
|
98
|
+
# Calculate overall progress
|
99
|
+
total_interviews = len(self.jobs_runner.total_interviews)
|
100
|
+
completed = len(self.completed_interviews)
|
101
|
+
|
102
|
+
# Get model-specific progress
|
103
|
+
model_progress = {}
|
104
|
+
for model in self.distinct_models:
|
105
|
+
completed_for_model = len(self.completed_interview_by_model[model])
|
106
|
+
target_for_model = int(
|
107
|
+
self.num_total_interviews / len(self.distinct_models)
|
108
|
+
)
|
109
|
+
model_progress[model] = {
|
110
|
+
"completed": completed_for_model,
|
111
|
+
"total": target_for_model,
|
112
|
+
"percent": (
|
113
|
+
(completed_for_model / target_for_model * 100)
|
114
|
+
if target_for_model > 0
|
115
|
+
else 0
|
116
|
+
),
|
117
|
+
}
|
118
|
+
|
119
|
+
status_dict = {
|
120
|
+
"overall_progress": {
|
121
|
+
"completed": completed,
|
122
|
+
"total": total_interviews,
|
123
|
+
"percent": (
|
124
|
+
(completed / total_interviews * 100) if total_interviews > 0 else 0
|
125
|
+
),
|
126
|
+
},
|
127
|
+
"language_model_progress": model_progress,
|
128
|
+
"statistics": stats,
|
129
|
+
"status": "completed" if completed >= total_interviews else "running",
|
130
|
+
}
|
131
|
+
|
132
|
+
model_queues = {}
|
133
|
+
for model, bucket in self.jobs_runner.bucket_collection.items():
|
134
|
+
model_name = model.model
|
135
|
+
model_queues[model_name] = {
|
136
|
+
"language_model_name": model_name,
|
137
|
+
"requests_bucket": {
|
138
|
+
"completed": bucket.requests_bucket.num_released,
|
139
|
+
"requested": bucket.requests_bucket.num_requests,
|
140
|
+
"tokens_returned": bucket.requests_bucket.tokens_returned,
|
141
|
+
"target_rate": round(bucket.requests_bucket.target_rate, 1),
|
142
|
+
"current_rate": round(bucket.requests_bucket.get_throughput(), 1),
|
143
|
+
},
|
144
|
+
"tokens_bucket": {
|
145
|
+
"completed": bucket.tokens_bucket.num_released,
|
146
|
+
"requested": bucket.tokens_bucket.num_requests,
|
147
|
+
"tokens_returned": bucket.tokens_bucket.tokens_returned,
|
148
|
+
"target_rate": round(bucket.tokens_bucket.target_rate, 1),
|
149
|
+
"current_rate": round(bucket.tokens_bucket.get_throughput(), 1),
|
150
|
+
},
|
151
|
+
}
|
152
|
+
status_dict["language_model_queues"] = model_queues
|
153
|
+
return status_dict
|
154
|
+
|
155
|
+
@abstractmethod
|
156
|
+
def setup(self):
|
157
|
+
"""
|
158
|
+
Conducts any setup that needs to happen prior to sending status updates.
|
159
|
+
|
160
|
+
Ex. For a local job, creates a job in the Coop database.
|
161
|
+
"""
|
162
|
+
pass
|
163
|
+
|
164
|
+
@abstractmethod
|
165
|
+
def send_status_update(self):
|
166
|
+
"""
|
167
|
+
Updates the current status of the job.
|
168
|
+
"""
|
169
|
+
pass
|
170
|
+
|
171
|
+
def add_completed_interview(self, result):
|
172
|
+
self.completed_interviews.append(result.interview_hash)
|
173
|
+
|
174
|
+
relevant_model = result.model.model
|
175
|
+
self.completed_interview_by_model[relevant_model].append(result.interview_hash)
|
176
|
+
|
177
|
+
def _compute_statistic(self, stat_name: str):
|
178
|
+
completed_tasks = self.completed_interviews
|
179
|
+
elapsed_time = time.time() - self.start_time
|
180
|
+
interviews = self.jobs_runner.total_interviews
|
181
|
+
|
182
|
+
stat_definitions = {
|
183
|
+
"elapsed_time": lambda: InterviewStatistic(
|
184
|
+
"elapsed_time", value=elapsed_time, digits=1, units="sec."
|
185
|
+
),
|
186
|
+
"total_interviews_requested": lambda: InterviewStatistic(
|
187
|
+
"total_interviews_requested", value=len(interviews), units=""
|
188
|
+
),
|
189
|
+
"completed_interviews": lambda: InterviewStatistic(
|
190
|
+
"completed_interviews", value=len(completed_tasks), units=""
|
191
|
+
),
|
192
|
+
"percent_complete": lambda: InterviewStatistic(
|
193
|
+
"percent_complete",
|
194
|
+
value=(
|
195
|
+
len(completed_tasks) / len(interviews) * 100
|
196
|
+
if len(interviews) > 0
|
197
|
+
else 0
|
198
|
+
),
|
199
|
+
digits=1,
|
200
|
+
units="%",
|
201
|
+
),
|
202
|
+
"average_time_per_interview": lambda: InterviewStatistic(
|
203
|
+
"average_time_per_interview",
|
204
|
+
value=elapsed_time / len(completed_tasks) if completed_tasks else 0,
|
205
|
+
digits=2,
|
206
|
+
units="sec.",
|
207
|
+
),
|
208
|
+
"task_remaining": lambda: InterviewStatistic(
|
209
|
+
"task_remaining", value=len(interviews) - len(completed_tasks), units=""
|
210
|
+
),
|
211
|
+
"estimated_time_remaining": lambda: InterviewStatistic(
|
212
|
+
"estimated_time_remaining",
|
213
|
+
value=(
|
214
|
+
(len(interviews) - len(completed_tasks))
|
215
|
+
* (elapsed_time / len(completed_tasks))
|
216
|
+
if len(completed_tasks) > 0
|
217
|
+
else 0
|
218
|
+
),
|
219
|
+
digits=1,
|
220
|
+
units="sec.",
|
221
|
+
),
|
222
|
+
"exceptions": lambda: InterviewStatistic(
|
223
|
+
"exceptions",
|
224
|
+
value=sum(len(i.exceptions) for i in interviews),
|
225
|
+
units="",
|
226
|
+
),
|
227
|
+
"unfixed_exceptions": lambda: InterviewStatistic(
|
228
|
+
"unfixed_exceptions",
|
229
|
+
value=sum(i.exceptions.num_unfixed() for i in interviews),
|
230
|
+
units="",
|
231
|
+
),
|
232
|
+
"throughput": lambda: InterviewStatistic(
|
233
|
+
"throughput",
|
234
|
+
value=len(completed_tasks) / elapsed_time if elapsed_time > 0 else 0,
|
235
|
+
digits=2,
|
236
|
+
units="interviews/sec.",
|
237
|
+
),
|
238
|
+
}
|
239
|
+
return stat_definitions[stat_name]()
|
240
|
+
|
241
|
+
def update_progress(self, stop_event):
|
242
|
+
while not stop_event.is_set():
|
243
|
+
self.send_status_update()
|
244
|
+
time.sleep(self.refresh_rate)
|
245
|
+
|
246
|
+
self.send_status_update()
|
247
|
+
|
248
|
+
|
249
|
+
class JobsRunnerStatus(JobsRunnerStatusBase):
|
250
|
+
@property
|
251
|
+
def create_url(self) -> str:
|
252
|
+
return f"{self.base_url}/api/v0/local-job"
|
253
|
+
|
254
|
+
@property
|
255
|
+
def viewing_url(self) -> str:
|
256
|
+
return f"{self.base_url}/home/local-job-progress/{str(self.job_uuid)}"
|
257
|
+
|
258
|
+
@property
|
259
|
+
def update_url(self) -> str:
|
260
|
+
return f"{self.base_url}/api/v0/local-job/{str(self.job_uuid)}"
|
261
|
+
|
262
|
+
def setup(self) -> None:
|
263
|
+
"""
|
264
|
+
Creates a local job on Coop if one does not already exist.
|
265
|
+
"""
|
266
|
+
|
267
|
+
headers = {"Content-Type": "application/json"}
|
268
|
+
|
269
|
+
if self.api_key:
|
270
|
+
headers["Authorization"] = f"Bearer {self.api_key}"
|
271
|
+
else:
|
272
|
+
headers["Authorization"] = f"Bearer None"
|
273
|
+
|
274
|
+
if self.job_uuid is None:
|
275
|
+
# Create a new local job
|
276
|
+
response = requests.post(
|
277
|
+
self.create_url,
|
278
|
+
headers=headers,
|
279
|
+
timeout=1,
|
280
|
+
)
|
281
|
+
response.raise_for_status()
|
282
|
+
data = response.json()
|
283
|
+
self.job_uuid = data.get("job_uuid")
|
284
|
+
|
285
|
+
print(f"Running with progress bar. View progress at {self.viewing_url}")
|
286
|
+
|
287
|
+
def send_status_update(self) -> None:
|
288
|
+
"""
|
289
|
+
Sends current status to the web endpoint using the instance's job_uuid.
|
290
|
+
"""
|
291
|
+
try:
|
292
|
+
# Get the status dictionary and add the job_id
|
293
|
+
status_dict = self.get_status_dict()
|
294
|
+
|
295
|
+
# Make the UUID JSON serializable
|
296
|
+
status_dict["job_id"] = str(self.job_uuid)
|
297
|
+
|
298
|
+
headers = {"Content-Type": "application/json"}
|
299
|
+
|
300
|
+
if self.api_key:
|
301
|
+
headers["Authorization"] = f"Bearer {self.api_key}"
|
302
|
+
else:
|
303
|
+
headers["Authorization"] = f"Bearer None"
|
304
|
+
|
305
|
+
# Send the update
|
306
|
+
response = requests.patch(
|
307
|
+
self.update_url,
|
308
|
+
json=status_dict,
|
309
|
+
headers=headers,
|
310
|
+
timeout=1,
|
311
|
+
)
|
312
|
+
response.raise_for_status()
|
313
|
+
except requests.exceptions.RequestException as e:
|
314
|
+
print(f"Failed to send status update for job {self.job_uuid}: {e}")
|
315
|
+
|
316
|
+
def has_ep_api_key(self) -> bool:
|
317
|
+
"""
|
318
|
+
Returns True if the user has an Expected Parrot API key. Otherwise, returns False.
|
319
|
+
"""
|
320
|
+
|
321
|
+
if self.api_key is not None:
|
322
|
+
return True
|
323
|
+
else:
|
324
|
+
return False
|
325
|
+
|
326
|
+
|
327
|
+
if __name__ == "__main__":
|
328
|
+
import doctest
|
329
|
+
|
330
|
+
doctest.testmod(optionflags=doctest.ELLIPSIS)
|