edsl 0.1.33.dev2__py3-none-any.whl → 0.1.34__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. edsl/Base.py +24 -14
  2. edsl/__init__.py +1 -0
  3. edsl/__version__.py +1 -1
  4. edsl/agents/Agent.py +6 -6
  5. edsl/agents/Invigilator.py +28 -6
  6. edsl/agents/InvigilatorBase.py +8 -27
  7. edsl/agents/{PromptConstructionMixin.py → PromptConstructor.py} +150 -182
  8. edsl/agents/prompt_helpers.py +129 -0
  9. edsl/config.py +26 -34
  10. edsl/coop/coop.py +14 -4
  11. edsl/data_transfer_models.py +26 -73
  12. edsl/enums.py +2 -0
  13. edsl/inference_services/AnthropicService.py +5 -2
  14. edsl/inference_services/AwsBedrock.py +5 -2
  15. edsl/inference_services/AzureAI.py +5 -2
  16. edsl/inference_services/GoogleService.py +108 -33
  17. edsl/inference_services/InferenceServiceABC.py +44 -13
  18. edsl/inference_services/MistralAIService.py +5 -2
  19. edsl/inference_services/OpenAIService.py +10 -6
  20. edsl/inference_services/TestService.py +34 -16
  21. edsl/inference_services/TogetherAIService.py +170 -0
  22. edsl/inference_services/registry.py +2 -0
  23. edsl/jobs/Jobs.py +109 -18
  24. edsl/jobs/buckets/BucketCollection.py +24 -15
  25. edsl/jobs/buckets/TokenBucket.py +64 -10
  26. edsl/jobs/interviews/Interview.py +130 -49
  27. edsl/jobs/interviews/{interview_exception_tracking.py → InterviewExceptionCollection.py} +16 -0
  28. edsl/jobs/interviews/InterviewExceptionEntry.py +2 -0
  29. edsl/jobs/runners/JobsRunnerAsyncio.py +119 -173
  30. edsl/jobs/runners/JobsRunnerStatus.py +332 -0
  31. edsl/jobs/tasks/QuestionTaskCreator.py +1 -13
  32. edsl/jobs/tasks/TaskHistory.py +17 -0
  33. edsl/language_models/LanguageModel.py +36 -38
  34. edsl/language_models/registry.py +13 -9
  35. edsl/language_models/utilities.py +5 -2
  36. edsl/questions/QuestionBase.py +74 -16
  37. edsl/questions/QuestionBaseGenMixin.py +28 -0
  38. edsl/questions/QuestionBudget.py +93 -41
  39. edsl/questions/QuestionCheckBox.py +1 -1
  40. edsl/questions/QuestionFreeText.py +6 -0
  41. edsl/questions/QuestionMultipleChoice.py +13 -24
  42. edsl/questions/QuestionNumerical.py +5 -4
  43. edsl/questions/Quick.py +41 -0
  44. edsl/questions/ResponseValidatorABC.py +11 -6
  45. edsl/questions/derived/QuestionLinearScale.py +4 -1
  46. edsl/questions/derived/QuestionTopK.py +4 -1
  47. edsl/questions/derived/QuestionYesNo.py +8 -2
  48. edsl/questions/descriptors.py +12 -11
  49. edsl/questions/templates/budget/__init__.py +0 -0
  50. edsl/questions/templates/budget/answering_instructions.jinja +7 -0
  51. edsl/questions/templates/budget/question_presentation.jinja +7 -0
  52. edsl/questions/templates/extract/__init__.py +0 -0
  53. edsl/questions/templates/numerical/answering_instructions.jinja +0 -1
  54. edsl/questions/templates/rank/__init__.py +0 -0
  55. edsl/questions/templates/yes_no/answering_instructions.jinja +2 -2
  56. edsl/results/DatasetExportMixin.py +5 -1
  57. edsl/results/Result.py +1 -1
  58. edsl/results/Results.py +4 -1
  59. edsl/scenarios/FileStore.py +178 -34
  60. edsl/scenarios/Scenario.py +76 -37
  61. edsl/scenarios/ScenarioList.py +19 -2
  62. edsl/scenarios/ScenarioListPdfMixin.py +150 -4
  63. edsl/study/Study.py +32 -0
  64. edsl/surveys/DAG.py +62 -0
  65. edsl/surveys/MemoryPlan.py +26 -0
  66. edsl/surveys/Rule.py +34 -1
  67. edsl/surveys/RuleCollection.py +55 -5
  68. edsl/surveys/Survey.py +189 -10
  69. edsl/surveys/base.py +4 -0
  70. edsl/templates/error_reporting/interview_details.html +6 -1
  71. edsl/utilities/utilities.py +9 -1
  72. {edsl-0.1.33.dev2.dist-info → edsl-0.1.34.dist-info}/METADATA +3 -1
  73. {edsl-0.1.33.dev2.dist-info → edsl-0.1.34.dist-info}/RECORD +75 -69
  74. edsl/jobs/interviews/retry_management.py +0 -39
  75. edsl/jobs/runners/JobsRunnerStatusMixin.py +0 -333
  76. edsl/scenarios/ScenarioImageMixin.py +0 -100
  77. {edsl-0.1.33.dev2.dist-info → edsl-0.1.34.dist-info}/LICENSE +0 -0
  78. {edsl-0.1.33.dev2.dist-info → edsl-0.1.34.dist-info}/WHEEL +0 -0
@@ -1,145 +1,35 @@
1
1
  from __future__ import annotations
2
- from typing import Dict, Any, Optional
2
+ from typing import Dict, Any, Optional, Set
3
3
  from collections import UserList
4
+ import pdb
4
5
 
5
- # from functools import reduce
6
- from edsl.prompts.Prompt import Prompt
6
+ from jinja2 import Environment, meta
7
7
 
8
- # from edsl.utilities.decorators import sync_wrapper, jupyter_nb_handler
8
+ from edsl.prompts.Prompt import Prompt
9
+ from edsl.data_transfer_models import ImageInfo
9
10
  from edsl.prompts.registry import get_classes as prompt_lookup
10
11
  from edsl.exceptions import QuestionScenarioRenderError
11
12
 
12
- import enum
13
-
14
-
15
- class PromptComponent(enum.Enum):
16
- AGENT_INSTRUCTIONS = "agent_instructions"
17
- AGENT_PERSONA = "agent_persona"
18
- QUESTION_INSTRUCTIONS = "question_instructions"
19
- PRIOR_QUESTION_MEMORY = "prior_question_memory"
20
-
21
-
22
- class PromptList(UserList):
23
- separator = Prompt(" ")
24
-
25
- def reduce(self):
26
- """Reduce the list of prompts to a single prompt.
27
-
28
- >>> p = PromptList([Prompt("You are a happy-go lucky agent."), Prompt("You are an agent with the following persona: {'age': 22, 'hair': 'brown', 'height': 5.5}")])
29
- >>> p.reduce()
30
- Prompt(text=\"""You are a happy-go lucky agent. You are an agent with the following persona: {'age': 22, 'hair': 'brown', 'height': 5.5}\""")
31
-
32
- """
33
- p = self[0]
34
- for prompt in self[1:]:
35
- if len(prompt) > 0:
36
- p = p + self.separator + prompt
37
- return p
38
-
39
-
40
- class PromptPlan:
41
- """A plan for constructing prompts for the LLM call.
42
- Every prompt plan has a user prompt order and a system prompt order.
43
- It must contain each of the values in the PromptComponent enum.
44
-
45
-
46
- >>> p = PromptPlan(user_prompt_order=(PromptComponent.AGENT_INSTRUCTIONS, PromptComponent.AGENT_PERSONA),system_prompt_order=(PromptComponent.QUESTION_INSTRUCTIONS, PromptComponent.PRIOR_QUESTION_MEMORY))
47
- >>> p._is_valid_plan()
48
- True
13
+ from edsl.agents.prompt_helpers import PromptComponent, PromptList, PromptPlan
49
14
 
50
- >>> p.arrange_components(agent_instructions=1, agent_persona=2, question_instructions=3, prior_question_memory=4)
51
- {'user_prompt': ..., 'system_prompt': ...}
52
-
53
- >>> p = PromptPlan(user_prompt_order=("agent_instructions", ), system_prompt_order=("question_instructions", "prior_question_memory"))
54
- Traceback (most recent call last):
55
- ...
56
- ValueError: Invalid plan: must contain each value of PromptComponent exactly once.
57
15
 
16
+ def get_jinja2_variables(template_str: str) -> Set[str]:
58
17
  """
18
+ Extracts all variable names from a Jinja2 template using Jinja2's built-in parsing.
59
19
 
60
- def __init__(
61
- self,
62
- user_prompt_order: Optional[tuple] = None,
63
- system_prompt_order: Optional[tuple] = None,
64
- ):
65
- """Initialize the PromptPlan."""
66
-
67
- if user_prompt_order is None:
68
- user_prompt_order = (
69
- PromptComponent.QUESTION_INSTRUCTIONS,
70
- PromptComponent.PRIOR_QUESTION_MEMORY,
71
- )
72
- if system_prompt_order is None:
73
- system_prompt_order = (
74
- PromptComponent.AGENT_INSTRUCTIONS,
75
- PromptComponent.AGENT_PERSONA,
76
- )
77
-
78
- # very commmon way to screw this up given how python treats single strings as iterables
79
- if isinstance(user_prompt_order, str):
80
- user_prompt_order = (user_prompt_order,)
81
-
82
- if isinstance(system_prompt_order, str):
83
- system_prompt_order = (system_prompt_order,)
84
-
85
- if not isinstance(user_prompt_order, tuple):
86
- raise TypeError(
87
- f"Expected a tuple, but got {type(user_prompt_order).__name__}"
88
- )
89
-
90
- if not isinstance(system_prompt_order, tuple):
91
- raise TypeError(
92
- f"Expected a tuple, but got {type(system_prompt_order).__name__}"
93
- )
94
-
95
- self.user_prompt_order = self._convert_to_enum(user_prompt_order)
96
- self.system_prompt_order = self._convert_to_enum(system_prompt_order)
97
- if not self._is_valid_plan():
98
- raise ValueError(
99
- "Invalid plan: must contain each value of PromptComponent exactly once."
100
- )
101
-
102
- def _convert_to_enum(self, prompt_order: tuple):
103
- """Convert string names to PromptComponent enum values."""
104
- return tuple(
105
- PromptComponent(component) if isinstance(component, str) else component
106
- for component in prompt_order
107
- )
108
-
109
- def _is_valid_plan(self):
110
- """Check if the plan is valid."""
111
- combined = self.user_prompt_order + self.system_prompt_order
112
- return set(combined) == set(PromptComponent)
113
-
114
- def arrange_components(self, **kwargs) -> Dict[PromptComponent, Prompt]:
115
- """Arrange the components in the order specified by the plan."""
116
- # check is valid components passed
117
- component_strings = set([pc.value for pc in PromptComponent])
118
- if not set(kwargs.keys()) == component_strings:
119
- raise ValueError(
120
- f"Invalid components passed: {set(kwargs.keys())} but expected {PromptComponent}"
121
- )
122
-
123
- user_prompt = PromptList(
124
- [kwargs[component.value] for component in self.user_prompt_order]
125
- )
126
- system_prompt = PromptList(
127
- [kwargs[component.value] for component in self.system_prompt_order]
128
- )
129
- return {"user_prompt": user_prompt, "system_prompt": system_prompt}
130
-
131
- def get_prompts(self, **kwargs) -> Dict[str, Prompt]:
132
- """Get both prompts for the LLM call."""
133
- prompts = self.arrange_components(**kwargs)
134
- return {
135
- "user_prompt": prompts["user_prompt"].reduce(),
136
- "system_prompt": prompts["system_prompt"].reduce(),
137
- }
20
+ Args:
21
+ template_str (str): The Jinja2 template string
138
22
 
23
+ Returns:
24
+ Set[str]: A set of variable names found in the template
25
+ """
26
+ env = Environment()
27
+ ast = env.parse(template_str)
28
+ return meta.find_undeclared_variables(ast)
139
29
 
140
- class PromptConstructorMixin:
141
- """Mixin for constructing prompts for the LLM call.
142
30
 
31
+ class PromptConstructor:
32
+ """
143
33
  The pieces of a prompt are:
144
34
  - The agent instructions - "You are answering questions as if you were a human. Do not break character."
145
35
  - The persona prompt - "You are an agent with the following persona: {'age': 22, 'hair': 'brown', 'height': 5.5}"
@@ -149,16 +39,42 @@ class PromptConstructorMixin:
149
39
  This is mixed into the Invigilator class.
150
40
  """
151
41
 
152
- prompt_plan = PromptPlan()
42
+ def __init__(self, invigilator):
43
+ self.invigilator = invigilator
44
+ self.agent = invigilator.agent
45
+ self.question = invigilator.question
46
+ self.scenario = invigilator.scenario
47
+ self.survey = invigilator.survey
48
+ self.model = invigilator.model
49
+ self.current_answers = invigilator.current_answers
50
+ self.memory_plan = invigilator.memory_plan
51
+ self.prompt_plan = PromptPlan()
52
+
53
+ @property
54
+ def scenario_file_keys(self) -> list:
55
+ """We need to find all the keys in the scenario that refer to FileStore objects.
56
+ These will be used to append to the prompt a list of files that are part of the scenario.
57
+ """
58
+ from edsl.scenarios.FileStore import FileStore
59
+
60
+ file_entries = []
61
+ for key, value in self.scenario.items():
62
+ if isinstance(value, FileStore):
63
+ file_entries.append(key)
64
+ return file_entries
153
65
 
154
66
  @property
155
67
  def agent_instructions_prompt(self) -> Prompt:
156
68
  """
157
69
  >>> from edsl.agents.InvigilatorBase import InvigilatorBase
158
70
  >>> i = InvigilatorBase.example()
159
- >>> i.agent_instructions_prompt
71
+ >>> i.prompt_constructor.agent_instructions_prompt
160
72
  Prompt(text=\"""You are answering questions as if you were a human. Do not break character.\""")
161
73
  """
74
+ from edsl import Agent
75
+
76
+ if self.agent == Agent(): # if agent is empty, then return an empty prompt
77
+ return Prompt(text="")
162
78
  if not hasattr(self, "_agent_instructions_prompt"):
163
79
  applicable_prompts = prompt_lookup(
164
80
  component_type="agent_instructions",
@@ -176,47 +92,56 @@ class PromptConstructorMixin:
176
92
  """
177
93
  >>> from edsl.agents.InvigilatorBase import InvigilatorBase
178
94
  >>> i = InvigilatorBase.example()
179
- >>> i.agent_persona_prompt
95
+ >>> i.prompt_constructor.agent_persona_prompt
180
96
  Prompt(text=\"""You are an agent with the following persona:
181
97
  {'age': 22, 'hair': 'brown', 'height': 5.5}\""")
182
98
 
183
99
  """
184
- if not hasattr(self, "_agent_persona_prompt"):
185
- if not hasattr(self.agent, "agent_persona"):
186
- applicable_prompts = prompt_lookup(
187
- component_type="agent_persona",
188
- model=self.model.model,
189
- )
190
- persona_prompt_template = applicable_prompts[0]()
191
- else:
192
- persona_prompt_template = self.agent.agent_persona
193
-
194
- # TODO: This multiple passing of agent traits - not sure if it is necessary. Not harmful.
195
- if undefined := persona_prompt_template.undefined_template_variables(
196
- self.agent.traits
197
- | {"traits": self.agent.traits}
198
- | {"codebook": self.agent.codebook}
199
- | {"traits": self.agent.traits}
200
- ):
201
- raise QuestionScenarioRenderError(
202
- f"Agent persona still has variables that were not rendered: {undefined}"
203
- )
100
+ from edsl import Agent
101
+
102
+ if hasattr(self, "_agent_persona_prompt"):
103
+ return self._agent_persona_prompt
204
104
 
205
- persona_prompt = persona_prompt_template.render(
206
- self.agent.traits | {"traits": self.agent.traits},
207
- codebook=self.agent.codebook,
208
- traits=self.agent.traits,
105
+ if self.agent == Agent(): # if agent is empty, then return an empty prompt
106
+ return Prompt(text="")
107
+
108
+ if not hasattr(self.agent, "agent_persona"):
109
+ applicable_prompts = prompt_lookup(
110
+ component_type="agent_persona",
111
+ model=self.model.model,
209
112
  )
210
- if persona_prompt.has_variables:
211
- raise QuestionScenarioRenderError(
212
- "Agent persona still has variables that were not rendered."
213
- )
214
- self._agent_persona_prompt = persona_prompt
113
+ persona_prompt_template = applicable_prompts[0]()
114
+ else:
115
+ persona_prompt_template = self.agent.agent_persona
116
+
117
+ # TODO: This multiple passing of agent traits - not sure if it is necessary. Not harmful.
118
+ template_parameter_dictionary = (
119
+ self.agent.traits
120
+ | {"traits": self.agent.traits}
121
+ | {"codebook": self.agent.codebook}
122
+ | {"traits": self.agent.traits}
123
+ )
124
+
125
+ if undefined := persona_prompt_template.undefined_template_variables(
126
+ template_parameter_dictionary
127
+ ):
128
+ raise QuestionScenarioRenderError(
129
+ f"Agent persona still has variables that were not rendered: {undefined}"
130
+ )
131
+
132
+ persona_prompt = persona_prompt_template.render(template_parameter_dictionary)
133
+ if persona_prompt.has_variables:
134
+ raise QuestionScenarioRenderError(
135
+ "Agent persona still has variables that were not rendered."
136
+ )
137
+
138
+ self._agent_persona_prompt = persona_prompt
215
139
 
216
140
  return self._agent_persona_prompt
217
141
 
218
142
  def prior_answers_dict(self) -> dict:
219
143
  d = self.survey.question_names_to_questions()
144
+ # This attaches the answer to the question
220
145
  for question, answer in self.current_answers.items():
221
146
  if question in d:
222
147
  d[question].answer = answer
@@ -226,51 +151,70 @@ class PromptConstructorMixin:
226
151
  d[new_question].comment = answer
227
152
  return d
228
153
 
154
+ @property
155
+ def question_file_keys(self):
156
+ raw_question_text = self.question.question_text
157
+ variables = get_jinja2_variables(raw_question_text)
158
+ question_file_keys = []
159
+ for var in variables:
160
+ if var in self.scenario_file_keys:
161
+ question_file_keys.append(var)
162
+ return question_file_keys
163
+
229
164
  @property
230
165
  def question_instructions_prompt(self) -> Prompt:
231
166
  """
232
167
  >>> from edsl.agents.InvigilatorBase import InvigilatorBase
233
168
  >>> i = InvigilatorBase.example()
234
- >>> i.question_instructions_prompt
169
+ >>> i.prompt_constructor.question_instructions_prompt
235
170
  Prompt(text=\"""...
236
171
  ...
237
172
  """
173
+ # The user might have passed a custom prompt, which would be stored in _question_instructions_prompt
238
174
  if not hasattr(self, "_question_instructions_prompt"):
175
+ # Gets the instructions for the question - this is how the question should be answered
239
176
  question_prompt = self.question.get_instructions(model=self.model.model)
240
177
 
241
- # TODO: Try to populate the answers in the question object if they are available
242
- # d = self.survey.question_names_to_questions()
243
- # for question, answer in self.current_answers.items():
244
- # if question in d:
245
- # d[question].answer = answer
246
- # else:
247
- # # adds a comment to the question
248
- # if (new_question := question.split("_comment")[0]) in d:
249
- # d[new_question].comment = answer
250
-
178
+ # Get the data for the question - this is a dictionary of the question data
179
+ # e.g., {'question_text': 'Do you like school?', 'question_name': 'q0', 'question_options': ['yes', 'no']}
251
180
  question_data = self.question.data.copy()
252
181
 
253
182
  # check to see if the question_options is actually a string
254
- # This is used when the user is using the question_options as a variable from a sceario
183
+ # This is used when the user is using the question_options as a variable from a scenario
255
184
  # if "question_options" in question_data:
256
185
  if isinstance(self.question.data.get("question_options", None), str):
257
- from jinja2 import Environment, meta
258
-
259
186
  env = Environment()
260
187
  parsed_content = env.parse(self.question.data["question_options"])
261
188
  question_option_key = list(
262
189
  meta.find_undeclared_variables(parsed_content)
263
190
  )[0]
264
191
 
192
+ # look to see if the question_option_key is in the scenario
265
193
  if isinstance(
266
194
  question_options := self.scenario.get(question_option_key), list
267
195
  ):
268
196
  question_data["question_options"] = question_options
269
197
  self.question.question_options = question_options
270
198
 
199
+ # might be getting it from the prior answers
200
+ if self.prior_answers_dict().get(question_option_key) is not None:
201
+ if isinstance(
202
+ question_options := self.prior_answers_dict()
203
+ .get(question_option_key)
204
+ .answer,
205
+ list,
206
+ ):
207
+ question_data["question_options"] = question_options
208
+ self.question.question_options = question_options
209
+
271
210
  replacement_dict = (
272
- question_data
273
- | self.scenario
211
+ {key: f"<see file {key}>" for key in self.scenario_file_keys}
212
+ | question_data
213
+ | {
214
+ k: v
215
+ for k, v in self.scenario.items()
216
+ if k not in self.scenario_file_keys
217
+ } # don't include images in the replacement dict
274
218
  | self.prior_answers_dict()
275
219
  | {"agent": self.agent}
276
220
  | {
@@ -280,9 +224,10 @@ class PromptConstructorMixin:
280
224
  ),
281
225
  }
282
226
  )
283
- # breakpoint()
227
+
284
228
  rendered_instructions = question_prompt.render(replacement_dict)
285
- # breakpoint()
229
+
230
+ # is there anything left to render?
286
231
  undefined_template_variables = (
287
232
  rendered_instructions.undefined_template_variables({})
288
233
  )
@@ -296,11 +241,14 @@ class PromptConstructorMixin:
296
241
  )
297
242
 
298
243
  if undefined_template_variables:
244
+ # breakpoint()
299
245
  raise QuestionScenarioRenderError(
300
246
  f"Question instructions still has variables: {undefined_template_variables}."
301
247
  )
302
248
 
303
- # Check if question has an instructions
249
+ ####################################
250
+ # Check if question has instructions - these are instructions in a Survey that can apply to multiple follow-on questions
251
+ ####################################
304
252
  relevant_instructions = self.survey.relevant_instructions(
305
253
  self.question.question_name
306
254
  )
@@ -329,6 +277,23 @@ class PromptConstructorMixin:
329
277
  self._prior_question_memory_prompt = memory_prompt
330
278
  return self._prior_question_memory_prompt
331
279
 
280
+ def create_memory_prompt(self, question_name: str) -> Prompt:
281
+ """Create a memory for the agent.
282
+
283
+ The returns a memory prompt for the agent.
284
+
285
+ >>> from edsl.agents.InvigilatorBase import InvigilatorBase
286
+ >>> i = InvigilatorBase.example()
287
+ >>> i.current_answers = {"q0": "Prior answer"}
288
+ >>> i.memory_plan.add_single_memory("q1", "q0")
289
+ >>> p = i.prompt_constructor.create_memory_prompt("q1")
290
+ >>> p.text.strip().replace("\\n", " ").replace("\\t", " ")
291
+ 'Before the question you are now answering, you already answered the following question(s): Question: Do you like school? Answer: Prior answer'
292
+ """
293
+ return self.memory_plan.get_memory_prompt_fragment(
294
+ question_name, self.current_answers
295
+ )
296
+
332
297
  def construct_system_prompt(self) -> Prompt:
333
298
  """Construct the system prompt for the LLM call."""
334
299
  import warnings
@@ -357,15 +322,18 @@ class PromptConstructorMixin:
357
322
  >>> i.get_prompts()
358
323
  {'user_prompt': ..., 'system_prompt': ...}
359
324
  """
325
+ # breakpoint()
360
326
  prompts = self.prompt_plan.get_prompts(
361
327
  agent_instructions=self.agent_instructions_prompt,
362
328
  agent_persona=self.agent_persona_prompt,
363
329
  question_instructions=self.question_instructions_prompt,
364
330
  prior_question_memory=self.prior_question_memory_prompt,
365
331
  )
366
-
367
- if hasattr(self.scenario, "has_image") and self.scenario.has_image:
368
- prompts["encoded_image"] = self.scenario["encoded_image"]
332
+ if self.question_file_keys:
333
+ files_list = []
334
+ for key in self.question_file_keys:
335
+ files_list.append(self.scenario[key])
336
+ prompts["files_list"] = files_list
369
337
  return prompts
370
338
 
371
339
  def _get_scenario_with_image(self) -> Scenario:
@@ -0,0 +1,129 @@
1
+ import enum
2
+ from typing import Dict, Optional
3
+ from collections import UserList
4
+ from edsl.prompts import Prompt
5
+
6
+
7
+ class PromptComponent(enum.Enum):
8
+ AGENT_INSTRUCTIONS = "agent_instructions"
9
+ AGENT_PERSONA = "agent_persona"
10
+ QUESTION_INSTRUCTIONS = "question_instructions"
11
+ PRIOR_QUESTION_MEMORY = "prior_question_memory"
12
+
13
+
14
+ class PromptList(UserList):
15
+ separator = Prompt(" ")
16
+
17
+ def reduce(self):
18
+ """Reduce the list of prompts to a single prompt.
19
+
20
+ >>> p = PromptList([Prompt("You are a happy-go lucky agent."), Prompt("You are an agent with the following persona: {'age': 22, 'hair': 'brown', 'height': 5.5}")])
21
+ >>> p.reduce()
22
+ Prompt(text=\"""You are a happy-go lucky agent. You are an agent with the following persona: {'age': 22, 'hair': 'brown', 'height': 5.5}\""")
23
+
24
+ """
25
+ p = self[0]
26
+ for prompt in self[1:]:
27
+ if len(prompt) > 0:
28
+ p = p + self.separator + prompt
29
+ return p
30
+
31
+
32
+ class PromptPlan:
33
+ """A plan for constructing prompts for the LLM call.
34
+ Every prompt plan has a user prompt order and a system prompt order.
35
+ It must contain each of the values in the PromptComponent enum.
36
+
37
+
38
+ >>> p = PromptPlan(user_prompt_order=(PromptComponent.AGENT_INSTRUCTIONS, PromptComponent.AGENT_PERSONA),system_prompt_order=(PromptComponent.QUESTION_INSTRUCTIONS, PromptComponent.PRIOR_QUESTION_MEMORY))
39
+ >>> p._is_valid_plan()
40
+ True
41
+
42
+ >>> p.arrange_components(agent_instructions=1, agent_persona=2, question_instructions=3, prior_question_memory=4)
43
+ {'user_prompt': ..., 'system_prompt': ...}
44
+
45
+ >>> p = PromptPlan(user_prompt_order=("agent_instructions", ), system_prompt_order=("question_instructions", "prior_question_memory"))
46
+ Traceback (most recent call last):
47
+ ...
48
+ ValueError: Invalid plan: must contain each value of PromptComponent exactly once.
49
+
50
+ """
51
+
52
+ def __init__(
53
+ self,
54
+ user_prompt_order: Optional[tuple] = None,
55
+ system_prompt_order: Optional[tuple] = None,
56
+ ):
57
+ """Initialize the PromptPlan."""
58
+
59
+ if user_prompt_order is None:
60
+ user_prompt_order = (
61
+ PromptComponent.QUESTION_INSTRUCTIONS,
62
+ PromptComponent.PRIOR_QUESTION_MEMORY,
63
+ )
64
+ if system_prompt_order is None:
65
+ system_prompt_order = (
66
+ PromptComponent.AGENT_INSTRUCTIONS,
67
+ PromptComponent.AGENT_PERSONA,
68
+ )
69
+
70
+ # very commmon way to screw this up given how python treats single strings as iterables
71
+ if isinstance(user_prompt_order, str):
72
+ user_prompt_order = (user_prompt_order,)
73
+
74
+ if isinstance(system_prompt_order, str):
75
+ system_prompt_order = (system_prompt_order,)
76
+
77
+ if not isinstance(user_prompt_order, tuple):
78
+ raise TypeError(
79
+ f"Expected a tuple, but got {type(user_prompt_order).__name__}"
80
+ )
81
+
82
+ if not isinstance(system_prompt_order, tuple):
83
+ raise TypeError(
84
+ f"Expected a tuple, but got {type(system_prompt_order).__name__}"
85
+ )
86
+
87
+ self.user_prompt_order = self._convert_to_enum(user_prompt_order)
88
+ self.system_prompt_order = self._convert_to_enum(system_prompt_order)
89
+ if not self._is_valid_plan():
90
+ raise ValueError(
91
+ "Invalid plan: must contain each value of PromptComponent exactly once."
92
+ )
93
+
94
+ def _convert_to_enum(self, prompt_order: tuple):
95
+ """Convert string names to PromptComponent enum values."""
96
+ return tuple(
97
+ PromptComponent(component) if isinstance(component, str) else component
98
+ for component in prompt_order
99
+ )
100
+
101
+ def _is_valid_plan(self):
102
+ """Check if the plan is valid."""
103
+ combined = self.user_prompt_order + self.system_prompt_order
104
+ return set(combined) == set(PromptComponent)
105
+
106
+ def arrange_components(self, **kwargs) -> Dict[PromptComponent, Prompt]:
107
+ """Arrange the components in the order specified by the plan."""
108
+ # check is valid components passed
109
+ component_strings = set([pc.value for pc in PromptComponent])
110
+ if not set(kwargs.keys()) == component_strings:
111
+ raise ValueError(
112
+ f"Invalid components passed: {set(kwargs.keys())} but expected {PromptComponent}"
113
+ )
114
+
115
+ user_prompt = PromptList(
116
+ [kwargs[component.value] for component in self.user_prompt_order]
117
+ )
118
+ system_prompt = PromptList(
119
+ [kwargs[component.value] for component in self.system_prompt_order]
120
+ )
121
+ return {"user_prompt": user_prompt, "system_prompt": system_prompt}
122
+
123
+ def get_prompts(self, **kwargs) -> Dict[str, Prompt]:
124
+ """Get both prompts for the LLM call."""
125
+ prompts = self.arrange_components(**kwargs)
126
+ return {
127
+ "user_prompt": prompts["user_prompt"].reduce(),
128
+ "system_prompt": prompts["system_prompt"].reduce(),
129
+ }