edsl 0.1.42__py3-none-any.whl → 0.1.44__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +15 -6
- edsl/__version__.py +1 -1
- edsl/agents/Invigilator.py +1 -1
- edsl/agents/PromptConstructor.py +92 -21
- edsl/agents/QuestionInstructionPromptBuilder.py +68 -9
- edsl/agents/prompt_helpers.py +2 -2
- edsl/coop/coop.py +100 -22
- edsl/enums.py +3 -1
- edsl/exceptions/coop.py +4 -0
- edsl/inference_services/AnthropicService.py +2 -0
- edsl/inference_services/AvailableModelFetcher.py +4 -1
- edsl/inference_services/GoogleService.py +2 -0
- edsl/inference_services/GrokService.py +11 -0
- edsl/inference_services/InferenceServiceABC.py +1 -0
- edsl/inference_services/OpenAIService.py +1 -0
- edsl/inference_services/TestService.py +1 -0
- edsl/inference_services/registry.py +2 -0
- edsl/jobs/Jobs.py +54 -35
- edsl/jobs/JobsChecks.py +7 -7
- edsl/jobs/JobsPrompts.py +57 -6
- edsl/jobs/JobsRemoteInferenceHandler.py +41 -25
- edsl/jobs/buckets/BucketCollection.py +30 -0
- edsl/jobs/data_structures.py +1 -0
- edsl/language_models/LanguageModel.py +5 -2
- edsl/language_models/key_management/KeyLookupBuilder.py +47 -20
- edsl/language_models/key_management/models.py +10 -4
- edsl/language_models/model.py +43 -11
- edsl/prompts/Prompt.py +124 -61
- edsl/questions/descriptors.py +32 -18
- edsl/questions/question_base_gen_mixin.py +1 -0
- edsl/results/DatasetExportMixin.py +35 -6
- edsl/results/Results.py +180 -1
- edsl/results/ResultsGGMixin.py +117 -60
- edsl/scenarios/FileStore.py +19 -8
- edsl/scenarios/Scenario.py +33 -0
- edsl/scenarios/ScenarioList.py +22 -3
- edsl/scenarios/ScenarioListPdfMixin.py +9 -3
- edsl/surveys/Survey.py +27 -6
- {edsl-0.1.42.dist-info → edsl-0.1.44.dist-info}/METADATA +3 -4
- {edsl-0.1.42.dist-info → edsl-0.1.44.dist-info}/RECORD +42 -41
- {edsl-0.1.42.dist-info → edsl-0.1.44.dist-info}/LICENSE +0 -0
- {edsl-0.1.42.dist-info → edsl-0.1.44.dist-info}/WHEEL +0 -0
edsl/results/Results.py
CHANGED
@@ -38,6 +38,64 @@ from edsl.results.ResultsGGMixin import ResultsGGMixin
|
|
38
38
|
from edsl.results.results_fetch_mixin import ResultsFetchMixin
|
39
39
|
from edsl.utilities.remove_edsl_version import remove_edsl_version
|
40
40
|
|
41
|
+
def ensure_fetched(method):
|
42
|
+
"""A decorator that checks if remote data is loaded, and if not, attempts to fetch it."""
|
43
|
+
def wrapper(self, *args, **kwargs):
|
44
|
+
if not self._fetched:
|
45
|
+
# If not fetched, try fetching now.
|
46
|
+
# (If you know you have job info stored in self.job_info)
|
47
|
+
self.fetch_remote(self.job_info)
|
48
|
+
return method(self, *args, **kwargs)
|
49
|
+
return wrapper
|
50
|
+
|
51
|
+
def ensure_ready(method):
|
52
|
+
"""
|
53
|
+
Decorator for Results methods.
|
54
|
+
|
55
|
+
If the Results object is not ready, for most methods we return a NotReadyObject.
|
56
|
+
However, for __repr__ (and other methods that need to return a string), we return
|
57
|
+
the string representation of NotReadyObject.
|
58
|
+
"""
|
59
|
+
from functools import wraps
|
60
|
+
|
61
|
+
@wraps(method)
|
62
|
+
def wrapper(self, *args, **kwargs):
|
63
|
+
if self.completed:
|
64
|
+
return method(self, *args, **kwargs)
|
65
|
+
# Attempt to fetch remote data
|
66
|
+
try:
|
67
|
+
if hasattr(self, "job_info"):
|
68
|
+
self.fetch_remote(self.job_info)
|
69
|
+
except Exception as e:
|
70
|
+
print(f"Error during fetch_remote in {method.__name__}: {e}")
|
71
|
+
if not self.completed:
|
72
|
+
not_ready = NotReadyObject(name = method.__name__, job_info = self.job_info)
|
73
|
+
# For __repr__, ensure we return a string
|
74
|
+
if method.__name__ == "__repr__" or method.__name__ == "__str__":
|
75
|
+
return not_ready.__repr__()
|
76
|
+
return not_ready
|
77
|
+
return method(self, *args, **kwargs)
|
78
|
+
|
79
|
+
return wrapper
|
80
|
+
|
81
|
+
class NotReadyObject:
|
82
|
+
"""A placeholder object that prints a message when any attribute is accessed."""
|
83
|
+
def __init__(self, name: str, job_info: RemoteJobInfo):
|
84
|
+
self.name = name
|
85
|
+
self.job_info = job_info
|
86
|
+
#print(f"Not ready to call {name}")
|
87
|
+
|
88
|
+
def __repr__(self):
|
89
|
+
message = f"""Results not ready - job still running on server."""
|
90
|
+
for key, value in self.job_info.creation_data.items():
|
91
|
+
message += f"\n{key}: {value}"
|
92
|
+
return message
|
93
|
+
|
94
|
+
def __getattr__(self, _):
|
95
|
+
return self
|
96
|
+
|
97
|
+
def __call__(self, *args, **kwargs):
|
98
|
+
return self
|
41
99
|
|
42
100
|
class Mixins(
|
43
101
|
ResultsExportMixin,
|
@@ -93,6 +151,16 @@ class Results(UserList, Mixins, Base):
|
|
93
151
|
"cache_keys",
|
94
152
|
]
|
95
153
|
|
154
|
+
@classmethod
|
155
|
+
def from_job_info(cls, job_info: dict) -> Results:
|
156
|
+
"""
|
157
|
+
Instantiate a `Results` object from a job info dictionary.
|
158
|
+
"""
|
159
|
+
results = cls()
|
160
|
+
results.completed = False
|
161
|
+
results.job_info = job_info
|
162
|
+
return results
|
163
|
+
|
96
164
|
def __init__(
|
97
165
|
self,
|
98
166
|
survey: Optional[Survey] = None,
|
@@ -112,6 +180,8 @@ class Results(UserList, Mixins, Base):
|
|
112
180
|
:param total_results: An integer representing the total number of results.
|
113
181
|
:cache: A Cache object.
|
114
182
|
"""
|
183
|
+
self.completed = True
|
184
|
+
self._fetching = False
|
115
185
|
super().__init__(data)
|
116
186
|
from edsl.data.Cache import Cache
|
117
187
|
from edsl.jobs.tasks.TaskHistory import TaskHistory
|
@@ -315,7 +385,22 @@ class Results(UserList, Mixins, Base):
|
|
315
385
|
data=self.data + other.data,
|
316
386
|
created_columns=self.created_columns,
|
317
387
|
)
|
318
|
-
|
388
|
+
|
389
|
+
def _repr_html_(self):
|
390
|
+
if not self.completed:
|
391
|
+
if hasattr(self, "job_info"):
|
392
|
+
self.fetch_remote(self.job_info)
|
393
|
+
|
394
|
+
if not self.completed:
|
395
|
+
return f"Results not ready to call"
|
396
|
+
|
397
|
+
return super()._repr_html_()
|
398
|
+
|
399
|
+
# @ensure_ready
|
400
|
+
# def __str__(self):
|
401
|
+
# super().__str__()
|
402
|
+
|
403
|
+
@ensure_ready
|
319
404
|
def __repr__(self) -> str:
|
320
405
|
return f"Results(data = {self.data}, survey = {repr(self.survey)}, created_columns = {self.created_columns})"
|
321
406
|
|
@@ -732,6 +817,7 @@ class Results(UserList, Mixins, Base):
|
|
732
817
|
|
733
818
|
return self.recode(column, recode_function=f, new_var_name=new_var_name)
|
734
819
|
|
820
|
+
@ensure_ready
|
735
821
|
def recode(
|
736
822
|
self, column: str, recode_function: Optional[Callable], new_var_name=None
|
737
823
|
) -> Results:
|
@@ -760,6 +846,7 @@ class Results(UserList, Mixins, Base):
|
|
760
846
|
created_columns=self.created_columns + [new_var_name],
|
761
847
|
)
|
762
848
|
|
849
|
+
@ensure_ready
|
763
850
|
def add_column(self, column_name: str, values: list) -> Results:
|
764
851
|
"""Adds columns to Results
|
765
852
|
|
@@ -780,6 +867,7 @@ class Results(UserList, Mixins, Base):
|
|
780
867
|
created_columns=self.created_columns + [column_name],
|
781
868
|
)
|
782
869
|
|
870
|
+
@ensure_ready
|
783
871
|
def add_columns_from_dict(self, columns: List[dict]) -> Results:
|
784
872
|
"""Adds columns to Results from a list of dictionaries.
|
785
873
|
|
@@ -829,6 +917,7 @@ class Results(UserList, Mixins, Base):
|
|
829
917
|
evaluator.functions.update(int=int, float=float)
|
830
918
|
return evaluator
|
831
919
|
|
920
|
+
@ensure_ready
|
832
921
|
def mutate(
|
833
922
|
self, new_var_string: str, functions_dict: Optional[dict] = None
|
834
923
|
) -> Results:
|
@@ -879,6 +968,7 @@ class Results(UserList, Mixins, Base):
|
|
879
968
|
created_columns=self.created_columns + [var_name],
|
880
969
|
)
|
881
970
|
|
971
|
+
@ensure_ready
|
882
972
|
def add_column(self, column_name: str, values: list) -> Results:
|
883
973
|
"""Adds columns to Results
|
884
974
|
|
@@ -899,6 +989,7 @@ class Results(UserList, Mixins, Base):
|
|
899
989
|
created_columns=self.created_columns + [column_name],
|
900
990
|
)
|
901
991
|
|
992
|
+
@ensure_ready
|
902
993
|
def rename(self, old_name: str, new_name: str) -> Results:
|
903
994
|
"""Rename an answer column in a Results object.
|
904
995
|
|
@@ -916,6 +1007,7 @@ class Results(UserList, Mixins, Base):
|
|
916
1007
|
|
917
1008
|
return self
|
918
1009
|
|
1010
|
+
@ensure_ready
|
919
1011
|
def shuffle(self, seed: Optional[str] = "edsl") -> Results:
|
920
1012
|
"""Shuffle the results.
|
921
1013
|
|
@@ -932,6 +1024,7 @@ class Results(UserList, Mixins, Base):
|
|
932
1024
|
random.shuffle(new_data)
|
933
1025
|
return Results(survey=self.survey, data=new_data, created_columns=None)
|
934
1026
|
|
1027
|
+
@ensure_ready
|
935
1028
|
def sample(
|
936
1029
|
self,
|
937
1030
|
n: Optional[int] = None,
|
@@ -971,6 +1064,7 @@ class Results(UserList, Mixins, Base):
|
|
971
1064
|
|
972
1065
|
return Results(survey=self.survey, data=new_data, created_columns=None)
|
973
1066
|
|
1067
|
+
@ensure_ready
|
974
1068
|
def select(self, *columns: Union[str, list[str]]) -> Results:
|
975
1069
|
"""
|
976
1070
|
Select data from the results and format it.
|
@@ -1004,6 +1098,7 @@ class Results(UserList, Mixins, Base):
|
|
1004
1098
|
)
|
1005
1099
|
return selector.select(*columns)
|
1006
1100
|
|
1101
|
+
@ensure_ready
|
1007
1102
|
def sort_by(self, *columns: str, reverse: bool = False) -> Results:
|
1008
1103
|
"""Sort the results by one or more columns."""
|
1009
1104
|
import warnings
|
@@ -1019,6 +1114,7 @@ class Results(UserList, Mixins, Base):
|
|
1019
1114
|
return column.split(".")
|
1020
1115
|
return self._key_to_data_type[column], column
|
1021
1116
|
|
1117
|
+
@ensure_ready
|
1022
1118
|
def order_by(self, *columns: str, reverse: bool = False) -> Results:
|
1023
1119
|
"""Sort the results by one or more columns.
|
1024
1120
|
|
@@ -1055,6 +1151,7 @@ class Results(UserList, Mixins, Base):
|
|
1055
1151
|
new_data = sorted(self.data, key=sort_key, reverse=reverse)
|
1056
1152
|
return Results(survey=self.survey, data=new_data, created_columns=None)
|
1057
1153
|
|
1154
|
+
@ensure_ready
|
1058
1155
|
def filter(self, expression: str) -> Results:
|
1059
1156
|
"""
|
1060
1157
|
Filter based on the given expression and returns the filtered `Results`.
|
@@ -1156,6 +1253,7 @@ class Results(UserList, Mixins, Base):
|
|
1156
1253
|
"""Display an object as a table."""
|
1157
1254
|
pass
|
1158
1255
|
|
1256
|
+
@ensure_ready
|
1159
1257
|
def __str__(self):
|
1160
1258
|
data = self.to_dict()["data"]
|
1161
1259
|
return json.dumps(data, indent=4)
|
@@ -1178,6 +1276,87 @@ class Results(UserList, Mixins, Base):
|
|
1178
1276
|
[1, 1, 0, 0]
|
1179
1277
|
"""
|
1180
1278
|
return [r.score(f) for r in self.data]
|
1279
|
+
|
1280
|
+
|
1281
|
+
def fetch_remote(self, job_info: "RemoteJobInfo") -> None:
|
1282
|
+
"""
|
1283
|
+
Fetches the remote Results object using the provided RemoteJobInfo and updates this instance with the remote data.
|
1284
|
+
|
1285
|
+
This is useful when you have a Results object that was created locally but want to sync it with
|
1286
|
+
the latest data from the remote server.
|
1287
|
+
|
1288
|
+
Args:
|
1289
|
+
job_info: RemoteJobInfo object containing the job_uuid and other remote job details
|
1290
|
+
|
1291
|
+
"""
|
1292
|
+
#print("Calling fetch_remote")
|
1293
|
+
try:
|
1294
|
+
from edsl.coop.coop import Coop
|
1295
|
+
from edsl.jobs.JobsRemoteInferenceHandler import JobsRemoteInferenceHandler
|
1296
|
+
|
1297
|
+
# Get the remote job data
|
1298
|
+
remote_job_data = JobsRemoteInferenceHandler.check_status(job_info.job_uuid)
|
1299
|
+
|
1300
|
+
if remote_job_data.get("status") not in ["completed", "failed"]:
|
1301
|
+
return False
|
1302
|
+
#
|
1303
|
+
results_uuid = remote_job_data.get("results_uuid")
|
1304
|
+
if not results_uuid:
|
1305
|
+
raise ResultsError("No results_uuid found in remote job data")
|
1306
|
+
|
1307
|
+
# Fetch the remote Results object
|
1308
|
+
coop = Coop()
|
1309
|
+
remote_results = coop.get(results_uuid, expected_object_type="results")
|
1310
|
+
|
1311
|
+
# Update this instance with remote data
|
1312
|
+
self.data = remote_results.data
|
1313
|
+
self.survey = remote_results.survey
|
1314
|
+
self.created_columns = remote_results.created_columns
|
1315
|
+
self.cache = remote_results.cache
|
1316
|
+
self.task_history = remote_results.task_history
|
1317
|
+
self.completed = True
|
1318
|
+
|
1319
|
+
# Set job_uuid and results_uuid from remote data
|
1320
|
+
self.job_uuid = job_info.job_uuid
|
1321
|
+
if hasattr(remote_results, 'results_uuid'):
|
1322
|
+
self.results_uuid = remote_results.results_uuid
|
1323
|
+
|
1324
|
+
return True
|
1325
|
+
|
1326
|
+
except Exception as e:
|
1327
|
+
raise ResultsError(f"Failed to fetch remote results: {str(e)}")
|
1328
|
+
|
1329
|
+
def fetch(self, polling_interval: [float, int] = 1.0) -> Results:
|
1330
|
+
"""
|
1331
|
+
Polls the server for job completion and updates this Results instance with the completed data.
|
1332
|
+
|
1333
|
+
Args:
|
1334
|
+
polling_interval: Number of seconds to wait between polling attempts (default: 1.0)
|
1335
|
+
|
1336
|
+
Returns:
|
1337
|
+
self: The updated Results instance
|
1338
|
+
"""
|
1339
|
+
if not hasattr(self, "job_info"):
|
1340
|
+
raise ResultsError("No job info available - this Results object wasn't created from a remote job")
|
1341
|
+
|
1342
|
+
from edsl.jobs.JobsRemoteInferenceHandler import JobsRemoteInferenceHandler
|
1343
|
+
|
1344
|
+
try:
|
1345
|
+
# Get the remote job data
|
1346
|
+
remote_job_data = JobsRemoteInferenceHandler.check_status(self.job_info.job_uuid)
|
1347
|
+
|
1348
|
+
while remote_job_data.get("status") not in ["completed", "failed"]:
|
1349
|
+
print("Waiting for remote job to complete...")
|
1350
|
+
import time
|
1351
|
+
time.sleep(polling_interval)
|
1352
|
+
remote_job_data = JobsRemoteInferenceHandler.check_status(self.job_info.job_uuid)
|
1353
|
+
|
1354
|
+
# Once complete, fetch the full results
|
1355
|
+
self.fetch_remote(self.job_info)
|
1356
|
+
return self
|
1357
|
+
|
1358
|
+
except Exception as e:
|
1359
|
+
raise ResultsError(f"Failed to fetch remote results: {str(e)}")
|
1181
1360
|
|
1182
1361
|
|
1183
1362
|
def main(): # pragma: no cover
|
edsl/results/ResultsGGMixin.py
CHANGED
@@ -5,46 +5,113 @@ import tempfile
|
|
5
5
|
from typing import Optional
|
6
6
|
|
7
7
|
|
8
|
+
class GGPlot:
|
9
|
+
"""A class to handle ggplot2 plot display and saving."""
|
10
|
+
|
11
|
+
def __init__(self, r_code: str, width: float = 6, height: float = 4):
|
12
|
+
"""Initialize with R code and dimensions."""
|
13
|
+
self.r_code = r_code
|
14
|
+
self.width = width
|
15
|
+
self.height = height
|
16
|
+
self._svg_data = None
|
17
|
+
self._saved = False # Track if the plot was saved
|
18
|
+
|
19
|
+
def _execute_r_code(self, save_command: str = ""):
|
20
|
+
"""Execute R code with optional save command."""
|
21
|
+
full_r_code = self.r_code + save_command
|
22
|
+
|
23
|
+
result = subprocess.run(
|
24
|
+
["Rscript", "-"],
|
25
|
+
input=full_r_code,
|
26
|
+
text=True,
|
27
|
+
stdout=subprocess.PIPE,
|
28
|
+
stderr=subprocess.PIPE,
|
29
|
+
)
|
30
|
+
|
31
|
+
if result.returncode != 0:
|
32
|
+
if result.returncode == 127:
|
33
|
+
raise RuntimeError(
|
34
|
+
"Rscript is probably not installed. Please install R from https://cran.r-project.org/"
|
35
|
+
)
|
36
|
+
else:
|
37
|
+
raise RuntimeError(
|
38
|
+
f"An error occurred while running Rscript: {result.stderr}"
|
39
|
+
)
|
40
|
+
|
41
|
+
if result.stderr:
|
42
|
+
print("Error in R script:", result.stderr)
|
43
|
+
|
44
|
+
return result
|
45
|
+
|
46
|
+
def save(self, filename: str):
|
47
|
+
"""Save the plot to a file."""
|
48
|
+
format = filename.split('.')[-1].lower()
|
49
|
+
if format not in ['svg', 'png']:
|
50
|
+
raise ValueError("Only 'svg' and 'png' formats are supported")
|
51
|
+
|
52
|
+
save_command = f'\nggsave("{filename}", plot = last_plot(), width = {self.width}, height = {self.height}, device = "{format}")'
|
53
|
+
self._execute_r_code(save_command)
|
54
|
+
|
55
|
+
self._saved = True
|
56
|
+
print(f"File saved to: {filename}")
|
57
|
+
return None # Return None instead of self
|
58
|
+
|
59
|
+
def _repr_html_(self):
|
60
|
+
"""Display the plot in a Jupyter notebook."""
|
61
|
+
# Don't display if the plot was saved
|
62
|
+
if self._saved:
|
63
|
+
return None
|
64
|
+
|
65
|
+
import tempfile
|
66
|
+
|
67
|
+
# Generate SVG if we haven't already
|
68
|
+
if self._svg_data is None:
|
69
|
+
# Create temporary SVG file
|
70
|
+
with tempfile.NamedTemporaryFile(suffix='.svg') as tmp:
|
71
|
+
save_command = f'\nggsave("{tmp.name}", plot = last_plot(), width = {self.width}, height = {self.height}, device = "svg")'
|
72
|
+
self._execute_r_code(save_command)
|
73
|
+
with open(tmp.name, 'r') as f:
|
74
|
+
self._svg_data = f.read()
|
75
|
+
|
76
|
+
return self._svg_data
|
77
|
+
|
8
78
|
class ResultsGGMixin:
|
9
79
|
"""Mixin class for ggplot2 plotting."""
|
10
80
|
|
11
81
|
def ggplot2(
|
12
82
|
self,
|
13
83
|
ggplot_code: str,
|
14
|
-
filename: str = None,
|
15
84
|
shape="wide",
|
16
85
|
sql: str = None,
|
17
86
|
remove_prefix: bool = True,
|
18
87
|
debug: bool = False,
|
19
88
|
height=4,
|
20
89
|
width=6,
|
21
|
-
format="svg",
|
22
90
|
factor_orders: Optional[dict] = None,
|
23
91
|
):
|
24
92
|
"""Create a ggplot2 plot from a DataFrame.
|
25
93
|
|
94
|
+
Returns a GGPlot object that can be displayed in a notebook or saved to a file.
|
95
|
+
|
26
96
|
:param ggplot_code: The ggplot2 code to execute.
|
27
|
-
:param filename: The filename to save the plot to.
|
28
97
|
:param shape: The shape of the data in the DataFrame (wide or long).
|
29
98
|
:param sql: The SQL query to execute beforehand to manipulate the data.
|
30
99
|
:param remove_prefix: Whether to remove the prefix from the column names.
|
31
100
|
:param debug: Whether to print the R code instead of executing it.
|
32
101
|
:param height: The height of the plot in inches.
|
33
102
|
:param width: The width of the plot in inches.
|
34
|
-
:param format: The format to save the plot in (png or svg).
|
35
103
|
:param factor_orders: A dictionary of factor columns and their order.
|
36
104
|
"""
|
37
|
-
|
38
105
|
if sql == None:
|
39
106
|
sql = "select * from self"
|
40
107
|
|
41
108
|
if shape == "long":
|
42
109
|
df = self.sql(sql, shape="long")
|
43
110
|
elif shape == "wide":
|
44
|
-
df = self.sql(sql,
|
111
|
+
df = self.sql(sql, remove_prefix=remove_prefix)
|
45
112
|
|
46
113
|
# Convert DataFrame to CSV format
|
47
|
-
csv_data = df.to_csv(
|
114
|
+
csv_data = df.to_csv().text
|
48
115
|
|
49
116
|
# Embed the CSV data within the R script
|
50
117
|
csv_data_escaped = csv_data.replace("\n", "\\n").replace("'", "\\'")
|
@@ -52,70 +119,60 @@ class ResultsGGMixin:
|
|
52
119
|
|
53
120
|
if factor_orders is not None:
|
54
121
|
for factor, order in factor_orders.items():
|
55
|
-
# read_csv_code += f"""self${{{factor}}} <- factor(self${{{factor}}}, levels=c({','.join(['"{}"'.format(x) for x in order])}))"""
|
56
|
-
|
57
122
|
level_string = ", ".join([f'"{x}"' for x in order])
|
58
123
|
read_csv_code += (
|
59
124
|
f"self${factor} <- factor(self${factor}, levels=c({level_string}))"
|
60
125
|
)
|
61
126
|
read_csv_code += "\n"
|
62
127
|
|
63
|
-
# Load ggplot2 library
|
64
|
-
|
65
|
-
|
66
|
-
# Check if a filename is provided for the plot, if not create a temporary one
|
67
|
-
if not filename:
|
68
|
-
filename = tempfile.mktemp(suffix=f".{format}")
|
69
|
-
|
70
|
-
# Combine all R script parts
|
71
|
-
full_r_code = load_ggplot2 + read_csv_code + ggplot_code
|
72
|
-
|
73
|
-
# Add command to save the plot to a file
|
74
|
-
full_r_code += f'\nggsave("{filename}", plot = last_plot(), width = {width}, height = {height}, device = "{format}")'
|
128
|
+
# Load ggplot2 library and combine all R script parts
|
129
|
+
full_r_code = "library(ggplot2)\n" + read_csv_code + ggplot_code
|
75
130
|
|
76
131
|
if debug:
|
77
132
|
print(full_r_code)
|
78
133
|
return
|
79
134
|
|
80
|
-
|
81
|
-
["Rscript", "-"],
|
82
|
-
input=full_r_code,
|
83
|
-
text=True,
|
84
|
-
stdout=subprocess.PIPE,
|
85
|
-
stderr=subprocess.PIPE,
|
86
|
-
)
|
87
|
-
|
88
|
-
if result.returncode != 0:
|
89
|
-
if result.returncode == 127: # 'command not found'
|
90
|
-
raise RuntimeError(
|
91
|
-
"Rscript is probably not installed. Please install R from https://cran.r-project.org/"
|
92
|
-
)
|
93
|
-
else:
|
94
|
-
raise RuntimeError(
|
95
|
-
f"An error occurred while running Rscript: {result.stderr}"
|
96
|
-
)
|
97
|
-
|
98
|
-
if result.stderr:
|
99
|
-
print("Error in R script:", result.stderr)
|
100
|
-
else:
|
101
|
-
self._display_plot(filename, width, height)
|
135
|
+
return GGPlot(full_r_code, width=width, height=height)
|
102
136
|
|
103
137
|
def _display_plot(self, filename: str, width: float, height: float):
|
104
|
-
"""Display the plot in the notebook."""
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
138
|
+
"""Display the plot in the notebook or open in system viewer if running from terminal."""
|
139
|
+
try:
|
140
|
+
# Try to import IPython-related modules
|
141
|
+
import matplotlib.pyplot as plt
|
142
|
+
import matplotlib.image as mpimg
|
143
|
+
from IPython import get_ipython
|
144
|
+
|
145
|
+
# Check if we're in a notebook environment
|
146
|
+
if get_ipython() is not None:
|
147
|
+
if filename.endswith(".png"):
|
148
|
+
img = mpimg.imread(filename)
|
149
|
+
plt.figure(figsize=(width, height))
|
150
|
+
plt.imshow(img)
|
151
|
+
plt.axis("off")
|
152
|
+
plt.show()
|
153
|
+
elif filename.endswith(".svg"):
|
154
|
+
from IPython.display import SVG, display
|
155
|
+
display(SVG(filename=filename))
|
156
|
+
else:
|
157
|
+
print("Unsupported file format. Please provide a PNG or SVG file.")
|
158
|
+
return
|
159
|
+
|
160
|
+
except ImportError:
|
161
|
+
pass
|
162
|
+
|
163
|
+
# If we're not in a notebook or imports failed, open with system viewer
|
164
|
+
import platform
|
165
|
+
import os
|
166
|
+
|
167
|
+
system = platform.system()
|
168
|
+
if system == 'Darwin': # macOS
|
169
|
+
if filename.endswith('.svg'):
|
170
|
+
subprocess.run(['open', '-a', 'Preview', filename])
|
171
|
+
else:
|
172
|
+
subprocess.run(['open', filename])
|
173
|
+
elif system == 'Linux':
|
174
|
+
subprocess.run(['xdg-open', filename])
|
175
|
+
elif system == 'Windows':
|
176
|
+
os.startfile(filename)
|
120
177
|
else:
|
121
|
-
print("
|
178
|
+
print(f"File saved to: {filename}")
|
edsl/scenarios/FileStore.py
CHANGED
@@ -9,7 +9,8 @@ from edsl.scenarios.Scenario import Scenario
|
|
9
9
|
from edsl.utilities.remove_edsl_version import remove_edsl_version
|
10
10
|
|
11
11
|
from edsl.scenarios.file_methods import FileMethods
|
12
|
-
|
12
|
+
from typing import Union
|
13
|
+
from uuid import UUID
|
13
14
|
|
14
15
|
class FileStore(Scenario):
|
15
16
|
__documentation__ = "https://docs.expectedparrot.com/en/latest/filestore.html"
|
@@ -262,7 +263,12 @@ class FileStore(Scenario):
|
|
262
263
|
# raise TypeError("No text method found for this file type.")
|
263
264
|
|
264
265
|
def push(
|
265
|
-
self,
|
266
|
+
self,
|
267
|
+
description: Optional[str] = None,
|
268
|
+
alias: Optional[str] = None,
|
269
|
+
visibility: Optional[str] = "unlisted",
|
270
|
+
expected_parrot_url: Optional[str] = None,
|
271
|
+
|
266
272
|
) -> dict:
|
267
273
|
"""
|
268
274
|
Push the object to Coop.
|
@@ -272,17 +278,22 @@ class FileStore(Scenario):
|
|
272
278
|
scenario_version = Scenario.from_dict(self.to_dict())
|
273
279
|
if description is None:
|
274
280
|
description = "File: " + self.path
|
275
|
-
info = scenario_version.push(description=description, visibility=visibility)
|
281
|
+
info = scenario_version.push(description=description, visibility=visibility, expected_parrot_url=expected_parrot_url, alias=alias)
|
276
282
|
return info
|
277
283
|
|
278
284
|
@classmethod
|
279
|
-
def pull(cls,
|
285
|
+
def pull(cls, url_or_uuid: Union[str, UUID]) -> "FileStore":
|
280
286
|
"""
|
281
|
-
|
282
|
-
|
283
|
-
:
|
287
|
+
Pull a FileStore object from Coop.
|
288
|
+
|
289
|
+
Args:
|
290
|
+
url_or_uuid: Either a UUID string or a URL pointing to the object
|
291
|
+
expected_parrot_url: Optional URL for the Parrot server
|
292
|
+
|
293
|
+
Returns:
|
294
|
+
FileStore: The pulled FileStore object
|
284
295
|
"""
|
285
|
-
scenario_version = Scenario.pull(
|
296
|
+
scenario_version = Scenario.pull(url_or_uuid)
|
286
297
|
return cls.from_dict(scenario_version.to_dict())
|
287
298
|
|
288
299
|
@classmethod
|
edsl/scenarios/Scenario.py
CHANGED
@@ -361,6 +361,39 @@ class Scenario(Base, UserDict, ScenarioHtmlMixin):
|
|
361
361
|
extractor = PdfExtractor(pdf_path)
|
362
362
|
return Scenario(extractor.get_pdf_dict())
|
363
363
|
|
364
|
+
@classmethod
|
365
|
+
def from_pdf_to_image(cls, pdf_path, image_format="jpeg"):
|
366
|
+
"""
|
367
|
+
Convert each page of a PDF into an image and create key/value for it.
|
368
|
+
|
369
|
+
:param pdf_path: Path to the PDF file.
|
370
|
+
:param image_format: Format of the output images (default is 'jpeg').
|
371
|
+
:return: ScenarioList instance containing the Scenario instances.
|
372
|
+
|
373
|
+
The scenario has a key "filepath" and one or more keys "page_{i}" for each page.
|
374
|
+
"""
|
375
|
+
import tempfile
|
376
|
+
from pdf2image import convert_from_path
|
377
|
+
from edsl.scenarios import Scenario
|
378
|
+
|
379
|
+
with tempfile.TemporaryDirectory() as output_folder:
|
380
|
+
# Convert PDF to images
|
381
|
+
images = convert_from_path(pdf_path)
|
382
|
+
|
383
|
+
scenario_dict = {"filepath":pdf_path}
|
384
|
+
|
385
|
+
# Save each page as an image and create Scenario instances
|
386
|
+
for i, image in enumerate(images):
|
387
|
+
image_path = os.path.join(output_folder, f"page_{i}.{image_format}")
|
388
|
+
image.save(image_path, image_format.upper())
|
389
|
+
|
390
|
+
from edsl import FileStore
|
391
|
+
scenario_dict[f"page_{i}"] = FileStore(image_path)
|
392
|
+
|
393
|
+
scenario = Scenario(scenario_dict)
|
394
|
+
|
395
|
+
return cls(scenario)
|
396
|
+
|
364
397
|
@classmethod
|
365
398
|
def from_docx(cls, docx_path: str) -> "Scenario":
|
366
399
|
"""Creates a scenario from the text of a docx file.
|
edsl/scenarios/ScenarioList.py
CHANGED
@@ -1135,7 +1135,7 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
1135
1135
|
return cls(observations)
|
1136
1136
|
|
1137
1137
|
@classmethod
|
1138
|
-
def from_google_sheet(cls, url: str, sheet_name: str = None) -> ScenarioList:
|
1138
|
+
def from_google_sheet(cls, url: str, sheet_name: str = None, column_names: Optional[List[str]]= None) -> ScenarioList:
|
1139
1139
|
"""Create a ScenarioList from a Google Sheet.
|
1140
1140
|
|
1141
1141
|
This method downloads the Google Sheet as an Excel file, saves it to a temporary file,
|
@@ -1145,6 +1145,8 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
1145
1145
|
url (str): The URL to the Google Sheet.
|
1146
1146
|
sheet_name (str, optional): The name of the sheet to load. If None, the method will behave
|
1147
1147
|
the same as from_excel regarding multiple sheets.
|
1148
|
+
column_names (List[str], optional): If provided, use these names for the columns instead
|
1149
|
+
of the default column names from the sheet.
|
1148
1150
|
|
1149
1151
|
Returns:
|
1150
1152
|
ScenarioList: An instance of the ScenarioList class.
|
@@ -1172,8 +1174,25 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
1172
1174
|
temp_file.write(response.content)
|
1173
1175
|
temp_filename = temp_file.name
|
1174
1176
|
|
1175
|
-
#
|
1176
|
-
|
1177
|
+
# First create the ScenarioList with default column names
|
1178
|
+
scenario_list = cls.from_excel(temp_filename, sheet_name=sheet_name)
|
1179
|
+
|
1180
|
+
# If column_names is provided, create a new ScenarioList with the specified names
|
1181
|
+
if column_names is not None:
|
1182
|
+
if len(column_names) != len(scenario_list[0].keys()):
|
1183
|
+
raise ValueError(
|
1184
|
+
f"Number of provided column names ({len(column_names)}) "
|
1185
|
+
f"does not match number of columns in sheet ({len(scenario_list[0].keys())})"
|
1186
|
+
)
|
1187
|
+
|
1188
|
+
# Create a codebook mapping original keys to new names
|
1189
|
+
original_keys = list(scenario_list[0].keys())
|
1190
|
+
codebook = dict(zip(original_keys, column_names))
|
1191
|
+
|
1192
|
+
# Return new ScenarioList with renamed columns
|
1193
|
+
return scenario_list.rename(codebook)
|
1194
|
+
else:
|
1195
|
+
return scenario_list
|
1177
1196
|
|
1178
1197
|
@classmethod
|
1179
1198
|
def from_delimited_file(
|