edsl 0.1.44__py3-none-any.whl → 0.1.45__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/__version__.py +1 -1
- edsl/agents/InvigilatorBase.py +3 -1
- edsl/agents/PromptConstructor.py +62 -34
- edsl/agents/QuestionInstructionPromptBuilder.py +111 -68
- edsl/agents/QuestionTemplateReplacementsBuilder.py +69 -16
- edsl/agents/question_option_processor.py +15 -6
- edsl/coop/CoopFunctionsMixin.py +3 -4
- edsl/coop/coop.py +23 -9
- edsl/enums.py +3 -3
- edsl/inference_services/AnthropicService.py +11 -9
- edsl/inference_services/AvailableModelFetcher.py +2 -0
- edsl/inference_services/AwsBedrock.py +1 -2
- edsl/inference_services/AzureAI.py +12 -9
- edsl/inference_services/GoogleService.py +9 -4
- edsl/inference_services/InferenceServicesCollection.py +2 -2
- edsl/inference_services/MistralAIService.py +1 -2
- edsl/inference_services/OpenAIService.py +9 -4
- edsl/inference_services/PerplexityService.py +2 -1
- edsl/inference_services/{GrokService.py → XAIService.py} +2 -2
- edsl/inference_services/registry.py +2 -2
- edsl/jobs/Jobs.py +9 -0
- edsl/jobs/JobsChecks.py +10 -13
- edsl/jobs/async_interview_runner.py +3 -1
- edsl/jobs/check_survey_scenario_compatibility.py +5 -5
- edsl/jobs/interviews/InterviewExceptionEntry.py +12 -0
- edsl/jobs/tasks/TaskHistory.py +1 -1
- edsl/language_models/LanguageModel.py +0 -3
- edsl/language_models/PriceManager.py +45 -5
- edsl/language_models/model.py +47 -26
- edsl/questions/QuestionBase.py +21 -0
- edsl/questions/QuestionBasePromptsMixin.py +103 -0
- edsl/questions/QuestionFreeText.py +22 -5
- edsl/questions/descriptors.py +4 -0
- edsl/questions/question_base_gen_mixin.py +94 -29
- edsl/results/Dataset.py +65 -0
- edsl/results/DatasetExportMixin.py +299 -32
- edsl/results/Result.py +27 -0
- edsl/results/Results.py +22 -2
- edsl/results/ResultsGGMixin.py +7 -3
- edsl/scenarios/DocumentChunker.py +2 -0
- edsl/scenarios/FileStore.py +10 -0
- edsl/scenarios/PdfExtractor.py +21 -1
- edsl/scenarios/Scenario.py +25 -9
- edsl/scenarios/ScenarioList.py +73 -3
- edsl/scenarios/handlers/__init__.py +1 -0
- edsl/scenarios/handlers/docx.py +5 -1
- edsl/scenarios/handlers/jpeg.py +39 -0
- edsl/surveys/Survey.py +5 -4
- edsl/surveys/SurveyFlowVisualization.py +91 -43
- edsl/templates/error_reporting/exceptions_table.html +7 -8
- edsl/templates/error_reporting/interview_details.html +1 -1
- edsl/templates/error_reporting/interviews.html +0 -1
- edsl/templates/error_reporting/overview.html +2 -7
- edsl/templates/error_reporting/performance_plot.html +1 -1
- edsl/templates/error_reporting/report.css +1 -1
- edsl/utilities/PrettyList.py +14 -0
- edsl-0.1.45.dist-info/METADATA +246 -0
- {edsl-0.1.44.dist-info → edsl-0.1.45.dist-info}/RECORD +60 -59
- edsl-0.1.44.dist-info/METADATA +0 -110
- {edsl-0.1.44.dist-info → edsl-0.1.45.dist-info}/LICENSE +0 -0
- {edsl-0.1.44.dist-info → edsl-0.1.45.dist-info}/WHEEL +0 -0
@@ -7,6 +7,7 @@ from typing import Optional, Tuple, Union, List
|
|
7
7
|
|
8
8
|
from edsl.results.file_exports import CSVExport, ExcelExport, JSONLExport, SQLiteExport
|
9
9
|
|
10
|
+
|
10
11
|
class DatasetExportMixin:
|
11
12
|
"""Mixin class for exporting Dataset objects."""
|
12
13
|
|
@@ -82,7 +83,8 @@ class DatasetExportMixin:
|
|
82
83
|
else:
|
83
84
|
if len(values) != _num_observations:
|
84
85
|
raise ValueError(
|
85
|
-
"The number of observations is not consistent across columns."
|
86
|
+
f"The number of observations is not consistent across columns. "
|
87
|
+
f"Column '{key}' has {len(values)} observations, but previous columns had {_num_observations} observations."
|
86
88
|
)
|
87
89
|
|
88
90
|
return _num_observations
|
@@ -219,7 +221,9 @@ class DatasetExportMixin:
|
|
219
221
|
)
|
220
222
|
return exporter.export()
|
221
223
|
|
222
|
-
def _db(
|
224
|
+
def _db(
|
225
|
+
self, remove_prefix: bool = True, shape: str = "wide"
|
226
|
+
) -> "sqlalchemy.engine.Engine":
|
223
227
|
"""Create a SQLite database in memory and return the connection.
|
224
228
|
|
225
229
|
Args:
|
@@ -229,7 +233,7 @@ class DatasetExportMixin:
|
|
229
233
|
Returns:
|
230
234
|
A database connection
|
231
235
|
>>> from sqlalchemy import text
|
232
|
-
>>> from edsl import Results
|
236
|
+
>>> from edsl import Results
|
233
237
|
>>> engine = Results.example()._db()
|
234
238
|
>>> len(engine.execute(text("SELECT * FROM self")).fetchall())
|
235
239
|
4
|
@@ -247,16 +251,17 @@ class DatasetExportMixin:
|
|
247
251
|
|
248
252
|
if shape == "long":
|
249
253
|
# Melt the dataframe to convert it to long format
|
250
|
-
df = df.melt(
|
251
|
-
var_name='key',
|
252
|
-
value_name='value'
|
253
|
-
)
|
254
|
+
df = df.melt(var_name="key", value_name="value")
|
254
255
|
# Add a row number column for reference
|
255
|
-
df.insert(0,
|
256
|
-
|
256
|
+
df.insert(0, "row_number", range(1, len(df) + 1))
|
257
|
+
|
257
258
|
# Split the key into data_type and key
|
258
|
-
df[
|
259
|
-
|
259
|
+
df["data_type"] = df["key"].apply(
|
260
|
+
lambda x: x.split(".")[0] if "." in x else None
|
261
|
+
)
|
262
|
+
df["key"] = df["key"].apply(
|
263
|
+
lambda x: ".".join(x.split(".")[1:]) if "." in x else x
|
264
|
+
)
|
260
265
|
|
261
266
|
df.to_sql(
|
262
267
|
"self",
|
@@ -276,27 +281,27 @@ class DatasetExportMixin:
|
|
276
281
|
) -> Union["pd.DataFrame", str]:
|
277
282
|
"""Execute a SQL query and return the results as a DataFrame.
|
278
283
|
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
284
|
+
Args:
|
285
|
+
query: The SQL query to execute
|
286
|
+
shape: The shape of the data in the database (wide or long)
|
287
|
+
remove_prefix: Whether to remove the prefix from the column names
|
288
|
+
transpose: Whether to transpose the DataFrame
|
289
|
+
transpose_by: The column to use as the index when transposing
|
290
|
+
csv: Whether to return the DataFrame as a CSV string
|
291
|
+
to_list: Whether to return the results as a list
|
292
|
+
to_latex: Whether to return the results as LaTeX
|
293
|
+
filename: Optional filename to save the results to
|
294
|
+
|
295
|
+
Returns:
|
296
|
+
DataFrame, CSV string, list, or LaTeX string depending on parameters
|
297
|
+
|
298
|
+
Examples:
|
299
|
+
>>> from edsl import Results
|
300
|
+
>>> r = Results.example();
|
301
|
+
>>> len(r.sql("SELECT * FROM self", shape = "wide"))
|
302
|
+
4
|
303
|
+
>>> len(r.sql("SELECT * FROM self", shape = "long"))
|
304
|
+
172
|
300
305
|
"""
|
301
306
|
import pandas as pd
|
302
307
|
|
@@ -538,6 +543,116 @@ class DatasetExportMixin:
|
|
538
543
|
|
539
544
|
if return_link:
|
540
545
|
return filename
|
546
|
+
|
547
|
+
def report(self, *fields: Optional[str], top_n: Optional[int] = None,
|
548
|
+
header_fields: Optional[List[str]] = None, divider: bool = True,
|
549
|
+
return_string: bool = False) -> Optional[str]:
|
550
|
+
"""Takes the fields in order and returns a report of the results by iterating through rows.
|
551
|
+
The row number is printed as # Observation: <row number>
|
552
|
+
The name of the field is used as markdown header at level "##"
|
553
|
+
The content of that field is then printed.
|
554
|
+
Then the next field and so on.
|
555
|
+
Once that row is done, a new line is printed and the next row is shown.
|
556
|
+
If in a jupyter notebook, the report is displayed as markdown.
|
557
|
+
|
558
|
+
Args:
|
559
|
+
*fields: The fields to include in the report. If none provided, all fields are used.
|
560
|
+
top_n: Optional limit on the number of observations to include.
|
561
|
+
header_fields: Optional list of fields to include in the main header instead of as sections.
|
562
|
+
divider: If True, adds a horizontal rule between observations for better visual separation.
|
563
|
+
return_string: If True, returns the markdown string. If False (default in notebooks),
|
564
|
+
only displays the markdown without returning.
|
565
|
+
|
566
|
+
Returns:
|
567
|
+
A string containing the markdown report if return_string is True, otherwise None.
|
568
|
+
|
569
|
+
Examples:
|
570
|
+
>>> from edsl.results import Results
|
571
|
+
>>> r = Results.example()
|
572
|
+
>>> report = r.select('how_feeling', 'how_feeling_yesterday').report(return_string=True)
|
573
|
+
>>> "# Observation: 1" in report
|
574
|
+
True
|
575
|
+
>>> "## answer.how_feeling" in report
|
576
|
+
True
|
577
|
+
>>> report = r.select('how_feeling').report(header_fields=['answer.how_feeling'], return_string=True)
|
578
|
+
>>> "# Observation: 1 (`how_feeling`: OK)" in report
|
579
|
+
True
|
580
|
+
"""
|
581
|
+
from edsl.utilities.utilities import is_notebook
|
582
|
+
|
583
|
+
# If no fields specified, use all columns
|
584
|
+
if not fields:
|
585
|
+
fields = self.relevant_columns()
|
586
|
+
|
587
|
+
# Initialize header_fields if not provided
|
588
|
+
if header_fields is None:
|
589
|
+
header_fields = []
|
590
|
+
|
591
|
+
# Validate all fields
|
592
|
+
all_fields = list(fields) + [f for f in header_fields if f not in fields]
|
593
|
+
for field in all_fields:
|
594
|
+
if field not in self.relevant_columns():
|
595
|
+
raise ValueError(f"Field '{field}' not found in dataset")
|
596
|
+
|
597
|
+
# Get data for each field
|
598
|
+
field_data = {}
|
599
|
+
for field in all_fields:
|
600
|
+
for entry in self:
|
601
|
+
if field in entry:
|
602
|
+
field_data[field] = entry[field]
|
603
|
+
break
|
604
|
+
|
605
|
+
# Number of observations to process
|
606
|
+
num_obs = self.num_observations()
|
607
|
+
if top_n is not None:
|
608
|
+
num_obs = min(num_obs, top_n)
|
609
|
+
|
610
|
+
# Build the report
|
611
|
+
report_lines = []
|
612
|
+
for i in range(num_obs):
|
613
|
+
# Create header with observation number and any header fields
|
614
|
+
header = f"# Observation: {i+1}"
|
615
|
+
if header_fields:
|
616
|
+
header_parts = []
|
617
|
+
for field in header_fields:
|
618
|
+
value = field_data[field][i]
|
619
|
+
# Get the field name without prefix for cleaner display
|
620
|
+
display_name = field.split('.')[-1] if '.' in field else field
|
621
|
+
# Format with backticks for monospace
|
622
|
+
header_parts.append(f"`{display_name}`: {value}")
|
623
|
+
if header_parts:
|
624
|
+
header += f" ({', '.join(header_parts)})"
|
625
|
+
report_lines.append(header)
|
626
|
+
|
627
|
+
# Add the remaining fields
|
628
|
+
for field in fields:
|
629
|
+
if field not in header_fields:
|
630
|
+
report_lines.append(f"## {field}")
|
631
|
+
value = field_data[field][i]
|
632
|
+
if isinstance(value, list) or isinstance(value, dict):
|
633
|
+
import json
|
634
|
+
report_lines.append(f"```\n{json.dumps(value, indent=2)}\n```")
|
635
|
+
else:
|
636
|
+
report_lines.append(str(value))
|
637
|
+
|
638
|
+
# Add divider between observations if requested
|
639
|
+
if divider and i < num_obs - 1:
|
640
|
+
report_lines.append("\n---\n")
|
641
|
+
else:
|
642
|
+
report_lines.append("") # Empty line between observations
|
643
|
+
|
644
|
+
report_text = "\n".join(report_lines)
|
645
|
+
|
646
|
+
# In notebooks, display as markdown and optionally return
|
647
|
+
is_nb = is_notebook()
|
648
|
+
if is_nb:
|
649
|
+
from IPython.display import Markdown, display
|
650
|
+
display(Markdown(report_text))
|
651
|
+
|
652
|
+
# Return the string if requested or if not in a notebook
|
653
|
+
if return_string or not is_nb:
|
654
|
+
return report_text
|
655
|
+
return None
|
541
656
|
|
542
657
|
def tally(
|
543
658
|
self, *fields: Optional[str], top_n: Optional[int] = None, output="Dataset"
|
@@ -616,6 +731,158 @@ class DatasetExportMixin:
|
|
616
731
|
keys.append("count")
|
617
732
|
return sl.reorder_keys(keys).to_dataset()
|
618
733
|
|
734
|
+
def flatten(self, field, keep_original=False):
|
735
|
+
"""
|
736
|
+
Flatten a field containing a list of dictionaries into separate fields.
|
737
|
+
|
738
|
+
For example, if a dataset contains:
|
739
|
+
[{'data': [{'a': 1}, {'b': 2}], 'other': ['x', 'y']}]
|
740
|
+
|
741
|
+
After d.flatten('data'), it should become:
|
742
|
+
[{'other': ['x', 'y'], 'data.a': [1, None], 'data.b': [None, 2]}]
|
743
|
+
|
744
|
+
Args:
|
745
|
+
field: The field to flatten
|
746
|
+
keep_original: If True, keeps the original field in the dataset
|
747
|
+
|
748
|
+
Returns:
|
749
|
+
A new dataset with the flattened fields
|
750
|
+
"""
|
751
|
+
from edsl.results.Dataset import Dataset
|
752
|
+
|
753
|
+
# Ensure the dataset isn't empty
|
754
|
+
if not self.data:
|
755
|
+
return self.copy()
|
756
|
+
|
757
|
+
# Get the number of observations
|
758
|
+
num_observations = self.num_observations()
|
759
|
+
|
760
|
+
# Find the column to flatten
|
761
|
+
field_entry = None
|
762
|
+
for entry in self.data:
|
763
|
+
if field in entry:
|
764
|
+
field_entry = entry
|
765
|
+
break
|
766
|
+
|
767
|
+
if field_entry is None:
|
768
|
+
warnings.warn(
|
769
|
+
f"Field '{field}' not found in dataset, returning original dataset"
|
770
|
+
)
|
771
|
+
return self.copy()
|
772
|
+
|
773
|
+
# Create new dictionary for flattened data
|
774
|
+
flattened_data = []
|
775
|
+
|
776
|
+
# Copy all existing columns except the one we're flattening (if keep_original is False)
|
777
|
+
for entry in self.data:
|
778
|
+
col_name = next(iter(entry.keys()))
|
779
|
+
if col_name != field or keep_original:
|
780
|
+
flattened_data.append(entry.copy())
|
781
|
+
|
782
|
+
# Get field data and make sure it's valid
|
783
|
+
field_values = field_entry[field]
|
784
|
+
if not all(isinstance(item, dict) for item in field_values if item is not None):
|
785
|
+
warnings.warn(
|
786
|
+
f"Field '{field}' contains non-dictionary values that cannot be flattened"
|
787
|
+
)
|
788
|
+
return self.copy()
|
789
|
+
|
790
|
+
# Collect all unique keys across all dictionaries
|
791
|
+
all_keys = set()
|
792
|
+
for item in field_values:
|
793
|
+
if isinstance(item, dict):
|
794
|
+
all_keys.update(item.keys())
|
795
|
+
|
796
|
+
# Create new columns for each key
|
797
|
+
for key in sorted(all_keys): # Sort for consistent output
|
798
|
+
new_values = []
|
799
|
+
for i in range(num_observations):
|
800
|
+
value = None
|
801
|
+
if i < len(field_values) and isinstance(field_values[i], dict):
|
802
|
+
value = field_values[i].get(key, None)
|
803
|
+
new_values.append(value)
|
804
|
+
|
805
|
+
# Add this as a new column
|
806
|
+
flattened_data.append({f"{field}.{key}": new_values})
|
807
|
+
|
808
|
+
# Return a new Dataset with the flattened data
|
809
|
+
return Dataset(flattened_data)
|
810
|
+
|
811
|
+
def unpack_list(
|
812
|
+
self,
|
813
|
+
field: str,
|
814
|
+
new_names: Optional[List[str]] = None,
|
815
|
+
keep_original: bool = True,
|
816
|
+
) -> "Dataset":
|
817
|
+
"""Unpack list columns into separate columns with provided names or numeric suffixes.
|
818
|
+
|
819
|
+
For example, if a dataset contains:
|
820
|
+
[{'data': [[1, 2, 3], [4, 5, 6]], 'other': ['x', 'y']}]
|
821
|
+
|
822
|
+
After d.unpack_list('data'), it should become:
|
823
|
+
[{'other': ['x', 'y'], 'data_1': [1, 4], 'data_2': [2, 5], 'data_3': [3, 6]}]
|
824
|
+
|
825
|
+
Args:
|
826
|
+
field: The field containing lists to unpack
|
827
|
+
new_names: Optional list of names for the unpacked fields. If None, uses numeric suffixes.
|
828
|
+
keep_original: If True, keeps the original field in the dataset
|
829
|
+
|
830
|
+
Returns:
|
831
|
+
A new Dataset with unpacked columns
|
832
|
+
|
833
|
+
Examples:
|
834
|
+
>>> from edsl.results.Dataset import Dataset
|
835
|
+
>>> d = Dataset([{'data': [[1, 2, 3], [4, 5, 6]]}])
|
836
|
+
>>> d.unpack_list('data')
|
837
|
+
Dataset([{'data': [[1, 2, 3], [4, 5, 6]]}, {'data_1': [1, 4]}, {'data_2': [2, 5]}, {'data_3': [3, 6]}])
|
838
|
+
|
839
|
+
>>> d.unpack_list('data', new_names=['first', 'second', 'third'])
|
840
|
+
Dataset([{'data': [[1, 2, 3], [4, 5, 6]]}, {'first': [1, 4]}, {'second': [2, 5]}, {'third': [3, 6]}])
|
841
|
+
"""
|
842
|
+
from edsl.results.Dataset import Dataset
|
843
|
+
|
844
|
+
# Create a copy of the dataset
|
845
|
+
result = Dataset(self.data.copy())
|
846
|
+
|
847
|
+
# Find the field in the dataset
|
848
|
+
field_index = None
|
849
|
+
for i, entry in enumerate(result.data):
|
850
|
+
if field in entry:
|
851
|
+
field_index = i
|
852
|
+
break
|
853
|
+
|
854
|
+
if field_index is None:
|
855
|
+
raise ValueError(f"Field '{field}' not found in dataset")
|
856
|
+
|
857
|
+
field_data = result.data[field_index][field]
|
858
|
+
|
859
|
+
# Check if values are lists
|
860
|
+
if not all(isinstance(v, list) for v in field_data):
|
861
|
+
raise ValueError(f"Field '{field}' does not contain lists in all entries")
|
862
|
+
|
863
|
+
# Get the maximum length of lists
|
864
|
+
max_len = max(len(v) for v in field_data)
|
865
|
+
|
866
|
+
# Create new fields for each index
|
867
|
+
for i in range(max_len):
|
868
|
+
if new_names and i < len(new_names):
|
869
|
+
new_field = new_names[i]
|
870
|
+
else:
|
871
|
+
new_field = f"{field}_{i+1}"
|
872
|
+
|
873
|
+
# Extract the i-th element from each list
|
874
|
+
new_values = []
|
875
|
+
for item in field_data:
|
876
|
+
new_values.append(item[i] if i < len(item) else None)
|
877
|
+
|
878
|
+
result.data.append({new_field: new_values})
|
879
|
+
|
880
|
+
# Remove the original field if keep_original is False
|
881
|
+
if not keep_original:
|
882
|
+
result.data.pop(field_index)
|
883
|
+
|
884
|
+
return result
|
885
|
+
|
619
886
|
|
620
887
|
if __name__ == "__main__":
|
621
888
|
import doctest
|
edsl/results/Result.py
CHANGED
@@ -439,6 +439,33 @@ class Result(Base, UserDict):
|
|
439
439
|
from edsl.results.Results import Results
|
440
440
|
|
441
441
|
return Results.example()[0]
|
442
|
+
|
443
|
+
def score_with_answer_key(self, answer_key: dict) -> Union[int, float]:
|
444
|
+
"""Score the result using an answer key.
|
445
|
+
|
446
|
+
:param answer_key: A dictionary that maps question_names to answers
|
447
|
+
|
448
|
+
>>> Result.example()['answer']
|
449
|
+
{'how_feeling': 'OK', 'how_feeling_yesterday': 'Great'}
|
450
|
+
|
451
|
+
>>> answer_key = {'how_feeling': 'OK', 'how_feeling_yesterday': 'Great'}
|
452
|
+
>>> Result.example().score_with_answer_key(answer_key)
|
453
|
+
{'correct': 2, 'incorrect': 0, 'missing': 0}
|
454
|
+
>>> answer_key = {'how_feeling': 'OK', 'how_feeling_yesterday': ['Great', 'Good']}
|
455
|
+
>>> Result.example().score_with_answer_key(answer_key)
|
456
|
+
{'correct': 2, 'incorrect': 0, 'missing': 0}
|
457
|
+
"""
|
458
|
+
final_scores = {'correct': 0, 'incorrect': 0, 'missing': 0}
|
459
|
+
for question_name, answer in self.answer.items():
|
460
|
+
if question_name in answer_key:
|
461
|
+
if answer == answer_key[question_name] or answer in answer_key[question_name]:
|
462
|
+
final_scores['correct'] += 1
|
463
|
+
else:
|
464
|
+
final_scores['incorrect'] += 1
|
465
|
+
else:
|
466
|
+
final_scores['missing'] += 1
|
467
|
+
|
468
|
+
return final_scores
|
442
469
|
|
443
470
|
def score(self, scoring_function: Callable) -> Union[int, float]:
|
444
471
|
"""Score the result using a passed-in scoring function.
|
edsl/results/Results.py
CHANGED
@@ -34,7 +34,7 @@ if TYPE_CHECKING:
|
|
34
34
|
from simpleeval import EvalWithCompoundTypes
|
35
35
|
|
36
36
|
from edsl.results.ResultsExportMixin import ResultsExportMixin
|
37
|
-
from edsl.results.ResultsGGMixin import
|
37
|
+
from edsl.results.ResultsGGMixin import GGPlotMethod
|
38
38
|
from edsl.results.results_fetch_mixin import ResultsFetchMixin
|
39
39
|
from edsl.utilities.remove_edsl_version import remove_edsl_version
|
40
40
|
|
@@ -100,7 +100,7 @@ class NotReadyObject:
|
|
100
100
|
class Mixins(
|
101
101
|
ResultsExportMixin,
|
102
102
|
ResultsFetchMixin,
|
103
|
-
ResultsGGMixin,
|
103
|
+
# ResultsGGMixin,
|
104
104
|
):
|
105
105
|
def long(self):
|
106
106
|
return self.table().long()
|
@@ -151,6 +151,19 @@ class Results(UserList, Mixins, Base):
|
|
151
151
|
"cache_keys",
|
152
152
|
]
|
153
153
|
|
154
|
+
def ggplot2(
|
155
|
+
self,
|
156
|
+
ggplot_code: str,
|
157
|
+
shape="wide",
|
158
|
+
sql: str = None,
|
159
|
+
remove_prefix: bool = True,
|
160
|
+
debug: bool = False,
|
161
|
+
height=4,
|
162
|
+
width=6,
|
163
|
+
factor_orders: Optional[dict] = None,
|
164
|
+
):
|
165
|
+
return GGPlotMethod(self).ggplot2(ggplot_code, shape, sql, remove_prefix, debug, height, width, factor_orders)
|
166
|
+
|
154
167
|
@classmethod
|
155
168
|
def from_job_info(cls, job_info: dict) -> Results:
|
156
169
|
"""
|
@@ -1277,6 +1290,13 @@ class Results(UserList, Mixins, Base):
|
|
1277
1290
|
"""
|
1278
1291
|
return [r.score(f) for r in self.data]
|
1279
1292
|
|
1293
|
+
def score_with_answer_key(self, answer_key: dict) -> list:
|
1294
|
+
"""Score the results using an answer key.
|
1295
|
+
|
1296
|
+
:param answer_key: A dictionary that maps answer values to scores.
|
1297
|
+
"""
|
1298
|
+
return [r.score_with_answer_key(answer_key) for r in self.data]
|
1299
|
+
|
1280
1300
|
|
1281
1301
|
def fetch_remote(self, job_info: "RemoteJobInfo") -> None:
|
1282
1302
|
"""
|
edsl/results/ResultsGGMixin.py
CHANGED
@@ -75,7 +75,11 @@ class GGPlot:
|
|
75
75
|
|
76
76
|
return self._svg_data
|
77
77
|
|
78
|
-
class
|
78
|
+
class GGPlotMethod:
|
79
|
+
|
80
|
+
def __init__(self, results: 'Results'):
|
81
|
+
self.results = results
|
82
|
+
|
79
83
|
"""Mixin class for ggplot2 plotting."""
|
80
84
|
|
81
85
|
def ggplot2(
|
@@ -106,9 +110,9 @@ class ResultsGGMixin:
|
|
106
110
|
sql = "select * from self"
|
107
111
|
|
108
112
|
if shape == "long":
|
109
|
-
df = self.sql(sql, shape="long")
|
113
|
+
df = self.results.sql(sql, shape="long")
|
110
114
|
elif shape == "wide":
|
111
|
-
df = self.sql(sql, remove_prefix=remove_prefix)
|
115
|
+
df = self.results.sql(sql, remove_prefix=remove_prefix)
|
112
116
|
|
113
117
|
# Convert DataFrame to CSV format
|
114
118
|
csv_data = df.to_csv().text
|
@@ -85,6 +85,8 @@ class DocumentChunker:
|
|
85
85
|
new_scenario = copy.deepcopy(self.scenario)
|
86
86
|
new_scenario[field] = chunk
|
87
87
|
new_scenario[field + "_chunk"] = i
|
88
|
+
new_scenario[field + "_char_count"] = len(chunk)
|
89
|
+
new_scenario[field + "_word_count"] = len(chunk.split())
|
88
90
|
if include_original:
|
89
91
|
if hash_original:
|
90
92
|
new_scenario[field + "_original"] = hashlib.md5(
|
edsl/scenarios/FileStore.py
CHANGED
@@ -29,6 +29,12 @@ class FileStore(Scenario):
|
|
29
29
|
if path is None and "filename" in kwargs:
|
30
30
|
path = kwargs["filename"]
|
31
31
|
|
32
|
+
# Check if path is a URL and handle download
|
33
|
+
if path and (path.startswith('http://') or path.startswith('https://')):
|
34
|
+
temp_filestore = self.from_url(path, mime_type=mime_type)
|
35
|
+
path = temp_filestore._path
|
36
|
+
mime_type = temp_filestore.mime_type
|
37
|
+
|
32
38
|
self._path = path # Store the original path privately
|
33
39
|
self._temp_path = None # Track any generated temporary file
|
34
40
|
|
@@ -138,6 +144,10 @@ class FileStore(Scenario):
|
|
138
144
|
base64_encoded_data = base64.b64encode(binary_data)
|
139
145
|
self.binary = True
|
140
146
|
# Convert the base64 bytes to a string
|
147
|
+
except FileNotFoundError:
|
148
|
+
print(f"File not found: {file_path}")
|
149
|
+
print("Current working directory:", os.getcwd())
|
150
|
+
raise
|
141
151
|
base64_string = base64_encoded_data.decode("utf-8")
|
142
152
|
|
143
153
|
return base64_string
|
edsl/scenarios/PdfExtractor.py
CHANGED
@@ -4,10 +4,30 @@ import os
|
|
4
4
|
class PdfExtractor:
|
5
5
|
def __init__(self, pdf_path: str):
|
6
6
|
self.pdf_path = pdf_path
|
7
|
+
self._has_pymupdf = self._check_pymupdf()
|
7
8
|
#self.constructor = parent_object.__class__
|
8
9
|
|
10
|
+
def _check_pymupdf(self):
|
11
|
+
"""Check if PyMuPDF is installed."""
|
12
|
+
try:
|
13
|
+
import fitz
|
14
|
+
return True
|
15
|
+
except ImportError:
|
16
|
+
return False
|
17
|
+
|
9
18
|
def get_pdf_dict(self) -> dict:
|
10
|
-
#
|
19
|
+
# First check if the file exists
|
20
|
+
if not os.path.exists(self.pdf_path):
|
21
|
+
raise FileNotFoundError(f"The file {self.pdf_path} does not exist.")
|
22
|
+
|
23
|
+
# Then check if PyMuPDF is available
|
24
|
+
if not self._has_pymupdf:
|
25
|
+
raise ImportError(
|
26
|
+
"The 'fitz' module (PyMuPDF) is required for PDF extraction. "
|
27
|
+
"Please install it with: pip install pymupdf"
|
28
|
+
)
|
29
|
+
|
30
|
+
# If we get here, we can safely import and use fitz
|
11
31
|
import fitz
|
12
32
|
|
13
33
|
if not os.path.exists(self.pdf_path):
|
edsl/scenarios/Scenario.py
CHANGED
@@ -64,6 +64,15 @@ class Scenario(Base, UserDict, ScenarioHtmlMixin):
|
|
64
64
|
self.data = data if data is not None else {}
|
65
65
|
self.name = name
|
66
66
|
|
67
|
+
def __mul__(self, scenario_list_or_scenario: Union["ScenarioList", "Scenario"]) -> "ScenarioList":
|
68
|
+
from edsl.scenarios.ScenarioList import ScenarioList
|
69
|
+
if isinstance(scenario_list_or_scenario, ScenarioList):
|
70
|
+
return scenario_list_or_scenario * self
|
71
|
+
elif isinstance(scenario_list_or_scenario, Scenario):
|
72
|
+
return ScenarioList([self]) * scenario_list_or_scenario
|
73
|
+
else:
|
74
|
+
raise TypeError(f"Cannot multiply Scenario with {type(scenario_list_or_scenario)}")
|
75
|
+
|
67
76
|
def replicate(self, n: int) -> "ScenarioList":
|
68
77
|
"""Replicate a scenario n times to return a ScenarioList.
|
69
78
|
|
@@ -356,11 +365,18 @@ class Scenario(Base, UserDict, ScenarioHtmlMixin):
|
|
356
365
|
|
357
366
|
@classmethod
|
358
367
|
def from_pdf(cls, pdf_path: str):
|
359
|
-
from
|
360
|
-
|
361
|
-
|
362
|
-
|
363
|
-
|
368
|
+
"""Create a Scenario from a PDF file."""
|
369
|
+
try:
|
370
|
+
from edsl.scenarios.PdfExtractor import PdfExtractor
|
371
|
+
extractor = PdfExtractor(pdf_path)
|
372
|
+
return Scenario(extractor.get_pdf_dict())
|
373
|
+
except ImportError as e:
|
374
|
+
raise ImportError(
|
375
|
+
f"Could not extract text from PDF: {str(e)}. "
|
376
|
+
"PDF extraction requires the PyMuPDF library. "
|
377
|
+
"Install it with: pip install pymupdf"
|
378
|
+
)
|
379
|
+
|
364
380
|
@classmethod
|
365
381
|
def from_pdf_to_image(cls, pdf_path, image_format="jpeg"):
|
366
382
|
"""
|
@@ -442,18 +458,18 @@ class Scenario(Base, UserDict, ScenarioHtmlMixin):
|
|
442
458
|
|
443
459
|
>>> s = Scenario({"text": "This is a test.\\nThis is a test.\\n\\nThis is a test."})
|
444
460
|
>>> s.chunk("text", num_lines = 1)
|
445
|
-
ScenarioList([Scenario({'text': 'This is a test.', 'text_chunk': 0}), Scenario({'text': 'This is a test.', 'text_chunk': 1}), Scenario({'text': '', 'text_chunk': 2}), Scenario({'text': 'This is a test.', 'text_chunk': 3})])
|
461
|
+
ScenarioList([Scenario({'text': 'This is a test.', 'text_chunk': 0, 'text_char_count': 15, 'text_word_count': 4}), Scenario({'text': 'This is a test.', 'text_chunk': 1, 'text_char_count': 15, 'text_word_count': 4}), Scenario({'text': '', 'text_chunk': 2, 'text_char_count': 0, 'text_word_count': 0}), Scenario({'text': 'This is a test.', 'text_chunk': 3, 'text_char_count': 15, 'text_word_count': 4})])
|
446
462
|
|
447
463
|
>>> s.chunk("text", num_words = 2)
|
448
|
-
ScenarioList([Scenario({'text': 'This is', 'text_chunk': 0}), Scenario({'text': 'a test.', 'text_chunk': 1}), Scenario({'text': 'This is', 'text_chunk': 2}), Scenario({'text': 'a test.', 'text_chunk': 3}), Scenario({'text': 'This is', 'text_chunk': 4}), Scenario({'text': 'a test.', 'text_chunk': 5})])
|
464
|
+
ScenarioList([Scenario({'text': 'This is', 'text_chunk': 0, 'text_char_count': 7, 'text_word_count': 2}), Scenario({'text': 'a test.', 'text_chunk': 1, 'text_char_count': 7, 'text_word_count': 2}), Scenario({'text': 'This is', 'text_chunk': 2, 'text_char_count': 7, 'text_word_count': 2}), Scenario({'text': 'a test.', 'text_chunk': 3, 'text_char_count': 7, 'text_word_count': 2}), Scenario({'text': 'This is', 'text_chunk': 4, 'text_char_count': 7, 'text_word_count': 2}), Scenario({'text': 'a test.', 'text_chunk': 5, 'text_char_count': 7, 'text_word_count': 2})])
|
449
465
|
|
450
466
|
>>> s = Scenario({"text": "Hello World"})
|
451
467
|
>>> s.chunk("text", num_words = 1, include_original = True)
|
452
|
-
ScenarioList([Scenario({'text': 'Hello', 'text_chunk': 0, 'text_original': 'Hello World'}), Scenario({'text': 'World', 'text_chunk': 1, 'text_original': 'Hello World'})])
|
468
|
+
ScenarioList([Scenario({'text': 'Hello', 'text_chunk': 0, 'text_char_count': 5, 'text_word_count': 1, 'text_original': 'Hello World'}), Scenario({'text': 'World', 'text_chunk': 1, 'text_char_count': 5, 'text_word_count': 1, 'text_original': 'Hello World'})])
|
453
469
|
|
454
470
|
>>> s = Scenario({"text": "Hello World"})
|
455
471
|
>>> s.chunk("text", num_words = 1, include_original = True, hash_original = True)
|
456
|
-
ScenarioList([Scenario({'text': 'Hello', 'text_chunk': 0, 'text_original': 'b10a8db164e0754105b7a99be72e3fe5'}), Scenario({'text': 'World', 'text_chunk': 1, 'text_original': 'b10a8db164e0754105b7a99be72e3fe5'})])
|
472
|
+
ScenarioList([Scenario({'text': 'Hello', 'text_chunk': 0, 'text_char_count': 5, 'text_word_count': 1, 'text_original': 'b10a8db164e0754105b7a99be72e3fe5'}), Scenario({'text': 'World', 'text_chunk': 1, 'text_char_count': 5, 'text_word_count': 1, 'text_original': 'b10a8db164e0754105b7a99be72e3fe5'})])
|
457
473
|
|
458
474
|
>>> s.chunk("text")
|
459
475
|
Traceback (most recent call last):
|