edsl 0.1.41__py3-none-any.whl → 0.1.43__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. edsl/__version__.py +1 -1
  2. edsl/agents/Invigilator.py +4 -3
  3. edsl/agents/InvigilatorBase.py +2 -1
  4. edsl/agents/PromptConstructor.py +92 -21
  5. edsl/agents/QuestionInstructionPromptBuilder.py +68 -9
  6. edsl/agents/QuestionTemplateReplacementsBuilder.py +7 -2
  7. edsl/agents/prompt_helpers.py +2 -2
  8. edsl/coop/coop.py +97 -19
  9. edsl/enums.py +3 -1
  10. edsl/exceptions/coop.py +4 -0
  11. edsl/exceptions/jobs.py +1 -9
  12. edsl/exceptions/language_models.py +8 -4
  13. edsl/exceptions/questions.py +8 -11
  14. edsl/inference_services/AvailableModelFetcher.py +4 -1
  15. edsl/inference_services/DeepSeekService.py +18 -0
  16. edsl/inference_services/registry.py +2 -0
  17. edsl/jobs/Jobs.py +60 -34
  18. edsl/jobs/JobsPrompts.py +64 -3
  19. edsl/jobs/JobsRemoteInferenceHandler.py +42 -25
  20. edsl/jobs/JobsRemoteInferenceLogger.py +1 -1
  21. edsl/jobs/buckets/BucketCollection.py +30 -0
  22. edsl/jobs/data_structures.py +1 -0
  23. edsl/jobs/interviews/Interview.py +1 -1
  24. edsl/jobs/loggers/HTMLTableJobLogger.py +6 -1
  25. edsl/jobs/results_exceptions_handler.py +2 -7
  26. edsl/jobs/tasks/TaskHistory.py +49 -17
  27. edsl/language_models/LanguageModel.py +7 -4
  28. edsl/language_models/ModelList.py +1 -1
  29. edsl/language_models/key_management/KeyLookupBuilder.py +47 -20
  30. edsl/language_models/key_management/models.py +10 -4
  31. edsl/language_models/model.py +49 -0
  32. edsl/prompts/Prompt.py +124 -61
  33. edsl/questions/descriptors.py +37 -23
  34. edsl/questions/question_base_gen_mixin.py +1 -0
  35. edsl/results/DatasetExportMixin.py +35 -6
  36. edsl/results/Result.py +9 -3
  37. edsl/results/Results.py +180 -2
  38. edsl/results/ResultsGGMixin.py +117 -60
  39. edsl/scenarios/PdfExtractor.py +3 -6
  40. edsl/scenarios/Scenario.py +35 -1
  41. edsl/scenarios/ScenarioList.py +22 -3
  42. edsl/scenarios/ScenarioListPdfMixin.py +9 -3
  43. edsl/surveys/Survey.py +1 -1
  44. edsl/templates/error_reporting/base.html +2 -4
  45. edsl/templates/error_reporting/exceptions_table.html +35 -0
  46. edsl/templates/error_reporting/interview_details.html +67 -53
  47. edsl/templates/error_reporting/interviews.html +4 -17
  48. edsl/templates/error_reporting/overview.html +31 -5
  49. edsl/templates/error_reporting/performance_plot.html +1 -1
  50. {edsl-0.1.41.dist-info → edsl-0.1.43.dist-info}/METADATA +2 -3
  51. {edsl-0.1.41.dist-info → edsl-0.1.43.dist-info}/RECORD +53 -51
  52. {edsl-0.1.41.dist-info → edsl-0.1.43.dist-info}/LICENSE +0 -0
  53. {edsl-0.1.41.dist-info → edsl-0.1.43.dist-info}/WHEEL +0 -0
@@ -7,7 +7,6 @@ from typing import Optional, Tuple, Union, List
7
7
 
8
8
  from edsl.results.file_exports import CSVExport, ExcelExport, JSONLExport, SQLiteExport
9
9
 
10
-
11
10
  class DatasetExportMixin:
12
11
  """Mixin class for exporting Dataset objects."""
13
12
 
@@ -220,23 +219,45 @@ class DatasetExportMixin:
220
219
  )
221
220
  return exporter.export()
222
221
 
223
- def _db(self, remove_prefix: bool = True):
222
+ def _db(self, remove_prefix: bool = True, shape: str = "wide") -> "sqlalchemy.engine.Engine":
224
223
  """Create a SQLite database in memory and return the connection.
225
224
 
226
225
  Args:
227
- shape: The shape of the data in the database (wide or long)
228
226
  remove_prefix: Whether to remove the prefix from the column names
227
+ shape: The shape of the data in the database ("wide" or "long")
229
228
 
230
229
  Returns:
231
230
  A database connection
231
+ >>> from sqlalchemy import text
232
+ >>> from edsl import Results
233
+ >>> engine = Results.example()._db()
234
+ >>> len(engine.execute(text("SELECT * FROM self")).fetchall())
235
+ 4
236
+ >>> engine = Results.example()._db(shape = "long")
237
+ >>> len(engine.execute(text("SELECT * FROM self")).fetchall())
238
+ 172
232
239
  """
233
- from sqlalchemy import create_engine
240
+ from sqlalchemy import create_engine, text
234
241
 
235
242
  engine = create_engine("sqlite:///:memory:")
236
- if remove_prefix:
243
+ if remove_prefix and shape == "wide":
237
244
  df = self.remove_prefix().to_pandas(lists_as_strings=True)
238
245
  else:
239
246
  df = self.to_pandas(lists_as_strings=True)
247
+
248
+ if shape == "long":
249
+ # Melt the dataframe to convert it to long format
250
+ df = df.melt(
251
+ var_name='key',
252
+ value_name='value'
253
+ )
254
+ # Add a row number column for reference
255
+ df.insert(0, 'row_number', range(1, len(df) + 1))
256
+
257
+ # Split the key into data_type and key
258
+ df['data_type'] = df['key'].apply(lambda x: x.split('.')[0] if '.' in x else None)
259
+ df['key'] = df['key'].apply(lambda x: '.'.join(x.split('.')[1:]) if '.' in x else x)
260
+
240
261
  df.to_sql(
241
262
  "self",
242
263
  engine,
@@ -251,6 +272,7 @@ class DatasetExportMixin:
251
272
  transpose: bool = None,
252
273
  transpose_by: str = None,
253
274
  remove_prefix: bool = True,
275
+ shape: str = "wide",
254
276
  ) -> Union["pd.DataFrame", str]:
255
277
  """Execute a SQL query and return the results as a DataFrame.
256
278
 
@@ -268,10 +290,17 @@ class DatasetExportMixin:
268
290
  Returns:
269
291
  DataFrame, CSV string, list, or LaTeX string depending on parameters
270
292
 
293
+ Examples:
294
+ >>> from edsl import Results
295
+ >>> r = Results.example();
296
+ >>> len(r.sql("SELECT * FROM self", shape = "wide"))
297
+ 4
298
+ >>> len(r.sql("SELECT * FROM self", shape = "long"))
299
+ 172
271
300
  """
272
301
  import pandas as pd
273
302
 
274
- conn = self._db(remove_prefix=remove_prefix)
303
+ conn = self._db(remove_prefix=remove_prefix, shape=shape)
275
304
  df = pd.read_sql_query(query, conn)
276
305
 
277
306
  # Transpose the DataFrame if transpose is True
edsl/results/Result.py CHANGED
@@ -78,7 +78,6 @@ class Result(Base, UserDict):
78
78
  self.question_to_attributes = (
79
79
  question_to_attributes or self._create_question_to_attributes(survey)
80
80
  )
81
-
82
81
  data = {
83
82
  "agent": agent,
84
83
  "scenario": scenario,
@@ -87,7 +86,7 @@ class Result(Base, UserDict):
87
86
  "answer": answer,
88
87
  "prompt": prompt or {},
89
88
  "raw_model_response": raw_model_response or {},
90
- "question_to_attributes": question_to_attributes,
89
+ "question_to_attributes": self.question_to_attributes,
91
90
  "generated_tokens": generated_tokens or {},
92
91
  "comments_dict": comments_dict or {},
93
92
  "cache_used_dict": cache_used_dict or {},
@@ -154,7 +153,9 @@ class Result(Base, UserDict):
154
153
  @staticmethod
155
154
  def _create_model_sub_dict(model) -> dict:
156
155
  return {
157
- "model": model.parameters | {"model": model.model},
156
+ "model": model.parameters
157
+ | {"model": model.model}
158
+ | {"inference_service": model._inference_service_},
158
159
  }
159
160
 
160
161
  @staticmethod
@@ -365,6 +366,10 @@ class Result(Base, UserDict):
365
366
  else prompt_obj.to_dict()
366
367
  )
367
368
  d[key] = new_prompt_dict
369
+
370
+ if self.indices is not None:
371
+ d["indices"] = self.indices
372
+
368
373
  if add_edsl_version:
369
374
  from edsl import __version__
370
375
 
@@ -414,6 +419,7 @@ class Result(Base, UserDict):
414
419
  comments_dict=json_dict.get("comments_dict", {}),
415
420
  cache_used_dict=json_dict.get("cache_used_dict", {}),
416
421
  cache_keys=json_dict.get("cache_keys", {}),
422
+ indices = json_dict.get("indices", None)
417
423
  )
418
424
  return result
419
425
 
edsl/results/Results.py CHANGED
@@ -38,6 +38,64 @@ from edsl.results.ResultsGGMixin import ResultsGGMixin
38
38
  from edsl.results.results_fetch_mixin import ResultsFetchMixin
39
39
  from edsl.utilities.remove_edsl_version import remove_edsl_version
40
40
 
41
+ def ensure_fetched(method):
42
+ """A decorator that checks if remote data is loaded, and if not, attempts to fetch it."""
43
+ def wrapper(self, *args, **kwargs):
44
+ if not self._fetched:
45
+ # If not fetched, try fetching now.
46
+ # (If you know you have job info stored in self.job_info)
47
+ self.fetch_remote(self.job_info)
48
+ return method(self, *args, **kwargs)
49
+ return wrapper
50
+
51
+ def ensure_ready(method):
52
+ """
53
+ Decorator for Results methods.
54
+
55
+ If the Results object is not ready, for most methods we return a NotReadyObject.
56
+ However, for __repr__ (and other methods that need to return a string), we return
57
+ the string representation of NotReadyObject.
58
+ """
59
+ from functools import wraps
60
+
61
+ @wraps(method)
62
+ def wrapper(self, *args, **kwargs):
63
+ if self.completed:
64
+ return method(self, *args, **kwargs)
65
+ # Attempt to fetch remote data
66
+ try:
67
+ if hasattr(self, "job_info"):
68
+ self.fetch_remote(self.job_info)
69
+ except Exception as e:
70
+ print(f"Error during fetch_remote in {method.__name__}: {e}")
71
+ if not self.completed:
72
+ not_ready = NotReadyObject(name = method.__name__, job_info = self.job_info)
73
+ # For __repr__, ensure we return a string
74
+ if method.__name__ == "__repr__" or method.__name__ == "__str__":
75
+ return not_ready.__repr__()
76
+ return not_ready
77
+ return method(self, *args, **kwargs)
78
+
79
+ return wrapper
80
+
81
+ class NotReadyObject:
82
+ """A placeholder object that prints a message when any attribute is accessed."""
83
+ def __init__(self, name: str, job_info: RemoteJobInfo):
84
+ self.name = name
85
+ self.job_info = job_info
86
+ #print(f"Not ready to call {name}")
87
+
88
+ def __repr__(self):
89
+ message = f"""Results not ready - job still running on server."""
90
+ for key, value in self.job_info.creation_data.items():
91
+ message += f"\n{key}: {value}"
92
+ return message
93
+
94
+ def __getattr__(self, _):
95
+ return self
96
+
97
+ def __call__(self, *args, **kwargs):
98
+ return self
41
99
 
42
100
  class Mixins(
43
101
  ResultsExportMixin,
@@ -93,6 +151,16 @@ class Results(UserList, Mixins, Base):
93
151
  "cache_keys",
94
152
  ]
95
153
 
154
+ @classmethod
155
+ def from_job_info(cls, job_info: dict) -> Results:
156
+ """
157
+ Instantiate a `Results` object from a job info dictionary.
158
+ """
159
+ results = cls()
160
+ results.completed = False
161
+ results.job_info = job_info
162
+ return results
163
+
96
164
  def __init__(
97
165
  self,
98
166
  survey: Optional[Survey] = None,
@@ -112,6 +180,8 @@ class Results(UserList, Mixins, Base):
112
180
  :param total_results: An integer representing the total number of results.
113
181
  :cache: A Cache object.
114
182
  """
183
+ self.completed = True
184
+ self._fetching = False
115
185
  super().__init__(data)
116
186
  from edsl.data.Cache import Cache
117
187
  from edsl.jobs.tasks.TaskHistory import TaskHistory
@@ -315,7 +385,22 @@ class Results(UserList, Mixins, Base):
315
385
  data=self.data + other.data,
316
386
  created_columns=self.created_columns,
317
387
  )
318
-
388
+
389
+ def _repr_html_(self):
390
+ if not self.completed:
391
+ if hasattr(self, "job_info"):
392
+ self.fetch_remote(self.job_info)
393
+
394
+ if not self.completed:
395
+ return f"Results not ready to call"
396
+
397
+ return super()._repr_html_()
398
+
399
+ # @ensure_ready
400
+ # def __str__(self):
401
+ # super().__str__()
402
+
403
+ @ensure_ready
319
404
  def __repr__(self) -> str:
320
405
  return f"Results(data = {self.data}, survey = {repr(self.survey)}, created_columns = {self.created_columns})"
321
406
 
@@ -647,7 +732,7 @@ class Results(UserList, Mixins, Base):
647
732
 
648
733
  >>> r = Results.example()
649
734
  >>> r.model_keys
650
- ['frequency_penalty', 'logprobs', 'max_tokens', 'model', 'model_index', 'presence_penalty', 'temperature', 'top_logprobs', 'top_p']
735
+ ['frequency_penalty', 'inference_service', 'logprobs', 'max_tokens', 'model', 'model_index', 'presence_penalty', 'temperature', 'top_logprobs', 'top_p']
651
736
  """
652
737
  return sorted(self._data_type_to_keys["model"])
653
738
 
@@ -732,6 +817,7 @@ class Results(UserList, Mixins, Base):
732
817
 
733
818
  return self.recode(column, recode_function=f, new_var_name=new_var_name)
734
819
 
820
+ @ensure_ready
735
821
  def recode(
736
822
  self, column: str, recode_function: Optional[Callable], new_var_name=None
737
823
  ) -> Results:
@@ -760,6 +846,7 @@ class Results(UserList, Mixins, Base):
760
846
  created_columns=self.created_columns + [new_var_name],
761
847
  )
762
848
 
849
+ @ensure_ready
763
850
  def add_column(self, column_name: str, values: list) -> Results:
764
851
  """Adds columns to Results
765
852
 
@@ -780,6 +867,7 @@ class Results(UserList, Mixins, Base):
780
867
  created_columns=self.created_columns + [column_name],
781
868
  )
782
869
 
870
+ @ensure_ready
783
871
  def add_columns_from_dict(self, columns: List[dict]) -> Results:
784
872
  """Adds columns to Results from a list of dictionaries.
785
873
 
@@ -829,6 +917,7 @@ class Results(UserList, Mixins, Base):
829
917
  evaluator.functions.update(int=int, float=float)
830
918
  return evaluator
831
919
 
920
+ @ensure_ready
832
921
  def mutate(
833
922
  self, new_var_string: str, functions_dict: Optional[dict] = None
834
923
  ) -> Results:
@@ -879,6 +968,7 @@ class Results(UserList, Mixins, Base):
879
968
  created_columns=self.created_columns + [var_name],
880
969
  )
881
970
 
971
+ @ensure_ready
882
972
  def add_column(self, column_name: str, values: list) -> Results:
883
973
  """Adds columns to Results
884
974
 
@@ -899,6 +989,7 @@ class Results(UserList, Mixins, Base):
899
989
  created_columns=self.created_columns + [column_name],
900
990
  )
901
991
 
992
+ @ensure_ready
902
993
  def rename(self, old_name: str, new_name: str) -> Results:
903
994
  """Rename an answer column in a Results object.
904
995
 
@@ -916,6 +1007,7 @@ class Results(UserList, Mixins, Base):
916
1007
 
917
1008
  return self
918
1009
 
1010
+ @ensure_ready
919
1011
  def shuffle(self, seed: Optional[str] = "edsl") -> Results:
920
1012
  """Shuffle the results.
921
1013
 
@@ -932,6 +1024,7 @@ class Results(UserList, Mixins, Base):
932
1024
  random.shuffle(new_data)
933
1025
  return Results(survey=self.survey, data=new_data, created_columns=None)
934
1026
 
1027
+ @ensure_ready
935
1028
  def sample(
936
1029
  self,
937
1030
  n: Optional[int] = None,
@@ -971,6 +1064,7 @@ class Results(UserList, Mixins, Base):
971
1064
 
972
1065
  return Results(survey=self.survey, data=new_data, created_columns=None)
973
1066
 
1067
+ @ensure_ready
974
1068
  def select(self, *columns: Union[str, list[str]]) -> Results:
975
1069
  """
976
1070
  Select data from the results and format it.
@@ -1004,6 +1098,7 @@ class Results(UserList, Mixins, Base):
1004
1098
  )
1005
1099
  return selector.select(*columns)
1006
1100
 
1101
+ @ensure_ready
1007
1102
  def sort_by(self, *columns: str, reverse: bool = False) -> Results:
1008
1103
  """Sort the results by one or more columns."""
1009
1104
  import warnings
@@ -1019,6 +1114,7 @@ class Results(UserList, Mixins, Base):
1019
1114
  return column.split(".")
1020
1115
  return self._key_to_data_type[column], column
1021
1116
 
1117
+ @ensure_ready
1022
1118
  def order_by(self, *columns: str, reverse: bool = False) -> Results:
1023
1119
  """Sort the results by one or more columns.
1024
1120
 
@@ -1055,6 +1151,7 @@ class Results(UserList, Mixins, Base):
1055
1151
  new_data = sorted(self.data, key=sort_key, reverse=reverse)
1056
1152
  return Results(survey=self.survey, data=new_data, created_columns=None)
1057
1153
 
1154
+ @ensure_ready
1058
1155
  def filter(self, expression: str) -> Results:
1059
1156
  """
1060
1157
  Filter based on the given expression and returns the filtered `Results`.
@@ -1156,6 +1253,7 @@ class Results(UserList, Mixins, Base):
1156
1253
  """Display an object as a table."""
1157
1254
  pass
1158
1255
 
1256
+ @ensure_ready
1159
1257
  def __str__(self):
1160
1258
  data = self.to_dict()["data"]
1161
1259
  return json.dumps(data, indent=4)
@@ -1178,6 +1276,86 @@ class Results(UserList, Mixins, Base):
1178
1276
  [1, 1, 0, 0]
1179
1277
  """
1180
1278
  return [r.score(f) for r in self.data]
1279
+
1280
+
1281
+ def fetch_remote(self, job_info: "RemoteJobInfo") -> None:
1282
+ """
1283
+ Fetches the remote Results object using the provided RemoteJobInfo and updates this instance with the remote data.
1284
+
1285
+ This is useful when you have a Results object that was created locally but want to sync it with
1286
+ the latest data from the remote server.
1287
+
1288
+ Args:
1289
+ job_info: RemoteJobInfo object containing the job_uuid and other remote job details
1290
+
1291
+ """
1292
+ #print("Calling fetch_remote")
1293
+ try:
1294
+ from edsl.coop.coop import Coop
1295
+ from edsl.jobs.JobsRemoteInferenceHandler import JobsRemoteInferenceHandler
1296
+
1297
+ # Get the remote job data
1298
+ remote_job_data = JobsRemoteInferenceHandler.check_status(job_info.job_uuid)
1299
+
1300
+ if remote_job_data.get("status") not in ["completed", "failed"]:
1301
+ return False
1302
+ #
1303
+ results_uuid = remote_job_data.get("results_uuid")
1304
+ if not results_uuid:
1305
+ raise ResultsError("No results_uuid found in remote job data")
1306
+
1307
+ # Fetch the remote Results object
1308
+ coop = Coop()
1309
+ remote_results = coop.get(results_uuid, expected_object_type="results")
1310
+
1311
+ # Update this instance with remote data
1312
+ self.data = remote_results.data
1313
+ self.survey = remote_results.survey
1314
+ self.created_columns = remote_results.created_columns
1315
+ self.cache = remote_results.cache
1316
+ self.task_history = remote_results.task_history
1317
+ self.completed = True
1318
+
1319
+ # Set job_uuid and results_uuid from remote data
1320
+ self.job_uuid = job_info.job_uuid
1321
+ if hasattr(remote_results, 'results_uuid'):
1322
+ self.results_uuid = remote_results.results_uuid
1323
+
1324
+ return True
1325
+
1326
+ except Exception as e:
1327
+ raise ResultsError(f"Failed to fetch remote results: {str(e)}")
1328
+
1329
+ def fetch(self, polling_interval: float = 1.0) -> Results:
1330
+ """
1331
+ Polls the server for job completion and updates this Results instance with the completed data.
1332
+
1333
+ Args:
1334
+ polling_interval: Number of seconds to wait between polling attempts (default: 1.0)
1335
+
1336
+ Returns:
1337
+ self: The updated Results instance
1338
+ """
1339
+ if not hasattr(self, "job_info"):
1340
+ raise ResultsError("No job info available - this Results object wasn't created from a remote job")
1341
+
1342
+ from edsl.jobs.JobsRemoteInferenceHandler import JobsRemoteInferenceHandler
1343
+
1344
+ try:
1345
+ # Get the remote job data
1346
+ remote_job_data = JobsRemoteInferenceHandler.check_status(self.job_info.job_uuid)
1347
+
1348
+ while remote_job_data.get("status") not in ["completed", "failed"]:
1349
+ import time
1350
+ time.sleep(polling_interval)
1351
+ remote_job_data = JobsRemoteInferenceHandler.check_status(self.job_info.job_uuid)
1352
+
1353
+ # Once complete, fetch the full results
1354
+ self.fetch_remote(self.job_info)
1355
+ return self
1356
+
1357
+ except Exception as e:
1358
+ raise ResultsError(f"Failed to fetch remote results: {str(e)}")
1181
1359
 
1182
1360
 
1183
1361
  def main(): # pragma: no cover
@@ -5,46 +5,113 @@ import tempfile
5
5
  from typing import Optional
6
6
 
7
7
 
8
+ class GGPlot:
9
+ """A class to handle ggplot2 plot display and saving."""
10
+
11
+ def __init__(self, r_code: str, width: float = 6, height: float = 4):
12
+ """Initialize with R code and dimensions."""
13
+ self.r_code = r_code
14
+ self.width = width
15
+ self.height = height
16
+ self._svg_data = None
17
+ self._saved = False # Track if the plot was saved
18
+
19
+ def _execute_r_code(self, save_command: str = ""):
20
+ """Execute R code with optional save command."""
21
+ full_r_code = self.r_code + save_command
22
+
23
+ result = subprocess.run(
24
+ ["Rscript", "-"],
25
+ input=full_r_code,
26
+ text=True,
27
+ stdout=subprocess.PIPE,
28
+ stderr=subprocess.PIPE,
29
+ )
30
+
31
+ if result.returncode != 0:
32
+ if result.returncode == 127:
33
+ raise RuntimeError(
34
+ "Rscript is probably not installed. Please install R from https://cran.r-project.org/"
35
+ )
36
+ else:
37
+ raise RuntimeError(
38
+ f"An error occurred while running Rscript: {result.stderr}"
39
+ )
40
+
41
+ if result.stderr:
42
+ print("Error in R script:", result.stderr)
43
+
44
+ return result
45
+
46
+ def save(self, filename: str):
47
+ """Save the plot to a file."""
48
+ format = filename.split('.')[-1].lower()
49
+ if format not in ['svg', 'png']:
50
+ raise ValueError("Only 'svg' and 'png' formats are supported")
51
+
52
+ save_command = f'\nggsave("{filename}", plot = last_plot(), width = {self.width}, height = {self.height}, device = "{format}")'
53
+ self._execute_r_code(save_command)
54
+
55
+ self._saved = True
56
+ print(f"File saved to: {filename}")
57
+ return None # Return None instead of self
58
+
59
+ def _repr_html_(self):
60
+ """Display the plot in a Jupyter notebook."""
61
+ # Don't display if the plot was saved
62
+ if self._saved:
63
+ return None
64
+
65
+ import tempfile
66
+
67
+ # Generate SVG if we haven't already
68
+ if self._svg_data is None:
69
+ # Create temporary SVG file
70
+ with tempfile.NamedTemporaryFile(suffix='.svg') as tmp:
71
+ save_command = f'\nggsave("{tmp.name}", plot = last_plot(), width = {self.width}, height = {self.height}, device = "svg")'
72
+ self._execute_r_code(save_command)
73
+ with open(tmp.name, 'r') as f:
74
+ self._svg_data = f.read()
75
+
76
+ return self._svg_data
77
+
8
78
  class ResultsGGMixin:
9
79
  """Mixin class for ggplot2 plotting."""
10
80
 
11
81
  def ggplot2(
12
82
  self,
13
83
  ggplot_code: str,
14
- filename: str = None,
15
84
  shape="wide",
16
85
  sql: str = None,
17
86
  remove_prefix: bool = True,
18
87
  debug: bool = False,
19
88
  height=4,
20
89
  width=6,
21
- format="svg",
22
90
  factor_orders: Optional[dict] = None,
23
91
  ):
24
92
  """Create a ggplot2 plot from a DataFrame.
25
93
 
94
+ Returns a GGPlot object that can be displayed in a notebook or saved to a file.
95
+
26
96
  :param ggplot_code: The ggplot2 code to execute.
27
- :param filename: The filename to save the plot to.
28
97
  :param shape: The shape of the data in the DataFrame (wide or long).
29
98
  :param sql: The SQL query to execute beforehand to manipulate the data.
30
99
  :param remove_prefix: Whether to remove the prefix from the column names.
31
100
  :param debug: Whether to print the R code instead of executing it.
32
101
  :param height: The height of the plot in inches.
33
102
  :param width: The width of the plot in inches.
34
- :param format: The format to save the plot in (png or svg).
35
103
  :param factor_orders: A dictionary of factor columns and their order.
36
104
  """
37
-
38
105
  if sql == None:
39
106
  sql = "select * from self"
40
107
 
41
108
  if shape == "long":
42
109
  df = self.sql(sql, shape="long")
43
110
  elif shape == "wide":
44
- df = self.sql(sql, shape="wide", remove_prefix=remove_prefix)
111
+ df = self.sql(sql, remove_prefix=remove_prefix)
45
112
 
46
113
  # Convert DataFrame to CSV format
47
- csv_data = df.to_csv(index=False)
114
+ csv_data = df.to_csv().text
48
115
 
49
116
  # Embed the CSV data within the R script
50
117
  csv_data_escaped = csv_data.replace("\n", "\\n").replace("'", "\\'")
@@ -52,70 +119,60 @@ class ResultsGGMixin:
52
119
 
53
120
  if factor_orders is not None:
54
121
  for factor, order in factor_orders.items():
55
- # read_csv_code += f"""self${{{factor}}} <- factor(self${{{factor}}}, levels=c({','.join(['"{}"'.format(x) for x in order])}))"""
56
-
57
122
  level_string = ", ".join([f'"{x}"' for x in order])
58
123
  read_csv_code += (
59
124
  f"self${factor} <- factor(self${factor}, levels=c({level_string}))"
60
125
  )
61
126
  read_csv_code += "\n"
62
127
 
63
- # Load ggplot2 library
64
- load_ggplot2 = "library(ggplot2)\n"
65
-
66
- # Check if a filename is provided for the plot, if not create a temporary one
67
- if not filename:
68
- filename = tempfile.mktemp(suffix=f".{format}")
69
-
70
- # Combine all R script parts
71
- full_r_code = load_ggplot2 + read_csv_code + ggplot_code
72
-
73
- # Add command to save the plot to a file
74
- full_r_code += f'\nggsave("{filename}", plot = last_plot(), width = {width}, height = {height}, device = "{format}")'
128
+ # Load ggplot2 library and combine all R script parts
129
+ full_r_code = "library(ggplot2)\n" + read_csv_code + ggplot_code
75
130
 
76
131
  if debug:
77
132
  print(full_r_code)
78
133
  return
79
134
 
80
- result = subprocess.run(
81
- ["Rscript", "-"],
82
- input=full_r_code,
83
- text=True,
84
- stdout=subprocess.PIPE,
85
- stderr=subprocess.PIPE,
86
- )
87
-
88
- if result.returncode != 0:
89
- if result.returncode == 127: # 'command not found'
90
- raise RuntimeError(
91
- "Rscript is probably not installed. Please install R from https://cran.r-project.org/"
92
- )
93
- else:
94
- raise RuntimeError(
95
- f"An error occurred while running Rscript: {result.stderr}"
96
- )
97
-
98
- if result.stderr:
99
- print("Error in R script:", result.stderr)
100
- else:
101
- self._display_plot(filename, width, height)
135
+ return GGPlot(full_r_code, width=width, height=height)
102
136
 
103
137
  def _display_plot(self, filename: str, width: float, height: float):
104
- """Display the plot in the notebook."""
105
- import matplotlib.pyplot as plt
106
- import matplotlib.image as mpimg
107
-
108
- if filename.endswith(".png"):
109
- img = mpimg.imread(filename)
110
- plt.figure(
111
- figsize=(width, height)
112
- ) # Set the figure size (width, height) in inches
113
- plt.imshow(img)
114
- plt.axis("off")
115
- plt.show()
116
- elif filename.endswith(".svg"):
117
- from IPython.display import SVG, display
118
-
119
- display(SVG(filename=filename))
138
+ """Display the plot in the notebook or open in system viewer if running from terminal."""
139
+ try:
140
+ # Try to import IPython-related modules
141
+ import matplotlib.pyplot as plt
142
+ import matplotlib.image as mpimg
143
+ from IPython import get_ipython
144
+
145
+ # Check if we're in a notebook environment
146
+ if get_ipython() is not None:
147
+ if filename.endswith(".png"):
148
+ img = mpimg.imread(filename)
149
+ plt.figure(figsize=(width, height))
150
+ plt.imshow(img)
151
+ plt.axis("off")
152
+ plt.show()
153
+ elif filename.endswith(".svg"):
154
+ from IPython.display import SVG, display
155
+ display(SVG(filename=filename))
156
+ else:
157
+ print("Unsupported file format. Please provide a PNG or SVG file.")
158
+ return
159
+
160
+ except ImportError:
161
+ pass
162
+
163
+ # If we're not in a notebook or imports failed, open with system viewer
164
+ import platform
165
+ import os
166
+
167
+ system = platform.system()
168
+ if system == 'Darwin': # macOS
169
+ if filename.endswith('.svg'):
170
+ subprocess.run(['open', '-a', 'Preview', filename])
171
+ else:
172
+ subprocess.run(['open', filename])
173
+ elif system == 'Linux':
174
+ subprocess.run(['xdg-open', filename])
175
+ elif system == 'Windows':
176
+ os.startfile(filename)
120
177
  else:
121
- print("Unsupported file format. Please provide a PNG or SVG file.")
178
+ print(f"File saved to: {filename}")
@@ -2,14 +2,11 @@ import os
2
2
 
3
3
 
4
4
  class PdfExtractor:
5
- def __init__(self, pdf_path: str, parent_object: object):
5
+ def __init__(self, pdf_path: str):
6
6
  self.pdf_path = pdf_path
7
- self.constructor = parent_object.__class__
7
+ #self.constructor = parent_object.__class__
8
8
 
9
- def get_object(self) -> object:
10
- return self.constructor(self._get_pdf_dict())
11
-
12
- def _get_pdf_dict(self) -> dict:
9
+ def get_pdf_dict(self) -> dict:
13
10
  # Ensure the file exists
14
11
  import fitz
15
12