edsl 0.1.41__py3-none-any.whl → 0.1.43__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/__version__.py +1 -1
- edsl/agents/Invigilator.py +4 -3
- edsl/agents/InvigilatorBase.py +2 -1
- edsl/agents/PromptConstructor.py +92 -21
- edsl/agents/QuestionInstructionPromptBuilder.py +68 -9
- edsl/agents/QuestionTemplateReplacementsBuilder.py +7 -2
- edsl/agents/prompt_helpers.py +2 -2
- edsl/coop/coop.py +97 -19
- edsl/enums.py +3 -1
- edsl/exceptions/coop.py +4 -0
- edsl/exceptions/jobs.py +1 -9
- edsl/exceptions/language_models.py +8 -4
- edsl/exceptions/questions.py +8 -11
- edsl/inference_services/AvailableModelFetcher.py +4 -1
- edsl/inference_services/DeepSeekService.py +18 -0
- edsl/inference_services/registry.py +2 -0
- edsl/jobs/Jobs.py +60 -34
- edsl/jobs/JobsPrompts.py +64 -3
- edsl/jobs/JobsRemoteInferenceHandler.py +42 -25
- edsl/jobs/JobsRemoteInferenceLogger.py +1 -1
- edsl/jobs/buckets/BucketCollection.py +30 -0
- edsl/jobs/data_structures.py +1 -0
- edsl/jobs/interviews/Interview.py +1 -1
- edsl/jobs/loggers/HTMLTableJobLogger.py +6 -1
- edsl/jobs/results_exceptions_handler.py +2 -7
- edsl/jobs/tasks/TaskHistory.py +49 -17
- edsl/language_models/LanguageModel.py +7 -4
- edsl/language_models/ModelList.py +1 -1
- edsl/language_models/key_management/KeyLookupBuilder.py +47 -20
- edsl/language_models/key_management/models.py +10 -4
- edsl/language_models/model.py +49 -0
- edsl/prompts/Prompt.py +124 -61
- edsl/questions/descriptors.py +37 -23
- edsl/questions/question_base_gen_mixin.py +1 -0
- edsl/results/DatasetExportMixin.py +35 -6
- edsl/results/Result.py +9 -3
- edsl/results/Results.py +180 -2
- edsl/results/ResultsGGMixin.py +117 -60
- edsl/scenarios/PdfExtractor.py +3 -6
- edsl/scenarios/Scenario.py +35 -1
- edsl/scenarios/ScenarioList.py +22 -3
- edsl/scenarios/ScenarioListPdfMixin.py +9 -3
- edsl/surveys/Survey.py +1 -1
- edsl/templates/error_reporting/base.html +2 -4
- edsl/templates/error_reporting/exceptions_table.html +35 -0
- edsl/templates/error_reporting/interview_details.html +67 -53
- edsl/templates/error_reporting/interviews.html +4 -17
- edsl/templates/error_reporting/overview.html +31 -5
- edsl/templates/error_reporting/performance_plot.html +1 -1
- {edsl-0.1.41.dist-info → edsl-0.1.43.dist-info}/METADATA +2 -3
- {edsl-0.1.41.dist-info → edsl-0.1.43.dist-info}/RECORD +53 -51
- {edsl-0.1.41.dist-info → edsl-0.1.43.dist-info}/LICENSE +0 -0
- {edsl-0.1.41.dist-info → edsl-0.1.43.dist-info}/WHEEL +0 -0
@@ -7,7 +7,6 @@ from typing import Optional, Tuple, Union, List
|
|
7
7
|
|
8
8
|
from edsl.results.file_exports import CSVExport, ExcelExport, JSONLExport, SQLiteExport
|
9
9
|
|
10
|
-
|
11
10
|
class DatasetExportMixin:
|
12
11
|
"""Mixin class for exporting Dataset objects."""
|
13
12
|
|
@@ -220,23 +219,45 @@ class DatasetExportMixin:
|
|
220
219
|
)
|
221
220
|
return exporter.export()
|
222
221
|
|
223
|
-
def _db(self, remove_prefix: bool = True):
|
222
|
+
def _db(self, remove_prefix: bool = True, shape: str = "wide") -> "sqlalchemy.engine.Engine":
|
224
223
|
"""Create a SQLite database in memory and return the connection.
|
225
224
|
|
226
225
|
Args:
|
227
|
-
shape: The shape of the data in the database (wide or long)
|
228
226
|
remove_prefix: Whether to remove the prefix from the column names
|
227
|
+
shape: The shape of the data in the database ("wide" or "long")
|
229
228
|
|
230
229
|
Returns:
|
231
230
|
A database connection
|
231
|
+
>>> from sqlalchemy import text
|
232
|
+
>>> from edsl import Results
|
233
|
+
>>> engine = Results.example()._db()
|
234
|
+
>>> len(engine.execute(text("SELECT * FROM self")).fetchall())
|
235
|
+
4
|
236
|
+
>>> engine = Results.example()._db(shape = "long")
|
237
|
+
>>> len(engine.execute(text("SELECT * FROM self")).fetchall())
|
238
|
+
172
|
232
239
|
"""
|
233
|
-
from sqlalchemy import create_engine
|
240
|
+
from sqlalchemy import create_engine, text
|
234
241
|
|
235
242
|
engine = create_engine("sqlite:///:memory:")
|
236
|
-
if remove_prefix:
|
243
|
+
if remove_prefix and shape == "wide":
|
237
244
|
df = self.remove_prefix().to_pandas(lists_as_strings=True)
|
238
245
|
else:
|
239
246
|
df = self.to_pandas(lists_as_strings=True)
|
247
|
+
|
248
|
+
if shape == "long":
|
249
|
+
# Melt the dataframe to convert it to long format
|
250
|
+
df = df.melt(
|
251
|
+
var_name='key',
|
252
|
+
value_name='value'
|
253
|
+
)
|
254
|
+
# Add a row number column for reference
|
255
|
+
df.insert(0, 'row_number', range(1, len(df) + 1))
|
256
|
+
|
257
|
+
# Split the key into data_type and key
|
258
|
+
df['data_type'] = df['key'].apply(lambda x: x.split('.')[0] if '.' in x else None)
|
259
|
+
df['key'] = df['key'].apply(lambda x: '.'.join(x.split('.')[1:]) if '.' in x else x)
|
260
|
+
|
240
261
|
df.to_sql(
|
241
262
|
"self",
|
242
263
|
engine,
|
@@ -251,6 +272,7 @@ class DatasetExportMixin:
|
|
251
272
|
transpose: bool = None,
|
252
273
|
transpose_by: str = None,
|
253
274
|
remove_prefix: bool = True,
|
275
|
+
shape: str = "wide",
|
254
276
|
) -> Union["pd.DataFrame", str]:
|
255
277
|
"""Execute a SQL query and return the results as a DataFrame.
|
256
278
|
|
@@ -268,10 +290,17 @@ class DatasetExportMixin:
|
|
268
290
|
Returns:
|
269
291
|
DataFrame, CSV string, list, or LaTeX string depending on parameters
|
270
292
|
|
293
|
+
Examples:
|
294
|
+
>>> from edsl import Results
|
295
|
+
>>> r = Results.example();
|
296
|
+
>>> len(r.sql("SELECT * FROM self", shape = "wide"))
|
297
|
+
4
|
298
|
+
>>> len(r.sql("SELECT * FROM self", shape = "long"))
|
299
|
+
172
|
271
300
|
"""
|
272
301
|
import pandas as pd
|
273
302
|
|
274
|
-
conn = self._db(remove_prefix=remove_prefix)
|
303
|
+
conn = self._db(remove_prefix=remove_prefix, shape=shape)
|
275
304
|
df = pd.read_sql_query(query, conn)
|
276
305
|
|
277
306
|
# Transpose the DataFrame if transpose is True
|
edsl/results/Result.py
CHANGED
@@ -78,7 +78,6 @@ class Result(Base, UserDict):
|
|
78
78
|
self.question_to_attributes = (
|
79
79
|
question_to_attributes or self._create_question_to_attributes(survey)
|
80
80
|
)
|
81
|
-
|
82
81
|
data = {
|
83
82
|
"agent": agent,
|
84
83
|
"scenario": scenario,
|
@@ -87,7 +86,7 @@ class Result(Base, UserDict):
|
|
87
86
|
"answer": answer,
|
88
87
|
"prompt": prompt or {},
|
89
88
|
"raw_model_response": raw_model_response or {},
|
90
|
-
"question_to_attributes": question_to_attributes,
|
89
|
+
"question_to_attributes": self.question_to_attributes,
|
91
90
|
"generated_tokens": generated_tokens or {},
|
92
91
|
"comments_dict": comments_dict or {},
|
93
92
|
"cache_used_dict": cache_used_dict or {},
|
@@ -154,7 +153,9 @@ class Result(Base, UserDict):
|
|
154
153
|
@staticmethod
|
155
154
|
def _create_model_sub_dict(model) -> dict:
|
156
155
|
return {
|
157
|
-
"model": model.parameters
|
156
|
+
"model": model.parameters
|
157
|
+
| {"model": model.model}
|
158
|
+
| {"inference_service": model._inference_service_},
|
158
159
|
}
|
159
160
|
|
160
161
|
@staticmethod
|
@@ -365,6 +366,10 @@ class Result(Base, UserDict):
|
|
365
366
|
else prompt_obj.to_dict()
|
366
367
|
)
|
367
368
|
d[key] = new_prompt_dict
|
369
|
+
|
370
|
+
if self.indices is not None:
|
371
|
+
d["indices"] = self.indices
|
372
|
+
|
368
373
|
if add_edsl_version:
|
369
374
|
from edsl import __version__
|
370
375
|
|
@@ -414,6 +419,7 @@ class Result(Base, UserDict):
|
|
414
419
|
comments_dict=json_dict.get("comments_dict", {}),
|
415
420
|
cache_used_dict=json_dict.get("cache_used_dict", {}),
|
416
421
|
cache_keys=json_dict.get("cache_keys", {}),
|
422
|
+
indices = json_dict.get("indices", None)
|
417
423
|
)
|
418
424
|
return result
|
419
425
|
|
edsl/results/Results.py
CHANGED
@@ -38,6 +38,64 @@ from edsl.results.ResultsGGMixin import ResultsGGMixin
|
|
38
38
|
from edsl.results.results_fetch_mixin import ResultsFetchMixin
|
39
39
|
from edsl.utilities.remove_edsl_version import remove_edsl_version
|
40
40
|
|
41
|
+
def ensure_fetched(method):
|
42
|
+
"""A decorator that checks if remote data is loaded, and if not, attempts to fetch it."""
|
43
|
+
def wrapper(self, *args, **kwargs):
|
44
|
+
if not self._fetched:
|
45
|
+
# If not fetched, try fetching now.
|
46
|
+
# (If you know you have job info stored in self.job_info)
|
47
|
+
self.fetch_remote(self.job_info)
|
48
|
+
return method(self, *args, **kwargs)
|
49
|
+
return wrapper
|
50
|
+
|
51
|
+
def ensure_ready(method):
|
52
|
+
"""
|
53
|
+
Decorator for Results methods.
|
54
|
+
|
55
|
+
If the Results object is not ready, for most methods we return a NotReadyObject.
|
56
|
+
However, for __repr__ (and other methods that need to return a string), we return
|
57
|
+
the string representation of NotReadyObject.
|
58
|
+
"""
|
59
|
+
from functools import wraps
|
60
|
+
|
61
|
+
@wraps(method)
|
62
|
+
def wrapper(self, *args, **kwargs):
|
63
|
+
if self.completed:
|
64
|
+
return method(self, *args, **kwargs)
|
65
|
+
# Attempt to fetch remote data
|
66
|
+
try:
|
67
|
+
if hasattr(self, "job_info"):
|
68
|
+
self.fetch_remote(self.job_info)
|
69
|
+
except Exception as e:
|
70
|
+
print(f"Error during fetch_remote in {method.__name__}: {e}")
|
71
|
+
if not self.completed:
|
72
|
+
not_ready = NotReadyObject(name = method.__name__, job_info = self.job_info)
|
73
|
+
# For __repr__, ensure we return a string
|
74
|
+
if method.__name__ == "__repr__" or method.__name__ == "__str__":
|
75
|
+
return not_ready.__repr__()
|
76
|
+
return not_ready
|
77
|
+
return method(self, *args, **kwargs)
|
78
|
+
|
79
|
+
return wrapper
|
80
|
+
|
81
|
+
class NotReadyObject:
|
82
|
+
"""A placeholder object that prints a message when any attribute is accessed."""
|
83
|
+
def __init__(self, name: str, job_info: RemoteJobInfo):
|
84
|
+
self.name = name
|
85
|
+
self.job_info = job_info
|
86
|
+
#print(f"Not ready to call {name}")
|
87
|
+
|
88
|
+
def __repr__(self):
|
89
|
+
message = f"""Results not ready - job still running on server."""
|
90
|
+
for key, value in self.job_info.creation_data.items():
|
91
|
+
message += f"\n{key}: {value}"
|
92
|
+
return message
|
93
|
+
|
94
|
+
def __getattr__(self, _):
|
95
|
+
return self
|
96
|
+
|
97
|
+
def __call__(self, *args, **kwargs):
|
98
|
+
return self
|
41
99
|
|
42
100
|
class Mixins(
|
43
101
|
ResultsExportMixin,
|
@@ -93,6 +151,16 @@ class Results(UserList, Mixins, Base):
|
|
93
151
|
"cache_keys",
|
94
152
|
]
|
95
153
|
|
154
|
+
@classmethod
|
155
|
+
def from_job_info(cls, job_info: dict) -> Results:
|
156
|
+
"""
|
157
|
+
Instantiate a `Results` object from a job info dictionary.
|
158
|
+
"""
|
159
|
+
results = cls()
|
160
|
+
results.completed = False
|
161
|
+
results.job_info = job_info
|
162
|
+
return results
|
163
|
+
|
96
164
|
def __init__(
|
97
165
|
self,
|
98
166
|
survey: Optional[Survey] = None,
|
@@ -112,6 +180,8 @@ class Results(UserList, Mixins, Base):
|
|
112
180
|
:param total_results: An integer representing the total number of results.
|
113
181
|
:cache: A Cache object.
|
114
182
|
"""
|
183
|
+
self.completed = True
|
184
|
+
self._fetching = False
|
115
185
|
super().__init__(data)
|
116
186
|
from edsl.data.Cache import Cache
|
117
187
|
from edsl.jobs.tasks.TaskHistory import TaskHistory
|
@@ -315,7 +385,22 @@ class Results(UserList, Mixins, Base):
|
|
315
385
|
data=self.data + other.data,
|
316
386
|
created_columns=self.created_columns,
|
317
387
|
)
|
318
|
-
|
388
|
+
|
389
|
+
def _repr_html_(self):
|
390
|
+
if not self.completed:
|
391
|
+
if hasattr(self, "job_info"):
|
392
|
+
self.fetch_remote(self.job_info)
|
393
|
+
|
394
|
+
if not self.completed:
|
395
|
+
return f"Results not ready to call"
|
396
|
+
|
397
|
+
return super()._repr_html_()
|
398
|
+
|
399
|
+
# @ensure_ready
|
400
|
+
# def __str__(self):
|
401
|
+
# super().__str__()
|
402
|
+
|
403
|
+
@ensure_ready
|
319
404
|
def __repr__(self) -> str:
|
320
405
|
return f"Results(data = {self.data}, survey = {repr(self.survey)}, created_columns = {self.created_columns})"
|
321
406
|
|
@@ -647,7 +732,7 @@ class Results(UserList, Mixins, Base):
|
|
647
732
|
|
648
733
|
>>> r = Results.example()
|
649
734
|
>>> r.model_keys
|
650
|
-
['frequency_penalty', 'logprobs', 'max_tokens', 'model', 'model_index', 'presence_penalty', 'temperature', 'top_logprobs', 'top_p']
|
735
|
+
['frequency_penalty', 'inference_service', 'logprobs', 'max_tokens', 'model', 'model_index', 'presence_penalty', 'temperature', 'top_logprobs', 'top_p']
|
651
736
|
"""
|
652
737
|
return sorted(self._data_type_to_keys["model"])
|
653
738
|
|
@@ -732,6 +817,7 @@ class Results(UserList, Mixins, Base):
|
|
732
817
|
|
733
818
|
return self.recode(column, recode_function=f, new_var_name=new_var_name)
|
734
819
|
|
820
|
+
@ensure_ready
|
735
821
|
def recode(
|
736
822
|
self, column: str, recode_function: Optional[Callable], new_var_name=None
|
737
823
|
) -> Results:
|
@@ -760,6 +846,7 @@ class Results(UserList, Mixins, Base):
|
|
760
846
|
created_columns=self.created_columns + [new_var_name],
|
761
847
|
)
|
762
848
|
|
849
|
+
@ensure_ready
|
763
850
|
def add_column(self, column_name: str, values: list) -> Results:
|
764
851
|
"""Adds columns to Results
|
765
852
|
|
@@ -780,6 +867,7 @@ class Results(UserList, Mixins, Base):
|
|
780
867
|
created_columns=self.created_columns + [column_name],
|
781
868
|
)
|
782
869
|
|
870
|
+
@ensure_ready
|
783
871
|
def add_columns_from_dict(self, columns: List[dict]) -> Results:
|
784
872
|
"""Adds columns to Results from a list of dictionaries.
|
785
873
|
|
@@ -829,6 +917,7 @@ class Results(UserList, Mixins, Base):
|
|
829
917
|
evaluator.functions.update(int=int, float=float)
|
830
918
|
return evaluator
|
831
919
|
|
920
|
+
@ensure_ready
|
832
921
|
def mutate(
|
833
922
|
self, new_var_string: str, functions_dict: Optional[dict] = None
|
834
923
|
) -> Results:
|
@@ -879,6 +968,7 @@ class Results(UserList, Mixins, Base):
|
|
879
968
|
created_columns=self.created_columns + [var_name],
|
880
969
|
)
|
881
970
|
|
971
|
+
@ensure_ready
|
882
972
|
def add_column(self, column_name: str, values: list) -> Results:
|
883
973
|
"""Adds columns to Results
|
884
974
|
|
@@ -899,6 +989,7 @@ class Results(UserList, Mixins, Base):
|
|
899
989
|
created_columns=self.created_columns + [column_name],
|
900
990
|
)
|
901
991
|
|
992
|
+
@ensure_ready
|
902
993
|
def rename(self, old_name: str, new_name: str) -> Results:
|
903
994
|
"""Rename an answer column in a Results object.
|
904
995
|
|
@@ -916,6 +1007,7 @@ class Results(UserList, Mixins, Base):
|
|
916
1007
|
|
917
1008
|
return self
|
918
1009
|
|
1010
|
+
@ensure_ready
|
919
1011
|
def shuffle(self, seed: Optional[str] = "edsl") -> Results:
|
920
1012
|
"""Shuffle the results.
|
921
1013
|
|
@@ -932,6 +1024,7 @@ class Results(UserList, Mixins, Base):
|
|
932
1024
|
random.shuffle(new_data)
|
933
1025
|
return Results(survey=self.survey, data=new_data, created_columns=None)
|
934
1026
|
|
1027
|
+
@ensure_ready
|
935
1028
|
def sample(
|
936
1029
|
self,
|
937
1030
|
n: Optional[int] = None,
|
@@ -971,6 +1064,7 @@ class Results(UserList, Mixins, Base):
|
|
971
1064
|
|
972
1065
|
return Results(survey=self.survey, data=new_data, created_columns=None)
|
973
1066
|
|
1067
|
+
@ensure_ready
|
974
1068
|
def select(self, *columns: Union[str, list[str]]) -> Results:
|
975
1069
|
"""
|
976
1070
|
Select data from the results and format it.
|
@@ -1004,6 +1098,7 @@ class Results(UserList, Mixins, Base):
|
|
1004
1098
|
)
|
1005
1099
|
return selector.select(*columns)
|
1006
1100
|
|
1101
|
+
@ensure_ready
|
1007
1102
|
def sort_by(self, *columns: str, reverse: bool = False) -> Results:
|
1008
1103
|
"""Sort the results by one or more columns."""
|
1009
1104
|
import warnings
|
@@ -1019,6 +1114,7 @@ class Results(UserList, Mixins, Base):
|
|
1019
1114
|
return column.split(".")
|
1020
1115
|
return self._key_to_data_type[column], column
|
1021
1116
|
|
1117
|
+
@ensure_ready
|
1022
1118
|
def order_by(self, *columns: str, reverse: bool = False) -> Results:
|
1023
1119
|
"""Sort the results by one or more columns.
|
1024
1120
|
|
@@ -1055,6 +1151,7 @@ class Results(UserList, Mixins, Base):
|
|
1055
1151
|
new_data = sorted(self.data, key=sort_key, reverse=reverse)
|
1056
1152
|
return Results(survey=self.survey, data=new_data, created_columns=None)
|
1057
1153
|
|
1154
|
+
@ensure_ready
|
1058
1155
|
def filter(self, expression: str) -> Results:
|
1059
1156
|
"""
|
1060
1157
|
Filter based on the given expression and returns the filtered `Results`.
|
@@ -1156,6 +1253,7 @@ class Results(UserList, Mixins, Base):
|
|
1156
1253
|
"""Display an object as a table."""
|
1157
1254
|
pass
|
1158
1255
|
|
1256
|
+
@ensure_ready
|
1159
1257
|
def __str__(self):
|
1160
1258
|
data = self.to_dict()["data"]
|
1161
1259
|
return json.dumps(data, indent=4)
|
@@ -1178,6 +1276,86 @@ class Results(UserList, Mixins, Base):
|
|
1178
1276
|
[1, 1, 0, 0]
|
1179
1277
|
"""
|
1180
1278
|
return [r.score(f) for r in self.data]
|
1279
|
+
|
1280
|
+
|
1281
|
+
def fetch_remote(self, job_info: "RemoteJobInfo") -> None:
|
1282
|
+
"""
|
1283
|
+
Fetches the remote Results object using the provided RemoteJobInfo and updates this instance with the remote data.
|
1284
|
+
|
1285
|
+
This is useful when you have a Results object that was created locally but want to sync it with
|
1286
|
+
the latest data from the remote server.
|
1287
|
+
|
1288
|
+
Args:
|
1289
|
+
job_info: RemoteJobInfo object containing the job_uuid and other remote job details
|
1290
|
+
|
1291
|
+
"""
|
1292
|
+
#print("Calling fetch_remote")
|
1293
|
+
try:
|
1294
|
+
from edsl.coop.coop import Coop
|
1295
|
+
from edsl.jobs.JobsRemoteInferenceHandler import JobsRemoteInferenceHandler
|
1296
|
+
|
1297
|
+
# Get the remote job data
|
1298
|
+
remote_job_data = JobsRemoteInferenceHandler.check_status(job_info.job_uuid)
|
1299
|
+
|
1300
|
+
if remote_job_data.get("status") not in ["completed", "failed"]:
|
1301
|
+
return False
|
1302
|
+
#
|
1303
|
+
results_uuid = remote_job_data.get("results_uuid")
|
1304
|
+
if not results_uuid:
|
1305
|
+
raise ResultsError("No results_uuid found in remote job data")
|
1306
|
+
|
1307
|
+
# Fetch the remote Results object
|
1308
|
+
coop = Coop()
|
1309
|
+
remote_results = coop.get(results_uuid, expected_object_type="results")
|
1310
|
+
|
1311
|
+
# Update this instance with remote data
|
1312
|
+
self.data = remote_results.data
|
1313
|
+
self.survey = remote_results.survey
|
1314
|
+
self.created_columns = remote_results.created_columns
|
1315
|
+
self.cache = remote_results.cache
|
1316
|
+
self.task_history = remote_results.task_history
|
1317
|
+
self.completed = True
|
1318
|
+
|
1319
|
+
# Set job_uuid and results_uuid from remote data
|
1320
|
+
self.job_uuid = job_info.job_uuid
|
1321
|
+
if hasattr(remote_results, 'results_uuid'):
|
1322
|
+
self.results_uuid = remote_results.results_uuid
|
1323
|
+
|
1324
|
+
return True
|
1325
|
+
|
1326
|
+
except Exception as e:
|
1327
|
+
raise ResultsError(f"Failed to fetch remote results: {str(e)}")
|
1328
|
+
|
1329
|
+
def fetch(self, polling_interval: float = 1.0) -> Results:
|
1330
|
+
"""
|
1331
|
+
Polls the server for job completion and updates this Results instance with the completed data.
|
1332
|
+
|
1333
|
+
Args:
|
1334
|
+
polling_interval: Number of seconds to wait between polling attempts (default: 1.0)
|
1335
|
+
|
1336
|
+
Returns:
|
1337
|
+
self: The updated Results instance
|
1338
|
+
"""
|
1339
|
+
if not hasattr(self, "job_info"):
|
1340
|
+
raise ResultsError("No job info available - this Results object wasn't created from a remote job")
|
1341
|
+
|
1342
|
+
from edsl.jobs.JobsRemoteInferenceHandler import JobsRemoteInferenceHandler
|
1343
|
+
|
1344
|
+
try:
|
1345
|
+
# Get the remote job data
|
1346
|
+
remote_job_data = JobsRemoteInferenceHandler.check_status(self.job_info.job_uuid)
|
1347
|
+
|
1348
|
+
while remote_job_data.get("status") not in ["completed", "failed"]:
|
1349
|
+
import time
|
1350
|
+
time.sleep(polling_interval)
|
1351
|
+
remote_job_data = JobsRemoteInferenceHandler.check_status(self.job_info.job_uuid)
|
1352
|
+
|
1353
|
+
# Once complete, fetch the full results
|
1354
|
+
self.fetch_remote(self.job_info)
|
1355
|
+
return self
|
1356
|
+
|
1357
|
+
except Exception as e:
|
1358
|
+
raise ResultsError(f"Failed to fetch remote results: {str(e)}")
|
1181
1359
|
|
1182
1360
|
|
1183
1361
|
def main(): # pragma: no cover
|
edsl/results/ResultsGGMixin.py
CHANGED
@@ -5,46 +5,113 @@ import tempfile
|
|
5
5
|
from typing import Optional
|
6
6
|
|
7
7
|
|
8
|
+
class GGPlot:
|
9
|
+
"""A class to handle ggplot2 plot display and saving."""
|
10
|
+
|
11
|
+
def __init__(self, r_code: str, width: float = 6, height: float = 4):
|
12
|
+
"""Initialize with R code and dimensions."""
|
13
|
+
self.r_code = r_code
|
14
|
+
self.width = width
|
15
|
+
self.height = height
|
16
|
+
self._svg_data = None
|
17
|
+
self._saved = False # Track if the plot was saved
|
18
|
+
|
19
|
+
def _execute_r_code(self, save_command: str = ""):
|
20
|
+
"""Execute R code with optional save command."""
|
21
|
+
full_r_code = self.r_code + save_command
|
22
|
+
|
23
|
+
result = subprocess.run(
|
24
|
+
["Rscript", "-"],
|
25
|
+
input=full_r_code,
|
26
|
+
text=True,
|
27
|
+
stdout=subprocess.PIPE,
|
28
|
+
stderr=subprocess.PIPE,
|
29
|
+
)
|
30
|
+
|
31
|
+
if result.returncode != 0:
|
32
|
+
if result.returncode == 127:
|
33
|
+
raise RuntimeError(
|
34
|
+
"Rscript is probably not installed. Please install R from https://cran.r-project.org/"
|
35
|
+
)
|
36
|
+
else:
|
37
|
+
raise RuntimeError(
|
38
|
+
f"An error occurred while running Rscript: {result.stderr}"
|
39
|
+
)
|
40
|
+
|
41
|
+
if result.stderr:
|
42
|
+
print("Error in R script:", result.stderr)
|
43
|
+
|
44
|
+
return result
|
45
|
+
|
46
|
+
def save(self, filename: str):
|
47
|
+
"""Save the plot to a file."""
|
48
|
+
format = filename.split('.')[-1].lower()
|
49
|
+
if format not in ['svg', 'png']:
|
50
|
+
raise ValueError("Only 'svg' and 'png' formats are supported")
|
51
|
+
|
52
|
+
save_command = f'\nggsave("{filename}", plot = last_plot(), width = {self.width}, height = {self.height}, device = "{format}")'
|
53
|
+
self._execute_r_code(save_command)
|
54
|
+
|
55
|
+
self._saved = True
|
56
|
+
print(f"File saved to: {filename}")
|
57
|
+
return None # Return None instead of self
|
58
|
+
|
59
|
+
def _repr_html_(self):
|
60
|
+
"""Display the plot in a Jupyter notebook."""
|
61
|
+
# Don't display if the plot was saved
|
62
|
+
if self._saved:
|
63
|
+
return None
|
64
|
+
|
65
|
+
import tempfile
|
66
|
+
|
67
|
+
# Generate SVG if we haven't already
|
68
|
+
if self._svg_data is None:
|
69
|
+
# Create temporary SVG file
|
70
|
+
with tempfile.NamedTemporaryFile(suffix='.svg') as tmp:
|
71
|
+
save_command = f'\nggsave("{tmp.name}", plot = last_plot(), width = {self.width}, height = {self.height}, device = "svg")'
|
72
|
+
self._execute_r_code(save_command)
|
73
|
+
with open(tmp.name, 'r') as f:
|
74
|
+
self._svg_data = f.read()
|
75
|
+
|
76
|
+
return self._svg_data
|
77
|
+
|
8
78
|
class ResultsGGMixin:
|
9
79
|
"""Mixin class for ggplot2 plotting."""
|
10
80
|
|
11
81
|
def ggplot2(
|
12
82
|
self,
|
13
83
|
ggplot_code: str,
|
14
|
-
filename: str = None,
|
15
84
|
shape="wide",
|
16
85
|
sql: str = None,
|
17
86
|
remove_prefix: bool = True,
|
18
87
|
debug: bool = False,
|
19
88
|
height=4,
|
20
89
|
width=6,
|
21
|
-
format="svg",
|
22
90
|
factor_orders: Optional[dict] = None,
|
23
91
|
):
|
24
92
|
"""Create a ggplot2 plot from a DataFrame.
|
25
93
|
|
94
|
+
Returns a GGPlot object that can be displayed in a notebook or saved to a file.
|
95
|
+
|
26
96
|
:param ggplot_code: The ggplot2 code to execute.
|
27
|
-
:param filename: The filename to save the plot to.
|
28
97
|
:param shape: The shape of the data in the DataFrame (wide or long).
|
29
98
|
:param sql: The SQL query to execute beforehand to manipulate the data.
|
30
99
|
:param remove_prefix: Whether to remove the prefix from the column names.
|
31
100
|
:param debug: Whether to print the R code instead of executing it.
|
32
101
|
:param height: The height of the plot in inches.
|
33
102
|
:param width: The width of the plot in inches.
|
34
|
-
:param format: The format to save the plot in (png or svg).
|
35
103
|
:param factor_orders: A dictionary of factor columns and their order.
|
36
104
|
"""
|
37
|
-
|
38
105
|
if sql == None:
|
39
106
|
sql = "select * from self"
|
40
107
|
|
41
108
|
if shape == "long":
|
42
109
|
df = self.sql(sql, shape="long")
|
43
110
|
elif shape == "wide":
|
44
|
-
df = self.sql(sql,
|
111
|
+
df = self.sql(sql, remove_prefix=remove_prefix)
|
45
112
|
|
46
113
|
# Convert DataFrame to CSV format
|
47
|
-
csv_data = df.to_csv(
|
114
|
+
csv_data = df.to_csv().text
|
48
115
|
|
49
116
|
# Embed the CSV data within the R script
|
50
117
|
csv_data_escaped = csv_data.replace("\n", "\\n").replace("'", "\\'")
|
@@ -52,70 +119,60 @@ class ResultsGGMixin:
|
|
52
119
|
|
53
120
|
if factor_orders is not None:
|
54
121
|
for factor, order in factor_orders.items():
|
55
|
-
# read_csv_code += f"""self${{{factor}}} <- factor(self${{{factor}}}, levels=c({','.join(['"{}"'.format(x) for x in order])}))"""
|
56
|
-
|
57
122
|
level_string = ", ".join([f'"{x}"' for x in order])
|
58
123
|
read_csv_code += (
|
59
124
|
f"self${factor} <- factor(self${factor}, levels=c({level_string}))"
|
60
125
|
)
|
61
126
|
read_csv_code += "\n"
|
62
127
|
|
63
|
-
# Load ggplot2 library
|
64
|
-
|
65
|
-
|
66
|
-
# Check if a filename is provided for the plot, if not create a temporary one
|
67
|
-
if not filename:
|
68
|
-
filename = tempfile.mktemp(suffix=f".{format}")
|
69
|
-
|
70
|
-
# Combine all R script parts
|
71
|
-
full_r_code = load_ggplot2 + read_csv_code + ggplot_code
|
72
|
-
|
73
|
-
# Add command to save the plot to a file
|
74
|
-
full_r_code += f'\nggsave("{filename}", plot = last_plot(), width = {width}, height = {height}, device = "{format}")'
|
128
|
+
# Load ggplot2 library and combine all R script parts
|
129
|
+
full_r_code = "library(ggplot2)\n" + read_csv_code + ggplot_code
|
75
130
|
|
76
131
|
if debug:
|
77
132
|
print(full_r_code)
|
78
133
|
return
|
79
134
|
|
80
|
-
|
81
|
-
["Rscript", "-"],
|
82
|
-
input=full_r_code,
|
83
|
-
text=True,
|
84
|
-
stdout=subprocess.PIPE,
|
85
|
-
stderr=subprocess.PIPE,
|
86
|
-
)
|
87
|
-
|
88
|
-
if result.returncode != 0:
|
89
|
-
if result.returncode == 127: # 'command not found'
|
90
|
-
raise RuntimeError(
|
91
|
-
"Rscript is probably not installed. Please install R from https://cran.r-project.org/"
|
92
|
-
)
|
93
|
-
else:
|
94
|
-
raise RuntimeError(
|
95
|
-
f"An error occurred while running Rscript: {result.stderr}"
|
96
|
-
)
|
97
|
-
|
98
|
-
if result.stderr:
|
99
|
-
print("Error in R script:", result.stderr)
|
100
|
-
else:
|
101
|
-
self._display_plot(filename, width, height)
|
135
|
+
return GGPlot(full_r_code, width=width, height=height)
|
102
136
|
|
103
137
|
def _display_plot(self, filename: str, width: float, height: float):
|
104
|
-
"""Display the plot in the notebook."""
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
138
|
+
"""Display the plot in the notebook or open in system viewer if running from terminal."""
|
139
|
+
try:
|
140
|
+
# Try to import IPython-related modules
|
141
|
+
import matplotlib.pyplot as plt
|
142
|
+
import matplotlib.image as mpimg
|
143
|
+
from IPython import get_ipython
|
144
|
+
|
145
|
+
# Check if we're in a notebook environment
|
146
|
+
if get_ipython() is not None:
|
147
|
+
if filename.endswith(".png"):
|
148
|
+
img = mpimg.imread(filename)
|
149
|
+
plt.figure(figsize=(width, height))
|
150
|
+
plt.imshow(img)
|
151
|
+
plt.axis("off")
|
152
|
+
plt.show()
|
153
|
+
elif filename.endswith(".svg"):
|
154
|
+
from IPython.display import SVG, display
|
155
|
+
display(SVG(filename=filename))
|
156
|
+
else:
|
157
|
+
print("Unsupported file format. Please provide a PNG or SVG file.")
|
158
|
+
return
|
159
|
+
|
160
|
+
except ImportError:
|
161
|
+
pass
|
162
|
+
|
163
|
+
# If we're not in a notebook or imports failed, open with system viewer
|
164
|
+
import platform
|
165
|
+
import os
|
166
|
+
|
167
|
+
system = platform.system()
|
168
|
+
if system == 'Darwin': # macOS
|
169
|
+
if filename.endswith('.svg'):
|
170
|
+
subprocess.run(['open', '-a', 'Preview', filename])
|
171
|
+
else:
|
172
|
+
subprocess.run(['open', filename])
|
173
|
+
elif system == 'Linux':
|
174
|
+
subprocess.run(['xdg-open', filename])
|
175
|
+
elif system == 'Windows':
|
176
|
+
os.startfile(filename)
|
120
177
|
else:
|
121
|
-
print("
|
178
|
+
print(f"File saved to: {filename}")
|
edsl/scenarios/PdfExtractor.py
CHANGED
@@ -2,14 +2,11 @@ import os
|
|
2
2
|
|
3
3
|
|
4
4
|
class PdfExtractor:
|
5
|
-
def __init__(self, pdf_path: str
|
5
|
+
def __init__(self, pdf_path: str):
|
6
6
|
self.pdf_path = pdf_path
|
7
|
-
self.constructor = parent_object.__class__
|
7
|
+
#self.constructor = parent_object.__class__
|
8
8
|
|
9
|
-
def
|
10
|
-
return self.constructor(self._get_pdf_dict())
|
11
|
-
|
12
|
-
def _get_pdf_dict(self) -> dict:
|
9
|
+
def get_pdf_dict(self) -> dict:
|
13
10
|
# Ensure the file exists
|
14
11
|
import fitz
|
15
12
|
|