edsl 0.1.60__py3-none-any.whl → 0.1.62__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -26,7 +26,7 @@ class CoopJobsObjects(CoopObjects):
26
26
 
27
27
  c = Coop()
28
28
  job_details = [
29
- c.remote_inference_get(obj["uuid"], include_json_string=True)
29
+ c.new_remote_inference_get(obj["uuid"], include_json_string=True)
30
30
  for obj in self
31
31
  ]
32
32
 
@@ -53,7 +53,7 @@ class CoopJobsObjects(CoopObjects):
53
53
 
54
54
  for obj in self:
55
55
  if obj.get("results_uuid"):
56
- result = c.get(obj["results_uuid"])
56
+ result = c.pull(obj["results_uuid"], expected_object_type="results")
57
57
  results.append(result)
58
58
 
59
59
  return results
@@ -0,0 +1,171 @@
1
+ import reprlib
2
+ from typing import Optional
3
+
4
+ from .exceptions import CoopValueError
5
+ from ..scenarios import Scenario, ScenarioList
6
+
7
+
8
+ class CoopProlificFilters(ScenarioList):
9
+ """Base class for Prolific filters supported on Coop.
10
+
11
+ This abstract class extends ScenarioList to provide common functionality
12
+ for working with Prolific filters.
13
+ """
14
+
15
+ def __init__(
16
+ self, data: Optional[list] = None, codebook: Optional[dict[str, str]] = None
17
+ ):
18
+ super().__init__(data, codebook)
19
+
20
+ def find(self, filter_id: str) -> Optional[Scenario]:
21
+ """
22
+ Find a filter by its ID. Raises a CoopValueError if the filter is not found.
23
+
24
+ >>> filters = coop.list_prolific_filters()
25
+ >>> filters.find("age")
26
+ Scenario(
27
+ {
28
+ "filter_id": "age",
29
+ "type": "range",
30
+ "range_filter_min": 18,
31
+ "range_filter_max": 100,
32
+ ...
33
+ }
34
+ """
35
+
36
+ # Prolific has inconsistent naming conventions for filters -
37
+ # some use underscores and some use dashes, so we need to check for both
38
+ id_with_dashes = filter_id.replace("_", "-")
39
+ id_with_underscores = filter_id.replace("-", "_")
40
+
41
+ for scenario in self:
42
+ if (
43
+ scenario["filter_id"] == id_with_dashes
44
+ or scenario["filter_id"] == id_with_underscores
45
+ ):
46
+ return scenario
47
+ raise CoopValueError(f"Filter with ID {filter_id} not found.")
48
+
49
+ def create_study_filter(
50
+ self,
51
+ filter_id: str,
52
+ min: Optional[int] = None,
53
+ max: Optional[int] = None,
54
+ values: Optional[list[str]] = None,
55
+ ) -> dict:
56
+ """
57
+ Create a valid filter dict that is compatible with Coop.create_prolific_study().
58
+ This function will raise a CoopValueError if:
59
+ - The filter ID is not found
60
+ - A range filter is provided with no min or max value, or a value that is outside of the allowed range
61
+ - A select filter is provided with no values, or a value that is not in the allowed options
62
+
63
+ For a select filter, you should pass a list of values:
64
+ >>> filters = coop.list_prolific_filters()
65
+ >>> filters.create_study_filter("current_country_of_residence", values=["United States", "Canada"])
66
+ {
67
+ "filter_id": "current_country_of_residence",
68
+ "selected_values": ["1", "45"],
69
+ }
70
+
71
+ For a range filter, you should pass a min and max value:
72
+ >>> filters.create_study_filter("age", min=20, max=40)
73
+ {
74
+ "filter_id": "age",
75
+ "selected_range": {
76
+ "lower": 20,
77
+ "upper": 40,
78
+ },
79
+ }
80
+ """
81
+ filter = self.find(filter_id)
82
+
83
+ # .find() has logic to handle inconsistent naming conventions for filter IDs,
84
+ # so we need to get the correct filter ID from the filter dict
85
+ correct_filter_id = filter.get("filter_id")
86
+
87
+ filter_type = filter.get("type")
88
+
89
+ if filter_type == "range":
90
+ filter_min = filter.get("range_filter_min")
91
+ filter_max = filter.get("range_filter_max")
92
+
93
+ if min is None and max is None:
94
+ raise CoopValueError("Range filters require both a min and max value.")
95
+ if min < filter_min:
96
+ raise CoopValueError(
97
+ f"Min value {min} is less than the minimum allowed value {filter_min}."
98
+ )
99
+ if max > filter_max:
100
+ raise CoopValueError(
101
+ f"Max value {max} is greater than the maximum allowed value {filter_max}."
102
+ )
103
+ if min > max:
104
+ raise CoopValueError("Min value cannot be greater than max value.")
105
+ return {
106
+ "filter_id": correct_filter_id,
107
+ "selected_range": {
108
+ "lower": min,
109
+ "upper": max,
110
+ },
111
+ }
112
+ elif filter_type == "select":
113
+ if values is None:
114
+ raise CoopValueError("Select filters require a list of values.")
115
+
116
+ if correct_filter_id == "custom_allowlist":
117
+ return {
118
+ "filter_id": correct_filter_id,
119
+ "selected_values": values,
120
+ }
121
+
122
+ try:
123
+ allowed_option_labels = filter.get("select_filter_options", {})
124
+ option_labels_to_ids = {v: k for k, v in allowed_option_labels.items()}
125
+ selected_option_ids = [option_labels_to_ids[value] for value in values]
126
+ except KeyError:
127
+ raise CoopValueError(
128
+ f"Invalid value(s) provided for filter {filter_id}: {values}. "
129
+ f"Call find() with the filter ID to examine the allowed values for this filter."
130
+ )
131
+
132
+ return {
133
+ "filter_id": correct_filter_id,
134
+ "selected_values": selected_option_ids,
135
+ }
136
+ else:
137
+ raise CoopValueError(f"Unsupported filter type: {filter_type}.")
138
+
139
+ def table(
140
+ self,
141
+ *fields,
142
+ tablefmt: Optional[str] = None,
143
+ pretty_labels: Optional[dict[str, str]] = None,
144
+ ) -> str:
145
+ """Return the CoopProlificFilters as a table with truncated options display for select filters."""
146
+
147
+ # Create a copy of the data with truncated options
148
+ truncated_scenarios = []
149
+ for scenario in self:
150
+ scenario_dict = dict(scenario)
151
+ if (
152
+ "select_filter_options" in scenario_dict
153
+ and scenario_dict["select_filter_options"] is not None
154
+ ):
155
+
156
+ # Create a truncated representation of the options list
157
+ formatter = reprlib.Repr()
158
+ formatter.maxstring = 50
159
+ select_filter_options = list(
160
+ dict(scenario_dict["select_filter_options"]).values()
161
+ )
162
+ formatted_options = formatter.repr(select_filter_options)
163
+ scenario_dict["select_filter_options"] = formatted_options
164
+ truncated_scenarios.append(scenario_dict)
165
+
166
+ temp_scenario_list = ScenarioList([Scenario(s) for s in truncated_scenarios])
167
+
168
+ # Display the table with the truncated data
169
+ return temp_scenario_list.table(
170
+ *fields, tablefmt=tablefmt, pretty_labels=pretty_labels
171
+ )
@@ -23,4 +23,6 @@ class CoopRegularObjects(CoopObjects):
23
23
  from ..coop import Coop
24
24
 
25
25
  c = Coop()
26
- return [c.get(obj["uuid"]) for obj in self]
26
+ return [
27
+ c.pull(obj["uuid"], expected_object_type=obj["object_type"]) for obj in self
28
+ ]
@@ -55,13 +55,46 @@ class TableDisplay:
55
55
  self.printing_parameters = {}
56
56
 
57
57
  def _repr_html_(self) -> str:
58
- table_data = TableData(
59
- headers=self.headers,
60
- data=self.data,
61
- parameters=self.printing_parameters,
62
- raw_data_set=self.raw_data_set,
63
- )
64
- return self.renderer_class(table_data).render_html()
58
+ """
59
+ HTML representation for Jupyter/Colab notebooks.
60
+
61
+ The primary path uses the configured `renderer_class` to build an HTML
62
+ string. Unfortunately, in shared or long-running notebook runtimes it
63
+ is not uncommon for binary dependencies (NumPy, Pandas, etc.) to get
64
+ into an incompatible state, raising import-time errors that would
65
+ otherwise bubble up to the notebook and obscure the actual table
66
+ output. To make the developer experience smoother we catch *any*
67
+ exception, log/annotate it, and fall back to a plain-text rendering via
68
+ `tabulate`, wrapped in a <pre> block so at least a readable table is
69
+ shown.
70
+ """
71
+ try:
72
+ table_data = TableData(
73
+ headers=self.headers,
74
+ data=self.data,
75
+ parameters=self.printing_parameters,
76
+ raw_data_set=self.raw_data_set,
77
+ )
78
+ return self.renderer_class(table_data).render_html()
79
+ except Exception as exc: # pragma: no cover
80
+ # --- graceful degradation -------------------------------------------------
81
+ try:
82
+ from tabulate import tabulate
83
+
84
+ plain = tabulate(
85
+ self.data,
86
+ headers=self.headers,
87
+ tablefmt=self.tablefmt or "simple",
88
+ )
89
+ except Exception:
90
+ # Even `tabulate` failed – resort to the default __repr__.
91
+ plain = super().__repr__() if hasattr(super(), "__repr__") else str(self.data)
92
+
93
+ # Escape HTML-sensitive chars so the browser renders plain text.
94
+ import html
95
+
96
+ safe_plain = html.escape(plain)
97
+ return f"<pre>{safe_plain}\n\n[TableDisplay fallback – original error: {exc}]</pre>"
65
98
 
66
99
  def __repr__(self):
67
100
  # If rich format is requested, use RichRenderer
@@ -4,7 +4,7 @@ import os
4
4
  import json
5
5
  from typing import Any, Callable, Iterable, Iterator, List, Optional
6
6
  from abc import ABC, abstractmethod
7
- from collections.abc import MutableSequence
7
+ from collections.abc import MutableSequence, MutableMapping
8
8
 
9
9
 
10
10
  class SQLiteList(MutableSequence, ABC):
@@ -97,7 +97,20 @@ class SQLiteList(MutableSequence, ABC):
97
97
  row = cursor.fetchone()
98
98
  if row is None:
99
99
  raise IndexError("list index out of range")
100
- return self.deserialize(row[0])
100
+
101
+ obj = self.deserialize(row[0])
102
+
103
+ # If the stored object is a Scenario (or subclass), return a specialised proxy
104
+ try:
105
+ from edsl.scenarios.scenario import Scenario
106
+ if isinstance(obj, Scenario):
107
+ return self._make_scenario_proxy(self, index, obj)
108
+ except ImportError:
109
+ # Scenario not available – fall back to generic proxy
110
+ pass
111
+
112
+ # Generic proxy for other types
113
+ return self._RowProxy(self, index, obj)
101
114
 
102
115
  def __setitem__(self, index, value):
103
116
  if index < 0:
@@ -346,4 +359,90 @@ class SQLiteList(MutableSequence, ABC):
346
359
  self.conn.close()
347
360
  os.unlink(self.db_path)
348
361
  except:
349
- pass
362
+ pass
363
+
364
+ class _RowProxy(MutableMapping):
365
+ """A write-through proxy returned by SQLiteList.__getitem__.
366
+
367
+ Any mutation on the proxy (e.g. proxy[key] = value) is immediately
368
+ re-serialised and written back to the underlying SQLite storage,
369
+ ensuring the database stays in sync with in-memory edits.
370
+ """
371
+
372
+ def __init__(self, parent: "SQLiteList", idx: int, obj: Any):
373
+ self._parent = parent
374
+ self._idx = idx
375
+ self._obj = obj # The real deserialised object (e.g. Scenario)
376
+
377
+ # ---- MutableMapping interface ----
378
+ def __getitem__(self, key):
379
+ return self._obj[key]
380
+
381
+ def __setitem__(self, key, value):
382
+ self._obj[key] = value
383
+ # Propagate change back to SQLite via parent list
384
+ self._parent.__setitem__(self._idx, self._obj)
385
+
386
+ def __delitem__(self, key):
387
+ del self._obj[key]
388
+ self._parent.__setitem__(self._idx, self._obj)
389
+
390
+ def __iter__(self):
391
+ return iter(self._obj)
392
+
393
+ def __len__(self):
394
+ return len(self._obj)
395
+
396
+ # ---- Convenience helpers ----
397
+ def __getattr__(self, name): # Delegate attribute access
398
+ return getattr(self._obj, name)
399
+
400
+ def __repr__(self):
401
+ return repr(self._obj)
402
+
403
+ # Specialised proxy for Scenario objects so isinstance(obj, Scenario) remains True.
404
+ # Defined lazily to avoid importing Scenario at module load time for performance.
405
+ @staticmethod
406
+ def _make_scenario_proxy(parent: "SQLiteList", idx: int, scenario_obj: Any):
407
+ """Create and return an on-the-fly proxy class inheriting from Scenario but
408
+ immediately removed from the global subclass registry so serialization
409
+ coverage tests ignore it.
410
+ """
411
+ from edsl.scenarios.scenario import Scenario # local import
412
+ from edsl.base import RegisterSubclassesMeta
413
+
414
+ # Dynamically build class dict with required methods
415
+ def _proxy_setitem(self, key, value):
416
+ Scenario.__setitem__(self, key, value) # super call avoids MRO confusion
417
+ from edsl.scenarios.scenario import Scenario as S
418
+ self._parent.__setitem__(self._idx, S(dict(self)))
419
+
420
+ def _proxy_delitem(self, key):
421
+ Scenario.__delitem__(self, key)
422
+ from edsl.scenarios.scenario import Scenario as S
423
+ self._parent.__setitem__(self._idx, S(dict(self)))
424
+
425
+ def _proxy_reduce(self):
426
+ from edsl.scenarios.scenario import Scenario as S
427
+ return (S, (dict(self),))
428
+
429
+ proxy_cls = type(
430
+ "_ScenarioRowProxy",
431
+ (Scenario,),
432
+ {
433
+ "__setitem__": _proxy_setitem,
434
+ "__delitem__": _proxy_delitem,
435
+ "__reduce__": _proxy_reduce,
436
+ "__module__": Scenario.__module__,
437
+ },
438
+ )
439
+
440
+ # Remove this helper class from global registry so tests ignore it
441
+ RegisterSubclassesMeta._registry.pop(proxy_cls.__name__, None)
442
+
443
+ # Instantiate
444
+ instance = proxy_cls(dict(scenario_obj))
445
+ # attach parent tracking attributes
446
+ instance._parent = parent
447
+ instance._idx = idx
448
+ return instance
@@ -5,6 +5,7 @@ from ..data_transfer_models import EDSLResultObjectInput
5
5
 
6
6
  # from edsl.data_transfer_models import VisibilityType
7
7
  from ..caching import Cache
8
+
8
9
  # Import BucketCollection lazily to avoid circular imports
9
10
  from ..key_management import KeyLookup
10
11
  from ..base import Base
@@ -18,23 +19,27 @@ if TYPE_CHECKING:
18
19
 
19
20
  VisibilityType = Literal["private", "public", "unlisted"]
20
21
 
22
+
21
23
  @dataclass
22
24
  class RunEnvironment:
23
25
  """
24
26
  Contains environment-related resources for job execution.
25
-
26
- This dataclass holds references to shared resources and infrastructure components
27
- needed for job execution. These components are typically long-lived and may be
27
+
28
+ This dataclass holds references to shared resources and infrastructure components
29
+ needed for job execution. These components are typically long-lived and may be
28
30
  shared across multiple job runs.
29
-
31
+
30
32
  Attributes:
31
33
  cache (Cache, optional): Cache for storing and retrieving interview results
32
34
  bucket_collection (BucketCollection, optional): Collection of token rate limit buckets
33
35
  key_lookup (KeyLookup, optional): Manager for API keys across models
34
36
  jobs_runner_status (JobsRunnerStatus, optional): Tracker for job execution progress
35
37
  """
38
+
36
39
  cache: Optional[Cache] = None
37
- bucket_collection: Optional[Any] = None # Using Any to avoid circular import of BucketCollection
40
+ bucket_collection: Optional[
41
+ Any
42
+ ] = None # Using Any to avoid circular import of BucketCollection
38
43
  key_lookup: Optional[KeyLookup] = None
39
44
  jobs_runner_status: Optional["JobsRunnerStatus"] = None
40
45
 
@@ -43,11 +48,11 @@ class RunEnvironment:
43
48
  class RunParameters(Base):
44
49
  """
45
50
  Contains execution-specific parameters for job runs.
46
-
51
+
47
52
  This dataclass holds parameters that control the behavior of a specific job run,
48
53
  such as iteration count, error handling preferences, and remote execution options.
49
54
  Unlike RunEnvironment, these parameters are specific to a single job execution.
50
-
55
+
51
56
  Attributes:
52
57
  n (int): Number of iterations to run each interview, default is 1
53
58
  progress_bar (bool): Whether to show a progress bar, default is False
@@ -66,7 +71,9 @@ class RunParameters(Base):
66
71
  disable_remote_inference (bool): Whether to disable remote inference, default is False
67
72
  job_uuid (str, optional): UUID for the job, used for tracking
68
73
  fresh (bool): If True, ignore cache and generate new results, default is False
74
+ new_format (bool): If True, uses remote_inference_create method, if False uses old_remote_inference_create method, default is True
69
75
  """
76
+
70
77
  n: int = 1
71
78
  progress_bar: bool = False
72
79
  stop_on_exception: bool = False
@@ -82,8 +89,13 @@ class RunParameters(Base):
82
89
  disable_remote_cache: bool = False
83
90
  disable_remote_inference: bool = False
84
91
  job_uuid: Optional[str] = None
85
- fresh: bool = False # if True, will not use cache and will save new results to cache
86
- memory_threshold: Optional[int] = None # Threshold in bytes for Results SQLList memory management
92
+ fresh: bool = (
93
+ False # if True, will not use cache and will save new results to cache
94
+ )
95
+ memory_threshold: Optional[
96
+ int
97
+ ] = None # Threshold in bytes for Results SQLList memory management
98
+ new_format: bool = True # if True, uses remote_inference_create, if False uses old_remote_inference_create
87
99
 
88
100
  def to_dict(self, add_edsl_version=False) -> dict:
89
101
  d = asdict(self)
@@ -110,24 +122,25 @@ class RunParameters(Base):
110
122
  class RunConfig:
111
123
  """
112
124
  Combines environment resources and execution parameters for a job run.
113
-
125
+
114
126
  This class brings together the two aspects of job configuration:
115
127
  1. Environment resources (caches, API keys, etc.) via RunEnvironment
116
128
  2. Execution parameters (iterations, error handling, etc.) via RunParameters
117
-
129
+
118
130
  It provides helper methods for modifying environment components after construction.
119
-
131
+
120
132
  Attributes:
121
133
  environment (RunEnvironment): The environment resources for the job
122
134
  parameters (RunParameters): The execution parameters for the job
123
135
  """
136
+
124
137
  environment: RunEnvironment
125
138
  parameters: RunParameters
126
139
 
127
140
  def add_environment(self, environment: RunEnvironment) -> None:
128
141
  """
129
142
  Replace the entire environment configuration.
130
-
143
+
131
144
  Parameters:
132
145
  environment (RunEnvironment): The new environment configuration
133
146
  """
@@ -136,7 +149,7 @@ class RunConfig:
136
149
  def add_bucket_collection(self, bucket_collection: "BucketCollection") -> None:
137
150
  """
138
151
  Set or replace the bucket collection in the environment.
139
-
152
+
140
153
  Parameters:
141
154
  bucket_collection (BucketCollection): The bucket collection to use
142
155
  """
@@ -145,7 +158,7 @@ class RunConfig:
145
158
  def add_cache(self, cache: Cache) -> None:
146
159
  """
147
160
  Set or replace the cache in the environment.
148
-
161
+
149
162
  Parameters:
150
163
  cache (Cache): The cache to use
151
164
  """
@@ -154,7 +167,7 @@ class RunConfig:
154
167
  def add_key_lookup(self, key_lookup: KeyLookup) -> None:
155
168
  """
156
169
  Set or replace the key lookup in the environment.
157
-
170
+
158
171
  Parameters:
159
172
  key_lookup (KeyLookup): The key lookup to use
160
173
  """
@@ -169,10 +182,10 @@ Additional data structures for working with job results and answers.
169
182
  class Answers(UserDict):
170
183
  """
171
184
  A specialized dictionary for holding interview response data.
172
-
185
+
173
186
  This class extends UserDict to provide a flexible container for survey answers,
174
187
  with special handling for response metadata like comments and token usage.
175
-
188
+
176
189
  Key features:
177
190
  - Stores answers by question name
178
191
  - Associates comments with their respective questions
@@ -185,14 +198,14 @@ class Answers(UserDict):
185
198
  ) -> None:
186
199
  """
187
200
  Add a response to the answers dictionary.
188
-
201
+
189
202
  This method processes a response and stores it in the dictionary with appropriate
190
203
  naming conventions for the answer itself, comments, and token usage tracking.
191
-
204
+
192
205
  Parameters:
193
206
  response (EDSLResultObjectInput): The response object containing answer data
194
207
  question (QuestionBase): The question that was answered
195
-
208
+
196
209
  Notes:
197
210
  - The main answer is stored with the question's name as the key
198
211
  - Comments are stored with "_comment" appended to the question name
@@ -201,31 +214,33 @@ class Answers(UserDict):
201
214
  answer = response.answer
202
215
  comment = response.comment
203
216
  generated_tokens = response.generated_tokens
204
-
217
+
205
218
  # Record token usage if available
206
219
  if generated_tokens:
207
220
  self[question.question_name + "_generated_tokens"] = generated_tokens
208
-
221
+
209
222
  # Record the primary answer
210
223
  self[question.question_name] = answer
211
-
224
+
212
225
  # Record comment if present
213
226
  if comment:
214
227
  self[question.question_name + "_comment"] = comment
215
228
 
216
229
  if getattr(response, "reasoning_summary", None):
217
- self[question.question_name + "_reasoning_summary"] = response.reasoning_summary
230
+ self[
231
+ question.question_name + "_reasoning_summary"
232
+ ] = response.reasoning_summary
218
233
 
219
234
  def replace_missing_answers_with_none(self, survey: "Survey") -> None:
220
235
  """
221
236
  Replace missing answers with None for all questions in the survey.
222
-
237
+
223
238
  This method ensures that all questions in the survey have an entry in the
224
239
  answers dictionary, even if they were skipped during the interview.
225
-
240
+
226
241
  Parameters:
227
242
  survey (Survey): The survey containing the questions to check
228
-
243
+
229
244
  Notes:
230
245
  - Answers can be missing if the agent skips a question due to skip logic
231
246
  - This ensures consistent data structure even with partial responses
@@ -237,7 +252,7 @@ class Answers(UserDict):
237
252
  def to_dict(self) -> dict:
238
253
  """
239
254
  Convert the answers to a standard dictionary.
240
-
255
+
241
256
  Returns:
242
257
  dict: A plain dictionary containing all the answers data
243
258
  """
@@ -247,10 +262,10 @@ class Answers(UserDict):
247
262
  def from_dict(cls, d: dict) -> "Answers":
248
263
  """
249
264
  Create an Answers object from a dictionary.
250
-
265
+
251
266
  Parameters:
252
267
  d (dict): The dictionary containing answer data
253
-
268
+
254
269
  Returns:
255
270
  Answers: A new Answers instance with the provided data
256
271
  """