edsl 0.1.44__py3-none-any.whl → 0.1.45__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. edsl/__version__.py +1 -1
  2. edsl/agents/InvigilatorBase.py +3 -1
  3. edsl/agents/PromptConstructor.py +62 -34
  4. edsl/agents/QuestionInstructionPromptBuilder.py +111 -68
  5. edsl/agents/QuestionTemplateReplacementsBuilder.py +69 -16
  6. edsl/agents/question_option_processor.py +15 -6
  7. edsl/coop/CoopFunctionsMixin.py +3 -4
  8. edsl/coop/coop.py +23 -9
  9. edsl/enums.py +3 -3
  10. edsl/inference_services/AnthropicService.py +11 -9
  11. edsl/inference_services/AvailableModelFetcher.py +2 -0
  12. edsl/inference_services/AwsBedrock.py +1 -2
  13. edsl/inference_services/AzureAI.py +12 -9
  14. edsl/inference_services/GoogleService.py +9 -4
  15. edsl/inference_services/InferenceServicesCollection.py +2 -2
  16. edsl/inference_services/MistralAIService.py +1 -2
  17. edsl/inference_services/OpenAIService.py +9 -4
  18. edsl/inference_services/PerplexityService.py +2 -1
  19. edsl/inference_services/{GrokService.py → XAIService.py} +2 -2
  20. edsl/inference_services/registry.py +2 -2
  21. edsl/jobs/Jobs.py +9 -0
  22. edsl/jobs/JobsChecks.py +10 -13
  23. edsl/jobs/async_interview_runner.py +3 -1
  24. edsl/jobs/check_survey_scenario_compatibility.py +5 -5
  25. edsl/jobs/interviews/InterviewExceptionEntry.py +12 -0
  26. edsl/jobs/tasks/TaskHistory.py +1 -1
  27. edsl/language_models/LanguageModel.py +0 -3
  28. edsl/language_models/PriceManager.py +45 -5
  29. edsl/language_models/model.py +47 -26
  30. edsl/questions/QuestionBase.py +21 -0
  31. edsl/questions/QuestionBasePromptsMixin.py +103 -0
  32. edsl/questions/QuestionFreeText.py +22 -5
  33. edsl/questions/descriptors.py +4 -0
  34. edsl/questions/question_base_gen_mixin.py +94 -29
  35. edsl/results/Dataset.py +65 -0
  36. edsl/results/DatasetExportMixin.py +299 -32
  37. edsl/results/Result.py +27 -0
  38. edsl/results/Results.py +22 -2
  39. edsl/results/ResultsGGMixin.py +7 -3
  40. edsl/scenarios/DocumentChunker.py +2 -0
  41. edsl/scenarios/FileStore.py +10 -0
  42. edsl/scenarios/PdfExtractor.py +21 -1
  43. edsl/scenarios/Scenario.py +25 -9
  44. edsl/scenarios/ScenarioList.py +73 -3
  45. edsl/scenarios/handlers/__init__.py +1 -0
  46. edsl/scenarios/handlers/docx.py +5 -1
  47. edsl/scenarios/handlers/jpeg.py +39 -0
  48. edsl/surveys/Survey.py +5 -4
  49. edsl/surveys/SurveyFlowVisualization.py +91 -43
  50. edsl/templates/error_reporting/exceptions_table.html +7 -8
  51. edsl/templates/error_reporting/interview_details.html +1 -1
  52. edsl/templates/error_reporting/interviews.html +0 -1
  53. edsl/templates/error_reporting/overview.html +2 -7
  54. edsl/templates/error_reporting/performance_plot.html +1 -1
  55. edsl/templates/error_reporting/report.css +1 -1
  56. edsl/utilities/PrettyList.py +14 -0
  57. edsl-0.1.45.dist-info/METADATA +246 -0
  58. {edsl-0.1.44.dist-info → edsl-0.1.45.dist-info}/RECORD +60 -59
  59. edsl-0.1.44.dist-info/METADATA +0 -110
  60. {edsl-0.1.44.dist-info → edsl-0.1.45.dist-info}/LICENSE +0 -0
  61. {edsl-0.1.44.dist-info → edsl-0.1.45.dist-info}/WHEEL +0 -0
@@ -17,7 +17,11 @@ if TYPE_CHECKING:
17
17
  from edsl.results.Dataset import Dataset
18
18
 
19
19
 
20
- def get_model_class(model_name, registry: Optional[InferenceServicesCollection] = None, service_name: Optional[InferenceServiceLiteral] = None):
20
+ def get_model_class(
21
+ model_name,
22
+ registry: Optional[InferenceServicesCollection] = None,
23
+ service_name: Optional[InferenceServiceLiteral] = None,
24
+ ):
21
25
  from edsl.inference_services.registry import default
22
26
 
23
27
  registry = registry or default
@@ -40,6 +44,9 @@ class Meta(type):
40
44
  To get the default model, you can leave out the model name.
41
45
  To see the available models, you can do:
42
46
  >>> Model.available()
47
+
48
+ Or to see the models for a specific service, you can do:
49
+ >>> Model.available(service='openai')
43
50
  """
44
51
  )
45
52
 
@@ -97,7 +104,10 @@ class Model(metaclass=Meta):
97
104
  *args,
98
105
  **kwargs,
99
106
  ):
100
- "Instantiate a new language model."
107
+ """Instantiate a new language model.
108
+ >>> Model()
109
+ Model(...)
110
+ """
101
111
  # Map index to the respective subclass
102
112
  if model_name is None:
103
113
  model_name = cls.default_model
@@ -127,28 +137,25 @@ class Model(metaclass=Meta):
127
137
  >>> Model.service_classes()
128
138
  [...]
129
139
  """
130
- return [r for r in cls.services(name_only=True)]
140
+ return [r for r in cls.services()]
131
141
 
132
142
  @classmethod
133
143
  def services(cls, name_only: bool = False) -> List[str]:
134
- """Returns a list of services, annotated with whether the user has local keys for them."""
135
- services_with_local_keys = set(cls.key_info().select("service").to_list())
136
- f = lambda service_name: (
137
- "yes" if service_name in services_with_local_keys else " "
138
- )
139
- if name_only:
140
- return PrettyList(
141
- [r._inference_service_ for r in cls.get_registry().services],
142
- columns=["Service Name"],
143
- )
144
- else:
145
- return PrettyList(
144
+ """Returns a list of services excluding 'test', sorted alphabetically.
145
+
146
+ >>> Model.services()
147
+ [...]
148
+ """
149
+ return PrettyList(
150
+ sorted(
146
151
  [
147
- (r._inference_service_, f(r._inference_service_))
152
+ [r._inference_service_]
148
153
  for r in cls.get_registry().services
149
- ],
150
- columns=["Service Name", "Local key?"],
151
- )
154
+ if r._inference_service_.lower() != "test"
155
+ ]
156
+ ),
157
+ columns=["Service Name"],
158
+ )
152
159
 
153
160
  @classmethod
154
161
  def services_with_local_keys(cls) -> set:
@@ -198,7 +205,15 @@ class Model(metaclass=Meta):
198
205
  search_term: str = None,
199
206
  name_only: bool = False,
200
207
  service: Optional[str] = None,
208
+ force_refresh: bool = False,
201
209
  ):
210
+ """Get available models
211
+
212
+ >>> Model.available()
213
+ [...]
214
+ >>> Model.available(service='openai')
215
+ [...]
216
+ """
202
217
  # if search_term is None and service is None:
203
218
  # print("Getting available models...")
204
219
  # print("You have local keys for the following services:")
@@ -209,13 +224,16 @@ class Model(metaclass=Meta):
209
224
  # return None
210
225
 
211
226
  if service is not None:
212
- if service not in cls.services(name_only=True):
227
+ known_services = [x[0] for x in cls.services(name_only=True)]
228
+ if service not in known_services:
213
229
  raise ValueError(
214
230
  f"Service {service} not found in available services.",
215
- f"Available services are: {cls.services()}",
231
+ f"Available services are: {known_services}",
216
232
  )
217
233
 
218
- full_list = cls.get_registry().available(service=service)
234
+ full_list = cls.get_registry().available(
235
+ service=service, force_refresh=force_refresh
236
+ )
219
237
 
220
238
  if search_term is None:
221
239
  if name_only:
@@ -319,6 +337,9 @@ class Model(metaclass=Meta):
319
337
  """
320
338
  Returns an example Model instance.
321
339
 
340
+ >>> Model.example()
341
+ Model(...)
342
+
322
343
  :param randomize: If True, the temperature is set to a random decimal between 0 and 1.
323
344
  """
324
345
  temperature = 0.5 if not randomize else round(random(), 2)
@@ -331,7 +352,7 @@ if __name__ == "__main__":
331
352
 
332
353
  doctest.testmod(optionflags=doctest.ELLIPSIS)
333
354
 
334
- available = Model.available()
335
- m = Model("gpt-4-1106-preview")
336
- results = m.execute_model_call("Hello world")
337
- print(results)
355
+ # available = Model.available()
356
+ # m = Model("gpt-4-1106-preview")
357
+ # results = m.execute_model_call("Hello world")
358
+ # print(results)
@@ -85,6 +85,9 @@ class QuestionBase(
85
85
  >>> Q.example()._simulate_answer()
86
86
  {'answer': '...', 'generated_tokens': ...}
87
87
  """
88
+ if self.question_type == "free_text":
89
+ return {"answer": "Hello, how are you?", 'generated_tokens': "Hello, how are you?"}
90
+
88
91
  simulated_answer = self.fake_data_factory.build().dict()
89
92
  if human_readable and hasattr(self, "question_options") and self.use_code:
90
93
  simulated_answer["answer"] = [
@@ -432,6 +435,24 @@ class QuestionBase(
432
435
 
433
436
  return Survey([self])
434
437
 
438
+ def humanize(
439
+ self,
440
+ project_name: str = "Project",
441
+ survey_description: Optional[str] = None,
442
+ survey_alias: Optional[str] = None,
443
+ survey_visibility: Optional["VisibilityType"] = "unlisted",
444
+ ) -> dict:
445
+ """
446
+ Turn a single question into a survey and send the survey to Coop.
447
+
448
+ Then, create a project on Coop so you can share the survey with human respondents.
449
+ """
450
+ s = self.to_survey()
451
+ project_details = s.humanize(
452
+ project_name, survey_description, survey_alias, survey_visibility
453
+ )
454
+ return project_details
455
+
435
456
  def by(self, *args) -> "Jobs":
436
457
  """Turn a single question into a survey and then a Job."""
437
458
  from edsl.surveys.Survey import Survey
@@ -187,6 +187,73 @@ class QuestionBasePromptsMixin:
187
187
  from edsl.prompts import Prompt
188
188
 
189
189
  return Prompt(self.question_presentation) + Prompt(self.answering_instructions)
190
+
191
+
192
+ def detailed_parameters_by_key(self) -> dict[str, set[tuple[str, ...]]]:
193
+ """
194
+ Return a dictionary of parameters by key.
195
+
196
+ >>> from edsl import QuestionMultipleChoice
197
+ >>> QuestionMultipleChoice.example().detailed_parameters_by_key()
198
+ {'question_name': set(), 'question_text': set()}
199
+
200
+ >>> from edsl import QuestionFreeText
201
+ >>> q = QuestionFreeText(question_name = "example", question_text = "What is your name, {{ nickname }}, based on {{ q0.answer }}?")
202
+ >>> r = q.detailed_parameters_by_key()
203
+ >>> r == {'question_name': set(), 'question_text': {('q0', 'answer'), ('nickname',)}}
204
+ True
205
+ """
206
+ params_by_key = {}
207
+ for key, value in self.data.items():
208
+ if isinstance(value, str):
209
+ params_by_key[key] = self.extract_parameters(value)
210
+ return params_by_key
211
+
212
+ @staticmethod
213
+ def extract_parameters(txt: str) -> set[tuple[str, ...]]:
214
+ """Return all parameters of the question as tuples representing their full paths.
215
+
216
+ :param txt: The text to extract parameters from.
217
+ :return: A set of tuples representing the parameters.
218
+
219
+ >>> from edsl.questions import QuestionMultipleChoice
220
+ >>> d = QuestionMultipleChoice.example().extract_parameters("What is your name, {{ nickname }}, based on {{ q0.answer }}?")
221
+ >>> d =={('nickname',), ('q0', 'answer')}
222
+ True
223
+ """
224
+ from jinja2 import Environment, nodes
225
+
226
+ env = Environment()
227
+ #txt = self._all_text()
228
+ ast = env.parse(txt)
229
+
230
+ variables = set()
231
+ processed_nodes = set() # Keep track of nodes we've processed
232
+
233
+ def visit_node(node, path=()):
234
+ if id(node) in processed_nodes:
235
+ return
236
+ processed_nodes.add(id(node))
237
+
238
+ if isinstance(node, nodes.Name):
239
+ # Only add the name if we're not in the middle of building a longer path
240
+ if not path:
241
+ variables.add((node.name,))
242
+ else:
243
+ variables.add((node.name,) + path)
244
+ elif isinstance(node, nodes.Getattr):
245
+ # Build path from bottom up
246
+ new_path = (node.attr,) + path
247
+ visit_node(node.node, new_path)
248
+
249
+ for node in ast.find_all((nodes.Name, nodes.Getattr)):
250
+ visit_node(node)
251
+
252
+ return variables
253
+
254
+ @property
255
+ def detailed_parameters(self):
256
+ return [".".join(p) for p in self.extract_parameters(self._all_text())]
190
257
 
191
258
  @property
192
259
  def parameters(self) -> set[str]:
@@ -219,3 +286,39 @@ class QuestionBasePromptsMixin:
219
286
  return self.new_default_instructions
220
287
  else:
221
288
  return self.applicable_prompts(model)[0]()
289
+
290
+ @staticmethod
291
+ def sequence_in_dict(d: dict, path: tuple[str, ...]) -> tuple[bool, any]:
292
+ """Check if a sequence of nested keys exists in a dictionary and return the value.
293
+
294
+ Args:
295
+ d: The dictionary to check
296
+ path: Tuple of keys representing the nested path
297
+
298
+ Returns:
299
+ tuple[bool, any]: (True, value) if the path exists, (False, None) otherwise
300
+
301
+ Example:
302
+ >>> sequence_in_dict = QuestionBasePromptsMixin.sequence_in_dict
303
+ >>> d = {'a': {'b': {'c': 1}}}
304
+ >>> sequence_in_dict(d, ('a', 'b', 'c'))
305
+ (True, 1)
306
+ >>> sequence_in_dict(d, ('a', 'b', 'd'))
307
+ (False, None)
308
+ >>> sequence_in_dict(d, ('x',))
309
+ (False, None)
310
+ """
311
+ try:
312
+ current = d
313
+ for key in path:
314
+ current = current.get(key)
315
+ if current is None:
316
+ return (False, None)
317
+ return (True, current)
318
+ except (AttributeError, TypeError):
319
+ return (False, None)
320
+
321
+
322
+ if __name__ == "__main__":
323
+ import doctest
324
+ doctest.testmod()
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
  from typing import Any, Optional
3
3
  from uuid import uuid4
4
4
 
5
- from pydantic import field_validator
5
+ from pydantic import field_validator, model_validator
6
6
 
7
7
  from edsl.questions.QuestionBase import QuestionBase
8
8
  from edsl.questions.response_validator_abc import ResponseValidatorABC
@@ -24,6 +24,17 @@ class FreeTextResponse(BaseModel):
24
24
  answer: str
25
25
  generated_tokens: Optional[str] = None
26
26
 
27
+ @model_validator(mode='after')
28
+ def validate_tokens_match_answer(self):
29
+ if self.generated_tokens is not None: # If generated_tokens exists
30
+ # Ensure exact string equality
31
+ if self.answer.strip() != self.generated_tokens.strip(): # They MUST match exactly
32
+ raise ValueError(
33
+ f"answer '{self.answer}' must exactly match generated_tokens '{self.generated_tokens}'. "
34
+ f"Type of answer: {type(self.answer)}, Type of tokens: {type(self.generated_tokens)}"
35
+ )
36
+ return self
37
+
27
38
 
28
39
  class FreeTextResponseValidator(ResponseValidatorABC):
29
40
  required_params = []
@@ -37,10 +48,16 @@ class FreeTextResponseValidator(ResponseValidatorABC):
37
48
  ]
38
49
 
39
50
  def fix(self, response, verbose=False):
40
- return {
41
- "answer": str(response.get("generated_tokens")),
42
- "generated_tokens": str(response.get("generated_tokens")),
43
- }
51
+ if response.get("generated_tokens") != response.get("answer"):
52
+ return {
53
+ "answer": str(response.get("generated_tokens")),
54
+ "generated_tokens": str(response.get("generated_tokens")),
55
+ }
56
+ else:
57
+ return {
58
+ "answer": str(response.get("generated_tokens")),
59
+ "generated_tokens": str(response.get("generated_tokens")),
60
+ }
44
61
 
45
62
 
46
63
  class QuestionFreeText(QuestionBase):
@@ -2,6 +2,7 @@
2
2
 
3
3
  from abc import ABC, abstractmethod
4
4
  import re
5
+ import textwrap
5
6
  from typing import Any, Callable, List, Optional
6
7
  from edsl.exceptions.questions import (
7
8
  QuestionCreationValidationError,
@@ -404,6 +405,9 @@ class QuestionTextDescriptor(BaseDescriptor):
404
405
  raise Exception("Question is too short!")
405
406
  if not isinstance(value, str):
406
407
  raise Exception("Question must be a string!")
408
+
409
+ #value = textwrap.dedent(value).strip()
410
+
407
411
  if contains_single_braced_substring(value):
408
412
  import warnings
409
413
 
@@ -1,7 +1,7 @@
1
1
  from __future__ import annotations
2
2
  import copy
3
3
  import itertools
4
- from typing import Optional, List, Callable, Type, TYPE_CHECKING
4
+ from typing import Optional, List, Callable, Type, TYPE_CHECKING, Union
5
5
 
6
6
  if TYPE_CHECKING:
7
7
  from edsl.questions.QuestionBase import QuestionBase
@@ -9,7 +9,11 @@ if TYPE_CHECKING:
9
9
 
10
10
 
11
11
  class QuestionBaseGenMixin:
12
- """Mixin for QuestionBase."""
12
+ """Mixin for QuestionBase.
13
+
14
+ This mostly has functions that are used to generate new questions from existing ones.
15
+
16
+ """
13
17
 
14
18
  def copy(self) -> QuestionBase:
15
19
  """Return a deep copy of the question.
@@ -85,48 +89,110 @@ class QuestionBaseGenMixin:
85
89
  lp = LoopProcessor(self)
86
90
  return lp.process_templates(scenario_list)
87
91
 
88
- def render(self, replacement_dict: dict) -> "QuestionBase":
89
- """Render the question components as jinja2 templates with the replacement dictionary.
90
-
91
- :param replacement_dict: The dictionary of values to replace in the question components.
92
+ class MaxTemplateNestingExceeded(Exception):
93
+ """Raised when template rendering exceeds maximum allowed nesting level."""
94
+ pass
92
95
 
96
+ def render(self, replacement_dict: dict, return_dict: bool = False) -> Union["QuestionBase", dict]:
97
+ """Render the question components as jinja2 templates with the replacement dictionary.
98
+ Handles nested template variables by recursively rendering until all variables are resolved.
99
+
100
+ Raises:
101
+ MaxTemplateNestingExceeded: If template nesting exceeds MAX_NESTING levels
102
+
93
103
  >>> from edsl.questions.QuestionFreeText import QuestionFreeText
94
104
  >>> q = QuestionFreeText(question_name = "color", question_text = "What is your favorite {{ thing }}?")
95
105
  >>> q.render({"thing": "color"})
96
106
  Question('free_text', question_name = \"""color\""", question_text = \"""What is your favorite color?\""")
97
107
 
108
+ >>> from edsl.questions.QuestionMultipleChoice import QuestionMultipleChoice
109
+ >>> q = QuestionMultipleChoice(question_name = "color", question_text = "What is your favorite {{ thing }}?", question_options = ["red", "blue", "green"])
110
+ >>> from edsl.scenarios.Scenario import Scenario
111
+ >>> q.render(Scenario({"thing": "color"})).data
112
+ {'question_name': 'color', 'question_text': 'What is your favorite color?', 'question_options': ['red', 'blue', 'green']}
113
+
114
+ >>> from edsl.questions.QuestionMultipleChoice import QuestionMultipleChoice
115
+ >>> q = QuestionMultipleChoice(question_name = "color", question_text = "What is your favorite {{ thing }}?", question_options = ["red", "blue", "green"])
116
+ >>> q.render({"thing": 1}).data
117
+ {'question_name': 'color', 'question_text': 'What is your favorite 1?', 'question_options': ['red', 'blue', 'green']}
118
+
119
+
120
+ >>> from edsl.questions.QuestionMultipleChoice import QuestionMultipleChoice
121
+ >>> from edsl.scenarios.Scenario import Scenario
122
+ >>> q = QuestionMultipleChoice(question_name = "color", question_text = "What is your favorite {{ thing }}?", question_options = ["red", "blue", "green"])
123
+ >>> q.render(Scenario({"thing": "color of {{ object }}", "object":"water"})).data
124
+ {'question_name': 'color', 'question_text': 'What is your favorite color of water?', 'question_options': ['red', 'blue', 'green']}
125
+
126
+
127
+ >>> from edsl.questions.QuestionFreeText import QuestionFreeText
128
+ >>> q = QuestionFreeText(question_name = "infinite", question_text = "This has {{ a }}")
129
+ >>> q.render({"a": "{{ b }}", "b": "{{ a }}"}) # doctest: +IGNORE_EXCEPTION_DETAIL
130
+ Traceback (most recent call last):
131
+ ...
132
+ edsl.questions.question_base_gen_mixin.QuestionBaseGenMixin.MaxTemplateNestingExceeded:...
98
133
  """
99
- from jinja2 import Environment
134
+ from jinja2 import Environment, meta
100
135
  from edsl.scenarios.Scenario import Scenario
101
136
 
137
+ MAX_NESTING = 10 # Maximum allowed nesting levels
138
+
102
139
  strings_only_replacement_dict = {
103
140
  k: v for k, v in replacement_dict.items() if not isinstance(v, Scenario)
104
141
  }
105
142
 
143
+ def _has_unrendered_variables(template_str: str, env: Environment) -> bool:
144
+ """Check if the template string has any unrendered variables."""
145
+ if not isinstance(template_str, str):
146
+ return False
147
+ ast = env.parse(template_str)
148
+ return bool(meta.find_undeclared_variables(ast))
149
+
106
150
  def render_string(value: str) -> str:
107
151
  if value is None or not isinstance(value, str):
108
152
  return value
109
- else:
110
- try:
111
- return (
112
- Environment()
113
- .from_string(value)
114
- .render(strings_only_replacement_dict)
115
- )
116
- except Exception as e:
117
- #breakpoint()
118
- import warnings
119
-
120
- warnings.warn("Failed to render string: " + value)
121
- # breakpoint()
122
- return value
123
-
124
- return self.apply_function(render_string)
125
-
153
+
154
+ try:
155
+ env = Environment()
156
+ result = value
157
+ nesting_count = 0
158
+
159
+ while _has_unrendered_variables(result, env):
160
+ if nesting_count >= MAX_NESTING:
161
+ raise self.MaxTemplateNestingExceeded(
162
+ f"Template rendering exceeded {MAX_NESTING} levels of nesting. "
163
+ f"Current value: {result}"
164
+ )
165
+
166
+ template = env.from_string(result)
167
+ new_result = template.render(strings_only_replacement_dict)
168
+ if new_result == result: # Break if no changes made
169
+ break
170
+ result = new_result
171
+ nesting_count += 1
172
+
173
+ return result
174
+ except self.MaxTemplateNestingExceeded:
175
+ raise
176
+ except Exception as e:
177
+ import warnings
178
+ warnings.warn("Failed to render string: " + value)
179
+ return value
180
+ if return_dict:
181
+ return self._apply_function_dict(render_string)
182
+ else:
183
+ return self.apply_function(render_string)
184
+
126
185
  def apply_function(
127
- self, func: Callable, exclude_components: List[str] = None
186
+ self, func: Callable, exclude_components: Optional[List[str]] = None
128
187
  ) -> QuestionBase:
129
- """Apply a function to the question parts
188
+ from edsl.questions.QuestionBase import QuestionBase
189
+ d = self._apply_function_dict(func, exclude_components)
190
+ return QuestionBase.from_dict(d)
191
+
192
+ def _apply_function_dict(
193
+ self, func: Callable, exclude_components: Optional[List[str]] = None
194
+ ) -> dict:
195
+ """Apply a function to the question parts, excluding certain components.
130
196
 
131
197
  :param func: The function to apply to the question parts.
132
198
  :param exclude_components: The components to exclude from the function application.
@@ -141,7 +207,6 @@ class QuestionBaseGenMixin:
141
207
  Question('free_text', question_name = \"""COLOR\""", question_text = \"""WHAT IS YOUR FAVORITE COLOR?\""")
142
208
 
143
209
  """
144
- from edsl.questions.QuestionBase import QuestionBase
145
210
 
146
211
  if exclude_components is None:
147
212
  exclude_components = ["question_name", "question_type"]
@@ -160,10 +225,10 @@ class QuestionBaseGenMixin:
160
225
  d[key] = value
161
226
  continue
162
227
  d[key] = func(value)
163
- return QuestionBase.from_dict(d)
228
+ return d
164
229
 
165
230
 
166
231
  if __name__ == "__main__":
167
232
  import doctest
168
233
 
169
- doctest.testmod()
234
+ doctest.testmod(optionflags=doctest.ELLIPSIS)
edsl/results/Dataset.py CHANGED
@@ -15,6 +15,7 @@ from edsl.Base import PersistenceMixin, HashingMixin
15
15
 
16
16
  from edsl.results.smart_objects import FirstObject
17
17
 
18
+ from edsl.results.ResultsGGMixin import GGPlotMethod
18
19
 
19
20
  class Dataset(UserList, ResultsExportMixin, PersistenceMixin, HashingMixin):
20
21
  """A class to represent a dataset of observations."""
@@ -26,6 +27,20 @@ class Dataset(UserList, ResultsExportMixin, PersistenceMixin, HashingMixin):
26
27
  super().__init__(data)
27
28
  self.print_parameters = print_parameters
28
29
 
30
+
31
+ def ggplot2(
32
+ self,
33
+ ggplot_code: str,
34
+ shape="wide",
35
+ sql: str = None,
36
+ remove_prefix: bool = True,
37
+ debug: bool = False,
38
+ height=4,
39
+ width=6,
40
+ factor_orders: Optional[dict] = None,
41
+ ):
42
+ return GGPlotMethod(self).ggplot2(ggplot_code, shape, sql, remove_prefix, debug, height, width, factor_orders)
43
+
29
44
  def __len__(self) -> int:
30
45
  """Return the number of observations in the dataset.
31
46
 
@@ -580,6 +595,56 @@ class Dataset(UserList, ResultsExportMixin, PersistenceMixin, HashingMixin):
580
595
  result = cls([{col: df[col].tolist()} for col in df.columns])
581
596
  return result
582
597
 
598
+ def to_docx(self, output_file: str, title: str = None) -> None:
599
+ """
600
+ Convert the dataset to a Word document.
601
+
602
+ Args:
603
+ output_file (str): Path to save the Word document
604
+ title (str, optional): Title for the document
605
+ """
606
+ from docx import Document
607
+ from docx.shared import Inches
608
+ from docx.enum.text import WD_ALIGN_PARAGRAPH
609
+
610
+ # Create document
611
+ doc = Document()
612
+
613
+ # Add title if provided
614
+ if title:
615
+ title_heading = doc.add_heading(title, level=1)
616
+ title_heading.alignment = WD_ALIGN_PARAGRAPH.CENTER
617
+
618
+ # Get headers and data
619
+ headers, data = self._tabular()
620
+
621
+ # Create table
622
+ table = doc.add_table(rows=len(data) + 1, cols=len(headers))
623
+ table.style = 'Table Grid'
624
+
625
+ # Add headers
626
+ for j, header in enumerate(headers):
627
+ cell = table.cell(0, j)
628
+ cell.text = str(header)
629
+
630
+ # Add data
631
+ for i, row in enumerate(data):
632
+ for j, cell_content in enumerate(row):
633
+ cell = table.cell(i + 1, j)
634
+ cell.text = str(cell_content) if cell_content is not None else ""
635
+
636
+ # Adjust column widths
637
+ for column in table.columns:
638
+ max_width = 0
639
+ for cell in column.cells:
640
+ text_width = len(str(cell.text))
641
+ max_width = max(max_width, text_width)
642
+ for cell in column.cells:
643
+ cell.width = Inches(min(max_width * 0.1 + 0.5, 6))
644
+
645
+ # Save the document
646
+ doc.save(output_file)
647
+
583
648
 
584
649
  if __name__ == "__main__":
585
650
  import doctest