rasa-pro 3.12.0.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- README.md +41 -0
- rasa/__init__.py +9 -0
- rasa/__main__.py +177 -0
- rasa/anonymization/__init__.py +2 -0
- rasa/anonymization/anonymisation_rule_yaml_reader.py +91 -0
- rasa/anonymization/anonymization_pipeline.py +286 -0
- rasa/anonymization/anonymization_rule_executor.py +260 -0
- rasa/anonymization/anonymization_rule_orchestrator.py +120 -0
- rasa/anonymization/schemas/config.yml +47 -0
- rasa/anonymization/utils.py +118 -0
- rasa/api.py +160 -0
- rasa/cli/__init__.py +5 -0
- rasa/cli/arguments/__init__.py +0 -0
- rasa/cli/arguments/data.py +106 -0
- rasa/cli/arguments/default_arguments.py +207 -0
- rasa/cli/arguments/evaluate.py +65 -0
- rasa/cli/arguments/export.py +51 -0
- rasa/cli/arguments/interactive.py +74 -0
- rasa/cli/arguments/run.py +219 -0
- rasa/cli/arguments/shell.py +17 -0
- rasa/cli/arguments/test.py +211 -0
- rasa/cli/arguments/train.py +279 -0
- rasa/cli/arguments/visualize.py +34 -0
- rasa/cli/arguments/x.py +30 -0
- rasa/cli/data.py +354 -0
- rasa/cli/dialogue_understanding_test.py +251 -0
- rasa/cli/e2e_test.py +259 -0
- rasa/cli/evaluate.py +222 -0
- rasa/cli/export.py +250 -0
- rasa/cli/inspect.py +75 -0
- rasa/cli/interactive.py +166 -0
- rasa/cli/license.py +65 -0
- rasa/cli/llm_fine_tuning.py +403 -0
- rasa/cli/markers.py +78 -0
- rasa/cli/project_templates/__init__.py +0 -0
- rasa/cli/project_templates/calm/actions/__init__.py +0 -0
- rasa/cli/project_templates/calm/actions/action_template.py +27 -0
- rasa/cli/project_templates/calm/actions/add_contact.py +30 -0
- rasa/cli/project_templates/calm/actions/db.py +57 -0
- rasa/cli/project_templates/calm/actions/list_contacts.py +22 -0
- rasa/cli/project_templates/calm/actions/remove_contact.py +35 -0
- rasa/cli/project_templates/calm/config.yml +10 -0
- rasa/cli/project_templates/calm/credentials.yml +33 -0
- rasa/cli/project_templates/calm/data/flows/add_contact.yml +31 -0
- rasa/cli/project_templates/calm/data/flows/list_contacts.yml +14 -0
- rasa/cli/project_templates/calm/data/flows/remove_contact.yml +29 -0
- rasa/cli/project_templates/calm/db/contacts.json +10 -0
- rasa/cli/project_templates/calm/domain/add_contact.yml +39 -0
- rasa/cli/project_templates/calm/domain/list_contacts.yml +17 -0
- rasa/cli/project_templates/calm/domain/remove_contact.yml +38 -0
- rasa/cli/project_templates/calm/domain/shared.yml +10 -0
- rasa/cli/project_templates/calm/e2e_tests/cancelations/user_cancels_during_a_correction.yml +16 -0
- rasa/cli/project_templates/calm/e2e_tests/cancelations/user_changes_mind_on_a_whim.yml +7 -0
- rasa/cli/project_templates/calm/e2e_tests/corrections/user_corrects_contact_handle.yml +20 -0
- rasa/cli/project_templates/calm/e2e_tests/corrections/user_corrects_contact_name.yml +19 -0
- rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_adds_contact_to_their_list.yml +15 -0
- rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_lists_contacts.yml +5 -0
- rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_removes_contact.yml +11 -0
- rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_removes_contact_from_list.yml +12 -0
- rasa/cli/project_templates/calm/endpoints.yml +58 -0
- rasa/cli/project_templates/default/actions/__init__.py +0 -0
- rasa/cli/project_templates/default/actions/actions.py +27 -0
- rasa/cli/project_templates/default/config.yml +44 -0
- rasa/cli/project_templates/default/credentials.yml +33 -0
- rasa/cli/project_templates/default/data/nlu.yml +91 -0
- rasa/cli/project_templates/default/data/rules.yml +13 -0
- rasa/cli/project_templates/default/data/stories.yml +30 -0
- rasa/cli/project_templates/default/domain.yml +34 -0
- rasa/cli/project_templates/default/endpoints.yml +42 -0
- rasa/cli/project_templates/default/tests/test_stories.yml +91 -0
- rasa/cli/project_templates/tutorial/actions/__init__.py +0 -0
- rasa/cli/project_templates/tutorial/actions/actions.py +22 -0
- rasa/cli/project_templates/tutorial/config.yml +12 -0
- rasa/cli/project_templates/tutorial/credentials.yml +33 -0
- rasa/cli/project_templates/tutorial/data/flows.yml +8 -0
- rasa/cli/project_templates/tutorial/data/patterns.yml +11 -0
- rasa/cli/project_templates/tutorial/domain.yml +35 -0
- rasa/cli/project_templates/tutorial/endpoints.yml +55 -0
- rasa/cli/run.py +143 -0
- rasa/cli/scaffold.py +273 -0
- rasa/cli/shell.py +141 -0
- rasa/cli/studio/__init__.py +0 -0
- rasa/cli/studio/download.py +62 -0
- rasa/cli/studio/studio.py +296 -0
- rasa/cli/studio/train.py +59 -0
- rasa/cli/studio/upload.py +62 -0
- rasa/cli/telemetry.py +102 -0
- rasa/cli/test.py +280 -0
- rasa/cli/train.py +278 -0
- rasa/cli/utils.py +484 -0
- rasa/cli/visualize.py +40 -0
- rasa/cli/x.py +206 -0
- rasa/constants.py +45 -0
- rasa/core/__init__.py +17 -0
- rasa/core/actions/__init__.py +0 -0
- rasa/core/actions/action.py +1318 -0
- rasa/core/actions/action_clean_stack.py +59 -0
- rasa/core/actions/action_exceptions.py +24 -0
- rasa/core/actions/action_hangup.py +29 -0
- rasa/core/actions/action_repeat_bot_messages.py +89 -0
- rasa/core/actions/action_run_slot_rejections.py +210 -0
- rasa/core/actions/action_trigger_chitchat.py +31 -0
- rasa/core/actions/action_trigger_flow.py +109 -0
- rasa/core/actions/action_trigger_search.py +31 -0
- rasa/core/actions/constants.py +5 -0
- rasa/core/actions/custom_action_executor.py +191 -0
- rasa/core/actions/direct_custom_actions_executor.py +109 -0
- rasa/core/actions/e2e_stub_custom_action_executor.py +72 -0
- rasa/core/actions/forms.py +741 -0
- rasa/core/actions/grpc_custom_action_executor.py +251 -0
- rasa/core/actions/http_custom_action_executor.py +145 -0
- rasa/core/actions/loops.py +114 -0
- rasa/core/actions/two_stage_fallback.py +186 -0
- rasa/core/agent.py +559 -0
- rasa/core/auth_retry_tracker_store.py +122 -0
- rasa/core/brokers/__init__.py +0 -0
- rasa/core/brokers/broker.py +126 -0
- rasa/core/brokers/file.py +58 -0
- rasa/core/brokers/kafka.py +324 -0
- rasa/core/brokers/pika.py +388 -0
- rasa/core/brokers/sql.py +86 -0
- rasa/core/channels/__init__.py +61 -0
- rasa/core/channels/botframework.py +338 -0
- rasa/core/channels/callback.py +84 -0
- rasa/core/channels/channel.py +456 -0
- rasa/core/channels/console.py +241 -0
- rasa/core/channels/development_inspector.py +197 -0
- rasa/core/channels/facebook.py +419 -0
- rasa/core/channels/hangouts.py +329 -0
- rasa/core/channels/inspector/.eslintrc.cjs +25 -0
- rasa/core/channels/inspector/.gitignore +23 -0
- rasa/core/channels/inspector/README.md +54 -0
- rasa/core/channels/inspector/assets/favicon.ico +0 -0
- rasa/core/channels/inspector/assets/rasa-chat.js +2 -0
- rasa/core/channels/inspector/custom.d.ts +3 -0
- rasa/core/channels/inspector/dist/assets/arc-861ddd57.js +1 -0
- rasa/core/channels/inspector/dist/assets/array-9f3ba611.js +1 -0
- rasa/core/channels/inspector/dist/assets/c4Diagram-d0fbc5ce-921f02db.js +10 -0
- rasa/core/channels/inspector/dist/assets/classDiagram-936ed81e-b436c4f8.js +2 -0
- rasa/core/channels/inspector/dist/assets/classDiagram-v2-c3cb15f1-511a23cb.js +2 -0
- rasa/core/channels/inspector/dist/assets/createText-62fc7601-ef476ecd.js +7 -0
- rasa/core/channels/inspector/dist/assets/edges-f2ad444c-f1878e0a.js +4 -0
- rasa/core/channels/inspector/dist/assets/erDiagram-9d236eb7-fac75185.js +51 -0
- rasa/core/channels/inspector/dist/assets/flowDb-1972c806-201c5bbc.js +6 -0
- rasa/core/channels/inspector/dist/assets/flowDiagram-7ea5b25a-f904ae41.js +4 -0
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-b080d6f2.js +1 -0
- rasa/core/channels/inspector/dist/assets/flowchart-elk-definition-abe16c3d-1813da66.js +139 -0
- rasa/core/channels/inspector/dist/assets/ganttDiagram-9b5ea136-872af172.js +266 -0
- rasa/core/channels/inspector/dist/assets/gitGraphDiagram-99d0ae7c-34a0af5a.js +70 -0
- rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-128cfa44.ttf +0 -0
- rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-21dbcb97.woff +0 -0
- rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-222b5e26.svg +329 -0
- rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-9ad89b2a.woff2 +0 -0
- rasa/core/channels/inspector/dist/assets/index-2c4b9a3b-42ba3e3d.js +1 -0
- rasa/core/channels/inspector/dist/assets/index-37817b51.js +1317 -0
- rasa/core/channels/inspector/dist/assets/index-3ee28881.css +1 -0
- rasa/core/channels/inspector/dist/assets/infoDiagram-736b4530-6b731386.js +7 -0
- rasa/core/channels/inspector/dist/assets/init-77b53fdd.js +1 -0
- rasa/core/channels/inspector/dist/assets/journeyDiagram-df861f2b-e8579ac6.js +139 -0
- rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-60c05ee4.woff +0 -0
- rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-8335d9b8.svg +438 -0
- rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-9cc39c75.ttf +0 -0
- rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-ead13ccf.woff2 +0 -0
- rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-16705655.woff2 +0 -0
- rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-5aeb07f9.woff +0 -0
- rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-9c459044.ttf +0 -0
- rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-9e2898a4.svg +435 -0
- rasa/core/channels/inspector/dist/assets/layout-89e6403a.js +1 -0
- rasa/core/channels/inspector/dist/assets/line-dc73d3fc.js +1 -0
- rasa/core/channels/inspector/dist/assets/linear-f5b1d2bc.js +1 -0
- rasa/core/channels/inspector/dist/assets/mindmap-definition-beec6740-82cb74fa.js +109 -0
- rasa/core/channels/inspector/dist/assets/ordinal-ba9b4969.js +1 -0
- rasa/core/channels/inspector/dist/assets/path-53f90ab3.js +1 -0
- rasa/core/channels/inspector/dist/assets/pieDiagram-dbbf0591-bdf5f29b.js +35 -0
- rasa/core/channels/inspector/dist/assets/quadrantDiagram-4d7f4fd6-c7a0cbe4.js +7 -0
- rasa/core/channels/inspector/dist/assets/requirementDiagram-6fc4c22a-7ec5410f.js +52 -0
- rasa/core/channels/inspector/dist/assets/sankeyDiagram-8f13d901-caee5554.js +8 -0
- rasa/core/channels/inspector/dist/assets/sequenceDiagram-b655622a-2935f8db.js +122 -0
- rasa/core/channels/inspector/dist/assets/stateDiagram-59f0c015-8f5d9693.js +1 -0
- rasa/core/channels/inspector/dist/assets/stateDiagram-v2-2b26beab-d565d1de.js +1 -0
- rasa/core/channels/inspector/dist/assets/styles-080da4f6-75ad421d.js +110 -0
- rasa/core/channels/inspector/dist/assets/styles-3dcbcfbf-7e764226.js +159 -0
- rasa/core/channels/inspector/dist/assets/styles-9c745c82-7a4e0e61.js +207 -0
- rasa/core/channels/inspector/dist/assets/svgDrawCommon-4835440b-4019d1bf.js +1 -0
- rasa/core/channels/inspector/dist/assets/timeline-definition-5b62e21b-01ea12df.js +61 -0
- rasa/core/channels/inspector/dist/assets/xychartDiagram-2b33534f-89407137.js +7 -0
- rasa/core/channels/inspector/dist/index.html +42 -0
- rasa/core/channels/inspector/index.html +40 -0
- rasa/core/channels/inspector/jest.config.ts +13 -0
- rasa/core/channels/inspector/package.json +52 -0
- rasa/core/channels/inspector/setupTests.ts +2 -0
- rasa/core/channels/inspector/src/App.tsx +220 -0
- rasa/core/channels/inspector/src/components/Chat.tsx +95 -0
- rasa/core/channels/inspector/src/components/DiagramFlow.tsx +108 -0
- rasa/core/channels/inspector/src/components/DialogueInformation.tsx +187 -0
- rasa/core/channels/inspector/src/components/DialogueStack.tsx +136 -0
- rasa/core/channels/inspector/src/components/ExpandIcon.tsx +16 -0
- rasa/core/channels/inspector/src/components/FullscreenButton.tsx +45 -0
- rasa/core/channels/inspector/src/components/LoadingSpinner.tsx +22 -0
- rasa/core/channels/inspector/src/components/NoActiveFlow.tsx +21 -0
- rasa/core/channels/inspector/src/components/RasaLogo.tsx +32 -0
- rasa/core/channels/inspector/src/components/SaraDiagrams.tsx +39 -0
- rasa/core/channels/inspector/src/components/Slots.tsx +91 -0
- rasa/core/channels/inspector/src/components/Welcome.tsx +54 -0
- rasa/core/channels/inspector/src/helpers/audiostream.ts +191 -0
- rasa/core/channels/inspector/src/helpers/formatters.test.ts +392 -0
- rasa/core/channels/inspector/src/helpers/formatters.ts +306 -0
- rasa/core/channels/inspector/src/helpers/utils.ts +127 -0
- rasa/core/channels/inspector/src/main.tsx +13 -0
- rasa/core/channels/inspector/src/theme/Button/Button.ts +29 -0
- rasa/core/channels/inspector/src/theme/Heading/Heading.ts +31 -0
- rasa/core/channels/inspector/src/theme/Input/Input.ts +27 -0
- rasa/core/channels/inspector/src/theme/Link/Link.ts +10 -0
- rasa/core/channels/inspector/src/theme/Modal/Modal.ts +47 -0
- rasa/core/channels/inspector/src/theme/Table/Table.tsx +38 -0
- rasa/core/channels/inspector/src/theme/Tooltip/Tooltip.ts +12 -0
- rasa/core/channels/inspector/src/theme/base/breakpoints.ts +8 -0
- rasa/core/channels/inspector/src/theme/base/colors.ts +88 -0
- rasa/core/channels/inspector/src/theme/base/fonts/fontFaces.css +29 -0
- rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.eot +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.svg +329 -0
- rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.ttf +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.woff +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.woff2 +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.eot +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.svg +438 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.ttf +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.woff +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.woff2 +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.eot +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.svg +435 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.ttf +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.woff +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.woff2 +0 -0
- rasa/core/channels/inspector/src/theme/base/radii.ts +9 -0
- rasa/core/channels/inspector/src/theme/base/shadows.ts +7 -0
- rasa/core/channels/inspector/src/theme/base/sizes.ts +7 -0
- rasa/core/channels/inspector/src/theme/base/space.ts +15 -0
- rasa/core/channels/inspector/src/theme/base/styles.ts +13 -0
- rasa/core/channels/inspector/src/theme/base/typography.ts +24 -0
- rasa/core/channels/inspector/src/theme/base/zIndices.ts +19 -0
- rasa/core/channels/inspector/src/theme/index.ts +101 -0
- rasa/core/channels/inspector/src/types.ts +84 -0
- rasa/core/channels/inspector/src/vite-env.d.ts +1 -0
- rasa/core/channels/inspector/tests/__mocks__/fileMock.ts +1 -0
- rasa/core/channels/inspector/tests/__mocks__/matchMedia.ts +16 -0
- rasa/core/channels/inspector/tests/__mocks__/styleMock.ts +1 -0
- rasa/core/channels/inspector/tests/renderWithProviders.tsx +14 -0
- rasa/core/channels/inspector/tsconfig.json +26 -0
- rasa/core/channels/inspector/tsconfig.node.json +10 -0
- rasa/core/channels/inspector/vite.config.ts +8 -0
- rasa/core/channels/inspector/yarn.lock +6249 -0
- rasa/core/channels/mattermost.py +229 -0
- rasa/core/channels/rasa_chat.py +126 -0
- rasa/core/channels/rest.py +230 -0
- rasa/core/channels/rocketchat.py +174 -0
- rasa/core/channels/slack.py +620 -0
- rasa/core/channels/socketio.py +302 -0
- rasa/core/channels/telegram.py +298 -0
- rasa/core/channels/twilio.py +169 -0
- rasa/core/channels/vier_cvg.py +374 -0
- rasa/core/channels/voice_ready/__init__.py +0 -0
- rasa/core/channels/voice_ready/audiocodes.py +501 -0
- rasa/core/channels/voice_ready/jambonz.py +121 -0
- rasa/core/channels/voice_ready/jambonz_protocol.py +396 -0
- rasa/core/channels/voice_ready/twilio_voice.py +403 -0
- rasa/core/channels/voice_ready/utils.py +37 -0
- rasa/core/channels/voice_stream/__init__.py +0 -0
- rasa/core/channels/voice_stream/asr/__init__.py +0 -0
- rasa/core/channels/voice_stream/asr/asr_engine.py +89 -0
- rasa/core/channels/voice_stream/asr/asr_event.py +18 -0
- rasa/core/channels/voice_stream/asr/azure.py +130 -0
- rasa/core/channels/voice_stream/asr/deepgram.py +90 -0
- rasa/core/channels/voice_stream/audio_bytes.py +8 -0
- rasa/core/channels/voice_stream/browser_audio.py +107 -0
- rasa/core/channels/voice_stream/call_state.py +23 -0
- rasa/core/channels/voice_stream/tts/__init__.py +0 -0
- rasa/core/channels/voice_stream/tts/azure.py +106 -0
- rasa/core/channels/voice_stream/tts/cartesia.py +118 -0
- rasa/core/channels/voice_stream/tts/tts_cache.py +27 -0
- rasa/core/channels/voice_stream/tts/tts_engine.py +58 -0
- rasa/core/channels/voice_stream/twilio_media_streams.py +173 -0
- rasa/core/channels/voice_stream/util.py +57 -0
- rasa/core/channels/voice_stream/voice_channel.py +427 -0
- rasa/core/channels/webexteams.py +134 -0
- rasa/core/concurrent_lock_store.py +210 -0
- rasa/core/constants.py +112 -0
- rasa/core/evaluation/__init__.py +0 -0
- rasa/core/evaluation/marker.py +267 -0
- rasa/core/evaluation/marker_base.py +923 -0
- rasa/core/evaluation/marker_stats.py +293 -0
- rasa/core/evaluation/marker_tracker_loader.py +103 -0
- rasa/core/exceptions.py +29 -0
- rasa/core/exporter.py +284 -0
- rasa/core/featurizers/__init__.py +0 -0
- rasa/core/featurizers/precomputation.py +410 -0
- rasa/core/featurizers/single_state_featurizer.py +421 -0
- rasa/core/featurizers/tracker_featurizers.py +1262 -0
- rasa/core/http_interpreter.py +89 -0
- rasa/core/information_retrieval/__init__.py +7 -0
- rasa/core/information_retrieval/faiss.py +124 -0
- rasa/core/information_retrieval/information_retrieval.py +137 -0
- rasa/core/information_retrieval/milvus.py +59 -0
- rasa/core/information_retrieval/qdrant.py +96 -0
- rasa/core/jobs.py +63 -0
- rasa/core/lock.py +139 -0
- rasa/core/lock_store.py +343 -0
- rasa/core/migrate.py +403 -0
- rasa/core/nlg/__init__.py +3 -0
- rasa/core/nlg/callback.py +146 -0
- rasa/core/nlg/contextual_response_rephraser.py +320 -0
- rasa/core/nlg/generator.py +230 -0
- rasa/core/nlg/interpolator.py +143 -0
- rasa/core/nlg/response.py +155 -0
- rasa/core/nlg/summarize.py +70 -0
- rasa/core/persistor.py +538 -0
- rasa/core/policies/__init__.py +0 -0
- rasa/core/policies/ensemble.py +329 -0
- rasa/core/policies/enterprise_search_policy.py +905 -0
- rasa/core/policies/enterprise_search_prompt_template.jinja2 +25 -0
- rasa/core/policies/enterprise_search_prompt_with_citation_template.jinja2 +60 -0
- rasa/core/policies/flow_policy.py +205 -0
- rasa/core/policies/flows/__init__.py +0 -0
- rasa/core/policies/flows/flow_exceptions.py +44 -0
- rasa/core/policies/flows/flow_executor.py +754 -0
- rasa/core/policies/flows/flow_step_result.py +43 -0
- rasa/core/policies/intentless_policy.py +1031 -0
- rasa/core/policies/intentless_prompt_template.jinja2 +22 -0
- rasa/core/policies/memoization.py +538 -0
- rasa/core/policies/policy.py +725 -0
- rasa/core/policies/rule_policy.py +1273 -0
- rasa/core/policies/ted_policy.py +2169 -0
- rasa/core/policies/unexpected_intent_policy.py +1022 -0
- rasa/core/processor.py +1465 -0
- rasa/core/run.py +342 -0
- rasa/core/secrets_manager/__init__.py +0 -0
- rasa/core/secrets_manager/constants.py +36 -0
- rasa/core/secrets_manager/endpoints.py +391 -0
- rasa/core/secrets_manager/factory.py +241 -0
- rasa/core/secrets_manager/secret_manager.py +262 -0
- rasa/core/secrets_manager/vault.py +584 -0
- rasa/core/test.py +1335 -0
- rasa/core/tracker_store.py +1703 -0
- rasa/core/train.py +105 -0
- rasa/core/training/__init__.py +89 -0
- rasa/core/training/converters/__init__.py +0 -0
- rasa/core/training/converters/responses_prefix_converter.py +119 -0
- rasa/core/training/interactive.py +1744 -0
- rasa/core/training/story_conflict.py +381 -0
- rasa/core/training/training.py +93 -0
- rasa/core/utils.py +366 -0
- rasa/core/visualize.py +70 -0
- rasa/dialogue_understanding/__init__.py +0 -0
- rasa/dialogue_understanding/coexistence/__init__.py +0 -0
- rasa/dialogue_understanding/coexistence/constants.py +4 -0
- rasa/dialogue_understanding/coexistence/intent_based_router.py +196 -0
- rasa/dialogue_understanding/coexistence/llm_based_router.py +327 -0
- rasa/dialogue_understanding/coexistence/router_template.jinja2 +12 -0
- rasa/dialogue_understanding/commands/__init__.py +61 -0
- rasa/dialogue_understanding/commands/can_not_handle_command.py +70 -0
- rasa/dialogue_understanding/commands/cancel_flow_command.py +125 -0
- rasa/dialogue_understanding/commands/change_flow_command.py +44 -0
- rasa/dialogue_understanding/commands/chit_chat_answer_command.py +57 -0
- rasa/dialogue_understanding/commands/clarify_command.py +86 -0
- rasa/dialogue_understanding/commands/command.py +85 -0
- rasa/dialogue_understanding/commands/correct_slots_command.py +297 -0
- rasa/dialogue_understanding/commands/error_command.py +79 -0
- rasa/dialogue_understanding/commands/free_form_answer_command.py +9 -0
- rasa/dialogue_understanding/commands/handle_code_change_command.py +73 -0
- rasa/dialogue_understanding/commands/human_handoff_command.py +66 -0
- rasa/dialogue_understanding/commands/knowledge_answer_command.py +57 -0
- rasa/dialogue_understanding/commands/noop_command.py +54 -0
- rasa/dialogue_understanding/commands/repeat_bot_messages_command.py +60 -0
- rasa/dialogue_understanding/commands/restart_command.py +58 -0
- rasa/dialogue_understanding/commands/session_end_command.py +61 -0
- rasa/dialogue_understanding/commands/session_start_command.py +59 -0
- rasa/dialogue_understanding/commands/set_slot_command.py +160 -0
- rasa/dialogue_understanding/commands/skip_question_command.py +75 -0
- rasa/dialogue_understanding/commands/start_flow_command.py +107 -0
- rasa/dialogue_understanding/commands/user_silence_command.py +59 -0
- rasa/dialogue_understanding/commands/utils.py +45 -0
- rasa/dialogue_understanding/generator/__init__.py +21 -0
- rasa/dialogue_understanding/generator/command_generator.py +464 -0
- rasa/dialogue_understanding/generator/constants.py +27 -0
- rasa/dialogue_understanding/generator/flow_document_template.jinja2 +4 -0
- rasa/dialogue_understanding/generator/flow_retrieval.py +466 -0
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +500 -0
- rasa/dialogue_understanding/generator/llm_command_generator.py +67 -0
- rasa/dialogue_understanding/generator/multi_step/__init__.py +0 -0
- rasa/dialogue_understanding/generator/multi_step/fill_slots_prompt.jinja2 +62 -0
- rasa/dialogue_understanding/generator/multi_step/handle_flows_prompt.jinja2 +38 -0
- rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +920 -0
- rasa/dialogue_understanding/generator/nlu_command_adapter.py +261 -0
- rasa/dialogue_understanding/generator/single_step/__init__.py +0 -0
- rasa/dialogue_understanding/generator/single_step/command_prompt_template.jinja2 +60 -0
- rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +486 -0
- rasa/dialogue_understanding/patterns/__init__.py +0 -0
- rasa/dialogue_understanding/patterns/cancel.py +111 -0
- rasa/dialogue_understanding/patterns/cannot_handle.py +43 -0
- rasa/dialogue_understanding/patterns/chitchat.py +37 -0
- rasa/dialogue_understanding/patterns/clarify.py +97 -0
- rasa/dialogue_understanding/patterns/code_change.py +41 -0
- rasa/dialogue_understanding/patterns/collect_information.py +90 -0
- rasa/dialogue_understanding/patterns/completed.py +40 -0
- rasa/dialogue_understanding/patterns/continue_interrupted.py +42 -0
- rasa/dialogue_understanding/patterns/correction.py +278 -0
- rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +301 -0
- rasa/dialogue_understanding/patterns/human_handoff.py +37 -0
- rasa/dialogue_understanding/patterns/internal_error.py +47 -0
- rasa/dialogue_understanding/patterns/repeat.py +37 -0
- rasa/dialogue_understanding/patterns/restart.py +37 -0
- rasa/dialogue_understanding/patterns/search.py +37 -0
- rasa/dialogue_understanding/patterns/session_start.py +37 -0
- rasa/dialogue_understanding/patterns/skip_question.py +38 -0
- rasa/dialogue_understanding/patterns/user_silence.py +37 -0
- rasa/dialogue_understanding/processor/__init__.py +0 -0
- rasa/dialogue_understanding/processor/command_processor.py +720 -0
- rasa/dialogue_understanding/processor/command_processor_component.py +43 -0
- rasa/dialogue_understanding/stack/__init__.py +0 -0
- rasa/dialogue_understanding/stack/dialogue_stack.py +178 -0
- rasa/dialogue_understanding/stack/frames/__init__.py +19 -0
- rasa/dialogue_understanding/stack/frames/chit_chat_frame.py +27 -0
- rasa/dialogue_understanding/stack/frames/dialogue_stack_frame.py +137 -0
- rasa/dialogue_understanding/stack/frames/flow_stack_frame.py +157 -0
- rasa/dialogue_understanding/stack/frames/pattern_frame.py +10 -0
- rasa/dialogue_understanding/stack/frames/search_frame.py +27 -0
- rasa/dialogue_understanding/stack/utils.py +211 -0
- rasa/dialogue_understanding/utils.py +14 -0
- rasa/dialogue_understanding_test/__init__.py +0 -0
- rasa/dialogue_understanding_test/command_metric_calculation.py +12 -0
- rasa/dialogue_understanding_test/constants.py +17 -0
- rasa/dialogue_understanding_test/du_test_case.py +118 -0
- rasa/dialogue_understanding_test/du_test_result.py +11 -0
- rasa/dialogue_understanding_test/du_test_runner.py +93 -0
- rasa/dialogue_understanding_test/io.py +54 -0
- rasa/dialogue_understanding_test/validation.py +22 -0
- rasa/e2e_test/__init__.py +0 -0
- rasa/e2e_test/aggregate_test_stats_calculator.py +134 -0
- rasa/e2e_test/assertions.py +1345 -0
- rasa/e2e_test/assertions_schema.yml +129 -0
- rasa/e2e_test/constants.py +31 -0
- rasa/e2e_test/e2e_config.py +220 -0
- rasa/e2e_test/e2e_config_schema.yml +26 -0
- rasa/e2e_test/e2e_test_case.py +569 -0
- rasa/e2e_test/e2e_test_converter.py +363 -0
- rasa/e2e_test/e2e_test_converter_prompt.jinja2 +70 -0
- rasa/e2e_test/e2e_test_coverage_report.py +364 -0
- rasa/e2e_test/e2e_test_result.py +54 -0
- rasa/e2e_test/e2e_test_runner.py +1192 -0
- rasa/e2e_test/e2e_test_schema.yml +181 -0
- rasa/e2e_test/pykwalify_extensions.py +39 -0
- rasa/e2e_test/stub_custom_action.py +70 -0
- rasa/e2e_test/utils/__init__.py +0 -0
- rasa/e2e_test/utils/e2e_yaml_utils.py +55 -0
- rasa/e2e_test/utils/io.py +598 -0
- rasa/e2e_test/utils/validation.py +178 -0
- rasa/engine/__init__.py +0 -0
- rasa/engine/caching.py +463 -0
- rasa/engine/constants.py +17 -0
- rasa/engine/exceptions.py +14 -0
- rasa/engine/graph.py +642 -0
- rasa/engine/loader.py +48 -0
- rasa/engine/recipes/__init__.py +0 -0
- rasa/engine/recipes/config_files/default_config.yml +41 -0
- rasa/engine/recipes/default_components.py +97 -0
- rasa/engine/recipes/default_recipe.py +1272 -0
- rasa/engine/recipes/graph_recipe.py +79 -0
- rasa/engine/recipes/recipe.py +93 -0
- rasa/engine/runner/__init__.py +0 -0
- rasa/engine/runner/dask.py +250 -0
- rasa/engine/runner/interface.py +49 -0
- rasa/engine/storage/__init__.py +0 -0
- rasa/engine/storage/local_model_storage.py +244 -0
- rasa/engine/storage/resource.py +110 -0
- rasa/engine/storage/storage.py +199 -0
- rasa/engine/training/__init__.py +0 -0
- rasa/engine/training/components.py +176 -0
- rasa/engine/training/fingerprinting.py +64 -0
- rasa/engine/training/graph_trainer.py +256 -0
- rasa/engine/training/hooks.py +164 -0
- rasa/engine/validation.py +1451 -0
- rasa/env.py +14 -0
- rasa/exceptions.py +69 -0
- rasa/graph_components/__init__.py +0 -0
- rasa/graph_components/converters/__init__.py +0 -0
- rasa/graph_components/converters/nlu_message_converter.py +48 -0
- rasa/graph_components/providers/__init__.py +0 -0
- rasa/graph_components/providers/domain_for_core_training_provider.py +87 -0
- rasa/graph_components/providers/domain_provider.py +71 -0
- rasa/graph_components/providers/flows_provider.py +74 -0
- rasa/graph_components/providers/forms_provider.py +44 -0
- rasa/graph_components/providers/nlu_training_data_provider.py +56 -0
- rasa/graph_components/providers/responses_provider.py +44 -0
- rasa/graph_components/providers/rule_only_provider.py +49 -0
- rasa/graph_components/providers/story_graph_provider.py +96 -0
- rasa/graph_components/providers/training_tracker_provider.py +55 -0
- rasa/graph_components/validators/__init__.py +0 -0
- rasa/graph_components/validators/default_recipe_validator.py +550 -0
- rasa/graph_components/validators/finetuning_validator.py +302 -0
- rasa/hooks.py +111 -0
- rasa/jupyter.py +63 -0
- rasa/llm_fine_tuning/__init__.py +0 -0
- rasa/llm_fine_tuning/annotation_module.py +241 -0
- rasa/llm_fine_tuning/conversations.py +144 -0
- rasa/llm_fine_tuning/llm_data_preparation_module.py +178 -0
- rasa/llm_fine_tuning/paraphrasing/__init__.py +0 -0
- rasa/llm_fine_tuning/paraphrasing/conversation_rephraser.py +281 -0
- rasa/llm_fine_tuning/paraphrasing/default_rephrase_prompt_template.jina2 +44 -0
- rasa/llm_fine_tuning/paraphrasing/rephrase_validator.py +121 -0
- rasa/llm_fine_tuning/paraphrasing/rephrased_user_message.py +10 -0
- rasa/llm_fine_tuning/paraphrasing_module.py +128 -0
- rasa/llm_fine_tuning/storage.py +174 -0
- rasa/llm_fine_tuning/train_test_split_module.py +441 -0
- rasa/markers/__init__.py +0 -0
- rasa/markers/marker.py +269 -0
- rasa/markers/marker_base.py +828 -0
- rasa/markers/upload.py +74 -0
- rasa/markers/validate.py +21 -0
- rasa/model.py +118 -0
- rasa/model_manager/__init__.py +0 -0
- rasa/model_manager/config.py +40 -0
- rasa/model_manager/model_api.py +559 -0
- rasa/model_manager/runner_service.py +286 -0
- rasa/model_manager/socket_bridge.py +146 -0
- rasa/model_manager/studio_jwt_auth.py +86 -0
- rasa/model_manager/trainer_service.py +325 -0
- rasa/model_manager/utils.py +87 -0
- rasa/model_manager/warm_rasa_process.py +187 -0
- rasa/model_service.py +112 -0
- rasa/model_testing.py +457 -0
- rasa/model_training.py +596 -0
- rasa/nlu/__init__.py +7 -0
- rasa/nlu/classifiers/__init__.py +3 -0
- rasa/nlu/classifiers/classifier.py +5 -0
- rasa/nlu/classifiers/diet_classifier.py +1881 -0
- rasa/nlu/classifiers/fallback_classifier.py +192 -0
- rasa/nlu/classifiers/keyword_intent_classifier.py +188 -0
- rasa/nlu/classifiers/logistic_regression_classifier.py +253 -0
- rasa/nlu/classifiers/mitie_intent_classifier.py +156 -0
- rasa/nlu/classifiers/regex_message_handler.py +56 -0
- rasa/nlu/classifiers/sklearn_intent_classifier.py +330 -0
- rasa/nlu/constants.py +77 -0
- rasa/nlu/convert.py +40 -0
- rasa/nlu/emulators/__init__.py +0 -0
- rasa/nlu/emulators/dialogflow.py +55 -0
- rasa/nlu/emulators/emulator.py +49 -0
- rasa/nlu/emulators/luis.py +86 -0
- rasa/nlu/emulators/no_emulator.py +10 -0
- rasa/nlu/emulators/wit.py +56 -0
- rasa/nlu/extractors/__init__.py +0 -0
- rasa/nlu/extractors/crf_entity_extractor.py +715 -0
- rasa/nlu/extractors/duckling_entity_extractor.py +206 -0
- rasa/nlu/extractors/entity_synonyms.py +178 -0
- rasa/nlu/extractors/extractor.py +470 -0
- rasa/nlu/extractors/mitie_entity_extractor.py +293 -0
- rasa/nlu/extractors/regex_entity_extractor.py +220 -0
- rasa/nlu/extractors/spacy_entity_extractor.py +95 -0
- rasa/nlu/featurizers/__init__.py +0 -0
- rasa/nlu/featurizers/dense_featurizer/__init__.py +0 -0
- rasa/nlu/featurizers/dense_featurizer/convert_featurizer.py +445 -0
- rasa/nlu/featurizers/dense_featurizer/dense_featurizer.py +57 -0
- rasa/nlu/featurizers/dense_featurizer/lm_featurizer.py +768 -0
- rasa/nlu/featurizers/dense_featurizer/mitie_featurizer.py +170 -0
- rasa/nlu/featurizers/dense_featurizer/spacy_featurizer.py +132 -0
- rasa/nlu/featurizers/featurizer.py +89 -0
- rasa/nlu/featurizers/sparse_featurizer/__init__.py +0 -0
- rasa/nlu/featurizers/sparse_featurizer/count_vectors_featurizer.py +867 -0
- rasa/nlu/featurizers/sparse_featurizer/lexical_syntactic_featurizer.py +571 -0
- rasa/nlu/featurizers/sparse_featurizer/regex_featurizer.py +271 -0
- rasa/nlu/featurizers/sparse_featurizer/sparse_featurizer.py +9 -0
- rasa/nlu/model.py +24 -0
- rasa/nlu/run.py +27 -0
- rasa/nlu/selectors/__init__.py +0 -0
- rasa/nlu/selectors/response_selector.py +987 -0
- rasa/nlu/test.py +1940 -0
- rasa/nlu/tokenizers/__init__.py +0 -0
- rasa/nlu/tokenizers/jieba_tokenizer.py +148 -0
- rasa/nlu/tokenizers/mitie_tokenizer.py +75 -0
- rasa/nlu/tokenizers/spacy_tokenizer.py +72 -0
- rasa/nlu/tokenizers/tokenizer.py +239 -0
- rasa/nlu/tokenizers/whitespace_tokenizer.py +95 -0
- rasa/nlu/utils/__init__.py +35 -0
- rasa/nlu/utils/bilou_utils.py +462 -0
- rasa/nlu/utils/hugging_face/__init__.py +0 -0
- rasa/nlu/utils/hugging_face/registry.py +108 -0
- rasa/nlu/utils/hugging_face/transformers_pre_post_processors.py +311 -0
- rasa/nlu/utils/mitie_utils.py +113 -0
- rasa/nlu/utils/pattern_utils.py +168 -0
- rasa/nlu/utils/spacy_utils.py +310 -0
- rasa/plugin.py +90 -0
- rasa/server.py +1588 -0
- rasa/shared/__init__.py +0 -0
- rasa/shared/constants.py +311 -0
- rasa/shared/core/__init__.py +0 -0
- rasa/shared/core/command_payload_reader.py +109 -0
- rasa/shared/core/constants.py +180 -0
- rasa/shared/core/conversation.py +46 -0
- rasa/shared/core/domain.py +2172 -0
- rasa/shared/core/events.py +2559 -0
- rasa/shared/core/flows/__init__.py +7 -0
- rasa/shared/core/flows/flow.py +562 -0
- rasa/shared/core/flows/flow_path.py +84 -0
- rasa/shared/core/flows/flow_step.py +146 -0
- rasa/shared/core/flows/flow_step_links.py +319 -0
- rasa/shared/core/flows/flow_step_sequence.py +70 -0
- rasa/shared/core/flows/flows_list.py +258 -0
- rasa/shared/core/flows/flows_yaml_schema.json +303 -0
- rasa/shared/core/flows/nlu_trigger.py +117 -0
- rasa/shared/core/flows/steps/__init__.py +24 -0
- rasa/shared/core/flows/steps/action.py +56 -0
- rasa/shared/core/flows/steps/call.py +64 -0
- rasa/shared/core/flows/steps/collect.py +112 -0
- rasa/shared/core/flows/steps/constants.py +5 -0
- rasa/shared/core/flows/steps/continuation.py +36 -0
- rasa/shared/core/flows/steps/end.py +22 -0
- rasa/shared/core/flows/steps/internal.py +44 -0
- rasa/shared/core/flows/steps/link.py +51 -0
- rasa/shared/core/flows/steps/no_operation.py +48 -0
- rasa/shared/core/flows/steps/set_slots.py +50 -0
- rasa/shared/core/flows/steps/start.py +30 -0
- rasa/shared/core/flows/utils.py +39 -0
- rasa/shared/core/flows/validation.py +735 -0
- rasa/shared/core/flows/yaml_flows_io.py +405 -0
- rasa/shared/core/generator.py +908 -0
- rasa/shared/core/slot_mappings.py +526 -0
- rasa/shared/core/slots.py +654 -0
- rasa/shared/core/trackers.py +1183 -0
- rasa/shared/core/training_data/__init__.py +0 -0
- rasa/shared/core/training_data/loading.py +89 -0
- rasa/shared/core/training_data/story_reader/__init__.py +0 -0
- rasa/shared/core/training_data/story_reader/story_reader.py +129 -0
- rasa/shared/core/training_data/story_reader/story_step_builder.py +168 -0
- rasa/shared/core/training_data/story_reader/yaml_story_reader.py +888 -0
- rasa/shared/core/training_data/story_writer/__init__.py +0 -0
- rasa/shared/core/training_data/story_writer/story_writer.py +76 -0
- rasa/shared/core/training_data/story_writer/yaml_story_writer.py +444 -0
- rasa/shared/core/training_data/structures.py +858 -0
- rasa/shared/core/training_data/visualization.html +146 -0
- rasa/shared/core/training_data/visualization.py +603 -0
- rasa/shared/data.py +249 -0
- rasa/shared/engine/__init__.py +0 -0
- rasa/shared/engine/caching.py +26 -0
- rasa/shared/exceptions.py +167 -0
- rasa/shared/importers/__init__.py +0 -0
- rasa/shared/importers/importer.py +770 -0
- rasa/shared/importers/multi_project.py +215 -0
- rasa/shared/importers/rasa.py +108 -0
- rasa/shared/importers/remote_importer.py +196 -0
- rasa/shared/importers/utils.py +36 -0
- rasa/shared/nlu/__init__.py +0 -0
- rasa/shared/nlu/constants.py +53 -0
- rasa/shared/nlu/interpreter.py +10 -0
- rasa/shared/nlu/training_data/__init__.py +0 -0
- rasa/shared/nlu/training_data/entities_parser.py +208 -0
- rasa/shared/nlu/training_data/features.py +492 -0
- rasa/shared/nlu/training_data/formats/__init__.py +10 -0
- rasa/shared/nlu/training_data/formats/dialogflow.py +163 -0
- rasa/shared/nlu/training_data/formats/luis.py +87 -0
- rasa/shared/nlu/training_data/formats/rasa.py +135 -0
- rasa/shared/nlu/training_data/formats/rasa_yaml.py +618 -0
- rasa/shared/nlu/training_data/formats/readerwriter.py +244 -0
- rasa/shared/nlu/training_data/formats/wit.py +52 -0
- rasa/shared/nlu/training_data/loading.py +137 -0
- rasa/shared/nlu/training_data/lookup_tables_parser.py +30 -0
- rasa/shared/nlu/training_data/message.py +490 -0
- rasa/shared/nlu/training_data/schemas/__init__.py +0 -0
- rasa/shared/nlu/training_data/schemas/data_schema.py +85 -0
- rasa/shared/nlu/training_data/schemas/nlu.yml +53 -0
- rasa/shared/nlu/training_data/schemas/responses.yml +70 -0
- rasa/shared/nlu/training_data/synonyms_parser.py +42 -0
- rasa/shared/nlu/training_data/training_data.py +729 -0
- rasa/shared/nlu/training_data/util.py +223 -0
- rasa/shared/providers/__init__.py +0 -0
- rasa/shared/providers/_configs/__init__.py +0 -0
- rasa/shared/providers/_configs/azure_openai_client_config.py +677 -0
- rasa/shared/providers/_configs/client_config.py +59 -0
- rasa/shared/providers/_configs/default_litellm_client_config.py +132 -0
- rasa/shared/providers/_configs/huggingface_local_embedding_client_config.py +236 -0
- rasa/shared/providers/_configs/litellm_router_client_config.py +222 -0
- rasa/shared/providers/_configs/model_group_config.py +173 -0
- rasa/shared/providers/_configs/openai_client_config.py +177 -0
- rasa/shared/providers/_configs/rasa_llm_client_config.py +75 -0
- rasa/shared/providers/_configs/self_hosted_llm_client_config.py +178 -0
- rasa/shared/providers/_configs/utils.py +117 -0
- rasa/shared/providers/_ssl_verification_utils.py +124 -0
- rasa/shared/providers/_utils.py +79 -0
- rasa/shared/providers/constants.py +7 -0
- rasa/shared/providers/embedding/__init__.py +0 -0
- rasa/shared/providers/embedding/_base_litellm_embedding_client.py +243 -0
- rasa/shared/providers/embedding/_langchain_embedding_client_adapter.py +74 -0
- rasa/shared/providers/embedding/azure_openai_embedding_client.py +335 -0
- rasa/shared/providers/embedding/default_litellm_embedding_client.py +126 -0
- rasa/shared/providers/embedding/embedding_client.py +90 -0
- rasa/shared/providers/embedding/embedding_response.py +41 -0
- rasa/shared/providers/embedding/huggingface_local_embedding_client.py +191 -0
- rasa/shared/providers/embedding/litellm_router_embedding_client.py +138 -0
- rasa/shared/providers/embedding/openai_embedding_client.py +172 -0
- rasa/shared/providers/llm/__init__.py +0 -0
- rasa/shared/providers/llm/_base_litellm_client.py +265 -0
- rasa/shared/providers/llm/azure_openai_llm_client.py +415 -0
- rasa/shared/providers/llm/default_litellm_llm_client.py +110 -0
- rasa/shared/providers/llm/litellm_router_llm_client.py +202 -0
- rasa/shared/providers/llm/llm_client.py +78 -0
- rasa/shared/providers/llm/llm_response.py +50 -0
- rasa/shared/providers/llm/openai_llm_client.py +161 -0
- rasa/shared/providers/llm/rasa_llm_client.py +120 -0
- rasa/shared/providers/llm/self_hosted_llm_client.py +276 -0
- rasa/shared/providers/mappings.py +94 -0
- rasa/shared/providers/router/__init__.py +0 -0
- rasa/shared/providers/router/_base_litellm_router_client.py +185 -0
- rasa/shared/providers/router/router_client.py +75 -0
- rasa/shared/utils/__init__.py +0 -0
- rasa/shared/utils/cli.py +102 -0
- rasa/shared/utils/common.py +324 -0
- rasa/shared/utils/constants.py +4 -0
- rasa/shared/utils/health_check/__init__.py +0 -0
- rasa/shared/utils/health_check/embeddings_health_check_mixin.py +31 -0
- rasa/shared/utils/health_check/health_check.py +258 -0
- rasa/shared/utils/health_check/llm_health_check_mixin.py +31 -0
- rasa/shared/utils/io.py +499 -0
- rasa/shared/utils/llm.py +764 -0
- rasa/shared/utils/pykwalify_extensions.py +27 -0
- rasa/shared/utils/schemas/__init__.py +0 -0
- rasa/shared/utils/schemas/config.yml +2 -0
- rasa/shared/utils/schemas/domain.yml +145 -0
- rasa/shared/utils/schemas/events.py +214 -0
- rasa/shared/utils/schemas/model_config.yml +36 -0
- rasa/shared/utils/schemas/stories.yml +173 -0
- rasa/shared/utils/yaml.py +1068 -0
- rasa/studio/__init__.py +0 -0
- rasa/studio/auth.py +270 -0
- rasa/studio/config.py +136 -0
- rasa/studio/constants.py +19 -0
- rasa/studio/data_handler.py +368 -0
- rasa/studio/download.py +489 -0
- rasa/studio/results_logger.py +137 -0
- rasa/studio/train.py +134 -0
- rasa/studio/upload.py +563 -0
- rasa/telemetry.py +1876 -0
- rasa/tracing/__init__.py +0 -0
- rasa/tracing/config.py +355 -0
- rasa/tracing/constants.py +62 -0
- rasa/tracing/instrumentation/__init__.py +0 -0
- rasa/tracing/instrumentation/attribute_extractors.py +765 -0
- rasa/tracing/instrumentation/instrumentation.py +1306 -0
- rasa/tracing/instrumentation/intentless_policy_instrumentation.py +144 -0
- rasa/tracing/instrumentation/metrics.py +294 -0
- rasa/tracing/metric_instrument_provider.py +205 -0
- rasa/utils/__init__.py +0 -0
- rasa/utils/beta.py +83 -0
- rasa/utils/cli.py +28 -0
- rasa/utils/common.py +639 -0
- rasa/utils/converter.py +53 -0
- rasa/utils/endpoints.py +331 -0
- rasa/utils/io.py +252 -0
- rasa/utils/json_utils.py +60 -0
- rasa/utils/licensing.py +542 -0
- rasa/utils/log_utils.py +181 -0
- rasa/utils/mapper.py +210 -0
- rasa/utils/ml_utils.py +147 -0
- rasa/utils/plotting.py +362 -0
- rasa/utils/sanic_error_handler.py +32 -0
- rasa/utils/singleton.py +23 -0
- rasa/utils/tensorflow/__init__.py +0 -0
- rasa/utils/tensorflow/callback.py +112 -0
- rasa/utils/tensorflow/constants.py +116 -0
- rasa/utils/tensorflow/crf.py +492 -0
- rasa/utils/tensorflow/data_generator.py +440 -0
- rasa/utils/tensorflow/environment.py +161 -0
- rasa/utils/tensorflow/exceptions.py +5 -0
- rasa/utils/tensorflow/feature_array.py +366 -0
- rasa/utils/tensorflow/layers.py +1565 -0
- rasa/utils/tensorflow/layers_utils.py +113 -0
- rasa/utils/tensorflow/metrics.py +281 -0
- rasa/utils/tensorflow/model_data.py +798 -0
- rasa/utils/tensorflow/model_data_utils.py +499 -0
- rasa/utils/tensorflow/models.py +935 -0
- rasa/utils/tensorflow/rasa_layers.py +1094 -0
- rasa/utils/tensorflow/transformer.py +640 -0
- rasa/utils/tensorflow/types.py +6 -0
- rasa/utils/train_utils.py +572 -0
- rasa/utils/url_tools.py +53 -0
- rasa/utils/yaml.py +54 -0
- rasa/validator.py +1644 -0
- rasa/version.py +3 -0
- rasa_pro-3.12.0.dev1.dist-info/METADATA +199 -0
- rasa_pro-3.12.0.dev1.dist-info/NOTICE +5 -0
- rasa_pro-3.12.0.dev1.dist-info/RECORD +790 -0
- rasa_pro-3.12.0.dev1.dist-info/WHEEL +4 -0
- rasa_pro-3.12.0.dev1.dist-info/entry_points.txt +3 -0
|
@@ -0,0 +1,987 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
import copy
|
|
3
|
+
import logging
|
|
4
|
+
from rasa.nlu.featurizers.featurizer import Featurizer
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import tensorflow as tf
|
|
8
|
+
|
|
9
|
+
from typing import Any, Dict, Optional, Text, Tuple, Union, List, Type
|
|
10
|
+
|
|
11
|
+
from rasa.engine.graph import ExecutionContext
|
|
12
|
+
from rasa.engine.recipes.default_recipe import DefaultV1Recipe
|
|
13
|
+
from rasa.engine.storage.resource import Resource
|
|
14
|
+
from rasa.engine.storage.storage import ModelStorage
|
|
15
|
+
from rasa.shared.constants import DIAGNOSTIC_DATA
|
|
16
|
+
from rasa.shared.nlu.training_data import util
|
|
17
|
+
import rasa.shared.utils.io
|
|
18
|
+
from rasa.shared.exceptions import InvalidConfigException
|
|
19
|
+
from rasa.shared.nlu.training_data.training_data import TrainingData
|
|
20
|
+
from rasa.shared.nlu.training_data.message import Message
|
|
21
|
+
from rasa.nlu.classifiers.diet_classifier import (
|
|
22
|
+
DIET,
|
|
23
|
+
LABEL_KEY,
|
|
24
|
+
LABEL_SUB_KEY,
|
|
25
|
+
SENTENCE,
|
|
26
|
+
SEQUENCE,
|
|
27
|
+
DIETClassifier,
|
|
28
|
+
)
|
|
29
|
+
from rasa.nlu.extractors.extractor import EntityTagSpec
|
|
30
|
+
from rasa.utils.tensorflow import rasa_layers
|
|
31
|
+
from rasa.utils.tensorflow.constants import (
|
|
32
|
+
LABEL,
|
|
33
|
+
HIDDEN_LAYERS_SIZES,
|
|
34
|
+
SHARE_HIDDEN_LAYERS,
|
|
35
|
+
TRANSFORMER_SIZE,
|
|
36
|
+
NUM_TRANSFORMER_LAYERS,
|
|
37
|
+
NUM_HEADS,
|
|
38
|
+
BATCH_SIZES,
|
|
39
|
+
BATCH_STRATEGY,
|
|
40
|
+
EPOCHS,
|
|
41
|
+
RANDOM_SEED,
|
|
42
|
+
LEARNING_RATE,
|
|
43
|
+
RANKING_LENGTH,
|
|
44
|
+
RENORMALIZE_CONFIDENCES,
|
|
45
|
+
LOSS_TYPE,
|
|
46
|
+
SIMILARITY_TYPE,
|
|
47
|
+
NUM_NEG,
|
|
48
|
+
SPARSE_INPUT_DROPOUT,
|
|
49
|
+
DENSE_INPUT_DROPOUT,
|
|
50
|
+
MASKED_LM,
|
|
51
|
+
ENTITY_RECOGNITION,
|
|
52
|
+
INTENT_CLASSIFICATION,
|
|
53
|
+
EVAL_NUM_EXAMPLES,
|
|
54
|
+
EVAL_NUM_EPOCHS,
|
|
55
|
+
UNIDIRECTIONAL_ENCODER,
|
|
56
|
+
DROP_RATE,
|
|
57
|
+
DROP_RATE_ATTENTION,
|
|
58
|
+
CONNECTION_DENSITY,
|
|
59
|
+
NEGATIVE_MARGIN_SCALE,
|
|
60
|
+
REGULARIZATION_CONSTANT,
|
|
61
|
+
SCALE_LOSS,
|
|
62
|
+
USE_MAX_NEG_SIM,
|
|
63
|
+
MAX_NEG_SIM,
|
|
64
|
+
MAX_POS_SIM,
|
|
65
|
+
EMBEDDING_DIMENSION,
|
|
66
|
+
BILOU_FLAG,
|
|
67
|
+
KEY_RELATIVE_ATTENTION,
|
|
68
|
+
VALUE_RELATIVE_ATTENTION,
|
|
69
|
+
MAX_RELATIVE_POSITION,
|
|
70
|
+
RETRIEVAL_INTENT,
|
|
71
|
+
USE_TEXT_AS_LABEL,
|
|
72
|
+
CROSS_ENTROPY,
|
|
73
|
+
AUTO,
|
|
74
|
+
BALANCED,
|
|
75
|
+
TENSORBOARD_LOG_DIR,
|
|
76
|
+
TENSORBOARD_LOG_LEVEL,
|
|
77
|
+
CONCAT_DIMENSION,
|
|
78
|
+
FEATURIZERS,
|
|
79
|
+
CHECKPOINT_MODEL,
|
|
80
|
+
DENSE_DIMENSION,
|
|
81
|
+
CONSTRAIN_SIMILARITIES,
|
|
82
|
+
MODEL_CONFIDENCE,
|
|
83
|
+
SOFTMAX,
|
|
84
|
+
)
|
|
85
|
+
from rasa.nlu.constants import (
|
|
86
|
+
RESPONSE_SELECTOR_PROPERTY_NAME,
|
|
87
|
+
RESPONSE_SELECTOR_RETRIEVAL_INTENTS,
|
|
88
|
+
RESPONSE_SELECTOR_RESPONSES_KEY,
|
|
89
|
+
RESPONSE_SELECTOR_PREDICTION_KEY,
|
|
90
|
+
RESPONSE_SELECTOR_RANKING_KEY,
|
|
91
|
+
RESPONSE_SELECTOR_UTTER_ACTION_KEY,
|
|
92
|
+
RESPONSE_SELECTOR_DEFAULT_INTENT,
|
|
93
|
+
DEFAULT_TRANSFORMER_SIZE,
|
|
94
|
+
)
|
|
95
|
+
from rasa.shared.nlu.constants import (
|
|
96
|
+
TEXT,
|
|
97
|
+
INTENT,
|
|
98
|
+
RESPONSE,
|
|
99
|
+
INTENT_RESPONSE_KEY,
|
|
100
|
+
INTENT_NAME_KEY,
|
|
101
|
+
PREDICTED_CONFIDENCE_KEY,
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
from rasa.utils.tensorflow.model_data import RasaModelData
|
|
105
|
+
from rasa.utils.tensorflow.models import RasaModel
|
|
106
|
+
|
|
107
|
+
logger = logging.getLogger(__name__)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
@DefaultV1Recipe.register(
|
|
111
|
+
DefaultV1Recipe.ComponentType.INTENT_CLASSIFIER, is_trainable=True
|
|
112
|
+
)
|
|
113
|
+
class ResponseSelector(DIETClassifier):
|
|
114
|
+
"""Response selector using supervised embeddings.
|
|
115
|
+
|
|
116
|
+
The response selector embeds user inputs
|
|
117
|
+
and candidate response into the same space.
|
|
118
|
+
Supervised embeddings are trained by maximizing similarity between them.
|
|
119
|
+
It also provides rankings of the response that did not "win".
|
|
120
|
+
|
|
121
|
+
The supervised response selector needs to be preceded by
|
|
122
|
+
a featurizer in the pipeline.
|
|
123
|
+
This featurizer creates the features used for the embeddings.
|
|
124
|
+
It is recommended to use ``CountVectorsFeaturizer`` that
|
|
125
|
+
can be optionally preceded by ``SpacyNLP`` and ``SpacyTokenizer``.
|
|
126
|
+
|
|
127
|
+
Based on the starspace idea from: https://arxiv.org/abs/1709.03856.
|
|
128
|
+
However, in this implementation the `mu` parameter is treated differently
|
|
129
|
+
and additional hidden layers are added together with dropout.
|
|
130
|
+
"""
|
|
131
|
+
|
|
132
|
+
@classmethod
|
|
133
|
+
def required_components(cls) -> List[Type]:
|
|
134
|
+
"""Components that should be included in the pipeline before this component."""
|
|
135
|
+
return [Featurizer]
|
|
136
|
+
|
|
137
|
+
@staticmethod
|
|
138
|
+
def get_default_config() -> Dict[Text, Any]:
|
|
139
|
+
"""The component's default config (see parent class for full docstring)."""
|
|
140
|
+
return {
|
|
141
|
+
**DIETClassifier.get_default_config(),
|
|
142
|
+
# ## Architecture of the used neural network
|
|
143
|
+
# Hidden layer sizes for layers before the embedding layers for user message
|
|
144
|
+
# and labels.
|
|
145
|
+
# The number of hidden layers is equal to the length of the corresponding
|
|
146
|
+
# list.
|
|
147
|
+
HIDDEN_LAYERS_SIZES: {TEXT: [256, 128], LABEL: [256, 128]},
|
|
148
|
+
# Whether to share the hidden layer weights between input words
|
|
149
|
+
# and responses
|
|
150
|
+
SHARE_HIDDEN_LAYERS: False,
|
|
151
|
+
# Number of units in transformer
|
|
152
|
+
TRANSFORMER_SIZE: None,
|
|
153
|
+
# Number of transformer layers
|
|
154
|
+
NUM_TRANSFORMER_LAYERS: 0,
|
|
155
|
+
# Number of attention heads in transformer
|
|
156
|
+
NUM_HEADS: 4,
|
|
157
|
+
# If 'True' use key relative embeddings in attention
|
|
158
|
+
KEY_RELATIVE_ATTENTION: False,
|
|
159
|
+
# If 'True' use key relative embeddings in attention
|
|
160
|
+
VALUE_RELATIVE_ATTENTION: False,
|
|
161
|
+
# Max position for relative embeddings. Only in effect if key-
|
|
162
|
+
# or value relative attention are turned on
|
|
163
|
+
MAX_RELATIVE_POSITION: 5,
|
|
164
|
+
# Use a unidirectional or bidirectional encoder.
|
|
165
|
+
UNIDIRECTIONAL_ENCODER: False,
|
|
166
|
+
# ## Training parameters
|
|
167
|
+
# Initial and final batch sizes:
|
|
168
|
+
# Batch size will be linearly increased for each epoch.
|
|
169
|
+
BATCH_SIZES: [64, 256],
|
|
170
|
+
# Strategy used when creating batches.
|
|
171
|
+
# Can be either 'sequence' or 'balanced'.
|
|
172
|
+
BATCH_STRATEGY: BALANCED,
|
|
173
|
+
# Number of epochs to train
|
|
174
|
+
EPOCHS: 300,
|
|
175
|
+
# Set random seed to any 'int' to get reproducible results
|
|
176
|
+
RANDOM_SEED: None,
|
|
177
|
+
# Initial learning rate for the optimizer
|
|
178
|
+
LEARNING_RATE: 0.001,
|
|
179
|
+
# ## Parameters for embeddings
|
|
180
|
+
# Dimension size of embedding vectors
|
|
181
|
+
EMBEDDING_DIMENSION: 20,
|
|
182
|
+
# Default dense dimension to use if no dense features are present.
|
|
183
|
+
DENSE_DIMENSION: {TEXT: 512, LABEL: 512},
|
|
184
|
+
# Default dimension to use for concatenating sequence and sentence features.
|
|
185
|
+
CONCAT_DIMENSION: {TEXT: 512, LABEL: 512},
|
|
186
|
+
# The number of incorrect labels. The algorithm will minimize
|
|
187
|
+
# their similarity to the user input during training.
|
|
188
|
+
NUM_NEG: 20,
|
|
189
|
+
# Type of similarity measure to use, either 'auto' or 'cosine' or 'inner'.
|
|
190
|
+
SIMILARITY_TYPE: AUTO,
|
|
191
|
+
# The type of the loss function, either 'cross_entropy' or 'margin'.
|
|
192
|
+
LOSS_TYPE: CROSS_ENTROPY,
|
|
193
|
+
# Number of top actions for which confidences should be predicted.
|
|
194
|
+
# Set to 0 if confidences for all intents should be reported.
|
|
195
|
+
RANKING_LENGTH: 10,
|
|
196
|
+
# Determines whether the confidences of the chosen top actions should be
|
|
197
|
+
# renormalized so that they sum up to 1. By default, we do not renormalize
|
|
198
|
+
# and return the confidences for the top actions as is.
|
|
199
|
+
# Note that renormalization only makes sense if confidences are generated
|
|
200
|
+
# via `softmax`.
|
|
201
|
+
RENORMALIZE_CONFIDENCES: False,
|
|
202
|
+
# Indicates how similar the algorithm should try to make embedding vectors
|
|
203
|
+
# for correct labels.
|
|
204
|
+
# Should be 0.0 < ... < 1.0 for 'cosine' similarity type.
|
|
205
|
+
MAX_POS_SIM: 0.8,
|
|
206
|
+
# Maximum negative similarity for incorrect labels.
|
|
207
|
+
# Should be -1.0 < ... < 1.0 for 'cosine' similarity type.
|
|
208
|
+
MAX_NEG_SIM: -0.4,
|
|
209
|
+
# If 'True' the algorithm only minimizes maximum similarity over
|
|
210
|
+
# incorrect intent labels, used only if 'loss_type' is set to 'margin'.
|
|
211
|
+
USE_MAX_NEG_SIM: True,
|
|
212
|
+
# Scale loss inverse proportionally to confidence of correct prediction
|
|
213
|
+
SCALE_LOSS: True,
|
|
214
|
+
# ## Regularization parameters
|
|
215
|
+
# The scale of regularization
|
|
216
|
+
REGULARIZATION_CONSTANT: 0.002,
|
|
217
|
+
# Fraction of trainable weights in internal layers.
|
|
218
|
+
CONNECTION_DENSITY: 1.0,
|
|
219
|
+
# The scale of how important is to minimize the maximum similarity
|
|
220
|
+
# between embeddings of different labels.
|
|
221
|
+
NEGATIVE_MARGIN_SCALE: 0.8,
|
|
222
|
+
# Dropout rate for encoder
|
|
223
|
+
DROP_RATE: 0.2,
|
|
224
|
+
# Dropout rate for attention
|
|
225
|
+
DROP_RATE_ATTENTION: 0,
|
|
226
|
+
# If 'True' apply dropout to sparse input tensors
|
|
227
|
+
SPARSE_INPUT_DROPOUT: False,
|
|
228
|
+
# If 'True' apply dropout to dense input tensors
|
|
229
|
+
DENSE_INPUT_DROPOUT: False,
|
|
230
|
+
# ## Evaluation parameters
|
|
231
|
+
# How often calculate validation accuracy.
|
|
232
|
+
# Small values may hurt performance, e.g. model accuracy.
|
|
233
|
+
EVAL_NUM_EPOCHS: 20,
|
|
234
|
+
# How many examples to use for hold out validation set
|
|
235
|
+
# Large values may hurt performance, e.g. model accuracy.
|
|
236
|
+
EVAL_NUM_EXAMPLES: 0,
|
|
237
|
+
# ## Selector config
|
|
238
|
+
# If 'True' random tokens of the input message will be masked and the model
|
|
239
|
+
# should predict those tokens.
|
|
240
|
+
MASKED_LM: False,
|
|
241
|
+
# Name of the intent for which this response selector is to be trained
|
|
242
|
+
RETRIEVAL_INTENT: None,
|
|
243
|
+
# Boolean flag to check if actual text of the response
|
|
244
|
+
# should be used as ground truth label for training the model.
|
|
245
|
+
USE_TEXT_AS_LABEL: False,
|
|
246
|
+
# If you want to use tensorboard to visualize training
|
|
247
|
+
# and validation metrics,
|
|
248
|
+
# set this option to a valid output directory.
|
|
249
|
+
TENSORBOARD_LOG_DIR: None,
|
|
250
|
+
# Define when training metrics for tensorboard should be logged.
|
|
251
|
+
# Either after every epoch or for every training step.
|
|
252
|
+
# Valid values: 'epoch' and 'batch'
|
|
253
|
+
TENSORBOARD_LOG_LEVEL: "epoch",
|
|
254
|
+
# Specify what features to use as sequence and sentence features
|
|
255
|
+
# By default all features in the pipeline are used.
|
|
256
|
+
FEATURIZERS: [],
|
|
257
|
+
# Perform model checkpointing
|
|
258
|
+
CHECKPOINT_MODEL: False,
|
|
259
|
+
# if 'True' applies sigmoid on all similarity terms and adds it
|
|
260
|
+
# to the loss function to ensure that similarity values are
|
|
261
|
+
# approximately bounded. Used inside cross-entropy loss only.
|
|
262
|
+
CONSTRAIN_SIMILARITIES: False,
|
|
263
|
+
# Model confidence to be returned during inference. Currently, the only
|
|
264
|
+
# possible value is `softmax`.
|
|
265
|
+
MODEL_CONFIDENCE: SOFTMAX,
|
|
266
|
+
}
|
|
267
|
+
|
|
268
|
+
def __init__(
|
|
269
|
+
self,
|
|
270
|
+
config: Dict[Text, Any],
|
|
271
|
+
model_storage: ModelStorage,
|
|
272
|
+
resource: Resource,
|
|
273
|
+
execution_context: ExecutionContext,
|
|
274
|
+
index_label_id_mapping: Optional[Dict[int, Text]] = None,
|
|
275
|
+
entity_tag_specs: Optional[List[EntityTagSpec]] = None,
|
|
276
|
+
model: Optional[RasaModel] = None,
|
|
277
|
+
all_retrieval_intents: Optional[List[Text]] = None,
|
|
278
|
+
responses: Optional[Dict[Text, List[Dict[Text, Any]]]] = None,
|
|
279
|
+
sparse_feature_sizes: Optional[Dict[Text, Dict[Text, List[int]]]] = None,
|
|
280
|
+
) -> None:
|
|
281
|
+
"""Declare instance variables with default values.
|
|
282
|
+
|
|
283
|
+
Args:
|
|
284
|
+
config: Configuration for the component.
|
|
285
|
+
model_storage: Storage which graph components can use to persist and load
|
|
286
|
+
themselves.
|
|
287
|
+
resource: Resource locator for this component which can be used to persist
|
|
288
|
+
and load itself from the `model_storage`.
|
|
289
|
+
execution_context: Information about the current graph run.
|
|
290
|
+
index_label_id_mapping: Mapping between label and index used for encoding.
|
|
291
|
+
entity_tag_specs: Format specification all entity tags.
|
|
292
|
+
model: Model architecture.
|
|
293
|
+
all_retrieval_intents: All retrieval intents defined in the data.
|
|
294
|
+
responses: All responses defined in the data.
|
|
295
|
+
finetune_mode: If `True` loads the model with pre-trained weights,
|
|
296
|
+
otherwise initializes it with random weights.
|
|
297
|
+
sparse_feature_sizes: Sizes of the sparse features the model was trained on.
|
|
298
|
+
"""
|
|
299
|
+
component_config = config
|
|
300
|
+
|
|
301
|
+
# the following properties cannot be adapted for the ResponseSelector
|
|
302
|
+
component_config[INTENT_CLASSIFICATION] = True
|
|
303
|
+
component_config[ENTITY_RECOGNITION] = False
|
|
304
|
+
component_config[BILOU_FLAG] = None
|
|
305
|
+
|
|
306
|
+
# Initialize defaults
|
|
307
|
+
self.responses = responses or {}
|
|
308
|
+
self.all_retrieval_intents = all_retrieval_intents or []
|
|
309
|
+
self.retrieval_intent = None
|
|
310
|
+
self.use_text_as_label = False
|
|
311
|
+
|
|
312
|
+
super().__init__(
|
|
313
|
+
component_config,
|
|
314
|
+
model_storage,
|
|
315
|
+
resource,
|
|
316
|
+
execution_context,
|
|
317
|
+
index_label_id_mapping,
|
|
318
|
+
entity_tag_specs,
|
|
319
|
+
model,
|
|
320
|
+
sparse_feature_sizes=sparse_feature_sizes,
|
|
321
|
+
)
|
|
322
|
+
|
|
323
|
+
@property
|
|
324
|
+
def label_key(self) -> Text:
|
|
325
|
+
"""Returns label key."""
|
|
326
|
+
return LABEL_KEY
|
|
327
|
+
|
|
328
|
+
@property
|
|
329
|
+
def label_sub_key(self) -> Text:
|
|
330
|
+
"""Returns label sub_key."""
|
|
331
|
+
return LABEL_SUB_KEY
|
|
332
|
+
|
|
333
|
+
@staticmethod
|
|
334
|
+
def model_class( # type: ignore[override]
|
|
335
|
+
use_text_as_label: bool,
|
|
336
|
+
) -> Type[RasaModel]:
|
|
337
|
+
"""Returns model class."""
|
|
338
|
+
if use_text_as_label:
|
|
339
|
+
return DIET2DIET
|
|
340
|
+
else:
|
|
341
|
+
return DIET2BOW
|
|
342
|
+
|
|
343
|
+
def _load_selector_params(self) -> None:
|
|
344
|
+
self.retrieval_intent = self.component_config[RETRIEVAL_INTENT]
|
|
345
|
+
self.use_text_as_label = self.component_config[USE_TEXT_AS_LABEL]
|
|
346
|
+
|
|
347
|
+
def _warn_about_transformer_and_hidden_layers_enabled(
|
|
348
|
+
self, selector_name: Text
|
|
349
|
+
) -> None:
|
|
350
|
+
"""Warns user if they enabled the transformer but didn't disable hidden layers.
|
|
351
|
+
|
|
352
|
+
ResponseSelector defaults specify considerable hidden layer sizes, but
|
|
353
|
+
this is for cases where no transformer is used. If a transformer exists,
|
|
354
|
+
then, from our experience, the best results are achieved with no hidden layers
|
|
355
|
+
used between the feature-combining layers and the transformer.
|
|
356
|
+
"""
|
|
357
|
+
default_config = self.get_default_config()
|
|
358
|
+
hidden_layers_is_at_default_value = (
|
|
359
|
+
self.component_config[HIDDEN_LAYERS_SIZES]
|
|
360
|
+
== default_config[HIDDEN_LAYERS_SIZES]
|
|
361
|
+
)
|
|
362
|
+
config_for_disabling_hidden_layers: Dict[Text, List[Any]] = {
|
|
363
|
+
k: [] for k, _ in default_config[HIDDEN_LAYERS_SIZES].items()
|
|
364
|
+
}
|
|
365
|
+
# warn if the hidden layers aren't disabled
|
|
366
|
+
if (
|
|
367
|
+
self.component_config[HIDDEN_LAYERS_SIZES]
|
|
368
|
+
!= config_for_disabling_hidden_layers
|
|
369
|
+
):
|
|
370
|
+
# make the warning text more contextual by explaining what the user did
|
|
371
|
+
# to the hidden layers' config (i.e. what it is they should change)
|
|
372
|
+
if hidden_layers_is_at_default_value:
|
|
373
|
+
what_user_did = "left the hidden layer sizes at their default value:"
|
|
374
|
+
else:
|
|
375
|
+
what_user_did = "set the hidden layer sizes to be non-empty by setting"
|
|
376
|
+
|
|
377
|
+
rasa.shared.utils.io.raise_warning(
|
|
378
|
+
f"You have enabled a transformer inside {selector_name} by"
|
|
379
|
+
f" setting a positive value for `{NUM_TRANSFORMER_LAYERS}`, but you "
|
|
380
|
+
f"{what_user_did} `{HIDDEN_LAYERS_SIZES}="
|
|
381
|
+
f"{self.component_config[HIDDEN_LAYERS_SIZES]}`. We recommend to "
|
|
382
|
+
f"disable the hidden layers when using a transformer, by specifying "
|
|
383
|
+
f"`{HIDDEN_LAYERS_SIZES}={config_for_disabling_hidden_layers}`.",
|
|
384
|
+
category=UserWarning,
|
|
385
|
+
)
|
|
386
|
+
|
|
387
|
+
def _warn_and_correct_transformer_size(self, selector_name: Text) -> None:
|
|
388
|
+
"""Corrects transformer size so that training doesn't break; informs the user.
|
|
389
|
+
|
|
390
|
+
If a transformer is used, the default `transformer_size` breaks things.
|
|
391
|
+
We need to set a reasonable default value so that the model works fine.
|
|
392
|
+
"""
|
|
393
|
+
if (
|
|
394
|
+
self.component_config[TRANSFORMER_SIZE] is None
|
|
395
|
+
or self.component_config[TRANSFORMER_SIZE] < 1
|
|
396
|
+
):
|
|
397
|
+
rasa.shared.utils.io.raise_warning(
|
|
398
|
+
f"`{TRANSFORMER_SIZE}` is set to "
|
|
399
|
+
f"`{self.component_config[TRANSFORMER_SIZE]}` for "
|
|
400
|
+
f"{selector_name}, but a positive size is required when using "
|
|
401
|
+
f"`{NUM_TRANSFORMER_LAYERS} > 0`. {selector_name} will proceed, using "
|
|
402
|
+
f"`{TRANSFORMER_SIZE}={DEFAULT_TRANSFORMER_SIZE}`. "
|
|
403
|
+
f"Alternatively, specify a different value in the component's config.",
|
|
404
|
+
category=UserWarning,
|
|
405
|
+
)
|
|
406
|
+
self.component_config[TRANSFORMER_SIZE] = DEFAULT_TRANSFORMER_SIZE
|
|
407
|
+
|
|
408
|
+
def _check_config_params_when_transformer_enabled(self) -> None:
|
|
409
|
+
"""Checks & corrects config parameters when the transformer is enabled.
|
|
410
|
+
|
|
411
|
+
This is needed because the defaults for individual config parameters are
|
|
412
|
+
interdependent and some defaults should change when the transformer is enabled.
|
|
413
|
+
"""
|
|
414
|
+
if self.component_config[NUM_TRANSFORMER_LAYERS] > 0:
|
|
415
|
+
selector_name = "ResponseSelector" + (
|
|
416
|
+
f"({self.retrieval_intent})" if self.retrieval_intent else ""
|
|
417
|
+
)
|
|
418
|
+
self._warn_about_transformer_and_hidden_layers_enabled(selector_name)
|
|
419
|
+
self._warn_and_correct_transformer_size(selector_name)
|
|
420
|
+
|
|
421
|
+
def _check_config_parameters(self) -> None:
|
|
422
|
+
"""Checks that component configuration makes sense; corrects it where needed."""
|
|
423
|
+
super()._check_config_parameters()
|
|
424
|
+
self._load_selector_params()
|
|
425
|
+
# Once general DIET-related parameters have been checked, check also the ones
|
|
426
|
+
# specific to ResponseSelector.
|
|
427
|
+
self._check_config_params_when_transformer_enabled()
|
|
428
|
+
|
|
429
|
+
def _set_message_property(
|
|
430
|
+
self, message: Message, prediction_dict: Dict[Text, Any], selector_key: Text
|
|
431
|
+
) -> None:
|
|
432
|
+
message_selector_properties = message.get(RESPONSE_SELECTOR_PROPERTY_NAME, {})
|
|
433
|
+
message_selector_properties[RESPONSE_SELECTOR_RETRIEVAL_INTENTS] = (
|
|
434
|
+
self.all_retrieval_intents
|
|
435
|
+
)
|
|
436
|
+
message_selector_properties[selector_key] = prediction_dict
|
|
437
|
+
message.set(
|
|
438
|
+
RESPONSE_SELECTOR_PROPERTY_NAME,
|
|
439
|
+
message_selector_properties,
|
|
440
|
+
add_to_output=True,
|
|
441
|
+
)
|
|
442
|
+
|
|
443
|
+
def preprocess_train_data(self, training_data: TrainingData) -> RasaModelData:
|
|
444
|
+
"""Prepares data for training.
|
|
445
|
+
|
|
446
|
+
Performs sanity checks on training data, extracts encodings for labels.
|
|
447
|
+
|
|
448
|
+
Args:
|
|
449
|
+
training_data: training data to preprocessed.
|
|
450
|
+
"""
|
|
451
|
+
# Collect all retrieval intents present in the data before filtering
|
|
452
|
+
self.all_retrieval_intents = list(training_data.retrieval_intents)
|
|
453
|
+
|
|
454
|
+
if self.retrieval_intent:
|
|
455
|
+
training_data = training_data.filter_training_examples(
|
|
456
|
+
lambda ex: self.retrieval_intent == ex.get(INTENT)
|
|
457
|
+
)
|
|
458
|
+
else:
|
|
459
|
+
# retrieval intent was left to its default value
|
|
460
|
+
logger.info(
|
|
461
|
+
"Retrieval intent parameter was left to its default value. This "
|
|
462
|
+
"response selector will be trained on training examples combining "
|
|
463
|
+
"all retrieval intents."
|
|
464
|
+
)
|
|
465
|
+
|
|
466
|
+
label_attribute = RESPONSE if self.use_text_as_label else INTENT_RESPONSE_KEY
|
|
467
|
+
|
|
468
|
+
label_id_index_mapping = self._label_id_index_mapping(
|
|
469
|
+
training_data, attribute=label_attribute
|
|
470
|
+
)
|
|
471
|
+
|
|
472
|
+
self.responses = training_data.responses
|
|
473
|
+
|
|
474
|
+
if not label_id_index_mapping:
|
|
475
|
+
# no labels are present to train
|
|
476
|
+
return RasaModelData()
|
|
477
|
+
|
|
478
|
+
self.index_label_id_mapping = self._invert_mapping(label_id_index_mapping)
|
|
479
|
+
|
|
480
|
+
self._label_data = self._create_label_data(
|
|
481
|
+
training_data, label_id_index_mapping, attribute=label_attribute
|
|
482
|
+
)
|
|
483
|
+
|
|
484
|
+
model_data = self._create_model_data(
|
|
485
|
+
training_data.intent_examples,
|
|
486
|
+
label_id_index_mapping,
|
|
487
|
+
label_attribute=label_attribute,
|
|
488
|
+
)
|
|
489
|
+
|
|
490
|
+
self._check_input_dimension_consistency(model_data)
|
|
491
|
+
|
|
492
|
+
return model_data
|
|
493
|
+
|
|
494
|
+
def _resolve_intent_response_key(
|
|
495
|
+
self, label: Dict[Text, Optional[Text]]
|
|
496
|
+
) -> Optional[Text]:
|
|
497
|
+
"""Given a label, return the response key based on the label id.
|
|
498
|
+
|
|
499
|
+
Args:
|
|
500
|
+
label: predicted label by the selector
|
|
501
|
+
|
|
502
|
+
Returns:
|
|
503
|
+
The match for the label that was found in the known responses.
|
|
504
|
+
It is always guaranteed to have a match, otherwise that case should have
|
|
505
|
+
been caught earlier and a warning should have been raised.
|
|
506
|
+
"""
|
|
507
|
+
for key, responses in self.responses.items():
|
|
508
|
+
# First check if the predicted label was the key itself
|
|
509
|
+
search_key = util.template_key_to_intent_response_key(key)
|
|
510
|
+
if search_key == label.get("name"):
|
|
511
|
+
return search_key
|
|
512
|
+
|
|
513
|
+
# Otherwise loop over the responses to check if the text has a direct match
|
|
514
|
+
for response in responses:
|
|
515
|
+
if response.get(TEXT, "") == label.get("name"):
|
|
516
|
+
return search_key
|
|
517
|
+
return None
|
|
518
|
+
|
|
519
|
+
def process(self, messages: List[Message]) -> List[Message]:
|
|
520
|
+
"""Selects most like response for message.
|
|
521
|
+
|
|
522
|
+
Args:
|
|
523
|
+
messages: List containing latest user message.
|
|
524
|
+
|
|
525
|
+
Returns:
|
|
526
|
+
List containing the message augmented with the most likely response,
|
|
527
|
+
the associated intent_response_key and its similarity to the input.
|
|
528
|
+
"""
|
|
529
|
+
for message in messages:
|
|
530
|
+
out = self._predict(message)
|
|
531
|
+
top_label, label_ranking = self._predict_label(out)
|
|
532
|
+
|
|
533
|
+
# Get the exact intent_response_key and the associated
|
|
534
|
+
# responses for the top predicted label
|
|
535
|
+
label_intent_response_key = (
|
|
536
|
+
self._resolve_intent_response_key(top_label)
|
|
537
|
+
or top_label[INTENT_NAME_KEY]
|
|
538
|
+
)
|
|
539
|
+
label_responses = self.responses.get(
|
|
540
|
+
util.intent_response_key_to_template_key(label_intent_response_key)
|
|
541
|
+
)
|
|
542
|
+
|
|
543
|
+
if label_intent_response_key and not label_responses:
|
|
544
|
+
# responses seem to be unavailable,
|
|
545
|
+
# likely an issue with the training data
|
|
546
|
+
# we'll use a fallback instead
|
|
547
|
+
rasa.shared.utils.io.raise_warning(
|
|
548
|
+
f"Unable to fetch responses for {label_intent_response_key} "
|
|
549
|
+
f"This means that there is likely an issue with the training data."
|
|
550
|
+
f"Please make sure you have added responses for this intent."
|
|
551
|
+
)
|
|
552
|
+
label_responses = [{TEXT: label_intent_response_key}]
|
|
553
|
+
|
|
554
|
+
for label in label_ranking:
|
|
555
|
+
label[INTENT_RESPONSE_KEY] = (
|
|
556
|
+
self._resolve_intent_response_key(label) or label[INTENT_NAME_KEY]
|
|
557
|
+
)
|
|
558
|
+
# Remove the "name" key since it is either the same as
|
|
559
|
+
# "intent_response_key" or it is the response text which
|
|
560
|
+
# is not needed in the ranking.
|
|
561
|
+
label.pop(INTENT_NAME_KEY)
|
|
562
|
+
|
|
563
|
+
selector_key = (
|
|
564
|
+
self.retrieval_intent
|
|
565
|
+
if self.retrieval_intent
|
|
566
|
+
else RESPONSE_SELECTOR_DEFAULT_INTENT
|
|
567
|
+
)
|
|
568
|
+
|
|
569
|
+
logger.debug(
|
|
570
|
+
f"Adding following selector key to message property: {selector_key}"
|
|
571
|
+
)
|
|
572
|
+
|
|
573
|
+
utter_action_key = util.intent_response_key_to_template_key(
|
|
574
|
+
label_intent_response_key
|
|
575
|
+
)
|
|
576
|
+
prediction_dict = {
|
|
577
|
+
RESPONSE_SELECTOR_PREDICTION_KEY: {
|
|
578
|
+
RESPONSE_SELECTOR_RESPONSES_KEY: label_responses,
|
|
579
|
+
PREDICTED_CONFIDENCE_KEY: top_label[PREDICTED_CONFIDENCE_KEY],
|
|
580
|
+
INTENT_RESPONSE_KEY: label_intent_response_key,
|
|
581
|
+
RESPONSE_SELECTOR_UTTER_ACTION_KEY: utter_action_key,
|
|
582
|
+
},
|
|
583
|
+
RESPONSE_SELECTOR_RANKING_KEY: label_ranking,
|
|
584
|
+
}
|
|
585
|
+
|
|
586
|
+
self._set_message_property(message, prediction_dict, selector_key)
|
|
587
|
+
|
|
588
|
+
if (
|
|
589
|
+
self._execution_context.should_add_diagnostic_data
|
|
590
|
+
and out
|
|
591
|
+
and DIAGNOSTIC_DATA in out
|
|
592
|
+
):
|
|
593
|
+
message.add_diagnostic_data(
|
|
594
|
+
self._execution_context.node_name, out.get(DIAGNOSTIC_DATA)
|
|
595
|
+
)
|
|
596
|
+
|
|
597
|
+
return messages
|
|
598
|
+
|
|
599
|
+
def persist(self) -> None:
|
|
600
|
+
"""Persist this model into the passed directory."""
|
|
601
|
+
if self.model is None:
|
|
602
|
+
return None
|
|
603
|
+
|
|
604
|
+
with self._model_storage.write_to(self._resource) as model_path:
|
|
605
|
+
file_name = self.__class__.__name__
|
|
606
|
+
|
|
607
|
+
rasa.shared.utils.io.dump_obj_as_json_to_file(
|
|
608
|
+
model_path / f"{file_name}.responses.json", self.responses
|
|
609
|
+
)
|
|
610
|
+
|
|
611
|
+
rasa.shared.utils.io.dump_obj_as_json_to_file(
|
|
612
|
+
model_path / f"{file_name}.retrieval_intents.json",
|
|
613
|
+
self.all_retrieval_intents,
|
|
614
|
+
)
|
|
615
|
+
|
|
616
|
+
super().persist()
|
|
617
|
+
|
|
618
|
+
@classmethod
|
|
619
|
+
def _load_model_class(
|
|
620
|
+
cls,
|
|
621
|
+
tf_model_file: Text,
|
|
622
|
+
model_data_example: RasaModelData,
|
|
623
|
+
label_data: RasaModelData,
|
|
624
|
+
entity_tag_specs: List[EntityTagSpec],
|
|
625
|
+
config: Dict[Text, Any],
|
|
626
|
+
finetune_mode: bool = False,
|
|
627
|
+
) -> "RasaModel":
|
|
628
|
+
predict_data_example = RasaModelData(
|
|
629
|
+
label_key=model_data_example.label_key,
|
|
630
|
+
data={
|
|
631
|
+
feature_name: features
|
|
632
|
+
for feature_name, features in model_data_example.items()
|
|
633
|
+
if TEXT in feature_name
|
|
634
|
+
},
|
|
635
|
+
)
|
|
636
|
+
return cls.model_class(config[USE_TEXT_AS_LABEL]).load(
|
|
637
|
+
tf_model_file,
|
|
638
|
+
model_data_example,
|
|
639
|
+
predict_data_example,
|
|
640
|
+
data_signature=model_data_example.get_signature(),
|
|
641
|
+
label_data=label_data,
|
|
642
|
+
entity_tag_specs=entity_tag_specs,
|
|
643
|
+
config=copy.deepcopy(config),
|
|
644
|
+
finetune_mode=finetune_mode,
|
|
645
|
+
)
|
|
646
|
+
|
|
647
|
+
def _instantiate_model_class(self, model_data: RasaModelData) -> "RasaModel":
|
|
648
|
+
return self.model_class(self.use_text_as_label)(
|
|
649
|
+
data_signature=model_data.get_signature(),
|
|
650
|
+
label_data=self._label_data,
|
|
651
|
+
entity_tag_specs=self._entity_tag_specs,
|
|
652
|
+
config=self.component_config,
|
|
653
|
+
)
|
|
654
|
+
|
|
655
|
+
@classmethod
|
|
656
|
+
def load(
|
|
657
|
+
cls,
|
|
658
|
+
config: Dict[Text, Any],
|
|
659
|
+
model_storage: ModelStorage,
|
|
660
|
+
resource: Resource,
|
|
661
|
+
execution_context: ExecutionContext,
|
|
662
|
+
**kwargs: Any,
|
|
663
|
+
) -> ResponseSelector:
|
|
664
|
+
"""Loads the trained model from the provided directory."""
|
|
665
|
+
model = super().load(
|
|
666
|
+
config, model_storage, resource, execution_context, **kwargs
|
|
667
|
+
)
|
|
668
|
+
|
|
669
|
+
try:
|
|
670
|
+
with model_storage.read_from(resource) as model_path:
|
|
671
|
+
file_name = cls.__name__
|
|
672
|
+
responses = rasa.shared.utils.io.read_json_file(
|
|
673
|
+
model_path / f"{file_name}.responses.json"
|
|
674
|
+
)
|
|
675
|
+
all_retrieval_intents = rasa.shared.utils.io.read_json_file(
|
|
676
|
+
model_path / f"{file_name}.retrieval_intents.json"
|
|
677
|
+
)
|
|
678
|
+
model.responses = responses
|
|
679
|
+
model.all_retrieval_intents = all_retrieval_intents
|
|
680
|
+
return model
|
|
681
|
+
except ValueError:
|
|
682
|
+
logger.debug(
|
|
683
|
+
f"Failed to load {cls.__name__} from model storage. Resource "
|
|
684
|
+
f"'{resource.name}' doesn't exist."
|
|
685
|
+
)
|
|
686
|
+
return cls(config, model_storage, resource, execution_context)
|
|
687
|
+
|
|
688
|
+
|
|
689
|
+
class DIET2BOW(DIET):
|
|
690
|
+
"""DIET2BOW transformer implementation."""
|
|
691
|
+
|
|
692
|
+
def _create_metrics(self) -> None:
|
|
693
|
+
# self.metrics preserve order
|
|
694
|
+
# output losses first
|
|
695
|
+
self.mask_loss = tf.keras.metrics.Mean(name="m_loss")
|
|
696
|
+
self.response_loss = tf.keras.metrics.Mean(name="r_loss")
|
|
697
|
+
# output accuracies second
|
|
698
|
+
self.mask_acc = tf.keras.metrics.Mean(name="m_acc")
|
|
699
|
+
self.response_acc = tf.keras.metrics.Mean(name="r_acc")
|
|
700
|
+
|
|
701
|
+
def _update_metrics_to_log(self) -> None:
|
|
702
|
+
debug_log_level = logging.getLogger("rasa").level == logging.DEBUG
|
|
703
|
+
|
|
704
|
+
if self.config[MASKED_LM]:
|
|
705
|
+
self.metrics_to_log.append("m_acc")
|
|
706
|
+
if debug_log_level:
|
|
707
|
+
self.metrics_to_log.append("m_loss")
|
|
708
|
+
|
|
709
|
+
self.metrics_to_log.append("r_acc")
|
|
710
|
+
if debug_log_level:
|
|
711
|
+
self.metrics_to_log.append("r_loss")
|
|
712
|
+
|
|
713
|
+
self._log_metric_info()
|
|
714
|
+
|
|
715
|
+
def _log_metric_info(self) -> None:
|
|
716
|
+
metric_name = {"t": "total", "m": "mask", "r": "response"}
|
|
717
|
+
logger.debug("Following metrics will be logged during training: ")
|
|
718
|
+
for metric in self.metrics_to_log:
|
|
719
|
+
parts = metric.split("_")
|
|
720
|
+
name = f"{metric_name[parts[0]]} {parts[1]}"
|
|
721
|
+
logger.debug(f" {metric} ({name})")
|
|
722
|
+
|
|
723
|
+
def _update_label_metrics(self, loss: tf.Tensor, acc: tf.Tensor) -> None:
|
|
724
|
+
self.response_loss.update_state(loss)
|
|
725
|
+
self.response_acc.update_state(acc)
|
|
726
|
+
|
|
727
|
+
|
|
728
|
+
class DIET2DIET(DIET):
|
|
729
|
+
"""Diet 2 Diet transformer implementation."""
|
|
730
|
+
|
|
731
|
+
def _check_data(self) -> None:
|
|
732
|
+
if TEXT not in self.data_signature:
|
|
733
|
+
raise InvalidConfigException(
|
|
734
|
+
f"No text features specified. "
|
|
735
|
+
f"Cannot train '{self.__class__.__name__}' model."
|
|
736
|
+
)
|
|
737
|
+
if LABEL not in self.data_signature:
|
|
738
|
+
raise InvalidConfigException(
|
|
739
|
+
f"No label features specified. "
|
|
740
|
+
f"Cannot train '{self.__class__.__name__}' model."
|
|
741
|
+
)
|
|
742
|
+
if (
|
|
743
|
+
self.config[SHARE_HIDDEN_LAYERS]
|
|
744
|
+
and self.data_signature[TEXT][SENTENCE]
|
|
745
|
+
!= self.data_signature[LABEL][SENTENCE]
|
|
746
|
+
):
|
|
747
|
+
raise ValueError(
|
|
748
|
+
"If hidden layer weights are shared, data signatures "
|
|
749
|
+
"for text_features and label_features must coincide."
|
|
750
|
+
)
|
|
751
|
+
|
|
752
|
+
def _create_metrics(self) -> None:
|
|
753
|
+
# self.metrics preserve order
|
|
754
|
+
# output losses first
|
|
755
|
+
self.mask_loss = tf.keras.metrics.Mean(name="m_loss")
|
|
756
|
+
self.response_loss = tf.keras.metrics.Mean(name="r_loss")
|
|
757
|
+
# output accuracies second
|
|
758
|
+
self.mask_acc = tf.keras.metrics.Mean(name="m_acc")
|
|
759
|
+
self.response_acc = tf.keras.metrics.Mean(name="r_acc")
|
|
760
|
+
|
|
761
|
+
def _update_metrics_to_log(self) -> None:
|
|
762
|
+
debug_log_level = logging.getLogger("rasa").level == logging.DEBUG
|
|
763
|
+
|
|
764
|
+
if self.config[MASKED_LM]:
|
|
765
|
+
self.metrics_to_log.append("m_acc")
|
|
766
|
+
if debug_log_level:
|
|
767
|
+
self.metrics_to_log.append("m_loss")
|
|
768
|
+
|
|
769
|
+
self.metrics_to_log.append("r_acc")
|
|
770
|
+
if debug_log_level:
|
|
771
|
+
self.metrics_to_log.append("r_loss")
|
|
772
|
+
|
|
773
|
+
self._log_metric_info()
|
|
774
|
+
|
|
775
|
+
def _log_metric_info(self) -> None:
|
|
776
|
+
metric_name = {"t": "total", "m": "mask", "r": "response"}
|
|
777
|
+
logger.debug("Following metrics will be logged during training: ")
|
|
778
|
+
for metric in self.metrics_to_log:
|
|
779
|
+
parts = metric.split("_")
|
|
780
|
+
name = f"{metric_name[parts[0]]} {parts[1]}"
|
|
781
|
+
logger.debug(f" {metric} ({name})")
|
|
782
|
+
|
|
783
|
+
def _prepare_layers(self) -> None:
|
|
784
|
+
self.text_name = TEXT
|
|
785
|
+
self.label_name = TEXT if self.config[SHARE_HIDDEN_LAYERS] else LABEL
|
|
786
|
+
|
|
787
|
+
# For user text and response text, prepare layers that combine different feature
|
|
788
|
+
# types, embed everything using a transformer and optionally also do masked
|
|
789
|
+
# language modeling. Omit input dropout for label features.
|
|
790
|
+
label_config = self.config.copy()
|
|
791
|
+
label_config.update({SPARSE_INPUT_DROPOUT: False, DENSE_INPUT_DROPOUT: False})
|
|
792
|
+
for attribute, config in [
|
|
793
|
+
(self.text_name, self.config),
|
|
794
|
+
(self.label_name, label_config),
|
|
795
|
+
]:
|
|
796
|
+
self._tf_layers[f"sequence_layer.{attribute}"] = (
|
|
797
|
+
rasa_layers.RasaSequenceLayer(
|
|
798
|
+
attribute, self.data_signature[attribute], config
|
|
799
|
+
)
|
|
800
|
+
)
|
|
801
|
+
|
|
802
|
+
if self.config[MASKED_LM]:
|
|
803
|
+
self._prepare_mask_lm_loss(self.text_name)
|
|
804
|
+
|
|
805
|
+
self._prepare_label_classification_layers(predictor_attribute=self.text_name)
|
|
806
|
+
|
|
807
|
+
def _create_all_labels(self) -> Tuple[tf.Tensor, tf.Tensor]:
|
|
808
|
+
all_label_ids = self.tf_label_data[LABEL_KEY][LABEL_SUB_KEY][0]
|
|
809
|
+
|
|
810
|
+
sequence_feature_lengths = self._get_sequence_feature_lengths(
|
|
811
|
+
self.tf_label_data, LABEL
|
|
812
|
+
)
|
|
813
|
+
|
|
814
|
+
# Combine all feature types into one and embed using a transformer.
|
|
815
|
+
label_transformed, _, _, _, _, _ = self._tf_layers[
|
|
816
|
+
f"sequence_layer.{self.label_name}"
|
|
817
|
+
](
|
|
818
|
+
(
|
|
819
|
+
self.tf_label_data[LABEL][SEQUENCE],
|
|
820
|
+
self.tf_label_data[LABEL][SENTENCE],
|
|
821
|
+
sequence_feature_lengths,
|
|
822
|
+
),
|
|
823
|
+
training=self._training,
|
|
824
|
+
)
|
|
825
|
+
|
|
826
|
+
# Last token is taken from the last position with real features, determined
|
|
827
|
+
# - by the number of real tokens, i.e. by the sequence length of sequence-level
|
|
828
|
+
# features, and
|
|
829
|
+
# - by the presence or absence of sentence-level features (reflected in the
|
|
830
|
+
# effective sequence length of these features being 1 or 0.
|
|
831
|
+
# We need to combine the two lengths to correctly get the last position.
|
|
832
|
+
sentence_feature_lengths = self._get_sentence_feature_lengths(
|
|
833
|
+
self.tf_label_data, LABEL
|
|
834
|
+
)
|
|
835
|
+
sentence_label = self._last_token(
|
|
836
|
+
label_transformed, sequence_feature_lengths + sentence_feature_lengths
|
|
837
|
+
)
|
|
838
|
+
|
|
839
|
+
all_labels_embed = self._tf_layers[f"embed.{LABEL}"](sentence_label)
|
|
840
|
+
|
|
841
|
+
return all_label_ids, all_labels_embed
|
|
842
|
+
|
|
843
|
+
def batch_loss(
|
|
844
|
+
self, batch_in: Union[Tuple[tf.Tensor, ...], Tuple[np.ndarray, ...]]
|
|
845
|
+
) -> tf.Tensor:
|
|
846
|
+
"""Calculates the loss for the given batch.
|
|
847
|
+
|
|
848
|
+
Args:
|
|
849
|
+
batch_in: The batch.
|
|
850
|
+
|
|
851
|
+
Returns:
|
|
852
|
+
The loss of the given batch.
|
|
853
|
+
"""
|
|
854
|
+
tf_batch_data = self.batch_to_model_data_format(batch_in, self.data_signature)
|
|
855
|
+
|
|
856
|
+
# Process all features for text.
|
|
857
|
+
sequence_feature_lengths_text = self._get_sequence_feature_lengths(
|
|
858
|
+
tf_batch_data, TEXT
|
|
859
|
+
)
|
|
860
|
+
(
|
|
861
|
+
text_transformed,
|
|
862
|
+
text_in,
|
|
863
|
+
_,
|
|
864
|
+
text_seq_ids,
|
|
865
|
+
mlm_mask_booleanean_text,
|
|
866
|
+
_,
|
|
867
|
+
) = self._tf_layers[f"sequence_layer.{self.text_name}"](
|
|
868
|
+
(
|
|
869
|
+
tf_batch_data[TEXT][SEQUENCE],
|
|
870
|
+
tf_batch_data[TEXT][SENTENCE],
|
|
871
|
+
sequence_feature_lengths_text,
|
|
872
|
+
),
|
|
873
|
+
training=self._training,
|
|
874
|
+
)
|
|
875
|
+
|
|
876
|
+
# Process all features for labels.
|
|
877
|
+
sequence_feature_lengths_label = self._get_sequence_feature_lengths(
|
|
878
|
+
tf_batch_data, LABEL
|
|
879
|
+
)
|
|
880
|
+
label_transformed, _, _, _, _, _ = self._tf_layers[
|
|
881
|
+
f"sequence_layer.{self.label_name}"
|
|
882
|
+
](
|
|
883
|
+
(
|
|
884
|
+
tf_batch_data[LABEL][SEQUENCE],
|
|
885
|
+
tf_batch_data[LABEL][SENTENCE],
|
|
886
|
+
sequence_feature_lengths_label,
|
|
887
|
+
),
|
|
888
|
+
training=self._training,
|
|
889
|
+
)
|
|
890
|
+
|
|
891
|
+
losses = []
|
|
892
|
+
|
|
893
|
+
if self.config[MASKED_LM]:
|
|
894
|
+
loss, acc = self._mask_loss(
|
|
895
|
+
text_transformed,
|
|
896
|
+
text_in,
|
|
897
|
+
text_seq_ids,
|
|
898
|
+
mlm_mask_booleanean_text,
|
|
899
|
+
self.text_name,
|
|
900
|
+
)
|
|
901
|
+
|
|
902
|
+
self.mask_loss.update_state(loss)
|
|
903
|
+
self.mask_acc.update_state(acc)
|
|
904
|
+
losses.append(loss)
|
|
905
|
+
|
|
906
|
+
# Get sentence feature vector for label classification. The vector is extracted
|
|
907
|
+
# from the last position with real features. To determine this position, we
|
|
908
|
+
# combine the sequence lengths of sequence- and sentence-level features.
|
|
909
|
+
sentence_feature_lengths_text = self._get_sentence_feature_lengths(
|
|
910
|
+
tf_batch_data, TEXT
|
|
911
|
+
)
|
|
912
|
+
sentence_vector_text = self._last_token(
|
|
913
|
+
text_transformed,
|
|
914
|
+
sequence_feature_lengths_text + sentence_feature_lengths_text,
|
|
915
|
+
)
|
|
916
|
+
|
|
917
|
+
# Extract sentence vector for the label attribute in the same way.
|
|
918
|
+
sentence_feature_lengths_label = self._get_sentence_feature_lengths(
|
|
919
|
+
tf_batch_data, LABEL
|
|
920
|
+
)
|
|
921
|
+
sentence_vector_label = self._last_token(
|
|
922
|
+
label_transformed,
|
|
923
|
+
sequence_feature_lengths_label + sentence_feature_lengths_label,
|
|
924
|
+
)
|
|
925
|
+
label_ids = tf_batch_data[LABEL_KEY][LABEL_SUB_KEY][0]
|
|
926
|
+
|
|
927
|
+
loss, acc = self._calculate_label_loss(
|
|
928
|
+
sentence_vector_text, sentence_vector_label, label_ids
|
|
929
|
+
)
|
|
930
|
+
self.response_loss.update_state(loss)
|
|
931
|
+
self.response_acc.update_state(acc)
|
|
932
|
+
losses.append(loss)
|
|
933
|
+
|
|
934
|
+
return tf.math.add_n(losses)
|
|
935
|
+
|
|
936
|
+
def batch_predict(
|
|
937
|
+
self, batch_in: Union[Tuple[tf.Tensor, ...], Tuple[np.ndarray, ...]]
|
|
938
|
+
) -> Dict[Text, Union[tf.Tensor, Dict[Text, tf.Tensor]]]:
|
|
939
|
+
"""Predicts the output of the given batch.
|
|
940
|
+
|
|
941
|
+
Args:
|
|
942
|
+
batch_in: The batch.
|
|
943
|
+
|
|
944
|
+
Returns:
|
|
945
|
+
The output to predict.
|
|
946
|
+
"""
|
|
947
|
+
tf_batch_data = self.batch_to_model_data_format(
|
|
948
|
+
batch_in, self.predict_data_signature
|
|
949
|
+
)
|
|
950
|
+
|
|
951
|
+
sequence_feature_lengths = self._get_sequence_feature_lengths(
|
|
952
|
+
tf_batch_data, TEXT
|
|
953
|
+
)
|
|
954
|
+
text_transformed, _, _, _, _, attention_weights = self._tf_layers[
|
|
955
|
+
f"sequence_layer.{self.text_name}"
|
|
956
|
+
](
|
|
957
|
+
(
|
|
958
|
+
tf_batch_data[TEXT][SEQUENCE],
|
|
959
|
+
tf_batch_data[TEXT][SENTENCE],
|
|
960
|
+
sequence_feature_lengths,
|
|
961
|
+
),
|
|
962
|
+
training=self._training,
|
|
963
|
+
)
|
|
964
|
+
|
|
965
|
+
predictions = {
|
|
966
|
+
DIAGNOSTIC_DATA: {
|
|
967
|
+
"attention_weights": attention_weights,
|
|
968
|
+
"text_transformed": text_transformed,
|
|
969
|
+
}
|
|
970
|
+
}
|
|
971
|
+
|
|
972
|
+
if self.all_labels_embed is None:
|
|
973
|
+
_, self.all_labels_embed = self._create_all_labels()
|
|
974
|
+
|
|
975
|
+
# get sentence feature vector for intent classification
|
|
976
|
+
sentence_vector = self._last_token(text_transformed, sequence_feature_lengths)
|
|
977
|
+
sentence_vector_embed = self._tf_layers[f"embed.{TEXT}"](sentence_vector)
|
|
978
|
+
|
|
979
|
+
_, scores = self._tf_layers[
|
|
980
|
+
f"loss.{LABEL}"
|
|
981
|
+
].get_similarities_and_confidences_from_embeddings(
|
|
982
|
+
sentence_vector_embed[:, tf.newaxis, :],
|
|
983
|
+
self.all_labels_embed[tf.newaxis, :, :],
|
|
984
|
+
)
|
|
985
|
+
predictions["i_scores"] = scores
|
|
986
|
+
|
|
987
|
+
return predictions
|