edsl 0.1.57__py3-none-any.whl → 0.1.59__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/__version__.py +1 -1
- edsl/agents/agent.py +23 -4
- edsl/agents/agent_list.py +36 -6
- edsl/coop/coop.py +274 -35
- edsl/coop/utils.py +63 -0
- edsl/dataset/dataset.py +74 -0
- edsl/dataset/dataset_operations_mixin.py +67 -62
- edsl/inference_services/services/test_service.py +1 -1
- edsl/interviews/exception_tracking.py +92 -20
- edsl/invigilators/invigilators.py +5 -1
- edsl/invigilators/prompt_constructor.py +299 -136
- edsl/jobs/html_table_job_logger.py +394 -48
- edsl/jobs/jobs_pricing_estimation.py +19 -114
- edsl/jobs/jobs_remote_inference_logger.py +29 -0
- edsl/jobs/jobs_runner_status.py +52 -21
- edsl/jobs/remote_inference.py +214 -30
- edsl/language_models/language_model.py +40 -3
- edsl/language_models/price_manager.py +91 -57
- edsl/prompts/prompt.py +1 -0
- edsl/questions/question_list.py +76 -20
- edsl/results/results.py +8 -1
- edsl/scenarios/file_store.py +8 -12
- edsl/scenarios/scenario.py +50 -2
- edsl/scenarios/scenario_list.py +34 -12
- edsl/surveys/survey.py +4 -0
- edsl/tasks/task_history.py +180 -6
- edsl/utilities/wikipedia.py +194 -0
- {edsl-0.1.57.dist-info → edsl-0.1.59.dist-info}/METADATA +4 -3
- {edsl-0.1.57.dist-info → edsl-0.1.59.dist-info}/RECORD +32 -32
- edsl/language_models/compute_cost.py +0 -78
- {edsl-0.1.57.dist-info → edsl-0.1.59.dist-info}/LICENSE +0 -0
- {edsl-0.1.57.dist-info → edsl-0.1.59.dist-info}/WHEEL +0 -0
- {edsl-0.1.57.dist-info → edsl-0.1.59.dist-info}/entry_points.txt +0 -0
edsl/dataset/dataset.py
CHANGED
@@ -93,6 +93,38 @@ class Dataset(UserList, DatasetOperationsMixin, PersistenceMixin, HashingMixin):
|
|
93
93
|
"""
|
94
94
|
_, values = list(self.data[0].items())[0]
|
95
95
|
return len(values)
|
96
|
+
|
97
|
+
def drop(self, field_name):
|
98
|
+
"""
|
99
|
+
Returns a new Dataset with the specified field removed.
|
100
|
+
|
101
|
+
Args:
|
102
|
+
field_name (str): The name of the field to remove.
|
103
|
+
|
104
|
+
Returns:
|
105
|
+
Dataset: A new Dataset instance without the specified field.
|
106
|
+
|
107
|
+
Raises:
|
108
|
+
KeyError: If the field_name doesn't exist in the dataset.
|
109
|
+
|
110
|
+
Examples:
|
111
|
+
>>> from .dataset import Dataset
|
112
|
+
>>> d = Dataset([{'a': [1, 2, 3]}, {'b': [4, 5, 6]}])
|
113
|
+
>>> d.drop('a')
|
114
|
+
Dataset([{'b': [4, 5, 6]}])
|
115
|
+
|
116
|
+
>>> # Testing drop with nonexistent field raises DatasetKeyError - tested in unit tests
|
117
|
+
"""
|
118
|
+
from .dataset import Dataset
|
119
|
+
|
120
|
+
# Check if field exists in the dataset
|
121
|
+
if field_name not in self.relevant_columns():
|
122
|
+
raise DatasetKeyError(f"Field '{field_name}' not found in dataset")
|
123
|
+
|
124
|
+
# Create a new dataset without the specified field
|
125
|
+
new_data = [entry for entry in self.data if field_name not in entry]
|
126
|
+
return Dataset(new_data)
|
127
|
+
|
96
128
|
|
97
129
|
def tail(self, n: int = 5) -> Dataset:
|
98
130
|
"""Return the last n observations in the dataset.
|
@@ -1054,6 +1086,48 @@ class Dataset(UserList, DatasetOperationsMixin, PersistenceMixin, HashingMixin):
|
|
1054
1086
|
|
1055
1087
|
return Dataset(new_data)
|
1056
1088
|
|
1089
|
+
def unique(self) -> "Dataset":
|
1090
|
+
"""Return a new dataset with only unique observations.
|
1091
|
+
|
1092
|
+
Examples:
|
1093
|
+
>>> d = Dataset([{'a': [1, 2, 2, 3]}, {'b': [4, 5, 5, 6]}])
|
1094
|
+
>>> d.unique().data
|
1095
|
+
[{'a': [1, 2, 3]}, {'b': [4, 5, 6]}]
|
1096
|
+
|
1097
|
+
>>> d = Dataset([{'x': ['a', 'a', 'b']}, {'y': [1, 1, 2]}])
|
1098
|
+
>>> d.unique().data
|
1099
|
+
[{'x': ['a', 'b']}, {'y': [1, 2]}]
|
1100
|
+
"""
|
1101
|
+
# Get all column names and values
|
1102
|
+
headers, data = self._tabular()
|
1103
|
+
|
1104
|
+
# Create a list of unique rows
|
1105
|
+
unique_rows = []
|
1106
|
+
seen = set()
|
1107
|
+
|
1108
|
+
for row in data:
|
1109
|
+
# Convert the row to a hashable representation for comparison
|
1110
|
+
# We need to handle potential unhashable types
|
1111
|
+
try:
|
1112
|
+
row_key = tuple(map(lambda x: str(x) if isinstance(x, (list, dict)) else x, row))
|
1113
|
+
if row_key not in seen:
|
1114
|
+
seen.add(row_key)
|
1115
|
+
unique_rows.append(row)
|
1116
|
+
except:
|
1117
|
+
# Fallback for complex objects: compare based on string representation
|
1118
|
+
row_str = str(row)
|
1119
|
+
if row_str not in seen:
|
1120
|
+
seen.add(row_str)
|
1121
|
+
unique_rows.append(row)
|
1122
|
+
|
1123
|
+
# Create a new dataset with unique combinations
|
1124
|
+
new_data = []
|
1125
|
+
for i, header in enumerate(headers):
|
1126
|
+
values = [row[i] for row in unique_rows]
|
1127
|
+
new_data.append({header: values})
|
1128
|
+
|
1129
|
+
return Dataset(new_data)
|
1130
|
+
|
1057
1131
|
|
1058
1132
|
if __name__ == "__main__":
|
1059
1133
|
import doctest
|
@@ -1070,7 +1070,6 @@ class DataOperationsBase:
|
|
1070
1070
|
- All dictionaries in the field must have compatible structures
|
1071
1071
|
- If a dictionary is missing a key, the corresponding value will be None
|
1072
1072
|
- Non-dictionary values in the field will cause a warning
|
1073
|
-
|
1074
1073
|
Examples:
|
1075
1074
|
>>> from edsl.dataset import Dataset
|
1076
1075
|
|
@@ -1086,48 +1085,85 @@ class DataOperationsBase:
|
|
1086
1085
|
>>> d = Dataset([{'a': [{'a': 1, 'b': 2}]}, {'c': [5]}])
|
1087
1086
|
>>> d.flatten('a', keep_original=True)
|
1088
1087
|
Dataset([{'a': [{'a': 1, 'b': 2}]}, {'c': [5]}, {'a.a': [1]}, {'a.b': [2]}])
|
1088
|
+
|
1089
|
+
# Can also use unambiguous unprefixed field name
|
1090
|
+
>>> result = Dataset([{'answer.pros_cons': [{'pros': ['Safety'], 'cons': ['Cost']}]}]).flatten('pros_cons')
|
1091
|
+
>>> sorted(result.keys()) == ['answer.pros_cons.cons', 'answer.pros_cons.pros']
|
1092
|
+
True
|
1093
|
+
>>> sorted(result.to_dicts()[0].items()) == sorted({'cons': ['Cost'], 'pros': ['Safety']}.items())
|
1094
|
+
True
|
1089
1095
|
"""
|
1090
1096
|
from ..dataset import Dataset
|
1091
1097
|
|
1092
1098
|
# Ensure the dataset isn't empty
|
1093
1099
|
if not self.data:
|
1094
1100
|
return self.copy()
|
1095
|
-
|
1096
|
-
#
|
1097
|
-
matching_entries = []
|
1098
|
-
for entry in self.data:
|
1099
|
-
col_name = next(iter(entry.keys()))
|
1100
|
-
if field == col_name or (
|
1101
|
-
"." in col_name
|
1102
|
-
and (col_name.endswith("." + field) or col_name.startswith(field + "."))
|
1103
|
-
):
|
1104
|
-
matching_entries.append(entry)
|
1105
|
-
|
1106
|
-
# Check if the field is ambiguous
|
1107
|
-
if len(matching_entries) > 1:
|
1108
|
-
matching_cols = [next(iter(entry.keys())) for entry in matching_entries]
|
1109
|
-
from .exceptions import DatasetValueError
|
1110
|
-
|
1111
|
-
raise DatasetValueError(
|
1112
|
-
f"Ambiguous field name '{field}'. It matches multiple columns: {matching_cols}. "
|
1113
|
-
f"Please specify the full column name to flatten."
|
1114
|
-
)
|
1115
|
-
|
1116
|
-
# Get the number of observations
|
1117
|
-
num_observations = self.num_observations()
|
1118
|
-
|
1119
|
-
# Find the column to flatten
|
1101
|
+
|
1102
|
+
# First try direct match with the exact field name
|
1120
1103
|
field_entry = None
|
1121
1104
|
for entry in self.data:
|
1122
|
-
|
1105
|
+
col_name = next(iter(entry.keys()))
|
1106
|
+
if field == col_name:
|
1123
1107
|
field_entry = entry
|
1124
1108
|
break
|
1109
|
+
|
1110
|
+
# If not found, try to match by unprefixed name
|
1111
|
+
if field_entry is None:
|
1112
|
+
# Find any columns that have field as their unprefixed name
|
1113
|
+
candidates = []
|
1114
|
+
for entry in self.data:
|
1115
|
+
col_name = next(iter(entry.keys()))
|
1116
|
+
if '.' in col_name:
|
1117
|
+
prefix, col_field = col_name.split('.', 1)
|
1118
|
+
if col_field == field:
|
1119
|
+
candidates.append(entry)
|
1120
|
+
|
1121
|
+
# If we found exactly one match by unprefixed name, use it
|
1122
|
+
if len(candidates) == 1:
|
1123
|
+
field_entry = candidates[0]
|
1124
|
+
# If we found multiple matches, it's ambiguous
|
1125
|
+
elif len(candidates) > 1:
|
1126
|
+
matching_cols = [next(iter(entry.keys())) for entry in candidates]
|
1127
|
+
from .exceptions import DatasetValueError
|
1128
|
+
raise DatasetValueError(
|
1129
|
+
f"Ambiguous field name '{field}'. It matches multiple columns: {matching_cols}. "
|
1130
|
+
f"Please specify the full column name to flatten."
|
1131
|
+
)
|
1132
|
+
# If no candidates by unprefixed name, check partial matches
|
1133
|
+
else:
|
1134
|
+
partial_matches = []
|
1135
|
+
for entry in self.data:
|
1136
|
+
col_name = next(iter(entry.keys()))
|
1137
|
+
if '.' in col_name and (
|
1138
|
+
col_name.endswith('.' + field) or
|
1139
|
+
col_name.startswith(field + '.')
|
1140
|
+
):
|
1141
|
+
partial_matches.append(entry)
|
1142
|
+
|
1143
|
+
# If we found exactly one partial match, use it
|
1144
|
+
if len(partial_matches) == 1:
|
1145
|
+
field_entry = partial_matches[0]
|
1146
|
+
# If we found multiple partial matches, it's ambiguous
|
1147
|
+
elif len(partial_matches) > 1:
|
1148
|
+
matching_cols = [next(iter(entry.keys())) for entry in partial_matches]
|
1149
|
+
from .exceptions import DatasetValueError
|
1150
|
+
raise DatasetValueError(
|
1151
|
+
f"Ambiguous field name '{field}'. It matches multiple columns: {matching_cols}. "
|
1152
|
+
f"Please specify the full column name to flatten."
|
1153
|
+
)
|
1154
|
+
|
1155
|
+
# Get the number of observations
|
1156
|
+
num_observations = self.num_observations()
|
1125
1157
|
|
1158
|
+
# If we still haven't found the field, it's not in the dataset
|
1126
1159
|
if field_entry is None:
|
1127
1160
|
warnings.warn(
|
1128
1161
|
f"Field '{field}' not found in dataset, returning original dataset"
|
1129
1162
|
)
|
1130
1163
|
return self.copy()
|
1164
|
+
|
1165
|
+
# Get the actual field name as it appears in the data
|
1166
|
+
actual_field = next(iter(field_entry.keys()))
|
1131
1167
|
|
1132
1168
|
# Create new dictionary for flattened data
|
1133
1169
|
flattened_data = []
|
@@ -1135,14 +1171,14 @@ class DataOperationsBase:
|
|
1135
1171
|
# Copy all existing columns except the one we're flattening (if keep_original is False)
|
1136
1172
|
for entry in self.data:
|
1137
1173
|
col_name = next(iter(entry.keys()))
|
1138
|
-
if col_name !=
|
1174
|
+
if col_name != actual_field or keep_original:
|
1139
1175
|
flattened_data.append(entry.copy())
|
1140
1176
|
|
1141
1177
|
# Get field data and make sure it's valid
|
1142
|
-
field_values = field_entry[
|
1178
|
+
field_values = field_entry[actual_field]
|
1143
1179
|
if not all(isinstance(item, dict) for item in field_values if item is not None):
|
1144
1180
|
warnings.warn(
|
1145
|
-
f"Field '{
|
1181
|
+
f"Field '{actual_field}' contains non-dictionary values that cannot be flattened"
|
1146
1182
|
)
|
1147
1183
|
return self.copy()
|
1148
1184
|
|
@@ -1162,7 +1198,7 @@ class DataOperationsBase:
|
|
1162
1198
|
new_values.append(value)
|
1163
1199
|
|
1164
1200
|
# Add this as a new column
|
1165
|
-
flattened_data.append({f"{
|
1201
|
+
flattened_data.append({f"{actual_field}.{key}": new_values})
|
1166
1202
|
|
1167
1203
|
# Return a new Dataset with the flattened data
|
1168
1204
|
return Dataset(flattened_data)
|
@@ -1244,37 +1280,6 @@ class DataOperationsBase:
|
|
1244
1280
|
|
1245
1281
|
return result
|
1246
1282
|
|
1247
|
-
def drop(self, field_name):
|
1248
|
-
"""
|
1249
|
-
Returns a new Dataset with the specified field removed.
|
1250
|
-
|
1251
|
-
Args:
|
1252
|
-
field_name (str): The name of the field to remove.
|
1253
|
-
|
1254
|
-
Returns:
|
1255
|
-
Dataset: A new Dataset instance without the specified field.
|
1256
|
-
|
1257
|
-
Raises:
|
1258
|
-
KeyError: If the field_name doesn't exist in the dataset.
|
1259
|
-
|
1260
|
-
Examples:
|
1261
|
-
>>> from .dataset import Dataset
|
1262
|
-
>>> d = Dataset([{'a': [1, 2, 3]}, {'b': [4, 5, 6]}])
|
1263
|
-
>>> d.drop('a')
|
1264
|
-
Dataset([{'b': [4, 5, 6]}])
|
1265
|
-
|
1266
|
-
>>> # Testing drop with nonexistent field raises DatasetKeyError - tested in unit tests
|
1267
|
-
"""
|
1268
|
-
from .dataset import Dataset
|
1269
|
-
|
1270
|
-
# Check if field exists in the dataset
|
1271
|
-
if field_name not in self.relevant_columns():
|
1272
|
-
raise DatasetKeyError(f"Field '{field_name}' not found in dataset")
|
1273
|
-
|
1274
|
-
# Create a new dataset without the specified field
|
1275
|
-
new_data = [entry for entry in self.data if field_name not in entry]
|
1276
|
-
return Dataset(new_data)
|
1277
|
-
|
1278
1283
|
def remove_prefix(self):
|
1279
1284
|
"""Returns a new Dataset with the prefix removed from all column names.
|
1280
1285
|
|
@@ -16,13 +16,19 @@ class InterviewExceptionEntry:
|
|
16
16
|
invigilator: "InvigilatorBase",
|
17
17
|
traceback_format="text",
|
18
18
|
answers=None,
|
19
|
+
time=None, # Added time parameter for deserialization
|
19
20
|
):
|
20
|
-
self.time = datetime.datetime.now().isoformat()
|
21
|
+
self.time = time or datetime.datetime.now().isoformat()
|
21
22
|
self.exception = exception
|
22
23
|
self.invigilator = invigilator
|
23
24
|
self.traceback_format = traceback_format
|
24
25
|
self.answers = answers
|
25
26
|
|
27
|
+
@property
|
28
|
+
def exception_type(self) -> str:
|
29
|
+
"""Return the type of the exception."""
|
30
|
+
return type(self.exception).__name__
|
31
|
+
|
26
32
|
@property
|
27
33
|
def question_type(self) -> str:
|
28
34
|
"""Return the type of the question that failed."""
|
@@ -125,7 +131,12 @@ class InterviewExceptionEntry:
|
|
125
131
|
'Traceback (most recent call last):...'
|
126
132
|
"""
|
127
133
|
e = self.exception
|
128
|
-
|
134
|
+
# Check if the exception has a traceback attribute
|
135
|
+
if hasattr(e, "__traceback__") and e.__traceback__:
|
136
|
+
tb_str = "".join(traceback.format_exception(type(e), e, e.__traceback__))
|
137
|
+
else:
|
138
|
+
# Use the message as traceback if no traceback available
|
139
|
+
tb_str = f"Exception: {str(e)}"
|
129
140
|
return tb_str
|
130
141
|
|
131
142
|
@property
|
@@ -139,14 +150,19 @@ class InterviewExceptionEntry:
|
|
139
150
|
|
140
151
|
console = Console(file=html_output, record=True)
|
141
152
|
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
153
|
+
# Check if the exception has a traceback attribute
|
154
|
+
if hasattr(self.exception, "__traceback__") and self.exception.__traceback__:
|
155
|
+
tb = Traceback.from_exception(
|
156
|
+
type(self.exception),
|
157
|
+
self.exception,
|
158
|
+
self.exception.__traceback__,
|
159
|
+
show_locals=True,
|
160
|
+
)
|
161
|
+
console.print(tb)
|
162
|
+
return html_output.getvalue()
|
163
|
+
else:
|
164
|
+
# Return a simple string if no traceback available
|
165
|
+
return f"<pre>Exception: {str(self.exception)}</pre>"
|
150
166
|
|
151
167
|
@staticmethod
|
152
168
|
def serialize_exception(exception: Exception) -> dict:
|
@@ -155,14 +171,25 @@ class InterviewExceptionEntry:
|
|
155
171
|
>>> entry = InterviewExceptionEntry.example()
|
156
172
|
>>> _ = entry.serialize_exception(entry.exception)
|
157
173
|
"""
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
174
|
+
# Store the original exception type for proper reconstruction
|
175
|
+
exception_type = type(exception).__name__
|
176
|
+
module_name = getattr(type(exception), "__module__", "builtins")
|
177
|
+
|
178
|
+
# Extract traceback if available
|
179
|
+
if hasattr(exception, "__traceback__") and exception.__traceback__:
|
180
|
+
tb_str = "".join(
|
162
181
|
traceback.format_exception(
|
163
182
|
type(exception), exception, exception.__traceback__
|
164
183
|
)
|
165
|
-
)
|
184
|
+
)
|
185
|
+
else:
|
186
|
+
tb_str = f"Exception: {str(exception)}"
|
187
|
+
|
188
|
+
return {
|
189
|
+
"type": exception_type,
|
190
|
+
"module": module_name,
|
191
|
+
"message": str(exception),
|
192
|
+
"traceback": tb_str,
|
166
193
|
}
|
167
194
|
|
168
195
|
@staticmethod
|
@@ -172,11 +199,31 @@ class InterviewExceptionEntry:
|
|
172
199
|
>>> entry = InterviewExceptionEntry.example()
|
173
200
|
>>> _ = entry.deserialize_exception(entry.to_dict()["exception"])
|
174
201
|
"""
|
202
|
+
exception_type = data.get("type", "Exception")
|
203
|
+
module_name = data.get("module", "builtins")
|
204
|
+
message = data.get("message", "")
|
205
|
+
|
175
206
|
try:
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
207
|
+
# Try to import the module and get the exception class
|
208
|
+
# if module_name != "builtins":
|
209
|
+
# import importlib
|
210
|
+
|
211
|
+
# module = importlib.import_module(module_name)
|
212
|
+
# exception_class = getattr(module, exception_type, Exception)
|
213
|
+
# else:
|
214
|
+
# # Look for exception in builtins
|
215
|
+
import builtins
|
216
|
+
|
217
|
+
exception_class = getattr(builtins, exception_type, Exception)
|
218
|
+
|
219
|
+
except (ImportError, AttributeError):
|
220
|
+
# Fall back to a generic Exception but preserve the type name
|
221
|
+
exception = Exception(message)
|
222
|
+
exception.__class__.__name__ = exception_type
|
223
|
+
return exception
|
224
|
+
|
225
|
+
# Create instance of the original exception type if possible
|
226
|
+
return exception_class(message)
|
180
227
|
|
181
228
|
def to_dict(self) -> dict:
|
182
229
|
"""Return the exception as a dictionary.
|
@@ -216,7 +263,11 @@ class InterviewExceptionEntry:
|
|
216
263
|
invigilator = None
|
217
264
|
else:
|
218
265
|
invigilator = InvigilatorAI.from_dict(data["invigilator"])
|
219
|
-
|
266
|
+
|
267
|
+
# Use the original timestamp from serialization
|
268
|
+
time = data.get("time")
|
269
|
+
|
270
|
+
return cls(exception=exception, invigilator=invigilator, time=time)
|
220
271
|
|
221
272
|
|
222
273
|
class InterviewExceptionCollection(UserDict):
|
@@ -239,6 +290,27 @@ class InterviewExceptionCollection(UserDict):
|
|
239
290
|
"""Return the number of unfixed exceptions."""
|
240
291
|
return sum(len(v) for v in self.unfixed_exceptions().values())
|
241
292
|
|
293
|
+
def list(self) -> list[dict]:
|
294
|
+
"""
|
295
|
+
Return a list of exception dicts with the following metadata:
|
296
|
+
- exception_type: the type of the exception
|
297
|
+
- inference_service: the inference service used
|
298
|
+
- model: the model used
|
299
|
+
- question_name: the name of the question that failed
|
300
|
+
"""
|
301
|
+
exception_list = []
|
302
|
+
for question_name, exceptions in self.data.items():
|
303
|
+
for exception in exceptions:
|
304
|
+
exception_list.append(
|
305
|
+
{
|
306
|
+
"exception_type": exception.exception_type,
|
307
|
+
"inference_service": exception.invigilator.model._inference_service_,
|
308
|
+
"model": exception.invigilator.model.model,
|
309
|
+
"question_name": question_name,
|
310
|
+
}
|
311
|
+
)
|
312
|
+
return exception_list
|
313
|
+
|
242
314
|
def num_unfixed(self) -> int:
|
243
315
|
"""Return a list of unfixed questions."""
|
244
316
|
return len([k for k in self.data.keys() if k not in self.fixed])
|
@@ -105,7 +105,11 @@ class InvigilatorBase(ABC):
|
|
105
105
|
value = getattr(self, attr)
|
106
106
|
if value is None:
|
107
107
|
return None
|
108
|
-
if hasattr(value, "
|
108
|
+
if attr == "scenario" and hasattr(value, "offload"):
|
109
|
+
# Use the scenario's offload method to replace base64_string values
|
110
|
+
offloaded = value.offload()
|
111
|
+
return offloaded.to_dict()
|
112
|
+
elif hasattr(value, "to_dict"):
|
109
113
|
return value.to_dict()
|
110
114
|
if isinstance(value, (int, float, str, bool, dict, list)):
|
111
115
|
return value
|