edsl 0.1.39.dev1__py3-none-any.whl → 0.1.39.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (194) hide show
  1. edsl/Base.py +169 -116
  2. edsl/__init__.py +14 -6
  3. edsl/__version__.py +1 -1
  4. edsl/agents/Agent.py +358 -146
  5. edsl/agents/AgentList.py +211 -73
  6. edsl/agents/Invigilator.py +88 -36
  7. edsl/agents/InvigilatorBase.py +59 -70
  8. edsl/agents/PromptConstructor.py +117 -219
  9. edsl/agents/QuestionInstructionPromptBuilder.py +128 -0
  10. edsl/agents/QuestionOptionProcessor.py +172 -0
  11. edsl/agents/QuestionTemplateReplacementsBuilder.py +137 -0
  12. edsl/agents/__init__.py +0 -1
  13. edsl/agents/prompt_helpers.py +3 -3
  14. edsl/config.py +22 -2
  15. edsl/conversation/car_buying.py +2 -1
  16. edsl/coop/CoopFunctionsMixin.py +15 -0
  17. edsl/coop/ExpectedParrotKeyHandler.py +125 -0
  18. edsl/coop/PriceFetcher.py +1 -1
  19. edsl/coop/coop.py +104 -42
  20. edsl/coop/utils.py +14 -14
  21. edsl/data/Cache.py +21 -14
  22. edsl/data/CacheEntry.py +12 -15
  23. edsl/data/CacheHandler.py +33 -12
  24. edsl/data/__init__.py +4 -3
  25. edsl/data_transfer_models.py +2 -1
  26. edsl/enums.py +20 -0
  27. edsl/exceptions/__init__.py +50 -50
  28. edsl/exceptions/agents.py +12 -0
  29. edsl/exceptions/inference_services.py +5 -0
  30. edsl/exceptions/questions.py +24 -6
  31. edsl/exceptions/scenarios.py +7 -0
  32. edsl/inference_services/AnthropicService.py +0 -3
  33. edsl/inference_services/AvailableModelCacheHandler.py +184 -0
  34. edsl/inference_services/AvailableModelFetcher.py +209 -0
  35. edsl/inference_services/AwsBedrock.py +0 -2
  36. edsl/inference_services/AzureAI.py +0 -2
  37. edsl/inference_services/GoogleService.py +2 -11
  38. edsl/inference_services/InferenceServiceABC.py +18 -85
  39. edsl/inference_services/InferenceServicesCollection.py +105 -80
  40. edsl/inference_services/MistralAIService.py +0 -3
  41. edsl/inference_services/OpenAIService.py +1 -4
  42. edsl/inference_services/PerplexityService.py +0 -3
  43. edsl/inference_services/ServiceAvailability.py +135 -0
  44. edsl/inference_services/TestService.py +11 -8
  45. edsl/inference_services/data_structures.py +62 -0
  46. edsl/jobs/AnswerQuestionFunctionConstructor.py +188 -0
  47. edsl/jobs/Answers.py +1 -14
  48. edsl/jobs/FetchInvigilator.py +40 -0
  49. edsl/jobs/InterviewTaskManager.py +98 -0
  50. edsl/jobs/InterviewsConstructor.py +48 -0
  51. edsl/jobs/Jobs.py +102 -243
  52. edsl/jobs/JobsChecks.py +35 -10
  53. edsl/jobs/JobsComponentConstructor.py +189 -0
  54. edsl/jobs/JobsPrompts.py +5 -3
  55. edsl/jobs/JobsRemoteInferenceHandler.py +128 -80
  56. edsl/jobs/JobsRemoteInferenceLogger.py +239 -0
  57. edsl/jobs/RequestTokenEstimator.py +30 -0
  58. edsl/jobs/buckets/BucketCollection.py +44 -3
  59. edsl/jobs/buckets/TokenBucket.py +53 -21
  60. edsl/jobs/buckets/TokenBucketAPI.py +211 -0
  61. edsl/jobs/buckets/TokenBucketClient.py +191 -0
  62. edsl/jobs/decorators.py +35 -0
  63. edsl/jobs/interviews/Interview.py +77 -380
  64. edsl/jobs/jobs_status_enums.py +9 -0
  65. edsl/jobs/loggers/HTMLTableJobLogger.py +304 -0
  66. edsl/jobs/runners/JobsRunnerAsyncio.py +4 -49
  67. edsl/jobs/tasks/QuestionTaskCreator.py +21 -19
  68. edsl/jobs/tasks/TaskHistory.py +14 -15
  69. edsl/jobs/tasks/task_status_enum.py +0 -2
  70. edsl/language_models/ComputeCost.py +63 -0
  71. edsl/language_models/LanguageModel.py +137 -234
  72. edsl/language_models/ModelList.py +11 -13
  73. edsl/language_models/PriceManager.py +127 -0
  74. edsl/language_models/RawResponseHandler.py +106 -0
  75. edsl/language_models/ServiceDataSources.py +0 -0
  76. edsl/language_models/__init__.py +0 -1
  77. edsl/language_models/key_management/KeyLookup.py +63 -0
  78. edsl/language_models/key_management/KeyLookupBuilder.py +273 -0
  79. edsl/language_models/key_management/KeyLookupCollection.py +38 -0
  80. edsl/language_models/key_management/__init__.py +0 -0
  81. edsl/language_models/key_management/models.py +131 -0
  82. edsl/language_models/registry.py +49 -59
  83. edsl/language_models/repair.py +2 -2
  84. edsl/language_models/utilities.py +5 -4
  85. edsl/notebooks/Notebook.py +19 -14
  86. edsl/notebooks/NotebookToLaTeX.py +142 -0
  87. edsl/prompts/Prompt.py +29 -39
  88. edsl/questions/AnswerValidatorMixin.py +47 -2
  89. edsl/questions/ExceptionExplainer.py +77 -0
  90. edsl/questions/HTMLQuestion.py +103 -0
  91. edsl/questions/LoopProcessor.py +149 -0
  92. edsl/questions/QuestionBase.py +37 -192
  93. edsl/questions/QuestionBaseGenMixin.py +52 -48
  94. edsl/questions/QuestionBasePromptsMixin.py +7 -3
  95. edsl/questions/QuestionCheckBox.py +1 -1
  96. edsl/questions/QuestionExtract.py +1 -1
  97. edsl/questions/QuestionFreeText.py +1 -2
  98. edsl/questions/QuestionList.py +3 -5
  99. edsl/questions/QuestionMatrix.py +265 -0
  100. edsl/questions/QuestionMultipleChoice.py +66 -22
  101. edsl/questions/QuestionNumerical.py +1 -3
  102. edsl/questions/QuestionRank.py +6 -16
  103. edsl/questions/ResponseValidatorABC.py +37 -11
  104. edsl/questions/ResponseValidatorFactory.py +28 -0
  105. edsl/questions/SimpleAskMixin.py +4 -3
  106. edsl/questions/__init__.py +1 -0
  107. edsl/questions/derived/QuestionLinearScale.py +6 -3
  108. edsl/questions/derived/QuestionTopK.py +1 -1
  109. edsl/questions/descriptors.py +17 -3
  110. edsl/questions/question_registry.py +1 -1
  111. edsl/questions/templates/matrix/__init__.py +1 -0
  112. edsl/questions/templates/matrix/answering_instructions.jinja +5 -0
  113. edsl/questions/templates/matrix/question_presentation.jinja +20 -0
  114. edsl/results/CSSParameterizer.py +1 -1
  115. edsl/results/Dataset.py +170 -7
  116. edsl/results/DatasetExportMixin.py +224 -302
  117. edsl/results/DatasetTree.py +28 -8
  118. edsl/results/MarkdownToDocx.py +122 -0
  119. edsl/results/MarkdownToPDF.py +111 -0
  120. edsl/results/Result.py +192 -206
  121. edsl/results/Results.py +120 -113
  122. edsl/results/ResultsExportMixin.py +2 -0
  123. edsl/results/Selector.py +23 -13
  124. edsl/results/TableDisplay.py +98 -171
  125. edsl/results/TextEditor.py +50 -0
  126. edsl/results/__init__.py +1 -1
  127. edsl/results/smart_objects.py +96 -0
  128. edsl/results/table_data_class.py +12 -0
  129. edsl/results/table_renderers.py +118 -0
  130. edsl/scenarios/ConstructDownloadLink.py +109 -0
  131. edsl/scenarios/DirectoryScanner.py +96 -0
  132. edsl/scenarios/DocumentChunker.py +102 -0
  133. edsl/scenarios/DocxScenario.py +16 -0
  134. edsl/scenarios/FileStore.py +118 -239
  135. edsl/scenarios/PdfExtractor.py +40 -0
  136. edsl/scenarios/Scenario.py +90 -193
  137. edsl/scenarios/ScenarioHtmlMixin.py +4 -3
  138. edsl/scenarios/ScenarioJoin.py +10 -6
  139. edsl/scenarios/ScenarioList.py +383 -240
  140. edsl/scenarios/ScenarioListExportMixin.py +0 -7
  141. edsl/scenarios/ScenarioListPdfMixin.py +15 -37
  142. edsl/scenarios/ScenarioSelector.py +156 -0
  143. edsl/scenarios/__init__.py +1 -2
  144. edsl/scenarios/file_methods.py +85 -0
  145. edsl/scenarios/handlers/__init__.py +13 -0
  146. edsl/scenarios/handlers/csv.py +38 -0
  147. edsl/scenarios/handlers/docx.py +76 -0
  148. edsl/scenarios/handlers/html.py +37 -0
  149. edsl/scenarios/handlers/json.py +111 -0
  150. edsl/scenarios/handlers/latex.py +5 -0
  151. edsl/scenarios/handlers/md.py +51 -0
  152. edsl/scenarios/handlers/pdf.py +68 -0
  153. edsl/scenarios/handlers/png.py +39 -0
  154. edsl/scenarios/handlers/pptx.py +105 -0
  155. edsl/scenarios/handlers/py.py +294 -0
  156. edsl/scenarios/handlers/sql.py +313 -0
  157. edsl/scenarios/handlers/sqlite.py +149 -0
  158. edsl/scenarios/handlers/txt.py +33 -0
  159. edsl/study/ObjectEntry.py +1 -1
  160. edsl/study/SnapShot.py +1 -1
  161. edsl/study/Study.py +5 -12
  162. edsl/surveys/ConstructDAG.py +92 -0
  163. edsl/surveys/EditSurvey.py +221 -0
  164. edsl/surveys/InstructionHandler.py +100 -0
  165. edsl/surveys/MemoryManagement.py +72 -0
  166. edsl/surveys/Rule.py +5 -4
  167. edsl/surveys/RuleCollection.py +25 -27
  168. edsl/surveys/RuleManager.py +172 -0
  169. edsl/surveys/Simulator.py +75 -0
  170. edsl/surveys/Survey.py +199 -771
  171. edsl/surveys/SurveyCSS.py +20 -8
  172. edsl/surveys/{SurveyFlowVisualizationMixin.py → SurveyFlowVisualization.py} +11 -9
  173. edsl/surveys/SurveyToApp.py +141 -0
  174. edsl/surveys/__init__.py +4 -2
  175. edsl/surveys/descriptors.py +6 -2
  176. edsl/surveys/instructions/ChangeInstruction.py +1 -2
  177. edsl/surveys/instructions/Instruction.py +4 -13
  178. edsl/surveys/instructions/InstructionCollection.py +11 -6
  179. edsl/templates/error_reporting/interview_details.html +1 -1
  180. edsl/templates/error_reporting/report.html +1 -1
  181. edsl/tools/plotting.py +1 -1
  182. edsl/utilities/PrettyList.py +56 -0
  183. edsl/utilities/is_notebook.py +18 -0
  184. edsl/utilities/is_valid_variable_name.py +11 -0
  185. edsl/utilities/remove_edsl_version.py +24 -0
  186. edsl/utilities/utilities.py +35 -23
  187. {edsl-0.1.39.dev1.dist-info → edsl-0.1.39.dev2.dist-info}/METADATA +12 -10
  188. edsl-0.1.39.dev2.dist-info/RECORD +352 -0
  189. edsl/language_models/KeyLookup.py +0 -30
  190. edsl/language_models/unused/ReplicateBase.py +0 -83
  191. edsl/results/ResultsDBMixin.py +0 -238
  192. edsl-0.1.39.dev1.dist-info/RECORD +0 -277
  193. {edsl-0.1.39.dev1.dist-info → edsl-0.1.39.dev2.dist-info}/LICENSE +0 -0
  194. {edsl-0.1.39.dev1.dist-info → edsl-0.1.39.dev2.dist-info}/WHEEL +0 -0
edsl/coop/PriceFetcher.py CHANGED
@@ -18,7 +18,7 @@ class PriceFetcher:
18
18
 
19
19
  import os
20
20
  import requests
21
- from edsl import CONFIG
21
+ from edsl.config import CONFIG
22
22
 
23
23
  try:
24
24
  # Fetch the pricing data
edsl/coop/coop.py CHANGED
@@ -1,11 +1,19 @@
1
1
  import aiohttp
2
2
  import json
3
- import os
4
3
  import requests
5
- from typing import Any, Optional, Union, Literal
4
+
5
+ from typing import Any, Optional, Union, Literal, TypedDict
6
6
  from uuid import UUID
7
+ from collections import UserDict, defaultdict
8
+
7
9
  import edsl
8
- from edsl import CONFIG, CacheEntry, Jobs, Survey
10
+ from pathlib import Path
11
+
12
+ from edsl.config import CONFIG
13
+ from edsl.data.CacheEntry import CacheEntry
14
+ from edsl.jobs.Jobs import Jobs
15
+ from edsl.surveys.Survey import Survey
16
+
9
17
  from edsl.exceptions.coop import CoopNoUUIDError, CoopServerResponseError
10
18
  from edsl.coop.utils import (
11
19
  EDSLObject,
@@ -15,19 +23,48 @@ from edsl.coop.utils import (
15
23
  VisibilityType,
16
24
  )
17
25
 
26
+ from edsl.coop.CoopFunctionsMixin import CoopFunctionsMixin
27
+ from edsl.coop.ExpectedParrotKeyHandler import ExpectedParrotKeyHandler
28
+
29
+ from edsl.inference_services.data_structures import ServiceToModelsMapping
30
+
31
+
32
+ class RemoteInferenceResponse(TypedDict):
33
+ job_uuid: str
34
+ results_uuid: str
35
+ results_url: str
36
+ latest_error_report_uuid: str
37
+ latest_error_report_url: str
38
+ status: str
39
+ reason: str
40
+ credits_consumed: float
41
+ version: str
42
+
43
+
44
+ class RemoteInferenceCreationInfo(TypedDict):
45
+ uuid: str
46
+ description: str
47
+ status: str
48
+ iterations: int
49
+ visibility: str
50
+ version: str
18
51
 
19
- class Coop:
52
+
53
+ class Coop(CoopFunctionsMixin):
20
54
  """
21
55
  Client for the Expected Parrot API.
22
56
  """
23
57
 
24
- def __init__(self, api_key: str = None, url: str = None) -> None:
58
+ def __init__(
59
+ self, api_key: Optional[str] = None, url: Optional[str] = None
60
+ ) -> None:
25
61
  """
26
62
  Initialize the client.
27
63
  - Provide an API key directly, or through an env variable.
28
64
  - Provide a URL directly, or use the default one.
29
65
  """
30
- self.api_key = api_key or os.getenv("EXPECTED_PARROT_API_KEY")
66
+ self.ep_key_handler = ExpectedParrotKeyHandler()
67
+ self.api_key = api_key or self.ep_key_handler.get_ep_api_key()
31
68
 
32
69
  self.url = url or CONFIG.EXPECTED_PARROT_URL
33
70
  if self.url.endswith("/"):
@@ -163,19 +200,27 @@ class Coop:
163
200
  edsl_auth_token = secrets.token_urlsafe(16)
164
201
 
165
202
  print("Your Expected Parrot API key is invalid.")
166
- print(
167
- "\nUse the link below to log in to Expected Parrot so we can automatically update your API key."
203
+ self._display_login_url(
204
+ edsl_auth_token=edsl_auth_token,
205
+ link_description="\n🔗 Use the link below to log in to Expected Parrot so we can automatically update your API key.",
168
206
  )
169
- self._display_login_url(edsl_auth_token=edsl_auth_token)
170
207
  api_key = self._poll_for_api_key(edsl_auth_token)
171
208
 
172
209
  if api_key is None:
173
210
  print("\nTimed out waiting for login. Please try again.")
174
211
  return
175
212
 
176
- write_api_key_to_env(api_key)
177
- print("\n✨ API key retrieved and written to .env file.")
178
- print("Rerun your code to try again with a valid API key.")
213
+ print("\n✨ API key retrieved.")
214
+
215
+ if stored_in_user_space := self.ep_key_handler.ask_to_store(api_key):
216
+ pass
217
+ else:
218
+ path_to_env = write_api_key_to_env(api_key)
219
+ print(
220
+ "\n✨ API key retrieved and written to .env file at the following path:"
221
+ )
222
+ print(f" {path_to_env}")
223
+ print("Rerun your code to try again with a valid API key.")
179
224
  return
180
225
 
181
226
  elif "Authorization" in message:
@@ -268,6 +313,7 @@ class Coop:
268
313
  self,
269
314
  object: EDSLObject,
270
315
  description: Optional[str] = None,
316
+ alias: Optional[str] = None,
271
317
  visibility: Optional[VisibilityType] = "unlisted",
272
318
  ) -> dict:
273
319
  """
@@ -279,6 +325,7 @@ class Coop:
279
325
  method="POST",
280
326
  payload={
281
327
  "description": description,
328
+ "alias": alias,
282
329
  "json_string": json.dumps(
283
330
  object.to_dict(),
284
331
  default=self._json_handle_none,
@@ -373,6 +420,7 @@ class Coop:
373
420
  uuid: Union[str, UUID] = None,
374
421
  url: str = None,
375
422
  description: Optional[str] = None,
423
+ alias: Optional[str] = None,
376
424
  value: Optional[EDSLObject] = None,
377
425
  visibility: Optional[VisibilityType] = None,
378
426
  ) -> dict:
@@ -389,6 +437,7 @@ class Coop:
389
437
  params={"uuid": uuid},
390
438
  payload={
391
439
  "description": description,
440
+ "alias": alias,
392
441
  "json_string": (
393
442
  json.dumps(
394
443
  value.to_dict(),
@@ -613,7 +662,7 @@ class Coop:
613
662
  visibility: Optional[VisibilityType] = "unlisted",
614
663
  initial_results_visibility: Optional[VisibilityType] = "unlisted",
615
664
  iterations: Optional[int] = 1,
616
- ) -> dict:
665
+ ) -> RemoteInferenceCreationInfo:
617
666
  """
618
667
  Send a remote inference job to the server.
619
668
 
@@ -645,18 +694,21 @@ class Coop:
645
694
  )
646
695
  self._resolve_server_response(response)
647
696
  response_json = response.json()
648
- return {
649
- "uuid": response_json.get("job_uuid"),
650
- "description": response_json.get("description"),
651
- "status": response_json.get("status"),
652
- "iterations": response_json.get("iterations"),
653
- "visibility": response_json.get("visibility"),
654
- "version": self._edsl_version,
655
- }
697
+
698
+ return RemoteInferenceCreationInfo(
699
+ **{
700
+ "uuid": response_json.get("job_uuid"),
701
+ "description": response_json.get("description"),
702
+ "status": response_json.get("status"),
703
+ "iterations": response_json.get("iterations"),
704
+ "visibility": response_json.get("visibility"),
705
+ "version": self._edsl_version,
706
+ }
707
+ )
656
708
 
657
709
  def remote_inference_get(
658
710
  self, job_uuid: Optional[str] = None, results_uuid: Optional[str] = None
659
- ) -> dict:
711
+ ) -> RemoteInferenceResponse:
660
712
  """
661
713
  Get the details of a remote inference job.
662
714
  You can pass either the job uuid or the results uuid as a parameter.
@@ -698,17 +750,19 @@ class Coop:
698
750
  f"{self.url}/home/remote-inference/error/{latest_error_report_uuid}"
699
751
  )
700
752
 
701
- return {
702
- "job_uuid": data.get("job_uuid"),
703
- "results_uuid": results_uuid,
704
- "results_url": results_url,
705
- "latest_error_report_uuid": latest_error_report_uuid,
706
- "latest_error_report_url": latest_error_report_url,
707
- "status": data.get("status"),
708
- "reason": data.get("reason"),
709
- "credits_consumed": data.get("price"),
710
- "version": data.get("version"),
711
- }
753
+ return RemoteInferenceResponse(
754
+ **{
755
+ "job_uuid": data.get("job_uuid"),
756
+ "results_uuid": results_uuid,
757
+ "results_url": results_url,
758
+ "latest_error_report_uuid": latest_error_report_uuid,
759
+ "latest_error_report_url": latest_error_report_url,
760
+ "status": data.get("status"),
761
+ "reason": data.get("reason"),
762
+ "credits_consumed": data.get("price"),
763
+ "version": data.get("version"),
764
+ }
765
+ )
712
766
 
713
767
  def remote_inference_cost(
714
768
  self, input: Union[Jobs, Survey], iterations: int = 1
@@ -810,7 +864,7 @@ class Coop:
810
864
  "Invalid EDSL_FETCH_TOKEN_PRICES value---should be 'True' or 'False'."
811
865
  )
812
866
 
813
- def fetch_models(self) -> dict:
867
+ def fetch_models(self) -> ServiceToModelsMapping:
814
868
  """
815
869
  Fetch a dict of available models from Coop.
816
870
 
@@ -819,7 +873,7 @@ class Coop:
819
873
  response = self._send_server_request(uri="api/v0/models", method="GET")
820
874
  self._resolve_server_response(response)
821
875
  data = response.json()
822
- return data
876
+ return ServiceToModelsMapping(data)
823
877
 
824
878
  def fetch_rate_limit_config_vars(self) -> dict:
825
879
  """
@@ -835,7 +889,9 @@ class Coop:
835
889
  data = response.json()
836
890
  return data
837
891
 
838
- def _display_login_url(self, edsl_auth_token: str):
892
+ def _display_login_url(
893
+ self, edsl_auth_token: str, link_description: Optional[str] = None
894
+ ):
839
895
  """
840
896
  Uses rich.print to display a login URL.
841
897
 
@@ -845,7 +901,12 @@ class Coop:
845
901
 
846
902
  url = f"{CONFIG.EXPECTED_PARROT_URL}/login?edsl_auth_token={edsl_auth_token}"
847
903
 
848
- rich_print(f"[#38bdf8][link={url}]{url}[/link][/#38bdf8]")
904
+ if link_description:
905
+ rich_print(
906
+ f"{link_description}\n [#38bdf8][link={url}]{url}[/link][/#38bdf8]"
907
+ )
908
+ else:
909
+ rich_print(f" [#38bdf8][link={url}]{url}[/link][/#38bdf8]")
849
910
 
850
911
  def _get_api_key(self, edsl_auth_token: str):
851
912
  """
@@ -873,17 +934,18 @@ class Coop:
873
934
 
874
935
  edsl_auth_token = secrets.token_urlsafe(16)
875
936
 
876
- print(
877
- "\nUse the link below to log in to Expected Parrot so we can automatically update your API key."
937
+ self._display_login_url(
938
+ edsl_auth_token=edsl_auth_token,
939
+ link_description="\n🔗 Use the link below to log in to Expected Parrot so we can automatically update your API key.",
878
940
  )
879
- self._display_login_url(edsl_auth_token=edsl_auth_token)
880
941
  api_key = self._poll_for_api_key(edsl_auth_token)
881
942
 
882
943
  if api_key is None:
883
944
  raise Exception("Timed out waiting for login. Please try again.")
884
945
 
885
- write_api_key_to_env(api_key)
886
- print("\n✨ API key retrieved and written to .env file.")
946
+ path_to_env = write_api_key_to_env(api_key)
947
+ print("\n✨ API key retrieved and written to .env file at the following path:")
948
+ print(f" {path_to_env}")
887
949
 
888
950
  # Add API key to environment
889
951
  load_dotenv()
edsl/coop/utils.py CHANGED
@@ -1,19 +1,19 @@
1
- from edsl import (
2
- Agent,
3
- AgentList,
4
- Cache,
5
- ModelList,
6
- Notebook,
7
- Results,
8
- Scenario,
9
- ScenarioList,
10
- Survey,
11
- Study,
12
- )
13
- from edsl.language_models import LanguageModel
14
- from edsl.questions import QuestionBase
15
1
  from typing import Literal, Optional, Type, Union
16
2
 
3
+ from edsl.agents.Agent import Agent
4
+ from edsl.agents.AgentList import AgentList
5
+ from edsl.data.Cache import Cache
6
+ from edsl.language_models.ModelList import ModelList
7
+ from edsl.notebooks.Notebook import Notebook
8
+ from edsl.results.Results import Results
9
+ from edsl.scenarios.Scenario import Scenario
10
+ from edsl.scenarios.ScenarioList import ScenarioList
11
+ from edsl.surveys.Survey import Survey
12
+ from edsl.study.Study import Study
13
+
14
+ from edsl.language_models.LanguageModel import LanguageModel
15
+ from edsl.questions.QuestionBase import QuestionBase
16
+
17
17
  EDSLObject = Union[
18
18
  Agent,
19
19
  AgentList,
edsl/data/Cache.py CHANGED
@@ -6,12 +6,12 @@ from __future__ import annotations
6
6
  import json
7
7
  import os
8
8
  import warnings
9
- import copy
10
9
  from typing import Optional, Union
11
10
  from edsl.Base import Base
12
- from edsl.data.CacheEntry import CacheEntry
13
- from edsl.utilities.utilities import dict_hash
14
- from edsl.utilities.decorators import remove_edsl_version
11
+
12
+
13
+ # from edsl.utilities.decorators import remove_edsl_version
14
+ from edsl.utilities.remove_edsl_version import remove_edsl_version
15
15
  from edsl.exceptions.cache import CacheError
16
16
 
17
17
 
@@ -83,9 +83,9 @@ class Cache(Base):
83
83
 
84
84
  self._perform_checks()
85
85
 
86
- def rich_print(sefl):
87
- pass
88
- # raise NotImplementedError("This method is not implemented yet.")
86
+ # def rich_print(sefl):
87
+ # pass
88
+ # # raise NotImplementedError("This method is not implemented yet.")
89
89
 
90
90
  def code(sefl):
91
91
  pass
@@ -201,6 +201,7 @@ class Cache(Base):
201
201
  >>> len(c)
202
202
  1
203
203
  """
204
+ from edsl.data.CacheEntry import CacheEntry
204
205
 
205
206
  entry = CacheEntry(
206
207
  model=model,
@@ -226,6 +227,7 @@ class Cache(Base):
226
227
 
227
228
  :param write_now: Whether to write to the cache immediately (similar to `immediate_write`).
228
229
  """
230
+ from edsl.data.CacheEntry import CacheEntry
229
231
 
230
232
  for key, value in new_data.items():
231
233
  if key in self.data:
@@ -246,6 +248,8 @@ class Cache(Base):
246
248
 
247
249
  :param write_now: Whether to write to the cache immediately (similar to `immediate_write`).
248
250
  """
251
+ from edsl.data.CacheEntry import CacheEntry
252
+
249
253
  with open(filename, "a+") as f:
250
254
  f.seek(0)
251
255
  lines = f.readlines()
@@ -353,7 +357,8 @@ class Cache(Base):
353
357
  f.write(json.dumps({key: value.to_dict()}) + "\n")
354
358
 
355
359
  def to_scenario_list(self):
356
- from edsl import ScenarioList, Scenario
360
+ from edsl.scenarios.ScenarioList import ScenarioList
361
+ from edsl.scenarios.Scenario import Scenario
357
362
 
358
363
  scenarios = []
359
364
  for key, value in self.data.items():
@@ -399,6 +404,8 @@ class Cache(Base):
399
404
  ####################
400
405
  def __hash__(self):
401
406
  """Return the hash of the Cache."""
407
+ from edsl.utilities.utilities import dict_hash
408
+
402
409
  return dict_hash(self.to_dict(add_edsl_version=False))
403
410
 
404
411
  def to_dict(self, add_edsl_version=True) -> dict:
@@ -414,12 +421,6 @@ class Cache(Base):
414
421
  def _summary(self):
415
422
  return {"EDSL Class": "Cache", "Number of entries": len(self.data)}
416
423
 
417
- def _repr_html_(self):
418
- # from edsl.utilities.utilities import data_to_html
419
- # return data_to_html(self.to_dict())
420
- footer = f"<a href={self.__documentation__}>(docs)</a>"
421
- return str(self.summary(format="html")) + footer
422
-
423
424
  def table(
424
425
  self,
425
426
  *fields,
@@ -443,6 +444,8 @@ class Cache(Base):
443
444
  @remove_edsl_version
444
445
  def from_dict(cls, data) -> Cache:
445
446
  """Construct a Cache from a dictionary."""
447
+ from edsl.data.CacheEntry import CacheEntry
448
+
446
449
  newdata = {k: CacheEntry.from_dict(v) for k, v in data.items()}
447
450
  return cls(data=newdata)
448
451
 
@@ -485,6 +488,8 @@ class Cache(Base):
485
488
  """
486
489
  Create an example input for a 'fetch' operation.
487
490
  """
491
+ from edsl.data.CacheEntry import CacheEntry
492
+
488
493
  return CacheEntry.fetch_input_example()
489
494
 
490
495
  def to_html(self):
@@ -541,6 +546,8 @@ class Cache(Base):
541
546
 
542
547
  :param randomize: If True, uses CacheEntry's randomize method.
543
548
  """
549
+ from edsl.data.CacheEntry import CacheEntry
550
+
544
551
  return cls(
545
552
  data={
546
553
  CacheEntry.example(randomize).key: CacheEntry.example(),
edsl/data/CacheEntry.py CHANGED
@@ -5,8 +5,12 @@ import hashlib
5
5
  from typing import Optional
6
6
  from uuid import uuid4
7
7
 
8
+ from edsl.utilities.decorators import remove_edsl_version
8
9
 
9
- class CacheEntry:
10
+ from edsl.Base import RepresentationMixin
11
+
12
+
13
+ class CacheEntry(RepresentationMixin):
10
14
  """
11
15
  A Class to represent a cache entry.
12
16
  """
@@ -78,11 +82,11 @@ class CacheEntry:
78
82
  d = {k: value for k, value in self.__dict__.items() if k in self.key_fields}
79
83
  return self.gen_key(**d)
80
84
 
81
- def to_dict(self) -> dict:
85
+ def to_dict(self, add_edsl_version=True) -> dict:
82
86
  """
83
87
  Returns a dictionary representation of a CacheEntry.
84
88
  """
85
- return {
89
+ d = {
86
90
  "model": self.model,
87
91
  "parameters": self.parameters,
88
92
  "system_prompt": self.system_prompt,
@@ -91,19 +95,12 @@ class CacheEntry:
91
95
  "iteration": self.iteration,
92
96
  "timestamp": self.timestamp,
93
97
  }
98
+ # if add_edsl_version:
99
+ # from edsl import __version__
94
100
 
95
- def _repr_html_(self) -> str:
96
- """
97
- Returns an HTML representation of a CacheEntry.
98
- """
99
- # from edsl.utilities.utilities import data_to_html
100
- # return data_to_html(self.to_dict())
101
- d = self.to_dict()
102
- data = [[k, v] for k, v in d.items()]
103
- from tabulate import tabulate
104
-
105
- table = str(tabulate(data, headers=["keys", "values"], tablefmt="html"))
106
- return f"<pre>{table}</pre>"
101
+ # d["edsl_version"] = __version__
102
+ # d["edsl_class_name"] = self.__class__.__name__
103
+ return d
107
104
 
108
105
  def keys(self):
109
106
  return list(self.to_dict().keys())
edsl/data/CacheHandler.py CHANGED
@@ -3,19 +3,19 @@ import ast
3
3
  import json
4
4
  import os
5
5
  import shutil
6
- import sqlite3
7
- from edsl.config import CONFIG
8
- from edsl.data.Cache import Cache
9
- from edsl.data.CacheEntry import CacheEntry
10
- from edsl.data.SQLiteDict import SQLiteDict
6
+ from typing import TYPE_CHECKING
11
7
 
12
- from edsl.config import CONFIG
8
+ if TYPE_CHECKING:
9
+ from edsl.data.Cache import Cache
10
+ from edsl.data.CacheEntry import CacheEntry
13
11
 
14
12
 
15
- def set_session_cache(cache: Cache) -> None:
13
+ def set_session_cache(cache: "Cache") -> None:
16
14
  """
17
15
  Set the session cache.
18
16
  """
17
+ from edsl.config import CONFIG
18
+
19
19
  CONFIG.EDSL_SESSION_CACHE = cache
20
20
 
21
21
 
@@ -23,6 +23,8 @@ def unset_session_cache() -> None:
23
23
  """
24
24
  Unset the session cache.
25
25
  """
26
+ from edsl.config import CONFIG
27
+
26
28
  if hasattr(CONFIG, "EDSL_SESSION_CACHE"):
27
29
  del CONFIG.EDSL_SESSION_CACHE
28
30
 
@@ -32,7 +34,11 @@ class CacheHandler:
32
34
  This CacheHandler figures out what caches are available and does migrations, as needed.
33
35
  """
34
36
 
35
- CACHE_PATH = CONFIG.get("EDSL_DATABASE_PATH")
37
+ @property
38
+ def CACHE_PATH(self):
39
+ from edsl.config import CONFIG
40
+
41
+ return CONFIG.get("EDSL_DATABASE_PATH")
36
42
 
37
43
  def __init__(self, test: bool = False):
38
44
  self.test = test
@@ -52,16 +58,24 @@ class CacheHandler:
52
58
  if notify:
53
59
  print(f"Created cache directory: {dir_path}")
54
60
 
55
- def gen_cache(self) -> Cache:
61
+ def gen_cache(self) -> "Cache":
56
62
  """
57
63
  Generate a Cache object.
58
64
  """
65
+ from edsl.data.Cache import Cache
66
+
59
67
  if self.test:
60
68
  return Cache(data={})
61
69
 
70
+ # if self.CACHE_PATH is not None:
71
+ # return self.CACHE_PATH
72
+ from edsl.config import CONFIG
73
+
62
74
  if hasattr(CONFIG, "EDSL_SESSION_CACHE"):
63
75
  return CONFIG.EDSL_SESSION_CACHE
64
76
 
77
+ from edsl.data.SQLiteDict import SQLiteDict
78
+
65
79
  cache = Cache(data=SQLiteDict(self.CACHE_PATH))
66
80
  return cache
67
81
 
@@ -76,6 +90,8 @@ class CacheHandler:
76
90
  if not os.path.exists(os.path.join(os.getcwd(), path)):
77
91
  return old_data
78
92
  try:
93
+ import sqlite3
94
+
79
95
  conn = sqlite3.connect(path)
80
96
  with conn:
81
97
  cur = conn.cursor()
@@ -108,6 +124,8 @@ class CacheHandler:
108
124
  entry_dict["user_prompt"] = entry_dict.pop("prompt")
109
125
  parameters = entry_dict["parameters"]
110
126
  entry_dict["parameters"] = ast.literal_eval(parameters)
127
+ from edsl.data.CacheEntry import CacheEntry
128
+
111
129
  entry = CacheEntry(**entry_dict)
112
130
  return entry
113
131
 
@@ -117,7 +135,7 @@ class CacheHandler:
117
135
  ###############
118
136
  # NOT IN USE
119
137
  ###############
120
- def from_sqlite(uri="new_edsl_cache.db") -> dict[str, CacheEntry]:
138
+ def from_sqlite(uri="new_edsl_cache.db") -> dict[str, "CacheEntry"]:
121
139
  """
122
140
  Read in a new-style sqlite cache and return a dictionary of dictionaries.
123
141
  """
@@ -131,7 +149,7 @@ class CacheHandler:
131
149
  newdata[entry.key] = entry
132
150
  return newdata
133
151
 
134
- def from_jsonl(filename="edsl_cache.jsonl") -> dict[str, CacheEntry]:
152
+ def from_jsonl(filename="edsl_cache.jsonl") -> dict[str, "CacheEntry"]:
135
153
  """Read in a jsonl file and return a dictionary of CacheEntry objects."""
136
154
  with open(filename, "a+") as f:
137
155
  f.seek(0)
@@ -146,4 +164,7 @@ class CacheHandler:
146
164
 
147
165
 
148
166
  if __name__ == "__main__":
149
- ch = CacheHandler()
167
+ # ch = CacheHandler()
168
+ import doctest
169
+
170
+ doctest.testmod()
edsl/data/__init__.py CHANGED
@@ -1,4 +1,5 @@
1
- from edsl.data.CacheEntry import CacheEntry
2
- from edsl.data.SQLiteDict import SQLiteDict
1
+ # from edsl.data.CacheEntry import CacheEntry
2
+ # from edsl.data.SQLiteDict import SQLiteDict
3
3
  from edsl.data.Cache import Cache
4
- from edsl.data.CacheHandler import CacheHandler
4
+
5
+ # from edsl.data.CacheHandler import CacheHandler
@@ -1,6 +1,5 @@
1
1
  from typing import NamedTuple, Dict, List, Optional, Any
2
2
  from dataclasses import dataclass, fields
3
- import reprlib
4
3
 
5
4
 
6
5
  class ModelInputs(NamedTuple):
@@ -56,6 +55,8 @@ class ImageInfo:
56
55
  encoded_image: str
57
56
 
58
57
  def __repr__(self):
58
+ import reprlib
59
+
59
60
  reprlib_instance = reprlib.Repr()
60
61
  reprlib_instance.maxstring = 30 # Limit the string length for the encoded image
61
62
 
edsl/enums.py CHANGED
@@ -1,6 +1,7 @@
1
1
  """Enums for the different types of questions, language models, and inference services."""
2
2
 
3
3
  from enum import Enum
4
+ from typing import Literal
4
5
 
5
6
 
6
7
  class EnumWithChecks(Enum):
@@ -67,6 +68,25 @@ class InferenceServiceType(EnumWithChecks):
67
68
  PERPLEXITY = "perplexity"
68
69
 
69
70
 
71
+ # unavoidable violation of the DRY principle but it is necessary
72
+ # checked w/ a unit test to make sure consistent with services in enums.py
73
+ InferenceServiceLiteral = Literal[
74
+ "bedrock",
75
+ "deep_infra",
76
+ "replicate",
77
+ "openai",
78
+ "google",
79
+ "test",
80
+ "anthropic",
81
+ "groq",
82
+ "azure",
83
+ "ollama",
84
+ "mistral",
85
+ "together",
86
+ "perplexity",
87
+ ]
88
+
89
+
70
90
  service_to_api_keyname = {
71
91
  InferenceServiceType.BEDROCK.value: "TBD",
72
92
  InferenceServiceType.DEEP_INFRA.value: "DEEP_INFRA_API_KEY",