edsl 0.1.33.dev3__py3-none-any.whl → 0.1.34__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. edsl/Base.py +15 -11
  2. edsl/__version__.py +1 -1
  3. edsl/agents/Invigilator.py +22 -3
  4. edsl/agents/PromptConstructor.py +80 -184
  5. edsl/agents/prompt_helpers.py +129 -0
  6. edsl/coop/coop.py +3 -2
  7. edsl/data_transfer_models.py +0 -1
  8. edsl/inference_services/AnthropicService.py +5 -2
  9. edsl/inference_services/AwsBedrock.py +5 -2
  10. edsl/inference_services/AzureAI.py +5 -2
  11. edsl/inference_services/GoogleService.py +108 -33
  12. edsl/inference_services/MistralAIService.py +5 -2
  13. edsl/inference_services/OpenAIService.py +3 -2
  14. edsl/inference_services/TestService.py +11 -2
  15. edsl/inference_services/TogetherAIService.py +1 -1
  16. edsl/jobs/Jobs.py +91 -10
  17. edsl/jobs/interviews/Interview.py +15 -2
  18. edsl/jobs/runners/JobsRunnerAsyncio.py +46 -25
  19. edsl/jobs/runners/JobsRunnerStatus.py +4 -3
  20. edsl/jobs/tasks/QuestionTaskCreator.py +1 -13
  21. edsl/language_models/LanguageModel.py +12 -9
  22. edsl/language_models/utilities.py +5 -2
  23. edsl/questions/QuestionBase.py +13 -3
  24. edsl/questions/QuestionBaseGenMixin.py +28 -0
  25. edsl/questions/QuestionCheckBox.py +1 -1
  26. edsl/questions/QuestionMultipleChoice.py +8 -4
  27. edsl/questions/ResponseValidatorABC.py +5 -1
  28. edsl/questions/descriptors.py +12 -11
  29. edsl/questions/templates/numerical/answering_instructions.jinja +0 -1
  30. edsl/questions/templates/yes_no/answering_instructions.jinja +2 -2
  31. edsl/scenarios/FileStore.py +159 -76
  32. edsl/scenarios/Scenario.py +23 -49
  33. edsl/scenarios/ScenarioList.py +6 -2
  34. edsl/surveys/DAG.py +62 -0
  35. edsl/surveys/MemoryPlan.py +26 -0
  36. edsl/surveys/Rule.py +24 -0
  37. edsl/surveys/RuleCollection.py +36 -2
  38. edsl/surveys/Survey.py +182 -10
  39. edsl/surveys/base.py +4 -0
  40. {edsl-0.1.33.dev3.dist-info → edsl-0.1.34.dist-info}/METADATA +2 -1
  41. {edsl-0.1.33.dev3.dist-info → edsl-0.1.34.dist-info}/RECORD +43 -43
  42. edsl/scenarios/ScenarioImageMixin.py +0 -100
  43. {edsl-0.1.33.dev3.dist-info → edsl-0.1.34.dist-info}/LICENSE +0 -0
  44. {edsl-0.1.33.dev3.dist-info → edsl-0.1.34.dist-info}/WHEEL +0 -0
@@ -2,25 +2,18 @@
2
2
 
3
3
  from __future__ import annotations
4
4
  import copy
5
- import base64
6
5
  import hashlib
7
6
  import os
8
- import reprlib
9
- import imghdr
10
-
11
-
12
7
  from collections import UserDict
13
8
  from typing import Union, List, Optional, Generator
14
9
  from uuid import uuid4
10
+
15
11
  from edsl.Base import Base
16
- from edsl.scenarios.ScenarioImageMixin import ScenarioImageMixin
17
12
  from edsl.scenarios.ScenarioHtmlMixin import ScenarioHtmlMixin
18
13
  from edsl.utilities.decorators import add_edsl_version, remove_edsl_version
19
14
 
20
- from edsl.data_transfer_models import ImageInfo
21
-
22
15
 
23
- class Scenario(Base, UserDict, ScenarioImageMixin, ScenarioHtmlMixin):
16
+ class Scenario(Base, UserDict, ScenarioHtmlMixin):
24
17
  """A Scenario is a dictionary of keys/values.
25
18
 
26
19
  They can be used parameterize edsl questions."""
@@ -48,12 +41,12 @@ class Scenario(Base, UserDict, ScenarioImageMixin, ScenarioHtmlMixin):
48
41
 
49
42
  return ScenarioList([copy.deepcopy(self) for _ in range(n)])
50
43
 
51
- @property
52
- def has_image(self) -> bool:
53
- """Return whether the scenario has an image."""
54
- if not hasattr(self, "_has_image"):
55
- self._has_image = False
56
- return self._has_image
44
+ # @property
45
+ # def has_image(self) -> bool:
46
+ # """Return whether the scenario has an image."""
47
+ # if not hasattr(self, "_has_image"):
48
+ # self._has_image = False
49
+ # return self._has_image
57
50
 
58
51
  @property
59
52
  def has_jinja_braces(self) -> bool:
@@ -63,9 +56,10 @@ class Scenario(Base, UserDict, ScenarioImageMixin, ScenarioHtmlMixin):
63
56
  >>> s.has_jinja_braces
64
57
  True
65
58
  """
66
- for key, value in self.items():
67
- if "{{" in str(value) and "}}" in value:
68
- return True
59
+ for _, value in self.items():
60
+ if isinstance(value, str):
61
+ if "{{" in value and "}}" in value:
62
+ return True
69
63
  return False
70
64
 
71
65
  def convert_jinja_braces(
@@ -88,10 +82,6 @@ class Scenario(Base, UserDict, ScenarioImageMixin, ScenarioHtmlMixin):
88
82
  new_scenario[key] = value
89
83
  return new_scenario
90
84
 
91
- @has_image.setter
92
- def has_image(self, value):
93
- self._has_image = value
94
-
95
85
  def __add__(self, other_scenario: "Scenario") -> "Scenario":
96
86
  """Combine two scenarios by taking the union of their keys
97
87
 
@@ -114,8 +104,6 @@ class Scenario(Base, UserDict, ScenarioImageMixin, ScenarioHtmlMixin):
114
104
  data1 = copy.deepcopy(self.data)
115
105
  data2 = copy.deepcopy(other_scenario.data)
116
106
  s = Scenario(data1 | data2)
117
- if self.has_image or other_scenario.has_image:
118
- s._has_image = True
119
107
  return s
120
108
 
121
109
  def rename(self, replacement_dict: dict) -> "Scenario":
@@ -235,6 +223,14 @@ class Scenario(Base, UserDict, ScenarioImageMixin, ScenarioHtmlMixin):
235
223
  text = requests.get(url).text
236
224
  return cls({"url": url, field_name: text})
237
225
 
226
+ @classmethod
227
+ def from_file(cls, file_path: str, field_name: str) -> "Scenario":
228
+ """Creates a scenario from a file."""
229
+ from edsl.scenarios.FileStore import FileStore
230
+
231
+ fs = FileStore(file_path)
232
+ return cls({field_name: fs})
233
+
238
234
  @classmethod
239
235
  def from_image(
240
236
  cls, image_path: str, image_name: Optional[str] = None
@@ -248,36 +244,14 @@ class Scenario(Base, UserDict, ScenarioImageMixin, ScenarioHtmlMixin):
248
244
  Returns:
249
245
  Scenario: A new Scenario instance with image information.
250
246
 
251
- Example:
252
- >>> s = Scenario.from_image(Scenario.example_image())
253
- >>> s
254
- Scenario({'logo': ...})
255
247
  """
256
248
  if not os.path.exists(image_path):
257
249
  raise FileNotFoundError(f"Image file not found: {image_path}")
258
250
 
259
- with open(image_path, "rb") as image_file:
260
- file_content = image_file.read()
261
-
262
- file_name = os.path.basename(image_path)
263
- file_size = os.path.getsize(image_path)
264
- image_format = imghdr.what(image_path) or "unknown"
265
-
266
251
  if image_name is None:
267
- image_name = file_name.split(".")[0]
268
-
269
- image_info = ImageInfo(
270
- file_path=image_path,
271
- file_name=file_name,
272
- image_format=image_format,
273
- file_size=file_size,
274
- encoded_image=base64.b64encode(file_content).decode("utf-8"),
275
- )
276
-
277
- scenario_data = {image_name: image_info}
278
- s = cls(scenario_data)
279
- s.has_image = True
280
- return s
252
+ image_name = os.path.basename(image_path).split(".")[0]
253
+
254
+ return cls.from_file(image_path, image_name)
281
255
 
282
256
  @classmethod
283
257
  def from_pdf(cls, pdf_path):
@@ -530,7 +530,9 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
530
530
  return ScenarioList([scenario.drop(fields) for scenario in self.data])
531
531
 
532
532
  @classmethod
533
- def from_list(cls, name, values) -> ScenarioList:
533
+ def from_list(
534
+ cls, name: str, values: list, func: Optional[Callable] = None
535
+ ) -> ScenarioList:
534
536
  """Create a ScenarioList from a list of values.
535
537
 
536
538
  Example:
@@ -538,7 +540,9 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
538
540
  >>> ScenarioList.from_list('name', ['Alice', 'Bob'])
539
541
  ScenarioList([Scenario({'name': 'Alice'}), Scenario({'name': 'Bob'})])
540
542
  """
541
- return cls([Scenario({name: value}) for value in values])
543
+ if not func:
544
+ func = lambda x: x
545
+ return cls([Scenario({name: func(value)}) for value in values])
542
546
 
543
547
  def to_dataset(self) -> "Dataset":
544
548
  """
edsl/surveys/DAG.py CHANGED
@@ -11,6 +11,7 @@ class DAG(UserDict):
11
11
  """Initialize the DAG class."""
12
12
  super().__init__(data)
13
13
  self.reverse_mapping = self._create_reverse_mapping()
14
+ self.validate_no_cycles()
14
15
 
15
16
  def _create_reverse_mapping(self):
16
17
  """
@@ -73,12 +74,73 @@ class DAG(UserDict):
73
74
  # else:
74
75
  # return DAG(d)
75
76
 
77
+ def remove_node(self, node: int) -> None:
78
+ """Remove a node and all its connections from the DAG."""
79
+ self.pop(node, None)
80
+ for connections in self.values():
81
+ connections.discard(node)
82
+ # Adjust remaining nodes if necessary
83
+ self._adjust_nodes_after_removal(node)
84
+
85
+ def _adjust_nodes_after_removal(self, removed_node: int) -> None:
86
+ """Adjust node indices after a node is removed."""
87
+ new_dag = {}
88
+ for node, connections in self.items():
89
+ new_node = node if node < removed_node else node - 1
90
+ new_connections = {c if c < removed_node else c - 1 for c in connections}
91
+ new_dag[new_node] = new_connections
92
+ self.clear()
93
+ self.update(new_dag)
94
+
76
95
  @classmethod
77
96
  def example(cls):
78
97
  """Return an example of the `DAG`."""
79
98
  data = {"a": ["b", "c"], "b": ["d"], "c": [], "d": []}
80
99
  return cls(data)
81
100
 
101
+ def detect_cycles(self):
102
+ """
103
+ Detect cycles in the DAG using depth-first search.
104
+
105
+ :return: A list of cycles if any are found, otherwise an empty list.
106
+ """
107
+ visited = set()
108
+ path = []
109
+ cycles = []
110
+
111
+ def dfs(node):
112
+ if node in path:
113
+ cycle = path[path.index(node) :]
114
+ cycles.append(cycle + [node])
115
+ return
116
+
117
+ if node in visited:
118
+ return
119
+
120
+ visited.add(node)
121
+ path.append(node)
122
+
123
+ for child in self.get(node, []):
124
+ dfs(child)
125
+
126
+ path.pop()
127
+
128
+ for node in self:
129
+ if node not in visited:
130
+ dfs(node)
131
+
132
+ return cycles
133
+
134
+ def validate_no_cycles(self):
135
+ """
136
+ Validate that the DAG does not contain any cycles.
137
+
138
+ :raises ValueError: If cycles are detected in the DAG.
139
+ """
140
+ cycles = self.detect_cycles()
141
+ if cycles:
142
+ raise ValueError(f"Cycles detected in the DAG: {cycles}")
143
+
82
144
 
83
145
  if __name__ == "__main__":
84
146
  import doctest
@@ -211,6 +211,32 @@ class MemoryPlan(UserDict):
211
211
  mp.add_single_memory("q1", "q0")
212
212
  return mp
213
213
 
214
+ def remove_question(self, question_name: str) -> None:
215
+ """Remove a question from the memory plan.
216
+
217
+ :param question_name: The name of the question to remove.
218
+ """
219
+ self._check_valid_question_name(question_name)
220
+
221
+ # Remove the question from survey_question_names and question_texts
222
+ index = self.survey_question_names.index(question_name)
223
+ self.survey_question_names.pop(index)
224
+ self.question_texts.pop(index)
225
+
226
+ # Remove the question from the memory plan if it's a focal question
227
+ self.pop(question_name, None)
228
+
229
+ # Remove the question from all memories where it appears as a prior question
230
+ for focal_question, memory in self.items():
231
+ memory.remove_prior_question(question_name)
232
+
233
+ # Update the DAG
234
+ self.dag.remove_node(index)
235
+
236
+ def remove_prior_question(self, question_name: str) -> None:
237
+ """Remove a prior question from the memory."""
238
+ self.prior_questions = [q for q in self.prior_questions if q != question_name]
239
+
214
240
 
215
241
  if __name__ == "__main__":
216
242
  import doctest
edsl/surveys/Rule.py CHANGED
@@ -38,9 +38,29 @@ from edsl.utilities.ast_utilities import extract_variable_names
38
38
  from edsl.utilities.decorators import add_edsl_version, remove_edsl_version
39
39
 
40
40
 
41
+ class QuestionIndex:
42
+ def __set_name__(self, owner, name):
43
+ self.name = f"_{name}"
44
+
45
+ def __get__(self, obj, objtype=None):
46
+ return getattr(obj, self.name)
47
+
48
+ def __set__(self, obj, value):
49
+ if not isinstance(value, (int, EndOfSurvey.__class__)):
50
+ raise ValueError(f"{self.name} must be an integer or EndOfSurvey")
51
+ if self.name == "_next_q" and isinstance(value, int):
52
+ current_q = getattr(obj, "_current_q")
53
+ if value <= current_q:
54
+ raise ValueError("next_q must be greater than current_q")
55
+ setattr(obj, self.name, value)
56
+
57
+
41
58
  class Rule:
42
59
  """The Rule class defines a "rule" for determining the next question presented to an agent."""
43
60
 
61
+ current_q = QuestionIndex()
62
+ next_q = QuestionIndex()
63
+
44
64
  # Not implemented but nice to have:
45
65
  # We could potentially use the question pydantic models to check for rule conflicts, as
46
66
  # they define the potential trees through a survey.
@@ -75,6 +95,10 @@ class Rule:
75
95
  self.priority = priority
76
96
  self.before_rule = before_rule
77
97
 
98
+ if not self.next_q == EndOfSurvey:
99
+ if self.next_q <= self.current_q:
100
+ raise SurveyRuleSendsYouBackwardsError
101
+
78
102
  if not self.next_q == EndOfSurvey and self.current_q > self.next_q:
79
103
  raise SurveyRuleSendsYouBackwardsError
80
104
 
@@ -120,13 +120,13 @@ class RuleCollection(UserList):
120
120
  :param answers: The answers to the survey questions.
121
121
 
122
122
  >>> rule_collection = RuleCollection()
123
- >>> r = Rule(current_q=1, expression="True", next_q=1, priority=1, question_name_to_index={}, before_rule = True)
123
+ >>> r = Rule(current_q=1, expression="True", next_q=2, priority=1, question_name_to_index={}, before_rule = True)
124
124
  >>> rule_collection.add_rule(r)
125
125
  >>> rule_collection.skip_question_before_running(1, {})
126
126
  True
127
127
 
128
128
  >>> rule_collection = RuleCollection()
129
- >>> r = Rule(current_q=1, expression="False", next_q=1, priority=1, question_name_to_index={}, before_rule = True)
129
+ >>> r = Rule(current_q=1, expression="False", next_q=2, priority=1, question_name_to_index={}, before_rule = True)
130
130
  >>> rule_collection.add_rule(r)
131
131
  >>> rule_collection.skip_question_before_running(1, {})
132
132
  False
@@ -321,6 +321,40 @@ class RuleCollection(UserList):
321
321
 
322
322
  return DAG(dict(sorted(children_to_parents.items())))
323
323
 
324
+ def detect_cycles(self):
325
+ """
326
+ Detect cycles in the survey rules using depth-first search.
327
+
328
+ :return: A list of cycles if any are found, otherwise an empty list.
329
+ """
330
+ dag = self.dag
331
+ visited = set()
332
+ path = []
333
+ cycles = []
334
+
335
+ def dfs(node):
336
+ if node in path:
337
+ cycle = path[path.index(node) :]
338
+ cycles.append(cycle + [node])
339
+ return
340
+
341
+ if node in visited:
342
+ return
343
+
344
+ visited.add(node)
345
+ path.append(node)
346
+
347
+ for child in dag.get(node, []):
348
+ dfs(child)
349
+
350
+ path.pop()
351
+
352
+ for node in dag:
353
+ if node not in visited:
354
+ dfs(node)
355
+
356
+ return cycles
357
+
324
358
  @classmethod
325
359
  def example(cls):
326
360
  """Create an example RuleCollection object."""
edsl/surveys/Survey.py CHANGED
@@ -22,6 +22,10 @@ from edsl.utilities.decorators import add_edsl_version, remove_edsl_version
22
22
 
23
23
  from edsl.agents.Agent import Agent
24
24
 
25
+ from edsl.surveys.instructions.InstructionCollection import InstructionCollection
26
+ from edsl.surveys.instructions.Instruction import Instruction
27
+ from edsl.surveys.instructions.ChangeInstruction import ChangeInstruction
28
+
25
29
 
26
30
  class ValidatedString(str):
27
31
  def __new__(cls, content):
@@ -32,13 +36,6 @@ class ValidatedString(str):
32
36
  return super().__new__(cls, content)
33
37
 
34
38
 
35
- # from edsl.surveys.Instruction import Instruction
36
- # from edsl.surveys.Instruction import ChangeInstruction
37
- from edsl.surveys.instructions.InstructionCollection import InstructionCollection
38
- from edsl.surveys.instructions.Instruction import Instruction
39
- from edsl.surveys.instructions.ChangeInstruction import ChangeInstruction
40
-
41
-
42
39
  class Survey(SurveyExportMixin, SurveyFlowVisualizationMixin, Base):
43
40
  """A collection of questions that supports skip logic."""
44
41
 
@@ -289,16 +286,52 @@ class Survey(SurveyExportMixin, SurveyFlowVisualizationMixin, Base):
289
286
 
290
287
  # region: Simulation methods
291
288
 
289
+ @classmethod
290
+ def random_survey(self):
291
+ """Create a random survey."""
292
+ from edsl.questions import QuestionMultipleChoice, QuestionFreeText
293
+ from random import choice
294
+
295
+ num_questions = 10
296
+ questions = []
297
+ for i in range(num_questions):
298
+ if choice([True, False]):
299
+ q = QuestionMultipleChoice(
300
+ question_text="nothing",
301
+ question_name="q_" + str(i),
302
+ question_options=list(range(3)),
303
+ )
304
+ questions.append(q)
305
+ else:
306
+ questions.append(
307
+ QuestionFreeText(
308
+ question_text="nothing", question_name="q_" + str(i)
309
+ )
310
+ )
311
+ s = Survey(questions)
312
+ start_index = choice(range(num_questions - 1))
313
+ end_index = choice(range(start_index + 1, 10))
314
+ s = s.add_rule(f"q_{start_index}", "True", f"q_{end_index}")
315
+ question_to_delete = choice(range(num_questions))
316
+ s.delete_question(f"q_{question_to_delete}")
317
+ return s
318
+
292
319
  def simulate(self) -> dict:
293
320
  """Simulate the survey and return the answers."""
294
321
  i = self.gen_path_through_survey()
295
322
  q = next(i)
323
+ num_passes = 0
296
324
  while True:
325
+ num_passes += 1
297
326
  try:
298
327
  answer = q._simulate_answer()
299
328
  q = i.send({q.question_name: answer["answer"]})
300
329
  except StopIteration:
301
330
  break
331
+
332
+ if num_passes > 100:
333
+ print("Too many passes.")
334
+ raise Exception("Too many passes.")
302
335
  return self.answers
303
336
 
304
337
  def create_agent(self) -> "Agent":
@@ -573,7 +606,110 @@ class Survey(SurveyExportMixin, SurveyFlowVisualizationMixin, Base):
573
606
 
574
607
  return Survey(questions=self.questions + other.questions)
575
608
 
576
- def add_question(self, question: QuestionBase) -> Survey:
609
+ def move_question(self, identifier: Union[str, int], new_index: int):
610
+ if isinstance(identifier, str):
611
+ if identifier not in self.question_names:
612
+ raise ValueError(
613
+ f"Question name '{identifier}' does not exist in the survey."
614
+ )
615
+ index = self.question_name_to_index[identifier]
616
+ elif isinstance(identifier, int):
617
+ if identifier < 0 or identifier >= len(self.questions):
618
+ raise ValueError(f"Index {identifier} is out of range.")
619
+ index = identifier
620
+ else:
621
+ raise TypeError(
622
+ "Identifier must be either a string (question name) or an integer (question index)."
623
+ )
624
+
625
+ moving_question = self._questions[index]
626
+
627
+ new_survey = self.delete_question(index)
628
+ new_survey.add_question(moving_question, new_index)
629
+ return new_survey
630
+
631
+ def delete_question(self, identifier: Union[str, int]) -> Survey:
632
+ """
633
+ Delete a question from the survey.
634
+
635
+ :param identifier: The name or index of the question to delete.
636
+ :return: The updated Survey object.
637
+
638
+ >>> from edsl import QuestionMultipleChoice, Survey
639
+ >>> q1 = QuestionMultipleChoice(question_text="Q1", question_options=["A", "B"], question_name="q1")
640
+ >>> q2 = QuestionMultipleChoice(question_text="Q2", question_options=["C", "D"], question_name="q2")
641
+ >>> s = Survey().add_question(q1).add_question(q2)
642
+ >>> _ = s.delete_question("q1")
643
+ >>> len(s.questions)
644
+ 1
645
+ >>> _ = s.delete_question(0)
646
+ >>> len(s.questions)
647
+ 0
648
+ """
649
+ if isinstance(identifier, str):
650
+ if identifier not in self.question_names:
651
+ raise ValueError(
652
+ f"Question name '{identifier}' does not exist in the survey."
653
+ )
654
+ index = self.question_name_to_index[identifier]
655
+ elif isinstance(identifier, int):
656
+ if identifier < 0 or identifier >= len(self.questions):
657
+ raise ValueError(f"Index {identifier} is out of range.")
658
+ index = identifier
659
+ else:
660
+ raise TypeError(
661
+ "Identifier must be either a string (question name) or an integer (question index)."
662
+ )
663
+
664
+ # Remove the question
665
+ deleted_question = self._questions.pop(index)
666
+ del self.pseudo_indices[deleted_question.question_name]
667
+ # del self.question_name_to_index[deleted_question.question_name]
668
+
669
+ # Update indices
670
+ for question_name, old_index in self.pseudo_indices.items():
671
+ if old_index > index:
672
+ self.pseudo_indices[question_name] = old_index - 1
673
+
674
+ # for question_name, old_index in self.question_name_to_index.items():
675
+ # if old_index > index:
676
+ # self.question_name_to_index[question_name] = old_index - 1
677
+
678
+ # Update rules
679
+ new_rule_collection = RuleCollection()
680
+ for rule in self.rule_collection:
681
+ if rule.current_q == index:
682
+ continue # Remove rules associated with the deleted question
683
+ if rule.current_q > index:
684
+ rule.current_q -= 1
685
+ if rule.next_q > index:
686
+ rule.next_q -= 1
687
+
688
+ if rule.next_q == index:
689
+ if index == len(self.questions):
690
+ rule.next_q = EndOfSurvey
691
+ else:
692
+ rule.next_q = index
693
+ # rule.next_q = min(index, len(self.questions) - 1)
694
+ # continue
695
+
696
+ # if rule.next_q == index:
697
+ # rule.next_q = min(
698
+ # rule.next_q, len(self.questions) - 1
699
+ # ) # Adjust to last question if necessary
700
+
701
+ new_rule_collection.add_rule(rule)
702
+ self.rule_collection = new_rule_collection
703
+
704
+ # Update memory plan if it exists
705
+ if hasattr(self, "memory_plan"):
706
+ self.memory_plan.remove_question(deleted_question.question_name)
707
+
708
+ return self
709
+
710
+ def add_question(
711
+ self, question: QuestionBase, index: Optional[int] = None
712
+ ) -> Survey:
577
713
  """
578
714
  Add a question to survey.
579
715
 
@@ -596,15 +732,51 @@ class Survey(SurveyExportMixin, SurveyFlowVisualizationMixin, Base):
596
732
  raise SurveyCreationError(
597
733
  f"""Question name '{question.question_name}' already exists in survey. Existing names are {self.question_names}."""
598
734
  )
599
- index = len(self.questions)
735
+ if index is None:
736
+ index = len(self.questions)
737
+
738
+ if index > len(self.questions):
739
+ raise ValueError(
740
+ f"Index {index} is greater than the number of questions in the survey."
741
+ )
742
+ if index < 0:
743
+ raise ValueError(f"Index {index} is less than 0.")
744
+
745
+ interior_insertion = index != len(self.questions)
746
+
747
+ # index = len(self.questions)
600
748
  # TODO: This is a bit ugly because the user
601
749
  # doesn't "know" about _questions - it's generated by the
602
750
  # descriptor.
603
- self._questions.append(question)
751
+ self._questions.insert(index, question)
752
+
753
+ if interior_insertion:
754
+ for question_name, old_index in self.pseudo_indices.items():
755
+ if old_index >= index:
756
+ self.pseudo_indices[question_name] = old_index + 1
604
757
 
605
758
  self.pseudo_indices[question.question_name] = index
606
759
 
760
+ ## Re-do question_name to index - this is done automatically
761
+ # for question_name, old_index in self.question_name_to_index.items():
762
+ # if old_index >= index:
763
+ # self.question_name_to_index[question_name] = old_index + 1
764
+
765
+ ## Need to re-do the rule collection and the indices of the questions
766
+
767
+ ## If a rule is before the insertion index and next_q is also before the insertion index, no change needed.
768
+ ## If the rule is before the insertion index but next_q is after the insertion index, increment the next_q by 1
769
+ ## If the rule is after the insertion index, increment the current_q by 1 and the next_q by 1
770
+
607
771
  # using index + 1 presumes there is a next question
772
+ if interior_insertion:
773
+ for rule in self.rule_collection:
774
+ if rule.current_q >= index:
775
+ rule.current_q += 1
776
+ if rule.next_q >= index:
777
+ rule.next_q += 1
778
+
779
+ # add a new rule
608
780
  self.rule_collection.add_rule(
609
781
  Rule(
610
782
  current_q=index,
edsl/surveys/base.py CHANGED
@@ -36,6 +36,10 @@ class EndOfSurveyParent:
36
36
  """
37
37
  return self
38
38
 
39
+ def __deepcopy__(self, memo):
40
+ # Return the same instance when deepcopy is called
41
+ return self
42
+
39
43
  def __radd__(self, other):
40
44
  """Add the object to another object.
41
45
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: edsl
3
- Version: 0.1.33.dev3
3
+ Version: 0.1.34
4
4
  Summary: Create and analyze LLM-based surveys
5
5
  Home-page: https://www.expectedparrot.com/
6
6
  License: MIT
@@ -21,6 +21,7 @@ Requires-Dist: anthropic (>=0.23.1,<0.24.0)
21
21
  Requires-Dist: azure-ai-inference (>=1.0.0b3,<2.0.0)
22
22
  Requires-Dist: black[jupyter] (>=24.4.2,<25.0.0)
23
23
  Requires-Dist: boto3 (>=1.34.161,<2.0.0)
24
+ Requires-Dist: google-generativeai (>=0.8.2,<0.9.0)
24
25
  Requires-Dist: groq (>=0.9.0,<0.10.0)
25
26
  Requires-Dist: jinja2 (>=3.1.2,<4.0.0)
26
27
  Requires-Dist: json-repair (>=0.28.4,<0.29.0)