edsl 0.1.59__py3-none-any.whl → 0.1.61__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/__version__.py +1 -1
- edsl/agents/agent.py +65 -17
- edsl/agents/agent_list.py +117 -33
- edsl/base/base_class.py +80 -11
- edsl/base/data_transfer_models.py +5 -0
- edsl/base/enums.py +7 -2
- edsl/config/config_class.py +7 -2
- edsl/coop/coop.py +1295 -85
- edsl/coop/coop_prolific_filters.py +171 -0
- edsl/dataset/dataset_operations_mixin.py +2 -2
- edsl/dataset/display/table_display.py +40 -7
- edsl/db_list/sqlite_list.py +102 -3
- edsl/inference_services/services/__init__.py +3 -1
- edsl/inference_services/services/open_ai_service_v2.py +243 -0
- edsl/jobs/data_structures.py +48 -30
- edsl/jobs/jobs.py +73 -2
- edsl/jobs/remote_inference.py +49 -15
- edsl/key_management/key_lookup_builder.py +25 -3
- edsl/language_models/language_model.py +2 -1
- edsl/language_models/raw_response_handler.py +126 -7
- edsl/questions/loop_processor.py +289 -10
- edsl/questions/templates/dict/answering_instructions.jinja +0 -1
- edsl/results/result.py +37 -0
- edsl/results/results.py +1 -0
- edsl/scenarios/scenario_list.py +31 -1
- edsl/scenarios/scenario_source.py +606 -498
- edsl/surveys/survey.py +198 -163
- {edsl-0.1.59.dist-info → edsl-0.1.61.dist-info}/METADATA +4 -4
- {edsl-0.1.59.dist-info → edsl-0.1.61.dist-info}/RECORD +32 -30
- {edsl-0.1.59.dist-info → edsl-0.1.61.dist-info}/LICENSE +0 -0
- {edsl-0.1.59.dist-info → edsl-0.1.61.dist-info}/WHEEL +0 -0
- {edsl-0.1.59.dist-info → edsl-0.1.61.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,171 @@
|
|
1
|
+
import reprlib
|
2
|
+
from typing import Optional
|
3
|
+
|
4
|
+
from .exceptions import CoopValueError
|
5
|
+
from ..scenarios import Scenario, ScenarioList
|
6
|
+
|
7
|
+
|
8
|
+
class CoopProlificFilters(ScenarioList):
|
9
|
+
"""Base class for Prolific filters supported on Coop.
|
10
|
+
|
11
|
+
This abstract class extends ScenarioList to provide common functionality
|
12
|
+
for working with Prolific filters.
|
13
|
+
"""
|
14
|
+
|
15
|
+
def __init__(
|
16
|
+
self, data: Optional[list] = None, codebook: Optional[dict[str, str]] = None
|
17
|
+
):
|
18
|
+
super().__init__(data, codebook)
|
19
|
+
|
20
|
+
def find(self, filter_id: str) -> Optional[Scenario]:
|
21
|
+
"""
|
22
|
+
Find a filter by its ID. Raises a CoopValueError if the filter is not found.
|
23
|
+
|
24
|
+
>>> filters = coop.list_prolific_filters()
|
25
|
+
>>> filters.find("age")
|
26
|
+
Scenario(
|
27
|
+
{
|
28
|
+
"filter_id": "age",
|
29
|
+
"type": "range",
|
30
|
+
"range_filter_min": 18,
|
31
|
+
"range_filter_max": 100,
|
32
|
+
...
|
33
|
+
}
|
34
|
+
"""
|
35
|
+
|
36
|
+
# Prolific has inconsistent naming conventions for filters -
|
37
|
+
# some use underscores and some use dashes, so we need to check for both
|
38
|
+
id_with_dashes = filter_id.replace("_", "-")
|
39
|
+
id_with_underscores = filter_id.replace("-", "_")
|
40
|
+
|
41
|
+
for scenario in self:
|
42
|
+
if (
|
43
|
+
scenario["filter_id"] == id_with_dashes
|
44
|
+
or scenario["filter_id"] == id_with_underscores
|
45
|
+
):
|
46
|
+
return scenario
|
47
|
+
raise CoopValueError(f"Filter with ID {filter_id} not found.")
|
48
|
+
|
49
|
+
def create_study_filter(
|
50
|
+
self,
|
51
|
+
filter_id: str,
|
52
|
+
min: Optional[int] = None,
|
53
|
+
max: Optional[int] = None,
|
54
|
+
values: Optional[list[str]] = None,
|
55
|
+
) -> dict:
|
56
|
+
"""
|
57
|
+
Create a valid filter dict that is compatible with Coop.create_prolific_study().
|
58
|
+
This function will raise a CoopValueError if:
|
59
|
+
- The filter ID is not found
|
60
|
+
- A range filter is provided with no min or max value, or a value that is outside of the allowed range
|
61
|
+
- A select filter is provided with no values, or a value that is not in the allowed options
|
62
|
+
|
63
|
+
For a select filter, you should pass a list of values:
|
64
|
+
>>> filters = coop.list_prolific_filters()
|
65
|
+
>>> filters.create_study_filter("current_country_of_residence", values=["United States", "Canada"])
|
66
|
+
{
|
67
|
+
"filter_id": "current_country_of_residence",
|
68
|
+
"selected_values": ["1", "45"],
|
69
|
+
}
|
70
|
+
|
71
|
+
For a range filter, you should pass a min and max value:
|
72
|
+
>>> filters.create_study_filter("age", min=20, max=40)
|
73
|
+
{
|
74
|
+
"filter_id": "age",
|
75
|
+
"selected_range": {
|
76
|
+
"lower": 20,
|
77
|
+
"upper": 40,
|
78
|
+
},
|
79
|
+
}
|
80
|
+
"""
|
81
|
+
filter = self.find(filter_id)
|
82
|
+
|
83
|
+
# .find() has logic to handle inconsistent naming conventions for filter IDs,
|
84
|
+
# so we need to get the correct filter ID from the filter dict
|
85
|
+
correct_filter_id = filter.get("filter_id")
|
86
|
+
|
87
|
+
filter_type = filter.get("type")
|
88
|
+
|
89
|
+
if filter_type == "range":
|
90
|
+
filter_min = filter.get("range_filter_min")
|
91
|
+
filter_max = filter.get("range_filter_max")
|
92
|
+
|
93
|
+
if min is None and max is None:
|
94
|
+
raise CoopValueError("Range filters require both a min and max value.")
|
95
|
+
if min < filter_min:
|
96
|
+
raise CoopValueError(
|
97
|
+
f"Min value {min} is less than the minimum allowed value {filter_min}."
|
98
|
+
)
|
99
|
+
if max > filter_max:
|
100
|
+
raise CoopValueError(
|
101
|
+
f"Max value {max} is greater than the maximum allowed value {filter_max}."
|
102
|
+
)
|
103
|
+
if min > max:
|
104
|
+
raise CoopValueError("Min value cannot be greater than max value.")
|
105
|
+
return {
|
106
|
+
"filter_id": correct_filter_id,
|
107
|
+
"selected_range": {
|
108
|
+
"lower": min,
|
109
|
+
"upper": max,
|
110
|
+
},
|
111
|
+
}
|
112
|
+
elif filter_type == "select":
|
113
|
+
if values is None:
|
114
|
+
raise CoopValueError("Select filters require a list of values.")
|
115
|
+
|
116
|
+
if correct_filter_id == "custom_allowlist":
|
117
|
+
return {
|
118
|
+
"filter_id": correct_filter_id,
|
119
|
+
"selected_values": values,
|
120
|
+
}
|
121
|
+
|
122
|
+
try:
|
123
|
+
allowed_option_labels = filter.get("select_filter_options", {})
|
124
|
+
option_labels_to_ids = {v: k for k, v in allowed_option_labels.items()}
|
125
|
+
selected_option_ids = [option_labels_to_ids[value] for value in values]
|
126
|
+
except KeyError:
|
127
|
+
raise CoopValueError(
|
128
|
+
f"Invalid value(s) provided for filter {filter_id}: {values}. "
|
129
|
+
f"Call find() with the filter ID to examine the allowed values for this filter."
|
130
|
+
)
|
131
|
+
|
132
|
+
return {
|
133
|
+
"filter_id": correct_filter_id,
|
134
|
+
"selected_values": selected_option_ids,
|
135
|
+
}
|
136
|
+
else:
|
137
|
+
raise CoopValueError(f"Unsupported filter type: {filter_type}.")
|
138
|
+
|
139
|
+
def table(
|
140
|
+
self,
|
141
|
+
*fields,
|
142
|
+
tablefmt: Optional[str] = None,
|
143
|
+
pretty_labels: Optional[dict[str, str]] = None,
|
144
|
+
) -> str:
|
145
|
+
"""Return the CoopProlificFilters as a table with truncated options display for select filters."""
|
146
|
+
|
147
|
+
# Create a copy of the data with truncated options
|
148
|
+
truncated_scenarios = []
|
149
|
+
for scenario in self:
|
150
|
+
scenario_dict = dict(scenario)
|
151
|
+
if (
|
152
|
+
"select_filter_options" in scenario_dict
|
153
|
+
and scenario_dict["select_filter_options"] is not None
|
154
|
+
):
|
155
|
+
|
156
|
+
# Create a truncated representation of the options list
|
157
|
+
formatter = reprlib.Repr()
|
158
|
+
formatter.maxstring = 50
|
159
|
+
select_filter_options = list(
|
160
|
+
dict(scenario_dict["select_filter_options"]).values()
|
161
|
+
)
|
162
|
+
formatted_options = formatter.repr(select_filter_options)
|
163
|
+
scenario_dict["select_filter_options"] = formatted_options
|
164
|
+
truncated_scenarios.append(scenario_dict)
|
165
|
+
|
166
|
+
temp_scenario_list = ScenarioList([Scenario(s) for s in truncated_scenarios])
|
167
|
+
|
168
|
+
# Display the table with the truncated data
|
169
|
+
return temp_scenario_list.table(
|
170
|
+
*fields, tablefmt=tablefmt, pretty_labels=pretty_labels
|
171
|
+
)
|
@@ -357,7 +357,7 @@ class DataOperationsBase:
|
|
357
357
|
4
|
358
358
|
>>> engine = Results.example()._db(shape = "long")
|
359
359
|
>>> len(engine.execute(text("SELECT * FROM self")).fetchall())
|
360
|
-
|
360
|
+
212
|
361
361
|
"""
|
362
362
|
# Import needed for database connection
|
363
363
|
from sqlalchemy import create_engine
|
@@ -442,7 +442,7 @@ class DataOperationsBase:
|
|
442
442
|
|
443
443
|
# Using long format
|
444
444
|
>>> len(r.sql("SELECT * FROM self", shape="long"))
|
445
|
-
|
445
|
+
212
|
446
446
|
"""
|
447
447
|
import pandas as pd
|
448
448
|
|
@@ -55,13 +55,46 @@ class TableDisplay:
|
|
55
55
|
self.printing_parameters = {}
|
56
56
|
|
57
57
|
def _repr_html_(self) -> str:
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
)
|
64
|
-
|
58
|
+
"""
|
59
|
+
HTML representation for Jupyter/Colab notebooks.
|
60
|
+
|
61
|
+
The primary path uses the configured `renderer_class` to build an HTML
|
62
|
+
string. Unfortunately, in shared or long-running notebook runtimes it
|
63
|
+
is not uncommon for binary dependencies (NumPy, Pandas, etc.) to get
|
64
|
+
into an incompatible state, raising import-time errors that would
|
65
|
+
otherwise bubble up to the notebook and obscure the actual table
|
66
|
+
output. To make the developer experience smoother we catch *any*
|
67
|
+
exception, log/annotate it, and fall back to a plain-text rendering via
|
68
|
+
`tabulate`, wrapped in a <pre> block so at least a readable table is
|
69
|
+
shown.
|
70
|
+
"""
|
71
|
+
try:
|
72
|
+
table_data = TableData(
|
73
|
+
headers=self.headers,
|
74
|
+
data=self.data,
|
75
|
+
parameters=self.printing_parameters,
|
76
|
+
raw_data_set=self.raw_data_set,
|
77
|
+
)
|
78
|
+
return self.renderer_class(table_data).render_html()
|
79
|
+
except Exception as exc: # pragma: no cover
|
80
|
+
# --- graceful degradation -------------------------------------------------
|
81
|
+
try:
|
82
|
+
from tabulate import tabulate
|
83
|
+
|
84
|
+
plain = tabulate(
|
85
|
+
self.data,
|
86
|
+
headers=self.headers,
|
87
|
+
tablefmt=self.tablefmt or "simple",
|
88
|
+
)
|
89
|
+
except Exception:
|
90
|
+
# Even `tabulate` failed – resort to the default __repr__.
|
91
|
+
plain = super().__repr__() if hasattr(super(), "__repr__") else str(self.data)
|
92
|
+
|
93
|
+
# Escape HTML-sensitive chars so the browser renders plain text.
|
94
|
+
import html
|
95
|
+
|
96
|
+
safe_plain = html.escape(plain)
|
97
|
+
return f"<pre>{safe_plain}\n\n[TableDisplay fallback – original error: {exc}]</pre>"
|
65
98
|
|
66
99
|
def __repr__(self):
|
67
100
|
# If rich format is requested, use RichRenderer
|
edsl/db_list/sqlite_list.py
CHANGED
@@ -4,7 +4,7 @@ import os
|
|
4
4
|
import json
|
5
5
|
from typing import Any, Callable, Iterable, Iterator, List, Optional
|
6
6
|
from abc import ABC, abstractmethod
|
7
|
-
from collections.abc import MutableSequence
|
7
|
+
from collections.abc import MutableSequence, MutableMapping
|
8
8
|
|
9
9
|
|
10
10
|
class SQLiteList(MutableSequence, ABC):
|
@@ -97,7 +97,20 @@ class SQLiteList(MutableSequence, ABC):
|
|
97
97
|
row = cursor.fetchone()
|
98
98
|
if row is None:
|
99
99
|
raise IndexError("list index out of range")
|
100
|
-
|
100
|
+
|
101
|
+
obj = self.deserialize(row[0])
|
102
|
+
|
103
|
+
# If the stored object is a Scenario (or subclass), return a specialised proxy
|
104
|
+
try:
|
105
|
+
from edsl.scenarios.scenario import Scenario
|
106
|
+
if isinstance(obj, Scenario):
|
107
|
+
return self._make_scenario_proxy(self, index, obj)
|
108
|
+
except ImportError:
|
109
|
+
# Scenario not available – fall back to generic proxy
|
110
|
+
pass
|
111
|
+
|
112
|
+
# Generic proxy for other types
|
113
|
+
return self._RowProxy(self, index, obj)
|
101
114
|
|
102
115
|
def __setitem__(self, index, value):
|
103
116
|
if index < 0:
|
@@ -346,4 +359,90 @@ class SQLiteList(MutableSequence, ABC):
|
|
346
359
|
self.conn.close()
|
347
360
|
os.unlink(self.db_path)
|
348
361
|
except:
|
349
|
-
pass
|
362
|
+
pass
|
363
|
+
|
364
|
+
class _RowProxy(MutableMapping):
|
365
|
+
"""A write-through proxy returned by SQLiteList.__getitem__.
|
366
|
+
|
367
|
+
Any mutation on the proxy (e.g. proxy[key] = value) is immediately
|
368
|
+
re-serialised and written back to the underlying SQLite storage,
|
369
|
+
ensuring the database stays in sync with in-memory edits.
|
370
|
+
"""
|
371
|
+
|
372
|
+
def __init__(self, parent: "SQLiteList", idx: int, obj: Any):
|
373
|
+
self._parent = parent
|
374
|
+
self._idx = idx
|
375
|
+
self._obj = obj # The real deserialised object (e.g. Scenario)
|
376
|
+
|
377
|
+
# ---- MutableMapping interface ----
|
378
|
+
def __getitem__(self, key):
|
379
|
+
return self._obj[key]
|
380
|
+
|
381
|
+
def __setitem__(self, key, value):
|
382
|
+
self._obj[key] = value
|
383
|
+
# Propagate change back to SQLite via parent list
|
384
|
+
self._parent.__setitem__(self._idx, self._obj)
|
385
|
+
|
386
|
+
def __delitem__(self, key):
|
387
|
+
del self._obj[key]
|
388
|
+
self._parent.__setitem__(self._idx, self._obj)
|
389
|
+
|
390
|
+
def __iter__(self):
|
391
|
+
return iter(self._obj)
|
392
|
+
|
393
|
+
def __len__(self):
|
394
|
+
return len(self._obj)
|
395
|
+
|
396
|
+
# ---- Convenience helpers ----
|
397
|
+
def __getattr__(self, name): # Delegate attribute access
|
398
|
+
return getattr(self._obj, name)
|
399
|
+
|
400
|
+
def __repr__(self):
|
401
|
+
return repr(self._obj)
|
402
|
+
|
403
|
+
# Specialised proxy for Scenario objects so isinstance(obj, Scenario) remains True.
|
404
|
+
# Defined lazily to avoid importing Scenario at module load time for performance.
|
405
|
+
@staticmethod
|
406
|
+
def _make_scenario_proxy(parent: "SQLiteList", idx: int, scenario_obj: Any):
|
407
|
+
"""Create and return an on-the-fly proxy class inheriting from Scenario but
|
408
|
+
immediately removed from the global subclass registry so serialization
|
409
|
+
coverage tests ignore it.
|
410
|
+
"""
|
411
|
+
from edsl.scenarios.scenario import Scenario # local import
|
412
|
+
from edsl.base import RegisterSubclassesMeta
|
413
|
+
|
414
|
+
# Dynamically build class dict with required methods
|
415
|
+
def _proxy_setitem(self, key, value):
|
416
|
+
Scenario.__setitem__(self, key, value) # super call avoids MRO confusion
|
417
|
+
from edsl.scenarios.scenario import Scenario as S
|
418
|
+
self._parent.__setitem__(self._idx, S(dict(self)))
|
419
|
+
|
420
|
+
def _proxy_delitem(self, key):
|
421
|
+
Scenario.__delitem__(self, key)
|
422
|
+
from edsl.scenarios.scenario import Scenario as S
|
423
|
+
self._parent.__setitem__(self._idx, S(dict(self)))
|
424
|
+
|
425
|
+
def _proxy_reduce(self):
|
426
|
+
from edsl.scenarios.scenario import Scenario as S
|
427
|
+
return (S, (dict(self),))
|
428
|
+
|
429
|
+
proxy_cls = type(
|
430
|
+
"_ScenarioRowProxy",
|
431
|
+
(Scenario,),
|
432
|
+
{
|
433
|
+
"__setitem__": _proxy_setitem,
|
434
|
+
"__delitem__": _proxy_delitem,
|
435
|
+
"__reduce__": _proxy_reduce,
|
436
|
+
"__module__": Scenario.__module__,
|
437
|
+
},
|
438
|
+
)
|
439
|
+
|
440
|
+
# Remove this helper class from global registry so tests ignore it
|
441
|
+
RegisterSubclassesMeta._registry.pop(proxy_cls.__name__, None)
|
442
|
+
|
443
|
+
# Instantiate
|
444
|
+
instance = proxy_cls(dict(scenario_obj))
|
445
|
+
# attach parent tracking attributes
|
446
|
+
instance._parent = parent
|
447
|
+
instance._idx = idx
|
448
|
+
return instance
|
@@ -8,6 +8,7 @@ from .groq_service import GroqService
|
|
8
8
|
from .mistral_ai_service import MistralAIService
|
9
9
|
from .ollama_service import OllamaService
|
10
10
|
from .open_ai_service import OpenAIService
|
11
|
+
from .open_ai_service_v2 import OpenAIServiceV2
|
11
12
|
from .perplexity_service import PerplexityService
|
12
13
|
from .test_service import TestService
|
13
14
|
from .together_ai_service import TogetherAIService
|
@@ -24,8 +25,9 @@ __all__ = [
|
|
24
25
|
"MistralAIService",
|
25
26
|
"OllamaService",
|
26
27
|
"OpenAIService",
|
28
|
+
"OpenAIServiceV2",
|
27
29
|
"PerplexityService",
|
28
30
|
"TestService",
|
29
31
|
"TogetherAIService",
|
30
32
|
"XAIService",
|
31
|
-
]
|
33
|
+
]
|
@@ -0,0 +1,243 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
from typing import Any, List, Optional, Dict, NewType, TYPE_CHECKING
|
3
|
+
import os
|
4
|
+
|
5
|
+
import openai
|
6
|
+
|
7
|
+
from ..inference_service_abc import InferenceServiceABC
|
8
|
+
|
9
|
+
# Use TYPE_CHECKING to avoid circular imports at runtime
|
10
|
+
if TYPE_CHECKING:
|
11
|
+
from ...language_models import LanguageModel
|
12
|
+
from ..rate_limits_cache import rate_limits
|
13
|
+
|
14
|
+
# Default to completions API but can use responses API with parameter
|
15
|
+
|
16
|
+
if TYPE_CHECKING:
|
17
|
+
from ....scenarios.file_store import FileStore as Files
|
18
|
+
from ....invigilators.invigilator_base import InvigilatorBase as InvigilatorAI
|
19
|
+
|
20
|
+
|
21
|
+
APIToken = NewType("APIToken", str)
|
22
|
+
|
23
|
+
|
24
|
+
class OpenAIServiceV2(InferenceServiceABC):
|
25
|
+
"""OpenAI service class using the Responses API."""
|
26
|
+
|
27
|
+
_inference_service_ = "openai_v2"
|
28
|
+
_env_key_name_ = "OPENAI_API_KEY"
|
29
|
+
_base_url_ = None
|
30
|
+
|
31
|
+
_sync_client_ = openai.OpenAI
|
32
|
+
_async_client_ = openai.AsyncOpenAI
|
33
|
+
|
34
|
+
_sync_client_instances: Dict[APIToken, openai.OpenAI] = {}
|
35
|
+
_async_client_instances: Dict[APIToken, openai.AsyncOpenAI] = {}
|
36
|
+
|
37
|
+
# sequence to extract text from response.output
|
38
|
+
key_sequence = ["output", 1, "content", 0, "text"]
|
39
|
+
usage_sequence = ["usage"]
|
40
|
+
# sequence to extract reasoning summary from response.output
|
41
|
+
reasoning_sequence = ["output", 0, "summary"]
|
42
|
+
input_token_name = "prompt_tokens"
|
43
|
+
output_token_name = "completion_tokens"
|
44
|
+
|
45
|
+
available_models_url = "https://platform.openai.com/docs/models/gp"
|
46
|
+
|
47
|
+
def __init_subclass__(cls, **kwargs):
|
48
|
+
super().__init_subclass__(**kwargs)
|
49
|
+
cls._sync_client_instances = {}
|
50
|
+
cls._async_client_instances = {}
|
51
|
+
|
52
|
+
@classmethod
|
53
|
+
def sync_client(cls, api_key: str) -> openai.OpenAI:
|
54
|
+
if api_key not in cls._sync_client_instances:
|
55
|
+
client = cls._sync_client_(
|
56
|
+
api_key=api_key,
|
57
|
+
base_url=cls._base_url_,
|
58
|
+
)
|
59
|
+
cls._sync_client_instances[api_key] = client
|
60
|
+
return cls._sync_client_instances[api_key]
|
61
|
+
|
62
|
+
@classmethod
|
63
|
+
def async_client(cls, api_key: str) -> openai.AsyncOpenAI:
|
64
|
+
if api_key not in cls._async_client_instances:
|
65
|
+
client = cls._async_client_(
|
66
|
+
api_key=api_key,
|
67
|
+
base_url=cls._base_url_,
|
68
|
+
)
|
69
|
+
cls._async_client_instances[api_key] = client
|
70
|
+
return cls._async_client_instances[api_key]
|
71
|
+
|
72
|
+
model_exclude_list = [
|
73
|
+
"whisper-1",
|
74
|
+
"davinci-002",
|
75
|
+
"dall-e-2",
|
76
|
+
"tts-1-hd-1106",
|
77
|
+
"tts-1-hd",
|
78
|
+
"dall-e-3",
|
79
|
+
"tts-1",
|
80
|
+
"babbage-002",
|
81
|
+
"tts-1-1106",
|
82
|
+
"text-embedding-3-large",
|
83
|
+
"text-embedding-3-small",
|
84
|
+
"text-embedding-ada-002",
|
85
|
+
"ft:davinci-002:mit-horton-lab::8OfuHgoo",
|
86
|
+
"gpt-3.5-turbo-instruct-0914",
|
87
|
+
"gpt-3.5-turbo-instruct",
|
88
|
+
]
|
89
|
+
_models_list_cache: List[str] = []
|
90
|
+
|
91
|
+
@classmethod
|
92
|
+
def get_model_list(cls, api_key: str | None = None) -> List[str]:
|
93
|
+
if api_key is None:
|
94
|
+
api_key = os.getenv(cls._env_key_name_)
|
95
|
+
raw = cls.sync_client(api_key).models.list()
|
96
|
+
return raw.data if hasattr(raw, "data") else raw
|
97
|
+
|
98
|
+
@classmethod
|
99
|
+
def available(cls, api_token: str | None = None) -> List[str]:
|
100
|
+
if api_token is None:
|
101
|
+
api_token = os.getenv(cls._env_key_name_)
|
102
|
+
if not cls._models_list_cache:
|
103
|
+
data = cls.get_model_list(api_key=api_token)
|
104
|
+
cls._models_list_cache = [
|
105
|
+
m.id for m in data if m.id not in cls.model_exclude_list
|
106
|
+
]
|
107
|
+
return cls._models_list_cache
|
108
|
+
|
109
|
+
@classmethod
|
110
|
+
def create_model(
|
111
|
+
cls,
|
112
|
+
model_name: str,
|
113
|
+
model_class_name: str | None = None,
|
114
|
+
) -> LanguageModel:
|
115
|
+
if model_class_name is None:
|
116
|
+
model_class_name = cls.to_class_name(model_name)
|
117
|
+
|
118
|
+
from ...language_models import LanguageModel
|
119
|
+
|
120
|
+
class LLM(LanguageModel):
|
121
|
+
"""Child class for OpenAI Responses API"""
|
122
|
+
|
123
|
+
key_sequence = cls.key_sequence
|
124
|
+
usage_sequence = cls.usage_sequence
|
125
|
+
reasoning_sequence = cls.reasoning_sequence
|
126
|
+
input_token_name = cls.input_token_name
|
127
|
+
output_token_name = cls.output_token_name
|
128
|
+
_inference_service_ = cls._inference_service_
|
129
|
+
_model_ = model_name
|
130
|
+
_parameters_ = {
|
131
|
+
"temperature": 0.5,
|
132
|
+
"max_tokens": 2000,
|
133
|
+
"top_p": 1,
|
134
|
+
"frequency_penalty": 0,
|
135
|
+
"presence_penalty": 0,
|
136
|
+
"logprobs": False,
|
137
|
+
"top_logprobs": 3,
|
138
|
+
}
|
139
|
+
|
140
|
+
def sync_client(self) -> openai.OpenAI:
|
141
|
+
return cls.sync_client(api_key=self.api_token)
|
142
|
+
|
143
|
+
def async_client(self) -> openai.AsyncOpenAI:
|
144
|
+
return cls.async_client(api_key=self.api_token)
|
145
|
+
|
146
|
+
@classmethod
|
147
|
+
def available(cls) -> list[str]:
|
148
|
+
return cls.sync_client().models.list().data
|
149
|
+
|
150
|
+
def get_headers(self) -> dict[str, Any]:
|
151
|
+
client = self.sync_client()
|
152
|
+
response = client.responses.with_raw_response.create(
|
153
|
+
model=self.model,
|
154
|
+
input=[{"role": "user", "content": "Say this is a test"}],
|
155
|
+
store=False,
|
156
|
+
)
|
157
|
+
return dict(response.headers)
|
158
|
+
|
159
|
+
def get_rate_limits(self) -> dict[str, Any]:
|
160
|
+
try:
|
161
|
+
headers = rate_limits.get("openai", self.get_headers())
|
162
|
+
except Exception:
|
163
|
+
return {"rpm": 10000, "tpm": 2000000}
|
164
|
+
return {
|
165
|
+
"rpm": int(headers["x-ratelimit-limit-requests"]),
|
166
|
+
"tpm": int(headers["x-ratelimit-limit-tokens"]),
|
167
|
+
}
|
168
|
+
|
169
|
+
async def async_execute_model_call(
|
170
|
+
self,
|
171
|
+
user_prompt: str,
|
172
|
+
system_prompt: str = "",
|
173
|
+
files_list: Optional[List[Files]] = None,
|
174
|
+
invigilator: Optional[InvigilatorAI] = None,
|
175
|
+
) -> dict[str, Any]:
|
176
|
+
content = user_prompt
|
177
|
+
if files_list:
|
178
|
+
# embed files as separate inputs
|
179
|
+
content = [{"type": "text", "text": user_prompt}]
|
180
|
+
for f in files_list:
|
181
|
+
content.append(
|
182
|
+
{
|
183
|
+
"type": "image_url",
|
184
|
+
"image_url": {
|
185
|
+
"url": f"data:{f.mime_type};base64,{f.base64_string}"
|
186
|
+
},
|
187
|
+
}
|
188
|
+
)
|
189
|
+
# build input sequence
|
190
|
+
messages: Any
|
191
|
+
if system_prompt and not self.omit_system_prompt_if_empty:
|
192
|
+
messages = [
|
193
|
+
{"role": "system", "content": system_prompt},
|
194
|
+
{"role": "user", "content": content},
|
195
|
+
]
|
196
|
+
else:
|
197
|
+
messages = [{"role": "user", "content": content}]
|
198
|
+
|
199
|
+
# All OpenAI models with the responses API use these base parameters
|
200
|
+
params = {
|
201
|
+
"model": self.model,
|
202
|
+
"input": messages,
|
203
|
+
"temperature": self.temperature,
|
204
|
+
"top_p": self.top_p,
|
205
|
+
"store": False,
|
206
|
+
}
|
207
|
+
|
208
|
+
# Check if this is a reasoning model (o-series models)
|
209
|
+
is_reasoning_model = any(tag in self.model for tag in ["o1", "o1-mini", "o3", "o3-mini", "o1-pro", "o4-mini"])
|
210
|
+
|
211
|
+
# Only add reasoning parameter for reasoning models
|
212
|
+
if is_reasoning_model:
|
213
|
+
params["reasoning"] = {"summary": "auto"}
|
214
|
+
|
215
|
+
# For all models using the responses API, use max_output_tokens
|
216
|
+
# instead of max_tokens (which is for the completions API)
|
217
|
+
params["max_output_tokens"] = self.max_tokens
|
218
|
+
|
219
|
+
# Specifically for o-series, we also set temperature to 1
|
220
|
+
if is_reasoning_model:
|
221
|
+
params["temperature"] = 1
|
222
|
+
|
223
|
+
client = self.async_client()
|
224
|
+
try:
|
225
|
+
response = await client.responses.create(**params)
|
226
|
+
|
227
|
+
except Exception as e:
|
228
|
+
return {"message": str(e)}
|
229
|
+
|
230
|
+
# convert to dict
|
231
|
+
response_dict = response.model_dump()
|
232
|
+
return response_dict
|
233
|
+
|
234
|
+
LLM.__name__ = model_class_name
|
235
|
+
return LLM
|
236
|
+
|
237
|
+
@staticmethod
|
238
|
+
def _create_reasoning_sequence():
|
239
|
+
"""Create the reasoning sequence for extracting reasoning summaries from model responses."""
|
240
|
+
# For OpenAI responses, the reasoning summary is typically found at:
|
241
|
+
# ["output", 0, "summary"]
|
242
|
+
# This is the path to the 'summary' field in the first item of the 'output' array
|
243
|
+
return ["output", 0, "summary"]
|