edsl 0.1.59__py3-none-any.whl → 0.1.61__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/__version__.py +1 -1
- edsl/agents/agent.py +65 -17
- edsl/agents/agent_list.py +117 -33
- edsl/base/base_class.py +80 -11
- edsl/base/data_transfer_models.py +5 -0
- edsl/base/enums.py +7 -2
- edsl/config/config_class.py +7 -2
- edsl/coop/coop.py +1295 -85
- edsl/coop/coop_prolific_filters.py +171 -0
- edsl/dataset/dataset_operations_mixin.py +2 -2
- edsl/dataset/display/table_display.py +40 -7
- edsl/db_list/sqlite_list.py +102 -3
- edsl/inference_services/services/__init__.py +3 -1
- edsl/inference_services/services/open_ai_service_v2.py +243 -0
- edsl/jobs/data_structures.py +48 -30
- edsl/jobs/jobs.py +73 -2
- edsl/jobs/remote_inference.py +49 -15
- edsl/key_management/key_lookup_builder.py +25 -3
- edsl/language_models/language_model.py +2 -1
- edsl/language_models/raw_response_handler.py +126 -7
- edsl/questions/loop_processor.py +289 -10
- edsl/questions/templates/dict/answering_instructions.jinja +0 -1
- edsl/results/result.py +37 -0
- edsl/results/results.py +1 -0
- edsl/scenarios/scenario_list.py +31 -1
- edsl/scenarios/scenario_source.py +606 -498
- edsl/surveys/survey.py +198 -163
- {edsl-0.1.59.dist-info → edsl-0.1.61.dist-info}/METADATA +4 -4
- {edsl-0.1.59.dist-info → edsl-0.1.61.dist-info}/RECORD +32 -30
- {edsl-0.1.59.dist-info → edsl-0.1.61.dist-info}/LICENSE +0 -0
- {edsl-0.1.59.dist-info → edsl-0.1.61.dist-info}/WHEEL +0 -0
- {edsl-0.1.59.dist-info → edsl-0.1.61.dist-info}/entry_points.txt +0 -0
@@ -1,5 +1,5 @@
|
|
1
1
|
import json
|
2
|
-
from typing import Optional, Any
|
2
|
+
from typing import Optional, Any, List
|
3
3
|
from .exceptions import (
|
4
4
|
LanguageModelBadResponseError,
|
5
5
|
LanguageModelTypeError,
|
@@ -41,10 +41,13 @@ def _extract_item_from_raw_response(data, sequence):
|
|
41
41
|
current_data = current_data[key]
|
42
42
|
except Exception as e:
|
43
43
|
path = " -> ".join(map(str, sequence[: i + 1]))
|
44
|
-
|
45
|
-
|
44
|
+
|
45
|
+
# Create a safe error message that won't be None
|
46
|
+
if "error" in data and data["error"] is not None:
|
47
|
+
msg = str(data["error"])
|
46
48
|
else:
|
47
49
|
msg = f"Error accessing path: {path}. {str(e)}. Full response is: '{data}'"
|
50
|
+
|
48
51
|
raise LanguageModelBadResponseError(message=msg, response_json=data)
|
49
52
|
if isinstance(current_data, str):
|
50
53
|
return current_data.strip()
|
@@ -55,17 +58,127 @@ def _extract_item_from_raw_response(data, sequence):
|
|
55
58
|
class RawResponseHandler:
|
56
59
|
"""Class to handle raw responses from language models."""
|
57
60
|
|
58
|
-
def __init__(self, key_sequence: list, usage_sequence: Optional[list] = None):
|
61
|
+
def __init__(self, key_sequence: list, usage_sequence: Optional[list] = None, reasoning_sequence: Optional[list] = None):
|
59
62
|
self.key_sequence = key_sequence
|
60
63
|
self.usage_sequence = usage_sequence
|
64
|
+
self.reasoning_sequence = reasoning_sequence
|
61
65
|
|
62
66
|
def get_generated_token_string(self, raw_response):
|
63
|
-
|
67
|
+
try:
|
68
|
+
return _extract_item_from_raw_response(raw_response, self.key_sequence)
|
69
|
+
except (LanguageModelKeyError, LanguageModelIndexError, LanguageModelTypeError, LanguageModelBadResponseError) as e:
|
70
|
+
# For non-reasoning models or reasoning models with different response formats,
|
71
|
+
# try to extract text directly from common response formats
|
72
|
+
if isinstance(raw_response, dict):
|
73
|
+
# Responses API format for non-reasoning models
|
74
|
+
if 'output' in raw_response and isinstance(raw_response['output'], list):
|
75
|
+
# Try to get first message content
|
76
|
+
if len(raw_response['output']) > 0:
|
77
|
+
item = raw_response['output'][0]
|
78
|
+
if isinstance(item, dict) and 'content' in item:
|
79
|
+
if isinstance(item['content'], list) and len(item['content']) > 0:
|
80
|
+
first_content = item['content'][0]
|
81
|
+
if isinstance(first_content, dict) and 'text' in first_content:
|
82
|
+
return first_content['text']
|
83
|
+
elif isinstance(item['content'], str):
|
84
|
+
return item['content']
|
85
|
+
|
86
|
+
# OpenAI completions format
|
87
|
+
if 'choices' in raw_response and isinstance(raw_response['choices'], list) and len(raw_response['choices']) > 0:
|
88
|
+
choice = raw_response['choices'][0]
|
89
|
+
if isinstance(choice, dict):
|
90
|
+
if 'text' in choice:
|
91
|
+
return choice['text']
|
92
|
+
elif 'message' in choice and isinstance(choice['message'], dict) and 'content' in choice['message']:
|
93
|
+
return choice['message']['content']
|
94
|
+
|
95
|
+
# Text directly in response
|
96
|
+
if 'text' in raw_response:
|
97
|
+
return raw_response['text']
|
98
|
+
elif 'content' in raw_response:
|
99
|
+
return raw_response['content']
|
100
|
+
|
101
|
+
# Error message - try to return a coherent error for debugging
|
102
|
+
if 'message' in raw_response:
|
103
|
+
return f"[ERROR: {raw_response['message']}]"
|
104
|
+
|
105
|
+
# If we get a string directly, return it
|
106
|
+
if isinstance(raw_response, str):
|
107
|
+
return raw_response
|
108
|
+
|
109
|
+
# As a last resort, convert the whole response to string
|
110
|
+
try:
|
111
|
+
return f"[ERROR: Could not extract text. Raw response: {str(raw_response)}]"
|
112
|
+
except:
|
113
|
+
return "[ERROR: Could not extract text from response]"
|
64
114
|
|
65
115
|
def get_usage_dict(self, raw_response):
|
66
116
|
if self.usage_sequence is None:
|
67
117
|
return {}
|
68
|
-
|
118
|
+
try:
|
119
|
+
return _extract_item_from_raw_response(raw_response, self.usage_sequence)
|
120
|
+
except (LanguageModelKeyError, LanguageModelIndexError, LanguageModelTypeError, LanguageModelBadResponseError):
|
121
|
+
# For non-reasoning models, try to extract usage from common response formats
|
122
|
+
if isinstance(raw_response, dict):
|
123
|
+
# Standard OpenAI usage format
|
124
|
+
if 'usage' in raw_response:
|
125
|
+
return raw_response['usage']
|
126
|
+
|
127
|
+
# Look for nested usage info
|
128
|
+
if 'choices' in raw_response and len(raw_response['choices']) > 0:
|
129
|
+
choice = raw_response['choices'][0]
|
130
|
+
if isinstance(choice, dict) and 'usage' in choice:
|
131
|
+
return choice['usage']
|
132
|
+
|
133
|
+
# If no usage info found, return empty dict
|
134
|
+
return {}
|
135
|
+
|
136
|
+
def get_reasoning_summary(self, raw_response):
|
137
|
+
"""
|
138
|
+
Extract reasoning summary from the model response.
|
139
|
+
|
140
|
+
Handles various response structures:
|
141
|
+
1. Standard path extraction using self.reasoning_sequence
|
142
|
+
2. Direct access to output[0]['summary'] for OpenAI responses
|
143
|
+
3. List responses where the first item contains the output structure
|
144
|
+
"""
|
145
|
+
if self.reasoning_sequence is None:
|
146
|
+
return None
|
147
|
+
|
148
|
+
try:
|
149
|
+
# First try the standard extraction path
|
150
|
+
summary_data = _extract_item_from_raw_response(raw_response, self.reasoning_sequence)
|
151
|
+
|
152
|
+
# If summary_data is a list of dictionaries with 'text' and 'type' fields
|
153
|
+
# (as in OpenAI's response format), combine them into a single string
|
154
|
+
if isinstance(summary_data, list) and all(isinstance(item, dict) and 'text' in item for item in summary_data):
|
155
|
+
return '\n\n'.join(item['text'] for item in summary_data)
|
156
|
+
|
157
|
+
return summary_data
|
158
|
+
except Exception:
|
159
|
+
# Fallback approaches for different response structures
|
160
|
+
try:
|
161
|
+
# Case 1: Direct dict with 'output' field (common OpenAI format)
|
162
|
+
if isinstance(raw_response, dict) and 'output' in raw_response:
|
163
|
+
output = raw_response['output']
|
164
|
+
if isinstance(output, list) and len(output) > 0 and 'summary' in output[0]:
|
165
|
+
summary_data = output[0]['summary']
|
166
|
+
if isinstance(summary_data, list) and all(isinstance(item, dict) and 'text' in item for item in summary_data):
|
167
|
+
return '\n\n'.join(item['text'] for item in summary_data)
|
168
|
+
|
169
|
+
# Case 2: List where the first item is a dict with 'output' field
|
170
|
+
if isinstance(raw_response, list) and len(raw_response) > 0:
|
171
|
+
first_item = raw_response[0]
|
172
|
+
if isinstance(first_item, dict) and 'output' in first_item:
|
173
|
+
output = first_item['output']
|
174
|
+
if isinstance(output, list) and len(output) > 0 and 'summary' in output[0]:
|
175
|
+
summary_data = output[0]['summary']
|
176
|
+
if isinstance(summary_data, list) and all(isinstance(item, dict) and 'text' in item for item in summary_data):
|
177
|
+
return '\n\n'.join(item['text'] for item in summary_data)
|
178
|
+
except Exception:
|
179
|
+
pass
|
180
|
+
|
181
|
+
return None
|
69
182
|
|
70
183
|
def parse_response(self, raw_response: dict[str, Any]) -> Any:
|
71
184
|
"""Parses the API response and returns the response text."""
|
@@ -73,7 +186,11 @@ class RawResponseHandler:
|
|
73
186
|
from edsl.data_transfer_models import EDSLOutput
|
74
187
|
|
75
188
|
generated_token_string = self.get_generated_token_string(raw_response)
|
189
|
+
# Ensure generated_token_string is a string before using string methods
|
190
|
+
if not isinstance(generated_token_string, str):
|
191
|
+
generated_token_string = str(generated_token_string)
|
76
192
|
last_newline = generated_token_string.rfind("\n")
|
193
|
+
reasoning_summary = self.get_reasoning_summary(raw_response)
|
77
194
|
|
78
195
|
if last_newline == -1:
|
79
196
|
# There is no comment
|
@@ -81,12 +198,14 @@ class RawResponseHandler:
|
|
81
198
|
"answer": self.convert_answer(generated_token_string),
|
82
199
|
"generated_tokens": generated_token_string,
|
83
200
|
"comment": None,
|
201
|
+
"reasoning_summary": reasoning_summary,
|
84
202
|
}
|
85
203
|
else:
|
86
204
|
edsl_dict = {
|
87
205
|
"answer": self.convert_answer(generated_token_string[:last_newline]),
|
88
|
-
"comment": generated_token_string[last_newline + 1
|
206
|
+
"comment": generated_token_string[last_newline + 1:].strip(),
|
89
207
|
"generated_tokens": generated_token_string,
|
208
|
+
"reasoning_summary": reasoning_summary,
|
90
209
|
}
|
91
210
|
return EDSLOutput(**edsl_dict)
|
92
211
|
|
edsl/questions/loop_processor.py
CHANGED
@@ -1,7 +1,9 @@
|
|
1
|
-
from typing import List, Any, Dict
|
1
|
+
from typing import List, Any, Dict, Tuple
|
2
2
|
from jinja2 import Environment, Undefined
|
3
3
|
from .question_base import QuestionBase
|
4
|
-
from ..scenarios import ScenarioList
|
4
|
+
from ..scenarios import Scenario, ScenarioList
|
5
|
+
from ..surveys import Survey
|
6
|
+
|
5
7
|
|
6
8
|
class LoopProcessor:
|
7
9
|
def __init__(self, question: QuestionBase):
|
@@ -88,7 +90,10 @@ class LoopProcessor:
|
|
88
90
|
return value
|
89
91
|
|
90
92
|
from .exceptions import QuestionValueError
|
91
|
-
|
93
|
+
|
94
|
+
raise QuestionValueError(
|
95
|
+
f"Unexpected value type: {type(value)} for key '{key}'"
|
96
|
+
)
|
92
97
|
|
93
98
|
def _render_template(self, template: str, scenario: Dict[str, Any]) -> str:
|
94
99
|
"""Render a single template string.
|
@@ -124,21 +129,21 @@ class LoopProcessor:
|
|
124
129
|
'{{ item.missing }}'
|
125
130
|
"""
|
126
131
|
import re
|
127
|
-
|
132
|
+
|
128
133
|
# Regular expression to find Jinja2 variables in the template
|
129
|
-
pattern = r
|
130
|
-
|
134
|
+
pattern = r"(?P<open>\{\{\s*)(?P<var>[a-zA-Z0-9_.]+)(?P<close>\s*\}\})"
|
135
|
+
|
131
136
|
def replace_var(match):
|
132
|
-
var_name = match.group(
|
137
|
+
var_name = match.group("var")
|
133
138
|
# We're keeping the original formatting with braces
|
134
139
|
# but not using these variables directly
|
135
140
|
# open_brace = match.group('open')
|
136
141
|
# close_brace = match.group('close')
|
137
|
-
|
142
|
+
|
138
143
|
# Try to evaluate the variable in the context
|
139
144
|
try:
|
140
145
|
# Handle nested attributes (like item.price)
|
141
|
-
parts = var_name.split(
|
146
|
+
parts = var_name.split(".")
|
142
147
|
value = scenario
|
143
148
|
for part in parts:
|
144
149
|
if part in value:
|
@@ -151,7 +156,7 @@ class LoopProcessor:
|
|
151
156
|
except (KeyError, TypeError):
|
152
157
|
# Return the original variable name with the expected spacing
|
153
158
|
return f"{{ {var_name} }}".replace("{", "{{").replace("}", "}}")
|
154
|
-
|
159
|
+
|
155
160
|
# Replace all variables in the template
|
156
161
|
result = re.sub(pattern, replace_var, template)
|
157
162
|
return result
|
@@ -191,6 +196,280 @@ class LoopProcessor:
|
|
191
196
|
}
|
192
197
|
|
193
198
|
|
199
|
+
class LongSurveyLoopProcessor:
|
200
|
+
"""
|
201
|
+
A modified LoopProcessor that creates a long survey where each question is rendered for each scenario.
|
202
|
+
|
203
|
+
Returns a tuple of (long_questions, long_scenario_list).
|
204
|
+
The long scenario list is essentially a flattened scenario list with one scenario that has many fields.
|
205
|
+
|
206
|
+
Usage:
|
207
|
+
>>> from edsl.questions import QuestionMultipleChoice
|
208
|
+
>>> from edsl.surveys import Survey
|
209
|
+
>>> from edsl.scenarios import Scenario, ScenarioList
|
210
|
+
>>> q = QuestionMultipleChoice(question_name = "enjoy", question_text = "How much do you enjoy {{ scenario.activity }}?", question_options = ["Not at all", "Somewhat", "Very much"])
|
211
|
+
>>> scenarios = ScenarioList([Scenario({"activity": activity}) for activity in ["tennis", "racecar driving", "cycling"]])
|
212
|
+
>>> survey = Survey([q])
|
213
|
+
>>> loop_processor = LongSurveyLoopProcessor(survey, scenarios)
|
214
|
+
>>> long_questions_list, long_scenario_list = loop_processor.process_templates_for_all_questions()
|
215
|
+
"""
|
216
|
+
|
217
|
+
def __init__(self, survey: Survey, scenario_list: ScenarioList):
|
218
|
+
self.survey = survey
|
219
|
+
self.scenario_list = scenario_list
|
220
|
+
self.env = Environment(undefined=Undefined)
|
221
|
+
self.long_scenario_dict = {}
|
222
|
+
|
223
|
+
def process_templates_for_all_questions(
|
224
|
+
self,
|
225
|
+
) -> Tuple[List[QuestionBase], ScenarioList]:
|
226
|
+
long_questions_list = []
|
227
|
+
|
228
|
+
self.long_scenario_dict = {}
|
229
|
+
|
230
|
+
for question in self.survey.questions:
|
231
|
+
updates_for_one_question = self.process_templates(
|
232
|
+
question, self.scenario_list
|
233
|
+
)
|
234
|
+
long_questions_list.extend(updates_for_one_question)
|
235
|
+
|
236
|
+
long_scenario_list = ScenarioList([Scenario(data=self.long_scenario_dict)])
|
237
|
+
|
238
|
+
return long_questions_list, long_scenario_list
|
239
|
+
|
240
|
+
def process_templates(
|
241
|
+
self, question: QuestionBase, scenario_list: ScenarioList
|
242
|
+
) -> List[QuestionBase]:
|
243
|
+
"""Process templates for each scenario and return list of modified questions.
|
244
|
+
|
245
|
+
Args:
|
246
|
+
scenario_list: List of scenarios to process templates against
|
247
|
+
|
248
|
+
Returns:
|
249
|
+
List of QuestionBase objects with rendered templates
|
250
|
+
"""
|
251
|
+
import re
|
252
|
+
|
253
|
+
questions = []
|
254
|
+
starting_name = question.question_name
|
255
|
+
|
256
|
+
# Check for Jinja2 variables in the question text
|
257
|
+
pattern = self._jinja_variable_pattern()
|
258
|
+
variables_in_question_text = (
|
259
|
+
re.search(pattern, question.question_text) is not None
|
260
|
+
)
|
261
|
+
if variables_in_question_text:
|
262
|
+
for index, scenario in enumerate(scenario_list):
|
263
|
+
question_data = question.to_dict().copy()
|
264
|
+
processed_data = self._process_data(question_data, scenario, index)
|
265
|
+
|
266
|
+
if processed_data["question_name"] == starting_name:
|
267
|
+
processed_data["question_name"] += f"_{index}"
|
268
|
+
|
269
|
+
questions.append(QuestionBase.from_dict(processed_data))
|
270
|
+
else:
|
271
|
+
questions.append(question)
|
272
|
+
|
273
|
+
return questions
|
274
|
+
|
275
|
+
def _process_data(
|
276
|
+
self, data: Dict[str, Any], scenario: Dict[str, Any], scenario_index: int
|
277
|
+
) -> Dict[str, Any]:
|
278
|
+
"""Process all data fields according to their type.
|
279
|
+
|
280
|
+
Args:
|
281
|
+
data: Dictionary of question data
|
282
|
+
scenario: Current scenario to render templates against
|
283
|
+
|
284
|
+
Returns:
|
285
|
+
Processed dictionary with rendered templates
|
286
|
+
"""
|
287
|
+
processed = {}
|
288
|
+
|
289
|
+
extended_scenario = scenario.copy()
|
290
|
+
extended_scenario.update({"scenario": scenario})
|
291
|
+
|
292
|
+
for key, value in [(k, v) for k, v in data.items() if v is not None]:
|
293
|
+
processed[key] = self._process_value(
|
294
|
+
key, value, extended_scenario, scenario_index
|
295
|
+
)
|
296
|
+
|
297
|
+
return processed
|
298
|
+
|
299
|
+
def _process_value(
|
300
|
+
self, key: str, value: Any, scenario: Dict[str, Any], scenario_index: int
|
301
|
+
) -> Any:
|
302
|
+
"""Process a single value according to its type.
|
303
|
+
|
304
|
+
Args:
|
305
|
+
key: Dictionary key
|
306
|
+
value: Value to process
|
307
|
+
scenario: Current scenario
|
308
|
+
|
309
|
+
Returns:
|
310
|
+
Processed value
|
311
|
+
"""
|
312
|
+
if key == "question_options" and isinstance(value, str):
|
313
|
+
return value
|
314
|
+
|
315
|
+
if key == "option_labels":
|
316
|
+
|
317
|
+
return (
|
318
|
+
eval(self._render_template(value, scenario, scenario_index))
|
319
|
+
if isinstance(value, str)
|
320
|
+
else value
|
321
|
+
)
|
322
|
+
|
323
|
+
if isinstance(value, str):
|
324
|
+
return self._render_template(value, scenario, scenario_index)
|
325
|
+
|
326
|
+
if isinstance(value, list):
|
327
|
+
return self._process_list(value, scenario, scenario_index)
|
328
|
+
|
329
|
+
if isinstance(value, dict):
|
330
|
+
return self._process_dict(value, scenario, scenario_index)
|
331
|
+
|
332
|
+
if isinstance(value, (int, float)):
|
333
|
+
return value
|
334
|
+
|
335
|
+
from edsl.questions.exceptions import QuestionValueError
|
336
|
+
|
337
|
+
raise QuestionValueError(
|
338
|
+
f"Unexpected value type: {type(value)} for key '{key}'"
|
339
|
+
)
|
340
|
+
|
341
|
+
def _jinja_variable_pattern(self) -> str:
|
342
|
+
|
343
|
+
# Regular expression to find Jinja2 variables in the template
|
344
|
+
pattern = (
|
345
|
+
r"(?P<open>\{\{\s*)scenario\.(?P<var>[a-zA-Z0-9_.]+)(?P<close>\s*\}\})"
|
346
|
+
)
|
347
|
+
return pattern
|
348
|
+
|
349
|
+
def _render_template(
|
350
|
+
self, template: str, scenario: Dict[str, Any], scenario_index: int
|
351
|
+
) -> str:
|
352
|
+
"""Render a single template string.
|
353
|
+
|
354
|
+
Args:
|
355
|
+
template: Template string to render
|
356
|
+
scenario: Current scenario
|
357
|
+
|
358
|
+
Returns:
|
359
|
+
Rendered template string, preserving any unmatched template variables
|
360
|
+
|
361
|
+
Examples:
|
362
|
+
>>> from edsl.questions import QuestionBase
|
363
|
+
>>> from edsl.scenarios import Scenario, ScenarioList
|
364
|
+
>>> q = QuestionBase()
|
365
|
+
>>> q.question_text = "test"
|
366
|
+
>>> sl = ScenarioList([Scenario({"name": "World"}), Scenario({"name": "everyone"})])
|
367
|
+
>>> p = LongSurveyLoopProcessor(q, sl)
|
368
|
+
>>> p._render_template("Hello {{scenario.name}}!", {"name": "everyone"}, scenario_index=1)
|
369
|
+
'Hello {{ scenario.name_1 }}!'
|
370
|
+
|
371
|
+
>>> p._render_template("{{scenario.a}} and {{scenario.b}}", {"b": 6}, scenario_index=1)
|
372
|
+
'{{ a }} and {{ scenario.b_1 }}'
|
373
|
+
|
374
|
+
>>> p._render_template("{{scenario.x}} + {{scenario.y}} = {{scenario.z}}", {"x": 2, "y": 3}, scenario_index=5)
|
375
|
+
'{{ scenario.x_5 }} + {{ scenario.y_5 }} = {{ z }}'
|
376
|
+
|
377
|
+
>>> p._render_template("No variables here", {}, scenario_index=0)
|
378
|
+
'No variables here'
|
379
|
+
|
380
|
+
>>> p._render_template("{{scenario.item.price}}", {"item": {"price": 9.99}}, scenario_index=3)
|
381
|
+
'{{ scenario.item_3.price }}'
|
382
|
+
|
383
|
+
>>> p._render_template("{{scenario.item.missing}}", {"item": {"price": 9.99}}, scenario_index=3)
|
384
|
+
'{{ scenario.item_3.missing }}'
|
385
|
+
"""
|
386
|
+
import re
|
387
|
+
|
388
|
+
# Regular expression to find Jinja2 variables in the template
|
389
|
+
pattern = self._jinja_variable_pattern()
|
390
|
+
|
391
|
+
def replace_var(match):
|
392
|
+
var_name = match.group("var")
|
393
|
+
# We're keeping the original formatting with braces
|
394
|
+
# but not using these variables directly
|
395
|
+
# open_brace = match.group('open')
|
396
|
+
# close_brace = match.group('close')
|
397
|
+
try:
|
398
|
+
# Handle nested attributes (like item.price)
|
399
|
+
parts = var_name.split(".")
|
400
|
+
|
401
|
+
base_var = parts[0]
|
402
|
+
|
403
|
+
self.long_scenario_dict.update(
|
404
|
+
{f"{base_var}_{scenario_index}": scenario[base_var]}
|
405
|
+
)
|
406
|
+
|
407
|
+
if len(parts) > 1:
|
408
|
+
non_name_parts = ".".join(parts[1:])
|
409
|
+
result = (
|
410
|
+
f"{{ scenario.{base_var}_{scenario_index}.{non_name_parts} }}"
|
411
|
+
)
|
412
|
+
else:
|
413
|
+
result = f"{{ scenario.{base_var}_{scenario_index} }}"
|
414
|
+
|
415
|
+
result = result.replace("{", "{{").replace("}", "}}")
|
416
|
+
return result
|
417
|
+
except (KeyError, TypeError) as e:
|
418
|
+
# Return the original variable name with the expected spacing
|
419
|
+
result = f"{{ {var_name} }}".replace("{", "{{").replace("}", "}}")
|
420
|
+
return result
|
421
|
+
|
422
|
+
# Replace all variables in the template
|
423
|
+
result = re.sub(pattern, replace_var, template)
|
424
|
+
return result
|
425
|
+
|
426
|
+
def _process_list(
|
427
|
+
self, items: List[Any], scenario: Dict[str, Any], scenario_index: int
|
428
|
+
) -> List[Any]:
|
429
|
+
"""Process all items in a list.
|
430
|
+
|
431
|
+
Args:
|
432
|
+
items: List of items to process
|
433
|
+
scenario: Current scenario
|
434
|
+
|
435
|
+
Returns:
|
436
|
+
List of processed items
|
437
|
+
"""
|
438
|
+
return [
|
439
|
+
(
|
440
|
+
self._render_template(item, scenario, scenario_index)
|
441
|
+
if isinstance(item, str)
|
442
|
+
else item
|
443
|
+
)
|
444
|
+
for item in items
|
445
|
+
]
|
446
|
+
|
447
|
+
def _process_dict(
|
448
|
+
self, data: Dict[str, Any], scenario: Dict[str, Any], scenario_index: int
|
449
|
+
) -> Dict[str, Any]:
|
450
|
+
"""Process all keys and values in a dictionary.
|
451
|
+
|
452
|
+
Args:
|
453
|
+
data: Dictionary to process
|
454
|
+
scenario: Current scenario
|
455
|
+
|
456
|
+
Returns:
|
457
|
+
Dictionary with processed keys and values
|
458
|
+
"""
|
459
|
+
return {
|
460
|
+
(
|
461
|
+
self._render_template(k, scenario, scenario_index)
|
462
|
+
if isinstance(k, str)
|
463
|
+
else k
|
464
|
+
): (
|
465
|
+
self._render_template(v, scenario, scenario_index)
|
466
|
+
if isinstance(v, str)
|
467
|
+
else v
|
468
|
+
)
|
469
|
+
for k, v in data.items()
|
470
|
+
}
|
471
|
+
|
472
|
+
|
194
473
|
if __name__ == "__main__":
|
195
474
|
import doctest
|
196
475
|
|
edsl/results/result.py
CHANGED
@@ -95,6 +95,7 @@ class Result(Base, UserDict):
|
|
95
95
|
question_to_attributes: Optional[dict[QuestionName, Any]] = None,
|
96
96
|
generated_tokens: Optional[dict] = None,
|
97
97
|
comments_dict: Optional[dict] = None,
|
98
|
+
reasoning_summaries_dict: Optional[dict] = None,
|
98
99
|
cache_used_dict: Optional[dict[QuestionName, bool]] = None,
|
99
100
|
indices: Optional[dict] = None,
|
100
101
|
cache_keys: Optional[dict[QuestionName, str]] = None,
|
@@ -112,6 +113,7 @@ class Result(Base, UserDict):
|
|
112
113
|
:param question_to_attributes: A dictionary of question attributes.
|
113
114
|
:param generated_tokens: A dictionary of generated tokens.
|
114
115
|
:param comments_dict: A dictionary of comments.
|
116
|
+
:param reasoning_summaries_dict: A dictionary of reasoning summaries.
|
115
117
|
:param cache_used_dict: A dictionary of cache usage.
|
116
118
|
:param indices: A dictionary of indices.
|
117
119
|
|
@@ -130,6 +132,7 @@ class Result(Base, UserDict):
|
|
130
132
|
"question_to_attributes": self.question_to_attributes,
|
131
133
|
"generated_tokens": generated_tokens or {},
|
132
134
|
"comments_dict": comments_dict or {},
|
135
|
+
"reasoning_summaries_dict": reasoning_summaries_dict or {},
|
133
136
|
"cache_used_dict": cache_used_dict or {},
|
134
137
|
"cache_keys": cache_keys or {},
|
135
138
|
}
|
@@ -236,6 +239,7 @@ class Result(Base, UserDict):
|
|
236
239
|
"answer": self.data["answer"],
|
237
240
|
"prompt": self.data["prompt"],
|
238
241
|
"comment": self.data["comments_dict"],
|
242
|
+
"reasoning_summary": self.data["reasoning_summaries_dict"],
|
239
243
|
"generated_tokens": self.data["generated_tokens"],
|
240
244
|
"raw_model_response": self.data["raw_model_response"],
|
241
245
|
"question_text": sub_dicts_needing_new_keys["question_text"],
|
@@ -497,6 +501,7 @@ class Result(Base, UserDict):
|
|
497
501
|
question_to_attributes=json_dict.get("question_to_attributes", None),
|
498
502
|
generated_tokens=json_dict.get("generated_tokens", {}),
|
499
503
|
comments_dict=json_dict.get("comments_dict", {}),
|
504
|
+
reasoning_summaries_dict=json_dict.get("reasoning_summaries_dict", {}),
|
500
505
|
cache_used_dict=json_dict.get("cache_used_dict", {}),
|
501
506
|
cache_keys=json_dict.get("cache_keys", {}),
|
502
507
|
indices=json_dict.get("indices", None),
|
@@ -631,6 +636,36 @@ class Result(Base, UserDict):
|
|
631
636
|
}
|
632
637
|
return comments_dict
|
633
638
|
|
639
|
+
def get_reasoning_summaries_dict(answer_key_names) -> dict[str, Any]:
|
640
|
+
reasoning_summaries_dict = {}
|
641
|
+
for k in answer_key_names:
|
642
|
+
reasoning_summary = question_results[k].reasoning_summary
|
643
|
+
|
644
|
+
# If reasoning summary is None but we have a raw model response, try to extract it
|
645
|
+
if reasoning_summary is None and hasattr(question_results[k], 'raw_model_response'):
|
646
|
+
try:
|
647
|
+
# Get the model class to access the reasoning_sequence
|
648
|
+
model_class = interview.model.__class__ if hasattr(interview, 'model') else None
|
649
|
+
|
650
|
+
if model_class and hasattr(model_class, 'reasoning_sequence'):
|
651
|
+
from ..language_models.raw_response_handler import RawResponseHandler
|
652
|
+
|
653
|
+
# Create a handler with the model's reasoning sequence
|
654
|
+
handler = RawResponseHandler(
|
655
|
+
key_sequence=model_class.key_sequence if hasattr(model_class, 'key_sequence') else None,
|
656
|
+
usage_sequence=model_class.usage_sequence if hasattr(model_class, 'usage_sequence') else None,
|
657
|
+
reasoning_sequence=model_class.reasoning_sequence
|
658
|
+
)
|
659
|
+
|
660
|
+
# Try to extract the reasoning summary
|
661
|
+
reasoning_summary = handler.get_reasoning_summary(question_results[k].raw_model_response)
|
662
|
+
except Exception:
|
663
|
+
# If extraction fails, keep it as None
|
664
|
+
pass
|
665
|
+
|
666
|
+
reasoning_summaries_dict[k + "_reasoning_summary"] = reasoning_summary
|
667
|
+
return reasoning_summaries_dict
|
668
|
+
|
634
669
|
def get_question_name_to_prompts(
|
635
670
|
model_response_objects,
|
636
671
|
) -> dict[str, dict[str, str]]:
|
@@ -705,6 +740,7 @@ class Result(Base, UserDict):
|
|
705
740
|
answer_key_names = list(question_results.keys())
|
706
741
|
generated_tokens_dict = get_generated_tokens_dict(answer_key_names) if answer_key_names else {}
|
707
742
|
comments_dict = get_comments_dict(answer_key_names) if answer_key_names else {}
|
743
|
+
reasoning_summaries_dict = get_reasoning_summaries_dict(answer_key_names) if answer_key_names else {}
|
708
744
|
|
709
745
|
# Get answers that are in the question results
|
710
746
|
answer_dict = {}
|
@@ -735,6 +771,7 @@ class Result(Base, UserDict):
|
|
735
771
|
survey=survey_copy,
|
736
772
|
generated_tokens=generated_tokens_dict,
|
737
773
|
comments_dict=comments_dict,
|
774
|
+
reasoning_summaries_dict=reasoning_summaries_dict,
|
738
775
|
cache_used_dict=cache_used_dictionary,
|
739
776
|
indices=indices_copy,
|
740
777
|
cache_keys=cache_keys,
|
edsl/results/results.py
CHANGED
edsl/scenarios/scenario_list.py
CHANGED
@@ -159,7 +159,15 @@ class ScenarioList(MutableSequence, Base, ScenarioListOperationsMixin):
|
|
159
159
|
|
160
160
|
# Required MutableSequence abstract methods
|
161
161
|
def __getitem__(self, index):
|
162
|
-
"""Get item at index.
|
162
|
+
"""Get item at index.
|
163
|
+
|
164
|
+
Example:
|
165
|
+
>>> from edsl.scenarios import Scenario, ScenarioList
|
166
|
+
>>> sl = ScenarioList([Scenario({'a': 12})])
|
167
|
+
>>> sl[0]['b'] = 100 # modify in-place
|
168
|
+
>>> sl[0]['b']
|
169
|
+
100
|
170
|
+
"""
|
163
171
|
if isinstance(index, slice):
|
164
172
|
return self.__class__(list(self.data[index]), self.codebook.copy())
|
165
173
|
return self.data[index]
|
@@ -356,7 +364,29 @@ class ScenarioList(MutableSequence, Base, ScenarioListOperationsMixin):
|
|
356
364
|
new_scenarios.append(Scenario(new_scenario))
|
357
365
|
|
358
366
|
return new_scenarios
|
367
|
+
|
368
|
+
@classmethod
|
369
|
+
def from_prompt(self, description: str, name:Optional[str] = "item", target_number:int = 10, verbose = False):
|
370
|
+
from ..questions.question_list import QuestionList
|
371
|
+
q = QuestionList(question_name = name,
|
372
|
+
question_text = description + f"\n Please try to return {target_number} examples.")
|
373
|
+
results = q.run(verbose = verbose)
|
374
|
+
return results.select(name).to_scenario_list().expand(name)
|
359
375
|
|
376
|
+
|
377
|
+
def __add__(self, other):
|
378
|
+
if isinstance(other, Scenario):
|
379
|
+
new_list = self.duplicate()
|
380
|
+
new_list.append(other)
|
381
|
+
return new_list
|
382
|
+
elif isinstance(other, ScenarioList):
|
383
|
+
new_list = self.duplicate()
|
384
|
+
for item in other:
|
385
|
+
new_list.append(item)
|
386
|
+
else:
|
387
|
+
raise ScenarioError("Don't know how to combine!")
|
388
|
+
return new_list
|
389
|
+
|
360
390
|
@classmethod
|
361
391
|
def from_search_terms(cls, search_terms: List[str]) -> ScenarioList:
|
362
392
|
"""Create a ScenarioList from a list of search terms, using Wikipedia.
|