edsl 0.1.33.dev2__py3-none-any.whl → 0.1.34__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +24 -14
- edsl/__init__.py +1 -0
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +6 -6
- edsl/agents/Invigilator.py +28 -6
- edsl/agents/InvigilatorBase.py +8 -27
- edsl/agents/{PromptConstructionMixin.py → PromptConstructor.py} +150 -182
- edsl/agents/prompt_helpers.py +129 -0
- edsl/config.py +26 -34
- edsl/coop/coop.py +14 -4
- edsl/data_transfer_models.py +26 -73
- edsl/enums.py +2 -0
- edsl/inference_services/AnthropicService.py +5 -2
- edsl/inference_services/AwsBedrock.py +5 -2
- edsl/inference_services/AzureAI.py +5 -2
- edsl/inference_services/GoogleService.py +108 -33
- edsl/inference_services/InferenceServiceABC.py +44 -13
- edsl/inference_services/MistralAIService.py +5 -2
- edsl/inference_services/OpenAIService.py +10 -6
- edsl/inference_services/TestService.py +34 -16
- edsl/inference_services/TogetherAIService.py +170 -0
- edsl/inference_services/registry.py +2 -0
- edsl/jobs/Jobs.py +109 -18
- edsl/jobs/buckets/BucketCollection.py +24 -15
- edsl/jobs/buckets/TokenBucket.py +64 -10
- edsl/jobs/interviews/Interview.py +130 -49
- edsl/jobs/interviews/{interview_exception_tracking.py → InterviewExceptionCollection.py} +16 -0
- edsl/jobs/interviews/InterviewExceptionEntry.py +2 -0
- edsl/jobs/runners/JobsRunnerAsyncio.py +119 -173
- edsl/jobs/runners/JobsRunnerStatus.py +332 -0
- edsl/jobs/tasks/QuestionTaskCreator.py +1 -13
- edsl/jobs/tasks/TaskHistory.py +17 -0
- edsl/language_models/LanguageModel.py +36 -38
- edsl/language_models/registry.py +13 -9
- edsl/language_models/utilities.py +5 -2
- edsl/questions/QuestionBase.py +74 -16
- edsl/questions/QuestionBaseGenMixin.py +28 -0
- edsl/questions/QuestionBudget.py +93 -41
- edsl/questions/QuestionCheckBox.py +1 -1
- edsl/questions/QuestionFreeText.py +6 -0
- edsl/questions/QuestionMultipleChoice.py +13 -24
- edsl/questions/QuestionNumerical.py +5 -4
- edsl/questions/Quick.py +41 -0
- edsl/questions/ResponseValidatorABC.py +11 -6
- edsl/questions/derived/QuestionLinearScale.py +4 -1
- edsl/questions/derived/QuestionTopK.py +4 -1
- edsl/questions/derived/QuestionYesNo.py +8 -2
- edsl/questions/descriptors.py +12 -11
- edsl/questions/templates/budget/__init__.py +0 -0
- edsl/questions/templates/budget/answering_instructions.jinja +7 -0
- edsl/questions/templates/budget/question_presentation.jinja +7 -0
- edsl/questions/templates/extract/__init__.py +0 -0
- edsl/questions/templates/numerical/answering_instructions.jinja +0 -1
- edsl/questions/templates/rank/__init__.py +0 -0
- edsl/questions/templates/yes_no/answering_instructions.jinja +2 -2
- edsl/results/DatasetExportMixin.py +5 -1
- edsl/results/Result.py +1 -1
- edsl/results/Results.py +4 -1
- edsl/scenarios/FileStore.py +178 -34
- edsl/scenarios/Scenario.py +76 -37
- edsl/scenarios/ScenarioList.py +19 -2
- edsl/scenarios/ScenarioListPdfMixin.py +150 -4
- edsl/study/Study.py +32 -0
- edsl/surveys/DAG.py +62 -0
- edsl/surveys/MemoryPlan.py +26 -0
- edsl/surveys/Rule.py +34 -1
- edsl/surveys/RuleCollection.py +55 -5
- edsl/surveys/Survey.py +189 -10
- edsl/surveys/base.py +4 -0
- edsl/templates/error_reporting/interview_details.html +6 -1
- edsl/utilities/utilities.py +9 -1
- {edsl-0.1.33.dev2.dist-info → edsl-0.1.34.dist-info}/METADATA +3 -1
- {edsl-0.1.33.dev2.dist-info → edsl-0.1.34.dist-info}/RECORD +75 -69
- edsl/jobs/interviews/retry_management.py +0 -39
- edsl/jobs/runners/JobsRunnerStatusMixin.py +0 -333
- edsl/scenarios/ScenarioImageMixin.py +0 -100
- {edsl-0.1.33.dev2.dist-info → edsl-0.1.34.dist-info}/LICENSE +0 -0
- {edsl-0.1.33.dev2.dist-info → edsl-0.1.34.dist-info}/WHEEL +0 -0
@@ -1,145 +1,35 @@
|
|
1
1
|
from __future__ import annotations
|
2
|
-
from typing import Dict, Any, Optional
|
2
|
+
from typing import Dict, Any, Optional, Set
|
3
3
|
from collections import UserList
|
4
|
+
import pdb
|
4
5
|
|
5
|
-
|
6
|
-
from edsl.prompts.Prompt import Prompt
|
6
|
+
from jinja2 import Environment, meta
|
7
7
|
|
8
|
-
|
8
|
+
from edsl.prompts.Prompt import Prompt
|
9
|
+
from edsl.data_transfer_models import ImageInfo
|
9
10
|
from edsl.prompts.registry import get_classes as prompt_lookup
|
10
11
|
from edsl.exceptions import QuestionScenarioRenderError
|
11
12
|
|
12
|
-
import
|
13
|
-
|
14
|
-
|
15
|
-
class PromptComponent(enum.Enum):
|
16
|
-
AGENT_INSTRUCTIONS = "agent_instructions"
|
17
|
-
AGENT_PERSONA = "agent_persona"
|
18
|
-
QUESTION_INSTRUCTIONS = "question_instructions"
|
19
|
-
PRIOR_QUESTION_MEMORY = "prior_question_memory"
|
20
|
-
|
21
|
-
|
22
|
-
class PromptList(UserList):
|
23
|
-
separator = Prompt(" ")
|
24
|
-
|
25
|
-
def reduce(self):
|
26
|
-
"""Reduce the list of prompts to a single prompt.
|
27
|
-
|
28
|
-
>>> p = PromptList([Prompt("You are a happy-go lucky agent."), Prompt("You are an agent with the following persona: {'age': 22, 'hair': 'brown', 'height': 5.5}")])
|
29
|
-
>>> p.reduce()
|
30
|
-
Prompt(text=\"""You are a happy-go lucky agent. You are an agent with the following persona: {'age': 22, 'hair': 'brown', 'height': 5.5}\""")
|
31
|
-
|
32
|
-
"""
|
33
|
-
p = self[0]
|
34
|
-
for prompt in self[1:]:
|
35
|
-
if len(prompt) > 0:
|
36
|
-
p = p + self.separator + prompt
|
37
|
-
return p
|
38
|
-
|
39
|
-
|
40
|
-
class PromptPlan:
|
41
|
-
"""A plan for constructing prompts for the LLM call.
|
42
|
-
Every prompt plan has a user prompt order and a system prompt order.
|
43
|
-
It must contain each of the values in the PromptComponent enum.
|
44
|
-
|
45
|
-
|
46
|
-
>>> p = PromptPlan(user_prompt_order=(PromptComponent.AGENT_INSTRUCTIONS, PromptComponent.AGENT_PERSONA),system_prompt_order=(PromptComponent.QUESTION_INSTRUCTIONS, PromptComponent.PRIOR_QUESTION_MEMORY))
|
47
|
-
>>> p._is_valid_plan()
|
48
|
-
True
|
13
|
+
from edsl.agents.prompt_helpers import PromptComponent, PromptList, PromptPlan
|
49
14
|
|
50
|
-
>>> p.arrange_components(agent_instructions=1, agent_persona=2, question_instructions=3, prior_question_memory=4)
|
51
|
-
{'user_prompt': ..., 'system_prompt': ...}
|
52
|
-
|
53
|
-
>>> p = PromptPlan(user_prompt_order=("agent_instructions", ), system_prompt_order=("question_instructions", "prior_question_memory"))
|
54
|
-
Traceback (most recent call last):
|
55
|
-
...
|
56
|
-
ValueError: Invalid plan: must contain each value of PromptComponent exactly once.
|
57
15
|
|
16
|
+
def get_jinja2_variables(template_str: str) -> Set[str]:
|
58
17
|
"""
|
18
|
+
Extracts all variable names from a Jinja2 template using Jinja2's built-in parsing.
|
59
19
|
|
60
|
-
|
61
|
-
|
62
|
-
user_prompt_order: Optional[tuple] = None,
|
63
|
-
system_prompt_order: Optional[tuple] = None,
|
64
|
-
):
|
65
|
-
"""Initialize the PromptPlan."""
|
66
|
-
|
67
|
-
if user_prompt_order is None:
|
68
|
-
user_prompt_order = (
|
69
|
-
PromptComponent.QUESTION_INSTRUCTIONS,
|
70
|
-
PromptComponent.PRIOR_QUESTION_MEMORY,
|
71
|
-
)
|
72
|
-
if system_prompt_order is None:
|
73
|
-
system_prompt_order = (
|
74
|
-
PromptComponent.AGENT_INSTRUCTIONS,
|
75
|
-
PromptComponent.AGENT_PERSONA,
|
76
|
-
)
|
77
|
-
|
78
|
-
# very commmon way to screw this up given how python treats single strings as iterables
|
79
|
-
if isinstance(user_prompt_order, str):
|
80
|
-
user_prompt_order = (user_prompt_order,)
|
81
|
-
|
82
|
-
if isinstance(system_prompt_order, str):
|
83
|
-
system_prompt_order = (system_prompt_order,)
|
84
|
-
|
85
|
-
if not isinstance(user_prompt_order, tuple):
|
86
|
-
raise TypeError(
|
87
|
-
f"Expected a tuple, but got {type(user_prompt_order).__name__}"
|
88
|
-
)
|
89
|
-
|
90
|
-
if not isinstance(system_prompt_order, tuple):
|
91
|
-
raise TypeError(
|
92
|
-
f"Expected a tuple, but got {type(system_prompt_order).__name__}"
|
93
|
-
)
|
94
|
-
|
95
|
-
self.user_prompt_order = self._convert_to_enum(user_prompt_order)
|
96
|
-
self.system_prompt_order = self._convert_to_enum(system_prompt_order)
|
97
|
-
if not self._is_valid_plan():
|
98
|
-
raise ValueError(
|
99
|
-
"Invalid plan: must contain each value of PromptComponent exactly once."
|
100
|
-
)
|
101
|
-
|
102
|
-
def _convert_to_enum(self, prompt_order: tuple):
|
103
|
-
"""Convert string names to PromptComponent enum values."""
|
104
|
-
return tuple(
|
105
|
-
PromptComponent(component) if isinstance(component, str) else component
|
106
|
-
for component in prompt_order
|
107
|
-
)
|
108
|
-
|
109
|
-
def _is_valid_plan(self):
|
110
|
-
"""Check if the plan is valid."""
|
111
|
-
combined = self.user_prompt_order + self.system_prompt_order
|
112
|
-
return set(combined) == set(PromptComponent)
|
113
|
-
|
114
|
-
def arrange_components(self, **kwargs) -> Dict[PromptComponent, Prompt]:
|
115
|
-
"""Arrange the components in the order specified by the plan."""
|
116
|
-
# check is valid components passed
|
117
|
-
component_strings = set([pc.value for pc in PromptComponent])
|
118
|
-
if not set(kwargs.keys()) == component_strings:
|
119
|
-
raise ValueError(
|
120
|
-
f"Invalid components passed: {set(kwargs.keys())} but expected {PromptComponent}"
|
121
|
-
)
|
122
|
-
|
123
|
-
user_prompt = PromptList(
|
124
|
-
[kwargs[component.value] for component in self.user_prompt_order]
|
125
|
-
)
|
126
|
-
system_prompt = PromptList(
|
127
|
-
[kwargs[component.value] for component in self.system_prompt_order]
|
128
|
-
)
|
129
|
-
return {"user_prompt": user_prompt, "system_prompt": system_prompt}
|
130
|
-
|
131
|
-
def get_prompts(self, **kwargs) -> Dict[str, Prompt]:
|
132
|
-
"""Get both prompts for the LLM call."""
|
133
|
-
prompts = self.arrange_components(**kwargs)
|
134
|
-
return {
|
135
|
-
"user_prompt": prompts["user_prompt"].reduce(),
|
136
|
-
"system_prompt": prompts["system_prompt"].reduce(),
|
137
|
-
}
|
20
|
+
Args:
|
21
|
+
template_str (str): The Jinja2 template string
|
138
22
|
|
23
|
+
Returns:
|
24
|
+
Set[str]: A set of variable names found in the template
|
25
|
+
"""
|
26
|
+
env = Environment()
|
27
|
+
ast = env.parse(template_str)
|
28
|
+
return meta.find_undeclared_variables(ast)
|
139
29
|
|
140
|
-
class PromptConstructorMixin:
|
141
|
-
"""Mixin for constructing prompts for the LLM call.
|
142
30
|
|
31
|
+
class PromptConstructor:
|
32
|
+
"""
|
143
33
|
The pieces of a prompt are:
|
144
34
|
- The agent instructions - "You are answering questions as if you were a human. Do not break character."
|
145
35
|
- The persona prompt - "You are an agent with the following persona: {'age': 22, 'hair': 'brown', 'height': 5.5}"
|
@@ -149,16 +39,42 @@ class PromptConstructorMixin:
|
|
149
39
|
This is mixed into the Invigilator class.
|
150
40
|
"""
|
151
41
|
|
152
|
-
|
42
|
+
def __init__(self, invigilator):
|
43
|
+
self.invigilator = invigilator
|
44
|
+
self.agent = invigilator.agent
|
45
|
+
self.question = invigilator.question
|
46
|
+
self.scenario = invigilator.scenario
|
47
|
+
self.survey = invigilator.survey
|
48
|
+
self.model = invigilator.model
|
49
|
+
self.current_answers = invigilator.current_answers
|
50
|
+
self.memory_plan = invigilator.memory_plan
|
51
|
+
self.prompt_plan = PromptPlan()
|
52
|
+
|
53
|
+
@property
|
54
|
+
def scenario_file_keys(self) -> list:
|
55
|
+
"""We need to find all the keys in the scenario that refer to FileStore objects.
|
56
|
+
These will be used to append to the prompt a list of files that are part of the scenario.
|
57
|
+
"""
|
58
|
+
from edsl.scenarios.FileStore import FileStore
|
59
|
+
|
60
|
+
file_entries = []
|
61
|
+
for key, value in self.scenario.items():
|
62
|
+
if isinstance(value, FileStore):
|
63
|
+
file_entries.append(key)
|
64
|
+
return file_entries
|
153
65
|
|
154
66
|
@property
|
155
67
|
def agent_instructions_prompt(self) -> Prompt:
|
156
68
|
"""
|
157
69
|
>>> from edsl.agents.InvigilatorBase import InvigilatorBase
|
158
70
|
>>> i = InvigilatorBase.example()
|
159
|
-
>>> i.agent_instructions_prompt
|
71
|
+
>>> i.prompt_constructor.agent_instructions_prompt
|
160
72
|
Prompt(text=\"""You are answering questions as if you were a human. Do not break character.\""")
|
161
73
|
"""
|
74
|
+
from edsl import Agent
|
75
|
+
|
76
|
+
if self.agent == Agent(): # if agent is empty, then return an empty prompt
|
77
|
+
return Prompt(text="")
|
162
78
|
if not hasattr(self, "_agent_instructions_prompt"):
|
163
79
|
applicable_prompts = prompt_lookup(
|
164
80
|
component_type="agent_instructions",
|
@@ -176,47 +92,56 @@ class PromptConstructorMixin:
|
|
176
92
|
"""
|
177
93
|
>>> from edsl.agents.InvigilatorBase import InvigilatorBase
|
178
94
|
>>> i = InvigilatorBase.example()
|
179
|
-
>>> i.agent_persona_prompt
|
95
|
+
>>> i.prompt_constructor.agent_persona_prompt
|
180
96
|
Prompt(text=\"""You are an agent with the following persona:
|
181
97
|
{'age': 22, 'hair': 'brown', 'height': 5.5}\""")
|
182
98
|
|
183
99
|
"""
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
model=self.model.model,
|
189
|
-
)
|
190
|
-
persona_prompt_template = applicable_prompts[0]()
|
191
|
-
else:
|
192
|
-
persona_prompt_template = self.agent.agent_persona
|
193
|
-
|
194
|
-
# TODO: This multiple passing of agent traits - not sure if it is necessary. Not harmful.
|
195
|
-
if undefined := persona_prompt_template.undefined_template_variables(
|
196
|
-
self.agent.traits
|
197
|
-
| {"traits": self.agent.traits}
|
198
|
-
| {"codebook": self.agent.codebook}
|
199
|
-
| {"traits": self.agent.traits}
|
200
|
-
):
|
201
|
-
raise QuestionScenarioRenderError(
|
202
|
-
f"Agent persona still has variables that were not rendered: {undefined}"
|
203
|
-
)
|
100
|
+
from edsl import Agent
|
101
|
+
|
102
|
+
if hasattr(self, "_agent_persona_prompt"):
|
103
|
+
return self._agent_persona_prompt
|
204
104
|
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
105
|
+
if self.agent == Agent(): # if agent is empty, then return an empty prompt
|
106
|
+
return Prompt(text="")
|
107
|
+
|
108
|
+
if not hasattr(self.agent, "agent_persona"):
|
109
|
+
applicable_prompts = prompt_lookup(
|
110
|
+
component_type="agent_persona",
|
111
|
+
model=self.model.model,
|
209
112
|
)
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
113
|
+
persona_prompt_template = applicable_prompts[0]()
|
114
|
+
else:
|
115
|
+
persona_prompt_template = self.agent.agent_persona
|
116
|
+
|
117
|
+
# TODO: This multiple passing of agent traits - not sure if it is necessary. Not harmful.
|
118
|
+
template_parameter_dictionary = (
|
119
|
+
self.agent.traits
|
120
|
+
| {"traits": self.agent.traits}
|
121
|
+
| {"codebook": self.agent.codebook}
|
122
|
+
| {"traits": self.agent.traits}
|
123
|
+
)
|
124
|
+
|
125
|
+
if undefined := persona_prompt_template.undefined_template_variables(
|
126
|
+
template_parameter_dictionary
|
127
|
+
):
|
128
|
+
raise QuestionScenarioRenderError(
|
129
|
+
f"Agent persona still has variables that were not rendered: {undefined}"
|
130
|
+
)
|
131
|
+
|
132
|
+
persona_prompt = persona_prompt_template.render(template_parameter_dictionary)
|
133
|
+
if persona_prompt.has_variables:
|
134
|
+
raise QuestionScenarioRenderError(
|
135
|
+
"Agent persona still has variables that were not rendered."
|
136
|
+
)
|
137
|
+
|
138
|
+
self._agent_persona_prompt = persona_prompt
|
215
139
|
|
216
140
|
return self._agent_persona_prompt
|
217
141
|
|
218
142
|
def prior_answers_dict(self) -> dict:
|
219
143
|
d = self.survey.question_names_to_questions()
|
144
|
+
# This attaches the answer to the question
|
220
145
|
for question, answer in self.current_answers.items():
|
221
146
|
if question in d:
|
222
147
|
d[question].answer = answer
|
@@ -226,51 +151,70 @@ class PromptConstructorMixin:
|
|
226
151
|
d[new_question].comment = answer
|
227
152
|
return d
|
228
153
|
|
154
|
+
@property
|
155
|
+
def question_file_keys(self):
|
156
|
+
raw_question_text = self.question.question_text
|
157
|
+
variables = get_jinja2_variables(raw_question_text)
|
158
|
+
question_file_keys = []
|
159
|
+
for var in variables:
|
160
|
+
if var in self.scenario_file_keys:
|
161
|
+
question_file_keys.append(var)
|
162
|
+
return question_file_keys
|
163
|
+
|
229
164
|
@property
|
230
165
|
def question_instructions_prompt(self) -> Prompt:
|
231
166
|
"""
|
232
167
|
>>> from edsl.agents.InvigilatorBase import InvigilatorBase
|
233
168
|
>>> i = InvigilatorBase.example()
|
234
|
-
>>> i.question_instructions_prompt
|
169
|
+
>>> i.prompt_constructor.question_instructions_prompt
|
235
170
|
Prompt(text=\"""...
|
236
171
|
...
|
237
172
|
"""
|
173
|
+
# The user might have passed a custom prompt, which would be stored in _question_instructions_prompt
|
238
174
|
if not hasattr(self, "_question_instructions_prompt"):
|
175
|
+
# Gets the instructions for the question - this is how the question should be answered
|
239
176
|
question_prompt = self.question.get_instructions(model=self.model.model)
|
240
177
|
|
241
|
-
#
|
242
|
-
#
|
243
|
-
# for question, answer in self.current_answers.items():
|
244
|
-
# if question in d:
|
245
|
-
# d[question].answer = answer
|
246
|
-
# else:
|
247
|
-
# # adds a comment to the question
|
248
|
-
# if (new_question := question.split("_comment")[0]) in d:
|
249
|
-
# d[new_question].comment = answer
|
250
|
-
|
178
|
+
# Get the data for the question - this is a dictionary of the question data
|
179
|
+
# e.g., {'question_text': 'Do you like school?', 'question_name': 'q0', 'question_options': ['yes', 'no']}
|
251
180
|
question_data = self.question.data.copy()
|
252
181
|
|
253
182
|
# check to see if the question_options is actually a string
|
254
|
-
# This is used when the user is using the question_options as a variable from a
|
183
|
+
# This is used when the user is using the question_options as a variable from a scenario
|
255
184
|
# if "question_options" in question_data:
|
256
185
|
if isinstance(self.question.data.get("question_options", None), str):
|
257
|
-
from jinja2 import Environment, meta
|
258
|
-
|
259
186
|
env = Environment()
|
260
187
|
parsed_content = env.parse(self.question.data["question_options"])
|
261
188
|
question_option_key = list(
|
262
189
|
meta.find_undeclared_variables(parsed_content)
|
263
190
|
)[0]
|
264
191
|
|
192
|
+
# look to see if the question_option_key is in the scenario
|
265
193
|
if isinstance(
|
266
194
|
question_options := self.scenario.get(question_option_key), list
|
267
195
|
):
|
268
196
|
question_data["question_options"] = question_options
|
269
197
|
self.question.question_options = question_options
|
270
198
|
|
199
|
+
# might be getting it from the prior answers
|
200
|
+
if self.prior_answers_dict().get(question_option_key) is not None:
|
201
|
+
if isinstance(
|
202
|
+
question_options := self.prior_answers_dict()
|
203
|
+
.get(question_option_key)
|
204
|
+
.answer,
|
205
|
+
list,
|
206
|
+
):
|
207
|
+
question_data["question_options"] = question_options
|
208
|
+
self.question.question_options = question_options
|
209
|
+
|
271
210
|
replacement_dict = (
|
272
|
-
|
273
|
-
|
|
211
|
+
{key: f"<see file {key}>" for key in self.scenario_file_keys}
|
212
|
+
| question_data
|
213
|
+
| {
|
214
|
+
k: v
|
215
|
+
for k, v in self.scenario.items()
|
216
|
+
if k not in self.scenario_file_keys
|
217
|
+
} # don't include images in the replacement dict
|
274
218
|
| self.prior_answers_dict()
|
275
219
|
| {"agent": self.agent}
|
276
220
|
| {
|
@@ -280,9 +224,10 @@ class PromptConstructorMixin:
|
|
280
224
|
),
|
281
225
|
}
|
282
226
|
)
|
283
|
-
|
227
|
+
|
284
228
|
rendered_instructions = question_prompt.render(replacement_dict)
|
285
|
-
|
229
|
+
|
230
|
+
# is there anything left to render?
|
286
231
|
undefined_template_variables = (
|
287
232
|
rendered_instructions.undefined_template_variables({})
|
288
233
|
)
|
@@ -296,11 +241,14 @@ class PromptConstructorMixin:
|
|
296
241
|
)
|
297
242
|
|
298
243
|
if undefined_template_variables:
|
244
|
+
# breakpoint()
|
299
245
|
raise QuestionScenarioRenderError(
|
300
246
|
f"Question instructions still has variables: {undefined_template_variables}."
|
301
247
|
)
|
302
248
|
|
303
|
-
|
249
|
+
####################################
|
250
|
+
# Check if question has instructions - these are instructions in a Survey that can apply to multiple follow-on questions
|
251
|
+
####################################
|
304
252
|
relevant_instructions = self.survey.relevant_instructions(
|
305
253
|
self.question.question_name
|
306
254
|
)
|
@@ -329,6 +277,23 @@ class PromptConstructorMixin:
|
|
329
277
|
self._prior_question_memory_prompt = memory_prompt
|
330
278
|
return self._prior_question_memory_prompt
|
331
279
|
|
280
|
+
def create_memory_prompt(self, question_name: str) -> Prompt:
|
281
|
+
"""Create a memory for the agent.
|
282
|
+
|
283
|
+
The returns a memory prompt for the agent.
|
284
|
+
|
285
|
+
>>> from edsl.agents.InvigilatorBase import InvigilatorBase
|
286
|
+
>>> i = InvigilatorBase.example()
|
287
|
+
>>> i.current_answers = {"q0": "Prior answer"}
|
288
|
+
>>> i.memory_plan.add_single_memory("q1", "q0")
|
289
|
+
>>> p = i.prompt_constructor.create_memory_prompt("q1")
|
290
|
+
>>> p.text.strip().replace("\\n", " ").replace("\\t", " ")
|
291
|
+
'Before the question you are now answering, you already answered the following question(s): Question: Do you like school? Answer: Prior answer'
|
292
|
+
"""
|
293
|
+
return self.memory_plan.get_memory_prompt_fragment(
|
294
|
+
question_name, self.current_answers
|
295
|
+
)
|
296
|
+
|
332
297
|
def construct_system_prompt(self) -> Prompt:
|
333
298
|
"""Construct the system prompt for the LLM call."""
|
334
299
|
import warnings
|
@@ -357,15 +322,18 @@ class PromptConstructorMixin:
|
|
357
322
|
>>> i.get_prompts()
|
358
323
|
{'user_prompt': ..., 'system_prompt': ...}
|
359
324
|
"""
|
325
|
+
# breakpoint()
|
360
326
|
prompts = self.prompt_plan.get_prompts(
|
361
327
|
agent_instructions=self.agent_instructions_prompt,
|
362
328
|
agent_persona=self.agent_persona_prompt,
|
363
329
|
question_instructions=self.question_instructions_prompt,
|
364
330
|
prior_question_memory=self.prior_question_memory_prompt,
|
365
331
|
)
|
366
|
-
|
367
|
-
|
368
|
-
|
332
|
+
if self.question_file_keys:
|
333
|
+
files_list = []
|
334
|
+
for key in self.question_file_keys:
|
335
|
+
files_list.append(self.scenario[key])
|
336
|
+
prompts["files_list"] = files_list
|
369
337
|
return prompts
|
370
338
|
|
371
339
|
def _get_scenario_with_image(self) -> Scenario:
|
@@ -0,0 +1,129 @@
|
|
1
|
+
import enum
|
2
|
+
from typing import Dict, Optional
|
3
|
+
from collections import UserList
|
4
|
+
from edsl.prompts import Prompt
|
5
|
+
|
6
|
+
|
7
|
+
class PromptComponent(enum.Enum):
|
8
|
+
AGENT_INSTRUCTIONS = "agent_instructions"
|
9
|
+
AGENT_PERSONA = "agent_persona"
|
10
|
+
QUESTION_INSTRUCTIONS = "question_instructions"
|
11
|
+
PRIOR_QUESTION_MEMORY = "prior_question_memory"
|
12
|
+
|
13
|
+
|
14
|
+
class PromptList(UserList):
|
15
|
+
separator = Prompt(" ")
|
16
|
+
|
17
|
+
def reduce(self):
|
18
|
+
"""Reduce the list of prompts to a single prompt.
|
19
|
+
|
20
|
+
>>> p = PromptList([Prompt("You are a happy-go lucky agent."), Prompt("You are an agent with the following persona: {'age': 22, 'hair': 'brown', 'height': 5.5}")])
|
21
|
+
>>> p.reduce()
|
22
|
+
Prompt(text=\"""You are a happy-go lucky agent. You are an agent with the following persona: {'age': 22, 'hair': 'brown', 'height': 5.5}\""")
|
23
|
+
|
24
|
+
"""
|
25
|
+
p = self[0]
|
26
|
+
for prompt in self[1:]:
|
27
|
+
if len(prompt) > 0:
|
28
|
+
p = p + self.separator + prompt
|
29
|
+
return p
|
30
|
+
|
31
|
+
|
32
|
+
class PromptPlan:
|
33
|
+
"""A plan for constructing prompts for the LLM call.
|
34
|
+
Every prompt plan has a user prompt order and a system prompt order.
|
35
|
+
It must contain each of the values in the PromptComponent enum.
|
36
|
+
|
37
|
+
|
38
|
+
>>> p = PromptPlan(user_prompt_order=(PromptComponent.AGENT_INSTRUCTIONS, PromptComponent.AGENT_PERSONA),system_prompt_order=(PromptComponent.QUESTION_INSTRUCTIONS, PromptComponent.PRIOR_QUESTION_MEMORY))
|
39
|
+
>>> p._is_valid_plan()
|
40
|
+
True
|
41
|
+
|
42
|
+
>>> p.arrange_components(agent_instructions=1, agent_persona=2, question_instructions=3, prior_question_memory=4)
|
43
|
+
{'user_prompt': ..., 'system_prompt': ...}
|
44
|
+
|
45
|
+
>>> p = PromptPlan(user_prompt_order=("agent_instructions", ), system_prompt_order=("question_instructions", "prior_question_memory"))
|
46
|
+
Traceback (most recent call last):
|
47
|
+
...
|
48
|
+
ValueError: Invalid plan: must contain each value of PromptComponent exactly once.
|
49
|
+
|
50
|
+
"""
|
51
|
+
|
52
|
+
def __init__(
|
53
|
+
self,
|
54
|
+
user_prompt_order: Optional[tuple] = None,
|
55
|
+
system_prompt_order: Optional[tuple] = None,
|
56
|
+
):
|
57
|
+
"""Initialize the PromptPlan."""
|
58
|
+
|
59
|
+
if user_prompt_order is None:
|
60
|
+
user_prompt_order = (
|
61
|
+
PromptComponent.QUESTION_INSTRUCTIONS,
|
62
|
+
PromptComponent.PRIOR_QUESTION_MEMORY,
|
63
|
+
)
|
64
|
+
if system_prompt_order is None:
|
65
|
+
system_prompt_order = (
|
66
|
+
PromptComponent.AGENT_INSTRUCTIONS,
|
67
|
+
PromptComponent.AGENT_PERSONA,
|
68
|
+
)
|
69
|
+
|
70
|
+
# very commmon way to screw this up given how python treats single strings as iterables
|
71
|
+
if isinstance(user_prompt_order, str):
|
72
|
+
user_prompt_order = (user_prompt_order,)
|
73
|
+
|
74
|
+
if isinstance(system_prompt_order, str):
|
75
|
+
system_prompt_order = (system_prompt_order,)
|
76
|
+
|
77
|
+
if not isinstance(user_prompt_order, tuple):
|
78
|
+
raise TypeError(
|
79
|
+
f"Expected a tuple, but got {type(user_prompt_order).__name__}"
|
80
|
+
)
|
81
|
+
|
82
|
+
if not isinstance(system_prompt_order, tuple):
|
83
|
+
raise TypeError(
|
84
|
+
f"Expected a tuple, but got {type(system_prompt_order).__name__}"
|
85
|
+
)
|
86
|
+
|
87
|
+
self.user_prompt_order = self._convert_to_enum(user_prompt_order)
|
88
|
+
self.system_prompt_order = self._convert_to_enum(system_prompt_order)
|
89
|
+
if not self._is_valid_plan():
|
90
|
+
raise ValueError(
|
91
|
+
"Invalid plan: must contain each value of PromptComponent exactly once."
|
92
|
+
)
|
93
|
+
|
94
|
+
def _convert_to_enum(self, prompt_order: tuple):
|
95
|
+
"""Convert string names to PromptComponent enum values."""
|
96
|
+
return tuple(
|
97
|
+
PromptComponent(component) if isinstance(component, str) else component
|
98
|
+
for component in prompt_order
|
99
|
+
)
|
100
|
+
|
101
|
+
def _is_valid_plan(self):
|
102
|
+
"""Check if the plan is valid."""
|
103
|
+
combined = self.user_prompt_order + self.system_prompt_order
|
104
|
+
return set(combined) == set(PromptComponent)
|
105
|
+
|
106
|
+
def arrange_components(self, **kwargs) -> Dict[PromptComponent, Prompt]:
|
107
|
+
"""Arrange the components in the order specified by the plan."""
|
108
|
+
# check is valid components passed
|
109
|
+
component_strings = set([pc.value for pc in PromptComponent])
|
110
|
+
if not set(kwargs.keys()) == component_strings:
|
111
|
+
raise ValueError(
|
112
|
+
f"Invalid components passed: {set(kwargs.keys())} but expected {PromptComponent}"
|
113
|
+
)
|
114
|
+
|
115
|
+
user_prompt = PromptList(
|
116
|
+
[kwargs[component.value] for component in self.user_prompt_order]
|
117
|
+
)
|
118
|
+
system_prompt = PromptList(
|
119
|
+
[kwargs[component.value] for component in self.system_prompt_order]
|
120
|
+
)
|
121
|
+
return {"user_prompt": user_prompt, "system_prompt": system_prompt}
|
122
|
+
|
123
|
+
def get_prompts(self, **kwargs) -> Dict[str, Prompt]:
|
124
|
+
"""Get both prompts for the LLM call."""
|
125
|
+
prompts = self.arrange_components(**kwargs)
|
126
|
+
return {
|
127
|
+
"user_prompt": prompts["user_prompt"].reduce(),
|
128
|
+
"system_prompt": prompts["system_prompt"].reduce(),
|
129
|
+
}
|