edsl 0.1.33.dev2__py3-none-any.whl → 0.1.34__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. edsl/Base.py +24 -14
  2. edsl/__init__.py +1 -0
  3. edsl/__version__.py +1 -1
  4. edsl/agents/Agent.py +6 -6
  5. edsl/agents/Invigilator.py +28 -6
  6. edsl/agents/InvigilatorBase.py +8 -27
  7. edsl/agents/{PromptConstructionMixin.py → PromptConstructor.py} +150 -182
  8. edsl/agents/prompt_helpers.py +129 -0
  9. edsl/config.py +26 -34
  10. edsl/coop/coop.py +14 -4
  11. edsl/data_transfer_models.py +26 -73
  12. edsl/enums.py +2 -0
  13. edsl/inference_services/AnthropicService.py +5 -2
  14. edsl/inference_services/AwsBedrock.py +5 -2
  15. edsl/inference_services/AzureAI.py +5 -2
  16. edsl/inference_services/GoogleService.py +108 -33
  17. edsl/inference_services/InferenceServiceABC.py +44 -13
  18. edsl/inference_services/MistralAIService.py +5 -2
  19. edsl/inference_services/OpenAIService.py +10 -6
  20. edsl/inference_services/TestService.py +34 -16
  21. edsl/inference_services/TogetherAIService.py +170 -0
  22. edsl/inference_services/registry.py +2 -0
  23. edsl/jobs/Jobs.py +109 -18
  24. edsl/jobs/buckets/BucketCollection.py +24 -15
  25. edsl/jobs/buckets/TokenBucket.py +64 -10
  26. edsl/jobs/interviews/Interview.py +130 -49
  27. edsl/jobs/interviews/{interview_exception_tracking.py → InterviewExceptionCollection.py} +16 -0
  28. edsl/jobs/interviews/InterviewExceptionEntry.py +2 -0
  29. edsl/jobs/runners/JobsRunnerAsyncio.py +119 -173
  30. edsl/jobs/runners/JobsRunnerStatus.py +332 -0
  31. edsl/jobs/tasks/QuestionTaskCreator.py +1 -13
  32. edsl/jobs/tasks/TaskHistory.py +17 -0
  33. edsl/language_models/LanguageModel.py +36 -38
  34. edsl/language_models/registry.py +13 -9
  35. edsl/language_models/utilities.py +5 -2
  36. edsl/questions/QuestionBase.py +74 -16
  37. edsl/questions/QuestionBaseGenMixin.py +28 -0
  38. edsl/questions/QuestionBudget.py +93 -41
  39. edsl/questions/QuestionCheckBox.py +1 -1
  40. edsl/questions/QuestionFreeText.py +6 -0
  41. edsl/questions/QuestionMultipleChoice.py +13 -24
  42. edsl/questions/QuestionNumerical.py +5 -4
  43. edsl/questions/Quick.py +41 -0
  44. edsl/questions/ResponseValidatorABC.py +11 -6
  45. edsl/questions/derived/QuestionLinearScale.py +4 -1
  46. edsl/questions/derived/QuestionTopK.py +4 -1
  47. edsl/questions/derived/QuestionYesNo.py +8 -2
  48. edsl/questions/descriptors.py +12 -11
  49. edsl/questions/templates/budget/__init__.py +0 -0
  50. edsl/questions/templates/budget/answering_instructions.jinja +7 -0
  51. edsl/questions/templates/budget/question_presentation.jinja +7 -0
  52. edsl/questions/templates/extract/__init__.py +0 -0
  53. edsl/questions/templates/numerical/answering_instructions.jinja +0 -1
  54. edsl/questions/templates/rank/__init__.py +0 -0
  55. edsl/questions/templates/yes_no/answering_instructions.jinja +2 -2
  56. edsl/results/DatasetExportMixin.py +5 -1
  57. edsl/results/Result.py +1 -1
  58. edsl/results/Results.py +4 -1
  59. edsl/scenarios/FileStore.py +178 -34
  60. edsl/scenarios/Scenario.py +76 -37
  61. edsl/scenarios/ScenarioList.py +19 -2
  62. edsl/scenarios/ScenarioListPdfMixin.py +150 -4
  63. edsl/study/Study.py +32 -0
  64. edsl/surveys/DAG.py +62 -0
  65. edsl/surveys/MemoryPlan.py +26 -0
  66. edsl/surveys/Rule.py +34 -1
  67. edsl/surveys/RuleCollection.py +55 -5
  68. edsl/surveys/Survey.py +189 -10
  69. edsl/surveys/base.py +4 -0
  70. edsl/templates/error_reporting/interview_details.html +6 -1
  71. edsl/utilities/utilities.py +9 -1
  72. {edsl-0.1.33.dev2.dist-info → edsl-0.1.34.dist-info}/METADATA +3 -1
  73. {edsl-0.1.33.dev2.dist-info → edsl-0.1.34.dist-info}/RECORD +75 -69
  74. edsl/jobs/interviews/retry_management.py +0 -39
  75. edsl/jobs/runners/JobsRunnerStatusMixin.py +0 -333
  76. edsl/scenarios/ScenarioImageMixin.py +0 -100
  77. {edsl-0.1.33.dev2.dist-info → edsl-0.1.34.dist-info}/LICENSE +0 -0
  78. {edsl-0.1.33.dev2.dist-info → edsl-0.1.34.dist-info}/WHEEL +0 -0
@@ -39,6 +39,15 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
39
39
  super().__init__([])
40
40
  self.codebook = codebook or {}
41
41
 
42
+ @property
43
+ def has_jinja_braces(self) -> bool:
44
+ """Check if the ScenarioList has Jinja braces."""
45
+ return any([scenario.has_jinja_braces for scenario in self])
46
+
47
+ def convert_jinja_braces(self) -> ScenarioList:
48
+ """Convert Jinja braces to Python braces."""
49
+ return ScenarioList([scenario.convert_jinja_braces() for scenario in self])
50
+
42
51
  def give_valid_names(self) -> ScenarioList:
43
52
  """Give valid names to the scenario keys.
44
53
 
@@ -273,6 +282,10 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
273
282
  for s in data["scenarios"]:
274
283
  _ = s.pop("edsl_version")
275
284
  _ = s.pop("edsl_class_name")
285
+ for scenario in data["scenarios"]:
286
+ for key, value in scenario.items():
287
+ if hasattr(value, "to_dict"):
288
+ data[key] = value.to_dict()
276
289
  return data_to_html(data)
277
290
 
278
291
  def tally(self, field) -> dict:
@@ -517,7 +530,9 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
517
530
  return ScenarioList([scenario.drop(fields) for scenario in self.data])
518
531
 
519
532
  @classmethod
520
- def from_list(cls, name, values) -> ScenarioList:
533
+ def from_list(
534
+ cls, name: str, values: list, func: Optional[Callable] = None
535
+ ) -> ScenarioList:
521
536
  """Create a ScenarioList from a list of values.
522
537
 
523
538
  Example:
@@ -525,7 +540,9 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
525
540
  >>> ScenarioList.from_list('name', ['Alice', 'Bob'])
526
541
  ScenarioList([Scenario({'name': 'Alice'}), Scenario({'name': 'Bob'})])
527
542
  """
528
- return cls([Scenario({name: value}) for value in values])
543
+ if not func:
544
+ func = lambda x: x
545
+ return cls([Scenario({name: func(value)}) for value in values])
529
546
 
530
547
  def to_dataset(self) -> "Dataset":
531
548
  """
@@ -1,15 +1,161 @@
1
1
  import fitz # PyMuPDF
2
2
  import os
3
+ import copy
3
4
  import subprocess
5
+ import requests
6
+ import tempfile
7
+ import os
8
+
9
+ # import urllib.parse as urlparse
10
+ from urllib.parse import urlparse
4
11
 
5
12
  # from edsl import Scenario
6
13
 
14
+ import requests
15
+ import re
16
+ import tempfile
17
+ import os
18
+ import atexit
19
+ from urllib.parse import urlparse, parse_qs
20
+
21
+
22
+ class GoogleDriveDownloader:
23
+ _temp_dir = None
24
+ _temp_file_path = None
25
+
26
+ @classmethod
27
+ def fetch_from_drive(cls, url, filename=None):
28
+ # Extract file ID from the URL
29
+ file_id = cls._extract_file_id(url)
30
+ if not file_id:
31
+ raise ValueError("Invalid Google Drive URL")
32
+
33
+ # Construct the download URL
34
+ download_url = f"https://drive.google.com/uc?export=download&id={file_id}"
35
+
36
+ # Send a GET request to the URL
37
+ session = requests.Session()
38
+ response = session.get(download_url, stream=True)
39
+ response.raise_for_status()
40
+
41
+ # Check for large file download prompt
42
+ for key, value in response.cookies.items():
43
+ if key.startswith("download_warning"):
44
+ params = {"id": file_id, "confirm": value}
45
+ response = session.get(download_url, params=params, stream=True)
46
+ break
47
+
48
+ # Create a temporary file to save the download
49
+ if not filename:
50
+ filename = "downloaded_file"
51
+
52
+ if cls._temp_dir is None:
53
+ cls._temp_dir = tempfile.TemporaryDirectory()
54
+ atexit.register(cls._cleanup)
55
+
56
+ cls._temp_file_path = os.path.join(cls._temp_dir.name, filename)
57
+
58
+ # Write the content to the temporary file
59
+ with open(cls._temp_file_path, "wb") as f:
60
+ for chunk in response.iter_content(32768):
61
+ if chunk:
62
+ f.write(chunk)
63
+
64
+ print(f"File saved to: {cls._temp_file_path}")
65
+
66
+ return cls._temp_file_path
67
+
68
+ @staticmethod
69
+ def _extract_file_id(url):
70
+ # Try to extract file ID from '/file/d/' format
71
+ file_id_match = re.search(r"/d/([a-zA-Z0-9-_]+)", url)
72
+ if file_id_match:
73
+ return file_id_match.group(1)
74
+
75
+ # If not found, try to extract from 'open?id=' format
76
+ parsed_url = urlparse(url)
77
+ query_params = parse_qs(parsed_url.query)
78
+ if "id" in query_params:
79
+ return query_params["id"][0]
80
+
81
+ return None
82
+
83
+ @classmethod
84
+ def _cleanup(cls):
85
+ if cls._temp_dir:
86
+ cls._temp_dir.cleanup()
87
+
88
+ @classmethod
89
+ def get_temp_file_path(cls):
90
+ return cls._temp_file_path
91
+
92
+
93
+ def fetch_and_save_pdf(url, filename):
94
+ # Send a GET request to the URL
95
+ response = requests.get(url)
96
+
97
+ # Check if the request was successful
98
+ response.raise_for_status()
99
+
100
+ # Create a temporary directory
101
+ with tempfile.TemporaryDirectory() as temp_dir:
102
+ # Construct the full path for the file
103
+ temp_file_path = os.path.join(temp_dir, filename)
104
+
105
+ # Write the content to the temporary file
106
+ with open(temp_file_path, "wb") as file:
107
+ file.write(response.content)
108
+
109
+ print(f"PDF saved to: {temp_file_path}")
110
+
111
+ # Here you can perform operations with the file
112
+ # The file will be automatically deleted when you exit this block
113
+
114
+ return temp_file_path
115
+
116
+
117
+ # Example usage:
118
+ # url = "https://example.com/sample.pdf"
119
+ # fetch_and_save_pdf(url, "sample.pdf")
120
+
7
121
 
8
122
  class ScenarioListPdfMixin:
9
123
  @classmethod
10
- def from_pdf(cls, filename):
11
- scenarios = list(cls.extract_text_from_pdf(filename))
12
- return cls(scenarios)
124
+ def from_pdf(cls, filename_or_url, collapse_pages=False):
125
+ # Check if the input is a URL
126
+ if cls.is_url(filename_or_url):
127
+ # Check if it's a Google Drive URL
128
+ if "drive.google.com" in filename_or_url:
129
+ temp_filename = GoogleDriveDownloader.fetch_from_drive(
130
+ filename_or_url, "temp_pdf.pdf"
131
+ )
132
+ else:
133
+ # For other URLs, use the previous fetch_and_save_pdf function
134
+ temp_filename = fetch_and_save_pdf(filename_or_url, "temp_pdf.pdf")
135
+
136
+ scenarios = list(cls.extract_text_from_pdf(temp_filename))
137
+ else:
138
+ # If it's not a URL, assume it's a local file path
139
+ scenarios = list(cls.extract_text_from_pdf(filename_or_url))
140
+ if not collapse_pages:
141
+ return cls(scenarios)
142
+ else:
143
+ txt = ""
144
+ for scenario in scenarios:
145
+ txt += scenario["text"]
146
+ from edsl.scenarios import Scenario
147
+
148
+ base_scenario = copy.copy(scenarios[0])
149
+ base_scenario["text"] = txt
150
+ return base_scenario
151
+
152
+ @staticmethod
153
+ def is_url(string):
154
+ try:
155
+ result = urlparse(string)
156
+ return all([result.scheme, result.netloc])
157
+ except ValueError:
158
+ return False
13
159
 
14
160
  @classmethod
15
161
  def _from_pdf_to_image(cls, pdf_path, image_format="jpeg"):
@@ -38,7 +184,7 @@ class ScenarioListPdfMixin:
38
184
  scenario = Scenario._from_filepath_image(image_path)
39
185
  scenarios.append(scenario)
40
186
 
41
- print(f"Saved {len(images)} pages as images in {output_folder}")
187
+ # print(f"Saved {len(images)} pages as images in {output_folder}")
42
188
  return cls(scenarios)
43
189
 
44
190
  @staticmethod
edsl/study/Study.py CHANGED
@@ -469,6 +469,38 @@ class Study:
469
469
  coop = Coop()
470
470
  return coop.create(self, description=self.description)
471
471
 
472
+ def delete_object(self, identifier: Union[str, UUID]):
473
+ """
474
+ Delete an EDSL object from the study.
475
+
476
+ :param identifier: Either the variable name or the hash of the object to delete
477
+ :raises ValueError: If the object is not found in the study
478
+ """
479
+ if isinstance(identifier, str):
480
+ # If identifier is a variable name or a string representation of UUID
481
+ for hash, obj_entry in list(self.objects.items()):
482
+ if obj_entry.variable_name == identifier or hash == identifier:
483
+ del self.objects[hash]
484
+ self._create_mapping_dicts() # Update internal mappings
485
+ if self.verbose:
486
+ print(f"Deleted object with identifier: {identifier}")
487
+ return
488
+ raise ValueError(f"No object found with identifier: {identifier}")
489
+ elif isinstance(identifier, UUID):
490
+ # If identifier is a UUID object
491
+ hash_str = str(identifier)
492
+ if hash_str in self.objects:
493
+ del self.objects[hash_str]
494
+ self._create_mapping_dicts() # Update internal mappings
495
+ if self.verbose:
496
+ print(f"Deleted object with hash: {hash_str}")
497
+ return
498
+ raise ValueError(f"No object found with hash: {hash_str}")
499
+ else:
500
+ raise TypeError(
501
+ "Identifier must be either a string (variable name or hash) or a UUID object"
502
+ )
503
+
472
504
  @classmethod
473
505
  def pull(cls, uuid: Optional[Union[str, UUID]] = None, url: Optional[str] = None):
474
506
  """Pull the object from coop."""
edsl/surveys/DAG.py CHANGED
@@ -11,6 +11,7 @@ class DAG(UserDict):
11
11
  """Initialize the DAG class."""
12
12
  super().__init__(data)
13
13
  self.reverse_mapping = self._create_reverse_mapping()
14
+ self.validate_no_cycles()
14
15
 
15
16
  def _create_reverse_mapping(self):
16
17
  """
@@ -73,12 +74,73 @@ class DAG(UserDict):
73
74
  # else:
74
75
  # return DAG(d)
75
76
 
77
+ def remove_node(self, node: int) -> None:
78
+ """Remove a node and all its connections from the DAG."""
79
+ self.pop(node, None)
80
+ for connections in self.values():
81
+ connections.discard(node)
82
+ # Adjust remaining nodes if necessary
83
+ self._adjust_nodes_after_removal(node)
84
+
85
+ def _adjust_nodes_after_removal(self, removed_node: int) -> None:
86
+ """Adjust node indices after a node is removed."""
87
+ new_dag = {}
88
+ for node, connections in self.items():
89
+ new_node = node if node < removed_node else node - 1
90
+ new_connections = {c if c < removed_node else c - 1 for c in connections}
91
+ new_dag[new_node] = new_connections
92
+ self.clear()
93
+ self.update(new_dag)
94
+
76
95
  @classmethod
77
96
  def example(cls):
78
97
  """Return an example of the `DAG`."""
79
98
  data = {"a": ["b", "c"], "b": ["d"], "c": [], "d": []}
80
99
  return cls(data)
81
100
 
101
+ def detect_cycles(self):
102
+ """
103
+ Detect cycles in the DAG using depth-first search.
104
+
105
+ :return: A list of cycles if any are found, otherwise an empty list.
106
+ """
107
+ visited = set()
108
+ path = []
109
+ cycles = []
110
+
111
+ def dfs(node):
112
+ if node in path:
113
+ cycle = path[path.index(node) :]
114
+ cycles.append(cycle + [node])
115
+ return
116
+
117
+ if node in visited:
118
+ return
119
+
120
+ visited.add(node)
121
+ path.append(node)
122
+
123
+ for child in self.get(node, []):
124
+ dfs(child)
125
+
126
+ path.pop()
127
+
128
+ for node in self:
129
+ if node not in visited:
130
+ dfs(node)
131
+
132
+ return cycles
133
+
134
+ def validate_no_cycles(self):
135
+ """
136
+ Validate that the DAG does not contain any cycles.
137
+
138
+ :raises ValueError: If cycles are detected in the DAG.
139
+ """
140
+ cycles = self.detect_cycles()
141
+ if cycles:
142
+ raise ValueError(f"Cycles detected in the DAG: {cycles}")
143
+
82
144
 
83
145
  if __name__ == "__main__":
84
146
  import doctest
@@ -211,6 +211,32 @@ class MemoryPlan(UserDict):
211
211
  mp.add_single_memory("q1", "q0")
212
212
  return mp
213
213
 
214
+ def remove_question(self, question_name: str) -> None:
215
+ """Remove a question from the memory plan.
216
+
217
+ :param question_name: The name of the question to remove.
218
+ """
219
+ self._check_valid_question_name(question_name)
220
+
221
+ # Remove the question from survey_question_names and question_texts
222
+ index = self.survey_question_names.index(question_name)
223
+ self.survey_question_names.pop(index)
224
+ self.question_texts.pop(index)
225
+
226
+ # Remove the question from the memory plan if it's a focal question
227
+ self.pop(question_name, None)
228
+
229
+ # Remove the question from all memories where it appears as a prior question
230
+ for focal_question, memory in self.items():
231
+ memory.remove_prior_question(question_name)
232
+
233
+ # Update the DAG
234
+ self.dag.remove_node(index)
235
+
236
+ def remove_prior_question(self, question_name: str) -> None:
237
+ """Remove a prior question from the memory."""
238
+ self.prior_questions = [q for q in self.prior_questions if q != question_name]
239
+
214
240
 
215
241
  if __name__ == "__main__":
216
242
  import doctest
edsl/surveys/Rule.py CHANGED
@@ -18,6 +18,7 @@ with a low (-1) priority.
18
18
  """
19
19
 
20
20
  import ast
21
+ import random
21
22
  from typing import Any, Union, List
22
23
 
23
24
  from jinja2 import Template
@@ -37,9 +38,29 @@ from edsl.utilities.ast_utilities import extract_variable_names
37
38
  from edsl.utilities.decorators import add_edsl_version, remove_edsl_version
38
39
 
39
40
 
41
+ class QuestionIndex:
42
+ def __set_name__(self, owner, name):
43
+ self.name = f"_{name}"
44
+
45
+ def __get__(self, obj, objtype=None):
46
+ return getattr(obj, self.name)
47
+
48
+ def __set__(self, obj, value):
49
+ if not isinstance(value, (int, EndOfSurvey.__class__)):
50
+ raise ValueError(f"{self.name} must be an integer or EndOfSurvey")
51
+ if self.name == "_next_q" and isinstance(value, int):
52
+ current_q = getattr(obj, "_current_q")
53
+ if value <= current_q:
54
+ raise ValueError("next_q must be greater than current_q")
55
+ setattr(obj, self.name, value)
56
+
57
+
40
58
  class Rule:
41
59
  """The Rule class defines a "rule" for determining the next question presented to an agent."""
42
60
 
61
+ current_q = QuestionIndex()
62
+ next_q = QuestionIndex()
63
+
43
64
  # Not implemented but nice to have:
44
65
  # We could potentially use the question pydantic models to check for rule conflicts, as
45
66
  # they define the potential trees through a survey.
@@ -74,6 +95,10 @@ class Rule:
74
95
  self.priority = priority
75
96
  self.before_rule = before_rule
76
97
 
98
+ if not self.next_q == EndOfSurvey:
99
+ if self.next_q <= self.current_q:
100
+ raise SurveyRuleSendsYouBackwardsError
101
+
77
102
  if not self.next_q == EndOfSurvey and self.current_q > self.next_q:
78
103
  raise SurveyRuleSendsYouBackwardsError
79
104
 
@@ -254,8 +279,16 @@ class Rule:
254
279
  msg = f"""Exception in evaluation: {e}. The expression is: {self.expression}. The current info env trying to substitute in is: {current_info_env}. After the substition, the expression was: {to_evaluate}."""
255
280
  raise SurveyRuleCannotEvaluateError(msg)
256
281
 
282
+ random_functions = {
283
+ "randint": random.randint,
284
+ "choice": random.choice,
285
+ "random": random.random,
286
+ "uniform": random.uniform,
287
+ # Add any other random functions you want to allow
288
+ }
289
+
257
290
  try:
258
- return EvalWithCompoundTypes().eval(to_evaluate)
291
+ return EvalWithCompoundTypes(functions=random_functions).eval(to_evaluate)
259
292
  except Exception as e:
260
293
  msg = f"""Exception in evaluation: {e}. The expression is: {self.expression}. The current info env trying to substitute in is: {current_info_env}. After the substition, the expression was: {to_evaluate}."""
261
294
  raise SurveyRuleCannotEvaluateError(msg)
@@ -120,13 +120,13 @@ class RuleCollection(UserList):
120
120
  :param answers: The answers to the survey questions.
121
121
 
122
122
  >>> rule_collection = RuleCollection()
123
- >>> r = Rule(current_q=1, expression="True", next_q=1, priority=1, question_name_to_index={}, before_rule = True)
123
+ >>> r = Rule(current_q=1, expression="True", next_q=2, priority=1, question_name_to_index={}, before_rule = True)
124
124
  >>> rule_collection.add_rule(r)
125
125
  >>> rule_collection.skip_question_before_running(1, {})
126
126
  True
127
127
 
128
128
  >>> rule_collection = RuleCollection()
129
- >>> r = Rule(current_q=1, expression="False", next_q=1, priority=1, question_name_to_index={}, before_rule = True)
129
+ >>> r = Rule(current_q=1, expression="False", next_q=2, priority=1, question_name_to_index={}, before_rule = True)
130
130
  >>> rule_collection.add_rule(r)
131
131
  >>> rule_collection.skip_question_before_running(1, {})
132
132
  False
@@ -172,7 +172,8 @@ class RuleCollection(UserList):
172
172
 
173
173
  def next_question(self, q_now: int, answers: dict[str, Any]) -> NextQuestion:
174
174
  """Find the next question by index, given the rule collection.
175
- This rule is applied after the question is asked.
175
+
176
+ This rule is applied after the question is answered.
176
177
 
177
178
  :param q_now: The current question index.
178
179
  :param answers: The answers to the survey questions so far, including the current question.
@@ -182,8 +183,17 @@ class RuleCollection(UserList):
182
183
  NextQuestion(next_q=3, num_rules_found=2, expressions_evaluating_to_true=1, priority=1)
183
184
 
184
185
  """
185
- # What rules apply at the current node?
186
-
186
+ # # is this the first question? If it is, we need to check if it should be skipped.
187
+ # if q_now == 0:
188
+ # if self.skip_question_before_running(q_now, answers):
189
+ # return NextQuestion(
190
+ # next_q=q_now + 1,
191
+ # num_rules_found=0,
192
+ # expressions_evaluating_to_true=0,
193
+ # priority=-1,
194
+ # )
195
+
196
+ # breakpoint()
187
197
  expressions_evaluating_to_true = 0
188
198
  next_q = None
189
199
  highest_priority = -2 # start with -2 to 'pick up' the default rule added
@@ -205,6 +215,12 @@ class RuleCollection(UserList):
205
215
  f"No rules found for question {q_now}"
206
216
  )
207
217
 
218
+ # breakpoint()
219
+ ## Now we need to check if the *next question* has any 'before; rules that we should follow
220
+ for rule in self.applicable_rules(next_q, before_rule=True):
221
+ if rule.evaluate(answers): # rule evaluates to True
222
+ return self.next_question(next_q, answers)
223
+
208
224
  return NextQuestion(
209
225
  next_q, num_rules_found, expressions_evaluating_to_true, highest_priority
210
226
  )
@@ -305,6 +321,40 @@ class RuleCollection(UserList):
305
321
 
306
322
  return DAG(dict(sorted(children_to_parents.items())))
307
323
 
324
+ def detect_cycles(self):
325
+ """
326
+ Detect cycles in the survey rules using depth-first search.
327
+
328
+ :return: A list of cycles if any are found, otherwise an empty list.
329
+ """
330
+ dag = self.dag
331
+ visited = set()
332
+ path = []
333
+ cycles = []
334
+
335
+ def dfs(node):
336
+ if node in path:
337
+ cycle = path[path.index(node) :]
338
+ cycles.append(cycle + [node])
339
+ return
340
+
341
+ if node in visited:
342
+ return
343
+
344
+ visited.add(node)
345
+ path.append(node)
346
+
347
+ for child in dag.get(node, []):
348
+ dfs(child)
349
+
350
+ path.pop()
351
+
352
+ for node in dag:
353
+ if node not in visited:
354
+ dfs(node)
355
+
356
+ return cycles
357
+
308
358
  @classmethod
309
359
  def example(cls):
310
360
  """Create an example RuleCollection object."""