edsl 0.1.33.dev2__py3-none-any.whl → 0.1.34__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +24 -14
- edsl/__init__.py +1 -0
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +6 -6
- edsl/agents/Invigilator.py +28 -6
- edsl/agents/InvigilatorBase.py +8 -27
- edsl/agents/{PromptConstructionMixin.py → PromptConstructor.py} +150 -182
- edsl/agents/prompt_helpers.py +129 -0
- edsl/config.py +26 -34
- edsl/coop/coop.py +14 -4
- edsl/data_transfer_models.py +26 -73
- edsl/enums.py +2 -0
- edsl/inference_services/AnthropicService.py +5 -2
- edsl/inference_services/AwsBedrock.py +5 -2
- edsl/inference_services/AzureAI.py +5 -2
- edsl/inference_services/GoogleService.py +108 -33
- edsl/inference_services/InferenceServiceABC.py +44 -13
- edsl/inference_services/MistralAIService.py +5 -2
- edsl/inference_services/OpenAIService.py +10 -6
- edsl/inference_services/TestService.py +34 -16
- edsl/inference_services/TogetherAIService.py +170 -0
- edsl/inference_services/registry.py +2 -0
- edsl/jobs/Jobs.py +109 -18
- edsl/jobs/buckets/BucketCollection.py +24 -15
- edsl/jobs/buckets/TokenBucket.py +64 -10
- edsl/jobs/interviews/Interview.py +130 -49
- edsl/jobs/interviews/{interview_exception_tracking.py → InterviewExceptionCollection.py} +16 -0
- edsl/jobs/interviews/InterviewExceptionEntry.py +2 -0
- edsl/jobs/runners/JobsRunnerAsyncio.py +119 -173
- edsl/jobs/runners/JobsRunnerStatus.py +332 -0
- edsl/jobs/tasks/QuestionTaskCreator.py +1 -13
- edsl/jobs/tasks/TaskHistory.py +17 -0
- edsl/language_models/LanguageModel.py +36 -38
- edsl/language_models/registry.py +13 -9
- edsl/language_models/utilities.py +5 -2
- edsl/questions/QuestionBase.py +74 -16
- edsl/questions/QuestionBaseGenMixin.py +28 -0
- edsl/questions/QuestionBudget.py +93 -41
- edsl/questions/QuestionCheckBox.py +1 -1
- edsl/questions/QuestionFreeText.py +6 -0
- edsl/questions/QuestionMultipleChoice.py +13 -24
- edsl/questions/QuestionNumerical.py +5 -4
- edsl/questions/Quick.py +41 -0
- edsl/questions/ResponseValidatorABC.py +11 -6
- edsl/questions/derived/QuestionLinearScale.py +4 -1
- edsl/questions/derived/QuestionTopK.py +4 -1
- edsl/questions/derived/QuestionYesNo.py +8 -2
- edsl/questions/descriptors.py +12 -11
- edsl/questions/templates/budget/__init__.py +0 -0
- edsl/questions/templates/budget/answering_instructions.jinja +7 -0
- edsl/questions/templates/budget/question_presentation.jinja +7 -0
- edsl/questions/templates/extract/__init__.py +0 -0
- edsl/questions/templates/numerical/answering_instructions.jinja +0 -1
- edsl/questions/templates/rank/__init__.py +0 -0
- edsl/questions/templates/yes_no/answering_instructions.jinja +2 -2
- edsl/results/DatasetExportMixin.py +5 -1
- edsl/results/Result.py +1 -1
- edsl/results/Results.py +4 -1
- edsl/scenarios/FileStore.py +178 -34
- edsl/scenarios/Scenario.py +76 -37
- edsl/scenarios/ScenarioList.py +19 -2
- edsl/scenarios/ScenarioListPdfMixin.py +150 -4
- edsl/study/Study.py +32 -0
- edsl/surveys/DAG.py +62 -0
- edsl/surveys/MemoryPlan.py +26 -0
- edsl/surveys/Rule.py +34 -1
- edsl/surveys/RuleCollection.py +55 -5
- edsl/surveys/Survey.py +189 -10
- edsl/surveys/base.py +4 -0
- edsl/templates/error_reporting/interview_details.html +6 -1
- edsl/utilities/utilities.py +9 -1
- {edsl-0.1.33.dev2.dist-info → edsl-0.1.34.dist-info}/METADATA +3 -1
- {edsl-0.1.33.dev2.dist-info → edsl-0.1.34.dist-info}/RECORD +75 -69
- edsl/jobs/interviews/retry_management.py +0 -39
- edsl/jobs/runners/JobsRunnerStatusMixin.py +0 -333
- edsl/scenarios/ScenarioImageMixin.py +0 -100
- {edsl-0.1.33.dev2.dist-info → edsl-0.1.34.dist-info}/LICENSE +0 -0
- {edsl-0.1.33.dev2.dist-info → edsl-0.1.34.dist-info}/WHEEL +0 -0
edsl/scenarios/ScenarioList.py
CHANGED
@@ -39,6 +39,15 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
39
39
|
super().__init__([])
|
40
40
|
self.codebook = codebook or {}
|
41
41
|
|
42
|
+
@property
|
43
|
+
def has_jinja_braces(self) -> bool:
|
44
|
+
"""Check if the ScenarioList has Jinja braces."""
|
45
|
+
return any([scenario.has_jinja_braces for scenario in self])
|
46
|
+
|
47
|
+
def convert_jinja_braces(self) -> ScenarioList:
|
48
|
+
"""Convert Jinja braces to Python braces."""
|
49
|
+
return ScenarioList([scenario.convert_jinja_braces() for scenario in self])
|
50
|
+
|
42
51
|
def give_valid_names(self) -> ScenarioList:
|
43
52
|
"""Give valid names to the scenario keys.
|
44
53
|
|
@@ -273,6 +282,10 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
273
282
|
for s in data["scenarios"]:
|
274
283
|
_ = s.pop("edsl_version")
|
275
284
|
_ = s.pop("edsl_class_name")
|
285
|
+
for scenario in data["scenarios"]:
|
286
|
+
for key, value in scenario.items():
|
287
|
+
if hasattr(value, "to_dict"):
|
288
|
+
data[key] = value.to_dict()
|
276
289
|
return data_to_html(data)
|
277
290
|
|
278
291
|
def tally(self, field) -> dict:
|
@@ -517,7 +530,9 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
517
530
|
return ScenarioList([scenario.drop(fields) for scenario in self.data])
|
518
531
|
|
519
532
|
@classmethod
|
520
|
-
def from_list(
|
533
|
+
def from_list(
|
534
|
+
cls, name: str, values: list, func: Optional[Callable] = None
|
535
|
+
) -> ScenarioList:
|
521
536
|
"""Create a ScenarioList from a list of values.
|
522
537
|
|
523
538
|
Example:
|
@@ -525,7 +540,9 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
525
540
|
>>> ScenarioList.from_list('name', ['Alice', 'Bob'])
|
526
541
|
ScenarioList([Scenario({'name': 'Alice'}), Scenario({'name': 'Bob'})])
|
527
542
|
"""
|
528
|
-
|
543
|
+
if not func:
|
544
|
+
func = lambda x: x
|
545
|
+
return cls([Scenario({name: func(value)}) for value in values])
|
529
546
|
|
530
547
|
def to_dataset(self) -> "Dataset":
|
531
548
|
"""
|
@@ -1,15 +1,161 @@
|
|
1
1
|
import fitz # PyMuPDF
|
2
2
|
import os
|
3
|
+
import copy
|
3
4
|
import subprocess
|
5
|
+
import requests
|
6
|
+
import tempfile
|
7
|
+
import os
|
8
|
+
|
9
|
+
# import urllib.parse as urlparse
|
10
|
+
from urllib.parse import urlparse
|
4
11
|
|
5
12
|
# from edsl import Scenario
|
6
13
|
|
14
|
+
import requests
|
15
|
+
import re
|
16
|
+
import tempfile
|
17
|
+
import os
|
18
|
+
import atexit
|
19
|
+
from urllib.parse import urlparse, parse_qs
|
20
|
+
|
21
|
+
|
22
|
+
class GoogleDriveDownloader:
|
23
|
+
_temp_dir = None
|
24
|
+
_temp_file_path = None
|
25
|
+
|
26
|
+
@classmethod
|
27
|
+
def fetch_from_drive(cls, url, filename=None):
|
28
|
+
# Extract file ID from the URL
|
29
|
+
file_id = cls._extract_file_id(url)
|
30
|
+
if not file_id:
|
31
|
+
raise ValueError("Invalid Google Drive URL")
|
32
|
+
|
33
|
+
# Construct the download URL
|
34
|
+
download_url = f"https://drive.google.com/uc?export=download&id={file_id}"
|
35
|
+
|
36
|
+
# Send a GET request to the URL
|
37
|
+
session = requests.Session()
|
38
|
+
response = session.get(download_url, stream=True)
|
39
|
+
response.raise_for_status()
|
40
|
+
|
41
|
+
# Check for large file download prompt
|
42
|
+
for key, value in response.cookies.items():
|
43
|
+
if key.startswith("download_warning"):
|
44
|
+
params = {"id": file_id, "confirm": value}
|
45
|
+
response = session.get(download_url, params=params, stream=True)
|
46
|
+
break
|
47
|
+
|
48
|
+
# Create a temporary file to save the download
|
49
|
+
if not filename:
|
50
|
+
filename = "downloaded_file"
|
51
|
+
|
52
|
+
if cls._temp_dir is None:
|
53
|
+
cls._temp_dir = tempfile.TemporaryDirectory()
|
54
|
+
atexit.register(cls._cleanup)
|
55
|
+
|
56
|
+
cls._temp_file_path = os.path.join(cls._temp_dir.name, filename)
|
57
|
+
|
58
|
+
# Write the content to the temporary file
|
59
|
+
with open(cls._temp_file_path, "wb") as f:
|
60
|
+
for chunk in response.iter_content(32768):
|
61
|
+
if chunk:
|
62
|
+
f.write(chunk)
|
63
|
+
|
64
|
+
print(f"File saved to: {cls._temp_file_path}")
|
65
|
+
|
66
|
+
return cls._temp_file_path
|
67
|
+
|
68
|
+
@staticmethod
|
69
|
+
def _extract_file_id(url):
|
70
|
+
# Try to extract file ID from '/file/d/' format
|
71
|
+
file_id_match = re.search(r"/d/([a-zA-Z0-9-_]+)", url)
|
72
|
+
if file_id_match:
|
73
|
+
return file_id_match.group(1)
|
74
|
+
|
75
|
+
# If not found, try to extract from 'open?id=' format
|
76
|
+
parsed_url = urlparse(url)
|
77
|
+
query_params = parse_qs(parsed_url.query)
|
78
|
+
if "id" in query_params:
|
79
|
+
return query_params["id"][0]
|
80
|
+
|
81
|
+
return None
|
82
|
+
|
83
|
+
@classmethod
|
84
|
+
def _cleanup(cls):
|
85
|
+
if cls._temp_dir:
|
86
|
+
cls._temp_dir.cleanup()
|
87
|
+
|
88
|
+
@classmethod
|
89
|
+
def get_temp_file_path(cls):
|
90
|
+
return cls._temp_file_path
|
91
|
+
|
92
|
+
|
93
|
+
def fetch_and_save_pdf(url, filename):
|
94
|
+
# Send a GET request to the URL
|
95
|
+
response = requests.get(url)
|
96
|
+
|
97
|
+
# Check if the request was successful
|
98
|
+
response.raise_for_status()
|
99
|
+
|
100
|
+
# Create a temporary directory
|
101
|
+
with tempfile.TemporaryDirectory() as temp_dir:
|
102
|
+
# Construct the full path for the file
|
103
|
+
temp_file_path = os.path.join(temp_dir, filename)
|
104
|
+
|
105
|
+
# Write the content to the temporary file
|
106
|
+
with open(temp_file_path, "wb") as file:
|
107
|
+
file.write(response.content)
|
108
|
+
|
109
|
+
print(f"PDF saved to: {temp_file_path}")
|
110
|
+
|
111
|
+
# Here you can perform operations with the file
|
112
|
+
# The file will be automatically deleted when you exit this block
|
113
|
+
|
114
|
+
return temp_file_path
|
115
|
+
|
116
|
+
|
117
|
+
# Example usage:
|
118
|
+
# url = "https://example.com/sample.pdf"
|
119
|
+
# fetch_and_save_pdf(url, "sample.pdf")
|
120
|
+
|
7
121
|
|
8
122
|
class ScenarioListPdfMixin:
|
9
123
|
@classmethod
|
10
|
-
def from_pdf(cls,
|
11
|
-
|
12
|
-
|
124
|
+
def from_pdf(cls, filename_or_url, collapse_pages=False):
|
125
|
+
# Check if the input is a URL
|
126
|
+
if cls.is_url(filename_or_url):
|
127
|
+
# Check if it's a Google Drive URL
|
128
|
+
if "drive.google.com" in filename_or_url:
|
129
|
+
temp_filename = GoogleDriveDownloader.fetch_from_drive(
|
130
|
+
filename_or_url, "temp_pdf.pdf"
|
131
|
+
)
|
132
|
+
else:
|
133
|
+
# For other URLs, use the previous fetch_and_save_pdf function
|
134
|
+
temp_filename = fetch_and_save_pdf(filename_or_url, "temp_pdf.pdf")
|
135
|
+
|
136
|
+
scenarios = list(cls.extract_text_from_pdf(temp_filename))
|
137
|
+
else:
|
138
|
+
# If it's not a URL, assume it's a local file path
|
139
|
+
scenarios = list(cls.extract_text_from_pdf(filename_or_url))
|
140
|
+
if not collapse_pages:
|
141
|
+
return cls(scenarios)
|
142
|
+
else:
|
143
|
+
txt = ""
|
144
|
+
for scenario in scenarios:
|
145
|
+
txt += scenario["text"]
|
146
|
+
from edsl.scenarios import Scenario
|
147
|
+
|
148
|
+
base_scenario = copy.copy(scenarios[0])
|
149
|
+
base_scenario["text"] = txt
|
150
|
+
return base_scenario
|
151
|
+
|
152
|
+
@staticmethod
|
153
|
+
def is_url(string):
|
154
|
+
try:
|
155
|
+
result = urlparse(string)
|
156
|
+
return all([result.scheme, result.netloc])
|
157
|
+
except ValueError:
|
158
|
+
return False
|
13
159
|
|
14
160
|
@classmethod
|
15
161
|
def _from_pdf_to_image(cls, pdf_path, image_format="jpeg"):
|
@@ -38,7 +184,7 @@ class ScenarioListPdfMixin:
|
|
38
184
|
scenario = Scenario._from_filepath_image(image_path)
|
39
185
|
scenarios.append(scenario)
|
40
186
|
|
41
|
-
print(f"Saved {len(images)} pages as images in {output_folder}")
|
187
|
+
# print(f"Saved {len(images)} pages as images in {output_folder}")
|
42
188
|
return cls(scenarios)
|
43
189
|
|
44
190
|
@staticmethod
|
edsl/study/Study.py
CHANGED
@@ -469,6 +469,38 @@ class Study:
|
|
469
469
|
coop = Coop()
|
470
470
|
return coop.create(self, description=self.description)
|
471
471
|
|
472
|
+
def delete_object(self, identifier: Union[str, UUID]):
|
473
|
+
"""
|
474
|
+
Delete an EDSL object from the study.
|
475
|
+
|
476
|
+
:param identifier: Either the variable name or the hash of the object to delete
|
477
|
+
:raises ValueError: If the object is not found in the study
|
478
|
+
"""
|
479
|
+
if isinstance(identifier, str):
|
480
|
+
# If identifier is a variable name or a string representation of UUID
|
481
|
+
for hash, obj_entry in list(self.objects.items()):
|
482
|
+
if obj_entry.variable_name == identifier or hash == identifier:
|
483
|
+
del self.objects[hash]
|
484
|
+
self._create_mapping_dicts() # Update internal mappings
|
485
|
+
if self.verbose:
|
486
|
+
print(f"Deleted object with identifier: {identifier}")
|
487
|
+
return
|
488
|
+
raise ValueError(f"No object found with identifier: {identifier}")
|
489
|
+
elif isinstance(identifier, UUID):
|
490
|
+
# If identifier is a UUID object
|
491
|
+
hash_str = str(identifier)
|
492
|
+
if hash_str in self.objects:
|
493
|
+
del self.objects[hash_str]
|
494
|
+
self._create_mapping_dicts() # Update internal mappings
|
495
|
+
if self.verbose:
|
496
|
+
print(f"Deleted object with hash: {hash_str}")
|
497
|
+
return
|
498
|
+
raise ValueError(f"No object found with hash: {hash_str}")
|
499
|
+
else:
|
500
|
+
raise TypeError(
|
501
|
+
"Identifier must be either a string (variable name or hash) or a UUID object"
|
502
|
+
)
|
503
|
+
|
472
504
|
@classmethod
|
473
505
|
def pull(cls, uuid: Optional[Union[str, UUID]] = None, url: Optional[str] = None):
|
474
506
|
"""Pull the object from coop."""
|
edsl/surveys/DAG.py
CHANGED
@@ -11,6 +11,7 @@ class DAG(UserDict):
|
|
11
11
|
"""Initialize the DAG class."""
|
12
12
|
super().__init__(data)
|
13
13
|
self.reverse_mapping = self._create_reverse_mapping()
|
14
|
+
self.validate_no_cycles()
|
14
15
|
|
15
16
|
def _create_reverse_mapping(self):
|
16
17
|
"""
|
@@ -73,12 +74,73 @@ class DAG(UserDict):
|
|
73
74
|
# else:
|
74
75
|
# return DAG(d)
|
75
76
|
|
77
|
+
def remove_node(self, node: int) -> None:
|
78
|
+
"""Remove a node and all its connections from the DAG."""
|
79
|
+
self.pop(node, None)
|
80
|
+
for connections in self.values():
|
81
|
+
connections.discard(node)
|
82
|
+
# Adjust remaining nodes if necessary
|
83
|
+
self._adjust_nodes_after_removal(node)
|
84
|
+
|
85
|
+
def _adjust_nodes_after_removal(self, removed_node: int) -> None:
|
86
|
+
"""Adjust node indices after a node is removed."""
|
87
|
+
new_dag = {}
|
88
|
+
for node, connections in self.items():
|
89
|
+
new_node = node if node < removed_node else node - 1
|
90
|
+
new_connections = {c if c < removed_node else c - 1 for c in connections}
|
91
|
+
new_dag[new_node] = new_connections
|
92
|
+
self.clear()
|
93
|
+
self.update(new_dag)
|
94
|
+
|
76
95
|
@classmethod
|
77
96
|
def example(cls):
|
78
97
|
"""Return an example of the `DAG`."""
|
79
98
|
data = {"a": ["b", "c"], "b": ["d"], "c": [], "d": []}
|
80
99
|
return cls(data)
|
81
100
|
|
101
|
+
def detect_cycles(self):
|
102
|
+
"""
|
103
|
+
Detect cycles in the DAG using depth-first search.
|
104
|
+
|
105
|
+
:return: A list of cycles if any are found, otherwise an empty list.
|
106
|
+
"""
|
107
|
+
visited = set()
|
108
|
+
path = []
|
109
|
+
cycles = []
|
110
|
+
|
111
|
+
def dfs(node):
|
112
|
+
if node in path:
|
113
|
+
cycle = path[path.index(node) :]
|
114
|
+
cycles.append(cycle + [node])
|
115
|
+
return
|
116
|
+
|
117
|
+
if node in visited:
|
118
|
+
return
|
119
|
+
|
120
|
+
visited.add(node)
|
121
|
+
path.append(node)
|
122
|
+
|
123
|
+
for child in self.get(node, []):
|
124
|
+
dfs(child)
|
125
|
+
|
126
|
+
path.pop()
|
127
|
+
|
128
|
+
for node in self:
|
129
|
+
if node not in visited:
|
130
|
+
dfs(node)
|
131
|
+
|
132
|
+
return cycles
|
133
|
+
|
134
|
+
def validate_no_cycles(self):
|
135
|
+
"""
|
136
|
+
Validate that the DAG does not contain any cycles.
|
137
|
+
|
138
|
+
:raises ValueError: If cycles are detected in the DAG.
|
139
|
+
"""
|
140
|
+
cycles = self.detect_cycles()
|
141
|
+
if cycles:
|
142
|
+
raise ValueError(f"Cycles detected in the DAG: {cycles}")
|
143
|
+
|
82
144
|
|
83
145
|
if __name__ == "__main__":
|
84
146
|
import doctest
|
edsl/surveys/MemoryPlan.py
CHANGED
@@ -211,6 +211,32 @@ class MemoryPlan(UserDict):
|
|
211
211
|
mp.add_single_memory("q1", "q0")
|
212
212
|
return mp
|
213
213
|
|
214
|
+
def remove_question(self, question_name: str) -> None:
|
215
|
+
"""Remove a question from the memory plan.
|
216
|
+
|
217
|
+
:param question_name: The name of the question to remove.
|
218
|
+
"""
|
219
|
+
self._check_valid_question_name(question_name)
|
220
|
+
|
221
|
+
# Remove the question from survey_question_names and question_texts
|
222
|
+
index = self.survey_question_names.index(question_name)
|
223
|
+
self.survey_question_names.pop(index)
|
224
|
+
self.question_texts.pop(index)
|
225
|
+
|
226
|
+
# Remove the question from the memory plan if it's a focal question
|
227
|
+
self.pop(question_name, None)
|
228
|
+
|
229
|
+
# Remove the question from all memories where it appears as a prior question
|
230
|
+
for focal_question, memory in self.items():
|
231
|
+
memory.remove_prior_question(question_name)
|
232
|
+
|
233
|
+
# Update the DAG
|
234
|
+
self.dag.remove_node(index)
|
235
|
+
|
236
|
+
def remove_prior_question(self, question_name: str) -> None:
|
237
|
+
"""Remove a prior question from the memory."""
|
238
|
+
self.prior_questions = [q for q in self.prior_questions if q != question_name]
|
239
|
+
|
214
240
|
|
215
241
|
if __name__ == "__main__":
|
216
242
|
import doctest
|
edsl/surveys/Rule.py
CHANGED
@@ -18,6 +18,7 @@ with a low (-1) priority.
|
|
18
18
|
"""
|
19
19
|
|
20
20
|
import ast
|
21
|
+
import random
|
21
22
|
from typing import Any, Union, List
|
22
23
|
|
23
24
|
from jinja2 import Template
|
@@ -37,9 +38,29 @@ from edsl.utilities.ast_utilities import extract_variable_names
|
|
37
38
|
from edsl.utilities.decorators import add_edsl_version, remove_edsl_version
|
38
39
|
|
39
40
|
|
41
|
+
class QuestionIndex:
|
42
|
+
def __set_name__(self, owner, name):
|
43
|
+
self.name = f"_{name}"
|
44
|
+
|
45
|
+
def __get__(self, obj, objtype=None):
|
46
|
+
return getattr(obj, self.name)
|
47
|
+
|
48
|
+
def __set__(self, obj, value):
|
49
|
+
if not isinstance(value, (int, EndOfSurvey.__class__)):
|
50
|
+
raise ValueError(f"{self.name} must be an integer or EndOfSurvey")
|
51
|
+
if self.name == "_next_q" and isinstance(value, int):
|
52
|
+
current_q = getattr(obj, "_current_q")
|
53
|
+
if value <= current_q:
|
54
|
+
raise ValueError("next_q must be greater than current_q")
|
55
|
+
setattr(obj, self.name, value)
|
56
|
+
|
57
|
+
|
40
58
|
class Rule:
|
41
59
|
"""The Rule class defines a "rule" for determining the next question presented to an agent."""
|
42
60
|
|
61
|
+
current_q = QuestionIndex()
|
62
|
+
next_q = QuestionIndex()
|
63
|
+
|
43
64
|
# Not implemented but nice to have:
|
44
65
|
# We could potentially use the question pydantic models to check for rule conflicts, as
|
45
66
|
# they define the potential trees through a survey.
|
@@ -74,6 +95,10 @@ class Rule:
|
|
74
95
|
self.priority = priority
|
75
96
|
self.before_rule = before_rule
|
76
97
|
|
98
|
+
if not self.next_q == EndOfSurvey:
|
99
|
+
if self.next_q <= self.current_q:
|
100
|
+
raise SurveyRuleSendsYouBackwardsError
|
101
|
+
|
77
102
|
if not self.next_q == EndOfSurvey and self.current_q > self.next_q:
|
78
103
|
raise SurveyRuleSendsYouBackwardsError
|
79
104
|
|
@@ -254,8 +279,16 @@ class Rule:
|
|
254
279
|
msg = f"""Exception in evaluation: {e}. The expression is: {self.expression}. The current info env trying to substitute in is: {current_info_env}. After the substition, the expression was: {to_evaluate}."""
|
255
280
|
raise SurveyRuleCannotEvaluateError(msg)
|
256
281
|
|
282
|
+
random_functions = {
|
283
|
+
"randint": random.randint,
|
284
|
+
"choice": random.choice,
|
285
|
+
"random": random.random,
|
286
|
+
"uniform": random.uniform,
|
287
|
+
# Add any other random functions you want to allow
|
288
|
+
}
|
289
|
+
|
257
290
|
try:
|
258
|
-
return EvalWithCompoundTypes().eval(to_evaluate)
|
291
|
+
return EvalWithCompoundTypes(functions=random_functions).eval(to_evaluate)
|
259
292
|
except Exception as e:
|
260
293
|
msg = f"""Exception in evaluation: {e}. The expression is: {self.expression}. The current info env trying to substitute in is: {current_info_env}. After the substition, the expression was: {to_evaluate}."""
|
261
294
|
raise SurveyRuleCannotEvaluateError(msg)
|
edsl/surveys/RuleCollection.py
CHANGED
@@ -120,13 +120,13 @@ class RuleCollection(UserList):
|
|
120
120
|
:param answers: The answers to the survey questions.
|
121
121
|
|
122
122
|
>>> rule_collection = RuleCollection()
|
123
|
-
>>> r = Rule(current_q=1, expression="True", next_q=
|
123
|
+
>>> r = Rule(current_q=1, expression="True", next_q=2, priority=1, question_name_to_index={}, before_rule = True)
|
124
124
|
>>> rule_collection.add_rule(r)
|
125
125
|
>>> rule_collection.skip_question_before_running(1, {})
|
126
126
|
True
|
127
127
|
|
128
128
|
>>> rule_collection = RuleCollection()
|
129
|
-
>>> r = Rule(current_q=1, expression="False", next_q=
|
129
|
+
>>> r = Rule(current_q=1, expression="False", next_q=2, priority=1, question_name_to_index={}, before_rule = True)
|
130
130
|
>>> rule_collection.add_rule(r)
|
131
131
|
>>> rule_collection.skip_question_before_running(1, {})
|
132
132
|
False
|
@@ -172,7 +172,8 @@ class RuleCollection(UserList):
|
|
172
172
|
|
173
173
|
def next_question(self, q_now: int, answers: dict[str, Any]) -> NextQuestion:
|
174
174
|
"""Find the next question by index, given the rule collection.
|
175
|
-
|
175
|
+
|
176
|
+
This rule is applied after the question is answered.
|
176
177
|
|
177
178
|
:param q_now: The current question index.
|
178
179
|
:param answers: The answers to the survey questions so far, including the current question.
|
@@ -182,8 +183,17 @@ class RuleCollection(UserList):
|
|
182
183
|
NextQuestion(next_q=3, num_rules_found=2, expressions_evaluating_to_true=1, priority=1)
|
183
184
|
|
184
185
|
"""
|
185
|
-
#
|
186
|
-
|
186
|
+
# # is this the first question? If it is, we need to check if it should be skipped.
|
187
|
+
# if q_now == 0:
|
188
|
+
# if self.skip_question_before_running(q_now, answers):
|
189
|
+
# return NextQuestion(
|
190
|
+
# next_q=q_now + 1,
|
191
|
+
# num_rules_found=0,
|
192
|
+
# expressions_evaluating_to_true=0,
|
193
|
+
# priority=-1,
|
194
|
+
# )
|
195
|
+
|
196
|
+
# breakpoint()
|
187
197
|
expressions_evaluating_to_true = 0
|
188
198
|
next_q = None
|
189
199
|
highest_priority = -2 # start with -2 to 'pick up' the default rule added
|
@@ -205,6 +215,12 @@ class RuleCollection(UserList):
|
|
205
215
|
f"No rules found for question {q_now}"
|
206
216
|
)
|
207
217
|
|
218
|
+
# breakpoint()
|
219
|
+
## Now we need to check if the *next question* has any 'before; rules that we should follow
|
220
|
+
for rule in self.applicable_rules(next_q, before_rule=True):
|
221
|
+
if rule.evaluate(answers): # rule evaluates to True
|
222
|
+
return self.next_question(next_q, answers)
|
223
|
+
|
208
224
|
return NextQuestion(
|
209
225
|
next_q, num_rules_found, expressions_evaluating_to_true, highest_priority
|
210
226
|
)
|
@@ -305,6 +321,40 @@ class RuleCollection(UserList):
|
|
305
321
|
|
306
322
|
return DAG(dict(sorted(children_to_parents.items())))
|
307
323
|
|
324
|
+
def detect_cycles(self):
|
325
|
+
"""
|
326
|
+
Detect cycles in the survey rules using depth-first search.
|
327
|
+
|
328
|
+
:return: A list of cycles if any are found, otherwise an empty list.
|
329
|
+
"""
|
330
|
+
dag = self.dag
|
331
|
+
visited = set()
|
332
|
+
path = []
|
333
|
+
cycles = []
|
334
|
+
|
335
|
+
def dfs(node):
|
336
|
+
if node in path:
|
337
|
+
cycle = path[path.index(node) :]
|
338
|
+
cycles.append(cycle + [node])
|
339
|
+
return
|
340
|
+
|
341
|
+
if node in visited:
|
342
|
+
return
|
343
|
+
|
344
|
+
visited.add(node)
|
345
|
+
path.append(node)
|
346
|
+
|
347
|
+
for child in dag.get(node, []):
|
348
|
+
dfs(child)
|
349
|
+
|
350
|
+
path.pop()
|
351
|
+
|
352
|
+
for node in dag:
|
353
|
+
if node not in visited:
|
354
|
+
dfs(node)
|
355
|
+
|
356
|
+
return cycles
|
357
|
+
|
308
358
|
@classmethod
|
309
359
|
def example(cls):
|
310
360
|
"""Create an example RuleCollection object."""
|