edsl 0.1.58__py3-none-any.whl → 0.1.59__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -217,6 +217,17 @@ class HTMLTableJobLogger(JobLogger):
217
217
  )
218
218
  total_cost = total_input_cost + total_output_cost
219
219
 
220
+ # Calculate credit totals
221
+ total_input_credits = sum(
222
+ cost.input_cost_credits_with_cache or 0
223
+ for cost in self.jobs_info.model_costs
224
+ )
225
+ total_output_credits = sum(
226
+ cost.output_cost_credits_with_cache or 0
227
+ for cost in self.jobs_info.model_costs
228
+ )
229
+ total_credits = total_input_credits + total_output_credits
230
+
220
231
  # Generate cost rows HTML with class names for right alignment
221
232
  cost_rows = "".join(
222
233
  f"""
@@ -228,6 +239,7 @@ class HTMLTableJobLogger(JobLogger):
228
239
  <td class='token-count'>{cost.output_tokens:,}</td>
229
240
  <td class='cost-value'>${cost.output_cost_usd:.4f}</td>
230
241
  <td class='cost-value'>${(cost.input_cost_usd or 0) + (cost.output_cost_usd or 0):.4f}</td>
242
+ <td class='cost-value'>{(cost.input_cost_credits_with_cache or 0) + (cost.output_cost_credits_with_cache or 0):,.2f}</td>
231
243
  </tr>
232
244
  """
233
245
  for cost in self.jobs_info.model_costs
@@ -242,6 +254,7 @@ class HTMLTableJobLogger(JobLogger):
242
254
  <td class='token-count'>{total_output_tokens:,}</td>
243
255
  <td class='cost-value'>${total_output_cost:.4f}</td>
244
256
  <td class='cost-value'>${total_cost:.4f}</td>
257
+ <td class='cost-value'>{total_credits:,.2f}</td>
245
258
  </tr>
246
259
  """
247
260
 
@@ -249,7 +262,7 @@ class HTMLTableJobLogger(JobLogger):
249
262
  <div class="model-costs-section">
250
263
  <div class="model-costs-header" onclick="{self._collapse(f'model-costs-content-{self.log_id}', f'model-costs-arrow-{self.log_id}')}">
251
264
  <span id="model-costs-arrow-{self.log_id}" class="expand-toggle">&#8963;</span>
252
- <span>Model Costs (${total_cost:.4f} total)</span>
265
+ <span>Model Costs (${total_cost:.4f} / {total_credits:,.2f} credits total)</span>
253
266
  <span style="flex-grow: 1;"></span>
254
267
  </div>
255
268
  <div id="model-costs-content-{self.log_id}" class="model-costs-content">
@@ -263,6 +276,7 @@ class HTMLTableJobLogger(JobLogger):
263
276
  <th class="cost-header">Output Tokens</th>
264
277
  <th class="cost-header">Output Cost</th>
265
278
  <th class="cost-header">Total Cost</th>
279
+ <th class="cost-header">Total Credits</th>
266
280
  </tr>
267
281
  </thead>
268
282
  <tbody>
@@ -270,6 +284,9 @@ class HTMLTableJobLogger(JobLogger):
270
284
  {total_row}
271
285
  </tbody>
272
286
  </table>
287
+ <p style="font-style: italic; margin-top: 8px; font-size: 0.85em; color: #4b5563;">
288
+ You can obtain the total credit cost by multiplying the total USD cost by 100. A lower credit cost indicates that you saved money by retrieving responses from the universal remote cache.
289
+ </p>
273
290
  </div>
274
291
  </div>
275
292
  """
@@ -88,7 +88,6 @@ class PromptCostEstimator:
88
88
 
89
89
 
90
90
  class JobsPrompts:
91
-
92
91
  relevant_keys = [
93
92
  "user_prompt",
94
93
  "system_prompt",
@@ -171,13 +170,18 @@ class JobsPrompts:
171
170
  cost = prompt_cost["cost_usd"]
172
171
 
173
172
  # Generate cache keys for each iteration
173
+ files_list = prompts.get("files_list", None)
174
+ if files_list:
175
+ files_hash = "+".join([str(hash(file)) for file in files_list])
176
+ user_prompt_with_hashes = user_prompt + f" {files_hash}"
174
177
  cache_keys = []
178
+
175
179
  for iteration in range(iterations):
176
180
  cache_key = CacheEntry.gen_key(
177
181
  model=model,
178
182
  parameters=invigilator.model.parameters,
179
183
  system_prompt=system_prompt,
180
- user_prompt=user_prompt,
184
+ user_prompt=user_prompt_with_hashes if files_list else user_prompt,
181
185
  iteration=iteration,
182
186
  )
183
187
  cache_keys.append(cache_key)
@@ -40,6 +40,8 @@ class ModelCost:
40
40
  input_cost_usd: float = None
41
41
  output_tokens: int = None
42
42
  output_cost_usd: float = None
43
+ input_cost_credits_with_cache: int = None
44
+ output_cost_credits_with_cache: int = None
43
45
 
44
46
 
45
47
  @dataclass
@@ -279,9 +279,7 @@ class JobsRemoteInferenceHandler:
279
279
  )
280
280
  time.sleep(self.poll_interval)
281
281
 
282
- def _get_expenses_from_results(
283
- self, results: "Results", include_cached_responses_in_cost: bool = False
284
- ) -> dict:
282
+ def _get_expenses_from_results(self, results: "Results") -> dict:
285
283
  """
286
284
  Calculates expenses from Results object.
287
285
 
@@ -309,10 +307,6 @@ class JobsRemoteInferenceHandler:
309
307
  question_name = key.removesuffix("_cost")
310
308
  cache_used = result["cache_used_dict"][question_name]
311
309
 
312
- # Skip if we're excluding cached responses and this was cached
313
- if not include_cached_responses_in_cost and cache_used:
314
- continue
315
-
316
310
  # Get expense keys for input and output tokens
317
311
  input_key = (
318
312
  result["model"]._inference_service_,
@@ -332,6 +326,7 @@ class JobsRemoteInferenceHandler:
332
326
  expenses[input_key] = {
333
327
  "tokens": 0,
334
328
  "cost_usd": 0,
329
+ "cost_usd_with_cache": 0,
335
330
  }
336
331
 
337
332
  input_price_per_million_tokens = input_key[3]
@@ -341,11 +336,15 @@ class JobsRemoteInferenceHandler:
341
336
  expenses[input_key]["tokens"] += input_tokens
342
337
  expenses[input_key]["cost_usd"] += input_cost
343
338
 
339
+ if not cache_used:
340
+ expenses[input_key]["cost_usd_with_cache"] += input_cost
341
+
344
342
  # Update output token expenses
345
343
  if output_key not in expenses:
346
344
  expenses[output_key] = {
347
345
  "tokens": 0,
348
346
  "cost_usd": 0,
347
+ "cost_usd_with_cache": 0,
349
348
  }
350
349
 
351
350
  output_price_per_million_tokens = output_key[3]
@@ -357,6 +356,9 @@ class JobsRemoteInferenceHandler:
357
356
  expenses[output_key]["tokens"] += output_tokens
358
357
  expenses[output_key]["cost_usd"] += output_cost
359
358
 
359
+ if not cache_used:
360
+ expenses[output_key]["cost_usd_with_cache"] += output_cost
361
+
360
362
  expenses_by_model = {}
361
363
  for expense_key, expense_usage in expenses.items():
362
364
  service, model, token_type, _ = expense_key
@@ -368,8 +370,10 @@ class JobsRemoteInferenceHandler:
368
370
  "model": model,
369
371
  "input_tokens": 0,
370
372
  "input_cost_usd": 0,
373
+ "input_cost_usd_with_cache": 0,
371
374
  "output_tokens": 0,
372
375
  "output_cost_usd": 0,
376
+ "output_cost_usd_with_cache": 0,
373
377
  }
374
378
 
375
379
  if token_type == "input":
@@ -377,14 +381,22 @@ class JobsRemoteInferenceHandler:
377
381
  expenses_by_model[model_key]["input_cost_usd"] += expense_usage[
378
382
  "cost_usd"
379
383
  ]
384
+ expenses_by_model[model_key][
385
+ "input_cost_usd_with_cache"
386
+ ] += expense_usage["cost_usd_with_cache"]
380
387
  elif token_type == "output":
381
388
  expenses_by_model[model_key]["output_tokens"] += expense_usage["tokens"]
382
389
  expenses_by_model[model_key]["output_cost_usd"] += expense_usage[
383
390
  "cost_usd"
384
391
  ]
392
+ expenses_by_model[model_key][
393
+ "output_cost_usd_with_cache"
394
+ ] += expense_usage["cost_usd_with_cache"]
385
395
 
386
396
  converter = CostConverter()
387
397
  for model_key, model_cost_dict in expenses_by_model.items():
398
+
399
+ # Handle full cost (without cache)
388
400
  input_cost = model_cost_dict["input_cost_usd"]
389
401
  output_cost = model_cost_dict["output_cost_usd"]
390
402
  model_cost_dict["input_cost_credits"] = converter.usd_to_credits(input_cost)
@@ -399,6 +411,15 @@ class JobsRemoteInferenceHandler:
399
411
  model_cost_dict["output_cost_credits"]
400
412
  )
401
413
 
414
+ # Handle cost with cache
415
+ input_cost_with_cache = model_cost_dict["input_cost_usd_with_cache"]
416
+ output_cost_with_cache = model_cost_dict["output_cost_usd_with_cache"]
417
+ model_cost_dict["input_cost_credits_with_cache"] = converter.usd_to_credits(
418
+ input_cost_with_cache
419
+ )
420
+ model_cost_dict["output_cost_credits_with_cache"] = (
421
+ converter.usd_to_credits(output_cost_with_cache)
422
+ )
402
423
  return list(expenses_by_model.values())
403
424
 
404
425
  def _fetch_results_and_log(
@@ -423,6 +444,12 @@ class JobsRemoteInferenceHandler:
423
444
  input_cost_usd=model_cost_dict.get("input_cost_usd"),
424
445
  output_tokens=model_cost_dict.get("output_tokens"),
425
446
  output_cost_usd=model_cost_dict.get("output_cost_usd"),
447
+ input_cost_credits_with_cache=model_cost_dict.get(
448
+ "input_cost_credits_with_cache"
449
+ ),
450
+ output_cost_credits_with_cache=model_cost_dict.get(
451
+ "output_cost_credits_with_cache"
452
+ ),
426
453
  )
427
454
  for model_cost_dict in model_cost_dicts
428
455
  ]
@@ -769,8 +769,45 @@ class LanguageModel(
769
769
  params["question_name"] = invigilator.question.question_name
770
770
  # Get timeout from configuration
771
771
  from ..config import CONFIG
772
-
773
- TIMEOUT = float(CONFIG.get("EDSL_API_TIMEOUT"))
772
+ import logging
773
+
774
+ logger = logging.getLogger(__name__)
775
+ base_timeout = float(CONFIG.get("EDSL_API_TIMEOUT"))
776
+
777
+ # Adjust timeout if files are present
778
+ import time
779
+
780
+ start = time.time()
781
+ if files_list:
782
+ # Calculate total size of attached files in MB
783
+ file_sizes = []
784
+ for file in files_list:
785
+ # Try different attributes that might contain the file content
786
+ if hasattr(file, "base64_string") and file.base64_string:
787
+ file_sizes.append(len(file.base64_string) / (1024 * 1024))
788
+ elif hasattr(file, "content") and file.content:
789
+ file_sizes.append(len(file.content) / (1024 * 1024))
790
+ elif hasattr(file, "data") and file.data:
791
+ file_sizes.append(len(file.data) / (1024 * 1024))
792
+ else:
793
+ # Default minimum size if we can't determine actual size
794
+ file_sizes.append(1) # Assume at least 1MB
795
+ total_size_mb = sum(file_sizes)
796
+
797
+ # Increase timeout proportionally to file size
798
+ # For each MB of file size, add 10 seconds to the timeout (adjust as needed)
799
+ size_adjustment = total_size_mb * 10
800
+
801
+ # Cap the maximum timeout adjustment at 5 minutes (300 seconds)
802
+ size_adjustment = min(size_adjustment, 300)
803
+
804
+ TIMEOUT = base_timeout + size_adjustment
805
+
806
+ logger.info(
807
+ f"Adjusted timeout for API call with {len(files_list)} files (total size: {total_size_mb:.2f}MB). Base timeout: {base_timeout}s, New timeout: {TIMEOUT}s"
808
+ )
809
+ else:
810
+ TIMEOUT = base_timeout
774
811
 
775
812
  # Execute the model call with timeout
776
813
  response = await asyncio.wait_for(f(**params), timeout=TIMEOUT)
edsl/prompts/prompt.py CHANGED
@@ -290,6 +290,7 @@ class Prompt(PersistenceMixin, RepresentationMixin):
290
290
  return result
291
291
  except Exception as e:
292
292
  print(f"Error rendering prompt: {e}")
293
+ raise e
293
294
  return self
294
295
 
295
296
  @staticmethod
@@ -299,23 +299,24 @@ class ListResponseValidator(ResponseValidatorABC):
299
299
  # This method can now be removed since validation is handled in the Pydantic model
300
300
  pass
301
301
 
302
- def fix(self, response, verbose=False):
302
+ def fix(self, response, verbose=False) -> dict[str, Any]:
303
303
  """
304
304
  Fix common issues in list responses by splitting strings into lists.
305
305
 
306
306
  Examples:
307
307
  >>> from edsl import QuestionList
308
- >>> q = QuestionList.example(min_list_items=2, max_list_items=4)
309
- >>> validator = q.response_validator
308
+ >>> q_constrained = QuestionList.example(min_list_items=2, max_list_items=4)
309
+ >>> validator_constrained = q_constrained.response_validator
310
310
 
311
+ >>> q_permissive = QuestionList.example(permissive=True)
312
+ >>> validator_permissive = q_permissive.response_validator
313
+
311
314
  >>> # Fix a string that should be a list
312
315
  >>> bad_response = {"answer": "apple,banana,cherry"}
313
- >>> try:
314
- ... validator.validate(bad_response)
315
- ... except Exception:
316
- ... fixed = validator.fix(bad_response)
317
- ... validated = validator.validate(fixed)
318
- ... validated # Show full response
316
+ >>> fixed = validator_constrained.fix(bad_response)
317
+ >>> fixed
318
+ {'answer': ['apple', 'banana', 'cherry']}
319
+ >>> validator_constrained.validate(fixed) # Show full response after validation
319
320
  {'answer': ['apple', 'banana', 'cherry'], 'comment': None, 'generated_tokens': None}
320
321
 
321
322
  >>> # Fix using generated_tokens when answer is invalid
@@ -323,12 +324,10 @@ class ListResponseValidator(ResponseValidatorABC):
323
324
  ... "answer": None,
324
325
  ... "generated_tokens": "pizza, pasta, salad"
325
326
  ... }
326
- >>> try:
327
- ... validator.validate(bad_response)
328
- ... except Exception:
329
- ... fixed = validator.fix(bad_response)
330
- ... validated = validator.validate(fixed)
331
- ... validated
327
+ >>> fixed = validator_constrained.fix(bad_response)
328
+ >>> fixed
329
+ {'answer': ['pizza', ' pasta', ' salad']}
330
+ >>> validator_constrained.validate(fixed)
332
331
  {'answer': ['pizza', ' pasta', ' salad'], 'comment': None, 'generated_tokens': None}
333
332
 
334
333
  >>> # Preserve comments during fixing
@@ -336,17 +335,74 @@ class ListResponseValidator(ResponseValidatorABC):
336
335
  ... "answer": "red,blue,green",
337
336
  ... "comment": "These are colors"
338
337
  ... }
339
- >>> fixed = validator.fix(bad_response)
340
- >>> fixed == {
338
+ >>> fixed_output = validator_constrained.fix(bad_response)
339
+ >>> fixed_output
340
+ {'answer': ['red', 'blue', 'green'], 'comment': 'These are colors'}
341
+ >>> validated_output = validator_constrained.validate(fixed_output)
342
+ >>> validated_output == {
341
343
  ... "answer": ["red", "blue", "green"],
342
- ... "comment": "These are colors"
344
+ ... "comment": "These are colors",
345
+ ... "generated_tokens": None
343
346
  ... }
344
347
  True
348
+
349
+ >>> # Fix an empty string answer
350
+ >>> bad_response = {"answer": ""}
351
+ >>> fixed = validator_constrained.fix(bad_response)
352
+ >>> fixed
353
+ {'answer': []}
354
+ >>> validator_permissive.validate(fixed)
355
+ {'answer': [], 'comment': None, 'generated_tokens': None}
356
+
357
+ >>> # Fix a single item string answer (no commas)
358
+ >>> bad_response = {"answer": "single_item"}
359
+ >>> fixed = validator_constrained.fix(bad_response)
360
+ >>> fixed
361
+ {'answer': ['single_item']}
362
+ >>> validator_permissive.validate(fixed)
363
+ {'answer': ['single_item'], 'comment': None, 'generated_tokens': None}
364
+
365
+ >>> # Fix when answer is None and no generated_tokens
366
+ >>> bad_response = {"answer": None}
367
+ >>> fixed = validator_constrained.fix(bad_response)
368
+ >>> fixed
369
+ {'answer': []}
370
+ >>> validator_permissive.validate(fixed)
371
+ {'answer': [], 'comment': None, 'generated_tokens': None}
372
+
373
+ >>> # Fix when answer key is missing but generated_tokens is present
374
+ >>> bad_response = {"generated_tokens": "token1,token2"}
375
+ >>> fixed = validator_constrained.fix(bad_response)
376
+ >>> fixed
377
+ {'answer': ['token1', 'token2']}
378
+ >>> validator_constrained.validate(fixed) # 2 items, OK for constrained validator
379
+ {'answer': ['token1', 'token2'], 'comment': None, 'generated_tokens': None}
380
+
381
+ >>> # Fix when answer key is missing and generated_tokens is an empty string
382
+ >>> bad_response = {"generated_tokens": ""}
383
+ >>> fixed = validator_constrained.fix(bad_response)
384
+ >>> fixed
385
+ {'answer': []}
386
+ >>> validator_permissive.validate(fixed)
387
+ {'answer': [], 'comment': None, 'generated_tokens': None}
388
+
389
+ >>> # Fix when answer key is missing and generated_tokens is a single item
390
+ >>> bad_response = {"generated_tokens": "single_token"}
391
+ >>> fixed = validator_constrained.fix(bad_response)
392
+ >>> fixed
393
+ {'answer': ['single_token']}
394
+ >>> validator_permissive.validate(fixed)
395
+ {'answer': ['single_token'], 'comment': None, 'generated_tokens': None}
345
396
  """
346
397
  if verbose:
347
398
  print(f"Fixing list response: {response}")
348
399
  answer = str(response.get("answer") or response.get("generated_tokens", ""))
349
- result = {"answer": answer.split(",")}
400
+ if "," in answer:
401
+ result = {"answer": answer.split(",")}
402
+ elif answer == "":
403
+ result = {"answer": []}
404
+ else:
405
+ result = {"answer": [answer]}
350
406
  if "comment" in response:
351
407
  result["comment"] = response["comment"]
352
408
  return result
@@ -395,7 +451,7 @@ class QuestionList(QuestionBase):
395
451
 
396
452
  self.include_comment = include_comment
397
453
  self.answering_instructions = answering_instructions
398
- self.question_presentations = question_presentation
454
+ self.question_presentation = question_presentation
399
455
 
400
456
  def create_response_model(self):
401
457
  return create_model(self.min_list_items, self.max_list_items, self.permissive)
edsl/results/results.py CHANGED
@@ -771,6 +771,10 @@ class Results(MutableSequence, ResultsOperationsMixin, Base):
771
771
  def to_dataset(self) -> "Dataset":
772
772
  return self.select()
773
773
 
774
+ def optimzie_scenarios(self):
775
+ for result in self.data:
776
+ result.scenario.offload(inplace=True)
777
+
774
778
  def to_dict(
775
779
  self,
776
780
  sort: bool = False,
@@ -778,9 +782,12 @@ class Results(MutableSequence, ResultsOperationsMixin, Base):
778
782
  include_cache: bool = True,
779
783
  include_task_history: bool = False,
780
784
  include_cache_info: bool = True,
785
+ offload_scenarios: bool = True,
781
786
  ) -> dict[str, Any]:
782
787
  from ..caching import Cache
783
788
 
789
+ if offload_scenarios:
790
+ self.optimzie_scenarios()
784
791
  if sort:
785
792
  data = sorted([result for result in self.data], key=lambda x: hash(x))
786
793
  else:
@@ -809,7 +816,7 @@ class Results(MutableSequence, ResultsOperationsMixin, Base):
809
816
  )
810
817
 
811
818
  if self.task_history.has_unfixed_exceptions or include_task_history:
812
- d.update({"task_history": self.task_history.to_dict()})
819
+ d.update({"task_history": self.task_history.to_dict(offload_content=True)})
813
820
 
814
821
  if add_edsl_version:
815
822
  from .. import __version__
@@ -446,9 +446,7 @@ class FileStore(Scenario):
446
446
  if suffix is None:
447
447
  suffix = self.suffix
448
448
  if self.binary:
449
- file_like_object = self.base64_to_file(
450
- self["base64_string"], is_binary=True
451
- )
449
+ file_like_object = self.base64_to_file(self.base64_string, is_binary=True)
452
450
  else:
453
451
  file_like_object = self.base64_to_text_file(self.base64_string)
454
452
 
@@ -765,15 +763,13 @@ class FileStore(Scenario):
765
763
  if name.startswith("__") and name.endswith("__"):
766
764
  raise AttributeError(name)
767
765
 
768
- # Only try to access suffix if it's in our __dict__
769
- if hasattr(self, "_data") and "suffix" in self._data:
770
- if self._data["suffix"] == "csv":
771
- # Get the pandas DataFrame
772
- df = self.to_pandas()
773
- # Check if the requested attribute exists in the DataFrame
774
- if hasattr(df, name):
775
- return getattr(df, name)
776
- # If not a CSV or attribute doesn't exist in DataFrame, raise AttributeError
766
+ # Check for _data directly in __dict__ to avoid recursion
767
+ _data = self.__dict__.get("_data", None)
768
+ if _data and _data.get("suffix") == "csv":
769
+ df = self.to_pandas()
770
+ if hasattr(df, name):
771
+ return getattr(df, name)
772
+
777
773
  raise AttributeError(
778
774
  f"'{self.__class__.__name__}' object has no attribute '{name}'"
779
775
  )
@@ -264,9 +264,49 @@ class Scenario(Base, UserDict):
264
264
  """Display a scenario as a table."""
265
265
  return self.to_dataset().table(tablefmt=tablefmt)
266
266
 
267
- def to_dict(self, add_edsl_version: bool = True) -> dict:
267
+ def offload(self, inplace=False) -> "Scenario":
268
+ """
269
+ Offloads base64-encoded content from the scenario by replacing 'base64_string'
270
+ fields with 'offloaded'. This reduces memory usage.
271
+
272
+ Args:
273
+ inplace (bool): If True, modify the current scenario. If False, return a new one.
274
+
275
+ Returns:
276
+ Scenario: The modified scenario (either self or a new instance).
277
+ """
278
+ from edsl.scenarios import FileStore
279
+ from edsl.prompts import Prompt
280
+
281
+ target = self if inplace else Scenario()
282
+
283
+ for key, value in self.items():
284
+ if isinstance(value, FileStore):
285
+ file_store_dict = value.to_dict()
286
+ if "base64_string" in file_store_dict:
287
+ file_store_dict["base64_string"] = "offloaded"
288
+ modified_value = FileStore.from_dict(file_store_dict)
289
+ elif isinstance(value, dict) and "base64_string" in value:
290
+ value_copy = value.copy()
291
+ value_copy["base64_string"] = "offloaded"
292
+ modified_value = value_copy
293
+ else:
294
+ modified_value = value
295
+
296
+ target[key] = modified_value
297
+
298
+ return target
299
+
300
+ def to_dict(
301
+ self, add_edsl_version: bool = True, offload_base64: bool = False
302
+ ) -> dict:
268
303
  """Convert a scenario to a dictionary.
269
304
 
305
+ Args:
306
+ add_edsl_version: If True, adds the EDSL version to the returned dictionary.
307
+ offload_base64: If True, replaces any base64_string fields with 'offloaded'
308
+ to reduce memory usage.
309
+
270
310
  Example:
271
311
 
272
312
  >>> s = Scenario({"food": "wood chips"})
@@ -283,7 +323,15 @@ class Scenario(Base, UserDict):
283
323
  d = self.data.copy()
284
324
  for key, value in d.items():
285
325
  if isinstance(value, FileStore) or isinstance(value, Prompt):
286
- d[key] = value.to_dict(add_edsl_version=add_edsl_version)
326
+ value_dict = value.to_dict(add_edsl_version=add_edsl_version)
327
+ if (
328
+ offload_base64
329
+ and isinstance(value_dict, dict)
330
+ and "base64_string" in value_dict
331
+ ):
332
+ value_dict["base64_string"] = "offloaded"
333
+ d[key] = value_dict
334
+
287
335
  if add_edsl_version:
288
336
  from edsl import __version__
289
337
 
@@ -145,22 +145,18 @@ class ScenarioList(MutableSequence, Base, ScenarioListOperationsMixin):
145
145
  """Initialize a new ScenarioList with optional data and codebook."""
146
146
  self._data_class = data_class
147
147
  self.data = self._data_class([])
148
- warned = False
149
148
  for item in data or []:
150
- try:
151
- _ = json.dumps(item.to_dict())
152
- except:
153
- import warnings
154
- if not warned:
155
- warnings.warn(
156
- f"One or more items in the data list are not JSON serializable. "
157
- "This would prevent running a job that uses this ScenarioList."
158
- "One solution is to use 'str(item)' to convert the item to a string before adding."
159
- )
160
- warned = True
161
149
  self.data.append(item)
162
150
  self.codebook = codebook or {}
163
151
 
152
+ def is_serializable(self):
153
+ for item in self.data:
154
+ try:
155
+ _ = json.dumps(item.to_dict())
156
+ except Exception as e:
157
+ return False
158
+ return True
159
+
164
160
  # Required MutableSequence abstract methods
165
161
  def __getitem__(self, index):
166
162
  """Get item at index."""
@@ -360,6 +356,32 @@ class ScenarioList(MutableSequence, Base, ScenarioListOperationsMixin):
360
356
  new_scenarios.append(Scenario(new_scenario))
361
357
 
362
358
  return new_scenarios
359
+
360
+ @classmethod
361
+ def from_search_terms(cls, search_terms: List[str]) -> ScenarioList:
362
+ """Create a ScenarioList from a list of search terms, using Wikipedia.
363
+
364
+ Args:
365
+ search_terms: A list of search terms.
366
+ """
367
+ from ..utilities.wikipedia import fetch_wikipedia_content
368
+ results = fetch_wikipedia_content(search_terms)
369
+ return cls([Scenario(result) for result in results])
370
+
371
+ def augment_with_wikipedia(self, search_key:str, content_only: bool = True, key_name: str = "wikipedia_content") -> ScenarioList:
372
+ """Augment the ScenarioList with Wikipedia content."""
373
+ search_terms = self.select(search_key).to_list()
374
+ wikipedia_results = ScenarioList.from_search_terms(search_terms)
375
+ new_sl = ScenarioList(data = [], codebook = self.codebook)
376
+ for scenario, wikipedia_result in zip(self, wikipedia_results):
377
+ if content_only:
378
+ scenario[key_name] = wikipedia_result["content"]
379
+ new_sl.append(scenario)
380
+ else:
381
+ scenario[key_name] = wikipedia_result
382
+ new_sl.append(scenario)
383
+ return new_sl
384
+
363
385
 
364
386
  def pivot(
365
387
  self,
edsl/surveys/survey.py CHANGED
@@ -384,6 +384,10 @@ class Survey(Base):
384
384
  if question_name not in self.question_name_to_index:
385
385
  raise SurveyError(f"Question name {question_name} not found in survey.")
386
386
  return self.questions[self.question_name_to_index[question_name]]
387
+
388
+ def get(self, question_name: str) -> QuestionBase:
389
+ """Return the question object given the question name."""
390
+ return self._get_question_by_name(question_name)
387
391
 
388
392
  def question_names_to_questions(self) -> dict:
389
393
  """Return a dictionary mapping question names to question attributes."""