edsl 0.1.14__py3-none-any.whl → 0.1.40__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +348 -38
- edsl/BaseDiff.py +260 -0
- edsl/TemplateLoader.py +24 -0
- edsl/__init__.py +46 -10
- edsl/__version__.py +1 -0
- edsl/agents/Agent.py +842 -144
- edsl/agents/AgentList.py +521 -25
- edsl/agents/Invigilator.py +250 -374
- edsl/agents/InvigilatorBase.py +257 -0
- edsl/agents/PromptConstructor.py +272 -0
- edsl/agents/QuestionInstructionPromptBuilder.py +128 -0
- edsl/agents/QuestionTemplateReplacementsBuilder.py +137 -0
- edsl/agents/descriptors.py +43 -13
- edsl/agents/prompt_helpers.py +129 -0
- edsl/agents/question_option_processor.py +172 -0
- edsl/auto/AutoStudy.py +130 -0
- edsl/auto/StageBase.py +243 -0
- edsl/auto/StageGenerateSurvey.py +178 -0
- edsl/auto/StageLabelQuestions.py +125 -0
- edsl/auto/StagePersona.py +61 -0
- edsl/auto/StagePersonaDimensionValueRanges.py +88 -0
- edsl/auto/StagePersonaDimensionValues.py +74 -0
- edsl/auto/StagePersonaDimensions.py +69 -0
- edsl/auto/StageQuestions.py +74 -0
- edsl/auto/SurveyCreatorPipeline.py +21 -0
- edsl/auto/utilities.py +218 -0
- edsl/base/Base.py +279 -0
- edsl/config.py +121 -104
- edsl/conversation/Conversation.py +290 -0
- edsl/conversation/car_buying.py +59 -0
- edsl/conversation/chips.py +95 -0
- edsl/conversation/mug_negotiation.py +81 -0
- edsl/conversation/next_speaker_utilities.py +93 -0
- edsl/coop/CoopFunctionsMixin.py +15 -0
- edsl/coop/ExpectedParrotKeyHandler.py +125 -0
- edsl/coop/PriceFetcher.py +54 -0
- edsl/coop/__init__.py +1 -0
- edsl/coop/coop.py +1029 -134
- edsl/coop/utils.py +131 -0
- edsl/data/Cache.py +560 -89
- edsl/data/CacheEntry.py +230 -0
- edsl/data/CacheHandler.py +168 -0
- edsl/data/RemoteCacheSync.py +186 -0
- edsl/data/SQLiteDict.py +292 -0
- edsl/data/__init__.py +5 -3
- edsl/data/orm.py +6 -33
- edsl/data_transfer_models.py +74 -27
- edsl/enums.py +165 -8
- edsl/exceptions/BaseException.py +21 -0
- edsl/exceptions/__init__.py +52 -46
- edsl/exceptions/agents.py +33 -15
- edsl/exceptions/cache.py +5 -0
- edsl/exceptions/coop.py +8 -0
- edsl/exceptions/general.py +34 -0
- edsl/exceptions/inference_services.py +5 -0
- edsl/exceptions/jobs.py +15 -0
- edsl/exceptions/language_models.py +46 -1
- edsl/exceptions/questions.py +80 -5
- edsl/exceptions/results.py +16 -5
- edsl/exceptions/scenarios.py +29 -0
- edsl/exceptions/surveys.py +13 -10
- edsl/inference_services/AnthropicService.py +106 -0
- edsl/inference_services/AvailableModelCacheHandler.py +184 -0
- edsl/inference_services/AvailableModelFetcher.py +215 -0
- edsl/inference_services/AwsBedrock.py +118 -0
- edsl/inference_services/AzureAI.py +215 -0
- edsl/inference_services/DeepInfraService.py +18 -0
- edsl/inference_services/GoogleService.py +143 -0
- edsl/inference_services/GroqService.py +20 -0
- edsl/inference_services/InferenceServiceABC.py +80 -0
- edsl/inference_services/InferenceServicesCollection.py +138 -0
- edsl/inference_services/MistralAIService.py +120 -0
- edsl/inference_services/OllamaService.py +18 -0
- edsl/inference_services/OpenAIService.py +236 -0
- edsl/inference_services/PerplexityService.py +160 -0
- edsl/inference_services/ServiceAvailability.py +135 -0
- edsl/inference_services/TestService.py +90 -0
- edsl/inference_services/TogetherAIService.py +172 -0
- edsl/inference_services/data_structures.py +134 -0
- edsl/inference_services/models_available_cache.py +118 -0
- edsl/inference_services/rate_limits_cache.py +25 -0
- edsl/inference_services/registry.py +41 -0
- edsl/inference_services/write_available.py +10 -0
- edsl/jobs/AnswerQuestionFunctionConstructor.py +223 -0
- edsl/jobs/Answers.py +21 -20
- edsl/jobs/FetchInvigilator.py +47 -0
- edsl/jobs/InterviewTaskManager.py +98 -0
- edsl/jobs/InterviewsConstructor.py +50 -0
- edsl/jobs/Jobs.py +684 -204
- edsl/jobs/JobsChecks.py +172 -0
- edsl/jobs/JobsComponentConstructor.py +189 -0
- edsl/jobs/JobsPrompts.py +270 -0
- edsl/jobs/JobsRemoteInferenceHandler.py +311 -0
- edsl/jobs/JobsRemoteInferenceLogger.py +239 -0
- edsl/jobs/RequestTokenEstimator.py +30 -0
- edsl/jobs/async_interview_runner.py +138 -0
- edsl/jobs/buckets/BucketCollection.py +104 -0
- edsl/jobs/buckets/ModelBuckets.py +65 -0
- edsl/jobs/buckets/TokenBucket.py +283 -0
- edsl/jobs/buckets/TokenBucketAPI.py +211 -0
- edsl/jobs/buckets/TokenBucketClient.py +191 -0
- edsl/jobs/check_survey_scenario_compatibility.py +85 -0
- edsl/jobs/data_structures.py +120 -0
- edsl/jobs/decorators.py +35 -0
- edsl/jobs/interviews/Interview.py +392 -0
- edsl/jobs/interviews/InterviewExceptionCollection.py +99 -0
- edsl/jobs/interviews/InterviewExceptionEntry.py +186 -0
- edsl/jobs/interviews/InterviewStatistic.py +63 -0
- edsl/jobs/interviews/InterviewStatisticsCollection.py +25 -0
- edsl/jobs/interviews/InterviewStatusDictionary.py +78 -0
- edsl/jobs/interviews/InterviewStatusLog.py +92 -0
- edsl/jobs/interviews/ReportErrors.py +66 -0
- edsl/jobs/interviews/interview_status_enum.py +9 -0
- edsl/jobs/jobs_status_enums.py +9 -0
- edsl/jobs/loggers/HTMLTableJobLogger.py +304 -0
- edsl/jobs/results_exceptions_handler.py +98 -0
- edsl/jobs/runners/JobsRunnerAsyncio.py +151 -110
- edsl/jobs/runners/JobsRunnerStatus.py +298 -0
- edsl/jobs/tasks/QuestionTaskCreator.py +244 -0
- edsl/jobs/tasks/TaskCreators.py +64 -0
- edsl/jobs/tasks/TaskHistory.py +470 -0
- edsl/jobs/tasks/TaskStatusLog.py +23 -0
- edsl/jobs/tasks/task_status_enum.py +161 -0
- edsl/jobs/tokens/InterviewTokenUsage.py +27 -0
- edsl/jobs/tokens/TokenUsage.py +34 -0
- edsl/language_models/ComputeCost.py +63 -0
- edsl/language_models/LanguageModel.py +507 -386
- edsl/language_models/ModelList.py +164 -0
- edsl/language_models/PriceManager.py +127 -0
- edsl/language_models/RawResponseHandler.py +106 -0
- edsl/language_models/RegisterLanguageModelsMeta.py +184 -0
- edsl/language_models/__init__.py +1 -8
- edsl/language_models/fake_openai_call.py +15 -0
- edsl/language_models/fake_openai_service.py +61 -0
- edsl/language_models/key_management/KeyLookup.py +63 -0
- edsl/language_models/key_management/KeyLookupBuilder.py +273 -0
- edsl/language_models/key_management/KeyLookupCollection.py +38 -0
- edsl/language_models/key_management/__init__.py +0 -0
- edsl/language_models/key_management/models.py +131 -0
- edsl/language_models/model.py +256 -0
- edsl/language_models/repair.py +109 -41
- edsl/language_models/utilities.py +65 -0
- edsl/notebooks/Notebook.py +263 -0
- edsl/notebooks/NotebookToLaTeX.py +142 -0
- edsl/notebooks/__init__.py +1 -0
- edsl/prompts/Prompt.py +222 -93
- edsl/prompts/__init__.py +1 -1
- edsl/questions/ExceptionExplainer.py +77 -0
- edsl/questions/HTMLQuestion.py +103 -0
- edsl/questions/QuestionBase.py +518 -0
- edsl/questions/QuestionBasePromptsMixin.py +221 -0
- edsl/questions/QuestionBudget.py +164 -67
- edsl/questions/QuestionCheckBox.py +281 -62
- edsl/questions/QuestionDict.py +343 -0
- edsl/questions/QuestionExtract.py +136 -50
- edsl/questions/QuestionFreeText.py +79 -55
- edsl/questions/QuestionFunctional.py +138 -41
- edsl/questions/QuestionList.py +184 -57
- edsl/questions/QuestionMatrix.py +265 -0
- edsl/questions/QuestionMultipleChoice.py +293 -69
- edsl/questions/QuestionNumerical.py +109 -56
- edsl/questions/QuestionRank.py +244 -49
- edsl/questions/Quick.py +41 -0
- edsl/questions/SimpleAskMixin.py +74 -0
- edsl/questions/__init__.py +9 -6
- edsl/questions/{AnswerValidatorMixin.py → answer_validator_mixin.py} +153 -38
- edsl/questions/compose_questions.py +13 -7
- edsl/questions/data_structures.py +20 -0
- edsl/questions/decorators.py +21 -0
- edsl/questions/derived/QuestionLikertFive.py +28 -26
- edsl/questions/derived/QuestionLinearScale.py +41 -28
- edsl/questions/derived/QuestionTopK.py +34 -26
- edsl/questions/derived/QuestionYesNo.py +40 -27
- edsl/questions/descriptors.py +228 -74
- edsl/questions/loop_processor.py +149 -0
- edsl/questions/prompt_templates/question_budget.jinja +13 -0
- edsl/questions/prompt_templates/question_checkbox.jinja +32 -0
- edsl/questions/prompt_templates/question_extract.jinja +11 -0
- edsl/questions/prompt_templates/question_free_text.jinja +3 -0
- edsl/questions/prompt_templates/question_linear_scale.jinja +11 -0
- edsl/questions/prompt_templates/question_list.jinja +17 -0
- edsl/questions/prompt_templates/question_multiple_choice.jinja +33 -0
- edsl/questions/prompt_templates/question_numerical.jinja +37 -0
- edsl/questions/question_base_gen_mixin.py +168 -0
- edsl/questions/question_registry.py +130 -46
- edsl/questions/register_questions_meta.py +71 -0
- edsl/questions/response_validator_abc.py +188 -0
- edsl/questions/response_validator_factory.py +34 -0
- edsl/questions/settings.py +5 -2
- edsl/questions/templates/__init__.py +0 -0
- edsl/questions/templates/budget/__init__.py +0 -0
- edsl/questions/templates/budget/answering_instructions.jinja +7 -0
- edsl/questions/templates/budget/question_presentation.jinja +7 -0
- edsl/questions/templates/checkbox/__init__.py +0 -0
- edsl/questions/templates/checkbox/answering_instructions.jinja +10 -0
- edsl/questions/templates/checkbox/question_presentation.jinja +22 -0
- edsl/questions/templates/dict/__init__.py +0 -0
- edsl/questions/templates/dict/answering_instructions.jinja +21 -0
- edsl/questions/templates/dict/question_presentation.jinja +1 -0
- edsl/questions/templates/extract/__init__.py +0 -0
- edsl/questions/templates/extract/answering_instructions.jinja +7 -0
- edsl/questions/templates/extract/question_presentation.jinja +1 -0
- edsl/questions/templates/free_text/__init__.py +0 -0
- edsl/questions/templates/free_text/answering_instructions.jinja +0 -0
- edsl/questions/templates/free_text/question_presentation.jinja +1 -0
- edsl/questions/templates/likert_five/__init__.py +0 -0
- edsl/questions/templates/likert_five/answering_instructions.jinja +10 -0
- edsl/questions/templates/likert_five/question_presentation.jinja +12 -0
- edsl/questions/templates/linear_scale/__init__.py +0 -0
- edsl/questions/templates/linear_scale/answering_instructions.jinja +5 -0
- edsl/questions/templates/linear_scale/question_presentation.jinja +5 -0
- edsl/questions/templates/list/__init__.py +0 -0
- edsl/questions/templates/list/answering_instructions.jinja +4 -0
- edsl/questions/templates/list/question_presentation.jinja +5 -0
- edsl/questions/templates/matrix/__init__.py +1 -0
- edsl/questions/templates/matrix/answering_instructions.jinja +5 -0
- edsl/questions/templates/matrix/question_presentation.jinja +20 -0
- edsl/questions/templates/multiple_choice/__init__.py +0 -0
- edsl/questions/templates/multiple_choice/answering_instructions.jinja +9 -0
- edsl/questions/templates/multiple_choice/html.jinja +0 -0
- edsl/questions/templates/multiple_choice/question_presentation.jinja +12 -0
- edsl/questions/templates/numerical/__init__.py +0 -0
- edsl/questions/templates/numerical/answering_instructions.jinja +7 -0
- edsl/questions/templates/numerical/question_presentation.jinja +7 -0
- edsl/questions/templates/rank/__init__.py +0 -0
- edsl/questions/templates/rank/answering_instructions.jinja +11 -0
- edsl/questions/templates/rank/question_presentation.jinja +15 -0
- edsl/questions/templates/top_k/__init__.py +0 -0
- edsl/questions/templates/top_k/answering_instructions.jinja +8 -0
- edsl/questions/templates/top_k/question_presentation.jinja +22 -0
- edsl/questions/templates/yes_no/__init__.py +0 -0
- edsl/questions/templates/yes_no/answering_instructions.jinja +6 -0
- edsl/questions/templates/yes_no/question_presentation.jinja +12 -0
- edsl/results/CSSParameterizer.py +108 -0
- edsl/results/Dataset.py +550 -19
- edsl/results/DatasetExportMixin.py +594 -0
- edsl/results/DatasetTree.py +295 -0
- edsl/results/MarkdownToDocx.py +122 -0
- edsl/results/MarkdownToPDF.py +111 -0
- edsl/results/Result.py +477 -173
- edsl/results/Results.py +987 -269
- edsl/results/ResultsExportMixin.py +28 -125
- edsl/results/ResultsGGMixin.py +83 -15
- edsl/results/TableDisplay.py +125 -0
- edsl/results/TextEditor.py +50 -0
- edsl/results/__init__.py +1 -1
- edsl/results/file_exports.py +252 -0
- edsl/results/results_fetch_mixin.py +33 -0
- edsl/results/results_selector.py +145 -0
- edsl/results/results_tools_mixin.py +98 -0
- edsl/results/smart_objects.py +96 -0
- edsl/results/table_data_class.py +12 -0
- edsl/results/table_display.css +78 -0
- edsl/results/table_renderers.py +118 -0
- edsl/results/tree_explore.py +115 -0
- edsl/scenarios/ConstructDownloadLink.py +109 -0
- edsl/scenarios/DocumentChunker.py +102 -0
- edsl/scenarios/DocxScenario.py +16 -0
- edsl/scenarios/FileStore.py +543 -0
- edsl/scenarios/PdfExtractor.py +40 -0
- edsl/scenarios/Scenario.py +431 -62
- edsl/scenarios/ScenarioHtmlMixin.py +65 -0
- edsl/scenarios/ScenarioList.py +1415 -45
- edsl/scenarios/ScenarioListExportMixin.py +45 -0
- edsl/scenarios/ScenarioListPdfMixin.py +239 -0
- edsl/scenarios/__init__.py +2 -0
- edsl/scenarios/directory_scanner.py +96 -0
- edsl/scenarios/file_methods.py +85 -0
- edsl/scenarios/handlers/__init__.py +13 -0
- edsl/scenarios/handlers/csv.py +49 -0
- edsl/scenarios/handlers/docx.py +76 -0
- edsl/scenarios/handlers/html.py +37 -0
- edsl/scenarios/handlers/json.py +111 -0
- edsl/scenarios/handlers/latex.py +5 -0
- edsl/scenarios/handlers/md.py +51 -0
- edsl/scenarios/handlers/pdf.py +68 -0
- edsl/scenarios/handlers/png.py +39 -0
- edsl/scenarios/handlers/pptx.py +105 -0
- edsl/scenarios/handlers/py.py +294 -0
- edsl/scenarios/handlers/sql.py +313 -0
- edsl/scenarios/handlers/sqlite.py +149 -0
- edsl/scenarios/handlers/txt.py +33 -0
- edsl/scenarios/scenario_join.py +131 -0
- edsl/scenarios/scenario_selector.py +156 -0
- edsl/shared.py +1 -0
- edsl/study/ObjectEntry.py +173 -0
- edsl/study/ProofOfWork.py +113 -0
- edsl/study/SnapShot.py +80 -0
- edsl/study/Study.py +521 -0
- edsl/study/__init__.py +4 -0
- edsl/surveys/ConstructDAG.py +92 -0
- edsl/surveys/DAG.py +92 -11
- edsl/surveys/EditSurvey.py +221 -0
- edsl/surveys/InstructionHandler.py +100 -0
- edsl/surveys/Memory.py +9 -4
- edsl/surveys/MemoryManagement.py +72 -0
- edsl/surveys/MemoryPlan.py +156 -35
- edsl/surveys/Rule.py +221 -74
- edsl/surveys/RuleCollection.py +241 -61
- edsl/surveys/RuleManager.py +172 -0
- edsl/surveys/Simulator.py +75 -0
- edsl/surveys/Survey.py +1079 -339
- edsl/surveys/SurveyCSS.py +273 -0
- edsl/surveys/SurveyExportMixin.py +235 -40
- edsl/surveys/SurveyFlowVisualization.py +181 -0
- edsl/surveys/SurveyQualtricsImport.py +284 -0
- edsl/surveys/SurveyToApp.py +141 -0
- edsl/surveys/__init__.py +4 -2
- edsl/surveys/base.py +19 -3
- edsl/surveys/descriptors.py +17 -6
- edsl/surveys/instructions/ChangeInstruction.py +48 -0
- edsl/surveys/instructions/Instruction.py +56 -0
- edsl/surveys/instructions/InstructionCollection.py +82 -0
- edsl/surveys/instructions/__init__.py +0 -0
- edsl/templates/error_reporting/base.html +24 -0
- edsl/templates/error_reporting/exceptions_by_model.html +35 -0
- edsl/templates/error_reporting/exceptions_by_question_name.html +17 -0
- edsl/templates/error_reporting/exceptions_by_type.html +17 -0
- edsl/templates/error_reporting/interview_details.html +116 -0
- edsl/templates/error_reporting/interviews.html +19 -0
- edsl/templates/error_reporting/overview.html +5 -0
- edsl/templates/error_reporting/performance_plot.html +2 -0
- edsl/templates/error_reporting/report.css +74 -0
- edsl/templates/error_reporting/report.html +118 -0
- edsl/templates/error_reporting/report.js +25 -0
- edsl/tools/__init__.py +1 -0
- edsl/tools/clusters.py +192 -0
- edsl/tools/embeddings.py +27 -0
- edsl/tools/embeddings_plotting.py +118 -0
- edsl/tools/plotting.py +112 -0
- edsl/tools/summarize.py +18 -0
- edsl/utilities/PrettyList.py +56 -0
- edsl/utilities/SystemInfo.py +5 -0
- edsl/utilities/__init__.py +21 -20
- edsl/utilities/ast_utilities.py +3 -0
- edsl/utilities/data/Registry.py +2 -0
- edsl/utilities/decorators.py +41 -0
- edsl/utilities/gcp_bucket/__init__.py +0 -0
- edsl/utilities/gcp_bucket/cloud_storage.py +96 -0
- edsl/utilities/interface.py +310 -60
- edsl/utilities/is_notebook.py +18 -0
- edsl/utilities/is_valid_variable_name.py +11 -0
- edsl/utilities/naming_utilities.py +263 -0
- edsl/utilities/remove_edsl_version.py +24 -0
- edsl/utilities/repair_functions.py +28 -0
- edsl/utilities/restricted_python.py +70 -0
- edsl/utilities/utilities.py +203 -13
- edsl-0.1.40.dist-info/METADATA +111 -0
- edsl-0.1.40.dist-info/RECORD +362 -0
- {edsl-0.1.14.dist-info → edsl-0.1.40.dist-info}/WHEEL +1 -1
- edsl/agents/AgentListExportMixin.py +0 -24
- edsl/coop/old.py +0 -31
- edsl/data/Database.py +0 -141
- edsl/data/crud.py +0 -121
- edsl/jobs/Interview.py +0 -417
- edsl/jobs/JobsRunner.py +0 -63
- edsl/jobs/JobsRunnerStatusMixin.py +0 -115
- edsl/jobs/base.py +0 -47
- edsl/jobs/buckets.py +0 -166
- edsl/jobs/runners/JobsRunnerDryRun.py +0 -19
- edsl/jobs/runners/JobsRunnerStreaming.py +0 -54
- edsl/jobs/task_management.py +0 -218
- edsl/jobs/token_tracking.py +0 -78
- edsl/language_models/DeepInfra.py +0 -69
- edsl/language_models/OpenAI.py +0 -98
- edsl/language_models/model_interfaces/GeminiPro.py +0 -66
- edsl/language_models/model_interfaces/LanguageModelOpenAIFour.py +0 -8
- edsl/language_models/model_interfaces/LanguageModelOpenAIThreeFiveTurbo.py +0 -8
- edsl/language_models/model_interfaces/LlamaTwo13B.py +0 -21
- edsl/language_models/model_interfaces/LlamaTwo70B.py +0 -21
- edsl/language_models/model_interfaces/Mixtral8x7B.py +0 -24
- edsl/language_models/registry.py +0 -81
- edsl/language_models/schemas.py +0 -15
- edsl/language_models/unused/ReplicateBase.py +0 -83
- edsl/prompts/QuestionInstructionsBase.py +0 -6
- edsl/prompts/library/agent_instructions.py +0 -29
- edsl/prompts/library/agent_persona.py +0 -17
- edsl/prompts/library/question_budget.py +0 -26
- edsl/prompts/library/question_checkbox.py +0 -32
- edsl/prompts/library/question_extract.py +0 -19
- edsl/prompts/library/question_freetext.py +0 -14
- edsl/prompts/library/question_linear_scale.py +0 -20
- edsl/prompts/library/question_list.py +0 -22
- edsl/prompts/library/question_multiple_choice.py +0 -44
- edsl/prompts/library/question_numerical.py +0 -31
- edsl/prompts/library/question_rank.py +0 -21
- edsl/prompts/prompt_config.py +0 -33
- edsl/prompts/registry.py +0 -185
- edsl/questions/Question.py +0 -240
- edsl/report/InputOutputDataTypes.py +0 -134
- edsl/report/RegressionMixin.py +0 -28
- edsl/report/ReportOutputs.py +0 -1228
- edsl/report/ResultsFetchMixin.py +0 -106
- edsl/report/ResultsOutputMixin.py +0 -14
- edsl/report/demo.ipynb +0 -645
- edsl/results/ResultsDBMixin.py +0 -184
- edsl/surveys/SurveyFlowVisualizationMixin.py +0 -92
- edsl/trackers/Tracker.py +0 -91
- edsl/trackers/TrackerAPI.py +0 -196
- edsl/trackers/TrackerTasks.py +0 -70
- edsl/utilities/pastebin.py +0 -141
- edsl-0.1.14.dist-info/METADATA +0 -69
- edsl-0.1.14.dist-info/RECORD +0 -141
- /edsl/{language_models/model_interfaces → inference_services}/__init__.py +0 -0
- /edsl/{report/__init__.py → jobs/runners/JobsRunnerStatusData.py} +0 -0
- /edsl/{trackers/__init__.py → language_models/ServiceDataSources.py} +0 -0
- {edsl-0.1.14.dist-info → edsl-0.1.40.dist-info}/LICENSE +0 -0
edsl/surveys/MemoryPlan.py
CHANGED
@@ -1,16 +1,26 @@
|
|
1
|
+
"""A survey has a memory plan that specifies what the agent should remember when answering a question."""
|
2
|
+
|
1
3
|
from collections import UserDict, defaultdict
|
2
|
-
from
|
4
|
+
from typing import Optional
|
3
5
|
|
4
|
-
from edsl.
|
5
|
-
from edsl.
|
6
|
+
# from edsl.surveys.Memory import Memory
|
7
|
+
# from edsl.prompts.Prompt import Prompt
|
8
|
+
# from edsl.surveys.DAG import DAG
|
6
9
|
|
7
10
|
|
8
11
|
class MemoryPlan(UserDict):
|
9
12
|
"""A survey has a memory plan that specifies what the agent should remember when answering a question.
|
10
|
-
|
13
|
+
|
14
|
+
The basic structure of a memory plan is a dictionary of focal questions to memories.
|
15
|
+
|
16
|
+
{focal_question1: [prior_question1, prior_question2, ...], focal_question: [prior_question3]}
|
11
17
|
"""
|
12
18
|
|
13
|
-
def __init__(self, survey: "Survey" = None, data=None):
|
19
|
+
def __init__(self, survey: Optional["Survey"] = None, data: Optional[dict] = None):
|
20
|
+
"""Initialize a memory plan.
|
21
|
+
|
22
|
+
The actual 'data' attributes of the memory plan are a dictionary of focal questions to memories.
|
23
|
+
"""
|
14
24
|
if survey is not None:
|
15
25
|
self.survey = survey
|
16
26
|
self.survey_question_names = [q.question_name for q in survey.questions]
|
@@ -18,24 +28,42 @@ class MemoryPlan(UserDict):
|
|
18
28
|
super().__init__(data or {})
|
19
29
|
|
20
30
|
@property
|
21
|
-
def name_to_text(self):
|
22
|
-
"
|
31
|
+
def name_to_text(self) -> dict:
|
32
|
+
"""Return a dictionary mapping question names to question texts."""
|
23
33
|
return dict(zip(self.survey_question_names, self.question_texts))
|
24
34
|
|
25
|
-
def add_question(self, question):
|
35
|
+
def add_question(self, question: "QuestionBase") -> None:
|
36
|
+
"""Add a question to the survey.
|
37
|
+
|
38
|
+
:param question: A question to add to the survey
|
39
|
+
|
40
|
+
"""
|
26
41
|
self.survey_question_names.append(question.question_name)
|
27
42
|
self.question_texts.append(question.question_text)
|
28
43
|
|
29
|
-
def
|
30
|
-
"
|
44
|
+
def _check_valid_question_name(self, question_name: str) -> None:
|
45
|
+
"""Ensure a passed question name is valid.
|
46
|
+
|
47
|
+
:param question_name: The name of the question to check.
|
48
|
+
|
49
|
+
"""
|
31
50
|
if question_name not in self.survey_question_names:
|
32
51
|
raise ValueError(
|
33
52
|
f"{question_name} is not in the survey. Current names are {self.survey_question_names}"
|
34
53
|
)
|
35
54
|
|
36
|
-
def get_memory_prompt_fragment(
|
37
|
-
|
38
|
-
|
55
|
+
def get_memory_prompt_fragment(
|
56
|
+
self, focal_question: str, answers: dict
|
57
|
+
) -> "Prompt":
|
58
|
+
"""Generate the prompt fragment descripting that past question and answer.
|
59
|
+
|
60
|
+
:param focal_question: The current question being answered.
|
61
|
+
:param answers: A dictionary of question names to answers.
|
62
|
+
|
63
|
+
"""
|
64
|
+
from edsl.prompts.Prompt import Prompt
|
65
|
+
|
66
|
+
self._check_valid_question_name(focal_question)
|
39
67
|
|
40
68
|
if focal_question not in self:
|
41
69
|
return Prompt("")
|
@@ -50,7 +78,7 @@ class MemoryPlan(UserDict):
|
|
50
78
|
"""
|
51
79
|
|
52
80
|
def gen_line(question_text, answer):
|
53
|
-
"
|
81
|
+
"""Return a line of memory."""
|
54
82
|
return f"\tQuestion: {question_text}\n\tAnswer: {answer}\n"
|
55
83
|
|
56
84
|
lines = [gen_line(*pair) for pair in q_and_a_pairs]
|
@@ -61,16 +89,41 @@ class MemoryPlan(UserDict):
|
|
61
89
|
else:
|
62
90
|
return Prompt("")
|
63
91
|
|
64
|
-
def
|
92
|
+
def _check_order(self, focal_question: str, prior_question: str) -> None:
|
93
|
+
"""Ensure the prior question comes before the focal question."""
|
65
94
|
focal_index = self.survey_question_names.index(focal_question)
|
66
95
|
prior_index = self.survey_question_names.index(prior_question)
|
67
96
|
if focal_index <= prior_index:
|
68
97
|
raise ValueError(f"{prior_question} must come before {focal_question}.")
|
69
98
|
|
70
|
-
def add_single_memory(self, focal_question: str, prior_question: str):
|
71
|
-
|
72
|
-
|
73
|
-
|
99
|
+
def add_single_memory(self, focal_question: str, prior_question: str) -> None:
|
100
|
+
"""Add a single memory to the memory plan.
|
101
|
+
|
102
|
+
:param focal_question: The current question being answered.
|
103
|
+
:param prior_question: The question that was answered before the focal question that should be remembered.
|
104
|
+
|
105
|
+
>>> mp = MemoryPlan.example()
|
106
|
+
>>> mp.add_single_memory("q0", "q1")
|
107
|
+
Traceback (most recent call last):
|
108
|
+
...
|
109
|
+
ValueError: q1 must come before q0.
|
110
|
+
|
111
|
+
>>> mp = MemoryPlan.example()
|
112
|
+
>>> mp.add_single_memory("q0", "crap")
|
113
|
+
Traceback (most recent call last):
|
114
|
+
...
|
115
|
+
ValueError: crap is not in the survey. Current names are ['q0', 'q1', 'q2']
|
116
|
+
|
117
|
+
>>> mp = MemoryPlan.example()
|
118
|
+
>>> mp.add_single_memory("crap", "q0")
|
119
|
+
Traceback (most recent call last):
|
120
|
+
...
|
121
|
+
ValueError: crap is not in the survey. Current names are ['q0', 'q1', 'q2']
|
122
|
+
"""
|
123
|
+
self._check_valid_question_name(focal_question)
|
124
|
+
self._check_valid_question_name(prior_question)
|
125
|
+
self._check_order(focal_question, prior_question)
|
126
|
+
from edsl.surveys.Memory import Memory
|
74
127
|
|
75
128
|
if focal_question not in self:
|
76
129
|
memory = Memory()
|
@@ -79,27 +132,53 @@ class MemoryPlan(UserDict):
|
|
79
132
|
else:
|
80
133
|
self[focal_question].add_prior_question(prior_question)
|
81
134
|
|
82
|
-
def add_memory_collection(
|
135
|
+
def add_memory_collection(
|
136
|
+
self, focal_question: str, prior_questions: list[str]
|
137
|
+
) -> None:
|
138
|
+
"""Add a collection of prior questions to the memory plan.
|
139
|
+
|
140
|
+
:param focal_question: The current question being answered.
|
141
|
+
:param prior_questions: A list of questions that were answered before the focal question that should be remembered.
|
142
|
+
"""
|
83
143
|
for question in prior_questions:
|
84
144
|
self.add_single_memory(focal_question, question)
|
85
145
|
|
86
|
-
def to_dict(self):
|
146
|
+
def to_dict(self, add_edsl_version=True) -> dict:
|
147
|
+
"""Serialize the memory plan to a dictionary.
|
148
|
+
|
149
|
+
>>> mp = MemoryPlan.example()
|
150
|
+
>>> mp.to_dict()
|
151
|
+
{'survey_question_names': ['q0', 'q1', 'q2'], 'survey_question_texts': ['Do you like school?', 'Why not?', 'Why?'], 'data': {'q1': {'prior_questions': ['q0']}}}
|
152
|
+
"""
|
153
|
+
newdata = {}
|
154
|
+
for question_name, memory in self.items():
|
155
|
+
newdata[question_name] = memory.to_dict()
|
156
|
+
|
87
157
|
return {
|
88
158
|
"survey_question_names": self.survey_question_names,
|
89
159
|
"survey_question_texts": self.question_texts,
|
90
|
-
"data":
|
160
|
+
"data": newdata,
|
91
161
|
}
|
92
162
|
|
93
163
|
@classmethod
|
94
|
-
def from_dict(cls, data):
|
95
|
-
|
96
|
-
|
164
|
+
def from_dict(cls, data) -> "MemoryPlan":
|
165
|
+
"""Deserialize a memory plan from a dictionary."""
|
166
|
+
from edsl.surveys.Memory import Memory
|
167
|
+
|
168
|
+
newdata = {}
|
169
|
+
for question_name, memory in data["data"].items():
|
170
|
+
newdata[question_name] = Memory.from_dict(memory)
|
171
|
+
|
172
|
+
memory_plan = cls(survey=None, data=newdata)
|
97
173
|
memory_plan.survey_question_names = data["survey_question_names"]
|
98
174
|
memory_plan.question_texts = data["survey_question_texts"]
|
99
|
-
# memory_plan.data = data
|
100
175
|
return memory_plan
|
101
176
|
|
102
|
-
def _indexify(self, d):
|
177
|
+
def _indexify(self, d: dict):
|
178
|
+
"""Convert a dictionary of question names to a dictionary of question indices.
|
179
|
+
|
180
|
+
:param d: A dictionary of question names to indices.
|
181
|
+
"""
|
103
182
|
new_d = {}
|
104
183
|
for k, v in d.items():
|
105
184
|
key = self.survey_question_names.index(k)
|
@@ -108,16 +187,58 @@ class MemoryPlan(UserDict):
|
|
108
187
|
return new_d
|
109
188
|
|
110
189
|
@property
|
111
|
-
def dag(self):
|
190
|
+
def dag(self) -> "DAG":
|
191
|
+
"""Return a directed acyclic graph of the memory plan.
|
192
|
+
|
193
|
+
>>> mp = MemoryPlan.example()
|
194
|
+
>>> mp.dag
|
195
|
+
{1: {0}}
|
196
|
+
"""
|
197
|
+
from edsl.surveys.DAG import DAG
|
198
|
+
|
112
199
|
d = defaultdict(set)
|
113
|
-
"Returns a directed acyclic graph of the memory plan"
|
114
200
|
for focal_question, memory in self.items():
|
115
201
|
for prior_question in memory:
|
116
202
|
d[focal_question].add(prior_question)
|
117
|
-
# if focal_question not in d:
|
118
|
-
# d[focal_question] = set({prior_question})
|
119
|
-
# else:
|
120
|
-
# d[focal_question].add(prior_question)
|
121
|
-
# focal_index = self.survey_question_names.index(focal_question)
|
122
|
-
# prior_index = self.survey_question_names.index(prior_question)
|
123
203
|
return DAG(self._indexify(d))
|
204
|
+
|
205
|
+
@classmethod
|
206
|
+
def example(cls):
|
207
|
+
"""Return an example memory plan."""
|
208
|
+
from edsl import Survey
|
209
|
+
|
210
|
+
mp = cls(survey=Survey.example())
|
211
|
+
mp.add_single_memory("q1", "q0")
|
212
|
+
return mp
|
213
|
+
|
214
|
+
def remove_question(self, question_name: str) -> None:
|
215
|
+
"""Remove a question from the memory plan.
|
216
|
+
|
217
|
+
:param question_name: The name of the question to remove.
|
218
|
+
"""
|
219
|
+
self._check_valid_question_name(question_name)
|
220
|
+
|
221
|
+
# Remove the question from survey_question_names and question_texts
|
222
|
+
index = self.survey_question_names.index(question_name)
|
223
|
+
self.survey_question_names.pop(index)
|
224
|
+
self.question_texts.pop(index)
|
225
|
+
|
226
|
+
# Remove the question from the memory plan if it's a focal question
|
227
|
+
self.pop(question_name, None)
|
228
|
+
|
229
|
+
# Remove the question from all memories where it appears as a prior question
|
230
|
+
for focal_question, memory in self.items():
|
231
|
+
memory.remove_prior_question(question_name)
|
232
|
+
|
233
|
+
# Update the DAG
|
234
|
+
self.dag.remove_node(index)
|
235
|
+
|
236
|
+
def remove_prior_question(self, question_name: str) -> None:
|
237
|
+
"""Remove a prior question from the memory."""
|
238
|
+
self.prior_questions = [q for q in self.prior_questions if q != question_name]
|
239
|
+
|
240
|
+
|
241
|
+
if __name__ == "__main__":
|
242
|
+
import doctest
|
243
|
+
|
244
|
+
doctest.testmod(optionflags=doctest.ELLIPSIS)
|
edsl/surveys/Rule.py
CHANGED
@@ -1,9 +1,33 @@
|
|
1
|
+
"""The Rule class defines a rule for determining the next question presented to an agent.
|
2
|
+
|
3
|
+
The key component is an expression specifiying the logic of the rule, which can include any combination of logical operators ('and', 'or', 'not'), e.g.:
|
4
|
+
|
5
|
+
.. code-block:: python
|
6
|
+
|
7
|
+
"q1 == 'yes' or q2 == 'no'"
|
8
|
+
|
9
|
+
The expression must be about questions "before" the current question.
|
10
|
+
|
11
|
+
Only one rule should apply at each priority level.
|
12
|
+
If multiple rules apply, the one with the highest priority is used.
|
13
|
+
If there are conflicting rules, an exception is raised.
|
14
|
+
|
15
|
+
If no rule is specified, the next question is given as the default.
|
16
|
+
When a question is added with index, it is always given a rule the next question is index + 1, but
|
17
|
+
with a low (-1) priority.
|
18
|
+
"""
|
19
|
+
|
1
20
|
import ast
|
2
|
-
|
3
|
-
from rich import print
|
4
|
-
from simpleeval import EvalWithCompoundTypes
|
21
|
+
import random
|
5
22
|
from typing import Any, Union, List
|
6
|
-
|
23
|
+
|
24
|
+
|
25
|
+
# from rich import print
|
26
|
+
from simpleeval import EvalWithCompoundTypes
|
27
|
+
|
28
|
+
from edsl.exceptions.surveys import SurveyError
|
29
|
+
|
30
|
+
from edsl.exceptions.surveys import (
|
7
31
|
SurveyRuleCannotEvaluateError,
|
8
32
|
SurveyRuleCollectionHasNoRulesAtNodeError,
|
9
33
|
SurveyRuleRefersToFutureStateError,
|
@@ -13,48 +37,39 @@ from edsl.exceptions import (
|
|
13
37
|
)
|
14
38
|
from edsl.surveys.base import EndOfSurvey
|
15
39
|
from edsl.utilities.ast_utilities import extract_variable_names
|
16
|
-
from edsl.utilities.
|
40
|
+
from edsl.utilities.remove_edsl_version import remove_edsl_version
|
17
41
|
|
18
42
|
|
19
|
-
class
|
20
|
-
|
21
|
-
|
43
|
+
class QuestionIndex:
|
44
|
+
def __set_name__(self, owner, name):
|
45
|
+
self.name = f"_{name}"
|
22
46
|
|
23
|
-
|
24
|
-
|
25
|
-
- expression: A string that if evaluates to true, then next_q (an index) is next
|
26
|
-
- next_q: The question is true
|
27
|
-
- priority: an integer that determines which rule is applied if multiple rules apply
|
47
|
+
def __get__(self, obj, objtype=None):
|
48
|
+
return getattr(obj, self.name)
|
28
49
|
|
29
|
-
|
30
|
-
|
50
|
+
def __set__(self, obj, value):
|
51
|
+
if not isinstance(value, (int, EndOfSurvey.__class__)):
|
52
|
+
raise SurveyError(f"{self.name} must be an integer or EndOfSurvey")
|
53
|
+
if self.name == "_next_q" and isinstance(value, int):
|
54
|
+
current_q = getattr(obj, "_current_q")
|
55
|
+
if value <= current_q:
|
56
|
+
raise SurveyError("next_q must be greater than current_q")
|
57
|
+
setattr(obj, self.name, value)
|
31
58
|
|
32
|
-
"q1 == 'yes' or q2 == 'no'"
|
33
59
|
|
34
|
-
|
35
|
-
|
36
|
-
Eventually, we'll use the AST of the expression to make it safer.
|
37
|
-
|
38
|
-
If multiple rules apply, the one with the highest priority is used.
|
39
|
-
This is to deal with the fact that when we create a survey, we give the
|
40
|
-
next question as the default.
|
41
|
-
So when a question is added with index, it is always given a rule the next question is index + 1, but
|
42
|
-
given a low (-1) priority.
|
60
|
+
class Rule:
|
61
|
+
"""The Rule class defines a "rule" for determining the next question presented to an agent."""
|
43
62
|
|
44
|
-
|
45
|
-
|
46
|
-
Ideally, we'd have a way to resolve this ex ante, perhaps to traversing the implied tree
|
47
|
-
each time a rule is added, but for now, we'll let the error emerge at run-time.
|
63
|
+
current_q = QuestionIndex()
|
64
|
+
next_q = QuestionIndex()
|
48
65
|
|
49
|
-
|
50
|
-
We could potentially use the question pydantic models to check for rule conflicts, as
|
51
|
-
they define the potential trees through a survey.
|
66
|
+
# Not implemented but nice to have:
|
67
|
+
# We could potentially use the question pydantic models to check for rule conflicts, as
|
68
|
+
# they define the potential trees through a survey.
|
52
69
|
|
53
|
-
We could also use the AST to check for conflicts by inspecting the types of a rule.
|
54
|
-
For example, if we know the answer to a question is a string, we could check that
|
55
|
-
the expression only contains string comparisons.
|
56
|
-
This would be a lot of work.
|
57
|
-
"""
|
70
|
+
# We could also use the AST to check for conflicts by inspecting the types of a rule.
|
71
|
+
# For example, if we know the answer to a question is a string, we could check that
|
72
|
+
# the expression only contains string comparisons.
|
58
73
|
|
59
74
|
def __init__(
|
60
75
|
self,
|
@@ -63,24 +78,41 @@ class Rule:
|
|
63
78
|
next_q: Union[int, EndOfSurvey.__class__],
|
64
79
|
question_name_to_index: dict[str, int],
|
65
80
|
priority: int,
|
81
|
+
before_rule: bool = False,
|
66
82
|
):
|
67
|
-
"""
|
83
|
+
"""Represent a rule for determining the next question presented to an agent.
|
68
84
|
|
85
|
+
Questions are represented by int indices.
|
86
|
+
|
87
|
+
:param current_q: The question at which the rule is potentially applied.
|
88
|
+
:param expression: A string that evaluates to true or false. If true, then next_q is next.
|
89
|
+
:param next_q: The next question if the expression is true.
|
90
|
+
:param question_name_to_index: A dictionary mapping question names to indices.
|
91
|
+
:param priority: An integer that determines which rule is applied, if multiple rules apply.
|
92
|
+
"""
|
69
93
|
self.current_q = current_q
|
70
94
|
self.expression = expression
|
71
95
|
self.next_q = next_q
|
72
|
-
self.priority = priority
|
73
96
|
self.question_name_to_index = question_name_to_index
|
97
|
+
self.priority = priority
|
98
|
+
self.before_rule = before_rule
|
74
99
|
|
75
|
-
if not next_q == EndOfSurvey
|
76
|
-
|
100
|
+
if not self.next_q == EndOfSurvey:
|
101
|
+
if self.next_q <= self.current_q:
|
102
|
+
raise SurveyRuleSendsYouBackwardsError
|
77
103
|
|
78
|
-
|
79
|
-
|
104
|
+
if not self.next_q == EndOfSurvey and self.current_q > self.next_q:
|
105
|
+
raise SurveyRuleSendsYouBackwardsError(
|
106
|
+
f"current_q: {self.current_q}, next_q: {self.next_q}"
|
107
|
+
)
|
108
|
+
|
109
|
+
# get the AST for the expression - used to extract the variables referenced in the expression
|
80
110
|
try:
|
81
111
|
self.ast_tree = ast.parse(self.expression)
|
82
112
|
except SyntaxError:
|
83
|
-
raise SurveyRuleSkipLogicSyntaxError
|
113
|
+
raise SurveyRuleSkipLogicSyntaxError(
|
114
|
+
f"The expression {self.expression} is not valid Python syntax."
|
115
|
+
)
|
84
116
|
|
85
117
|
# get the names of the variables in the expression
|
86
118
|
# e.g., q1 == 'yes' -> ['q1']
|
@@ -90,91 +122,206 @@ class Rule:
|
|
90
122
|
try:
|
91
123
|
assert all([q in question_name_to_index for q in extracted_question_names])
|
92
124
|
except AssertionError:
|
93
|
-
|
125
|
+
pass
|
126
|
+
# import warnings
|
127
|
+
# warnings.warn(f"There is an extracted field in the expression that is not a known question. It could be a scenario variable. That's fine! But it also could be a typo or mistake.")
|
128
|
+
# print(f"Question name to index: {question_name_to_index}")
|
129
|
+
# print(f"Extracted question names: {extracted_question_names}")
|
130
|
+
# raise SurveyRuleReferenceInRuleToUnknownQuestionError
|
94
131
|
|
95
132
|
# get the indices of the questions mentioned in the expression
|
96
133
|
self.named_questions_by_index = [
|
97
|
-
question_name_to_index[q]
|
134
|
+
question_name_to_index[q]
|
135
|
+
for q in extracted_question_names
|
136
|
+
if q in question_name_to_index
|
98
137
|
]
|
99
138
|
|
100
139
|
# A rule should only refer to questions that have already been asked.
|
101
140
|
# so the named questions in the expression should not be higher than the current question
|
102
141
|
if self.named_questions_by_index:
|
103
142
|
if max(self.named_questions_by_index) > self.current_q:
|
143
|
+
print(
|
144
|
+
"A rule refers to a future question, the answer to which would not be available here."
|
145
|
+
)
|
104
146
|
raise SurveyRuleRefersToFutureStateError
|
105
147
|
|
106
|
-
def
|
148
|
+
def _checks(self):
|
149
|
+
pass
|
150
|
+
|
151
|
+
def to_dict(self, add_edsl_version=True):
|
152
|
+
"""Convert the rule to a dictionary for serialization.
|
153
|
+
|
154
|
+
>>> r = Rule.example()
|
155
|
+
>>> r.to_dict()
|
156
|
+
{'current_q': 1, 'expression': "q1 == 'yes'", 'next_q': 2, 'priority': 0, 'question_name_to_index': {'q1': 1}, 'before_rule': False}
|
157
|
+
"""
|
107
158
|
return {
|
108
159
|
"current_q": self.current_q,
|
109
160
|
"expression": self.expression,
|
110
161
|
"next_q": "EndOfSurvey" if self.next_q == EndOfSurvey else self.next_q,
|
111
162
|
"priority": self.priority,
|
112
163
|
"question_name_to_index": self.question_name_to_index,
|
164
|
+
"before_rule": self.before_rule,
|
113
165
|
}
|
114
166
|
|
115
167
|
@classmethod
|
168
|
+
@remove_edsl_version
|
116
169
|
def from_dict(self, rule_dict):
|
170
|
+
"""Create a rule from a dictionary."""
|
117
171
|
if rule_dict["next_q"] == "EndOfSurvey":
|
118
172
|
rule_dict["next_q"] = EndOfSurvey
|
119
173
|
|
174
|
+
if "before_rule" not in rule_dict:
|
175
|
+
rule_dict["before_rule"] = False
|
176
|
+
|
120
177
|
return Rule(
|
121
178
|
current_q=rule_dict["current_q"],
|
122
179
|
expression=rule_dict["expression"],
|
123
180
|
next_q=rule_dict["next_q"],
|
124
181
|
priority=rule_dict["priority"],
|
125
182
|
question_name_to_index=rule_dict["question_name_to_index"],
|
183
|
+
before_rule=rule_dict["before_rule"],
|
126
184
|
)
|
127
185
|
|
128
186
|
def __repr__(self):
|
129
|
-
|
187
|
+
"""Pretty-print the rule."""
|
188
|
+
return f'Rule(current_q={self.current_q}, expression="{self.expression}", next_q={self.next_q}, priority={self.priority}, question_name_to_index={self.question_name_to_index}, before_rule={self.before_rule})'
|
130
189
|
|
131
190
|
def __str__(self):
|
191
|
+
"""Return a string representation of the rule."""
|
132
192
|
return self.__repr__()
|
133
193
|
|
134
194
|
@property
|
135
195
|
def question_index_to_name(self):
|
136
|
-
"""
|
196
|
+
"""Reverse the dictionary do we can look up questions by name.
|
197
|
+
|
198
|
+
>>> r = Rule.example()
|
199
|
+
>>> r.question_index_to_name
|
200
|
+
{1: 'q1'}
|
201
|
+
|
202
|
+
"""
|
137
203
|
return {v: k for k, v in self.question_name_to_index.items()}
|
138
204
|
|
139
205
|
def show_ast_tree(self):
|
140
|
-
"""Pretty-
|
206
|
+
"""Pretty-print the AST tree to the terminal.
|
207
|
+
|
208
|
+
>>> r = Rule.example()
|
209
|
+
>>> r.show_ast_tree()
|
210
|
+
Module(...)
|
211
|
+
"""
|
141
212
|
print(
|
142
213
|
ast.dump(
|
143
214
|
self.ast_tree, annotate_fields=True, indent=4, include_attributes=True
|
144
215
|
)
|
145
216
|
)
|
146
217
|
|
147
|
-
|
148
|
-
|
149
|
-
|
218
|
+
@staticmethod
|
219
|
+
def _prepare_replacement(current_info_env: dict[int, Any]):
|
220
|
+
d = {}
|
221
|
+
for var, value in current_info_env.items():
|
222
|
+
if isinstance(value, str):
|
223
|
+
replacement = f"'{value}'"
|
224
|
+
else:
|
225
|
+
replacement = str(value)
|
226
|
+
d[var] = replacement
|
227
|
+
return d
|
228
|
+
|
229
|
+
def evaluate(self, current_info_env: dict[int, Any]):
|
230
|
+
"""Compute the value of the expression, given a dictionary of known questions answers.
|
231
|
+
|
232
|
+
:param current_info_env: A dictionary mapping question, scenario, and agent names to their values.
|
233
|
+
|
150
234
|
If the expression cannot be evaluated, it raises a CannotEvaluate exception.
|
235
|
+
|
236
|
+
>>> r = Rule.example()
|
237
|
+
>>> r.evaluate({'q1' : 'yes'})
|
238
|
+
True
|
239
|
+
>>> r.evaluate({'q1' : 'no'})
|
240
|
+
False
|
241
|
+
|
242
|
+
>>> r = Rule.example(jinja2=True)
|
243
|
+
>>> r.evaluate({'q1' : 'yes'})
|
244
|
+
True
|
245
|
+
|
246
|
+
>>> r = Rule.example(jinja2=True)
|
247
|
+
>>> r.evaluate({'q1' : 'This is q1'})
|
248
|
+
False
|
249
|
+
|
250
|
+
>>> r = Rule.example(jinja2=False, bad = True)
|
251
|
+
>>> r.evaluate({'q1' : 'yes'})
|
252
|
+
Traceback (most recent call last):
|
253
|
+
...
|
254
|
+
edsl.exceptions.surveys.SurveyRuleCannotEvaluateError...
|
151
255
|
"""
|
256
|
+
from jinja2 import Template
|
257
|
+
|
258
|
+
def substitute_in_answers(expression, current_info_env):
|
259
|
+
"""Take the dictionary of answers and substitute them into the expression."""
|
152
260
|
|
153
|
-
|
154
|
-
"Take the dictionary of answers and substitute them into the expression"
|
155
|
-
for var, value in answers.items():
|
156
|
-
# If it's a string, add quotes; otherwise, just convert to string
|
157
|
-
if isinstance(value, str):
|
158
|
-
replacement = f"'{value}'"
|
159
|
-
else:
|
160
|
-
replacement = str(value)
|
261
|
+
current_info = self._prepare_replacement(current_info_env)
|
161
262
|
|
162
|
-
|
163
|
-
|
263
|
+
if "{{" in expression and "}}" in expression:
|
264
|
+
template_expression = Template(self.expression)
|
265
|
+
to_evaluate = template_expression.render(current_info)
|
266
|
+
else:
|
267
|
+
# import warnings
|
268
|
+
# import textwrap
|
269
|
+
# warnings.warn(textwrap.dedent("""\
|
270
|
+
# The expression is not a Jinja2 template with {{ }}. This is not recommended.
|
271
|
+
# You can re-write your expression say "q1 == 'yes'" as "{{ q1 }} == 'yes'".
|
272
|
+
# """))
|
273
|
+
to_evaluate = expression
|
274
|
+
for var, value in current_info.items():
|
275
|
+
to_evaluate = to_evaluate.replace(var, value)
|
276
|
+
|
277
|
+
return to_evaluate
|
164
278
|
|
165
279
|
try:
|
166
|
-
to_evaluate = substitute_in_answers(self.expression,
|
167
|
-
return EvalWithCompoundTypes().eval(to_evaluate)
|
280
|
+
to_evaluate = substitute_in_answers(self.expression, current_info_env)
|
168
281
|
except Exception as e:
|
169
|
-
|
170
|
-
raise SurveyRuleCannotEvaluateError
|
282
|
+
msg = f"""Exception in evaluation: {e}. The expression is: {self.expression}. The current info env trying to substitute in is: {current_info_env}. After the substition, the expression was: {to_evaluate}."""
|
283
|
+
raise SurveyRuleCannotEvaluateError(msg)
|
284
|
+
|
285
|
+
random_functions = {
|
286
|
+
"randint": random.randint,
|
287
|
+
"choice": random.choice,
|
288
|
+
"random": random.random,
|
289
|
+
"uniform": random.uniform,
|
290
|
+
# Add any other random functions you want to allow
|
291
|
+
}
|
292
|
+
|
293
|
+
try:
|
294
|
+
return EvalWithCompoundTypes(functions=random_functions).eval(to_evaluate)
|
295
|
+
except Exception as e:
|
296
|
+
msg = f"""Exception in evaluation: {e}. The expression is: {self.expression}. The current info env trying to substitute in is: {current_info_env}. After the substition, the expression was: {to_evaluate}."""
|
297
|
+
raise SurveyRuleCannotEvaluateError(msg)
|
298
|
+
|
299
|
+
@classmethod
|
300
|
+
def example(cls, jinja2=False, bad=False):
|
301
|
+
if jinja2:
|
302
|
+
# a rule written in jinja2 style with {{ }}
|
303
|
+
expression = "{{ q1 }} == 'yes'"
|
304
|
+
else:
|
305
|
+
expression = "q1 == 'yes'"
|
306
|
+
|
307
|
+
if bad and jinja2:
|
308
|
+
# a rule written in jinja2 style with {{ }} but with a 'bad' expression
|
309
|
+
expression = "{{ q1 }} == 'This is q1'"
|
310
|
+
|
311
|
+
if bad and not jinja2:
|
312
|
+
expression = "q1 == 'This is q1'"
|
313
|
+
|
314
|
+
r = Rule(
|
315
|
+
current_q=1,
|
316
|
+
expression=expression,
|
317
|
+
next_q=2,
|
318
|
+
question_name_to_index={"q1": 1},
|
319
|
+
priority=0,
|
320
|
+
)
|
321
|
+
return r
|
171
322
|
|
172
323
|
|
173
324
|
if __name__ == "__main__":
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
next_q=2,
|
178
|
-
question_name_to_index={"q1": 1},
|
179
|
-
priority=0,
|
180
|
-
)
|
325
|
+
import doctest
|
326
|
+
|
327
|
+
doctest.testmod(optionflags=doctest.ELLIPSIS)
|