edsl 0.1.44__py3-none-any.whl → 0.1.46__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +7 -3
- edsl/__version__.py +1 -1
- edsl/agents/InvigilatorBase.py +3 -1
- edsl/agents/PromptConstructor.py +66 -91
- edsl/agents/QuestionInstructionPromptBuilder.py +160 -79
- edsl/agents/QuestionTemplateReplacementsBuilder.py +80 -17
- edsl/agents/question_option_processor.py +15 -6
- edsl/coop/CoopFunctionsMixin.py +3 -4
- edsl/coop/coop.py +171 -96
- edsl/data/RemoteCacheSync.py +10 -9
- edsl/enums.py +3 -3
- edsl/inference_services/AnthropicService.py +11 -9
- edsl/inference_services/AvailableModelFetcher.py +2 -0
- edsl/inference_services/AwsBedrock.py +1 -2
- edsl/inference_services/AzureAI.py +12 -9
- edsl/inference_services/GoogleService.py +9 -4
- edsl/inference_services/InferenceServicesCollection.py +2 -2
- edsl/inference_services/MistralAIService.py +1 -2
- edsl/inference_services/OpenAIService.py +9 -4
- edsl/inference_services/PerplexityService.py +2 -1
- edsl/inference_services/{GrokService.py → XAIService.py} +2 -2
- edsl/inference_services/registry.py +2 -2
- edsl/jobs/AnswerQuestionFunctionConstructor.py +12 -1
- edsl/jobs/Jobs.py +24 -17
- edsl/jobs/JobsChecks.py +10 -13
- edsl/jobs/JobsPrompts.py +49 -26
- edsl/jobs/JobsRemoteInferenceHandler.py +4 -5
- edsl/jobs/async_interview_runner.py +3 -1
- edsl/jobs/check_survey_scenario_compatibility.py +5 -5
- edsl/jobs/data_structures.py +3 -0
- edsl/jobs/interviews/Interview.py +6 -3
- edsl/jobs/interviews/InterviewExceptionEntry.py +12 -0
- edsl/jobs/tasks/TaskHistory.py +1 -1
- edsl/language_models/LanguageModel.py +6 -3
- edsl/language_models/PriceManager.py +45 -5
- edsl/language_models/model.py +47 -26
- edsl/questions/QuestionBase.py +21 -0
- edsl/questions/QuestionBasePromptsMixin.py +103 -0
- edsl/questions/QuestionFreeText.py +22 -5
- edsl/questions/descriptors.py +4 -0
- edsl/questions/question_base_gen_mixin.py +96 -29
- edsl/results/Dataset.py +65 -0
- edsl/results/DatasetExportMixin.py +320 -32
- edsl/results/Result.py +27 -0
- edsl/results/Results.py +22 -2
- edsl/results/ResultsGGMixin.py +7 -3
- edsl/scenarios/DocumentChunker.py +2 -0
- edsl/scenarios/FileStore.py +10 -0
- edsl/scenarios/PdfExtractor.py +21 -1
- edsl/scenarios/Scenario.py +25 -9
- edsl/scenarios/ScenarioList.py +226 -24
- edsl/scenarios/handlers/__init__.py +1 -0
- edsl/scenarios/handlers/docx.py +5 -1
- edsl/scenarios/handlers/jpeg.py +39 -0
- edsl/surveys/Survey.py +5 -4
- edsl/surveys/SurveyFlowVisualization.py +91 -43
- edsl/templates/error_reporting/exceptions_table.html +7 -8
- edsl/templates/error_reporting/interview_details.html +1 -1
- edsl/templates/error_reporting/interviews.html +0 -1
- edsl/templates/error_reporting/overview.html +2 -7
- edsl/templates/error_reporting/performance_plot.html +1 -1
- edsl/templates/error_reporting/report.css +1 -1
- edsl/utilities/PrettyList.py +14 -0
- edsl-0.1.46.dist-info/METADATA +246 -0
- {edsl-0.1.44.dist-info → edsl-0.1.46.dist-info}/RECORD +67 -66
- edsl-0.1.44.dist-info/METADATA +0 -110
- {edsl-0.1.44.dist-info → edsl-0.1.46.dist-info}/LICENSE +0 -0
- {edsl-0.1.44.dist-info → edsl-0.1.46.dist-info}/WHEEL +0 -0
@@ -1,7 +1,7 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
import copy
|
3
3
|
import itertools
|
4
|
-
from typing import Optional, List, Callable, Type, TYPE_CHECKING
|
4
|
+
from typing import Optional, List, Callable, Type, TYPE_CHECKING, Union
|
5
5
|
|
6
6
|
if TYPE_CHECKING:
|
7
7
|
from edsl.questions.QuestionBase import QuestionBase
|
@@ -9,7 +9,11 @@ if TYPE_CHECKING:
|
|
9
9
|
|
10
10
|
|
11
11
|
class QuestionBaseGenMixin:
|
12
|
-
"""Mixin for QuestionBase.
|
12
|
+
"""Mixin for QuestionBase.
|
13
|
+
|
14
|
+
This mostly has functions that are used to generate new questions from existing ones.
|
15
|
+
|
16
|
+
"""
|
13
17
|
|
14
18
|
def copy(self) -> QuestionBase:
|
15
19
|
"""Return a deep copy of the question.
|
@@ -85,48 +89,112 @@ class QuestionBaseGenMixin:
|
|
85
89
|
lp = LoopProcessor(self)
|
86
90
|
return lp.process_templates(scenario_list)
|
87
91
|
|
88
|
-
|
89
|
-
"""
|
90
|
-
|
91
|
-
:param replacement_dict: The dictionary of values to replace in the question components.
|
92
|
+
class MaxTemplateNestingExceeded(Exception):
|
93
|
+
"""Raised when template rendering exceeds maximum allowed nesting level."""
|
94
|
+
pass
|
92
95
|
|
96
|
+
def render(self, replacement_dict: dict, return_dict: bool = False) -> Union["QuestionBase", dict]:
|
97
|
+
"""Render the question components as jinja2 templates with the replacement dictionary.
|
98
|
+
Handles nested template variables by recursively rendering until all variables are resolved.
|
99
|
+
|
100
|
+
Raises:
|
101
|
+
MaxTemplateNestingExceeded: If template nesting exceeds MAX_NESTING levels
|
102
|
+
|
93
103
|
>>> from edsl.questions.QuestionFreeText import QuestionFreeText
|
94
104
|
>>> q = QuestionFreeText(question_name = "color", question_text = "What is your favorite {{ thing }}?")
|
95
105
|
>>> q.render({"thing": "color"})
|
96
106
|
Question('free_text', question_name = \"""color\""", question_text = \"""What is your favorite color?\""")
|
97
107
|
|
108
|
+
>>> from edsl.questions.QuestionMultipleChoice import QuestionMultipleChoice
|
109
|
+
>>> q = QuestionMultipleChoice(question_name = "color", question_text = "What is your favorite {{ thing }}?", question_options = ["red", "blue", "green"])
|
110
|
+
>>> from edsl.scenarios.Scenario import Scenario
|
111
|
+
>>> q.render(Scenario({"thing": "color"})).data
|
112
|
+
{'question_name': 'color', 'question_text': 'What is your favorite color?', 'question_options': ['red', 'blue', 'green']}
|
113
|
+
|
114
|
+
>>> from edsl.questions.QuestionMultipleChoice import QuestionMultipleChoice
|
115
|
+
>>> q = QuestionMultipleChoice(question_name = "color", question_text = "What is your favorite {{ thing }}?", question_options = ["red", "blue", "green"])
|
116
|
+
>>> q.render({"thing": 1}).data
|
117
|
+
{'question_name': 'color', 'question_text': 'What is your favorite 1?', 'question_options': ['red', 'blue', 'green']}
|
118
|
+
|
119
|
+
|
120
|
+
>>> from edsl.questions.QuestionMultipleChoice import QuestionMultipleChoice
|
121
|
+
>>> from edsl.scenarios.Scenario import Scenario
|
122
|
+
>>> q = QuestionMultipleChoice(question_name = "color", question_text = "What is your favorite {{ thing }}?", question_options = ["red", "blue", "green"])
|
123
|
+
>>> q.render(Scenario({"thing": "color of {{ object }}", "object":"water"})).data
|
124
|
+
{'question_name': 'color', 'question_text': 'What is your favorite color of water?', 'question_options': ['red', 'blue', 'green']}
|
125
|
+
|
126
|
+
|
127
|
+
>>> from edsl.questions.QuestionFreeText import QuestionFreeText
|
128
|
+
>>> q = QuestionFreeText(question_name = "infinite", question_text = "This has {{ a }}")
|
129
|
+
>>> q.render({"a": "{{ b }}", "b": "{{ a }}"}) # doctest: +IGNORE_EXCEPTION_DETAIL
|
130
|
+
Traceback (most recent call last):
|
131
|
+
...
|
132
|
+
edsl.questions.question_base_gen_mixin.QuestionBaseGenMixin.MaxTemplateNestingExceeded:...
|
98
133
|
"""
|
99
|
-
from jinja2 import Environment
|
134
|
+
from jinja2 import Environment, meta
|
100
135
|
from edsl.scenarios.Scenario import Scenario
|
101
136
|
|
137
|
+
MAX_NESTING = 10 # Maximum allowed nesting levels
|
138
|
+
|
102
139
|
strings_only_replacement_dict = {
|
103
140
|
k: v for k, v in replacement_dict.items() if not isinstance(v, Scenario)
|
104
141
|
}
|
105
142
|
|
143
|
+
strings_only_replacement_dict['scenario'] = strings_only_replacement_dict
|
144
|
+
|
145
|
+
def _has_unrendered_variables(template_str: str, env: Environment) -> bool:
|
146
|
+
"""Check if the template string has any unrendered variables."""
|
147
|
+
if not isinstance(template_str, str):
|
148
|
+
return False
|
149
|
+
ast = env.parse(template_str)
|
150
|
+
return bool(meta.find_undeclared_variables(ast))
|
151
|
+
|
106
152
|
def render_string(value: str) -> str:
|
107
153
|
if value is None or not isinstance(value, str):
|
108
154
|
return value
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
155
|
+
|
156
|
+
try:
|
157
|
+
env = Environment()
|
158
|
+
result = value
|
159
|
+
nesting_count = 0
|
160
|
+
|
161
|
+
while _has_unrendered_variables(result, env):
|
162
|
+
if nesting_count >= MAX_NESTING:
|
163
|
+
raise self.MaxTemplateNestingExceeded(
|
164
|
+
f"Template rendering exceeded {MAX_NESTING} levels of nesting. "
|
165
|
+
f"Current value: {result}"
|
166
|
+
)
|
167
|
+
|
168
|
+
template = env.from_string(result)
|
169
|
+
new_result = template.render(strings_only_replacement_dict)
|
170
|
+
if new_result == result: # Break if no changes made
|
171
|
+
break
|
172
|
+
result = new_result
|
173
|
+
nesting_count += 1
|
174
|
+
|
175
|
+
return result
|
176
|
+
except self.MaxTemplateNestingExceeded:
|
177
|
+
raise
|
178
|
+
except Exception as e:
|
179
|
+
import warnings
|
180
|
+
warnings.warn("Failed to render string: " + value)
|
181
|
+
return value
|
182
|
+
if return_dict:
|
183
|
+
return self._apply_function_dict(render_string)
|
184
|
+
else:
|
185
|
+
return self.apply_function(render_string)
|
186
|
+
|
126
187
|
def apply_function(
|
127
|
-
self, func: Callable, exclude_components: List[str] = None
|
188
|
+
self, func: Callable, exclude_components: Optional[List[str]] = None
|
128
189
|
) -> QuestionBase:
|
129
|
-
|
190
|
+
from edsl.questions.QuestionBase import QuestionBase
|
191
|
+
d = self._apply_function_dict(func, exclude_components)
|
192
|
+
return QuestionBase.from_dict(d)
|
193
|
+
|
194
|
+
def _apply_function_dict(
|
195
|
+
self, func: Callable, exclude_components: Optional[List[str]] = None
|
196
|
+
) -> dict:
|
197
|
+
"""Apply a function to the question parts, excluding certain components.
|
130
198
|
|
131
199
|
:param func: The function to apply to the question parts.
|
132
200
|
:param exclude_components: The components to exclude from the function application.
|
@@ -141,7 +209,6 @@ class QuestionBaseGenMixin:
|
|
141
209
|
Question('free_text', question_name = \"""COLOR\""", question_text = \"""WHAT IS YOUR FAVORITE COLOR?\""")
|
142
210
|
|
143
211
|
"""
|
144
|
-
from edsl.questions.QuestionBase import QuestionBase
|
145
212
|
|
146
213
|
if exclude_components is None:
|
147
214
|
exclude_components = ["question_name", "question_type"]
|
@@ -160,10 +227,10 @@ class QuestionBaseGenMixin:
|
|
160
227
|
d[key] = value
|
161
228
|
continue
|
162
229
|
d[key] = func(value)
|
163
|
-
return
|
230
|
+
return d
|
164
231
|
|
165
232
|
|
166
233
|
if __name__ == "__main__":
|
167
234
|
import doctest
|
168
235
|
|
169
|
-
doctest.testmod()
|
236
|
+
doctest.testmod(optionflags=doctest.ELLIPSIS)
|
edsl/results/Dataset.py
CHANGED
@@ -15,6 +15,7 @@ from edsl.Base import PersistenceMixin, HashingMixin
|
|
15
15
|
|
16
16
|
from edsl.results.smart_objects import FirstObject
|
17
17
|
|
18
|
+
from edsl.results.ResultsGGMixin import GGPlotMethod
|
18
19
|
|
19
20
|
class Dataset(UserList, ResultsExportMixin, PersistenceMixin, HashingMixin):
|
20
21
|
"""A class to represent a dataset of observations."""
|
@@ -26,6 +27,20 @@ class Dataset(UserList, ResultsExportMixin, PersistenceMixin, HashingMixin):
|
|
26
27
|
super().__init__(data)
|
27
28
|
self.print_parameters = print_parameters
|
28
29
|
|
30
|
+
|
31
|
+
def ggplot2(
|
32
|
+
self,
|
33
|
+
ggplot_code: str,
|
34
|
+
shape="wide",
|
35
|
+
sql: str = None,
|
36
|
+
remove_prefix: bool = True,
|
37
|
+
debug: bool = False,
|
38
|
+
height=4,
|
39
|
+
width=6,
|
40
|
+
factor_orders: Optional[dict] = None,
|
41
|
+
):
|
42
|
+
return GGPlotMethod(self).ggplot2(ggplot_code, shape, sql, remove_prefix, debug, height, width, factor_orders)
|
43
|
+
|
29
44
|
def __len__(self) -> int:
|
30
45
|
"""Return the number of observations in the dataset.
|
31
46
|
|
@@ -580,6 +595,56 @@ class Dataset(UserList, ResultsExportMixin, PersistenceMixin, HashingMixin):
|
|
580
595
|
result = cls([{col: df[col].tolist()} for col in df.columns])
|
581
596
|
return result
|
582
597
|
|
598
|
+
def to_docx(self, output_file: str, title: str = None) -> None:
|
599
|
+
"""
|
600
|
+
Convert the dataset to a Word document.
|
601
|
+
|
602
|
+
Args:
|
603
|
+
output_file (str): Path to save the Word document
|
604
|
+
title (str, optional): Title for the document
|
605
|
+
"""
|
606
|
+
from docx import Document
|
607
|
+
from docx.shared import Inches
|
608
|
+
from docx.enum.text import WD_ALIGN_PARAGRAPH
|
609
|
+
|
610
|
+
# Create document
|
611
|
+
doc = Document()
|
612
|
+
|
613
|
+
# Add title if provided
|
614
|
+
if title:
|
615
|
+
title_heading = doc.add_heading(title, level=1)
|
616
|
+
title_heading.alignment = WD_ALIGN_PARAGRAPH.CENTER
|
617
|
+
|
618
|
+
# Get headers and data
|
619
|
+
headers, data = self._tabular()
|
620
|
+
|
621
|
+
# Create table
|
622
|
+
table = doc.add_table(rows=len(data) + 1, cols=len(headers))
|
623
|
+
table.style = 'Table Grid'
|
624
|
+
|
625
|
+
# Add headers
|
626
|
+
for j, header in enumerate(headers):
|
627
|
+
cell = table.cell(0, j)
|
628
|
+
cell.text = str(header)
|
629
|
+
|
630
|
+
# Add data
|
631
|
+
for i, row in enumerate(data):
|
632
|
+
for j, cell_content in enumerate(row):
|
633
|
+
cell = table.cell(i + 1, j)
|
634
|
+
cell.text = str(cell_content) if cell_content is not None else ""
|
635
|
+
|
636
|
+
# Adjust column widths
|
637
|
+
for column in table.columns:
|
638
|
+
max_width = 0
|
639
|
+
for cell in column.cells:
|
640
|
+
text_width = len(str(cell.text))
|
641
|
+
max_width = max(max_width, text_width)
|
642
|
+
for cell in column.cells:
|
643
|
+
cell.width = Inches(min(max_width * 0.1 + 0.5, 6))
|
644
|
+
|
645
|
+
# Save the document
|
646
|
+
doc.save(output_file)
|
647
|
+
|
583
648
|
|
584
649
|
if __name__ == "__main__":
|
585
650
|
import doctest
|
@@ -7,6 +7,7 @@ from typing import Optional, Tuple, Union, List
|
|
7
7
|
|
8
8
|
from edsl.results.file_exports import CSVExport, ExcelExport, JSONLExport, SQLiteExport
|
9
9
|
|
10
|
+
|
10
11
|
class DatasetExportMixin:
|
11
12
|
"""Mixin class for exporting Dataset objects."""
|
12
13
|
|
@@ -82,7 +83,8 @@ class DatasetExportMixin:
|
|
82
83
|
else:
|
83
84
|
if len(values) != _num_observations:
|
84
85
|
raise ValueError(
|
85
|
-
"The number of observations is not consistent across columns."
|
86
|
+
f"The number of observations is not consistent across columns. "
|
87
|
+
f"Column '{key}' has {len(values)} observations, but previous columns had {_num_observations} observations."
|
86
88
|
)
|
87
89
|
|
88
90
|
return _num_observations
|
@@ -219,7 +221,9 @@ class DatasetExportMixin:
|
|
219
221
|
)
|
220
222
|
return exporter.export()
|
221
223
|
|
222
|
-
def _db(
|
224
|
+
def _db(
|
225
|
+
self, remove_prefix: bool = True, shape: str = "wide"
|
226
|
+
) -> "sqlalchemy.engine.Engine":
|
223
227
|
"""Create a SQLite database in memory and return the connection.
|
224
228
|
|
225
229
|
Args:
|
@@ -229,7 +233,7 @@ class DatasetExportMixin:
|
|
229
233
|
Returns:
|
230
234
|
A database connection
|
231
235
|
>>> from sqlalchemy import text
|
232
|
-
>>> from edsl import Results
|
236
|
+
>>> from edsl import Results
|
233
237
|
>>> engine = Results.example()._db()
|
234
238
|
>>> len(engine.execute(text("SELECT * FROM self")).fetchall())
|
235
239
|
4
|
@@ -247,16 +251,17 @@ class DatasetExportMixin:
|
|
247
251
|
|
248
252
|
if shape == "long":
|
249
253
|
# Melt the dataframe to convert it to long format
|
250
|
-
df = df.melt(
|
251
|
-
var_name='key',
|
252
|
-
value_name='value'
|
253
|
-
)
|
254
|
+
df = df.melt(var_name="key", value_name="value")
|
254
255
|
# Add a row number column for reference
|
255
|
-
df.insert(0,
|
256
|
-
|
256
|
+
df.insert(0, "row_number", range(1, len(df) + 1))
|
257
|
+
|
257
258
|
# Split the key into data_type and key
|
258
|
-
df[
|
259
|
-
|
259
|
+
df["data_type"] = df["key"].apply(
|
260
|
+
lambda x: x.split(".")[0] if "." in x else None
|
261
|
+
)
|
262
|
+
df["key"] = df["key"].apply(
|
263
|
+
lambda x: ".".join(x.split(".")[1:]) if "." in x else x
|
264
|
+
)
|
260
265
|
|
261
266
|
df.to_sql(
|
262
267
|
"self",
|
@@ -276,27 +281,27 @@ class DatasetExportMixin:
|
|
276
281
|
) -> Union["pd.DataFrame", str]:
|
277
282
|
"""Execute a SQL query and return the results as a DataFrame.
|
278
283
|
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
284
|
+
Args:
|
285
|
+
query: The SQL query to execute
|
286
|
+
shape: The shape of the data in the database (wide or long)
|
287
|
+
remove_prefix: Whether to remove the prefix from the column names
|
288
|
+
transpose: Whether to transpose the DataFrame
|
289
|
+
transpose_by: The column to use as the index when transposing
|
290
|
+
csv: Whether to return the DataFrame as a CSV string
|
291
|
+
to_list: Whether to return the results as a list
|
292
|
+
to_latex: Whether to return the results as LaTeX
|
293
|
+
filename: Optional filename to save the results to
|
294
|
+
|
295
|
+
Returns:
|
296
|
+
DataFrame, CSV string, list, or LaTeX string depending on parameters
|
297
|
+
|
298
|
+
Examples:
|
299
|
+
>>> from edsl import Results
|
300
|
+
>>> r = Results.example();
|
301
|
+
>>> len(r.sql("SELECT * FROM self", shape = "wide"))
|
302
|
+
4
|
303
|
+
>>> len(r.sql("SELECT * FROM self", shape = "long"))
|
304
|
+
172
|
300
305
|
"""
|
301
306
|
import pandas as pd
|
302
307
|
|
@@ -538,6 +543,116 @@ class DatasetExportMixin:
|
|
538
543
|
|
539
544
|
if return_link:
|
540
545
|
return filename
|
546
|
+
|
547
|
+
def report(self, *fields: Optional[str], top_n: Optional[int] = None,
|
548
|
+
header_fields: Optional[List[str]] = None, divider: bool = True,
|
549
|
+
return_string: bool = False) -> Optional[str]:
|
550
|
+
"""Takes the fields in order and returns a report of the results by iterating through rows.
|
551
|
+
The row number is printed as # Observation: <row number>
|
552
|
+
The name of the field is used as markdown header at level "##"
|
553
|
+
The content of that field is then printed.
|
554
|
+
Then the next field and so on.
|
555
|
+
Once that row is done, a new line is printed and the next row is shown.
|
556
|
+
If in a jupyter notebook, the report is displayed as markdown.
|
557
|
+
|
558
|
+
Args:
|
559
|
+
*fields: The fields to include in the report. If none provided, all fields are used.
|
560
|
+
top_n: Optional limit on the number of observations to include.
|
561
|
+
header_fields: Optional list of fields to include in the main header instead of as sections.
|
562
|
+
divider: If True, adds a horizontal rule between observations for better visual separation.
|
563
|
+
return_string: If True, returns the markdown string. If False (default in notebooks),
|
564
|
+
only displays the markdown without returning.
|
565
|
+
|
566
|
+
Returns:
|
567
|
+
A string containing the markdown report if return_string is True, otherwise None.
|
568
|
+
|
569
|
+
Examples:
|
570
|
+
>>> from edsl.results import Results
|
571
|
+
>>> r = Results.example()
|
572
|
+
>>> report = r.select('how_feeling', 'how_feeling_yesterday').report(return_string=True)
|
573
|
+
>>> "# Observation: 1" in report
|
574
|
+
True
|
575
|
+
>>> "## answer.how_feeling" in report
|
576
|
+
True
|
577
|
+
>>> report = r.select('how_feeling').report(header_fields=['answer.how_feeling'], return_string=True)
|
578
|
+
>>> "# Observation: 1 (`how_feeling`: OK)" in report
|
579
|
+
True
|
580
|
+
"""
|
581
|
+
from edsl.utilities.utilities import is_notebook
|
582
|
+
|
583
|
+
# If no fields specified, use all columns
|
584
|
+
if not fields:
|
585
|
+
fields = self.relevant_columns()
|
586
|
+
|
587
|
+
# Initialize header_fields if not provided
|
588
|
+
if header_fields is None:
|
589
|
+
header_fields = []
|
590
|
+
|
591
|
+
# Validate all fields
|
592
|
+
all_fields = list(fields) + [f for f in header_fields if f not in fields]
|
593
|
+
for field in all_fields:
|
594
|
+
if field not in self.relevant_columns():
|
595
|
+
raise ValueError(f"Field '{field}' not found in dataset")
|
596
|
+
|
597
|
+
# Get data for each field
|
598
|
+
field_data = {}
|
599
|
+
for field in all_fields:
|
600
|
+
for entry in self:
|
601
|
+
if field in entry:
|
602
|
+
field_data[field] = entry[field]
|
603
|
+
break
|
604
|
+
|
605
|
+
# Number of observations to process
|
606
|
+
num_obs = self.num_observations()
|
607
|
+
if top_n is not None:
|
608
|
+
num_obs = min(num_obs, top_n)
|
609
|
+
|
610
|
+
# Build the report
|
611
|
+
report_lines = []
|
612
|
+
for i in range(num_obs):
|
613
|
+
# Create header with observation number and any header fields
|
614
|
+
header = f"# Observation: {i+1}"
|
615
|
+
if header_fields:
|
616
|
+
header_parts = []
|
617
|
+
for field in header_fields:
|
618
|
+
value = field_data[field][i]
|
619
|
+
# Get the field name without prefix for cleaner display
|
620
|
+
display_name = field.split('.')[-1] if '.' in field else field
|
621
|
+
# Format with backticks for monospace
|
622
|
+
header_parts.append(f"`{display_name}`: {value}")
|
623
|
+
if header_parts:
|
624
|
+
header += f" ({', '.join(header_parts)})"
|
625
|
+
report_lines.append(header)
|
626
|
+
|
627
|
+
# Add the remaining fields
|
628
|
+
for field in fields:
|
629
|
+
if field not in header_fields:
|
630
|
+
report_lines.append(f"## {field}")
|
631
|
+
value = field_data[field][i]
|
632
|
+
if isinstance(value, list) or isinstance(value, dict):
|
633
|
+
import json
|
634
|
+
report_lines.append(f"```\n{json.dumps(value, indent=2)}\n```")
|
635
|
+
else:
|
636
|
+
report_lines.append(str(value))
|
637
|
+
|
638
|
+
# Add divider between observations if requested
|
639
|
+
if divider and i < num_obs - 1:
|
640
|
+
report_lines.append("\n---\n")
|
641
|
+
else:
|
642
|
+
report_lines.append("") # Empty line between observations
|
643
|
+
|
644
|
+
report_text = "\n".join(report_lines)
|
645
|
+
|
646
|
+
# In notebooks, display as markdown and optionally return
|
647
|
+
is_nb = is_notebook()
|
648
|
+
if is_nb:
|
649
|
+
from IPython.display import Markdown, display
|
650
|
+
display(Markdown(report_text))
|
651
|
+
|
652
|
+
# Return the string if requested or if not in a notebook
|
653
|
+
if return_string or not is_nb:
|
654
|
+
return report_text
|
655
|
+
return None
|
541
656
|
|
542
657
|
def tally(
|
543
658
|
self, *fields: Optional[str], top_n: Optional[int] = None, output="Dataset"
|
@@ -616,6 +731,179 @@ class DatasetExportMixin:
|
|
616
731
|
keys.append("count")
|
617
732
|
return sl.reorder_keys(keys).to_dataset()
|
618
733
|
|
734
|
+
def flatten(self, field, keep_original=False):
|
735
|
+
"""
|
736
|
+
Flatten a field containing a list of dictionaries into separate fields.
|
737
|
+
|
738
|
+
>>> from edsl.results.Dataset import Dataset
|
739
|
+
>>> Dataset([{'a': [{'a': 1, 'b': 2}]}, {'c': [5] }]).flatten('a')
|
740
|
+
Dataset([{'c': [5]}, {'a.a': [1]}, {'a.b': [2]}])
|
741
|
+
|
742
|
+
|
743
|
+
>>> Dataset([{'answer.example': [{'a': 1, 'b': 2}]}, {'c': [5] }]).flatten('answer.example')
|
744
|
+
Dataset([{'c': [5]}, {'answer.example.a': [1]}, {'answer.example.b': [2]}])
|
745
|
+
|
746
|
+
|
747
|
+
Args:
|
748
|
+
field: The field to flatten
|
749
|
+
keep_original: If True, keeps the original field in the dataset
|
750
|
+
|
751
|
+
Returns:
|
752
|
+
A new dataset with the flattened fields
|
753
|
+
"""
|
754
|
+
from edsl.results.Dataset import Dataset
|
755
|
+
|
756
|
+
# Ensure the dataset isn't empty
|
757
|
+
if not self.data:
|
758
|
+
return self.copy()
|
759
|
+
|
760
|
+
# Find all columns that contain the field
|
761
|
+
matching_entries = []
|
762
|
+
for entry in self.data:
|
763
|
+
col_name = next(iter(entry.keys()))
|
764
|
+
if field == col_name or (
|
765
|
+
'.' in col_name and
|
766
|
+
(col_name.endswith('.' + field) or col_name.startswith(field + '.'))
|
767
|
+
):
|
768
|
+
matching_entries.append(entry)
|
769
|
+
|
770
|
+
# Check if the field is ambiguous
|
771
|
+
if len(matching_entries) > 1:
|
772
|
+
matching_cols = [next(iter(entry.keys())) for entry in matching_entries]
|
773
|
+
raise ValueError(
|
774
|
+
f"Ambiguous field name '{field}'. It matches multiple columns: {matching_cols}. "
|
775
|
+
f"Please specify the full column name to flatten."
|
776
|
+
)
|
777
|
+
|
778
|
+
# Get the number of observations
|
779
|
+
num_observations = self.num_observations()
|
780
|
+
|
781
|
+
# Find the column to flatten
|
782
|
+
field_entry = None
|
783
|
+
for entry in self.data:
|
784
|
+
if field in entry:
|
785
|
+
field_entry = entry
|
786
|
+
break
|
787
|
+
|
788
|
+
if field_entry is None:
|
789
|
+
warnings.warn(
|
790
|
+
f"Field '{field}' not found in dataset, returning original dataset"
|
791
|
+
)
|
792
|
+
return self.copy()
|
793
|
+
|
794
|
+
# Create new dictionary for flattened data
|
795
|
+
flattened_data = []
|
796
|
+
|
797
|
+
# Copy all existing columns except the one we're flattening (if keep_original is False)
|
798
|
+
for entry in self.data:
|
799
|
+
col_name = next(iter(entry.keys()))
|
800
|
+
if col_name != field or keep_original:
|
801
|
+
flattened_data.append(entry.copy())
|
802
|
+
|
803
|
+
# Get field data and make sure it's valid
|
804
|
+
field_values = field_entry[field]
|
805
|
+
if not all(isinstance(item, dict) for item in field_values if item is not None):
|
806
|
+
warnings.warn(
|
807
|
+
f"Field '{field}' contains non-dictionary values that cannot be flattened"
|
808
|
+
)
|
809
|
+
return self.copy()
|
810
|
+
|
811
|
+
# Collect all unique keys across all dictionaries
|
812
|
+
all_keys = set()
|
813
|
+
for item in field_values:
|
814
|
+
if isinstance(item, dict):
|
815
|
+
all_keys.update(item.keys())
|
816
|
+
|
817
|
+
# Create new columns for each key
|
818
|
+
for key in sorted(all_keys): # Sort for consistent output
|
819
|
+
new_values = []
|
820
|
+
for i in range(num_observations):
|
821
|
+
value = None
|
822
|
+
if i < len(field_values) and isinstance(field_values[i], dict):
|
823
|
+
value = field_values[i].get(key, None)
|
824
|
+
new_values.append(value)
|
825
|
+
|
826
|
+
# Add this as a new column
|
827
|
+
flattened_data.append({f"{field}.{key}": new_values})
|
828
|
+
|
829
|
+
# Return a new Dataset with the flattened data
|
830
|
+
return Dataset(flattened_data)
|
831
|
+
|
832
|
+
def unpack_list(
|
833
|
+
self,
|
834
|
+
field: str,
|
835
|
+
new_names: Optional[List[str]] = None,
|
836
|
+
keep_original: bool = True,
|
837
|
+
) -> "Dataset":
|
838
|
+
"""Unpack list columns into separate columns with provided names or numeric suffixes.
|
839
|
+
|
840
|
+
For example, if a dataset contains:
|
841
|
+
[{'data': [[1, 2, 3], [4, 5, 6]], 'other': ['x', 'y']}]
|
842
|
+
|
843
|
+
After d.unpack_list('data'), it should become:
|
844
|
+
[{'other': ['x', 'y'], 'data_1': [1, 4], 'data_2': [2, 5], 'data_3': [3, 6]}]
|
845
|
+
|
846
|
+
Args:
|
847
|
+
field: The field containing lists to unpack
|
848
|
+
new_names: Optional list of names for the unpacked fields. If None, uses numeric suffixes.
|
849
|
+
keep_original: If True, keeps the original field in the dataset
|
850
|
+
|
851
|
+
Returns:
|
852
|
+
A new Dataset with unpacked columns
|
853
|
+
|
854
|
+
Examples:
|
855
|
+
>>> from edsl.results.Dataset import Dataset
|
856
|
+
>>> d = Dataset([{'data': [[1, 2, 3], [4, 5, 6]]}])
|
857
|
+
>>> d.unpack_list('data')
|
858
|
+
Dataset([{'data': [[1, 2, 3], [4, 5, 6]]}, {'data_1': [1, 4]}, {'data_2': [2, 5]}, {'data_3': [3, 6]}])
|
859
|
+
|
860
|
+
>>> d.unpack_list('data', new_names=['first', 'second', 'third'])
|
861
|
+
Dataset([{'data': [[1, 2, 3], [4, 5, 6]]}, {'first': [1, 4]}, {'second': [2, 5]}, {'third': [3, 6]}])
|
862
|
+
"""
|
863
|
+
from edsl.results.Dataset import Dataset
|
864
|
+
|
865
|
+
# Create a copy of the dataset
|
866
|
+
result = Dataset(self.data.copy())
|
867
|
+
|
868
|
+
# Find the field in the dataset
|
869
|
+
field_index = None
|
870
|
+
for i, entry in enumerate(result.data):
|
871
|
+
if field in entry:
|
872
|
+
field_index = i
|
873
|
+
break
|
874
|
+
|
875
|
+
if field_index is None:
|
876
|
+
raise ValueError(f"Field '{field}' not found in dataset")
|
877
|
+
|
878
|
+
field_data = result.data[field_index][field]
|
879
|
+
|
880
|
+
# Check if values are lists
|
881
|
+
if not all(isinstance(v, list) for v in field_data):
|
882
|
+
raise ValueError(f"Field '{field}' does not contain lists in all entries")
|
883
|
+
|
884
|
+
# Get the maximum length of lists
|
885
|
+
max_len = max(len(v) for v in field_data)
|
886
|
+
|
887
|
+
# Create new fields for each index
|
888
|
+
for i in range(max_len):
|
889
|
+
if new_names and i < len(new_names):
|
890
|
+
new_field = new_names[i]
|
891
|
+
else:
|
892
|
+
new_field = f"{field}_{i+1}"
|
893
|
+
|
894
|
+
# Extract the i-th element from each list
|
895
|
+
new_values = []
|
896
|
+
for item in field_data:
|
897
|
+
new_values.append(item[i] if i < len(item) else None)
|
898
|
+
|
899
|
+
result.data.append({new_field: new_values})
|
900
|
+
|
901
|
+
# Remove the original field if keep_original is False
|
902
|
+
if not keep_original:
|
903
|
+
result.data.pop(field_index)
|
904
|
+
|
905
|
+
return result
|
906
|
+
|
619
907
|
|
620
908
|
if __name__ == "__main__":
|
621
909
|
import doctest
|