edsl 0.1.33.dev1__py3-none-any.whl → 0.1.33.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. edsl/TemplateLoader.py +24 -0
  2. edsl/__init__.py +8 -4
  3. edsl/agents/Agent.py +46 -14
  4. edsl/agents/AgentList.py +43 -0
  5. edsl/agents/Invigilator.py +125 -212
  6. edsl/agents/InvigilatorBase.py +140 -32
  7. edsl/agents/PromptConstructionMixin.py +43 -66
  8. edsl/agents/__init__.py +1 -0
  9. edsl/auto/AutoStudy.py +117 -0
  10. edsl/auto/StageBase.py +230 -0
  11. edsl/auto/StageGenerateSurvey.py +178 -0
  12. edsl/auto/StageLabelQuestions.py +125 -0
  13. edsl/auto/StagePersona.py +61 -0
  14. edsl/auto/StagePersonaDimensionValueRanges.py +88 -0
  15. edsl/auto/StagePersonaDimensionValues.py +74 -0
  16. edsl/auto/StagePersonaDimensions.py +69 -0
  17. edsl/auto/StageQuestions.py +73 -0
  18. edsl/auto/SurveyCreatorPipeline.py +21 -0
  19. edsl/auto/utilities.py +224 -0
  20. edsl/config.py +38 -39
  21. edsl/coop/PriceFetcher.py +58 -0
  22. edsl/coop/coop.py +39 -5
  23. edsl/data/Cache.py +35 -1
  24. edsl/data_transfer_models.py +120 -38
  25. edsl/enums.py +2 -0
  26. edsl/exceptions/language_models.py +25 -1
  27. edsl/exceptions/questions.py +62 -5
  28. edsl/exceptions/results.py +4 -0
  29. edsl/inference_services/AnthropicService.py +13 -11
  30. edsl/inference_services/AwsBedrock.py +19 -17
  31. edsl/inference_services/AzureAI.py +37 -20
  32. edsl/inference_services/GoogleService.py +16 -12
  33. edsl/inference_services/GroqService.py +2 -0
  34. edsl/inference_services/InferenceServiceABC.py +24 -0
  35. edsl/inference_services/MistralAIService.py +120 -0
  36. edsl/inference_services/OpenAIService.py +41 -50
  37. edsl/inference_services/TestService.py +71 -0
  38. edsl/inference_services/models_available_cache.py +0 -6
  39. edsl/inference_services/registry.py +4 -0
  40. edsl/jobs/Answers.py +10 -12
  41. edsl/jobs/FailedQuestion.py +78 -0
  42. edsl/jobs/Jobs.py +18 -13
  43. edsl/jobs/buckets/TokenBucket.py +39 -14
  44. edsl/jobs/interviews/Interview.py +297 -77
  45. edsl/jobs/interviews/InterviewExceptionEntry.py +83 -19
  46. edsl/jobs/interviews/interview_exception_tracking.py +0 -70
  47. edsl/jobs/interviews/retry_management.py +3 -1
  48. edsl/jobs/runners/JobsRunnerAsyncio.py +116 -70
  49. edsl/jobs/runners/JobsRunnerStatusMixin.py +1 -1
  50. edsl/jobs/tasks/QuestionTaskCreator.py +30 -23
  51. edsl/jobs/tasks/TaskHistory.py +131 -213
  52. edsl/language_models/LanguageModel.py +239 -129
  53. edsl/language_models/ModelList.py +2 -2
  54. edsl/language_models/RegisterLanguageModelsMeta.py +14 -29
  55. edsl/language_models/fake_openai_call.py +15 -0
  56. edsl/language_models/fake_openai_service.py +61 -0
  57. edsl/language_models/registry.py +15 -2
  58. edsl/language_models/repair.py +0 -19
  59. edsl/language_models/utilities.py +61 -0
  60. edsl/prompts/Prompt.py +52 -2
  61. edsl/questions/AnswerValidatorMixin.py +23 -26
  62. edsl/questions/QuestionBase.py +273 -242
  63. edsl/questions/QuestionBaseGenMixin.py +133 -0
  64. edsl/questions/QuestionBasePromptsMixin.py +266 -0
  65. edsl/questions/QuestionBudget.py +6 -0
  66. edsl/questions/QuestionCheckBox.py +227 -35
  67. edsl/questions/QuestionExtract.py +98 -27
  68. edsl/questions/QuestionFreeText.py +46 -29
  69. edsl/questions/QuestionFunctional.py +7 -0
  70. edsl/questions/QuestionList.py +141 -22
  71. edsl/questions/QuestionMultipleChoice.py +173 -64
  72. edsl/questions/QuestionNumerical.py +87 -46
  73. edsl/questions/QuestionRank.py +182 -24
  74. edsl/questions/RegisterQuestionsMeta.py +31 -12
  75. edsl/questions/ResponseValidatorABC.py +169 -0
  76. edsl/questions/__init__.py +3 -4
  77. edsl/questions/decorators.py +21 -0
  78. edsl/questions/derived/QuestionLikertFive.py +10 -5
  79. edsl/questions/derived/QuestionLinearScale.py +11 -1
  80. edsl/questions/derived/QuestionTopK.py +6 -0
  81. edsl/questions/derived/QuestionYesNo.py +16 -1
  82. edsl/questions/descriptors.py +43 -7
  83. edsl/questions/prompt_templates/question_budget.jinja +13 -0
  84. edsl/questions/prompt_templates/question_checkbox.jinja +32 -0
  85. edsl/questions/prompt_templates/question_extract.jinja +11 -0
  86. edsl/questions/prompt_templates/question_free_text.jinja +3 -0
  87. edsl/questions/prompt_templates/question_linear_scale.jinja +11 -0
  88. edsl/questions/prompt_templates/question_list.jinja +17 -0
  89. edsl/questions/prompt_templates/question_multiple_choice.jinja +33 -0
  90. edsl/questions/prompt_templates/question_numerical.jinja +37 -0
  91. edsl/questions/question_registry.py +6 -2
  92. edsl/questions/templates/__init__.py +0 -0
  93. edsl/questions/templates/checkbox/__init__.py +0 -0
  94. edsl/questions/templates/checkbox/answering_instructions.jinja +10 -0
  95. edsl/questions/templates/checkbox/question_presentation.jinja +22 -0
  96. edsl/questions/templates/extract/answering_instructions.jinja +7 -0
  97. edsl/questions/templates/extract/question_presentation.jinja +1 -0
  98. edsl/questions/templates/free_text/__init__.py +0 -0
  99. edsl/questions/templates/free_text/answering_instructions.jinja +0 -0
  100. edsl/questions/templates/free_text/question_presentation.jinja +1 -0
  101. edsl/questions/templates/likert_five/__init__.py +0 -0
  102. edsl/questions/templates/likert_five/answering_instructions.jinja +10 -0
  103. edsl/questions/templates/likert_five/question_presentation.jinja +12 -0
  104. edsl/questions/templates/linear_scale/__init__.py +0 -0
  105. edsl/questions/templates/linear_scale/answering_instructions.jinja +5 -0
  106. edsl/questions/templates/linear_scale/question_presentation.jinja +5 -0
  107. edsl/questions/templates/list/__init__.py +0 -0
  108. edsl/questions/templates/list/answering_instructions.jinja +4 -0
  109. edsl/questions/templates/list/question_presentation.jinja +5 -0
  110. edsl/questions/templates/multiple_choice/__init__.py +0 -0
  111. edsl/questions/templates/multiple_choice/answering_instructions.jinja +9 -0
  112. edsl/questions/templates/multiple_choice/html.jinja +0 -0
  113. edsl/questions/templates/multiple_choice/question_presentation.jinja +12 -0
  114. edsl/questions/templates/numerical/__init__.py +0 -0
  115. edsl/questions/templates/numerical/answering_instructions.jinja +8 -0
  116. edsl/questions/templates/numerical/question_presentation.jinja +7 -0
  117. edsl/questions/templates/rank/answering_instructions.jinja +11 -0
  118. edsl/questions/templates/rank/question_presentation.jinja +15 -0
  119. edsl/questions/templates/top_k/__init__.py +0 -0
  120. edsl/questions/templates/top_k/answering_instructions.jinja +8 -0
  121. edsl/questions/templates/top_k/question_presentation.jinja +22 -0
  122. edsl/questions/templates/yes_no/__init__.py +0 -0
  123. edsl/questions/templates/yes_no/answering_instructions.jinja +6 -0
  124. edsl/questions/templates/yes_no/question_presentation.jinja +12 -0
  125. edsl/results/Dataset.py +20 -0
  126. edsl/results/DatasetExportMixin.py +41 -47
  127. edsl/results/DatasetTree.py +145 -0
  128. edsl/results/Result.py +32 -5
  129. edsl/results/Results.py +131 -45
  130. edsl/results/ResultsDBMixin.py +3 -3
  131. edsl/results/Selector.py +118 -0
  132. edsl/results/tree_explore.py +115 -0
  133. edsl/scenarios/Scenario.py +10 -4
  134. edsl/scenarios/ScenarioList.py +348 -39
  135. edsl/scenarios/ScenarioListExportMixin.py +9 -0
  136. edsl/study/SnapShot.py +8 -1
  137. edsl/surveys/RuleCollection.py +2 -2
  138. edsl/surveys/Survey.py +634 -315
  139. edsl/surveys/SurveyExportMixin.py +71 -9
  140. edsl/surveys/SurveyFlowVisualizationMixin.py +2 -1
  141. edsl/surveys/SurveyQualtricsImport.py +75 -4
  142. edsl/surveys/instructions/ChangeInstruction.py +47 -0
  143. edsl/surveys/instructions/Instruction.py +34 -0
  144. edsl/surveys/instructions/InstructionCollection.py +77 -0
  145. edsl/surveys/instructions/__init__.py +0 -0
  146. edsl/templates/error_reporting/base.html +24 -0
  147. edsl/templates/error_reporting/exceptions_by_model.html +35 -0
  148. edsl/templates/error_reporting/exceptions_by_question_name.html +17 -0
  149. edsl/templates/error_reporting/exceptions_by_type.html +17 -0
  150. edsl/templates/error_reporting/interview_details.html +111 -0
  151. edsl/templates/error_reporting/interviews.html +10 -0
  152. edsl/templates/error_reporting/overview.html +5 -0
  153. edsl/templates/error_reporting/performance_plot.html +2 -0
  154. edsl/templates/error_reporting/report.css +74 -0
  155. edsl/templates/error_reporting/report.html +118 -0
  156. edsl/templates/error_reporting/report.js +25 -0
  157. {edsl-0.1.33.dev1.dist-info → edsl-0.1.33.dev2.dist-info}/METADATA +4 -2
  158. edsl-0.1.33.dev2.dist-info/RECORD +289 -0
  159. edsl/jobs/interviews/InterviewTaskBuildingMixin.py +0 -286
  160. edsl/utilities/gcp_bucket/simple_example.py +0 -9
  161. edsl-0.1.33.dev1.dist-info/RECORD +0 -209
  162. {edsl-0.1.33.dev1.dist-info → edsl-0.1.33.dev2.dist-info}/LICENSE +0 -0
  163. {edsl-0.1.33.dev1.dist-info → edsl-0.1.33.dev2.dist-info}/WHEEL +0 -0
@@ -0,0 +1,12 @@
1
+ {# Question Presention #}
2
+ {{question_text}}
3
+ {% if use_code %}
4
+ {%- for option in question_options %}
5
+ {{ loop.index0 }}: {{option}}
6
+ {% endfor %}
7
+ {% else %}
8
+ {% for option in question_options %}
9
+ {{option}}
10
+ {% endfor %}
11
+ {% endif %}
12
+ Only 1 option may be selected.
File without changes
@@ -0,0 +1,8 @@
1
+ This question requires a numerical response in the form of an integer or decimal (e.g., -12, 0, 1, 2, 3.45, ...).
2
+ Respond with just your number on a single line.
3
+ If your response is equivalent to zero, report '0'
4
+ If you cannot determine the answer, report 'None'
5
+
6
+ {% if include_comment %}
7
+ After the answer, put a comment explaining your choice on the next line.
8
+ {% endif %}
@@ -0,0 +1,7 @@
1
+ {{question_text}}
2
+ {% if min_value is not none %}
3
+ Minimum answer value: {{min_value}}
4
+ {% endif %}
5
+ {% if max_value is not none %}
6
+ Maximum answer value: {{max_value}}
7
+ {% endif %}
@@ -0,0 +1,11 @@
1
+ {# Answering Instructions #}
2
+ {% if use_code %}
3
+ Please respond only with a comma-separated list of the code of the raked options, with square brackets. E.g., [0, 1, 3]
4
+ {% else %}
5
+ Please respond only with a comma-separated list of the ranked options, with square brackets. E.g., ['Good', 'Bad', 'Ugly']
6
+ {% endif %}
7
+ {% if include_comment %}
8
+ After the answer, you can put a comment explaining your choice on the next line.
9
+ {% endif %}
10
+
11
+
@@ -0,0 +1,15 @@
1
+ {{question_text}}
2
+ {% if use_code %}
3
+ The options are
4
+ {% for option in question_options %}
5
+ {{ loop.index0 }}: {{option}}
6
+ {% endfor %}
7
+ {% else %}
8
+ The options are:
9
+ {% for option in question_options %}
10
+ {{option}}
11
+ {% endfor %}
12
+ {% endif %}
13
+ {% if num_selections %}
14
+ You can inlcude up to {{num_selections}} options in your answer.
15
+ {% endif %}
File without changes
@@ -0,0 +1,8 @@
1
+ {# Answering Instructions #}
2
+ Please respond with valid JSON, formatted like so:
3
+ {% if include_comment %}
4
+ {"answer": [<put comma-separated list here>], "comment": "<put explanation here>"}
5
+ {% else %}
6
+ {"answer": [<put comma-separated list here>]}
7
+ {% endif %}
8
+
@@ -0,0 +1,22 @@
1
+ {{question_text}}
2
+ {% if use_code %}
3
+ {% for option in question_options %}
4
+ {{ loop.index0 }}: {{option}}
5
+ {% endfor %}
6
+ {% else %}
7
+ {% for option in question_options %}
8
+ {{ option }}
9
+ {% endfor %}
10
+ {% endif %}
11
+
12
+ {# Restrictions #}
13
+ {% if min_selections != None and max_selections != None and min_selections == max_selections %}
14
+ You must select exactly {{min_selections}} options.
15
+ {% elif min_selections != None and max_selections != None %}
16
+ Minimum number of options that must be selected: {{min_selections}}.
17
+ Maximum number of options that must be selected: {{max_selections}}.
18
+ {% elif min_selections != None %}
19
+ Minimum number of options that must be selected: {{min_selections}}.
20
+ {% elif max_selections != None %}
21
+ Maximum number of options that must be selected: {{max_selections}}.
22
+ {% endif %}
File without changes
@@ -0,0 +1,6 @@
1
+ {# Answering Instructions #}
2
+ Please reponse with just your answer.
3
+
4
+ {% if include_comment %}
5
+ After the answer, you can put a comment explaining your reponse.
6
+ {% endif %}
@@ -0,0 +1,12 @@
1
+ {# Question Presention #}
2
+ {{question_text}}
3
+ {% if use_code %}
4
+ {%- for option in question_options %}
5
+ {{ loop.index0 }}: {{option}}
6
+ {% endfor %}
7
+ {% else %}
8
+ {% for option in question_options %}
9
+ {{option}}
10
+ {% endfor %}
11
+ {% endif %}
12
+ Only 1 option may be selected.
edsl/results/Dataset.py CHANGED
@@ -8,6 +8,7 @@ from typing import Any, Union, Optional
8
8
  import numpy as np
9
9
 
10
10
  from edsl.results.ResultsExportMixin import ResultsExportMixin
11
+ from edsl.results.DatasetTree import Tree
11
12
 
12
13
 
13
14
  class Dataset(UserList, ResultsExportMixin):
@@ -30,6 +31,15 @@ class Dataset(UserList, ResultsExportMixin):
30
31
  _, values = list(self.data[0].items())[0]
31
32
  return len(values)
32
33
 
34
+ def keys(self):
35
+ """Return the keys of the first observation in the dataset.
36
+
37
+ >>> d = Dataset([{'a.b':[1,2,3,4]}])
38
+ >>> d.keys()
39
+ ['a.b']
40
+ """
41
+ return [list(o.keys())[0] for o in self]
42
+
33
43
  def __repr__(self) -> str:
34
44
  """Return a string representation of the dataset."""
35
45
  return f"Dataset({self.data})"
@@ -245,6 +255,16 @@ class Dataset(UserList, ResultsExportMixin):
245
255
 
246
256
  return Dataset(new_data)
247
257
 
258
+ @property
259
+ def tree(self):
260
+ """Return a tree representation of the dataset.
261
+
262
+ >>> d = Dataset([{'a':[1,2,3,4]}, {'b':[4,3,2,1]}])
263
+ >>> d.tree.print_tree()
264
+ Tree has not been constructed yet.
265
+ """
266
+ return Tree(self)
267
+
248
268
  @classmethod
249
269
  def example(self):
250
270
  """Return an example dataset.
@@ -4,6 +4,7 @@ import base64
4
4
  import csv
5
5
  import io
6
6
  import html
7
+ from typing import Optional
7
8
 
8
9
  from typing import Literal, Optional, Union, List
9
10
 
@@ -41,7 +42,7 @@ class DatasetExportMixin:
41
42
  >>> Results.example().relevant_columns(data_type = "flimflam")
42
43
  Traceback (most recent call last):
43
44
  ...
44
- ValueError: No columns found for data type: flimflam. Available data types are: ['agent', 'answer', 'comment', 'model', 'prompt', 'question_options', 'question_text', 'question_type', 'raw_model_response', 'scenario'].
45
+ ValueError: No columns found for data type: flimflam. Available data types are: ...
45
46
  """
46
47
  columns = [list(x.keys())[0] for x in self]
47
48
  if remove_prefix:
@@ -156,12 +157,13 @@ class DatasetExportMixin:
156
157
  iframe_height: int = 200,
157
158
  iframe_width: int = 600,
158
159
  web=False,
159
- ) -> None:
160
+ return_string: bool = False,
161
+ ) -> Union[None, str, "Results"]:
160
162
  """Print the results in a pretty format.
161
163
 
162
164
  :param pretty_labels: A dictionary of pretty labels for the columns.
163
165
  :param filename: The filename to save the results to.
164
- :param format: The format to print the results in. Options are 'rich', 'html', or 'markdown'.
166
+ :param format: The format to print the results in. Options are 'rich', 'html', 'markdown', or 'latex'.
165
167
  :param interactive: Whether to print the results interactively in a Jupyter notebook.
166
168
  :param split_at_dot: Whether to split the column names at the last dot w/ a newline.
167
169
  :param max_rows: The maximum number of rows to print.
@@ -170,6 +172,9 @@ class DatasetExportMixin:
170
172
  :param iframe_height: The height of the iframe.
171
173
  :param iframe_width: The width of the iframe.
172
174
  :param web: Whether to display the table in a web browser.
175
+ :param return_string: Whether to return the output as a string instead of printing.
176
+
177
+ :return: None if tee is False and return_string is False, the dataset if tee is True, or a string if return_string is True.
173
178
 
174
179
  Example: Print in rich format at the terminal
175
180
 
@@ -253,11 +258,14 @@ class DatasetExportMixin:
253
258
 
254
259
  >>> r.select('how_feeling').print(format='latex')
255
260
  \\begin{tabular}{l}
256
- \\toprule
257
261
  ...
262
+ \\end{tabular}
263
+ <BLANKLINE>
258
264
  """
259
265
  from IPython.display import HTML, display
260
266
  from edsl.utilities.utilities import is_notebook
267
+ import io
268
+ import sys
261
269
 
262
270
  def _determine_format(format):
263
271
  if format is None:
@@ -266,7 +274,9 @@ class DatasetExportMixin:
266
274
  else:
267
275
  format = "rich"
268
276
  if format not in ["rich", "html", "markdown", "latex"]:
269
- raise ValueError("format must be one of 'rich', 'html', or 'markdown'.")
277
+ raise ValueError(
278
+ "format must be one of 'rich', 'html', 'markdown', or 'latex'."
279
+ )
270
280
 
271
281
  return format
272
282
 
@@ -285,21 +295,24 @@ class DatasetExportMixin:
285
295
 
286
296
  new_data = list(_create_data())
287
297
 
298
+ # Capture output if return_string is True
299
+ if return_string:
300
+ old_stdout = sys.stdout
301
+ sys.stdout = io.StringIO()
302
+
303
+ output = None
304
+
288
305
  if format == "rich":
289
306
  from edsl.utilities.interface import print_dataset_with_rich
290
307
 
291
- print_dataset_with_rich(
308
+ output = print_dataset_with_rich(
292
309
  new_data, filename=filename, split_at_dot=split_at_dot
293
310
  )
294
- return self if tee else None
295
-
296
- if format == "markdown":
311
+ elif format == "markdown":
297
312
  from edsl.utilities.interface import print_list_of_dicts_as_markdown_table
298
313
 
299
- print_list_of_dicts_as_markdown_table(new_data, filename=filename)
300
- return self if tee else None
301
-
302
- if format == "latex":
314
+ output = print_list_of_dicts_as_markdown_table(new_data, filename=filename)
315
+ elif format == "latex":
303
316
  df = self.to_pandas()
304
317
  df.columns = [col.replace("_", " ") for col in df.columns]
305
318
  latex_string = df.to_latex(index=False)
@@ -309,23 +322,14 @@ class DatasetExportMixin:
309
322
  f.write(latex_string)
310
323
  else:
311
324
  print(latex_string)
312
-
313
- return self if tee else None
314
-
315
- if format == "html":
325
+ output = latex_string
326
+ elif format == "html":
316
327
  from edsl.utilities.interface import print_list_of_dicts_as_html_table
317
328
 
318
329
  html_source = print_list_of_dicts_as_html_table(
319
330
  new_data, interactive=interactive
320
331
  )
321
332
 
322
- # if download_link:
323
- # from IPython.display import HTML, display
324
- # csv_file = output.getvalue()
325
- # b64 = base64.b64encode(csv_file.encode()).decode()
326
- # download_link = f'<a href="data:file/csv;base64,{b64}" download="my_data.csv">Download CSV file</a>'
327
- # #display(HTML(download_link))
328
-
329
333
  if iframe:
330
334
  iframe = f""""
331
335
  <iframe srcdoc="{ html.escape(html_source) }" style="width: {iframe_width}px; height: {iframe_height}px;"></iframe>
@@ -338,7 +342,18 @@ class DatasetExportMixin:
338
342
 
339
343
  view_html(html_source)
340
344
 
341
- return self if tee else None
345
+ output = html_source
346
+
347
+ # Restore stdout and get captured output if return_string is True
348
+ if return_string:
349
+ captured_output = sys.stdout.getvalue()
350
+ sys.stdout = old_stdout
351
+ return captured_output or output
352
+
353
+ if tee:
354
+ return self
355
+
356
+ return None
342
357
 
343
358
  def to_csv(
344
359
  self,
@@ -501,7 +516,7 @@ class DatasetExportMixin:
501
516
 
502
517
  return list_of_dicts
503
518
 
504
- def to_list(self, flatten=False, remove_none=False) -> list[list]:
519
+ def to_list(self, flatten=False, remove_none=False, unzipped=False) -> list[list]:
505
520
  """Convert the results to a list of lists.
506
521
 
507
522
  :param flatten: Whether to flatten the list of lists.
@@ -596,27 +611,6 @@ class DatasetExportMixin:
596
611
  if return_link:
597
612
  return filename
598
613
 
599
- def to_docx(self, filename: Optional[str] = None, separator: str = "\n"):
600
- """Export the results to a Word document.
601
-
602
- :param filename: The filename to save the Word document to.
603
-
604
-
605
- """
606
- from docx import Document
607
-
608
- doc = Document()
609
- for entry in self:
610
- key, values = list(entry.items())[0]
611
- doc.add_paragraph(key)
612
- line = separator.join(values)
613
- doc.add_paragraph(line)
614
-
615
- if filename is not None:
616
- doc.save(filename)
617
- else:
618
- return doc
619
-
620
614
  def tally(
621
615
  self, *fields: Optional[str], top_n: Optional[int] = None, output="Dataset"
622
616
  ) -> Union[dict, "Dataset"]:
@@ -0,0 +1,145 @@
1
+ from typing import Dict, List, Any, Optional
2
+ from docx import Document
3
+ from docx.shared import Inches, Pt
4
+ from docx.enum.text import WD_ALIGN_PARAGRAPH
5
+ from docx.enum.style import WD_STYLE_TYPE
6
+
7
+
8
+ class TreeNode:
9
+ def __init__(self, key=None, value=None):
10
+ self.key = key
11
+ self.value = value
12
+ self.children = {}
13
+
14
+
15
+ class Tree:
16
+ def __init__(self, data: "Dataset"):
17
+ d = {}
18
+ for entry in data:
19
+ d.update(entry)
20
+ self.data = d
21
+ self.root = None
22
+
23
+ def unique_values_by_keys(self) -> dict:
24
+ unique_values = {}
25
+ for key, values in self.data.items():
26
+ unique_values[key] = list(set(values))
27
+ return unique_values
28
+
29
+ def construct_tree(self, node_order: Optional[List[str]] = None):
30
+ # Validate node_order
31
+ if node_order is None:
32
+ unique_values = self.unique_values_by_keys()
33
+ # Sort keys by number of unique values
34
+ node_order = sorted(
35
+ unique_values, key=lambda k: len(unique_values[k]), reverse=True
36
+ )
37
+ else:
38
+ if not set(node_order).issubset(set(self.data.keys())):
39
+ invalid_keys = set(node_order) - set(self.data.keys())
40
+ raise ValueError(f"Invalid keys in node_order: {invalid_keys}")
41
+
42
+ self.root = TreeNode()
43
+
44
+ for i in range(len(self.data[node_order[0]])):
45
+ current = self.root
46
+ for level in node_order[:-1]:
47
+ value = self.data[level][i]
48
+ if value not in current.children:
49
+ current.children[value] = TreeNode(key=level, value=value)
50
+ current = current.children[value]
51
+
52
+ leaf_key = node_order[-1]
53
+ leaf_value = self.data[leaf_key][i]
54
+ if leaf_value not in current.children:
55
+ current.children[leaf_value] = TreeNode(key=leaf_key, value=leaf_value)
56
+
57
+ def print_tree(
58
+ self, node: Optional[TreeNode] = None, level: int = 0, print_keys: bool = False
59
+ ):
60
+ if node is None:
61
+ node = self.root
62
+ if node is None:
63
+ print("Tree has not been constructed yet.")
64
+ return
65
+
66
+ if node.value is not None:
67
+ if print_keys and node.key is not None:
68
+ print(" " * level + f"{node.key}: {node.value}")
69
+ else:
70
+ print(" " * level + str(node.value))
71
+ for child in node.children.values():
72
+ self.print_tree(child, level + 1, print_keys)
73
+
74
+ def to_docx(self, filename: str):
75
+ doc = Document()
76
+
77
+ # Create styles for headings
78
+ for i in range(1, 10): # Up to 9 levels of headings
79
+ style_name = f"Heading {i}"
80
+ if style_name not in doc.styles:
81
+ doc.styles.add_style(style_name, WD_STYLE_TYPE.PARAGRAPH)
82
+
83
+ # Get or create the 'Body Text' style
84
+ if "Body Text" not in doc.styles:
85
+ body_style = doc.styles.add_style("Body Text", WD_STYLE_TYPE.PARAGRAPH)
86
+ else:
87
+ body_style = doc.styles["Body Text"]
88
+
89
+ body_style.font.size = Pt(11)
90
+
91
+ self._add_to_docx(doc, self.root, 0)
92
+ doc.save(filename)
93
+
94
+ def _add_to_docx(self, doc, node: TreeNode, level: int):
95
+ if node.value is not None:
96
+ if level == 0:
97
+ doc.add_heading(str(node.value), level=level + 1)
98
+ elif node.children: # If the node has children, it's not the last level
99
+ para = doc.add_paragraph(str(node.value))
100
+ para.style = f"Heading {level+1}"
101
+ else: # If the node has no children, it's the last level (body text)
102
+ para = doc.add_paragraph(str(node.value))
103
+ para.style = "Body Text"
104
+
105
+ # Process child nodes (moved outside the if block)
106
+ for child in node.children.values():
107
+ self._add_to_docx(doc, child, level + 1)
108
+
109
+
110
+ # Example usage (commented out)
111
+ """
112
+ from edsl.results.Dataset import Dataset
113
+
114
+ data = Dataset(
115
+ [
116
+ {"continent": ["North America", "Asia", "Europe", "North America", "Asia"]},
117
+ {"country": ["US", "China", "France", "Canada", "Japan"]},
118
+ {"city": ["New York", "Beijing", "Paris", "Toronto", "Tokyo"]},
119
+ {"population": [8419000, 21540000, 2161000, 2930000, 13960000]},
120
+ ]
121
+ )
122
+
123
+ tree = Tree(data)
124
+
125
+ try:
126
+ tree.construct_tree(["continent", "country", "city", "population"])
127
+ print("Tree without key names:")
128
+ tree.print_tree()
129
+ print("\nTree with key names:")
130
+ tree.print_tree(print_keys=True)
131
+ except ValueError as e:
132
+ print(f"Error: {e}")
133
+
134
+ # Demonstrating validation
135
+ try:
136
+ tree.construct_tree(["continent", "country", "invalid_key"])
137
+ except ValueError as e:
138
+ print(f"\nValidation Error: {e}")
139
+
140
+ tree = Tree(data)
141
+ tree.construct_tree(["continent", "country", "city", "population"])
142
+ tree.print_tree(print_keys=True)
143
+ tree.to_docx("tree_structure.docx")
144
+ print("DocX file 'tree_structure.docx' has been created.")
145
+ """
edsl/results/Result.py CHANGED
@@ -53,8 +53,8 @@ class Result(Base, UserDict):
53
53
 
54
54
  >>> import warnings
55
55
  >>> warnings.simplefilter("ignore", UserWarning)
56
- >>> Result.example().answer
57
- {'how_feeling': 'OK', 'how_feeling_comment': 'This is a real survey response from a human.', 'how_feeling_yesterday': 'Great', 'how_feeling_yesterday_comment': 'This is a real survey response from a human.'}
56
+ >>> Result.example().answer == {'how_feeling_yesterday': 'Great', 'how_feeling': 'OK'}
57
+ True
58
58
 
59
59
  Its main data is an Agent, a Scenario, a Model, an Iteration, and an Answer.
60
60
  These are stored both in the UserDict and as attributes.
@@ -73,6 +73,8 @@ class Result(Base, UserDict):
73
73
  raw_model_response=None,
74
74
  survey: Optional["Survey"] = None,
75
75
  question_to_attributes: Optional[dict] = None,
76
+ generated_tokens: Optional[dict] = None,
77
+ comments_dict: Optional[dict] = None,
76
78
  ):
77
79
  """Initialize a Result object.
78
80
 
@@ -113,6 +115,7 @@ class Result(Base, UserDict):
113
115
  "prompt": prompt or {},
114
116
  "raw_model_response": raw_model_response or {},
115
117
  "question_to_attributes": question_to_attributes,
118
+ "generated_tokens": generated_tokens or {},
116
119
  }
117
120
  super().__init__(**data)
118
121
  # but also store the data as attributes
@@ -125,6 +128,8 @@ class Result(Base, UserDict):
125
128
  self.raw_model_response = raw_model_response or {}
126
129
  self.survey = survey
127
130
  self.question_to_attributes = question_to_attributes
131
+ self.generated_tokens = generated_tokens
132
+ self.comments_dict = comments_dict or {}
128
133
 
129
134
  self._combined_dict = None
130
135
  self._problem_keys = None
@@ -140,7 +145,7 @@ class Result(Base, UserDict):
140
145
  else:
141
146
  agent_name = self.agent.name
142
147
 
143
- comments_dict = {k: v for k, v in self.answer.items() if k.endswith("_comment")}
148
+ # comments_dict = {k: v for k, v in self.answer.items() if k.endswith("_comment")}
144
149
  question_text_dict = {}
145
150
  question_options_dict = {}
146
151
  question_type_dict = {}
@@ -167,11 +172,12 @@ class Result(Base, UserDict):
167
172
  "answer": self.answer,
168
173
  "prompt": self.prompt,
169
174
  "raw_model_response": self.raw_model_response,
170
- # "iteration": {"iteration": self.iteration},
175
+ "iteration": {"iteration": self.iteration},
171
176
  "question_text": question_text_dict,
172
177
  "question_options": question_options_dict,
173
178
  "question_type": question_type_dict,
174
- "comment": comments_dict,
179
+ "comment": self.comments_dict,
180
+ "generated_tokens": self.generated_tokens,
175
181
  }
176
182
 
177
183
  def check_expression(self, expression) -> None:
@@ -260,6 +266,26 @@ class Result(Base, UserDict):
260
266
  for key, value in subdict.items():
261
267
  yield (index, data_type, key, str(value))
262
268
 
269
+ def leaves(self):
270
+ leaves = []
271
+ for question_name, answer in self.answer.items():
272
+ if not question_name.endswith("_comment"):
273
+ leaves.append(
274
+ {
275
+ "question": f"({question_name}): "
276
+ + str(
277
+ self.question_to_attributes[question_name]["question_text"]
278
+ ),
279
+ "answer": answer,
280
+ "comment": self.answer.get(question_name + "_comment", ""),
281
+ "scenario": repr(self.scenario),
282
+ "agent": repr(self.agent),
283
+ "model": repr(self.model),
284
+ "iteration": self.iteration,
285
+ }
286
+ )
287
+ return leaves
288
+
263
289
  ###############
264
290
  # Useful
265
291
  ###############
@@ -341,6 +367,7 @@ class Result(Base, UserDict):
341
367
  "raw_model_response", {"raw_model_response": "No raw model response"}
342
368
  ),
343
369
  question_to_attributes=json_dict.get("question_to_attributes", None),
370
+ generated_tokens=json_dict.get("generated_tokens", None),
344
371
  )
345
372
  return result
346
373