edsl 0.1.58__py3-none-any.whl → 0.1.60__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. edsl/__version__.py +1 -1
  2. edsl/agents/agent.py +23 -4
  3. edsl/agents/agent_list.py +36 -6
  4. edsl/base/data_transfer_models.py +5 -0
  5. edsl/base/enums.py +7 -2
  6. edsl/coop/coop.py +103 -1
  7. edsl/dataset/dataset.py +74 -0
  8. edsl/dataset/dataset_operations_mixin.py +69 -64
  9. edsl/inference_services/services/__init__.py +3 -1
  10. edsl/inference_services/services/open_ai_service_v2.py +243 -0
  11. edsl/inference_services/services/test_service.py +1 -1
  12. edsl/interviews/exception_tracking.py +66 -20
  13. edsl/invigilators/invigilators.py +5 -1
  14. edsl/invigilators/prompt_constructor.py +299 -136
  15. edsl/jobs/data_structures.py +3 -0
  16. edsl/jobs/html_table_job_logger.py +18 -1
  17. edsl/jobs/jobs_pricing_estimation.py +6 -2
  18. edsl/jobs/jobs_remote_inference_logger.py +2 -0
  19. edsl/jobs/remote_inference.py +34 -7
  20. edsl/key_management/key_lookup_builder.py +25 -3
  21. edsl/language_models/language_model.py +41 -3
  22. edsl/language_models/raw_response_handler.py +126 -7
  23. edsl/prompts/prompt.py +1 -0
  24. edsl/questions/question_list.py +76 -20
  25. edsl/results/result.py +37 -0
  26. edsl/results/results.py +9 -1
  27. edsl/scenarios/file_store.py +8 -12
  28. edsl/scenarios/scenario.py +50 -2
  29. edsl/scenarios/scenario_list.py +34 -12
  30. edsl/surveys/survey.py +4 -0
  31. edsl/tasks/task_history.py +180 -6
  32. edsl/utilities/wikipedia.py +194 -0
  33. {edsl-0.1.58.dist-info → edsl-0.1.60.dist-info}/METADATA +5 -4
  34. {edsl-0.1.58.dist-info → edsl-0.1.60.dist-info}/RECORD +37 -35
  35. {edsl-0.1.58.dist-info → edsl-0.1.60.dist-info}/LICENSE +0 -0
  36. {edsl-0.1.58.dist-info → edsl-0.1.60.dist-info}/WHEEL +0 -0
  37. {edsl-0.1.58.dist-info → edsl-0.1.60.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,243 @@
1
+ from __future__ import annotations
2
+ from typing import Any, List, Optional, Dict, NewType, TYPE_CHECKING
3
+ import os
4
+
5
+ import openai
6
+
7
+ from ..inference_service_abc import InferenceServiceABC
8
+
9
+ # Use TYPE_CHECKING to avoid circular imports at runtime
10
+ if TYPE_CHECKING:
11
+ from ...language_models import LanguageModel
12
+ from ..rate_limits_cache import rate_limits
13
+
14
+ # Default to completions API but can use responses API with parameter
15
+
16
+ if TYPE_CHECKING:
17
+ from ....scenarios.file_store import FileStore as Files
18
+ from ....invigilators.invigilator_base import InvigilatorBase as InvigilatorAI
19
+
20
+
21
+ APIToken = NewType("APIToken", str)
22
+
23
+
24
+ class OpenAIServiceV2(InferenceServiceABC):
25
+ """OpenAI service class using the Responses API."""
26
+
27
+ _inference_service_ = "openai_v2"
28
+ _env_key_name_ = "OPENAI_API_KEY"
29
+ _base_url_ = None
30
+
31
+ _sync_client_ = openai.OpenAI
32
+ _async_client_ = openai.AsyncOpenAI
33
+
34
+ _sync_client_instances: Dict[APIToken, openai.OpenAI] = {}
35
+ _async_client_instances: Dict[APIToken, openai.AsyncOpenAI] = {}
36
+
37
+ # sequence to extract text from response.output
38
+ key_sequence = ["output", 1, "content", 0, "text"]
39
+ usage_sequence = ["usage"]
40
+ # sequence to extract reasoning summary from response.output
41
+ reasoning_sequence = ["output", 0, "summary"]
42
+ input_token_name = "prompt_tokens"
43
+ output_token_name = "completion_tokens"
44
+
45
+ available_models_url = "https://platform.openai.com/docs/models/gp"
46
+
47
+ def __init_subclass__(cls, **kwargs):
48
+ super().__init_subclass__(**kwargs)
49
+ cls._sync_client_instances = {}
50
+ cls._async_client_instances = {}
51
+
52
+ @classmethod
53
+ def sync_client(cls, api_key: str) -> openai.OpenAI:
54
+ if api_key not in cls._sync_client_instances:
55
+ client = cls._sync_client_(
56
+ api_key=api_key,
57
+ base_url=cls._base_url_,
58
+ )
59
+ cls._sync_client_instances[api_key] = client
60
+ return cls._sync_client_instances[api_key]
61
+
62
+ @classmethod
63
+ def async_client(cls, api_key: str) -> openai.AsyncOpenAI:
64
+ if api_key not in cls._async_client_instances:
65
+ client = cls._async_client_(
66
+ api_key=api_key,
67
+ base_url=cls._base_url_,
68
+ )
69
+ cls._async_client_instances[api_key] = client
70
+ return cls._async_client_instances[api_key]
71
+
72
+ model_exclude_list = [
73
+ "whisper-1",
74
+ "davinci-002",
75
+ "dall-e-2",
76
+ "tts-1-hd-1106",
77
+ "tts-1-hd",
78
+ "dall-e-3",
79
+ "tts-1",
80
+ "babbage-002",
81
+ "tts-1-1106",
82
+ "text-embedding-3-large",
83
+ "text-embedding-3-small",
84
+ "text-embedding-ada-002",
85
+ "ft:davinci-002:mit-horton-lab::8OfuHgoo",
86
+ "gpt-3.5-turbo-instruct-0914",
87
+ "gpt-3.5-turbo-instruct",
88
+ ]
89
+ _models_list_cache: List[str] = []
90
+
91
+ @classmethod
92
+ def get_model_list(cls, api_key: str | None = None) -> List[str]:
93
+ if api_key is None:
94
+ api_key = os.getenv(cls._env_key_name_)
95
+ raw = cls.sync_client(api_key).models.list()
96
+ return raw.data if hasattr(raw, "data") else raw
97
+
98
+ @classmethod
99
+ def available(cls, api_token: str | None = None) -> List[str]:
100
+ if api_token is None:
101
+ api_token = os.getenv(cls._env_key_name_)
102
+ if not cls._models_list_cache:
103
+ data = cls.get_model_list(api_key=api_token)
104
+ cls._models_list_cache = [
105
+ m.id for m in data if m.id not in cls.model_exclude_list
106
+ ]
107
+ return cls._models_list_cache
108
+
109
+ @classmethod
110
+ def create_model(
111
+ cls,
112
+ model_name: str,
113
+ model_class_name: str | None = None,
114
+ ) -> LanguageModel:
115
+ if model_class_name is None:
116
+ model_class_name = cls.to_class_name(model_name)
117
+
118
+ from ...language_models import LanguageModel
119
+
120
+ class LLM(LanguageModel):
121
+ """Child class for OpenAI Responses API"""
122
+
123
+ key_sequence = cls.key_sequence
124
+ usage_sequence = cls.usage_sequence
125
+ reasoning_sequence = cls.reasoning_sequence
126
+ input_token_name = cls.input_token_name
127
+ output_token_name = cls.output_token_name
128
+ _inference_service_ = cls._inference_service_
129
+ _model_ = model_name
130
+ _parameters_ = {
131
+ "temperature": 0.5,
132
+ "max_tokens": 2000,
133
+ "top_p": 1,
134
+ "frequency_penalty": 0,
135
+ "presence_penalty": 0,
136
+ "logprobs": False,
137
+ "top_logprobs": 3,
138
+ }
139
+
140
+ def sync_client(self) -> openai.OpenAI:
141
+ return cls.sync_client(api_key=self.api_token)
142
+
143
+ def async_client(self) -> openai.AsyncOpenAI:
144
+ return cls.async_client(api_key=self.api_token)
145
+
146
+ @classmethod
147
+ def available(cls) -> list[str]:
148
+ return cls.sync_client().models.list().data
149
+
150
+ def get_headers(self) -> dict[str, Any]:
151
+ client = self.sync_client()
152
+ response = client.responses.with_raw_response.create(
153
+ model=self.model,
154
+ input=[{"role": "user", "content": "Say this is a test"}],
155
+ store=False,
156
+ )
157
+ return dict(response.headers)
158
+
159
+ def get_rate_limits(self) -> dict[str, Any]:
160
+ try:
161
+ headers = rate_limits.get("openai", self.get_headers())
162
+ except Exception:
163
+ return {"rpm": 10000, "tpm": 2000000}
164
+ return {
165
+ "rpm": int(headers["x-ratelimit-limit-requests"]),
166
+ "tpm": int(headers["x-ratelimit-limit-tokens"]),
167
+ }
168
+
169
+ async def async_execute_model_call(
170
+ self,
171
+ user_prompt: str,
172
+ system_prompt: str = "",
173
+ files_list: Optional[List[Files]] = None,
174
+ invigilator: Optional[InvigilatorAI] = None,
175
+ ) -> dict[str, Any]:
176
+ content = user_prompt
177
+ if files_list:
178
+ # embed files as separate inputs
179
+ content = [{"type": "text", "text": user_prompt}]
180
+ for f in files_list:
181
+ content.append(
182
+ {
183
+ "type": "image_url",
184
+ "image_url": {
185
+ "url": f"data:{f.mime_type};base64,{f.base64_string}"
186
+ },
187
+ }
188
+ )
189
+ # build input sequence
190
+ messages: Any
191
+ if system_prompt and not self.omit_system_prompt_if_empty:
192
+ messages = [
193
+ {"role": "system", "content": system_prompt},
194
+ {"role": "user", "content": content},
195
+ ]
196
+ else:
197
+ messages = [{"role": "user", "content": content}]
198
+
199
+ # All OpenAI models with the responses API use these base parameters
200
+ params = {
201
+ "model": self.model,
202
+ "input": messages,
203
+ "temperature": self.temperature,
204
+ "top_p": self.top_p,
205
+ "store": False,
206
+ }
207
+
208
+ # Check if this is a reasoning model (o-series models)
209
+ is_reasoning_model = any(tag in self.model for tag in ["o1", "o1-mini", "o3", "o3-mini", "o1-pro", "o4-mini"])
210
+
211
+ # Only add reasoning parameter for reasoning models
212
+ if is_reasoning_model:
213
+ params["reasoning"] = {"summary": "auto"}
214
+
215
+ # For all models using the responses API, use max_output_tokens
216
+ # instead of max_tokens (which is for the completions API)
217
+ params["max_output_tokens"] = self.max_tokens
218
+
219
+ # Specifically for o-series, we also set temperature to 1
220
+ if is_reasoning_model:
221
+ params["temperature"] = 1
222
+
223
+ client = self.async_client()
224
+ try:
225
+ response = await client.responses.create(**params)
226
+
227
+ except Exception as e:
228
+ return {"message": str(e)}
229
+
230
+ # convert to dict
231
+ response_dict = response.model_dump()
232
+ return response_dict
233
+
234
+ LLM.__name__ = model_class_name
235
+ return LLM
236
+
237
+ @staticmethod
238
+ def _create_reasoning_sequence():
239
+ """Create the reasoning sequence for extracting reasoning summaries from model responses."""
240
+ # For OpenAI responses, the reasoning summary is typically found at:
241
+ # ["output", 0, "summary"]
242
+ # This is the path to the 'summary' field in the first item of the 'output' array
243
+ return ["output", 0, "summary"]
@@ -54,7 +54,7 @@ class TestService(InferenceServiceABC):
54
54
  input_token_name = cls.input_token_name
55
55
  output_token_name = cls.output_token_name
56
56
  _rpm = 1000
57
- _tpm = 100000
57
+ _tpm = 8000000
58
58
 
59
59
  @property
60
60
  def _canned_response(self):
@@ -16,8 +16,9 @@ class InterviewExceptionEntry:
16
16
  invigilator: "InvigilatorBase",
17
17
  traceback_format="text",
18
18
  answers=None,
19
+ time=None, # Added time parameter for deserialization
19
20
  ):
20
- self.time = datetime.datetime.now().isoformat()
21
+ self.time = time or datetime.datetime.now().isoformat()
21
22
  self.exception = exception
22
23
  self.invigilator = invigilator
23
24
  self.traceback_format = traceback_format
@@ -130,7 +131,12 @@ class InterviewExceptionEntry:
130
131
  'Traceback (most recent call last):...'
131
132
  """
132
133
  e = self.exception
133
- tb_str = "".join(traceback.format_exception(type(e), e, e.__traceback__))
134
+ # Check if the exception has a traceback attribute
135
+ if hasattr(e, "__traceback__") and e.__traceback__:
136
+ tb_str = "".join(traceback.format_exception(type(e), e, e.__traceback__))
137
+ else:
138
+ # Use the message as traceback if no traceback available
139
+ tb_str = f"Exception: {str(e)}"
134
140
  return tb_str
135
141
 
136
142
  @property
@@ -144,14 +150,19 @@ class InterviewExceptionEntry:
144
150
 
145
151
  console = Console(file=html_output, record=True)
146
152
 
147
- tb = Traceback.from_exception(
148
- type(self.exception),
149
- self.exception,
150
- self.exception.__traceback__,
151
- show_locals=True,
152
- )
153
- console.print(tb)
154
- return html_output.getvalue()
153
+ # Check if the exception has a traceback attribute
154
+ if hasattr(self.exception, "__traceback__") and self.exception.__traceback__:
155
+ tb = Traceback.from_exception(
156
+ type(self.exception),
157
+ self.exception,
158
+ self.exception.__traceback__,
159
+ show_locals=True,
160
+ )
161
+ console.print(tb)
162
+ return html_output.getvalue()
163
+ else:
164
+ # Return a simple string if no traceback available
165
+ return f"<pre>Exception: {str(self.exception)}</pre>"
155
166
 
156
167
  @staticmethod
157
168
  def serialize_exception(exception: Exception) -> dict:
@@ -160,14 +171,25 @@ class InterviewExceptionEntry:
160
171
  >>> entry = InterviewExceptionEntry.example()
161
172
  >>> _ = entry.serialize_exception(entry.exception)
162
173
  """
163
- return {
164
- "type": type(exception).__name__,
165
- "message": str(exception),
166
- "traceback": "".join(
174
+ # Store the original exception type for proper reconstruction
175
+ exception_type = type(exception).__name__
176
+ module_name = getattr(type(exception), "__module__", "builtins")
177
+
178
+ # Extract traceback if available
179
+ if hasattr(exception, "__traceback__") and exception.__traceback__:
180
+ tb_str = "".join(
167
181
  traceback.format_exception(
168
182
  type(exception), exception, exception.__traceback__
169
183
  )
170
- ),
184
+ )
185
+ else:
186
+ tb_str = f"Exception: {str(exception)}"
187
+
188
+ return {
189
+ "type": exception_type,
190
+ "module": module_name,
191
+ "message": str(exception),
192
+ "traceback": tb_str,
171
193
  }
172
194
 
173
195
  @staticmethod
@@ -177,11 +199,31 @@ class InterviewExceptionEntry:
177
199
  >>> entry = InterviewExceptionEntry.example()
178
200
  >>> _ = entry.deserialize_exception(entry.to_dict()["exception"])
179
201
  """
202
+ exception_type = data.get("type", "Exception")
203
+ module_name = data.get("module", "builtins")
204
+ message = data.get("message", "")
205
+
180
206
  try:
181
- exception_class = globals()[data["type"]]
182
- except KeyError:
183
- exception_class = Exception
184
- return exception_class(data["message"])
207
+ # Try to import the module and get the exception class
208
+ # if module_name != "builtins":
209
+ # import importlib
210
+
211
+ # module = importlib.import_module(module_name)
212
+ # exception_class = getattr(module, exception_type, Exception)
213
+ # else:
214
+ # # Look for exception in builtins
215
+ import builtins
216
+
217
+ exception_class = getattr(builtins, exception_type, Exception)
218
+
219
+ except (ImportError, AttributeError):
220
+ # Fall back to a generic Exception but preserve the type name
221
+ exception = Exception(message)
222
+ exception.__class__.__name__ = exception_type
223
+ return exception
224
+
225
+ # Create instance of the original exception type if possible
226
+ return exception_class(message)
185
227
 
186
228
  def to_dict(self) -> dict:
187
229
  """Return the exception as a dictionary.
@@ -221,7 +263,11 @@ class InterviewExceptionEntry:
221
263
  invigilator = None
222
264
  else:
223
265
  invigilator = InvigilatorAI.from_dict(data["invigilator"])
224
- return cls(exception=exception, invigilator=invigilator)
266
+
267
+ # Use the original timestamp from serialization
268
+ time = data.get("time")
269
+
270
+ return cls(exception=exception, invigilator=invigilator, time=time)
225
271
 
226
272
 
227
273
  class InterviewExceptionCollection(UserDict):
@@ -105,7 +105,11 @@ class InvigilatorBase(ABC):
105
105
  value = getattr(self, attr)
106
106
  if value is None:
107
107
  return None
108
- if hasattr(value, "to_dict"):
108
+ if attr == "scenario" and hasattr(value, "offload"):
109
+ # Use the scenario's offload method to replace base64_string values
110
+ offloaded = value.offload()
111
+ return offloaded.to_dict()
112
+ elif hasattr(value, "to_dict"):
109
113
  return value.to_dict()
110
114
  if isinstance(value, (int, float, str, bool, dict, list)):
111
115
  return value