edsl 0.1.57__py3-none-any.whl → 0.1.59__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
edsl/dataset/dataset.py CHANGED
@@ -93,6 +93,38 @@ class Dataset(UserList, DatasetOperationsMixin, PersistenceMixin, HashingMixin):
93
93
  """
94
94
  _, values = list(self.data[0].items())[0]
95
95
  return len(values)
96
+
97
+ def drop(self, field_name):
98
+ """
99
+ Returns a new Dataset with the specified field removed.
100
+
101
+ Args:
102
+ field_name (str): The name of the field to remove.
103
+
104
+ Returns:
105
+ Dataset: A new Dataset instance without the specified field.
106
+
107
+ Raises:
108
+ KeyError: If the field_name doesn't exist in the dataset.
109
+
110
+ Examples:
111
+ >>> from .dataset import Dataset
112
+ >>> d = Dataset([{'a': [1, 2, 3]}, {'b': [4, 5, 6]}])
113
+ >>> d.drop('a')
114
+ Dataset([{'b': [4, 5, 6]}])
115
+
116
+ >>> # Testing drop with nonexistent field raises DatasetKeyError - tested in unit tests
117
+ """
118
+ from .dataset import Dataset
119
+
120
+ # Check if field exists in the dataset
121
+ if field_name not in self.relevant_columns():
122
+ raise DatasetKeyError(f"Field '{field_name}' not found in dataset")
123
+
124
+ # Create a new dataset without the specified field
125
+ new_data = [entry for entry in self.data if field_name not in entry]
126
+ return Dataset(new_data)
127
+
96
128
 
97
129
  def tail(self, n: int = 5) -> Dataset:
98
130
  """Return the last n observations in the dataset.
@@ -1054,6 +1086,48 @@ class Dataset(UserList, DatasetOperationsMixin, PersistenceMixin, HashingMixin):
1054
1086
 
1055
1087
  return Dataset(new_data)
1056
1088
 
1089
+ def unique(self) -> "Dataset":
1090
+ """Return a new dataset with only unique observations.
1091
+
1092
+ Examples:
1093
+ >>> d = Dataset([{'a': [1, 2, 2, 3]}, {'b': [4, 5, 5, 6]}])
1094
+ >>> d.unique().data
1095
+ [{'a': [1, 2, 3]}, {'b': [4, 5, 6]}]
1096
+
1097
+ >>> d = Dataset([{'x': ['a', 'a', 'b']}, {'y': [1, 1, 2]}])
1098
+ >>> d.unique().data
1099
+ [{'x': ['a', 'b']}, {'y': [1, 2]}]
1100
+ """
1101
+ # Get all column names and values
1102
+ headers, data = self._tabular()
1103
+
1104
+ # Create a list of unique rows
1105
+ unique_rows = []
1106
+ seen = set()
1107
+
1108
+ for row in data:
1109
+ # Convert the row to a hashable representation for comparison
1110
+ # We need to handle potential unhashable types
1111
+ try:
1112
+ row_key = tuple(map(lambda x: str(x) if isinstance(x, (list, dict)) else x, row))
1113
+ if row_key not in seen:
1114
+ seen.add(row_key)
1115
+ unique_rows.append(row)
1116
+ except:
1117
+ # Fallback for complex objects: compare based on string representation
1118
+ row_str = str(row)
1119
+ if row_str not in seen:
1120
+ seen.add(row_str)
1121
+ unique_rows.append(row)
1122
+
1123
+ # Create a new dataset with unique combinations
1124
+ new_data = []
1125
+ for i, header in enumerate(headers):
1126
+ values = [row[i] for row in unique_rows]
1127
+ new_data.append({header: values})
1128
+
1129
+ return Dataset(new_data)
1130
+
1057
1131
 
1058
1132
  if __name__ == "__main__":
1059
1133
  import doctest
@@ -1070,7 +1070,6 @@ class DataOperationsBase:
1070
1070
  - All dictionaries in the field must have compatible structures
1071
1071
  - If a dictionary is missing a key, the corresponding value will be None
1072
1072
  - Non-dictionary values in the field will cause a warning
1073
-
1074
1073
  Examples:
1075
1074
  >>> from edsl.dataset import Dataset
1076
1075
 
@@ -1086,48 +1085,85 @@ class DataOperationsBase:
1086
1085
  >>> d = Dataset([{'a': [{'a': 1, 'b': 2}]}, {'c': [5]}])
1087
1086
  >>> d.flatten('a', keep_original=True)
1088
1087
  Dataset([{'a': [{'a': 1, 'b': 2}]}, {'c': [5]}, {'a.a': [1]}, {'a.b': [2]}])
1088
+
1089
+ # Can also use unambiguous unprefixed field name
1090
+ >>> result = Dataset([{'answer.pros_cons': [{'pros': ['Safety'], 'cons': ['Cost']}]}]).flatten('pros_cons')
1091
+ >>> sorted(result.keys()) == ['answer.pros_cons.cons', 'answer.pros_cons.pros']
1092
+ True
1093
+ >>> sorted(result.to_dicts()[0].items()) == sorted({'cons': ['Cost'], 'pros': ['Safety']}.items())
1094
+ True
1089
1095
  """
1090
1096
  from ..dataset import Dataset
1091
1097
 
1092
1098
  # Ensure the dataset isn't empty
1093
1099
  if not self.data:
1094
1100
  return self.copy()
1095
-
1096
- # Find all columns that contain the field
1097
- matching_entries = []
1098
- for entry in self.data:
1099
- col_name = next(iter(entry.keys()))
1100
- if field == col_name or (
1101
- "." in col_name
1102
- and (col_name.endswith("." + field) or col_name.startswith(field + "."))
1103
- ):
1104
- matching_entries.append(entry)
1105
-
1106
- # Check if the field is ambiguous
1107
- if len(matching_entries) > 1:
1108
- matching_cols = [next(iter(entry.keys())) for entry in matching_entries]
1109
- from .exceptions import DatasetValueError
1110
-
1111
- raise DatasetValueError(
1112
- f"Ambiguous field name '{field}'. It matches multiple columns: {matching_cols}. "
1113
- f"Please specify the full column name to flatten."
1114
- )
1115
-
1116
- # Get the number of observations
1117
- num_observations = self.num_observations()
1118
-
1119
- # Find the column to flatten
1101
+
1102
+ # First try direct match with the exact field name
1120
1103
  field_entry = None
1121
1104
  for entry in self.data:
1122
- if field in entry:
1105
+ col_name = next(iter(entry.keys()))
1106
+ if field == col_name:
1123
1107
  field_entry = entry
1124
1108
  break
1109
+
1110
+ # If not found, try to match by unprefixed name
1111
+ if field_entry is None:
1112
+ # Find any columns that have field as their unprefixed name
1113
+ candidates = []
1114
+ for entry in self.data:
1115
+ col_name = next(iter(entry.keys()))
1116
+ if '.' in col_name:
1117
+ prefix, col_field = col_name.split('.', 1)
1118
+ if col_field == field:
1119
+ candidates.append(entry)
1120
+
1121
+ # If we found exactly one match by unprefixed name, use it
1122
+ if len(candidates) == 1:
1123
+ field_entry = candidates[0]
1124
+ # If we found multiple matches, it's ambiguous
1125
+ elif len(candidates) > 1:
1126
+ matching_cols = [next(iter(entry.keys())) for entry in candidates]
1127
+ from .exceptions import DatasetValueError
1128
+ raise DatasetValueError(
1129
+ f"Ambiguous field name '{field}'. It matches multiple columns: {matching_cols}. "
1130
+ f"Please specify the full column name to flatten."
1131
+ )
1132
+ # If no candidates by unprefixed name, check partial matches
1133
+ else:
1134
+ partial_matches = []
1135
+ for entry in self.data:
1136
+ col_name = next(iter(entry.keys()))
1137
+ if '.' in col_name and (
1138
+ col_name.endswith('.' + field) or
1139
+ col_name.startswith(field + '.')
1140
+ ):
1141
+ partial_matches.append(entry)
1142
+
1143
+ # If we found exactly one partial match, use it
1144
+ if len(partial_matches) == 1:
1145
+ field_entry = partial_matches[0]
1146
+ # If we found multiple partial matches, it's ambiguous
1147
+ elif len(partial_matches) > 1:
1148
+ matching_cols = [next(iter(entry.keys())) for entry in partial_matches]
1149
+ from .exceptions import DatasetValueError
1150
+ raise DatasetValueError(
1151
+ f"Ambiguous field name '{field}'. It matches multiple columns: {matching_cols}. "
1152
+ f"Please specify the full column name to flatten."
1153
+ )
1154
+
1155
+ # Get the number of observations
1156
+ num_observations = self.num_observations()
1125
1157
 
1158
+ # If we still haven't found the field, it's not in the dataset
1126
1159
  if field_entry is None:
1127
1160
  warnings.warn(
1128
1161
  f"Field '{field}' not found in dataset, returning original dataset"
1129
1162
  )
1130
1163
  return self.copy()
1164
+
1165
+ # Get the actual field name as it appears in the data
1166
+ actual_field = next(iter(field_entry.keys()))
1131
1167
 
1132
1168
  # Create new dictionary for flattened data
1133
1169
  flattened_data = []
@@ -1135,14 +1171,14 @@ class DataOperationsBase:
1135
1171
  # Copy all existing columns except the one we're flattening (if keep_original is False)
1136
1172
  for entry in self.data:
1137
1173
  col_name = next(iter(entry.keys()))
1138
- if col_name != field or keep_original:
1174
+ if col_name != actual_field or keep_original:
1139
1175
  flattened_data.append(entry.copy())
1140
1176
 
1141
1177
  # Get field data and make sure it's valid
1142
- field_values = field_entry[field]
1178
+ field_values = field_entry[actual_field]
1143
1179
  if not all(isinstance(item, dict) for item in field_values if item is not None):
1144
1180
  warnings.warn(
1145
- f"Field '{field}' contains non-dictionary values that cannot be flattened"
1181
+ f"Field '{actual_field}' contains non-dictionary values that cannot be flattened"
1146
1182
  )
1147
1183
  return self.copy()
1148
1184
 
@@ -1162,7 +1198,7 @@ class DataOperationsBase:
1162
1198
  new_values.append(value)
1163
1199
 
1164
1200
  # Add this as a new column
1165
- flattened_data.append({f"{field}.{key}": new_values})
1201
+ flattened_data.append({f"{actual_field}.{key}": new_values})
1166
1202
 
1167
1203
  # Return a new Dataset with the flattened data
1168
1204
  return Dataset(flattened_data)
@@ -1244,37 +1280,6 @@ class DataOperationsBase:
1244
1280
 
1245
1281
  return result
1246
1282
 
1247
- def drop(self, field_name):
1248
- """
1249
- Returns a new Dataset with the specified field removed.
1250
-
1251
- Args:
1252
- field_name (str): The name of the field to remove.
1253
-
1254
- Returns:
1255
- Dataset: A new Dataset instance without the specified field.
1256
-
1257
- Raises:
1258
- KeyError: If the field_name doesn't exist in the dataset.
1259
-
1260
- Examples:
1261
- >>> from .dataset import Dataset
1262
- >>> d = Dataset([{'a': [1, 2, 3]}, {'b': [4, 5, 6]}])
1263
- >>> d.drop('a')
1264
- Dataset([{'b': [4, 5, 6]}])
1265
-
1266
- >>> # Testing drop with nonexistent field raises DatasetKeyError - tested in unit tests
1267
- """
1268
- from .dataset import Dataset
1269
-
1270
- # Check if field exists in the dataset
1271
- if field_name not in self.relevant_columns():
1272
- raise DatasetKeyError(f"Field '{field_name}' not found in dataset")
1273
-
1274
- # Create a new dataset without the specified field
1275
- new_data = [entry for entry in self.data if field_name not in entry]
1276
- return Dataset(new_data)
1277
-
1278
1283
  def remove_prefix(self):
1279
1284
  """Returns a new Dataset with the prefix removed from all column names.
1280
1285
 
@@ -54,7 +54,7 @@ class TestService(InferenceServiceABC):
54
54
  input_token_name = cls.input_token_name
55
55
  output_token_name = cls.output_token_name
56
56
  _rpm = 1000
57
- _tpm = 100000
57
+ _tpm = 8000000
58
58
 
59
59
  @property
60
60
  def _canned_response(self):
@@ -16,13 +16,19 @@ class InterviewExceptionEntry:
16
16
  invigilator: "InvigilatorBase",
17
17
  traceback_format="text",
18
18
  answers=None,
19
+ time=None, # Added time parameter for deserialization
19
20
  ):
20
- self.time = datetime.datetime.now().isoformat()
21
+ self.time = time or datetime.datetime.now().isoformat()
21
22
  self.exception = exception
22
23
  self.invigilator = invigilator
23
24
  self.traceback_format = traceback_format
24
25
  self.answers = answers
25
26
 
27
+ @property
28
+ def exception_type(self) -> str:
29
+ """Return the type of the exception."""
30
+ return type(self.exception).__name__
31
+
26
32
  @property
27
33
  def question_type(self) -> str:
28
34
  """Return the type of the question that failed."""
@@ -125,7 +131,12 @@ class InterviewExceptionEntry:
125
131
  'Traceback (most recent call last):...'
126
132
  """
127
133
  e = self.exception
128
- tb_str = "".join(traceback.format_exception(type(e), e, e.__traceback__))
134
+ # Check if the exception has a traceback attribute
135
+ if hasattr(e, "__traceback__") and e.__traceback__:
136
+ tb_str = "".join(traceback.format_exception(type(e), e, e.__traceback__))
137
+ else:
138
+ # Use the message as traceback if no traceback available
139
+ tb_str = f"Exception: {str(e)}"
129
140
  return tb_str
130
141
 
131
142
  @property
@@ -139,14 +150,19 @@ class InterviewExceptionEntry:
139
150
 
140
151
  console = Console(file=html_output, record=True)
141
152
 
142
- tb = Traceback.from_exception(
143
- type(self.exception),
144
- self.exception,
145
- self.exception.__traceback__,
146
- show_locals=True,
147
- )
148
- console.print(tb)
149
- return html_output.getvalue()
153
+ # Check if the exception has a traceback attribute
154
+ if hasattr(self.exception, "__traceback__") and self.exception.__traceback__:
155
+ tb = Traceback.from_exception(
156
+ type(self.exception),
157
+ self.exception,
158
+ self.exception.__traceback__,
159
+ show_locals=True,
160
+ )
161
+ console.print(tb)
162
+ return html_output.getvalue()
163
+ else:
164
+ # Return a simple string if no traceback available
165
+ return f"<pre>Exception: {str(self.exception)}</pre>"
150
166
 
151
167
  @staticmethod
152
168
  def serialize_exception(exception: Exception) -> dict:
@@ -155,14 +171,25 @@ class InterviewExceptionEntry:
155
171
  >>> entry = InterviewExceptionEntry.example()
156
172
  >>> _ = entry.serialize_exception(entry.exception)
157
173
  """
158
- return {
159
- "type": type(exception).__name__,
160
- "message": str(exception),
161
- "traceback": "".join(
174
+ # Store the original exception type for proper reconstruction
175
+ exception_type = type(exception).__name__
176
+ module_name = getattr(type(exception), "__module__", "builtins")
177
+
178
+ # Extract traceback if available
179
+ if hasattr(exception, "__traceback__") and exception.__traceback__:
180
+ tb_str = "".join(
162
181
  traceback.format_exception(
163
182
  type(exception), exception, exception.__traceback__
164
183
  )
165
- ),
184
+ )
185
+ else:
186
+ tb_str = f"Exception: {str(exception)}"
187
+
188
+ return {
189
+ "type": exception_type,
190
+ "module": module_name,
191
+ "message": str(exception),
192
+ "traceback": tb_str,
166
193
  }
167
194
 
168
195
  @staticmethod
@@ -172,11 +199,31 @@ class InterviewExceptionEntry:
172
199
  >>> entry = InterviewExceptionEntry.example()
173
200
  >>> _ = entry.deserialize_exception(entry.to_dict()["exception"])
174
201
  """
202
+ exception_type = data.get("type", "Exception")
203
+ module_name = data.get("module", "builtins")
204
+ message = data.get("message", "")
205
+
175
206
  try:
176
- exception_class = globals()[data["type"]]
177
- except KeyError:
178
- exception_class = Exception
179
- return exception_class(data["message"])
207
+ # Try to import the module and get the exception class
208
+ # if module_name != "builtins":
209
+ # import importlib
210
+
211
+ # module = importlib.import_module(module_name)
212
+ # exception_class = getattr(module, exception_type, Exception)
213
+ # else:
214
+ # # Look for exception in builtins
215
+ import builtins
216
+
217
+ exception_class = getattr(builtins, exception_type, Exception)
218
+
219
+ except (ImportError, AttributeError):
220
+ # Fall back to a generic Exception but preserve the type name
221
+ exception = Exception(message)
222
+ exception.__class__.__name__ = exception_type
223
+ return exception
224
+
225
+ # Create instance of the original exception type if possible
226
+ return exception_class(message)
180
227
 
181
228
  def to_dict(self) -> dict:
182
229
  """Return the exception as a dictionary.
@@ -216,7 +263,11 @@ class InterviewExceptionEntry:
216
263
  invigilator = None
217
264
  else:
218
265
  invigilator = InvigilatorAI.from_dict(data["invigilator"])
219
- return cls(exception=exception, invigilator=invigilator)
266
+
267
+ # Use the original timestamp from serialization
268
+ time = data.get("time")
269
+
270
+ return cls(exception=exception, invigilator=invigilator, time=time)
220
271
 
221
272
 
222
273
  class InterviewExceptionCollection(UserDict):
@@ -239,6 +290,27 @@ class InterviewExceptionCollection(UserDict):
239
290
  """Return the number of unfixed exceptions."""
240
291
  return sum(len(v) for v in self.unfixed_exceptions().values())
241
292
 
293
+ def list(self) -> list[dict]:
294
+ """
295
+ Return a list of exception dicts with the following metadata:
296
+ - exception_type: the type of the exception
297
+ - inference_service: the inference service used
298
+ - model: the model used
299
+ - question_name: the name of the question that failed
300
+ """
301
+ exception_list = []
302
+ for question_name, exceptions in self.data.items():
303
+ for exception in exceptions:
304
+ exception_list.append(
305
+ {
306
+ "exception_type": exception.exception_type,
307
+ "inference_service": exception.invigilator.model._inference_service_,
308
+ "model": exception.invigilator.model.model,
309
+ "question_name": question_name,
310
+ }
311
+ )
312
+ return exception_list
313
+
242
314
  def num_unfixed(self) -> int:
243
315
  """Return a list of unfixed questions."""
244
316
  return len([k for k in self.data.keys() if k not in self.fixed])
@@ -105,7 +105,11 @@ class InvigilatorBase(ABC):
105
105
  value = getattr(self, attr)
106
106
  if value is None:
107
107
  return None
108
- if hasattr(value, "to_dict"):
108
+ if attr == "scenario" and hasattr(value, "offload"):
109
+ # Use the scenario's offload method to replace base64_string values
110
+ offloaded = value.offload()
111
+ return offloaded.to_dict()
112
+ elif hasattr(value, "to_dict"):
109
113
  return value.to_dict()
110
114
  if isinstance(value, (int, float, str, bool, dict, list)):
111
115
  return value