rasa-pro 3.12.0.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- README.md +41 -0
- rasa/__init__.py +9 -0
- rasa/__main__.py +177 -0
- rasa/anonymization/__init__.py +2 -0
- rasa/anonymization/anonymisation_rule_yaml_reader.py +91 -0
- rasa/anonymization/anonymization_pipeline.py +286 -0
- rasa/anonymization/anonymization_rule_executor.py +260 -0
- rasa/anonymization/anonymization_rule_orchestrator.py +120 -0
- rasa/anonymization/schemas/config.yml +47 -0
- rasa/anonymization/utils.py +118 -0
- rasa/api.py +160 -0
- rasa/cli/__init__.py +5 -0
- rasa/cli/arguments/__init__.py +0 -0
- rasa/cli/arguments/data.py +106 -0
- rasa/cli/arguments/default_arguments.py +207 -0
- rasa/cli/arguments/evaluate.py +65 -0
- rasa/cli/arguments/export.py +51 -0
- rasa/cli/arguments/interactive.py +74 -0
- rasa/cli/arguments/run.py +219 -0
- rasa/cli/arguments/shell.py +17 -0
- rasa/cli/arguments/test.py +211 -0
- rasa/cli/arguments/train.py +279 -0
- rasa/cli/arguments/visualize.py +34 -0
- rasa/cli/arguments/x.py +30 -0
- rasa/cli/data.py +354 -0
- rasa/cli/dialogue_understanding_test.py +251 -0
- rasa/cli/e2e_test.py +259 -0
- rasa/cli/evaluate.py +222 -0
- rasa/cli/export.py +250 -0
- rasa/cli/inspect.py +75 -0
- rasa/cli/interactive.py +166 -0
- rasa/cli/license.py +65 -0
- rasa/cli/llm_fine_tuning.py +403 -0
- rasa/cli/markers.py +78 -0
- rasa/cli/project_templates/__init__.py +0 -0
- rasa/cli/project_templates/calm/actions/__init__.py +0 -0
- rasa/cli/project_templates/calm/actions/action_template.py +27 -0
- rasa/cli/project_templates/calm/actions/add_contact.py +30 -0
- rasa/cli/project_templates/calm/actions/db.py +57 -0
- rasa/cli/project_templates/calm/actions/list_contacts.py +22 -0
- rasa/cli/project_templates/calm/actions/remove_contact.py +35 -0
- rasa/cli/project_templates/calm/config.yml +10 -0
- rasa/cli/project_templates/calm/credentials.yml +33 -0
- rasa/cli/project_templates/calm/data/flows/add_contact.yml +31 -0
- rasa/cli/project_templates/calm/data/flows/list_contacts.yml +14 -0
- rasa/cli/project_templates/calm/data/flows/remove_contact.yml +29 -0
- rasa/cli/project_templates/calm/db/contacts.json +10 -0
- rasa/cli/project_templates/calm/domain/add_contact.yml +39 -0
- rasa/cli/project_templates/calm/domain/list_contacts.yml +17 -0
- rasa/cli/project_templates/calm/domain/remove_contact.yml +38 -0
- rasa/cli/project_templates/calm/domain/shared.yml +10 -0
- rasa/cli/project_templates/calm/e2e_tests/cancelations/user_cancels_during_a_correction.yml +16 -0
- rasa/cli/project_templates/calm/e2e_tests/cancelations/user_changes_mind_on_a_whim.yml +7 -0
- rasa/cli/project_templates/calm/e2e_tests/corrections/user_corrects_contact_handle.yml +20 -0
- rasa/cli/project_templates/calm/e2e_tests/corrections/user_corrects_contact_name.yml +19 -0
- rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_adds_contact_to_their_list.yml +15 -0
- rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_lists_contacts.yml +5 -0
- rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_removes_contact.yml +11 -0
- rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_removes_contact_from_list.yml +12 -0
- rasa/cli/project_templates/calm/endpoints.yml +58 -0
- rasa/cli/project_templates/default/actions/__init__.py +0 -0
- rasa/cli/project_templates/default/actions/actions.py +27 -0
- rasa/cli/project_templates/default/config.yml +44 -0
- rasa/cli/project_templates/default/credentials.yml +33 -0
- rasa/cli/project_templates/default/data/nlu.yml +91 -0
- rasa/cli/project_templates/default/data/rules.yml +13 -0
- rasa/cli/project_templates/default/data/stories.yml +30 -0
- rasa/cli/project_templates/default/domain.yml +34 -0
- rasa/cli/project_templates/default/endpoints.yml +42 -0
- rasa/cli/project_templates/default/tests/test_stories.yml +91 -0
- rasa/cli/project_templates/tutorial/actions/__init__.py +0 -0
- rasa/cli/project_templates/tutorial/actions/actions.py +22 -0
- rasa/cli/project_templates/tutorial/config.yml +12 -0
- rasa/cli/project_templates/tutorial/credentials.yml +33 -0
- rasa/cli/project_templates/tutorial/data/flows.yml +8 -0
- rasa/cli/project_templates/tutorial/data/patterns.yml +11 -0
- rasa/cli/project_templates/tutorial/domain.yml +35 -0
- rasa/cli/project_templates/tutorial/endpoints.yml +55 -0
- rasa/cli/run.py +143 -0
- rasa/cli/scaffold.py +273 -0
- rasa/cli/shell.py +141 -0
- rasa/cli/studio/__init__.py +0 -0
- rasa/cli/studio/download.py +62 -0
- rasa/cli/studio/studio.py +296 -0
- rasa/cli/studio/train.py +59 -0
- rasa/cli/studio/upload.py +62 -0
- rasa/cli/telemetry.py +102 -0
- rasa/cli/test.py +280 -0
- rasa/cli/train.py +278 -0
- rasa/cli/utils.py +484 -0
- rasa/cli/visualize.py +40 -0
- rasa/cli/x.py +206 -0
- rasa/constants.py +45 -0
- rasa/core/__init__.py +17 -0
- rasa/core/actions/__init__.py +0 -0
- rasa/core/actions/action.py +1318 -0
- rasa/core/actions/action_clean_stack.py +59 -0
- rasa/core/actions/action_exceptions.py +24 -0
- rasa/core/actions/action_hangup.py +29 -0
- rasa/core/actions/action_repeat_bot_messages.py +89 -0
- rasa/core/actions/action_run_slot_rejections.py +210 -0
- rasa/core/actions/action_trigger_chitchat.py +31 -0
- rasa/core/actions/action_trigger_flow.py +109 -0
- rasa/core/actions/action_trigger_search.py +31 -0
- rasa/core/actions/constants.py +5 -0
- rasa/core/actions/custom_action_executor.py +191 -0
- rasa/core/actions/direct_custom_actions_executor.py +109 -0
- rasa/core/actions/e2e_stub_custom_action_executor.py +72 -0
- rasa/core/actions/forms.py +741 -0
- rasa/core/actions/grpc_custom_action_executor.py +251 -0
- rasa/core/actions/http_custom_action_executor.py +145 -0
- rasa/core/actions/loops.py +114 -0
- rasa/core/actions/two_stage_fallback.py +186 -0
- rasa/core/agent.py +559 -0
- rasa/core/auth_retry_tracker_store.py +122 -0
- rasa/core/brokers/__init__.py +0 -0
- rasa/core/brokers/broker.py +126 -0
- rasa/core/brokers/file.py +58 -0
- rasa/core/brokers/kafka.py +324 -0
- rasa/core/brokers/pika.py +388 -0
- rasa/core/brokers/sql.py +86 -0
- rasa/core/channels/__init__.py +61 -0
- rasa/core/channels/botframework.py +338 -0
- rasa/core/channels/callback.py +84 -0
- rasa/core/channels/channel.py +456 -0
- rasa/core/channels/console.py +241 -0
- rasa/core/channels/development_inspector.py +197 -0
- rasa/core/channels/facebook.py +419 -0
- rasa/core/channels/hangouts.py +329 -0
- rasa/core/channels/inspector/.eslintrc.cjs +25 -0
- rasa/core/channels/inspector/.gitignore +23 -0
- rasa/core/channels/inspector/README.md +54 -0
- rasa/core/channels/inspector/assets/favicon.ico +0 -0
- rasa/core/channels/inspector/assets/rasa-chat.js +2 -0
- rasa/core/channels/inspector/custom.d.ts +3 -0
- rasa/core/channels/inspector/dist/assets/arc-861ddd57.js +1 -0
- rasa/core/channels/inspector/dist/assets/array-9f3ba611.js +1 -0
- rasa/core/channels/inspector/dist/assets/c4Diagram-d0fbc5ce-921f02db.js +10 -0
- rasa/core/channels/inspector/dist/assets/classDiagram-936ed81e-b436c4f8.js +2 -0
- rasa/core/channels/inspector/dist/assets/classDiagram-v2-c3cb15f1-511a23cb.js +2 -0
- rasa/core/channels/inspector/dist/assets/createText-62fc7601-ef476ecd.js +7 -0
- rasa/core/channels/inspector/dist/assets/edges-f2ad444c-f1878e0a.js +4 -0
- rasa/core/channels/inspector/dist/assets/erDiagram-9d236eb7-fac75185.js +51 -0
- rasa/core/channels/inspector/dist/assets/flowDb-1972c806-201c5bbc.js +6 -0
- rasa/core/channels/inspector/dist/assets/flowDiagram-7ea5b25a-f904ae41.js +4 -0
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-b080d6f2.js +1 -0
- rasa/core/channels/inspector/dist/assets/flowchart-elk-definition-abe16c3d-1813da66.js +139 -0
- rasa/core/channels/inspector/dist/assets/ganttDiagram-9b5ea136-872af172.js +266 -0
- rasa/core/channels/inspector/dist/assets/gitGraphDiagram-99d0ae7c-34a0af5a.js +70 -0
- rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-128cfa44.ttf +0 -0
- rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-21dbcb97.woff +0 -0
- rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-222b5e26.svg +329 -0
- rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-9ad89b2a.woff2 +0 -0
- rasa/core/channels/inspector/dist/assets/index-2c4b9a3b-42ba3e3d.js +1 -0
- rasa/core/channels/inspector/dist/assets/index-37817b51.js +1317 -0
- rasa/core/channels/inspector/dist/assets/index-3ee28881.css +1 -0
- rasa/core/channels/inspector/dist/assets/infoDiagram-736b4530-6b731386.js +7 -0
- rasa/core/channels/inspector/dist/assets/init-77b53fdd.js +1 -0
- rasa/core/channels/inspector/dist/assets/journeyDiagram-df861f2b-e8579ac6.js +139 -0
- rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-60c05ee4.woff +0 -0
- rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-8335d9b8.svg +438 -0
- rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-9cc39c75.ttf +0 -0
- rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-ead13ccf.woff2 +0 -0
- rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-16705655.woff2 +0 -0
- rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-5aeb07f9.woff +0 -0
- rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-9c459044.ttf +0 -0
- rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-9e2898a4.svg +435 -0
- rasa/core/channels/inspector/dist/assets/layout-89e6403a.js +1 -0
- rasa/core/channels/inspector/dist/assets/line-dc73d3fc.js +1 -0
- rasa/core/channels/inspector/dist/assets/linear-f5b1d2bc.js +1 -0
- rasa/core/channels/inspector/dist/assets/mindmap-definition-beec6740-82cb74fa.js +109 -0
- rasa/core/channels/inspector/dist/assets/ordinal-ba9b4969.js +1 -0
- rasa/core/channels/inspector/dist/assets/path-53f90ab3.js +1 -0
- rasa/core/channels/inspector/dist/assets/pieDiagram-dbbf0591-bdf5f29b.js +35 -0
- rasa/core/channels/inspector/dist/assets/quadrantDiagram-4d7f4fd6-c7a0cbe4.js +7 -0
- rasa/core/channels/inspector/dist/assets/requirementDiagram-6fc4c22a-7ec5410f.js +52 -0
- rasa/core/channels/inspector/dist/assets/sankeyDiagram-8f13d901-caee5554.js +8 -0
- rasa/core/channels/inspector/dist/assets/sequenceDiagram-b655622a-2935f8db.js +122 -0
- rasa/core/channels/inspector/dist/assets/stateDiagram-59f0c015-8f5d9693.js +1 -0
- rasa/core/channels/inspector/dist/assets/stateDiagram-v2-2b26beab-d565d1de.js +1 -0
- rasa/core/channels/inspector/dist/assets/styles-080da4f6-75ad421d.js +110 -0
- rasa/core/channels/inspector/dist/assets/styles-3dcbcfbf-7e764226.js +159 -0
- rasa/core/channels/inspector/dist/assets/styles-9c745c82-7a4e0e61.js +207 -0
- rasa/core/channels/inspector/dist/assets/svgDrawCommon-4835440b-4019d1bf.js +1 -0
- rasa/core/channels/inspector/dist/assets/timeline-definition-5b62e21b-01ea12df.js +61 -0
- rasa/core/channels/inspector/dist/assets/xychartDiagram-2b33534f-89407137.js +7 -0
- rasa/core/channels/inspector/dist/index.html +42 -0
- rasa/core/channels/inspector/index.html +40 -0
- rasa/core/channels/inspector/jest.config.ts +13 -0
- rasa/core/channels/inspector/package.json +52 -0
- rasa/core/channels/inspector/setupTests.ts +2 -0
- rasa/core/channels/inspector/src/App.tsx +220 -0
- rasa/core/channels/inspector/src/components/Chat.tsx +95 -0
- rasa/core/channels/inspector/src/components/DiagramFlow.tsx +108 -0
- rasa/core/channels/inspector/src/components/DialogueInformation.tsx +187 -0
- rasa/core/channels/inspector/src/components/DialogueStack.tsx +136 -0
- rasa/core/channels/inspector/src/components/ExpandIcon.tsx +16 -0
- rasa/core/channels/inspector/src/components/FullscreenButton.tsx +45 -0
- rasa/core/channels/inspector/src/components/LoadingSpinner.tsx +22 -0
- rasa/core/channels/inspector/src/components/NoActiveFlow.tsx +21 -0
- rasa/core/channels/inspector/src/components/RasaLogo.tsx +32 -0
- rasa/core/channels/inspector/src/components/SaraDiagrams.tsx +39 -0
- rasa/core/channels/inspector/src/components/Slots.tsx +91 -0
- rasa/core/channels/inspector/src/components/Welcome.tsx +54 -0
- rasa/core/channels/inspector/src/helpers/audiostream.ts +191 -0
- rasa/core/channels/inspector/src/helpers/formatters.test.ts +392 -0
- rasa/core/channels/inspector/src/helpers/formatters.ts +306 -0
- rasa/core/channels/inspector/src/helpers/utils.ts +127 -0
- rasa/core/channels/inspector/src/main.tsx +13 -0
- rasa/core/channels/inspector/src/theme/Button/Button.ts +29 -0
- rasa/core/channels/inspector/src/theme/Heading/Heading.ts +31 -0
- rasa/core/channels/inspector/src/theme/Input/Input.ts +27 -0
- rasa/core/channels/inspector/src/theme/Link/Link.ts +10 -0
- rasa/core/channels/inspector/src/theme/Modal/Modal.ts +47 -0
- rasa/core/channels/inspector/src/theme/Table/Table.tsx +38 -0
- rasa/core/channels/inspector/src/theme/Tooltip/Tooltip.ts +12 -0
- rasa/core/channels/inspector/src/theme/base/breakpoints.ts +8 -0
- rasa/core/channels/inspector/src/theme/base/colors.ts +88 -0
- rasa/core/channels/inspector/src/theme/base/fonts/fontFaces.css +29 -0
- rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.eot +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.svg +329 -0
- rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.ttf +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.woff +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.woff2 +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.eot +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.svg +438 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.ttf +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.woff +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.woff2 +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.eot +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.svg +435 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.ttf +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.woff +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.woff2 +0 -0
- rasa/core/channels/inspector/src/theme/base/radii.ts +9 -0
- rasa/core/channels/inspector/src/theme/base/shadows.ts +7 -0
- rasa/core/channels/inspector/src/theme/base/sizes.ts +7 -0
- rasa/core/channels/inspector/src/theme/base/space.ts +15 -0
- rasa/core/channels/inspector/src/theme/base/styles.ts +13 -0
- rasa/core/channels/inspector/src/theme/base/typography.ts +24 -0
- rasa/core/channels/inspector/src/theme/base/zIndices.ts +19 -0
- rasa/core/channels/inspector/src/theme/index.ts +101 -0
- rasa/core/channels/inspector/src/types.ts +84 -0
- rasa/core/channels/inspector/src/vite-env.d.ts +1 -0
- rasa/core/channels/inspector/tests/__mocks__/fileMock.ts +1 -0
- rasa/core/channels/inspector/tests/__mocks__/matchMedia.ts +16 -0
- rasa/core/channels/inspector/tests/__mocks__/styleMock.ts +1 -0
- rasa/core/channels/inspector/tests/renderWithProviders.tsx +14 -0
- rasa/core/channels/inspector/tsconfig.json +26 -0
- rasa/core/channels/inspector/tsconfig.node.json +10 -0
- rasa/core/channels/inspector/vite.config.ts +8 -0
- rasa/core/channels/inspector/yarn.lock +6249 -0
- rasa/core/channels/mattermost.py +229 -0
- rasa/core/channels/rasa_chat.py +126 -0
- rasa/core/channels/rest.py +230 -0
- rasa/core/channels/rocketchat.py +174 -0
- rasa/core/channels/slack.py +620 -0
- rasa/core/channels/socketio.py +302 -0
- rasa/core/channels/telegram.py +298 -0
- rasa/core/channels/twilio.py +169 -0
- rasa/core/channels/vier_cvg.py +374 -0
- rasa/core/channels/voice_ready/__init__.py +0 -0
- rasa/core/channels/voice_ready/audiocodes.py +501 -0
- rasa/core/channels/voice_ready/jambonz.py +121 -0
- rasa/core/channels/voice_ready/jambonz_protocol.py +396 -0
- rasa/core/channels/voice_ready/twilio_voice.py +403 -0
- rasa/core/channels/voice_ready/utils.py +37 -0
- rasa/core/channels/voice_stream/__init__.py +0 -0
- rasa/core/channels/voice_stream/asr/__init__.py +0 -0
- rasa/core/channels/voice_stream/asr/asr_engine.py +89 -0
- rasa/core/channels/voice_stream/asr/asr_event.py +18 -0
- rasa/core/channels/voice_stream/asr/azure.py +130 -0
- rasa/core/channels/voice_stream/asr/deepgram.py +90 -0
- rasa/core/channels/voice_stream/audio_bytes.py +8 -0
- rasa/core/channels/voice_stream/browser_audio.py +107 -0
- rasa/core/channels/voice_stream/call_state.py +23 -0
- rasa/core/channels/voice_stream/tts/__init__.py +0 -0
- rasa/core/channels/voice_stream/tts/azure.py +106 -0
- rasa/core/channels/voice_stream/tts/cartesia.py +118 -0
- rasa/core/channels/voice_stream/tts/tts_cache.py +27 -0
- rasa/core/channels/voice_stream/tts/tts_engine.py +58 -0
- rasa/core/channels/voice_stream/twilio_media_streams.py +173 -0
- rasa/core/channels/voice_stream/util.py +57 -0
- rasa/core/channels/voice_stream/voice_channel.py +427 -0
- rasa/core/channels/webexteams.py +134 -0
- rasa/core/concurrent_lock_store.py +210 -0
- rasa/core/constants.py +112 -0
- rasa/core/evaluation/__init__.py +0 -0
- rasa/core/evaluation/marker.py +267 -0
- rasa/core/evaluation/marker_base.py +923 -0
- rasa/core/evaluation/marker_stats.py +293 -0
- rasa/core/evaluation/marker_tracker_loader.py +103 -0
- rasa/core/exceptions.py +29 -0
- rasa/core/exporter.py +284 -0
- rasa/core/featurizers/__init__.py +0 -0
- rasa/core/featurizers/precomputation.py +410 -0
- rasa/core/featurizers/single_state_featurizer.py +421 -0
- rasa/core/featurizers/tracker_featurizers.py +1262 -0
- rasa/core/http_interpreter.py +89 -0
- rasa/core/information_retrieval/__init__.py +7 -0
- rasa/core/information_retrieval/faiss.py +124 -0
- rasa/core/information_retrieval/information_retrieval.py +137 -0
- rasa/core/information_retrieval/milvus.py +59 -0
- rasa/core/information_retrieval/qdrant.py +96 -0
- rasa/core/jobs.py +63 -0
- rasa/core/lock.py +139 -0
- rasa/core/lock_store.py +343 -0
- rasa/core/migrate.py +403 -0
- rasa/core/nlg/__init__.py +3 -0
- rasa/core/nlg/callback.py +146 -0
- rasa/core/nlg/contextual_response_rephraser.py +320 -0
- rasa/core/nlg/generator.py +230 -0
- rasa/core/nlg/interpolator.py +143 -0
- rasa/core/nlg/response.py +155 -0
- rasa/core/nlg/summarize.py +70 -0
- rasa/core/persistor.py +538 -0
- rasa/core/policies/__init__.py +0 -0
- rasa/core/policies/ensemble.py +329 -0
- rasa/core/policies/enterprise_search_policy.py +905 -0
- rasa/core/policies/enterprise_search_prompt_template.jinja2 +25 -0
- rasa/core/policies/enterprise_search_prompt_with_citation_template.jinja2 +60 -0
- rasa/core/policies/flow_policy.py +205 -0
- rasa/core/policies/flows/__init__.py +0 -0
- rasa/core/policies/flows/flow_exceptions.py +44 -0
- rasa/core/policies/flows/flow_executor.py +754 -0
- rasa/core/policies/flows/flow_step_result.py +43 -0
- rasa/core/policies/intentless_policy.py +1031 -0
- rasa/core/policies/intentless_prompt_template.jinja2 +22 -0
- rasa/core/policies/memoization.py +538 -0
- rasa/core/policies/policy.py +725 -0
- rasa/core/policies/rule_policy.py +1273 -0
- rasa/core/policies/ted_policy.py +2169 -0
- rasa/core/policies/unexpected_intent_policy.py +1022 -0
- rasa/core/processor.py +1465 -0
- rasa/core/run.py +342 -0
- rasa/core/secrets_manager/__init__.py +0 -0
- rasa/core/secrets_manager/constants.py +36 -0
- rasa/core/secrets_manager/endpoints.py +391 -0
- rasa/core/secrets_manager/factory.py +241 -0
- rasa/core/secrets_manager/secret_manager.py +262 -0
- rasa/core/secrets_manager/vault.py +584 -0
- rasa/core/test.py +1335 -0
- rasa/core/tracker_store.py +1703 -0
- rasa/core/train.py +105 -0
- rasa/core/training/__init__.py +89 -0
- rasa/core/training/converters/__init__.py +0 -0
- rasa/core/training/converters/responses_prefix_converter.py +119 -0
- rasa/core/training/interactive.py +1744 -0
- rasa/core/training/story_conflict.py +381 -0
- rasa/core/training/training.py +93 -0
- rasa/core/utils.py +366 -0
- rasa/core/visualize.py +70 -0
- rasa/dialogue_understanding/__init__.py +0 -0
- rasa/dialogue_understanding/coexistence/__init__.py +0 -0
- rasa/dialogue_understanding/coexistence/constants.py +4 -0
- rasa/dialogue_understanding/coexistence/intent_based_router.py +196 -0
- rasa/dialogue_understanding/coexistence/llm_based_router.py +327 -0
- rasa/dialogue_understanding/coexistence/router_template.jinja2 +12 -0
- rasa/dialogue_understanding/commands/__init__.py +61 -0
- rasa/dialogue_understanding/commands/can_not_handle_command.py +70 -0
- rasa/dialogue_understanding/commands/cancel_flow_command.py +125 -0
- rasa/dialogue_understanding/commands/change_flow_command.py +44 -0
- rasa/dialogue_understanding/commands/chit_chat_answer_command.py +57 -0
- rasa/dialogue_understanding/commands/clarify_command.py +86 -0
- rasa/dialogue_understanding/commands/command.py +85 -0
- rasa/dialogue_understanding/commands/correct_slots_command.py +297 -0
- rasa/dialogue_understanding/commands/error_command.py +79 -0
- rasa/dialogue_understanding/commands/free_form_answer_command.py +9 -0
- rasa/dialogue_understanding/commands/handle_code_change_command.py +73 -0
- rasa/dialogue_understanding/commands/human_handoff_command.py +66 -0
- rasa/dialogue_understanding/commands/knowledge_answer_command.py +57 -0
- rasa/dialogue_understanding/commands/noop_command.py +54 -0
- rasa/dialogue_understanding/commands/repeat_bot_messages_command.py +60 -0
- rasa/dialogue_understanding/commands/restart_command.py +58 -0
- rasa/dialogue_understanding/commands/session_end_command.py +61 -0
- rasa/dialogue_understanding/commands/session_start_command.py +59 -0
- rasa/dialogue_understanding/commands/set_slot_command.py +160 -0
- rasa/dialogue_understanding/commands/skip_question_command.py +75 -0
- rasa/dialogue_understanding/commands/start_flow_command.py +107 -0
- rasa/dialogue_understanding/commands/user_silence_command.py +59 -0
- rasa/dialogue_understanding/commands/utils.py +45 -0
- rasa/dialogue_understanding/generator/__init__.py +21 -0
- rasa/dialogue_understanding/generator/command_generator.py +464 -0
- rasa/dialogue_understanding/generator/constants.py +27 -0
- rasa/dialogue_understanding/generator/flow_document_template.jinja2 +4 -0
- rasa/dialogue_understanding/generator/flow_retrieval.py +466 -0
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +500 -0
- rasa/dialogue_understanding/generator/llm_command_generator.py +67 -0
- rasa/dialogue_understanding/generator/multi_step/__init__.py +0 -0
- rasa/dialogue_understanding/generator/multi_step/fill_slots_prompt.jinja2 +62 -0
- rasa/dialogue_understanding/generator/multi_step/handle_flows_prompt.jinja2 +38 -0
- rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +920 -0
- rasa/dialogue_understanding/generator/nlu_command_adapter.py +261 -0
- rasa/dialogue_understanding/generator/single_step/__init__.py +0 -0
- rasa/dialogue_understanding/generator/single_step/command_prompt_template.jinja2 +60 -0
- rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +486 -0
- rasa/dialogue_understanding/patterns/__init__.py +0 -0
- rasa/dialogue_understanding/patterns/cancel.py +111 -0
- rasa/dialogue_understanding/patterns/cannot_handle.py +43 -0
- rasa/dialogue_understanding/patterns/chitchat.py +37 -0
- rasa/dialogue_understanding/patterns/clarify.py +97 -0
- rasa/dialogue_understanding/patterns/code_change.py +41 -0
- rasa/dialogue_understanding/patterns/collect_information.py +90 -0
- rasa/dialogue_understanding/patterns/completed.py +40 -0
- rasa/dialogue_understanding/patterns/continue_interrupted.py +42 -0
- rasa/dialogue_understanding/patterns/correction.py +278 -0
- rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +301 -0
- rasa/dialogue_understanding/patterns/human_handoff.py +37 -0
- rasa/dialogue_understanding/patterns/internal_error.py +47 -0
- rasa/dialogue_understanding/patterns/repeat.py +37 -0
- rasa/dialogue_understanding/patterns/restart.py +37 -0
- rasa/dialogue_understanding/patterns/search.py +37 -0
- rasa/dialogue_understanding/patterns/session_start.py +37 -0
- rasa/dialogue_understanding/patterns/skip_question.py +38 -0
- rasa/dialogue_understanding/patterns/user_silence.py +37 -0
- rasa/dialogue_understanding/processor/__init__.py +0 -0
- rasa/dialogue_understanding/processor/command_processor.py +720 -0
- rasa/dialogue_understanding/processor/command_processor_component.py +43 -0
- rasa/dialogue_understanding/stack/__init__.py +0 -0
- rasa/dialogue_understanding/stack/dialogue_stack.py +178 -0
- rasa/dialogue_understanding/stack/frames/__init__.py +19 -0
- rasa/dialogue_understanding/stack/frames/chit_chat_frame.py +27 -0
- rasa/dialogue_understanding/stack/frames/dialogue_stack_frame.py +137 -0
- rasa/dialogue_understanding/stack/frames/flow_stack_frame.py +157 -0
- rasa/dialogue_understanding/stack/frames/pattern_frame.py +10 -0
- rasa/dialogue_understanding/stack/frames/search_frame.py +27 -0
- rasa/dialogue_understanding/stack/utils.py +211 -0
- rasa/dialogue_understanding/utils.py +14 -0
- rasa/dialogue_understanding_test/__init__.py +0 -0
- rasa/dialogue_understanding_test/command_metric_calculation.py +12 -0
- rasa/dialogue_understanding_test/constants.py +17 -0
- rasa/dialogue_understanding_test/du_test_case.py +118 -0
- rasa/dialogue_understanding_test/du_test_result.py +11 -0
- rasa/dialogue_understanding_test/du_test_runner.py +93 -0
- rasa/dialogue_understanding_test/io.py +54 -0
- rasa/dialogue_understanding_test/validation.py +22 -0
- rasa/e2e_test/__init__.py +0 -0
- rasa/e2e_test/aggregate_test_stats_calculator.py +134 -0
- rasa/e2e_test/assertions.py +1345 -0
- rasa/e2e_test/assertions_schema.yml +129 -0
- rasa/e2e_test/constants.py +31 -0
- rasa/e2e_test/e2e_config.py +220 -0
- rasa/e2e_test/e2e_config_schema.yml +26 -0
- rasa/e2e_test/e2e_test_case.py +569 -0
- rasa/e2e_test/e2e_test_converter.py +363 -0
- rasa/e2e_test/e2e_test_converter_prompt.jinja2 +70 -0
- rasa/e2e_test/e2e_test_coverage_report.py +364 -0
- rasa/e2e_test/e2e_test_result.py +54 -0
- rasa/e2e_test/e2e_test_runner.py +1192 -0
- rasa/e2e_test/e2e_test_schema.yml +181 -0
- rasa/e2e_test/pykwalify_extensions.py +39 -0
- rasa/e2e_test/stub_custom_action.py +70 -0
- rasa/e2e_test/utils/__init__.py +0 -0
- rasa/e2e_test/utils/e2e_yaml_utils.py +55 -0
- rasa/e2e_test/utils/io.py +598 -0
- rasa/e2e_test/utils/validation.py +178 -0
- rasa/engine/__init__.py +0 -0
- rasa/engine/caching.py +463 -0
- rasa/engine/constants.py +17 -0
- rasa/engine/exceptions.py +14 -0
- rasa/engine/graph.py +642 -0
- rasa/engine/loader.py +48 -0
- rasa/engine/recipes/__init__.py +0 -0
- rasa/engine/recipes/config_files/default_config.yml +41 -0
- rasa/engine/recipes/default_components.py +97 -0
- rasa/engine/recipes/default_recipe.py +1272 -0
- rasa/engine/recipes/graph_recipe.py +79 -0
- rasa/engine/recipes/recipe.py +93 -0
- rasa/engine/runner/__init__.py +0 -0
- rasa/engine/runner/dask.py +250 -0
- rasa/engine/runner/interface.py +49 -0
- rasa/engine/storage/__init__.py +0 -0
- rasa/engine/storage/local_model_storage.py +244 -0
- rasa/engine/storage/resource.py +110 -0
- rasa/engine/storage/storage.py +199 -0
- rasa/engine/training/__init__.py +0 -0
- rasa/engine/training/components.py +176 -0
- rasa/engine/training/fingerprinting.py +64 -0
- rasa/engine/training/graph_trainer.py +256 -0
- rasa/engine/training/hooks.py +164 -0
- rasa/engine/validation.py +1451 -0
- rasa/env.py +14 -0
- rasa/exceptions.py +69 -0
- rasa/graph_components/__init__.py +0 -0
- rasa/graph_components/converters/__init__.py +0 -0
- rasa/graph_components/converters/nlu_message_converter.py +48 -0
- rasa/graph_components/providers/__init__.py +0 -0
- rasa/graph_components/providers/domain_for_core_training_provider.py +87 -0
- rasa/graph_components/providers/domain_provider.py +71 -0
- rasa/graph_components/providers/flows_provider.py +74 -0
- rasa/graph_components/providers/forms_provider.py +44 -0
- rasa/graph_components/providers/nlu_training_data_provider.py +56 -0
- rasa/graph_components/providers/responses_provider.py +44 -0
- rasa/graph_components/providers/rule_only_provider.py +49 -0
- rasa/graph_components/providers/story_graph_provider.py +96 -0
- rasa/graph_components/providers/training_tracker_provider.py +55 -0
- rasa/graph_components/validators/__init__.py +0 -0
- rasa/graph_components/validators/default_recipe_validator.py +550 -0
- rasa/graph_components/validators/finetuning_validator.py +302 -0
- rasa/hooks.py +111 -0
- rasa/jupyter.py +63 -0
- rasa/llm_fine_tuning/__init__.py +0 -0
- rasa/llm_fine_tuning/annotation_module.py +241 -0
- rasa/llm_fine_tuning/conversations.py +144 -0
- rasa/llm_fine_tuning/llm_data_preparation_module.py +178 -0
- rasa/llm_fine_tuning/paraphrasing/__init__.py +0 -0
- rasa/llm_fine_tuning/paraphrasing/conversation_rephraser.py +281 -0
- rasa/llm_fine_tuning/paraphrasing/default_rephrase_prompt_template.jina2 +44 -0
- rasa/llm_fine_tuning/paraphrasing/rephrase_validator.py +121 -0
- rasa/llm_fine_tuning/paraphrasing/rephrased_user_message.py +10 -0
- rasa/llm_fine_tuning/paraphrasing_module.py +128 -0
- rasa/llm_fine_tuning/storage.py +174 -0
- rasa/llm_fine_tuning/train_test_split_module.py +441 -0
- rasa/markers/__init__.py +0 -0
- rasa/markers/marker.py +269 -0
- rasa/markers/marker_base.py +828 -0
- rasa/markers/upload.py +74 -0
- rasa/markers/validate.py +21 -0
- rasa/model.py +118 -0
- rasa/model_manager/__init__.py +0 -0
- rasa/model_manager/config.py +40 -0
- rasa/model_manager/model_api.py +559 -0
- rasa/model_manager/runner_service.py +286 -0
- rasa/model_manager/socket_bridge.py +146 -0
- rasa/model_manager/studio_jwt_auth.py +86 -0
- rasa/model_manager/trainer_service.py +325 -0
- rasa/model_manager/utils.py +87 -0
- rasa/model_manager/warm_rasa_process.py +187 -0
- rasa/model_service.py +112 -0
- rasa/model_testing.py +457 -0
- rasa/model_training.py +596 -0
- rasa/nlu/__init__.py +7 -0
- rasa/nlu/classifiers/__init__.py +3 -0
- rasa/nlu/classifiers/classifier.py +5 -0
- rasa/nlu/classifiers/diet_classifier.py +1881 -0
- rasa/nlu/classifiers/fallback_classifier.py +192 -0
- rasa/nlu/classifiers/keyword_intent_classifier.py +188 -0
- rasa/nlu/classifiers/logistic_regression_classifier.py +253 -0
- rasa/nlu/classifiers/mitie_intent_classifier.py +156 -0
- rasa/nlu/classifiers/regex_message_handler.py +56 -0
- rasa/nlu/classifiers/sklearn_intent_classifier.py +330 -0
- rasa/nlu/constants.py +77 -0
- rasa/nlu/convert.py +40 -0
- rasa/nlu/emulators/__init__.py +0 -0
- rasa/nlu/emulators/dialogflow.py +55 -0
- rasa/nlu/emulators/emulator.py +49 -0
- rasa/nlu/emulators/luis.py +86 -0
- rasa/nlu/emulators/no_emulator.py +10 -0
- rasa/nlu/emulators/wit.py +56 -0
- rasa/nlu/extractors/__init__.py +0 -0
- rasa/nlu/extractors/crf_entity_extractor.py +715 -0
- rasa/nlu/extractors/duckling_entity_extractor.py +206 -0
- rasa/nlu/extractors/entity_synonyms.py +178 -0
- rasa/nlu/extractors/extractor.py +470 -0
- rasa/nlu/extractors/mitie_entity_extractor.py +293 -0
- rasa/nlu/extractors/regex_entity_extractor.py +220 -0
- rasa/nlu/extractors/spacy_entity_extractor.py +95 -0
- rasa/nlu/featurizers/__init__.py +0 -0
- rasa/nlu/featurizers/dense_featurizer/__init__.py +0 -0
- rasa/nlu/featurizers/dense_featurizer/convert_featurizer.py +445 -0
- rasa/nlu/featurizers/dense_featurizer/dense_featurizer.py +57 -0
- rasa/nlu/featurizers/dense_featurizer/lm_featurizer.py +768 -0
- rasa/nlu/featurizers/dense_featurizer/mitie_featurizer.py +170 -0
- rasa/nlu/featurizers/dense_featurizer/spacy_featurizer.py +132 -0
- rasa/nlu/featurizers/featurizer.py +89 -0
- rasa/nlu/featurizers/sparse_featurizer/__init__.py +0 -0
- rasa/nlu/featurizers/sparse_featurizer/count_vectors_featurizer.py +867 -0
- rasa/nlu/featurizers/sparse_featurizer/lexical_syntactic_featurizer.py +571 -0
- rasa/nlu/featurizers/sparse_featurizer/regex_featurizer.py +271 -0
- rasa/nlu/featurizers/sparse_featurizer/sparse_featurizer.py +9 -0
- rasa/nlu/model.py +24 -0
- rasa/nlu/run.py +27 -0
- rasa/nlu/selectors/__init__.py +0 -0
- rasa/nlu/selectors/response_selector.py +987 -0
- rasa/nlu/test.py +1940 -0
- rasa/nlu/tokenizers/__init__.py +0 -0
- rasa/nlu/tokenizers/jieba_tokenizer.py +148 -0
- rasa/nlu/tokenizers/mitie_tokenizer.py +75 -0
- rasa/nlu/tokenizers/spacy_tokenizer.py +72 -0
- rasa/nlu/tokenizers/tokenizer.py +239 -0
- rasa/nlu/tokenizers/whitespace_tokenizer.py +95 -0
- rasa/nlu/utils/__init__.py +35 -0
- rasa/nlu/utils/bilou_utils.py +462 -0
- rasa/nlu/utils/hugging_face/__init__.py +0 -0
- rasa/nlu/utils/hugging_face/registry.py +108 -0
- rasa/nlu/utils/hugging_face/transformers_pre_post_processors.py +311 -0
- rasa/nlu/utils/mitie_utils.py +113 -0
- rasa/nlu/utils/pattern_utils.py +168 -0
- rasa/nlu/utils/spacy_utils.py +310 -0
- rasa/plugin.py +90 -0
- rasa/server.py +1588 -0
- rasa/shared/__init__.py +0 -0
- rasa/shared/constants.py +311 -0
- rasa/shared/core/__init__.py +0 -0
- rasa/shared/core/command_payload_reader.py +109 -0
- rasa/shared/core/constants.py +180 -0
- rasa/shared/core/conversation.py +46 -0
- rasa/shared/core/domain.py +2172 -0
- rasa/shared/core/events.py +2559 -0
- rasa/shared/core/flows/__init__.py +7 -0
- rasa/shared/core/flows/flow.py +562 -0
- rasa/shared/core/flows/flow_path.py +84 -0
- rasa/shared/core/flows/flow_step.py +146 -0
- rasa/shared/core/flows/flow_step_links.py +319 -0
- rasa/shared/core/flows/flow_step_sequence.py +70 -0
- rasa/shared/core/flows/flows_list.py +258 -0
- rasa/shared/core/flows/flows_yaml_schema.json +303 -0
- rasa/shared/core/flows/nlu_trigger.py +117 -0
- rasa/shared/core/flows/steps/__init__.py +24 -0
- rasa/shared/core/flows/steps/action.py +56 -0
- rasa/shared/core/flows/steps/call.py +64 -0
- rasa/shared/core/flows/steps/collect.py +112 -0
- rasa/shared/core/flows/steps/constants.py +5 -0
- rasa/shared/core/flows/steps/continuation.py +36 -0
- rasa/shared/core/flows/steps/end.py +22 -0
- rasa/shared/core/flows/steps/internal.py +44 -0
- rasa/shared/core/flows/steps/link.py +51 -0
- rasa/shared/core/flows/steps/no_operation.py +48 -0
- rasa/shared/core/flows/steps/set_slots.py +50 -0
- rasa/shared/core/flows/steps/start.py +30 -0
- rasa/shared/core/flows/utils.py +39 -0
- rasa/shared/core/flows/validation.py +735 -0
- rasa/shared/core/flows/yaml_flows_io.py +405 -0
- rasa/shared/core/generator.py +908 -0
- rasa/shared/core/slot_mappings.py +526 -0
- rasa/shared/core/slots.py +654 -0
- rasa/shared/core/trackers.py +1183 -0
- rasa/shared/core/training_data/__init__.py +0 -0
- rasa/shared/core/training_data/loading.py +89 -0
- rasa/shared/core/training_data/story_reader/__init__.py +0 -0
- rasa/shared/core/training_data/story_reader/story_reader.py +129 -0
- rasa/shared/core/training_data/story_reader/story_step_builder.py +168 -0
- rasa/shared/core/training_data/story_reader/yaml_story_reader.py +888 -0
- rasa/shared/core/training_data/story_writer/__init__.py +0 -0
- rasa/shared/core/training_data/story_writer/story_writer.py +76 -0
- rasa/shared/core/training_data/story_writer/yaml_story_writer.py +444 -0
- rasa/shared/core/training_data/structures.py +858 -0
- rasa/shared/core/training_data/visualization.html +146 -0
- rasa/shared/core/training_data/visualization.py +603 -0
- rasa/shared/data.py +249 -0
- rasa/shared/engine/__init__.py +0 -0
- rasa/shared/engine/caching.py +26 -0
- rasa/shared/exceptions.py +167 -0
- rasa/shared/importers/__init__.py +0 -0
- rasa/shared/importers/importer.py +770 -0
- rasa/shared/importers/multi_project.py +215 -0
- rasa/shared/importers/rasa.py +108 -0
- rasa/shared/importers/remote_importer.py +196 -0
- rasa/shared/importers/utils.py +36 -0
- rasa/shared/nlu/__init__.py +0 -0
- rasa/shared/nlu/constants.py +53 -0
- rasa/shared/nlu/interpreter.py +10 -0
- rasa/shared/nlu/training_data/__init__.py +0 -0
- rasa/shared/nlu/training_data/entities_parser.py +208 -0
- rasa/shared/nlu/training_data/features.py +492 -0
- rasa/shared/nlu/training_data/formats/__init__.py +10 -0
- rasa/shared/nlu/training_data/formats/dialogflow.py +163 -0
- rasa/shared/nlu/training_data/formats/luis.py +87 -0
- rasa/shared/nlu/training_data/formats/rasa.py +135 -0
- rasa/shared/nlu/training_data/formats/rasa_yaml.py +618 -0
- rasa/shared/nlu/training_data/formats/readerwriter.py +244 -0
- rasa/shared/nlu/training_data/formats/wit.py +52 -0
- rasa/shared/nlu/training_data/loading.py +137 -0
- rasa/shared/nlu/training_data/lookup_tables_parser.py +30 -0
- rasa/shared/nlu/training_data/message.py +490 -0
- rasa/shared/nlu/training_data/schemas/__init__.py +0 -0
- rasa/shared/nlu/training_data/schemas/data_schema.py +85 -0
- rasa/shared/nlu/training_data/schemas/nlu.yml +53 -0
- rasa/shared/nlu/training_data/schemas/responses.yml +70 -0
- rasa/shared/nlu/training_data/synonyms_parser.py +42 -0
- rasa/shared/nlu/training_data/training_data.py +729 -0
- rasa/shared/nlu/training_data/util.py +223 -0
- rasa/shared/providers/__init__.py +0 -0
- rasa/shared/providers/_configs/__init__.py +0 -0
- rasa/shared/providers/_configs/azure_openai_client_config.py +677 -0
- rasa/shared/providers/_configs/client_config.py +59 -0
- rasa/shared/providers/_configs/default_litellm_client_config.py +132 -0
- rasa/shared/providers/_configs/huggingface_local_embedding_client_config.py +236 -0
- rasa/shared/providers/_configs/litellm_router_client_config.py +222 -0
- rasa/shared/providers/_configs/model_group_config.py +173 -0
- rasa/shared/providers/_configs/openai_client_config.py +177 -0
- rasa/shared/providers/_configs/rasa_llm_client_config.py +75 -0
- rasa/shared/providers/_configs/self_hosted_llm_client_config.py +178 -0
- rasa/shared/providers/_configs/utils.py +117 -0
- rasa/shared/providers/_ssl_verification_utils.py +124 -0
- rasa/shared/providers/_utils.py +79 -0
- rasa/shared/providers/constants.py +7 -0
- rasa/shared/providers/embedding/__init__.py +0 -0
- rasa/shared/providers/embedding/_base_litellm_embedding_client.py +243 -0
- rasa/shared/providers/embedding/_langchain_embedding_client_adapter.py +74 -0
- rasa/shared/providers/embedding/azure_openai_embedding_client.py +335 -0
- rasa/shared/providers/embedding/default_litellm_embedding_client.py +126 -0
- rasa/shared/providers/embedding/embedding_client.py +90 -0
- rasa/shared/providers/embedding/embedding_response.py +41 -0
- rasa/shared/providers/embedding/huggingface_local_embedding_client.py +191 -0
- rasa/shared/providers/embedding/litellm_router_embedding_client.py +138 -0
- rasa/shared/providers/embedding/openai_embedding_client.py +172 -0
- rasa/shared/providers/llm/__init__.py +0 -0
- rasa/shared/providers/llm/_base_litellm_client.py +265 -0
- rasa/shared/providers/llm/azure_openai_llm_client.py +415 -0
- rasa/shared/providers/llm/default_litellm_llm_client.py +110 -0
- rasa/shared/providers/llm/litellm_router_llm_client.py +202 -0
- rasa/shared/providers/llm/llm_client.py +78 -0
- rasa/shared/providers/llm/llm_response.py +50 -0
- rasa/shared/providers/llm/openai_llm_client.py +161 -0
- rasa/shared/providers/llm/rasa_llm_client.py +120 -0
- rasa/shared/providers/llm/self_hosted_llm_client.py +276 -0
- rasa/shared/providers/mappings.py +94 -0
- rasa/shared/providers/router/__init__.py +0 -0
- rasa/shared/providers/router/_base_litellm_router_client.py +185 -0
- rasa/shared/providers/router/router_client.py +75 -0
- rasa/shared/utils/__init__.py +0 -0
- rasa/shared/utils/cli.py +102 -0
- rasa/shared/utils/common.py +324 -0
- rasa/shared/utils/constants.py +4 -0
- rasa/shared/utils/health_check/__init__.py +0 -0
- rasa/shared/utils/health_check/embeddings_health_check_mixin.py +31 -0
- rasa/shared/utils/health_check/health_check.py +258 -0
- rasa/shared/utils/health_check/llm_health_check_mixin.py +31 -0
- rasa/shared/utils/io.py +499 -0
- rasa/shared/utils/llm.py +764 -0
- rasa/shared/utils/pykwalify_extensions.py +27 -0
- rasa/shared/utils/schemas/__init__.py +0 -0
- rasa/shared/utils/schemas/config.yml +2 -0
- rasa/shared/utils/schemas/domain.yml +145 -0
- rasa/shared/utils/schemas/events.py +214 -0
- rasa/shared/utils/schemas/model_config.yml +36 -0
- rasa/shared/utils/schemas/stories.yml +173 -0
- rasa/shared/utils/yaml.py +1068 -0
- rasa/studio/__init__.py +0 -0
- rasa/studio/auth.py +270 -0
- rasa/studio/config.py +136 -0
- rasa/studio/constants.py +19 -0
- rasa/studio/data_handler.py +368 -0
- rasa/studio/download.py +489 -0
- rasa/studio/results_logger.py +137 -0
- rasa/studio/train.py +134 -0
- rasa/studio/upload.py +563 -0
- rasa/telemetry.py +1876 -0
- rasa/tracing/__init__.py +0 -0
- rasa/tracing/config.py +355 -0
- rasa/tracing/constants.py +62 -0
- rasa/tracing/instrumentation/__init__.py +0 -0
- rasa/tracing/instrumentation/attribute_extractors.py +765 -0
- rasa/tracing/instrumentation/instrumentation.py +1306 -0
- rasa/tracing/instrumentation/intentless_policy_instrumentation.py +144 -0
- rasa/tracing/instrumentation/metrics.py +294 -0
- rasa/tracing/metric_instrument_provider.py +205 -0
- rasa/utils/__init__.py +0 -0
- rasa/utils/beta.py +83 -0
- rasa/utils/cli.py +28 -0
- rasa/utils/common.py +639 -0
- rasa/utils/converter.py +53 -0
- rasa/utils/endpoints.py +331 -0
- rasa/utils/io.py +252 -0
- rasa/utils/json_utils.py +60 -0
- rasa/utils/licensing.py +542 -0
- rasa/utils/log_utils.py +181 -0
- rasa/utils/mapper.py +210 -0
- rasa/utils/ml_utils.py +147 -0
- rasa/utils/plotting.py +362 -0
- rasa/utils/sanic_error_handler.py +32 -0
- rasa/utils/singleton.py +23 -0
- rasa/utils/tensorflow/__init__.py +0 -0
- rasa/utils/tensorflow/callback.py +112 -0
- rasa/utils/tensorflow/constants.py +116 -0
- rasa/utils/tensorflow/crf.py +492 -0
- rasa/utils/tensorflow/data_generator.py +440 -0
- rasa/utils/tensorflow/environment.py +161 -0
- rasa/utils/tensorflow/exceptions.py +5 -0
- rasa/utils/tensorflow/feature_array.py +366 -0
- rasa/utils/tensorflow/layers.py +1565 -0
- rasa/utils/tensorflow/layers_utils.py +113 -0
- rasa/utils/tensorflow/metrics.py +281 -0
- rasa/utils/tensorflow/model_data.py +798 -0
- rasa/utils/tensorflow/model_data_utils.py +499 -0
- rasa/utils/tensorflow/models.py +935 -0
- rasa/utils/tensorflow/rasa_layers.py +1094 -0
- rasa/utils/tensorflow/transformer.py +640 -0
- rasa/utils/tensorflow/types.py +6 -0
- rasa/utils/train_utils.py +572 -0
- rasa/utils/url_tools.py +53 -0
- rasa/utils/yaml.py +54 -0
- rasa/validator.py +1644 -0
- rasa/version.py +3 -0
- rasa_pro-3.12.0.dev1.dist-info/METADATA +199 -0
- rasa_pro-3.12.0.dev1.dist-info/NOTICE +5 -0
- rasa_pro-3.12.0.dev1.dist-info/RECORD +790 -0
- rasa_pro-3.12.0.dev1.dist-info/WHEEL +4 -0
- rasa_pro-3.12.0.dev1.dist-info/entry_points.txt +3 -0
|
@@ -0,0 +1,935 @@
|
|
|
1
|
+
import time
|
|
2
|
+
import random
|
|
3
|
+
import tensorflow as tf
|
|
4
|
+
import numpy as np
|
|
5
|
+
import logging
|
|
6
|
+
import os
|
|
7
|
+
from collections import defaultdict
|
|
8
|
+
from typing import List, Text, Dict, Tuple, Union, Optional, Any, TYPE_CHECKING
|
|
9
|
+
|
|
10
|
+
from keras.src.utils import tf_utils
|
|
11
|
+
from keras import Model
|
|
12
|
+
|
|
13
|
+
from rasa.shared.constants import DIAGNOSTIC_DATA
|
|
14
|
+
from rasa.utils.tensorflow.constants import (
|
|
15
|
+
LABEL,
|
|
16
|
+
IDS,
|
|
17
|
+
INTENT_CLASSIFICATION,
|
|
18
|
+
SENTENCE,
|
|
19
|
+
SEQUENCE_LENGTH,
|
|
20
|
+
RANDOM_SEED,
|
|
21
|
+
EMBEDDING_DIMENSION,
|
|
22
|
+
REGULARIZATION_CONSTANT,
|
|
23
|
+
SIMILARITY_TYPE,
|
|
24
|
+
CONNECTION_DENSITY,
|
|
25
|
+
NUM_NEG,
|
|
26
|
+
LOSS_TYPE,
|
|
27
|
+
MAX_POS_SIM,
|
|
28
|
+
MAX_NEG_SIM,
|
|
29
|
+
USE_MAX_NEG_SIM,
|
|
30
|
+
NEGATIVE_MARGIN_SCALE,
|
|
31
|
+
SCALE_LOSS,
|
|
32
|
+
LEARNING_RATE,
|
|
33
|
+
CONSTRAIN_SIMILARITIES,
|
|
34
|
+
MODEL_CONFIDENCE,
|
|
35
|
+
RUN_EAGERLY,
|
|
36
|
+
)
|
|
37
|
+
from rasa.utils.tensorflow.model_data import (
|
|
38
|
+
RasaModelData,
|
|
39
|
+
FeatureSignature,
|
|
40
|
+
FeatureArray,
|
|
41
|
+
)
|
|
42
|
+
import rasa.utils.train_utils
|
|
43
|
+
from rasa.utils.tensorflow import layers
|
|
44
|
+
from rasa.utils.tensorflow import rasa_layers
|
|
45
|
+
from rasa.utils.tensorflow.data_generator import (
|
|
46
|
+
RasaDataGenerator,
|
|
47
|
+
RasaBatchDataGenerator,
|
|
48
|
+
)
|
|
49
|
+
from rasa.shared.nlu.constants import TEXT
|
|
50
|
+
from rasa.shared.exceptions import RasaException
|
|
51
|
+
from rasa.utils.tensorflow.types import BatchData, MaybeNestedBatchData
|
|
52
|
+
|
|
53
|
+
if TYPE_CHECKING:
|
|
54
|
+
from tensorflow.python.types.core import GenericFunction
|
|
55
|
+
|
|
56
|
+
logger = logging.getLogger(__name__)
|
|
57
|
+
|
|
58
|
+
LABEL_KEY = LABEL
|
|
59
|
+
LABEL_SUB_KEY = IDS
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
# noinspection PyMethodOverriding
|
|
63
|
+
class RasaModel(Model):
|
|
64
|
+
"""Abstract custom Keras model.
|
|
65
|
+
|
|
66
|
+
This model overwrites the following methods:
|
|
67
|
+
- train_step
|
|
68
|
+
- test_step
|
|
69
|
+
- predict_step
|
|
70
|
+
- save
|
|
71
|
+
- load
|
|
72
|
+
Cannot be used as tf.keras.Model.
|
|
73
|
+
"""
|
|
74
|
+
|
|
75
|
+
_training: Optional[bool]
|
|
76
|
+
|
|
77
|
+
def __init__(self, random_seed: Optional[int] = None, **kwargs: Any) -> None:
|
|
78
|
+
"""Initialize the RasaModel.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
random_seed: set the random seed to get reproducible results
|
|
82
|
+
"""
|
|
83
|
+
# make sure that keras releases resources from previously trained model
|
|
84
|
+
tf.keras.backend.clear_session()
|
|
85
|
+
super().__init__(**kwargs)
|
|
86
|
+
|
|
87
|
+
self.total_loss = tf.keras.metrics.Mean(name="t_loss")
|
|
88
|
+
self.metrics_to_log = ["t_loss"]
|
|
89
|
+
|
|
90
|
+
self._training = None # training phase should be defined when building a graph
|
|
91
|
+
|
|
92
|
+
if random_seed is None:
|
|
93
|
+
random_seed = int(time.time())
|
|
94
|
+
self.random_seed = random_seed
|
|
95
|
+
self._set_random_seed()
|
|
96
|
+
|
|
97
|
+
self._tf_predict_step: Optional["GenericFunction"] = None
|
|
98
|
+
self.prepared_for_prediction = False
|
|
99
|
+
|
|
100
|
+
self._checkpoint = tf.train.Checkpoint(model=self)
|
|
101
|
+
|
|
102
|
+
def _set_random_seed(self) -> None:
|
|
103
|
+
random.seed(self.random_seed)
|
|
104
|
+
np.random.seed(self.random_seed)
|
|
105
|
+
tf.random.set_seed(self.random_seed)
|
|
106
|
+
tf.experimental.numpy.random.seed(self.random_seed)
|
|
107
|
+
tf.keras.utils.set_random_seed(self.random_seed)
|
|
108
|
+
# Set a fixed value for the hash seed
|
|
109
|
+
os.environ["PYTHONHASHSEED"] = str(self.random_seed)
|
|
110
|
+
|
|
111
|
+
def batch_loss(
|
|
112
|
+
self, batch_in: Union[Tuple[tf.Tensor, ...], Tuple[np.ndarray, ...]]
|
|
113
|
+
) -> tf.Tensor:
|
|
114
|
+
"""Calculates the loss for the given batch.
|
|
115
|
+
|
|
116
|
+
Args:
|
|
117
|
+
batch_in: The batch.
|
|
118
|
+
|
|
119
|
+
Returns:
|
|
120
|
+
The loss of the given batch.
|
|
121
|
+
"""
|
|
122
|
+
raise NotImplementedError
|
|
123
|
+
|
|
124
|
+
def prepare_for_predict(self) -> None:
|
|
125
|
+
"""Prepares tf graph fpr prediction.
|
|
126
|
+
|
|
127
|
+
This method should contain necessary tf calculations
|
|
128
|
+
and set self variables that are used in `batch_predict`.
|
|
129
|
+
For example, pre calculation of `self.all_labels_embed`.
|
|
130
|
+
"""
|
|
131
|
+
pass
|
|
132
|
+
|
|
133
|
+
def batch_predict(
|
|
134
|
+
self, batch_in: Union[Tuple[tf.Tensor, ...], Tuple[np.ndarray, ...]]
|
|
135
|
+
) -> Dict[Text, Union[tf.Tensor, Dict[Text, tf.Tensor]]]:
|
|
136
|
+
"""Predicts the output of the given batch.
|
|
137
|
+
|
|
138
|
+
Args:
|
|
139
|
+
batch_in: The batch.
|
|
140
|
+
|
|
141
|
+
Returns:
|
|
142
|
+
The output to predict.
|
|
143
|
+
"""
|
|
144
|
+
raise NotImplementedError
|
|
145
|
+
|
|
146
|
+
def train_step(
|
|
147
|
+
self, batch_in: Union[Tuple[tf.Tensor, ...], Tuple[np.ndarray, ...]]
|
|
148
|
+
) -> Dict[Text, float]:
|
|
149
|
+
"""Performs a train step using the given batch.
|
|
150
|
+
|
|
151
|
+
Args:
|
|
152
|
+
batch_in: The batch input.
|
|
153
|
+
|
|
154
|
+
Returns:
|
|
155
|
+
Training metrics.
|
|
156
|
+
"""
|
|
157
|
+
self._training = True
|
|
158
|
+
|
|
159
|
+
# calculate supervision and regularization losses separately
|
|
160
|
+
with tf.GradientTape(persistent=True) as tape:
|
|
161
|
+
prediction_loss = self.batch_loss(batch_in)
|
|
162
|
+
regularization_loss = tf.math.add_n(self.losses)
|
|
163
|
+
total_loss = prediction_loss + regularization_loss
|
|
164
|
+
|
|
165
|
+
self.total_loss.update_state(total_loss)
|
|
166
|
+
|
|
167
|
+
# calculate the gradients that come from supervision signal
|
|
168
|
+
prediction_gradients = tape.gradient(prediction_loss, self.trainable_variables)
|
|
169
|
+
# calculate the gradients that come from regularization
|
|
170
|
+
regularization_gradients = tape.gradient(
|
|
171
|
+
regularization_loss, self.trainable_variables
|
|
172
|
+
)
|
|
173
|
+
# delete gradient tape manually
|
|
174
|
+
# since it was created with `persistent=True` option
|
|
175
|
+
del tape
|
|
176
|
+
|
|
177
|
+
gradients = []
|
|
178
|
+
for pred_grad, reg_grad in zip(prediction_gradients, regularization_gradients):
|
|
179
|
+
if pred_grad is not None and reg_grad is not None:
|
|
180
|
+
# remove regularization gradient for variables
|
|
181
|
+
# that don't have prediction gradient
|
|
182
|
+
gradients.append(
|
|
183
|
+
pred_grad
|
|
184
|
+
+ tf.where(pred_grad > 0, reg_grad, tf.zeros_like(reg_grad))
|
|
185
|
+
)
|
|
186
|
+
else:
|
|
187
|
+
gradients.append(pred_grad)
|
|
188
|
+
|
|
189
|
+
self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
|
|
190
|
+
|
|
191
|
+
self._training = None
|
|
192
|
+
|
|
193
|
+
return self._get_metric_results()
|
|
194
|
+
|
|
195
|
+
def test_step(
|
|
196
|
+
self, batch_in: Union[Tuple[tf.Tensor, ...], Tuple[np.ndarray, ...]]
|
|
197
|
+
) -> Dict[Text, float]:
|
|
198
|
+
"""Tests the model using the given batch.
|
|
199
|
+
|
|
200
|
+
This method is used during validation.
|
|
201
|
+
|
|
202
|
+
Args:
|
|
203
|
+
batch_in: The batch input.
|
|
204
|
+
|
|
205
|
+
Returns:
|
|
206
|
+
Testing metrics.
|
|
207
|
+
"""
|
|
208
|
+
self._training = False
|
|
209
|
+
|
|
210
|
+
prediction_loss = self.batch_loss(batch_in)
|
|
211
|
+
regularization_loss = tf.math.add_n(self.losses)
|
|
212
|
+
total_loss = prediction_loss + regularization_loss
|
|
213
|
+
self.total_loss.update_state(total_loss)
|
|
214
|
+
|
|
215
|
+
self._training = None
|
|
216
|
+
|
|
217
|
+
return self._get_metric_results()
|
|
218
|
+
|
|
219
|
+
def predict_step(
|
|
220
|
+
self, batch_in: Union[Tuple[tf.Tensor, ...], Tuple[np.ndarray, ...]]
|
|
221
|
+
) -> Dict[Text, tf.Tensor]:
|
|
222
|
+
"""Predicts the output for the given batch.
|
|
223
|
+
|
|
224
|
+
Args:
|
|
225
|
+
batch_in: The batch to predict.
|
|
226
|
+
|
|
227
|
+
Returns:
|
|
228
|
+
Prediction output.
|
|
229
|
+
"""
|
|
230
|
+
self._training = False
|
|
231
|
+
|
|
232
|
+
if not self.prepared_for_prediction:
|
|
233
|
+
# in case the model is used for prediction without loading, e.g. directly
|
|
234
|
+
# after training, we need to prepare the model for prediction once
|
|
235
|
+
self.prepare_for_predict()
|
|
236
|
+
self.prepared_for_prediction = True
|
|
237
|
+
|
|
238
|
+
return self.batch_predict(batch_in)
|
|
239
|
+
|
|
240
|
+
@staticmethod
|
|
241
|
+
def _dynamic_signature(
|
|
242
|
+
batch_in: Union[Tuple[tf.Tensor, ...], Tuple[np.ndarray, ...]],
|
|
243
|
+
) -> List[List[tf.TensorSpec]]:
|
|
244
|
+
element_spec = []
|
|
245
|
+
for tensor in batch_in:
|
|
246
|
+
if len(tensor.shape) > 1:
|
|
247
|
+
shape: List[Union[None, int]] = [None] * (len(tensor.shape) - 1)
|
|
248
|
+
shape += [tensor.shape[-1]]
|
|
249
|
+
else:
|
|
250
|
+
shape = [None]
|
|
251
|
+
element_spec.append(tf.TensorSpec(shape, tensor.dtype))
|
|
252
|
+
# batch_in is a list of tensors, therefore we need to wrap element_spec into
|
|
253
|
+
# the list
|
|
254
|
+
return [element_spec]
|
|
255
|
+
|
|
256
|
+
def _rasa_predict(
|
|
257
|
+
self, batch_in: Tuple[np.ndarray, ...]
|
|
258
|
+
) -> Dict[Text, Union[np.ndarray, Dict[Text, Any]]]:
|
|
259
|
+
"""Custom prediction method that builds tf graph on the first call.
|
|
260
|
+
|
|
261
|
+
Args:
|
|
262
|
+
batch_in: Prepared batch ready for input to `predict_step` method of model.
|
|
263
|
+
|
|
264
|
+
Return:
|
|
265
|
+
Prediction output, including diagnostic data.
|
|
266
|
+
"""
|
|
267
|
+
self._training = False
|
|
268
|
+
if not self.prepared_for_prediction:
|
|
269
|
+
# in case the model is used for prediction without loading, e.g. directly
|
|
270
|
+
# after training, we need to prepare the model for prediction once
|
|
271
|
+
self.prepare_for_predict()
|
|
272
|
+
self.prepared_for_prediction = True
|
|
273
|
+
|
|
274
|
+
if self._run_eagerly:
|
|
275
|
+
# Once we take advantage of TF's distributed training, this is where
|
|
276
|
+
# scheduled functions will be forced to execute and return actual values.
|
|
277
|
+
outputs = tf_utils.sync_to_numpy_or_python_type(self.predict_step(batch_in))
|
|
278
|
+
if DIAGNOSTIC_DATA in outputs:
|
|
279
|
+
outputs[DIAGNOSTIC_DATA] = self._empty_lists_to_none_in_dict(
|
|
280
|
+
outputs[DIAGNOSTIC_DATA]
|
|
281
|
+
)
|
|
282
|
+
return outputs
|
|
283
|
+
|
|
284
|
+
if self._tf_predict_step is None:
|
|
285
|
+
self._tf_predict_step = tf.function(
|
|
286
|
+
self.predict_step, input_signature=self._dynamic_signature(batch_in)
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
# Once we take advantage of TF's distributed training, this is where
|
|
290
|
+
# scheduled functions will be forced to execute and return actual values.
|
|
291
|
+
outputs = tf_utils.sync_to_numpy_or_python_type(
|
|
292
|
+
self._tf_predict_step(list(batch_in))
|
|
293
|
+
)
|
|
294
|
+
if DIAGNOSTIC_DATA in outputs:
|
|
295
|
+
outputs[DIAGNOSTIC_DATA] = self._empty_lists_to_none_in_dict(
|
|
296
|
+
outputs[DIAGNOSTIC_DATA]
|
|
297
|
+
)
|
|
298
|
+
return outputs
|
|
299
|
+
|
|
300
|
+
def run_inference(
|
|
301
|
+
self,
|
|
302
|
+
model_data: RasaModelData,
|
|
303
|
+
batch_size: Union[int, List[int]] = 1,
|
|
304
|
+
output_keys_expected: Optional[List[Text]] = None,
|
|
305
|
+
) -> Dict[Text, Union[np.ndarray, Dict[Text, Any]]]:
|
|
306
|
+
"""Implements bulk inferencing through the model.
|
|
307
|
+
|
|
308
|
+
Args:
|
|
309
|
+
model_data: Input data to be fed to the model.
|
|
310
|
+
batch_size: Size of batches that the generator should create.
|
|
311
|
+
output_keys_expected: Keys which are expected in the output.
|
|
312
|
+
The output should be filtered to have only these keys before
|
|
313
|
+
merging it with the output across all batches.
|
|
314
|
+
|
|
315
|
+
Returns:
|
|
316
|
+
Model outputs corresponding to the inputs fed.
|
|
317
|
+
"""
|
|
318
|
+
outputs: Dict[Text, Union[np.ndarray, Dict[Text, Any]]] = {}
|
|
319
|
+
(data_generator, _) = rasa.utils.train_utils.create_data_generators(
|
|
320
|
+
model_data=model_data, batch_sizes=batch_size, epochs=1, shuffle=False
|
|
321
|
+
)
|
|
322
|
+
data_iterator = iter(data_generator)
|
|
323
|
+
while True:
|
|
324
|
+
try:
|
|
325
|
+
# data_generator is a tuple of 2 elements - input and output.
|
|
326
|
+
# We only need input, since output is always None and not
|
|
327
|
+
# consumed by our TF graphs.
|
|
328
|
+
batch_in = next(data_iterator)[0]
|
|
329
|
+
batch_out: Dict[Text, Union[np.ndarray, Dict[Text, Any]]] = (
|
|
330
|
+
self._rasa_predict(batch_in)
|
|
331
|
+
)
|
|
332
|
+
if output_keys_expected:
|
|
333
|
+
batch_out = {
|
|
334
|
+
key: output
|
|
335
|
+
for key, output in batch_out.items()
|
|
336
|
+
if key in output_keys_expected
|
|
337
|
+
}
|
|
338
|
+
outputs = self._merge_batch_outputs(outputs, batch_out)
|
|
339
|
+
except StopIteration:
|
|
340
|
+
# Generator ran out of batches, time to finish inferencing
|
|
341
|
+
break
|
|
342
|
+
return outputs
|
|
343
|
+
|
|
344
|
+
@staticmethod
|
|
345
|
+
def _merge_batch_outputs(
|
|
346
|
+
all_outputs: Dict[Text, Union[np.ndarray, Dict[Text, Any]]],
|
|
347
|
+
batch_output: Dict[Text, Union[np.ndarray, Dict[Text, np.ndarray]]],
|
|
348
|
+
) -> Dict[Text, Union[np.ndarray, Dict[Text, Any]]]:
|
|
349
|
+
"""Merges a batch's output into the output for all batches.
|
|
350
|
+
|
|
351
|
+
Function assumes that the schema of batch output remains the same,
|
|
352
|
+
i.e. keys and their value types do not change from one batch's
|
|
353
|
+
output to another.
|
|
354
|
+
|
|
355
|
+
Args:
|
|
356
|
+
all_outputs: Existing output for all previous batches.
|
|
357
|
+
batch_output: Output for a batch.
|
|
358
|
+
|
|
359
|
+
Returns:
|
|
360
|
+
Merged output with the output for current batch stacked
|
|
361
|
+
below the output for all previous batches.
|
|
362
|
+
"""
|
|
363
|
+
if not all_outputs:
|
|
364
|
+
return batch_output
|
|
365
|
+
for key, val in batch_output.items():
|
|
366
|
+
if isinstance(val, np.ndarray):
|
|
367
|
+
all_outputs[key] = np.concatenate(
|
|
368
|
+
[all_outputs[key], batch_output[key]], axis=0
|
|
369
|
+
)
|
|
370
|
+
|
|
371
|
+
elif isinstance(val, dict):
|
|
372
|
+
# recurse and merge the inner dict first
|
|
373
|
+
all_outputs[key] = RasaModel._merge_batch_outputs(all_outputs[key], val)
|
|
374
|
+
|
|
375
|
+
return all_outputs
|
|
376
|
+
|
|
377
|
+
@staticmethod
|
|
378
|
+
def _empty_lists_to_none_in_dict(input_dict: Dict[Text, Any]) -> Dict[Text, Any]:
|
|
379
|
+
"""Recursively replaces empty list or np array with None in a dictionary."""
|
|
380
|
+
|
|
381
|
+
def _recurse(
|
|
382
|
+
x: Union[Dict[Text, Any], List[Any], np.ndarray],
|
|
383
|
+
) -> Optional[Union[Dict[Text, Any], List[Any], np.ndarray]]:
|
|
384
|
+
if isinstance(x, dict):
|
|
385
|
+
return {k: _recurse(v) for k, v in x.items()}
|
|
386
|
+
elif (isinstance(x, list) or isinstance(x, np.ndarray)) and np.size(x) == 0:
|
|
387
|
+
return None
|
|
388
|
+
return x
|
|
389
|
+
|
|
390
|
+
return {k: _recurse(v) for k, v in input_dict.items()}
|
|
391
|
+
|
|
392
|
+
def _get_metric_results(self, prefix: Optional[Text] = "") -> Dict[Text, float]:
|
|
393
|
+
return {
|
|
394
|
+
f"{prefix}{metric.name}": metric.result()
|
|
395
|
+
for metric in self.metrics
|
|
396
|
+
if metric.name in self.metrics_to_log
|
|
397
|
+
}
|
|
398
|
+
|
|
399
|
+
def save(self, model_file_name: Text, overwrite: bool = True) -> None:
|
|
400
|
+
"""Save the model to the given file.
|
|
401
|
+
|
|
402
|
+
Args:
|
|
403
|
+
model_file_name: The file name to save the model to.
|
|
404
|
+
overwrite: If 'True' an already existing model with the same file name will
|
|
405
|
+
be overwritten.
|
|
406
|
+
"""
|
|
407
|
+
self.save_weights(model_file_name, overwrite=overwrite, save_format="tf")
|
|
408
|
+
|
|
409
|
+
@classmethod
|
|
410
|
+
def load(
|
|
411
|
+
cls,
|
|
412
|
+
model_file_name: Text,
|
|
413
|
+
model_data_example: RasaModelData,
|
|
414
|
+
predict_data_example: Optional[RasaModelData] = None,
|
|
415
|
+
finetune_mode: bool = False,
|
|
416
|
+
*args: Any,
|
|
417
|
+
**kwargs: Any,
|
|
418
|
+
) -> "RasaModel":
|
|
419
|
+
"""Loads a model from the given weights.
|
|
420
|
+
|
|
421
|
+
Args:
|
|
422
|
+
model_file_name: Path to file containing model weights.
|
|
423
|
+
model_data_example: Example data point to construct the model architecture.
|
|
424
|
+
predict_data_example: Example data point to speed up prediction during
|
|
425
|
+
inference.
|
|
426
|
+
finetune_mode: Indicates whether to load the model for further finetuning.
|
|
427
|
+
*args: Any other non key-worded arguments.
|
|
428
|
+
**kwargs: Any other key-worded arguments.
|
|
429
|
+
|
|
430
|
+
Returns:
|
|
431
|
+
Loaded model with weights appropriately set.
|
|
432
|
+
"""
|
|
433
|
+
logger.debug(
|
|
434
|
+
f"Loading the model from {model_file_name} "
|
|
435
|
+
f"with finetune_mode={finetune_mode}..."
|
|
436
|
+
)
|
|
437
|
+
# create empty model
|
|
438
|
+
model = cls(*args, **kwargs)
|
|
439
|
+
learning_rate = kwargs.get("config", {}).get(LEARNING_RATE, 0.001)
|
|
440
|
+
run_eagerly = kwargs.get("config", {}).get(RUN_EAGERLY)
|
|
441
|
+
|
|
442
|
+
# need to train on 1 example to build weights of the correct size
|
|
443
|
+
model.compile(
|
|
444
|
+
optimizer=tf.keras.optimizers.Adam(learning_rate), run_eagerly=run_eagerly
|
|
445
|
+
)
|
|
446
|
+
data_generator = RasaBatchDataGenerator(model_data_example, batch_size=1)
|
|
447
|
+
model.fit(data_generator, verbose=False)
|
|
448
|
+
# load trained weights
|
|
449
|
+
model.load_weights(model_file_name)
|
|
450
|
+
|
|
451
|
+
# predict on one data example to speed up prediction during inference
|
|
452
|
+
# the first prediction always takes a bit longer to trace tf function
|
|
453
|
+
if not finetune_mode and predict_data_example:
|
|
454
|
+
model.run_inference(predict_data_example)
|
|
455
|
+
|
|
456
|
+
logger.debug("Finished loading the model.")
|
|
457
|
+
return model
|
|
458
|
+
|
|
459
|
+
@staticmethod
|
|
460
|
+
def batch_to_model_data_format(
|
|
461
|
+
batch: MaybeNestedBatchData,
|
|
462
|
+
data_signature: Dict[Text, Dict[Text, List[FeatureSignature]]],
|
|
463
|
+
) -> Dict[Text, Dict[Text, List[tf.Tensor]]]:
|
|
464
|
+
"""Convert input batch tensors into batch data format.
|
|
465
|
+
|
|
466
|
+
Batch contains any number of batch data. The order is equal to the
|
|
467
|
+
key-value pairs in session data. As sparse data were converted into (indices,
|
|
468
|
+
data, shape) before, this method converts them into sparse tensors. Dense
|
|
469
|
+
data is kept.
|
|
470
|
+
"""
|
|
471
|
+
# during training batch is a tuple of input and target data
|
|
472
|
+
# as our target data is inside the input data, we are just interested in the
|
|
473
|
+
# input data
|
|
474
|
+
unpacked_batch = batch[0] if isinstance(batch[0], Tuple) else batch
|
|
475
|
+
|
|
476
|
+
batch_data: Dict[Text, Dict[Text, List[tf.Tensor]]] = defaultdict(
|
|
477
|
+
lambda: defaultdict(list)
|
|
478
|
+
)
|
|
479
|
+
|
|
480
|
+
idx = 0
|
|
481
|
+
for key, values in data_signature.items():
|
|
482
|
+
for sub_key, signature in values.items():
|
|
483
|
+
for is_sparse, feature_dimension, number_of_dimensions in signature:
|
|
484
|
+
# we converted all 4D features to 3D features before
|
|
485
|
+
number_of_dimensions = (
|
|
486
|
+
number_of_dimensions if number_of_dimensions != 4 else 3
|
|
487
|
+
)
|
|
488
|
+
if is_sparse:
|
|
489
|
+
tensor, idx = RasaModel._convert_sparse_features(
|
|
490
|
+
unpacked_batch, feature_dimension, idx, number_of_dimensions
|
|
491
|
+
)
|
|
492
|
+
else:
|
|
493
|
+
tensor, idx = RasaModel._convert_dense_features(
|
|
494
|
+
unpacked_batch, feature_dimension, idx, number_of_dimensions
|
|
495
|
+
)
|
|
496
|
+
batch_data[key][sub_key].append(tensor)
|
|
497
|
+
|
|
498
|
+
return batch_data
|
|
499
|
+
|
|
500
|
+
@staticmethod
|
|
501
|
+
def _convert_dense_features(
|
|
502
|
+
batch: BatchData,
|
|
503
|
+
feature_dimension: int,
|
|
504
|
+
idx: int,
|
|
505
|
+
number_of_dimensions: int,
|
|
506
|
+
) -> Tuple[tf.Tensor, int]:
|
|
507
|
+
batch_at_idx = batch[idx]
|
|
508
|
+
if isinstance(batch_at_idx, tf.Tensor):
|
|
509
|
+
# explicitly substitute last dimension in shape with known
|
|
510
|
+
# static value
|
|
511
|
+
if number_of_dimensions > 1 and (
|
|
512
|
+
batch_at_idx.shape is None or batch_at_idx.shape[-1] is None
|
|
513
|
+
):
|
|
514
|
+
shape: List[Optional[int]] = [None] * (number_of_dimensions - 1)
|
|
515
|
+
shape.append(feature_dimension)
|
|
516
|
+
batch_at_idx.set_shape(shape)
|
|
517
|
+
|
|
518
|
+
return batch_at_idx, idx + 1
|
|
519
|
+
|
|
520
|
+
# convert to Tensor
|
|
521
|
+
return (
|
|
522
|
+
tf.constant(batch[idx], dtype=tf.float32, shape=batch[idx].shape),
|
|
523
|
+
idx + 1,
|
|
524
|
+
)
|
|
525
|
+
|
|
526
|
+
@staticmethod
|
|
527
|
+
def _convert_sparse_features(
|
|
528
|
+
batch: BatchData,
|
|
529
|
+
feature_dimension: int,
|
|
530
|
+
idx: int,
|
|
531
|
+
number_of_dimensions: int,
|
|
532
|
+
) -> Tuple[tf.SparseTensor, int]:
|
|
533
|
+
# explicitly substitute last dimension in shape with known
|
|
534
|
+
# static value
|
|
535
|
+
shape = [batch[idx + 2][i] for i in range(number_of_dimensions - 1)] + [
|
|
536
|
+
feature_dimension
|
|
537
|
+
]
|
|
538
|
+
return tf.SparseTensor(batch[idx], batch[idx + 1], shape), idx + 3
|
|
539
|
+
|
|
540
|
+
def call(
|
|
541
|
+
self,
|
|
542
|
+
inputs: Union[tf.Tensor, List[tf.Tensor]],
|
|
543
|
+
training: Optional[tf.Tensor] = None,
|
|
544
|
+
mask: Optional[tf.Tensor] = None,
|
|
545
|
+
) -> Union[tf.Tensor, List[tf.Tensor]]:
|
|
546
|
+
"""Calls the model on new inputs.
|
|
547
|
+
|
|
548
|
+
Arguments:
|
|
549
|
+
inputs: A tensor or list of tensors.
|
|
550
|
+
training: Boolean or boolean scalar tensor, indicating whether to run
|
|
551
|
+
the `Network` in training mode or inference mode.
|
|
552
|
+
mask: A mask or list of masks. A mask can be
|
|
553
|
+
either a tensor or None (no mask).
|
|
554
|
+
|
|
555
|
+
Returns:
|
|
556
|
+
A tensor if there is a single output, or
|
|
557
|
+
a list of tensors if there are more than one outputs.
|
|
558
|
+
"""
|
|
559
|
+
# This method needs to be implemented, otherwise the super class is raising a
|
|
560
|
+
# NotImplementedError('When subclassing the `Model` class, you should
|
|
561
|
+
# implement a `call` method.')
|
|
562
|
+
pass
|
|
563
|
+
|
|
564
|
+
|
|
565
|
+
# noinspection PyMethodOverriding
|
|
566
|
+
class TransformerRasaModel(RasaModel):
|
|
567
|
+
def __init__(
|
|
568
|
+
self,
|
|
569
|
+
name: Text,
|
|
570
|
+
config: Dict[Text, Any],
|
|
571
|
+
data_signature: Dict[Text, Dict[Text, List[FeatureSignature]]],
|
|
572
|
+
label_data: RasaModelData,
|
|
573
|
+
) -> None:
|
|
574
|
+
super().__init__(name=name, random_seed=config[RANDOM_SEED])
|
|
575
|
+
|
|
576
|
+
self.config = config
|
|
577
|
+
self.data_signature = data_signature
|
|
578
|
+
self.label_signature = label_data.get_signature()
|
|
579
|
+
self._check_data()
|
|
580
|
+
|
|
581
|
+
label_batch = RasaDataGenerator.prepare_batch(label_data.data)
|
|
582
|
+
self.tf_label_data = self.batch_to_model_data_format(
|
|
583
|
+
label_batch, self.label_signature
|
|
584
|
+
)
|
|
585
|
+
|
|
586
|
+
# set up tf layers
|
|
587
|
+
self._tf_layers: Dict[Text, tf.keras.layers.Layer] = {}
|
|
588
|
+
|
|
589
|
+
def adjust_for_incremental_training(
|
|
590
|
+
self,
|
|
591
|
+
data_example: Dict[Text, Dict[Text, List[FeatureArray]]],
|
|
592
|
+
new_sparse_feature_sizes: Dict[Text, Dict[Text, List[int]]],
|
|
593
|
+
old_sparse_feature_sizes: Dict[Text, Dict[Text, List[int]]],
|
|
594
|
+
) -> None:
|
|
595
|
+
"""Adjusts the model for incremental training.
|
|
596
|
+
|
|
597
|
+
First we should check if any of the sparse feature sizes has decreased
|
|
598
|
+
and raise an exception if this happens.
|
|
599
|
+
If none of them have decreased and any of them has increased, then the
|
|
600
|
+
function updates `DenseForSparse` layers, compiles the model, fits a sample
|
|
601
|
+
data on it to activate adjusted layer(s) and updates the data signatures.
|
|
602
|
+
|
|
603
|
+
New and old sparse feature sizes could look like this:
|
|
604
|
+
{TEXT: {FEATURE_TYPE_SEQUENCE: [4, 24, 128], FEATURE_TYPE_SENTENCE: [4, 128]}}
|
|
605
|
+
|
|
606
|
+
Args:
|
|
607
|
+
data_example: a data example that is stored with the ML component.
|
|
608
|
+
new_sparse_feature_sizes: sizes of current sparse features.
|
|
609
|
+
old_sparse_feature_sizes: sizes of sparse features the model was
|
|
610
|
+
previously trained on.
|
|
611
|
+
"""
|
|
612
|
+
self._check_if_sparse_feature_sizes_decreased(
|
|
613
|
+
new_sparse_feature_sizes=new_sparse_feature_sizes,
|
|
614
|
+
old_sparse_feature_sizes=old_sparse_feature_sizes,
|
|
615
|
+
)
|
|
616
|
+
if self._sparse_feature_sizes_have_increased(
|
|
617
|
+
new_sparse_feature_sizes=new_sparse_feature_sizes,
|
|
618
|
+
old_sparse_feature_sizes=old_sparse_feature_sizes,
|
|
619
|
+
):
|
|
620
|
+
self._update_dense_for_sparse_layers(
|
|
621
|
+
new_sparse_feature_sizes, old_sparse_feature_sizes
|
|
622
|
+
)
|
|
623
|
+
self._compile_and_fit(data_example)
|
|
624
|
+
|
|
625
|
+
@staticmethod
|
|
626
|
+
def _check_if_sparse_feature_sizes_decreased(
|
|
627
|
+
new_sparse_feature_sizes: Dict[Text, Dict[Text, List[int]]],
|
|
628
|
+
old_sparse_feature_sizes: Dict[Text, Dict[Text, List[int]]],
|
|
629
|
+
) -> None:
|
|
630
|
+
"""Checks if the sizes of sparse features have decreased during fine-tuning.
|
|
631
|
+
|
|
632
|
+
Sparse feature sizes might decrease after changing the training data.
|
|
633
|
+
This can happen for example with `LexicalSyntacticFeaturizer`.
|
|
634
|
+
We don't support this behaviour and we raise an exception if this happens.
|
|
635
|
+
|
|
636
|
+
Args:
|
|
637
|
+
new_sparse_feature_sizes: sizes of current sparse features.
|
|
638
|
+
old_sparse_feature_sizes: sizes of sparse features the model was
|
|
639
|
+
previously trained on.
|
|
640
|
+
|
|
641
|
+
Raises:
|
|
642
|
+
RasaException: When any of the sparse feature sizes decrease
|
|
643
|
+
from the last time training was run.
|
|
644
|
+
"""
|
|
645
|
+
for attribute, new_feature_sizes in new_sparse_feature_sizes.items():
|
|
646
|
+
old_feature_sizes = old_sparse_feature_sizes[attribute]
|
|
647
|
+
for feature_type, new_sizes in new_feature_sizes.items():
|
|
648
|
+
old_sizes = old_feature_sizes[feature_type]
|
|
649
|
+
for new_size, old_size in zip(new_sizes, old_sizes):
|
|
650
|
+
if new_size < old_size:
|
|
651
|
+
raise RasaException(
|
|
652
|
+
"Sparse feature sizes have decreased from the last time "
|
|
653
|
+
"training was run. The training data was changed in a way "
|
|
654
|
+
"that resulted in some features not being present in the "
|
|
655
|
+
"data anymore. This can happen if you had "
|
|
656
|
+
"`LexicalSyntacticFeaturizer` in your pipeline. "
|
|
657
|
+
"The pipeline cannot support incremental training "
|
|
658
|
+
"in this setting. We recommend you to retrain "
|
|
659
|
+
"the model from scratch."
|
|
660
|
+
)
|
|
661
|
+
|
|
662
|
+
@staticmethod
|
|
663
|
+
def _sparse_feature_sizes_have_increased(
|
|
664
|
+
new_sparse_feature_sizes: Dict[Text, Dict[Text, List[int]]],
|
|
665
|
+
old_sparse_feature_sizes: Dict[Text, Dict[Text, List[int]]],
|
|
666
|
+
) -> bool:
|
|
667
|
+
"""Checks if the sizes of sparse features have increased during fine-tuning.
|
|
668
|
+
|
|
669
|
+
If there's any sparse feature size that has increased after changing the
|
|
670
|
+
training data, we need to look for the corresponding `DenseForSparse` layer
|
|
671
|
+
and adjust it. On the other hand, if none of them have increased, we don't
|
|
672
|
+
need to change anything. This function helps us with making the decision.
|
|
673
|
+
|
|
674
|
+
Note that the function assumes that none of the sparse feature sizes
|
|
675
|
+
have decreased. In other words, it should get valid arguments in order
|
|
676
|
+
to function well.
|
|
677
|
+
|
|
678
|
+
Args:
|
|
679
|
+
new_sparse_feature_sizes: sizes of current sparse features.
|
|
680
|
+
old_sparse_feature_sizes: sizes of sparse features the model was
|
|
681
|
+
previously trained on.
|
|
682
|
+
|
|
683
|
+
Returns:
|
|
684
|
+
`True` if any of the sparse feature sizes has increased, `False` otherwise.
|
|
685
|
+
"""
|
|
686
|
+
for attribute, new_feature_sizes in new_sparse_feature_sizes.items():
|
|
687
|
+
old_feature_sizes = old_sparse_feature_sizes[attribute]
|
|
688
|
+
for feature_type, new_sizes in new_feature_sizes.items():
|
|
689
|
+
old_sizes = old_feature_sizes[feature_type]
|
|
690
|
+
if sum(new_sizes) > sum(old_sizes):
|
|
691
|
+
return True
|
|
692
|
+
return False
|
|
693
|
+
|
|
694
|
+
def _update_dense_for_sparse_layers(
|
|
695
|
+
self,
|
|
696
|
+
new_sparse_feature_sizes: Dict[Text, Dict[Text, List[int]]],
|
|
697
|
+
old_sparse_feature_sizes: Dict[Text, Dict[Text, List[int]]],
|
|
698
|
+
) -> None:
|
|
699
|
+
"""Updates `DenseForSparse` layers.
|
|
700
|
+
|
|
701
|
+
Updates sizes of `DenseForSparse` layers by comparing current sparse feature
|
|
702
|
+
sizes to old ones. This must be done before fine-tuning starts to account
|
|
703
|
+
for any change in the size of sparse features that might have happened
|
|
704
|
+
because of addition of new data.
|
|
705
|
+
|
|
706
|
+
Args:
|
|
707
|
+
new_sparse_feature_sizes: sizes of current sparse features.
|
|
708
|
+
old_sparse_feature_sizes: sizes of sparse features the model was
|
|
709
|
+
previously trained on.
|
|
710
|
+
"""
|
|
711
|
+
for name, layer in self._tf_layers.items():
|
|
712
|
+
# `if` condition is necessary because only `RasaCustomLayer`
|
|
713
|
+
# can adjust sparse layers for incremental training by default.
|
|
714
|
+
if isinstance(layer, rasa_layers.RasaCustomLayer):
|
|
715
|
+
layer.adjust_sparse_layers_for_incremental_training(
|
|
716
|
+
new_sparse_feature_sizes,
|
|
717
|
+
old_sparse_feature_sizes,
|
|
718
|
+
self.config[REGULARIZATION_CONSTANT],
|
|
719
|
+
)
|
|
720
|
+
|
|
721
|
+
def _compile_and_fit(
|
|
722
|
+
self, data_example: Dict[Text, Dict[Text, List[FeatureArray]]]
|
|
723
|
+
) -> None:
|
|
724
|
+
"""Compiles modified model and fits a sample data on it.
|
|
725
|
+
|
|
726
|
+
Args:
|
|
727
|
+
data_example: a data example that is stored with the ML component.
|
|
728
|
+
"""
|
|
729
|
+
self.compile(
|
|
730
|
+
optimizer=tf.keras.optimizers.Adam(self.config[LEARNING_RATE]),
|
|
731
|
+
run_eagerly=self.config[RUN_EAGERLY],
|
|
732
|
+
)
|
|
733
|
+
label_key = LABEL_KEY if self.config[INTENT_CLASSIFICATION] else None
|
|
734
|
+
label_sub_key = LABEL_SUB_KEY if self.config[INTENT_CLASSIFICATION] else None
|
|
735
|
+
|
|
736
|
+
model_data = RasaModelData(
|
|
737
|
+
label_key=label_key, label_sub_key=label_sub_key, data=data_example
|
|
738
|
+
)
|
|
739
|
+
self._update_data_signatures(model_data)
|
|
740
|
+
data_generator = RasaBatchDataGenerator(model_data, batch_size=1)
|
|
741
|
+
self.fit(data_generator, verbose=False)
|
|
742
|
+
|
|
743
|
+
def _update_data_signatures(self, model_data: RasaModelData) -> None:
|
|
744
|
+
self.data_signature = model_data.get_signature()
|
|
745
|
+
self.predict_data_signature = {
|
|
746
|
+
feature_name: features
|
|
747
|
+
for feature_name, features in self.data_signature.items()
|
|
748
|
+
if TEXT in feature_name
|
|
749
|
+
}
|
|
750
|
+
|
|
751
|
+
def _check_data(self) -> None:
|
|
752
|
+
raise NotImplementedError
|
|
753
|
+
|
|
754
|
+
def _prepare_layers(self) -> None:
|
|
755
|
+
raise NotImplementedError
|
|
756
|
+
|
|
757
|
+
def _prepare_label_classification_layers(self, predictor_attribute: Text) -> None:
|
|
758
|
+
"""Prepares layers & loss for the final label prediction step."""
|
|
759
|
+
self._prepare_embed_layers(predictor_attribute)
|
|
760
|
+
self._prepare_embed_layers(LABEL)
|
|
761
|
+
self._prepare_dot_product_loss(LABEL, self.config[SCALE_LOSS])
|
|
762
|
+
|
|
763
|
+
def _prepare_embed_layers(self, name: Text, prefix: Text = "embed") -> None:
|
|
764
|
+
self._tf_layers[f"{prefix}.{name}"] = layers.Embed(
|
|
765
|
+
self.config[EMBEDDING_DIMENSION], self.config[REGULARIZATION_CONSTANT], name
|
|
766
|
+
)
|
|
767
|
+
|
|
768
|
+
def _prepare_ffnn_layer(
|
|
769
|
+
self,
|
|
770
|
+
name: Text,
|
|
771
|
+
layer_sizes: List[int],
|
|
772
|
+
drop_rate: float,
|
|
773
|
+
prefix: Text = "ffnn",
|
|
774
|
+
) -> None:
|
|
775
|
+
self._tf_layers[f"{prefix}.{name}"] = layers.Ffnn(
|
|
776
|
+
layer_sizes,
|
|
777
|
+
drop_rate,
|
|
778
|
+
self.config[REGULARIZATION_CONSTANT],
|
|
779
|
+
self.config[CONNECTION_DENSITY],
|
|
780
|
+
layer_name_suffix=name,
|
|
781
|
+
)
|
|
782
|
+
|
|
783
|
+
def _prepare_dot_product_loss(
|
|
784
|
+
self, name: Text, scale_loss: bool, prefix: Text = "loss"
|
|
785
|
+
) -> None:
|
|
786
|
+
self._tf_layers[f"{prefix}.{name}"] = self.dot_product_loss_layer(
|
|
787
|
+
self.config[NUM_NEG],
|
|
788
|
+
loss_type=self.config[LOSS_TYPE],
|
|
789
|
+
mu_pos=self.config[MAX_POS_SIM],
|
|
790
|
+
mu_neg=self.config[MAX_NEG_SIM],
|
|
791
|
+
use_max_sim_neg=self.config[USE_MAX_NEG_SIM],
|
|
792
|
+
neg_lambda=self.config[NEGATIVE_MARGIN_SCALE],
|
|
793
|
+
scale_loss=scale_loss,
|
|
794
|
+
similarity_type=self.config[SIMILARITY_TYPE],
|
|
795
|
+
constrain_similarities=self.config[CONSTRAIN_SIMILARITIES],
|
|
796
|
+
model_confidence=self.config[MODEL_CONFIDENCE],
|
|
797
|
+
)
|
|
798
|
+
|
|
799
|
+
@property
|
|
800
|
+
def dot_product_loss_layer(self) -> tf.keras.layers.Layer:
|
|
801
|
+
"""Returns the dot-product loss layer to use.
|
|
802
|
+
|
|
803
|
+
Returns:
|
|
804
|
+
The loss layer that is used by `_prepare_dot_product_loss`.
|
|
805
|
+
"""
|
|
806
|
+
return layers.SingleLabelDotProductLoss
|
|
807
|
+
|
|
808
|
+
def _prepare_entity_recognition_layers(self) -> None:
|
|
809
|
+
for tag_spec in self._entity_tag_specs:
|
|
810
|
+
name = tag_spec.tag_name
|
|
811
|
+
num_tags = tag_spec.num_tags
|
|
812
|
+
self._tf_layers[f"embed.{name}.logits"] = layers.Embed(
|
|
813
|
+
num_tags, self.config[REGULARIZATION_CONSTANT], f"logits.{name}"
|
|
814
|
+
)
|
|
815
|
+
self._tf_layers[f"crf.{name}"] = layers.CRF(
|
|
816
|
+
num_tags, self.config[REGULARIZATION_CONSTANT], self.config[SCALE_LOSS]
|
|
817
|
+
)
|
|
818
|
+
self._tf_layers[f"embed.{name}.tags"] = layers.Embed(
|
|
819
|
+
self.config[EMBEDDING_DIMENSION],
|
|
820
|
+
self.config[REGULARIZATION_CONSTANT],
|
|
821
|
+
f"tags.{name}",
|
|
822
|
+
)
|
|
823
|
+
|
|
824
|
+
@staticmethod
|
|
825
|
+
def _last_token(x: tf.Tensor, sequence_lengths: tf.Tensor) -> tf.Tensor:
|
|
826
|
+
last_sequence_index = tf.maximum(0, sequence_lengths - 1)
|
|
827
|
+
batch_index = tf.range(tf.shape(last_sequence_index)[0])
|
|
828
|
+
|
|
829
|
+
indices = tf.stack([batch_index, last_sequence_index], axis=1)
|
|
830
|
+
return tf.gather_nd(x, indices)
|
|
831
|
+
|
|
832
|
+
def _get_mask_for(
|
|
833
|
+
self,
|
|
834
|
+
tf_batch_data: Dict[Text, Dict[Text, List[tf.Tensor]]],
|
|
835
|
+
key: Text,
|
|
836
|
+
sub_key: Text,
|
|
837
|
+
) -> Optional[tf.Tensor]:
|
|
838
|
+
if key not in tf_batch_data or sub_key not in tf_batch_data[key]:
|
|
839
|
+
return None
|
|
840
|
+
|
|
841
|
+
sequence_lengths = tf.cast(tf_batch_data[key][sub_key][0], dtype=tf.int32)
|
|
842
|
+
return rasa_layers.compute_mask(sequence_lengths)
|
|
843
|
+
|
|
844
|
+
def _get_sequence_feature_lengths(
|
|
845
|
+
self, tf_batch_data: Dict[Text, Dict[Text, List[tf.Tensor]]], key: Text
|
|
846
|
+
) -> tf.Tensor:
|
|
847
|
+
"""Fetches the sequence lengths of real tokens per input example.
|
|
848
|
+
|
|
849
|
+
The number of real tokens for an example is the same as the length of the
|
|
850
|
+
sequence of the sequence-level (token-level) features for that input example.
|
|
851
|
+
"""
|
|
852
|
+
if key in tf_batch_data and SEQUENCE_LENGTH in tf_batch_data[key]:
|
|
853
|
+
return tf.cast(tf_batch_data[key][SEQUENCE_LENGTH][0], dtype=tf.int32)
|
|
854
|
+
|
|
855
|
+
batch_dim = self._get_batch_dim(tf_batch_data[key])
|
|
856
|
+
return tf.zeros([batch_dim], dtype=tf.int32)
|
|
857
|
+
|
|
858
|
+
def _get_sentence_feature_lengths(
|
|
859
|
+
self, tf_batch_data: Dict[Text, Dict[Text, List[tf.Tensor]]], key: Text
|
|
860
|
+
) -> tf.Tensor:
|
|
861
|
+
"""Fetches the sequence lengths of sentence-level features per input example.
|
|
862
|
+
|
|
863
|
+
This is needed because we treat sentence-level features as token-level features
|
|
864
|
+
with 1 token per input example. Hence, the sequence lengths returned by this
|
|
865
|
+
function are all 1s if sentence-level features are present, and 0s otherwise.
|
|
866
|
+
"""
|
|
867
|
+
batch_dim = self._get_batch_dim(tf_batch_data[key])
|
|
868
|
+
|
|
869
|
+
if key in tf_batch_data and SENTENCE in tf_batch_data[key]:
|
|
870
|
+
return tf.ones([batch_dim], dtype=tf.int32)
|
|
871
|
+
|
|
872
|
+
return tf.zeros([batch_dim], dtype=tf.int32)
|
|
873
|
+
|
|
874
|
+
@staticmethod
|
|
875
|
+
def _get_batch_dim(attribute_data: Dict[Text, List[tf.Tensor]]) -> int:
|
|
876
|
+
# All the values in the attribute_data dict should be lists of tensors, each
|
|
877
|
+
# tensor of the shape (batch_dim, ...). So we take the first non-empty list we
|
|
878
|
+
# encounter and infer the batch size from its first tensor.
|
|
879
|
+
for key, data in attribute_data.items():
|
|
880
|
+
if data:
|
|
881
|
+
return tf.shape(data[0])[0]
|
|
882
|
+
|
|
883
|
+
return 0
|
|
884
|
+
|
|
885
|
+
def _calculate_entity_loss(
|
|
886
|
+
self,
|
|
887
|
+
inputs: tf.Tensor,
|
|
888
|
+
tag_ids: tf.Tensor,
|
|
889
|
+
mask: tf.Tensor,
|
|
890
|
+
sequence_lengths: tf.Tensor,
|
|
891
|
+
tag_name: Text,
|
|
892
|
+
entity_tags: Optional[tf.Tensor] = None,
|
|
893
|
+
) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
|
|
894
|
+
tag_ids = tf.cast(tag_ids[:, :, 0], tf.int32)
|
|
895
|
+
|
|
896
|
+
if entity_tags is not None:
|
|
897
|
+
_tags = self._tf_layers[f"embed.{tag_name}.tags"](entity_tags)
|
|
898
|
+
inputs = tf.concat([inputs, _tags], axis=-1)
|
|
899
|
+
|
|
900
|
+
logits = self._tf_layers[f"embed.{tag_name}.logits"](inputs)
|
|
901
|
+
|
|
902
|
+
# should call first to build weights
|
|
903
|
+
pred_ids, _ = self._tf_layers[f"crf.{tag_name}"](logits, sequence_lengths)
|
|
904
|
+
loss = self._tf_layers[f"crf.{tag_name}"].loss(
|
|
905
|
+
logits, tag_ids, sequence_lengths
|
|
906
|
+
)
|
|
907
|
+
f1 = self._tf_layers[f"crf.{tag_name}"].f1_score(tag_ids, pred_ids, mask)
|
|
908
|
+
|
|
909
|
+
return loss, f1, logits
|
|
910
|
+
|
|
911
|
+
def batch_loss(
|
|
912
|
+
self, batch_in: Union[Tuple[tf.Tensor, ...], Tuple[np.ndarray, ...]]
|
|
913
|
+
) -> tf.Tensor:
|
|
914
|
+
"""Calculates the loss for the given batch.
|
|
915
|
+
|
|
916
|
+
Args:
|
|
917
|
+
batch_in: The batch.
|
|
918
|
+
|
|
919
|
+
Returns:
|
|
920
|
+
The loss of the given batch.
|
|
921
|
+
"""
|
|
922
|
+
raise NotImplementedError
|
|
923
|
+
|
|
924
|
+
def batch_predict(
|
|
925
|
+
self, batch_in: Union[Tuple[tf.Tensor, ...], Tuple[np.ndarray, ...]]
|
|
926
|
+
) -> Dict[Text, Union[tf.Tensor, Dict[Text, tf.Tensor]]]:
|
|
927
|
+
"""Predicts the output of the given batch.
|
|
928
|
+
|
|
929
|
+
Args:
|
|
930
|
+
batch_in: The batch.
|
|
931
|
+
|
|
932
|
+
Returns:
|
|
933
|
+
The output to predict.
|
|
934
|
+
"""
|
|
935
|
+
raise NotImplementedError
|