edsl 0.1.33.dev1__py3-none-any.whl → 0.1.33.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/TemplateLoader.py +24 -0
- edsl/__init__.py +8 -4
- edsl/agents/Agent.py +46 -14
- edsl/agents/AgentList.py +43 -0
- edsl/agents/Invigilator.py +125 -212
- edsl/agents/InvigilatorBase.py +140 -32
- edsl/agents/PromptConstructionMixin.py +43 -66
- edsl/agents/__init__.py +1 -0
- edsl/auto/AutoStudy.py +117 -0
- edsl/auto/StageBase.py +230 -0
- edsl/auto/StageGenerateSurvey.py +178 -0
- edsl/auto/StageLabelQuestions.py +125 -0
- edsl/auto/StagePersona.py +61 -0
- edsl/auto/StagePersonaDimensionValueRanges.py +88 -0
- edsl/auto/StagePersonaDimensionValues.py +74 -0
- edsl/auto/StagePersonaDimensions.py +69 -0
- edsl/auto/StageQuestions.py +73 -0
- edsl/auto/SurveyCreatorPipeline.py +21 -0
- edsl/auto/utilities.py +224 -0
- edsl/config.py +38 -39
- edsl/coop/PriceFetcher.py +58 -0
- edsl/coop/coop.py +39 -5
- edsl/data/Cache.py +35 -1
- edsl/data_transfer_models.py +120 -38
- edsl/enums.py +2 -0
- edsl/exceptions/language_models.py +25 -1
- edsl/exceptions/questions.py +62 -5
- edsl/exceptions/results.py +4 -0
- edsl/inference_services/AnthropicService.py +13 -11
- edsl/inference_services/AwsBedrock.py +19 -17
- edsl/inference_services/AzureAI.py +37 -20
- edsl/inference_services/GoogleService.py +16 -12
- edsl/inference_services/GroqService.py +2 -0
- edsl/inference_services/InferenceServiceABC.py +24 -0
- edsl/inference_services/MistralAIService.py +120 -0
- edsl/inference_services/OpenAIService.py +41 -50
- edsl/inference_services/TestService.py +71 -0
- edsl/inference_services/models_available_cache.py +0 -6
- edsl/inference_services/registry.py +4 -0
- edsl/jobs/Answers.py +10 -12
- edsl/jobs/FailedQuestion.py +78 -0
- edsl/jobs/Jobs.py +18 -13
- edsl/jobs/buckets/TokenBucket.py +39 -14
- edsl/jobs/interviews/Interview.py +297 -77
- edsl/jobs/interviews/InterviewExceptionEntry.py +83 -19
- edsl/jobs/interviews/interview_exception_tracking.py +0 -70
- edsl/jobs/interviews/retry_management.py +3 -1
- edsl/jobs/runners/JobsRunnerAsyncio.py +116 -70
- edsl/jobs/runners/JobsRunnerStatusMixin.py +1 -1
- edsl/jobs/tasks/QuestionTaskCreator.py +30 -23
- edsl/jobs/tasks/TaskHistory.py +131 -213
- edsl/language_models/LanguageModel.py +239 -129
- edsl/language_models/ModelList.py +2 -2
- edsl/language_models/RegisterLanguageModelsMeta.py +14 -29
- edsl/language_models/fake_openai_call.py +15 -0
- edsl/language_models/fake_openai_service.py +61 -0
- edsl/language_models/registry.py +15 -2
- edsl/language_models/repair.py +0 -19
- edsl/language_models/utilities.py +61 -0
- edsl/prompts/Prompt.py +52 -2
- edsl/questions/AnswerValidatorMixin.py +23 -26
- edsl/questions/QuestionBase.py +273 -242
- edsl/questions/QuestionBaseGenMixin.py +133 -0
- edsl/questions/QuestionBasePromptsMixin.py +266 -0
- edsl/questions/QuestionBudget.py +6 -0
- edsl/questions/QuestionCheckBox.py +227 -35
- edsl/questions/QuestionExtract.py +98 -27
- edsl/questions/QuestionFreeText.py +46 -29
- edsl/questions/QuestionFunctional.py +7 -0
- edsl/questions/QuestionList.py +141 -22
- edsl/questions/QuestionMultipleChoice.py +173 -64
- edsl/questions/QuestionNumerical.py +87 -46
- edsl/questions/QuestionRank.py +182 -24
- edsl/questions/RegisterQuestionsMeta.py +31 -12
- edsl/questions/ResponseValidatorABC.py +169 -0
- edsl/questions/__init__.py +3 -4
- edsl/questions/decorators.py +21 -0
- edsl/questions/derived/QuestionLikertFive.py +10 -5
- edsl/questions/derived/QuestionLinearScale.py +11 -1
- edsl/questions/derived/QuestionTopK.py +6 -0
- edsl/questions/derived/QuestionYesNo.py +16 -1
- edsl/questions/descriptors.py +43 -7
- edsl/questions/prompt_templates/question_budget.jinja +13 -0
- edsl/questions/prompt_templates/question_checkbox.jinja +32 -0
- edsl/questions/prompt_templates/question_extract.jinja +11 -0
- edsl/questions/prompt_templates/question_free_text.jinja +3 -0
- edsl/questions/prompt_templates/question_linear_scale.jinja +11 -0
- edsl/questions/prompt_templates/question_list.jinja +17 -0
- edsl/questions/prompt_templates/question_multiple_choice.jinja +33 -0
- edsl/questions/prompt_templates/question_numerical.jinja +37 -0
- edsl/questions/question_registry.py +6 -2
- edsl/questions/templates/__init__.py +0 -0
- edsl/questions/templates/checkbox/__init__.py +0 -0
- edsl/questions/templates/checkbox/answering_instructions.jinja +10 -0
- edsl/questions/templates/checkbox/question_presentation.jinja +22 -0
- edsl/questions/templates/extract/answering_instructions.jinja +7 -0
- edsl/questions/templates/extract/question_presentation.jinja +1 -0
- edsl/questions/templates/free_text/__init__.py +0 -0
- edsl/questions/templates/free_text/answering_instructions.jinja +0 -0
- edsl/questions/templates/free_text/question_presentation.jinja +1 -0
- edsl/questions/templates/likert_five/__init__.py +0 -0
- edsl/questions/templates/likert_five/answering_instructions.jinja +10 -0
- edsl/questions/templates/likert_five/question_presentation.jinja +12 -0
- edsl/questions/templates/linear_scale/__init__.py +0 -0
- edsl/questions/templates/linear_scale/answering_instructions.jinja +5 -0
- edsl/questions/templates/linear_scale/question_presentation.jinja +5 -0
- edsl/questions/templates/list/__init__.py +0 -0
- edsl/questions/templates/list/answering_instructions.jinja +4 -0
- edsl/questions/templates/list/question_presentation.jinja +5 -0
- edsl/questions/templates/multiple_choice/__init__.py +0 -0
- edsl/questions/templates/multiple_choice/answering_instructions.jinja +9 -0
- edsl/questions/templates/multiple_choice/html.jinja +0 -0
- edsl/questions/templates/multiple_choice/question_presentation.jinja +12 -0
- edsl/questions/templates/numerical/__init__.py +0 -0
- edsl/questions/templates/numerical/answering_instructions.jinja +8 -0
- edsl/questions/templates/numerical/question_presentation.jinja +7 -0
- edsl/questions/templates/rank/answering_instructions.jinja +11 -0
- edsl/questions/templates/rank/question_presentation.jinja +15 -0
- edsl/questions/templates/top_k/__init__.py +0 -0
- edsl/questions/templates/top_k/answering_instructions.jinja +8 -0
- edsl/questions/templates/top_k/question_presentation.jinja +22 -0
- edsl/questions/templates/yes_no/__init__.py +0 -0
- edsl/questions/templates/yes_no/answering_instructions.jinja +6 -0
- edsl/questions/templates/yes_no/question_presentation.jinja +12 -0
- edsl/results/Dataset.py +20 -0
- edsl/results/DatasetExportMixin.py +41 -47
- edsl/results/DatasetTree.py +145 -0
- edsl/results/Result.py +32 -5
- edsl/results/Results.py +131 -45
- edsl/results/ResultsDBMixin.py +3 -3
- edsl/results/Selector.py +118 -0
- edsl/results/tree_explore.py +115 -0
- edsl/scenarios/Scenario.py +10 -4
- edsl/scenarios/ScenarioList.py +348 -39
- edsl/scenarios/ScenarioListExportMixin.py +9 -0
- edsl/study/SnapShot.py +8 -1
- edsl/surveys/RuleCollection.py +2 -2
- edsl/surveys/Survey.py +634 -315
- edsl/surveys/SurveyExportMixin.py +71 -9
- edsl/surveys/SurveyFlowVisualizationMixin.py +2 -1
- edsl/surveys/SurveyQualtricsImport.py +75 -4
- edsl/surveys/instructions/ChangeInstruction.py +47 -0
- edsl/surveys/instructions/Instruction.py +34 -0
- edsl/surveys/instructions/InstructionCollection.py +77 -0
- edsl/surveys/instructions/__init__.py +0 -0
- edsl/templates/error_reporting/base.html +24 -0
- edsl/templates/error_reporting/exceptions_by_model.html +35 -0
- edsl/templates/error_reporting/exceptions_by_question_name.html +17 -0
- edsl/templates/error_reporting/exceptions_by_type.html +17 -0
- edsl/templates/error_reporting/interview_details.html +111 -0
- edsl/templates/error_reporting/interviews.html +10 -0
- edsl/templates/error_reporting/overview.html +5 -0
- edsl/templates/error_reporting/performance_plot.html +2 -0
- edsl/templates/error_reporting/report.css +74 -0
- edsl/templates/error_reporting/report.html +118 -0
- edsl/templates/error_reporting/report.js +25 -0
- {edsl-0.1.33.dev1.dist-info → edsl-0.1.33.dev2.dist-info}/METADATA +4 -2
- edsl-0.1.33.dev2.dist-info/RECORD +289 -0
- edsl/jobs/interviews/InterviewTaskBuildingMixin.py +0 -286
- edsl/utilities/gcp_bucket/simple_example.py +0 -9
- edsl-0.1.33.dev1.dist-info/RECORD +0 -209
- {edsl-0.1.33.dev1.dist-info → edsl-0.1.33.dev2.dist-info}/LICENSE +0 -0
- {edsl-0.1.33.dev1.dist-info → edsl-0.1.33.dev2.dist-info}/WHEEL +0 -0
@@ -10,6 +10,165 @@ from edsl.questions.descriptors import (
|
|
10
10
|
QuestionOptionsDescriptor,
|
11
11
|
)
|
12
12
|
|
13
|
+
from edsl.questions.decorators import inject_exception
|
14
|
+
|
15
|
+
from pydantic import field_validator
|
16
|
+
from edsl.questions.ResponseValidatorABC import ResponseValidatorABC
|
17
|
+
from edsl.questions.ResponseValidatorABC import BaseResponse
|
18
|
+
|
19
|
+
from edsl.exceptions import QuestionAnswerValidationError
|
20
|
+
|
21
|
+
from pydantic import BaseModel, Field, conlist
|
22
|
+
from typing import List, Literal, Optional, Annotated
|
23
|
+
|
24
|
+
|
25
|
+
def create_checkbox_response_model(
|
26
|
+
choices: list,
|
27
|
+
min_selections: Optional[int] = None,
|
28
|
+
max_selections: Optional[int] = None,
|
29
|
+
permissive: bool = False,
|
30
|
+
):
|
31
|
+
"""
|
32
|
+
Dynamically create a CheckboxResponse model with a predefined list of choices.
|
33
|
+
|
34
|
+
:param choices: A list of allowed values for the answer field.
|
35
|
+
:param include_comment: Whether to include a comment field in the model.
|
36
|
+
:return: A new Pydantic model class.
|
37
|
+
"""
|
38
|
+
# Convert the choices list to a tuple for use with Literal
|
39
|
+
choice_tuple = tuple(choices)
|
40
|
+
|
41
|
+
field_params = {}
|
42
|
+
if min_selections is not None and not permissive:
|
43
|
+
field_params["min_items"] = min_selections
|
44
|
+
if max_selections is not None and not permissive:
|
45
|
+
field_params["max_items"] = max_selections
|
46
|
+
|
47
|
+
class CheckboxResponse(BaseModel):
|
48
|
+
answer: Annotated[
|
49
|
+
List[Literal[choice_tuple]],
|
50
|
+
Field(..., **field_params),
|
51
|
+
] = Field(..., description="List of selected choices")
|
52
|
+
comment: Optional[str] = Field(None, description="Optional comment field")
|
53
|
+
generated_tokens: Optional[Any] = Field(default=None)
|
54
|
+
|
55
|
+
class Config:
|
56
|
+
@staticmethod
|
57
|
+
def json_schema_extra(schema: dict, model: BaseModel) -> None:
|
58
|
+
# Add the list of choices to the schema for better documentation
|
59
|
+
for prop in schema.get("properties", {}).values():
|
60
|
+
if prop.get("title") == "answer":
|
61
|
+
prop["items"] = {"enum": choices}
|
62
|
+
|
63
|
+
return CheckboxResponse
|
64
|
+
|
65
|
+
|
66
|
+
class CheckBoxResponseValidator(ResponseValidatorABC):
|
67
|
+
required_params = [
|
68
|
+
"question_options",
|
69
|
+
"min_selections",
|
70
|
+
"max_selections",
|
71
|
+
"use_code",
|
72
|
+
"permissive",
|
73
|
+
]
|
74
|
+
|
75
|
+
valid_examples = [
|
76
|
+
({"answer": [1, 2]}, {"question_options": ["Good", "Great", "OK", "Bad"]})
|
77
|
+
]
|
78
|
+
|
79
|
+
invalid_examples = [
|
80
|
+
(
|
81
|
+
{"answer": [-1]},
|
82
|
+
{"question_options": ["Good", "Great", "OK", "Bad"]},
|
83
|
+
"Answer code must be a non-negative integer",
|
84
|
+
),
|
85
|
+
(
|
86
|
+
{"answer": 1},
|
87
|
+
{"question_options": ["Good", "Great", "OK", "Bad"]},
|
88
|
+
"Answer code must be a list",
|
89
|
+
),
|
90
|
+
(
|
91
|
+
{"answer": [1, 2, 3, 4]},
|
92
|
+
{
|
93
|
+
"question_options": ["Good", "Great", "OK", "Bad"],
|
94
|
+
"min_selections": 1,
|
95
|
+
"max_selections": 2,
|
96
|
+
},
|
97
|
+
"Too many options selected",
|
98
|
+
),
|
99
|
+
]
|
100
|
+
|
101
|
+
def fix(self, response, verbose=False):
|
102
|
+
if verbose:
|
103
|
+
print("Invalid response of QuestionCheckBox was: ", response)
|
104
|
+
response_text = response.get("generated_tokens")
|
105
|
+
if response_text is None or response_text == "": # nothing to be done
|
106
|
+
return response
|
107
|
+
# Maybe it's a comma separated list?
|
108
|
+
proposed_list = response_text.split(",")
|
109
|
+
proposed_list = [item.strip() for item in proposed_list]
|
110
|
+
if verbose:
|
111
|
+
print("Using code? ", self.use_code)
|
112
|
+
if self.use_code:
|
113
|
+
try:
|
114
|
+
proposed_list = [int(i) for i in proposed_list]
|
115
|
+
except ValueError:
|
116
|
+
# print("Could not convert to int")
|
117
|
+
pass
|
118
|
+
|
119
|
+
if verbose:
|
120
|
+
print("Proposed solution is: ", proposed_list)
|
121
|
+
|
122
|
+
# print(f"Ivalid generated tokens was was: {response_text}")
|
123
|
+
if "comment" in response:
|
124
|
+
proposed_data = {
|
125
|
+
"answer": proposed_list,
|
126
|
+
"comment": response["comment"],
|
127
|
+
"generated_tokens": response.get("generated_tokens", None),
|
128
|
+
}
|
129
|
+
else:
|
130
|
+
proposed_data = {
|
131
|
+
"answer": proposed_list,
|
132
|
+
"generated_tokens": response.get("generated_tokens", None),
|
133
|
+
}
|
134
|
+
|
135
|
+
try:
|
136
|
+
self.response_model(**proposed_data)
|
137
|
+
print("Proposed solution is valid")
|
138
|
+
print("Returning proposed data: ", proposed_data)
|
139
|
+
return proposed_data
|
140
|
+
except Exception as e:
|
141
|
+
if verbose:
|
142
|
+
print(f"Proposed solution {proposed_data} is invalid. Error: {e}")
|
143
|
+
# return response
|
144
|
+
if verbose:
|
145
|
+
print("Now seeing if responses show up in the answer")
|
146
|
+
matches = []
|
147
|
+
for index, option in enumerate(self.question_options):
|
148
|
+
if self.use_code:
|
149
|
+
if str(index) in response_text:
|
150
|
+
matches.append(index)
|
151
|
+
else:
|
152
|
+
if option in response_text:
|
153
|
+
matches.append(index)
|
154
|
+
proposed_data = {
|
155
|
+
"answer": matches,
|
156
|
+
"comment": response.get("comment", None),
|
157
|
+
"generated_tokens": response.get("generated_tokens", None),
|
158
|
+
}
|
159
|
+
try:
|
160
|
+
self.response_model(**proposed_data)
|
161
|
+
return proposed_data
|
162
|
+
except Exception as e:
|
163
|
+
if verbose:
|
164
|
+
print(f"Proposed solution {proposed_data} is invalid. Error: {e}")
|
165
|
+
return response
|
166
|
+
|
167
|
+
def custom_validate(self, response) -> BaseResponse:
|
168
|
+
if response.answer is None:
|
169
|
+
raise QuestionAnswerValidationError("Answer is missing.")
|
170
|
+
return response.dict()
|
171
|
+
|
13
172
|
|
14
173
|
class QuestionCheckBox(QuestionBase):
|
15
174
|
"""This question prompts the agent to select options from a list."""
|
@@ -20,6 +179,9 @@ class QuestionCheckBox(QuestionBase):
|
|
20
179
|
min_selections = IntegerDescriptor(none_allowed=True)
|
21
180
|
max_selections = IntegerDescriptor(none_allowed=True)
|
22
181
|
|
182
|
+
_response_model = None
|
183
|
+
response_validator_class = CheckBoxResponseValidator
|
184
|
+
|
23
185
|
def __init__(
|
24
186
|
self,
|
25
187
|
question_name: str,
|
@@ -27,6 +189,11 @@ class QuestionCheckBox(QuestionBase):
|
|
27
189
|
question_options: list[str],
|
28
190
|
min_selections: Optional[int] = None,
|
29
191
|
max_selections: Optional[int] = None,
|
192
|
+
include_comment: bool = True,
|
193
|
+
use_code: bool = True,
|
194
|
+
question_presentation: Optional[str] = None,
|
195
|
+
answering_instructions: Optional[str] = None,
|
196
|
+
permissive: bool = False,
|
30
197
|
):
|
31
198
|
"""Instantiate a new QuestionCheckBox.
|
32
199
|
|
@@ -42,15 +209,28 @@ class QuestionCheckBox(QuestionBase):
|
|
42
209
|
self.max_selections = max_selections
|
43
210
|
self.question_options = question_options
|
44
211
|
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
self.
|
51
|
-
|
52
|
-
|
53
|
-
|
212
|
+
self._include_comment = include_comment
|
213
|
+
self._use_code = use_code
|
214
|
+
self.permissive = permissive
|
215
|
+
|
216
|
+
self.question_presentation = question_presentation
|
217
|
+
self.answering_instructions = answering_instructions
|
218
|
+
|
219
|
+
def create_response_model(self):
|
220
|
+
if not self._use_code:
|
221
|
+
return create_checkbox_response_model(
|
222
|
+
self.question_options,
|
223
|
+
min_selections=self.min_selections,
|
224
|
+
max_selections=self.max_selections, # include_comment=self._include_comment
|
225
|
+
permissive=self.permissive,
|
226
|
+
)
|
227
|
+
else:
|
228
|
+
return create_checkbox_response_model(
|
229
|
+
list(range(len(self.question_options))),
|
230
|
+
min_selections=self.min_selections,
|
231
|
+
max_selections=self.max_selections, # include_comment=self._include_comment
|
232
|
+
permissive=self.permissive,
|
233
|
+
)
|
54
234
|
|
55
235
|
def _translate_answer_code_to_answer(
|
56
236
|
self, answer_codes, scenario: "Scenario" = None
|
@@ -69,33 +249,36 @@ class QuestionCheckBox(QuestionBase):
|
|
69
249
|
]
|
70
250
|
translated_codes = []
|
71
251
|
for answer_code in answer_codes:
|
72
|
-
|
252
|
+
if self._use_code:
|
253
|
+
translated_codes.append(translated_options[int(answer_code)])
|
254
|
+
else:
|
255
|
+
translated_codes.append(answer_code)
|
73
256
|
return translated_codes
|
74
257
|
|
75
|
-
def _simulate_answer(self, human_readable=True) -> dict[str, Union[int, str]]:
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
258
|
+
# def _simulate_answer(self, human_readable=True) -> dict[str, Union[int, str]]:
|
259
|
+
# """Simulate a valid answer for debugging purposes."""
|
260
|
+
# from edsl.utilities.utilities import random_string
|
261
|
+
|
262
|
+
# min_selections = self.min_selections or 1
|
263
|
+
# max_selections = self.max_selections or len(self.question_options)
|
264
|
+
# num_selections = random.randint(min_selections, max_selections)
|
265
|
+
# if human_readable:
|
266
|
+
# # Select a random number of options from self.question_options
|
267
|
+
# selected_options = random.sample(self.question_options, num_selections)
|
268
|
+
# answer = {
|
269
|
+
# "answer": selected_options,
|
270
|
+
# "comment": random_string(),
|
271
|
+
# }
|
272
|
+
# else:
|
273
|
+
# # Select a random number of indices from the range of self.question_options
|
274
|
+
# selected_indices = random.sample(
|
275
|
+
# range(len(self.question_options)), num_selections
|
276
|
+
# )
|
277
|
+
# answer = {
|
278
|
+
# "answer": selected_indices,
|
279
|
+
# "comment": random_string(),
|
280
|
+
# }
|
281
|
+
# return answer
|
99
282
|
|
100
283
|
@property
|
101
284
|
def question_html_content(self) -> str:
|
@@ -125,7 +308,8 @@ class QuestionCheckBox(QuestionBase):
|
|
125
308
|
# Helpful methods
|
126
309
|
################
|
127
310
|
@classmethod
|
128
|
-
|
311
|
+
@inject_exception
|
312
|
+
def example(cls, include_comment=False, use_code=True) -> QuestionCheckBox:
|
129
313
|
"""Return an example checkbox question."""
|
130
314
|
return cls(
|
131
315
|
question_name="never_eat",
|
@@ -139,6 +323,8 @@ class QuestionCheckBox(QuestionBase):
|
|
139
323
|
],
|
140
324
|
min_selections=2,
|
141
325
|
max_selections=5,
|
326
|
+
use_code=use_code,
|
327
|
+
include_comment=include_comment,
|
142
328
|
)
|
143
329
|
|
144
330
|
|
@@ -165,3 +351,9 @@ def main():
|
|
165
351
|
import doctest
|
166
352
|
|
167
353
|
doctest.testmod(optionflags=doctest.ELLIPSIS)
|
354
|
+
|
355
|
+
|
356
|
+
if __name__ == "__main__":
|
357
|
+
import doctest
|
358
|
+
|
359
|
+
doctest.testmod(optionflags=doctest.ELLIPSIS)
|
@@ -1,20 +1,112 @@
|
|
1
1
|
from __future__ import annotations
|
2
|
-
from typing import Any
|
2
|
+
from typing import Any, Optional, Dict
|
3
3
|
from edsl.questions.QuestionBase import QuestionBase
|
4
4
|
from edsl.questions.descriptors import AnswerTemplateDescriptor
|
5
5
|
|
6
|
+
from edsl.questions.ResponseValidatorABC import ResponseValidatorABC
|
7
|
+
from edsl.questions.ResponseValidatorABC import BaseResponse
|
8
|
+
from edsl.exceptions import QuestionAnswerValidationError
|
9
|
+
from edsl.questions.decorators import inject_exception
|
10
|
+
|
11
|
+
from typing import Dict, Any
|
12
|
+
from pydantic import create_model, Field
|
13
|
+
|
14
|
+
import json
|
15
|
+
import re
|
16
|
+
|
17
|
+
|
18
|
+
def extract_json(text, expected_keys, verbose=False):
|
19
|
+
# Escape special regex characters in keys
|
20
|
+
escaped_keys = [re.escape(key) for key in expected_keys]
|
21
|
+
|
22
|
+
# Create a pattern that looks for all expected keys
|
23
|
+
pattern = r"\{[^}]*" + r"[^}]*".join(escaped_keys) + r"[^}]*\}"
|
24
|
+
|
25
|
+
json_match = re.search(pattern, text)
|
26
|
+
|
27
|
+
if json_match:
|
28
|
+
json_str = json_match.group(0)
|
29
|
+
try:
|
30
|
+
# Parse the extracted string as JSON
|
31
|
+
json_data = json.loads(json_str)
|
32
|
+
|
33
|
+
# Verify that all expected keys are present
|
34
|
+
if all(key in json_data for key in expected_keys):
|
35
|
+
return json_data
|
36
|
+
else:
|
37
|
+
if verbose:
|
38
|
+
print(
|
39
|
+
"Error: Not all expected keys were found in the extracted JSON."
|
40
|
+
)
|
41
|
+
return None
|
42
|
+
except json.JSONDecodeError:
|
43
|
+
if verbose:
|
44
|
+
print("Error: The extracted content is not valid JSON.")
|
45
|
+
return None
|
46
|
+
else:
|
47
|
+
if verbose:
|
48
|
+
print("Error: No JSON-like structure found with all expected keys.")
|
49
|
+
return None
|
50
|
+
|
51
|
+
|
52
|
+
def dict_to_pydantic_model(input_dict: Dict[str, Any]) -> Any:
|
53
|
+
field_definitions = {
|
54
|
+
key: (str, Field(default=str(value))) for key, value in input_dict.items()
|
55
|
+
}
|
56
|
+
|
57
|
+
DynamicModel = create_model("DynamicModel", **field_definitions)
|
58
|
+
|
59
|
+
class AnswerModel(BaseResponse):
|
60
|
+
answer: DynamicModel
|
61
|
+
generated_tokens: Optional[str] = None
|
62
|
+
comment: Optional[str] = None
|
63
|
+
|
64
|
+
return AnswerModel
|
65
|
+
|
66
|
+
|
67
|
+
class ExtractResponseValidator(ResponseValidatorABC):
|
68
|
+
required_params = ["answer_template"]
|
69
|
+
valid_examples = [({"answer": "This is great"}, {})]
|
70
|
+
invalid_examples = [
|
71
|
+
(
|
72
|
+
{"answer": None},
|
73
|
+
{"answer_template": {"name": "John Doe", "profession": "Carpenter"}},
|
74
|
+
"Result cannot be empty",
|
75
|
+
),
|
76
|
+
]
|
77
|
+
|
78
|
+
def custom_validate(self, response) -> BaseResponse:
|
79
|
+
return response.dict()
|
80
|
+
|
81
|
+
def fix(self, response, verbose=False):
|
82
|
+
raw_tokens = response["generated_tokens"]
|
83
|
+
if verbose:
|
84
|
+
print(f"Invalid response of QuestionExtract was: {raw_tokens}")
|
85
|
+
extracted_json = extract_json(raw_tokens, self.answer_template.keys(), verbose)
|
86
|
+
if verbose:
|
87
|
+
print("Proposed solution is: ", extracted_json)
|
88
|
+
return {
|
89
|
+
"answer": extracted_json,
|
90
|
+
"comment": response.get("comment", None),
|
91
|
+
"generated_tokens": raw_tokens,
|
92
|
+
}
|
93
|
+
|
6
94
|
|
7
95
|
class QuestionExtract(QuestionBase):
|
8
96
|
"""This question prompts the agent to extract information from a string and return it in a given template."""
|
9
97
|
|
10
98
|
question_type = "extract"
|
11
99
|
answer_template: dict[str, Any] = AnswerTemplateDescriptor()
|
100
|
+
_response_model = None
|
101
|
+
response_validator_class = ExtractResponseValidator
|
12
102
|
|
13
103
|
def __init__(
|
14
104
|
self,
|
15
105
|
question_text: str,
|
16
106
|
answer_template: dict[str, Any],
|
17
107
|
question_name: str,
|
108
|
+
answering_instructions: str = None,
|
109
|
+
question_presentation: str = None,
|
18
110
|
):
|
19
111
|
"""Initialize the question.
|
20
112
|
|
@@ -26,33 +118,11 @@ class QuestionExtract(QuestionBase):
|
|
26
118
|
self.question_name = question_name
|
27
119
|
self.question_text = question_text
|
28
120
|
self.answer_template = answer_template
|
121
|
+
self.answering_instructions = answering_instructions
|
122
|
+
self.question_presentation = question_presentation
|
29
123
|
|
30
|
-
|
31
|
-
|
32
|
-
################
|
33
|
-
def _validate_answer(self, answer: Any) -> dict[str, Any]:
|
34
|
-
"""Validate the answer."""
|
35
|
-
# raw_json = answer["answer"]
|
36
|
-
# fixed_json_data = re.sub(r"\'", '"', raw_json)
|
37
|
-
# answer["answer"] = json.loads(fixed_json_data)
|
38
|
-
self._validate_answer_template_basic(answer)
|
39
|
-
# self._validate_answer_key_value(answer, "answer", dict)
|
40
|
-
|
41
|
-
self._validate_answer_extract(answer)
|
42
|
-
return answer
|
43
|
-
|
44
|
-
def _translate_answer_code_to_answer(self, answer, scenario: "Scenario" = None):
|
45
|
-
"""Return the answer in a human-readable format."""
|
46
|
-
return answer
|
47
|
-
|
48
|
-
def _simulate_answer(self, human_readable: bool = True) -> dict[str, str]:
|
49
|
-
"""Simulate a valid answer for debugging purposes."""
|
50
|
-
from edsl.utilities.utilities import random_string
|
51
|
-
|
52
|
-
return {
|
53
|
-
"answer": {key: random_string() for key in self.answer_template.keys()},
|
54
|
-
"comment": random_string(),
|
55
|
-
}
|
124
|
+
def create_response_model(self):
|
125
|
+
return dict_to_pydantic_model(self.answer_template)
|
56
126
|
|
57
127
|
@property
|
58
128
|
def question_html_content(self) -> str:
|
@@ -77,6 +147,7 @@ class QuestionExtract(QuestionBase):
|
|
77
147
|
# Helpful methods
|
78
148
|
################
|
79
149
|
@classmethod
|
150
|
+
@inject_exception
|
80
151
|
def example(cls) -> QuestionExtract:
|
81
152
|
"""Return an example question."""
|
82
153
|
return cls(
|
@@ -1,23 +1,56 @@
|
|
1
1
|
from __future__ import annotations
|
2
|
-
import textwrap
|
3
2
|
from typing import Any, Optional
|
4
3
|
from uuid import uuid4
|
4
|
+
|
5
|
+
from pydantic import field_validator
|
6
|
+
|
5
7
|
from edsl.questions.QuestionBase import QuestionBase
|
8
|
+
from edsl.questions.ResponseValidatorABC import ResponseValidatorABC
|
9
|
+
|
10
|
+
from edsl.exceptions import QuestionAnswerValidationError
|
11
|
+
from edsl.questions.decorators import inject_exception
|
12
|
+
|
13
|
+
from pydantic import BaseModel
|
14
|
+
from typing import Optional, Any, List
|
15
|
+
|
16
|
+
from edsl.exceptions import QuestionAnswerValidationError
|
17
|
+
|
18
|
+
|
19
|
+
class FreeTextResponse(BaseModel):
|
20
|
+
"""
|
21
|
+
Validator for free text response questions.
|
22
|
+
"""
|
23
|
+
|
24
|
+
answer: str
|
25
|
+
generated_tokens: Optional[str] = None
|
26
|
+
|
27
|
+
|
28
|
+
class FreeTextResponseValidator(ResponseValidatorABC):
|
29
|
+
required_params = []
|
30
|
+
valid_examples = [({"answer": "This is great"}, {})]
|
31
|
+
invalid_examples = [
|
32
|
+
(
|
33
|
+
{"answer": None},
|
34
|
+
{},
|
35
|
+
"Answer code must not be missing.",
|
36
|
+
),
|
37
|
+
]
|
6
38
|
|
7
39
|
|
8
40
|
class QuestionFreeText(QuestionBase):
|
9
41
|
"""This question prompts the agent to respond with free text."""
|
10
42
|
|
11
43
|
question_type = "free_text"
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
44
|
+
_response_model = FreeTextResponse
|
45
|
+
response_validator_class = FreeTextResponseValidator
|
46
|
+
|
47
|
+
def __init__(
|
48
|
+
self,
|
49
|
+
question_name: str,
|
50
|
+
question_text: str,
|
51
|
+
answering_instructions: Optional[str] = None,
|
52
|
+
question_presentation: Optional[str] = None,
|
53
|
+
):
|
21
54
|
"""Instantiate a new QuestionFreeText.
|
22
55
|
|
23
56
|
:param question_name: The name of the question.
|
@@ -25,25 +58,8 @@ class QuestionFreeText(QuestionBase):
|
|
25
58
|
"""
|
26
59
|
self.question_name = question_name
|
27
60
|
self.question_text = question_text
|
28
|
-
|
29
|
-
|
30
|
-
# Answer methods
|
31
|
-
################
|
32
|
-
def _validate_answer(self, answer: Any) -> dict[str, str]:
|
33
|
-
"""Validate the answer."""
|
34
|
-
self._validate_answer_template_basic(answer)
|
35
|
-
self._validate_answer_key_value(answer, "answer", str)
|
36
|
-
return answer
|
37
|
-
|
38
|
-
def _translate_answer_code_to_answer(self, answer, scenario: "Scenario" = None):
|
39
|
-
"""Do nothing, because the answer is already in a human-readable format."""
|
40
|
-
return answer
|
41
|
-
|
42
|
-
def _simulate_answer(self, human_readable: bool = True) -> dict[str, str]:
|
43
|
-
"""Simulate a valid answer for debugging purposes."""
|
44
|
-
from edsl.utilities.utilities import random_string
|
45
|
-
|
46
|
-
return {"answer": random_string()}
|
61
|
+
self.answering_instructions = answering_instructions
|
62
|
+
self.question_presentation = question_presentation
|
47
63
|
|
48
64
|
@property
|
49
65
|
def question_html_content(self) -> str:
|
@@ -59,6 +75,7 @@ class QuestionFreeText(QuestionBase):
|
|
59
75
|
return question_html_content
|
60
76
|
|
61
77
|
@classmethod
|
78
|
+
@inject_exception
|
62
79
|
def example(cls, randomize: bool = False) -> QuestionFreeText:
|
63
80
|
"""Return an example instance of a free text question."""
|
64
81
|
addition = "" if not randomize else str(uuid4())
|
@@ -39,6 +39,9 @@ class QuestionFunctional(QuestionBase):
|
|
39
39
|
function_source_code = ""
|
40
40
|
function_name = ""
|
41
41
|
|
42
|
+
_response_model = None
|
43
|
+
response_validator_class = None
|
44
|
+
|
42
45
|
def __init__(
|
43
46
|
self,
|
44
47
|
question_name: str,
|
@@ -97,6 +100,10 @@ class QuestionFunctional(QuestionBase):
|
|
97
100
|
"""Required by Question, but not used by QuestionFunctional."""
|
98
101
|
raise NotImplementedError
|
99
102
|
|
103
|
+
@property
|
104
|
+
def question_html_content(self) -> str:
|
105
|
+
return "NA for QuestionFunctional"
|
106
|
+
|
100
107
|
@add_edsl_version
|
101
108
|
def to_dict(self):
|
102
109
|
return {
|