rasa-pro 3.12.0.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- README.md +41 -0
- rasa/__init__.py +9 -0
- rasa/__main__.py +177 -0
- rasa/anonymization/__init__.py +2 -0
- rasa/anonymization/anonymisation_rule_yaml_reader.py +91 -0
- rasa/anonymization/anonymization_pipeline.py +286 -0
- rasa/anonymization/anonymization_rule_executor.py +260 -0
- rasa/anonymization/anonymization_rule_orchestrator.py +120 -0
- rasa/anonymization/schemas/config.yml +47 -0
- rasa/anonymization/utils.py +118 -0
- rasa/api.py +160 -0
- rasa/cli/__init__.py +5 -0
- rasa/cli/arguments/__init__.py +0 -0
- rasa/cli/arguments/data.py +106 -0
- rasa/cli/arguments/default_arguments.py +207 -0
- rasa/cli/arguments/evaluate.py +65 -0
- rasa/cli/arguments/export.py +51 -0
- rasa/cli/arguments/interactive.py +74 -0
- rasa/cli/arguments/run.py +219 -0
- rasa/cli/arguments/shell.py +17 -0
- rasa/cli/arguments/test.py +211 -0
- rasa/cli/arguments/train.py +279 -0
- rasa/cli/arguments/visualize.py +34 -0
- rasa/cli/arguments/x.py +30 -0
- rasa/cli/data.py +354 -0
- rasa/cli/dialogue_understanding_test.py +251 -0
- rasa/cli/e2e_test.py +259 -0
- rasa/cli/evaluate.py +222 -0
- rasa/cli/export.py +250 -0
- rasa/cli/inspect.py +75 -0
- rasa/cli/interactive.py +166 -0
- rasa/cli/license.py +65 -0
- rasa/cli/llm_fine_tuning.py +403 -0
- rasa/cli/markers.py +78 -0
- rasa/cli/project_templates/__init__.py +0 -0
- rasa/cli/project_templates/calm/actions/__init__.py +0 -0
- rasa/cli/project_templates/calm/actions/action_template.py +27 -0
- rasa/cli/project_templates/calm/actions/add_contact.py +30 -0
- rasa/cli/project_templates/calm/actions/db.py +57 -0
- rasa/cli/project_templates/calm/actions/list_contacts.py +22 -0
- rasa/cli/project_templates/calm/actions/remove_contact.py +35 -0
- rasa/cli/project_templates/calm/config.yml +10 -0
- rasa/cli/project_templates/calm/credentials.yml +33 -0
- rasa/cli/project_templates/calm/data/flows/add_contact.yml +31 -0
- rasa/cli/project_templates/calm/data/flows/list_contacts.yml +14 -0
- rasa/cli/project_templates/calm/data/flows/remove_contact.yml +29 -0
- rasa/cli/project_templates/calm/db/contacts.json +10 -0
- rasa/cli/project_templates/calm/domain/add_contact.yml +39 -0
- rasa/cli/project_templates/calm/domain/list_contacts.yml +17 -0
- rasa/cli/project_templates/calm/domain/remove_contact.yml +38 -0
- rasa/cli/project_templates/calm/domain/shared.yml +10 -0
- rasa/cli/project_templates/calm/e2e_tests/cancelations/user_cancels_during_a_correction.yml +16 -0
- rasa/cli/project_templates/calm/e2e_tests/cancelations/user_changes_mind_on_a_whim.yml +7 -0
- rasa/cli/project_templates/calm/e2e_tests/corrections/user_corrects_contact_handle.yml +20 -0
- rasa/cli/project_templates/calm/e2e_tests/corrections/user_corrects_contact_name.yml +19 -0
- rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_adds_contact_to_their_list.yml +15 -0
- rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_lists_contacts.yml +5 -0
- rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_removes_contact.yml +11 -0
- rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_removes_contact_from_list.yml +12 -0
- rasa/cli/project_templates/calm/endpoints.yml +58 -0
- rasa/cli/project_templates/default/actions/__init__.py +0 -0
- rasa/cli/project_templates/default/actions/actions.py +27 -0
- rasa/cli/project_templates/default/config.yml +44 -0
- rasa/cli/project_templates/default/credentials.yml +33 -0
- rasa/cli/project_templates/default/data/nlu.yml +91 -0
- rasa/cli/project_templates/default/data/rules.yml +13 -0
- rasa/cli/project_templates/default/data/stories.yml +30 -0
- rasa/cli/project_templates/default/domain.yml +34 -0
- rasa/cli/project_templates/default/endpoints.yml +42 -0
- rasa/cli/project_templates/default/tests/test_stories.yml +91 -0
- rasa/cli/project_templates/tutorial/actions/__init__.py +0 -0
- rasa/cli/project_templates/tutorial/actions/actions.py +22 -0
- rasa/cli/project_templates/tutorial/config.yml +12 -0
- rasa/cli/project_templates/tutorial/credentials.yml +33 -0
- rasa/cli/project_templates/tutorial/data/flows.yml +8 -0
- rasa/cli/project_templates/tutorial/data/patterns.yml +11 -0
- rasa/cli/project_templates/tutorial/domain.yml +35 -0
- rasa/cli/project_templates/tutorial/endpoints.yml +55 -0
- rasa/cli/run.py +143 -0
- rasa/cli/scaffold.py +273 -0
- rasa/cli/shell.py +141 -0
- rasa/cli/studio/__init__.py +0 -0
- rasa/cli/studio/download.py +62 -0
- rasa/cli/studio/studio.py +296 -0
- rasa/cli/studio/train.py +59 -0
- rasa/cli/studio/upload.py +62 -0
- rasa/cli/telemetry.py +102 -0
- rasa/cli/test.py +280 -0
- rasa/cli/train.py +278 -0
- rasa/cli/utils.py +484 -0
- rasa/cli/visualize.py +40 -0
- rasa/cli/x.py +206 -0
- rasa/constants.py +45 -0
- rasa/core/__init__.py +17 -0
- rasa/core/actions/__init__.py +0 -0
- rasa/core/actions/action.py +1318 -0
- rasa/core/actions/action_clean_stack.py +59 -0
- rasa/core/actions/action_exceptions.py +24 -0
- rasa/core/actions/action_hangup.py +29 -0
- rasa/core/actions/action_repeat_bot_messages.py +89 -0
- rasa/core/actions/action_run_slot_rejections.py +210 -0
- rasa/core/actions/action_trigger_chitchat.py +31 -0
- rasa/core/actions/action_trigger_flow.py +109 -0
- rasa/core/actions/action_trigger_search.py +31 -0
- rasa/core/actions/constants.py +5 -0
- rasa/core/actions/custom_action_executor.py +191 -0
- rasa/core/actions/direct_custom_actions_executor.py +109 -0
- rasa/core/actions/e2e_stub_custom_action_executor.py +72 -0
- rasa/core/actions/forms.py +741 -0
- rasa/core/actions/grpc_custom_action_executor.py +251 -0
- rasa/core/actions/http_custom_action_executor.py +145 -0
- rasa/core/actions/loops.py +114 -0
- rasa/core/actions/two_stage_fallback.py +186 -0
- rasa/core/agent.py +559 -0
- rasa/core/auth_retry_tracker_store.py +122 -0
- rasa/core/brokers/__init__.py +0 -0
- rasa/core/brokers/broker.py +126 -0
- rasa/core/brokers/file.py +58 -0
- rasa/core/brokers/kafka.py +324 -0
- rasa/core/brokers/pika.py +388 -0
- rasa/core/brokers/sql.py +86 -0
- rasa/core/channels/__init__.py +61 -0
- rasa/core/channels/botframework.py +338 -0
- rasa/core/channels/callback.py +84 -0
- rasa/core/channels/channel.py +456 -0
- rasa/core/channels/console.py +241 -0
- rasa/core/channels/development_inspector.py +197 -0
- rasa/core/channels/facebook.py +419 -0
- rasa/core/channels/hangouts.py +329 -0
- rasa/core/channels/inspector/.eslintrc.cjs +25 -0
- rasa/core/channels/inspector/.gitignore +23 -0
- rasa/core/channels/inspector/README.md +54 -0
- rasa/core/channels/inspector/assets/favicon.ico +0 -0
- rasa/core/channels/inspector/assets/rasa-chat.js +2 -0
- rasa/core/channels/inspector/custom.d.ts +3 -0
- rasa/core/channels/inspector/dist/assets/arc-861ddd57.js +1 -0
- rasa/core/channels/inspector/dist/assets/array-9f3ba611.js +1 -0
- rasa/core/channels/inspector/dist/assets/c4Diagram-d0fbc5ce-921f02db.js +10 -0
- rasa/core/channels/inspector/dist/assets/classDiagram-936ed81e-b436c4f8.js +2 -0
- rasa/core/channels/inspector/dist/assets/classDiagram-v2-c3cb15f1-511a23cb.js +2 -0
- rasa/core/channels/inspector/dist/assets/createText-62fc7601-ef476ecd.js +7 -0
- rasa/core/channels/inspector/dist/assets/edges-f2ad444c-f1878e0a.js +4 -0
- rasa/core/channels/inspector/dist/assets/erDiagram-9d236eb7-fac75185.js +51 -0
- rasa/core/channels/inspector/dist/assets/flowDb-1972c806-201c5bbc.js +6 -0
- rasa/core/channels/inspector/dist/assets/flowDiagram-7ea5b25a-f904ae41.js +4 -0
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-b080d6f2.js +1 -0
- rasa/core/channels/inspector/dist/assets/flowchart-elk-definition-abe16c3d-1813da66.js +139 -0
- rasa/core/channels/inspector/dist/assets/ganttDiagram-9b5ea136-872af172.js +266 -0
- rasa/core/channels/inspector/dist/assets/gitGraphDiagram-99d0ae7c-34a0af5a.js +70 -0
- rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-128cfa44.ttf +0 -0
- rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-21dbcb97.woff +0 -0
- rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-222b5e26.svg +329 -0
- rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-9ad89b2a.woff2 +0 -0
- rasa/core/channels/inspector/dist/assets/index-2c4b9a3b-42ba3e3d.js +1 -0
- rasa/core/channels/inspector/dist/assets/index-37817b51.js +1317 -0
- rasa/core/channels/inspector/dist/assets/index-3ee28881.css +1 -0
- rasa/core/channels/inspector/dist/assets/infoDiagram-736b4530-6b731386.js +7 -0
- rasa/core/channels/inspector/dist/assets/init-77b53fdd.js +1 -0
- rasa/core/channels/inspector/dist/assets/journeyDiagram-df861f2b-e8579ac6.js +139 -0
- rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-60c05ee4.woff +0 -0
- rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-8335d9b8.svg +438 -0
- rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-9cc39c75.ttf +0 -0
- rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-ead13ccf.woff2 +0 -0
- rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-16705655.woff2 +0 -0
- rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-5aeb07f9.woff +0 -0
- rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-9c459044.ttf +0 -0
- rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-9e2898a4.svg +435 -0
- rasa/core/channels/inspector/dist/assets/layout-89e6403a.js +1 -0
- rasa/core/channels/inspector/dist/assets/line-dc73d3fc.js +1 -0
- rasa/core/channels/inspector/dist/assets/linear-f5b1d2bc.js +1 -0
- rasa/core/channels/inspector/dist/assets/mindmap-definition-beec6740-82cb74fa.js +109 -0
- rasa/core/channels/inspector/dist/assets/ordinal-ba9b4969.js +1 -0
- rasa/core/channels/inspector/dist/assets/path-53f90ab3.js +1 -0
- rasa/core/channels/inspector/dist/assets/pieDiagram-dbbf0591-bdf5f29b.js +35 -0
- rasa/core/channels/inspector/dist/assets/quadrantDiagram-4d7f4fd6-c7a0cbe4.js +7 -0
- rasa/core/channels/inspector/dist/assets/requirementDiagram-6fc4c22a-7ec5410f.js +52 -0
- rasa/core/channels/inspector/dist/assets/sankeyDiagram-8f13d901-caee5554.js +8 -0
- rasa/core/channels/inspector/dist/assets/sequenceDiagram-b655622a-2935f8db.js +122 -0
- rasa/core/channels/inspector/dist/assets/stateDiagram-59f0c015-8f5d9693.js +1 -0
- rasa/core/channels/inspector/dist/assets/stateDiagram-v2-2b26beab-d565d1de.js +1 -0
- rasa/core/channels/inspector/dist/assets/styles-080da4f6-75ad421d.js +110 -0
- rasa/core/channels/inspector/dist/assets/styles-3dcbcfbf-7e764226.js +159 -0
- rasa/core/channels/inspector/dist/assets/styles-9c745c82-7a4e0e61.js +207 -0
- rasa/core/channels/inspector/dist/assets/svgDrawCommon-4835440b-4019d1bf.js +1 -0
- rasa/core/channels/inspector/dist/assets/timeline-definition-5b62e21b-01ea12df.js +61 -0
- rasa/core/channels/inspector/dist/assets/xychartDiagram-2b33534f-89407137.js +7 -0
- rasa/core/channels/inspector/dist/index.html +42 -0
- rasa/core/channels/inspector/index.html +40 -0
- rasa/core/channels/inspector/jest.config.ts +13 -0
- rasa/core/channels/inspector/package.json +52 -0
- rasa/core/channels/inspector/setupTests.ts +2 -0
- rasa/core/channels/inspector/src/App.tsx +220 -0
- rasa/core/channels/inspector/src/components/Chat.tsx +95 -0
- rasa/core/channels/inspector/src/components/DiagramFlow.tsx +108 -0
- rasa/core/channels/inspector/src/components/DialogueInformation.tsx +187 -0
- rasa/core/channels/inspector/src/components/DialogueStack.tsx +136 -0
- rasa/core/channels/inspector/src/components/ExpandIcon.tsx +16 -0
- rasa/core/channels/inspector/src/components/FullscreenButton.tsx +45 -0
- rasa/core/channels/inspector/src/components/LoadingSpinner.tsx +22 -0
- rasa/core/channels/inspector/src/components/NoActiveFlow.tsx +21 -0
- rasa/core/channels/inspector/src/components/RasaLogo.tsx +32 -0
- rasa/core/channels/inspector/src/components/SaraDiagrams.tsx +39 -0
- rasa/core/channels/inspector/src/components/Slots.tsx +91 -0
- rasa/core/channels/inspector/src/components/Welcome.tsx +54 -0
- rasa/core/channels/inspector/src/helpers/audiostream.ts +191 -0
- rasa/core/channels/inspector/src/helpers/formatters.test.ts +392 -0
- rasa/core/channels/inspector/src/helpers/formatters.ts +306 -0
- rasa/core/channels/inspector/src/helpers/utils.ts +127 -0
- rasa/core/channels/inspector/src/main.tsx +13 -0
- rasa/core/channels/inspector/src/theme/Button/Button.ts +29 -0
- rasa/core/channels/inspector/src/theme/Heading/Heading.ts +31 -0
- rasa/core/channels/inspector/src/theme/Input/Input.ts +27 -0
- rasa/core/channels/inspector/src/theme/Link/Link.ts +10 -0
- rasa/core/channels/inspector/src/theme/Modal/Modal.ts +47 -0
- rasa/core/channels/inspector/src/theme/Table/Table.tsx +38 -0
- rasa/core/channels/inspector/src/theme/Tooltip/Tooltip.ts +12 -0
- rasa/core/channels/inspector/src/theme/base/breakpoints.ts +8 -0
- rasa/core/channels/inspector/src/theme/base/colors.ts +88 -0
- rasa/core/channels/inspector/src/theme/base/fonts/fontFaces.css +29 -0
- rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.eot +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.svg +329 -0
- rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.ttf +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.woff +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.woff2 +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.eot +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.svg +438 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.ttf +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.woff +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.woff2 +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.eot +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.svg +435 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.ttf +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.woff +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.woff2 +0 -0
- rasa/core/channels/inspector/src/theme/base/radii.ts +9 -0
- rasa/core/channels/inspector/src/theme/base/shadows.ts +7 -0
- rasa/core/channels/inspector/src/theme/base/sizes.ts +7 -0
- rasa/core/channels/inspector/src/theme/base/space.ts +15 -0
- rasa/core/channels/inspector/src/theme/base/styles.ts +13 -0
- rasa/core/channels/inspector/src/theme/base/typography.ts +24 -0
- rasa/core/channels/inspector/src/theme/base/zIndices.ts +19 -0
- rasa/core/channels/inspector/src/theme/index.ts +101 -0
- rasa/core/channels/inspector/src/types.ts +84 -0
- rasa/core/channels/inspector/src/vite-env.d.ts +1 -0
- rasa/core/channels/inspector/tests/__mocks__/fileMock.ts +1 -0
- rasa/core/channels/inspector/tests/__mocks__/matchMedia.ts +16 -0
- rasa/core/channels/inspector/tests/__mocks__/styleMock.ts +1 -0
- rasa/core/channels/inspector/tests/renderWithProviders.tsx +14 -0
- rasa/core/channels/inspector/tsconfig.json +26 -0
- rasa/core/channels/inspector/tsconfig.node.json +10 -0
- rasa/core/channels/inspector/vite.config.ts +8 -0
- rasa/core/channels/inspector/yarn.lock +6249 -0
- rasa/core/channels/mattermost.py +229 -0
- rasa/core/channels/rasa_chat.py +126 -0
- rasa/core/channels/rest.py +230 -0
- rasa/core/channels/rocketchat.py +174 -0
- rasa/core/channels/slack.py +620 -0
- rasa/core/channels/socketio.py +302 -0
- rasa/core/channels/telegram.py +298 -0
- rasa/core/channels/twilio.py +169 -0
- rasa/core/channels/vier_cvg.py +374 -0
- rasa/core/channels/voice_ready/__init__.py +0 -0
- rasa/core/channels/voice_ready/audiocodes.py +501 -0
- rasa/core/channels/voice_ready/jambonz.py +121 -0
- rasa/core/channels/voice_ready/jambonz_protocol.py +396 -0
- rasa/core/channels/voice_ready/twilio_voice.py +403 -0
- rasa/core/channels/voice_ready/utils.py +37 -0
- rasa/core/channels/voice_stream/__init__.py +0 -0
- rasa/core/channels/voice_stream/asr/__init__.py +0 -0
- rasa/core/channels/voice_stream/asr/asr_engine.py +89 -0
- rasa/core/channels/voice_stream/asr/asr_event.py +18 -0
- rasa/core/channels/voice_stream/asr/azure.py +130 -0
- rasa/core/channels/voice_stream/asr/deepgram.py +90 -0
- rasa/core/channels/voice_stream/audio_bytes.py +8 -0
- rasa/core/channels/voice_stream/browser_audio.py +107 -0
- rasa/core/channels/voice_stream/call_state.py +23 -0
- rasa/core/channels/voice_stream/tts/__init__.py +0 -0
- rasa/core/channels/voice_stream/tts/azure.py +106 -0
- rasa/core/channels/voice_stream/tts/cartesia.py +118 -0
- rasa/core/channels/voice_stream/tts/tts_cache.py +27 -0
- rasa/core/channels/voice_stream/tts/tts_engine.py +58 -0
- rasa/core/channels/voice_stream/twilio_media_streams.py +173 -0
- rasa/core/channels/voice_stream/util.py +57 -0
- rasa/core/channels/voice_stream/voice_channel.py +427 -0
- rasa/core/channels/webexteams.py +134 -0
- rasa/core/concurrent_lock_store.py +210 -0
- rasa/core/constants.py +112 -0
- rasa/core/evaluation/__init__.py +0 -0
- rasa/core/evaluation/marker.py +267 -0
- rasa/core/evaluation/marker_base.py +923 -0
- rasa/core/evaluation/marker_stats.py +293 -0
- rasa/core/evaluation/marker_tracker_loader.py +103 -0
- rasa/core/exceptions.py +29 -0
- rasa/core/exporter.py +284 -0
- rasa/core/featurizers/__init__.py +0 -0
- rasa/core/featurizers/precomputation.py +410 -0
- rasa/core/featurizers/single_state_featurizer.py +421 -0
- rasa/core/featurizers/tracker_featurizers.py +1262 -0
- rasa/core/http_interpreter.py +89 -0
- rasa/core/information_retrieval/__init__.py +7 -0
- rasa/core/information_retrieval/faiss.py +124 -0
- rasa/core/information_retrieval/information_retrieval.py +137 -0
- rasa/core/information_retrieval/milvus.py +59 -0
- rasa/core/information_retrieval/qdrant.py +96 -0
- rasa/core/jobs.py +63 -0
- rasa/core/lock.py +139 -0
- rasa/core/lock_store.py +343 -0
- rasa/core/migrate.py +403 -0
- rasa/core/nlg/__init__.py +3 -0
- rasa/core/nlg/callback.py +146 -0
- rasa/core/nlg/contextual_response_rephraser.py +320 -0
- rasa/core/nlg/generator.py +230 -0
- rasa/core/nlg/interpolator.py +143 -0
- rasa/core/nlg/response.py +155 -0
- rasa/core/nlg/summarize.py +70 -0
- rasa/core/persistor.py +538 -0
- rasa/core/policies/__init__.py +0 -0
- rasa/core/policies/ensemble.py +329 -0
- rasa/core/policies/enterprise_search_policy.py +905 -0
- rasa/core/policies/enterprise_search_prompt_template.jinja2 +25 -0
- rasa/core/policies/enterprise_search_prompt_with_citation_template.jinja2 +60 -0
- rasa/core/policies/flow_policy.py +205 -0
- rasa/core/policies/flows/__init__.py +0 -0
- rasa/core/policies/flows/flow_exceptions.py +44 -0
- rasa/core/policies/flows/flow_executor.py +754 -0
- rasa/core/policies/flows/flow_step_result.py +43 -0
- rasa/core/policies/intentless_policy.py +1031 -0
- rasa/core/policies/intentless_prompt_template.jinja2 +22 -0
- rasa/core/policies/memoization.py +538 -0
- rasa/core/policies/policy.py +725 -0
- rasa/core/policies/rule_policy.py +1273 -0
- rasa/core/policies/ted_policy.py +2169 -0
- rasa/core/policies/unexpected_intent_policy.py +1022 -0
- rasa/core/processor.py +1465 -0
- rasa/core/run.py +342 -0
- rasa/core/secrets_manager/__init__.py +0 -0
- rasa/core/secrets_manager/constants.py +36 -0
- rasa/core/secrets_manager/endpoints.py +391 -0
- rasa/core/secrets_manager/factory.py +241 -0
- rasa/core/secrets_manager/secret_manager.py +262 -0
- rasa/core/secrets_manager/vault.py +584 -0
- rasa/core/test.py +1335 -0
- rasa/core/tracker_store.py +1703 -0
- rasa/core/train.py +105 -0
- rasa/core/training/__init__.py +89 -0
- rasa/core/training/converters/__init__.py +0 -0
- rasa/core/training/converters/responses_prefix_converter.py +119 -0
- rasa/core/training/interactive.py +1744 -0
- rasa/core/training/story_conflict.py +381 -0
- rasa/core/training/training.py +93 -0
- rasa/core/utils.py +366 -0
- rasa/core/visualize.py +70 -0
- rasa/dialogue_understanding/__init__.py +0 -0
- rasa/dialogue_understanding/coexistence/__init__.py +0 -0
- rasa/dialogue_understanding/coexistence/constants.py +4 -0
- rasa/dialogue_understanding/coexistence/intent_based_router.py +196 -0
- rasa/dialogue_understanding/coexistence/llm_based_router.py +327 -0
- rasa/dialogue_understanding/coexistence/router_template.jinja2 +12 -0
- rasa/dialogue_understanding/commands/__init__.py +61 -0
- rasa/dialogue_understanding/commands/can_not_handle_command.py +70 -0
- rasa/dialogue_understanding/commands/cancel_flow_command.py +125 -0
- rasa/dialogue_understanding/commands/change_flow_command.py +44 -0
- rasa/dialogue_understanding/commands/chit_chat_answer_command.py +57 -0
- rasa/dialogue_understanding/commands/clarify_command.py +86 -0
- rasa/dialogue_understanding/commands/command.py +85 -0
- rasa/dialogue_understanding/commands/correct_slots_command.py +297 -0
- rasa/dialogue_understanding/commands/error_command.py +79 -0
- rasa/dialogue_understanding/commands/free_form_answer_command.py +9 -0
- rasa/dialogue_understanding/commands/handle_code_change_command.py +73 -0
- rasa/dialogue_understanding/commands/human_handoff_command.py +66 -0
- rasa/dialogue_understanding/commands/knowledge_answer_command.py +57 -0
- rasa/dialogue_understanding/commands/noop_command.py +54 -0
- rasa/dialogue_understanding/commands/repeat_bot_messages_command.py +60 -0
- rasa/dialogue_understanding/commands/restart_command.py +58 -0
- rasa/dialogue_understanding/commands/session_end_command.py +61 -0
- rasa/dialogue_understanding/commands/session_start_command.py +59 -0
- rasa/dialogue_understanding/commands/set_slot_command.py +160 -0
- rasa/dialogue_understanding/commands/skip_question_command.py +75 -0
- rasa/dialogue_understanding/commands/start_flow_command.py +107 -0
- rasa/dialogue_understanding/commands/user_silence_command.py +59 -0
- rasa/dialogue_understanding/commands/utils.py +45 -0
- rasa/dialogue_understanding/generator/__init__.py +21 -0
- rasa/dialogue_understanding/generator/command_generator.py +464 -0
- rasa/dialogue_understanding/generator/constants.py +27 -0
- rasa/dialogue_understanding/generator/flow_document_template.jinja2 +4 -0
- rasa/dialogue_understanding/generator/flow_retrieval.py +466 -0
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +500 -0
- rasa/dialogue_understanding/generator/llm_command_generator.py +67 -0
- rasa/dialogue_understanding/generator/multi_step/__init__.py +0 -0
- rasa/dialogue_understanding/generator/multi_step/fill_slots_prompt.jinja2 +62 -0
- rasa/dialogue_understanding/generator/multi_step/handle_flows_prompt.jinja2 +38 -0
- rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +920 -0
- rasa/dialogue_understanding/generator/nlu_command_adapter.py +261 -0
- rasa/dialogue_understanding/generator/single_step/__init__.py +0 -0
- rasa/dialogue_understanding/generator/single_step/command_prompt_template.jinja2 +60 -0
- rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +486 -0
- rasa/dialogue_understanding/patterns/__init__.py +0 -0
- rasa/dialogue_understanding/patterns/cancel.py +111 -0
- rasa/dialogue_understanding/patterns/cannot_handle.py +43 -0
- rasa/dialogue_understanding/patterns/chitchat.py +37 -0
- rasa/dialogue_understanding/patterns/clarify.py +97 -0
- rasa/dialogue_understanding/patterns/code_change.py +41 -0
- rasa/dialogue_understanding/patterns/collect_information.py +90 -0
- rasa/dialogue_understanding/patterns/completed.py +40 -0
- rasa/dialogue_understanding/patterns/continue_interrupted.py +42 -0
- rasa/dialogue_understanding/patterns/correction.py +278 -0
- rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +301 -0
- rasa/dialogue_understanding/patterns/human_handoff.py +37 -0
- rasa/dialogue_understanding/patterns/internal_error.py +47 -0
- rasa/dialogue_understanding/patterns/repeat.py +37 -0
- rasa/dialogue_understanding/patterns/restart.py +37 -0
- rasa/dialogue_understanding/patterns/search.py +37 -0
- rasa/dialogue_understanding/patterns/session_start.py +37 -0
- rasa/dialogue_understanding/patterns/skip_question.py +38 -0
- rasa/dialogue_understanding/patterns/user_silence.py +37 -0
- rasa/dialogue_understanding/processor/__init__.py +0 -0
- rasa/dialogue_understanding/processor/command_processor.py +720 -0
- rasa/dialogue_understanding/processor/command_processor_component.py +43 -0
- rasa/dialogue_understanding/stack/__init__.py +0 -0
- rasa/dialogue_understanding/stack/dialogue_stack.py +178 -0
- rasa/dialogue_understanding/stack/frames/__init__.py +19 -0
- rasa/dialogue_understanding/stack/frames/chit_chat_frame.py +27 -0
- rasa/dialogue_understanding/stack/frames/dialogue_stack_frame.py +137 -0
- rasa/dialogue_understanding/stack/frames/flow_stack_frame.py +157 -0
- rasa/dialogue_understanding/stack/frames/pattern_frame.py +10 -0
- rasa/dialogue_understanding/stack/frames/search_frame.py +27 -0
- rasa/dialogue_understanding/stack/utils.py +211 -0
- rasa/dialogue_understanding/utils.py +14 -0
- rasa/dialogue_understanding_test/__init__.py +0 -0
- rasa/dialogue_understanding_test/command_metric_calculation.py +12 -0
- rasa/dialogue_understanding_test/constants.py +17 -0
- rasa/dialogue_understanding_test/du_test_case.py +118 -0
- rasa/dialogue_understanding_test/du_test_result.py +11 -0
- rasa/dialogue_understanding_test/du_test_runner.py +93 -0
- rasa/dialogue_understanding_test/io.py +54 -0
- rasa/dialogue_understanding_test/validation.py +22 -0
- rasa/e2e_test/__init__.py +0 -0
- rasa/e2e_test/aggregate_test_stats_calculator.py +134 -0
- rasa/e2e_test/assertions.py +1345 -0
- rasa/e2e_test/assertions_schema.yml +129 -0
- rasa/e2e_test/constants.py +31 -0
- rasa/e2e_test/e2e_config.py +220 -0
- rasa/e2e_test/e2e_config_schema.yml +26 -0
- rasa/e2e_test/e2e_test_case.py +569 -0
- rasa/e2e_test/e2e_test_converter.py +363 -0
- rasa/e2e_test/e2e_test_converter_prompt.jinja2 +70 -0
- rasa/e2e_test/e2e_test_coverage_report.py +364 -0
- rasa/e2e_test/e2e_test_result.py +54 -0
- rasa/e2e_test/e2e_test_runner.py +1192 -0
- rasa/e2e_test/e2e_test_schema.yml +181 -0
- rasa/e2e_test/pykwalify_extensions.py +39 -0
- rasa/e2e_test/stub_custom_action.py +70 -0
- rasa/e2e_test/utils/__init__.py +0 -0
- rasa/e2e_test/utils/e2e_yaml_utils.py +55 -0
- rasa/e2e_test/utils/io.py +598 -0
- rasa/e2e_test/utils/validation.py +178 -0
- rasa/engine/__init__.py +0 -0
- rasa/engine/caching.py +463 -0
- rasa/engine/constants.py +17 -0
- rasa/engine/exceptions.py +14 -0
- rasa/engine/graph.py +642 -0
- rasa/engine/loader.py +48 -0
- rasa/engine/recipes/__init__.py +0 -0
- rasa/engine/recipes/config_files/default_config.yml +41 -0
- rasa/engine/recipes/default_components.py +97 -0
- rasa/engine/recipes/default_recipe.py +1272 -0
- rasa/engine/recipes/graph_recipe.py +79 -0
- rasa/engine/recipes/recipe.py +93 -0
- rasa/engine/runner/__init__.py +0 -0
- rasa/engine/runner/dask.py +250 -0
- rasa/engine/runner/interface.py +49 -0
- rasa/engine/storage/__init__.py +0 -0
- rasa/engine/storage/local_model_storage.py +244 -0
- rasa/engine/storage/resource.py +110 -0
- rasa/engine/storage/storage.py +199 -0
- rasa/engine/training/__init__.py +0 -0
- rasa/engine/training/components.py +176 -0
- rasa/engine/training/fingerprinting.py +64 -0
- rasa/engine/training/graph_trainer.py +256 -0
- rasa/engine/training/hooks.py +164 -0
- rasa/engine/validation.py +1451 -0
- rasa/env.py +14 -0
- rasa/exceptions.py +69 -0
- rasa/graph_components/__init__.py +0 -0
- rasa/graph_components/converters/__init__.py +0 -0
- rasa/graph_components/converters/nlu_message_converter.py +48 -0
- rasa/graph_components/providers/__init__.py +0 -0
- rasa/graph_components/providers/domain_for_core_training_provider.py +87 -0
- rasa/graph_components/providers/domain_provider.py +71 -0
- rasa/graph_components/providers/flows_provider.py +74 -0
- rasa/graph_components/providers/forms_provider.py +44 -0
- rasa/graph_components/providers/nlu_training_data_provider.py +56 -0
- rasa/graph_components/providers/responses_provider.py +44 -0
- rasa/graph_components/providers/rule_only_provider.py +49 -0
- rasa/graph_components/providers/story_graph_provider.py +96 -0
- rasa/graph_components/providers/training_tracker_provider.py +55 -0
- rasa/graph_components/validators/__init__.py +0 -0
- rasa/graph_components/validators/default_recipe_validator.py +550 -0
- rasa/graph_components/validators/finetuning_validator.py +302 -0
- rasa/hooks.py +111 -0
- rasa/jupyter.py +63 -0
- rasa/llm_fine_tuning/__init__.py +0 -0
- rasa/llm_fine_tuning/annotation_module.py +241 -0
- rasa/llm_fine_tuning/conversations.py +144 -0
- rasa/llm_fine_tuning/llm_data_preparation_module.py +178 -0
- rasa/llm_fine_tuning/paraphrasing/__init__.py +0 -0
- rasa/llm_fine_tuning/paraphrasing/conversation_rephraser.py +281 -0
- rasa/llm_fine_tuning/paraphrasing/default_rephrase_prompt_template.jina2 +44 -0
- rasa/llm_fine_tuning/paraphrasing/rephrase_validator.py +121 -0
- rasa/llm_fine_tuning/paraphrasing/rephrased_user_message.py +10 -0
- rasa/llm_fine_tuning/paraphrasing_module.py +128 -0
- rasa/llm_fine_tuning/storage.py +174 -0
- rasa/llm_fine_tuning/train_test_split_module.py +441 -0
- rasa/markers/__init__.py +0 -0
- rasa/markers/marker.py +269 -0
- rasa/markers/marker_base.py +828 -0
- rasa/markers/upload.py +74 -0
- rasa/markers/validate.py +21 -0
- rasa/model.py +118 -0
- rasa/model_manager/__init__.py +0 -0
- rasa/model_manager/config.py +40 -0
- rasa/model_manager/model_api.py +559 -0
- rasa/model_manager/runner_service.py +286 -0
- rasa/model_manager/socket_bridge.py +146 -0
- rasa/model_manager/studio_jwt_auth.py +86 -0
- rasa/model_manager/trainer_service.py +325 -0
- rasa/model_manager/utils.py +87 -0
- rasa/model_manager/warm_rasa_process.py +187 -0
- rasa/model_service.py +112 -0
- rasa/model_testing.py +457 -0
- rasa/model_training.py +596 -0
- rasa/nlu/__init__.py +7 -0
- rasa/nlu/classifiers/__init__.py +3 -0
- rasa/nlu/classifiers/classifier.py +5 -0
- rasa/nlu/classifiers/diet_classifier.py +1881 -0
- rasa/nlu/classifiers/fallback_classifier.py +192 -0
- rasa/nlu/classifiers/keyword_intent_classifier.py +188 -0
- rasa/nlu/classifiers/logistic_regression_classifier.py +253 -0
- rasa/nlu/classifiers/mitie_intent_classifier.py +156 -0
- rasa/nlu/classifiers/regex_message_handler.py +56 -0
- rasa/nlu/classifiers/sklearn_intent_classifier.py +330 -0
- rasa/nlu/constants.py +77 -0
- rasa/nlu/convert.py +40 -0
- rasa/nlu/emulators/__init__.py +0 -0
- rasa/nlu/emulators/dialogflow.py +55 -0
- rasa/nlu/emulators/emulator.py +49 -0
- rasa/nlu/emulators/luis.py +86 -0
- rasa/nlu/emulators/no_emulator.py +10 -0
- rasa/nlu/emulators/wit.py +56 -0
- rasa/nlu/extractors/__init__.py +0 -0
- rasa/nlu/extractors/crf_entity_extractor.py +715 -0
- rasa/nlu/extractors/duckling_entity_extractor.py +206 -0
- rasa/nlu/extractors/entity_synonyms.py +178 -0
- rasa/nlu/extractors/extractor.py +470 -0
- rasa/nlu/extractors/mitie_entity_extractor.py +293 -0
- rasa/nlu/extractors/regex_entity_extractor.py +220 -0
- rasa/nlu/extractors/spacy_entity_extractor.py +95 -0
- rasa/nlu/featurizers/__init__.py +0 -0
- rasa/nlu/featurizers/dense_featurizer/__init__.py +0 -0
- rasa/nlu/featurizers/dense_featurizer/convert_featurizer.py +445 -0
- rasa/nlu/featurizers/dense_featurizer/dense_featurizer.py +57 -0
- rasa/nlu/featurizers/dense_featurizer/lm_featurizer.py +768 -0
- rasa/nlu/featurizers/dense_featurizer/mitie_featurizer.py +170 -0
- rasa/nlu/featurizers/dense_featurizer/spacy_featurizer.py +132 -0
- rasa/nlu/featurizers/featurizer.py +89 -0
- rasa/nlu/featurizers/sparse_featurizer/__init__.py +0 -0
- rasa/nlu/featurizers/sparse_featurizer/count_vectors_featurizer.py +867 -0
- rasa/nlu/featurizers/sparse_featurizer/lexical_syntactic_featurizer.py +571 -0
- rasa/nlu/featurizers/sparse_featurizer/regex_featurizer.py +271 -0
- rasa/nlu/featurizers/sparse_featurizer/sparse_featurizer.py +9 -0
- rasa/nlu/model.py +24 -0
- rasa/nlu/run.py +27 -0
- rasa/nlu/selectors/__init__.py +0 -0
- rasa/nlu/selectors/response_selector.py +987 -0
- rasa/nlu/test.py +1940 -0
- rasa/nlu/tokenizers/__init__.py +0 -0
- rasa/nlu/tokenizers/jieba_tokenizer.py +148 -0
- rasa/nlu/tokenizers/mitie_tokenizer.py +75 -0
- rasa/nlu/tokenizers/spacy_tokenizer.py +72 -0
- rasa/nlu/tokenizers/tokenizer.py +239 -0
- rasa/nlu/tokenizers/whitespace_tokenizer.py +95 -0
- rasa/nlu/utils/__init__.py +35 -0
- rasa/nlu/utils/bilou_utils.py +462 -0
- rasa/nlu/utils/hugging_face/__init__.py +0 -0
- rasa/nlu/utils/hugging_face/registry.py +108 -0
- rasa/nlu/utils/hugging_face/transformers_pre_post_processors.py +311 -0
- rasa/nlu/utils/mitie_utils.py +113 -0
- rasa/nlu/utils/pattern_utils.py +168 -0
- rasa/nlu/utils/spacy_utils.py +310 -0
- rasa/plugin.py +90 -0
- rasa/server.py +1588 -0
- rasa/shared/__init__.py +0 -0
- rasa/shared/constants.py +311 -0
- rasa/shared/core/__init__.py +0 -0
- rasa/shared/core/command_payload_reader.py +109 -0
- rasa/shared/core/constants.py +180 -0
- rasa/shared/core/conversation.py +46 -0
- rasa/shared/core/domain.py +2172 -0
- rasa/shared/core/events.py +2559 -0
- rasa/shared/core/flows/__init__.py +7 -0
- rasa/shared/core/flows/flow.py +562 -0
- rasa/shared/core/flows/flow_path.py +84 -0
- rasa/shared/core/flows/flow_step.py +146 -0
- rasa/shared/core/flows/flow_step_links.py +319 -0
- rasa/shared/core/flows/flow_step_sequence.py +70 -0
- rasa/shared/core/flows/flows_list.py +258 -0
- rasa/shared/core/flows/flows_yaml_schema.json +303 -0
- rasa/shared/core/flows/nlu_trigger.py +117 -0
- rasa/shared/core/flows/steps/__init__.py +24 -0
- rasa/shared/core/flows/steps/action.py +56 -0
- rasa/shared/core/flows/steps/call.py +64 -0
- rasa/shared/core/flows/steps/collect.py +112 -0
- rasa/shared/core/flows/steps/constants.py +5 -0
- rasa/shared/core/flows/steps/continuation.py +36 -0
- rasa/shared/core/flows/steps/end.py +22 -0
- rasa/shared/core/flows/steps/internal.py +44 -0
- rasa/shared/core/flows/steps/link.py +51 -0
- rasa/shared/core/flows/steps/no_operation.py +48 -0
- rasa/shared/core/flows/steps/set_slots.py +50 -0
- rasa/shared/core/flows/steps/start.py +30 -0
- rasa/shared/core/flows/utils.py +39 -0
- rasa/shared/core/flows/validation.py +735 -0
- rasa/shared/core/flows/yaml_flows_io.py +405 -0
- rasa/shared/core/generator.py +908 -0
- rasa/shared/core/slot_mappings.py +526 -0
- rasa/shared/core/slots.py +654 -0
- rasa/shared/core/trackers.py +1183 -0
- rasa/shared/core/training_data/__init__.py +0 -0
- rasa/shared/core/training_data/loading.py +89 -0
- rasa/shared/core/training_data/story_reader/__init__.py +0 -0
- rasa/shared/core/training_data/story_reader/story_reader.py +129 -0
- rasa/shared/core/training_data/story_reader/story_step_builder.py +168 -0
- rasa/shared/core/training_data/story_reader/yaml_story_reader.py +888 -0
- rasa/shared/core/training_data/story_writer/__init__.py +0 -0
- rasa/shared/core/training_data/story_writer/story_writer.py +76 -0
- rasa/shared/core/training_data/story_writer/yaml_story_writer.py +444 -0
- rasa/shared/core/training_data/structures.py +858 -0
- rasa/shared/core/training_data/visualization.html +146 -0
- rasa/shared/core/training_data/visualization.py +603 -0
- rasa/shared/data.py +249 -0
- rasa/shared/engine/__init__.py +0 -0
- rasa/shared/engine/caching.py +26 -0
- rasa/shared/exceptions.py +167 -0
- rasa/shared/importers/__init__.py +0 -0
- rasa/shared/importers/importer.py +770 -0
- rasa/shared/importers/multi_project.py +215 -0
- rasa/shared/importers/rasa.py +108 -0
- rasa/shared/importers/remote_importer.py +196 -0
- rasa/shared/importers/utils.py +36 -0
- rasa/shared/nlu/__init__.py +0 -0
- rasa/shared/nlu/constants.py +53 -0
- rasa/shared/nlu/interpreter.py +10 -0
- rasa/shared/nlu/training_data/__init__.py +0 -0
- rasa/shared/nlu/training_data/entities_parser.py +208 -0
- rasa/shared/nlu/training_data/features.py +492 -0
- rasa/shared/nlu/training_data/formats/__init__.py +10 -0
- rasa/shared/nlu/training_data/formats/dialogflow.py +163 -0
- rasa/shared/nlu/training_data/formats/luis.py +87 -0
- rasa/shared/nlu/training_data/formats/rasa.py +135 -0
- rasa/shared/nlu/training_data/formats/rasa_yaml.py +618 -0
- rasa/shared/nlu/training_data/formats/readerwriter.py +244 -0
- rasa/shared/nlu/training_data/formats/wit.py +52 -0
- rasa/shared/nlu/training_data/loading.py +137 -0
- rasa/shared/nlu/training_data/lookup_tables_parser.py +30 -0
- rasa/shared/nlu/training_data/message.py +490 -0
- rasa/shared/nlu/training_data/schemas/__init__.py +0 -0
- rasa/shared/nlu/training_data/schemas/data_schema.py +85 -0
- rasa/shared/nlu/training_data/schemas/nlu.yml +53 -0
- rasa/shared/nlu/training_data/schemas/responses.yml +70 -0
- rasa/shared/nlu/training_data/synonyms_parser.py +42 -0
- rasa/shared/nlu/training_data/training_data.py +729 -0
- rasa/shared/nlu/training_data/util.py +223 -0
- rasa/shared/providers/__init__.py +0 -0
- rasa/shared/providers/_configs/__init__.py +0 -0
- rasa/shared/providers/_configs/azure_openai_client_config.py +677 -0
- rasa/shared/providers/_configs/client_config.py +59 -0
- rasa/shared/providers/_configs/default_litellm_client_config.py +132 -0
- rasa/shared/providers/_configs/huggingface_local_embedding_client_config.py +236 -0
- rasa/shared/providers/_configs/litellm_router_client_config.py +222 -0
- rasa/shared/providers/_configs/model_group_config.py +173 -0
- rasa/shared/providers/_configs/openai_client_config.py +177 -0
- rasa/shared/providers/_configs/rasa_llm_client_config.py +75 -0
- rasa/shared/providers/_configs/self_hosted_llm_client_config.py +178 -0
- rasa/shared/providers/_configs/utils.py +117 -0
- rasa/shared/providers/_ssl_verification_utils.py +124 -0
- rasa/shared/providers/_utils.py +79 -0
- rasa/shared/providers/constants.py +7 -0
- rasa/shared/providers/embedding/__init__.py +0 -0
- rasa/shared/providers/embedding/_base_litellm_embedding_client.py +243 -0
- rasa/shared/providers/embedding/_langchain_embedding_client_adapter.py +74 -0
- rasa/shared/providers/embedding/azure_openai_embedding_client.py +335 -0
- rasa/shared/providers/embedding/default_litellm_embedding_client.py +126 -0
- rasa/shared/providers/embedding/embedding_client.py +90 -0
- rasa/shared/providers/embedding/embedding_response.py +41 -0
- rasa/shared/providers/embedding/huggingface_local_embedding_client.py +191 -0
- rasa/shared/providers/embedding/litellm_router_embedding_client.py +138 -0
- rasa/shared/providers/embedding/openai_embedding_client.py +172 -0
- rasa/shared/providers/llm/__init__.py +0 -0
- rasa/shared/providers/llm/_base_litellm_client.py +265 -0
- rasa/shared/providers/llm/azure_openai_llm_client.py +415 -0
- rasa/shared/providers/llm/default_litellm_llm_client.py +110 -0
- rasa/shared/providers/llm/litellm_router_llm_client.py +202 -0
- rasa/shared/providers/llm/llm_client.py +78 -0
- rasa/shared/providers/llm/llm_response.py +50 -0
- rasa/shared/providers/llm/openai_llm_client.py +161 -0
- rasa/shared/providers/llm/rasa_llm_client.py +120 -0
- rasa/shared/providers/llm/self_hosted_llm_client.py +276 -0
- rasa/shared/providers/mappings.py +94 -0
- rasa/shared/providers/router/__init__.py +0 -0
- rasa/shared/providers/router/_base_litellm_router_client.py +185 -0
- rasa/shared/providers/router/router_client.py +75 -0
- rasa/shared/utils/__init__.py +0 -0
- rasa/shared/utils/cli.py +102 -0
- rasa/shared/utils/common.py +324 -0
- rasa/shared/utils/constants.py +4 -0
- rasa/shared/utils/health_check/__init__.py +0 -0
- rasa/shared/utils/health_check/embeddings_health_check_mixin.py +31 -0
- rasa/shared/utils/health_check/health_check.py +258 -0
- rasa/shared/utils/health_check/llm_health_check_mixin.py +31 -0
- rasa/shared/utils/io.py +499 -0
- rasa/shared/utils/llm.py +764 -0
- rasa/shared/utils/pykwalify_extensions.py +27 -0
- rasa/shared/utils/schemas/__init__.py +0 -0
- rasa/shared/utils/schemas/config.yml +2 -0
- rasa/shared/utils/schemas/domain.yml +145 -0
- rasa/shared/utils/schemas/events.py +214 -0
- rasa/shared/utils/schemas/model_config.yml +36 -0
- rasa/shared/utils/schemas/stories.yml +173 -0
- rasa/shared/utils/yaml.py +1068 -0
- rasa/studio/__init__.py +0 -0
- rasa/studio/auth.py +270 -0
- rasa/studio/config.py +136 -0
- rasa/studio/constants.py +19 -0
- rasa/studio/data_handler.py +368 -0
- rasa/studio/download.py +489 -0
- rasa/studio/results_logger.py +137 -0
- rasa/studio/train.py +134 -0
- rasa/studio/upload.py +563 -0
- rasa/telemetry.py +1876 -0
- rasa/tracing/__init__.py +0 -0
- rasa/tracing/config.py +355 -0
- rasa/tracing/constants.py +62 -0
- rasa/tracing/instrumentation/__init__.py +0 -0
- rasa/tracing/instrumentation/attribute_extractors.py +765 -0
- rasa/tracing/instrumentation/instrumentation.py +1306 -0
- rasa/tracing/instrumentation/intentless_policy_instrumentation.py +144 -0
- rasa/tracing/instrumentation/metrics.py +294 -0
- rasa/tracing/metric_instrument_provider.py +205 -0
- rasa/utils/__init__.py +0 -0
- rasa/utils/beta.py +83 -0
- rasa/utils/cli.py +28 -0
- rasa/utils/common.py +639 -0
- rasa/utils/converter.py +53 -0
- rasa/utils/endpoints.py +331 -0
- rasa/utils/io.py +252 -0
- rasa/utils/json_utils.py +60 -0
- rasa/utils/licensing.py +542 -0
- rasa/utils/log_utils.py +181 -0
- rasa/utils/mapper.py +210 -0
- rasa/utils/ml_utils.py +147 -0
- rasa/utils/plotting.py +362 -0
- rasa/utils/sanic_error_handler.py +32 -0
- rasa/utils/singleton.py +23 -0
- rasa/utils/tensorflow/__init__.py +0 -0
- rasa/utils/tensorflow/callback.py +112 -0
- rasa/utils/tensorflow/constants.py +116 -0
- rasa/utils/tensorflow/crf.py +492 -0
- rasa/utils/tensorflow/data_generator.py +440 -0
- rasa/utils/tensorflow/environment.py +161 -0
- rasa/utils/tensorflow/exceptions.py +5 -0
- rasa/utils/tensorflow/feature_array.py +366 -0
- rasa/utils/tensorflow/layers.py +1565 -0
- rasa/utils/tensorflow/layers_utils.py +113 -0
- rasa/utils/tensorflow/metrics.py +281 -0
- rasa/utils/tensorflow/model_data.py +798 -0
- rasa/utils/tensorflow/model_data_utils.py +499 -0
- rasa/utils/tensorflow/models.py +935 -0
- rasa/utils/tensorflow/rasa_layers.py +1094 -0
- rasa/utils/tensorflow/transformer.py +640 -0
- rasa/utils/tensorflow/types.py +6 -0
- rasa/utils/train_utils.py +572 -0
- rasa/utils/url_tools.py +53 -0
- rasa/utils/yaml.py +54 -0
- rasa/validator.py +1644 -0
- rasa/version.py +3 -0
- rasa_pro-3.12.0.dev1.dist-info/METADATA +199 -0
- rasa_pro-3.12.0.dev1.dist-info/NOTICE +5 -0
- rasa_pro-3.12.0.dev1.dist-info/RECORD +790 -0
- rasa_pro-3.12.0.dev1.dist-info/WHEEL +4 -0
- rasa_pro-3.12.0.dev1.dist-info/entry_points.txt +3 -0
|
@@ -0,0 +1,1881 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import copy
|
|
4
|
+
import logging
|
|
5
|
+
from collections import defaultdict
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Any, Dict, List, Optional, Text, Tuple, Union, TypeVar, Type
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
import scipy.sparse
|
|
11
|
+
import tensorflow as tf
|
|
12
|
+
|
|
13
|
+
from rasa.exceptions import ModelNotFound
|
|
14
|
+
from rasa.nlu.featurizers.featurizer import Featurizer
|
|
15
|
+
from rasa.engine.graph import ExecutionContext, GraphComponent
|
|
16
|
+
from rasa.engine.recipes.default_recipe import DefaultV1Recipe
|
|
17
|
+
from rasa.engine.storage.resource import Resource
|
|
18
|
+
from rasa.engine.storage.storage import ModelStorage
|
|
19
|
+
from rasa.nlu.extractors.extractor import EntityExtractorMixin
|
|
20
|
+
from rasa.nlu.classifiers.classifier import IntentClassifier
|
|
21
|
+
import rasa.shared.utils.io
|
|
22
|
+
import rasa.nlu.utils.bilou_utils as bilou_utils
|
|
23
|
+
from rasa.shared.constants import DIAGNOSTIC_DATA
|
|
24
|
+
from rasa.nlu.extractors.extractor import EntityTagSpec
|
|
25
|
+
from rasa.nlu.classifiers import LABEL_RANKING_LENGTH
|
|
26
|
+
from rasa.utils import train_utils
|
|
27
|
+
from rasa.utils.tensorflow import rasa_layers
|
|
28
|
+
from rasa.utils.tensorflow.feature_array import (
|
|
29
|
+
FeatureArray,
|
|
30
|
+
serialize_nested_feature_arrays,
|
|
31
|
+
deserialize_nested_feature_arrays,
|
|
32
|
+
)
|
|
33
|
+
from rasa.utils.tensorflow.models import RasaModel, TransformerRasaModel
|
|
34
|
+
from rasa.utils.tensorflow.model_data import (
|
|
35
|
+
RasaModelData,
|
|
36
|
+
FeatureSignature,
|
|
37
|
+
)
|
|
38
|
+
from rasa.nlu.constants import TOKENS_NAMES, DEFAULT_TRANSFORMER_SIZE
|
|
39
|
+
from rasa.shared.nlu.constants import (
|
|
40
|
+
SPLIT_ENTITIES_BY_COMMA_DEFAULT_VALUE,
|
|
41
|
+
TEXT,
|
|
42
|
+
INTENT,
|
|
43
|
+
INTENT_RESPONSE_KEY,
|
|
44
|
+
ENTITIES,
|
|
45
|
+
ENTITY_ATTRIBUTE_TYPE,
|
|
46
|
+
ENTITY_ATTRIBUTE_GROUP,
|
|
47
|
+
ENTITY_ATTRIBUTE_ROLE,
|
|
48
|
+
NO_ENTITY_TAG,
|
|
49
|
+
SPLIT_ENTITIES_BY_COMMA,
|
|
50
|
+
)
|
|
51
|
+
from rasa.shared.exceptions import InvalidConfigException
|
|
52
|
+
from rasa.shared.nlu.training_data.training_data import TrainingData
|
|
53
|
+
from rasa.shared.nlu.training_data.message import Message
|
|
54
|
+
from rasa.utils.tensorflow.constants import (
|
|
55
|
+
DROP_SMALL_LAST_BATCH,
|
|
56
|
+
LABEL,
|
|
57
|
+
IDS,
|
|
58
|
+
HIDDEN_LAYERS_SIZES,
|
|
59
|
+
RENORMALIZE_CONFIDENCES,
|
|
60
|
+
SHARE_HIDDEN_LAYERS,
|
|
61
|
+
TRANSFORMER_SIZE,
|
|
62
|
+
NUM_TRANSFORMER_LAYERS,
|
|
63
|
+
NUM_HEADS,
|
|
64
|
+
BATCH_SIZES,
|
|
65
|
+
BATCH_STRATEGY,
|
|
66
|
+
EPOCHS,
|
|
67
|
+
RANDOM_SEED,
|
|
68
|
+
LEARNING_RATE,
|
|
69
|
+
RANKING_LENGTH,
|
|
70
|
+
LOSS_TYPE,
|
|
71
|
+
SIMILARITY_TYPE,
|
|
72
|
+
NUM_NEG,
|
|
73
|
+
SPARSE_INPUT_DROPOUT,
|
|
74
|
+
DENSE_INPUT_DROPOUT,
|
|
75
|
+
MASKED_LM,
|
|
76
|
+
ENTITY_RECOGNITION,
|
|
77
|
+
TENSORBOARD_LOG_DIR,
|
|
78
|
+
INTENT_CLASSIFICATION,
|
|
79
|
+
EVAL_NUM_EXAMPLES,
|
|
80
|
+
EVAL_NUM_EPOCHS,
|
|
81
|
+
UNIDIRECTIONAL_ENCODER,
|
|
82
|
+
DROP_RATE,
|
|
83
|
+
DROP_RATE_ATTENTION,
|
|
84
|
+
CONNECTION_DENSITY,
|
|
85
|
+
NEGATIVE_MARGIN_SCALE,
|
|
86
|
+
REGULARIZATION_CONSTANT,
|
|
87
|
+
SCALE_LOSS,
|
|
88
|
+
USE_MAX_NEG_SIM,
|
|
89
|
+
MAX_NEG_SIM,
|
|
90
|
+
MAX_POS_SIM,
|
|
91
|
+
EMBEDDING_DIMENSION,
|
|
92
|
+
BILOU_FLAG,
|
|
93
|
+
KEY_RELATIVE_ATTENTION,
|
|
94
|
+
VALUE_RELATIVE_ATTENTION,
|
|
95
|
+
MAX_RELATIVE_POSITION,
|
|
96
|
+
AUTO,
|
|
97
|
+
BALANCED,
|
|
98
|
+
CROSS_ENTROPY,
|
|
99
|
+
TENSORBOARD_LOG_LEVEL,
|
|
100
|
+
CONCAT_DIMENSION,
|
|
101
|
+
FEATURIZERS,
|
|
102
|
+
CHECKPOINT_MODEL,
|
|
103
|
+
SEQUENCE,
|
|
104
|
+
SENTENCE,
|
|
105
|
+
SEQUENCE_LENGTH,
|
|
106
|
+
DENSE_DIMENSION,
|
|
107
|
+
MASK,
|
|
108
|
+
CONSTRAIN_SIMILARITIES,
|
|
109
|
+
MODEL_CONFIDENCE,
|
|
110
|
+
SOFTMAX,
|
|
111
|
+
RUN_EAGERLY,
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
logger = logging.getLogger(__name__)
|
|
115
|
+
|
|
116
|
+
SPARSE = "sparse"
|
|
117
|
+
DENSE = "dense"
|
|
118
|
+
LABEL_KEY = LABEL
|
|
119
|
+
LABEL_SUB_KEY = IDS
|
|
120
|
+
|
|
121
|
+
POSSIBLE_TAGS = [ENTITY_ATTRIBUTE_TYPE, ENTITY_ATTRIBUTE_ROLE, ENTITY_ATTRIBUTE_GROUP]
|
|
122
|
+
|
|
123
|
+
DIETClassifierT = TypeVar("DIETClassifierT", bound="DIETClassifier")
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
@DefaultV1Recipe.register(
|
|
127
|
+
[
|
|
128
|
+
DefaultV1Recipe.ComponentType.INTENT_CLASSIFIER,
|
|
129
|
+
DefaultV1Recipe.ComponentType.ENTITY_EXTRACTOR,
|
|
130
|
+
],
|
|
131
|
+
is_trainable=True,
|
|
132
|
+
)
|
|
133
|
+
class DIETClassifier(GraphComponent, IntentClassifier, EntityExtractorMixin):
|
|
134
|
+
"""A multi-task model for intent classification and entity extraction.
|
|
135
|
+
|
|
136
|
+
DIET is Dual Intent and Entity Transformer.
|
|
137
|
+
The architecture is based on a transformer which is shared for both tasks.
|
|
138
|
+
A sequence of entity labels is predicted through a Conditional Random Field (CRF)
|
|
139
|
+
tagging layer on top of the transformer output sequence corresponding to the
|
|
140
|
+
input sequence of tokens. The transformer output for the ``__CLS__`` token and
|
|
141
|
+
intent labels are embedded into a single semantic vector space. We use the
|
|
142
|
+
dot-product loss to maximize the similarity with the target label and minimize
|
|
143
|
+
similarities with negative samples.
|
|
144
|
+
"""
|
|
145
|
+
|
|
146
|
+
@classmethod
|
|
147
|
+
def required_components(cls) -> List[Type]:
|
|
148
|
+
"""Components that should be included in the pipeline before this component."""
|
|
149
|
+
return [Featurizer]
|
|
150
|
+
|
|
151
|
+
@staticmethod
|
|
152
|
+
def get_default_config() -> Dict[Text, Any]:
|
|
153
|
+
"""The component's default config (see parent class for full docstring)."""
|
|
154
|
+
# please make sure to update the docs when changing a default parameter
|
|
155
|
+
return {
|
|
156
|
+
# ## Architecture of the used neural network
|
|
157
|
+
# Hidden layer sizes for layers before the embedding layers for user message
|
|
158
|
+
# and labels.
|
|
159
|
+
# The number of hidden layers is equal to the length of the corresponding
|
|
160
|
+
# list.
|
|
161
|
+
HIDDEN_LAYERS_SIZES: {TEXT: [], LABEL: []},
|
|
162
|
+
# Whether to share the hidden layer weights between user message and labels.
|
|
163
|
+
SHARE_HIDDEN_LAYERS: False,
|
|
164
|
+
# Number of units in transformer
|
|
165
|
+
TRANSFORMER_SIZE: DEFAULT_TRANSFORMER_SIZE,
|
|
166
|
+
# Number of transformer layers
|
|
167
|
+
NUM_TRANSFORMER_LAYERS: 2,
|
|
168
|
+
# Number of attention heads in transformer
|
|
169
|
+
NUM_HEADS: 4,
|
|
170
|
+
# If 'True' use key relative embeddings in attention
|
|
171
|
+
KEY_RELATIVE_ATTENTION: False,
|
|
172
|
+
# If 'True' use value relative embeddings in attention
|
|
173
|
+
VALUE_RELATIVE_ATTENTION: False,
|
|
174
|
+
# Max position for relative embeddings. Only in effect if key- or value
|
|
175
|
+
# relative attention are turned on
|
|
176
|
+
MAX_RELATIVE_POSITION: 5,
|
|
177
|
+
# Use a unidirectional or bidirectional encoder.
|
|
178
|
+
UNIDIRECTIONAL_ENCODER: False,
|
|
179
|
+
# ## Training parameters
|
|
180
|
+
# Initial and final batch sizes:
|
|
181
|
+
# Batch size will be linearly increased for each epoch.
|
|
182
|
+
BATCH_SIZES: [64, 256],
|
|
183
|
+
# Strategy used when creating batches.
|
|
184
|
+
# Can be either 'sequence' or 'balanced'.
|
|
185
|
+
BATCH_STRATEGY: BALANCED,
|
|
186
|
+
# Number of epochs to train
|
|
187
|
+
EPOCHS: 300,
|
|
188
|
+
# Set random seed to any 'int' to get reproducible results
|
|
189
|
+
RANDOM_SEED: None,
|
|
190
|
+
# Initial learning rate for the optimizer
|
|
191
|
+
LEARNING_RATE: 0.001,
|
|
192
|
+
# ## Parameters for embeddings
|
|
193
|
+
# Dimension size of embedding vectors
|
|
194
|
+
EMBEDDING_DIMENSION: 20,
|
|
195
|
+
# Dense dimension to use for sparse features.
|
|
196
|
+
DENSE_DIMENSION: {TEXT: 128, LABEL: 20},
|
|
197
|
+
# Default dimension to use for concatenating sequence and sentence features.
|
|
198
|
+
CONCAT_DIMENSION: {TEXT: 128, LABEL: 20},
|
|
199
|
+
# The number of incorrect labels. The algorithm will minimize
|
|
200
|
+
# their similarity to the user input during training.
|
|
201
|
+
NUM_NEG: 20,
|
|
202
|
+
# Type of similarity measure to use, either 'auto' or 'cosine' or 'inner'.
|
|
203
|
+
SIMILARITY_TYPE: AUTO,
|
|
204
|
+
# The type of the loss function, either 'cross_entropy' or 'margin'.
|
|
205
|
+
LOSS_TYPE: CROSS_ENTROPY,
|
|
206
|
+
# Number of top intents for which confidences should be reported.
|
|
207
|
+
# Set to 0 if confidences for all intents should be reported.
|
|
208
|
+
RANKING_LENGTH: LABEL_RANKING_LENGTH,
|
|
209
|
+
# Indicates how similar the algorithm should try to make embedding vectors
|
|
210
|
+
# for correct labels.
|
|
211
|
+
# Should be 0.0 < ... < 1.0 for 'cosine' similarity type.
|
|
212
|
+
MAX_POS_SIM: 0.8,
|
|
213
|
+
# Maximum negative similarity for incorrect labels.
|
|
214
|
+
# Should be -1.0 < ... < 1.0 for 'cosine' similarity type.
|
|
215
|
+
MAX_NEG_SIM: -0.4,
|
|
216
|
+
# If 'True' the algorithm only minimizes maximum similarity over
|
|
217
|
+
# incorrect intent labels, used only if 'loss_type' is set to 'margin'.
|
|
218
|
+
USE_MAX_NEG_SIM: True,
|
|
219
|
+
# If 'True' scale loss inverse proportionally to the confidence
|
|
220
|
+
# of the correct prediction
|
|
221
|
+
SCALE_LOSS: False,
|
|
222
|
+
# ## Regularization parameters
|
|
223
|
+
# The scale of regularization
|
|
224
|
+
REGULARIZATION_CONSTANT: 0.002,
|
|
225
|
+
# The scale of how important is to minimize the maximum similarity
|
|
226
|
+
# between embeddings of different labels,
|
|
227
|
+
# used only if 'loss_type' is set to 'margin'.
|
|
228
|
+
NEGATIVE_MARGIN_SCALE: 0.8,
|
|
229
|
+
# Dropout rate for encoder
|
|
230
|
+
DROP_RATE: 0.2,
|
|
231
|
+
# Dropout rate for attention
|
|
232
|
+
DROP_RATE_ATTENTION: 0,
|
|
233
|
+
# Fraction of trainable weights in internal layers.
|
|
234
|
+
CONNECTION_DENSITY: 0.2,
|
|
235
|
+
# If 'True' apply dropout to sparse input tensors
|
|
236
|
+
SPARSE_INPUT_DROPOUT: True,
|
|
237
|
+
# If 'True' apply dropout to dense input tensors
|
|
238
|
+
DENSE_INPUT_DROPOUT: True,
|
|
239
|
+
# ## Evaluation parameters
|
|
240
|
+
# How often calculate validation accuracy.
|
|
241
|
+
# Small values may hurt performance.
|
|
242
|
+
EVAL_NUM_EPOCHS: 20,
|
|
243
|
+
# How many examples to use for hold out validation set
|
|
244
|
+
# Large values may hurt performance, e.g. model accuracy.
|
|
245
|
+
# Set to 0 for no validation.
|
|
246
|
+
EVAL_NUM_EXAMPLES: 0,
|
|
247
|
+
# ## Model config
|
|
248
|
+
# If 'True' intent classification is trained and intent predicted.
|
|
249
|
+
INTENT_CLASSIFICATION: True,
|
|
250
|
+
# If 'True' named entity recognition is trained and entities predicted.
|
|
251
|
+
ENTITY_RECOGNITION: True,
|
|
252
|
+
# If 'True' random tokens of the input message will be masked and the model
|
|
253
|
+
# should predict those tokens.
|
|
254
|
+
MASKED_LM: False,
|
|
255
|
+
# 'BILOU_flag' determines whether to use BILOU tagging or not.
|
|
256
|
+
# If set to 'True' labelling is more rigorous, however more
|
|
257
|
+
# examples per entity are required.
|
|
258
|
+
# Rule of thumb: you should have more than 100 examples per entity.
|
|
259
|
+
BILOU_FLAG: True,
|
|
260
|
+
# If you want to use tensorboard to visualize training and validation
|
|
261
|
+
# metrics, set this option to a valid output directory.
|
|
262
|
+
TENSORBOARD_LOG_DIR: None,
|
|
263
|
+
# Define when training metrics for tensorboard should be logged.
|
|
264
|
+
# Either after every epoch or for every training step.
|
|
265
|
+
# Valid values: 'epoch' and 'batch'
|
|
266
|
+
TENSORBOARD_LOG_LEVEL: "epoch",
|
|
267
|
+
# Perform model checkpointing
|
|
268
|
+
CHECKPOINT_MODEL: False,
|
|
269
|
+
# Specify what features to use as sequence and sentence features
|
|
270
|
+
# By default all features in the pipeline are used.
|
|
271
|
+
FEATURIZERS: [],
|
|
272
|
+
# Split entities by comma, this makes sense e.g. for a list of ingredients
|
|
273
|
+
# in a recipie, but it doesn't make sense for the parts of an address
|
|
274
|
+
SPLIT_ENTITIES_BY_COMMA: True,
|
|
275
|
+
# If 'True' applies sigmoid on all similarity terms and adds
|
|
276
|
+
# it to the loss function to ensure that similarity values are
|
|
277
|
+
# approximately bounded. Used inside cross-entropy loss only.
|
|
278
|
+
CONSTRAIN_SIMILARITIES: False,
|
|
279
|
+
# Model confidence to be returned during inference. Currently, the only
|
|
280
|
+
# possible value is `softmax`.
|
|
281
|
+
MODEL_CONFIDENCE: SOFTMAX,
|
|
282
|
+
# Determines whether the confidences of the chosen top intents should be
|
|
283
|
+
# renormalized so that they sum up to 1. By default, we do not renormalize
|
|
284
|
+
# and return the confidences for the top intents as is.
|
|
285
|
+
# Note that renormalization only makes sense if confidences are generated
|
|
286
|
+
# via `softmax`.
|
|
287
|
+
RENORMALIZE_CONFIDENCES: False,
|
|
288
|
+
# Determines whether to construct the model graph or not.
|
|
289
|
+
# This is advantageous when the model is only trained or inferred for
|
|
290
|
+
# a few steps, as the compilation of the graph tends to take more time than
|
|
291
|
+
# running it. It is recommended to not adjust the optimization parameter.
|
|
292
|
+
RUN_EAGERLY: False,
|
|
293
|
+
# Determines whether the last batch should be dropped if it contains fewer
|
|
294
|
+
# than half a batch size of examples
|
|
295
|
+
DROP_SMALL_LAST_BATCH: False,
|
|
296
|
+
}
|
|
297
|
+
|
|
298
|
+
def __init__(
|
|
299
|
+
self,
|
|
300
|
+
config: Dict[Text, Any],
|
|
301
|
+
model_storage: ModelStorage,
|
|
302
|
+
resource: Resource,
|
|
303
|
+
execution_context: ExecutionContext,
|
|
304
|
+
index_label_id_mapping: Optional[Dict[int, Text]] = None,
|
|
305
|
+
entity_tag_specs: Optional[List[EntityTagSpec]] = None,
|
|
306
|
+
model: Optional[RasaModel] = None,
|
|
307
|
+
sparse_feature_sizes: Optional[Dict[Text, Dict[Text, List[int]]]] = None,
|
|
308
|
+
) -> None:
|
|
309
|
+
"""Declare instance variables with default values."""
|
|
310
|
+
if EPOCHS not in config:
|
|
311
|
+
rasa.shared.utils.io.raise_warning(
|
|
312
|
+
f"Please configure the number of '{EPOCHS}' in your configuration file."
|
|
313
|
+
f" We will change the default value of '{EPOCHS}' in the future to 1. "
|
|
314
|
+
)
|
|
315
|
+
|
|
316
|
+
self.component_config = config
|
|
317
|
+
self._model_storage = model_storage
|
|
318
|
+
self._resource = resource
|
|
319
|
+
self._execution_context = execution_context
|
|
320
|
+
|
|
321
|
+
self._check_config_parameters()
|
|
322
|
+
|
|
323
|
+
# transform numbers to labels
|
|
324
|
+
self.index_label_id_mapping = index_label_id_mapping or {}
|
|
325
|
+
|
|
326
|
+
self._entity_tag_specs = entity_tag_specs
|
|
327
|
+
|
|
328
|
+
self.model = model
|
|
329
|
+
|
|
330
|
+
self.tmp_checkpoint_dir = None
|
|
331
|
+
if self.component_config[CHECKPOINT_MODEL]:
|
|
332
|
+
self.tmp_checkpoint_dir = Path(rasa.utils.io.create_temporary_directory())
|
|
333
|
+
|
|
334
|
+
self._label_data: Optional[RasaModelData] = None
|
|
335
|
+
self._data_example: Optional[Dict[Text, Dict[Text, List[FeatureArray]]]] = None
|
|
336
|
+
|
|
337
|
+
self.split_entities_config = rasa.utils.train_utils.init_split_entities(
|
|
338
|
+
self.component_config[SPLIT_ENTITIES_BY_COMMA],
|
|
339
|
+
SPLIT_ENTITIES_BY_COMMA_DEFAULT_VALUE,
|
|
340
|
+
)
|
|
341
|
+
|
|
342
|
+
self.finetune_mode = self._execution_context.is_finetuning
|
|
343
|
+
self._sparse_feature_sizes = sparse_feature_sizes
|
|
344
|
+
|
|
345
|
+
# init helpers
|
|
346
|
+
def _check_masked_lm(self) -> None:
|
|
347
|
+
if (
|
|
348
|
+
self.component_config[MASKED_LM]
|
|
349
|
+
and self.component_config[NUM_TRANSFORMER_LAYERS] == 0
|
|
350
|
+
):
|
|
351
|
+
raise ValueError(
|
|
352
|
+
f"If number of transformer layers is 0, "
|
|
353
|
+
f"'{MASKED_LM}' option should be 'False'."
|
|
354
|
+
)
|
|
355
|
+
|
|
356
|
+
def _check_share_hidden_layers_sizes(self) -> None:
|
|
357
|
+
if self.component_config.get(SHARE_HIDDEN_LAYERS):
|
|
358
|
+
first_hidden_layer_sizes = next(
|
|
359
|
+
iter(self.component_config[HIDDEN_LAYERS_SIZES].values())
|
|
360
|
+
)
|
|
361
|
+
# check that all hidden layer sizes are the same
|
|
362
|
+
identical_hidden_layer_sizes = all(
|
|
363
|
+
current_hidden_layer_sizes == first_hidden_layer_sizes
|
|
364
|
+
for current_hidden_layer_sizes in self.component_config[
|
|
365
|
+
HIDDEN_LAYERS_SIZES
|
|
366
|
+
].values()
|
|
367
|
+
)
|
|
368
|
+
if not identical_hidden_layer_sizes:
|
|
369
|
+
raise ValueError(
|
|
370
|
+
f"If hidden layer weights are shared, "
|
|
371
|
+
f"{HIDDEN_LAYERS_SIZES} must coincide."
|
|
372
|
+
)
|
|
373
|
+
|
|
374
|
+
def _check_config_parameters(self) -> None:
|
|
375
|
+
self.component_config = train_utils.check_deprecated_options(
|
|
376
|
+
self.component_config
|
|
377
|
+
)
|
|
378
|
+
|
|
379
|
+
self._check_masked_lm()
|
|
380
|
+
self._check_share_hidden_layers_sizes()
|
|
381
|
+
|
|
382
|
+
self.component_config = train_utils.update_confidence_type(
|
|
383
|
+
self.component_config
|
|
384
|
+
)
|
|
385
|
+
|
|
386
|
+
train_utils.validate_configuration_settings(self.component_config)
|
|
387
|
+
|
|
388
|
+
self.component_config = train_utils.update_similarity_type(
|
|
389
|
+
self.component_config
|
|
390
|
+
)
|
|
391
|
+
self.component_config = train_utils.update_evaluation_parameters(
|
|
392
|
+
self.component_config
|
|
393
|
+
)
|
|
394
|
+
|
|
395
|
+
@classmethod
|
|
396
|
+
def create(
|
|
397
|
+
cls,
|
|
398
|
+
config: Dict[Text, Any],
|
|
399
|
+
model_storage: ModelStorage,
|
|
400
|
+
resource: Resource,
|
|
401
|
+
execution_context: ExecutionContext,
|
|
402
|
+
) -> DIETClassifier:
|
|
403
|
+
"""Creates a new untrained component (see parent class for full docstring)."""
|
|
404
|
+
return cls(config, model_storage, resource, execution_context)
|
|
405
|
+
|
|
406
|
+
@property
|
|
407
|
+
def label_key(self) -> Optional[Text]:
|
|
408
|
+
"""Return key if intent classification is activated."""
|
|
409
|
+
return LABEL_KEY if self.component_config[INTENT_CLASSIFICATION] else None
|
|
410
|
+
|
|
411
|
+
@property
|
|
412
|
+
def label_sub_key(self) -> Optional[Text]:
|
|
413
|
+
"""Return sub key if intent classification is activated."""
|
|
414
|
+
return LABEL_SUB_KEY if self.component_config[INTENT_CLASSIFICATION] else None
|
|
415
|
+
|
|
416
|
+
@staticmethod
|
|
417
|
+
def model_class() -> Type[RasaModel]:
|
|
418
|
+
return DIET
|
|
419
|
+
|
|
420
|
+
# training data helpers:
|
|
421
|
+
@staticmethod
|
|
422
|
+
def _label_id_index_mapping(
|
|
423
|
+
training_data: TrainingData, attribute: Text
|
|
424
|
+
) -> Dict[Text, int]:
|
|
425
|
+
"""Create label_id dictionary."""
|
|
426
|
+
distinct_label_ids = {
|
|
427
|
+
example.get(attribute) for example in training_data.intent_examples
|
|
428
|
+
} - {None}
|
|
429
|
+
return {
|
|
430
|
+
label_id: idx for idx, label_id in enumerate(sorted(distinct_label_ids))
|
|
431
|
+
}
|
|
432
|
+
|
|
433
|
+
@staticmethod
|
|
434
|
+
def _invert_mapping(mapping: Dict) -> Dict:
|
|
435
|
+
return {value: key for key, value in mapping.items()}
|
|
436
|
+
|
|
437
|
+
def _create_entity_tag_specs(
|
|
438
|
+
self, training_data: TrainingData
|
|
439
|
+
) -> List[EntityTagSpec]:
|
|
440
|
+
"""Create entity tag specifications with their respective tag id mappings."""
|
|
441
|
+
_tag_specs = []
|
|
442
|
+
|
|
443
|
+
for tag_name in POSSIBLE_TAGS:
|
|
444
|
+
if self.component_config[BILOU_FLAG]:
|
|
445
|
+
tag_id_index_mapping = bilou_utils.build_tag_id_dict(
|
|
446
|
+
training_data, tag_name
|
|
447
|
+
)
|
|
448
|
+
else:
|
|
449
|
+
tag_id_index_mapping = self._tag_id_index_mapping_for(
|
|
450
|
+
tag_name, training_data
|
|
451
|
+
)
|
|
452
|
+
|
|
453
|
+
if tag_id_index_mapping:
|
|
454
|
+
_tag_specs.append(
|
|
455
|
+
EntityTagSpec(
|
|
456
|
+
tag_name=tag_name,
|
|
457
|
+
tags_to_ids=tag_id_index_mapping,
|
|
458
|
+
ids_to_tags=self._invert_mapping(tag_id_index_mapping),
|
|
459
|
+
num_tags=len(tag_id_index_mapping),
|
|
460
|
+
)
|
|
461
|
+
)
|
|
462
|
+
|
|
463
|
+
return _tag_specs
|
|
464
|
+
|
|
465
|
+
@staticmethod
|
|
466
|
+
def _tag_id_index_mapping_for(
|
|
467
|
+
tag_name: Text, training_data: TrainingData
|
|
468
|
+
) -> Optional[Dict[Text, int]]:
|
|
469
|
+
"""Create mapping from tag name to id."""
|
|
470
|
+
if tag_name == ENTITY_ATTRIBUTE_ROLE:
|
|
471
|
+
distinct_tags = training_data.entity_roles
|
|
472
|
+
elif tag_name == ENTITY_ATTRIBUTE_GROUP:
|
|
473
|
+
distinct_tags = training_data.entity_groups
|
|
474
|
+
else:
|
|
475
|
+
distinct_tags = training_data.entities
|
|
476
|
+
|
|
477
|
+
distinct_tags = distinct_tags - {NO_ENTITY_TAG} - {None}
|
|
478
|
+
|
|
479
|
+
if not distinct_tags:
|
|
480
|
+
return None
|
|
481
|
+
|
|
482
|
+
tag_id_dict = {
|
|
483
|
+
tag_id: idx for idx, tag_id in enumerate(sorted(distinct_tags), 1)
|
|
484
|
+
}
|
|
485
|
+
# NO_ENTITY_TAG corresponds to non-entity which should correspond to 0 index
|
|
486
|
+
# needed for correct prediction for padding
|
|
487
|
+
tag_id_dict[NO_ENTITY_TAG] = 0
|
|
488
|
+
|
|
489
|
+
return tag_id_dict
|
|
490
|
+
|
|
491
|
+
@staticmethod
|
|
492
|
+
def _find_example_for_label(
|
|
493
|
+
label: Text, examples: List[Message], attribute: Text
|
|
494
|
+
) -> Optional[Message]:
|
|
495
|
+
for ex in examples:
|
|
496
|
+
if ex.get(attribute) == label:
|
|
497
|
+
return ex
|
|
498
|
+
return None
|
|
499
|
+
|
|
500
|
+
def _check_labels_features_exist(
|
|
501
|
+
self, labels_example: List[Message], attribute: Text
|
|
502
|
+
) -> bool:
|
|
503
|
+
"""Checks if all labels have features set."""
|
|
504
|
+
return all(
|
|
505
|
+
label_example.features_present(
|
|
506
|
+
attribute, self.component_config[FEATURIZERS]
|
|
507
|
+
)
|
|
508
|
+
for label_example in labels_example
|
|
509
|
+
)
|
|
510
|
+
|
|
511
|
+
def _extract_features(
|
|
512
|
+
self, message: Message, attribute: Text
|
|
513
|
+
) -> Dict[Text, Union[scipy.sparse.spmatrix, np.ndarray]]:
|
|
514
|
+
(
|
|
515
|
+
sparse_sequence_features,
|
|
516
|
+
sparse_sentence_features,
|
|
517
|
+
) = message.get_sparse_features(attribute, self.component_config[FEATURIZERS])
|
|
518
|
+
dense_sequence_features, dense_sentence_features = message.get_dense_features(
|
|
519
|
+
attribute, self.component_config[FEATURIZERS]
|
|
520
|
+
)
|
|
521
|
+
|
|
522
|
+
if dense_sequence_features is not None and sparse_sequence_features is not None:
|
|
523
|
+
if (
|
|
524
|
+
dense_sequence_features.features.shape[0]
|
|
525
|
+
!= sparse_sequence_features.features.shape[0]
|
|
526
|
+
):
|
|
527
|
+
raise ValueError(
|
|
528
|
+
f"Sequence dimensions for sparse and dense sequence features "
|
|
529
|
+
f"don't coincide in '{message.get(TEXT)}'"
|
|
530
|
+
f"for attribute '{attribute}'."
|
|
531
|
+
)
|
|
532
|
+
if dense_sentence_features is not None and sparse_sentence_features is not None:
|
|
533
|
+
if (
|
|
534
|
+
dense_sentence_features.features.shape[0]
|
|
535
|
+
!= sparse_sentence_features.features.shape[0]
|
|
536
|
+
):
|
|
537
|
+
raise ValueError(
|
|
538
|
+
f"Sequence dimensions for sparse and dense sentence features "
|
|
539
|
+
f"don't coincide in '{message.get(TEXT)}'"
|
|
540
|
+
f"for attribute '{attribute}'."
|
|
541
|
+
)
|
|
542
|
+
|
|
543
|
+
# If we don't use the transformer and we don't want to do entity recognition,
|
|
544
|
+
# to speed up training take only the sentence features as feature vector.
|
|
545
|
+
# We would not make use of the sequence anyway in this setup. Carrying over
|
|
546
|
+
# those features to the actual training process takes quite some time.
|
|
547
|
+
if (
|
|
548
|
+
self.component_config[NUM_TRANSFORMER_LAYERS] == 0
|
|
549
|
+
and not self.component_config[ENTITY_RECOGNITION]
|
|
550
|
+
and attribute not in [INTENT, INTENT_RESPONSE_KEY]
|
|
551
|
+
):
|
|
552
|
+
sparse_sequence_features = None
|
|
553
|
+
dense_sequence_features = None
|
|
554
|
+
|
|
555
|
+
out = {}
|
|
556
|
+
|
|
557
|
+
if sparse_sentence_features is not None:
|
|
558
|
+
out[f"{SPARSE}_{SENTENCE}"] = sparse_sentence_features.features
|
|
559
|
+
if sparse_sequence_features is not None:
|
|
560
|
+
out[f"{SPARSE}_{SEQUENCE}"] = sparse_sequence_features.features
|
|
561
|
+
if dense_sentence_features is not None:
|
|
562
|
+
out[f"{DENSE}_{SENTENCE}"] = dense_sentence_features.features
|
|
563
|
+
if dense_sequence_features is not None:
|
|
564
|
+
out[f"{DENSE}_{SEQUENCE}"] = dense_sequence_features.features
|
|
565
|
+
|
|
566
|
+
return out
|
|
567
|
+
|
|
568
|
+
def _check_input_dimension_consistency(self, model_data: RasaModelData) -> None:
|
|
569
|
+
"""Checks if features have same dimensionality if hidden layers are shared."""
|
|
570
|
+
if self.component_config.get(SHARE_HIDDEN_LAYERS):
|
|
571
|
+
num_text_sentence_features = model_data.number_of_units(TEXT, SENTENCE)
|
|
572
|
+
num_label_sentence_features = model_data.number_of_units(LABEL, SENTENCE)
|
|
573
|
+
num_text_sequence_features = model_data.number_of_units(TEXT, SEQUENCE)
|
|
574
|
+
num_label_sequence_features = model_data.number_of_units(LABEL, SEQUENCE)
|
|
575
|
+
|
|
576
|
+
if (0 < num_text_sentence_features != num_label_sentence_features > 0) or (
|
|
577
|
+
0 < num_text_sequence_features != num_label_sequence_features > 0
|
|
578
|
+
):
|
|
579
|
+
raise ValueError(
|
|
580
|
+
"If embeddings are shared text features and label features "
|
|
581
|
+
"must coincide. Check the output dimensions of previous components."
|
|
582
|
+
)
|
|
583
|
+
|
|
584
|
+
def _extract_labels_precomputed_features(
|
|
585
|
+
self, label_examples: List[Message], attribute: Text = INTENT
|
|
586
|
+
) -> Tuple[List[FeatureArray], List[FeatureArray]]:
|
|
587
|
+
"""Collects precomputed encodings."""
|
|
588
|
+
features = defaultdict(list)
|
|
589
|
+
|
|
590
|
+
for e in label_examples:
|
|
591
|
+
label_features = self._extract_features(e, attribute)
|
|
592
|
+
for feature_key, feature_value in label_features.items():
|
|
593
|
+
features[feature_key].append(feature_value)
|
|
594
|
+
sequence_features = []
|
|
595
|
+
sentence_features = []
|
|
596
|
+
for feature_name, feature_value in features.items():
|
|
597
|
+
if SEQUENCE in feature_name:
|
|
598
|
+
sequence_features.append(
|
|
599
|
+
FeatureArray(np.array(feature_value), number_of_dimensions=3)
|
|
600
|
+
)
|
|
601
|
+
else:
|
|
602
|
+
sentence_features.append(
|
|
603
|
+
FeatureArray(np.array(feature_value), number_of_dimensions=3)
|
|
604
|
+
)
|
|
605
|
+
return sequence_features, sentence_features
|
|
606
|
+
|
|
607
|
+
@staticmethod
|
|
608
|
+
def _compute_default_label_features(
|
|
609
|
+
labels_example: List[Message],
|
|
610
|
+
) -> List[FeatureArray]:
|
|
611
|
+
"""Computes one-hot representation for the labels."""
|
|
612
|
+
logger.debug("No label features found. Computing default label features.")
|
|
613
|
+
|
|
614
|
+
eye_matrix = np.eye(len(labels_example), dtype=np.float32)
|
|
615
|
+
# add sequence dimension to one-hot labels
|
|
616
|
+
return [
|
|
617
|
+
FeatureArray(
|
|
618
|
+
np.array([np.expand_dims(a, 0) for a in eye_matrix]),
|
|
619
|
+
number_of_dimensions=3,
|
|
620
|
+
)
|
|
621
|
+
]
|
|
622
|
+
|
|
623
|
+
def _create_label_data(
|
|
624
|
+
self,
|
|
625
|
+
training_data: TrainingData,
|
|
626
|
+
label_id_dict: Dict[Text, int],
|
|
627
|
+
attribute: Text,
|
|
628
|
+
) -> RasaModelData:
|
|
629
|
+
"""Create matrix with label_ids encoded in rows as bag of words.
|
|
630
|
+
|
|
631
|
+
Find a training example for each label and get the encoded features
|
|
632
|
+
from the corresponding Message object.
|
|
633
|
+
If the features are already computed, fetch them from the message object
|
|
634
|
+
else compute a one hot encoding for the label as the feature vector.
|
|
635
|
+
"""
|
|
636
|
+
# Collect one example for each label
|
|
637
|
+
labels_idx_examples = []
|
|
638
|
+
for label_name, idx in label_id_dict.items():
|
|
639
|
+
label_example = self._find_example_for_label(
|
|
640
|
+
label_name, training_data.intent_examples, attribute
|
|
641
|
+
)
|
|
642
|
+
labels_idx_examples.append((idx, label_example))
|
|
643
|
+
|
|
644
|
+
# Sort the list of tuples based on label_idx
|
|
645
|
+
labels_idx_examples = sorted(labels_idx_examples, key=lambda x: x[0])
|
|
646
|
+
labels_example = [example for (_, example) in labels_idx_examples]
|
|
647
|
+
# Collect features, precomputed if they exist, else compute on the fly
|
|
648
|
+
if self._check_labels_features_exist(labels_example, attribute):
|
|
649
|
+
(
|
|
650
|
+
sequence_features,
|
|
651
|
+
sentence_features,
|
|
652
|
+
) = self._extract_labels_precomputed_features(labels_example, attribute)
|
|
653
|
+
else:
|
|
654
|
+
sequence_features = None
|
|
655
|
+
sentence_features = self._compute_default_label_features(labels_example)
|
|
656
|
+
|
|
657
|
+
label_data = RasaModelData()
|
|
658
|
+
label_data.add_features(LABEL, SEQUENCE, sequence_features)
|
|
659
|
+
label_data.add_features(LABEL, SENTENCE, sentence_features)
|
|
660
|
+
if label_data.does_feature_not_exist(
|
|
661
|
+
LABEL, SENTENCE
|
|
662
|
+
) and label_data.does_feature_not_exist(LABEL, SEQUENCE):
|
|
663
|
+
raise ValueError(
|
|
664
|
+
"No label features are present. Please check your configuration file."
|
|
665
|
+
)
|
|
666
|
+
|
|
667
|
+
label_ids = np.array([idx for (idx, _) in labels_idx_examples])
|
|
668
|
+
# explicitly add last dimension to label_ids
|
|
669
|
+
# to track correctly dynamic sequences
|
|
670
|
+
label_data.add_features(
|
|
671
|
+
LABEL_KEY,
|
|
672
|
+
LABEL_SUB_KEY,
|
|
673
|
+
[
|
|
674
|
+
FeatureArray(
|
|
675
|
+
np.expand_dims(label_ids, -1),
|
|
676
|
+
number_of_dimensions=2,
|
|
677
|
+
)
|
|
678
|
+
],
|
|
679
|
+
)
|
|
680
|
+
|
|
681
|
+
label_data.add_lengths(LABEL, SEQUENCE_LENGTH, LABEL, SEQUENCE)
|
|
682
|
+
|
|
683
|
+
return label_data
|
|
684
|
+
|
|
685
|
+
def _use_default_label_features(self, label_ids: np.ndarray) -> List[FeatureArray]:
|
|
686
|
+
if self._label_data is None:
|
|
687
|
+
return []
|
|
688
|
+
|
|
689
|
+
feature_arrays = self._label_data.get(LABEL, SENTENCE)
|
|
690
|
+
all_label_features = feature_arrays[0]
|
|
691
|
+
return [
|
|
692
|
+
FeatureArray(
|
|
693
|
+
np.array([all_label_features[label_id] for label_id in label_ids]),
|
|
694
|
+
number_of_dimensions=all_label_features.number_of_dimensions,
|
|
695
|
+
)
|
|
696
|
+
]
|
|
697
|
+
|
|
698
|
+
def _create_model_data(
|
|
699
|
+
self,
|
|
700
|
+
training_data: List[Message],
|
|
701
|
+
label_id_dict: Optional[Dict[Text, int]] = None,
|
|
702
|
+
label_attribute: Optional[Text] = None,
|
|
703
|
+
training: bool = True,
|
|
704
|
+
) -> RasaModelData:
|
|
705
|
+
"""Prepare data for training and create a RasaModelData object."""
|
|
706
|
+
from rasa.utils.tensorflow import model_data_utils
|
|
707
|
+
|
|
708
|
+
attributes_to_consider = [TEXT]
|
|
709
|
+
if training and self.component_config[INTENT_CLASSIFICATION]:
|
|
710
|
+
# we don't have any intent labels during prediction, just add them during
|
|
711
|
+
# training
|
|
712
|
+
attributes_to_consider.append(label_attribute)
|
|
713
|
+
if (
|
|
714
|
+
training
|
|
715
|
+
and self.component_config[ENTITY_RECOGNITION]
|
|
716
|
+
and self._entity_tag_specs
|
|
717
|
+
):
|
|
718
|
+
# Add entities as labels only during training and only if there was
|
|
719
|
+
# training data added for entities with DIET configured to predict entities.
|
|
720
|
+
attributes_to_consider.append(ENTITIES)
|
|
721
|
+
|
|
722
|
+
if training and label_attribute is not None:
|
|
723
|
+
# only use those training examples that have the label_attribute set
|
|
724
|
+
# during training
|
|
725
|
+
training_data = [
|
|
726
|
+
example for example in training_data if label_attribute in example.data
|
|
727
|
+
]
|
|
728
|
+
|
|
729
|
+
training_data = [
|
|
730
|
+
message
|
|
731
|
+
for message in training_data
|
|
732
|
+
if message.features_present(
|
|
733
|
+
attribute=TEXT, featurizers=self.component_config.get(FEATURIZERS)
|
|
734
|
+
)
|
|
735
|
+
]
|
|
736
|
+
|
|
737
|
+
if not training_data:
|
|
738
|
+
# no training data are present to train
|
|
739
|
+
return RasaModelData()
|
|
740
|
+
|
|
741
|
+
(
|
|
742
|
+
features_for_examples,
|
|
743
|
+
sparse_feature_sizes,
|
|
744
|
+
) = model_data_utils.featurize_training_examples(
|
|
745
|
+
training_data,
|
|
746
|
+
attributes_to_consider,
|
|
747
|
+
entity_tag_specs=self._entity_tag_specs,
|
|
748
|
+
featurizers=self.component_config[FEATURIZERS],
|
|
749
|
+
bilou_tagging=self.component_config[BILOU_FLAG],
|
|
750
|
+
)
|
|
751
|
+
attribute_data, _ = model_data_utils.convert_to_data_format(
|
|
752
|
+
features_for_examples, consider_dialogue_dimension=False
|
|
753
|
+
)
|
|
754
|
+
|
|
755
|
+
model_data = RasaModelData(
|
|
756
|
+
label_key=self.label_key, label_sub_key=self.label_sub_key
|
|
757
|
+
)
|
|
758
|
+
model_data.add_data(attribute_data)
|
|
759
|
+
model_data.add_lengths(TEXT, SEQUENCE_LENGTH, TEXT, SEQUENCE)
|
|
760
|
+
# Current implementation doesn't yet account for updating sparse
|
|
761
|
+
# feature sizes of label attributes. That's why we remove them.
|
|
762
|
+
sparse_feature_sizes = self._remove_label_sparse_feature_sizes(
|
|
763
|
+
sparse_feature_sizes=sparse_feature_sizes, label_attribute=label_attribute
|
|
764
|
+
)
|
|
765
|
+
model_data.add_sparse_feature_sizes(sparse_feature_sizes)
|
|
766
|
+
|
|
767
|
+
self._add_label_features(
|
|
768
|
+
model_data, training_data, label_attribute, label_id_dict, training
|
|
769
|
+
)
|
|
770
|
+
|
|
771
|
+
# make sure all keys are in the same order during training and prediction
|
|
772
|
+
# as we rely on the order of key and sub-key when constructing the actual
|
|
773
|
+
# tensors from the model data
|
|
774
|
+
model_data.sort()
|
|
775
|
+
|
|
776
|
+
return model_data
|
|
777
|
+
|
|
778
|
+
@staticmethod
|
|
779
|
+
def _remove_label_sparse_feature_sizes(
|
|
780
|
+
sparse_feature_sizes: Dict[Text, Dict[Text, List[int]]],
|
|
781
|
+
label_attribute: Optional[Text] = None,
|
|
782
|
+
) -> Dict[Text, Dict[Text, List[int]]]:
|
|
783
|
+
if label_attribute in sparse_feature_sizes:
|
|
784
|
+
del sparse_feature_sizes[label_attribute]
|
|
785
|
+
return sparse_feature_sizes
|
|
786
|
+
|
|
787
|
+
def _add_label_features(
|
|
788
|
+
self,
|
|
789
|
+
model_data: RasaModelData,
|
|
790
|
+
training_data: List[Message],
|
|
791
|
+
label_attribute: Text,
|
|
792
|
+
label_id_dict: Dict[Text, int],
|
|
793
|
+
training: bool = True,
|
|
794
|
+
) -> None:
|
|
795
|
+
label_ids = []
|
|
796
|
+
if training and self.component_config[INTENT_CLASSIFICATION]:
|
|
797
|
+
for example in training_data:
|
|
798
|
+
if example.get(label_attribute):
|
|
799
|
+
label_ids.append(label_id_dict[example.get(label_attribute)])
|
|
800
|
+
# explicitly add last dimension to label_ids
|
|
801
|
+
# to track correctly dynamic sequences
|
|
802
|
+
model_data.add_features(
|
|
803
|
+
LABEL_KEY,
|
|
804
|
+
LABEL_SUB_KEY,
|
|
805
|
+
[
|
|
806
|
+
FeatureArray(
|
|
807
|
+
np.expand_dims(label_ids, -1),
|
|
808
|
+
number_of_dimensions=2,
|
|
809
|
+
)
|
|
810
|
+
],
|
|
811
|
+
)
|
|
812
|
+
|
|
813
|
+
if (
|
|
814
|
+
label_attribute
|
|
815
|
+
and model_data.does_feature_not_exist(label_attribute, SENTENCE)
|
|
816
|
+
and model_data.does_feature_not_exist(label_attribute, SEQUENCE)
|
|
817
|
+
):
|
|
818
|
+
# no label features are present, get default features from _label_data
|
|
819
|
+
model_data.add_features(
|
|
820
|
+
LABEL, SENTENCE, self._use_default_label_features(np.array(label_ids))
|
|
821
|
+
)
|
|
822
|
+
|
|
823
|
+
# as label_attribute can have different values, e.g. INTENT or RESPONSE,
|
|
824
|
+
# copy over the features to the LABEL key to make
|
|
825
|
+
# it easier to access the label features inside the model itself
|
|
826
|
+
model_data.update_key(label_attribute, SENTENCE, LABEL, SENTENCE)
|
|
827
|
+
model_data.update_key(label_attribute, SEQUENCE, LABEL, SEQUENCE)
|
|
828
|
+
model_data.update_key(label_attribute, MASK, LABEL, MASK)
|
|
829
|
+
|
|
830
|
+
model_data.add_lengths(LABEL, SEQUENCE_LENGTH, LABEL, SEQUENCE)
|
|
831
|
+
|
|
832
|
+
# train helpers
|
|
833
|
+
def preprocess_train_data(self, training_data: TrainingData) -> RasaModelData:
|
|
834
|
+
"""Prepares data for training.
|
|
835
|
+
|
|
836
|
+
Performs sanity checks on training data, extracts encodings for labels.
|
|
837
|
+
"""
|
|
838
|
+
if (
|
|
839
|
+
self.component_config[BILOU_FLAG]
|
|
840
|
+
and self.component_config[ENTITY_RECOGNITION]
|
|
841
|
+
):
|
|
842
|
+
bilou_utils.apply_bilou_schema(training_data)
|
|
843
|
+
|
|
844
|
+
label_id_index_mapping = self._label_id_index_mapping(
|
|
845
|
+
training_data, attribute=INTENT
|
|
846
|
+
)
|
|
847
|
+
|
|
848
|
+
if not label_id_index_mapping:
|
|
849
|
+
# no labels are present to train
|
|
850
|
+
return RasaModelData()
|
|
851
|
+
|
|
852
|
+
self.index_label_id_mapping = self._invert_mapping(label_id_index_mapping)
|
|
853
|
+
|
|
854
|
+
self._label_data = self._create_label_data(
|
|
855
|
+
training_data, label_id_index_mapping, attribute=INTENT
|
|
856
|
+
)
|
|
857
|
+
|
|
858
|
+
self._entity_tag_specs = self._create_entity_tag_specs(training_data)
|
|
859
|
+
|
|
860
|
+
label_attribute = (
|
|
861
|
+
INTENT if self.component_config[INTENT_CLASSIFICATION] else None
|
|
862
|
+
)
|
|
863
|
+
model_data = self._create_model_data(
|
|
864
|
+
training_data.nlu_examples,
|
|
865
|
+
label_id_index_mapping,
|
|
866
|
+
label_attribute=label_attribute,
|
|
867
|
+
)
|
|
868
|
+
|
|
869
|
+
self._check_input_dimension_consistency(model_data)
|
|
870
|
+
|
|
871
|
+
return model_data
|
|
872
|
+
|
|
873
|
+
@staticmethod
|
|
874
|
+
def _check_enough_labels(model_data: RasaModelData) -> bool:
|
|
875
|
+
return len(np.unique(model_data.get(LABEL_KEY, LABEL_SUB_KEY))) >= 2
|
|
876
|
+
|
|
877
|
+
def train(self, training_data: TrainingData) -> Resource:
|
|
878
|
+
"""Train the embedding intent classifier on a data set."""
|
|
879
|
+
model_data = self.preprocess_train_data(training_data)
|
|
880
|
+
if model_data.is_empty():
|
|
881
|
+
logger.debug(
|
|
882
|
+
f"Cannot train '{self.__class__.__name__}'. No data was provided. "
|
|
883
|
+
f"Skipping training of the classifier."
|
|
884
|
+
)
|
|
885
|
+
return self._resource
|
|
886
|
+
|
|
887
|
+
if not self.model and self.finetune_mode:
|
|
888
|
+
raise rasa.shared.exceptions.InvalidParameterException(
|
|
889
|
+
f"{self.__class__.__name__} was instantiated "
|
|
890
|
+
f"with `model=None` and `finetune_mode=True`. "
|
|
891
|
+
f"This is not a valid combination as the component "
|
|
892
|
+
f"needs an already instantiated and trained model "
|
|
893
|
+
f"to continue training in finetune mode."
|
|
894
|
+
)
|
|
895
|
+
|
|
896
|
+
if self.component_config.get(INTENT_CLASSIFICATION):
|
|
897
|
+
if not self._check_enough_labels(model_data):
|
|
898
|
+
logger.error(
|
|
899
|
+
f"Cannot train '{self.__class__.__name__}'. "
|
|
900
|
+
f"Need at least 2 different intent classes. "
|
|
901
|
+
f"Skipping training of classifier."
|
|
902
|
+
)
|
|
903
|
+
return self._resource
|
|
904
|
+
if self.component_config.get(ENTITY_RECOGNITION):
|
|
905
|
+
self.check_correct_entity_annotations(training_data)
|
|
906
|
+
|
|
907
|
+
# keep one example for persisting and loading
|
|
908
|
+
self._data_example = model_data.first_data_example()
|
|
909
|
+
|
|
910
|
+
if not self.finetune_mode:
|
|
911
|
+
# No pre-trained model to load from. Create a new instance of the model.
|
|
912
|
+
self.model = self._instantiate_model_class(model_data)
|
|
913
|
+
self.model.compile(
|
|
914
|
+
optimizer=tf.keras.optimizers.Adam(
|
|
915
|
+
self.component_config[LEARNING_RATE]
|
|
916
|
+
),
|
|
917
|
+
run_eagerly=self.component_config[RUN_EAGERLY],
|
|
918
|
+
)
|
|
919
|
+
else:
|
|
920
|
+
if self.model is None:
|
|
921
|
+
raise ModelNotFound("Model could not be found. ")
|
|
922
|
+
|
|
923
|
+
self.model.adjust_for_incremental_training(
|
|
924
|
+
data_example=self._data_example,
|
|
925
|
+
new_sparse_feature_sizes=model_data.get_sparse_feature_sizes(),
|
|
926
|
+
old_sparse_feature_sizes=self._sparse_feature_sizes,
|
|
927
|
+
)
|
|
928
|
+
self._sparse_feature_sizes = model_data.get_sparse_feature_sizes()
|
|
929
|
+
|
|
930
|
+
data_generator, validation_data_generator = train_utils.create_data_generators(
|
|
931
|
+
model_data,
|
|
932
|
+
self.component_config[BATCH_SIZES],
|
|
933
|
+
self.component_config[EPOCHS],
|
|
934
|
+
self.component_config[BATCH_STRATEGY],
|
|
935
|
+
self.component_config[EVAL_NUM_EXAMPLES],
|
|
936
|
+
self.component_config[RANDOM_SEED],
|
|
937
|
+
drop_small_last_batch=self.component_config[DROP_SMALL_LAST_BATCH],
|
|
938
|
+
)
|
|
939
|
+
callbacks = train_utils.create_common_callbacks(
|
|
940
|
+
self.component_config[EPOCHS],
|
|
941
|
+
self.component_config[TENSORBOARD_LOG_DIR],
|
|
942
|
+
self.component_config[TENSORBOARD_LOG_LEVEL],
|
|
943
|
+
self.tmp_checkpoint_dir,
|
|
944
|
+
)
|
|
945
|
+
|
|
946
|
+
self.model.fit(
|
|
947
|
+
data_generator,
|
|
948
|
+
epochs=self.component_config[EPOCHS],
|
|
949
|
+
validation_data=validation_data_generator,
|
|
950
|
+
validation_freq=self.component_config[EVAL_NUM_EPOCHS],
|
|
951
|
+
callbacks=callbacks,
|
|
952
|
+
verbose=False,
|
|
953
|
+
shuffle=False, # we use custom shuffle inside data generator
|
|
954
|
+
)
|
|
955
|
+
|
|
956
|
+
self.persist()
|
|
957
|
+
|
|
958
|
+
return self._resource
|
|
959
|
+
|
|
960
|
+
# process helpers
|
|
961
|
+
def _predict(
|
|
962
|
+
self, message: Message
|
|
963
|
+
) -> Optional[Dict[Text, Union[tf.Tensor, Dict[Text, tf.Tensor]]]]:
|
|
964
|
+
if self.model is None:
|
|
965
|
+
logger.debug(
|
|
966
|
+
f"There is no trained model for '{self.__class__.__name__}': The "
|
|
967
|
+
f"component is either not trained or didn't receive enough training "
|
|
968
|
+
f"data."
|
|
969
|
+
)
|
|
970
|
+
return None
|
|
971
|
+
|
|
972
|
+
# create session data from message and convert it into a batch of 1
|
|
973
|
+
model_data = self._create_model_data([message], training=False)
|
|
974
|
+
if model_data.is_empty():
|
|
975
|
+
return None
|
|
976
|
+
return self.model.run_inference(model_data)
|
|
977
|
+
|
|
978
|
+
def _predict_label(
|
|
979
|
+
self, predict_out: Optional[Dict[Text, tf.Tensor]]
|
|
980
|
+
) -> Tuple[Dict[Text, Any], List[Dict[Text, Any]]]:
|
|
981
|
+
"""Predicts the intent of the provided message."""
|
|
982
|
+
label: Dict[Text, Any] = {"name": None, "confidence": 0.0}
|
|
983
|
+
label_ranking: List[Dict[Text, Any]] = []
|
|
984
|
+
|
|
985
|
+
if predict_out is None:
|
|
986
|
+
return label, label_ranking
|
|
987
|
+
|
|
988
|
+
message_sim = predict_out["i_scores"]
|
|
989
|
+
message_sim = message_sim.flatten() # sim is a matrix
|
|
990
|
+
|
|
991
|
+
# if X contains all zeros do not predict some label
|
|
992
|
+
if message_sim.size == 0:
|
|
993
|
+
return label, label_ranking
|
|
994
|
+
|
|
995
|
+
# rank the confidences
|
|
996
|
+
ranking_length = self.component_config[RANKING_LENGTH]
|
|
997
|
+
renormalize = (
|
|
998
|
+
self.component_config[RENORMALIZE_CONFIDENCES]
|
|
999
|
+
and self.component_config[MODEL_CONFIDENCE] == SOFTMAX
|
|
1000
|
+
)
|
|
1001
|
+
ranked_label_indices, message_sim = train_utils.rank_and_mask(
|
|
1002
|
+
message_sim, ranking_length=ranking_length, renormalize=renormalize
|
|
1003
|
+
)
|
|
1004
|
+
|
|
1005
|
+
# construct the label and ranking
|
|
1006
|
+
casted_message_sim: List[float] = message_sim.tolist() # np.float to float
|
|
1007
|
+
top_label_idx = ranked_label_indices[0]
|
|
1008
|
+
label = {
|
|
1009
|
+
"name": self.index_label_id_mapping[top_label_idx],
|
|
1010
|
+
"confidence": casted_message_sim[top_label_idx],
|
|
1011
|
+
}
|
|
1012
|
+
|
|
1013
|
+
ranking = [(idx, casted_message_sim[idx]) for idx in ranked_label_indices]
|
|
1014
|
+
label_ranking = [
|
|
1015
|
+
{"name": self.index_label_id_mapping[label_idx], "confidence": score}
|
|
1016
|
+
for label_idx, score in ranking
|
|
1017
|
+
]
|
|
1018
|
+
|
|
1019
|
+
return label, label_ranking
|
|
1020
|
+
|
|
1021
|
+
def _predict_entities(
|
|
1022
|
+
self, predict_out: Optional[Dict[Text, tf.Tensor]], message: Message
|
|
1023
|
+
) -> List[Dict]:
|
|
1024
|
+
if predict_out is None:
|
|
1025
|
+
return []
|
|
1026
|
+
|
|
1027
|
+
predicted_tags, confidence_values = train_utils.entity_label_to_tags(
|
|
1028
|
+
predict_out, self._entity_tag_specs, self.component_config[BILOU_FLAG]
|
|
1029
|
+
)
|
|
1030
|
+
|
|
1031
|
+
entities = self.convert_predictions_into_entities(
|
|
1032
|
+
message.get(TEXT),
|
|
1033
|
+
message.get(TOKENS_NAMES[TEXT], []),
|
|
1034
|
+
predicted_tags,
|
|
1035
|
+
self.split_entities_config,
|
|
1036
|
+
confidence_values,
|
|
1037
|
+
)
|
|
1038
|
+
|
|
1039
|
+
entities = self.add_extractor_name(entities)
|
|
1040
|
+
entities = message.get(ENTITIES, []) + entities
|
|
1041
|
+
|
|
1042
|
+
return entities
|
|
1043
|
+
|
|
1044
|
+
def process(self, messages: List[Message]) -> List[Message]:
|
|
1045
|
+
"""Augments the message with intents, entities, and diagnostic data."""
|
|
1046
|
+
for message in messages:
|
|
1047
|
+
out = self._predict(message)
|
|
1048
|
+
|
|
1049
|
+
if self.component_config[INTENT_CLASSIFICATION]:
|
|
1050
|
+
label, label_ranking = self._predict_label(out)
|
|
1051
|
+
|
|
1052
|
+
message.set(INTENT, label, add_to_output=True)
|
|
1053
|
+
message.set("intent_ranking", label_ranking, add_to_output=True)
|
|
1054
|
+
|
|
1055
|
+
if self.component_config[ENTITY_RECOGNITION]:
|
|
1056
|
+
entities = self._predict_entities(out, message)
|
|
1057
|
+
|
|
1058
|
+
message.set(ENTITIES, entities, add_to_output=True)
|
|
1059
|
+
|
|
1060
|
+
if out and self._execution_context.should_add_diagnostic_data:
|
|
1061
|
+
message.add_diagnostic_data(
|
|
1062
|
+
self._execution_context.node_name, out.get(DIAGNOSTIC_DATA)
|
|
1063
|
+
)
|
|
1064
|
+
|
|
1065
|
+
return messages
|
|
1066
|
+
|
|
1067
|
+
def persist(self) -> None:
|
|
1068
|
+
"""Persist this model into the passed directory."""
|
|
1069
|
+
if self.model is None:
|
|
1070
|
+
return None
|
|
1071
|
+
|
|
1072
|
+
with self._model_storage.write_to(self._resource) as model_path:
|
|
1073
|
+
file_name = self.__class__.__name__
|
|
1074
|
+
tf_model_file = model_path / f"{file_name}.tf_model"
|
|
1075
|
+
|
|
1076
|
+
rasa.shared.utils.io.create_directory_for_file(tf_model_file)
|
|
1077
|
+
|
|
1078
|
+
if self.component_config[CHECKPOINT_MODEL] and self.tmp_checkpoint_dir:
|
|
1079
|
+
self.model.load_weights(self.tmp_checkpoint_dir / "checkpoint.tf_model")
|
|
1080
|
+
# Save an empty file to flag that this model has been
|
|
1081
|
+
# produced using checkpointing
|
|
1082
|
+
checkpoint_marker = model_path / f"{file_name}.from_checkpoint.pkl"
|
|
1083
|
+
checkpoint_marker.touch()
|
|
1084
|
+
|
|
1085
|
+
self.model.save(str(tf_model_file))
|
|
1086
|
+
|
|
1087
|
+
# save data example
|
|
1088
|
+
serialize_nested_feature_arrays(
|
|
1089
|
+
self._data_example,
|
|
1090
|
+
model_path / f"{file_name}.data_example.st",
|
|
1091
|
+
model_path / f"{file_name}.data_example_metadata.json",
|
|
1092
|
+
)
|
|
1093
|
+
# save label data
|
|
1094
|
+
serialize_nested_feature_arrays(
|
|
1095
|
+
dict(self._label_data.data) if self._label_data is not None else {},
|
|
1096
|
+
model_path / f"{file_name}.label_data.st",
|
|
1097
|
+
model_path / f"{file_name}.label_data_metadata.json",
|
|
1098
|
+
)
|
|
1099
|
+
|
|
1100
|
+
rasa.shared.utils.io.dump_obj_as_json_to_file(
|
|
1101
|
+
model_path / f"{file_name}.sparse_feature_sizes.json",
|
|
1102
|
+
self._sparse_feature_sizes,
|
|
1103
|
+
)
|
|
1104
|
+
rasa.shared.utils.io.dump_obj_as_json_to_file(
|
|
1105
|
+
model_path / f"{file_name}.index_label_id_mapping.json",
|
|
1106
|
+
self.index_label_id_mapping,
|
|
1107
|
+
)
|
|
1108
|
+
|
|
1109
|
+
entity_tag_specs = (
|
|
1110
|
+
[tag_spec._asdict() for tag_spec in self._entity_tag_specs]
|
|
1111
|
+
if self._entity_tag_specs
|
|
1112
|
+
else []
|
|
1113
|
+
)
|
|
1114
|
+
rasa.shared.utils.io.dump_obj_as_json_to_file(
|
|
1115
|
+
model_path / f"{file_name}.entity_tag_specs.json", entity_tag_specs
|
|
1116
|
+
)
|
|
1117
|
+
|
|
1118
|
+
@classmethod
|
|
1119
|
+
def load(
|
|
1120
|
+
cls: Type[DIETClassifierT],
|
|
1121
|
+
config: Dict[Text, Any],
|
|
1122
|
+
model_storage: ModelStorage,
|
|
1123
|
+
resource: Resource,
|
|
1124
|
+
execution_context: ExecutionContext,
|
|
1125
|
+
**kwargs: Any,
|
|
1126
|
+
) -> DIETClassifierT:
|
|
1127
|
+
"""Loads a policy from the storage (see parent class for full docstring)."""
|
|
1128
|
+
try:
|
|
1129
|
+
with model_storage.read_from(resource) as model_path:
|
|
1130
|
+
return cls._load(
|
|
1131
|
+
model_path, config, model_storage, resource, execution_context
|
|
1132
|
+
)
|
|
1133
|
+
except ValueError:
|
|
1134
|
+
logger.debug(
|
|
1135
|
+
f"Failed to load {cls.__class__.__name__} from model storage. Resource "
|
|
1136
|
+
f"'{resource.name}' doesn't exist."
|
|
1137
|
+
)
|
|
1138
|
+
return cls(config, model_storage, resource, execution_context)
|
|
1139
|
+
|
|
1140
|
+
@classmethod
|
|
1141
|
+
def _load(
|
|
1142
|
+
cls: Type[DIETClassifierT],
|
|
1143
|
+
model_path: Path,
|
|
1144
|
+
config: Dict[Text, Any],
|
|
1145
|
+
model_storage: ModelStorage,
|
|
1146
|
+
resource: Resource,
|
|
1147
|
+
execution_context: ExecutionContext,
|
|
1148
|
+
) -> DIETClassifierT:
|
|
1149
|
+
"""Loads the trained model from the provided directory."""
|
|
1150
|
+
(
|
|
1151
|
+
index_label_id_mapping,
|
|
1152
|
+
entity_tag_specs,
|
|
1153
|
+
label_data,
|
|
1154
|
+
data_example,
|
|
1155
|
+
sparse_feature_sizes,
|
|
1156
|
+
) = cls._load_from_files(model_path)
|
|
1157
|
+
|
|
1158
|
+
config = train_utils.update_confidence_type(config)
|
|
1159
|
+
config = train_utils.update_similarity_type(config)
|
|
1160
|
+
|
|
1161
|
+
model = cls._load_model(
|
|
1162
|
+
entity_tag_specs,
|
|
1163
|
+
label_data,
|
|
1164
|
+
config,
|
|
1165
|
+
data_example,
|
|
1166
|
+
model_path,
|
|
1167
|
+
finetune_mode=execution_context.is_finetuning,
|
|
1168
|
+
)
|
|
1169
|
+
|
|
1170
|
+
return cls(
|
|
1171
|
+
config=config,
|
|
1172
|
+
model_storage=model_storage,
|
|
1173
|
+
resource=resource,
|
|
1174
|
+
execution_context=execution_context,
|
|
1175
|
+
index_label_id_mapping=index_label_id_mapping,
|
|
1176
|
+
entity_tag_specs=entity_tag_specs,
|
|
1177
|
+
model=model,
|
|
1178
|
+
sparse_feature_sizes=sparse_feature_sizes,
|
|
1179
|
+
)
|
|
1180
|
+
|
|
1181
|
+
@classmethod
|
|
1182
|
+
def _load_from_files(
|
|
1183
|
+
cls, model_path: Path
|
|
1184
|
+
) -> Tuple[
|
|
1185
|
+
Dict[int, Text],
|
|
1186
|
+
List[EntityTagSpec],
|
|
1187
|
+
RasaModelData,
|
|
1188
|
+
Dict[Text, Dict[Text, List[FeatureArray]]],
|
|
1189
|
+
Dict[Text, Dict[Text, List[int]]],
|
|
1190
|
+
]:
|
|
1191
|
+
file_name = cls.__name__
|
|
1192
|
+
|
|
1193
|
+
# load data example
|
|
1194
|
+
data_example = deserialize_nested_feature_arrays(
|
|
1195
|
+
str(model_path / f"{file_name}.data_example.st"),
|
|
1196
|
+
str(model_path / f"{file_name}.data_example_metadata.json"),
|
|
1197
|
+
)
|
|
1198
|
+
# load label data
|
|
1199
|
+
loaded_label_data = deserialize_nested_feature_arrays(
|
|
1200
|
+
str(model_path / f"{file_name}.label_data.st"),
|
|
1201
|
+
str(model_path / f"{file_name}.label_data_metadata.json"),
|
|
1202
|
+
)
|
|
1203
|
+
label_data = RasaModelData(data=loaded_label_data)
|
|
1204
|
+
|
|
1205
|
+
sparse_feature_sizes = rasa.shared.utils.io.read_json_file(
|
|
1206
|
+
model_path / f"{file_name}.sparse_feature_sizes.json"
|
|
1207
|
+
)
|
|
1208
|
+
index_label_id_mapping = rasa.shared.utils.io.read_json_file(
|
|
1209
|
+
model_path / f"{file_name}.index_label_id_mapping.json"
|
|
1210
|
+
)
|
|
1211
|
+
entity_tag_specs = rasa.shared.utils.io.read_json_file(
|
|
1212
|
+
model_path / f"{file_name}.entity_tag_specs.json"
|
|
1213
|
+
)
|
|
1214
|
+
entity_tag_specs = [
|
|
1215
|
+
EntityTagSpec(
|
|
1216
|
+
tag_name=tag_spec["tag_name"],
|
|
1217
|
+
ids_to_tags={
|
|
1218
|
+
int(key): value for key, value in tag_spec["ids_to_tags"].items()
|
|
1219
|
+
},
|
|
1220
|
+
tags_to_ids={
|
|
1221
|
+
key: int(value) for key, value in tag_spec["tags_to_ids"].items()
|
|
1222
|
+
},
|
|
1223
|
+
num_tags=tag_spec["num_tags"],
|
|
1224
|
+
)
|
|
1225
|
+
for tag_spec in entity_tag_specs
|
|
1226
|
+
]
|
|
1227
|
+
|
|
1228
|
+
index_label_id_mapping = {
|
|
1229
|
+
int(key): value for key, value in index_label_id_mapping.items()
|
|
1230
|
+
}
|
|
1231
|
+
|
|
1232
|
+
return (
|
|
1233
|
+
index_label_id_mapping,
|
|
1234
|
+
entity_tag_specs,
|
|
1235
|
+
label_data,
|
|
1236
|
+
data_example,
|
|
1237
|
+
sparse_feature_sizes,
|
|
1238
|
+
)
|
|
1239
|
+
|
|
1240
|
+
@classmethod
|
|
1241
|
+
def _load_model(
|
|
1242
|
+
cls,
|
|
1243
|
+
entity_tag_specs: List[EntityTagSpec],
|
|
1244
|
+
label_data: RasaModelData,
|
|
1245
|
+
config: Dict[Text, Any],
|
|
1246
|
+
data_example: Dict[Text, Dict[Text, List[FeatureArray]]],
|
|
1247
|
+
model_path: Path,
|
|
1248
|
+
finetune_mode: bool = False,
|
|
1249
|
+
) -> "RasaModel":
|
|
1250
|
+
file_name = cls.__name__
|
|
1251
|
+
tf_model_file = model_path / f"{file_name}.tf_model"
|
|
1252
|
+
|
|
1253
|
+
label_key = LABEL_KEY if config[INTENT_CLASSIFICATION] else None
|
|
1254
|
+
label_sub_key = LABEL_SUB_KEY if config[INTENT_CLASSIFICATION] else None
|
|
1255
|
+
|
|
1256
|
+
model_data_example = RasaModelData(
|
|
1257
|
+
label_key=label_key, label_sub_key=label_sub_key, data=data_example
|
|
1258
|
+
)
|
|
1259
|
+
|
|
1260
|
+
model = cls._load_model_class(
|
|
1261
|
+
tf_model_file,
|
|
1262
|
+
model_data_example,
|
|
1263
|
+
label_data,
|
|
1264
|
+
entity_tag_specs,
|
|
1265
|
+
config,
|
|
1266
|
+
finetune_mode=finetune_mode,
|
|
1267
|
+
)
|
|
1268
|
+
|
|
1269
|
+
return model
|
|
1270
|
+
|
|
1271
|
+
@classmethod
|
|
1272
|
+
def _load_model_class(
|
|
1273
|
+
cls,
|
|
1274
|
+
tf_model_file: Text,
|
|
1275
|
+
model_data_example: RasaModelData,
|
|
1276
|
+
label_data: RasaModelData,
|
|
1277
|
+
entity_tag_specs: List[EntityTagSpec],
|
|
1278
|
+
config: Dict[Text, Any],
|
|
1279
|
+
finetune_mode: bool,
|
|
1280
|
+
) -> "RasaModel":
|
|
1281
|
+
predict_data_example = RasaModelData(
|
|
1282
|
+
label_key=model_data_example.label_key,
|
|
1283
|
+
data={
|
|
1284
|
+
feature_name: features
|
|
1285
|
+
for feature_name, features in model_data_example.items()
|
|
1286
|
+
if TEXT in feature_name
|
|
1287
|
+
},
|
|
1288
|
+
)
|
|
1289
|
+
|
|
1290
|
+
return cls.model_class().load(
|
|
1291
|
+
tf_model_file,
|
|
1292
|
+
model_data_example,
|
|
1293
|
+
predict_data_example,
|
|
1294
|
+
data_signature=model_data_example.get_signature(),
|
|
1295
|
+
label_data=label_data,
|
|
1296
|
+
entity_tag_specs=entity_tag_specs,
|
|
1297
|
+
config=copy.deepcopy(config),
|
|
1298
|
+
finetune_mode=finetune_mode,
|
|
1299
|
+
)
|
|
1300
|
+
|
|
1301
|
+
def _instantiate_model_class(self, model_data: RasaModelData) -> "RasaModel":
|
|
1302
|
+
return self.model_class()(
|
|
1303
|
+
data_signature=model_data.get_signature(),
|
|
1304
|
+
label_data=self._label_data,
|
|
1305
|
+
entity_tag_specs=self._entity_tag_specs,
|
|
1306
|
+
config=self.component_config,
|
|
1307
|
+
)
|
|
1308
|
+
|
|
1309
|
+
|
|
1310
|
+
class DIET(TransformerRasaModel):
|
|
1311
|
+
def __init__(
|
|
1312
|
+
self,
|
|
1313
|
+
data_signature: Dict[Text, Dict[Text, List[FeatureSignature]]],
|
|
1314
|
+
label_data: RasaModelData,
|
|
1315
|
+
entity_tag_specs: Optional[List[EntityTagSpec]],
|
|
1316
|
+
config: Dict[Text, Any],
|
|
1317
|
+
) -> None:
|
|
1318
|
+
# create entity tag spec before calling super otherwise building the model
|
|
1319
|
+
# will fail
|
|
1320
|
+
super().__init__("DIET", config, data_signature, label_data)
|
|
1321
|
+
self._entity_tag_specs = self._ordered_tag_specs(entity_tag_specs)
|
|
1322
|
+
|
|
1323
|
+
self.predict_data_signature = {
|
|
1324
|
+
feature_name: features
|
|
1325
|
+
for feature_name, features in data_signature.items()
|
|
1326
|
+
if TEXT in feature_name
|
|
1327
|
+
}
|
|
1328
|
+
|
|
1329
|
+
# tf training
|
|
1330
|
+
self._create_metrics()
|
|
1331
|
+
self._update_metrics_to_log()
|
|
1332
|
+
|
|
1333
|
+
# needed for efficient prediction
|
|
1334
|
+
self.all_labels_embed: Optional[tf.Tensor] = None
|
|
1335
|
+
|
|
1336
|
+
self._prepare_layers()
|
|
1337
|
+
|
|
1338
|
+
@staticmethod
|
|
1339
|
+
def _ordered_tag_specs(
|
|
1340
|
+
entity_tag_specs: Optional[List[EntityTagSpec]],
|
|
1341
|
+
) -> List[EntityTagSpec]:
|
|
1342
|
+
"""Ensure that order of entity tag specs matches CRF layer order."""
|
|
1343
|
+
if entity_tag_specs is None:
|
|
1344
|
+
return []
|
|
1345
|
+
|
|
1346
|
+
crf_order = [
|
|
1347
|
+
ENTITY_ATTRIBUTE_TYPE,
|
|
1348
|
+
ENTITY_ATTRIBUTE_ROLE,
|
|
1349
|
+
ENTITY_ATTRIBUTE_GROUP,
|
|
1350
|
+
]
|
|
1351
|
+
|
|
1352
|
+
ordered_tag_spec = []
|
|
1353
|
+
|
|
1354
|
+
for tag_name in crf_order:
|
|
1355
|
+
for tag_spec in entity_tag_specs:
|
|
1356
|
+
if tag_name == tag_spec.tag_name:
|
|
1357
|
+
ordered_tag_spec.append(tag_spec)
|
|
1358
|
+
|
|
1359
|
+
return ordered_tag_spec
|
|
1360
|
+
|
|
1361
|
+
def _check_data(self) -> None:
|
|
1362
|
+
if TEXT not in self.data_signature:
|
|
1363
|
+
raise InvalidConfigException(
|
|
1364
|
+
f"No text features specified. "
|
|
1365
|
+
f"Cannot train '{self.__class__.__name__}' model."
|
|
1366
|
+
)
|
|
1367
|
+
if self.config[INTENT_CLASSIFICATION]:
|
|
1368
|
+
if LABEL not in self.data_signature:
|
|
1369
|
+
raise InvalidConfigException(
|
|
1370
|
+
f"No label features specified. "
|
|
1371
|
+
f"Cannot train '{self.__class__.__name__}' model."
|
|
1372
|
+
)
|
|
1373
|
+
|
|
1374
|
+
if self.config[SHARE_HIDDEN_LAYERS]:
|
|
1375
|
+
different_sentence_signatures = False
|
|
1376
|
+
different_sequence_signatures = False
|
|
1377
|
+
if (
|
|
1378
|
+
SENTENCE in self.data_signature[TEXT]
|
|
1379
|
+
and SENTENCE in self.data_signature[LABEL]
|
|
1380
|
+
):
|
|
1381
|
+
different_sentence_signatures = (
|
|
1382
|
+
self.data_signature[TEXT][SENTENCE]
|
|
1383
|
+
!= self.data_signature[LABEL][SENTENCE]
|
|
1384
|
+
)
|
|
1385
|
+
if (
|
|
1386
|
+
SEQUENCE in self.data_signature[TEXT]
|
|
1387
|
+
and SEQUENCE in self.data_signature[LABEL]
|
|
1388
|
+
):
|
|
1389
|
+
different_sequence_signatures = (
|
|
1390
|
+
self.data_signature[TEXT][SEQUENCE]
|
|
1391
|
+
!= self.data_signature[LABEL][SEQUENCE]
|
|
1392
|
+
)
|
|
1393
|
+
|
|
1394
|
+
if different_sentence_signatures or different_sequence_signatures:
|
|
1395
|
+
raise ValueError(
|
|
1396
|
+
"If hidden layer weights are shared, data signatures "
|
|
1397
|
+
"for text_features and label_features must coincide."
|
|
1398
|
+
)
|
|
1399
|
+
|
|
1400
|
+
if self.config[ENTITY_RECOGNITION] and (
|
|
1401
|
+
ENTITIES not in self.data_signature
|
|
1402
|
+
or ENTITY_ATTRIBUTE_TYPE not in self.data_signature[ENTITIES]
|
|
1403
|
+
):
|
|
1404
|
+
logger.debug(
|
|
1405
|
+
f"You specified '{self.__class__.__name__}' to train entities, but "
|
|
1406
|
+
f"no entities are present in the training data. Skipping training of "
|
|
1407
|
+
f"entities."
|
|
1408
|
+
)
|
|
1409
|
+
self.config[ENTITY_RECOGNITION] = False
|
|
1410
|
+
|
|
1411
|
+
def _create_metrics(self) -> None:
|
|
1412
|
+
# self.metrics will have the same order as they are created
|
|
1413
|
+
# so create loss metrics first to output losses first
|
|
1414
|
+
self.mask_loss = tf.keras.metrics.Mean(name="m_loss")
|
|
1415
|
+
self.intent_loss = tf.keras.metrics.Mean(name="i_loss")
|
|
1416
|
+
self.entity_loss = tf.keras.metrics.Mean(name="e_loss")
|
|
1417
|
+
self.entity_group_loss = tf.keras.metrics.Mean(name="g_loss")
|
|
1418
|
+
self.entity_role_loss = tf.keras.metrics.Mean(name="r_loss")
|
|
1419
|
+
# create accuracy metrics second to output accuracies second
|
|
1420
|
+
self.mask_acc = tf.keras.metrics.Mean(name="m_acc")
|
|
1421
|
+
self.intent_acc = tf.keras.metrics.Mean(name="i_acc")
|
|
1422
|
+
self.entity_f1 = tf.keras.metrics.Mean(name="e_f1")
|
|
1423
|
+
self.entity_group_f1 = tf.keras.metrics.Mean(name="g_f1")
|
|
1424
|
+
self.entity_role_f1 = tf.keras.metrics.Mean(name="r_f1")
|
|
1425
|
+
|
|
1426
|
+
def _update_metrics_to_log(self) -> None:
|
|
1427
|
+
debug_log_level = logging.getLogger("rasa").level == logging.DEBUG
|
|
1428
|
+
|
|
1429
|
+
if self.config[MASKED_LM]:
|
|
1430
|
+
self.metrics_to_log.append("m_acc")
|
|
1431
|
+
if debug_log_level:
|
|
1432
|
+
self.metrics_to_log.append("m_loss")
|
|
1433
|
+
if self.config[INTENT_CLASSIFICATION]:
|
|
1434
|
+
self.metrics_to_log.append("i_acc")
|
|
1435
|
+
if debug_log_level:
|
|
1436
|
+
self.metrics_to_log.append("i_loss")
|
|
1437
|
+
if self.config[ENTITY_RECOGNITION]:
|
|
1438
|
+
for tag_spec in self._entity_tag_specs:
|
|
1439
|
+
if tag_spec.num_tags != 0:
|
|
1440
|
+
name = tag_spec.tag_name
|
|
1441
|
+
self.metrics_to_log.append(f"{name[0]}_f1")
|
|
1442
|
+
if debug_log_level:
|
|
1443
|
+
self.metrics_to_log.append(f"{name[0]}_loss")
|
|
1444
|
+
|
|
1445
|
+
self._log_metric_info()
|
|
1446
|
+
|
|
1447
|
+
def _log_metric_info(self) -> None:
|
|
1448
|
+
metric_name = {
|
|
1449
|
+
"t": "total",
|
|
1450
|
+
"i": "intent",
|
|
1451
|
+
"e": "entity",
|
|
1452
|
+
"m": "mask",
|
|
1453
|
+
"r": "role",
|
|
1454
|
+
"g": "group",
|
|
1455
|
+
}
|
|
1456
|
+
logger.debug("Following metrics will be logged during training: ")
|
|
1457
|
+
for metric in self.metrics_to_log:
|
|
1458
|
+
parts = metric.split("_")
|
|
1459
|
+
name = f"{metric_name[parts[0]]} {parts[1]}"
|
|
1460
|
+
logger.debug(f" {metric} ({name})")
|
|
1461
|
+
|
|
1462
|
+
def _prepare_layers(self) -> None:
|
|
1463
|
+
# For user text, prepare layers that combine different feature types, embed
|
|
1464
|
+
# everything using a transformer and optionally also do masked language
|
|
1465
|
+
# modeling.
|
|
1466
|
+
self.text_name = TEXT
|
|
1467
|
+
self._tf_layers[f"sequence_layer.{self.text_name}"] = (
|
|
1468
|
+
rasa_layers.RasaSequenceLayer(
|
|
1469
|
+
self.text_name, self.data_signature[self.text_name], self.config
|
|
1470
|
+
)
|
|
1471
|
+
)
|
|
1472
|
+
if self.config[MASKED_LM]:
|
|
1473
|
+
self._prepare_mask_lm_loss(self.text_name)
|
|
1474
|
+
|
|
1475
|
+
# Intent labels are treated similarly to user text but without the transformer,
|
|
1476
|
+
# without masked language modelling, and with no dropout applied to the
|
|
1477
|
+
# individual features, only to the overall label embedding after all label
|
|
1478
|
+
# features have been combined.
|
|
1479
|
+
if self.config[INTENT_CLASSIFICATION]:
|
|
1480
|
+
self.label_name = TEXT if self.config[SHARE_HIDDEN_LAYERS] else LABEL
|
|
1481
|
+
|
|
1482
|
+
# disable input dropout applied to sparse and dense label features
|
|
1483
|
+
label_config = self.config.copy()
|
|
1484
|
+
label_config.update(
|
|
1485
|
+
{SPARSE_INPUT_DROPOUT: False, DENSE_INPUT_DROPOUT: False}
|
|
1486
|
+
)
|
|
1487
|
+
|
|
1488
|
+
self._tf_layers[f"feature_combining_layer.{self.label_name}"] = (
|
|
1489
|
+
rasa_layers.RasaFeatureCombiningLayer(
|
|
1490
|
+
self.label_name, self.label_signature[self.label_name], label_config
|
|
1491
|
+
)
|
|
1492
|
+
)
|
|
1493
|
+
|
|
1494
|
+
self._prepare_ffnn_layer(
|
|
1495
|
+
self.label_name,
|
|
1496
|
+
self.config[HIDDEN_LAYERS_SIZES][self.label_name],
|
|
1497
|
+
self.config[DROP_RATE],
|
|
1498
|
+
)
|
|
1499
|
+
|
|
1500
|
+
self._prepare_label_classification_layers(predictor_attribute=TEXT)
|
|
1501
|
+
|
|
1502
|
+
if self.config[ENTITY_RECOGNITION]:
|
|
1503
|
+
self._prepare_entity_recognition_layers()
|
|
1504
|
+
|
|
1505
|
+
def _prepare_mask_lm_loss(self, name: Text) -> None:
|
|
1506
|
+
# for embedding predicted tokens at masked positions
|
|
1507
|
+
self._prepare_embed_layers(f"{name}_lm_mask")
|
|
1508
|
+
|
|
1509
|
+
# for embedding the true tokens that got masked
|
|
1510
|
+
self._prepare_embed_layers(f"{name}_golden_token")
|
|
1511
|
+
|
|
1512
|
+
# mask loss is additional loss
|
|
1513
|
+
# set scaling to False, so that it doesn't overpower other losses
|
|
1514
|
+
self._prepare_dot_product_loss(f"{name}_mask", scale_loss=False)
|
|
1515
|
+
|
|
1516
|
+
def _create_bow(
|
|
1517
|
+
self,
|
|
1518
|
+
sequence_features: List[Union[tf.Tensor, tf.SparseTensor]],
|
|
1519
|
+
sentence_features: List[Union[tf.Tensor, tf.SparseTensor]],
|
|
1520
|
+
sequence_feature_lengths: tf.Tensor,
|
|
1521
|
+
name: Text,
|
|
1522
|
+
) -> tf.Tensor:
|
|
1523
|
+
x, _ = self._tf_layers[f"feature_combining_layer.{name}"](
|
|
1524
|
+
(sequence_features, sentence_features, sequence_feature_lengths),
|
|
1525
|
+
training=self._training,
|
|
1526
|
+
)
|
|
1527
|
+
|
|
1528
|
+
# convert to bag-of-words by summing along the sequence dimension
|
|
1529
|
+
x = tf.reduce_sum(x, axis=1)
|
|
1530
|
+
|
|
1531
|
+
return self._tf_layers[f"ffnn.{name}"](x, self._training)
|
|
1532
|
+
|
|
1533
|
+
def _create_all_labels(self) -> Tuple[tf.Tensor, tf.Tensor]:
|
|
1534
|
+
all_label_ids = self.tf_label_data[LABEL_KEY][LABEL_SUB_KEY][0]
|
|
1535
|
+
|
|
1536
|
+
sequence_feature_lengths = self._get_sequence_feature_lengths(
|
|
1537
|
+
self.tf_label_data, LABEL
|
|
1538
|
+
)
|
|
1539
|
+
|
|
1540
|
+
x = self._create_bow(
|
|
1541
|
+
self.tf_label_data[LABEL][SEQUENCE],
|
|
1542
|
+
self.tf_label_data[LABEL][SENTENCE],
|
|
1543
|
+
sequence_feature_lengths,
|
|
1544
|
+
self.label_name,
|
|
1545
|
+
)
|
|
1546
|
+
all_labels_embed = self._tf_layers[f"embed.{LABEL}"](x)
|
|
1547
|
+
|
|
1548
|
+
return all_label_ids, all_labels_embed
|
|
1549
|
+
|
|
1550
|
+
def _mask_loss(
|
|
1551
|
+
self,
|
|
1552
|
+
outputs: tf.Tensor,
|
|
1553
|
+
inputs: tf.Tensor,
|
|
1554
|
+
seq_ids: tf.Tensor,
|
|
1555
|
+
mlm_mask_boolean: tf.Tensor,
|
|
1556
|
+
name: Text,
|
|
1557
|
+
) -> tf.Tensor:
|
|
1558
|
+
# make sure there is at least one element in the mask
|
|
1559
|
+
mlm_mask_boolean = tf.cond(
|
|
1560
|
+
tf.reduce_any(mlm_mask_boolean),
|
|
1561
|
+
lambda: mlm_mask_boolean,
|
|
1562
|
+
lambda: tf.scatter_nd([[0, 0, 0]], [True], tf.shape(mlm_mask_boolean)),
|
|
1563
|
+
)
|
|
1564
|
+
|
|
1565
|
+
mlm_mask_boolean = tf.squeeze(mlm_mask_boolean, -1)
|
|
1566
|
+
|
|
1567
|
+
# Pick elements that were masked, throwing away the batch & sequence dimension
|
|
1568
|
+
# and effectively switching from shape (batch_size, sequence_length, units) to
|
|
1569
|
+
# (num_masked_elements, units).
|
|
1570
|
+
outputs = tf.boolean_mask(outputs, mlm_mask_boolean)
|
|
1571
|
+
inputs = tf.boolean_mask(inputs, mlm_mask_boolean)
|
|
1572
|
+
ids = tf.boolean_mask(seq_ids, mlm_mask_boolean)
|
|
1573
|
+
|
|
1574
|
+
tokens_predicted_embed = self._tf_layers[f"embed.{name}_lm_mask"](outputs)
|
|
1575
|
+
tokens_true_embed = self._tf_layers[f"embed.{name}_golden_token"](inputs)
|
|
1576
|
+
|
|
1577
|
+
# To limit the otherwise computationally expensive loss calculation, we
|
|
1578
|
+
# constrain the label space in MLM (i.e. token space) to only those tokens that
|
|
1579
|
+
# were masked in this batch. Hence the reduced list of token embeddings
|
|
1580
|
+
# (tokens_true_embed) and the reduced list of labels (ids) are passed as
|
|
1581
|
+
# all_labels_embed and all_labels, respectively. In the future, we could be less
|
|
1582
|
+
# restrictive and construct a slightly bigger label space which could include
|
|
1583
|
+
# tokens not masked in the current batch too.
|
|
1584
|
+
return self._tf_layers[f"loss.{name}_mask"](
|
|
1585
|
+
inputs_embed=tokens_predicted_embed,
|
|
1586
|
+
labels_embed=tokens_true_embed,
|
|
1587
|
+
labels=ids,
|
|
1588
|
+
all_labels_embed=tokens_true_embed,
|
|
1589
|
+
all_labels=ids,
|
|
1590
|
+
)
|
|
1591
|
+
|
|
1592
|
+
def _calculate_label_loss(
|
|
1593
|
+
self, text_features: tf.Tensor, label_features: tf.Tensor, label_ids: tf.Tensor
|
|
1594
|
+
) -> tf.Tensor:
|
|
1595
|
+
all_label_ids, all_labels_embed = self._create_all_labels()
|
|
1596
|
+
|
|
1597
|
+
text_embed = self._tf_layers[f"embed.{TEXT}"](text_features)
|
|
1598
|
+
label_embed = self._tf_layers[f"embed.{LABEL}"](label_features)
|
|
1599
|
+
|
|
1600
|
+
return self._tf_layers[f"loss.{LABEL}"](
|
|
1601
|
+
text_embed, label_embed, label_ids, all_labels_embed, all_label_ids
|
|
1602
|
+
)
|
|
1603
|
+
|
|
1604
|
+
def batch_loss(
|
|
1605
|
+
self, batch_in: Union[Tuple[tf.Tensor, ...], Tuple[np.ndarray, ...]]
|
|
1606
|
+
) -> tf.Tensor:
|
|
1607
|
+
"""Calculates the loss for the given batch.
|
|
1608
|
+
|
|
1609
|
+
Args:
|
|
1610
|
+
batch_in: The batch.
|
|
1611
|
+
|
|
1612
|
+
Returns:
|
|
1613
|
+
The loss of the given batch.
|
|
1614
|
+
"""
|
|
1615
|
+
tf_batch_data = self.batch_to_model_data_format(batch_in, self.data_signature)
|
|
1616
|
+
|
|
1617
|
+
sequence_feature_lengths = self._get_sequence_feature_lengths(
|
|
1618
|
+
tf_batch_data, TEXT
|
|
1619
|
+
)
|
|
1620
|
+
|
|
1621
|
+
(
|
|
1622
|
+
text_transformed,
|
|
1623
|
+
text_in,
|
|
1624
|
+
mask_combined_sequence_sentence,
|
|
1625
|
+
text_seq_ids,
|
|
1626
|
+
mlm_mask_boolean_text,
|
|
1627
|
+
_,
|
|
1628
|
+
) = self._tf_layers[f"sequence_layer.{self.text_name}"](
|
|
1629
|
+
(
|
|
1630
|
+
tf_batch_data[TEXT][SEQUENCE],
|
|
1631
|
+
tf_batch_data[TEXT][SENTENCE],
|
|
1632
|
+
sequence_feature_lengths,
|
|
1633
|
+
),
|
|
1634
|
+
training=self._training,
|
|
1635
|
+
)
|
|
1636
|
+
|
|
1637
|
+
losses = []
|
|
1638
|
+
|
|
1639
|
+
# Lengths of sequences in case of sentence-level features are always 1, but they
|
|
1640
|
+
# can effectively be 0 if sentence-level features aren't present.
|
|
1641
|
+
sentence_feature_lengths = self._get_sentence_feature_lengths(
|
|
1642
|
+
tf_batch_data, TEXT
|
|
1643
|
+
)
|
|
1644
|
+
|
|
1645
|
+
combined_sequence_sentence_feature_lengths = (
|
|
1646
|
+
sequence_feature_lengths + sentence_feature_lengths
|
|
1647
|
+
)
|
|
1648
|
+
|
|
1649
|
+
if self.config[MASKED_LM] and self._training:
|
|
1650
|
+
loss, acc = self._mask_loss(
|
|
1651
|
+
text_transformed, text_in, text_seq_ids, mlm_mask_boolean_text, TEXT
|
|
1652
|
+
)
|
|
1653
|
+
self.mask_loss.update_state(loss)
|
|
1654
|
+
self.mask_acc.update_state(acc)
|
|
1655
|
+
losses.append(loss)
|
|
1656
|
+
|
|
1657
|
+
if self.config[INTENT_CLASSIFICATION]:
|
|
1658
|
+
loss = self._batch_loss_intent(
|
|
1659
|
+
combined_sequence_sentence_feature_lengths,
|
|
1660
|
+
text_transformed,
|
|
1661
|
+
tf_batch_data,
|
|
1662
|
+
)
|
|
1663
|
+
losses.append(loss)
|
|
1664
|
+
|
|
1665
|
+
if self.config[ENTITY_RECOGNITION]:
|
|
1666
|
+
losses += self._batch_loss_entities(
|
|
1667
|
+
mask_combined_sequence_sentence,
|
|
1668
|
+
sequence_feature_lengths,
|
|
1669
|
+
text_transformed,
|
|
1670
|
+
tf_batch_data,
|
|
1671
|
+
)
|
|
1672
|
+
|
|
1673
|
+
return tf.math.add_n(losses)
|
|
1674
|
+
|
|
1675
|
+
def _batch_loss_intent(
|
|
1676
|
+
self,
|
|
1677
|
+
combined_sequence_sentence_feature_lengths_text: tf.Tensor,
|
|
1678
|
+
text_transformed: tf.Tensor,
|
|
1679
|
+
tf_batch_data: Dict[Text, Dict[Text, List[tf.Tensor]]],
|
|
1680
|
+
) -> tf.Tensor:
|
|
1681
|
+
# get sentence features vector for intent classification
|
|
1682
|
+
sentence_vector = self._last_token(
|
|
1683
|
+
text_transformed, combined_sequence_sentence_feature_lengths_text
|
|
1684
|
+
)
|
|
1685
|
+
|
|
1686
|
+
sequence_feature_lengths_label = self._get_sequence_feature_lengths(
|
|
1687
|
+
tf_batch_data, LABEL
|
|
1688
|
+
)
|
|
1689
|
+
|
|
1690
|
+
label_ids = tf_batch_data[LABEL_KEY][LABEL_SUB_KEY][0]
|
|
1691
|
+
label = self._create_bow(
|
|
1692
|
+
tf_batch_data[LABEL][SEQUENCE],
|
|
1693
|
+
tf_batch_data[LABEL][SENTENCE],
|
|
1694
|
+
sequence_feature_lengths_label,
|
|
1695
|
+
self.label_name,
|
|
1696
|
+
)
|
|
1697
|
+
loss, acc = self._calculate_label_loss(sentence_vector, label, label_ids)
|
|
1698
|
+
|
|
1699
|
+
self._update_label_metrics(loss, acc)
|
|
1700
|
+
|
|
1701
|
+
return loss
|
|
1702
|
+
|
|
1703
|
+
def _update_label_metrics(self, loss: tf.Tensor, acc: tf.Tensor) -> None:
|
|
1704
|
+
self.intent_loss.update_state(loss)
|
|
1705
|
+
self.intent_acc.update_state(acc)
|
|
1706
|
+
|
|
1707
|
+
def _batch_loss_entities(
|
|
1708
|
+
self,
|
|
1709
|
+
mask_combined_sequence_sentence: tf.Tensor,
|
|
1710
|
+
sequence_feature_lengths: tf.Tensor,
|
|
1711
|
+
text_transformed: tf.Tensor,
|
|
1712
|
+
tf_batch_data: Dict[Text, Dict[Text, List[tf.Tensor]]],
|
|
1713
|
+
) -> List[tf.Tensor]:
|
|
1714
|
+
losses = []
|
|
1715
|
+
|
|
1716
|
+
entity_tags = None
|
|
1717
|
+
|
|
1718
|
+
for tag_spec in self._entity_tag_specs:
|
|
1719
|
+
if tag_spec.num_tags == 0:
|
|
1720
|
+
continue
|
|
1721
|
+
|
|
1722
|
+
tag_ids = tf_batch_data[ENTITIES][tag_spec.tag_name][0]
|
|
1723
|
+
# add a zero (no entity) for the sentence features to match the shape of
|
|
1724
|
+
# inputs
|
|
1725
|
+
tag_ids = tf.pad(tag_ids, [[0, 0], [0, 1], [0, 0]])
|
|
1726
|
+
|
|
1727
|
+
loss, f1, _logits = self._calculate_entity_loss(
|
|
1728
|
+
text_transformed,
|
|
1729
|
+
tag_ids,
|
|
1730
|
+
mask_combined_sequence_sentence,
|
|
1731
|
+
sequence_feature_lengths,
|
|
1732
|
+
tag_spec.tag_name,
|
|
1733
|
+
entity_tags,
|
|
1734
|
+
)
|
|
1735
|
+
|
|
1736
|
+
if tag_spec.tag_name == ENTITY_ATTRIBUTE_TYPE:
|
|
1737
|
+
# use the entity tags as additional input for the role
|
|
1738
|
+
# and group CRF
|
|
1739
|
+
entity_tags = tf.one_hot(
|
|
1740
|
+
tf.cast(tag_ids[:, :, 0], tf.int32), depth=tag_spec.num_tags
|
|
1741
|
+
)
|
|
1742
|
+
|
|
1743
|
+
self._update_entity_metrics(loss, f1, tag_spec.tag_name)
|
|
1744
|
+
|
|
1745
|
+
losses.append(loss)
|
|
1746
|
+
|
|
1747
|
+
return losses
|
|
1748
|
+
|
|
1749
|
+
def _update_entity_metrics(
|
|
1750
|
+
self, loss: tf.Tensor, f1: tf.Tensor, tag_name: Text
|
|
1751
|
+
) -> None:
|
|
1752
|
+
if tag_name == ENTITY_ATTRIBUTE_TYPE:
|
|
1753
|
+
self.entity_loss.update_state(loss)
|
|
1754
|
+
self.entity_f1.update_state(f1)
|
|
1755
|
+
elif tag_name == ENTITY_ATTRIBUTE_GROUP:
|
|
1756
|
+
self.entity_group_loss.update_state(loss)
|
|
1757
|
+
self.entity_group_f1.update_state(f1)
|
|
1758
|
+
elif tag_name == ENTITY_ATTRIBUTE_ROLE:
|
|
1759
|
+
self.entity_role_loss.update_state(loss)
|
|
1760
|
+
self.entity_role_f1.update_state(f1)
|
|
1761
|
+
|
|
1762
|
+
def prepare_for_predict(self) -> None:
|
|
1763
|
+
"""Prepares the model for prediction."""
|
|
1764
|
+
if self.config[INTENT_CLASSIFICATION]:
|
|
1765
|
+
_, self.all_labels_embed = self._create_all_labels()
|
|
1766
|
+
|
|
1767
|
+
def batch_predict(
|
|
1768
|
+
self, batch_in: Union[Tuple[tf.Tensor, ...], Tuple[np.ndarray, ...]]
|
|
1769
|
+
) -> Dict[Text, tf.Tensor]:
|
|
1770
|
+
"""Predicts the output of the given batch.
|
|
1771
|
+
|
|
1772
|
+
Args:
|
|
1773
|
+
batch_in: The batch.
|
|
1774
|
+
|
|
1775
|
+
Returns:
|
|
1776
|
+
The output to predict.
|
|
1777
|
+
"""
|
|
1778
|
+
tf_batch_data = self.batch_to_model_data_format(
|
|
1779
|
+
batch_in, self.predict_data_signature
|
|
1780
|
+
)
|
|
1781
|
+
|
|
1782
|
+
sequence_feature_lengths = self._get_sequence_feature_lengths(
|
|
1783
|
+
tf_batch_data, TEXT
|
|
1784
|
+
)
|
|
1785
|
+
sentence_feature_lengths = self._get_sentence_feature_lengths(
|
|
1786
|
+
tf_batch_data, TEXT
|
|
1787
|
+
)
|
|
1788
|
+
|
|
1789
|
+
text_transformed, _, _, _, _, attention_weights = self._tf_layers[
|
|
1790
|
+
f"sequence_layer.{self.text_name}"
|
|
1791
|
+
](
|
|
1792
|
+
(
|
|
1793
|
+
tf_batch_data[TEXT][SEQUENCE],
|
|
1794
|
+
tf_batch_data[TEXT][SENTENCE],
|
|
1795
|
+
sequence_feature_lengths,
|
|
1796
|
+
),
|
|
1797
|
+
training=self._training,
|
|
1798
|
+
)
|
|
1799
|
+
predictions = {
|
|
1800
|
+
DIAGNOSTIC_DATA: {
|
|
1801
|
+
"attention_weights": attention_weights,
|
|
1802
|
+
"text_transformed": text_transformed,
|
|
1803
|
+
}
|
|
1804
|
+
}
|
|
1805
|
+
|
|
1806
|
+
if self.config[INTENT_CLASSIFICATION]:
|
|
1807
|
+
predictions.update(
|
|
1808
|
+
self._batch_predict_intents(
|
|
1809
|
+
sequence_feature_lengths + sentence_feature_lengths,
|
|
1810
|
+
text_transformed,
|
|
1811
|
+
)
|
|
1812
|
+
)
|
|
1813
|
+
|
|
1814
|
+
if self.config[ENTITY_RECOGNITION]:
|
|
1815
|
+
predictions.update(
|
|
1816
|
+
self._batch_predict_entities(sequence_feature_lengths, text_transformed)
|
|
1817
|
+
)
|
|
1818
|
+
|
|
1819
|
+
return predictions
|
|
1820
|
+
|
|
1821
|
+
def _batch_predict_entities(
|
|
1822
|
+
self, sequence_feature_lengths: tf.Tensor, text_transformed: tf.Tensor
|
|
1823
|
+
) -> Dict[Text, tf.Tensor]:
|
|
1824
|
+
predictions: Dict[Text, tf.Tensor] = {}
|
|
1825
|
+
|
|
1826
|
+
entity_tags = None
|
|
1827
|
+
|
|
1828
|
+
for tag_spec in self._entity_tag_specs:
|
|
1829
|
+
# skip crf layer if it was not trained
|
|
1830
|
+
if tag_spec.num_tags == 0:
|
|
1831
|
+
continue
|
|
1832
|
+
|
|
1833
|
+
name = tag_spec.tag_name
|
|
1834
|
+
_input = text_transformed
|
|
1835
|
+
|
|
1836
|
+
if entity_tags is not None:
|
|
1837
|
+
_tags = self._tf_layers[f"embed.{name}.tags"](entity_tags)
|
|
1838
|
+
_input = tf.concat([_input, _tags], axis=-1)
|
|
1839
|
+
|
|
1840
|
+
_logits = self._tf_layers[f"embed.{name}.logits"](_input)
|
|
1841
|
+
pred_ids, confidences = self._tf_layers[f"crf.{name}"](
|
|
1842
|
+
_logits, sequence_feature_lengths
|
|
1843
|
+
)
|
|
1844
|
+
|
|
1845
|
+
predictions[f"e_{name}_ids"] = pred_ids
|
|
1846
|
+
predictions[f"e_{name}_scores"] = confidences
|
|
1847
|
+
|
|
1848
|
+
if name == ENTITY_ATTRIBUTE_TYPE:
|
|
1849
|
+
# use the entity tags as additional input for the role
|
|
1850
|
+
# and group CRF
|
|
1851
|
+
entity_tags = tf.one_hot(
|
|
1852
|
+
tf.cast(pred_ids, tf.int32), depth=tag_spec.num_tags
|
|
1853
|
+
)
|
|
1854
|
+
|
|
1855
|
+
return predictions
|
|
1856
|
+
|
|
1857
|
+
def _batch_predict_intents(
|
|
1858
|
+
self,
|
|
1859
|
+
combined_sequence_sentence_feature_lengths: tf.Tensor,
|
|
1860
|
+
text_transformed: tf.Tensor,
|
|
1861
|
+
) -> Dict[Text, tf.Tensor]:
|
|
1862
|
+
if self.all_labels_embed is None:
|
|
1863
|
+
raise ValueError(
|
|
1864
|
+
"The model was not prepared for prediction. "
|
|
1865
|
+
"Call `prepare_for_predict` first."
|
|
1866
|
+
)
|
|
1867
|
+
|
|
1868
|
+
# get sentence feature vector for intent classification
|
|
1869
|
+
sentence_vector = self._last_token(
|
|
1870
|
+
text_transformed, combined_sequence_sentence_feature_lengths
|
|
1871
|
+
)
|
|
1872
|
+
sentence_vector_embed = self._tf_layers[f"embed.{TEXT}"](sentence_vector)
|
|
1873
|
+
|
|
1874
|
+
_, scores = self._tf_layers[
|
|
1875
|
+
f"loss.{LABEL}"
|
|
1876
|
+
].get_similarities_and_confidences_from_embeddings(
|
|
1877
|
+
sentence_vector_embed[:, tf.newaxis, :],
|
|
1878
|
+
self.all_labels_embed[tf.newaxis, :, :],
|
|
1879
|
+
)
|
|
1880
|
+
|
|
1881
|
+
return {"i_scores": scores}
|