edsl 0.1.43__py3-none-any.whl → 0.1.45__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +15 -6
- edsl/__version__.py +1 -1
- edsl/agents/InvigilatorBase.py +3 -1
- edsl/agents/PromptConstructor.py +62 -34
- edsl/agents/QuestionInstructionPromptBuilder.py +111 -68
- edsl/agents/QuestionTemplateReplacementsBuilder.py +69 -16
- edsl/agents/question_option_processor.py +15 -6
- edsl/coop/CoopFunctionsMixin.py +3 -4
- edsl/coop/coop.py +56 -10
- edsl/enums.py +4 -1
- edsl/inference_services/AnthropicService.py +12 -8
- edsl/inference_services/AvailableModelFetcher.py +2 -0
- edsl/inference_services/AwsBedrock.py +1 -2
- edsl/inference_services/AzureAI.py +12 -9
- edsl/inference_services/GoogleService.py +10 -3
- edsl/inference_services/InferenceServiceABC.py +1 -0
- edsl/inference_services/InferenceServicesCollection.py +2 -2
- edsl/inference_services/MistralAIService.py +1 -2
- edsl/inference_services/OpenAIService.py +10 -4
- edsl/inference_services/PerplexityService.py +2 -1
- edsl/inference_services/TestService.py +1 -0
- edsl/inference_services/XAIService.py +11 -0
- edsl/inference_services/registry.py +2 -0
- edsl/jobs/Jobs.py +9 -0
- edsl/jobs/JobsChecks.py +11 -14
- edsl/jobs/JobsPrompts.py +3 -3
- edsl/jobs/async_interview_runner.py +3 -1
- edsl/jobs/check_survey_scenario_compatibility.py +5 -5
- edsl/jobs/interviews/InterviewExceptionEntry.py +12 -0
- edsl/jobs/tasks/TaskHistory.py +1 -1
- edsl/language_models/LanguageModel.py +3 -3
- edsl/language_models/PriceManager.py +45 -5
- edsl/language_models/model.py +89 -36
- edsl/questions/QuestionBase.py +21 -0
- edsl/questions/QuestionBasePromptsMixin.py +103 -0
- edsl/questions/QuestionFreeText.py +22 -5
- edsl/questions/descriptors.py +4 -0
- edsl/questions/question_base_gen_mixin.py +94 -29
- edsl/results/Dataset.py +65 -0
- edsl/results/DatasetExportMixin.py +299 -32
- edsl/results/Result.py +27 -0
- edsl/results/Results.py +24 -3
- edsl/results/ResultsGGMixin.py +7 -3
- edsl/scenarios/DocumentChunker.py +2 -0
- edsl/scenarios/FileStore.py +29 -8
- edsl/scenarios/PdfExtractor.py +21 -1
- edsl/scenarios/Scenario.py +25 -9
- edsl/scenarios/ScenarioList.py +73 -3
- edsl/scenarios/handlers/__init__.py +1 -0
- edsl/scenarios/handlers/docx.py +5 -1
- edsl/scenarios/handlers/jpeg.py +39 -0
- edsl/surveys/Survey.py +28 -6
- edsl/surveys/SurveyFlowVisualization.py +91 -43
- edsl/templates/error_reporting/exceptions_table.html +7 -8
- edsl/templates/error_reporting/interview_details.html +1 -1
- edsl/templates/error_reporting/interviews.html +0 -1
- edsl/templates/error_reporting/overview.html +2 -7
- edsl/templates/error_reporting/performance_plot.html +1 -1
- edsl/templates/error_reporting/report.css +1 -1
- edsl/utilities/PrettyList.py +14 -0
- edsl-0.1.45.dist-info/METADATA +246 -0
- {edsl-0.1.43.dist-info → edsl-0.1.45.dist-info}/RECORD +64 -62
- edsl-0.1.43.dist-info/METADATA +0 -110
- {edsl-0.1.43.dist-info → edsl-0.1.45.dist-info}/LICENSE +0 -0
- {edsl-0.1.43.dist-info → edsl-0.1.45.dist-info}/WHEEL +0 -0
edsl/language_models/model.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
import textwrap
|
2
2
|
from random import random
|
3
|
-
from typing import Optional, TYPE_CHECKING, List
|
3
|
+
from typing import Optional, TYPE_CHECKING, List, Callable
|
4
4
|
|
5
5
|
from edsl.utilities.PrettyList import PrettyList
|
6
6
|
from edsl.config import CONFIG
|
@@ -11,17 +11,25 @@ from edsl.inference_services.InferenceServicesCollection import (
|
|
11
11
|
from edsl.inference_services.data_structures import AvailableModels
|
12
12
|
from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
13
13
|
from edsl.enums import InferenceServiceLiteral
|
14
|
+
from edsl.exceptions.inference_services import InferenceServiceError
|
14
15
|
|
15
16
|
if TYPE_CHECKING:
|
16
17
|
from edsl.results.Dataset import Dataset
|
17
18
|
|
18
19
|
|
19
|
-
def get_model_class(
|
20
|
+
def get_model_class(
|
21
|
+
model_name,
|
22
|
+
registry: Optional[InferenceServicesCollection] = None,
|
23
|
+
service_name: Optional[InferenceServiceLiteral] = None,
|
24
|
+
):
|
20
25
|
from edsl.inference_services.registry import default
|
21
26
|
|
22
27
|
registry = registry or default
|
23
|
-
|
24
|
-
|
28
|
+
try:
|
29
|
+
factory = registry.create_model_factory(model_name, service_name=service_name)
|
30
|
+
return factory
|
31
|
+
except (InferenceServiceError, Exception) as e:
|
32
|
+
return Model._handle_model_error(model_name, e)
|
25
33
|
|
26
34
|
|
27
35
|
class Meta(type):
|
@@ -36,6 +44,9 @@ class Meta(type):
|
|
36
44
|
To get the default model, you can leave out the model name.
|
37
45
|
To see the available models, you can do:
|
38
46
|
>>> Model.available()
|
47
|
+
|
48
|
+
Or to see the models for a specific service, you can do:
|
49
|
+
>>> Model.available(service='openai')
|
39
50
|
"""
|
40
51
|
)
|
41
52
|
|
@@ -58,6 +69,33 @@ class Model(metaclass=Meta):
|
|
58
69
|
"""Set a new registry"""
|
59
70
|
cls._registry = registry
|
60
71
|
|
72
|
+
@classmethod
|
73
|
+
def _handle_model_error(cls, model_name: str, error: Exception):
|
74
|
+
"""Handle errors from model creation and execution with notebook-aware behavior."""
|
75
|
+
if isinstance(error, InferenceServiceError):
|
76
|
+
services = [s._inference_service_ for s in cls.get_registry().services]
|
77
|
+
message = (
|
78
|
+
f"Model '{model_name}' not found in any services.\n"
|
79
|
+
"It is likely that our registry is just out of date.\n"
|
80
|
+
"Simply adding the service name to your model call should fix this.\n"
|
81
|
+
f"Available services are: {services}\n"
|
82
|
+
f"To specify a model with a service, use:\n"
|
83
|
+
f'Model("{model_name}", service_name="<service_name>")'
|
84
|
+
)
|
85
|
+
else:
|
86
|
+
message = f"An error occurred: {str(error)}"
|
87
|
+
|
88
|
+
# Check if we're in a notebook environment
|
89
|
+
try:
|
90
|
+
get_ipython()
|
91
|
+
print(message)
|
92
|
+
return None
|
93
|
+
except NameError:
|
94
|
+
# Not in a notebook, raise the exception
|
95
|
+
if isinstance(error, InferenceServiceError):
|
96
|
+
raise InferenceServiceError(message)
|
97
|
+
raise error
|
98
|
+
|
61
99
|
def __new__(
|
62
100
|
cls,
|
63
101
|
model_name: Optional[str] = None,
|
@@ -66,12 +104,13 @@ class Model(metaclass=Meta):
|
|
66
104
|
*args,
|
67
105
|
**kwargs,
|
68
106
|
):
|
69
|
-
"Instantiate a new language model.
|
107
|
+
"""Instantiate a new language model.
|
108
|
+
>>> Model()
|
109
|
+
Model(...)
|
110
|
+
"""
|
70
111
|
# Map index to the respective subclass
|
71
112
|
if model_name is None:
|
72
|
-
model_name =
|
73
|
-
cls.default_model
|
74
|
-
) # when model_name is None, use the default model, set in the config file
|
113
|
+
model_name = cls.default_model
|
75
114
|
|
76
115
|
if registry is not None:
|
77
116
|
cls.set_registry(registry)
|
@@ -79,10 +118,13 @@ class Model(metaclass=Meta):
|
|
79
118
|
if isinstance(model_name, int): # can refer to a model by index
|
80
119
|
model_name = cls.available(name_only=True)[model_name]
|
81
120
|
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
121
|
+
try:
|
122
|
+
factory = cls.get_registry().create_model_factory(
|
123
|
+
model_name, service_name=service_name
|
124
|
+
)
|
125
|
+
return factory(*args, **kwargs)
|
126
|
+
except (InferenceServiceError, Exception) as e:
|
127
|
+
return cls._handle_model_error(model_name, e)
|
86
128
|
|
87
129
|
@classmethod
|
88
130
|
def add_model(cls, service_name, model_name) -> None:
|
@@ -95,28 +137,25 @@ class Model(metaclass=Meta):
|
|
95
137
|
>>> Model.service_classes()
|
96
138
|
[...]
|
97
139
|
"""
|
98
|
-
return [r for r in cls.services(
|
140
|
+
return [r for r in cls.services()]
|
99
141
|
|
100
142
|
@classmethod
|
101
143
|
def services(cls, name_only: bool = False) -> List[str]:
|
102
|
-
"""Returns a list of services
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
[r._inference_service_ for r in cls.get_registry().services],
|
110
|
-
columns=["Service Name"],
|
111
|
-
)
|
112
|
-
else:
|
113
|
-
return PrettyList(
|
144
|
+
"""Returns a list of services excluding 'test', sorted alphabetically.
|
145
|
+
|
146
|
+
>>> Model.services()
|
147
|
+
[...]
|
148
|
+
"""
|
149
|
+
return PrettyList(
|
150
|
+
sorted(
|
114
151
|
[
|
115
|
-
|
152
|
+
[r._inference_service_]
|
116
153
|
for r in cls.get_registry().services
|
117
|
-
|
118
|
-
|
119
|
-
)
|
154
|
+
if r._inference_service_.lower() != "test"
|
155
|
+
]
|
156
|
+
),
|
157
|
+
columns=["Service Name"],
|
158
|
+
)
|
120
159
|
|
121
160
|
@classmethod
|
122
161
|
def services_with_local_keys(cls) -> set:
|
@@ -166,7 +205,15 @@ class Model(metaclass=Meta):
|
|
166
205
|
search_term: str = None,
|
167
206
|
name_only: bool = False,
|
168
207
|
service: Optional[str] = None,
|
208
|
+
force_refresh: bool = False,
|
169
209
|
):
|
210
|
+
"""Get available models
|
211
|
+
|
212
|
+
>>> Model.available()
|
213
|
+
[...]
|
214
|
+
>>> Model.available(service='openai')
|
215
|
+
[...]
|
216
|
+
"""
|
170
217
|
# if search_term is None and service is None:
|
171
218
|
# print("Getting available models...")
|
172
219
|
# print("You have local keys for the following services:")
|
@@ -177,13 +224,16 @@ class Model(metaclass=Meta):
|
|
177
224
|
# return None
|
178
225
|
|
179
226
|
if service is not None:
|
180
|
-
|
227
|
+
known_services = [x[0] for x in cls.services(name_only=True)]
|
228
|
+
if service not in known_services:
|
181
229
|
raise ValueError(
|
182
230
|
f"Service {service} not found in available services.",
|
183
|
-
f"Available services are: {
|
231
|
+
f"Available services are: {known_services}",
|
184
232
|
)
|
185
233
|
|
186
|
-
full_list = cls.get_registry().available(
|
234
|
+
full_list = cls.get_registry().available(
|
235
|
+
service=service, force_refresh=force_refresh
|
236
|
+
)
|
187
237
|
|
188
238
|
if search_term is None:
|
189
239
|
if name_only:
|
@@ -287,6 +337,9 @@ class Model(metaclass=Meta):
|
|
287
337
|
"""
|
288
338
|
Returns an example Model instance.
|
289
339
|
|
340
|
+
>>> Model.example()
|
341
|
+
Model(...)
|
342
|
+
|
290
343
|
:param randomize: If True, the temperature is set to a random decimal between 0 and 1.
|
291
344
|
"""
|
292
345
|
temperature = 0.5 if not randomize else round(random(), 2)
|
@@ -299,7 +352,7 @@ if __name__ == "__main__":
|
|
299
352
|
|
300
353
|
doctest.testmod(optionflags=doctest.ELLIPSIS)
|
301
354
|
|
302
|
-
available = Model.available()
|
303
|
-
m = Model("gpt-4-1106-preview")
|
304
|
-
results = m.execute_model_call("Hello world")
|
305
|
-
print(results)
|
355
|
+
# available = Model.available()
|
356
|
+
# m = Model("gpt-4-1106-preview")
|
357
|
+
# results = m.execute_model_call("Hello world")
|
358
|
+
# print(results)
|
edsl/questions/QuestionBase.py
CHANGED
@@ -85,6 +85,9 @@ class QuestionBase(
|
|
85
85
|
>>> Q.example()._simulate_answer()
|
86
86
|
{'answer': '...', 'generated_tokens': ...}
|
87
87
|
"""
|
88
|
+
if self.question_type == "free_text":
|
89
|
+
return {"answer": "Hello, how are you?", 'generated_tokens': "Hello, how are you?"}
|
90
|
+
|
88
91
|
simulated_answer = self.fake_data_factory.build().dict()
|
89
92
|
if human_readable and hasattr(self, "question_options") and self.use_code:
|
90
93
|
simulated_answer["answer"] = [
|
@@ -432,6 +435,24 @@ class QuestionBase(
|
|
432
435
|
|
433
436
|
return Survey([self])
|
434
437
|
|
438
|
+
def humanize(
|
439
|
+
self,
|
440
|
+
project_name: str = "Project",
|
441
|
+
survey_description: Optional[str] = None,
|
442
|
+
survey_alias: Optional[str] = None,
|
443
|
+
survey_visibility: Optional["VisibilityType"] = "unlisted",
|
444
|
+
) -> dict:
|
445
|
+
"""
|
446
|
+
Turn a single question into a survey and send the survey to Coop.
|
447
|
+
|
448
|
+
Then, create a project on Coop so you can share the survey with human respondents.
|
449
|
+
"""
|
450
|
+
s = self.to_survey()
|
451
|
+
project_details = s.humanize(
|
452
|
+
project_name, survey_description, survey_alias, survey_visibility
|
453
|
+
)
|
454
|
+
return project_details
|
455
|
+
|
435
456
|
def by(self, *args) -> "Jobs":
|
436
457
|
"""Turn a single question into a survey and then a Job."""
|
437
458
|
from edsl.surveys.Survey import Survey
|
@@ -187,6 +187,73 @@ class QuestionBasePromptsMixin:
|
|
187
187
|
from edsl.prompts import Prompt
|
188
188
|
|
189
189
|
return Prompt(self.question_presentation) + Prompt(self.answering_instructions)
|
190
|
+
|
191
|
+
|
192
|
+
def detailed_parameters_by_key(self) -> dict[str, set[tuple[str, ...]]]:
|
193
|
+
"""
|
194
|
+
Return a dictionary of parameters by key.
|
195
|
+
|
196
|
+
>>> from edsl import QuestionMultipleChoice
|
197
|
+
>>> QuestionMultipleChoice.example().detailed_parameters_by_key()
|
198
|
+
{'question_name': set(), 'question_text': set()}
|
199
|
+
|
200
|
+
>>> from edsl import QuestionFreeText
|
201
|
+
>>> q = QuestionFreeText(question_name = "example", question_text = "What is your name, {{ nickname }}, based on {{ q0.answer }}?")
|
202
|
+
>>> r = q.detailed_parameters_by_key()
|
203
|
+
>>> r == {'question_name': set(), 'question_text': {('q0', 'answer'), ('nickname',)}}
|
204
|
+
True
|
205
|
+
"""
|
206
|
+
params_by_key = {}
|
207
|
+
for key, value in self.data.items():
|
208
|
+
if isinstance(value, str):
|
209
|
+
params_by_key[key] = self.extract_parameters(value)
|
210
|
+
return params_by_key
|
211
|
+
|
212
|
+
@staticmethod
|
213
|
+
def extract_parameters(txt: str) -> set[tuple[str, ...]]:
|
214
|
+
"""Return all parameters of the question as tuples representing their full paths.
|
215
|
+
|
216
|
+
:param txt: The text to extract parameters from.
|
217
|
+
:return: A set of tuples representing the parameters.
|
218
|
+
|
219
|
+
>>> from edsl.questions import QuestionMultipleChoice
|
220
|
+
>>> d = QuestionMultipleChoice.example().extract_parameters("What is your name, {{ nickname }}, based on {{ q0.answer }}?")
|
221
|
+
>>> d =={('nickname',), ('q0', 'answer')}
|
222
|
+
True
|
223
|
+
"""
|
224
|
+
from jinja2 import Environment, nodes
|
225
|
+
|
226
|
+
env = Environment()
|
227
|
+
#txt = self._all_text()
|
228
|
+
ast = env.parse(txt)
|
229
|
+
|
230
|
+
variables = set()
|
231
|
+
processed_nodes = set() # Keep track of nodes we've processed
|
232
|
+
|
233
|
+
def visit_node(node, path=()):
|
234
|
+
if id(node) in processed_nodes:
|
235
|
+
return
|
236
|
+
processed_nodes.add(id(node))
|
237
|
+
|
238
|
+
if isinstance(node, nodes.Name):
|
239
|
+
# Only add the name if we're not in the middle of building a longer path
|
240
|
+
if not path:
|
241
|
+
variables.add((node.name,))
|
242
|
+
else:
|
243
|
+
variables.add((node.name,) + path)
|
244
|
+
elif isinstance(node, nodes.Getattr):
|
245
|
+
# Build path from bottom up
|
246
|
+
new_path = (node.attr,) + path
|
247
|
+
visit_node(node.node, new_path)
|
248
|
+
|
249
|
+
for node in ast.find_all((nodes.Name, nodes.Getattr)):
|
250
|
+
visit_node(node)
|
251
|
+
|
252
|
+
return variables
|
253
|
+
|
254
|
+
@property
|
255
|
+
def detailed_parameters(self):
|
256
|
+
return [".".join(p) for p in self.extract_parameters(self._all_text())]
|
190
257
|
|
191
258
|
@property
|
192
259
|
def parameters(self) -> set[str]:
|
@@ -219,3 +286,39 @@ class QuestionBasePromptsMixin:
|
|
219
286
|
return self.new_default_instructions
|
220
287
|
else:
|
221
288
|
return self.applicable_prompts(model)[0]()
|
289
|
+
|
290
|
+
@staticmethod
|
291
|
+
def sequence_in_dict(d: dict, path: tuple[str, ...]) -> tuple[bool, any]:
|
292
|
+
"""Check if a sequence of nested keys exists in a dictionary and return the value.
|
293
|
+
|
294
|
+
Args:
|
295
|
+
d: The dictionary to check
|
296
|
+
path: Tuple of keys representing the nested path
|
297
|
+
|
298
|
+
Returns:
|
299
|
+
tuple[bool, any]: (True, value) if the path exists, (False, None) otherwise
|
300
|
+
|
301
|
+
Example:
|
302
|
+
>>> sequence_in_dict = QuestionBasePromptsMixin.sequence_in_dict
|
303
|
+
>>> d = {'a': {'b': {'c': 1}}}
|
304
|
+
>>> sequence_in_dict(d, ('a', 'b', 'c'))
|
305
|
+
(True, 1)
|
306
|
+
>>> sequence_in_dict(d, ('a', 'b', 'd'))
|
307
|
+
(False, None)
|
308
|
+
>>> sequence_in_dict(d, ('x',))
|
309
|
+
(False, None)
|
310
|
+
"""
|
311
|
+
try:
|
312
|
+
current = d
|
313
|
+
for key in path:
|
314
|
+
current = current.get(key)
|
315
|
+
if current is None:
|
316
|
+
return (False, None)
|
317
|
+
return (True, current)
|
318
|
+
except (AttributeError, TypeError):
|
319
|
+
return (False, None)
|
320
|
+
|
321
|
+
|
322
|
+
if __name__ == "__main__":
|
323
|
+
import doctest
|
324
|
+
doctest.testmod()
|
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|
2
2
|
from typing import Any, Optional
|
3
3
|
from uuid import uuid4
|
4
4
|
|
5
|
-
from pydantic import field_validator
|
5
|
+
from pydantic import field_validator, model_validator
|
6
6
|
|
7
7
|
from edsl.questions.QuestionBase import QuestionBase
|
8
8
|
from edsl.questions.response_validator_abc import ResponseValidatorABC
|
@@ -24,6 +24,17 @@ class FreeTextResponse(BaseModel):
|
|
24
24
|
answer: str
|
25
25
|
generated_tokens: Optional[str] = None
|
26
26
|
|
27
|
+
@model_validator(mode='after')
|
28
|
+
def validate_tokens_match_answer(self):
|
29
|
+
if self.generated_tokens is not None: # If generated_tokens exists
|
30
|
+
# Ensure exact string equality
|
31
|
+
if self.answer.strip() != self.generated_tokens.strip(): # They MUST match exactly
|
32
|
+
raise ValueError(
|
33
|
+
f"answer '{self.answer}' must exactly match generated_tokens '{self.generated_tokens}'. "
|
34
|
+
f"Type of answer: {type(self.answer)}, Type of tokens: {type(self.generated_tokens)}"
|
35
|
+
)
|
36
|
+
return self
|
37
|
+
|
27
38
|
|
28
39
|
class FreeTextResponseValidator(ResponseValidatorABC):
|
29
40
|
required_params = []
|
@@ -37,10 +48,16 @@ class FreeTextResponseValidator(ResponseValidatorABC):
|
|
37
48
|
]
|
38
49
|
|
39
50
|
def fix(self, response, verbose=False):
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
51
|
+
if response.get("generated_tokens") != response.get("answer"):
|
52
|
+
return {
|
53
|
+
"answer": str(response.get("generated_tokens")),
|
54
|
+
"generated_tokens": str(response.get("generated_tokens")),
|
55
|
+
}
|
56
|
+
else:
|
57
|
+
return {
|
58
|
+
"answer": str(response.get("generated_tokens")),
|
59
|
+
"generated_tokens": str(response.get("generated_tokens")),
|
60
|
+
}
|
44
61
|
|
45
62
|
|
46
63
|
class QuestionFreeText(QuestionBase):
|
edsl/questions/descriptors.py
CHANGED
@@ -2,6 +2,7 @@
|
|
2
2
|
|
3
3
|
from abc import ABC, abstractmethod
|
4
4
|
import re
|
5
|
+
import textwrap
|
5
6
|
from typing import Any, Callable, List, Optional
|
6
7
|
from edsl.exceptions.questions import (
|
7
8
|
QuestionCreationValidationError,
|
@@ -404,6 +405,9 @@ class QuestionTextDescriptor(BaseDescriptor):
|
|
404
405
|
raise Exception("Question is too short!")
|
405
406
|
if not isinstance(value, str):
|
406
407
|
raise Exception("Question must be a string!")
|
408
|
+
|
409
|
+
#value = textwrap.dedent(value).strip()
|
410
|
+
|
407
411
|
if contains_single_braced_substring(value):
|
408
412
|
import warnings
|
409
413
|
|
@@ -1,7 +1,7 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
import copy
|
3
3
|
import itertools
|
4
|
-
from typing import Optional, List, Callable, Type, TYPE_CHECKING
|
4
|
+
from typing import Optional, List, Callable, Type, TYPE_CHECKING, Union
|
5
5
|
|
6
6
|
if TYPE_CHECKING:
|
7
7
|
from edsl.questions.QuestionBase import QuestionBase
|
@@ -9,7 +9,11 @@ if TYPE_CHECKING:
|
|
9
9
|
|
10
10
|
|
11
11
|
class QuestionBaseGenMixin:
|
12
|
-
"""Mixin for QuestionBase.
|
12
|
+
"""Mixin for QuestionBase.
|
13
|
+
|
14
|
+
This mostly has functions that are used to generate new questions from existing ones.
|
15
|
+
|
16
|
+
"""
|
13
17
|
|
14
18
|
def copy(self) -> QuestionBase:
|
15
19
|
"""Return a deep copy of the question.
|
@@ -85,48 +89,110 @@ class QuestionBaseGenMixin:
|
|
85
89
|
lp = LoopProcessor(self)
|
86
90
|
return lp.process_templates(scenario_list)
|
87
91
|
|
88
|
-
|
89
|
-
"""
|
90
|
-
|
91
|
-
:param replacement_dict: The dictionary of values to replace in the question components.
|
92
|
+
class MaxTemplateNestingExceeded(Exception):
|
93
|
+
"""Raised when template rendering exceeds maximum allowed nesting level."""
|
94
|
+
pass
|
92
95
|
|
96
|
+
def render(self, replacement_dict: dict, return_dict: bool = False) -> Union["QuestionBase", dict]:
|
97
|
+
"""Render the question components as jinja2 templates with the replacement dictionary.
|
98
|
+
Handles nested template variables by recursively rendering until all variables are resolved.
|
99
|
+
|
100
|
+
Raises:
|
101
|
+
MaxTemplateNestingExceeded: If template nesting exceeds MAX_NESTING levels
|
102
|
+
|
93
103
|
>>> from edsl.questions.QuestionFreeText import QuestionFreeText
|
94
104
|
>>> q = QuestionFreeText(question_name = "color", question_text = "What is your favorite {{ thing }}?")
|
95
105
|
>>> q.render({"thing": "color"})
|
96
106
|
Question('free_text', question_name = \"""color\""", question_text = \"""What is your favorite color?\""")
|
97
107
|
|
108
|
+
>>> from edsl.questions.QuestionMultipleChoice import QuestionMultipleChoice
|
109
|
+
>>> q = QuestionMultipleChoice(question_name = "color", question_text = "What is your favorite {{ thing }}?", question_options = ["red", "blue", "green"])
|
110
|
+
>>> from edsl.scenarios.Scenario import Scenario
|
111
|
+
>>> q.render(Scenario({"thing": "color"})).data
|
112
|
+
{'question_name': 'color', 'question_text': 'What is your favorite color?', 'question_options': ['red', 'blue', 'green']}
|
113
|
+
|
114
|
+
>>> from edsl.questions.QuestionMultipleChoice import QuestionMultipleChoice
|
115
|
+
>>> q = QuestionMultipleChoice(question_name = "color", question_text = "What is your favorite {{ thing }}?", question_options = ["red", "blue", "green"])
|
116
|
+
>>> q.render({"thing": 1}).data
|
117
|
+
{'question_name': 'color', 'question_text': 'What is your favorite 1?', 'question_options': ['red', 'blue', 'green']}
|
118
|
+
|
119
|
+
|
120
|
+
>>> from edsl.questions.QuestionMultipleChoice import QuestionMultipleChoice
|
121
|
+
>>> from edsl.scenarios.Scenario import Scenario
|
122
|
+
>>> q = QuestionMultipleChoice(question_name = "color", question_text = "What is your favorite {{ thing }}?", question_options = ["red", "blue", "green"])
|
123
|
+
>>> q.render(Scenario({"thing": "color of {{ object }}", "object":"water"})).data
|
124
|
+
{'question_name': 'color', 'question_text': 'What is your favorite color of water?', 'question_options': ['red', 'blue', 'green']}
|
125
|
+
|
126
|
+
|
127
|
+
>>> from edsl.questions.QuestionFreeText import QuestionFreeText
|
128
|
+
>>> q = QuestionFreeText(question_name = "infinite", question_text = "This has {{ a }}")
|
129
|
+
>>> q.render({"a": "{{ b }}", "b": "{{ a }}"}) # doctest: +IGNORE_EXCEPTION_DETAIL
|
130
|
+
Traceback (most recent call last):
|
131
|
+
...
|
132
|
+
edsl.questions.question_base_gen_mixin.QuestionBaseGenMixin.MaxTemplateNestingExceeded:...
|
98
133
|
"""
|
99
|
-
from jinja2 import Environment
|
134
|
+
from jinja2 import Environment, meta
|
100
135
|
from edsl.scenarios.Scenario import Scenario
|
101
136
|
|
137
|
+
MAX_NESTING = 10 # Maximum allowed nesting levels
|
138
|
+
|
102
139
|
strings_only_replacement_dict = {
|
103
140
|
k: v for k, v in replacement_dict.items() if not isinstance(v, Scenario)
|
104
141
|
}
|
105
142
|
|
143
|
+
def _has_unrendered_variables(template_str: str, env: Environment) -> bool:
|
144
|
+
"""Check if the template string has any unrendered variables."""
|
145
|
+
if not isinstance(template_str, str):
|
146
|
+
return False
|
147
|
+
ast = env.parse(template_str)
|
148
|
+
return bool(meta.find_undeclared_variables(ast))
|
149
|
+
|
106
150
|
def render_string(value: str) -> str:
|
107
151
|
if value is None or not isinstance(value, str):
|
108
152
|
return value
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
153
|
+
|
154
|
+
try:
|
155
|
+
env = Environment()
|
156
|
+
result = value
|
157
|
+
nesting_count = 0
|
158
|
+
|
159
|
+
while _has_unrendered_variables(result, env):
|
160
|
+
if nesting_count >= MAX_NESTING:
|
161
|
+
raise self.MaxTemplateNestingExceeded(
|
162
|
+
f"Template rendering exceeded {MAX_NESTING} levels of nesting. "
|
163
|
+
f"Current value: {result}"
|
164
|
+
)
|
165
|
+
|
166
|
+
template = env.from_string(result)
|
167
|
+
new_result = template.render(strings_only_replacement_dict)
|
168
|
+
if new_result == result: # Break if no changes made
|
169
|
+
break
|
170
|
+
result = new_result
|
171
|
+
nesting_count += 1
|
172
|
+
|
173
|
+
return result
|
174
|
+
except self.MaxTemplateNestingExceeded:
|
175
|
+
raise
|
176
|
+
except Exception as e:
|
177
|
+
import warnings
|
178
|
+
warnings.warn("Failed to render string: " + value)
|
179
|
+
return value
|
180
|
+
if return_dict:
|
181
|
+
return self._apply_function_dict(render_string)
|
182
|
+
else:
|
183
|
+
return self.apply_function(render_string)
|
184
|
+
|
126
185
|
def apply_function(
|
127
|
-
self, func: Callable, exclude_components: List[str] = None
|
186
|
+
self, func: Callable, exclude_components: Optional[List[str]] = None
|
128
187
|
) -> QuestionBase:
|
129
|
-
|
188
|
+
from edsl.questions.QuestionBase import QuestionBase
|
189
|
+
d = self._apply_function_dict(func, exclude_components)
|
190
|
+
return QuestionBase.from_dict(d)
|
191
|
+
|
192
|
+
def _apply_function_dict(
|
193
|
+
self, func: Callable, exclude_components: Optional[List[str]] = None
|
194
|
+
) -> dict:
|
195
|
+
"""Apply a function to the question parts, excluding certain components.
|
130
196
|
|
131
197
|
:param func: The function to apply to the question parts.
|
132
198
|
:param exclude_components: The components to exclude from the function application.
|
@@ -141,7 +207,6 @@ class QuestionBaseGenMixin:
|
|
141
207
|
Question('free_text', question_name = \"""COLOR\""", question_text = \"""WHAT IS YOUR FAVORITE COLOR?\""")
|
142
208
|
|
143
209
|
"""
|
144
|
-
from edsl.questions.QuestionBase import QuestionBase
|
145
210
|
|
146
211
|
if exclude_components is None:
|
147
212
|
exclude_components = ["question_name", "question_type"]
|
@@ -160,10 +225,10 @@ class QuestionBaseGenMixin:
|
|
160
225
|
d[key] = value
|
161
226
|
continue
|
162
227
|
d[key] = func(value)
|
163
|
-
return
|
228
|
+
return d
|
164
229
|
|
165
230
|
|
166
231
|
if __name__ == "__main__":
|
167
232
|
import doctest
|
168
233
|
|
169
|
-
doctest.testmod()
|
234
|
+
doctest.testmod(optionflags=doctest.ELLIPSIS)
|