edsl 0.1.50__py3-none-any.whl → 0.1.51__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/__version__.py +1 -1
- edsl/base/base_exception.py +2 -2
- edsl/buckets/bucket_collection.py +1 -1
- edsl/buckets/exceptions.py +32 -0
- edsl/buckets/token_bucket_api.py +26 -10
- edsl/caching/cache.py +5 -2
- edsl/caching/remote_cache_sync.py +5 -5
- edsl/caching/sql_dict.py +12 -11
- edsl/config/__init__.py +1 -1
- edsl/config/config_class.py +4 -2
- edsl/conversation/Conversation.py +7 -4
- edsl/conversation/car_buying.py +1 -3
- edsl/conversation/mug_negotiation.py +2 -6
- edsl/coop/__init__.py +11 -8
- edsl/coop/coop.py +13 -13
- edsl/coop/coop_functions.py +1 -1
- edsl/coop/ep_key_handling.py +1 -1
- edsl/coop/price_fetcher.py +2 -2
- edsl/coop/utils.py +2 -2
- edsl/dataset/dataset.py +144 -63
- edsl/dataset/dataset_operations_mixin.py +14 -6
- edsl/dataset/dataset_tree.py +3 -3
- edsl/dataset/display/table_renderers.py +6 -3
- edsl/dataset/file_exports.py +4 -4
- edsl/dataset/r/ggplot.py +3 -3
- edsl/inference_services/available_model_fetcher.py +2 -2
- edsl/inference_services/data_structures.py +5 -5
- edsl/inference_services/inference_service_abc.py +1 -1
- edsl/inference_services/inference_services_collection.py +1 -1
- edsl/inference_services/service_availability.py +3 -3
- edsl/inference_services/services/azure_ai.py +3 -3
- edsl/inference_services/services/google_service.py +1 -1
- edsl/inference_services/services/test_service.py +1 -1
- edsl/instructions/change_instruction.py +5 -4
- edsl/instructions/instruction.py +1 -0
- edsl/instructions/instruction_collection.py +5 -4
- edsl/instructions/instruction_handler.py +10 -8
- edsl/interviews/exception_tracking.py +1 -1
- edsl/interviews/interview.py +1 -1
- edsl/interviews/interview_status_dictionary.py +1 -1
- edsl/interviews/interview_task_manager.py +2 -2
- edsl/interviews/request_token_estimator.py +3 -2
- edsl/interviews/statistics.py +2 -2
- edsl/invigilators/invigilators.py +2 -2
- edsl/jobs/__init__.py +39 -2
- edsl/jobs/async_interview_runner.py +1 -1
- edsl/jobs/check_survey_scenario_compatibility.py +5 -5
- edsl/jobs/data_structures.py +2 -2
- edsl/jobs/jobs.py +2 -2
- edsl/jobs/jobs_checks.py +5 -5
- edsl/jobs/jobs_component_constructor.py +2 -2
- edsl/jobs/jobs_pricing_estimation.py +1 -1
- edsl/jobs/jobs_runner_asyncio.py +2 -2
- edsl/jobs/remote_inference.py +1 -1
- edsl/jobs/results_exceptions_handler.py +2 -2
- edsl/language_models/language_model.py +5 -1
- edsl/notebooks/__init__.py +24 -1
- edsl/notebooks/exceptions.py +82 -0
- edsl/notebooks/notebook.py +7 -3
- edsl/notebooks/notebook_to_latex.py +1 -1
- edsl/prompts/__init__.py +23 -2
- edsl/prompts/prompt.py +1 -1
- edsl/questions/__init__.py +4 -4
- edsl/questions/answer_validator_mixin.py +0 -5
- edsl/questions/compose_questions.py +2 -2
- edsl/questions/descriptors.py +1 -1
- edsl/questions/question_base.py +32 -3
- edsl/questions/question_base_prompts_mixin.py +4 -4
- edsl/questions/question_budget.py +503 -102
- edsl/questions/question_check_box.py +658 -156
- edsl/questions/question_dict.py +176 -2
- edsl/questions/question_extract.py +401 -61
- edsl/questions/question_free_text.py +77 -9
- edsl/questions/question_functional.py +118 -9
- edsl/questions/{derived/question_likert_five.py → question_likert_five.py} +2 -2
- edsl/questions/{derived/question_linear_scale.py → question_linear_scale.py} +3 -4
- edsl/questions/question_list.py +246 -26
- edsl/questions/question_matrix.py +586 -73
- edsl/questions/question_multiple_choice.py +213 -47
- edsl/questions/question_numerical.py +360 -29
- edsl/questions/question_rank.py +401 -124
- edsl/questions/question_registry.py +3 -3
- edsl/questions/{derived/question_top_k.py → question_top_k.py} +3 -3
- edsl/questions/{derived/question_yes_no.py → question_yes_no.py} +3 -4
- edsl/questions/register_questions_meta.py +2 -1
- edsl/questions/response_validator_abc.py +6 -2
- edsl/questions/response_validator_factory.py +10 -12
- edsl/results/report.py +1 -1
- edsl/results/result.py +7 -4
- edsl/results/results.py +471 -271
- edsl/results/results_selector.py +2 -2
- edsl/scenarios/construct_download_link.py +3 -3
- edsl/scenarios/scenario.py +1 -2
- edsl/scenarios/scenario_list.py +41 -23
- edsl/surveys/survey_css.py +3 -3
- edsl/surveys/survey_simulator.py +2 -1
- edsl/tasks/__init__.py +22 -2
- edsl/tasks/exceptions.py +72 -0
- edsl/tasks/task_history.py +3 -3
- edsl/tokens/__init__.py +27 -1
- edsl/tokens/exceptions.py +37 -0
- edsl/tokens/interview_token_usage.py +3 -2
- edsl/tokens/token_usage.py +4 -3
- {edsl-0.1.50.dist-info → edsl-0.1.51.dist-info}/METADATA +1 -1
- {edsl-0.1.50.dist-info → edsl-0.1.51.dist-info}/RECORD +108 -106
- edsl/questions/derived/__init__.py +0 -0
- {edsl-0.1.50.dist-info → edsl-0.1.51.dist-info}/LICENSE +0 -0
- {edsl-0.1.50.dist-info → edsl-0.1.51.dist-info}/WHEEL +0 -0
- {edsl-0.1.50.dist-info → edsl-0.1.51.dist-info}/entry_points.txt +0 -0
edsl/__version__.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
__version__ = "0.1.
|
1
|
+
__version__ = "0.1.51"
|
edsl/base/base_exception.py
CHANGED
@@ -146,7 +146,7 @@ class BaseException(Exception):
|
|
146
146
|
# )
|
147
147
|
# )
|
148
148
|
# except:
|
149
|
-
print(f"❌ EDSL ERROR: {etype.__name__}: {evalue}", file=sys.stderr)
|
149
|
+
print(f"❌ E[🦃] EDSL ERROR: {etype.__name__}: {evalue}", file=sys.stderr)
|
150
150
|
# Suppress IPython’s normal traceback
|
151
151
|
return
|
152
152
|
# Otherwise, fall back to the usual traceback
|
@@ -178,7 +178,7 @@ class BaseException(Exception):
|
|
178
178
|
# )
|
179
179
|
# except:
|
180
180
|
print(
|
181
|
-
f"❌ EDSL ERROR: {exc_type.__name__}: {exc_value}",
|
181
|
+
f"❌ E[🦃]EDSL ERROR: {exc_type.__name__}: {exc_value}",
|
182
182
|
exc_traceback,
|
183
183
|
file=sys.stderr,
|
184
184
|
)
|
@@ -135,7 +135,7 @@ class BucketCollection(UserDict):
|
|
135
135
|
>>> # The following would raise an exception:
|
136
136
|
>>> # bucket_collection.get_tokens(m, 'tokens', 10)
|
137
137
|
"""
|
138
|
-
from
|
138
|
+
from .exceptions import BucketError
|
139
139
|
raise BucketError("This method is deprecated and should not be used")
|
140
140
|
# The following code is kept for reference only
|
141
141
|
# relevant_bucket = getattr(self[model], bucket_type)
|
edsl/buckets/exceptions.py
CHANGED
@@ -72,4 +72,36 @@ class BucketConfigurationError(BucketError):
|
|
72
72
|
TokenBucket(name="test", capacity=-100, refill_rate=10) # Would raise BucketConfigurationError
|
73
73
|
```
|
74
74
|
"""
|
75
|
+
relevant_doc = "https://docs.expectedparrot.com/"
|
76
|
+
|
77
|
+
|
78
|
+
class BucketNotFoundError(BucketError):
|
79
|
+
"""
|
80
|
+
Exception raised when a requested bucket cannot be found.
|
81
|
+
|
82
|
+
This exception occurs when attempting to access a bucket that doesn't exist
|
83
|
+
in the system or has been removed.
|
84
|
+
|
85
|
+
Examples:
|
86
|
+
```python
|
87
|
+
# Attempting to access a non-existent bucket:
|
88
|
+
bucket_collection.get_bucket("non_existent_bucket") # Would raise BucketNotFoundError
|
89
|
+
```
|
90
|
+
"""
|
91
|
+
relevant_doc = "https://docs.expectedparrot.com/"
|
92
|
+
|
93
|
+
|
94
|
+
class InvalidBucketParameterError(BucketConfigurationError):
|
95
|
+
"""
|
96
|
+
Exception raised when an invalid parameter is provided for bucket operations.
|
97
|
+
|
98
|
+
This exception occurs when providing invalid parameters to bucket methods,
|
99
|
+
such as negative token amounts, invalid capacity values, etc.
|
100
|
+
|
101
|
+
Examples:
|
102
|
+
```python
|
103
|
+
# Attempting to use invalid parameters:
|
104
|
+
bucket.add_tokens(-100) # Would raise InvalidBucketParameterError
|
105
|
+
```
|
106
|
+
"""
|
75
107
|
relevant_doc = "https://docs.expectedparrot.com/"
|
edsl/buckets/token_bucket_api.py
CHANGED
@@ -1,9 +1,11 @@
|
|
1
|
-
from fastapi import FastAPI
|
1
|
+
from fastapi import FastAPI
|
2
|
+
from fastapi.responses import JSONResponse
|
2
3
|
from pydantic import BaseModel
|
3
4
|
from typing import Union, Dict
|
4
5
|
from typing import Optional
|
5
6
|
|
6
7
|
from .token_bucket import TokenBucket # Original implementation
|
8
|
+
from .exceptions import BucketNotFoundError, InvalidBucketParameterError
|
7
9
|
|
8
10
|
def safe_float_for_json(value: float) -> Union[float, str]:
|
9
11
|
"""Convert float('inf') to 'infinity' for JSON serialization.
|
@@ -24,6 +26,20 @@ app = FastAPI()
|
|
24
26
|
# In-memory storage for TokenBucket instances
|
25
27
|
buckets: Dict[str, TokenBucket] = {}
|
26
28
|
|
29
|
+
@app.exception_handler(BucketNotFoundError)
|
30
|
+
async def bucket_not_found_handler(request, exc):
|
31
|
+
return JSONResponse(
|
32
|
+
status_code=404,
|
33
|
+
content={"detail": str(exc)},
|
34
|
+
)
|
35
|
+
|
36
|
+
@app.exception_handler(InvalidBucketParameterError)
|
37
|
+
async def invalid_parameter_handler(request, exc):
|
38
|
+
return JSONResponse(
|
39
|
+
status_code=400,
|
40
|
+
content={"detail": str(exc)},
|
41
|
+
)
|
42
|
+
|
27
43
|
|
28
44
|
class TokenBucketCreate(BaseModel):
|
29
45
|
bucket_name: str
|
@@ -83,13 +99,13 @@ async def list_buckets(
|
|
83
99
|
async def add_tokens(bucket_id: str, amount: float):
|
84
100
|
"""Add tokens to an existing bucket."""
|
85
101
|
if bucket_id not in buckets:
|
86
|
-
raise
|
102
|
+
raise BucketNotFoundError(f"Bucket with ID '{bucket_id}' not found")
|
87
103
|
|
88
104
|
if not isinstance(amount, (int, float)) or amount != amount: # Check for NaN
|
89
|
-
raise
|
105
|
+
raise InvalidBucketParameterError("Invalid amount specified")
|
90
106
|
|
91
107
|
if amount == float("inf") or amount == float("-inf"):
|
92
|
-
raise
|
108
|
+
raise InvalidBucketParameterError("Amount cannot be infinite")
|
93
109
|
|
94
110
|
bucket = buckets[bucket_id]
|
95
111
|
bucket.add_tokens(amount)
|
@@ -124,14 +140,14 @@ async def create_bucket(bucket: TokenBucketCreate):
|
|
124
140
|
not isinstance(bucket.capacity, (int, float))
|
125
141
|
or bucket.capacity != bucket.capacity
|
126
142
|
): # Check for NaN
|
127
|
-
raise
|
143
|
+
raise InvalidBucketParameterError("Invalid capacity value")
|
128
144
|
if (
|
129
145
|
not isinstance(bucket.refill_rate, (int, float))
|
130
146
|
or bucket.refill_rate != bucket.refill_rate
|
131
147
|
): # Check for NaN
|
132
|
-
raise
|
148
|
+
raise InvalidBucketParameterError("Invalid refill rate value")
|
133
149
|
if bucket.capacity == float("inf") or bucket.refill_rate == float("inf"):
|
134
|
-
raise
|
150
|
+
raise InvalidBucketParameterError("Values cannot be infinite")
|
135
151
|
bucket_id = f"{bucket.bucket_name}_{bucket.bucket_type}"
|
136
152
|
if bucket_id in buckets:
|
137
153
|
# Instead of error, return success with "existing" status
|
@@ -156,7 +172,7 @@ async def create_bucket(bucket: TokenBucketCreate):
|
|
156
172
|
@app.post("/bucket/{bucket_id}/get_tokens")
|
157
173
|
async def get_tokens(bucket_id: str, amount: float, cheat_bucket_capacity: bool = True):
|
158
174
|
if bucket_id not in buckets:
|
159
|
-
raise
|
175
|
+
raise BucketNotFoundError(f"Bucket with ID '{bucket_id}' not found")
|
160
176
|
|
161
177
|
bucket = buckets[bucket_id]
|
162
178
|
await bucket.get_tokens(amount, cheat_bucket_capacity)
|
@@ -166,7 +182,7 @@ async def get_tokens(bucket_id: str, amount: float, cheat_bucket_capacity: bool
|
|
166
182
|
@app.post("/bucket/{bucket_id}/turbo_mode/{state}")
|
167
183
|
async def set_turbo_mode(bucket_id: str, state: bool):
|
168
184
|
if bucket_id not in buckets:
|
169
|
-
raise
|
185
|
+
raise BucketNotFoundError(f"Bucket with ID '{bucket_id}' not found")
|
170
186
|
|
171
187
|
bucket = buckets[bucket_id]
|
172
188
|
if state:
|
@@ -179,7 +195,7 @@ async def set_turbo_mode(bucket_id: str, state: bool):
|
|
179
195
|
@app.get("/bucket/{bucket_id}/status")
|
180
196
|
async def get_bucket_status(bucket_id: str):
|
181
197
|
if bucket_id not in buckets:
|
182
|
-
raise
|
198
|
+
raise BucketNotFoundError(f"Bucket with ID '{bucket_id}' not found")
|
183
199
|
|
184
200
|
bucket = buckets[bucket_id]
|
185
201
|
status = {
|
edsl/caching/cache.py
CHANGED
@@ -176,8 +176,11 @@ class Cache(Base):
|
|
176
176
|
|
177
177
|
Examples:
|
178
178
|
>>> from edsl import Cache
|
179
|
-
>>> Cache.example().values()
|
180
|
-
|
179
|
+
>>> entries = Cache.example().values()
|
180
|
+
>>> len(entries)
|
181
|
+
1
|
182
|
+
>>> entries[0] # doctest: +ELLIPSIS
|
183
|
+
CacheEntry(model='gpt-3.5-turbo', parameters={'temperature': 0.5}, ...)
|
181
184
|
"""
|
182
185
|
return list(self.data.values())
|
183
186
|
|
@@ -5,7 +5,7 @@ from collections import UserList
|
|
5
5
|
|
6
6
|
if TYPE_CHECKING:
|
7
7
|
from .cache import Cache
|
8
|
-
from
|
8
|
+
from ..coop.coop import Coop
|
9
9
|
from .cache_entry import CacheEntry
|
10
10
|
|
11
11
|
|
@@ -34,7 +34,7 @@ class CacheEntriesList(UserList):
|
|
34
34
|
return f"CacheEntries({entries_repr})"
|
35
35
|
|
36
36
|
def to_cache(self) -> "Cache":
|
37
|
-
from
|
37
|
+
from .cache import Cache
|
38
38
|
|
39
39
|
return Cache({entry.key: entry for entry in self.data})
|
40
40
|
|
@@ -178,9 +178,9 @@ if __name__ == "__main__":
|
|
178
178
|
|
179
179
|
doctest.testmod()
|
180
180
|
|
181
|
-
from
|
182
|
-
from
|
183
|
-
from
|
181
|
+
from ..coop.coop import Coop
|
182
|
+
from .cache import Cache
|
183
|
+
from .cache_entry import CacheEntry
|
184
184
|
|
185
185
|
r = RemoteCacheSync(Coop(), Cache(), print)
|
186
186
|
diff = r._get_cache_difference()
|
edsl/caching/sql_dict.py
CHANGED
@@ -64,8 +64,9 @@ class SQLiteDict:
|
|
64
64
|
|
65
65
|
Example:
|
66
66
|
>>> temp_db_path = SQLiteDict._get_temp_path()
|
67
|
-
>>> SQLiteDict(f"sqlite:///{temp_db_path}") # Use the temp file for SQLite
|
68
|
-
SQLiteDict
|
67
|
+
>>> db = SQLiteDict(f"sqlite:///{temp_db_path}") # Use the temp file for SQLite
|
68
|
+
>>> isinstance(db, SQLiteDict)
|
69
|
+
True
|
69
70
|
>>> import os; os.unlink(temp_db_path) # Clean up the temp file after the test
|
70
71
|
"""
|
71
72
|
from sqlalchemy.exc import SQLAlchemyError
|
@@ -76,13 +77,13 @@ class SQLiteDict:
|
|
76
77
|
if not self.db_path.startswith("sqlite:///"):
|
77
78
|
self.db_path = f"sqlite:///{self.db_path}"
|
78
79
|
try:
|
79
|
-
from
|
80
|
+
from .orm import Base
|
80
81
|
|
81
82
|
self.engine = create_engine(self.db_path, echo=False, future=True)
|
82
83
|
Base.metadata.create_all(self.engine)
|
83
84
|
self.Session = sessionmaker(bind=self.engine)
|
84
85
|
except SQLAlchemyError as e:
|
85
|
-
from
|
86
|
+
from .exceptions import CacheError
|
86
87
|
raise CacheError(
|
87
88
|
f"""Database initialization error: {e}. The attempted DB path was {db_path}"""
|
88
89
|
) from e
|
@@ -123,10 +124,10 @@ class SQLiteDict:
|
|
123
124
|
>>> d["foo"] = CacheEntry.example()
|
124
125
|
"""
|
125
126
|
if not isinstance(value, CacheEntry):
|
126
|
-
from
|
127
|
+
from .exceptions import CacheValueError
|
127
128
|
raise CacheValueError(f"Value must be a CacheEntry object (got {type(value)}).")
|
128
129
|
with self.Session() as db:
|
129
|
-
from
|
130
|
+
from .orm import Data
|
130
131
|
|
131
132
|
db.merge(Data(key=key, value=json.dumps(value.to_dict())))
|
132
133
|
db.commit()
|
@@ -155,11 +156,11 @@ class SQLiteDict:
|
|
155
156
|
True
|
156
157
|
"""
|
157
158
|
with self.Session() as db:
|
158
|
-
from
|
159
|
+
from .orm import Data
|
159
160
|
|
160
161
|
value = db.query(Data).filter_by(key=key).first()
|
161
162
|
if not value:
|
162
|
-
from
|
163
|
+
from .exceptions import CacheKeyError
|
163
164
|
raise CacheKeyError(f"Key '{key}' not found.")
|
164
165
|
return CacheEntry.from_dict(json.loads(value.value))
|
165
166
|
|
@@ -183,7 +184,7 @@ class SQLiteDict:
|
|
183
184
|
>>> d.get("foo", "bar")
|
184
185
|
'bar'
|
185
186
|
"""
|
186
|
-
from
|
187
|
+
from .exceptions import CacheKeyError
|
187
188
|
try:
|
188
189
|
return self[key]
|
189
190
|
except (KeyError, CacheKeyError):
|
@@ -236,7 +237,7 @@ class SQLiteDict:
|
|
236
237
|
the database from being locked for too long.
|
237
238
|
"""
|
238
239
|
if not (isinstance(new_d, dict) or isinstance(new_d, SQLiteDict)):
|
239
|
-
from
|
240
|
+
from .exceptions import CacheValueError
|
240
241
|
raise CacheValueError(
|
241
242
|
f"new_d must be a dict or SQLiteDict object (got {type(new_d)})"
|
242
243
|
)
|
@@ -305,7 +306,7 @@ class SQLiteDict:
|
|
305
306
|
db.delete(instance)
|
306
307
|
db.commit()
|
307
308
|
else:
|
308
|
-
from
|
309
|
+
from .exceptions import CacheKeyError
|
309
310
|
raise CacheKeyError(f"Key '{key}' not found.")
|
310
311
|
|
311
312
|
def __contains__(self, key: str) -> bool:
|
edsl/config/__init__.py
CHANGED
@@ -3,6 +3,6 @@
|
|
3
3
|
This module provides a Config class that loads environment variables from a .env file and sets them as class attributes.
|
4
4
|
"""
|
5
5
|
|
6
|
-
from
|
6
|
+
from .config_class import Config, CONFIG, CONFIG_MAP, EDSL_RUN_MODES, cache_dir
|
7
7
|
|
8
8
|
__all__ = ["Config", "CONFIG", "CONFIG_MAP", "EDSL_RUN_MODES", "cache_dir"]
|
edsl/config/config_class.py
CHANGED
@@ -3,8 +3,10 @@
|
|
3
3
|
import os
|
4
4
|
import platformdirs
|
5
5
|
from dotenv import load_dotenv, find_dotenv
|
6
|
-
from
|
7
|
-
|
6
|
+
from ..base import BaseException
|
7
|
+
import logging
|
8
|
+
|
9
|
+
logger = logging.getLogger(__name__)
|
8
10
|
|
9
11
|
class InvalidEnvironmentVariableError(BaseException):
|
10
12
|
"""Raised when an environment variable is invalid."""
|
@@ -1,13 +1,16 @@
|
|
1
1
|
from collections import UserList
|
2
2
|
import asyncio
|
3
3
|
import inspect
|
4
|
-
from typing import Optional, Callable
|
5
|
-
from .. import QuestionFreeText, Results, AgentList, ScenarioList, Scenario
|
4
|
+
from typing import Optional, Callable, TYPE_CHECKING
|
5
|
+
from .. import QuestionFreeText, Results, AgentList, ScenarioList, Scenario, Model
|
6
6
|
from ..questions import QuestionBase
|
7
7
|
from ..results.Result import Result
|
8
8
|
from jinja2 import Template
|
9
9
|
from ..caching import Cache
|
10
10
|
|
11
|
+
if TYPE_CHECKING:
|
12
|
+
from ..language_models.model import Model
|
13
|
+
|
11
14
|
from .next_speaker_utilities import (
|
12
15
|
default_turn_taking_generator,
|
13
16
|
speaker_closure,
|
@@ -71,7 +74,7 @@ class Conversation:
|
|
71
74
|
conversation_index: Optional[int] = None,
|
72
75
|
cache=None,
|
73
76
|
disable_remote_inference=False,
|
74
|
-
default_model: Optional[
|
77
|
+
default_model: Optional[Model] = None,
|
75
78
|
):
|
76
79
|
self.disable_remote_inference = disable_remote_inference
|
77
80
|
self.per_round_message_template = per_round_message_template
|
@@ -120,7 +123,7 @@ What do you say next?"""
|
|
120
123
|
per_round_message_template
|
121
124
|
and "{{ round_message }}" not in next_statement_question.question_text
|
122
125
|
):
|
123
|
-
from
|
126
|
+
from .exceptions import ConversationValueError
|
124
127
|
raise ConversationValueError(
|
125
128
|
"If you pass in a per_round_message_template, you must include {{ round_message }} in the question_text."
|
126
129
|
)
|
edsl/conversation/car_buying.py
CHANGED
@@ -1,5 +1,4 @@
|
|
1
|
-
from .. import Agent, AgentList, QuestionFreeText
|
2
|
-
from .. import Cache
|
1
|
+
from .. import Agent, AgentList, QuestionFreeText, Cache, QuestionList
|
3
2
|
from .Conversation import Conversation, ConversationList
|
4
3
|
|
5
4
|
a1 = Agent(
|
@@ -46,7 +45,6 @@ q = QuestionFreeText(
|
|
46
45
|
question_name="car_brand",
|
47
46
|
)
|
48
47
|
|
49
|
-
from .. import QuestionList
|
50
48
|
|
51
49
|
q_actors = QuestionList(
|
52
50
|
question_text="""This was a conversation about buying a car: {{ transcript }}.
|
@@ -1,5 +1,5 @@
|
|
1
|
-
from
|
2
|
-
from
|
1
|
+
from .. import Agent, AgentList, QuestionYesNo, QuestionNumerical
|
2
|
+
from .Conversation import Conversation, ConversationList
|
3
3
|
|
4
4
|
|
5
5
|
def bargaining_pairs(alice_valuation, bob_valuation):
|
@@ -43,10 +43,6 @@ results.select("conversation_index", "index", "agent_name", "dialogue").print(
|
|
43
43
|
format="rich"
|
44
44
|
)
|
45
45
|
|
46
|
-
from edsl import (
|
47
|
-
QuestionYesNo,
|
48
|
-
QuestionNumerical,
|
49
|
-
)
|
50
46
|
|
51
47
|
q_deal = QuestionYesNo(
|
52
48
|
question_text="""This was a negotiation: {{ transcript }}.
|
edsl/coop/__init__.py
CHANGED
@@ -13,16 +13,19 @@ This module enables EDSL to interact with cloud-based resources for enhanced fun
|
|
13
13
|
The primary interface is the Coop class, which serves as a client for the
|
14
14
|
Expected Parrot API. Most users will only need to interact with the Coop class directly.
|
15
15
|
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
16
|
+
Examples:
|
17
|
+
|
18
|
+
```python
|
19
|
+
from edsl.coop import Coop
|
20
|
+
coop = Coop() # Uses API key from environment or stored location
|
21
|
+
survey = my_survey.push() # Uploads survey to Expected Parrot
|
22
|
+
job_info = coop.remote_inference_create(my_job) # Creates remote job
|
21
23
|
|
22
24
|
# Working with plugins
|
23
|
-
|
24
|
-
|
25
|
-
|
25
|
+
from edsl.coop import get_available_plugins
|
26
|
+
plugins = get_available_plugins()
|
27
|
+
plugin_names = [p.name for p in plugins]
|
28
|
+
```
|
26
29
|
"""
|
27
30
|
|
28
31
|
from .utils import EDSLObject, ObjectType, VisibilityType, ObjectRegistry
|
edsl/coop/coop.py
CHANGED
@@ -180,7 +180,7 @@ class Coop(CoopFunctionsMixin):
|
|
180
180
|
timeout=timeout,
|
181
181
|
)
|
182
182
|
else:
|
183
|
-
from
|
183
|
+
from .exceptions import CoopInvalidMethodError
|
184
184
|
|
185
185
|
raise CoopInvalidMethodError(f"Invalid {method=}.")
|
186
186
|
except requests.ConnectionError:
|
@@ -303,7 +303,7 @@ class Coop(CoopFunctionsMixin):
|
|
303
303
|
message = root.find("Message").text
|
304
304
|
details = root.find("Details").text
|
305
305
|
except Exception:
|
306
|
-
from
|
306
|
+
from .exceptions import CoopServerResponseError
|
307
307
|
|
308
308
|
raise CoopServerResponseError(
|
309
309
|
f"Server returned status code {response.status_code}. "
|
@@ -311,7 +311,7 @@ class Coop(CoopFunctionsMixin):
|
|
311
311
|
f"The server response was: {response.text}"
|
312
312
|
)
|
313
313
|
|
314
|
-
from
|
314
|
+
from .exceptions import CoopServerResponseError
|
315
315
|
|
316
316
|
raise CoopServerResponseError(
|
317
317
|
f"An error occurred: {code} - {message} - {details}"
|
@@ -538,7 +538,7 @@ class Coop(CoopFunctionsMixin):
|
|
538
538
|
if response_json.get("upload_signed_url"):
|
539
539
|
signed_url = response_json.get("upload_signed_url")
|
540
540
|
else:
|
541
|
-
from
|
541
|
+
from .exceptions import CoopResponseError
|
542
542
|
|
543
543
|
raise CoopResponseError("No signed url was provided received")
|
544
544
|
|
@@ -551,7 +551,7 @@ class Coop(CoopFunctionsMixin):
|
|
551
551
|
"file_store_upload_signed_url"
|
552
552
|
)
|
553
553
|
if file_store_metadata and not file_store_upload_signed_url:
|
554
|
-
from
|
554
|
+
from .exceptions import CoopResponseError
|
555
555
|
|
556
556
|
raise CoopResponseError("No file store signed url provided.")
|
557
557
|
elif file_store_metadata:
|
@@ -659,7 +659,7 @@ class Coop(CoopFunctionsMixin):
|
|
659
659
|
json_string = object_data.text
|
660
660
|
object_type = response.json().get("object_type")
|
661
661
|
if expected_object_type and object_type != expected_object_type:
|
662
|
-
from
|
662
|
+
from .exceptions import CoopObjectTypeError
|
663
663
|
|
664
664
|
raise CoopObjectTypeError(
|
665
665
|
f"Expected {expected_object_type=} but got {object_type=}"
|
@@ -754,7 +754,7 @@ class Coop(CoopFunctionsMixin):
|
|
754
754
|
and value is None
|
755
755
|
and alias is None
|
756
756
|
):
|
757
|
-
from
|
757
|
+
from .exceptions import CoopPatchError
|
758
758
|
|
759
759
|
raise CoopPatchError("Nothing to patch.")
|
760
760
|
|
@@ -887,7 +887,7 @@ class Coop(CoopFunctionsMixin):
|
|
887
887
|
[CacheEntry(...), CacheEntry(...), ...]
|
888
888
|
"""
|
889
889
|
if job_uuid is None:
|
890
|
-
from
|
890
|
+
from .exceptions import CoopValueError
|
891
891
|
|
892
892
|
raise CoopValueError("Must provide a job_uuid.")
|
893
893
|
response = self._send_server_request(
|
@@ -917,7 +917,7 @@ class Coop(CoopFunctionsMixin):
|
|
917
917
|
[CacheEntry(...), CacheEntry(...), ...]
|
918
918
|
"""
|
919
919
|
if select_keys is None or len(select_keys) == 0:
|
920
|
-
from
|
920
|
+
from .exceptions import CoopValueError
|
921
921
|
|
922
922
|
raise CoopValueError("Must provide a non-empty list of select_keys.")
|
923
923
|
response = self._send_server_request(
|
@@ -1182,7 +1182,7 @@ class Coop(CoopFunctionsMixin):
|
|
1182
1182
|
... print(f"Results available at: {job_status['results_url']}")
|
1183
1183
|
"""
|
1184
1184
|
if job_uuid is None and results_uuid is None:
|
1185
|
-
from
|
1185
|
+
from .exceptions import CoopValueError
|
1186
1186
|
|
1187
1187
|
raise CoopValueError("Either job_uuid or results_uuid must be provided.")
|
1188
1188
|
elif job_uuid is not None:
|
@@ -1258,7 +1258,7 @@ class Coop(CoopFunctionsMixin):
|
|
1258
1258
|
elif isinstance(input, Survey):
|
1259
1259
|
job = Jobs(survey=input)
|
1260
1260
|
else:
|
1261
|
-
from
|
1261
|
+
from .exceptions import CoopTypeError
|
1262
1262
|
|
1263
1263
|
raise CoopTypeError("Input must be either a Job or a Survey.")
|
1264
1264
|
|
@@ -1395,7 +1395,7 @@ class Coop(CoopFunctionsMixin):
|
|
1395
1395
|
elif CONFIG.get("EDSL_FETCH_TOKEN_PRICES") == "False":
|
1396
1396
|
return {}
|
1397
1397
|
else:
|
1398
|
-
from
|
1398
|
+
from .exceptions import CoopValueError
|
1399
1399
|
|
1400
1400
|
raise CoopValueError(
|
1401
1401
|
"Invalid EDSL_FETCH_TOKEN_PRICES value---should be 'True' or 'False'."
|
@@ -1553,7 +1553,7 @@ class Coop(CoopFunctionsMixin):
|
|
1553
1553
|
api_key = self._poll_for_api_key(edsl_auth_token)
|
1554
1554
|
|
1555
1555
|
if api_key is None:
|
1556
|
-
from
|
1556
|
+
from .exceptions import CoopTimeoutError
|
1557
1557
|
|
1558
1558
|
raise CoopTimeoutError("Timed out waiting for login. Please try again.")
|
1559
1559
|
|
edsl/coop/coop_functions.py
CHANGED
edsl/coop/ep_key_handling.py
CHANGED
@@ -70,7 +70,7 @@ class ExpectedParrotKeyHandler:
|
|
70
70
|
|
71
71
|
def ok_to_ask_to_store(self):
|
72
72
|
"""Check if it's okay to ask the user to store the key."""
|
73
|
-
from
|
73
|
+
from ..config import CONFIG
|
74
74
|
|
75
75
|
if CONFIG.get("EDSL_RUN_MODE") != "production":
|
76
76
|
return False
|
edsl/coop/price_fetcher.py
CHANGED
@@ -7,6 +7,7 @@ that price information is only fetched once and then cached for efficiency.
|
|
7
7
|
"""
|
8
8
|
|
9
9
|
import requests
|
10
|
+
import os
|
10
11
|
from typing import Dict, Tuple, Any
|
11
12
|
|
12
13
|
|
@@ -82,7 +83,6 @@ class PriceFetcher:
|
|
82
83
|
if self._cached_prices is not None:
|
83
84
|
return self._cached_prices
|
84
85
|
|
85
|
-
import os
|
86
86
|
from ..config import CONFIG
|
87
87
|
|
88
88
|
try:
|
@@ -120,4 +120,4 @@ class PriceFetcher:
|
|
120
120
|
except requests.RequestException:
|
121
121
|
# Silently handle errors and return empty dict
|
122
122
|
# print(f"An error occurred: {e}")
|
123
|
-
return {}
|
123
|
+
return {}
|
edsl/coop/utils.py
CHANGED
@@ -129,7 +129,7 @@ class ObjectRegistry:
|
|
129
129
|
# Look up the object type
|
130
130
|
object_type = cls.edsl_class_to_object_type.get(edsl_class_name)
|
131
131
|
if object_type is None:
|
132
|
-
from
|
132
|
+
from .exceptions import CoopValueError
|
133
133
|
raise CoopValueError(f"Object type not found for {edsl_object=}")
|
134
134
|
return object_type
|
135
135
|
|
@@ -152,7 +152,7 @@ class ObjectRegistry:
|
|
152
152
|
"""
|
153
153
|
EDSL_class = cls.object_type_to_edsl_class.get(object_type)
|
154
154
|
if EDSL_class is None:
|
155
|
-
from
|
155
|
+
from .exceptions import CoopValueError
|
156
156
|
raise CoopValueError(f"EDSL class not found for {object_type=}")
|
157
157
|
return EDSL_class
|
158
158
|
|