edsl 0.1.33.dev1__py3-none-any.whl → 0.1.33.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/TemplateLoader.py +24 -0
- edsl/__init__.py +8 -4
- edsl/agents/Agent.py +46 -14
- edsl/agents/AgentList.py +43 -0
- edsl/agents/Invigilator.py +125 -212
- edsl/agents/InvigilatorBase.py +140 -32
- edsl/agents/PromptConstructionMixin.py +43 -66
- edsl/agents/__init__.py +1 -0
- edsl/auto/AutoStudy.py +117 -0
- edsl/auto/StageBase.py +230 -0
- edsl/auto/StageGenerateSurvey.py +178 -0
- edsl/auto/StageLabelQuestions.py +125 -0
- edsl/auto/StagePersona.py +61 -0
- edsl/auto/StagePersonaDimensionValueRanges.py +88 -0
- edsl/auto/StagePersonaDimensionValues.py +74 -0
- edsl/auto/StagePersonaDimensions.py +69 -0
- edsl/auto/StageQuestions.py +73 -0
- edsl/auto/SurveyCreatorPipeline.py +21 -0
- edsl/auto/utilities.py +224 -0
- edsl/config.py +38 -39
- edsl/coop/PriceFetcher.py +58 -0
- edsl/coop/coop.py +39 -5
- edsl/data/Cache.py +35 -1
- edsl/data_transfer_models.py +120 -38
- edsl/enums.py +2 -0
- edsl/exceptions/language_models.py +25 -1
- edsl/exceptions/questions.py +62 -5
- edsl/exceptions/results.py +4 -0
- edsl/inference_services/AnthropicService.py +13 -11
- edsl/inference_services/AwsBedrock.py +19 -17
- edsl/inference_services/AzureAI.py +37 -20
- edsl/inference_services/GoogleService.py +16 -12
- edsl/inference_services/GroqService.py +2 -0
- edsl/inference_services/InferenceServiceABC.py +24 -0
- edsl/inference_services/MistralAIService.py +120 -0
- edsl/inference_services/OpenAIService.py +41 -50
- edsl/inference_services/TestService.py +71 -0
- edsl/inference_services/models_available_cache.py +0 -6
- edsl/inference_services/registry.py +4 -0
- edsl/jobs/Answers.py +10 -12
- edsl/jobs/FailedQuestion.py +78 -0
- edsl/jobs/Jobs.py +18 -13
- edsl/jobs/buckets/TokenBucket.py +39 -14
- edsl/jobs/interviews/Interview.py +297 -77
- edsl/jobs/interviews/InterviewExceptionEntry.py +83 -19
- edsl/jobs/interviews/interview_exception_tracking.py +0 -70
- edsl/jobs/interviews/retry_management.py +3 -1
- edsl/jobs/runners/JobsRunnerAsyncio.py +116 -70
- edsl/jobs/runners/JobsRunnerStatusMixin.py +1 -1
- edsl/jobs/tasks/QuestionTaskCreator.py +30 -23
- edsl/jobs/tasks/TaskHistory.py +131 -213
- edsl/language_models/LanguageModel.py +239 -129
- edsl/language_models/ModelList.py +2 -2
- edsl/language_models/RegisterLanguageModelsMeta.py +14 -29
- edsl/language_models/fake_openai_call.py +15 -0
- edsl/language_models/fake_openai_service.py +61 -0
- edsl/language_models/registry.py +15 -2
- edsl/language_models/repair.py +0 -19
- edsl/language_models/utilities.py +61 -0
- edsl/prompts/Prompt.py +52 -2
- edsl/questions/AnswerValidatorMixin.py +23 -26
- edsl/questions/QuestionBase.py +273 -242
- edsl/questions/QuestionBaseGenMixin.py +133 -0
- edsl/questions/QuestionBasePromptsMixin.py +266 -0
- edsl/questions/QuestionBudget.py +6 -0
- edsl/questions/QuestionCheckBox.py +227 -35
- edsl/questions/QuestionExtract.py +98 -27
- edsl/questions/QuestionFreeText.py +46 -29
- edsl/questions/QuestionFunctional.py +7 -0
- edsl/questions/QuestionList.py +141 -22
- edsl/questions/QuestionMultipleChoice.py +173 -64
- edsl/questions/QuestionNumerical.py +87 -46
- edsl/questions/QuestionRank.py +182 -24
- edsl/questions/RegisterQuestionsMeta.py +31 -12
- edsl/questions/ResponseValidatorABC.py +169 -0
- edsl/questions/__init__.py +3 -4
- edsl/questions/decorators.py +21 -0
- edsl/questions/derived/QuestionLikertFive.py +10 -5
- edsl/questions/derived/QuestionLinearScale.py +11 -1
- edsl/questions/derived/QuestionTopK.py +6 -0
- edsl/questions/derived/QuestionYesNo.py +16 -1
- edsl/questions/descriptors.py +43 -7
- edsl/questions/prompt_templates/question_budget.jinja +13 -0
- edsl/questions/prompt_templates/question_checkbox.jinja +32 -0
- edsl/questions/prompt_templates/question_extract.jinja +11 -0
- edsl/questions/prompt_templates/question_free_text.jinja +3 -0
- edsl/questions/prompt_templates/question_linear_scale.jinja +11 -0
- edsl/questions/prompt_templates/question_list.jinja +17 -0
- edsl/questions/prompt_templates/question_multiple_choice.jinja +33 -0
- edsl/questions/prompt_templates/question_numerical.jinja +37 -0
- edsl/questions/question_registry.py +6 -2
- edsl/questions/templates/__init__.py +0 -0
- edsl/questions/templates/checkbox/__init__.py +0 -0
- edsl/questions/templates/checkbox/answering_instructions.jinja +10 -0
- edsl/questions/templates/checkbox/question_presentation.jinja +22 -0
- edsl/questions/templates/extract/answering_instructions.jinja +7 -0
- edsl/questions/templates/extract/question_presentation.jinja +1 -0
- edsl/questions/templates/free_text/__init__.py +0 -0
- edsl/questions/templates/free_text/answering_instructions.jinja +0 -0
- edsl/questions/templates/free_text/question_presentation.jinja +1 -0
- edsl/questions/templates/likert_five/__init__.py +0 -0
- edsl/questions/templates/likert_five/answering_instructions.jinja +10 -0
- edsl/questions/templates/likert_five/question_presentation.jinja +12 -0
- edsl/questions/templates/linear_scale/__init__.py +0 -0
- edsl/questions/templates/linear_scale/answering_instructions.jinja +5 -0
- edsl/questions/templates/linear_scale/question_presentation.jinja +5 -0
- edsl/questions/templates/list/__init__.py +0 -0
- edsl/questions/templates/list/answering_instructions.jinja +4 -0
- edsl/questions/templates/list/question_presentation.jinja +5 -0
- edsl/questions/templates/multiple_choice/__init__.py +0 -0
- edsl/questions/templates/multiple_choice/answering_instructions.jinja +9 -0
- edsl/questions/templates/multiple_choice/html.jinja +0 -0
- edsl/questions/templates/multiple_choice/question_presentation.jinja +12 -0
- edsl/questions/templates/numerical/__init__.py +0 -0
- edsl/questions/templates/numerical/answering_instructions.jinja +8 -0
- edsl/questions/templates/numerical/question_presentation.jinja +7 -0
- edsl/questions/templates/rank/answering_instructions.jinja +11 -0
- edsl/questions/templates/rank/question_presentation.jinja +15 -0
- edsl/questions/templates/top_k/__init__.py +0 -0
- edsl/questions/templates/top_k/answering_instructions.jinja +8 -0
- edsl/questions/templates/top_k/question_presentation.jinja +22 -0
- edsl/questions/templates/yes_no/__init__.py +0 -0
- edsl/questions/templates/yes_no/answering_instructions.jinja +6 -0
- edsl/questions/templates/yes_no/question_presentation.jinja +12 -0
- edsl/results/Dataset.py +20 -0
- edsl/results/DatasetExportMixin.py +41 -47
- edsl/results/DatasetTree.py +145 -0
- edsl/results/Result.py +32 -5
- edsl/results/Results.py +131 -45
- edsl/results/ResultsDBMixin.py +3 -3
- edsl/results/Selector.py +118 -0
- edsl/results/tree_explore.py +115 -0
- edsl/scenarios/Scenario.py +10 -4
- edsl/scenarios/ScenarioList.py +348 -39
- edsl/scenarios/ScenarioListExportMixin.py +9 -0
- edsl/study/SnapShot.py +8 -1
- edsl/surveys/RuleCollection.py +2 -2
- edsl/surveys/Survey.py +634 -315
- edsl/surveys/SurveyExportMixin.py +71 -9
- edsl/surveys/SurveyFlowVisualizationMixin.py +2 -1
- edsl/surveys/SurveyQualtricsImport.py +75 -4
- edsl/surveys/instructions/ChangeInstruction.py +47 -0
- edsl/surveys/instructions/Instruction.py +34 -0
- edsl/surveys/instructions/InstructionCollection.py +77 -0
- edsl/surveys/instructions/__init__.py +0 -0
- edsl/templates/error_reporting/base.html +24 -0
- edsl/templates/error_reporting/exceptions_by_model.html +35 -0
- edsl/templates/error_reporting/exceptions_by_question_name.html +17 -0
- edsl/templates/error_reporting/exceptions_by_type.html +17 -0
- edsl/templates/error_reporting/interview_details.html +111 -0
- edsl/templates/error_reporting/interviews.html +10 -0
- edsl/templates/error_reporting/overview.html +5 -0
- edsl/templates/error_reporting/performance_plot.html +2 -0
- edsl/templates/error_reporting/report.css +74 -0
- edsl/templates/error_reporting/report.html +118 -0
- edsl/templates/error_reporting/report.js +25 -0
- {edsl-0.1.33.dev1.dist-info → edsl-0.1.33.dev2.dist-info}/METADATA +4 -2
- edsl-0.1.33.dev2.dist-info/RECORD +289 -0
- edsl/jobs/interviews/InterviewTaskBuildingMixin.py +0 -286
- edsl/utilities/gcp_bucket/simple_example.py +0 -9
- edsl-0.1.33.dev1.dist-info/RECORD +0 -209
- {edsl-0.1.33.dev1.dist-info → edsl-0.1.33.dev2.dist-info}/LICENSE +0 -0
- {edsl-0.1.33.dev1.dist-info → edsl-0.1.33.dev2.dist-info}/WHEEL +0 -0
edsl/agents/InvigilatorBase.py
CHANGED
@@ -8,21 +8,22 @@ from edsl.data_transfer_models import AgentResponseDict
|
|
8
8
|
|
9
9
|
from edsl.data.Cache import Cache
|
10
10
|
|
11
|
-
# from edsl.agents.Agent import Agent
|
12
11
|
from edsl.questions.QuestionBase import QuestionBase
|
13
12
|
from edsl.scenarios.Scenario import Scenario
|
14
13
|
from edsl.surveys.MemoryPlan import MemoryPlan
|
15
14
|
from edsl.language_models.LanguageModel import LanguageModel
|
16
15
|
|
16
|
+
from edsl.data_transfer_models import EDSLResultObjectInput
|
17
|
+
|
17
18
|
|
18
19
|
class InvigilatorBase(ABC):
|
19
20
|
"""An invigiator (someone who administers an exam) is a class that is responsible for administering a question to an agent.
|
20
21
|
|
21
22
|
>>> InvigilatorBase.example().answer_question()
|
22
|
-
{'message': '
|
23
|
+
{'message': [{'text': 'SPAM!'}], 'usage': {'prompt_tokens': 1, 'completion_tokens': 1}}
|
23
24
|
|
24
|
-
>>> InvigilatorBase.example().get_failed_task_result()
|
25
|
-
|
25
|
+
>>> InvigilatorBase.example().get_failed_task_result(failure_reason="Failed to get response").comment
|
26
|
+
'Failed to get response'
|
26
27
|
|
27
28
|
This returns an empty prompt because there is no memory the agent needs to have at q0.
|
28
29
|
|
@@ -51,6 +52,7 @@ class InvigilatorBase(ABC):
|
|
51
52
|
iteration: Optional[int] = 1,
|
52
53
|
additional_prompt_data: Optional[dict] = None,
|
53
54
|
sidecar_model: Optional[LanguageModel] = None,
|
55
|
+
raise_validation_errors: Optional[bool] = True,
|
54
56
|
):
|
55
57
|
"""Initialize a new Invigilator."""
|
56
58
|
self.agent = agent
|
@@ -64,6 +66,73 @@ class InvigilatorBase(ABC):
|
|
64
66
|
self.cache = cache
|
65
67
|
self.sidecar_model = sidecar_model
|
66
68
|
self.survey = survey
|
69
|
+
self.raise_validation_errors = raise_validation_errors
|
70
|
+
|
71
|
+
self.raw_model_response = (
|
72
|
+
None # placeholder for the raw response from the model
|
73
|
+
)
|
74
|
+
|
75
|
+
def to_dict(self):
|
76
|
+
attributes = [
|
77
|
+
"agent",
|
78
|
+
"question",
|
79
|
+
"scenario",
|
80
|
+
"model",
|
81
|
+
"memory_plan",
|
82
|
+
"current_answers",
|
83
|
+
"iteration",
|
84
|
+
"additional_prompt_data",
|
85
|
+
"cache",
|
86
|
+
"sidecar_model",
|
87
|
+
"survey",
|
88
|
+
]
|
89
|
+
|
90
|
+
def serialize_attribute(attr):
|
91
|
+
value = getattr(self, attr)
|
92
|
+
if value is None:
|
93
|
+
return None
|
94
|
+
if hasattr(value, "to_dict"):
|
95
|
+
return value.to_dict()
|
96
|
+
if isinstance(value, (int, float, str, bool, dict, list)):
|
97
|
+
return value
|
98
|
+
return str(value)
|
99
|
+
|
100
|
+
return {attr: serialize_attribute(attr) for attr in attributes}
|
101
|
+
|
102
|
+
@classmethod
|
103
|
+
def from_dict(cls, data):
|
104
|
+
from edsl.agents.Agent import Agent
|
105
|
+
from edsl.questions import QuestionBase
|
106
|
+
from edsl.scenarios.Scenario import Scenario
|
107
|
+
from edsl.surveys.MemoryPlan import MemoryPlan
|
108
|
+
from edsl.language_models.LanguageModel import LanguageModel
|
109
|
+
from edsl.surveys.Survey import Survey
|
110
|
+
|
111
|
+
agent = Agent.from_dict(data["agent"])
|
112
|
+
question = QuestionBase.from_dict(data["question"])
|
113
|
+
scenario = Scenario.from_dict(data["scenario"])
|
114
|
+
model = LanguageModel.from_dict(data["model"])
|
115
|
+
memory_plan = MemoryPlan.from_dict(data["memory_plan"])
|
116
|
+
survey = Survey.from_dict(data["survey"])
|
117
|
+
current_answers = data["current_answers"]
|
118
|
+
iteration = data["iteration"]
|
119
|
+
additional_prompt_data = data["additional_prompt_data"]
|
120
|
+
cache = Cache.from_dict(data["cache"])
|
121
|
+
sidecar_model = LanguageModel.from_dict(data["sidecar_model"])
|
122
|
+
|
123
|
+
return cls(
|
124
|
+
agent=agent,
|
125
|
+
question=question,
|
126
|
+
scenario=scenario,
|
127
|
+
model=model,
|
128
|
+
memory_plan=memory_plan,
|
129
|
+
current_answers=current_answers,
|
130
|
+
survey=survey,
|
131
|
+
iteration=iteration,
|
132
|
+
additional_prompt_data=additional_prompt_data,
|
133
|
+
cache=cache,
|
134
|
+
sidecar_model=sidecar_model,
|
135
|
+
)
|
67
136
|
|
68
137
|
def __repr__(self) -> str:
|
69
138
|
"""Return a string representation of the Invigilator.
|
@@ -74,18 +143,45 @@ class InvigilatorBase(ABC):
|
|
74
143
|
"""
|
75
144
|
return f"{self.__class__.__name__}(agent={repr(self.agent)}, question={repr(self.question)}, scneario={repr(self.scenario)}, model={repr(self.model)}, memory_plan={repr(self.memory_plan)}, current_answers={repr(self.current_answers)}, iteration{repr(self.iteration)}, additional_prompt_data={repr(self.additional_prompt_data)}, cache={repr(self.cache)}, sidecarmodel={repr(self.sidecar_model)})"
|
76
145
|
|
77
|
-
def get_failed_task_result(self) ->
|
146
|
+
def get_failed_task_result(self, failure_reason) -> EDSLResultObjectInput:
|
78
147
|
"""Return an AgentResponseDict used in case the question-asking fails.
|
79
148
|
|
80
|
-
|
81
|
-
|
149
|
+
Possible reasons include:
|
150
|
+
- Legimately skipped because of skip logic
|
151
|
+
- Failed to get response from the model
|
152
|
+
|
82
153
|
"""
|
83
|
-
|
84
|
-
answer
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
154
|
+
data = {
|
155
|
+
"answer": None,
|
156
|
+
"generated_tokens": None,
|
157
|
+
"comment": failure_reason,
|
158
|
+
"question_name": self.question.question_name,
|
159
|
+
"prompts": self.get_prompts(),
|
160
|
+
"cached_response": None,
|
161
|
+
"raw_model_response": None,
|
162
|
+
"cache_used": None,
|
163
|
+
"cache_key": None,
|
164
|
+
}
|
165
|
+
return EDSLResultObjectInput(**data)
|
166
|
+
|
167
|
+
# breakpoint()
|
168
|
+
# if hasattr(self, "augmented_model_response"):
|
169
|
+
# import json
|
170
|
+
|
171
|
+
# generated_tokens = json.loads(self.augmented_model_response)["answer"][
|
172
|
+
# "generated_tokens"
|
173
|
+
# ]
|
174
|
+
# else:
|
175
|
+
# generated_tokens = "Filled in by InvigilatorBase.get_failed_task_result"
|
176
|
+
# agent_response_dict = AgentResponseDict(
|
177
|
+
# answer=None,
|
178
|
+
# comment="Failed to get usable response",
|
179
|
+
# generated_tokens=generated_tokens,
|
180
|
+
# question_name=self.question.question_name,
|
181
|
+
# prompts=self.get_prompts(),
|
182
|
+
# )
|
183
|
+
# # breakpoint()
|
184
|
+
# return agent_response_dict
|
89
185
|
|
90
186
|
def get_prompts(self) -> Dict[str, Prompt]:
|
91
187
|
"""Return the prompt used."""
|
@@ -128,7 +224,9 @@ class InvigilatorBase(ABC):
|
|
128
224
|
)
|
129
225
|
|
130
226
|
@classmethod
|
131
|
-
def example(
|
227
|
+
def example(
|
228
|
+
cls, throw_an_exception=False, question=None, scenario=None, survey=None
|
229
|
+
) -> "InvigilatorBase":
|
132
230
|
"""Return an example invigilator.
|
133
231
|
|
134
232
|
>>> InvigilatorBase.example()
|
@@ -143,39 +241,49 @@ class InvigilatorBase(ABC):
|
|
143
241
|
|
144
242
|
from edsl.enums import InferenceServiceType
|
145
243
|
|
146
|
-
|
147
|
-
"""A test language model."""
|
244
|
+
from edsl import Model
|
148
245
|
|
149
|
-
|
150
|
-
|
151
|
-
|
246
|
+
model = Model("test", canned_response="SPAM!")
|
247
|
+
# class TestLanguageModelGood(LanguageModel):
|
248
|
+
# """A test language model."""
|
152
249
|
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
await asyncio.sleep(0.1)
|
157
|
-
if hasattr(self, "throw_an_exception"):
|
158
|
-
raise Exception("Error!")
|
159
|
-
return {"message": """{"answer": "SPAM!"}"""}
|
250
|
+
# _model_ = "test"
|
251
|
+
# _parameters_ = {"temperature": 0.5}
|
252
|
+
# _inference_service_ = InferenceServiceType.TEST.value
|
160
253
|
|
161
|
-
|
162
|
-
|
163
|
-
|
254
|
+
# async def async_execute_model_call(
|
255
|
+
# self, user_prompt: str, system_prompt: str
|
256
|
+
# ) -> dict[str, Any]:
|
257
|
+
# await asyncio.sleep(0.1)
|
258
|
+
# if hasattr(self, "throw_an_exception"):
|
259
|
+
# raise Exception("Error!")
|
260
|
+
# return {"message": """{"answer": "SPAM!"}"""}
|
261
|
+
|
262
|
+
# def parse_response(self, raw_response: dict[str, Any]) -> str:
|
263
|
+
# """Parse the response from the model."""
|
264
|
+
# return raw_response["message"]
|
164
265
|
|
165
|
-
model = TestLanguageModelGood()
|
166
266
|
if throw_an_exception:
|
167
267
|
model.throw_an_exception = True
|
168
268
|
agent = Agent.example()
|
169
269
|
# question = QuestionMultipleChoice.example()
|
170
270
|
from edsl.surveys import Survey
|
171
271
|
|
172
|
-
|
272
|
+
if not survey:
|
273
|
+
survey = Survey.example()
|
274
|
+
# if question:
|
275
|
+
# need to have the focal question name in the list of names
|
276
|
+
# survey._questions[0].question_name = question.question_name
|
277
|
+
# survey.add_question(question)
|
278
|
+
if question:
|
279
|
+
survey.add_question(question)
|
280
|
+
|
173
281
|
question = question or survey.questions[0]
|
174
282
|
scenario = scenario or Scenario.example()
|
175
283
|
# memory_plan = None #memory_plan = MemoryPlan()
|
176
284
|
from edsl import Survey
|
177
285
|
|
178
|
-
memory_plan = MemoryPlan(survey=
|
286
|
+
memory_plan = MemoryPlan(survey=survey)
|
179
287
|
current_answers = None
|
180
288
|
from edsl.agents.PromptConstructionMixin import PromptConstructorMixin
|
181
289
|
|
@@ -1,3 +1,4 @@
|
|
1
|
+
from __future__ import annotations
|
1
2
|
from typing import Dict, Any, Optional
|
2
3
|
from collections import UserList
|
3
4
|
|
@@ -231,47 +232,8 @@ class PromptConstructorMixin:
|
|
231
232
|
>>> from edsl.agents.InvigilatorBase import InvigilatorBase
|
232
233
|
>>> i = InvigilatorBase.example()
|
233
234
|
>>> i.question_instructions_prompt
|
234
|
-
Prompt(text=\"""
|
235
|
-
The options are
|
236
|
-
<BLANKLINE>
|
237
|
-
0: yes
|
238
|
-
<BLANKLINE>
|
239
|
-
1: no
|
240
|
-
<BLANKLINE>
|
241
|
-
Return a valid JSON formatted like this, selecting only the number of the option:
|
242
|
-
{"answer": <put answer code here>, "comment": "<put explanation here>"}
|
243
|
-
Only 1 option may be selected.\""")
|
244
|
-
|
245
|
-
>>> from edsl import QuestionFreeText
|
246
|
-
>>> q = QuestionFreeText(question_text = "Consider {{ X }}. What is your favorite color?", question_name = "q_color")
|
247
|
-
>>> from edsl.agents.InvigilatorBase import InvigilatorBase
|
248
|
-
>>> i = InvigilatorBase.example(question = q)
|
249
|
-
>>> i.question_instructions_prompt
|
250
|
-
Traceback (most recent call last):
|
235
|
+
Prompt(text=\"""...
|
251
236
|
...
|
252
|
-
edsl.exceptions.questions.QuestionScenarioRenderError: Question instructions still has variables: ['X'].
|
253
|
-
|
254
|
-
|
255
|
-
>>> from edsl import QuestionFreeText
|
256
|
-
>>> q = QuestionFreeText(question_text = "You were asked the question '{{ q0.question_text }}'. What is your favorite color?", question_name = "q_color")
|
257
|
-
>>> from edsl.agents.InvigilatorBase import InvigilatorBase
|
258
|
-
>>> i = InvigilatorBase.example(question = q)
|
259
|
-
>>> i.question_instructions_prompt
|
260
|
-
Prompt(text=\"""You are being asked the following question: You were asked the question 'Do you like school?'. What is your favorite color?
|
261
|
-
Return a valid JSON formatted like this:
|
262
|
-
{"answer": "<put free text answer here>"}\""")
|
263
|
-
|
264
|
-
>>> from edsl import QuestionFreeText
|
265
|
-
>>> q = QuestionFreeText(question_text = "You stated '{{ q0.answer }}'. What is your favorite color?", question_name = "q_color")
|
266
|
-
>>> from edsl.agents.InvigilatorBase import InvigilatorBase
|
267
|
-
>>> i = InvigilatorBase.example(question = q)
|
268
|
-
>>> i.current_answers = {"q0": "I like school"}
|
269
|
-
>>> i.question_instructions_prompt
|
270
|
-
Prompt(text=\"""You are being asked the following question: You stated 'I like school'. What is your favorite color?
|
271
|
-
Return a valid JSON formatted like this:
|
272
|
-
{"answer": "<put free text answer here>"}\""")
|
273
|
-
|
274
|
-
|
275
237
|
"""
|
276
238
|
if not hasattr(self, "_question_instructions_prompt"):
|
277
239
|
question_prompt = self.question.get_instructions(model=self.model.model)
|
@@ -290,27 +252,37 @@ class PromptConstructorMixin:
|
|
290
252
|
|
291
253
|
# check to see if the question_options is actually a string
|
292
254
|
# This is used when the user is using the question_options as a variable from a sceario
|
293
|
-
if "question_options" in question_data:
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
)
|
305
|
-
|
306
|
-
|
307
|
-
|
255
|
+
# if "question_options" in question_data:
|
256
|
+
if isinstance(self.question.data.get("question_options", None), str):
|
257
|
+
from jinja2 import Environment, meta
|
258
|
+
|
259
|
+
env = Environment()
|
260
|
+
parsed_content = env.parse(self.question.data["question_options"])
|
261
|
+
question_option_key = list(
|
262
|
+
meta.find_undeclared_variables(parsed_content)
|
263
|
+
)[0]
|
264
|
+
|
265
|
+
if isinstance(
|
266
|
+
question_options := self.scenario.get(question_option_key), list
|
267
|
+
):
|
268
|
+
question_data["question_options"] = question_options
|
269
|
+
self.question.question_options = question_options
|
270
|
+
|
271
|
+
replacement_dict = (
|
308
272
|
question_data
|
309
273
|
| self.scenario
|
310
274
|
| self.prior_answers_dict()
|
311
275
|
| {"agent": self.agent}
|
276
|
+
| {
|
277
|
+
"use_code": getattr(self.question, "_use_code", True),
|
278
|
+
"include_comment": getattr(
|
279
|
+
self.question, "_include_comment", False
|
280
|
+
),
|
281
|
+
}
|
312
282
|
)
|
313
|
-
|
283
|
+
# breakpoint()
|
284
|
+
rendered_instructions = question_prompt.render(replacement_dict)
|
285
|
+
# breakpoint()
|
314
286
|
undefined_template_variables = (
|
315
287
|
rendered_instructions.undefined_template_variables({})
|
316
288
|
)
|
@@ -324,11 +296,23 @@ class PromptConstructorMixin:
|
|
324
296
|
)
|
325
297
|
|
326
298
|
if undefined_template_variables:
|
327
|
-
print(undefined_template_variables)
|
328
299
|
raise QuestionScenarioRenderError(
|
329
300
|
f"Question instructions still has variables: {undefined_template_variables}."
|
330
301
|
)
|
331
302
|
|
303
|
+
# Check if question has an instructions
|
304
|
+
relevant_instructions = self.survey.relevant_instructions(
|
305
|
+
self.question.question_name
|
306
|
+
)
|
307
|
+
|
308
|
+
if relevant_instructions != []:
|
309
|
+
preamble_text = Prompt(
|
310
|
+
text="Before answer this question, you were given the following instructions: "
|
311
|
+
)
|
312
|
+
for instruction in relevant_instructions:
|
313
|
+
preamble_text += instruction.text
|
314
|
+
rendered_instructions = preamble_text + rendered_instructions
|
315
|
+
|
332
316
|
self._question_instructions_prompt = rendered_instructions
|
333
317
|
return self._question_instructions_prompt
|
334
318
|
|
@@ -368,17 +352,10 @@ class PromptConstructorMixin:
|
|
368
352
|
|
369
353
|
>>> from edsl import QuestionFreeText
|
370
354
|
>>> from edsl.agents.InvigilatorBase import InvigilatorBase
|
371
|
-
>>> q = QuestionFreeText(question_text="How are you today?", question_name="
|
355
|
+
>>> q = QuestionFreeText(question_text="How are you today?", question_name="q_new")
|
372
356
|
>>> i = InvigilatorBase.example(question = q)
|
373
357
|
>>> i.get_prompts()
|
374
358
|
{'user_prompt': ..., 'system_prompt': ...}
|
375
|
-
>>> scenario = i._get_scenario_with_image()
|
376
|
-
>>> scenario.has_image
|
377
|
-
True
|
378
|
-
>>> q = QuestionFreeText(question_text="How are you today?", question_name="q0")
|
379
|
-
>>> i = InvigilatorBase.example(question = q, scenario = scenario)
|
380
|
-
>>> i.get_prompts()
|
381
|
-
{'user_prompt': ..., 'system_prompt': ..., 'encoded_image': ...'}
|
382
359
|
"""
|
383
360
|
prompts = self.prompt_plan.get_prompts(
|
384
361
|
agent_instructions=self.agent_instructions_prompt,
|
@@ -391,7 +368,7 @@ class PromptConstructorMixin:
|
|
391
368
|
prompts["encoded_image"] = self.scenario["encoded_image"]
|
392
369
|
return prompts
|
393
370
|
|
394
|
-
def _get_scenario_with_image(self) ->
|
371
|
+
def _get_scenario_with_image(self) -> Scenario:
|
395
372
|
"""This is a helper function to get a scenario with an image, for testing purposes."""
|
396
373
|
from edsl import Scenario
|
397
374
|
|
edsl/agents/__init__.py
CHANGED
edsl/auto/AutoStudy.py
ADDED
@@ -0,0 +1,117 @@
|
|
1
|
+
from typing import Optional
|
2
|
+
|
3
|
+
from edsl import Model
|
4
|
+
from edsl.auto.StageQuestions import StageQuestions
|
5
|
+
from edsl.auto.StagePersona import StagePersona
|
6
|
+
from edsl.auto.StagePersonaDimensions import StagePersonaDimensions
|
7
|
+
from edsl.auto.StagePersonaDimensionValues import StagePersonaDimensionValues
|
8
|
+
from edsl.auto.StagePersonaDimensionValueRanges import (
|
9
|
+
StagePersonaDimensionValueRanges,
|
10
|
+
)
|
11
|
+
from edsl.auto.StageLabelQuestions import StageLabelQuestions
|
12
|
+
from edsl.auto.StageGenerateSurvey import StageGenerateSurvey
|
13
|
+
|
14
|
+
# from edsl.auto.StageBase import gen_pipeline
|
15
|
+
|
16
|
+
from edsl.auto.utilities import agent_generator, create_agents, gen_pipeline
|
17
|
+
|
18
|
+
|
19
|
+
class AutoStudy:
|
20
|
+
def __init__(
|
21
|
+
self,
|
22
|
+
overall_question: str,
|
23
|
+
population: str,
|
24
|
+
model: Optional["Model"] = None,
|
25
|
+
survey: Optional["Survey"] = None,
|
26
|
+
agent_list: Optional["AgentList"] = None,
|
27
|
+
default_num_agents=11,
|
28
|
+
):
|
29
|
+
self.overall_question = overall_question
|
30
|
+
self.population = population
|
31
|
+
self._survey = survey
|
32
|
+
self._agent_list = agent_list
|
33
|
+
self._agent_list_generator = None
|
34
|
+
self._persona_mapping = None
|
35
|
+
self._results = None
|
36
|
+
self.default_num_agents = default_num_agents
|
37
|
+
self.model = model or Model()
|
38
|
+
|
39
|
+
@property
|
40
|
+
def survey(self):
|
41
|
+
if self._survey is None:
|
42
|
+
self._survey = self._create_survey()
|
43
|
+
return self._survey
|
44
|
+
|
45
|
+
@property
|
46
|
+
def persona_mapping(self):
|
47
|
+
if self._persona_mapping is None:
|
48
|
+
self._persona_mapping = self._create_persona_mapping()
|
49
|
+
return self._persona_mapping
|
50
|
+
|
51
|
+
@property
|
52
|
+
def agent_list_generator(self):
|
53
|
+
if self._agent_list_generator is None:
|
54
|
+
self._agent_list_generator = self._create_agent_list_generator()
|
55
|
+
return self._agent_list_generator
|
56
|
+
|
57
|
+
@property
|
58
|
+
def results(self):
|
59
|
+
if self._results is None:
|
60
|
+
self._results = self._create_results()
|
61
|
+
return self._results
|
62
|
+
|
63
|
+
def _create_survey(self):
|
64
|
+
survey_pipline_stages = [
|
65
|
+
StageQuestions,
|
66
|
+
StageLabelQuestions,
|
67
|
+
StageGenerateSurvey,
|
68
|
+
]
|
69
|
+
survey_pipeline = gen_pipeline(survey_pipline_stages)
|
70
|
+
return survey_pipeline.process(
|
71
|
+
data=survey_pipeline.input(
|
72
|
+
overall_question=self.overall_question, population=self.population
|
73
|
+
)
|
74
|
+
).survey
|
75
|
+
|
76
|
+
def _create_persona_mapping(self):
|
77
|
+
persona_pipeline_stages = [
|
78
|
+
StageQuestions,
|
79
|
+
StagePersona,
|
80
|
+
StagePersonaDimensions,
|
81
|
+
StagePersonaDimensionValues,
|
82
|
+
StagePersonaDimensionValueRanges,
|
83
|
+
]
|
84
|
+
|
85
|
+
persona_pipeline = gen_pipeline(persona_pipeline_stages)
|
86
|
+
sample_agent_results = persona_pipeline.process(
|
87
|
+
persona_pipeline.input(
|
88
|
+
overall_question=overall_question, population=self.population
|
89
|
+
)
|
90
|
+
)
|
91
|
+
return sample_agent_results
|
92
|
+
|
93
|
+
def _create_agent_list_generator(self):
|
94
|
+
return agent_generator(
|
95
|
+
persona=self.persona_mapping.persona,
|
96
|
+
dimension_dict=self.persona_mapping.mapping,
|
97
|
+
)
|
98
|
+
|
99
|
+
def agent_list(self, num_agents):
|
100
|
+
return create_agents(
|
101
|
+
agent_generator=self.agent_list_generator,
|
102
|
+
survey=self.survey,
|
103
|
+
num_agents=num_agents,
|
104
|
+
)
|
105
|
+
|
106
|
+
def _create_results(self, num_agents=None):
|
107
|
+
if num_agents is None:
|
108
|
+
num_agents = self.default_num_agents
|
109
|
+
agent_list = self.agent_list(num_agents)
|
110
|
+
return self.survey.by(agent_list).by(self.model).run()
|
111
|
+
|
112
|
+
|
113
|
+
if __name__ == "__main__":
|
114
|
+
overall_question = "Should online platforms be regulated with respect to selling electric scooters?"
|
115
|
+
auto_study = AutoStudy(overall_question, population="US Adults")
|
116
|
+
|
117
|
+
results = auto_study.results
|