edsl 0.1.59__py3-none-any.whl → 0.1.61__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,171 @@
1
+ import reprlib
2
+ from typing import Optional
3
+
4
+ from .exceptions import CoopValueError
5
+ from ..scenarios import Scenario, ScenarioList
6
+
7
+
8
+ class CoopProlificFilters(ScenarioList):
9
+ """Base class for Prolific filters supported on Coop.
10
+
11
+ This abstract class extends ScenarioList to provide common functionality
12
+ for working with Prolific filters.
13
+ """
14
+
15
+ def __init__(
16
+ self, data: Optional[list] = None, codebook: Optional[dict[str, str]] = None
17
+ ):
18
+ super().__init__(data, codebook)
19
+
20
+ def find(self, filter_id: str) -> Optional[Scenario]:
21
+ """
22
+ Find a filter by its ID. Raises a CoopValueError if the filter is not found.
23
+
24
+ >>> filters = coop.list_prolific_filters()
25
+ >>> filters.find("age")
26
+ Scenario(
27
+ {
28
+ "filter_id": "age",
29
+ "type": "range",
30
+ "range_filter_min": 18,
31
+ "range_filter_max": 100,
32
+ ...
33
+ }
34
+ """
35
+
36
+ # Prolific has inconsistent naming conventions for filters -
37
+ # some use underscores and some use dashes, so we need to check for both
38
+ id_with_dashes = filter_id.replace("_", "-")
39
+ id_with_underscores = filter_id.replace("-", "_")
40
+
41
+ for scenario in self:
42
+ if (
43
+ scenario["filter_id"] == id_with_dashes
44
+ or scenario["filter_id"] == id_with_underscores
45
+ ):
46
+ return scenario
47
+ raise CoopValueError(f"Filter with ID {filter_id} not found.")
48
+
49
+ def create_study_filter(
50
+ self,
51
+ filter_id: str,
52
+ min: Optional[int] = None,
53
+ max: Optional[int] = None,
54
+ values: Optional[list[str]] = None,
55
+ ) -> dict:
56
+ """
57
+ Create a valid filter dict that is compatible with Coop.create_prolific_study().
58
+ This function will raise a CoopValueError if:
59
+ - The filter ID is not found
60
+ - A range filter is provided with no min or max value, or a value that is outside of the allowed range
61
+ - A select filter is provided with no values, or a value that is not in the allowed options
62
+
63
+ For a select filter, you should pass a list of values:
64
+ >>> filters = coop.list_prolific_filters()
65
+ >>> filters.create_study_filter("current_country_of_residence", values=["United States", "Canada"])
66
+ {
67
+ "filter_id": "current_country_of_residence",
68
+ "selected_values": ["1", "45"],
69
+ }
70
+
71
+ For a range filter, you should pass a min and max value:
72
+ >>> filters.create_study_filter("age", min=20, max=40)
73
+ {
74
+ "filter_id": "age",
75
+ "selected_range": {
76
+ "lower": 20,
77
+ "upper": 40,
78
+ },
79
+ }
80
+ """
81
+ filter = self.find(filter_id)
82
+
83
+ # .find() has logic to handle inconsistent naming conventions for filter IDs,
84
+ # so we need to get the correct filter ID from the filter dict
85
+ correct_filter_id = filter.get("filter_id")
86
+
87
+ filter_type = filter.get("type")
88
+
89
+ if filter_type == "range":
90
+ filter_min = filter.get("range_filter_min")
91
+ filter_max = filter.get("range_filter_max")
92
+
93
+ if min is None and max is None:
94
+ raise CoopValueError("Range filters require both a min and max value.")
95
+ if min < filter_min:
96
+ raise CoopValueError(
97
+ f"Min value {min} is less than the minimum allowed value {filter_min}."
98
+ )
99
+ if max > filter_max:
100
+ raise CoopValueError(
101
+ f"Max value {max} is greater than the maximum allowed value {filter_max}."
102
+ )
103
+ if min > max:
104
+ raise CoopValueError("Min value cannot be greater than max value.")
105
+ return {
106
+ "filter_id": correct_filter_id,
107
+ "selected_range": {
108
+ "lower": min,
109
+ "upper": max,
110
+ },
111
+ }
112
+ elif filter_type == "select":
113
+ if values is None:
114
+ raise CoopValueError("Select filters require a list of values.")
115
+
116
+ if correct_filter_id == "custom_allowlist":
117
+ return {
118
+ "filter_id": correct_filter_id,
119
+ "selected_values": values,
120
+ }
121
+
122
+ try:
123
+ allowed_option_labels = filter.get("select_filter_options", {})
124
+ option_labels_to_ids = {v: k for k, v in allowed_option_labels.items()}
125
+ selected_option_ids = [option_labels_to_ids[value] for value in values]
126
+ except KeyError:
127
+ raise CoopValueError(
128
+ f"Invalid value(s) provided for filter {filter_id}: {values}. "
129
+ f"Call find() with the filter ID to examine the allowed values for this filter."
130
+ )
131
+
132
+ return {
133
+ "filter_id": correct_filter_id,
134
+ "selected_values": selected_option_ids,
135
+ }
136
+ else:
137
+ raise CoopValueError(f"Unsupported filter type: {filter_type}.")
138
+
139
+ def table(
140
+ self,
141
+ *fields,
142
+ tablefmt: Optional[str] = None,
143
+ pretty_labels: Optional[dict[str, str]] = None,
144
+ ) -> str:
145
+ """Return the CoopProlificFilters as a table with truncated options display for select filters."""
146
+
147
+ # Create a copy of the data with truncated options
148
+ truncated_scenarios = []
149
+ for scenario in self:
150
+ scenario_dict = dict(scenario)
151
+ if (
152
+ "select_filter_options" in scenario_dict
153
+ and scenario_dict["select_filter_options"] is not None
154
+ ):
155
+
156
+ # Create a truncated representation of the options list
157
+ formatter = reprlib.Repr()
158
+ formatter.maxstring = 50
159
+ select_filter_options = list(
160
+ dict(scenario_dict["select_filter_options"]).values()
161
+ )
162
+ formatted_options = formatter.repr(select_filter_options)
163
+ scenario_dict["select_filter_options"] = formatted_options
164
+ truncated_scenarios.append(scenario_dict)
165
+
166
+ temp_scenario_list = ScenarioList([Scenario(s) for s in truncated_scenarios])
167
+
168
+ # Display the table with the truncated data
169
+ return temp_scenario_list.table(
170
+ *fields, tablefmt=tablefmt, pretty_labels=pretty_labels
171
+ )
@@ -357,7 +357,7 @@ class DataOperationsBase:
357
357
  4
358
358
  >>> engine = Results.example()._db(shape = "long")
359
359
  >>> len(engine.execute(text("SELECT * FROM self")).fetchall())
360
- 204
360
+ 212
361
361
  """
362
362
  # Import needed for database connection
363
363
  from sqlalchemy import create_engine
@@ -442,7 +442,7 @@ class DataOperationsBase:
442
442
 
443
443
  # Using long format
444
444
  >>> len(r.sql("SELECT * FROM self", shape="long"))
445
- 204
445
+ 212
446
446
  """
447
447
  import pandas as pd
448
448
 
@@ -55,13 +55,46 @@ class TableDisplay:
55
55
  self.printing_parameters = {}
56
56
 
57
57
  def _repr_html_(self) -> str:
58
- table_data = TableData(
59
- headers=self.headers,
60
- data=self.data,
61
- parameters=self.printing_parameters,
62
- raw_data_set=self.raw_data_set,
63
- )
64
- return self.renderer_class(table_data).render_html()
58
+ """
59
+ HTML representation for Jupyter/Colab notebooks.
60
+
61
+ The primary path uses the configured `renderer_class` to build an HTML
62
+ string. Unfortunately, in shared or long-running notebook runtimes it
63
+ is not uncommon for binary dependencies (NumPy, Pandas, etc.) to get
64
+ into an incompatible state, raising import-time errors that would
65
+ otherwise bubble up to the notebook and obscure the actual table
66
+ output. To make the developer experience smoother we catch *any*
67
+ exception, log/annotate it, and fall back to a plain-text rendering via
68
+ `tabulate`, wrapped in a <pre> block so at least a readable table is
69
+ shown.
70
+ """
71
+ try:
72
+ table_data = TableData(
73
+ headers=self.headers,
74
+ data=self.data,
75
+ parameters=self.printing_parameters,
76
+ raw_data_set=self.raw_data_set,
77
+ )
78
+ return self.renderer_class(table_data).render_html()
79
+ except Exception as exc: # pragma: no cover
80
+ # --- graceful degradation -------------------------------------------------
81
+ try:
82
+ from tabulate import tabulate
83
+
84
+ plain = tabulate(
85
+ self.data,
86
+ headers=self.headers,
87
+ tablefmt=self.tablefmt or "simple",
88
+ )
89
+ except Exception:
90
+ # Even `tabulate` failed – resort to the default __repr__.
91
+ plain = super().__repr__() if hasattr(super(), "__repr__") else str(self.data)
92
+
93
+ # Escape HTML-sensitive chars so the browser renders plain text.
94
+ import html
95
+
96
+ safe_plain = html.escape(plain)
97
+ return f"<pre>{safe_plain}\n\n[TableDisplay fallback – original error: {exc}]</pre>"
65
98
 
66
99
  def __repr__(self):
67
100
  # If rich format is requested, use RichRenderer
@@ -4,7 +4,7 @@ import os
4
4
  import json
5
5
  from typing import Any, Callable, Iterable, Iterator, List, Optional
6
6
  from abc import ABC, abstractmethod
7
- from collections.abc import MutableSequence
7
+ from collections.abc import MutableSequence, MutableMapping
8
8
 
9
9
 
10
10
  class SQLiteList(MutableSequence, ABC):
@@ -97,7 +97,20 @@ class SQLiteList(MutableSequence, ABC):
97
97
  row = cursor.fetchone()
98
98
  if row is None:
99
99
  raise IndexError("list index out of range")
100
- return self.deserialize(row[0])
100
+
101
+ obj = self.deserialize(row[0])
102
+
103
+ # If the stored object is a Scenario (or subclass), return a specialised proxy
104
+ try:
105
+ from edsl.scenarios.scenario import Scenario
106
+ if isinstance(obj, Scenario):
107
+ return self._make_scenario_proxy(self, index, obj)
108
+ except ImportError:
109
+ # Scenario not available – fall back to generic proxy
110
+ pass
111
+
112
+ # Generic proxy for other types
113
+ return self._RowProxy(self, index, obj)
101
114
 
102
115
  def __setitem__(self, index, value):
103
116
  if index < 0:
@@ -346,4 +359,90 @@ class SQLiteList(MutableSequence, ABC):
346
359
  self.conn.close()
347
360
  os.unlink(self.db_path)
348
361
  except:
349
- pass
362
+ pass
363
+
364
+ class _RowProxy(MutableMapping):
365
+ """A write-through proxy returned by SQLiteList.__getitem__.
366
+
367
+ Any mutation on the proxy (e.g. proxy[key] = value) is immediately
368
+ re-serialised and written back to the underlying SQLite storage,
369
+ ensuring the database stays in sync with in-memory edits.
370
+ """
371
+
372
+ def __init__(self, parent: "SQLiteList", idx: int, obj: Any):
373
+ self._parent = parent
374
+ self._idx = idx
375
+ self._obj = obj # The real deserialised object (e.g. Scenario)
376
+
377
+ # ---- MutableMapping interface ----
378
+ def __getitem__(self, key):
379
+ return self._obj[key]
380
+
381
+ def __setitem__(self, key, value):
382
+ self._obj[key] = value
383
+ # Propagate change back to SQLite via parent list
384
+ self._parent.__setitem__(self._idx, self._obj)
385
+
386
+ def __delitem__(self, key):
387
+ del self._obj[key]
388
+ self._parent.__setitem__(self._idx, self._obj)
389
+
390
+ def __iter__(self):
391
+ return iter(self._obj)
392
+
393
+ def __len__(self):
394
+ return len(self._obj)
395
+
396
+ # ---- Convenience helpers ----
397
+ def __getattr__(self, name): # Delegate attribute access
398
+ return getattr(self._obj, name)
399
+
400
+ def __repr__(self):
401
+ return repr(self._obj)
402
+
403
+ # Specialised proxy for Scenario objects so isinstance(obj, Scenario) remains True.
404
+ # Defined lazily to avoid importing Scenario at module load time for performance.
405
+ @staticmethod
406
+ def _make_scenario_proxy(parent: "SQLiteList", idx: int, scenario_obj: Any):
407
+ """Create and return an on-the-fly proxy class inheriting from Scenario but
408
+ immediately removed from the global subclass registry so serialization
409
+ coverage tests ignore it.
410
+ """
411
+ from edsl.scenarios.scenario import Scenario # local import
412
+ from edsl.base import RegisterSubclassesMeta
413
+
414
+ # Dynamically build class dict with required methods
415
+ def _proxy_setitem(self, key, value):
416
+ Scenario.__setitem__(self, key, value) # super call avoids MRO confusion
417
+ from edsl.scenarios.scenario import Scenario as S
418
+ self._parent.__setitem__(self._idx, S(dict(self)))
419
+
420
+ def _proxy_delitem(self, key):
421
+ Scenario.__delitem__(self, key)
422
+ from edsl.scenarios.scenario import Scenario as S
423
+ self._parent.__setitem__(self._idx, S(dict(self)))
424
+
425
+ def _proxy_reduce(self):
426
+ from edsl.scenarios.scenario import Scenario as S
427
+ return (S, (dict(self),))
428
+
429
+ proxy_cls = type(
430
+ "_ScenarioRowProxy",
431
+ (Scenario,),
432
+ {
433
+ "__setitem__": _proxy_setitem,
434
+ "__delitem__": _proxy_delitem,
435
+ "__reduce__": _proxy_reduce,
436
+ "__module__": Scenario.__module__,
437
+ },
438
+ )
439
+
440
+ # Remove this helper class from global registry so tests ignore it
441
+ RegisterSubclassesMeta._registry.pop(proxy_cls.__name__, None)
442
+
443
+ # Instantiate
444
+ instance = proxy_cls(dict(scenario_obj))
445
+ # attach parent tracking attributes
446
+ instance._parent = parent
447
+ instance._idx = idx
448
+ return instance
@@ -8,6 +8,7 @@ from .groq_service import GroqService
8
8
  from .mistral_ai_service import MistralAIService
9
9
  from .ollama_service import OllamaService
10
10
  from .open_ai_service import OpenAIService
11
+ from .open_ai_service_v2 import OpenAIServiceV2
11
12
  from .perplexity_service import PerplexityService
12
13
  from .test_service import TestService
13
14
  from .together_ai_service import TogetherAIService
@@ -24,8 +25,9 @@ __all__ = [
24
25
  "MistralAIService",
25
26
  "OllamaService",
26
27
  "OpenAIService",
28
+ "OpenAIServiceV2",
27
29
  "PerplexityService",
28
30
  "TestService",
29
31
  "TogetherAIService",
30
32
  "XAIService",
31
- ]
33
+ ]
@@ -0,0 +1,243 @@
1
+ from __future__ import annotations
2
+ from typing import Any, List, Optional, Dict, NewType, TYPE_CHECKING
3
+ import os
4
+
5
+ import openai
6
+
7
+ from ..inference_service_abc import InferenceServiceABC
8
+
9
+ # Use TYPE_CHECKING to avoid circular imports at runtime
10
+ if TYPE_CHECKING:
11
+ from ...language_models import LanguageModel
12
+ from ..rate_limits_cache import rate_limits
13
+
14
+ # Default to completions API but can use responses API with parameter
15
+
16
+ if TYPE_CHECKING:
17
+ from ....scenarios.file_store import FileStore as Files
18
+ from ....invigilators.invigilator_base import InvigilatorBase as InvigilatorAI
19
+
20
+
21
+ APIToken = NewType("APIToken", str)
22
+
23
+
24
+ class OpenAIServiceV2(InferenceServiceABC):
25
+ """OpenAI service class using the Responses API."""
26
+
27
+ _inference_service_ = "openai_v2"
28
+ _env_key_name_ = "OPENAI_API_KEY"
29
+ _base_url_ = None
30
+
31
+ _sync_client_ = openai.OpenAI
32
+ _async_client_ = openai.AsyncOpenAI
33
+
34
+ _sync_client_instances: Dict[APIToken, openai.OpenAI] = {}
35
+ _async_client_instances: Dict[APIToken, openai.AsyncOpenAI] = {}
36
+
37
+ # sequence to extract text from response.output
38
+ key_sequence = ["output", 1, "content", 0, "text"]
39
+ usage_sequence = ["usage"]
40
+ # sequence to extract reasoning summary from response.output
41
+ reasoning_sequence = ["output", 0, "summary"]
42
+ input_token_name = "prompt_tokens"
43
+ output_token_name = "completion_tokens"
44
+
45
+ available_models_url = "https://platform.openai.com/docs/models/gp"
46
+
47
+ def __init_subclass__(cls, **kwargs):
48
+ super().__init_subclass__(**kwargs)
49
+ cls._sync_client_instances = {}
50
+ cls._async_client_instances = {}
51
+
52
+ @classmethod
53
+ def sync_client(cls, api_key: str) -> openai.OpenAI:
54
+ if api_key not in cls._sync_client_instances:
55
+ client = cls._sync_client_(
56
+ api_key=api_key,
57
+ base_url=cls._base_url_,
58
+ )
59
+ cls._sync_client_instances[api_key] = client
60
+ return cls._sync_client_instances[api_key]
61
+
62
+ @classmethod
63
+ def async_client(cls, api_key: str) -> openai.AsyncOpenAI:
64
+ if api_key not in cls._async_client_instances:
65
+ client = cls._async_client_(
66
+ api_key=api_key,
67
+ base_url=cls._base_url_,
68
+ )
69
+ cls._async_client_instances[api_key] = client
70
+ return cls._async_client_instances[api_key]
71
+
72
+ model_exclude_list = [
73
+ "whisper-1",
74
+ "davinci-002",
75
+ "dall-e-2",
76
+ "tts-1-hd-1106",
77
+ "tts-1-hd",
78
+ "dall-e-3",
79
+ "tts-1",
80
+ "babbage-002",
81
+ "tts-1-1106",
82
+ "text-embedding-3-large",
83
+ "text-embedding-3-small",
84
+ "text-embedding-ada-002",
85
+ "ft:davinci-002:mit-horton-lab::8OfuHgoo",
86
+ "gpt-3.5-turbo-instruct-0914",
87
+ "gpt-3.5-turbo-instruct",
88
+ ]
89
+ _models_list_cache: List[str] = []
90
+
91
+ @classmethod
92
+ def get_model_list(cls, api_key: str | None = None) -> List[str]:
93
+ if api_key is None:
94
+ api_key = os.getenv(cls._env_key_name_)
95
+ raw = cls.sync_client(api_key).models.list()
96
+ return raw.data if hasattr(raw, "data") else raw
97
+
98
+ @classmethod
99
+ def available(cls, api_token: str | None = None) -> List[str]:
100
+ if api_token is None:
101
+ api_token = os.getenv(cls._env_key_name_)
102
+ if not cls._models_list_cache:
103
+ data = cls.get_model_list(api_key=api_token)
104
+ cls._models_list_cache = [
105
+ m.id for m in data if m.id not in cls.model_exclude_list
106
+ ]
107
+ return cls._models_list_cache
108
+
109
+ @classmethod
110
+ def create_model(
111
+ cls,
112
+ model_name: str,
113
+ model_class_name: str | None = None,
114
+ ) -> LanguageModel:
115
+ if model_class_name is None:
116
+ model_class_name = cls.to_class_name(model_name)
117
+
118
+ from ...language_models import LanguageModel
119
+
120
+ class LLM(LanguageModel):
121
+ """Child class for OpenAI Responses API"""
122
+
123
+ key_sequence = cls.key_sequence
124
+ usage_sequence = cls.usage_sequence
125
+ reasoning_sequence = cls.reasoning_sequence
126
+ input_token_name = cls.input_token_name
127
+ output_token_name = cls.output_token_name
128
+ _inference_service_ = cls._inference_service_
129
+ _model_ = model_name
130
+ _parameters_ = {
131
+ "temperature": 0.5,
132
+ "max_tokens": 2000,
133
+ "top_p": 1,
134
+ "frequency_penalty": 0,
135
+ "presence_penalty": 0,
136
+ "logprobs": False,
137
+ "top_logprobs": 3,
138
+ }
139
+
140
+ def sync_client(self) -> openai.OpenAI:
141
+ return cls.sync_client(api_key=self.api_token)
142
+
143
+ def async_client(self) -> openai.AsyncOpenAI:
144
+ return cls.async_client(api_key=self.api_token)
145
+
146
+ @classmethod
147
+ def available(cls) -> list[str]:
148
+ return cls.sync_client().models.list().data
149
+
150
+ def get_headers(self) -> dict[str, Any]:
151
+ client = self.sync_client()
152
+ response = client.responses.with_raw_response.create(
153
+ model=self.model,
154
+ input=[{"role": "user", "content": "Say this is a test"}],
155
+ store=False,
156
+ )
157
+ return dict(response.headers)
158
+
159
+ def get_rate_limits(self) -> dict[str, Any]:
160
+ try:
161
+ headers = rate_limits.get("openai", self.get_headers())
162
+ except Exception:
163
+ return {"rpm": 10000, "tpm": 2000000}
164
+ return {
165
+ "rpm": int(headers["x-ratelimit-limit-requests"]),
166
+ "tpm": int(headers["x-ratelimit-limit-tokens"]),
167
+ }
168
+
169
+ async def async_execute_model_call(
170
+ self,
171
+ user_prompt: str,
172
+ system_prompt: str = "",
173
+ files_list: Optional[List[Files]] = None,
174
+ invigilator: Optional[InvigilatorAI] = None,
175
+ ) -> dict[str, Any]:
176
+ content = user_prompt
177
+ if files_list:
178
+ # embed files as separate inputs
179
+ content = [{"type": "text", "text": user_prompt}]
180
+ for f in files_list:
181
+ content.append(
182
+ {
183
+ "type": "image_url",
184
+ "image_url": {
185
+ "url": f"data:{f.mime_type};base64,{f.base64_string}"
186
+ },
187
+ }
188
+ )
189
+ # build input sequence
190
+ messages: Any
191
+ if system_prompt and not self.omit_system_prompt_if_empty:
192
+ messages = [
193
+ {"role": "system", "content": system_prompt},
194
+ {"role": "user", "content": content},
195
+ ]
196
+ else:
197
+ messages = [{"role": "user", "content": content}]
198
+
199
+ # All OpenAI models with the responses API use these base parameters
200
+ params = {
201
+ "model": self.model,
202
+ "input": messages,
203
+ "temperature": self.temperature,
204
+ "top_p": self.top_p,
205
+ "store": False,
206
+ }
207
+
208
+ # Check if this is a reasoning model (o-series models)
209
+ is_reasoning_model = any(tag in self.model for tag in ["o1", "o1-mini", "o3", "o3-mini", "o1-pro", "o4-mini"])
210
+
211
+ # Only add reasoning parameter for reasoning models
212
+ if is_reasoning_model:
213
+ params["reasoning"] = {"summary": "auto"}
214
+
215
+ # For all models using the responses API, use max_output_tokens
216
+ # instead of max_tokens (which is for the completions API)
217
+ params["max_output_tokens"] = self.max_tokens
218
+
219
+ # Specifically for o-series, we also set temperature to 1
220
+ if is_reasoning_model:
221
+ params["temperature"] = 1
222
+
223
+ client = self.async_client()
224
+ try:
225
+ response = await client.responses.create(**params)
226
+
227
+ except Exception as e:
228
+ return {"message": str(e)}
229
+
230
+ # convert to dict
231
+ response_dict = response.model_dump()
232
+ return response_dict
233
+
234
+ LLM.__name__ = model_class_name
235
+ return LLM
236
+
237
+ @staticmethod
238
+ def _create_reasoning_sequence():
239
+ """Create the reasoning sequence for extracting reasoning summaries from model responses."""
240
+ # For OpenAI responses, the reasoning summary is typically found at:
241
+ # ["output", 0, "summary"]
242
+ # This is the path to the 'summary' field in the first item of the 'output' array
243
+ return ["output", 0, "summary"]