edsl 0.1.33.dev3__py3-none-any.whl → 0.1.34__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +15 -11
- edsl/__version__.py +1 -1
- edsl/agents/Invigilator.py +22 -3
- edsl/agents/PromptConstructor.py +80 -184
- edsl/agents/prompt_helpers.py +129 -0
- edsl/coop/coop.py +3 -2
- edsl/data_transfer_models.py +0 -1
- edsl/inference_services/AnthropicService.py +5 -2
- edsl/inference_services/AwsBedrock.py +5 -2
- edsl/inference_services/AzureAI.py +5 -2
- edsl/inference_services/GoogleService.py +108 -33
- edsl/inference_services/MistralAIService.py +5 -2
- edsl/inference_services/OpenAIService.py +3 -2
- edsl/inference_services/TestService.py +11 -2
- edsl/inference_services/TogetherAIService.py +1 -1
- edsl/jobs/Jobs.py +91 -10
- edsl/jobs/interviews/Interview.py +15 -2
- edsl/jobs/runners/JobsRunnerAsyncio.py +46 -25
- edsl/jobs/runners/JobsRunnerStatus.py +4 -3
- edsl/jobs/tasks/QuestionTaskCreator.py +1 -13
- edsl/language_models/LanguageModel.py +12 -9
- edsl/language_models/utilities.py +5 -2
- edsl/questions/QuestionBase.py +13 -3
- edsl/questions/QuestionBaseGenMixin.py +28 -0
- edsl/questions/QuestionCheckBox.py +1 -1
- edsl/questions/QuestionMultipleChoice.py +8 -4
- edsl/questions/ResponseValidatorABC.py +5 -1
- edsl/questions/descriptors.py +12 -11
- edsl/questions/templates/numerical/answering_instructions.jinja +0 -1
- edsl/questions/templates/yes_no/answering_instructions.jinja +2 -2
- edsl/scenarios/FileStore.py +159 -76
- edsl/scenarios/Scenario.py +23 -49
- edsl/scenarios/ScenarioList.py +6 -2
- edsl/surveys/DAG.py +62 -0
- edsl/surveys/MemoryPlan.py +26 -0
- edsl/surveys/Rule.py +24 -0
- edsl/surveys/RuleCollection.py +36 -2
- edsl/surveys/Survey.py +182 -10
- edsl/surveys/base.py +4 -0
- {edsl-0.1.33.dev3.dist-info → edsl-0.1.34.dist-info}/METADATA +2 -1
- {edsl-0.1.33.dev3.dist-info → edsl-0.1.34.dist-info}/RECORD +43 -43
- edsl/scenarios/ScenarioImageMixin.py +0 -100
- {edsl-0.1.33.dev3.dist-info → edsl-0.1.34.dist-info}/LICENSE +0 -0
- {edsl-0.1.33.dev3.dist-info → edsl-0.1.34.dist-info}/WHEEL +0 -0
edsl/scenarios/Scenario.py
CHANGED
@@ -2,25 +2,18 @@
|
|
2
2
|
|
3
3
|
from __future__ import annotations
|
4
4
|
import copy
|
5
|
-
import base64
|
6
5
|
import hashlib
|
7
6
|
import os
|
8
|
-
import reprlib
|
9
|
-
import imghdr
|
10
|
-
|
11
|
-
|
12
7
|
from collections import UserDict
|
13
8
|
from typing import Union, List, Optional, Generator
|
14
9
|
from uuid import uuid4
|
10
|
+
|
15
11
|
from edsl.Base import Base
|
16
|
-
from edsl.scenarios.ScenarioImageMixin import ScenarioImageMixin
|
17
12
|
from edsl.scenarios.ScenarioHtmlMixin import ScenarioHtmlMixin
|
18
13
|
from edsl.utilities.decorators import add_edsl_version, remove_edsl_version
|
19
14
|
|
20
|
-
from edsl.data_transfer_models import ImageInfo
|
21
|
-
|
22
15
|
|
23
|
-
class Scenario(Base, UserDict,
|
16
|
+
class Scenario(Base, UserDict, ScenarioHtmlMixin):
|
24
17
|
"""A Scenario is a dictionary of keys/values.
|
25
18
|
|
26
19
|
They can be used parameterize edsl questions."""
|
@@ -48,12 +41,12 @@ class Scenario(Base, UserDict, ScenarioImageMixin, ScenarioHtmlMixin):
|
|
48
41
|
|
49
42
|
return ScenarioList([copy.deepcopy(self) for _ in range(n)])
|
50
43
|
|
51
|
-
@property
|
52
|
-
def has_image(self) -> bool:
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
44
|
+
# @property
|
45
|
+
# def has_image(self) -> bool:
|
46
|
+
# """Return whether the scenario has an image."""
|
47
|
+
# if not hasattr(self, "_has_image"):
|
48
|
+
# self._has_image = False
|
49
|
+
# return self._has_image
|
57
50
|
|
58
51
|
@property
|
59
52
|
def has_jinja_braces(self) -> bool:
|
@@ -63,9 +56,10 @@ class Scenario(Base, UserDict, ScenarioImageMixin, ScenarioHtmlMixin):
|
|
63
56
|
>>> s.has_jinja_braces
|
64
57
|
True
|
65
58
|
"""
|
66
|
-
for
|
67
|
-
if
|
68
|
-
|
59
|
+
for _, value in self.items():
|
60
|
+
if isinstance(value, str):
|
61
|
+
if "{{" in value and "}}" in value:
|
62
|
+
return True
|
69
63
|
return False
|
70
64
|
|
71
65
|
def convert_jinja_braces(
|
@@ -88,10 +82,6 @@ class Scenario(Base, UserDict, ScenarioImageMixin, ScenarioHtmlMixin):
|
|
88
82
|
new_scenario[key] = value
|
89
83
|
return new_scenario
|
90
84
|
|
91
|
-
@has_image.setter
|
92
|
-
def has_image(self, value):
|
93
|
-
self._has_image = value
|
94
|
-
|
95
85
|
def __add__(self, other_scenario: "Scenario") -> "Scenario":
|
96
86
|
"""Combine two scenarios by taking the union of their keys
|
97
87
|
|
@@ -114,8 +104,6 @@ class Scenario(Base, UserDict, ScenarioImageMixin, ScenarioHtmlMixin):
|
|
114
104
|
data1 = copy.deepcopy(self.data)
|
115
105
|
data2 = copy.deepcopy(other_scenario.data)
|
116
106
|
s = Scenario(data1 | data2)
|
117
|
-
if self.has_image or other_scenario.has_image:
|
118
|
-
s._has_image = True
|
119
107
|
return s
|
120
108
|
|
121
109
|
def rename(self, replacement_dict: dict) -> "Scenario":
|
@@ -235,6 +223,14 @@ class Scenario(Base, UserDict, ScenarioImageMixin, ScenarioHtmlMixin):
|
|
235
223
|
text = requests.get(url).text
|
236
224
|
return cls({"url": url, field_name: text})
|
237
225
|
|
226
|
+
@classmethod
|
227
|
+
def from_file(cls, file_path: str, field_name: str) -> "Scenario":
|
228
|
+
"""Creates a scenario from a file."""
|
229
|
+
from edsl.scenarios.FileStore import FileStore
|
230
|
+
|
231
|
+
fs = FileStore(file_path)
|
232
|
+
return cls({field_name: fs})
|
233
|
+
|
238
234
|
@classmethod
|
239
235
|
def from_image(
|
240
236
|
cls, image_path: str, image_name: Optional[str] = None
|
@@ -248,36 +244,14 @@ class Scenario(Base, UserDict, ScenarioImageMixin, ScenarioHtmlMixin):
|
|
248
244
|
Returns:
|
249
245
|
Scenario: A new Scenario instance with image information.
|
250
246
|
|
251
|
-
Example:
|
252
|
-
>>> s = Scenario.from_image(Scenario.example_image())
|
253
|
-
>>> s
|
254
|
-
Scenario({'logo': ...})
|
255
247
|
"""
|
256
248
|
if not os.path.exists(image_path):
|
257
249
|
raise FileNotFoundError(f"Image file not found: {image_path}")
|
258
250
|
|
259
|
-
with open(image_path, "rb") as image_file:
|
260
|
-
file_content = image_file.read()
|
261
|
-
|
262
|
-
file_name = os.path.basename(image_path)
|
263
|
-
file_size = os.path.getsize(image_path)
|
264
|
-
image_format = imghdr.what(image_path) or "unknown"
|
265
|
-
|
266
251
|
if image_name is None:
|
267
|
-
image_name =
|
268
|
-
|
269
|
-
|
270
|
-
file_path=image_path,
|
271
|
-
file_name=file_name,
|
272
|
-
image_format=image_format,
|
273
|
-
file_size=file_size,
|
274
|
-
encoded_image=base64.b64encode(file_content).decode("utf-8"),
|
275
|
-
)
|
276
|
-
|
277
|
-
scenario_data = {image_name: image_info}
|
278
|
-
s = cls(scenario_data)
|
279
|
-
s.has_image = True
|
280
|
-
return s
|
252
|
+
image_name = os.path.basename(image_path).split(".")[0]
|
253
|
+
|
254
|
+
return cls.from_file(image_path, image_name)
|
281
255
|
|
282
256
|
@classmethod
|
283
257
|
def from_pdf(cls, pdf_path):
|
edsl/scenarios/ScenarioList.py
CHANGED
@@ -530,7 +530,9 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
530
530
|
return ScenarioList([scenario.drop(fields) for scenario in self.data])
|
531
531
|
|
532
532
|
@classmethod
|
533
|
-
def from_list(
|
533
|
+
def from_list(
|
534
|
+
cls, name: str, values: list, func: Optional[Callable] = None
|
535
|
+
) -> ScenarioList:
|
534
536
|
"""Create a ScenarioList from a list of values.
|
535
537
|
|
536
538
|
Example:
|
@@ -538,7 +540,9 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
|
|
538
540
|
>>> ScenarioList.from_list('name', ['Alice', 'Bob'])
|
539
541
|
ScenarioList([Scenario({'name': 'Alice'}), Scenario({'name': 'Bob'})])
|
540
542
|
"""
|
541
|
-
|
543
|
+
if not func:
|
544
|
+
func = lambda x: x
|
545
|
+
return cls([Scenario({name: func(value)}) for value in values])
|
542
546
|
|
543
547
|
def to_dataset(self) -> "Dataset":
|
544
548
|
"""
|
edsl/surveys/DAG.py
CHANGED
@@ -11,6 +11,7 @@ class DAG(UserDict):
|
|
11
11
|
"""Initialize the DAG class."""
|
12
12
|
super().__init__(data)
|
13
13
|
self.reverse_mapping = self._create_reverse_mapping()
|
14
|
+
self.validate_no_cycles()
|
14
15
|
|
15
16
|
def _create_reverse_mapping(self):
|
16
17
|
"""
|
@@ -73,12 +74,73 @@ class DAG(UserDict):
|
|
73
74
|
# else:
|
74
75
|
# return DAG(d)
|
75
76
|
|
77
|
+
def remove_node(self, node: int) -> None:
|
78
|
+
"""Remove a node and all its connections from the DAG."""
|
79
|
+
self.pop(node, None)
|
80
|
+
for connections in self.values():
|
81
|
+
connections.discard(node)
|
82
|
+
# Adjust remaining nodes if necessary
|
83
|
+
self._adjust_nodes_after_removal(node)
|
84
|
+
|
85
|
+
def _adjust_nodes_after_removal(self, removed_node: int) -> None:
|
86
|
+
"""Adjust node indices after a node is removed."""
|
87
|
+
new_dag = {}
|
88
|
+
for node, connections in self.items():
|
89
|
+
new_node = node if node < removed_node else node - 1
|
90
|
+
new_connections = {c if c < removed_node else c - 1 for c in connections}
|
91
|
+
new_dag[new_node] = new_connections
|
92
|
+
self.clear()
|
93
|
+
self.update(new_dag)
|
94
|
+
|
76
95
|
@classmethod
|
77
96
|
def example(cls):
|
78
97
|
"""Return an example of the `DAG`."""
|
79
98
|
data = {"a": ["b", "c"], "b": ["d"], "c": [], "d": []}
|
80
99
|
return cls(data)
|
81
100
|
|
101
|
+
def detect_cycles(self):
|
102
|
+
"""
|
103
|
+
Detect cycles in the DAG using depth-first search.
|
104
|
+
|
105
|
+
:return: A list of cycles if any are found, otherwise an empty list.
|
106
|
+
"""
|
107
|
+
visited = set()
|
108
|
+
path = []
|
109
|
+
cycles = []
|
110
|
+
|
111
|
+
def dfs(node):
|
112
|
+
if node in path:
|
113
|
+
cycle = path[path.index(node) :]
|
114
|
+
cycles.append(cycle + [node])
|
115
|
+
return
|
116
|
+
|
117
|
+
if node in visited:
|
118
|
+
return
|
119
|
+
|
120
|
+
visited.add(node)
|
121
|
+
path.append(node)
|
122
|
+
|
123
|
+
for child in self.get(node, []):
|
124
|
+
dfs(child)
|
125
|
+
|
126
|
+
path.pop()
|
127
|
+
|
128
|
+
for node in self:
|
129
|
+
if node not in visited:
|
130
|
+
dfs(node)
|
131
|
+
|
132
|
+
return cycles
|
133
|
+
|
134
|
+
def validate_no_cycles(self):
|
135
|
+
"""
|
136
|
+
Validate that the DAG does not contain any cycles.
|
137
|
+
|
138
|
+
:raises ValueError: If cycles are detected in the DAG.
|
139
|
+
"""
|
140
|
+
cycles = self.detect_cycles()
|
141
|
+
if cycles:
|
142
|
+
raise ValueError(f"Cycles detected in the DAG: {cycles}")
|
143
|
+
|
82
144
|
|
83
145
|
if __name__ == "__main__":
|
84
146
|
import doctest
|
edsl/surveys/MemoryPlan.py
CHANGED
@@ -211,6 +211,32 @@ class MemoryPlan(UserDict):
|
|
211
211
|
mp.add_single_memory("q1", "q0")
|
212
212
|
return mp
|
213
213
|
|
214
|
+
def remove_question(self, question_name: str) -> None:
|
215
|
+
"""Remove a question from the memory plan.
|
216
|
+
|
217
|
+
:param question_name: The name of the question to remove.
|
218
|
+
"""
|
219
|
+
self._check_valid_question_name(question_name)
|
220
|
+
|
221
|
+
# Remove the question from survey_question_names and question_texts
|
222
|
+
index = self.survey_question_names.index(question_name)
|
223
|
+
self.survey_question_names.pop(index)
|
224
|
+
self.question_texts.pop(index)
|
225
|
+
|
226
|
+
# Remove the question from the memory plan if it's a focal question
|
227
|
+
self.pop(question_name, None)
|
228
|
+
|
229
|
+
# Remove the question from all memories where it appears as a prior question
|
230
|
+
for focal_question, memory in self.items():
|
231
|
+
memory.remove_prior_question(question_name)
|
232
|
+
|
233
|
+
# Update the DAG
|
234
|
+
self.dag.remove_node(index)
|
235
|
+
|
236
|
+
def remove_prior_question(self, question_name: str) -> None:
|
237
|
+
"""Remove a prior question from the memory."""
|
238
|
+
self.prior_questions = [q for q in self.prior_questions if q != question_name]
|
239
|
+
|
214
240
|
|
215
241
|
if __name__ == "__main__":
|
216
242
|
import doctest
|
edsl/surveys/Rule.py
CHANGED
@@ -38,9 +38,29 @@ from edsl.utilities.ast_utilities import extract_variable_names
|
|
38
38
|
from edsl.utilities.decorators import add_edsl_version, remove_edsl_version
|
39
39
|
|
40
40
|
|
41
|
+
class QuestionIndex:
|
42
|
+
def __set_name__(self, owner, name):
|
43
|
+
self.name = f"_{name}"
|
44
|
+
|
45
|
+
def __get__(self, obj, objtype=None):
|
46
|
+
return getattr(obj, self.name)
|
47
|
+
|
48
|
+
def __set__(self, obj, value):
|
49
|
+
if not isinstance(value, (int, EndOfSurvey.__class__)):
|
50
|
+
raise ValueError(f"{self.name} must be an integer or EndOfSurvey")
|
51
|
+
if self.name == "_next_q" and isinstance(value, int):
|
52
|
+
current_q = getattr(obj, "_current_q")
|
53
|
+
if value <= current_q:
|
54
|
+
raise ValueError("next_q must be greater than current_q")
|
55
|
+
setattr(obj, self.name, value)
|
56
|
+
|
57
|
+
|
41
58
|
class Rule:
|
42
59
|
"""The Rule class defines a "rule" for determining the next question presented to an agent."""
|
43
60
|
|
61
|
+
current_q = QuestionIndex()
|
62
|
+
next_q = QuestionIndex()
|
63
|
+
|
44
64
|
# Not implemented but nice to have:
|
45
65
|
# We could potentially use the question pydantic models to check for rule conflicts, as
|
46
66
|
# they define the potential trees through a survey.
|
@@ -75,6 +95,10 @@ class Rule:
|
|
75
95
|
self.priority = priority
|
76
96
|
self.before_rule = before_rule
|
77
97
|
|
98
|
+
if not self.next_q == EndOfSurvey:
|
99
|
+
if self.next_q <= self.current_q:
|
100
|
+
raise SurveyRuleSendsYouBackwardsError
|
101
|
+
|
78
102
|
if not self.next_q == EndOfSurvey and self.current_q > self.next_q:
|
79
103
|
raise SurveyRuleSendsYouBackwardsError
|
80
104
|
|
edsl/surveys/RuleCollection.py
CHANGED
@@ -120,13 +120,13 @@ class RuleCollection(UserList):
|
|
120
120
|
:param answers: The answers to the survey questions.
|
121
121
|
|
122
122
|
>>> rule_collection = RuleCollection()
|
123
|
-
>>> r = Rule(current_q=1, expression="True", next_q=
|
123
|
+
>>> r = Rule(current_q=1, expression="True", next_q=2, priority=1, question_name_to_index={}, before_rule = True)
|
124
124
|
>>> rule_collection.add_rule(r)
|
125
125
|
>>> rule_collection.skip_question_before_running(1, {})
|
126
126
|
True
|
127
127
|
|
128
128
|
>>> rule_collection = RuleCollection()
|
129
|
-
>>> r = Rule(current_q=1, expression="False", next_q=
|
129
|
+
>>> r = Rule(current_q=1, expression="False", next_q=2, priority=1, question_name_to_index={}, before_rule = True)
|
130
130
|
>>> rule_collection.add_rule(r)
|
131
131
|
>>> rule_collection.skip_question_before_running(1, {})
|
132
132
|
False
|
@@ -321,6 +321,40 @@ class RuleCollection(UserList):
|
|
321
321
|
|
322
322
|
return DAG(dict(sorted(children_to_parents.items())))
|
323
323
|
|
324
|
+
def detect_cycles(self):
|
325
|
+
"""
|
326
|
+
Detect cycles in the survey rules using depth-first search.
|
327
|
+
|
328
|
+
:return: A list of cycles if any are found, otherwise an empty list.
|
329
|
+
"""
|
330
|
+
dag = self.dag
|
331
|
+
visited = set()
|
332
|
+
path = []
|
333
|
+
cycles = []
|
334
|
+
|
335
|
+
def dfs(node):
|
336
|
+
if node in path:
|
337
|
+
cycle = path[path.index(node) :]
|
338
|
+
cycles.append(cycle + [node])
|
339
|
+
return
|
340
|
+
|
341
|
+
if node in visited:
|
342
|
+
return
|
343
|
+
|
344
|
+
visited.add(node)
|
345
|
+
path.append(node)
|
346
|
+
|
347
|
+
for child in dag.get(node, []):
|
348
|
+
dfs(child)
|
349
|
+
|
350
|
+
path.pop()
|
351
|
+
|
352
|
+
for node in dag:
|
353
|
+
if node not in visited:
|
354
|
+
dfs(node)
|
355
|
+
|
356
|
+
return cycles
|
357
|
+
|
324
358
|
@classmethod
|
325
359
|
def example(cls):
|
326
360
|
"""Create an example RuleCollection object."""
|
edsl/surveys/Survey.py
CHANGED
@@ -22,6 +22,10 @@ from edsl.utilities.decorators import add_edsl_version, remove_edsl_version
|
|
22
22
|
|
23
23
|
from edsl.agents.Agent import Agent
|
24
24
|
|
25
|
+
from edsl.surveys.instructions.InstructionCollection import InstructionCollection
|
26
|
+
from edsl.surveys.instructions.Instruction import Instruction
|
27
|
+
from edsl.surveys.instructions.ChangeInstruction import ChangeInstruction
|
28
|
+
|
25
29
|
|
26
30
|
class ValidatedString(str):
|
27
31
|
def __new__(cls, content):
|
@@ -32,13 +36,6 @@ class ValidatedString(str):
|
|
32
36
|
return super().__new__(cls, content)
|
33
37
|
|
34
38
|
|
35
|
-
# from edsl.surveys.Instruction import Instruction
|
36
|
-
# from edsl.surveys.Instruction import ChangeInstruction
|
37
|
-
from edsl.surveys.instructions.InstructionCollection import InstructionCollection
|
38
|
-
from edsl.surveys.instructions.Instruction import Instruction
|
39
|
-
from edsl.surveys.instructions.ChangeInstruction import ChangeInstruction
|
40
|
-
|
41
|
-
|
42
39
|
class Survey(SurveyExportMixin, SurveyFlowVisualizationMixin, Base):
|
43
40
|
"""A collection of questions that supports skip logic."""
|
44
41
|
|
@@ -289,16 +286,52 @@ class Survey(SurveyExportMixin, SurveyFlowVisualizationMixin, Base):
|
|
289
286
|
|
290
287
|
# region: Simulation methods
|
291
288
|
|
289
|
+
@classmethod
|
290
|
+
def random_survey(self):
|
291
|
+
"""Create a random survey."""
|
292
|
+
from edsl.questions import QuestionMultipleChoice, QuestionFreeText
|
293
|
+
from random import choice
|
294
|
+
|
295
|
+
num_questions = 10
|
296
|
+
questions = []
|
297
|
+
for i in range(num_questions):
|
298
|
+
if choice([True, False]):
|
299
|
+
q = QuestionMultipleChoice(
|
300
|
+
question_text="nothing",
|
301
|
+
question_name="q_" + str(i),
|
302
|
+
question_options=list(range(3)),
|
303
|
+
)
|
304
|
+
questions.append(q)
|
305
|
+
else:
|
306
|
+
questions.append(
|
307
|
+
QuestionFreeText(
|
308
|
+
question_text="nothing", question_name="q_" + str(i)
|
309
|
+
)
|
310
|
+
)
|
311
|
+
s = Survey(questions)
|
312
|
+
start_index = choice(range(num_questions - 1))
|
313
|
+
end_index = choice(range(start_index + 1, 10))
|
314
|
+
s = s.add_rule(f"q_{start_index}", "True", f"q_{end_index}")
|
315
|
+
question_to_delete = choice(range(num_questions))
|
316
|
+
s.delete_question(f"q_{question_to_delete}")
|
317
|
+
return s
|
318
|
+
|
292
319
|
def simulate(self) -> dict:
|
293
320
|
"""Simulate the survey and return the answers."""
|
294
321
|
i = self.gen_path_through_survey()
|
295
322
|
q = next(i)
|
323
|
+
num_passes = 0
|
296
324
|
while True:
|
325
|
+
num_passes += 1
|
297
326
|
try:
|
298
327
|
answer = q._simulate_answer()
|
299
328
|
q = i.send({q.question_name: answer["answer"]})
|
300
329
|
except StopIteration:
|
301
330
|
break
|
331
|
+
|
332
|
+
if num_passes > 100:
|
333
|
+
print("Too many passes.")
|
334
|
+
raise Exception("Too many passes.")
|
302
335
|
return self.answers
|
303
336
|
|
304
337
|
def create_agent(self) -> "Agent":
|
@@ -573,7 +606,110 @@ class Survey(SurveyExportMixin, SurveyFlowVisualizationMixin, Base):
|
|
573
606
|
|
574
607
|
return Survey(questions=self.questions + other.questions)
|
575
608
|
|
576
|
-
def
|
609
|
+
def move_question(self, identifier: Union[str, int], new_index: int):
|
610
|
+
if isinstance(identifier, str):
|
611
|
+
if identifier not in self.question_names:
|
612
|
+
raise ValueError(
|
613
|
+
f"Question name '{identifier}' does not exist in the survey."
|
614
|
+
)
|
615
|
+
index = self.question_name_to_index[identifier]
|
616
|
+
elif isinstance(identifier, int):
|
617
|
+
if identifier < 0 or identifier >= len(self.questions):
|
618
|
+
raise ValueError(f"Index {identifier} is out of range.")
|
619
|
+
index = identifier
|
620
|
+
else:
|
621
|
+
raise TypeError(
|
622
|
+
"Identifier must be either a string (question name) or an integer (question index)."
|
623
|
+
)
|
624
|
+
|
625
|
+
moving_question = self._questions[index]
|
626
|
+
|
627
|
+
new_survey = self.delete_question(index)
|
628
|
+
new_survey.add_question(moving_question, new_index)
|
629
|
+
return new_survey
|
630
|
+
|
631
|
+
def delete_question(self, identifier: Union[str, int]) -> Survey:
|
632
|
+
"""
|
633
|
+
Delete a question from the survey.
|
634
|
+
|
635
|
+
:param identifier: The name or index of the question to delete.
|
636
|
+
:return: The updated Survey object.
|
637
|
+
|
638
|
+
>>> from edsl import QuestionMultipleChoice, Survey
|
639
|
+
>>> q1 = QuestionMultipleChoice(question_text="Q1", question_options=["A", "B"], question_name="q1")
|
640
|
+
>>> q2 = QuestionMultipleChoice(question_text="Q2", question_options=["C", "D"], question_name="q2")
|
641
|
+
>>> s = Survey().add_question(q1).add_question(q2)
|
642
|
+
>>> _ = s.delete_question("q1")
|
643
|
+
>>> len(s.questions)
|
644
|
+
1
|
645
|
+
>>> _ = s.delete_question(0)
|
646
|
+
>>> len(s.questions)
|
647
|
+
0
|
648
|
+
"""
|
649
|
+
if isinstance(identifier, str):
|
650
|
+
if identifier not in self.question_names:
|
651
|
+
raise ValueError(
|
652
|
+
f"Question name '{identifier}' does not exist in the survey."
|
653
|
+
)
|
654
|
+
index = self.question_name_to_index[identifier]
|
655
|
+
elif isinstance(identifier, int):
|
656
|
+
if identifier < 0 or identifier >= len(self.questions):
|
657
|
+
raise ValueError(f"Index {identifier} is out of range.")
|
658
|
+
index = identifier
|
659
|
+
else:
|
660
|
+
raise TypeError(
|
661
|
+
"Identifier must be either a string (question name) or an integer (question index)."
|
662
|
+
)
|
663
|
+
|
664
|
+
# Remove the question
|
665
|
+
deleted_question = self._questions.pop(index)
|
666
|
+
del self.pseudo_indices[deleted_question.question_name]
|
667
|
+
# del self.question_name_to_index[deleted_question.question_name]
|
668
|
+
|
669
|
+
# Update indices
|
670
|
+
for question_name, old_index in self.pseudo_indices.items():
|
671
|
+
if old_index > index:
|
672
|
+
self.pseudo_indices[question_name] = old_index - 1
|
673
|
+
|
674
|
+
# for question_name, old_index in self.question_name_to_index.items():
|
675
|
+
# if old_index > index:
|
676
|
+
# self.question_name_to_index[question_name] = old_index - 1
|
677
|
+
|
678
|
+
# Update rules
|
679
|
+
new_rule_collection = RuleCollection()
|
680
|
+
for rule in self.rule_collection:
|
681
|
+
if rule.current_q == index:
|
682
|
+
continue # Remove rules associated with the deleted question
|
683
|
+
if rule.current_q > index:
|
684
|
+
rule.current_q -= 1
|
685
|
+
if rule.next_q > index:
|
686
|
+
rule.next_q -= 1
|
687
|
+
|
688
|
+
if rule.next_q == index:
|
689
|
+
if index == len(self.questions):
|
690
|
+
rule.next_q = EndOfSurvey
|
691
|
+
else:
|
692
|
+
rule.next_q = index
|
693
|
+
# rule.next_q = min(index, len(self.questions) - 1)
|
694
|
+
# continue
|
695
|
+
|
696
|
+
# if rule.next_q == index:
|
697
|
+
# rule.next_q = min(
|
698
|
+
# rule.next_q, len(self.questions) - 1
|
699
|
+
# ) # Adjust to last question if necessary
|
700
|
+
|
701
|
+
new_rule_collection.add_rule(rule)
|
702
|
+
self.rule_collection = new_rule_collection
|
703
|
+
|
704
|
+
# Update memory plan if it exists
|
705
|
+
if hasattr(self, "memory_plan"):
|
706
|
+
self.memory_plan.remove_question(deleted_question.question_name)
|
707
|
+
|
708
|
+
return self
|
709
|
+
|
710
|
+
def add_question(
|
711
|
+
self, question: QuestionBase, index: Optional[int] = None
|
712
|
+
) -> Survey:
|
577
713
|
"""
|
578
714
|
Add a question to survey.
|
579
715
|
|
@@ -596,15 +732,51 @@ class Survey(SurveyExportMixin, SurveyFlowVisualizationMixin, Base):
|
|
596
732
|
raise SurveyCreationError(
|
597
733
|
f"""Question name '{question.question_name}' already exists in survey. Existing names are {self.question_names}."""
|
598
734
|
)
|
599
|
-
index
|
735
|
+
if index is None:
|
736
|
+
index = len(self.questions)
|
737
|
+
|
738
|
+
if index > len(self.questions):
|
739
|
+
raise ValueError(
|
740
|
+
f"Index {index} is greater than the number of questions in the survey."
|
741
|
+
)
|
742
|
+
if index < 0:
|
743
|
+
raise ValueError(f"Index {index} is less than 0.")
|
744
|
+
|
745
|
+
interior_insertion = index != len(self.questions)
|
746
|
+
|
747
|
+
# index = len(self.questions)
|
600
748
|
# TODO: This is a bit ugly because the user
|
601
749
|
# doesn't "know" about _questions - it's generated by the
|
602
750
|
# descriptor.
|
603
|
-
self._questions.
|
751
|
+
self._questions.insert(index, question)
|
752
|
+
|
753
|
+
if interior_insertion:
|
754
|
+
for question_name, old_index in self.pseudo_indices.items():
|
755
|
+
if old_index >= index:
|
756
|
+
self.pseudo_indices[question_name] = old_index + 1
|
604
757
|
|
605
758
|
self.pseudo_indices[question.question_name] = index
|
606
759
|
|
760
|
+
## Re-do question_name to index - this is done automatically
|
761
|
+
# for question_name, old_index in self.question_name_to_index.items():
|
762
|
+
# if old_index >= index:
|
763
|
+
# self.question_name_to_index[question_name] = old_index + 1
|
764
|
+
|
765
|
+
## Need to re-do the rule collection and the indices of the questions
|
766
|
+
|
767
|
+
## If a rule is before the insertion index and next_q is also before the insertion index, no change needed.
|
768
|
+
## If the rule is before the insertion index but next_q is after the insertion index, increment the next_q by 1
|
769
|
+
## If the rule is after the insertion index, increment the current_q by 1 and the next_q by 1
|
770
|
+
|
607
771
|
# using index + 1 presumes there is a next question
|
772
|
+
if interior_insertion:
|
773
|
+
for rule in self.rule_collection:
|
774
|
+
if rule.current_q >= index:
|
775
|
+
rule.current_q += 1
|
776
|
+
if rule.next_q >= index:
|
777
|
+
rule.next_q += 1
|
778
|
+
|
779
|
+
# add a new rule
|
608
780
|
self.rule_collection.add_rule(
|
609
781
|
Rule(
|
610
782
|
current_q=index,
|
edsl/surveys/base.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: edsl
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.34
|
4
4
|
Summary: Create and analyze LLM-based surveys
|
5
5
|
Home-page: https://www.expectedparrot.com/
|
6
6
|
License: MIT
|
@@ -21,6 +21,7 @@ Requires-Dist: anthropic (>=0.23.1,<0.24.0)
|
|
21
21
|
Requires-Dist: azure-ai-inference (>=1.0.0b3,<2.0.0)
|
22
22
|
Requires-Dist: black[jupyter] (>=24.4.2,<25.0.0)
|
23
23
|
Requires-Dist: boto3 (>=1.34.161,<2.0.0)
|
24
|
+
Requires-Dist: google-generativeai (>=0.8.2,<0.9.0)
|
24
25
|
Requires-Dist: groq (>=0.9.0,<0.10.0)
|
25
26
|
Requires-Dist: jinja2 (>=3.1.2,<4.0.0)
|
26
27
|
Requires-Dist: json-repair (>=0.28.4,<0.29.0)
|