rasa-pro 3.12.0.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- README.md +41 -0
- rasa/__init__.py +9 -0
- rasa/__main__.py +177 -0
- rasa/anonymization/__init__.py +2 -0
- rasa/anonymization/anonymisation_rule_yaml_reader.py +91 -0
- rasa/anonymization/anonymization_pipeline.py +286 -0
- rasa/anonymization/anonymization_rule_executor.py +260 -0
- rasa/anonymization/anonymization_rule_orchestrator.py +120 -0
- rasa/anonymization/schemas/config.yml +47 -0
- rasa/anonymization/utils.py +118 -0
- rasa/api.py +160 -0
- rasa/cli/__init__.py +5 -0
- rasa/cli/arguments/__init__.py +0 -0
- rasa/cli/arguments/data.py +106 -0
- rasa/cli/arguments/default_arguments.py +207 -0
- rasa/cli/arguments/evaluate.py +65 -0
- rasa/cli/arguments/export.py +51 -0
- rasa/cli/arguments/interactive.py +74 -0
- rasa/cli/arguments/run.py +219 -0
- rasa/cli/arguments/shell.py +17 -0
- rasa/cli/arguments/test.py +211 -0
- rasa/cli/arguments/train.py +279 -0
- rasa/cli/arguments/visualize.py +34 -0
- rasa/cli/arguments/x.py +30 -0
- rasa/cli/data.py +354 -0
- rasa/cli/dialogue_understanding_test.py +251 -0
- rasa/cli/e2e_test.py +259 -0
- rasa/cli/evaluate.py +222 -0
- rasa/cli/export.py +250 -0
- rasa/cli/inspect.py +75 -0
- rasa/cli/interactive.py +166 -0
- rasa/cli/license.py +65 -0
- rasa/cli/llm_fine_tuning.py +403 -0
- rasa/cli/markers.py +78 -0
- rasa/cli/project_templates/__init__.py +0 -0
- rasa/cli/project_templates/calm/actions/__init__.py +0 -0
- rasa/cli/project_templates/calm/actions/action_template.py +27 -0
- rasa/cli/project_templates/calm/actions/add_contact.py +30 -0
- rasa/cli/project_templates/calm/actions/db.py +57 -0
- rasa/cli/project_templates/calm/actions/list_contacts.py +22 -0
- rasa/cli/project_templates/calm/actions/remove_contact.py +35 -0
- rasa/cli/project_templates/calm/config.yml +10 -0
- rasa/cli/project_templates/calm/credentials.yml +33 -0
- rasa/cli/project_templates/calm/data/flows/add_contact.yml +31 -0
- rasa/cli/project_templates/calm/data/flows/list_contacts.yml +14 -0
- rasa/cli/project_templates/calm/data/flows/remove_contact.yml +29 -0
- rasa/cli/project_templates/calm/db/contacts.json +10 -0
- rasa/cli/project_templates/calm/domain/add_contact.yml +39 -0
- rasa/cli/project_templates/calm/domain/list_contacts.yml +17 -0
- rasa/cli/project_templates/calm/domain/remove_contact.yml +38 -0
- rasa/cli/project_templates/calm/domain/shared.yml +10 -0
- rasa/cli/project_templates/calm/e2e_tests/cancelations/user_cancels_during_a_correction.yml +16 -0
- rasa/cli/project_templates/calm/e2e_tests/cancelations/user_changes_mind_on_a_whim.yml +7 -0
- rasa/cli/project_templates/calm/e2e_tests/corrections/user_corrects_contact_handle.yml +20 -0
- rasa/cli/project_templates/calm/e2e_tests/corrections/user_corrects_contact_name.yml +19 -0
- rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_adds_contact_to_their_list.yml +15 -0
- rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_lists_contacts.yml +5 -0
- rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_removes_contact.yml +11 -0
- rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_removes_contact_from_list.yml +12 -0
- rasa/cli/project_templates/calm/endpoints.yml +58 -0
- rasa/cli/project_templates/default/actions/__init__.py +0 -0
- rasa/cli/project_templates/default/actions/actions.py +27 -0
- rasa/cli/project_templates/default/config.yml +44 -0
- rasa/cli/project_templates/default/credentials.yml +33 -0
- rasa/cli/project_templates/default/data/nlu.yml +91 -0
- rasa/cli/project_templates/default/data/rules.yml +13 -0
- rasa/cli/project_templates/default/data/stories.yml +30 -0
- rasa/cli/project_templates/default/domain.yml +34 -0
- rasa/cli/project_templates/default/endpoints.yml +42 -0
- rasa/cli/project_templates/default/tests/test_stories.yml +91 -0
- rasa/cli/project_templates/tutorial/actions/__init__.py +0 -0
- rasa/cli/project_templates/tutorial/actions/actions.py +22 -0
- rasa/cli/project_templates/tutorial/config.yml +12 -0
- rasa/cli/project_templates/tutorial/credentials.yml +33 -0
- rasa/cli/project_templates/tutorial/data/flows.yml +8 -0
- rasa/cli/project_templates/tutorial/data/patterns.yml +11 -0
- rasa/cli/project_templates/tutorial/domain.yml +35 -0
- rasa/cli/project_templates/tutorial/endpoints.yml +55 -0
- rasa/cli/run.py +143 -0
- rasa/cli/scaffold.py +273 -0
- rasa/cli/shell.py +141 -0
- rasa/cli/studio/__init__.py +0 -0
- rasa/cli/studio/download.py +62 -0
- rasa/cli/studio/studio.py +296 -0
- rasa/cli/studio/train.py +59 -0
- rasa/cli/studio/upload.py +62 -0
- rasa/cli/telemetry.py +102 -0
- rasa/cli/test.py +280 -0
- rasa/cli/train.py +278 -0
- rasa/cli/utils.py +484 -0
- rasa/cli/visualize.py +40 -0
- rasa/cli/x.py +206 -0
- rasa/constants.py +45 -0
- rasa/core/__init__.py +17 -0
- rasa/core/actions/__init__.py +0 -0
- rasa/core/actions/action.py +1318 -0
- rasa/core/actions/action_clean_stack.py +59 -0
- rasa/core/actions/action_exceptions.py +24 -0
- rasa/core/actions/action_hangup.py +29 -0
- rasa/core/actions/action_repeat_bot_messages.py +89 -0
- rasa/core/actions/action_run_slot_rejections.py +210 -0
- rasa/core/actions/action_trigger_chitchat.py +31 -0
- rasa/core/actions/action_trigger_flow.py +109 -0
- rasa/core/actions/action_trigger_search.py +31 -0
- rasa/core/actions/constants.py +5 -0
- rasa/core/actions/custom_action_executor.py +191 -0
- rasa/core/actions/direct_custom_actions_executor.py +109 -0
- rasa/core/actions/e2e_stub_custom_action_executor.py +72 -0
- rasa/core/actions/forms.py +741 -0
- rasa/core/actions/grpc_custom_action_executor.py +251 -0
- rasa/core/actions/http_custom_action_executor.py +145 -0
- rasa/core/actions/loops.py +114 -0
- rasa/core/actions/two_stage_fallback.py +186 -0
- rasa/core/agent.py +559 -0
- rasa/core/auth_retry_tracker_store.py +122 -0
- rasa/core/brokers/__init__.py +0 -0
- rasa/core/brokers/broker.py +126 -0
- rasa/core/brokers/file.py +58 -0
- rasa/core/brokers/kafka.py +324 -0
- rasa/core/brokers/pika.py +388 -0
- rasa/core/brokers/sql.py +86 -0
- rasa/core/channels/__init__.py +61 -0
- rasa/core/channels/botframework.py +338 -0
- rasa/core/channels/callback.py +84 -0
- rasa/core/channels/channel.py +456 -0
- rasa/core/channels/console.py +241 -0
- rasa/core/channels/development_inspector.py +197 -0
- rasa/core/channels/facebook.py +419 -0
- rasa/core/channels/hangouts.py +329 -0
- rasa/core/channels/inspector/.eslintrc.cjs +25 -0
- rasa/core/channels/inspector/.gitignore +23 -0
- rasa/core/channels/inspector/README.md +54 -0
- rasa/core/channels/inspector/assets/favicon.ico +0 -0
- rasa/core/channels/inspector/assets/rasa-chat.js +2 -0
- rasa/core/channels/inspector/custom.d.ts +3 -0
- rasa/core/channels/inspector/dist/assets/arc-861ddd57.js +1 -0
- rasa/core/channels/inspector/dist/assets/array-9f3ba611.js +1 -0
- rasa/core/channels/inspector/dist/assets/c4Diagram-d0fbc5ce-921f02db.js +10 -0
- rasa/core/channels/inspector/dist/assets/classDiagram-936ed81e-b436c4f8.js +2 -0
- rasa/core/channels/inspector/dist/assets/classDiagram-v2-c3cb15f1-511a23cb.js +2 -0
- rasa/core/channels/inspector/dist/assets/createText-62fc7601-ef476ecd.js +7 -0
- rasa/core/channels/inspector/dist/assets/edges-f2ad444c-f1878e0a.js +4 -0
- rasa/core/channels/inspector/dist/assets/erDiagram-9d236eb7-fac75185.js +51 -0
- rasa/core/channels/inspector/dist/assets/flowDb-1972c806-201c5bbc.js +6 -0
- rasa/core/channels/inspector/dist/assets/flowDiagram-7ea5b25a-f904ae41.js +4 -0
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-b080d6f2.js +1 -0
- rasa/core/channels/inspector/dist/assets/flowchart-elk-definition-abe16c3d-1813da66.js +139 -0
- rasa/core/channels/inspector/dist/assets/ganttDiagram-9b5ea136-872af172.js +266 -0
- rasa/core/channels/inspector/dist/assets/gitGraphDiagram-99d0ae7c-34a0af5a.js +70 -0
- rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-128cfa44.ttf +0 -0
- rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-21dbcb97.woff +0 -0
- rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-222b5e26.svg +329 -0
- rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-9ad89b2a.woff2 +0 -0
- rasa/core/channels/inspector/dist/assets/index-2c4b9a3b-42ba3e3d.js +1 -0
- rasa/core/channels/inspector/dist/assets/index-37817b51.js +1317 -0
- rasa/core/channels/inspector/dist/assets/index-3ee28881.css +1 -0
- rasa/core/channels/inspector/dist/assets/infoDiagram-736b4530-6b731386.js +7 -0
- rasa/core/channels/inspector/dist/assets/init-77b53fdd.js +1 -0
- rasa/core/channels/inspector/dist/assets/journeyDiagram-df861f2b-e8579ac6.js +139 -0
- rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-60c05ee4.woff +0 -0
- rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-8335d9b8.svg +438 -0
- rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-9cc39c75.ttf +0 -0
- rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-ead13ccf.woff2 +0 -0
- rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-16705655.woff2 +0 -0
- rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-5aeb07f9.woff +0 -0
- rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-9c459044.ttf +0 -0
- rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-9e2898a4.svg +435 -0
- rasa/core/channels/inspector/dist/assets/layout-89e6403a.js +1 -0
- rasa/core/channels/inspector/dist/assets/line-dc73d3fc.js +1 -0
- rasa/core/channels/inspector/dist/assets/linear-f5b1d2bc.js +1 -0
- rasa/core/channels/inspector/dist/assets/mindmap-definition-beec6740-82cb74fa.js +109 -0
- rasa/core/channels/inspector/dist/assets/ordinal-ba9b4969.js +1 -0
- rasa/core/channels/inspector/dist/assets/path-53f90ab3.js +1 -0
- rasa/core/channels/inspector/dist/assets/pieDiagram-dbbf0591-bdf5f29b.js +35 -0
- rasa/core/channels/inspector/dist/assets/quadrantDiagram-4d7f4fd6-c7a0cbe4.js +7 -0
- rasa/core/channels/inspector/dist/assets/requirementDiagram-6fc4c22a-7ec5410f.js +52 -0
- rasa/core/channels/inspector/dist/assets/sankeyDiagram-8f13d901-caee5554.js +8 -0
- rasa/core/channels/inspector/dist/assets/sequenceDiagram-b655622a-2935f8db.js +122 -0
- rasa/core/channels/inspector/dist/assets/stateDiagram-59f0c015-8f5d9693.js +1 -0
- rasa/core/channels/inspector/dist/assets/stateDiagram-v2-2b26beab-d565d1de.js +1 -0
- rasa/core/channels/inspector/dist/assets/styles-080da4f6-75ad421d.js +110 -0
- rasa/core/channels/inspector/dist/assets/styles-3dcbcfbf-7e764226.js +159 -0
- rasa/core/channels/inspector/dist/assets/styles-9c745c82-7a4e0e61.js +207 -0
- rasa/core/channels/inspector/dist/assets/svgDrawCommon-4835440b-4019d1bf.js +1 -0
- rasa/core/channels/inspector/dist/assets/timeline-definition-5b62e21b-01ea12df.js +61 -0
- rasa/core/channels/inspector/dist/assets/xychartDiagram-2b33534f-89407137.js +7 -0
- rasa/core/channels/inspector/dist/index.html +42 -0
- rasa/core/channels/inspector/index.html +40 -0
- rasa/core/channels/inspector/jest.config.ts +13 -0
- rasa/core/channels/inspector/package.json +52 -0
- rasa/core/channels/inspector/setupTests.ts +2 -0
- rasa/core/channels/inspector/src/App.tsx +220 -0
- rasa/core/channels/inspector/src/components/Chat.tsx +95 -0
- rasa/core/channels/inspector/src/components/DiagramFlow.tsx +108 -0
- rasa/core/channels/inspector/src/components/DialogueInformation.tsx +187 -0
- rasa/core/channels/inspector/src/components/DialogueStack.tsx +136 -0
- rasa/core/channels/inspector/src/components/ExpandIcon.tsx +16 -0
- rasa/core/channels/inspector/src/components/FullscreenButton.tsx +45 -0
- rasa/core/channels/inspector/src/components/LoadingSpinner.tsx +22 -0
- rasa/core/channels/inspector/src/components/NoActiveFlow.tsx +21 -0
- rasa/core/channels/inspector/src/components/RasaLogo.tsx +32 -0
- rasa/core/channels/inspector/src/components/SaraDiagrams.tsx +39 -0
- rasa/core/channels/inspector/src/components/Slots.tsx +91 -0
- rasa/core/channels/inspector/src/components/Welcome.tsx +54 -0
- rasa/core/channels/inspector/src/helpers/audiostream.ts +191 -0
- rasa/core/channels/inspector/src/helpers/formatters.test.ts +392 -0
- rasa/core/channels/inspector/src/helpers/formatters.ts +306 -0
- rasa/core/channels/inspector/src/helpers/utils.ts +127 -0
- rasa/core/channels/inspector/src/main.tsx +13 -0
- rasa/core/channels/inspector/src/theme/Button/Button.ts +29 -0
- rasa/core/channels/inspector/src/theme/Heading/Heading.ts +31 -0
- rasa/core/channels/inspector/src/theme/Input/Input.ts +27 -0
- rasa/core/channels/inspector/src/theme/Link/Link.ts +10 -0
- rasa/core/channels/inspector/src/theme/Modal/Modal.ts +47 -0
- rasa/core/channels/inspector/src/theme/Table/Table.tsx +38 -0
- rasa/core/channels/inspector/src/theme/Tooltip/Tooltip.ts +12 -0
- rasa/core/channels/inspector/src/theme/base/breakpoints.ts +8 -0
- rasa/core/channels/inspector/src/theme/base/colors.ts +88 -0
- rasa/core/channels/inspector/src/theme/base/fonts/fontFaces.css +29 -0
- rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.eot +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.svg +329 -0
- rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.ttf +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.woff +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.woff2 +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.eot +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.svg +438 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.ttf +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.woff +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.woff2 +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.eot +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.svg +435 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.ttf +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.woff +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.woff2 +0 -0
- rasa/core/channels/inspector/src/theme/base/radii.ts +9 -0
- rasa/core/channels/inspector/src/theme/base/shadows.ts +7 -0
- rasa/core/channels/inspector/src/theme/base/sizes.ts +7 -0
- rasa/core/channels/inspector/src/theme/base/space.ts +15 -0
- rasa/core/channels/inspector/src/theme/base/styles.ts +13 -0
- rasa/core/channels/inspector/src/theme/base/typography.ts +24 -0
- rasa/core/channels/inspector/src/theme/base/zIndices.ts +19 -0
- rasa/core/channels/inspector/src/theme/index.ts +101 -0
- rasa/core/channels/inspector/src/types.ts +84 -0
- rasa/core/channels/inspector/src/vite-env.d.ts +1 -0
- rasa/core/channels/inspector/tests/__mocks__/fileMock.ts +1 -0
- rasa/core/channels/inspector/tests/__mocks__/matchMedia.ts +16 -0
- rasa/core/channels/inspector/tests/__mocks__/styleMock.ts +1 -0
- rasa/core/channels/inspector/tests/renderWithProviders.tsx +14 -0
- rasa/core/channels/inspector/tsconfig.json +26 -0
- rasa/core/channels/inspector/tsconfig.node.json +10 -0
- rasa/core/channels/inspector/vite.config.ts +8 -0
- rasa/core/channels/inspector/yarn.lock +6249 -0
- rasa/core/channels/mattermost.py +229 -0
- rasa/core/channels/rasa_chat.py +126 -0
- rasa/core/channels/rest.py +230 -0
- rasa/core/channels/rocketchat.py +174 -0
- rasa/core/channels/slack.py +620 -0
- rasa/core/channels/socketio.py +302 -0
- rasa/core/channels/telegram.py +298 -0
- rasa/core/channels/twilio.py +169 -0
- rasa/core/channels/vier_cvg.py +374 -0
- rasa/core/channels/voice_ready/__init__.py +0 -0
- rasa/core/channels/voice_ready/audiocodes.py +501 -0
- rasa/core/channels/voice_ready/jambonz.py +121 -0
- rasa/core/channels/voice_ready/jambonz_protocol.py +396 -0
- rasa/core/channels/voice_ready/twilio_voice.py +403 -0
- rasa/core/channels/voice_ready/utils.py +37 -0
- rasa/core/channels/voice_stream/__init__.py +0 -0
- rasa/core/channels/voice_stream/asr/__init__.py +0 -0
- rasa/core/channels/voice_stream/asr/asr_engine.py +89 -0
- rasa/core/channels/voice_stream/asr/asr_event.py +18 -0
- rasa/core/channels/voice_stream/asr/azure.py +130 -0
- rasa/core/channels/voice_stream/asr/deepgram.py +90 -0
- rasa/core/channels/voice_stream/audio_bytes.py +8 -0
- rasa/core/channels/voice_stream/browser_audio.py +107 -0
- rasa/core/channels/voice_stream/call_state.py +23 -0
- rasa/core/channels/voice_stream/tts/__init__.py +0 -0
- rasa/core/channels/voice_stream/tts/azure.py +106 -0
- rasa/core/channels/voice_stream/tts/cartesia.py +118 -0
- rasa/core/channels/voice_stream/tts/tts_cache.py +27 -0
- rasa/core/channels/voice_stream/tts/tts_engine.py +58 -0
- rasa/core/channels/voice_stream/twilio_media_streams.py +173 -0
- rasa/core/channels/voice_stream/util.py +57 -0
- rasa/core/channels/voice_stream/voice_channel.py +427 -0
- rasa/core/channels/webexteams.py +134 -0
- rasa/core/concurrent_lock_store.py +210 -0
- rasa/core/constants.py +112 -0
- rasa/core/evaluation/__init__.py +0 -0
- rasa/core/evaluation/marker.py +267 -0
- rasa/core/evaluation/marker_base.py +923 -0
- rasa/core/evaluation/marker_stats.py +293 -0
- rasa/core/evaluation/marker_tracker_loader.py +103 -0
- rasa/core/exceptions.py +29 -0
- rasa/core/exporter.py +284 -0
- rasa/core/featurizers/__init__.py +0 -0
- rasa/core/featurizers/precomputation.py +410 -0
- rasa/core/featurizers/single_state_featurizer.py +421 -0
- rasa/core/featurizers/tracker_featurizers.py +1262 -0
- rasa/core/http_interpreter.py +89 -0
- rasa/core/information_retrieval/__init__.py +7 -0
- rasa/core/information_retrieval/faiss.py +124 -0
- rasa/core/information_retrieval/information_retrieval.py +137 -0
- rasa/core/information_retrieval/milvus.py +59 -0
- rasa/core/information_retrieval/qdrant.py +96 -0
- rasa/core/jobs.py +63 -0
- rasa/core/lock.py +139 -0
- rasa/core/lock_store.py +343 -0
- rasa/core/migrate.py +403 -0
- rasa/core/nlg/__init__.py +3 -0
- rasa/core/nlg/callback.py +146 -0
- rasa/core/nlg/contextual_response_rephraser.py +320 -0
- rasa/core/nlg/generator.py +230 -0
- rasa/core/nlg/interpolator.py +143 -0
- rasa/core/nlg/response.py +155 -0
- rasa/core/nlg/summarize.py +70 -0
- rasa/core/persistor.py +538 -0
- rasa/core/policies/__init__.py +0 -0
- rasa/core/policies/ensemble.py +329 -0
- rasa/core/policies/enterprise_search_policy.py +905 -0
- rasa/core/policies/enterprise_search_prompt_template.jinja2 +25 -0
- rasa/core/policies/enterprise_search_prompt_with_citation_template.jinja2 +60 -0
- rasa/core/policies/flow_policy.py +205 -0
- rasa/core/policies/flows/__init__.py +0 -0
- rasa/core/policies/flows/flow_exceptions.py +44 -0
- rasa/core/policies/flows/flow_executor.py +754 -0
- rasa/core/policies/flows/flow_step_result.py +43 -0
- rasa/core/policies/intentless_policy.py +1031 -0
- rasa/core/policies/intentless_prompt_template.jinja2 +22 -0
- rasa/core/policies/memoization.py +538 -0
- rasa/core/policies/policy.py +725 -0
- rasa/core/policies/rule_policy.py +1273 -0
- rasa/core/policies/ted_policy.py +2169 -0
- rasa/core/policies/unexpected_intent_policy.py +1022 -0
- rasa/core/processor.py +1465 -0
- rasa/core/run.py +342 -0
- rasa/core/secrets_manager/__init__.py +0 -0
- rasa/core/secrets_manager/constants.py +36 -0
- rasa/core/secrets_manager/endpoints.py +391 -0
- rasa/core/secrets_manager/factory.py +241 -0
- rasa/core/secrets_manager/secret_manager.py +262 -0
- rasa/core/secrets_manager/vault.py +584 -0
- rasa/core/test.py +1335 -0
- rasa/core/tracker_store.py +1703 -0
- rasa/core/train.py +105 -0
- rasa/core/training/__init__.py +89 -0
- rasa/core/training/converters/__init__.py +0 -0
- rasa/core/training/converters/responses_prefix_converter.py +119 -0
- rasa/core/training/interactive.py +1744 -0
- rasa/core/training/story_conflict.py +381 -0
- rasa/core/training/training.py +93 -0
- rasa/core/utils.py +366 -0
- rasa/core/visualize.py +70 -0
- rasa/dialogue_understanding/__init__.py +0 -0
- rasa/dialogue_understanding/coexistence/__init__.py +0 -0
- rasa/dialogue_understanding/coexistence/constants.py +4 -0
- rasa/dialogue_understanding/coexistence/intent_based_router.py +196 -0
- rasa/dialogue_understanding/coexistence/llm_based_router.py +327 -0
- rasa/dialogue_understanding/coexistence/router_template.jinja2 +12 -0
- rasa/dialogue_understanding/commands/__init__.py +61 -0
- rasa/dialogue_understanding/commands/can_not_handle_command.py +70 -0
- rasa/dialogue_understanding/commands/cancel_flow_command.py +125 -0
- rasa/dialogue_understanding/commands/change_flow_command.py +44 -0
- rasa/dialogue_understanding/commands/chit_chat_answer_command.py +57 -0
- rasa/dialogue_understanding/commands/clarify_command.py +86 -0
- rasa/dialogue_understanding/commands/command.py +85 -0
- rasa/dialogue_understanding/commands/correct_slots_command.py +297 -0
- rasa/dialogue_understanding/commands/error_command.py +79 -0
- rasa/dialogue_understanding/commands/free_form_answer_command.py +9 -0
- rasa/dialogue_understanding/commands/handle_code_change_command.py +73 -0
- rasa/dialogue_understanding/commands/human_handoff_command.py +66 -0
- rasa/dialogue_understanding/commands/knowledge_answer_command.py +57 -0
- rasa/dialogue_understanding/commands/noop_command.py +54 -0
- rasa/dialogue_understanding/commands/repeat_bot_messages_command.py +60 -0
- rasa/dialogue_understanding/commands/restart_command.py +58 -0
- rasa/dialogue_understanding/commands/session_end_command.py +61 -0
- rasa/dialogue_understanding/commands/session_start_command.py +59 -0
- rasa/dialogue_understanding/commands/set_slot_command.py +160 -0
- rasa/dialogue_understanding/commands/skip_question_command.py +75 -0
- rasa/dialogue_understanding/commands/start_flow_command.py +107 -0
- rasa/dialogue_understanding/commands/user_silence_command.py +59 -0
- rasa/dialogue_understanding/commands/utils.py +45 -0
- rasa/dialogue_understanding/generator/__init__.py +21 -0
- rasa/dialogue_understanding/generator/command_generator.py +464 -0
- rasa/dialogue_understanding/generator/constants.py +27 -0
- rasa/dialogue_understanding/generator/flow_document_template.jinja2 +4 -0
- rasa/dialogue_understanding/generator/flow_retrieval.py +466 -0
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +500 -0
- rasa/dialogue_understanding/generator/llm_command_generator.py +67 -0
- rasa/dialogue_understanding/generator/multi_step/__init__.py +0 -0
- rasa/dialogue_understanding/generator/multi_step/fill_slots_prompt.jinja2 +62 -0
- rasa/dialogue_understanding/generator/multi_step/handle_flows_prompt.jinja2 +38 -0
- rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +920 -0
- rasa/dialogue_understanding/generator/nlu_command_adapter.py +261 -0
- rasa/dialogue_understanding/generator/single_step/__init__.py +0 -0
- rasa/dialogue_understanding/generator/single_step/command_prompt_template.jinja2 +60 -0
- rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +486 -0
- rasa/dialogue_understanding/patterns/__init__.py +0 -0
- rasa/dialogue_understanding/patterns/cancel.py +111 -0
- rasa/dialogue_understanding/patterns/cannot_handle.py +43 -0
- rasa/dialogue_understanding/patterns/chitchat.py +37 -0
- rasa/dialogue_understanding/patterns/clarify.py +97 -0
- rasa/dialogue_understanding/patterns/code_change.py +41 -0
- rasa/dialogue_understanding/patterns/collect_information.py +90 -0
- rasa/dialogue_understanding/patterns/completed.py +40 -0
- rasa/dialogue_understanding/patterns/continue_interrupted.py +42 -0
- rasa/dialogue_understanding/patterns/correction.py +278 -0
- rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +301 -0
- rasa/dialogue_understanding/patterns/human_handoff.py +37 -0
- rasa/dialogue_understanding/patterns/internal_error.py +47 -0
- rasa/dialogue_understanding/patterns/repeat.py +37 -0
- rasa/dialogue_understanding/patterns/restart.py +37 -0
- rasa/dialogue_understanding/patterns/search.py +37 -0
- rasa/dialogue_understanding/patterns/session_start.py +37 -0
- rasa/dialogue_understanding/patterns/skip_question.py +38 -0
- rasa/dialogue_understanding/patterns/user_silence.py +37 -0
- rasa/dialogue_understanding/processor/__init__.py +0 -0
- rasa/dialogue_understanding/processor/command_processor.py +720 -0
- rasa/dialogue_understanding/processor/command_processor_component.py +43 -0
- rasa/dialogue_understanding/stack/__init__.py +0 -0
- rasa/dialogue_understanding/stack/dialogue_stack.py +178 -0
- rasa/dialogue_understanding/stack/frames/__init__.py +19 -0
- rasa/dialogue_understanding/stack/frames/chit_chat_frame.py +27 -0
- rasa/dialogue_understanding/stack/frames/dialogue_stack_frame.py +137 -0
- rasa/dialogue_understanding/stack/frames/flow_stack_frame.py +157 -0
- rasa/dialogue_understanding/stack/frames/pattern_frame.py +10 -0
- rasa/dialogue_understanding/stack/frames/search_frame.py +27 -0
- rasa/dialogue_understanding/stack/utils.py +211 -0
- rasa/dialogue_understanding/utils.py +14 -0
- rasa/dialogue_understanding_test/__init__.py +0 -0
- rasa/dialogue_understanding_test/command_metric_calculation.py +12 -0
- rasa/dialogue_understanding_test/constants.py +17 -0
- rasa/dialogue_understanding_test/du_test_case.py +118 -0
- rasa/dialogue_understanding_test/du_test_result.py +11 -0
- rasa/dialogue_understanding_test/du_test_runner.py +93 -0
- rasa/dialogue_understanding_test/io.py +54 -0
- rasa/dialogue_understanding_test/validation.py +22 -0
- rasa/e2e_test/__init__.py +0 -0
- rasa/e2e_test/aggregate_test_stats_calculator.py +134 -0
- rasa/e2e_test/assertions.py +1345 -0
- rasa/e2e_test/assertions_schema.yml +129 -0
- rasa/e2e_test/constants.py +31 -0
- rasa/e2e_test/e2e_config.py +220 -0
- rasa/e2e_test/e2e_config_schema.yml +26 -0
- rasa/e2e_test/e2e_test_case.py +569 -0
- rasa/e2e_test/e2e_test_converter.py +363 -0
- rasa/e2e_test/e2e_test_converter_prompt.jinja2 +70 -0
- rasa/e2e_test/e2e_test_coverage_report.py +364 -0
- rasa/e2e_test/e2e_test_result.py +54 -0
- rasa/e2e_test/e2e_test_runner.py +1192 -0
- rasa/e2e_test/e2e_test_schema.yml +181 -0
- rasa/e2e_test/pykwalify_extensions.py +39 -0
- rasa/e2e_test/stub_custom_action.py +70 -0
- rasa/e2e_test/utils/__init__.py +0 -0
- rasa/e2e_test/utils/e2e_yaml_utils.py +55 -0
- rasa/e2e_test/utils/io.py +598 -0
- rasa/e2e_test/utils/validation.py +178 -0
- rasa/engine/__init__.py +0 -0
- rasa/engine/caching.py +463 -0
- rasa/engine/constants.py +17 -0
- rasa/engine/exceptions.py +14 -0
- rasa/engine/graph.py +642 -0
- rasa/engine/loader.py +48 -0
- rasa/engine/recipes/__init__.py +0 -0
- rasa/engine/recipes/config_files/default_config.yml +41 -0
- rasa/engine/recipes/default_components.py +97 -0
- rasa/engine/recipes/default_recipe.py +1272 -0
- rasa/engine/recipes/graph_recipe.py +79 -0
- rasa/engine/recipes/recipe.py +93 -0
- rasa/engine/runner/__init__.py +0 -0
- rasa/engine/runner/dask.py +250 -0
- rasa/engine/runner/interface.py +49 -0
- rasa/engine/storage/__init__.py +0 -0
- rasa/engine/storage/local_model_storage.py +244 -0
- rasa/engine/storage/resource.py +110 -0
- rasa/engine/storage/storage.py +199 -0
- rasa/engine/training/__init__.py +0 -0
- rasa/engine/training/components.py +176 -0
- rasa/engine/training/fingerprinting.py +64 -0
- rasa/engine/training/graph_trainer.py +256 -0
- rasa/engine/training/hooks.py +164 -0
- rasa/engine/validation.py +1451 -0
- rasa/env.py +14 -0
- rasa/exceptions.py +69 -0
- rasa/graph_components/__init__.py +0 -0
- rasa/graph_components/converters/__init__.py +0 -0
- rasa/graph_components/converters/nlu_message_converter.py +48 -0
- rasa/graph_components/providers/__init__.py +0 -0
- rasa/graph_components/providers/domain_for_core_training_provider.py +87 -0
- rasa/graph_components/providers/domain_provider.py +71 -0
- rasa/graph_components/providers/flows_provider.py +74 -0
- rasa/graph_components/providers/forms_provider.py +44 -0
- rasa/graph_components/providers/nlu_training_data_provider.py +56 -0
- rasa/graph_components/providers/responses_provider.py +44 -0
- rasa/graph_components/providers/rule_only_provider.py +49 -0
- rasa/graph_components/providers/story_graph_provider.py +96 -0
- rasa/graph_components/providers/training_tracker_provider.py +55 -0
- rasa/graph_components/validators/__init__.py +0 -0
- rasa/graph_components/validators/default_recipe_validator.py +550 -0
- rasa/graph_components/validators/finetuning_validator.py +302 -0
- rasa/hooks.py +111 -0
- rasa/jupyter.py +63 -0
- rasa/llm_fine_tuning/__init__.py +0 -0
- rasa/llm_fine_tuning/annotation_module.py +241 -0
- rasa/llm_fine_tuning/conversations.py +144 -0
- rasa/llm_fine_tuning/llm_data_preparation_module.py +178 -0
- rasa/llm_fine_tuning/paraphrasing/__init__.py +0 -0
- rasa/llm_fine_tuning/paraphrasing/conversation_rephraser.py +281 -0
- rasa/llm_fine_tuning/paraphrasing/default_rephrase_prompt_template.jina2 +44 -0
- rasa/llm_fine_tuning/paraphrasing/rephrase_validator.py +121 -0
- rasa/llm_fine_tuning/paraphrasing/rephrased_user_message.py +10 -0
- rasa/llm_fine_tuning/paraphrasing_module.py +128 -0
- rasa/llm_fine_tuning/storage.py +174 -0
- rasa/llm_fine_tuning/train_test_split_module.py +441 -0
- rasa/markers/__init__.py +0 -0
- rasa/markers/marker.py +269 -0
- rasa/markers/marker_base.py +828 -0
- rasa/markers/upload.py +74 -0
- rasa/markers/validate.py +21 -0
- rasa/model.py +118 -0
- rasa/model_manager/__init__.py +0 -0
- rasa/model_manager/config.py +40 -0
- rasa/model_manager/model_api.py +559 -0
- rasa/model_manager/runner_service.py +286 -0
- rasa/model_manager/socket_bridge.py +146 -0
- rasa/model_manager/studio_jwt_auth.py +86 -0
- rasa/model_manager/trainer_service.py +325 -0
- rasa/model_manager/utils.py +87 -0
- rasa/model_manager/warm_rasa_process.py +187 -0
- rasa/model_service.py +112 -0
- rasa/model_testing.py +457 -0
- rasa/model_training.py +596 -0
- rasa/nlu/__init__.py +7 -0
- rasa/nlu/classifiers/__init__.py +3 -0
- rasa/nlu/classifiers/classifier.py +5 -0
- rasa/nlu/classifiers/diet_classifier.py +1881 -0
- rasa/nlu/classifiers/fallback_classifier.py +192 -0
- rasa/nlu/classifiers/keyword_intent_classifier.py +188 -0
- rasa/nlu/classifiers/logistic_regression_classifier.py +253 -0
- rasa/nlu/classifiers/mitie_intent_classifier.py +156 -0
- rasa/nlu/classifiers/regex_message_handler.py +56 -0
- rasa/nlu/classifiers/sklearn_intent_classifier.py +330 -0
- rasa/nlu/constants.py +77 -0
- rasa/nlu/convert.py +40 -0
- rasa/nlu/emulators/__init__.py +0 -0
- rasa/nlu/emulators/dialogflow.py +55 -0
- rasa/nlu/emulators/emulator.py +49 -0
- rasa/nlu/emulators/luis.py +86 -0
- rasa/nlu/emulators/no_emulator.py +10 -0
- rasa/nlu/emulators/wit.py +56 -0
- rasa/nlu/extractors/__init__.py +0 -0
- rasa/nlu/extractors/crf_entity_extractor.py +715 -0
- rasa/nlu/extractors/duckling_entity_extractor.py +206 -0
- rasa/nlu/extractors/entity_synonyms.py +178 -0
- rasa/nlu/extractors/extractor.py +470 -0
- rasa/nlu/extractors/mitie_entity_extractor.py +293 -0
- rasa/nlu/extractors/regex_entity_extractor.py +220 -0
- rasa/nlu/extractors/spacy_entity_extractor.py +95 -0
- rasa/nlu/featurizers/__init__.py +0 -0
- rasa/nlu/featurizers/dense_featurizer/__init__.py +0 -0
- rasa/nlu/featurizers/dense_featurizer/convert_featurizer.py +445 -0
- rasa/nlu/featurizers/dense_featurizer/dense_featurizer.py +57 -0
- rasa/nlu/featurizers/dense_featurizer/lm_featurizer.py +768 -0
- rasa/nlu/featurizers/dense_featurizer/mitie_featurizer.py +170 -0
- rasa/nlu/featurizers/dense_featurizer/spacy_featurizer.py +132 -0
- rasa/nlu/featurizers/featurizer.py +89 -0
- rasa/nlu/featurizers/sparse_featurizer/__init__.py +0 -0
- rasa/nlu/featurizers/sparse_featurizer/count_vectors_featurizer.py +867 -0
- rasa/nlu/featurizers/sparse_featurizer/lexical_syntactic_featurizer.py +571 -0
- rasa/nlu/featurizers/sparse_featurizer/regex_featurizer.py +271 -0
- rasa/nlu/featurizers/sparse_featurizer/sparse_featurizer.py +9 -0
- rasa/nlu/model.py +24 -0
- rasa/nlu/run.py +27 -0
- rasa/nlu/selectors/__init__.py +0 -0
- rasa/nlu/selectors/response_selector.py +987 -0
- rasa/nlu/test.py +1940 -0
- rasa/nlu/tokenizers/__init__.py +0 -0
- rasa/nlu/tokenizers/jieba_tokenizer.py +148 -0
- rasa/nlu/tokenizers/mitie_tokenizer.py +75 -0
- rasa/nlu/tokenizers/spacy_tokenizer.py +72 -0
- rasa/nlu/tokenizers/tokenizer.py +239 -0
- rasa/nlu/tokenizers/whitespace_tokenizer.py +95 -0
- rasa/nlu/utils/__init__.py +35 -0
- rasa/nlu/utils/bilou_utils.py +462 -0
- rasa/nlu/utils/hugging_face/__init__.py +0 -0
- rasa/nlu/utils/hugging_face/registry.py +108 -0
- rasa/nlu/utils/hugging_face/transformers_pre_post_processors.py +311 -0
- rasa/nlu/utils/mitie_utils.py +113 -0
- rasa/nlu/utils/pattern_utils.py +168 -0
- rasa/nlu/utils/spacy_utils.py +310 -0
- rasa/plugin.py +90 -0
- rasa/server.py +1588 -0
- rasa/shared/__init__.py +0 -0
- rasa/shared/constants.py +311 -0
- rasa/shared/core/__init__.py +0 -0
- rasa/shared/core/command_payload_reader.py +109 -0
- rasa/shared/core/constants.py +180 -0
- rasa/shared/core/conversation.py +46 -0
- rasa/shared/core/domain.py +2172 -0
- rasa/shared/core/events.py +2559 -0
- rasa/shared/core/flows/__init__.py +7 -0
- rasa/shared/core/flows/flow.py +562 -0
- rasa/shared/core/flows/flow_path.py +84 -0
- rasa/shared/core/flows/flow_step.py +146 -0
- rasa/shared/core/flows/flow_step_links.py +319 -0
- rasa/shared/core/flows/flow_step_sequence.py +70 -0
- rasa/shared/core/flows/flows_list.py +258 -0
- rasa/shared/core/flows/flows_yaml_schema.json +303 -0
- rasa/shared/core/flows/nlu_trigger.py +117 -0
- rasa/shared/core/flows/steps/__init__.py +24 -0
- rasa/shared/core/flows/steps/action.py +56 -0
- rasa/shared/core/flows/steps/call.py +64 -0
- rasa/shared/core/flows/steps/collect.py +112 -0
- rasa/shared/core/flows/steps/constants.py +5 -0
- rasa/shared/core/flows/steps/continuation.py +36 -0
- rasa/shared/core/flows/steps/end.py +22 -0
- rasa/shared/core/flows/steps/internal.py +44 -0
- rasa/shared/core/flows/steps/link.py +51 -0
- rasa/shared/core/flows/steps/no_operation.py +48 -0
- rasa/shared/core/flows/steps/set_slots.py +50 -0
- rasa/shared/core/flows/steps/start.py +30 -0
- rasa/shared/core/flows/utils.py +39 -0
- rasa/shared/core/flows/validation.py +735 -0
- rasa/shared/core/flows/yaml_flows_io.py +405 -0
- rasa/shared/core/generator.py +908 -0
- rasa/shared/core/slot_mappings.py +526 -0
- rasa/shared/core/slots.py +654 -0
- rasa/shared/core/trackers.py +1183 -0
- rasa/shared/core/training_data/__init__.py +0 -0
- rasa/shared/core/training_data/loading.py +89 -0
- rasa/shared/core/training_data/story_reader/__init__.py +0 -0
- rasa/shared/core/training_data/story_reader/story_reader.py +129 -0
- rasa/shared/core/training_data/story_reader/story_step_builder.py +168 -0
- rasa/shared/core/training_data/story_reader/yaml_story_reader.py +888 -0
- rasa/shared/core/training_data/story_writer/__init__.py +0 -0
- rasa/shared/core/training_data/story_writer/story_writer.py +76 -0
- rasa/shared/core/training_data/story_writer/yaml_story_writer.py +444 -0
- rasa/shared/core/training_data/structures.py +858 -0
- rasa/shared/core/training_data/visualization.html +146 -0
- rasa/shared/core/training_data/visualization.py +603 -0
- rasa/shared/data.py +249 -0
- rasa/shared/engine/__init__.py +0 -0
- rasa/shared/engine/caching.py +26 -0
- rasa/shared/exceptions.py +167 -0
- rasa/shared/importers/__init__.py +0 -0
- rasa/shared/importers/importer.py +770 -0
- rasa/shared/importers/multi_project.py +215 -0
- rasa/shared/importers/rasa.py +108 -0
- rasa/shared/importers/remote_importer.py +196 -0
- rasa/shared/importers/utils.py +36 -0
- rasa/shared/nlu/__init__.py +0 -0
- rasa/shared/nlu/constants.py +53 -0
- rasa/shared/nlu/interpreter.py +10 -0
- rasa/shared/nlu/training_data/__init__.py +0 -0
- rasa/shared/nlu/training_data/entities_parser.py +208 -0
- rasa/shared/nlu/training_data/features.py +492 -0
- rasa/shared/nlu/training_data/formats/__init__.py +10 -0
- rasa/shared/nlu/training_data/formats/dialogflow.py +163 -0
- rasa/shared/nlu/training_data/formats/luis.py +87 -0
- rasa/shared/nlu/training_data/formats/rasa.py +135 -0
- rasa/shared/nlu/training_data/formats/rasa_yaml.py +618 -0
- rasa/shared/nlu/training_data/formats/readerwriter.py +244 -0
- rasa/shared/nlu/training_data/formats/wit.py +52 -0
- rasa/shared/nlu/training_data/loading.py +137 -0
- rasa/shared/nlu/training_data/lookup_tables_parser.py +30 -0
- rasa/shared/nlu/training_data/message.py +490 -0
- rasa/shared/nlu/training_data/schemas/__init__.py +0 -0
- rasa/shared/nlu/training_data/schemas/data_schema.py +85 -0
- rasa/shared/nlu/training_data/schemas/nlu.yml +53 -0
- rasa/shared/nlu/training_data/schemas/responses.yml +70 -0
- rasa/shared/nlu/training_data/synonyms_parser.py +42 -0
- rasa/shared/nlu/training_data/training_data.py +729 -0
- rasa/shared/nlu/training_data/util.py +223 -0
- rasa/shared/providers/__init__.py +0 -0
- rasa/shared/providers/_configs/__init__.py +0 -0
- rasa/shared/providers/_configs/azure_openai_client_config.py +677 -0
- rasa/shared/providers/_configs/client_config.py +59 -0
- rasa/shared/providers/_configs/default_litellm_client_config.py +132 -0
- rasa/shared/providers/_configs/huggingface_local_embedding_client_config.py +236 -0
- rasa/shared/providers/_configs/litellm_router_client_config.py +222 -0
- rasa/shared/providers/_configs/model_group_config.py +173 -0
- rasa/shared/providers/_configs/openai_client_config.py +177 -0
- rasa/shared/providers/_configs/rasa_llm_client_config.py +75 -0
- rasa/shared/providers/_configs/self_hosted_llm_client_config.py +178 -0
- rasa/shared/providers/_configs/utils.py +117 -0
- rasa/shared/providers/_ssl_verification_utils.py +124 -0
- rasa/shared/providers/_utils.py +79 -0
- rasa/shared/providers/constants.py +7 -0
- rasa/shared/providers/embedding/__init__.py +0 -0
- rasa/shared/providers/embedding/_base_litellm_embedding_client.py +243 -0
- rasa/shared/providers/embedding/_langchain_embedding_client_adapter.py +74 -0
- rasa/shared/providers/embedding/azure_openai_embedding_client.py +335 -0
- rasa/shared/providers/embedding/default_litellm_embedding_client.py +126 -0
- rasa/shared/providers/embedding/embedding_client.py +90 -0
- rasa/shared/providers/embedding/embedding_response.py +41 -0
- rasa/shared/providers/embedding/huggingface_local_embedding_client.py +191 -0
- rasa/shared/providers/embedding/litellm_router_embedding_client.py +138 -0
- rasa/shared/providers/embedding/openai_embedding_client.py +172 -0
- rasa/shared/providers/llm/__init__.py +0 -0
- rasa/shared/providers/llm/_base_litellm_client.py +265 -0
- rasa/shared/providers/llm/azure_openai_llm_client.py +415 -0
- rasa/shared/providers/llm/default_litellm_llm_client.py +110 -0
- rasa/shared/providers/llm/litellm_router_llm_client.py +202 -0
- rasa/shared/providers/llm/llm_client.py +78 -0
- rasa/shared/providers/llm/llm_response.py +50 -0
- rasa/shared/providers/llm/openai_llm_client.py +161 -0
- rasa/shared/providers/llm/rasa_llm_client.py +120 -0
- rasa/shared/providers/llm/self_hosted_llm_client.py +276 -0
- rasa/shared/providers/mappings.py +94 -0
- rasa/shared/providers/router/__init__.py +0 -0
- rasa/shared/providers/router/_base_litellm_router_client.py +185 -0
- rasa/shared/providers/router/router_client.py +75 -0
- rasa/shared/utils/__init__.py +0 -0
- rasa/shared/utils/cli.py +102 -0
- rasa/shared/utils/common.py +324 -0
- rasa/shared/utils/constants.py +4 -0
- rasa/shared/utils/health_check/__init__.py +0 -0
- rasa/shared/utils/health_check/embeddings_health_check_mixin.py +31 -0
- rasa/shared/utils/health_check/health_check.py +258 -0
- rasa/shared/utils/health_check/llm_health_check_mixin.py +31 -0
- rasa/shared/utils/io.py +499 -0
- rasa/shared/utils/llm.py +764 -0
- rasa/shared/utils/pykwalify_extensions.py +27 -0
- rasa/shared/utils/schemas/__init__.py +0 -0
- rasa/shared/utils/schemas/config.yml +2 -0
- rasa/shared/utils/schemas/domain.yml +145 -0
- rasa/shared/utils/schemas/events.py +214 -0
- rasa/shared/utils/schemas/model_config.yml +36 -0
- rasa/shared/utils/schemas/stories.yml +173 -0
- rasa/shared/utils/yaml.py +1068 -0
- rasa/studio/__init__.py +0 -0
- rasa/studio/auth.py +270 -0
- rasa/studio/config.py +136 -0
- rasa/studio/constants.py +19 -0
- rasa/studio/data_handler.py +368 -0
- rasa/studio/download.py +489 -0
- rasa/studio/results_logger.py +137 -0
- rasa/studio/train.py +134 -0
- rasa/studio/upload.py +563 -0
- rasa/telemetry.py +1876 -0
- rasa/tracing/__init__.py +0 -0
- rasa/tracing/config.py +355 -0
- rasa/tracing/constants.py +62 -0
- rasa/tracing/instrumentation/__init__.py +0 -0
- rasa/tracing/instrumentation/attribute_extractors.py +765 -0
- rasa/tracing/instrumentation/instrumentation.py +1306 -0
- rasa/tracing/instrumentation/intentless_policy_instrumentation.py +144 -0
- rasa/tracing/instrumentation/metrics.py +294 -0
- rasa/tracing/metric_instrument_provider.py +205 -0
- rasa/utils/__init__.py +0 -0
- rasa/utils/beta.py +83 -0
- rasa/utils/cli.py +28 -0
- rasa/utils/common.py +639 -0
- rasa/utils/converter.py +53 -0
- rasa/utils/endpoints.py +331 -0
- rasa/utils/io.py +252 -0
- rasa/utils/json_utils.py +60 -0
- rasa/utils/licensing.py +542 -0
- rasa/utils/log_utils.py +181 -0
- rasa/utils/mapper.py +210 -0
- rasa/utils/ml_utils.py +147 -0
- rasa/utils/plotting.py +362 -0
- rasa/utils/sanic_error_handler.py +32 -0
- rasa/utils/singleton.py +23 -0
- rasa/utils/tensorflow/__init__.py +0 -0
- rasa/utils/tensorflow/callback.py +112 -0
- rasa/utils/tensorflow/constants.py +116 -0
- rasa/utils/tensorflow/crf.py +492 -0
- rasa/utils/tensorflow/data_generator.py +440 -0
- rasa/utils/tensorflow/environment.py +161 -0
- rasa/utils/tensorflow/exceptions.py +5 -0
- rasa/utils/tensorflow/feature_array.py +366 -0
- rasa/utils/tensorflow/layers.py +1565 -0
- rasa/utils/tensorflow/layers_utils.py +113 -0
- rasa/utils/tensorflow/metrics.py +281 -0
- rasa/utils/tensorflow/model_data.py +798 -0
- rasa/utils/tensorflow/model_data_utils.py +499 -0
- rasa/utils/tensorflow/models.py +935 -0
- rasa/utils/tensorflow/rasa_layers.py +1094 -0
- rasa/utils/tensorflow/transformer.py +640 -0
- rasa/utils/tensorflow/types.py +6 -0
- rasa/utils/train_utils.py +572 -0
- rasa/utils/url_tools.py +53 -0
- rasa/utils/yaml.py +54 -0
- rasa/validator.py +1644 -0
- rasa/version.py +3 -0
- rasa_pro-3.12.0.dev1.dist-info/METADATA +199 -0
- rasa_pro-3.12.0.dev1.dist-info/NOTICE +5 -0
- rasa_pro-3.12.0.dev1.dist-info/RECORD +790 -0
- rasa_pro-3.12.0.dev1.dist-info/WHEEL +4 -0
- rasa_pro-3.12.0.dev1.dist-info/entry_points.txt +3 -0
|
@@ -0,0 +1,1273 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
import copy
|
|
3
|
+
import functools
|
|
4
|
+
import logging
|
|
5
|
+
import structlog
|
|
6
|
+
from typing import Any, List, DefaultDict, Dict, Text, Optional, Set, Tuple, cast
|
|
7
|
+
|
|
8
|
+
from tqdm import tqdm
|
|
9
|
+
import numpy as np
|
|
10
|
+
import json
|
|
11
|
+
from collections import defaultdict
|
|
12
|
+
|
|
13
|
+
from rasa.engine.graph import ExecutionContext
|
|
14
|
+
from rasa.engine.recipes.default_recipe import DefaultV1Recipe
|
|
15
|
+
from rasa.engine.storage.resource import Resource
|
|
16
|
+
from rasa.engine.storage.storage import ModelStorage
|
|
17
|
+
from rasa.shared.constants import DOCS_URL_RULES
|
|
18
|
+
from rasa.shared.exceptions import RasaException
|
|
19
|
+
import rasa.shared.utils.io
|
|
20
|
+
from rasa.shared.core.events import LoopInterrupted, UserUttered, ActionExecuted
|
|
21
|
+
from rasa.core.featurizers.tracker_featurizers import TrackerFeaturizer
|
|
22
|
+
from rasa.core.policies.memoization import MemoizationPolicy
|
|
23
|
+
from rasa.core.policies.policy import SupportedData, PolicyPrediction
|
|
24
|
+
from rasa.shared.core.trackers import (
|
|
25
|
+
DialogueStateTracker,
|
|
26
|
+
get_active_loop_name,
|
|
27
|
+
is_prev_action_listen_in_state,
|
|
28
|
+
)
|
|
29
|
+
from rasa.shared.core.generator import TrackerWithCachedStates
|
|
30
|
+
from rasa.core.constants import (
|
|
31
|
+
DEFAULT_CORE_FALLBACK_THRESHOLD,
|
|
32
|
+
RULE_POLICY_PRIORITY,
|
|
33
|
+
POLICY_PRIORITY,
|
|
34
|
+
POLICY_MAX_HISTORY,
|
|
35
|
+
)
|
|
36
|
+
from rasa.shared.core.constants import (
|
|
37
|
+
USER_INTENT_RESTART,
|
|
38
|
+
USER_INTENT_BACK,
|
|
39
|
+
USER_INTENT_SESSION_START,
|
|
40
|
+
ACTION_LISTEN_NAME,
|
|
41
|
+
ACTION_RESTART_NAME,
|
|
42
|
+
ACTION_SESSION_START_NAME,
|
|
43
|
+
ACTION_DEFAULT_FALLBACK_NAME,
|
|
44
|
+
ACTION_BACK_NAME,
|
|
45
|
+
RULE_SNIPPET_ACTION_NAME,
|
|
46
|
+
SHOULD_NOT_BE_SET,
|
|
47
|
+
PREVIOUS_ACTION,
|
|
48
|
+
LOOP_NAME,
|
|
49
|
+
SLOTS,
|
|
50
|
+
ACTIVE_LOOP,
|
|
51
|
+
RULE_ONLY_SLOTS,
|
|
52
|
+
RULE_ONLY_LOOPS,
|
|
53
|
+
)
|
|
54
|
+
from rasa.shared.core.domain import InvalidDomain, State, Domain
|
|
55
|
+
from rasa.shared.nlu.constants import ACTION_NAME, INTENT_NAME_KEY
|
|
56
|
+
import rasa.core.test
|
|
57
|
+
from rasa.core.training.training import create_action_fingerprints, ActionFingerprint
|
|
58
|
+
|
|
59
|
+
logger = logging.getLogger(__name__)
|
|
60
|
+
structlogger = structlog.get_logger()
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
# These are Rasa Pro default actions and overrule everything at any time.
|
|
64
|
+
DEFAULT_ACTION_MAPPINGS = {
|
|
65
|
+
USER_INTENT_RESTART: ACTION_RESTART_NAME,
|
|
66
|
+
USER_INTENT_BACK: ACTION_BACK_NAME,
|
|
67
|
+
USER_INTENT_SESSION_START: ACTION_SESSION_START_NAME,
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
RULES = "rules"
|
|
71
|
+
RULES_FOR_LOOP_UNHAPPY_PATH = "rules_for_loop_unhappy_path"
|
|
72
|
+
RULES_NOT_IN_STORIES = "rules_not_in_stories"
|
|
73
|
+
|
|
74
|
+
LOOP_WAS_INTERRUPTED = "loop_was_interrupted"
|
|
75
|
+
DO_NOT_PREDICT_LOOP_ACTION = "do_not_predict_loop_action"
|
|
76
|
+
|
|
77
|
+
DEFAULT_RULES = "predicting default action with intent "
|
|
78
|
+
LOOP_RULES = "handling active loops and forms - "
|
|
79
|
+
LOOP_RULES_SEPARATOR = " - "
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class InvalidRule(RasaException):
|
|
83
|
+
"""Exception that can be raised when rules are not valid."""
|
|
84
|
+
|
|
85
|
+
def __init__(self, message: Text) -> None:
|
|
86
|
+
super().__init__()
|
|
87
|
+
self.message = message
|
|
88
|
+
|
|
89
|
+
def __str__(self) -> Text:
|
|
90
|
+
return self.message + (
|
|
91
|
+
f"\nYou can find more information about the usage of "
|
|
92
|
+
f"rules at {DOCS_URL_RULES}. "
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
@DefaultV1Recipe.register(
|
|
97
|
+
DefaultV1Recipe.ComponentType.POLICY_WITHOUT_END_TO_END_SUPPORT, is_trainable=True
|
|
98
|
+
)
|
|
99
|
+
class RulePolicy(MemoizationPolicy):
|
|
100
|
+
"""Policy which handles all the rules."""
|
|
101
|
+
|
|
102
|
+
# rules use explicit json strings
|
|
103
|
+
ENABLE_FEATURE_STRING_COMPRESSION = False
|
|
104
|
+
|
|
105
|
+
# number of user inputs that is allowed in case rules are restricted
|
|
106
|
+
ALLOWED_NUMBER_OF_USER_INPUTS = 1
|
|
107
|
+
|
|
108
|
+
@staticmethod
|
|
109
|
+
def supported_data() -> SupportedData:
|
|
110
|
+
"""The type of data supported by this policy.
|
|
111
|
+
|
|
112
|
+
Returns:
|
|
113
|
+
The data type supported by this policy (ML and rule data).
|
|
114
|
+
"""
|
|
115
|
+
return SupportedData.ML_AND_RULE_DATA
|
|
116
|
+
|
|
117
|
+
@staticmethod
|
|
118
|
+
def get_default_config() -> Dict[Text, Any]:
|
|
119
|
+
"""Returns the default config (see parent class for full docstring)."""
|
|
120
|
+
return {
|
|
121
|
+
# Priority of the policy which is used if multiple policies predict
|
|
122
|
+
# actions with the same confidence.
|
|
123
|
+
POLICY_PRIORITY: RULE_POLICY_PRIORITY,
|
|
124
|
+
# Confidence of the prediction if no rule matched and de-facto
|
|
125
|
+
# threshold for a core fallback.
|
|
126
|
+
"core_fallback_threshold": DEFAULT_CORE_FALLBACK_THRESHOLD,
|
|
127
|
+
# Name of the action which should be predicted if no rule matched.
|
|
128
|
+
"core_fallback_action_name": ACTION_DEFAULT_FALLBACK_NAME,
|
|
129
|
+
# If `True` `core_fallback_action_name` is predicted in case no rule
|
|
130
|
+
# matched.
|
|
131
|
+
"enable_fallback_prediction": True,
|
|
132
|
+
# If `True` rules are restricted to contain a maximum of 1
|
|
133
|
+
# user message. This is used to avoid that users build a state machine
|
|
134
|
+
# using the rules.
|
|
135
|
+
"restrict_rules": True,
|
|
136
|
+
# Whether to check for contradictions between rules and stories
|
|
137
|
+
"check_for_contradictions": True,
|
|
138
|
+
# the policy will use the confidence of NLU on the latest
|
|
139
|
+
# user message to set the confidence of the action
|
|
140
|
+
"use_nlu_confidence_as_score": False,
|
|
141
|
+
}
|
|
142
|
+
|
|
143
|
+
def __init__(
|
|
144
|
+
self,
|
|
145
|
+
config: Dict[Text, Any],
|
|
146
|
+
model_storage: ModelStorage,
|
|
147
|
+
resource: Resource,
|
|
148
|
+
execution_context: ExecutionContext,
|
|
149
|
+
featurizer: Optional[TrackerFeaturizer] = None,
|
|
150
|
+
lookup: Optional[Dict] = None,
|
|
151
|
+
) -> None:
|
|
152
|
+
"""Initializes the policy."""
|
|
153
|
+
# max history is set to `None` in order to capture any lengths of rule stories
|
|
154
|
+
config[POLICY_MAX_HISTORY] = None
|
|
155
|
+
|
|
156
|
+
super().__init__(
|
|
157
|
+
config, model_storage, resource, execution_context, featurizer, lookup
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
self._fallback_action_name = config["core_fallback_action_name"]
|
|
161
|
+
self._enable_fallback_prediction = config["enable_fallback_prediction"]
|
|
162
|
+
self._check_for_contradictions = config["check_for_contradictions"]
|
|
163
|
+
|
|
164
|
+
self._rules_sources: DefaultDict[Text, List[Tuple[Text, Text]]] = defaultdict(
|
|
165
|
+
list
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
@classmethod
|
|
169
|
+
def raise_if_incompatible_with_domain(
|
|
170
|
+
cls, config: Dict[Text, Any], domain: Domain
|
|
171
|
+
) -> None:
|
|
172
|
+
"""Checks whether the domains action names match the configured fallback.
|
|
173
|
+
|
|
174
|
+
Args:
|
|
175
|
+
config: configuration of a `RulePolicy`
|
|
176
|
+
domain: a domain
|
|
177
|
+
Raises:
|
|
178
|
+
`InvalidDomain` if this policy is incompatible with the domain
|
|
179
|
+
"""
|
|
180
|
+
fallback_action_name = config.get("core_fallback_action_name", None)
|
|
181
|
+
if (
|
|
182
|
+
fallback_action_name
|
|
183
|
+
and fallback_action_name not in domain.action_names_or_texts
|
|
184
|
+
):
|
|
185
|
+
raise InvalidDomain(
|
|
186
|
+
f"The fallback action '{fallback_action_name}' which was "
|
|
187
|
+
f"configured for the {RulePolicy.__name__} must be "
|
|
188
|
+
f"present in the domain."
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
@staticmethod
|
|
192
|
+
def _is_rule_snippet_state(state: State) -> bool:
|
|
193
|
+
prev_action_name = state.get(PREVIOUS_ACTION, {}).get(ACTION_NAME)
|
|
194
|
+
return prev_action_name == RULE_SNIPPET_ACTION_NAME
|
|
195
|
+
|
|
196
|
+
def _create_feature_key(self, states: List[State]) -> Optional[Text]:
|
|
197
|
+
new_states: List[State] = []
|
|
198
|
+
for state in reversed(states):
|
|
199
|
+
if self._is_rule_snippet_state(state):
|
|
200
|
+
# remove all states before RULE_SNIPPET_ACTION_NAME
|
|
201
|
+
break
|
|
202
|
+
new_states.insert(0, state)
|
|
203
|
+
|
|
204
|
+
if not new_states:
|
|
205
|
+
return None
|
|
206
|
+
|
|
207
|
+
# we sort keys to make sure that the same states
|
|
208
|
+
# represented as dictionaries have the same json strings
|
|
209
|
+
return json.dumps(new_states, sort_keys=True)
|
|
210
|
+
|
|
211
|
+
@staticmethod
|
|
212
|
+
def _states_for_unhappy_loop_predictions(states: List[State]) -> List[State]:
|
|
213
|
+
"""Modifies the states to create feature keys for loop unhappy path conditions.
|
|
214
|
+
|
|
215
|
+
Args:
|
|
216
|
+
states: a representation of a tracker
|
|
217
|
+
as a list of dictionaries containing features
|
|
218
|
+
|
|
219
|
+
Returns:
|
|
220
|
+
modified states
|
|
221
|
+
"""
|
|
222
|
+
# leave only last 2 dialogue turns to
|
|
223
|
+
# - capture previous meaningful action before action_listen
|
|
224
|
+
# - ignore previous intent
|
|
225
|
+
if len(states) == 1 or not states[-2].get(PREVIOUS_ACTION):
|
|
226
|
+
return [states[-1]]
|
|
227
|
+
else:
|
|
228
|
+
return [{PREVIOUS_ACTION: states[-2][PREVIOUS_ACTION]}, states[-1]]
|
|
229
|
+
|
|
230
|
+
@staticmethod
|
|
231
|
+
def _remove_rule_snippet_predictions(lookup: Dict[Text, Text]) -> Dict[Text, Text]:
|
|
232
|
+
# Delete rules if it would predict the RULE_SNIPPET_ACTION_NAME action
|
|
233
|
+
return {
|
|
234
|
+
feature_key: action
|
|
235
|
+
for feature_key, action in lookup.items()
|
|
236
|
+
if action != RULE_SNIPPET_ACTION_NAME
|
|
237
|
+
}
|
|
238
|
+
|
|
239
|
+
def _create_loop_unhappy_lookup_from_states(
|
|
240
|
+
self,
|
|
241
|
+
trackers_as_states: List[List[State]],
|
|
242
|
+
trackers_as_actions: List[List[Text]],
|
|
243
|
+
) -> Dict[Text, Text]:
|
|
244
|
+
"""Creates lookup dictionary from the tracker represented as states.
|
|
245
|
+
|
|
246
|
+
Args:
|
|
247
|
+
trackers_as_states: representation of the trackers as a list of states
|
|
248
|
+
trackers_as_actions: representation of the trackers as a list of actions
|
|
249
|
+
|
|
250
|
+
Returns:
|
|
251
|
+
lookup dictionary
|
|
252
|
+
"""
|
|
253
|
+
lookup = {}
|
|
254
|
+
for states, actions in zip(trackers_as_states, trackers_as_actions):
|
|
255
|
+
action = actions[0]
|
|
256
|
+
active_loop = get_active_loop_name(states[-1])
|
|
257
|
+
# even if there are two identical feature keys
|
|
258
|
+
# their loop will be the same
|
|
259
|
+
if not active_loop:
|
|
260
|
+
continue
|
|
261
|
+
|
|
262
|
+
states = self._states_for_unhappy_loop_predictions(states)
|
|
263
|
+
feature_key = self._create_feature_key(states)
|
|
264
|
+
if not feature_key:
|
|
265
|
+
continue
|
|
266
|
+
|
|
267
|
+
# Since rule snippets and stories inside the loop contain
|
|
268
|
+
# only unhappy paths, notify the loop that
|
|
269
|
+
# it was predicted after an answer to a different question and
|
|
270
|
+
# therefore it should not validate user input
|
|
271
|
+
if (
|
|
272
|
+
# loop is predicted after action_listen in unhappy path,
|
|
273
|
+
# therefore no validation is needed
|
|
274
|
+
is_prev_action_listen_in_state(states[-1]) and action == active_loop
|
|
275
|
+
):
|
|
276
|
+
lookup[feature_key] = LOOP_WAS_INTERRUPTED
|
|
277
|
+
elif (
|
|
278
|
+
# some action other than active_loop is predicted in unhappy path,
|
|
279
|
+
# therefore active_loop shouldn't be predicted by the rule
|
|
280
|
+
not is_prev_action_listen_in_state(states[-1]) and action != active_loop
|
|
281
|
+
):
|
|
282
|
+
lookup[feature_key] = DO_NOT_PREDICT_LOOP_ACTION
|
|
283
|
+
return lookup
|
|
284
|
+
|
|
285
|
+
def _check_rule_restriction(
|
|
286
|
+
self, rule_trackers: List[TrackerWithCachedStates]
|
|
287
|
+
) -> None:
|
|
288
|
+
rules_exceeding_max_user_turns = []
|
|
289
|
+
for tracker in rule_trackers:
|
|
290
|
+
number_of_user_uttered = sum(
|
|
291
|
+
isinstance(event, UserUttered) for event in tracker.events
|
|
292
|
+
)
|
|
293
|
+
if number_of_user_uttered > self.ALLOWED_NUMBER_OF_USER_INPUTS:
|
|
294
|
+
rules_exceeding_max_user_turns.append(tracker.sender_id)
|
|
295
|
+
|
|
296
|
+
if rules_exceeding_max_user_turns:
|
|
297
|
+
raise InvalidRule(
|
|
298
|
+
f"Found rules '{', '.join(rules_exceeding_max_user_turns)}' "
|
|
299
|
+
f"that contain more than {self.ALLOWED_NUMBER_OF_USER_INPUTS} "
|
|
300
|
+
f"user message. Rules are not meant to hardcode a state machine. "
|
|
301
|
+
f"Please use stories for these cases."
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
@staticmethod
|
|
305
|
+
def _expected_but_missing_slots(
|
|
306
|
+
fingerprint: ActionFingerprint, state: State
|
|
307
|
+
) -> Set[Text]:
|
|
308
|
+
expected_slots = set(fingerprint.slots)
|
|
309
|
+
current_slots = set(state.get(SLOTS, {}).keys())
|
|
310
|
+
# report all slots that are expected but aren't set in current slots
|
|
311
|
+
return expected_slots.difference(current_slots)
|
|
312
|
+
|
|
313
|
+
@staticmethod
|
|
314
|
+
def _check_active_loops_fingerprint(
|
|
315
|
+
fingerprint: ActionFingerprint, state: State
|
|
316
|
+
) -> Set[Optional[Text]]:
|
|
317
|
+
expected_active_loops = set(fingerprint.active_loop)
|
|
318
|
+
# we don't use tracker.active_loop_name
|
|
319
|
+
# because we need to keep should_not_be_set
|
|
320
|
+
current_active_loop = state.get(ACTIVE_LOOP, {}).get(LOOP_NAME)
|
|
321
|
+
if current_active_loop in expected_active_loops:
|
|
322
|
+
# one of expected active loops is set
|
|
323
|
+
return set()
|
|
324
|
+
|
|
325
|
+
return expected_active_loops
|
|
326
|
+
|
|
327
|
+
@staticmethod
|
|
328
|
+
def _error_messages_from_fingerprints(
|
|
329
|
+
action_name: Text,
|
|
330
|
+
missing_fingerprint_slots: Set[Text],
|
|
331
|
+
fingerprint_active_loops: Set[Text],
|
|
332
|
+
rule_name: Text,
|
|
333
|
+
) -> List[Text]:
|
|
334
|
+
error_messages = []
|
|
335
|
+
if action_name and missing_fingerprint_slots:
|
|
336
|
+
error_messages.append(
|
|
337
|
+
f"- the action '{action_name}' in rule '{rule_name}' does not set some "
|
|
338
|
+
f"of the slots that it sets in other rules. Slots not set in rule "
|
|
339
|
+
f"'{rule_name}': '{', '.join(missing_fingerprint_slots)}'. Please "
|
|
340
|
+
f"update the rule with an appropriate slot or if it is the last action "
|
|
341
|
+
f"add 'wait_for_user_input: false' after this action."
|
|
342
|
+
)
|
|
343
|
+
if action_name and fingerprint_active_loops:
|
|
344
|
+
# substitute `SHOULD_NOT_BE_SET` with `null` so that users
|
|
345
|
+
# know what to put in their rules
|
|
346
|
+
fingerprint_active_loops = set(
|
|
347
|
+
"null" if active_loop == SHOULD_NOT_BE_SET else active_loop
|
|
348
|
+
for active_loop in fingerprint_active_loops
|
|
349
|
+
)
|
|
350
|
+
# add action_name to active loop so that users
|
|
351
|
+
# know what to put in their rules
|
|
352
|
+
fingerprint_active_loops.add(action_name)
|
|
353
|
+
|
|
354
|
+
error_messages.append(
|
|
355
|
+
f"- the form '{action_name}' in rule '{rule_name}' does not set "
|
|
356
|
+
f"the 'active_loop', that it sets in other rules: "
|
|
357
|
+
f"'{', '.join(fingerprint_active_loops)}'. Please update the rule with "
|
|
358
|
+
f"the appropriate 'active loop' property or if it is the last action "
|
|
359
|
+
f"add 'wait_for_user_input: false' after this action."
|
|
360
|
+
)
|
|
361
|
+
return error_messages
|
|
362
|
+
|
|
363
|
+
def _check_for_incomplete_rules(
|
|
364
|
+
self, rule_trackers: List[TrackerWithCachedStates], domain: Domain
|
|
365
|
+
) -> None:
|
|
366
|
+
logger.debug("Started checking if some rules are incomplete.")
|
|
367
|
+
# we need to use only fingerprints from rules
|
|
368
|
+
rule_fingerprints = create_action_fingerprints(rule_trackers, domain)
|
|
369
|
+
if not rule_fingerprints:
|
|
370
|
+
return
|
|
371
|
+
|
|
372
|
+
error_messages: List[Text] = []
|
|
373
|
+
for tracker in rule_trackers:
|
|
374
|
+
states = tracker.past_states(domain)
|
|
375
|
+
# the last action is always action listen
|
|
376
|
+
action_names = [
|
|
377
|
+
state.get(PREVIOUS_ACTION, {}).get(ACTION_NAME) for state in states[1:]
|
|
378
|
+
] + [ACTION_LISTEN_NAME]
|
|
379
|
+
|
|
380
|
+
for state, action_name in zip(states, action_names):
|
|
381
|
+
previous_action_name = state.get(PREVIOUS_ACTION, {}).get(ACTION_NAME)
|
|
382
|
+
fingerprint = rule_fingerprints.get(previous_action_name)
|
|
383
|
+
if (
|
|
384
|
+
not previous_action_name
|
|
385
|
+
or not fingerprint
|
|
386
|
+
or action_name == RULE_SNIPPET_ACTION_NAME
|
|
387
|
+
or previous_action_name == RULE_SNIPPET_ACTION_NAME
|
|
388
|
+
):
|
|
389
|
+
# do not check fingerprints for rule snippet action
|
|
390
|
+
# and don't raise if fingerprints are not satisfied
|
|
391
|
+
# for a previous action if current action is rule snippet action
|
|
392
|
+
continue
|
|
393
|
+
|
|
394
|
+
missing_expected_slots = self._expected_but_missing_slots(
|
|
395
|
+
fingerprint, state
|
|
396
|
+
)
|
|
397
|
+
expected_active_loops = self._check_active_loops_fingerprint(
|
|
398
|
+
fingerprint, state
|
|
399
|
+
)
|
|
400
|
+
error_messages.extend(
|
|
401
|
+
self._error_messages_from_fingerprints(
|
|
402
|
+
previous_action_name,
|
|
403
|
+
missing_expected_slots,
|
|
404
|
+
expected_active_loops,
|
|
405
|
+
tracker.sender_id,
|
|
406
|
+
)
|
|
407
|
+
)
|
|
408
|
+
|
|
409
|
+
if error_messages:
|
|
410
|
+
error_text = "\n".join(error_messages)
|
|
411
|
+
raise InvalidRule(
|
|
412
|
+
f"\nIncomplete rules found🚨\n\n{error_text}\n"
|
|
413
|
+
f"Please note that if some slots or active loops should not be set "
|
|
414
|
+
f"during prediction you need to explicitly set them to 'null' in the "
|
|
415
|
+
f"rules."
|
|
416
|
+
)
|
|
417
|
+
|
|
418
|
+
logger.debug("Found no incompletions in rules.")
|
|
419
|
+
|
|
420
|
+
@staticmethod
|
|
421
|
+
def _get_slots_loops_from_states(
|
|
422
|
+
trackers_as_states: List[List[State]],
|
|
423
|
+
) -> Tuple[Set[Text], Set[Text]]:
|
|
424
|
+
slots = set()
|
|
425
|
+
loops = set()
|
|
426
|
+
for states in trackers_as_states:
|
|
427
|
+
for state in states:
|
|
428
|
+
slots.update(set(state.get(SLOTS, {}).keys()))
|
|
429
|
+
# FIXME: ideally we have better annotation for State, TypedDict
|
|
430
|
+
# could work but support in mypy is very limited. Dataclass are
|
|
431
|
+
# another option
|
|
432
|
+
active_loop = cast(Text, state.get(ACTIVE_LOOP, {}).get(LOOP_NAME))
|
|
433
|
+
if active_loop:
|
|
434
|
+
loops.add(active_loop)
|
|
435
|
+
return slots, loops
|
|
436
|
+
|
|
437
|
+
def _find_rule_only_slots_loops(
|
|
438
|
+
self,
|
|
439
|
+
rule_trackers_as_states: List[List[State]],
|
|
440
|
+
story_trackers_as_states: List[List[State]],
|
|
441
|
+
) -> Tuple[List[Text], List[Text]]:
|
|
442
|
+
rule_slots, rule_loops = self._get_slots_loops_from_states(
|
|
443
|
+
rule_trackers_as_states
|
|
444
|
+
)
|
|
445
|
+
story_slots, story_loops = self._get_slots_loops_from_states(
|
|
446
|
+
story_trackers_as_states
|
|
447
|
+
)
|
|
448
|
+
|
|
449
|
+
# set is not json serializable, so convert to list
|
|
450
|
+
return (
|
|
451
|
+
list(rule_slots - story_slots - {SHOULD_NOT_BE_SET}),
|
|
452
|
+
list(rule_loops - story_loops - {SHOULD_NOT_BE_SET}),
|
|
453
|
+
)
|
|
454
|
+
|
|
455
|
+
def _predict_next_action(
|
|
456
|
+
self, tracker: TrackerWithCachedStates, domain: Domain
|
|
457
|
+
) -> Tuple[Optional[Text], Optional[Text]]:
|
|
458
|
+
prediction, prediction_source = self._predict(tracker, domain)
|
|
459
|
+
probabilities = prediction.probabilities
|
|
460
|
+
# do not raise an error if RulePolicy didn't predict anything for stories;
|
|
461
|
+
# however for rules RulePolicy should always predict an action
|
|
462
|
+
predicted_action_name = None
|
|
463
|
+
if (
|
|
464
|
+
probabilities != self._default_predictions(domain)
|
|
465
|
+
or tracker.is_rule_tracker
|
|
466
|
+
):
|
|
467
|
+
predicted_action_name = domain.action_names_or_texts[
|
|
468
|
+
np.argmax(probabilities)
|
|
469
|
+
]
|
|
470
|
+
|
|
471
|
+
return predicted_action_name, prediction_source
|
|
472
|
+
|
|
473
|
+
def _predicted_action_name(
|
|
474
|
+
self, tracker: TrackerWithCachedStates, domain: Domain, gold_action_name: Text
|
|
475
|
+
) -> Tuple[Optional[Text], Optional[Text]]:
|
|
476
|
+
predicted_action_name, prediction_source = self._predict_next_action(
|
|
477
|
+
tracker, domain
|
|
478
|
+
)
|
|
479
|
+
# if there is an active_loop,
|
|
480
|
+
# RulePolicy will always predict active_loop first,
|
|
481
|
+
# but inside loop unhappy path there might be another action
|
|
482
|
+
if (
|
|
483
|
+
tracker.active_loop_name
|
|
484
|
+
and predicted_action_name != gold_action_name
|
|
485
|
+
and predicted_action_name == tracker.active_loop_name
|
|
486
|
+
):
|
|
487
|
+
rasa.core.test.emulate_loop_rejection(tracker)
|
|
488
|
+
predicted_action_name, prediction_source = self._predict_next_action(
|
|
489
|
+
tracker, domain
|
|
490
|
+
)
|
|
491
|
+
|
|
492
|
+
return predicted_action_name, prediction_source
|
|
493
|
+
|
|
494
|
+
def _collect_sources(
|
|
495
|
+
self,
|
|
496
|
+
tracker: TrackerWithCachedStates,
|
|
497
|
+
predicted_action_name: Optional[Text],
|
|
498
|
+
gold_action_name: Optional[Text],
|
|
499
|
+
prediction_source: Text,
|
|
500
|
+
) -> None:
|
|
501
|
+
# we need to remember which action should be predicted by the rule
|
|
502
|
+
# in order to correctly output the names of the contradicting rules
|
|
503
|
+
rule_name = tracker.sender_id
|
|
504
|
+
|
|
505
|
+
if prediction_source is not None and (
|
|
506
|
+
prediction_source.startswith(DEFAULT_RULES)
|
|
507
|
+
or prediction_source.startswith(LOOP_RULES)
|
|
508
|
+
):
|
|
509
|
+
# the real gold action contradict the one in the rules in this case
|
|
510
|
+
gold_action_name = predicted_action_name
|
|
511
|
+
rule_name = prediction_source
|
|
512
|
+
|
|
513
|
+
self._rules_sources[prediction_source].append((rule_name, gold_action_name))
|
|
514
|
+
|
|
515
|
+
@staticmethod
|
|
516
|
+
def _default_sources() -> Set[Text]:
|
|
517
|
+
return {
|
|
518
|
+
DEFAULT_RULES + default_intent
|
|
519
|
+
for default_intent in DEFAULT_ACTION_MAPPINGS.keys()
|
|
520
|
+
}
|
|
521
|
+
|
|
522
|
+
@staticmethod
|
|
523
|
+
def _handling_loop_sources(domain: Domain) -> Set[Text]:
|
|
524
|
+
loop_sources = set()
|
|
525
|
+
for loop_name in domain.form_names:
|
|
526
|
+
loop_sources.add(LOOP_RULES + loop_name)
|
|
527
|
+
loop_sources.add(
|
|
528
|
+
LOOP_RULES + loop_name + LOOP_RULES_SEPARATOR + ACTION_LISTEN_NAME
|
|
529
|
+
)
|
|
530
|
+
return loop_sources
|
|
531
|
+
|
|
532
|
+
def _should_delete(
|
|
533
|
+
self,
|
|
534
|
+
prediction_source: Text,
|
|
535
|
+
tracker: TrackerWithCachedStates,
|
|
536
|
+
predicted_action_name: Text,
|
|
537
|
+
) -> bool:
|
|
538
|
+
"""Checks whether this contradiction is due to action, intent pair.
|
|
539
|
+
|
|
540
|
+
Args:
|
|
541
|
+
prediction_source: the states that result in the prediction
|
|
542
|
+
tracker: the tracker that raises the contradiction
|
|
543
|
+
predicted_action_name: the action that was predicted
|
|
544
|
+
|
|
545
|
+
Returns:
|
|
546
|
+
true if the contradiction is a result of an action, intent pair in the rule.
|
|
547
|
+
"""
|
|
548
|
+
if (
|
|
549
|
+
# only apply to contradicting story, not rule
|
|
550
|
+
tracker.is_rule_tracker
|
|
551
|
+
# only apply for prediction after unpredictable action
|
|
552
|
+
or prediction_source.count(PREVIOUS_ACTION) > 1
|
|
553
|
+
# only apply for prediction of action_listen
|
|
554
|
+
or predicted_action_name != ACTION_LISTEN_NAME
|
|
555
|
+
):
|
|
556
|
+
return False
|
|
557
|
+
for source in self.lookup[RULES]:
|
|
558
|
+
# remove rule only if another action is predicted after action_listen
|
|
559
|
+
if (
|
|
560
|
+
source.startswith(prediction_source[:-2])
|
|
561
|
+
and not prediction_source == source
|
|
562
|
+
):
|
|
563
|
+
return True
|
|
564
|
+
return False
|
|
565
|
+
|
|
566
|
+
def _check_prediction(
|
|
567
|
+
self,
|
|
568
|
+
tracker: TrackerWithCachedStates,
|
|
569
|
+
predicted_action_name: Optional[Text],
|
|
570
|
+
gold_action_name: Text,
|
|
571
|
+
prediction_source: Optional[Text],
|
|
572
|
+
) -> List[Text]:
|
|
573
|
+
# FIXME: `predicted_action_name` and `prediction_source` are
|
|
574
|
+
# either None together or defined together. This could be improved
|
|
575
|
+
# by better typing in this class, but requires some refactoring
|
|
576
|
+
if (
|
|
577
|
+
not predicted_action_name
|
|
578
|
+
or not prediction_source
|
|
579
|
+
or predicted_action_name == gold_action_name
|
|
580
|
+
):
|
|
581
|
+
return []
|
|
582
|
+
|
|
583
|
+
if self._should_delete(prediction_source, tracker, predicted_action_name):
|
|
584
|
+
self.lookup[RULES].pop(prediction_source)
|
|
585
|
+
return []
|
|
586
|
+
|
|
587
|
+
tracker_type = "rule" if tracker.is_rule_tracker else "story"
|
|
588
|
+
contradicting_rules = {
|
|
589
|
+
rule_name
|
|
590
|
+
for rule_name, action_name in self._rules_sources[prediction_source]
|
|
591
|
+
if action_name != gold_action_name
|
|
592
|
+
}
|
|
593
|
+
|
|
594
|
+
if not contradicting_rules:
|
|
595
|
+
return []
|
|
596
|
+
|
|
597
|
+
error_message = (
|
|
598
|
+
f"- the prediction of the action '{gold_action_name}' in {tracker_type} "
|
|
599
|
+
f"'{tracker.sender_id}' "
|
|
600
|
+
f"is contradicting with rule(s) '{', '.join(contradicting_rules)}'"
|
|
601
|
+
)
|
|
602
|
+
# outputting predicted action 'action_default_fallback' is confusing
|
|
603
|
+
if predicted_action_name != self._fallback_action_name:
|
|
604
|
+
error_message += f" which predicted action '{predicted_action_name}'"
|
|
605
|
+
|
|
606
|
+
return [error_message + "."]
|
|
607
|
+
|
|
608
|
+
def _run_prediction_on_trackers(
|
|
609
|
+
self,
|
|
610
|
+
trackers: List[TrackerWithCachedStates],
|
|
611
|
+
domain: Domain,
|
|
612
|
+
collect_sources: bool,
|
|
613
|
+
) -> Tuple[List[Text], Set[Optional[Text]]]:
|
|
614
|
+
if collect_sources:
|
|
615
|
+
self._rules_sources = defaultdict(list)
|
|
616
|
+
|
|
617
|
+
error_messages = []
|
|
618
|
+
rules_used_in_stories = set()
|
|
619
|
+
pbar = tqdm(
|
|
620
|
+
trackers,
|
|
621
|
+
desc="Processed trackers",
|
|
622
|
+
disable=rasa.shared.utils.io.is_logging_disabled(),
|
|
623
|
+
)
|
|
624
|
+
for tracker in pbar:
|
|
625
|
+
running_tracker = tracker.init_copy()
|
|
626
|
+
running_tracker.sender_id = tracker.sender_id
|
|
627
|
+
# the first action is always unpredictable
|
|
628
|
+
next_action_is_unpredictable = True
|
|
629
|
+
for event in tracker.applied_events(True):
|
|
630
|
+
if not isinstance(event, ActionExecuted):
|
|
631
|
+
running_tracker.update(event)
|
|
632
|
+
continue
|
|
633
|
+
|
|
634
|
+
if event.action_name == RULE_SNIPPET_ACTION_NAME:
|
|
635
|
+
# notify that the action after RULE_SNIPPET_ACTION_NAME is
|
|
636
|
+
# unpredictable
|
|
637
|
+
next_action_is_unpredictable = True
|
|
638
|
+
running_tracker.update(event)
|
|
639
|
+
continue
|
|
640
|
+
|
|
641
|
+
# do not run prediction on unpredictable actions
|
|
642
|
+
if next_action_is_unpredictable or event.unpredictable:
|
|
643
|
+
next_action_is_unpredictable = False # reset unpredictability
|
|
644
|
+
running_tracker.update(event)
|
|
645
|
+
continue
|
|
646
|
+
|
|
647
|
+
gold_action_name = event.action_name or event.action_text
|
|
648
|
+
predicted_action_name, prediction_source = self._predicted_action_name(
|
|
649
|
+
running_tracker, domain, gold_action_name
|
|
650
|
+
)
|
|
651
|
+
if collect_sources:
|
|
652
|
+
if prediction_source:
|
|
653
|
+
self._collect_sources(
|
|
654
|
+
running_tracker,
|
|
655
|
+
predicted_action_name,
|
|
656
|
+
gold_action_name,
|
|
657
|
+
prediction_source,
|
|
658
|
+
)
|
|
659
|
+
else:
|
|
660
|
+
# to be able to remove only rules turns from the dialogue history
|
|
661
|
+
# for ML policies,
|
|
662
|
+
# we need to know which rules were used in ML trackers
|
|
663
|
+
if (
|
|
664
|
+
not tracker.is_rule_tracker
|
|
665
|
+
and predicted_action_name == gold_action_name
|
|
666
|
+
):
|
|
667
|
+
rules_used_in_stories.add(prediction_source)
|
|
668
|
+
|
|
669
|
+
error_messages += self._check_prediction(
|
|
670
|
+
running_tracker,
|
|
671
|
+
predicted_action_name,
|
|
672
|
+
gold_action_name,
|
|
673
|
+
prediction_source,
|
|
674
|
+
)
|
|
675
|
+
|
|
676
|
+
running_tracker.update(event)
|
|
677
|
+
|
|
678
|
+
return error_messages, rules_used_in_stories
|
|
679
|
+
|
|
680
|
+
def _collect_rule_sources(
|
|
681
|
+
self, rule_trackers: List[TrackerWithCachedStates], domain: Domain
|
|
682
|
+
) -> None:
|
|
683
|
+
self._run_prediction_on_trackers(rule_trackers, domain, collect_sources=True)
|
|
684
|
+
|
|
685
|
+
def _find_contradicting_and_used_in_stories_rules(
|
|
686
|
+
self, trackers: List[TrackerWithCachedStates], domain: Domain
|
|
687
|
+
) -> Tuple[List[Text], Set[Optional[Text]]]:
|
|
688
|
+
return self._run_prediction_on_trackers(trackers, domain, collect_sources=False)
|
|
689
|
+
|
|
690
|
+
def _analyze_rules(
|
|
691
|
+
self,
|
|
692
|
+
rule_trackers: List[TrackerWithCachedStates],
|
|
693
|
+
all_trackers: List[TrackerWithCachedStates],
|
|
694
|
+
domain: Domain,
|
|
695
|
+
) -> List[Text]:
|
|
696
|
+
"""Analyzes learned rules by running prediction on training trackers.
|
|
697
|
+
|
|
698
|
+
This method collects error messages for contradicting rules
|
|
699
|
+
and creates the lookup for rules that are not present in the stories.
|
|
700
|
+
|
|
701
|
+
Args:
|
|
702
|
+
rule_trackers: The list of the rule trackers.
|
|
703
|
+
all_trackers: The list of all trackers.
|
|
704
|
+
domain: The domain.
|
|
705
|
+
|
|
706
|
+
Returns:
|
|
707
|
+
Rules that are not present in the stories.
|
|
708
|
+
"""
|
|
709
|
+
logger.debug("Started checking rules and stories for contradictions.")
|
|
710
|
+
# during training we run `predict_action_probabilities` to check for
|
|
711
|
+
# contradicting rules.
|
|
712
|
+
# We silent prediction debug to avoid too many logs during these checks.
|
|
713
|
+
logger_level = logger.level
|
|
714
|
+
logger.setLevel(logging.WARNING)
|
|
715
|
+
|
|
716
|
+
# we need to run prediction on rule trackers twice, because we need to collect
|
|
717
|
+
# the information about which rule snippets contributed to the learned rules
|
|
718
|
+
self._collect_rule_sources(rule_trackers, domain)
|
|
719
|
+
(
|
|
720
|
+
error_messages,
|
|
721
|
+
rules_used_in_stories,
|
|
722
|
+
) = self._find_contradicting_and_used_in_stories_rules(all_trackers, domain)
|
|
723
|
+
|
|
724
|
+
logger.setLevel(logger_level) # reset logger level
|
|
725
|
+
if error_messages:
|
|
726
|
+
error_text = "\n".join(error_messages)
|
|
727
|
+
raise InvalidRule(
|
|
728
|
+
f"\nContradicting rules or stories found 🚨\n\n{error_text}\n"
|
|
729
|
+
f"Please update your stories and rules so that they don't contradict "
|
|
730
|
+
f"each other."
|
|
731
|
+
)
|
|
732
|
+
|
|
733
|
+
logger.debug("Found no contradicting rules.")
|
|
734
|
+
all_rules = (
|
|
735
|
+
set(self._rules_sources.keys())
|
|
736
|
+
| self._default_sources()
|
|
737
|
+
| self._handling_loop_sources(domain)
|
|
738
|
+
)
|
|
739
|
+
# set is not json serializable, so convert to list
|
|
740
|
+
return list(all_rules - rules_used_in_stories)
|
|
741
|
+
|
|
742
|
+
def _create_lookup_from_trackers(
|
|
743
|
+
self,
|
|
744
|
+
rule_trackers: List[TrackerWithCachedStates],
|
|
745
|
+
story_trackers: List[TrackerWithCachedStates],
|
|
746
|
+
domain: Domain,
|
|
747
|
+
) -> None:
|
|
748
|
+
(
|
|
749
|
+
rule_trackers_as_states,
|
|
750
|
+
rule_trackers_as_actions,
|
|
751
|
+
) = self.featurizer.training_states_and_labels(
|
|
752
|
+
rule_trackers, domain, omit_unset_slots=True
|
|
753
|
+
)
|
|
754
|
+
|
|
755
|
+
rules_lookup = self._create_lookup_from_states(
|
|
756
|
+
rule_trackers_as_states, rule_trackers_as_actions
|
|
757
|
+
)
|
|
758
|
+
self.lookup[RULES] = self._remove_rule_snippet_predictions(rules_lookup)
|
|
759
|
+
|
|
760
|
+
(
|
|
761
|
+
story_trackers_as_states,
|
|
762
|
+
story_trackers_as_actions,
|
|
763
|
+
) = self.featurizer.training_states_and_labels(story_trackers, domain)
|
|
764
|
+
|
|
765
|
+
if self._check_for_contradictions:
|
|
766
|
+
(
|
|
767
|
+
self.lookup[RULE_ONLY_SLOTS],
|
|
768
|
+
self.lookup[RULE_ONLY_LOOPS],
|
|
769
|
+
) = self._find_rule_only_slots_loops(
|
|
770
|
+
rule_trackers_as_states, story_trackers_as_states
|
|
771
|
+
)
|
|
772
|
+
|
|
773
|
+
# use all trackers to find negative rules in unhappy paths
|
|
774
|
+
trackers_as_states = rule_trackers_as_states + story_trackers_as_states
|
|
775
|
+
trackers_as_actions = rule_trackers_as_actions + story_trackers_as_actions
|
|
776
|
+
|
|
777
|
+
# negative rules are not anti-rules, they are auxiliary to actual rules
|
|
778
|
+
self.lookup[RULES_FOR_LOOP_UNHAPPY_PATH] = (
|
|
779
|
+
self._create_loop_unhappy_lookup_from_states(
|
|
780
|
+
trackers_as_states, trackers_as_actions
|
|
781
|
+
)
|
|
782
|
+
)
|
|
783
|
+
|
|
784
|
+
def train(
|
|
785
|
+
self,
|
|
786
|
+
training_trackers: List[TrackerWithCachedStates],
|
|
787
|
+
domain: Domain,
|
|
788
|
+
**kwargs: Any,
|
|
789
|
+
) -> Resource:
|
|
790
|
+
"""Trains the policy on given training trackers.
|
|
791
|
+
|
|
792
|
+
Args:
|
|
793
|
+
training_trackers: The list of the trackers.
|
|
794
|
+
domain: The domain.
|
|
795
|
+
**kwargs: Additional arguments.
|
|
796
|
+
|
|
797
|
+
Returns:
|
|
798
|
+
The resource which can be used to load the trained policy.
|
|
799
|
+
"""
|
|
800
|
+
self.raise_if_incompatible_with_domain(self.config, domain)
|
|
801
|
+
|
|
802
|
+
# only consider original trackers (no augmented ones)
|
|
803
|
+
training_trackers = [
|
|
804
|
+
t for t in training_trackers if not getattr(t, "is_augmented", False)
|
|
805
|
+
]
|
|
806
|
+
# trackers from rule-based training data
|
|
807
|
+
rule_trackers = [t for t in training_trackers if t.is_rule_tracker]
|
|
808
|
+
if self.config["restrict_rules"]:
|
|
809
|
+
self._check_rule_restriction(rule_trackers)
|
|
810
|
+
if self._check_for_contradictions:
|
|
811
|
+
self._check_for_incomplete_rules(rule_trackers, domain)
|
|
812
|
+
|
|
813
|
+
# trackers from ML-based training data
|
|
814
|
+
story_trackers = [t for t in training_trackers if not t.is_rule_tracker]
|
|
815
|
+
|
|
816
|
+
self._create_lookup_from_trackers(rule_trackers, story_trackers, domain)
|
|
817
|
+
|
|
818
|
+
# make this configurable because checking might take a lot of time
|
|
819
|
+
if self._check_for_contradictions:
|
|
820
|
+
# using trackers here might not be the most efficient way, however
|
|
821
|
+
# it allows us to directly test `predict_action_probabilities` method
|
|
822
|
+
self.lookup[RULES_NOT_IN_STORIES] = self._analyze_rules(
|
|
823
|
+
rule_trackers, training_trackers, domain
|
|
824
|
+
)
|
|
825
|
+
|
|
826
|
+
logger.debug(f"Memorized '{len(self.lookup[RULES])}' unique rules.")
|
|
827
|
+
|
|
828
|
+
self.persist()
|
|
829
|
+
|
|
830
|
+
return self._resource
|
|
831
|
+
|
|
832
|
+
@staticmethod
|
|
833
|
+
def _does_rule_match_state(rule_state: State, conversation_state: State) -> bool:
|
|
834
|
+
for state_type, rule_sub_state in rule_state.items():
|
|
835
|
+
conversation_sub_state = conversation_state.get(state_type, {})
|
|
836
|
+
for key, value_from_rules in rule_sub_state.items():
|
|
837
|
+
if isinstance(value_from_rules, list):
|
|
838
|
+
# json dumps and loads tuples as lists,
|
|
839
|
+
# so we need to convert them back
|
|
840
|
+
value_from_rules = tuple(value_from_rules)
|
|
841
|
+
value_from_conversation = conversation_sub_state.get(key)
|
|
842
|
+
if (
|
|
843
|
+
# value should be set, therefore
|
|
844
|
+
# check whether it is the same as in the state
|
|
845
|
+
value_from_rules
|
|
846
|
+
and value_from_rules != SHOULD_NOT_BE_SET
|
|
847
|
+
and value_from_conversation != value_from_rules
|
|
848
|
+
) or (
|
|
849
|
+
# value shouldn't be set, therefore
|
|
850
|
+
# it should be None or non existent in the state
|
|
851
|
+
value_from_rules == SHOULD_NOT_BE_SET
|
|
852
|
+
and value_from_conversation
|
|
853
|
+
# during training `SHOULD_NOT_BE_SET` is provided. Hence, we also
|
|
854
|
+
# have to check for the value of the slot state
|
|
855
|
+
and value_from_conversation != SHOULD_NOT_BE_SET
|
|
856
|
+
):
|
|
857
|
+
return False
|
|
858
|
+
|
|
859
|
+
return True
|
|
860
|
+
|
|
861
|
+
@staticmethod
|
|
862
|
+
# This function is called a lot (e.g. for checking contradictions) so we cache
|
|
863
|
+
# its results.
|
|
864
|
+
@functools.lru_cache(maxsize=1000)
|
|
865
|
+
def _rule_key_to_state(rule_key: Text) -> List[State]:
|
|
866
|
+
return json.loads(rule_key)
|
|
867
|
+
|
|
868
|
+
def _is_rule_applicable(
|
|
869
|
+
self, rule_key: Text, turn_index: int, conversation_state: State
|
|
870
|
+
) -> bool:
|
|
871
|
+
"""Checks if rule is satisfied with current state at turn.
|
|
872
|
+
|
|
873
|
+
Args:
|
|
874
|
+
rule_key: the textual representation of learned rule
|
|
875
|
+
turn_index: index of a current dialogue turn
|
|
876
|
+
conversation_state: the state that corresponds to turn_index
|
|
877
|
+
|
|
878
|
+
Returns:
|
|
879
|
+
a boolean that says whether the rule is applicable to current state
|
|
880
|
+
"""
|
|
881
|
+
# turn_index goes back in time
|
|
882
|
+
reversed_rule_states = list(reversed(self._rule_key_to_state(rule_key)))
|
|
883
|
+
|
|
884
|
+
# the rule must be applicable because we got (without any applicability issues)
|
|
885
|
+
# further in the conversation history than the rule's length
|
|
886
|
+
if turn_index >= len(reversed_rule_states):
|
|
887
|
+
return True
|
|
888
|
+
|
|
889
|
+
# a state has previous action if and only if it is not a conversation start
|
|
890
|
+
# state
|
|
891
|
+
current_previous_action = conversation_state.get(PREVIOUS_ACTION)
|
|
892
|
+
rule_previous_action = reversed_rule_states[turn_index].get(PREVIOUS_ACTION)
|
|
893
|
+
|
|
894
|
+
# current conversation state and rule state are conversation starters.
|
|
895
|
+
# any slots with initial_value set will necessarily be in both states and don't
|
|
896
|
+
# need to be checked.
|
|
897
|
+
if not rule_previous_action and not current_previous_action:
|
|
898
|
+
return True
|
|
899
|
+
|
|
900
|
+
# current rule state is a conversation starter (due to conversation_start: true)
|
|
901
|
+
# but current conversation state is not.
|
|
902
|
+
# or
|
|
903
|
+
# current conversation state is a starter
|
|
904
|
+
# but current rule state is not.
|
|
905
|
+
if not rule_previous_action or not current_previous_action:
|
|
906
|
+
return False
|
|
907
|
+
|
|
908
|
+
# check: current rule state features are present in current conversation state
|
|
909
|
+
return self._does_rule_match_state(
|
|
910
|
+
reversed_rule_states[turn_index], conversation_state
|
|
911
|
+
)
|
|
912
|
+
|
|
913
|
+
def _get_possible_keys(
|
|
914
|
+
self, lookup: Dict[Text, Text], states: List[State]
|
|
915
|
+
) -> Set[Text]:
|
|
916
|
+
possible_keys = set(lookup.keys())
|
|
917
|
+
for i, state in enumerate(reversed(states)):
|
|
918
|
+
# find rule keys that correspond to current state
|
|
919
|
+
possible_keys = set(
|
|
920
|
+
filter(
|
|
921
|
+
lambda _key: self._is_rule_applicable(_key, i, state), possible_keys
|
|
922
|
+
)
|
|
923
|
+
)
|
|
924
|
+
return possible_keys
|
|
925
|
+
|
|
926
|
+
@staticmethod
|
|
927
|
+
def _find_action_from_default_actions(
|
|
928
|
+
tracker: DialogueStateTracker,
|
|
929
|
+
) -> Tuple[Optional[Text], Optional[Text]]:
|
|
930
|
+
if (
|
|
931
|
+
not tracker.latest_action_name == ACTION_LISTEN_NAME
|
|
932
|
+
or not tracker.latest_message
|
|
933
|
+
):
|
|
934
|
+
return None, None
|
|
935
|
+
|
|
936
|
+
intent_name = tracker.latest_message.intent.get(INTENT_NAME_KEY)
|
|
937
|
+
if intent_name is None:
|
|
938
|
+
return None, None
|
|
939
|
+
|
|
940
|
+
default_action_name = DEFAULT_ACTION_MAPPINGS.get(intent_name)
|
|
941
|
+
if default_action_name is None:
|
|
942
|
+
return None, None
|
|
943
|
+
|
|
944
|
+
logger.debug(f"Predicted default action '{default_action_name}'.")
|
|
945
|
+
return (
|
|
946
|
+
default_action_name,
|
|
947
|
+
# create prediction source that corresponds to one of
|
|
948
|
+
# default prediction sources in `_default_sources()`
|
|
949
|
+
DEFAULT_RULES + intent_name,
|
|
950
|
+
)
|
|
951
|
+
|
|
952
|
+
@staticmethod
|
|
953
|
+
def _find_action_from_loop_happy_path(
|
|
954
|
+
tracker: DialogueStateTracker,
|
|
955
|
+
) -> Tuple[Optional[Text], Optional[Text]]:
|
|
956
|
+
active_loop_name = tracker.active_loop_name
|
|
957
|
+
if active_loop_name is None:
|
|
958
|
+
return None, None
|
|
959
|
+
|
|
960
|
+
active_loop_rejected = tracker.is_active_loop_rejected
|
|
961
|
+
should_predict_loop = (
|
|
962
|
+
not active_loop_rejected
|
|
963
|
+
and tracker.latest_action
|
|
964
|
+
and tracker.latest_action.get(ACTION_NAME) != active_loop_name
|
|
965
|
+
)
|
|
966
|
+
should_predict_listen = (
|
|
967
|
+
not active_loop_rejected and tracker.latest_action_name == active_loop_name
|
|
968
|
+
)
|
|
969
|
+
|
|
970
|
+
if should_predict_loop:
|
|
971
|
+
logger.debug(f"Predicted loop '{active_loop_name}'.")
|
|
972
|
+
return active_loop_name, LOOP_RULES + active_loop_name
|
|
973
|
+
|
|
974
|
+
# predict `action_listen` if loop action was run successfully
|
|
975
|
+
if should_predict_listen:
|
|
976
|
+
logger.debug(
|
|
977
|
+
f"Predicted '{ACTION_LISTEN_NAME}' after loop '{active_loop_name}'."
|
|
978
|
+
)
|
|
979
|
+
return (
|
|
980
|
+
ACTION_LISTEN_NAME,
|
|
981
|
+
(
|
|
982
|
+
f"{LOOP_RULES}{active_loop_name}"
|
|
983
|
+
f"{LOOP_RULES_SEPARATOR}{ACTION_LISTEN_NAME}"
|
|
984
|
+
),
|
|
985
|
+
)
|
|
986
|
+
|
|
987
|
+
return None, None
|
|
988
|
+
|
|
989
|
+
def _find_action_from_rules(
|
|
990
|
+
self,
|
|
991
|
+
tracker: DialogueStateTracker,
|
|
992
|
+
domain: Domain,
|
|
993
|
+
use_text_for_last_user_input: bool,
|
|
994
|
+
) -> Tuple[Optional[Text], Optional[Text], bool]:
|
|
995
|
+
"""Predicts the next action based on the memoized rules.
|
|
996
|
+
|
|
997
|
+
Args:
|
|
998
|
+
tracker: The current conversation tracker.
|
|
999
|
+
domain: The domain of the current model.
|
|
1000
|
+
use_text_for_last_user_input: `True` if text of last user message
|
|
1001
|
+
should be used for the prediction. `False` if intent should be used.
|
|
1002
|
+
|
|
1003
|
+
Returns:
|
|
1004
|
+
A tuple of the predicted action name or text (or `None` if no matching rule
|
|
1005
|
+
was found), a description of the matching rule, and `True` if a loop action
|
|
1006
|
+
was predicted after the loop has been in an unhappy path before.
|
|
1007
|
+
"""
|
|
1008
|
+
if (
|
|
1009
|
+
use_text_for_last_user_input
|
|
1010
|
+
and not tracker.latest_action_name == ACTION_LISTEN_NAME
|
|
1011
|
+
):
|
|
1012
|
+
# make text prediction only directly after user utterance
|
|
1013
|
+
# because we've otherwise already decided whether to use
|
|
1014
|
+
# the text or the intent
|
|
1015
|
+
return None, None, False
|
|
1016
|
+
|
|
1017
|
+
states = self._prediction_states(
|
|
1018
|
+
tracker,
|
|
1019
|
+
domain,
|
|
1020
|
+
use_text_for_last_user_input,
|
|
1021
|
+
rule_only_data=self._get_rule_only_data(),
|
|
1022
|
+
)
|
|
1023
|
+
|
|
1024
|
+
current_states = self.format_tracker_states(states)
|
|
1025
|
+
structlogger.debug(
|
|
1026
|
+
"rule_policy.actions.find", current_states=copy.deepcopy(current_states)
|
|
1027
|
+
)
|
|
1028
|
+
|
|
1029
|
+
# Tracks if we are returning after an unhappy loop path. If this becomes `True`
|
|
1030
|
+
# the policy returns an event which notifies the loop action that it
|
|
1031
|
+
# is returning after an unhappy path. For example, the `FormAction` uses this
|
|
1032
|
+
# to skip the validation of slots for its first execution after an unhappy path.
|
|
1033
|
+
returning_from_unhappy_path = False
|
|
1034
|
+
|
|
1035
|
+
rule_keys = self._get_possible_keys(self.lookup[RULES], states)
|
|
1036
|
+
predicted_action_name = None
|
|
1037
|
+
best_rule_key = ""
|
|
1038
|
+
if rule_keys:
|
|
1039
|
+
# if there are several rules,
|
|
1040
|
+
# it should mean that some rule is a subset of another rule
|
|
1041
|
+
# therefore we pick a rule of maximum length
|
|
1042
|
+
best_rule_key = max(rule_keys, key=len)
|
|
1043
|
+
predicted_action_name = self.lookup[RULES].get(best_rule_key)
|
|
1044
|
+
|
|
1045
|
+
active_loop_name = tracker.active_loop_name
|
|
1046
|
+
if active_loop_name:
|
|
1047
|
+
# find rules for unhappy path of the loop
|
|
1048
|
+
loop_unhappy_keys = self._get_possible_keys(
|
|
1049
|
+
self.lookup[RULES_FOR_LOOP_UNHAPPY_PATH], states
|
|
1050
|
+
)
|
|
1051
|
+
# there could be several unhappy path conditions
|
|
1052
|
+
unhappy_path_conditions = [
|
|
1053
|
+
self.lookup[RULES_FOR_LOOP_UNHAPPY_PATH].get(key)
|
|
1054
|
+
for key in loop_unhappy_keys
|
|
1055
|
+
]
|
|
1056
|
+
|
|
1057
|
+
# Check if a rule that predicted action_listen
|
|
1058
|
+
# was applied inside the loop.
|
|
1059
|
+
# Rules might not explicitly switch back to the loop.
|
|
1060
|
+
# Hence, we have to take care of that.
|
|
1061
|
+
predicted_listen_from_general_rule = (
|
|
1062
|
+
predicted_action_name == ACTION_LISTEN_NAME
|
|
1063
|
+
and not get_active_loop_name(self._rule_key_to_state(best_rule_key)[-1])
|
|
1064
|
+
)
|
|
1065
|
+
if predicted_listen_from_general_rule:
|
|
1066
|
+
if DO_NOT_PREDICT_LOOP_ACTION not in unhappy_path_conditions:
|
|
1067
|
+
# negative rules don't contain a key that corresponds to
|
|
1068
|
+
# the fact that active_loop shouldn't be predicted
|
|
1069
|
+
logger.debug(
|
|
1070
|
+
f"Predicted loop '{active_loop_name}' by overwriting "
|
|
1071
|
+
f"'{ACTION_LISTEN_NAME}' predicted by general rule."
|
|
1072
|
+
)
|
|
1073
|
+
return (
|
|
1074
|
+
active_loop_name,
|
|
1075
|
+
best_rule_key,
|
|
1076
|
+
returning_from_unhappy_path,
|
|
1077
|
+
)
|
|
1078
|
+
|
|
1079
|
+
# do not predict anything
|
|
1080
|
+
predicted_action_name = None
|
|
1081
|
+
|
|
1082
|
+
if LOOP_WAS_INTERRUPTED in unhappy_path_conditions:
|
|
1083
|
+
logger.debug(
|
|
1084
|
+
"Returning from unhappy path. Loop will be notified that "
|
|
1085
|
+
"it was interrupted."
|
|
1086
|
+
)
|
|
1087
|
+
returning_from_unhappy_path = True
|
|
1088
|
+
|
|
1089
|
+
if predicted_action_name is not None:
|
|
1090
|
+
logger.debug(
|
|
1091
|
+
f"There is a rule for the next action '{predicted_action_name}'."
|
|
1092
|
+
)
|
|
1093
|
+
else:
|
|
1094
|
+
logger.debug("There is no applicable rule.")
|
|
1095
|
+
|
|
1096
|
+
# if we didn't predict anything from the rules, then the feature key created
|
|
1097
|
+
# from states can be used as an indicator that this state will lead to fallback
|
|
1098
|
+
return (
|
|
1099
|
+
predicted_action_name,
|
|
1100
|
+
best_rule_key or self._create_feature_key(states),
|
|
1101
|
+
returning_from_unhappy_path,
|
|
1102
|
+
)
|
|
1103
|
+
|
|
1104
|
+
async def predict_action_probabilities(
|
|
1105
|
+
self,
|
|
1106
|
+
tracker: DialogueStateTracker,
|
|
1107
|
+
domain: Domain,
|
|
1108
|
+
rule_only_data: Optional[Dict[Text, Any]] = None,
|
|
1109
|
+
**kwargs: Any,
|
|
1110
|
+
) -> PolicyPrediction:
|
|
1111
|
+
"""Predicts the next action (see parent class for more information)."""
|
|
1112
|
+
if self.should_abstain_in_coexistence(tracker, False):
|
|
1113
|
+
# don't use self._default_predictions as this might have a different
|
|
1114
|
+
# probability for the fallback action, and we want to have all probabilities
|
|
1115
|
+
# set to 0.0
|
|
1116
|
+
return self._prediction(super()._default_predictions(domain))
|
|
1117
|
+
|
|
1118
|
+
prediction, _ = self._predict(tracker, domain)
|
|
1119
|
+
return prediction
|
|
1120
|
+
|
|
1121
|
+
def _predict(
|
|
1122
|
+
self, tracker: DialogueStateTracker, domain: Domain
|
|
1123
|
+
) -> Tuple[PolicyPrediction, Optional[Text]]:
|
|
1124
|
+
(
|
|
1125
|
+
rules_action_name_from_text,
|
|
1126
|
+
prediction_source_from_text,
|
|
1127
|
+
returning_from_unhappy_path_from_text,
|
|
1128
|
+
) = self._find_action_from_rules(
|
|
1129
|
+
tracker, domain, use_text_for_last_user_input=True
|
|
1130
|
+
)
|
|
1131
|
+
|
|
1132
|
+
# Rasa Pro default actions overrule anything. If users want to achieve
|
|
1133
|
+
# the same, they need to write a rule or make sure that their loop rejects
|
|
1134
|
+
# accordingly.
|
|
1135
|
+
(
|
|
1136
|
+
default_action_name,
|
|
1137
|
+
default_prediction_source,
|
|
1138
|
+
) = self._find_action_from_default_actions(tracker)
|
|
1139
|
+
|
|
1140
|
+
# text has priority over intents including default,
|
|
1141
|
+
# however loop happy path has priority over rules prediction
|
|
1142
|
+
if default_action_name and not rules_action_name_from_text:
|
|
1143
|
+
return (
|
|
1144
|
+
self._rule_prediction(
|
|
1145
|
+
self._prediction_result(default_action_name, tracker, domain),
|
|
1146
|
+
default_prediction_source,
|
|
1147
|
+
),
|
|
1148
|
+
default_prediction_source,
|
|
1149
|
+
)
|
|
1150
|
+
|
|
1151
|
+
# A loop has priority over any other rule except defaults.
|
|
1152
|
+
# The rules or any other prediction will be applied only if a loop was rejected.
|
|
1153
|
+
# If we are in a loop, and the loop didn't run previously or rejected, we can
|
|
1154
|
+
# simply force predict the loop.
|
|
1155
|
+
(
|
|
1156
|
+
loop_happy_path_action_name,
|
|
1157
|
+
loop_happy_path_prediction_source,
|
|
1158
|
+
) = self._find_action_from_loop_happy_path(tracker)
|
|
1159
|
+
if loop_happy_path_action_name:
|
|
1160
|
+
# this prediction doesn't use user input
|
|
1161
|
+
# and happy user input anyhow should be ignored during featurization
|
|
1162
|
+
return (
|
|
1163
|
+
self._rule_prediction(
|
|
1164
|
+
self._prediction_result(
|
|
1165
|
+
loop_happy_path_action_name, tracker, domain
|
|
1166
|
+
),
|
|
1167
|
+
loop_happy_path_prediction_source,
|
|
1168
|
+
is_no_user_prediction=True,
|
|
1169
|
+
),
|
|
1170
|
+
loop_happy_path_prediction_source,
|
|
1171
|
+
)
|
|
1172
|
+
|
|
1173
|
+
# predict rules from text first
|
|
1174
|
+
if rules_action_name_from_text:
|
|
1175
|
+
return (
|
|
1176
|
+
self._rule_prediction(
|
|
1177
|
+
self._prediction_result(
|
|
1178
|
+
rules_action_name_from_text, tracker, domain
|
|
1179
|
+
),
|
|
1180
|
+
prediction_source_from_text,
|
|
1181
|
+
returning_from_unhappy_path=returning_from_unhappy_path_from_text,
|
|
1182
|
+
is_end_to_end_prediction=True,
|
|
1183
|
+
),
|
|
1184
|
+
prediction_source_from_text,
|
|
1185
|
+
)
|
|
1186
|
+
|
|
1187
|
+
(
|
|
1188
|
+
rules_action_name_from_intent,
|
|
1189
|
+
# we want to remember the source even if rules didn't predict any action
|
|
1190
|
+
prediction_source_from_intent,
|
|
1191
|
+
returning_from_unhappy_path_from_intent,
|
|
1192
|
+
) = self._find_action_from_rules(
|
|
1193
|
+
tracker, domain, use_text_for_last_user_input=False
|
|
1194
|
+
)
|
|
1195
|
+
if rules_action_name_from_intent:
|
|
1196
|
+
probabilities = self._prediction_result(
|
|
1197
|
+
rules_action_name_from_intent, tracker, domain
|
|
1198
|
+
)
|
|
1199
|
+
else:
|
|
1200
|
+
probabilities = self._default_predictions(domain)
|
|
1201
|
+
|
|
1202
|
+
return (
|
|
1203
|
+
self._rule_prediction(
|
|
1204
|
+
probabilities,
|
|
1205
|
+
prediction_source_from_intent,
|
|
1206
|
+
returning_from_unhappy_path=(
|
|
1207
|
+
# returning_from_unhappy_path is a negative condition,
|
|
1208
|
+
# so `or` should be applied
|
|
1209
|
+
returning_from_unhappy_path_from_text
|
|
1210
|
+
or returning_from_unhappy_path_from_intent
|
|
1211
|
+
),
|
|
1212
|
+
is_end_to_end_prediction=False,
|
|
1213
|
+
),
|
|
1214
|
+
prediction_source_from_intent,
|
|
1215
|
+
)
|
|
1216
|
+
|
|
1217
|
+
def _rule_prediction(
|
|
1218
|
+
self,
|
|
1219
|
+
probabilities: List[float],
|
|
1220
|
+
prediction_source: Text,
|
|
1221
|
+
returning_from_unhappy_path: bool = False,
|
|
1222
|
+
is_end_to_end_prediction: bool = False,
|
|
1223
|
+
is_no_user_prediction: bool = False,
|
|
1224
|
+
) -> PolicyPrediction:
|
|
1225
|
+
return PolicyPrediction(
|
|
1226
|
+
probabilities,
|
|
1227
|
+
self.__class__.__name__,
|
|
1228
|
+
self.priority,
|
|
1229
|
+
events=[LoopInterrupted(True)] if returning_from_unhappy_path else [],
|
|
1230
|
+
is_end_to_end_prediction=is_end_to_end_prediction,
|
|
1231
|
+
is_no_user_prediction=is_no_user_prediction,
|
|
1232
|
+
hide_rule_turn=(
|
|
1233
|
+
True
|
|
1234
|
+
if prediction_source in self.lookup.get(RULES_NOT_IN_STORIES, [])
|
|
1235
|
+
else False
|
|
1236
|
+
),
|
|
1237
|
+
)
|
|
1238
|
+
|
|
1239
|
+
def _default_predictions(self, domain: Domain) -> List[float]:
|
|
1240
|
+
result = super()._default_predictions(domain)
|
|
1241
|
+
|
|
1242
|
+
if self._enable_fallback_prediction:
|
|
1243
|
+
result[domain.index_for_action(self._fallback_action_name)] = self.config[
|
|
1244
|
+
"core_fallback_threshold"
|
|
1245
|
+
]
|
|
1246
|
+
|
|
1247
|
+
return result
|
|
1248
|
+
|
|
1249
|
+
def persist(self) -> None:
|
|
1250
|
+
"""Persists trained `RulePolicy`."""
|
|
1251
|
+
super().persist()
|
|
1252
|
+
with self._model_storage.write_to(self._resource) as directory:
|
|
1253
|
+
rule_only_data = self._get_rule_only_data()
|
|
1254
|
+
rasa.shared.utils.io.dump_obj_as_json_to_file(
|
|
1255
|
+
directory / "rule_only_data.json", rule_only_data
|
|
1256
|
+
)
|
|
1257
|
+
|
|
1258
|
+
def _metadata(self) -> Dict[Text, Any]:
|
|
1259
|
+
return {"lookup": self.lookup}
|
|
1260
|
+
|
|
1261
|
+
@classmethod
|
|
1262
|
+
def _metadata_filename(cls) -> Text:
|
|
1263
|
+
return "rule_policy.json"
|
|
1264
|
+
|
|
1265
|
+
def _get_rule_only_data(self) -> Dict[Text, Any]:
|
|
1266
|
+
"""Gets the slots and loops that are used only in rule data.
|
|
1267
|
+
|
|
1268
|
+
Returns:
|
|
1269
|
+
Slots and loops that are used only in rule data.
|
|
1270
|
+
"""
|
|
1271
|
+
return {
|
|
1272
|
+
key: self.lookup.get(key, []) for key in [RULE_ONLY_SLOTS, RULE_ONLY_LOOPS]
|
|
1273
|
+
}
|