edsl 0.1.31.dev4__py3-none-any.whl → 0.1.33__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (188) hide show
  1. edsl/Base.py +9 -3
  2. edsl/TemplateLoader.py +24 -0
  3. edsl/__init__.py +8 -3
  4. edsl/__version__.py +1 -1
  5. edsl/agents/Agent.py +40 -8
  6. edsl/agents/AgentList.py +43 -0
  7. edsl/agents/Invigilator.py +136 -221
  8. edsl/agents/InvigilatorBase.py +148 -59
  9. edsl/agents/{PromptConstructionMixin.py → PromptConstructor.py} +154 -85
  10. edsl/agents/__init__.py +1 -0
  11. edsl/auto/AutoStudy.py +117 -0
  12. edsl/auto/StageBase.py +230 -0
  13. edsl/auto/StageGenerateSurvey.py +178 -0
  14. edsl/auto/StageLabelQuestions.py +125 -0
  15. edsl/auto/StagePersona.py +61 -0
  16. edsl/auto/StagePersonaDimensionValueRanges.py +88 -0
  17. edsl/auto/StagePersonaDimensionValues.py +74 -0
  18. edsl/auto/StagePersonaDimensions.py +69 -0
  19. edsl/auto/StageQuestions.py +73 -0
  20. edsl/auto/SurveyCreatorPipeline.py +21 -0
  21. edsl/auto/utilities.py +224 -0
  22. edsl/config.py +48 -47
  23. edsl/conjure/Conjure.py +6 -0
  24. edsl/coop/PriceFetcher.py +58 -0
  25. edsl/coop/coop.py +50 -7
  26. edsl/data/Cache.py +35 -1
  27. edsl/data/CacheHandler.py +3 -4
  28. edsl/data_transfer_models.py +73 -38
  29. edsl/enums.py +8 -0
  30. edsl/exceptions/general.py +10 -8
  31. edsl/exceptions/language_models.py +25 -1
  32. edsl/exceptions/questions.py +62 -5
  33. edsl/exceptions/results.py +4 -0
  34. edsl/inference_services/AnthropicService.py +13 -11
  35. edsl/inference_services/AwsBedrock.py +112 -0
  36. edsl/inference_services/AzureAI.py +214 -0
  37. edsl/inference_services/DeepInfraService.py +4 -3
  38. edsl/inference_services/GoogleService.py +16 -12
  39. edsl/inference_services/GroqService.py +5 -4
  40. edsl/inference_services/InferenceServiceABC.py +58 -3
  41. edsl/inference_services/InferenceServicesCollection.py +13 -8
  42. edsl/inference_services/MistralAIService.py +120 -0
  43. edsl/inference_services/OllamaService.py +18 -0
  44. edsl/inference_services/OpenAIService.py +55 -56
  45. edsl/inference_services/TestService.py +80 -0
  46. edsl/inference_services/TogetherAIService.py +170 -0
  47. edsl/inference_services/models_available_cache.py +25 -0
  48. edsl/inference_services/registry.py +19 -1
  49. edsl/jobs/Answers.py +10 -12
  50. edsl/jobs/FailedQuestion.py +78 -0
  51. edsl/jobs/Jobs.py +137 -41
  52. edsl/jobs/buckets/BucketCollection.py +24 -15
  53. edsl/jobs/buckets/TokenBucket.py +105 -18
  54. edsl/jobs/interviews/Interview.py +393 -83
  55. edsl/jobs/interviews/{interview_exception_tracking.py → InterviewExceptionCollection.py} +22 -18
  56. edsl/jobs/interviews/InterviewExceptionEntry.py +167 -0
  57. edsl/jobs/runners/JobsRunnerAsyncio.py +152 -160
  58. edsl/jobs/runners/JobsRunnerStatus.py +331 -0
  59. edsl/jobs/tasks/QuestionTaskCreator.py +30 -23
  60. edsl/jobs/tasks/TaskCreators.py +1 -1
  61. edsl/jobs/tasks/TaskHistory.py +205 -126
  62. edsl/language_models/LanguageModel.py +297 -177
  63. edsl/language_models/ModelList.py +2 -2
  64. edsl/language_models/RegisterLanguageModelsMeta.py +14 -29
  65. edsl/language_models/fake_openai_call.py +15 -0
  66. edsl/language_models/fake_openai_service.py +61 -0
  67. edsl/language_models/registry.py +25 -8
  68. edsl/language_models/repair.py +0 -19
  69. edsl/language_models/utilities.py +61 -0
  70. edsl/notebooks/Notebook.py +20 -2
  71. edsl/prompts/Prompt.py +52 -2
  72. edsl/questions/AnswerValidatorMixin.py +23 -26
  73. edsl/questions/QuestionBase.py +330 -249
  74. edsl/questions/QuestionBaseGenMixin.py +133 -0
  75. edsl/questions/QuestionBasePromptsMixin.py +266 -0
  76. edsl/questions/QuestionBudget.py +99 -42
  77. edsl/questions/QuestionCheckBox.py +227 -36
  78. edsl/questions/QuestionExtract.py +98 -28
  79. edsl/questions/QuestionFreeText.py +47 -31
  80. edsl/questions/QuestionFunctional.py +7 -0
  81. edsl/questions/QuestionList.py +141 -23
  82. edsl/questions/QuestionMultipleChoice.py +159 -66
  83. edsl/questions/QuestionNumerical.py +88 -47
  84. edsl/questions/QuestionRank.py +182 -25
  85. edsl/questions/Quick.py +41 -0
  86. edsl/questions/RegisterQuestionsMeta.py +31 -12
  87. edsl/questions/ResponseValidatorABC.py +170 -0
  88. edsl/questions/__init__.py +3 -4
  89. edsl/questions/decorators.py +21 -0
  90. edsl/questions/derived/QuestionLikertFive.py +10 -5
  91. edsl/questions/derived/QuestionLinearScale.py +15 -2
  92. edsl/questions/derived/QuestionTopK.py +10 -1
  93. edsl/questions/derived/QuestionYesNo.py +24 -3
  94. edsl/questions/descriptors.py +43 -7
  95. edsl/questions/prompt_templates/question_budget.jinja +13 -0
  96. edsl/questions/prompt_templates/question_checkbox.jinja +32 -0
  97. edsl/questions/prompt_templates/question_extract.jinja +11 -0
  98. edsl/questions/prompt_templates/question_free_text.jinja +3 -0
  99. edsl/questions/prompt_templates/question_linear_scale.jinja +11 -0
  100. edsl/questions/prompt_templates/question_list.jinja +17 -0
  101. edsl/questions/prompt_templates/question_multiple_choice.jinja +33 -0
  102. edsl/questions/prompt_templates/question_numerical.jinja +37 -0
  103. edsl/questions/question_registry.py +6 -2
  104. edsl/questions/templates/__init__.py +0 -0
  105. edsl/questions/templates/budget/__init__.py +0 -0
  106. edsl/questions/templates/budget/answering_instructions.jinja +7 -0
  107. edsl/questions/templates/budget/question_presentation.jinja +7 -0
  108. edsl/questions/templates/checkbox/__init__.py +0 -0
  109. edsl/questions/templates/checkbox/answering_instructions.jinja +10 -0
  110. edsl/questions/templates/checkbox/question_presentation.jinja +22 -0
  111. edsl/questions/templates/extract/__init__.py +0 -0
  112. edsl/questions/templates/extract/answering_instructions.jinja +7 -0
  113. edsl/questions/templates/extract/question_presentation.jinja +1 -0
  114. edsl/questions/templates/free_text/__init__.py +0 -0
  115. edsl/questions/templates/free_text/answering_instructions.jinja +0 -0
  116. edsl/questions/templates/free_text/question_presentation.jinja +1 -0
  117. edsl/questions/templates/likert_five/__init__.py +0 -0
  118. edsl/questions/templates/likert_five/answering_instructions.jinja +10 -0
  119. edsl/questions/templates/likert_five/question_presentation.jinja +12 -0
  120. edsl/questions/templates/linear_scale/__init__.py +0 -0
  121. edsl/questions/templates/linear_scale/answering_instructions.jinja +5 -0
  122. edsl/questions/templates/linear_scale/question_presentation.jinja +5 -0
  123. edsl/questions/templates/list/__init__.py +0 -0
  124. edsl/questions/templates/list/answering_instructions.jinja +4 -0
  125. edsl/questions/templates/list/question_presentation.jinja +5 -0
  126. edsl/questions/templates/multiple_choice/__init__.py +0 -0
  127. edsl/questions/templates/multiple_choice/answering_instructions.jinja +9 -0
  128. edsl/questions/templates/multiple_choice/html.jinja +0 -0
  129. edsl/questions/templates/multiple_choice/question_presentation.jinja +12 -0
  130. edsl/questions/templates/numerical/__init__.py +0 -0
  131. edsl/questions/templates/numerical/answering_instructions.jinja +8 -0
  132. edsl/questions/templates/numerical/question_presentation.jinja +7 -0
  133. edsl/questions/templates/rank/__init__.py +0 -0
  134. edsl/questions/templates/rank/answering_instructions.jinja +11 -0
  135. edsl/questions/templates/rank/question_presentation.jinja +15 -0
  136. edsl/questions/templates/top_k/__init__.py +0 -0
  137. edsl/questions/templates/top_k/answering_instructions.jinja +8 -0
  138. edsl/questions/templates/top_k/question_presentation.jinja +22 -0
  139. edsl/questions/templates/yes_no/__init__.py +0 -0
  140. edsl/questions/templates/yes_no/answering_instructions.jinja +6 -0
  141. edsl/questions/templates/yes_no/question_presentation.jinja +12 -0
  142. edsl/results/Dataset.py +20 -0
  143. edsl/results/DatasetExportMixin.py +58 -30
  144. edsl/results/DatasetTree.py +145 -0
  145. edsl/results/Result.py +32 -5
  146. edsl/results/Results.py +135 -46
  147. edsl/results/ResultsDBMixin.py +3 -3
  148. edsl/results/Selector.py +118 -0
  149. edsl/results/tree_explore.py +115 -0
  150. edsl/scenarios/FileStore.py +71 -10
  151. edsl/scenarios/Scenario.py +109 -24
  152. edsl/scenarios/ScenarioImageMixin.py +2 -2
  153. edsl/scenarios/ScenarioList.py +546 -21
  154. edsl/scenarios/ScenarioListExportMixin.py +24 -4
  155. edsl/scenarios/ScenarioListPdfMixin.py +153 -4
  156. edsl/study/SnapShot.py +8 -1
  157. edsl/study/Study.py +32 -0
  158. edsl/surveys/Rule.py +15 -3
  159. edsl/surveys/RuleCollection.py +21 -5
  160. edsl/surveys/Survey.py +707 -298
  161. edsl/surveys/SurveyExportMixin.py +71 -9
  162. edsl/surveys/SurveyFlowVisualizationMixin.py +2 -1
  163. edsl/surveys/SurveyQualtricsImport.py +284 -0
  164. edsl/surveys/instructions/ChangeInstruction.py +47 -0
  165. edsl/surveys/instructions/Instruction.py +34 -0
  166. edsl/surveys/instructions/InstructionCollection.py +77 -0
  167. edsl/surveys/instructions/__init__.py +0 -0
  168. edsl/templates/error_reporting/base.html +24 -0
  169. edsl/templates/error_reporting/exceptions_by_model.html +35 -0
  170. edsl/templates/error_reporting/exceptions_by_question_name.html +17 -0
  171. edsl/templates/error_reporting/exceptions_by_type.html +17 -0
  172. edsl/templates/error_reporting/interview_details.html +116 -0
  173. edsl/templates/error_reporting/interviews.html +10 -0
  174. edsl/templates/error_reporting/overview.html +5 -0
  175. edsl/templates/error_reporting/performance_plot.html +2 -0
  176. edsl/templates/error_reporting/report.css +74 -0
  177. edsl/templates/error_reporting/report.html +118 -0
  178. edsl/templates/error_reporting/report.js +25 -0
  179. edsl/utilities/utilities.py +40 -1
  180. {edsl-0.1.31.dev4.dist-info → edsl-0.1.33.dist-info}/METADATA +8 -2
  181. edsl-0.1.33.dist-info/RECORD +295 -0
  182. edsl/jobs/interviews/InterviewTaskBuildingMixin.py +0 -271
  183. edsl/jobs/interviews/retry_management.py +0 -37
  184. edsl/jobs/runners/JobsRunnerStatusMixin.py +0 -303
  185. edsl/utilities/gcp_bucket/simple_example.py +0 -9
  186. edsl-0.1.31.dev4.dist-info/RECORD +0 -204
  187. {edsl-0.1.31.dev4.dist-info → edsl-0.1.33.dist-info}/LICENSE +0 -0
  188. {edsl-0.1.31.dev4.dist-info → edsl-0.1.33.dist-info}/WHEEL +0 -0
@@ -8,34 +8,27 @@ from edsl.data_transfer_models import AgentResponseDict
8
8
 
9
9
  from edsl.data.Cache import Cache
10
10
 
11
- # from edsl.agents.Agent import Agent
12
11
  from edsl.questions.QuestionBase import QuestionBase
13
12
  from edsl.scenarios.Scenario import Scenario
14
13
  from edsl.surveys.MemoryPlan import MemoryPlan
15
14
  from edsl.language_models.LanguageModel import LanguageModel
16
15
 
16
+ from edsl.data_transfer_models import EDSLResultObjectInput
17
+ from edsl.agents.PromptConstructor import PromptConstructor
18
+
17
19
 
18
20
  class InvigilatorBase(ABC):
19
21
  """An invigiator (someone who administers an exam) is a class that is responsible for administering a question to an agent.
20
22
 
21
23
  >>> InvigilatorBase.example().answer_question()
22
- {'message': '{"answer": "SPAM!"}'}
24
+ {'message': [{'text': 'SPAM!'}], 'usage': {'prompt_tokens': 1, 'completion_tokens': 1}}
23
25
 
24
- >>> InvigilatorBase.example().get_failed_task_result()
25
- {'answer': None, 'comment': 'Failed to get response', ...
26
+ >>> InvigilatorBase.example().get_failed_task_result(failure_reason="Failed to get response").comment
27
+ 'Failed to get response'
26
28
 
27
29
  This returns an empty prompt because there is no memory the agent needs to have at q0.
28
30
 
29
- >>> InvigilatorBase.example().create_memory_prompt("q0")
30
- Prompt(text=\"""\""")
31
31
 
32
- >>> i = InvigilatorBase.example()
33
- >>> i.current_answers = {"q0": "Prior answer"}
34
- >>> i.memory_plan.add_single_memory("q1", "q0")
35
- >>> i.create_memory_prompt("q1")
36
- Prompt(text=\"""
37
- Before the question you are now answering, you already answered the following question(s):
38
- ...
39
32
  """
40
33
 
41
34
  def __init__(
@@ -51,6 +44,7 @@ class InvigilatorBase(ABC):
51
44
  iteration: Optional[int] = 1,
52
45
  additional_prompt_data: Optional[dict] = None,
53
46
  sidecar_model: Optional[LanguageModel] = None,
47
+ raise_validation_errors: Optional[bool] = True,
54
48
  ):
55
49
  """Initialize a new Invigilator."""
56
50
  self.agent = agent
@@ -64,6 +58,78 @@ class InvigilatorBase(ABC):
64
58
  self.cache = cache
65
59
  self.sidecar_model = sidecar_model
66
60
  self.survey = survey
61
+ self.raise_validation_errors = raise_validation_errors
62
+
63
+ self.raw_model_response = (
64
+ None # placeholder for the raw response from the model
65
+ )
66
+
67
+ @property
68
+ def prompt_constructor(self) -> PromptConstructor:
69
+ """Return the prompt constructor."""
70
+ return PromptConstructor(self)
71
+
72
+ def to_dict(self):
73
+ attributes = [
74
+ "agent",
75
+ "question",
76
+ "scenario",
77
+ "model",
78
+ "memory_plan",
79
+ "current_answers",
80
+ "iteration",
81
+ "additional_prompt_data",
82
+ "cache",
83
+ "sidecar_model",
84
+ "survey",
85
+ ]
86
+
87
+ def serialize_attribute(attr):
88
+ value = getattr(self, attr)
89
+ if value is None:
90
+ return None
91
+ if hasattr(value, "to_dict"):
92
+ return value.to_dict()
93
+ if isinstance(value, (int, float, str, bool, dict, list)):
94
+ return value
95
+ return str(value)
96
+
97
+ return {attr: serialize_attribute(attr) for attr in attributes}
98
+
99
+ @classmethod
100
+ def from_dict(cls, data):
101
+ from edsl.agents.Agent import Agent
102
+ from edsl.questions import QuestionBase
103
+ from edsl.scenarios.Scenario import Scenario
104
+ from edsl.surveys.MemoryPlan import MemoryPlan
105
+ from edsl.language_models.LanguageModel import LanguageModel
106
+ from edsl.surveys.Survey import Survey
107
+
108
+ agent = Agent.from_dict(data["agent"])
109
+ question = QuestionBase.from_dict(data["question"])
110
+ scenario = Scenario.from_dict(data["scenario"])
111
+ model = LanguageModel.from_dict(data["model"])
112
+ memory_plan = MemoryPlan.from_dict(data["memory_plan"])
113
+ survey = Survey.from_dict(data["survey"])
114
+ current_answers = data["current_answers"]
115
+ iteration = data["iteration"]
116
+ additional_prompt_data = data["additional_prompt_data"]
117
+ cache = Cache.from_dict(data["cache"])
118
+ sidecar_model = LanguageModel.from_dict(data["sidecar_model"])
119
+
120
+ return cls(
121
+ agent=agent,
122
+ question=question,
123
+ scenario=scenario,
124
+ model=model,
125
+ memory_plan=memory_plan,
126
+ current_answers=current_answers,
127
+ survey=survey,
128
+ iteration=iteration,
129
+ additional_prompt_data=additional_prompt_data,
130
+ cache=cache,
131
+ sidecar_model=sidecar_model,
132
+ )
67
133
 
68
134
  def __repr__(self) -> str:
69
135
  """Return a string representation of the Invigilator.
@@ -74,18 +140,45 @@ class InvigilatorBase(ABC):
74
140
  """
75
141
  return f"{self.__class__.__name__}(agent={repr(self.agent)}, question={repr(self.question)}, scneario={repr(self.scenario)}, model={repr(self.model)}, memory_plan={repr(self.memory_plan)}, current_answers={repr(self.current_answers)}, iteration{repr(self.iteration)}, additional_prompt_data={repr(self.additional_prompt_data)}, cache={repr(self.cache)}, sidecarmodel={repr(self.sidecar_model)})"
76
142
 
77
- def get_failed_task_result(self) -> AgentResponseDict:
143
+ def get_failed_task_result(self, failure_reason) -> EDSLResultObjectInput:
78
144
  """Return an AgentResponseDict used in case the question-asking fails.
79
145
 
80
- >>> InvigilatorBase.example().get_failed_task_result()
81
- {'answer': None, 'comment': 'Failed to get response', ...}
146
+ Possible reasons include:
147
+ - Legimately skipped because of skip logic
148
+ - Failed to get response from the model
149
+
82
150
  """
83
- return AgentResponseDict(
84
- answer=None,
85
- comment="Failed to get response",
86
- question_name=self.question.question_name,
87
- prompts=self.get_prompts(),
88
- )
151
+ data = {
152
+ "answer": None,
153
+ "generated_tokens": None,
154
+ "comment": failure_reason,
155
+ "question_name": self.question.question_name,
156
+ "prompts": self.get_prompts(),
157
+ "cached_response": None,
158
+ "raw_model_response": None,
159
+ "cache_used": None,
160
+ "cache_key": None,
161
+ }
162
+ return EDSLResultObjectInput(**data)
163
+
164
+ # breakpoint()
165
+ # if hasattr(self, "augmented_model_response"):
166
+ # import json
167
+
168
+ # generated_tokens = json.loads(self.augmented_model_response)["answer"][
169
+ # "generated_tokens"
170
+ # ]
171
+ # else:
172
+ # generated_tokens = "Filled in by InvigilatorBase.get_failed_task_result"
173
+ # agent_response_dict = AgentResponseDict(
174
+ # answer=None,
175
+ # comment="Failed to get usable response",
176
+ # generated_tokens=generated_tokens,
177
+ # question_name=self.question.question_name,
178
+ # prompts=self.get_prompts(),
179
+ # )
180
+ # # breakpoint()
181
+ # return agent_response_dict
89
182
 
90
183
  def get_prompts(self) -> Dict[str, Prompt]:
91
184
  """Return the prompt used."""
@@ -111,24 +204,10 @@ class InvigilatorBase(ABC):
111
204
 
112
205
  return main()
113
206
 
114
- def create_memory_prompt(self, question_name: str) -> Prompt:
115
- """Create a memory for the agent.
116
-
117
- The returns a memory prompt for the agent.
118
-
119
- >>> i = InvigilatorBase.example()
120
- >>> i.current_answers = {"q0": "Prior answer"}
121
- >>> i.memory_plan.add_single_memory("q1", "q0")
122
- >>> p = i.create_memory_prompt("q1")
123
- >>> p.text.strip().replace("\\n", " ").replace("\\t", " ")
124
- 'Before the question you are now answering, you already answered the following question(s): Question: Do you like school? Answer: Prior answer'
125
- """
126
- return self.memory_plan.get_memory_prompt_fragment(
127
- question_name, self.current_answers
128
- )
129
-
130
207
  @classmethod
131
- def example(cls, throw_an_exception=False, question=None, scenario=None):
208
+ def example(
209
+ cls, throw_an_exception=False, question=None, scenario=None, survey=None
210
+ ) -> "InvigilatorBase":
132
211
  """Return an example invigilator.
133
212
 
134
213
  >>> InvigilatorBase.example()
@@ -143,43 +222,53 @@ class InvigilatorBase(ABC):
143
222
 
144
223
  from edsl.enums import InferenceServiceType
145
224
 
146
- class TestLanguageModelGood(LanguageModel):
147
- """A test language model."""
225
+ from edsl import Model
226
+
227
+ model = Model("test", canned_response="SPAM!")
228
+ # class TestLanguageModelGood(LanguageModel):
229
+ # """A test language model."""
148
230
 
149
- _model_ = "test"
150
- _parameters_ = {"temperature": 0.5}
151
- _inference_service_ = InferenceServiceType.TEST.value
231
+ # _model_ = "test"
232
+ # _parameters_ = {"temperature": 0.5}
233
+ # _inference_service_ = InferenceServiceType.TEST.value
152
234
 
153
- async def async_execute_model_call(
154
- self, user_prompt: str, system_prompt: str
155
- ) -> dict[str, Any]:
156
- await asyncio.sleep(0.1)
157
- if hasattr(self, "throw_an_exception"):
158
- raise Exception("Error!")
159
- return {"message": """{"answer": "SPAM!"}"""}
235
+ # async def async_execute_model_call(
236
+ # self, user_prompt: str, system_prompt: str
237
+ # ) -> dict[str, Any]:
238
+ # await asyncio.sleep(0.1)
239
+ # if hasattr(self, "throw_an_exception"):
240
+ # raise Exception("Error!")
241
+ # return {"message": """{"answer": "SPAM!"}"""}
160
242
 
161
- def parse_response(self, raw_response: dict[str, Any]) -> str:
162
- """Parse the response from the model."""
163
- return raw_response["message"]
243
+ # def parse_response(self, raw_response: dict[str, Any]) -> str:
244
+ # """Parse the response from the model."""
245
+ # return raw_response["message"]
164
246
 
165
- model = TestLanguageModelGood()
166
247
  if throw_an_exception:
167
248
  model.throw_an_exception = True
168
249
  agent = Agent.example()
169
250
  # question = QuestionMultipleChoice.example()
170
251
  from edsl.surveys import Survey
171
252
 
172
- survey = Survey.example()
253
+ if not survey:
254
+ survey = Survey.example()
255
+ # if question:
256
+ # need to have the focal question name in the list of names
257
+ # survey._questions[0].question_name = question.question_name
258
+ # survey.add_question(question)
259
+ if question:
260
+ survey.add_question(question)
261
+
173
262
  question = question or survey.questions[0]
174
263
  scenario = scenario or Scenario.example()
175
264
  # memory_plan = None #memory_plan = MemoryPlan()
176
265
  from edsl import Survey
177
266
 
178
- memory_plan = MemoryPlan(survey=Survey.example())
267
+ memory_plan = MemoryPlan(survey=survey)
179
268
  current_answers = None
180
- from edsl.agents.PromptConstructionMixin import PromptConstructorMixin
269
+ from edsl.agents.PromptConstructor import PromptConstructor
181
270
 
182
- class InvigilatorExample(PromptConstructorMixin, InvigilatorBase):
271
+ class InvigilatorExample(InvigilatorBase):
183
272
  """An example invigilator."""
184
273
 
185
274
  async def async_answer_question(self):
@@ -1,15 +1,15 @@
1
- from typing import Dict, Any, Optional
1
+ from __future__ import annotations
2
+ from typing import Dict, Any, Optional, Set
2
3
  from collections import UserList
4
+ import enum
3
5
 
4
- # from functools import reduce
5
- from edsl.prompts.Prompt import Prompt
6
+ from jinja2 import Environment, meta
6
7
 
7
- # from edsl.utilities.decorators import sync_wrapper, jupyter_nb_handler
8
+ from edsl.prompts.Prompt import Prompt
9
+ from edsl.data_transfer_models import ImageInfo
8
10
  from edsl.prompts.registry import get_classes as prompt_lookup
9
11
  from edsl.exceptions import QuestionScenarioRenderError
10
12
 
11
- import enum
12
-
13
13
 
14
14
  class PromptComponent(enum.Enum):
15
15
  AGENT_INSTRUCTIONS = "agent_instructions"
@@ -18,6 +18,21 @@ class PromptComponent(enum.Enum):
18
18
  PRIOR_QUESTION_MEMORY = "prior_question_memory"
19
19
 
20
20
 
21
+ def get_jinja2_variables(template_str: str) -> Set[str]:
22
+ """
23
+ Extracts all variable names from a Jinja2 template using Jinja2's built-in parsing.
24
+
25
+ Args:
26
+ template_str (str): The Jinja2 template string
27
+
28
+ Returns:
29
+ Set[str]: A set of variable names found in the template
30
+ """
31
+ env = Environment()
32
+ ast = env.parse(template_str)
33
+ return meta.find_undeclared_variables(ast)
34
+
35
+
21
36
  class PromptList(UserList):
22
37
  separator = Prompt(" ")
23
38
 
@@ -136,7 +151,7 @@ class PromptPlan:
136
151
  }
137
152
 
138
153
 
139
- class PromptConstructorMixin:
154
+ class PromptConstructor:
140
155
  """Mixin for constructing prompts for the LLM call.
141
156
 
142
157
  The pieces of a prompt are:
@@ -148,16 +163,40 @@ class PromptConstructorMixin:
148
163
  This is mixed into the Invigilator class.
149
164
  """
150
165
 
151
- prompt_plan = PromptPlan()
166
+ def __init__(self, invigilator):
167
+ self.invigilator = invigilator
168
+ self.agent = invigilator.agent
169
+ self.question = invigilator.question
170
+ self.scenario = invigilator.scenario
171
+ self.survey = invigilator.survey
172
+ self.model = invigilator.model
173
+ self.current_answers = invigilator.current_answers
174
+ self.memory_plan = invigilator.memory_plan
175
+ self.prompt_plan = PromptPlan() # Assuming PromptPlan is defined elsewhere
176
+
177
+ # prompt_plan = PromptPlan()
178
+
179
+ @property
180
+ def scenario_image_keys(self):
181
+ image_entries = []
182
+
183
+ for key, value in self.scenario.items():
184
+ if isinstance(value, ImageInfo):
185
+ image_entries.append(key)
186
+ return image_entries
152
187
 
153
188
  @property
154
189
  def agent_instructions_prompt(self) -> Prompt:
155
190
  """
156
191
  >>> from edsl.agents.InvigilatorBase import InvigilatorBase
157
192
  >>> i = InvigilatorBase.example()
158
- >>> i.agent_instructions_prompt
193
+ >>> i.prompt_constructor.agent_instructions_prompt
159
194
  Prompt(text=\"""You are answering questions as if you were a human. Do not break character.\""")
160
195
  """
196
+ from edsl import Agent
197
+
198
+ if self.agent == Agent(): # if agent is empty, then return an empty prompt
199
+ return Prompt(text="")
161
200
  if not hasattr(self, "_agent_instructions_prompt"):
162
201
  applicable_prompts = prompt_lookup(
163
202
  component_type="agent_instructions",
@@ -175,12 +214,17 @@ class PromptConstructorMixin:
175
214
  """
176
215
  >>> from edsl.agents.InvigilatorBase import InvigilatorBase
177
216
  >>> i = InvigilatorBase.example()
178
- >>> i.agent_persona_prompt
217
+ >>> i.prompt_constructor.agent_persona_prompt
179
218
  Prompt(text=\"""You are an agent with the following persona:
180
219
  {'age': 22, 'hair': 'brown', 'height': 5.5}\""")
181
220
 
182
221
  """
183
222
  if not hasattr(self, "_agent_persona_prompt"):
223
+ from edsl import Agent
224
+
225
+ if self.agent == Agent(): # if agent is empty, then return an empty prompt
226
+ return Prompt(text="")
227
+
184
228
  if not hasattr(self.agent, "agent_persona"):
185
229
  applicable_prompts = prompt_lookup(
186
230
  component_type="agent_persona",
@@ -214,83 +258,80 @@ class PromptConstructorMixin:
214
258
 
215
259
  return self._agent_persona_prompt
216
260
 
261
+ def prior_answers_dict(self) -> dict:
262
+ d = self.survey.question_names_to_questions()
263
+ for question, answer in self.current_answers.items():
264
+ if question in d:
265
+ d[question].answer = answer
266
+ else:
267
+ # adds a comment to the question
268
+ if (new_question := question.split("_comment")[0]) in d:
269
+ d[new_question].comment = answer
270
+ return d
271
+
272
+ @property
273
+ def question_image_keys(self):
274
+ raw_question_text = self.question.question_text
275
+ variables = get_jinja2_variables(raw_question_text)
276
+ question_image_keys = []
277
+ for var in variables:
278
+ if var in self.scenario_image_keys:
279
+ question_image_keys.append(var)
280
+ return question_image_keys
281
+
217
282
  @property
218
283
  def question_instructions_prompt(self) -> Prompt:
219
284
  """
220
285
  >>> from edsl.agents.InvigilatorBase import InvigilatorBase
221
286
  >>> i = InvigilatorBase.example()
222
- >>> i.question_instructions_prompt
223
- Prompt(text=\"""You are being asked the following question: Do you like school?
224
- The options are
225
- <BLANKLINE>
226
- 0: yes
227
- <BLANKLINE>
228
- 1: no
229
- <BLANKLINE>
230
- Return a valid JSON formatted like this, selecting only the number of the option:
231
- {"answer": <put answer code here>, "comment": "<put explanation here>"}
232
- Only 1 option may be selected.\""")
233
-
234
- >>> from edsl import QuestionFreeText
235
- >>> q = QuestionFreeText(question_text = "Consider {{ X }}. What is your favorite color?", question_name = "q_color")
236
- >>> from edsl.agents.InvigilatorBase import InvigilatorBase
237
- >>> i = InvigilatorBase.example(question = q)
238
- >>> i.question_instructions_prompt
239
- Traceback (most recent call last):
287
+ >>> i.prompt_constructor.question_instructions_prompt
288
+ Prompt(text=\"""...
240
289
  ...
241
- edsl.exceptions.questions.QuestionScenarioRenderError: Question instructions still has variables: ['X'].
242
-
243
-
244
- >>> from edsl import QuestionFreeText
245
- >>> q = QuestionFreeText(question_text = "You were asked the question '{{ q0.question_text }}'. What is your favorite color?", question_name = "q_color")
246
- >>> from edsl.agents.InvigilatorBase import InvigilatorBase
247
- >>> i = InvigilatorBase.example(question = q)
248
- >>> i.question_instructions_prompt
249
- Prompt(text=\"""You are being asked the following question: You were asked the question 'Do you like school?'. What is your favorite color?
250
- Return a valid JSON formatted like this:
251
- {"answer": "<put free text answer here>"}\""")
252
-
253
- >>> from edsl import QuestionFreeText
254
- >>> q = QuestionFreeText(question_text = "You stated '{{ q0.answer }}'. What is your favorite color?", question_name = "q_color")
255
- >>> from edsl.agents.InvigilatorBase import InvigilatorBase
256
- >>> i = InvigilatorBase.example(question = q)
257
- >>> i.current_answers = {"q0": "I like school"}
258
- >>> i.question_instructions_prompt
259
- Prompt(text=\"""You are being asked the following question: You stated 'I like school'. What is your favorite color?
260
- Return a valid JSON formatted like this:
261
- {"answer": "<put free text answer here>"}\""")
262
-
263
-
264
290
  """
265
291
  if not hasattr(self, "_question_instructions_prompt"):
266
292
  question_prompt = self.question.get_instructions(model=self.model.model)
267
293
 
268
- # TODO: Try to populate the answers in the question object if they are available
269
- d = self.survey.question_names_to_questions()
270
- for question, answer in self.current_answers.items():
271
- if question in d:
272
- d[question].answer = answer
273
- else:
274
- # adds a comment to the question
275
- if (new_question := question.split("_comment")[0]) in d:
276
- d[new_question].comment = answer
294
+ # Are any of the scenario values ImageInfo
277
295
 
278
296
  question_data = self.question.data.copy()
279
297
 
280
- # check to see if the questio_options is actuall a string
281
- if "question_options" in question_data:
282
- if isinstance(self.question.data["question_options"], str):
283
- from jinja2 import Environment, meta
284
- env = Environment()
285
- parsed_content = env.parse(self.question.data['question_options'])
286
- question_option_key = list(meta.find_undeclared_variables(parsed_content))[0]
287
- question_data["question_options"] = self.scenario.get(question_option_key)
288
-
289
- #breakpoint()
290
- rendered_instructions = question_prompt.render(
291
- question_data | self.scenario | d | {"agent": self.agent}
298
+ # check to see if the question_options is actually a string
299
+ # This is used when the user is using the question_options as a variable from a sceario
300
+ # if "question_options" in question_data:
301
+ if isinstance(self.question.data.get("question_options", None), str):
302
+ env = Environment()
303
+ parsed_content = env.parse(self.question.data["question_options"])
304
+ question_option_key = list(
305
+ meta.find_undeclared_variables(parsed_content)
306
+ )[0]
307
+
308
+ if isinstance(
309
+ question_options := self.scenario.get(question_option_key), list
310
+ ):
311
+ question_data["question_options"] = question_options
312
+ self.question.question_options = question_options
313
+
314
+ replacement_dict = (
315
+ {key: "<see image>" for key in self.scenario_image_keys}
316
+ | question_data
317
+ | {
318
+ k: v
319
+ for k, v in self.scenario.items()
320
+ if k not in self.scenario_image_keys
321
+ } # don't include images in the replacement dict
322
+ | self.prior_answers_dict()
323
+ | {"agent": self.agent}
324
+ | {
325
+ "use_code": getattr(self.question, "_use_code", True),
326
+ "include_comment": getattr(
327
+ self.question, "_include_comment", False
328
+ ),
329
+ }
292
330
  )
293
331
 
332
+ rendered_instructions = question_prompt.render(replacement_dict)
333
+
334
+ # is there anything left to render?
294
335
  undefined_template_variables = (
295
336
  rendered_instructions.undefined_template_variables({})
296
337
  )
@@ -304,11 +345,25 @@ class PromptConstructorMixin:
304
345
  )
305
346
 
306
347
  if undefined_template_variables:
307
- print(undefined_template_variables)
308
348
  raise QuestionScenarioRenderError(
309
349
  f"Question instructions still has variables: {undefined_template_variables}."
310
350
  )
311
351
 
352
+ ####################################
353
+ # Check if question has instructions - these are instructions in a Survey that can apply to multiple follow-on questions
354
+ ####################################
355
+ relevant_instructions = self.survey.relevant_instructions(
356
+ self.question.question_name
357
+ )
358
+
359
+ if relevant_instructions != []:
360
+ preamble_text = Prompt(
361
+ text="Before answer this question, you were given the following instructions: "
362
+ )
363
+ for instruction in relevant_instructions:
364
+ preamble_text += instruction.text
365
+ rendered_instructions = preamble_text + rendered_instructions
366
+
312
367
  self._question_instructions_prompt = rendered_instructions
313
368
  return self._question_instructions_prompt
314
369
 
@@ -321,10 +376,27 @@ class PromptConstructorMixin:
321
376
  if self.memory_plan is not None:
322
377
  memory_prompt += self.create_memory_prompt(
323
378
  self.question.question_name
324
- ).render(self.scenario)
379
+ ).render(self.scenario | self.prior_answers_dict())
325
380
  self._prior_question_memory_prompt = memory_prompt
326
381
  return self._prior_question_memory_prompt
327
382
 
383
+ def create_memory_prompt(self, question_name: str) -> Prompt:
384
+ """Create a memory for the agent.
385
+
386
+ The returns a memory prompt for the agent.
387
+
388
+ >>> from edsl.agents.InvigilatorBase import InvigilatorBase
389
+ >>> i = InvigilatorBase.example()
390
+ >>> i.current_answers = {"q0": "Prior answer"}
391
+ >>> i.memory_plan.add_single_memory("q1", "q0")
392
+ >>> p = i.prompt_constructor.create_memory_prompt("q1")
393
+ >>> p.text.strip().replace("\\n", " ").replace("\\t", " ")
394
+ 'Before the question you are now answering, you already answered the following question(s): Question: Do you like school? Answer: Prior answer'
395
+ """
396
+ return self.memory_plan.get_memory_prompt_fragment(
397
+ question_name, self.current_answers
398
+ )
399
+
328
400
  def construct_system_prompt(self) -> Prompt:
329
401
  """Construct the system prompt for the LLM call."""
330
402
  import warnings
@@ -348,17 +420,10 @@ class PromptConstructorMixin:
348
420
 
349
421
  >>> from edsl import QuestionFreeText
350
422
  >>> from edsl.agents.InvigilatorBase import InvigilatorBase
351
- >>> q = QuestionFreeText(question_text="How are you today?", question_name="q0")
423
+ >>> q = QuestionFreeText(question_text="How are you today?", question_name="q_new")
352
424
  >>> i = InvigilatorBase.example(question = q)
353
425
  >>> i.get_prompts()
354
426
  {'user_prompt': ..., 'system_prompt': ...}
355
- >>> scenario = i._get_scenario_with_image()
356
- >>> scenario.has_image
357
- True
358
- >>> q = QuestionFreeText(question_text="How are you today?", question_name="q0")
359
- >>> i = InvigilatorBase.example(question = q, scenario = scenario)
360
- >>> i.get_prompts()
361
- {'user_prompt': ..., 'system_prompt': ..., 'encoded_image': ...'}
362
427
  """
363
428
  prompts = self.prompt_plan.get_prompts(
364
429
  agent_instructions=self.agent_instructions_prompt,
@@ -366,12 +431,16 @@ class PromptConstructorMixin:
366
431
  question_instructions=self.question_instructions_prompt,
367
432
  prior_question_memory=self.prior_question_memory_prompt,
368
433
  )
434
+ if len(self.question_image_keys) > 1:
435
+ raise ValueError("We can only handle one image per question.")
436
+ elif len(self.question_image_keys) == 1:
437
+ prompts["encoded_image"] = self.scenario[
438
+ self.question_image_keys[0]
439
+ ].encoded_image
369
440
 
370
- if hasattr(self.scenario, "has_image") and self.scenario.has_image:
371
- prompts["encoded_image"] = self.scenario["encoded_image"]
372
441
  return prompts
373
442
 
374
- def _get_scenario_with_image(self) -> Dict[str, Any]:
443
+ def _get_scenario_with_image(self) -> Scenario:
375
444
  """This is a helper function to get a scenario with an image, for testing purposes."""
376
445
  from edsl import Scenario
377
446
 
edsl/agents/__init__.py CHANGED
@@ -1,2 +1,3 @@
1
1
  from edsl.agents.Agent import Agent
2
2
  from edsl.agents.AgentList import AgentList
3
+ from edsl.agents.InvigilatorBase import InvigilatorBase