edsl 0.1.50__py3-none-any.whl → 0.1.51__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. edsl/__version__.py +1 -1
  2. edsl/base/base_exception.py +2 -2
  3. edsl/buckets/bucket_collection.py +1 -1
  4. edsl/buckets/exceptions.py +32 -0
  5. edsl/buckets/token_bucket_api.py +26 -10
  6. edsl/caching/cache.py +5 -2
  7. edsl/caching/remote_cache_sync.py +5 -5
  8. edsl/caching/sql_dict.py +12 -11
  9. edsl/config/__init__.py +1 -1
  10. edsl/config/config_class.py +4 -2
  11. edsl/conversation/Conversation.py +7 -4
  12. edsl/conversation/car_buying.py +1 -3
  13. edsl/conversation/mug_negotiation.py +2 -6
  14. edsl/coop/__init__.py +11 -8
  15. edsl/coop/coop.py +13 -13
  16. edsl/coop/coop_functions.py +1 -1
  17. edsl/coop/ep_key_handling.py +1 -1
  18. edsl/coop/price_fetcher.py +2 -2
  19. edsl/coop/utils.py +2 -2
  20. edsl/dataset/dataset.py +144 -63
  21. edsl/dataset/dataset_operations_mixin.py +14 -6
  22. edsl/dataset/dataset_tree.py +3 -3
  23. edsl/dataset/display/table_renderers.py +6 -3
  24. edsl/dataset/file_exports.py +4 -4
  25. edsl/dataset/r/ggplot.py +3 -3
  26. edsl/inference_services/available_model_fetcher.py +2 -2
  27. edsl/inference_services/data_structures.py +5 -5
  28. edsl/inference_services/inference_service_abc.py +1 -1
  29. edsl/inference_services/inference_services_collection.py +1 -1
  30. edsl/inference_services/service_availability.py +3 -3
  31. edsl/inference_services/services/azure_ai.py +3 -3
  32. edsl/inference_services/services/google_service.py +1 -1
  33. edsl/inference_services/services/test_service.py +1 -1
  34. edsl/instructions/change_instruction.py +5 -4
  35. edsl/instructions/instruction.py +1 -0
  36. edsl/instructions/instruction_collection.py +5 -4
  37. edsl/instructions/instruction_handler.py +10 -8
  38. edsl/interviews/exception_tracking.py +1 -1
  39. edsl/interviews/interview.py +1 -1
  40. edsl/interviews/interview_status_dictionary.py +1 -1
  41. edsl/interviews/interview_task_manager.py +2 -2
  42. edsl/interviews/request_token_estimator.py +3 -2
  43. edsl/interviews/statistics.py +2 -2
  44. edsl/invigilators/invigilators.py +2 -2
  45. edsl/jobs/__init__.py +39 -2
  46. edsl/jobs/async_interview_runner.py +1 -1
  47. edsl/jobs/check_survey_scenario_compatibility.py +5 -5
  48. edsl/jobs/data_structures.py +2 -2
  49. edsl/jobs/jobs.py +2 -2
  50. edsl/jobs/jobs_checks.py +5 -5
  51. edsl/jobs/jobs_component_constructor.py +2 -2
  52. edsl/jobs/jobs_pricing_estimation.py +1 -1
  53. edsl/jobs/jobs_runner_asyncio.py +2 -2
  54. edsl/jobs/remote_inference.py +1 -1
  55. edsl/jobs/results_exceptions_handler.py +2 -2
  56. edsl/language_models/language_model.py +5 -1
  57. edsl/notebooks/__init__.py +24 -1
  58. edsl/notebooks/exceptions.py +82 -0
  59. edsl/notebooks/notebook.py +7 -3
  60. edsl/notebooks/notebook_to_latex.py +1 -1
  61. edsl/prompts/__init__.py +23 -2
  62. edsl/prompts/prompt.py +1 -1
  63. edsl/questions/__init__.py +4 -4
  64. edsl/questions/answer_validator_mixin.py +0 -5
  65. edsl/questions/compose_questions.py +2 -2
  66. edsl/questions/descriptors.py +1 -1
  67. edsl/questions/question_base.py +32 -3
  68. edsl/questions/question_base_prompts_mixin.py +4 -4
  69. edsl/questions/question_budget.py +503 -102
  70. edsl/questions/question_check_box.py +658 -156
  71. edsl/questions/question_dict.py +176 -2
  72. edsl/questions/question_extract.py +401 -61
  73. edsl/questions/question_free_text.py +77 -9
  74. edsl/questions/question_functional.py +118 -9
  75. edsl/questions/{derived/question_likert_five.py → question_likert_five.py} +2 -2
  76. edsl/questions/{derived/question_linear_scale.py → question_linear_scale.py} +3 -4
  77. edsl/questions/question_list.py +246 -26
  78. edsl/questions/question_matrix.py +586 -73
  79. edsl/questions/question_multiple_choice.py +213 -47
  80. edsl/questions/question_numerical.py +360 -29
  81. edsl/questions/question_rank.py +401 -124
  82. edsl/questions/question_registry.py +3 -3
  83. edsl/questions/{derived/question_top_k.py → question_top_k.py} +3 -3
  84. edsl/questions/{derived/question_yes_no.py → question_yes_no.py} +3 -4
  85. edsl/questions/register_questions_meta.py +2 -1
  86. edsl/questions/response_validator_abc.py +6 -2
  87. edsl/questions/response_validator_factory.py +10 -12
  88. edsl/results/report.py +1 -1
  89. edsl/results/result.py +7 -4
  90. edsl/results/results.py +471 -271
  91. edsl/results/results_selector.py +2 -2
  92. edsl/scenarios/construct_download_link.py +3 -3
  93. edsl/scenarios/scenario.py +1 -2
  94. edsl/scenarios/scenario_list.py +41 -23
  95. edsl/surveys/survey_css.py +3 -3
  96. edsl/surveys/survey_simulator.py +2 -1
  97. edsl/tasks/__init__.py +22 -2
  98. edsl/tasks/exceptions.py +72 -0
  99. edsl/tasks/task_history.py +3 -3
  100. edsl/tokens/__init__.py +27 -1
  101. edsl/tokens/exceptions.py +37 -0
  102. edsl/tokens/interview_token_usage.py +3 -2
  103. edsl/tokens/token_usage.py +4 -3
  104. {edsl-0.1.50.dist-info → edsl-0.1.51.dist-info}/METADATA +1 -1
  105. {edsl-0.1.50.dist-info → edsl-0.1.51.dist-info}/RECORD +108 -106
  106. edsl/questions/derived/__init__.py +0 -0
  107. {edsl-0.1.50.dist-info → edsl-0.1.51.dist-info}/LICENSE +0 -0
  108. {edsl-0.1.50.dist-info → edsl-0.1.51.dist-info}/WHEEL +0 -0
  109. {edsl-0.1.50.dist-info → edsl-0.1.51.dist-info}/entry_points.txt +0 -0
edsl/dataset/dataset.py CHANGED
@@ -1,5 +1,3 @@
1
-
2
-
3
1
  from __future__ import annotations
4
2
  import sys
5
3
  import json
@@ -10,7 +8,8 @@ from typing import Any, Union, Optional, TYPE_CHECKING, Callable
10
8
  from ..base import PersistenceMixin, HashingMixin
11
9
 
12
10
  from .dataset_tree import Tree
13
- from .exceptions import DatasetKeyError, DatasetValueError
11
+ from .exceptions import DatasetKeyError, DatasetValueError, DatasetTypeError
12
+
14
13
 
15
14
  from .display.table_display import TableDisplay
16
15
  #from .smart_objects import FirstObject
@@ -121,19 +120,9 @@ class Dataset(UserList, DatasetOperationsMixin, PersistenceMixin, HashingMixin):
121
120
  new_data.append({key: values[:n]})
122
121
  return Dataset(new_data)
123
122
 
124
- def expand(self, field):
125
- return self.to_scenario_list().expand(field)
126
-
127
- # def view(self):
128
- # from perspective.widget import PerspectiveWidget
123
+ # def expand(self, field):
124
+ # return self.to_scenario_list().expand(field)
129
125
 
130
- # w = PerspectiveWidget(
131
- # self.to_pandas(),
132
- # plugin="Datagrid",
133
- # aggregates={"datetime": "any"},
134
- # sort=[["date", "desc"]],
135
- # )
136
- # return w
137
126
 
138
127
  def keys(self) -> list[str]:
139
128
  """Return the keys of the dataset.
@@ -287,12 +276,12 @@ class Dataset(UserList, DatasetOperationsMixin, PersistenceMixin, HashingMixin):
287
276
  if len(potential_matches) == 1:
288
277
  return potential_matches[0][1]
289
278
  elif len(potential_matches) > 1:
290
- from edsl.dataset.exceptions import DatasetKeyError
279
+ from .exceptions import DatasetKeyError
291
280
  raise DatasetKeyError(
292
281
  f"Key '{key}' found in more than one location: {[m[0] for m in potential_matches]}"
293
282
  )
294
283
 
295
- from edsl.dataset.exceptions import DatasetKeyError
284
+ from .exceptions import DatasetKeyError
296
285
  raise DatasetKeyError(f"Key '{key}' not found in any of the dictionaries.")
297
286
 
298
287
  def first(self) -> dict[str, Any]:
@@ -376,11 +365,12 @@ class Dataset(UserList, DatasetOperationsMixin, PersistenceMixin, HashingMixin):
376
365
  >>> d = Dataset([{'person_name':["John"]}])
377
366
  >>> from edsl import QuestionFreeText
378
367
  >>> q = QuestionFreeText(question_text = "How are you, {{ person_name ?}}?", question_name = "how_feeling")
379
- >>> d.to(q)
380
- Jobs(...)
368
+ >>> jobs = d.to(q)
369
+ >>> isinstance(jobs, object)
370
+ True
381
371
  """
382
- from edsl.surveys import Survey
383
- from edsl.questions import QuestionBase
372
+ from ..surveys import Survey
373
+ from ..questions import QuestionBase
384
374
 
385
375
  if isinstance(survey_or_question, Survey):
386
376
  return survey_or_question.by(self.to_scenario_list())
@@ -402,7 +392,7 @@ class Dataset(UserList, DatasetOperationsMixin, PersistenceMixin, HashingMixin):
402
392
  """
403
393
  for key in keys:
404
394
  if key not in self.keys():
405
- from edsl.dataset.exceptions import DatasetValueError
395
+ from .exceptions import DatasetValueError
406
396
  raise DatasetValueError(f"Key '{key}' not found in the dataset. "
407
397
  f"Available keys: {self.keys()}"
408
398
  )
@@ -479,11 +469,11 @@ class Dataset(UserList, DatasetOperationsMixin, PersistenceMixin, HashingMixin):
479
469
 
480
470
  # Validate the input for sampling parameters
481
471
  if n is None and frac is None:
482
- from edsl.dataset.exceptions import DatasetValueError
472
+ from .exceptions import DatasetValueError
483
473
  raise DatasetValueError("Either 'n' or 'frac' must be provided for sampling.")
484
474
 
485
475
  if n is not None and frac is not None:
486
- from edsl.dataset.exceptions import DatasetValueError
476
+ from .exceptions import DatasetValueError
487
477
  raise DatasetValueError("Only one of 'n' or 'frac' should be specified.")
488
478
 
489
479
  # Get the length of the lists from the first entry
@@ -495,7 +485,7 @@ class Dataset(UserList, DatasetOperationsMixin, PersistenceMixin, HashingMixin):
495
485
  n = int(total_length * frac)
496
486
 
497
487
  if not with_replacement and n > total_length:
498
- from edsl.dataset.exceptions import DatasetValueError
488
+ from .exceptions import DatasetValueError
499
489
  raise DatasetValueError(
500
490
  "Sample size cannot be greater than the number of available elements when sampling without replacement."
501
491
  )
@@ -513,47 +503,61 @@ class Dataset(UserList, DatasetOperationsMixin, PersistenceMixin, HashingMixin):
513
503
 
514
504
  return self
515
505
 
516
- def order_by(self, sort_key: str, reverse: bool = False) -> Dataset:
517
- """Return a new dataset with the observations sorted by the given key.
518
-
519
- :param sort_key: The key to sort the observations by.
520
- :param reverse: Whether to sort in reverse order.
521
-
522
- >>> d = Dataset([{'a':[1,2,3,4]}, {'b':[4,3,2,1]}])
523
- >>> d.order_by('a')
524
- Dataset([{'a': [1, 2, 3, 4]}, {'b': [4, 3, 2, 1]}])
525
-
526
- >>> d.order_by('a', reverse=True)
527
- Dataset([{'a': [4, 3, 2, 1]}, {'b': [1, 2, 3, 4]}])
528
-
529
- >>> d = Dataset([{'X.a':[1,2,3,4]}, {'X.b':[4,3,2,1]}])
530
- >>> d.order_by('a')
531
- Dataset([{'X.a': [1, 2, 3, 4]}, {'X.b': [4, 3, 2, 1]}])
532
-
533
-
506
+ def get_sort_indices(self, lst: list[Any], reverse: bool = False, use_numpy: bool = True) -> list[int]:
534
507
  """
535
- import numpy as np
508
+ Return the indices that would sort the list, using either numpy or pure Python.
509
+ None values are placed at the end of the sorted list.
536
510
 
537
- def sort_indices(lst: list[Any]) -> list[int]:
538
- """
539
- Return the indices that would sort the list.
511
+ Args:
512
+ lst: The list to be sorted
513
+ reverse: Whether to sort in descending order
514
+ use_numpy: Whether to use numpy implementation (falls back to pure Python if numpy is unavailable)
540
515
 
541
- :param lst: The list to be sorted.
542
- :return: A list of indices that would sort the list.
543
- """
544
- indices = np.argsort(lst).tolist()
545
- if reverse:
546
- indices.reverse()
547
- return indices
516
+ Returns:
517
+ A list of indices that would sort the list
518
+ """
519
+ if use_numpy:
520
+ try:
521
+ import numpy as np
522
+ # Convert list to numpy array
523
+ arr = np.array(lst, dtype=object)
524
+ # Get mask of non-None values
525
+ mask = ~(arr is None)
526
+ # Get indices of non-None and None values
527
+ non_none_indices = np.where(mask)[0]
528
+ none_indices = np.where(~mask)[0]
529
+ # Sort non-None values
530
+ sorted_indices = non_none_indices[np.argsort(arr[mask])]
531
+ # Combine sorted non-None indices with None indices
532
+ indices = np.concatenate([sorted_indices, none_indices]).tolist()
533
+ if reverse:
534
+ # When reversing, keep None values at end
535
+ indices = sorted_indices[::-1].tolist() + none_indices.tolist()
536
+ return indices
537
+ except ImportError:
538
+ # Fallback to pure Python if numpy is not available
539
+ pass
540
+
541
+ # Pure Python implementation
542
+ enumerated = list(enumerate(lst))
543
+ # Sort None values to end by using (is_none, value) as sort key
544
+ sorted_pairs = sorted(enumerated,
545
+ key=lambda x: (x[1] is None, x[1]),
546
+ reverse=reverse)
547
+ return [index for index, _ in sorted_pairs]
548
+
549
+ def order_by(self, sort_key: str, reverse: bool = False, use_numpy: bool = True) -> Dataset:
550
+ """Return a new dataset with the observations sorted by the given key.
548
551
 
552
+ Args:
553
+ sort_key: The key to sort the observations by
554
+ reverse: Whether to sort in reverse order
555
+ use_numpy: Whether to use numpy for sorting (faster for large lists)
556
+ """
549
557
  number_found = 0
550
558
  for obs in self.data:
551
559
  key, values = list(obs.items())[0]
552
- # an obseration is {'a':[1,2,3,4]}
553
- # key = list(obs.keys())[0]
554
- if (
555
- sort_key == key or sort_key == key.split(".")[-1]
556
- ): # e.g., "age" in "scenario.age"
560
+ if sort_key == key or sort_key == key.split(".")[-1]:
557
561
  relevant_values = values
558
562
  number_found += 1
559
563
 
@@ -562,11 +566,9 @@ class Dataset(UserList, DatasetOperationsMixin, PersistenceMixin, HashingMixin):
562
566
  elif number_found > 1:
563
567
  raise DatasetKeyError(f"Key '{sort_key}' found in more than one dictionary.")
564
568
 
565
- # relevant_values = self._key_to_value(sort_key)
566
- sort_indices_list = sort_indices(relevant_values)
569
+ sort_indices_list = self.get_sort_indices(relevant_values, reverse=reverse, use_numpy=use_numpy)
567
570
  new_data = []
568
571
  for observation in self.data:
569
- # print(observation)
570
572
  key, values = list(observation.items())[0]
571
573
  new_values = [values[i] for i in sort_indices_list]
572
574
  new_data.append({key: new_values})
@@ -646,7 +648,7 @@ class Dataset(UserList, DatasetOperationsMixin, PersistenceMixin, HashingMixin):
646
648
 
647
649
  if max_rows is not None:
648
650
  if max_rows > len(data):
649
- from edsl.dataset.exceptions import DatasetValueError
651
+ from .exceptions import DatasetValueError
650
652
  raise DatasetValueError(
651
653
  "max_rows cannot be greater than the number of rows in the dataset."
652
654
  )
@@ -685,6 +687,19 @@ class Dataset(UserList, DatasetOperationsMixin, PersistenceMixin, HashingMixin):
685
687
  def from_pandas_dataframe(cls, df):
686
688
  result = cls([{col: df[col].tolist()} for col in df.columns])
687
689
  return result
690
+
691
+ def to_dict(self) -> dict:
692
+ """
693
+ Convert the dataset to a dictionary.
694
+ """
695
+ return {'data': self.data}
696
+
697
+ @classmethod
698
+ def from_dict(cls, data: dict) -> 'Dataset':
699
+ """
700
+ Convert a dictionary to a dataset.
701
+ """
702
+ return cls(data['data'])
688
703
 
689
704
  def to_docx(self, output_file: str, title: str = None) -> None:
690
705
  """
@@ -736,6 +751,72 @@ class Dataset(UserList, DatasetOperationsMixin, PersistenceMixin, HashingMixin):
736
751
  # Save the document
737
752
  doc.save(output_file)
738
753
 
754
+ def expand(self, field: str, number_field: bool = False) -> "Dataset":
755
+ """
756
+ Expand a field containing lists into multiple rows.
757
+
758
+ Args:
759
+ field: The field containing lists to expand
760
+ number_field: If True, adds a number field indicating the position in the original list
761
+
762
+ Returns:
763
+ A new Dataset with the expanded rows
764
+
765
+ Example:
766
+ >>> from edsl.dataset import Dataset
767
+ >>> d = Dataset([{'a': [[1, 2, 3], [4, 5, 6]]}, {'b': ['x', 'y']}])
768
+ >>> d.expand('a')
769
+ Dataset([{'a': [1, 2, 3, 4, 5, 6]}, {'b': ['x', 'x', 'x', 'y', 'y', 'y']}])
770
+ """
771
+ from collections.abc import Iterable
772
+
773
+ # Find the field in the dataset
774
+ field_data = None
775
+ for entry in self.data:
776
+ key = list(entry.keys())[0]
777
+ if key == field:
778
+ field_data = entry[key]
779
+ break
780
+
781
+ if field_data is None:
782
+ raise DatasetKeyError(f"Field '{field}' not found in dataset. Available fields are: {self.keys()}")
783
+
784
+
785
+ # Validate that the field contains lists
786
+ if not all(isinstance(v, list) for v in field_data):
787
+ raise DatasetTypeError(f"Field '{field}' must contain lists in all entries")
788
+
789
+ # Create new expanded data structure
790
+ new_data = []
791
+
792
+ # Process each field
793
+ for entry in self.data:
794
+ key, values = list(entry.items())[0]
795
+ new_values = []
796
+
797
+ if key == field:
798
+ # This is the field to expand - flatten all sublists
799
+ for row_values in values:
800
+ if not isinstance(row_values, Iterable) or isinstance(row_values, str):
801
+ row_values = [row_values]
802
+ new_values.extend(row_values)
803
+ else:
804
+ # For other fields, repeat each value the appropriate number of times
805
+ for i, row_value in enumerate(values):
806
+ expand_length = len(field_data[i]) if i < len(field_data) else 0
807
+ new_values.extend([row_value] * expand_length)
808
+
809
+ new_data.append({key: new_values})
810
+
811
+ # Add number field if requested
812
+ if number_field:
813
+ number_values = []
814
+ for i, lst in enumerate(field_data):
815
+ number_values.extend(range(1, len(lst) + 1))
816
+ new_data.append({f"{field}_number": number_values})
817
+
818
+ return Dataset(new_data)
819
+
739
820
 
740
821
  if __name__ == "__main__":
741
822
  import doctest
@@ -184,6 +184,13 @@ class DataOperationsBase:
184
184
  )
185
185
 
186
186
  return _num_observations
187
+
188
+ def chart(self):
189
+ """
190
+ Create a chart from the results.
191
+ """
192
+ import altair as alt
193
+ return alt.Chart(self.to_pandas(remove_prefix=True))
187
194
 
188
195
  def make_tabular(
189
196
  self, remove_prefix: bool, pretty_labels: Optional[dict] = None
@@ -538,13 +545,14 @@ class DataOperationsBase:
538
545
  >>> r.select('how_feeling').to_scenario_list()
539
546
  ScenarioList([Scenario({'how_feeling': 'OK'}), Scenario({'how_feeling': 'Great'}), Scenario({'how_feeling': 'Terrible'}), Scenario({'how_feeling': 'OK'})])
540
547
  """
541
- from edsl.scenarios import ScenarioList, Scenario
548
+ from ..scenarios import ScenarioList, Scenario
542
549
 
543
550
  list_of_dicts = self.to_dicts(remove_prefix=remove_prefix)
544
551
  scenarios = []
545
552
  for d in list_of_dicts:
546
553
  scenarios.append(Scenario(d))
547
554
  return ScenarioList(scenarios)
555
+
548
556
 
549
557
  def to_agent_list(self, remove_prefix: bool = True):
550
558
  """Convert the results to a list of dictionaries, one per agent.
@@ -556,7 +564,7 @@ class DataOperationsBase:
556
564
  >>> r.select('how_feeling').to_agent_list()
557
565
  AgentList([Agent(traits = {'how_feeling': 'OK'}), Agent(traits = {'how_feeling': 'Great'}), Agent(traits = {'how_feeling': 'Terrible'}), Agent(traits = {'how_feeling': 'OK'})])
558
566
  """
559
- from edsl.agents import Agent, AgentList
567
+ from ..agents import Agent, AgentList
560
568
 
561
569
  list_of_dicts = self.to_dicts(remove_prefix=remove_prefix)
562
570
  agents = []
@@ -665,7 +673,7 @@ class DataOperationsBase:
665
673
  ):
666
674
  import os
667
675
  import tempfile
668
- from edsl.utilities.utilities import is_notebook
676
+ from ..utilities.utilities import is_notebook
669
677
  from IPython.display import HTML, display
670
678
 
671
679
  df = self.to_pandas()
@@ -799,7 +807,7 @@ class DataOperationsBase:
799
807
  from docx.shared import Pt
800
808
  import json
801
809
  except ImportError:
802
- from edsl.dataset.exceptions import DatasetImportError
810
+ from .exceptions import DatasetImportError
803
811
  raise DatasetImportError("The python-docx package is required for DOCX export. Install it with 'pip install python-docx'.")
804
812
 
805
813
  doc = Document()
@@ -871,7 +879,7 @@ class DataOperationsBase:
871
879
  >>> isinstance(doc, object)
872
880
  True
873
881
  """
874
- from edsl.utilities.utilities import is_notebook
882
+ from ..utilities.utilities import is_notebook
875
883
 
876
884
  # Prepare the data for the report
877
885
  field_data, num_obs, fields, header_fields = self._prepare_report_data(
@@ -1076,7 +1084,7 @@ class DataOperationsBase:
1076
1084
  # Check if the field is ambiguous
1077
1085
  if len(matching_entries) > 1:
1078
1086
  matching_cols = [next(iter(entry.keys())) for entry in matching_entries]
1079
- from edsl.dataset.exceptions import DatasetValueError
1087
+ from .exceptions import DatasetValueError
1080
1088
  raise DatasetValueError(
1081
1089
  f"Ambiguous field name '{field}'. It matches multiple columns: {matching_cols}. "
1082
1090
  f"Please specify the full column name to flatten."
@@ -51,7 +51,7 @@ class Tree:
51
51
  else:
52
52
  if not set(node_order).issubset(set(self.data.keys())):
53
53
  invalid_keys = set(node_order) - set(self.data.keys())
54
- from edsl.dataset.exceptions import DatasetValueError
54
+ from .exceptions import DatasetValueError
55
55
  raise DatasetValueError(f"Invalid keys in node_order: {invalid_keys}")
56
56
 
57
57
  self.root = TreeNode()
@@ -130,7 +130,7 @@ class Tree:
130
130
  doc_buffer.seek(0)
131
131
 
132
132
  base64_string = base64.b64encode(doc_buffer.getvalue()).decode("utf-8")
133
- from edsl.scenarios.FileStore import FileStore
133
+ from ..scenarios.file_store import FileStore
134
134
 
135
135
  # Create and return FileStore instance
136
136
  return FileStore(
@@ -335,7 +335,7 @@ class Tree:
335
335
  Returns:
336
336
  A string containing the markdown document, or renders markdown in notebooks.
337
337
  """
338
- from edsl.utilities.utilities import is_notebook
338
+ from ..utilities.utilities import is_notebook
339
339
  from IPython.display import Markdown, display
340
340
 
341
341
  if node is None:
@@ -103,9 +103,12 @@ class PandasStyleRenderer(DataTablesRendererABC):
103
103
  else:
104
104
  df = pd.DataFrame(self.table_data.data, columns=self.table_data.headers)
105
105
 
106
- styled_df = df.style.set_properties(
107
- **{"text-align": "left"}
108
- ).background_gradient()
106
+ styled_df = df.style.set_properties(**{
107
+ "text-align": "left",
108
+ "white-space": "pre-wrap", # Allows text wrapping
109
+ "max-width": "300px", # Maximum width before wrapping
110
+ "word-wrap": "break-word" # Breaks words that exceed max-width
111
+ }).background_gradient()
109
112
 
110
113
  return f"""
111
114
  <div style="max-height: 500px; overflow-y: auto;">
@@ -40,7 +40,7 @@ class FileExport(ABC):
40
40
 
41
41
  def _create_filestore(self, data: Union[str, bytes]):
42
42
  """Create a FileStore instance with encoded data."""
43
- from ..scenarios import FileStore
43
+ from ..scenarios.file_store import FileStore
44
44
  if isinstance(data, str):
45
45
  base64_string = base64.b64encode(data.encode()).decode()
46
46
  else:
@@ -203,7 +203,7 @@ class SQLiteExport(TabularExport):
203
203
  (self.table_name,),
204
204
  )
205
205
  if cursor.fetchone():
206
- from edsl.dataset.exceptions import DatasetValueError
206
+ from .exceptions import DatasetValueError
207
207
  raise DatasetValueError(f"Table {self.table_name} already exists")
208
208
 
209
209
  # Create table
@@ -245,14 +245,14 @@ class SQLiteExport(TabularExport):
245
245
  """Validate initialization parameters."""
246
246
  valid_if_exists = {"fail", "replace", "append"}
247
247
  if self.if_exists not in valid_if_exists:
248
- from edsl.dataset.exceptions import DatasetValueError
248
+ from .exceptions import DatasetValueError
249
249
  raise DatasetValueError(
250
250
  f"if_exists must be one of {valid_if_exists}, got {self.if_exists}"
251
251
  )
252
252
 
253
253
  # Validate table name (basic SQLite identifier validation)
254
254
  if not self.table_name.isalnum() and not all(c in "_" for c in self.table_name):
255
- from edsl.dataset.exceptions import DatasetValueError
255
+ from .exceptions import DatasetValueError
256
256
  raise DatasetValueError(
257
257
  f"Invalid table name: {self.table_name}. Must contain only alphanumeric characters and underscores."
258
258
  )
edsl/dataset/r/ggplot.py CHANGED
@@ -30,12 +30,12 @@ class GGPlot:
30
30
 
31
31
  if result.returncode != 0:
32
32
  if result.returncode == 127:
33
- from edsl.dataset.exceptions import DatasetRuntimeError
33
+ from ..exceptions import DatasetRuntimeError
34
34
  raise DatasetRuntimeError(
35
35
  "Rscript is probably not installed. Please install R from https://cran.r-project.org/"
36
36
  )
37
37
  else:
38
- from edsl.dataset.exceptions import DatasetRuntimeError
38
+ from ..exceptions import DatasetRuntimeError
39
39
  raise DatasetRuntimeError(
40
40
  f"An error occurred while running Rscript: {result.stderr}"
41
41
  )
@@ -49,7 +49,7 @@ class GGPlot:
49
49
  """Save the plot to a file."""
50
50
  format = filename.split('.')[-1].lower()
51
51
  if format not in ['svg', 'png']:
52
- from edsl.dataset.exceptions import DatasetValueError
52
+ from ..exceptions import DatasetValueError
53
53
  raise DatasetValueError("Only 'svg' and 'png' formats are supported")
54
54
 
55
55
  save_command = f'\nggsave("{filename}", plot = last_plot(), width = {self.width}, height = {self.height}, device = "{format}")'
@@ -55,7 +55,7 @@ class AvailableModelFetcher:
55
55
 
56
56
  :param service: Optional[InferenceServiceABC] - If specified, only fetch models for this service.
57
57
 
58
- >>> from edsl.inference_services.services.open_ai_service import OpenAIService
58
+ >>> from .services.open_ai_service import OpenAIService
59
59
  >>> af = AvailableModelFetcher([OpenAIService()], {})
60
60
  >>> af.available(service="openai")
61
61
  [LanguageModelInfo(model_name='...', service_name='openai'), ...]
@@ -155,7 +155,7 @@ class AvailableModelFetcher:
155
155
  """The service name is the _inference_service_ attribute of the service."""
156
156
  if service_name in self._service_map:
157
157
  return self._service_map[service_name]
158
- from edsl.inference_services.exceptions import InferenceServiceValueError
158
+ from .exceptions import InferenceServiceValueError
159
159
  raise InferenceServiceValueError(f"Service {service_name} not found")
160
160
 
161
161
  def _get_all_models(self, force_refresh=False) -> List[LanguageModelInfo]:
@@ -43,7 +43,7 @@ class LanguageModelInfo:
43
43
  elif key == 1:
44
44
  return self.service_name
45
45
  else:
46
- from edsl.inference_services.exceptions import InferenceServiceIndexError
46
+ from .exceptions import InferenceServiceIndexError
47
47
  raise InferenceServiceIndexError("Index out of range")
48
48
 
49
49
  @classmethod
@@ -70,7 +70,7 @@ class AvailableModels(UserList):
70
70
  return self.to_dataset().print()
71
71
 
72
72
  def to_dataset(self):
73
- from edsl.scenarios.ScenarioList import ScenarioList
73
+ from ..scenarios.scenario_list import ScenarioList
74
74
 
75
75
  models, services = zip(
76
76
  *[(model.model_name, model.service_name) for model in self]
@@ -106,14 +106,14 @@ class AvailableModels(UserList):
106
106
  ]
107
107
  )
108
108
  if len(avm) == 0:
109
- from edsl.inference_services.exceptions import InferenceServiceValueError
109
+ from .exceptions import InferenceServiceValueError
110
110
  raise InferenceServiceValueError(
111
111
  "No models found matching the search pattern: " + pattern
112
112
  )
113
113
  else:
114
114
  return avm
115
115
  except re.error as e:
116
- from edsl.inference_services.exceptions import InferenceServiceValueError
116
+ from .exceptions import InferenceServiceValueError
117
117
  raise InferenceServiceValueError(f"Invalid regular expression pattern: {e}")
118
118
 
119
119
 
@@ -128,7 +128,7 @@ class ServiceToModelsMapping(UserDict):
128
128
  def _validate_service_names(self):
129
129
  for service in self.service_names:
130
130
  if service not in InferenceServiceLiteral:
131
- from edsl.inference_services.exceptions import InferenceServiceValueError
131
+ from .exceptions import InferenceServiceValueError
132
132
  raise InferenceServiceValueError(f"Invalid service name: {service}")
133
133
 
134
134
  def model_to_services(self) -> dict:
@@ -26,7 +26,7 @@ class InferenceServiceABC(ABC):
26
26
  ]
27
27
  for attr in must_have_attributes:
28
28
  if not hasattr(cls, attr):
29
- from edsl.inference_services.exceptions import InferenceServiceNotImplementedError
29
+ from .exceptions import InferenceServiceNotImplementedError
30
30
  raise InferenceServiceNotImplementedError(
31
31
  f"Class {cls.__name__} must have a '{attr}' attribute."
32
32
  )
@@ -1,7 +1,7 @@
1
1
  from collections import defaultdict
2
2
  from typing import Optional, Protocol, Dict, List, Tuple, TYPE_CHECKING
3
3
 
4
- from edsl.enums import InferenceServiceLiteral
4
+ from ..enums import InferenceServiceLiteral
5
5
  from .inference_service_abc import InferenceServiceABC
6
6
  from .available_model_fetcher import AvailableModelFetcher
7
7
  from .exceptions import InferenceServiceError
@@ -42,7 +42,7 @@ class ServiceAvailability:
42
42
  @classmethod
43
43
  def models_from_coop(cls) -> AvailableModels:
44
44
  if not cls._coop_model_list:
45
- from edsl.coop.coop import Coop
45
+ from ..coop.coop import Coop
46
46
 
47
47
  c = Coop()
48
48
  coop_model_list = c.fetch_models()
@@ -74,7 +74,7 @@ class ServiceAvailability:
74
74
  continue
75
75
 
76
76
  # If we get here, all sources failed
77
- from edsl.inference_services.exceptions import InferenceServiceRuntimeError
77
+ from .exceptions import InferenceServiceRuntimeError
78
78
  raise InferenceServiceRuntimeError(
79
79
  f"All sources failed to fetch models. Last error: {last_error}"
80
80
  )
@@ -93,7 +93,7 @@ class ServiceAvailability:
93
93
  @staticmethod
94
94
  def _fetch_from_cache(service: "InferenceServiceABC") -> ModelNamesList:
95
95
  """Fetch models from local cache."""
96
- from edsl.inference_services.models_available_cache import models_available
96
+ from .models_available_cache import models_available
97
97
 
98
98
  return models_available.get(service._inference_service_, [])
99
99
 
@@ -46,7 +46,7 @@ class AzureAIService(InferenceServiceABC):
46
46
  out = []
47
47
  azure_endpoints = os.getenv("AZURE_ENDPOINT_URL_AND_KEY", None)
48
48
  if not azure_endpoints:
49
- from edsl.inference_services.exceptions import InferenceServiceEnvironmentError
49
+ from ..exceptions import InferenceServiceEnvironmentError
50
50
  raise InferenceServiceEnvironmentError("AZURE_ENDPOINT_URL_AND_KEY is not defined")
51
51
  azure_endpoints = azure_endpoints.split(",")
52
52
  for data in azure_endpoints:
@@ -135,7 +135,7 @@ class AzureAIService(InferenceServiceABC):
135
135
  api_key = None
136
136
 
137
137
  if not api_key:
138
- from edsl.inference_services.exceptions import InferenceServiceEnvironmentError
138
+ from ..exceptions import InferenceServiceEnvironmentError
139
139
  raise InferenceServiceEnvironmentError(
140
140
  f"AZURE_ENDPOINT_URL_AND_KEY doesn't have the endpoint:key pair for your model: {model_name}"
141
141
  )
@@ -146,7 +146,7 @@ class AzureAIService(InferenceServiceABC):
146
146
  endpoint = None
147
147
 
148
148
  if not endpoint:
149
- from edsl.inference_services.exceptions import InferenceServiceEnvironmentError
149
+ from ..exceptions import InferenceServiceEnvironmentError
150
150
  raise InferenceServiceEnvironmentError(
151
151
  f"AZURE_ENDPOINT_URL_AND_KEY doesn't have the endpoint:key pair for your model: {model_name}"
152
152
  )
@@ -5,7 +5,7 @@ import google.generativeai as genai
5
5
  from google.generativeai.types import GenerationConfig
6
6
  from google.api_core.exceptions import InvalidArgument
7
7
 
8
- # from edsl.exceptions.general import MissingAPIKeyError
8
+ # from ...exceptions.general import MissingAPIKeyError
9
9
  from ..inference_service_abc import InferenceServiceABC
10
10
  from ...language_models import LanguageModel
11
11
 
@@ -74,7 +74,7 @@ class TestService(InferenceServiceABC):
74
74
  p = 1
75
75
 
76
76
  if random.random() < p:
77
- from edsl.inference_services.exceptions import InferenceServiceError
77
+ from ..exceptions import InferenceServiceError
78
78
  raise InferenceServiceError("This is a test error")
79
79
 
80
80
  if hasattr(self, "func"):