edsl 0.1.44__py3-none-any.whl → 0.1.45__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. edsl/__version__.py +1 -1
  2. edsl/agents/InvigilatorBase.py +3 -1
  3. edsl/agents/PromptConstructor.py +62 -34
  4. edsl/agents/QuestionInstructionPromptBuilder.py +111 -68
  5. edsl/agents/QuestionTemplateReplacementsBuilder.py +69 -16
  6. edsl/agents/question_option_processor.py +15 -6
  7. edsl/coop/CoopFunctionsMixin.py +3 -4
  8. edsl/coop/coop.py +23 -9
  9. edsl/enums.py +3 -3
  10. edsl/inference_services/AnthropicService.py +11 -9
  11. edsl/inference_services/AvailableModelFetcher.py +2 -0
  12. edsl/inference_services/AwsBedrock.py +1 -2
  13. edsl/inference_services/AzureAI.py +12 -9
  14. edsl/inference_services/GoogleService.py +9 -4
  15. edsl/inference_services/InferenceServicesCollection.py +2 -2
  16. edsl/inference_services/MistralAIService.py +1 -2
  17. edsl/inference_services/OpenAIService.py +9 -4
  18. edsl/inference_services/PerplexityService.py +2 -1
  19. edsl/inference_services/{GrokService.py → XAIService.py} +2 -2
  20. edsl/inference_services/registry.py +2 -2
  21. edsl/jobs/Jobs.py +9 -0
  22. edsl/jobs/JobsChecks.py +10 -13
  23. edsl/jobs/async_interview_runner.py +3 -1
  24. edsl/jobs/check_survey_scenario_compatibility.py +5 -5
  25. edsl/jobs/interviews/InterviewExceptionEntry.py +12 -0
  26. edsl/jobs/tasks/TaskHistory.py +1 -1
  27. edsl/language_models/LanguageModel.py +0 -3
  28. edsl/language_models/PriceManager.py +45 -5
  29. edsl/language_models/model.py +47 -26
  30. edsl/questions/QuestionBase.py +21 -0
  31. edsl/questions/QuestionBasePromptsMixin.py +103 -0
  32. edsl/questions/QuestionFreeText.py +22 -5
  33. edsl/questions/descriptors.py +4 -0
  34. edsl/questions/question_base_gen_mixin.py +94 -29
  35. edsl/results/Dataset.py +65 -0
  36. edsl/results/DatasetExportMixin.py +299 -32
  37. edsl/results/Result.py +27 -0
  38. edsl/results/Results.py +22 -2
  39. edsl/results/ResultsGGMixin.py +7 -3
  40. edsl/scenarios/DocumentChunker.py +2 -0
  41. edsl/scenarios/FileStore.py +10 -0
  42. edsl/scenarios/PdfExtractor.py +21 -1
  43. edsl/scenarios/Scenario.py +25 -9
  44. edsl/scenarios/ScenarioList.py +73 -3
  45. edsl/scenarios/handlers/__init__.py +1 -0
  46. edsl/scenarios/handlers/docx.py +5 -1
  47. edsl/scenarios/handlers/jpeg.py +39 -0
  48. edsl/surveys/Survey.py +5 -4
  49. edsl/surveys/SurveyFlowVisualization.py +91 -43
  50. edsl/templates/error_reporting/exceptions_table.html +7 -8
  51. edsl/templates/error_reporting/interview_details.html +1 -1
  52. edsl/templates/error_reporting/interviews.html +0 -1
  53. edsl/templates/error_reporting/overview.html +2 -7
  54. edsl/templates/error_reporting/performance_plot.html +1 -1
  55. edsl/templates/error_reporting/report.css +1 -1
  56. edsl/utilities/PrettyList.py +14 -0
  57. edsl-0.1.45.dist-info/METADATA +246 -0
  58. {edsl-0.1.44.dist-info → edsl-0.1.45.dist-info}/RECORD +60 -59
  59. edsl-0.1.44.dist-info/METADATA +0 -110
  60. {edsl-0.1.44.dist-info → edsl-0.1.45.dist-info}/LICENSE +0 -0
  61. {edsl-0.1.44.dist-info → edsl-0.1.45.dist-info}/WHEEL +0 -0
@@ -7,6 +7,7 @@ from typing import Optional, Tuple, Union, List
7
7
 
8
8
  from edsl.results.file_exports import CSVExport, ExcelExport, JSONLExport, SQLiteExport
9
9
 
10
+
10
11
  class DatasetExportMixin:
11
12
  """Mixin class for exporting Dataset objects."""
12
13
 
@@ -82,7 +83,8 @@ class DatasetExportMixin:
82
83
  else:
83
84
  if len(values) != _num_observations:
84
85
  raise ValueError(
85
- "The number of observations is not consistent across columns."
86
+ f"The number of observations is not consistent across columns. "
87
+ f"Column '{key}' has {len(values)} observations, but previous columns had {_num_observations} observations."
86
88
  )
87
89
 
88
90
  return _num_observations
@@ -219,7 +221,9 @@ class DatasetExportMixin:
219
221
  )
220
222
  return exporter.export()
221
223
 
222
- def _db(self, remove_prefix: bool = True, shape: str = "wide") -> "sqlalchemy.engine.Engine":
224
+ def _db(
225
+ self, remove_prefix: bool = True, shape: str = "wide"
226
+ ) -> "sqlalchemy.engine.Engine":
223
227
  """Create a SQLite database in memory and return the connection.
224
228
 
225
229
  Args:
@@ -229,7 +233,7 @@ class DatasetExportMixin:
229
233
  Returns:
230
234
  A database connection
231
235
  >>> from sqlalchemy import text
232
- >>> from edsl import Results
236
+ >>> from edsl import Results
233
237
  >>> engine = Results.example()._db()
234
238
  >>> len(engine.execute(text("SELECT * FROM self")).fetchall())
235
239
  4
@@ -247,16 +251,17 @@ class DatasetExportMixin:
247
251
 
248
252
  if shape == "long":
249
253
  # Melt the dataframe to convert it to long format
250
- df = df.melt(
251
- var_name='key',
252
- value_name='value'
253
- )
254
+ df = df.melt(var_name="key", value_name="value")
254
255
  # Add a row number column for reference
255
- df.insert(0, 'row_number', range(1, len(df) + 1))
256
-
256
+ df.insert(0, "row_number", range(1, len(df) + 1))
257
+
257
258
  # Split the key into data_type and key
258
- df['data_type'] = df['key'].apply(lambda x: x.split('.')[0] if '.' in x else None)
259
- df['key'] = df['key'].apply(lambda x: '.'.join(x.split('.')[1:]) if '.' in x else x)
259
+ df["data_type"] = df["key"].apply(
260
+ lambda x: x.split(".")[0] if "." in x else None
261
+ )
262
+ df["key"] = df["key"].apply(
263
+ lambda x: ".".join(x.split(".")[1:]) if "." in x else x
264
+ )
260
265
 
261
266
  df.to_sql(
262
267
  "self",
@@ -276,27 +281,27 @@ class DatasetExportMixin:
276
281
  ) -> Union["pd.DataFrame", str]:
277
282
  """Execute a SQL query and return the results as a DataFrame.
278
283
 
279
- Args:
280
- query: The SQL query to execute
281
- shape: The shape of the data in the database (wide or long)
282
- remove_prefix: Whether to remove the prefix from the column names
283
- transpose: Whether to transpose the DataFrame
284
- transpose_by: The column to use as the index when transposing
285
- csv: Whether to return the DataFrame as a CSV string
286
- to_list: Whether to return the results as a list
287
- to_latex: Whether to return the results as LaTeX
288
- filename: Optional filename to save the results to
289
-
290
- Returns:
291
- DataFrame, CSV string, list, or LaTeX string depending on parameters
292
-
293
- Examples:
294
- >>> from edsl import Results
295
- >>> r = Results.example();
296
- >>> len(r.sql("SELECT * FROM self", shape = "wide"))
297
- 4
298
- >>> len(r.sql("SELECT * FROM self", shape = "long"))
299
- 172
284
+ Args:
285
+ query: The SQL query to execute
286
+ shape: The shape of the data in the database (wide or long)
287
+ remove_prefix: Whether to remove the prefix from the column names
288
+ transpose: Whether to transpose the DataFrame
289
+ transpose_by: The column to use as the index when transposing
290
+ csv: Whether to return the DataFrame as a CSV string
291
+ to_list: Whether to return the results as a list
292
+ to_latex: Whether to return the results as LaTeX
293
+ filename: Optional filename to save the results to
294
+
295
+ Returns:
296
+ DataFrame, CSV string, list, or LaTeX string depending on parameters
297
+
298
+ Examples:
299
+ >>> from edsl import Results
300
+ >>> r = Results.example();
301
+ >>> len(r.sql("SELECT * FROM self", shape = "wide"))
302
+ 4
303
+ >>> len(r.sql("SELECT * FROM self", shape = "long"))
304
+ 172
300
305
  """
301
306
  import pandas as pd
302
307
 
@@ -538,6 +543,116 @@ class DatasetExportMixin:
538
543
 
539
544
  if return_link:
540
545
  return filename
546
+
547
+ def report(self, *fields: Optional[str], top_n: Optional[int] = None,
548
+ header_fields: Optional[List[str]] = None, divider: bool = True,
549
+ return_string: bool = False) -> Optional[str]:
550
+ """Takes the fields in order and returns a report of the results by iterating through rows.
551
+ The row number is printed as # Observation: <row number>
552
+ The name of the field is used as markdown header at level "##"
553
+ The content of that field is then printed.
554
+ Then the next field and so on.
555
+ Once that row is done, a new line is printed and the next row is shown.
556
+ If in a jupyter notebook, the report is displayed as markdown.
557
+
558
+ Args:
559
+ *fields: The fields to include in the report. If none provided, all fields are used.
560
+ top_n: Optional limit on the number of observations to include.
561
+ header_fields: Optional list of fields to include in the main header instead of as sections.
562
+ divider: If True, adds a horizontal rule between observations for better visual separation.
563
+ return_string: If True, returns the markdown string. If False (default in notebooks),
564
+ only displays the markdown without returning.
565
+
566
+ Returns:
567
+ A string containing the markdown report if return_string is True, otherwise None.
568
+
569
+ Examples:
570
+ >>> from edsl.results import Results
571
+ >>> r = Results.example()
572
+ >>> report = r.select('how_feeling', 'how_feeling_yesterday').report(return_string=True)
573
+ >>> "# Observation: 1" in report
574
+ True
575
+ >>> "## answer.how_feeling" in report
576
+ True
577
+ >>> report = r.select('how_feeling').report(header_fields=['answer.how_feeling'], return_string=True)
578
+ >>> "# Observation: 1 (`how_feeling`: OK)" in report
579
+ True
580
+ """
581
+ from edsl.utilities.utilities import is_notebook
582
+
583
+ # If no fields specified, use all columns
584
+ if not fields:
585
+ fields = self.relevant_columns()
586
+
587
+ # Initialize header_fields if not provided
588
+ if header_fields is None:
589
+ header_fields = []
590
+
591
+ # Validate all fields
592
+ all_fields = list(fields) + [f for f in header_fields if f not in fields]
593
+ for field in all_fields:
594
+ if field not in self.relevant_columns():
595
+ raise ValueError(f"Field '{field}' not found in dataset")
596
+
597
+ # Get data for each field
598
+ field_data = {}
599
+ for field in all_fields:
600
+ for entry in self:
601
+ if field in entry:
602
+ field_data[field] = entry[field]
603
+ break
604
+
605
+ # Number of observations to process
606
+ num_obs = self.num_observations()
607
+ if top_n is not None:
608
+ num_obs = min(num_obs, top_n)
609
+
610
+ # Build the report
611
+ report_lines = []
612
+ for i in range(num_obs):
613
+ # Create header with observation number and any header fields
614
+ header = f"# Observation: {i+1}"
615
+ if header_fields:
616
+ header_parts = []
617
+ for field in header_fields:
618
+ value = field_data[field][i]
619
+ # Get the field name without prefix for cleaner display
620
+ display_name = field.split('.')[-1] if '.' in field else field
621
+ # Format with backticks for monospace
622
+ header_parts.append(f"`{display_name}`: {value}")
623
+ if header_parts:
624
+ header += f" ({', '.join(header_parts)})"
625
+ report_lines.append(header)
626
+
627
+ # Add the remaining fields
628
+ for field in fields:
629
+ if field not in header_fields:
630
+ report_lines.append(f"## {field}")
631
+ value = field_data[field][i]
632
+ if isinstance(value, list) or isinstance(value, dict):
633
+ import json
634
+ report_lines.append(f"```\n{json.dumps(value, indent=2)}\n```")
635
+ else:
636
+ report_lines.append(str(value))
637
+
638
+ # Add divider between observations if requested
639
+ if divider and i < num_obs - 1:
640
+ report_lines.append("\n---\n")
641
+ else:
642
+ report_lines.append("") # Empty line between observations
643
+
644
+ report_text = "\n".join(report_lines)
645
+
646
+ # In notebooks, display as markdown and optionally return
647
+ is_nb = is_notebook()
648
+ if is_nb:
649
+ from IPython.display import Markdown, display
650
+ display(Markdown(report_text))
651
+
652
+ # Return the string if requested or if not in a notebook
653
+ if return_string or not is_nb:
654
+ return report_text
655
+ return None
541
656
 
542
657
  def tally(
543
658
  self, *fields: Optional[str], top_n: Optional[int] = None, output="Dataset"
@@ -616,6 +731,158 @@ class DatasetExportMixin:
616
731
  keys.append("count")
617
732
  return sl.reorder_keys(keys).to_dataset()
618
733
 
734
+ def flatten(self, field, keep_original=False):
735
+ """
736
+ Flatten a field containing a list of dictionaries into separate fields.
737
+
738
+ For example, if a dataset contains:
739
+ [{'data': [{'a': 1}, {'b': 2}], 'other': ['x', 'y']}]
740
+
741
+ After d.flatten('data'), it should become:
742
+ [{'other': ['x', 'y'], 'data.a': [1, None], 'data.b': [None, 2]}]
743
+
744
+ Args:
745
+ field: The field to flatten
746
+ keep_original: If True, keeps the original field in the dataset
747
+
748
+ Returns:
749
+ A new dataset with the flattened fields
750
+ """
751
+ from edsl.results.Dataset import Dataset
752
+
753
+ # Ensure the dataset isn't empty
754
+ if not self.data:
755
+ return self.copy()
756
+
757
+ # Get the number of observations
758
+ num_observations = self.num_observations()
759
+
760
+ # Find the column to flatten
761
+ field_entry = None
762
+ for entry in self.data:
763
+ if field in entry:
764
+ field_entry = entry
765
+ break
766
+
767
+ if field_entry is None:
768
+ warnings.warn(
769
+ f"Field '{field}' not found in dataset, returning original dataset"
770
+ )
771
+ return self.copy()
772
+
773
+ # Create new dictionary for flattened data
774
+ flattened_data = []
775
+
776
+ # Copy all existing columns except the one we're flattening (if keep_original is False)
777
+ for entry in self.data:
778
+ col_name = next(iter(entry.keys()))
779
+ if col_name != field or keep_original:
780
+ flattened_data.append(entry.copy())
781
+
782
+ # Get field data and make sure it's valid
783
+ field_values = field_entry[field]
784
+ if not all(isinstance(item, dict) for item in field_values if item is not None):
785
+ warnings.warn(
786
+ f"Field '{field}' contains non-dictionary values that cannot be flattened"
787
+ )
788
+ return self.copy()
789
+
790
+ # Collect all unique keys across all dictionaries
791
+ all_keys = set()
792
+ for item in field_values:
793
+ if isinstance(item, dict):
794
+ all_keys.update(item.keys())
795
+
796
+ # Create new columns for each key
797
+ for key in sorted(all_keys): # Sort for consistent output
798
+ new_values = []
799
+ for i in range(num_observations):
800
+ value = None
801
+ if i < len(field_values) and isinstance(field_values[i], dict):
802
+ value = field_values[i].get(key, None)
803
+ new_values.append(value)
804
+
805
+ # Add this as a new column
806
+ flattened_data.append({f"{field}.{key}": new_values})
807
+
808
+ # Return a new Dataset with the flattened data
809
+ return Dataset(flattened_data)
810
+
811
+ def unpack_list(
812
+ self,
813
+ field: str,
814
+ new_names: Optional[List[str]] = None,
815
+ keep_original: bool = True,
816
+ ) -> "Dataset":
817
+ """Unpack list columns into separate columns with provided names or numeric suffixes.
818
+
819
+ For example, if a dataset contains:
820
+ [{'data': [[1, 2, 3], [4, 5, 6]], 'other': ['x', 'y']}]
821
+
822
+ After d.unpack_list('data'), it should become:
823
+ [{'other': ['x', 'y'], 'data_1': [1, 4], 'data_2': [2, 5], 'data_3': [3, 6]}]
824
+
825
+ Args:
826
+ field: The field containing lists to unpack
827
+ new_names: Optional list of names for the unpacked fields. If None, uses numeric suffixes.
828
+ keep_original: If True, keeps the original field in the dataset
829
+
830
+ Returns:
831
+ A new Dataset with unpacked columns
832
+
833
+ Examples:
834
+ >>> from edsl.results.Dataset import Dataset
835
+ >>> d = Dataset([{'data': [[1, 2, 3], [4, 5, 6]]}])
836
+ >>> d.unpack_list('data')
837
+ Dataset([{'data': [[1, 2, 3], [4, 5, 6]]}, {'data_1': [1, 4]}, {'data_2': [2, 5]}, {'data_3': [3, 6]}])
838
+
839
+ >>> d.unpack_list('data', new_names=['first', 'second', 'third'])
840
+ Dataset([{'data': [[1, 2, 3], [4, 5, 6]]}, {'first': [1, 4]}, {'second': [2, 5]}, {'third': [3, 6]}])
841
+ """
842
+ from edsl.results.Dataset import Dataset
843
+
844
+ # Create a copy of the dataset
845
+ result = Dataset(self.data.copy())
846
+
847
+ # Find the field in the dataset
848
+ field_index = None
849
+ for i, entry in enumerate(result.data):
850
+ if field in entry:
851
+ field_index = i
852
+ break
853
+
854
+ if field_index is None:
855
+ raise ValueError(f"Field '{field}' not found in dataset")
856
+
857
+ field_data = result.data[field_index][field]
858
+
859
+ # Check if values are lists
860
+ if not all(isinstance(v, list) for v in field_data):
861
+ raise ValueError(f"Field '{field}' does not contain lists in all entries")
862
+
863
+ # Get the maximum length of lists
864
+ max_len = max(len(v) for v in field_data)
865
+
866
+ # Create new fields for each index
867
+ for i in range(max_len):
868
+ if new_names and i < len(new_names):
869
+ new_field = new_names[i]
870
+ else:
871
+ new_field = f"{field}_{i+1}"
872
+
873
+ # Extract the i-th element from each list
874
+ new_values = []
875
+ for item in field_data:
876
+ new_values.append(item[i] if i < len(item) else None)
877
+
878
+ result.data.append({new_field: new_values})
879
+
880
+ # Remove the original field if keep_original is False
881
+ if not keep_original:
882
+ result.data.pop(field_index)
883
+
884
+ return result
885
+
619
886
 
620
887
  if __name__ == "__main__":
621
888
  import doctest
edsl/results/Result.py CHANGED
@@ -439,6 +439,33 @@ class Result(Base, UserDict):
439
439
  from edsl.results.Results import Results
440
440
 
441
441
  return Results.example()[0]
442
+
443
+ def score_with_answer_key(self, answer_key: dict) -> Union[int, float]:
444
+ """Score the result using an answer key.
445
+
446
+ :param answer_key: A dictionary that maps question_names to answers
447
+
448
+ >>> Result.example()['answer']
449
+ {'how_feeling': 'OK', 'how_feeling_yesterday': 'Great'}
450
+
451
+ >>> answer_key = {'how_feeling': 'OK', 'how_feeling_yesterday': 'Great'}
452
+ >>> Result.example().score_with_answer_key(answer_key)
453
+ {'correct': 2, 'incorrect': 0, 'missing': 0}
454
+ >>> answer_key = {'how_feeling': 'OK', 'how_feeling_yesterday': ['Great', 'Good']}
455
+ >>> Result.example().score_with_answer_key(answer_key)
456
+ {'correct': 2, 'incorrect': 0, 'missing': 0}
457
+ """
458
+ final_scores = {'correct': 0, 'incorrect': 0, 'missing': 0}
459
+ for question_name, answer in self.answer.items():
460
+ if question_name in answer_key:
461
+ if answer == answer_key[question_name] or answer in answer_key[question_name]:
462
+ final_scores['correct'] += 1
463
+ else:
464
+ final_scores['incorrect'] += 1
465
+ else:
466
+ final_scores['missing'] += 1
467
+
468
+ return final_scores
442
469
 
443
470
  def score(self, scoring_function: Callable) -> Union[int, float]:
444
471
  """Score the result using a passed-in scoring function.
edsl/results/Results.py CHANGED
@@ -34,7 +34,7 @@ if TYPE_CHECKING:
34
34
  from simpleeval import EvalWithCompoundTypes
35
35
 
36
36
  from edsl.results.ResultsExportMixin import ResultsExportMixin
37
- from edsl.results.ResultsGGMixin import ResultsGGMixin
37
+ from edsl.results.ResultsGGMixin import GGPlotMethod
38
38
  from edsl.results.results_fetch_mixin import ResultsFetchMixin
39
39
  from edsl.utilities.remove_edsl_version import remove_edsl_version
40
40
 
@@ -100,7 +100,7 @@ class NotReadyObject:
100
100
  class Mixins(
101
101
  ResultsExportMixin,
102
102
  ResultsFetchMixin,
103
- ResultsGGMixin,
103
+ # ResultsGGMixin,
104
104
  ):
105
105
  def long(self):
106
106
  return self.table().long()
@@ -151,6 +151,19 @@ class Results(UserList, Mixins, Base):
151
151
  "cache_keys",
152
152
  ]
153
153
 
154
+ def ggplot2(
155
+ self,
156
+ ggplot_code: str,
157
+ shape="wide",
158
+ sql: str = None,
159
+ remove_prefix: bool = True,
160
+ debug: bool = False,
161
+ height=4,
162
+ width=6,
163
+ factor_orders: Optional[dict] = None,
164
+ ):
165
+ return GGPlotMethod(self).ggplot2(ggplot_code, shape, sql, remove_prefix, debug, height, width, factor_orders)
166
+
154
167
  @classmethod
155
168
  def from_job_info(cls, job_info: dict) -> Results:
156
169
  """
@@ -1277,6 +1290,13 @@ class Results(UserList, Mixins, Base):
1277
1290
  """
1278
1291
  return [r.score(f) for r in self.data]
1279
1292
 
1293
+ def score_with_answer_key(self, answer_key: dict) -> list:
1294
+ """Score the results using an answer key.
1295
+
1296
+ :param answer_key: A dictionary that maps answer values to scores.
1297
+ """
1298
+ return [r.score_with_answer_key(answer_key) for r in self.data]
1299
+
1280
1300
 
1281
1301
  def fetch_remote(self, job_info: "RemoteJobInfo") -> None:
1282
1302
  """
@@ -75,7 +75,11 @@ class GGPlot:
75
75
 
76
76
  return self._svg_data
77
77
 
78
- class ResultsGGMixin:
78
+ class GGPlotMethod:
79
+
80
+ def __init__(self, results: 'Results'):
81
+ self.results = results
82
+
79
83
  """Mixin class for ggplot2 plotting."""
80
84
 
81
85
  def ggplot2(
@@ -106,9 +110,9 @@ class ResultsGGMixin:
106
110
  sql = "select * from self"
107
111
 
108
112
  if shape == "long":
109
- df = self.sql(sql, shape="long")
113
+ df = self.results.sql(sql, shape="long")
110
114
  elif shape == "wide":
111
- df = self.sql(sql, remove_prefix=remove_prefix)
115
+ df = self.results.sql(sql, remove_prefix=remove_prefix)
112
116
 
113
117
  # Convert DataFrame to CSV format
114
118
  csv_data = df.to_csv().text
@@ -85,6 +85,8 @@ class DocumentChunker:
85
85
  new_scenario = copy.deepcopy(self.scenario)
86
86
  new_scenario[field] = chunk
87
87
  new_scenario[field + "_chunk"] = i
88
+ new_scenario[field + "_char_count"] = len(chunk)
89
+ new_scenario[field + "_word_count"] = len(chunk.split())
88
90
  if include_original:
89
91
  if hash_original:
90
92
  new_scenario[field + "_original"] = hashlib.md5(
@@ -29,6 +29,12 @@ class FileStore(Scenario):
29
29
  if path is None and "filename" in kwargs:
30
30
  path = kwargs["filename"]
31
31
 
32
+ # Check if path is a URL and handle download
33
+ if path and (path.startswith('http://') or path.startswith('https://')):
34
+ temp_filestore = self.from_url(path, mime_type=mime_type)
35
+ path = temp_filestore._path
36
+ mime_type = temp_filestore.mime_type
37
+
32
38
  self._path = path # Store the original path privately
33
39
  self._temp_path = None # Track any generated temporary file
34
40
 
@@ -138,6 +144,10 @@ class FileStore(Scenario):
138
144
  base64_encoded_data = base64.b64encode(binary_data)
139
145
  self.binary = True
140
146
  # Convert the base64 bytes to a string
147
+ except FileNotFoundError:
148
+ print(f"File not found: {file_path}")
149
+ print("Current working directory:", os.getcwd())
150
+ raise
141
151
  base64_string = base64_encoded_data.decode("utf-8")
142
152
 
143
153
  return base64_string
@@ -4,10 +4,30 @@ import os
4
4
  class PdfExtractor:
5
5
  def __init__(self, pdf_path: str):
6
6
  self.pdf_path = pdf_path
7
+ self._has_pymupdf = self._check_pymupdf()
7
8
  #self.constructor = parent_object.__class__
8
9
 
10
+ def _check_pymupdf(self):
11
+ """Check if PyMuPDF is installed."""
12
+ try:
13
+ import fitz
14
+ return True
15
+ except ImportError:
16
+ return False
17
+
9
18
  def get_pdf_dict(self) -> dict:
10
- # Ensure the file exists
19
+ # First check if the file exists
20
+ if not os.path.exists(self.pdf_path):
21
+ raise FileNotFoundError(f"The file {self.pdf_path} does not exist.")
22
+
23
+ # Then check if PyMuPDF is available
24
+ if not self._has_pymupdf:
25
+ raise ImportError(
26
+ "The 'fitz' module (PyMuPDF) is required for PDF extraction. "
27
+ "Please install it with: pip install pymupdf"
28
+ )
29
+
30
+ # If we get here, we can safely import and use fitz
11
31
  import fitz
12
32
 
13
33
  if not os.path.exists(self.pdf_path):
@@ -64,6 +64,15 @@ class Scenario(Base, UserDict, ScenarioHtmlMixin):
64
64
  self.data = data if data is not None else {}
65
65
  self.name = name
66
66
 
67
+ def __mul__(self, scenario_list_or_scenario: Union["ScenarioList", "Scenario"]) -> "ScenarioList":
68
+ from edsl.scenarios.ScenarioList import ScenarioList
69
+ if isinstance(scenario_list_or_scenario, ScenarioList):
70
+ return scenario_list_or_scenario * self
71
+ elif isinstance(scenario_list_or_scenario, Scenario):
72
+ return ScenarioList([self]) * scenario_list_or_scenario
73
+ else:
74
+ raise TypeError(f"Cannot multiply Scenario with {type(scenario_list_or_scenario)}")
75
+
67
76
  def replicate(self, n: int) -> "ScenarioList":
68
77
  """Replicate a scenario n times to return a ScenarioList.
69
78
 
@@ -356,11 +365,18 @@ class Scenario(Base, UserDict, ScenarioHtmlMixin):
356
365
 
357
366
  @classmethod
358
367
  def from_pdf(cls, pdf_path: str):
359
- from edsl.scenarios.PdfExtractor import PdfExtractor
360
-
361
- extractor = PdfExtractor(pdf_path)
362
- return Scenario(extractor.get_pdf_dict())
363
-
368
+ """Create a Scenario from a PDF file."""
369
+ try:
370
+ from edsl.scenarios.PdfExtractor import PdfExtractor
371
+ extractor = PdfExtractor(pdf_path)
372
+ return Scenario(extractor.get_pdf_dict())
373
+ except ImportError as e:
374
+ raise ImportError(
375
+ f"Could not extract text from PDF: {str(e)}. "
376
+ "PDF extraction requires the PyMuPDF library. "
377
+ "Install it with: pip install pymupdf"
378
+ )
379
+
364
380
  @classmethod
365
381
  def from_pdf_to_image(cls, pdf_path, image_format="jpeg"):
366
382
  """
@@ -442,18 +458,18 @@ class Scenario(Base, UserDict, ScenarioHtmlMixin):
442
458
 
443
459
  >>> s = Scenario({"text": "This is a test.\\nThis is a test.\\n\\nThis is a test."})
444
460
  >>> s.chunk("text", num_lines = 1)
445
- ScenarioList([Scenario({'text': 'This is a test.', 'text_chunk': 0}), Scenario({'text': 'This is a test.', 'text_chunk': 1}), Scenario({'text': '', 'text_chunk': 2}), Scenario({'text': 'This is a test.', 'text_chunk': 3})])
461
+ ScenarioList([Scenario({'text': 'This is a test.', 'text_chunk': 0, 'text_char_count': 15, 'text_word_count': 4}), Scenario({'text': 'This is a test.', 'text_chunk': 1, 'text_char_count': 15, 'text_word_count': 4}), Scenario({'text': '', 'text_chunk': 2, 'text_char_count': 0, 'text_word_count': 0}), Scenario({'text': 'This is a test.', 'text_chunk': 3, 'text_char_count': 15, 'text_word_count': 4})])
446
462
 
447
463
  >>> s.chunk("text", num_words = 2)
448
- ScenarioList([Scenario({'text': 'This is', 'text_chunk': 0}), Scenario({'text': 'a test.', 'text_chunk': 1}), Scenario({'text': 'This is', 'text_chunk': 2}), Scenario({'text': 'a test.', 'text_chunk': 3}), Scenario({'text': 'This is', 'text_chunk': 4}), Scenario({'text': 'a test.', 'text_chunk': 5})])
464
+ ScenarioList([Scenario({'text': 'This is', 'text_chunk': 0, 'text_char_count': 7, 'text_word_count': 2}), Scenario({'text': 'a test.', 'text_chunk': 1, 'text_char_count': 7, 'text_word_count': 2}), Scenario({'text': 'This is', 'text_chunk': 2, 'text_char_count': 7, 'text_word_count': 2}), Scenario({'text': 'a test.', 'text_chunk': 3, 'text_char_count': 7, 'text_word_count': 2}), Scenario({'text': 'This is', 'text_chunk': 4, 'text_char_count': 7, 'text_word_count': 2}), Scenario({'text': 'a test.', 'text_chunk': 5, 'text_char_count': 7, 'text_word_count': 2})])
449
465
 
450
466
  >>> s = Scenario({"text": "Hello World"})
451
467
  >>> s.chunk("text", num_words = 1, include_original = True)
452
- ScenarioList([Scenario({'text': 'Hello', 'text_chunk': 0, 'text_original': 'Hello World'}), Scenario({'text': 'World', 'text_chunk': 1, 'text_original': 'Hello World'})])
468
+ ScenarioList([Scenario({'text': 'Hello', 'text_chunk': 0, 'text_char_count': 5, 'text_word_count': 1, 'text_original': 'Hello World'}), Scenario({'text': 'World', 'text_chunk': 1, 'text_char_count': 5, 'text_word_count': 1, 'text_original': 'Hello World'})])
453
469
 
454
470
  >>> s = Scenario({"text": "Hello World"})
455
471
  >>> s.chunk("text", num_words = 1, include_original = True, hash_original = True)
456
- ScenarioList([Scenario({'text': 'Hello', 'text_chunk': 0, 'text_original': 'b10a8db164e0754105b7a99be72e3fe5'}), Scenario({'text': 'World', 'text_chunk': 1, 'text_original': 'b10a8db164e0754105b7a99be72e3fe5'})])
472
+ ScenarioList([Scenario({'text': 'Hello', 'text_chunk': 0, 'text_char_count': 5, 'text_word_count': 1, 'text_original': 'b10a8db164e0754105b7a99be72e3fe5'}), Scenario({'text': 'World', 'text_chunk': 1, 'text_char_count': 5, 'text_word_count': 1, 'text_original': 'b10a8db164e0754105b7a99be72e3fe5'})])
457
473
 
458
474
  >>> s.chunk("text")
459
475
  Traceback (most recent call last):