edsl 0.1.39__py3-none-any.whl → 0.1.39.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +116 -197
- edsl/__init__.py +7 -15
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +147 -351
- edsl/agents/AgentList.py +73 -211
- edsl/agents/Invigilator.py +50 -101
- edsl/agents/InvigilatorBase.py +70 -62
- edsl/agents/PromptConstructor.py +225 -143
- edsl/agents/__init__.py +1 -0
- edsl/agents/prompt_helpers.py +3 -3
- edsl/auto/AutoStudy.py +5 -18
- edsl/auto/StageBase.py +40 -53
- edsl/auto/StageQuestions.py +1 -2
- edsl/auto/utilities.py +6 -0
- edsl/config.py +2 -22
- edsl/conversation/car_buying.py +1 -2
- edsl/coop/PriceFetcher.py +1 -1
- edsl/coop/coop.py +47 -125
- edsl/coop/utils.py +14 -14
- edsl/data/Cache.py +27 -45
- edsl/data/CacheEntry.py +15 -12
- edsl/data/CacheHandler.py +12 -31
- edsl/data/RemoteCacheSync.py +46 -154
- edsl/data/__init__.py +3 -4
- edsl/data_transfer_models.py +1 -2
- edsl/enums.py +0 -27
- edsl/exceptions/__init__.py +50 -50
- edsl/exceptions/agents.py +0 -12
- edsl/exceptions/questions.py +6 -24
- edsl/exceptions/scenarios.py +0 -7
- edsl/inference_services/AnthropicService.py +19 -38
- edsl/inference_services/AwsBedrock.py +2 -0
- edsl/inference_services/AzureAI.py +2 -0
- edsl/inference_services/GoogleService.py +12 -7
- edsl/inference_services/InferenceServiceABC.py +85 -18
- edsl/inference_services/InferenceServicesCollection.py +79 -120
- edsl/inference_services/MistralAIService.py +3 -0
- edsl/inference_services/OpenAIService.py +35 -47
- edsl/inference_services/PerplexityService.py +3 -0
- edsl/inference_services/TestService.py +10 -11
- edsl/inference_services/TogetherAIService.py +3 -5
- edsl/jobs/Answers.py +14 -1
- edsl/jobs/Jobs.py +431 -356
- edsl/jobs/JobsChecks.py +10 -35
- edsl/jobs/JobsPrompts.py +4 -6
- edsl/jobs/JobsRemoteInferenceHandler.py +133 -205
- edsl/jobs/buckets/BucketCollection.py +3 -44
- edsl/jobs/buckets/TokenBucket.py +21 -53
- edsl/jobs/interviews/Interview.py +408 -143
- edsl/jobs/runners/JobsRunnerAsyncio.py +403 -88
- edsl/jobs/runners/JobsRunnerStatus.py +165 -133
- edsl/jobs/tasks/QuestionTaskCreator.py +19 -21
- edsl/jobs/tasks/TaskHistory.py +18 -38
- edsl/jobs/tasks/task_status_enum.py +2 -0
- edsl/language_models/KeyLookup.py +30 -0
- edsl/language_models/LanguageModel.py +236 -194
- edsl/language_models/ModelList.py +19 -28
- edsl/language_models/__init__.py +2 -1
- edsl/language_models/registry.py +190 -0
- edsl/language_models/repair.py +2 -2
- edsl/language_models/unused/ReplicateBase.py +83 -0
- edsl/language_models/utilities.py +4 -5
- edsl/notebooks/Notebook.py +14 -19
- edsl/prompts/Prompt.py +39 -29
- edsl/questions/{answer_validator_mixin.py → AnswerValidatorMixin.py} +2 -47
- edsl/questions/QuestionBase.py +214 -68
- edsl/questions/{question_base_gen_mixin.py → QuestionBaseGenMixin.py} +50 -57
- edsl/questions/QuestionBasePromptsMixin.py +3 -7
- edsl/questions/QuestionBudget.py +1 -1
- edsl/questions/QuestionCheckBox.py +3 -3
- edsl/questions/QuestionExtract.py +7 -5
- edsl/questions/QuestionFreeText.py +3 -2
- edsl/questions/QuestionList.py +18 -10
- edsl/questions/QuestionMultipleChoice.py +23 -67
- edsl/questions/QuestionNumerical.py +4 -2
- edsl/questions/QuestionRank.py +17 -7
- edsl/questions/{response_validator_abc.py → ResponseValidatorABC.py} +26 -40
- edsl/questions/SimpleAskMixin.py +3 -4
- edsl/questions/__init__.py +1 -2
- edsl/questions/derived/QuestionLinearScale.py +3 -6
- edsl/questions/derived/QuestionTopK.py +1 -1
- edsl/questions/descriptors.py +3 -17
- edsl/questions/question_registry.py +1 -1
- edsl/results/CSSParameterizer.py +1 -1
- edsl/results/Dataset.py +7 -170
- edsl/results/DatasetExportMixin.py +305 -168
- edsl/results/DatasetTree.py +8 -28
- edsl/results/Result.py +206 -298
- edsl/results/Results.py +131 -149
- edsl/results/ResultsDBMixin.py +238 -0
- edsl/results/ResultsExportMixin.py +0 -2
- edsl/results/{results_selector.py → Selector.py} +13 -23
- edsl/results/TableDisplay.py +171 -98
- edsl/results/__init__.py +1 -1
- edsl/scenarios/FileStore.py +239 -150
- edsl/scenarios/Scenario.py +193 -90
- edsl/scenarios/ScenarioHtmlMixin.py +3 -4
- edsl/scenarios/{scenario_join.py → ScenarioJoin.py} +6 -10
- edsl/scenarios/ScenarioList.py +244 -415
- edsl/scenarios/ScenarioListExportMixin.py +7 -0
- edsl/scenarios/ScenarioListPdfMixin.py +37 -15
- edsl/scenarios/__init__.py +2 -1
- edsl/study/ObjectEntry.py +1 -1
- edsl/study/SnapShot.py +1 -1
- edsl/study/Study.py +12 -5
- edsl/surveys/Rule.py +4 -5
- edsl/surveys/RuleCollection.py +27 -25
- edsl/surveys/Survey.py +791 -270
- edsl/surveys/SurveyCSS.py +8 -20
- edsl/surveys/{SurveyFlowVisualization.py → SurveyFlowVisualizationMixin.py} +9 -11
- edsl/surveys/__init__.py +2 -4
- edsl/surveys/descriptors.py +2 -6
- edsl/surveys/instructions/ChangeInstruction.py +2 -1
- edsl/surveys/instructions/Instruction.py +13 -4
- edsl/surveys/instructions/InstructionCollection.py +6 -11
- edsl/templates/error_reporting/interview_details.html +1 -1
- edsl/templates/error_reporting/report.html +1 -1
- edsl/tools/plotting.py +1 -1
- edsl/utilities/utilities.py +23 -35
- {edsl-0.1.39.dist-info → edsl-0.1.39.dev1.dist-info}/METADATA +10 -12
- edsl-0.1.39.dev1.dist-info/RECORD +277 -0
- {edsl-0.1.39.dist-info → edsl-0.1.39.dev1.dist-info}/WHEEL +1 -1
- edsl/agents/QuestionInstructionPromptBuilder.py +0 -128
- edsl/agents/QuestionTemplateReplacementsBuilder.py +0 -137
- edsl/agents/question_option_processor.py +0 -172
- edsl/coop/CoopFunctionsMixin.py +0 -15
- edsl/coop/ExpectedParrotKeyHandler.py +0 -125
- edsl/exceptions/inference_services.py +0 -5
- edsl/inference_services/AvailableModelCacheHandler.py +0 -184
- edsl/inference_services/AvailableModelFetcher.py +0 -215
- edsl/inference_services/ServiceAvailability.py +0 -135
- edsl/inference_services/data_structures.py +0 -134
- edsl/jobs/AnswerQuestionFunctionConstructor.py +0 -223
- edsl/jobs/FetchInvigilator.py +0 -47
- edsl/jobs/InterviewTaskManager.py +0 -98
- edsl/jobs/InterviewsConstructor.py +0 -50
- edsl/jobs/JobsComponentConstructor.py +0 -189
- edsl/jobs/JobsRemoteInferenceLogger.py +0 -239
- edsl/jobs/RequestTokenEstimator.py +0 -30
- edsl/jobs/async_interview_runner.py +0 -138
- edsl/jobs/buckets/TokenBucketAPI.py +0 -211
- edsl/jobs/buckets/TokenBucketClient.py +0 -191
- edsl/jobs/check_survey_scenario_compatibility.py +0 -85
- edsl/jobs/data_structures.py +0 -120
- edsl/jobs/decorators.py +0 -35
- edsl/jobs/jobs_status_enums.py +0 -9
- edsl/jobs/loggers/HTMLTableJobLogger.py +0 -304
- edsl/jobs/results_exceptions_handler.py +0 -98
- edsl/language_models/ComputeCost.py +0 -63
- edsl/language_models/PriceManager.py +0 -127
- edsl/language_models/RawResponseHandler.py +0 -106
- edsl/language_models/ServiceDataSources.py +0 -0
- edsl/language_models/key_management/KeyLookup.py +0 -63
- edsl/language_models/key_management/KeyLookupBuilder.py +0 -273
- edsl/language_models/key_management/KeyLookupCollection.py +0 -38
- edsl/language_models/key_management/__init__.py +0 -0
- edsl/language_models/key_management/models.py +0 -131
- edsl/language_models/model.py +0 -256
- edsl/notebooks/NotebookToLaTeX.py +0 -142
- edsl/questions/ExceptionExplainer.py +0 -77
- edsl/questions/HTMLQuestion.py +0 -103
- edsl/questions/QuestionMatrix.py +0 -265
- edsl/questions/data_structures.py +0 -20
- edsl/questions/loop_processor.py +0 -149
- edsl/questions/response_validator_factory.py +0 -34
- edsl/questions/templates/matrix/__init__.py +0 -1
- edsl/questions/templates/matrix/answering_instructions.jinja +0 -5
- edsl/questions/templates/matrix/question_presentation.jinja +0 -20
- edsl/results/MarkdownToDocx.py +0 -122
- edsl/results/MarkdownToPDF.py +0 -111
- edsl/results/TextEditor.py +0 -50
- edsl/results/file_exports.py +0 -252
- edsl/results/smart_objects.py +0 -96
- edsl/results/table_data_class.py +0 -12
- edsl/results/table_renderers.py +0 -118
- edsl/scenarios/ConstructDownloadLink.py +0 -109
- edsl/scenarios/DocumentChunker.py +0 -102
- edsl/scenarios/DocxScenario.py +0 -16
- edsl/scenarios/PdfExtractor.py +0 -40
- edsl/scenarios/directory_scanner.py +0 -96
- edsl/scenarios/file_methods.py +0 -85
- edsl/scenarios/handlers/__init__.py +0 -13
- edsl/scenarios/handlers/csv.py +0 -49
- edsl/scenarios/handlers/docx.py +0 -76
- edsl/scenarios/handlers/html.py +0 -37
- edsl/scenarios/handlers/json.py +0 -111
- edsl/scenarios/handlers/latex.py +0 -5
- edsl/scenarios/handlers/md.py +0 -51
- edsl/scenarios/handlers/pdf.py +0 -68
- edsl/scenarios/handlers/png.py +0 -39
- edsl/scenarios/handlers/pptx.py +0 -105
- edsl/scenarios/handlers/py.py +0 -294
- edsl/scenarios/handlers/sql.py +0 -313
- edsl/scenarios/handlers/sqlite.py +0 -149
- edsl/scenarios/handlers/txt.py +0 -33
- edsl/scenarios/scenario_selector.py +0 -156
- edsl/surveys/ConstructDAG.py +0 -92
- edsl/surveys/EditSurvey.py +0 -221
- edsl/surveys/InstructionHandler.py +0 -100
- edsl/surveys/MemoryManagement.py +0 -72
- edsl/surveys/RuleManager.py +0 -172
- edsl/surveys/Simulator.py +0 -75
- edsl/surveys/SurveyToApp.py +0 -141
- edsl/utilities/PrettyList.py +0 -56
- edsl/utilities/is_notebook.py +0 -18
- edsl/utilities/is_valid_variable_name.py +0 -11
- edsl/utilities/remove_edsl_version.py +0 -24
- edsl-0.1.39.dist-info/RECORD +0 -358
- /edsl/questions/{register_questions_meta.py → RegisterQuestionsMeta.py} +0 -0
- /edsl/results/{results_fetch_mixin.py → ResultsFetchMixin.py} +0 -0
- /edsl/results/{results_tools_mixin.py → ResultsToolsMixin.py} +0 -0
- {edsl-0.1.39.dist-info → edsl-0.1.39.dev1.dist-info}/LICENSE +0 -0
edsl/auto/StageBase.py
CHANGED
@@ -1,5 +1,4 @@
|
|
1
1
|
from abc import ABC, abstractmethod
|
2
|
-
import json
|
3
2
|
from typing import Dict, List, Any, TypeVar, Generator, Dict, Callable
|
4
3
|
from dataclasses import dataclass, field, KW_ONLY, fields, asdict
|
5
4
|
import textwrap
|
@@ -36,13 +35,6 @@ class FlowDataBase:
|
|
36
35
|
sent_to_stage_name: str = field(default_factory=str)
|
37
36
|
came_from_stage_name: str = field(default_factory=str)
|
38
37
|
|
39
|
-
def to_dict(self):
|
40
|
-
return asdict(self)
|
41
|
-
|
42
|
-
@classmethod
|
43
|
-
def from_dict(cls, data: dict):
|
44
|
-
return cls(**data)
|
45
|
-
|
46
38
|
def __getitem__(self, key):
|
47
39
|
"""Allows dictionary-style getting."""
|
48
40
|
return getattr(self, key)
|
@@ -134,10 +126,6 @@ class StageBase(ABC):
|
|
134
126
|
else:
|
135
127
|
self.next_stage = None
|
136
128
|
|
137
|
-
@classmethod
|
138
|
-
def function_parameters(self):
|
139
|
-
return fields(self.input)
|
140
|
-
|
141
129
|
@classmethod
|
142
130
|
def func(cls, **kwargs):
|
143
131
|
"This provides a shortcut for running a stage by passing keyword arguments to the input function."
|
@@ -185,59 +173,58 @@ class StageBase(ABC):
|
|
185
173
|
|
186
174
|
|
187
175
|
if __name__ == "__main__":
|
188
|
-
|
189
|
-
# try:
|
176
|
+
try:
|
190
177
|
|
191
|
-
|
192
|
-
|
193
|
-
|
178
|
+
class StageMissing(StageBase):
|
179
|
+
def handle_data(self, data):
|
180
|
+
return data
|
194
181
|
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
182
|
+
except NotImplementedError as e:
|
183
|
+
print(e)
|
184
|
+
else:
|
185
|
+
raise Exception("Should have raised NotImplementedError")
|
199
186
|
|
200
|
-
|
187
|
+
try:
|
201
188
|
|
202
|
-
|
203
|
-
|
189
|
+
class StageMissingInput(StageBase):
|
190
|
+
output = FlowDataBase
|
204
191
|
|
205
|
-
|
206
|
-
|
192
|
+
except NotImplementedError as e:
|
193
|
+
print(e)
|
207
194
|
|
208
|
-
|
209
|
-
|
195
|
+
else:
|
196
|
+
raise Exception("Should have raised NotImplementedError")
|
210
197
|
|
211
|
-
|
212
|
-
|
213
|
-
|
198
|
+
@dataclass
|
199
|
+
class MockInputOutput(FlowDataBase):
|
200
|
+
text: str
|
214
201
|
|
215
|
-
|
216
|
-
|
217
|
-
|
202
|
+
class StageTest(StageBase):
|
203
|
+
input = MockInputOutput
|
204
|
+
output = MockInputOutput
|
218
205
|
|
219
|
-
|
220
|
-
|
206
|
+
def handle_data(self, data):
|
207
|
+
return self.output(text=data["text"] + "processed")
|
221
208
|
|
222
|
-
|
223
|
-
|
209
|
+
result = StageTest().process(MockInputOutput(text="Hello world!"))
|
210
|
+
print(result.text)
|
224
211
|
|
225
|
-
|
226
|
-
|
227
|
-
|
212
|
+
pipeline = StageTest(next_stage=StageTest(next_stage=StageTest()))
|
213
|
+
result = pipeline.process(MockInputOutput(text="Hello world!"))
|
214
|
+
print(result.text)
|
228
215
|
|
229
|
-
|
230
|
-
|
231
|
-
|
216
|
+
class BadMockInput(FlowDataBase):
|
217
|
+
text: str
|
218
|
+
other: str
|
232
219
|
|
233
|
-
|
234
|
-
|
235
|
-
|
220
|
+
class StageBad(StageBase):
|
221
|
+
input = BadMockInput
|
222
|
+
output = BadMockInput
|
236
223
|
|
237
|
-
|
238
|
-
|
224
|
+
def handle_data(self, data):
|
225
|
+
return self.output(text=data["text"] + "processed")
|
239
226
|
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
227
|
+
try:
|
228
|
+
pipeline = StageTest(next_stage=StageBad(next_stage=StageTest()))
|
229
|
+
except ExceptionPipesDoNotFit as e:
|
230
|
+
print(e)
|
edsl/auto/StageQuestions.py
CHANGED
edsl/auto/utilities.py
CHANGED
@@ -88,6 +88,12 @@ def agent_eligibility(
|
|
88
88
|
q_eligibility(model=model, questions=questions, persona=persona, cache=cache)
|
89
89
|
== "Yes"
|
90
90
|
)
|
91
|
+
# results = (
|
92
|
+
# q.by(model)
|
93
|
+
# .by(Scenario({"questions": questions, "persona": persona}))
|
94
|
+
# .run(cache=cache)
|
95
|
+
# )
|
96
|
+
# return results.select("eligibility").first() == "Yes"
|
91
97
|
|
92
98
|
|
93
99
|
def gen_agent_traits(dimension_dict: dict, seed_value: Optional[str] = None):
|
edsl/config.py
CHANGED
@@ -1,16 +1,12 @@
|
|
1
1
|
"""This module provides a Config class that loads environment variables from a .env file and sets them as class attributes."""
|
2
2
|
|
3
3
|
import os
|
4
|
-
import platformdirs
|
5
4
|
from dotenv import load_dotenv, find_dotenv
|
6
|
-
from edsl.exceptions
|
5
|
+
from edsl.exceptions import (
|
7
6
|
InvalidEnvironmentVariableError,
|
8
7
|
MissingEnvironmentVariableError,
|
9
8
|
)
|
10
9
|
|
11
|
-
cache_dir = platformdirs.user_cache_dir("edsl")
|
12
|
-
os.makedirs(cache_dir, exist_ok=True)
|
13
|
-
|
14
10
|
# valid values for EDSL_RUN_MODE
|
15
11
|
EDSL_RUN_MODES = [
|
16
12
|
"development",
|
@@ -38,8 +34,7 @@ CONFIG_MAP = {
|
|
38
34
|
"info": "This config var determines the maximum number of seconds to wait before retrying a failed API call.",
|
39
35
|
},
|
40
36
|
"EDSL_DATABASE_PATH": {
|
41
|
-
|
42
|
-
"default": f"sqlite:///{os.path.join(platformdirs.user_cache_dir('edsl'), 'lm_model_calls.db')}",
|
37
|
+
"default": f"sqlite:///{os.path.join(os.getcwd(), '.edsl_cache/data.db')}",
|
43
38
|
"info": "This config var determines the path to the cache file.",
|
44
39
|
},
|
45
40
|
"EDSL_DEFAULT_MODEL": {
|
@@ -74,10 +69,6 @@ CONFIG_MAP = {
|
|
74
69
|
"default": "False",
|
75
70
|
"info": "This config var determines whether to open the exception report URL in the browser",
|
76
71
|
},
|
77
|
-
"EDSL_REMOTE_TOKEN_BUCKET_URL": {
|
78
|
-
"default": "None",
|
79
|
-
"info": "This config var holds the URL of the remote token bucket server.",
|
80
|
-
},
|
81
72
|
}
|
82
73
|
|
83
74
|
|
@@ -90,9 +81,6 @@ class Config:
|
|
90
81
|
self._load_dotenv()
|
91
82
|
self._set_env_vars()
|
92
83
|
|
93
|
-
def show_path_to_dot_env(self):
|
94
|
-
print(find_dotenv(usecwd=True))
|
95
|
-
|
96
84
|
def _set_run_mode(self) -> None:
|
97
85
|
"""
|
98
86
|
Sets EDSL_RUN_MODE as a class attribute.
|
@@ -156,14 +144,6 @@ class Config:
|
|
156
144
|
raise MissingEnvironmentVariableError(f"{env_var} is not set. {info}")
|
157
145
|
return self.__dict__.get(env_var)
|
158
146
|
|
159
|
-
def __iter__(self):
|
160
|
-
"""Iterate over the environment variables."""
|
161
|
-
return iter(self.__dict__)
|
162
|
-
|
163
|
-
def items(self):
|
164
|
-
"""Iterate over the environment variables and their values."""
|
165
|
-
return self.__dict__.items()
|
166
|
-
|
167
147
|
def show(self) -> str:
|
168
148
|
"""Print the currently set environment vars."""
|
169
149
|
max_env_var_length = max(len(env_var) for env_var in self.__dict__)
|
edsl/conversation/car_buying.py
CHANGED
@@ -29,8 +29,7 @@ a3 = Agent(
|
|
29
29
|
c1 = Conversation(agent_list=AgentList([a1, a3, a2]), max_turns=5, verbose=True)
|
30
30
|
c2 = Conversation(agent_list=AgentList([a1, a2]), max_turns=5, verbose=True)
|
31
31
|
|
32
|
-
|
33
|
-
c = Cache()
|
32
|
+
c = Cache.load("car_talk.json.gz")
|
34
33
|
# breakpoint()
|
35
34
|
combo = ConversationList([c1, c2], cache=c)
|
36
35
|
combo.run()
|
edsl/coop/PriceFetcher.py
CHANGED
edsl/coop/coop.py
CHANGED
@@ -1,19 +1,11 @@
|
|
1
1
|
import aiohttp
|
2
2
|
import json
|
3
|
+
import os
|
3
4
|
import requests
|
4
|
-
|
5
|
-
from typing import Any, Optional, Union, Literal, TypedDict
|
5
|
+
from typing import Any, Optional, Union, Literal
|
6
6
|
from uuid import UUID
|
7
|
-
from collections import UserDict, defaultdict
|
8
|
-
|
9
7
|
import edsl
|
10
|
-
from
|
11
|
-
|
12
|
-
from edsl.config import CONFIG
|
13
|
-
from edsl.data.CacheEntry import CacheEntry
|
14
|
-
from edsl.jobs.Jobs import Jobs
|
15
|
-
from edsl.surveys.Survey import Survey
|
16
|
-
|
8
|
+
from edsl import CONFIG, CacheEntry, Jobs, Survey
|
17
9
|
from edsl.exceptions.coop import CoopNoUUIDError, CoopServerResponseError
|
18
10
|
from edsl.coop.utils import (
|
19
11
|
EDSLObject,
|
@@ -23,48 +15,19 @@ from edsl.coop.utils import (
|
|
23
15
|
VisibilityType,
|
24
16
|
)
|
25
17
|
|
26
|
-
from edsl.coop.CoopFunctionsMixin import CoopFunctionsMixin
|
27
|
-
from edsl.coop.ExpectedParrotKeyHandler import ExpectedParrotKeyHandler
|
28
|
-
|
29
|
-
from edsl.inference_services.data_structures import ServiceToModelsMapping
|
30
|
-
|
31
18
|
|
32
|
-
class
|
33
|
-
job_uuid: str
|
34
|
-
results_uuid: str
|
35
|
-
results_url: str
|
36
|
-
latest_error_report_uuid: str
|
37
|
-
latest_error_report_url: str
|
38
|
-
status: str
|
39
|
-
reason: str
|
40
|
-
credits_consumed: float
|
41
|
-
version: str
|
42
|
-
|
43
|
-
|
44
|
-
class RemoteInferenceCreationInfo(TypedDict):
|
45
|
-
uuid: str
|
46
|
-
description: str
|
47
|
-
status: str
|
48
|
-
iterations: int
|
49
|
-
visibility: str
|
50
|
-
version: str
|
51
|
-
|
52
|
-
|
53
|
-
class Coop(CoopFunctionsMixin):
|
19
|
+
class Coop:
|
54
20
|
"""
|
55
21
|
Client for the Expected Parrot API.
|
56
22
|
"""
|
57
23
|
|
58
|
-
def __init__(
|
59
|
-
self, api_key: Optional[str] = None, url: Optional[str] = None
|
60
|
-
) -> None:
|
24
|
+
def __init__(self, api_key: str = None, url: str = None) -> None:
|
61
25
|
"""
|
62
26
|
Initialize the client.
|
63
27
|
- Provide an API key directly, or through an env variable.
|
64
28
|
- Provide a URL directly, or use the default one.
|
65
29
|
"""
|
66
|
-
self.
|
67
|
-
self.api_key = api_key or self.ep_key_handler.get_ep_api_key()
|
30
|
+
self.api_key = api_key or os.getenv("EXPECTED_PARROT_API_KEY")
|
68
31
|
|
69
32
|
self.url = url or CONFIG.EXPECTED_PARROT_URL
|
70
33
|
if self.url.endswith("/"):
|
@@ -179,7 +142,6 @@ class Coop(CoopFunctionsMixin):
|
|
179
142
|
Check the response from the server and raise errors as appropriate.
|
180
143
|
"""
|
181
144
|
# Get EDSL version from header
|
182
|
-
# breakpoint()
|
183
145
|
server_edsl_version = response.headers.get("X-EDSL-Version")
|
184
146
|
|
185
147
|
if server_edsl_version:
|
@@ -188,18 +150,11 @@ class Coop(CoopFunctionsMixin):
|
|
188
150
|
server_version_str=server_edsl_version,
|
189
151
|
):
|
190
152
|
print(
|
191
|
-
"Please upgrade your EDSL version to access our latest features. To upgrade, open your terminal and run `pip
|
153
|
+
"Please upgrade your EDSL version to access our latest features. To upgrade, open your terminal and run `pip upgrade edsl`"
|
192
154
|
)
|
193
155
|
|
194
156
|
if response.status_code >= 400:
|
195
|
-
|
196
|
-
message = response.json().get("detail")
|
197
|
-
except json.JSONDecodeError:
|
198
|
-
raise CoopServerResponseError(
|
199
|
-
f"Server returned status code {response.status_code}."
|
200
|
-
"JSON response could not be decoded.",
|
201
|
-
"The server response was: " + response.text,
|
202
|
-
)
|
157
|
+
message = response.json().get("detail")
|
203
158
|
# print(response.text)
|
204
159
|
if "The API key you provided is invalid" in message and check_api_key:
|
205
160
|
import secrets
|
@@ -208,27 +163,19 @@ class Coop(CoopFunctionsMixin):
|
|
208
163
|
edsl_auth_token = secrets.token_urlsafe(16)
|
209
164
|
|
210
165
|
print("Your Expected Parrot API key is invalid.")
|
211
|
-
|
212
|
-
|
213
|
-
link_description="\n🔗 Use the link below to log in to Expected Parrot so we can automatically update your API key.",
|
166
|
+
print(
|
167
|
+
"\nUse the link below to log in to Expected Parrot so we can automatically update your API key."
|
214
168
|
)
|
169
|
+
self._display_login_url(edsl_auth_token=edsl_auth_token)
|
215
170
|
api_key = self._poll_for_api_key(edsl_auth_token)
|
216
171
|
|
217
172
|
if api_key is None:
|
218
173
|
print("\nTimed out waiting for login. Please try again.")
|
219
174
|
return
|
220
175
|
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
pass
|
225
|
-
else:
|
226
|
-
path_to_env = write_api_key_to_env(api_key)
|
227
|
-
print(
|
228
|
-
"\n✨ API key retrieved and written to .env file at the following path:"
|
229
|
-
)
|
230
|
-
print(f" {path_to_env}")
|
231
|
-
print("Rerun your code to try again with a valid API key.")
|
176
|
+
write_api_key_to_env(api_key)
|
177
|
+
print("\n✨ API key retrieved and written to .env file.")
|
178
|
+
print("Rerun your code to try again with a valid API key.")
|
232
179
|
return
|
233
180
|
|
234
181
|
elif "Authorization" in message:
|
@@ -321,7 +268,6 @@ class Coop(CoopFunctionsMixin):
|
|
321
268
|
self,
|
322
269
|
object: EDSLObject,
|
323
270
|
description: Optional[str] = None,
|
324
|
-
alias: Optional[str] = None,
|
325
271
|
visibility: Optional[VisibilityType] = "unlisted",
|
326
272
|
) -> dict:
|
327
273
|
"""
|
@@ -333,7 +279,6 @@ class Coop(CoopFunctionsMixin):
|
|
333
279
|
method="POST",
|
334
280
|
payload={
|
335
281
|
"description": description,
|
336
|
-
"alias": alias,
|
337
282
|
"json_string": json.dumps(
|
338
283
|
object.to_dict(),
|
339
284
|
default=self._json_handle_none,
|
@@ -428,7 +373,6 @@ class Coop(CoopFunctionsMixin):
|
|
428
373
|
uuid: Union[str, UUID] = None,
|
429
374
|
url: str = None,
|
430
375
|
description: Optional[str] = None,
|
431
|
-
alias: Optional[str] = None,
|
432
376
|
value: Optional[EDSLObject] = None,
|
433
377
|
visibility: Optional[VisibilityType] = None,
|
434
378
|
) -> dict:
|
@@ -445,7 +389,6 @@ class Coop(CoopFunctionsMixin):
|
|
445
389
|
params={"uuid": uuid},
|
446
390
|
payload={
|
447
391
|
"description": description,
|
448
|
-
"alias": alias,
|
449
392
|
"json_string": (
|
450
393
|
json.dumps(
|
451
394
|
value.to_dict(),
|
@@ -659,6 +602,9 @@ class Coop(CoopFunctionsMixin):
|
|
659
602
|
self._resolve_server_response(response)
|
660
603
|
return response.json()
|
661
604
|
|
605
|
+
################
|
606
|
+
# Remote Inference
|
607
|
+
################
|
662
608
|
def remote_inference_create(
|
663
609
|
self,
|
664
610
|
job: Jobs,
|
@@ -667,7 +613,7 @@ class Coop(CoopFunctionsMixin):
|
|
667
613
|
visibility: Optional[VisibilityType] = "unlisted",
|
668
614
|
initial_results_visibility: Optional[VisibilityType] = "unlisted",
|
669
615
|
iterations: Optional[int] = 1,
|
670
|
-
) ->
|
616
|
+
) -> dict:
|
671
617
|
"""
|
672
618
|
Send a remote inference job to the server.
|
673
619
|
|
@@ -699,21 +645,18 @@ class Coop(CoopFunctionsMixin):
|
|
699
645
|
)
|
700
646
|
self._resolve_server_response(response)
|
701
647
|
response_json = response.json()
|
702
|
-
|
703
|
-
|
704
|
-
|
705
|
-
|
706
|
-
|
707
|
-
|
708
|
-
|
709
|
-
|
710
|
-
"version": self._edsl_version,
|
711
|
-
}
|
712
|
-
)
|
648
|
+
return {
|
649
|
+
"uuid": response_json.get("job_uuid"),
|
650
|
+
"description": response_json.get("description"),
|
651
|
+
"status": response_json.get("status"),
|
652
|
+
"iterations": response_json.get("iterations"),
|
653
|
+
"visibility": response_json.get("visibility"),
|
654
|
+
"version": self._edsl_version,
|
655
|
+
}
|
713
656
|
|
714
657
|
def remote_inference_get(
|
715
658
|
self, job_uuid: Optional[str] = None, results_uuid: Optional[str] = None
|
716
|
-
) ->
|
659
|
+
) -> dict:
|
717
660
|
"""
|
718
661
|
Get the details of a remote inference job.
|
719
662
|
You can pass either the job uuid or the results uuid as a parameter.
|
@@ -755,30 +698,17 @@ class Coop(CoopFunctionsMixin):
|
|
755
698
|
f"{self.url}/home/remote-inference/error/{latest_error_report_uuid}"
|
756
699
|
)
|
757
700
|
|
758
|
-
return
|
759
|
-
|
760
|
-
|
761
|
-
|
762
|
-
|
763
|
-
|
764
|
-
|
765
|
-
|
766
|
-
|
767
|
-
|
768
|
-
|
769
|
-
}
|
770
|
-
)
|
771
|
-
|
772
|
-
def get_running_jobs(self) -> list[str]:
|
773
|
-
"""
|
774
|
-
Get a list of currently running job IDs.
|
775
|
-
|
776
|
-
Returns:
|
777
|
-
list[str]: List of running job UUIDs
|
778
|
-
"""
|
779
|
-
response = self._send_server_request(uri="jobs/status", method="GET")
|
780
|
-
self._resolve_server_response(response)
|
781
|
-
return response.json().get("running_jobs", [])
|
701
|
+
return {
|
702
|
+
"job_uuid": data.get("job_uuid"),
|
703
|
+
"results_uuid": results_uuid,
|
704
|
+
"results_url": results_url,
|
705
|
+
"latest_error_report_uuid": latest_error_report_uuid,
|
706
|
+
"latest_error_report_url": latest_error_report_url,
|
707
|
+
"status": data.get("status"),
|
708
|
+
"reason": data.get("reason"),
|
709
|
+
"credits_consumed": data.get("price"),
|
710
|
+
"version": data.get("version"),
|
711
|
+
}
|
782
712
|
|
783
713
|
def remote_inference_cost(
|
784
714
|
self, input: Union[Jobs, Survey], iterations: int = 1
|
@@ -880,7 +810,7 @@ class Coop(CoopFunctionsMixin):
|
|
880
810
|
"Invalid EDSL_FETCH_TOKEN_PRICES value---should be 'True' or 'False'."
|
881
811
|
)
|
882
812
|
|
883
|
-
def fetch_models(self) ->
|
813
|
+
def fetch_models(self) -> dict:
|
884
814
|
"""
|
885
815
|
Fetch a dict of available models from Coop.
|
886
816
|
|
@@ -889,7 +819,7 @@ class Coop(CoopFunctionsMixin):
|
|
889
819
|
response = self._send_server_request(uri="api/v0/models", method="GET")
|
890
820
|
self._resolve_server_response(response)
|
891
821
|
data = response.json()
|
892
|
-
return
|
822
|
+
return data
|
893
823
|
|
894
824
|
def fetch_rate_limit_config_vars(self) -> dict:
|
895
825
|
"""
|
@@ -905,9 +835,7 @@ class Coop(CoopFunctionsMixin):
|
|
905
835
|
data = response.json()
|
906
836
|
return data
|
907
837
|
|
908
|
-
def _display_login_url(
|
909
|
-
self, edsl_auth_token: str, link_description: Optional[str] = None
|
910
|
-
):
|
838
|
+
def _display_login_url(self, edsl_auth_token: str):
|
911
839
|
"""
|
912
840
|
Uses rich.print to display a login URL.
|
913
841
|
|
@@ -917,12 +845,7 @@ class Coop(CoopFunctionsMixin):
|
|
917
845
|
|
918
846
|
url = f"{CONFIG.EXPECTED_PARROT_URL}/login?edsl_auth_token={edsl_auth_token}"
|
919
847
|
|
920
|
-
|
921
|
-
rich_print(
|
922
|
-
f"{link_description}\n [#38bdf8][link={url}]{url}[/link][/#38bdf8]"
|
923
|
-
)
|
924
|
-
else:
|
925
|
-
rich_print(f" [#38bdf8][link={url}]{url}[/link][/#38bdf8]")
|
848
|
+
rich_print(f"[#38bdf8][link={url}]{url}[/link][/#38bdf8]")
|
926
849
|
|
927
850
|
def _get_api_key(self, edsl_auth_token: str):
|
928
851
|
"""
|
@@ -950,18 +873,17 @@ class Coop(CoopFunctionsMixin):
|
|
950
873
|
|
951
874
|
edsl_auth_token = secrets.token_urlsafe(16)
|
952
875
|
|
953
|
-
|
954
|
-
|
955
|
-
link_description="\n🔗 Use the link below to log in to Expected Parrot so we can automatically update your API key.",
|
876
|
+
print(
|
877
|
+
"\nUse the link below to log in to Expected Parrot so we can automatically update your API key."
|
956
878
|
)
|
879
|
+
self._display_login_url(edsl_auth_token=edsl_auth_token)
|
957
880
|
api_key = self._poll_for_api_key(edsl_auth_token)
|
958
881
|
|
959
882
|
if api_key is None:
|
960
883
|
raise Exception("Timed out waiting for login. Please try again.")
|
961
884
|
|
962
|
-
|
963
|
-
print("\n✨ API key retrieved and written to .env file
|
964
|
-
print(f" {path_to_env}")
|
885
|
+
write_api_key_to_env(api_key)
|
886
|
+
print("\n✨ API key retrieved and written to .env file.")
|
965
887
|
|
966
888
|
# Add API key to environment
|
967
889
|
load_dotenv()
|
edsl/coop/utils.py
CHANGED
@@ -1,19 +1,19 @@
|
|
1
|
+
from edsl import (
|
2
|
+
Agent,
|
3
|
+
AgentList,
|
4
|
+
Cache,
|
5
|
+
ModelList,
|
6
|
+
Notebook,
|
7
|
+
Results,
|
8
|
+
Scenario,
|
9
|
+
ScenarioList,
|
10
|
+
Survey,
|
11
|
+
Study,
|
12
|
+
)
|
13
|
+
from edsl.language_models import LanguageModel
|
14
|
+
from edsl.questions import QuestionBase
|
1
15
|
from typing import Literal, Optional, Type, Union
|
2
16
|
|
3
|
-
from edsl.agents.Agent import Agent
|
4
|
-
from edsl.agents.AgentList import AgentList
|
5
|
-
from edsl.data.Cache import Cache
|
6
|
-
from edsl.language_models.ModelList import ModelList
|
7
|
-
from edsl.notebooks.Notebook import Notebook
|
8
|
-
from edsl.results.Results import Results
|
9
|
-
from edsl.scenarios.Scenario import Scenario
|
10
|
-
from edsl.scenarios.ScenarioList import ScenarioList
|
11
|
-
from edsl.surveys.Survey import Survey
|
12
|
-
from edsl.study.Study import Study
|
13
|
-
|
14
|
-
from edsl.language_models.LanguageModel import LanguageModel
|
15
|
-
from edsl.questions.QuestionBase import QuestionBase
|
16
|
-
|
17
17
|
EDSLObject = Union[
|
18
18
|
Agent,
|
19
19
|
AgentList,
|