edsl 0.1.49__py3-none-any.whl → 0.1.51__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/__init__.py +124 -53
- edsl/__version__.py +1 -1
- edsl/agents/agent.py +21 -21
- edsl/agents/agent_list.py +2 -5
- edsl/agents/exceptions.py +119 -5
- edsl/base/__init__.py +10 -35
- edsl/base/base_class.py +71 -36
- edsl/base/base_exception.py +204 -0
- edsl/base/data_transfer_models.py +1 -1
- edsl/base/exceptions.py +94 -0
- edsl/buckets/__init__.py +15 -1
- edsl/buckets/bucket_collection.py +3 -4
- edsl/buckets/exceptions.py +107 -0
- edsl/buckets/model_buckets.py +1 -2
- edsl/buckets/token_bucket.py +11 -6
- edsl/buckets/token_bucket_api.py +27 -12
- edsl/buckets/token_bucket_client.py +9 -7
- edsl/caching/cache.py +12 -4
- edsl/caching/cache_entry.py +10 -9
- edsl/caching/exceptions.py +113 -7
- edsl/caching/remote_cache_sync.py +6 -7
- edsl/caching/sql_dict.py +20 -14
- edsl/cli.py +43 -0
- edsl/config/__init__.py +1 -1
- edsl/config/config_class.py +32 -6
- edsl/conversation/Conversation.py +8 -4
- edsl/conversation/car_buying.py +1 -3
- edsl/conversation/exceptions.py +58 -0
- edsl/conversation/mug_negotiation.py +2 -8
- edsl/coop/__init__.py +28 -6
- edsl/coop/coop.py +120 -29
- edsl/coop/coop_functions.py +1 -1
- edsl/coop/ep_key_handling.py +1 -1
- edsl/coop/exceptions.py +188 -9
- edsl/coop/price_fetcher.py +5 -8
- edsl/coop/utils.py +4 -6
- edsl/dataset/__init__.py +5 -4
- edsl/dataset/dataset.py +177 -86
- edsl/dataset/dataset_operations_mixin.py +98 -76
- edsl/dataset/dataset_tree.py +11 -7
- edsl/dataset/display/table_display.py +0 -2
- edsl/dataset/display/table_renderers.py +6 -4
- edsl/dataset/exceptions.py +125 -0
- edsl/dataset/file_exports.py +18 -11
- edsl/dataset/r/ggplot.py +13 -6
- edsl/display/__init__.py +27 -0
- edsl/display/core.py +147 -0
- edsl/display/plugin.py +189 -0
- edsl/display/utils.py +52 -0
- edsl/inference_services/__init__.py +9 -1
- edsl/inference_services/available_model_cache_handler.py +1 -1
- edsl/inference_services/available_model_fetcher.py +5 -6
- edsl/inference_services/data_structures.py +10 -7
- edsl/inference_services/exceptions.py +132 -1
- edsl/inference_services/inference_service_abc.py +2 -2
- edsl/inference_services/inference_services_collection.py +2 -6
- edsl/inference_services/registry.py +4 -3
- edsl/inference_services/service_availability.py +4 -3
- edsl/inference_services/services/anthropic_service.py +4 -1
- edsl/inference_services/services/aws_bedrock.py +13 -12
- edsl/inference_services/services/azure_ai.py +12 -10
- edsl/inference_services/services/deep_infra_service.py +1 -4
- edsl/inference_services/services/deep_seek_service.py +1 -5
- edsl/inference_services/services/google_service.py +7 -3
- edsl/inference_services/services/groq_service.py +1 -1
- edsl/inference_services/services/mistral_ai_service.py +4 -2
- edsl/inference_services/services/ollama_service.py +1 -1
- edsl/inference_services/services/open_ai_service.py +7 -5
- edsl/inference_services/services/perplexity_service.py +6 -2
- edsl/inference_services/services/test_service.py +8 -7
- edsl/inference_services/services/together_ai_service.py +2 -3
- edsl/inference_services/services/xai_service.py +1 -1
- edsl/instructions/__init__.py +1 -1
- edsl/instructions/change_instruction.py +7 -5
- edsl/instructions/exceptions.py +61 -0
- edsl/instructions/instruction.py +6 -2
- edsl/instructions/instruction_collection.py +6 -4
- edsl/instructions/instruction_handler.py +12 -15
- edsl/interviews/ReportErrors.py +0 -3
- edsl/interviews/__init__.py +9 -2
- edsl/interviews/answering_function.py +11 -13
- edsl/interviews/exception_tracking.py +15 -8
- edsl/interviews/exceptions.py +79 -0
- edsl/interviews/interview.py +33 -30
- edsl/interviews/interview_status_dictionary.py +4 -2
- edsl/interviews/interview_status_log.py +2 -1
- edsl/interviews/interview_task_manager.py +5 -5
- edsl/interviews/request_token_estimator.py +5 -2
- edsl/interviews/statistics.py +3 -4
- edsl/invigilators/__init__.py +7 -1
- edsl/invigilators/exceptions.py +79 -0
- edsl/invigilators/invigilator_base.py +0 -1
- edsl/invigilators/invigilators.py +9 -13
- edsl/invigilators/prompt_constructor.py +1 -5
- edsl/invigilators/prompt_helpers.py +8 -4
- edsl/invigilators/question_instructions_prompt_builder.py +1 -1
- edsl/invigilators/question_option_processor.py +9 -5
- edsl/invigilators/question_template_replacements_builder.py +3 -2
- edsl/jobs/__init__.py +42 -5
- edsl/jobs/async_interview_runner.py +25 -23
- edsl/jobs/check_survey_scenario_compatibility.py +11 -10
- edsl/jobs/data_structures.py +8 -5
- edsl/jobs/exceptions.py +177 -8
- edsl/jobs/fetch_invigilator.py +1 -1
- edsl/jobs/jobs.py +74 -69
- edsl/jobs/jobs_checks.py +6 -7
- edsl/jobs/jobs_component_constructor.py +4 -4
- edsl/jobs/jobs_pricing_estimation.py +4 -3
- edsl/jobs/jobs_remote_inference_logger.py +5 -4
- edsl/jobs/jobs_runner_asyncio.py +3 -4
- edsl/jobs/jobs_runner_status.py +8 -9
- edsl/jobs/remote_inference.py +27 -24
- edsl/jobs/results_exceptions_handler.py +10 -7
- edsl/key_management/__init__.py +3 -1
- edsl/key_management/exceptions.py +62 -0
- edsl/key_management/key_lookup.py +1 -1
- edsl/key_management/key_lookup_builder.py +37 -14
- edsl/key_management/key_lookup_collection.py +2 -0
- edsl/language_models/__init__.py +1 -1
- edsl/language_models/exceptions.py +302 -14
- edsl/language_models/language_model.py +9 -8
- edsl/language_models/model.py +4 -4
- edsl/language_models/model_list.py +1 -1
- edsl/language_models/price_manager.py +1 -1
- edsl/language_models/raw_response_handler.py +14 -9
- edsl/language_models/registry.py +17 -21
- edsl/language_models/repair.py +0 -6
- edsl/language_models/unused/fake_openai_service.py +0 -1
- edsl/load_plugins.py +69 -0
- edsl/logger.py +146 -0
- edsl/notebooks/__init__.py +24 -1
- edsl/notebooks/exceptions.py +82 -0
- edsl/notebooks/notebook.py +7 -3
- edsl/notebooks/notebook_to_latex.py +1 -2
- edsl/plugins/__init__.py +63 -0
- edsl/plugins/built_in/export_example.py +50 -0
- edsl/plugins/built_in/pig_latin.py +67 -0
- edsl/plugins/cli.py +372 -0
- edsl/plugins/cli_typer.py +283 -0
- edsl/plugins/exceptions.py +31 -0
- edsl/plugins/hookspec.py +51 -0
- edsl/plugins/plugin_host.py +128 -0
- edsl/plugins/plugin_manager.py +633 -0
- edsl/plugins/plugins_registry.py +168 -0
- edsl/prompts/__init__.py +24 -1
- edsl/prompts/exceptions.py +107 -5
- edsl/prompts/prompt.py +15 -7
- edsl/questions/HTMLQuestion.py +5 -11
- edsl/questions/Quick.py +0 -1
- edsl/questions/__init__.py +6 -4
- edsl/questions/answer_validator_mixin.py +318 -323
- edsl/questions/compose_questions.py +3 -3
- edsl/questions/descriptors.py +11 -50
- edsl/questions/exceptions.py +278 -22
- edsl/questions/loop_processor.py +7 -5
- edsl/questions/prompt_templates/question_list.jinja +3 -0
- edsl/questions/question_base.py +46 -19
- edsl/questions/question_base_gen_mixin.py +2 -2
- edsl/questions/question_base_prompts_mixin.py +13 -7
- edsl/questions/question_budget.py +503 -98
- edsl/questions/question_check_box.py +660 -160
- edsl/questions/question_dict.py +345 -194
- edsl/questions/question_extract.py +401 -61
- edsl/questions/question_free_text.py +80 -14
- edsl/questions/question_functional.py +119 -9
- edsl/questions/{derived/question_likert_five.py → question_likert_five.py} +2 -2
- edsl/questions/{derived/question_linear_scale.py → question_linear_scale.py} +3 -4
- edsl/questions/question_list.py +275 -28
- edsl/questions/question_matrix.py +643 -96
- edsl/questions/question_multiple_choice.py +219 -51
- edsl/questions/question_numerical.py +361 -32
- edsl/questions/question_rank.py +401 -124
- edsl/questions/question_registry.py +7 -5
- edsl/questions/{derived/question_top_k.py → question_top_k.py} +3 -3
- edsl/questions/{derived/question_yes_no.py → question_yes_no.py} +3 -4
- edsl/questions/register_questions_meta.py +2 -2
- edsl/questions/response_validator_abc.py +13 -15
- edsl/questions/response_validator_factory.py +10 -12
- edsl/questions/templates/dict/answering_instructions.jinja +1 -0
- edsl/questions/templates/rank/question_presentation.jinja +1 -1
- edsl/results/__init__.py +1 -1
- edsl/results/exceptions.py +141 -7
- edsl/results/report.py +1 -2
- edsl/results/result.py +11 -9
- edsl/results/results.py +480 -321
- edsl/results/results_selector.py +8 -4
- edsl/scenarios/PdfExtractor.py +2 -2
- edsl/scenarios/construct_download_link.py +69 -35
- edsl/scenarios/directory_scanner.py +33 -14
- edsl/scenarios/document_chunker.py +1 -1
- edsl/scenarios/exceptions.py +238 -14
- edsl/scenarios/file_methods.py +1 -1
- edsl/scenarios/file_store.py +7 -3
- edsl/scenarios/handlers/__init__.py +17 -0
- edsl/scenarios/handlers/docx_file_store.py +0 -5
- edsl/scenarios/handlers/pdf_file_store.py +0 -1
- edsl/scenarios/handlers/pptx_file_store.py +0 -5
- edsl/scenarios/handlers/py_file_store.py +0 -1
- edsl/scenarios/handlers/sql_file_store.py +1 -4
- edsl/scenarios/handlers/sqlite_file_store.py +0 -1
- edsl/scenarios/handlers/txt_file_store.py +1 -1
- edsl/scenarios/scenario.py +1 -3
- edsl/scenarios/scenario_list.py +179 -27
- edsl/scenarios/scenario_list_pdf_tools.py +1 -0
- edsl/scenarios/scenario_selector.py +0 -1
- edsl/surveys/__init__.py +3 -4
- edsl/surveys/dag/__init__.py +4 -2
- edsl/surveys/descriptors.py +1 -1
- edsl/surveys/edit_survey.py +1 -0
- edsl/surveys/exceptions.py +165 -9
- edsl/surveys/memory/__init__.py +5 -3
- edsl/surveys/memory/memory_management.py +1 -0
- edsl/surveys/memory/memory_plan.py +6 -15
- edsl/surveys/rules/__init__.py +5 -3
- edsl/surveys/rules/rule.py +1 -2
- edsl/surveys/rules/rule_collection.py +1 -1
- edsl/surveys/survey.py +12 -24
- edsl/surveys/survey_css.py +3 -3
- edsl/surveys/survey_export.py +6 -3
- edsl/surveys/survey_flow_visualization.py +10 -1
- edsl/surveys/survey_simulator.py +2 -1
- edsl/tasks/__init__.py +23 -1
- edsl/tasks/exceptions.py +72 -0
- edsl/tasks/question_task_creator.py +3 -3
- edsl/tasks/task_creators.py +1 -3
- edsl/tasks/task_history.py +8 -10
- edsl/tasks/task_status_log.py +1 -2
- edsl/tokens/__init__.py +29 -1
- edsl/tokens/exceptions.py +37 -0
- edsl/tokens/interview_token_usage.py +3 -2
- edsl/tokens/token_usage.py +4 -3
- edsl/utilities/__init__.py +21 -1
- edsl/utilities/decorators.py +1 -2
- edsl/utilities/markdown_to_docx.py +2 -2
- edsl/utilities/markdown_to_pdf.py +1 -1
- edsl/utilities/repair_functions.py +0 -1
- edsl/utilities/restricted_python.py +0 -1
- edsl/utilities/template_loader.py +2 -3
- edsl/utilities/utilities.py +8 -29
- {edsl-0.1.49.dist-info → edsl-0.1.51.dist-info}/METADATA +32 -2
- edsl-0.1.51.dist-info/RECORD +365 -0
- edsl-0.1.51.dist-info/entry_points.txt +3 -0
- edsl/dataset/smart_objects.py +0 -96
- edsl/exceptions/BaseException.py +0 -21
- edsl/exceptions/__init__.py +0 -54
- edsl/exceptions/configuration.py +0 -16
- edsl/exceptions/general.py +0 -34
- edsl/questions/derived/__init__.py +0 -0
- edsl/study/ObjectEntry.py +0 -173
- edsl/study/ProofOfWork.py +0 -113
- edsl/study/SnapShot.py +0 -80
- edsl/study/Study.py +0 -520
- edsl/study/__init__.py +0 -6
- edsl/utilities/interface.py +0 -135
- edsl-0.1.49.dist-info/RECORD +0 -347
- {edsl-0.1.49.dist-info → edsl-0.1.51.dist-info}/LICENSE +0 -0
- {edsl-0.1.49.dist-info → edsl-0.1.51.dist-info}/WHEEL +0 -0
edsl/dataset/dataset.py
CHANGED
@@ -1,24 +1,25 @@
|
|
1
|
-
|
2
|
-
|
3
1
|
from __future__ import annotations
|
4
2
|
import sys
|
5
3
|
import json
|
6
4
|
import random
|
7
5
|
from collections import UserList
|
8
|
-
from typing import Any, Union, Optional, TYPE_CHECKING
|
6
|
+
from typing import Any, Union, Optional, TYPE_CHECKING, Callable
|
9
7
|
|
10
8
|
from ..base import PersistenceMixin, HashingMixin
|
11
9
|
|
12
10
|
from .dataset_tree import Tree
|
11
|
+
from .exceptions import DatasetKeyError, DatasetValueError, DatasetTypeError
|
12
|
+
|
13
13
|
|
14
14
|
from .display.table_display import TableDisplay
|
15
|
-
from .smart_objects import FirstObject
|
16
|
-
from .r.ggplot import GGPlotMethod
|
15
|
+
#from .smart_objects import FirstObject
|
17
16
|
from .dataset_operations_mixin import DatasetOperationsMixin
|
18
17
|
|
19
18
|
if TYPE_CHECKING:
|
20
19
|
from ..surveys import Survey
|
21
|
-
from ..questions
|
20
|
+
from ..questions import QuestionBase
|
21
|
+
from ..jobs import Job # noqa: F401
|
22
|
+
|
22
23
|
|
23
24
|
class Dataset(UserList, DatasetOperationsMixin, PersistenceMixin, HashingMixin):
|
24
25
|
"""
|
@@ -76,6 +77,7 @@ class Dataset(UserList, DatasetOperationsMixin, PersistenceMixin, HashingMixin):
|
|
76
77
|
Dataset([{'answer.how_feeling': ['OK', 'Great', 'Terrible']}])
|
77
78
|
"""
|
78
79
|
super().__init__(data)
|
80
|
+
#self.data = data
|
79
81
|
self.print_parameters = print_parameters
|
80
82
|
|
81
83
|
|
@@ -118,19 +120,9 @@ class Dataset(UserList, DatasetOperationsMixin, PersistenceMixin, HashingMixin):
|
|
118
120
|
new_data.append({key: values[:n]})
|
119
121
|
return Dataset(new_data)
|
120
122
|
|
121
|
-
def expand(self, field):
|
122
|
-
|
123
|
+
# def expand(self, field):
|
124
|
+
# return self.to_scenario_list().expand(field)
|
123
125
|
|
124
|
-
def view(self):
|
125
|
-
from perspective.widget import PerspectiveWidget
|
126
|
-
|
127
|
-
w = PerspectiveWidget(
|
128
|
-
self.to_pandas(),
|
129
|
-
plugin="Datagrid",
|
130
|
-
aggregates={"datetime": "any"},
|
131
|
-
sort=[["date", "desc"]],
|
132
|
-
)
|
133
|
-
return w
|
134
126
|
|
135
127
|
def keys(self) -> list[str]:
|
136
128
|
"""Return the keys of the dataset.
|
@@ -212,7 +204,7 @@ class Dataset(UserList, DatasetOperationsMixin, PersistenceMixin, HashingMixin):
|
|
212
204
|
values = value_dict["value"]
|
213
205
|
|
214
206
|
if not (len(rows) == len(keys) == len(values)):
|
215
|
-
raise
|
207
|
+
raise DatasetValueError("All input arrays must have the same length")
|
216
208
|
|
217
209
|
# Get unique keys and row indices
|
218
210
|
unique_keys = sorted(set(keys))
|
@@ -272,12 +264,6 @@ class Dataset(UserList, DatasetOperationsMixin, PersistenceMixin, HashingMixin):
|
|
272
264
|
>>> d = Dataset([{'a.b':[1,2,3,4]}])
|
273
265
|
>>> d._key_to_value('a.b')
|
274
266
|
[1, 2, 3, 4]
|
275
|
-
|
276
|
-
>>> d._key_to_value('a')
|
277
|
-
Traceback (most recent call last):
|
278
|
-
...
|
279
|
-
KeyError: "Key 'a' not found in any of the dictionaries."
|
280
|
-
|
281
267
|
"""
|
282
268
|
potential_matches = []
|
283
269
|
for data_dict in self.data:
|
@@ -290,11 +276,13 @@ class Dataset(UserList, DatasetOperationsMixin, PersistenceMixin, HashingMixin):
|
|
290
276
|
if len(potential_matches) == 1:
|
291
277
|
return potential_matches[0][1]
|
292
278
|
elif len(potential_matches) > 1:
|
293
|
-
|
279
|
+
from .exceptions import DatasetKeyError
|
280
|
+
raise DatasetKeyError(
|
294
281
|
f"Key '{key}' found in more than one location: {[m[0] for m in potential_matches]}"
|
295
282
|
)
|
296
283
|
|
297
|
-
|
284
|
+
from .exceptions import DatasetKeyError
|
285
|
+
raise DatasetKeyError(f"Key '{key}' not found in any of the dictionaries.")
|
298
286
|
|
299
287
|
def first(self) -> dict[str, Any]:
|
300
288
|
"""Get the first value of the first key in the first dictionary.
|
@@ -308,7 +296,7 @@ class Dataset(UserList, DatasetOperationsMixin, PersistenceMixin, HashingMixin):
|
|
308
296
|
"""Get the values of the first key in the dictionary."""
|
309
297
|
return list(d.values())[0]
|
310
298
|
|
311
|
-
return
|
299
|
+
return get_values(self.data[0])[0]
|
312
300
|
|
313
301
|
def latex(self, **kwargs):
|
314
302
|
return self.table().latex()
|
@@ -338,7 +326,7 @@ class Dataset(UserList, DatasetOperationsMixin, PersistenceMixin, HashingMixin):
|
|
338
326
|
"""
|
339
327
|
if "format" in kwargs:
|
340
328
|
if kwargs["format"] not in ["html", "markdown", "rich", "latex"]:
|
341
|
-
raise
|
329
|
+
raise DatasetValueError(f"Format '{kwargs['format']}' not supported.")
|
342
330
|
|
343
331
|
# If rich format is requested, set tablefmt accordingly
|
344
332
|
if kwargs["format"] == "rich":
|
@@ -371,10 +359,18 @@ class Dataset(UserList, DatasetOperationsMixin, PersistenceMixin, HashingMixin):
|
|
371
359
|
merged_df = df1.merge(df2, how="left", left_on=by_x, right_on=by_y)
|
372
360
|
return Dataset.from_pandas_dataframe(merged_df)
|
373
361
|
|
374
|
-
def to(self, survey_or_question: Union["Survey", "QuestionBase"]) -> "
|
375
|
-
"""Return a new dataset with the observations transformed by the given survey or question.
|
376
|
-
|
377
|
-
|
362
|
+
def to(self, survey_or_question: Union["Survey", "QuestionBase"]) -> "Job":
|
363
|
+
"""Return a new dataset with the observations transformed by the given survey or question.
|
364
|
+
|
365
|
+
>>> d = Dataset([{'person_name':["John"]}])
|
366
|
+
>>> from edsl import QuestionFreeText
|
367
|
+
>>> q = QuestionFreeText(question_text = "How are you, {{ person_name ?}}?", question_name = "how_feeling")
|
368
|
+
>>> jobs = d.to(q)
|
369
|
+
>>> isinstance(jobs, object)
|
370
|
+
True
|
371
|
+
"""
|
372
|
+
from ..surveys import Survey
|
373
|
+
from ..questions import QuestionBase
|
378
374
|
|
379
375
|
if isinstance(survey_or_question, Survey):
|
380
376
|
return survey_or_question.by(self.to_scenario_list())
|
@@ -396,9 +392,10 @@ class Dataset(UserList, DatasetOperationsMixin, PersistenceMixin, HashingMixin):
|
|
396
392
|
"""
|
397
393
|
for key in keys:
|
398
394
|
if key not in self.keys():
|
399
|
-
|
400
|
-
|
401
|
-
|
395
|
+
from .exceptions import DatasetValueError
|
396
|
+
raise DatasetValueError(f"Key '{key}' not found in the dataset. "
|
397
|
+
f"Available keys: {self.keys()}"
|
398
|
+
)
|
402
399
|
|
403
400
|
if isinstance(keys, str):
|
404
401
|
keys = [keys]
|
@@ -442,7 +439,11 @@ class Dataset(UserList, DatasetOperationsMixin, PersistenceMixin, HashingMixin):
|
|
442
439
|
|
443
440
|
return self
|
444
441
|
|
445
|
-
def
|
442
|
+
def expand_field(self, field):
|
443
|
+
"""Expand a field in the dataset.
|
444
|
+
|
445
|
+
Renamed to avoid conflict with the expand method defined earlier.
|
446
|
+
"""
|
446
447
|
return self.to_scenario_list().expand(field).to_dataset()
|
447
448
|
|
448
449
|
def sample(
|
@@ -462,21 +463,18 @@ class Dataset(UserList, DatasetOperationsMixin, PersistenceMixin, HashingMixin):
|
|
462
463
|
>>> d = Dataset([{'a.b':[1,2,3,4]}])
|
463
464
|
>>> d.sample(n=2, seed=0, with_replacement=True)
|
464
465
|
Dataset([{'a.b': [4, 4]}])
|
465
|
-
|
466
|
-
>>> d.sample(n = 10, seed=0, with_replacement=False)
|
467
|
-
Traceback (most recent call last):
|
468
|
-
...
|
469
|
-
ValueError: Sample size cannot be greater than the number of available elements when sampling without replacement.
|
470
466
|
"""
|
471
467
|
if seed is not None:
|
472
468
|
random.seed(seed)
|
473
469
|
|
474
470
|
# Validate the input for sampling parameters
|
475
471
|
if n is None and frac is None:
|
476
|
-
|
472
|
+
from .exceptions import DatasetValueError
|
473
|
+
raise DatasetValueError("Either 'n' or 'frac' must be provided for sampling.")
|
477
474
|
|
478
475
|
if n is not None and frac is not None:
|
479
|
-
|
476
|
+
from .exceptions import DatasetValueError
|
477
|
+
raise DatasetValueError("Only one of 'n' or 'frac' should be specified.")
|
480
478
|
|
481
479
|
# Get the length of the lists from the first entry
|
482
480
|
first_key, first_values = list(self[0].items())[0]
|
@@ -487,7 +485,8 @@ class Dataset(UserList, DatasetOperationsMixin, PersistenceMixin, HashingMixin):
|
|
487
485
|
n = int(total_length * frac)
|
488
486
|
|
489
487
|
if not with_replacement and n > total_length:
|
490
|
-
|
488
|
+
from .exceptions import DatasetValueError
|
489
|
+
raise DatasetValueError(
|
491
490
|
"Sample size cannot be greater than the number of available elements when sampling without replacement."
|
492
491
|
)
|
493
492
|
|
@@ -504,60 +503,72 @@ class Dataset(UserList, DatasetOperationsMixin, PersistenceMixin, HashingMixin):
|
|
504
503
|
|
505
504
|
return self
|
506
505
|
|
507
|
-
def
|
508
|
-
"""Return a new dataset with the observations sorted by the given key.
|
509
|
-
|
510
|
-
:param sort_key: The key to sort the observations by.
|
511
|
-
:param reverse: Whether to sort in reverse order.
|
512
|
-
|
513
|
-
>>> d = Dataset([{'a':[1,2,3,4]}, {'b':[4,3,2,1]}])
|
514
|
-
>>> d.order_by('a')
|
515
|
-
Dataset([{'a': [1, 2, 3, 4]}, {'b': [4, 3, 2, 1]}])
|
516
|
-
|
517
|
-
>>> d.order_by('a', reverse=True)
|
518
|
-
Dataset([{'a': [4, 3, 2, 1]}, {'b': [1, 2, 3, 4]}])
|
519
|
-
|
520
|
-
>>> d = Dataset([{'X.a':[1,2,3,4]}, {'X.b':[4,3,2,1]}])
|
521
|
-
>>> d.order_by('a')
|
522
|
-
Dataset([{'X.a': [1, 2, 3, 4]}, {'X.b': [4, 3, 2, 1]}])
|
523
|
-
|
524
|
-
|
506
|
+
def get_sort_indices(self, lst: list[Any], reverse: bool = False, use_numpy: bool = True) -> list[int]:
|
525
507
|
"""
|
526
|
-
|
508
|
+
Return the indices that would sort the list, using either numpy or pure Python.
|
509
|
+
None values are placed at the end of the sorted list.
|
527
510
|
|
528
|
-
|
529
|
-
|
530
|
-
|
511
|
+
Args:
|
512
|
+
lst: The list to be sorted
|
513
|
+
reverse: Whether to sort in descending order
|
514
|
+
use_numpy: Whether to use numpy implementation (falls back to pure Python if numpy is unavailable)
|
531
515
|
|
532
|
-
|
533
|
-
|
534
|
-
|
535
|
-
|
536
|
-
|
537
|
-
|
538
|
-
|
516
|
+
Returns:
|
517
|
+
A list of indices that would sort the list
|
518
|
+
"""
|
519
|
+
if use_numpy:
|
520
|
+
try:
|
521
|
+
import numpy as np
|
522
|
+
# Convert list to numpy array
|
523
|
+
arr = np.array(lst, dtype=object)
|
524
|
+
# Get mask of non-None values
|
525
|
+
mask = ~(arr is None)
|
526
|
+
# Get indices of non-None and None values
|
527
|
+
non_none_indices = np.where(mask)[0]
|
528
|
+
none_indices = np.where(~mask)[0]
|
529
|
+
# Sort non-None values
|
530
|
+
sorted_indices = non_none_indices[np.argsort(arr[mask])]
|
531
|
+
# Combine sorted non-None indices with None indices
|
532
|
+
indices = np.concatenate([sorted_indices, none_indices]).tolist()
|
533
|
+
if reverse:
|
534
|
+
# When reversing, keep None values at end
|
535
|
+
indices = sorted_indices[::-1].tolist() + none_indices.tolist()
|
536
|
+
return indices
|
537
|
+
except ImportError:
|
538
|
+
# Fallback to pure Python if numpy is not available
|
539
|
+
pass
|
540
|
+
|
541
|
+
# Pure Python implementation
|
542
|
+
enumerated = list(enumerate(lst))
|
543
|
+
# Sort None values to end by using (is_none, value) as sort key
|
544
|
+
sorted_pairs = sorted(enumerated,
|
545
|
+
key=lambda x: (x[1] is None, x[1]),
|
546
|
+
reverse=reverse)
|
547
|
+
return [index for index, _ in sorted_pairs]
|
548
|
+
|
549
|
+
def order_by(self, sort_key: str, reverse: bool = False, use_numpy: bool = True) -> Dataset:
|
550
|
+
"""Return a new dataset with the observations sorted by the given key.
|
539
551
|
|
552
|
+
Args:
|
553
|
+
sort_key: The key to sort the observations by
|
554
|
+
reverse: Whether to sort in reverse order
|
555
|
+
use_numpy: Whether to use numpy for sorting (faster for large lists)
|
556
|
+
"""
|
540
557
|
number_found = 0
|
541
558
|
for obs in self.data:
|
542
559
|
key, values = list(obs.items())[0]
|
543
|
-
|
544
|
-
# key = list(obs.keys())[0]
|
545
|
-
if (
|
546
|
-
sort_key == key or sort_key == key.split(".")[-1]
|
547
|
-
): # e.g., "age" in "scenario.age"
|
560
|
+
if sort_key == key or sort_key == key.split(".")[-1]:
|
548
561
|
relevant_values = values
|
549
562
|
number_found += 1
|
550
563
|
|
551
564
|
if number_found == 0:
|
552
|
-
raise
|
565
|
+
raise DatasetKeyError(f"Key '{sort_key}' not found in any of the dictionaries.")
|
553
566
|
elif number_found > 1:
|
554
|
-
raise
|
567
|
+
raise DatasetKeyError(f"Key '{sort_key}' found in more than one dictionary.")
|
555
568
|
|
556
|
-
|
557
|
-
sort_indices_list = sort_indices(relevant_values)
|
569
|
+
sort_indices_list = self.get_sort_indices(relevant_values, reverse=reverse, use_numpy=use_numpy)
|
558
570
|
new_data = []
|
559
571
|
for observation in self.data:
|
560
|
-
# print(observation)
|
561
572
|
key, values = list(observation.items())[0]
|
562
573
|
new_values = [values[i] for i in sort_indices_list]
|
563
574
|
new_data.append({key: new_values})
|
@@ -578,7 +589,7 @@ class Dataset(UserList, DatasetOperationsMixin, PersistenceMixin, HashingMixin):
|
|
578
589
|
def table(
|
579
590
|
self,
|
580
591
|
*fields,
|
581
|
-
tablefmt: Optional[str] =
|
592
|
+
tablefmt: Optional[str] = "rich",
|
582
593
|
max_rows: Optional[int] = None,
|
583
594
|
pretty_labels=None,
|
584
595
|
print_parameters: Optional[dict] = None,
|
@@ -637,7 +648,8 @@ class Dataset(UserList, DatasetOperationsMixin, PersistenceMixin, HashingMixin):
|
|
637
648
|
|
638
649
|
if max_rows is not None:
|
639
650
|
if max_rows > len(data):
|
640
|
-
|
651
|
+
from .exceptions import DatasetValueError
|
652
|
+
raise DatasetValueError(
|
641
653
|
"max_rows cannot be greater than the number of rows in the dataset."
|
642
654
|
)
|
643
655
|
last_line = data[-1]
|
@@ -675,6 +687,19 @@ class Dataset(UserList, DatasetOperationsMixin, PersistenceMixin, HashingMixin):
|
|
675
687
|
def from_pandas_dataframe(cls, df):
|
676
688
|
result = cls([{col: df[col].tolist()} for col in df.columns])
|
677
689
|
return result
|
690
|
+
|
691
|
+
def to_dict(self) -> dict:
|
692
|
+
"""
|
693
|
+
Convert the dataset to a dictionary.
|
694
|
+
"""
|
695
|
+
return {'data': self.data}
|
696
|
+
|
697
|
+
@classmethod
|
698
|
+
def from_dict(cls, data: dict) -> 'Dataset':
|
699
|
+
"""
|
700
|
+
Convert a dictionary to a dataset.
|
701
|
+
"""
|
702
|
+
return cls(data['data'])
|
678
703
|
|
679
704
|
def to_docx(self, output_file: str, title: str = None) -> None:
|
680
705
|
"""
|
@@ -726,6 +751,72 @@ class Dataset(UserList, DatasetOperationsMixin, PersistenceMixin, HashingMixin):
|
|
726
751
|
# Save the document
|
727
752
|
doc.save(output_file)
|
728
753
|
|
754
|
+
def expand(self, field: str, number_field: bool = False) -> "Dataset":
|
755
|
+
"""
|
756
|
+
Expand a field containing lists into multiple rows.
|
757
|
+
|
758
|
+
Args:
|
759
|
+
field: The field containing lists to expand
|
760
|
+
number_field: If True, adds a number field indicating the position in the original list
|
761
|
+
|
762
|
+
Returns:
|
763
|
+
A new Dataset with the expanded rows
|
764
|
+
|
765
|
+
Example:
|
766
|
+
>>> from edsl.dataset import Dataset
|
767
|
+
>>> d = Dataset([{'a': [[1, 2, 3], [4, 5, 6]]}, {'b': ['x', 'y']}])
|
768
|
+
>>> d.expand('a')
|
769
|
+
Dataset([{'a': [1, 2, 3, 4, 5, 6]}, {'b': ['x', 'x', 'x', 'y', 'y', 'y']}])
|
770
|
+
"""
|
771
|
+
from collections.abc import Iterable
|
772
|
+
|
773
|
+
# Find the field in the dataset
|
774
|
+
field_data = None
|
775
|
+
for entry in self.data:
|
776
|
+
key = list(entry.keys())[0]
|
777
|
+
if key == field:
|
778
|
+
field_data = entry[key]
|
779
|
+
break
|
780
|
+
|
781
|
+
if field_data is None:
|
782
|
+
raise DatasetKeyError(f"Field '{field}' not found in dataset. Available fields are: {self.keys()}")
|
783
|
+
|
784
|
+
|
785
|
+
# Validate that the field contains lists
|
786
|
+
if not all(isinstance(v, list) for v in field_data):
|
787
|
+
raise DatasetTypeError(f"Field '{field}' must contain lists in all entries")
|
788
|
+
|
789
|
+
# Create new expanded data structure
|
790
|
+
new_data = []
|
791
|
+
|
792
|
+
# Process each field
|
793
|
+
for entry in self.data:
|
794
|
+
key, values = list(entry.items())[0]
|
795
|
+
new_values = []
|
796
|
+
|
797
|
+
if key == field:
|
798
|
+
# This is the field to expand - flatten all sublists
|
799
|
+
for row_values in values:
|
800
|
+
if not isinstance(row_values, Iterable) or isinstance(row_values, str):
|
801
|
+
row_values = [row_values]
|
802
|
+
new_values.extend(row_values)
|
803
|
+
else:
|
804
|
+
# For other fields, repeat each value the appropriate number of times
|
805
|
+
for i, row_value in enumerate(values):
|
806
|
+
expand_length = len(field_data[i]) if i < len(field_data) else 0
|
807
|
+
new_values.extend([row_value] * expand_length)
|
808
|
+
|
809
|
+
new_data.append({key: new_values})
|
810
|
+
|
811
|
+
# Add number field if requested
|
812
|
+
if number_field:
|
813
|
+
number_values = []
|
814
|
+
for i, lst in enumerate(field_data):
|
815
|
+
number_values.extend(range(1, len(lst) + 1))
|
816
|
+
new_data.append({f"{field}_number": number_values})
|
817
|
+
|
818
|
+
return Dataset(new_data)
|
819
|
+
|
729
820
|
|
730
821
|
if __name__ == "__main__":
|
731
822
|
import doctest
|