edsl 0.1.39__py3-none-any.whl → 0.1.39.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +116 -197
- edsl/__init__.py +7 -15
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +147 -351
- edsl/agents/AgentList.py +73 -211
- edsl/agents/Invigilator.py +50 -101
- edsl/agents/InvigilatorBase.py +70 -62
- edsl/agents/PromptConstructor.py +225 -143
- edsl/agents/__init__.py +1 -0
- edsl/agents/prompt_helpers.py +3 -3
- edsl/auto/AutoStudy.py +5 -18
- edsl/auto/StageBase.py +40 -53
- edsl/auto/StageQuestions.py +1 -2
- edsl/auto/utilities.py +6 -0
- edsl/config.py +2 -22
- edsl/conversation/car_buying.py +1 -2
- edsl/coop/PriceFetcher.py +1 -1
- edsl/coop/coop.py +47 -125
- edsl/coop/utils.py +14 -14
- edsl/data/Cache.py +27 -45
- edsl/data/CacheEntry.py +15 -12
- edsl/data/CacheHandler.py +12 -31
- edsl/data/RemoteCacheSync.py +46 -154
- edsl/data/__init__.py +3 -4
- edsl/data_transfer_models.py +1 -2
- edsl/enums.py +0 -27
- edsl/exceptions/__init__.py +50 -50
- edsl/exceptions/agents.py +0 -12
- edsl/exceptions/questions.py +6 -24
- edsl/exceptions/scenarios.py +0 -7
- edsl/inference_services/AnthropicService.py +19 -38
- edsl/inference_services/AwsBedrock.py +2 -0
- edsl/inference_services/AzureAI.py +2 -0
- edsl/inference_services/GoogleService.py +12 -7
- edsl/inference_services/InferenceServiceABC.py +85 -18
- edsl/inference_services/InferenceServicesCollection.py +79 -120
- edsl/inference_services/MistralAIService.py +3 -0
- edsl/inference_services/OpenAIService.py +35 -47
- edsl/inference_services/PerplexityService.py +3 -0
- edsl/inference_services/TestService.py +10 -11
- edsl/inference_services/TogetherAIService.py +3 -5
- edsl/jobs/Answers.py +14 -1
- edsl/jobs/Jobs.py +431 -356
- edsl/jobs/JobsChecks.py +10 -35
- edsl/jobs/JobsPrompts.py +4 -6
- edsl/jobs/JobsRemoteInferenceHandler.py +133 -205
- edsl/jobs/buckets/BucketCollection.py +3 -44
- edsl/jobs/buckets/TokenBucket.py +21 -53
- edsl/jobs/interviews/Interview.py +408 -143
- edsl/jobs/runners/JobsRunnerAsyncio.py +403 -88
- edsl/jobs/runners/JobsRunnerStatus.py +165 -133
- edsl/jobs/tasks/QuestionTaskCreator.py +19 -21
- edsl/jobs/tasks/TaskHistory.py +18 -38
- edsl/jobs/tasks/task_status_enum.py +2 -0
- edsl/language_models/KeyLookup.py +30 -0
- edsl/language_models/LanguageModel.py +236 -194
- edsl/language_models/ModelList.py +19 -28
- edsl/language_models/__init__.py +2 -1
- edsl/language_models/registry.py +190 -0
- edsl/language_models/repair.py +2 -2
- edsl/language_models/unused/ReplicateBase.py +83 -0
- edsl/language_models/utilities.py +4 -5
- edsl/notebooks/Notebook.py +14 -19
- edsl/prompts/Prompt.py +39 -29
- edsl/questions/{answer_validator_mixin.py → AnswerValidatorMixin.py} +2 -47
- edsl/questions/QuestionBase.py +214 -68
- edsl/questions/{question_base_gen_mixin.py → QuestionBaseGenMixin.py} +50 -57
- edsl/questions/QuestionBasePromptsMixin.py +3 -7
- edsl/questions/QuestionBudget.py +1 -1
- edsl/questions/QuestionCheckBox.py +3 -3
- edsl/questions/QuestionExtract.py +7 -5
- edsl/questions/QuestionFreeText.py +3 -2
- edsl/questions/QuestionList.py +18 -10
- edsl/questions/QuestionMultipleChoice.py +23 -67
- edsl/questions/QuestionNumerical.py +4 -2
- edsl/questions/QuestionRank.py +17 -7
- edsl/questions/{response_validator_abc.py → ResponseValidatorABC.py} +26 -40
- edsl/questions/SimpleAskMixin.py +3 -4
- edsl/questions/__init__.py +1 -2
- edsl/questions/derived/QuestionLinearScale.py +3 -6
- edsl/questions/derived/QuestionTopK.py +1 -1
- edsl/questions/descriptors.py +3 -17
- edsl/questions/question_registry.py +1 -1
- edsl/results/CSSParameterizer.py +1 -1
- edsl/results/Dataset.py +7 -170
- edsl/results/DatasetExportMixin.py +305 -168
- edsl/results/DatasetTree.py +8 -28
- edsl/results/Result.py +206 -298
- edsl/results/Results.py +131 -149
- edsl/results/ResultsDBMixin.py +238 -0
- edsl/results/ResultsExportMixin.py +0 -2
- edsl/results/{results_selector.py → Selector.py} +13 -23
- edsl/results/TableDisplay.py +171 -98
- edsl/results/__init__.py +1 -1
- edsl/scenarios/FileStore.py +239 -150
- edsl/scenarios/Scenario.py +193 -90
- edsl/scenarios/ScenarioHtmlMixin.py +3 -4
- edsl/scenarios/{scenario_join.py → ScenarioJoin.py} +6 -10
- edsl/scenarios/ScenarioList.py +244 -415
- edsl/scenarios/ScenarioListExportMixin.py +7 -0
- edsl/scenarios/ScenarioListPdfMixin.py +37 -15
- edsl/scenarios/__init__.py +2 -1
- edsl/study/ObjectEntry.py +1 -1
- edsl/study/SnapShot.py +1 -1
- edsl/study/Study.py +12 -5
- edsl/surveys/Rule.py +4 -5
- edsl/surveys/RuleCollection.py +27 -25
- edsl/surveys/Survey.py +791 -270
- edsl/surveys/SurveyCSS.py +8 -20
- edsl/surveys/{SurveyFlowVisualization.py → SurveyFlowVisualizationMixin.py} +9 -11
- edsl/surveys/__init__.py +2 -4
- edsl/surveys/descriptors.py +2 -6
- edsl/surveys/instructions/ChangeInstruction.py +2 -1
- edsl/surveys/instructions/Instruction.py +13 -4
- edsl/surveys/instructions/InstructionCollection.py +6 -11
- edsl/templates/error_reporting/interview_details.html +1 -1
- edsl/templates/error_reporting/report.html +1 -1
- edsl/tools/plotting.py +1 -1
- edsl/utilities/utilities.py +23 -35
- {edsl-0.1.39.dist-info → edsl-0.1.39.dev1.dist-info}/METADATA +10 -12
- edsl-0.1.39.dev1.dist-info/RECORD +277 -0
- {edsl-0.1.39.dist-info → edsl-0.1.39.dev1.dist-info}/WHEEL +1 -1
- edsl/agents/QuestionInstructionPromptBuilder.py +0 -128
- edsl/agents/QuestionTemplateReplacementsBuilder.py +0 -137
- edsl/agents/question_option_processor.py +0 -172
- edsl/coop/CoopFunctionsMixin.py +0 -15
- edsl/coop/ExpectedParrotKeyHandler.py +0 -125
- edsl/exceptions/inference_services.py +0 -5
- edsl/inference_services/AvailableModelCacheHandler.py +0 -184
- edsl/inference_services/AvailableModelFetcher.py +0 -215
- edsl/inference_services/ServiceAvailability.py +0 -135
- edsl/inference_services/data_structures.py +0 -134
- edsl/jobs/AnswerQuestionFunctionConstructor.py +0 -223
- edsl/jobs/FetchInvigilator.py +0 -47
- edsl/jobs/InterviewTaskManager.py +0 -98
- edsl/jobs/InterviewsConstructor.py +0 -50
- edsl/jobs/JobsComponentConstructor.py +0 -189
- edsl/jobs/JobsRemoteInferenceLogger.py +0 -239
- edsl/jobs/RequestTokenEstimator.py +0 -30
- edsl/jobs/async_interview_runner.py +0 -138
- edsl/jobs/buckets/TokenBucketAPI.py +0 -211
- edsl/jobs/buckets/TokenBucketClient.py +0 -191
- edsl/jobs/check_survey_scenario_compatibility.py +0 -85
- edsl/jobs/data_structures.py +0 -120
- edsl/jobs/decorators.py +0 -35
- edsl/jobs/jobs_status_enums.py +0 -9
- edsl/jobs/loggers/HTMLTableJobLogger.py +0 -304
- edsl/jobs/results_exceptions_handler.py +0 -98
- edsl/language_models/ComputeCost.py +0 -63
- edsl/language_models/PriceManager.py +0 -127
- edsl/language_models/RawResponseHandler.py +0 -106
- edsl/language_models/ServiceDataSources.py +0 -0
- edsl/language_models/key_management/KeyLookup.py +0 -63
- edsl/language_models/key_management/KeyLookupBuilder.py +0 -273
- edsl/language_models/key_management/KeyLookupCollection.py +0 -38
- edsl/language_models/key_management/__init__.py +0 -0
- edsl/language_models/key_management/models.py +0 -131
- edsl/language_models/model.py +0 -256
- edsl/notebooks/NotebookToLaTeX.py +0 -142
- edsl/questions/ExceptionExplainer.py +0 -77
- edsl/questions/HTMLQuestion.py +0 -103
- edsl/questions/QuestionMatrix.py +0 -265
- edsl/questions/data_structures.py +0 -20
- edsl/questions/loop_processor.py +0 -149
- edsl/questions/response_validator_factory.py +0 -34
- edsl/questions/templates/matrix/__init__.py +0 -1
- edsl/questions/templates/matrix/answering_instructions.jinja +0 -5
- edsl/questions/templates/matrix/question_presentation.jinja +0 -20
- edsl/results/MarkdownToDocx.py +0 -122
- edsl/results/MarkdownToPDF.py +0 -111
- edsl/results/TextEditor.py +0 -50
- edsl/results/file_exports.py +0 -252
- edsl/results/smart_objects.py +0 -96
- edsl/results/table_data_class.py +0 -12
- edsl/results/table_renderers.py +0 -118
- edsl/scenarios/ConstructDownloadLink.py +0 -109
- edsl/scenarios/DocumentChunker.py +0 -102
- edsl/scenarios/DocxScenario.py +0 -16
- edsl/scenarios/PdfExtractor.py +0 -40
- edsl/scenarios/directory_scanner.py +0 -96
- edsl/scenarios/file_methods.py +0 -85
- edsl/scenarios/handlers/__init__.py +0 -13
- edsl/scenarios/handlers/csv.py +0 -49
- edsl/scenarios/handlers/docx.py +0 -76
- edsl/scenarios/handlers/html.py +0 -37
- edsl/scenarios/handlers/json.py +0 -111
- edsl/scenarios/handlers/latex.py +0 -5
- edsl/scenarios/handlers/md.py +0 -51
- edsl/scenarios/handlers/pdf.py +0 -68
- edsl/scenarios/handlers/png.py +0 -39
- edsl/scenarios/handlers/pptx.py +0 -105
- edsl/scenarios/handlers/py.py +0 -294
- edsl/scenarios/handlers/sql.py +0 -313
- edsl/scenarios/handlers/sqlite.py +0 -149
- edsl/scenarios/handlers/txt.py +0 -33
- edsl/scenarios/scenario_selector.py +0 -156
- edsl/surveys/ConstructDAG.py +0 -92
- edsl/surveys/EditSurvey.py +0 -221
- edsl/surveys/InstructionHandler.py +0 -100
- edsl/surveys/MemoryManagement.py +0 -72
- edsl/surveys/RuleManager.py +0 -172
- edsl/surveys/Simulator.py +0 -75
- edsl/surveys/SurveyToApp.py +0 -141
- edsl/utilities/PrettyList.py +0 -56
- edsl/utilities/is_notebook.py +0 -18
- edsl/utilities/is_valid_variable_name.py +0 -11
- edsl/utilities/remove_edsl_version.py +0 -24
- edsl-0.1.39.dist-info/RECORD +0 -358
- /edsl/questions/{register_questions_meta.py → RegisterQuestionsMeta.py} +0 -0
- /edsl/results/{results_fetch_mixin.py → ResultsFetchMixin.py} +0 -0
- /edsl/results/{results_tools_mixin.py → ResultsToolsMixin.py} +0 -0
- {edsl-0.1.39.dist-info → edsl-0.1.39.dev1.dist-info}/LICENSE +0 -0
edsl/jobs/JobsChecks.py
CHANGED
@@ -1,21 +1,16 @@
|
|
1
1
|
import os
|
2
|
-
from edsl.exceptions
|
2
|
+
from edsl.exceptions import MissingAPIKeyError
|
3
3
|
|
4
4
|
|
5
5
|
class JobsChecks:
|
6
6
|
def __init__(self, jobs):
|
7
|
-
"""
|
7
|
+
""" """
|
8
8
|
self.jobs = jobs
|
9
9
|
|
10
10
|
def check_api_keys(self) -> None:
|
11
|
-
from edsl
|
11
|
+
from edsl import Model
|
12
12
|
|
13
|
-
|
14
|
-
models = [Model()]
|
15
|
-
else:
|
16
|
-
models = self.jobs.models
|
17
|
-
|
18
|
-
for model in models: # + [Model()]:
|
13
|
+
for model in self.jobs.models + [Model()]:
|
19
14
|
if not model.has_valid_api_key():
|
20
15
|
raise MissingAPIKeyError(
|
21
16
|
model_name=str(model.model),
|
@@ -28,7 +23,7 @@ class JobsChecks:
|
|
28
23
|
"""
|
29
24
|
missing_api_keys = set()
|
30
25
|
|
31
|
-
from edsl
|
26
|
+
from edsl import Model
|
32
27
|
from edsl.enums import service_to_api_keyname
|
33
28
|
|
34
29
|
for model in self.jobs.models + [Model()]:
|
@@ -100,33 +95,16 @@ class JobsChecks:
|
|
100
95
|
return True
|
101
96
|
|
102
97
|
def needs_key_process(self):
|
103
|
-
"""
|
104
|
-
A User needs the key process when:
|
105
|
-
1. They don't have all the model keys
|
106
|
-
2. They don't have the EP API
|
107
|
-
3. They need external LLMs to run the job
|
108
|
-
"""
|
109
98
|
return (
|
110
99
|
not self.user_has_all_model_keys()
|
111
100
|
and not self.user_has_ep_api_key()
|
112
101
|
and self.needs_external_llms()
|
113
102
|
)
|
114
103
|
|
115
|
-
def status(self) -> dict:
|
116
|
-
"""
|
117
|
-
Returns a dictionary with the status of the job checks.
|
118
|
-
"""
|
119
|
-
return {
|
120
|
-
"user_has_ep_api_key": self.user_has_ep_api_key(),
|
121
|
-
"user_has_all_model_keys": self.user_has_all_model_keys(),
|
122
|
-
"needs_external_llms": self.needs_external_llms(),
|
123
|
-
"needs_key_process": self.needs_key_process(),
|
124
|
-
}
|
125
|
-
|
126
104
|
def key_process(self):
|
127
105
|
import secrets
|
128
106
|
from dotenv import load_dotenv
|
129
|
-
from edsl
|
107
|
+
from edsl import CONFIG
|
130
108
|
from edsl.coop.coop import Coop
|
131
109
|
from edsl.utilities.utilities import write_api_key_to_env
|
132
110
|
|
@@ -141,12 +119,10 @@ class JobsChecks:
|
|
141
119
|
"\nYou can either add the missing keys to your .env file, or use remote inference."
|
142
120
|
)
|
143
121
|
print("Remote inference allows you to run jobs on our server.")
|
122
|
+
print("\n🚀 To use remote inference, sign up at the following link:")
|
144
123
|
|
145
124
|
coop = Coop()
|
146
|
-
coop._display_login_url(
|
147
|
-
edsl_auth_token=edsl_auth_token,
|
148
|
-
link_description="\n🚀 To use remote inference, sign up at the following link:",
|
149
|
-
)
|
125
|
+
coop._display_login_url(edsl_auth_token=edsl_auth_token)
|
150
126
|
|
151
127
|
print(
|
152
128
|
"\nOnce you log in, we will automatically retrieve your Expected Parrot API key and continue your job remotely."
|
@@ -158,9 +134,8 @@ class JobsChecks:
|
|
158
134
|
print("\nTimed out waiting for login. Please try again.")
|
159
135
|
return
|
160
136
|
|
161
|
-
|
162
|
-
print("
|
163
|
-
print(f" {path_to_env}")
|
137
|
+
write_api_key_to_env(api_key)
|
138
|
+
print("✨ API key retrieved and written to .env file.\n")
|
164
139
|
|
165
140
|
# Retrieve API key so we can continue running the job
|
166
141
|
load_dotenv()
|
edsl/jobs/JobsPrompts.py
CHANGED
@@ -11,8 +11,6 @@ if TYPE_CHECKING:
|
|
11
11
|
# from edsl.scenarios.ScenarioList import ScenarioList
|
12
12
|
# from edsl.surveys.Survey import Survey
|
13
13
|
|
14
|
-
from edsl.jobs.FetchInvigilator import FetchInvigilator
|
15
|
-
|
16
14
|
|
17
15
|
class JobsPrompts:
|
18
16
|
def __init__(self, jobs: "Jobs"):
|
@@ -25,7 +23,7 @@ class JobsPrompts:
|
|
25
23
|
@property
|
26
24
|
def price_lookup(self):
|
27
25
|
if self._price_lookup is None:
|
28
|
-
from edsl
|
26
|
+
from edsl import Coop
|
29
27
|
|
30
28
|
c = Coop()
|
31
29
|
self._price_lookup = c.fetch_prices()
|
@@ -50,8 +48,8 @@ class JobsPrompts:
|
|
50
48
|
|
51
49
|
for interview_index, interview in enumerate(interviews):
|
52
50
|
invigilators = [
|
53
|
-
|
54
|
-
for question in
|
51
|
+
interview._get_invigilator(question)
|
52
|
+
for question in self.survey.questions
|
55
53
|
]
|
56
54
|
for _, invigilator in enumerate(invigilators):
|
57
55
|
prompts = invigilator.get_prompts()
|
@@ -186,7 +184,7 @@ class JobsPrompts:
|
|
186
184
|
data = []
|
187
185
|
for interview in interviews:
|
188
186
|
invigilators = [
|
189
|
-
|
187
|
+
interview._get_invigilator(question)
|
190
188
|
for question in self.survey.questions
|
191
189
|
]
|
192
190
|
for invigilator in invigilators:
|
@@ -1,78 +1,47 @@
|
|
1
|
-
from typing import Optional, Union, Literal
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
Seconds = NewType("Seconds", float)
|
7
|
-
JobUUID = NewType("JobUUID", str)
|
8
|
-
|
1
|
+
from typing import Optional, Union, Literal
|
2
|
+
import requests
|
3
|
+
import sys
|
9
4
|
from edsl.exceptions.coop import CoopServerResponseError
|
10
5
|
|
11
|
-
|
12
|
-
|
13
|
-
from edsl.jobs.Jobs import Jobs
|
14
|
-
from edsl.coop.coop import RemoteInferenceResponse, RemoteInferenceCreationInfo
|
15
|
-
from edsl.jobs.JobsRemoteInferenceLogger import JobLogger
|
16
|
-
|
17
|
-
from edsl.coop.coop import RemoteInferenceResponse, RemoteInferenceCreationInfo
|
18
|
-
|
19
|
-
from edsl.jobs.jobs_status_enums import JobsStatus
|
20
|
-
from edsl.coop.utils import VisibilityType
|
21
|
-
from edsl.jobs.JobsRemoteInferenceLogger import JobLogger
|
22
|
-
|
23
|
-
|
24
|
-
class RemoteJobConstants:
|
25
|
-
"""Constants for remote job handling."""
|
26
|
-
|
27
|
-
REMOTE_JOB_POLL_INTERVAL = 1
|
28
|
-
REMOTE_JOB_VERBOSE = False
|
29
|
-
DISCORD_URL = "https://discord.com/invite/mxAYkjfy9m"
|
30
|
-
|
31
|
-
|
32
|
-
@dataclass
|
33
|
-
class RemoteJobInfo:
|
34
|
-
creation_data: RemoteInferenceCreationInfo
|
35
|
-
job_uuid: JobUUID
|
36
|
-
logger: JobLogger
|
6
|
+
# from edsl.enums import VisibilityType
|
7
|
+
from edsl.results import Results
|
37
8
|
|
38
9
|
|
39
10
|
class JobsRemoteInferenceHandler:
|
40
|
-
def __init__(
|
41
|
-
|
42
|
-
jobs
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
"""
|
11
|
+
def __init__(self, jobs, verbose=False, poll_interval=3):
|
12
|
+
"""
|
13
|
+
>>> from edsl.jobs import Jobs
|
14
|
+
>>> jh = JobsRemoteInferenceHandler(Jobs.example(), verbose=True)
|
15
|
+
>>> jh.use_remote_inference(True)
|
16
|
+
False
|
17
|
+
>>> jh._poll_remote_inference_job({'uuid':1234}, testing_simulated_response={"status": "failed"}) # doctest: +NORMALIZE_WHITESPACE
|
18
|
+
Job failed.
|
19
|
+
...
|
20
|
+
>>> jh._poll_remote_inference_job({'uuid':1234}, testing_simulated_response={"status": "completed"}) # doctest: +NORMALIZE_WHITESPACE
|
21
|
+
Job completed and Results stored on Coop: None.
|
22
|
+
Results(...)
|
23
|
+
"""
|
47
24
|
self.jobs = jobs
|
48
25
|
self.verbose = verbose
|
49
26
|
self.poll_interval = poll_interval
|
50
27
|
|
51
|
-
|
52
|
-
|
53
|
-
self.expected_parrot_url = CONFIG.get("EXPECTED_PARROT_URL")
|
54
|
-
self.remote_inference_url = f"{self.expected_parrot_url}/home/remote-inference"
|
28
|
+
self._remote_job_creation_data = None
|
29
|
+
self._job_uuid = None
|
55
30
|
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
JupyterJobLogger,
|
60
|
-
StdOutJobLogger,
|
61
|
-
)
|
62
|
-
from edsl.jobs.loggers.HTMLTableJobLogger import HTMLTableJobLogger
|
31
|
+
@property
|
32
|
+
def remote_job_creation_data(self):
|
33
|
+
return self._remote_job_creation_data
|
63
34
|
|
64
|
-
|
65
|
-
|
66
|
-
return
|
35
|
+
@property
|
36
|
+
def job_uuid(self):
|
37
|
+
return self._job_uuid
|
67
38
|
|
68
39
|
def use_remote_inference(self, disable_remote_inference: bool) -> bool:
|
69
|
-
import requests
|
70
|
-
|
71
40
|
if disable_remote_inference:
|
72
41
|
return False
|
73
42
|
if not disable_remote_inference:
|
74
43
|
try:
|
75
|
-
from edsl
|
44
|
+
from edsl import Coop
|
76
45
|
|
77
46
|
user_edsl_settings = Coop().edsl_settings
|
78
47
|
return user_edsl_settings.get("remote_inference", False)
|
@@ -87,19 +56,16 @@ class JobsRemoteInferenceHandler:
|
|
87
56
|
self,
|
88
57
|
iterations: int = 1,
|
89
58
|
remote_inference_description: Optional[str] = None,
|
90
|
-
remote_inference_results_visibility: Optional[VisibilityType] = "unlisted",
|
91
|
-
|
92
|
-
|
59
|
+
remote_inference_results_visibility: Optional["VisibilityType"] = "unlisted",
|
60
|
+
verbose=False,
|
61
|
+
):
|
62
|
+
""" """
|
93
63
|
from edsl.config import CONFIG
|
94
64
|
from edsl.coop.coop import Coop
|
95
|
-
|
96
|
-
logger = self._create_logger()
|
65
|
+
from rich import print as rich_print
|
97
66
|
|
98
67
|
coop = Coop()
|
99
|
-
|
100
|
-
"Remote inference activated. Sending job to server...",
|
101
|
-
status=JobsStatus.QUEUED,
|
102
|
-
)
|
68
|
+
print("Remote inference activated. Sending job to server...")
|
103
69
|
remote_job_creation_data = coop.remote_inference_create(
|
104
70
|
self.jobs,
|
105
71
|
description=remote_inference_description,
|
@@ -107,172 +73,136 @@ class JobsRemoteInferenceHandler:
|
|
107
73
|
iterations=iterations,
|
108
74
|
initial_results_visibility=remote_inference_results_visibility,
|
109
75
|
)
|
110
|
-
logger.update(
|
111
|
-
"Your survey is running at the Expected Parrot server...",
|
112
|
-
status=JobsStatus.RUNNING,
|
113
|
-
)
|
114
76
|
job_uuid = remote_job_creation_data.get("uuid")
|
115
|
-
|
116
|
-
message=f"Job sent to server. (Job uuid={job_uuid}).",
|
117
|
-
status=JobsStatus.RUNNING,
|
118
|
-
)
|
119
|
-
logger.add_info("job_uuid", job_uuid)
|
77
|
+
print(f"Job sent to server. (Job uuid={job_uuid}).")
|
120
78
|
|
121
|
-
|
122
|
-
|
123
|
-
status=JobsStatus.RUNNING,
|
124
|
-
)
|
125
|
-
progress_bar_url = (
|
126
|
-
f"{self.expected_parrot_url}/home/remote-job-progress/{job_uuid}"
|
127
|
-
)
|
128
|
-
logger.add_info("progress_bar_url", progress_bar_url)
|
129
|
-
logger.update(
|
130
|
-
f"View job progress here: {progress_bar_url}", status=JobsStatus.RUNNING
|
131
|
-
)
|
79
|
+
expected_parrot_url = CONFIG.get("EXPECTED_PARROT_URL")
|
80
|
+
progress_bar_url = f"{expected_parrot_url}/home/remote-job-progress/{job_uuid}"
|
132
81
|
|
133
|
-
|
134
|
-
|
135
|
-
job_uuid=job_uuid,
|
136
|
-
logger=logger,
|
82
|
+
rich_print(
|
83
|
+
f"View job progress here: [#38bdf8][link={progress_bar_url}]{progress_bar_url}[/link][/#38bdf8]"
|
137
84
|
)
|
138
85
|
|
86
|
+
self._remote_job_creation_data = remote_job_creation_data
|
87
|
+
self._job_uuid = job_uuid
|
88
|
+
# return remote_job_creation_data
|
89
|
+
|
139
90
|
@staticmethod
|
140
|
-
def check_status(
|
141
|
-
job_uuid: JobUUID,
|
142
|
-
) -> RemoteInferenceResponse:
|
91
|
+
def check_status(job_uuid):
|
143
92
|
from edsl.coop.coop import Coop
|
144
93
|
|
145
94
|
coop = Coop()
|
146
95
|
return coop.remote_inference_get(job_uuid)
|
147
96
|
|
148
|
-
def
|
149
|
-
self
|
150
|
-
|
151
|
-
if testing_simulated_response is not None:
|
152
|
-
return lambda job_uuid: testing_simulated_response
|
153
|
-
else:
|
154
|
-
from edsl.coop.coop import Coop
|
155
|
-
|
156
|
-
coop = Coop()
|
157
|
-
return coop.remote_inference_get
|
158
|
-
|
159
|
-
def _construct_object_fetcher(
|
160
|
-
self, testing_simulated_response: Optional[Any] = None
|
161
|
-
) -> Callable:
|
162
|
-
"Constructs a function to fetch the results object from Coop."
|
163
|
-
if testing_simulated_response is not None:
|
164
|
-
return lambda results_uuid, expected_object_type: Results.example()
|
165
|
-
else:
|
166
|
-
from edsl.coop.coop import Coop
|
167
|
-
|
168
|
-
coop = Coop()
|
169
|
-
return coop.get
|
170
|
-
|
171
|
-
def _handle_cancelled_job(self, job_info: RemoteJobInfo) -> None:
|
172
|
-
"Handles a cancelled job by logging the cancellation and updating the job status."
|
173
|
-
|
174
|
-
job_info.logger.update(
|
175
|
-
message="Job cancelled by the user.", status=JobsStatus.CANCELLED
|
176
|
-
)
|
177
|
-
job_info.logger.update(
|
178
|
-
f"See {self.expected_parrot_url}/home/remote-inference for more details.",
|
179
|
-
status=JobsStatus.CANCELLED,
|
180
|
-
)
|
181
|
-
|
182
|
-
def _handle_failed_job(
|
183
|
-
self, job_info: RemoteJobInfo, remote_job_data: RemoteInferenceResponse
|
184
|
-
) -> None:
|
185
|
-
"Handles a failed job by logging the error and updating the job status."
|
186
|
-
latest_error_report_url = remote_job_data.get("latest_error_report_url")
|
187
|
-
if latest_error_report_url:
|
188
|
-
job_info.logger.add_info("error_report_url", latest_error_report_url)
|
189
|
-
|
190
|
-
job_info.logger.update("Job failed.", status=JobsStatus.FAILED)
|
191
|
-
job_info.logger.update(
|
192
|
-
f"See {self.expected_parrot_url}/home/remote-inference for more details.",
|
193
|
-
status=JobsStatus.FAILED,
|
194
|
-
)
|
195
|
-
job_info.logger.update(
|
196
|
-
f"Need support? Visit Discord: {RemoteJobConstants.DISCORD_URL}",
|
197
|
-
status=JobsStatus.FAILED,
|
97
|
+
def poll_remote_inference_job(self):
|
98
|
+
return self._poll_remote_inference_job(
|
99
|
+
self.remote_job_creation_data, verbose=self.verbose
|
198
100
|
)
|
199
101
|
|
200
|
-
def
|
102
|
+
def _poll_remote_inference_job(
|
103
|
+
self,
|
104
|
+
remote_job_creation_data: dict,
|
105
|
+
verbose=False,
|
106
|
+
poll_interval: Optional[float] = None,
|
107
|
+
testing_simulated_response: Optional[dict] = None,
|
108
|
+
) -> Union[Results, None]:
|
201
109
|
import time
|
202
110
|
from datetime import datetime
|
111
|
+
from edsl.config import CONFIG
|
112
|
+
from edsl.coop.coop import Coop
|
203
113
|
|
204
|
-
|
205
|
-
|
206
|
-
f"Job status: {status} - last update: {time_checked}",
|
207
|
-
status=JobsStatus.RUNNING,
|
208
|
-
)
|
209
|
-
time.sleep(self.poll_interval)
|
114
|
+
if poll_interval is None:
|
115
|
+
poll_interval = self.poll_interval
|
210
116
|
|
211
|
-
|
212
|
-
self,
|
213
|
-
job_info: RemoteJobInfo,
|
214
|
-
results_uuid: str,
|
215
|
-
remote_job_data: RemoteInferenceResponse,
|
216
|
-
object_fetcher: Callable,
|
217
|
-
) -> "Results":
|
218
|
-
"Fetches the results object and logs the results URL."
|
219
|
-
job_info.logger.add_info("results_uuid", results_uuid)
|
220
|
-
results = object_fetcher(results_uuid, expected_object_type="results")
|
221
|
-
results_url = remote_job_data.get("results_url")
|
222
|
-
job_info.logger.update(
|
223
|
-
f"Job completed and Results stored on Coop: {results_url}",
|
224
|
-
status=JobsStatus.COMPLETED,
|
225
|
-
)
|
226
|
-
results.job_uuid = job_info.job_uuid
|
227
|
-
results.results_uuid = results_uuid
|
228
|
-
return results
|
117
|
+
expected_parrot_url = CONFIG.get("EXPECTED_PARROT_URL")
|
229
118
|
|
230
|
-
|
231
|
-
|
232
|
-
job_info: RemoteJobInfo,
|
233
|
-
testing_simulated_response=None,
|
234
|
-
) -> Union[None, "Results"]:
|
235
|
-
"""Polls a remote inference job for completion and returns the results."""
|
119
|
+
job_uuid = remote_job_creation_data.get("uuid")
|
120
|
+
coop = Coop()
|
236
121
|
|
237
|
-
|
238
|
-
testing_simulated_response
|
239
|
-
|
240
|
-
|
122
|
+
if testing_simulated_response is not None:
|
123
|
+
remote_job_data_fetcher = lambda job_uuid: testing_simulated_response
|
124
|
+
object_fetcher = (
|
125
|
+
lambda results_uuid, expected_object_type: Results.example()
|
126
|
+
)
|
127
|
+
else:
|
128
|
+
remote_job_data_fetcher = coop.remote_inference_get
|
129
|
+
object_fetcher = coop.get
|
241
130
|
|
242
131
|
job_in_queue = True
|
243
132
|
while job_in_queue:
|
244
|
-
remote_job_data = remote_job_data_fetcher(
|
133
|
+
remote_job_data = remote_job_data_fetcher(job_uuid)
|
245
134
|
status = remote_job_data.get("status")
|
246
|
-
|
247
135
|
if status == "cancelled":
|
248
|
-
|
136
|
+
print("\r" + " " * 80 + "\r", end="")
|
137
|
+
print("Job cancelled by the user.")
|
138
|
+
print(
|
139
|
+
f"See {expected_parrot_url}/home/remote-inference for more details."
|
140
|
+
)
|
249
141
|
return None
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
object_fetcher=object_fetcher,
|
142
|
+
elif status == "failed":
|
143
|
+
print("\r" + " " * 80 + "\r", end="")
|
144
|
+
# write to stderr
|
145
|
+
latest_error_report_url = remote_job_data.get("latest_error_report_url")
|
146
|
+
if latest_error_report_url:
|
147
|
+
print("Job failed.")
|
148
|
+
print(
|
149
|
+
f"Your job generated exceptions. Details on these exceptions can be found in the following report: {latest_error_report_url}"
|
150
|
+
)
|
151
|
+
print(
|
152
|
+
f"Need support? Post a message at the Expected Parrot Discord channel (https://discord.com/invite/mxAYkjfy9m) or send an email to info@expectedparrot.com."
|
262
153
|
)
|
263
|
-
return results
|
264
154
|
else:
|
265
|
-
|
266
|
-
|
155
|
+
print("Job failed.")
|
156
|
+
print(
|
157
|
+
f"See {expected_parrot_url}/home/remote-inference for more details."
|
158
|
+
)
|
159
|
+
return None
|
160
|
+
elif status == "completed":
|
161
|
+
results_uuid = remote_job_data.get("results_uuid")
|
162
|
+
results_url = remote_job_data.get("results_url")
|
163
|
+
results = object_fetcher(results_uuid, expected_object_type="results")
|
164
|
+
print("\r" + " " * 80 + "\r", end="")
|
165
|
+
print(f"Job completed and Results stored on Coop: {results_url}.")
|
166
|
+
return results
|
267
167
|
else:
|
268
|
-
|
168
|
+
duration = poll_interval
|
169
|
+
time_checked = datetime.now().strftime("%Y-%m-%d %I:%M:%S %p")
|
170
|
+
frames = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]
|
171
|
+
start_time = time.time()
|
172
|
+
i = 0
|
173
|
+
while time.time() - start_time < duration:
|
174
|
+
print(
|
175
|
+
f"\r{frames[i % len(frames)]} Job status: {status} - last update: {time_checked}",
|
176
|
+
end="",
|
177
|
+
flush=True,
|
178
|
+
)
|
179
|
+
time.sleep(0.1)
|
180
|
+
i += 1
|
181
|
+
|
182
|
+
def use_remote_inference(self, disable_remote_inference: bool) -> bool:
|
183
|
+
if disable_remote_inference:
|
184
|
+
return False
|
185
|
+
if not disable_remote_inference:
|
186
|
+
try:
|
187
|
+
from edsl import Coop
|
188
|
+
|
189
|
+
user_edsl_settings = Coop().edsl_settings
|
190
|
+
return user_edsl_settings.get("remote_inference", False)
|
191
|
+
except requests.ConnectionError:
|
192
|
+
pass
|
193
|
+
except CoopServerResponseError as e:
|
194
|
+
pass
|
195
|
+
|
196
|
+
return False
|
269
197
|
|
270
198
|
async def create_and_poll_remote_job(
|
271
199
|
self,
|
272
200
|
iterations: int = 1,
|
273
201
|
remote_inference_description: Optional[str] = None,
|
274
|
-
remote_inference_results_visibility: Optional[
|
275
|
-
|
202
|
+
remote_inference_results_visibility: Optional[
|
203
|
+
Literal["private", "public", "unlisted"]
|
204
|
+
] = "unlisted",
|
205
|
+
) -> Union[Results, None]:
|
276
206
|
"""
|
277
207
|
Creates and polls a remote inference job asynchronously.
|
278
208
|
Reuses existing synchronous methods but runs them in an async context.
|
@@ -287,7 +217,7 @@ class JobsRemoteInferenceHandler:
|
|
287
217
|
|
288
218
|
# Create job using existing method
|
289
219
|
loop = asyncio.get_event_loop()
|
290
|
-
|
220
|
+
remote_job_creation_data = await loop.run_in_executor(
|
291
221
|
None,
|
292
222
|
partial(
|
293
223
|
self.create_remote_inference_job,
|
@@ -296,12 +226,10 @@ class JobsRemoteInferenceHandler:
|
|
296
226
|
remote_inference_results_visibility=remote_inference_results_visibility,
|
297
227
|
),
|
298
228
|
)
|
299
|
-
if job_info is None:
|
300
|
-
raise ValueError("Remote job creation failed.")
|
301
229
|
|
230
|
+
# Poll using existing method but with async sleep
|
302
231
|
return await loop.run_in_executor(
|
303
|
-
None,
|
304
|
-
partial(self.poll_remote_inference_job, job_info),
|
232
|
+
None, partial(self.poll_remote_inference_job, remote_job_creation_data)
|
305
233
|
)
|
306
234
|
|
307
235
|
|
@@ -1,15 +1,8 @@
|
|
1
|
-
from typing import Optional
|
2
1
|
from collections import UserDict
|
3
2
|
from edsl.jobs.buckets.TokenBucket import TokenBucket
|
4
3
|
from edsl.jobs.buckets.ModelBuckets import ModelBuckets
|
5
4
|
|
6
|
-
# from functools import wraps
|
7
|
-
from threading import RLock
|
8
5
|
|
9
|
-
from edsl.jobs.decorators import synchronized_class
|
10
|
-
|
11
|
-
|
12
|
-
@synchronized_class
|
13
6
|
class BucketCollection(UserDict):
|
14
7
|
"""A Jobs object will have a whole collection of model buckets, as multiple models could be used.
|
15
8
|
|
@@ -17,43 +10,11 @@ class BucketCollection(UserDict):
|
|
17
10
|
Models themselves are hashable, so this works.
|
18
11
|
"""
|
19
12
|
|
20
|
-
def __init__(self, infinity_buckets
|
21
|
-
"""Create a new BucketCollection.
|
22
|
-
An infinity bucket is a bucket that never runs out of tokens or requests.
|
23
|
-
"""
|
13
|
+
def __init__(self, infinity_buckets=False):
|
24
14
|
super().__init__()
|
25
15
|
self.infinity_buckets = infinity_buckets
|
26
16
|
self.models_to_services = {}
|
27
17
|
self.services_to_buckets = {}
|
28
|
-
self._lock = RLock()
|
29
|
-
|
30
|
-
from edsl.config import CONFIG
|
31
|
-
import os
|
32
|
-
|
33
|
-
url = os.environ.get("EDSL_REMOTE_TOKEN_BUCKET_URL", None)
|
34
|
-
|
35
|
-
if url == "None" or url is None:
|
36
|
-
self.remote_url = None
|
37
|
-
# print(f"Using remote token bucket URL: {url}")
|
38
|
-
else:
|
39
|
-
self.remote_url = url
|
40
|
-
|
41
|
-
@classmethod
|
42
|
-
def from_models(
|
43
|
-
cls, models_list: list, infinity_buckets: bool = False
|
44
|
-
) -> "BucketCollection":
|
45
|
-
"""Create a BucketCollection from a list of models."""
|
46
|
-
bucket_collection = cls(infinity_buckets=infinity_buckets)
|
47
|
-
for model in models_list:
|
48
|
-
bucket_collection.add_model(model)
|
49
|
-
return bucket_collection
|
50
|
-
|
51
|
-
def get_tokens(
|
52
|
-
self, model: "LanguageModel", bucket_type: str, num_tokens: int
|
53
|
-
) -> int:
|
54
|
-
"""Get the number of tokens remaining in the bucket."""
|
55
|
-
relevant_bucket = getattr(self[model], bucket_type)
|
56
|
-
return relevant_bucket.get_tokens(num_tokens)
|
57
18
|
|
58
19
|
def __repr__(self):
|
59
20
|
return f"BucketCollection({self.data})"
|
@@ -65,8 +26,8 @@ class BucketCollection(UserDict):
|
|
65
26
|
|
66
27
|
# compute the TPS and RPS from the model
|
67
28
|
if not self.infinity_buckets:
|
68
|
-
TPS = model.
|
69
|
-
RPS = model.
|
29
|
+
TPS = model.TPM / 60.0
|
30
|
+
RPS = model.RPM / 60.0
|
70
31
|
else:
|
71
32
|
TPS = float("inf")
|
72
33
|
RPS = float("inf")
|
@@ -79,14 +40,12 @@ class BucketCollection(UserDict):
|
|
79
40
|
bucket_type="requests",
|
80
41
|
capacity=RPS,
|
81
42
|
refill_rate=RPS,
|
82
|
-
remote_url=self.remote_url,
|
83
43
|
)
|
84
44
|
tokens_bucket = TokenBucket(
|
85
45
|
bucket_name=service,
|
86
46
|
bucket_type="tokens",
|
87
47
|
capacity=TPS,
|
88
48
|
refill_rate=TPS,
|
89
|
-
remote_url=self.remote_url,
|
90
49
|
)
|
91
50
|
self.services_to_buckets[service] = ModelBuckets(
|
92
51
|
requests_bucket, tokens_bucket
|