edsl 0.1.58__py3-none-any.whl → 0.1.60__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/__version__.py +1 -1
- edsl/agents/agent.py +23 -4
- edsl/agents/agent_list.py +36 -6
- edsl/base/data_transfer_models.py +5 -0
- edsl/base/enums.py +7 -2
- edsl/coop/coop.py +103 -1
- edsl/dataset/dataset.py +74 -0
- edsl/dataset/dataset_operations_mixin.py +69 -64
- edsl/inference_services/services/__init__.py +3 -1
- edsl/inference_services/services/open_ai_service_v2.py +243 -0
- edsl/inference_services/services/test_service.py +1 -1
- edsl/interviews/exception_tracking.py +66 -20
- edsl/invigilators/invigilators.py +5 -1
- edsl/invigilators/prompt_constructor.py +299 -136
- edsl/jobs/data_structures.py +3 -0
- edsl/jobs/html_table_job_logger.py +18 -1
- edsl/jobs/jobs_pricing_estimation.py +6 -2
- edsl/jobs/jobs_remote_inference_logger.py +2 -0
- edsl/jobs/remote_inference.py +34 -7
- edsl/key_management/key_lookup_builder.py +25 -3
- edsl/language_models/language_model.py +41 -3
- edsl/language_models/raw_response_handler.py +126 -7
- edsl/prompts/prompt.py +1 -0
- edsl/questions/question_list.py +76 -20
- edsl/results/result.py +37 -0
- edsl/results/results.py +9 -1
- edsl/scenarios/file_store.py +8 -12
- edsl/scenarios/scenario.py +50 -2
- edsl/scenarios/scenario_list.py +34 -12
- edsl/surveys/survey.py +4 -0
- edsl/tasks/task_history.py +180 -6
- edsl/utilities/wikipedia.py +194 -0
- {edsl-0.1.58.dist-info → edsl-0.1.60.dist-info}/METADATA +5 -4
- {edsl-0.1.58.dist-info → edsl-0.1.60.dist-info}/RECORD +37 -35
- {edsl-0.1.58.dist-info → edsl-0.1.60.dist-info}/LICENSE +0 -0
- {edsl-0.1.58.dist-info → edsl-0.1.60.dist-info}/WHEEL +0 -0
- {edsl-0.1.58.dist-info → edsl-0.1.60.dist-info}/entry_points.txt +0 -0
edsl/jobs/data_structures.py
CHANGED
@@ -213,6 +213,9 @@ class Answers(UserDict):
|
|
213
213
|
if comment:
|
214
214
|
self[question.question_name + "_comment"] = comment
|
215
215
|
|
216
|
+
if getattr(response, "reasoning_summary", None):
|
217
|
+
self[question.question_name + "_reasoning_summary"] = response.reasoning_summary
|
218
|
+
|
216
219
|
def replace_missing_answers_with_none(self, survey: "Survey") -> None:
|
217
220
|
"""
|
218
221
|
Replace missing answers with None for all questions in the survey.
|
@@ -217,6 +217,17 @@ class HTMLTableJobLogger(JobLogger):
|
|
217
217
|
)
|
218
218
|
total_cost = total_input_cost + total_output_cost
|
219
219
|
|
220
|
+
# Calculate credit totals
|
221
|
+
total_input_credits = sum(
|
222
|
+
cost.input_cost_credits_with_cache or 0
|
223
|
+
for cost in self.jobs_info.model_costs
|
224
|
+
)
|
225
|
+
total_output_credits = sum(
|
226
|
+
cost.output_cost_credits_with_cache or 0
|
227
|
+
for cost in self.jobs_info.model_costs
|
228
|
+
)
|
229
|
+
total_credits = total_input_credits + total_output_credits
|
230
|
+
|
220
231
|
# Generate cost rows HTML with class names for right alignment
|
221
232
|
cost_rows = "".join(
|
222
233
|
f"""
|
@@ -228,6 +239,7 @@ class HTMLTableJobLogger(JobLogger):
|
|
228
239
|
<td class='token-count'>{cost.output_tokens:,}</td>
|
229
240
|
<td class='cost-value'>${cost.output_cost_usd:.4f}</td>
|
230
241
|
<td class='cost-value'>${(cost.input_cost_usd or 0) + (cost.output_cost_usd or 0):.4f}</td>
|
242
|
+
<td class='cost-value'>{(cost.input_cost_credits_with_cache or 0) + (cost.output_cost_credits_with_cache or 0):,.2f}</td>
|
231
243
|
</tr>
|
232
244
|
"""
|
233
245
|
for cost in self.jobs_info.model_costs
|
@@ -242,6 +254,7 @@ class HTMLTableJobLogger(JobLogger):
|
|
242
254
|
<td class='token-count'>{total_output_tokens:,}</td>
|
243
255
|
<td class='cost-value'>${total_output_cost:.4f}</td>
|
244
256
|
<td class='cost-value'>${total_cost:.4f}</td>
|
257
|
+
<td class='cost-value'>{total_credits:,.2f}</td>
|
245
258
|
</tr>
|
246
259
|
"""
|
247
260
|
|
@@ -249,7 +262,7 @@ class HTMLTableJobLogger(JobLogger):
|
|
249
262
|
<div class="model-costs-section">
|
250
263
|
<div class="model-costs-header" onclick="{self._collapse(f'model-costs-content-{self.log_id}', f'model-costs-arrow-{self.log_id}')}">
|
251
264
|
<span id="model-costs-arrow-{self.log_id}" class="expand-toggle">⌃</span>
|
252
|
-
<span>Model Costs (${total_cost:.4f} total)</span>
|
265
|
+
<span>Model Costs (${total_cost:.4f} / {total_credits:,.2f} credits total)</span>
|
253
266
|
<span style="flex-grow: 1;"></span>
|
254
267
|
</div>
|
255
268
|
<div id="model-costs-content-{self.log_id}" class="model-costs-content">
|
@@ -263,6 +276,7 @@ class HTMLTableJobLogger(JobLogger):
|
|
263
276
|
<th class="cost-header">Output Tokens</th>
|
264
277
|
<th class="cost-header">Output Cost</th>
|
265
278
|
<th class="cost-header">Total Cost</th>
|
279
|
+
<th class="cost-header">Total Credits</th>
|
266
280
|
</tr>
|
267
281
|
</thead>
|
268
282
|
<tbody>
|
@@ -270,6 +284,9 @@ class HTMLTableJobLogger(JobLogger):
|
|
270
284
|
{total_row}
|
271
285
|
</tbody>
|
272
286
|
</table>
|
287
|
+
<p style="font-style: italic; margin-top: 8px; font-size: 0.85em; color: #4b5563;">
|
288
|
+
You can obtain the total credit cost by multiplying the total USD cost by 100. A lower credit cost indicates that you saved money by retrieving responses from the universal remote cache.
|
289
|
+
</p>
|
273
290
|
</div>
|
274
291
|
</div>
|
275
292
|
"""
|
@@ -88,7 +88,6 @@ class PromptCostEstimator:
|
|
88
88
|
|
89
89
|
|
90
90
|
class JobsPrompts:
|
91
|
-
|
92
91
|
relevant_keys = [
|
93
92
|
"user_prompt",
|
94
93
|
"system_prompt",
|
@@ -171,13 +170,18 @@ class JobsPrompts:
|
|
171
170
|
cost = prompt_cost["cost_usd"]
|
172
171
|
|
173
172
|
# Generate cache keys for each iteration
|
173
|
+
files_list = prompts.get("files_list", None)
|
174
|
+
if files_list:
|
175
|
+
files_hash = "+".join([str(hash(file)) for file in files_list])
|
176
|
+
user_prompt_with_hashes = user_prompt + f" {files_hash}"
|
174
177
|
cache_keys = []
|
178
|
+
|
175
179
|
for iteration in range(iterations):
|
176
180
|
cache_key = CacheEntry.gen_key(
|
177
181
|
model=model,
|
178
182
|
parameters=invigilator.model.parameters,
|
179
183
|
system_prompt=system_prompt,
|
180
|
-
user_prompt=user_prompt,
|
184
|
+
user_prompt=user_prompt_with_hashes if files_list else user_prompt,
|
181
185
|
iteration=iteration,
|
182
186
|
)
|
183
187
|
cache_keys.append(cache_key)
|
edsl/jobs/remote_inference.py
CHANGED
@@ -279,9 +279,7 @@ class JobsRemoteInferenceHandler:
|
|
279
279
|
)
|
280
280
|
time.sleep(self.poll_interval)
|
281
281
|
|
282
|
-
def _get_expenses_from_results(
|
283
|
-
self, results: "Results", include_cached_responses_in_cost: bool = False
|
284
|
-
) -> dict:
|
282
|
+
def _get_expenses_from_results(self, results: "Results") -> dict:
|
285
283
|
"""
|
286
284
|
Calculates expenses from Results object.
|
287
285
|
|
@@ -309,10 +307,6 @@ class JobsRemoteInferenceHandler:
|
|
309
307
|
question_name = key.removesuffix("_cost")
|
310
308
|
cache_used = result["cache_used_dict"][question_name]
|
311
309
|
|
312
|
-
# Skip if we're excluding cached responses and this was cached
|
313
|
-
if not include_cached_responses_in_cost and cache_used:
|
314
|
-
continue
|
315
|
-
|
316
310
|
# Get expense keys for input and output tokens
|
317
311
|
input_key = (
|
318
312
|
result["model"]._inference_service_,
|
@@ -332,6 +326,7 @@ class JobsRemoteInferenceHandler:
|
|
332
326
|
expenses[input_key] = {
|
333
327
|
"tokens": 0,
|
334
328
|
"cost_usd": 0,
|
329
|
+
"cost_usd_with_cache": 0,
|
335
330
|
}
|
336
331
|
|
337
332
|
input_price_per_million_tokens = input_key[3]
|
@@ -341,11 +336,15 @@ class JobsRemoteInferenceHandler:
|
|
341
336
|
expenses[input_key]["tokens"] += input_tokens
|
342
337
|
expenses[input_key]["cost_usd"] += input_cost
|
343
338
|
|
339
|
+
if not cache_used:
|
340
|
+
expenses[input_key]["cost_usd_with_cache"] += input_cost
|
341
|
+
|
344
342
|
# Update output token expenses
|
345
343
|
if output_key not in expenses:
|
346
344
|
expenses[output_key] = {
|
347
345
|
"tokens": 0,
|
348
346
|
"cost_usd": 0,
|
347
|
+
"cost_usd_with_cache": 0,
|
349
348
|
}
|
350
349
|
|
351
350
|
output_price_per_million_tokens = output_key[3]
|
@@ -357,6 +356,9 @@ class JobsRemoteInferenceHandler:
|
|
357
356
|
expenses[output_key]["tokens"] += output_tokens
|
358
357
|
expenses[output_key]["cost_usd"] += output_cost
|
359
358
|
|
359
|
+
if not cache_used:
|
360
|
+
expenses[output_key]["cost_usd_with_cache"] += output_cost
|
361
|
+
|
360
362
|
expenses_by_model = {}
|
361
363
|
for expense_key, expense_usage in expenses.items():
|
362
364
|
service, model, token_type, _ = expense_key
|
@@ -368,8 +370,10 @@ class JobsRemoteInferenceHandler:
|
|
368
370
|
"model": model,
|
369
371
|
"input_tokens": 0,
|
370
372
|
"input_cost_usd": 0,
|
373
|
+
"input_cost_usd_with_cache": 0,
|
371
374
|
"output_tokens": 0,
|
372
375
|
"output_cost_usd": 0,
|
376
|
+
"output_cost_usd_with_cache": 0,
|
373
377
|
}
|
374
378
|
|
375
379
|
if token_type == "input":
|
@@ -377,14 +381,22 @@ class JobsRemoteInferenceHandler:
|
|
377
381
|
expenses_by_model[model_key]["input_cost_usd"] += expense_usage[
|
378
382
|
"cost_usd"
|
379
383
|
]
|
384
|
+
expenses_by_model[model_key][
|
385
|
+
"input_cost_usd_with_cache"
|
386
|
+
] += expense_usage["cost_usd_with_cache"]
|
380
387
|
elif token_type == "output":
|
381
388
|
expenses_by_model[model_key]["output_tokens"] += expense_usage["tokens"]
|
382
389
|
expenses_by_model[model_key]["output_cost_usd"] += expense_usage[
|
383
390
|
"cost_usd"
|
384
391
|
]
|
392
|
+
expenses_by_model[model_key][
|
393
|
+
"output_cost_usd_with_cache"
|
394
|
+
] += expense_usage["cost_usd_with_cache"]
|
385
395
|
|
386
396
|
converter = CostConverter()
|
387
397
|
for model_key, model_cost_dict in expenses_by_model.items():
|
398
|
+
|
399
|
+
# Handle full cost (without cache)
|
388
400
|
input_cost = model_cost_dict["input_cost_usd"]
|
389
401
|
output_cost = model_cost_dict["output_cost_usd"]
|
390
402
|
model_cost_dict["input_cost_credits"] = converter.usd_to_credits(input_cost)
|
@@ -399,6 +411,15 @@ class JobsRemoteInferenceHandler:
|
|
399
411
|
model_cost_dict["output_cost_credits"]
|
400
412
|
)
|
401
413
|
|
414
|
+
# Handle cost with cache
|
415
|
+
input_cost_with_cache = model_cost_dict["input_cost_usd_with_cache"]
|
416
|
+
output_cost_with_cache = model_cost_dict["output_cost_usd_with_cache"]
|
417
|
+
model_cost_dict["input_cost_credits_with_cache"] = converter.usd_to_credits(
|
418
|
+
input_cost_with_cache
|
419
|
+
)
|
420
|
+
model_cost_dict["output_cost_credits_with_cache"] = (
|
421
|
+
converter.usd_to_credits(output_cost_with_cache)
|
422
|
+
)
|
402
423
|
return list(expenses_by_model.values())
|
403
424
|
|
404
425
|
def _fetch_results_and_log(
|
@@ -423,6 +444,12 @@ class JobsRemoteInferenceHandler:
|
|
423
444
|
input_cost_usd=model_cost_dict.get("input_cost_usd"),
|
424
445
|
output_tokens=model_cost_dict.get("output_tokens"),
|
425
446
|
output_cost_usd=model_cost_dict.get("output_cost_usd"),
|
447
|
+
input_cost_credits_with_cache=model_cost_dict.get(
|
448
|
+
"input_cost_credits_with_cache"
|
449
|
+
),
|
450
|
+
output_cost_credits_with_cache=model_cost_dict.get(
|
451
|
+
"output_cost_credits_with_cache"
|
452
|
+
),
|
426
453
|
)
|
427
454
|
for model_cost_dict in model_cost_dicts
|
428
455
|
]
|
@@ -363,13 +363,35 @@ class KeyLookupBuilder:
|
|
363
363
|
>>> builder._add_api_key("OPENAI_API_KEY", "sk-1234", "env")
|
364
364
|
>>> 'sk-1234' == builder.key_data["openai"][-1].value
|
365
365
|
True
|
366
|
+
>>> 'sk-1234' == builder.key_data["openai_v2"][-1].value
|
367
|
+
True
|
366
368
|
"""
|
367
369
|
service = api_keyname_to_service[key]
|
368
370
|
new_entry = APIKeyEntry(service=service, name=key, value=value, source=source)
|
369
|
-
|
370
|
-
|
371
|
+
|
372
|
+
# Special case for OPENAI_API_KEY - add to both openai and openai_v2
|
373
|
+
if key == "OPENAI_API_KEY":
|
374
|
+
# Add to openai service
|
375
|
+
openai_service = "openai"
|
376
|
+
openai_entry = APIKeyEntry(service=openai_service, name=key, value=value, source=source)
|
377
|
+
if openai_service not in self.key_data:
|
378
|
+
self.key_data[openai_service] = [openai_entry]
|
379
|
+
else:
|
380
|
+
self.key_data[openai_service].append(openai_entry)
|
381
|
+
|
382
|
+
# Add to openai_v2 service
|
383
|
+
openai_v2_service = "openai_v2"
|
384
|
+
openai_v2_entry = APIKeyEntry(service=openai_v2_service, name=key, value=value, source=source)
|
385
|
+
if openai_v2_service not in self.key_data:
|
386
|
+
self.key_data[openai_v2_service] = [openai_v2_entry]
|
387
|
+
else:
|
388
|
+
self.key_data[openai_v2_service].append(openai_v2_entry)
|
371
389
|
else:
|
372
|
-
|
390
|
+
# Normal case for all other API keys
|
391
|
+
if service not in self.key_data:
|
392
|
+
self.key_data[service] = [new_entry]
|
393
|
+
else:
|
394
|
+
self.key_data[service].append(new_entry)
|
373
395
|
|
374
396
|
def update_from_dict(self, d: dict) -> None:
|
375
397
|
"""
|
@@ -174,7 +174,8 @@ class LanguageModel(
|
|
174
174
|
"""
|
175
175
|
key_sequence = cls.key_sequence
|
176
176
|
usage_sequence = cls.usage_sequence if hasattr(cls, "usage_sequence") else None
|
177
|
-
|
177
|
+
reasoning_sequence = cls.reasoning_sequence if hasattr(cls, "reasoning_sequence") else None
|
178
|
+
return RawResponseHandler(key_sequence, usage_sequence, reasoning_sequence)
|
178
179
|
|
179
180
|
def __init__(
|
180
181
|
self,
|
@@ -769,8 +770,45 @@ class LanguageModel(
|
|
769
770
|
params["question_name"] = invigilator.question.question_name
|
770
771
|
# Get timeout from configuration
|
771
772
|
from ..config import CONFIG
|
772
|
-
|
773
|
-
|
773
|
+
import logging
|
774
|
+
|
775
|
+
logger = logging.getLogger(__name__)
|
776
|
+
base_timeout = float(CONFIG.get("EDSL_API_TIMEOUT"))
|
777
|
+
|
778
|
+
# Adjust timeout if files are present
|
779
|
+
import time
|
780
|
+
|
781
|
+
start = time.time()
|
782
|
+
if files_list:
|
783
|
+
# Calculate total size of attached files in MB
|
784
|
+
file_sizes = []
|
785
|
+
for file in files_list:
|
786
|
+
# Try different attributes that might contain the file content
|
787
|
+
if hasattr(file, "base64_string") and file.base64_string:
|
788
|
+
file_sizes.append(len(file.base64_string) / (1024 * 1024))
|
789
|
+
elif hasattr(file, "content") and file.content:
|
790
|
+
file_sizes.append(len(file.content) / (1024 * 1024))
|
791
|
+
elif hasattr(file, "data") and file.data:
|
792
|
+
file_sizes.append(len(file.data) / (1024 * 1024))
|
793
|
+
else:
|
794
|
+
# Default minimum size if we can't determine actual size
|
795
|
+
file_sizes.append(1) # Assume at least 1MB
|
796
|
+
total_size_mb = sum(file_sizes)
|
797
|
+
|
798
|
+
# Increase timeout proportionally to file size
|
799
|
+
# For each MB of file size, add 10 seconds to the timeout (adjust as needed)
|
800
|
+
size_adjustment = total_size_mb * 10
|
801
|
+
|
802
|
+
# Cap the maximum timeout adjustment at 5 minutes (300 seconds)
|
803
|
+
size_adjustment = min(size_adjustment, 300)
|
804
|
+
|
805
|
+
TIMEOUT = base_timeout + size_adjustment
|
806
|
+
|
807
|
+
logger.info(
|
808
|
+
f"Adjusted timeout for API call with {len(files_list)} files (total size: {total_size_mb:.2f}MB). Base timeout: {base_timeout}s, New timeout: {TIMEOUT}s"
|
809
|
+
)
|
810
|
+
else:
|
811
|
+
TIMEOUT = base_timeout
|
774
812
|
|
775
813
|
# Execute the model call with timeout
|
776
814
|
response = await asyncio.wait_for(f(**params), timeout=TIMEOUT)
|
@@ -1,5 +1,5 @@
|
|
1
1
|
import json
|
2
|
-
from typing import Optional, Any
|
2
|
+
from typing import Optional, Any, List
|
3
3
|
from .exceptions import (
|
4
4
|
LanguageModelBadResponseError,
|
5
5
|
LanguageModelTypeError,
|
@@ -41,10 +41,13 @@ def _extract_item_from_raw_response(data, sequence):
|
|
41
41
|
current_data = current_data[key]
|
42
42
|
except Exception as e:
|
43
43
|
path = " -> ".join(map(str, sequence[: i + 1]))
|
44
|
-
|
45
|
-
|
44
|
+
|
45
|
+
# Create a safe error message that won't be None
|
46
|
+
if "error" in data and data["error"] is not None:
|
47
|
+
msg = str(data["error"])
|
46
48
|
else:
|
47
49
|
msg = f"Error accessing path: {path}. {str(e)}. Full response is: '{data}'"
|
50
|
+
|
48
51
|
raise LanguageModelBadResponseError(message=msg, response_json=data)
|
49
52
|
if isinstance(current_data, str):
|
50
53
|
return current_data.strip()
|
@@ -55,17 +58,127 @@ def _extract_item_from_raw_response(data, sequence):
|
|
55
58
|
class RawResponseHandler:
|
56
59
|
"""Class to handle raw responses from language models."""
|
57
60
|
|
58
|
-
def __init__(self, key_sequence: list, usage_sequence: Optional[list] = None):
|
61
|
+
def __init__(self, key_sequence: list, usage_sequence: Optional[list] = None, reasoning_sequence: Optional[list] = None):
|
59
62
|
self.key_sequence = key_sequence
|
60
63
|
self.usage_sequence = usage_sequence
|
64
|
+
self.reasoning_sequence = reasoning_sequence
|
61
65
|
|
62
66
|
def get_generated_token_string(self, raw_response):
|
63
|
-
|
67
|
+
try:
|
68
|
+
return _extract_item_from_raw_response(raw_response, self.key_sequence)
|
69
|
+
except (LanguageModelKeyError, LanguageModelIndexError, LanguageModelTypeError, LanguageModelBadResponseError) as e:
|
70
|
+
# For non-reasoning models or reasoning models with different response formats,
|
71
|
+
# try to extract text directly from common response formats
|
72
|
+
if isinstance(raw_response, dict):
|
73
|
+
# Responses API format for non-reasoning models
|
74
|
+
if 'output' in raw_response and isinstance(raw_response['output'], list):
|
75
|
+
# Try to get first message content
|
76
|
+
if len(raw_response['output']) > 0:
|
77
|
+
item = raw_response['output'][0]
|
78
|
+
if isinstance(item, dict) and 'content' in item:
|
79
|
+
if isinstance(item['content'], list) and len(item['content']) > 0:
|
80
|
+
first_content = item['content'][0]
|
81
|
+
if isinstance(first_content, dict) and 'text' in first_content:
|
82
|
+
return first_content['text']
|
83
|
+
elif isinstance(item['content'], str):
|
84
|
+
return item['content']
|
85
|
+
|
86
|
+
# OpenAI completions format
|
87
|
+
if 'choices' in raw_response and isinstance(raw_response['choices'], list) and len(raw_response['choices']) > 0:
|
88
|
+
choice = raw_response['choices'][0]
|
89
|
+
if isinstance(choice, dict):
|
90
|
+
if 'text' in choice:
|
91
|
+
return choice['text']
|
92
|
+
elif 'message' in choice and isinstance(choice['message'], dict) and 'content' in choice['message']:
|
93
|
+
return choice['message']['content']
|
94
|
+
|
95
|
+
# Text directly in response
|
96
|
+
if 'text' in raw_response:
|
97
|
+
return raw_response['text']
|
98
|
+
elif 'content' in raw_response:
|
99
|
+
return raw_response['content']
|
100
|
+
|
101
|
+
# Error message - try to return a coherent error for debugging
|
102
|
+
if 'message' in raw_response:
|
103
|
+
return f"[ERROR: {raw_response['message']}]"
|
104
|
+
|
105
|
+
# If we get a string directly, return it
|
106
|
+
if isinstance(raw_response, str):
|
107
|
+
return raw_response
|
108
|
+
|
109
|
+
# As a last resort, convert the whole response to string
|
110
|
+
try:
|
111
|
+
return f"[ERROR: Could not extract text. Raw response: {str(raw_response)}]"
|
112
|
+
except:
|
113
|
+
return "[ERROR: Could not extract text from response]"
|
64
114
|
|
65
115
|
def get_usage_dict(self, raw_response):
|
66
116
|
if self.usage_sequence is None:
|
67
117
|
return {}
|
68
|
-
|
118
|
+
try:
|
119
|
+
return _extract_item_from_raw_response(raw_response, self.usage_sequence)
|
120
|
+
except (LanguageModelKeyError, LanguageModelIndexError, LanguageModelTypeError, LanguageModelBadResponseError):
|
121
|
+
# For non-reasoning models, try to extract usage from common response formats
|
122
|
+
if isinstance(raw_response, dict):
|
123
|
+
# Standard OpenAI usage format
|
124
|
+
if 'usage' in raw_response:
|
125
|
+
return raw_response['usage']
|
126
|
+
|
127
|
+
# Look for nested usage info
|
128
|
+
if 'choices' in raw_response and len(raw_response['choices']) > 0:
|
129
|
+
choice = raw_response['choices'][0]
|
130
|
+
if isinstance(choice, dict) and 'usage' in choice:
|
131
|
+
return choice['usage']
|
132
|
+
|
133
|
+
# If no usage info found, return empty dict
|
134
|
+
return {}
|
135
|
+
|
136
|
+
def get_reasoning_summary(self, raw_response):
|
137
|
+
"""
|
138
|
+
Extract reasoning summary from the model response.
|
139
|
+
|
140
|
+
Handles various response structures:
|
141
|
+
1. Standard path extraction using self.reasoning_sequence
|
142
|
+
2. Direct access to output[0]['summary'] for OpenAI responses
|
143
|
+
3. List responses where the first item contains the output structure
|
144
|
+
"""
|
145
|
+
if self.reasoning_sequence is None:
|
146
|
+
return None
|
147
|
+
|
148
|
+
try:
|
149
|
+
# First try the standard extraction path
|
150
|
+
summary_data = _extract_item_from_raw_response(raw_response, self.reasoning_sequence)
|
151
|
+
|
152
|
+
# If summary_data is a list of dictionaries with 'text' and 'type' fields
|
153
|
+
# (as in OpenAI's response format), combine them into a single string
|
154
|
+
if isinstance(summary_data, list) and all(isinstance(item, dict) and 'text' in item for item in summary_data):
|
155
|
+
return '\n\n'.join(item['text'] for item in summary_data)
|
156
|
+
|
157
|
+
return summary_data
|
158
|
+
except Exception:
|
159
|
+
# Fallback approaches for different response structures
|
160
|
+
try:
|
161
|
+
# Case 1: Direct dict with 'output' field (common OpenAI format)
|
162
|
+
if isinstance(raw_response, dict) and 'output' in raw_response:
|
163
|
+
output = raw_response['output']
|
164
|
+
if isinstance(output, list) and len(output) > 0 and 'summary' in output[0]:
|
165
|
+
summary_data = output[0]['summary']
|
166
|
+
if isinstance(summary_data, list) and all(isinstance(item, dict) and 'text' in item for item in summary_data):
|
167
|
+
return '\n\n'.join(item['text'] for item in summary_data)
|
168
|
+
|
169
|
+
# Case 2: List where the first item is a dict with 'output' field
|
170
|
+
if isinstance(raw_response, list) and len(raw_response) > 0:
|
171
|
+
first_item = raw_response[0]
|
172
|
+
if isinstance(first_item, dict) and 'output' in first_item:
|
173
|
+
output = first_item['output']
|
174
|
+
if isinstance(output, list) and len(output) > 0 and 'summary' in output[0]:
|
175
|
+
summary_data = output[0]['summary']
|
176
|
+
if isinstance(summary_data, list) and all(isinstance(item, dict) and 'text' in item for item in summary_data):
|
177
|
+
return '\n\n'.join(item['text'] for item in summary_data)
|
178
|
+
except Exception:
|
179
|
+
pass
|
180
|
+
|
181
|
+
return None
|
69
182
|
|
70
183
|
def parse_response(self, raw_response: dict[str, Any]) -> Any:
|
71
184
|
"""Parses the API response and returns the response text."""
|
@@ -73,7 +186,11 @@ class RawResponseHandler:
|
|
73
186
|
from edsl.data_transfer_models import EDSLOutput
|
74
187
|
|
75
188
|
generated_token_string = self.get_generated_token_string(raw_response)
|
189
|
+
# Ensure generated_token_string is a string before using string methods
|
190
|
+
if not isinstance(generated_token_string, str):
|
191
|
+
generated_token_string = str(generated_token_string)
|
76
192
|
last_newline = generated_token_string.rfind("\n")
|
193
|
+
reasoning_summary = self.get_reasoning_summary(raw_response)
|
77
194
|
|
78
195
|
if last_newline == -1:
|
79
196
|
# There is no comment
|
@@ -81,12 +198,14 @@ class RawResponseHandler:
|
|
81
198
|
"answer": self.convert_answer(generated_token_string),
|
82
199
|
"generated_tokens": generated_token_string,
|
83
200
|
"comment": None,
|
201
|
+
"reasoning_summary": reasoning_summary,
|
84
202
|
}
|
85
203
|
else:
|
86
204
|
edsl_dict = {
|
87
205
|
"answer": self.convert_answer(generated_token_string[:last_newline]),
|
88
|
-
"comment": generated_token_string[last_newline + 1
|
206
|
+
"comment": generated_token_string[last_newline + 1:].strip(),
|
89
207
|
"generated_tokens": generated_token_string,
|
208
|
+
"reasoning_summary": reasoning_summary,
|
90
209
|
}
|
91
210
|
return EDSLOutput(**edsl_dict)
|
92
211
|
|
edsl/prompts/prompt.py
CHANGED
edsl/questions/question_list.py
CHANGED
@@ -299,23 +299,24 @@ class ListResponseValidator(ResponseValidatorABC):
|
|
299
299
|
# This method can now be removed since validation is handled in the Pydantic model
|
300
300
|
pass
|
301
301
|
|
302
|
-
def fix(self, response, verbose=False):
|
302
|
+
def fix(self, response, verbose=False) -> dict[str, Any]:
|
303
303
|
"""
|
304
304
|
Fix common issues in list responses by splitting strings into lists.
|
305
305
|
|
306
306
|
Examples:
|
307
307
|
>>> from edsl import QuestionList
|
308
|
-
>>>
|
309
|
-
>>>
|
308
|
+
>>> q_constrained = QuestionList.example(min_list_items=2, max_list_items=4)
|
309
|
+
>>> validator_constrained = q_constrained.response_validator
|
310
310
|
|
311
|
+
>>> q_permissive = QuestionList.example(permissive=True)
|
312
|
+
>>> validator_permissive = q_permissive.response_validator
|
313
|
+
|
311
314
|
>>> # Fix a string that should be a list
|
312
315
|
>>> bad_response = {"answer": "apple,banana,cherry"}
|
313
|
-
>>>
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
... validated = validator.validate(fixed)
|
318
|
-
... validated # Show full response
|
316
|
+
>>> fixed = validator_constrained.fix(bad_response)
|
317
|
+
>>> fixed
|
318
|
+
{'answer': ['apple', 'banana', 'cherry']}
|
319
|
+
>>> validator_constrained.validate(fixed) # Show full response after validation
|
319
320
|
{'answer': ['apple', 'banana', 'cherry'], 'comment': None, 'generated_tokens': None}
|
320
321
|
|
321
322
|
>>> # Fix using generated_tokens when answer is invalid
|
@@ -323,12 +324,10 @@ class ListResponseValidator(ResponseValidatorABC):
|
|
323
324
|
... "answer": None,
|
324
325
|
... "generated_tokens": "pizza, pasta, salad"
|
325
326
|
... }
|
326
|
-
>>>
|
327
|
-
|
328
|
-
|
329
|
-
|
330
|
-
... validated = validator.validate(fixed)
|
331
|
-
... validated
|
327
|
+
>>> fixed = validator_constrained.fix(bad_response)
|
328
|
+
>>> fixed
|
329
|
+
{'answer': ['pizza', ' pasta', ' salad']}
|
330
|
+
>>> validator_constrained.validate(fixed)
|
332
331
|
{'answer': ['pizza', ' pasta', ' salad'], 'comment': None, 'generated_tokens': None}
|
333
332
|
|
334
333
|
>>> # Preserve comments during fixing
|
@@ -336,17 +335,74 @@ class ListResponseValidator(ResponseValidatorABC):
|
|
336
335
|
... "answer": "red,blue,green",
|
337
336
|
... "comment": "These are colors"
|
338
337
|
... }
|
339
|
-
>>>
|
340
|
-
>>>
|
338
|
+
>>> fixed_output = validator_constrained.fix(bad_response)
|
339
|
+
>>> fixed_output
|
340
|
+
{'answer': ['red', 'blue', 'green'], 'comment': 'These are colors'}
|
341
|
+
>>> validated_output = validator_constrained.validate(fixed_output)
|
342
|
+
>>> validated_output == {
|
341
343
|
... "answer": ["red", "blue", "green"],
|
342
|
-
... "comment": "These are colors"
|
344
|
+
... "comment": "These are colors",
|
345
|
+
... "generated_tokens": None
|
343
346
|
... }
|
344
347
|
True
|
348
|
+
|
349
|
+
>>> # Fix an empty string answer
|
350
|
+
>>> bad_response = {"answer": ""}
|
351
|
+
>>> fixed = validator_constrained.fix(bad_response)
|
352
|
+
>>> fixed
|
353
|
+
{'answer': []}
|
354
|
+
>>> validator_permissive.validate(fixed)
|
355
|
+
{'answer': [], 'comment': None, 'generated_tokens': None}
|
356
|
+
|
357
|
+
>>> # Fix a single item string answer (no commas)
|
358
|
+
>>> bad_response = {"answer": "single_item"}
|
359
|
+
>>> fixed = validator_constrained.fix(bad_response)
|
360
|
+
>>> fixed
|
361
|
+
{'answer': ['single_item']}
|
362
|
+
>>> validator_permissive.validate(fixed)
|
363
|
+
{'answer': ['single_item'], 'comment': None, 'generated_tokens': None}
|
364
|
+
|
365
|
+
>>> # Fix when answer is None and no generated_tokens
|
366
|
+
>>> bad_response = {"answer": None}
|
367
|
+
>>> fixed = validator_constrained.fix(bad_response)
|
368
|
+
>>> fixed
|
369
|
+
{'answer': []}
|
370
|
+
>>> validator_permissive.validate(fixed)
|
371
|
+
{'answer': [], 'comment': None, 'generated_tokens': None}
|
372
|
+
|
373
|
+
>>> # Fix when answer key is missing but generated_tokens is present
|
374
|
+
>>> bad_response = {"generated_tokens": "token1,token2"}
|
375
|
+
>>> fixed = validator_constrained.fix(bad_response)
|
376
|
+
>>> fixed
|
377
|
+
{'answer': ['token1', 'token2']}
|
378
|
+
>>> validator_constrained.validate(fixed) # 2 items, OK for constrained validator
|
379
|
+
{'answer': ['token1', 'token2'], 'comment': None, 'generated_tokens': None}
|
380
|
+
|
381
|
+
>>> # Fix when answer key is missing and generated_tokens is an empty string
|
382
|
+
>>> bad_response = {"generated_tokens": ""}
|
383
|
+
>>> fixed = validator_constrained.fix(bad_response)
|
384
|
+
>>> fixed
|
385
|
+
{'answer': []}
|
386
|
+
>>> validator_permissive.validate(fixed)
|
387
|
+
{'answer': [], 'comment': None, 'generated_tokens': None}
|
388
|
+
|
389
|
+
>>> # Fix when answer key is missing and generated_tokens is a single item
|
390
|
+
>>> bad_response = {"generated_tokens": "single_token"}
|
391
|
+
>>> fixed = validator_constrained.fix(bad_response)
|
392
|
+
>>> fixed
|
393
|
+
{'answer': ['single_token']}
|
394
|
+
>>> validator_permissive.validate(fixed)
|
395
|
+
{'answer': ['single_token'], 'comment': None, 'generated_tokens': None}
|
345
396
|
"""
|
346
397
|
if verbose:
|
347
398
|
print(f"Fixing list response: {response}")
|
348
399
|
answer = str(response.get("answer") or response.get("generated_tokens", ""))
|
349
|
-
|
400
|
+
if "," in answer:
|
401
|
+
result = {"answer": answer.split(",")}
|
402
|
+
elif answer == "":
|
403
|
+
result = {"answer": []}
|
404
|
+
else:
|
405
|
+
result = {"answer": [answer]}
|
350
406
|
if "comment" in response:
|
351
407
|
result["comment"] = response["comment"]
|
352
408
|
return result
|
@@ -395,7 +451,7 @@ class QuestionList(QuestionBase):
|
|
395
451
|
|
396
452
|
self.include_comment = include_comment
|
397
453
|
self.answering_instructions = answering_instructions
|
398
|
-
self.
|
454
|
+
self.question_presentation = question_presentation
|
399
455
|
|
400
456
|
def create_response_model(self):
|
401
457
|
return create_model(self.min_list_items, self.max_list_items, self.permissive)
|