edsl 0.1.49__py3-none-any.whl → 0.1.51__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (257) hide show
  1. edsl/__init__.py +124 -53
  2. edsl/__version__.py +1 -1
  3. edsl/agents/agent.py +21 -21
  4. edsl/agents/agent_list.py +2 -5
  5. edsl/agents/exceptions.py +119 -5
  6. edsl/base/__init__.py +10 -35
  7. edsl/base/base_class.py +71 -36
  8. edsl/base/base_exception.py +204 -0
  9. edsl/base/data_transfer_models.py +1 -1
  10. edsl/base/exceptions.py +94 -0
  11. edsl/buckets/__init__.py +15 -1
  12. edsl/buckets/bucket_collection.py +3 -4
  13. edsl/buckets/exceptions.py +107 -0
  14. edsl/buckets/model_buckets.py +1 -2
  15. edsl/buckets/token_bucket.py +11 -6
  16. edsl/buckets/token_bucket_api.py +27 -12
  17. edsl/buckets/token_bucket_client.py +9 -7
  18. edsl/caching/cache.py +12 -4
  19. edsl/caching/cache_entry.py +10 -9
  20. edsl/caching/exceptions.py +113 -7
  21. edsl/caching/remote_cache_sync.py +6 -7
  22. edsl/caching/sql_dict.py +20 -14
  23. edsl/cli.py +43 -0
  24. edsl/config/__init__.py +1 -1
  25. edsl/config/config_class.py +32 -6
  26. edsl/conversation/Conversation.py +8 -4
  27. edsl/conversation/car_buying.py +1 -3
  28. edsl/conversation/exceptions.py +58 -0
  29. edsl/conversation/mug_negotiation.py +2 -8
  30. edsl/coop/__init__.py +28 -6
  31. edsl/coop/coop.py +120 -29
  32. edsl/coop/coop_functions.py +1 -1
  33. edsl/coop/ep_key_handling.py +1 -1
  34. edsl/coop/exceptions.py +188 -9
  35. edsl/coop/price_fetcher.py +5 -8
  36. edsl/coop/utils.py +4 -6
  37. edsl/dataset/__init__.py +5 -4
  38. edsl/dataset/dataset.py +177 -86
  39. edsl/dataset/dataset_operations_mixin.py +98 -76
  40. edsl/dataset/dataset_tree.py +11 -7
  41. edsl/dataset/display/table_display.py +0 -2
  42. edsl/dataset/display/table_renderers.py +6 -4
  43. edsl/dataset/exceptions.py +125 -0
  44. edsl/dataset/file_exports.py +18 -11
  45. edsl/dataset/r/ggplot.py +13 -6
  46. edsl/display/__init__.py +27 -0
  47. edsl/display/core.py +147 -0
  48. edsl/display/plugin.py +189 -0
  49. edsl/display/utils.py +52 -0
  50. edsl/inference_services/__init__.py +9 -1
  51. edsl/inference_services/available_model_cache_handler.py +1 -1
  52. edsl/inference_services/available_model_fetcher.py +5 -6
  53. edsl/inference_services/data_structures.py +10 -7
  54. edsl/inference_services/exceptions.py +132 -1
  55. edsl/inference_services/inference_service_abc.py +2 -2
  56. edsl/inference_services/inference_services_collection.py +2 -6
  57. edsl/inference_services/registry.py +4 -3
  58. edsl/inference_services/service_availability.py +4 -3
  59. edsl/inference_services/services/anthropic_service.py +4 -1
  60. edsl/inference_services/services/aws_bedrock.py +13 -12
  61. edsl/inference_services/services/azure_ai.py +12 -10
  62. edsl/inference_services/services/deep_infra_service.py +1 -4
  63. edsl/inference_services/services/deep_seek_service.py +1 -5
  64. edsl/inference_services/services/google_service.py +7 -3
  65. edsl/inference_services/services/groq_service.py +1 -1
  66. edsl/inference_services/services/mistral_ai_service.py +4 -2
  67. edsl/inference_services/services/ollama_service.py +1 -1
  68. edsl/inference_services/services/open_ai_service.py +7 -5
  69. edsl/inference_services/services/perplexity_service.py +6 -2
  70. edsl/inference_services/services/test_service.py +8 -7
  71. edsl/inference_services/services/together_ai_service.py +2 -3
  72. edsl/inference_services/services/xai_service.py +1 -1
  73. edsl/instructions/__init__.py +1 -1
  74. edsl/instructions/change_instruction.py +7 -5
  75. edsl/instructions/exceptions.py +61 -0
  76. edsl/instructions/instruction.py +6 -2
  77. edsl/instructions/instruction_collection.py +6 -4
  78. edsl/instructions/instruction_handler.py +12 -15
  79. edsl/interviews/ReportErrors.py +0 -3
  80. edsl/interviews/__init__.py +9 -2
  81. edsl/interviews/answering_function.py +11 -13
  82. edsl/interviews/exception_tracking.py +15 -8
  83. edsl/interviews/exceptions.py +79 -0
  84. edsl/interviews/interview.py +33 -30
  85. edsl/interviews/interview_status_dictionary.py +4 -2
  86. edsl/interviews/interview_status_log.py +2 -1
  87. edsl/interviews/interview_task_manager.py +5 -5
  88. edsl/interviews/request_token_estimator.py +5 -2
  89. edsl/interviews/statistics.py +3 -4
  90. edsl/invigilators/__init__.py +7 -1
  91. edsl/invigilators/exceptions.py +79 -0
  92. edsl/invigilators/invigilator_base.py +0 -1
  93. edsl/invigilators/invigilators.py +9 -13
  94. edsl/invigilators/prompt_constructor.py +1 -5
  95. edsl/invigilators/prompt_helpers.py +8 -4
  96. edsl/invigilators/question_instructions_prompt_builder.py +1 -1
  97. edsl/invigilators/question_option_processor.py +9 -5
  98. edsl/invigilators/question_template_replacements_builder.py +3 -2
  99. edsl/jobs/__init__.py +42 -5
  100. edsl/jobs/async_interview_runner.py +25 -23
  101. edsl/jobs/check_survey_scenario_compatibility.py +11 -10
  102. edsl/jobs/data_structures.py +8 -5
  103. edsl/jobs/exceptions.py +177 -8
  104. edsl/jobs/fetch_invigilator.py +1 -1
  105. edsl/jobs/jobs.py +74 -69
  106. edsl/jobs/jobs_checks.py +6 -7
  107. edsl/jobs/jobs_component_constructor.py +4 -4
  108. edsl/jobs/jobs_pricing_estimation.py +4 -3
  109. edsl/jobs/jobs_remote_inference_logger.py +5 -4
  110. edsl/jobs/jobs_runner_asyncio.py +3 -4
  111. edsl/jobs/jobs_runner_status.py +8 -9
  112. edsl/jobs/remote_inference.py +27 -24
  113. edsl/jobs/results_exceptions_handler.py +10 -7
  114. edsl/key_management/__init__.py +3 -1
  115. edsl/key_management/exceptions.py +62 -0
  116. edsl/key_management/key_lookup.py +1 -1
  117. edsl/key_management/key_lookup_builder.py +37 -14
  118. edsl/key_management/key_lookup_collection.py +2 -0
  119. edsl/language_models/__init__.py +1 -1
  120. edsl/language_models/exceptions.py +302 -14
  121. edsl/language_models/language_model.py +9 -8
  122. edsl/language_models/model.py +4 -4
  123. edsl/language_models/model_list.py +1 -1
  124. edsl/language_models/price_manager.py +1 -1
  125. edsl/language_models/raw_response_handler.py +14 -9
  126. edsl/language_models/registry.py +17 -21
  127. edsl/language_models/repair.py +0 -6
  128. edsl/language_models/unused/fake_openai_service.py +0 -1
  129. edsl/load_plugins.py +69 -0
  130. edsl/logger.py +146 -0
  131. edsl/notebooks/__init__.py +24 -1
  132. edsl/notebooks/exceptions.py +82 -0
  133. edsl/notebooks/notebook.py +7 -3
  134. edsl/notebooks/notebook_to_latex.py +1 -2
  135. edsl/plugins/__init__.py +63 -0
  136. edsl/plugins/built_in/export_example.py +50 -0
  137. edsl/plugins/built_in/pig_latin.py +67 -0
  138. edsl/plugins/cli.py +372 -0
  139. edsl/plugins/cli_typer.py +283 -0
  140. edsl/plugins/exceptions.py +31 -0
  141. edsl/plugins/hookspec.py +51 -0
  142. edsl/plugins/plugin_host.py +128 -0
  143. edsl/plugins/plugin_manager.py +633 -0
  144. edsl/plugins/plugins_registry.py +168 -0
  145. edsl/prompts/__init__.py +24 -1
  146. edsl/prompts/exceptions.py +107 -5
  147. edsl/prompts/prompt.py +15 -7
  148. edsl/questions/HTMLQuestion.py +5 -11
  149. edsl/questions/Quick.py +0 -1
  150. edsl/questions/__init__.py +6 -4
  151. edsl/questions/answer_validator_mixin.py +318 -323
  152. edsl/questions/compose_questions.py +3 -3
  153. edsl/questions/descriptors.py +11 -50
  154. edsl/questions/exceptions.py +278 -22
  155. edsl/questions/loop_processor.py +7 -5
  156. edsl/questions/prompt_templates/question_list.jinja +3 -0
  157. edsl/questions/question_base.py +46 -19
  158. edsl/questions/question_base_gen_mixin.py +2 -2
  159. edsl/questions/question_base_prompts_mixin.py +13 -7
  160. edsl/questions/question_budget.py +503 -98
  161. edsl/questions/question_check_box.py +660 -160
  162. edsl/questions/question_dict.py +345 -194
  163. edsl/questions/question_extract.py +401 -61
  164. edsl/questions/question_free_text.py +80 -14
  165. edsl/questions/question_functional.py +119 -9
  166. edsl/questions/{derived/question_likert_five.py → question_likert_five.py} +2 -2
  167. edsl/questions/{derived/question_linear_scale.py → question_linear_scale.py} +3 -4
  168. edsl/questions/question_list.py +275 -28
  169. edsl/questions/question_matrix.py +643 -96
  170. edsl/questions/question_multiple_choice.py +219 -51
  171. edsl/questions/question_numerical.py +361 -32
  172. edsl/questions/question_rank.py +401 -124
  173. edsl/questions/question_registry.py +7 -5
  174. edsl/questions/{derived/question_top_k.py → question_top_k.py} +3 -3
  175. edsl/questions/{derived/question_yes_no.py → question_yes_no.py} +3 -4
  176. edsl/questions/register_questions_meta.py +2 -2
  177. edsl/questions/response_validator_abc.py +13 -15
  178. edsl/questions/response_validator_factory.py +10 -12
  179. edsl/questions/templates/dict/answering_instructions.jinja +1 -0
  180. edsl/questions/templates/rank/question_presentation.jinja +1 -1
  181. edsl/results/__init__.py +1 -1
  182. edsl/results/exceptions.py +141 -7
  183. edsl/results/report.py +1 -2
  184. edsl/results/result.py +11 -9
  185. edsl/results/results.py +480 -321
  186. edsl/results/results_selector.py +8 -4
  187. edsl/scenarios/PdfExtractor.py +2 -2
  188. edsl/scenarios/construct_download_link.py +69 -35
  189. edsl/scenarios/directory_scanner.py +33 -14
  190. edsl/scenarios/document_chunker.py +1 -1
  191. edsl/scenarios/exceptions.py +238 -14
  192. edsl/scenarios/file_methods.py +1 -1
  193. edsl/scenarios/file_store.py +7 -3
  194. edsl/scenarios/handlers/__init__.py +17 -0
  195. edsl/scenarios/handlers/docx_file_store.py +0 -5
  196. edsl/scenarios/handlers/pdf_file_store.py +0 -1
  197. edsl/scenarios/handlers/pptx_file_store.py +0 -5
  198. edsl/scenarios/handlers/py_file_store.py +0 -1
  199. edsl/scenarios/handlers/sql_file_store.py +1 -4
  200. edsl/scenarios/handlers/sqlite_file_store.py +0 -1
  201. edsl/scenarios/handlers/txt_file_store.py +1 -1
  202. edsl/scenarios/scenario.py +1 -3
  203. edsl/scenarios/scenario_list.py +179 -27
  204. edsl/scenarios/scenario_list_pdf_tools.py +1 -0
  205. edsl/scenarios/scenario_selector.py +0 -1
  206. edsl/surveys/__init__.py +3 -4
  207. edsl/surveys/dag/__init__.py +4 -2
  208. edsl/surveys/descriptors.py +1 -1
  209. edsl/surveys/edit_survey.py +1 -0
  210. edsl/surveys/exceptions.py +165 -9
  211. edsl/surveys/memory/__init__.py +5 -3
  212. edsl/surveys/memory/memory_management.py +1 -0
  213. edsl/surveys/memory/memory_plan.py +6 -15
  214. edsl/surveys/rules/__init__.py +5 -3
  215. edsl/surveys/rules/rule.py +1 -2
  216. edsl/surveys/rules/rule_collection.py +1 -1
  217. edsl/surveys/survey.py +12 -24
  218. edsl/surveys/survey_css.py +3 -3
  219. edsl/surveys/survey_export.py +6 -3
  220. edsl/surveys/survey_flow_visualization.py +10 -1
  221. edsl/surveys/survey_simulator.py +2 -1
  222. edsl/tasks/__init__.py +23 -1
  223. edsl/tasks/exceptions.py +72 -0
  224. edsl/tasks/question_task_creator.py +3 -3
  225. edsl/tasks/task_creators.py +1 -3
  226. edsl/tasks/task_history.py +8 -10
  227. edsl/tasks/task_status_log.py +1 -2
  228. edsl/tokens/__init__.py +29 -1
  229. edsl/tokens/exceptions.py +37 -0
  230. edsl/tokens/interview_token_usage.py +3 -2
  231. edsl/tokens/token_usage.py +4 -3
  232. edsl/utilities/__init__.py +21 -1
  233. edsl/utilities/decorators.py +1 -2
  234. edsl/utilities/markdown_to_docx.py +2 -2
  235. edsl/utilities/markdown_to_pdf.py +1 -1
  236. edsl/utilities/repair_functions.py +0 -1
  237. edsl/utilities/restricted_python.py +0 -1
  238. edsl/utilities/template_loader.py +2 -3
  239. edsl/utilities/utilities.py +8 -29
  240. {edsl-0.1.49.dist-info → edsl-0.1.51.dist-info}/METADATA +32 -2
  241. edsl-0.1.51.dist-info/RECORD +365 -0
  242. edsl-0.1.51.dist-info/entry_points.txt +3 -0
  243. edsl/dataset/smart_objects.py +0 -96
  244. edsl/exceptions/BaseException.py +0 -21
  245. edsl/exceptions/__init__.py +0 -54
  246. edsl/exceptions/configuration.py +0 -16
  247. edsl/exceptions/general.py +0 -34
  248. edsl/questions/derived/__init__.py +0 -0
  249. edsl/study/ObjectEntry.py +0 -173
  250. edsl/study/ProofOfWork.py +0 -113
  251. edsl/study/SnapShot.py +0 -80
  252. edsl/study/Study.py +0 -520
  253. edsl/study/__init__.py +0 -6
  254. edsl/utilities/interface.py +0 -135
  255. edsl-0.1.49.dist-info/RECORD +0 -347
  256. {edsl-0.1.49.dist-info → edsl-0.1.51.dist-info}/LICENSE +0 -0
  257. {edsl-0.1.49.dist-info → edsl-0.1.51.dist-info}/WHEEL +0 -0
edsl/dataset/dataset.py CHANGED
@@ -1,24 +1,25 @@
1
-
2
-
3
1
  from __future__ import annotations
4
2
  import sys
5
3
  import json
6
4
  import random
7
5
  from collections import UserList
8
- from typing import Any, Union, Optional, TYPE_CHECKING
6
+ from typing import Any, Union, Optional, TYPE_CHECKING, Callable
9
7
 
10
8
  from ..base import PersistenceMixin, HashingMixin
11
9
 
12
10
  from .dataset_tree import Tree
11
+ from .exceptions import DatasetKeyError, DatasetValueError, DatasetTypeError
12
+
13
13
 
14
14
  from .display.table_display import TableDisplay
15
- from .smart_objects import FirstObject
16
- from .r.ggplot import GGPlotMethod
15
+ #from .smart_objects import FirstObject
17
16
  from .dataset_operations_mixin import DatasetOperationsMixin
18
17
 
19
18
  if TYPE_CHECKING:
20
19
  from ..surveys import Survey
21
- from ..questions.QuestionBase import QuestionBase
20
+ from ..questions import QuestionBase
21
+ from ..jobs import Job # noqa: F401
22
+
22
23
 
23
24
  class Dataset(UserList, DatasetOperationsMixin, PersistenceMixin, HashingMixin):
24
25
  """
@@ -76,6 +77,7 @@ class Dataset(UserList, DatasetOperationsMixin, PersistenceMixin, HashingMixin):
76
77
  Dataset([{'answer.how_feeling': ['OK', 'Great', 'Terrible']}])
77
78
  """
78
79
  super().__init__(data)
80
+ #self.data = data
79
81
  self.print_parameters = print_parameters
80
82
 
81
83
 
@@ -118,19 +120,9 @@ class Dataset(UserList, DatasetOperationsMixin, PersistenceMixin, HashingMixin):
118
120
  new_data.append({key: values[:n]})
119
121
  return Dataset(new_data)
120
122
 
121
- def expand(self, field):
122
- return self.to_scenario_list().expand(field)
123
+ # def expand(self, field):
124
+ # return self.to_scenario_list().expand(field)
123
125
 
124
- def view(self):
125
- from perspective.widget import PerspectiveWidget
126
-
127
- w = PerspectiveWidget(
128
- self.to_pandas(),
129
- plugin="Datagrid",
130
- aggregates={"datetime": "any"},
131
- sort=[["date", "desc"]],
132
- )
133
- return w
134
126
 
135
127
  def keys(self) -> list[str]:
136
128
  """Return the keys of the dataset.
@@ -212,7 +204,7 @@ class Dataset(UserList, DatasetOperationsMixin, PersistenceMixin, HashingMixin):
212
204
  values = value_dict["value"]
213
205
 
214
206
  if not (len(rows) == len(keys) == len(values)):
215
- raise ValueError("All input arrays must have the same length")
207
+ raise DatasetValueError("All input arrays must have the same length")
216
208
 
217
209
  # Get unique keys and row indices
218
210
  unique_keys = sorted(set(keys))
@@ -272,12 +264,6 @@ class Dataset(UserList, DatasetOperationsMixin, PersistenceMixin, HashingMixin):
272
264
  >>> d = Dataset([{'a.b':[1,2,3,4]}])
273
265
  >>> d._key_to_value('a.b')
274
266
  [1, 2, 3, 4]
275
-
276
- >>> d._key_to_value('a')
277
- Traceback (most recent call last):
278
- ...
279
- KeyError: "Key 'a' not found in any of the dictionaries."
280
-
281
267
  """
282
268
  potential_matches = []
283
269
  for data_dict in self.data:
@@ -290,11 +276,13 @@ class Dataset(UserList, DatasetOperationsMixin, PersistenceMixin, HashingMixin):
290
276
  if len(potential_matches) == 1:
291
277
  return potential_matches[0][1]
292
278
  elif len(potential_matches) > 1:
293
- raise KeyError(
279
+ from .exceptions import DatasetKeyError
280
+ raise DatasetKeyError(
294
281
  f"Key '{key}' found in more than one location: {[m[0] for m in potential_matches]}"
295
282
  )
296
283
 
297
- raise KeyError(f"Key '{key}' not found in any of the dictionaries.")
284
+ from .exceptions import DatasetKeyError
285
+ raise DatasetKeyError(f"Key '{key}' not found in any of the dictionaries.")
298
286
 
299
287
  def first(self) -> dict[str, Any]:
300
288
  """Get the first value of the first key in the first dictionary.
@@ -308,7 +296,7 @@ class Dataset(UserList, DatasetOperationsMixin, PersistenceMixin, HashingMixin):
308
296
  """Get the values of the first key in the dictionary."""
309
297
  return list(d.values())[0]
310
298
 
311
- return FirstObject(get_values(self.data[0])[0])
299
+ return get_values(self.data[0])[0]
312
300
 
313
301
  def latex(self, **kwargs):
314
302
  return self.table().latex()
@@ -338,7 +326,7 @@ class Dataset(UserList, DatasetOperationsMixin, PersistenceMixin, HashingMixin):
338
326
  """
339
327
  if "format" in kwargs:
340
328
  if kwargs["format"] not in ["html", "markdown", "rich", "latex"]:
341
- raise ValueError(f"Format '{kwargs['format']}' not supported.")
329
+ raise DatasetValueError(f"Format '{kwargs['format']}' not supported.")
342
330
 
343
331
  # If rich format is requested, set tablefmt accordingly
344
332
  if kwargs["format"] == "rich":
@@ -371,10 +359,18 @@ class Dataset(UserList, DatasetOperationsMixin, PersistenceMixin, HashingMixin):
371
359
  merged_df = df1.merge(df2, how="left", left_on=by_x, right_on=by_y)
372
360
  return Dataset.from_pandas_dataframe(merged_df)
373
361
 
374
- def to(self, survey_or_question: Union["Survey", "QuestionBase"]) -> "Jobs":
375
- """Return a new dataset with the observations transformed by the given survey or question."""
376
- from edsl.surveys import Survey
377
- from edsl.questions.QuestionBase import QuestionBase
362
+ def to(self, survey_or_question: Union["Survey", "QuestionBase"]) -> "Job":
363
+ """Return a new dataset with the observations transformed by the given survey or question.
364
+
365
+ >>> d = Dataset([{'person_name':["John"]}])
366
+ >>> from edsl import QuestionFreeText
367
+ >>> q = QuestionFreeText(question_text = "How are you, {{ person_name ?}}?", question_name = "how_feeling")
368
+ >>> jobs = d.to(q)
369
+ >>> isinstance(jobs, object)
370
+ True
371
+ """
372
+ from ..surveys import Survey
373
+ from ..questions import QuestionBase
378
374
 
379
375
  if isinstance(survey_or_question, Survey):
380
376
  return survey_or_question.by(self.to_scenario_list())
@@ -396,9 +392,10 @@ class Dataset(UserList, DatasetOperationsMixin, PersistenceMixin, HashingMixin):
396
392
  """
397
393
  for key in keys:
398
394
  if key not in self.keys():
399
- raise ValueError(f"Key '{key}' not found in the dataset."
400
- f"Available keys: {self.keys()}"
401
- )
395
+ from .exceptions import DatasetValueError
396
+ raise DatasetValueError(f"Key '{key}' not found in the dataset. "
397
+ f"Available keys: {self.keys()}"
398
+ )
402
399
 
403
400
  if isinstance(keys, str):
404
401
  keys = [keys]
@@ -442,7 +439,11 @@ class Dataset(UserList, DatasetOperationsMixin, PersistenceMixin, HashingMixin):
442
439
 
443
440
  return self
444
441
 
445
- def expand(self, field):
442
+ def expand_field(self, field):
443
+ """Expand a field in the dataset.
444
+
445
+ Renamed to avoid conflict with the expand method defined earlier.
446
+ """
446
447
  return self.to_scenario_list().expand(field).to_dataset()
447
448
 
448
449
  def sample(
@@ -462,21 +463,18 @@ class Dataset(UserList, DatasetOperationsMixin, PersistenceMixin, HashingMixin):
462
463
  >>> d = Dataset([{'a.b':[1,2,3,4]}])
463
464
  >>> d.sample(n=2, seed=0, with_replacement=True)
464
465
  Dataset([{'a.b': [4, 4]}])
465
-
466
- >>> d.sample(n = 10, seed=0, with_replacement=False)
467
- Traceback (most recent call last):
468
- ...
469
- ValueError: Sample size cannot be greater than the number of available elements when sampling without replacement.
470
466
  """
471
467
  if seed is not None:
472
468
  random.seed(seed)
473
469
 
474
470
  # Validate the input for sampling parameters
475
471
  if n is None and frac is None:
476
- raise ValueError("Either 'n' or 'frac' must be provided for sampling.")
472
+ from .exceptions import DatasetValueError
473
+ raise DatasetValueError("Either 'n' or 'frac' must be provided for sampling.")
477
474
 
478
475
  if n is not None and frac is not None:
479
- raise ValueError("Only one of 'n' or 'frac' should be specified.")
476
+ from .exceptions import DatasetValueError
477
+ raise DatasetValueError("Only one of 'n' or 'frac' should be specified.")
480
478
 
481
479
  # Get the length of the lists from the first entry
482
480
  first_key, first_values = list(self[0].items())[0]
@@ -487,7 +485,8 @@ class Dataset(UserList, DatasetOperationsMixin, PersistenceMixin, HashingMixin):
487
485
  n = int(total_length * frac)
488
486
 
489
487
  if not with_replacement and n > total_length:
490
- raise ValueError(
488
+ from .exceptions import DatasetValueError
489
+ raise DatasetValueError(
491
490
  "Sample size cannot be greater than the number of available elements when sampling without replacement."
492
491
  )
493
492
 
@@ -504,60 +503,72 @@ class Dataset(UserList, DatasetOperationsMixin, PersistenceMixin, HashingMixin):
504
503
 
505
504
  return self
506
505
 
507
- def order_by(self, sort_key: str, reverse: bool = False) -> Dataset:
508
- """Return a new dataset with the observations sorted by the given key.
509
-
510
- :param sort_key: The key to sort the observations by.
511
- :param reverse: Whether to sort in reverse order.
512
-
513
- >>> d = Dataset([{'a':[1,2,3,4]}, {'b':[4,3,2,1]}])
514
- >>> d.order_by('a')
515
- Dataset([{'a': [1, 2, 3, 4]}, {'b': [4, 3, 2, 1]}])
516
-
517
- >>> d.order_by('a', reverse=True)
518
- Dataset([{'a': [4, 3, 2, 1]}, {'b': [1, 2, 3, 4]}])
519
-
520
- >>> d = Dataset([{'X.a':[1,2,3,4]}, {'X.b':[4,3,2,1]}])
521
- >>> d.order_by('a')
522
- Dataset([{'X.a': [1, 2, 3, 4]}, {'X.b': [4, 3, 2, 1]}])
523
-
524
-
506
+ def get_sort_indices(self, lst: list[Any], reverse: bool = False, use_numpy: bool = True) -> list[int]:
525
507
  """
526
- import numpy as np
508
+ Return the indices that would sort the list, using either numpy or pure Python.
509
+ None values are placed at the end of the sorted list.
527
510
 
528
- def sort_indices(lst: list[Any]) -> list[int]:
529
- """
530
- Return the indices that would sort the list.
511
+ Args:
512
+ lst: The list to be sorted
513
+ reverse: Whether to sort in descending order
514
+ use_numpy: Whether to use numpy implementation (falls back to pure Python if numpy is unavailable)
531
515
 
532
- :param lst: The list to be sorted.
533
- :return: A list of indices that would sort the list.
534
- """
535
- indices = np.argsort(lst).tolist()
536
- if reverse:
537
- indices.reverse()
538
- return indices
516
+ Returns:
517
+ A list of indices that would sort the list
518
+ """
519
+ if use_numpy:
520
+ try:
521
+ import numpy as np
522
+ # Convert list to numpy array
523
+ arr = np.array(lst, dtype=object)
524
+ # Get mask of non-None values
525
+ mask = ~(arr is None)
526
+ # Get indices of non-None and None values
527
+ non_none_indices = np.where(mask)[0]
528
+ none_indices = np.where(~mask)[0]
529
+ # Sort non-None values
530
+ sorted_indices = non_none_indices[np.argsort(arr[mask])]
531
+ # Combine sorted non-None indices with None indices
532
+ indices = np.concatenate([sorted_indices, none_indices]).tolist()
533
+ if reverse:
534
+ # When reversing, keep None values at end
535
+ indices = sorted_indices[::-1].tolist() + none_indices.tolist()
536
+ return indices
537
+ except ImportError:
538
+ # Fallback to pure Python if numpy is not available
539
+ pass
540
+
541
+ # Pure Python implementation
542
+ enumerated = list(enumerate(lst))
543
+ # Sort None values to end by using (is_none, value) as sort key
544
+ sorted_pairs = sorted(enumerated,
545
+ key=lambda x: (x[1] is None, x[1]),
546
+ reverse=reverse)
547
+ return [index for index, _ in sorted_pairs]
548
+
549
+ def order_by(self, sort_key: str, reverse: bool = False, use_numpy: bool = True) -> Dataset:
550
+ """Return a new dataset with the observations sorted by the given key.
539
551
 
552
+ Args:
553
+ sort_key: The key to sort the observations by
554
+ reverse: Whether to sort in reverse order
555
+ use_numpy: Whether to use numpy for sorting (faster for large lists)
556
+ """
540
557
  number_found = 0
541
558
  for obs in self.data:
542
559
  key, values = list(obs.items())[0]
543
- # an obseration is {'a':[1,2,3,4]}
544
- # key = list(obs.keys())[0]
545
- if (
546
- sort_key == key or sort_key == key.split(".")[-1]
547
- ): # e.g., "age" in "scenario.age"
560
+ if sort_key == key or sort_key == key.split(".")[-1]:
548
561
  relevant_values = values
549
562
  number_found += 1
550
563
 
551
564
  if number_found == 0:
552
- raise ValueError(f"Key '{sort_key}' not found in any of the dictionaries.")
565
+ raise DatasetKeyError(f"Key '{sort_key}' not found in any of the dictionaries.")
553
566
  elif number_found > 1:
554
- raise ValueError(f"Key '{sort_key}' found in more than one dictionary.")
567
+ raise DatasetKeyError(f"Key '{sort_key}' found in more than one dictionary.")
555
568
 
556
- # relevant_values = self._key_to_value(sort_key)
557
- sort_indices_list = sort_indices(relevant_values)
569
+ sort_indices_list = self.get_sort_indices(relevant_values, reverse=reverse, use_numpy=use_numpy)
558
570
  new_data = []
559
571
  for observation in self.data:
560
- # print(observation)
561
572
  key, values = list(observation.items())[0]
562
573
  new_values = [values[i] for i in sort_indices_list]
563
574
  new_data.append({key: new_values})
@@ -578,7 +589,7 @@ class Dataset(UserList, DatasetOperationsMixin, PersistenceMixin, HashingMixin):
578
589
  def table(
579
590
  self,
580
591
  *fields,
581
- tablefmt: Optional[str] = None,
592
+ tablefmt: Optional[str] = "rich",
582
593
  max_rows: Optional[int] = None,
583
594
  pretty_labels=None,
584
595
  print_parameters: Optional[dict] = None,
@@ -637,7 +648,8 @@ class Dataset(UserList, DatasetOperationsMixin, PersistenceMixin, HashingMixin):
637
648
 
638
649
  if max_rows is not None:
639
650
  if max_rows > len(data):
640
- raise ValueError(
651
+ from .exceptions import DatasetValueError
652
+ raise DatasetValueError(
641
653
  "max_rows cannot be greater than the number of rows in the dataset."
642
654
  )
643
655
  last_line = data[-1]
@@ -675,6 +687,19 @@ class Dataset(UserList, DatasetOperationsMixin, PersistenceMixin, HashingMixin):
675
687
  def from_pandas_dataframe(cls, df):
676
688
  result = cls([{col: df[col].tolist()} for col in df.columns])
677
689
  return result
690
+
691
+ def to_dict(self) -> dict:
692
+ """
693
+ Convert the dataset to a dictionary.
694
+ """
695
+ return {'data': self.data}
696
+
697
+ @classmethod
698
+ def from_dict(cls, data: dict) -> 'Dataset':
699
+ """
700
+ Convert a dictionary to a dataset.
701
+ """
702
+ return cls(data['data'])
678
703
 
679
704
  def to_docx(self, output_file: str, title: str = None) -> None:
680
705
  """
@@ -726,6 +751,72 @@ class Dataset(UserList, DatasetOperationsMixin, PersistenceMixin, HashingMixin):
726
751
  # Save the document
727
752
  doc.save(output_file)
728
753
 
754
+ def expand(self, field: str, number_field: bool = False) -> "Dataset":
755
+ """
756
+ Expand a field containing lists into multiple rows.
757
+
758
+ Args:
759
+ field: The field containing lists to expand
760
+ number_field: If True, adds a number field indicating the position in the original list
761
+
762
+ Returns:
763
+ A new Dataset with the expanded rows
764
+
765
+ Example:
766
+ >>> from edsl.dataset import Dataset
767
+ >>> d = Dataset([{'a': [[1, 2, 3], [4, 5, 6]]}, {'b': ['x', 'y']}])
768
+ >>> d.expand('a')
769
+ Dataset([{'a': [1, 2, 3, 4, 5, 6]}, {'b': ['x', 'x', 'x', 'y', 'y', 'y']}])
770
+ """
771
+ from collections.abc import Iterable
772
+
773
+ # Find the field in the dataset
774
+ field_data = None
775
+ for entry in self.data:
776
+ key = list(entry.keys())[0]
777
+ if key == field:
778
+ field_data = entry[key]
779
+ break
780
+
781
+ if field_data is None:
782
+ raise DatasetKeyError(f"Field '{field}' not found in dataset. Available fields are: {self.keys()}")
783
+
784
+
785
+ # Validate that the field contains lists
786
+ if not all(isinstance(v, list) for v in field_data):
787
+ raise DatasetTypeError(f"Field '{field}' must contain lists in all entries")
788
+
789
+ # Create new expanded data structure
790
+ new_data = []
791
+
792
+ # Process each field
793
+ for entry in self.data:
794
+ key, values = list(entry.items())[0]
795
+ new_values = []
796
+
797
+ if key == field:
798
+ # This is the field to expand - flatten all sublists
799
+ for row_values in values:
800
+ if not isinstance(row_values, Iterable) or isinstance(row_values, str):
801
+ row_values = [row_values]
802
+ new_values.extend(row_values)
803
+ else:
804
+ # For other fields, repeat each value the appropriate number of times
805
+ for i, row_value in enumerate(values):
806
+ expand_length = len(field_data[i]) if i < len(field_data) else 0
807
+ new_values.extend([row_value] * expand_length)
808
+
809
+ new_data.append({key: new_values})
810
+
811
+ # Add number field if requested
812
+ if number_field:
813
+ number_values = []
814
+ for i, lst in enumerate(field_data):
815
+ number_values.extend(range(1, len(lst) + 1))
816
+ new_data.append({f"{field}_number": number_values})
817
+
818
+ return Dataset(new_data)
819
+
729
820
 
730
821
  if __name__ == "__main__":
731
822
  import doctest