edsl 0.1.47__py3-none-any.whl → 0.1.49__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/__init__.py +44 -39
- edsl/__version__.py +1 -1
- edsl/agents/__init__.py +4 -2
- edsl/agents/{Agent.py → agent.py} +442 -152
- edsl/agents/{AgentList.py → agent_list.py} +220 -162
- edsl/agents/descriptors.py +46 -7
- edsl/{exceptions/agents.py → agents/exceptions.py} +3 -12
- edsl/base/__init__.py +75 -0
- edsl/base/base_class.py +1303 -0
- edsl/base/data_transfer_models.py +114 -0
- edsl/base/enums.py +215 -0
- edsl/base.py +8 -0
- edsl/buckets/__init__.py +25 -0
- edsl/buckets/bucket_collection.py +324 -0
- edsl/buckets/model_buckets.py +206 -0
- edsl/buckets/token_bucket.py +502 -0
- edsl/{jobs/buckets/TokenBucketAPI.py → buckets/token_bucket_api.py} +1 -1
- edsl/buckets/token_bucket_client.py +509 -0
- edsl/caching/__init__.py +20 -0
- edsl/caching/cache.py +814 -0
- edsl/caching/cache_entry.py +427 -0
- edsl/{data/CacheHandler.py → caching/cache_handler.py} +14 -15
- edsl/caching/exceptions.py +24 -0
- edsl/caching/orm.py +30 -0
- edsl/{data/RemoteCacheSync.py → caching/remote_cache_sync.py} +3 -3
- edsl/caching/sql_dict.py +441 -0
- edsl/config/__init__.py +8 -0
- edsl/config/config_class.py +177 -0
- edsl/config.py +4 -176
- edsl/conversation/Conversation.py +7 -7
- edsl/conversation/car_buying.py +4 -4
- edsl/conversation/chips.py +6 -6
- edsl/coop/__init__.py +25 -2
- edsl/coop/coop.py +311 -75
- edsl/coop/{ExpectedParrotKeyHandler.py → ep_key_handling.py} +86 -10
- edsl/coop/exceptions.py +62 -0
- edsl/coop/price_fetcher.py +126 -0
- edsl/coop/utils.py +89 -24
- edsl/data_transfer_models.py +5 -72
- edsl/dataset/__init__.py +10 -0
- edsl/{results/Dataset.py → dataset/dataset.py} +116 -36
- edsl/{results/DatasetExportMixin.py → dataset/dataset_operations_mixin.py} +606 -122
- edsl/{results/DatasetTree.py → dataset/dataset_tree.py} +156 -75
- edsl/{results/TableDisplay.py → dataset/display/table_display.py} +18 -7
- edsl/{results → dataset/display}/table_renderers.py +58 -2
- edsl/{results → dataset}/file_exports.py +4 -5
- edsl/{results → dataset}/smart_objects.py +2 -2
- edsl/enums.py +5 -205
- edsl/inference_services/__init__.py +5 -0
- edsl/inference_services/{AvailableModelCacheHandler.py → available_model_cache_handler.py} +2 -3
- edsl/inference_services/{AvailableModelFetcher.py → available_model_fetcher.py} +8 -14
- edsl/inference_services/data_structures.py +3 -2
- edsl/{exceptions/inference_services.py → inference_services/exceptions.py} +1 -1
- edsl/inference_services/{InferenceServiceABC.py → inference_service_abc.py} +1 -1
- edsl/inference_services/{InferenceServicesCollection.py → inference_services_collection.py} +8 -7
- edsl/inference_services/registry.py +4 -41
- edsl/inference_services/{ServiceAvailability.py → service_availability.py} +5 -25
- edsl/inference_services/services/__init__.py +31 -0
- edsl/inference_services/{AnthropicService.py → services/anthropic_service.py} +3 -3
- edsl/inference_services/{AwsBedrock.py → services/aws_bedrock.py} +2 -2
- edsl/inference_services/{AzureAI.py → services/azure_ai.py} +2 -2
- edsl/inference_services/{DeepInfraService.py → services/deep_infra_service.py} +1 -3
- edsl/inference_services/{DeepSeekService.py → services/deep_seek_service.py} +2 -4
- edsl/inference_services/{GoogleService.py → services/google_service.py} +5 -4
- edsl/inference_services/{GroqService.py → services/groq_service.py} +1 -1
- edsl/inference_services/{MistralAIService.py → services/mistral_ai_service.py} +3 -3
- edsl/inference_services/{OllamaService.py → services/ollama_service.py} +1 -7
- edsl/inference_services/{OpenAIService.py → services/open_ai_service.py} +5 -6
- edsl/inference_services/{PerplexityService.py → services/perplexity_service.py} +3 -7
- edsl/inference_services/{TestService.py → services/test_service.py} +7 -6
- edsl/inference_services/{TogetherAIService.py → services/together_ai_service.py} +2 -6
- edsl/inference_services/{XAIService.py → services/xai_service.py} +1 -1
- edsl/inference_services/write_available.py +1 -2
- edsl/instructions/__init__.py +6 -0
- edsl/{surveys/instructions/Instruction.py → instructions/instruction.py} +11 -6
- edsl/{surveys/instructions/InstructionCollection.py → instructions/instruction_collection.py} +10 -5
- edsl/{surveys/InstructionHandler.py → instructions/instruction_handler.py} +3 -3
- edsl/{jobs/interviews → interviews}/ReportErrors.py +2 -2
- edsl/interviews/__init__.py +4 -0
- edsl/{jobs/AnswerQuestionFunctionConstructor.py → interviews/answering_function.py} +45 -18
- edsl/{jobs/interviews/InterviewExceptionEntry.py → interviews/exception_tracking.py} +107 -22
- edsl/interviews/interview.py +638 -0
- edsl/{jobs/interviews/InterviewStatusDictionary.py → interviews/interview_status_dictionary.py} +21 -12
- edsl/{jobs/interviews/InterviewStatusLog.py → interviews/interview_status_log.py} +16 -7
- edsl/{jobs/InterviewTaskManager.py → interviews/interview_task_manager.py} +12 -7
- edsl/{jobs/RequestTokenEstimator.py → interviews/request_token_estimator.py} +8 -3
- edsl/{jobs/interviews/InterviewStatistic.py → interviews/statistics.py} +36 -10
- edsl/invigilators/__init__.py +38 -0
- edsl/invigilators/invigilator_base.py +477 -0
- edsl/{agents/Invigilator.py → invigilators/invigilators.py} +263 -10
- edsl/invigilators/prompt_constructor.py +476 -0
- edsl/{agents → invigilators}/prompt_helpers.py +2 -1
- edsl/{agents/QuestionInstructionPromptBuilder.py → invigilators/question_instructions_prompt_builder.py} +18 -13
- edsl/{agents → invigilators}/question_option_processor.py +96 -21
- edsl/{agents/QuestionTemplateReplacementsBuilder.py → invigilators/question_template_replacements_builder.py} +64 -12
- edsl/jobs/__init__.py +7 -1
- edsl/jobs/async_interview_runner.py +99 -35
- edsl/jobs/check_survey_scenario_compatibility.py +7 -5
- edsl/jobs/data_structures.py +153 -22
- edsl/{exceptions/jobs.py → jobs/exceptions.py} +2 -1
- edsl/jobs/{FetchInvigilator.py → fetch_invigilator.py} +4 -4
- edsl/jobs/{loggers/HTMLTableJobLogger.py → html_table_job_logger.py} +6 -2
- edsl/jobs/{Jobs.py → jobs.py} +313 -167
- edsl/jobs/{JobsChecks.py → jobs_checks.py} +15 -7
- edsl/jobs/{JobsComponentConstructor.py → jobs_component_constructor.py} +19 -17
- edsl/jobs/{InterviewsConstructor.py → jobs_interview_constructor.py} +10 -5
- edsl/jobs/jobs_pricing_estimation.py +347 -0
- edsl/jobs/{JobsRemoteInferenceLogger.py → jobs_remote_inference_logger.py} +4 -3
- edsl/jobs/jobs_runner_asyncio.py +282 -0
- edsl/jobs/{JobsRemoteInferenceHandler.py → remote_inference.py} +19 -22
- edsl/jobs/results_exceptions_handler.py +2 -2
- edsl/key_management/__init__.py +28 -0
- edsl/key_management/key_lookup.py +161 -0
- edsl/{language_models/key_management/KeyLookupBuilder.py → key_management/key_lookup_builder.py} +118 -47
- edsl/key_management/key_lookup_collection.py +82 -0
- edsl/key_management/models.py +218 -0
- edsl/language_models/__init__.py +7 -2
- edsl/language_models/{ComputeCost.py → compute_cost.py} +18 -3
- edsl/{exceptions/language_models.py → language_models/exceptions.py} +2 -1
- edsl/language_models/language_model.py +1080 -0
- edsl/language_models/model.py +10 -25
- edsl/language_models/{ModelList.py → model_list.py} +9 -14
- edsl/language_models/{RawResponseHandler.py → raw_response_handler.py} +1 -1
- edsl/language_models/{RegisterLanguageModelsMeta.py → registry.py} +1 -1
- edsl/language_models/repair.py +4 -4
- edsl/language_models/utilities.py +4 -4
- edsl/notebooks/__init__.py +3 -1
- edsl/notebooks/{Notebook.py → notebook.py} +7 -8
- edsl/prompts/__init__.py +1 -1
- edsl/{exceptions/prompts.py → prompts/exceptions.py} +3 -1
- edsl/prompts/{Prompt.py → prompt.py} +101 -95
- edsl/questions/HTMLQuestion.py +1 -1
- edsl/questions/__init__.py +154 -25
- edsl/questions/answer_validator_mixin.py +1 -1
- edsl/questions/compose_questions.py +4 -3
- edsl/questions/derived/question_likert_five.py +166 -0
- edsl/questions/derived/{QuestionLinearScale.py → question_linear_scale.py} +4 -4
- edsl/questions/derived/{QuestionTopK.py → question_top_k.py} +4 -4
- edsl/questions/derived/{QuestionYesNo.py → question_yes_no.py} +4 -5
- edsl/questions/descriptors.py +24 -30
- edsl/questions/loop_processor.py +65 -19
- edsl/questions/question_base.py +881 -0
- edsl/questions/question_base_gen_mixin.py +15 -16
- edsl/questions/{QuestionBasePromptsMixin.py → question_base_prompts_mixin.py} +2 -2
- edsl/questions/{QuestionBudget.py → question_budget.py} +3 -4
- edsl/questions/{QuestionCheckBox.py → question_check_box.py} +16 -16
- edsl/questions/{QuestionDict.py → question_dict.py} +39 -5
- edsl/questions/{QuestionExtract.py → question_extract.py} +9 -9
- edsl/questions/question_free_text.py +282 -0
- edsl/questions/{QuestionFunctional.py → question_functional.py} +6 -5
- edsl/questions/{QuestionList.py → question_list.py} +6 -7
- edsl/questions/{QuestionMatrix.py → question_matrix.py} +6 -5
- edsl/questions/{QuestionMultipleChoice.py → question_multiple_choice.py} +126 -21
- edsl/questions/{QuestionNumerical.py → question_numerical.py} +5 -5
- edsl/questions/{QuestionRank.py → question_rank.py} +6 -6
- edsl/questions/question_registry.py +4 -9
- edsl/questions/register_questions_meta.py +8 -4
- edsl/questions/response_validator_abc.py +17 -16
- edsl/results/__init__.py +4 -1
- edsl/{exceptions/results.py → results/exceptions.py} +1 -1
- edsl/results/report.py +197 -0
- edsl/results/{Result.py → result.py} +131 -45
- edsl/results/{Results.py → results.py} +365 -220
- edsl/results/results_selector.py +344 -25
- edsl/scenarios/__init__.py +30 -3
- edsl/scenarios/{ConstructDownloadLink.py → construct_download_link.py} +7 -0
- edsl/scenarios/directory_scanner.py +156 -13
- edsl/scenarios/document_chunker.py +186 -0
- edsl/scenarios/exceptions.py +101 -0
- edsl/scenarios/file_methods.py +2 -3
- edsl/scenarios/{FileStore.py → file_store.py} +275 -189
- edsl/scenarios/handlers/__init__.py +14 -14
- edsl/scenarios/handlers/{csv.py → csv_file_store.py} +1 -2
- edsl/scenarios/handlers/{docx.py → docx_file_store.py} +8 -7
- edsl/scenarios/handlers/{html.py → html_file_store.py} +1 -2
- edsl/scenarios/handlers/{jpeg.py → jpeg_file_store.py} +1 -1
- edsl/scenarios/handlers/{json.py → json_file_store.py} +1 -1
- edsl/scenarios/handlers/latex_file_store.py +5 -0
- edsl/scenarios/handlers/{md.py → md_file_store.py} +1 -1
- edsl/scenarios/handlers/{pdf.py → pdf_file_store.py} +2 -2
- edsl/scenarios/handlers/{png.py → png_file_store.py} +1 -1
- edsl/scenarios/handlers/{pptx.py → pptx_file_store.py} +8 -7
- edsl/scenarios/handlers/{py.py → py_file_store.py} +1 -3
- edsl/scenarios/handlers/{sql.py → sql_file_store.py} +2 -1
- edsl/scenarios/handlers/{sqlite.py → sqlite_file_store.py} +2 -3
- edsl/scenarios/handlers/{txt.py → txt_file_store.py} +1 -1
- edsl/scenarios/scenario.py +928 -0
- edsl/scenarios/scenario_join.py +18 -5
- edsl/scenarios/{ScenarioList.py → scenario_list.py} +294 -106
- edsl/scenarios/{ScenarioListPdfMixin.py → scenario_list_pdf_tools.py} +16 -15
- edsl/scenarios/scenario_selector.py +5 -1
- edsl/study/ObjectEntry.py +2 -2
- edsl/study/SnapShot.py +5 -5
- edsl/study/Study.py +18 -19
- edsl/study/__init__.py +6 -4
- edsl/surveys/__init__.py +7 -4
- edsl/surveys/dag/__init__.py +2 -0
- edsl/surveys/{ConstructDAG.py → dag/construct_dag.py} +3 -3
- edsl/surveys/{DAG.py → dag/dag.py} +13 -10
- edsl/surveys/descriptors.py +1 -1
- edsl/surveys/{EditSurvey.py → edit_survey.py} +9 -9
- edsl/{exceptions/surveys.py → surveys/exceptions.py} +1 -2
- edsl/surveys/memory/__init__.py +3 -0
- edsl/surveys/{MemoryPlan.py → memory/memory_plan.py} +10 -9
- edsl/surveys/rules/__init__.py +3 -0
- edsl/surveys/{Rule.py → rules/rule.py} +103 -43
- edsl/surveys/{RuleCollection.py → rules/rule_collection.py} +21 -30
- edsl/surveys/{RuleManager.py → rules/rule_manager.py} +19 -13
- edsl/surveys/survey.py +1743 -0
- edsl/surveys/{SurveyExportMixin.py → survey_export.py} +22 -27
- edsl/surveys/{SurveyFlowVisualization.py → survey_flow_visualization.py} +11 -2
- edsl/surveys/{Simulator.py → survey_simulator.py} +10 -3
- edsl/tasks/__init__.py +32 -0
- edsl/{jobs/tasks/QuestionTaskCreator.py → tasks/question_task_creator.py} +115 -57
- edsl/tasks/task_creators.py +135 -0
- edsl/{jobs/tasks/TaskHistory.py → tasks/task_history.py} +86 -47
- edsl/{jobs/tasks → tasks}/task_status_enum.py +91 -7
- edsl/tasks/task_status_log.py +85 -0
- edsl/tokens/__init__.py +2 -0
- edsl/tokens/interview_token_usage.py +53 -0
- edsl/utilities/PrettyList.py +1 -1
- edsl/utilities/SystemInfo.py +25 -22
- edsl/utilities/__init__.py +29 -21
- edsl/utilities/gcp_bucket/__init__.py +2 -0
- edsl/utilities/gcp_bucket/cloud_storage.py +99 -96
- edsl/utilities/interface.py +44 -536
- edsl/{results/MarkdownToPDF.py → utilities/markdown_to_pdf.py} +13 -5
- edsl/utilities/repair_functions.py +1 -1
- {edsl-0.1.47.dist-info → edsl-0.1.49.dist-info}/METADATA +1 -1
- edsl-0.1.49.dist-info/RECORD +347 -0
- edsl/Base.py +0 -493
- edsl/BaseDiff.py +0 -260
- edsl/agents/InvigilatorBase.py +0 -260
- edsl/agents/PromptConstructor.py +0 -318
- edsl/coop/PriceFetcher.py +0 -54
- edsl/data/Cache.py +0 -582
- edsl/data/CacheEntry.py +0 -238
- edsl/data/SQLiteDict.py +0 -292
- edsl/data/__init__.py +0 -5
- edsl/data/orm.py +0 -10
- edsl/exceptions/cache.py +0 -5
- edsl/exceptions/coop.py +0 -14
- edsl/exceptions/data.py +0 -14
- edsl/exceptions/scenarios.py +0 -29
- edsl/jobs/Answers.py +0 -43
- edsl/jobs/JobsPrompts.py +0 -354
- edsl/jobs/buckets/BucketCollection.py +0 -134
- edsl/jobs/buckets/ModelBuckets.py +0 -65
- edsl/jobs/buckets/TokenBucket.py +0 -283
- edsl/jobs/buckets/TokenBucketClient.py +0 -191
- edsl/jobs/interviews/Interview.py +0 -395
- edsl/jobs/interviews/InterviewExceptionCollection.py +0 -99
- edsl/jobs/interviews/InterviewStatisticsCollection.py +0 -25
- edsl/jobs/runners/JobsRunnerAsyncio.py +0 -163
- edsl/jobs/runners/JobsRunnerStatusData.py +0 -0
- edsl/jobs/tasks/TaskCreators.py +0 -64
- edsl/jobs/tasks/TaskStatusLog.py +0 -23
- edsl/jobs/tokens/InterviewTokenUsage.py +0 -27
- edsl/language_models/LanguageModel.py +0 -635
- edsl/language_models/ServiceDataSources.py +0 -0
- edsl/language_models/key_management/KeyLookup.py +0 -63
- edsl/language_models/key_management/KeyLookupCollection.py +0 -38
- edsl/language_models/key_management/models.py +0 -137
- edsl/questions/QuestionBase.py +0 -544
- edsl/questions/QuestionFreeText.py +0 -130
- edsl/questions/derived/QuestionLikertFive.py +0 -76
- edsl/results/ResultsExportMixin.py +0 -45
- edsl/results/TextEditor.py +0 -50
- edsl/results/results_fetch_mixin.py +0 -33
- edsl/results/results_tools_mixin.py +0 -98
- edsl/scenarios/DocumentChunker.py +0 -104
- edsl/scenarios/Scenario.py +0 -548
- edsl/scenarios/ScenarioHtmlMixin.py +0 -65
- edsl/scenarios/ScenarioListExportMixin.py +0 -45
- edsl/scenarios/handlers/latex.py +0 -5
- edsl/shared.py +0 -1
- edsl/surveys/Survey.py +0 -1301
- edsl/surveys/SurveyQualtricsImport.py +0 -284
- edsl/surveys/SurveyToApp.py +0 -141
- edsl/surveys/instructions/__init__.py +0 -0
- edsl/tools/__init__.py +0 -1
- edsl/tools/clusters.py +0 -192
- edsl/tools/embeddings.py +0 -27
- edsl/tools/embeddings_plotting.py +0 -118
- edsl/tools/plotting.py +0 -112
- edsl/tools/summarize.py +0 -18
- edsl/utilities/data/Registry.py +0 -6
- edsl/utilities/data/__init__.py +0 -1
- edsl/utilities/data/scooter_results.json +0 -1
- edsl-0.1.47.dist-info/RECORD +0 -354
- /edsl/coop/{CoopFunctionsMixin.py → coop_functions.py} +0 -0
- /edsl/{results → dataset/display}/CSSParameterizer.py +0 -0
- /edsl/{language_models/key_management → dataset/display}/__init__.py +0 -0
- /edsl/{results → dataset/display}/table_data_class.py +0 -0
- /edsl/{results → dataset/display}/table_display.css +0 -0
- /edsl/{results/ResultsGGMixin.py → dataset/r/ggplot.py} +0 -0
- /edsl/{results → dataset}/tree_explore.py +0 -0
- /edsl/{surveys/instructions/ChangeInstruction.py → instructions/change_instruction.py} +0 -0
- /edsl/{jobs/interviews → interviews}/interview_status_enum.py +0 -0
- /edsl/jobs/{runners/JobsRunnerStatus.py → jobs_runner_status.py} +0 -0
- /edsl/language_models/{PriceManager.py → price_manager.py} +0 -0
- /edsl/language_models/{fake_openai_call.py → unused/fake_openai_call.py} +0 -0
- /edsl/language_models/{fake_openai_service.py → unused/fake_openai_service.py} +0 -0
- /edsl/notebooks/{NotebookToLaTeX.py → notebook_to_latex.py} +0 -0
- /edsl/{exceptions/questions.py → questions/exceptions.py} +0 -0
- /edsl/questions/{SimpleAskMixin.py → simple_ask_mixin.py} +0 -0
- /edsl/surveys/{Memory.py → memory/memory.py} +0 -0
- /edsl/surveys/{MemoryManagement.py → memory/memory_management.py} +0 -0
- /edsl/surveys/{SurveyCSS.py → survey_css.py} +0 -0
- /edsl/{jobs/tokens/TokenUsage.py → tokens/token_usage.py} +0 -0
- /edsl/{results/MarkdownToDocx.py → utilities/markdown_to_docx.py} +0 -0
- /edsl/{TemplateLoader.py → utilities/template_loader.py} +0 -0
- {edsl-0.1.47.dist-info → edsl-0.1.49.dist-info}/LICENSE +0 -0
- {edsl-0.1.47.dist-info → edsl-0.1.49.dist-info}/WHEEL +0 -0
@@ -1,21 +1,270 @@
|
|
1
1
|
"""Module for creating Invigilators, which are objects to administer a question to an Agent."""
|
2
|
-
|
2
|
+
from abc import ABC, abstractmethod
|
3
|
+
import asyncio
|
4
|
+
from typing import Coroutine, Dict, Any, Optional, TYPE_CHECKING
|
3
5
|
from typing import Dict, Any, Optional, TYPE_CHECKING, Literal
|
4
6
|
|
5
|
-
from
|
6
|
-
from
|
7
|
-
from
|
8
|
-
from
|
7
|
+
from ..utilities.decorators import sync_wrapper
|
8
|
+
from ..questions.exceptions import QuestionAnswerValidationError
|
9
|
+
from ..data_transfer_models import AgentResponseDict, EDSLResultObjectInput
|
10
|
+
from ..utilities.decorators import jupyter_nb_handler
|
11
|
+
from ..data_transfer_models import AgentResponseDict
|
12
|
+
from ..data_transfer_models import EDSLResultObjectInput
|
13
|
+
|
14
|
+
from .prompt_constructor import PromptConstructor
|
15
|
+
from .prompt_helpers import PromptPlan
|
9
16
|
|
10
17
|
if TYPE_CHECKING:
|
11
|
-
from
|
12
|
-
from
|
13
|
-
from
|
18
|
+
from ..prompts import Prompt
|
19
|
+
from ..scenarios import Scenario
|
20
|
+
from ..surveys import Survey
|
21
|
+
from ..prompts import Prompt
|
22
|
+
from ..caching import Cache
|
23
|
+
from ..questions import QuestionBase
|
24
|
+
from ..scenarios import Scenario
|
25
|
+
from ..surveys.memory import MemoryPlan
|
26
|
+
from ..language_models import LanguageModel
|
27
|
+
from ..surveys import Survey
|
28
|
+
from ..agents import Agent
|
29
|
+
from ..key_management import KeyLookup
|
30
|
+
|
31
|
+
|
14
32
|
|
15
33
|
PromptType = Literal["user_prompt", "system_prompt", "encoded_image", "files_list"]
|
16
34
|
|
17
35
|
NA = "Not Applicable"
|
18
36
|
|
37
|
+
class InvigilatorBase(ABC):
|
38
|
+
"""An invigiator (someone who administers an exam) is a class that is responsible for administering a question to an agent.
|
39
|
+
|
40
|
+
>>> InvigilatorBase.example().answer_question()
|
41
|
+
{'message': [{'text': 'SPAM!'}], 'usage': {'prompt_tokens': 1, 'completion_tokens': 1}}
|
42
|
+
|
43
|
+
>>> InvigilatorBase.example().get_failed_task_result(failure_reason="Failed to get response").comment
|
44
|
+
'Failed to get response'
|
45
|
+
|
46
|
+
This returns an empty prompt because there is no memory the agent needs to have at q0.
|
47
|
+
"""
|
48
|
+
|
49
|
+
def __init__(
|
50
|
+
self,
|
51
|
+
agent: "Agent",
|
52
|
+
question: "QuestionBase",
|
53
|
+
scenario: "Scenario",
|
54
|
+
model: "LanguageModel",
|
55
|
+
memory_plan: "MemoryPlan",
|
56
|
+
current_answers: dict,
|
57
|
+
survey: Optional["Survey"],
|
58
|
+
cache: Optional["Cache"] = None,
|
59
|
+
iteration: Optional[int] = 1,
|
60
|
+
additional_prompt_data: Optional[dict] = None,
|
61
|
+
raise_validation_errors: Optional[bool] = True,
|
62
|
+
prompt_plan: Optional["PromptPlan"] = None,
|
63
|
+
key_lookup: Optional["KeyLookup"] = None,
|
64
|
+
):
|
65
|
+
"""Initialize a new Invigilator."""
|
66
|
+
self.agent = agent
|
67
|
+
self.question = question
|
68
|
+
self.scenario = scenario
|
69
|
+
self.model = model
|
70
|
+
self.memory_plan = memory_plan
|
71
|
+
self.current_answers = current_answers or {}
|
72
|
+
self.iteration = iteration
|
73
|
+
self.additional_prompt_data = additional_prompt_data
|
74
|
+
self.cache = cache
|
75
|
+
self.survey = survey
|
76
|
+
self.raise_validation_errors = raise_validation_errors
|
77
|
+
self.key_lookup = key_lookup
|
78
|
+
|
79
|
+
if prompt_plan is None:
|
80
|
+
self.prompt_plan = PromptPlan()
|
81
|
+
else:
|
82
|
+
self.prompt_plan = prompt_plan
|
83
|
+
|
84
|
+
# placeholder to store the raw model response
|
85
|
+
self.raw_model_response = None
|
86
|
+
|
87
|
+
@property
|
88
|
+
def prompt_constructor(self) -> PromptConstructor:
|
89
|
+
"""Return the prompt constructor."""
|
90
|
+
return PromptConstructor.from_invigilator(self, prompt_plan=self.prompt_plan)
|
91
|
+
|
92
|
+
def to_dict(self, include_cache=False) -> Dict[str, Any]:
|
93
|
+
attributes = [
|
94
|
+
"agent",
|
95
|
+
"question",
|
96
|
+
"scenario",
|
97
|
+
"model",
|
98
|
+
"memory_plan",
|
99
|
+
"current_answers",
|
100
|
+
"iteration",
|
101
|
+
"additional_prompt_data",
|
102
|
+
"survey",
|
103
|
+
"raw_model_response",
|
104
|
+
]
|
105
|
+
if include_cache:
|
106
|
+
attributes.append("cache")
|
107
|
+
|
108
|
+
def serialize_attribute(attr):
|
109
|
+
value = getattr(self, attr)
|
110
|
+
if value is None:
|
111
|
+
return None
|
112
|
+
if hasattr(value, "to_dict"):
|
113
|
+
return value.to_dict()
|
114
|
+
if isinstance(value, (int, float, str, bool, dict, list)):
|
115
|
+
return value
|
116
|
+
return str(value)
|
117
|
+
|
118
|
+
return {attr: serialize_attribute(attr) for attr in attributes}
|
119
|
+
|
120
|
+
@classmethod
|
121
|
+
def from_dict(cls, data) -> "InvigilatorBase":
|
122
|
+
from ..agents import Agent
|
123
|
+
from ..questions import QuestionBase
|
124
|
+
from ..scenarios import Scenario
|
125
|
+
from ..surveys.memory import MemoryPlan
|
126
|
+
from ..language_models import LanguageModel
|
127
|
+
from ..surveys import Survey
|
128
|
+
from ..caching import Cache
|
129
|
+
|
130
|
+
attributes_to_classes = {
|
131
|
+
"agent": Agent,
|
132
|
+
"question": QuestionBase,
|
133
|
+
"scenario": Scenario,
|
134
|
+
"model": LanguageModel,
|
135
|
+
"memory_plan": MemoryPlan,
|
136
|
+
"survey": Survey,
|
137
|
+
"cache": Cache,
|
138
|
+
}
|
139
|
+
d = {}
|
140
|
+
for attr, cls_ in attributes_to_classes.items():
|
141
|
+
if attr in data and data[attr] is not None:
|
142
|
+
if attr not in data:
|
143
|
+
d[attr] = {}
|
144
|
+
else:
|
145
|
+
d[attr] = cls_.from_dict(data[attr])
|
146
|
+
|
147
|
+
d["current_answers"] = data["current_answers"]
|
148
|
+
d["iteration"] = data["iteration"]
|
149
|
+
d["additional_prompt_data"] = data["additional_prompt_data"]
|
150
|
+
|
151
|
+
d = cls(**d)
|
152
|
+
d.raw_model_response = data.get("raw_model_response")
|
153
|
+
return d
|
154
|
+
|
155
|
+
def __repr__(self) -> str:
|
156
|
+
"""Return a string representation of the Invigilator.
|
157
|
+
|
158
|
+
>>> InvigilatorBase.example().__repr__()
|
159
|
+
'InvigilatorExample(...)'
|
160
|
+
|
161
|
+
"""
|
162
|
+
return f"{self.__class__.__name__}(agent={repr(self.agent)}, question={repr(self.question)}, scenario={repr(self.scenario)}, model={repr(self.model)}, memory_plan={repr(self.memory_plan)}, current_answers={repr(self.current_answers)}, iteration={repr(self.iteration)}, additional_prompt_data={repr(self.additional_prompt_data)}, cache={repr(self.cache)})"
|
163
|
+
|
164
|
+
def get_failed_task_result(self, failure_reason: str) -> EDSLResultObjectInput:
|
165
|
+
"""Return an AgentResponseDict used in case the question-asking fails.
|
166
|
+
|
167
|
+
Possible reasons include:
|
168
|
+
- Legimately skipped because of skip logic
|
169
|
+
- Failed to get response from the model
|
170
|
+
|
171
|
+
"""
|
172
|
+
data = {
|
173
|
+
"answer": None,
|
174
|
+
"generated_tokens": None,
|
175
|
+
"comment": failure_reason,
|
176
|
+
"question_name": self.question.question_name,
|
177
|
+
"prompts": self.get_prompts(),
|
178
|
+
"cached_response": None,
|
179
|
+
"raw_model_response": None,
|
180
|
+
"cache_used": None,
|
181
|
+
"cache_key": None,
|
182
|
+
}
|
183
|
+
return EDSLResultObjectInput(**data)
|
184
|
+
|
185
|
+
def get_prompts(self) -> Dict[str, "Prompt"]:
|
186
|
+
"""Return the prompt used."""
|
187
|
+
from ..prompts import Prompt
|
188
|
+
|
189
|
+
return {
|
190
|
+
"user_prompt": Prompt("NA"),
|
191
|
+
"system_prompt": Prompt("NA"),
|
192
|
+
}
|
193
|
+
|
194
|
+
@abstractmethod
|
195
|
+
async def async_answer_question(self):
|
196
|
+
"""Asnwer a question."""
|
197
|
+
pass
|
198
|
+
|
199
|
+
@jupyter_nb_handler
|
200
|
+
def answer_question(self) -> Coroutine:
|
201
|
+
"""Return a function that gets the answers to the question."""
|
202
|
+
|
203
|
+
async def main():
|
204
|
+
"""Return the answer to the question."""
|
205
|
+
results = await asyncio.gather(self.async_answer_question())
|
206
|
+
return results[0] # Since there's only one task, return its result
|
207
|
+
|
208
|
+
return main()
|
209
|
+
|
210
|
+
@classmethod
|
211
|
+
def example(
|
212
|
+
cls, throw_an_exception=False, question=None, scenario=None, survey=None
|
213
|
+
) -> "InvigilatorBase":
|
214
|
+
"""Return an example invigilator.
|
215
|
+
|
216
|
+
>>> InvigilatorBase.example()
|
217
|
+
InvigilatorExample(...)
|
218
|
+
|
219
|
+
>>> InvigilatorBase.example().answer_question()
|
220
|
+
{'message': [{'text': 'SPAM!'}], 'usage': {'prompt_tokens': 1, 'completion_tokens': 1}}
|
221
|
+
|
222
|
+
>>> InvigilatorBase.example(throw_an_exception=True).answer_question()
|
223
|
+
Traceback (most recent call last):
|
224
|
+
...
|
225
|
+
Exception: This is a test error
|
226
|
+
"""
|
227
|
+
from ..agents import Agent
|
228
|
+
from ..scenarios import Scenario
|
229
|
+
from ..surveys.memory import MemoryPlan
|
230
|
+
from ..language_models import Model
|
231
|
+
from ..surveys import Survey
|
232
|
+
|
233
|
+
model = Model("test", canned_response="SPAM!")
|
234
|
+
|
235
|
+
if throw_an_exception:
|
236
|
+
model.throw_exception = True
|
237
|
+
agent = Agent.example()
|
238
|
+
|
239
|
+
if not survey:
|
240
|
+
survey = Survey.example()
|
241
|
+
|
242
|
+
if question not in survey.questions and question is not None:
|
243
|
+
survey.add_question(question)
|
244
|
+
|
245
|
+
question = question or survey.questions[0]
|
246
|
+
scenario = scenario or Scenario.example()
|
247
|
+
memory_plan = MemoryPlan(survey=survey)
|
248
|
+
current_answers = None
|
249
|
+
|
250
|
+
class InvigilatorExample(cls):
|
251
|
+
"""An example invigilator."""
|
252
|
+
|
253
|
+
async def async_answer_question(self):
|
254
|
+
"""Answer a question."""
|
255
|
+
return await self.model.async_execute_model_call(
|
256
|
+
user_prompt="Hello", system_prompt="Hi"
|
257
|
+
)
|
258
|
+
|
259
|
+
return InvigilatorExample(
|
260
|
+
agent=agent,
|
261
|
+
question=question,
|
262
|
+
scenario=scenario,
|
263
|
+
survey=survey,
|
264
|
+
model=model,
|
265
|
+
memory_plan=memory_plan,
|
266
|
+
current_answers=current_answers,
|
267
|
+
)
|
19
268
|
|
20
269
|
class InvigilatorAI(InvigilatorBase):
|
21
270
|
"""An invigilator that uses an AI model to answer questions."""
|
@@ -23,6 +272,10 @@ class InvigilatorAI(InvigilatorBase):
|
|
23
272
|
def get_prompts(self) -> Dict[PromptType, "Prompt"]:
|
24
273
|
"""Return the prompts used."""
|
25
274
|
return self.prompt_constructor.get_prompts()
|
275
|
+
|
276
|
+
def get_captured_variables(self) -> dict:
|
277
|
+
"""Get the captured variables."""
|
278
|
+
return self.prompt_constructor.get_captured_variables()
|
26
279
|
|
27
280
|
async def async_get_agent_response(self) -> AgentResponseDict:
|
28
281
|
prompts = self.get_prompts()
|
@@ -249,7 +502,7 @@ class InvigilatorHuman(InvigilatorBase):
|
|
249
502
|
|
250
503
|
|
251
504
|
class InvigilatorFunctional(InvigilatorBase):
|
252
|
-
"""A Invigilator for when the question has
|
505
|
+
"""A Invigilator for when the question has an answer_question_directly function."""
|
253
506
|
|
254
507
|
async def async_answer_question(self, iteration: int = 0) -> AgentResponseDict:
|
255
508
|
"""Return the answer to the question."""
|
@@ -271,7 +524,7 @@ class InvigilatorFunctional(InvigilatorBase):
|
|
271
524
|
)
|
272
525
|
|
273
526
|
def get_prompts(self) -> Dict[str, "Prompt"]:
|
274
|
-
from
|
527
|
+
from ..prompts import Prompt
|
275
528
|
|
276
529
|
"""Return the prompts used."""
|
277
530
|
return {
|