edsl 0.1.58__py3-none-any.whl → 0.1.60__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. edsl/__version__.py +1 -1
  2. edsl/agents/agent.py +23 -4
  3. edsl/agents/agent_list.py +36 -6
  4. edsl/base/data_transfer_models.py +5 -0
  5. edsl/base/enums.py +7 -2
  6. edsl/coop/coop.py +103 -1
  7. edsl/dataset/dataset.py +74 -0
  8. edsl/dataset/dataset_operations_mixin.py +69 -64
  9. edsl/inference_services/services/__init__.py +3 -1
  10. edsl/inference_services/services/open_ai_service_v2.py +243 -0
  11. edsl/inference_services/services/test_service.py +1 -1
  12. edsl/interviews/exception_tracking.py +66 -20
  13. edsl/invigilators/invigilators.py +5 -1
  14. edsl/invigilators/prompt_constructor.py +299 -136
  15. edsl/jobs/data_structures.py +3 -0
  16. edsl/jobs/html_table_job_logger.py +18 -1
  17. edsl/jobs/jobs_pricing_estimation.py +6 -2
  18. edsl/jobs/jobs_remote_inference_logger.py +2 -0
  19. edsl/jobs/remote_inference.py +34 -7
  20. edsl/key_management/key_lookup_builder.py +25 -3
  21. edsl/language_models/language_model.py +41 -3
  22. edsl/language_models/raw_response_handler.py +126 -7
  23. edsl/prompts/prompt.py +1 -0
  24. edsl/questions/question_list.py +76 -20
  25. edsl/results/result.py +37 -0
  26. edsl/results/results.py +9 -1
  27. edsl/scenarios/file_store.py +8 -12
  28. edsl/scenarios/scenario.py +50 -2
  29. edsl/scenarios/scenario_list.py +34 -12
  30. edsl/surveys/survey.py +4 -0
  31. edsl/tasks/task_history.py +180 -6
  32. edsl/utilities/wikipedia.py +194 -0
  33. {edsl-0.1.58.dist-info → edsl-0.1.60.dist-info}/METADATA +5 -4
  34. {edsl-0.1.58.dist-info → edsl-0.1.60.dist-info}/RECORD +37 -35
  35. {edsl-0.1.58.dist-info → edsl-0.1.60.dist-info}/LICENSE +0 -0
  36. {edsl-0.1.58.dist-info → edsl-0.1.60.dist-info}/WHEEL +0 -0
  37. {edsl-0.1.58.dist-info → edsl-0.1.60.dist-info}/entry_points.txt +0 -0
@@ -213,6 +213,9 @@ class Answers(UserDict):
213
213
  if comment:
214
214
  self[question.question_name + "_comment"] = comment
215
215
 
216
+ if getattr(response, "reasoning_summary", None):
217
+ self[question.question_name + "_reasoning_summary"] = response.reasoning_summary
218
+
216
219
  def replace_missing_answers_with_none(self, survey: "Survey") -> None:
217
220
  """
218
221
  Replace missing answers with None for all questions in the survey.
@@ -217,6 +217,17 @@ class HTMLTableJobLogger(JobLogger):
217
217
  )
218
218
  total_cost = total_input_cost + total_output_cost
219
219
 
220
+ # Calculate credit totals
221
+ total_input_credits = sum(
222
+ cost.input_cost_credits_with_cache or 0
223
+ for cost in self.jobs_info.model_costs
224
+ )
225
+ total_output_credits = sum(
226
+ cost.output_cost_credits_with_cache or 0
227
+ for cost in self.jobs_info.model_costs
228
+ )
229
+ total_credits = total_input_credits + total_output_credits
230
+
220
231
  # Generate cost rows HTML with class names for right alignment
221
232
  cost_rows = "".join(
222
233
  f"""
@@ -228,6 +239,7 @@ class HTMLTableJobLogger(JobLogger):
228
239
  <td class='token-count'>{cost.output_tokens:,}</td>
229
240
  <td class='cost-value'>${cost.output_cost_usd:.4f}</td>
230
241
  <td class='cost-value'>${(cost.input_cost_usd or 0) + (cost.output_cost_usd or 0):.4f}</td>
242
+ <td class='cost-value'>{(cost.input_cost_credits_with_cache or 0) + (cost.output_cost_credits_with_cache or 0):,.2f}</td>
231
243
  </tr>
232
244
  """
233
245
  for cost in self.jobs_info.model_costs
@@ -242,6 +254,7 @@ class HTMLTableJobLogger(JobLogger):
242
254
  <td class='token-count'>{total_output_tokens:,}</td>
243
255
  <td class='cost-value'>${total_output_cost:.4f}</td>
244
256
  <td class='cost-value'>${total_cost:.4f}</td>
257
+ <td class='cost-value'>{total_credits:,.2f}</td>
245
258
  </tr>
246
259
  """
247
260
 
@@ -249,7 +262,7 @@ class HTMLTableJobLogger(JobLogger):
249
262
  <div class="model-costs-section">
250
263
  <div class="model-costs-header" onclick="{self._collapse(f'model-costs-content-{self.log_id}', f'model-costs-arrow-{self.log_id}')}">
251
264
  <span id="model-costs-arrow-{self.log_id}" class="expand-toggle">&#8963;</span>
252
- <span>Model Costs (${total_cost:.4f} total)</span>
265
+ <span>Model Costs (${total_cost:.4f} / {total_credits:,.2f} credits total)</span>
253
266
  <span style="flex-grow: 1;"></span>
254
267
  </div>
255
268
  <div id="model-costs-content-{self.log_id}" class="model-costs-content">
@@ -263,6 +276,7 @@ class HTMLTableJobLogger(JobLogger):
263
276
  <th class="cost-header">Output Tokens</th>
264
277
  <th class="cost-header">Output Cost</th>
265
278
  <th class="cost-header">Total Cost</th>
279
+ <th class="cost-header">Total Credits</th>
266
280
  </tr>
267
281
  </thead>
268
282
  <tbody>
@@ -270,6 +284,9 @@ class HTMLTableJobLogger(JobLogger):
270
284
  {total_row}
271
285
  </tbody>
272
286
  </table>
287
+ <p style="font-style: italic; margin-top: 8px; font-size: 0.85em; color: #4b5563;">
288
+ You can obtain the total credit cost by multiplying the total USD cost by 100. A lower credit cost indicates that you saved money by retrieving responses from the universal remote cache.
289
+ </p>
273
290
  </div>
274
291
  </div>
275
292
  """
@@ -88,7 +88,6 @@ class PromptCostEstimator:
88
88
 
89
89
 
90
90
  class JobsPrompts:
91
-
92
91
  relevant_keys = [
93
92
  "user_prompt",
94
93
  "system_prompt",
@@ -171,13 +170,18 @@ class JobsPrompts:
171
170
  cost = prompt_cost["cost_usd"]
172
171
 
173
172
  # Generate cache keys for each iteration
173
+ files_list = prompts.get("files_list", None)
174
+ if files_list:
175
+ files_hash = "+".join([str(hash(file)) for file in files_list])
176
+ user_prompt_with_hashes = user_prompt + f" {files_hash}"
174
177
  cache_keys = []
178
+
175
179
  for iteration in range(iterations):
176
180
  cache_key = CacheEntry.gen_key(
177
181
  model=model,
178
182
  parameters=invigilator.model.parameters,
179
183
  system_prompt=system_prompt,
180
- user_prompt=user_prompt,
184
+ user_prompt=user_prompt_with_hashes if files_list else user_prompt,
181
185
  iteration=iteration,
182
186
  )
183
187
  cache_keys.append(cache_key)
@@ -40,6 +40,8 @@ class ModelCost:
40
40
  input_cost_usd: float = None
41
41
  output_tokens: int = None
42
42
  output_cost_usd: float = None
43
+ input_cost_credits_with_cache: int = None
44
+ output_cost_credits_with_cache: int = None
43
45
 
44
46
 
45
47
  @dataclass
@@ -279,9 +279,7 @@ class JobsRemoteInferenceHandler:
279
279
  )
280
280
  time.sleep(self.poll_interval)
281
281
 
282
- def _get_expenses_from_results(
283
- self, results: "Results", include_cached_responses_in_cost: bool = False
284
- ) -> dict:
282
+ def _get_expenses_from_results(self, results: "Results") -> dict:
285
283
  """
286
284
  Calculates expenses from Results object.
287
285
 
@@ -309,10 +307,6 @@ class JobsRemoteInferenceHandler:
309
307
  question_name = key.removesuffix("_cost")
310
308
  cache_used = result["cache_used_dict"][question_name]
311
309
 
312
- # Skip if we're excluding cached responses and this was cached
313
- if not include_cached_responses_in_cost and cache_used:
314
- continue
315
-
316
310
  # Get expense keys for input and output tokens
317
311
  input_key = (
318
312
  result["model"]._inference_service_,
@@ -332,6 +326,7 @@ class JobsRemoteInferenceHandler:
332
326
  expenses[input_key] = {
333
327
  "tokens": 0,
334
328
  "cost_usd": 0,
329
+ "cost_usd_with_cache": 0,
335
330
  }
336
331
 
337
332
  input_price_per_million_tokens = input_key[3]
@@ -341,11 +336,15 @@ class JobsRemoteInferenceHandler:
341
336
  expenses[input_key]["tokens"] += input_tokens
342
337
  expenses[input_key]["cost_usd"] += input_cost
343
338
 
339
+ if not cache_used:
340
+ expenses[input_key]["cost_usd_with_cache"] += input_cost
341
+
344
342
  # Update output token expenses
345
343
  if output_key not in expenses:
346
344
  expenses[output_key] = {
347
345
  "tokens": 0,
348
346
  "cost_usd": 0,
347
+ "cost_usd_with_cache": 0,
349
348
  }
350
349
 
351
350
  output_price_per_million_tokens = output_key[3]
@@ -357,6 +356,9 @@ class JobsRemoteInferenceHandler:
357
356
  expenses[output_key]["tokens"] += output_tokens
358
357
  expenses[output_key]["cost_usd"] += output_cost
359
358
 
359
+ if not cache_used:
360
+ expenses[output_key]["cost_usd_with_cache"] += output_cost
361
+
360
362
  expenses_by_model = {}
361
363
  for expense_key, expense_usage in expenses.items():
362
364
  service, model, token_type, _ = expense_key
@@ -368,8 +370,10 @@ class JobsRemoteInferenceHandler:
368
370
  "model": model,
369
371
  "input_tokens": 0,
370
372
  "input_cost_usd": 0,
373
+ "input_cost_usd_with_cache": 0,
371
374
  "output_tokens": 0,
372
375
  "output_cost_usd": 0,
376
+ "output_cost_usd_with_cache": 0,
373
377
  }
374
378
 
375
379
  if token_type == "input":
@@ -377,14 +381,22 @@ class JobsRemoteInferenceHandler:
377
381
  expenses_by_model[model_key]["input_cost_usd"] += expense_usage[
378
382
  "cost_usd"
379
383
  ]
384
+ expenses_by_model[model_key][
385
+ "input_cost_usd_with_cache"
386
+ ] += expense_usage["cost_usd_with_cache"]
380
387
  elif token_type == "output":
381
388
  expenses_by_model[model_key]["output_tokens"] += expense_usage["tokens"]
382
389
  expenses_by_model[model_key]["output_cost_usd"] += expense_usage[
383
390
  "cost_usd"
384
391
  ]
392
+ expenses_by_model[model_key][
393
+ "output_cost_usd_with_cache"
394
+ ] += expense_usage["cost_usd_with_cache"]
385
395
 
386
396
  converter = CostConverter()
387
397
  for model_key, model_cost_dict in expenses_by_model.items():
398
+
399
+ # Handle full cost (without cache)
388
400
  input_cost = model_cost_dict["input_cost_usd"]
389
401
  output_cost = model_cost_dict["output_cost_usd"]
390
402
  model_cost_dict["input_cost_credits"] = converter.usd_to_credits(input_cost)
@@ -399,6 +411,15 @@ class JobsRemoteInferenceHandler:
399
411
  model_cost_dict["output_cost_credits"]
400
412
  )
401
413
 
414
+ # Handle cost with cache
415
+ input_cost_with_cache = model_cost_dict["input_cost_usd_with_cache"]
416
+ output_cost_with_cache = model_cost_dict["output_cost_usd_with_cache"]
417
+ model_cost_dict["input_cost_credits_with_cache"] = converter.usd_to_credits(
418
+ input_cost_with_cache
419
+ )
420
+ model_cost_dict["output_cost_credits_with_cache"] = (
421
+ converter.usd_to_credits(output_cost_with_cache)
422
+ )
402
423
  return list(expenses_by_model.values())
403
424
 
404
425
  def _fetch_results_and_log(
@@ -423,6 +444,12 @@ class JobsRemoteInferenceHandler:
423
444
  input_cost_usd=model_cost_dict.get("input_cost_usd"),
424
445
  output_tokens=model_cost_dict.get("output_tokens"),
425
446
  output_cost_usd=model_cost_dict.get("output_cost_usd"),
447
+ input_cost_credits_with_cache=model_cost_dict.get(
448
+ "input_cost_credits_with_cache"
449
+ ),
450
+ output_cost_credits_with_cache=model_cost_dict.get(
451
+ "output_cost_credits_with_cache"
452
+ ),
426
453
  )
427
454
  for model_cost_dict in model_cost_dicts
428
455
  ]
@@ -363,13 +363,35 @@ class KeyLookupBuilder:
363
363
  >>> builder._add_api_key("OPENAI_API_KEY", "sk-1234", "env")
364
364
  >>> 'sk-1234' == builder.key_data["openai"][-1].value
365
365
  True
366
+ >>> 'sk-1234' == builder.key_data["openai_v2"][-1].value
367
+ True
366
368
  """
367
369
  service = api_keyname_to_service[key]
368
370
  new_entry = APIKeyEntry(service=service, name=key, value=value, source=source)
369
- if service not in self.key_data:
370
- self.key_data[service] = [new_entry]
371
+
372
+ # Special case for OPENAI_API_KEY - add to both openai and openai_v2
373
+ if key == "OPENAI_API_KEY":
374
+ # Add to openai service
375
+ openai_service = "openai"
376
+ openai_entry = APIKeyEntry(service=openai_service, name=key, value=value, source=source)
377
+ if openai_service not in self.key_data:
378
+ self.key_data[openai_service] = [openai_entry]
379
+ else:
380
+ self.key_data[openai_service].append(openai_entry)
381
+
382
+ # Add to openai_v2 service
383
+ openai_v2_service = "openai_v2"
384
+ openai_v2_entry = APIKeyEntry(service=openai_v2_service, name=key, value=value, source=source)
385
+ if openai_v2_service not in self.key_data:
386
+ self.key_data[openai_v2_service] = [openai_v2_entry]
387
+ else:
388
+ self.key_data[openai_v2_service].append(openai_v2_entry)
371
389
  else:
372
- self.key_data[service].append(new_entry)
390
+ # Normal case for all other API keys
391
+ if service not in self.key_data:
392
+ self.key_data[service] = [new_entry]
393
+ else:
394
+ self.key_data[service].append(new_entry)
373
395
 
374
396
  def update_from_dict(self, d: dict) -> None:
375
397
  """
@@ -174,7 +174,8 @@ class LanguageModel(
174
174
  """
175
175
  key_sequence = cls.key_sequence
176
176
  usage_sequence = cls.usage_sequence if hasattr(cls, "usage_sequence") else None
177
- return RawResponseHandler(key_sequence, usage_sequence)
177
+ reasoning_sequence = cls.reasoning_sequence if hasattr(cls, "reasoning_sequence") else None
178
+ return RawResponseHandler(key_sequence, usage_sequence, reasoning_sequence)
178
179
 
179
180
  def __init__(
180
181
  self,
@@ -769,8 +770,45 @@ class LanguageModel(
769
770
  params["question_name"] = invigilator.question.question_name
770
771
  # Get timeout from configuration
771
772
  from ..config import CONFIG
772
-
773
- TIMEOUT = float(CONFIG.get("EDSL_API_TIMEOUT"))
773
+ import logging
774
+
775
+ logger = logging.getLogger(__name__)
776
+ base_timeout = float(CONFIG.get("EDSL_API_TIMEOUT"))
777
+
778
+ # Adjust timeout if files are present
779
+ import time
780
+
781
+ start = time.time()
782
+ if files_list:
783
+ # Calculate total size of attached files in MB
784
+ file_sizes = []
785
+ for file in files_list:
786
+ # Try different attributes that might contain the file content
787
+ if hasattr(file, "base64_string") and file.base64_string:
788
+ file_sizes.append(len(file.base64_string) / (1024 * 1024))
789
+ elif hasattr(file, "content") and file.content:
790
+ file_sizes.append(len(file.content) / (1024 * 1024))
791
+ elif hasattr(file, "data") and file.data:
792
+ file_sizes.append(len(file.data) / (1024 * 1024))
793
+ else:
794
+ # Default minimum size if we can't determine actual size
795
+ file_sizes.append(1) # Assume at least 1MB
796
+ total_size_mb = sum(file_sizes)
797
+
798
+ # Increase timeout proportionally to file size
799
+ # For each MB of file size, add 10 seconds to the timeout (adjust as needed)
800
+ size_adjustment = total_size_mb * 10
801
+
802
+ # Cap the maximum timeout adjustment at 5 minutes (300 seconds)
803
+ size_adjustment = min(size_adjustment, 300)
804
+
805
+ TIMEOUT = base_timeout + size_adjustment
806
+
807
+ logger.info(
808
+ f"Adjusted timeout for API call with {len(files_list)} files (total size: {total_size_mb:.2f}MB). Base timeout: {base_timeout}s, New timeout: {TIMEOUT}s"
809
+ )
810
+ else:
811
+ TIMEOUT = base_timeout
774
812
 
775
813
  # Execute the model call with timeout
776
814
  response = await asyncio.wait_for(f(**params), timeout=TIMEOUT)
@@ -1,5 +1,5 @@
1
1
  import json
2
- from typing import Optional, Any
2
+ from typing import Optional, Any, List
3
3
  from .exceptions import (
4
4
  LanguageModelBadResponseError,
5
5
  LanguageModelTypeError,
@@ -41,10 +41,13 @@ def _extract_item_from_raw_response(data, sequence):
41
41
  current_data = current_data[key]
42
42
  except Exception as e:
43
43
  path = " -> ".join(map(str, sequence[: i + 1]))
44
- if "error" in data:
45
- msg = data["error"]
44
+
45
+ # Create a safe error message that won't be None
46
+ if "error" in data and data["error"] is not None:
47
+ msg = str(data["error"])
46
48
  else:
47
49
  msg = f"Error accessing path: {path}. {str(e)}. Full response is: '{data}'"
50
+
48
51
  raise LanguageModelBadResponseError(message=msg, response_json=data)
49
52
  if isinstance(current_data, str):
50
53
  return current_data.strip()
@@ -55,17 +58,127 @@ def _extract_item_from_raw_response(data, sequence):
55
58
  class RawResponseHandler:
56
59
  """Class to handle raw responses from language models."""
57
60
 
58
- def __init__(self, key_sequence: list, usage_sequence: Optional[list] = None):
61
+ def __init__(self, key_sequence: list, usage_sequence: Optional[list] = None, reasoning_sequence: Optional[list] = None):
59
62
  self.key_sequence = key_sequence
60
63
  self.usage_sequence = usage_sequence
64
+ self.reasoning_sequence = reasoning_sequence
61
65
 
62
66
  def get_generated_token_string(self, raw_response):
63
- return _extract_item_from_raw_response(raw_response, self.key_sequence)
67
+ try:
68
+ return _extract_item_from_raw_response(raw_response, self.key_sequence)
69
+ except (LanguageModelKeyError, LanguageModelIndexError, LanguageModelTypeError, LanguageModelBadResponseError) as e:
70
+ # For non-reasoning models or reasoning models with different response formats,
71
+ # try to extract text directly from common response formats
72
+ if isinstance(raw_response, dict):
73
+ # Responses API format for non-reasoning models
74
+ if 'output' in raw_response and isinstance(raw_response['output'], list):
75
+ # Try to get first message content
76
+ if len(raw_response['output']) > 0:
77
+ item = raw_response['output'][0]
78
+ if isinstance(item, dict) and 'content' in item:
79
+ if isinstance(item['content'], list) and len(item['content']) > 0:
80
+ first_content = item['content'][0]
81
+ if isinstance(first_content, dict) and 'text' in first_content:
82
+ return first_content['text']
83
+ elif isinstance(item['content'], str):
84
+ return item['content']
85
+
86
+ # OpenAI completions format
87
+ if 'choices' in raw_response and isinstance(raw_response['choices'], list) and len(raw_response['choices']) > 0:
88
+ choice = raw_response['choices'][0]
89
+ if isinstance(choice, dict):
90
+ if 'text' in choice:
91
+ return choice['text']
92
+ elif 'message' in choice and isinstance(choice['message'], dict) and 'content' in choice['message']:
93
+ return choice['message']['content']
94
+
95
+ # Text directly in response
96
+ if 'text' in raw_response:
97
+ return raw_response['text']
98
+ elif 'content' in raw_response:
99
+ return raw_response['content']
100
+
101
+ # Error message - try to return a coherent error for debugging
102
+ if 'message' in raw_response:
103
+ return f"[ERROR: {raw_response['message']}]"
104
+
105
+ # If we get a string directly, return it
106
+ if isinstance(raw_response, str):
107
+ return raw_response
108
+
109
+ # As a last resort, convert the whole response to string
110
+ try:
111
+ return f"[ERROR: Could not extract text. Raw response: {str(raw_response)}]"
112
+ except:
113
+ return "[ERROR: Could not extract text from response]"
64
114
 
65
115
  def get_usage_dict(self, raw_response):
66
116
  if self.usage_sequence is None:
67
117
  return {}
68
- return _extract_item_from_raw_response(raw_response, self.usage_sequence)
118
+ try:
119
+ return _extract_item_from_raw_response(raw_response, self.usage_sequence)
120
+ except (LanguageModelKeyError, LanguageModelIndexError, LanguageModelTypeError, LanguageModelBadResponseError):
121
+ # For non-reasoning models, try to extract usage from common response formats
122
+ if isinstance(raw_response, dict):
123
+ # Standard OpenAI usage format
124
+ if 'usage' in raw_response:
125
+ return raw_response['usage']
126
+
127
+ # Look for nested usage info
128
+ if 'choices' in raw_response and len(raw_response['choices']) > 0:
129
+ choice = raw_response['choices'][0]
130
+ if isinstance(choice, dict) and 'usage' in choice:
131
+ return choice['usage']
132
+
133
+ # If no usage info found, return empty dict
134
+ return {}
135
+
136
+ def get_reasoning_summary(self, raw_response):
137
+ """
138
+ Extract reasoning summary from the model response.
139
+
140
+ Handles various response structures:
141
+ 1. Standard path extraction using self.reasoning_sequence
142
+ 2. Direct access to output[0]['summary'] for OpenAI responses
143
+ 3. List responses where the first item contains the output structure
144
+ """
145
+ if self.reasoning_sequence is None:
146
+ return None
147
+
148
+ try:
149
+ # First try the standard extraction path
150
+ summary_data = _extract_item_from_raw_response(raw_response, self.reasoning_sequence)
151
+
152
+ # If summary_data is a list of dictionaries with 'text' and 'type' fields
153
+ # (as in OpenAI's response format), combine them into a single string
154
+ if isinstance(summary_data, list) and all(isinstance(item, dict) and 'text' in item for item in summary_data):
155
+ return '\n\n'.join(item['text'] for item in summary_data)
156
+
157
+ return summary_data
158
+ except Exception:
159
+ # Fallback approaches for different response structures
160
+ try:
161
+ # Case 1: Direct dict with 'output' field (common OpenAI format)
162
+ if isinstance(raw_response, dict) and 'output' in raw_response:
163
+ output = raw_response['output']
164
+ if isinstance(output, list) and len(output) > 0 and 'summary' in output[0]:
165
+ summary_data = output[0]['summary']
166
+ if isinstance(summary_data, list) and all(isinstance(item, dict) and 'text' in item for item in summary_data):
167
+ return '\n\n'.join(item['text'] for item in summary_data)
168
+
169
+ # Case 2: List where the first item is a dict with 'output' field
170
+ if isinstance(raw_response, list) and len(raw_response) > 0:
171
+ first_item = raw_response[0]
172
+ if isinstance(first_item, dict) and 'output' in first_item:
173
+ output = first_item['output']
174
+ if isinstance(output, list) and len(output) > 0 and 'summary' in output[0]:
175
+ summary_data = output[0]['summary']
176
+ if isinstance(summary_data, list) and all(isinstance(item, dict) and 'text' in item for item in summary_data):
177
+ return '\n\n'.join(item['text'] for item in summary_data)
178
+ except Exception:
179
+ pass
180
+
181
+ return None
69
182
 
70
183
  def parse_response(self, raw_response: dict[str, Any]) -> Any:
71
184
  """Parses the API response and returns the response text."""
@@ -73,7 +186,11 @@ class RawResponseHandler:
73
186
  from edsl.data_transfer_models import EDSLOutput
74
187
 
75
188
  generated_token_string = self.get_generated_token_string(raw_response)
189
+ # Ensure generated_token_string is a string before using string methods
190
+ if not isinstance(generated_token_string, str):
191
+ generated_token_string = str(generated_token_string)
76
192
  last_newline = generated_token_string.rfind("\n")
193
+ reasoning_summary = self.get_reasoning_summary(raw_response)
77
194
 
78
195
  if last_newline == -1:
79
196
  # There is no comment
@@ -81,12 +198,14 @@ class RawResponseHandler:
81
198
  "answer": self.convert_answer(generated_token_string),
82
199
  "generated_tokens": generated_token_string,
83
200
  "comment": None,
201
+ "reasoning_summary": reasoning_summary,
84
202
  }
85
203
  else:
86
204
  edsl_dict = {
87
205
  "answer": self.convert_answer(generated_token_string[:last_newline]),
88
- "comment": generated_token_string[last_newline + 1 :].strip(),
206
+ "comment": generated_token_string[last_newline + 1:].strip(),
89
207
  "generated_tokens": generated_token_string,
208
+ "reasoning_summary": reasoning_summary,
90
209
  }
91
210
  return EDSLOutput(**edsl_dict)
92
211
 
edsl/prompts/prompt.py CHANGED
@@ -290,6 +290,7 @@ class Prompt(PersistenceMixin, RepresentationMixin):
290
290
  return result
291
291
  except Exception as e:
292
292
  print(f"Error rendering prompt: {e}")
293
+ raise e
293
294
  return self
294
295
 
295
296
  @staticmethod
@@ -299,23 +299,24 @@ class ListResponseValidator(ResponseValidatorABC):
299
299
  # This method can now be removed since validation is handled in the Pydantic model
300
300
  pass
301
301
 
302
- def fix(self, response, verbose=False):
302
+ def fix(self, response, verbose=False) -> dict[str, Any]:
303
303
  """
304
304
  Fix common issues in list responses by splitting strings into lists.
305
305
 
306
306
  Examples:
307
307
  >>> from edsl import QuestionList
308
- >>> q = QuestionList.example(min_list_items=2, max_list_items=4)
309
- >>> validator = q.response_validator
308
+ >>> q_constrained = QuestionList.example(min_list_items=2, max_list_items=4)
309
+ >>> validator_constrained = q_constrained.response_validator
310
310
 
311
+ >>> q_permissive = QuestionList.example(permissive=True)
312
+ >>> validator_permissive = q_permissive.response_validator
313
+
311
314
  >>> # Fix a string that should be a list
312
315
  >>> bad_response = {"answer": "apple,banana,cherry"}
313
- >>> try:
314
- ... validator.validate(bad_response)
315
- ... except Exception:
316
- ... fixed = validator.fix(bad_response)
317
- ... validated = validator.validate(fixed)
318
- ... validated # Show full response
316
+ >>> fixed = validator_constrained.fix(bad_response)
317
+ >>> fixed
318
+ {'answer': ['apple', 'banana', 'cherry']}
319
+ >>> validator_constrained.validate(fixed) # Show full response after validation
319
320
  {'answer': ['apple', 'banana', 'cherry'], 'comment': None, 'generated_tokens': None}
320
321
 
321
322
  >>> # Fix using generated_tokens when answer is invalid
@@ -323,12 +324,10 @@ class ListResponseValidator(ResponseValidatorABC):
323
324
  ... "answer": None,
324
325
  ... "generated_tokens": "pizza, pasta, salad"
325
326
  ... }
326
- >>> try:
327
- ... validator.validate(bad_response)
328
- ... except Exception:
329
- ... fixed = validator.fix(bad_response)
330
- ... validated = validator.validate(fixed)
331
- ... validated
327
+ >>> fixed = validator_constrained.fix(bad_response)
328
+ >>> fixed
329
+ {'answer': ['pizza', ' pasta', ' salad']}
330
+ >>> validator_constrained.validate(fixed)
332
331
  {'answer': ['pizza', ' pasta', ' salad'], 'comment': None, 'generated_tokens': None}
333
332
 
334
333
  >>> # Preserve comments during fixing
@@ -336,17 +335,74 @@ class ListResponseValidator(ResponseValidatorABC):
336
335
  ... "answer": "red,blue,green",
337
336
  ... "comment": "These are colors"
338
337
  ... }
339
- >>> fixed = validator.fix(bad_response)
340
- >>> fixed == {
338
+ >>> fixed_output = validator_constrained.fix(bad_response)
339
+ >>> fixed_output
340
+ {'answer': ['red', 'blue', 'green'], 'comment': 'These are colors'}
341
+ >>> validated_output = validator_constrained.validate(fixed_output)
342
+ >>> validated_output == {
341
343
  ... "answer": ["red", "blue", "green"],
342
- ... "comment": "These are colors"
344
+ ... "comment": "These are colors",
345
+ ... "generated_tokens": None
343
346
  ... }
344
347
  True
348
+
349
+ >>> # Fix an empty string answer
350
+ >>> bad_response = {"answer": ""}
351
+ >>> fixed = validator_constrained.fix(bad_response)
352
+ >>> fixed
353
+ {'answer': []}
354
+ >>> validator_permissive.validate(fixed)
355
+ {'answer': [], 'comment': None, 'generated_tokens': None}
356
+
357
+ >>> # Fix a single item string answer (no commas)
358
+ >>> bad_response = {"answer": "single_item"}
359
+ >>> fixed = validator_constrained.fix(bad_response)
360
+ >>> fixed
361
+ {'answer': ['single_item']}
362
+ >>> validator_permissive.validate(fixed)
363
+ {'answer': ['single_item'], 'comment': None, 'generated_tokens': None}
364
+
365
+ >>> # Fix when answer is None and no generated_tokens
366
+ >>> bad_response = {"answer": None}
367
+ >>> fixed = validator_constrained.fix(bad_response)
368
+ >>> fixed
369
+ {'answer': []}
370
+ >>> validator_permissive.validate(fixed)
371
+ {'answer': [], 'comment': None, 'generated_tokens': None}
372
+
373
+ >>> # Fix when answer key is missing but generated_tokens is present
374
+ >>> bad_response = {"generated_tokens": "token1,token2"}
375
+ >>> fixed = validator_constrained.fix(bad_response)
376
+ >>> fixed
377
+ {'answer': ['token1', 'token2']}
378
+ >>> validator_constrained.validate(fixed) # 2 items, OK for constrained validator
379
+ {'answer': ['token1', 'token2'], 'comment': None, 'generated_tokens': None}
380
+
381
+ >>> # Fix when answer key is missing and generated_tokens is an empty string
382
+ >>> bad_response = {"generated_tokens": ""}
383
+ >>> fixed = validator_constrained.fix(bad_response)
384
+ >>> fixed
385
+ {'answer': []}
386
+ >>> validator_permissive.validate(fixed)
387
+ {'answer': [], 'comment': None, 'generated_tokens': None}
388
+
389
+ >>> # Fix when answer key is missing and generated_tokens is a single item
390
+ >>> bad_response = {"generated_tokens": "single_token"}
391
+ >>> fixed = validator_constrained.fix(bad_response)
392
+ >>> fixed
393
+ {'answer': ['single_token']}
394
+ >>> validator_permissive.validate(fixed)
395
+ {'answer': ['single_token'], 'comment': None, 'generated_tokens': None}
345
396
  """
346
397
  if verbose:
347
398
  print(f"Fixing list response: {response}")
348
399
  answer = str(response.get("answer") or response.get("generated_tokens", ""))
349
- result = {"answer": answer.split(",")}
400
+ if "," in answer:
401
+ result = {"answer": answer.split(",")}
402
+ elif answer == "":
403
+ result = {"answer": []}
404
+ else:
405
+ result = {"answer": [answer]}
350
406
  if "comment" in response:
351
407
  result["comment"] = response["comment"]
352
408
  return result
@@ -395,7 +451,7 @@ class QuestionList(QuestionBase):
395
451
 
396
452
  self.include_comment = include_comment
397
453
  self.answering_instructions = answering_instructions
398
- self.question_presentations = question_presentation
454
+ self.question_presentation = question_presentation
399
455
 
400
456
  def create_response_model(self):
401
457
  return create_model(self.min_list_items, self.max_list_items, self.permissive)