edsl 0.1.50__py3-none-any.whl → 0.1.52__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. edsl/__init__.py +45 -34
  2. edsl/__version__.py +1 -1
  3. edsl/base/base_exception.py +2 -2
  4. edsl/buckets/bucket_collection.py +1 -1
  5. edsl/buckets/exceptions.py +32 -0
  6. edsl/buckets/token_bucket_api.py +26 -10
  7. edsl/caching/cache.py +5 -2
  8. edsl/caching/remote_cache_sync.py +5 -5
  9. edsl/caching/sql_dict.py +12 -11
  10. edsl/config/__init__.py +1 -1
  11. edsl/config/config_class.py +4 -2
  12. edsl/conversation/Conversation.py +9 -5
  13. edsl/conversation/car_buying.py +1 -3
  14. edsl/conversation/mug_negotiation.py +2 -6
  15. edsl/coop/__init__.py +11 -8
  16. edsl/coop/coop.py +15 -13
  17. edsl/coop/coop_functions.py +1 -1
  18. edsl/coop/ep_key_handling.py +1 -1
  19. edsl/coop/price_fetcher.py +2 -2
  20. edsl/coop/utils.py +2 -2
  21. edsl/dataset/dataset.py +144 -63
  22. edsl/dataset/dataset_operations_mixin.py +14 -6
  23. edsl/dataset/dataset_tree.py +3 -3
  24. edsl/dataset/display/table_renderers.py +6 -3
  25. edsl/dataset/file_exports.py +4 -4
  26. edsl/dataset/r/ggplot.py +3 -3
  27. edsl/inference_services/available_model_fetcher.py +2 -2
  28. edsl/inference_services/data_structures.py +5 -5
  29. edsl/inference_services/inference_service_abc.py +1 -1
  30. edsl/inference_services/inference_services_collection.py +1 -1
  31. edsl/inference_services/service_availability.py +3 -3
  32. edsl/inference_services/services/azure_ai.py +3 -3
  33. edsl/inference_services/services/google_service.py +1 -1
  34. edsl/inference_services/services/test_service.py +1 -1
  35. edsl/instructions/change_instruction.py +5 -4
  36. edsl/instructions/instruction.py +1 -0
  37. edsl/instructions/instruction_collection.py +5 -4
  38. edsl/instructions/instruction_handler.py +10 -8
  39. edsl/interviews/answering_function.py +20 -21
  40. edsl/interviews/exception_tracking.py +3 -2
  41. edsl/interviews/interview.py +1 -1
  42. edsl/interviews/interview_status_dictionary.py +1 -1
  43. edsl/interviews/interview_task_manager.py +7 -4
  44. edsl/interviews/request_token_estimator.py +3 -2
  45. edsl/interviews/statistics.py +2 -2
  46. edsl/invigilators/invigilators.py +34 -6
  47. edsl/jobs/__init__.py +39 -2
  48. edsl/jobs/async_interview_runner.py +1 -1
  49. edsl/jobs/check_survey_scenario_compatibility.py +5 -5
  50. edsl/jobs/data_structures.py +2 -2
  51. edsl/jobs/html_table_job_logger.py +494 -257
  52. edsl/jobs/jobs.py +2 -2
  53. edsl/jobs/jobs_checks.py +5 -5
  54. edsl/jobs/jobs_component_constructor.py +2 -2
  55. edsl/jobs/jobs_pricing_estimation.py +1 -1
  56. edsl/jobs/jobs_runner_asyncio.py +2 -2
  57. edsl/jobs/jobs_status_enums.py +1 -0
  58. edsl/jobs/remote_inference.py +47 -13
  59. edsl/jobs/results_exceptions_handler.py +2 -2
  60. edsl/language_models/language_model.py +151 -145
  61. edsl/notebooks/__init__.py +24 -1
  62. edsl/notebooks/exceptions.py +82 -0
  63. edsl/notebooks/notebook.py +7 -3
  64. edsl/notebooks/notebook_to_latex.py +1 -1
  65. edsl/prompts/__init__.py +23 -2
  66. edsl/prompts/prompt.py +1 -1
  67. edsl/questions/__init__.py +4 -4
  68. edsl/questions/answer_validator_mixin.py +0 -5
  69. edsl/questions/compose_questions.py +2 -2
  70. edsl/questions/descriptors.py +1 -1
  71. edsl/questions/question_base.py +32 -3
  72. edsl/questions/question_base_prompts_mixin.py +4 -4
  73. edsl/questions/question_budget.py +503 -102
  74. edsl/questions/question_check_box.py +658 -156
  75. edsl/questions/question_dict.py +176 -2
  76. edsl/questions/question_extract.py +401 -61
  77. edsl/questions/question_free_text.py +77 -9
  78. edsl/questions/question_functional.py +118 -9
  79. edsl/questions/{derived/question_likert_five.py → question_likert_five.py} +2 -2
  80. edsl/questions/{derived/question_linear_scale.py → question_linear_scale.py} +3 -4
  81. edsl/questions/question_list.py +246 -26
  82. edsl/questions/question_matrix.py +586 -73
  83. edsl/questions/question_multiple_choice.py +213 -47
  84. edsl/questions/question_numerical.py +360 -29
  85. edsl/questions/question_rank.py +401 -124
  86. edsl/questions/question_registry.py +3 -3
  87. edsl/questions/{derived/question_top_k.py → question_top_k.py} +3 -3
  88. edsl/questions/{derived/question_yes_no.py → question_yes_no.py} +3 -4
  89. edsl/questions/register_questions_meta.py +2 -1
  90. edsl/questions/response_validator_abc.py +6 -2
  91. edsl/questions/response_validator_factory.py +10 -12
  92. edsl/results/report.py +1 -1
  93. edsl/results/result.py +7 -4
  94. edsl/results/results.py +500 -271
  95. edsl/results/results_selector.py +2 -2
  96. edsl/scenarios/construct_download_link.py +3 -3
  97. edsl/scenarios/scenario.py +1 -2
  98. edsl/scenarios/scenario_list.py +41 -23
  99. edsl/surveys/survey_css.py +3 -3
  100. edsl/surveys/survey_simulator.py +2 -1
  101. edsl/tasks/__init__.py +22 -2
  102. edsl/tasks/exceptions.py +72 -0
  103. edsl/tasks/task_history.py +48 -11
  104. edsl/templates/error_reporting/base.html +37 -4
  105. edsl/templates/error_reporting/exceptions_table.html +105 -33
  106. edsl/templates/error_reporting/interview_details.html +130 -126
  107. edsl/templates/error_reporting/overview.html +21 -25
  108. edsl/templates/error_reporting/report.css +215 -46
  109. edsl/templates/error_reporting/report.js +122 -20
  110. edsl/tokens/__init__.py +27 -1
  111. edsl/tokens/exceptions.py +37 -0
  112. edsl/tokens/interview_token_usage.py +3 -2
  113. edsl/tokens/token_usage.py +4 -3
  114. {edsl-0.1.50.dist-info → edsl-0.1.52.dist-info}/METADATA +1 -1
  115. {edsl-0.1.50.dist-info → edsl-0.1.52.dist-info}/RECORD +118 -116
  116. edsl/questions/derived/__init__.py +0 -0
  117. {edsl-0.1.50.dist-info → edsl-0.1.52.dist-info}/LICENSE +0 -0
  118. {edsl-0.1.50.dist-info → edsl-0.1.52.dist-info}/WHEEL +0 -0
  119. {edsl-0.1.50.dist-info → edsl-0.1.52.dist-info}/entry_points.txt +0 -0
edsl/coop/coop.py CHANGED
@@ -180,7 +180,7 @@ class Coop(CoopFunctionsMixin):
180
180
  timeout=timeout,
181
181
  )
182
182
  else:
183
- from edsl.coop.exceptions import CoopInvalidMethodError
183
+ from .exceptions import CoopInvalidMethodError
184
184
 
185
185
  raise CoopInvalidMethodError(f"Invalid {method=}.")
186
186
  except requests.ConnectionError:
@@ -303,7 +303,7 @@ class Coop(CoopFunctionsMixin):
303
303
  message = root.find("Message").text
304
304
  details = root.find("Details").text
305
305
  except Exception:
306
- from edsl.coop.exceptions import CoopServerResponseError
306
+ from .exceptions import CoopServerResponseError
307
307
 
308
308
  raise CoopServerResponseError(
309
309
  f"Server returned status code {response.status_code}. "
@@ -311,7 +311,7 @@ class Coop(CoopFunctionsMixin):
311
311
  f"The server response was: {response.text}"
312
312
  )
313
313
 
314
- from edsl.coop.exceptions import CoopServerResponseError
314
+ from .exceptions import CoopServerResponseError
315
315
 
316
316
  raise CoopServerResponseError(
317
317
  f"An error occurred: {code} - {message} - {details}"
@@ -538,7 +538,7 @@ class Coop(CoopFunctionsMixin):
538
538
  if response_json.get("upload_signed_url"):
539
539
  signed_url = response_json.get("upload_signed_url")
540
540
  else:
541
- from edsl.coop.exceptions import CoopResponseError
541
+ from .exceptions import CoopResponseError
542
542
 
543
543
  raise CoopResponseError("No signed url was provided received")
544
544
 
@@ -551,7 +551,7 @@ class Coop(CoopFunctionsMixin):
551
551
  "file_store_upload_signed_url"
552
552
  )
553
553
  if file_store_metadata and not file_store_upload_signed_url:
554
- from edsl.coop.exceptions import CoopResponseError
554
+ from .exceptions import CoopResponseError
555
555
 
556
556
  raise CoopResponseError("No file store signed url provided.")
557
557
  elif file_store_metadata:
@@ -659,13 +659,15 @@ class Coop(CoopFunctionsMixin):
659
659
  json_string = object_data.text
660
660
  object_type = response.json().get("object_type")
661
661
  if expected_object_type and object_type != expected_object_type:
662
- from edsl.coop.exceptions import CoopObjectTypeError
662
+ from .exceptions import CoopObjectTypeError
663
663
 
664
664
  raise CoopObjectTypeError(
665
665
  f"Expected {expected_object_type=} but got {object_type=}"
666
666
  )
667
667
  edsl_class = ObjectRegistry.object_type_to_edsl_class.get(object_type)
668
668
  object = edsl_class.from_dict(json.loads(json_string))
669
+ if object_type == "results":
670
+ object.initialize_cache_from_results()
669
671
  return object
670
672
 
671
673
  def get_all(self, object_type: ObjectType) -> list[dict[str, Any]]:
@@ -754,7 +756,7 @@ class Coop(CoopFunctionsMixin):
754
756
  and value is None
755
757
  and alias is None
756
758
  ):
757
- from edsl.coop.exceptions import CoopPatchError
759
+ from .exceptions import CoopPatchError
758
760
 
759
761
  raise CoopPatchError("Nothing to patch.")
760
762
 
@@ -887,7 +889,7 @@ class Coop(CoopFunctionsMixin):
887
889
  [CacheEntry(...), CacheEntry(...), ...]
888
890
  """
889
891
  if job_uuid is None:
890
- from edsl.coop.exceptions import CoopValueError
892
+ from .exceptions import CoopValueError
891
893
 
892
894
  raise CoopValueError("Must provide a job_uuid.")
893
895
  response = self._send_server_request(
@@ -917,7 +919,7 @@ class Coop(CoopFunctionsMixin):
917
919
  [CacheEntry(...), CacheEntry(...), ...]
918
920
  """
919
921
  if select_keys is None or len(select_keys) == 0:
920
- from edsl.coop.exceptions import CoopValueError
922
+ from .exceptions import CoopValueError
921
923
 
922
924
  raise CoopValueError("Must provide a non-empty list of select_keys.")
923
925
  response = self._send_server_request(
@@ -1182,7 +1184,7 @@ class Coop(CoopFunctionsMixin):
1182
1184
  ... print(f"Results available at: {job_status['results_url']}")
1183
1185
  """
1184
1186
  if job_uuid is None and results_uuid is None:
1185
- from edsl.coop.exceptions import CoopValueError
1187
+ from .exceptions import CoopValueError
1186
1188
 
1187
1189
  raise CoopValueError("Either job_uuid or results_uuid must be provided.")
1188
1190
  elif job_uuid is not None:
@@ -1258,7 +1260,7 @@ class Coop(CoopFunctionsMixin):
1258
1260
  elif isinstance(input, Survey):
1259
1261
  job = Jobs(survey=input)
1260
1262
  else:
1261
- from edsl.coop.exceptions import CoopTypeError
1263
+ from .exceptions import CoopTypeError
1262
1264
 
1263
1265
  raise CoopTypeError("Input must be either a Job or a Survey.")
1264
1266
 
@@ -1395,7 +1397,7 @@ class Coop(CoopFunctionsMixin):
1395
1397
  elif CONFIG.get("EDSL_FETCH_TOKEN_PRICES") == "False":
1396
1398
  return {}
1397
1399
  else:
1398
- from edsl.coop.exceptions import CoopValueError
1400
+ from .exceptions import CoopValueError
1399
1401
 
1400
1402
  raise CoopValueError(
1401
1403
  "Invalid EDSL_FETCH_TOKEN_PRICES value---should be 'True' or 'False'."
@@ -1553,7 +1555,7 @@ class Coop(CoopFunctionsMixin):
1553
1555
  api_key = self._poll_for_api_key(edsl_auth_token)
1554
1556
 
1555
1557
  if api_key is None:
1556
- from edsl.coop.exceptions import CoopTimeoutError
1558
+ from .exceptions import CoopTimeoutError
1557
1559
 
1558
1560
  raise CoopTimeoutError("Timed out waiting for login. Please try again.")
1559
1561
 
@@ -1,6 +1,6 @@
1
1
  class CoopFunctionsMixin:
2
2
  def better_names(self, existing_names):
3
- from edsl import QuestionList, Scenario
3
+ from .. import QuestionList, Scenario
4
4
 
5
5
  s = Scenario({"existing_names": existing_names})
6
6
  q = QuestionList(
@@ -70,7 +70,7 @@ class ExpectedParrotKeyHandler:
70
70
 
71
71
  def ok_to_ask_to_store(self):
72
72
  """Check if it's okay to ask the user to store the key."""
73
- from edsl.config import CONFIG
73
+ from ..config import CONFIG
74
74
 
75
75
  if CONFIG.get("EDSL_RUN_MODE") != "production":
76
76
  return False
@@ -7,6 +7,7 @@ that price information is only fetched once and then cached for efficiency.
7
7
  """
8
8
 
9
9
  import requests
10
+ import os
10
11
  from typing import Dict, Tuple, Any
11
12
 
12
13
 
@@ -82,7 +83,6 @@ class PriceFetcher:
82
83
  if self._cached_prices is not None:
83
84
  return self._cached_prices
84
85
 
85
- import os
86
86
  from ..config import CONFIG
87
87
 
88
88
  try:
@@ -120,4 +120,4 @@ class PriceFetcher:
120
120
  except requests.RequestException:
121
121
  # Silently handle errors and return empty dict
122
122
  # print(f"An error occurred: {e}")
123
- return {}
123
+ return {}
edsl/coop/utils.py CHANGED
@@ -129,7 +129,7 @@ class ObjectRegistry:
129
129
  # Look up the object type
130
130
  object_type = cls.edsl_class_to_object_type.get(edsl_class_name)
131
131
  if object_type is None:
132
- from edsl.coop.exceptions import CoopValueError
132
+ from .exceptions import CoopValueError
133
133
  raise CoopValueError(f"Object type not found for {edsl_object=}")
134
134
  return object_type
135
135
 
@@ -152,7 +152,7 @@ class ObjectRegistry:
152
152
  """
153
153
  EDSL_class = cls.object_type_to_edsl_class.get(object_type)
154
154
  if EDSL_class is None:
155
- from edsl.coop.exceptions import CoopValueError
155
+ from .exceptions import CoopValueError
156
156
  raise CoopValueError(f"EDSL class not found for {object_type=}")
157
157
  return EDSL_class
158
158
 
edsl/dataset/dataset.py CHANGED
@@ -1,5 +1,3 @@
1
-
2
-
3
1
  from __future__ import annotations
4
2
  import sys
5
3
  import json
@@ -10,7 +8,8 @@ from typing import Any, Union, Optional, TYPE_CHECKING, Callable
10
8
  from ..base import PersistenceMixin, HashingMixin
11
9
 
12
10
  from .dataset_tree import Tree
13
- from .exceptions import DatasetKeyError, DatasetValueError
11
+ from .exceptions import DatasetKeyError, DatasetValueError, DatasetTypeError
12
+
14
13
 
15
14
  from .display.table_display import TableDisplay
16
15
  #from .smart_objects import FirstObject
@@ -121,19 +120,9 @@ class Dataset(UserList, DatasetOperationsMixin, PersistenceMixin, HashingMixin):
121
120
  new_data.append({key: values[:n]})
122
121
  return Dataset(new_data)
123
122
 
124
- def expand(self, field):
125
- return self.to_scenario_list().expand(field)
126
-
127
- # def view(self):
128
- # from perspective.widget import PerspectiveWidget
123
+ # def expand(self, field):
124
+ # return self.to_scenario_list().expand(field)
129
125
 
130
- # w = PerspectiveWidget(
131
- # self.to_pandas(),
132
- # plugin="Datagrid",
133
- # aggregates={"datetime": "any"},
134
- # sort=[["date", "desc"]],
135
- # )
136
- # return w
137
126
 
138
127
  def keys(self) -> list[str]:
139
128
  """Return the keys of the dataset.
@@ -287,12 +276,12 @@ class Dataset(UserList, DatasetOperationsMixin, PersistenceMixin, HashingMixin):
287
276
  if len(potential_matches) == 1:
288
277
  return potential_matches[0][1]
289
278
  elif len(potential_matches) > 1:
290
- from edsl.dataset.exceptions import DatasetKeyError
279
+ from .exceptions import DatasetKeyError
291
280
  raise DatasetKeyError(
292
281
  f"Key '{key}' found in more than one location: {[m[0] for m in potential_matches]}"
293
282
  )
294
283
 
295
- from edsl.dataset.exceptions import DatasetKeyError
284
+ from .exceptions import DatasetKeyError
296
285
  raise DatasetKeyError(f"Key '{key}' not found in any of the dictionaries.")
297
286
 
298
287
  def first(self) -> dict[str, Any]:
@@ -376,11 +365,12 @@ class Dataset(UserList, DatasetOperationsMixin, PersistenceMixin, HashingMixin):
376
365
  >>> d = Dataset([{'person_name':["John"]}])
377
366
  >>> from edsl import QuestionFreeText
378
367
  >>> q = QuestionFreeText(question_text = "How are you, {{ person_name ?}}?", question_name = "how_feeling")
379
- >>> d.to(q)
380
- Jobs(...)
368
+ >>> jobs = d.to(q)
369
+ >>> isinstance(jobs, object)
370
+ True
381
371
  """
382
- from edsl.surveys import Survey
383
- from edsl.questions import QuestionBase
372
+ from ..surveys import Survey
373
+ from ..questions import QuestionBase
384
374
 
385
375
  if isinstance(survey_or_question, Survey):
386
376
  return survey_or_question.by(self.to_scenario_list())
@@ -402,7 +392,7 @@ class Dataset(UserList, DatasetOperationsMixin, PersistenceMixin, HashingMixin):
402
392
  """
403
393
  for key in keys:
404
394
  if key not in self.keys():
405
- from edsl.dataset.exceptions import DatasetValueError
395
+ from .exceptions import DatasetValueError
406
396
  raise DatasetValueError(f"Key '{key}' not found in the dataset. "
407
397
  f"Available keys: {self.keys()}"
408
398
  )
@@ -479,11 +469,11 @@ class Dataset(UserList, DatasetOperationsMixin, PersistenceMixin, HashingMixin):
479
469
 
480
470
  # Validate the input for sampling parameters
481
471
  if n is None and frac is None:
482
- from edsl.dataset.exceptions import DatasetValueError
472
+ from .exceptions import DatasetValueError
483
473
  raise DatasetValueError("Either 'n' or 'frac' must be provided for sampling.")
484
474
 
485
475
  if n is not None and frac is not None:
486
- from edsl.dataset.exceptions import DatasetValueError
476
+ from .exceptions import DatasetValueError
487
477
  raise DatasetValueError("Only one of 'n' or 'frac' should be specified.")
488
478
 
489
479
  # Get the length of the lists from the first entry
@@ -495,7 +485,7 @@ class Dataset(UserList, DatasetOperationsMixin, PersistenceMixin, HashingMixin):
495
485
  n = int(total_length * frac)
496
486
 
497
487
  if not with_replacement and n > total_length:
498
- from edsl.dataset.exceptions import DatasetValueError
488
+ from .exceptions import DatasetValueError
499
489
  raise DatasetValueError(
500
490
  "Sample size cannot be greater than the number of available elements when sampling without replacement."
501
491
  )
@@ -513,47 +503,61 @@ class Dataset(UserList, DatasetOperationsMixin, PersistenceMixin, HashingMixin):
513
503
 
514
504
  return self
515
505
 
516
- def order_by(self, sort_key: str, reverse: bool = False) -> Dataset:
517
- """Return a new dataset with the observations sorted by the given key.
518
-
519
- :param sort_key: The key to sort the observations by.
520
- :param reverse: Whether to sort in reverse order.
521
-
522
- >>> d = Dataset([{'a':[1,2,3,4]}, {'b':[4,3,2,1]}])
523
- >>> d.order_by('a')
524
- Dataset([{'a': [1, 2, 3, 4]}, {'b': [4, 3, 2, 1]}])
525
-
526
- >>> d.order_by('a', reverse=True)
527
- Dataset([{'a': [4, 3, 2, 1]}, {'b': [1, 2, 3, 4]}])
528
-
529
- >>> d = Dataset([{'X.a':[1,2,3,4]}, {'X.b':[4,3,2,1]}])
530
- >>> d.order_by('a')
531
- Dataset([{'X.a': [1, 2, 3, 4]}, {'X.b': [4, 3, 2, 1]}])
532
-
533
-
506
+ def get_sort_indices(self, lst: list[Any], reverse: bool = False, use_numpy: bool = True) -> list[int]:
534
507
  """
535
- import numpy as np
508
+ Return the indices that would sort the list, using either numpy or pure Python.
509
+ None values are placed at the end of the sorted list.
536
510
 
537
- def sort_indices(lst: list[Any]) -> list[int]:
538
- """
539
- Return the indices that would sort the list.
511
+ Args:
512
+ lst: The list to be sorted
513
+ reverse: Whether to sort in descending order
514
+ use_numpy: Whether to use numpy implementation (falls back to pure Python if numpy is unavailable)
540
515
 
541
- :param lst: The list to be sorted.
542
- :return: A list of indices that would sort the list.
543
- """
544
- indices = np.argsort(lst).tolist()
545
- if reverse:
546
- indices.reverse()
547
- return indices
516
+ Returns:
517
+ A list of indices that would sort the list
518
+ """
519
+ if use_numpy:
520
+ try:
521
+ import numpy as np
522
+ # Convert list to numpy array
523
+ arr = np.array(lst, dtype=object)
524
+ # Get mask of non-None values
525
+ mask = ~(arr is None)
526
+ # Get indices of non-None and None values
527
+ non_none_indices = np.where(mask)[0]
528
+ none_indices = np.where(~mask)[0]
529
+ # Sort non-None values
530
+ sorted_indices = non_none_indices[np.argsort(arr[mask])]
531
+ # Combine sorted non-None indices with None indices
532
+ indices = np.concatenate([sorted_indices, none_indices]).tolist()
533
+ if reverse:
534
+ # When reversing, keep None values at end
535
+ indices = sorted_indices[::-1].tolist() + none_indices.tolist()
536
+ return indices
537
+ except ImportError:
538
+ # Fallback to pure Python if numpy is not available
539
+ pass
540
+
541
+ # Pure Python implementation
542
+ enumerated = list(enumerate(lst))
543
+ # Sort None values to end by using (is_none, value) as sort key
544
+ sorted_pairs = sorted(enumerated,
545
+ key=lambda x: (x[1] is None, x[1]),
546
+ reverse=reverse)
547
+ return [index for index, _ in sorted_pairs]
548
+
549
+ def order_by(self, sort_key: str, reverse: bool = False, use_numpy: bool = True) -> Dataset:
550
+ """Return a new dataset with the observations sorted by the given key.
548
551
 
552
+ Args:
553
+ sort_key: The key to sort the observations by
554
+ reverse: Whether to sort in reverse order
555
+ use_numpy: Whether to use numpy for sorting (faster for large lists)
556
+ """
549
557
  number_found = 0
550
558
  for obs in self.data:
551
559
  key, values = list(obs.items())[0]
552
- # an obseration is {'a':[1,2,3,4]}
553
- # key = list(obs.keys())[0]
554
- if (
555
- sort_key == key or sort_key == key.split(".")[-1]
556
- ): # e.g., "age" in "scenario.age"
560
+ if sort_key == key or sort_key == key.split(".")[-1]:
557
561
  relevant_values = values
558
562
  number_found += 1
559
563
 
@@ -562,11 +566,9 @@ class Dataset(UserList, DatasetOperationsMixin, PersistenceMixin, HashingMixin):
562
566
  elif number_found > 1:
563
567
  raise DatasetKeyError(f"Key '{sort_key}' found in more than one dictionary.")
564
568
 
565
- # relevant_values = self._key_to_value(sort_key)
566
- sort_indices_list = sort_indices(relevant_values)
569
+ sort_indices_list = self.get_sort_indices(relevant_values, reverse=reverse, use_numpy=use_numpy)
567
570
  new_data = []
568
571
  for observation in self.data:
569
- # print(observation)
570
572
  key, values = list(observation.items())[0]
571
573
  new_values = [values[i] for i in sort_indices_list]
572
574
  new_data.append({key: new_values})
@@ -646,7 +648,7 @@ class Dataset(UserList, DatasetOperationsMixin, PersistenceMixin, HashingMixin):
646
648
 
647
649
  if max_rows is not None:
648
650
  if max_rows > len(data):
649
- from edsl.dataset.exceptions import DatasetValueError
651
+ from .exceptions import DatasetValueError
650
652
  raise DatasetValueError(
651
653
  "max_rows cannot be greater than the number of rows in the dataset."
652
654
  )
@@ -685,6 +687,19 @@ class Dataset(UserList, DatasetOperationsMixin, PersistenceMixin, HashingMixin):
685
687
  def from_pandas_dataframe(cls, df):
686
688
  result = cls([{col: df[col].tolist()} for col in df.columns])
687
689
  return result
690
+
691
+ def to_dict(self) -> dict:
692
+ """
693
+ Convert the dataset to a dictionary.
694
+ """
695
+ return {'data': self.data}
696
+
697
+ @classmethod
698
+ def from_dict(cls, data: dict) -> 'Dataset':
699
+ """
700
+ Convert a dictionary to a dataset.
701
+ """
702
+ return cls(data['data'])
688
703
 
689
704
  def to_docx(self, output_file: str, title: str = None) -> None:
690
705
  """
@@ -736,6 +751,72 @@ class Dataset(UserList, DatasetOperationsMixin, PersistenceMixin, HashingMixin):
736
751
  # Save the document
737
752
  doc.save(output_file)
738
753
 
754
+ def expand(self, field: str, number_field: bool = False) -> "Dataset":
755
+ """
756
+ Expand a field containing lists into multiple rows.
757
+
758
+ Args:
759
+ field: The field containing lists to expand
760
+ number_field: If True, adds a number field indicating the position in the original list
761
+
762
+ Returns:
763
+ A new Dataset with the expanded rows
764
+
765
+ Example:
766
+ >>> from edsl.dataset import Dataset
767
+ >>> d = Dataset([{'a': [[1, 2, 3], [4, 5, 6]]}, {'b': ['x', 'y']}])
768
+ >>> d.expand('a')
769
+ Dataset([{'a': [1, 2, 3, 4, 5, 6]}, {'b': ['x', 'x', 'x', 'y', 'y', 'y']}])
770
+ """
771
+ from collections.abc import Iterable
772
+
773
+ # Find the field in the dataset
774
+ field_data = None
775
+ for entry in self.data:
776
+ key = list(entry.keys())[0]
777
+ if key == field:
778
+ field_data = entry[key]
779
+ break
780
+
781
+ if field_data is None:
782
+ raise DatasetKeyError(f"Field '{field}' not found in dataset. Available fields are: {self.keys()}")
783
+
784
+
785
+ # Validate that the field contains lists
786
+ if not all(isinstance(v, list) for v in field_data):
787
+ raise DatasetTypeError(f"Field '{field}' must contain lists in all entries")
788
+
789
+ # Create new expanded data structure
790
+ new_data = []
791
+
792
+ # Process each field
793
+ for entry in self.data:
794
+ key, values = list(entry.items())[0]
795
+ new_values = []
796
+
797
+ if key == field:
798
+ # This is the field to expand - flatten all sublists
799
+ for row_values in values:
800
+ if not isinstance(row_values, Iterable) or isinstance(row_values, str):
801
+ row_values = [row_values]
802
+ new_values.extend(row_values)
803
+ else:
804
+ # For other fields, repeat each value the appropriate number of times
805
+ for i, row_value in enumerate(values):
806
+ expand_length = len(field_data[i]) if i < len(field_data) else 0
807
+ new_values.extend([row_value] * expand_length)
808
+
809
+ new_data.append({key: new_values})
810
+
811
+ # Add number field if requested
812
+ if number_field:
813
+ number_values = []
814
+ for i, lst in enumerate(field_data):
815
+ number_values.extend(range(1, len(lst) + 1))
816
+ new_data.append({f"{field}_number": number_values})
817
+
818
+ return Dataset(new_data)
819
+
739
820
 
740
821
  if __name__ == "__main__":
741
822
  import doctest
@@ -184,6 +184,13 @@ class DataOperationsBase:
184
184
  )
185
185
 
186
186
  return _num_observations
187
+
188
+ def chart(self):
189
+ """
190
+ Create a chart from the results.
191
+ """
192
+ import altair as alt
193
+ return alt.Chart(self.to_pandas(remove_prefix=True))
187
194
 
188
195
  def make_tabular(
189
196
  self, remove_prefix: bool, pretty_labels: Optional[dict] = None
@@ -538,13 +545,14 @@ class DataOperationsBase:
538
545
  >>> r.select('how_feeling').to_scenario_list()
539
546
  ScenarioList([Scenario({'how_feeling': 'OK'}), Scenario({'how_feeling': 'Great'}), Scenario({'how_feeling': 'Terrible'}), Scenario({'how_feeling': 'OK'})])
540
547
  """
541
- from edsl.scenarios import ScenarioList, Scenario
548
+ from ..scenarios import ScenarioList, Scenario
542
549
 
543
550
  list_of_dicts = self.to_dicts(remove_prefix=remove_prefix)
544
551
  scenarios = []
545
552
  for d in list_of_dicts:
546
553
  scenarios.append(Scenario(d))
547
554
  return ScenarioList(scenarios)
555
+
548
556
 
549
557
  def to_agent_list(self, remove_prefix: bool = True):
550
558
  """Convert the results to a list of dictionaries, one per agent.
@@ -556,7 +564,7 @@ class DataOperationsBase:
556
564
  >>> r.select('how_feeling').to_agent_list()
557
565
  AgentList([Agent(traits = {'how_feeling': 'OK'}), Agent(traits = {'how_feeling': 'Great'}), Agent(traits = {'how_feeling': 'Terrible'}), Agent(traits = {'how_feeling': 'OK'})])
558
566
  """
559
- from edsl.agents import Agent, AgentList
567
+ from ..agents import Agent, AgentList
560
568
 
561
569
  list_of_dicts = self.to_dicts(remove_prefix=remove_prefix)
562
570
  agents = []
@@ -665,7 +673,7 @@ class DataOperationsBase:
665
673
  ):
666
674
  import os
667
675
  import tempfile
668
- from edsl.utilities.utilities import is_notebook
676
+ from ..utilities.utilities import is_notebook
669
677
  from IPython.display import HTML, display
670
678
 
671
679
  df = self.to_pandas()
@@ -799,7 +807,7 @@ class DataOperationsBase:
799
807
  from docx.shared import Pt
800
808
  import json
801
809
  except ImportError:
802
- from edsl.dataset.exceptions import DatasetImportError
810
+ from .exceptions import DatasetImportError
803
811
  raise DatasetImportError("The python-docx package is required for DOCX export. Install it with 'pip install python-docx'.")
804
812
 
805
813
  doc = Document()
@@ -871,7 +879,7 @@ class DataOperationsBase:
871
879
  >>> isinstance(doc, object)
872
880
  True
873
881
  """
874
- from edsl.utilities.utilities import is_notebook
882
+ from ..utilities.utilities import is_notebook
875
883
 
876
884
  # Prepare the data for the report
877
885
  field_data, num_obs, fields, header_fields = self._prepare_report_data(
@@ -1076,7 +1084,7 @@ class DataOperationsBase:
1076
1084
  # Check if the field is ambiguous
1077
1085
  if len(matching_entries) > 1:
1078
1086
  matching_cols = [next(iter(entry.keys())) for entry in matching_entries]
1079
- from edsl.dataset.exceptions import DatasetValueError
1087
+ from .exceptions import DatasetValueError
1080
1088
  raise DatasetValueError(
1081
1089
  f"Ambiguous field name '{field}'. It matches multiple columns: {matching_cols}. "
1082
1090
  f"Please specify the full column name to flatten."
@@ -51,7 +51,7 @@ class Tree:
51
51
  else:
52
52
  if not set(node_order).issubset(set(self.data.keys())):
53
53
  invalid_keys = set(node_order) - set(self.data.keys())
54
- from edsl.dataset.exceptions import DatasetValueError
54
+ from .exceptions import DatasetValueError
55
55
  raise DatasetValueError(f"Invalid keys in node_order: {invalid_keys}")
56
56
 
57
57
  self.root = TreeNode()
@@ -130,7 +130,7 @@ class Tree:
130
130
  doc_buffer.seek(0)
131
131
 
132
132
  base64_string = base64.b64encode(doc_buffer.getvalue()).decode("utf-8")
133
- from edsl.scenarios.FileStore import FileStore
133
+ from ..scenarios.file_store import FileStore
134
134
 
135
135
  # Create and return FileStore instance
136
136
  return FileStore(
@@ -335,7 +335,7 @@ class Tree:
335
335
  Returns:
336
336
  A string containing the markdown document, or renders markdown in notebooks.
337
337
  """
338
- from edsl.utilities.utilities import is_notebook
338
+ from ..utilities.utilities import is_notebook
339
339
  from IPython.display import Markdown, display
340
340
 
341
341
  if node is None:
@@ -103,9 +103,12 @@ class PandasStyleRenderer(DataTablesRendererABC):
103
103
  else:
104
104
  df = pd.DataFrame(self.table_data.data, columns=self.table_data.headers)
105
105
 
106
- styled_df = df.style.set_properties(
107
- **{"text-align": "left"}
108
- ).background_gradient()
106
+ styled_df = df.style.set_properties(**{
107
+ "text-align": "left",
108
+ "white-space": "pre-wrap", # Allows text wrapping
109
+ "max-width": "300px", # Maximum width before wrapping
110
+ "word-wrap": "break-word" # Breaks words that exceed max-width
111
+ }).background_gradient()
109
112
 
110
113
  return f"""
111
114
  <div style="max-height: 500px; overflow-y: auto;">
@@ -40,7 +40,7 @@ class FileExport(ABC):
40
40
 
41
41
  def _create_filestore(self, data: Union[str, bytes]):
42
42
  """Create a FileStore instance with encoded data."""
43
- from ..scenarios import FileStore
43
+ from ..scenarios.file_store import FileStore
44
44
  if isinstance(data, str):
45
45
  base64_string = base64.b64encode(data.encode()).decode()
46
46
  else:
@@ -203,7 +203,7 @@ class SQLiteExport(TabularExport):
203
203
  (self.table_name,),
204
204
  )
205
205
  if cursor.fetchone():
206
- from edsl.dataset.exceptions import DatasetValueError
206
+ from .exceptions import DatasetValueError
207
207
  raise DatasetValueError(f"Table {self.table_name} already exists")
208
208
 
209
209
  # Create table
@@ -245,14 +245,14 @@ class SQLiteExport(TabularExport):
245
245
  """Validate initialization parameters."""
246
246
  valid_if_exists = {"fail", "replace", "append"}
247
247
  if self.if_exists not in valid_if_exists:
248
- from edsl.dataset.exceptions import DatasetValueError
248
+ from .exceptions import DatasetValueError
249
249
  raise DatasetValueError(
250
250
  f"if_exists must be one of {valid_if_exists}, got {self.if_exists}"
251
251
  )
252
252
 
253
253
  # Validate table name (basic SQLite identifier validation)
254
254
  if not self.table_name.isalnum() and not all(c in "_" for c in self.table_name):
255
- from edsl.dataset.exceptions import DatasetValueError
255
+ from .exceptions import DatasetValueError
256
256
  raise DatasetValueError(
257
257
  f"Invalid table name: {self.table_name}. Must contain only alphanumeric characters and underscores."
258
258
  )
edsl/dataset/r/ggplot.py CHANGED
@@ -30,12 +30,12 @@ class GGPlot:
30
30
 
31
31
  if result.returncode != 0:
32
32
  if result.returncode == 127:
33
- from edsl.dataset.exceptions import DatasetRuntimeError
33
+ from ..exceptions import DatasetRuntimeError
34
34
  raise DatasetRuntimeError(
35
35
  "Rscript is probably not installed. Please install R from https://cran.r-project.org/"
36
36
  )
37
37
  else:
38
- from edsl.dataset.exceptions import DatasetRuntimeError
38
+ from ..exceptions import DatasetRuntimeError
39
39
  raise DatasetRuntimeError(
40
40
  f"An error occurred while running Rscript: {result.stderr}"
41
41
  )
@@ -49,7 +49,7 @@ class GGPlot:
49
49
  """Save the plot to a file."""
50
50
  format = filename.split('.')[-1].lower()
51
51
  if format not in ['svg', 'png']:
52
- from edsl.dataset.exceptions import DatasetValueError
52
+ from ..exceptions import DatasetValueError
53
53
  raise DatasetValueError("Only 'svg' and 'png' formats are supported")
54
54
 
55
55
  save_command = f'\nggsave("{filename}", plot = last_plot(), width = {self.width}, height = {self.height}, device = "{format}")'