edsl 0.1.39__py3-none-any.whl → 0.1.39.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (212) hide show
  1. edsl/Base.py +116 -197
  2. edsl/__init__.py +7 -15
  3. edsl/__version__.py +1 -1
  4. edsl/agents/Agent.py +147 -351
  5. edsl/agents/AgentList.py +73 -211
  6. edsl/agents/Invigilator.py +50 -101
  7. edsl/agents/InvigilatorBase.py +70 -62
  8. edsl/agents/PromptConstructor.py +225 -143
  9. edsl/agents/__init__.py +1 -0
  10. edsl/agents/prompt_helpers.py +3 -3
  11. edsl/auto/AutoStudy.py +5 -18
  12. edsl/auto/StageBase.py +40 -53
  13. edsl/auto/StageQuestions.py +1 -2
  14. edsl/auto/utilities.py +6 -0
  15. edsl/config.py +2 -22
  16. edsl/conversation/car_buying.py +1 -2
  17. edsl/coop/PriceFetcher.py +1 -1
  18. edsl/coop/coop.py +47 -125
  19. edsl/coop/utils.py +14 -14
  20. edsl/data/Cache.py +27 -45
  21. edsl/data/CacheEntry.py +15 -12
  22. edsl/data/CacheHandler.py +12 -31
  23. edsl/data/RemoteCacheSync.py +46 -154
  24. edsl/data/__init__.py +3 -4
  25. edsl/data_transfer_models.py +1 -2
  26. edsl/enums.py +0 -27
  27. edsl/exceptions/__init__.py +50 -50
  28. edsl/exceptions/agents.py +0 -12
  29. edsl/exceptions/questions.py +6 -24
  30. edsl/exceptions/scenarios.py +0 -7
  31. edsl/inference_services/AnthropicService.py +19 -38
  32. edsl/inference_services/AwsBedrock.py +2 -0
  33. edsl/inference_services/AzureAI.py +2 -0
  34. edsl/inference_services/GoogleService.py +12 -7
  35. edsl/inference_services/InferenceServiceABC.py +85 -18
  36. edsl/inference_services/InferenceServicesCollection.py +79 -120
  37. edsl/inference_services/MistralAIService.py +3 -0
  38. edsl/inference_services/OpenAIService.py +35 -47
  39. edsl/inference_services/PerplexityService.py +3 -0
  40. edsl/inference_services/TestService.py +10 -11
  41. edsl/inference_services/TogetherAIService.py +3 -5
  42. edsl/jobs/Answers.py +14 -1
  43. edsl/jobs/Jobs.py +431 -356
  44. edsl/jobs/JobsChecks.py +10 -35
  45. edsl/jobs/JobsPrompts.py +4 -6
  46. edsl/jobs/JobsRemoteInferenceHandler.py +133 -205
  47. edsl/jobs/buckets/BucketCollection.py +3 -44
  48. edsl/jobs/buckets/TokenBucket.py +21 -53
  49. edsl/jobs/interviews/Interview.py +408 -143
  50. edsl/jobs/runners/JobsRunnerAsyncio.py +403 -88
  51. edsl/jobs/runners/JobsRunnerStatus.py +165 -133
  52. edsl/jobs/tasks/QuestionTaskCreator.py +19 -21
  53. edsl/jobs/tasks/TaskHistory.py +18 -38
  54. edsl/jobs/tasks/task_status_enum.py +2 -0
  55. edsl/language_models/KeyLookup.py +30 -0
  56. edsl/language_models/LanguageModel.py +236 -194
  57. edsl/language_models/ModelList.py +19 -28
  58. edsl/language_models/__init__.py +2 -1
  59. edsl/language_models/registry.py +190 -0
  60. edsl/language_models/repair.py +2 -2
  61. edsl/language_models/unused/ReplicateBase.py +83 -0
  62. edsl/language_models/utilities.py +4 -5
  63. edsl/notebooks/Notebook.py +14 -19
  64. edsl/prompts/Prompt.py +39 -29
  65. edsl/questions/{answer_validator_mixin.py → AnswerValidatorMixin.py} +2 -47
  66. edsl/questions/QuestionBase.py +214 -68
  67. edsl/questions/{question_base_gen_mixin.py → QuestionBaseGenMixin.py} +50 -57
  68. edsl/questions/QuestionBasePromptsMixin.py +3 -7
  69. edsl/questions/QuestionBudget.py +1 -1
  70. edsl/questions/QuestionCheckBox.py +3 -3
  71. edsl/questions/QuestionExtract.py +7 -5
  72. edsl/questions/QuestionFreeText.py +3 -2
  73. edsl/questions/QuestionList.py +18 -10
  74. edsl/questions/QuestionMultipleChoice.py +23 -67
  75. edsl/questions/QuestionNumerical.py +4 -2
  76. edsl/questions/QuestionRank.py +17 -7
  77. edsl/questions/{response_validator_abc.py → ResponseValidatorABC.py} +26 -40
  78. edsl/questions/SimpleAskMixin.py +3 -4
  79. edsl/questions/__init__.py +1 -2
  80. edsl/questions/derived/QuestionLinearScale.py +3 -6
  81. edsl/questions/derived/QuestionTopK.py +1 -1
  82. edsl/questions/descriptors.py +3 -17
  83. edsl/questions/question_registry.py +1 -1
  84. edsl/results/CSSParameterizer.py +1 -1
  85. edsl/results/Dataset.py +7 -170
  86. edsl/results/DatasetExportMixin.py +305 -168
  87. edsl/results/DatasetTree.py +8 -28
  88. edsl/results/Result.py +206 -298
  89. edsl/results/Results.py +131 -149
  90. edsl/results/ResultsDBMixin.py +238 -0
  91. edsl/results/ResultsExportMixin.py +0 -2
  92. edsl/results/{results_selector.py → Selector.py} +13 -23
  93. edsl/results/TableDisplay.py +171 -98
  94. edsl/results/__init__.py +1 -1
  95. edsl/scenarios/FileStore.py +239 -150
  96. edsl/scenarios/Scenario.py +193 -90
  97. edsl/scenarios/ScenarioHtmlMixin.py +3 -4
  98. edsl/scenarios/{scenario_join.py → ScenarioJoin.py} +6 -10
  99. edsl/scenarios/ScenarioList.py +244 -415
  100. edsl/scenarios/ScenarioListExportMixin.py +7 -0
  101. edsl/scenarios/ScenarioListPdfMixin.py +37 -15
  102. edsl/scenarios/__init__.py +2 -1
  103. edsl/study/ObjectEntry.py +1 -1
  104. edsl/study/SnapShot.py +1 -1
  105. edsl/study/Study.py +12 -5
  106. edsl/surveys/Rule.py +4 -5
  107. edsl/surveys/RuleCollection.py +27 -25
  108. edsl/surveys/Survey.py +791 -270
  109. edsl/surveys/SurveyCSS.py +8 -20
  110. edsl/surveys/{SurveyFlowVisualization.py → SurveyFlowVisualizationMixin.py} +9 -11
  111. edsl/surveys/__init__.py +2 -4
  112. edsl/surveys/descriptors.py +2 -6
  113. edsl/surveys/instructions/ChangeInstruction.py +2 -1
  114. edsl/surveys/instructions/Instruction.py +13 -4
  115. edsl/surveys/instructions/InstructionCollection.py +6 -11
  116. edsl/templates/error_reporting/interview_details.html +1 -1
  117. edsl/templates/error_reporting/report.html +1 -1
  118. edsl/tools/plotting.py +1 -1
  119. edsl/utilities/utilities.py +23 -35
  120. {edsl-0.1.39.dist-info → edsl-0.1.39.dev1.dist-info}/METADATA +10 -12
  121. edsl-0.1.39.dev1.dist-info/RECORD +277 -0
  122. {edsl-0.1.39.dist-info → edsl-0.1.39.dev1.dist-info}/WHEEL +1 -1
  123. edsl/agents/QuestionInstructionPromptBuilder.py +0 -128
  124. edsl/agents/QuestionTemplateReplacementsBuilder.py +0 -137
  125. edsl/agents/question_option_processor.py +0 -172
  126. edsl/coop/CoopFunctionsMixin.py +0 -15
  127. edsl/coop/ExpectedParrotKeyHandler.py +0 -125
  128. edsl/exceptions/inference_services.py +0 -5
  129. edsl/inference_services/AvailableModelCacheHandler.py +0 -184
  130. edsl/inference_services/AvailableModelFetcher.py +0 -215
  131. edsl/inference_services/ServiceAvailability.py +0 -135
  132. edsl/inference_services/data_structures.py +0 -134
  133. edsl/jobs/AnswerQuestionFunctionConstructor.py +0 -223
  134. edsl/jobs/FetchInvigilator.py +0 -47
  135. edsl/jobs/InterviewTaskManager.py +0 -98
  136. edsl/jobs/InterviewsConstructor.py +0 -50
  137. edsl/jobs/JobsComponentConstructor.py +0 -189
  138. edsl/jobs/JobsRemoteInferenceLogger.py +0 -239
  139. edsl/jobs/RequestTokenEstimator.py +0 -30
  140. edsl/jobs/async_interview_runner.py +0 -138
  141. edsl/jobs/buckets/TokenBucketAPI.py +0 -211
  142. edsl/jobs/buckets/TokenBucketClient.py +0 -191
  143. edsl/jobs/check_survey_scenario_compatibility.py +0 -85
  144. edsl/jobs/data_structures.py +0 -120
  145. edsl/jobs/decorators.py +0 -35
  146. edsl/jobs/jobs_status_enums.py +0 -9
  147. edsl/jobs/loggers/HTMLTableJobLogger.py +0 -304
  148. edsl/jobs/results_exceptions_handler.py +0 -98
  149. edsl/language_models/ComputeCost.py +0 -63
  150. edsl/language_models/PriceManager.py +0 -127
  151. edsl/language_models/RawResponseHandler.py +0 -106
  152. edsl/language_models/ServiceDataSources.py +0 -0
  153. edsl/language_models/key_management/KeyLookup.py +0 -63
  154. edsl/language_models/key_management/KeyLookupBuilder.py +0 -273
  155. edsl/language_models/key_management/KeyLookupCollection.py +0 -38
  156. edsl/language_models/key_management/__init__.py +0 -0
  157. edsl/language_models/key_management/models.py +0 -131
  158. edsl/language_models/model.py +0 -256
  159. edsl/notebooks/NotebookToLaTeX.py +0 -142
  160. edsl/questions/ExceptionExplainer.py +0 -77
  161. edsl/questions/HTMLQuestion.py +0 -103
  162. edsl/questions/QuestionMatrix.py +0 -265
  163. edsl/questions/data_structures.py +0 -20
  164. edsl/questions/loop_processor.py +0 -149
  165. edsl/questions/response_validator_factory.py +0 -34
  166. edsl/questions/templates/matrix/__init__.py +0 -1
  167. edsl/questions/templates/matrix/answering_instructions.jinja +0 -5
  168. edsl/questions/templates/matrix/question_presentation.jinja +0 -20
  169. edsl/results/MarkdownToDocx.py +0 -122
  170. edsl/results/MarkdownToPDF.py +0 -111
  171. edsl/results/TextEditor.py +0 -50
  172. edsl/results/file_exports.py +0 -252
  173. edsl/results/smart_objects.py +0 -96
  174. edsl/results/table_data_class.py +0 -12
  175. edsl/results/table_renderers.py +0 -118
  176. edsl/scenarios/ConstructDownloadLink.py +0 -109
  177. edsl/scenarios/DocumentChunker.py +0 -102
  178. edsl/scenarios/DocxScenario.py +0 -16
  179. edsl/scenarios/PdfExtractor.py +0 -40
  180. edsl/scenarios/directory_scanner.py +0 -96
  181. edsl/scenarios/file_methods.py +0 -85
  182. edsl/scenarios/handlers/__init__.py +0 -13
  183. edsl/scenarios/handlers/csv.py +0 -49
  184. edsl/scenarios/handlers/docx.py +0 -76
  185. edsl/scenarios/handlers/html.py +0 -37
  186. edsl/scenarios/handlers/json.py +0 -111
  187. edsl/scenarios/handlers/latex.py +0 -5
  188. edsl/scenarios/handlers/md.py +0 -51
  189. edsl/scenarios/handlers/pdf.py +0 -68
  190. edsl/scenarios/handlers/png.py +0 -39
  191. edsl/scenarios/handlers/pptx.py +0 -105
  192. edsl/scenarios/handlers/py.py +0 -294
  193. edsl/scenarios/handlers/sql.py +0 -313
  194. edsl/scenarios/handlers/sqlite.py +0 -149
  195. edsl/scenarios/handlers/txt.py +0 -33
  196. edsl/scenarios/scenario_selector.py +0 -156
  197. edsl/surveys/ConstructDAG.py +0 -92
  198. edsl/surveys/EditSurvey.py +0 -221
  199. edsl/surveys/InstructionHandler.py +0 -100
  200. edsl/surveys/MemoryManagement.py +0 -72
  201. edsl/surveys/RuleManager.py +0 -172
  202. edsl/surveys/Simulator.py +0 -75
  203. edsl/surveys/SurveyToApp.py +0 -141
  204. edsl/utilities/PrettyList.py +0 -56
  205. edsl/utilities/is_notebook.py +0 -18
  206. edsl/utilities/is_valid_variable_name.py +0 -11
  207. edsl/utilities/remove_edsl_version.py +0 -24
  208. edsl-0.1.39.dist-info/RECORD +0 -358
  209. /edsl/questions/{register_questions_meta.py → RegisterQuestionsMeta.py} +0 -0
  210. /edsl/results/{results_fetch_mixin.py → ResultsFetchMixin.py} +0 -0
  211. /edsl/results/{results_tools_mixin.py → ResultsToolsMixin.py} +0 -0
  212. {edsl-0.1.39.dist-info → edsl-0.1.39.dev1.dist-info}/LICENSE +0 -0
edsl/auto/StageBase.py CHANGED
@@ -1,5 +1,4 @@
1
1
  from abc import ABC, abstractmethod
2
- import json
3
2
  from typing import Dict, List, Any, TypeVar, Generator, Dict, Callable
4
3
  from dataclasses import dataclass, field, KW_ONLY, fields, asdict
5
4
  import textwrap
@@ -36,13 +35,6 @@ class FlowDataBase:
36
35
  sent_to_stage_name: str = field(default_factory=str)
37
36
  came_from_stage_name: str = field(default_factory=str)
38
37
 
39
- def to_dict(self):
40
- return asdict(self)
41
-
42
- @classmethod
43
- def from_dict(cls, data: dict):
44
- return cls(**data)
45
-
46
38
  def __getitem__(self, key):
47
39
  """Allows dictionary-style getting."""
48
40
  return getattr(self, key)
@@ -134,10 +126,6 @@ class StageBase(ABC):
134
126
  else:
135
127
  self.next_stage = None
136
128
 
137
- @classmethod
138
- def function_parameters(self):
139
- return fields(self.input)
140
-
141
129
  @classmethod
142
130
  def func(cls, **kwargs):
143
131
  "This provides a shortcut for running a stage by passing keyword arguments to the input function."
@@ -185,59 +173,58 @@ class StageBase(ABC):
185
173
 
186
174
 
187
175
  if __name__ == "__main__":
188
- pass
189
- # try:
176
+ try:
190
177
 
191
- # class StageMissing(StageBase):
192
- # def handle_data(self, data):
193
- # return data
178
+ class StageMissing(StageBase):
179
+ def handle_data(self, data):
180
+ return data
194
181
 
195
- # except NotImplementedError as e:
196
- # print(e)
197
- # else:
198
- # raise Exception("Should have raised NotImplementedError")
182
+ except NotImplementedError as e:
183
+ print(e)
184
+ else:
185
+ raise Exception("Should have raised NotImplementedError")
199
186
 
200
- # try:
187
+ try:
201
188
 
202
- # class StageMissingInput(StageBase):
203
- # output = FlowDataBase
189
+ class StageMissingInput(StageBase):
190
+ output = FlowDataBase
204
191
 
205
- # except NotImplementedError as e:
206
- # print(e)
192
+ except NotImplementedError as e:
193
+ print(e)
207
194
 
208
- # else:
209
- # raise Exception("Should have raised NotImplementedError")
195
+ else:
196
+ raise Exception("Should have raised NotImplementedError")
210
197
 
211
- # @dataclass
212
- # class MockInputOutput(FlowDataBase):
213
- # text: str
198
+ @dataclass
199
+ class MockInputOutput(FlowDataBase):
200
+ text: str
214
201
 
215
- # class StageTest(StageBase):
216
- # input = MockInputOutput
217
- # output = MockInputOutput
202
+ class StageTest(StageBase):
203
+ input = MockInputOutput
204
+ output = MockInputOutput
218
205
 
219
- # def handle_data(self, data):
220
- # return self.output(text=data["text"] + "processed")
206
+ def handle_data(self, data):
207
+ return self.output(text=data["text"] + "processed")
221
208
 
222
- # result = StageTest().process(MockInputOutput(text="Hello world!"))
223
- # print(result.text)
209
+ result = StageTest().process(MockInputOutput(text="Hello world!"))
210
+ print(result.text)
224
211
 
225
- # pipeline = StageTest(next_stage=StageTest(next_stage=StageTest()))
226
- # result = pipeline.process(MockInputOutput(text="Hello world!"))
227
- # print(result.text)
212
+ pipeline = StageTest(next_stage=StageTest(next_stage=StageTest()))
213
+ result = pipeline.process(MockInputOutput(text="Hello world!"))
214
+ print(result.text)
228
215
 
229
- # class BadMockInput(FlowDataBase):
230
- # text: str
231
- # other: str
216
+ class BadMockInput(FlowDataBase):
217
+ text: str
218
+ other: str
232
219
 
233
- # class StageBad(StageBase):
234
- # input = BadMockInput
235
- # output = BadMockInput
220
+ class StageBad(StageBase):
221
+ input = BadMockInput
222
+ output = BadMockInput
236
223
 
237
- # def handle_data(self, data):
238
- # return self.output(text=data["text"] + "processed")
224
+ def handle_data(self, data):
225
+ return self.output(text=data["text"] + "processed")
239
226
 
240
- # try:
241
- # pipeline = StageTest(next_stage=StageBad(next_stage=StageTest()))
242
- # except ExceptionPipesDoNotFit as e:
243
- # print(e)
227
+ try:
228
+ pipeline = StageTest(next_stage=StageBad(next_stage=StageTest()))
229
+ except ExceptionPipesDoNotFit as e:
230
+ print(e)
@@ -68,7 +68,6 @@ if __name__ == "__main__":
68
68
  population="Consumers",
69
69
  )
70
70
  )
71
-
72
- results = StageQuestions.func(
71
+ StageQuestions.func(
73
72
  overall_question="Why aren't my students studying more?", population="Tech"
74
73
  )
edsl/auto/utilities.py CHANGED
@@ -88,6 +88,12 @@ def agent_eligibility(
88
88
  q_eligibility(model=model, questions=questions, persona=persona, cache=cache)
89
89
  == "Yes"
90
90
  )
91
+ # results = (
92
+ # q.by(model)
93
+ # .by(Scenario({"questions": questions, "persona": persona}))
94
+ # .run(cache=cache)
95
+ # )
96
+ # return results.select("eligibility").first() == "Yes"
91
97
 
92
98
 
93
99
  def gen_agent_traits(dimension_dict: dict, seed_value: Optional[str] = None):
edsl/config.py CHANGED
@@ -1,16 +1,12 @@
1
1
  """This module provides a Config class that loads environment variables from a .env file and sets them as class attributes."""
2
2
 
3
3
  import os
4
- import platformdirs
5
4
  from dotenv import load_dotenv, find_dotenv
6
- from edsl.exceptions.configuration import (
5
+ from edsl.exceptions import (
7
6
  InvalidEnvironmentVariableError,
8
7
  MissingEnvironmentVariableError,
9
8
  )
10
9
 
11
- cache_dir = platformdirs.user_cache_dir("edsl")
12
- os.makedirs(cache_dir, exist_ok=True)
13
-
14
10
  # valid values for EDSL_RUN_MODE
15
11
  EDSL_RUN_MODES = [
16
12
  "development",
@@ -38,8 +34,7 @@ CONFIG_MAP = {
38
34
  "info": "This config var determines the maximum number of seconds to wait before retrying a failed API call.",
39
35
  },
40
36
  "EDSL_DATABASE_PATH": {
41
- # "default": f"sqlite:///{os.path.join(os.getcwd(), '.edsl_cache/data.db')}",
42
- "default": f"sqlite:///{os.path.join(platformdirs.user_cache_dir('edsl'), 'lm_model_calls.db')}",
37
+ "default": f"sqlite:///{os.path.join(os.getcwd(), '.edsl_cache/data.db')}",
43
38
  "info": "This config var determines the path to the cache file.",
44
39
  },
45
40
  "EDSL_DEFAULT_MODEL": {
@@ -74,10 +69,6 @@ CONFIG_MAP = {
74
69
  "default": "False",
75
70
  "info": "This config var determines whether to open the exception report URL in the browser",
76
71
  },
77
- "EDSL_REMOTE_TOKEN_BUCKET_URL": {
78
- "default": "None",
79
- "info": "This config var holds the URL of the remote token bucket server.",
80
- },
81
72
  }
82
73
 
83
74
 
@@ -90,9 +81,6 @@ class Config:
90
81
  self._load_dotenv()
91
82
  self._set_env_vars()
92
83
 
93
- def show_path_to_dot_env(self):
94
- print(find_dotenv(usecwd=True))
95
-
96
84
  def _set_run_mode(self) -> None:
97
85
  """
98
86
  Sets EDSL_RUN_MODE as a class attribute.
@@ -156,14 +144,6 @@ class Config:
156
144
  raise MissingEnvironmentVariableError(f"{env_var} is not set. {info}")
157
145
  return self.__dict__.get(env_var)
158
146
 
159
- def __iter__(self):
160
- """Iterate over the environment variables."""
161
- return iter(self.__dict__)
162
-
163
- def items(self):
164
- """Iterate over the environment variables and their values."""
165
- return self.__dict__.items()
166
-
167
147
  def show(self) -> str:
168
148
  """Print the currently set environment vars."""
169
149
  max_env_var_length = max(len(env_var) for env_var in self.__dict__)
@@ -29,8 +29,7 @@ a3 = Agent(
29
29
  c1 = Conversation(agent_list=AgentList([a1, a3, a2]), max_turns=5, verbose=True)
30
30
  c2 = Conversation(agent_list=AgentList([a1, a2]), max_turns=5, verbose=True)
31
31
 
32
- # c = Cache.load("car_talk.json.gz")
33
- c = Cache()
32
+ c = Cache.load("car_talk.json.gz")
34
33
  # breakpoint()
35
34
  combo = ConversationList([c1, c2], cache=c)
36
35
  combo.run()
edsl/coop/PriceFetcher.py CHANGED
@@ -18,7 +18,7 @@ class PriceFetcher:
18
18
 
19
19
  import os
20
20
  import requests
21
- from edsl.config import CONFIG
21
+ from edsl import CONFIG
22
22
 
23
23
  try:
24
24
  # Fetch the pricing data
edsl/coop/coop.py CHANGED
@@ -1,19 +1,11 @@
1
1
  import aiohttp
2
2
  import json
3
+ import os
3
4
  import requests
4
-
5
- from typing import Any, Optional, Union, Literal, TypedDict
5
+ from typing import Any, Optional, Union, Literal
6
6
  from uuid import UUID
7
- from collections import UserDict, defaultdict
8
-
9
7
  import edsl
10
- from pathlib import Path
11
-
12
- from edsl.config import CONFIG
13
- from edsl.data.CacheEntry import CacheEntry
14
- from edsl.jobs.Jobs import Jobs
15
- from edsl.surveys.Survey import Survey
16
-
8
+ from edsl import CONFIG, CacheEntry, Jobs, Survey
17
9
  from edsl.exceptions.coop import CoopNoUUIDError, CoopServerResponseError
18
10
  from edsl.coop.utils import (
19
11
  EDSLObject,
@@ -23,48 +15,19 @@ from edsl.coop.utils import (
23
15
  VisibilityType,
24
16
  )
25
17
 
26
- from edsl.coop.CoopFunctionsMixin import CoopFunctionsMixin
27
- from edsl.coop.ExpectedParrotKeyHandler import ExpectedParrotKeyHandler
28
-
29
- from edsl.inference_services.data_structures import ServiceToModelsMapping
30
-
31
18
 
32
- class RemoteInferenceResponse(TypedDict):
33
- job_uuid: str
34
- results_uuid: str
35
- results_url: str
36
- latest_error_report_uuid: str
37
- latest_error_report_url: str
38
- status: str
39
- reason: str
40
- credits_consumed: float
41
- version: str
42
-
43
-
44
- class RemoteInferenceCreationInfo(TypedDict):
45
- uuid: str
46
- description: str
47
- status: str
48
- iterations: int
49
- visibility: str
50
- version: str
51
-
52
-
53
- class Coop(CoopFunctionsMixin):
19
+ class Coop:
54
20
  """
55
21
  Client for the Expected Parrot API.
56
22
  """
57
23
 
58
- def __init__(
59
- self, api_key: Optional[str] = None, url: Optional[str] = None
60
- ) -> None:
24
+ def __init__(self, api_key: str = None, url: str = None) -> None:
61
25
  """
62
26
  Initialize the client.
63
27
  - Provide an API key directly, or through an env variable.
64
28
  - Provide a URL directly, or use the default one.
65
29
  """
66
- self.ep_key_handler = ExpectedParrotKeyHandler()
67
- self.api_key = api_key or self.ep_key_handler.get_ep_api_key()
30
+ self.api_key = api_key or os.getenv("EXPECTED_PARROT_API_KEY")
68
31
 
69
32
  self.url = url or CONFIG.EXPECTED_PARROT_URL
70
33
  if self.url.endswith("/"):
@@ -179,7 +142,6 @@ class Coop(CoopFunctionsMixin):
179
142
  Check the response from the server and raise errors as appropriate.
180
143
  """
181
144
  # Get EDSL version from header
182
- # breakpoint()
183
145
  server_edsl_version = response.headers.get("X-EDSL-Version")
184
146
 
185
147
  if server_edsl_version:
@@ -188,18 +150,11 @@ class Coop(CoopFunctionsMixin):
188
150
  server_version_str=server_edsl_version,
189
151
  ):
190
152
  print(
191
- "Please upgrade your EDSL version to access our latest features. To upgrade, open your terminal and run `pip install --upgrade edsl`"
153
+ "Please upgrade your EDSL version to access our latest features. To upgrade, open your terminal and run `pip upgrade edsl`"
192
154
  )
193
155
 
194
156
  if response.status_code >= 400:
195
- try:
196
- message = response.json().get("detail")
197
- except json.JSONDecodeError:
198
- raise CoopServerResponseError(
199
- f"Server returned status code {response.status_code}."
200
- "JSON response could not be decoded.",
201
- "The server response was: " + response.text,
202
- )
157
+ message = response.json().get("detail")
203
158
  # print(response.text)
204
159
  if "The API key you provided is invalid" in message and check_api_key:
205
160
  import secrets
@@ -208,27 +163,19 @@ class Coop(CoopFunctionsMixin):
208
163
  edsl_auth_token = secrets.token_urlsafe(16)
209
164
 
210
165
  print("Your Expected Parrot API key is invalid.")
211
- self._display_login_url(
212
- edsl_auth_token=edsl_auth_token,
213
- link_description="\n🔗 Use the link below to log in to Expected Parrot so we can automatically update your API key.",
166
+ print(
167
+ "\nUse the link below to log in to Expected Parrot so we can automatically update your API key."
214
168
  )
169
+ self._display_login_url(edsl_auth_token=edsl_auth_token)
215
170
  api_key = self._poll_for_api_key(edsl_auth_token)
216
171
 
217
172
  if api_key is None:
218
173
  print("\nTimed out waiting for login. Please try again.")
219
174
  return
220
175
 
221
- print("\n✨ API key retrieved.")
222
-
223
- if stored_in_user_space := self.ep_key_handler.ask_to_store(api_key):
224
- pass
225
- else:
226
- path_to_env = write_api_key_to_env(api_key)
227
- print(
228
- "\n✨ API key retrieved and written to .env file at the following path:"
229
- )
230
- print(f" {path_to_env}")
231
- print("Rerun your code to try again with a valid API key.")
176
+ write_api_key_to_env(api_key)
177
+ print("\n✨ API key retrieved and written to .env file.")
178
+ print("Rerun your code to try again with a valid API key.")
232
179
  return
233
180
 
234
181
  elif "Authorization" in message:
@@ -321,7 +268,6 @@ class Coop(CoopFunctionsMixin):
321
268
  self,
322
269
  object: EDSLObject,
323
270
  description: Optional[str] = None,
324
- alias: Optional[str] = None,
325
271
  visibility: Optional[VisibilityType] = "unlisted",
326
272
  ) -> dict:
327
273
  """
@@ -333,7 +279,6 @@ class Coop(CoopFunctionsMixin):
333
279
  method="POST",
334
280
  payload={
335
281
  "description": description,
336
- "alias": alias,
337
282
  "json_string": json.dumps(
338
283
  object.to_dict(),
339
284
  default=self._json_handle_none,
@@ -428,7 +373,6 @@ class Coop(CoopFunctionsMixin):
428
373
  uuid: Union[str, UUID] = None,
429
374
  url: str = None,
430
375
  description: Optional[str] = None,
431
- alias: Optional[str] = None,
432
376
  value: Optional[EDSLObject] = None,
433
377
  visibility: Optional[VisibilityType] = None,
434
378
  ) -> dict:
@@ -445,7 +389,6 @@ class Coop(CoopFunctionsMixin):
445
389
  params={"uuid": uuid},
446
390
  payload={
447
391
  "description": description,
448
- "alias": alias,
449
392
  "json_string": (
450
393
  json.dumps(
451
394
  value.to_dict(),
@@ -659,6 +602,9 @@ class Coop(CoopFunctionsMixin):
659
602
  self._resolve_server_response(response)
660
603
  return response.json()
661
604
 
605
+ ################
606
+ # Remote Inference
607
+ ################
662
608
  def remote_inference_create(
663
609
  self,
664
610
  job: Jobs,
@@ -667,7 +613,7 @@ class Coop(CoopFunctionsMixin):
667
613
  visibility: Optional[VisibilityType] = "unlisted",
668
614
  initial_results_visibility: Optional[VisibilityType] = "unlisted",
669
615
  iterations: Optional[int] = 1,
670
- ) -> RemoteInferenceCreationInfo:
616
+ ) -> dict:
671
617
  """
672
618
  Send a remote inference job to the server.
673
619
 
@@ -699,21 +645,18 @@ class Coop(CoopFunctionsMixin):
699
645
  )
700
646
  self._resolve_server_response(response)
701
647
  response_json = response.json()
702
-
703
- return RemoteInferenceCreationInfo(
704
- **{
705
- "uuid": response_json.get("job_uuid"),
706
- "description": response_json.get("description"),
707
- "status": response_json.get("status"),
708
- "iterations": response_json.get("iterations"),
709
- "visibility": response_json.get("visibility"),
710
- "version": self._edsl_version,
711
- }
712
- )
648
+ return {
649
+ "uuid": response_json.get("job_uuid"),
650
+ "description": response_json.get("description"),
651
+ "status": response_json.get("status"),
652
+ "iterations": response_json.get("iterations"),
653
+ "visibility": response_json.get("visibility"),
654
+ "version": self._edsl_version,
655
+ }
713
656
 
714
657
  def remote_inference_get(
715
658
  self, job_uuid: Optional[str] = None, results_uuid: Optional[str] = None
716
- ) -> RemoteInferenceResponse:
659
+ ) -> dict:
717
660
  """
718
661
  Get the details of a remote inference job.
719
662
  You can pass either the job uuid or the results uuid as a parameter.
@@ -755,30 +698,17 @@ class Coop(CoopFunctionsMixin):
755
698
  f"{self.url}/home/remote-inference/error/{latest_error_report_uuid}"
756
699
  )
757
700
 
758
- return RemoteInferenceResponse(
759
- **{
760
- "job_uuid": data.get("job_uuid"),
761
- "results_uuid": results_uuid,
762
- "results_url": results_url,
763
- "latest_error_report_uuid": latest_error_report_uuid,
764
- "latest_error_report_url": latest_error_report_url,
765
- "status": data.get("status"),
766
- "reason": data.get("reason"),
767
- "credits_consumed": data.get("price"),
768
- "version": data.get("version"),
769
- }
770
- )
771
-
772
- def get_running_jobs(self) -> list[str]:
773
- """
774
- Get a list of currently running job IDs.
775
-
776
- Returns:
777
- list[str]: List of running job UUIDs
778
- """
779
- response = self._send_server_request(uri="jobs/status", method="GET")
780
- self._resolve_server_response(response)
781
- return response.json().get("running_jobs", [])
701
+ return {
702
+ "job_uuid": data.get("job_uuid"),
703
+ "results_uuid": results_uuid,
704
+ "results_url": results_url,
705
+ "latest_error_report_uuid": latest_error_report_uuid,
706
+ "latest_error_report_url": latest_error_report_url,
707
+ "status": data.get("status"),
708
+ "reason": data.get("reason"),
709
+ "credits_consumed": data.get("price"),
710
+ "version": data.get("version"),
711
+ }
782
712
 
783
713
  def remote_inference_cost(
784
714
  self, input: Union[Jobs, Survey], iterations: int = 1
@@ -880,7 +810,7 @@ class Coop(CoopFunctionsMixin):
880
810
  "Invalid EDSL_FETCH_TOKEN_PRICES value---should be 'True' or 'False'."
881
811
  )
882
812
 
883
- def fetch_models(self) -> ServiceToModelsMapping:
813
+ def fetch_models(self) -> dict:
884
814
  """
885
815
  Fetch a dict of available models from Coop.
886
816
 
@@ -889,7 +819,7 @@ class Coop(CoopFunctionsMixin):
889
819
  response = self._send_server_request(uri="api/v0/models", method="GET")
890
820
  self._resolve_server_response(response)
891
821
  data = response.json()
892
- return ServiceToModelsMapping(data)
822
+ return data
893
823
 
894
824
  def fetch_rate_limit_config_vars(self) -> dict:
895
825
  """
@@ -905,9 +835,7 @@ class Coop(CoopFunctionsMixin):
905
835
  data = response.json()
906
836
  return data
907
837
 
908
- def _display_login_url(
909
- self, edsl_auth_token: str, link_description: Optional[str] = None
910
- ):
838
+ def _display_login_url(self, edsl_auth_token: str):
911
839
  """
912
840
  Uses rich.print to display a login URL.
913
841
 
@@ -917,12 +845,7 @@ class Coop(CoopFunctionsMixin):
917
845
 
918
846
  url = f"{CONFIG.EXPECTED_PARROT_URL}/login?edsl_auth_token={edsl_auth_token}"
919
847
 
920
- if link_description:
921
- rich_print(
922
- f"{link_description}\n [#38bdf8][link={url}]{url}[/link][/#38bdf8]"
923
- )
924
- else:
925
- rich_print(f" [#38bdf8][link={url}]{url}[/link][/#38bdf8]")
848
+ rich_print(f"[#38bdf8][link={url}]{url}[/link][/#38bdf8]")
926
849
 
927
850
  def _get_api_key(self, edsl_auth_token: str):
928
851
  """
@@ -950,18 +873,17 @@ class Coop(CoopFunctionsMixin):
950
873
 
951
874
  edsl_auth_token = secrets.token_urlsafe(16)
952
875
 
953
- self._display_login_url(
954
- edsl_auth_token=edsl_auth_token,
955
- link_description="\n🔗 Use the link below to log in to Expected Parrot so we can automatically update your API key.",
876
+ print(
877
+ "\nUse the link below to log in to Expected Parrot so we can automatically update your API key."
956
878
  )
879
+ self._display_login_url(edsl_auth_token=edsl_auth_token)
957
880
  api_key = self._poll_for_api_key(edsl_auth_token)
958
881
 
959
882
  if api_key is None:
960
883
  raise Exception("Timed out waiting for login. Please try again.")
961
884
 
962
- path_to_env = write_api_key_to_env(api_key)
963
- print("\n✨ API key retrieved and written to .env file at the following path:")
964
- print(f" {path_to_env}")
885
+ write_api_key_to_env(api_key)
886
+ print("\n✨ API key retrieved and written to .env file.")
965
887
 
966
888
  # Add API key to environment
967
889
  load_dotenv()
edsl/coop/utils.py CHANGED
@@ -1,19 +1,19 @@
1
+ from edsl import (
2
+ Agent,
3
+ AgentList,
4
+ Cache,
5
+ ModelList,
6
+ Notebook,
7
+ Results,
8
+ Scenario,
9
+ ScenarioList,
10
+ Survey,
11
+ Study,
12
+ )
13
+ from edsl.language_models import LanguageModel
14
+ from edsl.questions import QuestionBase
1
15
  from typing import Literal, Optional, Type, Union
2
16
 
3
- from edsl.agents.Agent import Agent
4
- from edsl.agents.AgentList import AgentList
5
- from edsl.data.Cache import Cache
6
- from edsl.language_models.ModelList import ModelList
7
- from edsl.notebooks.Notebook import Notebook
8
- from edsl.results.Results import Results
9
- from edsl.scenarios.Scenario import Scenario
10
- from edsl.scenarios.ScenarioList import ScenarioList
11
- from edsl.surveys.Survey import Survey
12
- from edsl.study.Study import Study
13
-
14
- from edsl.language_models.LanguageModel import LanguageModel
15
- from edsl.questions.QuestionBase import QuestionBase
16
-
17
17
  EDSLObject = Union[
18
18
  Agent,
19
19
  AgentList,