edsl 0.1.38.dev2__py3-none-any.whl → 0.1.38.dev4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. edsl/Base.py +60 -31
  2. edsl/__version__.py +1 -1
  3. edsl/agents/Agent.py +18 -9
  4. edsl/agents/AgentList.py +59 -8
  5. edsl/agents/Invigilator.py +18 -7
  6. edsl/agents/InvigilatorBase.py +0 -19
  7. edsl/agents/PromptConstructor.py +5 -4
  8. edsl/config.py +8 -0
  9. edsl/coop/coop.py +74 -7
  10. edsl/data/Cache.py +27 -2
  11. edsl/data/CacheEntry.py +8 -3
  12. edsl/data/RemoteCacheSync.py +0 -19
  13. edsl/enums.py +2 -0
  14. edsl/inference_services/GoogleService.py +7 -15
  15. edsl/inference_services/PerplexityService.py +163 -0
  16. edsl/inference_services/registry.py +2 -0
  17. edsl/jobs/Jobs.py +88 -548
  18. edsl/jobs/JobsChecks.py +147 -0
  19. edsl/jobs/JobsPrompts.py +268 -0
  20. edsl/jobs/JobsRemoteInferenceHandler.py +239 -0
  21. edsl/jobs/interviews/Interview.py +11 -11
  22. edsl/jobs/runners/JobsRunnerAsyncio.py +140 -35
  23. edsl/jobs/runners/JobsRunnerStatus.py +0 -2
  24. edsl/jobs/tasks/TaskHistory.py +15 -16
  25. edsl/language_models/LanguageModel.py +44 -84
  26. edsl/language_models/ModelList.py +47 -1
  27. edsl/language_models/registry.py +57 -4
  28. edsl/prompts/Prompt.py +8 -3
  29. edsl/questions/QuestionBase.py +20 -16
  30. edsl/questions/QuestionExtract.py +3 -4
  31. edsl/questions/question_registry.py +36 -6
  32. edsl/results/CSSParameterizer.py +108 -0
  33. edsl/results/Dataset.py +146 -15
  34. edsl/results/DatasetExportMixin.py +231 -217
  35. edsl/results/DatasetTree.py +134 -4
  36. edsl/results/Result.py +18 -9
  37. edsl/results/Results.py +145 -51
  38. edsl/results/TableDisplay.py +198 -0
  39. edsl/results/table_display.css +78 -0
  40. edsl/scenarios/FileStore.py +187 -13
  41. edsl/scenarios/Scenario.py +61 -4
  42. edsl/scenarios/ScenarioJoin.py +127 -0
  43. edsl/scenarios/ScenarioList.py +237 -62
  44. edsl/surveys/Survey.py +16 -2
  45. edsl/surveys/SurveyFlowVisualizationMixin.py +67 -9
  46. edsl/surveys/instructions/Instruction.py +12 -0
  47. edsl/templates/error_reporting/interview_details.html +3 -3
  48. edsl/templates/error_reporting/interviews.html +18 -9
  49. edsl/utilities/utilities.py +15 -0
  50. {edsl-0.1.38.dev2.dist-info → edsl-0.1.38.dev4.dist-info}/METADATA +2 -1
  51. {edsl-0.1.38.dev2.dist-info → edsl-0.1.38.dev4.dist-info}/RECORD +53 -45
  52. {edsl-0.1.38.dev2.dist-info → edsl-0.1.38.dev4.dist-info}/LICENSE +0 -0
  53. {edsl-0.1.38.dev2.dist-info → edsl-0.1.38.dev4.dist-info}/WHEEL +0 -0
@@ -1,4 +1,4 @@
1
- from typing import Optional
1
+ from typing import Optional, List
2
2
  from collections import UserList
3
3
  from edsl import Model
4
4
 
@@ -10,6 +10,8 @@ from edsl.utilities.utilities import dict_hash
10
10
 
11
11
 
12
12
  class ModelList(Base, UserList):
13
+ __documentation__ = """https://docs.expectedparrot.com/en/latest/language_models.html#module-edsl.language_models.ModelList"""
14
+
13
15
  def __init__(self, data: Optional[list] = None):
14
16
  """Initialize the ScenarioList class.
15
17
 
@@ -37,6 +39,9 @@ class ModelList(Base, UserList):
37
39
  def __repr__(self):
38
40
  return f"ModelList({super().__repr__()})"
39
41
 
42
+ def _summary(self):
43
+ return {"EDSL Class": "ModelList", "Number of Models": len(self)}
44
+
40
45
  def __hash__(self):
41
46
  """Return a hash of the ModelList. This is used for comparison of ModelLists.
42
47
 
@@ -48,6 +53,42 @@ class ModelList(Base, UserList):
48
53
 
49
54
  return dict_hash(self.to_dict(sort=True, add_edsl_version=False))
50
55
 
56
+ def to_scenario_list(self):
57
+ from edsl import ScenarioList, Scenario
58
+
59
+ sl = ScenarioList()
60
+ for model in self:
61
+ d = {"model": model.model}
62
+ d.update(model.parameters)
63
+ sl.append(Scenario(d))
64
+ return sl
65
+
66
+ def tree(self, node_list: Optional[List[str]] = None):
67
+ return self.to_scenario_list().tree(node_list)
68
+
69
+ def table(
70
+ self,
71
+ *fields,
72
+ tablefmt: Optional[str] = None,
73
+ pretty_labels: Optional[dict] = None,
74
+ ):
75
+ """
76
+ >>> ModelList.example().table("model")
77
+ model
78
+ -------
79
+ gpt-4o
80
+ gpt-4o
81
+ gpt-4o
82
+ """
83
+ return (
84
+ self.to_scenario_list()
85
+ .to_dataset()
86
+ .table(*fields, tablefmt=tablefmt, pretty_labels=pretty_labels)
87
+ )
88
+
89
+ def to_list(self):
90
+ return self.to_scenario_list().to_list()
91
+
51
92
  def to_dict(self, sort=False, add_edsl_version=True):
52
93
  if sort:
53
94
  model_list = sorted([model for model in self], key=lambda x: hash(x))
@@ -71,6 +112,11 @@ class ModelList(Base, UserList):
71
112
 
72
113
  return d
73
114
 
115
+ def _repr_html_(self):
116
+ """Return an HTML representation of the ModelList."""
117
+ footer = f"<a href={self.__documentation__}>(docs)</a>"
118
+ return str(self.summary(format="html")) + footer
119
+
74
120
  @classmethod
75
121
  def from_names(self, *args, **kwargs):
76
122
  """A a model list from a list of names"""
@@ -7,6 +7,49 @@ from edsl.config import CONFIG
7
7
  # else:
8
8
  # default_model = CONFIG.get("EDSL_DEFAULT_MODEL")
9
9
 
10
+ from collections import UserList
11
+
12
+
13
+ class PrettyList(UserList):
14
+ def __init__(self, data=None, columns=None):
15
+ super().__init__(data)
16
+ self.columns = columns
17
+
18
+ def _repr_html_(self):
19
+ if isinstance(self[0], list) or isinstance(self[0], tuple):
20
+ num_cols = len(self[0])
21
+ else:
22
+ num_cols = 1
23
+
24
+ if self.columns:
25
+ columns = self.columns
26
+ else:
27
+ columns = list(range(num_cols))
28
+
29
+ if num_cols > 1:
30
+ return (
31
+ "<pre><table>"
32
+ + "".join(["<th>" + str(column) + "</th>" for column in columns])
33
+ + "".join(
34
+ [
35
+ "<tr>"
36
+ + "".join(["<td>" + str(x) + "</td>" for x in row])
37
+ + "</tr>"
38
+ for row in self
39
+ ]
40
+ )
41
+ + "</table></pre>"
42
+ )
43
+ else:
44
+ return (
45
+ "<pre><table>"
46
+ + "".join(["<th>" + str(index) + "</th>" for index in columns])
47
+ + "".join(
48
+ ["<tr>" + "<td>" + str(row) + "</td>" + "</tr>" for row in self]
49
+ )
50
+ + "</table></pre>"
51
+ )
52
+
10
53
 
11
54
  def get_model_class(model_name, registry=None):
12
55
  from edsl.inference_services.registry import default
@@ -82,17 +125,27 @@ class Model(metaclass=Meta):
82
125
 
83
126
  if search_term is None:
84
127
  if name_only:
85
- return [m[0] for m in full_list]
128
+ return PrettyList(
129
+ [m[0] for m in full_list],
130
+ columns=["Model Name", "Service Name", "Code"],
131
+ )
86
132
  else:
87
- return full_list
133
+ return PrettyList(
134
+ full_list, columns=["Model Name", "Service Name", "Code"]
135
+ )
88
136
  else:
89
137
  filtered_results = [
90
138
  m for m in full_list if search_term in m[0] or search_term in m[1]
91
139
  ]
92
140
  if name_only:
93
- return [m[0] for m in filtered_results]
141
+ return PrettyList(
142
+ [m[0] for m in filtered_results],
143
+ columns=["Model Name", "Service Name", "Code"],
144
+ )
94
145
  else:
95
- return filtered_results
146
+ return PrettyList(
147
+ filtered_results, columns=["Model Name", "Service Name", "Code"]
148
+ )
96
149
 
97
150
  @classmethod
98
151
  def check_models(cls, verbose=False):
edsl/prompts/Prompt.py CHANGED
@@ -29,9 +29,14 @@ class Prompt(PersistenceMixin, RichPrintingMixin):
29
29
 
30
30
  def _repr_html_(self):
31
31
  """Return an HTML representation of the Prompt."""
32
- from edsl.utilities.utilities import data_to_html
33
-
34
- return data_to_html(self.to_dict())
32
+ # from edsl.utilities.utilities import data_to_html
33
+ # return data_to_html(self.to_dict())
34
+ d = self.to_dict()
35
+ data = [[k, v] for k, v in d.items()]
36
+ from tabulate import tabulate
37
+
38
+ table = str(tabulate(data, headers=["keys", "values"], tablefmt="html"))
39
+ return f"<pre>{table}</pre>"
35
40
 
36
41
  def __len__(self):
37
42
  """Return the length of the prompt text."""
@@ -82,7 +82,8 @@ class QuestionBase(
82
82
  if not hasattr(self, "_fake_data_factory"):
83
83
  from polyfactory.factories.pydantic_factory import ModelFactory
84
84
 
85
- class FakeData(ModelFactory[self.response_model]): ...
85
+ class FakeData(ModelFactory[self.response_model]):
86
+ ...
86
87
 
87
88
  self._fake_data_factory = FakeData
88
89
  return self._fake_data_factory
@@ -263,12 +264,9 @@ class QuestionBase(
263
264
  >>> m.execute_model_call("", "")
264
265
  {'message': [{'text': "Yo, what's up?"}], 'usage': {'prompt_tokens': 1, 'completion_tokens': 1}}
265
266
  >>> Q.run_example(show_answer = True, model = m, disable_remote_cache = True, disable_remote_inference = True)
266
- ┏━━━━━━━━━━━━━━━━┓
267
- ┃ answer ┃
268
- .how_are_you ┃
269
- ┡━━━━━━━━━━━━━━━━┩
270
- │ Yo, what's up? │
271
- └────────────────┘
267
+ answer.how_are_you
268
+ --------------------
269
+ Yo, what's up?
272
270
  """
273
271
  if model is None:
274
272
  from edsl import Model
@@ -284,7 +282,7 @@ class QuestionBase(
284
282
  )
285
283
  )
286
284
  if show_answer:
287
- results.select("answer.*").print()
285
+ return results.select("answer.*").print()
288
286
  else:
289
287
  return results
290
288
 
@@ -362,16 +360,22 @@ class QuestionBase(
362
360
 
363
361
  # region: Magic methods
364
362
  def _repr_html_(self):
365
- from edsl.utilities.utilities import data_to_html
363
+ # from edsl.utilities.utilities import data_to_html
366
364
 
367
- data = self.to_dict()
368
- try:
369
- _ = data.pop("edsl_version")
370
- _ = data.pop("edsl_class_name")
371
- except KeyError:
372
- print("Serialized question lacks edsl version, but is should have it.")
365
+ data = self.to_dict(add_edsl_version=False)
366
+ # keys = list(data.keys())
367
+ # values = list(data.values())
368
+ from tabulate import tabulate
369
+
370
+ return tabulate(data.items(), headers=["keys", "values"], tablefmt="html")
371
+
372
+ # try:
373
+ # _ = data.pop("edsl_version")
374
+ # _ = data.pop("edsl_class_name")
375
+ # except KeyError:
376
+ # print("Serialized question lacks edsl version, but is should have it.")
373
377
 
374
- return data_to_html(data)
378
+ # return data_to_html(data)
375
379
 
376
380
  def __getitem__(self, key: str) -> Any:
377
381
  """Get an attribute of the question so it can be treated like a dictionary.
@@ -1,4 +1,7 @@
1
1
  from __future__ import annotations
2
+ import json
3
+ import re
4
+
2
5
  from typing import Any, Optional, Dict
3
6
  from edsl.questions.QuestionBase import QuestionBase
4
7
  from edsl.questions.descriptors import AnswerTemplateDescriptor
@@ -11,9 +14,6 @@ from edsl.questions.decorators import inject_exception
11
14
  from typing import Dict, Any
12
15
  from pydantic import create_model, Field
13
16
 
14
- import json
15
- import re
16
-
17
17
 
18
18
  def extract_json(text, expected_keys, verbose=False):
19
19
  # Escape special regex characters in keys
@@ -112,7 +112,6 @@ class QuestionExtract(QuestionBase):
112
112
 
113
113
  :param question_name: The name of the question.
114
114
  :param question_text: The text of the question.
115
- :param question_options: The options the respondent should select from.
116
115
  :param answer_template: The template for the answer.
117
116
  """
118
117
  self.question_name = question_name
@@ -90,6 +90,22 @@ class Question(metaclass=Meta):
90
90
  coop = Coop()
91
91
  return coop.patch(uuid, url, description, value, visibility)
92
92
 
93
+ @classmethod
94
+ def list_question_types(cls):
95
+ """Return a list of available question types.
96
+
97
+ >>> from edsl import Question
98
+ >>> Question.list_question_types()
99
+ ['checkbox', 'extract', 'free_text', 'functional', 'likert_five', 'linear_scale', 'list', 'multiple_choice', 'numerical', 'rank', 'top_k', 'yes_no']
100
+ """
101
+ return [
102
+ q
103
+ for q in sorted(
104
+ list(RegisterQuestionsMeta.question_types_to_classes().keys())
105
+ )
106
+ if q not in ["budget"]
107
+ ]
108
+
93
109
  @classmethod
94
110
  def available(cls, show_class_names: bool = False) -> Union[list, dict]:
95
111
  """Return a list of available question types.
@@ -98,18 +114,32 @@ class Question(metaclass=Meta):
98
114
 
99
115
  Example usage:
100
116
 
101
- >>> from edsl import Question
102
- >>> Question.available()
103
- ['checkbox', 'extract', 'free_text', 'functional', 'likert_five', 'linear_scale', 'list', 'multiple_choice', 'numerical', 'rank', 'top_k', 'yes_no']
104
117
  """
118
+ from edsl.results.Dataset import Dataset
119
+
105
120
  exclude = ["budget"]
106
121
  if show_class_names:
107
122
  return RegisterQuestionsMeta.question_types_to_classes()
108
123
  else:
109
- question_list = sorted(
110
- set(RegisterQuestionsMeta.question_types_to_classes().keys())
124
+ question_list = [
125
+ q
126
+ for q in sorted(
127
+ set(RegisterQuestionsMeta.question_types_to_classes().keys())
128
+ )
129
+ if q not in exclude
130
+ ]
131
+ d = RegisterQuestionsMeta.question_types_to_classes()
132
+ question_classes = [d[q] for q in question_list]
133
+ example_questions = [repr(q.example()) for q in question_classes]
134
+
135
+ return Dataset(
136
+ [
137
+ {"question_type": [q for q in question_list]},
138
+ {"question_class": [q.__name__ for q in question_classes]},
139
+ {"example_question": example_questions},
140
+ ],
141
+ print_parameters={"containerHeight": "auto"},
111
142
  )
112
- return [q for q in question_list if q not in exclude]
113
143
 
114
144
 
115
145
  def get_question_class(question_type):
@@ -0,0 +1,108 @@
1
+ import re
2
+ from typing import Dict, Set, Optional
3
+
4
+
5
+ class CSSParameterizer:
6
+ """A utility class to parameterize CSS with custom properties (variables)."""
7
+
8
+ def __init__(self, css_content: str):
9
+ """
10
+ Initialize with CSS content to be parameterized.
11
+
12
+ Args:
13
+ css_content (str): The CSS content containing var() declarations
14
+ """
15
+ self.css_content = css_content
16
+ self._extract_variables()
17
+
18
+ def _extract_variables(self) -> None:
19
+ """Extract all CSS custom properties (variables) from the CSS content."""
20
+ # Find all var(...) declarations in the CSS
21
+ var_pattern = r"var\((--[a-zA-Z0-9-]+)\)"
22
+ self.variables = set(re.findall(var_pattern, self.css_content))
23
+
24
+ def _validate_parameters(self, parameters: Dict[str, str]) -> Set[str]:
25
+ """
26
+ Validate the provided parameters against the CSS variables.
27
+
28
+ Args:
29
+ parameters (Dict[str, str]): Dictionary of variable names and their values
30
+
31
+ Returns:
32
+ Set[str]: Set of missing variables
33
+ """
34
+ # Convert parameter keys to CSS variable format if they don't already have --
35
+ formatted_params = {
36
+ f"--{k}" if not k.startswith("--") else k for k in parameters.keys()
37
+ }
38
+
39
+ # print("Variables from CSS:", self.variables)
40
+ # print("Formatted parameters:", formatted_params)
41
+
42
+ # Find missing and extra variables
43
+ missing_vars = self.variables - formatted_params
44
+ extra_vars = formatted_params - self.variables
45
+
46
+ if extra_vars:
47
+ print(f"Warning: Found unused parameters: {extra_vars}")
48
+
49
+ return missing_vars
50
+
51
+ def generate_root(self, **parameters: str) -> Optional[str]:
52
+ """
53
+ Generate a :root block with the provided parameters.
54
+
55
+ Args:
56
+ **parameters: Keyword arguments where keys are variable names and values are their values
57
+
58
+ Returns:
59
+ str: Generated :root block with variables, or None if validation fails
60
+
61
+ Example:
62
+ >>> css = "body { height: var(--bodyHeight); }"
63
+ >>> parameterizer = CSSParameterizer(css)
64
+ >>> parameterizer.apply_parameters({'bodyHeight':"100vh"})
65
+ ':root {\\n --bodyHeight: 100vh;\\n}\\n\\nbody { height: var(--bodyHeight); }'
66
+ """
67
+ missing_vars = self._validate_parameters(parameters)
68
+
69
+ if missing_vars:
70
+ print(f"Error: Missing required variables: {missing_vars}")
71
+ return None
72
+
73
+ # Format parameters with -- prefix if not present
74
+ formatted_params = {
75
+ f"--{k}" if not k.startswith("--") else k: v for k, v in parameters.items()
76
+ }
77
+
78
+ # Generate the :root block
79
+ root_block = [":root {"]
80
+ for var_name, value in sorted(formatted_params.items()):
81
+ if var_name in self.variables:
82
+ root_block.append(f" {var_name}: {value};")
83
+ root_block.append("}")
84
+
85
+ return "\n".join(root_block)
86
+
87
+ def apply_parameters(self, parameters: dict) -> Optional[str]:
88
+ """
89
+ Generate the complete CSS with the :root block and original CSS content.
90
+
91
+ Args:
92
+ **parameters: Keyword arguments where keys are variable names and values are their values
93
+
94
+ Returns:
95
+ str: Complete CSS with :root block and original content, or None if validation fails
96
+ """
97
+ root_block = self.generate_root(**parameters)
98
+ if root_block is None:
99
+ return None
100
+
101
+ return f"{root_block}\n\n{self.css_content}"
102
+
103
+
104
+ # Example usage
105
+ if __name__ == "__main__":
106
+ import doctest
107
+
108
+ doctest.testmod()
edsl/results/Dataset.py CHANGED
@@ -5,19 +5,23 @@ import random
5
5
  import json
6
6
  from collections import UserList
7
7
  from typing import Any, Union, Optional
8
-
8
+ import sys
9
9
  import numpy as np
10
10
 
11
11
  from edsl.results.ResultsExportMixin import ResultsExportMixin
12
12
  from edsl.results.DatasetTree import Tree
13
+ from edsl.results.TableDisplay import TableDisplay
13
14
 
14
15
 
15
16
  class Dataset(UserList, ResultsExportMixin):
16
17
  """A class to represent a dataset of observations."""
17
18
 
18
- def __init__(self, data: list[dict[str, Any]] = None):
19
+ def __init__(
20
+ self, data: list[dict[str, Any]] = None, print_parameters: Optional[dict] = None
21
+ ):
19
22
  """Initialize the dataset with the given data."""
20
23
  super().__init__(data)
24
+ self.print_parameters = print_parameters
21
25
 
22
26
  def __len__(self) -> int:
23
27
  """Return the number of observations in the dataset.
@@ -32,7 +36,7 @@ class Dataset(UserList, ResultsExportMixin):
32
36
  _, values = list(self.data[0].items())[0]
33
37
  return len(values)
34
38
 
35
- def keys(self):
39
+ def keys(self) -> list[str]:
36
40
  """Return the keys of the first observation in the dataset.
37
41
 
38
42
  >>> d = Dataset([{'a.b':[1,2,3,4]}])
@@ -41,10 +45,45 @@ class Dataset(UserList, ResultsExportMixin):
41
45
  """
42
46
  return [list(o.keys())[0] for o in self]
43
47
 
48
+ def filter(self, expression):
49
+ return self.to_scenario_list().filter(expression).to_dataset()
50
+
44
51
  def __repr__(self) -> str:
45
52
  """Return a string representation of the dataset."""
46
53
  return f"Dataset({self.data})"
47
54
 
55
+ def write(self, filename: str, tablefmt: Optional[str] = None) -> None:
56
+ return self.table(tablefmt=tablefmt).write(filename)
57
+
58
+ def _repr_html_(self):
59
+ # headers, data = self._tabular()
60
+ return self.table(print_parameters=self.print_parameters)._repr_html_()
61
+ # return TableDisplay(headers=headers, data=data, raw_data_set=self)
62
+
63
+ def _tabular(self) -> tuple[list[str], list[list[Any]]]:
64
+ # Extract headers
65
+ headers = []
66
+ for entry in self.data:
67
+ headers.extend(entry.keys())
68
+ headers = list(dict.fromkeys(headers)) # Ensure unique headers
69
+
70
+ # Extract data
71
+ max_len = max(len(values) for entry in self.data for values in entry.values())
72
+ rows = []
73
+ for i in range(max_len):
74
+ row = []
75
+ for header in headers:
76
+ for entry in self.data:
77
+ if header in entry:
78
+ values = entry[header]
79
+ row.append(values[i] if i < len(values) else None)
80
+ break
81
+ else:
82
+ row.append(None) # Default to None if header is missing
83
+ rows.append(row)
84
+
85
+ return headers, rows
86
+
48
87
  def _key_to_value(self, key: str) -> Any:
49
88
  """Retrieve the value associated with the given key from the dataset.
50
89
 
@@ -89,7 +128,25 @@ class Dataset(UserList, ResultsExportMixin):
89
128
 
90
129
  return get_values(self.data[0])[0]
91
130
 
92
- def select(self, *keys):
131
+ def print(self, pretty_labels=None, **kwargs):
132
+ if "format" in kwargs:
133
+ if kwargs["format"] not in ["html", "markdown", "rich", "latex"]:
134
+ raise ValueError(f"Format '{kwargs['format']}' not supported.")
135
+ if pretty_labels is None:
136
+ pretty_labels = {}
137
+ else:
138
+ return self.rename(pretty_labels).print(**kwargs)
139
+ return self.table()
140
+
141
+ def rename(self, rename_dic) -> Dataset:
142
+ new_data = []
143
+ for observation in self.data:
144
+ key, values = list(observation.items())[0]
145
+ new_key = rename_dic.get(key, key)
146
+ new_data.append({new_key: values})
147
+ return Dataset(new_data)
148
+
149
+ def select(self, *keys) -> Dataset:
93
150
  """Return a new dataset with only the selected keys.
94
151
 
95
152
  :param keys: The keys to select.
@@ -122,12 +179,6 @@ class Dataset(UserList, ResultsExportMixin):
122
179
  json.dumps(self.data)
123
180
  ) # janky but I want to make sure it's serializable & deserializable
124
181
 
125
- def _repr_html_(self) -> str:
126
- """Return an HTML representation of the dataset."""
127
- from edsl.utilities.utilities import data_to_html
128
-
129
- return data_to_html(self.data)
130
-
131
182
  def shuffle(self, seed=None) -> Dataset:
132
183
  """Return a new dataset with the observations shuffled.
133
184
 
@@ -149,6 +200,9 @@ class Dataset(UserList, ResultsExportMixin):
149
200
 
150
201
  return self
151
202
 
203
+ def expand(self, field):
204
+ return self.to_scenario_list().expand(field).to_dataset()
205
+
152
206
  def sample(
153
207
  self,
154
208
  n: int = None,
@@ -267,15 +321,92 @@ class Dataset(UserList, ResultsExportMixin):
267
321
 
268
322
  return Dataset(new_data)
269
323
 
270
- @property
271
- def tree(self):
324
+ def tree(self, node_order: Optional[list[str]] = None) -> Tree:
272
325
  """Return a tree representation of the dataset.
273
326
 
274
327
  >>> d = Dataset([{'a':[1,2,3,4]}, {'b':[4,3,2,1]}])
275
- >>> d.tree.print_tree()
276
- Tree has not been constructed yet.
328
+ >>> d.tree()
329
+ Tree(Dataset({'a': [1, 2, 3, 4], 'b': [4, 3, 2, 1]}))
277
330
  """
278
- return Tree(self)
331
+ return Tree(self, node_order=node_order)
332
+
333
+ def table(
334
+ self,
335
+ *fields,
336
+ tablefmt: Optional[str] = None,
337
+ max_rows: Optional[int] = None,
338
+ pretty_labels=None,
339
+ print_parameters: Optional[dict] = None,
340
+ ):
341
+ if pretty_labels is not None:
342
+ new_fields = []
343
+ for field in fields:
344
+ new_fields.append(pretty_labels.get(field, field))
345
+ return self.rename(pretty_labels).table(
346
+ *new_fields, tablefmt=tablefmt, max_rows=max_rows
347
+ )
348
+
349
+ self.print_parameters = print_parameters
350
+
351
+ headers, data = self._tabular()
352
+
353
+ if tablefmt is not None:
354
+ from tabulate import tabulate_formats
355
+
356
+ if tablefmt not in tabulate_formats:
357
+ print(
358
+ f"Error: The following table format is not supported: {tablefmt}",
359
+ file=sys.stderr,
360
+ )
361
+ print(f"\nAvailable formats are: {tabulate_formats}", file=sys.stderr)
362
+ return None
363
+
364
+ if max_rows:
365
+ if len(data) < max_rows:
366
+ max_rows = None
367
+
368
+ if fields:
369
+ full_data = data
370
+ data = []
371
+ indices = []
372
+ for field in fields:
373
+ if field not in headers:
374
+ print(
375
+ f"Error: The following field was not found: {field}",
376
+ file=sys.stderr,
377
+ )
378
+ print(f"\nAvailable fields are: {headers}", file=sys.stderr)
379
+
380
+ # Optional: Suggest similar fields using difflib
381
+ import difflib
382
+
383
+ matches = difflib.get_close_matches(field, headers)
384
+ if matches:
385
+ print(f"\nDid you mean: {matches[0]} ?", file=sys.stderr)
386
+ return None
387
+ indices.append(headers.index(field))
388
+ headers = fields
389
+ for row in full_data:
390
+ data.append([row[i] for i in indices])
391
+
392
+ if max_rows is not None:
393
+ if max_rows > len(data):
394
+ raise ValueError(
395
+ "max_rows cannot be greater than the number of rows in the dataset."
396
+ )
397
+ last_line = data[-1]
398
+ spaces = len(data[max_rows])
399
+ filler_line = ["." for i in range(spaces)]
400
+ data = data[:max_rows]
401
+ data.append(filler_line)
402
+ data.append(last_line)
403
+
404
+ return TableDisplay(
405
+ data=data, headers=headers, tablefmt=tablefmt, raw_data_set=self
406
+ )
407
+
408
+ def summary(self):
409
+ return Dataset([{"num_observations": [len(self)], "keys": [self.keys()]}])
279
410
 
280
411
  @classmethod
281
412
  def example(self):