edsl 0.1.36.dev5__py3-none-any.whl → 0.1.36.dev7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (257) hide show
  1. edsl/Base.py +303 -303
  2. edsl/BaseDiff.py +260 -260
  3. edsl/TemplateLoader.py +24 -24
  4. edsl/__init__.py +48 -47
  5. edsl/__version__.py +1 -1
  6. edsl/agents/Agent.py +804 -804
  7. edsl/agents/AgentList.py +337 -337
  8. edsl/agents/Invigilator.py +222 -222
  9. edsl/agents/InvigilatorBase.py +298 -294
  10. edsl/agents/PromptConstructor.py +320 -312
  11. edsl/agents/__init__.py +3 -3
  12. edsl/agents/descriptors.py +86 -86
  13. edsl/agents/prompt_helpers.py +129 -129
  14. edsl/auto/AutoStudy.py +117 -117
  15. edsl/auto/StageBase.py +230 -230
  16. edsl/auto/StageGenerateSurvey.py +178 -178
  17. edsl/auto/StageLabelQuestions.py +125 -125
  18. edsl/auto/StagePersona.py +61 -61
  19. edsl/auto/StagePersonaDimensionValueRanges.py +88 -88
  20. edsl/auto/StagePersonaDimensionValues.py +74 -74
  21. edsl/auto/StagePersonaDimensions.py +69 -69
  22. edsl/auto/StageQuestions.py +73 -73
  23. edsl/auto/SurveyCreatorPipeline.py +21 -21
  24. edsl/auto/utilities.py +224 -224
  25. edsl/base/Base.py +289 -289
  26. edsl/config.py +149 -149
  27. edsl/conjure/AgentConstructionMixin.py +152 -152
  28. edsl/conjure/Conjure.py +62 -62
  29. edsl/conjure/InputData.py +659 -659
  30. edsl/conjure/InputDataCSV.py +48 -48
  31. edsl/conjure/InputDataMixinQuestionStats.py +182 -182
  32. edsl/conjure/InputDataPyRead.py +91 -91
  33. edsl/conjure/InputDataSPSS.py +8 -8
  34. edsl/conjure/InputDataStata.py +8 -8
  35. edsl/conjure/QuestionOptionMixin.py +76 -76
  36. edsl/conjure/QuestionTypeMixin.py +23 -23
  37. edsl/conjure/RawQuestion.py +65 -65
  38. edsl/conjure/SurveyResponses.py +7 -7
  39. edsl/conjure/__init__.py +9 -9
  40. edsl/conjure/naming_utilities.py +263 -263
  41. edsl/conjure/utilities.py +201 -201
  42. edsl/conversation/Conversation.py +238 -238
  43. edsl/conversation/car_buying.py +58 -58
  44. edsl/conversation/mug_negotiation.py +81 -81
  45. edsl/conversation/next_speaker_utilities.py +93 -93
  46. edsl/coop/PriceFetcher.py +54 -54
  47. edsl/coop/__init__.py +2 -2
  48. edsl/coop/coop.py +849 -849
  49. edsl/coop/utils.py +131 -131
  50. edsl/data/Cache.py +527 -527
  51. edsl/data/CacheEntry.py +228 -228
  52. edsl/data/CacheHandler.py +149 -149
  53. edsl/data/RemoteCacheSync.py +83 -83
  54. edsl/data/SQLiteDict.py +292 -292
  55. edsl/data/__init__.py +4 -4
  56. edsl/data/orm.py +10 -10
  57. edsl/data_transfer_models.py +73 -73
  58. edsl/enums.py +173 -173
  59. edsl/exceptions/__init__.py +50 -50
  60. edsl/exceptions/agents.py +40 -40
  61. edsl/exceptions/configuration.py +16 -16
  62. edsl/exceptions/coop.py +10 -10
  63. edsl/exceptions/data.py +14 -14
  64. edsl/exceptions/general.py +34 -34
  65. edsl/exceptions/jobs.py +33 -33
  66. edsl/exceptions/language_models.py +63 -63
  67. edsl/exceptions/prompts.py +15 -15
  68. edsl/exceptions/questions.py +91 -91
  69. edsl/exceptions/results.py +26 -26
  70. edsl/exceptions/surveys.py +34 -34
  71. edsl/inference_services/AnthropicService.py +87 -87
  72. edsl/inference_services/AwsBedrock.py +115 -115
  73. edsl/inference_services/AzureAI.py +217 -217
  74. edsl/inference_services/DeepInfraService.py +18 -18
  75. edsl/inference_services/GoogleService.py +156 -156
  76. edsl/inference_services/GroqService.py +20 -20
  77. edsl/inference_services/InferenceServiceABC.py +147 -147
  78. edsl/inference_services/InferenceServicesCollection.py +74 -68
  79. edsl/inference_services/MistralAIService.py +123 -123
  80. edsl/inference_services/OllamaService.py +18 -18
  81. edsl/inference_services/OpenAIService.py +224 -224
  82. edsl/inference_services/TestService.py +89 -89
  83. edsl/inference_services/TogetherAIService.py +170 -170
  84. edsl/inference_services/models_available_cache.py +118 -94
  85. edsl/inference_services/rate_limits_cache.py +25 -25
  86. edsl/inference_services/registry.py +39 -39
  87. edsl/inference_services/write_available.py +10 -10
  88. edsl/jobs/Answers.py +56 -56
  89. edsl/jobs/Jobs.py +1112 -1112
  90. edsl/jobs/__init__.py +1 -1
  91. edsl/jobs/buckets/BucketCollection.py +63 -63
  92. edsl/jobs/buckets/ModelBuckets.py +65 -65
  93. edsl/jobs/buckets/TokenBucket.py +248 -248
  94. edsl/jobs/interviews/Interview.py +661 -651
  95. edsl/jobs/interviews/InterviewExceptionCollection.py +99 -99
  96. edsl/jobs/interviews/InterviewExceptionEntry.py +189 -182
  97. edsl/jobs/interviews/InterviewStatistic.py +63 -63
  98. edsl/jobs/interviews/InterviewStatisticsCollection.py +25 -25
  99. edsl/jobs/interviews/InterviewStatusDictionary.py +78 -78
  100. edsl/jobs/interviews/InterviewStatusLog.py +92 -92
  101. edsl/jobs/interviews/ReportErrors.py +66 -66
  102. edsl/jobs/interviews/interview_status_enum.py +9 -9
  103. edsl/jobs/runners/JobsRunnerAsyncio.py +337 -337
  104. edsl/jobs/runners/JobsRunnerStatus.py +332 -332
  105. edsl/jobs/tasks/QuestionTaskCreator.py +242 -242
  106. edsl/jobs/tasks/TaskCreators.py +64 -64
  107. edsl/jobs/tasks/TaskHistory.py +441 -441
  108. edsl/jobs/tasks/TaskStatusLog.py +23 -23
  109. edsl/jobs/tasks/task_status_enum.py +163 -163
  110. edsl/jobs/tokens/InterviewTokenUsage.py +27 -27
  111. edsl/jobs/tokens/TokenUsage.py +34 -34
  112. edsl/language_models/LanguageModel.py +718 -718
  113. edsl/language_models/ModelList.py +102 -102
  114. edsl/language_models/RegisterLanguageModelsMeta.py +184 -184
  115. edsl/language_models/__init__.py +2 -2
  116. edsl/language_models/fake_openai_call.py +15 -15
  117. edsl/language_models/fake_openai_service.py +61 -61
  118. edsl/language_models/registry.py +137 -137
  119. edsl/language_models/repair.py +156 -156
  120. edsl/language_models/unused/ReplicateBase.py +83 -83
  121. edsl/language_models/utilities.py +64 -64
  122. edsl/notebooks/Notebook.py +259 -259
  123. edsl/notebooks/__init__.py +1 -1
  124. edsl/prompts/Prompt.py +358 -358
  125. edsl/prompts/__init__.py +2 -2
  126. edsl/questions/AnswerValidatorMixin.py +289 -289
  127. edsl/questions/QuestionBase.py +616 -616
  128. edsl/questions/QuestionBaseGenMixin.py +161 -161
  129. edsl/questions/QuestionBasePromptsMixin.py +266 -266
  130. edsl/questions/QuestionBudget.py +227 -227
  131. edsl/questions/QuestionCheckBox.py +359 -359
  132. edsl/questions/QuestionExtract.py +183 -183
  133. edsl/questions/QuestionFreeText.py +113 -113
  134. edsl/questions/QuestionFunctional.py +159 -159
  135. edsl/questions/QuestionList.py +231 -231
  136. edsl/questions/QuestionMultipleChoice.py +286 -286
  137. edsl/questions/QuestionNumerical.py +153 -153
  138. edsl/questions/QuestionRank.py +324 -324
  139. edsl/questions/Quick.py +41 -41
  140. edsl/questions/RegisterQuestionsMeta.py +71 -71
  141. edsl/questions/ResponseValidatorABC.py +174 -174
  142. edsl/questions/SimpleAskMixin.py +73 -73
  143. edsl/questions/__init__.py +26 -26
  144. edsl/questions/compose_questions.py +98 -98
  145. edsl/questions/decorators.py +21 -21
  146. edsl/questions/derived/QuestionLikertFive.py +76 -76
  147. edsl/questions/derived/QuestionLinearScale.py +87 -87
  148. edsl/questions/derived/QuestionTopK.py +91 -91
  149. edsl/questions/derived/QuestionYesNo.py +82 -82
  150. edsl/questions/descriptors.py +418 -418
  151. edsl/questions/prompt_templates/question_budget.jinja +13 -13
  152. edsl/questions/prompt_templates/question_checkbox.jinja +32 -32
  153. edsl/questions/prompt_templates/question_extract.jinja +11 -11
  154. edsl/questions/prompt_templates/question_free_text.jinja +3 -3
  155. edsl/questions/prompt_templates/question_linear_scale.jinja +11 -11
  156. edsl/questions/prompt_templates/question_list.jinja +17 -17
  157. edsl/questions/prompt_templates/question_multiple_choice.jinja +33 -33
  158. edsl/questions/prompt_templates/question_numerical.jinja +36 -36
  159. edsl/questions/question_registry.py +147 -147
  160. edsl/questions/settings.py +12 -12
  161. edsl/questions/templates/budget/answering_instructions.jinja +7 -7
  162. edsl/questions/templates/budget/question_presentation.jinja +7 -7
  163. edsl/questions/templates/checkbox/answering_instructions.jinja +10 -10
  164. edsl/questions/templates/checkbox/question_presentation.jinja +22 -22
  165. edsl/questions/templates/extract/answering_instructions.jinja +7 -7
  166. edsl/questions/templates/likert_five/answering_instructions.jinja +10 -10
  167. edsl/questions/templates/likert_five/question_presentation.jinja +11 -11
  168. edsl/questions/templates/linear_scale/answering_instructions.jinja +5 -5
  169. edsl/questions/templates/linear_scale/question_presentation.jinja +5 -5
  170. edsl/questions/templates/list/answering_instructions.jinja +3 -3
  171. edsl/questions/templates/list/question_presentation.jinja +5 -5
  172. edsl/questions/templates/multiple_choice/answering_instructions.jinja +9 -9
  173. edsl/questions/templates/multiple_choice/question_presentation.jinja +11 -11
  174. edsl/questions/templates/numerical/answering_instructions.jinja +6 -6
  175. edsl/questions/templates/numerical/question_presentation.jinja +6 -6
  176. edsl/questions/templates/rank/answering_instructions.jinja +11 -11
  177. edsl/questions/templates/rank/question_presentation.jinja +15 -15
  178. edsl/questions/templates/top_k/answering_instructions.jinja +8 -8
  179. edsl/questions/templates/top_k/question_presentation.jinja +22 -22
  180. edsl/questions/templates/yes_no/answering_instructions.jinja +6 -6
  181. edsl/questions/templates/yes_no/question_presentation.jinja +11 -11
  182. edsl/results/Dataset.py +293 -293
  183. edsl/results/DatasetExportMixin.py +693 -693
  184. edsl/results/DatasetTree.py +145 -145
  185. edsl/results/Result.py +433 -433
  186. edsl/results/Results.py +1158 -1158
  187. edsl/results/ResultsDBMixin.py +238 -238
  188. edsl/results/ResultsExportMixin.py +43 -43
  189. edsl/results/ResultsFetchMixin.py +33 -33
  190. edsl/results/ResultsGGMixin.py +121 -121
  191. edsl/results/ResultsToolsMixin.py +98 -98
  192. edsl/results/Selector.py +118 -118
  193. edsl/results/__init__.py +2 -2
  194. edsl/results/tree_explore.py +115 -115
  195. edsl/scenarios/FileStore.py +458 -443
  196. edsl/scenarios/Scenario.py +510 -507
  197. edsl/scenarios/ScenarioHtmlMixin.py +59 -59
  198. edsl/scenarios/ScenarioList.py +1101 -1101
  199. edsl/scenarios/ScenarioListExportMixin.py +52 -52
  200. edsl/scenarios/ScenarioListPdfMixin.py +261 -261
  201. edsl/scenarios/__init__.py +4 -2
  202. edsl/shared.py +1 -1
  203. edsl/study/ObjectEntry.py +173 -173
  204. edsl/study/ProofOfWork.py +113 -113
  205. edsl/study/SnapShot.py +80 -80
  206. edsl/study/Study.py +528 -528
  207. edsl/study/__init__.py +4 -4
  208. edsl/surveys/DAG.py +148 -148
  209. edsl/surveys/Memory.py +31 -31
  210. edsl/surveys/MemoryPlan.py +244 -244
  211. edsl/surveys/Rule.py +324 -324
  212. edsl/surveys/RuleCollection.py +387 -387
  213. edsl/surveys/Survey.py +1772 -1772
  214. edsl/surveys/SurveyCSS.py +261 -261
  215. edsl/surveys/SurveyExportMixin.py +259 -259
  216. edsl/surveys/SurveyFlowVisualizationMixin.py +121 -121
  217. edsl/surveys/SurveyQualtricsImport.py +284 -284
  218. edsl/surveys/__init__.py +3 -3
  219. edsl/surveys/base.py +53 -53
  220. edsl/surveys/descriptors.py +56 -56
  221. edsl/surveys/instructions/ChangeInstruction.py +47 -47
  222. edsl/surveys/instructions/Instruction.py +51 -51
  223. edsl/surveys/instructions/InstructionCollection.py +77 -77
  224. edsl/templates/error_reporting/base.html +23 -23
  225. edsl/templates/error_reporting/exceptions_by_model.html +34 -34
  226. edsl/templates/error_reporting/exceptions_by_question_name.html +16 -16
  227. edsl/templates/error_reporting/exceptions_by_type.html +16 -16
  228. edsl/templates/error_reporting/interview_details.html +115 -115
  229. edsl/templates/error_reporting/interviews.html +9 -9
  230. edsl/templates/error_reporting/overview.html +4 -4
  231. edsl/templates/error_reporting/performance_plot.html +1 -1
  232. edsl/templates/error_reporting/report.css +73 -73
  233. edsl/templates/error_reporting/report.html +117 -117
  234. edsl/templates/error_reporting/report.js +25 -25
  235. edsl/tools/__init__.py +1 -1
  236. edsl/tools/clusters.py +192 -192
  237. edsl/tools/embeddings.py +27 -27
  238. edsl/tools/embeddings_plotting.py +118 -118
  239. edsl/tools/plotting.py +112 -112
  240. edsl/tools/summarize.py +18 -18
  241. edsl/utilities/SystemInfo.py +28 -28
  242. edsl/utilities/__init__.py +22 -22
  243. edsl/utilities/ast_utilities.py +25 -25
  244. edsl/utilities/data/Registry.py +6 -6
  245. edsl/utilities/data/__init__.py +1 -1
  246. edsl/utilities/data/scooter_results.json +1 -1
  247. edsl/utilities/decorators.py +77 -77
  248. edsl/utilities/gcp_bucket/cloud_storage.py +96 -96
  249. edsl/utilities/interface.py +627 -627
  250. edsl/utilities/repair_functions.py +28 -28
  251. edsl/utilities/restricted_python.py +70 -70
  252. edsl/utilities/utilities.py +391 -391
  253. {edsl-0.1.36.dev5.dist-info → edsl-0.1.36.dev7.dist-info}/LICENSE +21 -21
  254. {edsl-0.1.36.dev5.dist-info → edsl-0.1.36.dev7.dist-info}/METADATA +1 -1
  255. edsl-0.1.36.dev7.dist-info/RECORD +279 -0
  256. edsl-0.1.36.dev5.dist-info/RECORD +0 -279
  257. {edsl-0.1.36.dev5.dist-info → edsl-0.1.36.dev7.dist-info}/WHEEL +0 -0
@@ -1,294 +1,298 @@
1
- from abc import ABC, abstractmethod
2
- import asyncio
3
- from typing import Coroutine, Dict, Any, Optional
4
-
5
- from edsl.prompts.Prompt import Prompt
6
- from edsl.utilities.decorators import jupyter_nb_handler
7
- from edsl.data_transfer_models import AgentResponseDict
8
-
9
- from edsl.data.Cache import Cache
10
-
11
- from edsl.questions.QuestionBase import QuestionBase
12
- from edsl.scenarios.Scenario import Scenario
13
- from edsl.surveys.MemoryPlan import MemoryPlan
14
- from edsl.language_models.LanguageModel import LanguageModel
15
-
16
- from edsl.data_transfer_models import EDSLResultObjectInput
17
- from edsl.agents.PromptConstructor import PromptConstructor
18
-
19
-
20
- class InvigilatorBase(ABC):
21
- """An invigiator (someone who administers an exam) is a class that is responsible for administering a question to an agent.
22
-
23
- >>> InvigilatorBase.example().answer_question()
24
- {'message': [{'text': 'SPAM!'}], 'usage': {'prompt_tokens': 1, 'completion_tokens': 1}}
25
-
26
- >>> InvigilatorBase.example().get_failed_task_result(failure_reason="Failed to get response").comment
27
- 'Failed to get response'
28
-
29
- This returns an empty prompt because there is no memory the agent needs to have at q0.
30
-
31
-
32
- """
33
-
34
- def __init__(
35
- self,
36
- agent: "Agent",
37
- question: QuestionBase,
38
- scenario: Scenario,
39
- model: LanguageModel,
40
- memory_plan: MemoryPlan,
41
- current_answers: dict,
42
- survey: Optional["Survey"],
43
- cache: Optional[Cache] = None,
44
- iteration: Optional[int] = 1,
45
- additional_prompt_data: Optional[dict] = None,
46
- sidecar_model: Optional[LanguageModel] = None,
47
- raise_validation_errors: Optional[bool] = True,
48
- ):
49
- """Initialize a new Invigilator."""
50
- self.agent = agent
51
- self.question = question
52
- self.scenario = scenario
53
- self.model = model
54
- self.memory_plan = memory_plan
55
- self.current_answers = current_answers or {}
56
- self.iteration = iteration
57
- self.additional_prompt_data = additional_prompt_data
58
- self.cache = cache
59
- self.sidecar_model = sidecar_model
60
- self.survey = survey
61
- self.raise_validation_errors = raise_validation_errors
62
-
63
- self.raw_model_response = (
64
- None # placeholder for the raw response from the model
65
- )
66
-
67
- @property
68
- def prompt_constructor(self) -> PromptConstructor:
69
- """Return the prompt constructor."""
70
- return PromptConstructor(self)
71
-
72
- def to_dict(self):
73
- attributes = [
74
- "agent",
75
- "question",
76
- "scenario",
77
- "model",
78
- "memory_plan",
79
- "current_answers",
80
- "iteration",
81
- "additional_prompt_data",
82
- "cache",
83
- "sidecar_model",
84
- "survey",
85
- ]
86
-
87
- def serialize_attribute(attr):
88
- value = getattr(self, attr)
89
- if value is None:
90
- return None
91
- if hasattr(value, "to_dict"):
92
- return value.to_dict()
93
- if isinstance(value, (int, float, str, bool, dict, list)):
94
- return value
95
- return str(value)
96
-
97
- return {attr: serialize_attribute(attr) for attr in attributes}
98
-
99
- @classmethod
100
- def from_dict(cls, data):
101
- from edsl.agents.Agent import Agent
102
- from edsl.questions import QuestionBase
103
- from edsl.scenarios.Scenario import Scenario
104
- from edsl.surveys.MemoryPlan import MemoryPlan
105
- from edsl.language_models.LanguageModel import LanguageModel
106
- from edsl.surveys.Survey import Survey
107
-
108
- agent = Agent.from_dict(data["agent"])
109
- question = QuestionBase.from_dict(data["question"])
110
- scenario = Scenario.from_dict(data["scenario"])
111
- model = LanguageModel.from_dict(data["model"])
112
- memory_plan = MemoryPlan.from_dict(data["memory_plan"])
113
- survey = Survey.from_dict(data["survey"])
114
- current_answers = data["current_answers"]
115
- iteration = data["iteration"]
116
- additional_prompt_data = data["additional_prompt_data"]
117
- cache = Cache.from_dict(data["cache"])
118
- sidecar_model = LanguageModel.from_dict(data["sidecar_model"])
119
-
120
- return cls(
121
- agent=agent,
122
- question=question,
123
- scenario=scenario,
124
- model=model,
125
- memory_plan=memory_plan,
126
- current_answers=current_answers,
127
- survey=survey,
128
- iteration=iteration,
129
- additional_prompt_data=additional_prompt_data,
130
- cache=cache,
131
- sidecar_model=sidecar_model,
132
- )
133
-
134
- def __repr__(self) -> str:
135
- """Return a string representation of the Invigilator.
136
-
137
- >>> InvigilatorBase.example().__repr__()
138
- 'InvigilatorExample(...)'
139
-
140
- """
141
- return f"{self.__class__.__name__}(agent={repr(self.agent)}, question={repr(self.question)}, scneario={repr(self.scenario)}, model={repr(self.model)}, memory_plan={repr(self.memory_plan)}, current_answers={repr(self.current_answers)}, iteration{repr(self.iteration)}, additional_prompt_data={repr(self.additional_prompt_data)}, cache={repr(self.cache)}, sidecarmodel={repr(self.sidecar_model)})"
142
-
143
- def get_failed_task_result(self, failure_reason) -> EDSLResultObjectInput:
144
- """Return an AgentResponseDict used in case the question-asking fails.
145
-
146
- Possible reasons include:
147
- - Legimately skipped because of skip logic
148
- - Failed to get response from the model
149
-
150
- """
151
- data = {
152
- "answer": None,
153
- "generated_tokens": None,
154
- "comment": failure_reason,
155
- "question_name": self.question.question_name,
156
- "prompts": self.get_prompts(),
157
- "cached_response": None,
158
- "raw_model_response": None,
159
- "cache_used": None,
160
- "cache_key": None,
161
- }
162
- return EDSLResultObjectInput(**data)
163
-
164
- # breakpoint()
165
- # if hasattr(self, "augmented_model_response"):
166
- # import json
167
-
168
- # generated_tokens = json.loads(self.augmented_model_response)["answer"][
169
- # "generated_tokens"
170
- # ]
171
- # else:
172
- # generated_tokens = "Filled in by InvigilatorBase.get_failed_task_result"
173
- # agent_response_dict = AgentResponseDict(
174
- # answer=None,
175
- # comment="Failed to get usable response",
176
- # generated_tokens=generated_tokens,
177
- # question_name=self.question.question_name,
178
- # prompts=self.get_prompts(),
179
- # )
180
- # # breakpoint()
181
- # return agent_response_dict
182
-
183
- def get_prompts(self) -> Dict[str, Prompt]:
184
- """Return the prompt used."""
185
-
186
- return {
187
- "user_prompt": Prompt("NA"),
188
- "system_prompt": Prompt("NA"),
189
- }
190
-
191
- @abstractmethod
192
- async def async_answer_question(self):
193
- """Asnwer a question."""
194
- pass
195
-
196
- @jupyter_nb_handler
197
- def answer_question(self) -> Coroutine:
198
- """Return a function that gets the answers to the question."""
199
-
200
- async def main():
201
- """Return the answer to the question."""
202
- results = await asyncio.gather(self.async_answer_question())
203
- return results[0] # Since there's only one task, return its result
204
-
205
- return main()
206
-
207
- @classmethod
208
- def example(
209
- cls, throw_an_exception=False, question=None, scenario=None, survey=None
210
- ) -> "InvigilatorBase":
211
- """Return an example invigilator.
212
-
213
- >>> InvigilatorBase.example()
214
- InvigilatorExample(...)
215
-
216
- """
217
- from edsl.agents.Agent import Agent
218
- from edsl.questions import QuestionMultipleChoice
219
- from edsl.scenarios.Scenario import Scenario
220
- from edsl.language_models import LanguageModel
221
- from edsl.surveys.MemoryPlan import MemoryPlan
222
-
223
- from edsl.enums import InferenceServiceType
224
-
225
- from edsl import Model
226
-
227
- model = Model("test", canned_response="SPAM!")
228
- # class TestLanguageModelGood(LanguageModel):
229
- # """A test language model."""
230
-
231
- # _model_ = "test"
232
- # _parameters_ = {"temperature": 0.5}
233
- # _inference_service_ = InferenceServiceType.TEST.value
234
-
235
- # async def async_execute_model_call(
236
- # self, user_prompt: str, system_prompt: str
237
- # ) -> dict[str, Any]:
238
- # await asyncio.sleep(0.1)
239
- # if hasattr(self, "throw_an_exception"):
240
- # raise Exception("Error!")
241
- # return {"message": """{"answer": "SPAM!"}"""}
242
-
243
- # def parse_response(self, raw_response: dict[str, Any]) -> str:
244
- # """Parse the response from the model."""
245
- # return raw_response["message"]
246
-
247
- if throw_an_exception:
248
- model.throw_an_exception = True
249
- agent = Agent.example()
250
- # question = QuestionMultipleChoice.example()
251
- from edsl.surveys import Survey
252
-
253
- if not survey:
254
- survey = Survey.example()
255
- # if question:
256
- # need to have the focal question name in the list of names
257
- # survey._questions[0].question_name = question.question_name
258
- # survey.add_question(question)
259
- if question:
260
- survey.add_question(question)
261
-
262
- question = question or survey.questions[0]
263
- scenario = scenario or Scenario.example()
264
- # memory_plan = None #memory_plan = MemoryPlan()
265
- from edsl import Survey
266
-
267
- memory_plan = MemoryPlan(survey=survey)
268
- current_answers = None
269
- from edsl.agents.PromptConstructor import PromptConstructor
270
-
271
- class InvigilatorExample(InvigilatorBase):
272
- """An example invigilator."""
273
-
274
- async def async_answer_question(self):
275
- """Answer a question."""
276
- return await self.model.async_execute_model_call(
277
- user_prompt="Hello", system_prompt="Hi"
278
- )
279
-
280
- return InvigilatorExample(
281
- agent=agent,
282
- question=question,
283
- scenario=scenario,
284
- survey=survey,
285
- model=model,
286
- memory_plan=memory_plan,
287
- current_answers=current_answers,
288
- )
289
-
290
-
291
- if __name__ == "__main__":
292
- import doctest
293
-
294
- doctest.testmod(optionflags=doctest.ELLIPSIS)
1
+ from abc import ABC, abstractmethod
2
+ import asyncio
3
+ from typing import Coroutine, Dict, Any, Optional
4
+
5
+ from edsl.prompts.Prompt import Prompt
6
+ from edsl.utilities.decorators import jupyter_nb_handler
7
+ from edsl.data_transfer_models import AgentResponseDict
8
+
9
+ from edsl.data.Cache import Cache
10
+
11
+ from edsl.questions.QuestionBase import QuestionBase
12
+ from edsl.scenarios.Scenario import Scenario
13
+ from edsl.surveys.MemoryPlan import MemoryPlan
14
+ from edsl.language_models.LanguageModel import LanguageModel
15
+
16
+ from edsl.data_transfer_models import EDSLResultObjectInput
17
+ from edsl.agents.PromptConstructor import PromptConstructor
18
+
19
+
20
+ class InvigilatorBase(ABC):
21
+ """An invigiator (someone who administers an exam) is a class that is responsible for administering a question to an agent.
22
+
23
+ >>> InvigilatorBase.example().answer_question()
24
+ {'message': [{'text': 'SPAM!'}], 'usage': {'prompt_tokens': 1, 'completion_tokens': 1}}
25
+
26
+ >>> InvigilatorBase.example().get_failed_task_result(failure_reason="Failed to get response").comment
27
+ 'Failed to get response'
28
+
29
+ This returns an empty prompt because there is no memory the agent needs to have at q0.
30
+
31
+
32
+ """
33
+
34
+ def __init__(
35
+ self,
36
+ agent: "Agent",
37
+ question: QuestionBase,
38
+ scenario: Scenario,
39
+ model: LanguageModel,
40
+ memory_plan: MemoryPlan,
41
+ current_answers: dict,
42
+ survey: Optional["Survey"],
43
+ cache: Optional[Cache] = None,
44
+ iteration: Optional[int] = 1,
45
+ additional_prompt_data: Optional[dict] = None,
46
+ sidecar_model: Optional[LanguageModel] = None,
47
+ raise_validation_errors: Optional[bool] = True,
48
+ ):
49
+ """Initialize a new Invigilator."""
50
+ self.agent = agent
51
+ self.question = question
52
+ self.scenario = scenario
53
+ self.model = model
54
+ self.memory_plan = memory_plan
55
+ self.current_answers = current_answers or {}
56
+ self.iteration = iteration
57
+ self.additional_prompt_data = additional_prompt_data
58
+ self.cache = cache
59
+ self.sidecar_model = sidecar_model
60
+ self.survey = survey
61
+ self.raise_validation_errors = raise_validation_errors
62
+
63
+ self.raw_model_response = (
64
+ None # placeholder for the raw response from the model
65
+ )
66
+
67
+ @property
68
+ def prompt_constructor(self) -> PromptConstructor:
69
+ """Return the prompt constructor."""
70
+ return PromptConstructor(self)
71
+
72
+ def to_dict(self):
73
+ attributes = [
74
+ "agent",
75
+ "question",
76
+ "scenario",
77
+ "model",
78
+ "memory_plan",
79
+ "current_answers",
80
+ "iteration",
81
+ "additional_prompt_data",
82
+ "cache",
83
+ "sidecar_model",
84
+ "survey",
85
+ ]
86
+
87
+ def serialize_attribute(attr):
88
+ value = getattr(self, attr)
89
+ if value is None:
90
+ return None
91
+ if hasattr(value, "to_dict"):
92
+ return value.to_dict()
93
+ if isinstance(value, (int, float, str, bool, dict, list)):
94
+ return value
95
+ return str(value)
96
+
97
+ return {attr: serialize_attribute(attr) for attr in attributes}
98
+
99
+ @classmethod
100
+ def from_dict(cls, data):
101
+ from edsl.agents.Agent import Agent
102
+ from edsl.questions import QuestionBase
103
+ from edsl.scenarios.Scenario import Scenario
104
+ from edsl.surveys.MemoryPlan import MemoryPlan
105
+ from edsl.language_models.LanguageModel import LanguageModel
106
+ from edsl.surveys.Survey import Survey
107
+
108
+ agent = Agent.from_dict(data["agent"])
109
+ question = QuestionBase.from_dict(data["question"])
110
+ scenario = Scenario.from_dict(data["scenario"])
111
+ model = LanguageModel.from_dict(data["model"])
112
+ memory_plan = MemoryPlan.from_dict(data["memory_plan"])
113
+ survey = Survey.from_dict(data["survey"])
114
+ current_answers = data["current_answers"]
115
+ iteration = data["iteration"]
116
+ additional_prompt_data = data["additional_prompt_data"]
117
+ cache = Cache.from_dict(data["cache"])
118
+
119
+ if data["sidecar_model"] is None:
120
+ sidecar_model = None
121
+ else:
122
+ sidecar_model = LanguageModel.from_dict(data["sidecar_model"])
123
+
124
+ return cls(
125
+ agent=agent,
126
+ question=question,
127
+ scenario=scenario,
128
+ model=model,
129
+ memory_plan=memory_plan,
130
+ current_answers=current_answers,
131
+ survey=survey,
132
+ iteration=iteration,
133
+ additional_prompt_data=additional_prompt_data,
134
+ cache=cache,
135
+ sidecar_model=sidecar_model,
136
+ )
137
+
138
+ def __repr__(self) -> str:
139
+ """Return a string representation of the Invigilator.
140
+
141
+ >>> InvigilatorBase.example().__repr__()
142
+ 'InvigilatorExample(...)'
143
+
144
+ """
145
+ return f"{self.__class__.__name__}(agent={repr(self.agent)}, question={repr(self.question)}, scneario={repr(self.scenario)}, model={repr(self.model)}, memory_plan={repr(self.memory_plan)}, current_answers={repr(self.current_answers)}, iteration{repr(self.iteration)}, additional_prompt_data={repr(self.additional_prompt_data)}, cache={repr(self.cache)}, sidecarmodel={repr(self.sidecar_model)})"
146
+
147
+ def get_failed_task_result(self, failure_reason) -> EDSLResultObjectInput:
148
+ """Return an AgentResponseDict used in case the question-asking fails.
149
+
150
+ Possible reasons include:
151
+ - Legimately skipped because of skip logic
152
+ - Failed to get response from the model
153
+
154
+ """
155
+ data = {
156
+ "answer": None,
157
+ "generated_tokens": None,
158
+ "comment": failure_reason,
159
+ "question_name": self.question.question_name,
160
+ "prompts": self.get_prompts(),
161
+ "cached_response": None,
162
+ "raw_model_response": None,
163
+ "cache_used": None,
164
+ "cache_key": None,
165
+ }
166
+ return EDSLResultObjectInput(**data)
167
+
168
+ # breakpoint()
169
+ # if hasattr(self, "augmented_model_response"):
170
+ # import json
171
+
172
+ # generated_tokens = json.loads(self.augmented_model_response)["answer"][
173
+ # "generated_tokens"
174
+ # ]
175
+ # else:
176
+ # generated_tokens = "Filled in by InvigilatorBase.get_failed_task_result"
177
+ # agent_response_dict = AgentResponseDict(
178
+ # answer=None,
179
+ # comment="Failed to get usable response",
180
+ # generated_tokens=generated_tokens,
181
+ # question_name=self.question.question_name,
182
+ # prompts=self.get_prompts(),
183
+ # )
184
+ # # breakpoint()
185
+ # return agent_response_dict
186
+
187
+ def get_prompts(self) -> Dict[str, Prompt]:
188
+ """Return the prompt used."""
189
+
190
+ return {
191
+ "user_prompt": Prompt("NA"),
192
+ "system_prompt": Prompt("NA"),
193
+ }
194
+
195
+ @abstractmethod
196
+ async def async_answer_question(self):
197
+ """Asnwer a question."""
198
+ pass
199
+
200
+ @jupyter_nb_handler
201
+ def answer_question(self) -> Coroutine:
202
+ """Return a function that gets the answers to the question."""
203
+
204
+ async def main():
205
+ """Return the answer to the question."""
206
+ results = await asyncio.gather(self.async_answer_question())
207
+ return results[0] # Since there's only one task, return its result
208
+
209
+ return main()
210
+
211
+ @classmethod
212
+ def example(
213
+ cls, throw_an_exception=False, question=None, scenario=None, survey=None
214
+ ) -> "InvigilatorBase":
215
+ """Return an example invigilator.
216
+
217
+ >>> InvigilatorBase.example()
218
+ InvigilatorExample(...)
219
+
220
+ """
221
+ from edsl.agents.Agent import Agent
222
+ from edsl.questions import QuestionMultipleChoice
223
+ from edsl.scenarios.Scenario import Scenario
224
+ from edsl.language_models import LanguageModel
225
+ from edsl.surveys.MemoryPlan import MemoryPlan
226
+
227
+ from edsl.enums import InferenceServiceType
228
+
229
+ from edsl import Model
230
+
231
+ model = Model("test", canned_response="SPAM!")
232
+ # class TestLanguageModelGood(LanguageModel):
233
+ # """A test language model."""
234
+
235
+ # _model_ = "test"
236
+ # _parameters_ = {"temperature": 0.5}
237
+ # _inference_service_ = InferenceServiceType.TEST.value
238
+
239
+ # async def async_execute_model_call(
240
+ # self, user_prompt: str, system_prompt: str
241
+ # ) -> dict[str, Any]:
242
+ # await asyncio.sleep(0.1)
243
+ # if hasattr(self, "throw_an_exception"):
244
+ # raise Exception("Error!")
245
+ # return {"message": """{"answer": "SPAM!"}"""}
246
+
247
+ # def parse_response(self, raw_response: dict[str, Any]) -> str:
248
+ # """Parse the response from the model."""
249
+ # return raw_response["message"]
250
+
251
+ if throw_an_exception:
252
+ model.throw_an_exception = True
253
+ agent = Agent.example()
254
+ # question = QuestionMultipleChoice.example()
255
+ from edsl.surveys import Survey
256
+
257
+ if not survey:
258
+ survey = Survey.example()
259
+ # if question:
260
+ # need to have the focal question name in the list of names
261
+ # survey._questions[0].question_name = question.question_name
262
+ # survey.add_question(question)
263
+ if question:
264
+ survey.add_question(question)
265
+
266
+ question = question or survey.questions[0]
267
+ scenario = scenario or Scenario.example()
268
+ # memory_plan = None #memory_plan = MemoryPlan()
269
+ from edsl import Survey
270
+
271
+ memory_plan = MemoryPlan(survey=survey)
272
+ current_answers = None
273
+ from edsl.agents.PromptConstructor import PromptConstructor
274
+
275
+ class InvigilatorExample(InvigilatorBase):
276
+ """An example invigilator."""
277
+
278
+ async def async_answer_question(self):
279
+ """Answer a question."""
280
+ return await self.model.async_execute_model_call(
281
+ user_prompt="Hello", system_prompt="Hi"
282
+ )
283
+
284
+ return InvigilatorExample(
285
+ agent=agent,
286
+ question=question,
287
+ scenario=scenario,
288
+ survey=survey,
289
+ model=model,
290
+ memory_plan=memory_plan,
291
+ current_answers=current_answers,
292
+ )
293
+
294
+
295
+ if __name__ == "__main__":
296
+ import doctest
297
+
298
+ doctest.testmod(optionflags=doctest.ELLIPSIS)