edsl 0.1.44__py3-none-any.whl → 0.1.46__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. edsl/Base.py +7 -3
  2. edsl/__version__.py +1 -1
  3. edsl/agents/InvigilatorBase.py +3 -1
  4. edsl/agents/PromptConstructor.py +66 -91
  5. edsl/agents/QuestionInstructionPromptBuilder.py +160 -79
  6. edsl/agents/QuestionTemplateReplacementsBuilder.py +80 -17
  7. edsl/agents/question_option_processor.py +15 -6
  8. edsl/coop/CoopFunctionsMixin.py +3 -4
  9. edsl/coop/coop.py +171 -96
  10. edsl/data/RemoteCacheSync.py +10 -9
  11. edsl/enums.py +3 -3
  12. edsl/inference_services/AnthropicService.py +11 -9
  13. edsl/inference_services/AvailableModelFetcher.py +2 -0
  14. edsl/inference_services/AwsBedrock.py +1 -2
  15. edsl/inference_services/AzureAI.py +12 -9
  16. edsl/inference_services/GoogleService.py +9 -4
  17. edsl/inference_services/InferenceServicesCollection.py +2 -2
  18. edsl/inference_services/MistralAIService.py +1 -2
  19. edsl/inference_services/OpenAIService.py +9 -4
  20. edsl/inference_services/PerplexityService.py +2 -1
  21. edsl/inference_services/{GrokService.py → XAIService.py} +2 -2
  22. edsl/inference_services/registry.py +2 -2
  23. edsl/jobs/AnswerQuestionFunctionConstructor.py +12 -1
  24. edsl/jobs/Jobs.py +24 -17
  25. edsl/jobs/JobsChecks.py +10 -13
  26. edsl/jobs/JobsPrompts.py +49 -26
  27. edsl/jobs/JobsRemoteInferenceHandler.py +4 -5
  28. edsl/jobs/async_interview_runner.py +3 -1
  29. edsl/jobs/check_survey_scenario_compatibility.py +5 -5
  30. edsl/jobs/data_structures.py +3 -0
  31. edsl/jobs/interviews/Interview.py +6 -3
  32. edsl/jobs/interviews/InterviewExceptionEntry.py +12 -0
  33. edsl/jobs/tasks/TaskHistory.py +1 -1
  34. edsl/language_models/LanguageModel.py +6 -3
  35. edsl/language_models/PriceManager.py +45 -5
  36. edsl/language_models/model.py +47 -26
  37. edsl/questions/QuestionBase.py +21 -0
  38. edsl/questions/QuestionBasePromptsMixin.py +103 -0
  39. edsl/questions/QuestionFreeText.py +22 -5
  40. edsl/questions/descriptors.py +4 -0
  41. edsl/questions/question_base_gen_mixin.py +96 -29
  42. edsl/results/Dataset.py +65 -0
  43. edsl/results/DatasetExportMixin.py +320 -32
  44. edsl/results/Result.py +27 -0
  45. edsl/results/Results.py +22 -2
  46. edsl/results/ResultsGGMixin.py +7 -3
  47. edsl/scenarios/DocumentChunker.py +2 -0
  48. edsl/scenarios/FileStore.py +10 -0
  49. edsl/scenarios/PdfExtractor.py +21 -1
  50. edsl/scenarios/Scenario.py +25 -9
  51. edsl/scenarios/ScenarioList.py +226 -24
  52. edsl/scenarios/handlers/__init__.py +1 -0
  53. edsl/scenarios/handlers/docx.py +5 -1
  54. edsl/scenarios/handlers/jpeg.py +39 -0
  55. edsl/surveys/Survey.py +5 -4
  56. edsl/surveys/SurveyFlowVisualization.py +91 -43
  57. edsl/templates/error_reporting/exceptions_table.html +7 -8
  58. edsl/templates/error_reporting/interview_details.html +1 -1
  59. edsl/templates/error_reporting/interviews.html +0 -1
  60. edsl/templates/error_reporting/overview.html +2 -7
  61. edsl/templates/error_reporting/performance_plot.html +1 -1
  62. edsl/templates/error_reporting/report.css +1 -1
  63. edsl/utilities/PrettyList.py +14 -0
  64. edsl-0.1.46.dist-info/METADATA +246 -0
  65. {edsl-0.1.44.dist-info → edsl-0.1.46.dist-info}/RECORD +67 -66
  66. edsl-0.1.44.dist-info/METADATA +0 -110
  67. {edsl-0.1.44.dist-info → edsl-0.1.46.dist-info}/LICENSE +0 -0
  68. {edsl-0.1.44.dist-info → edsl-0.1.46.dist-info}/WHEEL +0 -0
@@ -1,7 +1,7 @@
1
1
  from __future__ import annotations
2
2
  import copy
3
3
  import itertools
4
- from typing import Optional, List, Callable, Type, TYPE_CHECKING
4
+ from typing import Optional, List, Callable, Type, TYPE_CHECKING, Union
5
5
 
6
6
  if TYPE_CHECKING:
7
7
  from edsl.questions.QuestionBase import QuestionBase
@@ -9,7 +9,11 @@ if TYPE_CHECKING:
9
9
 
10
10
 
11
11
  class QuestionBaseGenMixin:
12
- """Mixin for QuestionBase."""
12
+ """Mixin for QuestionBase.
13
+
14
+ This mostly has functions that are used to generate new questions from existing ones.
15
+
16
+ """
13
17
 
14
18
  def copy(self) -> QuestionBase:
15
19
  """Return a deep copy of the question.
@@ -85,48 +89,112 @@ class QuestionBaseGenMixin:
85
89
  lp = LoopProcessor(self)
86
90
  return lp.process_templates(scenario_list)
87
91
 
88
- def render(self, replacement_dict: dict) -> "QuestionBase":
89
- """Render the question components as jinja2 templates with the replacement dictionary.
90
-
91
- :param replacement_dict: The dictionary of values to replace in the question components.
92
+ class MaxTemplateNestingExceeded(Exception):
93
+ """Raised when template rendering exceeds maximum allowed nesting level."""
94
+ pass
92
95
 
96
+ def render(self, replacement_dict: dict, return_dict: bool = False) -> Union["QuestionBase", dict]:
97
+ """Render the question components as jinja2 templates with the replacement dictionary.
98
+ Handles nested template variables by recursively rendering until all variables are resolved.
99
+
100
+ Raises:
101
+ MaxTemplateNestingExceeded: If template nesting exceeds MAX_NESTING levels
102
+
93
103
  >>> from edsl.questions.QuestionFreeText import QuestionFreeText
94
104
  >>> q = QuestionFreeText(question_name = "color", question_text = "What is your favorite {{ thing }}?")
95
105
  >>> q.render({"thing": "color"})
96
106
  Question('free_text', question_name = \"""color\""", question_text = \"""What is your favorite color?\""")
97
107
 
108
+ >>> from edsl.questions.QuestionMultipleChoice import QuestionMultipleChoice
109
+ >>> q = QuestionMultipleChoice(question_name = "color", question_text = "What is your favorite {{ thing }}?", question_options = ["red", "blue", "green"])
110
+ >>> from edsl.scenarios.Scenario import Scenario
111
+ >>> q.render(Scenario({"thing": "color"})).data
112
+ {'question_name': 'color', 'question_text': 'What is your favorite color?', 'question_options': ['red', 'blue', 'green']}
113
+
114
+ >>> from edsl.questions.QuestionMultipleChoice import QuestionMultipleChoice
115
+ >>> q = QuestionMultipleChoice(question_name = "color", question_text = "What is your favorite {{ thing }}?", question_options = ["red", "blue", "green"])
116
+ >>> q.render({"thing": 1}).data
117
+ {'question_name': 'color', 'question_text': 'What is your favorite 1?', 'question_options': ['red', 'blue', 'green']}
118
+
119
+
120
+ >>> from edsl.questions.QuestionMultipleChoice import QuestionMultipleChoice
121
+ >>> from edsl.scenarios.Scenario import Scenario
122
+ >>> q = QuestionMultipleChoice(question_name = "color", question_text = "What is your favorite {{ thing }}?", question_options = ["red", "blue", "green"])
123
+ >>> q.render(Scenario({"thing": "color of {{ object }}", "object":"water"})).data
124
+ {'question_name': 'color', 'question_text': 'What is your favorite color of water?', 'question_options': ['red', 'blue', 'green']}
125
+
126
+
127
+ >>> from edsl.questions.QuestionFreeText import QuestionFreeText
128
+ >>> q = QuestionFreeText(question_name = "infinite", question_text = "This has {{ a }}")
129
+ >>> q.render({"a": "{{ b }}", "b": "{{ a }}"}) # doctest: +IGNORE_EXCEPTION_DETAIL
130
+ Traceback (most recent call last):
131
+ ...
132
+ edsl.questions.question_base_gen_mixin.QuestionBaseGenMixin.MaxTemplateNestingExceeded:...
98
133
  """
99
- from jinja2 import Environment
134
+ from jinja2 import Environment, meta
100
135
  from edsl.scenarios.Scenario import Scenario
101
136
 
137
+ MAX_NESTING = 10 # Maximum allowed nesting levels
138
+
102
139
  strings_only_replacement_dict = {
103
140
  k: v for k, v in replacement_dict.items() if not isinstance(v, Scenario)
104
141
  }
105
142
 
143
+ strings_only_replacement_dict['scenario'] = strings_only_replacement_dict
144
+
145
+ def _has_unrendered_variables(template_str: str, env: Environment) -> bool:
146
+ """Check if the template string has any unrendered variables."""
147
+ if not isinstance(template_str, str):
148
+ return False
149
+ ast = env.parse(template_str)
150
+ return bool(meta.find_undeclared_variables(ast))
151
+
106
152
  def render_string(value: str) -> str:
107
153
  if value is None or not isinstance(value, str):
108
154
  return value
109
- else:
110
- try:
111
- return (
112
- Environment()
113
- .from_string(value)
114
- .render(strings_only_replacement_dict)
115
- )
116
- except Exception as e:
117
- #breakpoint()
118
- import warnings
119
-
120
- warnings.warn("Failed to render string: " + value)
121
- # breakpoint()
122
- return value
123
-
124
- return self.apply_function(render_string)
125
-
155
+
156
+ try:
157
+ env = Environment()
158
+ result = value
159
+ nesting_count = 0
160
+
161
+ while _has_unrendered_variables(result, env):
162
+ if nesting_count >= MAX_NESTING:
163
+ raise self.MaxTemplateNestingExceeded(
164
+ f"Template rendering exceeded {MAX_NESTING} levels of nesting. "
165
+ f"Current value: {result}"
166
+ )
167
+
168
+ template = env.from_string(result)
169
+ new_result = template.render(strings_only_replacement_dict)
170
+ if new_result == result: # Break if no changes made
171
+ break
172
+ result = new_result
173
+ nesting_count += 1
174
+
175
+ return result
176
+ except self.MaxTemplateNestingExceeded:
177
+ raise
178
+ except Exception as e:
179
+ import warnings
180
+ warnings.warn("Failed to render string: " + value)
181
+ return value
182
+ if return_dict:
183
+ return self._apply_function_dict(render_string)
184
+ else:
185
+ return self.apply_function(render_string)
186
+
126
187
  def apply_function(
127
- self, func: Callable, exclude_components: List[str] = None
188
+ self, func: Callable, exclude_components: Optional[List[str]] = None
128
189
  ) -> QuestionBase:
129
- """Apply a function to the question parts
190
+ from edsl.questions.QuestionBase import QuestionBase
191
+ d = self._apply_function_dict(func, exclude_components)
192
+ return QuestionBase.from_dict(d)
193
+
194
+ def _apply_function_dict(
195
+ self, func: Callable, exclude_components: Optional[List[str]] = None
196
+ ) -> dict:
197
+ """Apply a function to the question parts, excluding certain components.
130
198
 
131
199
  :param func: The function to apply to the question parts.
132
200
  :param exclude_components: The components to exclude from the function application.
@@ -141,7 +209,6 @@ class QuestionBaseGenMixin:
141
209
  Question('free_text', question_name = \"""COLOR\""", question_text = \"""WHAT IS YOUR FAVORITE COLOR?\""")
142
210
 
143
211
  """
144
- from edsl.questions.QuestionBase import QuestionBase
145
212
 
146
213
  if exclude_components is None:
147
214
  exclude_components = ["question_name", "question_type"]
@@ -160,10 +227,10 @@ class QuestionBaseGenMixin:
160
227
  d[key] = value
161
228
  continue
162
229
  d[key] = func(value)
163
- return QuestionBase.from_dict(d)
230
+ return d
164
231
 
165
232
 
166
233
  if __name__ == "__main__":
167
234
  import doctest
168
235
 
169
- doctest.testmod()
236
+ doctest.testmod(optionflags=doctest.ELLIPSIS)
edsl/results/Dataset.py CHANGED
@@ -15,6 +15,7 @@ from edsl.Base import PersistenceMixin, HashingMixin
15
15
 
16
16
  from edsl.results.smart_objects import FirstObject
17
17
 
18
+ from edsl.results.ResultsGGMixin import GGPlotMethod
18
19
 
19
20
  class Dataset(UserList, ResultsExportMixin, PersistenceMixin, HashingMixin):
20
21
  """A class to represent a dataset of observations."""
@@ -26,6 +27,20 @@ class Dataset(UserList, ResultsExportMixin, PersistenceMixin, HashingMixin):
26
27
  super().__init__(data)
27
28
  self.print_parameters = print_parameters
28
29
 
30
+
31
+ def ggplot2(
32
+ self,
33
+ ggplot_code: str,
34
+ shape="wide",
35
+ sql: str = None,
36
+ remove_prefix: bool = True,
37
+ debug: bool = False,
38
+ height=4,
39
+ width=6,
40
+ factor_orders: Optional[dict] = None,
41
+ ):
42
+ return GGPlotMethod(self).ggplot2(ggplot_code, shape, sql, remove_prefix, debug, height, width, factor_orders)
43
+
29
44
  def __len__(self) -> int:
30
45
  """Return the number of observations in the dataset.
31
46
 
@@ -580,6 +595,56 @@ class Dataset(UserList, ResultsExportMixin, PersistenceMixin, HashingMixin):
580
595
  result = cls([{col: df[col].tolist()} for col in df.columns])
581
596
  return result
582
597
 
598
+ def to_docx(self, output_file: str, title: str = None) -> None:
599
+ """
600
+ Convert the dataset to a Word document.
601
+
602
+ Args:
603
+ output_file (str): Path to save the Word document
604
+ title (str, optional): Title for the document
605
+ """
606
+ from docx import Document
607
+ from docx.shared import Inches
608
+ from docx.enum.text import WD_ALIGN_PARAGRAPH
609
+
610
+ # Create document
611
+ doc = Document()
612
+
613
+ # Add title if provided
614
+ if title:
615
+ title_heading = doc.add_heading(title, level=1)
616
+ title_heading.alignment = WD_ALIGN_PARAGRAPH.CENTER
617
+
618
+ # Get headers and data
619
+ headers, data = self._tabular()
620
+
621
+ # Create table
622
+ table = doc.add_table(rows=len(data) + 1, cols=len(headers))
623
+ table.style = 'Table Grid'
624
+
625
+ # Add headers
626
+ for j, header in enumerate(headers):
627
+ cell = table.cell(0, j)
628
+ cell.text = str(header)
629
+
630
+ # Add data
631
+ for i, row in enumerate(data):
632
+ for j, cell_content in enumerate(row):
633
+ cell = table.cell(i + 1, j)
634
+ cell.text = str(cell_content) if cell_content is not None else ""
635
+
636
+ # Adjust column widths
637
+ for column in table.columns:
638
+ max_width = 0
639
+ for cell in column.cells:
640
+ text_width = len(str(cell.text))
641
+ max_width = max(max_width, text_width)
642
+ for cell in column.cells:
643
+ cell.width = Inches(min(max_width * 0.1 + 0.5, 6))
644
+
645
+ # Save the document
646
+ doc.save(output_file)
647
+
583
648
 
584
649
  if __name__ == "__main__":
585
650
  import doctest
@@ -7,6 +7,7 @@ from typing import Optional, Tuple, Union, List
7
7
 
8
8
  from edsl.results.file_exports import CSVExport, ExcelExport, JSONLExport, SQLiteExport
9
9
 
10
+
10
11
  class DatasetExportMixin:
11
12
  """Mixin class for exporting Dataset objects."""
12
13
 
@@ -82,7 +83,8 @@ class DatasetExportMixin:
82
83
  else:
83
84
  if len(values) != _num_observations:
84
85
  raise ValueError(
85
- "The number of observations is not consistent across columns."
86
+ f"The number of observations is not consistent across columns. "
87
+ f"Column '{key}' has {len(values)} observations, but previous columns had {_num_observations} observations."
86
88
  )
87
89
 
88
90
  return _num_observations
@@ -219,7 +221,9 @@ class DatasetExportMixin:
219
221
  )
220
222
  return exporter.export()
221
223
 
222
- def _db(self, remove_prefix: bool = True, shape: str = "wide") -> "sqlalchemy.engine.Engine":
224
+ def _db(
225
+ self, remove_prefix: bool = True, shape: str = "wide"
226
+ ) -> "sqlalchemy.engine.Engine":
223
227
  """Create a SQLite database in memory and return the connection.
224
228
 
225
229
  Args:
@@ -229,7 +233,7 @@ class DatasetExportMixin:
229
233
  Returns:
230
234
  A database connection
231
235
  >>> from sqlalchemy import text
232
- >>> from edsl import Results
236
+ >>> from edsl import Results
233
237
  >>> engine = Results.example()._db()
234
238
  >>> len(engine.execute(text("SELECT * FROM self")).fetchall())
235
239
  4
@@ -247,16 +251,17 @@ class DatasetExportMixin:
247
251
 
248
252
  if shape == "long":
249
253
  # Melt the dataframe to convert it to long format
250
- df = df.melt(
251
- var_name='key',
252
- value_name='value'
253
- )
254
+ df = df.melt(var_name="key", value_name="value")
254
255
  # Add a row number column for reference
255
- df.insert(0, 'row_number', range(1, len(df) + 1))
256
-
256
+ df.insert(0, "row_number", range(1, len(df) + 1))
257
+
257
258
  # Split the key into data_type and key
258
- df['data_type'] = df['key'].apply(lambda x: x.split('.')[0] if '.' in x else None)
259
- df['key'] = df['key'].apply(lambda x: '.'.join(x.split('.')[1:]) if '.' in x else x)
259
+ df["data_type"] = df["key"].apply(
260
+ lambda x: x.split(".")[0] if "." in x else None
261
+ )
262
+ df["key"] = df["key"].apply(
263
+ lambda x: ".".join(x.split(".")[1:]) if "." in x else x
264
+ )
260
265
 
261
266
  df.to_sql(
262
267
  "self",
@@ -276,27 +281,27 @@ class DatasetExportMixin:
276
281
  ) -> Union["pd.DataFrame", str]:
277
282
  """Execute a SQL query and return the results as a DataFrame.
278
283
 
279
- Args:
280
- query: The SQL query to execute
281
- shape: The shape of the data in the database (wide or long)
282
- remove_prefix: Whether to remove the prefix from the column names
283
- transpose: Whether to transpose the DataFrame
284
- transpose_by: The column to use as the index when transposing
285
- csv: Whether to return the DataFrame as a CSV string
286
- to_list: Whether to return the results as a list
287
- to_latex: Whether to return the results as LaTeX
288
- filename: Optional filename to save the results to
289
-
290
- Returns:
291
- DataFrame, CSV string, list, or LaTeX string depending on parameters
292
-
293
- Examples:
294
- >>> from edsl import Results
295
- >>> r = Results.example();
296
- >>> len(r.sql("SELECT * FROM self", shape = "wide"))
297
- 4
298
- >>> len(r.sql("SELECT * FROM self", shape = "long"))
299
- 172
284
+ Args:
285
+ query: The SQL query to execute
286
+ shape: The shape of the data in the database (wide or long)
287
+ remove_prefix: Whether to remove the prefix from the column names
288
+ transpose: Whether to transpose the DataFrame
289
+ transpose_by: The column to use as the index when transposing
290
+ csv: Whether to return the DataFrame as a CSV string
291
+ to_list: Whether to return the results as a list
292
+ to_latex: Whether to return the results as LaTeX
293
+ filename: Optional filename to save the results to
294
+
295
+ Returns:
296
+ DataFrame, CSV string, list, or LaTeX string depending on parameters
297
+
298
+ Examples:
299
+ >>> from edsl import Results
300
+ >>> r = Results.example();
301
+ >>> len(r.sql("SELECT * FROM self", shape = "wide"))
302
+ 4
303
+ >>> len(r.sql("SELECT * FROM self", shape = "long"))
304
+ 172
300
305
  """
301
306
  import pandas as pd
302
307
 
@@ -538,6 +543,116 @@ class DatasetExportMixin:
538
543
 
539
544
  if return_link:
540
545
  return filename
546
+
547
+ def report(self, *fields: Optional[str], top_n: Optional[int] = None,
548
+ header_fields: Optional[List[str]] = None, divider: bool = True,
549
+ return_string: bool = False) -> Optional[str]:
550
+ """Takes the fields in order and returns a report of the results by iterating through rows.
551
+ The row number is printed as # Observation: <row number>
552
+ The name of the field is used as markdown header at level "##"
553
+ The content of that field is then printed.
554
+ Then the next field and so on.
555
+ Once that row is done, a new line is printed and the next row is shown.
556
+ If in a jupyter notebook, the report is displayed as markdown.
557
+
558
+ Args:
559
+ *fields: The fields to include in the report. If none provided, all fields are used.
560
+ top_n: Optional limit on the number of observations to include.
561
+ header_fields: Optional list of fields to include in the main header instead of as sections.
562
+ divider: If True, adds a horizontal rule between observations for better visual separation.
563
+ return_string: If True, returns the markdown string. If False (default in notebooks),
564
+ only displays the markdown without returning.
565
+
566
+ Returns:
567
+ A string containing the markdown report if return_string is True, otherwise None.
568
+
569
+ Examples:
570
+ >>> from edsl.results import Results
571
+ >>> r = Results.example()
572
+ >>> report = r.select('how_feeling', 'how_feeling_yesterday').report(return_string=True)
573
+ >>> "# Observation: 1" in report
574
+ True
575
+ >>> "## answer.how_feeling" in report
576
+ True
577
+ >>> report = r.select('how_feeling').report(header_fields=['answer.how_feeling'], return_string=True)
578
+ >>> "# Observation: 1 (`how_feeling`: OK)" in report
579
+ True
580
+ """
581
+ from edsl.utilities.utilities import is_notebook
582
+
583
+ # If no fields specified, use all columns
584
+ if not fields:
585
+ fields = self.relevant_columns()
586
+
587
+ # Initialize header_fields if not provided
588
+ if header_fields is None:
589
+ header_fields = []
590
+
591
+ # Validate all fields
592
+ all_fields = list(fields) + [f for f in header_fields if f not in fields]
593
+ for field in all_fields:
594
+ if field not in self.relevant_columns():
595
+ raise ValueError(f"Field '{field}' not found in dataset")
596
+
597
+ # Get data for each field
598
+ field_data = {}
599
+ for field in all_fields:
600
+ for entry in self:
601
+ if field in entry:
602
+ field_data[field] = entry[field]
603
+ break
604
+
605
+ # Number of observations to process
606
+ num_obs = self.num_observations()
607
+ if top_n is not None:
608
+ num_obs = min(num_obs, top_n)
609
+
610
+ # Build the report
611
+ report_lines = []
612
+ for i in range(num_obs):
613
+ # Create header with observation number and any header fields
614
+ header = f"# Observation: {i+1}"
615
+ if header_fields:
616
+ header_parts = []
617
+ for field in header_fields:
618
+ value = field_data[field][i]
619
+ # Get the field name without prefix for cleaner display
620
+ display_name = field.split('.')[-1] if '.' in field else field
621
+ # Format with backticks for monospace
622
+ header_parts.append(f"`{display_name}`: {value}")
623
+ if header_parts:
624
+ header += f" ({', '.join(header_parts)})"
625
+ report_lines.append(header)
626
+
627
+ # Add the remaining fields
628
+ for field in fields:
629
+ if field not in header_fields:
630
+ report_lines.append(f"## {field}")
631
+ value = field_data[field][i]
632
+ if isinstance(value, list) or isinstance(value, dict):
633
+ import json
634
+ report_lines.append(f"```\n{json.dumps(value, indent=2)}\n```")
635
+ else:
636
+ report_lines.append(str(value))
637
+
638
+ # Add divider between observations if requested
639
+ if divider and i < num_obs - 1:
640
+ report_lines.append("\n---\n")
641
+ else:
642
+ report_lines.append("") # Empty line between observations
643
+
644
+ report_text = "\n".join(report_lines)
645
+
646
+ # In notebooks, display as markdown and optionally return
647
+ is_nb = is_notebook()
648
+ if is_nb:
649
+ from IPython.display import Markdown, display
650
+ display(Markdown(report_text))
651
+
652
+ # Return the string if requested or if not in a notebook
653
+ if return_string or not is_nb:
654
+ return report_text
655
+ return None
541
656
 
542
657
  def tally(
543
658
  self, *fields: Optional[str], top_n: Optional[int] = None, output="Dataset"
@@ -616,6 +731,179 @@ class DatasetExportMixin:
616
731
  keys.append("count")
617
732
  return sl.reorder_keys(keys).to_dataset()
618
733
 
734
+ def flatten(self, field, keep_original=False):
735
+ """
736
+ Flatten a field containing a list of dictionaries into separate fields.
737
+
738
+ >>> from edsl.results.Dataset import Dataset
739
+ >>> Dataset([{'a': [{'a': 1, 'b': 2}]}, {'c': [5] }]).flatten('a')
740
+ Dataset([{'c': [5]}, {'a.a': [1]}, {'a.b': [2]}])
741
+
742
+
743
+ >>> Dataset([{'answer.example': [{'a': 1, 'b': 2}]}, {'c': [5] }]).flatten('answer.example')
744
+ Dataset([{'c': [5]}, {'answer.example.a': [1]}, {'answer.example.b': [2]}])
745
+
746
+
747
+ Args:
748
+ field: The field to flatten
749
+ keep_original: If True, keeps the original field in the dataset
750
+
751
+ Returns:
752
+ A new dataset with the flattened fields
753
+ """
754
+ from edsl.results.Dataset import Dataset
755
+
756
+ # Ensure the dataset isn't empty
757
+ if not self.data:
758
+ return self.copy()
759
+
760
+ # Find all columns that contain the field
761
+ matching_entries = []
762
+ for entry in self.data:
763
+ col_name = next(iter(entry.keys()))
764
+ if field == col_name or (
765
+ '.' in col_name and
766
+ (col_name.endswith('.' + field) or col_name.startswith(field + '.'))
767
+ ):
768
+ matching_entries.append(entry)
769
+
770
+ # Check if the field is ambiguous
771
+ if len(matching_entries) > 1:
772
+ matching_cols = [next(iter(entry.keys())) for entry in matching_entries]
773
+ raise ValueError(
774
+ f"Ambiguous field name '{field}'. It matches multiple columns: {matching_cols}. "
775
+ f"Please specify the full column name to flatten."
776
+ )
777
+
778
+ # Get the number of observations
779
+ num_observations = self.num_observations()
780
+
781
+ # Find the column to flatten
782
+ field_entry = None
783
+ for entry in self.data:
784
+ if field in entry:
785
+ field_entry = entry
786
+ break
787
+
788
+ if field_entry is None:
789
+ warnings.warn(
790
+ f"Field '{field}' not found in dataset, returning original dataset"
791
+ )
792
+ return self.copy()
793
+
794
+ # Create new dictionary for flattened data
795
+ flattened_data = []
796
+
797
+ # Copy all existing columns except the one we're flattening (if keep_original is False)
798
+ for entry in self.data:
799
+ col_name = next(iter(entry.keys()))
800
+ if col_name != field or keep_original:
801
+ flattened_data.append(entry.copy())
802
+
803
+ # Get field data and make sure it's valid
804
+ field_values = field_entry[field]
805
+ if not all(isinstance(item, dict) for item in field_values if item is not None):
806
+ warnings.warn(
807
+ f"Field '{field}' contains non-dictionary values that cannot be flattened"
808
+ )
809
+ return self.copy()
810
+
811
+ # Collect all unique keys across all dictionaries
812
+ all_keys = set()
813
+ for item in field_values:
814
+ if isinstance(item, dict):
815
+ all_keys.update(item.keys())
816
+
817
+ # Create new columns for each key
818
+ for key in sorted(all_keys): # Sort for consistent output
819
+ new_values = []
820
+ for i in range(num_observations):
821
+ value = None
822
+ if i < len(field_values) and isinstance(field_values[i], dict):
823
+ value = field_values[i].get(key, None)
824
+ new_values.append(value)
825
+
826
+ # Add this as a new column
827
+ flattened_data.append({f"{field}.{key}": new_values})
828
+
829
+ # Return a new Dataset with the flattened data
830
+ return Dataset(flattened_data)
831
+
832
+ def unpack_list(
833
+ self,
834
+ field: str,
835
+ new_names: Optional[List[str]] = None,
836
+ keep_original: bool = True,
837
+ ) -> "Dataset":
838
+ """Unpack list columns into separate columns with provided names or numeric suffixes.
839
+
840
+ For example, if a dataset contains:
841
+ [{'data': [[1, 2, 3], [4, 5, 6]], 'other': ['x', 'y']}]
842
+
843
+ After d.unpack_list('data'), it should become:
844
+ [{'other': ['x', 'y'], 'data_1': [1, 4], 'data_2': [2, 5], 'data_3': [3, 6]}]
845
+
846
+ Args:
847
+ field: The field containing lists to unpack
848
+ new_names: Optional list of names for the unpacked fields. If None, uses numeric suffixes.
849
+ keep_original: If True, keeps the original field in the dataset
850
+
851
+ Returns:
852
+ A new Dataset with unpacked columns
853
+
854
+ Examples:
855
+ >>> from edsl.results.Dataset import Dataset
856
+ >>> d = Dataset([{'data': [[1, 2, 3], [4, 5, 6]]}])
857
+ >>> d.unpack_list('data')
858
+ Dataset([{'data': [[1, 2, 3], [4, 5, 6]]}, {'data_1': [1, 4]}, {'data_2': [2, 5]}, {'data_3': [3, 6]}])
859
+
860
+ >>> d.unpack_list('data', new_names=['first', 'second', 'third'])
861
+ Dataset([{'data': [[1, 2, 3], [4, 5, 6]]}, {'first': [1, 4]}, {'second': [2, 5]}, {'third': [3, 6]}])
862
+ """
863
+ from edsl.results.Dataset import Dataset
864
+
865
+ # Create a copy of the dataset
866
+ result = Dataset(self.data.copy())
867
+
868
+ # Find the field in the dataset
869
+ field_index = None
870
+ for i, entry in enumerate(result.data):
871
+ if field in entry:
872
+ field_index = i
873
+ break
874
+
875
+ if field_index is None:
876
+ raise ValueError(f"Field '{field}' not found in dataset")
877
+
878
+ field_data = result.data[field_index][field]
879
+
880
+ # Check if values are lists
881
+ if not all(isinstance(v, list) for v in field_data):
882
+ raise ValueError(f"Field '{field}' does not contain lists in all entries")
883
+
884
+ # Get the maximum length of lists
885
+ max_len = max(len(v) for v in field_data)
886
+
887
+ # Create new fields for each index
888
+ for i in range(max_len):
889
+ if new_names and i < len(new_names):
890
+ new_field = new_names[i]
891
+ else:
892
+ new_field = f"{field}_{i+1}"
893
+
894
+ # Extract the i-th element from each list
895
+ new_values = []
896
+ for item in field_data:
897
+ new_values.append(item[i] if i < len(item) else None)
898
+
899
+ result.data.append({new_field: new_values})
900
+
901
+ # Remove the original field if keep_original is False
902
+ if not keep_original:
903
+ result.data.pop(field_index)
904
+
905
+ return result
906
+
619
907
 
620
908
  if __name__ == "__main__":
621
909
  import doctest