edsl 0.1.38.dev2__py3-none-any.whl → 0.1.38.dev4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +60 -31
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +18 -9
- edsl/agents/AgentList.py +59 -8
- edsl/agents/Invigilator.py +18 -7
- edsl/agents/InvigilatorBase.py +0 -19
- edsl/agents/PromptConstructor.py +5 -4
- edsl/config.py +8 -0
- edsl/coop/coop.py +74 -7
- edsl/data/Cache.py +27 -2
- edsl/data/CacheEntry.py +8 -3
- edsl/data/RemoteCacheSync.py +0 -19
- edsl/enums.py +2 -0
- edsl/inference_services/GoogleService.py +7 -15
- edsl/inference_services/PerplexityService.py +163 -0
- edsl/inference_services/registry.py +2 -0
- edsl/jobs/Jobs.py +88 -548
- edsl/jobs/JobsChecks.py +147 -0
- edsl/jobs/JobsPrompts.py +268 -0
- edsl/jobs/JobsRemoteInferenceHandler.py +239 -0
- edsl/jobs/interviews/Interview.py +11 -11
- edsl/jobs/runners/JobsRunnerAsyncio.py +140 -35
- edsl/jobs/runners/JobsRunnerStatus.py +0 -2
- edsl/jobs/tasks/TaskHistory.py +15 -16
- edsl/language_models/LanguageModel.py +44 -84
- edsl/language_models/ModelList.py +47 -1
- edsl/language_models/registry.py +57 -4
- edsl/prompts/Prompt.py +8 -3
- edsl/questions/QuestionBase.py +20 -16
- edsl/questions/QuestionExtract.py +3 -4
- edsl/questions/question_registry.py +36 -6
- edsl/results/CSSParameterizer.py +108 -0
- edsl/results/Dataset.py +146 -15
- edsl/results/DatasetExportMixin.py +231 -217
- edsl/results/DatasetTree.py +134 -4
- edsl/results/Result.py +18 -9
- edsl/results/Results.py +145 -51
- edsl/results/TableDisplay.py +198 -0
- edsl/results/table_display.css +78 -0
- edsl/scenarios/FileStore.py +187 -13
- edsl/scenarios/Scenario.py +61 -4
- edsl/scenarios/ScenarioJoin.py +127 -0
- edsl/scenarios/ScenarioList.py +237 -62
- edsl/surveys/Survey.py +16 -2
- edsl/surveys/SurveyFlowVisualizationMixin.py +67 -9
- edsl/surveys/instructions/Instruction.py +12 -0
- edsl/templates/error_reporting/interview_details.html +3 -3
- edsl/templates/error_reporting/interviews.html +18 -9
- edsl/utilities/utilities.py +15 -0
- {edsl-0.1.38.dev2.dist-info → edsl-0.1.38.dev4.dist-info}/METADATA +2 -1
- {edsl-0.1.38.dev2.dist-info → edsl-0.1.38.dev4.dist-info}/RECORD +53 -45
- {edsl-0.1.38.dev2.dist-info → edsl-0.1.38.dev4.dist-info}/LICENSE +0 -0
- {edsl-0.1.38.dev2.dist-info → edsl-0.1.38.dev4.dist-info}/WHEEL +0 -0
@@ -1,4 +1,4 @@
|
|
1
|
-
from typing import Optional
|
1
|
+
from typing import Optional, List
|
2
2
|
from collections import UserList
|
3
3
|
from edsl import Model
|
4
4
|
|
@@ -10,6 +10,8 @@ from edsl.utilities.utilities import dict_hash
|
|
10
10
|
|
11
11
|
|
12
12
|
class ModelList(Base, UserList):
|
13
|
+
__documentation__ = """https://docs.expectedparrot.com/en/latest/language_models.html#module-edsl.language_models.ModelList"""
|
14
|
+
|
13
15
|
def __init__(self, data: Optional[list] = None):
|
14
16
|
"""Initialize the ScenarioList class.
|
15
17
|
|
@@ -37,6 +39,9 @@ class ModelList(Base, UserList):
|
|
37
39
|
def __repr__(self):
|
38
40
|
return f"ModelList({super().__repr__()})"
|
39
41
|
|
42
|
+
def _summary(self):
|
43
|
+
return {"EDSL Class": "ModelList", "Number of Models": len(self)}
|
44
|
+
|
40
45
|
def __hash__(self):
|
41
46
|
"""Return a hash of the ModelList. This is used for comparison of ModelLists.
|
42
47
|
|
@@ -48,6 +53,42 @@ class ModelList(Base, UserList):
|
|
48
53
|
|
49
54
|
return dict_hash(self.to_dict(sort=True, add_edsl_version=False))
|
50
55
|
|
56
|
+
def to_scenario_list(self):
|
57
|
+
from edsl import ScenarioList, Scenario
|
58
|
+
|
59
|
+
sl = ScenarioList()
|
60
|
+
for model in self:
|
61
|
+
d = {"model": model.model}
|
62
|
+
d.update(model.parameters)
|
63
|
+
sl.append(Scenario(d))
|
64
|
+
return sl
|
65
|
+
|
66
|
+
def tree(self, node_list: Optional[List[str]] = None):
|
67
|
+
return self.to_scenario_list().tree(node_list)
|
68
|
+
|
69
|
+
def table(
|
70
|
+
self,
|
71
|
+
*fields,
|
72
|
+
tablefmt: Optional[str] = None,
|
73
|
+
pretty_labels: Optional[dict] = None,
|
74
|
+
):
|
75
|
+
"""
|
76
|
+
>>> ModelList.example().table("model")
|
77
|
+
model
|
78
|
+
-------
|
79
|
+
gpt-4o
|
80
|
+
gpt-4o
|
81
|
+
gpt-4o
|
82
|
+
"""
|
83
|
+
return (
|
84
|
+
self.to_scenario_list()
|
85
|
+
.to_dataset()
|
86
|
+
.table(*fields, tablefmt=tablefmt, pretty_labels=pretty_labels)
|
87
|
+
)
|
88
|
+
|
89
|
+
def to_list(self):
|
90
|
+
return self.to_scenario_list().to_list()
|
91
|
+
|
51
92
|
def to_dict(self, sort=False, add_edsl_version=True):
|
52
93
|
if sort:
|
53
94
|
model_list = sorted([model for model in self], key=lambda x: hash(x))
|
@@ -71,6 +112,11 @@ class ModelList(Base, UserList):
|
|
71
112
|
|
72
113
|
return d
|
73
114
|
|
115
|
+
def _repr_html_(self):
|
116
|
+
"""Return an HTML representation of the ModelList."""
|
117
|
+
footer = f"<a href={self.__documentation__}>(docs)</a>"
|
118
|
+
return str(self.summary(format="html")) + footer
|
119
|
+
|
74
120
|
@classmethod
|
75
121
|
def from_names(self, *args, **kwargs):
|
76
122
|
"""A a model list from a list of names"""
|
edsl/language_models/registry.py
CHANGED
@@ -7,6 +7,49 @@ from edsl.config import CONFIG
|
|
7
7
|
# else:
|
8
8
|
# default_model = CONFIG.get("EDSL_DEFAULT_MODEL")
|
9
9
|
|
10
|
+
from collections import UserList
|
11
|
+
|
12
|
+
|
13
|
+
class PrettyList(UserList):
|
14
|
+
def __init__(self, data=None, columns=None):
|
15
|
+
super().__init__(data)
|
16
|
+
self.columns = columns
|
17
|
+
|
18
|
+
def _repr_html_(self):
|
19
|
+
if isinstance(self[0], list) or isinstance(self[0], tuple):
|
20
|
+
num_cols = len(self[0])
|
21
|
+
else:
|
22
|
+
num_cols = 1
|
23
|
+
|
24
|
+
if self.columns:
|
25
|
+
columns = self.columns
|
26
|
+
else:
|
27
|
+
columns = list(range(num_cols))
|
28
|
+
|
29
|
+
if num_cols > 1:
|
30
|
+
return (
|
31
|
+
"<pre><table>"
|
32
|
+
+ "".join(["<th>" + str(column) + "</th>" for column in columns])
|
33
|
+
+ "".join(
|
34
|
+
[
|
35
|
+
"<tr>"
|
36
|
+
+ "".join(["<td>" + str(x) + "</td>" for x in row])
|
37
|
+
+ "</tr>"
|
38
|
+
for row in self
|
39
|
+
]
|
40
|
+
)
|
41
|
+
+ "</table></pre>"
|
42
|
+
)
|
43
|
+
else:
|
44
|
+
return (
|
45
|
+
"<pre><table>"
|
46
|
+
+ "".join(["<th>" + str(index) + "</th>" for index in columns])
|
47
|
+
+ "".join(
|
48
|
+
["<tr>" + "<td>" + str(row) + "</td>" + "</tr>" for row in self]
|
49
|
+
)
|
50
|
+
+ "</table></pre>"
|
51
|
+
)
|
52
|
+
|
10
53
|
|
11
54
|
def get_model_class(model_name, registry=None):
|
12
55
|
from edsl.inference_services.registry import default
|
@@ -82,17 +125,27 @@ class Model(metaclass=Meta):
|
|
82
125
|
|
83
126
|
if search_term is None:
|
84
127
|
if name_only:
|
85
|
-
return
|
128
|
+
return PrettyList(
|
129
|
+
[m[0] for m in full_list],
|
130
|
+
columns=["Model Name", "Service Name", "Code"],
|
131
|
+
)
|
86
132
|
else:
|
87
|
-
return
|
133
|
+
return PrettyList(
|
134
|
+
full_list, columns=["Model Name", "Service Name", "Code"]
|
135
|
+
)
|
88
136
|
else:
|
89
137
|
filtered_results = [
|
90
138
|
m for m in full_list if search_term in m[0] or search_term in m[1]
|
91
139
|
]
|
92
140
|
if name_only:
|
93
|
-
return
|
141
|
+
return PrettyList(
|
142
|
+
[m[0] for m in filtered_results],
|
143
|
+
columns=["Model Name", "Service Name", "Code"],
|
144
|
+
)
|
94
145
|
else:
|
95
|
-
return
|
146
|
+
return PrettyList(
|
147
|
+
filtered_results, columns=["Model Name", "Service Name", "Code"]
|
148
|
+
)
|
96
149
|
|
97
150
|
@classmethod
|
98
151
|
def check_models(cls, verbose=False):
|
edsl/prompts/Prompt.py
CHANGED
@@ -29,9 +29,14 @@ class Prompt(PersistenceMixin, RichPrintingMixin):
|
|
29
29
|
|
30
30
|
def _repr_html_(self):
|
31
31
|
"""Return an HTML representation of the Prompt."""
|
32
|
-
from edsl.utilities.utilities import data_to_html
|
33
|
-
|
34
|
-
|
32
|
+
# from edsl.utilities.utilities import data_to_html
|
33
|
+
# return data_to_html(self.to_dict())
|
34
|
+
d = self.to_dict()
|
35
|
+
data = [[k, v] for k, v in d.items()]
|
36
|
+
from tabulate import tabulate
|
37
|
+
|
38
|
+
table = str(tabulate(data, headers=["keys", "values"], tablefmt="html"))
|
39
|
+
return f"<pre>{table}</pre>"
|
35
40
|
|
36
41
|
def __len__(self):
|
37
42
|
"""Return the length of the prompt text."""
|
edsl/questions/QuestionBase.py
CHANGED
@@ -82,7 +82,8 @@ class QuestionBase(
|
|
82
82
|
if not hasattr(self, "_fake_data_factory"):
|
83
83
|
from polyfactory.factories.pydantic_factory import ModelFactory
|
84
84
|
|
85
|
-
class FakeData(ModelFactory[self.response_model]):
|
85
|
+
class FakeData(ModelFactory[self.response_model]):
|
86
|
+
...
|
86
87
|
|
87
88
|
self._fake_data_factory = FakeData
|
88
89
|
return self._fake_data_factory
|
@@ -263,12 +264,9 @@ class QuestionBase(
|
|
263
264
|
>>> m.execute_model_call("", "")
|
264
265
|
{'message': [{'text': "Yo, what's up?"}], 'usage': {'prompt_tokens': 1, 'completion_tokens': 1}}
|
265
266
|
>>> Q.run_example(show_answer = True, model = m, disable_remote_cache = True, disable_remote_inference = True)
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
┡━━━━━━━━━━━━━━━━┩
|
270
|
-
│ Yo, what's up? │
|
271
|
-
└────────────────┘
|
267
|
+
answer.how_are_you
|
268
|
+
--------------------
|
269
|
+
Yo, what's up?
|
272
270
|
"""
|
273
271
|
if model is None:
|
274
272
|
from edsl import Model
|
@@ -284,7 +282,7 @@ class QuestionBase(
|
|
284
282
|
)
|
285
283
|
)
|
286
284
|
if show_answer:
|
287
|
-
results.select("answer.*").print()
|
285
|
+
return results.select("answer.*").print()
|
288
286
|
else:
|
289
287
|
return results
|
290
288
|
|
@@ -362,16 +360,22 @@ class QuestionBase(
|
|
362
360
|
|
363
361
|
# region: Magic methods
|
364
362
|
def _repr_html_(self):
|
365
|
-
from edsl.utilities.utilities import data_to_html
|
363
|
+
# from edsl.utilities.utilities import data_to_html
|
366
364
|
|
367
|
-
data = self.to_dict()
|
368
|
-
|
369
|
-
|
370
|
-
|
371
|
-
|
372
|
-
|
365
|
+
data = self.to_dict(add_edsl_version=False)
|
366
|
+
# keys = list(data.keys())
|
367
|
+
# values = list(data.values())
|
368
|
+
from tabulate import tabulate
|
369
|
+
|
370
|
+
return tabulate(data.items(), headers=["keys", "values"], tablefmt="html")
|
371
|
+
|
372
|
+
# try:
|
373
|
+
# _ = data.pop("edsl_version")
|
374
|
+
# _ = data.pop("edsl_class_name")
|
375
|
+
# except KeyError:
|
376
|
+
# print("Serialized question lacks edsl version, but is should have it.")
|
373
377
|
|
374
|
-
return data_to_html(data)
|
378
|
+
# return data_to_html(data)
|
375
379
|
|
376
380
|
def __getitem__(self, key: str) -> Any:
|
377
381
|
"""Get an attribute of the question so it can be treated like a dictionary.
|
@@ -1,4 +1,7 @@
|
|
1
1
|
from __future__ import annotations
|
2
|
+
import json
|
3
|
+
import re
|
4
|
+
|
2
5
|
from typing import Any, Optional, Dict
|
3
6
|
from edsl.questions.QuestionBase import QuestionBase
|
4
7
|
from edsl.questions.descriptors import AnswerTemplateDescriptor
|
@@ -11,9 +14,6 @@ from edsl.questions.decorators import inject_exception
|
|
11
14
|
from typing import Dict, Any
|
12
15
|
from pydantic import create_model, Field
|
13
16
|
|
14
|
-
import json
|
15
|
-
import re
|
16
|
-
|
17
17
|
|
18
18
|
def extract_json(text, expected_keys, verbose=False):
|
19
19
|
# Escape special regex characters in keys
|
@@ -112,7 +112,6 @@ class QuestionExtract(QuestionBase):
|
|
112
112
|
|
113
113
|
:param question_name: The name of the question.
|
114
114
|
:param question_text: The text of the question.
|
115
|
-
:param question_options: The options the respondent should select from.
|
116
115
|
:param answer_template: The template for the answer.
|
117
116
|
"""
|
118
117
|
self.question_name = question_name
|
@@ -90,6 +90,22 @@ class Question(metaclass=Meta):
|
|
90
90
|
coop = Coop()
|
91
91
|
return coop.patch(uuid, url, description, value, visibility)
|
92
92
|
|
93
|
+
@classmethod
|
94
|
+
def list_question_types(cls):
|
95
|
+
"""Return a list of available question types.
|
96
|
+
|
97
|
+
>>> from edsl import Question
|
98
|
+
>>> Question.list_question_types()
|
99
|
+
['checkbox', 'extract', 'free_text', 'functional', 'likert_five', 'linear_scale', 'list', 'multiple_choice', 'numerical', 'rank', 'top_k', 'yes_no']
|
100
|
+
"""
|
101
|
+
return [
|
102
|
+
q
|
103
|
+
for q in sorted(
|
104
|
+
list(RegisterQuestionsMeta.question_types_to_classes().keys())
|
105
|
+
)
|
106
|
+
if q not in ["budget"]
|
107
|
+
]
|
108
|
+
|
93
109
|
@classmethod
|
94
110
|
def available(cls, show_class_names: bool = False) -> Union[list, dict]:
|
95
111
|
"""Return a list of available question types.
|
@@ -98,18 +114,32 @@ class Question(metaclass=Meta):
|
|
98
114
|
|
99
115
|
Example usage:
|
100
116
|
|
101
|
-
>>> from edsl import Question
|
102
|
-
>>> Question.available()
|
103
|
-
['checkbox', 'extract', 'free_text', 'functional', 'likert_five', 'linear_scale', 'list', 'multiple_choice', 'numerical', 'rank', 'top_k', 'yes_no']
|
104
117
|
"""
|
118
|
+
from edsl.results.Dataset import Dataset
|
119
|
+
|
105
120
|
exclude = ["budget"]
|
106
121
|
if show_class_names:
|
107
122
|
return RegisterQuestionsMeta.question_types_to_classes()
|
108
123
|
else:
|
109
|
-
question_list =
|
110
|
-
|
124
|
+
question_list = [
|
125
|
+
q
|
126
|
+
for q in sorted(
|
127
|
+
set(RegisterQuestionsMeta.question_types_to_classes().keys())
|
128
|
+
)
|
129
|
+
if q not in exclude
|
130
|
+
]
|
131
|
+
d = RegisterQuestionsMeta.question_types_to_classes()
|
132
|
+
question_classes = [d[q] for q in question_list]
|
133
|
+
example_questions = [repr(q.example()) for q in question_classes]
|
134
|
+
|
135
|
+
return Dataset(
|
136
|
+
[
|
137
|
+
{"question_type": [q for q in question_list]},
|
138
|
+
{"question_class": [q.__name__ for q in question_classes]},
|
139
|
+
{"example_question": example_questions},
|
140
|
+
],
|
141
|
+
print_parameters={"containerHeight": "auto"},
|
111
142
|
)
|
112
|
-
return [q for q in question_list if q not in exclude]
|
113
143
|
|
114
144
|
|
115
145
|
def get_question_class(question_type):
|
@@ -0,0 +1,108 @@
|
|
1
|
+
import re
|
2
|
+
from typing import Dict, Set, Optional
|
3
|
+
|
4
|
+
|
5
|
+
class CSSParameterizer:
|
6
|
+
"""A utility class to parameterize CSS with custom properties (variables)."""
|
7
|
+
|
8
|
+
def __init__(self, css_content: str):
|
9
|
+
"""
|
10
|
+
Initialize with CSS content to be parameterized.
|
11
|
+
|
12
|
+
Args:
|
13
|
+
css_content (str): The CSS content containing var() declarations
|
14
|
+
"""
|
15
|
+
self.css_content = css_content
|
16
|
+
self._extract_variables()
|
17
|
+
|
18
|
+
def _extract_variables(self) -> None:
|
19
|
+
"""Extract all CSS custom properties (variables) from the CSS content."""
|
20
|
+
# Find all var(...) declarations in the CSS
|
21
|
+
var_pattern = r"var\((--[a-zA-Z0-9-]+)\)"
|
22
|
+
self.variables = set(re.findall(var_pattern, self.css_content))
|
23
|
+
|
24
|
+
def _validate_parameters(self, parameters: Dict[str, str]) -> Set[str]:
|
25
|
+
"""
|
26
|
+
Validate the provided parameters against the CSS variables.
|
27
|
+
|
28
|
+
Args:
|
29
|
+
parameters (Dict[str, str]): Dictionary of variable names and their values
|
30
|
+
|
31
|
+
Returns:
|
32
|
+
Set[str]: Set of missing variables
|
33
|
+
"""
|
34
|
+
# Convert parameter keys to CSS variable format if they don't already have --
|
35
|
+
formatted_params = {
|
36
|
+
f"--{k}" if not k.startswith("--") else k for k in parameters.keys()
|
37
|
+
}
|
38
|
+
|
39
|
+
# print("Variables from CSS:", self.variables)
|
40
|
+
# print("Formatted parameters:", formatted_params)
|
41
|
+
|
42
|
+
# Find missing and extra variables
|
43
|
+
missing_vars = self.variables - formatted_params
|
44
|
+
extra_vars = formatted_params - self.variables
|
45
|
+
|
46
|
+
if extra_vars:
|
47
|
+
print(f"Warning: Found unused parameters: {extra_vars}")
|
48
|
+
|
49
|
+
return missing_vars
|
50
|
+
|
51
|
+
def generate_root(self, **parameters: str) -> Optional[str]:
|
52
|
+
"""
|
53
|
+
Generate a :root block with the provided parameters.
|
54
|
+
|
55
|
+
Args:
|
56
|
+
**parameters: Keyword arguments where keys are variable names and values are their values
|
57
|
+
|
58
|
+
Returns:
|
59
|
+
str: Generated :root block with variables, or None if validation fails
|
60
|
+
|
61
|
+
Example:
|
62
|
+
>>> css = "body { height: var(--bodyHeight); }"
|
63
|
+
>>> parameterizer = CSSParameterizer(css)
|
64
|
+
>>> parameterizer.apply_parameters({'bodyHeight':"100vh"})
|
65
|
+
':root {\\n --bodyHeight: 100vh;\\n}\\n\\nbody { height: var(--bodyHeight); }'
|
66
|
+
"""
|
67
|
+
missing_vars = self._validate_parameters(parameters)
|
68
|
+
|
69
|
+
if missing_vars:
|
70
|
+
print(f"Error: Missing required variables: {missing_vars}")
|
71
|
+
return None
|
72
|
+
|
73
|
+
# Format parameters with -- prefix if not present
|
74
|
+
formatted_params = {
|
75
|
+
f"--{k}" if not k.startswith("--") else k: v for k, v in parameters.items()
|
76
|
+
}
|
77
|
+
|
78
|
+
# Generate the :root block
|
79
|
+
root_block = [":root {"]
|
80
|
+
for var_name, value in sorted(formatted_params.items()):
|
81
|
+
if var_name in self.variables:
|
82
|
+
root_block.append(f" {var_name}: {value};")
|
83
|
+
root_block.append("}")
|
84
|
+
|
85
|
+
return "\n".join(root_block)
|
86
|
+
|
87
|
+
def apply_parameters(self, parameters: dict) -> Optional[str]:
|
88
|
+
"""
|
89
|
+
Generate the complete CSS with the :root block and original CSS content.
|
90
|
+
|
91
|
+
Args:
|
92
|
+
**parameters: Keyword arguments where keys are variable names and values are their values
|
93
|
+
|
94
|
+
Returns:
|
95
|
+
str: Complete CSS with :root block and original content, or None if validation fails
|
96
|
+
"""
|
97
|
+
root_block = self.generate_root(**parameters)
|
98
|
+
if root_block is None:
|
99
|
+
return None
|
100
|
+
|
101
|
+
return f"{root_block}\n\n{self.css_content}"
|
102
|
+
|
103
|
+
|
104
|
+
# Example usage
|
105
|
+
if __name__ == "__main__":
|
106
|
+
import doctest
|
107
|
+
|
108
|
+
doctest.testmod()
|
edsl/results/Dataset.py
CHANGED
@@ -5,19 +5,23 @@ import random
|
|
5
5
|
import json
|
6
6
|
from collections import UserList
|
7
7
|
from typing import Any, Union, Optional
|
8
|
-
|
8
|
+
import sys
|
9
9
|
import numpy as np
|
10
10
|
|
11
11
|
from edsl.results.ResultsExportMixin import ResultsExportMixin
|
12
12
|
from edsl.results.DatasetTree import Tree
|
13
|
+
from edsl.results.TableDisplay import TableDisplay
|
13
14
|
|
14
15
|
|
15
16
|
class Dataset(UserList, ResultsExportMixin):
|
16
17
|
"""A class to represent a dataset of observations."""
|
17
18
|
|
18
|
-
def __init__(
|
19
|
+
def __init__(
|
20
|
+
self, data: list[dict[str, Any]] = None, print_parameters: Optional[dict] = None
|
21
|
+
):
|
19
22
|
"""Initialize the dataset with the given data."""
|
20
23
|
super().__init__(data)
|
24
|
+
self.print_parameters = print_parameters
|
21
25
|
|
22
26
|
def __len__(self) -> int:
|
23
27
|
"""Return the number of observations in the dataset.
|
@@ -32,7 +36,7 @@ class Dataset(UserList, ResultsExportMixin):
|
|
32
36
|
_, values = list(self.data[0].items())[0]
|
33
37
|
return len(values)
|
34
38
|
|
35
|
-
def keys(self):
|
39
|
+
def keys(self) -> list[str]:
|
36
40
|
"""Return the keys of the first observation in the dataset.
|
37
41
|
|
38
42
|
>>> d = Dataset([{'a.b':[1,2,3,4]}])
|
@@ -41,10 +45,45 @@ class Dataset(UserList, ResultsExportMixin):
|
|
41
45
|
"""
|
42
46
|
return [list(o.keys())[0] for o in self]
|
43
47
|
|
48
|
+
def filter(self, expression):
|
49
|
+
return self.to_scenario_list().filter(expression).to_dataset()
|
50
|
+
|
44
51
|
def __repr__(self) -> str:
|
45
52
|
"""Return a string representation of the dataset."""
|
46
53
|
return f"Dataset({self.data})"
|
47
54
|
|
55
|
+
def write(self, filename: str, tablefmt: Optional[str] = None) -> None:
|
56
|
+
return self.table(tablefmt=tablefmt).write(filename)
|
57
|
+
|
58
|
+
def _repr_html_(self):
|
59
|
+
# headers, data = self._tabular()
|
60
|
+
return self.table(print_parameters=self.print_parameters)._repr_html_()
|
61
|
+
# return TableDisplay(headers=headers, data=data, raw_data_set=self)
|
62
|
+
|
63
|
+
def _tabular(self) -> tuple[list[str], list[list[Any]]]:
|
64
|
+
# Extract headers
|
65
|
+
headers = []
|
66
|
+
for entry in self.data:
|
67
|
+
headers.extend(entry.keys())
|
68
|
+
headers = list(dict.fromkeys(headers)) # Ensure unique headers
|
69
|
+
|
70
|
+
# Extract data
|
71
|
+
max_len = max(len(values) for entry in self.data for values in entry.values())
|
72
|
+
rows = []
|
73
|
+
for i in range(max_len):
|
74
|
+
row = []
|
75
|
+
for header in headers:
|
76
|
+
for entry in self.data:
|
77
|
+
if header in entry:
|
78
|
+
values = entry[header]
|
79
|
+
row.append(values[i] if i < len(values) else None)
|
80
|
+
break
|
81
|
+
else:
|
82
|
+
row.append(None) # Default to None if header is missing
|
83
|
+
rows.append(row)
|
84
|
+
|
85
|
+
return headers, rows
|
86
|
+
|
48
87
|
def _key_to_value(self, key: str) -> Any:
|
49
88
|
"""Retrieve the value associated with the given key from the dataset.
|
50
89
|
|
@@ -89,7 +128,25 @@ class Dataset(UserList, ResultsExportMixin):
|
|
89
128
|
|
90
129
|
return get_values(self.data[0])[0]
|
91
130
|
|
92
|
-
def
|
131
|
+
def print(self, pretty_labels=None, **kwargs):
|
132
|
+
if "format" in kwargs:
|
133
|
+
if kwargs["format"] not in ["html", "markdown", "rich", "latex"]:
|
134
|
+
raise ValueError(f"Format '{kwargs['format']}' not supported.")
|
135
|
+
if pretty_labels is None:
|
136
|
+
pretty_labels = {}
|
137
|
+
else:
|
138
|
+
return self.rename(pretty_labels).print(**kwargs)
|
139
|
+
return self.table()
|
140
|
+
|
141
|
+
def rename(self, rename_dic) -> Dataset:
|
142
|
+
new_data = []
|
143
|
+
for observation in self.data:
|
144
|
+
key, values = list(observation.items())[0]
|
145
|
+
new_key = rename_dic.get(key, key)
|
146
|
+
new_data.append({new_key: values})
|
147
|
+
return Dataset(new_data)
|
148
|
+
|
149
|
+
def select(self, *keys) -> Dataset:
|
93
150
|
"""Return a new dataset with only the selected keys.
|
94
151
|
|
95
152
|
:param keys: The keys to select.
|
@@ -122,12 +179,6 @@ class Dataset(UserList, ResultsExportMixin):
|
|
122
179
|
json.dumps(self.data)
|
123
180
|
) # janky but I want to make sure it's serializable & deserializable
|
124
181
|
|
125
|
-
def _repr_html_(self) -> str:
|
126
|
-
"""Return an HTML representation of the dataset."""
|
127
|
-
from edsl.utilities.utilities import data_to_html
|
128
|
-
|
129
|
-
return data_to_html(self.data)
|
130
|
-
|
131
182
|
def shuffle(self, seed=None) -> Dataset:
|
132
183
|
"""Return a new dataset with the observations shuffled.
|
133
184
|
|
@@ -149,6 +200,9 @@ class Dataset(UserList, ResultsExportMixin):
|
|
149
200
|
|
150
201
|
return self
|
151
202
|
|
203
|
+
def expand(self, field):
|
204
|
+
return self.to_scenario_list().expand(field).to_dataset()
|
205
|
+
|
152
206
|
def sample(
|
153
207
|
self,
|
154
208
|
n: int = None,
|
@@ -267,15 +321,92 @@ class Dataset(UserList, ResultsExportMixin):
|
|
267
321
|
|
268
322
|
return Dataset(new_data)
|
269
323
|
|
270
|
-
|
271
|
-
def tree(self):
|
324
|
+
def tree(self, node_order: Optional[list[str]] = None) -> Tree:
|
272
325
|
"""Return a tree representation of the dataset.
|
273
326
|
|
274
327
|
>>> d = Dataset([{'a':[1,2,3,4]}, {'b':[4,3,2,1]}])
|
275
|
-
>>> d.tree
|
276
|
-
Tree
|
328
|
+
>>> d.tree()
|
329
|
+
Tree(Dataset({'a': [1, 2, 3, 4], 'b': [4, 3, 2, 1]}))
|
277
330
|
"""
|
278
|
-
return Tree(self)
|
331
|
+
return Tree(self, node_order=node_order)
|
332
|
+
|
333
|
+
def table(
|
334
|
+
self,
|
335
|
+
*fields,
|
336
|
+
tablefmt: Optional[str] = None,
|
337
|
+
max_rows: Optional[int] = None,
|
338
|
+
pretty_labels=None,
|
339
|
+
print_parameters: Optional[dict] = None,
|
340
|
+
):
|
341
|
+
if pretty_labels is not None:
|
342
|
+
new_fields = []
|
343
|
+
for field in fields:
|
344
|
+
new_fields.append(pretty_labels.get(field, field))
|
345
|
+
return self.rename(pretty_labels).table(
|
346
|
+
*new_fields, tablefmt=tablefmt, max_rows=max_rows
|
347
|
+
)
|
348
|
+
|
349
|
+
self.print_parameters = print_parameters
|
350
|
+
|
351
|
+
headers, data = self._tabular()
|
352
|
+
|
353
|
+
if tablefmt is not None:
|
354
|
+
from tabulate import tabulate_formats
|
355
|
+
|
356
|
+
if tablefmt not in tabulate_formats:
|
357
|
+
print(
|
358
|
+
f"Error: The following table format is not supported: {tablefmt}",
|
359
|
+
file=sys.stderr,
|
360
|
+
)
|
361
|
+
print(f"\nAvailable formats are: {tabulate_formats}", file=sys.stderr)
|
362
|
+
return None
|
363
|
+
|
364
|
+
if max_rows:
|
365
|
+
if len(data) < max_rows:
|
366
|
+
max_rows = None
|
367
|
+
|
368
|
+
if fields:
|
369
|
+
full_data = data
|
370
|
+
data = []
|
371
|
+
indices = []
|
372
|
+
for field in fields:
|
373
|
+
if field not in headers:
|
374
|
+
print(
|
375
|
+
f"Error: The following field was not found: {field}",
|
376
|
+
file=sys.stderr,
|
377
|
+
)
|
378
|
+
print(f"\nAvailable fields are: {headers}", file=sys.stderr)
|
379
|
+
|
380
|
+
# Optional: Suggest similar fields using difflib
|
381
|
+
import difflib
|
382
|
+
|
383
|
+
matches = difflib.get_close_matches(field, headers)
|
384
|
+
if matches:
|
385
|
+
print(f"\nDid you mean: {matches[0]} ?", file=sys.stderr)
|
386
|
+
return None
|
387
|
+
indices.append(headers.index(field))
|
388
|
+
headers = fields
|
389
|
+
for row in full_data:
|
390
|
+
data.append([row[i] for i in indices])
|
391
|
+
|
392
|
+
if max_rows is not None:
|
393
|
+
if max_rows > len(data):
|
394
|
+
raise ValueError(
|
395
|
+
"max_rows cannot be greater than the number of rows in the dataset."
|
396
|
+
)
|
397
|
+
last_line = data[-1]
|
398
|
+
spaces = len(data[max_rows])
|
399
|
+
filler_line = ["." for i in range(spaces)]
|
400
|
+
data = data[:max_rows]
|
401
|
+
data.append(filler_line)
|
402
|
+
data.append(last_line)
|
403
|
+
|
404
|
+
return TableDisplay(
|
405
|
+
data=data, headers=headers, tablefmt=tablefmt, raw_data_set=self
|
406
|
+
)
|
407
|
+
|
408
|
+
def summary(self):
|
409
|
+
return Dataset([{"num_observations": [len(self)], "keys": [self.keys()]}])
|
279
410
|
|
280
411
|
@classmethod
|
281
412
|
def example(self):
|