edsl 0.1.50__py3-none-any.whl → 0.1.51__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. edsl/__version__.py +1 -1
  2. edsl/base/base_exception.py +2 -2
  3. edsl/buckets/bucket_collection.py +1 -1
  4. edsl/buckets/exceptions.py +32 -0
  5. edsl/buckets/token_bucket_api.py +26 -10
  6. edsl/caching/cache.py +5 -2
  7. edsl/caching/remote_cache_sync.py +5 -5
  8. edsl/caching/sql_dict.py +12 -11
  9. edsl/config/__init__.py +1 -1
  10. edsl/config/config_class.py +4 -2
  11. edsl/conversation/Conversation.py +7 -4
  12. edsl/conversation/car_buying.py +1 -3
  13. edsl/conversation/mug_negotiation.py +2 -6
  14. edsl/coop/__init__.py +11 -8
  15. edsl/coop/coop.py +13 -13
  16. edsl/coop/coop_functions.py +1 -1
  17. edsl/coop/ep_key_handling.py +1 -1
  18. edsl/coop/price_fetcher.py +2 -2
  19. edsl/coop/utils.py +2 -2
  20. edsl/dataset/dataset.py +144 -63
  21. edsl/dataset/dataset_operations_mixin.py +14 -6
  22. edsl/dataset/dataset_tree.py +3 -3
  23. edsl/dataset/display/table_renderers.py +6 -3
  24. edsl/dataset/file_exports.py +4 -4
  25. edsl/dataset/r/ggplot.py +3 -3
  26. edsl/inference_services/available_model_fetcher.py +2 -2
  27. edsl/inference_services/data_structures.py +5 -5
  28. edsl/inference_services/inference_service_abc.py +1 -1
  29. edsl/inference_services/inference_services_collection.py +1 -1
  30. edsl/inference_services/service_availability.py +3 -3
  31. edsl/inference_services/services/azure_ai.py +3 -3
  32. edsl/inference_services/services/google_service.py +1 -1
  33. edsl/inference_services/services/test_service.py +1 -1
  34. edsl/instructions/change_instruction.py +5 -4
  35. edsl/instructions/instruction.py +1 -0
  36. edsl/instructions/instruction_collection.py +5 -4
  37. edsl/instructions/instruction_handler.py +10 -8
  38. edsl/interviews/exception_tracking.py +1 -1
  39. edsl/interviews/interview.py +1 -1
  40. edsl/interviews/interview_status_dictionary.py +1 -1
  41. edsl/interviews/interview_task_manager.py +2 -2
  42. edsl/interviews/request_token_estimator.py +3 -2
  43. edsl/interviews/statistics.py +2 -2
  44. edsl/invigilators/invigilators.py +2 -2
  45. edsl/jobs/__init__.py +39 -2
  46. edsl/jobs/async_interview_runner.py +1 -1
  47. edsl/jobs/check_survey_scenario_compatibility.py +5 -5
  48. edsl/jobs/data_structures.py +2 -2
  49. edsl/jobs/jobs.py +2 -2
  50. edsl/jobs/jobs_checks.py +5 -5
  51. edsl/jobs/jobs_component_constructor.py +2 -2
  52. edsl/jobs/jobs_pricing_estimation.py +1 -1
  53. edsl/jobs/jobs_runner_asyncio.py +2 -2
  54. edsl/jobs/remote_inference.py +1 -1
  55. edsl/jobs/results_exceptions_handler.py +2 -2
  56. edsl/language_models/language_model.py +5 -1
  57. edsl/notebooks/__init__.py +24 -1
  58. edsl/notebooks/exceptions.py +82 -0
  59. edsl/notebooks/notebook.py +7 -3
  60. edsl/notebooks/notebook_to_latex.py +1 -1
  61. edsl/prompts/__init__.py +23 -2
  62. edsl/prompts/prompt.py +1 -1
  63. edsl/questions/__init__.py +4 -4
  64. edsl/questions/answer_validator_mixin.py +0 -5
  65. edsl/questions/compose_questions.py +2 -2
  66. edsl/questions/descriptors.py +1 -1
  67. edsl/questions/question_base.py +32 -3
  68. edsl/questions/question_base_prompts_mixin.py +4 -4
  69. edsl/questions/question_budget.py +503 -102
  70. edsl/questions/question_check_box.py +658 -156
  71. edsl/questions/question_dict.py +176 -2
  72. edsl/questions/question_extract.py +401 -61
  73. edsl/questions/question_free_text.py +77 -9
  74. edsl/questions/question_functional.py +118 -9
  75. edsl/questions/{derived/question_likert_five.py → question_likert_five.py} +2 -2
  76. edsl/questions/{derived/question_linear_scale.py → question_linear_scale.py} +3 -4
  77. edsl/questions/question_list.py +246 -26
  78. edsl/questions/question_matrix.py +586 -73
  79. edsl/questions/question_multiple_choice.py +213 -47
  80. edsl/questions/question_numerical.py +360 -29
  81. edsl/questions/question_rank.py +401 -124
  82. edsl/questions/question_registry.py +3 -3
  83. edsl/questions/{derived/question_top_k.py → question_top_k.py} +3 -3
  84. edsl/questions/{derived/question_yes_no.py → question_yes_no.py} +3 -4
  85. edsl/questions/register_questions_meta.py +2 -1
  86. edsl/questions/response_validator_abc.py +6 -2
  87. edsl/questions/response_validator_factory.py +10 -12
  88. edsl/results/report.py +1 -1
  89. edsl/results/result.py +7 -4
  90. edsl/results/results.py +471 -271
  91. edsl/results/results_selector.py +2 -2
  92. edsl/scenarios/construct_download_link.py +3 -3
  93. edsl/scenarios/scenario.py +1 -2
  94. edsl/scenarios/scenario_list.py +41 -23
  95. edsl/surveys/survey_css.py +3 -3
  96. edsl/surveys/survey_simulator.py +2 -1
  97. edsl/tasks/__init__.py +22 -2
  98. edsl/tasks/exceptions.py +72 -0
  99. edsl/tasks/task_history.py +3 -3
  100. edsl/tokens/__init__.py +27 -1
  101. edsl/tokens/exceptions.py +37 -0
  102. edsl/tokens/interview_token_usage.py +3 -2
  103. edsl/tokens/token_usage.py +4 -3
  104. {edsl-0.1.50.dist-info → edsl-0.1.51.dist-info}/METADATA +1 -1
  105. {edsl-0.1.50.dist-info → edsl-0.1.51.dist-info}/RECORD +108 -106
  106. edsl/questions/derived/__init__.py +0 -0
  107. {edsl-0.1.50.dist-info → edsl-0.1.51.dist-info}/LICENSE +0 -0
  108. {edsl-0.1.50.dist-info → edsl-0.1.51.dist-info}/WHEEL +0 -0
  109. {edsl-0.1.50.dist-info → edsl-0.1.51.dist-info}/entry_points.txt +0 -0
edsl/__version__.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.1.50"
1
+ __version__ = "0.1.51"
@@ -146,7 +146,7 @@ class BaseException(Exception):
146
146
  # )
147
147
  # )
148
148
  # except:
149
- print(f"❌ EDSL ERROR: {etype.__name__}: {evalue}", file=sys.stderr)
149
+ print(f"❌ E[🦃] EDSL ERROR: {etype.__name__}: {evalue}", file=sys.stderr)
150
150
  # Suppress IPython’s normal traceback
151
151
  return
152
152
  # Otherwise, fall back to the usual traceback
@@ -178,7 +178,7 @@ class BaseException(Exception):
178
178
  # )
179
179
  # except:
180
180
  print(
181
- f"❌ EDSL ERROR: {exc_type.__name__}: {exc_value}",
181
+ f"❌ E[🦃]EDSL ERROR: {exc_type.__name__}: {exc_value}",
182
182
  exc_traceback,
183
183
  file=sys.stderr,
184
184
  )
@@ -135,7 +135,7 @@ class BucketCollection(UserDict):
135
135
  >>> # The following would raise an exception:
136
136
  >>> # bucket_collection.get_tokens(m, 'tokens', 10)
137
137
  """
138
- from edsl.buckets.exceptions import BucketError
138
+ from .exceptions import BucketError
139
139
  raise BucketError("This method is deprecated and should not be used")
140
140
  # The following code is kept for reference only
141
141
  # relevant_bucket = getattr(self[model], bucket_type)
@@ -72,4 +72,36 @@ class BucketConfigurationError(BucketError):
72
72
  TokenBucket(name="test", capacity=-100, refill_rate=10) # Would raise BucketConfigurationError
73
73
  ```
74
74
  """
75
+ relevant_doc = "https://docs.expectedparrot.com/"
76
+
77
+
78
+ class BucketNotFoundError(BucketError):
79
+ """
80
+ Exception raised when a requested bucket cannot be found.
81
+
82
+ This exception occurs when attempting to access a bucket that doesn't exist
83
+ in the system or has been removed.
84
+
85
+ Examples:
86
+ ```python
87
+ # Attempting to access a non-existent bucket:
88
+ bucket_collection.get_bucket("non_existent_bucket") # Would raise BucketNotFoundError
89
+ ```
90
+ """
91
+ relevant_doc = "https://docs.expectedparrot.com/"
92
+
93
+
94
+ class InvalidBucketParameterError(BucketConfigurationError):
95
+ """
96
+ Exception raised when an invalid parameter is provided for bucket operations.
97
+
98
+ This exception occurs when providing invalid parameters to bucket methods,
99
+ such as negative token amounts, invalid capacity values, etc.
100
+
101
+ Examples:
102
+ ```python
103
+ # Attempting to use invalid parameters:
104
+ bucket.add_tokens(-100) # Would raise InvalidBucketParameterError
105
+ ```
106
+ """
75
107
  relevant_doc = "https://docs.expectedparrot.com/"
@@ -1,9 +1,11 @@
1
- from fastapi import FastAPI, HTTPException
1
+ from fastapi import FastAPI
2
+ from fastapi.responses import JSONResponse
2
3
  from pydantic import BaseModel
3
4
  from typing import Union, Dict
4
5
  from typing import Optional
5
6
 
6
7
  from .token_bucket import TokenBucket # Original implementation
8
+ from .exceptions import BucketNotFoundError, InvalidBucketParameterError
7
9
 
8
10
  def safe_float_for_json(value: float) -> Union[float, str]:
9
11
  """Convert float('inf') to 'infinity' for JSON serialization.
@@ -24,6 +26,20 @@ app = FastAPI()
24
26
  # In-memory storage for TokenBucket instances
25
27
  buckets: Dict[str, TokenBucket] = {}
26
28
 
29
+ @app.exception_handler(BucketNotFoundError)
30
+ async def bucket_not_found_handler(request, exc):
31
+ return JSONResponse(
32
+ status_code=404,
33
+ content={"detail": str(exc)},
34
+ )
35
+
36
+ @app.exception_handler(InvalidBucketParameterError)
37
+ async def invalid_parameter_handler(request, exc):
38
+ return JSONResponse(
39
+ status_code=400,
40
+ content={"detail": str(exc)},
41
+ )
42
+
27
43
 
28
44
  class TokenBucketCreate(BaseModel):
29
45
  bucket_name: str
@@ -83,13 +99,13 @@ async def list_buckets(
83
99
  async def add_tokens(bucket_id: str, amount: float):
84
100
  """Add tokens to an existing bucket."""
85
101
  if bucket_id not in buckets:
86
- raise HTTPException(status_code=404, detail="Bucket not found")
102
+ raise BucketNotFoundError(f"Bucket with ID '{bucket_id}' not found")
87
103
 
88
104
  if not isinstance(amount, (int, float)) or amount != amount: # Check for NaN
89
- raise HTTPException(status_code=400, detail="Invalid amount specified")
105
+ raise InvalidBucketParameterError("Invalid amount specified")
90
106
 
91
107
  if amount == float("inf") or amount == float("-inf"):
92
- raise HTTPException(status_code=400, detail="Amount cannot be infinite")
108
+ raise InvalidBucketParameterError("Amount cannot be infinite")
93
109
 
94
110
  bucket = buckets[bucket_id]
95
111
  bucket.add_tokens(amount)
@@ -124,14 +140,14 @@ async def create_bucket(bucket: TokenBucketCreate):
124
140
  not isinstance(bucket.capacity, (int, float))
125
141
  or bucket.capacity != bucket.capacity
126
142
  ): # Check for NaN
127
- raise HTTPException(status_code=400, detail="Invalid capacity value")
143
+ raise InvalidBucketParameterError("Invalid capacity value")
128
144
  if (
129
145
  not isinstance(bucket.refill_rate, (int, float))
130
146
  or bucket.refill_rate != bucket.refill_rate
131
147
  ): # Check for NaN
132
- raise HTTPException(status_code=400, detail="Invalid refill rate value")
148
+ raise InvalidBucketParameterError("Invalid refill rate value")
133
149
  if bucket.capacity == float("inf") or bucket.refill_rate == float("inf"):
134
- raise HTTPException(status_code=400, detail="Values cannot be infinite")
150
+ raise InvalidBucketParameterError("Values cannot be infinite")
135
151
  bucket_id = f"{bucket.bucket_name}_{bucket.bucket_type}"
136
152
  if bucket_id in buckets:
137
153
  # Instead of error, return success with "existing" status
@@ -156,7 +172,7 @@ async def create_bucket(bucket: TokenBucketCreate):
156
172
  @app.post("/bucket/{bucket_id}/get_tokens")
157
173
  async def get_tokens(bucket_id: str, amount: float, cheat_bucket_capacity: bool = True):
158
174
  if bucket_id not in buckets:
159
- raise HTTPException(status_code=404, detail="Bucket not found")
175
+ raise BucketNotFoundError(f"Bucket with ID '{bucket_id}' not found")
160
176
 
161
177
  bucket = buckets[bucket_id]
162
178
  await bucket.get_tokens(amount, cheat_bucket_capacity)
@@ -166,7 +182,7 @@ async def get_tokens(bucket_id: str, amount: float, cheat_bucket_capacity: bool
166
182
  @app.post("/bucket/{bucket_id}/turbo_mode/{state}")
167
183
  async def set_turbo_mode(bucket_id: str, state: bool):
168
184
  if bucket_id not in buckets:
169
- raise HTTPException(status_code=404, detail="Bucket not found")
185
+ raise BucketNotFoundError(f"Bucket with ID '{bucket_id}' not found")
170
186
 
171
187
  bucket = buckets[bucket_id]
172
188
  if state:
@@ -179,7 +195,7 @@ async def set_turbo_mode(bucket_id: str, state: bool):
179
195
  @app.get("/bucket/{bucket_id}/status")
180
196
  async def get_bucket_status(bucket_id: str):
181
197
  if bucket_id not in buckets:
182
- raise HTTPException(status_code=404, detail="Bucket not found")
198
+ raise BucketNotFoundError(f"Bucket with ID '{bucket_id}' not found")
183
199
 
184
200
  bucket = buckets[bucket_id]
185
201
  status = {
edsl/caching/cache.py CHANGED
@@ -176,8 +176,11 @@ class Cache(Base):
176
176
 
177
177
  Examples:
178
178
  >>> from edsl import Cache
179
- >>> Cache.example().values()
180
- [CacheEntry(...)]
179
+ >>> entries = Cache.example().values()
180
+ >>> len(entries)
181
+ 1
182
+ >>> entries[0] # doctest: +ELLIPSIS
183
+ CacheEntry(model='gpt-3.5-turbo', parameters={'temperature': 0.5}, ...)
181
184
  """
182
185
  return list(self.data.values())
183
186
 
@@ -5,7 +5,7 @@ from collections import UserList
5
5
 
6
6
  if TYPE_CHECKING:
7
7
  from .cache import Cache
8
- from edsl.coop.coop import Coop
8
+ from ..coop.coop import Coop
9
9
  from .cache_entry import CacheEntry
10
10
 
11
11
 
@@ -34,7 +34,7 @@ class CacheEntriesList(UserList):
34
34
  return f"CacheEntries({entries_repr})"
35
35
 
36
36
  def to_cache(self) -> "Cache":
37
- from edsl.caching.cache import Cache
37
+ from .cache import Cache
38
38
 
39
39
  return Cache({entry.key: entry for entry in self.data})
40
40
 
@@ -178,9 +178,9 @@ if __name__ == "__main__":
178
178
 
179
179
  doctest.testmod()
180
180
 
181
- from edsl.coop.coop import Coop
182
- from edsl.data.Cache import Cache
183
- from edsl.data.CacheEntry import CacheEntry
181
+ from ..coop.coop import Coop
182
+ from .cache import Cache
183
+ from .cache_entry import CacheEntry
184
184
 
185
185
  r = RemoteCacheSync(Coop(), Cache(), print)
186
186
  diff = r._get_cache_difference()
edsl/caching/sql_dict.py CHANGED
@@ -64,8 +64,9 @@ class SQLiteDict:
64
64
 
65
65
  Example:
66
66
  >>> temp_db_path = SQLiteDict._get_temp_path()
67
- >>> SQLiteDict(f"sqlite:///{temp_db_path}") # Use the temp file for SQLite
68
- SQLiteDict(db_path='...')
67
+ >>> db = SQLiteDict(f"sqlite:///{temp_db_path}") # Use the temp file for SQLite
68
+ >>> isinstance(db, SQLiteDict)
69
+ True
69
70
  >>> import os; os.unlink(temp_db_path) # Clean up the temp file after the test
70
71
  """
71
72
  from sqlalchemy.exc import SQLAlchemyError
@@ -76,13 +77,13 @@ class SQLiteDict:
76
77
  if not self.db_path.startswith("sqlite:///"):
77
78
  self.db_path = f"sqlite:///{self.db_path}"
78
79
  try:
79
- from edsl.caching.orm import Base
80
+ from .orm import Base
80
81
 
81
82
  self.engine = create_engine(self.db_path, echo=False, future=True)
82
83
  Base.metadata.create_all(self.engine)
83
84
  self.Session = sessionmaker(bind=self.engine)
84
85
  except SQLAlchemyError as e:
85
- from edsl.caching.exceptions import CacheError
86
+ from .exceptions import CacheError
86
87
  raise CacheError(
87
88
  f"""Database initialization error: {e}. The attempted DB path was {db_path}"""
88
89
  ) from e
@@ -123,10 +124,10 @@ class SQLiteDict:
123
124
  >>> d["foo"] = CacheEntry.example()
124
125
  """
125
126
  if not isinstance(value, CacheEntry):
126
- from edsl.caching.exceptions import CacheValueError
127
+ from .exceptions import CacheValueError
127
128
  raise CacheValueError(f"Value must be a CacheEntry object (got {type(value)}).")
128
129
  with self.Session() as db:
129
- from edsl.caching.orm import Data
130
+ from .orm import Data
130
131
 
131
132
  db.merge(Data(key=key, value=json.dumps(value.to_dict())))
132
133
  db.commit()
@@ -155,11 +156,11 @@ class SQLiteDict:
155
156
  True
156
157
  """
157
158
  with self.Session() as db:
158
- from edsl.caching.orm import Data
159
+ from .orm import Data
159
160
 
160
161
  value = db.query(Data).filter_by(key=key).first()
161
162
  if not value:
162
- from edsl.caching.exceptions import CacheKeyError
163
+ from .exceptions import CacheKeyError
163
164
  raise CacheKeyError(f"Key '{key}' not found.")
164
165
  return CacheEntry.from_dict(json.loads(value.value))
165
166
 
@@ -183,7 +184,7 @@ class SQLiteDict:
183
184
  >>> d.get("foo", "bar")
184
185
  'bar'
185
186
  """
186
- from edsl.caching.exceptions import CacheKeyError
187
+ from .exceptions import CacheKeyError
187
188
  try:
188
189
  return self[key]
189
190
  except (KeyError, CacheKeyError):
@@ -236,7 +237,7 @@ class SQLiteDict:
236
237
  the database from being locked for too long.
237
238
  """
238
239
  if not (isinstance(new_d, dict) or isinstance(new_d, SQLiteDict)):
239
- from edsl.caching.exceptions import CacheValueError
240
+ from .exceptions import CacheValueError
240
241
  raise CacheValueError(
241
242
  f"new_d must be a dict or SQLiteDict object (got {type(new_d)})"
242
243
  )
@@ -305,7 +306,7 @@ class SQLiteDict:
305
306
  db.delete(instance)
306
307
  db.commit()
307
308
  else:
308
- from edsl.caching.exceptions import CacheKeyError
309
+ from .exceptions import CacheKeyError
309
310
  raise CacheKeyError(f"Key '{key}' not found.")
310
311
 
311
312
  def __contains__(self, key: str) -> bool:
edsl/config/__init__.py CHANGED
@@ -3,6 +3,6 @@
3
3
  This module provides a Config class that loads environment variables from a .env file and sets them as class attributes.
4
4
  """
5
5
 
6
- from edsl.config.config_class import Config, CONFIG, CONFIG_MAP, EDSL_RUN_MODES, cache_dir
6
+ from .config_class import Config, CONFIG, CONFIG_MAP, EDSL_RUN_MODES, cache_dir
7
7
 
8
8
  __all__ = ["Config", "CONFIG", "CONFIG_MAP", "EDSL_RUN_MODES", "cache_dir"]
@@ -3,8 +3,10 @@
3
3
  import os
4
4
  import platformdirs
5
5
  from dotenv import load_dotenv, find_dotenv
6
- from edsl.base import BaseException
7
- from edsl import logger
6
+ from ..base import BaseException
7
+ import logging
8
+
9
+ logger = logging.getLogger(__name__)
8
10
 
9
11
  class InvalidEnvironmentVariableError(BaseException):
10
12
  """Raised when an environment variable is invalid."""
@@ -1,13 +1,16 @@
1
1
  from collections import UserList
2
2
  import asyncio
3
3
  import inspect
4
- from typing import Optional, Callable
5
- from .. import QuestionFreeText, Results, AgentList, ScenarioList, Scenario
4
+ from typing import Optional, Callable, TYPE_CHECKING
5
+ from .. import QuestionFreeText, Results, AgentList, ScenarioList, Scenario, Model
6
6
  from ..questions import QuestionBase
7
7
  from ..results.Result import Result
8
8
  from jinja2 import Template
9
9
  from ..caching import Cache
10
10
 
11
+ if TYPE_CHECKING:
12
+ from ..language_models.model import Model
13
+
11
14
  from .next_speaker_utilities import (
12
15
  default_turn_taking_generator,
13
16
  speaker_closure,
@@ -71,7 +74,7 @@ class Conversation:
71
74
  conversation_index: Optional[int] = None,
72
75
  cache=None,
73
76
  disable_remote_inference=False,
74
- default_model: Optional["LanguageModel"] = None,
77
+ default_model: Optional[Model] = None,
75
78
  ):
76
79
  self.disable_remote_inference = disable_remote_inference
77
80
  self.per_round_message_template = per_round_message_template
@@ -120,7 +123,7 @@ What do you say next?"""
120
123
  per_round_message_template
121
124
  and "{{ round_message }}" not in next_statement_question.question_text
122
125
  ):
123
- from edsl.conversation.exceptions import ConversationValueError
126
+ from .exceptions import ConversationValueError
124
127
  raise ConversationValueError(
125
128
  "If you pass in a per_round_message_template, you must include {{ round_message }} in the question_text."
126
129
  )
@@ -1,5 +1,4 @@
1
- from .. import Agent, AgentList, QuestionFreeText
2
- from .. import Cache
1
+ from .. import Agent, AgentList, QuestionFreeText, Cache, QuestionList
3
2
  from .Conversation import Conversation, ConversationList
4
3
 
5
4
  a1 = Agent(
@@ -46,7 +45,6 @@ q = QuestionFreeText(
46
45
  question_name="car_brand",
47
46
  )
48
47
 
49
- from .. import QuestionList
50
48
 
51
49
  q_actors = QuestionList(
52
50
  question_text="""This was a conversation about buying a car: {{ transcript }}.
@@ -1,5 +1,5 @@
1
- from edsl import Agent, AgentList
2
- from edsl.conversation.Conversation import Conversation, ConversationList
1
+ from .. import Agent, AgentList, QuestionYesNo, QuestionNumerical
2
+ from .Conversation import Conversation, ConversationList
3
3
 
4
4
 
5
5
  def bargaining_pairs(alice_valuation, bob_valuation):
@@ -43,10 +43,6 @@ results.select("conversation_index", "index", "agent_name", "dialogue").print(
43
43
  format="rich"
44
44
  )
45
45
 
46
- from edsl import (
47
- QuestionYesNo,
48
- QuestionNumerical,
49
- )
50
46
 
51
47
  q_deal = QuestionYesNo(
52
48
  question_text="""This was a negotiation: {{ transcript }}.
edsl/coop/__init__.py CHANGED
@@ -13,16 +13,19 @@ This module enables EDSL to interact with cloud-based resources for enhanced fun
13
13
  The primary interface is the Coop class, which serves as a client for the
14
14
  Expected Parrot API. Most users will only need to interact with the Coop class directly.
15
15
 
16
- Example:
17
- >>> from edsl.coop import Coop
18
- >>> coop = Coop() # Uses API key from environment or stored location
19
- >>> survey = my_survey.push() # Uploads survey to Expected Parrot
20
- >>> job_info = coop.remote_inference_create(my_job) # Creates remote job
16
+ Examples:
17
+
18
+ ```python
19
+ from edsl.coop import Coop
20
+ coop = Coop() # Uses API key from environment or stored location
21
+ survey = my_survey.push() # Uploads survey to Expected Parrot
22
+ job_info = coop.remote_inference_create(my_job) # Creates remote job
21
23
 
22
24
  # Working with plugins
23
- >>> from edsl.coop import get_available_plugins
24
- >>> plugins = get_available_plugins()
25
- >>> plugin_names = [p.name for p in plugins]
25
+ from edsl.coop import get_available_plugins
26
+ plugins = get_available_plugins()
27
+ plugin_names = [p.name for p in plugins]
28
+ ```
26
29
  """
27
30
 
28
31
  from .utils import EDSLObject, ObjectType, VisibilityType, ObjectRegistry
edsl/coop/coop.py CHANGED
@@ -180,7 +180,7 @@ class Coop(CoopFunctionsMixin):
180
180
  timeout=timeout,
181
181
  )
182
182
  else:
183
- from edsl.coop.exceptions import CoopInvalidMethodError
183
+ from .exceptions import CoopInvalidMethodError
184
184
 
185
185
  raise CoopInvalidMethodError(f"Invalid {method=}.")
186
186
  except requests.ConnectionError:
@@ -303,7 +303,7 @@ class Coop(CoopFunctionsMixin):
303
303
  message = root.find("Message").text
304
304
  details = root.find("Details").text
305
305
  except Exception:
306
- from edsl.coop.exceptions import CoopServerResponseError
306
+ from .exceptions import CoopServerResponseError
307
307
 
308
308
  raise CoopServerResponseError(
309
309
  f"Server returned status code {response.status_code}. "
@@ -311,7 +311,7 @@ class Coop(CoopFunctionsMixin):
311
311
  f"The server response was: {response.text}"
312
312
  )
313
313
 
314
- from edsl.coop.exceptions import CoopServerResponseError
314
+ from .exceptions import CoopServerResponseError
315
315
 
316
316
  raise CoopServerResponseError(
317
317
  f"An error occurred: {code} - {message} - {details}"
@@ -538,7 +538,7 @@ class Coop(CoopFunctionsMixin):
538
538
  if response_json.get("upload_signed_url"):
539
539
  signed_url = response_json.get("upload_signed_url")
540
540
  else:
541
- from edsl.coop.exceptions import CoopResponseError
541
+ from .exceptions import CoopResponseError
542
542
 
543
543
  raise CoopResponseError("No signed url was provided received")
544
544
 
@@ -551,7 +551,7 @@ class Coop(CoopFunctionsMixin):
551
551
  "file_store_upload_signed_url"
552
552
  )
553
553
  if file_store_metadata and not file_store_upload_signed_url:
554
- from edsl.coop.exceptions import CoopResponseError
554
+ from .exceptions import CoopResponseError
555
555
 
556
556
  raise CoopResponseError("No file store signed url provided.")
557
557
  elif file_store_metadata:
@@ -659,7 +659,7 @@ class Coop(CoopFunctionsMixin):
659
659
  json_string = object_data.text
660
660
  object_type = response.json().get("object_type")
661
661
  if expected_object_type and object_type != expected_object_type:
662
- from edsl.coop.exceptions import CoopObjectTypeError
662
+ from .exceptions import CoopObjectTypeError
663
663
 
664
664
  raise CoopObjectTypeError(
665
665
  f"Expected {expected_object_type=} but got {object_type=}"
@@ -754,7 +754,7 @@ class Coop(CoopFunctionsMixin):
754
754
  and value is None
755
755
  and alias is None
756
756
  ):
757
- from edsl.coop.exceptions import CoopPatchError
757
+ from .exceptions import CoopPatchError
758
758
 
759
759
  raise CoopPatchError("Nothing to patch.")
760
760
 
@@ -887,7 +887,7 @@ class Coop(CoopFunctionsMixin):
887
887
  [CacheEntry(...), CacheEntry(...), ...]
888
888
  """
889
889
  if job_uuid is None:
890
- from edsl.coop.exceptions import CoopValueError
890
+ from .exceptions import CoopValueError
891
891
 
892
892
  raise CoopValueError("Must provide a job_uuid.")
893
893
  response = self._send_server_request(
@@ -917,7 +917,7 @@ class Coop(CoopFunctionsMixin):
917
917
  [CacheEntry(...), CacheEntry(...), ...]
918
918
  """
919
919
  if select_keys is None or len(select_keys) == 0:
920
- from edsl.coop.exceptions import CoopValueError
920
+ from .exceptions import CoopValueError
921
921
 
922
922
  raise CoopValueError("Must provide a non-empty list of select_keys.")
923
923
  response = self._send_server_request(
@@ -1182,7 +1182,7 @@ class Coop(CoopFunctionsMixin):
1182
1182
  ... print(f"Results available at: {job_status['results_url']}")
1183
1183
  """
1184
1184
  if job_uuid is None and results_uuid is None:
1185
- from edsl.coop.exceptions import CoopValueError
1185
+ from .exceptions import CoopValueError
1186
1186
 
1187
1187
  raise CoopValueError("Either job_uuid or results_uuid must be provided.")
1188
1188
  elif job_uuid is not None:
@@ -1258,7 +1258,7 @@ class Coop(CoopFunctionsMixin):
1258
1258
  elif isinstance(input, Survey):
1259
1259
  job = Jobs(survey=input)
1260
1260
  else:
1261
- from edsl.coop.exceptions import CoopTypeError
1261
+ from .exceptions import CoopTypeError
1262
1262
 
1263
1263
  raise CoopTypeError("Input must be either a Job or a Survey.")
1264
1264
 
@@ -1395,7 +1395,7 @@ class Coop(CoopFunctionsMixin):
1395
1395
  elif CONFIG.get("EDSL_FETCH_TOKEN_PRICES") == "False":
1396
1396
  return {}
1397
1397
  else:
1398
- from edsl.coop.exceptions import CoopValueError
1398
+ from .exceptions import CoopValueError
1399
1399
 
1400
1400
  raise CoopValueError(
1401
1401
  "Invalid EDSL_FETCH_TOKEN_PRICES value---should be 'True' or 'False'."
@@ -1553,7 +1553,7 @@ class Coop(CoopFunctionsMixin):
1553
1553
  api_key = self._poll_for_api_key(edsl_auth_token)
1554
1554
 
1555
1555
  if api_key is None:
1556
- from edsl.coop.exceptions import CoopTimeoutError
1556
+ from .exceptions import CoopTimeoutError
1557
1557
 
1558
1558
  raise CoopTimeoutError("Timed out waiting for login. Please try again.")
1559
1559
 
@@ -1,6 +1,6 @@
1
1
  class CoopFunctionsMixin:
2
2
  def better_names(self, existing_names):
3
- from edsl import QuestionList, Scenario
3
+ from .. import QuestionList, Scenario
4
4
 
5
5
  s = Scenario({"existing_names": existing_names})
6
6
  q = QuestionList(
@@ -70,7 +70,7 @@ class ExpectedParrotKeyHandler:
70
70
 
71
71
  def ok_to_ask_to_store(self):
72
72
  """Check if it's okay to ask the user to store the key."""
73
- from edsl.config import CONFIG
73
+ from ..config import CONFIG
74
74
 
75
75
  if CONFIG.get("EDSL_RUN_MODE") != "production":
76
76
  return False
@@ -7,6 +7,7 @@ that price information is only fetched once and then cached for efficiency.
7
7
  """
8
8
 
9
9
  import requests
10
+ import os
10
11
  from typing import Dict, Tuple, Any
11
12
 
12
13
 
@@ -82,7 +83,6 @@ class PriceFetcher:
82
83
  if self._cached_prices is not None:
83
84
  return self._cached_prices
84
85
 
85
- import os
86
86
  from ..config import CONFIG
87
87
 
88
88
  try:
@@ -120,4 +120,4 @@ class PriceFetcher:
120
120
  except requests.RequestException:
121
121
  # Silently handle errors and return empty dict
122
122
  # print(f"An error occurred: {e}")
123
- return {}
123
+ return {}
edsl/coop/utils.py CHANGED
@@ -129,7 +129,7 @@ class ObjectRegistry:
129
129
  # Look up the object type
130
130
  object_type = cls.edsl_class_to_object_type.get(edsl_class_name)
131
131
  if object_type is None:
132
- from edsl.coop.exceptions import CoopValueError
132
+ from .exceptions import CoopValueError
133
133
  raise CoopValueError(f"Object type not found for {edsl_object=}")
134
134
  return object_type
135
135
 
@@ -152,7 +152,7 @@ class ObjectRegistry:
152
152
  """
153
153
  EDSL_class = cls.object_type_to_edsl_class.get(object_type)
154
154
  if EDSL_class is None:
155
- from edsl.coop.exceptions import CoopValueError
155
+ from .exceptions import CoopValueError
156
156
  raise CoopValueError(f"EDSL class not found for {object_type=}")
157
157
  return EDSL_class
158
158