edsl 0.1.33.dev3__py3-none-any.whl → 0.1.34.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. edsl/Base.py +15 -11
  2. edsl/__version__.py +1 -1
  3. edsl/agents/Invigilator.py +22 -3
  4. edsl/agents/PromptConstructor.py +79 -183
  5. edsl/agents/prompt_helpers.py +129 -0
  6. edsl/coop/coop.py +3 -2
  7. edsl/data_transfer_models.py +0 -1
  8. edsl/inference_services/AnthropicService.py +5 -2
  9. edsl/inference_services/AwsBedrock.py +5 -2
  10. edsl/inference_services/AzureAI.py +5 -2
  11. edsl/inference_services/GoogleService.py +108 -33
  12. edsl/inference_services/MistralAIService.py +5 -2
  13. edsl/inference_services/OpenAIService.py +3 -2
  14. edsl/inference_services/TestService.py +11 -2
  15. edsl/inference_services/TogetherAIService.py +1 -1
  16. edsl/jobs/interviews/Interview.py +19 -9
  17. edsl/jobs/runners/JobsRunnerAsyncio.py +37 -16
  18. edsl/jobs/runners/JobsRunnerStatus.py +4 -3
  19. edsl/jobs/tasks/QuestionTaskCreator.py +1 -13
  20. edsl/language_models/LanguageModel.py +12 -9
  21. edsl/language_models/utilities.py +3 -2
  22. edsl/questions/QuestionBase.py +11 -2
  23. edsl/questions/QuestionBaseGenMixin.py +28 -0
  24. edsl/questions/QuestionCheckBox.py +1 -1
  25. edsl/questions/QuestionMultipleChoice.py +5 -1
  26. edsl/questions/ResponseValidatorABC.py +5 -1
  27. edsl/questions/descriptors.py +12 -11
  28. edsl/questions/templates/yes_no/answering_instructions.jinja +2 -2
  29. edsl/scenarios/FileStore.py +159 -71
  30. edsl/scenarios/Scenario.py +23 -49
  31. edsl/scenarios/ScenarioList.py +6 -2
  32. edsl/surveys/DAG.py +62 -0
  33. edsl/surveys/MemoryPlan.py +26 -0
  34. edsl/surveys/Rule.py +24 -0
  35. edsl/surveys/RuleCollection.py +36 -2
  36. edsl/surveys/Survey.py +182 -10
  37. {edsl-0.1.33.dev3.dist-info → edsl-0.1.34.dev1.dist-info}/METADATA +2 -1
  38. {edsl-0.1.33.dev3.dist-info → edsl-0.1.34.dev1.dist-info}/RECORD +40 -40
  39. edsl/scenarios/ScenarioImageMixin.py +0 -100
  40. {edsl-0.1.33.dev3.dist-info → edsl-0.1.34.dev1.dist-info}/LICENSE +0 -0
  41. {edsl-0.1.33.dev3.dist-info → edsl-0.1.34.dev1.dist-info}/WHEEL +0 -0
@@ -303,7 +303,7 @@ class QuestionOptionsDescriptor(BaseDescriptor):
303
303
  return None
304
304
  else:
305
305
  raise QuestionCreationValidationError(
306
- f"Dynamic question options must have jina2 braces - instead received: {value}."
306
+ f"Dynamic question options must have jinja2 braces - instead received: {value}."
307
307
  )
308
308
  if not isinstance(value, list):
309
309
  raise QuestionCreationValidationError(
@@ -325,14 +325,15 @@ class QuestionOptionsDescriptor(BaseDescriptor):
325
325
  )
326
326
  if not self.linear_scale:
327
327
  if not self.q_budget:
328
- if not (
329
- value
330
- and all(type(x) == type(value[0]) for x in value)
331
- and isinstance(value[0], (str, list, int, float))
332
- ):
333
- raise QuestionCreationValidationError(
334
- f"Question options must be all same type (got {value}).)"
335
- )
328
+ pass
329
+ # if not (
330
+ # value
331
+ # and all(type(x) == type(value[0]) for x in value)
332
+ # and isinstance(value[0], (str, list, int, float))
333
+ # ):
334
+ # raise QuestionCreationValidationError(
335
+ # f"Question options must be all same type (got {value}).)"
336
+ # )
336
337
  else:
337
338
  if not all(isinstance(x, (str)) for x in value):
338
339
  raise QuestionCreationValidationError(
@@ -390,8 +391,8 @@ class QuestionTextDescriptor(BaseDescriptor):
390
391
 
391
392
  def validate(self, value, instance):
392
393
  """Validate the value is a string."""
393
- if len(value) > Settings.MAX_QUESTION_LENGTH:
394
- raise Exception("Question is too long!")
394
+ # if len(value) > Settings.MAX_QUESTION_LENGTH:
395
+ # raise Exception("Question is too long!")
395
396
  if len(value) < 1:
396
397
  raise Exception("Question is too short!")
397
398
  if not isinstance(value, str):
@@ -1,6 +1,6 @@
1
1
  {# Answering Instructions #}
2
- Please reponse with just your answer.
2
+ Please respond with just your answer.
3
3
 
4
4
  {% if include_comment %}
5
- After the answer, you can put a comment explaining your reponse.
5
+ After the answer, you can put a comment explaining your response.
6
6
  {% endif %}
@@ -1,41 +1,101 @@
1
- from edsl import Scenario
2
1
  import base64
3
2
  import io
4
3
  import tempfile
5
- from typing import Optional
4
+ import mimetypes
5
+ import os
6
+ from typing import Dict, Any, IO, Optional
7
+ import requests
8
+ from urllib.parse import urlparse
9
+
10
+ import google.generativeai as genai
11
+
12
+ from edsl import Scenario
13
+ from edsl.utilities.decorators import add_edsl_version, remove_edsl_version
14
+ from edsl.utilities.utilities import is_notebook
15
+
16
+
17
+ def view_pdf(pdf_path):
18
+ import os
19
+ import subprocess
20
+
21
+ if is_notebook():
22
+ from IPython.display import IFrame
23
+ from IPython.display import display, HTML
24
+
25
+ # Replace 'path/to/your/file.pdf' with the actual path to your PDF file
26
+ IFrame(pdf_path, width=700, height=600)
27
+ display(HTML(f'<a href="{pdf_path}" target="_blank">Open PDF</a>'))
28
+ return
29
+
30
+ if os.path.exists(pdf_path):
31
+ try:
32
+ if (os_name := os.name) == "posix":
33
+ # for cool kids
34
+ subprocess.run(["open", pdf_path], check=True) # macOS
35
+ elif os_name == "nt":
36
+ os.startfile(pdf_path) # Windows
37
+ else:
38
+ subprocess.run(["xdg-open", pdf_path], check=True) # Linux
39
+ except Exception as e:
40
+ print(f"Error opening PDF: {e}")
41
+ else:
42
+ print("PDF file was not created successfully.")
6
43
 
7
44
 
8
45
  class FileStore(Scenario):
9
46
  def __init__(
10
47
  self,
11
- filename: str,
48
+ path: Optional[str] = None,
49
+ mime_type: Optional[str] = None,
12
50
  binary: Optional[bool] = None,
13
51
  suffix: Optional[str] = None,
14
52
  base64_string: Optional[str] = None,
53
+ external_locations: Optional[Dict[str, str]] = None,
54
+ **kwargs,
15
55
  ):
16
- self.filename = filename
17
- self.suffix = suffix or "." + filename.split(".")[-1]
56
+ if path is None and "filename" in kwargs:
57
+ path = kwargs["filename"]
58
+ self.path = path
59
+ self.suffix = suffix or path.split(".")[-1]
18
60
  self.binary = binary or False
19
- self.base64_string = base64_string or self.encode_file_to_base64_string(
20
- filename
61
+ self.mime_type = (
62
+ mime_type or mimetypes.guess_type(path)[0] or "application/octet-stream"
21
63
  )
64
+ self.base64_string = base64_string or self.encode_file_to_base64_string(path)
65
+ self.external_locations = external_locations or {}
22
66
  super().__init__(
23
67
  {
24
- "filename": self.filename,
68
+ "path": self.path,
25
69
  "base64_string": self.base64_string,
26
70
  "binary": self.binary,
27
71
  "suffix": self.suffix,
72
+ "mime_type": self.mime_type,
73
+ "external_locations": self.external_locations,
28
74
  }
29
75
  )
30
76
 
77
+ def __str__(self):
78
+ return "FileStore: self.path"
79
+
80
+ @property
81
+ def size(self) -> int:
82
+ return os.path.getsize(self.path)
83
+
84
+ def upload_google(self, refresh: bool = False) -> None:
85
+ genai.configure(api_key=os.getenv("GOOGLE_API_KEY"))
86
+ google_info = genai.upload_file(self.path, mime_type=self.mime_type)
87
+ self.external_locations["google"] = google_info.to_dict()
88
+
31
89
  @classmethod
90
+ @remove_edsl_version
32
91
  def from_dict(cls, d):
33
- return cls(d["filename"], d["binary"], d["suffix"], d["base64_string"])
92
+ # return cls(d["filename"], d["binary"], d["suffix"], d["base64_string"])
93
+ return cls(**d)
34
94
 
35
95
  def __repr__(self):
36
- return f"FileStore(filename='{self.filename}', binary='{self.binary}', 'suffix'={self.suffix})"
96
+ return f"FileStore({self.path})"
37
97
 
38
- def encode_file_to_base64_string(self, file_path):
98
+ def encode_file_to_base64_string(self, file_path: str):
39
99
  try:
40
100
  # Attempt to open the file in text mode
41
101
  with open(file_path, "r") as text_file:
@@ -56,14 +116,14 @@ class FileStore(Scenario):
56
116
 
57
117
  return base64_string
58
118
 
59
- def open(self):
119
+ def open(self) -> "IO":
60
120
  if self.binary:
61
121
  return self.base64_to_file(self["base64_string"], is_binary=True)
62
122
  else:
63
123
  return self.base64_to_text_file(self["base64_string"])
64
124
 
65
125
  @staticmethod
66
- def base64_to_text_file(base64_string):
126
+ def base64_to_text_file(base64_string) -> "IO":
67
127
  # Decode the base64 string to bytes
68
128
  text_data_bytes = base64.b64decode(base64_string)
69
129
 
@@ -101,7 +161,9 @@ class FileStore(Scenario):
101
161
 
102
162
  # Create a named temporary file
103
163
  mode = "wb" if self.binary else "w"
104
- temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=suffix, mode=mode)
164
+ temp_file = tempfile.NamedTemporaryFile(
165
+ delete=False, suffix="." + suffix, mode=mode
166
+ )
105
167
 
106
168
  if self.binary:
107
169
  temp_file.write(file_like_object.read())
@@ -112,30 +174,95 @@ class FileStore(Scenario):
112
174
 
113
175
  return temp_file.name
114
176
 
115
- def push(self, description=None):
177
+ def view(self, max_size: int = 300) -> None:
178
+ if self.suffix == "pdf":
179
+ view_pdf(self.path)
180
+
181
+ if self.suffix == "png" or self.suffix == "jpg" or self.suffix == "jpeg":
182
+ if is_notebook():
183
+ from IPython.display import Image
184
+ from PIL import Image as PILImage
185
+
186
+ if max_size:
187
+ # Open the image using Pillow
188
+ with PILImage.open(self.path) as img:
189
+ # Get original width and height
190
+ original_width, original_height = img.size
191
+
192
+ # Calculate the scaling factor
193
+ scale = min(
194
+ max_size / original_width, max_size / original_height
195
+ )
196
+
197
+ # Calculate new dimensions
198
+ new_width = int(original_width * scale)
199
+ new_height = int(original_height * scale)
200
+
201
+ return Image(self.path, width=new_width, height=new_height)
202
+ else:
203
+ return Image(self.path)
204
+
205
+ def push(
206
+ self, description: Optional[str] = None, visibility: str = "unlisted"
207
+ ) -> dict:
208
+ """
209
+ Push the object to Coop.
210
+ :param description: The description of the object to push.
211
+ :param visibility: The visibility of the object to push.
212
+ """
116
213
  scenario_version = Scenario.from_dict(self.to_dict())
117
214
  if description is None:
118
- description = "File: " + self["filename"]
119
- info = scenario_version.push(description=description)
215
+ description = "File: " + self.path
216
+ info = scenario_version.push(description=description, visibility=visibility)
120
217
  return info
121
218
 
122
219
  @classmethod
123
- def pull(cls, uuid, expected_parrot_url: Optional[str] = None):
220
+ def pull(cls, uuid: str, expected_parrot_url: Optional[str] = None) -> "FileStore":
221
+ """
222
+ :param uuid: The UUID of the object to pull.
223
+ :param expected_parrot_url: The URL of the Parrot server to use.
224
+ :return: The object pulled from the Parrot server.
225
+ """
124
226
  scenario_version = Scenario.pull(uuid, expected_parrot_url=expected_parrot_url)
125
227
  return cls.from_dict(scenario_version.to_dict())
126
228
 
229
+ @classmethod
230
+ def from_url(
231
+ cls,
232
+ url: str,
233
+ download_path: Optional[str] = None,
234
+ mime_type: Optional[str] = None,
235
+ ) -> "FileStore":
236
+ """
237
+ :param url: The URL of the file to download.
238
+ :param download_path: The path to save the downloaded file.
239
+ :param mime_type: The MIME type of the file. If None, it will be guessed from the file extension.
240
+ """
241
+
242
+ response = requests.get(url, stream=True)
243
+ response.raise_for_status() # Raises an HTTPError for bad responses
244
+
245
+ # Get the filename from the URL if download_path is not provided
246
+ if download_path is None:
247
+ filename = os.path.basename(urlparse(url).path)
248
+ if not filename:
249
+ filename = "downloaded_file"
250
+ # download_path = filename
251
+ download_path = os.path.join(os.getcwd(), filename)
252
+
253
+ # Ensure the directory exists
254
+ os.makedirs(os.path.dirname(download_path), exist_ok=True)
255
+
256
+ # Write the file
257
+ with open(download_path, "wb") as file:
258
+ for chunk in response.iter_content(chunk_size=8192):
259
+ file.write(chunk)
260
+
261
+ # Create and return a new File instance
262
+ return cls(download_path, mime_type=mime_type)
263
+
127
264
 
128
265
  class CSVFileStore(FileStore):
129
- def __init__(
130
- self,
131
- filename,
132
- binary: Optional[bool] = None,
133
- suffix: Optional[str] = None,
134
- base64_string: Optional[str] = None,
135
- ):
136
- super().__init__(
137
- filename, binary=binary, base64_string=base64_string, suffix=".csv"
138
- )
139
266
 
140
267
  @classmethod
141
268
  def example(cls):
@@ -155,16 +282,6 @@ class CSVFileStore(FileStore):
155
282
 
156
283
 
157
284
  class PDFFileStore(FileStore):
158
- def __init__(
159
- self,
160
- filename,
161
- binary: Optional[bool] = None,
162
- suffix: Optional[str] = None,
163
- base64_string: Optional[str] = None,
164
- ):
165
- super().__init__(
166
- filename, binary=binary, base64_string=base64_string, suffix=".pdf"
167
- )
168
285
 
169
286
  def view(self):
170
287
  pdf_path = self.to_tempfile()
@@ -241,16 +358,6 @@ class PDFFileStore(FileStore):
241
358
 
242
359
 
243
360
  class PNGFileStore(FileStore):
244
- def __init__(
245
- self,
246
- filename,
247
- binary: Optional[bool] = None,
248
- suffix: Optional[str] = None,
249
- base64_string: Optional[str] = None,
250
- ):
251
- super().__init__(
252
- filename, binary=binary, base64_string=base64_string, suffix=".png"
253
- )
254
361
 
255
362
  @classmethod
256
363
  def example(cls):
@@ -275,16 +382,6 @@ class PNGFileStore(FileStore):
275
382
 
276
383
 
277
384
  class SQLiteFileStore(FileStore):
278
- def __init__(
279
- self,
280
- filename,
281
- binary: Optional[bool] = None,
282
- suffix: Optional[str] = None,
283
- base64_string: Optional[str] = None,
284
- ):
285
- super().__init__(
286
- filename, binary=binary, base64_string=base64_string, suffix=".sqlite"
287
- )
288
385
 
289
386
  @classmethod
290
387
  def example(cls):
@@ -308,16 +405,6 @@ class SQLiteFileStore(FileStore):
308
405
 
309
406
 
310
407
  class HTMLFileStore(FileStore):
311
- def __init__(
312
- self,
313
- filename,
314
- binary: Optional[bool] = None,
315
- suffix: Optional[str] = None,
316
- base64_string: Optional[str] = None,
317
- ):
318
- super().__init__(
319
- filename, binary=binary, base64_string=base64_string, suffix=".html"
320
- )
321
408
 
322
409
  @classmethod
323
410
  def example(cls):
@@ -350,9 +437,10 @@ if __name__ == "__main__":
350
437
  # fs = PDFFileStore("paper.pdf")
351
438
  # fs.view()
352
439
  # from edsl import Conjure
353
-
354
- fs = PNGFileStore("robot.png")
355
- fs.view()
440
+ pass
441
+ # fs = PNGFileStore("logo.png")
442
+ # fs.view()
443
+ # fs.upload_google()
356
444
 
357
445
  # c = Conjure(datafile_name=fs.to_tempfile())
358
446
  # f = PDFFileStore("paper.pdf")
@@ -2,25 +2,18 @@
2
2
 
3
3
  from __future__ import annotations
4
4
  import copy
5
- import base64
6
5
  import hashlib
7
6
  import os
8
- import reprlib
9
- import imghdr
10
-
11
-
12
7
  from collections import UserDict
13
8
  from typing import Union, List, Optional, Generator
14
9
  from uuid import uuid4
10
+
15
11
  from edsl.Base import Base
16
- from edsl.scenarios.ScenarioImageMixin import ScenarioImageMixin
17
12
  from edsl.scenarios.ScenarioHtmlMixin import ScenarioHtmlMixin
18
13
  from edsl.utilities.decorators import add_edsl_version, remove_edsl_version
19
14
 
20
- from edsl.data_transfer_models import ImageInfo
21
-
22
15
 
23
- class Scenario(Base, UserDict, ScenarioImageMixin, ScenarioHtmlMixin):
16
+ class Scenario(Base, UserDict, ScenarioHtmlMixin):
24
17
  """A Scenario is a dictionary of keys/values.
25
18
 
26
19
  They can be used parameterize edsl questions."""
@@ -48,12 +41,12 @@ class Scenario(Base, UserDict, ScenarioImageMixin, ScenarioHtmlMixin):
48
41
 
49
42
  return ScenarioList([copy.deepcopy(self) for _ in range(n)])
50
43
 
51
- @property
52
- def has_image(self) -> bool:
53
- """Return whether the scenario has an image."""
54
- if not hasattr(self, "_has_image"):
55
- self._has_image = False
56
- return self._has_image
44
+ # @property
45
+ # def has_image(self) -> bool:
46
+ # """Return whether the scenario has an image."""
47
+ # if not hasattr(self, "_has_image"):
48
+ # self._has_image = False
49
+ # return self._has_image
57
50
 
58
51
  @property
59
52
  def has_jinja_braces(self) -> bool:
@@ -63,9 +56,10 @@ class Scenario(Base, UserDict, ScenarioImageMixin, ScenarioHtmlMixin):
63
56
  >>> s.has_jinja_braces
64
57
  True
65
58
  """
66
- for key, value in self.items():
67
- if "{{" in str(value) and "}}" in value:
68
- return True
59
+ for _, value in self.items():
60
+ if isinstance(value, str):
61
+ if "{{" in value and "}}" in value:
62
+ return True
69
63
  return False
70
64
 
71
65
  def convert_jinja_braces(
@@ -88,10 +82,6 @@ class Scenario(Base, UserDict, ScenarioImageMixin, ScenarioHtmlMixin):
88
82
  new_scenario[key] = value
89
83
  return new_scenario
90
84
 
91
- @has_image.setter
92
- def has_image(self, value):
93
- self._has_image = value
94
-
95
85
  def __add__(self, other_scenario: "Scenario") -> "Scenario":
96
86
  """Combine two scenarios by taking the union of their keys
97
87
 
@@ -114,8 +104,6 @@ class Scenario(Base, UserDict, ScenarioImageMixin, ScenarioHtmlMixin):
114
104
  data1 = copy.deepcopy(self.data)
115
105
  data2 = copy.deepcopy(other_scenario.data)
116
106
  s = Scenario(data1 | data2)
117
- if self.has_image or other_scenario.has_image:
118
- s._has_image = True
119
107
  return s
120
108
 
121
109
  def rename(self, replacement_dict: dict) -> "Scenario":
@@ -235,6 +223,14 @@ class Scenario(Base, UserDict, ScenarioImageMixin, ScenarioHtmlMixin):
235
223
  text = requests.get(url).text
236
224
  return cls({"url": url, field_name: text})
237
225
 
226
+ @classmethod
227
+ def from_file(cls, file_path: str, field_name: str) -> "Scenario":
228
+ """Creates a scenario from a file."""
229
+ from edsl.scenarios.FileStore import FileStore
230
+
231
+ fs = FileStore(file_path)
232
+ return cls({field_name: fs})
233
+
238
234
  @classmethod
239
235
  def from_image(
240
236
  cls, image_path: str, image_name: Optional[str] = None
@@ -248,36 +244,14 @@ class Scenario(Base, UserDict, ScenarioImageMixin, ScenarioHtmlMixin):
248
244
  Returns:
249
245
  Scenario: A new Scenario instance with image information.
250
246
 
251
- Example:
252
- >>> s = Scenario.from_image(Scenario.example_image())
253
- >>> s
254
- Scenario({'logo': ...})
255
247
  """
256
248
  if not os.path.exists(image_path):
257
249
  raise FileNotFoundError(f"Image file not found: {image_path}")
258
250
 
259
- with open(image_path, "rb") as image_file:
260
- file_content = image_file.read()
261
-
262
- file_name = os.path.basename(image_path)
263
- file_size = os.path.getsize(image_path)
264
- image_format = imghdr.what(image_path) or "unknown"
265
-
266
251
  if image_name is None:
267
- image_name = file_name.split(".")[0]
268
-
269
- image_info = ImageInfo(
270
- file_path=image_path,
271
- file_name=file_name,
272
- image_format=image_format,
273
- file_size=file_size,
274
- encoded_image=base64.b64encode(file_content).decode("utf-8"),
275
- )
276
-
277
- scenario_data = {image_name: image_info}
278
- s = cls(scenario_data)
279
- s.has_image = True
280
- return s
252
+ image_name = os.path.basename(image_path).split(".")[0]
253
+
254
+ return cls.from_file(image_path, image_name)
281
255
 
282
256
  @classmethod
283
257
  def from_pdf(cls, pdf_path):
@@ -530,7 +530,9 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
530
530
  return ScenarioList([scenario.drop(fields) for scenario in self.data])
531
531
 
532
532
  @classmethod
533
- def from_list(cls, name, values) -> ScenarioList:
533
+ def from_list(
534
+ cls, name: str, values: list, func: Optional[Callable] = None
535
+ ) -> ScenarioList:
534
536
  """Create a ScenarioList from a list of values.
535
537
 
536
538
  Example:
@@ -538,7 +540,9 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
538
540
  >>> ScenarioList.from_list('name', ['Alice', 'Bob'])
539
541
  ScenarioList([Scenario({'name': 'Alice'}), Scenario({'name': 'Bob'})])
540
542
  """
541
- return cls([Scenario({name: value}) for value in values])
543
+ if not func:
544
+ func = lambda x: x
545
+ return cls([Scenario({name: func(value)}) for value in values])
542
546
 
543
547
  def to_dataset(self) -> "Dataset":
544
548
  """
edsl/surveys/DAG.py CHANGED
@@ -11,6 +11,7 @@ class DAG(UserDict):
11
11
  """Initialize the DAG class."""
12
12
  super().__init__(data)
13
13
  self.reverse_mapping = self._create_reverse_mapping()
14
+ self.validate_no_cycles()
14
15
 
15
16
  def _create_reverse_mapping(self):
16
17
  """
@@ -73,12 +74,73 @@ class DAG(UserDict):
73
74
  # else:
74
75
  # return DAG(d)
75
76
 
77
+ def remove_node(self, node: int) -> None:
78
+ """Remove a node and all its connections from the DAG."""
79
+ self.pop(node, None)
80
+ for connections in self.values():
81
+ connections.discard(node)
82
+ # Adjust remaining nodes if necessary
83
+ self._adjust_nodes_after_removal(node)
84
+
85
+ def _adjust_nodes_after_removal(self, removed_node: int) -> None:
86
+ """Adjust node indices after a node is removed."""
87
+ new_dag = {}
88
+ for node, connections in self.items():
89
+ new_node = node if node < removed_node else node - 1
90
+ new_connections = {c if c < removed_node else c - 1 for c in connections}
91
+ new_dag[new_node] = new_connections
92
+ self.clear()
93
+ self.update(new_dag)
94
+
76
95
  @classmethod
77
96
  def example(cls):
78
97
  """Return an example of the `DAG`."""
79
98
  data = {"a": ["b", "c"], "b": ["d"], "c": [], "d": []}
80
99
  return cls(data)
81
100
 
101
+ def detect_cycles(self):
102
+ """
103
+ Detect cycles in the DAG using depth-first search.
104
+
105
+ :return: A list of cycles if any are found, otherwise an empty list.
106
+ """
107
+ visited = set()
108
+ path = []
109
+ cycles = []
110
+
111
+ def dfs(node):
112
+ if node in path:
113
+ cycle = path[path.index(node) :]
114
+ cycles.append(cycle + [node])
115
+ return
116
+
117
+ if node in visited:
118
+ return
119
+
120
+ visited.add(node)
121
+ path.append(node)
122
+
123
+ for child in self.get(node, []):
124
+ dfs(child)
125
+
126
+ path.pop()
127
+
128
+ for node in self:
129
+ if node not in visited:
130
+ dfs(node)
131
+
132
+ return cycles
133
+
134
+ def validate_no_cycles(self):
135
+ """
136
+ Validate that the DAG does not contain any cycles.
137
+
138
+ :raises ValueError: If cycles are detected in the DAG.
139
+ """
140
+ cycles = self.detect_cycles()
141
+ if cycles:
142
+ raise ValueError(f"Cycles detected in the DAG: {cycles}")
143
+
82
144
 
83
145
  if __name__ == "__main__":
84
146
  import doctest
@@ -211,6 +211,32 @@ class MemoryPlan(UserDict):
211
211
  mp.add_single_memory("q1", "q0")
212
212
  return mp
213
213
 
214
+ def remove_question(self, question_name: str) -> None:
215
+ """Remove a question from the memory plan.
216
+
217
+ :param question_name: The name of the question to remove.
218
+ """
219
+ self._check_valid_question_name(question_name)
220
+
221
+ # Remove the question from survey_question_names and question_texts
222
+ index = self.survey_question_names.index(question_name)
223
+ self.survey_question_names.pop(index)
224
+ self.question_texts.pop(index)
225
+
226
+ # Remove the question from the memory plan if it's a focal question
227
+ self.pop(question_name, None)
228
+
229
+ # Remove the question from all memories where it appears as a prior question
230
+ for focal_question, memory in self.items():
231
+ memory.remove_prior_question(question_name)
232
+
233
+ # Update the DAG
234
+ self.dag.remove_node(index)
235
+
236
+ def remove_prior_question(self, question_name: str) -> None:
237
+ """Remove a prior question from the memory."""
238
+ self.prior_questions = [q for q in self.prior_questions if q != question_name]
239
+
214
240
 
215
241
  if __name__ == "__main__":
216
242
  import doctest