rasa-pro 3.12.0.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- README.md +41 -0
- rasa/__init__.py +9 -0
- rasa/__main__.py +177 -0
- rasa/anonymization/__init__.py +2 -0
- rasa/anonymization/anonymisation_rule_yaml_reader.py +91 -0
- rasa/anonymization/anonymization_pipeline.py +286 -0
- rasa/anonymization/anonymization_rule_executor.py +260 -0
- rasa/anonymization/anonymization_rule_orchestrator.py +120 -0
- rasa/anonymization/schemas/config.yml +47 -0
- rasa/anonymization/utils.py +118 -0
- rasa/api.py +160 -0
- rasa/cli/__init__.py +5 -0
- rasa/cli/arguments/__init__.py +0 -0
- rasa/cli/arguments/data.py +106 -0
- rasa/cli/arguments/default_arguments.py +207 -0
- rasa/cli/arguments/evaluate.py +65 -0
- rasa/cli/arguments/export.py +51 -0
- rasa/cli/arguments/interactive.py +74 -0
- rasa/cli/arguments/run.py +219 -0
- rasa/cli/arguments/shell.py +17 -0
- rasa/cli/arguments/test.py +211 -0
- rasa/cli/arguments/train.py +279 -0
- rasa/cli/arguments/visualize.py +34 -0
- rasa/cli/arguments/x.py +30 -0
- rasa/cli/data.py +354 -0
- rasa/cli/dialogue_understanding_test.py +251 -0
- rasa/cli/e2e_test.py +259 -0
- rasa/cli/evaluate.py +222 -0
- rasa/cli/export.py +250 -0
- rasa/cli/inspect.py +75 -0
- rasa/cli/interactive.py +166 -0
- rasa/cli/license.py +65 -0
- rasa/cli/llm_fine_tuning.py +403 -0
- rasa/cli/markers.py +78 -0
- rasa/cli/project_templates/__init__.py +0 -0
- rasa/cli/project_templates/calm/actions/__init__.py +0 -0
- rasa/cli/project_templates/calm/actions/action_template.py +27 -0
- rasa/cli/project_templates/calm/actions/add_contact.py +30 -0
- rasa/cli/project_templates/calm/actions/db.py +57 -0
- rasa/cli/project_templates/calm/actions/list_contacts.py +22 -0
- rasa/cli/project_templates/calm/actions/remove_contact.py +35 -0
- rasa/cli/project_templates/calm/config.yml +10 -0
- rasa/cli/project_templates/calm/credentials.yml +33 -0
- rasa/cli/project_templates/calm/data/flows/add_contact.yml +31 -0
- rasa/cli/project_templates/calm/data/flows/list_contacts.yml +14 -0
- rasa/cli/project_templates/calm/data/flows/remove_contact.yml +29 -0
- rasa/cli/project_templates/calm/db/contacts.json +10 -0
- rasa/cli/project_templates/calm/domain/add_contact.yml +39 -0
- rasa/cli/project_templates/calm/domain/list_contacts.yml +17 -0
- rasa/cli/project_templates/calm/domain/remove_contact.yml +38 -0
- rasa/cli/project_templates/calm/domain/shared.yml +10 -0
- rasa/cli/project_templates/calm/e2e_tests/cancelations/user_cancels_during_a_correction.yml +16 -0
- rasa/cli/project_templates/calm/e2e_tests/cancelations/user_changes_mind_on_a_whim.yml +7 -0
- rasa/cli/project_templates/calm/e2e_tests/corrections/user_corrects_contact_handle.yml +20 -0
- rasa/cli/project_templates/calm/e2e_tests/corrections/user_corrects_contact_name.yml +19 -0
- rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_adds_contact_to_their_list.yml +15 -0
- rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_lists_contacts.yml +5 -0
- rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_removes_contact.yml +11 -0
- rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_removes_contact_from_list.yml +12 -0
- rasa/cli/project_templates/calm/endpoints.yml +58 -0
- rasa/cli/project_templates/default/actions/__init__.py +0 -0
- rasa/cli/project_templates/default/actions/actions.py +27 -0
- rasa/cli/project_templates/default/config.yml +44 -0
- rasa/cli/project_templates/default/credentials.yml +33 -0
- rasa/cli/project_templates/default/data/nlu.yml +91 -0
- rasa/cli/project_templates/default/data/rules.yml +13 -0
- rasa/cli/project_templates/default/data/stories.yml +30 -0
- rasa/cli/project_templates/default/domain.yml +34 -0
- rasa/cli/project_templates/default/endpoints.yml +42 -0
- rasa/cli/project_templates/default/tests/test_stories.yml +91 -0
- rasa/cli/project_templates/tutorial/actions/__init__.py +0 -0
- rasa/cli/project_templates/tutorial/actions/actions.py +22 -0
- rasa/cli/project_templates/tutorial/config.yml +12 -0
- rasa/cli/project_templates/tutorial/credentials.yml +33 -0
- rasa/cli/project_templates/tutorial/data/flows.yml +8 -0
- rasa/cli/project_templates/tutorial/data/patterns.yml +11 -0
- rasa/cli/project_templates/tutorial/domain.yml +35 -0
- rasa/cli/project_templates/tutorial/endpoints.yml +55 -0
- rasa/cli/run.py +143 -0
- rasa/cli/scaffold.py +273 -0
- rasa/cli/shell.py +141 -0
- rasa/cli/studio/__init__.py +0 -0
- rasa/cli/studio/download.py +62 -0
- rasa/cli/studio/studio.py +296 -0
- rasa/cli/studio/train.py +59 -0
- rasa/cli/studio/upload.py +62 -0
- rasa/cli/telemetry.py +102 -0
- rasa/cli/test.py +280 -0
- rasa/cli/train.py +278 -0
- rasa/cli/utils.py +484 -0
- rasa/cli/visualize.py +40 -0
- rasa/cli/x.py +206 -0
- rasa/constants.py +45 -0
- rasa/core/__init__.py +17 -0
- rasa/core/actions/__init__.py +0 -0
- rasa/core/actions/action.py +1318 -0
- rasa/core/actions/action_clean_stack.py +59 -0
- rasa/core/actions/action_exceptions.py +24 -0
- rasa/core/actions/action_hangup.py +29 -0
- rasa/core/actions/action_repeat_bot_messages.py +89 -0
- rasa/core/actions/action_run_slot_rejections.py +210 -0
- rasa/core/actions/action_trigger_chitchat.py +31 -0
- rasa/core/actions/action_trigger_flow.py +109 -0
- rasa/core/actions/action_trigger_search.py +31 -0
- rasa/core/actions/constants.py +5 -0
- rasa/core/actions/custom_action_executor.py +191 -0
- rasa/core/actions/direct_custom_actions_executor.py +109 -0
- rasa/core/actions/e2e_stub_custom_action_executor.py +72 -0
- rasa/core/actions/forms.py +741 -0
- rasa/core/actions/grpc_custom_action_executor.py +251 -0
- rasa/core/actions/http_custom_action_executor.py +145 -0
- rasa/core/actions/loops.py +114 -0
- rasa/core/actions/two_stage_fallback.py +186 -0
- rasa/core/agent.py +559 -0
- rasa/core/auth_retry_tracker_store.py +122 -0
- rasa/core/brokers/__init__.py +0 -0
- rasa/core/brokers/broker.py +126 -0
- rasa/core/brokers/file.py +58 -0
- rasa/core/brokers/kafka.py +324 -0
- rasa/core/brokers/pika.py +388 -0
- rasa/core/brokers/sql.py +86 -0
- rasa/core/channels/__init__.py +61 -0
- rasa/core/channels/botframework.py +338 -0
- rasa/core/channels/callback.py +84 -0
- rasa/core/channels/channel.py +456 -0
- rasa/core/channels/console.py +241 -0
- rasa/core/channels/development_inspector.py +197 -0
- rasa/core/channels/facebook.py +419 -0
- rasa/core/channels/hangouts.py +329 -0
- rasa/core/channels/inspector/.eslintrc.cjs +25 -0
- rasa/core/channels/inspector/.gitignore +23 -0
- rasa/core/channels/inspector/README.md +54 -0
- rasa/core/channels/inspector/assets/favicon.ico +0 -0
- rasa/core/channels/inspector/assets/rasa-chat.js +2 -0
- rasa/core/channels/inspector/custom.d.ts +3 -0
- rasa/core/channels/inspector/dist/assets/arc-861ddd57.js +1 -0
- rasa/core/channels/inspector/dist/assets/array-9f3ba611.js +1 -0
- rasa/core/channels/inspector/dist/assets/c4Diagram-d0fbc5ce-921f02db.js +10 -0
- rasa/core/channels/inspector/dist/assets/classDiagram-936ed81e-b436c4f8.js +2 -0
- rasa/core/channels/inspector/dist/assets/classDiagram-v2-c3cb15f1-511a23cb.js +2 -0
- rasa/core/channels/inspector/dist/assets/createText-62fc7601-ef476ecd.js +7 -0
- rasa/core/channels/inspector/dist/assets/edges-f2ad444c-f1878e0a.js +4 -0
- rasa/core/channels/inspector/dist/assets/erDiagram-9d236eb7-fac75185.js +51 -0
- rasa/core/channels/inspector/dist/assets/flowDb-1972c806-201c5bbc.js +6 -0
- rasa/core/channels/inspector/dist/assets/flowDiagram-7ea5b25a-f904ae41.js +4 -0
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-b080d6f2.js +1 -0
- rasa/core/channels/inspector/dist/assets/flowchart-elk-definition-abe16c3d-1813da66.js +139 -0
- rasa/core/channels/inspector/dist/assets/ganttDiagram-9b5ea136-872af172.js +266 -0
- rasa/core/channels/inspector/dist/assets/gitGraphDiagram-99d0ae7c-34a0af5a.js +70 -0
- rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-128cfa44.ttf +0 -0
- rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-21dbcb97.woff +0 -0
- rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-222b5e26.svg +329 -0
- rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-9ad89b2a.woff2 +0 -0
- rasa/core/channels/inspector/dist/assets/index-2c4b9a3b-42ba3e3d.js +1 -0
- rasa/core/channels/inspector/dist/assets/index-37817b51.js +1317 -0
- rasa/core/channels/inspector/dist/assets/index-3ee28881.css +1 -0
- rasa/core/channels/inspector/dist/assets/infoDiagram-736b4530-6b731386.js +7 -0
- rasa/core/channels/inspector/dist/assets/init-77b53fdd.js +1 -0
- rasa/core/channels/inspector/dist/assets/journeyDiagram-df861f2b-e8579ac6.js +139 -0
- rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-60c05ee4.woff +0 -0
- rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-8335d9b8.svg +438 -0
- rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-9cc39c75.ttf +0 -0
- rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-ead13ccf.woff2 +0 -0
- rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-16705655.woff2 +0 -0
- rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-5aeb07f9.woff +0 -0
- rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-9c459044.ttf +0 -0
- rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-9e2898a4.svg +435 -0
- rasa/core/channels/inspector/dist/assets/layout-89e6403a.js +1 -0
- rasa/core/channels/inspector/dist/assets/line-dc73d3fc.js +1 -0
- rasa/core/channels/inspector/dist/assets/linear-f5b1d2bc.js +1 -0
- rasa/core/channels/inspector/dist/assets/mindmap-definition-beec6740-82cb74fa.js +109 -0
- rasa/core/channels/inspector/dist/assets/ordinal-ba9b4969.js +1 -0
- rasa/core/channels/inspector/dist/assets/path-53f90ab3.js +1 -0
- rasa/core/channels/inspector/dist/assets/pieDiagram-dbbf0591-bdf5f29b.js +35 -0
- rasa/core/channels/inspector/dist/assets/quadrantDiagram-4d7f4fd6-c7a0cbe4.js +7 -0
- rasa/core/channels/inspector/dist/assets/requirementDiagram-6fc4c22a-7ec5410f.js +52 -0
- rasa/core/channels/inspector/dist/assets/sankeyDiagram-8f13d901-caee5554.js +8 -0
- rasa/core/channels/inspector/dist/assets/sequenceDiagram-b655622a-2935f8db.js +122 -0
- rasa/core/channels/inspector/dist/assets/stateDiagram-59f0c015-8f5d9693.js +1 -0
- rasa/core/channels/inspector/dist/assets/stateDiagram-v2-2b26beab-d565d1de.js +1 -0
- rasa/core/channels/inspector/dist/assets/styles-080da4f6-75ad421d.js +110 -0
- rasa/core/channels/inspector/dist/assets/styles-3dcbcfbf-7e764226.js +159 -0
- rasa/core/channels/inspector/dist/assets/styles-9c745c82-7a4e0e61.js +207 -0
- rasa/core/channels/inspector/dist/assets/svgDrawCommon-4835440b-4019d1bf.js +1 -0
- rasa/core/channels/inspector/dist/assets/timeline-definition-5b62e21b-01ea12df.js +61 -0
- rasa/core/channels/inspector/dist/assets/xychartDiagram-2b33534f-89407137.js +7 -0
- rasa/core/channels/inspector/dist/index.html +42 -0
- rasa/core/channels/inspector/index.html +40 -0
- rasa/core/channels/inspector/jest.config.ts +13 -0
- rasa/core/channels/inspector/package.json +52 -0
- rasa/core/channels/inspector/setupTests.ts +2 -0
- rasa/core/channels/inspector/src/App.tsx +220 -0
- rasa/core/channels/inspector/src/components/Chat.tsx +95 -0
- rasa/core/channels/inspector/src/components/DiagramFlow.tsx +108 -0
- rasa/core/channels/inspector/src/components/DialogueInformation.tsx +187 -0
- rasa/core/channels/inspector/src/components/DialogueStack.tsx +136 -0
- rasa/core/channels/inspector/src/components/ExpandIcon.tsx +16 -0
- rasa/core/channels/inspector/src/components/FullscreenButton.tsx +45 -0
- rasa/core/channels/inspector/src/components/LoadingSpinner.tsx +22 -0
- rasa/core/channels/inspector/src/components/NoActiveFlow.tsx +21 -0
- rasa/core/channels/inspector/src/components/RasaLogo.tsx +32 -0
- rasa/core/channels/inspector/src/components/SaraDiagrams.tsx +39 -0
- rasa/core/channels/inspector/src/components/Slots.tsx +91 -0
- rasa/core/channels/inspector/src/components/Welcome.tsx +54 -0
- rasa/core/channels/inspector/src/helpers/audiostream.ts +191 -0
- rasa/core/channels/inspector/src/helpers/formatters.test.ts +392 -0
- rasa/core/channels/inspector/src/helpers/formatters.ts +306 -0
- rasa/core/channels/inspector/src/helpers/utils.ts +127 -0
- rasa/core/channels/inspector/src/main.tsx +13 -0
- rasa/core/channels/inspector/src/theme/Button/Button.ts +29 -0
- rasa/core/channels/inspector/src/theme/Heading/Heading.ts +31 -0
- rasa/core/channels/inspector/src/theme/Input/Input.ts +27 -0
- rasa/core/channels/inspector/src/theme/Link/Link.ts +10 -0
- rasa/core/channels/inspector/src/theme/Modal/Modal.ts +47 -0
- rasa/core/channels/inspector/src/theme/Table/Table.tsx +38 -0
- rasa/core/channels/inspector/src/theme/Tooltip/Tooltip.ts +12 -0
- rasa/core/channels/inspector/src/theme/base/breakpoints.ts +8 -0
- rasa/core/channels/inspector/src/theme/base/colors.ts +88 -0
- rasa/core/channels/inspector/src/theme/base/fonts/fontFaces.css +29 -0
- rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.eot +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.svg +329 -0
- rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.ttf +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.woff +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.woff2 +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.eot +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.svg +438 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.ttf +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.woff +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.woff2 +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.eot +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.svg +435 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.ttf +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.woff +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.woff2 +0 -0
- rasa/core/channels/inspector/src/theme/base/radii.ts +9 -0
- rasa/core/channels/inspector/src/theme/base/shadows.ts +7 -0
- rasa/core/channels/inspector/src/theme/base/sizes.ts +7 -0
- rasa/core/channels/inspector/src/theme/base/space.ts +15 -0
- rasa/core/channels/inspector/src/theme/base/styles.ts +13 -0
- rasa/core/channels/inspector/src/theme/base/typography.ts +24 -0
- rasa/core/channels/inspector/src/theme/base/zIndices.ts +19 -0
- rasa/core/channels/inspector/src/theme/index.ts +101 -0
- rasa/core/channels/inspector/src/types.ts +84 -0
- rasa/core/channels/inspector/src/vite-env.d.ts +1 -0
- rasa/core/channels/inspector/tests/__mocks__/fileMock.ts +1 -0
- rasa/core/channels/inspector/tests/__mocks__/matchMedia.ts +16 -0
- rasa/core/channels/inspector/tests/__mocks__/styleMock.ts +1 -0
- rasa/core/channels/inspector/tests/renderWithProviders.tsx +14 -0
- rasa/core/channels/inspector/tsconfig.json +26 -0
- rasa/core/channels/inspector/tsconfig.node.json +10 -0
- rasa/core/channels/inspector/vite.config.ts +8 -0
- rasa/core/channels/inspector/yarn.lock +6249 -0
- rasa/core/channels/mattermost.py +229 -0
- rasa/core/channels/rasa_chat.py +126 -0
- rasa/core/channels/rest.py +230 -0
- rasa/core/channels/rocketchat.py +174 -0
- rasa/core/channels/slack.py +620 -0
- rasa/core/channels/socketio.py +302 -0
- rasa/core/channels/telegram.py +298 -0
- rasa/core/channels/twilio.py +169 -0
- rasa/core/channels/vier_cvg.py +374 -0
- rasa/core/channels/voice_ready/__init__.py +0 -0
- rasa/core/channels/voice_ready/audiocodes.py +501 -0
- rasa/core/channels/voice_ready/jambonz.py +121 -0
- rasa/core/channels/voice_ready/jambonz_protocol.py +396 -0
- rasa/core/channels/voice_ready/twilio_voice.py +403 -0
- rasa/core/channels/voice_ready/utils.py +37 -0
- rasa/core/channels/voice_stream/__init__.py +0 -0
- rasa/core/channels/voice_stream/asr/__init__.py +0 -0
- rasa/core/channels/voice_stream/asr/asr_engine.py +89 -0
- rasa/core/channels/voice_stream/asr/asr_event.py +18 -0
- rasa/core/channels/voice_stream/asr/azure.py +130 -0
- rasa/core/channels/voice_stream/asr/deepgram.py +90 -0
- rasa/core/channels/voice_stream/audio_bytes.py +8 -0
- rasa/core/channels/voice_stream/browser_audio.py +107 -0
- rasa/core/channels/voice_stream/call_state.py +23 -0
- rasa/core/channels/voice_stream/tts/__init__.py +0 -0
- rasa/core/channels/voice_stream/tts/azure.py +106 -0
- rasa/core/channels/voice_stream/tts/cartesia.py +118 -0
- rasa/core/channels/voice_stream/tts/tts_cache.py +27 -0
- rasa/core/channels/voice_stream/tts/tts_engine.py +58 -0
- rasa/core/channels/voice_stream/twilio_media_streams.py +173 -0
- rasa/core/channels/voice_stream/util.py +57 -0
- rasa/core/channels/voice_stream/voice_channel.py +427 -0
- rasa/core/channels/webexteams.py +134 -0
- rasa/core/concurrent_lock_store.py +210 -0
- rasa/core/constants.py +112 -0
- rasa/core/evaluation/__init__.py +0 -0
- rasa/core/evaluation/marker.py +267 -0
- rasa/core/evaluation/marker_base.py +923 -0
- rasa/core/evaluation/marker_stats.py +293 -0
- rasa/core/evaluation/marker_tracker_loader.py +103 -0
- rasa/core/exceptions.py +29 -0
- rasa/core/exporter.py +284 -0
- rasa/core/featurizers/__init__.py +0 -0
- rasa/core/featurizers/precomputation.py +410 -0
- rasa/core/featurizers/single_state_featurizer.py +421 -0
- rasa/core/featurizers/tracker_featurizers.py +1262 -0
- rasa/core/http_interpreter.py +89 -0
- rasa/core/information_retrieval/__init__.py +7 -0
- rasa/core/information_retrieval/faiss.py +124 -0
- rasa/core/information_retrieval/information_retrieval.py +137 -0
- rasa/core/information_retrieval/milvus.py +59 -0
- rasa/core/information_retrieval/qdrant.py +96 -0
- rasa/core/jobs.py +63 -0
- rasa/core/lock.py +139 -0
- rasa/core/lock_store.py +343 -0
- rasa/core/migrate.py +403 -0
- rasa/core/nlg/__init__.py +3 -0
- rasa/core/nlg/callback.py +146 -0
- rasa/core/nlg/contextual_response_rephraser.py +320 -0
- rasa/core/nlg/generator.py +230 -0
- rasa/core/nlg/interpolator.py +143 -0
- rasa/core/nlg/response.py +155 -0
- rasa/core/nlg/summarize.py +70 -0
- rasa/core/persistor.py +538 -0
- rasa/core/policies/__init__.py +0 -0
- rasa/core/policies/ensemble.py +329 -0
- rasa/core/policies/enterprise_search_policy.py +905 -0
- rasa/core/policies/enterprise_search_prompt_template.jinja2 +25 -0
- rasa/core/policies/enterprise_search_prompt_with_citation_template.jinja2 +60 -0
- rasa/core/policies/flow_policy.py +205 -0
- rasa/core/policies/flows/__init__.py +0 -0
- rasa/core/policies/flows/flow_exceptions.py +44 -0
- rasa/core/policies/flows/flow_executor.py +754 -0
- rasa/core/policies/flows/flow_step_result.py +43 -0
- rasa/core/policies/intentless_policy.py +1031 -0
- rasa/core/policies/intentless_prompt_template.jinja2 +22 -0
- rasa/core/policies/memoization.py +538 -0
- rasa/core/policies/policy.py +725 -0
- rasa/core/policies/rule_policy.py +1273 -0
- rasa/core/policies/ted_policy.py +2169 -0
- rasa/core/policies/unexpected_intent_policy.py +1022 -0
- rasa/core/processor.py +1465 -0
- rasa/core/run.py +342 -0
- rasa/core/secrets_manager/__init__.py +0 -0
- rasa/core/secrets_manager/constants.py +36 -0
- rasa/core/secrets_manager/endpoints.py +391 -0
- rasa/core/secrets_manager/factory.py +241 -0
- rasa/core/secrets_manager/secret_manager.py +262 -0
- rasa/core/secrets_manager/vault.py +584 -0
- rasa/core/test.py +1335 -0
- rasa/core/tracker_store.py +1703 -0
- rasa/core/train.py +105 -0
- rasa/core/training/__init__.py +89 -0
- rasa/core/training/converters/__init__.py +0 -0
- rasa/core/training/converters/responses_prefix_converter.py +119 -0
- rasa/core/training/interactive.py +1744 -0
- rasa/core/training/story_conflict.py +381 -0
- rasa/core/training/training.py +93 -0
- rasa/core/utils.py +366 -0
- rasa/core/visualize.py +70 -0
- rasa/dialogue_understanding/__init__.py +0 -0
- rasa/dialogue_understanding/coexistence/__init__.py +0 -0
- rasa/dialogue_understanding/coexistence/constants.py +4 -0
- rasa/dialogue_understanding/coexistence/intent_based_router.py +196 -0
- rasa/dialogue_understanding/coexistence/llm_based_router.py +327 -0
- rasa/dialogue_understanding/coexistence/router_template.jinja2 +12 -0
- rasa/dialogue_understanding/commands/__init__.py +61 -0
- rasa/dialogue_understanding/commands/can_not_handle_command.py +70 -0
- rasa/dialogue_understanding/commands/cancel_flow_command.py +125 -0
- rasa/dialogue_understanding/commands/change_flow_command.py +44 -0
- rasa/dialogue_understanding/commands/chit_chat_answer_command.py +57 -0
- rasa/dialogue_understanding/commands/clarify_command.py +86 -0
- rasa/dialogue_understanding/commands/command.py +85 -0
- rasa/dialogue_understanding/commands/correct_slots_command.py +297 -0
- rasa/dialogue_understanding/commands/error_command.py +79 -0
- rasa/dialogue_understanding/commands/free_form_answer_command.py +9 -0
- rasa/dialogue_understanding/commands/handle_code_change_command.py +73 -0
- rasa/dialogue_understanding/commands/human_handoff_command.py +66 -0
- rasa/dialogue_understanding/commands/knowledge_answer_command.py +57 -0
- rasa/dialogue_understanding/commands/noop_command.py +54 -0
- rasa/dialogue_understanding/commands/repeat_bot_messages_command.py +60 -0
- rasa/dialogue_understanding/commands/restart_command.py +58 -0
- rasa/dialogue_understanding/commands/session_end_command.py +61 -0
- rasa/dialogue_understanding/commands/session_start_command.py +59 -0
- rasa/dialogue_understanding/commands/set_slot_command.py +160 -0
- rasa/dialogue_understanding/commands/skip_question_command.py +75 -0
- rasa/dialogue_understanding/commands/start_flow_command.py +107 -0
- rasa/dialogue_understanding/commands/user_silence_command.py +59 -0
- rasa/dialogue_understanding/commands/utils.py +45 -0
- rasa/dialogue_understanding/generator/__init__.py +21 -0
- rasa/dialogue_understanding/generator/command_generator.py +464 -0
- rasa/dialogue_understanding/generator/constants.py +27 -0
- rasa/dialogue_understanding/generator/flow_document_template.jinja2 +4 -0
- rasa/dialogue_understanding/generator/flow_retrieval.py +466 -0
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +500 -0
- rasa/dialogue_understanding/generator/llm_command_generator.py +67 -0
- rasa/dialogue_understanding/generator/multi_step/__init__.py +0 -0
- rasa/dialogue_understanding/generator/multi_step/fill_slots_prompt.jinja2 +62 -0
- rasa/dialogue_understanding/generator/multi_step/handle_flows_prompt.jinja2 +38 -0
- rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +920 -0
- rasa/dialogue_understanding/generator/nlu_command_adapter.py +261 -0
- rasa/dialogue_understanding/generator/single_step/__init__.py +0 -0
- rasa/dialogue_understanding/generator/single_step/command_prompt_template.jinja2 +60 -0
- rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +486 -0
- rasa/dialogue_understanding/patterns/__init__.py +0 -0
- rasa/dialogue_understanding/patterns/cancel.py +111 -0
- rasa/dialogue_understanding/patterns/cannot_handle.py +43 -0
- rasa/dialogue_understanding/patterns/chitchat.py +37 -0
- rasa/dialogue_understanding/patterns/clarify.py +97 -0
- rasa/dialogue_understanding/patterns/code_change.py +41 -0
- rasa/dialogue_understanding/patterns/collect_information.py +90 -0
- rasa/dialogue_understanding/patterns/completed.py +40 -0
- rasa/dialogue_understanding/patterns/continue_interrupted.py +42 -0
- rasa/dialogue_understanding/patterns/correction.py +278 -0
- rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +301 -0
- rasa/dialogue_understanding/patterns/human_handoff.py +37 -0
- rasa/dialogue_understanding/patterns/internal_error.py +47 -0
- rasa/dialogue_understanding/patterns/repeat.py +37 -0
- rasa/dialogue_understanding/patterns/restart.py +37 -0
- rasa/dialogue_understanding/patterns/search.py +37 -0
- rasa/dialogue_understanding/patterns/session_start.py +37 -0
- rasa/dialogue_understanding/patterns/skip_question.py +38 -0
- rasa/dialogue_understanding/patterns/user_silence.py +37 -0
- rasa/dialogue_understanding/processor/__init__.py +0 -0
- rasa/dialogue_understanding/processor/command_processor.py +720 -0
- rasa/dialogue_understanding/processor/command_processor_component.py +43 -0
- rasa/dialogue_understanding/stack/__init__.py +0 -0
- rasa/dialogue_understanding/stack/dialogue_stack.py +178 -0
- rasa/dialogue_understanding/stack/frames/__init__.py +19 -0
- rasa/dialogue_understanding/stack/frames/chit_chat_frame.py +27 -0
- rasa/dialogue_understanding/stack/frames/dialogue_stack_frame.py +137 -0
- rasa/dialogue_understanding/stack/frames/flow_stack_frame.py +157 -0
- rasa/dialogue_understanding/stack/frames/pattern_frame.py +10 -0
- rasa/dialogue_understanding/stack/frames/search_frame.py +27 -0
- rasa/dialogue_understanding/stack/utils.py +211 -0
- rasa/dialogue_understanding/utils.py +14 -0
- rasa/dialogue_understanding_test/__init__.py +0 -0
- rasa/dialogue_understanding_test/command_metric_calculation.py +12 -0
- rasa/dialogue_understanding_test/constants.py +17 -0
- rasa/dialogue_understanding_test/du_test_case.py +118 -0
- rasa/dialogue_understanding_test/du_test_result.py +11 -0
- rasa/dialogue_understanding_test/du_test_runner.py +93 -0
- rasa/dialogue_understanding_test/io.py +54 -0
- rasa/dialogue_understanding_test/validation.py +22 -0
- rasa/e2e_test/__init__.py +0 -0
- rasa/e2e_test/aggregate_test_stats_calculator.py +134 -0
- rasa/e2e_test/assertions.py +1345 -0
- rasa/e2e_test/assertions_schema.yml +129 -0
- rasa/e2e_test/constants.py +31 -0
- rasa/e2e_test/e2e_config.py +220 -0
- rasa/e2e_test/e2e_config_schema.yml +26 -0
- rasa/e2e_test/e2e_test_case.py +569 -0
- rasa/e2e_test/e2e_test_converter.py +363 -0
- rasa/e2e_test/e2e_test_converter_prompt.jinja2 +70 -0
- rasa/e2e_test/e2e_test_coverage_report.py +364 -0
- rasa/e2e_test/e2e_test_result.py +54 -0
- rasa/e2e_test/e2e_test_runner.py +1192 -0
- rasa/e2e_test/e2e_test_schema.yml +181 -0
- rasa/e2e_test/pykwalify_extensions.py +39 -0
- rasa/e2e_test/stub_custom_action.py +70 -0
- rasa/e2e_test/utils/__init__.py +0 -0
- rasa/e2e_test/utils/e2e_yaml_utils.py +55 -0
- rasa/e2e_test/utils/io.py +598 -0
- rasa/e2e_test/utils/validation.py +178 -0
- rasa/engine/__init__.py +0 -0
- rasa/engine/caching.py +463 -0
- rasa/engine/constants.py +17 -0
- rasa/engine/exceptions.py +14 -0
- rasa/engine/graph.py +642 -0
- rasa/engine/loader.py +48 -0
- rasa/engine/recipes/__init__.py +0 -0
- rasa/engine/recipes/config_files/default_config.yml +41 -0
- rasa/engine/recipes/default_components.py +97 -0
- rasa/engine/recipes/default_recipe.py +1272 -0
- rasa/engine/recipes/graph_recipe.py +79 -0
- rasa/engine/recipes/recipe.py +93 -0
- rasa/engine/runner/__init__.py +0 -0
- rasa/engine/runner/dask.py +250 -0
- rasa/engine/runner/interface.py +49 -0
- rasa/engine/storage/__init__.py +0 -0
- rasa/engine/storage/local_model_storage.py +244 -0
- rasa/engine/storage/resource.py +110 -0
- rasa/engine/storage/storage.py +199 -0
- rasa/engine/training/__init__.py +0 -0
- rasa/engine/training/components.py +176 -0
- rasa/engine/training/fingerprinting.py +64 -0
- rasa/engine/training/graph_trainer.py +256 -0
- rasa/engine/training/hooks.py +164 -0
- rasa/engine/validation.py +1451 -0
- rasa/env.py +14 -0
- rasa/exceptions.py +69 -0
- rasa/graph_components/__init__.py +0 -0
- rasa/graph_components/converters/__init__.py +0 -0
- rasa/graph_components/converters/nlu_message_converter.py +48 -0
- rasa/graph_components/providers/__init__.py +0 -0
- rasa/graph_components/providers/domain_for_core_training_provider.py +87 -0
- rasa/graph_components/providers/domain_provider.py +71 -0
- rasa/graph_components/providers/flows_provider.py +74 -0
- rasa/graph_components/providers/forms_provider.py +44 -0
- rasa/graph_components/providers/nlu_training_data_provider.py +56 -0
- rasa/graph_components/providers/responses_provider.py +44 -0
- rasa/graph_components/providers/rule_only_provider.py +49 -0
- rasa/graph_components/providers/story_graph_provider.py +96 -0
- rasa/graph_components/providers/training_tracker_provider.py +55 -0
- rasa/graph_components/validators/__init__.py +0 -0
- rasa/graph_components/validators/default_recipe_validator.py +550 -0
- rasa/graph_components/validators/finetuning_validator.py +302 -0
- rasa/hooks.py +111 -0
- rasa/jupyter.py +63 -0
- rasa/llm_fine_tuning/__init__.py +0 -0
- rasa/llm_fine_tuning/annotation_module.py +241 -0
- rasa/llm_fine_tuning/conversations.py +144 -0
- rasa/llm_fine_tuning/llm_data_preparation_module.py +178 -0
- rasa/llm_fine_tuning/paraphrasing/__init__.py +0 -0
- rasa/llm_fine_tuning/paraphrasing/conversation_rephraser.py +281 -0
- rasa/llm_fine_tuning/paraphrasing/default_rephrase_prompt_template.jina2 +44 -0
- rasa/llm_fine_tuning/paraphrasing/rephrase_validator.py +121 -0
- rasa/llm_fine_tuning/paraphrasing/rephrased_user_message.py +10 -0
- rasa/llm_fine_tuning/paraphrasing_module.py +128 -0
- rasa/llm_fine_tuning/storage.py +174 -0
- rasa/llm_fine_tuning/train_test_split_module.py +441 -0
- rasa/markers/__init__.py +0 -0
- rasa/markers/marker.py +269 -0
- rasa/markers/marker_base.py +828 -0
- rasa/markers/upload.py +74 -0
- rasa/markers/validate.py +21 -0
- rasa/model.py +118 -0
- rasa/model_manager/__init__.py +0 -0
- rasa/model_manager/config.py +40 -0
- rasa/model_manager/model_api.py +559 -0
- rasa/model_manager/runner_service.py +286 -0
- rasa/model_manager/socket_bridge.py +146 -0
- rasa/model_manager/studio_jwt_auth.py +86 -0
- rasa/model_manager/trainer_service.py +325 -0
- rasa/model_manager/utils.py +87 -0
- rasa/model_manager/warm_rasa_process.py +187 -0
- rasa/model_service.py +112 -0
- rasa/model_testing.py +457 -0
- rasa/model_training.py +596 -0
- rasa/nlu/__init__.py +7 -0
- rasa/nlu/classifiers/__init__.py +3 -0
- rasa/nlu/classifiers/classifier.py +5 -0
- rasa/nlu/classifiers/diet_classifier.py +1881 -0
- rasa/nlu/classifiers/fallback_classifier.py +192 -0
- rasa/nlu/classifiers/keyword_intent_classifier.py +188 -0
- rasa/nlu/classifiers/logistic_regression_classifier.py +253 -0
- rasa/nlu/classifiers/mitie_intent_classifier.py +156 -0
- rasa/nlu/classifiers/regex_message_handler.py +56 -0
- rasa/nlu/classifiers/sklearn_intent_classifier.py +330 -0
- rasa/nlu/constants.py +77 -0
- rasa/nlu/convert.py +40 -0
- rasa/nlu/emulators/__init__.py +0 -0
- rasa/nlu/emulators/dialogflow.py +55 -0
- rasa/nlu/emulators/emulator.py +49 -0
- rasa/nlu/emulators/luis.py +86 -0
- rasa/nlu/emulators/no_emulator.py +10 -0
- rasa/nlu/emulators/wit.py +56 -0
- rasa/nlu/extractors/__init__.py +0 -0
- rasa/nlu/extractors/crf_entity_extractor.py +715 -0
- rasa/nlu/extractors/duckling_entity_extractor.py +206 -0
- rasa/nlu/extractors/entity_synonyms.py +178 -0
- rasa/nlu/extractors/extractor.py +470 -0
- rasa/nlu/extractors/mitie_entity_extractor.py +293 -0
- rasa/nlu/extractors/regex_entity_extractor.py +220 -0
- rasa/nlu/extractors/spacy_entity_extractor.py +95 -0
- rasa/nlu/featurizers/__init__.py +0 -0
- rasa/nlu/featurizers/dense_featurizer/__init__.py +0 -0
- rasa/nlu/featurizers/dense_featurizer/convert_featurizer.py +445 -0
- rasa/nlu/featurizers/dense_featurizer/dense_featurizer.py +57 -0
- rasa/nlu/featurizers/dense_featurizer/lm_featurizer.py +768 -0
- rasa/nlu/featurizers/dense_featurizer/mitie_featurizer.py +170 -0
- rasa/nlu/featurizers/dense_featurizer/spacy_featurizer.py +132 -0
- rasa/nlu/featurizers/featurizer.py +89 -0
- rasa/nlu/featurizers/sparse_featurizer/__init__.py +0 -0
- rasa/nlu/featurizers/sparse_featurizer/count_vectors_featurizer.py +867 -0
- rasa/nlu/featurizers/sparse_featurizer/lexical_syntactic_featurizer.py +571 -0
- rasa/nlu/featurizers/sparse_featurizer/regex_featurizer.py +271 -0
- rasa/nlu/featurizers/sparse_featurizer/sparse_featurizer.py +9 -0
- rasa/nlu/model.py +24 -0
- rasa/nlu/run.py +27 -0
- rasa/nlu/selectors/__init__.py +0 -0
- rasa/nlu/selectors/response_selector.py +987 -0
- rasa/nlu/test.py +1940 -0
- rasa/nlu/tokenizers/__init__.py +0 -0
- rasa/nlu/tokenizers/jieba_tokenizer.py +148 -0
- rasa/nlu/tokenizers/mitie_tokenizer.py +75 -0
- rasa/nlu/tokenizers/spacy_tokenizer.py +72 -0
- rasa/nlu/tokenizers/tokenizer.py +239 -0
- rasa/nlu/tokenizers/whitespace_tokenizer.py +95 -0
- rasa/nlu/utils/__init__.py +35 -0
- rasa/nlu/utils/bilou_utils.py +462 -0
- rasa/nlu/utils/hugging_face/__init__.py +0 -0
- rasa/nlu/utils/hugging_face/registry.py +108 -0
- rasa/nlu/utils/hugging_face/transformers_pre_post_processors.py +311 -0
- rasa/nlu/utils/mitie_utils.py +113 -0
- rasa/nlu/utils/pattern_utils.py +168 -0
- rasa/nlu/utils/spacy_utils.py +310 -0
- rasa/plugin.py +90 -0
- rasa/server.py +1588 -0
- rasa/shared/__init__.py +0 -0
- rasa/shared/constants.py +311 -0
- rasa/shared/core/__init__.py +0 -0
- rasa/shared/core/command_payload_reader.py +109 -0
- rasa/shared/core/constants.py +180 -0
- rasa/shared/core/conversation.py +46 -0
- rasa/shared/core/domain.py +2172 -0
- rasa/shared/core/events.py +2559 -0
- rasa/shared/core/flows/__init__.py +7 -0
- rasa/shared/core/flows/flow.py +562 -0
- rasa/shared/core/flows/flow_path.py +84 -0
- rasa/shared/core/flows/flow_step.py +146 -0
- rasa/shared/core/flows/flow_step_links.py +319 -0
- rasa/shared/core/flows/flow_step_sequence.py +70 -0
- rasa/shared/core/flows/flows_list.py +258 -0
- rasa/shared/core/flows/flows_yaml_schema.json +303 -0
- rasa/shared/core/flows/nlu_trigger.py +117 -0
- rasa/shared/core/flows/steps/__init__.py +24 -0
- rasa/shared/core/flows/steps/action.py +56 -0
- rasa/shared/core/flows/steps/call.py +64 -0
- rasa/shared/core/flows/steps/collect.py +112 -0
- rasa/shared/core/flows/steps/constants.py +5 -0
- rasa/shared/core/flows/steps/continuation.py +36 -0
- rasa/shared/core/flows/steps/end.py +22 -0
- rasa/shared/core/flows/steps/internal.py +44 -0
- rasa/shared/core/flows/steps/link.py +51 -0
- rasa/shared/core/flows/steps/no_operation.py +48 -0
- rasa/shared/core/flows/steps/set_slots.py +50 -0
- rasa/shared/core/flows/steps/start.py +30 -0
- rasa/shared/core/flows/utils.py +39 -0
- rasa/shared/core/flows/validation.py +735 -0
- rasa/shared/core/flows/yaml_flows_io.py +405 -0
- rasa/shared/core/generator.py +908 -0
- rasa/shared/core/slot_mappings.py +526 -0
- rasa/shared/core/slots.py +654 -0
- rasa/shared/core/trackers.py +1183 -0
- rasa/shared/core/training_data/__init__.py +0 -0
- rasa/shared/core/training_data/loading.py +89 -0
- rasa/shared/core/training_data/story_reader/__init__.py +0 -0
- rasa/shared/core/training_data/story_reader/story_reader.py +129 -0
- rasa/shared/core/training_data/story_reader/story_step_builder.py +168 -0
- rasa/shared/core/training_data/story_reader/yaml_story_reader.py +888 -0
- rasa/shared/core/training_data/story_writer/__init__.py +0 -0
- rasa/shared/core/training_data/story_writer/story_writer.py +76 -0
- rasa/shared/core/training_data/story_writer/yaml_story_writer.py +444 -0
- rasa/shared/core/training_data/structures.py +858 -0
- rasa/shared/core/training_data/visualization.html +146 -0
- rasa/shared/core/training_data/visualization.py +603 -0
- rasa/shared/data.py +249 -0
- rasa/shared/engine/__init__.py +0 -0
- rasa/shared/engine/caching.py +26 -0
- rasa/shared/exceptions.py +167 -0
- rasa/shared/importers/__init__.py +0 -0
- rasa/shared/importers/importer.py +770 -0
- rasa/shared/importers/multi_project.py +215 -0
- rasa/shared/importers/rasa.py +108 -0
- rasa/shared/importers/remote_importer.py +196 -0
- rasa/shared/importers/utils.py +36 -0
- rasa/shared/nlu/__init__.py +0 -0
- rasa/shared/nlu/constants.py +53 -0
- rasa/shared/nlu/interpreter.py +10 -0
- rasa/shared/nlu/training_data/__init__.py +0 -0
- rasa/shared/nlu/training_data/entities_parser.py +208 -0
- rasa/shared/nlu/training_data/features.py +492 -0
- rasa/shared/nlu/training_data/formats/__init__.py +10 -0
- rasa/shared/nlu/training_data/formats/dialogflow.py +163 -0
- rasa/shared/nlu/training_data/formats/luis.py +87 -0
- rasa/shared/nlu/training_data/formats/rasa.py +135 -0
- rasa/shared/nlu/training_data/formats/rasa_yaml.py +618 -0
- rasa/shared/nlu/training_data/formats/readerwriter.py +244 -0
- rasa/shared/nlu/training_data/formats/wit.py +52 -0
- rasa/shared/nlu/training_data/loading.py +137 -0
- rasa/shared/nlu/training_data/lookup_tables_parser.py +30 -0
- rasa/shared/nlu/training_data/message.py +490 -0
- rasa/shared/nlu/training_data/schemas/__init__.py +0 -0
- rasa/shared/nlu/training_data/schemas/data_schema.py +85 -0
- rasa/shared/nlu/training_data/schemas/nlu.yml +53 -0
- rasa/shared/nlu/training_data/schemas/responses.yml +70 -0
- rasa/shared/nlu/training_data/synonyms_parser.py +42 -0
- rasa/shared/nlu/training_data/training_data.py +729 -0
- rasa/shared/nlu/training_data/util.py +223 -0
- rasa/shared/providers/__init__.py +0 -0
- rasa/shared/providers/_configs/__init__.py +0 -0
- rasa/shared/providers/_configs/azure_openai_client_config.py +677 -0
- rasa/shared/providers/_configs/client_config.py +59 -0
- rasa/shared/providers/_configs/default_litellm_client_config.py +132 -0
- rasa/shared/providers/_configs/huggingface_local_embedding_client_config.py +236 -0
- rasa/shared/providers/_configs/litellm_router_client_config.py +222 -0
- rasa/shared/providers/_configs/model_group_config.py +173 -0
- rasa/shared/providers/_configs/openai_client_config.py +177 -0
- rasa/shared/providers/_configs/rasa_llm_client_config.py +75 -0
- rasa/shared/providers/_configs/self_hosted_llm_client_config.py +178 -0
- rasa/shared/providers/_configs/utils.py +117 -0
- rasa/shared/providers/_ssl_verification_utils.py +124 -0
- rasa/shared/providers/_utils.py +79 -0
- rasa/shared/providers/constants.py +7 -0
- rasa/shared/providers/embedding/__init__.py +0 -0
- rasa/shared/providers/embedding/_base_litellm_embedding_client.py +243 -0
- rasa/shared/providers/embedding/_langchain_embedding_client_adapter.py +74 -0
- rasa/shared/providers/embedding/azure_openai_embedding_client.py +335 -0
- rasa/shared/providers/embedding/default_litellm_embedding_client.py +126 -0
- rasa/shared/providers/embedding/embedding_client.py +90 -0
- rasa/shared/providers/embedding/embedding_response.py +41 -0
- rasa/shared/providers/embedding/huggingface_local_embedding_client.py +191 -0
- rasa/shared/providers/embedding/litellm_router_embedding_client.py +138 -0
- rasa/shared/providers/embedding/openai_embedding_client.py +172 -0
- rasa/shared/providers/llm/__init__.py +0 -0
- rasa/shared/providers/llm/_base_litellm_client.py +265 -0
- rasa/shared/providers/llm/azure_openai_llm_client.py +415 -0
- rasa/shared/providers/llm/default_litellm_llm_client.py +110 -0
- rasa/shared/providers/llm/litellm_router_llm_client.py +202 -0
- rasa/shared/providers/llm/llm_client.py +78 -0
- rasa/shared/providers/llm/llm_response.py +50 -0
- rasa/shared/providers/llm/openai_llm_client.py +161 -0
- rasa/shared/providers/llm/rasa_llm_client.py +120 -0
- rasa/shared/providers/llm/self_hosted_llm_client.py +276 -0
- rasa/shared/providers/mappings.py +94 -0
- rasa/shared/providers/router/__init__.py +0 -0
- rasa/shared/providers/router/_base_litellm_router_client.py +185 -0
- rasa/shared/providers/router/router_client.py +75 -0
- rasa/shared/utils/__init__.py +0 -0
- rasa/shared/utils/cli.py +102 -0
- rasa/shared/utils/common.py +324 -0
- rasa/shared/utils/constants.py +4 -0
- rasa/shared/utils/health_check/__init__.py +0 -0
- rasa/shared/utils/health_check/embeddings_health_check_mixin.py +31 -0
- rasa/shared/utils/health_check/health_check.py +258 -0
- rasa/shared/utils/health_check/llm_health_check_mixin.py +31 -0
- rasa/shared/utils/io.py +499 -0
- rasa/shared/utils/llm.py +764 -0
- rasa/shared/utils/pykwalify_extensions.py +27 -0
- rasa/shared/utils/schemas/__init__.py +0 -0
- rasa/shared/utils/schemas/config.yml +2 -0
- rasa/shared/utils/schemas/domain.yml +145 -0
- rasa/shared/utils/schemas/events.py +214 -0
- rasa/shared/utils/schemas/model_config.yml +36 -0
- rasa/shared/utils/schemas/stories.yml +173 -0
- rasa/shared/utils/yaml.py +1068 -0
- rasa/studio/__init__.py +0 -0
- rasa/studio/auth.py +270 -0
- rasa/studio/config.py +136 -0
- rasa/studio/constants.py +19 -0
- rasa/studio/data_handler.py +368 -0
- rasa/studio/download.py +489 -0
- rasa/studio/results_logger.py +137 -0
- rasa/studio/train.py +134 -0
- rasa/studio/upload.py +563 -0
- rasa/telemetry.py +1876 -0
- rasa/tracing/__init__.py +0 -0
- rasa/tracing/config.py +355 -0
- rasa/tracing/constants.py +62 -0
- rasa/tracing/instrumentation/__init__.py +0 -0
- rasa/tracing/instrumentation/attribute_extractors.py +765 -0
- rasa/tracing/instrumentation/instrumentation.py +1306 -0
- rasa/tracing/instrumentation/intentless_policy_instrumentation.py +144 -0
- rasa/tracing/instrumentation/metrics.py +294 -0
- rasa/tracing/metric_instrument_provider.py +205 -0
- rasa/utils/__init__.py +0 -0
- rasa/utils/beta.py +83 -0
- rasa/utils/cli.py +28 -0
- rasa/utils/common.py +639 -0
- rasa/utils/converter.py +53 -0
- rasa/utils/endpoints.py +331 -0
- rasa/utils/io.py +252 -0
- rasa/utils/json_utils.py +60 -0
- rasa/utils/licensing.py +542 -0
- rasa/utils/log_utils.py +181 -0
- rasa/utils/mapper.py +210 -0
- rasa/utils/ml_utils.py +147 -0
- rasa/utils/plotting.py +362 -0
- rasa/utils/sanic_error_handler.py +32 -0
- rasa/utils/singleton.py +23 -0
- rasa/utils/tensorflow/__init__.py +0 -0
- rasa/utils/tensorflow/callback.py +112 -0
- rasa/utils/tensorflow/constants.py +116 -0
- rasa/utils/tensorflow/crf.py +492 -0
- rasa/utils/tensorflow/data_generator.py +440 -0
- rasa/utils/tensorflow/environment.py +161 -0
- rasa/utils/tensorflow/exceptions.py +5 -0
- rasa/utils/tensorflow/feature_array.py +366 -0
- rasa/utils/tensorflow/layers.py +1565 -0
- rasa/utils/tensorflow/layers_utils.py +113 -0
- rasa/utils/tensorflow/metrics.py +281 -0
- rasa/utils/tensorflow/model_data.py +798 -0
- rasa/utils/tensorflow/model_data_utils.py +499 -0
- rasa/utils/tensorflow/models.py +935 -0
- rasa/utils/tensorflow/rasa_layers.py +1094 -0
- rasa/utils/tensorflow/transformer.py +640 -0
- rasa/utils/tensorflow/types.py +6 -0
- rasa/utils/train_utils.py +572 -0
- rasa/utils/url_tools.py +53 -0
- rasa/utils/yaml.py +54 -0
- rasa/validator.py +1644 -0
- rasa/version.py +3 -0
- rasa_pro-3.12.0.dev1.dist-info/METADATA +199 -0
- rasa_pro-3.12.0.dev1.dist-info/NOTICE +5 -0
- rasa_pro-3.12.0.dev1.dist-info/RECORD +790 -0
- rasa_pro-3.12.0.dev1.dist-info/WHEEL +4 -0
- rasa_pro-3.12.0.dev1.dist-info/entry_points.txt +3 -0
|
@@ -0,0 +1,1022 @@
|
|
|
1
|
+
import dataclasses
|
|
2
|
+
import logging
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Any, List, Optional, Text, Dict, Type, Union
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import tensorflow as tf
|
|
8
|
+
|
|
9
|
+
import rasa.utils.common
|
|
10
|
+
from rasa.engine.graph import ExecutionContext
|
|
11
|
+
from rasa.engine.recipes.default_recipe import DefaultV1Recipe
|
|
12
|
+
from rasa.engine.storage.resource import Resource
|
|
13
|
+
from rasa.engine.storage.storage import ModelStorage
|
|
14
|
+
from rasa.nlu.classifiers import LABEL_RANKING_LENGTH
|
|
15
|
+
from rasa.shared.nlu.training_data.features import Features
|
|
16
|
+
from rasa.shared.core.domain import Domain
|
|
17
|
+
from rasa.shared.core.trackers import DialogueStateTracker
|
|
18
|
+
from rasa.shared.core.constants import SLOTS, ACTIVE_LOOP, ACTION_UNLIKELY_INTENT_NAME
|
|
19
|
+
from rasa.shared.core.events import UserUttered, ActionExecuted
|
|
20
|
+
import rasa.shared.utils.io
|
|
21
|
+
from rasa.shared.nlu.constants import (
|
|
22
|
+
INTENT,
|
|
23
|
+
TEXT,
|
|
24
|
+
ENTITIES,
|
|
25
|
+
ACTION_NAME,
|
|
26
|
+
SPLIT_ENTITIES_BY_COMMA,
|
|
27
|
+
SPLIT_ENTITIES_BY_COMMA_DEFAULT_VALUE,
|
|
28
|
+
)
|
|
29
|
+
from rasa.nlu.extractors.extractor import EntityTagSpec
|
|
30
|
+
from rasa.core.featurizers.precomputation import MessageContainerForCoreFeaturization
|
|
31
|
+
from rasa.core.featurizers.tracker_featurizers import TrackerFeaturizer
|
|
32
|
+
from rasa.core.featurizers.tracker_featurizers import IntentMaxHistoryTrackerFeaturizer
|
|
33
|
+
from rasa.core.featurizers.single_state_featurizer import (
|
|
34
|
+
IntentTokenizerSingleStateFeaturizer,
|
|
35
|
+
)
|
|
36
|
+
from rasa.shared.core.generator import TrackerWithCachedStates
|
|
37
|
+
from rasa.core.constants import (
|
|
38
|
+
DIALOGUE,
|
|
39
|
+
POLICY_MAX_HISTORY,
|
|
40
|
+
POLICY_PRIORITY,
|
|
41
|
+
UNLIKELY_INTENT_POLICY_PRIORITY,
|
|
42
|
+
)
|
|
43
|
+
from rasa.core.policies.policy import PolicyPrediction
|
|
44
|
+
from rasa.core.policies.ted_policy import (
|
|
45
|
+
LABEL_KEY,
|
|
46
|
+
LABEL_SUB_KEY,
|
|
47
|
+
TEDPolicy,
|
|
48
|
+
TED,
|
|
49
|
+
SEQUENCE_LENGTH,
|
|
50
|
+
SEQUENCE,
|
|
51
|
+
PREDICTION_FEATURES,
|
|
52
|
+
)
|
|
53
|
+
from rasa.utils import train_utils
|
|
54
|
+
from rasa.utils.tensorflow.models import RasaModel
|
|
55
|
+
from rasa.utils.tensorflow.constants import (
|
|
56
|
+
LABEL,
|
|
57
|
+
DENSE_DIMENSION,
|
|
58
|
+
ENCODING_DIMENSION,
|
|
59
|
+
UNIDIRECTIONAL_ENCODER,
|
|
60
|
+
TRANSFORMER_SIZE,
|
|
61
|
+
NUM_TRANSFORMER_LAYERS,
|
|
62
|
+
NUM_HEADS,
|
|
63
|
+
BATCH_SIZES,
|
|
64
|
+
BATCH_STRATEGY,
|
|
65
|
+
EPOCHS,
|
|
66
|
+
RANDOM_SEED,
|
|
67
|
+
RANKING_LENGTH,
|
|
68
|
+
LOSS_TYPE,
|
|
69
|
+
SIMILARITY_TYPE,
|
|
70
|
+
NUM_NEG,
|
|
71
|
+
EVAL_NUM_EXAMPLES,
|
|
72
|
+
EVAL_NUM_EPOCHS,
|
|
73
|
+
REGULARIZATION_CONSTANT,
|
|
74
|
+
SCALE_LOSS,
|
|
75
|
+
EMBEDDING_DIMENSION,
|
|
76
|
+
DROP_RATE_DIALOGUE,
|
|
77
|
+
DROP_RATE_LABEL,
|
|
78
|
+
DROP_RATE,
|
|
79
|
+
DROP_RATE_ATTENTION,
|
|
80
|
+
CONNECTION_DENSITY,
|
|
81
|
+
KEY_RELATIVE_ATTENTION,
|
|
82
|
+
VALUE_RELATIVE_ATTENTION,
|
|
83
|
+
MAX_RELATIVE_POSITION,
|
|
84
|
+
INNER,
|
|
85
|
+
BALANCED,
|
|
86
|
+
TENSORBOARD_LOG_DIR,
|
|
87
|
+
TENSORBOARD_LOG_LEVEL,
|
|
88
|
+
CHECKPOINT_MODEL,
|
|
89
|
+
FEATURIZERS,
|
|
90
|
+
ENTITY_RECOGNITION,
|
|
91
|
+
IGNORE_INTENTS_LIST,
|
|
92
|
+
BILOU_FLAG,
|
|
93
|
+
LEARNING_RATE,
|
|
94
|
+
CROSS_ENTROPY,
|
|
95
|
+
SPARSE_INPUT_DROPOUT,
|
|
96
|
+
DENSE_INPUT_DROPOUT,
|
|
97
|
+
MASKED_LM,
|
|
98
|
+
HIDDEN_LAYERS_SIZES,
|
|
99
|
+
CONCAT_DIMENSION,
|
|
100
|
+
TOLERANCE,
|
|
101
|
+
LABEL_PAD_ID,
|
|
102
|
+
POSITIVE_SCORES_KEY,
|
|
103
|
+
NEGATIVE_SCORES_KEY,
|
|
104
|
+
USE_GPU,
|
|
105
|
+
)
|
|
106
|
+
from rasa.utils.tensorflow import layers
|
|
107
|
+
from rasa.utils.tensorflow.model_data import RasaModelData, FeatureArray, Data
|
|
108
|
+
from rasa.core.exceptions import RasaCoreException
|
|
109
|
+
from rasa.shared.utils import common
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
@dataclasses.dataclass
|
|
113
|
+
class RankingCandidateMetadata:
|
|
114
|
+
"""Dataclass to represent metada for a candidate intent."""
|
|
115
|
+
|
|
116
|
+
name: Text
|
|
117
|
+
score: float
|
|
118
|
+
threshold: Optional[float]
|
|
119
|
+
severity: Optional[float]
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
@dataclasses.dataclass
|
|
123
|
+
class UnexpecTEDIntentPolicyMetadata:
|
|
124
|
+
"""Dataclass to represent policy metadata."""
|
|
125
|
+
|
|
126
|
+
query_intent: RankingCandidateMetadata
|
|
127
|
+
ranking: List[RankingCandidateMetadata]
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
logger = logging.getLogger(__name__)
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
@DefaultV1Recipe.register(
|
|
134
|
+
DefaultV1Recipe.ComponentType.POLICY_WITH_END_TO_END_SUPPORT, is_trainable=True
|
|
135
|
+
)
|
|
136
|
+
class UnexpecTEDIntentPolicy(TEDPolicy):
|
|
137
|
+
"""`UnexpecTEDIntentPolicy` has the same model architecture as `TEDPolicy`.
|
|
138
|
+
|
|
139
|
+
The difference is at a task level.
|
|
140
|
+
Instead of predicting the next probable action, this policy
|
|
141
|
+
predicts whether the last predicted intent is a likely intent
|
|
142
|
+
according to the training stories and conversation context.
|
|
143
|
+
"""
|
|
144
|
+
|
|
145
|
+
@staticmethod
|
|
146
|
+
def get_default_config() -> Dict[Text, Any]:
|
|
147
|
+
"""Returns the default config (see parent class for full docstring)."""
|
|
148
|
+
return {
|
|
149
|
+
# ## Architecture of the used neural network
|
|
150
|
+
# Hidden layer sizes for layers before the embedding layers for user message
|
|
151
|
+
# and labels.
|
|
152
|
+
# The number of hidden layers is equal to the length
|
|
153
|
+
# of the corresponding list.
|
|
154
|
+
HIDDEN_LAYERS_SIZES: {TEXT: []},
|
|
155
|
+
# Dense dimension to use for sparse features.
|
|
156
|
+
DENSE_DIMENSION: {
|
|
157
|
+
TEXT: 128,
|
|
158
|
+
INTENT: 20,
|
|
159
|
+
ACTION_NAME: 20,
|
|
160
|
+
ENTITIES: 20,
|
|
161
|
+
SLOTS: 20,
|
|
162
|
+
ACTIVE_LOOP: 20,
|
|
163
|
+
f"{LABEL}_{INTENT}": 20,
|
|
164
|
+
},
|
|
165
|
+
# Default dimension to use for concatenating sequence and sentence features.
|
|
166
|
+
CONCAT_DIMENSION: {TEXT: 128},
|
|
167
|
+
# Dimension size of embedding vectors before
|
|
168
|
+
# the dialogue transformer encoder.
|
|
169
|
+
ENCODING_DIMENSION: 50,
|
|
170
|
+
# Number of units in transformer encoders
|
|
171
|
+
TRANSFORMER_SIZE: {TEXT: 128, DIALOGUE: 128},
|
|
172
|
+
# Number of layers in transformer encoders
|
|
173
|
+
NUM_TRANSFORMER_LAYERS: {TEXT: 1, DIALOGUE: 1},
|
|
174
|
+
# Number of attention heads in transformer
|
|
175
|
+
NUM_HEADS: 4,
|
|
176
|
+
# If 'True' use key relative embeddings in attention
|
|
177
|
+
KEY_RELATIVE_ATTENTION: False,
|
|
178
|
+
# If 'True' use value relative embeddings in attention
|
|
179
|
+
VALUE_RELATIVE_ATTENTION: False,
|
|
180
|
+
# Max position for relative embeddings. Only in effect
|
|
181
|
+
# if key- or value relative attention are turned on
|
|
182
|
+
MAX_RELATIVE_POSITION: 5,
|
|
183
|
+
# Use a unidirectional or bidirectional encoder
|
|
184
|
+
# for `text`, `action_text`, and `label_action_text`.
|
|
185
|
+
UNIDIRECTIONAL_ENCODER: False,
|
|
186
|
+
# ## Training parameters
|
|
187
|
+
# Initial and final batch sizes:
|
|
188
|
+
# Batch size will be linearly increased for each epoch.
|
|
189
|
+
BATCH_SIZES: [64, 256],
|
|
190
|
+
# Strategy used when creating batches.
|
|
191
|
+
# Can be either 'sequence' or 'balanced'.
|
|
192
|
+
BATCH_STRATEGY: BALANCED,
|
|
193
|
+
# Number of epochs to train
|
|
194
|
+
EPOCHS: 1,
|
|
195
|
+
# Set random seed to any 'int' to get reproducible results
|
|
196
|
+
RANDOM_SEED: None,
|
|
197
|
+
# Initial learning rate for the optimizer
|
|
198
|
+
LEARNING_RATE: 0.001,
|
|
199
|
+
# ## Parameters for embeddings
|
|
200
|
+
# Dimension size of embedding vectors
|
|
201
|
+
EMBEDDING_DIMENSION: 20,
|
|
202
|
+
# The number of incorrect labels. The algorithm will minimize
|
|
203
|
+
# their similarity to the user input during training.
|
|
204
|
+
NUM_NEG: 20,
|
|
205
|
+
# Number of intents to store in ranking key of predicted action metadata.
|
|
206
|
+
# Set this to `0` to include all intents.
|
|
207
|
+
RANKING_LENGTH: LABEL_RANKING_LENGTH,
|
|
208
|
+
# If 'True' scale loss inverse proportionally to the confidence
|
|
209
|
+
# of the correct prediction
|
|
210
|
+
SCALE_LOSS: True,
|
|
211
|
+
# ## Regularization parameters
|
|
212
|
+
# The scale of regularization
|
|
213
|
+
REGULARIZATION_CONSTANT: 0.001,
|
|
214
|
+
# Dropout rate for embedding layers of dialogue features.
|
|
215
|
+
DROP_RATE_DIALOGUE: 0.1,
|
|
216
|
+
# Dropout rate for embedding layers of utterance level features.
|
|
217
|
+
DROP_RATE: 0.0,
|
|
218
|
+
# Dropout rate for embedding layers of label, e.g. action, features.
|
|
219
|
+
DROP_RATE_LABEL: 0.0,
|
|
220
|
+
# Dropout rate for attention.
|
|
221
|
+
DROP_RATE_ATTENTION: 0.0,
|
|
222
|
+
# Fraction of trainable weights in internal layers.
|
|
223
|
+
CONNECTION_DENSITY: 0.2,
|
|
224
|
+
# If 'True' apply dropout to sparse input tensors
|
|
225
|
+
SPARSE_INPUT_DROPOUT: True,
|
|
226
|
+
# If 'True' apply dropout to dense input tensors
|
|
227
|
+
DENSE_INPUT_DROPOUT: True,
|
|
228
|
+
# If 'True' random tokens of the input message will be masked.
|
|
229
|
+
# Since there is no related loss term used inside TED, the masking
|
|
230
|
+
# effectively becomes just input dropout applied to the text of user
|
|
231
|
+
# utterances.
|
|
232
|
+
MASKED_LM: False,
|
|
233
|
+
# ## Evaluation parameters
|
|
234
|
+
# How often calculate validation accuracy.
|
|
235
|
+
# Small values may hurt performance, e.g. model accuracy.
|
|
236
|
+
EVAL_NUM_EPOCHS: 20,
|
|
237
|
+
# How many examples to use for hold out validation set
|
|
238
|
+
# Large values may hurt performance, e.g. model accuracy.
|
|
239
|
+
EVAL_NUM_EXAMPLES: 0,
|
|
240
|
+
# If you want to use tensorboard to visualize training and validation
|
|
241
|
+
# metrics, set this option to a valid output directory.
|
|
242
|
+
TENSORBOARD_LOG_DIR: None,
|
|
243
|
+
# Define when training metrics for tensorboard should be logged.
|
|
244
|
+
# Either after every epoch or for every training step.
|
|
245
|
+
# Valid values: 'epoch' and 'batch'
|
|
246
|
+
TENSORBOARD_LOG_LEVEL: "epoch",
|
|
247
|
+
# Perform model checkpointing
|
|
248
|
+
CHECKPOINT_MODEL: False,
|
|
249
|
+
# Specify what features to use as sequence and sentence features.
|
|
250
|
+
# By default all features in the pipeline are used.
|
|
251
|
+
FEATURIZERS: [],
|
|
252
|
+
# List of intents to ignore for `action_unlikely_intent` prediction.
|
|
253
|
+
IGNORE_INTENTS_LIST: [],
|
|
254
|
+
# Tolerance for prediction of `action_unlikely_intent`.
|
|
255
|
+
# For each intent, the tolerance is the percentage of
|
|
256
|
+
# negative training instances (trackers for which
|
|
257
|
+
# the corresponding intent is not the correct label) that
|
|
258
|
+
# would be ignored by `UnexpecTEDIntentPolicy`. This is converted
|
|
259
|
+
# into a similarity threshold by identifying the similarity
|
|
260
|
+
# score for the (1 - tolerance) percentile of negative
|
|
261
|
+
# examples. Any tracker with a similarity score below this
|
|
262
|
+
# threshold will trigger an `action_unlikely_intent`.
|
|
263
|
+
# Higher values of `tolerance` means the policy is more
|
|
264
|
+
# "tolerant" to surprising paths in conversations and
|
|
265
|
+
# hence will result in lesser number of `action_unlikely_intent`
|
|
266
|
+
# triggers. Acceptable values are between 0.0 and 1.0 (inclusive).
|
|
267
|
+
TOLERANCE: 0.0,
|
|
268
|
+
# Split entities by comma, this makes sense e.g. for a list of
|
|
269
|
+
# ingredients in a recipe, but it doesn't make sense for the parts of
|
|
270
|
+
# an address
|
|
271
|
+
SPLIT_ENTITIES_BY_COMMA: SPLIT_ENTITIES_BY_COMMA_DEFAULT_VALUE,
|
|
272
|
+
# Type of similarity measure to use, either 'auto' or 'cosine' or 'inner'.
|
|
273
|
+
SIMILARITY_TYPE: INNER,
|
|
274
|
+
# If set to true, entities are predicted in user utterances.
|
|
275
|
+
ENTITY_RECOGNITION: False,
|
|
276
|
+
# 'BILOU_flag' determines whether to use BILOU tagging or not.
|
|
277
|
+
# If set to 'True' labelling is more rigorous, however more
|
|
278
|
+
# examples per entity are required.
|
|
279
|
+
# Rule of thumb: you should have more than 100 examples per entity.
|
|
280
|
+
BILOU_FLAG: False,
|
|
281
|
+
# The type of the loss function, either 'cross_entropy' or 'margin'.
|
|
282
|
+
LOSS_TYPE: CROSS_ENTROPY,
|
|
283
|
+
# Determines the importance of policies, higher values take precedence
|
|
284
|
+
POLICY_PRIORITY: UNLIKELY_INTENT_POLICY_PRIORITY,
|
|
285
|
+
USE_GPU: True,
|
|
286
|
+
}
|
|
287
|
+
|
|
288
|
+
def __init__(
|
|
289
|
+
self,
|
|
290
|
+
config: Dict[Text, Any],
|
|
291
|
+
model_storage: ModelStorage,
|
|
292
|
+
resource: Resource,
|
|
293
|
+
execution_context: ExecutionContext,
|
|
294
|
+
model: Optional[RasaModel] = None,
|
|
295
|
+
featurizer: Optional[TrackerFeaturizer] = None,
|
|
296
|
+
fake_features: Optional[Dict[Text, List[Features]]] = None,
|
|
297
|
+
entity_tag_specs: Optional[List[EntityTagSpec]] = None,
|
|
298
|
+
label_quantiles: Optional[Dict[int, List[float]]] = None,
|
|
299
|
+
):
|
|
300
|
+
"""Declares instance variables with default values."""
|
|
301
|
+
# Set all invalid / non configurable parameters
|
|
302
|
+
config[ENTITY_RECOGNITION] = False
|
|
303
|
+
config[BILOU_FLAG] = False
|
|
304
|
+
config[SIMILARITY_TYPE] = INNER
|
|
305
|
+
config[LOSS_TYPE] = CROSS_ENTROPY
|
|
306
|
+
self.config = config
|
|
307
|
+
|
|
308
|
+
super().__init__(
|
|
309
|
+
self.config,
|
|
310
|
+
model_storage,
|
|
311
|
+
resource,
|
|
312
|
+
execution_context,
|
|
313
|
+
model,
|
|
314
|
+
featurizer,
|
|
315
|
+
fake_features,
|
|
316
|
+
entity_tag_specs,
|
|
317
|
+
)
|
|
318
|
+
|
|
319
|
+
self.label_quantiles = label_quantiles or {}
|
|
320
|
+
self.label_thresholds = (
|
|
321
|
+
self._pick_thresholds(self.label_quantiles, self.config[TOLERANCE])
|
|
322
|
+
if self.label_quantiles
|
|
323
|
+
else {}
|
|
324
|
+
)
|
|
325
|
+
self.ignore_intent_list = self.config[IGNORE_INTENTS_LIST]
|
|
326
|
+
|
|
327
|
+
common.mark_as_experimental_feature("UnexpecTED Intent Policy")
|
|
328
|
+
|
|
329
|
+
def _standard_featurizer(self) -> IntentMaxHistoryTrackerFeaturizer:
|
|
330
|
+
return IntentMaxHistoryTrackerFeaturizer(
|
|
331
|
+
IntentTokenizerSingleStateFeaturizer(),
|
|
332
|
+
max_history=self.config.get(POLICY_MAX_HISTORY),
|
|
333
|
+
)
|
|
334
|
+
|
|
335
|
+
@staticmethod
|
|
336
|
+
def model_class() -> Type["IntentTED"]:
|
|
337
|
+
"""Gets the class of the model architecture to be used by the policy.
|
|
338
|
+
|
|
339
|
+
Returns:
|
|
340
|
+
Required class.
|
|
341
|
+
"""
|
|
342
|
+
return IntentTED
|
|
343
|
+
|
|
344
|
+
def _auto_update_configuration(self) -> None:
|
|
345
|
+
self.config = train_utils.update_evaluation_parameters(self.config)
|
|
346
|
+
|
|
347
|
+
@classmethod
|
|
348
|
+
def _metadata_filename(cls) -> Optional[Text]:
|
|
349
|
+
return "unexpected_intent_policy"
|
|
350
|
+
|
|
351
|
+
def _assemble_label_data(
|
|
352
|
+
self, attribute_data: Data, domain: Domain
|
|
353
|
+
) -> RasaModelData:
|
|
354
|
+
"""Constructs data regarding labels to be fed to the model.
|
|
355
|
+
|
|
356
|
+
The resultant model data should contain the keys `label_intent`, `label`.
|
|
357
|
+
`label_intent` will contain the sequence, sentence and mask features
|
|
358
|
+
for all intent labels and `label` will contain the numerical label ids.
|
|
359
|
+
|
|
360
|
+
Args:
|
|
361
|
+
attribute_data: Feature data for all intent labels.
|
|
362
|
+
domain: Domain of the assistant.
|
|
363
|
+
|
|
364
|
+
Returns:
|
|
365
|
+
Features of labels ready to be fed to the model.
|
|
366
|
+
"""
|
|
367
|
+
label_data = RasaModelData()
|
|
368
|
+
label_data.add_data(attribute_data, key_prefix=f"{LABEL_KEY}_")
|
|
369
|
+
label_data.add_lengths(
|
|
370
|
+
f"{LABEL}_{INTENT}", SEQUENCE_LENGTH, f"{LABEL}_{INTENT}", SEQUENCE
|
|
371
|
+
)
|
|
372
|
+
label_ids = np.arange(len(domain.intents))
|
|
373
|
+
label_data.add_features(
|
|
374
|
+
LABEL_KEY,
|
|
375
|
+
LABEL_SUB_KEY,
|
|
376
|
+
[
|
|
377
|
+
FeatureArray(
|
|
378
|
+
np.expand_dims(label_ids, -1),
|
|
379
|
+
number_of_dimensions=2,
|
|
380
|
+
)
|
|
381
|
+
],
|
|
382
|
+
)
|
|
383
|
+
return label_data
|
|
384
|
+
|
|
385
|
+
@staticmethod
|
|
386
|
+
def _prepare_data_for_prediction(model_data: RasaModelData) -> RasaModelData:
|
|
387
|
+
"""Transforms training model data to data usable for making model predictions.
|
|
388
|
+
|
|
389
|
+
Transformation involves filtering out all features which
|
|
390
|
+
are not useful at prediction time. This is important
|
|
391
|
+
because the prediction signature will not contain these
|
|
392
|
+
attributes and hence prediction will break.
|
|
393
|
+
|
|
394
|
+
Args:
|
|
395
|
+
model_data: Data used during model training.
|
|
396
|
+
|
|
397
|
+
Returns:
|
|
398
|
+
Transformed data usable for making predictions.
|
|
399
|
+
"""
|
|
400
|
+
filtered_data: Dict[Text, Dict[Text, Any]] = {
|
|
401
|
+
key: features
|
|
402
|
+
for key, features in model_data.data.items()
|
|
403
|
+
if key in PREDICTION_FEATURES
|
|
404
|
+
}
|
|
405
|
+
return RasaModelData(data=filtered_data)
|
|
406
|
+
|
|
407
|
+
def compute_label_quantiles_post_training(
|
|
408
|
+
self, model_data: RasaModelData, label_ids: np.ndarray
|
|
409
|
+
) -> None:
|
|
410
|
+
"""Computes quantile scores for prediction of `action_unlikely_intent`.
|
|
411
|
+
|
|
412
|
+
Multiple quantiles are computed for each label
|
|
413
|
+
so that an appropriate threshold can be picked at
|
|
414
|
+
inference time according to the `tolerance` value specified.
|
|
415
|
+
|
|
416
|
+
Args:
|
|
417
|
+
model_data: Data used for training the model.
|
|
418
|
+
label_ids: Numerical IDs of labels for each data point used during training.
|
|
419
|
+
"""
|
|
420
|
+
# `model_data` contains data attributes like `label` which were
|
|
421
|
+
# used during training. These attributes are not present in
|
|
422
|
+
# the `predict_data_signature`. Prediction through the model
|
|
423
|
+
# will break if `model_data` is passed as it is through the model.
|
|
424
|
+
# Hence, we first filter out the attributes inside `model_data`
|
|
425
|
+
# to keep only those which should be present during prediction.
|
|
426
|
+
model_prediction_data = self._prepare_data_for_prediction(model_data)
|
|
427
|
+
prediction_scores = (
|
|
428
|
+
self.model.run_bulk_inference(model_prediction_data)
|
|
429
|
+
if self.model is not None
|
|
430
|
+
else {}
|
|
431
|
+
)
|
|
432
|
+
label_id_scores = self._collect_label_id_grouped_scores(
|
|
433
|
+
prediction_scores, label_ids
|
|
434
|
+
)
|
|
435
|
+
# For each label id, compute multiple quantile scores.
|
|
436
|
+
# These quantile scores can be looked up during inference
|
|
437
|
+
# to select a specific threshold according to the `tolerance`
|
|
438
|
+
# value specified in the configuration.
|
|
439
|
+
self.label_quantiles = self._compute_label_quantiles(label_id_scores)
|
|
440
|
+
|
|
441
|
+
@staticmethod
|
|
442
|
+
def _get_trackers_for_training(
|
|
443
|
+
trackers: List[TrackerWithCachedStates],
|
|
444
|
+
) -> List[TrackerWithCachedStates]:
|
|
445
|
+
"""Filters out the list of trackers which should not be used for training.
|
|
446
|
+
|
|
447
|
+
`UnexpecTEDIntentPolicy` cannot be trained on trackers with:
|
|
448
|
+
1. `UserUttered` events with no intent.
|
|
449
|
+
2. `ActionExecuted` events with no action_name.
|
|
450
|
+
|
|
451
|
+
Trackers with such events are filtered out.
|
|
452
|
+
|
|
453
|
+
Args:
|
|
454
|
+
trackers: All trackers available for training.
|
|
455
|
+
|
|
456
|
+
Returns:
|
|
457
|
+
Trackers which should be used for training.
|
|
458
|
+
"""
|
|
459
|
+
trackers_for_training = []
|
|
460
|
+
for tracker in trackers:
|
|
461
|
+
tracker_compatible = True
|
|
462
|
+
for event in tracker.applied_events(True):
|
|
463
|
+
if (isinstance(event, UserUttered) and event.intent_name is None) or (
|
|
464
|
+
isinstance(event, ActionExecuted) and event.action_name is None
|
|
465
|
+
):
|
|
466
|
+
tracker_compatible = False
|
|
467
|
+
break
|
|
468
|
+
if tracker_compatible:
|
|
469
|
+
trackers_for_training.append(tracker)
|
|
470
|
+
return trackers_for_training
|
|
471
|
+
|
|
472
|
+
def run_training(
|
|
473
|
+
self, model_data: RasaModelData, label_ids: Optional[np.ndarray] = None
|
|
474
|
+
) -> None:
|
|
475
|
+
"""Feeds the featurized training data to the model.
|
|
476
|
+
|
|
477
|
+
Args:
|
|
478
|
+
model_data: Featurized training data.
|
|
479
|
+
label_ids: Label ids corresponding to the data points in `model_data`.
|
|
480
|
+
|
|
481
|
+
Raises:
|
|
482
|
+
`RasaCoreException` if `label_ids` is None as it's needed for
|
|
483
|
+
running post training procedures.
|
|
484
|
+
"""
|
|
485
|
+
if label_ids is None:
|
|
486
|
+
raise RasaCoreException(
|
|
487
|
+
f"Incorrect usage of `run_training` "
|
|
488
|
+
f"method of `{self.__class__.__name__}`."
|
|
489
|
+
f"`label_ids` cannot be left to `None`."
|
|
490
|
+
)
|
|
491
|
+
super().run_training(model_data, label_ids)
|
|
492
|
+
self.compute_label_quantiles_post_training(model_data, label_ids)
|
|
493
|
+
|
|
494
|
+
def _collect_action_metadata(
|
|
495
|
+
self, domain: Domain, similarities: np.ndarray, query_intent: Text
|
|
496
|
+
) -> UnexpecTEDIntentPolicyMetadata:
|
|
497
|
+
"""Collects metadata to be attached to the predicted action.
|
|
498
|
+
|
|
499
|
+
Metadata schema looks like this:
|
|
500
|
+
|
|
501
|
+
{
|
|
502
|
+
"query_intent": <metadata of intent that was queried>,
|
|
503
|
+
"ranking": <sorted list of metadata corresponding to all intents
|
|
504
|
+
(truncated by `ranking_length` parameter)
|
|
505
|
+
It also includes the `query_intent`.
|
|
506
|
+
Sorting is based on predicted similarities.>
|
|
507
|
+
}
|
|
508
|
+
|
|
509
|
+
Each metadata dictionary looks like this:
|
|
510
|
+
|
|
511
|
+
{
|
|
512
|
+
"name": <name of intent>,
|
|
513
|
+
"score": <predicted similarity score>,
|
|
514
|
+
"threshold": <threshold used for intent>,
|
|
515
|
+
"severity": <numerical difference between threshold and score>
|
|
516
|
+
}
|
|
517
|
+
|
|
518
|
+
Args:
|
|
519
|
+
domain: Domain of the assistant.
|
|
520
|
+
similarities: Predicted similarities for each intent.
|
|
521
|
+
query_intent: Name of intent queried in this round of inference.
|
|
522
|
+
|
|
523
|
+
Returns:
|
|
524
|
+
Metadata to be attached.
|
|
525
|
+
"""
|
|
526
|
+
query_intent_index = domain.intents.index(query_intent)
|
|
527
|
+
|
|
528
|
+
def _compile_metadata_for_label(
|
|
529
|
+
label_name: Text, similarity_score: float, threshold: Optional[float]
|
|
530
|
+
) -> RankingCandidateMetadata:
|
|
531
|
+
severity = float(threshold - similarity_score) if threshold else None
|
|
532
|
+
return RankingCandidateMetadata(
|
|
533
|
+
label_name,
|
|
534
|
+
float(similarity_score),
|
|
535
|
+
float(threshold) if threshold else None,
|
|
536
|
+
severity,
|
|
537
|
+
)
|
|
538
|
+
|
|
539
|
+
query_intent_metadata = _compile_metadata_for_label(
|
|
540
|
+
query_intent,
|
|
541
|
+
similarities[0][domain.intents.index(query_intent)],
|
|
542
|
+
self.label_thresholds.get(query_intent_index),
|
|
543
|
+
)
|
|
544
|
+
|
|
545
|
+
# Ranking in descending order of predicted similarities
|
|
546
|
+
sorted_similarities = sorted(
|
|
547
|
+
[(index, similarity) for index, similarity in enumerate(similarities[0])],
|
|
548
|
+
key=lambda x: -x[1],
|
|
549
|
+
)
|
|
550
|
+
|
|
551
|
+
if self.config[RANKING_LENGTH] > 0:
|
|
552
|
+
sorted_similarities = sorted_similarities[: self.config[RANKING_LENGTH]]
|
|
553
|
+
|
|
554
|
+
ranking_metadata = [
|
|
555
|
+
_compile_metadata_for_label(
|
|
556
|
+
domain.intents[intent_index],
|
|
557
|
+
similarity,
|
|
558
|
+
self.label_thresholds.get(intent_index),
|
|
559
|
+
)
|
|
560
|
+
for intent_index, similarity in sorted_similarities
|
|
561
|
+
]
|
|
562
|
+
|
|
563
|
+
return UnexpecTEDIntentPolicyMetadata(query_intent_metadata, ranking_metadata)
|
|
564
|
+
|
|
565
|
+
async def predict_action_probabilities(
|
|
566
|
+
self,
|
|
567
|
+
tracker: DialogueStateTracker,
|
|
568
|
+
domain: Domain,
|
|
569
|
+
rule_only_data: Optional[Dict[Text, Any]] = None,
|
|
570
|
+
precomputations: Optional[MessageContainerForCoreFeaturization] = None,
|
|
571
|
+
**kwargs: Any,
|
|
572
|
+
) -> PolicyPrediction:
|
|
573
|
+
"""Predicts the next action the bot should take after seeing the tracker.
|
|
574
|
+
|
|
575
|
+
Args:
|
|
576
|
+
tracker: Tracker containing past conversation events.
|
|
577
|
+
domain: Domain of the assistant.
|
|
578
|
+
rule_only_data: Slots and loops which are specific to rules and hence
|
|
579
|
+
should be ignored by this policy.
|
|
580
|
+
precomputations: Contains precomputed features and attributes.
|
|
581
|
+
**kwargs: Additional arguments.
|
|
582
|
+
|
|
583
|
+
Returns:
|
|
584
|
+
The policy's prediction (e.g. the probabilities for the actions).
|
|
585
|
+
"""
|
|
586
|
+
if self.model is None or self.should_abstain_in_coexistence(tracker, False):
|
|
587
|
+
return self._prediction(self._default_predictions(domain))
|
|
588
|
+
|
|
589
|
+
# Prediction through the policy is skipped if:
|
|
590
|
+
# 1. If the tracker does not contain any event of type `UserUttered`
|
|
591
|
+
# till now or the intent of such event is not in domain.
|
|
592
|
+
# 2. There is at least one event of type `ActionExecuted`
|
|
593
|
+
# after the last `UserUttered` event.
|
|
594
|
+
if self._should_skip_prediction(tracker, domain):
|
|
595
|
+
logger.debug(
|
|
596
|
+
f"Skipping predictions for {self.__class__.__name__} "
|
|
597
|
+
f"as either there is no event of type `UserUttered`, "
|
|
598
|
+
f"event's intent is new and not in domain or "
|
|
599
|
+
f"there is an event of type `ActionExecuted` after "
|
|
600
|
+
f"the last `UserUttered`."
|
|
601
|
+
)
|
|
602
|
+
return self._prediction(self._default_predictions(domain))
|
|
603
|
+
|
|
604
|
+
# create model data from tracker
|
|
605
|
+
tracker_state_features = self._featurize_for_prediction(
|
|
606
|
+
tracker, domain, precomputations, rule_only_data=rule_only_data
|
|
607
|
+
)
|
|
608
|
+
|
|
609
|
+
model_data = self._create_model_data(tracker_state_features)
|
|
610
|
+
output = self.model.run_inference(model_data)
|
|
611
|
+
|
|
612
|
+
# take the last prediction in the sequence
|
|
613
|
+
if isinstance(output["similarities"], np.ndarray):
|
|
614
|
+
sequence_similarities = output["similarities"][:, -1, :]
|
|
615
|
+
else:
|
|
616
|
+
raise TypeError(
|
|
617
|
+
"model output for `similarities` " "should be a numpy array"
|
|
618
|
+
)
|
|
619
|
+
|
|
620
|
+
# Check for unlikely intent
|
|
621
|
+
last_user_uttered_event = tracker.get_last_event_for(UserUttered)
|
|
622
|
+
query_intent = (
|
|
623
|
+
last_user_uttered_event.intent_name
|
|
624
|
+
if last_user_uttered_event is not None
|
|
625
|
+
else ""
|
|
626
|
+
)
|
|
627
|
+
is_unlikely_intent = self._check_unlikely_intent(
|
|
628
|
+
domain, sequence_similarities, query_intent
|
|
629
|
+
)
|
|
630
|
+
|
|
631
|
+
confidences = list(np.zeros(domain.num_actions))
|
|
632
|
+
|
|
633
|
+
if is_unlikely_intent:
|
|
634
|
+
confidences[domain.index_for_action(ACTION_UNLIKELY_INTENT_NAME)] = 1.0
|
|
635
|
+
|
|
636
|
+
return self._prediction(
|
|
637
|
+
confidences,
|
|
638
|
+
action_metadata=dataclasses.asdict(
|
|
639
|
+
self._collect_action_metadata(
|
|
640
|
+
domain, sequence_similarities, query_intent
|
|
641
|
+
)
|
|
642
|
+
),
|
|
643
|
+
)
|
|
644
|
+
|
|
645
|
+
@staticmethod
|
|
646
|
+
def _should_skip_prediction(tracker: DialogueStateTracker, domain: Domain) -> bool:
|
|
647
|
+
"""Checks if the policy should skip making a prediction.
|
|
648
|
+
|
|
649
|
+
A prediction can be skipped if:
|
|
650
|
+
1. There is no event of type `UserUttered` in the tracker.
|
|
651
|
+
2. If the `UserUttered` event's intent is new and not in domain
|
|
652
|
+
(a new intent can be created from rasa interactive and not placed in
|
|
653
|
+
domain yet)
|
|
654
|
+
3. There is an event of type `ActionExecuted` after the last
|
|
655
|
+
`UserUttered` event. This is to prevent the dialogue manager
|
|
656
|
+
from getting stuck in a prediction loop.
|
|
657
|
+
For example, if the last `ActionExecuted` event
|
|
658
|
+
contained `action_unlikely_intent` predicted by
|
|
659
|
+
`UnexpecTEDIntentPolicy` and
|
|
660
|
+
if `UnexpecTEDIntentPolicy` runs inference
|
|
661
|
+
on the same tracker, it will predict `action_unlikely_intent`
|
|
662
|
+
again which would make the dialogue manager get stuck in a
|
|
663
|
+
prediction loop.
|
|
664
|
+
|
|
665
|
+
Returns:
|
|
666
|
+
Whether prediction should be skipped.
|
|
667
|
+
"""
|
|
668
|
+
applied_events = tracker.applied_events(True)
|
|
669
|
+
|
|
670
|
+
for event in reversed(applied_events):
|
|
671
|
+
if isinstance(event, ActionExecuted):
|
|
672
|
+
return True
|
|
673
|
+
elif isinstance(event, UserUttered):
|
|
674
|
+
if event.intent_name not in domain.intents:
|
|
675
|
+
return True
|
|
676
|
+
return False
|
|
677
|
+
# No event of type `ActionExecuted` and `UserUttered` means
|
|
678
|
+
# that there is nothing for `UnexpecTEDIntentPolicy` to predict on.
|
|
679
|
+
return True
|
|
680
|
+
|
|
681
|
+
def _should_check_for_intent(self, intent: Text, domain: Domain) -> bool:
|
|
682
|
+
"""Checks if the intent should raise `action_unlikely_intent`.
|
|
683
|
+
|
|
684
|
+
Args:
|
|
685
|
+
intent: Intent to be queried.
|
|
686
|
+
domain: Domain of the assistant.
|
|
687
|
+
|
|
688
|
+
Returns:
|
|
689
|
+
Whether intent should raise `action_unlikely_intent` or not.
|
|
690
|
+
"""
|
|
691
|
+
if domain.intents.index(intent) not in self.label_thresholds:
|
|
692
|
+
# This means the intent was never present in a story
|
|
693
|
+
logger.debug(
|
|
694
|
+
f"Query intent index {domain.intents.index(intent)} not "
|
|
695
|
+
f"found in label thresholds - {self.label_thresholds}. "
|
|
696
|
+
f"Check for `{ACTION_UNLIKELY_INTENT_NAME}` prediction will be skipped."
|
|
697
|
+
)
|
|
698
|
+
return False
|
|
699
|
+
if intent in self.config[IGNORE_INTENTS_LIST]:
|
|
700
|
+
logger.debug(
|
|
701
|
+
f"Query intent `{intent}` found in "
|
|
702
|
+
f"`{IGNORE_INTENTS_LIST}={self.config[IGNORE_INTENTS_LIST]}`. "
|
|
703
|
+
f"Check for `{ACTION_UNLIKELY_INTENT_NAME}` prediction will be skipped."
|
|
704
|
+
)
|
|
705
|
+
return False
|
|
706
|
+
|
|
707
|
+
return True
|
|
708
|
+
|
|
709
|
+
def _check_unlikely_intent(
|
|
710
|
+
self, domain: Domain, similarities: np.ndarray, query_intent: Text
|
|
711
|
+
) -> bool:
|
|
712
|
+
"""Checks if the query intent is probable according to model's predictions.
|
|
713
|
+
|
|
714
|
+
If the similarity prediction for the intent
|
|
715
|
+
is lower than the threshold calculated for that
|
|
716
|
+
intent during training, the corresponding user
|
|
717
|
+
intent is unlikely.
|
|
718
|
+
|
|
719
|
+
Args:
|
|
720
|
+
domain: Domain of the assistant.
|
|
721
|
+
similarities: Predicted similarities for all intents.
|
|
722
|
+
query_intent: Intent to be queried.
|
|
723
|
+
|
|
724
|
+
Returns:
|
|
725
|
+
Whether query intent is likely or not.
|
|
726
|
+
"""
|
|
727
|
+
logger.debug(f"Querying for intent `{query_intent}`.")
|
|
728
|
+
|
|
729
|
+
if not self._should_check_for_intent(query_intent, domain):
|
|
730
|
+
return False
|
|
731
|
+
|
|
732
|
+
predicted_intent_scores = {
|
|
733
|
+
index: similarities[0][index] for index, intent in enumerate(domain.intents)
|
|
734
|
+
}
|
|
735
|
+
sorted_intent_scores = sorted(
|
|
736
|
+
[
|
|
737
|
+
(domain.intents[label_index], score)
|
|
738
|
+
for label_index, score in predicted_intent_scores.items()
|
|
739
|
+
],
|
|
740
|
+
key=lambda x: x[1],
|
|
741
|
+
)
|
|
742
|
+
query_intent_id = domain.intents.index(query_intent)
|
|
743
|
+
query_intent_similarity = similarities[0][query_intent_id]
|
|
744
|
+
highest_likely_intent_id = domain.intents.index(sorted_intent_scores[-1][0])
|
|
745
|
+
|
|
746
|
+
logger.debug(
|
|
747
|
+
f"Score for intent `{query_intent}` is "
|
|
748
|
+
f"`{query_intent_similarity}`, while "
|
|
749
|
+
f"threshold is `{self.label_thresholds[query_intent_id]}`."
|
|
750
|
+
)
|
|
751
|
+
logger.debug(
|
|
752
|
+
f"Top 5 intents (in ascending order) that "
|
|
753
|
+
f"are likely here are: `{sorted_intent_scores[-5:]}`."
|
|
754
|
+
)
|
|
755
|
+
|
|
756
|
+
# If score for query intent is below threshold and
|
|
757
|
+
# the query intent is not the top likely intent
|
|
758
|
+
if (
|
|
759
|
+
query_intent_similarity < self.label_thresholds[query_intent_id]
|
|
760
|
+
and query_intent_id != highest_likely_intent_id
|
|
761
|
+
):
|
|
762
|
+
logger.debug(
|
|
763
|
+
f"Intent `{query_intent}-{query_intent_id}` unlikely to occur here."
|
|
764
|
+
)
|
|
765
|
+
return True
|
|
766
|
+
|
|
767
|
+
return False
|
|
768
|
+
|
|
769
|
+
@staticmethod
|
|
770
|
+
def _collect_label_id_grouped_scores(
|
|
771
|
+
output_scores: Dict[Text, np.ndarray], label_ids: np.ndarray
|
|
772
|
+
) -> Dict[int, Dict[Text, List[float]]]:
|
|
773
|
+
"""Collects similarities predicted for each label id.
|
|
774
|
+
|
|
775
|
+
For each `label_id`, we collect similarity scores across
|
|
776
|
+
all trackers and categorize them into two buckets:
|
|
777
|
+
1. Similarity scores when `label_id` is the correct label.
|
|
778
|
+
2. Similarity scores when `label_id` is the wrong label.
|
|
779
|
+
|
|
780
|
+
Args:
|
|
781
|
+
output_scores: Model's predictions for each data point.
|
|
782
|
+
label_ids: Numerical IDs of labels for each data point.
|
|
783
|
+
|
|
784
|
+
Returns:
|
|
785
|
+
Both buckets of similarity scores grouped by each unique label id.
|
|
786
|
+
"""
|
|
787
|
+
unique_label_ids = np.unique(label_ids).tolist()
|
|
788
|
+
if LABEL_PAD_ID in unique_label_ids:
|
|
789
|
+
unique_label_ids.remove(LABEL_PAD_ID)
|
|
790
|
+
|
|
791
|
+
label_id_scores: Dict[int, Dict[Text, List[float]]] = {
|
|
792
|
+
label_id: {POSITIVE_SCORES_KEY: [], NEGATIVE_SCORES_KEY: []}
|
|
793
|
+
for label_id in unique_label_ids
|
|
794
|
+
}
|
|
795
|
+
|
|
796
|
+
for index, all_pos_labels in enumerate(label_ids):
|
|
797
|
+
for candidate_label_id in unique_label_ids:
|
|
798
|
+
if candidate_label_id in all_pos_labels:
|
|
799
|
+
label_id_scores[candidate_label_id][POSITIVE_SCORES_KEY].append(
|
|
800
|
+
output_scores["similarities"][index, 0, candidate_label_id]
|
|
801
|
+
)
|
|
802
|
+
else:
|
|
803
|
+
label_id_scores[candidate_label_id][NEGATIVE_SCORES_KEY].append(
|
|
804
|
+
output_scores["similarities"][index, 0, candidate_label_id]
|
|
805
|
+
)
|
|
806
|
+
|
|
807
|
+
return label_id_scores
|
|
808
|
+
|
|
809
|
+
@staticmethod
|
|
810
|
+
def _compute_label_quantiles(
|
|
811
|
+
label_id_scores: Dict[int, Dict[Text, List[float]]],
|
|
812
|
+
) -> Dict[int, List[float]]:
|
|
813
|
+
"""Computes multiple quantiles for each label id.
|
|
814
|
+
|
|
815
|
+
The quantiles are computed over the negative scores
|
|
816
|
+
collected for each label id. However, no quantile score
|
|
817
|
+
can be greater than the minimum positive score collected
|
|
818
|
+
for the corresponding label id.
|
|
819
|
+
|
|
820
|
+
Args:
|
|
821
|
+
label_id_scores: Scores collected for each label id
|
|
822
|
+
over positive and negative trackers.
|
|
823
|
+
|
|
824
|
+
Returns:
|
|
825
|
+
Computed quantiles for each label id.
|
|
826
|
+
"""
|
|
827
|
+
label_quantiles = {}
|
|
828
|
+
|
|
829
|
+
quantile_indices = [
|
|
830
|
+
1 - tolerance_value / 100.0 for tolerance_value in range(0, 100, 5)
|
|
831
|
+
]
|
|
832
|
+
for label_id, prediction_scores in label_id_scores.items():
|
|
833
|
+
positive_scores, negative_scores = (
|
|
834
|
+
prediction_scores[POSITIVE_SCORES_KEY],
|
|
835
|
+
prediction_scores[NEGATIVE_SCORES_KEY],
|
|
836
|
+
)
|
|
837
|
+
minimum_positive_score = min(positive_scores)
|
|
838
|
+
if negative_scores:
|
|
839
|
+
quantile_values = np.quantile( # type: ignore[call-overload]
|
|
840
|
+
negative_scores, quantile_indices, interpolation="lower"
|
|
841
|
+
)
|
|
842
|
+
label_quantiles[label_id] = [
|
|
843
|
+
min(minimum_positive_score, value) for value in quantile_values
|
|
844
|
+
]
|
|
845
|
+
else:
|
|
846
|
+
label_quantiles[label_id] = [minimum_positive_score] * len(
|
|
847
|
+
quantile_indices
|
|
848
|
+
)
|
|
849
|
+
|
|
850
|
+
return label_quantiles
|
|
851
|
+
|
|
852
|
+
@staticmethod
|
|
853
|
+
def _pick_thresholds(
|
|
854
|
+
label_quantiles: Dict[int, List[float]], tolerance: float
|
|
855
|
+
) -> Dict[int, float]:
|
|
856
|
+
"""Computes a threshold for each label id.
|
|
857
|
+
|
|
858
|
+
Uses tolerance which is the percentage of negative
|
|
859
|
+
trackers for which predicted score should be equal
|
|
860
|
+
to or above the threshold.
|
|
861
|
+
|
|
862
|
+
Args:
|
|
863
|
+
label_quantiles: Quantiles computed for each label id
|
|
864
|
+
tolerance: Specified tolerance value from the configuration.
|
|
865
|
+
|
|
866
|
+
Returns:
|
|
867
|
+
Computed thresholds
|
|
868
|
+
"""
|
|
869
|
+
label_thresholds = {}
|
|
870
|
+
for label_id in label_quantiles:
|
|
871
|
+
num_thresholds = len(label_quantiles[label_id])
|
|
872
|
+
label_thresholds[label_id] = label_quantiles[label_id][
|
|
873
|
+
min(int(tolerance * num_thresholds), num_thresholds - 1)
|
|
874
|
+
]
|
|
875
|
+
return label_thresholds
|
|
876
|
+
|
|
877
|
+
def persist_model_utilities(self, model_path: Path) -> None:
|
|
878
|
+
"""Persists model's utility attributes like model weights, etc.
|
|
879
|
+
|
|
880
|
+
Args:
|
|
881
|
+
model_path: Path where model is to be persisted
|
|
882
|
+
"""
|
|
883
|
+
super().persist_model_utilities(model_path)
|
|
884
|
+
|
|
885
|
+
from safetensors.numpy import save_file
|
|
886
|
+
|
|
887
|
+
save_file(
|
|
888
|
+
{str(k): np.array(v) for k, v in self.label_quantiles.items()},
|
|
889
|
+
model_path / f"{self._metadata_filename()}.label_quantiles.st",
|
|
890
|
+
)
|
|
891
|
+
|
|
892
|
+
@classmethod
|
|
893
|
+
def _load_model_utilities(cls, model_path: Path) -> Dict[Text, Any]:
|
|
894
|
+
"""Loads model's utility attributes.
|
|
895
|
+
|
|
896
|
+
Args:
|
|
897
|
+
model_path: Path where model is to be persisted.
|
|
898
|
+
"""
|
|
899
|
+
model_utilties = super()._load_model_utilities(model_path)
|
|
900
|
+
|
|
901
|
+
from safetensors.numpy import load_file
|
|
902
|
+
|
|
903
|
+
loaded_label_quantiles = load_file(
|
|
904
|
+
model_path / f"{cls._metadata_filename()}.label_quantiles.st"
|
|
905
|
+
)
|
|
906
|
+
label_quantiles = {int(k): list(v) for k, v in loaded_label_quantiles.items()}
|
|
907
|
+
|
|
908
|
+
model_utilties.update({"label_quantiles": label_quantiles})
|
|
909
|
+
return model_utilties
|
|
910
|
+
|
|
911
|
+
@classmethod
|
|
912
|
+
def _update_loaded_params(cls, meta: Dict[Text, Any]) -> Dict[Text, Any]:
|
|
913
|
+
meta = rasa.utils.common.override_defaults(cls.get_default_config(), meta)
|
|
914
|
+
return meta
|
|
915
|
+
|
|
916
|
+
@classmethod
|
|
917
|
+
def _load_policy_with_model(
|
|
918
|
+
cls,
|
|
919
|
+
config: Dict[Text, Any],
|
|
920
|
+
model_storage: ModelStorage,
|
|
921
|
+
resource: Resource,
|
|
922
|
+
execution_context: ExecutionContext,
|
|
923
|
+
featurizer: TrackerFeaturizer,
|
|
924
|
+
model: "IntentTED",
|
|
925
|
+
model_utilities: Dict[Text, Any],
|
|
926
|
+
) -> "UnexpecTEDIntentPolicy":
|
|
927
|
+
return cls(
|
|
928
|
+
config,
|
|
929
|
+
model_storage,
|
|
930
|
+
resource,
|
|
931
|
+
execution_context,
|
|
932
|
+
model=model,
|
|
933
|
+
featurizer=featurizer,
|
|
934
|
+
fake_features=model_utilities["fake_features"],
|
|
935
|
+
entity_tag_specs=model_utilities["entity_tag_specs"],
|
|
936
|
+
label_quantiles=model_utilities["label_quantiles"],
|
|
937
|
+
)
|
|
938
|
+
|
|
939
|
+
|
|
940
|
+
class IntentTED(TED):
|
|
941
|
+
"""Follows TED's model architecture from https://arxiv.org/abs/1910.00486.
|
|
942
|
+
|
|
943
|
+
However, it has been re-purposed to predict multiple
|
|
944
|
+
labels (intents) instead of a single label (action).
|
|
945
|
+
"""
|
|
946
|
+
|
|
947
|
+
def _prepare_dot_product_loss(
|
|
948
|
+
self, name: Text, scale_loss: bool, prefix: Text = "loss"
|
|
949
|
+
) -> None:
|
|
950
|
+
self._tf_layers[f"{prefix}.{name}"] = self.dot_product_loss_layer(
|
|
951
|
+
self.config[NUM_NEG],
|
|
952
|
+
scale_loss,
|
|
953
|
+
similarity_type=self.config[SIMILARITY_TYPE],
|
|
954
|
+
)
|
|
955
|
+
|
|
956
|
+
@property
|
|
957
|
+
def dot_product_loss_layer(self) -> tf.keras.layers.Layer:
|
|
958
|
+
"""Returns the dot-product loss layer to use.
|
|
959
|
+
|
|
960
|
+
Multiple intents can be valid simultaneously, so `IntentTED` uses the
|
|
961
|
+
`MultiLabelDotProductLoss`.
|
|
962
|
+
|
|
963
|
+
Returns:
|
|
964
|
+
The loss layer that is used by `_prepare_dot_product_loss`.
|
|
965
|
+
"""
|
|
966
|
+
return layers.MultiLabelDotProductLoss
|
|
967
|
+
|
|
968
|
+
@staticmethod
|
|
969
|
+
def _get_labels_embed(
|
|
970
|
+
label_ids: tf.Tensor, all_labels_embed: tf.Tensor
|
|
971
|
+
) -> tf.Tensor:
|
|
972
|
+
# instead of processing labels again, gather embeddings from
|
|
973
|
+
# all_labels_embed using label ids
|
|
974
|
+
|
|
975
|
+
indices = tf.cast(label_ids[:, :, 0], tf.int32)
|
|
976
|
+
|
|
977
|
+
# Find padding indices. They should have a value equal to `LABEL_PAD_ID`
|
|
978
|
+
padding_indices = tf.where(tf.equal(indices, LABEL_PAD_ID))
|
|
979
|
+
|
|
980
|
+
# Create a tensor of values with sign opposite to `LABEL_PAD_ID` which
|
|
981
|
+
# will serve as updates to original `indices`
|
|
982
|
+
updates_to_indices = (
|
|
983
|
+
tf.ones((tf.shape(padding_indices)[0]), dtype=tf.int32) * -1 * LABEL_PAD_ID
|
|
984
|
+
)
|
|
985
|
+
|
|
986
|
+
# Add the updates tensor to indices with padding.
|
|
987
|
+
# So, effectively all indices with `LABEL_PAD_ID=-1`
|
|
988
|
+
# become 0 because updates contain 1s.
|
|
989
|
+
# This is fine because we don't change the original non-padding label
|
|
990
|
+
# indices but only make the padding indices 'compatible'
|
|
991
|
+
# for the `tf.gather` op below.
|
|
992
|
+
indices_to_gather = tf.cast(
|
|
993
|
+
tf.tensor_scatter_nd_add(indices, padding_indices, updates_to_indices),
|
|
994
|
+
tf.int32,
|
|
995
|
+
)
|
|
996
|
+
|
|
997
|
+
labels_embed = tf.gather(all_labels_embed, indices_to_gather)
|
|
998
|
+
|
|
999
|
+
return labels_embed
|
|
1000
|
+
|
|
1001
|
+
def run_bulk_inference(
|
|
1002
|
+
self, model_data: RasaModelData
|
|
1003
|
+
) -> Dict[Text, Union[np.ndarray, Dict[Text, Any]]]:
|
|
1004
|
+
"""Computes model's predictions for input data.
|
|
1005
|
+
|
|
1006
|
+
Args:
|
|
1007
|
+
model_data: Data to be passed as input
|
|
1008
|
+
|
|
1009
|
+
Returns:
|
|
1010
|
+
Predictions for the input data.
|
|
1011
|
+
"""
|
|
1012
|
+
self._training = False
|
|
1013
|
+
|
|
1014
|
+
batch_size = (
|
|
1015
|
+
self.config[BATCH_SIZES]
|
|
1016
|
+
if isinstance(self.config[BATCH_SIZES], int)
|
|
1017
|
+
else self.config[BATCH_SIZES][0]
|
|
1018
|
+
)
|
|
1019
|
+
|
|
1020
|
+
return self.run_inference(
|
|
1021
|
+
model_data, batch_size=batch_size, output_keys_expected=["similarities"]
|
|
1022
|
+
)
|