edsl 0.1.36.dev5__py3-none-any.whl → 0.1.37__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. edsl/__init__.py +1 -0
  2. edsl/__version__.py +1 -1
  3. edsl/agents/Agent.py +92 -41
  4. edsl/agents/AgentList.py +15 -2
  5. edsl/agents/InvigilatorBase.py +15 -25
  6. edsl/agents/PromptConstructor.py +149 -108
  7. edsl/agents/descriptors.py +17 -4
  8. edsl/conjure/AgentConstructionMixin.py +11 -3
  9. edsl/conversation/Conversation.py +66 -14
  10. edsl/conversation/chips.py +95 -0
  11. edsl/coop/coop.py +148 -39
  12. edsl/data/Cache.py +1 -1
  13. edsl/data/RemoteCacheSync.py +25 -12
  14. edsl/exceptions/BaseException.py +21 -0
  15. edsl/exceptions/__init__.py +7 -3
  16. edsl/exceptions/agents.py +17 -19
  17. edsl/exceptions/results.py +11 -8
  18. edsl/exceptions/scenarios.py +22 -0
  19. edsl/exceptions/surveys.py +13 -10
  20. edsl/inference_services/AwsBedrock.py +7 -2
  21. edsl/inference_services/InferenceServicesCollection.py +42 -13
  22. edsl/inference_services/models_available_cache.py +25 -1
  23. edsl/jobs/Jobs.py +306 -71
  24. edsl/jobs/interviews/Interview.py +24 -14
  25. edsl/jobs/interviews/InterviewExceptionCollection.py +1 -1
  26. edsl/jobs/interviews/InterviewExceptionEntry.py +17 -13
  27. edsl/jobs/interviews/ReportErrors.py +2 -2
  28. edsl/jobs/runners/JobsRunnerAsyncio.py +10 -9
  29. edsl/jobs/tasks/TaskHistory.py +1 -0
  30. edsl/language_models/KeyLookup.py +30 -0
  31. edsl/language_models/LanguageModel.py +47 -59
  32. edsl/language_models/__init__.py +1 -0
  33. edsl/prompts/Prompt.py +11 -12
  34. edsl/questions/QuestionBase.py +53 -13
  35. edsl/questions/QuestionBasePromptsMixin.py +1 -33
  36. edsl/questions/QuestionFreeText.py +1 -0
  37. edsl/questions/QuestionFunctional.py +2 -2
  38. edsl/questions/descriptors.py +23 -28
  39. edsl/results/DatasetExportMixin.py +25 -1
  40. edsl/results/Result.py +27 -10
  41. edsl/results/Results.py +34 -121
  42. edsl/results/ResultsDBMixin.py +1 -1
  43. edsl/results/Selector.py +18 -1
  44. edsl/scenarios/FileStore.py +20 -5
  45. edsl/scenarios/Scenario.py +52 -13
  46. edsl/scenarios/ScenarioHtmlMixin.py +7 -2
  47. edsl/scenarios/ScenarioList.py +12 -1
  48. edsl/scenarios/__init__.py +2 -0
  49. edsl/surveys/Rule.py +10 -4
  50. edsl/surveys/Survey.py +100 -77
  51. edsl/utilities/utilities.py +18 -0
  52. {edsl-0.1.36.dev5.dist-info → edsl-0.1.37.dist-info}/METADATA +1 -1
  53. {edsl-0.1.36.dev5.dist-info → edsl-0.1.37.dist-info}/RECORD +55 -51
  54. {edsl-0.1.36.dev5.dist-info → edsl-0.1.37.dist-info}/LICENSE +0 -0
  55. {edsl-0.1.36.dev5.dist-info → edsl-0.1.37.dist-info}/WHEEL +0 -0
@@ -53,33 +53,12 @@ class BaseDescriptor(ABC):
53
53
 
54
54
  def __set__(self, instance, value: Any) -> None:
55
55
  """Set the value of the attribute."""
56
- self.validate(value, instance)
57
- # from edsl.prompts.registry import get_classes
58
-
59
- instance.__dict__[self.name] = value
60
- # if self.name == "_instructions":
61
- # instructions = value
62
- # if value is not None:
63
- # instance.__dict__[self.name] = instructions
64
- # instance.set_instructions = True
65
- # else:
66
- # potential_prompt_classes = get_classes(
67
- # question_type=instance.question_type
68
- # )
69
- # if len(potential_prompt_classes) > 0:
70
- # instructions = potential_prompt_classes[0]().text
71
- # instance.__dict__[self.name] = instructions
72
- # instance.set_instructions = False
73
- # else:
74
- # if not hasattr(instance, "default_instructions"):
75
- # raise Exception(
76
- # "No default instructions found and no matching prompts!"
77
- # )
78
- # instructions = instance.default_instructions
79
- # instance.__dict__[self.name] = instructions
80
- # instance.set_instructions = False
81
-
82
- # instance.set_instructions = value != instance.default_instructions
56
+ new_value = self.validate(value, instance)
57
+
58
+ if new_value is not None:
59
+ instance.__dict__[self.name] = new_value
60
+ else:
61
+ instance.__dict__[self.name] = value
83
62
 
84
63
  def __set_name__(self, owner, name: str) -> None:
85
64
  """Set the name of the attribute."""
@@ -400,10 +379,24 @@ class QuestionTextDescriptor(BaseDescriptor):
400
379
  if contains_single_braced_substring(value):
401
380
  import warnings
402
381
 
382
+ # # warnings.warn(
383
+ # # f"WARNING: Question text contains a single-braced substring: If you intended to parameterize the question with a Scenario this should be changed to a double-braced substring, e.g. {{variable}}.\nSee details on constructing Scenarios in the docs: https://docs.expectedparrot.com/en/latest/scenarios.html",
384
+ # # UserWarning,
385
+ # # )
403
386
  warnings.warn(
404
- f"WARNING: Question text contains a single-braced substring: If you intended to parameterize the question with a Scenario this should be changed to a double-braced substring, e.g. {{variable}}.\nSee details on constructing Scenarios in the docs: https://docs.expectedparrot.com/en/latest/scenarios.html",
387
+ "WARNING: Question text contains a single-braced substring. "
388
+ "If you intended to parameterize the question with a Scenario, this will "
389
+ "be changed to a double-braced substring, e.g. {{variable}}.\n"
390
+ "See details on constructing Scenarios in the docs: "
391
+ "https://docs.expectedparrot.com/en/latest/scenarios.html",
405
392
  UserWarning,
406
393
  )
394
+ # Automatically replace single braces with double braces
395
+ # This is here because if the user is using an f-string, the double brace will get converted to a single brace.
396
+ # This undoes that.
397
+ value = re.sub(r"\{([^\{\}]+)\}", r"{{\1}}", value)
398
+ return value
399
+
407
400
  # iterate through all doubles braces and check if they are valid python identifiers
408
401
  for match in re.finditer(r"\{\{([^\{\}]+)\}\}", value):
409
402
  if " " in match.group(1).strip():
@@ -411,6 +404,8 @@ class QuestionTextDescriptor(BaseDescriptor):
411
404
  f"Question text contains an invalid identifier: '{match.group(1)}'"
412
405
  )
413
406
 
407
+ return None
408
+
414
409
 
415
410
  if __name__ == "__main__":
416
411
  import doctest
@@ -437,7 +437,30 @@ class DatasetExportMixin:
437
437
  b64 = base64.b64encode(csv_string.encode()).decode()
438
438
  return f'<a href="data:file/csv;base64,{b64}" download="my_data.csv">Download CSV file</a>'
439
439
 
440
- def to_pandas(self, remove_prefix: bool = False) -> "pd.DataFrame":
440
+ def to_pandas(
441
+ self, remove_prefix: bool = False, lists_as_strings=False
442
+ ) -> "DataFrame":
443
+ """Convert the results to a pandas DataFrame, ensuring that lists remain as lists.
444
+
445
+ :param remove_prefix: Whether to remove the prefix from the column names.
446
+
447
+ """
448
+ return self._to_pandas_strings(remove_prefix)
449
+ # if lists_as_strings:
450
+ # return self._to_pandas_strings(remove_prefix=remove_prefix)
451
+
452
+ # import pandas as pd
453
+
454
+ # df = pd.DataFrame(self.data)
455
+
456
+ # if remove_prefix:
457
+ # # Optionally remove prefixes from column names
458
+ # df.columns = [col.split(".")[-1] for col in df.columns]
459
+
460
+ # df_sorted = df.sort_index(axis=1) # Sort columns alphabetically
461
+ # return df_sorted
462
+
463
+ def _to_pandas_strings(self, remove_prefix: bool = False) -> "pd.DataFrame":
441
464
  """Convert the results to a pandas DataFrame.
442
465
 
443
466
  :param remove_prefix: Whether to remove the prefix from the column names.
@@ -451,6 +474,7 @@ class DatasetExportMixin:
451
474
  2 Terrible
452
475
  3 OK
453
476
  """
477
+
454
478
  import pandas as pd
455
479
 
456
480
  csv_string = self.to_csv(remove_prefix=remove_prefix)
edsl/results/Result.py CHANGED
@@ -117,6 +117,7 @@ class Result(Base, UserDict):
117
117
  "raw_model_response": raw_model_response or {},
118
118
  "question_to_attributes": question_to_attributes,
119
119
  "generated_tokens": generated_tokens or {},
120
+ "comments_dict": comments_dict or {},
120
121
  }
121
122
  super().__init__(**data)
122
123
  # but also store the data as attributes
@@ -155,15 +156,15 @@ class Result(Base, UserDict):
155
156
  if key in self.question_to_attributes:
156
157
  # You might be tempted to just use the naked key
157
158
  # but this is a bad idea because it pollutes the namespace
158
- question_text_dict[key + "_question_text"] = (
159
- self.question_to_attributes[key]["question_text"]
160
- )
161
- question_options_dict[key + "_question_options"] = (
162
- self.question_to_attributes[key]["question_options"]
163
- )
164
- question_type_dict[key + "_question_type"] = (
165
- self.question_to_attributes[key]["question_type"]
166
- )
159
+ question_text_dict[
160
+ key + "_question_text"
161
+ ] = self.question_to_attributes[key]["question_text"]
162
+ question_options_dict[
163
+ key + "_question_options"
164
+ ] = self.question_to_attributes[key]["question_options"]
165
+ question_type_dict[
166
+ key + "_question_type"
167
+ ] = self.question_to_attributes[key]["question_type"]
167
168
 
168
169
  return {
169
170
  "agent": self.agent.traits
@@ -256,10 +257,25 @@ class Result(Base, UserDict):
256
257
 
257
258
  """
258
259
  d = {}
259
- data_types = self.sub_dicts.keys()
260
+ problem_keys = []
261
+ data_types = sorted(self.sub_dicts.keys())
260
262
  for data_type in data_types:
261
263
  for key in self.sub_dicts[data_type]:
264
+ if key in d:
265
+ import warnings
266
+
267
+ warnings.warn(
268
+ f"Key '{key}' of data type '{data_type}' is already in use. Renaming to {key}_{data_type}"
269
+ )
270
+ problem_keys.append((key, data_type))
271
+ key = f"{key}_{data_type}"
272
+ # raise ValueError(f"Key '{key}' is already in the dictionary")
262
273
  d[key] = data_type
274
+
275
+ for key, data_type in problem_keys:
276
+ self.sub_dicts[data_type][f"{key}_{data_type}"] = self.sub_dicts[
277
+ data_type
278
+ ].pop(key)
263
279
  return d
264
280
 
265
281
  def rows(self, index) -> tuple[int, str, str, str]:
@@ -370,6 +386,7 @@ class Result(Base, UserDict):
370
386
  ),
371
387
  question_to_attributes=json_dict.get("question_to_attributes", None),
372
388
  generated_tokens=json_dict.get("generated_tokens", {}),
389
+ comments_dict=json_dict.get("comments_dict", {}),
373
390
  )
374
391
  return result
375
392
 
edsl/results/Results.py CHANGED
@@ -7,11 +7,17 @@ from __future__ import annotations
7
7
  import json
8
8
  import random
9
9
  from collections import UserList, defaultdict
10
- from typing import Optional, Callable, Any, Type, Union, List
10
+ from typing import Optional, Callable, Any, Type, Union, List, TYPE_CHECKING
11
+
12
+ if TYPE_CHECKING:
13
+ from edsl import Survey, Cache, AgentList, ModelList, ScenarioList
14
+ from edsl.results.Result import Result
15
+ from edsl.jobs.tasks.TaskHistory import TaskHistory
11
16
 
12
17
  from simpleeval import EvalWithCompoundTypes
13
18
 
14
19
  from edsl.exceptions.results import (
20
+ ResultsError,
15
21
  ResultsBadMutationstringError,
16
22
  ResultsColumnNotFoundError,
17
23
  ResultsInvalidNameError,
@@ -40,7 +46,7 @@ class Mixins(
40
46
  ResultsGGMixin,
41
47
  ResultsToolsMixin,
42
48
  ):
43
- def print_long(self, max_rows=None) -> None:
49
+ def print_long(self, max_rows: int = None) -> None:
44
50
  """Print the results in long format.
45
51
 
46
52
  >>> from edsl.results import Results
@@ -84,13 +90,13 @@ class Results(UserList, Mixins, Base):
84
90
 
85
91
  def __init__(
86
92
  self,
87
- survey: Optional["Survey"] = None,
88
- data: Optional[list["Result"]] = None,
93
+ survey: Optional[Survey] = None,
94
+ data: Optional[list[Result]] = None,
89
95
  created_columns: Optional[list[str]] = None,
90
- cache: Optional["Cache"] = None,
96
+ cache: Optional[Cache] = None,
91
97
  job_uuid: Optional[str] = None,
92
98
  total_results: Optional[int] = None,
93
- task_history: Optional["TaskHistory"] = None,
99
+ task_history: Optional[TaskHistory] = None,
94
100
  ):
95
101
  """Instantiate a `Results` object with a survey and a list of `Result` objects.
96
102
 
@@ -110,7 +116,7 @@ class Results(UserList, Mixins, Base):
110
116
  self._total_results = total_results
111
117
  self.cache = cache or Cache()
112
118
 
113
- self.task_history = task_history or TaskHistory(interviews = [])
119
+ self.task_history = task_history or TaskHistory(interviews=[])
114
120
 
115
121
  if hasattr(self, "_add_output_functions"):
116
122
  self._add_output_functions()
@@ -235,11 +241,11 @@ class Results(UserList, Mixins, Base):
235
241
  >>> r3 = r + r2
236
242
  """
237
243
  if self.survey != other.survey:
238
- raise Exception(
239
- "The surveys are not the same so they cannot be added together."
244
+ raise ResultsError(
245
+ "The surveys are not the same so the the results cannot be added together."
240
246
  )
241
247
  if self.created_columns != other.created_columns:
242
- raise Exception(
248
+ raise ResultsError(
243
249
  "The created columns are not the same so they cannot be added together."
244
250
  )
245
251
 
@@ -258,16 +264,7 @@ class Results(UserList, Mixins, Base):
258
264
  from IPython.display import HTML
259
265
 
260
266
  json_str = json.dumps(self.to_dict()["data"], indent=4)
261
- from pygments import highlight
262
- from pygments.lexers import JsonLexer
263
- from pygments.formatters import HtmlFormatter
264
-
265
- formatted_json = highlight(
266
- json_str,
267
- JsonLexer(),
268
- HtmlFormatter(style="default", full=True, noclasses=True),
269
- )
270
- return HTML(formatted_json).data
267
+ return f"<pre>{json_str}</pre>"
271
268
 
272
269
  def _to_dict(self, sort=False):
273
270
  from edsl.data.Cache import Cache
@@ -301,7 +298,7 @@ class Results(UserList, Mixins, Base):
301
298
  "b_not_a": [other_results[i] for i in indices_other],
302
299
  }
303
300
 
304
- @property
301
+ @property
305
302
  def has_unfixed_exceptions(self):
306
303
  return self.task_history.has_unfixed_exceptions
307
304
 
@@ -326,7 +323,7 @@ class Results(UserList, Mixins, Base):
326
323
  def hashes(self) -> set:
327
324
  return set(hash(result) for result in self.data)
328
325
 
329
- def sample(self, n: int) -> "Results":
326
+ def sample(self, n: int) -> Results:
330
327
  """Return a random sample of the results.
331
328
 
332
329
  :param n: The number of samples to return.
@@ -344,7 +341,7 @@ class Results(UserList, Mixins, Base):
344
341
  indices = list(range(len(values)))
345
342
  sampled_indices = random.sample(indices, n)
346
343
  if n > len(indices):
347
- raise ValueError(
344
+ raise ResultsError(
348
345
  f"Cannot sample {n} items from a list of length {len(indices)}."
349
346
  )
350
347
  entry[key] = [values[i] for i in sampled_indices]
@@ -397,11 +394,12 @@ class Results(UserList, Mixins, Base):
397
394
  - Uses the key_to_data_type property of the Result class.
398
395
  - Includes any columns that the user has created with `mutate`
399
396
  """
400
- d = {}
397
+ d: dict = {}
401
398
  for result in self.data:
402
399
  d.update(result.key_to_data_type)
403
400
  for column in self.created_columns:
404
401
  d[column] = "answer"
402
+
405
403
  return d
406
404
 
407
405
  @property
@@ -451,7 +449,7 @@ class Results(UserList, Mixins, Base):
451
449
  from edsl.utilities.utilities import shorten_string
452
450
 
453
451
  if not self.survey:
454
- raise Exception("Survey is not defined so no answer keys are available.")
452
+ raise ResultsError("Survey is not defined so no answer keys are available.")
455
453
 
456
454
  answer_keys = self._data_type_to_keys["answer"]
457
455
  answer_keys = {k for k in answer_keys if "_comment" not in k}
@@ -464,7 +462,7 @@ class Results(UserList, Mixins, Base):
464
462
  return sorted_dict
465
463
 
466
464
  @property
467
- def agents(self) -> "AgentList":
465
+ def agents(self) -> AgentList:
468
466
  """Return a list of all of the agents in the Results.
469
467
 
470
468
  Example:
@@ -478,7 +476,7 @@ class Results(UserList, Mixins, Base):
478
476
  return AgentList([r.agent for r in self.data])
479
477
 
480
478
  @property
481
- def models(self) -> list[Type["LanguageModel"]]:
479
+ def models(self) -> ModelList:
482
480
  """Return a list of all of the models in the Results.
483
481
 
484
482
  Example:
@@ -487,10 +485,12 @@ class Results(UserList, Mixins, Base):
487
485
  >>> r.models[0]
488
486
  Model(model_name = ...)
489
487
  """
490
- return [r.model for r in self.data]
488
+ from edsl import ModelList
489
+
490
+ return ModelList([r.model for r in self.data])
491
491
 
492
492
  @property
493
- def scenarios(self) -> "ScenarioList":
493
+ def scenarios(self) -> ScenarioList:
494
494
  """Return a list of all of the scenarios in the Results.
495
495
 
496
496
  Example:
@@ -567,7 +567,7 @@ class Results(UserList, Mixins, Base):
567
567
  )
568
568
  return sorted(list(all_keys))
569
569
 
570
- def first(self) -> "Result":
570
+ def first(self) -> Result:
571
571
  """Return the first observation in the results.
572
572
 
573
573
  Example:
@@ -817,7 +817,7 @@ class Results(UserList, Mixins, Base):
817
817
 
818
818
  return Results(survey=self.survey, data=new_data, created_columns=None)
819
819
 
820
- def select(self, *columns: Union[str, list[str]]) -> "Dataset":
820
+ def select(self, *columns: Union[str, list[str]]) -> Results:
821
821
  """
822
822
  Select data from the results and format it.
823
823
 
@@ -830,93 +830,12 @@ class Results(UserList, Mixins, Base):
830
830
  Dataset([{'answer.how_feeling': ['OK', 'Great', 'Terrible', 'OK']}])
831
831
 
832
832
  >>> results.select('how_feeling', 'model', 'how_feeling')
833
- Dataset([{'answer.how_feeling': ['OK', 'Great', 'Terrible', 'OK']}, {'model.model': ['...', '...', '...', '...']}, {'answer.how_feeling': ['OK', 'Great', 'Terrible', 'OK']}])
833
+ Dataset([{'answer.how_feeling': ['OK', 'Great', 'Terrible', 'OK']}, {'answer.how_feeling': ['OK', 'Great', 'Terrible', 'OK']}, {'model.model': ['...', '...', '...', '...']}, {'answer.how_feeling': ['OK', 'Great', 'Terrible', 'OK']}, {'answer.how_feeling': ['OK', 'Great', 'Terrible', 'OK']}])
834
834
 
835
835
  >>> from edsl import Results; r = Results.example(); r.select('answer.how_feeling_y')
836
836
  Dataset([{'answer.how_feeling_yesterday': ['Great', 'Good', 'OK', 'Terrible']}])
837
837
  """
838
838
 
839
- # if len(self) == 0:
840
- # raise Exception("No data to select from---the Results object is empty.")
841
-
842
- if not columns or columns == ("*",) or columns == (None,):
843
- # is the users passes nothing, then we'll return all the columns
844
- columns = ("*.*",)
845
-
846
- if isinstance(columns[0], list):
847
- columns = tuple(columns[0])
848
-
849
- def get_data_types_to_return(parsed_data_type):
850
- if parsed_data_type == "*": # they want all of the columns
851
- return self.known_data_types
852
- else:
853
- if parsed_data_type not in self.known_data_types:
854
- raise Exception(
855
- f"Data type {parsed_data_type} not found in data. Did you mean one of {self.known_data_types}"
856
- )
857
- return [parsed_data_type]
858
-
859
- # we're doing to populate this with the data we want to fetch
860
- to_fetch = defaultdict(list)
861
-
862
- new_data = []
863
- items_in_order = []
864
- # iterate through the passed columns
865
- for column in columns:
866
- # a user could pass 'result.how_feeling' or just 'how_feeling'
867
- matches = self._matching_columns(column)
868
- if len(matches) > 1:
869
- raise Exception(
870
- f"Column '{column}' is ambiguous. Did you mean one of {matches}?"
871
- )
872
- if len(matches) == 0 and ".*" not in column:
873
- raise Exception(f"Column '{column}' not found in data.")
874
- if len(matches) == 1:
875
- column = matches[0]
876
-
877
- parsed_data_type, parsed_key = self._parse_column(column)
878
- data_types = get_data_types_to_return(parsed_data_type)
879
- found_once = False # we need to track this to make sure we found the key at least once
880
-
881
- for data_type in data_types:
882
- # the keys for that data_type e.g.,# if data_type is 'answer', then the keys are 'how_feeling', 'how_feeling_comment', etc.
883
- relevant_keys = self._data_type_to_keys[data_type]
884
-
885
- for key in relevant_keys:
886
- if key == parsed_key or parsed_key == "*":
887
- found_once = True
888
- to_fetch[data_type].append(key)
889
- items_in_order.append(data_type + "." + key)
890
-
891
- if not found_once:
892
- raise Exception(f"Key {parsed_key} not found in data.")
893
-
894
- for data_type in to_fetch:
895
- for key in to_fetch[data_type]:
896
- entries = self._fetch_list(data_type, key)
897
- new_data.append({data_type + "." + key: entries})
898
-
899
- def sort_by_key_order(dictionary):
900
- # Extract the single key from the dictionary
901
- single_key = next(iter(dictionary))
902
- # Return the index of this key in the list_of_keys
903
- return items_in_order.index(single_key)
904
-
905
- # sorted(new_data, key=sort_by_key_order)
906
- from edsl.results.Dataset import Dataset
907
-
908
- sorted_new_data = []
909
-
910
- # WORKS but slow
911
- for key in items_in_order:
912
- for d in new_data:
913
- if key in d:
914
- sorted_new_data.append(d)
915
- break
916
-
917
- return Dataset(sorted_new_data)
918
-
919
- def select(self, *columns: Union[str, list[str]]) -> "Results":
920
839
  from edsl.results.Selector import Selector
921
840
 
922
841
  if len(self) == 0:
@@ -1026,6 +945,7 @@ class Results(UserList, Mixins, Base):
1026
945
  Traceback (most recent call last):
1027
946
  ...
1028
947
  edsl.exceptions.results.ResultsFilterError: You must use '==' instead of '=' in the filter expression.
948
+ ...
1029
949
 
1030
950
  >>> r.filter("how_feeling == 'Great' or how_feeling == 'Terrible'").select('how_feeling').print()
1031
951
  ┏━━━━━━━━━━━━━━┓
@@ -1103,6 +1023,7 @@ class Results(UserList, Mixins, Base):
1103
1023
  stop_on_exception=True,
1104
1024
  skip_retry=True,
1105
1025
  raise_validation_errors=True,
1026
+ disable_remote_cache=True,
1106
1027
  disable_remote_inference=True,
1107
1028
  )
1108
1029
  return results
@@ -1110,14 +1031,6 @@ class Results(UserList, Mixins, Base):
1110
1031
  def rich_print(self):
1111
1032
  """Display an object as a table."""
1112
1033
  pass
1113
- # with io.StringIO() as buf:
1114
- # console = Console(file=buf, record=True)
1115
-
1116
- # for index, result in enumerate(self):
1117
- # console.print(f"Result {index}")
1118
- # console.print(result.rich_print())
1119
-
1120
- # return console.export_text()
1121
1034
 
1122
1035
  def __str__(self):
1123
1036
  data = self.to_dict()["data"]
@@ -93,7 +93,7 @@ class ResultsDBMixin:
93
93
  from sqlalchemy import create_engine
94
94
 
95
95
  engine = create_engine("sqlite:///:memory:")
96
- df = self.to_pandas(remove_prefix=remove_prefix)
96
+ df = self.to_pandas(remove_prefix=remove_prefix, lists_as_strings=True)
97
97
  df.to_sql("self", engine, index=False, if_exists="replace")
98
98
  return engine.connect()
99
99
  else:
edsl/results/Selector.py CHANGED
@@ -12,6 +12,7 @@ class Selector:
12
12
  fetch_list_func,
13
13
  columns: List[str],
14
14
  ):
15
+ """Selects columns from a Results object"""
15
16
  self.known_data_types = known_data_types
16
17
  self._data_type_to_keys = data_type_to_keys
17
18
  self._key_to_data_type = key_to_data_type
@@ -21,10 +22,19 @@ class Selector:
21
22
  def select(self, *columns: Union[str, List[str]]) -> "Dataset":
22
23
  columns = self._normalize_columns(columns)
23
24
  to_fetch = self._get_columns_to_fetch(columns)
25
+ # breakpoint()
24
26
  new_data = self._fetch_data(to_fetch)
25
27
  return Dataset(new_data)
26
28
 
27
29
  def _normalize_columns(self, columns: Union[str, List[str]]) -> tuple:
30
+ """Normalize the columns to a tuple of strings
31
+
32
+ >>> s = Selector([], {}, {}, lambda x, y: x, [])
33
+ >>> s._normalize_columns([["a", "b"], ])
34
+ ('a', 'b')
35
+ >>> s._normalize_columns(None)
36
+ ('*.*',)
37
+ """
28
38
  if not columns or columns == ("*",) or columns == (None,):
29
39
  return ("*.*",)
30
40
  if isinstance(columns[0], list):
@@ -37,6 +47,7 @@ class Selector:
37
47
 
38
48
  for column in columns:
39
49
  matches = self._find_matching_columns(column)
50
+ # breakpoint()
40
51
  self._validate_matches(column, matches)
41
52
 
42
53
  if len(matches) == 1:
@@ -52,7 +63,7 @@ class Selector:
52
63
  search_in_list = self.columns
53
64
  else:
54
65
  search_in_list = [s.split(".")[1] for s in self.columns]
55
-
66
+ # breakpoint()
56
67
  matches = [s for s in search_in_list if s.startswith(partial_name)]
57
68
  return [partial_name] if partial_name in matches else matches
58
69
 
@@ -116,3 +127,9 @@ class Selector:
116
127
  new_data.append({f"{data_type}.{key}": entries})
117
128
 
118
129
  return [d for key in self.items_in_order for d in new_data if key in d]
130
+
131
+
132
+ if __name__ == "__main__":
133
+ import doctest
134
+
135
+ doctest.testmod()
@@ -77,8 +77,19 @@ class FileStore(Scenario):
77
77
  def __str__(self):
78
78
  return "FileStore: self.path"
79
79
 
80
+ @classmethod
81
+ def example(self):
82
+ import tempfile
83
+
84
+ with tempfile.NamedTemporaryFile(suffix=".txt", delete=False) as f:
85
+ f.write(b"Hello, World!")
86
+
87
+ return self(path=f.name)
88
+
80
89
  @property
81
90
  def size(self) -> int:
91
+ if self.base64_string != None:
92
+ return (len(self.base64_string) / 4.0) * 3 # from base64 to char size
82
93
  return os.path.getsize(self.path)
83
94
 
84
95
  def upload_google(self, refresh: bool = False) -> None:
@@ -93,7 +104,7 @@ class FileStore(Scenario):
93
104
  return cls(**d)
94
105
 
95
106
  def __repr__(self):
96
- return f"FileStore({self.path})"
107
+ return f"FileStore(path='{self.path}')"
97
108
 
98
109
  def encode_file_to_base64_string(self, file_path: str):
99
110
  try:
@@ -272,7 +283,8 @@ class CSVFileStore(FileStore):
272
283
 
273
284
  with tempfile.NamedTemporaryFile(suffix=".csv", delete=False) as f:
274
285
  r.to_csv(filename=f.name)
275
- return cls(f.name)
286
+
287
+ return cls(f.name)
276
288
 
277
289
  def view(self):
278
290
  import pandas as pd
@@ -352,7 +364,8 @@ class PDFFileStore(FileStore):
352
364
 
353
365
  with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as f:
354
366
  f.write(pdf_string.encode())
355
- return cls(f.name)
367
+
368
+ return cls(f.name)
356
369
 
357
370
 
358
371
  class PNGFileStore(FileStore):
@@ -367,7 +380,8 @@ class PNGFileStore(FileStore):
367
380
 
368
381
  with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
369
382
  f.write(png_string.encode())
370
- return cls(f.name)
383
+
384
+ return cls(f.name)
371
385
 
372
386
  def view(self):
373
387
  import matplotlib.pyplot as plt
@@ -407,7 +421,8 @@ class HTMLFileStore(FileStore):
407
421
 
408
422
  with tempfile.NamedTemporaryFile(suffix=".html", delete=False) as f:
409
423
  f.write("<html><body><h1>Test</h1></body></html>".encode())
410
- return cls(f.name)
424
+
425
+ return cls(f.name)
411
426
 
412
427
  def view(self):
413
428
  import webbrowser