edsl 0.1.60__py3-none-any.whl → 0.1.61__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/__version__.py +1 -1
- edsl/agents/agent.py +65 -17
- edsl/agents/agent_list.py +117 -33
- edsl/base/base_class.py +80 -11
- edsl/config/config_class.py +7 -2
- edsl/coop/coop.py +1295 -85
- edsl/coop/coop_prolific_filters.py +171 -0
- edsl/dataset/display/table_display.py +40 -7
- edsl/db_list/sqlite_list.py +102 -3
- edsl/jobs/data_structures.py +46 -31
- edsl/jobs/jobs.py +73 -2
- edsl/jobs/remote_inference.py +49 -15
- edsl/questions/loop_processor.py +289 -10
- edsl/questions/templates/dict/answering_instructions.jinja +0 -1
- edsl/scenarios/scenario_list.py +31 -1
- edsl/scenarios/scenario_source.py +606 -498
- edsl/surveys/survey.py +198 -163
- {edsl-0.1.60.dist-info → edsl-0.1.61.dist-info}/METADATA +3 -3
- {edsl-0.1.60.dist-info → edsl-0.1.61.dist-info}/RECORD +22 -21
- {edsl-0.1.60.dist-info → edsl-0.1.61.dist-info}/LICENSE +0 -0
- {edsl-0.1.60.dist-info → edsl-0.1.61.dist-info}/WHEEL +0 -0
- {edsl-0.1.60.dist-info → edsl-0.1.61.dist-info}/entry_points.txt +0 -0
edsl/jobs/jobs.py
CHANGED
@@ -1,8 +1,8 @@
|
|
1
1
|
"""
|
2
2
|
The Jobs module is the core orchestration component of the EDSL framework.
|
3
3
|
|
4
|
-
It provides functionality to define, configure, and execute computational jobs that
|
5
|
-
involve multiple agents, scenarios, models, and a survey. Jobs are the primary way
|
4
|
+
It provides functionality to define, configure, and execute computational jobs that
|
5
|
+
involve multiple agents, scenarios, models, and a survey. Jobs are the primary way
|
6
6
|
that users run large-scale experiments or simulations in EDSL.
|
7
7
|
|
8
8
|
The Jobs class handles:
|
@@ -15,6 +15,7 @@ The Jobs class handles:
|
|
15
15
|
This module is designed to be used by both application developers and researchers
|
16
16
|
who need to run complex simulations with language models.
|
17
17
|
"""
|
18
|
+
|
18
19
|
from __future__ import annotations
|
19
20
|
import asyncio
|
20
21
|
from typing import Optional, Union, TypeVar, Callable, cast
|
@@ -564,6 +565,7 @@ class Jobs(Base):
|
|
564
565
|
remote_inference_description=self.run_config.parameters.remote_inference_description,
|
565
566
|
remote_inference_results_visibility=self.run_config.parameters.remote_inference_results_visibility,
|
566
567
|
fresh=self.run_config.parameters.fresh,
|
568
|
+
new_format=self.run_config.parameters.new_format,
|
567
569
|
)
|
568
570
|
return job_info
|
569
571
|
|
@@ -829,6 +831,7 @@ class Jobs(Base):
|
|
829
831
|
key_lookup (KeyLookup, optional): Object to manage API keys
|
830
832
|
memory_threshold (int, optional): Memory threshold in bytes for the Results object's SQLList,
|
831
833
|
controlling when data is offloaded to SQLite storage
|
834
|
+
new_format (bool): If True, uses remote_inference_create method, if False uses old_remote_inference_create method (default: True)
|
832
835
|
|
833
836
|
Returns:
|
834
837
|
Results: A Results object containing all responses and metadata
|
@@ -889,6 +892,7 @@ class Jobs(Base):
|
|
889
892
|
key_lookup (KeyLookup, optional): Object to manage API keys
|
890
893
|
memory_threshold (int, optional): Memory threshold in bytes for the Results object's SQLList,
|
891
894
|
controlling when data is offloaded to SQLite storage
|
895
|
+
new_format (bool): If True, uses remote_inference_create method, if False uses old_remote_inference_create method (default: True)
|
892
896
|
|
893
897
|
Returns:
|
894
898
|
Results: A Results object containing all responses and metadata
|
@@ -1084,6 +1088,73 @@ class Jobs(Base):
|
|
1084
1088
|
"""Return the code to create this instance."""
|
1085
1089
|
raise JobsImplementationError("Code generation not implemented yet")
|
1086
1090
|
|
1091
|
+
def humanize(
|
1092
|
+
self,
|
1093
|
+
project_name: str = "Project",
|
1094
|
+
scenario_list_method: Optional[
|
1095
|
+
Literal["randomize", "loop", "single_scenario"]
|
1096
|
+
] = None,
|
1097
|
+
survey_description: Optional[str] = None,
|
1098
|
+
survey_alias: Optional[str] = None,
|
1099
|
+
survey_visibility: Optional["VisibilityType"] = "unlisted",
|
1100
|
+
scenario_list_description: Optional[str] = None,
|
1101
|
+
scenario_list_alias: Optional[str] = None,
|
1102
|
+
scenario_list_visibility: Optional["VisibilityType"] = "unlisted",
|
1103
|
+
):
|
1104
|
+
"""
|
1105
|
+
Send the survey and scenario list to Coop.
|
1106
|
+
|
1107
|
+
Then, create a project on Coop so you can share the survey with human respondents.
|
1108
|
+
"""
|
1109
|
+
from edsl.coop import Coop
|
1110
|
+
from edsl.coop.exceptions import CoopValueError
|
1111
|
+
|
1112
|
+
if len(self.agents) > 0 or len(self.models) > 0:
|
1113
|
+
raise CoopValueError("We don't support humanize with agents or models yet.")
|
1114
|
+
|
1115
|
+
if len(self.scenarios) > 0 and scenario_list_method is None:
|
1116
|
+
raise CoopValueError(
|
1117
|
+
"You must specify both a scenario list and a scenario list method to use scenarios with your survey."
|
1118
|
+
)
|
1119
|
+
elif len(self.scenarios) == 0 and scenario_list_method is not None:
|
1120
|
+
raise CoopValueError(
|
1121
|
+
"You must specify both a scenario list and a scenario list method to use scenarios with your survey."
|
1122
|
+
)
|
1123
|
+
elif scenario_list_method is "loop":
|
1124
|
+
questions, long_scenario_list = self.survey.to_long_format(self.scenarios)
|
1125
|
+
|
1126
|
+
# Replace the questions with new ones from the loop method
|
1127
|
+
self.survey = Survey(questions)
|
1128
|
+
self.scenarios = long_scenario_list
|
1129
|
+
|
1130
|
+
if len(self.scenarios) != 1:
|
1131
|
+
raise CoopValueError("Something went wrong with the loop method.")
|
1132
|
+
elif len(self.scenarios) != 1 and scenario_list_method == "single_scenario":
|
1133
|
+
raise CoopValueError(
|
1134
|
+
f"The single_scenario method requires exactly one scenario. "
|
1135
|
+
f"If you have a scenario list with multiple scenarios, try using the randomize or loop methods."
|
1136
|
+
)
|
1137
|
+
|
1138
|
+
if len(self.scenarios) == 0:
|
1139
|
+
scenario_list = None
|
1140
|
+
else:
|
1141
|
+
scenario_list = self.scenarios
|
1142
|
+
|
1143
|
+
c = Coop()
|
1144
|
+
project_details = c.create_project(
|
1145
|
+
self.survey,
|
1146
|
+
scenario_list,
|
1147
|
+
scenario_list_method,
|
1148
|
+
project_name,
|
1149
|
+
survey_description,
|
1150
|
+
survey_alias,
|
1151
|
+
survey_visibility,
|
1152
|
+
scenario_list_description,
|
1153
|
+
scenario_list_alias,
|
1154
|
+
scenario_list_visibility,
|
1155
|
+
)
|
1156
|
+
return project_details
|
1157
|
+
|
1087
1158
|
|
1088
1159
|
def main():
|
1089
1160
|
"""Run the module's doctests."""
|
edsl/jobs/remote_inference.py
CHANGED
@@ -31,6 +31,7 @@ class RemoteJobInfo:
|
|
31
31
|
creation_data: RemoteInferenceCreationInfo
|
32
32
|
job_uuid: JobUUID
|
33
33
|
logger: JobLogger
|
34
|
+
new_format: bool = True
|
34
35
|
|
35
36
|
|
36
37
|
class JobsRemoteInferenceHandler:
|
@@ -85,7 +86,21 @@ class JobsRemoteInferenceHandler:
|
|
85
86
|
remote_inference_description: Optional[str] = None,
|
86
87
|
remote_inference_results_visibility: Optional["VisibilityType"] = "unlisted",
|
87
88
|
fresh: Optional[bool] = False,
|
89
|
+
new_format: Optional[bool] = True,
|
88
90
|
) -> RemoteJobInfo:
|
91
|
+
"""
|
92
|
+
Create a remote inference job and return job information.
|
93
|
+
|
94
|
+
Args:
|
95
|
+
iterations: Number of times to run each interview
|
96
|
+
remote_inference_description: Optional description for the remote job
|
97
|
+
remote_inference_results_visibility: Visibility setting for results
|
98
|
+
fresh: If True, ignore existing cache entries and generate new results
|
99
|
+
new_format: If True, use pull method for result retrieval; if False, use legacy get method
|
100
|
+
|
101
|
+
Returns:
|
102
|
+
RemoteJobInfo: Information about the created job including UUID and logger
|
103
|
+
"""
|
89
104
|
from ..coop import Coop
|
90
105
|
|
91
106
|
logger = self._create_logger()
|
@@ -101,14 +116,24 @@ class JobsRemoteInferenceHandler:
|
|
101
116
|
logger.add_info(
|
102
117
|
"remote_cache_url", f"{self.expected_parrot_url}/home/remote-cache"
|
103
118
|
)
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
119
|
+
if new_format:
|
120
|
+
remote_job_creation_data = coop.remote_inference_create(
|
121
|
+
self.jobs,
|
122
|
+
description=remote_inference_description,
|
123
|
+
status="queued",
|
124
|
+
iterations=iterations,
|
125
|
+
initial_results_visibility=remote_inference_results_visibility,
|
126
|
+
fresh=fresh,
|
127
|
+
)
|
128
|
+
else:
|
129
|
+
remote_job_creation_data = coop.old_remote_inference_create(
|
130
|
+
self.jobs,
|
131
|
+
description=remote_inference_description,
|
132
|
+
status="queued",
|
133
|
+
iterations=iterations,
|
134
|
+
initial_results_visibility=remote_inference_results_visibility,
|
135
|
+
fresh=fresh,
|
136
|
+
)
|
112
137
|
logger.update(
|
113
138
|
"Your survey is running at the Expected Parrot server...",
|
114
139
|
status=JobsStatus.RUNNING,
|
@@ -141,6 +166,7 @@ class JobsRemoteInferenceHandler:
|
|
141
166
|
creation_data=remote_job_creation_data,
|
142
167
|
job_uuid=job_uuid,
|
143
168
|
logger=logger,
|
169
|
+
new_format=new_format,
|
144
170
|
)
|
145
171
|
|
146
172
|
@staticmethod
|
@@ -164,7 +190,7 @@ class JobsRemoteInferenceHandler:
|
|
164
190
|
return coop.remote_inference_get
|
165
191
|
|
166
192
|
def _construct_object_fetcher(
|
167
|
-
self, testing_simulated_response: Optional[Any] = None
|
193
|
+
self, new_format: bool = True, testing_simulated_response: Optional[Any] = None
|
168
194
|
) -> Callable:
|
169
195
|
"Constructs a function to fetch the results object from Coop."
|
170
196
|
if testing_simulated_response is not None:
|
@@ -173,7 +199,10 @@ class JobsRemoteInferenceHandler:
|
|
173
199
|
from ..coop import Coop
|
174
200
|
|
175
201
|
coop = Coop()
|
176
|
-
|
202
|
+
if new_format:
|
203
|
+
return coop.pull
|
204
|
+
else:
|
205
|
+
return coop.get
|
177
206
|
|
178
207
|
def _handle_cancelled_job(self, job_info: RemoteJobInfo) -> None:
|
179
208
|
"Handles a cancelled job by logging the cancellation and updating the job status."
|
@@ -395,7 +424,6 @@ class JobsRemoteInferenceHandler:
|
|
395
424
|
|
396
425
|
converter = CostConverter()
|
397
426
|
for model_key, model_cost_dict in expenses_by_model.items():
|
398
|
-
|
399
427
|
# Handle full cost (without cache)
|
400
428
|
input_cost = model_cost_dict["input_cost_usd"]
|
401
429
|
output_cost = model_cost_dict["output_cost_usd"]
|
@@ -417,9 +445,9 @@ class JobsRemoteInferenceHandler:
|
|
417
445
|
model_cost_dict["input_cost_credits_with_cache"] = converter.usd_to_credits(
|
418
446
|
input_cost_with_cache
|
419
447
|
)
|
420
|
-
model_cost_dict[
|
421
|
-
|
422
|
-
)
|
448
|
+
model_cost_dict[
|
449
|
+
"output_cost_credits_with_cache"
|
450
|
+
] = converter.usd_to_credits(output_cost_with_cache)
|
423
451
|
return list(expenses_by_model.values())
|
424
452
|
|
425
453
|
def _fetch_results_and_log(
|
@@ -525,7 +553,10 @@ class JobsRemoteInferenceHandler:
|
|
525
553
|
remote_job_data_fetcher = self._construct_remote_job_fetcher(
|
526
554
|
testing_simulated_response
|
527
555
|
)
|
528
|
-
object_fetcher = self._construct_object_fetcher(
|
556
|
+
object_fetcher = self._construct_object_fetcher(
|
557
|
+
new_format=job_info.new_format,
|
558
|
+
testing_simulated_response=testing_simulated_response,
|
559
|
+
)
|
529
560
|
|
530
561
|
job_in_queue = True
|
531
562
|
while job_in_queue:
|
@@ -540,6 +571,7 @@ class JobsRemoteInferenceHandler:
|
|
540
571
|
iterations: int = 1,
|
541
572
|
remote_inference_description: Optional[str] = None,
|
542
573
|
remote_inference_results_visibility: Optional[VisibilityType] = "unlisted",
|
574
|
+
new_format: Optional[bool] = True,
|
543
575
|
) -> Union["Results", None]:
|
544
576
|
"""
|
545
577
|
Creates and polls a remote inference job asynchronously.
|
@@ -548,6 +580,7 @@ class JobsRemoteInferenceHandler:
|
|
548
580
|
:param iterations: Number of times to run each interview
|
549
581
|
:param remote_inference_description: Optional description for the remote job
|
550
582
|
:param remote_inference_results_visibility: Visibility setting for results
|
583
|
+
:param new_format: If True, use pull method for result retrieval; if False, use legacy get method
|
551
584
|
:return: Results object if successful, None if job fails or is cancelled
|
552
585
|
"""
|
553
586
|
import asyncio
|
@@ -562,6 +595,7 @@ class JobsRemoteInferenceHandler:
|
|
562
595
|
iterations=iterations,
|
563
596
|
remote_inference_description=remote_inference_description,
|
564
597
|
remote_inference_results_visibility=remote_inference_results_visibility,
|
598
|
+
new_format=new_format,
|
565
599
|
),
|
566
600
|
)
|
567
601
|
if job_info is None:
|
edsl/questions/loop_processor.py
CHANGED
@@ -1,7 +1,9 @@
|
|
1
|
-
from typing import List, Any, Dict
|
1
|
+
from typing import List, Any, Dict, Tuple
|
2
2
|
from jinja2 import Environment, Undefined
|
3
3
|
from .question_base import QuestionBase
|
4
|
-
from ..scenarios import ScenarioList
|
4
|
+
from ..scenarios import Scenario, ScenarioList
|
5
|
+
from ..surveys import Survey
|
6
|
+
|
5
7
|
|
6
8
|
class LoopProcessor:
|
7
9
|
def __init__(self, question: QuestionBase):
|
@@ -88,7 +90,10 @@ class LoopProcessor:
|
|
88
90
|
return value
|
89
91
|
|
90
92
|
from .exceptions import QuestionValueError
|
91
|
-
|
93
|
+
|
94
|
+
raise QuestionValueError(
|
95
|
+
f"Unexpected value type: {type(value)} for key '{key}'"
|
96
|
+
)
|
92
97
|
|
93
98
|
def _render_template(self, template: str, scenario: Dict[str, Any]) -> str:
|
94
99
|
"""Render a single template string.
|
@@ -124,21 +129,21 @@ class LoopProcessor:
|
|
124
129
|
'{{ item.missing }}'
|
125
130
|
"""
|
126
131
|
import re
|
127
|
-
|
132
|
+
|
128
133
|
# Regular expression to find Jinja2 variables in the template
|
129
|
-
pattern = r
|
130
|
-
|
134
|
+
pattern = r"(?P<open>\{\{\s*)(?P<var>[a-zA-Z0-9_.]+)(?P<close>\s*\}\})"
|
135
|
+
|
131
136
|
def replace_var(match):
|
132
|
-
var_name = match.group(
|
137
|
+
var_name = match.group("var")
|
133
138
|
# We're keeping the original formatting with braces
|
134
139
|
# but not using these variables directly
|
135
140
|
# open_brace = match.group('open')
|
136
141
|
# close_brace = match.group('close')
|
137
|
-
|
142
|
+
|
138
143
|
# Try to evaluate the variable in the context
|
139
144
|
try:
|
140
145
|
# Handle nested attributes (like item.price)
|
141
|
-
parts = var_name.split(
|
146
|
+
parts = var_name.split(".")
|
142
147
|
value = scenario
|
143
148
|
for part in parts:
|
144
149
|
if part in value:
|
@@ -151,7 +156,7 @@ class LoopProcessor:
|
|
151
156
|
except (KeyError, TypeError):
|
152
157
|
# Return the original variable name with the expected spacing
|
153
158
|
return f"{{ {var_name} }}".replace("{", "{{").replace("}", "}}")
|
154
|
-
|
159
|
+
|
155
160
|
# Replace all variables in the template
|
156
161
|
result = re.sub(pattern, replace_var, template)
|
157
162
|
return result
|
@@ -191,6 +196,280 @@ class LoopProcessor:
|
|
191
196
|
}
|
192
197
|
|
193
198
|
|
199
|
+
class LongSurveyLoopProcessor:
|
200
|
+
"""
|
201
|
+
A modified LoopProcessor that creates a long survey where each question is rendered for each scenario.
|
202
|
+
|
203
|
+
Returns a tuple of (long_questions, long_scenario_list).
|
204
|
+
The long scenario list is essentially a flattened scenario list with one scenario that has many fields.
|
205
|
+
|
206
|
+
Usage:
|
207
|
+
>>> from edsl.questions import QuestionMultipleChoice
|
208
|
+
>>> from edsl.surveys import Survey
|
209
|
+
>>> from edsl.scenarios import Scenario, ScenarioList
|
210
|
+
>>> q = QuestionMultipleChoice(question_name = "enjoy", question_text = "How much do you enjoy {{ scenario.activity }}?", question_options = ["Not at all", "Somewhat", "Very much"])
|
211
|
+
>>> scenarios = ScenarioList([Scenario({"activity": activity}) for activity in ["tennis", "racecar driving", "cycling"]])
|
212
|
+
>>> survey = Survey([q])
|
213
|
+
>>> loop_processor = LongSurveyLoopProcessor(survey, scenarios)
|
214
|
+
>>> long_questions_list, long_scenario_list = loop_processor.process_templates_for_all_questions()
|
215
|
+
"""
|
216
|
+
|
217
|
+
def __init__(self, survey: Survey, scenario_list: ScenarioList):
|
218
|
+
self.survey = survey
|
219
|
+
self.scenario_list = scenario_list
|
220
|
+
self.env = Environment(undefined=Undefined)
|
221
|
+
self.long_scenario_dict = {}
|
222
|
+
|
223
|
+
def process_templates_for_all_questions(
|
224
|
+
self,
|
225
|
+
) -> Tuple[List[QuestionBase], ScenarioList]:
|
226
|
+
long_questions_list = []
|
227
|
+
|
228
|
+
self.long_scenario_dict = {}
|
229
|
+
|
230
|
+
for question in self.survey.questions:
|
231
|
+
updates_for_one_question = self.process_templates(
|
232
|
+
question, self.scenario_list
|
233
|
+
)
|
234
|
+
long_questions_list.extend(updates_for_one_question)
|
235
|
+
|
236
|
+
long_scenario_list = ScenarioList([Scenario(data=self.long_scenario_dict)])
|
237
|
+
|
238
|
+
return long_questions_list, long_scenario_list
|
239
|
+
|
240
|
+
def process_templates(
|
241
|
+
self, question: QuestionBase, scenario_list: ScenarioList
|
242
|
+
) -> List[QuestionBase]:
|
243
|
+
"""Process templates for each scenario and return list of modified questions.
|
244
|
+
|
245
|
+
Args:
|
246
|
+
scenario_list: List of scenarios to process templates against
|
247
|
+
|
248
|
+
Returns:
|
249
|
+
List of QuestionBase objects with rendered templates
|
250
|
+
"""
|
251
|
+
import re
|
252
|
+
|
253
|
+
questions = []
|
254
|
+
starting_name = question.question_name
|
255
|
+
|
256
|
+
# Check for Jinja2 variables in the question text
|
257
|
+
pattern = self._jinja_variable_pattern()
|
258
|
+
variables_in_question_text = (
|
259
|
+
re.search(pattern, question.question_text) is not None
|
260
|
+
)
|
261
|
+
if variables_in_question_text:
|
262
|
+
for index, scenario in enumerate(scenario_list):
|
263
|
+
question_data = question.to_dict().copy()
|
264
|
+
processed_data = self._process_data(question_data, scenario, index)
|
265
|
+
|
266
|
+
if processed_data["question_name"] == starting_name:
|
267
|
+
processed_data["question_name"] += f"_{index}"
|
268
|
+
|
269
|
+
questions.append(QuestionBase.from_dict(processed_data))
|
270
|
+
else:
|
271
|
+
questions.append(question)
|
272
|
+
|
273
|
+
return questions
|
274
|
+
|
275
|
+
def _process_data(
|
276
|
+
self, data: Dict[str, Any], scenario: Dict[str, Any], scenario_index: int
|
277
|
+
) -> Dict[str, Any]:
|
278
|
+
"""Process all data fields according to their type.
|
279
|
+
|
280
|
+
Args:
|
281
|
+
data: Dictionary of question data
|
282
|
+
scenario: Current scenario to render templates against
|
283
|
+
|
284
|
+
Returns:
|
285
|
+
Processed dictionary with rendered templates
|
286
|
+
"""
|
287
|
+
processed = {}
|
288
|
+
|
289
|
+
extended_scenario = scenario.copy()
|
290
|
+
extended_scenario.update({"scenario": scenario})
|
291
|
+
|
292
|
+
for key, value in [(k, v) for k, v in data.items() if v is not None]:
|
293
|
+
processed[key] = self._process_value(
|
294
|
+
key, value, extended_scenario, scenario_index
|
295
|
+
)
|
296
|
+
|
297
|
+
return processed
|
298
|
+
|
299
|
+
def _process_value(
|
300
|
+
self, key: str, value: Any, scenario: Dict[str, Any], scenario_index: int
|
301
|
+
) -> Any:
|
302
|
+
"""Process a single value according to its type.
|
303
|
+
|
304
|
+
Args:
|
305
|
+
key: Dictionary key
|
306
|
+
value: Value to process
|
307
|
+
scenario: Current scenario
|
308
|
+
|
309
|
+
Returns:
|
310
|
+
Processed value
|
311
|
+
"""
|
312
|
+
if key == "question_options" and isinstance(value, str):
|
313
|
+
return value
|
314
|
+
|
315
|
+
if key == "option_labels":
|
316
|
+
|
317
|
+
return (
|
318
|
+
eval(self._render_template(value, scenario, scenario_index))
|
319
|
+
if isinstance(value, str)
|
320
|
+
else value
|
321
|
+
)
|
322
|
+
|
323
|
+
if isinstance(value, str):
|
324
|
+
return self._render_template(value, scenario, scenario_index)
|
325
|
+
|
326
|
+
if isinstance(value, list):
|
327
|
+
return self._process_list(value, scenario, scenario_index)
|
328
|
+
|
329
|
+
if isinstance(value, dict):
|
330
|
+
return self._process_dict(value, scenario, scenario_index)
|
331
|
+
|
332
|
+
if isinstance(value, (int, float)):
|
333
|
+
return value
|
334
|
+
|
335
|
+
from edsl.questions.exceptions import QuestionValueError
|
336
|
+
|
337
|
+
raise QuestionValueError(
|
338
|
+
f"Unexpected value type: {type(value)} for key '{key}'"
|
339
|
+
)
|
340
|
+
|
341
|
+
def _jinja_variable_pattern(self) -> str:
|
342
|
+
|
343
|
+
# Regular expression to find Jinja2 variables in the template
|
344
|
+
pattern = (
|
345
|
+
r"(?P<open>\{\{\s*)scenario\.(?P<var>[a-zA-Z0-9_.]+)(?P<close>\s*\}\})"
|
346
|
+
)
|
347
|
+
return pattern
|
348
|
+
|
349
|
+
def _render_template(
|
350
|
+
self, template: str, scenario: Dict[str, Any], scenario_index: int
|
351
|
+
) -> str:
|
352
|
+
"""Render a single template string.
|
353
|
+
|
354
|
+
Args:
|
355
|
+
template: Template string to render
|
356
|
+
scenario: Current scenario
|
357
|
+
|
358
|
+
Returns:
|
359
|
+
Rendered template string, preserving any unmatched template variables
|
360
|
+
|
361
|
+
Examples:
|
362
|
+
>>> from edsl.questions import QuestionBase
|
363
|
+
>>> from edsl.scenarios import Scenario, ScenarioList
|
364
|
+
>>> q = QuestionBase()
|
365
|
+
>>> q.question_text = "test"
|
366
|
+
>>> sl = ScenarioList([Scenario({"name": "World"}), Scenario({"name": "everyone"})])
|
367
|
+
>>> p = LongSurveyLoopProcessor(q, sl)
|
368
|
+
>>> p._render_template("Hello {{scenario.name}}!", {"name": "everyone"}, scenario_index=1)
|
369
|
+
'Hello {{ scenario.name_1 }}!'
|
370
|
+
|
371
|
+
>>> p._render_template("{{scenario.a}} and {{scenario.b}}", {"b": 6}, scenario_index=1)
|
372
|
+
'{{ a }} and {{ scenario.b_1 }}'
|
373
|
+
|
374
|
+
>>> p._render_template("{{scenario.x}} + {{scenario.y}} = {{scenario.z}}", {"x": 2, "y": 3}, scenario_index=5)
|
375
|
+
'{{ scenario.x_5 }} + {{ scenario.y_5 }} = {{ z }}'
|
376
|
+
|
377
|
+
>>> p._render_template("No variables here", {}, scenario_index=0)
|
378
|
+
'No variables here'
|
379
|
+
|
380
|
+
>>> p._render_template("{{scenario.item.price}}", {"item": {"price": 9.99}}, scenario_index=3)
|
381
|
+
'{{ scenario.item_3.price }}'
|
382
|
+
|
383
|
+
>>> p._render_template("{{scenario.item.missing}}", {"item": {"price": 9.99}}, scenario_index=3)
|
384
|
+
'{{ scenario.item_3.missing }}'
|
385
|
+
"""
|
386
|
+
import re
|
387
|
+
|
388
|
+
# Regular expression to find Jinja2 variables in the template
|
389
|
+
pattern = self._jinja_variable_pattern()
|
390
|
+
|
391
|
+
def replace_var(match):
|
392
|
+
var_name = match.group("var")
|
393
|
+
# We're keeping the original formatting with braces
|
394
|
+
# but not using these variables directly
|
395
|
+
# open_brace = match.group('open')
|
396
|
+
# close_brace = match.group('close')
|
397
|
+
try:
|
398
|
+
# Handle nested attributes (like item.price)
|
399
|
+
parts = var_name.split(".")
|
400
|
+
|
401
|
+
base_var = parts[0]
|
402
|
+
|
403
|
+
self.long_scenario_dict.update(
|
404
|
+
{f"{base_var}_{scenario_index}": scenario[base_var]}
|
405
|
+
)
|
406
|
+
|
407
|
+
if len(parts) > 1:
|
408
|
+
non_name_parts = ".".join(parts[1:])
|
409
|
+
result = (
|
410
|
+
f"{{ scenario.{base_var}_{scenario_index}.{non_name_parts} }}"
|
411
|
+
)
|
412
|
+
else:
|
413
|
+
result = f"{{ scenario.{base_var}_{scenario_index} }}"
|
414
|
+
|
415
|
+
result = result.replace("{", "{{").replace("}", "}}")
|
416
|
+
return result
|
417
|
+
except (KeyError, TypeError) as e:
|
418
|
+
# Return the original variable name with the expected spacing
|
419
|
+
result = f"{{ {var_name} }}".replace("{", "{{").replace("}", "}}")
|
420
|
+
return result
|
421
|
+
|
422
|
+
# Replace all variables in the template
|
423
|
+
result = re.sub(pattern, replace_var, template)
|
424
|
+
return result
|
425
|
+
|
426
|
+
def _process_list(
|
427
|
+
self, items: List[Any], scenario: Dict[str, Any], scenario_index: int
|
428
|
+
) -> List[Any]:
|
429
|
+
"""Process all items in a list.
|
430
|
+
|
431
|
+
Args:
|
432
|
+
items: List of items to process
|
433
|
+
scenario: Current scenario
|
434
|
+
|
435
|
+
Returns:
|
436
|
+
List of processed items
|
437
|
+
"""
|
438
|
+
return [
|
439
|
+
(
|
440
|
+
self._render_template(item, scenario, scenario_index)
|
441
|
+
if isinstance(item, str)
|
442
|
+
else item
|
443
|
+
)
|
444
|
+
for item in items
|
445
|
+
]
|
446
|
+
|
447
|
+
def _process_dict(
|
448
|
+
self, data: Dict[str, Any], scenario: Dict[str, Any], scenario_index: int
|
449
|
+
) -> Dict[str, Any]:
|
450
|
+
"""Process all keys and values in a dictionary.
|
451
|
+
|
452
|
+
Args:
|
453
|
+
data: Dictionary to process
|
454
|
+
scenario: Current scenario
|
455
|
+
|
456
|
+
Returns:
|
457
|
+
Dictionary with processed keys and values
|
458
|
+
"""
|
459
|
+
return {
|
460
|
+
(
|
461
|
+
self._render_template(k, scenario, scenario_index)
|
462
|
+
if isinstance(k, str)
|
463
|
+
else k
|
464
|
+
): (
|
465
|
+
self._render_template(v, scenario, scenario_index)
|
466
|
+
if isinstance(v, str)
|
467
|
+
else v
|
468
|
+
)
|
469
|
+
for k, v in data.items()
|
470
|
+
}
|
471
|
+
|
472
|
+
|
194
473
|
if __name__ == "__main__":
|
195
474
|
import doctest
|
196
475
|
|
edsl/scenarios/scenario_list.py
CHANGED
@@ -159,7 +159,15 @@ class ScenarioList(MutableSequence, Base, ScenarioListOperationsMixin):
|
|
159
159
|
|
160
160
|
# Required MutableSequence abstract methods
|
161
161
|
def __getitem__(self, index):
|
162
|
-
"""Get item at index.
|
162
|
+
"""Get item at index.
|
163
|
+
|
164
|
+
Example:
|
165
|
+
>>> from edsl.scenarios import Scenario, ScenarioList
|
166
|
+
>>> sl = ScenarioList([Scenario({'a': 12})])
|
167
|
+
>>> sl[0]['b'] = 100 # modify in-place
|
168
|
+
>>> sl[0]['b']
|
169
|
+
100
|
170
|
+
"""
|
163
171
|
if isinstance(index, slice):
|
164
172
|
return self.__class__(list(self.data[index]), self.codebook.copy())
|
165
173
|
return self.data[index]
|
@@ -356,7 +364,29 @@ class ScenarioList(MutableSequence, Base, ScenarioListOperationsMixin):
|
|
356
364
|
new_scenarios.append(Scenario(new_scenario))
|
357
365
|
|
358
366
|
return new_scenarios
|
367
|
+
|
368
|
+
@classmethod
|
369
|
+
def from_prompt(self, description: str, name:Optional[str] = "item", target_number:int = 10, verbose = False):
|
370
|
+
from ..questions.question_list import QuestionList
|
371
|
+
q = QuestionList(question_name = name,
|
372
|
+
question_text = description + f"\n Please try to return {target_number} examples.")
|
373
|
+
results = q.run(verbose = verbose)
|
374
|
+
return results.select(name).to_scenario_list().expand(name)
|
359
375
|
|
376
|
+
|
377
|
+
def __add__(self, other):
|
378
|
+
if isinstance(other, Scenario):
|
379
|
+
new_list = self.duplicate()
|
380
|
+
new_list.append(other)
|
381
|
+
return new_list
|
382
|
+
elif isinstance(other, ScenarioList):
|
383
|
+
new_list = self.duplicate()
|
384
|
+
for item in other:
|
385
|
+
new_list.append(item)
|
386
|
+
else:
|
387
|
+
raise ScenarioError("Don't know how to combine!")
|
388
|
+
return new_list
|
389
|
+
|
360
390
|
@classmethod
|
361
391
|
def from_search_terms(cls, search_terms: List[str]) -> ScenarioList:
|
362
392
|
"""Create a ScenarioList from a list of search terms, using Wikipedia.
|