edsl 0.1.43__py3-none-any.whl → 0.1.45__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. edsl/Base.py +15 -6
  2. edsl/__version__.py +1 -1
  3. edsl/agents/InvigilatorBase.py +3 -1
  4. edsl/agents/PromptConstructor.py +62 -34
  5. edsl/agents/QuestionInstructionPromptBuilder.py +111 -68
  6. edsl/agents/QuestionTemplateReplacementsBuilder.py +69 -16
  7. edsl/agents/question_option_processor.py +15 -6
  8. edsl/coop/CoopFunctionsMixin.py +3 -4
  9. edsl/coop/coop.py +56 -10
  10. edsl/enums.py +4 -1
  11. edsl/inference_services/AnthropicService.py +12 -8
  12. edsl/inference_services/AvailableModelFetcher.py +2 -0
  13. edsl/inference_services/AwsBedrock.py +1 -2
  14. edsl/inference_services/AzureAI.py +12 -9
  15. edsl/inference_services/GoogleService.py +10 -3
  16. edsl/inference_services/InferenceServiceABC.py +1 -0
  17. edsl/inference_services/InferenceServicesCollection.py +2 -2
  18. edsl/inference_services/MistralAIService.py +1 -2
  19. edsl/inference_services/OpenAIService.py +10 -4
  20. edsl/inference_services/PerplexityService.py +2 -1
  21. edsl/inference_services/TestService.py +1 -0
  22. edsl/inference_services/XAIService.py +11 -0
  23. edsl/inference_services/registry.py +2 -0
  24. edsl/jobs/Jobs.py +9 -0
  25. edsl/jobs/JobsChecks.py +11 -14
  26. edsl/jobs/JobsPrompts.py +3 -3
  27. edsl/jobs/async_interview_runner.py +3 -1
  28. edsl/jobs/check_survey_scenario_compatibility.py +5 -5
  29. edsl/jobs/interviews/InterviewExceptionEntry.py +12 -0
  30. edsl/jobs/tasks/TaskHistory.py +1 -1
  31. edsl/language_models/LanguageModel.py +3 -3
  32. edsl/language_models/PriceManager.py +45 -5
  33. edsl/language_models/model.py +89 -36
  34. edsl/questions/QuestionBase.py +21 -0
  35. edsl/questions/QuestionBasePromptsMixin.py +103 -0
  36. edsl/questions/QuestionFreeText.py +22 -5
  37. edsl/questions/descriptors.py +4 -0
  38. edsl/questions/question_base_gen_mixin.py +94 -29
  39. edsl/results/Dataset.py +65 -0
  40. edsl/results/DatasetExportMixin.py +299 -32
  41. edsl/results/Result.py +27 -0
  42. edsl/results/Results.py +24 -3
  43. edsl/results/ResultsGGMixin.py +7 -3
  44. edsl/scenarios/DocumentChunker.py +2 -0
  45. edsl/scenarios/FileStore.py +29 -8
  46. edsl/scenarios/PdfExtractor.py +21 -1
  47. edsl/scenarios/Scenario.py +25 -9
  48. edsl/scenarios/ScenarioList.py +73 -3
  49. edsl/scenarios/handlers/__init__.py +1 -0
  50. edsl/scenarios/handlers/docx.py +5 -1
  51. edsl/scenarios/handlers/jpeg.py +39 -0
  52. edsl/surveys/Survey.py +28 -6
  53. edsl/surveys/SurveyFlowVisualization.py +91 -43
  54. edsl/templates/error_reporting/exceptions_table.html +7 -8
  55. edsl/templates/error_reporting/interview_details.html +1 -1
  56. edsl/templates/error_reporting/interviews.html +0 -1
  57. edsl/templates/error_reporting/overview.html +2 -7
  58. edsl/templates/error_reporting/performance_plot.html +1 -1
  59. edsl/templates/error_reporting/report.css +1 -1
  60. edsl/utilities/PrettyList.py +14 -0
  61. edsl-0.1.45.dist-info/METADATA +246 -0
  62. {edsl-0.1.43.dist-info → edsl-0.1.45.dist-info}/RECORD +64 -62
  63. edsl-0.1.43.dist-info/METADATA +0 -110
  64. {edsl-0.1.43.dist-info → edsl-0.1.45.dist-info}/LICENSE +0 -0
  65. {edsl-0.1.43.dist-info → edsl-0.1.45.dist-info}/WHEEL +0 -0
@@ -1,6 +1,6 @@
1
1
  import textwrap
2
2
  from random import random
3
- from typing import Optional, TYPE_CHECKING, List
3
+ from typing import Optional, TYPE_CHECKING, List, Callable
4
4
 
5
5
  from edsl.utilities.PrettyList import PrettyList
6
6
  from edsl.config import CONFIG
@@ -11,17 +11,25 @@ from edsl.inference_services.InferenceServicesCollection import (
11
11
  from edsl.inference_services.data_structures import AvailableModels
12
12
  from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
13
13
  from edsl.enums import InferenceServiceLiteral
14
+ from edsl.exceptions.inference_services import InferenceServiceError
14
15
 
15
16
  if TYPE_CHECKING:
16
17
  from edsl.results.Dataset import Dataset
17
18
 
18
19
 
19
- def get_model_class(model_name, registry: Optional[InferenceServicesCollection] = None):
20
+ def get_model_class(
21
+ model_name,
22
+ registry: Optional[InferenceServicesCollection] = None,
23
+ service_name: Optional[InferenceServiceLiteral] = None,
24
+ ):
20
25
  from edsl.inference_services.registry import default
21
26
 
22
27
  registry = registry or default
23
- factory = registry.create_model_factory(model_name)
24
- return factory
28
+ try:
29
+ factory = registry.create_model_factory(model_name, service_name=service_name)
30
+ return factory
31
+ except (InferenceServiceError, Exception) as e:
32
+ return Model._handle_model_error(model_name, e)
25
33
 
26
34
 
27
35
  class Meta(type):
@@ -36,6 +44,9 @@ class Meta(type):
36
44
  To get the default model, you can leave out the model name.
37
45
  To see the available models, you can do:
38
46
  >>> Model.available()
47
+
48
+ Or to see the models for a specific service, you can do:
49
+ >>> Model.available(service='openai')
39
50
  """
40
51
  )
41
52
 
@@ -58,6 +69,33 @@ class Model(metaclass=Meta):
58
69
  """Set a new registry"""
59
70
  cls._registry = registry
60
71
 
72
+ @classmethod
73
+ def _handle_model_error(cls, model_name: str, error: Exception):
74
+ """Handle errors from model creation and execution with notebook-aware behavior."""
75
+ if isinstance(error, InferenceServiceError):
76
+ services = [s._inference_service_ for s in cls.get_registry().services]
77
+ message = (
78
+ f"Model '{model_name}' not found in any services.\n"
79
+ "It is likely that our registry is just out of date.\n"
80
+ "Simply adding the service name to your model call should fix this.\n"
81
+ f"Available services are: {services}\n"
82
+ f"To specify a model with a service, use:\n"
83
+ f'Model("{model_name}", service_name="<service_name>")'
84
+ )
85
+ else:
86
+ message = f"An error occurred: {str(error)}"
87
+
88
+ # Check if we're in a notebook environment
89
+ try:
90
+ get_ipython()
91
+ print(message)
92
+ return None
93
+ except NameError:
94
+ # Not in a notebook, raise the exception
95
+ if isinstance(error, InferenceServiceError):
96
+ raise InferenceServiceError(message)
97
+ raise error
98
+
61
99
  def __new__(
62
100
  cls,
63
101
  model_name: Optional[str] = None,
@@ -66,12 +104,13 @@ class Model(metaclass=Meta):
66
104
  *args,
67
105
  **kwargs,
68
106
  ):
69
- "Instantiate a new language model."
107
+ """Instantiate a new language model.
108
+ >>> Model()
109
+ Model(...)
110
+ """
70
111
  # Map index to the respective subclass
71
112
  if model_name is None:
72
- model_name = (
73
- cls.default_model
74
- ) # when model_name is None, use the default model, set in the config file
113
+ model_name = cls.default_model
75
114
 
76
115
  if registry is not None:
77
116
  cls.set_registry(registry)
@@ -79,10 +118,13 @@ class Model(metaclass=Meta):
79
118
  if isinstance(model_name, int): # can refer to a model by index
80
119
  model_name = cls.available(name_only=True)[model_name]
81
120
 
82
- factory = cls.get_registry().create_model_factory(
83
- model_name, service_name=service_name
84
- )
85
- return factory(*args, **kwargs)
121
+ try:
122
+ factory = cls.get_registry().create_model_factory(
123
+ model_name, service_name=service_name
124
+ )
125
+ return factory(*args, **kwargs)
126
+ except (InferenceServiceError, Exception) as e:
127
+ return cls._handle_model_error(model_name, e)
86
128
 
87
129
  @classmethod
88
130
  def add_model(cls, service_name, model_name) -> None:
@@ -95,28 +137,25 @@ class Model(metaclass=Meta):
95
137
  >>> Model.service_classes()
96
138
  [...]
97
139
  """
98
- return [r for r in cls.services(name_only=True)]
140
+ return [r for r in cls.services()]
99
141
 
100
142
  @classmethod
101
143
  def services(cls, name_only: bool = False) -> List[str]:
102
- """Returns a list of services, annotated with whether the user has local keys for them."""
103
- services_with_local_keys = set(cls.key_info().select("service").to_list())
104
- f = lambda service_name: (
105
- "yes" if service_name in services_with_local_keys else " "
106
- )
107
- if name_only:
108
- return PrettyList(
109
- [r._inference_service_ for r in cls.get_registry().services],
110
- columns=["Service Name"],
111
- )
112
- else:
113
- return PrettyList(
144
+ """Returns a list of services excluding 'test', sorted alphabetically.
145
+
146
+ >>> Model.services()
147
+ [...]
148
+ """
149
+ return PrettyList(
150
+ sorted(
114
151
  [
115
- (r._inference_service_, f(r._inference_service_))
152
+ [r._inference_service_]
116
153
  for r in cls.get_registry().services
117
- ],
118
- columns=["Service Name", "Local key?"],
119
- )
154
+ if r._inference_service_.lower() != "test"
155
+ ]
156
+ ),
157
+ columns=["Service Name"],
158
+ )
120
159
 
121
160
  @classmethod
122
161
  def services_with_local_keys(cls) -> set:
@@ -166,7 +205,15 @@ class Model(metaclass=Meta):
166
205
  search_term: str = None,
167
206
  name_only: bool = False,
168
207
  service: Optional[str] = None,
208
+ force_refresh: bool = False,
169
209
  ):
210
+ """Get available models
211
+
212
+ >>> Model.available()
213
+ [...]
214
+ >>> Model.available(service='openai')
215
+ [...]
216
+ """
170
217
  # if search_term is None and service is None:
171
218
  # print("Getting available models...")
172
219
  # print("You have local keys for the following services:")
@@ -177,13 +224,16 @@ class Model(metaclass=Meta):
177
224
  # return None
178
225
 
179
226
  if service is not None:
180
- if service not in cls.services(name_only=True):
227
+ known_services = [x[0] for x in cls.services(name_only=True)]
228
+ if service not in known_services:
181
229
  raise ValueError(
182
230
  f"Service {service} not found in available services.",
183
- f"Available services are: {cls.services()}",
231
+ f"Available services are: {known_services}",
184
232
  )
185
233
 
186
- full_list = cls.get_registry().available(service=service)
234
+ full_list = cls.get_registry().available(
235
+ service=service, force_refresh=force_refresh
236
+ )
187
237
 
188
238
  if search_term is None:
189
239
  if name_only:
@@ -287,6 +337,9 @@ class Model(metaclass=Meta):
287
337
  """
288
338
  Returns an example Model instance.
289
339
 
340
+ >>> Model.example()
341
+ Model(...)
342
+
290
343
  :param randomize: If True, the temperature is set to a random decimal between 0 and 1.
291
344
  """
292
345
  temperature = 0.5 if not randomize else round(random(), 2)
@@ -299,7 +352,7 @@ if __name__ == "__main__":
299
352
 
300
353
  doctest.testmod(optionflags=doctest.ELLIPSIS)
301
354
 
302
- available = Model.available()
303
- m = Model("gpt-4-1106-preview")
304
- results = m.execute_model_call("Hello world")
305
- print(results)
355
+ # available = Model.available()
356
+ # m = Model("gpt-4-1106-preview")
357
+ # results = m.execute_model_call("Hello world")
358
+ # print(results)
@@ -85,6 +85,9 @@ class QuestionBase(
85
85
  >>> Q.example()._simulate_answer()
86
86
  {'answer': '...', 'generated_tokens': ...}
87
87
  """
88
+ if self.question_type == "free_text":
89
+ return {"answer": "Hello, how are you?", 'generated_tokens': "Hello, how are you?"}
90
+
88
91
  simulated_answer = self.fake_data_factory.build().dict()
89
92
  if human_readable and hasattr(self, "question_options") and self.use_code:
90
93
  simulated_answer["answer"] = [
@@ -432,6 +435,24 @@ class QuestionBase(
432
435
 
433
436
  return Survey([self])
434
437
 
438
+ def humanize(
439
+ self,
440
+ project_name: str = "Project",
441
+ survey_description: Optional[str] = None,
442
+ survey_alias: Optional[str] = None,
443
+ survey_visibility: Optional["VisibilityType"] = "unlisted",
444
+ ) -> dict:
445
+ """
446
+ Turn a single question into a survey and send the survey to Coop.
447
+
448
+ Then, create a project on Coop so you can share the survey with human respondents.
449
+ """
450
+ s = self.to_survey()
451
+ project_details = s.humanize(
452
+ project_name, survey_description, survey_alias, survey_visibility
453
+ )
454
+ return project_details
455
+
435
456
  def by(self, *args) -> "Jobs":
436
457
  """Turn a single question into a survey and then a Job."""
437
458
  from edsl.surveys.Survey import Survey
@@ -187,6 +187,73 @@ class QuestionBasePromptsMixin:
187
187
  from edsl.prompts import Prompt
188
188
 
189
189
  return Prompt(self.question_presentation) + Prompt(self.answering_instructions)
190
+
191
+
192
+ def detailed_parameters_by_key(self) -> dict[str, set[tuple[str, ...]]]:
193
+ """
194
+ Return a dictionary of parameters by key.
195
+
196
+ >>> from edsl import QuestionMultipleChoice
197
+ >>> QuestionMultipleChoice.example().detailed_parameters_by_key()
198
+ {'question_name': set(), 'question_text': set()}
199
+
200
+ >>> from edsl import QuestionFreeText
201
+ >>> q = QuestionFreeText(question_name = "example", question_text = "What is your name, {{ nickname }}, based on {{ q0.answer }}?")
202
+ >>> r = q.detailed_parameters_by_key()
203
+ >>> r == {'question_name': set(), 'question_text': {('q0', 'answer'), ('nickname',)}}
204
+ True
205
+ """
206
+ params_by_key = {}
207
+ for key, value in self.data.items():
208
+ if isinstance(value, str):
209
+ params_by_key[key] = self.extract_parameters(value)
210
+ return params_by_key
211
+
212
+ @staticmethod
213
+ def extract_parameters(txt: str) -> set[tuple[str, ...]]:
214
+ """Return all parameters of the question as tuples representing their full paths.
215
+
216
+ :param txt: The text to extract parameters from.
217
+ :return: A set of tuples representing the parameters.
218
+
219
+ >>> from edsl.questions import QuestionMultipleChoice
220
+ >>> d = QuestionMultipleChoice.example().extract_parameters("What is your name, {{ nickname }}, based on {{ q0.answer }}?")
221
+ >>> d =={('nickname',), ('q0', 'answer')}
222
+ True
223
+ """
224
+ from jinja2 import Environment, nodes
225
+
226
+ env = Environment()
227
+ #txt = self._all_text()
228
+ ast = env.parse(txt)
229
+
230
+ variables = set()
231
+ processed_nodes = set() # Keep track of nodes we've processed
232
+
233
+ def visit_node(node, path=()):
234
+ if id(node) in processed_nodes:
235
+ return
236
+ processed_nodes.add(id(node))
237
+
238
+ if isinstance(node, nodes.Name):
239
+ # Only add the name if we're not in the middle of building a longer path
240
+ if not path:
241
+ variables.add((node.name,))
242
+ else:
243
+ variables.add((node.name,) + path)
244
+ elif isinstance(node, nodes.Getattr):
245
+ # Build path from bottom up
246
+ new_path = (node.attr,) + path
247
+ visit_node(node.node, new_path)
248
+
249
+ for node in ast.find_all((nodes.Name, nodes.Getattr)):
250
+ visit_node(node)
251
+
252
+ return variables
253
+
254
+ @property
255
+ def detailed_parameters(self):
256
+ return [".".join(p) for p in self.extract_parameters(self._all_text())]
190
257
 
191
258
  @property
192
259
  def parameters(self) -> set[str]:
@@ -219,3 +286,39 @@ class QuestionBasePromptsMixin:
219
286
  return self.new_default_instructions
220
287
  else:
221
288
  return self.applicable_prompts(model)[0]()
289
+
290
+ @staticmethod
291
+ def sequence_in_dict(d: dict, path: tuple[str, ...]) -> tuple[bool, any]:
292
+ """Check if a sequence of nested keys exists in a dictionary and return the value.
293
+
294
+ Args:
295
+ d: The dictionary to check
296
+ path: Tuple of keys representing the nested path
297
+
298
+ Returns:
299
+ tuple[bool, any]: (True, value) if the path exists, (False, None) otherwise
300
+
301
+ Example:
302
+ >>> sequence_in_dict = QuestionBasePromptsMixin.sequence_in_dict
303
+ >>> d = {'a': {'b': {'c': 1}}}
304
+ >>> sequence_in_dict(d, ('a', 'b', 'c'))
305
+ (True, 1)
306
+ >>> sequence_in_dict(d, ('a', 'b', 'd'))
307
+ (False, None)
308
+ >>> sequence_in_dict(d, ('x',))
309
+ (False, None)
310
+ """
311
+ try:
312
+ current = d
313
+ for key in path:
314
+ current = current.get(key)
315
+ if current is None:
316
+ return (False, None)
317
+ return (True, current)
318
+ except (AttributeError, TypeError):
319
+ return (False, None)
320
+
321
+
322
+ if __name__ == "__main__":
323
+ import doctest
324
+ doctest.testmod()
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
  from typing import Any, Optional
3
3
  from uuid import uuid4
4
4
 
5
- from pydantic import field_validator
5
+ from pydantic import field_validator, model_validator
6
6
 
7
7
  from edsl.questions.QuestionBase import QuestionBase
8
8
  from edsl.questions.response_validator_abc import ResponseValidatorABC
@@ -24,6 +24,17 @@ class FreeTextResponse(BaseModel):
24
24
  answer: str
25
25
  generated_tokens: Optional[str] = None
26
26
 
27
+ @model_validator(mode='after')
28
+ def validate_tokens_match_answer(self):
29
+ if self.generated_tokens is not None: # If generated_tokens exists
30
+ # Ensure exact string equality
31
+ if self.answer.strip() != self.generated_tokens.strip(): # They MUST match exactly
32
+ raise ValueError(
33
+ f"answer '{self.answer}' must exactly match generated_tokens '{self.generated_tokens}'. "
34
+ f"Type of answer: {type(self.answer)}, Type of tokens: {type(self.generated_tokens)}"
35
+ )
36
+ return self
37
+
27
38
 
28
39
  class FreeTextResponseValidator(ResponseValidatorABC):
29
40
  required_params = []
@@ -37,10 +48,16 @@ class FreeTextResponseValidator(ResponseValidatorABC):
37
48
  ]
38
49
 
39
50
  def fix(self, response, verbose=False):
40
- return {
41
- "answer": str(response.get("generated_tokens")),
42
- "generated_tokens": str(response.get("generated_tokens")),
43
- }
51
+ if response.get("generated_tokens") != response.get("answer"):
52
+ return {
53
+ "answer": str(response.get("generated_tokens")),
54
+ "generated_tokens": str(response.get("generated_tokens")),
55
+ }
56
+ else:
57
+ return {
58
+ "answer": str(response.get("generated_tokens")),
59
+ "generated_tokens": str(response.get("generated_tokens")),
60
+ }
44
61
 
45
62
 
46
63
  class QuestionFreeText(QuestionBase):
@@ -2,6 +2,7 @@
2
2
 
3
3
  from abc import ABC, abstractmethod
4
4
  import re
5
+ import textwrap
5
6
  from typing import Any, Callable, List, Optional
6
7
  from edsl.exceptions.questions import (
7
8
  QuestionCreationValidationError,
@@ -404,6 +405,9 @@ class QuestionTextDescriptor(BaseDescriptor):
404
405
  raise Exception("Question is too short!")
405
406
  if not isinstance(value, str):
406
407
  raise Exception("Question must be a string!")
408
+
409
+ #value = textwrap.dedent(value).strip()
410
+
407
411
  if contains_single_braced_substring(value):
408
412
  import warnings
409
413
 
@@ -1,7 +1,7 @@
1
1
  from __future__ import annotations
2
2
  import copy
3
3
  import itertools
4
- from typing import Optional, List, Callable, Type, TYPE_CHECKING
4
+ from typing import Optional, List, Callable, Type, TYPE_CHECKING, Union
5
5
 
6
6
  if TYPE_CHECKING:
7
7
  from edsl.questions.QuestionBase import QuestionBase
@@ -9,7 +9,11 @@ if TYPE_CHECKING:
9
9
 
10
10
 
11
11
  class QuestionBaseGenMixin:
12
- """Mixin for QuestionBase."""
12
+ """Mixin for QuestionBase.
13
+
14
+ This mostly has functions that are used to generate new questions from existing ones.
15
+
16
+ """
13
17
 
14
18
  def copy(self) -> QuestionBase:
15
19
  """Return a deep copy of the question.
@@ -85,48 +89,110 @@ class QuestionBaseGenMixin:
85
89
  lp = LoopProcessor(self)
86
90
  return lp.process_templates(scenario_list)
87
91
 
88
- def render(self, replacement_dict: dict) -> "QuestionBase":
89
- """Render the question components as jinja2 templates with the replacement dictionary.
90
-
91
- :param replacement_dict: The dictionary of values to replace in the question components.
92
+ class MaxTemplateNestingExceeded(Exception):
93
+ """Raised when template rendering exceeds maximum allowed nesting level."""
94
+ pass
92
95
 
96
+ def render(self, replacement_dict: dict, return_dict: bool = False) -> Union["QuestionBase", dict]:
97
+ """Render the question components as jinja2 templates with the replacement dictionary.
98
+ Handles nested template variables by recursively rendering until all variables are resolved.
99
+
100
+ Raises:
101
+ MaxTemplateNestingExceeded: If template nesting exceeds MAX_NESTING levels
102
+
93
103
  >>> from edsl.questions.QuestionFreeText import QuestionFreeText
94
104
  >>> q = QuestionFreeText(question_name = "color", question_text = "What is your favorite {{ thing }}?")
95
105
  >>> q.render({"thing": "color"})
96
106
  Question('free_text', question_name = \"""color\""", question_text = \"""What is your favorite color?\""")
97
107
 
108
+ >>> from edsl.questions.QuestionMultipleChoice import QuestionMultipleChoice
109
+ >>> q = QuestionMultipleChoice(question_name = "color", question_text = "What is your favorite {{ thing }}?", question_options = ["red", "blue", "green"])
110
+ >>> from edsl.scenarios.Scenario import Scenario
111
+ >>> q.render(Scenario({"thing": "color"})).data
112
+ {'question_name': 'color', 'question_text': 'What is your favorite color?', 'question_options': ['red', 'blue', 'green']}
113
+
114
+ >>> from edsl.questions.QuestionMultipleChoice import QuestionMultipleChoice
115
+ >>> q = QuestionMultipleChoice(question_name = "color", question_text = "What is your favorite {{ thing }}?", question_options = ["red", "blue", "green"])
116
+ >>> q.render({"thing": 1}).data
117
+ {'question_name': 'color', 'question_text': 'What is your favorite 1?', 'question_options': ['red', 'blue', 'green']}
118
+
119
+
120
+ >>> from edsl.questions.QuestionMultipleChoice import QuestionMultipleChoice
121
+ >>> from edsl.scenarios.Scenario import Scenario
122
+ >>> q = QuestionMultipleChoice(question_name = "color", question_text = "What is your favorite {{ thing }}?", question_options = ["red", "blue", "green"])
123
+ >>> q.render(Scenario({"thing": "color of {{ object }}", "object":"water"})).data
124
+ {'question_name': 'color', 'question_text': 'What is your favorite color of water?', 'question_options': ['red', 'blue', 'green']}
125
+
126
+
127
+ >>> from edsl.questions.QuestionFreeText import QuestionFreeText
128
+ >>> q = QuestionFreeText(question_name = "infinite", question_text = "This has {{ a }}")
129
+ >>> q.render({"a": "{{ b }}", "b": "{{ a }}"}) # doctest: +IGNORE_EXCEPTION_DETAIL
130
+ Traceback (most recent call last):
131
+ ...
132
+ edsl.questions.question_base_gen_mixin.QuestionBaseGenMixin.MaxTemplateNestingExceeded:...
98
133
  """
99
- from jinja2 import Environment
134
+ from jinja2 import Environment, meta
100
135
  from edsl.scenarios.Scenario import Scenario
101
136
 
137
+ MAX_NESTING = 10 # Maximum allowed nesting levels
138
+
102
139
  strings_only_replacement_dict = {
103
140
  k: v for k, v in replacement_dict.items() if not isinstance(v, Scenario)
104
141
  }
105
142
 
143
+ def _has_unrendered_variables(template_str: str, env: Environment) -> bool:
144
+ """Check if the template string has any unrendered variables."""
145
+ if not isinstance(template_str, str):
146
+ return False
147
+ ast = env.parse(template_str)
148
+ return bool(meta.find_undeclared_variables(ast))
149
+
106
150
  def render_string(value: str) -> str:
107
151
  if value is None or not isinstance(value, str):
108
152
  return value
109
- else:
110
- try:
111
- return (
112
- Environment()
113
- .from_string(value)
114
- .render(strings_only_replacement_dict)
115
- )
116
- except Exception as e:
117
- #breakpoint()
118
- import warnings
119
-
120
- warnings.warn("Failed to render string: " + value)
121
- # breakpoint()
122
- return value
123
-
124
- return self.apply_function(render_string)
125
-
153
+
154
+ try:
155
+ env = Environment()
156
+ result = value
157
+ nesting_count = 0
158
+
159
+ while _has_unrendered_variables(result, env):
160
+ if nesting_count >= MAX_NESTING:
161
+ raise self.MaxTemplateNestingExceeded(
162
+ f"Template rendering exceeded {MAX_NESTING} levels of nesting. "
163
+ f"Current value: {result}"
164
+ )
165
+
166
+ template = env.from_string(result)
167
+ new_result = template.render(strings_only_replacement_dict)
168
+ if new_result == result: # Break if no changes made
169
+ break
170
+ result = new_result
171
+ nesting_count += 1
172
+
173
+ return result
174
+ except self.MaxTemplateNestingExceeded:
175
+ raise
176
+ except Exception as e:
177
+ import warnings
178
+ warnings.warn("Failed to render string: " + value)
179
+ return value
180
+ if return_dict:
181
+ return self._apply_function_dict(render_string)
182
+ else:
183
+ return self.apply_function(render_string)
184
+
126
185
  def apply_function(
127
- self, func: Callable, exclude_components: List[str] = None
186
+ self, func: Callable, exclude_components: Optional[List[str]] = None
128
187
  ) -> QuestionBase:
129
- """Apply a function to the question parts
188
+ from edsl.questions.QuestionBase import QuestionBase
189
+ d = self._apply_function_dict(func, exclude_components)
190
+ return QuestionBase.from_dict(d)
191
+
192
+ def _apply_function_dict(
193
+ self, func: Callable, exclude_components: Optional[List[str]] = None
194
+ ) -> dict:
195
+ """Apply a function to the question parts, excluding certain components.
130
196
 
131
197
  :param func: The function to apply to the question parts.
132
198
  :param exclude_components: The components to exclude from the function application.
@@ -141,7 +207,6 @@ class QuestionBaseGenMixin:
141
207
  Question('free_text', question_name = \"""COLOR\""", question_text = \"""WHAT IS YOUR FAVORITE COLOR?\""")
142
208
 
143
209
  """
144
- from edsl.questions.QuestionBase import QuestionBase
145
210
 
146
211
  if exclude_components is None:
147
212
  exclude_components = ["question_name", "question_type"]
@@ -160,10 +225,10 @@ class QuestionBaseGenMixin:
160
225
  d[key] = value
161
226
  continue
162
227
  d[key] = func(value)
163
- return QuestionBase.from_dict(d)
228
+ return d
164
229
 
165
230
 
166
231
  if __name__ == "__main__":
167
232
  import doctest
168
233
 
169
- doctest.testmod()
234
+ doctest.testmod(optionflags=doctest.ELLIPSIS)