edsl 0.1.39.dev1__py3-none-any.whl → 0.1.39.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +169 -116
- edsl/__init__.py +14 -6
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +358 -146
- edsl/agents/AgentList.py +211 -73
- edsl/agents/Invigilator.py +88 -36
- edsl/agents/InvigilatorBase.py +59 -70
- edsl/agents/PromptConstructor.py +117 -219
- edsl/agents/QuestionInstructionPromptBuilder.py +128 -0
- edsl/agents/QuestionOptionProcessor.py +172 -0
- edsl/agents/QuestionTemplateReplacementsBuilder.py +137 -0
- edsl/agents/__init__.py +0 -1
- edsl/agents/prompt_helpers.py +3 -3
- edsl/config.py +22 -2
- edsl/conversation/car_buying.py +2 -1
- edsl/coop/CoopFunctionsMixin.py +15 -0
- edsl/coop/ExpectedParrotKeyHandler.py +125 -0
- edsl/coop/PriceFetcher.py +1 -1
- edsl/coop/coop.py +104 -42
- edsl/coop/utils.py +14 -14
- edsl/data/Cache.py +21 -14
- edsl/data/CacheEntry.py +12 -15
- edsl/data/CacheHandler.py +33 -12
- edsl/data/__init__.py +4 -3
- edsl/data_transfer_models.py +2 -1
- edsl/enums.py +20 -0
- edsl/exceptions/__init__.py +50 -50
- edsl/exceptions/agents.py +12 -0
- edsl/exceptions/inference_services.py +5 -0
- edsl/exceptions/questions.py +24 -6
- edsl/exceptions/scenarios.py +7 -0
- edsl/inference_services/AnthropicService.py +0 -3
- edsl/inference_services/AvailableModelCacheHandler.py +184 -0
- edsl/inference_services/AvailableModelFetcher.py +209 -0
- edsl/inference_services/AwsBedrock.py +0 -2
- edsl/inference_services/AzureAI.py +0 -2
- edsl/inference_services/GoogleService.py +2 -11
- edsl/inference_services/InferenceServiceABC.py +18 -85
- edsl/inference_services/InferenceServicesCollection.py +105 -80
- edsl/inference_services/MistralAIService.py +0 -3
- edsl/inference_services/OpenAIService.py +1 -4
- edsl/inference_services/PerplexityService.py +0 -3
- edsl/inference_services/ServiceAvailability.py +135 -0
- edsl/inference_services/TestService.py +11 -8
- edsl/inference_services/data_structures.py +62 -0
- edsl/jobs/AnswerQuestionFunctionConstructor.py +188 -0
- edsl/jobs/Answers.py +1 -14
- edsl/jobs/FetchInvigilator.py +40 -0
- edsl/jobs/InterviewTaskManager.py +98 -0
- edsl/jobs/InterviewsConstructor.py +48 -0
- edsl/jobs/Jobs.py +102 -243
- edsl/jobs/JobsChecks.py +35 -10
- edsl/jobs/JobsComponentConstructor.py +189 -0
- edsl/jobs/JobsPrompts.py +5 -3
- edsl/jobs/JobsRemoteInferenceHandler.py +128 -80
- edsl/jobs/JobsRemoteInferenceLogger.py +239 -0
- edsl/jobs/RequestTokenEstimator.py +30 -0
- edsl/jobs/buckets/BucketCollection.py +44 -3
- edsl/jobs/buckets/TokenBucket.py +53 -21
- edsl/jobs/buckets/TokenBucketAPI.py +211 -0
- edsl/jobs/buckets/TokenBucketClient.py +191 -0
- edsl/jobs/decorators.py +35 -0
- edsl/jobs/interviews/Interview.py +77 -380
- edsl/jobs/jobs_status_enums.py +9 -0
- edsl/jobs/loggers/HTMLTableJobLogger.py +304 -0
- edsl/jobs/runners/JobsRunnerAsyncio.py +4 -49
- edsl/jobs/tasks/QuestionTaskCreator.py +21 -19
- edsl/jobs/tasks/TaskHistory.py +14 -15
- edsl/jobs/tasks/task_status_enum.py +0 -2
- edsl/language_models/ComputeCost.py +63 -0
- edsl/language_models/LanguageModel.py +137 -234
- edsl/language_models/ModelList.py +11 -13
- edsl/language_models/PriceManager.py +127 -0
- edsl/language_models/RawResponseHandler.py +106 -0
- edsl/language_models/ServiceDataSources.py +0 -0
- edsl/language_models/__init__.py +0 -1
- edsl/language_models/key_management/KeyLookup.py +63 -0
- edsl/language_models/key_management/KeyLookupBuilder.py +273 -0
- edsl/language_models/key_management/KeyLookupCollection.py +38 -0
- edsl/language_models/key_management/__init__.py +0 -0
- edsl/language_models/key_management/models.py +131 -0
- edsl/language_models/registry.py +49 -59
- edsl/language_models/repair.py +2 -2
- edsl/language_models/utilities.py +5 -4
- edsl/notebooks/Notebook.py +19 -14
- edsl/notebooks/NotebookToLaTeX.py +142 -0
- edsl/prompts/Prompt.py +29 -39
- edsl/questions/AnswerValidatorMixin.py +47 -2
- edsl/questions/ExceptionExplainer.py +77 -0
- edsl/questions/HTMLQuestion.py +103 -0
- edsl/questions/LoopProcessor.py +149 -0
- edsl/questions/QuestionBase.py +37 -192
- edsl/questions/QuestionBaseGenMixin.py +52 -48
- edsl/questions/QuestionBasePromptsMixin.py +7 -3
- edsl/questions/QuestionCheckBox.py +1 -1
- edsl/questions/QuestionExtract.py +1 -1
- edsl/questions/QuestionFreeText.py +1 -2
- edsl/questions/QuestionList.py +3 -5
- edsl/questions/QuestionMatrix.py +265 -0
- edsl/questions/QuestionMultipleChoice.py +66 -22
- edsl/questions/QuestionNumerical.py +1 -3
- edsl/questions/QuestionRank.py +6 -16
- edsl/questions/ResponseValidatorABC.py +37 -11
- edsl/questions/ResponseValidatorFactory.py +28 -0
- edsl/questions/SimpleAskMixin.py +4 -3
- edsl/questions/__init__.py +1 -0
- edsl/questions/derived/QuestionLinearScale.py +6 -3
- edsl/questions/derived/QuestionTopK.py +1 -1
- edsl/questions/descriptors.py +17 -3
- edsl/questions/question_registry.py +1 -1
- edsl/questions/templates/matrix/__init__.py +1 -0
- edsl/questions/templates/matrix/answering_instructions.jinja +5 -0
- edsl/questions/templates/matrix/question_presentation.jinja +20 -0
- edsl/results/CSSParameterizer.py +1 -1
- edsl/results/Dataset.py +170 -7
- edsl/results/DatasetExportMixin.py +224 -302
- edsl/results/DatasetTree.py +28 -8
- edsl/results/MarkdownToDocx.py +122 -0
- edsl/results/MarkdownToPDF.py +111 -0
- edsl/results/Result.py +192 -206
- edsl/results/Results.py +120 -113
- edsl/results/ResultsExportMixin.py +2 -0
- edsl/results/Selector.py +23 -13
- edsl/results/TableDisplay.py +98 -171
- edsl/results/TextEditor.py +50 -0
- edsl/results/__init__.py +1 -1
- edsl/results/smart_objects.py +96 -0
- edsl/results/table_data_class.py +12 -0
- edsl/results/table_renderers.py +118 -0
- edsl/scenarios/ConstructDownloadLink.py +109 -0
- edsl/scenarios/DirectoryScanner.py +96 -0
- edsl/scenarios/DocumentChunker.py +102 -0
- edsl/scenarios/DocxScenario.py +16 -0
- edsl/scenarios/FileStore.py +118 -239
- edsl/scenarios/PdfExtractor.py +40 -0
- edsl/scenarios/Scenario.py +90 -193
- edsl/scenarios/ScenarioHtmlMixin.py +4 -3
- edsl/scenarios/ScenarioJoin.py +10 -6
- edsl/scenarios/ScenarioList.py +383 -240
- edsl/scenarios/ScenarioListExportMixin.py +0 -7
- edsl/scenarios/ScenarioListPdfMixin.py +15 -37
- edsl/scenarios/ScenarioSelector.py +156 -0
- edsl/scenarios/__init__.py +1 -2
- edsl/scenarios/file_methods.py +85 -0
- edsl/scenarios/handlers/__init__.py +13 -0
- edsl/scenarios/handlers/csv.py +38 -0
- edsl/scenarios/handlers/docx.py +76 -0
- edsl/scenarios/handlers/html.py +37 -0
- edsl/scenarios/handlers/json.py +111 -0
- edsl/scenarios/handlers/latex.py +5 -0
- edsl/scenarios/handlers/md.py +51 -0
- edsl/scenarios/handlers/pdf.py +68 -0
- edsl/scenarios/handlers/png.py +39 -0
- edsl/scenarios/handlers/pptx.py +105 -0
- edsl/scenarios/handlers/py.py +294 -0
- edsl/scenarios/handlers/sql.py +313 -0
- edsl/scenarios/handlers/sqlite.py +149 -0
- edsl/scenarios/handlers/txt.py +33 -0
- edsl/study/ObjectEntry.py +1 -1
- edsl/study/SnapShot.py +1 -1
- edsl/study/Study.py +5 -12
- edsl/surveys/ConstructDAG.py +92 -0
- edsl/surveys/EditSurvey.py +221 -0
- edsl/surveys/InstructionHandler.py +100 -0
- edsl/surveys/MemoryManagement.py +72 -0
- edsl/surveys/Rule.py +5 -4
- edsl/surveys/RuleCollection.py +25 -27
- edsl/surveys/RuleManager.py +172 -0
- edsl/surveys/Simulator.py +75 -0
- edsl/surveys/Survey.py +199 -771
- edsl/surveys/SurveyCSS.py +20 -8
- edsl/surveys/{SurveyFlowVisualizationMixin.py → SurveyFlowVisualization.py} +11 -9
- edsl/surveys/SurveyToApp.py +141 -0
- edsl/surveys/__init__.py +4 -2
- edsl/surveys/descriptors.py +6 -2
- edsl/surveys/instructions/ChangeInstruction.py +1 -2
- edsl/surveys/instructions/Instruction.py +4 -13
- edsl/surveys/instructions/InstructionCollection.py +11 -6
- edsl/templates/error_reporting/interview_details.html +1 -1
- edsl/templates/error_reporting/report.html +1 -1
- edsl/tools/plotting.py +1 -1
- edsl/utilities/PrettyList.py +56 -0
- edsl/utilities/is_notebook.py +18 -0
- edsl/utilities/is_valid_variable_name.py +11 -0
- edsl/utilities/remove_edsl_version.py +24 -0
- edsl/utilities/utilities.py +35 -23
- {edsl-0.1.39.dev1.dist-info → edsl-0.1.39.dev2.dist-info}/METADATA +12 -10
- edsl-0.1.39.dev2.dist-info/RECORD +352 -0
- edsl/language_models/KeyLookup.py +0 -30
- edsl/language_models/unused/ReplicateBase.py +0 -83
- edsl/results/ResultsDBMixin.py +0 -238
- edsl-0.1.39.dev1.dist-info/RECORD +0 -277
- {edsl-0.1.39.dev1.dist-info → edsl-0.1.39.dev2.dist-info}/LICENSE +0 -0
- {edsl-0.1.39.dev1.dist-info → edsl-0.1.39.dev2.dist-info}/WHEEL +0 -0
edsl/coop/PriceFetcher.py
CHANGED
edsl/coop/coop.py
CHANGED
@@ -1,11 +1,19 @@
|
|
1
1
|
import aiohttp
|
2
2
|
import json
|
3
|
-
import os
|
4
3
|
import requests
|
5
|
-
|
4
|
+
|
5
|
+
from typing import Any, Optional, Union, Literal, TypedDict
|
6
6
|
from uuid import UUID
|
7
|
+
from collections import UserDict, defaultdict
|
8
|
+
|
7
9
|
import edsl
|
8
|
-
from
|
10
|
+
from pathlib import Path
|
11
|
+
|
12
|
+
from edsl.config import CONFIG
|
13
|
+
from edsl.data.CacheEntry import CacheEntry
|
14
|
+
from edsl.jobs.Jobs import Jobs
|
15
|
+
from edsl.surveys.Survey import Survey
|
16
|
+
|
9
17
|
from edsl.exceptions.coop import CoopNoUUIDError, CoopServerResponseError
|
10
18
|
from edsl.coop.utils import (
|
11
19
|
EDSLObject,
|
@@ -15,19 +23,48 @@ from edsl.coop.utils import (
|
|
15
23
|
VisibilityType,
|
16
24
|
)
|
17
25
|
|
26
|
+
from edsl.coop.CoopFunctionsMixin import CoopFunctionsMixin
|
27
|
+
from edsl.coop.ExpectedParrotKeyHandler import ExpectedParrotKeyHandler
|
28
|
+
|
29
|
+
from edsl.inference_services.data_structures import ServiceToModelsMapping
|
30
|
+
|
31
|
+
|
32
|
+
class RemoteInferenceResponse(TypedDict):
|
33
|
+
job_uuid: str
|
34
|
+
results_uuid: str
|
35
|
+
results_url: str
|
36
|
+
latest_error_report_uuid: str
|
37
|
+
latest_error_report_url: str
|
38
|
+
status: str
|
39
|
+
reason: str
|
40
|
+
credits_consumed: float
|
41
|
+
version: str
|
42
|
+
|
43
|
+
|
44
|
+
class RemoteInferenceCreationInfo(TypedDict):
|
45
|
+
uuid: str
|
46
|
+
description: str
|
47
|
+
status: str
|
48
|
+
iterations: int
|
49
|
+
visibility: str
|
50
|
+
version: str
|
18
51
|
|
19
|
-
|
52
|
+
|
53
|
+
class Coop(CoopFunctionsMixin):
|
20
54
|
"""
|
21
55
|
Client for the Expected Parrot API.
|
22
56
|
"""
|
23
57
|
|
24
|
-
def __init__(
|
58
|
+
def __init__(
|
59
|
+
self, api_key: Optional[str] = None, url: Optional[str] = None
|
60
|
+
) -> None:
|
25
61
|
"""
|
26
62
|
Initialize the client.
|
27
63
|
- Provide an API key directly, or through an env variable.
|
28
64
|
- Provide a URL directly, or use the default one.
|
29
65
|
"""
|
30
|
-
self.
|
66
|
+
self.ep_key_handler = ExpectedParrotKeyHandler()
|
67
|
+
self.api_key = api_key or self.ep_key_handler.get_ep_api_key()
|
31
68
|
|
32
69
|
self.url = url or CONFIG.EXPECTED_PARROT_URL
|
33
70
|
if self.url.endswith("/"):
|
@@ -163,19 +200,27 @@ class Coop:
|
|
163
200
|
edsl_auth_token = secrets.token_urlsafe(16)
|
164
201
|
|
165
202
|
print("Your Expected Parrot API key is invalid.")
|
166
|
-
|
167
|
-
|
203
|
+
self._display_login_url(
|
204
|
+
edsl_auth_token=edsl_auth_token,
|
205
|
+
link_description="\n🔗 Use the link below to log in to Expected Parrot so we can automatically update your API key.",
|
168
206
|
)
|
169
|
-
self._display_login_url(edsl_auth_token=edsl_auth_token)
|
170
207
|
api_key = self._poll_for_api_key(edsl_auth_token)
|
171
208
|
|
172
209
|
if api_key is None:
|
173
210
|
print("\nTimed out waiting for login. Please try again.")
|
174
211
|
return
|
175
212
|
|
176
|
-
|
177
|
-
|
178
|
-
|
213
|
+
print("\n✨ API key retrieved.")
|
214
|
+
|
215
|
+
if stored_in_user_space := self.ep_key_handler.ask_to_store(api_key):
|
216
|
+
pass
|
217
|
+
else:
|
218
|
+
path_to_env = write_api_key_to_env(api_key)
|
219
|
+
print(
|
220
|
+
"\n✨ API key retrieved and written to .env file at the following path:"
|
221
|
+
)
|
222
|
+
print(f" {path_to_env}")
|
223
|
+
print("Rerun your code to try again with a valid API key.")
|
179
224
|
return
|
180
225
|
|
181
226
|
elif "Authorization" in message:
|
@@ -268,6 +313,7 @@ class Coop:
|
|
268
313
|
self,
|
269
314
|
object: EDSLObject,
|
270
315
|
description: Optional[str] = None,
|
316
|
+
alias: Optional[str] = None,
|
271
317
|
visibility: Optional[VisibilityType] = "unlisted",
|
272
318
|
) -> dict:
|
273
319
|
"""
|
@@ -279,6 +325,7 @@ class Coop:
|
|
279
325
|
method="POST",
|
280
326
|
payload={
|
281
327
|
"description": description,
|
328
|
+
"alias": alias,
|
282
329
|
"json_string": json.dumps(
|
283
330
|
object.to_dict(),
|
284
331
|
default=self._json_handle_none,
|
@@ -373,6 +420,7 @@ class Coop:
|
|
373
420
|
uuid: Union[str, UUID] = None,
|
374
421
|
url: str = None,
|
375
422
|
description: Optional[str] = None,
|
423
|
+
alias: Optional[str] = None,
|
376
424
|
value: Optional[EDSLObject] = None,
|
377
425
|
visibility: Optional[VisibilityType] = None,
|
378
426
|
) -> dict:
|
@@ -389,6 +437,7 @@ class Coop:
|
|
389
437
|
params={"uuid": uuid},
|
390
438
|
payload={
|
391
439
|
"description": description,
|
440
|
+
"alias": alias,
|
392
441
|
"json_string": (
|
393
442
|
json.dumps(
|
394
443
|
value.to_dict(),
|
@@ -613,7 +662,7 @@ class Coop:
|
|
613
662
|
visibility: Optional[VisibilityType] = "unlisted",
|
614
663
|
initial_results_visibility: Optional[VisibilityType] = "unlisted",
|
615
664
|
iterations: Optional[int] = 1,
|
616
|
-
) ->
|
665
|
+
) -> RemoteInferenceCreationInfo:
|
617
666
|
"""
|
618
667
|
Send a remote inference job to the server.
|
619
668
|
|
@@ -645,18 +694,21 @@ class Coop:
|
|
645
694
|
)
|
646
695
|
self._resolve_server_response(response)
|
647
696
|
response_json = response.json()
|
648
|
-
|
649
|
-
|
650
|
-
|
651
|
-
|
652
|
-
|
653
|
-
|
654
|
-
|
655
|
-
|
697
|
+
|
698
|
+
return RemoteInferenceCreationInfo(
|
699
|
+
**{
|
700
|
+
"uuid": response_json.get("job_uuid"),
|
701
|
+
"description": response_json.get("description"),
|
702
|
+
"status": response_json.get("status"),
|
703
|
+
"iterations": response_json.get("iterations"),
|
704
|
+
"visibility": response_json.get("visibility"),
|
705
|
+
"version": self._edsl_version,
|
706
|
+
}
|
707
|
+
)
|
656
708
|
|
657
709
|
def remote_inference_get(
|
658
710
|
self, job_uuid: Optional[str] = None, results_uuid: Optional[str] = None
|
659
|
-
) ->
|
711
|
+
) -> RemoteInferenceResponse:
|
660
712
|
"""
|
661
713
|
Get the details of a remote inference job.
|
662
714
|
You can pass either the job uuid or the results uuid as a parameter.
|
@@ -698,17 +750,19 @@ class Coop:
|
|
698
750
|
f"{self.url}/home/remote-inference/error/{latest_error_report_uuid}"
|
699
751
|
)
|
700
752
|
|
701
|
-
return
|
702
|
-
|
703
|
-
|
704
|
-
|
705
|
-
|
706
|
-
|
707
|
-
|
708
|
-
|
709
|
-
|
710
|
-
|
711
|
-
|
753
|
+
return RemoteInferenceResponse(
|
754
|
+
**{
|
755
|
+
"job_uuid": data.get("job_uuid"),
|
756
|
+
"results_uuid": results_uuid,
|
757
|
+
"results_url": results_url,
|
758
|
+
"latest_error_report_uuid": latest_error_report_uuid,
|
759
|
+
"latest_error_report_url": latest_error_report_url,
|
760
|
+
"status": data.get("status"),
|
761
|
+
"reason": data.get("reason"),
|
762
|
+
"credits_consumed": data.get("price"),
|
763
|
+
"version": data.get("version"),
|
764
|
+
}
|
765
|
+
)
|
712
766
|
|
713
767
|
def remote_inference_cost(
|
714
768
|
self, input: Union[Jobs, Survey], iterations: int = 1
|
@@ -810,7 +864,7 @@ class Coop:
|
|
810
864
|
"Invalid EDSL_FETCH_TOKEN_PRICES value---should be 'True' or 'False'."
|
811
865
|
)
|
812
866
|
|
813
|
-
def fetch_models(self) ->
|
867
|
+
def fetch_models(self) -> ServiceToModelsMapping:
|
814
868
|
"""
|
815
869
|
Fetch a dict of available models from Coop.
|
816
870
|
|
@@ -819,7 +873,7 @@ class Coop:
|
|
819
873
|
response = self._send_server_request(uri="api/v0/models", method="GET")
|
820
874
|
self._resolve_server_response(response)
|
821
875
|
data = response.json()
|
822
|
-
return data
|
876
|
+
return ServiceToModelsMapping(data)
|
823
877
|
|
824
878
|
def fetch_rate_limit_config_vars(self) -> dict:
|
825
879
|
"""
|
@@ -835,7 +889,9 @@ class Coop:
|
|
835
889
|
data = response.json()
|
836
890
|
return data
|
837
891
|
|
838
|
-
def _display_login_url(
|
892
|
+
def _display_login_url(
|
893
|
+
self, edsl_auth_token: str, link_description: Optional[str] = None
|
894
|
+
):
|
839
895
|
"""
|
840
896
|
Uses rich.print to display a login URL.
|
841
897
|
|
@@ -845,7 +901,12 @@ class Coop:
|
|
845
901
|
|
846
902
|
url = f"{CONFIG.EXPECTED_PARROT_URL}/login?edsl_auth_token={edsl_auth_token}"
|
847
903
|
|
848
|
-
|
904
|
+
if link_description:
|
905
|
+
rich_print(
|
906
|
+
f"{link_description}\n [#38bdf8][link={url}]{url}[/link][/#38bdf8]"
|
907
|
+
)
|
908
|
+
else:
|
909
|
+
rich_print(f" [#38bdf8][link={url}]{url}[/link][/#38bdf8]")
|
849
910
|
|
850
911
|
def _get_api_key(self, edsl_auth_token: str):
|
851
912
|
"""
|
@@ -873,17 +934,18 @@ class Coop:
|
|
873
934
|
|
874
935
|
edsl_auth_token = secrets.token_urlsafe(16)
|
875
936
|
|
876
|
-
|
877
|
-
|
937
|
+
self._display_login_url(
|
938
|
+
edsl_auth_token=edsl_auth_token,
|
939
|
+
link_description="\n🔗 Use the link below to log in to Expected Parrot so we can automatically update your API key.",
|
878
940
|
)
|
879
|
-
self._display_login_url(edsl_auth_token=edsl_auth_token)
|
880
941
|
api_key = self._poll_for_api_key(edsl_auth_token)
|
881
942
|
|
882
943
|
if api_key is None:
|
883
944
|
raise Exception("Timed out waiting for login. Please try again.")
|
884
945
|
|
885
|
-
write_api_key_to_env(api_key)
|
886
|
-
print("\n✨ API key retrieved and written to .env file
|
946
|
+
path_to_env = write_api_key_to_env(api_key)
|
947
|
+
print("\n✨ API key retrieved and written to .env file at the following path:")
|
948
|
+
print(f" {path_to_env}")
|
887
949
|
|
888
950
|
# Add API key to environment
|
889
951
|
load_dotenv()
|
edsl/coop/utils.py
CHANGED
@@ -1,19 +1,19 @@
|
|
1
|
-
from edsl import (
|
2
|
-
Agent,
|
3
|
-
AgentList,
|
4
|
-
Cache,
|
5
|
-
ModelList,
|
6
|
-
Notebook,
|
7
|
-
Results,
|
8
|
-
Scenario,
|
9
|
-
ScenarioList,
|
10
|
-
Survey,
|
11
|
-
Study,
|
12
|
-
)
|
13
|
-
from edsl.language_models import LanguageModel
|
14
|
-
from edsl.questions import QuestionBase
|
15
1
|
from typing import Literal, Optional, Type, Union
|
16
2
|
|
3
|
+
from edsl.agents.Agent import Agent
|
4
|
+
from edsl.agents.AgentList import AgentList
|
5
|
+
from edsl.data.Cache import Cache
|
6
|
+
from edsl.language_models.ModelList import ModelList
|
7
|
+
from edsl.notebooks.Notebook import Notebook
|
8
|
+
from edsl.results.Results import Results
|
9
|
+
from edsl.scenarios.Scenario import Scenario
|
10
|
+
from edsl.scenarios.ScenarioList import ScenarioList
|
11
|
+
from edsl.surveys.Survey import Survey
|
12
|
+
from edsl.study.Study import Study
|
13
|
+
|
14
|
+
from edsl.language_models.LanguageModel import LanguageModel
|
15
|
+
from edsl.questions.QuestionBase import QuestionBase
|
16
|
+
|
17
17
|
EDSLObject = Union[
|
18
18
|
Agent,
|
19
19
|
AgentList,
|
edsl/data/Cache.py
CHANGED
@@ -6,12 +6,12 @@ from __future__ import annotations
|
|
6
6
|
import json
|
7
7
|
import os
|
8
8
|
import warnings
|
9
|
-
import copy
|
10
9
|
from typing import Optional, Union
|
11
10
|
from edsl.Base import Base
|
12
|
-
|
13
|
-
|
14
|
-
from edsl.utilities.decorators import remove_edsl_version
|
11
|
+
|
12
|
+
|
13
|
+
# from edsl.utilities.decorators import remove_edsl_version
|
14
|
+
from edsl.utilities.remove_edsl_version import remove_edsl_version
|
15
15
|
from edsl.exceptions.cache import CacheError
|
16
16
|
|
17
17
|
|
@@ -83,9 +83,9 @@ class Cache(Base):
|
|
83
83
|
|
84
84
|
self._perform_checks()
|
85
85
|
|
86
|
-
def rich_print(sefl):
|
87
|
-
|
88
|
-
|
86
|
+
# def rich_print(sefl):
|
87
|
+
# pass
|
88
|
+
# # raise NotImplementedError("This method is not implemented yet.")
|
89
89
|
|
90
90
|
def code(sefl):
|
91
91
|
pass
|
@@ -201,6 +201,7 @@ class Cache(Base):
|
|
201
201
|
>>> len(c)
|
202
202
|
1
|
203
203
|
"""
|
204
|
+
from edsl.data.CacheEntry import CacheEntry
|
204
205
|
|
205
206
|
entry = CacheEntry(
|
206
207
|
model=model,
|
@@ -226,6 +227,7 @@ class Cache(Base):
|
|
226
227
|
|
227
228
|
:param write_now: Whether to write to the cache immediately (similar to `immediate_write`).
|
228
229
|
"""
|
230
|
+
from edsl.data.CacheEntry import CacheEntry
|
229
231
|
|
230
232
|
for key, value in new_data.items():
|
231
233
|
if key in self.data:
|
@@ -246,6 +248,8 @@ class Cache(Base):
|
|
246
248
|
|
247
249
|
:param write_now: Whether to write to the cache immediately (similar to `immediate_write`).
|
248
250
|
"""
|
251
|
+
from edsl.data.CacheEntry import CacheEntry
|
252
|
+
|
249
253
|
with open(filename, "a+") as f:
|
250
254
|
f.seek(0)
|
251
255
|
lines = f.readlines()
|
@@ -353,7 +357,8 @@ class Cache(Base):
|
|
353
357
|
f.write(json.dumps({key: value.to_dict()}) + "\n")
|
354
358
|
|
355
359
|
def to_scenario_list(self):
|
356
|
-
from edsl import ScenarioList
|
360
|
+
from edsl.scenarios.ScenarioList import ScenarioList
|
361
|
+
from edsl.scenarios.Scenario import Scenario
|
357
362
|
|
358
363
|
scenarios = []
|
359
364
|
for key, value in self.data.items():
|
@@ -399,6 +404,8 @@ class Cache(Base):
|
|
399
404
|
####################
|
400
405
|
def __hash__(self):
|
401
406
|
"""Return the hash of the Cache."""
|
407
|
+
from edsl.utilities.utilities import dict_hash
|
408
|
+
|
402
409
|
return dict_hash(self.to_dict(add_edsl_version=False))
|
403
410
|
|
404
411
|
def to_dict(self, add_edsl_version=True) -> dict:
|
@@ -414,12 +421,6 @@ class Cache(Base):
|
|
414
421
|
def _summary(self):
|
415
422
|
return {"EDSL Class": "Cache", "Number of entries": len(self.data)}
|
416
423
|
|
417
|
-
def _repr_html_(self):
|
418
|
-
# from edsl.utilities.utilities import data_to_html
|
419
|
-
# return data_to_html(self.to_dict())
|
420
|
-
footer = f"<a href={self.__documentation__}>(docs)</a>"
|
421
|
-
return str(self.summary(format="html")) + footer
|
422
|
-
|
423
424
|
def table(
|
424
425
|
self,
|
425
426
|
*fields,
|
@@ -443,6 +444,8 @@ class Cache(Base):
|
|
443
444
|
@remove_edsl_version
|
444
445
|
def from_dict(cls, data) -> Cache:
|
445
446
|
"""Construct a Cache from a dictionary."""
|
447
|
+
from edsl.data.CacheEntry import CacheEntry
|
448
|
+
|
446
449
|
newdata = {k: CacheEntry.from_dict(v) for k, v in data.items()}
|
447
450
|
return cls(data=newdata)
|
448
451
|
|
@@ -485,6 +488,8 @@ class Cache(Base):
|
|
485
488
|
"""
|
486
489
|
Create an example input for a 'fetch' operation.
|
487
490
|
"""
|
491
|
+
from edsl.data.CacheEntry import CacheEntry
|
492
|
+
|
488
493
|
return CacheEntry.fetch_input_example()
|
489
494
|
|
490
495
|
def to_html(self):
|
@@ -541,6 +546,8 @@ class Cache(Base):
|
|
541
546
|
|
542
547
|
:param randomize: If True, uses CacheEntry's randomize method.
|
543
548
|
"""
|
549
|
+
from edsl.data.CacheEntry import CacheEntry
|
550
|
+
|
544
551
|
return cls(
|
545
552
|
data={
|
546
553
|
CacheEntry.example(randomize).key: CacheEntry.example(),
|
edsl/data/CacheEntry.py
CHANGED
@@ -5,8 +5,12 @@ import hashlib
|
|
5
5
|
from typing import Optional
|
6
6
|
from uuid import uuid4
|
7
7
|
|
8
|
+
from edsl.utilities.decorators import remove_edsl_version
|
8
9
|
|
9
|
-
|
10
|
+
from edsl.Base import RepresentationMixin
|
11
|
+
|
12
|
+
|
13
|
+
class CacheEntry(RepresentationMixin):
|
10
14
|
"""
|
11
15
|
A Class to represent a cache entry.
|
12
16
|
"""
|
@@ -78,11 +82,11 @@ class CacheEntry:
|
|
78
82
|
d = {k: value for k, value in self.__dict__.items() if k in self.key_fields}
|
79
83
|
return self.gen_key(**d)
|
80
84
|
|
81
|
-
def to_dict(self) -> dict:
|
85
|
+
def to_dict(self, add_edsl_version=True) -> dict:
|
82
86
|
"""
|
83
87
|
Returns a dictionary representation of a CacheEntry.
|
84
88
|
"""
|
85
|
-
|
89
|
+
d = {
|
86
90
|
"model": self.model,
|
87
91
|
"parameters": self.parameters,
|
88
92
|
"system_prompt": self.system_prompt,
|
@@ -91,19 +95,12 @@ class CacheEntry:
|
|
91
95
|
"iteration": self.iteration,
|
92
96
|
"timestamp": self.timestamp,
|
93
97
|
}
|
98
|
+
# if add_edsl_version:
|
99
|
+
# from edsl import __version__
|
94
100
|
|
95
|
-
|
96
|
-
""
|
97
|
-
|
98
|
-
"""
|
99
|
-
# from edsl.utilities.utilities import data_to_html
|
100
|
-
# return data_to_html(self.to_dict())
|
101
|
-
d = self.to_dict()
|
102
|
-
data = [[k, v] for k, v in d.items()]
|
103
|
-
from tabulate import tabulate
|
104
|
-
|
105
|
-
table = str(tabulate(data, headers=["keys", "values"], tablefmt="html"))
|
106
|
-
return f"<pre>{table}</pre>"
|
101
|
+
# d["edsl_version"] = __version__
|
102
|
+
# d["edsl_class_name"] = self.__class__.__name__
|
103
|
+
return d
|
107
104
|
|
108
105
|
def keys(self):
|
109
106
|
return list(self.to_dict().keys())
|
edsl/data/CacheHandler.py
CHANGED
@@ -3,19 +3,19 @@ import ast
|
|
3
3
|
import json
|
4
4
|
import os
|
5
5
|
import shutil
|
6
|
-
import
|
7
|
-
from edsl.config import CONFIG
|
8
|
-
from edsl.data.Cache import Cache
|
9
|
-
from edsl.data.CacheEntry import CacheEntry
|
10
|
-
from edsl.data.SQLiteDict import SQLiteDict
|
6
|
+
from typing import TYPE_CHECKING
|
11
7
|
|
12
|
-
|
8
|
+
if TYPE_CHECKING:
|
9
|
+
from edsl.data.Cache import Cache
|
10
|
+
from edsl.data.CacheEntry import CacheEntry
|
13
11
|
|
14
12
|
|
15
|
-
def set_session_cache(cache: Cache) -> None:
|
13
|
+
def set_session_cache(cache: "Cache") -> None:
|
16
14
|
"""
|
17
15
|
Set the session cache.
|
18
16
|
"""
|
17
|
+
from edsl.config import CONFIG
|
18
|
+
|
19
19
|
CONFIG.EDSL_SESSION_CACHE = cache
|
20
20
|
|
21
21
|
|
@@ -23,6 +23,8 @@ def unset_session_cache() -> None:
|
|
23
23
|
"""
|
24
24
|
Unset the session cache.
|
25
25
|
"""
|
26
|
+
from edsl.config import CONFIG
|
27
|
+
|
26
28
|
if hasattr(CONFIG, "EDSL_SESSION_CACHE"):
|
27
29
|
del CONFIG.EDSL_SESSION_CACHE
|
28
30
|
|
@@ -32,7 +34,11 @@ class CacheHandler:
|
|
32
34
|
This CacheHandler figures out what caches are available and does migrations, as needed.
|
33
35
|
"""
|
34
36
|
|
35
|
-
|
37
|
+
@property
|
38
|
+
def CACHE_PATH(self):
|
39
|
+
from edsl.config import CONFIG
|
40
|
+
|
41
|
+
return CONFIG.get("EDSL_DATABASE_PATH")
|
36
42
|
|
37
43
|
def __init__(self, test: bool = False):
|
38
44
|
self.test = test
|
@@ -52,16 +58,24 @@ class CacheHandler:
|
|
52
58
|
if notify:
|
53
59
|
print(f"Created cache directory: {dir_path}")
|
54
60
|
|
55
|
-
def gen_cache(self) -> Cache:
|
61
|
+
def gen_cache(self) -> "Cache":
|
56
62
|
"""
|
57
63
|
Generate a Cache object.
|
58
64
|
"""
|
65
|
+
from edsl.data.Cache import Cache
|
66
|
+
|
59
67
|
if self.test:
|
60
68
|
return Cache(data={})
|
61
69
|
|
70
|
+
# if self.CACHE_PATH is not None:
|
71
|
+
# return self.CACHE_PATH
|
72
|
+
from edsl.config import CONFIG
|
73
|
+
|
62
74
|
if hasattr(CONFIG, "EDSL_SESSION_CACHE"):
|
63
75
|
return CONFIG.EDSL_SESSION_CACHE
|
64
76
|
|
77
|
+
from edsl.data.SQLiteDict import SQLiteDict
|
78
|
+
|
65
79
|
cache = Cache(data=SQLiteDict(self.CACHE_PATH))
|
66
80
|
return cache
|
67
81
|
|
@@ -76,6 +90,8 @@ class CacheHandler:
|
|
76
90
|
if not os.path.exists(os.path.join(os.getcwd(), path)):
|
77
91
|
return old_data
|
78
92
|
try:
|
93
|
+
import sqlite3
|
94
|
+
|
79
95
|
conn = sqlite3.connect(path)
|
80
96
|
with conn:
|
81
97
|
cur = conn.cursor()
|
@@ -108,6 +124,8 @@ class CacheHandler:
|
|
108
124
|
entry_dict["user_prompt"] = entry_dict.pop("prompt")
|
109
125
|
parameters = entry_dict["parameters"]
|
110
126
|
entry_dict["parameters"] = ast.literal_eval(parameters)
|
127
|
+
from edsl.data.CacheEntry import CacheEntry
|
128
|
+
|
111
129
|
entry = CacheEntry(**entry_dict)
|
112
130
|
return entry
|
113
131
|
|
@@ -117,7 +135,7 @@ class CacheHandler:
|
|
117
135
|
###############
|
118
136
|
# NOT IN USE
|
119
137
|
###############
|
120
|
-
def from_sqlite(uri="new_edsl_cache.db") -> dict[str, CacheEntry]:
|
138
|
+
def from_sqlite(uri="new_edsl_cache.db") -> dict[str, "CacheEntry"]:
|
121
139
|
"""
|
122
140
|
Read in a new-style sqlite cache and return a dictionary of dictionaries.
|
123
141
|
"""
|
@@ -131,7 +149,7 @@ class CacheHandler:
|
|
131
149
|
newdata[entry.key] = entry
|
132
150
|
return newdata
|
133
151
|
|
134
|
-
def from_jsonl(filename="edsl_cache.jsonl") -> dict[str, CacheEntry]:
|
152
|
+
def from_jsonl(filename="edsl_cache.jsonl") -> dict[str, "CacheEntry"]:
|
135
153
|
"""Read in a jsonl file and return a dictionary of CacheEntry objects."""
|
136
154
|
with open(filename, "a+") as f:
|
137
155
|
f.seek(0)
|
@@ -146,4 +164,7 @@ class CacheHandler:
|
|
146
164
|
|
147
165
|
|
148
166
|
if __name__ == "__main__":
|
149
|
-
ch = CacheHandler()
|
167
|
+
# ch = CacheHandler()
|
168
|
+
import doctest
|
169
|
+
|
170
|
+
doctest.testmod()
|
edsl/data/__init__.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1
|
-
from edsl.data.CacheEntry import CacheEntry
|
2
|
-
from edsl.data.SQLiteDict import SQLiteDict
|
1
|
+
# from edsl.data.CacheEntry import CacheEntry
|
2
|
+
# from edsl.data.SQLiteDict import SQLiteDict
|
3
3
|
from edsl.data.Cache import Cache
|
4
|
-
|
4
|
+
|
5
|
+
# from edsl.data.CacheHandler import CacheHandler
|
edsl/data_transfer_models.py
CHANGED
@@ -1,6 +1,5 @@
|
|
1
1
|
from typing import NamedTuple, Dict, List, Optional, Any
|
2
2
|
from dataclasses import dataclass, fields
|
3
|
-
import reprlib
|
4
3
|
|
5
4
|
|
6
5
|
class ModelInputs(NamedTuple):
|
@@ -56,6 +55,8 @@ class ImageInfo:
|
|
56
55
|
encoded_image: str
|
57
56
|
|
58
57
|
def __repr__(self):
|
58
|
+
import reprlib
|
59
|
+
|
59
60
|
reprlib_instance = reprlib.Repr()
|
60
61
|
reprlib_instance.maxstring = 30 # Limit the string length for the encoded image
|
61
62
|
|
edsl/enums.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
"""Enums for the different types of questions, language models, and inference services."""
|
2
2
|
|
3
3
|
from enum import Enum
|
4
|
+
from typing import Literal
|
4
5
|
|
5
6
|
|
6
7
|
class EnumWithChecks(Enum):
|
@@ -67,6 +68,25 @@ class InferenceServiceType(EnumWithChecks):
|
|
67
68
|
PERPLEXITY = "perplexity"
|
68
69
|
|
69
70
|
|
71
|
+
# unavoidable violation of the DRY principle but it is necessary
|
72
|
+
# checked w/ a unit test to make sure consistent with services in enums.py
|
73
|
+
InferenceServiceLiteral = Literal[
|
74
|
+
"bedrock",
|
75
|
+
"deep_infra",
|
76
|
+
"replicate",
|
77
|
+
"openai",
|
78
|
+
"google",
|
79
|
+
"test",
|
80
|
+
"anthropic",
|
81
|
+
"groq",
|
82
|
+
"azure",
|
83
|
+
"ollama",
|
84
|
+
"mistral",
|
85
|
+
"together",
|
86
|
+
"perplexity",
|
87
|
+
]
|
88
|
+
|
89
|
+
|
70
90
|
service_to_api_keyname = {
|
71
91
|
InferenceServiceType.BEDROCK.value: "TBD",
|
72
92
|
InferenceServiceType.DEEP_INFRA.value: "DEEP_INFRA_API_KEY",
|