rasa-pro 3.12.0.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- README.md +41 -0
- rasa/__init__.py +9 -0
- rasa/__main__.py +177 -0
- rasa/anonymization/__init__.py +2 -0
- rasa/anonymization/anonymisation_rule_yaml_reader.py +91 -0
- rasa/anonymization/anonymization_pipeline.py +286 -0
- rasa/anonymization/anonymization_rule_executor.py +260 -0
- rasa/anonymization/anonymization_rule_orchestrator.py +120 -0
- rasa/anonymization/schemas/config.yml +47 -0
- rasa/anonymization/utils.py +118 -0
- rasa/api.py +160 -0
- rasa/cli/__init__.py +5 -0
- rasa/cli/arguments/__init__.py +0 -0
- rasa/cli/arguments/data.py +106 -0
- rasa/cli/arguments/default_arguments.py +207 -0
- rasa/cli/arguments/evaluate.py +65 -0
- rasa/cli/arguments/export.py +51 -0
- rasa/cli/arguments/interactive.py +74 -0
- rasa/cli/arguments/run.py +219 -0
- rasa/cli/arguments/shell.py +17 -0
- rasa/cli/arguments/test.py +211 -0
- rasa/cli/arguments/train.py +279 -0
- rasa/cli/arguments/visualize.py +34 -0
- rasa/cli/arguments/x.py +30 -0
- rasa/cli/data.py +354 -0
- rasa/cli/dialogue_understanding_test.py +251 -0
- rasa/cli/e2e_test.py +259 -0
- rasa/cli/evaluate.py +222 -0
- rasa/cli/export.py +250 -0
- rasa/cli/inspect.py +75 -0
- rasa/cli/interactive.py +166 -0
- rasa/cli/license.py +65 -0
- rasa/cli/llm_fine_tuning.py +403 -0
- rasa/cli/markers.py +78 -0
- rasa/cli/project_templates/__init__.py +0 -0
- rasa/cli/project_templates/calm/actions/__init__.py +0 -0
- rasa/cli/project_templates/calm/actions/action_template.py +27 -0
- rasa/cli/project_templates/calm/actions/add_contact.py +30 -0
- rasa/cli/project_templates/calm/actions/db.py +57 -0
- rasa/cli/project_templates/calm/actions/list_contacts.py +22 -0
- rasa/cli/project_templates/calm/actions/remove_contact.py +35 -0
- rasa/cli/project_templates/calm/config.yml +10 -0
- rasa/cli/project_templates/calm/credentials.yml +33 -0
- rasa/cli/project_templates/calm/data/flows/add_contact.yml +31 -0
- rasa/cli/project_templates/calm/data/flows/list_contacts.yml +14 -0
- rasa/cli/project_templates/calm/data/flows/remove_contact.yml +29 -0
- rasa/cli/project_templates/calm/db/contacts.json +10 -0
- rasa/cli/project_templates/calm/domain/add_contact.yml +39 -0
- rasa/cli/project_templates/calm/domain/list_contacts.yml +17 -0
- rasa/cli/project_templates/calm/domain/remove_contact.yml +38 -0
- rasa/cli/project_templates/calm/domain/shared.yml +10 -0
- rasa/cli/project_templates/calm/e2e_tests/cancelations/user_cancels_during_a_correction.yml +16 -0
- rasa/cli/project_templates/calm/e2e_tests/cancelations/user_changes_mind_on_a_whim.yml +7 -0
- rasa/cli/project_templates/calm/e2e_tests/corrections/user_corrects_contact_handle.yml +20 -0
- rasa/cli/project_templates/calm/e2e_tests/corrections/user_corrects_contact_name.yml +19 -0
- rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_adds_contact_to_their_list.yml +15 -0
- rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_lists_contacts.yml +5 -0
- rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_removes_contact.yml +11 -0
- rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_removes_contact_from_list.yml +12 -0
- rasa/cli/project_templates/calm/endpoints.yml +58 -0
- rasa/cli/project_templates/default/actions/__init__.py +0 -0
- rasa/cli/project_templates/default/actions/actions.py +27 -0
- rasa/cli/project_templates/default/config.yml +44 -0
- rasa/cli/project_templates/default/credentials.yml +33 -0
- rasa/cli/project_templates/default/data/nlu.yml +91 -0
- rasa/cli/project_templates/default/data/rules.yml +13 -0
- rasa/cli/project_templates/default/data/stories.yml +30 -0
- rasa/cli/project_templates/default/domain.yml +34 -0
- rasa/cli/project_templates/default/endpoints.yml +42 -0
- rasa/cli/project_templates/default/tests/test_stories.yml +91 -0
- rasa/cli/project_templates/tutorial/actions/__init__.py +0 -0
- rasa/cli/project_templates/tutorial/actions/actions.py +22 -0
- rasa/cli/project_templates/tutorial/config.yml +12 -0
- rasa/cli/project_templates/tutorial/credentials.yml +33 -0
- rasa/cli/project_templates/tutorial/data/flows.yml +8 -0
- rasa/cli/project_templates/tutorial/data/patterns.yml +11 -0
- rasa/cli/project_templates/tutorial/domain.yml +35 -0
- rasa/cli/project_templates/tutorial/endpoints.yml +55 -0
- rasa/cli/run.py +143 -0
- rasa/cli/scaffold.py +273 -0
- rasa/cli/shell.py +141 -0
- rasa/cli/studio/__init__.py +0 -0
- rasa/cli/studio/download.py +62 -0
- rasa/cli/studio/studio.py +296 -0
- rasa/cli/studio/train.py +59 -0
- rasa/cli/studio/upload.py +62 -0
- rasa/cli/telemetry.py +102 -0
- rasa/cli/test.py +280 -0
- rasa/cli/train.py +278 -0
- rasa/cli/utils.py +484 -0
- rasa/cli/visualize.py +40 -0
- rasa/cli/x.py +206 -0
- rasa/constants.py +45 -0
- rasa/core/__init__.py +17 -0
- rasa/core/actions/__init__.py +0 -0
- rasa/core/actions/action.py +1318 -0
- rasa/core/actions/action_clean_stack.py +59 -0
- rasa/core/actions/action_exceptions.py +24 -0
- rasa/core/actions/action_hangup.py +29 -0
- rasa/core/actions/action_repeat_bot_messages.py +89 -0
- rasa/core/actions/action_run_slot_rejections.py +210 -0
- rasa/core/actions/action_trigger_chitchat.py +31 -0
- rasa/core/actions/action_trigger_flow.py +109 -0
- rasa/core/actions/action_trigger_search.py +31 -0
- rasa/core/actions/constants.py +5 -0
- rasa/core/actions/custom_action_executor.py +191 -0
- rasa/core/actions/direct_custom_actions_executor.py +109 -0
- rasa/core/actions/e2e_stub_custom_action_executor.py +72 -0
- rasa/core/actions/forms.py +741 -0
- rasa/core/actions/grpc_custom_action_executor.py +251 -0
- rasa/core/actions/http_custom_action_executor.py +145 -0
- rasa/core/actions/loops.py +114 -0
- rasa/core/actions/two_stage_fallback.py +186 -0
- rasa/core/agent.py +559 -0
- rasa/core/auth_retry_tracker_store.py +122 -0
- rasa/core/brokers/__init__.py +0 -0
- rasa/core/brokers/broker.py +126 -0
- rasa/core/brokers/file.py +58 -0
- rasa/core/brokers/kafka.py +324 -0
- rasa/core/brokers/pika.py +388 -0
- rasa/core/brokers/sql.py +86 -0
- rasa/core/channels/__init__.py +61 -0
- rasa/core/channels/botframework.py +338 -0
- rasa/core/channels/callback.py +84 -0
- rasa/core/channels/channel.py +456 -0
- rasa/core/channels/console.py +241 -0
- rasa/core/channels/development_inspector.py +197 -0
- rasa/core/channels/facebook.py +419 -0
- rasa/core/channels/hangouts.py +329 -0
- rasa/core/channels/inspector/.eslintrc.cjs +25 -0
- rasa/core/channels/inspector/.gitignore +23 -0
- rasa/core/channels/inspector/README.md +54 -0
- rasa/core/channels/inspector/assets/favicon.ico +0 -0
- rasa/core/channels/inspector/assets/rasa-chat.js +2 -0
- rasa/core/channels/inspector/custom.d.ts +3 -0
- rasa/core/channels/inspector/dist/assets/arc-861ddd57.js +1 -0
- rasa/core/channels/inspector/dist/assets/array-9f3ba611.js +1 -0
- rasa/core/channels/inspector/dist/assets/c4Diagram-d0fbc5ce-921f02db.js +10 -0
- rasa/core/channels/inspector/dist/assets/classDiagram-936ed81e-b436c4f8.js +2 -0
- rasa/core/channels/inspector/dist/assets/classDiagram-v2-c3cb15f1-511a23cb.js +2 -0
- rasa/core/channels/inspector/dist/assets/createText-62fc7601-ef476ecd.js +7 -0
- rasa/core/channels/inspector/dist/assets/edges-f2ad444c-f1878e0a.js +4 -0
- rasa/core/channels/inspector/dist/assets/erDiagram-9d236eb7-fac75185.js +51 -0
- rasa/core/channels/inspector/dist/assets/flowDb-1972c806-201c5bbc.js +6 -0
- rasa/core/channels/inspector/dist/assets/flowDiagram-7ea5b25a-f904ae41.js +4 -0
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-b080d6f2.js +1 -0
- rasa/core/channels/inspector/dist/assets/flowchart-elk-definition-abe16c3d-1813da66.js +139 -0
- rasa/core/channels/inspector/dist/assets/ganttDiagram-9b5ea136-872af172.js +266 -0
- rasa/core/channels/inspector/dist/assets/gitGraphDiagram-99d0ae7c-34a0af5a.js +70 -0
- rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-128cfa44.ttf +0 -0
- rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-21dbcb97.woff +0 -0
- rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-222b5e26.svg +329 -0
- rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-9ad89b2a.woff2 +0 -0
- rasa/core/channels/inspector/dist/assets/index-2c4b9a3b-42ba3e3d.js +1 -0
- rasa/core/channels/inspector/dist/assets/index-37817b51.js +1317 -0
- rasa/core/channels/inspector/dist/assets/index-3ee28881.css +1 -0
- rasa/core/channels/inspector/dist/assets/infoDiagram-736b4530-6b731386.js +7 -0
- rasa/core/channels/inspector/dist/assets/init-77b53fdd.js +1 -0
- rasa/core/channels/inspector/dist/assets/journeyDiagram-df861f2b-e8579ac6.js +139 -0
- rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-60c05ee4.woff +0 -0
- rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-8335d9b8.svg +438 -0
- rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-9cc39c75.ttf +0 -0
- rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-ead13ccf.woff2 +0 -0
- rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-16705655.woff2 +0 -0
- rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-5aeb07f9.woff +0 -0
- rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-9c459044.ttf +0 -0
- rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-9e2898a4.svg +435 -0
- rasa/core/channels/inspector/dist/assets/layout-89e6403a.js +1 -0
- rasa/core/channels/inspector/dist/assets/line-dc73d3fc.js +1 -0
- rasa/core/channels/inspector/dist/assets/linear-f5b1d2bc.js +1 -0
- rasa/core/channels/inspector/dist/assets/mindmap-definition-beec6740-82cb74fa.js +109 -0
- rasa/core/channels/inspector/dist/assets/ordinal-ba9b4969.js +1 -0
- rasa/core/channels/inspector/dist/assets/path-53f90ab3.js +1 -0
- rasa/core/channels/inspector/dist/assets/pieDiagram-dbbf0591-bdf5f29b.js +35 -0
- rasa/core/channels/inspector/dist/assets/quadrantDiagram-4d7f4fd6-c7a0cbe4.js +7 -0
- rasa/core/channels/inspector/dist/assets/requirementDiagram-6fc4c22a-7ec5410f.js +52 -0
- rasa/core/channels/inspector/dist/assets/sankeyDiagram-8f13d901-caee5554.js +8 -0
- rasa/core/channels/inspector/dist/assets/sequenceDiagram-b655622a-2935f8db.js +122 -0
- rasa/core/channels/inspector/dist/assets/stateDiagram-59f0c015-8f5d9693.js +1 -0
- rasa/core/channels/inspector/dist/assets/stateDiagram-v2-2b26beab-d565d1de.js +1 -0
- rasa/core/channels/inspector/dist/assets/styles-080da4f6-75ad421d.js +110 -0
- rasa/core/channels/inspector/dist/assets/styles-3dcbcfbf-7e764226.js +159 -0
- rasa/core/channels/inspector/dist/assets/styles-9c745c82-7a4e0e61.js +207 -0
- rasa/core/channels/inspector/dist/assets/svgDrawCommon-4835440b-4019d1bf.js +1 -0
- rasa/core/channels/inspector/dist/assets/timeline-definition-5b62e21b-01ea12df.js +61 -0
- rasa/core/channels/inspector/dist/assets/xychartDiagram-2b33534f-89407137.js +7 -0
- rasa/core/channels/inspector/dist/index.html +42 -0
- rasa/core/channels/inspector/index.html +40 -0
- rasa/core/channels/inspector/jest.config.ts +13 -0
- rasa/core/channels/inspector/package.json +52 -0
- rasa/core/channels/inspector/setupTests.ts +2 -0
- rasa/core/channels/inspector/src/App.tsx +220 -0
- rasa/core/channels/inspector/src/components/Chat.tsx +95 -0
- rasa/core/channels/inspector/src/components/DiagramFlow.tsx +108 -0
- rasa/core/channels/inspector/src/components/DialogueInformation.tsx +187 -0
- rasa/core/channels/inspector/src/components/DialogueStack.tsx +136 -0
- rasa/core/channels/inspector/src/components/ExpandIcon.tsx +16 -0
- rasa/core/channels/inspector/src/components/FullscreenButton.tsx +45 -0
- rasa/core/channels/inspector/src/components/LoadingSpinner.tsx +22 -0
- rasa/core/channels/inspector/src/components/NoActiveFlow.tsx +21 -0
- rasa/core/channels/inspector/src/components/RasaLogo.tsx +32 -0
- rasa/core/channels/inspector/src/components/SaraDiagrams.tsx +39 -0
- rasa/core/channels/inspector/src/components/Slots.tsx +91 -0
- rasa/core/channels/inspector/src/components/Welcome.tsx +54 -0
- rasa/core/channels/inspector/src/helpers/audiostream.ts +191 -0
- rasa/core/channels/inspector/src/helpers/formatters.test.ts +392 -0
- rasa/core/channels/inspector/src/helpers/formatters.ts +306 -0
- rasa/core/channels/inspector/src/helpers/utils.ts +127 -0
- rasa/core/channels/inspector/src/main.tsx +13 -0
- rasa/core/channels/inspector/src/theme/Button/Button.ts +29 -0
- rasa/core/channels/inspector/src/theme/Heading/Heading.ts +31 -0
- rasa/core/channels/inspector/src/theme/Input/Input.ts +27 -0
- rasa/core/channels/inspector/src/theme/Link/Link.ts +10 -0
- rasa/core/channels/inspector/src/theme/Modal/Modal.ts +47 -0
- rasa/core/channels/inspector/src/theme/Table/Table.tsx +38 -0
- rasa/core/channels/inspector/src/theme/Tooltip/Tooltip.ts +12 -0
- rasa/core/channels/inspector/src/theme/base/breakpoints.ts +8 -0
- rasa/core/channels/inspector/src/theme/base/colors.ts +88 -0
- rasa/core/channels/inspector/src/theme/base/fonts/fontFaces.css +29 -0
- rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.eot +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.svg +329 -0
- rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.ttf +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.woff +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.woff2 +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.eot +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.svg +438 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.ttf +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.woff +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.woff2 +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.eot +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.svg +435 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.ttf +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.woff +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.woff2 +0 -0
- rasa/core/channels/inspector/src/theme/base/radii.ts +9 -0
- rasa/core/channels/inspector/src/theme/base/shadows.ts +7 -0
- rasa/core/channels/inspector/src/theme/base/sizes.ts +7 -0
- rasa/core/channels/inspector/src/theme/base/space.ts +15 -0
- rasa/core/channels/inspector/src/theme/base/styles.ts +13 -0
- rasa/core/channels/inspector/src/theme/base/typography.ts +24 -0
- rasa/core/channels/inspector/src/theme/base/zIndices.ts +19 -0
- rasa/core/channels/inspector/src/theme/index.ts +101 -0
- rasa/core/channels/inspector/src/types.ts +84 -0
- rasa/core/channels/inspector/src/vite-env.d.ts +1 -0
- rasa/core/channels/inspector/tests/__mocks__/fileMock.ts +1 -0
- rasa/core/channels/inspector/tests/__mocks__/matchMedia.ts +16 -0
- rasa/core/channels/inspector/tests/__mocks__/styleMock.ts +1 -0
- rasa/core/channels/inspector/tests/renderWithProviders.tsx +14 -0
- rasa/core/channels/inspector/tsconfig.json +26 -0
- rasa/core/channels/inspector/tsconfig.node.json +10 -0
- rasa/core/channels/inspector/vite.config.ts +8 -0
- rasa/core/channels/inspector/yarn.lock +6249 -0
- rasa/core/channels/mattermost.py +229 -0
- rasa/core/channels/rasa_chat.py +126 -0
- rasa/core/channels/rest.py +230 -0
- rasa/core/channels/rocketchat.py +174 -0
- rasa/core/channels/slack.py +620 -0
- rasa/core/channels/socketio.py +302 -0
- rasa/core/channels/telegram.py +298 -0
- rasa/core/channels/twilio.py +169 -0
- rasa/core/channels/vier_cvg.py +374 -0
- rasa/core/channels/voice_ready/__init__.py +0 -0
- rasa/core/channels/voice_ready/audiocodes.py +501 -0
- rasa/core/channels/voice_ready/jambonz.py +121 -0
- rasa/core/channels/voice_ready/jambonz_protocol.py +396 -0
- rasa/core/channels/voice_ready/twilio_voice.py +403 -0
- rasa/core/channels/voice_ready/utils.py +37 -0
- rasa/core/channels/voice_stream/__init__.py +0 -0
- rasa/core/channels/voice_stream/asr/__init__.py +0 -0
- rasa/core/channels/voice_stream/asr/asr_engine.py +89 -0
- rasa/core/channels/voice_stream/asr/asr_event.py +18 -0
- rasa/core/channels/voice_stream/asr/azure.py +130 -0
- rasa/core/channels/voice_stream/asr/deepgram.py +90 -0
- rasa/core/channels/voice_stream/audio_bytes.py +8 -0
- rasa/core/channels/voice_stream/browser_audio.py +107 -0
- rasa/core/channels/voice_stream/call_state.py +23 -0
- rasa/core/channels/voice_stream/tts/__init__.py +0 -0
- rasa/core/channels/voice_stream/tts/azure.py +106 -0
- rasa/core/channels/voice_stream/tts/cartesia.py +118 -0
- rasa/core/channels/voice_stream/tts/tts_cache.py +27 -0
- rasa/core/channels/voice_stream/tts/tts_engine.py +58 -0
- rasa/core/channels/voice_stream/twilio_media_streams.py +173 -0
- rasa/core/channels/voice_stream/util.py +57 -0
- rasa/core/channels/voice_stream/voice_channel.py +427 -0
- rasa/core/channels/webexteams.py +134 -0
- rasa/core/concurrent_lock_store.py +210 -0
- rasa/core/constants.py +112 -0
- rasa/core/evaluation/__init__.py +0 -0
- rasa/core/evaluation/marker.py +267 -0
- rasa/core/evaluation/marker_base.py +923 -0
- rasa/core/evaluation/marker_stats.py +293 -0
- rasa/core/evaluation/marker_tracker_loader.py +103 -0
- rasa/core/exceptions.py +29 -0
- rasa/core/exporter.py +284 -0
- rasa/core/featurizers/__init__.py +0 -0
- rasa/core/featurizers/precomputation.py +410 -0
- rasa/core/featurizers/single_state_featurizer.py +421 -0
- rasa/core/featurizers/tracker_featurizers.py +1262 -0
- rasa/core/http_interpreter.py +89 -0
- rasa/core/information_retrieval/__init__.py +7 -0
- rasa/core/information_retrieval/faiss.py +124 -0
- rasa/core/information_retrieval/information_retrieval.py +137 -0
- rasa/core/information_retrieval/milvus.py +59 -0
- rasa/core/information_retrieval/qdrant.py +96 -0
- rasa/core/jobs.py +63 -0
- rasa/core/lock.py +139 -0
- rasa/core/lock_store.py +343 -0
- rasa/core/migrate.py +403 -0
- rasa/core/nlg/__init__.py +3 -0
- rasa/core/nlg/callback.py +146 -0
- rasa/core/nlg/contextual_response_rephraser.py +320 -0
- rasa/core/nlg/generator.py +230 -0
- rasa/core/nlg/interpolator.py +143 -0
- rasa/core/nlg/response.py +155 -0
- rasa/core/nlg/summarize.py +70 -0
- rasa/core/persistor.py +538 -0
- rasa/core/policies/__init__.py +0 -0
- rasa/core/policies/ensemble.py +329 -0
- rasa/core/policies/enterprise_search_policy.py +905 -0
- rasa/core/policies/enterprise_search_prompt_template.jinja2 +25 -0
- rasa/core/policies/enterprise_search_prompt_with_citation_template.jinja2 +60 -0
- rasa/core/policies/flow_policy.py +205 -0
- rasa/core/policies/flows/__init__.py +0 -0
- rasa/core/policies/flows/flow_exceptions.py +44 -0
- rasa/core/policies/flows/flow_executor.py +754 -0
- rasa/core/policies/flows/flow_step_result.py +43 -0
- rasa/core/policies/intentless_policy.py +1031 -0
- rasa/core/policies/intentless_prompt_template.jinja2 +22 -0
- rasa/core/policies/memoization.py +538 -0
- rasa/core/policies/policy.py +725 -0
- rasa/core/policies/rule_policy.py +1273 -0
- rasa/core/policies/ted_policy.py +2169 -0
- rasa/core/policies/unexpected_intent_policy.py +1022 -0
- rasa/core/processor.py +1465 -0
- rasa/core/run.py +342 -0
- rasa/core/secrets_manager/__init__.py +0 -0
- rasa/core/secrets_manager/constants.py +36 -0
- rasa/core/secrets_manager/endpoints.py +391 -0
- rasa/core/secrets_manager/factory.py +241 -0
- rasa/core/secrets_manager/secret_manager.py +262 -0
- rasa/core/secrets_manager/vault.py +584 -0
- rasa/core/test.py +1335 -0
- rasa/core/tracker_store.py +1703 -0
- rasa/core/train.py +105 -0
- rasa/core/training/__init__.py +89 -0
- rasa/core/training/converters/__init__.py +0 -0
- rasa/core/training/converters/responses_prefix_converter.py +119 -0
- rasa/core/training/interactive.py +1744 -0
- rasa/core/training/story_conflict.py +381 -0
- rasa/core/training/training.py +93 -0
- rasa/core/utils.py +366 -0
- rasa/core/visualize.py +70 -0
- rasa/dialogue_understanding/__init__.py +0 -0
- rasa/dialogue_understanding/coexistence/__init__.py +0 -0
- rasa/dialogue_understanding/coexistence/constants.py +4 -0
- rasa/dialogue_understanding/coexistence/intent_based_router.py +196 -0
- rasa/dialogue_understanding/coexistence/llm_based_router.py +327 -0
- rasa/dialogue_understanding/coexistence/router_template.jinja2 +12 -0
- rasa/dialogue_understanding/commands/__init__.py +61 -0
- rasa/dialogue_understanding/commands/can_not_handle_command.py +70 -0
- rasa/dialogue_understanding/commands/cancel_flow_command.py +125 -0
- rasa/dialogue_understanding/commands/change_flow_command.py +44 -0
- rasa/dialogue_understanding/commands/chit_chat_answer_command.py +57 -0
- rasa/dialogue_understanding/commands/clarify_command.py +86 -0
- rasa/dialogue_understanding/commands/command.py +85 -0
- rasa/dialogue_understanding/commands/correct_slots_command.py +297 -0
- rasa/dialogue_understanding/commands/error_command.py +79 -0
- rasa/dialogue_understanding/commands/free_form_answer_command.py +9 -0
- rasa/dialogue_understanding/commands/handle_code_change_command.py +73 -0
- rasa/dialogue_understanding/commands/human_handoff_command.py +66 -0
- rasa/dialogue_understanding/commands/knowledge_answer_command.py +57 -0
- rasa/dialogue_understanding/commands/noop_command.py +54 -0
- rasa/dialogue_understanding/commands/repeat_bot_messages_command.py +60 -0
- rasa/dialogue_understanding/commands/restart_command.py +58 -0
- rasa/dialogue_understanding/commands/session_end_command.py +61 -0
- rasa/dialogue_understanding/commands/session_start_command.py +59 -0
- rasa/dialogue_understanding/commands/set_slot_command.py +160 -0
- rasa/dialogue_understanding/commands/skip_question_command.py +75 -0
- rasa/dialogue_understanding/commands/start_flow_command.py +107 -0
- rasa/dialogue_understanding/commands/user_silence_command.py +59 -0
- rasa/dialogue_understanding/commands/utils.py +45 -0
- rasa/dialogue_understanding/generator/__init__.py +21 -0
- rasa/dialogue_understanding/generator/command_generator.py +464 -0
- rasa/dialogue_understanding/generator/constants.py +27 -0
- rasa/dialogue_understanding/generator/flow_document_template.jinja2 +4 -0
- rasa/dialogue_understanding/generator/flow_retrieval.py +466 -0
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +500 -0
- rasa/dialogue_understanding/generator/llm_command_generator.py +67 -0
- rasa/dialogue_understanding/generator/multi_step/__init__.py +0 -0
- rasa/dialogue_understanding/generator/multi_step/fill_slots_prompt.jinja2 +62 -0
- rasa/dialogue_understanding/generator/multi_step/handle_flows_prompt.jinja2 +38 -0
- rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +920 -0
- rasa/dialogue_understanding/generator/nlu_command_adapter.py +261 -0
- rasa/dialogue_understanding/generator/single_step/__init__.py +0 -0
- rasa/dialogue_understanding/generator/single_step/command_prompt_template.jinja2 +60 -0
- rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +486 -0
- rasa/dialogue_understanding/patterns/__init__.py +0 -0
- rasa/dialogue_understanding/patterns/cancel.py +111 -0
- rasa/dialogue_understanding/patterns/cannot_handle.py +43 -0
- rasa/dialogue_understanding/patterns/chitchat.py +37 -0
- rasa/dialogue_understanding/patterns/clarify.py +97 -0
- rasa/dialogue_understanding/patterns/code_change.py +41 -0
- rasa/dialogue_understanding/patterns/collect_information.py +90 -0
- rasa/dialogue_understanding/patterns/completed.py +40 -0
- rasa/dialogue_understanding/patterns/continue_interrupted.py +42 -0
- rasa/dialogue_understanding/patterns/correction.py +278 -0
- rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +301 -0
- rasa/dialogue_understanding/patterns/human_handoff.py +37 -0
- rasa/dialogue_understanding/patterns/internal_error.py +47 -0
- rasa/dialogue_understanding/patterns/repeat.py +37 -0
- rasa/dialogue_understanding/patterns/restart.py +37 -0
- rasa/dialogue_understanding/patterns/search.py +37 -0
- rasa/dialogue_understanding/patterns/session_start.py +37 -0
- rasa/dialogue_understanding/patterns/skip_question.py +38 -0
- rasa/dialogue_understanding/patterns/user_silence.py +37 -0
- rasa/dialogue_understanding/processor/__init__.py +0 -0
- rasa/dialogue_understanding/processor/command_processor.py +720 -0
- rasa/dialogue_understanding/processor/command_processor_component.py +43 -0
- rasa/dialogue_understanding/stack/__init__.py +0 -0
- rasa/dialogue_understanding/stack/dialogue_stack.py +178 -0
- rasa/dialogue_understanding/stack/frames/__init__.py +19 -0
- rasa/dialogue_understanding/stack/frames/chit_chat_frame.py +27 -0
- rasa/dialogue_understanding/stack/frames/dialogue_stack_frame.py +137 -0
- rasa/dialogue_understanding/stack/frames/flow_stack_frame.py +157 -0
- rasa/dialogue_understanding/stack/frames/pattern_frame.py +10 -0
- rasa/dialogue_understanding/stack/frames/search_frame.py +27 -0
- rasa/dialogue_understanding/stack/utils.py +211 -0
- rasa/dialogue_understanding/utils.py +14 -0
- rasa/dialogue_understanding_test/__init__.py +0 -0
- rasa/dialogue_understanding_test/command_metric_calculation.py +12 -0
- rasa/dialogue_understanding_test/constants.py +17 -0
- rasa/dialogue_understanding_test/du_test_case.py +118 -0
- rasa/dialogue_understanding_test/du_test_result.py +11 -0
- rasa/dialogue_understanding_test/du_test_runner.py +93 -0
- rasa/dialogue_understanding_test/io.py +54 -0
- rasa/dialogue_understanding_test/validation.py +22 -0
- rasa/e2e_test/__init__.py +0 -0
- rasa/e2e_test/aggregate_test_stats_calculator.py +134 -0
- rasa/e2e_test/assertions.py +1345 -0
- rasa/e2e_test/assertions_schema.yml +129 -0
- rasa/e2e_test/constants.py +31 -0
- rasa/e2e_test/e2e_config.py +220 -0
- rasa/e2e_test/e2e_config_schema.yml +26 -0
- rasa/e2e_test/e2e_test_case.py +569 -0
- rasa/e2e_test/e2e_test_converter.py +363 -0
- rasa/e2e_test/e2e_test_converter_prompt.jinja2 +70 -0
- rasa/e2e_test/e2e_test_coverage_report.py +364 -0
- rasa/e2e_test/e2e_test_result.py +54 -0
- rasa/e2e_test/e2e_test_runner.py +1192 -0
- rasa/e2e_test/e2e_test_schema.yml +181 -0
- rasa/e2e_test/pykwalify_extensions.py +39 -0
- rasa/e2e_test/stub_custom_action.py +70 -0
- rasa/e2e_test/utils/__init__.py +0 -0
- rasa/e2e_test/utils/e2e_yaml_utils.py +55 -0
- rasa/e2e_test/utils/io.py +598 -0
- rasa/e2e_test/utils/validation.py +178 -0
- rasa/engine/__init__.py +0 -0
- rasa/engine/caching.py +463 -0
- rasa/engine/constants.py +17 -0
- rasa/engine/exceptions.py +14 -0
- rasa/engine/graph.py +642 -0
- rasa/engine/loader.py +48 -0
- rasa/engine/recipes/__init__.py +0 -0
- rasa/engine/recipes/config_files/default_config.yml +41 -0
- rasa/engine/recipes/default_components.py +97 -0
- rasa/engine/recipes/default_recipe.py +1272 -0
- rasa/engine/recipes/graph_recipe.py +79 -0
- rasa/engine/recipes/recipe.py +93 -0
- rasa/engine/runner/__init__.py +0 -0
- rasa/engine/runner/dask.py +250 -0
- rasa/engine/runner/interface.py +49 -0
- rasa/engine/storage/__init__.py +0 -0
- rasa/engine/storage/local_model_storage.py +244 -0
- rasa/engine/storage/resource.py +110 -0
- rasa/engine/storage/storage.py +199 -0
- rasa/engine/training/__init__.py +0 -0
- rasa/engine/training/components.py +176 -0
- rasa/engine/training/fingerprinting.py +64 -0
- rasa/engine/training/graph_trainer.py +256 -0
- rasa/engine/training/hooks.py +164 -0
- rasa/engine/validation.py +1451 -0
- rasa/env.py +14 -0
- rasa/exceptions.py +69 -0
- rasa/graph_components/__init__.py +0 -0
- rasa/graph_components/converters/__init__.py +0 -0
- rasa/graph_components/converters/nlu_message_converter.py +48 -0
- rasa/graph_components/providers/__init__.py +0 -0
- rasa/graph_components/providers/domain_for_core_training_provider.py +87 -0
- rasa/graph_components/providers/domain_provider.py +71 -0
- rasa/graph_components/providers/flows_provider.py +74 -0
- rasa/graph_components/providers/forms_provider.py +44 -0
- rasa/graph_components/providers/nlu_training_data_provider.py +56 -0
- rasa/graph_components/providers/responses_provider.py +44 -0
- rasa/graph_components/providers/rule_only_provider.py +49 -0
- rasa/graph_components/providers/story_graph_provider.py +96 -0
- rasa/graph_components/providers/training_tracker_provider.py +55 -0
- rasa/graph_components/validators/__init__.py +0 -0
- rasa/graph_components/validators/default_recipe_validator.py +550 -0
- rasa/graph_components/validators/finetuning_validator.py +302 -0
- rasa/hooks.py +111 -0
- rasa/jupyter.py +63 -0
- rasa/llm_fine_tuning/__init__.py +0 -0
- rasa/llm_fine_tuning/annotation_module.py +241 -0
- rasa/llm_fine_tuning/conversations.py +144 -0
- rasa/llm_fine_tuning/llm_data_preparation_module.py +178 -0
- rasa/llm_fine_tuning/paraphrasing/__init__.py +0 -0
- rasa/llm_fine_tuning/paraphrasing/conversation_rephraser.py +281 -0
- rasa/llm_fine_tuning/paraphrasing/default_rephrase_prompt_template.jina2 +44 -0
- rasa/llm_fine_tuning/paraphrasing/rephrase_validator.py +121 -0
- rasa/llm_fine_tuning/paraphrasing/rephrased_user_message.py +10 -0
- rasa/llm_fine_tuning/paraphrasing_module.py +128 -0
- rasa/llm_fine_tuning/storage.py +174 -0
- rasa/llm_fine_tuning/train_test_split_module.py +441 -0
- rasa/markers/__init__.py +0 -0
- rasa/markers/marker.py +269 -0
- rasa/markers/marker_base.py +828 -0
- rasa/markers/upload.py +74 -0
- rasa/markers/validate.py +21 -0
- rasa/model.py +118 -0
- rasa/model_manager/__init__.py +0 -0
- rasa/model_manager/config.py +40 -0
- rasa/model_manager/model_api.py +559 -0
- rasa/model_manager/runner_service.py +286 -0
- rasa/model_manager/socket_bridge.py +146 -0
- rasa/model_manager/studio_jwt_auth.py +86 -0
- rasa/model_manager/trainer_service.py +325 -0
- rasa/model_manager/utils.py +87 -0
- rasa/model_manager/warm_rasa_process.py +187 -0
- rasa/model_service.py +112 -0
- rasa/model_testing.py +457 -0
- rasa/model_training.py +596 -0
- rasa/nlu/__init__.py +7 -0
- rasa/nlu/classifiers/__init__.py +3 -0
- rasa/nlu/classifiers/classifier.py +5 -0
- rasa/nlu/classifiers/diet_classifier.py +1881 -0
- rasa/nlu/classifiers/fallback_classifier.py +192 -0
- rasa/nlu/classifiers/keyword_intent_classifier.py +188 -0
- rasa/nlu/classifiers/logistic_regression_classifier.py +253 -0
- rasa/nlu/classifiers/mitie_intent_classifier.py +156 -0
- rasa/nlu/classifiers/regex_message_handler.py +56 -0
- rasa/nlu/classifiers/sklearn_intent_classifier.py +330 -0
- rasa/nlu/constants.py +77 -0
- rasa/nlu/convert.py +40 -0
- rasa/nlu/emulators/__init__.py +0 -0
- rasa/nlu/emulators/dialogflow.py +55 -0
- rasa/nlu/emulators/emulator.py +49 -0
- rasa/nlu/emulators/luis.py +86 -0
- rasa/nlu/emulators/no_emulator.py +10 -0
- rasa/nlu/emulators/wit.py +56 -0
- rasa/nlu/extractors/__init__.py +0 -0
- rasa/nlu/extractors/crf_entity_extractor.py +715 -0
- rasa/nlu/extractors/duckling_entity_extractor.py +206 -0
- rasa/nlu/extractors/entity_synonyms.py +178 -0
- rasa/nlu/extractors/extractor.py +470 -0
- rasa/nlu/extractors/mitie_entity_extractor.py +293 -0
- rasa/nlu/extractors/regex_entity_extractor.py +220 -0
- rasa/nlu/extractors/spacy_entity_extractor.py +95 -0
- rasa/nlu/featurizers/__init__.py +0 -0
- rasa/nlu/featurizers/dense_featurizer/__init__.py +0 -0
- rasa/nlu/featurizers/dense_featurizer/convert_featurizer.py +445 -0
- rasa/nlu/featurizers/dense_featurizer/dense_featurizer.py +57 -0
- rasa/nlu/featurizers/dense_featurizer/lm_featurizer.py +768 -0
- rasa/nlu/featurizers/dense_featurizer/mitie_featurizer.py +170 -0
- rasa/nlu/featurizers/dense_featurizer/spacy_featurizer.py +132 -0
- rasa/nlu/featurizers/featurizer.py +89 -0
- rasa/nlu/featurizers/sparse_featurizer/__init__.py +0 -0
- rasa/nlu/featurizers/sparse_featurizer/count_vectors_featurizer.py +867 -0
- rasa/nlu/featurizers/sparse_featurizer/lexical_syntactic_featurizer.py +571 -0
- rasa/nlu/featurizers/sparse_featurizer/regex_featurizer.py +271 -0
- rasa/nlu/featurizers/sparse_featurizer/sparse_featurizer.py +9 -0
- rasa/nlu/model.py +24 -0
- rasa/nlu/run.py +27 -0
- rasa/nlu/selectors/__init__.py +0 -0
- rasa/nlu/selectors/response_selector.py +987 -0
- rasa/nlu/test.py +1940 -0
- rasa/nlu/tokenizers/__init__.py +0 -0
- rasa/nlu/tokenizers/jieba_tokenizer.py +148 -0
- rasa/nlu/tokenizers/mitie_tokenizer.py +75 -0
- rasa/nlu/tokenizers/spacy_tokenizer.py +72 -0
- rasa/nlu/tokenizers/tokenizer.py +239 -0
- rasa/nlu/tokenizers/whitespace_tokenizer.py +95 -0
- rasa/nlu/utils/__init__.py +35 -0
- rasa/nlu/utils/bilou_utils.py +462 -0
- rasa/nlu/utils/hugging_face/__init__.py +0 -0
- rasa/nlu/utils/hugging_face/registry.py +108 -0
- rasa/nlu/utils/hugging_face/transformers_pre_post_processors.py +311 -0
- rasa/nlu/utils/mitie_utils.py +113 -0
- rasa/nlu/utils/pattern_utils.py +168 -0
- rasa/nlu/utils/spacy_utils.py +310 -0
- rasa/plugin.py +90 -0
- rasa/server.py +1588 -0
- rasa/shared/__init__.py +0 -0
- rasa/shared/constants.py +311 -0
- rasa/shared/core/__init__.py +0 -0
- rasa/shared/core/command_payload_reader.py +109 -0
- rasa/shared/core/constants.py +180 -0
- rasa/shared/core/conversation.py +46 -0
- rasa/shared/core/domain.py +2172 -0
- rasa/shared/core/events.py +2559 -0
- rasa/shared/core/flows/__init__.py +7 -0
- rasa/shared/core/flows/flow.py +562 -0
- rasa/shared/core/flows/flow_path.py +84 -0
- rasa/shared/core/flows/flow_step.py +146 -0
- rasa/shared/core/flows/flow_step_links.py +319 -0
- rasa/shared/core/flows/flow_step_sequence.py +70 -0
- rasa/shared/core/flows/flows_list.py +258 -0
- rasa/shared/core/flows/flows_yaml_schema.json +303 -0
- rasa/shared/core/flows/nlu_trigger.py +117 -0
- rasa/shared/core/flows/steps/__init__.py +24 -0
- rasa/shared/core/flows/steps/action.py +56 -0
- rasa/shared/core/flows/steps/call.py +64 -0
- rasa/shared/core/flows/steps/collect.py +112 -0
- rasa/shared/core/flows/steps/constants.py +5 -0
- rasa/shared/core/flows/steps/continuation.py +36 -0
- rasa/shared/core/flows/steps/end.py +22 -0
- rasa/shared/core/flows/steps/internal.py +44 -0
- rasa/shared/core/flows/steps/link.py +51 -0
- rasa/shared/core/flows/steps/no_operation.py +48 -0
- rasa/shared/core/flows/steps/set_slots.py +50 -0
- rasa/shared/core/flows/steps/start.py +30 -0
- rasa/shared/core/flows/utils.py +39 -0
- rasa/shared/core/flows/validation.py +735 -0
- rasa/shared/core/flows/yaml_flows_io.py +405 -0
- rasa/shared/core/generator.py +908 -0
- rasa/shared/core/slot_mappings.py +526 -0
- rasa/shared/core/slots.py +654 -0
- rasa/shared/core/trackers.py +1183 -0
- rasa/shared/core/training_data/__init__.py +0 -0
- rasa/shared/core/training_data/loading.py +89 -0
- rasa/shared/core/training_data/story_reader/__init__.py +0 -0
- rasa/shared/core/training_data/story_reader/story_reader.py +129 -0
- rasa/shared/core/training_data/story_reader/story_step_builder.py +168 -0
- rasa/shared/core/training_data/story_reader/yaml_story_reader.py +888 -0
- rasa/shared/core/training_data/story_writer/__init__.py +0 -0
- rasa/shared/core/training_data/story_writer/story_writer.py +76 -0
- rasa/shared/core/training_data/story_writer/yaml_story_writer.py +444 -0
- rasa/shared/core/training_data/structures.py +858 -0
- rasa/shared/core/training_data/visualization.html +146 -0
- rasa/shared/core/training_data/visualization.py +603 -0
- rasa/shared/data.py +249 -0
- rasa/shared/engine/__init__.py +0 -0
- rasa/shared/engine/caching.py +26 -0
- rasa/shared/exceptions.py +167 -0
- rasa/shared/importers/__init__.py +0 -0
- rasa/shared/importers/importer.py +770 -0
- rasa/shared/importers/multi_project.py +215 -0
- rasa/shared/importers/rasa.py +108 -0
- rasa/shared/importers/remote_importer.py +196 -0
- rasa/shared/importers/utils.py +36 -0
- rasa/shared/nlu/__init__.py +0 -0
- rasa/shared/nlu/constants.py +53 -0
- rasa/shared/nlu/interpreter.py +10 -0
- rasa/shared/nlu/training_data/__init__.py +0 -0
- rasa/shared/nlu/training_data/entities_parser.py +208 -0
- rasa/shared/nlu/training_data/features.py +492 -0
- rasa/shared/nlu/training_data/formats/__init__.py +10 -0
- rasa/shared/nlu/training_data/formats/dialogflow.py +163 -0
- rasa/shared/nlu/training_data/formats/luis.py +87 -0
- rasa/shared/nlu/training_data/formats/rasa.py +135 -0
- rasa/shared/nlu/training_data/formats/rasa_yaml.py +618 -0
- rasa/shared/nlu/training_data/formats/readerwriter.py +244 -0
- rasa/shared/nlu/training_data/formats/wit.py +52 -0
- rasa/shared/nlu/training_data/loading.py +137 -0
- rasa/shared/nlu/training_data/lookup_tables_parser.py +30 -0
- rasa/shared/nlu/training_data/message.py +490 -0
- rasa/shared/nlu/training_data/schemas/__init__.py +0 -0
- rasa/shared/nlu/training_data/schemas/data_schema.py +85 -0
- rasa/shared/nlu/training_data/schemas/nlu.yml +53 -0
- rasa/shared/nlu/training_data/schemas/responses.yml +70 -0
- rasa/shared/nlu/training_data/synonyms_parser.py +42 -0
- rasa/shared/nlu/training_data/training_data.py +729 -0
- rasa/shared/nlu/training_data/util.py +223 -0
- rasa/shared/providers/__init__.py +0 -0
- rasa/shared/providers/_configs/__init__.py +0 -0
- rasa/shared/providers/_configs/azure_openai_client_config.py +677 -0
- rasa/shared/providers/_configs/client_config.py +59 -0
- rasa/shared/providers/_configs/default_litellm_client_config.py +132 -0
- rasa/shared/providers/_configs/huggingface_local_embedding_client_config.py +236 -0
- rasa/shared/providers/_configs/litellm_router_client_config.py +222 -0
- rasa/shared/providers/_configs/model_group_config.py +173 -0
- rasa/shared/providers/_configs/openai_client_config.py +177 -0
- rasa/shared/providers/_configs/rasa_llm_client_config.py +75 -0
- rasa/shared/providers/_configs/self_hosted_llm_client_config.py +178 -0
- rasa/shared/providers/_configs/utils.py +117 -0
- rasa/shared/providers/_ssl_verification_utils.py +124 -0
- rasa/shared/providers/_utils.py +79 -0
- rasa/shared/providers/constants.py +7 -0
- rasa/shared/providers/embedding/__init__.py +0 -0
- rasa/shared/providers/embedding/_base_litellm_embedding_client.py +243 -0
- rasa/shared/providers/embedding/_langchain_embedding_client_adapter.py +74 -0
- rasa/shared/providers/embedding/azure_openai_embedding_client.py +335 -0
- rasa/shared/providers/embedding/default_litellm_embedding_client.py +126 -0
- rasa/shared/providers/embedding/embedding_client.py +90 -0
- rasa/shared/providers/embedding/embedding_response.py +41 -0
- rasa/shared/providers/embedding/huggingface_local_embedding_client.py +191 -0
- rasa/shared/providers/embedding/litellm_router_embedding_client.py +138 -0
- rasa/shared/providers/embedding/openai_embedding_client.py +172 -0
- rasa/shared/providers/llm/__init__.py +0 -0
- rasa/shared/providers/llm/_base_litellm_client.py +265 -0
- rasa/shared/providers/llm/azure_openai_llm_client.py +415 -0
- rasa/shared/providers/llm/default_litellm_llm_client.py +110 -0
- rasa/shared/providers/llm/litellm_router_llm_client.py +202 -0
- rasa/shared/providers/llm/llm_client.py +78 -0
- rasa/shared/providers/llm/llm_response.py +50 -0
- rasa/shared/providers/llm/openai_llm_client.py +161 -0
- rasa/shared/providers/llm/rasa_llm_client.py +120 -0
- rasa/shared/providers/llm/self_hosted_llm_client.py +276 -0
- rasa/shared/providers/mappings.py +94 -0
- rasa/shared/providers/router/__init__.py +0 -0
- rasa/shared/providers/router/_base_litellm_router_client.py +185 -0
- rasa/shared/providers/router/router_client.py +75 -0
- rasa/shared/utils/__init__.py +0 -0
- rasa/shared/utils/cli.py +102 -0
- rasa/shared/utils/common.py +324 -0
- rasa/shared/utils/constants.py +4 -0
- rasa/shared/utils/health_check/__init__.py +0 -0
- rasa/shared/utils/health_check/embeddings_health_check_mixin.py +31 -0
- rasa/shared/utils/health_check/health_check.py +258 -0
- rasa/shared/utils/health_check/llm_health_check_mixin.py +31 -0
- rasa/shared/utils/io.py +499 -0
- rasa/shared/utils/llm.py +764 -0
- rasa/shared/utils/pykwalify_extensions.py +27 -0
- rasa/shared/utils/schemas/__init__.py +0 -0
- rasa/shared/utils/schemas/config.yml +2 -0
- rasa/shared/utils/schemas/domain.yml +145 -0
- rasa/shared/utils/schemas/events.py +214 -0
- rasa/shared/utils/schemas/model_config.yml +36 -0
- rasa/shared/utils/schemas/stories.yml +173 -0
- rasa/shared/utils/yaml.py +1068 -0
- rasa/studio/__init__.py +0 -0
- rasa/studio/auth.py +270 -0
- rasa/studio/config.py +136 -0
- rasa/studio/constants.py +19 -0
- rasa/studio/data_handler.py +368 -0
- rasa/studio/download.py +489 -0
- rasa/studio/results_logger.py +137 -0
- rasa/studio/train.py +134 -0
- rasa/studio/upload.py +563 -0
- rasa/telemetry.py +1876 -0
- rasa/tracing/__init__.py +0 -0
- rasa/tracing/config.py +355 -0
- rasa/tracing/constants.py +62 -0
- rasa/tracing/instrumentation/__init__.py +0 -0
- rasa/tracing/instrumentation/attribute_extractors.py +765 -0
- rasa/tracing/instrumentation/instrumentation.py +1306 -0
- rasa/tracing/instrumentation/intentless_policy_instrumentation.py +144 -0
- rasa/tracing/instrumentation/metrics.py +294 -0
- rasa/tracing/metric_instrument_provider.py +205 -0
- rasa/utils/__init__.py +0 -0
- rasa/utils/beta.py +83 -0
- rasa/utils/cli.py +28 -0
- rasa/utils/common.py +639 -0
- rasa/utils/converter.py +53 -0
- rasa/utils/endpoints.py +331 -0
- rasa/utils/io.py +252 -0
- rasa/utils/json_utils.py +60 -0
- rasa/utils/licensing.py +542 -0
- rasa/utils/log_utils.py +181 -0
- rasa/utils/mapper.py +210 -0
- rasa/utils/ml_utils.py +147 -0
- rasa/utils/plotting.py +362 -0
- rasa/utils/sanic_error_handler.py +32 -0
- rasa/utils/singleton.py +23 -0
- rasa/utils/tensorflow/__init__.py +0 -0
- rasa/utils/tensorflow/callback.py +112 -0
- rasa/utils/tensorflow/constants.py +116 -0
- rasa/utils/tensorflow/crf.py +492 -0
- rasa/utils/tensorflow/data_generator.py +440 -0
- rasa/utils/tensorflow/environment.py +161 -0
- rasa/utils/tensorflow/exceptions.py +5 -0
- rasa/utils/tensorflow/feature_array.py +366 -0
- rasa/utils/tensorflow/layers.py +1565 -0
- rasa/utils/tensorflow/layers_utils.py +113 -0
- rasa/utils/tensorflow/metrics.py +281 -0
- rasa/utils/tensorflow/model_data.py +798 -0
- rasa/utils/tensorflow/model_data_utils.py +499 -0
- rasa/utils/tensorflow/models.py +935 -0
- rasa/utils/tensorflow/rasa_layers.py +1094 -0
- rasa/utils/tensorflow/transformer.py +640 -0
- rasa/utils/tensorflow/types.py +6 -0
- rasa/utils/train_utils.py +572 -0
- rasa/utils/url_tools.py +53 -0
- rasa/utils/yaml.py +54 -0
- rasa/validator.py +1644 -0
- rasa/version.py +3 -0
- rasa_pro-3.12.0.dev1.dist-info/METADATA +199 -0
- rasa_pro-3.12.0.dev1.dist-info/NOTICE +5 -0
- rasa_pro-3.12.0.dev1.dist-info/RECORD +790 -0
- rasa_pro-3.12.0.dev1.dist-info/WHEEL +4 -0
- rasa_pro-3.12.0.dev1.dist-info/entry_points.txt +3 -0
|
@@ -0,0 +1,176 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from typing import Any, Dict, Optional, Text, Type
|
|
3
|
+
import dataclasses
|
|
4
|
+
import uuid
|
|
5
|
+
|
|
6
|
+
from rasa.engine.caching import Cacheable, TrainingCache
|
|
7
|
+
from rasa.engine.graph import ExecutionContext, GraphComponent, SchemaNode
|
|
8
|
+
from rasa.engine.storage.resource import Resource
|
|
9
|
+
from rasa.engine.storage.storage import ModelStorage
|
|
10
|
+
from rasa.engine.training import fingerprinting
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class PrecomputedValueProvider(GraphComponent):
|
|
14
|
+
"""Holds the precomputed values of a `GraphNode` from a previous training.
|
|
15
|
+
|
|
16
|
+
Pre-computed values can either be
|
|
17
|
+
- values loaded from cache
|
|
18
|
+
- values which were provided during the fingerprint run by input nodes
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def __init__(self, output: Cacheable):
|
|
22
|
+
"""Initializes a `PrecomputedValueProvider`.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
output: The precomputed output to return.
|
|
26
|
+
"""
|
|
27
|
+
self._output = output
|
|
28
|
+
|
|
29
|
+
@classmethod
|
|
30
|
+
def create(
|
|
31
|
+
cls,
|
|
32
|
+
config: Dict[Text, Any],
|
|
33
|
+
model_storage: ModelStorage,
|
|
34
|
+
resource: Resource,
|
|
35
|
+
execution_context: ExecutionContext,
|
|
36
|
+
) -> PrecomputedValueProvider:
|
|
37
|
+
"""Creates instance (see parent class for full docstring)."""
|
|
38
|
+
return cls(output=config["output"])
|
|
39
|
+
|
|
40
|
+
def get_value(self) -> Cacheable:
|
|
41
|
+
"""Returns the precomputed output."""
|
|
42
|
+
return self._output
|
|
43
|
+
|
|
44
|
+
@classmethod
|
|
45
|
+
def replace_schema_node(cls, node: SchemaNode, output: Any) -> None:
|
|
46
|
+
"""Updates a `SchemaNode` to use a `PrecomputedValueProvider`.
|
|
47
|
+
|
|
48
|
+
This is for when we want to use the precomputed output value of a node from a
|
|
49
|
+
previous training in a subsequent training. We replace the class in the `uses`
|
|
50
|
+
of the node to a be a `PrecomputedValueProvider` configured to return the
|
|
51
|
+
precomputed value.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
node: The node to update.
|
|
55
|
+
output: precomputed cached output that the `PrecomputedValueProvider` will
|
|
56
|
+
return.
|
|
57
|
+
"""
|
|
58
|
+
node.uses = cls
|
|
59
|
+
node.config = {"output": output}
|
|
60
|
+
node.fn = cls.get_value.__name__
|
|
61
|
+
node.constructor_name = cls.create.__name__
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
@dataclasses.dataclass
|
|
65
|
+
class FingerprintStatus:
|
|
66
|
+
"""Holds the output of a `FingerprintComponent` and is used to prune the graph.
|
|
67
|
+
|
|
68
|
+
Attributes:
|
|
69
|
+
output_fingerprint: A fingerprint of the node's output value.
|
|
70
|
+
is_hit: `True` if node's fingerprint key exists in the cache, `False` otherwise.
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
output_fingerprint: Optional[Text]
|
|
74
|
+
is_hit: bool
|
|
75
|
+
|
|
76
|
+
def fingerprint(self) -> Text:
|
|
77
|
+
"""Returns the internal fingerprint.
|
|
78
|
+
|
|
79
|
+
If there is no fingerprint returns a random string that will never match.
|
|
80
|
+
"""
|
|
81
|
+
return self.output_fingerprint or uuid.uuid4().hex
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class FingerprintComponent(GraphComponent):
|
|
85
|
+
"""Replaces non-input nodes during a fingerprint run."""
|
|
86
|
+
|
|
87
|
+
def __init__(
|
|
88
|
+
self,
|
|
89
|
+
cache: TrainingCache,
|
|
90
|
+
config_of_replaced_component: Dict[Text, Any],
|
|
91
|
+
class_of_replaced_component: Type,
|
|
92
|
+
) -> None:
|
|
93
|
+
"""Initializes a `FingerprintComponent`.
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
cache: Training cache used to determine if the run is a hit or not.
|
|
97
|
+
config_of_replaced_component: Needed to generate the fingerprint key.
|
|
98
|
+
class_of_replaced_component: Needed to generate the fingerprint key.
|
|
99
|
+
"""
|
|
100
|
+
self._cache = cache
|
|
101
|
+
self._config_of_replaced_component = config_of_replaced_component
|
|
102
|
+
self._class_of_replaced_component = class_of_replaced_component
|
|
103
|
+
|
|
104
|
+
@classmethod
|
|
105
|
+
def create(
|
|
106
|
+
cls,
|
|
107
|
+
config: Dict[Text, Any],
|
|
108
|
+
model_storage: ModelStorage,
|
|
109
|
+
resource: Resource,
|
|
110
|
+
execution_context: ExecutionContext,
|
|
111
|
+
) -> FingerprintComponent:
|
|
112
|
+
"""Creates a `FingerprintComponent` (see parent class for full docstring)."""
|
|
113
|
+
return cls(
|
|
114
|
+
cache=config["cache"],
|
|
115
|
+
config_of_replaced_component=config["config_of_replaced_component"],
|
|
116
|
+
class_of_replaced_component=config["graph_component_class"],
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
def run(self, **kwargs: Any) -> FingerprintStatus:
|
|
120
|
+
"""Calculates the fingerprint key to determine if cached output can be used.
|
|
121
|
+
|
|
122
|
+
If the fingerprint key matches an entry in the cache it means that there has
|
|
123
|
+
been a previous node execution which matches the same component class, component
|
|
124
|
+
config and input values. This means that we can potentially prune this node
|
|
125
|
+
from the schema, or replace it with a cached value before the next graph run.
|
|
126
|
+
|
|
127
|
+
Args:
|
|
128
|
+
**kwargs: Inputs from all parent nodes.
|
|
129
|
+
|
|
130
|
+
Returns:
|
|
131
|
+
A `FingerprintStatus` determining if the run was a hit, and if it was a hit
|
|
132
|
+
also the output fingerprint from the cache.
|
|
133
|
+
"""
|
|
134
|
+
fingerprint_key = fingerprinting.calculate_fingerprint_key(
|
|
135
|
+
graph_component_class=self._class_of_replaced_component,
|
|
136
|
+
config={
|
|
137
|
+
**self._class_of_replaced_component.get_default_config(),
|
|
138
|
+
**self._config_of_replaced_component,
|
|
139
|
+
},
|
|
140
|
+
inputs=kwargs,
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
output_fingerprint = self._cache.get_cached_output_fingerprint(fingerprint_key)
|
|
144
|
+
|
|
145
|
+
return FingerprintStatus(
|
|
146
|
+
is_hit=output_fingerprint is not None, output_fingerprint=output_fingerprint
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
@classmethod
|
|
150
|
+
def replace_schema_node(cls, node: SchemaNode, cache: TrainingCache) -> None:
|
|
151
|
+
"""Updates a `SchemaNode` to use a `FingerprintComponent`.
|
|
152
|
+
|
|
153
|
+
This is for when we want to do a fingerprint run. During the fingerprint run we
|
|
154
|
+
replace all non-input nodes with `FingerprintComponent`s so we can determine
|
|
155
|
+
whether they are able to be pruned or cached before the next graph run without
|
|
156
|
+
running the actual components.
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
Args:
|
|
160
|
+
node: The node to update.
|
|
161
|
+
cache: The cache is needed to determine of there is cache hit for the
|
|
162
|
+
fingerprint key.
|
|
163
|
+
"""
|
|
164
|
+
graph_component_class = node.uses
|
|
165
|
+
node.uses = cls
|
|
166
|
+
# We update the node to be "eager" so that `FingerprintComponent.run` sees
|
|
167
|
+
# ALL the inputs to the node. If it was not eager, we would miss any args used
|
|
168
|
+
# by the constructor.
|
|
169
|
+
node.eager = True
|
|
170
|
+
node.constructor_name = cls.create.__name__
|
|
171
|
+
node.fn = cls.run.__name__
|
|
172
|
+
node.config = {
|
|
173
|
+
"config_of_replaced_component": node.config,
|
|
174
|
+
"cache": cache,
|
|
175
|
+
"graph_component_class": graph_component_class,
|
|
176
|
+
}
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
import inspect
|
|
2
|
+
import logging
|
|
3
|
+
from typing import Any, Dict, Text, Type
|
|
4
|
+
from typing_extensions import Protocol, runtime_checkable
|
|
5
|
+
import importlib_metadata
|
|
6
|
+
import rasa.utils.common
|
|
7
|
+
import rasa.shared.utils.io
|
|
8
|
+
from rasa.engine.graph import GraphComponent
|
|
9
|
+
|
|
10
|
+
logger = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
import_name_to_package_map = {"sklearn": "scikit_learn"}
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@runtime_checkable
|
|
16
|
+
class Fingerprintable(Protocol):
|
|
17
|
+
"""Interface that enforces training data can be fingerprinted."""
|
|
18
|
+
|
|
19
|
+
def fingerprint(self) -> Text:
|
|
20
|
+
"""Returns a unique stable fingerprint of the data."""
|
|
21
|
+
...
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def calculate_fingerprint_key(
|
|
25
|
+
graph_component_class: Type[GraphComponent],
|
|
26
|
+
config: Dict[Text, Any],
|
|
27
|
+
inputs: Dict[Text, Fingerprintable],
|
|
28
|
+
) -> Text:
|
|
29
|
+
"""Calculates a fingerprint key that uniquely represents a single node's execution.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
graph_component_class: The graph component class.
|
|
33
|
+
config: The component config.
|
|
34
|
+
inputs: The inputs as a mapping of parent node name to input value.
|
|
35
|
+
|
|
36
|
+
Returns:
|
|
37
|
+
The fingerprint key.
|
|
38
|
+
"""
|
|
39
|
+
dependency_versions = {
|
|
40
|
+
package: importlib_metadata.version(
|
|
41
|
+
import_name_to_package_map.get(package, package)
|
|
42
|
+
)
|
|
43
|
+
for package in graph_component_class.required_packages()
|
|
44
|
+
}
|
|
45
|
+
fingerprint_data = {
|
|
46
|
+
"node_name": rasa.utils.common.module_path_from_class(graph_component_class),
|
|
47
|
+
"component_implementation": inspect.getsource(graph_component_class),
|
|
48
|
+
"config": config,
|
|
49
|
+
"inputs": inputs,
|
|
50
|
+
"dependency_versions": dependency_versions,
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
fingerprint_addon = graph_component_class.fingerprint_addon(config)
|
|
54
|
+
if fingerprint_addon is not None:
|
|
55
|
+
fingerprint_data["addon"] = fingerprint_addon
|
|
56
|
+
|
|
57
|
+
fingerprint_key = rasa.shared.utils.io.deep_container_fingerprint(fingerprint_data)
|
|
58
|
+
|
|
59
|
+
logger.debug(
|
|
60
|
+
f"Calculated fingerprint_key '{fingerprint_key}' for class "
|
|
61
|
+
f"'{graph_component_class.__name__}'."
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
return fingerprint_key
|
|
@@ -0,0 +1,256 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
import logging
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Any, Dict, Text, Type, Union
|
|
5
|
+
|
|
6
|
+
from rasa.engine.caching import TrainingCache
|
|
7
|
+
from rasa.engine.graph import ExecutionContext, GraphSchema, GraphModelConfiguration
|
|
8
|
+
from rasa.engine.constants import PLACEHOLDER_IMPORTER
|
|
9
|
+
from rasa.engine.runner.interface import GraphRunner
|
|
10
|
+
from rasa.engine.storage.storage import ModelStorage, ModelMetadata
|
|
11
|
+
from rasa.engine.training.components import (
|
|
12
|
+
PrecomputedValueProvider,
|
|
13
|
+
FingerprintComponent,
|
|
14
|
+
FingerprintStatus,
|
|
15
|
+
)
|
|
16
|
+
from rasa.engine.training.hooks import TrainingHook, LoggingHook
|
|
17
|
+
from rasa.shared.importers.importer import TrainingDataImporter
|
|
18
|
+
|
|
19
|
+
logger = logging.getLogger(__name__)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class GraphTrainer:
|
|
23
|
+
"""Trains a model using a graph schema."""
|
|
24
|
+
|
|
25
|
+
def __init__(
|
|
26
|
+
self,
|
|
27
|
+
model_storage: ModelStorage,
|
|
28
|
+
cache: TrainingCache,
|
|
29
|
+
graph_runner_class: Type[GraphRunner],
|
|
30
|
+
) -> None:
|
|
31
|
+
"""Initializes a `GraphTrainer`.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
model_storage: Storage which graph components can use to persist and load.
|
|
35
|
+
Also used for packaging the trained model.
|
|
36
|
+
cache: Cache used to store fingerprints and outputs.
|
|
37
|
+
graph_runner_class: The class to instantiate the runner from.
|
|
38
|
+
"""
|
|
39
|
+
self._model_storage = model_storage
|
|
40
|
+
self._cache = cache
|
|
41
|
+
self._graph_runner_class = graph_runner_class
|
|
42
|
+
|
|
43
|
+
async def train(
|
|
44
|
+
self,
|
|
45
|
+
model_configuration: GraphModelConfiguration,
|
|
46
|
+
importer: TrainingDataImporter,
|
|
47
|
+
output_filename: Path,
|
|
48
|
+
force_retraining: bool = False,
|
|
49
|
+
is_finetuning: bool = False,
|
|
50
|
+
) -> ModelMetadata:
|
|
51
|
+
"""Trains and packages a model and returns the prediction graph runner.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
model_configuration: The model configuration (schemas, language, etc.)
|
|
55
|
+
importer: The importer which provides the training data for the training.
|
|
56
|
+
output_filename: The location to save the packaged model.
|
|
57
|
+
force_retraining: If `True` then the cache is skipped and all components
|
|
58
|
+
are retrained.
|
|
59
|
+
is_finetuning: `True` if we want to finetune the model.
|
|
60
|
+
|
|
61
|
+
Returns:
|
|
62
|
+
The metadata describing the trained model.
|
|
63
|
+
"""
|
|
64
|
+
logger.debug("Starting training.")
|
|
65
|
+
|
|
66
|
+
# Retrieve the domain for the model metadata right at the start.
|
|
67
|
+
# This avoids that something during the graph runs mutates it.
|
|
68
|
+
domain = copy.deepcopy(importer.get_domain())
|
|
69
|
+
|
|
70
|
+
if force_retraining:
|
|
71
|
+
logger.debug(
|
|
72
|
+
"Skip fingerprint run as a full training of the model was enforced."
|
|
73
|
+
)
|
|
74
|
+
pruned_training_schema = model_configuration.train_schema
|
|
75
|
+
else:
|
|
76
|
+
fingerprint_run_outputs = await self.fingerprint(
|
|
77
|
+
model_configuration.train_schema,
|
|
78
|
+
importer=importer,
|
|
79
|
+
is_finetuning=is_finetuning,
|
|
80
|
+
)
|
|
81
|
+
pruned_training_schema = self._prune_schema(
|
|
82
|
+
model_configuration.train_schema, fingerprint_run_outputs
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
hooks = [
|
|
86
|
+
LoggingHook(pruned_schema=pruned_training_schema),
|
|
87
|
+
TrainingHook(
|
|
88
|
+
cache=self._cache,
|
|
89
|
+
model_storage=self._model_storage,
|
|
90
|
+
pruned_schema=pruned_training_schema,
|
|
91
|
+
),
|
|
92
|
+
]
|
|
93
|
+
|
|
94
|
+
graph_runner = self._graph_runner_class.create(
|
|
95
|
+
graph_schema=pruned_training_schema,
|
|
96
|
+
model_storage=self._model_storage,
|
|
97
|
+
execution_context=ExecutionContext(
|
|
98
|
+
graph_schema=model_configuration.train_schema,
|
|
99
|
+
is_finetuning=is_finetuning,
|
|
100
|
+
),
|
|
101
|
+
hooks=hooks,
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
logger.debug("Running the pruned train graph with real node execution.")
|
|
105
|
+
|
|
106
|
+
await graph_runner.run(inputs={PLACEHOLDER_IMPORTER: importer})
|
|
107
|
+
|
|
108
|
+
return self._model_storage.create_model_package(
|
|
109
|
+
output_filename, model_configuration, domain
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
async def fingerprint(
|
|
113
|
+
self,
|
|
114
|
+
train_schema: GraphSchema,
|
|
115
|
+
importer: TrainingDataImporter,
|
|
116
|
+
is_finetuning: bool = False,
|
|
117
|
+
) -> Dict[Text, Union[FingerprintStatus, Any]]:
|
|
118
|
+
"""Runs the graph using fingerprints to determine which nodes need to re-run.
|
|
119
|
+
|
|
120
|
+
Nodes which have a matching fingerprint key in the cache can either be removed
|
|
121
|
+
entirely from the graph, or replaced with a cached value if their output is
|
|
122
|
+
needed by descendent nodes.
|
|
123
|
+
|
|
124
|
+
Args:
|
|
125
|
+
train_schema: The train graph schema that will be run in fingerprint mode.
|
|
126
|
+
importer: The importer which provides the training data for the training.
|
|
127
|
+
is_finetuning: `True` if we want to finetune the model.
|
|
128
|
+
|
|
129
|
+
Returns:
|
|
130
|
+
Mapping of node names to fingerprint results.
|
|
131
|
+
"""
|
|
132
|
+
fingerprint_schema = self._create_fingerprint_schema(train_schema)
|
|
133
|
+
|
|
134
|
+
fingerprint_graph_runner = self._graph_runner_class.create(
|
|
135
|
+
graph_schema=fingerprint_schema,
|
|
136
|
+
model_storage=self._model_storage,
|
|
137
|
+
execution_context=ExecutionContext(
|
|
138
|
+
graph_schema=train_schema, is_finetuning=is_finetuning
|
|
139
|
+
),
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
logger.debug("Running the train graph in fingerprint mode.")
|
|
143
|
+
return await fingerprint_graph_runner.run(
|
|
144
|
+
inputs={PLACEHOLDER_IMPORTER: importer}
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
def _create_fingerprint_schema(self, train_schema: GraphSchema) -> GraphSchema:
|
|
148
|
+
fingerprint_schema = copy.deepcopy(train_schema)
|
|
149
|
+
for node_name, schema_node in fingerprint_schema.nodes.items():
|
|
150
|
+
# We make every node a target so that `graph_runner.run(...)` returns
|
|
151
|
+
# the output for each node. We need the output of each node
|
|
152
|
+
# to decide which nodes we can prune.
|
|
153
|
+
schema_node.is_target = True
|
|
154
|
+
|
|
155
|
+
# We do not replace the input nodes as we need an up-to-date fingerprint of
|
|
156
|
+
# any input data to the graph. This means we can prune according to what
|
|
157
|
+
# has actually changed.
|
|
158
|
+
if not schema_node.is_input:
|
|
159
|
+
FingerprintComponent.replace_schema_node(schema_node, self._cache)
|
|
160
|
+
return fingerprint_schema
|
|
161
|
+
|
|
162
|
+
def _prune_schema(
|
|
163
|
+
self,
|
|
164
|
+
schema: GraphSchema,
|
|
165
|
+
fingerprint_run_outputs: Dict[Text, Union[FingerprintStatus, Any]],
|
|
166
|
+
) -> GraphSchema:
|
|
167
|
+
"""Uses the fingerprint statuses to prune the graph schema.
|
|
168
|
+
|
|
169
|
+
Walks the graph starting at each target node. If a node has a cache hit we
|
|
170
|
+
replace it with a `PrecomputedValueProvider` and remove its input dependencies.
|
|
171
|
+
At the end, any node that is not an ancestor of a target node will be pruned
|
|
172
|
+
when we call `minimal_graph_schema()`.
|
|
173
|
+
|
|
174
|
+
Args:
|
|
175
|
+
schema: The graph to prune.
|
|
176
|
+
fingerprint_run_outputs: Node outputs from the fingerprint run as a mapping
|
|
177
|
+
from node name to output.
|
|
178
|
+
|
|
179
|
+
Returns:
|
|
180
|
+
The pruned schema.
|
|
181
|
+
"""
|
|
182
|
+
pruned_schema = copy.deepcopy(schema)
|
|
183
|
+
target_node_names = pruned_schema.target_names
|
|
184
|
+
|
|
185
|
+
for target_node_name in target_node_names:
|
|
186
|
+
self._walk_and_prune(
|
|
187
|
+
pruned_schema, target_node_name, fingerprint_run_outputs
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
return pruned_schema.minimal_graph_schema()
|
|
191
|
+
|
|
192
|
+
def _walk_and_prune(
|
|
193
|
+
self,
|
|
194
|
+
schema: GraphSchema,
|
|
195
|
+
current_node_name: Text,
|
|
196
|
+
fingerprint_run_outputs: Dict[Text, Union[FingerprintStatus, Any]],
|
|
197
|
+
) -> None:
|
|
198
|
+
"""Recursively walks backwards though a graph checking the status of each node.
|
|
199
|
+
|
|
200
|
+
If node has a fingerprint key hit then we check if there is a cached output.
|
|
201
|
+
If there is a cached output we will replace the node with a
|
|
202
|
+
`PrecomputedValueProvider` and remove all its dependencies (`.needs`). If
|
|
203
|
+
there is not a fingerprint key hit, or there is no cached output, the node is
|
|
204
|
+
left untouched and will be executed again next run unless it is no longer the
|
|
205
|
+
ancestor of a target node.
|
|
206
|
+
|
|
207
|
+
Args:
|
|
208
|
+
schema: The graph we are currently walking.
|
|
209
|
+
current_node_name: The current node on the walk.
|
|
210
|
+
fingerprint_run_outputs: The fingerprint statuses of every node as a mapping
|
|
211
|
+
from node name to status.
|
|
212
|
+
"""
|
|
213
|
+
fingerprint_run_output = fingerprint_run_outputs[current_node_name]
|
|
214
|
+
node = schema.nodes[current_node_name]
|
|
215
|
+
|
|
216
|
+
# If we have replaced this node with a `PrecomputedValueProvider` we have
|
|
217
|
+
# already visited this node. A `PrecomputedValueProvider` is updated to have
|
|
218
|
+
# no parent nodes, so
|
|
219
|
+
# we can end the walk here.
|
|
220
|
+
if node.uses == PrecomputedValueProvider:
|
|
221
|
+
return
|
|
222
|
+
|
|
223
|
+
# If the output was a `FingerprintStatus` we must check the cache and status.
|
|
224
|
+
if isinstance(fingerprint_run_output, FingerprintStatus):
|
|
225
|
+
# If there is a fingerprint key hit we can potentially use a cached output.
|
|
226
|
+
if fingerprint_run_output.is_hit:
|
|
227
|
+
output_result = self._cache.get_cached_result(
|
|
228
|
+
output_fingerprint_key=fingerprint_run_output.output_fingerprint,
|
|
229
|
+
node_name=current_node_name,
|
|
230
|
+
model_storage=self._model_storage,
|
|
231
|
+
)
|
|
232
|
+
if output_result:
|
|
233
|
+
logger.debug(
|
|
234
|
+
f"Updating '{current_node_name}' to use a "
|
|
235
|
+
f"'{PrecomputedValueProvider.__name__}'."
|
|
236
|
+
)
|
|
237
|
+
PrecomputedValueProvider.replace_schema_node(node, output_result)
|
|
238
|
+
# We remove all parent dependencies as the cached output value will
|
|
239
|
+
# be used.
|
|
240
|
+
node.needs = {}
|
|
241
|
+
else:
|
|
242
|
+
# If there is no cached output the node must be re-run if it ends
|
|
243
|
+
# up as an ancestor of a target node.
|
|
244
|
+
fingerprint_run_output.is_hit = False
|
|
245
|
+
|
|
246
|
+
# Else the node was an input node and the output is the actual node's output.
|
|
247
|
+
else:
|
|
248
|
+
# As fingerprint_run_output is just the node's output there is no need to
|
|
249
|
+
# execute the node again. We can just return it from a
|
|
250
|
+
# `PrecomputedValueProvider`.
|
|
251
|
+
PrecomputedValueProvider.replace_schema_node(node, fingerprint_run_output)
|
|
252
|
+
node.needs = {}
|
|
253
|
+
|
|
254
|
+
# Continue walking for every parent node.
|
|
255
|
+
for parent_node_name in node.needs.values():
|
|
256
|
+
self._walk_and_prune(schema, parent_node_name, fingerprint_run_outputs)
|
|
@@ -0,0 +1,164 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Any, Dict, Text, Type
|
|
3
|
+
|
|
4
|
+
from rasa.engine.caching import TrainingCache
|
|
5
|
+
from rasa.engine.graph import ExecutionContext, GraphNodeHook, GraphSchema, SchemaNode
|
|
6
|
+
from rasa.engine.storage.storage import ModelStorage
|
|
7
|
+
from rasa.engine.training.components import PrecomputedValueProvider
|
|
8
|
+
import rasa.shared.utils.io
|
|
9
|
+
from rasa.engine.training import fingerprinting
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class TrainingHook(GraphNodeHook):
|
|
15
|
+
"""Caches fingerprints and outputs of nodes during model training."""
|
|
16
|
+
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
cache: TrainingCache,
|
|
20
|
+
model_storage: ModelStorage,
|
|
21
|
+
pruned_schema: GraphSchema,
|
|
22
|
+
) -> None:
|
|
23
|
+
"""Initializes a `TrainingHook`.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
cache: Cache used to store fingerprints and outputs.
|
|
27
|
+
model_storage: Used to cache `Resource`s.
|
|
28
|
+
pruned_schema: The pruned training schema.
|
|
29
|
+
"""
|
|
30
|
+
self._cache = cache
|
|
31
|
+
self._model_storage = model_storage
|
|
32
|
+
self._pruned_schema = pruned_schema
|
|
33
|
+
|
|
34
|
+
def on_before_node(
|
|
35
|
+
self,
|
|
36
|
+
node_name: Text,
|
|
37
|
+
execution_context: ExecutionContext,
|
|
38
|
+
config: Dict[Text, Any],
|
|
39
|
+
received_inputs: Dict[Text, Any],
|
|
40
|
+
) -> Dict:
|
|
41
|
+
"""Calculates the run fingerprint for use in `on_after_node`."""
|
|
42
|
+
graph_component_class = self._get_graph_component_class(
|
|
43
|
+
execution_context, node_name
|
|
44
|
+
)
|
|
45
|
+
graph_component_config = self._get_graph_component_config(
|
|
46
|
+
execution_context, node_name
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
merged_config = {
|
|
50
|
+
**config,
|
|
51
|
+
**graph_component_config,
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
fingerprint_key = fingerprinting.calculate_fingerprint_key(
|
|
55
|
+
graph_component_class=graph_component_class,
|
|
56
|
+
config=merged_config,
|
|
57
|
+
inputs=received_inputs,
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
return {"fingerprint_key": fingerprint_key}
|
|
61
|
+
|
|
62
|
+
def on_after_node(
|
|
63
|
+
self,
|
|
64
|
+
node_name: Text,
|
|
65
|
+
execution_context: ExecutionContext,
|
|
66
|
+
config: Dict[Text, Any],
|
|
67
|
+
output: Any,
|
|
68
|
+
input_hook_data: Dict,
|
|
69
|
+
) -> None:
|
|
70
|
+
"""Stores the fingerprints and caches the output of the node."""
|
|
71
|
+
# We should not re-cache the output of a PrecomputedValueProvider.
|
|
72
|
+
graph_component_class = self._pruned_schema.nodes[node_name].uses
|
|
73
|
+
|
|
74
|
+
if graph_component_class == PrecomputedValueProvider:
|
|
75
|
+
return None
|
|
76
|
+
|
|
77
|
+
output_fingerprint = rasa.shared.utils.io.deep_container_fingerprint(output)
|
|
78
|
+
fingerprint_key = input_hook_data["fingerprint_key"]
|
|
79
|
+
|
|
80
|
+
logger.debug(
|
|
81
|
+
f"Caching '{output.__class__.__name__}' with fingerprint_key: "
|
|
82
|
+
f"'{fingerprint_key}' and output_fingerprint '{output_fingerprint}'."
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
self._cache.cache_output(
|
|
86
|
+
fingerprint_key=fingerprint_key,
|
|
87
|
+
output=output,
|
|
88
|
+
output_fingerprint=output_fingerprint,
|
|
89
|
+
model_storage=self._model_storage,
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
@staticmethod
|
|
93
|
+
def _get_graph_component_class(
|
|
94
|
+
execution_context: ExecutionContext, node_name: Text
|
|
95
|
+
) -> Type:
|
|
96
|
+
graph_component_class = execution_context.graph_schema.nodes[node_name].uses
|
|
97
|
+
return graph_component_class
|
|
98
|
+
|
|
99
|
+
@staticmethod
|
|
100
|
+
def _get_graph_component_config(
|
|
101
|
+
execution_context: ExecutionContext, node_name: str
|
|
102
|
+
) -> Dict[str, Any]:
|
|
103
|
+
return execution_context.graph_schema.nodes[node_name].config
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
class LoggingHook(GraphNodeHook):
|
|
107
|
+
"""Logs the training of components."""
|
|
108
|
+
|
|
109
|
+
def __init__(self, pruned_schema: GraphSchema) -> None:
|
|
110
|
+
"""Creates hook.
|
|
111
|
+
|
|
112
|
+
Args:
|
|
113
|
+
pruned_schema: The pruned schema provides us with the information whether
|
|
114
|
+
a component is cached or not.
|
|
115
|
+
"""
|
|
116
|
+
self._pruned_schema = pruned_schema
|
|
117
|
+
|
|
118
|
+
def on_before_node(
|
|
119
|
+
self,
|
|
120
|
+
node_name: Text,
|
|
121
|
+
execution_context: ExecutionContext,
|
|
122
|
+
config: Dict[Text, Any],
|
|
123
|
+
received_inputs: Dict[Text, Any],
|
|
124
|
+
) -> Dict:
|
|
125
|
+
"""Logs the training start of a graph node."""
|
|
126
|
+
node = self._pruned_schema.nodes[node_name]
|
|
127
|
+
|
|
128
|
+
if not self._is_cached_node(node) and self._does_node_train(node):
|
|
129
|
+
logger.info(f"Starting to train component '{node.uses.__name__}'.")
|
|
130
|
+
|
|
131
|
+
return {}
|
|
132
|
+
|
|
133
|
+
@staticmethod
|
|
134
|
+
def _does_node_train(node: SchemaNode) -> bool:
|
|
135
|
+
# Nodes which train are always targets so that they store their output in the
|
|
136
|
+
# model storage. `is_input` filters out nodes which don't really train but e.g.
|
|
137
|
+
# persist some training data.
|
|
138
|
+
return node.is_target and not node.is_input
|
|
139
|
+
|
|
140
|
+
@staticmethod
|
|
141
|
+
def _is_cached_node(node: SchemaNode) -> bool:
|
|
142
|
+
return node.uses == PrecomputedValueProvider
|
|
143
|
+
|
|
144
|
+
def on_after_node(
|
|
145
|
+
self,
|
|
146
|
+
node_name: Text,
|
|
147
|
+
execution_context: ExecutionContext,
|
|
148
|
+
config: Dict[Text, Any],
|
|
149
|
+
output: Any,
|
|
150
|
+
input_hook_data: Dict,
|
|
151
|
+
) -> None:
|
|
152
|
+
"""Logs when a component finished its training."""
|
|
153
|
+
node = self._pruned_schema.nodes[node_name]
|
|
154
|
+
|
|
155
|
+
if not self._does_node_train(node):
|
|
156
|
+
return
|
|
157
|
+
|
|
158
|
+
if self._is_cached_node(node):
|
|
159
|
+
actual_component = execution_context.graph_schema.nodes[node_name]
|
|
160
|
+
logger.info(
|
|
161
|
+
f"Restored component '{actual_component.uses.__name__}' from cache."
|
|
162
|
+
)
|
|
163
|
+
else:
|
|
164
|
+
logger.info(f"Finished training component '{node.uses.__name__}'.")
|