edsl 0.1.33__py3-none-any.whl → 0.1.33.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +3 -9
- edsl/__init__.py +3 -8
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +8 -40
- edsl/agents/AgentList.py +0 -43
- edsl/agents/Invigilator.py +219 -135
- edsl/agents/InvigilatorBase.py +59 -148
- edsl/agents/{PromptConstructor.py → PromptConstructionMixin.py} +89 -138
- edsl/agents/__init__.py +0 -1
- edsl/config.py +56 -47
- edsl/coop/coop.py +7 -50
- edsl/data/Cache.py +1 -35
- edsl/data_transfer_models.py +38 -73
- edsl/enums.py +0 -4
- edsl/exceptions/language_models.py +1 -25
- edsl/exceptions/questions.py +5 -62
- edsl/exceptions/results.py +0 -4
- edsl/inference_services/AnthropicService.py +11 -13
- edsl/inference_services/AwsBedrock.py +17 -19
- edsl/inference_services/AzureAI.py +20 -37
- edsl/inference_services/GoogleService.py +12 -16
- edsl/inference_services/GroqService.py +0 -2
- edsl/inference_services/InferenceServiceABC.py +3 -58
- edsl/inference_services/OpenAIService.py +54 -48
- edsl/inference_services/models_available_cache.py +6 -0
- edsl/inference_services/registry.py +0 -6
- edsl/jobs/Answers.py +12 -10
- edsl/jobs/Jobs.py +21 -36
- edsl/jobs/buckets/BucketCollection.py +15 -24
- edsl/jobs/buckets/TokenBucket.py +14 -93
- edsl/jobs/interviews/Interview.py +78 -366
- edsl/jobs/interviews/InterviewExceptionEntry.py +19 -85
- edsl/jobs/interviews/InterviewTaskBuildingMixin.py +286 -0
- edsl/jobs/interviews/{InterviewExceptionCollection.py → interview_exception_tracking.py} +68 -14
- edsl/jobs/interviews/retry_management.py +37 -0
- edsl/jobs/runners/JobsRunnerAsyncio.py +175 -146
- edsl/jobs/runners/JobsRunnerStatusMixin.py +333 -0
- edsl/jobs/tasks/QuestionTaskCreator.py +23 -30
- edsl/jobs/tasks/TaskHistory.py +213 -148
- edsl/language_models/LanguageModel.py +156 -261
- edsl/language_models/ModelList.py +2 -2
- edsl/language_models/RegisterLanguageModelsMeta.py +29 -14
- edsl/language_models/registry.py +6 -23
- edsl/language_models/repair.py +19 -0
- edsl/prompts/Prompt.py +2 -52
- edsl/questions/AnswerValidatorMixin.py +26 -23
- edsl/questions/QuestionBase.py +249 -329
- edsl/questions/QuestionBudget.py +41 -99
- edsl/questions/QuestionCheckBox.py +35 -227
- edsl/questions/QuestionExtract.py +27 -98
- edsl/questions/QuestionFreeText.py +29 -52
- edsl/questions/QuestionFunctional.py +0 -7
- edsl/questions/QuestionList.py +22 -141
- edsl/questions/QuestionMultipleChoice.py +65 -159
- edsl/questions/QuestionNumerical.py +46 -88
- edsl/questions/QuestionRank.py +24 -182
- edsl/questions/RegisterQuestionsMeta.py +12 -31
- edsl/questions/__init__.py +4 -3
- edsl/questions/derived/QuestionLikertFive.py +5 -10
- edsl/questions/derived/QuestionLinearScale.py +2 -15
- edsl/questions/derived/QuestionTopK.py +1 -10
- edsl/questions/derived/QuestionYesNo.py +3 -24
- edsl/questions/descriptors.py +7 -43
- edsl/questions/question_registry.py +2 -6
- edsl/results/Dataset.py +0 -20
- edsl/results/DatasetExportMixin.py +48 -46
- edsl/results/Result.py +5 -32
- edsl/results/Results.py +46 -135
- edsl/results/ResultsDBMixin.py +3 -3
- edsl/scenarios/FileStore.py +10 -71
- edsl/scenarios/Scenario.py +25 -96
- edsl/scenarios/ScenarioImageMixin.py +2 -2
- edsl/scenarios/ScenarioList.py +39 -361
- edsl/scenarios/ScenarioListExportMixin.py +0 -9
- edsl/scenarios/ScenarioListPdfMixin.py +4 -150
- edsl/study/SnapShot.py +1 -8
- edsl/study/Study.py +0 -32
- edsl/surveys/Rule.py +1 -10
- edsl/surveys/RuleCollection.py +5 -21
- edsl/surveys/Survey.py +310 -636
- edsl/surveys/SurveyExportMixin.py +9 -71
- edsl/surveys/SurveyFlowVisualizationMixin.py +1 -2
- edsl/surveys/SurveyQualtricsImport.py +4 -75
- edsl/utilities/gcp_bucket/simple_example.py +9 -0
- edsl/utilities/utilities.py +1 -9
- {edsl-0.1.33.dist-info → edsl-0.1.33.dev1.dist-info}/METADATA +2 -5
- edsl-0.1.33.dev1.dist-info/RECORD +209 -0
- edsl/TemplateLoader.py +0 -24
- edsl/auto/AutoStudy.py +0 -117
- edsl/auto/StageBase.py +0 -230
- edsl/auto/StageGenerateSurvey.py +0 -178
- edsl/auto/StageLabelQuestions.py +0 -125
- edsl/auto/StagePersona.py +0 -61
- edsl/auto/StagePersonaDimensionValueRanges.py +0 -88
- edsl/auto/StagePersonaDimensionValues.py +0 -74
- edsl/auto/StagePersonaDimensions.py +0 -69
- edsl/auto/StageQuestions.py +0 -73
- edsl/auto/SurveyCreatorPipeline.py +0 -21
- edsl/auto/utilities.py +0 -224
- edsl/coop/PriceFetcher.py +0 -58
- edsl/inference_services/MistralAIService.py +0 -120
- edsl/inference_services/TestService.py +0 -80
- edsl/inference_services/TogetherAIService.py +0 -170
- edsl/jobs/FailedQuestion.py +0 -78
- edsl/jobs/runners/JobsRunnerStatus.py +0 -331
- edsl/language_models/fake_openai_call.py +0 -15
- edsl/language_models/fake_openai_service.py +0 -61
- edsl/language_models/utilities.py +0 -61
- edsl/questions/QuestionBaseGenMixin.py +0 -133
- edsl/questions/QuestionBasePromptsMixin.py +0 -266
- edsl/questions/Quick.py +0 -41
- edsl/questions/ResponseValidatorABC.py +0 -170
- edsl/questions/decorators.py +0 -21
- edsl/questions/prompt_templates/question_budget.jinja +0 -13
- edsl/questions/prompt_templates/question_checkbox.jinja +0 -32
- edsl/questions/prompt_templates/question_extract.jinja +0 -11
- edsl/questions/prompt_templates/question_free_text.jinja +0 -3
- edsl/questions/prompt_templates/question_linear_scale.jinja +0 -11
- edsl/questions/prompt_templates/question_list.jinja +0 -17
- edsl/questions/prompt_templates/question_multiple_choice.jinja +0 -33
- edsl/questions/prompt_templates/question_numerical.jinja +0 -37
- edsl/questions/templates/__init__.py +0 -0
- edsl/questions/templates/budget/__init__.py +0 -0
- edsl/questions/templates/budget/answering_instructions.jinja +0 -7
- edsl/questions/templates/budget/question_presentation.jinja +0 -7
- edsl/questions/templates/checkbox/__init__.py +0 -0
- edsl/questions/templates/checkbox/answering_instructions.jinja +0 -10
- edsl/questions/templates/checkbox/question_presentation.jinja +0 -22
- edsl/questions/templates/extract/__init__.py +0 -0
- edsl/questions/templates/extract/answering_instructions.jinja +0 -7
- edsl/questions/templates/extract/question_presentation.jinja +0 -1
- edsl/questions/templates/free_text/__init__.py +0 -0
- edsl/questions/templates/free_text/answering_instructions.jinja +0 -0
- edsl/questions/templates/free_text/question_presentation.jinja +0 -1
- edsl/questions/templates/likert_five/__init__.py +0 -0
- edsl/questions/templates/likert_five/answering_instructions.jinja +0 -10
- edsl/questions/templates/likert_five/question_presentation.jinja +0 -12
- edsl/questions/templates/linear_scale/__init__.py +0 -0
- edsl/questions/templates/linear_scale/answering_instructions.jinja +0 -5
- edsl/questions/templates/linear_scale/question_presentation.jinja +0 -5
- edsl/questions/templates/list/__init__.py +0 -0
- edsl/questions/templates/list/answering_instructions.jinja +0 -4
- edsl/questions/templates/list/question_presentation.jinja +0 -5
- edsl/questions/templates/multiple_choice/__init__.py +0 -0
- edsl/questions/templates/multiple_choice/answering_instructions.jinja +0 -9
- edsl/questions/templates/multiple_choice/html.jinja +0 -0
- edsl/questions/templates/multiple_choice/question_presentation.jinja +0 -12
- edsl/questions/templates/numerical/__init__.py +0 -0
- edsl/questions/templates/numerical/answering_instructions.jinja +0 -8
- edsl/questions/templates/numerical/question_presentation.jinja +0 -7
- edsl/questions/templates/rank/__init__.py +0 -0
- edsl/questions/templates/rank/answering_instructions.jinja +0 -11
- edsl/questions/templates/rank/question_presentation.jinja +0 -15
- edsl/questions/templates/top_k/__init__.py +0 -0
- edsl/questions/templates/top_k/answering_instructions.jinja +0 -8
- edsl/questions/templates/top_k/question_presentation.jinja +0 -22
- edsl/questions/templates/yes_no/__init__.py +0 -0
- edsl/questions/templates/yes_no/answering_instructions.jinja +0 -6
- edsl/questions/templates/yes_no/question_presentation.jinja +0 -12
- edsl/results/DatasetTree.py +0 -145
- edsl/results/Selector.py +0 -118
- edsl/results/tree_explore.py +0 -115
- edsl/surveys/instructions/ChangeInstruction.py +0 -47
- edsl/surveys/instructions/Instruction.py +0 -34
- edsl/surveys/instructions/InstructionCollection.py +0 -77
- edsl/surveys/instructions/__init__.py +0 -0
- edsl/templates/error_reporting/base.html +0 -24
- edsl/templates/error_reporting/exceptions_by_model.html +0 -35
- edsl/templates/error_reporting/exceptions_by_question_name.html +0 -17
- edsl/templates/error_reporting/exceptions_by_type.html +0 -17
- edsl/templates/error_reporting/interview_details.html +0 -116
- edsl/templates/error_reporting/interviews.html +0 -10
- edsl/templates/error_reporting/overview.html +0 -5
- edsl/templates/error_reporting/performance_plot.html +0 -2
- edsl/templates/error_reporting/report.css +0 -74
- edsl/templates/error_reporting/report.html +0 -118
- edsl/templates/error_reporting/report.js +0 -25
- edsl-0.1.33.dist-info/RECORD +0 -295
- {edsl-0.1.33.dist-info → edsl-0.1.33.dev1.dist-info}/LICENSE +0 -0
- {edsl-0.1.33.dist-info → edsl-0.1.33.dev1.dist-info}/WHEEL +0 -0
edsl/Base.py
CHANGED
@@ -47,27 +47,21 @@ class PersistenceMixin:
|
|
47
47
|
self,
|
48
48
|
description: Optional[str] = None,
|
49
49
|
visibility: Optional[str] = "unlisted",
|
50
|
-
expected_parrot_url: Optional[str] = None,
|
51
50
|
):
|
52
51
|
"""Post the object to coop."""
|
53
52
|
from edsl.coop import Coop
|
54
53
|
|
55
|
-
c = Coop(
|
54
|
+
c = Coop()
|
56
55
|
return c.create(self, description, visibility)
|
57
56
|
|
58
57
|
@classmethod
|
59
|
-
def pull(
|
60
|
-
cls,
|
61
|
-
uuid: Optional[Union[str, UUID]] = None,
|
62
|
-
url: Optional[str] = None,
|
63
|
-
expected_parrot_url: Optional[str] = None,
|
64
|
-
):
|
58
|
+
def pull(cls, uuid: Optional[Union[str, UUID]] = None, url: Optional[str] = None):
|
65
59
|
"""Pull the object from coop."""
|
66
60
|
from edsl.coop import Coop
|
67
61
|
from edsl.coop.utils import ObjectRegistry
|
68
62
|
|
69
63
|
object_type = ObjectRegistry.get_object_type_by_edsl_class(cls)
|
70
|
-
coop = Coop(
|
64
|
+
coop = Coop()
|
71
65
|
return coop.get(uuid, url, object_type)
|
72
66
|
|
73
67
|
@classmethod
|
edsl/__init__.py
CHANGED
@@ -8,10 +8,9 @@ from edsl.__version__ import __version__
|
|
8
8
|
from edsl.config import Config, CONFIG
|
9
9
|
from edsl.agents.Agent import Agent
|
10
10
|
from edsl.agents.AgentList import AgentList
|
11
|
-
|
12
11
|
from edsl.questions import QuestionBase
|
13
|
-
from edsl.questions.question_registry import Question
|
14
12
|
from edsl.questions import QuestionMultipleChoice
|
13
|
+
from edsl.questions import QuestionBudget
|
15
14
|
from edsl.questions import QuestionCheckBox
|
16
15
|
from edsl.questions import QuestionExtract
|
17
16
|
from edsl.questions import QuestionFreeText
|
@@ -20,11 +19,10 @@ from edsl.questions import QuestionLikertFive
|
|
20
19
|
from edsl.questions import QuestionList
|
21
20
|
from edsl.questions import QuestionLinearScale
|
22
21
|
from edsl.questions import QuestionNumerical
|
23
|
-
from edsl.questions import QuestionYesNo
|
24
|
-
from edsl.questions import QuestionBudget
|
25
22
|
from edsl.questions import QuestionRank
|
26
23
|
from edsl.questions import QuestionTopK
|
27
|
-
|
24
|
+
from edsl.questions import QuestionYesNo
|
25
|
+
from edsl.questions.question_registry import Question
|
28
26
|
from edsl.scenarios import Scenario
|
29
27
|
from edsl.scenarios import ScenarioList
|
30
28
|
|
@@ -42,6 +40,3 @@ from edsl.notebooks.Notebook import Notebook
|
|
42
40
|
from edsl.study.Study import Study
|
43
41
|
from edsl.conjure.Conjure import Conjure
|
44
42
|
from edsl.coop.coop import Coop
|
45
|
-
|
46
|
-
from edsl.surveys.instructions.Instruction import Instruction
|
47
|
-
from edsl.surveys.instructions.ChangeInstruction import ChangeInstruction
|
edsl/__version__.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
__version__ = "0.1.33"
|
1
|
+
__version__ = "0.1.33.dev1"
|
edsl/agents/Agent.py
CHANGED
@@ -4,7 +4,7 @@ from __future__ import annotations
|
|
4
4
|
import copy
|
5
5
|
import inspect
|
6
6
|
import types
|
7
|
-
from typing import Callable, Optional, Union
|
7
|
+
from typing import Callable, Optional, Union
|
8
8
|
from uuid import uuid4
|
9
9
|
from edsl.Base import Base
|
10
10
|
|
@@ -228,12 +228,7 @@ class Agent(Base):
|
|
228
228
|
if hasattr(self, "answer_question_directly"):
|
229
229
|
delattr(self, "answer_question_directly")
|
230
230
|
|
231
|
-
def add_direct_question_answering_method(
|
232
|
-
self,
|
233
|
-
method: Callable,
|
234
|
-
validate_response: bool = False,
|
235
|
-
translate_response: bool = False,
|
236
|
-
) -> None:
|
231
|
+
def add_direct_question_answering_method(self, method: Callable) -> None:
|
237
232
|
"""Add a method to the agent that can answer a particular question type.
|
238
233
|
|
239
234
|
:param method: A method that can answer a question directly.
|
@@ -254,9 +249,6 @@ class Agent(Base):
|
|
254
249
|
)
|
255
250
|
# print("Warning: overwriting existing answer_question_directly method")
|
256
251
|
|
257
|
-
self.validate_response = validate_response
|
258
|
-
self.translate_response = translate_response
|
259
|
-
|
260
252
|
signature = inspect.signature(method)
|
261
253
|
for argument in ["question", "scenario", "self"]:
|
262
254
|
if argument not in signature.parameters:
|
@@ -280,7 +272,6 @@ class Agent(Base):
|
|
280
272
|
current_answers: Optional[dict] = None,
|
281
273
|
iteration: int = 1,
|
282
274
|
sidecar_model=None,
|
283
|
-
raise_validation_errors: bool = True,
|
284
275
|
) -> "InvigilatorBase":
|
285
276
|
"""Create an Invigilator.
|
286
277
|
|
@@ -312,12 +303,7 @@ class Agent(Base):
|
|
312
303
|
iteration=iteration,
|
313
304
|
cache=cache,
|
314
305
|
sidecar_model=sidecar_model,
|
315
|
-
raise_validation_errors=raise_validation_errors,
|
316
306
|
)
|
317
|
-
if hasattr(self, "validate_response"):
|
318
|
-
invigilator.validate_response = self.validate_response
|
319
|
-
if hasattr(self, "translate_response"):
|
320
|
-
invigilator.translate_response = self.translate_response
|
321
307
|
return invigilator
|
322
308
|
|
323
309
|
async def async_answer_question(
|
@@ -348,8 +334,8 @@ class Agent(Base):
|
|
348
334
|
>>> a.add_direct_question_answering_method(lambda self, question, scenario: "I am a direct answer.")
|
349
335
|
>>> from edsl import QuestionFreeText
|
350
336
|
>>> q = QuestionFreeText.example()
|
351
|
-
>>> a.answer_question(question = q, cache = False)
|
352
|
-
'I am a direct answer.'
|
337
|
+
>>> a.answer_question(question = q, cache = False)
|
338
|
+
{'answer': 'I am a direct answer.', 'comment': 'This is a real survey response from a human.', ...}
|
353
339
|
|
354
340
|
This is a function where an agent returns an answer to a particular question.
|
355
341
|
However, there are several different ways an agent can answer a question, so the
|
@@ -383,7 +369,6 @@ class Agent(Base):
|
|
383
369
|
current_answers: Optional[dict] = None,
|
384
370
|
iteration: int = 0,
|
385
371
|
sidecar_model=None,
|
386
|
-
raise_validation_errors: bool = True,
|
387
372
|
) -> "InvigilatorBase":
|
388
373
|
"""Create an Invigilator."""
|
389
374
|
from edsl import Model
|
@@ -393,6 +378,7 @@ class Agent(Base):
|
|
393
378
|
scenario = scenario or Scenario()
|
394
379
|
|
395
380
|
from edsl.agents.Invigilator import (
|
381
|
+
InvigilatorDebug,
|
396
382
|
InvigilatorHuman,
|
397
383
|
InvigilatorFunctional,
|
398
384
|
InvigilatorAI,
|
@@ -405,9 +391,8 @@ class Agent(Base):
|
|
405
391
|
cache = Cache()
|
406
392
|
|
407
393
|
if debug:
|
408
|
-
raise NotImplementedError("Debug mode is not yet implemented.")
|
409
394
|
# use the question's _simulate_answer method
|
410
|
-
|
395
|
+
invigilator_class = InvigilatorDebug
|
411
396
|
elif hasattr(question, "answer_question_directly"):
|
412
397
|
# It's a functional question and the answer only depends on the agent's traits & the scenario
|
413
398
|
invigilator_class = InvigilatorFunctional
|
@@ -437,7 +422,6 @@ class Agent(Base):
|
|
437
422
|
iteration=iteration,
|
438
423
|
cache=cache,
|
439
424
|
sidecar_model=sidecar_model,
|
440
|
-
raise_validation_errors=raise_validation_errors,
|
441
425
|
)
|
442
426
|
return invigilator
|
443
427
|
|
@@ -513,8 +497,8 @@ class Agent(Base):
|
|
513
497
|
if name == "has_dynamic_traits_function":
|
514
498
|
return self.has_dynamic_traits_function
|
515
499
|
|
516
|
-
if name in self.
|
517
|
-
return self.
|
500
|
+
if name in self.traits:
|
501
|
+
return self.traits[name]
|
518
502
|
raise AttributeError(
|
519
503
|
f"'{type(self).__name__}' object has no attribute '{name}'"
|
520
504
|
)
|
@@ -656,22 +640,6 @@ class Agent(Base):
|
|
656
640
|
column_names = ["Attribute", "Value"]
|
657
641
|
return table_data, column_names
|
658
642
|
|
659
|
-
def add_trait(self, trait_name_or_dict: str, value: Optional[Any] = None) -> Agent:
|
660
|
-
"""Adds a trait to an agent and returns that agent"""
|
661
|
-
if isinstance(trait_name_or_dict, dict) and value is None:
|
662
|
-
self.traits.update(trait_name_or_dict)
|
663
|
-
return self
|
664
|
-
|
665
|
-
if isinstance(trait_name_or_dict, dict) and value:
|
666
|
-
raise ValueError(f"You passed a dict: {trait_name_or_dict}")
|
667
|
-
|
668
|
-
if isinstance(trait_name_or_dict, str):
|
669
|
-
trait = trait_name_or_dict
|
670
|
-
self.traits[trait] = value
|
671
|
-
return self
|
672
|
-
|
673
|
-
raise Exception("Something is not right with adding")
|
674
|
-
|
675
643
|
def remove_trait(self, trait: str) -> Agent:
|
676
644
|
"""Remove a trait from the agent.
|
677
645
|
|
edsl/agents/AgentList.py
CHANGED
@@ -21,12 +21,6 @@ from simpleeval import EvalWithCompoundTypes
|
|
21
21
|
from edsl.Base import Base
|
22
22
|
from edsl.utilities.decorators import add_edsl_version, remove_edsl_version
|
23
23
|
|
24
|
-
from collections.abc import Iterable
|
25
|
-
|
26
|
-
|
27
|
-
def is_iterable(obj):
|
28
|
-
return isinstance(obj, Iterable)
|
29
|
-
|
30
24
|
|
31
25
|
class AgentList(UserList, Base):
|
32
26
|
"""A list of Agents."""
|
@@ -117,13 +111,6 @@ class AgentList(UserList, Base):
|
|
117
111
|
|
118
112
|
return AgentList(new_data)
|
119
113
|
|
120
|
-
@property
|
121
|
-
def all_traits(self):
|
122
|
-
d = {}
|
123
|
-
for agent in self:
|
124
|
-
d.update(agent.traits)
|
125
|
-
return list(d.keys())
|
126
|
-
|
127
114
|
@classmethod
|
128
115
|
def from_csv(cls, file_path: str):
|
129
116
|
"""Load AgentList from a CSV file.
|
@@ -172,36 +159,6 @@ class AgentList(UserList, Base):
|
|
172
159
|
_ = agent.remove_trait(trait)
|
173
160
|
return self
|
174
161
|
|
175
|
-
def add_trait(self, trait, values):
|
176
|
-
"""Adds a new trait to every agent, with values taken from values.
|
177
|
-
|
178
|
-
:param trait: The name of the trait.
|
179
|
-
:param values: The valeues(s) of the trait. If a single value is passed, it is used for all agents.
|
180
|
-
|
181
|
-
>>> al = AgentList.example()
|
182
|
-
>>> al.add_trait('new_trait', 1)
|
183
|
-
AgentList([Agent(traits = {'age': 22, 'hair': 'brown', 'height': 5.5, 'new_trait': 1}), Agent(traits = {'age': 22, 'hair': 'brown', 'height': 5.5, 'new_trait': 1})])
|
184
|
-
>>> al.select('new_trait').to_scenario_list().to_list()
|
185
|
-
[1, 1]
|
186
|
-
>>> al.add_trait('new_trait', [1, 2, 3])
|
187
|
-
Traceback (most recent call last):
|
188
|
-
...
|
189
|
-
ValueError: The passed values have to be the same length as the agent list.
|
190
|
-
"""
|
191
|
-
if not is_iterable(values):
|
192
|
-
value = values
|
193
|
-
for agent in self.data:
|
194
|
-
agent.add_trait(trait, value)
|
195
|
-
return self
|
196
|
-
|
197
|
-
if len(values) != len(self):
|
198
|
-
raise ValueError(
|
199
|
-
"The passed values have to be the same length as the agent list."
|
200
|
-
)
|
201
|
-
for agent, value in zip(self.data, values):
|
202
|
-
agent.add_trait(trait, value)
|
203
|
-
return self
|
204
|
-
|
205
162
|
@staticmethod
|
206
163
|
def get_codebook(file_path: str):
|
207
164
|
"""Return the codebook for a CSV file.
|
edsl/agents/Invigilator.py
CHANGED
@@ -1,169 +1,252 @@
|
|
1
1
|
"""Module for creating Invigilators, which are objects to administer a question to an Agent."""
|
2
2
|
|
3
|
+
import json
|
3
4
|
from typing import Dict, Any, Optional
|
4
5
|
|
6
|
+
from edsl.exceptions import AgentRespondedWithBadJSONError
|
5
7
|
from edsl.prompts.Prompt import Prompt
|
6
8
|
from edsl.utilities.decorators import sync_wrapper, jupyter_nb_handler
|
7
9
|
from edsl.prompts.registry import get_classes as prompt_lookup
|
8
|
-
from edsl.
|
9
|
-
from edsl.agents
|
10
|
-
from edsl.
|
11
|
-
from edsl.agents.PromptConstructor import PromptConstructor
|
12
|
-
|
10
|
+
from edsl.data_transfer_models import AgentResponseDict
|
11
|
+
from edsl.exceptions.agents import FailedTaskException
|
12
|
+
from edsl.agents.PromptConstructionMixin import PromptConstructorMixin
|
13
13
|
|
14
|
-
|
15
|
-
def __new__(cls):
|
16
|
-
instance = super().__new__(cls, "Not Applicable")
|
17
|
-
instance.literal = "Not Applicable"
|
18
|
-
return instance
|
14
|
+
from edsl.agents.InvigilatorBase import InvigilatorBase
|
19
15
|
|
20
16
|
|
21
|
-
class InvigilatorAI(InvigilatorBase):
|
17
|
+
class InvigilatorAI(PromptConstructorMixin, InvigilatorBase):
|
22
18
|
"""An invigilator that uses an AI model to answer questions."""
|
23
19
|
|
24
|
-
def get_prompts(self) -> Dict[str, Prompt]:
|
25
|
-
"""Return the prompts used."""
|
26
|
-
return self.prompt_constructor.get_prompts()
|
27
|
-
|
28
20
|
async def async_answer_question(self) -> AgentResponseDict:
|
29
21
|
"""Answer a question using the AI model.
|
30
22
|
|
31
23
|
>>> i = InvigilatorAI.example()
|
32
24
|
>>> i.answer_question()
|
33
|
-
{'message':
|
25
|
+
{'message': '{"answer": "SPAM!"}'}
|
34
26
|
"""
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
"
|
27
|
+
params = self.get_prompts() | {"iteration": self.iteration}
|
28
|
+
raw_response = await self.async_get_response(**params)
|
29
|
+
data = {
|
30
|
+
"agent": self.agent,
|
31
|
+
"question": self.question,
|
32
|
+
"scenario": self.scenario,
|
33
|
+
"raw_response": raw_response,
|
34
|
+
"raw_model_response": raw_response["raw_model_response"],
|
39
35
|
}
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
36
|
+
response = self._format_raw_response(**data)
|
37
|
+
# breakpoint()
|
38
|
+
return AgentResponseDict(**response)
|
39
|
+
|
40
|
+
async def async_get_response(
|
41
|
+
self,
|
42
|
+
user_prompt: Prompt,
|
43
|
+
system_prompt: Prompt,
|
44
|
+
iteration: int = 0,
|
45
|
+
encoded_image=None,
|
46
|
+
) -> dict:
|
47
|
+
"""Call the LLM and gets a response. Used in the `answer_question` method."""
|
48
|
+
try:
|
49
|
+
params = {
|
50
|
+
"user_prompt": user_prompt.text,
|
51
|
+
"system_prompt": system_prompt.text,
|
52
|
+
"iteration": iteration,
|
53
|
+
"cache": self.cache,
|
54
|
+
}
|
55
|
+
if encoded_image:
|
56
|
+
params["encoded_image"] = encoded_image
|
57
|
+
response = await self.model.async_get_response(**params)
|
58
|
+
|
59
|
+
# TODO: I *don't* think we need to delete the cache key here because I think
|
60
|
+
# it will not have been set yet; the exception would have been raised before.
|
61
|
+
except json.JSONDecodeError as e:
|
62
|
+
raise AgentRespondedWithBadJSONError(
|
63
|
+
f"Returned bad JSON: {e}"
|
64
|
+
f"Prompt: {user_prompt}"
|
65
|
+
f"System Prompt: {system_prompt}"
|
66
|
+
)
|
67
|
+
|
68
|
+
return response
|
69
|
+
|
70
|
+
def _remove_from_cache(self, raw_response) -> None:
|
55
71
|
"""Remove an entry from the cache."""
|
72
|
+
cache_key = raw_response.get("cache_key", None)
|
56
73
|
if cache_key:
|
57
74
|
del self.cache.data[cache_key]
|
58
75
|
|
59
|
-
def
|
76
|
+
def _format_raw_response(
|
77
|
+
self, *, agent, question, scenario, raw_response, raw_model_response
|
78
|
+
) -> AgentResponseDict:
|
79
|
+
"""Return formatted raw response.
|
80
|
+
|
81
|
+
This cleans up the raw response to make it suitable to pass to AgentResponseDict.
|
82
|
+
"""
|
83
|
+
_ = agent
|
84
|
+
try:
|
85
|
+
response = question._validate_answer(raw_response)
|
86
|
+
except Exception as e:
|
87
|
+
"""If the response is invalid, remove it from the cache and raise the exception."""
|
88
|
+
self._remove_from_cache(raw_response)
|
89
|
+
raise e
|
90
|
+
|
60
91
|
question_dict = self.survey.question_names_to_questions()
|
61
|
-
# iterates through the current answers and updates the question_dict (which is all questions)
|
62
92
|
for other_question, answer in self.current_answers.items():
|
63
93
|
if other_question in question_dict:
|
64
94
|
question_dict[other_question].answer = answer
|
65
95
|
else:
|
66
|
-
#
|
96
|
+
# adds a comment to the question
|
67
97
|
if (
|
68
98
|
new_question := other_question.split("_comment")[0]
|
69
99
|
) in question_dict:
|
70
100
|
question_dict[new_question].comment = answer
|
71
101
|
|
72
|
-
combined_dict = {**question_dict, **
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
if self.raise_validation_errors:
|
91
|
-
exception_occurred = e
|
92
|
-
except Exception as non_validation_error:
|
93
|
-
answer = None
|
94
|
-
comment = "Some other error occurred."
|
95
|
-
exception_occurred = non_validation_error
|
96
|
-
finally:
|
97
|
-
# even if validation failes, we still return the result
|
98
|
-
data = {
|
99
|
-
"answer": answer,
|
100
|
-
"comment": comment,
|
101
|
-
"generated_tokens": agent_response_dict.edsl_dict.generated_tokens,
|
102
|
-
"question_name": self.question.question_name,
|
103
|
-
"prompts": self.get_prompts(),
|
104
|
-
"cached_response": agent_response_dict.model_outputs.cached_response,
|
105
|
-
"raw_model_response": agent_response_dict.model_outputs.response,
|
106
|
-
"cache_used": agent_response_dict.model_outputs.cache_used,
|
107
|
-
"cache_key": agent_response_dict.model_outputs.cache_key,
|
108
|
-
"validated": validated,
|
109
|
-
"exception_occurred": exception_occurred,
|
110
|
-
"cost": agent_response_dict.model_outputs.cost,
|
111
|
-
}
|
112
|
-
result = EDSLResultObjectInput(**data)
|
113
|
-
return result
|
102
|
+
combined_dict = {**question_dict, **scenario}
|
103
|
+
answer = question._translate_answer_code_to_answer(
|
104
|
+
response["answer"], combined_dict
|
105
|
+
)
|
106
|
+
data = {
|
107
|
+
"answer": answer,
|
108
|
+
"comment": response.get(
|
109
|
+
"comment", ""
|
110
|
+
), # not all question have comment fields,
|
111
|
+
"question_name": question.question_name,
|
112
|
+
"prompts": self.get_prompts(),
|
113
|
+
"cached_response": raw_response.get("cached_response", None),
|
114
|
+
"usage": raw_response.get("usage", {}),
|
115
|
+
"raw_model_response": raw_model_response,
|
116
|
+
"cache_used": raw_response.get("cache_used", False),
|
117
|
+
"cache_key": raw_response.get("cache_key", None),
|
118
|
+
}
|
119
|
+
return AgentResponseDict(**data)
|
114
120
|
|
121
|
+
get_response = sync_wrapper(async_get_response)
|
115
122
|
answer_question = sync_wrapper(async_answer_question)
|
116
123
|
|
117
124
|
|
118
|
-
class
|
119
|
-
"""An invigilator
|
125
|
+
class InvigilatorSidecar(InvigilatorAI):
|
126
|
+
"""An invigilator that presents the 'raw' question to the agent
|
127
|
+
& uses a sidecar model to answer questions."""
|
128
|
+
|
129
|
+
async def async_answer_question(self, failed: bool = False) -> AgentResponseDict:
|
130
|
+
"""Answer a question using the AI model."""
|
131
|
+
from edsl import Model
|
132
|
+
|
133
|
+
advanced_model = self.sidecar_model
|
134
|
+
simple_model = self.model
|
135
|
+
question = self.question
|
136
|
+
human_readable_question = (
|
137
|
+
"Please answer this single question: " + question.human_readable()
|
138
|
+
)
|
139
|
+
print("Getting the simple model response to: ", human_readable_question)
|
140
|
+
raw_simple_response = await simple_model.async_execute_model_call(
|
141
|
+
user_prompt=human_readable_question,
|
142
|
+
system_prompt="""Pretend you are a human answering a question. Do not break character.""",
|
143
|
+
)
|
144
|
+
simple_response = simple_model.parse_response(raw_simple_response)
|
145
|
+
instructions = question.get_instructions()
|
146
|
+
|
147
|
+
main_model_prompt = Prompt(
|
148
|
+
text="""
|
149
|
+
A simpler language model was asked this question:
|
150
|
+
|
151
|
+
To the simpel model:
|
152
|
+
{{ human_readable_question }}
|
153
|
+
|
154
|
+
The simple model responded:
|
155
|
+
<response>
|
156
|
+
{{ simple_response }}
|
157
|
+
</response>
|
158
|
+
|
159
|
+
It was suppose to respond according to these instructions:
|
160
|
+
<instructions>
|
161
|
+
{{ instructions }}
|
162
|
+
</instructions>
|
163
|
+
|
164
|
+
Please format the simple model's response as it should have been formmated, given the instructions.
|
165
|
+
Only respond in valid JSON, like so {"answer": "SPAM!"} or {"answer": "SPAM!", "comment": "I am a robot."}
|
166
|
+
Do not inlcude the word 'json'
|
167
|
+
"""
|
168
|
+
)
|
169
|
+
|
170
|
+
d = {
|
171
|
+
"human_readable_question": human_readable_question,
|
172
|
+
"simple_response": simple_response,
|
173
|
+
"instructions": instructions,
|
174
|
+
}
|
175
|
+
|
176
|
+
print("The human-readable question is: ", human_readable_question)
|
177
|
+
print("The simple response is: ", simple_response)
|
178
|
+
|
179
|
+
raw_response_data = await advanced_model.async_execute_model_call(
|
180
|
+
user_prompt=main_model_prompt.render(d).text,
|
181
|
+
system_prompt="You are a helpful assistant.",
|
182
|
+
)
|
183
|
+
|
184
|
+
raw_response = await advanced_model.async_get_response(
|
185
|
+
user_prompt=main_model_prompt.render(d).text,
|
186
|
+
system_prompt="You are a helpful assistant.",
|
187
|
+
iteration=0,
|
188
|
+
cache=self.cache,
|
189
|
+
)
|
190
|
+
|
191
|
+
data = {
|
192
|
+
"agent": self.agent,
|
193
|
+
"question": self.question,
|
194
|
+
"scenario": self.scenario,
|
195
|
+
}
|
196
|
+
raw_response_data = {
|
197
|
+
"raw_response": raw_response,
|
198
|
+
"raw_model_response": raw_response["raw_model_response"],
|
199
|
+
}
|
200
|
+
params = data | raw_response_data
|
201
|
+
response = self._format_raw_response(**params)
|
202
|
+
response.update({"simple_model_raw_response": simple_response})
|
203
|
+
return AgentResponseDict(**response)
|
204
|
+
|
205
|
+
# get_response = sync_wrapper(async_get_response)
|
206
|
+
answer_question = sync_wrapper(async_answer_question)
|
120
207
|
|
121
|
-
|
122
|
-
|
208
|
+
|
209
|
+
class InvigilatorDebug(InvigilatorBase):
|
210
|
+
"""An invigilator class for debugging purposes."""
|
123
211
|
|
124
212
|
async def async_answer_question(self, iteration: int = 0) -> AgentResponseDict:
|
125
213
|
"""Return the answer to the question."""
|
126
|
-
|
214
|
+
results = self.question._simulate_answer(human_readable=True)
|
215
|
+
results["prompts"] = self.get_prompts()
|
216
|
+
results["question_name"] = self.question.question_name
|
217
|
+
results["comment"] = "Debug comment"
|
218
|
+
return AgentResponseDict(**results)
|
219
|
+
|
220
|
+
def get_prompts(self) -> Dict[str, Prompt]:
|
221
|
+
"""Return the prompts used."""
|
222
|
+
return {
|
223
|
+
"user_prompt": Prompt("NA"),
|
224
|
+
"system_prompt": Prompt("NA"),
|
225
|
+
}
|
127
226
|
|
128
|
-
def __repr__(self):
|
129
|
-
return f"{self.literal}"
|
130
227
|
|
131
|
-
|
132
|
-
|
228
|
+
class InvigilatorHuman(InvigilatorBase):
|
229
|
+
"""An invigilator for when a human is answering the question."""
|
230
|
+
|
231
|
+
async def async_answer_question(self, iteration: int = 0) -> AgentResponseDict:
|
232
|
+
"""Return the answer to the question."""
|
233
|
+
data = {
|
234
|
+
"comment": "This is a real survey response from a human.",
|
235
|
+
"answer": None,
|
236
|
+
"prompts": self.get_prompts(),
|
237
|
+
"question_name": self.question.question_name,
|
238
|
+
}
|
133
239
|
try:
|
134
240
|
answer = self.agent.answer_question_directly(self.question, self.scenario)
|
135
|
-
|
136
|
-
|
137
|
-
if self.validate_response:
|
138
|
-
_ = self.question._validate_answer({"answer": answer})
|
139
|
-
if self.translate_response:
|
140
|
-
answer = self.question._translate_answer_code_to_answer(
|
141
|
-
answer, self.scenario
|
142
|
-
)
|
143
|
-
validated = True
|
144
|
-
except QuestionAnswerValidationError as e:
|
145
|
-
answer = None
|
146
|
-
if self.raise_validation_errors:
|
147
|
-
exception_occurred = e
|
241
|
+
return AgentResponseDict(**(data | {"answer": answer}))
|
148
242
|
except Exception as e:
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
"prompts": self.get_prompts(),
|
157
|
-
"cached_response": NotApplicable(),
|
158
|
-
"raw_model_response": NotApplicable(),
|
159
|
-
"cache_used": NotApplicable(),
|
160
|
-
"cache_key": NotApplicable(),
|
161
|
-
"answer": answer,
|
162
|
-
"comment": comment,
|
163
|
-
"validated": validated,
|
164
|
-
"exception_occurred": exception_occurred,
|
165
|
-
}
|
166
|
-
return EDSLResultObjectInput(**data)
|
243
|
+
agent_response_dict = AgentResponseDict(
|
244
|
+
**(data | {"answer": None, "comment": str(e)})
|
245
|
+
)
|
246
|
+
raise FailedTaskException(
|
247
|
+
f"Failed to get response. The exception is {str(e)}",
|
248
|
+
agent_response_dict,
|
249
|
+
) from e
|
167
250
|
|
168
251
|
|
169
252
|
class InvigilatorFunctional(InvigilatorBase):
|
@@ -172,21 +255,22 @@ class InvigilatorFunctional(InvigilatorBase):
|
|
172
255
|
async def async_answer_question(self, iteration: int = 0) -> AgentResponseDict:
|
173
256
|
"""Return the answer to the question."""
|
174
257
|
func = self.question.answer_question_directly
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
258
|
+
data = {
|
259
|
+
"comment": "Functional.",
|
260
|
+
"prompts": self.get_prompts(),
|
261
|
+
"question_name": self.question.question_name,
|
262
|
+
}
|
263
|
+
try:
|
264
|
+
answer = func(scenario=self.scenario, agent_traits=self.agent.traits)
|
265
|
+
return AgentResponseDict(**(data | answer))
|
266
|
+
except Exception as e:
|
267
|
+
agent_response_dict = AgentResponseDict(
|
268
|
+
**(data | {"answer": None, "comment": str(e)})
|
269
|
+
)
|
270
|
+
raise FailedTaskException(
|
271
|
+
f"Failed to get response. The exception is {str(e)}",
|
272
|
+
agent_response_dict,
|
273
|
+
) from e
|
190
274
|
|
191
275
|
def get_prompts(self) -> Dict[str, Prompt]:
|
192
276
|
"""Return the prompts used."""
|