edsl 0.1.33.dev1__py3-none-any.whl → 0.1.33.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/TemplateLoader.py +24 -0
- edsl/__init__.py +8 -4
- edsl/agents/Agent.py +46 -14
- edsl/agents/AgentList.py +43 -0
- edsl/agents/Invigilator.py +125 -212
- edsl/agents/InvigilatorBase.py +140 -32
- edsl/agents/PromptConstructionMixin.py +43 -66
- edsl/agents/__init__.py +1 -0
- edsl/auto/AutoStudy.py +117 -0
- edsl/auto/StageBase.py +230 -0
- edsl/auto/StageGenerateSurvey.py +178 -0
- edsl/auto/StageLabelQuestions.py +125 -0
- edsl/auto/StagePersona.py +61 -0
- edsl/auto/StagePersonaDimensionValueRanges.py +88 -0
- edsl/auto/StagePersonaDimensionValues.py +74 -0
- edsl/auto/StagePersonaDimensions.py +69 -0
- edsl/auto/StageQuestions.py +73 -0
- edsl/auto/SurveyCreatorPipeline.py +21 -0
- edsl/auto/utilities.py +224 -0
- edsl/config.py +38 -39
- edsl/coop/PriceFetcher.py +58 -0
- edsl/coop/coop.py +39 -5
- edsl/data/Cache.py +35 -1
- edsl/data_transfer_models.py +120 -38
- edsl/enums.py +2 -0
- edsl/exceptions/language_models.py +25 -1
- edsl/exceptions/questions.py +62 -5
- edsl/exceptions/results.py +4 -0
- edsl/inference_services/AnthropicService.py +13 -11
- edsl/inference_services/AwsBedrock.py +19 -17
- edsl/inference_services/AzureAI.py +37 -20
- edsl/inference_services/GoogleService.py +16 -12
- edsl/inference_services/GroqService.py +2 -0
- edsl/inference_services/InferenceServiceABC.py +24 -0
- edsl/inference_services/MistralAIService.py +120 -0
- edsl/inference_services/OpenAIService.py +41 -50
- edsl/inference_services/TestService.py +71 -0
- edsl/inference_services/models_available_cache.py +0 -6
- edsl/inference_services/registry.py +4 -0
- edsl/jobs/Answers.py +10 -12
- edsl/jobs/FailedQuestion.py +78 -0
- edsl/jobs/Jobs.py +18 -13
- edsl/jobs/buckets/TokenBucket.py +39 -14
- edsl/jobs/interviews/Interview.py +297 -77
- edsl/jobs/interviews/InterviewExceptionEntry.py +83 -19
- edsl/jobs/interviews/interview_exception_tracking.py +0 -70
- edsl/jobs/interviews/retry_management.py +3 -1
- edsl/jobs/runners/JobsRunnerAsyncio.py +116 -70
- edsl/jobs/runners/JobsRunnerStatusMixin.py +1 -1
- edsl/jobs/tasks/QuestionTaskCreator.py +30 -23
- edsl/jobs/tasks/TaskHistory.py +131 -213
- edsl/language_models/LanguageModel.py +239 -129
- edsl/language_models/ModelList.py +2 -2
- edsl/language_models/RegisterLanguageModelsMeta.py +14 -29
- edsl/language_models/fake_openai_call.py +15 -0
- edsl/language_models/fake_openai_service.py +61 -0
- edsl/language_models/registry.py +15 -2
- edsl/language_models/repair.py +0 -19
- edsl/language_models/utilities.py +61 -0
- edsl/prompts/Prompt.py +52 -2
- edsl/questions/AnswerValidatorMixin.py +23 -26
- edsl/questions/QuestionBase.py +273 -242
- edsl/questions/QuestionBaseGenMixin.py +133 -0
- edsl/questions/QuestionBasePromptsMixin.py +266 -0
- edsl/questions/QuestionBudget.py +6 -0
- edsl/questions/QuestionCheckBox.py +227 -35
- edsl/questions/QuestionExtract.py +98 -27
- edsl/questions/QuestionFreeText.py +46 -29
- edsl/questions/QuestionFunctional.py +7 -0
- edsl/questions/QuestionList.py +141 -22
- edsl/questions/QuestionMultipleChoice.py +173 -64
- edsl/questions/QuestionNumerical.py +87 -46
- edsl/questions/QuestionRank.py +182 -24
- edsl/questions/RegisterQuestionsMeta.py +31 -12
- edsl/questions/ResponseValidatorABC.py +169 -0
- edsl/questions/__init__.py +3 -4
- edsl/questions/decorators.py +21 -0
- edsl/questions/derived/QuestionLikertFive.py +10 -5
- edsl/questions/derived/QuestionLinearScale.py +11 -1
- edsl/questions/derived/QuestionTopK.py +6 -0
- edsl/questions/derived/QuestionYesNo.py +16 -1
- edsl/questions/descriptors.py +43 -7
- edsl/questions/prompt_templates/question_budget.jinja +13 -0
- edsl/questions/prompt_templates/question_checkbox.jinja +32 -0
- edsl/questions/prompt_templates/question_extract.jinja +11 -0
- edsl/questions/prompt_templates/question_free_text.jinja +3 -0
- edsl/questions/prompt_templates/question_linear_scale.jinja +11 -0
- edsl/questions/prompt_templates/question_list.jinja +17 -0
- edsl/questions/prompt_templates/question_multiple_choice.jinja +33 -0
- edsl/questions/prompt_templates/question_numerical.jinja +37 -0
- edsl/questions/question_registry.py +6 -2
- edsl/questions/templates/__init__.py +0 -0
- edsl/questions/templates/checkbox/__init__.py +0 -0
- edsl/questions/templates/checkbox/answering_instructions.jinja +10 -0
- edsl/questions/templates/checkbox/question_presentation.jinja +22 -0
- edsl/questions/templates/extract/answering_instructions.jinja +7 -0
- edsl/questions/templates/extract/question_presentation.jinja +1 -0
- edsl/questions/templates/free_text/__init__.py +0 -0
- edsl/questions/templates/free_text/answering_instructions.jinja +0 -0
- edsl/questions/templates/free_text/question_presentation.jinja +1 -0
- edsl/questions/templates/likert_five/__init__.py +0 -0
- edsl/questions/templates/likert_five/answering_instructions.jinja +10 -0
- edsl/questions/templates/likert_five/question_presentation.jinja +12 -0
- edsl/questions/templates/linear_scale/__init__.py +0 -0
- edsl/questions/templates/linear_scale/answering_instructions.jinja +5 -0
- edsl/questions/templates/linear_scale/question_presentation.jinja +5 -0
- edsl/questions/templates/list/__init__.py +0 -0
- edsl/questions/templates/list/answering_instructions.jinja +4 -0
- edsl/questions/templates/list/question_presentation.jinja +5 -0
- edsl/questions/templates/multiple_choice/__init__.py +0 -0
- edsl/questions/templates/multiple_choice/answering_instructions.jinja +9 -0
- edsl/questions/templates/multiple_choice/html.jinja +0 -0
- edsl/questions/templates/multiple_choice/question_presentation.jinja +12 -0
- edsl/questions/templates/numerical/__init__.py +0 -0
- edsl/questions/templates/numerical/answering_instructions.jinja +8 -0
- edsl/questions/templates/numerical/question_presentation.jinja +7 -0
- edsl/questions/templates/rank/answering_instructions.jinja +11 -0
- edsl/questions/templates/rank/question_presentation.jinja +15 -0
- edsl/questions/templates/top_k/__init__.py +0 -0
- edsl/questions/templates/top_k/answering_instructions.jinja +8 -0
- edsl/questions/templates/top_k/question_presentation.jinja +22 -0
- edsl/questions/templates/yes_no/__init__.py +0 -0
- edsl/questions/templates/yes_no/answering_instructions.jinja +6 -0
- edsl/questions/templates/yes_no/question_presentation.jinja +12 -0
- edsl/results/Dataset.py +20 -0
- edsl/results/DatasetExportMixin.py +41 -47
- edsl/results/DatasetTree.py +145 -0
- edsl/results/Result.py +32 -5
- edsl/results/Results.py +131 -45
- edsl/results/ResultsDBMixin.py +3 -3
- edsl/results/Selector.py +118 -0
- edsl/results/tree_explore.py +115 -0
- edsl/scenarios/Scenario.py +10 -4
- edsl/scenarios/ScenarioList.py +348 -39
- edsl/scenarios/ScenarioListExportMixin.py +9 -0
- edsl/study/SnapShot.py +8 -1
- edsl/surveys/RuleCollection.py +2 -2
- edsl/surveys/Survey.py +634 -315
- edsl/surveys/SurveyExportMixin.py +71 -9
- edsl/surveys/SurveyFlowVisualizationMixin.py +2 -1
- edsl/surveys/SurveyQualtricsImport.py +75 -4
- edsl/surveys/instructions/ChangeInstruction.py +47 -0
- edsl/surveys/instructions/Instruction.py +34 -0
- edsl/surveys/instructions/InstructionCollection.py +77 -0
- edsl/surveys/instructions/__init__.py +0 -0
- edsl/templates/error_reporting/base.html +24 -0
- edsl/templates/error_reporting/exceptions_by_model.html +35 -0
- edsl/templates/error_reporting/exceptions_by_question_name.html +17 -0
- edsl/templates/error_reporting/exceptions_by_type.html +17 -0
- edsl/templates/error_reporting/interview_details.html +111 -0
- edsl/templates/error_reporting/interviews.html +10 -0
- edsl/templates/error_reporting/overview.html +5 -0
- edsl/templates/error_reporting/performance_plot.html +2 -0
- edsl/templates/error_reporting/report.css +74 -0
- edsl/templates/error_reporting/report.html +118 -0
- edsl/templates/error_reporting/report.js +25 -0
- {edsl-0.1.33.dev1.dist-info → edsl-0.1.33.dev2.dist-info}/METADATA +4 -2
- edsl-0.1.33.dev2.dist-info/RECORD +289 -0
- edsl/jobs/interviews/InterviewTaskBuildingMixin.py +0 -286
- edsl/utilities/gcp_bucket/simple_example.py +0 -9
- edsl-0.1.33.dev1.dist-info/RECORD +0 -209
- {edsl-0.1.33.dev1.dist-info → edsl-0.1.33.dev2.dist-info}/LICENSE +0 -0
- {edsl-0.1.33.dev1.dist-info → edsl-0.1.33.dev2.dist-info}/WHEEL +0 -0
@@ -0,0 +1,12 @@
|
|
1
|
+
{# Question Presention #}
|
2
|
+
{{question_text}}
|
3
|
+
{% if use_code %}
|
4
|
+
{%- for option in question_options %}
|
5
|
+
{{ loop.index0 }}: {{option}}
|
6
|
+
{% endfor %}
|
7
|
+
{% else %}
|
8
|
+
{% for option in question_options %}
|
9
|
+
{{option}}
|
10
|
+
{% endfor %}
|
11
|
+
{% endif %}
|
12
|
+
Only 1 option may be selected.
|
File without changes
|
@@ -0,0 +1,8 @@
|
|
1
|
+
This question requires a numerical response in the form of an integer or decimal (e.g., -12, 0, 1, 2, 3.45, ...).
|
2
|
+
Respond with just your number on a single line.
|
3
|
+
If your response is equivalent to zero, report '0'
|
4
|
+
If you cannot determine the answer, report 'None'
|
5
|
+
|
6
|
+
{% if include_comment %}
|
7
|
+
After the answer, put a comment explaining your choice on the next line.
|
8
|
+
{% endif %}
|
@@ -0,0 +1,11 @@
|
|
1
|
+
{# Answering Instructions #}
|
2
|
+
{% if use_code %}
|
3
|
+
Please respond only with a comma-separated list of the code of the raked options, with square brackets. E.g., [0, 1, 3]
|
4
|
+
{% else %}
|
5
|
+
Please respond only with a comma-separated list of the ranked options, with square brackets. E.g., ['Good', 'Bad', 'Ugly']
|
6
|
+
{% endif %}
|
7
|
+
{% if include_comment %}
|
8
|
+
After the answer, you can put a comment explaining your choice on the next line.
|
9
|
+
{% endif %}
|
10
|
+
|
11
|
+
|
@@ -0,0 +1,15 @@
|
|
1
|
+
{{question_text}}
|
2
|
+
{% if use_code %}
|
3
|
+
The options are
|
4
|
+
{% for option in question_options %}
|
5
|
+
{{ loop.index0 }}: {{option}}
|
6
|
+
{% endfor %}
|
7
|
+
{% else %}
|
8
|
+
The options are:
|
9
|
+
{% for option in question_options %}
|
10
|
+
{{option}}
|
11
|
+
{% endfor %}
|
12
|
+
{% endif %}
|
13
|
+
{% if num_selections %}
|
14
|
+
You can inlcude up to {{num_selections}} options in your answer.
|
15
|
+
{% endif %}
|
File without changes
|
@@ -0,0 +1,8 @@
|
|
1
|
+
{# Answering Instructions #}
|
2
|
+
Please respond with valid JSON, formatted like so:
|
3
|
+
{% if include_comment %}
|
4
|
+
{"answer": [<put comma-separated list here>], "comment": "<put explanation here>"}
|
5
|
+
{% else %}
|
6
|
+
{"answer": [<put comma-separated list here>]}
|
7
|
+
{% endif %}
|
8
|
+
|
@@ -0,0 +1,22 @@
|
|
1
|
+
{{question_text}}
|
2
|
+
{% if use_code %}
|
3
|
+
{% for option in question_options %}
|
4
|
+
{{ loop.index0 }}: {{option}}
|
5
|
+
{% endfor %}
|
6
|
+
{% else %}
|
7
|
+
{% for option in question_options %}
|
8
|
+
{{ option }}
|
9
|
+
{% endfor %}
|
10
|
+
{% endif %}
|
11
|
+
|
12
|
+
{# Restrictions #}
|
13
|
+
{% if min_selections != None and max_selections != None and min_selections == max_selections %}
|
14
|
+
You must select exactly {{min_selections}} options.
|
15
|
+
{% elif min_selections != None and max_selections != None %}
|
16
|
+
Minimum number of options that must be selected: {{min_selections}}.
|
17
|
+
Maximum number of options that must be selected: {{max_selections}}.
|
18
|
+
{% elif min_selections != None %}
|
19
|
+
Minimum number of options that must be selected: {{min_selections}}.
|
20
|
+
{% elif max_selections != None %}
|
21
|
+
Maximum number of options that must be selected: {{max_selections}}.
|
22
|
+
{% endif %}
|
File without changes
|
@@ -0,0 +1,12 @@
|
|
1
|
+
{# Question Presention #}
|
2
|
+
{{question_text}}
|
3
|
+
{% if use_code %}
|
4
|
+
{%- for option in question_options %}
|
5
|
+
{{ loop.index0 }}: {{option}}
|
6
|
+
{% endfor %}
|
7
|
+
{% else %}
|
8
|
+
{% for option in question_options %}
|
9
|
+
{{option}}
|
10
|
+
{% endfor %}
|
11
|
+
{% endif %}
|
12
|
+
Only 1 option may be selected.
|
edsl/results/Dataset.py
CHANGED
@@ -8,6 +8,7 @@ from typing import Any, Union, Optional
|
|
8
8
|
import numpy as np
|
9
9
|
|
10
10
|
from edsl.results.ResultsExportMixin import ResultsExportMixin
|
11
|
+
from edsl.results.DatasetTree import Tree
|
11
12
|
|
12
13
|
|
13
14
|
class Dataset(UserList, ResultsExportMixin):
|
@@ -30,6 +31,15 @@ class Dataset(UserList, ResultsExportMixin):
|
|
30
31
|
_, values = list(self.data[0].items())[0]
|
31
32
|
return len(values)
|
32
33
|
|
34
|
+
def keys(self):
|
35
|
+
"""Return the keys of the first observation in the dataset.
|
36
|
+
|
37
|
+
>>> d = Dataset([{'a.b':[1,2,3,4]}])
|
38
|
+
>>> d.keys()
|
39
|
+
['a.b']
|
40
|
+
"""
|
41
|
+
return [list(o.keys())[0] for o in self]
|
42
|
+
|
33
43
|
def __repr__(self) -> str:
|
34
44
|
"""Return a string representation of the dataset."""
|
35
45
|
return f"Dataset({self.data})"
|
@@ -245,6 +255,16 @@ class Dataset(UserList, ResultsExportMixin):
|
|
245
255
|
|
246
256
|
return Dataset(new_data)
|
247
257
|
|
258
|
+
@property
|
259
|
+
def tree(self):
|
260
|
+
"""Return a tree representation of the dataset.
|
261
|
+
|
262
|
+
>>> d = Dataset([{'a':[1,2,3,4]}, {'b':[4,3,2,1]}])
|
263
|
+
>>> d.tree.print_tree()
|
264
|
+
Tree has not been constructed yet.
|
265
|
+
"""
|
266
|
+
return Tree(self)
|
267
|
+
|
248
268
|
@classmethod
|
249
269
|
def example(self):
|
250
270
|
"""Return an example dataset.
|
@@ -4,6 +4,7 @@ import base64
|
|
4
4
|
import csv
|
5
5
|
import io
|
6
6
|
import html
|
7
|
+
from typing import Optional
|
7
8
|
|
8
9
|
from typing import Literal, Optional, Union, List
|
9
10
|
|
@@ -41,7 +42,7 @@ class DatasetExportMixin:
|
|
41
42
|
>>> Results.example().relevant_columns(data_type = "flimflam")
|
42
43
|
Traceback (most recent call last):
|
43
44
|
...
|
44
|
-
ValueError: No columns found for data type: flimflam. Available data types are:
|
45
|
+
ValueError: No columns found for data type: flimflam. Available data types are: ...
|
45
46
|
"""
|
46
47
|
columns = [list(x.keys())[0] for x in self]
|
47
48
|
if remove_prefix:
|
@@ -156,12 +157,13 @@ class DatasetExportMixin:
|
|
156
157
|
iframe_height: int = 200,
|
157
158
|
iframe_width: int = 600,
|
158
159
|
web=False,
|
159
|
-
|
160
|
+
return_string: bool = False,
|
161
|
+
) -> Union[None, str, "Results"]:
|
160
162
|
"""Print the results in a pretty format.
|
161
163
|
|
162
164
|
:param pretty_labels: A dictionary of pretty labels for the columns.
|
163
165
|
:param filename: The filename to save the results to.
|
164
|
-
:param format: The format to print the results in. Options are 'rich', 'html', or '
|
166
|
+
:param format: The format to print the results in. Options are 'rich', 'html', 'markdown', or 'latex'.
|
165
167
|
:param interactive: Whether to print the results interactively in a Jupyter notebook.
|
166
168
|
:param split_at_dot: Whether to split the column names at the last dot w/ a newline.
|
167
169
|
:param max_rows: The maximum number of rows to print.
|
@@ -170,6 +172,9 @@ class DatasetExportMixin:
|
|
170
172
|
:param iframe_height: The height of the iframe.
|
171
173
|
:param iframe_width: The width of the iframe.
|
172
174
|
:param web: Whether to display the table in a web browser.
|
175
|
+
:param return_string: Whether to return the output as a string instead of printing.
|
176
|
+
|
177
|
+
:return: None if tee is False and return_string is False, the dataset if tee is True, or a string if return_string is True.
|
173
178
|
|
174
179
|
Example: Print in rich format at the terminal
|
175
180
|
|
@@ -253,11 +258,14 @@ class DatasetExportMixin:
|
|
253
258
|
|
254
259
|
>>> r.select('how_feeling').print(format='latex')
|
255
260
|
\\begin{tabular}{l}
|
256
|
-
\\toprule
|
257
261
|
...
|
262
|
+
\\end{tabular}
|
263
|
+
<BLANKLINE>
|
258
264
|
"""
|
259
265
|
from IPython.display import HTML, display
|
260
266
|
from edsl.utilities.utilities import is_notebook
|
267
|
+
import io
|
268
|
+
import sys
|
261
269
|
|
262
270
|
def _determine_format(format):
|
263
271
|
if format is None:
|
@@ -266,7 +274,9 @@ class DatasetExportMixin:
|
|
266
274
|
else:
|
267
275
|
format = "rich"
|
268
276
|
if format not in ["rich", "html", "markdown", "latex"]:
|
269
|
-
raise ValueError(
|
277
|
+
raise ValueError(
|
278
|
+
"format must be one of 'rich', 'html', 'markdown', or 'latex'."
|
279
|
+
)
|
270
280
|
|
271
281
|
return format
|
272
282
|
|
@@ -285,21 +295,24 @@ class DatasetExportMixin:
|
|
285
295
|
|
286
296
|
new_data = list(_create_data())
|
287
297
|
|
298
|
+
# Capture output if return_string is True
|
299
|
+
if return_string:
|
300
|
+
old_stdout = sys.stdout
|
301
|
+
sys.stdout = io.StringIO()
|
302
|
+
|
303
|
+
output = None
|
304
|
+
|
288
305
|
if format == "rich":
|
289
306
|
from edsl.utilities.interface import print_dataset_with_rich
|
290
307
|
|
291
|
-
print_dataset_with_rich(
|
308
|
+
output = print_dataset_with_rich(
|
292
309
|
new_data, filename=filename, split_at_dot=split_at_dot
|
293
310
|
)
|
294
|
-
|
295
|
-
|
296
|
-
if format == "markdown":
|
311
|
+
elif format == "markdown":
|
297
312
|
from edsl.utilities.interface import print_list_of_dicts_as_markdown_table
|
298
313
|
|
299
|
-
print_list_of_dicts_as_markdown_table(new_data, filename=filename)
|
300
|
-
|
301
|
-
|
302
|
-
if format == "latex":
|
314
|
+
output = print_list_of_dicts_as_markdown_table(new_data, filename=filename)
|
315
|
+
elif format == "latex":
|
303
316
|
df = self.to_pandas()
|
304
317
|
df.columns = [col.replace("_", " ") for col in df.columns]
|
305
318
|
latex_string = df.to_latex(index=False)
|
@@ -309,23 +322,14 @@ class DatasetExportMixin:
|
|
309
322
|
f.write(latex_string)
|
310
323
|
else:
|
311
324
|
print(latex_string)
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
if format == "html":
|
325
|
+
output = latex_string
|
326
|
+
elif format == "html":
|
316
327
|
from edsl.utilities.interface import print_list_of_dicts_as_html_table
|
317
328
|
|
318
329
|
html_source = print_list_of_dicts_as_html_table(
|
319
330
|
new_data, interactive=interactive
|
320
331
|
)
|
321
332
|
|
322
|
-
# if download_link:
|
323
|
-
# from IPython.display import HTML, display
|
324
|
-
# csv_file = output.getvalue()
|
325
|
-
# b64 = base64.b64encode(csv_file.encode()).decode()
|
326
|
-
# download_link = f'<a href="data:file/csv;base64,{b64}" download="my_data.csv">Download CSV file</a>'
|
327
|
-
# #display(HTML(download_link))
|
328
|
-
|
329
333
|
if iframe:
|
330
334
|
iframe = f""""
|
331
335
|
<iframe srcdoc="{ html.escape(html_source) }" style="width: {iframe_width}px; height: {iframe_height}px;"></iframe>
|
@@ -338,7 +342,18 @@ class DatasetExportMixin:
|
|
338
342
|
|
339
343
|
view_html(html_source)
|
340
344
|
|
341
|
-
|
345
|
+
output = html_source
|
346
|
+
|
347
|
+
# Restore stdout and get captured output if return_string is True
|
348
|
+
if return_string:
|
349
|
+
captured_output = sys.stdout.getvalue()
|
350
|
+
sys.stdout = old_stdout
|
351
|
+
return captured_output or output
|
352
|
+
|
353
|
+
if tee:
|
354
|
+
return self
|
355
|
+
|
356
|
+
return None
|
342
357
|
|
343
358
|
def to_csv(
|
344
359
|
self,
|
@@ -501,7 +516,7 @@ class DatasetExportMixin:
|
|
501
516
|
|
502
517
|
return list_of_dicts
|
503
518
|
|
504
|
-
def to_list(self, flatten=False, remove_none=False) -> list[list]:
|
519
|
+
def to_list(self, flatten=False, remove_none=False, unzipped=False) -> list[list]:
|
505
520
|
"""Convert the results to a list of lists.
|
506
521
|
|
507
522
|
:param flatten: Whether to flatten the list of lists.
|
@@ -596,27 +611,6 @@ class DatasetExportMixin:
|
|
596
611
|
if return_link:
|
597
612
|
return filename
|
598
613
|
|
599
|
-
def to_docx(self, filename: Optional[str] = None, separator: str = "\n"):
|
600
|
-
"""Export the results to a Word document.
|
601
|
-
|
602
|
-
:param filename: The filename to save the Word document to.
|
603
|
-
|
604
|
-
|
605
|
-
"""
|
606
|
-
from docx import Document
|
607
|
-
|
608
|
-
doc = Document()
|
609
|
-
for entry in self:
|
610
|
-
key, values = list(entry.items())[0]
|
611
|
-
doc.add_paragraph(key)
|
612
|
-
line = separator.join(values)
|
613
|
-
doc.add_paragraph(line)
|
614
|
-
|
615
|
-
if filename is not None:
|
616
|
-
doc.save(filename)
|
617
|
-
else:
|
618
|
-
return doc
|
619
|
-
|
620
614
|
def tally(
|
621
615
|
self, *fields: Optional[str], top_n: Optional[int] = None, output="Dataset"
|
622
616
|
) -> Union[dict, "Dataset"]:
|
@@ -0,0 +1,145 @@
|
|
1
|
+
from typing import Dict, List, Any, Optional
|
2
|
+
from docx import Document
|
3
|
+
from docx.shared import Inches, Pt
|
4
|
+
from docx.enum.text import WD_ALIGN_PARAGRAPH
|
5
|
+
from docx.enum.style import WD_STYLE_TYPE
|
6
|
+
|
7
|
+
|
8
|
+
class TreeNode:
|
9
|
+
def __init__(self, key=None, value=None):
|
10
|
+
self.key = key
|
11
|
+
self.value = value
|
12
|
+
self.children = {}
|
13
|
+
|
14
|
+
|
15
|
+
class Tree:
|
16
|
+
def __init__(self, data: "Dataset"):
|
17
|
+
d = {}
|
18
|
+
for entry in data:
|
19
|
+
d.update(entry)
|
20
|
+
self.data = d
|
21
|
+
self.root = None
|
22
|
+
|
23
|
+
def unique_values_by_keys(self) -> dict:
|
24
|
+
unique_values = {}
|
25
|
+
for key, values in self.data.items():
|
26
|
+
unique_values[key] = list(set(values))
|
27
|
+
return unique_values
|
28
|
+
|
29
|
+
def construct_tree(self, node_order: Optional[List[str]] = None):
|
30
|
+
# Validate node_order
|
31
|
+
if node_order is None:
|
32
|
+
unique_values = self.unique_values_by_keys()
|
33
|
+
# Sort keys by number of unique values
|
34
|
+
node_order = sorted(
|
35
|
+
unique_values, key=lambda k: len(unique_values[k]), reverse=True
|
36
|
+
)
|
37
|
+
else:
|
38
|
+
if not set(node_order).issubset(set(self.data.keys())):
|
39
|
+
invalid_keys = set(node_order) - set(self.data.keys())
|
40
|
+
raise ValueError(f"Invalid keys in node_order: {invalid_keys}")
|
41
|
+
|
42
|
+
self.root = TreeNode()
|
43
|
+
|
44
|
+
for i in range(len(self.data[node_order[0]])):
|
45
|
+
current = self.root
|
46
|
+
for level in node_order[:-1]:
|
47
|
+
value = self.data[level][i]
|
48
|
+
if value not in current.children:
|
49
|
+
current.children[value] = TreeNode(key=level, value=value)
|
50
|
+
current = current.children[value]
|
51
|
+
|
52
|
+
leaf_key = node_order[-1]
|
53
|
+
leaf_value = self.data[leaf_key][i]
|
54
|
+
if leaf_value not in current.children:
|
55
|
+
current.children[leaf_value] = TreeNode(key=leaf_key, value=leaf_value)
|
56
|
+
|
57
|
+
def print_tree(
|
58
|
+
self, node: Optional[TreeNode] = None, level: int = 0, print_keys: bool = False
|
59
|
+
):
|
60
|
+
if node is None:
|
61
|
+
node = self.root
|
62
|
+
if node is None:
|
63
|
+
print("Tree has not been constructed yet.")
|
64
|
+
return
|
65
|
+
|
66
|
+
if node.value is not None:
|
67
|
+
if print_keys and node.key is not None:
|
68
|
+
print(" " * level + f"{node.key}: {node.value}")
|
69
|
+
else:
|
70
|
+
print(" " * level + str(node.value))
|
71
|
+
for child in node.children.values():
|
72
|
+
self.print_tree(child, level + 1, print_keys)
|
73
|
+
|
74
|
+
def to_docx(self, filename: str):
|
75
|
+
doc = Document()
|
76
|
+
|
77
|
+
# Create styles for headings
|
78
|
+
for i in range(1, 10): # Up to 9 levels of headings
|
79
|
+
style_name = f"Heading {i}"
|
80
|
+
if style_name not in doc.styles:
|
81
|
+
doc.styles.add_style(style_name, WD_STYLE_TYPE.PARAGRAPH)
|
82
|
+
|
83
|
+
# Get or create the 'Body Text' style
|
84
|
+
if "Body Text" not in doc.styles:
|
85
|
+
body_style = doc.styles.add_style("Body Text", WD_STYLE_TYPE.PARAGRAPH)
|
86
|
+
else:
|
87
|
+
body_style = doc.styles["Body Text"]
|
88
|
+
|
89
|
+
body_style.font.size = Pt(11)
|
90
|
+
|
91
|
+
self._add_to_docx(doc, self.root, 0)
|
92
|
+
doc.save(filename)
|
93
|
+
|
94
|
+
def _add_to_docx(self, doc, node: TreeNode, level: int):
|
95
|
+
if node.value is not None:
|
96
|
+
if level == 0:
|
97
|
+
doc.add_heading(str(node.value), level=level + 1)
|
98
|
+
elif node.children: # If the node has children, it's not the last level
|
99
|
+
para = doc.add_paragraph(str(node.value))
|
100
|
+
para.style = f"Heading {level+1}"
|
101
|
+
else: # If the node has no children, it's the last level (body text)
|
102
|
+
para = doc.add_paragraph(str(node.value))
|
103
|
+
para.style = "Body Text"
|
104
|
+
|
105
|
+
# Process child nodes (moved outside the if block)
|
106
|
+
for child in node.children.values():
|
107
|
+
self._add_to_docx(doc, child, level + 1)
|
108
|
+
|
109
|
+
|
110
|
+
# Example usage (commented out)
|
111
|
+
"""
|
112
|
+
from edsl.results.Dataset import Dataset
|
113
|
+
|
114
|
+
data = Dataset(
|
115
|
+
[
|
116
|
+
{"continent": ["North America", "Asia", "Europe", "North America", "Asia"]},
|
117
|
+
{"country": ["US", "China", "France", "Canada", "Japan"]},
|
118
|
+
{"city": ["New York", "Beijing", "Paris", "Toronto", "Tokyo"]},
|
119
|
+
{"population": [8419000, 21540000, 2161000, 2930000, 13960000]},
|
120
|
+
]
|
121
|
+
)
|
122
|
+
|
123
|
+
tree = Tree(data)
|
124
|
+
|
125
|
+
try:
|
126
|
+
tree.construct_tree(["continent", "country", "city", "population"])
|
127
|
+
print("Tree without key names:")
|
128
|
+
tree.print_tree()
|
129
|
+
print("\nTree with key names:")
|
130
|
+
tree.print_tree(print_keys=True)
|
131
|
+
except ValueError as e:
|
132
|
+
print(f"Error: {e}")
|
133
|
+
|
134
|
+
# Demonstrating validation
|
135
|
+
try:
|
136
|
+
tree.construct_tree(["continent", "country", "invalid_key"])
|
137
|
+
except ValueError as e:
|
138
|
+
print(f"\nValidation Error: {e}")
|
139
|
+
|
140
|
+
tree = Tree(data)
|
141
|
+
tree.construct_tree(["continent", "country", "city", "population"])
|
142
|
+
tree.print_tree(print_keys=True)
|
143
|
+
tree.to_docx("tree_structure.docx")
|
144
|
+
print("DocX file 'tree_structure.docx' has been created.")
|
145
|
+
"""
|
edsl/results/Result.py
CHANGED
@@ -53,8 +53,8 @@ class Result(Base, UserDict):
|
|
53
53
|
|
54
54
|
>>> import warnings
|
55
55
|
>>> warnings.simplefilter("ignore", UserWarning)
|
56
|
-
>>> Result.example().answer
|
57
|
-
|
56
|
+
>>> Result.example().answer == {'how_feeling_yesterday': 'Great', 'how_feeling': 'OK'}
|
57
|
+
True
|
58
58
|
|
59
59
|
Its main data is an Agent, a Scenario, a Model, an Iteration, and an Answer.
|
60
60
|
These are stored both in the UserDict and as attributes.
|
@@ -73,6 +73,8 @@ class Result(Base, UserDict):
|
|
73
73
|
raw_model_response=None,
|
74
74
|
survey: Optional["Survey"] = None,
|
75
75
|
question_to_attributes: Optional[dict] = None,
|
76
|
+
generated_tokens: Optional[dict] = None,
|
77
|
+
comments_dict: Optional[dict] = None,
|
76
78
|
):
|
77
79
|
"""Initialize a Result object.
|
78
80
|
|
@@ -113,6 +115,7 @@ class Result(Base, UserDict):
|
|
113
115
|
"prompt": prompt or {},
|
114
116
|
"raw_model_response": raw_model_response or {},
|
115
117
|
"question_to_attributes": question_to_attributes,
|
118
|
+
"generated_tokens": generated_tokens or {},
|
116
119
|
}
|
117
120
|
super().__init__(**data)
|
118
121
|
# but also store the data as attributes
|
@@ -125,6 +128,8 @@ class Result(Base, UserDict):
|
|
125
128
|
self.raw_model_response = raw_model_response or {}
|
126
129
|
self.survey = survey
|
127
130
|
self.question_to_attributes = question_to_attributes
|
131
|
+
self.generated_tokens = generated_tokens
|
132
|
+
self.comments_dict = comments_dict or {}
|
128
133
|
|
129
134
|
self._combined_dict = None
|
130
135
|
self._problem_keys = None
|
@@ -140,7 +145,7 @@ class Result(Base, UserDict):
|
|
140
145
|
else:
|
141
146
|
agent_name = self.agent.name
|
142
147
|
|
143
|
-
comments_dict = {k: v for k, v in self.answer.items() if k.endswith("_comment")}
|
148
|
+
# comments_dict = {k: v for k, v in self.answer.items() if k.endswith("_comment")}
|
144
149
|
question_text_dict = {}
|
145
150
|
question_options_dict = {}
|
146
151
|
question_type_dict = {}
|
@@ -167,11 +172,12 @@ class Result(Base, UserDict):
|
|
167
172
|
"answer": self.answer,
|
168
173
|
"prompt": self.prompt,
|
169
174
|
"raw_model_response": self.raw_model_response,
|
170
|
-
|
175
|
+
"iteration": {"iteration": self.iteration},
|
171
176
|
"question_text": question_text_dict,
|
172
177
|
"question_options": question_options_dict,
|
173
178
|
"question_type": question_type_dict,
|
174
|
-
"comment": comments_dict,
|
179
|
+
"comment": self.comments_dict,
|
180
|
+
"generated_tokens": self.generated_tokens,
|
175
181
|
}
|
176
182
|
|
177
183
|
def check_expression(self, expression) -> None:
|
@@ -260,6 +266,26 @@ class Result(Base, UserDict):
|
|
260
266
|
for key, value in subdict.items():
|
261
267
|
yield (index, data_type, key, str(value))
|
262
268
|
|
269
|
+
def leaves(self):
|
270
|
+
leaves = []
|
271
|
+
for question_name, answer in self.answer.items():
|
272
|
+
if not question_name.endswith("_comment"):
|
273
|
+
leaves.append(
|
274
|
+
{
|
275
|
+
"question": f"({question_name}): "
|
276
|
+
+ str(
|
277
|
+
self.question_to_attributes[question_name]["question_text"]
|
278
|
+
),
|
279
|
+
"answer": answer,
|
280
|
+
"comment": self.answer.get(question_name + "_comment", ""),
|
281
|
+
"scenario": repr(self.scenario),
|
282
|
+
"agent": repr(self.agent),
|
283
|
+
"model": repr(self.model),
|
284
|
+
"iteration": self.iteration,
|
285
|
+
}
|
286
|
+
)
|
287
|
+
return leaves
|
288
|
+
|
263
289
|
###############
|
264
290
|
# Useful
|
265
291
|
###############
|
@@ -341,6 +367,7 @@ class Result(Base, UserDict):
|
|
341
367
|
"raw_model_response", {"raw_model_response": "No raw model response"}
|
342
368
|
),
|
343
369
|
question_to_attributes=json_dict.get("question_to_attributes", None),
|
370
|
+
generated_tokens=json_dict.get("generated_tokens", None),
|
344
371
|
)
|
345
372
|
return result
|
346
373
|
|