edsl 0.1.44__py3-none-any.whl → 0.1.45__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/__version__.py +1 -1
- edsl/agents/InvigilatorBase.py +3 -1
- edsl/agents/PromptConstructor.py +62 -34
- edsl/agents/QuestionInstructionPromptBuilder.py +111 -68
- edsl/agents/QuestionTemplateReplacementsBuilder.py +69 -16
- edsl/agents/question_option_processor.py +15 -6
- edsl/coop/CoopFunctionsMixin.py +3 -4
- edsl/coop/coop.py +23 -9
- edsl/enums.py +3 -3
- edsl/inference_services/AnthropicService.py +11 -9
- edsl/inference_services/AvailableModelFetcher.py +2 -0
- edsl/inference_services/AwsBedrock.py +1 -2
- edsl/inference_services/AzureAI.py +12 -9
- edsl/inference_services/GoogleService.py +9 -4
- edsl/inference_services/InferenceServicesCollection.py +2 -2
- edsl/inference_services/MistralAIService.py +1 -2
- edsl/inference_services/OpenAIService.py +9 -4
- edsl/inference_services/PerplexityService.py +2 -1
- edsl/inference_services/{GrokService.py → XAIService.py} +2 -2
- edsl/inference_services/registry.py +2 -2
- edsl/jobs/Jobs.py +9 -0
- edsl/jobs/JobsChecks.py +10 -13
- edsl/jobs/async_interview_runner.py +3 -1
- edsl/jobs/check_survey_scenario_compatibility.py +5 -5
- edsl/jobs/interviews/InterviewExceptionEntry.py +12 -0
- edsl/jobs/tasks/TaskHistory.py +1 -1
- edsl/language_models/LanguageModel.py +0 -3
- edsl/language_models/PriceManager.py +45 -5
- edsl/language_models/model.py +47 -26
- edsl/questions/QuestionBase.py +21 -0
- edsl/questions/QuestionBasePromptsMixin.py +103 -0
- edsl/questions/QuestionFreeText.py +22 -5
- edsl/questions/descriptors.py +4 -0
- edsl/questions/question_base_gen_mixin.py +94 -29
- edsl/results/Dataset.py +65 -0
- edsl/results/DatasetExportMixin.py +299 -32
- edsl/results/Result.py +27 -0
- edsl/results/Results.py +22 -2
- edsl/results/ResultsGGMixin.py +7 -3
- edsl/scenarios/DocumentChunker.py +2 -0
- edsl/scenarios/FileStore.py +10 -0
- edsl/scenarios/PdfExtractor.py +21 -1
- edsl/scenarios/Scenario.py +25 -9
- edsl/scenarios/ScenarioList.py +73 -3
- edsl/scenarios/handlers/__init__.py +1 -0
- edsl/scenarios/handlers/docx.py +5 -1
- edsl/scenarios/handlers/jpeg.py +39 -0
- edsl/surveys/Survey.py +5 -4
- edsl/surveys/SurveyFlowVisualization.py +91 -43
- edsl/templates/error_reporting/exceptions_table.html +7 -8
- edsl/templates/error_reporting/interview_details.html +1 -1
- edsl/templates/error_reporting/interviews.html +0 -1
- edsl/templates/error_reporting/overview.html +2 -7
- edsl/templates/error_reporting/performance_plot.html +1 -1
- edsl/templates/error_reporting/report.css +1 -1
- edsl/utilities/PrettyList.py +14 -0
- edsl-0.1.45.dist-info/METADATA +246 -0
- {edsl-0.1.44.dist-info → edsl-0.1.45.dist-info}/RECORD +60 -59
- edsl-0.1.44.dist-info/METADATA +0 -110
- {edsl-0.1.44.dist-info → edsl-0.1.45.dist-info}/LICENSE +0 -0
- {edsl-0.1.44.dist-info → edsl-0.1.45.dist-info}/WHEEL +0 -0
edsl/language_models/model.py
CHANGED
@@ -17,7 +17,11 @@ if TYPE_CHECKING:
|
|
17
17
|
from edsl.results.Dataset import Dataset
|
18
18
|
|
19
19
|
|
20
|
-
def get_model_class(
|
20
|
+
def get_model_class(
|
21
|
+
model_name,
|
22
|
+
registry: Optional[InferenceServicesCollection] = None,
|
23
|
+
service_name: Optional[InferenceServiceLiteral] = None,
|
24
|
+
):
|
21
25
|
from edsl.inference_services.registry import default
|
22
26
|
|
23
27
|
registry = registry or default
|
@@ -40,6 +44,9 @@ class Meta(type):
|
|
40
44
|
To get the default model, you can leave out the model name.
|
41
45
|
To see the available models, you can do:
|
42
46
|
>>> Model.available()
|
47
|
+
|
48
|
+
Or to see the models for a specific service, you can do:
|
49
|
+
>>> Model.available(service='openai')
|
43
50
|
"""
|
44
51
|
)
|
45
52
|
|
@@ -97,7 +104,10 @@ class Model(metaclass=Meta):
|
|
97
104
|
*args,
|
98
105
|
**kwargs,
|
99
106
|
):
|
100
|
-
"Instantiate a new language model.
|
107
|
+
"""Instantiate a new language model.
|
108
|
+
>>> Model()
|
109
|
+
Model(...)
|
110
|
+
"""
|
101
111
|
# Map index to the respective subclass
|
102
112
|
if model_name is None:
|
103
113
|
model_name = cls.default_model
|
@@ -127,28 +137,25 @@ class Model(metaclass=Meta):
|
|
127
137
|
>>> Model.service_classes()
|
128
138
|
[...]
|
129
139
|
"""
|
130
|
-
return [r for r in cls.services(
|
140
|
+
return [r for r in cls.services()]
|
131
141
|
|
132
142
|
@classmethod
|
133
143
|
def services(cls, name_only: bool = False) -> List[str]:
|
134
|
-
"""Returns a list of services
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
[r._inference_service_ for r in cls.get_registry().services],
|
142
|
-
columns=["Service Name"],
|
143
|
-
)
|
144
|
-
else:
|
145
|
-
return PrettyList(
|
144
|
+
"""Returns a list of services excluding 'test', sorted alphabetically.
|
145
|
+
|
146
|
+
>>> Model.services()
|
147
|
+
[...]
|
148
|
+
"""
|
149
|
+
return PrettyList(
|
150
|
+
sorted(
|
146
151
|
[
|
147
|
-
|
152
|
+
[r._inference_service_]
|
148
153
|
for r in cls.get_registry().services
|
149
|
-
|
150
|
-
|
151
|
-
)
|
154
|
+
if r._inference_service_.lower() != "test"
|
155
|
+
]
|
156
|
+
),
|
157
|
+
columns=["Service Name"],
|
158
|
+
)
|
152
159
|
|
153
160
|
@classmethod
|
154
161
|
def services_with_local_keys(cls) -> set:
|
@@ -198,7 +205,15 @@ class Model(metaclass=Meta):
|
|
198
205
|
search_term: str = None,
|
199
206
|
name_only: bool = False,
|
200
207
|
service: Optional[str] = None,
|
208
|
+
force_refresh: bool = False,
|
201
209
|
):
|
210
|
+
"""Get available models
|
211
|
+
|
212
|
+
>>> Model.available()
|
213
|
+
[...]
|
214
|
+
>>> Model.available(service='openai')
|
215
|
+
[...]
|
216
|
+
"""
|
202
217
|
# if search_term is None and service is None:
|
203
218
|
# print("Getting available models...")
|
204
219
|
# print("You have local keys for the following services:")
|
@@ -209,13 +224,16 @@ class Model(metaclass=Meta):
|
|
209
224
|
# return None
|
210
225
|
|
211
226
|
if service is not None:
|
212
|
-
|
227
|
+
known_services = [x[0] for x in cls.services(name_only=True)]
|
228
|
+
if service not in known_services:
|
213
229
|
raise ValueError(
|
214
230
|
f"Service {service} not found in available services.",
|
215
|
-
f"Available services are: {
|
231
|
+
f"Available services are: {known_services}",
|
216
232
|
)
|
217
233
|
|
218
|
-
full_list = cls.get_registry().available(
|
234
|
+
full_list = cls.get_registry().available(
|
235
|
+
service=service, force_refresh=force_refresh
|
236
|
+
)
|
219
237
|
|
220
238
|
if search_term is None:
|
221
239
|
if name_only:
|
@@ -319,6 +337,9 @@ class Model(metaclass=Meta):
|
|
319
337
|
"""
|
320
338
|
Returns an example Model instance.
|
321
339
|
|
340
|
+
>>> Model.example()
|
341
|
+
Model(...)
|
342
|
+
|
322
343
|
:param randomize: If True, the temperature is set to a random decimal between 0 and 1.
|
323
344
|
"""
|
324
345
|
temperature = 0.5 if not randomize else round(random(), 2)
|
@@ -331,7 +352,7 @@ if __name__ == "__main__":
|
|
331
352
|
|
332
353
|
doctest.testmod(optionflags=doctest.ELLIPSIS)
|
333
354
|
|
334
|
-
available = Model.available()
|
335
|
-
m = Model("gpt-4-1106-preview")
|
336
|
-
results = m.execute_model_call("Hello world")
|
337
|
-
print(results)
|
355
|
+
# available = Model.available()
|
356
|
+
# m = Model("gpt-4-1106-preview")
|
357
|
+
# results = m.execute_model_call("Hello world")
|
358
|
+
# print(results)
|
edsl/questions/QuestionBase.py
CHANGED
@@ -85,6 +85,9 @@ class QuestionBase(
|
|
85
85
|
>>> Q.example()._simulate_answer()
|
86
86
|
{'answer': '...', 'generated_tokens': ...}
|
87
87
|
"""
|
88
|
+
if self.question_type == "free_text":
|
89
|
+
return {"answer": "Hello, how are you?", 'generated_tokens': "Hello, how are you?"}
|
90
|
+
|
88
91
|
simulated_answer = self.fake_data_factory.build().dict()
|
89
92
|
if human_readable and hasattr(self, "question_options") and self.use_code:
|
90
93
|
simulated_answer["answer"] = [
|
@@ -432,6 +435,24 @@ class QuestionBase(
|
|
432
435
|
|
433
436
|
return Survey([self])
|
434
437
|
|
438
|
+
def humanize(
|
439
|
+
self,
|
440
|
+
project_name: str = "Project",
|
441
|
+
survey_description: Optional[str] = None,
|
442
|
+
survey_alias: Optional[str] = None,
|
443
|
+
survey_visibility: Optional["VisibilityType"] = "unlisted",
|
444
|
+
) -> dict:
|
445
|
+
"""
|
446
|
+
Turn a single question into a survey and send the survey to Coop.
|
447
|
+
|
448
|
+
Then, create a project on Coop so you can share the survey with human respondents.
|
449
|
+
"""
|
450
|
+
s = self.to_survey()
|
451
|
+
project_details = s.humanize(
|
452
|
+
project_name, survey_description, survey_alias, survey_visibility
|
453
|
+
)
|
454
|
+
return project_details
|
455
|
+
|
435
456
|
def by(self, *args) -> "Jobs":
|
436
457
|
"""Turn a single question into a survey and then a Job."""
|
437
458
|
from edsl.surveys.Survey import Survey
|
@@ -187,6 +187,73 @@ class QuestionBasePromptsMixin:
|
|
187
187
|
from edsl.prompts import Prompt
|
188
188
|
|
189
189
|
return Prompt(self.question_presentation) + Prompt(self.answering_instructions)
|
190
|
+
|
191
|
+
|
192
|
+
def detailed_parameters_by_key(self) -> dict[str, set[tuple[str, ...]]]:
|
193
|
+
"""
|
194
|
+
Return a dictionary of parameters by key.
|
195
|
+
|
196
|
+
>>> from edsl import QuestionMultipleChoice
|
197
|
+
>>> QuestionMultipleChoice.example().detailed_parameters_by_key()
|
198
|
+
{'question_name': set(), 'question_text': set()}
|
199
|
+
|
200
|
+
>>> from edsl import QuestionFreeText
|
201
|
+
>>> q = QuestionFreeText(question_name = "example", question_text = "What is your name, {{ nickname }}, based on {{ q0.answer }}?")
|
202
|
+
>>> r = q.detailed_parameters_by_key()
|
203
|
+
>>> r == {'question_name': set(), 'question_text': {('q0', 'answer'), ('nickname',)}}
|
204
|
+
True
|
205
|
+
"""
|
206
|
+
params_by_key = {}
|
207
|
+
for key, value in self.data.items():
|
208
|
+
if isinstance(value, str):
|
209
|
+
params_by_key[key] = self.extract_parameters(value)
|
210
|
+
return params_by_key
|
211
|
+
|
212
|
+
@staticmethod
|
213
|
+
def extract_parameters(txt: str) -> set[tuple[str, ...]]:
|
214
|
+
"""Return all parameters of the question as tuples representing their full paths.
|
215
|
+
|
216
|
+
:param txt: The text to extract parameters from.
|
217
|
+
:return: A set of tuples representing the parameters.
|
218
|
+
|
219
|
+
>>> from edsl.questions import QuestionMultipleChoice
|
220
|
+
>>> d = QuestionMultipleChoice.example().extract_parameters("What is your name, {{ nickname }}, based on {{ q0.answer }}?")
|
221
|
+
>>> d =={('nickname',), ('q0', 'answer')}
|
222
|
+
True
|
223
|
+
"""
|
224
|
+
from jinja2 import Environment, nodes
|
225
|
+
|
226
|
+
env = Environment()
|
227
|
+
#txt = self._all_text()
|
228
|
+
ast = env.parse(txt)
|
229
|
+
|
230
|
+
variables = set()
|
231
|
+
processed_nodes = set() # Keep track of nodes we've processed
|
232
|
+
|
233
|
+
def visit_node(node, path=()):
|
234
|
+
if id(node) in processed_nodes:
|
235
|
+
return
|
236
|
+
processed_nodes.add(id(node))
|
237
|
+
|
238
|
+
if isinstance(node, nodes.Name):
|
239
|
+
# Only add the name if we're not in the middle of building a longer path
|
240
|
+
if not path:
|
241
|
+
variables.add((node.name,))
|
242
|
+
else:
|
243
|
+
variables.add((node.name,) + path)
|
244
|
+
elif isinstance(node, nodes.Getattr):
|
245
|
+
# Build path from bottom up
|
246
|
+
new_path = (node.attr,) + path
|
247
|
+
visit_node(node.node, new_path)
|
248
|
+
|
249
|
+
for node in ast.find_all((nodes.Name, nodes.Getattr)):
|
250
|
+
visit_node(node)
|
251
|
+
|
252
|
+
return variables
|
253
|
+
|
254
|
+
@property
|
255
|
+
def detailed_parameters(self):
|
256
|
+
return [".".join(p) for p in self.extract_parameters(self._all_text())]
|
190
257
|
|
191
258
|
@property
|
192
259
|
def parameters(self) -> set[str]:
|
@@ -219,3 +286,39 @@ class QuestionBasePromptsMixin:
|
|
219
286
|
return self.new_default_instructions
|
220
287
|
else:
|
221
288
|
return self.applicable_prompts(model)[0]()
|
289
|
+
|
290
|
+
@staticmethod
|
291
|
+
def sequence_in_dict(d: dict, path: tuple[str, ...]) -> tuple[bool, any]:
|
292
|
+
"""Check if a sequence of nested keys exists in a dictionary and return the value.
|
293
|
+
|
294
|
+
Args:
|
295
|
+
d: The dictionary to check
|
296
|
+
path: Tuple of keys representing the nested path
|
297
|
+
|
298
|
+
Returns:
|
299
|
+
tuple[bool, any]: (True, value) if the path exists, (False, None) otherwise
|
300
|
+
|
301
|
+
Example:
|
302
|
+
>>> sequence_in_dict = QuestionBasePromptsMixin.sequence_in_dict
|
303
|
+
>>> d = {'a': {'b': {'c': 1}}}
|
304
|
+
>>> sequence_in_dict(d, ('a', 'b', 'c'))
|
305
|
+
(True, 1)
|
306
|
+
>>> sequence_in_dict(d, ('a', 'b', 'd'))
|
307
|
+
(False, None)
|
308
|
+
>>> sequence_in_dict(d, ('x',))
|
309
|
+
(False, None)
|
310
|
+
"""
|
311
|
+
try:
|
312
|
+
current = d
|
313
|
+
for key in path:
|
314
|
+
current = current.get(key)
|
315
|
+
if current is None:
|
316
|
+
return (False, None)
|
317
|
+
return (True, current)
|
318
|
+
except (AttributeError, TypeError):
|
319
|
+
return (False, None)
|
320
|
+
|
321
|
+
|
322
|
+
if __name__ == "__main__":
|
323
|
+
import doctest
|
324
|
+
doctest.testmod()
|
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|
2
2
|
from typing import Any, Optional
|
3
3
|
from uuid import uuid4
|
4
4
|
|
5
|
-
from pydantic import field_validator
|
5
|
+
from pydantic import field_validator, model_validator
|
6
6
|
|
7
7
|
from edsl.questions.QuestionBase import QuestionBase
|
8
8
|
from edsl.questions.response_validator_abc import ResponseValidatorABC
|
@@ -24,6 +24,17 @@ class FreeTextResponse(BaseModel):
|
|
24
24
|
answer: str
|
25
25
|
generated_tokens: Optional[str] = None
|
26
26
|
|
27
|
+
@model_validator(mode='after')
|
28
|
+
def validate_tokens_match_answer(self):
|
29
|
+
if self.generated_tokens is not None: # If generated_tokens exists
|
30
|
+
# Ensure exact string equality
|
31
|
+
if self.answer.strip() != self.generated_tokens.strip(): # They MUST match exactly
|
32
|
+
raise ValueError(
|
33
|
+
f"answer '{self.answer}' must exactly match generated_tokens '{self.generated_tokens}'. "
|
34
|
+
f"Type of answer: {type(self.answer)}, Type of tokens: {type(self.generated_tokens)}"
|
35
|
+
)
|
36
|
+
return self
|
37
|
+
|
27
38
|
|
28
39
|
class FreeTextResponseValidator(ResponseValidatorABC):
|
29
40
|
required_params = []
|
@@ -37,10 +48,16 @@ class FreeTextResponseValidator(ResponseValidatorABC):
|
|
37
48
|
]
|
38
49
|
|
39
50
|
def fix(self, response, verbose=False):
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
51
|
+
if response.get("generated_tokens") != response.get("answer"):
|
52
|
+
return {
|
53
|
+
"answer": str(response.get("generated_tokens")),
|
54
|
+
"generated_tokens": str(response.get("generated_tokens")),
|
55
|
+
}
|
56
|
+
else:
|
57
|
+
return {
|
58
|
+
"answer": str(response.get("generated_tokens")),
|
59
|
+
"generated_tokens": str(response.get("generated_tokens")),
|
60
|
+
}
|
44
61
|
|
45
62
|
|
46
63
|
class QuestionFreeText(QuestionBase):
|
edsl/questions/descriptors.py
CHANGED
@@ -2,6 +2,7 @@
|
|
2
2
|
|
3
3
|
from abc import ABC, abstractmethod
|
4
4
|
import re
|
5
|
+
import textwrap
|
5
6
|
from typing import Any, Callable, List, Optional
|
6
7
|
from edsl.exceptions.questions import (
|
7
8
|
QuestionCreationValidationError,
|
@@ -404,6 +405,9 @@ class QuestionTextDescriptor(BaseDescriptor):
|
|
404
405
|
raise Exception("Question is too short!")
|
405
406
|
if not isinstance(value, str):
|
406
407
|
raise Exception("Question must be a string!")
|
408
|
+
|
409
|
+
#value = textwrap.dedent(value).strip()
|
410
|
+
|
407
411
|
if contains_single_braced_substring(value):
|
408
412
|
import warnings
|
409
413
|
|
@@ -1,7 +1,7 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
import copy
|
3
3
|
import itertools
|
4
|
-
from typing import Optional, List, Callable, Type, TYPE_CHECKING
|
4
|
+
from typing import Optional, List, Callable, Type, TYPE_CHECKING, Union
|
5
5
|
|
6
6
|
if TYPE_CHECKING:
|
7
7
|
from edsl.questions.QuestionBase import QuestionBase
|
@@ -9,7 +9,11 @@ if TYPE_CHECKING:
|
|
9
9
|
|
10
10
|
|
11
11
|
class QuestionBaseGenMixin:
|
12
|
-
"""Mixin for QuestionBase.
|
12
|
+
"""Mixin for QuestionBase.
|
13
|
+
|
14
|
+
This mostly has functions that are used to generate new questions from existing ones.
|
15
|
+
|
16
|
+
"""
|
13
17
|
|
14
18
|
def copy(self) -> QuestionBase:
|
15
19
|
"""Return a deep copy of the question.
|
@@ -85,48 +89,110 @@ class QuestionBaseGenMixin:
|
|
85
89
|
lp = LoopProcessor(self)
|
86
90
|
return lp.process_templates(scenario_list)
|
87
91
|
|
88
|
-
|
89
|
-
"""
|
90
|
-
|
91
|
-
:param replacement_dict: The dictionary of values to replace in the question components.
|
92
|
+
class MaxTemplateNestingExceeded(Exception):
|
93
|
+
"""Raised when template rendering exceeds maximum allowed nesting level."""
|
94
|
+
pass
|
92
95
|
|
96
|
+
def render(self, replacement_dict: dict, return_dict: bool = False) -> Union["QuestionBase", dict]:
|
97
|
+
"""Render the question components as jinja2 templates with the replacement dictionary.
|
98
|
+
Handles nested template variables by recursively rendering until all variables are resolved.
|
99
|
+
|
100
|
+
Raises:
|
101
|
+
MaxTemplateNestingExceeded: If template nesting exceeds MAX_NESTING levels
|
102
|
+
|
93
103
|
>>> from edsl.questions.QuestionFreeText import QuestionFreeText
|
94
104
|
>>> q = QuestionFreeText(question_name = "color", question_text = "What is your favorite {{ thing }}?")
|
95
105
|
>>> q.render({"thing": "color"})
|
96
106
|
Question('free_text', question_name = \"""color\""", question_text = \"""What is your favorite color?\""")
|
97
107
|
|
108
|
+
>>> from edsl.questions.QuestionMultipleChoice import QuestionMultipleChoice
|
109
|
+
>>> q = QuestionMultipleChoice(question_name = "color", question_text = "What is your favorite {{ thing }}?", question_options = ["red", "blue", "green"])
|
110
|
+
>>> from edsl.scenarios.Scenario import Scenario
|
111
|
+
>>> q.render(Scenario({"thing": "color"})).data
|
112
|
+
{'question_name': 'color', 'question_text': 'What is your favorite color?', 'question_options': ['red', 'blue', 'green']}
|
113
|
+
|
114
|
+
>>> from edsl.questions.QuestionMultipleChoice import QuestionMultipleChoice
|
115
|
+
>>> q = QuestionMultipleChoice(question_name = "color", question_text = "What is your favorite {{ thing }}?", question_options = ["red", "blue", "green"])
|
116
|
+
>>> q.render({"thing": 1}).data
|
117
|
+
{'question_name': 'color', 'question_text': 'What is your favorite 1?', 'question_options': ['red', 'blue', 'green']}
|
118
|
+
|
119
|
+
|
120
|
+
>>> from edsl.questions.QuestionMultipleChoice import QuestionMultipleChoice
|
121
|
+
>>> from edsl.scenarios.Scenario import Scenario
|
122
|
+
>>> q = QuestionMultipleChoice(question_name = "color", question_text = "What is your favorite {{ thing }}?", question_options = ["red", "blue", "green"])
|
123
|
+
>>> q.render(Scenario({"thing": "color of {{ object }}", "object":"water"})).data
|
124
|
+
{'question_name': 'color', 'question_text': 'What is your favorite color of water?', 'question_options': ['red', 'blue', 'green']}
|
125
|
+
|
126
|
+
|
127
|
+
>>> from edsl.questions.QuestionFreeText import QuestionFreeText
|
128
|
+
>>> q = QuestionFreeText(question_name = "infinite", question_text = "This has {{ a }}")
|
129
|
+
>>> q.render({"a": "{{ b }}", "b": "{{ a }}"}) # doctest: +IGNORE_EXCEPTION_DETAIL
|
130
|
+
Traceback (most recent call last):
|
131
|
+
...
|
132
|
+
edsl.questions.question_base_gen_mixin.QuestionBaseGenMixin.MaxTemplateNestingExceeded:...
|
98
133
|
"""
|
99
|
-
from jinja2 import Environment
|
134
|
+
from jinja2 import Environment, meta
|
100
135
|
from edsl.scenarios.Scenario import Scenario
|
101
136
|
|
137
|
+
MAX_NESTING = 10 # Maximum allowed nesting levels
|
138
|
+
|
102
139
|
strings_only_replacement_dict = {
|
103
140
|
k: v for k, v in replacement_dict.items() if not isinstance(v, Scenario)
|
104
141
|
}
|
105
142
|
|
143
|
+
def _has_unrendered_variables(template_str: str, env: Environment) -> bool:
|
144
|
+
"""Check if the template string has any unrendered variables."""
|
145
|
+
if not isinstance(template_str, str):
|
146
|
+
return False
|
147
|
+
ast = env.parse(template_str)
|
148
|
+
return bool(meta.find_undeclared_variables(ast))
|
149
|
+
|
106
150
|
def render_string(value: str) -> str:
|
107
151
|
if value is None or not isinstance(value, str):
|
108
152
|
return value
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
153
|
+
|
154
|
+
try:
|
155
|
+
env = Environment()
|
156
|
+
result = value
|
157
|
+
nesting_count = 0
|
158
|
+
|
159
|
+
while _has_unrendered_variables(result, env):
|
160
|
+
if nesting_count >= MAX_NESTING:
|
161
|
+
raise self.MaxTemplateNestingExceeded(
|
162
|
+
f"Template rendering exceeded {MAX_NESTING} levels of nesting. "
|
163
|
+
f"Current value: {result}"
|
164
|
+
)
|
165
|
+
|
166
|
+
template = env.from_string(result)
|
167
|
+
new_result = template.render(strings_only_replacement_dict)
|
168
|
+
if new_result == result: # Break if no changes made
|
169
|
+
break
|
170
|
+
result = new_result
|
171
|
+
nesting_count += 1
|
172
|
+
|
173
|
+
return result
|
174
|
+
except self.MaxTemplateNestingExceeded:
|
175
|
+
raise
|
176
|
+
except Exception as e:
|
177
|
+
import warnings
|
178
|
+
warnings.warn("Failed to render string: " + value)
|
179
|
+
return value
|
180
|
+
if return_dict:
|
181
|
+
return self._apply_function_dict(render_string)
|
182
|
+
else:
|
183
|
+
return self.apply_function(render_string)
|
184
|
+
|
126
185
|
def apply_function(
|
127
|
-
self, func: Callable, exclude_components: List[str] = None
|
186
|
+
self, func: Callable, exclude_components: Optional[List[str]] = None
|
128
187
|
) -> QuestionBase:
|
129
|
-
|
188
|
+
from edsl.questions.QuestionBase import QuestionBase
|
189
|
+
d = self._apply_function_dict(func, exclude_components)
|
190
|
+
return QuestionBase.from_dict(d)
|
191
|
+
|
192
|
+
def _apply_function_dict(
|
193
|
+
self, func: Callable, exclude_components: Optional[List[str]] = None
|
194
|
+
) -> dict:
|
195
|
+
"""Apply a function to the question parts, excluding certain components.
|
130
196
|
|
131
197
|
:param func: The function to apply to the question parts.
|
132
198
|
:param exclude_components: The components to exclude from the function application.
|
@@ -141,7 +207,6 @@ class QuestionBaseGenMixin:
|
|
141
207
|
Question('free_text', question_name = \"""COLOR\""", question_text = \"""WHAT IS YOUR FAVORITE COLOR?\""")
|
142
208
|
|
143
209
|
"""
|
144
|
-
from edsl.questions.QuestionBase import QuestionBase
|
145
210
|
|
146
211
|
if exclude_components is None:
|
147
212
|
exclude_components = ["question_name", "question_type"]
|
@@ -160,10 +225,10 @@ class QuestionBaseGenMixin:
|
|
160
225
|
d[key] = value
|
161
226
|
continue
|
162
227
|
d[key] = func(value)
|
163
|
-
return
|
228
|
+
return d
|
164
229
|
|
165
230
|
|
166
231
|
if __name__ == "__main__":
|
167
232
|
import doctest
|
168
233
|
|
169
|
-
doctest.testmod()
|
234
|
+
doctest.testmod(optionflags=doctest.ELLIPSIS)
|
edsl/results/Dataset.py
CHANGED
@@ -15,6 +15,7 @@ from edsl.Base import PersistenceMixin, HashingMixin
|
|
15
15
|
|
16
16
|
from edsl.results.smart_objects import FirstObject
|
17
17
|
|
18
|
+
from edsl.results.ResultsGGMixin import GGPlotMethod
|
18
19
|
|
19
20
|
class Dataset(UserList, ResultsExportMixin, PersistenceMixin, HashingMixin):
|
20
21
|
"""A class to represent a dataset of observations."""
|
@@ -26,6 +27,20 @@ class Dataset(UserList, ResultsExportMixin, PersistenceMixin, HashingMixin):
|
|
26
27
|
super().__init__(data)
|
27
28
|
self.print_parameters = print_parameters
|
28
29
|
|
30
|
+
|
31
|
+
def ggplot2(
|
32
|
+
self,
|
33
|
+
ggplot_code: str,
|
34
|
+
shape="wide",
|
35
|
+
sql: str = None,
|
36
|
+
remove_prefix: bool = True,
|
37
|
+
debug: bool = False,
|
38
|
+
height=4,
|
39
|
+
width=6,
|
40
|
+
factor_orders: Optional[dict] = None,
|
41
|
+
):
|
42
|
+
return GGPlotMethod(self).ggplot2(ggplot_code, shape, sql, remove_prefix, debug, height, width, factor_orders)
|
43
|
+
|
29
44
|
def __len__(self) -> int:
|
30
45
|
"""Return the number of observations in the dataset.
|
31
46
|
|
@@ -580,6 +595,56 @@ class Dataset(UserList, ResultsExportMixin, PersistenceMixin, HashingMixin):
|
|
580
595
|
result = cls([{col: df[col].tolist()} for col in df.columns])
|
581
596
|
return result
|
582
597
|
|
598
|
+
def to_docx(self, output_file: str, title: str = None) -> None:
|
599
|
+
"""
|
600
|
+
Convert the dataset to a Word document.
|
601
|
+
|
602
|
+
Args:
|
603
|
+
output_file (str): Path to save the Word document
|
604
|
+
title (str, optional): Title for the document
|
605
|
+
"""
|
606
|
+
from docx import Document
|
607
|
+
from docx.shared import Inches
|
608
|
+
from docx.enum.text import WD_ALIGN_PARAGRAPH
|
609
|
+
|
610
|
+
# Create document
|
611
|
+
doc = Document()
|
612
|
+
|
613
|
+
# Add title if provided
|
614
|
+
if title:
|
615
|
+
title_heading = doc.add_heading(title, level=1)
|
616
|
+
title_heading.alignment = WD_ALIGN_PARAGRAPH.CENTER
|
617
|
+
|
618
|
+
# Get headers and data
|
619
|
+
headers, data = self._tabular()
|
620
|
+
|
621
|
+
# Create table
|
622
|
+
table = doc.add_table(rows=len(data) + 1, cols=len(headers))
|
623
|
+
table.style = 'Table Grid'
|
624
|
+
|
625
|
+
# Add headers
|
626
|
+
for j, header in enumerate(headers):
|
627
|
+
cell = table.cell(0, j)
|
628
|
+
cell.text = str(header)
|
629
|
+
|
630
|
+
# Add data
|
631
|
+
for i, row in enumerate(data):
|
632
|
+
for j, cell_content in enumerate(row):
|
633
|
+
cell = table.cell(i + 1, j)
|
634
|
+
cell.text = str(cell_content) if cell_content is not None else ""
|
635
|
+
|
636
|
+
# Adjust column widths
|
637
|
+
for column in table.columns:
|
638
|
+
max_width = 0
|
639
|
+
for cell in column.cells:
|
640
|
+
text_width = len(str(cell.text))
|
641
|
+
max_width = max(max_width, text_width)
|
642
|
+
for cell in column.cells:
|
643
|
+
cell.width = Inches(min(max_width * 0.1 + 0.5, 6))
|
644
|
+
|
645
|
+
# Save the document
|
646
|
+
doc.save(output_file)
|
647
|
+
|
583
648
|
|
584
649
|
if __name__ == "__main__":
|
585
650
|
import doctest
|