rasa-pro 3.12.22__py3-none-any.whl → 3.13.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- rasa/__main__.py +3 -4
- rasa/api.py +1 -1
- rasa/cli/dialogue_understanding_test.py +1 -1
- rasa/cli/e2e_test.py +1 -8
- rasa/cli/evaluate.py +2 -2
- rasa/cli/export.py +5 -3
- rasa/cli/inspect.py +7 -0
- rasa/cli/llm_fine_tuning.py +1 -1
- rasa/cli/project_templates/default/config.yml +5 -32
- rasa/cli/project_templates/{calm → default}/e2e_tests/cancelations/user_cancels_during_a_correction.yml +1 -1
- rasa/cli/project_templates/{calm → default}/e2e_tests/cancelations/user_changes_mind_on_a_whim.yml +1 -1
- rasa/cli/project_templates/{calm → default}/e2e_tests/corrections/user_corrects_contact_handle.yml +1 -1
- rasa/cli/project_templates/{calm → default}/e2e_tests/corrections/user_corrects_contact_name.yml +1 -1
- rasa/cli/project_templates/{calm → default}/e2e_tests/happy_paths/user_adds_contact_to_their_list.yml +1 -1
- rasa/cli/project_templates/{calm → default}/e2e_tests/happy_paths/user_lists_contacts.yml +1 -1
- rasa/cli/project_templates/{calm → default}/e2e_tests/happy_paths/user_removes_contact.yml +1 -1
- rasa/cli/project_templates/{calm → default}/e2e_tests/happy_paths/user_removes_contact_from_list.yml +1 -1
- rasa/cli/project_templates/default/endpoints.yml +18 -2
- rasa/cli/project_templates/defaults.py +133 -0
- rasa/cli/project_templates/tutorial/config.yml +1 -1
- rasa/cli/project_templates/tutorial/endpoints.yml +1 -1
- rasa/cli/run.py +1 -1
- rasa/cli/scaffold.py +2 -3
- rasa/cli/shell.py +6 -1
- rasa/cli/studio/download.py +0 -22
- rasa/cli/studio/link.py +36 -0
- rasa/cli/studio/pull.py +79 -0
- rasa/cli/studio/push.py +78 -0
- rasa/cli/studio/studio.py +12 -0
- rasa/cli/studio/train.py +1 -5
- rasa/cli/studio/upload.py +6 -4
- rasa/cli/train.py +5 -1
- rasa/cli/utils.py +1 -1
- rasa/cli/x.py +1 -1
- rasa/constants.py +2 -0
- rasa/core/__init__.py +0 -16
- rasa/core/actions/action.py +43 -29
- rasa/core/actions/action_repeat_bot_messages.py +18 -22
- rasa/core/actions/action_run_slot_rejections.py +1 -2
- rasa/core/agent.py +24 -3
- rasa/core/available_endpoints.py +146 -0
- rasa/core/brokers/kafka.py +4 -0
- rasa/core/brokers/pika.py +5 -2
- rasa/core/brokers/sql.py +1 -1
- rasa/core/channels/__init__.py +3 -0
- rasa/core/channels/botframework.py +2 -2
- rasa/core/channels/channel.py +2 -2
- rasa/core/channels/development_inspector.py +1 -1
- rasa/core/channels/facebook.py +1 -4
- rasa/core/channels/hangouts.py +8 -5
- rasa/core/channels/inspector/.eslintrc.cjs +12 -6
- rasa/core/channels/inspector/.prettierrc +5 -0
- rasa/core/channels/inspector/README.md +11 -5
- rasa/core/channels/inspector/dist/assets/{arc-9f75cc3b.js → arc-371401b1.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{blockDiagram-38ab4fdb-7f34db23.js → blockDiagram-38ab4fdb-3f126156.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{c4Diagram-3d4e48cf-948bab2c.js → c4Diagram-3d4e48cf-12f22eb7.js} +1 -1
- rasa/core/channels/inspector/dist/assets/channel-f1efda17.js +1 -0
- rasa/core/channels/inspector/dist/assets/{classDiagram-70f12bd4-53b0dd0e.js → classDiagram-70f12bd4-03b1d386.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{classDiagram-v2-f2320105-fdf789e7.js → classDiagram-v2-f2320105-84f69d63.js} +1 -1
- rasa/core/channels/inspector/dist/assets/clone-fdf164e2.js +1 -0
- rasa/core/channels/inspector/dist/assets/{createText-2e5e7dd3-87c4ece5.js → createText-2e5e7dd3-ca47fd38.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{edges-e0da2a9e-5a8b0749.js → edges-e0da2a9e-f837ca8a.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{erDiagram-9861fffd-66da90e2.js → erDiagram-9861fffd-8717ac54.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{flowDb-956e92f1-10044f05.js → flowDb-956e92f1-94f38b83.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{flowDiagram-66a62f08-f338f66a.js → flowDiagram-66a62f08-b616f9fb.js} +1 -1
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-96b9c2cf-7d7a1629.js +1 -0
- rasa/core/channels/inspector/dist/assets/{flowchart-elk-definition-4a651766-b13140aa.js → flowchart-elk-definition-4a651766-f5d24bb8.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{ganttDiagram-c361ad54-f2b4a55a.js → ganttDiagram-c361ad54-b43ba8d9.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{gitGraphDiagram-72cf32ee-dedc298d.js → gitGraphDiagram-72cf32ee-c3aafaa5.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{graph-4ede11ff.js → graph-0d0a2c10.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{index-3862675e-65549d37.js → index-3862675e-58ea0305.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{index-3a23e736.js → index-cce6f8a1.js} +123 -123
- rasa/core/channels/inspector/dist/assets/{infoDiagram-f8f76790-65439671.js → infoDiagram-f8f76790-b8f60461.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{journeyDiagram-49397b02-56d03d98.js → journeyDiagram-49397b02-95be5545.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{layout-dd48f7f4.js → layout-da885b9b.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{line-1569ad2c.js → line-f1c817d3.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{linear-48bf4935.js → linear-d42801e6.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{mindmap-definition-fc14e90a-688504c1.js → mindmap-definition-fc14e90a-a38923a6.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{pieDiagram-8a3498a8-78b6d7e6.js → pieDiagram-8a3498a8-ca6e71e9.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{quadrantDiagram-120e2f19-048b84b3.js → quadrantDiagram-120e2f19-b290dae9.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{requirementDiagram-deff3bca-dd67f107.js → requirementDiagram-deff3bca-03f02ceb.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{sankeyDiagram-04a897e0-8128436e.js → sankeyDiagram-04a897e0-c49eee40.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{sequenceDiagram-704730f1-1a0d1461.js → sequenceDiagram-704730f1-b2cd6a3d.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{stateDiagram-587899a1-46d388ed.js → stateDiagram-587899a1-e53a2028.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{stateDiagram-v2-d93cdb3a-ea42951a.js → stateDiagram-v2-d93cdb3a-e1982a03.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-6aaf32cf-7427ed0c.js → styles-6aaf32cf-d0226ca5.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-9a916d00-ff5e5a16.js → styles-9a916d00-0e21dc00.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-c10674c1-7b3680cf.js → styles-c10674c1-9588494e.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{svgDrawCommon-08f97a94-f860f2ad.js → svgDrawCommon-08f97a94-be478d4f.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{timeline-definition-85554ec2-2eebf0c8.js → timeline-definition-85554ec2-74631749.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{xychartDiagram-e933f94c-5d7f4e96.js → xychartDiagram-e933f94c-a043552f.js} +1 -1
- rasa/core/channels/inspector/dist/index.html +1 -1
- rasa/core/channels/inspector/package.json +3 -1
- rasa/core/channels/inspector/src/App.tsx +91 -90
- rasa/core/channels/inspector/src/components/Chat.tsx +45 -41
- rasa/core/channels/inspector/src/components/DiagramFlow.tsx +40 -40
- rasa/core/channels/inspector/src/components/DialogueInformation.tsx +57 -57
- rasa/core/channels/inspector/src/components/DialogueStack.tsx +36 -27
- rasa/core/channels/inspector/src/components/ExpandIcon.tsx +4 -4
- rasa/core/channels/inspector/src/components/FullscreenButton.tsx +7 -7
- rasa/core/channels/inspector/src/components/LoadingSpinner.tsx +28 -12
- rasa/core/channels/inspector/src/components/NoActiveFlow.tsx +9 -9
- rasa/core/channels/inspector/src/components/RasaLogo.tsx +5 -5
- rasa/core/channels/inspector/src/components/RecruitmentPanel.tsx +55 -60
- rasa/core/channels/inspector/src/components/SaraDiagrams.tsx +5 -5
- rasa/core/channels/inspector/src/components/Slots.tsx +22 -22
- rasa/core/channels/inspector/src/components/Welcome.tsx +28 -31
- rasa/core/channels/inspector/src/helpers/audio/audiostream.ts +245 -0
- rasa/core/channels/inspector/src/helpers/audio/microphone-processor.js +12 -0
- rasa/core/channels/inspector/src/helpers/audio/playback-processor.js +36 -0
- rasa/core/channels/inspector/src/helpers/conversation.ts +7 -7
- rasa/core/channels/inspector/src/helpers/formatters.test.ts +181 -181
- rasa/core/channels/inspector/src/helpers/formatters.ts +111 -111
- rasa/core/channels/inspector/src/helpers/utils.ts +78 -61
- rasa/core/channels/inspector/src/main.tsx +8 -8
- rasa/core/channels/inspector/src/theme/Button/Button.ts +8 -8
- rasa/core/channels/inspector/src/theme/Heading/Heading.ts +7 -7
- rasa/core/channels/inspector/src/theme/Input/Input.ts +9 -9
- rasa/core/channels/inspector/src/theme/Link/Link.ts +6 -6
- rasa/core/channels/inspector/src/theme/Modal/Modal.ts +13 -13
- rasa/core/channels/inspector/src/theme/Table/Table.tsx +10 -10
- rasa/core/channels/inspector/src/theme/Tooltip/Tooltip.ts +5 -5
- rasa/core/channels/inspector/src/theme/base/breakpoints.ts +7 -7
- rasa/core/channels/inspector/src/theme/base/colors.ts +64 -64
- rasa/core/channels/inspector/src/theme/base/fonts/fontFaces.css +21 -18
- rasa/core/channels/inspector/src/theme/base/radii.ts +8 -8
- rasa/core/channels/inspector/src/theme/base/shadows.ts +5 -5
- rasa/core/channels/inspector/src/theme/base/sizes.ts +5 -5
- rasa/core/channels/inspector/src/theme/base/space.ts +12 -12
- rasa/core/channels/inspector/src/theme/base/styles.ts +5 -5
- rasa/core/channels/inspector/src/theme/base/typography.ts +12 -12
- rasa/core/channels/inspector/src/theme/base/zIndices.ts +3 -3
- rasa/core/channels/inspector/src/theme/index.ts +38 -38
- rasa/core/channels/inspector/src/types.ts +56 -50
- rasa/core/channels/inspector/yarn.lock +5 -0
- rasa/core/channels/mattermost.py +1 -1
- rasa/core/channels/rasa_chat.py +2 -4
- rasa/core/channels/rest.py +5 -4
- rasa/core/channels/socketio.py +56 -41
- rasa/core/channels/studio_chat.py +329 -68
- rasa/core/channels/vier_cvg.py +1 -2
- rasa/core/channels/voice_ready/audiocodes.py +4 -11
- rasa/core/channels/voice_ready/jambonz.py +5 -6
- rasa/core/channels/voice_ready/twilio_voice.py +13 -12
- rasa/core/channels/voice_ready/utils.py +22 -0
- rasa/core/channels/voice_stream/audiocodes.py +13 -16
- rasa/core/channels/voice_stream/browser_audio.py +1 -1
- rasa/core/channels/voice_stream/genesys.py +37 -18
- rasa/core/channels/voice_stream/jambonz.py +232 -0
- rasa/core/channels/voice_stream/tts/__init__.py +8 -0
- rasa/core/channels/voice_stream/twilio_media_streams.py +15 -12
- rasa/core/channels/voice_stream/voice_channel.py +71 -27
- rasa/core/concurrent_lock_store.py +24 -10
- rasa/core/evaluation/marker_tracker_loader.py +1 -1
- rasa/core/exporter.py +37 -1
- rasa/core/http_interpreter.py +3 -7
- rasa/core/information_retrieval/faiss.py +18 -11
- rasa/core/information_retrieval/ingestion/faq_parser.py +158 -0
- rasa/core/jobs.py +2 -1
- rasa/core/lock_store.py +151 -60
- rasa/core/nlg/contextual_response_rephraser.py +17 -7
- rasa/core/nlg/generator.py +5 -22
- rasa/core/nlg/interpolator.py +2 -3
- rasa/core/nlg/response.py +6 -43
- rasa/core/nlg/summarize.py +1 -1
- rasa/core/nlg/translate.py +0 -8
- rasa/core/policies/enterprise_search_policy.py +305 -189
- rasa/core/policies/enterprise_search_policy_config.py +241 -0
- rasa/core/policies/enterprise_search_prompt_with_relevancy_check_and_citation_template.jinja2 +67 -0
- rasa/core/policies/flow_policy.py +1 -1
- rasa/core/policies/flows/flow_executor.py +102 -17
- rasa/core/policies/intentless_policy.py +56 -17
- rasa/core/processor.py +70 -49
- rasa/core/run.py +33 -11
- rasa/core/tracker_stores/__init__.py +0 -0
- rasa/core/{auth_retry_tracker_store.py → tracker_stores/auth_retry_tracker_store.py} +66 -1
- rasa/core/tracker_stores/dynamo_tracker_store.py +256 -0
- rasa/core/tracker_stores/mongo_tracker_store.py +223 -0
- rasa/core/tracker_stores/redis_tracker_store.py +252 -0
- rasa/core/tracker_stores/sql_tracker_store.py +582 -0
- rasa/core/tracker_stores/tracker_store.py +839 -0
- rasa/core/training/interactive.py +1 -1
- rasa/core/utils.py +24 -95
- rasa/dialogue_understanding/coexistence/intent_based_router.py +2 -1
- rasa/dialogue_understanding/coexistence/llm_based_router.py +13 -11
- rasa/dialogue_understanding/commands/can_not_handle_command.py +2 -0
- rasa/dialogue_understanding/commands/cancel_flow_command.py +3 -1
- rasa/dialogue_understanding/commands/chit_chat_answer_command.py +2 -0
- rasa/dialogue_understanding/commands/clarify_command.py +6 -2
- rasa/dialogue_understanding/commands/command_syntax_manager.py +1 -0
- rasa/dialogue_understanding/commands/correct_slots_command.py +5 -6
- rasa/dialogue_understanding/commands/error_command.py +1 -1
- rasa/dialogue_understanding/commands/human_handoff_command.py +3 -3
- rasa/dialogue_understanding/commands/knowledge_answer_command.py +2 -0
- rasa/dialogue_understanding/commands/repeat_bot_messages_command.py +2 -0
- rasa/dialogue_understanding/commands/set_slot_command.py +8 -4
- rasa/dialogue_understanding/commands/skip_question_command.py +3 -3
- rasa/dialogue_understanding/commands/start_flow_command.py +7 -3
- rasa/dialogue_understanding/generator/__init__.py +7 -1
- rasa/dialogue_understanding/generator/command_generator.py +4 -2
- rasa/dialogue_understanding/generator/command_parser.py +2 -2
- rasa/dialogue_understanding/generator/command_parser_validator.py +63 -0
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +1 -2
- rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +3 -2
- rasa/dialogue_understanding/generator/nlu_command_adapter.py +2 -2
- rasa/dialogue_understanding/generator/prompt_templates/command_prompt_template.jinja2 +0 -2
- rasa/dialogue_understanding/generator/prompt_templates/command_prompt_v2_claude_3_5_sonnet_20240620_template.jinja2 +1 -0
- rasa/dialogue_understanding/generator/prompt_templates/command_prompt_v2_gpt_4o_2024_11_20_template.jinja2 +1 -0
- rasa/dialogue_understanding/generator/prompt_templates/command_prompt_v3_claude_3_5_sonnet_20240620_template.jinja2 +79 -0
- rasa/dialogue_understanding/generator/prompt_templates/command_prompt_v3_gpt_4o_2024_11_20_template.jinja2 +79 -0
- rasa/dialogue_understanding/generator/single_step/compact_llm_command_generator.py +26 -461
- rasa/dialogue_understanding/generator/single_step/search_ready_llm_command_generator.py +147 -0
- rasa/dialogue_understanding/generator/single_step/single_step_based_llm_command_generator.py +461 -0
- rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +20 -64
- rasa/dialogue_understanding/patterns/cancel.py +1 -2
- rasa/dialogue_understanding/patterns/clarify.py +1 -1
- rasa/dialogue_understanding/patterns/correction.py +2 -2
- rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +42 -27
- rasa/dialogue_understanding/patterns/domain_for_patterns.py +190 -0
- rasa/dialogue_understanding/processor/command_processor.py +6 -7
- rasa/dialogue_understanding_test/command_metric_calculation.py +7 -40
- rasa/dialogue_understanding_test/command_metrics.py +38 -0
- rasa/dialogue_understanding_test/du_test_case.py +58 -25
- rasa/dialogue_understanding_test/du_test_result.py +228 -132
- rasa/dialogue_understanding_test/du_test_runner.py +11 -2
- rasa/dialogue_understanding_test/du_test_schema.yml +3 -3
- rasa/dialogue_understanding_test/io.py +35 -8
- rasa/e2e_test/constants.py +1 -1
- rasa/e2e_test/e2e_test_runner.py +1 -1
- rasa/e2e_test/e2e_test_schema.yml +3 -3
- rasa/engine/constants.py +1 -1
- rasa/engine/graph.py +2 -2
- rasa/engine/recipes/default_recipe.py +1 -1
- rasa/engine/validation.py +3 -2
- rasa/hooks.py +2 -30
- rasa/llm_fine_tuning/paraphrasing/conversation_rephraser.py +2 -6
- rasa/model_manager/model_api.py +89 -1
- rasa/model_manager/runner_service.py +20 -4
- rasa/model_manager/socket_bridge.py +0 -7
- rasa/model_manager/trainer_service.py +10 -4
- rasa/plugin.py +2 -15
- rasa/privacy/__init__.py +0 -0
- rasa/privacy/constants.py +83 -0
- rasa/privacy/event_broker_utils.py +77 -0
- rasa/privacy/privacy_config.py +281 -0
- rasa/privacy/privacy_config_schema.json +86 -0
- rasa/privacy/privacy_filter.py +393 -0
- rasa/privacy/privacy_manager.py +594 -0
- rasa/server.py +23 -2
- rasa/shared/constants.py +17 -0
- rasa/shared/core/command_payload_reader.py +1 -5
- rasa/shared/core/constants.py +4 -3
- rasa/shared/core/domain.py +172 -11
- rasa/shared/core/events.py +100 -6
- rasa/shared/core/flows/flow.py +30 -5
- rasa/shared/core/flows/flow_step.py +19 -3
- rasa/shared/core/flows/flow_step_links.py +15 -0
- rasa/shared/core/flows/flow_step_sequence.py +6 -0
- rasa/shared/core/flows/flows_yaml_schema.json +3 -0
- rasa/shared/core/flows/nlu_trigger.py +13 -0
- rasa/shared/core/flows/steps/action.py +7 -4
- rasa/shared/core/flows/steps/call.py +11 -4
- rasa/shared/core/flows/steps/collect.py +71 -6
- rasa/shared/core/flows/steps/internal.py +6 -1
- rasa/shared/core/flows/steps/link.py +7 -4
- rasa/shared/core/flows/steps/no_operation.py +7 -4
- rasa/shared/core/flows/steps/set_slots.py +8 -4
- rasa/shared/core/flows/validation.py +25 -5
- rasa/shared/core/flows/yaml_flows_io.py +106 -5
- rasa/shared/core/slots.py +29 -1
- rasa/shared/core/trackers.py +21 -10
- rasa/shared/core/training_data/story_reader/yaml_story_reader.py +1 -4
- rasa/shared/importers/importer.py +8 -0
- rasa/shared/providers/_configs/azure_openai_client_config.py +2 -2
- rasa/shared/providers/_configs/default_litellm_client_config.py +1 -1
- rasa/shared/providers/_configs/huggingface_local_embedding_client_config.py +1 -1
- rasa/shared/providers/_configs/openai_client_config.py +1 -1
- rasa/shared/providers/_configs/rasa_llm_client_config.py +1 -1
- rasa/shared/providers/_configs/self_hosted_llm_client_config.py +1 -1
- rasa/shared/providers/_configs/utils.py +0 -99
- rasa/shared/providers/llm/default_litellm_llm_client.py +2 -2
- rasa/shared/utils/common.py +43 -1
- rasa/shared/utils/configs.py +110 -0
- rasa/shared/utils/constants.py +0 -3
- rasa/shared/utils/llm.py +245 -8
- rasa/shared/utils/pykwalify_extensions.py +0 -9
- rasa/shared/utils/yaml.py +32 -0
- rasa/studio/constants.py +1 -0
- rasa/studio/data_handler.py +33 -12
- rasa/studio/download.py +117 -435
- rasa/studio/link.py +211 -0
- rasa/studio/prompts.py +221 -0
- rasa/studio/pull/__init__.py +0 -0
- rasa/studio/pull/data.py +222 -0
- rasa/studio/pull/domains.py +60 -0
- rasa/studio/pull/pull.py +239 -0
- rasa/studio/push.py +138 -0
- rasa/studio/results_logger.py +6 -1
- rasa/studio/train.py +1 -1
- rasa/studio/upload.py +243 -72
- rasa/studio/utils.py +33 -0
- rasa/telemetry.py +83 -26
- rasa/tracing/config.py +4 -5
- rasa/tracing/constants.py +19 -1
- rasa/tracing/instrumentation/attribute_extractors.py +68 -16
- rasa/tracing/instrumentation/instrumentation.py +54 -3
- rasa/tracing/instrumentation/metrics.py +98 -15
- rasa/tracing/metric_instrument_provider.py +75 -3
- rasa/utils/common.py +43 -22
- rasa/utils/endpoints.py +22 -1
- rasa/utils/licensing.py +2 -3
- rasa/utils/log_utils.py +1 -45
- rasa/validator.py +2 -8
- rasa/version.py +1 -1
- {rasa_pro-3.12.22.dist-info → rasa_pro-3.13.0.dist-info}/METADATA +11 -12
- {rasa_pro-3.12.22.dist-info → rasa_pro-3.13.0.dist-info}/RECORD +333 -309
- rasa/anonymization/__init__.py +0 -2
- rasa/anonymization/anonymisation_rule_yaml_reader.py +0 -91
- rasa/anonymization/anonymization_pipeline.py +0 -286
- rasa/anonymization/anonymization_rule_executor.py +0 -266
- rasa/anonymization/anonymization_rule_orchestrator.py +0 -119
- rasa/anonymization/schemas/config.yml +0 -47
- rasa/anonymization/utils.py +0 -118
- rasa/cli/project_templates/calm/config.yml +0 -10
- rasa/cli/project_templates/calm/credentials.yml +0 -33
- rasa/cli/project_templates/calm/endpoints.yml +0 -58
- rasa/cli/project_templates/default/actions/actions.py +0 -27
- rasa/cli/project_templates/default/data/nlu.yml +0 -91
- rasa/cli/project_templates/default/data/rules.yml +0 -13
- rasa/cli/project_templates/default/data/stories.yml +0 -30
- rasa/cli/project_templates/default/domain.yml +0 -34
- rasa/cli/project_templates/default/tests/test_stories.yml +0 -91
- rasa/core/channels/inspector/dist/assets/channel-dfa68278.js +0 -1
- rasa/core/channels/inspector/dist/assets/clone-edb7f119.js +0 -1
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-96b9c2cf-65e7c670.js +0 -1
- rasa/core/channels/inspector/src/helpers/audiostream.ts +0 -191
- rasa/core/tracker_store.py +0 -1792
- /rasa/cli/project_templates/{calm → default}/actions/action_template.py +0 -0
- /rasa/cli/project_templates/{calm → default}/actions/add_contact.py +0 -0
- /rasa/cli/project_templates/{calm → default}/actions/db.py +0 -0
- /rasa/cli/project_templates/{calm → default}/actions/list_contacts.py +0 -0
- /rasa/cli/project_templates/{calm → default}/actions/remove_contact.py +0 -0
- /rasa/cli/project_templates/{calm → default}/data/flows/add_contact.yml +0 -0
- /rasa/cli/project_templates/{calm → default}/data/flows/list_contacts.yml +0 -0
- /rasa/cli/project_templates/{calm → default}/data/flows/remove_contact.yml +0 -0
- /rasa/cli/project_templates/{calm → default}/db/contacts.json +0 -0
- /rasa/cli/project_templates/{calm → default}/domain/add_contact.yml +0 -0
- /rasa/cli/project_templates/{calm → default}/domain/list_contacts.yml +0 -0
- /rasa/cli/project_templates/{calm → default}/domain/remove_contact.yml +0 -0
- /rasa/cli/project_templates/{calm → default}/domain/shared.yml +0 -0
- /rasa/{cli/project_templates/calm/actions → core/information_retrieval/ingestion}/__init__.py +0 -0
- {rasa_pro-3.12.22.dist-info → rasa_pro-3.13.0.dist-info}/NOTICE +0 -0
- {rasa_pro-3.12.22.dist-info → rasa_pro-3.13.0.dist-info}/WHEEL +0 -0
- {rasa_pro-3.12.22.dist-info → rasa_pro-3.13.0.dist-info}/entry_points.txt +0 -0
rasa/core/tracker_store.py
DELETED
|
@@ -1,1792 +0,0 @@
|
|
|
1
|
-
from __future__ import annotations
|
|
2
|
-
|
|
3
|
-
import contextlib
|
|
4
|
-
import itertools
|
|
5
|
-
import json
|
|
6
|
-
import os
|
|
7
|
-
from inspect import isawaitable, iscoroutinefunction
|
|
8
|
-
from time import sleep
|
|
9
|
-
from typing import (
|
|
10
|
-
TYPE_CHECKING,
|
|
11
|
-
Any,
|
|
12
|
-
Callable,
|
|
13
|
-
Dict,
|
|
14
|
-
Generator,
|
|
15
|
-
Generic,
|
|
16
|
-
Iterable,
|
|
17
|
-
Iterator,
|
|
18
|
-
List,
|
|
19
|
-
Optional,
|
|
20
|
-
Text,
|
|
21
|
-
TypeVar,
|
|
22
|
-
Union,
|
|
23
|
-
)
|
|
24
|
-
|
|
25
|
-
import sqlalchemy as sa
|
|
26
|
-
import structlog
|
|
27
|
-
from boto3.dynamodb.conditions import Key
|
|
28
|
-
from pymongo.collection import Collection
|
|
29
|
-
|
|
30
|
-
import rasa.shared.utils.cli
|
|
31
|
-
import rasa.shared.utils.common
|
|
32
|
-
import rasa.shared.utils.io
|
|
33
|
-
import rasa.utils.json_utils
|
|
34
|
-
from rasa.constants import DEFAULT_SANIC_WORKERS, ENV_SANIC_WORKERS
|
|
35
|
-
from rasa.core.brokers.broker import EventBroker
|
|
36
|
-
from rasa.core.constants import (
|
|
37
|
-
POSTGRESQL_MAX_OVERFLOW,
|
|
38
|
-
POSTGRESQL_POOL_SIZE,
|
|
39
|
-
POSTGRESQL_SCHEMA,
|
|
40
|
-
)
|
|
41
|
-
from rasa.plugin import plugin_manager
|
|
42
|
-
from rasa.shared.core.constants import ACTION_LISTEN_NAME
|
|
43
|
-
from rasa.shared.core.conversation import Dialogue
|
|
44
|
-
from rasa.shared.core.domain import Domain
|
|
45
|
-
from rasa.shared.core.events import Event, SessionStarted
|
|
46
|
-
from rasa.shared.core.trackers import (
|
|
47
|
-
ActionExecuted,
|
|
48
|
-
DialogueStateTracker,
|
|
49
|
-
EventVerbosity,
|
|
50
|
-
TrackerEventDiffEngine,
|
|
51
|
-
)
|
|
52
|
-
from rasa.shared.exceptions import ConnectionException, RasaException
|
|
53
|
-
from rasa.shared.nlu.constants import INTENT_NAME_KEY
|
|
54
|
-
from rasa.utils.endpoints import EndpointConfig
|
|
55
|
-
|
|
56
|
-
if TYPE_CHECKING:
|
|
57
|
-
import boto3.resources.factory.dynamodb.Table
|
|
58
|
-
from sqlalchemy import Sequence
|
|
59
|
-
from sqlalchemy.engine.base import Engine
|
|
60
|
-
from sqlalchemy.engine.url import URL
|
|
61
|
-
from sqlalchemy.orm import Query, Session
|
|
62
|
-
|
|
63
|
-
structlogger = structlog.get_logger(__name__)
|
|
64
|
-
|
|
65
|
-
# default values of PostgreSQL pool size and max overflow
|
|
66
|
-
POSTGRESQL_DEFAULT_MAX_OVERFLOW = 100
|
|
67
|
-
POSTGRESQL_DEFAULT_POOL_SIZE = 50
|
|
68
|
-
|
|
69
|
-
# default value for key prefix in RedisTrackerStore
|
|
70
|
-
DEFAULT_REDIS_TRACKER_STORE_KEY_PREFIX = "tracker:"
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
def check_if_tracker_store_async(tracker_store: TrackerStore) -> bool:
|
|
74
|
-
"""Evaluates if a tracker store object is async based on implementation of methods.
|
|
75
|
-
|
|
76
|
-
:param tracker_store: tracker store object we're evaluating
|
|
77
|
-
:return: if the tracker store correctly implements all async methods
|
|
78
|
-
"""
|
|
79
|
-
return all(
|
|
80
|
-
iscoroutinefunction(getattr(tracker_store, method))
|
|
81
|
-
for method in _get_async_tracker_store_methods()
|
|
82
|
-
)
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
def _get_async_tracker_store_methods() -> List[str]:
|
|
86
|
-
return [
|
|
87
|
-
attribute
|
|
88
|
-
for attribute in dir(TrackerStore)
|
|
89
|
-
if iscoroutinefunction(getattr(TrackerStore, attribute))
|
|
90
|
-
]
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
class TrackerDeserialisationException(RasaException):
|
|
94
|
-
"""Raised when an error is encountered while deserialising a tracker."""
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
SerializationType = TypeVar("SerializationType")
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
class SerializedTrackerRepresentation(Generic[SerializationType]):
|
|
101
|
-
"""Mixin class for specifying different serialization methods per tracker store."""
|
|
102
|
-
|
|
103
|
-
@staticmethod
|
|
104
|
-
def serialise_tracker(tracker: DialogueStateTracker) -> SerializationType:
|
|
105
|
-
"""Requires implementation to return representation of tracker."""
|
|
106
|
-
raise NotImplementedError()
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
class SerializedTrackerAsText(SerializedTrackerRepresentation[Text]):
|
|
110
|
-
"""Mixin class that returns the serialized tracker as string."""
|
|
111
|
-
|
|
112
|
-
@staticmethod
|
|
113
|
-
def serialise_tracker(tracker: DialogueStateTracker) -> Text:
|
|
114
|
-
"""Serializes the tracker, returns representation of the tracker."""
|
|
115
|
-
dialogue = tracker.as_dialogue()
|
|
116
|
-
|
|
117
|
-
return json.dumps(dialogue.as_dict())
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
class SerializedTrackerAsDict(SerializedTrackerRepresentation[Dict]):
|
|
121
|
-
"""Mixin class that returns the serialized tracker as dictionary."""
|
|
122
|
-
|
|
123
|
-
@staticmethod
|
|
124
|
-
def serialise_tracker(tracker: DialogueStateTracker) -> Dict:
|
|
125
|
-
"""Serializes the tracker, returns representation of the tracker."""
|
|
126
|
-
d = tracker.as_dialogue().as_dict()
|
|
127
|
-
d.update({"sender_id": tracker.sender_id})
|
|
128
|
-
return d
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
class TrackerStore:
|
|
132
|
-
"""Represents common behavior and interface for all `TrackerStore`s."""
|
|
133
|
-
|
|
134
|
-
def __init__(
|
|
135
|
-
self,
|
|
136
|
-
domain: Optional[Domain],
|
|
137
|
-
event_broker: Optional[EventBroker] = None,
|
|
138
|
-
**kwargs: Dict[Text, Any],
|
|
139
|
-
) -> None:
|
|
140
|
-
"""Create a TrackerStore.
|
|
141
|
-
|
|
142
|
-
Args:
|
|
143
|
-
domain: The `Domain` to initialize the `DialogueStateTracker`.
|
|
144
|
-
event_broker: An event broker to publish any new events to another
|
|
145
|
-
destination.
|
|
146
|
-
kwargs: Additional kwargs.
|
|
147
|
-
"""
|
|
148
|
-
self._domain = domain or Domain.empty()
|
|
149
|
-
self.event_broker = event_broker
|
|
150
|
-
self.max_event_history: Optional[int] = None
|
|
151
|
-
|
|
152
|
-
@staticmethod
|
|
153
|
-
def create(
|
|
154
|
-
obj: Union[TrackerStore, EndpointConfig, None],
|
|
155
|
-
domain: Optional[Domain] = None,
|
|
156
|
-
event_broker: Optional[EventBroker] = None,
|
|
157
|
-
) -> TrackerStore:
|
|
158
|
-
"""Factory to create a tracker store."""
|
|
159
|
-
if isinstance(obj, TrackerStore):
|
|
160
|
-
return obj
|
|
161
|
-
|
|
162
|
-
import pymongo.errors
|
|
163
|
-
import sqlalchemy.exc
|
|
164
|
-
from botocore.exceptions import BotoCoreError
|
|
165
|
-
|
|
166
|
-
try:
|
|
167
|
-
_tracker_store = plugin_manager().hook.create_tracker_store(
|
|
168
|
-
endpoint_config=obj,
|
|
169
|
-
domain=domain,
|
|
170
|
-
event_broker=event_broker,
|
|
171
|
-
)
|
|
172
|
-
|
|
173
|
-
tracker_store = (
|
|
174
|
-
_tracker_store
|
|
175
|
-
if _tracker_store
|
|
176
|
-
else create_tracker_store(obj, domain, event_broker)
|
|
177
|
-
)
|
|
178
|
-
|
|
179
|
-
return tracker_store
|
|
180
|
-
except (
|
|
181
|
-
BotoCoreError,
|
|
182
|
-
pymongo.errors.ConnectionFailure,
|
|
183
|
-
sqlalchemy.exc.OperationalError,
|
|
184
|
-
ConnectionError,
|
|
185
|
-
pymongo.errors.OperationFailure,
|
|
186
|
-
) as error:
|
|
187
|
-
raise ConnectionException(
|
|
188
|
-
"Cannot connect to tracker store." + str(error)
|
|
189
|
-
) from error
|
|
190
|
-
|
|
191
|
-
async def get_or_create_tracker(
|
|
192
|
-
self,
|
|
193
|
-
sender_id: Text,
|
|
194
|
-
max_event_history: Optional[int] = None,
|
|
195
|
-
append_action_listen: bool = True,
|
|
196
|
-
) -> "DialogueStateTracker":
|
|
197
|
-
"""Returns tracker or creates one if the retrieval returns None.
|
|
198
|
-
|
|
199
|
-
Args:
|
|
200
|
-
sender_id: Conversation ID associated with the requested tracker.
|
|
201
|
-
max_event_history: Value to update the tracker store's max event history to.
|
|
202
|
-
append_action_listen: Whether or not to append an initial `action_listen`.
|
|
203
|
-
"""
|
|
204
|
-
self.max_event_history = max_event_history
|
|
205
|
-
|
|
206
|
-
tracker = await self.retrieve(sender_id)
|
|
207
|
-
|
|
208
|
-
if tracker is None:
|
|
209
|
-
tracker = await self.create_tracker(
|
|
210
|
-
sender_id, append_action_listen=append_action_listen
|
|
211
|
-
)
|
|
212
|
-
|
|
213
|
-
return tracker
|
|
214
|
-
|
|
215
|
-
def init_tracker(self, sender_id: Text) -> "DialogueStateTracker":
|
|
216
|
-
"""Returns a Dialogue State Tracker."""
|
|
217
|
-
return DialogueStateTracker(
|
|
218
|
-
sender_id,
|
|
219
|
-
self.domain.slots,
|
|
220
|
-
max_event_history=self.max_event_history,
|
|
221
|
-
)
|
|
222
|
-
|
|
223
|
-
async def create_tracker(
|
|
224
|
-
self, sender_id: Text, append_action_listen: bool = True
|
|
225
|
-
) -> DialogueStateTracker:
|
|
226
|
-
"""Creates a new tracker for `sender_id`.
|
|
227
|
-
|
|
228
|
-
The tracker begins with a `SessionStarted` event and is initially listening.
|
|
229
|
-
|
|
230
|
-
Args:
|
|
231
|
-
sender_id: Conversation ID associated with the tracker.
|
|
232
|
-
append_action_listen: Whether or not to append an initial `action_listen`.
|
|
233
|
-
|
|
234
|
-
Returns:
|
|
235
|
-
The newly created tracker for `sender_id`.
|
|
236
|
-
"""
|
|
237
|
-
tracker = self.init_tracker(sender_id)
|
|
238
|
-
|
|
239
|
-
if append_action_listen:
|
|
240
|
-
tracker.update(ActionExecuted(ACTION_LISTEN_NAME))
|
|
241
|
-
|
|
242
|
-
await self.save(tracker)
|
|
243
|
-
|
|
244
|
-
return tracker
|
|
245
|
-
|
|
246
|
-
async def save(self, tracker: DialogueStateTracker) -> None:
|
|
247
|
-
"""Save method that will be overridden by specific tracker."""
|
|
248
|
-
raise NotImplementedError()
|
|
249
|
-
|
|
250
|
-
async def exists(self, conversation_id: Text) -> bool:
|
|
251
|
-
"""Checks if tracker exists for the specified ID.
|
|
252
|
-
|
|
253
|
-
This method may be overridden by the specific tracker store for
|
|
254
|
-
faster implementations.
|
|
255
|
-
|
|
256
|
-
Args:
|
|
257
|
-
conversation_id: Conversation ID to check if the tracker exists.
|
|
258
|
-
|
|
259
|
-
Returns:
|
|
260
|
-
`True` if the tracker exists, `False` otherwise.
|
|
261
|
-
"""
|
|
262
|
-
return await self.retrieve(conversation_id) is not None
|
|
263
|
-
|
|
264
|
-
async def retrieve(self, sender_id: Text) -> Optional[DialogueStateTracker]:
|
|
265
|
-
"""Retrieves tracker for the latest conversation session.
|
|
266
|
-
|
|
267
|
-
This method will be overridden by the specific tracker store.
|
|
268
|
-
|
|
269
|
-
Args:
|
|
270
|
-
sender_id: Conversation ID to fetch the tracker for.
|
|
271
|
-
|
|
272
|
-
Returns:
|
|
273
|
-
Tracker containing events from the latest conversation sessions.
|
|
274
|
-
"""
|
|
275
|
-
raise NotImplementedError()
|
|
276
|
-
|
|
277
|
-
async def retrieve_full_tracker(
|
|
278
|
-
self, conversation_id: Text
|
|
279
|
-
) -> Optional[DialogueStateTracker]:
|
|
280
|
-
"""Retrieve method for fetching all tracker events.
|
|
281
|
-
|
|
282
|
-
Fetches events across conversation sessions. The default implementation
|
|
283
|
-
uses `self.retrieve()`.
|
|
284
|
-
|
|
285
|
-
Args:
|
|
286
|
-
conversation_id: The conversation ID to retrieve the tracker for.
|
|
287
|
-
|
|
288
|
-
Returns:
|
|
289
|
-
The fetch tracker containing all events across session starts.
|
|
290
|
-
"""
|
|
291
|
-
return await self.retrieve(conversation_id)
|
|
292
|
-
|
|
293
|
-
async def get_or_create_full_tracker(
|
|
294
|
-
self,
|
|
295
|
-
sender_id: Text,
|
|
296
|
-
append_action_listen: bool = True,
|
|
297
|
-
) -> "DialogueStateTracker":
|
|
298
|
-
"""Returns tracker or creates one if the retrieval returns None.
|
|
299
|
-
|
|
300
|
-
Args:
|
|
301
|
-
sender_id: Conversation ID associated with the requested tracker.
|
|
302
|
-
append_action_listen: Whether to append an initial `action_listen`.
|
|
303
|
-
|
|
304
|
-
Returns:
|
|
305
|
-
The tracker for the conversation ID.
|
|
306
|
-
"""
|
|
307
|
-
tracker = await self.retrieve_full_tracker(sender_id)
|
|
308
|
-
|
|
309
|
-
if tracker is None:
|
|
310
|
-
tracker = await self.create_tracker(
|
|
311
|
-
sender_id, append_action_listen=append_action_listen
|
|
312
|
-
)
|
|
313
|
-
|
|
314
|
-
return tracker
|
|
315
|
-
|
|
316
|
-
async def stream_events(self, tracker: DialogueStateTracker) -> None:
|
|
317
|
-
"""Streams events to a message broker."""
|
|
318
|
-
if self.event_broker is None:
|
|
319
|
-
structlogger.debug(
|
|
320
|
-
"tracker_store.stream_events.no_broker_configured",
|
|
321
|
-
event_info="No event broker configured. Skipping streaming events.",
|
|
322
|
-
)
|
|
323
|
-
return None
|
|
324
|
-
|
|
325
|
-
old_tracker = await self.retrieve(tracker.sender_id)
|
|
326
|
-
new_events = TrackerEventDiffEngine.event_difference(old_tracker, tracker)
|
|
327
|
-
|
|
328
|
-
await self._stream_new_events(self.event_broker, new_events, tracker.sender_id)
|
|
329
|
-
|
|
330
|
-
async def _stream_new_events(
|
|
331
|
-
self,
|
|
332
|
-
event_broker: EventBroker,
|
|
333
|
-
new_events: List[Event],
|
|
334
|
-
sender_id: Text,
|
|
335
|
-
) -> None:
|
|
336
|
-
"""Publishes new tracker events to a message broker."""
|
|
337
|
-
for event in new_events:
|
|
338
|
-
body = {"sender_id": sender_id}
|
|
339
|
-
body.update(event.as_dict())
|
|
340
|
-
event_broker.publish(body)
|
|
341
|
-
|
|
342
|
-
async def keys(self) -> Iterable[Text]:
|
|
343
|
-
"""Returns the set of values for the tracker store's primary key."""
|
|
344
|
-
raise NotImplementedError()
|
|
345
|
-
|
|
346
|
-
async def count_conversations(self, after_timestamp: float = 0.0) -> int:
|
|
347
|
-
"""Returns the number of conversations that have occurred after a timestamp.
|
|
348
|
-
|
|
349
|
-
By default, this method returns the number of conversations that
|
|
350
|
-
have occurred after the Unix epoch (i.e. timestamp 0). A conversation
|
|
351
|
-
is considered to have occurred after a timestamp if at least one event
|
|
352
|
-
happened after that timestamp.
|
|
353
|
-
"""
|
|
354
|
-
tracker_keys = await self.keys()
|
|
355
|
-
|
|
356
|
-
conversation_count = 0
|
|
357
|
-
for key in tracker_keys:
|
|
358
|
-
tracker = await self.retrieve(key)
|
|
359
|
-
if tracker is None or not tracker.events:
|
|
360
|
-
continue
|
|
361
|
-
|
|
362
|
-
last_event = tracker.events[-1]
|
|
363
|
-
if last_event.timestamp >= after_timestamp:
|
|
364
|
-
conversation_count += 1
|
|
365
|
-
|
|
366
|
-
return conversation_count
|
|
367
|
-
|
|
368
|
-
def deserialise_tracker(
|
|
369
|
-
self, sender_id: Text, serialised_tracker: Union[Text, bytes]
|
|
370
|
-
) -> Optional[DialogueStateTracker]:
|
|
371
|
-
"""Deserializes the tracker and returns it."""
|
|
372
|
-
tracker = self.init_tracker(sender_id)
|
|
373
|
-
|
|
374
|
-
try:
|
|
375
|
-
dialogue = Dialogue.from_parameters(json.loads(serialised_tracker))
|
|
376
|
-
except UnicodeDecodeError as e:
|
|
377
|
-
raise TrackerDeserialisationException(
|
|
378
|
-
"Tracker cannot be deserialised. "
|
|
379
|
-
"Trackers must be serialised as json. "
|
|
380
|
-
"Support for deserialising pickled trackers has been removed."
|
|
381
|
-
) from e
|
|
382
|
-
|
|
383
|
-
tracker.recreate_from_dialogue(dialogue)
|
|
384
|
-
|
|
385
|
-
return tracker
|
|
386
|
-
|
|
387
|
-
@property
|
|
388
|
-
def domain(self) -> Domain:
|
|
389
|
-
"""Returns the domain of the tracker store."""
|
|
390
|
-
return self._domain
|
|
391
|
-
|
|
392
|
-
@domain.setter
|
|
393
|
-
def domain(self, domain: Optional[Domain]) -> None:
|
|
394
|
-
self._domain = domain or Domain.empty()
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
class InMemoryTrackerStore(TrackerStore, SerializedTrackerAsText):
|
|
398
|
-
"""Stores conversation history in memory."""
|
|
399
|
-
|
|
400
|
-
def __init__(
|
|
401
|
-
self,
|
|
402
|
-
domain: Domain,
|
|
403
|
-
event_broker: Optional[EventBroker] = None,
|
|
404
|
-
**kwargs: Dict[Text, Any],
|
|
405
|
-
) -> None:
|
|
406
|
-
"""Initializes the tracker store."""
|
|
407
|
-
self.store: Dict[Text, Text] = {}
|
|
408
|
-
super().__init__(domain, event_broker, **kwargs)
|
|
409
|
-
|
|
410
|
-
async def save(self, tracker: DialogueStateTracker) -> None:
|
|
411
|
-
"""Updates and saves the current conversation state."""
|
|
412
|
-
await self.stream_events(tracker)
|
|
413
|
-
serialised = InMemoryTrackerStore.serialise_tracker(tracker)
|
|
414
|
-
self.store[tracker.sender_id] = serialised
|
|
415
|
-
|
|
416
|
-
async def retrieve(self, sender_id: Text) -> Optional[DialogueStateTracker]:
|
|
417
|
-
"""Returns tracker matching sender_id."""
|
|
418
|
-
return await self._retrieve(sender_id, fetch_all_sessions=False)
|
|
419
|
-
|
|
420
|
-
async def keys(self) -> Iterable[Text]:
|
|
421
|
-
"""Returns sender_ids of the Tracker Store in memory."""
|
|
422
|
-
return self.store.keys()
|
|
423
|
-
|
|
424
|
-
async def retrieve_full_tracker(
|
|
425
|
-
self, sender_id: Text
|
|
426
|
-
) -> Optional[DialogueStateTracker]:
|
|
427
|
-
"""Returns tracker matching sender_id.
|
|
428
|
-
|
|
429
|
-
Args:
|
|
430
|
-
sender_id: Conversation ID to fetch the tracker for.
|
|
431
|
-
"""
|
|
432
|
-
return await self._retrieve(sender_id, fetch_all_sessions=True)
|
|
433
|
-
|
|
434
|
-
async def _retrieve(
|
|
435
|
-
self, sender_id: Text, fetch_all_sessions: bool
|
|
436
|
-
) -> Optional[DialogueStateTracker]:
|
|
437
|
-
"""Returns tracker matching sender_id.
|
|
438
|
-
|
|
439
|
-
Args:
|
|
440
|
-
sender_id: Conversation ID to fetch the tracker for.
|
|
441
|
-
fetch_all_sessions: Whether to fetch all sessions or only the last one.
|
|
442
|
-
"""
|
|
443
|
-
if sender_id not in self.store:
|
|
444
|
-
structlogger.debug(
|
|
445
|
-
"in_memory_tracker_store.retrieve.no_tracker_for_sender_id",
|
|
446
|
-
event_info=f"Could not find tracker for conversation ID '{sender_id}'.",
|
|
447
|
-
)
|
|
448
|
-
return None
|
|
449
|
-
|
|
450
|
-
tracker = self.deserialise_tracker(sender_id, self.store[sender_id])
|
|
451
|
-
|
|
452
|
-
if not tracker:
|
|
453
|
-
structlogger.debug(
|
|
454
|
-
"in_memory_tracker_store.retrieve.failed_to_deserialize_tracker",
|
|
455
|
-
event_info=(
|
|
456
|
-
f"Could not deserialize tracker "
|
|
457
|
-
f"for conversation ID '{sender_id}'.",
|
|
458
|
-
),
|
|
459
|
-
)
|
|
460
|
-
return None
|
|
461
|
-
|
|
462
|
-
if fetch_all_sessions:
|
|
463
|
-
return tracker
|
|
464
|
-
|
|
465
|
-
# only return the last session
|
|
466
|
-
multiple_tracker_sessions = (
|
|
467
|
-
rasa.shared.core.trackers.get_trackers_for_conversation_sessions(tracker)
|
|
468
|
-
)
|
|
469
|
-
|
|
470
|
-
if 0 <= len(multiple_tracker_sessions) <= 1:
|
|
471
|
-
return tracker
|
|
472
|
-
|
|
473
|
-
return multiple_tracker_sessions[-1]
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
class RedisTrackerStore(TrackerStore, SerializedTrackerAsText):
|
|
477
|
-
"""Stores conversation history in Redis."""
|
|
478
|
-
|
|
479
|
-
def __init__(
|
|
480
|
-
self,
|
|
481
|
-
domain: Domain,
|
|
482
|
-
host: Text = "localhost",
|
|
483
|
-
port: int = 6379,
|
|
484
|
-
db: int = 0,
|
|
485
|
-
username: Optional[Text] = None,
|
|
486
|
-
password: Optional[Text] = None,
|
|
487
|
-
event_broker: Optional[EventBroker] = None,
|
|
488
|
-
record_exp: Optional[float] = None,
|
|
489
|
-
key_prefix: Optional[Text] = None,
|
|
490
|
-
use_ssl: bool = False,
|
|
491
|
-
ssl_keyfile: Optional[Text] = None,
|
|
492
|
-
ssl_certfile: Optional[Text] = None,
|
|
493
|
-
ssl_ca_certs: Optional[Text] = None,
|
|
494
|
-
**kwargs: Dict[Text, Any],
|
|
495
|
-
) -> None:
|
|
496
|
-
"""Initializes the tracker store."""
|
|
497
|
-
import redis
|
|
498
|
-
|
|
499
|
-
self.red = redis.StrictRedis(
|
|
500
|
-
host=host,
|
|
501
|
-
port=port,
|
|
502
|
-
db=db,
|
|
503
|
-
username=username,
|
|
504
|
-
password=password,
|
|
505
|
-
ssl=use_ssl,
|
|
506
|
-
ssl_keyfile=ssl_keyfile,
|
|
507
|
-
ssl_certfile=ssl_certfile,
|
|
508
|
-
ssl_ca_certs=ssl_ca_certs,
|
|
509
|
-
decode_responses=True,
|
|
510
|
-
)
|
|
511
|
-
self.record_exp = record_exp
|
|
512
|
-
|
|
513
|
-
self.key_prefix = DEFAULT_REDIS_TRACKER_STORE_KEY_PREFIX
|
|
514
|
-
if key_prefix:
|
|
515
|
-
structlogger.debug(
|
|
516
|
-
"redis_tracker_store.init.custom_key_prefix",
|
|
517
|
-
event_info=f"Setting non-default redis key prefix: '{key_prefix}'.",
|
|
518
|
-
)
|
|
519
|
-
self._set_key_prefix(key_prefix)
|
|
520
|
-
|
|
521
|
-
super().__init__(domain, event_broker, **kwargs)
|
|
522
|
-
|
|
523
|
-
def _set_key_prefix(self, key_prefix: Text) -> None:
|
|
524
|
-
if isinstance(key_prefix, str) and key_prefix.isalnum():
|
|
525
|
-
self.key_prefix = key_prefix + ":" + DEFAULT_REDIS_TRACKER_STORE_KEY_PREFIX
|
|
526
|
-
else:
|
|
527
|
-
structlogger.warning(
|
|
528
|
-
"redis_tracker_store.init.invalid_key_prefix",
|
|
529
|
-
event_info=(
|
|
530
|
-
f"Omitting provided non-alphanumeric "
|
|
531
|
-
f"redis key prefix: '{key_prefix}'. "
|
|
532
|
-
f"Using default '{self.key_prefix}' instead."
|
|
533
|
-
),
|
|
534
|
-
)
|
|
535
|
-
|
|
536
|
-
def _get_key_prefix(self) -> Text:
|
|
537
|
-
return self.key_prefix
|
|
538
|
-
|
|
539
|
-
async def save(
|
|
540
|
-
self, tracker: DialogueStateTracker, timeout: Optional[float] = None
|
|
541
|
-
) -> None:
|
|
542
|
-
"""Saves the current conversation state."""
|
|
543
|
-
await self.stream_events(tracker)
|
|
544
|
-
|
|
545
|
-
if not timeout and self.record_exp:
|
|
546
|
-
timeout = self.record_exp
|
|
547
|
-
|
|
548
|
-
stored = self.red.get(self.key_prefix + tracker.sender_id)
|
|
549
|
-
|
|
550
|
-
if stored is not None:
|
|
551
|
-
prior_tracker = self.deserialise_tracker(tracker.sender_id, stored)
|
|
552
|
-
|
|
553
|
-
tracker = self._merge_trackers(prior_tracker, tracker)
|
|
554
|
-
|
|
555
|
-
serialised_tracker = self.serialise_tracker(tracker)
|
|
556
|
-
self.red.set(
|
|
557
|
-
self.key_prefix + tracker.sender_id, serialised_tracker, ex=timeout
|
|
558
|
-
)
|
|
559
|
-
|
|
560
|
-
async def retrieve(self, sender_id: Text) -> Optional[DialogueStateTracker]:
|
|
561
|
-
"""Retrieves tracker for the latest conversation session.
|
|
562
|
-
|
|
563
|
-
The Redis key is formed by appending a prefix to sender_id.
|
|
564
|
-
|
|
565
|
-
Args:
|
|
566
|
-
sender_id: Conversation ID to fetch the tracker for.
|
|
567
|
-
|
|
568
|
-
Returns:
|
|
569
|
-
Tracker containing events from the latest conversation sessions.
|
|
570
|
-
"""
|
|
571
|
-
return await self._retrieve(sender_id, fetch_all_sessions=False)
|
|
572
|
-
|
|
573
|
-
async def retrieve_full_tracker(
|
|
574
|
-
self, sender_id: Text
|
|
575
|
-
) -> Optional[DialogueStateTracker]:
|
|
576
|
-
"""Retrieves tracker for all conversation sessions.
|
|
577
|
-
|
|
578
|
-
The Redis key is formed by appending a prefix to sender_id.
|
|
579
|
-
|
|
580
|
-
Args:
|
|
581
|
-
sender_id: Conversation ID to fetch the tracker for.
|
|
582
|
-
|
|
583
|
-
Returns:
|
|
584
|
-
Tracker containing events from all conversation sessions.
|
|
585
|
-
"""
|
|
586
|
-
return await self._retrieve(sender_id, fetch_all_sessions=True)
|
|
587
|
-
|
|
588
|
-
async def _retrieve(
|
|
589
|
-
self, sender_id: Text, fetch_all_sessions: bool
|
|
590
|
-
) -> Optional[DialogueStateTracker]:
|
|
591
|
-
"""Returns tracker matching sender_id.
|
|
592
|
-
|
|
593
|
-
Args:
|
|
594
|
-
sender_id: Conversation ID to fetch the tracker for.
|
|
595
|
-
fetch_all_sessions: Whether to fetch all sessions or only the last one.
|
|
596
|
-
"""
|
|
597
|
-
stored = self.red.get(self.key_prefix + sender_id)
|
|
598
|
-
if stored is None:
|
|
599
|
-
structlogger.debug(
|
|
600
|
-
"redis_tracker_store.retrieve.no_tracker_for_sender_id",
|
|
601
|
-
event_info=f"Could not find tracker for conversation ID '{sender_id}'.",
|
|
602
|
-
)
|
|
603
|
-
return None
|
|
604
|
-
|
|
605
|
-
tracker = self.deserialise_tracker(sender_id, stored)
|
|
606
|
-
if fetch_all_sessions:
|
|
607
|
-
return tracker
|
|
608
|
-
|
|
609
|
-
# only return the last session
|
|
610
|
-
multiple_tracker_sessions = (
|
|
611
|
-
rasa.shared.core.trackers.get_trackers_for_conversation_sessions(tracker)
|
|
612
|
-
)
|
|
613
|
-
|
|
614
|
-
if 0 <= len(multiple_tracker_sessions) <= 1:
|
|
615
|
-
return tracker
|
|
616
|
-
|
|
617
|
-
return multiple_tracker_sessions[-1]
|
|
618
|
-
|
|
619
|
-
async def keys(self) -> Iterable[Text]:
|
|
620
|
-
"""Returns keys of the Redis Tracker Store."""
|
|
621
|
-
return self.red.keys(self.key_prefix + "*")
|
|
622
|
-
|
|
623
|
-
@staticmethod
|
|
624
|
-
def _merge_trackers(
|
|
625
|
-
prior_tracker: DialogueStateTracker, tracker: DialogueStateTracker
|
|
626
|
-
) -> DialogueStateTracker:
|
|
627
|
-
"""Merges two trackers.
|
|
628
|
-
|
|
629
|
-
Args:
|
|
630
|
-
prior_tracker: Tracker containing events from the previous conversation
|
|
631
|
-
sessions.
|
|
632
|
-
tracker: Tracker containing events from the current conversation session.
|
|
633
|
-
"""
|
|
634
|
-
if not prior_tracker.events:
|
|
635
|
-
return tracker
|
|
636
|
-
|
|
637
|
-
last_event_timestamp = prior_tracker.events[-1].timestamp
|
|
638
|
-
past_tracker = tracker.travel_back_in_time(target_time=last_event_timestamp)
|
|
639
|
-
|
|
640
|
-
if past_tracker.events == prior_tracker.events:
|
|
641
|
-
return tracker
|
|
642
|
-
|
|
643
|
-
merged = tracker.init_copy()
|
|
644
|
-
merged.update_with_events(list(prior_tracker.events), override_timestamp=False)
|
|
645
|
-
|
|
646
|
-
for new_event in tracker.events:
|
|
647
|
-
# Event subclasses implement `__eq__` method that make it difficult
|
|
648
|
-
# to compare events. We use `as_dict` to compare events.
|
|
649
|
-
if all(
|
|
650
|
-
[
|
|
651
|
-
new_event.as_dict() != existing_event.as_dict()
|
|
652
|
-
for existing_event in merged.events
|
|
653
|
-
]
|
|
654
|
-
):
|
|
655
|
-
merged.update(new_event)
|
|
656
|
-
|
|
657
|
-
return merged
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
class DynamoTrackerStore(TrackerStore, SerializedTrackerAsDict):
|
|
661
|
-
"""Stores conversation history in DynamoDB."""
|
|
662
|
-
|
|
663
|
-
def __init__(
|
|
664
|
-
self,
|
|
665
|
-
domain: Domain,
|
|
666
|
-
table_name: Text = "states",
|
|
667
|
-
region: Text = "us-east-1",
|
|
668
|
-
event_broker: Optional[EndpointConfig] = None,
|
|
669
|
-
**kwargs: Dict[Text, Any],
|
|
670
|
-
) -> None:
|
|
671
|
-
"""Initialize `DynamoTrackerStore`.
|
|
672
|
-
|
|
673
|
-
Args:
|
|
674
|
-
domain: Domain associated with this tracker store.
|
|
675
|
-
table_name: The name of the DynamoDB table, does not need to be present a
|
|
676
|
-
priori.
|
|
677
|
-
region: The name of the region associated with the client.
|
|
678
|
-
A client is associated with a single region.
|
|
679
|
-
event_broker: An event broker used to publish events.
|
|
680
|
-
kwargs: Additional kwargs.
|
|
681
|
-
"""
|
|
682
|
-
import boto3
|
|
683
|
-
|
|
684
|
-
self.client = boto3.client("dynamodb", region_name=region)
|
|
685
|
-
self.region = region
|
|
686
|
-
self.table_name = table_name
|
|
687
|
-
self.db = self.get_or_create_table(table_name)
|
|
688
|
-
super().__init__(domain, event_broker, **kwargs)
|
|
689
|
-
|
|
690
|
-
def get_or_create_table(
|
|
691
|
-
self, table_name: Text
|
|
692
|
-
) -> "boto3.resources.factory.dynamodb.Table":
|
|
693
|
-
"""Returns table or creates one if the table name is not in the table list."""
|
|
694
|
-
import boto3
|
|
695
|
-
|
|
696
|
-
dynamo = boto3.resource("dynamodb", region_name=self.region)
|
|
697
|
-
try:
|
|
698
|
-
self.client.describe_table(TableName=table_name)
|
|
699
|
-
except self.client.exceptions.ResourceNotFoundException:
|
|
700
|
-
sanic_workers_count = int(
|
|
701
|
-
os.environ.get(ENV_SANIC_WORKERS, DEFAULT_SANIC_WORKERS)
|
|
702
|
-
)
|
|
703
|
-
|
|
704
|
-
if sanic_workers_count > 1:
|
|
705
|
-
structlogger.error(
|
|
706
|
-
"dynamo_tracker_store.table_creation_not_supported_in_multi_worker_mode",
|
|
707
|
-
event_info=(
|
|
708
|
-
"DynamoDB table creation is not "
|
|
709
|
-
"supported in multi-worker mode. "
|
|
710
|
-
"Table should already exist.",
|
|
711
|
-
),
|
|
712
|
-
)
|
|
713
|
-
raise RasaException(
|
|
714
|
-
"DynamoDB table creation is not supported in "
|
|
715
|
-
"case of multiple sanic workers. To create the table either "
|
|
716
|
-
"run Rasa with a single worker or create the table manually."
|
|
717
|
-
"Here are the defaults which can be used to "
|
|
718
|
-
"create the table manually: "
|
|
719
|
-
f"Table name: {table_name}, Primary key: sender_id, "
|
|
720
|
-
f"key type `HASH`, attribute type `S` (String), "
|
|
721
|
-
"Provisioned throughput: Read capacity units: 5, "
|
|
722
|
-
"Write capacity units: 5"
|
|
723
|
-
)
|
|
724
|
-
|
|
725
|
-
table = dynamo.create_table(
|
|
726
|
-
TableName=self.table_name,
|
|
727
|
-
KeySchema=[{"AttributeName": "sender_id", "KeyType": "HASH"}],
|
|
728
|
-
AttributeDefinitions=[
|
|
729
|
-
{"AttributeName": "sender_id", "AttributeType": "S"}
|
|
730
|
-
],
|
|
731
|
-
ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5},
|
|
732
|
-
)
|
|
733
|
-
|
|
734
|
-
# Wait until the table exists.
|
|
735
|
-
table.meta.client.get_waiter("table_exists").wait(TableName=table_name)
|
|
736
|
-
else:
|
|
737
|
-
table = dynamo.Table(table_name)
|
|
738
|
-
|
|
739
|
-
return table
|
|
740
|
-
|
|
741
|
-
async def save(self, tracker: DialogueStateTracker) -> None:
|
|
742
|
-
"""Saves the current conversation state."""
|
|
743
|
-
await self.stream_events(tracker)
|
|
744
|
-
serialized = self.serialise_tracker(tracker)
|
|
745
|
-
|
|
746
|
-
self.db.put_item(Item=serialized)
|
|
747
|
-
|
|
748
|
-
@staticmethod
|
|
749
|
-
def serialise_tracker(
|
|
750
|
-
tracker: "DialogueStateTracker",
|
|
751
|
-
) -> Dict:
|
|
752
|
-
"""Serializes the tracker, returns object with decimal types.
|
|
753
|
-
|
|
754
|
-
DynamoDB cannot store `float`s, so we'll convert them to `Decimal`s.
|
|
755
|
-
"""
|
|
756
|
-
return rasa.utils.json_utils.replace_floats_with_decimals(
|
|
757
|
-
SerializedTrackerAsDict.serialise_tracker(tracker)
|
|
758
|
-
)
|
|
759
|
-
|
|
760
|
-
async def retrieve(self, sender_id: Text) -> Optional[DialogueStateTracker]:
|
|
761
|
-
"""Retrieve dialogues for a sender_id in reverse-chronological order.
|
|
762
|
-
|
|
763
|
-
Based on the session_date sort key.
|
|
764
|
-
"""
|
|
765
|
-
return await self._retrieve(sender_id, fetch_all_sessions=False)
|
|
766
|
-
|
|
767
|
-
async def retrieve_full_tracker(
|
|
768
|
-
self, sender_id: Text
|
|
769
|
-
) -> Optional[DialogueStateTracker]:
|
|
770
|
-
"""Retrieves tracker for all conversation sessions.
|
|
771
|
-
|
|
772
|
-
Args:
|
|
773
|
-
sender_id: Conversation ID to fetch the tracker for.
|
|
774
|
-
"""
|
|
775
|
-
return await self._retrieve(sender_id, fetch_all_sessions=True)
|
|
776
|
-
|
|
777
|
-
async def _retrieve(
|
|
778
|
-
self, sender_id: Text, fetch_all_sessions: bool
|
|
779
|
-
) -> Optional[DialogueStateTracker]:
|
|
780
|
-
"""Returns tracker matching sender_id.
|
|
781
|
-
|
|
782
|
-
Args:
|
|
783
|
-
sender_id: Conversation ID to fetch the tracker for.
|
|
784
|
-
fetch_all_sessions: Whether to fetch all sessions or only the last one.
|
|
785
|
-
"""
|
|
786
|
-
dialogues = self.db.query(
|
|
787
|
-
KeyConditionExpression=Key("sender_id").eq(sender_id),
|
|
788
|
-
ScanIndexForward=False,
|
|
789
|
-
)["Items"]
|
|
790
|
-
|
|
791
|
-
if not dialogues:
|
|
792
|
-
return None
|
|
793
|
-
|
|
794
|
-
if fetch_all_sessions:
|
|
795
|
-
events_with_floats = []
|
|
796
|
-
for dialogue in dialogues:
|
|
797
|
-
if dialogue.get("events"):
|
|
798
|
-
events = rasa.utils.json_utils.replace_decimals_with_floats(
|
|
799
|
-
dialogue["events"]
|
|
800
|
-
)
|
|
801
|
-
events_with_floats += events
|
|
802
|
-
else:
|
|
803
|
-
events = dialogues[0].get("events", [])
|
|
804
|
-
# `float`s are stored as `Decimal` objects - we need to convert them back
|
|
805
|
-
events_with_floats = rasa.utils.json_utils.replace_decimals_with_floats(
|
|
806
|
-
events
|
|
807
|
-
)
|
|
808
|
-
|
|
809
|
-
if self.domain is None:
|
|
810
|
-
slots = []
|
|
811
|
-
else:
|
|
812
|
-
slots = self.domain.slots
|
|
813
|
-
|
|
814
|
-
return DialogueStateTracker.from_dict(sender_id, events_with_floats, slots)
|
|
815
|
-
|
|
816
|
-
async def keys(self) -> Iterable[Text]:
|
|
817
|
-
"""Returns sender_ids of the `DynamoTrackerStore`."""
|
|
818
|
-
response = self.db.scan(ProjectionExpression="sender_id")
|
|
819
|
-
sender_ids = [i["sender_id"] for i in response["Items"]]
|
|
820
|
-
|
|
821
|
-
while response.get("LastEvaluatedKey"):
|
|
822
|
-
response = self.db.scan(
|
|
823
|
-
ProjectionExpression="sender_id",
|
|
824
|
-
ExclusiveStartKey=response["LastEvaluatedKey"],
|
|
825
|
-
)
|
|
826
|
-
sender_ids.extend([i["sender_id"] for i in response["Items"]])
|
|
827
|
-
|
|
828
|
-
return sender_ids
|
|
829
|
-
|
|
830
|
-
|
|
831
|
-
class MongoTrackerStore(TrackerStore, SerializedTrackerAsText):
|
|
832
|
-
"""Stores conversation history in Mongo.
|
|
833
|
-
|
|
834
|
-
Property methods:
|
|
835
|
-
conversations: returns the current conversation
|
|
836
|
-
"""
|
|
837
|
-
|
|
838
|
-
def __init__(
|
|
839
|
-
self,
|
|
840
|
-
domain: Domain,
|
|
841
|
-
host: Optional[Text] = "mongodb://localhost:27017",
|
|
842
|
-
db: Optional[Text] = "rasa",
|
|
843
|
-
username: Optional[Text] = None,
|
|
844
|
-
password: Optional[Text] = None,
|
|
845
|
-
auth_source: Optional[Text] = "admin",
|
|
846
|
-
collection: Text = "conversations",
|
|
847
|
-
event_broker: Optional[EventBroker] = None,
|
|
848
|
-
**kwargs: Dict[Text, Any],
|
|
849
|
-
) -> None:
|
|
850
|
-
from pymongo import MongoClient
|
|
851
|
-
from pymongo.database import Database
|
|
852
|
-
|
|
853
|
-
self.client: MongoClient = MongoClient(
|
|
854
|
-
host,
|
|
855
|
-
username=username,
|
|
856
|
-
password=password,
|
|
857
|
-
authSource=auth_source,
|
|
858
|
-
# delay connect until process forking is done
|
|
859
|
-
connect=False,
|
|
860
|
-
)
|
|
861
|
-
|
|
862
|
-
self.db = Database(self.client, db)
|
|
863
|
-
self.collection = collection
|
|
864
|
-
super().__init__(domain, event_broker, **kwargs)
|
|
865
|
-
|
|
866
|
-
self._ensure_indices()
|
|
867
|
-
|
|
868
|
-
@property
|
|
869
|
-
def conversations(self) -> Collection:
|
|
870
|
-
"""Returns the current conversation."""
|
|
871
|
-
return self.db[self.collection]
|
|
872
|
-
|
|
873
|
-
def _ensure_indices(self) -> None:
|
|
874
|
-
"""Create an index on the sender_id."""
|
|
875
|
-
self.conversations.create_index("sender_id")
|
|
876
|
-
|
|
877
|
-
@staticmethod
|
|
878
|
-
def _current_tracker_state_without_events(tracker: DialogueStateTracker) -> Dict:
|
|
879
|
-
# get current tracker state and remove `events` key from state
|
|
880
|
-
# since events are pushed separately in the `update_one()` operation
|
|
881
|
-
state = tracker.current_state(EventVerbosity.ALL)
|
|
882
|
-
state.pop("events", None)
|
|
883
|
-
|
|
884
|
-
return state
|
|
885
|
-
|
|
886
|
-
async def save(self, tracker: DialogueStateTracker) -> None:
|
|
887
|
-
"""Saves the current conversation state."""
|
|
888
|
-
await self.stream_events(tracker)
|
|
889
|
-
|
|
890
|
-
additional_events = self._additional_events(tracker)
|
|
891
|
-
|
|
892
|
-
self.conversations.update_one(
|
|
893
|
-
{"sender_id": tracker.sender_id},
|
|
894
|
-
{
|
|
895
|
-
"$set": self._current_tracker_state_without_events(tracker),
|
|
896
|
-
"$push": {
|
|
897
|
-
"events": {"$each": [e.as_dict() for e in additional_events]}
|
|
898
|
-
},
|
|
899
|
-
},
|
|
900
|
-
upsert=True,
|
|
901
|
-
)
|
|
902
|
-
|
|
903
|
-
def _additional_events(self, tracker: DialogueStateTracker) -> Iterator:
|
|
904
|
-
"""Return events from the tracker which aren't currently stored.
|
|
905
|
-
|
|
906
|
-
Args:
|
|
907
|
-
tracker: Tracker to inspect.
|
|
908
|
-
|
|
909
|
-
Returns:
|
|
910
|
-
List of serialised events that aren't currently stored.
|
|
911
|
-
|
|
912
|
-
"""
|
|
913
|
-
stored = self.conversations.find_one({"sender_id": tracker.sender_id}) or {}
|
|
914
|
-
all_events = self._events_from_serialized_tracker(stored)
|
|
915
|
-
|
|
916
|
-
number_events_since_last_session = len(
|
|
917
|
-
self._events_since_last_session_start(all_events)
|
|
918
|
-
)
|
|
919
|
-
|
|
920
|
-
return itertools.islice(
|
|
921
|
-
tracker.events, number_events_since_last_session, len(tracker.events)
|
|
922
|
-
)
|
|
923
|
-
|
|
924
|
-
@staticmethod
|
|
925
|
-
def _events_from_serialized_tracker(serialised: Dict) -> List[Dict]:
|
|
926
|
-
return serialised.get("events", [])
|
|
927
|
-
|
|
928
|
-
@staticmethod
|
|
929
|
-
def _events_since_last_session_start(events: List[Dict]) -> List[Dict]:
|
|
930
|
-
"""Retrieve events since and including the latest `SessionStart` event.
|
|
931
|
-
|
|
932
|
-
Args:
|
|
933
|
-
events: All events for a conversation ID.
|
|
934
|
-
|
|
935
|
-
Returns:
|
|
936
|
-
List of serialised events since and including the latest `SessionStarted`
|
|
937
|
-
event. Returns all events if no such event is found.
|
|
938
|
-
|
|
939
|
-
"""
|
|
940
|
-
events_after_session_start = []
|
|
941
|
-
for event in reversed(events):
|
|
942
|
-
events_after_session_start.append(event)
|
|
943
|
-
if event["event"] == SessionStarted.type_name:
|
|
944
|
-
break
|
|
945
|
-
|
|
946
|
-
return list(reversed(events_after_session_start))
|
|
947
|
-
|
|
948
|
-
async def _retrieve(
|
|
949
|
-
self, sender_id: Text, fetch_events_from_all_sessions: bool
|
|
950
|
-
) -> Optional[List[Dict[Text, Any]]]:
|
|
951
|
-
stored = self.conversations.find_one({"sender_id": sender_id})
|
|
952
|
-
|
|
953
|
-
# look for conversations which have used an `int` sender_id in the past
|
|
954
|
-
# and update them.
|
|
955
|
-
if not stored and sender_id.isdigit():
|
|
956
|
-
from pymongo import ReturnDocument
|
|
957
|
-
|
|
958
|
-
stored = self.conversations.find_one_and_update(
|
|
959
|
-
{"sender_id": int(sender_id)},
|
|
960
|
-
{"$set": {"sender_id": str(sender_id)}},
|
|
961
|
-
return_document=ReturnDocument.AFTER,
|
|
962
|
-
)
|
|
963
|
-
|
|
964
|
-
if not stored:
|
|
965
|
-
return None
|
|
966
|
-
|
|
967
|
-
events = self._events_from_serialized_tracker(stored)
|
|
968
|
-
|
|
969
|
-
if not fetch_events_from_all_sessions:
|
|
970
|
-
events = self._events_since_last_session_start(events)
|
|
971
|
-
|
|
972
|
-
return events
|
|
973
|
-
|
|
974
|
-
async def retrieve(self, sender_id: Text) -> Optional[DialogueStateTracker]:
|
|
975
|
-
"""Retrieves tracker for the latest conversation session."""
|
|
976
|
-
events = await self._retrieve(sender_id, fetch_events_from_all_sessions=False)
|
|
977
|
-
|
|
978
|
-
if not events:
|
|
979
|
-
return None
|
|
980
|
-
|
|
981
|
-
return DialogueStateTracker.from_dict(sender_id, events, self.domain.slots)
|
|
982
|
-
|
|
983
|
-
async def retrieve_full_tracker(
|
|
984
|
-
self, conversation_id: Text
|
|
985
|
-
) -> Optional[DialogueStateTracker]:
|
|
986
|
-
"""Fetching all tracker events across conversation sessions."""
|
|
987
|
-
events = await self._retrieve(
|
|
988
|
-
conversation_id, fetch_events_from_all_sessions=True
|
|
989
|
-
)
|
|
990
|
-
|
|
991
|
-
if not events:
|
|
992
|
-
return None
|
|
993
|
-
|
|
994
|
-
return DialogueStateTracker.from_dict(
|
|
995
|
-
conversation_id, events, self.domain.slots
|
|
996
|
-
)
|
|
997
|
-
|
|
998
|
-
async def keys(self) -> Iterable[Text]:
|
|
999
|
-
"""Returns sender_ids of the Mongo Tracker Store."""
|
|
1000
|
-
return [c["sender_id"] for c in self.conversations.find()]
|
|
1001
|
-
|
|
1002
|
-
|
|
1003
|
-
def _create_sequence(table_name: Text) -> "Sequence":
|
|
1004
|
-
"""Creates a sequence object for a specific table name.
|
|
1005
|
-
|
|
1006
|
-
If using Oracle you will need to create a sequence in your database,
|
|
1007
|
-
as described here: https://rasa.com/docs/rasa-pro/production/tracker-stores#sqltrackerstore
|
|
1008
|
-
Args:
|
|
1009
|
-
table_name: The name of the table, which gets a Sequence assigned
|
|
1010
|
-
|
|
1011
|
-
Returns: A `Sequence` object
|
|
1012
|
-
"""
|
|
1013
|
-
from sqlalchemy.orm import declarative_base
|
|
1014
|
-
|
|
1015
|
-
sequence_name = f"{table_name}_seq"
|
|
1016
|
-
Base = declarative_base()
|
|
1017
|
-
return sa.Sequence(sequence_name, metadata=Base.metadata, optional=True)
|
|
1018
|
-
|
|
1019
|
-
|
|
1020
|
-
def is_postgresql_url(url: Union[Text, "URL"]) -> bool:
|
|
1021
|
-
"""Determine whether `url` configures a PostgreSQL connection.
|
|
1022
|
-
|
|
1023
|
-
Args:
|
|
1024
|
-
url: SQL connection URL.
|
|
1025
|
-
|
|
1026
|
-
Returns:
|
|
1027
|
-
`True` if `url` is a PostgreSQL connection URL.
|
|
1028
|
-
"""
|
|
1029
|
-
if isinstance(url, str):
|
|
1030
|
-
return "postgresql" in url
|
|
1031
|
-
|
|
1032
|
-
return url.drivername == "postgresql"
|
|
1033
|
-
|
|
1034
|
-
|
|
1035
|
-
def create_engine_kwargs(url: Union[Text, "URL"]) -> Dict[Text, Any]:
|
|
1036
|
-
"""Get `sqlalchemy.create_engine()` kwargs.
|
|
1037
|
-
|
|
1038
|
-
Args:
|
|
1039
|
-
url: SQL connection URL.
|
|
1040
|
-
|
|
1041
|
-
Returns:
|
|
1042
|
-
kwargs to be passed into `sqlalchemy.create_engine()`.
|
|
1043
|
-
"""
|
|
1044
|
-
if not is_postgresql_url(url):
|
|
1045
|
-
return {}
|
|
1046
|
-
|
|
1047
|
-
kwargs: Dict[Text, Any] = {}
|
|
1048
|
-
|
|
1049
|
-
schema_name = os.environ.get(POSTGRESQL_SCHEMA)
|
|
1050
|
-
|
|
1051
|
-
if schema_name:
|
|
1052
|
-
structlogger.debug(
|
|
1053
|
-
"postgresql_tracker_store.schema_name",
|
|
1054
|
-
event_inf=f"Using PostgreSQL schema '{schema_name}'.",
|
|
1055
|
-
)
|
|
1056
|
-
kwargs["connect_args"] = {"options": f"-csearch_path={schema_name}"}
|
|
1057
|
-
|
|
1058
|
-
# pool_size and max_overflow can be set to control the number of
|
|
1059
|
-
# connections that are kept in the connection pool. Not available
|
|
1060
|
-
# for SQLite, and only tested for PostgreSQL. See
|
|
1061
|
-
# https://docs.sqlalchemy.org/en/13/core/pooling.html#sqlalchemy.pool.QueuePool
|
|
1062
|
-
kwargs["pool_size"] = int(
|
|
1063
|
-
os.environ.get(POSTGRESQL_POOL_SIZE, POSTGRESQL_DEFAULT_POOL_SIZE)
|
|
1064
|
-
)
|
|
1065
|
-
kwargs["max_overflow"] = int(
|
|
1066
|
-
os.environ.get(POSTGRESQL_MAX_OVERFLOW, POSTGRESQL_DEFAULT_MAX_OVERFLOW)
|
|
1067
|
-
)
|
|
1068
|
-
|
|
1069
|
-
return kwargs
|
|
1070
|
-
|
|
1071
|
-
|
|
1072
|
-
def ensure_schema_exists(session: "Session") -> None:
|
|
1073
|
-
"""Ensure that the requested PostgreSQL schema exists in the database.
|
|
1074
|
-
|
|
1075
|
-
Args:
|
|
1076
|
-
session: Session used to inspect the database.
|
|
1077
|
-
|
|
1078
|
-
Raises:
|
|
1079
|
-
`ValueError` if the requested schema does not exist.
|
|
1080
|
-
RasaException if no engine can be obtained from session.
|
|
1081
|
-
"""
|
|
1082
|
-
schema_name = os.environ.get(POSTGRESQL_SCHEMA)
|
|
1083
|
-
|
|
1084
|
-
if not schema_name:
|
|
1085
|
-
return
|
|
1086
|
-
|
|
1087
|
-
engine = session.get_bind()
|
|
1088
|
-
|
|
1089
|
-
if not isinstance(engine, sa.engine.base.Engine):
|
|
1090
|
-
# The "bind" is usually an instance of Engine, except in the case
|
|
1091
|
-
# where the session has been explicitly bound directly to a connection.
|
|
1092
|
-
raise RasaException("Cannot ensure schema exists as no engine exists.")
|
|
1093
|
-
|
|
1094
|
-
if is_postgresql_url(engine.url):
|
|
1095
|
-
query = sa.exists(
|
|
1096
|
-
sa.select(sa.text("schema_name"))
|
|
1097
|
-
.select_from(sa.text("information_schema.schemata"))
|
|
1098
|
-
.where(sa.text(f"schema_name = '{schema_name}'"))
|
|
1099
|
-
)
|
|
1100
|
-
if not session.query(query).scalar():
|
|
1101
|
-
raise ValueError(schema_name)
|
|
1102
|
-
|
|
1103
|
-
|
|
1104
|
-
def validate_port(port: Any) -> Optional[int]:
|
|
1105
|
-
"""Ensure that port can be converted to integer.
|
|
1106
|
-
|
|
1107
|
-
Raises:
|
|
1108
|
-
RasaException if port cannot be cast to integer.
|
|
1109
|
-
"""
|
|
1110
|
-
if port is not None and not isinstance(port, int):
|
|
1111
|
-
try:
|
|
1112
|
-
port = int(port)
|
|
1113
|
-
except ValueError as e:
|
|
1114
|
-
raise RasaException(f"The port '{port}' cannot be cast to integer.") from e
|
|
1115
|
-
|
|
1116
|
-
return port
|
|
1117
|
-
|
|
1118
|
-
|
|
1119
|
-
class SQLTrackerStore(TrackerStore, SerializedTrackerAsText):
|
|
1120
|
-
"""Store which can save and retrieve trackers from an SQL database."""
|
|
1121
|
-
|
|
1122
|
-
from sqlalchemy.orm import DeclarativeBase
|
|
1123
|
-
|
|
1124
|
-
class Base(DeclarativeBase):
|
|
1125
|
-
"""Base class for all tracker store tables."""
|
|
1126
|
-
|
|
1127
|
-
pass
|
|
1128
|
-
|
|
1129
|
-
class SQLEvent(Base):
|
|
1130
|
-
"""Represents an event in the SQL Tracker Store."""
|
|
1131
|
-
|
|
1132
|
-
__tablename__ = "events"
|
|
1133
|
-
|
|
1134
|
-
# `create_sequence` is needed to create a sequence for databases that
|
|
1135
|
-
# don't autoincrement Integer primary keys (e.g. Oracle)
|
|
1136
|
-
id = sa.Column(sa.Integer, _create_sequence(__tablename__), primary_key=True)
|
|
1137
|
-
sender_id = sa.Column(sa.String(255), nullable=False, index=True)
|
|
1138
|
-
type_name = sa.Column(sa.String(255), nullable=False)
|
|
1139
|
-
timestamp = sa.Column(sa.Float)
|
|
1140
|
-
intent_name = sa.Column(sa.String(255))
|
|
1141
|
-
action_name = sa.Column(sa.String(255))
|
|
1142
|
-
data = sa.Column(sa.Text)
|
|
1143
|
-
|
|
1144
|
-
def __init__(
|
|
1145
|
-
self,
|
|
1146
|
-
domain: Optional[Domain] = None,
|
|
1147
|
-
dialect: Text = "sqlite",
|
|
1148
|
-
host: Optional[Text] = None,
|
|
1149
|
-
port: Optional[int] = None,
|
|
1150
|
-
db: Text = "rasa.db",
|
|
1151
|
-
username: Optional[Text] = None,
|
|
1152
|
-
password: Optional[Text] = None,
|
|
1153
|
-
event_broker: Optional[EventBroker] = None,
|
|
1154
|
-
login_db: Optional[Text] = None,
|
|
1155
|
-
query: Optional[Dict] = None,
|
|
1156
|
-
**kwargs: Dict[Text, Any],
|
|
1157
|
-
) -> None:
|
|
1158
|
-
import sqlalchemy.exc
|
|
1159
|
-
|
|
1160
|
-
port = validate_port(port)
|
|
1161
|
-
|
|
1162
|
-
engine_url = self.get_db_url(
|
|
1163
|
-
dialect, host, port, db, username, password, login_db, query
|
|
1164
|
-
)
|
|
1165
|
-
|
|
1166
|
-
self.engine = sa.create_engine(engine_url, **create_engine_kwargs(engine_url))
|
|
1167
|
-
|
|
1168
|
-
structlogger.debug(
|
|
1169
|
-
"sql_tracker_store.connect_to_sql_database",
|
|
1170
|
-
event_info=f"Attempting to connect to database via '{self.engine.url!r}'.",
|
|
1171
|
-
)
|
|
1172
|
-
|
|
1173
|
-
# Database might take a while to come up
|
|
1174
|
-
while True:
|
|
1175
|
-
try:
|
|
1176
|
-
# if `login_db` has been provided, use current channel with
|
|
1177
|
-
# that database to create working database `db`
|
|
1178
|
-
if login_db:
|
|
1179
|
-
self._create_database_and_update_engine(db, engine_url)
|
|
1180
|
-
|
|
1181
|
-
try:
|
|
1182
|
-
self.Base.metadata.create_all(self.engine)
|
|
1183
|
-
except (
|
|
1184
|
-
sqlalchemy.exc.OperationalError,
|
|
1185
|
-
sqlalchemy.exc.ProgrammingError,
|
|
1186
|
-
) as e:
|
|
1187
|
-
# Several Rasa services started in parallel may attempt to
|
|
1188
|
-
# create tables at the same time. That is okay so long as
|
|
1189
|
-
# the first services finishes the table creation.
|
|
1190
|
-
structlogger.error(
|
|
1191
|
-
"sql_tracker_store.create_tables_failed",
|
|
1192
|
-
event_info="Could not create tables",
|
|
1193
|
-
exec_info=e,
|
|
1194
|
-
)
|
|
1195
|
-
|
|
1196
|
-
self.sessionmaker = sa.orm.session.sessionmaker(bind=self.engine)
|
|
1197
|
-
break
|
|
1198
|
-
except (
|
|
1199
|
-
sqlalchemy.exc.OperationalError,
|
|
1200
|
-
sqlalchemy.exc.IntegrityError,
|
|
1201
|
-
) as error:
|
|
1202
|
-
structlogger.warning(
|
|
1203
|
-
"sql_tracker_store.initialisation_error",
|
|
1204
|
-
event_info="Failed to establish a connection to the SQL database. ",
|
|
1205
|
-
exc_info=error,
|
|
1206
|
-
)
|
|
1207
|
-
sleep(5)
|
|
1208
|
-
|
|
1209
|
-
structlogger.debug(
|
|
1210
|
-
"sql_tracker_store.connected_to_sql_database",
|
|
1211
|
-
event_info=f"Connection to SQL database '{db}' successful.",
|
|
1212
|
-
)
|
|
1213
|
-
|
|
1214
|
-
super().__init__(domain, event_broker, **kwargs)
|
|
1215
|
-
|
|
1216
|
-
@staticmethod
|
|
1217
|
-
def get_db_url(
|
|
1218
|
-
dialect: Text = "sqlite",
|
|
1219
|
-
host: Optional[Text] = None,
|
|
1220
|
-
port: Optional[int] = None,
|
|
1221
|
-
db: Text = "rasa.db",
|
|
1222
|
-
username: Optional[Text] = None,
|
|
1223
|
-
password: Optional[Text] = None,
|
|
1224
|
-
login_db: Optional[Text] = None,
|
|
1225
|
-
query: Optional[Dict] = None,
|
|
1226
|
-
) -> Union[Text, "URL"]:
|
|
1227
|
-
"""Build an SQLAlchemy `URL` object.
|
|
1228
|
-
|
|
1229
|
-
The URL object represents the parameters needed to connect to an
|
|
1230
|
-
SQL database.
|
|
1231
|
-
|
|
1232
|
-
Args:
|
|
1233
|
-
dialect: SQL database type.
|
|
1234
|
-
host: Database network host.
|
|
1235
|
-
port: Database network port.
|
|
1236
|
-
db: Database name.
|
|
1237
|
-
username: User name to use when connecting to the database.
|
|
1238
|
-
password: Password for database user.
|
|
1239
|
-
login_db: Alternative database name to which initially connect, and create
|
|
1240
|
-
the database specified by `db` (PostgreSQL only).
|
|
1241
|
-
query: Dictionary of options to be passed to the dialect and/or the
|
|
1242
|
-
DBAPI upon connect.
|
|
1243
|
-
|
|
1244
|
-
Returns:
|
|
1245
|
-
URL ready to be used with an SQLAlchemy `Engine` object.
|
|
1246
|
-
"""
|
|
1247
|
-
from urllib import parse
|
|
1248
|
-
|
|
1249
|
-
# Users might specify a url in the host
|
|
1250
|
-
if host and "://" in host:
|
|
1251
|
-
# assumes this is a complete database host name including
|
|
1252
|
-
# e.g. `postgres://...`
|
|
1253
|
-
return host
|
|
1254
|
-
elif host:
|
|
1255
|
-
# add fake scheme to properly parse components
|
|
1256
|
-
parsed = parse.urlsplit(f"scheme://{host}")
|
|
1257
|
-
|
|
1258
|
-
# users might include the port in the url
|
|
1259
|
-
port = parsed.port or port
|
|
1260
|
-
host = parsed.hostname or host
|
|
1261
|
-
|
|
1262
|
-
if not query:
|
|
1263
|
-
# query needs to be set in order to create a URL
|
|
1264
|
-
query = {}
|
|
1265
|
-
|
|
1266
|
-
return sa.engine.url.URL(
|
|
1267
|
-
dialect,
|
|
1268
|
-
username,
|
|
1269
|
-
password,
|
|
1270
|
-
host,
|
|
1271
|
-
port,
|
|
1272
|
-
database=login_db if login_db else db,
|
|
1273
|
-
query=query,
|
|
1274
|
-
)
|
|
1275
|
-
|
|
1276
|
-
def _create_database_and_update_engine(self, db: Text, engine_url: "URL") -> None:
|
|
1277
|
-
"""Creates database `db` and updates engine accordingly."""
|
|
1278
|
-
from sqlalchemy import create_engine
|
|
1279
|
-
|
|
1280
|
-
if self.engine.dialect.name != "postgresql":
|
|
1281
|
-
rasa.shared.utils.io.raise_warning(
|
|
1282
|
-
"The parameter 'login_db' can only be used with a postgres database."
|
|
1283
|
-
)
|
|
1284
|
-
return
|
|
1285
|
-
|
|
1286
|
-
self._create_database(self.engine, db)
|
|
1287
|
-
self.engine.dispose()
|
|
1288
|
-
engine_url = sa.engine.url.URL(
|
|
1289
|
-
drivername=engine_url.drivername,
|
|
1290
|
-
username=engine_url.username,
|
|
1291
|
-
password=engine_url.password,
|
|
1292
|
-
host=engine_url.host,
|
|
1293
|
-
port=engine_url.port,
|
|
1294
|
-
database=db,
|
|
1295
|
-
query=engine_url.query,
|
|
1296
|
-
)
|
|
1297
|
-
self.engine = create_engine(engine_url)
|
|
1298
|
-
|
|
1299
|
-
@staticmethod
|
|
1300
|
-
def _create_database(engine: "Engine", database_name: Text) -> None:
|
|
1301
|
-
"""Create database `db` on `engine` if it does not exist."""
|
|
1302
|
-
import sqlalchemy.exc
|
|
1303
|
-
|
|
1304
|
-
with engine.connect() as connection:
|
|
1305
|
-
connection.execution_options(isolation_level="AUTOCOMMIT")
|
|
1306
|
-
matching_rows = connection.execute(
|
|
1307
|
-
sa.text(
|
|
1308
|
-
f"SELECT 1 FROM pg_catalog.pg_database "
|
|
1309
|
-
f"WHERE datname = '{database_name}'"
|
|
1310
|
-
)
|
|
1311
|
-
).rowcount
|
|
1312
|
-
|
|
1313
|
-
if not matching_rows:
|
|
1314
|
-
try:
|
|
1315
|
-
connection.execute(sa.text(f"CREATE DATABASE {database_name}"))
|
|
1316
|
-
except (
|
|
1317
|
-
sqlalchemy.exc.ProgrammingError,
|
|
1318
|
-
sqlalchemy.exc.IntegrityError,
|
|
1319
|
-
) as e:
|
|
1320
|
-
structlogger.error(
|
|
1321
|
-
"sql_tracker_store.create_database_failed",
|
|
1322
|
-
event_info=f"Could not create database '{database_name}'",
|
|
1323
|
-
exec_info=e,
|
|
1324
|
-
)
|
|
1325
|
-
|
|
1326
|
-
@contextlib.contextmanager
|
|
1327
|
-
def session_scope(self) -> Generator["Session", None, None]:
|
|
1328
|
-
"""Provide a transactional scope around a series of operations."""
|
|
1329
|
-
session = self.sessionmaker()
|
|
1330
|
-
try:
|
|
1331
|
-
ensure_schema_exists(session)
|
|
1332
|
-
yield session
|
|
1333
|
-
except ValueError as e:
|
|
1334
|
-
rasa.shared.utils.cli.print_error_and_exit(
|
|
1335
|
-
f"Requested PostgreSQL schema '{e}' was not found in the database. To "
|
|
1336
|
-
f"continue, please create the schema by running 'CREATE DATABASE {e};' "
|
|
1337
|
-
f"or unset the '{POSTGRESQL_SCHEMA}' environment variable in order to "
|
|
1338
|
-
f"use the default schema. Exiting application."
|
|
1339
|
-
)
|
|
1340
|
-
finally:
|
|
1341
|
-
session.close()
|
|
1342
|
-
|
|
1343
|
-
async def keys(self) -> Iterable[Text]:
|
|
1344
|
-
"""Returns sender_ids of the SQLTrackerStore."""
|
|
1345
|
-
with self.session_scope() as session:
|
|
1346
|
-
sender_ids = session.query(self.SQLEvent.sender_id).distinct().all()
|
|
1347
|
-
return [sender_id for (sender_id,) in sender_ids]
|
|
1348
|
-
|
|
1349
|
-
async def retrieve(self, sender_id: Text) -> Optional[DialogueStateTracker]:
|
|
1350
|
-
"""Retrieves tracker for the latest conversation session."""
|
|
1351
|
-
return await self._retrieve(sender_id, fetch_events_from_all_sessions=False)
|
|
1352
|
-
|
|
1353
|
-
async def retrieve_full_tracker(
|
|
1354
|
-
self, conversation_id: Text
|
|
1355
|
-
) -> Optional[DialogueStateTracker]:
|
|
1356
|
-
"""Fetching all tracker events across conversation sessions."""
|
|
1357
|
-
return await self._retrieve(
|
|
1358
|
-
conversation_id, fetch_events_from_all_sessions=True
|
|
1359
|
-
)
|
|
1360
|
-
|
|
1361
|
-
async def count_conversations(self, after_timestamp: float = 0.0) -> int:
|
|
1362
|
-
"""Returns the number of conversations that have occurred after a timestamp.
|
|
1363
|
-
|
|
1364
|
-
By default, this method returns the number of conversations that
|
|
1365
|
-
have occurred after the Unix epoch (i.e. timestamp 0).
|
|
1366
|
-
"""
|
|
1367
|
-
with self.session_scope() as session:
|
|
1368
|
-
query = (
|
|
1369
|
-
session.query(self.SQLEvent.sender_id)
|
|
1370
|
-
.distinct()
|
|
1371
|
-
.filter(self.SQLEvent.timestamp >= after_timestamp)
|
|
1372
|
-
)
|
|
1373
|
-
return query.count()
|
|
1374
|
-
|
|
1375
|
-
async def _retrieve(
|
|
1376
|
-
self, sender_id: Text, fetch_events_from_all_sessions: bool
|
|
1377
|
-
) -> Optional[DialogueStateTracker]:
|
|
1378
|
-
with self.session_scope() as session:
|
|
1379
|
-
serialised_events = self._event_query(
|
|
1380
|
-
session,
|
|
1381
|
-
sender_id,
|
|
1382
|
-
fetch_events_from_all_sessions=fetch_events_from_all_sessions,
|
|
1383
|
-
).all()
|
|
1384
|
-
|
|
1385
|
-
events = [json.loads(event.data) for event in serialised_events]
|
|
1386
|
-
|
|
1387
|
-
if self.domain and len(events) > 0:
|
|
1388
|
-
structlogger.debug(
|
|
1389
|
-
"sql_tracker_store.recreating_tracker",
|
|
1390
|
-
event_info=f"Recreating tracker from sender id '{sender_id}'",
|
|
1391
|
-
)
|
|
1392
|
-
return DialogueStateTracker.from_dict(
|
|
1393
|
-
sender_id, events, self.domain.slots
|
|
1394
|
-
)
|
|
1395
|
-
else:
|
|
1396
|
-
structlogger.debug(
|
|
1397
|
-
"sql_tracker_store._retrieve.no_tracker_for_sender_id",
|
|
1398
|
-
event_info=(
|
|
1399
|
-
f"Can't retrieve tracker matching "
|
|
1400
|
-
f"sender id '{sender_id}' from SQL storage. "
|
|
1401
|
-
f"Returning `None` instead.",
|
|
1402
|
-
),
|
|
1403
|
-
)
|
|
1404
|
-
return None
|
|
1405
|
-
|
|
1406
|
-
def _event_query(
|
|
1407
|
-
self, session: "Session", sender_id: Text, fetch_events_from_all_sessions: bool
|
|
1408
|
-
) -> "Query":
|
|
1409
|
-
"""Provide the query to retrieve the conversation events for a specific sender.
|
|
1410
|
-
|
|
1411
|
-
The events are ordered by ID to ensure correct sequence of events.
|
|
1412
|
-
As `timestamp` is not guaranteed to be unique and low-precision (float), it
|
|
1413
|
-
cannot be used to order the events.
|
|
1414
|
-
|
|
1415
|
-
Args:
|
|
1416
|
-
session: Current database session.
|
|
1417
|
-
sender_id: Sender id whose conversation events should be retrieved.
|
|
1418
|
-
fetch_events_from_all_sessions: Whether to fetch events from all
|
|
1419
|
-
conversation sessions. If `False`, only fetch events from the
|
|
1420
|
-
latest conversation session.
|
|
1421
|
-
|
|
1422
|
-
Returns:
|
|
1423
|
-
Query to get the conversation events.
|
|
1424
|
-
"""
|
|
1425
|
-
# Subquery to find the timestamp of the latest `SessionStarted` event
|
|
1426
|
-
session_start_sub_query = (
|
|
1427
|
-
session.query(sa.func.max(self.SQLEvent.timestamp).label("session_start"))
|
|
1428
|
-
.filter(
|
|
1429
|
-
self.SQLEvent.sender_id == sender_id,
|
|
1430
|
-
self.SQLEvent.type_name == SessionStarted.type_name,
|
|
1431
|
-
)
|
|
1432
|
-
.subquery()
|
|
1433
|
-
)
|
|
1434
|
-
|
|
1435
|
-
event_query = session.query(self.SQLEvent).filter(
|
|
1436
|
-
self.SQLEvent.sender_id == sender_id
|
|
1437
|
-
)
|
|
1438
|
-
if not fetch_events_from_all_sessions:
|
|
1439
|
-
event_query = event_query.filter(
|
|
1440
|
-
# Find events after the latest `SessionStarted` event or return all
|
|
1441
|
-
# events
|
|
1442
|
-
sa.or_(
|
|
1443
|
-
self.SQLEvent.timestamp >= session_start_sub_query.c.session_start,
|
|
1444
|
-
session_start_sub_query.c.session_start.is_(None),
|
|
1445
|
-
)
|
|
1446
|
-
)
|
|
1447
|
-
|
|
1448
|
-
return event_query.order_by(self.SQLEvent.id)
|
|
1449
|
-
|
|
1450
|
-
async def save(self, tracker: DialogueStateTracker) -> None:
|
|
1451
|
-
"""Update database with events from the current conversation."""
|
|
1452
|
-
await self.stream_events(tracker)
|
|
1453
|
-
|
|
1454
|
-
with self.session_scope() as session:
|
|
1455
|
-
# only store recent events
|
|
1456
|
-
events = self._additional_events(session, tracker)
|
|
1457
|
-
|
|
1458
|
-
for event in events:
|
|
1459
|
-
data = event.as_dict()
|
|
1460
|
-
intent = (
|
|
1461
|
-
data.get("parse_data", {}).get("intent", {}).get(INTENT_NAME_KEY)
|
|
1462
|
-
)
|
|
1463
|
-
action = data.get("name")
|
|
1464
|
-
timestamp = data.get("timestamp")
|
|
1465
|
-
|
|
1466
|
-
# noinspection PyArgumentList
|
|
1467
|
-
session.add(
|
|
1468
|
-
self.SQLEvent(
|
|
1469
|
-
sender_id=tracker.sender_id,
|
|
1470
|
-
type_name=event.type_name,
|
|
1471
|
-
timestamp=timestamp,
|
|
1472
|
-
intent_name=intent,
|
|
1473
|
-
action_name=action,
|
|
1474
|
-
data=json.dumps(data),
|
|
1475
|
-
)
|
|
1476
|
-
)
|
|
1477
|
-
session.commit()
|
|
1478
|
-
|
|
1479
|
-
structlogger.debug(
|
|
1480
|
-
"sql_tracker_store.save_tracker",
|
|
1481
|
-
event_info=(
|
|
1482
|
-
f"Tracker with sender_id " f"'{tracker.sender_id}' stored to database",
|
|
1483
|
-
),
|
|
1484
|
-
)
|
|
1485
|
-
|
|
1486
|
-
def _additional_events(
|
|
1487
|
-
self, session: "Session", tracker: DialogueStateTracker
|
|
1488
|
-
) -> Iterator:
|
|
1489
|
-
"""Return events from the tracker which aren't currently stored."""
|
|
1490
|
-
number_of_events_since_last_session = self._event_query(
|
|
1491
|
-
session, tracker.sender_id, fetch_events_from_all_sessions=False
|
|
1492
|
-
).count()
|
|
1493
|
-
|
|
1494
|
-
return itertools.islice(
|
|
1495
|
-
tracker.events, number_of_events_since_last_session, len(tracker.events)
|
|
1496
|
-
)
|
|
1497
|
-
|
|
1498
|
-
|
|
1499
|
-
class FailSafeTrackerStore(TrackerStore):
|
|
1500
|
-
"""Tracker store wrapper.
|
|
1501
|
-
|
|
1502
|
-
Allows a fallback to a different tracker store in case of errors.
|
|
1503
|
-
"""
|
|
1504
|
-
|
|
1505
|
-
def __init__(
|
|
1506
|
-
self,
|
|
1507
|
-
tracker_store: TrackerStore,
|
|
1508
|
-
on_tracker_store_error: Optional[Callable[[Exception], None]] = None,
|
|
1509
|
-
fallback_tracker_store: Optional[TrackerStore] = None,
|
|
1510
|
-
) -> None:
|
|
1511
|
-
"""Create a `FailSafeTrackerStore`.
|
|
1512
|
-
|
|
1513
|
-
Args:
|
|
1514
|
-
tracker_store: Primary tracker store.
|
|
1515
|
-
on_tracker_store_error: Callback which is called when there is an error
|
|
1516
|
-
in the primary tracker store.
|
|
1517
|
-
fallback_tracker_store: Fallback tracker store.
|
|
1518
|
-
"""
|
|
1519
|
-
self._fallback_tracker_store: Optional[TrackerStore] = fallback_tracker_store
|
|
1520
|
-
self._tracker_store = tracker_store
|
|
1521
|
-
self._on_tracker_store_error = on_tracker_store_error
|
|
1522
|
-
|
|
1523
|
-
super().__init__(tracker_store.domain, tracker_store.event_broker)
|
|
1524
|
-
|
|
1525
|
-
@property
|
|
1526
|
-
def domain(self) -> Domain:
|
|
1527
|
-
"""Returns the domain of the primary tracker store."""
|
|
1528
|
-
return self._tracker_store.domain
|
|
1529
|
-
|
|
1530
|
-
@domain.setter
|
|
1531
|
-
def domain(self, domain: Domain) -> None:
|
|
1532
|
-
self._tracker_store.domain = domain
|
|
1533
|
-
|
|
1534
|
-
if self._fallback_tracker_store:
|
|
1535
|
-
self._fallback_tracker_store.domain = domain
|
|
1536
|
-
|
|
1537
|
-
@property
|
|
1538
|
-
def fallback_tracker_store(self) -> TrackerStore:
|
|
1539
|
-
"""Returns the fallback tracker store."""
|
|
1540
|
-
if not self._fallback_tracker_store:
|
|
1541
|
-
self._fallback_tracker_store = InMemoryTrackerStore(
|
|
1542
|
-
self._tracker_store.domain, self._tracker_store.event_broker
|
|
1543
|
-
)
|
|
1544
|
-
|
|
1545
|
-
return self._fallback_tracker_store
|
|
1546
|
-
|
|
1547
|
-
def on_tracker_store_error(self, error: Exception) -> None:
|
|
1548
|
-
"""Calls the callback when there is an error in the primary tracker store."""
|
|
1549
|
-
if self._on_tracker_store_error:
|
|
1550
|
-
self._on_tracker_store_error(error)
|
|
1551
|
-
else:
|
|
1552
|
-
structlogger.error(
|
|
1553
|
-
"fail_safe_tracker_store.tracker_store_error",
|
|
1554
|
-
event_info=(
|
|
1555
|
-
f"Error happened when trying to save conversation tracker to "
|
|
1556
|
-
f"'{self._tracker_store.__class__.__name__}'. Falling back to use "
|
|
1557
|
-
f"the '{InMemoryTrackerStore.__name__}'. Please "
|
|
1558
|
-
f"investigate the following error: {error}."
|
|
1559
|
-
),
|
|
1560
|
-
)
|
|
1561
|
-
|
|
1562
|
-
async def retrieve(self, sender_id: Text) -> Optional[DialogueStateTracker]:
|
|
1563
|
-
"""Calls `retrieve` method of primary tracker store."""
|
|
1564
|
-
try:
|
|
1565
|
-
return await self._tracker_store.retrieve(sender_id)
|
|
1566
|
-
except Exception as e:
|
|
1567
|
-
self.on_tracker_store_retrieve_error(e)
|
|
1568
|
-
return None
|
|
1569
|
-
|
|
1570
|
-
async def keys(self) -> Iterable[Text]:
|
|
1571
|
-
"""Calls `keys` method of primary tracker store."""
|
|
1572
|
-
try:
|
|
1573
|
-
return await self._tracker_store.keys()
|
|
1574
|
-
except Exception as e:
|
|
1575
|
-
self.on_tracker_store_error(e)
|
|
1576
|
-
return []
|
|
1577
|
-
|
|
1578
|
-
async def save(self, tracker: DialogueStateTracker) -> None:
|
|
1579
|
-
"""Calls `save` method of primary tracker store."""
|
|
1580
|
-
try:
|
|
1581
|
-
await self._tracker_store.save(tracker)
|
|
1582
|
-
except Exception as e:
|
|
1583
|
-
self.on_tracker_store_error(e)
|
|
1584
|
-
await self.fallback_tracker_store.save(tracker)
|
|
1585
|
-
|
|
1586
|
-
async def retrieve_full_tracker(
|
|
1587
|
-
self, sender_id: Text
|
|
1588
|
-
) -> Optional[DialogueStateTracker]:
|
|
1589
|
-
"""Calls `retrieve_full_tracker` method of primary tracker store.
|
|
1590
|
-
|
|
1591
|
-
Args:
|
|
1592
|
-
sender_id: The sender id of the tracker to retrieve.
|
|
1593
|
-
"""
|
|
1594
|
-
try:
|
|
1595
|
-
return await self._tracker_store.retrieve_full_tracker(sender_id)
|
|
1596
|
-
except Exception as e:
|
|
1597
|
-
self.on_tracker_store_retrieve_error(e)
|
|
1598
|
-
return None
|
|
1599
|
-
|
|
1600
|
-
def on_tracker_store_retrieve_error(self, error: Exception) -> None:
|
|
1601
|
-
"""Calls `_on_tracker_store_error` callable attribute if set.
|
|
1602
|
-
|
|
1603
|
-
Otherwise, logs the error.
|
|
1604
|
-
|
|
1605
|
-
Args:
|
|
1606
|
-
error: The error that occurred.
|
|
1607
|
-
"""
|
|
1608
|
-
if self._on_tracker_store_error:
|
|
1609
|
-
self._on_tracker_store_error(error)
|
|
1610
|
-
else:
|
|
1611
|
-
structlogger.error(
|
|
1612
|
-
"fail_safe_tracker_store.tracker_store_retrieve_error",
|
|
1613
|
-
event_info=(
|
|
1614
|
-
f"Error happened when trying to retrieve conversation tracker from "
|
|
1615
|
-
f"'{self._tracker_store.__class__.__name__}'. Falling back to use "
|
|
1616
|
-
f"the '{InMemoryTrackerStore.__name__}'."
|
|
1617
|
-
),
|
|
1618
|
-
exec_info=error,
|
|
1619
|
-
)
|
|
1620
|
-
|
|
1621
|
-
|
|
1622
|
-
def _create_from_endpoint_config(
|
|
1623
|
-
endpoint_config: Optional[EndpointConfig] = None,
|
|
1624
|
-
domain: Optional[Domain] = None,
|
|
1625
|
-
event_broker: Optional[EventBroker] = None,
|
|
1626
|
-
) -> TrackerStore:
|
|
1627
|
-
"""Given an endpoint configuration, create a proper tracker store object."""
|
|
1628
|
-
domain = domain or Domain.empty()
|
|
1629
|
-
|
|
1630
|
-
if endpoint_config is None or endpoint_config.type is None:
|
|
1631
|
-
# default tracker store if no type is set
|
|
1632
|
-
tracker_store: TrackerStore = InMemoryTrackerStore(domain, event_broker)
|
|
1633
|
-
elif endpoint_config.type.lower() == "redis":
|
|
1634
|
-
tracker_store = RedisTrackerStore(
|
|
1635
|
-
domain=domain,
|
|
1636
|
-
host=endpoint_config.url,
|
|
1637
|
-
event_broker=event_broker,
|
|
1638
|
-
**endpoint_config.kwargs,
|
|
1639
|
-
)
|
|
1640
|
-
elif endpoint_config.type.lower() == "mongod":
|
|
1641
|
-
tracker_store = MongoTrackerStore(
|
|
1642
|
-
domain=domain,
|
|
1643
|
-
host=endpoint_config.url,
|
|
1644
|
-
event_broker=event_broker,
|
|
1645
|
-
**endpoint_config.kwargs,
|
|
1646
|
-
)
|
|
1647
|
-
elif endpoint_config.type.lower() == "sql":
|
|
1648
|
-
tracker_store = SQLTrackerStore(
|
|
1649
|
-
domain=domain,
|
|
1650
|
-
host=endpoint_config.url,
|
|
1651
|
-
event_broker=event_broker,
|
|
1652
|
-
**endpoint_config.kwargs,
|
|
1653
|
-
)
|
|
1654
|
-
elif endpoint_config.type.lower() == "dynamo":
|
|
1655
|
-
tracker_store = DynamoTrackerStore(
|
|
1656
|
-
domain=domain, event_broker=event_broker, **endpoint_config.kwargs
|
|
1657
|
-
)
|
|
1658
|
-
else:
|
|
1659
|
-
tracker_store = _load_from_module_name_in_endpoint_config(
|
|
1660
|
-
domain, endpoint_config, event_broker
|
|
1661
|
-
)
|
|
1662
|
-
|
|
1663
|
-
structlogger.debug(
|
|
1664
|
-
"tracker_store.create_tracker_store_from_endpoint_config",
|
|
1665
|
-
eventi_info=f"Connected to {tracker_store.__class__.__name__}.",
|
|
1666
|
-
)
|
|
1667
|
-
|
|
1668
|
-
return tracker_store
|
|
1669
|
-
|
|
1670
|
-
|
|
1671
|
-
def _load_from_module_name_in_endpoint_config(
|
|
1672
|
-
domain: Domain, store: EndpointConfig, event_broker: Optional[EventBroker] = None
|
|
1673
|
-
) -> TrackerStore:
|
|
1674
|
-
"""Initializes a custom tracker.
|
|
1675
|
-
|
|
1676
|
-
Defaults to the InMemoryTrackerStore if the module path can not be found.
|
|
1677
|
-
|
|
1678
|
-
Args:
|
|
1679
|
-
domain: defines the universe in which the assistant operates
|
|
1680
|
-
store: the specific tracker store
|
|
1681
|
-
event_broker: an event broker to publish events
|
|
1682
|
-
|
|
1683
|
-
Returns:
|
|
1684
|
-
a tracker store from a specified type in a stores endpoint configuration
|
|
1685
|
-
"""
|
|
1686
|
-
try:
|
|
1687
|
-
tracker_store_class = rasa.shared.utils.common.class_from_module_path(
|
|
1688
|
-
store.type
|
|
1689
|
-
)
|
|
1690
|
-
|
|
1691
|
-
return tracker_store_class(
|
|
1692
|
-
host=store.url, domain=domain, event_broker=event_broker, **store.kwargs
|
|
1693
|
-
)
|
|
1694
|
-
except (AttributeError, ImportError):
|
|
1695
|
-
rasa.shared.utils.io.raise_warning(
|
|
1696
|
-
f"Tracker store with type '{store.type}' not found. "
|
|
1697
|
-
f"Using `InMemoryTrackerStore` instead."
|
|
1698
|
-
)
|
|
1699
|
-
return InMemoryTrackerStore(domain)
|
|
1700
|
-
|
|
1701
|
-
|
|
1702
|
-
def create_tracker_store(
|
|
1703
|
-
endpoint_config: Optional[EndpointConfig],
|
|
1704
|
-
domain: Optional[Domain] = None,
|
|
1705
|
-
event_broker: Optional[EventBroker] = None,
|
|
1706
|
-
) -> TrackerStore:
|
|
1707
|
-
"""Creates a tracker store based on the current configuration."""
|
|
1708
|
-
tracker_store = _create_from_endpoint_config(endpoint_config, domain, event_broker)
|
|
1709
|
-
|
|
1710
|
-
if not check_if_tracker_store_async(tracker_store):
|
|
1711
|
-
rasa.shared.utils.io.raise_deprecation_warning(
|
|
1712
|
-
f"Tracker store implementation "
|
|
1713
|
-
f"{tracker_store.__class__.__name__} "
|
|
1714
|
-
f"is not asynchronous. Non-asynchronous tracker stores "
|
|
1715
|
-
f"are currently deprecated and will be removed in 4.0. "
|
|
1716
|
-
f"Please make the following methods async: "
|
|
1717
|
-
f"{_get_async_tracker_store_methods()}"
|
|
1718
|
-
)
|
|
1719
|
-
tracker_store = AwaitableTrackerStore(tracker_store)
|
|
1720
|
-
|
|
1721
|
-
return tracker_store
|
|
1722
|
-
|
|
1723
|
-
|
|
1724
|
-
class AwaitableTrackerStore(TrackerStore):
|
|
1725
|
-
"""Wraps a tracker store so it can be implemented with async overrides."""
|
|
1726
|
-
|
|
1727
|
-
def __init__(
|
|
1728
|
-
self,
|
|
1729
|
-
tracker_store: TrackerStore,
|
|
1730
|
-
) -> None:
|
|
1731
|
-
"""Create a `AwaitableTrackerStore`.
|
|
1732
|
-
|
|
1733
|
-
Args:
|
|
1734
|
-
tracker_store: the wrapped tracker store.
|
|
1735
|
-
"""
|
|
1736
|
-
self._tracker_store = tracker_store
|
|
1737
|
-
|
|
1738
|
-
super().__init__(tracker_store.domain, tracker_store.event_broker)
|
|
1739
|
-
|
|
1740
|
-
@property
|
|
1741
|
-
def domain(self) -> Domain:
|
|
1742
|
-
"""Returns the domain of the primary tracker store."""
|
|
1743
|
-
return self._tracker_store.domain
|
|
1744
|
-
|
|
1745
|
-
@domain.setter
|
|
1746
|
-
def domain(self, domain: Optional[Domain]) -> None:
|
|
1747
|
-
"""Setter method to modify the wrapped tracker store's domain field."""
|
|
1748
|
-
self._tracker_store.domain = domain or Domain.empty()
|
|
1749
|
-
|
|
1750
|
-
@staticmethod
|
|
1751
|
-
def create(
|
|
1752
|
-
obj: Union[TrackerStore, EndpointConfig, None],
|
|
1753
|
-
domain: Optional[Domain] = None,
|
|
1754
|
-
event_broker: Optional[EventBroker] = None,
|
|
1755
|
-
) -> TrackerStore:
|
|
1756
|
-
"""Wrapper to call `create` method of primary tracker store."""
|
|
1757
|
-
if isinstance(obj, TrackerStore):
|
|
1758
|
-
return AwaitableTrackerStore(obj)
|
|
1759
|
-
elif isinstance(obj, EndpointConfig):
|
|
1760
|
-
return AwaitableTrackerStore(_create_from_endpoint_config(obj))
|
|
1761
|
-
else:
|
|
1762
|
-
raise ValueError(
|
|
1763
|
-
f"{type(obj).__name__} supplied "
|
|
1764
|
-
f"but expected object of type {TrackerStore.__name__} or "
|
|
1765
|
-
f"of type {EndpointConfig.__name__}."
|
|
1766
|
-
)
|
|
1767
|
-
|
|
1768
|
-
async def retrieve(self, sender_id: Text) -> Optional[DialogueStateTracker]:
|
|
1769
|
-
"""Wrapper to call `retrieve` method of primary tracker store."""
|
|
1770
|
-
result = self._tracker_store.retrieve(sender_id)
|
|
1771
|
-
return (
|
|
1772
|
-
await result if isawaitable(result) else result # type: ignore[return-value, misc]
|
|
1773
|
-
)
|
|
1774
|
-
|
|
1775
|
-
async def keys(self) -> Iterable[Text]:
|
|
1776
|
-
"""Wrapper to call `keys` method of primary tracker store."""
|
|
1777
|
-
result = self._tracker_store.keys()
|
|
1778
|
-
return await result if isawaitable(result) else result
|
|
1779
|
-
|
|
1780
|
-
async def save(self, tracker: DialogueStateTracker) -> None:
|
|
1781
|
-
"""Wrapper to call `save` method of primary tracker store."""
|
|
1782
|
-
result = self._tracker_store.save(tracker)
|
|
1783
|
-
return await result if isawaitable(result) else result
|
|
1784
|
-
|
|
1785
|
-
async def retrieve_full_tracker(
|
|
1786
|
-
self, conversation_id: Text
|
|
1787
|
-
) -> Optional[DialogueStateTracker]:
|
|
1788
|
-
"""Wrapper to call `retrieve_full_tracker` method of primary tracker store."""
|
|
1789
|
-
result = self._tracker_store.retrieve_full_tracker(conversation_id)
|
|
1790
|
-
return (
|
|
1791
|
-
await result if isawaitable(result) else result # type: ignore[return-value, misc]
|
|
1792
|
-
)
|