edsl 0.1.43__py3-none-any.whl → 0.1.45__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +15 -6
- edsl/__version__.py +1 -1
- edsl/agents/InvigilatorBase.py +3 -1
- edsl/agents/PromptConstructor.py +62 -34
- edsl/agents/QuestionInstructionPromptBuilder.py +111 -68
- edsl/agents/QuestionTemplateReplacementsBuilder.py +69 -16
- edsl/agents/question_option_processor.py +15 -6
- edsl/coop/CoopFunctionsMixin.py +3 -4
- edsl/coop/coop.py +56 -10
- edsl/enums.py +4 -1
- edsl/inference_services/AnthropicService.py +12 -8
- edsl/inference_services/AvailableModelFetcher.py +2 -0
- edsl/inference_services/AwsBedrock.py +1 -2
- edsl/inference_services/AzureAI.py +12 -9
- edsl/inference_services/GoogleService.py +10 -3
- edsl/inference_services/InferenceServiceABC.py +1 -0
- edsl/inference_services/InferenceServicesCollection.py +2 -2
- edsl/inference_services/MistralAIService.py +1 -2
- edsl/inference_services/OpenAIService.py +10 -4
- edsl/inference_services/PerplexityService.py +2 -1
- edsl/inference_services/TestService.py +1 -0
- edsl/inference_services/XAIService.py +11 -0
- edsl/inference_services/registry.py +2 -0
- edsl/jobs/Jobs.py +9 -0
- edsl/jobs/JobsChecks.py +11 -14
- edsl/jobs/JobsPrompts.py +3 -3
- edsl/jobs/async_interview_runner.py +3 -1
- edsl/jobs/check_survey_scenario_compatibility.py +5 -5
- edsl/jobs/interviews/InterviewExceptionEntry.py +12 -0
- edsl/jobs/tasks/TaskHistory.py +1 -1
- edsl/language_models/LanguageModel.py +3 -3
- edsl/language_models/PriceManager.py +45 -5
- edsl/language_models/model.py +89 -36
- edsl/questions/QuestionBase.py +21 -0
- edsl/questions/QuestionBasePromptsMixin.py +103 -0
- edsl/questions/QuestionFreeText.py +22 -5
- edsl/questions/descriptors.py +4 -0
- edsl/questions/question_base_gen_mixin.py +94 -29
- edsl/results/Dataset.py +65 -0
- edsl/results/DatasetExportMixin.py +299 -32
- edsl/results/Result.py +27 -0
- edsl/results/Results.py +24 -3
- edsl/results/ResultsGGMixin.py +7 -3
- edsl/scenarios/DocumentChunker.py +2 -0
- edsl/scenarios/FileStore.py +29 -8
- edsl/scenarios/PdfExtractor.py +21 -1
- edsl/scenarios/Scenario.py +25 -9
- edsl/scenarios/ScenarioList.py +73 -3
- edsl/scenarios/handlers/__init__.py +1 -0
- edsl/scenarios/handlers/docx.py +5 -1
- edsl/scenarios/handlers/jpeg.py +39 -0
- edsl/surveys/Survey.py +28 -6
- edsl/surveys/SurveyFlowVisualization.py +91 -43
- edsl/templates/error_reporting/exceptions_table.html +7 -8
- edsl/templates/error_reporting/interview_details.html +1 -1
- edsl/templates/error_reporting/interviews.html +0 -1
- edsl/templates/error_reporting/overview.html +2 -7
- edsl/templates/error_reporting/performance_plot.html +1 -1
- edsl/templates/error_reporting/report.css +1 -1
- edsl/utilities/PrettyList.py +14 -0
- edsl-0.1.45.dist-info/METADATA +246 -0
- {edsl-0.1.43.dist-info → edsl-0.1.45.dist-info}/RECORD +64 -62
- edsl-0.1.43.dist-info/METADATA +0 -110
- {edsl-0.1.43.dist-info → edsl-0.1.45.dist-info}/LICENSE +0 -0
- {edsl-0.1.43.dist-info → edsl-0.1.45.dist-info}/WHEEL +0 -0
edsl/results/Dataset.py
CHANGED
@@ -15,6 +15,7 @@ from edsl.Base import PersistenceMixin, HashingMixin
|
|
15
15
|
|
16
16
|
from edsl.results.smart_objects import FirstObject
|
17
17
|
|
18
|
+
from edsl.results.ResultsGGMixin import GGPlotMethod
|
18
19
|
|
19
20
|
class Dataset(UserList, ResultsExportMixin, PersistenceMixin, HashingMixin):
|
20
21
|
"""A class to represent a dataset of observations."""
|
@@ -26,6 +27,20 @@ class Dataset(UserList, ResultsExportMixin, PersistenceMixin, HashingMixin):
|
|
26
27
|
super().__init__(data)
|
27
28
|
self.print_parameters = print_parameters
|
28
29
|
|
30
|
+
|
31
|
+
def ggplot2(
|
32
|
+
self,
|
33
|
+
ggplot_code: str,
|
34
|
+
shape="wide",
|
35
|
+
sql: str = None,
|
36
|
+
remove_prefix: bool = True,
|
37
|
+
debug: bool = False,
|
38
|
+
height=4,
|
39
|
+
width=6,
|
40
|
+
factor_orders: Optional[dict] = None,
|
41
|
+
):
|
42
|
+
return GGPlotMethod(self).ggplot2(ggplot_code, shape, sql, remove_prefix, debug, height, width, factor_orders)
|
43
|
+
|
29
44
|
def __len__(self) -> int:
|
30
45
|
"""Return the number of observations in the dataset.
|
31
46
|
|
@@ -580,6 +595,56 @@ class Dataset(UserList, ResultsExportMixin, PersistenceMixin, HashingMixin):
|
|
580
595
|
result = cls([{col: df[col].tolist()} for col in df.columns])
|
581
596
|
return result
|
582
597
|
|
598
|
+
def to_docx(self, output_file: str, title: str = None) -> None:
|
599
|
+
"""
|
600
|
+
Convert the dataset to a Word document.
|
601
|
+
|
602
|
+
Args:
|
603
|
+
output_file (str): Path to save the Word document
|
604
|
+
title (str, optional): Title for the document
|
605
|
+
"""
|
606
|
+
from docx import Document
|
607
|
+
from docx.shared import Inches
|
608
|
+
from docx.enum.text import WD_ALIGN_PARAGRAPH
|
609
|
+
|
610
|
+
# Create document
|
611
|
+
doc = Document()
|
612
|
+
|
613
|
+
# Add title if provided
|
614
|
+
if title:
|
615
|
+
title_heading = doc.add_heading(title, level=1)
|
616
|
+
title_heading.alignment = WD_ALIGN_PARAGRAPH.CENTER
|
617
|
+
|
618
|
+
# Get headers and data
|
619
|
+
headers, data = self._tabular()
|
620
|
+
|
621
|
+
# Create table
|
622
|
+
table = doc.add_table(rows=len(data) + 1, cols=len(headers))
|
623
|
+
table.style = 'Table Grid'
|
624
|
+
|
625
|
+
# Add headers
|
626
|
+
for j, header in enumerate(headers):
|
627
|
+
cell = table.cell(0, j)
|
628
|
+
cell.text = str(header)
|
629
|
+
|
630
|
+
# Add data
|
631
|
+
for i, row in enumerate(data):
|
632
|
+
for j, cell_content in enumerate(row):
|
633
|
+
cell = table.cell(i + 1, j)
|
634
|
+
cell.text = str(cell_content) if cell_content is not None else ""
|
635
|
+
|
636
|
+
# Adjust column widths
|
637
|
+
for column in table.columns:
|
638
|
+
max_width = 0
|
639
|
+
for cell in column.cells:
|
640
|
+
text_width = len(str(cell.text))
|
641
|
+
max_width = max(max_width, text_width)
|
642
|
+
for cell in column.cells:
|
643
|
+
cell.width = Inches(min(max_width * 0.1 + 0.5, 6))
|
644
|
+
|
645
|
+
# Save the document
|
646
|
+
doc.save(output_file)
|
647
|
+
|
583
648
|
|
584
649
|
if __name__ == "__main__":
|
585
650
|
import doctest
|
@@ -7,6 +7,7 @@ from typing import Optional, Tuple, Union, List
|
|
7
7
|
|
8
8
|
from edsl.results.file_exports import CSVExport, ExcelExport, JSONLExport, SQLiteExport
|
9
9
|
|
10
|
+
|
10
11
|
class DatasetExportMixin:
|
11
12
|
"""Mixin class for exporting Dataset objects."""
|
12
13
|
|
@@ -82,7 +83,8 @@ class DatasetExportMixin:
|
|
82
83
|
else:
|
83
84
|
if len(values) != _num_observations:
|
84
85
|
raise ValueError(
|
85
|
-
"The number of observations is not consistent across columns."
|
86
|
+
f"The number of observations is not consistent across columns. "
|
87
|
+
f"Column '{key}' has {len(values)} observations, but previous columns had {_num_observations} observations."
|
86
88
|
)
|
87
89
|
|
88
90
|
return _num_observations
|
@@ -219,7 +221,9 @@ class DatasetExportMixin:
|
|
219
221
|
)
|
220
222
|
return exporter.export()
|
221
223
|
|
222
|
-
def _db(
|
224
|
+
def _db(
|
225
|
+
self, remove_prefix: bool = True, shape: str = "wide"
|
226
|
+
) -> "sqlalchemy.engine.Engine":
|
223
227
|
"""Create a SQLite database in memory and return the connection.
|
224
228
|
|
225
229
|
Args:
|
@@ -229,7 +233,7 @@ class DatasetExportMixin:
|
|
229
233
|
Returns:
|
230
234
|
A database connection
|
231
235
|
>>> from sqlalchemy import text
|
232
|
-
>>> from edsl import Results
|
236
|
+
>>> from edsl import Results
|
233
237
|
>>> engine = Results.example()._db()
|
234
238
|
>>> len(engine.execute(text("SELECT * FROM self")).fetchall())
|
235
239
|
4
|
@@ -247,16 +251,17 @@ class DatasetExportMixin:
|
|
247
251
|
|
248
252
|
if shape == "long":
|
249
253
|
# Melt the dataframe to convert it to long format
|
250
|
-
df = df.melt(
|
251
|
-
var_name='key',
|
252
|
-
value_name='value'
|
253
|
-
)
|
254
|
+
df = df.melt(var_name="key", value_name="value")
|
254
255
|
# Add a row number column for reference
|
255
|
-
df.insert(0,
|
256
|
-
|
256
|
+
df.insert(0, "row_number", range(1, len(df) + 1))
|
257
|
+
|
257
258
|
# Split the key into data_type and key
|
258
|
-
df[
|
259
|
-
|
259
|
+
df["data_type"] = df["key"].apply(
|
260
|
+
lambda x: x.split(".")[0] if "." in x else None
|
261
|
+
)
|
262
|
+
df["key"] = df["key"].apply(
|
263
|
+
lambda x: ".".join(x.split(".")[1:]) if "." in x else x
|
264
|
+
)
|
260
265
|
|
261
266
|
df.to_sql(
|
262
267
|
"self",
|
@@ -276,27 +281,27 @@ class DatasetExportMixin:
|
|
276
281
|
) -> Union["pd.DataFrame", str]:
|
277
282
|
"""Execute a SQL query and return the results as a DataFrame.
|
278
283
|
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
284
|
+
Args:
|
285
|
+
query: The SQL query to execute
|
286
|
+
shape: The shape of the data in the database (wide or long)
|
287
|
+
remove_prefix: Whether to remove the prefix from the column names
|
288
|
+
transpose: Whether to transpose the DataFrame
|
289
|
+
transpose_by: The column to use as the index when transposing
|
290
|
+
csv: Whether to return the DataFrame as a CSV string
|
291
|
+
to_list: Whether to return the results as a list
|
292
|
+
to_latex: Whether to return the results as LaTeX
|
293
|
+
filename: Optional filename to save the results to
|
294
|
+
|
295
|
+
Returns:
|
296
|
+
DataFrame, CSV string, list, or LaTeX string depending on parameters
|
297
|
+
|
298
|
+
Examples:
|
299
|
+
>>> from edsl import Results
|
300
|
+
>>> r = Results.example();
|
301
|
+
>>> len(r.sql("SELECT * FROM self", shape = "wide"))
|
302
|
+
4
|
303
|
+
>>> len(r.sql("SELECT * FROM self", shape = "long"))
|
304
|
+
172
|
300
305
|
"""
|
301
306
|
import pandas as pd
|
302
307
|
|
@@ -538,6 +543,116 @@ class DatasetExportMixin:
|
|
538
543
|
|
539
544
|
if return_link:
|
540
545
|
return filename
|
546
|
+
|
547
|
+
def report(self, *fields: Optional[str], top_n: Optional[int] = None,
|
548
|
+
header_fields: Optional[List[str]] = None, divider: bool = True,
|
549
|
+
return_string: bool = False) -> Optional[str]:
|
550
|
+
"""Takes the fields in order and returns a report of the results by iterating through rows.
|
551
|
+
The row number is printed as # Observation: <row number>
|
552
|
+
The name of the field is used as markdown header at level "##"
|
553
|
+
The content of that field is then printed.
|
554
|
+
Then the next field and so on.
|
555
|
+
Once that row is done, a new line is printed and the next row is shown.
|
556
|
+
If in a jupyter notebook, the report is displayed as markdown.
|
557
|
+
|
558
|
+
Args:
|
559
|
+
*fields: The fields to include in the report. If none provided, all fields are used.
|
560
|
+
top_n: Optional limit on the number of observations to include.
|
561
|
+
header_fields: Optional list of fields to include in the main header instead of as sections.
|
562
|
+
divider: If True, adds a horizontal rule between observations for better visual separation.
|
563
|
+
return_string: If True, returns the markdown string. If False (default in notebooks),
|
564
|
+
only displays the markdown without returning.
|
565
|
+
|
566
|
+
Returns:
|
567
|
+
A string containing the markdown report if return_string is True, otherwise None.
|
568
|
+
|
569
|
+
Examples:
|
570
|
+
>>> from edsl.results import Results
|
571
|
+
>>> r = Results.example()
|
572
|
+
>>> report = r.select('how_feeling', 'how_feeling_yesterday').report(return_string=True)
|
573
|
+
>>> "# Observation: 1" in report
|
574
|
+
True
|
575
|
+
>>> "## answer.how_feeling" in report
|
576
|
+
True
|
577
|
+
>>> report = r.select('how_feeling').report(header_fields=['answer.how_feeling'], return_string=True)
|
578
|
+
>>> "# Observation: 1 (`how_feeling`: OK)" in report
|
579
|
+
True
|
580
|
+
"""
|
581
|
+
from edsl.utilities.utilities import is_notebook
|
582
|
+
|
583
|
+
# If no fields specified, use all columns
|
584
|
+
if not fields:
|
585
|
+
fields = self.relevant_columns()
|
586
|
+
|
587
|
+
# Initialize header_fields if not provided
|
588
|
+
if header_fields is None:
|
589
|
+
header_fields = []
|
590
|
+
|
591
|
+
# Validate all fields
|
592
|
+
all_fields = list(fields) + [f for f in header_fields if f not in fields]
|
593
|
+
for field in all_fields:
|
594
|
+
if field not in self.relevant_columns():
|
595
|
+
raise ValueError(f"Field '{field}' not found in dataset")
|
596
|
+
|
597
|
+
# Get data for each field
|
598
|
+
field_data = {}
|
599
|
+
for field in all_fields:
|
600
|
+
for entry in self:
|
601
|
+
if field in entry:
|
602
|
+
field_data[field] = entry[field]
|
603
|
+
break
|
604
|
+
|
605
|
+
# Number of observations to process
|
606
|
+
num_obs = self.num_observations()
|
607
|
+
if top_n is not None:
|
608
|
+
num_obs = min(num_obs, top_n)
|
609
|
+
|
610
|
+
# Build the report
|
611
|
+
report_lines = []
|
612
|
+
for i in range(num_obs):
|
613
|
+
# Create header with observation number and any header fields
|
614
|
+
header = f"# Observation: {i+1}"
|
615
|
+
if header_fields:
|
616
|
+
header_parts = []
|
617
|
+
for field in header_fields:
|
618
|
+
value = field_data[field][i]
|
619
|
+
# Get the field name without prefix for cleaner display
|
620
|
+
display_name = field.split('.')[-1] if '.' in field else field
|
621
|
+
# Format with backticks for monospace
|
622
|
+
header_parts.append(f"`{display_name}`: {value}")
|
623
|
+
if header_parts:
|
624
|
+
header += f" ({', '.join(header_parts)})"
|
625
|
+
report_lines.append(header)
|
626
|
+
|
627
|
+
# Add the remaining fields
|
628
|
+
for field in fields:
|
629
|
+
if field not in header_fields:
|
630
|
+
report_lines.append(f"## {field}")
|
631
|
+
value = field_data[field][i]
|
632
|
+
if isinstance(value, list) or isinstance(value, dict):
|
633
|
+
import json
|
634
|
+
report_lines.append(f"```\n{json.dumps(value, indent=2)}\n```")
|
635
|
+
else:
|
636
|
+
report_lines.append(str(value))
|
637
|
+
|
638
|
+
# Add divider between observations if requested
|
639
|
+
if divider and i < num_obs - 1:
|
640
|
+
report_lines.append("\n---\n")
|
641
|
+
else:
|
642
|
+
report_lines.append("") # Empty line between observations
|
643
|
+
|
644
|
+
report_text = "\n".join(report_lines)
|
645
|
+
|
646
|
+
# In notebooks, display as markdown and optionally return
|
647
|
+
is_nb = is_notebook()
|
648
|
+
if is_nb:
|
649
|
+
from IPython.display import Markdown, display
|
650
|
+
display(Markdown(report_text))
|
651
|
+
|
652
|
+
# Return the string if requested or if not in a notebook
|
653
|
+
if return_string or not is_nb:
|
654
|
+
return report_text
|
655
|
+
return None
|
541
656
|
|
542
657
|
def tally(
|
543
658
|
self, *fields: Optional[str], top_n: Optional[int] = None, output="Dataset"
|
@@ -616,6 +731,158 @@ class DatasetExportMixin:
|
|
616
731
|
keys.append("count")
|
617
732
|
return sl.reorder_keys(keys).to_dataset()
|
618
733
|
|
734
|
+
def flatten(self, field, keep_original=False):
|
735
|
+
"""
|
736
|
+
Flatten a field containing a list of dictionaries into separate fields.
|
737
|
+
|
738
|
+
For example, if a dataset contains:
|
739
|
+
[{'data': [{'a': 1}, {'b': 2}], 'other': ['x', 'y']}]
|
740
|
+
|
741
|
+
After d.flatten('data'), it should become:
|
742
|
+
[{'other': ['x', 'y'], 'data.a': [1, None], 'data.b': [None, 2]}]
|
743
|
+
|
744
|
+
Args:
|
745
|
+
field: The field to flatten
|
746
|
+
keep_original: If True, keeps the original field in the dataset
|
747
|
+
|
748
|
+
Returns:
|
749
|
+
A new dataset with the flattened fields
|
750
|
+
"""
|
751
|
+
from edsl.results.Dataset import Dataset
|
752
|
+
|
753
|
+
# Ensure the dataset isn't empty
|
754
|
+
if not self.data:
|
755
|
+
return self.copy()
|
756
|
+
|
757
|
+
# Get the number of observations
|
758
|
+
num_observations = self.num_observations()
|
759
|
+
|
760
|
+
# Find the column to flatten
|
761
|
+
field_entry = None
|
762
|
+
for entry in self.data:
|
763
|
+
if field in entry:
|
764
|
+
field_entry = entry
|
765
|
+
break
|
766
|
+
|
767
|
+
if field_entry is None:
|
768
|
+
warnings.warn(
|
769
|
+
f"Field '{field}' not found in dataset, returning original dataset"
|
770
|
+
)
|
771
|
+
return self.copy()
|
772
|
+
|
773
|
+
# Create new dictionary for flattened data
|
774
|
+
flattened_data = []
|
775
|
+
|
776
|
+
# Copy all existing columns except the one we're flattening (if keep_original is False)
|
777
|
+
for entry in self.data:
|
778
|
+
col_name = next(iter(entry.keys()))
|
779
|
+
if col_name != field or keep_original:
|
780
|
+
flattened_data.append(entry.copy())
|
781
|
+
|
782
|
+
# Get field data and make sure it's valid
|
783
|
+
field_values = field_entry[field]
|
784
|
+
if not all(isinstance(item, dict) for item in field_values if item is not None):
|
785
|
+
warnings.warn(
|
786
|
+
f"Field '{field}' contains non-dictionary values that cannot be flattened"
|
787
|
+
)
|
788
|
+
return self.copy()
|
789
|
+
|
790
|
+
# Collect all unique keys across all dictionaries
|
791
|
+
all_keys = set()
|
792
|
+
for item in field_values:
|
793
|
+
if isinstance(item, dict):
|
794
|
+
all_keys.update(item.keys())
|
795
|
+
|
796
|
+
# Create new columns for each key
|
797
|
+
for key in sorted(all_keys): # Sort for consistent output
|
798
|
+
new_values = []
|
799
|
+
for i in range(num_observations):
|
800
|
+
value = None
|
801
|
+
if i < len(field_values) and isinstance(field_values[i], dict):
|
802
|
+
value = field_values[i].get(key, None)
|
803
|
+
new_values.append(value)
|
804
|
+
|
805
|
+
# Add this as a new column
|
806
|
+
flattened_data.append({f"{field}.{key}": new_values})
|
807
|
+
|
808
|
+
# Return a new Dataset with the flattened data
|
809
|
+
return Dataset(flattened_data)
|
810
|
+
|
811
|
+
def unpack_list(
|
812
|
+
self,
|
813
|
+
field: str,
|
814
|
+
new_names: Optional[List[str]] = None,
|
815
|
+
keep_original: bool = True,
|
816
|
+
) -> "Dataset":
|
817
|
+
"""Unpack list columns into separate columns with provided names or numeric suffixes.
|
818
|
+
|
819
|
+
For example, if a dataset contains:
|
820
|
+
[{'data': [[1, 2, 3], [4, 5, 6]], 'other': ['x', 'y']}]
|
821
|
+
|
822
|
+
After d.unpack_list('data'), it should become:
|
823
|
+
[{'other': ['x', 'y'], 'data_1': [1, 4], 'data_2': [2, 5], 'data_3': [3, 6]}]
|
824
|
+
|
825
|
+
Args:
|
826
|
+
field: The field containing lists to unpack
|
827
|
+
new_names: Optional list of names for the unpacked fields. If None, uses numeric suffixes.
|
828
|
+
keep_original: If True, keeps the original field in the dataset
|
829
|
+
|
830
|
+
Returns:
|
831
|
+
A new Dataset with unpacked columns
|
832
|
+
|
833
|
+
Examples:
|
834
|
+
>>> from edsl.results.Dataset import Dataset
|
835
|
+
>>> d = Dataset([{'data': [[1, 2, 3], [4, 5, 6]]}])
|
836
|
+
>>> d.unpack_list('data')
|
837
|
+
Dataset([{'data': [[1, 2, 3], [4, 5, 6]]}, {'data_1': [1, 4]}, {'data_2': [2, 5]}, {'data_3': [3, 6]}])
|
838
|
+
|
839
|
+
>>> d.unpack_list('data', new_names=['first', 'second', 'third'])
|
840
|
+
Dataset([{'data': [[1, 2, 3], [4, 5, 6]]}, {'first': [1, 4]}, {'second': [2, 5]}, {'third': [3, 6]}])
|
841
|
+
"""
|
842
|
+
from edsl.results.Dataset import Dataset
|
843
|
+
|
844
|
+
# Create a copy of the dataset
|
845
|
+
result = Dataset(self.data.copy())
|
846
|
+
|
847
|
+
# Find the field in the dataset
|
848
|
+
field_index = None
|
849
|
+
for i, entry in enumerate(result.data):
|
850
|
+
if field in entry:
|
851
|
+
field_index = i
|
852
|
+
break
|
853
|
+
|
854
|
+
if field_index is None:
|
855
|
+
raise ValueError(f"Field '{field}' not found in dataset")
|
856
|
+
|
857
|
+
field_data = result.data[field_index][field]
|
858
|
+
|
859
|
+
# Check if values are lists
|
860
|
+
if not all(isinstance(v, list) for v in field_data):
|
861
|
+
raise ValueError(f"Field '{field}' does not contain lists in all entries")
|
862
|
+
|
863
|
+
# Get the maximum length of lists
|
864
|
+
max_len = max(len(v) for v in field_data)
|
865
|
+
|
866
|
+
# Create new fields for each index
|
867
|
+
for i in range(max_len):
|
868
|
+
if new_names and i < len(new_names):
|
869
|
+
new_field = new_names[i]
|
870
|
+
else:
|
871
|
+
new_field = f"{field}_{i+1}"
|
872
|
+
|
873
|
+
# Extract the i-th element from each list
|
874
|
+
new_values = []
|
875
|
+
for item in field_data:
|
876
|
+
new_values.append(item[i] if i < len(item) else None)
|
877
|
+
|
878
|
+
result.data.append({new_field: new_values})
|
879
|
+
|
880
|
+
# Remove the original field if keep_original is False
|
881
|
+
if not keep_original:
|
882
|
+
result.data.pop(field_index)
|
883
|
+
|
884
|
+
return result
|
885
|
+
|
619
886
|
|
620
887
|
if __name__ == "__main__":
|
621
888
|
import doctest
|
edsl/results/Result.py
CHANGED
@@ -439,6 +439,33 @@ class Result(Base, UserDict):
|
|
439
439
|
from edsl.results.Results import Results
|
440
440
|
|
441
441
|
return Results.example()[0]
|
442
|
+
|
443
|
+
def score_with_answer_key(self, answer_key: dict) -> Union[int, float]:
|
444
|
+
"""Score the result using an answer key.
|
445
|
+
|
446
|
+
:param answer_key: A dictionary that maps question_names to answers
|
447
|
+
|
448
|
+
>>> Result.example()['answer']
|
449
|
+
{'how_feeling': 'OK', 'how_feeling_yesterday': 'Great'}
|
450
|
+
|
451
|
+
>>> answer_key = {'how_feeling': 'OK', 'how_feeling_yesterday': 'Great'}
|
452
|
+
>>> Result.example().score_with_answer_key(answer_key)
|
453
|
+
{'correct': 2, 'incorrect': 0, 'missing': 0}
|
454
|
+
>>> answer_key = {'how_feeling': 'OK', 'how_feeling_yesterday': ['Great', 'Good']}
|
455
|
+
>>> Result.example().score_with_answer_key(answer_key)
|
456
|
+
{'correct': 2, 'incorrect': 0, 'missing': 0}
|
457
|
+
"""
|
458
|
+
final_scores = {'correct': 0, 'incorrect': 0, 'missing': 0}
|
459
|
+
for question_name, answer in self.answer.items():
|
460
|
+
if question_name in answer_key:
|
461
|
+
if answer == answer_key[question_name] or answer in answer_key[question_name]:
|
462
|
+
final_scores['correct'] += 1
|
463
|
+
else:
|
464
|
+
final_scores['incorrect'] += 1
|
465
|
+
else:
|
466
|
+
final_scores['missing'] += 1
|
467
|
+
|
468
|
+
return final_scores
|
442
469
|
|
443
470
|
def score(self, scoring_function: Callable) -> Union[int, float]:
|
444
471
|
"""Score the result using a passed-in scoring function.
|
edsl/results/Results.py
CHANGED
@@ -34,7 +34,7 @@ if TYPE_CHECKING:
|
|
34
34
|
from simpleeval import EvalWithCompoundTypes
|
35
35
|
|
36
36
|
from edsl.results.ResultsExportMixin import ResultsExportMixin
|
37
|
-
from edsl.results.ResultsGGMixin import
|
37
|
+
from edsl.results.ResultsGGMixin import GGPlotMethod
|
38
38
|
from edsl.results.results_fetch_mixin import ResultsFetchMixin
|
39
39
|
from edsl.utilities.remove_edsl_version import remove_edsl_version
|
40
40
|
|
@@ -100,7 +100,7 @@ class NotReadyObject:
|
|
100
100
|
class Mixins(
|
101
101
|
ResultsExportMixin,
|
102
102
|
ResultsFetchMixin,
|
103
|
-
ResultsGGMixin,
|
103
|
+
# ResultsGGMixin,
|
104
104
|
):
|
105
105
|
def long(self):
|
106
106
|
return self.table().long()
|
@@ -151,6 +151,19 @@ class Results(UserList, Mixins, Base):
|
|
151
151
|
"cache_keys",
|
152
152
|
]
|
153
153
|
|
154
|
+
def ggplot2(
|
155
|
+
self,
|
156
|
+
ggplot_code: str,
|
157
|
+
shape="wide",
|
158
|
+
sql: str = None,
|
159
|
+
remove_prefix: bool = True,
|
160
|
+
debug: bool = False,
|
161
|
+
height=4,
|
162
|
+
width=6,
|
163
|
+
factor_orders: Optional[dict] = None,
|
164
|
+
):
|
165
|
+
return GGPlotMethod(self).ggplot2(ggplot_code, shape, sql, remove_prefix, debug, height, width, factor_orders)
|
166
|
+
|
154
167
|
@classmethod
|
155
168
|
def from_job_info(cls, job_info: dict) -> Results:
|
156
169
|
"""
|
@@ -1277,6 +1290,13 @@ class Results(UserList, Mixins, Base):
|
|
1277
1290
|
"""
|
1278
1291
|
return [r.score(f) for r in self.data]
|
1279
1292
|
|
1293
|
+
def score_with_answer_key(self, answer_key: dict) -> list:
|
1294
|
+
"""Score the results using an answer key.
|
1295
|
+
|
1296
|
+
:param answer_key: A dictionary that maps answer values to scores.
|
1297
|
+
"""
|
1298
|
+
return [r.score_with_answer_key(answer_key) for r in self.data]
|
1299
|
+
|
1280
1300
|
|
1281
1301
|
def fetch_remote(self, job_info: "RemoteJobInfo") -> None:
|
1282
1302
|
"""
|
@@ -1326,7 +1346,7 @@ class Results(UserList, Mixins, Base):
|
|
1326
1346
|
except Exception as e:
|
1327
1347
|
raise ResultsError(f"Failed to fetch remote results: {str(e)}")
|
1328
1348
|
|
1329
|
-
def fetch(self, polling_interval: float = 1.0) -> Results:
|
1349
|
+
def fetch(self, polling_interval: [float, int] = 1.0) -> Results:
|
1330
1350
|
"""
|
1331
1351
|
Polls the server for job completion and updates this Results instance with the completed data.
|
1332
1352
|
|
@@ -1346,6 +1366,7 @@ class Results(UserList, Mixins, Base):
|
|
1346
1366
|
remote_job_data = JobsRemoteInferenceHandler.check_status(self.job_info.job_uuid)
|
1347
1367
|
|
1348
1368
|
while remote_job_data.get("status") not in ["completed", "failed"]:
|
1369
|
+
print("Waiting for remote job to complete...")
|
1349
1370
|
import time
|
1350
1371
|
time.sleep(polling_interval)
|
1351
1372
|
remote_job_data = JobsRemoteInferenceHandler.check_status(self.job_info.job_uuid)
|
edsl/results/ResultsGGMixin.py
CHANGED
@@ -75,7 +75,11 @@ class GGPlot:
|
|
75
75
|
|
76
76
|
return self._svg_data
|
77
77
|
|
78
|
-
class
|
78
|
+
class GGPlotMethod:
|
79
|
+
|
80
|
+
def __init__(self, results: 'Results'):
|
81
|
+
self.results = results
|
82
|
+
|
79
83
|
"""Mixin class for ggplot2 plotting."""
|
80
84
|
|
81
85
|
def ggplot2(
|
@@ -106,9 +110,9 @@ class ResultsGGMixin:
|
|
106
110
|
sql = "select * from self"
|
107
111
|
|
108
112
|
if shape == "long":
|
109
|
-
df = self.sql(sql, shape="long")
|
113
|
+
df = self.results.sql(sql, shape="long")
|
110
114
|
elif shape == "wide":
|
111
|
-
df = self.sql(sql, remove_prefix=remove_prefix)
|
115
|
+
df = self.results.sql(sql, remove_prefix=remove_prefix)
|
112
116
|
|
113
117
|
# Convert DataFrame to CSV format
|
114
118
|
csv_data = df.to_csv().text
|
@@ -85,6 +85,8 @@ class DocumentChunker:
|
|
85
85
|
new_scenario = copy.deepcopy(self.scenario)
|
86
86
|
new_scenario[field] = chunk
|
87
87
|
new_scenario[field + "_chunk"] = i
|
88
|
+
new_scenario[field + "_char_count"] = len(chunk)
|
89
|
+
new_scenario[field + "_word_count"] = len(chunk.split())
|
88
90
|
if include_original:
|
89
91
|
if hash_original:
|
90
92
|
new_scenario[field + "_original"] = hashlib.md5(
|
edsl/scenarios/FileStore.py
CHANGED
@@ -9,7 +9,8 @@ from edsl.scenarios.Scenario import Scenario
|
|
9
9
|
from edsl.utilities.remove_edsl_version import remove_edsl_version
|
10
10
|
|
11
11
|
from edsl.scenarios.file_methods import FileMethods
|
12
|
-
|
12
|
+
from typing import Union
|
13
|
+
from uuid import UUID
|
13
14
|
|
14
15
|
class FileStore(Scenario):
|
15
16
|
__documentation__ = "https://docs.expectedparrot.com/en/latest/filestore.html"
|
@@ -28,6 +29,12 @@ class FileStore(Scenario):
|
|
28
29
|
if path is None and "filename" in kwargs:
|
29
30
|
path = kwargs["filename"]
|
30
31
|
|
32
|
+
# Check if path is a URL and handle download
|
33
|
+
if path and (path.startswith('http://') or path.startswith('https://')):
|
34
|
+
temp_filestore = self.from_url(path, mime_type=mime_type)
|
35
|
+
path = temp_filestore._path
|
36
|
+
mime_type = temp_filestore.mime_type
|
37
|
+
|
31
38
|
self._path = path # Store the original path privately
|
32
39
|
self._temp_path = None # Track any generated temporary file
|
33
40
|
|
@@ -137,6 +144,10 @@ class FileStore(Scenario):
|
|
137
144
|
base64_encoded_data = base64.b64encode(binary_data)
|
138
145
|
self.binary = True
|
139
146
|
# Convert the base64 bytes to a string
|
147
|
+
except FileNotFoundError:
|
148
|
+
print(f"File not found: {file_path}")
|
149
|
+
print("Current working directory:", os.getcwd())
|
150
|
+
raise
|
140
151
|
base64_string = base64_encoded_data.decode("utf-8")
|
141
152
|
|
142
153
|
return base64_string
|
@@ -262,7 +273,12 @@ class FileStore(Scenario):
|
|
262
273
|
# raise TypeError("No text method found for this file type.")
|
263
274
|
|
264
275
|
def push(
|
265
|
-
self,
|
276
|
+
self,
|
277
|
+
description: Optional[str] = None,
|
278
|
+
alias: Optional[str] = None,
|
279
|
+
visibility: Optional[str] = "unlisted",
|
280
|
+
expected_parrot_url: Optional[str] = None,
|
281
|
+
|
266
282
|
) -> dict:
|
267
283
|
"""
|
268
284
|
Push the object to Coop.
|
@@ -272,17 +288,22 @@ class FileStore(Scenario):
|
|
272
288
|
scenario_version = Scenario.from_dict(self.to_dict())
|
273
289
|
if description is None:
|
274
290
|
description = "File: " + self.path
|
275
|
-
info = scenario_version.push(description=description, visibility=visibility)
|
291
|
+
info = scenario_version.push(description=description, visibility=visibility, expected_parrot_url=expected_parrot_url, alias=alias)
|
276
292
|
return info
|
277
293
|
|
278
294
|
@classmethod
|
279
|
-
def pull(cls,
|
295
|
+
def pull(cls, url_or_uuid: Union[str, UUID]) -> "FileStore":
|
280
296
|
"""
|
281
|
-
|
282
|
-
|
283
|
-
:
|
297
|
+
Pull a FileStore object from Coop.
|
298
|
+
|
299
|
+
Args:
|
300
|
+
url_or_uuid: Either a UUID string or a URL pointing to the object
|
301
|
+
expected_parrot_url: Optional URL for the Parrot server
|
302
|
+
|
303
|
+
Returns:
|
304
|
+
FileStore: The pulled FileStore object
|
284
305
|
"""
|
285
|
-
scenario_version = Scenario.pull(
|
306
|
+
scenario_version = Scenario.pull(url_or_uuid)
|
286
307
|
return cls.from_dict(scenario_version.to_dict())
|
287
308
|
|
288
309
|
@classmethod
|