edsl 0.1.31.dev4__py3-none-any.whl → 0.1.33__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +9 -3
- edsl/TemplateLoader.py +24 -0
- edsl/__init__.py +8 -3
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +40 -8
- edsl/agents/AgentList.py +43 -0
- edsl/agents/Invigilator.py +136 -221
- edsl/agents/InvigilatorBase.py +148 -59
- edsl/agents/{PromptConstructionMixin.py → PromptConstructor.py} +154 -85
- edsl/agents/__init__.py +1 -0
- edsl/auto/AutoStudy.py +117 -0
- edsl/auto/StageBase.py +230 -0
- edsl/auto/StageGenerateSurvey.py +178 -0
- edsl/auto/StageLabelQuestions.py +125 -0
- edsl/auto/StagePersona.py +61 -0
- edsl/auto/StagePersonaDimensionValueRanges.py +88 -0
- edsl/auto/StagePersonaDimensionValues.py +74 -0
- edsl/auto/StagePersonaDimensions.py +69 -0
- edsl/auto/StageQuestions.py +73 -0
- edsl/auto/SurveyCreatorPipeline.py +21 -0
- edsl/auto/utilities.py +224 -0
- edsl/config.py +48 -47
- edsl/conjure/Conjure.py +6 -0
- edsl/coop/PriceFetcher.py +58 -0
- edsl/coop/coop.py +50 -7
- edsl/data/Cache.py +35 -1
- edsl/data/CacheHandler.py +3 -4
- edsl/data_transfer_models.py +73 -38
- edsl/enums.py +8 -0
- edsl/exceptions/general.py +10 -8
- edsl/exceptions/language_models.py +25 -1
- edsl/exceptions/questions.py +62 -5
- edsl/exceptions/results.py +4 -0
- edsl/inference_services/AnthropicService.py +13 -11
- edsl/inference_services/AwsBedrock.py +112 -0
- edsl/inference_services/AzureAI.py +214 -0
- edsl/inference_services/DeepInfraService.py +4 -3
- edsl/inference_services/GoogleService.py +16 -12
- edsl/inference_services/GroqService.py +5 -4
- edsl/inference_services/InferenceServiceABC.py +58 -3
- edsl/inference_services/InferenceServicesCollection.py +13 -8
- edsl/inference_services/MistralAIService.py +120 -0
- edsl/inference_services/OllamaService.py +18 -0
- edsl/inference_services/OpenAIService.py +55 -56
- edsl/inference_services/TestService.py +80 -0
- edsl/inference_services/TogetherAIService.py +170 -0
- edsl/inference_services/models_available_cache.py +25 -0
- edsl/inference_services/registry.py +19 -1
- edsl/jobs/Answers.py +10 -12
- edsl/jobs/FailedQuestion.py +78 -0
- edsl/jobs/Jobs.py +137 -41
- edsl/jobs/buckets/BucketCollection.py +24 -15
- edsl/jobs/buckets/TokenBucket.py +105 -18
- edsl/jobs/interviews/Interview.py +393 -83
- edsl/jobs/interviews/{interview_exception_tracking.py → InterviewExceptionCollection.py} +22 -18
- edsl/jobs/interviews/InterviewExceptionEntry.py +167 -0
- edsl/jobs/runners/JobsRunnerAsyncio.py +152 -160
- edsl/jobs/runners/JobsRunnerStatus.py +331 -0
- edsl/jobs/tasks/QuestionTaskCreator.py +30 -23
- edsl/jobs/tasks/TaskCreators.py +1 -1
- edsl/jobs/tasks/TaskHistory.py +205 -126
- edsl/language_models/LanguageModel.py +297 -177
- edsl/language_models/ModelList.py +2 -2
- edsl/language_models/RegisterLanguageModelsMeta.py +14 -29
- edsl/language_models/fake_openai_call.py +15 -0
- edsl/language_models/fake_openai_service.py +61 -0
- edsl/language_models/registry.py +25 -8
- edsl/language_models/repair.py +0 -19
- edsl/language_models/utilities.py +61 -0
- edsl/notebooks/Notebook.py +20 -2
- edsl/prompts/Prompt.py +52 -2
- edsl/questions/AnswerValidatorMixin.py +23 -26
- edsl/questions/QuestionBase.py +330 -249
- edsl/questions/QuestionBaseGenMixin.py +133 -0
- edsl/questions/QuestionBasePromptsMixin.py +266 -0
- edsl/questions/QuestionBudget.py +99 -42
- edsl/questions/QuestionCheckBox.py +227 -36
- edsl/questions/QuestionExtract.py +98 -28
- edsl/questions/QuestionFreeText.py +47 -31
- edsl/questions/QuestionFunctional.py +7 -0
- edsl/questions/QuestionList.py +141 -23
- edsl/questions/QuestionMultipleChoice.py +159 -66
- edsl/questions/QuestionNumerical.py +88 -47
- edsl/questions/QuestionRank.py +182 -25
- edsl/questions/Quick.py +41 -0
- edsl/questions/RegisterQuestionsMeta.py +31 -12
- edsl/questions/ResponseValidatorABC.py +170 -0
- edsl/questions/__init__.py +3 -4
- edsl/questions/decorators.py +21 -0
- edsl/questions/derived/QuestionLikertFive.py +10 -5
- edsl/questions/derived/QuestionLinearScale.py +15 -2
- edsl/questions/derived/QuestionTopK.py +10 -1
- edsl/questions/derived/QuestionYesNo.py +24 -3
- edsl/questions/descriptors.py +43 -7
- edsl/questions/prompt_templates/question_budget.jinja +13 -0
- edsl/questions/prompt_templates/question_checkbox.jinja +32 -0
- edsl/questions/prompt_templates/question_extract.jinja +11 -0
- edsl/questions/prompt_templates/question_free_text.jinja +3 -0
- edsl/questions/prompt_templates/question_linear_scale.jinja +11 -0
- edsl/questions/prompt_templates/question_list.jinja +17 -0
- edsl/questions/prompt_templates/question_multiple_choice.jinja +33 -0
- edsl/questions/prompt_templates/question_numerical.jinja +37 -0
- edsl/questions/question_registry.py +6 -2
- edsl/questions/templates/__init__.py +0 -0
- edsl/questions/templates/budget/__init__.py +0 -0
- edsl/questions/templates/budget/answering_instructions.jinja +7 -0
- edsl/questions/templates/budget/question_presentation.jinja +7 -0
- edsl/questions/templates/checkbox/__init__.py +0 -0
- edsl/questions/templates/checkbox/answering_instructions.jinja +10 -0
- edsl/questions/templates/checkbox/question_presentation.jinja +22 -0
- edsl/questions/templates/extract/__init__.py +0 -0
- edsl/questions/templates/extract/answering_instructions.jinja +7 -0
- edsl/questions/templates/extract/question_presentation.jinja +1 -0
- edsl/questions/templates/free_text/__init__.py +0 -0
- edsl/questions/templates/free_text/answering_instructions.jinja +0 -0
- edsl/questions/templates/free_text/question_presentation.jinja +1 -0
- edsl/questions/templates/likert_five/__init__.py +0 -0
- edsl/questions/templates/likert_five/answering_instructions.jinja +10 -0
- edsl/questions/templates/likert_five/question_presentation.jinja +12 -0
- edsl/questions/templates/linear_scale/__init__.py +0 -0
- edsl/questions/templates/linear_scale/answering_instructions.jinja +5 -0
- edsl/questions/templates/linear_scale/question_presentation.jinja +5 -0
- edsl/questions/templates/list/__init__.py +0 -0
- edsl/questions/templates/list/answering_instructions.jinja +4 -0
- edsl/questions/templates/list/question_presentation.jinja +5 -0
- edsl/questions/templates/multiple_choice/__init__.py +0 -0
- edsl/questions/templates/multiple_choice/answering_instructions.jinja +9 -0
- edsl/questions/templates/multiple_choice/html.jinja +0 -0
- edsl/questions/templates/multiple_choice/question_presentation.jinja +12 -0
- edsl/questions/templates/numerical/__init__.py +0 -0
- edsl/questions/templates/numerical/answering_instructions.jinja +8 -0
- edsl/questions/templates/numerical/question_presentation.jinja +7 -0
- edsl/questions/templates/rank/__init__.py +0 -0
- edsl/questions/templates/rank/answering_instructions.jinja +11 -0
- edsl/questions/templates/rank/question_presentation.jinja +15 -0
- edsl/questions/templates/top_k/__init__.py +0 -0
- edsl/questions/templates/top_k/answering_instructions.jinja +8 -0
- edsl/questions/templates/top_k/question_presentation.jinja +22 -0
- edsl/questions/templates/yes_no/__init__.py +0 -0
- edsl/questions/templates/yes_no/answering_instructions.jinja +6 -0
- edsl/questions/templates/yes_no/question_presentation.jinja +12 -0
- edsl/results/Dataset.py +20 -0
- edsl/results/DatasetExportMixin.py +58 -30
- edsl/results/DatasetTree.py +145 -0
- edsl/results/Result.py +32 -5
- edsl/results/Results.py +135 -46
- edsl/results/ResultsDBMixin.py +3 -3
- edsl/results/Selector.py +118 -0
- edsl/results/tree_explore.py +115 -0
- edsl/scenarios/FileStore.py +71 -10
- edsl/scenarios/Scenario.py +109 -24
- edsl/scenarios/ScenarioImageMixin.py +2 -2
- edsl/scenarios/ScenarioList.py +546 -21
- edsl/scenarios/ScenarioListExportMixin.py +24 -4
- edsl/scenarios/ScenarioListPdfMixin.py +153 -4
- edsl/study/SnapShot.py +8 -1
- edsl/study/Study.py +32 -0
- edsl/surveys/Rule.py +15 -3
- edsl/surveys/RuleCollection.py +21 -5
- edsl/surveys/Survey.py +707 -298
- edsl/surveys/SurveyExportMixin.py +71 -9
- edsl/surveys/SurveyFlowVisualizationMixin.py +2 -1
- edsl/surveys/SurveyQualtricsImport.py +284 -0
- edsl/surveys/instructions/ChangeInstruction.py +47 -0
- edsl/surveys/instructions/Instruction.py +34 -0
- edsl/surveys/instructions/InstructionCollection.py +77 -0
- edsl/surveys/instructions/__init__.py +0 -0
- edsl/templates/error_reporting/base.html +24 -0
- edsl/templates/error_reporting/exceptions_by_model.html +35 -0
- edsl/templates/error_reporting/exceptions_by_question_name.html +17 -0
- edsl/templates/error_reporting/exceptions_by_type.html +17 -0
- edsl/templates/error_reporting/interview_details.html +116 -0
- edsl/templates/error_reporting/interviews.html +10 -0
- edsl/templates/error_reporting/overview.html +5 -0
- edsl/templates/error_reporting/performance_plot.html +2 -0
- edsl/templates/error_reporting/report.css +74 -0
- edsl/templates/error_reporting/report.html +118 -0
- edsl/templates/error_reporting/report.js +25 -0
- edsl/utilities/utilities.py +40 -1
- {edsl-0.1.31.dev4.dist-info → edsl-0.1.33.dist-info}/METADATA +8 -2
- edsl-0.1.33.dist-info/RECORD +295 -0
- edsl/jobs/interviews/InterviewTaskBuildingMixin.py +0 -271
- edsl/jobs/interviews/retry_management.py +0 -37
- edsl/jobs/runners/JobsRunnerStatusMixin.py +0 -303
- edsl/utilities/gcp_bucket/simple_example.py +0 -9
- edsl-0.1.31.dev4.dist-info/RECORD +0 -204
- {edsl-0.1.31.dev4.dist-info → edsl-0.1.33.dist-info}/LICENSE +0 -0
- {edsl-0.1.31.dev4.dist-info → edsl-0.1.33.dist-info}/WHEEL +0 -0
@@ -20,13 +20,33 @@ def to_dataset(func):
|
|
20
20
|
return wrapper
|
21
21
|
|
22
22
|
|
23
|
-
def
|
24
|
-
for attr_name, attr_value in
|
25
|
-
if callable(attr_value):
|
23
|
+
def decorate_methods_from_mixin(cls, mixin_cls):
|
24
|
+
for attr_name, attr_value in mixin_cls.__dict__.items():
|
25
|
+
if callable(attr_value) and not attr_name.startswith("__"):
|
26
26
|
setattr(cls, attr_name, to_dataset(attr_value))
|
27
27
|
return cls
|
28
28
|
|
29
29
|
|
30
|
-
|
30
|
+
# def decorate_all_methods(cls):
|
31
|
+
# for attr_name, attr_value in cls.__dict__.items():
|
32
|
+
# if callable(attr_value):
|
33
|
+
# setattr(cls, attr_name, to_dataset(attr_value))
|
34
|
+
# return cls
|
35
|
+
|
36
|
+
|
37
|
+
# @decorate_all_methods
|
31
38
|
class ScenarioListExportMixin(DatasetExportMixin):
|
32
39
|
"""Mixin class for exporting Results objects."""
|
40
|
+
|
41
|
+
def __init_subclass__(cls, **kwargs):
|
42
|
+
super().__init_subclass__(**kwargs)
|
43
|
+
decorate_methods_from_mixin(cls, DatasetExportMixin)
|
44
|
+
|
45
|
+
def to_docx(self, filename: str):
|
46
|
+
"""Export the ScenarioList to a .docx file."""
|
47
|
+
dataset = self.to_dataset()
|
48
|
+
from edsl.results.DatasetTree import Tree
|
49
|
+
|
50
|
+
tree = Tree(dataset)
|
51
|
+
tree.construct_tree()
|
52
|
+
tree.to_docx(filename)
|
@@ -1,15 +1,161 @@
|
|
1
1
|
import fitz # PyMuPDF
|
2
2
|
import os
|
3
|
+
import copy
|
3
4
|
import subprocess
|
5
|
+
import requests
|
6
|
+
import tempfile
|
7
|
+
import os
|
8
|
+
|
9
|
+
# import urllib.parse as urlparse
|
10
|
+
from urllib.parse import urlparse
|
4
11
|
|
5
12
|
# from edsl import Scenario
|
6
13
|
|
14
|
+
import requests
|
15
|
+
import re
|
16
|
+
import tempfile
|
17
|
+
import os
|
18
|
+
import atexit
|
19
|
+
from urllib.parse import urlparse, parse_qs
|
20
|
+
|
21
|
+
|
22
|
+
class GoogleDriveDownloader:
|
23
|
+
_temp_dir = None
|
24
|
+
_temp_file_path = None
|
25
|
+
|
26
|
+
@classmethod
|
27
|
+
def fetch_from_drive(cls, url, filename=None):
|
28
|
+
# Extract file ID from the URL
|
29
|
+
file_id = cls._extract_file_id(url)
|
30
|
+
if not file_id:
|
31
|
+
raise ValueError("Invalid Google Drive URL")
|
32
|
+
|
33
|
+
# Construct the download URL
|
34
|
+
download_url = f"https://drive.google.com/uc?export=download&id={file_id}"
|
35
|
+
|
36
|
+
# Send a GET request to the URL
|
37
|
+
session = requests.Session()
|
38
|
+
response = session.get(download_url, stream=True)
|
39
|
+
response.raise_for_status()
|
40
|
+
|
41
|
+
# Check for large file download prompt
|
42
|
+
for key, value in response.cookies.items():
|
43
|
+
if key.startswith("download_warning"):
|
44
|
+
params = {"id": file_id, "confirm": value}
|
45
|
+
response = session.get(download_url, params=params, stream=True)
|
46
|
+
break
|
47
|
+
|
48
|
+
# Create a temporary file to save the download
|
49
|
+
if not filename:
|
50
|
+
filename = "downloaded_file"
|
51
|
+
|
52
|
+
if cls._temp_dir is None:
|
53
|
+
cls._temp_dir = tempfile.TemporaryDirectory()
|
54
|
+
atexit.register(cls._cleanup)
|
55
|
+
|
56
|
+
cls._temp_file_path = os.path.join(cls._temp_dir.name, filename)
|
57
|
+
|
58
|
+
# Write the content to the temporary file
|
59
|
+
with open(cls._temp_file_path, "wb") as f:
|
60
|
+
for chunk in response.iter_content(32768):
|
61
|
+
if chunk:
|
62
|
+
f.write(chunk)
|
63
|
+
|
64
|
+
print(f"File saved to: {cls._temp_file_path}")
|
65
|
+
|
66
|
+
return cls._temp_file_path
|
67
|
+
|
68
|
+
@staticmethod
|
69
|
+
def _extract_file_id(url):
|
70
|
+
# Try to extract file ID from '/file/d/' format
|
71
|
+
file_id_match = re.search(r"/d/([a-zA-Z0-9-_]+)", url)
|
72
|
+
if file_id_match:
|
73
|
+
return file_id_match.group(1)
|
74
|
+
|
75
|
+
# If not found, try to extract from 'open?id=' format
|
76
|
+
parsed_url = urlparse(url)
|
77
|
+
query_params = parse_qs(parsed_url.query)
|
78
|
+
if "id" in query_params:
|
79
|
+
return query_params["id"][0]
|
80
|
+
|
81
|
+
return None
|
82
|
+
|
83
|
+
@classmethod
|
84
|
+
def _cleanup(cls):
|
85
|
+
if cls._temp_dir:
|
86
|
+
cls._temp_dir.cleanup()
|
87
|
+
|
88
|
+
@classmethod
|
89
|
+
def get_temp_file_path(cls):
|
90
|
+
return cls._temp_file_path
|
91
|
+
|
92
|
+
|
93
|
+
def fetch_and_save_pdf(url, filename):
|
94
|
+
# Send a GET request to the URL
|
95
|
+
response = requests.get(url)
|
96
|
+
|
97
|
+
# Check if the request was successful
|
98
|
+
response.raise_for_status()
|
99
|
+
|
100
|
+
# Create a temporary directory
|
101
|
+
with tempfile.TemporaryDirectory() as temp_dir:
|
102
|
+
# Construct the full path for the file
|
103
|
+
temp_file_path = os.path.join(temp_dir, filename)
|
104
|
+
|
105
|
+
# Write the content to the temporary file
|
106
|
+
with open(temp_file_path, "wb") as file:
|
107
|
+
file.write(response.content)
|
108
|
+
|
109
|
+
print(f"PDF saved to: {temp_file_path}")
|
110
|
+
|
111
|
+
# Here you can perform operations with the file
|
112
|
+
# The file will be automatically deleted when you exit this block
|
113
|
+
|
114
|
+
return temp_file_path
|
115
|
+
|
116
|
+
|
117
|
+
# Example usage:
|
118
|
+
# url = "https://example.com/sample.pdf"
|
119
|
+
# fetch_and_save_pdf(url, "sample.pdf")
|
120
|
+
|
7
121
|
|
8
122
|
class ScenarioListPdfMixin:
|
9
123
|
@classmethod
|
10
|
-
def from_pdf(cls,
|
11
|
-
|
12
|
-
|
124
|
+
def from_pdf(cls, filename_or_url, collapse_pages=False):
|
125
|
+
# Check if the input is a URL
|
126
|
+
if cls.is_url(filename_or_url):
|
127
|
+
# Check if it's a Google Drive URL
|
128
|
+
if "drive.google.com" in filename_or_url:
|
129
|
+
temp_filename = GoogleDriveDownloader.fetch_from_drive(
|
130
|
+
filename_or_url, "temp_pdf.pdf"
|
131
|
+
)
|
132
|
+
else:
|
133
|
+
# For other URLs, use the previous fetch_and_save_pdf function
|
134
|
+
temp_filename = fetch_and_save_pdf(filename_or_url, "temp_pdf.pdf")
|
135
|
+
|
136
|
+
scenarios = list(cls.extract_text_from_pdf(temp_filename))
|
137
|
+
else:
|
138
|
+
# If it's not a URL, assume it's a local file path
|
139
|
+
scenarios = list(cls.extract_text_from_pdf(filename_or_url))
|
140
|
+
if not collapse_pages:
|
141
|
+
return cls(scenarios)
|
142
|
+
else:
|
143
|
+
txt = ""
|
144
|
+
for scenario in scenarios:
|
145
|
+
txt += scenario["text"]
|
146
|
+
from edsl.scenarios import Scenario
|
147
|
+
|
148
|
+
base_scenario = copy.copy(scenarios[0])
|
149
|
+
base_scenario["text"] = txt
|
150
|
+
return base_scenario
|
151
|
+
|
152
|
+
@staticmethod
|
153
|
+
def is_url(string):
|
154
|
+
try:
|
155
|
+
result = urlparse(string)
|
156
|
+
return all([result.scheme, result.netloc])
|
157
|
+
except ValueError:
|
158
|
+
return False
|
13
159
|
|
14
160
|
@classmethod
|
15
161
|
def _from_pdf_to_image(cls, pdf_path, image_format="jpeg"):
|
@@ -38,11 +184,14 @@ class ScenarioListPdfMixin:
|
|
38
184
|
scenario = Scenario._from_filepath_image(image_path)
|
39
185
|
scenarios.append(scenario)
|
40
186
|
|
41
|
-
print(f"Saved {len(images)} pages as images in {output_folder}")
|
187
|
+
# print(f"Saved {len(images)} pages as images in {output_folder}")
|
42
188
|
return cls(scenarios)
|
43
189
|
|
44
190
|
@staticmethod
|
45
191
|
def extract_text_from_pdf(pdf_path):
|
192
|
+
from edsl import Scenario
|
193
|
+
|
194
|
+
# TODO: Add test case
|
46
195
|
# Ensure the file exists
|
47
196
|
if not os.path.exists(pdf_path):
|
48
197
|
raise FileNotFoundError(f"The file {pdf_path} does not exist.")
|
edsl/study/SnapShot.py
CHANGED
@@ -57,10 +57,17 @@ class SnapShot:
|
|
57
57
|
from edsl.Base import Base
|
58
58
|
from edsl.study.Study import Study
|
59
59
|
|
60
|
+
def is_edsl_object(obj):
|
61
|
+
package_name = "edsl"
|
62
|
+
cls = obj.__class__
|
63
|
+
module_name = cls.__module__
|
64
|
+
return module_name.startswith(package_name)
|
65
|
+
|
60
66
|
for name, value in namespace.items():
|
61
67
|
# TODO check this code logic (if there are other objects with to_dict method that are not from edsl)
|
62
68
|
if (
|
63
|
-
|
69
|
+
is_edsl_object(value)
|
70
|
+
and hasattr(value, "to_dict")
|
64
71
|
and not inspect.isclass(value)
|
65
72
|
and value.__class__ not in [o.__class__ for o in self.exclude]
|
66
73
|
):
|
edsl/study/Study.py
CHANGED
@@ -469,6 +469,38 @@ class Study:
|
|
469
469
|
coop = Coop()
|
470
470
|
return coop.create(self, description=self.description)
|
471
471
|
|
472
|
+
def delete_object(self, identifier: Union[str, UUID]):
|
473
|
+
"""
|
474
|
+
Delete an EDSL object from the study.
|
475
|
+
|
476
|
+
:param identifier: Either the variable name or the hash of the object to delete
|
477
|
+
:raises ValueError: If the object is not found in the study
|
478
|
+
"""
|
479
|
+
if isinstance(identifier, str):
|
480
|
+
# If identifier is a variable name or a string representation of UUID
|
481
|
+
for hash, obj_entry in list(self.objects.items()):
|
482
|
+
if obj_entry.variable_name == identifier or hash == identifier:
|
483
|
+
del self.objects[hash]
|
484
|
+
self._create_mapping_dicts() # Update internal mappings
|
485
|
+
if self.verbose:
|
486
|
+
print(f"Deleted object with identifier: {identifier}")
|
487
|
+
return
|
488
|
+
raise ValueError(f"No object found with identifier: {identifier}")
|
489
|
+
elif isinstance(identifier, UUID):
|
490
|
+
# If identifier is a UUID object
|
491
|
+
hash_str = str(identifier)
|
492
|
+
if hash_str in self.objects:
|
493
|
+
del self.objects[hash_str]
|
494
|
+
self._create_mapping_dicts() # Update internal mappings
|
495
|
+
if self.verbose:
|
496
|
+
print(f"Deleted object with hash: {hash_str}")
|
497
|
+
return
|
498
|
+
raise ValueError(f"No object found with hash: {hash_str}")
|
499
|
+
else:
|
500
|
+
raise TypeError(
|
501
|
+
"Identifier must be either a string (variable name or hash) or a UUID object"
|
502
|
+
)
|
503
|
+
|
472
504
|
@classmethod
|
473
505
|
def pull(cls, uuid: Optional[Union[str, UUID]] = None, url: Optional[str] = None):
|
474
506
|
"""Pull the object from coop."""
|
edsl/surveys/Rule.py
CHANGED
@@ -18,6 +18,7 @@ with a low (-1) priority.
|
|
18
18
|
"""
|
19
19
|
|
20
20
|
import ast
|
21
|
+
import random
|
21
22
|
from typing import Any, Union, List
|
22
23
|
|
23
24
|
from jinja2 import Template
|
@@ -117,13 +118,15 @@ class Rule:
|
|
117
118
|
def _checks(self):
|
118
119
|
pass
|
119
120
|
|
120
|
-
|
121
|
+
# def _to_dict(self):
|
122
|
+
|
123
|
+
# @add_edsl_version
|
121
124
|
def to_dict(self):
|
122
125
|
"""Convert the rule to a dictionary for serialization.
|
123
126
|
|
124
127
|
>>> r = Rule.example()
|
125
128
|
>>> r.to_dict()
|
126
|
-
{'current_q': 1, 'expression': "q1 == 'yes'", 'next_q': 2, 'priority': 0, 'question_name_to_index': {'q1': 1}, 'before_rule': False
|
129
|
+
{'current_q': 1, 'expression': "q1 == 'yes'", 'next_q': 2, 'priority': 0, 'question_name_to_index': {'q1': 1}, 'before_rule': False}
|
127
130
|
"""
|
128
131
|
return {
|
129
132
|
"current_q": self.current_q,
|
@@ -133,6 +136,7 @@ class Rule:
|
|
133
136
|
"question_name_to_index": self.question_name_to_index,
|
134
137
|
"before_rule": self.before_rule,
|
135
138
|
}
|
139
|
+
# return self._to_dict()
|
136
140
|
|
137
141
|
@classmethod
|
138
142
|
@remove_edsl_version
|
@@ -251,8 +255,16 @@ class Rule:
|
|
251
255
|
msg = f"""Exception in evaluation: {e}. The expression is: {self.expression}. The current info env trying to substitute in is: {current_info_env}. After the substition, the expression was: {to_evaluate}."""
|
252
256
|
raise SurveyRuleCannotEvaluateError(msg)
|
253
257
|
|
258
|
+
random_functions = {
|
259
|
+
"randint": random.randint,
|
260
|
+
"choice": random.choice,
|
261
|
+
"random": random.random,
|
262
|
+
"uniform": random.uniform,
|
263
|
+
# Add any other random functions you want to allow
|
264
|
+
}
|
265
|
+
|
254
266
|
try:
|
255
|
-
return EvalWithCompoundTypes().eval(to_evaluate)
|
267
|
+
return EvalWithCompoundTypes(functions=random_functions).eval(to_evaluate)
|
256
268
|
except Exception as e:
|
257
269
|
msg = f"""Exception in evaluation: {e}. The expression is: {self.expression}. The current info env trying to substitute in is: {current_info_env}. After the substition, the expression was: {to_evaluate}."""
|
258
270
|
raise SurveyRuleCannotEvaluateError(msg)
|
edsl/surveys/RuleCollection.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
"""A collection of rules for a survey."""
|
2
2
|
|
3
|
-
from typing import List, Union, Any
|
3
|
+
from typing import List, Union, Any, Optional
|
4
4
|
from collections import defaultdict, UserList
|
5
5
|
|
6
6
|
from edsl.exceptions import (
|
@@ -24,7 +24,7 @@ NextQuestion = namedtuple(
|
|
24
24
|
class RuleCollection(UserList):
|
25
25
|
"""A collection of rules for a particular survey."""
|
26
26
|
|
27
|
-
def __init__(self, num_questions: int = None, rules: List[Rule] = None):
|
27
|
+
def __init__(self, num_questions: Optional[int] = None, rules: List[Rule] = None):
|
28
28
|
"""Initialize the RuleCollection object.
|
29
29
|
|
30
30
|
:param num_questions: The number of questions in the survey.
|
@@ -172,7 +172,8 @@ class RuleCollection(UserList):
|
|
172
172
|
|
173
173
|
def next_question(self, q_now: int, answers: dict[str, Any]) -> NextQuestion:
|
174
174
|
"""Find the next question by index, given the rule collection.
|
175
|
-
|
175
|
+
|
176
|
+
This rule is applied after the question is answered.
|
176
177
|
|
177
178
|
:param q_now: The current question index.
|
178
179
|
:param answers: The answers to the survey questions so far, including the current question.
|
@@ -182,8 +183,17 @@ class RuleCollection(UserList):
|
|
182
183
|
NextQuestion(next_q=3, num_rules_found=2, expressions_evaluating_to_true=1, priority=1)
|
183
184
|
|
184
185
|
"""
|
185
|
-
#
|
186
|
-
|
186
|
+
# # is this the first question? If it is, we need to check if it should be skipped.
|
187
|
+
# if q_now == 0:
|
188
|
+
# if self.skip_question_before_running(q_now, answers):
|
189
|
+
# return NextQuestion(
|
190
|
+
# next_q=q_now + 1,
|
191
|
+
# num_rules_found=0,
|
192
|
+
# expressions_evaluating_to_true=0,
|
193
|
+
# priority=-1,
|
194
|
+
# )
|
195
|
+
|
196
|
+
# breakpoint()
|
187
197
|
expressions_evaluating_to_true = 0
|
188
198
|
next_q = None
|
189
199
|
highest_priority = -2 # start with -2 to 'pick up' the default rule added
|
@@ -205,6 +215,12 @@ class RuleCollection(UserList):
|
|
205
215
|
f"No rules found for question {q_now}"
|
206
216
|
)
|
207
217
|
|
218
|
+
# breakpoint()
|
219
|
+
## Now we need to check if the *next question* has any 'before; rules that we should follow
|
220
|
+
for rule in self.applicable_rules(next_q, before_rule=True):
|
221
|
+
if rule.evaluate(answers): # rule evaluates to True
|
222
|
+
return self.next_question(next_q, answers)
|
223
|
+
|
208
224
|
return NextQuestion(
|
209
225
|
next_q, num_rules_found, expressions_evaluating_to_true, highest_priority
|
210
226
|
)
|