edsl 0.1.49__py3-none-any.whl → 0.1.51__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (257) hide show
  1. edsl/__init__.py +124 -53
  2. edsl/__version__.py +1 -1
  3. edsl/agents/agent.py +21 -21
  4. edsl/agents/agent_list.py +2 -5
  5. edsl/agents/exceptions.py +119 -5
  6. edsl/base/__init__.py +10 -35
  7. edsl/base/base_class.py +71 -36
  8. edsl/base/base_exception.py +204 -0
  9. edsl/base/data_transfer_models.py +1 -1
  10. edsl/base/exceptions.py +94 -0
  11. edsl/buckets/__init__.py +15 -1
  12. edsl/buckets/bucket_collection.py +3 -4
  13. edsl/buckets/exceptions.py +107 -0
  14. edsl/buckets/model_buckets.py +1 -2
  15. edsl/buckets/token_bucket.py +11 -6
  16. edsl/buckets/token_bucket_api.py +27 -12
  17. edsl/buckets/token_bucket_client.py +9 -7
  18. edsl/caching/cache.py +12 -4
  19. edsl/caching/cache_entry.py +10 -9
  20. edsl/caching/exceptions.py +113 -7
  21. edsl/caching/remote_cache_sync.py +6 -7
  22. edsl/caching/sql_dict.py +20 -14
  23. edsl/cli.py +43 -0
  24. edsl/config/__init__.py +1 -1
  25. edsl/config/config_class.py +32 -6
  26. edsl/conversation/Conversation.py +8 -4
  27. edsl/conversation/car_buying.py +1 -3
  28. edsl/conversation/exceptions.py +58 -0
  29. edsl/conversation/mug_negotiation.py +2 -8
  30. edsl/coop/__init__.py +28 -6
  31. edsl/coop/coop.py +120 -29
  32. edsl/coop/coop_functions.py +1 -1
  33. edsl/coop/ep_key_handling.py +1 -1
  34. edsl/coop/exceptions.py +188 -9
  35. edsl/coop/price_fetcher.py +5 -8
  36. edsl/coop/utils.py +4 -6
  37. edsl/dataset/__init__.py +5 -4
  38. edsl/dataset/dataset.py +177 -86
  39. edsl/dataset/dataset_operations_mixin.py +98 -76
  40. edsl/dataset/dataset_tree.py +11 -7
  41. edsl/dataset/display/table_display.py +0 -2
  42. edsl/dataset/display/table_renderers.py +6 -4
  43. edsl/dataset/exceptions.py +125 -0
  44. edsl/dataset/file_exports.py +18 -11
  45. edsl/dataset/r/ggplot.py +13 -6
  46. edsl/display/__init__.py +27 -0
  47. edsl/display/core.py +147 -0
  48. edsl/display/plugin.py +189 -0
  49. edsl/display/utils.py +52 -0
  50. edsl/inference_services/__init__.py +9 -1
  51. edsl/inference_services/available_model_cache_handler.py +1 -1
  52. edsl/inference_services/available_model_fetcher.py +5 -6
  53. edsl/inference_services/data_structures.py +10 -7
  54. edsl/inference_services/exceptions.py +132 -1
  55. edsl/inference_services/inference_service_abc.py +2 -2
  56. edsl/inference_services/inference_services_collection.py +2 -6
  57. edsl/inference_services/registry.py +4 -3
  58. edsl/inference_services/service_availability.py +4 -3
  59. edsl/inference_services/services/anthropic_service.py +4 -1
  60. edsl/inference_services/services/aws_bedrock.py +13 -12
  61. edsl/inference_services/services/azure_ai.py +12 -10
  62. edsl/inference_services/services/deep_infra_service.py +1 -4
  63. edsl/inference_services/services/deep_seek_service.py +1 -5
  64. edsl/inference_services/services/google_service.py +7 -3
  65. edsl/inference_services/services/groq_service.py +1 -1
  66. edsl/inference_services/services/mistral_ai_service.py +4 -2
  67. edsl/inference_services/services/ollama_service.py +1 -1
  68. edsl/inference_services/services/open_ai_service.py +7 -5
  69. edsl/inference_services/services/perplexity_service.py +6 -2
  70. edsl/inference_services/services/test_service.py +8 -7
  71. edsl/inference_services/services/together_ai_service.py +2 -3
  72. edsl/inference_services/services/xai_service.py +1 -1
  73. edsl/instructions/__init__.py +1 -1
  74. edsl/instructions/change_instruction.py +7 -5
  75. edsl/instructions/exceptions.py +61 -0
  76. edsl/instructions/instruction.py +6 -2
  77. edsl/instructions/instruction_collection.py +6 -4
  78. edsl/instructions/instruction_handler.py +12 -15
  79. edsl/interviews/ReportErrors.py +0 -3
  80. edsl/interviews/__init__.py +9 -2
  81. edsl/interviews/answering_function.py +11 -13
  82. edsl/interviews/exception_tracking.py +15 -8
  83. edsl/interviews/exceptions.py +79 -0
  84. edsl/interviews/interview.py +33 -30
  85. edsl/interviews/interview_status_dictionary.py +4 -2
  86. edsl/interviews/interview_status_log.py +2 -1
  87. edsl/interviews/interview_task_manager.py +5 -5
  88. edsl/interviews/request_token_estimator.py +5 -2
  89. edsl/interviews/statistics.py +3 -4
  90. edsl/invigilators/__init__.py +7 -1
  91. edsl/invigilators/exceptions.py +79 -0
  92. edsl/invigilators/invigilator_base.py +0 -1
  93. edsl/invigilators/invigilators.py +9 -13
  94. edsl/invigilators/prompt_constructor.py +1 -5
  95. edsl/invigilators/prompt_helpers.py +8 -4
  96. edsl/invigilators/question_instructions_prompt_builder.py +1 -1
  97. edsl/invigilators/question_option_processor.py +9 -5
  98. edsl/invigilators/question_template_replacements_builder.py +3 -2
  99. edsl/jobs/__init__.py +42 -5
  100. edsl/jobs/async_interview_runner.py +25 -23
  101. edsl/jobs/check_survey_scenario_compatibility.py +11 -10
  102. edsl/jobs/data_structures.py +8 -5
  103. edsl/jobs/exceptions.py +177 -8
  104. edsl/jobs/fetch_invigilator.py +1 -1
  105. edsl/jobs/jobs.py +74 -69
  106. edsl/jobs/jobs_checks.py +6 -7
  107. edsl/jobs/jobs_component_constructor.py +4 -4
  108. edsl/jobs/jobs_pricing_estimation.py +4 -3
  109. edsl/jobs/jobs_remote_inference_logger.py +5 -4
  110. edsl/jobs/jobs_runner_asyncio.py +3 -4
  111. edsl/jobs/jobs_runner_status.py +8 -9
  112. edsl/jobs/remote_inference.py +27 -24
  113. edsl/jobs/results_exceptions_handler.py +10 -7
  114. edsl/key_management/__init__.py +3 -1
  115. edsl/key_management/exceptions.py +62 -0
  116. edsl/key_management/key_lookup.py +1 -1
  117. edsl/key_management/key_lookup_builder.py +37 -14
  118. edsl/key_management/key_lookup_collection.py +2 -0
  119. edsl/language_models/__init__.py +1 -1
  120. edsl/language_models/exceptions.py +302 -14
  121. edsl/language_models/language_model.py +9 -8
  122. edsl/language_models/model.py +4 -4
  123. edsl/language_models/model_list.py +1 -1
  124. edsl/language_models/price_manager.py +1 -1
  125. edsl/language_models/raw_response_handler.py +14 -9
  126. edsl/language_models/registry.py +17 -21
  127. edsl/language_models/repair.py +0 -6
  128. edsl/language_models/unused/fake_openai_service.py +0 -1
  129. edsl/load_plugins.py +69 -0
  130. edsl/logger.py +146 -0
  131. edsl/notebooks/__init__.py +24 -1
  132. edsl/notebooks/exceptions.py +82 -0
  133. edsl/notebooks/notebook.py +7 -3
  134. edsl/notebooks/notebook_to_latex.py +1 -2
  135. edsl/plugins/__init__.py +63 -0
  136. edsl/plugins/built_in/export_example.py +50 -0
  137. edsl/plugins/built_in/pig_latin.py +67 -0
  138. edsl/plugins/cli.py +372 -0
  139. edsl/plugins/cli_typer.py +283 -0
  140. edsl/plugins/exceptions.py +31 -0
  141. edsl/plugins/hookspec.py +51 -0
  142. edsl/plugins/plugin_host.py +128 -0
  143. edsl/plugins/plugin_manager.py +633 -0
  144. edsl/plugins/plugins_registry.py +168 -0
  145. edsl/prompts/__init__.py +24 -1
  146. edsl/prompts/exceptions.py +107 -5
  147. edsl/prompts/prompt.py +15 -7
  148. edsl/questions/HTMLQuestion.py +5 -11
  149. edsl/questions/Quick.py +0 -1
  150. edsl/questions/__init__.py +6 -4
  151. edsl/questions/answer_validator_mixin.py +318 -323
  152. edsl/questions/compose_questions.py +3 -3
  153. edsl/questions/descriptors.py +11 -50
  154. edsl/questions/exceptions.py +278 -22
  155. edsl/questions/loop_processor.py +7 -5
  156. edsl/questions/prompt_templates/question_list.jinja +3 -0
  157. edsl/questions/question_base.py +46 -19
  158. edsl/questions/question_base_gen_mixin.py +2 -2
  159. edsl/questions/question_base_prompts_mixin.py +13 -7
  160. edsl/questions/question_budget.py +503 -98
  161. edsl/questions/question_check_box.py +660 -160
  162. edsl/questions/question_dict.py +345 -194
  163. edsl/questions/question_extract.py +401 -61
  164. edsl/questions/question_free_text.py +80 -14
  165. edsl/questions/question_functional.py +119 -9
  166. edsl/questions/{derived/question_likert_five.py → question_likert_five.py} +2 -2
  167. edsl/questions/{derived/question_linear_scale.py → question_linear_scale.py} +3 -4
  168. edsl/questions/question_list.py +275 -28
  169. edsl/questions/question_matrix.py +643 -96
  170. edsl/questions/question_multiple_choice.py +219 -51
  171. edsl/questions/question_numerical.py +361 -32
  172. edsl/questions/question_rank.py +401 -124
  173. edsl/questions/question_registry.py +7 -5
  174. edsl/questions/{derived/question_top_k.py → question_top_k.py} +3 -3
  175. edsl/questions/{derived/question_yes_no.py → question_yes_no.py} +3 -4
  176. edsl/questions/register_questions_meta.py +2 -2
  177. edsl/questions/response_validator_abc.py +13 -15
  178. edsl/questions/response_validator_factory.py +10 -12
  179. edsl/questions/templates/dict/answering_instructions.jinja +1 -0
  180. edsl/questions/templates/rank/question_presentation.jinja +1 -1
  181. edsl/results/__init__.py +1 -1
  182. edsl/results/exceptions.py +141 -7
  183. edsl/results/report.py +1 -2
  184. edsl/results/result.py +11 -9
  185. edsl/results/results.py +480 -321
  186. edsl/results/results_selector.py +8 -4
  187. edsl/scenarios/PdfExtractor.py +2 -2
  188. edsl/scenarios/construct_download_link.py +69 -35
  189. edsl/scenarios/directory_scanner.py +33 -14
  190. edsl/scenarios/document_chunker.py +1 -1
  191. edsl/scenarios/exceptions.py +238 -14
  192. edsl/scenarios/file_methods.py +1 -1
  193. edsl/scenarios/file_store.py +7 -3
  194. edsl/scenarios/handlers/__init__.py +17 -0
  195. edsl/scenarios/handlers/docx_file_store.py +0 -5
  196. edsl/scenarios/handlers/pdf_file_store.py +0 -1
  197. edsl/scenarios/handlers/pptx_file_store.py +0 -5
  198. edsl/scenarios/handlers/py_file_store.py +0 -1
  199. edsl/scenarios/handlers/sql_file_store.py +1 -4
  200. edsl/scenarios/handlers/sqlite_file_store.py +0 -1
  201. edsl/scenarios/handlers/txt_file_store.py +1 -1
  202. edsl/scenarios/scenario.py +1 -3
  203. edsl/scenarios/scenario_list.py +179 -27
  204. edsl/scenarios/scenario_list_pdf_tools.py +1 -0
  205. edsl/scenarios/scenario_selector.py +0 -1
  206. edsl/surveys/__init__.py +3 -4
  207. edsl/surveys/dag/__init__.py +4 -2
  208. edsl/surveys/descriptors.py +1 -1
  209. edsl/surveys/edit_survey.py +1 -0
  210. edsl/surveys/exceptions.py +165 -9
  211. edsl/surveys/memory/__init__.py +5 -3
  212. edsl/surveys/memory/memory_management.py +1 -0
  213. edsl/surveys/memory/memory_plan.py +6 -15
  214. edsl/surveys/rules/__init__.py +5 -3
  215. edsl/surveys/rules/rule.py +1 -2
  216. edsl/surveys/rules/rule_collection.py +1 -1
  217. edsl/surveys/survey.py +12 -24
  218. edsl/surveys/survey_css.py +3 -3
  219. edsl/surveys/survey_export.py +6 -3
  220. edsl/surveys/survey_flow_visualization.py +10 -1
  221. edsl/surveys/survey_simulator.py +2 -1
  222. edsl/tasks/__init__.py +23 -1
  223. edsl/tasks/exceptions.py +72 -0
  224. edsl/tasks/question_task_creator.py +3 -3
  225. edsl/tasks/task_creators.py +1 -3
  226. edsl/tasks/task_history.py +8 -10
  227. edsl/tasks/task_status_log.py +1 -2
  228. edsl/tokens/__init__.py +29 -1
  229. edsl/tokens/exceptions.py +37 -0
  230. edsl/tokens/interview_token_usage.py +3 -2
  231. edsl/tokens/token_usage.py +4 -3
  232. edsl/utilities/__init__.py +21 -1
  233. edsl/utilities/decorators.py +1 -2
  234. edsl/utilities/markdown_to_docx.py +2 -2
  235. edsl/utilities/markdown_to_pdf.py +1 -1
  236. edsl/utilities/repair_functions.py +0 -1
  237. edsl/utilities/restricted_python.py +0 -1
  238. edsl/utilities/template_loader.py +2 -3
  239. edsl/utilities/utilities.py +8 -29
  240. {edsl-0.1.49.dist-info → edsl-0.1.51.dist-info}/METADATA +32 -2
  241. edsl-0.1.51.dist-info/RECORD +365 -0
  242. edsl-0.1.51.dist-info/entry_points.txt +3 -0
  243. edsl/dataset/smart_objects.py +0 -96
  244. edsl/exceptions/BaseException.py +0 -21
  245. edsl/exceptions/__init__.py +0 -54
  246. edsl/exceptions/configuration.py +0 -16
  247. edsl/exceptions/general.py +0 -34
  248. edsl/questions/derived/__init__.py +0 -0
  249. edsl/study/ObjectEntry.py +0 -173
  250. edsl/study/ProofOfWork.py +0 -113
  251. edsl/study/SnapShot.py +0 -80
  252. edsl/study/Study.py +0 -520
  253. edsl/study/__init__.py +0 -6
  254. edsl/utilities/interface.py +0 -135
  255. edsl-0.1.49.dist-info/RECORD +0 -347
  256. {edsl-0.1.49.dist-info → edsl-0.1.51.dist-info}/LICENSE +0 -0
  257. {edsl-0.1.49.dist-info → edsl-0.1.51.dist-info}/WHEEL +0 -0
@@ -3,10 +3,18 @@
3
3
  import os
4
4
  import platformdirs
5
5
  from dotenv import load_dotenv, find_dotenv
6
- from edsl.exceptions.configuration import (
7
- InvalidEnvironmentVariableError,
8
- MissingEnvironmentVariableError,
9
- )
6
+ from ..base import BaseException
7
+ import logging
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+ class InvalidEnvironmentVariableError(BaseException):
12
+ """Raised when an environment variable is invalid."""
13
+ pass
14
+
15
+ class MissingEnvironmentVariableError(BaseException):
16
+ """Raised when an expected environment variable is missing."""
17
+ pass
10
18
 
11
19
  cache_dir = platformdirs.user_cache_dir("edsl")
12
20
  os.makedirs(cache_dir, exist_ok=True)
@@ -50,6 +58,10 @@ CONFIG_MAP = {
50
58
  "default": "True",
51
59
  "info": "This config var determines whether to fetch prices for tokens used in remote inference",
52
60
  },
61
+ "EDSL_LOG_LEVEL": {
62
+ "default": "ERROR",
63
+ "info": "This config var determines the logging level for the EDSL package (DEBUG, INFO, WARNING, ERROR, CRITICAL).",
64
+ },
53
65
  "EDSL_MAX_ATTEMPTS": {
54
66
  "default": "5",
55
67
  "info": "This config var determines the maximum number of times to retry a failed API call.",
@@ -86,9 +98,11 @@ class Config:
86
98
 
87
99
  def __init__(self):
88
100
  """Initialize the Config class."""
101
+ logger.debug("Initializing Config class")
89
102
  self._set_run_mode()
90
103
  self._load_dotenv()
91
104
  self._set_env_vars()
105
+ logger.info(f"Config initialized with run mode: {self.EDSL_RUN_MODE}")
92
106
 
93
107
  def show_path_to_dot_env(self):
94
108
  print(find_dotenv(usecwd=True))
@@ -101,7 +115,12 @@ class Config:
101
115
  default = CONFIG_MAP.get("EDSL_RUN_MODE").get("default")
102
116
  if run_mode is None:
103
117
  run_mode = default
118
+ logger.debug(f"EDSL_RUN_MODE not set, using default: {default}")
119
+ else:
120
+ logger.debug(f"EDSL_RUN_MODE set to: {run_mode}")
121
+
104
122
  if run_mode not in EDSL_RUN_MODES:
123
+ logger.error(f"Invalid EDSL_RUN_MODE: {run_mode}")
105
124
  raise InvalidEnvironmentVariableError(
106
125
  f"Value `{run_mode}` is not allowed for EDSL_RUN_MODE."
107
126
  )
@@ -149,12 +168,19 @@ class Config:
149
168
  """
150
169
  Returns the value of an environment variable.
151
170
  """
171
+ logger.debug(f"Getting config value for: {env_var}")
172
+
152
173
  if env_var not in CONFIG_MAP:
174
+ logger.error(f"Invalid environment variable requested: {env_var}")
153
175
  raise InvalidEnvironmentVariableError(f"{env_var} is not a valid env var. ")
154
176
  elif env_var not in self.__dict__:
155
177
  info = CONFIG_MAP[env_var].get("info")
178
+ logger.error(f"Missing environment variable: {env_var}")
156
179
  raise MissingEnvironmentVariableError(f"{env_var} is not set. {info}")
157
- return self.__dict__.get(env_var)
180
+
181
+ value = self.__dict__.get(env_var)
182
+ logger.debug(f"Config value for {env_var}: {value}")
183
+ return value
158
184
 
159
185
  def __iter__(self):
160
186
  """Iterate over the environment variables."""
@@ -174,4 +200,4 @@ class Config:
174
200
 
175
201
  # Note: Python modules are singletons. As such, once this module is imported
176
202
  # the same instance of it is reused across the application.
177
- CONFIG = Config()
203
+ CONFIG = Config()
@@ -1,13 +1,16 @@
1
1
  from collections import UserList
2
2
  import asyncio
3
3
  import inspect
4
- from typing import Optional, Callable
5
- from .. import Agent, QuestionFreeText, Results, AgentList, ScenarioList, Scenario
4
+ from typing import Optional, Callable, TYPE_CHECKING
5
+ from .. import QuestionFreeText, Results, AgentList, ScenarioList, Scenario, Model
6
6
  from ..questions import QuestionBase
7
7
  from ..results.Result import Result
8
8
  from jinja2 import Template
9
9
  from ..caching import Cache
10
10
 
11
+ if TYPE_CHECKING:
12
+ from ..language_models.model import Model
13
+
11
14
  from .next_speaker_utilities import (
12
15
  default_turn_taking_generator,
13
16
  speaker_closure,
@@ -71,7 +74,7 @@ class Conversation:
71
74
  conversation_index: Optional[int] = None,
72
75
  cache=None,
73
76
  disable_remote_inference=False,
74
- default_model: Optional["LanguageModel"] = None,
77
+ default_model: Optional[Model] = None,
75
78
  ):
76
79
  self.disable_remote_inference = disable_remote_inference
77
80
  self.per_round_message_template = per_round_message_template
@@ -120,7 +123,8 @@ What do you say next?"""
120
123
  per_round_message_template
121
124
  and "{{ round_message }}" not in next_statement_question.question_text
122
125
  ):
123
- raise ValueError(
126
+ from .exceptions import ConversationValueError
127
+ raise ConversationValueError(
124
128
  "If you pass in a per_round_message_template, you must include {{ round_message }} in the question_text."
125
129
  )
126
130
 
@@ -1,5 +1,4 @@
1
- from .. import Agent, AgentList, QuestionFreeText
2
- from .. import Cache
1
+ from .. import Agent, AgentList, QuestionFreeText, Cache, QuestionList
3
2
  from .Conversation import Conversation, ConversationList
4
3
 
5
4
  a1 = Agent(
@@ -46,7 +45,6 @@ q = QuestionFreeText(
46
45
  question_name="car_brand",
47
46
  )
48
47
 
49
- from .. import QuestionList
50
48
 
51
49
  q_actors = QuestionList(
52
50
  question_text="""This was a conversation about buying a car: {{ transcript }}.
@@ -0,0 +1,58 @@
1
+ """
2
+ Exceptions for the conversation module.
3
+
4
+ This module defines custom exceptions for the conversation module,
5
+ including errors for invalid participant configurations, agent interaction
6
+ failures, and conversation state errors.
7
+ """
8
+
9
+ from ..base import BaseException
10
+
11
+
12
+ class ConversationError(BaseException):
13
+ """
14
+ Base exception class for all conversation-related errors.
15
+
16
+ This is the parent class for all exceptions related to conversation
17
+ operations, including agent communication, turn management, and
18
+ participant configuration.
19
+ """
20
+ relevant_doc = "https://docs.expectedparrot.com/"
21
+
22
+
23
+ class ConversationValueError(ConversationError):
24
+ """
25
+ Exception raised when an invalid value is provided to a conversation.
26
+
27
+ This exception occurs when attempting to create or modify a conversation
28
+ with invalid values, such as:
29
+ - Invalid participant configurations
30
+ - Inappropriate agent parameters
31
+ - Incompatible conversation settings
32
+
33
+ Examples:
34
+ ```python
35
+ # Attempting to add an invalid participant to a conversation
36
+ conversation.add_participant(None) # Raises ConversationValueError
37
+ ```
38
+ """
39
+ relevant_doc = "https://docs.expectedparrot.com/"
40
+
41
+
42
+ class ConversationStateError(ConversationError):
43
+ """
44
+ Exception raised when the conversation is in an invalid state.
45
+
46
+ This exception occurs when attempting to perform an operation that
47
+ is incompatible with the current state of the conversation, such as:
48
+ - Ending a conversation that hasn't started
49
+ - Starting a conversation that's already in progress
50
+ - Accessing a participant that doesn't exist
51
+
52
+ Examples:
53
+ ```python
54
+ # Attempting to get the next speaker when the conversation is empty
55
+ empty_conversation.next_speaker() # Raises ConversationStateError
56
+ ```
57
+ """
58
+ relevant_doc = "https://docs.expectedparrot.com/"
@@ -1,5 +1,5 @@
1
- from edsl import Agent, AgentList
2
- from edsl.conversation.Conversation import Conversation, ConversationList
1
+ from .. import Agent, AgentList, QuestionYesNo, QuestionNumerical
2
+ from .Conversation import Conversation, ConversationList
3
3
 
4
4
 
5
5
  def bargaining_pairs(alice_valuation, bob_valuation):
@@ -43,12 +43,6 @@ results.select("conversation_index", "index", "agent_name", "dialogue").print(
43
43
  format="rich"
44
44
  )
45
45
 
46
- from edsl import (
47
- QuestionFreeText,
48
- QuestionMultipleChoice,
49
- QuestionYesNo,
50
- QuestionNumerical,
51
- )
52
46
 
53
47
  q_deal = QuestionYesNo(
54
48
  question_text="""This was a negotiation: {{ transcript }}.
edsl/coop/__init__.py CHANGED
@@ -8,18 +8,40 @@ This module enables EDSL to interact with cloud-based resources for enhanced fun
8
8
  3. Caching of interview results for improved performance and cost savings
9
9
  4. API key management and authentication
10
10
  5. Price and model availability information
11
+ 6. Plugin registry and discovery
11
12
 
12
13
  The primary interface is the Coop class, which serves as a client for the
13
14
  Expected Parrot API. Most users will only need to interact with the Coop class directly.
14
15
 
15
- Example:
16
- >>> from edsl.coop import Coop
17
- >>> coop = Coop() # Uses API key from environment or stored location
18
- >>> survey = my_survey.push() # Uploads survey to Expected Parrot
19
- >>> job_info = coop.remote_inference_create(my_job) # Creates remote job
16
+ Examples:
17
+
18
+ ```python
19
+ from edsl.coop import Coop
20
+ coop = Coop() # Uses API key from environment or stored location
21
+ survey = my_survey.push() # Uploads survey to Expected Parrot
22
+ job_info = coop.remote_inference_create(my_job) # Creates remote job
23
+
24
+ # Working with plugins
25
+ from edsl.coop import get_available_plugins
26
+ plugins = get_available_plugins()
27
+ plugin_names = [p.name for p in plugins]
28
+ ```
20
29
  """
21
30
 
22
31
  from .utils import EDSLObject, ObjectType, VisibilityType, ObjectRegistry
23
32
  from .coop import Coop
24
33
  from .exceptions import CoopServerResponseError
25
- __all__ = ["Coop"]
34
+
35
+ __all__ = [
36
+ "Coop",
37
+ "EDSLObject",
38
+ "ObjectType",
39
+ "VisibilityType",
40
+ "ObjectRegistry",
41
+ "CoopServerResponseError",
42
+ "AvailablePlugin",
43
+ "get_available_plugins",
44
+ "search_plugins",
45
+ "get_plugin_details",
46
+ "PluginRegistryError"
47
+ ]
edsl/coop/coop.py CHANGED
@@ -1,4 +1,5 @@
1
1
  import aiohttp
2
+ import base64
2
3
  import json
3
4
  import requests
4
5
 
@@ -140,7 +141,7 @@ class Coop(CoopFunctionsMixin):
140
141
  if self.api_key:
141
142
  headers["Authorization"] = f"Bearer {self.api_key}"
142
143
  else:
143
- headers["Authorization"] = f"Bearer None"
144
+ headers["Authorization"] = "Bearer None"
144
145
  return headers
145
146
 
146
147
  def _send_server_request(
@@ -149,7 +150,7 @@ class Coop(CoopFunctionsMixin):
149
150
  method: str,
150
151
  payload: Optional[dict[str, Any]] = None,
151
152
  params: Optional[dict[str, Any]] = None,
152
- timeout: Optional[float] = 5,
153
+ timeout: Optional[float] = 10,
153
154
  ) -> requests.Response:
154
155
  """
155
156
  Send a request to the server and return the response.
@@ -159,7 +160,7 @@ class Coop(CoopFunctionsMixin):
159
160
  if payload is None:
160
161
  timeout = 40
161
162
  elif (
162
- method.upper() == "POST"
163
+ (method.upper() == "POST" or method.upper() == "PATCH")
163
164
  and "json_string" in payload
164
165
  and payload.get("json_string") is not None
165
166
  ):
@@ -179,7 +180,9 @@ class Coop(CoopFunctionsMixin):
179
180
  timeout=timeout,
180
181
  )
181
182
  else:
182
- raise Exception(f"Invalid {method=}.")
183
+ from .exceptions import CoopInvalidMethodError
184
+
185
+ raise CoopInvalidMethodError(f"Invalid {method=}.")
183
186
  except requests.ConnectionError:
184
187
  raise requests.ConnectionError(f"Could not connect to the server at {url}.")
185
188
 
@@ -226,7 +229,8 @@ class Coop(CoopFunctionsMixin):
226
229
  """
227
230
  # Get EDSL version from header
228
231
  # breakpoint()
229
- server_edsl_version = response.headers.get("X-EDSL-Version")
232
+ # Commented out as currently unused
233
+ # server_edsl_version = response.headers.get("X-EDSL-Version")
230
234
 
231
235
  # if server_edsl_version:
232
236
  # if self._user_version_is_outdated(
@@ -266,7 +270,7 @@ class Coop(CoopFunctionsMixin):
266
270
 
267
271
  print("\n✨ API key retrieved.")
268
272
 
269
- if stored_in_user_space := self.ep_key_handler.ask_to_store(api_key):
273
+ if self.ep_key_handler.ask_to_store(api_key):
270
274
  pass
271
275
  else:
272
276
  path_to_env = write_api_key_to_env(api_key)
@@ -299,13 +303,19 @@ class Coop(CoopFunctionsMixin):
299
303
  message = root.find("Message").text
300
304
  details = root.find("Details").text
301
305
  except Exception:
302
- raise Exception(
303
- f"Server returned status code {response.status_code}",
304
- "XML response could not be decoded.",
305
- "The server response was: " + response.text,
306
+ from .exceptions import CoopServerResponseError
307
+
308
+ raise CoopServerResponseError(
309
+ f"Server returned status code {response.status_code}. "
310
+ f"XML response could not be decoded. "
311
+ f"The server response was: {response.text}"
306
312
  )
307
313
 
308
- raise Exception(f"An error occurred: {code} - {message} - {details}")
314
+ from .exceptions import CoopServerResponseError
315
+
316
+ raise CoopServerResponseError(
317
+ f"An error occurred: {code} - {message} - {details}"
318
+ )
309
319
 
310
320
  def _poll_for_api_key(
311
321
  self, edsl_auth_token: str, timeout: int = 120
@@ -432,6 +442,23 @@ class Coop(CoopFunctionsMixin):
432
442
  else:
433
443
  return None
434
444
 
445
+ def _scenario_is_file_store(self, scenario_dict: dict) -> bool:
446
+ """
447
+ Check if the scenario object is a valid FileStore.
448
+
449
+ Matches keys in the scenario dict against the expected keys for a FileStore.
450
+ """
451
+ file_store_keys = [
452
+ "path",
453
+ "base64_string",
454
+ "binary",
455
+ "suffix",
456
+ "mime_type",
457
+ "external_locations",
458
+ "extracted_text",
459
+ ]
460
+ return all(key in scenario_dict.keys() for key in file_store_keys)
461
+
435
462
  def create(
436
463
  self,
437
464
  object: EDSLObject,
@@ -471,21 +498,30 @@ class Coop(CoopFunctionsMixin):
471
498
  >>> print(result["url"]) # URL to access the survey
472
499
  """
473
500
  object_type = ObjectRegistry.get_object_type_by_edsl_class(object)
501
+ object_dict = object.to_dict()
502
+ if object_type == "scenario" and self._scenario_is_file_store(object_dict):
503
+ file_store_metadata = {
504
+ "suffix": object_dict["suffix"],
505
+ "mime_type": object_dict["mime_type"],
506
+ }
507
+ else:
508
+ file_store_metadata = None
474
509
  response = self._send_server_request(
475
- uri=f"api/v0/object",
510
+ uri="api/v0/object",
476
511
  method="POST",
477
512
  payload={
478
513
  "description": description,
479
514
  "alias": alias,
480
515
  "json_string": (
481
516
  json.dumps(
482
- object.to_dict(),
517
+ object_dict,
483
518
  default=self._json_handle_none,
484
519
  )
485
520
  if object_type != "scenario"
486
521
  else ""
487
522
  ),
488
523
  "object_type": object_type,
524
+ "file_store_metadata": file_store_metadata,
489
525
  "visibility": visibility,
490
526
  "version": self._edsl_version,
491
527
  },
@@ -495,19 +531,57 @@ class Coop(CoopFunctionsMixin):
495
531
 
496
532
  if object_type == "scenario":
497
533
  json_data = json.dumps(
498
- object.to_dict(),
534
+ object_dict,
499
535
  default=self._json_handle_none,
500
536
  )
501
537
  headers = {"Content-Type": "application/json"}
502
538
  if response_json.get("upload_signed_url"):
503
539
  signed_url = response_json.get("upload_signed_url")
504
540
  else:
505
- raise Exception("No signed url provided received")
541
+ from .exceptions import CoopResponseError
542
+
543
+ raise CoopResponseError("No signed url was provided received")
506
544
 
507
545
  response = requests.put(
508
546
  signed_url, data=json_data.encode(), headers=headers
509
547
  )
510
548
  self._resolve_gcs_response(response)
549
+
550
+ file_store_upload_signed_url = response_json.get(
551
+ "file_store_upload_signed_url"
552
+ )
553
+ if file_store_metadata and not file_store_upload_signed_url:
554
+ from .exceptions import CoopResponseError
555
+
556
+ raise CoopResponseError("No file store signed url provided.")
557
+ elif file_store_metadata:
558
+ headers = {"Content-Type": file_store_metadata["mime_type"]}
559
+ # Lint json files prior to upload
560
+ if file_store_metadata["suffix"] == "json":
561
+ file_store_bytes = base64.b64decode(object_dict["base64_string"])
562
+ pretty_json_string = json.dumps(
563
+ json.loads(file_store_bytes), indent=4
564
+ )
565
+ byte_data = pretty_json_string.encode("utf-8")
566
+ # Lint python files prior to upload
567
+ elif file_store_metadata["suffix"] == "py":
568
+ import black
569
+
570
+ file_store_bytes = base64.b64decode(object_dict["base64_string"])
571
+ python_string = file_store_bytes.decode("utf-8")
572
+ formatted_python_string = black.format_str(
573
+ python_string, mode=black.Mode()
574
+ )
575
+ byte_data = formatted_python_string.encode("utf-8")
576
+ else:
577
+ byte_data = base64.b64decode(object_dict["base64_string"])
578
+ response = requests.put(
579
+ file_store_upload_signed_url,
580
+ data=byte_data,
581
+ headers=headers,
582
+ )
583
+ self._resolve_gcs_response(response)
584
+
511
585
  owner_username = response_json.get("owner_username")
512
586
  object_alias = response_json.get("alias")
513
587
 
@@ -519,7 +593,6 @@ class Coop(CoopFunctionsMixin):
519
593
  "uuid": response_json.get("uuid"),
520
594
  "version": self._edsl_version,
521
595
  "visibility": response_json.get("visibility"),
522
- "upload_signed_url": response_json.get("upload_signed_url", None),
523
596
  }
524
597
 
525
598
  def get(
@@ -566,13 +639,13 @@ class Coop(CoopFunctionsMixin):
566
639
 
567
640
  if obj_uuid:
568
641
  response = self._send_server_request(
569
- uri=f"api/v0/object",
642
+ uri="api/v0/object",
570
643
  method="GET",
571
644
  params={"uuid": obj_uuid},
572
645
  )
573
646
  else:
574
647
  response = self._send_server_request(
575
- uri=f"api/v0/object/alias",
648
+ uri="api/v0/object/alias",
576
649
  method="GET",
577
650
  params={"owner_username": owner_username, "alias": alias},
578
651
  )
@@ -586,7 +659,11 @@ class Coop(CoopFunctionsMixin):
586
659
  json_string = object_data.text
587
660
  object_type = response.json().get("object_type")
588
661
  if expected_object_type and object_type != expected_object_type:
589
- raise Exception(f"Expected {expected_object_type=} but got {object_type=}")
662
+ from .exceptions import CoopObjectTypeError
663
+
664
+ raise CoopObjectTypeError(
665
+ f"Expected {expected_object_type=} but got {object_type=}"
666
+ )
590
667
  edsl_class = ObjectRegistry.object_type_to_edsl_class.get(object_type)
591
668
  object = edsl_class.from_dict(json.loads(json_string))
592
669
  return object
@@ -597,7 +674,7 @@ class Coop(CoopFunctionsMixin):
597
674
  """
598
675
  edsl_class = ObjectRegistry.object_type_to_edsl_class.get(object_type)
599
676
  response = self._send_server_request(
600
- uri=f"api/v0/objects",
677
+ uri="api/v0/objects",
601
678
  method="GET",
602
679
  params={"type": object_type},
603
680
  )
@@ -677,7 +754,9 @@ class Coop(CoopFunctionsMixin):
677
754
  and value is None
678
755
  and alias is None
679
756
  ):
680
- raise Exception("Nothing to patch.")
757
+ from .exceptions import CoopPatchError
758
+
759
+ raise CoopPatchError("Nothing to patch.")
681
760
 
682
761
  obj_uuid, owner_username, obj_alias = self._resolve_uuid_or_alias(url_or_uuid)
683
762
 
@@ -808,7 +887,9 @@ class Coop(CoopFunctionsMixin):
808
887
  [CacheEntry(...), CacheEntry(...), ...]
809
888
  """
810
889
  if job_uuid is None:
811
- raise ValueError("Must provide a job_uuid.")
890
+ from .exceptions import CoopValueError
891
+
892
+ raise CoopValueError("Must provide a job_uuid.")
812
893
  response = self._send_server_request(
813
894
  uri="api/v0/remote-cache/get-many-by-job",
814
895
  method="POST",
@@ -836,7 +917,9 @@ class Coop(CoopFunctionsMixin):
836
917
  [CacheEntry(...), CacheEntry(...), ...]
837
918
  """
838
919
  if select_keys is None or len(select_keys) == 0:
839
- raise ValueError("Must provide a non-empty list of select_keys.")
920
+ from .exceptions import CoopValueError
921
+
922
+ raise CoopValueError("Must provide a non-empty list of select_keys.")
840
923
  response = self._send_server_request(
841
924
  uri="api/v0/remote-cache/get-many-by-key",
842
925
  method="POST",
@@ -1099,7 +1182,9 @@ class Coop(CoopFunctionsMixin):
1099
1182
  ... print(f"Results available at: {job_status['results_url']}")
1100
1183
  """
1101
1184
  if job_uuid is None and results_uuid is None:
1102
- raise ValueError("Either job_uuid or results_uuid must be provided.")
1185
+ from .exceptions import CoopValueError
1186
+
1187
+ raise CoopValueError("Either job_uuid or results_uuid must be provided.")
1103
1188
  elif job_uuid is not None:
1104
1189
  params = {"job_uuid": job_uuid}
1105
1190
  else:
@@ -1136,7 +1221,7 @@ class Coop(CoopFunctionsMixin):
1136
1221
  "latest_error_report_uuid": latest_error_report_uuid,
1137
1222
  "latest_error_report_url": latest_error_report_url,
1138
1223
  "status": data.get("status"),
1139
- "reason": data.get("reason"),
1224
+ "reason": data.get("latest_failure_reason"),
1140
1225
  "credits_consumed": data.get("price"),
1141
1226
  "version": data.get("version"),
1142
1227
  }
@@ -1173,7 +1258,9 @@ class Coop(CoopFunctionsMixin):
1173
1258
  elif isinstance(input, Survey):
1174
1259
  job = Jobs(survey=input)
1175
1260
  else:
1176
- raise TypeError("Input must be either a Job or a Survey.")
1261
+ from .exceptions import CoopTypeError
1262
+
1263
+ raise CoopTypeError("Input must be either a Job or a Survey.")
1177
1264
 
1178
1265
  response = self._send_server_request(
1179
1266
  uri="api/v0/remote-inference/cost",
@@ -1215,7 +1302,7 @@ class Coop(CoopFunctionsMixin):
1215
1302
  )
1216
1303
  survey_uuid = survey_details.get("uuid")
1217
1304
  response = self._send_server_request(
1218
- uri=f"api/v0/projects/create-from-survey",
1305
+ uri="api/v0/projects/create-from-survey",
1219
1306
  method="POST",
1220
1307
  payload={"project_name": project_name, "survey_uuid": str(survey_uuid)},
1221
1308
  )
@@ -1308,7 +1395,9 @@ class Coop(CoopFunctionsMixin):
1308
1395
  elif CONFIG.get("EDSL_FETCH_TOKEN_PRICES") == "False":
1309
1396
  return {}
1310
1397
  else:
1311
- raise ValueError(
1398
+ from .exceptions import CoopValueError
1399
+
1400
+ raise CoopValueError(
1312
1401
  "Invalid EDSL_FETCH_TOKEN_PRICES value---should be 'True' or 'False'."
1313
1402
  )
1314
1403
 
@@ -1464,7 +1553,9 @@ class Coop(CoopFunctionsMixin):
1464
1553
  api_key = self._poll_for_api_key(edsl_auth_token)
1465
1554
 
1466
1555
  if api_key is None:
1467
- raise Exception("Timed out waiting for login. Please try again.")
1556
+ from .exceptions import CoopTimeoutError
1557
+
1558
+ raise CoopTimeoutError("Timed out waiting for login. Please try again.")
1468
1559
 
1469
1560
  path_to_env = write_api_key_to_env(api_key)
1470
1561
  print("\n✨ API key retrieved and written to .env file at the following path:")
@@ -1,6 +1,6 @@
1
1
  class CoopFunctionsMixin:
2
2
  def better_names(self, existing_names):
3
- from edsl import QuestionList, Scenario
3
+ from .. import QuestionList, Scenario
4
4
 
5
5
  s = Scenario({"existing_names": existing_names})
6
6
  q = QuestionList(
@@ -70,7 +70,7 @@ class ExpectedParrotKeyHandler:
70
70
 
71
71
  def ok_to_ask_to_store(self):
72
72
  """Check if it's okay to ask the user to store the key."""
73
- from edsl.config import CONFIG
73
+ from ..config import CONFIG
74
74
 
75
75
  if CONFIG.get("EDSL_RUN_MODE") != "production":
76
76
  return False