edsl 0.1.42__py3-none-any.whl → 0.1.44__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. edsl/Base.py +15 -6
  2. edsl/__version__.py +1 -1
  3. edsl/agents/Invigilator.py +1 -1
  4. edsl/agents/PromptConstructor.py +92 -21
  5. edsl/agents/QuestionInstructionPromptBuilder.py +68 -9
  6. edsl/agents/prompt_helpers.py +2 -2
  7. edsl/coop/coop.py +100 -22
  8. edsl/enums.py +3 -1
  9. edsl/exceptions/coop.py +4 -0
  10. edsl/inference_services/AnthropicService.py +2 -0
  11. edsl/inference_services/AvailableModelFetcher.py +4 -1
  12. edsl/inference_services/GoogleService.py +2 -0
  13. edsl/inference_services/GrokService.py +11 -0
  14. edsl/inference_services/InferenceServiceABC.py +1 -0
  15. edsl/inference_services/OpenAIService.py +1 -0
  16. edsl/inference_services/TestService.py +1 -0
  17. edsl/inference_services/registry.py +2 -0
  18. edsl/jobs/Jobs.py +54 -35
  19. edsl/jobs/JobsChecks.py +7 -7
  20. edsl/jobs/JobsPrompts.py +57 -6
  21. edsl/jobs/JobsRemoteInferenceHandler.py +41 -25
  22. edsl/jobs/buckets/BucketCollection.py +30 -0
  23. edsl/jobs/data_structures.py +1 -0
  24. edsl/language_models/LanguageModel.py +5 -2
  25. edsl/language_models/key_management/KeyLookupBuilder.py +47 -20
  26. edsl/language_models/key_management/models.py +10 -4
  27. edsl/language_models/model.py +43 -11
  28. edsl/prompts/Prompt.py +124 -61
  29. edsl/questions/descriptors.py +32 -18
  30. edsl/questions/question_base_gen_mixin.py +1 -0
  31. edsl/results/DatasetExportMixin.py +35 -6
  32. edsl/results/Results.py +180 -1
  33. edsl/results/ResultsGGMixin.py +117 -60
  34. edsl/scenarios/FileStore.py +19 -8
  35. edsl/scenarios/Scenario.py +33 -0
  36. edsl/scenarios/ScenarioList.py +22 -3
  37. edsl/scenarios/ScenarioListPdfMixin.py +9 -3
  38. edsl/surveys/Survey.py +27 -6
  39. {edsl-0.1.42.dist-info → edsl-0.1.44.dist-info}/METADATA +3 -4
  40. {edsl-0.1.42.dist-info → edsl-0.1.44.dist-info}/RECORD +42 -41
  41. {edsl-0.1.42.dist-info → edsl-0.1.44.dist-info}/LICENSE +0 -0
  42. {edsl-0.1.42.dist-info → edsl-0.1.44.dist-info}/WHEEL +0 -0
edsl/results/Results.py CHANGED
@@ -38,6 +38,64 @@ from edsl.results.ResultsGGMixin import ResultsGGMixin
38
38
  from edsl.results.results_fetch_mixin import ResultsFetchMixin
39
39
  from edsl.utilities.remove_edsl_version import remove_edsl_version
40
40
 
41
+ def ensure_fetched(method):
42
+ """A decorator that checks if remote data is loaded, and if not, attempts to fetch it."""
43
+ def wrapper(self, *args, **kwargs):
44
+ if not self._fetched:
45
+ # If not fetched, try fetching now.
46
+ # (If you know you have job info stored in self.job_info)
47
+ self.fetch_remote(self.job_info)
48
+ return method(self, *args, **kwargs)
49
+ return wrapper
50
+
51
+ def ensure_ready(method):
52
+ """
53
+ Decorator for Results methods.
54
+
55
+ If the Results object is not ready, for most methods we return a NotReadyObject.
56
+ However, for __repr__ (and other methods that need to return a string), we return
57
+ the string representation of NotReadyObject.
58
+ """
59
+ from functools import wraps
60
+
61
+ @wraps(method)
62
+ def wrapper(self, *args, **kwargs):
63
+ if self.completed:
64
+ return method(self, *args, **kwargs)
65
+ # Attempt to fetch remote data
66
+ try:
67
+ if hasattr(self, "job_info"):
68
+ self.fetch_remote(self.job_info)
69
+ except Exception as e:
70
+ print(f"Error during fetch_remote in {method.__name__}: {e}")
71
+ if not self.completed:
72
+ not_ready = NotReadyObject(name = method.__name__, job_info = self.job_info)
73
+ # For __repr__, ensure we return a string
74
+ if method.__name__ == "__repr__" or method.__name__ == "__str__":
75
+ return not_ready.__repr__()
76
+ return not_ready
77
+ return method(self, *args, **kwargs)
78
+
79
+ return wrapper
80
+
81
+ class NotReadyObject:
82
+ """A placeholder object that prints a message when any attribute is accessed."""
83
+ def __init__(self, name: str, job_info: RemoteJobInfo):
84
+ self.name = name
85
+ self.job_info = job_info
86
+ #print(f"Not ready to call {name}")
87
+
88
+ def __repr__(self):
89
+ message = f"""Results not ready - job still running on server."""
90
+ for key, value in self.job_info.creation_data.items():
91
+ message += f"\n{key}: {value}"
92
+ return message
93
+
94
+ def __getattr__(self, _):
95
+ return self
96
+
97
+ def __call__(self, *args, **kwargs):
98
+ return self
41
99
 
42
100
  class Mixins(
43
101
  ResultsExportMixin,
@@ -93,6 +151,16 @@ class Results(UserList, Mixins, Base):
93
151
  "cache_keys",
94
152
  ]
95
153
 
154
+ @classmethod
155
+ def from_job_info(cls, job_info: dict) -> Results:
156
+ """
157
+ Instantiate a `Results` object from a job info dictionary.
158
+ """
159
+ results = cls()
160
+ results.completed = False
161
+ results.job_info = job_info
162
+ return results
163
+
96
164
  def __init__(
97
165
  self,
98
166
  survey: Optional[Survey] = None,
@@ -112,6 +180,8 @@ class Results(UserList, Mixins, Base):
112
180
  :param total_results: An integer representing the total number of results.
113
181
  :cache: A Cache object.
114
182
  """
183
+ self.completed = True
184
+ self._fetching = False
115
185
  super().__init__(data)
116
186
  from edsl.data.Cache import Cache
117
187
  from edsl.jobs.tasks.TaskHistory import TaskHistory
@@ -315,7 +385,22 @@ class Results(UserList, Mixins, Base):
315
385
  data=self.data + other.data,
316
386
  created_columns=self.created_columns,
317
387
  )
318
-
388
+
389
+ def _repr_html_(self):
390
+ if not self.completed:
391
+ if hasattr(self, "job_info"):
392
+ self.fetch_remote(self.job_info)
393
+
394
+ if not self.completed:
395
+ return f"Results not ready to call"
396
+
397
+ return super()._repr_html_()
398
+
399
+ # @ensure_ready
400
+ # def __str__(self):
401
+ # super().__str__()
402
+
403
+ @ensure_ready
319
404
  def __repr__(self) -> str:
320
405
  return f"Results(data = {self.data}, survey = {repr(self.survey)}, created_columns = {self.created_columns})"
321
406
 
@@ -732,6 +817,7 @@ class Results(UserList, Mixins, Base):
732
817
 
733
818
  return self.recode(column, recode_function=f, new_var_name=new_var_name)
734
819
 
820
+ @ensure_ready
735
821
  def recode(
736
822
  self, column: str, recode_function: Optional[Callable], new_var_name=None
737
823
  ) -> Results:
@@ -760,6 +846,7 @@ class Results(UserList, Mixins, Base):
760
846
  created_columns=self.created_columns + [new_var_name],
761
847
  )
762
848
 
849
+ @ensure_ready
763
850
  def add_column(self, column_name: str, values: list) -> Results:
764
851
  """Adds columns to Results
765
852
 
@@ -780,6 +867,7 @@ class Results(UserList, Mixins, Base):
780
867
  created_columns=self.created_columns + [column_name],
781
868
  )
782
869
 
870
+ @ensure_ready
783
871
  def add_columns_from_dict(self, columns: List[dict]) -> Results:
784
872
  """Adds columns to Results from a list of dictionaries.
785
873
 
@@ -829,6 +917,7 @@ class Results(UserList, Mixins, Base):
829
917
  evaluator.functions.update(int=int, float=float)
830
918
  return evaluator
831
919
 
920
+ @ensure_ready
832
921
  def mutate(
833
922
  self, new_var_string: str, functions_dict: Optional[dict] = None
834
923
  ) -> Results:
@@ -879,6 +968,7 @@ class Results(UserList, Mixins, Base):
879
968
  created_columns=self.created_columns + [var_name],
880
969
  )
881
970
 
971
+ @ensure_ready
882
972
  def add_column(self, column_name: str, values: list) -> Results:
883
973
  """Adds columns to Results
884
974
 
@@ -899,6 +989,7 @@ class Results(UserList, Mixins, Base):
899
989
  created_columns=self.created_columns + [column_name],
900
990
  )
901
991
 
992
+ @ensure_ready
902
993
  def rename(self, old_name: str, new_name: str) -> Results:
903
994
  """Rename an answer column in a Results object.
904
995
 
@@ -916,6 +1007,7 @@ class Results(UserList, Mixins, Base):
916
1007
 
917
1008
  return self
918
1009
 
1010
+ @ensure_ready
919
1011
  def shuffle(self, seed: Optional[str] = "edsl") -> Results:
920
1012
  """Shuffle the results.
921
1013
 
@@ -932,6 +1024,7 @@ class Results(UserList, Mixins, Base):
932
1024
  random.shuffle(new_data)
933
1025
  return Results(survey=self.survey, data=new_data, created_columns=None)
934
1026
 
1027
+ @ensure_ready
935
1028
  def sample(
936
1029
  self,
937
1030
  n: Optional[int] = None,
@@ -971,6 +1064,7 @@ class Results(UserList, Mixins, Base):
971
1064
 
972
1065
  return Results(survey=self.survey, data=new_data, created_columns=None)
973
1066
 
1067
+ @ensure_ready
974
1068
  def select(self, *columns: Union[str, list[str]]) -> Results:
975
1069
  """
976
1070
  Select data from the results and format it.
@@ -1004,6 +1098,7 @@ class Results(UserList, Mixins, Base):
1004
1098
  )
1005
1099
  return selector.select(*columns)
1006
1100
 
1101
+ @ensure_ready
1007
1102
  def sort_by(self, *columns: str, reverse: bool = False) -> Results:
1008
1103
  """Sort the results by one or more columns."""
1009
1104
  import warnings
@@ -1019,6 +1114,7 @@ class Results(UserList, Mixins, Base):
1019
1114
  return column.split(".")
1020
1115
  return self._key_to_data_type[column], column
1021
1116
 
1117
+ @ensure_ready
1022
1118
  def order_by(self, *columns: str, reverse: bool = False) -> Results:
1023
1119
  """Sort the results by one or more columns.
1024
1120
 
@@ -1055,6 +1151,7 @@ class Results(UserList, Mixins, Base):
1055
1151
  new_data = sorted(self.data, key=sort_key, reverse=reverse)
1056
1152
  return Results(survey=self.survey, data=new_data, created_columns=None)
1057
1153
 
1154
+ @ensure_ready
1058
1155
  def filter(self, expression: str) -> Results:
1059
1156
  """
1060
1157
  Filter based on the given expression and returns the filtered `Results`.
@@ -1156,6 +1253,7 @@ class Results(UserList, Mixins, Base):
1156
1253
  """Display an object as a table."""
1157
1254
  pass
1158
1255
 
1256
+ @ensure_ready
1159
1257
  def __str__(self):
1160
1258
  data = self.to_dict()["data"]
1161
1259
  return json.dumps(data, indent=4)
@@ -1178,6 +1276,87 @@ class Results(UserList, Mixins, Base):
1178
1276
  [1, 1, 0, 0]
1179
1277
  """
1180
1278
  return [r.score(f) for r in self.data]
1279
+
1280
+
1281
+ def fetch_remote(self, job_info: "RemoteJobInfo") -> None:
1282
+ """
1283
+ Fetches the remote Results object using the provided RemoteJobInfo and updates this instance with the remote data.
1284
+
1285
+ This is useful when you have a Results object that was created locally but want to sync it with
1286
+ the latest data from the remote server.
1287
+
1288
+ Args:
1289
+ job_info: RemoteJobInfo object containing the job_uuid and other remote job details
1290
+
1291
+ """
1292
+ #print("Calling fetch_remote")
1293
+ try:
1294
+ from edsl.coop.coop import Coop
1295
+ from edsl.jobs.JobsRemoteInferenceHandler import JobsRemoteInferenceHandler
1296
+
1297
+ # Get the remote job data
1298
+ remote_job_data = JobsRemoteInferenceHandler.check_status(job_info.job_uuid)
1299
+
1300
+ if remote_job_data.get("status") not in ["completed", "failed"]:
1301
+ return False
1302
+ #
1303
+ results_uuid = remote_job_data.get("results_uuid")
1304
+ if not results_uuid:
1305
+ raise ResultsError("No results_uuid found in remote job data")
1306
+
1307
+ # Fetch the remote Results object
1308
+ coop = Coop()
1309
+ remote_results = coop.get(results_uuid, expected_object_type="results")
1310
+
1311
+ # Update this instance with remote data
1312
+ self.data = remote_results.data
1313
+ self.survey = remote_results.survey
1314
+ self.created_columns = remote_results.created_columns
1315
+ self.cache = remote_results.cache
1316
+ self.task_history = remote_results.task_history
1317
+ self.completed = True
1318
+
1319
+ # Set job_uuid and results_uuid from remote data
1320
+ self.job_uuid = job_info.job_uuid
1321
+ if hasattr(remote_results, 'results_uuid'):
1322
+ self.results_uuid = remote_results.results_uuid
1323
+
1324
+ return True
1325
+
1326
+ except Exception as e:
1327
+ raise ResultsError(f"Failed to fetch remote results: {str(e)}")
1328
+
1329
+ def fetch(self, polling_interval: [float, int] = 1.0) -> Results:
1330
+ """
1331
+ Polls the server for job completion and updates this Results instance with the completed data.
1332
+
1333
+ Args:
1334
+ polling_interval: Number of seconds to wait between polling attempts (default: 1.0)
1335
+
1336
+ Returns:
1337
+ self: The updated Results instance
1338
+ """
1339
+ if not hasattr(self, "job_info"):
1340
+ raise ResultsError("No job info available - this Results object wasn't created from a remote job")
1341
+
1342
+ from edsl.jobs.JobsRemoteInferenceHandler import JobsRemoteInferenceHandler
1343
+
1344
+ try:
1345
+ # Get the remote job data
1346
+ remote_job_data = JobsRemoteInferenceHandler.check_status(self.job_info.job_uuid)
1347
+
1348
+ while remote_job_data.get("status") not in ["completed", "failed"]:
1349
+ print("Waiting for remote job to complete...")
1350
+ import time
1351
+ time.sleep(polling_interval)
1352
+ remote_job_data = JobsRemoteInferenceHandler.check_status(self.job_info.job_uuid)
1353
+
1354
+ # Once complete, fetch the full results
1355
+ self.fetch_remote(self.job_info)
1356
+ return self
1357
+
1358
+ except Exception as e:
1359
+ raise ResultsError(f"Failed to fetch remote results: {str(e)}")
1181
1360
 
1182
1361
 
1183
1362
  def main(): # pragma: no cover
@@ -5,46 +5,113 @@ import tempfile
5
5
  from typing import Optional
6
6
 
7
7
 
8
+ class GGPlot:
9
+ """A class to handle ggplot2 plot display and saving."""
10
+
11
+ def __init__(self, r_code: str, width: float = 6, height: float = 4):
12
+ """Initialize with R code and dimensions."""
13
+ self.r_code = r_code
14
+ self.width = width
15
+ self.height = height
16
+ self._svg_data = None
17
+ self._saved = False # Track if the plot was saved
18
+
19
+ def _execute_r_code(self, save_command: str = ""):
20
+ """Execute R code with optional save command."""
21
+ full_r_code = self.r_code + save_command
22
+
23
+ result = subprocess.run(
24
+ ["Rscript", "-"],
25
+ input=full_r_code,
26
+ text=True,
27
+ stdout=subprocess.PIPE,
28
+ stderr=subprocess.PIPE,
29
+ )
30
+
31
+ if result.returncode != 0:
32
+ if result.returncode == 127:
33
+ raise RuntimeError(
34
+ "Rscript is probably not installed. Please install R from https://cran.r-project.org/"
35
+ )
36
+ else:
37
+ raise RuntimeError(
38
+ f"An error occurred while running Rscript: {result.stderr}"
39
+ )
40
+
41
+ if result.stderr:
42
+ print("Error in R script:", result.stderr)
43
+
44
+ return result
45
+
46
+ def save(self, filename: str):
47
+ """Save the plot to a file."""
48
+ format = filename.split('.')[-1].lower()
49
+ if format not in ['svg', 'png']:
50
+ raise ValueError("Only 'svg' and 'png' formats are supported")
51
+
52
+ save_command = f'\nggsave("{filename}", plot = last_plot(), width = {self.width}, height = {self.height}, device = "{format}")'
53
+ self._execute_r_code(save_command)
54
+
55
+ self._saved = True
56
+ print(f"File saved to: {filename}")
57
+ return None # Return None instead of self
58
+
59
+ def _repr_html_(self):
60
+ """Display the plot in a Jupyter notebook."""
61
+ # Don't display if the plot was saved
62
+ if self._saved:
63
+ return None
64
+
65
+ import tempfile
66
+
67
+ # Generate SVG if we haven't already
68
+ if self._svg_data is None:
69
+ # Create temporary SVG file
70
+ with tempfile.NamedTemporaryFile(suffix='.svg') as tmp:
71
+ save_command = f'\nggsave("{tmp.name}", plot = last_plot(), width = {self.width}, height = {self.height}, device = "svg")'
72
+ self._execute_r_code(save_command)
73
+ with open(tmp.name, 'r') as f:
74
+ self._svg_data = f.read()
75
+
76
+ return self._svg_data
77
+
8
78
  class ResultsGGMixin:
9
79
  """Mixin class for ggplot2 plotting."""
10
80
 
11
81
  def ggplot2(
12
82
  self,
13
83
  ggplot_code: str,
14
- filename: str = None,
15
84
  shape="wide",
16
85
  sql: str = None,
17
86
  remove_prefix: bool = True,
18
87
  debug: bool = False,
19
88
  height=4,
20
89
  width=6,
21
- format="svg",
22
90
  factor_orders: Optional[dict] = None,
23
91
  ):
24
92
  """Create a ggplot2 plot from a DataFrame.
25
93
 
94
+ Returns a GGPlot object that can be displayed in a notebook or saved to a file.
95
+
26
96
  :param ggplot_code: The ggplot2 code to execute.
27
- :param filename: The filename to save the plot to.
28
97
  :param shape: The shape of the data in the DataFrame (wide or long).
29
98
  :param sql: The SQL query to execute beforehand to manipulate the data.
30
99
  :param remove_prefix: Whether to remove the prefix from the column names.
31
100
  :param debug: Whether to print the R code instead of executing it.
32
101
  :param height: The height of the plot in inches.
33
102
  :param width: The width of the plot in inches.
34
- :param format: The format to save the plot in (png or svg).
35
103
  :param factor_orders: A dictionary of factor columns and their order.
36
104
  """
37
-
38
105
  if sql == None:
39
106
  sql = "select * from self"
40
107
 
41
108
  if shape == "long":
42
109
  df = self.sql(sql, shape="long")
43
110
  elif shape == "wide":
44
- df = self.sql(sql, shape="wide", remove_prefix=remove_prefix)
111
+ df = self.sql(sql, remove_prefix=remove_prefix)
45
112
 
46
113
  # Convert DataFrame to CSV format
47
- csv_data = df.to_csv(index=False)
114
+ csv_data = df.to_csv().text
48
115
 
49
116
  # Embed the CSV data within the R script
50
117
  csv_data_escaped = csv_data.replace("\n", "\\n").replace("'", "\\'")
@@ -52,70 +119,60 @@ class ResultsGGMixin:
52
119
 
53
120
  if factor_orders is not None:
54
121
  for factor, order in factor_orders.items():
55
- # read_csv_code += f"""self${{{factor}}} <- factor(self${{{factor}}}, levels=c({','.join(['"{}"'.format(x) for x in order])}))"""
56
-
57
122
  level_string = ", ".join([f'"{x}"' for x in order])
58
123
  read_csv_code += (
59
124
  f"self${factor} <- factor(self${factor}, levels=c({level_string}))"
60
125
  )
61
126
  read_csv_code += "\n"
62
127
 
63
- # Load ggplot2 library
64
- load_ggplot2 = "library(ggplot2)\n"
65
-
66
- # Check if a filename is provided for the plot, if not create a temporary one
67
- if not filename:
68
- filename = tempfile.mktemp(suffix=f".{format}")
69
-
70
- # Combine all R script parts
71
- full_r_code = load_ggplot2 + read_csv_code + ggplot_code
72
-
73
- # Add command to save the plot to a file
74
- full_r_code += f'\nggsave("{filename}", plot = last_plot(), width = {width}, height = {height}, device = "{format}")'
128
+ # Load ggplot2 library and combine all R script parts
129
+ full_r_code = "library(ggplot2)\n" + read_csv_code + ggplot_code
75
130
 
76
131
  if debug:
77
132
  print(full_r_code)
78
133
  return
79
134
 
80
- result = subprocess.run(
81
- ["Rscript", "-"],
82
- input=full_r_code,
83
- text=True,
84
- stdout=subprocess.PIPE,
85
- stderr=subprocess.PIPE,
86
- )
87
-
88
- if result.returncode != 0:
89
- if result.returncode == 127: # 'command not found'
90
- raise RuntimeError(
91
- "Rscript is probably not installed. Please install R from https://cran.r-project.org/"
92
- )
93
- else:
94
- raise RuntimeError(
95
- f"An error occurred while running Rscript: {result.stderr}"
96
- )
97
-
98
- if result.stderr:
99
- print("Error in R script:", result.stderr)
100
- else:
101
- self._display_plot(filename, width, height)
135
+ return GGPlot(full_r_code, width=width, height=height)
102
136
 
103
137
  def _display_plot(self, filename: str, width: float, height: float):
104
- """Display the plot in the notebook."""
105
- import matplotlib.pyplot as plt
106
- import matplotlib.image as mpimg
107
-
108
- if filename.endswith(".png"):
109
- img = mpimg.imread(filename)
110
- plt.figure(
111
- figsize=(width, height)
112
- ) # Set the figure size (width, height) in inches
113
- plt.imshow(img)
114
- plt.axis("off")
115
- plt.show()
116
- elif filename.endswith(".svg"):
117
- from IPython.display import SVG, display
118
-
119
- display(SVG(filename=filename))
138
+ """Display the plot in the notebook or open in system viewer if running from terminal."""
139
+ try:
140
+ # Try to import IPython-related modules
141
+ import matplotlib.pyplot as plt
142
+ import matplotlib.image as mpimg
143
+ from IPython import get_ipython
144
+
145
+ # Check if we're in a notebook environment
146
+ if get_ipython() is not None:
147
+ if filename.endswith(".png"):
148
+ img = mpimg.imread(filename)
149
+ plt.figure(figsize=(width, height))
150
+ plt.imshow(img)
151
+ plt.axis("off")
152
+ plt.show()
153
+ elif filename.endswith(".svg"):
154
+ from IPython.display import SVG, display
155
+ display(SVG(filename=filename))
156
+ else:
157
+ print("Unsupported file format. Please provide a PNG or SVG file.")
158
+ return
159
+
160
+ except ImportError:
161
+ pass
162
+
163
+ # If we're not in a notebook or imports failed, open with system viewer
164
+ import platform
165
+ import os
166
+
167
+ system = platform.system()
168
+ if system == 'Darwin': # macOS
169
+ if filename.endswith('.svg'):
170
+ subprocess.run(['open', '-a', 'Preview', filename])
171
+ else:
172
+ subprocess.run(['open', filename])
173
+ elif system == 'Linux':
174
+ subprocess.run(['xdg-open', filename])
175
+ elif system == 'Windows':
176
+ os.startfile(filename)
120
177
  else:
121
- print("Unsupported file format. Please provide a PNG or SVG file.")
178
+ print(f"File saved to: {filename}")
@@ -9,7 +9,8 @@ from edsl.scenarios.Scenario import Scenario
9
9
  from edsl.utilities.remove_edsl_version import remove_edsl_version
10
10
 
11
11
  from edsl.scenarios.file_methods import FileMethods
12
-
12
+ from typing import Union
13
+ from uuid import UUID
13
14
 
14
15
  class FileStore(Scenario):
15
16
  __documentation__ = "https://docs.expectedparrot.com/en/latest/filestore.html"
@@ -262,7 +263,12 @@ class FileStore(Scenario):
262
263
  # raise TypeError("No text method found for this file type.")
263
264
 
264
265
  def push(
265
- self, description: Optional[str] = None, visibility: str = "unlisted"
266
+ self,
267
+ description: Optional[str] = None,
268
+ alias: Optional[str] = None,
269
+ visibility: Optional[str] = "unlisted",
270
+ expected_parrot_url: Optional[str] = None,
271
+
266
272
  ) -> dict:
267
273
  """
268
274
  Push the object to Coop.
@@ -272,17 +278,22 @@ class FileStore(Scenario):
272
278
  scenario_version = Scenario.from_dict(self.to_dict())
273
279
  if description is None:
274
280
  description = "File: " + self.path
275
- info = scenario_version.push(description=description, visibility=visibility)
281
+ info = scenario_version.push(description=description, visibility=visibility, expected_parrot_url=expected_parrot_url, alias=alias)
276
282
  return info
277
283
 
278
284
  @classmethod
279
- def pull(cls, uuid: str, expected_parrot_url: Optional[str] = None) -> "FileStore":
285
+ def pull(cls, url_or_uuid: Union[str, UUID]) -> "FileStore":
280
286
  """
281
- :param uuid: The UUID of the object to pull.
282
- :param expected_parrot_url: The URL of the Parrot server to use.
283
- :return: The object pulled from the Parrot server.
287
+ Pull a FileStore object from Coop.
288
+
289
+ Args:
290
+ url_or_uuid: Either a UUID string or a URL pointing to the object
291
+ expected_parrot_url: Optional URL for the Parrot server
292
+
293
+ Returns:
294
+ FileStore: The pulled FileStore object
284
295
  """
285
- scenario_version = Scenario.pull(uuid, expected_parrot_url=expected_parrot_url)
296
+ scenario_version = Scenario.pull(url_or_uuid)
286
297
  return cls.from_dict(scenario_version.to_dict())
287
298
 
288
299
  @classmethod
@@ -361,6 +361,39 @@ class Scenario(Base, UserDict, ScenarioHtmlMixin):
361
361
  extractor = PdfExtractor(pdf_path)
362
362
  return Scenario(extractor.get_pdf_dict())
363
363
 
364
+ @classmethod
365
+ def from_pdf_to_image(cls, pdf_path, image_format="jpeg"):
366
+ """
367
+ Convert each page of a PDF into an image and create key/value for it.
368
+
369
+ :param pdf_path: Path to the PDF file.
370
+ :param image_format: Format of the output images (default is 'jpeg').
371
+ :return: ScenarioList instance containing the Scenario instances.
372
+
373
+ The scenario has a key "filepath" and one or more keys "page_{i}" for each page.
374
+ """
375
+ import tempfile
376
+ from pdf2image import convert_from_path
377
+ from edsl.scenarios import Scenario
378
+
379
+ with tempfile.TemporaryDirectory() as output_folder:
380
+ # Convert PDF to images
381
+ images = convert_from_path(pdf_path)
382
+
383
+ scenario_dict = {"filepath":pdf_path}
384
+
385
+ # Save each page as an image and create Scenario instances
386
+ for i, image in enumerate(images):
387
+ image_path = os.path.join(output_folder, f"page_{i}.{image_format}")
388
+ image.save(image_path, image_format.upper())
389
+
390
+ from edsl import FileStore
391
+ scenario_dict[f"page_{i}"] = FileStore(image_path)
392
+
393
+ scenario = Scenario(scenario_dict)
394
+
395
+ return cls(scenario)
396
+
364
397
  @classmethod
365
398
  def from_docx(cls, docx_path: str) -> "Scenario":
366
399
  """Creates a scenario from the text of a docx file.
@@ -1135,7 +1135,7 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
1135
1135
  return cls(observations)
1136
1136
 
1137
1137
  @classmethod
1138
- def from_google_sheet(cls, url: str, sheet_name: str = None) -> ScenarioList:
1138
+ def from_google_sheet(cls, url: str, sheet_name: str = None, column_names: Optional[List[str]]= None) -> ScenarioList:
1139
1139
  """Create a ScenarioList from a Google Sheet.
1140
1140
 
1141
1141
  This method downloads the Google Sheet as an Excel file, saves it to a temporary file,
@@ -1145,6 +1145,8 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
1145
1145
  url (str): The URL to the Google Sheet.
1146
1146
  sheet_name (str, optional): The name of the sheet to load. If None, the method will behave
1147
1147
  the same as from_excel regarding multiple sheets.
1148
+ column_names (List[str], optional): If provided, use these names for the columns instead
1149
+ of the default column names from the sheet.
1148
1150
 
1149
1151
  Returns:
1150
1152
  ScenarioList: An instance of the ScenarioList class.
@@ -1172,8 +1174,25 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
1172
1174
  temp_file.write(response.content)
1173
1175
  temp_filename = temp_file.name
1174
1176
 
1175
- # Call the from_excel class method with the temporary file
1176
- return cls.from_excel(temp_filename, sheet_name=sheet_name)
1177
+ # First create the ScenarioList with default column names
1178
+ scenario_list = cls.from_excel(temp_filename, sheet_name=sheet_name)
1179
+
1180
+ # If column_names is provided, create a new ScenarioList with the specified names
1181
+ if column_names is not None:
1182
+ if len(column_names) != len(scenario_list[0].keys()):
1183
+ raise ValueError(
1184
+ f"Number of provided column names ({len(column_names)}) "
1185
+ f"does not match number of columns in sheet ({len(scenario_list[0].keys())})"
1186
+ )
1187
+
1188
+ # Create a codebook mapping original keys to new names
1189
+ original_keys = list(scenario_list[0].keys())
1190
+ codebook = dict(zip(original_keys, column_names))
1191
+
1192
+ # Return new ScenarioList with renamed columns
1193
+ return scenario_list.rename(codebook)
1194
+ else:
1195
+ return scenario_list
1177
1196
 
1178
1197
  @classmethod
1179
1198
  def from_delimited_file(