edsl 0.1.38.dev4__py3-none-any.whl → 0.1.39__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +197 -116
- edsl/__init__.py +15 -7
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +351 -147
- edsl/agents/AgentList.py +211 -73
- edsl/agents/Invigilator.py +101 -50
- edsl/agents/InvigilatorBase.py +62 -70
- edsl/agents/PromptConstructor.py +143 -225
- edsl/agents/QuestionInstructionPromptBuilder.py +128 -0
- edsl/agents/QuestionTemplateReplacementsBuilder.py +137 -0
- edsl/agents/__init__.py +0 -1
- edsl/agents/prompt_helpers.py +3 -3
- edsl/agents/question_option_processor.py +172 -0
- edsl/auto/AutoStudy.py +18 -5
- edsl/auto/StageBase.py +53 -40
- edsl/auto/StageQuestions.py +2 -1
- edsl/auto/utilities.py +0 -6
- edsl/config.py +22 -2
- edsl/conversation/car_buying.py +2 -1
- edsl/coop/CoopFunctionsMixin.py +15 -0
- edsl/coop/ExpectedParrotKeyHandler.py +125 -0
- edsl/coop/PriceFetcher.py +1 -1
- edsl/coop/coop.py +125 -47
- edsl/coop/utils.py +14 -14
- edsl/data/Cache.py +45 -27
- edsl/data/CacheEntry.py +12 -15
- edsl/data/CacheHandler.py +31 -12
- edsl/data/RemoteCacheSync.py +154 -46
- edsl/data/__init__.py +4 -3
- edsl/data_transfer_models.py +2 -1
- edsl/enums.py +27 -0
- edsl/exceptions/__init__.py +50 -50
- edsl/exceptions/agents.py +12 -0
- edsl/exceptions/inference_services.py +5 -0
- edsl/exceptions/questions.py +24 -6
- edsl/exceptions/scenarios.py +7 -0
- edsl/inference_services/AnthropicService.py +38 -19
- edsl/inference_services/AvailableModelCacheHandler.py +184 -0
- edsl/inference_services/AvailableModelFetcher.py +215 -0
- edsl/inference_services/AwsBedrock.py +0 -2
- edsl/inference_services/AzureAI.py +0 -2
- edsl/inference_services/GoogleService.py +7 -12
- edsl/inference_services/InferenceServiceABC.py +18 -85
- edsl/inference_services/InferenceServicesCollection.py +120 -79
- edsl/inference_services/MistralAIService.py +0 -3
- edsl/inference_services/OpenAIService.py +47 -35
- edsl/inference_services/PerplexityService.py +0 -3
- edsl/inference_services/ServiceAvailability.py +135 -0
- edsl/inference_services/TestService.py +11 -10
- edsl/inference_services/TogetherAIService.py +5 -3
- edsl/inference_services/data_structures.py +134 -0
- edsl/jobs/AnswerQuestionFunctionConstructor.py +223 -0
- edsl/jobs/Answers.py +1 -14
- edsl/jobs/FetchInvigilator.py +47 -0
- edsl/jobs/InterviewTaskManager.py +98 -0
- edsl/jobs/InterviewsConstructor.py +50 -0
- edsl/jobs/Jobs.py +356 -431
- edsl/jobs/JobsChecks.py +35 -10
- edsl/jobs/JobsComponentConstructor.py +189 -0
- edsl/jobs/JobsPrompts.py +6 -4
- edsl/jobs/JobsRemoteInferenceHandler.py +205 -133
- edsl/jobs/JobsRemoteInferenceLogger.py +239 -0
- edsl/jobs/RequestTokenEstimator.py +30 -0
- edsl/jobs/async_interview_runner.py +138 -0
- edsl/jobs/buckets/BucketCollection.py +44 -3
- edsl/jobs/buckets/TokenBucket.py +53 -21
- edsl/jobs/buckets/TokenBucketAPI.py +211 -0
- edsl/jobs/buckets/TokenBucketClient.py +191 -0
- edsl/jobs/check_survey_scenario_compatibility.py +85 -0
- edsl/jobs/data_structures.py +120 -0
- edsl/jobs/decorators.py +35 -0
- edsl/jobs/interviews/Interview.py +143 -408
- edsl/jobs/jobs_status_enums.py +9 -0
- edsl/jobs/loggers/HTMLTableJobLogger.py +304 -0
- edsl/jobs/results_exceptions_handler.py +98 -0
- edsl/jobs/runners/JobsRunnerAsyncio.py +88 -403
- edsl/jobs/runners/JobsRunnerStatus.py +133 -165
- edsl/jobs/tasks/QuestionTaskCreator.py +21 -19
- edsl/jobs/tasks/TaskHistory.py +38 -18
- edsl/jobs/tasks/task_status_enum.py +0 -2
- edsl/language_models/ComputeCost.py +63 -0
- edsl/language_models/LanguageModel.py +194 -236
- edsl/language_models/ModelList.py +28 -19
- edsl/language_models/PriceManager.py +127 -0
- edsl/language_models/RawResponseHandler.py +106 -0
- edsl/language_models/ServiceDataSources.py +0 -0
- edsl/language_models/__init__.py +1 -2
- edsl/language_models/key_management/KeyLookup.py +63 -0
- edsl/language_models/key_management/KeyLookupBuilder.py +273 -0
- edsl/language_models/key_management/KeyLookupCollection.py +38 -0
- edsl/language_models/key_management/__init__.py +0 -0
- edsl/language_models/key_management/models.py +131 -0
- edsl/language_models/model.py +256 -0
- edsl/language_models/repair.py +2 -2
- edsl/language_models/utilities.py +5 -4
- edsl/notebooks/Notebook.py +19 -14
- edsl/notebooks/NotebookToLaTeX.py +142 -0
- edsl/prompts/Prompt.py +29 -39
- edsl/questions/ExceptionExplainer.py +77 -0
- edsl/questions/HTMLQuestion.py +103 -0
- edsl/questions/QuestionBase.py +68 -214
- edsl/questions/QuestionBasePromptsMixin.py +7 -3
- edsl/questions/QuestionBudget.py +1 -1
- edsl/questions/QuestionCheckBox.py +3 -3
- edsl/questions/QuestionExtract.py +5 -7
- edsl/questions/QuestionFreeText.py +2 -3
- edsl/questions/QuestionList.py +10 -18
- edsl/questions/QuestionMatrix.py +265 -0
- edsl/questions/QuestionMultipleChoice.py +67 -23
- edsl/questions/QuestionNumerical.py +2 -4
- edsl/questions/QuestionRank.py +7 -17
- edsl/questions/SimpleAskMixin.py +4 -3
- edsl/questions/__init__.py +2 -1
- edsl/questions/{AnswerValidatorMixin.py → answer_validator_mixin.py} +47 -2
- edsl/questions/data_structures.py +20 -0
- edsl/questions/derived/QuestionLinearScale.py +6 -3
- edsl/questions/derived/QuestionTopK.py +1 -1
- edsl/questions/descriptors.py +17 -3
- edsl/questions/loop_processor.py +149 -0
- edsl/questions/{QuestionBaseGenMixin.py → question_base_gen_mixin.py} +57 -50
- edsl/questions/question_registry.py +1 -1
- edsl/questions/{ResponseValidatorABC.py → response_validator_abc.py} +40 -26
- edsl/questions/response_validator_factory.py +34 -0
- edsl/questions/templates/matrix/__init__.py +1 -0
- edsl/questions/templates/matrix/answering_instructions.jinja +5 -0
- edsl/questions/templates/matrix/question_presentation.jinja +20 -0
- edsl/results/CSSParameterizer.py +1 -1
- edsl/results/Dataset.py +170 -7
- edsl/results/DatasetExportMixin.py +168 -305
- edsl/results/DatasetTree.py +28 -8
- edsl/results/MarkdownToDocx.py +122 -0
- edsl/results/MarkdownToPDF.py +111 -0
- edsl/results/Result.py +298 -206
- edsl/results/Results.py +149 -131
- edsl/results/ResultsExportMixin.py +2 -0
- edsl/results/TableDisplay.py +98 -171
- edsl/results/TextEditor.py +50 -0
- edsl/results/__init__.py +1 -1
- edsl/results/file_exports.py +252 -0
- edsl/results/{Selector.py → results_selector.py} +23 -13
- edsl/results/smart_objects.py +96 -0
- edsl/results/table_data_class.py +12 -0
- edsl/results/table_renderers.py +118 -0
- edsl/scenarios/ConstructDownloadLink.py +109 -0
- edsl/scenarios/DocumentChunker.py +102 -0
- edsl/scenarios/DocxScenario.py +16 -0
- edsl/scenarios/FileStore.py +150 -239
- edsl/scenarios/PdfExtractor.py +40 -0
- edsl/scenarios/Scenario.py +90 -193
- edsl/scenarios/ScenarioHtmlMixin.py +4 -3
- edsl/scenarios/ScenarioList.py +415 -244
- edsl/scenarios/ScenarioListExportMixin.py +0 -7
- edsl/scenarios/ScenarioListPdfMixin.py +15 -37
- edsl/scenarios/__init__.py +1 -2
- edsl/scenarios/directory_scanner.py +96 -0
- edsl/scenarios/file_methods.py +85 -0
- edsl/scenarios/handlers/__init__.py +13 -0
- edsl/scenarios/handlers/csv.py +49 -0
- edsl/scenarios/handlers/docx.py +76 -0
- edsl/scenarios/handlers/html.py +37 -0
- edsl/scenarios/handlers/json.py +111 -0
- edsl/scenarios/handlers/latex.py +5 -0
- edsl/scenarios/handlers/md.py +51 -0
- edsl/scenarios/handlers/pdf.py +68 -0
- edsl/scenarios/handlers/png.py +39 -0
- edsl/scenarios/handlers/pptx.py +105 -0
- edsl/scenarios/handlers/py.py +294 -0
- edsl/scenarios/handlers/sql.py +313 -0
- edsl/scenarios/handlers/sqlite.py +149 -0
- edsl/scenarios/handlers/txt.py +33 -0
- edsl/scenarios/{ScenarioJoin.py → scenario_join.py} +10 -6
- edsl/scenarios/scenario_selector.py +156 -0
- edsl/study/ObjectEntry.py +1 -1
- edsl/study/SnapShot.py +1 -1
- edsl/study/Study.py +5 -12
- edsl/surveys/ConstructDAG.py +92 -0
- edsl/surveys/EditSurvey.py +221 -0
- edsl/surveys/InstructionHandler.py +100 -0
- edsl/surveys/MemoryManagement.py +72 -0
- edsl/surveys/Rule.py +5 -4
- edsl/surveys/RuleCollection.py +25 -27
- edsl/surveys/RuleManager.py +172 -0
- edsl/surveys/Simulator.py +75 -0
- edsl/surveys/Survey.py +270 -791
- edsl/surveys/SurveyCSS.py +20 -8
- edsl/surveys/{SurveyFlowVisualizationMixin.py → SurveyFlowVisualization.py} +11 -9
- edsl/surveys/SurveyToApp.py +141 -0
- edsl/surveys/__init__.py +4 -2
- edsl/surveys/descriptors.py +6 -2
- edsl/surveys/instructions/ChangeInstruction.py +1 -2
- edsl/surveys/instructions/Instruction.py +4 -13
- edsl/surveys/instructions/InstructionCollection.py +11 -6
- edsl/templates/error_reporting/interview_details.html +1 -1
- edsl/templates/error_reporting/report.html +1 -1
- edsl/tools/plotting.py +1 -1
- edsl/utilities/PrettyList.py +56 -0
- edsl/utilities/is_notebook.py +18 -0
- edsl/utilities/is_valid_variable_name.py +11 -0
- edsl/utilities/remove_edsl_version.py +24 -0
- edsl/utilities/utilities.py +35 -23
- {edsl-0.1.38.dev4.dist-info → edsl-0.1.39.dist-info}/METADATA +12 -10
- edsl-0.1.39.dist-info/RECORD +358 -0
- {edsl-0.1.38.dev4.dist-info → edsl-0.1.39.dist-info}/WHEEL +1 -1
- edsl/language_models/KeyLookup.py +0 -30
- edsl/language_models/registry.py +0 -190
- edsl/language_models/unused/ReplicateBase.py +0 -83
- edsl/results/ResultsDBMixin.py +0 -238
- edsl-0.1.38.dev4.dist-info/RECORD +0 -277
- /edsl/questions/{RegisterQuestionsMeta.py → register_questions_meta.py} +0 -0
- /edsl/results/{ResultsFetchMixin.py → results_fetch_mixin.py} +0 -0
- /edsl/results/{ResultsToolsMixin.py → results_tools_mixin.py} +0 -0
- {edsl-0.1.38.dev4.dist-info → edsl-0.1.39.dist-info}/LICENSE +0 -0
@@ -0,0 +1,125 @@
|
|
1
|
+
from pathlib import Path
|
2
|
+
import os
|
3
|
+
import platformdirs
|
4
|
+
|
5
|
+
|
6
|
+
import sys
|
7
|
+
import select
|
8
|
+
|
9
|
+
|
10
|
+
def get_input_with_timeout(prompt, timeout=5, default="y"):
|
11
|
+
print(prompt, end="", flush=True)
|
12
|
+
ready, _, _ = select.select([sys.stdin], [], [], timeout)
|
13
|
+
if ready:
|
14
|
+
return sys.stdin.readline().strip()
|
15
|
+
print(f"\nNo input received within {timeout} seconds. Using default: {default}")
|
16
|
+
return default
|
17
|
+
|
18
|
+
|
19
|
+
class ExpectedParrotKeyHandler:
|
20
|
+
asked_to_store_file_name = "asked_to_store.txt"
|
21
|
+
ep_key_file_name = "ep_api_key.txt"
|
22
|
+
application_name = "edsl"
|
23
|
+
|
24
|
+
@property
|
25
|
+
def config_dir(self):
|
26
|
+
return platformdirs.user_config_dir(self.application_name)
|
27
|
+
|
28
|
+
def _ep_key_file_exists(self) -> bool:
|
29
|
+
"""Check if the Expected Parrot key file exists."""
|
30
|
+
return Path(self.config_dir).joinpath(self.ep_key_file_name).exists()
|
31
|
+
|
32
|
+
def ok_to_ask_to_store(self):
|
33
|
+
"""Check if it's okay to ask the user to store the key."""
|
34
|
+
from edsl.config import CONFIG
|
35
|
+
|
36
|
+
if CONFIG.get("EDSL_RUN_MODE") != "production":
|
37
|
+
return False
|
38
|
+
|
39
|
+
return (
|
40
|
+
not Path(self.config_dir).joinpath(self.asked_to_store_file_name).exists()
|
41
|
+
)
|
42
|
+
|
43
|
+
def reset_asked_to_store(self):
|
44
|
+
"""Reset the flag that indicates whether the user has been asked to store the key."""
|
45
|
+
asked_to_store_path = Path(self.config_dir).joinpath(
|
46
|
+
self.asked_to_store_file_name
|
47
|
+
)
|
48
|
+
if asked_to_store_path.exists():
|
49
|
+
os.remove(asked_to_store_path)
|
50
|
+
print(
|
51
|
+
"Deleted the file that indicates whether the user has been asked to store the key."
|
52
|
+
)
|
53
|
+
|
54
|
+
def ask_to_store(self, api_key) -> bool:
|
55
|
+
"""Ask the user if they want to store the Expected Parrot key. If they say "yes", store it."""
|
56
|
+
if self.ok_to_ask_to_store():
|
57
|
+
# can_we_store = get_input_with_timeout(
|
58
|
+
# "Would you like to store your Expected Parrot key for future use? (y/n): ",
|
59
|
+
# timeout=5,
|
60
|
+
# default="y",
|
61
|
+
# )
|
62
|
+
can_we_store = "y"
|
63
|
+
if can_we_store.lower() == "y":
|
64
|
+
Path(self.config_dir).mkdir(parents=True, exist_ok=True)
|
65
|
+
self.store_ep_api_key(api_key)
|
66
|
+
# print("Stored Expected Parrot API key at ", self.config_dir)
|
67
|
+
return True
|
68
|
+
else:
|
69
|
+
Path(self.config_dir).mkdir(parents=True, exist_ok=True)
|
70
|
+
with open(
|
71
|
+
Path(self.config_dir).joinpath(self.asked_to_store_file_name), "w"
|
72
|
+
) as f:
|
73
|
+
f.write("Yes")
|
74
|
+
return False
|
75
|
+
|
76
|
+
def get_ep_api_key(self):
|
77
|
+
# check if the key is stored in the config_dir
|
78
|
+
api_key = None
|
79
|
+
api_key_from_cache = None
|
80
|
+
api_key_from_os = None
|
81
|
+
|
82
|
+
if self._ep_key_file_exists():
|
83
|
+
with open(Path(self.config_dir).joinpath(self.ep_key_file_name), "r") as f:
|
84
|
+
api_key_from_cache = f.read().strip()
|
85
|
+
|
86
|
+
api_key_from_os = os.getenv("EXPECTED_PARROT_API_KEY")
|
87
|
+
|
88
|
+
if api_key_from_os and api_key_from_cache:
|
89
|
+
if api_key_from_os != api_key_from_cache:
|
90
|
+
import warnings
|
91
|
+
|
92
|
+
warnings.warn(
|
93
|
+
"WARNING: The Expected Parrot API key from the environment variable "
|
94
|
+
"differs from the one stored in the config directory. Using the one "
|
95
|
+
"from the environment variable."
|
96
|
+
)
|
97
|
+
api_key = api_key_from_os
|
98
|
+
|
99
|
+
if api_key_from_os and not api_key_from_cache:
|
100
|
+
api_key = api_key_from_os
|
101
|
+
|
102
|
+
if not api_key_from_os and api_key_from_cache:
|
103
|
+
api_key = api_key_from_cache
|
104
|
+
|
105
|
+
if api_key is not None:
|
106
|
+
_ = self.ask_to_store(api_key)
|
107
|
+
return api_key
|
108
|
+
|
109
|
+
def delete_ep_api_key(self):
|
110
|
+
key_path = Path(self.config_dir) / self.ep_key_file_name
|
111
|
+
if key_path.exists():
|
112
|
+
os.remove(key_path)
|
113
|
+
print("Deleted Expected Parrot API key at ", key_path)
|
114
|
+
|
115
|
+
def store_ep_api_key(self, api_key):
|
116
|
+
# Create the directory if it doesn't exist
|
117
|
+
os.makedirs(self.config_dir, exist_ok=True)
|
118
|
+
|
119
|
+
# Create the path for the key file
|
120
|
+
key_path = Path(self.config_dir) / self.ep_key_file_name
|
121
|
+
|
122
|
+
# Save the key
|
123
|
+
with open(key_path, "w") as f:
|
124
|
+
f.write(api_key)
|
125
|
+
# print("Stored Expected Parrot API key at ", key_path)
|
edsl/coop/PriceFetcher.py
CHANGED
edsl/coop/coop.py
CHANGED
@@ -1,11 +1,19 @@
|
|
1
1
|
import aiohttp
|
2
2
|
import json
|
3
|
-
import os
|
4
3
|
import requests
|
5
|
-
|
4
|
+
|
5
|
+
from typing import Any, Optional, Union, Literal, TypedDict
|
6
6
|
from uuid import UUID
|
7
|
+
from collections import UserDict, defaultdict
|
8
|
+
|
7
9
|
import edsl
|
8
|
-
from
|
10
|
+
from pathlib import Path
|
11
|
+
|
12
|
+
from edsl.config import CONFIG
|
13
|
+
from edsl.data.CacheEntry import CacheEntry
|
14
|
+
from edsl.jobs.Jobs import Jobs
|
15
|
+
from edsl.surveys.Survey import Survey
|
16
|
+
|
9
17
|
from edsl.exceptions.coop import CoopNoUUIDError, CoopServerResponseError
|
10
18
|
from edsl.coop.utils import (
|
11
19
|
EDSLObject,
|
@@ -15,19 +23,48 @@ from edsl.coop.utils import (
|
|
15
23
|
VisibilityType,
|
16
24
|
)
|
17
25
|
|
26
|
+
from edsl.coop.CoopFunctionsMixin import CoopFunctionsMixin
|
27
|
+
from edsl.coop.ExpectedParrotKeyHandler import ExpectedParrotKeyHandler
|
28
|
+
|
29
|
+
from edsl.inference_services.data_structures import ServiceToModelsMapping
|
30
|
+
|
18
31
|
|
19
|
-
class
|
32
|
+
class RemoteInferenceResponse(TypedDict):
|
33
|
+
job_uuid: str
|
34
|
+
results_uuid: str
|
35
|
+
results_url: str
|
36
|
+
latest_error_report_uuid: str
|
37
|
+
latest_error_report_url: str
|
38
|
+
status: str
|
39
|
+
reason: str
|
40
|
+
credits_consumed: float
|
41
|
+
version: str
|
42
|
+
|
43
|
+
|
44
|
+
class RemoteInferenceCreationInfo(TypedDict):
|
45
|
+
uuid: str
|
46
|
+
description: str
|
47
|
+
status: str
|
48
|
+
iterations: int
|
49
|
+
visibility: str
|
50
|
+
version: str
|
51
|
+
|
52
|
+
|
53
|
+
class Coop(CoopFunctionsMixin):
|
20
54
|
"""
|
21
55
|
Client for the Expected Parrot API.
|
22
56
|
"""
|
23
57
|
|
24
|
-
def __init__(
|
58
|
+
def __init__(
|
59
|
+
self, api_key: Optional[str] = None, url: Optional[str] = None
|
60
|
+
) -> None:
|
25
61
|
"""
|
26
62
|
Initialize the client.
|
27
63
|
- Provide an API key directly, or through an env variable.
|
28
64
|
- Provide a URL directly, or use the default one.
|
29
65
|
"""
|
30
|
-
self.
|
66
|
+
self.ep_key_handler = ExpectedParrotKeyHandler()
|
67
|
+
self.api_key = api_key or self.ep_key_handler.get_ep_api_key()
|
31
68
|
|
32
69
|
self.url = url or CONFIG.EXPECTED_PARROT_URL
|
33
70
|
if self.url.endswith("/"):
|
@@ -142,6 +179,7 @@ class Coop:
|
|
142
179
|
Check the response from the server and raise errors as appropriate.
|
143
180
|
"""
|
144
181
|
# Get EDSL version from header
|
182
|
+
# breakpoint()
|
145
183
|
server_edsl_version = response.headers.get("X-EDSL-Version")
|
146
184
|
|
147
185
|
if server_edsl_version:
|
@@ -150,11 +188,18 @@ class Coop:
|
|
150
188
|
server_version_str=server_edsl_version,
|
151
189
|
):
|
152
190
|
print(
|
153
|
-
"Please upgrade your EDSL version to access our latest features. To upgrade, open your terminal and run `pip upgrade edsl`"
|
191
|
+
"Please upgrade your EDSL version to access our latest features. To upgrade, open your terminal and run `pip install --upgrade edsl`"
|
154
192
|
)
|
155
193
|
|
156
194
|
if response.status_code >= 400:
|
157
|
-
|
195
|
+
try:
|
196
|
+
message = response.json().get("detail")
|
197
|
+
except json.JSONDecodeError:
|
198
|
+
raise CoopServerResponseError(
|
199
|
+
f"Server returned status code {response.status_code}."
|
200
|
+
"JSON response could not be decoded.",
|
201
|
+
"The server response was: " + response.text,
|
202
|
+
)
|
158
203
|
# print(response.text)
|
159
204
|
if "The API key you provided is invalid" in message and check_api_key:
|
160
205
|
import secrets
|
@@ -163,19 +208,27 @@ class Coop:
|
|
163
208
|
edsl_auth_token = secrets.token_urlsafe(16)
|
164
209
|
|
165
210
|
print("Your Expected Parrot API key is invalid.")
|
166
|
-
|
167
|
-
|
211
|
+
self._display_login_url(
|
212
|
+
edsl_auth_token=edsl_auth_token,
|
213
|
+
link_description="\n🔗 Use the link below to log in to Expected Parrot so we can automatically update your API key.",
|
168
214
|
)
|
169
|
-
self._display_login_url(edsl_auth_token=edsl_auth_token)
|
170
215
|
api_key = self._poll_for_api_key(edsl_auth_token)
|
171
216
|
|
172
217
|
if api_key is None:
|
173
218
|
print("\nTimed out waiting for login. Please try again.")
|
174
219
|
return
|
175
220
|
|
176
|
-
|
177
|
-
|
178
|
-
|
221
|
+
print("\n✨ API key retrieved.")
|
222
|
+
|
223
|
+
if stored_in_user_space := self.ep_key_handler.ask_to_store(api_key):
|
224
|
+
pass
|
225
|
+
else:
|
226
|
+
path_to_env = write_api_key_to_env(api_key)
|
227
|
+
print(
|
228
|
+
"\n✨ API key retrieved and written to .env file at the following path:"
|
229
|
+
)
|
230
|
+
print(f" {path_to_env}")
|
231
|
+
print("Rerun your code to try again with a valid API key.")
|
179
232
|
return
|
180
233
|
|
181
234
|
elif "Authorization" in message:
|
@@ -268,6 +321,7 @@ class Coop:
|
|
268
321
|
self,
|
269
322
|
object: EDSLObject,
|
270
323
|
description: Optional[str] = None,
|
324
|
+
alias: Optional[str] = None,
|
271
325
|
visibility: Optional[VisibilityType] = "unlisted",
|
272
326
|
) -> dict:
|
273
327
|
"""
|
@@ -279,6 +333,7 @@ class Coop:
|
|
279
333
|
method="POST",
|
280
334
|
payload={
|
281
335
|
"description": description,
|
336
|
+
"alias": alias,
|
282
337
|
"json_string": json.dumps(
|
283
338
|
object.to_dict(),
|
284
339
|
default=self._json_handle_none,
|
@@ -373,6 +428,7 @@ class Coop:
|
|
373
428
|
uuid: Union[str, UUID] = None,
|
374
429
|
url: str = None,
|
375
430
|
description: Optional[str] = None,
|
431
|
+
alias: Optional[str] = None,
|
376
432
|
value: Optional[EDSLObject] = None,
|
377
433
|
visibility: Optional[VisibilityType] = None,
|
378
434
|
) -> dict:
|
@@ -389,6 +445,7 @@ class Coop:
|
|
389
445
|
params={"uuid": uuid},
|
390
446
|
payload={
|
391
447
|
"description": description,
|
448
|
+
"alias": alias,
|
392
449
|
"json_string": (
|
393
450
|
json.dumps(
|
394
451
|
value.to_dict(),
|
@@ -602,9 +659,6 @@ class Coop:
|
|
602
659
|
self._resolve_server_response(response)
|
603
660
|
return response.json()
|
604
661
|
|
605
|
-
################
|
606
|
-
# Remote Inference
|
607
|
-
################
|
608
662
|
def remote_inference_create(
|
609
663
|
self,
|
610
664
|
job: Jobs,
|
@@ -613,7 +667,7 @@ class Coop:
|
|
613
667
|
visibility: Optional[VisibilityType] = "unlisted",
|
614
668
|
initial_results_visibility: Optional[VisibilityType] = "unlisted",
|
615
669
|
iterations: Optional[int] = 1,
|
616
|
-
) ->
|
670
|
+
) -> RemoteInferenceCreationInfo:
|
617
671
|
"""
|
618
672
|
Send a remote inference job to the server.
|
619
673
|
|
@@ -645,18 +699,21 @@ class Coop:
|
|
645
699
|
)
|
646
700
|
self._resolve_server_response(response)
|
647
701
|
response_json = response.json()
|
648
|
-
|
649
|
-
|
650
|
-
|
651
|
-
|
652
|
-
|
653
|
-
|
654
|
-
|
655
|
-
|
702
|
+
|
703
|
+
return RemoteInferenceCreationInfo(
|
704
|
+
**{
|
705
|
+
"uuid": response_json.get("job_uuid"),
|
706
|
+
"description": response_json.get("description"),
|
707
|
+
"status": response_json.get("status"),
|
708
|
+
"iterations": response_json.get("iterations"),
|
709
|
+
"visibility": response_json.get("visibility"),
|
710
|
+
"version": self._edsl_version,
|
711
|
+
}
|
712
|
+
)
|
656
713
|
|
657
714
|
def remote_inference_get(
|
658
715
|
self, job_uuid: Optional[str] = None, results_uuid: Optional[str] = None
|
659
|
-
) ->
|
716
|
+
) -> RemoteInferenceResponse:
|
660
717
|
"""
|
661
718
|
Get the details of a remote inference job.
|
662
719
|
You can pass either the job uuid or the results uuid as a parameter.
|
@@ -698,17 +755,30 @@ class Coop:
|
|
698
755
|
f"{self.url}/home/remote-inference/error/{latest_error_report_uuid}"
|
699
756
|
)
|
700
757
|
|
701
|
-
return
|
702
|
-
|
703
|
-
|
704
|
-
|
705
|
-
|
706
|
-
|
707
|
-
|
708
|
-
|
709
|
-
|
710
|
-
|
711
|
-
|
758
|
+
return RemoteInferenceResponse(
|
759
|
+
**{
|
760
|
+
"job_uuid": data.get("job_uuid"),
|
761
|
+
"results_uuid": results_uuid,
|
762
|
+
"results_url": results_url,
|
763
|
+
"latest_error_report_uuid": latest_error_report_uuid,
|
764
|
+
"latest_error_report_url": latest_error_report_url,
|
765
|
+
"status": data.get("status"),
|
766
|
+
"reason": data.get("reason"),
|
767
|
+
"credits_consumed": data.get("price"),
|
768
|
+
"version": data.get("version"),
|
769
|
+
}
|
770
|
+
)
|
771
|
+
|
772
|
+
def get_running_jobs(self) -> list[str]:
|
773
|
+
"""
|
774
|
+
Get a list of currently running job IDs.
|
775
|
+
|
776
|
+
Returns:
|
777
|
+
list[str]: List of running job UUIDs
|
778
|
+
"""
|
779
|
+
response = self._send_server_request(uri="jobs/status", method="GET")
|
780
|
+
self._resolve_server_response(response)
|
781
|
+
return response.json().get("running_jobs", [])
|
712
782
|
|
713
783
|
def remote_inference_cost(
|
714
784
|
self, input: Union[Jobs, Survey], iterations: int = 1
|
@@ -810,7 +880,7 @@ class Coop:
|
|
810
880
|
"Invalid EDSL_FETCH_TOKEN_PRICES value---should be 'True' or 'False'."
|
811
881
|
)
|
812
882
|
|
813
|
-
def fetch_models(self) ->
|
883
|
+
def fetch_models(self) -> ServiceToModelsMapping:
|
814
884
|
"""
|
815
885
|
Fetch a dict of available models from Coop.
|
816
886
|
|
@@ -819,7 +889,7 @@ class Coop:
|
|
819
889
|
response = self._send_server_request(uri="api/v0/models", method="GET")
|
820
890
|
self._resolve_server_response(response)
|
821
891
|
data = response.json()
|
822
|
-
return data
|
892
|
+
return ServiceToModelsMapping(data)
|
823
893
|
|
824
894
|
def fetch_rate_limit_config_vars(self) -> dict:
|
825
895
|
"""
|
@@ -835,7 +905,9 @@ class Coop:
|
|
835
905
|
data = response.json()
|
836
906
|
return data
|
837
907
|
|
838
|
-
def _display_login_url(
|
908
|
+
def _display_login_url(
|
909
|
+
self, edsl_auth_token: str, link_description: Optional[str] = None
|
910
|
+
):
|
839
911
|
"""
|
840
912
|
Uses rich.print to display a login URL.
|
841
913
|
|
@@ -845,7 +917,12 @@ class Coop:
|
|
845
917
|
|
846
918
|
url = f"{CONFIG.EXPECTED_PARROT_URL}/login?edsl_auth_token={edsl_auth_token}"
|
847
919
|
|
848
|
-
|
920
|
+
if link_description:
|
921
|
+
rich_print(
|
922
|
+
f"{link_description}\n [#38bdf8][link={url}]{url}[/link][/#38bdf8]"
|
923
|
+
)
|
924
|
+
else:
|
925
|
+
rich_print(f" [#38bdf8][link={url}]{url}[/link][/#38bdf8]")
|
849
926
|
|
850
927
|
def _get_api_key(self, edsl_auth_token: str):
|
851
928
|
"""
|
@@ -873,17 +950,18 @@ class Coop:
|
|
873
950
|
|
874
951
|
edsl_auth_token = secrets.token_urlsafe(16)
|
875
952
|
|
876
|
-
|
877
|
-
|
953
|
+
self._display_login_url(
|
954
|
+
edsl_auth_token=edsl_auth_token,
|
955
|
+
link_description="\n🔗 Use the link below to log in to Expected Parrot so we can automatically update your API key.",
|
878
956
|
)
|
879
|
-
self._display_login_url(edsl_auth_token=edsl_auth_token)
|
880
957
|
api_key = self._poll_for_api_key(edsl_auth_token)
|
881
958
|
|
882
959
|
if api_key is None:
|
883
960
|
raise Exception("Timed out waiting for login. Please try again.")
|
884
961
|
|
885
|
-
write_api_key_to_env(api_key)
|
886
|
-
print("\n✨ API key retrieved and written to .env file
|
962
|
+
path_to_env = write_api_key_to_env(api_key)
|
963
|
+
print("\n✨ API key retrieved and written to .env file at the following path:")
|
964
|
+
print(f" {path_to_env}")
|
887
965
|
|
888
966
|
# Add API key to environment
|
889
967
|
load_dotenv()
|
edsl/coop/utils.py
CHANGED
@@ -1,19 +1,19 @@
|
|
1
|
-
from edsl import (
|
2
|
-
Agent,
|
3
|
-
AgentList,
|
4
|
-
Cache,
|
5
|
-
ModelList,
|
6
|
-
Notebook,
|
7
|
-
Results,
|
8
|
-
Scenario,
|
9
|
-
ScenarioList,
|
10
|
-
Survey,
|
11
|
-
Study,
|
12
|
-
)
|
13
|
-
from edsl.language_models import LanguageModel
|
14
|
-
from edsl.questions import QuestionBase
|
15
1
|
from typing import Literal, Optional, Type, Union
|
16
2
|
|
3
|
+
from edsl.agents.Agent import Agent
|
4
|
+
from edsl.agents.AgentList import AgentList
|
5
|
+
from edsl.data.Cache import Cache
|
6
|
+
from edsl.language_models.ModelList import ModelList
|
7
|
+
from edsl.notebooks.Notebook import Notebook
|
8
|
+
from edsl.results.Results import Results
|
9
|
+
from edsl.scenarios.Scenario import Scenario
|
10
|
+
from edsl.scenarios.ScenarioList import ScenarioList
|
11
|
+
from edsl.surveys.Survey import Survey
|
12
|
+
from edsl.study.Study import Study
|
13
|
+
|
14
|
+
from edsl.language_models.LanguageModel import LanguageModel
|
15
|
+
from edsl.questions.QuestionBase import QuestionBase
|
16
|
+
|
17
17
|
EDSLObject = Union[
|
18
18
|
Agent,
|
19
19
|
AgentList,
|
edsl/data/Cache.py
CHANGED
@@ -6,12 +6,10 @@ from __future__ import annotations
|
|
6
6
|
import json
|
7
7
|
import os
|
8
8
|
import warnings
|
9
|
-
import
|
10
|
-
from typing import Optional, Union
|
9
|
+
from typing import Optional, Union, TYPE_CHECKING
|
11
10
|
from edsl.Base import Base
|
12
|
-
|
13
|
-
from edsl.utilities.
|
14
|
-
from edsl.utilities.decorators import remove_edsl_version
|
11
|
+
|
12
|
+
from edsl.utilities.remove_edsl_version import remove_edsl_version
|
15
13
|
from edsl.exceptions.cache import CacheError
|
16
14
|
|
17
15
|
|
@@ -83,10 +81,6 @@ class Cache(Base):
|
|
83
81
|
|
84
82
|
self._perform_checks()
|
85
83
|
|
86
|
-
def rich_print(sefl):
|
87
|
-
pass
|
88
|
-
# raise NotImplementedError("This method is not implemented yet.")
|
89
|
-
|
90
84
|
def code(sefl):
|
91
85
|
pass
|
92
86
|
# raise NotImplementedError("This method is not implemented yet.")
|
@@ -201,6 +195,7 @@ class Cache(Base):
|
|
201
195
|
>>> len(c)
|
202
196
|
1
|
203
197
|
"""
|
198
|
+
from edsl.data.CacheEntry import CacheEntry
|
204
199
|
|
205
200
|
entry = CacheEntry(
|
206
201
|
model=model,
|
@@ -226,6 +221,7 @@ class Cache(Base):
|
|
226
221
|
|
227
222
|
:param write_now: Whether to write to the cache immediately (similar to `immediate_write`).
|
228
223
|
"""
|
224
|
+
from edsl.data.CacheEntry import CacheEntry
|
229
225
|
|
230
226
|
for key, value in new_data.items():
|
231
227
|
if key in self.data:
|
@@ -246,6 +242,8 @@ class Cache(Base):
|
|
246
242
|
|
247
243
|
:param write_now: Whether to write to the cache immediately (similar to `immediate_write`).
|
248
244
|
"""
|
245
|
+
from edsl.data.CacheEntry import CacheEntry
|
246
|
+
|
249
247
|
with open(filename, "a+") as f:
|
250
248
|
f.seek(0)
|
251
249
|
lines = f.readlines()
|
@@ -289,8 +287,8 @@ class Cache(Base):
|
|
289
287
|
|
290
288
|
CACHE_PATH = CONFIG.get("EDSL_DATABASE_PATH")
|
291
289
|
path = CACHE_PATH.replace("sqlite:///", "")
|
292
|
-
db_path = os.path.join(os.path.dirname(path), "data.db")
|
293
|
-
return cls.from_sqlite_db(
|
290
|
+
# db_path = os.path.join(os.path.dirname(path), "data.db")
|
291
|
+
return cls.from_sqlite_db(path)
|
294
292
|
|
295
293
|
@classmethod
|
296
294
|
def from_jsonl(cls, jsonlfile: str, db_path: Optional[str] = None) -> Cache:
|
@@ -353,7 +351,8 @@ class Cache(Base):
|
|
353
351
|
f.write(json.dumps({key: value.to_dict()}) + "\n")
|
354
352
|
|
355
353
|
def to_scenario_list(self):
|
356
|
-
from edsl import ScenarioList
|
354
|
+
from edsl.scenarios.ScenarioList import ScenarioList
|
355
|
+
from edsl.scenarios.Scenario import Scenario
|
357
356
|
|
358
357
|
scenarios = []
|
359
358
|
for key, value in self.data.items():
|
@@ -363,12 +362,32 @@ class Cache(Base):
|
|
363
362
|
scenarios.append(s)
|
364
363
|
return ScenarioList(scenarios)
|
365
364
|
|
366
|
-
|
367
|
-
|
368
|
-
|
369
|
-
|
370
|
-
|
371
|
-
|
365
|
+
def __floordiv__(self, other: "Cache") -> "Cache":
|
366
|
+
"""
|
367
|
+
Return a new Cache containing entries that are in self but not in other.
|
368
|
+
Uses // operator as alternative to subtraction.
|
369
|
+
|
370
|
+
:param other: Another Cache object to compare against
|
371
|
+
:return: A new Cache object containing unique entries
|
372
|
+
|
373
|
+
>>> from edsl.data.CacheEntry import CacheEntry
|
374
|
+
>>> ce1 = CacheEntry.example(randomize = True)
|
375
|
+
>>> ce2 = CacheEntry.example(randomize = True)
|
376
|
+
>>> ce2 = CacheEntry.example(randomize = True)
|
377
|
+
>>> c1 = Cache(data={ce1.key: ce1, ce2.key: ce2})
|
378
|
+
>>> c2 = Cache(data={ce1.key: ce1})
|
379
|
+
>>> c3 = c1 // c2
|
380
|
+
>>> len(c3)
|
381
|
+
1
|
382
|
+
>>> c3.data[ce2.key] == ce2
|
383
|
+
True
|
384
|
+
"""
|
385
|
+
if not isinstance(other, Cache):
|
386
|
+
raise CacheError("Can only compare two caches")
|
387
|
+
|
388
|
+
diff_data = {k: v for k, v in self.data.items() if k not in other.data}
|
389
|
+
return Cache(data=diff_data, immediate_write=self.immediate_write)
|
390
|
+
|
372
391
|
@classmethod
|
373
392
|
def from_url(cls, db_path=None) -> Cache:
|
374
393
|
"""
|
@@ -394,11 +413,10 @@ class Cache(Base):
|
|
394
413
|
if self.filename:
|
395
414
|
self.write(self.filename)
|
396
415
|
|
397
|
-
####################
|
398
|
-
# DUNDER / USEFUL
|
399
|
-
####################
|
400
416
|
def __hash__(self):
|
401
417
|
"""Return the hash of the Cache."""
|
418
|
+
from edsl.utilities.utilities import dict_hash
|
419
|
+
|
402
420
|
return dict_hash(self.to_dict(add_edsl_version=False))
|
403
421
|
|
404
422
|
def to_dict(self, add_edsl_version=True) -> dict:
|
@@ -414,12 +432,6 @@ class Cache(Base):
|
|
414
432
|
def _summary(self):
|
415
433
|
return {"EDSL Class": "Cache", "Number of entries": len(self.data)}
|
416
434
|
|
417
|
-
def _repr_html_(self):
|
418
|
-
# from edsl.utilities.utilities import data_to_html
|
419
|
-
# return data_to_html(self.to_dict())
|
420
|
-
footer = f"<a href={self.__documentation__}>(docs)</a>"
|
421
|
-
return str(self.summary(format="html")) + footer
|
422
|
-
|
423
435
|
def table(
|
424
436
|
self,
|
425
437
|
*fields,
|
@@ -443,6 +455,8 @@ class Cache(Base):
|
|
443
455
|
@remove_edsl_version
|
444
456
|
def from_dict(cls, data) -> Cache:
|
445
457
|
"""Construct a Cache from a dictionary."""
|
458
|
+
from edsl.data.CacheEntry import CacheEntry
|
459
|
+
|
446
460
|
newdata = {k: CacheEntry.from_dict(v) for k, v in data.items()}
|
447
461
|
return cls(data=newdata)
|
448
462
|
|
@@ -485,6 +499,8 @@ class Cache(Base):
|
|
485
499
|
"""
|
486
500
|
Create an example input for a 'fetch' operation.
|
487
501
|
"""
|
502
|
+
from edsl.data.CacheEntry import CacheEntry
|
503
|
+
|
488
504
|
return CacheEntry.fetch_input_example()
|
489
505
|
|
490
506
|
def to_html(self):
|
@@ -541,6 +557,8 @@ class Cache(Base):
|
|
541
557
|
|
542
558
|
:param randomize: If True, uses CacheEntry's randomize method.
|
543
559
|
"""
|
560
|
+
from edsl.data.CacheEntry import CacheEntry
|
561
|
+
|
544
562
|
return cls(
|
545
563
|
data={
|
546
564
|
CacheEntry.example(randomize).key: CacheEntry.example(),
|