edsl 0.1.38.dev4__py3-none-any.whl → 0.1.39__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (212) hide show
  1. edsl/Base.py +197 -116
  2. edsl/__init__.py +15 -7
  3. edsl/__version__.py +1 -1
  4. edsl/agents/Agent.py +351 -147
  5. edsl/agents/AgentList.py +211 -73
  6. edsl/agents/Invigilator.py +101 -50
  7. edsl/agents/InvigilatorBase.py +62 -70
  8. edsl/agents/PromptConstructor.py +143 -225
  9. edsl/agents/QuestionInstructionPromptBuilder.py +128 -0
  10. edsl/agents/QuestionTemplateReplacementsBuilder.py +137 -0
  11. edsl/agents/__init__.py +0 -1
  12. edsl/agents/prompt_helpers.py +3 -3
  13. edsl/agents/question_option_processor.py +172 -0
  14. edsl/auto/AutoStudy.py +18 -5
  15. edsl/auto/StageBase.py +53 -40
  16. edsl/auto/StageQuestions.py +2 -1
  17. edsl/auto/utilities.py +0 -6
  18. edsl/config.py +22 -2
  19. edsl/conversation/car_buying.py +2 -1
  20. edsl/coop/CoopFunctionsMixin.py +15 -0
  21. edsl/coop/ExpectedParrotKeyHandler.py +125 -0
  22. edsl/coop/PriceFetcher.py +1 -1
  23. edsl/coop/coop.py +125 -47
  24. edsl/coop/utils.py +14 -14
  25. edsl/data/Cache.py +45 -27
  26. edsl/data/CacheEntry.py +12 -15
  27. edsl/data/CacheHandler.py +31 -12
  28. edsl/data/RemoteCacheSync.py +154 -46
  29. edsl/data/__init__.py +4 -3
  30. edsl/data_transfer_models.py +2 -1
  31. edsl/enums.py +27 -0
  32. edsl/exceptions/__init__.py +50 -50
  33. edsl/exceptions/agents.py +12 -0
  34. edsl/exceptions/inference_services.py +5 -0
  35. edsl/exceptions/questions.py +24 -6
  36. edsl/exceptions/scenarios.py +7 -0
  37. edsl/inference_services/AnthropicService.py +38 -19
  38. edsl/inference_services/AvailableModelCacheHandler.py +184 -0
  39. edsl/inference_services/AvailableModelFetcher.py +215 -0
  40. edsl/inference_services/AwsBedrock.py +0 -2
  41. edsl/inference_services/AzureAI.py +0 -2
  42. edsl/inference_services/GoogleService.py +7 -12
  43. edsl/inference_services/InferenceServiceABC.py +18 -85
  44. edsl/inference_services/InferenceServicesCollection.py +120 -79
  45. edsl/inference_services/MistralAIService.py +0 -3
  46. edsl/inference_services/OpenAIService.py +47 -35
  47. edsl/inference_services/PerplexityService.py +0 -3
  48. edsl/inference_services/ServiceAvailability.py +135 -0
  49. edsl/inference_services/TestService.py +11 -10
  50. edsl/inference_services/TogetherAIService.py +5 -3
  51. edsl/inference_services/data_structures.py +134 -0
  52. edsl/jobs/AnswerQuestionFunctionConstructor.py +223 -0
  53. edsl/jobs/Answers.py +1 -14
  54. edsl/jobs/FetchInvigilator.py +47 -0
  55. edsl/jobs/InterviewTaskManager.py +98 -0
  56. edsl/jobs/InterviewsConstructor.py +50 -0
  57. edsl/jobs/Jobs.py +356 -431
  58. edsl/jobs/JobsChecks.py +35 -10
  59. edsl/jobs/JobsComponentConstructor.py +189 -0
  60. edsl/jobs/JobsPrompts.py +6 -4
  61. edsl/jobs/JobsRemoteInferenceHandler.py +205 -133
  62. edsl/jobs/JobsRemoteInferenceLogger.py +239 -0
  63. edsl/jobs/RequestTokenEstimator.py +30 -0
  64. edsl/jobs/async_interview_runner.py +138 -0
  65. edsl/jobs/buckets/BucketCollection.py +44 -3
  66. edsl/jobs/buckets/TokenBucket.py +53 -21
  67. edsl/jobs/buckets/TokenBucketAPI.py +211 -0
  68. edsl/jobs/buckets/TokenBucketClient.py +191 -0
  69. edsl/jobs/check_survey_scenario_compatibility.py +85 -0
  70. edsl/jobs/data_structures.py +120 -0
  71. edsl/jobs/decorators.py +35 -0
  72. edsl/jobs/interviews/Interview.py +143 -408
  73. edsl/jobs/jobs_status_enums.py +9 -0
  74. edsl/jobs/loggers/HTMLTableJobLogger.py +304 -0
  75. edsl/jobs/results_exceptions_handler.py +98 -0
  76. edsl/jobs/runners/JobsRunnerAsyncio.py +88 -403
  77. edsl/jobs/runners/JobsRunnerStatus.py +133 -165
  78. edsl/jobs/tasks/QuestionTaskCreator.py +21 -19
  79. edsl/jobs/tasks/TaskHistory.py +38 -18
  80. edsl/jobs/tasks/task_status_enum.py +0 -2
  81. edsl/language_models/ComputeCost.py +63 -0
  82. edsl/language_models/LanguageModel.py +194 -236
  83. edsl/language_models/ModelList.py +28 -19
  84. edsl/language_models/PriceManager.py +127 -0
  85. edsl/language_models/RawResponseHandler.py +106 -0
  86. edsl/language_models/ServiceDataSources.py +0 -0
  87. edsl/language_models/__init__.py +1 -2
  88. edsl/language_models/key_management/KeyLookup.py +63 -0
  89. edsl/language_models/key_management/KeyLookupBuilder.py +273 -0
  90. edsl/language_models/key_management/KeyLookupCollection.py +38 -0
  91. edsl/language_models/key_management/__init__.py +0 -0
  92. edsl/language_models/key_management/models.py +131 -0
  93. edsl/language_models/model.py +256 -0
  94. edsl/language_models/repair.py +2 -2
  95. edsl/language_models/utilities.py +5 -4
  96. edsl/notebooks/Notebook.py +19 -14
  97. edsl/notebooks/NotebookToLaTeX.py +142 -0
  98. edsl/prompts/Prompt.py +29 -39
  99. edsl/questions/ExceptionExplainer.py +77 -0
  100. edsl/questions/HTMLQuestion.py +103 -0
  101. edsl/questions/QuestionBase.py +68 -214
  102. edsl/questions/QuestionBasePromptsMixin.py +7 -3
  103. edsl/questions/QuestionBudget.py +1 -1
  104. edsl/questions/QuestionCheckBox.py +3 -3
  105. edsl/questions/QuestionExtract.py +5 -7
  106. edsl/questions/QuestionFreeText.py +2 -3
  107. edsl/questions/QuestionList.py +10 -18
  108. edsl/questions/QuestionMatrix.py +265 -0
  109. edsl/questions/QuestionMultipleChoice.py +67 -23
  110. edsl/questions/QuestionNumerical.py +2 -4
  111. edsl/questions/QuestionRank.py +7 -17
  112. edsl/questions/SimpleAskMixin.py +4 -3
  113. edsl/questions/__init__.py +2 -1
  114. edsl/questions/{AnswerValidatorMixin.py → answer_validator_mixin.py} +47 -2
  115. edsl/questions/data_structures.py +20 -0
  116. edsl/questions/derived/QuestionLinearScale.py +6 -3
  117. edsl/questions/derived/QuestionTopK.py +1 -1
  118. edsl/questions/descriptors.py +17 -3
  119. edsl/questions/loop_processor.py +149 -0
  120. edsl/questions/{QuestionBaseGenMixin.py → question_base_gen_mixin.py} +57 -50
  121. edsl/questions/question_registry.py +1 -1
  122. edsl/questions/{ResponseValidatorABC.py → response_validator_abc.py} +40 -26
  123. edsl/questions/response_validator_factory.py +34 -0
  124. edsl/questions/templates/matrix/__init__.py +1 -0
  125. edsl/questions/templates/matrix/answering_instructions.jinja +5 -0
  126. edsl/questions/templates/matrix/question_presentation.jinja +20 -0
  127. edsl/results/CSSParameterizer.py +1 -1
  128. edsl/results/Dataset.py +170 -7
  129. edsl/results/DatasetExportMixin.py +168 -305
  130. edsl/results/DatasetTree.py +28 -8
  131. edsl/results/MarkdownToDocx.py +122 -0
  132. edsl/results/MarkdownToPDF.py +111 -0
  133. edsl/results/Result.py +298 -206
  134. edsl/results/Results.py +149 -131
  135. edsl/results/ResultsExportMixin.py +2 -0
  136. edsl/results/TableDisplay.py +98 -171
  137. edsl/results/TextEditor.py +50 -0
  138. edsl/results/__init__.py +1 -1
  139. edsl/results/file_exports.py +252 -0
  140. edsl/results/{Selector.py → results_selector.py} +23 -13
  141. edsl/results/smart_objects.py +96 -0
  142. edsl/results/table_data_class.py +12 -0
  143. edsl/results/table_renderers.py +118 -0
  144. edsl/scenarios/ConstructDownloadLink.py +109 -0
  145. edsl/scenarios/DocumentChunker.py +102 -0
  146. edsl/scenarios/DocxScenario.py +16 -0
  147. edsl/scenarios/FileStore.py +150 -239
  148. edsl/scenarios/PdfExtractor.py +40 -0
  149. edsl/scenarios/Scenario.py +90 -193
  150. edsl/scenarios/ScenarioHtmlMixin.py +4 -3
  151. edsl/scenarios/ScenarioList.py +415 -244
  152. edsl/scenarios/ScenarioListExportMixin.py +0 -7
  153. edsl/scenarios/ScenarioListPdfMixin.py +15 -37
  154. edsl/scenarios/__init__.py +1 -2
  155. edsl/scenarios/directory_scanner.py +96 -0
  156. edsl/scenarios/file_methods.py +85 -0
  157. edsl/scenarios/handlers/__init__.py +13 -0
  158. edsl/scenarios/handlers/csv.py +49 -0
  159. edsl/scenarios/handlers/docx.py +76 -0
  160. edsl/scenarios/handlers/html.py +37 -0
  161. edsl/scenarios/handlers/json.py +111 -0
  162. edsl/scenarios/handlers/latex.py +5 -0
  163. edsl/scenarios/handlers/md.py +51 -0
  164. edsl/scenarios/handlers/pdf.py +68 -0
  165. edsl/scenarios/handlers/png.py +39 -0
  166. edsl/scenarios/handlers/pptx.py +105 -0
  167. edsl/scenarios/handlers/py.py +294 -0
  168. edsl/scenarios/handlers/sql.py +313 -0
  169. edsl/scenarios/handlers/sqlite.py +149 -0
  170. edsl/scenarios/handlers/txt.py +33 -0
  171. edsl/scenarios/{ScenarioJoin.py → scenario_join.py} +10 -6
  172. edsl/scenarios/scenario_selector.py +156 -0
  173. edsl/study/ObjectEntry.py +1 -1
  174. edsl/study/SnapShot.py +1 -1
  175. edsl/study/Study.py +5 -12
  176. edsl/surveys/ConstructDAG.py +92 -0
  177. edsl/surveys/EditSurvey.py +221 -0
  178. edsl/surveys/InstructionHandler.py +100 -0
  179. edsl/surveys/MemoryManagement.py +72 -0
  180. edsl/surveys/Rule.py +5 -4
  181. edsl/surveys/RuleCollection.py +25 -27
  182. edsl/surveys/RuleManager.py +172 -0
  183. edsl/surveys/Simulator.py +75 -0
  184. edsl/surveys/Survey.py +270 -791
  185. edsl/surveys/SurveyCSS.py +20 -8
  186. edsl/surveys/{SurveyFlowVisualizationMixin.py → SurveyFlowVisualization.py} +11 -9
  187. edsl/surveys/SurveyToApp.py +141 -0
  188. edsl/surveys/__init__.py +4 -2
  189. edsl/surveys/descriptors.py +6 -2
  190. edsl/surveys/instructions/ChangeInstruction.py +1 -2
  191. edsl/surveys/instructions/Instruction.py +4 -13
  192. edsl/surveys/instructions/InstructionCollection.py +11 -6
  193. edsl/templates/error_reporting/interview_details.html +1 -1
  194. edsl/templates/error_reporting/report.html +1 -1
  195. edsl/tools/plotting.py +1 -1
  196. edsl/utilities/PrettyList.py +56 -0
  197. edsl/utilities/is_notebook.py +18 -0
  198. edsl/utilities/is_valid_variable_name.py +11 -0
  199. edsl/utilities/remove_edsl_version.py +24 -0
  200. edsl/utilities/utilities.py +35 -23
  201. {edsl-0.1.38.dev4.dist-info → edsl-0.1.39.dist-info}/METADATA +12 -10
  202. edsl-0.1.39.dist-info/RECORD +358 -0
  203. {edsl-0.1.38.dev4.dist-info → edsl-0.1.39.dist-info}/WHEEL +1 -1
  204. edsl/language_models/KeyLookup.py +0 -30
  205. edsl/language_models/registry.py +0 -190
  206. edsl/language_models/unused/ReplicateBase.py +0 -83
  207. edsl/results/ResultsDBMixin.py +0 -238
  208. edsl-0.1.38.dev4.dist-info/RECORD +0 -277
  209. /edsl/questions/{RegisterQuestionsMeta.py → register_questions_meta.py} +0 -0
  210. /edsl/results/{ResultsFetchMixin.py → results_fetch_mixin.py} +0 -0
  211. /edsl/results/{ResultsToolsMixin.py → results_tools_mixin.py} +0 -0
  212. {edsl-0.1.38.dev4.dist-info → edsl-0.1.39.dist-info}/LICENSE +0 -0
@@ -0,0 +1,125 @@
1
+ from pathlib import Path
2
+ import os
3
+ import platformdirs
4
+
5
+
6
+ import sys
7
+ import select
8
+
9
+
10
+ def get_input_with_timeout(prompt, timeout=5, default="y"):
11
+ print(prompt, end="", flush=True)
12
+ ready, _, _ = select.select([sys.stdin], [], [], timeout)
13
+ if ready:
14
+ return sys.stdin.readline().strip()
15
+ print(f"\nNo input received within {timeout} seconds. Using default: {default}")
16
+ return default
17
+
18
+
19
+ class ExpectedParrotKeyHandler:
20
+ asked_to_store_file_name = "asked_to_store.txt"
21
+ ep_key_file_name = "ep_api_key.txt"
22
+ application_name = "edsl"
23
+
24
+ @property
25
+ def config_dir(self):
26
+ return platformdirs.user_config_dir(self.application_name)
27
+
28
+ def _ep_key_file_exists(self) -> bool:
29
+ """Check if the Expected Parrot key file exists."""
30
+ return Path(self.config_dir).joinpath(self.ep_key_file_name).exists()
31
+
32
+ def ok_to_ask_to_store(self):
33
+ """Check if it's okay to ask the user to store the key."""
34
+ from edsl.config import CONFIG
35
+
36
+ if CONFIG.get("EDSL_RUN_MODE") != "production":
37
+ return False
38
+
39
+ return (
40
+ not Path(self.config_dir).joinpath(self.asked_to_store_file_name).exists()
41
+ )
42
+
43
+ def reset_asked_to_store(self):
44
+ """Reset the flag that indicates whether the user has been asked to store the key."""
45
+ asked_to_store_path = Path(self.config_dir).joinpath(
46
+ self.asked_to_store_file_name
47
+ )
48
+ if asked_to_store_path.exists():
49
+ os.remove(asked_to_store_path)
50
+ print(
51
+ "Deleted the file that indicates whether the user has been asked to store the key."
52
+ )
53
+
54
+ def ask_to_store(self, api_key) -> bool:
55
+ """Ask the user if they want to store the Expected Parrot key. If they say "yes", store it."""
56
+ if self.ok_to_ask_to_store():
57
+ # can_we_store = get_input_with_timeout(
58
+ # "Would you like to store your Expected Parrot key for future use? (y/n): ",
59
+ # timeout=5,
60
+ # default="y",
61
+ # )
62
+ can_we_store = "y"
63
+ if can_we_store.lower() == "y":
64
+ Path(self.config_dir).mkdir(parents=True, exist_ok=True)
65
+ self.store_ep_api_key(api_key)
66
+ # print("Stored Expected Parrot API key at ", self.config_dir)
67
+ return True
68
+ else:
69
+ Path(self.config_dir).mkdir(parents=True, exist_ok=True)
70
+ with open(
71
+ Path(self.config_dir).joinpath(self.asked_to_store_file_name), "w"
72
+ ) as f:
73
+ f.write("Yes")
74
+ return False
75
+
76
+ def get_ep_api_key(self):
77
+ # check if the key is stored in the config_dir
78
+ api_key = None
79
+ api_key_from_cache = None
80
+ api_key_from_os = None
81
+
82
+ if self._ep_key_file_exists():
83
+ with open(Path(self.config_dir).joinpath(self.ep_key_file_name), "r") as f:
84
+ api_key_from_cache = f.read().strip()
85
+
86
+ api_key_from_os = os.getenv("EXPECTED_PARROT_API_KEY")
87
+
88
+ if api_key_from_os and api_key_from_cache:
89
+ if api_key_from_os != api_key_from_cache:
90
+ import warnings
91
+
92
+ warnings.warn(
93
+ "WARNING: The Expected Parrot API key from the environment variable "
94
+ "differs from the one stored in the config directory. Using the one "
95
+ "from the environment variable."
96
+ )
97
+ api_key = api_key_from_os
98
+
99
+ if api_key_from_os and not api_key_from_cache:
100
+ api_key = api_key_from_os
101
+
102
+ if not api_key_from_os and api_key_from_cache:
103
+ api_key = api_key_from_cache
104
+
105
+ if api_key is not None:
106
+ _ = self.ask_to_store(api_key)
107
+ return api_key
108
+
109
+ def delete_ep_api_key(self):
110
+ key_path = Path(self.config_dir) / self.ep_key_file_name
111
+ if key_path.exists():
112
+ os.remove(key_path)
113
+ print("Deleted Expected Parrot API key at ", key_path)
114
+
115
+ def store_ep_api_key(self, api_key):
116
+ # Create the directory if it doesn't exist
117
+ os.makedirs(self.config_dir, exist_ok=True)
118
+
119
+ # Create the path for the key file
120
+ key_path = Path(self.config_dir) / self.ep_key_file_name
121
+
122
+ # Save the key
123
+ with open(key_path, "w") as f:
124
+ f.write(api_key)
125
+ # print("Stored Expected Parrot API key at ", key_path)
edsl/coop/PriceFetcher.py CHANGED
@@ -18,7 +18,7 @@ class PriceFetcher:
18
18
 
19
19
  import os
20
20
  import requests
21
- from edsl import CONFIG
21
+ from edsl.config import CONFIG
22
22
 
23
23
  try:
24
24
  # Fetch the pricing data
edsl/coop/coop.py CHANGED
@@ -1,11 +1,19 @@
1
1
  import aiohttp
2
2
  import json
3
- import os
4
3
  import requests
5
- from typing import Any, Optional, Union, Literal
4
+
5
+ from typing import Any, Optional, Union, Literal, TypedDict
6
6
  from uuid import UUID
7
+ from collections import UserDict, defaultdict
8
+
7
9
  import edsl
8
- from edsl import CONFIG, CacheEntry, Jobs, Survey
10
+ from pathlib import Path
11
+
12
+ from edsl.config import CONFIG
13
+ from edsl.data.CacheEntry import CacheEntry
14
+ from edsl.jobs.Jobs import Jobs
15
+ from edsl.surveys.Survey import Survey
16
+
9
17
  from edsl.exceptions.coop import CoopNoUUIDError, CoopServerResponseError
10
18
  from edsl.coop.utils import (
11
19
  EDSLObject,
@@ -15,19 +23,48 @@ from edsl.coop.utils import (
15
23
  VisibilityType,
16
24
  )
17
25
 
26
+ from edsl.coop.CoopFunctionsMixin import CoopFunctionsMixin
27
+ from edsl.coop.ExpectedParrotKeyHandler import ExpectedParrotKeyHandler
28
+
29
+ from edsl.inference_services.data_structures import ServiceToModelsMapping
30
+
18
31
 
19
- class Coop:
32
+ class RemoteInferenceResponse(TypedDict):
33
+ job_uuid: str
34
+ results_uuid: str
35
+ results_url: str
36
+ latest_error_report_uuid: str
37
+ latest_error_report_url: str
38
+ status: str
39
+ reason: str
40
+ credits_consumed: float
41
+ version: str
42
+
43
+
44
+ class RemoteInferenceCreationInfo(TypedDict):
45
+ uuid: str
46
+ description: str
47
+ status: str
48
+ iterations: int
49
+ visibility: str
50
+ version: str
51
+
52
+
53
+ class Coop(CoopFunctionsMixin):
20
54
  """
21
55
  Client for the Expected Parrot API.
22
56
  """
23
57
 
24
- def __init__(self, api_key: str = None, url: str = None) -> None:
58
+ def __init__(
59
+ self, api_key: Optional[str] = None, url: Optional[str] = None
60
+ ) -> None:
25
61
  """
26
62
  Initialize the client.
27
63
  - Provide an API key directly, or through an env variable.
28
64
  - Provide a URL directly, or use the default one.
29
65
  """
30
- self.api_key = api_key or os.getenv("EXPECTED_PARROT_API_KEY")
66
+ self.ep_key_handler = ExpectedParrotKeyHandler()
67
+ self.api_key = api_key or self.ep_key_handler.get_ep_api_key()
31
68
 
32
69
  self.url = url or CONFIG.EXPECTED_PARROT_URL
33
70
  if self.url.endswith("/"):
@@ -142,6 +179,7 @@ class Coop:
142
179
  Check the response from the server and raise errors as appropriate.
143
180
  """
144
181
  # Get EDSL version from header
182
+ # breakpoint()
145
183
  server_edsl_version = response.headers.get("X-EDSL-Version")
146
184
 
147
185
  if server_edsl_version:
@@ -150,11 +188,18 @@ class Coop:
150
188
  server_version_str=server_edsl_version,
151
189
  ):
152
190
  print(
153
- "Please upgrade your EDSL version to access our latest features. To upgrade, open your terminal and run `pip upgrade edsl`"
191
+ "Please upgrade your EDSL version to access our latest features. To upgrade, open your terminal and run `pip install --upgrade edsl`"
154
192
  )
155
193
 
156
194
  if response.status_code >= 400:
157
- message = response.json().get("detail")
195
+ try:
196
+ message = response.json().get("detail")
197
+ except json.JSONDecodeError:
198
+ raise CoopServerResponseError(
199
+ f"Server returned status code {response.status_code}."
200
+ "JSON response could not be decoded.",
201
+ "The server response was: " + response.text,
202
+ )
158
203
  # print(response.text)
159
204
  if "The API key you provided is invalid" in message and check_api_key:
160
205
  import secrets
@@ -163,19 +208,27 @@ class Coop:
163
208
  edsl_auth_token = secrets.token_urlsafe(16)
164
209
 
165
210
  print("Your Expected Parrot API key is invalid.")
166
- print(
167
- "\nUse the link below to log in to Expected Parrot so we can automatically update your API key."
211
+ self._display_login_url(
212
+ edsl_auth_token=edsl_auth_token,
213
+ link_description="\n🔗 Use the link below to log in to Expected Parrot so we can automatically update your API key.",
168
214
  )
169
- self._display_login_url(edsl_auth_token=edsl_auth_token)
170
215
  api_key = self._poll_for_api_key(edsl_auth_token)
171
216
 
172
217
  if api_key is None:
173
218
  print("\nTimed out waiting for login. Please try again.")
174
219
  return
175
220
 
176
- write_api_key_to_env(api_key)
177
- print("\n✨ API key retrieved and written to .env file.")
178
- print("Rerun your code to try again with a valid API key.")
221
+ print("\n✨ API key retrieved.")
222
+
223
+ if stored_in_user_space := self.ep_key_handler.ask_to_store(api_key):
224
+ pass
225
+ else:
226
+ path_to_env = write_api_key_to_env(api_key)
227
+ print(
228
+ "\n✨ API key retrieved and written to .env file at the following path:"
229
+ )
230
+ print(f" {path_to_env}")
231
+ print("Rerun your code to try again with a valid API key.")
179
232
  return
180
233
 
181
234
  elif "Authorization" in message:
@@ -268,6 +321,7 @@ class Coop:
268
321
  self,
269
322
  object: EDSLObject,
270
323
  description: Optional[str] = None,
324
+ alias: Optional[str] = None,
271
325
  visibility: Optional[VisibilityType] = "unlisted",
272
326
  ) -> dict:
273
327
  """
@@ -279,6 +333,7 @@ class Coop:
279
333
  method="POST",
280
334
  payload={
281
335
  "description": description,
336
+ "alias": alias,
282
337
  "json_string": json.dumps(
283
338
  object.to_dict(),
284
339
  default=self._json_handle_none,
@@ -373,6 +428,7 @@ class Coop:
373
428
  uuid: Union[str, UUID] = None,
374
429
  url: str = None,
375
430
  description: Optional[str] = None,
431
+ alias: Optional[str] = None,
376
432
  value: Optional[EDSLObject] = None,
377
433
  visibility: Optional[VisibilityType] = None,
378
434
  ) -> dict:
@@ -389,6 +445,7 @@ class Coop:
389
445
  params={"uuid": uuid},
390
446
  payload={
391
447
  "description": description,
448
+ "alias": alias,
392
449
  "json_string": (
393
450
  json.dumps(
394
451
  value.to_dict(),
@@ -602,9 +659,6 @@ class Coop:
602
659
  self._resolve_server_response(response)
603
660
  return response.json()
604
661
 
605
- ################
606
- # Remote Inference
607
- ################
608
662
  def remote_inference_create(
609
663
  self,
610
664
  job: Jobs,
@@ -613,7 +667,7 @@ class Coop:
613
667
  visibility: Optional[VisibilityType] = "unlisted",
614
668
  initial_results_visibility: Optional[VisibilityType] = "unlisted",
615
669
  iterations: Optional[int] = 1,
616
- ) -> dict:
670
+ ) -> RemoteInferenceCreationInfo:
617
671
  """
618
672
  Send a remote inference job to the server.
619
673
 
@@ -645,18 +699,21 @@ class Coop:
645
699
  )
646
700
  self._resolve_server_response(response)
647
701
  response_json = response.json()
648
- return {
649
- "uuid": response_json.get("job_uuid"),
650
- "description": response_json.get("description"),
651
- "status": response_json.get("status"),
652
- "iterations": response_json.get("iterations"),
653
- "visibility": response_json.get("visibility"),
654
- "version": self._edsl_version,
655
- }
702
+
703
+ return RemoteInferenceCreationInfo(
704
+ **{
705
+ "uuid": response_json.get("job_uuid"),
706
+ "description": response_json.get("description"),
707
+ "status": response_json.get("status"),
708
+ "iterations": response_json.get("iterations"),
709
+ "visibility": response_json.get("visibility"),
710
+ "version": self._edsl_version,
711
+ }
712
+ )
656
713
 
657
714
  def remote_inference_get(
658
715
  self, job_uuid: Optional[str] = None, results_uuid: Optional[str] = None
659
- ) -> dict:
716
+ ) -> RemoteInferenceResponse:
660
717
  """
661
718
  Get the details of a remote inference job.
662
719
  You can pass either the job uuid or the results uuid as a parameter.
@@ -698,17 +755,30 @@ class Coop:
698
755
  f"{self.url}/home/remote-inference/error/{latest_error_report_uuid}"
699
756
  )
700
757
 
701
- return {
702
- "job_uuid": data.get("job_uuid"),
703
- "results_uuid": results_uuid,
704
- "results_url": results_url,
705
- "latest_error_report_uuid": latest_error_report_uuid,
706
- "latest_error_report_url": latest_error_report_url,
707
- "status": data.get("status"),
708
- "reason": data.get("reason"),
709
- "credits_consumed": data.get("price"),
710
- "version": data.get("version"),
711
- }
758
+ return RemoteInferenceResponse(
759
+ **{
760
+ "job_uuid": data.get("job_uuid"),
761
+ "results_uuid": results_uuid,
762
+ "results_url": results_url,
763
+ "latest_error_report_uuid": latest_error_report_uuid,
764
+ "latest_error_report_url": latest_error_report_url,
765
+ "status": data.get("status"),
766
+ "reason": data.get("reason"),
767
+ "credits_consumed": data.get("price"),
768
+ "version": data.get("version"),
769
+ }
770
+ )
771
+
772
+ def get_running_jobs(self) -> list[str]:
773
+ """
774
+ Get a list of currently running job IDs.
775
+
776
+ Returns:
777
+ list[str]: List of running job UUIDs
778
+ """
779
+ response = self._send_server_request(uri="jobs/status", method="GET")
780
+ self._resolve_server_response(response)
781
+ return response.json().get("running_jobs", [])
712
782
 
713
783
  def remote_inference_cost(
714
784
  self, input: Union[Jobs, Survey], iterations: int = 1
@@ -810,7 +880,7 @@ class Coop:
810
880
  "Invalid EDSL_FETCH_TOKEN_PRICES value---should be 'True' or 'False'."
811
881
  )
812
882
 
813
- def fetch_models(self) -> dict:
883
+ def fetch_models(self) -> ServiceToModelsMapping:
814
884
  """
815
885
  Fetch a dict of available models from Coop.
816
886
 
@@ -819,7 +889,7 @@ class Coop:
819
889
  response = self._send_server_request(uri="api/v0/models", method="GET")
820
890
  self._resolve_server_response(response)
821
891
  data = response.json()
822
- return data
892
+ return ServiceToModelsMapping(data)
823
893
 
824
894
  def fetch_rate_limit_config_vars(self) -> dict:
825
895
  """
@@ -835,7 +905,9 @@ class Coop:
835
905
  data = response.json()
836
906
  return data
837
907
 
838
- def _display_login_url(self, edsl_auth_token: str):
908
+ def _display_login_url(
909
+ self, edsl_auth_token: str, link_description: Optional[str] = None
910
+ ):
839
911
  """
840
912
  Uses rich.print to display a login URL.
841
913
 
@@ -845,7 +917,12 @@ class Coop:
845
917
 
846
918
  url = f"{CONFIG.EXPECTED_PARROT_URL}/login?edsl_auth_token={edsl_auth_token}"
847
919
 
848
- rich_print(f"[#38bdf8][link={url}]{url}[/link][/#38bdf8]")
920
+ if link_description:
921
+ rich_print(
922
+ f"{link_description}\n [#38bdf8][link={url}]{url}[/link][/#38bdf8]"
923
+ )
924
+ else:
925
+ rich_print(f" [#38bdf8][link={url}]{url}[/link][/#38bdf8]")
849
926
 
850
927
  def _get_api_key(self, edsl_auth_token: str):
851
928
  """
@@ -873,17 +950,18 @@ class Coop:
873
950
 
874
951
  edsl_auth_token = secrets.token_urlsafe(16)
875
952
 
876
- print(
877
- "\nUse the link below to log in to Expected Parrot so we can automatically update your API key."
953
+ self._display_login_url(
954
+ edsl_auth_token=edsl_auth_token,
955
+ link_description="\n🔗 Use the link below to log in to Expected Parrot so we can automatically update your API key.",
878
956
  )
879
- self._display_login_url(edsl_auth_token=edsl_auth_token)
880
957
  api_key = self._poll_for_api_key(edsl_auth_token)
881
958
 
882
959
  if api_key is None:
883
960
  raise Exception("Timed out waiting for login. Please try again.")
884
961
 
885
- write_api_key_to_env(api_key)
886
- print("\n✨ API key retrieved and written to .env file.")
962
+ path_to_env = write_api_key_to_env(api_key)
963
+ print("\n✨ API key retrieved and written to .env file at the following path:")
964
+ print(f" {path_to_env}")
887
965
 
888
966
  # Add API key to environment
889
967
  load_dotenv()
edsl/coop/utils.py CHANGED
@@ -1,19 +1,19 @@
1
- from edsl import (
2
- Agent,
3
- AgentList,
4
- Cache,
5
- ModelList,
6
- Notebook,
7
- Results,
8
- Scenario,
9
- ScenarioList,
10
- Survey,
11
- Study,
12
- )
13
- from edsl.language_models import LanguageModel
14
- from edsl.questions import QuestionBase
15
1
  from typing import Literal, Optional, Type, Union
16
2
 
3
+ from edsl.agents.Agent import Agent
4
+ from edsl.agents.AgentList import AgentList
5
+ from edsl.data.Cache import Cache
6
+ from edsl.language_models.ModelList import ModelList
7
+ from edsl.notebooks.Notebook import Notebook
8
+ from edsl.results.Results import Results
9
+ from edsl.scenarios.Scenario import Scenario
10
+ from edsl.scenarios.ScenarioList import ScenarioList
11
+ from edsl.surveys.Survey import Survey
12
+ from edsl.study.Study import Study
13
+
14
+ from edsl.language_models.LanguageModel import LanguageModel
15
+ from edsl.questions.QuestionBase import QuestionBase
16
+
17
17
  EDSLObject = Union[
18
18
  Agent,
19
19
  AgentList,
edsl/data/Cache.py CHANGED
@@ -6,12 +6,10 @@ from __future__ import annotations
6
6
  import json
7
7
  import os
8
8
  import warnings
9
- import copy
10
- from typing import Optional, Union
9
+ from typing import Optional, Union, TYPE_CHECKING
11
10
  from edsl.Base import Base
12
- from edsl.data.CacheEntry import CacheEntry
13
- from edsl.utilities.utilities import dict_hash
14
- from edsl.utilities.decorators import remove_edsl_version
11
+
12
+ from edsl.utilities.remove_edsl_version import remove_edsl_version
15
13
  from edsl.exceptions.cache import CacheError
16
14
 
17
15
 
@@ -83,10 +81,6 @@ class Cache(Base):
83
81
 
84
82
  self._perform_checks()
85
83
 
86
- def rich_print(sefl):
87
- pass
88
- # raise NotImplementedError("This method is not implemented yet.")
89
-
90
84
  def code(sefl):
91
85
  pass
92
86
  # raise NotImplementedError("This method is not implemented yet.")
@@ -201,6 +195,7 @@ class Cache(Base):
201
195
  >>> len(c)
202
196
  1
203
197
  """
198
+ from edsl.data.CacheEntry import CacheEntry
204
199
 
205
200
  entry = CacheEntry(
206
201
  model=model,
@@ -226,6 +221,7 @@ class Cache(Base):
226
221
 
227
222
  :param write_now: Whether to write to the cache immediately (similar to `immediate_write`).
228
223
  """
224
+ from edsl.data.CacheEntry import CacheEntry
229
225
 
230
226
  for key, value in new_data.items():
231
227
  if key in self.data:
@@ -246,6 +242,8 @@ class Cache(Base):
246
242
 
247
243
  :param write_now: Whether to write to the cache immediately (similar to `immediate_write`).
248
244
  """
245
+ from edsl.data.CacheEntry import CacheEntry
246
+
249
247
  with open(filename, "a+") as f:
250
248
  f.seek(0)
251
249
  lines = f.readlines()
@@ -289,8 +287,8 @@ class Cache(Base):
289
287
 
290
288
  CACHE_PATH = CONFIG.get("EDSL_DATABASE_PATH")
291
289
  path = CACHE_PATH.replace("sqlite:///", "")
292
- db_path = os.path.join(os.path.dirname(path), "data.db")
293
- return cls.from_sqlite_db(db_path=db_path)
290
+ # db_path = os.path.join(os.path.dirname(path), "data.db")
291
+ return cls.from_sqlite_db(path)
294
292
 
295
293
  @classmethod
296
294
  def from_jsonl(cls, jsonlfile: str, db_path: Optional[str] = None) -> Cache:
@@ -353,7 +351,8 @@ class Cache(Base):
353
351
  f.write(json.dumps({key: value.to_dict()}) + "\n")
354
352
 
355
353
  def to_scenario_list(self):
356
- from edsl import ScenarioList, Scenario
354
+ from edsl.scenarios.ScenarioList import ScenarioList
355
+ from edsl.scenarios.Scenario import Scenario
357
356
 
358
357
  scenarios = []
359
358
  for key, value in self.data.items():
@@ -363,12 +362,32 @@ class Cache(Base):
363
362
  scenarios.append(s)
364
363
  return ScenarioList(scenarios)
365
364
 
366
- ####################
367
- # REMOTE
368
- ####################
369
- # TODO: Make this work
370
- # - Need to decide whether the cache belongs to a user and what can be shared
371
- # - I.e., some cache entries? all or nothing?
365
+ def __floordiv__(self, other: "Cache") -> "Cache":
366
+ """
367
+ Return a new Cache containing entries that are in self but not in other.
368
+ Uses // operator as alternative to subtraction.
369
+
370
+ :param other: Another Cache object to compare against
371
+ :return: A new Cache object containing unique entries
372
+
373
+ >>> from edsl.data.CacheEntry import CacheEntry
374
+ >>> ce1 = CacheEntry.example(randomize = True)
375
+ >>> ce2 = CacheEntry.example(randomize = True)
376
+ >>> ce2 = CacheEntry.example(randomize = True)
377
+ >>> c1 = Cache(data={ce1.key: ce1, ce2.key: ce2})
378
+ >>> c2 = Cache(data={ce1.key: ce1})
379
+ >>> c3 = c1 // c2
380
+ >>> len(c3)
381
+ 1
382
+ >>> c3.data[ce2.key] == ce2
383
+ True
384
+ """
385
+ if not isinstance(other, Cache):
386
+ raise CacheError("Can only compare two caches")
387
+
388
+ diff_data = {k: v for k, v in self.data.items() if k not in other.data}
389
+ return Cache(data=diff_data, immediate_write=self.immediate_write)
390
+
372
391
  @classmethod
373
392
  def from_url(cls, db_path=None) -> Cache:
374
393
  """
@@ -394,11 +413,10 @@ class Cache(Base):
394
413
  if self.filename:
395
414
  self.write(self.filename)
396
415
 
397
- ####################
398
- # DUNDER / USEFUL
399
- ####################
400
416
  def __hash__(self):
401
417
  """Return the hash of the Cache."""
418
+ from edsl.utilities.utilities import dict_hash
419
+
402
420
  return dict_hash(self.to_dict(add_edsl_version=False))
403
421
 
404
422
  def to_dict(self, add_edsl_version=True) -> dict:
@@ -414,12 +432,6 @@ class Cache(Base):
414
432
  def _summary(self):
415
433
  return {"EDSL Class": "Cache", "Number of entries": len(self.data)}
416
434
 
417
- def _repr_html_(self):
418
- # from edsl.utilities.utilities import data_to_html
419
- # return data_to_html(self.to_dict())
420
- footer = f"<a href={self.__documentation__}>(docs)</a>"
421
- return str(self.summary(format="html")) + footer
422
-
423
435
  def table(
424
436
  self,
425
437
  *fields,
@@ -443,6 +455,8 @@ class Cache(Base):
443
455
  @remove_edsl_version
444
456
  def from_dict(cls, data) -> Cache:
445
457
  """Construct a Cache from a dictionary."""
458
+ from edsl.data.CacheEntry import CacheEntry
459
+
446
460
  newdata = {k: CacheEntry.from_dict(v) for k, v in data.items()}
447
461
  return cls(data=newdata)
448
462
 
@@ -485,6 +499,8 @@ class Cache(Base):
485
499
  """
486
500
  Create an example input for a 'fetch' operation.
487
501
  """
502
+ from edsl.data.CacheEntry import CacheEntry
503
+
488
504
  return CacheEntry.fetch_input_example()
489
505
 
490
506
  def to_html(self):
@@ -541,6 +557,8 @@ class Cache(Base):
541
557
 
542
558
  :param randomize: If True, uses CacheEntry's randomize method.
543
559
  """
560
+ from edsl.data.CacheEntry import CacheEntry
561
+
544
562
  return cls(
545
563
  data={
546
564
  CacheEntry.example(randomize).key: CacheEntry.example(),