edsl 0.1.33.dev1__py3-none-any.whl → 0.1.33.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/TemplateLoader.py +24 -0
- edsl/__init__.py +8 -4
- edsl/agents/Agent.py +46 -14
- edsl/agents/AgentList.py +43 -0
- edsl/agents/Invigilator.py +125 -212
- edsl/agents/InvigilatorBase.py +140 -32
- edsl/agents/PromptConstructionMixin.py +43 -66
- edsl/agents/__init__.py +1 -0
- edsl/auto/AutoStudy.py +117 -0
- edsl/auto/StageBase.py +230 -0
- edsl/auto/StageGenerateSurvey.py +178 -0
- edsl/auto/StageLabelQuestions.py +125 -0
- edsl/auto/StagePersona.py +61 -0
- edsl/auto/StagePersonaDimensionValueRanges.py +88 -0
- edsl/auto/StagePersonaDimensionValues.py +74 -0
- edsl/auto/StagePersonaDimensions.py +69 -0
- edsl/auto/StageQuestions.py +73 -0
- edsl/auto/SurveyCreatorPipeline.py +21 -0
- edsl/auto/utilities.py +224 -0
- edsl/config.py +38 -39
- edsl/coop/PriceFetcher.py +58 -0
- edsl/coop/coop.py +39 -5
- edsl/data/Cache.py +35 -1
- edsl/data_transfer_models.py +120 -38
- edsl/enums.py +2 -0
- edsl/exceptions/language_models.py +25 -1
- edsl/exceptions/questions.py +62 -5
- edsl/exceptions/results.py +4 -0
- edsl/inference_services/AnthropicService.py +13 -11
- edsl/inference_services/AwsBedrock.py +19 -17
- edsl/inference_services/AzureAI.py +37 -20
- edsl/inference_services/GoogleService.py +16 -12
- edsl/inference_services/GroqService.py +2 -0
- edsl/inference_services/InferenceServiceABC.py +24 -0
- edsl/inference_services/MistralAIService.py +120 -0
- edsl/inference_services/OpenAIService.py +41 -50
- edsl/inference_services/TestService.py +71 -0
- edsl/inference_services/models_available_cache.py +0 -6
- edsl/inference_services/registry.py +4 -0
- edsl/jobs/Answers.py +10 -12
- edsl/jobs/FailedQuestion.py +78 -0
- edsl/jobs/Jobs.py +18 -13
- edsl/jobs/buckets/TokenBucket.py +39 -14
- edsl/jobs/interviews/Interview.py +297 -77
- edsl/jobs/interviews/InterviewExceptionEntry.py +83 -19
- edsl/jobs/interviews/interview_exception_tracking.py +0 -70
- edsl/jobs/interviews/retry_management.py +3 -1
- edsl/jobs/runners/JobsRunnerAsyncio.py +116 -70
- edsl/jobs/runners/JobsRunnerStatusMixin.py +1 -1
- edsl/jobs/tasks/QuestionTaskCreator.py +30 -23
- edsl/jobs/tasks/TaskHistory.py +131 -213
- edsl/language_models/LanguageModel.py +239 -129
- edsl/language_models/ModelList.py +2 -2
- edsl/language_models/RegisterLanguageModelsMeta.py +14 -29
- edsl/language_models/fake_openai_call.py +15 -0
- edsl/language_models/fake_openai_service.py +61 -0
- edsl/language_models/registry.py +15 -2
- edsl/language_models/repair.py +0 -19
- edsl/language_models/utilities.py +61 -0
- edsl/prompts/Prompt.py +52 -2
- edsl/questions/AnswerValidatorMixin.py +23 -26
- edsl/questions/QuestionBase.py +273 -242
- edsl/questions/QuestionBaseGenMixin.py +133 -0
- edsl/questions/QuestionBasePromptsMixin.py +266 -0
- edsl/questions/QuestionBudget.py +6 -0
- edsl/questions/QuestionCheckBox.py +227 -35
- edsl/questions/QuestionExtract.py +98 -27
- edsl/questions/QuestionFreeText.py +46 -29
- edsl/questions/QuestionFunctional.py +7 -0
- edsl/questions/QuestionList.py +141 -22
- edsl/questions/QuestionMultipleChoice.py +173 -64
- edsl/questions/QuestionNumerical.py +87 -46
- edsl/questions/QuestionRank.py +182 -24
- edsl/questions/RegisterQuestionsMeta.py +31 -12
- edsl/questions/ResponseValidatorABC.py +169 -0
- edsl/questions/__init__.py +3 -4
- edsl/questions/decorators.py +21 -0
- edsl/questions/derived/QuestionLikertFive.py +10 -5
- edsl/questions/derived/QuestionLinearScale.py +11 -1
- edsl/questions/derived/QuestionTopK.py +6 -0
- edsl/questions/derived/QuestionYesNo.py +16 -1
- edsl/questions/descriptors.py +43 -7
- edsl/questions/prompt_templates/question_budget.jinja +13 -0
- edsl/questions/prompt_templates/question_checkbox.jinja +32 -0
- edsl/questions/prompt_templates/question_extract.jinja +11 -0
- edsl/questions/prompt_templates/question_free_text.jinja +3 -0
- edsl/questions/prompt_templates/question_linear_scale.jinja +11 -0
- edsl/questions/prompt_templates/question_list.jinja +17 -0
- edsl/questions/prompt_templates/question_multiple_choice.jinja +33 -0
- edsl/questions/prompt_templates/question_numerical.jinja +37 -0
- edsl/questions/question_registry.py +6 -2
- edsl/questions/templates/__init__.py +0 -0
- edsl/questions/templates/checkbox/__init__.py +0 -0
- edsl/questions/templates/checkbox/answering_instructions.jinja +10 -0
- edsl/questions/templates/checkbox/question_presentation.jinja +22 -0
- edsl/questions/templates/extract/answering_instructions.jinja +7 -0
- edsl/questions/templates/extract/question_presentation.jinja +1 -0
- edsl/questions/templates/free_text/__init__.py +0 -0
- edsl/questions/templates/free_text/answering_instructions.jinja +0 -0
- edsl/questions/templates/free_text/question_presentation.jinja +1 -0
- edsl/questions/templates/likert_five/__init__.py +0 -0
- edsl/questions/templates/likert_five/answering_instructions.jinja +10 -0
- edsl/questions/templates/likert_five/question_presentation.jinja +12 -0
- edsl/questions/templates/linear_scale/__init__.py +0 -0
- edsl/questions/templates/linear_scale/answering_instructions.jinja +5 -0
- edsl/questions/templates/linear_scale/question_presentation.jinja +5 -0
- edsl/questions/templates/list/__init__.py +0 -0
- edsl/questions/templates/list/answering_instructions.jinja +4 -0
- edsl/questions/templates/list/question_presentation.jinja +5 -0
- edsl/questions/templates/multiple_choice/__init__.py +0 -0
- edsl/questions/templates/multiple_choice/answering_instructions.jinja +9 -0
- edsl/questions/templates/multiple_choice/html.jinja +0 -0
- edsl/questions/templates/multiple_choice/question_presentation.jinja +12 -0
- edsl/questions/templates/numerical/__init__.py +0 -0
- edsl/questions/templates/numerical/answering_instructions.jinja +8 -0
- edsl/questions/templates/numerical/question_presentation.jinja +7 -0
- edsl/questions/templates/rank/answering_instructions.jinja +11 -0
- edsl/questions/templates/rank/question_presentation.jinja +15 -0
- edsl/questions/templates/top_k/__init__.py +0 -0
- edsl/questions/templates/top_k/answering_instructions.jinja +8 -0
- edsl/questions/templates/top_k/question_presentation.jinja +22 -0
- edsl/questions/templates/yes_no/__init__.py +0 -0
- edsl/questions/templates/yes_no/answering_instructions.jinja +6 -0
- edsl/questions/templates/yes_no/question_presentation.jinja +12 -0
- edsl/results/Dataset.py +20 -0
- edsl/results/DatasetExportMixin.py +41 -47
- edsl/results/DatasetTree.py +145 -0
- edsl/results/Result.py +32 -5
- edsl/results/Results.py +131 -45
- edsl/results/ResultsDBMixin.py +3 -3
- edsl/results/Selector.py +118 -0
- edsl/results/tree_explore.py +115 -0
- edsl/scenarios/Scenario.py +10 -4
- edsl/scenarios/ScenarioList.py +348 -39
- edsl/scenarios/ScenarioListExportMixin.py +9 -0
- edsl/study/SnapShot.py +8 -1
- edsl/surveys/RuleCollection.py +2 -2
- edsl/surveys/Survey.py +634 -315
- edsl/surveys/SurveyExportMixin.py +71 -9
- edsl/surveys/SurveyFlowVisualizationMixin.py +2 -1
- edsl/surveys/SurveyQualtricsImport.py +75 -4
- edsl/surveys/instructions/ChangeInstruction.py +47 -0
- edsl/surveys/instructions/Instruction.py +34 -0
- edsl/surveys/instructions/InstructionCollection.py +77 -0
- edsl/surveys/instructions/__init__.py +0 -0
- edsl/templates/error_reporting/base.html +24 -0
- edsl/templates/error_reporting/exceptions_by_model.html +35 -0
- edsl/templates/error_reporting/exceptions_by_question_name.html +17 -0
- edsl/templates/error_reporting/exceptions_by_type.html +17 -0
- edsl/templates/error_reporting/interview_details.html +111 -0
- edsl/templates/error_reporting/interviews.html +10 -0
- edsl/templates/error_reporting/overview.html +5 -0
- edsl/templates/error_reporting/performance_plot.html +2 -0
- edsl/templates/error_reporting/report.css +74 -0
- edsl/templates/error_reporting/report.html +118 -0
- edsl/templates/error_reporting/report.js +25 -0
- {edsl-0.1.33.dev1.dist-info → edsl-0.1.33.dev2.dist-info}/METADATA +4 -2
- edsl-0.1.33.dev2.dist-info/RECORD +289 -0
- edsl/jobs/interviews/InterviewTaskBuildingMixin.py +0 -286
- edsl/utilities/gcp_bucket/simple_example.py +0 -9
- edsl-0.1.33.dev1.dist-info/RECORD +0 -209
- {edsl-0.1.33.dev1.dist-info → edsl-0.1.33.dev2.dist-info}/LICENSE +0 -0
- {edsl-0.1.33.dev1.dist-info → edsl-0.1.33.dev2.dist-info}/WHEEL +0 -0
edsl/TemplateLoader.py
ADDED
@@ -0,0 +1,24 @@
|
|
1
|
+
from importlib import resources
|
2
|
+
from jinja2 import BaseLoader, TemplateNotFound
|
3
|
+
import os
|
4
|
+
|
5
|
+
|
6
|
+
class TemplateLoader(BaseLoader):
|
7
|
+
def __init__(self, package_name, templates_dir):
|
8
|
+
self.package_name = package_name
|
9
|
+
self.templates_dir = templates_dir
|
10
|
+
|
11
|
+
def get_source(self, environment, template):
|
12
|
+
try:
|
13
|
+
parts = [self.templates_dir] + template.split("/")
|
14
|
+
template_path = os.path.join(*parts)
|
15
|
+
|
16
|
+
# Use resources.files() to get a Traversable object
|
17
|
+
templates = resources.files(self.package_name).joinpath(self.templates_dir)
|
18
|
+
|
19
|
+
# Use the read_text() method of the Traversable object
|
20
|
+
content = templates.joinpath(template).read_text()
|
21
|
+
|
22
|
+
return content, None, lambda: True
|
23
|
+
except FileNotFoundError:
|
24
|
+
raise TemplateNotFound(template)
|
edsl/__init__.py
CHANGED
@@ -8,9 +8,10 @@ from edsl.__version__ import __version__
|
|
8
8
|
from edsl.config import Config, CONFIG
|
9
9
|
from edsl.agents.Agent import Agent
|
10
10
|
from edsl.agents.AgentList import AgentList
|
11
|
+
|
11
12
|
from edsl.questions import QuestionBase
|
13
|
+
from edsl.questions.question_registry import Question
|
12
14
|
from edsl.questions import QuestionMultipleChoice
|
13
|
-
from edsl.questions import QuestionBudget
|
14
15
|
from edsl.questions import QuestionCheckBox
|
15
16
|
from edsl.questions import QuestionExtract
|
16
17
|
from edsl.questions import QuestionFreeText
|
@@ -19,10 +20,10 @@ from edsl.questions import QuestionLikertFive
|
|
19
20
|
from edsl.questions import QuestionList
|
20
21
|
from edsl.questions import QuestionLinearScale
|
21
22
|
from edsl.questions import QuestionNumerical
|
22
|
-
from edsl.questions import QuestionRank
|
23
|
-
from edsl.questions import QuestionTopK
|
24
23
|
from edsl.questions import QuestionYesNo
|
25
|
-
from edsl.questions
|
24
|
+
from edsl.questions import QuestionBudget
|
25
|
+
from edsl.questions import QuestionRank
|
26
|
+
|
26
27
|
from edsl.scenarios import Scenario
|
27
28
|
from edsl.scenarios import ScenarioList
|
28
29
|
|
@@ -40,3 +41,6 @@ from edsl.notebooks.Notebook import Notebook
|
|
40
41
|
from edsl.study.Study import Study
|
41
42
|
from edsl.conjure.Conjure import Conjure
|
42
43
|
from edsl.coop.coop import Coop
|
44
|
+
|
45
|
+
from edsl.surveys.instructions.Instruction import Instruction
|
46
|
+
from edsl.surveys.instructions.ChangeInstruction import ChangeInstruction
|
edsl/agents/Agent.py
CHANGED
@@ -4,7 +4,7 @@ from __future__ import annotations
|
|
4
4
|
import copy
|
5
5
|
import inspect
|
6
6
|
import types
|
7
|
-
from typing import Callable, Optional, Union
|
7
|
+
from typing import Callable, Optional, Union, Any
|
8
8
|
from uuid import uuid4
|
9
9
|
from edsl.Base import Base
|
10
10
|
|
@@ -228,7 +228,12 @@ class Agent(Base):
|
|
228
228
|
if hasattr(self, "answer_question_directly"):
|
229
229
|
delattr(self, "answer_question_directly")
|
230
230
|
|
231
|
-
def add_direct_question_answering_method(
|
231
|
+
def add_direct_question_answering_method(
|
232
|
+
self,
|
233
|
+
method: Callable,
|
234
|
+
validate_response: bool = False,
|
235
|
+
translate_response: bool = False,
|
236
|
+
) -> None:
|
232
237
|
"""Add a method to the agent that can answer a particular question type.
|
233
238
|
|
234
239
|
:param method: A method that can answer a question directly.
|
@@ -249,6 +254,9 @@ class Agent(Base):
|
|
249
254
|
)
|
250
255
|
# print("Warning: overwriting existing answer_question_directly method")
|
251
256
|
|
257
|
+
self.validate_response = validate_response
|
258
|
+
self.translate_response = translate_response
|
259
|
+
|
252
260
|
signature = inspect.signature(method)
|
253
261
|
for argument in ["question", "scenario", "self"]:
|
254
262
|
if argument not in signature.parameters:
|
@@ -272,6 +280,7 @@ class Agent(Base):
|
|
272
280
|
current_answers: Optional[dict] = None,
|
273
281
|
iteration: int = 1,
|
274
282
|
sidecar_model=None,
|
283
|
+
raise_validation_errors: bool = True,
|
275
284
|
) -> "InvigilatorBase":
|
276
285
|
"""Create an Invigilator.
|
277
286
|
|
@@ -303,7 +312,12 @@ class Agent(Base):
|
|
303
312
|
iteration=iteration,
|
304
313
|
cache=cache,
|
305
314
|
sidecar_model=sidecar_model,
|
315
|
+
raise_validation_errors=raise_validation_errors,
|
306
316
|
)
|
317
|
+
if hasattr(self, "validate_response"):
|
318
|
+
invigilator.validate_response = self.validate_response
|
319
|
+
if hasattr(self, "translate_response"):
|
320
|
+
invigilator.translate_response = self.translate_response
|
307
321
|
return invigilator
|
308
322
|
|
309
323
|
async def async_answer_question(
|
@@ -334,8 +348,8 @@ class Agent(Base):
|
|
334
348
|
>>> a.add_direct_question_answering_method(lambda self, question, scenario: "I am a direct answer.")
|
335
349
|
>>> from edsl import QuestionFreeText
|
336
350
|
>>> q = QuestionFreeText.example()
|
337
|
-
>>> a.answer_question(question = q, cache = False)
|
338
|
-
|
351
|
+
>>> a.answer_question(question = q, cache = False).answer
|
352
|
+
'I am a direct answer.'
|
339
353
|
|
340
354
|
This is a function where an agent returns an answer to a particular question.
|
341
355
|
However, there are several different ways an agent can answer a question, so the
|
@@ -369,6 +383,7 @@ class Agent(Base):
|
|
369
383
|
current_answers: Optional[dict] = None,
|
370
384
|
iteration: int = 0,
|
371
385
|
sidecar_model=None,
|
386
|
+
raise_validation_errors: bool = True,
|
372
387
|
) -> "InvigilatorBase":
|
373
388
|
"""Create an Invigilator."""
|
374
389
|
from edsl import Model
|
@@ -378,7 +393,6 @@ class Agent(Base):
|
|
378
393
|
scenario = scenario or Scenario()
|
379
394
|
|
380
395
|
from edsl.agents.Invigilator import (
|
381
|
-
InvigilatorDebug,
|
382
396
|
InvigilatorHuman,
|
383
397
|
InvigilatorFunctional,
|
384
398
|
InvigilatorAI,
|
@@ -391,8 +405,9 @@ class Agent(Base):
|
|
391
405
|
cache = Cache()
|
392
406
|
|
393
407
|
if debug:
|
408
|
+
raise NotImplementedError("Debug mode is not yet implemented.")
|
394
409
|
# use the question's _simulate_answer method
|
395
|
-
invigilator_class = InvigilatorDebug
|
410
|
+
# invigilator_class = InvigilatorDebug
|
396
411
|
elif hasattr(question, "answer_question_directly"):
|
397
412
|
# It's a functional question and the answer only depends on the agent's traits & the scenario
|
398
413
|
invigilator_class = InvigilatorFunctional
|
@@ -422,6 +437,7 @@ class Agent(Base):
|
|
422
437
|
iteration=iteration,
|
423
438
|
cache=cache,
|
424
439
|
sidecar_model=sidecar_model,
|
440
|
+
raise_validation_errors=raise_validation_errors,
|
425
441
|
)
|
426
442
|
return invigilator
|
427
443
|
|
@@ -497,8 +513,8 @@ class Agent(Base):
|
|
497
513
|
if name == "has_dynamic_traits_function":
|
498
514
|
return self.has_dynamic_traits_function
|
499
515
|
|
500
|
-
if name in self.
|
501
|
-
return self.
|
516
|
+
if name in self._traits:
|
517
|
+
return self._traits[name]
|
502
518
|
raise AttributeError(
|
503
519
|
f"'{type(self).__name__}' object has no attribute '{name}'"
|
504
520
|
)
|
@@ -570,9 +586,9 @@ class Agent(Base):
|
|
570
586
|
if dynamic_traits_func:
|
571
587
|
func = inspect.getsource(dynamic_traits_func)
|
572
588
|
raw_data["dynamic_traits_function_source_code"] = func
|
573
|
-
raw_data[
|
574
|
-
|
575
|
-
|
589
|
+
raw_data["dynamic_traits_function_name"] = (
|
590
|
+
self.dynamic_traits_function_name
|
591
|
+
)
|
576
592
|
if hasattr(self, "answer_question_directly"):
|
577
593
|
raw_data.pop(
|
578
594
|
"answer_question_directly", None
|
@@ -588,9 +604,9 @@ class Agent(Base):
|
|
588
604
|
raw_data["answer_question_directly_source_code"] = inspect.getsource(
|
589
605
|
answer_question_directly_func
|
590
606
|
)
|
591
|
-
raw_data[
|
592
|
-
|
593
|
-
|
607
|
+
raw_data["answer_question_directly_function_name"] = (
|
608
|
+
self.answer_question_directly_function_name
|
609
|
+
)
|
594
610
|
|
595
611
|
return raw_data
|
596
612
|
|
@@ -640,6 +656,22 @@ class Agent(Base):
|
|
640
656
|
column_names = ["Attribute", "Value"]
|
641
657
|
return table_data, column_names
|
642
658
|
|
659
|
+
def add_trait(self, trait_name_or_dict: str, value: Optional[Any] = None) -> Agent:
|
660
|
+
"""Adds a trait to an agent and returns that agent"""
|
661
|
+
if isinstance(trait_name_or_dict, dict) and value is None:
|
662
|
+
self.traits.update(trait_name_or_dict)
|
663
|
+
return self
|
664
|
+
|
665
|
+
if isinstance(trait_name_or_dict, dict) and value:
|
666
|
+
raise ValueError(f"You passed a dict: {trait_name_or_dict}")
|
667
|
+
|
668
|
+
if isinstance(trait_name_or_dict, str):
|
669
|
+
trait = trait_name_or_dict
|
670
|
+
self.traits[trait] = value
|
671
|
+
return self
|
672
|
+
|
673
|
+
raise Exception("Something is not right with adding")
|
674
|
+
|
643
675
|
def remove_trait(self, trait: str) -> Agent:
|
644
676
|
"""Remove a trait from the agent.
|
645
677
|
|
edsl/agents/AgentList.py
CHANGED
@@ -21,6 +21,12 @@ from simpleeval import EvalWithCompoundTypes
|
|
21
21
|
from edsl.Base import Base
|
22
22
|
from edsl.utilities.decorators import add_edsl_version, remove_edsl_version
|
23
23
|
|
24
|
+
from collections.abc import Iterable
|
25
|
+
|
26
|
+
|
27
|
+
def is_iterable(obj):
|
28
|
+
return isinstance(obj, Iterable)
|
29
|
+
|
24
30
|
|
25
31
|
class AgentList(UserList, Base):
|
26
32
|
"""A list of Agents."""
|
@@ -111,6 +117,13 @@ class AgentList(UserList, Base):
|
|
111
117
|
|
112
118
|
return AgentList(new_data)
|
113
119
|
|
120
|
+
@property
|
121
|
+
def all_traits(self):
|
122
|
+
d = {}
|
123
|
+
for agent in self:
|
124
|
+
d.update(agent.traits)
|
125
|
+
return list(d.keys())
|
126
|
+
|
114
127
|
@classmethod
|
115
128
|
def from_csv(cls, file_path: str):
|
116
129
|
"""Load AgentList from a CSV file.
|
@@ -159,6 +172,36 @@ class AgentList(UserList, Base):
|
|
159
172
|
_ = agent.remove_trait(trait)
|
160
173
|
return self
|
161
174
|
|
175
|
+
def add_trait(self, trait, values):
|
176
|
+
"""Adds a new trait to every agent, with values taken from values.
|
177
|
+
|
178
|
+
:param trait: The name of the trait.
|
179
|
+
:param values: The valeues(s) of the trait. If a single value is passed, it is used for all agents.
|
180
|
+
|
181
|
+
>>> al = AgentList.example()
|
182
|
+
>>> al.add_trait('new_trait', 1)
|
183
|
+
AgentList([Agent(traits = {'age': 22, 'hair': 'brown', 'height': 5.5, 'new_trait': 1}), Agent(traits = {'age': 22, 'hair': 'brown', 'height': 5.5, 'new_trait': 1})])
|
184
|
+
>>> al.select('new_trait').to_scenario_list().to_list()
|
185
|
+
[1, 1]
|
186
|
+
>>> al.add_trait('new_trait', [1, 2, 3])
|
187
|
+
Traceback (most recent call last):
|
188
|
+
...
|
189
|
+
ValueError: The passed values have to be the same length as the agent list.
|
190
|
+
"""
|
191
|
+
if not is_iterable(values):
|
192
|
+
value = values
|
193
|
+
for agent in self.data:
|
194
|
+
agent.add_trait(trait, value)
|
195
|
+
return self
|
196
|
+
|
197
|
+
if len(values) != len(self):
|
198
|
+
raise ValueError(
|
199
|
+
"The passed values have to be the same length as the agent list."
|
200
|
+
)
|
201
|
+
for agent, value in zip(self.data, values):
|
202
|
+
agent.add_trait(trait, value)
|
203
|
+
return self
|
204
|
+
|
162
205
|
@staticmethod
|
163
206
|
def get_codebook(file_path: str):
|
164
207
|
"""Return the codebook for a CSV file.
|
edsl/agents/Invigilator.py
CHANGED
@@ -1,17 +1,22 @@
|
|
1
1
|
"""Module for creating Invigilators, which are objects to administer a question to an Agent."""
|
2
2
|
|
3
|
-
import json
|
4
3
|
from typing import Dict, Any, Optional
|
5
4
|
|
6
5
|
from edsl.exceptions import AgentRespondedWithBadJSONError
|
7
6
|
from edsl.prompts.Prompt import Prompt
|
8
7
|
from edsl.utilities.decorators import sync_wrapper, jupyter_nb_handler
|
9
8
|
from edsl.prompts.registry import get_classes as prompt_lookup
|
10
|
-
from edsl.
|
11
|
-
from edsl.exceptions.agents import FailedTaskException
|
9
|
+
from edsl.exceptions.questions import QuestionAnswerValidationError
|
12
10
|
from edsl.agents.PromptConstructionMixin import PromptConstructorMixin
|
13
|
-
|
14
11
|
from edsl.agents.InvigilatorBase import InvigilatorBase
|
12
|
+
from edsl.data_transfer_models import AgentResponseDict, EDSLResultObjectInput
|
13
|
+
|
14
|
+
|
15
|
+
class NotApplicable(str):
|
16
|
+
def __new__(cls):
|
17
|
+
instance = super().__new__(cls, "Not Applicable")
|
18
|
+
instance.literal = "Not Applicable"
|
19
|
+
return instance
|
15
20
|
|
16
21
|
|
17
22
|
class InvigilatorAI(PromptConstructorMixin, InvigilatorBase):
|
@@ -22,231 +27,140 @@ class InvigilatorAI(PromptConstructorMixin, InvigilatorBase):
|
|
22
27
|
|
23
28
|
>>> i = InvigilatorAI.example()
|
24
29
|
>>> i.answer_question()
|
25
|
-
{'message': '
|
30
|
+
{'message': [{'text': 'SPAM!'}], 'usage': {'prompt_tokens': 1, 'completion_tokens': 1}}
|
26
31
|
"""
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
"
|
31
|
-
"question": self.question,
|
32
|
-
"scenario": self.scenario,
|
33
|
-
"raw_response": raw_response,
|
34
|
-
"raw_model_response": raw_response["raw_model_response"],
|
32
|
+
prompts = self.get_prompts()
|
33
|
+
params = {
|
34
|
+
"user_prompt": prompts["user_prompt"].text,
|
35
|
+
"system_prompt": prompts["system_prompt"].text,
|
35
36
|
}
|
36
|
-
|
37
|
-
|
38
|
-
return AgentResponseDict(**response)
|
37
|
+
if "encoded_image" in prompts:
|
38
|
+
params["encoded_image"] = prompts["encoded_image"]
|
39
39
|
|
40
|
-
|
41
|
-
self,
|
42
|
-
user_prompt: Prompt,
|
43
|
-
system_prompt: Prompt,
|
44
|
-
iteration: int = 0,
|
45
|
-
encoded_image=None,
|
46
|
-
) -> dict:
|
47
|
-
"""Call the LLM and gets a response. Used in the `answer_question` method."""
|
48
|
-
try:
|
49
|
-
params = {
|
50
|
-
"user_prompt": user_prompt.text,
|
51
|
-
"system_prompt": system_prompt.text,
|
52
|
-
"iteration": iteration,
|
53
|
-
"cache": self.cache,
|
54
|
-
}
|
55
|
-
if encoded_image:
|
56
|
-
params["encoded_image"] = encoded_image
|
57
|
-
response = await self.model.async_get_response(**params)
|
40
|
+
params.update({"iteration": self.iteration, "cache": self.cache})
|
58
41
|
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
f"System Prompt: {system_prompt}"
|
66
|
-
)
|
42
|
+
agent_response_dict: AgentResponseDict = await self.model.async_get_response(
|
43
|
+
**params
|
44
|
+
)
|
45
|
+
# store to self in case validation failure
|
46
|
+
self.raw_model_response = agent_response_dict.model_outputs.response
|
47
|
+
self.generated_tokens = agent_response_dict.edsl_dict.generated_tokens
|
67
48
|
|
68
|
-
return
|
49
|
+
return self.extract_edsl_result_entry_and_validate(agent_response_dict)
|
69
50
|
|
70
|
-
def _remove_from_cache(self,
|
51
|
+
def _remove_from_cache(self, cache_key) -> None:
|
71
52
|
"""Remove an entry from the cache."""
|
72
|
-
cache_key = raw_response.get("cache_key", None)
|
73
53
|
if cache_key:
|
74
54
|
del self.cache.data[cache_key]
|
75
55
|
|
76
|
-
def
|
77
|
-
self, *, agent, question, scenario, raw_response, raw_model_response
|
78
|
-
) -> AgentResponseDict:
|
79
|
-
"""Return formatted raw response.
|
80
|
-
|
81
|
-
This cleans up the raw response to make it suitable to pass to AgentResponseDict.
|
82
|
-
"""
|
83
|
-
_ = agent
|
84
|
-
try:
|
85
|
-
response = question._validate_answer(raw_response)
|
86
|
-
except Exception as e:
|
87
|
-
"""If the response is invalid, remove it from the cache and raise the exception."""
|
88
|
-
self._remove_from_cache(raw_response)
|
89
|
-
raise e
|
90
|
-
|
56
|
+
def determine_answer(self, raw_answer: str) -> Any:
|
91
57
|
question_dict = self.survey.question_names_to_questions()
|
58
|
+
# iterates through the current answers and updates the question_dict (which is all questions)
|
92
59
|
for other_question, answer in self.current_answers.items():
|
93
60
|
if other_question in question_dict:
|
94
61
|
question_dict[other_question].answer = answer
|
95
62
|
else:
|
96
|
-
#
|
63
|
+
# it might be a comment
|
97
64
|
if (
|
98
65
|
new_question := other_question.split("_comment")[0]
|
99
66
|
) in question_dict:
|
100
67
|
question_dict[new_question].comment = answer
|
101
68
|
|
102
|
-
combined_dict = {**question_dict, **scenario}
|
103
|
-
answer
|
104
|
-
|
105
|
-
)
|
106
|
-
data = {
|
107
|
-
"answer": answer,
|
108
|
-
"comment": response.get(
|
109
|
-
"comment", ""
|
110
|
-
), # not all question have comment fields,
|
111
|
-
"question_name": question.question_name,
|
112
|
-
"prompts": self.get_prompts(),
|
113
|
-
"cached_response": raw_response.get("cached_response", None),
|
114
|
-
"usage": raw_response.get("usage", {}),
|
115
|
-
"raw_model_response": raw_model_response,
|
116
|
-
"cache_used": raw_response.get("cache_used", False),
|
117
|
-
"cache_key": raw_response.get("cache_key", None),
|
118
|
-
}
|
119
|
-
return AgentResponseDict(**data)
|
120
|
-
|
121
|
-
get_response = sync_wrapper(async_get_response)
|
122
|
-
answer_question = sync_wrapper(async_answer_question)
|
123
|
-
|
124
|
-
|
125
|
-
class InvigilatorSidecar(InvigilatorAI):
|
126
|
-
"""An invigilator that presents the 'raw' question to the agent
|
127
|
-
& uses a sidecar model to answer questions."""
|
128
|
-
|
129
|
-
async def async_answer_question(self, failed: bool = False) -> AgentResponseDict:
|
130
|
-
"""Answer a question using the AI model."""
|
131
|
-
from edsl import Model
|
132
|
-
|
133
|
-
advanced_model = self.sidecar_model
|
134
|
-
simple_model = self.model
|
135
|
-
question = self.question
|
136
|
-
human_readable_question = (
|
137
|
-
"Please answer this single question: " + question.human_readable()
|
138
|
-
)
|
139
|
-
print("Getting the simple model response to: ", human_readable_question)
|
140
|
-
raw_simple_response = await simple_model.async_execute_model_call(
|
141
|
-
user_prompt=human_readable_question,
|
142
|
-
system_prompt="""Pretend you are a human answering a question. Do not break character.""",
|
143
|
-
)
|
144
|
-
simple_response = simple_model.parse_response(raw_simple_response)
|
145
|
-
instructions = question.get_instructions()
|
146
|
-
|
147
|
-
main_model_prompt = Prompt(
|
148
|
-
text="""
|
149
|
-
A simpler language model was asked this question:
|
150
|
-
|
151
|
-
To the simpel model:
|
152
|
-
{{ human_readable_question }}
|
153
|
-
|
154
|
-
The simple model responded:
|
155
|
-
<response>
|
156
|
-
{{ simple_response }}
|
157
|
-
</response>
|
158
|
-
|
159
|
-
It was suppose to respond according to these instructions:
|
160
|
-
<instructions>
|
161
|
-
{{ instructions }}
|
162
|
-
</instructions>
|
163
|
-
|
164
|
-
Please format the simple model's response as it should have been formmated, given the instructions.
|
165
|
-
Only respond in valid JSON, like so {"answer": "SPAM!"} or {"answer": "SPAM!", "comment": "I am a robot."}
|
166
|
-
Do not inlcude the word 'json'
|
167
|
-
"""
|
168
|
-
)
|
169
|
-
|
170
|
-
d = {
|
171
|
-
"human_readable_question": human_readable_question,
|
172
|
-
"simple_response": simple_response,
|
173
|
-
"instructions": instructions,
|
174
|
-
}
|
175
|
-
|
176
|
-
print("The human-readable question is: ", human_readable_question)
|
177
|
-
print("The simple response is: ", simple_response)
|
69
|
+
combined_dict = {**question_dict, **self.scenario}
|
70
|
+
# sometimes the answer is a code, so we need to translate it
|
71
|
+
return self.question._translate_answer_code_to_answer(raw_answer, combined_dict)
|
178
72
|
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
)
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
"
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
"
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
73
|
+
def extract_edsl_result_entry_and_validate(
|
74
|
+
self, agent_response_dict: AgentResponseDict
|
75
|
+
) -> EDSLResultObjectInput:
|
76
|
+
edsl_dict = agent_response_dict.edsl_dict._asdict()
|
77
|
+
exception_occurred = None
|
78
|
+
validated = False
|
79
|
+
try:
|
80
|
+
validated_edsl_dict = self.question._validate_answer(edsl_dict)
|
81
|
+
answer = self.determine_answer(validated_edsl_dict["answer"])
|
82
|
+
comment = validated_edsl_dict.get("comment", "")
|
83
|
+
validated = True
|
84
|
+
except QuestionAnswerValidationError as e:
|
85
|
+
answer = None
|
86
|
+
comment = "The response was not valid."
|
87
|
+
if self.raise_validation_errors:
|
88
|
+
exception_occurred = e
|
89
|
+
except Exception as non_validation_error:
|
90
|
+
answer = None
|
91
|
+
comment = "Some other error occurred."
|
92
|
+
exception_occurred = non_validation_error
|
93
|
+
finally:
|
94
|
+
# even if validation failes, we still return the result
|
95
|
+
data = {
|
96
|
+
"answer": answer,
|
97
|
+
"comment": comment,
|
98
|
+
"generated_tokens": agent_response_dict.edsl_dict.generated_tokens,
|
99
|
+
"question_name": self.question.question_name,
|
100
|
+
"prompts": self.get_prompts(),
|
101
|
+
"cached_response": agent_response_dict.model_outputs.cached_response,
|
102
|
+
"raw_model_response": agent_response_dict.model_outputs.response,
|
103
|
+
"cache_used": agent_response_dict.model_outputs.cache_used,
|
104
|
+
"cache_key": agent_response_dict.model_outputs.cache_key,
|
105
|
+
"validated": validated,
|
106
|
+
"exception_occurred": exception_occurred,
|
107
|
+
"cost": agent_response_dict.model_outputs.cost,
|
108
|
+
}
|
109
|
+
result = EDSLResultObjectInput(**data)
|
110
|
+
return result
|
204
111
|
|
205
|
-
# get_response = sync_wrapper(async_get_response)
|
206
112
|
answer_question = sync_wrapper(async_answer_question)
|
207
113
|
|
208
114
|
|
209
|
-
class InvigilatorDebug(InvigilatorBase):
|
210
|
-
"""An invigilator class for debugging purposes."""
|
211
|
-
|
212
|
-
async def async_answer_question(self, iteration: int = 0) -> AgentResponseDict:
|
213
|
-
"""Return the answer to the question."""
|
214
|
-
results = self.question._simulate_answer(human_readable=True)
|
215
|
-
results["prompts"] = self.get_prompts()
|
216
|
-
results["question_name"] = self.question.question_name
|
217
|
-
results["comment"] = "Debug comment"
|
218
|
-
return AgentResponseDict(**results)
|
219
|
-
|
220
|
-
def get_prompts(self) -> Dict[str, Prompt]:
|
221
|
-
"""Return the prompts used."""
|
222
|
-
return {
|
223
|
-
"user_prompt": Prompt("NA"),
|
224
|
-
"system_prompt": Prompt("NA"),
|
225
|
-
}
|
226
|
-
|
227
|
-
|
228
115
|
class InvigilatorHuman(InvigilatorBase):
|
229
116
|
"""An invigilator for when a human is answering the question."""
|
230
117
|
|
118
|
+
validate_response: bool = False
|
119
|
+
translate_response: bool = False
|
120
|
+
|
231
121
|
async def async_answer_question(self, iteration: int = 0) -> AgentResponseDict:
|
232
122
|
"""Return the answer to the question."""
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
"
|
237
|
-
|
238
|
-
|
123
|
+
comment = "This is a real survey response from a human."
|
124
|
+
|
125
|
+
def __repr__(self):
|
126
|
+
return f"{self.literal}"
|
127
|
+
|
128
|
+
exception_occurred = None
|
129
|
+
validated = False
|
239
130
|
try:
|
240
131
|
answer = self.agent.answer_question_directly(self.question, self.scenario)
|
241
|
-
|
132
|
+
self.raw_model_response = answer
|
133
|
+
|
134
|
+
if self.validate_response:
|
135
|
+
_ = self.question._validate_answer({"answer": answer})
|
136
|
+
if self.translate_response:
|
137
|
+
answer = self.question._translate_answer_code_to_answer(
|
138
|
+
answer, self.scenario
|
139
|
+
)
|
140
|
+
validated = True
|
141
|
+
except QuestionAnswerValidationError as e:
|
142
|
+
answer = None
|
143
|
+
if self.raise_validation_errors:
|
144
|
+
exception_occurred = e
|
242
145
|
except Exception as e:
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
146
|
+
answer = None
|
147
|
+
if self.raise_validation_errors:
|
148
|
+
exception_occurred = e
|
149
|
+
finally:
|
150
|
+
data = {
|
151
|
+
"generated_tokens": NotApplicable(),
|
152
|
+
"question_name": self.question.question_name,
|
153
|
+
"prompts": self.get_prompts(),
|
154
|
+
"cached_response": NotApplicable(),
|
155
|
+
"raw_model_response": NotApplicable(),
|
156
|
+
"cache_used": NotApplicable(),
|
157
|
+
"cache_key": NotApplicable(),
|
158
|
+
"answer": answer,
|
159
|
+
"comment": comment,
|
160
|
+
"validated": validated,
|
161
|
+
"exception_occurred": exception_occurred,
|
162
|
+
}
|
163
|
+
return EDSLResultObjectInput(**data)
|
250
164
|
|
251
165
|
|
252
166
|
class InvigilatorFunctional(InvigilatorBase):
|
@@ -255,22 +169,21 @@ class InvigilatorFunctional(InvigilatorBase):
|
|
255
169
|
async def async_answer_question(self, iteration: int = 0) -> AgentResponseDict:
|
256
170
|
"""Return the answer to the question."""
|
257
171
|
func = self.question.answer_question_directly
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
) from e
|
172
|
+
answer = func(scenario=self.scenario, agent_traits=self.agent.traits)
|
173
|
+
|
174
|
+
return EDSLResultObjectInput(
|
175
|
+
generated_tokens=str(answer),
|
176
|
+
question_name=self.question.question_name,
|
177
|
+
prompts=self.get_prompts(),
|
178
|
+
cached_response=NotApplicable(),
|
179
|
+
raw_model_response=NotApplicable(),
|
180
|
+
cache_used=NotApplicable(),
|
181
|
+
cache_key=NotApplicable(),
|
182
|
+
answer=answer["answer"],
|
183
|
+
comment="This is the result of a functional question.",
|
184
|
+
validated=True,
|
185
|
+
exception_occurred=None,
|
186
|
+
)
|
274
187
|
|
275
188
|
def get_prompts(self) -> Dict[str, Prompt]:
|
276
189
|
"""Return the prompts used."""
|