edsl 0.1.50__py3-none-any.whl → 0.1.51__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/__version__.py +1 -1
- edsl/base/base_exception.py +2 -2
- edsl/buckets/bucket_collection.py +1 -1
- edsl/buckets/exceptions.py +32 -0
- edsl/buckets/token_bucket_api.py +26 -10
- edsl/caching/cache.py +5 -2
- edsl/caching/remote_cache_sync.py +5 -5
- edsl/caching/sql_dict.py +12 -11
- edsl/config/__init__.py +1 -1
- edsl/config/config_class.py +4 -2
- edsl/conversation/Conversation.py +7 -4
- edsl/conversation/car_buying.py +1 -3
- edsl/conversation/mug_negotiation.py +2 -6
- edsl/coop/__init__.py +11 -8
- edsl/coop/coop.py +13 -13
- edsl/coop/coop_functions.py +1 -1
- edsl/coop/ep_key_handling.py +1 -1
- edsl/coop/price_fetcher.py +2 -2
- edsl/coop/utils.py +2 -2
- edsl/dataset/dataset.py +144 -63
- edsl/dataset/dataset_operations_mixin.py +14 -6
- edsl/dataset/dataset_tree.py +3 -3
- edsl/dataset/display/table_renderers.py +6 -3
- edsl/dataset/file_exports.py +4 -4
- edsl/dataset/r/ggplot.py +3 -3
- edsl/inference_services/available_model_fetcher.py +2 -2
- edsl/inference_services/data_structures.py +5 -5
- edsl/inference_services/inference_service_abc.py +1 -1
- edsl/inference_services/inference_services_collection.py +1 -1
- edsl/inference_services/service_availability.py +3 -3
- edsl/inference_services/services/azure_ai.py +3 -3
- edsl/inference_services/services/google_service.py +1 -1
- edsl/inference_services/services/test_service.py +1 -1
- edsl/instructions/change_instruction.py +5 -4
- edsl/instructions/instruction.py +1 -0
- edsl/instructions/instruction_collection.py +5 -4
- edsl/instructions/instruction_handler.py +10 -8
- edsl/interviews/exception_tracking.py +1 -1
- edsl/interviews/interview.py +1 -1
- edsl/interviews/interview_status_dictionary.py +1 -1
- edsl/interviews/interview_task_manager.py +2 -2
- edsl/interviews/request_token_estimator.py +3 -2
- edsl/interviews/statistics.py +2 -2
- edsl/invigilators/invigilators.py +2 -2
- edsl/jobs/__init__.py +39 -2
- edsl/jobs/async_interview_runner.py +1 -1
- edsl/jobs/check_survey_scenario_compatibility.py +5 -5
- edsl/jobs/data_structures.py +2 -2
- edsl/jobs/jobs.py +2 -2
- edsl/jobs/jobs_checks.py +5 -5
- edsl/jobs/jobs_component_constructor.py +2 -2
- edsl/jobs/jobs_pricing_estimation.py +1 -1
- edsl/jobs/jobs_runner_asyncio.py +2 -2
- edsl/jobs/remote_inference.py +1 -1
- edsl/jobs/results_exceptions_handler.py +2 -2
- edsl/language_models/language_model.py +5 -1
- edsl/notebooks/__init__.py +24 -1
- edsl/notebooks/exceptions.py +82 -0
- edsl/notebooks/notebook.py +7 -3
- edsl/notebooks/notebook_to_latex.py +1 -1
- edsl/prompts/__init__.py +23 -2
- edsl/prompts/prompt.py +1 -1
- edsl/questions/__init__.py +4 -4
- edsl/questions/answer_validator_mixin.py +0 -5
- edsl/questions/compose_questions.py +2 -2
- edsl/questions/descriptors.py +1 -1
- edsl/questions/question_base.py +32 -3
- edsl/questions/question_base_prompts_mixin.py +4 -4
- edsl/questions/question_budget.py +503 -102
- edsl/questions/question_check_box.py +658 -156
- edsl/questions/question_dict.py +176 -2
- edsl/questions/question_extract.py +401 -61
- edsl/questions/question_free_text.py +77 -9
- edsl/questions/question_functional.py +118 -9
- edsl/questions/{derived/question_likert_five.py → question_likert_five.py} +2 -2
- edsl/questions/{derived/question_linear_scale.py → question_linear_scale.py} +3 -4
- edsl/questions/question_list.py +246 -26
- edsl/questions/question_matrix.py +586 -73
- edsl/questions/question_multiple_choice.py +213 -47
- edsl/questions/question_numerical.py +360 -29
- edsl/questions/question_rank.py +401 -124
- edsl/questions/question_registry.py +3 -3
- edsl/questions/{derived/question_top_k.py → question_top_k.py} +3 -3
- edsl/questions/{derived/question_yes_no.py → question_yes_no.py} +3 -4
- edsl/questions/register_questions_meta.py +2 -1
- edsl/questions/response_validator_abc.py +6 -2
- edsl/questions/response_validator_factory.py +10 -12
- edsl/results/report.py +1 -1
- edsl/results/result.py +7 -4
- edsl/results/results.py +471 -271
- edsl/results/results_selector.py +2 -2
- edsl/scenarios/construct_download_link.py +3 -3
- edsl/scenarios/scenario.py +1 -2
- edsl/scenarios/scenario_list.py +41 -23
- edsl/surveys/survey_css.py +3 -3
- edsl/surveys/survey_simulator.py +2 -1
- edsl/tasks/__init__.py +22 -2
- edsl/tasks/exceptions.py +72 -0
- edsl/tasks/task_history.py +3 -3
- edsl/tokens/__init__.py +27 -1
- edsl/tokens/exceptions.py +37 -0
- edsl/tokens/interview_token_usage.py +3 -2
- edsl/tokens/token_usage.py +4 -3
- {edsl-0.1.50.dist-info → edsl-0.1.51.dist-info}/METADATA +1 -1
- {edsl-0.1.50.dist-info → edsl-0.1.51.dist-info}/RECORD +108 -106
- edsl/questions/derived/__init__.py +0 -0
- {edsl-0.1.50.dist-info → edsl-0.1.51.dist-info}/LICENSE +0 -0
- {edsl-0.1.50.dist-info → edsl-0.1.51.dist-info}/WHEEL +0 -0
- {edsl-0.1.50.dist-info → edsl-0.1.51.dist-info}/entry_points.txt +0 -0
edsl/dataset/dataset.py
CHANGED
@@ -1,5 +1,3 @@
|
|
1
|
-
|
2
|
-
|
3
1
|
from __future__ import annotations
|
4
2
|
import sys
|
5
3
|
import json
|
@@ -10,7 +8,8 @@ from typing import Any, Union, Optional, TYPE_CHECKING, Callable
|
|
10
8
|
from ..base import PersistenceMixin, HashingMixin
|
11
9
|
|
12
10
|
from .dataset_tree import Tree
|
13
|
-
from .exceptions import DatasetKeyError, DatasetValueError
|
11
|
+
from .exceptions import DatasetKeyError, DatasetValueError, DatasetTypeError
|
12
|
+
|
14
13
|
|
15
14
|
from .display.table_display import TableDisplay
|
16
15
|
#from .smart_objects import FirstObject
|
@@ -121,19 +120,9 @@ class Dataset(UserList, DatasetOperationsMixin, PersistenceMixin, HashingMixin):
|
|
121
120
|
new_data.append({key: values[:n]})
|
122
121
|
return Dataset(new_data)
|
123
122
|
|
124
|
-
def expand(self, field):
|
125
|
-
|
126
|
-
|
127
|
-
# def view(self):
|
128
|
-
# from perspective.widget import PerspectiveWidget
|
123
|
+
# def expand(self, field):
|
124
|
+
# return self.to_scenario_list().expand(field)
|
129
125
|
|
130
|
-
# w = PerspectiveWidget(
|
131
|
-
# self.to_pandas(),
|
132
|
-
# plugin="Datagrid",
|
133
|
-
# aggregates={"datetime": "any"},
|
134
|
-
# sort=[["date", "desc"]],
|
135
|
-
# )
|
136
|
-
# return w
|
137
126
|
|
138
127
|
def keys(self) -> list[str]:
|
139
128
|
"""Return the keys of the dataset.
|
@@ -287,12 +276,12 @@ class Dataset(UserList, DatasetOperationsMixin, PersistenceMixin, HashingMixin):
|
|
287
276
|
if len(potential_matches) == 1:
|
288
277
|
return potential_matches[0][1]
|
289
278
|
elif len(potential_matches) > 1:
|
290
|
-
from
|
279
|
+
from .exceptions import DatasetKeyError
|
291
280
|
raise DatasetKeyError(
|
292
281
|
f"Key '{key}' found in more than one location: {[m[0] for m in potential_matches]}"
|
293
282
|
)
|
294
283
|
|
295
|
-
from
|
284
|
+
from .exceptions import DatasetKeyError
|
296
285
|
raise DatasetKeyError(f"Key '{key}' not found in any of the dictionaries.")
|
297
286
|
|
298
287
|
def first(self) -> dict[str, Any]:
|
@@ -376,11 +365,12 @@ class Dataset(UserList, DatasetOperationsMixin, PersistenceMixin, HashingMixin):
|
|
376
365
|
>>> d = Dataset([{'person_name':["John"]}])
|
377
366
|
>>> from edsl import QuestionFreeText
|
378
367
|
>>> q = QuestionFreeText(question_text = "How are you, {{ person_name ?}}?", question_name = "how_feeling")
|
379
|
-
>>> d.to(q)
|
380
|
-
|
368
|
+
>>> jobs = d.to(q)
|
369
|
+
>>> isinstance(jobs, object)
|
370
|
+
True
|
381
371
|
"""
|
382
|
-
from
|
383
|
-
from
|
372
|
+
from ..surveys import Survey
|
373
|
+
from ..questions import QuestionBase
|
384
374
|
|
385
375
|
if isinstance(survey_or_question, Survey):
|
386
376
|
return survey_or_question.by(self.to_scenario_list())
|
@@ -402,7 +392,7 @@ class Dataset(UserList, DatasetOperationsMixin, PersistenceMixin, HashingMixin):
|
|
402
392
|
"""
|
403
393
|
for key in keys:
|
404
394
|
if key not in self.keys():
|
405
|
-
from
|
395
|
+
from .exceptions import DatasetValueError
|
406
396
|
raise DatasetValueError(f"Key '{key}' not found in the dataset. "
|
407
397
|
f"Available keys: {self.keys()}"
|
408
398
|
)
|
@@ -479,11 +469,11 @@ class Dataset(UserList, DatasetOperationsMixin, PersistenceMixin, HashingMixin):
|
|
479
469
|
|
480
470
|
# Validate the input for sampling parameters
|
481
471
|
if n is None and frac is None:
|
482
|
-
from
|
472
|
+
from .exceptions import DatasetValueError
|
483
473
|
raise DatasetValueError("Either 'n' or 'frac' must be provided for sampling.")
|
484
474
|
|
485
475
|
if n is not None and frac is not None:
|
486
|
-
from
|
476
|
+
from .exceptions import DatasetValueError
|
487
477
|
raise DatasetValueError("Only one of 'n' or 'frac' should be specified.")
|
488
478
|
|
489
479
|
# Get the length of the lists from the first entry
|
@@ -495,7 +485,7 @@ class Dataset(UserList, DatasetOperationsMixin, PersistenceMixin, HashingMixin):
|
|
495
485
|
n = int(total_length * frac)
|
496
486
|
|
497
487
|
if not with_replacement and n > total_length:
|
498
|
-
from
|
488
|
+
from .exceptions import DatasetValueError
|
499
489
|
raise DatasetValueError(
|
500
490
|
"Sample size cannot be greater than the number of available elements when sampling without replacement."
|
501
491
|
)
|
@@ -513,47 +503,61 @@ class Dataset(UserList, DatasetOperationsMixin, PersistenceMixin, HashingMixin):
|
|
513
503
|
|
514
504
|
return self
|
515
505
|
|
516
|
-
def
|
517
|
-
"""Return a new dataset with the observations sorted by the given key.
|
518
|
-
|
519
|
-
:param sort_key: The key to sort the observations by.
|
520
|
-
:param reverse: Whether to sort in reverse order.
|
521
|
-
|
522
|
-
>>> d = Dataset([{'a':[1,2,3,4]}, {'b':[4,3,2,1]}])
|
523
|
-
>>> d.order_by('a')
|
524
|
-
Dataset([{'a': [1, 2, 3, 4]}, {'b': [4, 3, 2, 1]}])
|
525
|
-
|
526
|
-
>>> d.order_by('a', reverse=True)
|
527
|
-
Dataset([{'a': [4, 3, 2, 1]}, {'b': [1, 2, 3, 4]}])
|
528
|
-
|
529
|
-
>>> d = Dataset([{'X.a':[1,2,3,4]}, {'X.b':[4,3,2,1]}])
|
530
|
-
>>> d.order_by('a')
|
531
|
-
Dataset([{'X.a': [1, 2, 3, 4]}, {'X.b': [4, 3, 2, 1]}])
|
532
|
-
|
533
|
-
|
506
|
+
def get_sort_indices(self, lst: list[Any], reverse: bool = False, use_numpy: bool = True) -> list[int]:
|
534
507
|
"""
|
535
|
-
|
508
|
+
Return the indices that would sort the list, using either numpy or pure Python.
|
509
|
+
None values are placed at the end of the sorted list.
|
536
510
|
|
537
|
-
|
538
|
-
|
539
|
-
|
511
|
+
Args:
|
512
|
+
lst: The list to be sorted
|
513
|
+
reverse: Whether to sort in descending order
|
514
|
+
use_numpy: Whether to use numpy implementation (falls back to pure Python if numpy is unavailable)
|
540
515
|
|
541
|
-
|
542
|
-
|
543
|
-
|
544
|
-
|
545
|
-
|
546
|
-
|
547
|
-
|
516
|
+
Returns:
|
517
|
+
A list of indices that would sort the list
|
518
|
+
"""
|
519
|
+
if use_numpy:
|
520
|
+
try:
|
521
|
+
import numpy as np
|
522
|
+
# Convert list to numpy array
|
523
|
+
arr = np.array(lst, dtype=object)
|
524
|
+
# Get mask of non-None values
|
525
|
+
mask = ~(arr is None)
|
526
|
+
# Get indices of non-None and None values
|
527
|
+
non_none_indices = np.where(mask)[0]
|
528
|
+
none_indices = np.where(~mask)[0]
|
529
|
+
# Sort non-None values
|
530
|
+
sorted_indices = non_none_indices[np.argsort(arr[mask])]
|
531
|
+
# Combine sorted non-None indices with None indices
|
532
|
+
indices = np.concatenate([sorted_indices, none_indices]).tolist()
|
533
|
+
if reverse:
|
534
|
+
# When reversing, keep None values at end
|
535
|
+
indices = sorted_indices[::-1].tolist() + none_indices.tolist()
|
536
|
+
return indices
|
537
|
+
except ImportError:
|
538
|
+
# Fallback to pure Python if numpy is not available
|
539
|
+
pass
|
540
|
+
|
541
|
+
# Pure Python implementation
|
542
|
+
enumerated = list(enumerate(lst))
|
543
|
+
# Sort None values to end by using (is_none, value) as sort key
|
544
|
+
sorted_pairs = sorted(enumerated,
|
545
|
+
key=lambda x: (x[1] is None, x[1]),
|
546
|
+
reverse=reverse)
|
547
|
+
return [index for index, _ in sorted_pairs]
|
548
|
+
|
549
|
+
def order_by(self, sort_key: str, reverse: bool = False, use_numpy: bool = True) -> Dataset:
|
550
|
+
"""Return a new dataset with the observations sorted by the given key.
|
548
551
|
|
552
|
+
Args:
|
553
|
+
sort_key: The key to sort the observations by
|
554
|
+
reverse: Whether to sort in reverse order
|
555
|
+
use_numpy: Whether to use numpy for sorting (faster for large lists)
|
556
|
+
"""
|
549
557
|
number_found = 0
|
550
558
|
for obs in self.data:
|
551
559
|
key, values = list(obs.items())[0]
|
552
|
-
|
553
|
-
# key = list(obs.keys())[0]
|
554
|
-
if (
|
555
|
-
sort_key == key or sort_key == key.split(".")[-1]
|
556
|
-
): # e.g., "age" in "scenario.age"
|
560
|
+
if sort_key == key or sort_key == key.split(".")[-1]:
|
557
561
|
relevant_values = values
|
558
562
|
number_found += 1
|
559
563
|
|
@@ -562,11 +566,9 @@ class Dataset(UserList, DatasetOperationsMixin, PersistenceMixin, HashingMixin):
|
|
562
566
|
elif number_found > 1:
|
563
567
|
raise DatasetKeyError(f"Key '{sort_key}' found in more than one dictionary.")
|
564
568
|
|
565
|
-
|
566
|
-
sort_indices_list = sort_indices(relevant_values)
|
569
|
+
sort_indices_list = self.get_sort_indices(relevant_values, reverse=reverse, use_numpy=use_numpy)
|
567
570
|
new_data = []
|
568
571
|
for observation in self.data:
|
569
|
-
# print(observation)
|
570
572
|
key, values = list(observation.items())[0]
|
571
573
|
new_values = [values[i] for i in sort_indices_list]
|
572
574
|
new_data.append({key: new_values})
|
@@ -646,7 +648,7 @@ class Dataset(UserList, DatasetOperationsMixin, PersistenceMixin, HashingMixin):
|
|
646
648
|
|
647
649
|
if max_rows is not None:
|
648
650
|
if max_rows > len(data):
|
649
|
-
from
|
651
|
+
from .exceptions import DatasetValueError
|
650
652
|
raise DatasetValueError(
|
651
653
|
"max_rows cannot be greater than the number of rows in the dataset."
|
652
654
|
)
|
@@ -685,6 +687,19 @@ class Dataset(UserList, DatasetOperationsMixin, PersistenceMixin, HashingMixin):
|
|
685
687
|
def from_pandas_dataframe(cls, df):
|
686
688
|
result = cls([{col: df[col].tolist()} for col in df.columns])
|
687
689
|
return result
|
690
|
+
|
691
|
+
def to_dict(self) -> dict:
|
692
|
+
"""
|
693
|
+
Convert the dataset to a dictionary.
|
694
|
+
"""
|
695
|
+
return {'data': self.data}
|
696
|
+
|
697
|
+
@classmethod
|
698
|
+
def from_dict(cls, data: dict) -> 'Dataset':
|
699
|
+
"""
|
700
|
+
Convert a dictionary to a dataset.
|
701
|
+
"""
|
702
|
+
return cls(data['data'])
|
688
703
|
|
689
704
|
def to_docx(self, output_file: str, title: str = None) -> None:
|
690
705
|
"""
|
@@ -736,6 +751,72 @@ class Dataset(UserList, DatasetOperationsMixin, PersistenceMixin, HashingMixin):
|
|
736
751
|
# Save the document
|
737
752
|
doc.save(output_file)
|
738
753
|
|
754
|
+
def expand(self, field: str, number_field: bool = False) -> "Dataset":
|
755
|
+
"""
|
756
|
+
Expand a field containing lists into multiple rows.
|
757
|
+
|
758
|
+
Args:
|
759
|
+
field: The field containing lists to expand
|
760
|
+
number_field: If True, adds a number field indicating the position in the original list
|
761
|
+
|
762
|
+
Returns:
|
763
|
+
A new Dataset with the expanded rows
|
764
|
+
|
765
|
+
Example:
|
766
|
+
>>> from edsl.dataset import Dataset
|
767
|
+
>>> d = Dataset([{'a': [[1, 2, 3], [4, 5, 6]]}, {'b': ['x', 'y']}])
|
768
|
+
>>> d.expand('a')
|
769
|
+
Dataset([{'a': [1, 2, 3, 4, 5, 6]}, {'b': ['x', 'x', 'x', 'y', 'y', 'y']}])
|
770
|
+
"""
|
771
|
+
from collections.abc import Iterable
|
772
|
+
|
773
|
+
# Find the field in the dataset
|
774
|
+
field_data = None
|
775
|
+
for entry in self.data:
|
776
|
+
key = list(entry.keys())[0]
|
777
|
+
if key == field:
|
778
|
+
field_data = entry[key]
|
779
|
+
break
|
780
|
+
|
781
|
+
if field_data is None:
|
782
|
+
raise DatasetKeyError(f"Field '{field}' not found in dataset. Available fields are: {self.keys()}")
|
783
|
+
|
784
|
+
|
785
|
+
# Validate that the field contains lists
|
786
|
+
if not all(isinstance(v, list) for v in field_data):
|
787
|
+
raise DatasetTypeError(f"Field '{field}' must contain lists in all entries")
|
788
|
+
|
789
|
+
# Create new expanded data structure
|
790
|
+
new_data = []
|
791
|
+
|
792
|
+
# Process each field
|
793
|
+
for entry in self.data:
|
794
|
+
key, values = list(entry.items())[0]
|
795
|
+
new_values = []
|
796
|
+
|
797
|
+
if key == field:
|
798
|
+
# This is the field to expand - flatten all sublists
|
799
|
+
for row_values in values:
|
800
|
+
if not isinstance(row_values, Iterable) or isinstance(row_values, str):
|
801
|
+
row_values = [row_values]
|
802
|
+
new_values.extend(row_values)
|
803
|
+
else:
|
804
|
+
# For other fields, repeat each value the appropriate number of times
|
805
|
+
for i, row_value in enumerate(values):
|
806
|
+
expand_length = len(field_data[i]) if i < len(field_data) else 0
|
807
|
+
new_values.extend([row_value] * expand_length)
|
808
|
+
|
809
|
+
new_data.append({key: new_values})
|
810
|
+
|
811
|
+
# Add number field if requested
|
812
|
+
if number_field:
|
813
|
+
number_values = []
|
814
|
+
for i, lst in enumerate(field_data):
|
815
|
+
number_values.extend(range(1, len(lst) + 1))
|
816
|
+
new_data.append({f"{field}_number": number_values})
|
817
|
+
|
818
|
+
return Dataset(new_data)
|
819
|
+
|
739
820
|
|
740
821
|
if __name__ == "__main__":
|
741
822
|
import doctest
|
@@ -184,6 +184,13 @@ class DataOperationsBase:
|
|
184
184
|
)
|
185
185
|
|
186
186
|
return _num_observations
|
187
|
+
|
188
|
+
def chart(self):
|
189
|
+
"""
|
190
|
+
Create a chart from the results.
|
191
|
+
"""
|
192
|
+
import altair as alt
|
193
|
+
return alt.Chart(self.to_pandas(remove_prefix=True))
|
187
194
|
|
188
195
|
def make_tabular(
|
189
196
|
self, remove_prefix: bool, pretty_labels: Optional[dict] = None
|
@@ -538,13 +545,14 @@ class DataOperationsBase:
|
|
538
545
|
>>> r.select('how_feeling').to_scenario_list()
|
539
546
|
ScenarioList([Scenario({'how_feeling': 'OK'}), Scenario({'how_feeling': 'Great'}), Scenario({'how_feeling': 'Terrible'}), Scenario({'how_feeling': 'OK'})])
|
540
547
|
"""
|
541
|
-
from
|
548
|
+
from ..scenarios import ScenarioList, Scenario
|
542
549
|
|
543
550
|
list_of_dicts = self.to_dicts(remove_prefix=remove_prefix)
|
544
551
|
scenarios = []
|
545
552
|
for d in list_of_dicts:
|
546
553
|
scenarios.append(Scenario(d))
|
547
554
|
return ScenarioList(scenarios)
|
555
|
+
|
548
556
|
|
549
557
|
def to_agent_list(self, remove_prefix: bool = True):
|
550
558
|
"""Convert the results to a list of dictionaries, one per agent.
|
@@ -556,7 +564,7 @@ class DataOperationsBase:
|
|
556
564
|
>>> r.select('how_feeling').to_agent_list()
|
557
565
|
AgentList([Agent(traits = {'how_feeling': 'OK'}), Agent(traits = {'how_feeling': 'Great'}), Agent(traits = {'how_feeling': 'Terrible'}), Agent(traits = {'how_feeling': 'OK'})])
|
558
566
|
"""
|
559
|
-
from
|
567
|
+
from ..agents import Agent, AgentList
|
560
568
|
|
561
569
|
list_of_dicts = self.to_dicts(remove_prefix=remove_prefix)
|
562
570
|
agents = []
|
@@ -665,7 +673,7 @@ class DataOperationsBase:
|
|
665
673
|
):
|
666
674
|
import os
|
667
675
|
import tempfile
|
668
|
-
from
|
676
|
+
from ..utilities.utilities import is_notebook
|
669
677
|
from IPython.display import HTML, display
|
670
678
|
|
671
679
|
df = self.to_pandas()
|
@@ -799,7 +807,7 @@ class DataOperationsBase:
|
|
799
807
|
from docx.shared import Pt
|
800
808
|
import json
|
801
809
|
except ImportError:
|
802
|
-
from
|
810
|
+
from .exceptions import DatasetImportError
|
803
811
|
raise DatasetImportError("The python-docx package is required for DOCX export. Install it with 'pip install python-docx'.")
|
804
812
|
|
805
813
|
doc = Document()
|
@@ -871,7 +879,7 @@ class DataOperationsBase:
|
|
871
879
|
>>> isinstance(doc, object)
|
872
880
|
True
|
873
881
|
"""
|
874
|
-
from
|
882
|
+
from ..utilities.utilities import is_notebook
|
875
883
|
|
876
884
|
# Prepare the data for the report
|
877
885
|
field_data, num_obs, fields, header_fields = self._prepare_report_data(
|
@@ -1076,7 +1084,7 @@ class DataOperationsBase:
|
|
1076
1084
|
# Check if the field is ambiguous
|
1077
1085
|
if len(matching_entries) > 1:
|
1078
1086
|
matching_cols = [next(iter(entry.keys())) for entry in matching_entries]
|
1079
|
-
from
|
1087
|
+
from .exceptions import DatasetValueError
|
1080
1088
|
raise DatasetValueError(
|
1081
1089
|
f"Ambiguous field name '{field}'. It matches multiple columns: {matching_cols}. "
|
1082
1090
|
f"Please specify the full column name to flatten."
|
edsl/dataset/dataset_tree.py
CHANGED
@@ -51,7 +51,7 @@ class Tree:
|
|
51
51
|
else:
|
52
52
|
if not set(node_order).issubset(set(self.data.keys())):
|
53
53
|
invalid_keys = set(node_order) - set(self.data.keys())
|
54
|
-
from
|
54
|
+
from .exceptions import DatasetValueError
|
55
55
|
raise DatasetValueError(f"Invalid keys in node_order: {invalid_keys}")
|
56
56
|
|
57
57
|
self.root = TreeNode()
|
@@ -130,7 +130,7 @@ class Tree:
|
|
130
130
|
doc_buffer.seek(0)
|
131
131
|
|
132
132
|
base64_string = base64.b64encode(doc_buffer.getvalue()).decode("utf-8")
|
133
|
-
from
|
133
|
+
from ..scenarios.file_store import FileStore
|
134
134
|
|
135
135
|
# Create and return FileStore instance
|
136
136
|
return FileStore(
|
@@ -335,7 +335,7 @@ class Tree:
|
|
335
335
|
Returns:
|
336
336
|
A string containing the markdown document, or renders markdown in notebooks.
|
337
337
|
"""
|
338
|
-
from
|
338
|
+
from ..utilities.utilities import is_notebook
|
339
339
|
from IPython.display import Markdown, display
|
340
340
|
|
341
341
|
if node is None:
|
@@ -103,9 +103,12 @@ class PandasStyleRenderer(DataTablesRendererABC):
|
|
103
103
|
else:
|
104
104
|
df = pd.DataFrame(self.table_data.data, columns=self.table_data.headers)
|
105
105
|
|
106
|
-
styled_df = df.style.set_properties(
|
107
|
-
|
108
|
-
|
106
|
+
styled_df = df.style.set_properties(**{
|
107
|
+
"text-align": "left",
|
108
|
+
"white-space": "pre-wrap", # Allows text wrapping
|
109
|
+
"max-width": "300px", # Maximum width before wrapping
|
110
|
+
"word-wrap": "break-word" # Breaks words that exceed max-width
|
111
|
+
}).background_gradient()
|
109
112
|
|
110
113
|
return f"""
|
111
114
|
<div style="max-height: 500px; overflow-y: auto;">
|
edsl/dataset/file_exports.py
CHANGED
@@ -40,7 +40,7 @@ class FileExport(ABC):
|
|
40
40
|
|
41
41
|
def _create_filestore(self, data: Union[str, bytes]):
|
42
42
|
"""Create a FileStore instance with encoded data."""
|
43
|
-
from ..scenarios import FileStore
|
43
|
+
from ..scenarios.file_store import FileStore
|
44
44
|
if isinstance(data, str):
|
45
45
|
base64_string = base64.b64encode(data.encode()).decode()
|
46
46
|
else:
|
@@ -203,7 +203,7 @@ class SQLiteExport(TabularExport):
|
|
203
203
|
(self.table_name,),
|
204
204
|
)
|
205
205
|
if cursor.fetchone():
|
206
|
-
from
|
206
|
+
from .exceptions import DatasetValueError
|
207
207
|
raise DatasetValueError(f"Table {self.table_name} already exists")
|
208
208
|
|
209
209
|
# Create table
|
@@ -245,14 +245,14 @@ class SQLiteExport(TabularExport):
|
|
245
245
|
"""Validate initialization parameters."""
|
246
246
|
valid_if_exists = {"fail", "replace", "append"}
|
247
247
|
if self.if_exists not in valid_if_exists:
|
248
|
-
from
|
248
|
+
from .exceptions import DatasetValueError
|
249
249
|
raise DatasetValueError(
|
250
250
|
f"if_exists must be one of {valid_if_exists}, got {self.if_exists}"
|
251
251
|
)
|
252
252
|
|
253
253
|
# Validate table name (basic SQLite identifier validation)
|
254
254
|
if not self.table_name.isalnum() and not all(c in "_" for c in self.table_name):
|
255
|
-
from
|
255
|
+
from .exceptions import DatasetValueError
|
256
256
|
raise DatasetValueError(
|
257
257
|
f"Invalid table name: {self.table_name}. Must contain only alphanumeric characters and underscores."
|
258
258
|
)
|
edsl/dataset/r/ggplot.py
CHANGED
@@ -30,12 +30,12 @@ class GGPlot:
|
|
30
30
|
|
31
31
|
if result.returncode != 0:
|
32
32
|
if result.returncode == 127:
|
33
|
-
from
|
33
|
+
from ..exceptions import DatasetRuntimeError
|
34
34
|
raise DatasetRuntimeError(
|
35
35
|
"Rscript is probably not installed. Please install R from https://cran.r-project.org/"
|
36
36
|
)
|
37
37
|
else:
|
38
|
-
from
|
38
|
+
from ..exceptions import DatasetRuntimeError
|
39
39
|
raise DatasetRuntimeError(
|
40
40
|
f"An error occurred while running Rscript: {result.stderr}"
|
41
41
|
)
|
@@ -49,7 +49,7 @@ class GGPlot:
|
|
49
49
|
"""Save the plot to a file."""
|
50
50
|
format = filename.split('.')[-1].lower()
|
51
51
|
if format not in ['svg', 'png']:
|
52
|
-
from
|
52
|
+
from ..exceptions import DatasetValueError
|
53
53
|
raise DatasetValueError("Only 'svg' and 'png' formats are supported")
|
54
54
|
|
55
55
|
save_command = f'\nggsave("{filename}", plot = last_plot(), width = {self.width}, height = {self.height}, device = "{format}")'
|
@@ -55,7 +55,7 @@ class AvailableModelFetcher:
|
|
55
55
|
|
56
56
|
:param service: Optional[InferenceServiceABC] - If specified, only fetch models for this service.
|
57
57
|
|
58
|
-
>>> from
|
58
|
+
>>> from .services.open_ai_service import OpenAIService
|
59
59
|
>>> af = AvailableModelFetcher([OpenAIService()], {})
|
60
60
|
>>> af.available(service="openai")
|
61
61
|
[LanguageModelInfo(model_name='...', service_name='openai'), ...]
|
@@ -155,7 +155,7 @@ class AvailableModelFetcher:
|
|
155
155
|
"""The service name is the _inference_service_ attribute of the service."""
|
156
156
|
if service_name in self._service_map:
|
157
157
|
return self._service_map[service_name]
|
158
|
-
from
|
158
|
+
from .exceptions import InferenceServiceValueError
|
159
159
|
raise InferenceServiceValueError(f"Service {service_name} not found")
|
160
160
|
|
161
161
|
def _get_all_models(self, force_refresh=False) -> List[LanguageModelInfo]:
|
@@ -43,7 +43,7 @@ class LanguageModelInfo:
|
|
43
43
|
elif key == 1:
|
44
44
|
return self.service_name
|
45
45
|
else:
|
46
|
-
from
|
46
|
+
from .exceptions import InferenceServiceIndexError
|
47
47
|
raise InferenceServiceIndexError("Index out of range")
|
48
48
|
|
49
49
|
@classmethod
|
@@ -70,7 +70,7 @@ class AvailableModels(UserList):
|
|
70
70
|
return self.to_dataset().print()
|
71
71
|
|
72
72
|
def to_dataset(self):
|
73
|
-
from
|
73
|
+
from ..scenarios.scenario_list import ScenarioList
|
74
74
|
|
75
75
|
models, services = zip(
|
76
76
|
*[(model.model_name, model.service_name) for model in self]
|
@@ -106,14 +106,14 @@ class AvailableModels(UserList):
|
|
106
106
|
]
|
107
107
|
)
|
108
108
|
if len(avm) == 0:
|
109
|
-
from
|
109
|
+
from .exceptions import InferenceServiceValueError
|
110
110
|
raise InferenceServiceValueError(
|
111
111
|
"No models found matching the search pattern: " + pattern
|
112
112
|
)
|
113
113
|
else:
|
114
114
|
return avm
|
115
115
|
except re.error as e:
|
116
|
-
from
|
116
|
+
from .exceptions import InferenceServiceValueError
|
117
117
|
raise InferenceServiceValueError(f"Invalid regular expression pattern: {e}")
|
118
118
|
|
119
119
|
|
@@ -128,7 +128,7 @@ class ServiceToModelsMapping(UserDict):
|
|
128
128
|
def _validate_service_names(self):
|
129
129
|
for service in self.service_names:
|
130
130
|
if service not in InferenceServiceLiteral:
|
131
|
-
from
|
131
|
+
from .exceptions import InferenceServiceValueError
|
132
132
|
raise InferenceServiceValueError(f"Invalid service name: {service}")
|
133
133
|
|
134
134
|
def model_to_services(self) -> dict:
|
@@ -26,7 +26,7 @@ class InferenceServiceABC(ABC):
|
|
26
26
|
]
|
27
27
|
for attr in must_have_attributes:
|
28
28
|
if not hasattr(cls, attr):
|
29
|
-
from
|
29
|
+
from .exceptions import InferenceServiceNotImplementedError
|
30
30
|
raise InferenceServiceNotImplementedError(
|
31
31
|
f"Class {cls.__name__} must have a '{attr}' attribute."
|
32
32
|
)
|
@@ -1,7 +1,7 @@
|
|
1
1
|
from collections import defaultdict
|
2
2
|
from typing import Optional, Protocol, Dict, List, Tuple, TYPE_CHECKING
|
3
3
|
|
4
|
-
from
|
4
|
+
from ..enums import InferenceServiceLiteral
|
5
5
|
from .inference_service_abc import InferenceServiceABC
|
6
6
|
from .available_model_fetcher import AvailableModelFetcher
|
7
7
|
from .exceptions import InferenceServiceError
|
@@ -42,7 +42,7 @@ class ServiceAvailability:
|
|
42
42
|
@classmethod
|
43
43
|
def models_from_coop(cls) -> AvailableModels:
|
44
44
|
if not cls._coop_model_list:
|
45
|
-
from
|
45
|
+
from ..coop.coop import Coop
|
46
46
|
|
47
47
|
c = Coop()
|
48
48
|
coop_model_list = c.fetch_models()
|
@@ -74,7 +74,7 @@ class ServiceAvailability:
|
|
74
74
|
continue
|
75
75
|
|
76
76
|
# If we get here, all sources failed
|
77
|
-
from
|
77
|
+
from .exceptions import InferenceServiceRuntimeError
|
78
78
|
raise InferenceServiceRuntimeError(
|
79
79
|
f"All sources failed to fetch models. Last error: {last_error}"
|
80
80
|
)
|
@@ -93,7 +93,7 @@ class ServiceAvailability:
|
|
93
93
|
@staticmethod
|
94
94
|
def _fetch_from_cache(service: "InferenceServiceABC") -> ModelNamesList:
|
95
95
|
"""Fetch models from local cache."""
|
96
|
-
from
|
96
|
+
from .models_available_cache import models_available
|
97
97
|
|
98
98
|
return models_available.get(service._inference_service_, [])
|
99
99
|
|
@@ -46,7 +46,7 @@ class AzureAIService(InferenceServiceABC):
|
|
46
46
|
out = []
|
47
47
|
azure_endpoints = os.getenv("AZURE_ENDPOINT_URL_AND_KEY", None)
|
48
48
|
if not azure_endpoints:
|
49
|
-
from
|
49
|
+
from ..exceptions import InferenceServiceEnvironmentError
|
50
50
|
raise InferenceServiceEnvironmentError("AZURE_ENDPOINT_URL_AND_KEY is not defined")
|
51
51
|
azure_endpoints = azure_endpoints.split(",")
|
52
52
|
for data in azure_endpoints:
|
@@ -135,7 +135,7 @@ class AzureAIService(InferenceServiceABC):
|
|
135
135
|
api_key = None
|
136
136
|
|
137
137
|
if not api_key:
|
138
|
-
from
|
138
|
+
from ..exceptions import InferenceServiceEnvironmentError
|
139
139
|
raise InferenceServiceEnvironmentError(
|
140
140
|
f"AZURE_ENDPOINT_URL_AND_KEY doesn't have the endpoint:key pair for your model: {model_name}"
|
141
141
|
)
|
@@ -146,7 +146,7 @@ class AzureAIService(InferenceServiceABC):
|
|
146
146
|
endpoint = None
|
147
147
|
|
148
148
|
if not endpoint:
|
149
|
-
from
|
149
|
+
from ..exceptions import InferenceServiceEnvironmentError
|
150
150
|
raise InferenceServiceEnvironmentError(
|
151
151
|
f"AZURE_ENDPOINT_URL_AND_KEY doesn't have the endpoint:key pair for your model: {model_name}"
|
152
152
|
)
|
@@ -5,7 +5,7 @@ import google.generativeai as genai
|
|
5
5
|
from google.generativeai.types import GenerationConfig
|
6
6
|
from google.api_core.exceptions import InvalidArgument
|
7
7
|
|
8
|
-
# from
|
8
|
+
# from ...exceptions.general import MissingAPIKeyError
|
9
9
|
from ..inference_service_abc import InferenceServiceABC
|
10
10
|
from ...language_models import LanguageModel
|
11
11
|
|
@@ -74,7 +74,7 @@ class TestService(InferenceServiceABC):
|
|
74
74
|
p = 1
|
75
75
|
|
76
76
|
if random.random() < p:
|
77
|
-
from
|
77
|
+
from ..exceptions import InferenceServiceError
|
78
78
|
raise InferenceServiceError("This is a test error")
|
79
79
|
|
80
80
|
if hasattr(self, "func"):
|