rasa-pro 3.11.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- README.md +41 -0
- rasa/__init__.py +9 -0
- rasa/__main__.py +177 -0
- rasa/anonymization/__init__.py +2 -0
- rasa/anonymization/anonymisation_rule_yaml_reader.py +91 -0
- rasa/anonymization/anonymization_pipeline.py +286 -0
- rasa/anonymization/anonymization_rule_executor.py +260 -0
- rasa/anonymization/anonymization_rule_orchestrator.py +120 -0
- rasa/anonymization/schemas/config.yml +47 -0
- rasa/anonymization/utils.py +118 -0
- rasa/api.py +160 -0
- rasa/cli/__init__.py +5 -0
- rasa/cli/arguments/__init__.py +0 -0
- rasa/cli/arguments/data.py +106 -0
- rasa/cli/arguments/default_arguments.py +207 -0
- rasa/cli/arguments/evaluate.py +65 -0
- rasa/cli/arguments/export.py +51 -0
- rasa/cli/arguments/interactive.py +74 -0
- rasa/cli/arguments/run.py +219 -0
- rasa/cli/arguments/shell.py +17 -0
- rasa/cli/arguments/test.py +211 -0
- rasa/cli/arguments/train.py +279 -0
- rasa/cli/arguments/visualize.py +34 -0
- rasa/cli/arguments/x.py +30 -0
- rasa/cli/data.py +354 -0
- rasa/cli/e2e_test.py +259 -0
- rasa/cli/evaluate.py +222 -0
- rasa/cli/export.py +250 -0
- rasa/cli/inspect.py +75 -0
- rasa/cli/interactive.py +166 -0
- rasa/cli/license.py +65 -0
- rasa/cli/llm_fine_tuning.py +403 -0
- rasa/cli/markers.py +78 -0
- rasa/cli/project_templates/__init__.py +0 -0
- rasa/cli/project_templates/calm/actions/__init__.py +0 -0
- rasa/cli/project_templates/calm/actions/action_template.py +27 -0
- rasa/cli/project_templates/calm/actions/add_contact.py +30 -0
- rasa/cli/project_templates/calm/actions/db.py +57 -0
- rasa/cli/project_templates/calm/actions/list_contacts.py +22 -0
- rasa/cli/project_templates/calm/actions/remove_contact.py +35 -0
- rasa/cli/project_templates/calm/config.yml +10 -0
- rasa/cli/project_templates/calm/credentials.yml +33 -0
- rasa/cli/project_templates/calm/data/flows/add_contact.yml +31 -0
- rasa/cli/project_templates/calm/data/flows/list_contacts.yml +14 -0
- rasa/cli/project_templates/calm/data/flows/remove_contact.yml +29 -0
- rasa/cli/project_templates/calm/db/contacts.json +10 -0
- rasa/cli/project_templates/calm/domain/add_contact.yml +39 -0
- rasa/cli/project_templates/calm/domain/list_contacts.yml +17 -0
- rasa/cli/project_templates/calm/domain/remove_contact.yml +38 -0
- rasa/cli/project_templates/calm/domain/shared.yml +10 -0
- rasa/cli/project_templates/calm/e2e_tests/cancelations/user_cancels_during_a_correction.yml +16 -0
- rasa/cli/project_templates/calm/e2e_tests/cancelations/user_changes_mind_on_a_whim.yml +7 -0
- rasa/cli/project_templates/calm/e2e_tests/corrections/user_corrects_contact_handle.yml +20 -0
- rasa/cli/project_templates/calm/e2e_tests/corrections/user_corrects_contact_name.yml +19 -0
- rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_adds_contact_to_their_list.yml +15 -0
- rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_lists_contacts.yml +5 -0
- rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_removes_contact.yml +11 -0
- rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_removes_contact_from_list.yml +12 -0
- rasa/cli/project_templates/calm/endpoints.yml +58 -0
- rasa/cli/project_templates/default/actions/__init__.py +0 -0
- rasa/cli/project_templates/default/actions/actions.py +27 -0
- rasa/cli/project_templates/default/config.yml +44 -0
- rasa/cli/project_templates/default/credentials.yml +33 -0
- rasa/cli/project_templates/default/data/nlu.yml +91 -0
- rasa/cli/project_templates/default/data/rules.yml +13 -0
- rasa/cli/project_templates/default/data/stories.yml +30 -0
- rasa/cli/project_templates/default/domain.yml +34 -0
- rasa/cli/project_templates/default/endpoints.yml +42 -0
- rasa/cli/project_templates/default/tests/test_stories.yml +91 -0
- rasa/cli/project_templates/tutorial/actions/__init__.py +0 -0
- rasa/cli/project_templates/tutorial/actions/actions.py +22 -0
- rasa/cli/project_templates/tutorial/config.yml +12 -0
- rasa/cli/project_templates/tutorial/credentials.yml +33 -0
- rasa/cli/project_templates/tutorial/data/flows.yml +8 -0
- rasa/cli/project_templates/tutorial/data/patterns.yml +11 -0
- rasa/cli/project_templates/tutorial/domain.yml +35 -0
- rasa/cli/project_templates/tutorial/endpoints.yml +55 -0
- rasa/cli/run.py +143 -0
- rasa/cli/scaffold.py +273 -0
- rasa/cli/shell.py +141 -0
- rasa/cli/studio/__init__.py +0 -0
- rasa/cli/studio/download.py +62 -0
- rasa/cli/studio/studio.py +296 -0
- rasa/cli/studio/train.py +59 -0
- rasa/cli/studio/upload.py +62 -0
- rasa/cli/telemetry.py +102 -0
- rasa/cli/test.py +280 -0
- rasa/cli/train.py +278 -0
- rasa/cli/utils.py +484 -0
- rasa/cli/visualize.py +40 -0
- rasa/cli/x.py +206 -0
- rasa/constants.py +45 -0
- rasa/core/__init__.py +17 -0
- rasa/core/actions/__init__.py +0 -0
- rasa/core/actions/action.py +1318 -0
- rasa/core/actions/action_clean_stack.py +59 -0
- rasa/core/actions/action_exceptions.py +24 -0
- rasa/core/actions/action_hangup.py +29 -0
- rasa/core/actions/action_repeat_bot_messages.py +89 -0
- rasa/core/actions/action_run_slot_rejections.py +210 -0
- rasa/core/actions/action_trigger_chitchat.py +31 -0
- rasa/core/actions/action_trigger_flow.py +109 -0
- rasa/core/actions/action_trigger_search.py +31 -0
- rasa/core/actions/constants.py +5 -0
- rasa/core/actions/custom_action_executor.py +191 -0
- rasa/core/actions/direct_custom_actions_executor.py +80 -0
- rasa/core/actions/e2e_stub_custom_action_executor.py +72 -0
- rasa/core/actions/forms.py +741 -0
- rasa/core/actions/grpc_custom_action_executor.py +251 -0
- rasa/core/actions/http_custom_action_executor.py +145 -0
- rasa/core/actions/loops.py +114 -0
- rasa/core/actions/two_stage_fallback.py +186 -0
- rasa/core/agent.py +559 -0
- rasa/core/auth_retry_tracker_store.py +122 -0
- rasa/core/brokers/__init__.py +0 -0
- rasa/core/brokers/broker.py +126 -0
- rasa/core/brokers/file.py +58 -0
- rasa/core/brokers/kafka.py +324 -0
- rasa/core/brokers/pika.py +388 -0
- rasa/core/brokers/sql.py +86 -0
- rasa/core/channels/__init__.py +61 -0
- rasa/core/channels/botframework.py +338 -0
- rasa/core/channels/callback.py +84 -0
- rasa/core/channels/channel.py +456 -0
- rasa/core/channels/console.py +241 -0
- rasa/core/channels/development_inspector.py +197 -0
- rasa/core/channels/facebook.py +419 -0
- rasa/core/channels/hangouts.py +329 -0
- rasa/core/channels/inspector/.eslintrc.cjs +25 -0
- rasa/core/channels/inspector/.gitignore +23 -0
- rasa/core/channels/inspector/README.md +54 -0
- rasa/core/channels/inspector/assets/favicon.ico +0 -0
- rasa/core/channels/inspector/assets/rasa-chat.js +2 -0
- rasa/core/channels/inspector/custom.d.ts +3 -0
- rasa/core/channels/inspector/dist/assets/arc-632a63ec.js +1 -0
- rasa/core/channels/inspector/dist/assets/array-9f3ba611.js +1 -0
- rasa/core/channels/inspector/dist/assets/c4Diagram-d0fbc5ce-081e0df4.js +10 -0
- rasa/core/channels/inspector/dist/assets/classDiagram-936ed81e-3df0afc2.js +2 -0
- rasa/core/channels/inspector/dist/assets/classDiagram-v2-c3cb15f1-8c5ed31e.js +2 -0
- rasa/core/channels/inspector/dist/assets/createText-62fc7601-89c73b31.js +7 -0
- rasa/core/channels/inspector/dist/assets/edges-f2ad444c-4fc48c3e.js +4 -0
- rasa/core/channels/inspector/dist/assets/erDiagram-9d236eb7-907e0440.js +51 -0
- rasa/core/channels/inspector/dist/assets/flowDb-1972c806-9ec53a3c.js +6 -0
- rasa/core/channels/inspector/dist/assets/flowDiagram-7ea5b25a-41da787a.js +4 -0
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-8bea338b.js +1 -0
- rasa/core/channels/inspector/dist/assets/flowchart-elk-definition-abe16c3d-ce370633.js +139 -0
- rasa/core/channels/inspector/dist/assets/ganttDiagram-9b5ea136-90a36523.js +266 -0
- rasa/core/channels/inspector/dist/assets/gitGraphDiagram-99d0ae7c-41e1aa3f.js +70 -0
- rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-128cfa44.ttf +0 -0
- rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-21dbcb97.woff +0 -0
- rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-222b5e26.svg +329 -0
- rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-9ad89b2a.woff2 +0 -0
- rasa/core/channels/inspector/dist/assets/index-2c4b9a3b-e6f2af62.js +1 -0
- rasa/core/channels/inspector/dist/assets/index-3ee28881.css +1 -0
- rasa/core/channels/inspector/dist/assets/index-e793d777.js +1317 -0
- rasa/core/channels/inspector/dist/assets/infoDiagram-736b4530-8ceba4db.js +7 -0
- rasa/core/channels/inspector/dist/assets/init-77b53fdd.js +1 -0
- rasa/core/channels/inspector/dist/assets/journeyDiagram-df861f2b-960d3809.js +139 -0
- rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-60c05ee4.woff +0 -0
- rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-8335d9b8.svg +438 -0
- rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-9cc39c75.ttf +0 -0
- rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-ead13ccf.woff2 +0 -0
- rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-16705655.woff2 +0 -0
- rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-5aeb07f9.woff +0 -0
- rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-9c459044.ttf +0 -0
- rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-9e2898a4.svg +435 -0
- rasa/core/channels/inspector/dist/assets/layout-498807d8.js +1 -0
- rasa/core/channels/inspector/dist/assets/line-eeccc4e2.js +1 -0
- rasa/core/channels/inspector/dist/assets/linear-8a078617.js +1 -0
- rasa/core/channels/inspector/dist/assets/mindmap-definition-beec6740-396d17dd.js +109 -0
- rasa/core/channels/inspector/dist/assets/ordinal-ba9b4969.js +1 -0
- rasa/core/channels/inspector/dist/assets/path-53f90ab3.js +1 -0
- rasa/core/channels/inspector/dist/assets/pieDiagram-dbbf0591-dc9b5e1b.js +35 -0
- rasa/core/channels/inspector/dist/assets/quadrantDiagram-4d7f4fd6-a08cba6d.js +7 -0
- rasa/core/channels/inspector/dist/assets/requirementDiagram-6fc4c22a-87242b9e.js +52 -0
- rasa/core/channels/inspector/dist/assets/sankeyDiagram-8f13d901-53f6f391.js +8 -0
- rasa/core/channels/inspector/dist/assets/sequenceDiagram-b655622a-715c9c20.js +122 -0
- rasa/core/channels/inspector/dist/assets/stateDiagram-59f0c015-2e8fb31f.js +1 -0
- rasa/core/channels/inspector/dist/assets/stateDiagram-v2-2b26beab-7e2d2aa0.js +1 -0
- rasa/core/channels/inspector/dist/assets/styles-080da4f6-4420cea6.js +110 -0
- rasa/core/channels/inspector/dist/assets/styles-3dcbcfbf-28676cf4.js +159 -0
- rasa/core/channels/inspector/dist/assets/styles-9c745c82-cef936a6.js +207 -0
- rasa/core/channels/inspector/dist/assets/svgDrawCommon-4835440b-151251e9.js +1 -0
- rasa/core/channels/inspector/dist/assets/timeline-definition-5b62e21b-0d39bdb2.js +61 -0
- rasa/core/channels/inspector/dist/assets/xychartDiagram-2b33534f-a03fa445.js +7 -0
- rasa/core/channels/inspector/dist/index.html +44 -0
- rasa/core/channels/inspector/index.html +42 -0
- rasa/core/channels/inspector/jest.config.ts +13 -0
- rasa/core/channels/inspector/package.json +52 -0
- rasa/core/channels/inspector/setupTests.ts +2 -0
- rasa/core/channels/inspector/src/App.tsx +217 -0
- rasa/core/channels/inspector/src/components/Chat.tsx +95 -0
- rasa/core/channels/inspector/src/components/DiagramFlow.tsx +108 -0
- rasa/core/channels/inspector/src/components/DialogueInformation.tsx +187 -0
- rasa/core/channels/inspector/src/components/DialogueStack.tsx +136 -0
- rasa/core/channels/inspector/src/components/ExpandIcon.tsx +16 -0
- rasa/core/channels/inspector/src/components/FullscreenButton.tsx +45 -0
- rasa/core/channels/inspector/src/components/LoadingSpinner.tsx +22 -0
- rasa/core/channels/inspector/src/components/NoActiveFlow.tsx +21 -0
- rasa/core/channels/inspector/src/components/RasaLogo.tsx +32 -0
- rasa/core/channels/inspector/src/components/SaraDiagrams.tsx +39 -0
- rasa/core/channels/inspector/src/components/Slots.tsx +91 -0
- rasa/core/channels/inspector/src/components/Welcome.tsx +54 -0
- rasa/core/channels/inspector/src/helpers/audiostream.ts +191 -0
- rasa/core/channels/inspector/src/helpers/formatters.test.ts +392 -0
- rasa/core/channels/inspector/src/helpers/formatters.ts +306 -0
- rasa/core/channels/inspector/src/helpers/utils.ts +127 -0
- rasa/core/channels/inspector/src/main.tsx +13 -0
- rasa/core/channels/inspector/src/theme/Button/Button.ts +29 -0
- rasa/core/channels/inspector/src/theme/Heading/Heading.ts +31 -0
- rasa/core/channels/inspector/src/theme/Input/Input.ts +27 -0
- rasa/core/channels/inspector/src/theme/Link/Link.ts +10 -0
- rasa/core/channels/inspector/src/theme/Modal/Modal.ts +47 -0
- rasa/core/channels/inspector/src/theme/Table/Table.tsx +38 -0
- rasa/core/channels/inspector/src/theme/Tooltip/Tooltip.ts +12 -0
- rasa/core/channels/inspector/src/theme/base/breakpoints.ts +8 -0
- rasa/core/channels/inspector/src/theme/base/colors.ts +88 -0
- rasa/core/channels/inspector/src/theme/base/fonts/fontFaces.css +29 -0
- rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.eot +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.svg +329 -0
- rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.ttf +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.woff +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.woff2 +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.eot +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.svg +438 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.ttf +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.woff +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.woff2 +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.eot +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.svg +435 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.ttf +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.woff +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.woff2 +0 -0
- rasa/core/channels/inspector/src/theme/base/radii.ts +9 -0
- rasa/core/channels/inspector/src/theme/base/shadows.ts +7 -0
- rasa/core/channels/inspector/src/theme/base/sizes.ts +7 -0
- rasa/core/channels/inspector/src/theme/base/space.ts +15 -0
- rasa/core/channels/inspector/src/theme/base/styles.ts +13 -0
- rasa/core/channels/inspector/src/theme/base/typography.ts +24 -0
- rasa/core/channels/inspector/src/theme/base/zIndices.ts +19 -0
- rasa/core/channels/inspector/src/theme/index.ts +101 -0
- rasa/core/channels/inspector/src/types.ts +84 -0
- rasa/core/channels/inspector/src/vite-env.d.ts +1 -0
- rasa/core/channels/inspector/tests/__mocks__/fileMock.ts +1 -0
- rasa/core/channels/inspector/tests/__mocks__/matchMedia.ts +16 -0
- rasa/core/channels/inspector/tests/__mocks__/styleMock.ts +1 -0
- rasa/core/channels/inspector/tests/renderWithProviders.tsx +14 -0
- rasa/core/channels/inspector/tsconfig.json +26 -0
- rasa/core/channels/inspector/tsconfig.node.json +10 -0
- rasa/core/channels/inspector/vite.config.ts +8 -0
- rasa/core/channels/inspector/yarn.lock +6249 -0
- rasa/core/channels/mattermost.py +229 -0
- rasa/core/channels/rasa_chat.py +126 -0
- rasa/core/channels/rest.py +230 -0
- rasa/core/channels/rocketchat.py +174 -0
- rasa/core/channels/slack.py +620 -0
- rasa/core/channels/socketio.py +301 -0
- rasa/core/channels/telegram.py +298 -0
- rasa/core/channels/twilio.py +169 -0
- rasa/core/channels/vier_cvg.py +374 -0
- rasa/core/channels/voice_ready/__init__.py +0 -0
- rasa/core/channels/voice_ready/audiocodes.py +501 -0
- rasa/core/channels/voice_ready/jambonz.py +121 -0
- rasa/core/channels/voice_ready/jambonz_protocol.py +396 -0
- rasa/core/channels/voice_ready/twilio_voice.py +403 -0
- rasa/core/channels/voice_ready/utils.py +37 -0
- rasa/core/channels/voice_stream/__init__.py +0 -0
- rasa/core/channels/voice_stream/asr/__init__.py +0 -0
- rasa/core/channels/voice_stream/asr/asr_engine.py +89 -0
- rasa/core/channels/voice_stream/asr/asr_event.py +18 -0
- rasa/core/channels/voice_stream/asr/azure.py +130 -0
- rasa/core/channels/voice_stream/asr/deepgram.py +90 -0
- rasa/core/channels/voice_stream/audio_bytes.py +8 -0
- rasa/core/channels/voice_stream/browser_audio.py +107 -0
- rasa/core/channels/voice_stream/call_state.py +23 -0
- rasa/core/channels/voice_stream/tts/__init__.py +0 -0
- rasa/core/channels/voice_stream/tts/azure.py +106 -0
- rasa/core/channels/voice_stream/tts/cartesia.py +118 -0
- rasa/core/channels/voice_stream/tts/tts_cache.py +27 -0
- rasa/core/channels/voice_stream/tts/tts_engine.py +58 -0
- rasa/core/channels/voice_stream/twilio_media_streams.py +173 -0
- rasa/core/channels/voice_stream/util.py +57 -0
- rasa/core/channels/voice_stream/voice_channel.py +427 -0
- rasa/core/channels/webexteams.py +134 -0
- rasa/core/concurrent_lock_store.py +210 -0
- rasa/core/constants.py +112 -0
- rasa/core/evaluation/__init__.py +0 -0
- rasa/core/evaluation/marker.py +267 -0
- rasa/core/evaluation/marker_base.py +923 -0
- rasa/core/evaluation/marker_stats.py +293 -0
- rasa/core/evaluation/marker_tracker_loader.py +103 -0
- rasa/core/exceptions.py +29 -0
- rasa/core/exporter.py +284 -0
- rasa/core/featurizers/__init__.py +0 -0
- rasa/core/featurizers/precomputation.py +410 -0
- rasa/core/featurizers/single_state_featurizer.py +421 -0
- rasa/core/featurizers/tracker_featurizers.py +1262 -0
- rasa/core/http_interpreter.py +89 -0
- rasa/core/information_retrieval/__init__.py +7 -0
- rasa/core/information_retrieval/faiss.py +124 -0
- rasa/core/information_retrieval/information_retrieval.py +137 -0
- rasa/core/information_retrieval/milvus.py +59 -0
- rasa/core/information_retrieval/qdrant.py +96 -0
- rasa/core/jobs.py +63 -0
- rasa/core/lock.py +139 -0
- rasa/core/lock_store.py +343 -0
- rasa/core/migrate.py +403 -0
- rasa/core/nlg/__init__.py +3 -0
- rasa/core/nlg/callback.py +146 -0
- rasa/core/nlg/contextual_response_rephraser.py +320 -0
- rasa/core/nlg/generator.py +230 -0
- rasa/core/nlg/interpolator.py +143 -0
- rasa/core/nlg/response.py +155 -0
- rasa/core/nlg/summarize.py +70 -0
- rasa/core/persistor.py +538 -0
- rasa/core/policies/__init__.py +0 -0
- rasa/core/policies/ensemble.py +329 -0
- rasa/core/policies/enterprise_search_policy.py +905 -0
- rasa/core/policies/enterprise_search_prompt_template.jinja2 +25 -0
- rasa/core/policies/enterprise_search_prompt_with_citation_template.jinja2 +60 -0
- rasa/core/policies/flow_policy.py +205 -0
- rasa/core/policies/flows/__init__.py +0 -0
- rasa/core/policies/flows/flow_exceptions.py +44 -0
- rasa/core/policies/flows/flow_executor.py +754 -0
- rasa/core/policies/flows/flow_step_result.py +43 -0
- rasa/core/policies/intentless_policy.py +1031 -0
- rasa/core/policies/intentless_prompt_template.jinja2 +22 -0
- rasa/core/policies/memoization.py +538 -0
- rasa/core/policies/policy.py +725 -0
- rasa/core/policies/rule_policy.py +1273 -0
- rasa/core/policies/ted_policy.py +2169 -0
- rasa/core/policies/unexpected_intent_policy.py +1022 -0
- rasa/core/processor.py +1465 -0
- rasa/core/run.py +342 -0
- rasa/core/secrets_manager/__init__.py +0 -0
- rasa/core/secrets_manager/constants.py +36 -0
- rasa/core/secrets_manager/endpoints.py +391 -0
- rasa/core/secrets_manager/factory.py +241 -0
- rasa/core/secrets_manager/secret_manager.py +262 -0
- rasa/core/secrets_manager/vault.py +584 -0
- rasa/core/test.py +1335 -0
- rasa/core/tracker_store.py +1703 -0
- rasa/core/train.py +105 -0
- rasa/core/training/__init__.py +89 -0
- rasa/core/training/converters/__init__.py +0 -0
- rasa/core/training/converters/responses_prefix_converter.py +119 -0
- rasa/core/training/interactive.py +1744 -0
- rasa/core/training/story_conflict.py +381 -0
- rasa/core/training/training.py +93 -0
- rasa/core/utils.py +366 -0
- rasa/core/visualize.py +70 -0
- rasa/dialogue_understanding/__init__.py +0 -0
- rasa/dialogue_understanding/coexistence/__init__.py +0 -0
- rasa/dialogue_understanding/coexistence/constants.py +4 -0
- rasa/dialogue_understanding/coexistence/intent_based_router.py +196 -0
- rasa/dialogue_understanding/coexistence/llm_based_router.py +327 -0
- rasa/dialogue_understanding/coexistence/router_template.jinja2 +12 -0
- rasa/dialogue_understanding/commands/__init__.py +61 -0
- rasa/dialogue_understanding/commands/can_not_handle_command.py +70 -0
- rasa/dialogue_understanding/commands/cancel_flow_command.py +125 -0
- rasa/dialogue_understanding/commands/change_flow_command.py +44 -0
- rasa/dialogue_understanding/commands/chit_chat_answer_command.py +57 -0
- rasa/dialogue_understanding/commands/clarify_command.py +86 -0
- rasa/dialogue_understanding/commands/command.py +85 -0
- rasa/dialogue_understanding/commands/correct_slots_command.py +297 -0
- rasa/dialogue_understanding/commands/error_command.py +79 -0
- rasa/dialogue_understanding/commands/free_form_answer_command.py +9 -0
- rasa/dialogue_understanding/commands/handle_code_change_command.py +73 -0
- rasa/dialogue_understanding/commands/human_handoff_command.py +66 -0
- rasa/dialogue_understanding/commands/knowledge_answer_command.py +57 -0
- rasa/dialogue_understanding/commands/noop_command.py +54 -0
- rasa/dialogue_understanding/commands/repeat_bot_messages_command.py +60 -0
- rasa/dialogue_understanding/commands/restart_command.py +58 -0
- rasa/dialogue_understanding/commands/session_end_command.py +61 -0
- rasa/dialogue_understanding/commands/session_start_command.py +59 -0
- rasa/dialogue_understanding/commands/set_slot_command.py +160 -0
- rasa/dialogue_understanding/commands/skip_question_command.py +75 -0
- rasa/dialogue_understanding/commands/start_flow_command.py +107 -0
- rasa/dialogue_understanding/commands/user_silence_command.py +59 -0
- rasa/dialogue_understanding/commands/utils.py +45 -0
- rasa/dialogue_understanding/generator/__init__.py +21 -0
- rasa/dialogue_understanding/generator/command_generator.py +343 -0
- rasa/dialogue_understanding/generator/constants.py +27 -0
- rasa/dialogue_understanding/generator/flow_document_template.jinja2 +4 -0
- rasa/dialogue_understanding/generator/flow_retrieval.py +466 -0
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +500 -0
- rasa/dialogue_understanding/generator/llm_command_generator.py +67 -0
- rasa/dialogue_understanding/generator/multi_step/__init__.py +0 -0
- rasa/dialogue_understanding/generator/multi_step/fill_slots_prompt.jinja2 +62 -0
- rasa/dialogue_understanding/generator/multi_step/handle_flows_prompt.jinja2 +38 -0
- rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +893 -0
- rasa/dialogue_understanding/generator/nlu_command_adapter.py +258 -0
- rasa/dialogue_understanding/generator/single_step/__init__.py +0 -0
- rasa/dialogue_understanding/generator/single_step/command_prompt_template.jinja2 +60 -0
- rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +478 -0
- rasa/dialogue_understanding/patterns/__init__.py +0 -0
- rasa/dialogue_understanding/patterns/cancel.py +111 -0
- rasa/dialogue_understanding/patterns/cannot_handle.py +43 -0
- rasa/dialogue_understanding/patterns/chitchat.py +37 -0
- rasa/dialogue_understanding/patterns/clarify.py +97 -0
- rasa/dialogue_understanding/patterns/code_change.py +41 -0
- rasa/dialogue_understanding/patterns/collect_information.py +90 -0
- rasa/dialogue_understanding/patterns/completed.py +40 -0
- rasa/dialogue_understanding/patterns/continue_interrupted.py +42 -0
- rasa/dialogue_understanding/patterns/correction.py +278 -0
- rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +301 -0
- rasa/dialogue_understanding/patterns/human_handoff.py +37 -0
- rasa/dialogue_understanding/patterns/internal_error.py +47 -0
- rasa/dialogue_understanding/patterns/repeat.py +37 -0
- rasa/dialogue_understanding/patterns/restart.py +37 -0
- rasa/dialogue_understanding/patterns/search.py +37 -0
- rasa/dialogue_understanding/patterns/session_start.py +37 -0
- rasa/dialogue_understanding/patterns/skip_question.py +38 -0
- rasa/dialogue_understanding/patterns/user_silence.py +37 -0
- rasa/dialogue_understanding/processor/__init__.py +0 -0
- rasa/dialogue_understanding/processor/command_processor.py +720 -0
- rasa/dialogue_understanding/processor/command_processor_component.py +43 -0
- rasa/dialogue_understanding/stack/__init__.py +0 -0
- rasa/dialogue_understanding/stack/dialogue_stack.py +178 -0
- rasa/dialogue_understanding/stack/frames/__init__.py +19 -0
- rasa/dialogue_understanding/stack/frames/chit_chat_frame.py +27 -0
- rasa/dialogue_understanding/stack/frames/dialogue_stack_frame.py +137 -0
- rasa/dialogue_understanding/stack/frames/flow_stack_frame.py +157 -0
- rasa/dialogue_understanding/stack/frames/pattern_frame.py +10 -0
- rasa/dialogue_understanding/stack/frames/search_frame.py +27 -0
- rasa/dialogue_understanding/stack/utils.py +211 -0
- rasa/e2e_test/__init__.py +0 -0
- rasa/e2e_test/aggregate_test_stats_calculator.py +134 -0
- rasa/e2e_test/assertions.py +1345 -0
- rasa/e2e_test/assertions_schema.yml +129 -0
- rasa/e2e_test/constants.py +31 -0
- rasa/e2e_test/e2e_config.py +220 -0
- rasa/e2e_test/e2e_config_schema.yml +26 -0
- rasa/e2e_test/e2e_test_case.py +568 -0
- rasa/e2e_test/e2e_test_converter.py +363 -0
- rasa/e2e_test/e2e_test_converter_prompt.jinja2 +70 -0
- rasa/e2e_test/e2e_test_coverage_report.py +364 -0
- rasa/e2e_test/e2e_test_result.py +54 -0
- rasa/e2e_test/e2e_test_runner.py +1190 -0
- rasa/e2e_test/e2e_test_schema.yml +181 -0
- rasa/e2e_test/pykwalify_extensions.py +39 -0
- rasa/e2e_test/stub_custom_action.py +70 -0
- rasa/e2e_test/utils/__init__.py +0 -0
- rasa/e2e_test/utils/e2e_yaml_utils.py +55 -0
- rasa/e2e_test/utils/io.py +598 -0
- rasa/e2e_test/utils/validation.py +80 -0
- rasa/engine/__init__.py +0 -0
- rasa/engine/caching.py +463 -0
- rasa/engine/constants.py +17 -0
- rasa/engine/exceptions.py +14 -0
- rasa/engine/graph.py +642 -0
- rasa/engine/loader.py +48 -0
- rasa/engine/recipes/__init__.py +0 -0
- rasa/engine/recipes/config_files/default_config.yml +41 -0
- rasa/engine/recipes/default_components.py +97 -0
- rasa/engine/recipes/default_recipe.py +1258 -0
- rasa/engine/recipes/graph_recipe.py +78 -0
- rasa/engine/recipes/recipe.py +93 -0
- rasa/engine/runner/__init__.py +0 -0
- rasa/engine/runner/dask.py +250 -0
- rasa/engine/runner/interface.py +49 -0
- rasa/engine/storage/__init__.py +0 -0
- rasa/engine/storage/local_model_storage.py +244 -0
- rasa/engine/storage/resource.py +110 -0
- rasa/engine/storage/storage.py +199 -0
- rasa/engine/training/__init__.py +0 -0
- rasa/engine/training/components.py +176 -0
- rasa/engine/training/fingerprinting.py +64 -0
- rasa/engine/training/graph_trainer.py +256 -0
- rasa/engine/training/hooks.py +164 -0
- rasa/engine/validation.py +1451 -0
- rasa/env.py +14 -0
- rasa/exceptions.py +69 -0
- rasa/graph_components/__init__.py +0 -0
- rasa/graph_components/converters/__init__.py +0 -0
- rasa/graph_components/converters/nlu_message_converter.py +48 -0
- rasa/graph_components/providers/__init__.py +0 -0
- rasa/graph_components/providers/domain_for_core_training_provider.py +87 -0
- rasa/graph_components/providers/domain_provider.py +71 -0
- rasa/graph_components/providers/flows_provider.py +74 -0
- rasa/graph_components/providers/forms_provider.py +44 -0
- rasa/graph_components/providers/nlu_training_data_provider.py +56 -0
- rasa/graph_components/providers/responses_provider.py +44 -0
- rasa/graph_components/providers/rule_only_provider.py +49 -0
- rasa/graph_components/providers/story_graph_provider.py +96 -0
- rasa/graph_components/providers/training_tracker_provider.py +55 -0
- rasa/graph_components/validators/__init__.py +0 -0
- rasa/graph_components/validators/default_recipe_validator.py +550 -0
- rasa/graph_components/validators/finetuning_validator.py +302 -0
- rasa/hooks.py +112 -0
- rasa/jupyter.py +63 -0
- rasa/llm_fine_tuning/__init__.py +0 -0
- rasa/llm_fine_tuning/annotation_module.py +241 -0
- rasa/llm_fine_tuning/conversations.py +144 -0
- rasa/llm_fine_tuning/llm_data_preparation_module.py +178 -0
- rasa/llm_fine_tuning/paraphrasing/__init__.py +0 -0
- rasa/llm_fine_tuning/paraphrasing/conversation_rephraser.py +281 -0
- rasa/llm_fine_tuning/paraphrasing/default_rephrase_prompt_template.jina2 +44 -0
- rasa/llm_fine_tuning/paraphrasing/rephrase_validator.py +121 -0
- rasa/llm_fine_tuning/paraphrasing/rephrased_user_message.py +10 -0
- rasa/llm_fine_tuning/paraphrasing_module.py +128 -0
- rasa/llm_fine_tuning/storage.py +174 -0
- rasa/llm_fine_tuning/train_test_split_module.py +441 -0
- rasa/markers/__init__.py +0 -0
- rasa/markers/marker.py +269 -0
- rasa/markers/marker_base.py +828 -0
- rasa/markers/upload.py +74 -0
- rasa/markers/validate.py +21 -0
- rasa/model.py +118 -0
- rasa/model_manager/__init__.py +0 -0
- rasa/model_manager/config.py +40 -0
- rasa/model_manager/model_api.py +559 -0
- rasa/model_manager/runner_service.py +286 -0
- rasa/model_manager/socket_bridge.py +146 -0
- rasa/model_manager/studio_jwt_auth.py +86 -0
- rasa/model_manager/trainer_service.py +325 -0
- rasa/model_manager/utils.py +87 -0
- rasa/model_manager/warm_rasa_process.py +187 -0
- rasa/model_service.py +112 -0
- rasa/model_testing.py +457 -0
- rasa/model_training.py +595 -0
- rasa/nlu/__init__.py +7 -0
- rasa/nlu/classifiers/__init__.py +3 -0
- rasa/nlu/classifiers/classifier.py +5 -0
- rasa/nlu/classifiers/diet_classifier.py +1881 -0
- rasa/nlu/classifiers/fallback_classifier.py +192 -0
- rasa/nlu/classifiers/keyword_intent_classifier.py +188 -0
- rasa/nlu/classifiers/logistic_regression_classifier.py +253 -0
- rasa/nlu/classifiers/mitie_intent_classifier.py +156 -0
- rasa/nlu/classifiers/regex_message_handler.py +56 -0
- rasa/nlu/classifiers/sklearn_intent_classifier.py +330 -0
- rasa/nlu/constants.py +77 -0
- rasa/nlu/convert.py +40 -0
- rasa/nlu/emulators/__init__.py +0 -0
- rasa/nlu/emulators/dialogflow.py +55 -0
- rasa/nlu/emulators/emulator.py +49 -0
- rasa/nlu/emulators/luis.py +86 -0
- rasa/nlu/emulators/no_emulator.py +10 -0
- rasa/nlu/emulators/wit.py +56 -0
- rasa/nlu/extractors/__init__.py +0 -0
- rasa/nlu/extractors/crf_entity_extractor.py +715 -0
- rasa/nlu/extractors/duckling_entity_extractor.py +206 -0
- rasa/nlu/extractors/entity_synonyms.py +178 -0
- rasa/nlu/extractors/extractor.py +470 -0
- rasa/nlu/extractors/mitie_entity_extractor.py +293 -0
- rasa/nlu/extractors/regex_entity_extractor.py +220 -0
- rasa/nlu/extractors/spacy_entity_extractor.py +95 -0
- rasa/nlu/featurizers/__init__.py +0 -0
- rasa/nlu/featurizers/dense_featurizer/__init__.py +0 -0
- rasa/nlu/featurizers/dense_featurizer/convert_featurizer.py +445 -0
- rasa/nlu/featurizers/dense_featurizer/dense_featurizer.py +57 -0
- rasa/nlu/featurizers/dense_featurizer/lm_featurizer.py +768 -0
- rasa/nlu/featurizers/dense_featurizer/mitie_featurizer.py +170 -0
- rasa/nlu/featurizers/dense_featurizer/spacy_featurizer.py +132 -0
- rasa/nlu/featurizers/featurizer.py +89 -0
- rasa/nlu/featurizers/sparse_featurizer/__init__.py +0 -0
- rasa/nlu/featurizers/sparse_featurizer/count_vectors_featurizer.py +867 -0
- rasa/nlu/featurizers/sparse_featurizer/lexical_syntactic_featurizer.py +571 -0
- rasa/nlu/featurizers/sparse_featurizer/regex_featurizer.py +271 -0
- rasa/nlu/featurizers/sparse_featurizer/sparse_featurizer.py +9 -0
- rasa/nlu/model.py +24 -0
- rasa/nlu/run.py +27 -0
- rasa/nlu/selectors/__init__.py +0 -0
- rasa/nlu/selectors/response_selector.py +987 -0
- rasa/nlu/test.py +1940 -0
- rasa/nlu/tokenizers/__init__.py +0 -0
- rasa/nlu/tokenizers/jieba_tokenizer.py +148 -0
- rasa/nlu/tokenizers/mitie_tokenizer.py +75 -0
- rasa/nlu/tokenizers/spacy_tokenizer.py +72 -0
- rasa/nlu/tokenizers/tokenizer.py +239 -0
- rasa/nlu/tokenizers/whitespace_tokenizer.py +95 -0
- rasa/nlu/utils/__init__.py +35 -0
- rasa/nlu/utils/bilou_utils.py +462 -0
- rasa/nlu/utils/hugging_face/__init__.py +0 -0
- rasa/nlu/utils/hugging_face/registry.py +108 -0
- rasa/nlu/utils/hugging_face/transformers_pre_post_processors.py +311 -0
- rasa/nlu/utils/mitie_utils.py +113 -0
- rasa/nlu/utils/pattern_utils.py +168 -0
- rasa/nlu/utils/spacy_utils.py +310 -0
- rasa/plugin.py +90 -0
- rasa/server.py +1624 -0
- rasa/shared/__init__.py +0 -0
- rasa/shared/constants.py +310 -0
- rasa/shared/core/__init__.py +0 -0
- rasa/shared/core/command_payload_reader.py +109 -0
- rasa/shared/core/constants.py +180 -0
- rasa/shared/core/conversation.py +46 -0
- rasa/shared/core/domain.py +2172 -0
- rasa/shared/core/events.py +2559 -0
- rasa/shared/core/flows/__init__.py +7 -0
- rasa/shared/core/flows/flow.py +562 -0
- rasa/shared/core/flows/flow_path.py +84 -0
- rasa/shared/core/flows/flow_step.py +146 -0
- rasa/shared/core/flows/flow_step_links.py +319 -0
- rasa/shared/core/flows/flow_step_sequence.py +70 -0
- rasa/shared/core/flows/flows_list.py +258 -0
- rasa/shared/core/flows/flows_yaml_schema.json +303 -0
- rasa/shared/core/flows/nlu_trigger.py +117 -0
- rasa/shared/core/flows/steps/__init__.py +24 -0
- rasa/shared/core/flows/steps/action.py +56 -0
- rasa/shared/core/flows/steps/call.py +64 -0
- rasa/shared/core/flows/steps/collect.py +112 -0
- rasa/shared/core/flows/steps/constants.py +5 -0
- rasa/shared/core/flows/steps/continuation.py +36 -0
- rasa/shared/core/flows/steps/end.py +22 -0
- rasa/shared/core/flows/steps/internal.py +44 -0
- rasa/shared/core/flows/steps/link.py +51 -0
- rasa/shared/core/flows/steps/no_operation.py +48 -0
- rasa/shared/core/flows/steps/set_slots.py +50 -0
- rasa/shared/core/flows/steps/start.py +30 -0
- rasa/shared/core/flows/utils.py +39 -0
- rasa/shared/core/flows/validation.py +735 -0
- rasa/shared/core/flows/yaml_flows_io.py +405 -0
- rasa/shared/core/generator.py +908 -0
- rasa/shared/core/slot_mappings.py +526 -0
- rasa/shared/core/slots.py +654 -0
- rasa/shared/core/trackers.py +1183 -0
- rasa/shared/core/training_data/__init__.py +0 -0
- rasa/shared/core/training_data/loading.py +89 -0
- rasa/shared/core/training_data/story_reader/__init__.py +0 -0
- rasa/shared/core/training_data/story_reader/story_reader.py +129 -0
- rasa/shared/core/training_data/story_reader/story_step_builder.py +168 -0
- rasa/shared/core/training_data/story_reader/yaml_story_reader.py +888 -0
- rasa/shared/core/training_data/story_writer/__init__.py +0 -0
- rasa/shared/core/training_data/story_writer/story_writer.py +76 -0
- rasa/shared/core/training_data/story_writer/yaml_story_writer.py +444 -0
- rasa/shared/core/training_data/structures.py +858 -0
- rasa/shared/core/training_data/visualization.html +146 -0
- rasa/shared/core/training_data/visualization.py +603 -0
- rasa/shared/data.py +249 -0
- rasa/shared/engine/__init__.py +0 -0
- rasa/shared/engine/caching.py +26 -0
- rasa/shared/exceptions.py +167 -0
- rasa/shared/importers/__init__.py +0 -0
- rasa/shared/importers/importer.py +770 -0
- rasa/shared/importers/multi_project.py +215 -0
- rasa/shared/importers/rasa.py +108 -0
- rasa/shared/importers/remote_importer.py +196 -0
- rasa/shared/importers/utils.py +36 -0
- rasa/shared/nlu/__init__.py +0 -0
- rasa/shared/nlu/constants.py +49 -0
- rasa/shared/nlu/interpreter.py +10 -0
- rasa/shared/nlu/training_data/__init__.py +0 -0
- rasa/shared/nlu/training_data/entities_parser.py +208 -0
- rasa/shared/nlu/training_data/features.py +492 -0
- rasa/shared/nlu/training_data/formats/__init__.py +10 -0
- rasa/shared/nlu/training_data/formats/dialogflow.py +163 -0
- rasa/shared/nlu/training_data/formats/luis.py +87 -0
- rasa/shared/nlu/training_data/formats/rasa.py +135 -0
- rasa/shared/nlu/training_data/formats/rasa_yaml.py +618 -0
- rasa/shared/nlu/training_data/formats/readerwriter.py +244 -0
- rasa/shared/nlu/training_data/formats/wit.py +52 -0
- rasa/shared/nlu/training_data/loading.py +137 -0
- rasa/shared/nlu/training_data/lookup_tables_parser.py +30 -0
- rasa/shared/nlu/training_data/message.py +490 -0
- rasa/shared/nlu/training_data/schemas/__init__.py +0 -0
- rasa/shared/nlu/training_data/schemas/data_schema.py +85 -0
- rasa/shared/nlu/training_data/schemas/nlu.yml +53 -0
- rasa/shared/nlu/training_data/schemas/responses.yml +70 -0
- rasa/shared/nlu/training_data/synonyms_parser.py +42 -0
- rasa/shared/nlu/training_data/training_data.py +729 -0
- rasa/shared/nlu/training_data/util.py +223 -0
- rasa/shared/providers/__init__.py +0 -0
- rasa/shared/providers/_configs/__init__.py +0 -0
- rasa/shared/providers/_configs/azure_openai_client_config.py +183 -0
- rasa/shared/providers/_configs/client_config.py +57 -0
- rasa/shared/providers/_configs/default_litellm_client_config.py +130 -0
- rasa/shared/providers/_configs/huggingface_local_embedding_client_config.py +234 -0
- rasa/shared/providers/_configs/litellm_router_client_config.py +220 -0
- rasa/shared/providers/_configs/model_group_config.py +167 -0
- rasa/shared/providers/_configs/openai_client_config.py +175 -0
- rasa/shared/providers/_configs/rasa_llm_client_config.py +73 -0
- rasa/shared/providers/_configs/self_hosted_llm_client_config.py +177 -0
- rasa/shared/providers/_configs/utils.py +117 -0
- rasa/shared/providers/_ssl_verification_utils.py +124 -0
- rasa/shared/providers/_utils.py +79 -0
- rasa/shared/providers/embedding/__init__.py +0 -0
- rasa/shared/providers/embedding/_base_litellm_embedding_client.py +243 -0
- rasa/shared/providers/embedding/_langchain_embedding_client_adapter.py +74 -0
- rasa/shared/providers/embedding/azure_openai_embedding_client.py +310 -0
- rasa/shared/providers/embedding/default_litellm_embedding_client.py +126 -0
- rasa/shared/providers/embedding/embedding_client.py +90 -0
- rasa/shared/providers/embedding/embedding_response.py +41 -0
- rasa/shared/providers/embedding/huggingface_local_embedding_client.py +191 -0
- rasa/shared/providers/embedding/litellm_router_embedding_client.py +135 -0
- rasa/shared/providers/embedding/openai_embedding_client.py +172 -0
- rasa/shared/providers/llm/__init__.py +0 -0
- rasa/shared/providers/llm/_base_litellm_client.py +263 -0
- rasa/shared/providers/llm/azure_openai_llm_client.py +359 -0
- rasa/shared/providers/llm/default_litellm_llm_client.py +108 -0
- rasa/shared/providers/llm/litellm_router_llm_client.py +182 -0
- rasa/shared/providers/llm/llm_client.py +76 -0
- rasa/shared/providers/llm/llm_response.py +50 -0
- rasa/shared/providers/llm/openai_llm_client.py +155 -0
- rasa/shared/providers/llm/rasa_llm_client.py +112 -0
- rasa/shared/providers/llm/self_hosted_llm_client.py +269 -0
- rasa/shared/providers/mappings.py +94 -0
- rasa/shared/providers/router/__init__.py +0 -0
- rasa/shared/providers/router/_base_litellm_router_client.py +183 -0
- rasa/shared/providers/router/router_client.py +73 -0
- rasa/shared/utils/__init__.py +0 -0
- rasa/shared/utils/cli.py +102 -0
- rasa/shared/utils/common.py +324 -0
- rasa/shared/utils/constants.py +4 -0
- rasa/shared/utils/health_check/__init__.py +0 -0
- rasa/shared/utils/health_check/embeddings_health_check_mixin.py +31 -0
- rasa/shared/utils/health_check/health_check.py +258 -0
- rasa/shared/utils/health_check/llm_health_check_mixin.py +31 -0
- rasa/shared/utils/io.py +499 -0
- rasa/shared/utils/llm.py +760 -0
- rasa/shared/utils/pykwalify_extensions.py +27 -0
- rasa/shared/utils/schemas/__init__.py +0 -0
- rasa/shared/utils/schemas/config.yml +2 -0
- rasa/shared/utils/schemas/domain.yml +145 -0
- rasa/shared/utils/schemas/events.py +214 -0
- rasa/shared/utils/schemas/model_config.yml +36 -0
- rasa/shared/utils/schemas/stories.yml +173 -0
- rasa/shared/utils/yaml.py +1067 -0
- rasa/studio/__init__.py +0 -0
- rasa/studio/auth.py +270 -0
- rasa/studio/config.py +136 -0
- rasa/studio/constants.py +19 -0
- rasa/studio/data_handler.py +368 -0
- rasa/studio/download.py +489 -0
- rasa/studio/results_logger.py +137 -0
- rasa/studio/train.py +134 -0
- rasa/studio/upload.py +549 -0
- rasa/telemetry.py +1869 -0
- rasa/tracing/__init__.py +0 -0
- rasa/tracing/config.py +355 -0
- rasa/tracing/constants.py +62 -0
- rasa/tracing/instrumentation/__init__.py +0 -0
- rasa/tracing/instrumentation/attribute_extractors.py +764 -0
- rasa/tracing/instrumentation/instrumentation.py +1306 -0
- rasa/tracing/instrumentation/intentless_policy_instrumentation.py +144 -0
- rasa/tracing/instrumentation/metrics.py +294 -0
- rasa/tracing/metric_instrument_provider.py +205 -0
- rasa/utils/__init__.py +0 -0
- rasa/utils/beta.py +83 -0
- rasa/utils/cli.py +28 -0
- rasa/utils/common.py +639 -0
- rasa/utils/converter.py +53 -0
- rasa/utils/endpoints.py +331 -0
- rasa/utils/io.py +252 -0
- rasa/utils/json_utils.py +60 -0
- rasa/utils/licensing.py +542 -0
- rasa/utils/log_utils.py +181 -0
- rasa/utils/mapper.py +210 -0
- rasa/utils/ml_utils.py +147 -0
- rasa/utils/plotting.py +362 -0
- rasa/utils/sanic_error_handler.py +32 -0
- rasa/utils/singleton.py +23 -0
- rasa/utils/tensorflow/__init__.py +0 -0
- rasa/utils/tensorflow/callback.py +112 -0
- rasa/utils/tensorflow/constants.py +116 -0
- rasa/utils/tensorflow/crf.py +492 -0
- rasa/utils/tensorflow/data_generator.py +440 -0
- rasa/utils/tensorflow/environment.py +161 -0
- rasa/utils/tensorflow/exceptions.py +5 -0
- rasa/utils/tensorflow/feature_array.py +366 -0
- rasa/utils/tensorflow/layers.py +1565 -0
- rasa/utils/tensorflow/layers_utils.py +113 -0
- rasa/utils/tensorflow/metrics.py +281 -0
- rasa/utils/tensorflow/model_data.py +798 -0
- rasa/utils/tensorflow/model_data_utils.py +499 -0
- rasa/utils/tensorflow/models.py +935 -0
- rasa/utils/tensorflow/rasa_layers.py +1094 -0
- rasa/utils/tensorflow/transformer.py +640 -0
- rasa/utils/tensorflow/types.py +6 -0
- rasa/utils/train_utils.py +572 -0
- rasa/utils/url_tools.py +53 -0
- rasa/utils/yaml.py +54 -0
- rasa/validator.py +1653 -0
- rasa/version.py +3 -0
- rasa_pro-3.11.3.dist-info/METADATA +198 -0
- rasa_pro-3.11.3.dist-info/NOTICE +5 -0
- rasa_pro-3.11.3.dist-info/RECORD +779 -0
- rasa_pro-3.11.3.dist-info/WHEEL +4 -0
- rasa_pro-3.11.3.dist-info/entry_points.txt +3 -0
|
@@ -0,0 +1,1258 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import copy
|
|
4
|
+
import enum
|
|
5
|
+
import logging
|
|
6
|
+
import math
|
|
7
|
+
from enum import Enum
|
|
8
|
+
from typing import Dict, Text, Any, Tuple, Type, Optional, List, Callable, Set, Union
|
|
9
|
+
|
|
10
|
+
import dataclasses
|
|
11
|
+
|
|
12
|
+
from rasa.core.featurizers.precomputation import (
|
|
13
|
+
CoreFeaturizationInputConverter,
|
|
14
|
+
CoreFeaturizationCollector,
|
|
15
|
+
)
|
|
16
|
+
from rasa.graph_components.providers.flows_provider import FlowsProvider
|
|
17
|
+
from rasa.dialogue_understanding.processor.command_processor_component import (
|
|
18
|
+
CommandProcessorComponent,
|
|
19
|
+
)
|
|
20
|
+
from rasa.shared.exceptions import FileNotFoundException
|
|
21
|
+
from rasa.core.policies.ensemble import DefaultPolicyPredictionEnsemble
|
|
22
|
+
|
|
23
|
+
from rasa.engine.graph import (
|
|
24
|
+
GraphSchema,
|
|
25
|
+
GraphComponent,
|
|
26
|
+
SchemaNode,
|
|
27
|
+
GraphModelConfiguration,
|
|
28
|
+
)
|
|
29
|
+
from rasa.engine.constants import (
|
|
30
|
+
PLACEHOLDER_IMPORTER,
|
|
31
|
+
PLACEHOLDER_MESSAGE,
|
|
32
|
+
PLACEHOLDER_TRACKER,
|
|
33
|
+
PLACEHOLDER_ENDPOINTS,
|
|
34
|
+
)
|
|
35
|
+
from rasa.engine.recipes.recipe import Recipe
|
|
36
|
+
from rasa.engine.storage.resource import Resource
|
|
37
|
+
from rasa.graph_components.converters.nlu_message_converter import NLUMessageConverter
|
|
38
|
+
from rasa.graph_components.providers.domain_provider import DomainProvider
|
|
39
|
+
from rasa.graph_components.providers.forms_provider import FormsProvider
|
|
40
|
+
from rasa.graph_components.providers.responses_provider import ResponsesProvider
|
|
41
|
+
from rasa.graph_components.providers.domain_for_core_training_provider import (
|
|
42
|
+
DomainForCoreTrainingProvider,
|
|
43
|
+
)
|
|
44
|
+
from rasa.graph_components.providers.nlu_training_data_provider import (
|
|
45
|
+
NLUTrainingDataProvider,
|
|
46
|
+
)
|
|
47
|
+
from rasa.graph_components.providers.rule_only_provider import RuleOnlyDataProvider
|
|
48
|
+
from rasa.graph_components.providers.story_graph_provider import StoryGraphProvider
|
|
49
|
+
from rasa.graph_components.providers.training_tracker_provider import (
|
|
50
|
+
TrainingTrackerProvider,
|
|
51
|
+
)
|
|
52
|
+
import rasa.shared.constants
|
|
53
|
+
from rasa.shared.exceptions import RasaException, InvalidConfigException
|
|
54
|
+
from rasa.shared.constants import ASSISTANT_ID_KEY
|
|
55
|
+
from rasa.shared.data import TrainingType
|
|
56
|
+
from rasa.shared.utils.yaml import read_config_file
|
|
57
|
+
|
|
58
|
+
from rasa.utils.tensorflow.constants import EPOCHS
|
|
59
|
+
from rasa.shared.utils.common import (
|
|
60
|
+
class_from_module_path,
|
|
61
|
+
transform_collection_to_sentence,
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
logger = logging.getLogger(__name__)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
DEFAULT_PREDICT_KWARGS = dict(constructor_name="load", eager=True, is_target=False)
|
|
68
|
+
|
|
69
|
+
COMMENTS_FOR_KEYS = {
|
|
70
|
+
"pipeline": (
|
|
71
|
+
f"# # No configuration for the NLU pipeline was provided. The following "
|
|
72
|
+
f"default pipeline was used to train your model.\n"
|
|
73
|
+
f"# # If you'd like to customize it, uncomment and adjust the pipeline.\n"
|
|
74
|
+
f"# # See {rasa.shared.constants.DOCS_URL_PIPELINE} for more information.\n"
|
|
75
|
+
),
|
|
76
|
+
"policies": (
|
|
77
|
+
f"# # No configuration for policies was provided. The following default "
|
|
78
|
+
f"policies were used to train your model.\n"
|
|
79
|
+
f"# # If you'd like to customize them, uncomment and adjust the policies.\n"
|
|
80
|
+
f"# # See {rasa.shared.constants.DOCS_URL_POLICIES} for more information.\n"
|
|
81
|
+
),
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
class DefaultV1RecipeRegisterException(RasaException):
|
|
86
|
+
"""If you register a class which is not of type `GraphComponent`."""
|
|
87
|
+
|
|
88
|
+
pass
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class DefaultV1Recipe(Recipe):
|
|
92
|
+
"""Recipe which converts the normal model config to train and predict graph."""
|
|
93
|
+
|
|
94
|
+
@enum.unique
|
|
95
|
+
class ComponentType(Enum):
|
|
96
|
+
"""Enum to categorize and place custom components correctly in the graph."""
|
|
97
|
+
|
|
98
|
+
MESSAGE_TOKENIZER = 0
|
|
99
|
+
MESSAGE_FEATURIZER = 1
|
|
100
|
+
INTENT_CLASSIFIER = 2
|
|
101
|
+
ENTITY_EXTRACTOR = 3
|
|
102
|
+
POLICY_WITHOUT_END_TO_END_SUPPORT = 4
|
|
103
|
+
POLICY_WITH_END_TO_END_SUPPORT = 5
|
|
104
|
+
MODEL_LOADER = 6
|
|
105
|
+
COMMAND_GENERATOR = 7
|
|
106
|
+
COEXISTENCE_ROUTER = 8
|
|
107
|
+
|
|
108
|
+
name = "default.v1"
|
|
109
|
+
_registered_components: Dict[Text, RegisteredComponent] = {} # noqa: RUF012
|
|
110
|
+
|
|
111
|
+
def __init__(self) -> None:
|
|
112
|
+
"""Creates recipe."""
|
|
113
|
+
self._use_core = True
|
|
114
|
+
self._use_nlu = True
|
|
115
|
+
self._use_end_to_end = True
|
|
116
|
+
self._is_finetuning = False
|
|
117
|
+
|
|
118
|
+
@dataclasses.dataclass()
|
|
119
|
+
class RegisteredComponent:
|
|
120
|
+
"""Describes a graph component which was registered with the decorator."""
|
|
121
|
+
|
|
122
|
+
clazz: Type[GraphComponent]
|
|
123
|
+
types: Set[DefaultV1Recipe.ComponentType]
|
|
124
|
+
is_trainable: bool
|
|
125
|
+
model_from: Optional[Text]
|
|
126
|
+
|
|
127
|
+
@classmethod
|
|
128
|
+
def register(
|
|
129
|
+
cls,
|
|
130
|
+
component_types: Union[ComponentType, List[ComponentType]],
|
|
131
|
+
is_trainable: bool,
|
|
132
|
+
model_from: Optional[Text] = None,
|
|
133
|
+
) -> Callable[[Type[GraphComponent]], Type[GraphComponent]]:
|
|
134
|
+
"""This decorator can be used to register classes with the recipe.
|
|
135
|
+
|
|
136
|
+
Args:
|
|
137
|
+
component_types: Describes the types of a component which are then used
|
|
138
|
+
to place the component in the graph.
|
|
139
|
+
is_trainable: `True` if the component requires training.
|
|
140
|
+
model_from: Can be used if this component requires a pre-loaded model
|
|
141
|
+
such as `SpacyNLP` or `MitieNLP`.
|
|
142
|
+
|
|
143
|
+
Returns:
|
|
144
|
+
The registered class.
|
|
145
|
+
"""
|
|
146
|
+
|
|
147
|
+
def decorator(registered_class: Type[GraphComponent]) -> Type[GraphComponent]:
|
|
148
|
+
if not issubclass(registered_class, GraphComponent):
|
|
149
|
+
raise DefaultV1RecipeRegisterException(
|
|
150
|
+
f"Failed to register class '{registered_class.__name__}' with "
|
|
151
|
+
f"the recipe '{cls.name}'. The class has to be of type "
|
|
152
|
+
f"'{GraphComponent.__name__}'."
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
if isinstance(component_types, cls.ComponentType):
|
|
156
|
+
unique_types = {component_types}
|
|
157
|
+
else:
|
|
158
|
+
unique_types = set(component_types)
|
|
159
|
+
|
|
160
|
+
cls._registered_components[registered_class.__name__] = (
|
|
161
|
+
cls.RegisteredComponent(
|
|
162
|
+
registered_class, unique_types, is_trainable, model_from
|
|
163
|
+
)
|
|
164
|
+
)
|
|
165
|
+
return registered_class
|
|
166
|
+
|
|
167
|
+
return decorator
|
|
168
|
+
|
|
169
|
+
@classmethod
|
|
170
|
+
def _from_registry(cls, name: Text) -> RegisteredComponent:
|
|
171
|
+
# Importing all the default Rasa components will automatically register them
|
|
172
|
+
from rasa.engine.recipes.default_components import DEFAULT_COMPONENTS # noqa
|
|
173
|
+
|
|
174
|
+
if name in cls._registered_components:
|
|
175
|
+
return cls._registered_components[name]
|
|
176
|
+
|
|
177
|
+
if "." in name:
|
|
178
|
+
clazz = class_from_module_path(name)
|
|
179
|
+
if clazz.__name__ in cls._registered_components:
|
|
180
|
+
return cls._registered_components[clazz.__name__]
|
|
181
|
+
|
|
182
|
+
raise InvalidConfigException(
|
|
183
|
+
f"Can't load class for name '{name}'. Please make sure to provide "
|
|
184
|
+
f"a valid name or module path and to register it using the "
|
|
185
|
+
f"'@DefaultV1Recipe.register' decorator."
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
def graph_config_for_recipe(
|
|
189
|
+
self,
|
|
190
|
+
config: Dict,
|
|
191
|
+
cli_parameters: Dict[Text, Any],
|
|
192
|
+
training_type: TrainingType = TrainingType.BOTH,
|
|
193
|
+
is_finetuning: bool = False,
|
|
194
|
+
) -> GraphModelConfiguration:
|
|
195
|
+
"""Converts the default config to graphs (see interface for full docstring)."""
|
|
196
|
+
self._use_core = (
|
|
197
|
+
bool(config.get("policies")) and not training_type == TrainingType.NLU
|
|
198
|
+
)
|
|
199
|
+
self._use_nlu = (
|
|
200
|
+
bool(config.get("pipeline")) and not training_type == TrainingType.CORE
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
if not self._use_nlu and training_type == TrainingType.NLU:
|
|
204
|
+
raise InvalidConfigException(
|
|
205
|
+
"Can't train an NLU model without a specified pipeline. Please make "
|
|
206
|
+
"sure to specify a valid pipeline in your configuration."
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
if not self._use_core and training_type == TrainingType.CORE:
|
|
210
|
+
raise InvalidConfigException(
|
|
211
|
+
"Can't train an Core model without policies. Please make "
|
|
212
|
+
"sure to specify a valid policy in your configuration."
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
self._use_end_to_end = (
|
|
216
|
+
self._use_nlu
|
|
217
|
+
and self._use_core
|
|
218
|
+
and training_type == TrainingType.END_TO_END
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
self._is_finetuning = is_finetuning
|
|
222
|
+
|
|
223
|
+
train_nodes, preprocessors = self._create_train_nodes(config, cli_parameters)
|
|
224
|
+
predict_nodes = self._create_predict_nodes(config, preprocessors, train_nodes)
|
|
225
|
+
|
|
226
|
+
core_target = "select_prediction" if self._use_core else None
|
|
227
|
+
|
|
228
|
+
from rasa.nlu.classifiers.regex_message_handler import RegexMessageHandler
|
|
229
|
+
|
|
230
|
+
return GraphModelConfiguration(
|
|
231
|
+
train_schema=GraphSchema(train_nodes),
|
|
232
|
+
predict_schema=GraphSchema(predict_nodes),
|
|
233
|
+
training_type=training_type,
|
|
234
|
+
assistant_id=config.get(ASSISTANT_ID_KEY),
|
|
235
|
+
language=config.get("language"),
|
|
236
|
+
core_target=core_target,
|
|
237
|
+
nlu_target=f"run_{RegexMessageHandler.__name__}",
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
def _create_train_nodes(
|
|
241
|
+
self, config: Dict[Text, Any], cli_parameters: Dict[Text, Any]
|
|
242
|
+
) -> Tuple[Dict[Text, SchemaNode], List[Text]]:
|
|
243
|
+
from rasa.graph_components.validators.default_recipe_validator import (
|
|
244
|
+
DefaultV1RecipeValidator,
|
|
245
|
+
)
|
|
246
|
+
from rasa.graph_components.validators.finetuning_validator import (
|
|
247
|
+
FinetuningValidator,
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
train_config = copy.deepcopy(config)
|
|
251
|
+
|
|
252
|
+
train_nodes = {
|
|
253
|
+
"schema_validator": SchemaNode(
|
|
254
|
+
needs={"importer": PLACEHOLDER_IMPORTER},
|
|
255
|
+
uses=DefaultV1RecipeValidator,
|
|
256
|
+
constructor_name="create",
|
|
257
|
+
fn="validate",
|
|
258
|
+
config={},
|
|
259
|
+
is_input=True,
|
|
260
|
+
),
|
|
261
|
+
"finetuning_validator": SchemaNode(
|
|
262
|
+
needs={"importer": "schema_validator"},
|
|
263
|
+
uses=FinetuningValidator,
|
|
264
|
+
constructor_name="load" if self._is_finetuning else "create",
|
|
265
|
+
fn="validate",
|
|
266
|
+
is_input=True,
|
|
267
|
+
config={"validate_core": self._use_core, "validate_nlu": self._use_nlu},
|
|
268
|
+
),
|
|
269
|
+
}
|
|
270
|
+
|
|
271
|
+
preprocessors = []
|
|
272
|
+
|
|
273
|
+
if self._use_nlu:
|
|
274
|
+
preprocessors = self._add_nlu_train_nodes(
|
|
275
|
+
train_config, train_nodes, cli_parameters
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
if self._use_core:
|
|
279
|
+
self._add_core_train_nodes(
|
|
280
|
+
train_config, train_nodes, preprocessors, cli_parameters
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
return train_nodes, preprocessors
|
|
284
|
+
|
|
285
|
+
def _add_nlu_train_nodes(
|
|
286
|
+
self,
|
|
287
|
+
train_config: Dict[Text, Any],
|
|
288
|
+
train_nodes: Dict[Text, SchemaNode],
|
|
289
|
+
cli_parameters: Dict[Text, Any],
|
|
290
|
+
) -> List[Text]:
|
|
291
|
+
train_nodes["flows_provider"] = SchemaNode(
|
|
292
|
+
needs={
|
|
293
|
+
"importer": "finetuning_validator",
|
|
294
|
+
},
|
|
295
|
+
uses=FlowsProvider,
|
|
296
|
+
constructor_name="create",
|
|
297
|
+
fn="provide_train",
|
|
298
|
+
config={},
|
|
299
|
+
is_target=True,
|
|
300
|
+
is_input=True,
|
|
301
|
+
)
|
|
302
|
+
train_nodes["domain_provider"] = SchemaNode(
|
|
303
|
+
needs={
|
|
304
|
+
"importer": "finetuning_validator",
|
|
305
|
+
},
|
|
306
|
+
uses=DomainProvider,
|
|
307
|
+
constructor_name="create",
|
|
308
|
+
fn="provide_train",
|
|
309
|
+
config={},
|
|
310
|
+
is_target=True,
|
|
311
|
+
is_input=True,
|
|
312
|
+
)
|
|
313
|
+
persist_nlu_data = bool(cli_parameters.get("persist_nlu_training_data"))
|
|
314
|
+
train_nodes["nlu_training_data_provider"] = SchemaNode(
|
|
315
|
+
needs={"importer": "finetuning_validator"},
|
|
316
|
+
uses=NLUTrainingDataProvider,
|
|
317
|
+
constructor_name="create",
|
|
318
|
+
fn="provide",
|
|
319
|
+
config={
|
|
320
|
+
"language": train_config.get("language"),
|
|
321
|
+
"persist": persist_nlu_data,
|
|
322
|
+
},
|
|
323
|
+
is_target=persist_nlu_data,
|
|
324
|
+
is_input=True,
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
last_run_node = "nlu_training_data_provider"
|
|
328
|
+
preprocessors: List[Text] = []
|
|
329
|
+
|
|
330
|
+
for idx, config in enumerate(train_config["pipeline"]):
|
|
331
|
+
component_name = config.pop("name")
|
|
332
|
+
component = self._from_registry(component_name)
|
|
333
|
+
component_name = f"{component_name}{idx}"
|
|
334
|
+
|
|
335
|
+
if (
|
|
336
|
+
self.ComponentType.POLICY_WITHOUT_END_TO_END_SUPPORT in component.types
|
|
337
|
+
or self.ComponentType.POLICY_WITH_END_TO_END_SUPPORT in component.types
|
|
338
|
+
):
|
|
339
|
+
raise InvalidConfigException(
|
|
340
|
+
f"Found policy '{component_name}' in NLU pipeline. Policies should "
|
|
341
|
+
f"be defined in the 'policies' section of your configuration."
|
|
342
|
+
)
|
|
343
|
+
if self.ComponentType.MODEL_LOADER in component.types:
|
|
344
|
+
node_name = f"provide_{component_name}"
|
|
345
|
+
train_nodes[node_name] = SchemaNode(
|
|
346
|
+
needs={},
|
|
347
|
+
uses=component.clazz,
|
|
348
|
+
constructor_name="create",
|
|
349
|
+
fn="provide",
|
|
350
|
+
config=config,
|
|
351
|
+
)
|
|
352
|
+
|
|
353
|
+
from_resource = None
|
|
354
|
+
if component.is_trainable:
|
|
355
|
+
from_resource = self._add_nlu_train_node(
|
|
356
|
+
train_nodes,
|
|
357
|
+
component.clazz,
|
|
358
|
+
component_name,
|
|
359
|
+
last_run_node,
|
|
360
|
+
config,
|
|
361
|
+
cli_parameters,
|
|
362
|
+
)
|
|
363
|
+
|
|
364
|
+
if component.types.intersection(
|
|
365
|
+
{
|
|
366
|
+
self.ComponentType.MESSAGE_TOKENIZER,
|
|
367
|
+
self.ComponentType.MESSAGE_FEATURIZER,
|
|
368
|
+
}
|
|
369
|
+
):
|
|
370
|
+
last_run_node = self._add_nlu_process_node(
|
|
371
|
+
train_nodes,
|
|
372
|
+
component.clazz,
|
|
373
|
+
component_name,
|
|
374
|
+
last_run_node,
|
|
375
|
+
config,
|
|
376
|
+
from_resource=from_resource,
|
|
377
|
+
)
|
|
378
|
+
|
|
379
|
+
# Remember for End-to-End-Featurization
|
|
380
|
+
preprocessors.append(last_run_node)
|
|
381
|
+
|
|
382
|
+
return preprocessors
|
|
383
|
+
|
|
384
|
+
def _get_needs_from_args(
|
|
385
|
+
self, component: Type[GraphComponent], fn_name: str
|
|
386
|
+
) -> Dict[str, str]:
|
|
387
|
+
"""Get the needed arguments from the method on the component.
|
|
388
|
+
|
|
389
|
+
Filters out arguments that are already provided by other graph
|
|
390
|
+
components. Does not check if the created providers are actually
|
|
391
|
+
part of the graph. If they aren't an error will be raised later on
|
|
392
|
+
when the graph is validated.
|
|
393
|
+
|
|
394
|
+
Args:
|
|
395
|
+
component: The component class.
|
|
396
|
+
fn_name: The name of the method to inspect.
|
|
397
|
+
|
|
398
|
+
Returns:
|
|
399
|
+
The name of the arguments which need to be provided.
|
|
400
|
+
"""
|
|
401
|
+
from inspect import signature
|
|
402
|
+
|
|
403
|
+
if not hasattr(component, fn_name):
|
|
404
|
+
return {}
|
|
405
|
+
|
|
406
|
+
def resolver_name_from_parameter(parameter: str) -> str:
|
|
407
|
+
# we got a couple special cases to handle where the parameter name
|
|
408
|
+
# doesn't match the provider name
|
|
409
|
+
if "training_trackers" == parameter:
|
|
410
|
+
return "training_tracker_provider"
|
|
411
|
+
elif "tracker" == parameter:
|
|
412
|
+
return PLACEHOLDER_TRACKER
|
|
413
|
+
elif "endpoints" == parameter:
|
|
414
|
+
return PLACEHOLDER_ENDPOINTS
|
|
415
|
+
elif "training_data" == parameter:
|
|
416
|
+
return "nlu_training_data_provider"
|
|
417
|
+
return f"{parameter}_provider"
|
|
418
|
+
|
|
419
|
+
sig = signature(getattr(component, fn_name))
|
|
420
|
+
parameters = {
|
|
421
|
+
name
|
|
422
|
+
for name, param in sig.parameters.items()
|
|
423
|
+
if param.kind == param.POSITIONAL_OR_KEYWORD
|
|
424
|
+
}
|
|
425
|
+
|
|
426
|
+
# filter out parameters which are already resolved in other ways
|
|
427
|
+
unprovided_parameters = parameters - {
|
|
428
|
+
"message",
|
|
429
|
+
"messages",
|
|
430
|
+
"self",
|
|
431
|
+
"model",
|
|
432
|
+
"precomputations",
|
|
433
|
+
}
|
|
434
|
+
|
|
435
|
+
return {
|
|
436
|
+
parameter: resolver_name_from_parameter(parameter)
|
|
437
|
+
for parameter in unprovided_parameters
|
|
438
|
+
}
|
|
439
|
+
|
|
440
|
+
def _add_nlu_train_node(
|
|
441
|
+
self,
|
|
442
|
+
train_nodes: Dict[Text, SchemaNode],
|
|
443
|
+
component: Type[GraphComponent],
|
|
444
|
+
component_name: Text,
|
|
445
|
+
last_run_node: Text,
|
|
446
|
+
config: Dict[Text, Any],
|
|
447
|
+
cli_parameters: Dict[Text, Any],
|
|
448
|
+
) -> Text:
|
|
449
|
+
config_from_cli = self._extra_config_from_cli(cli_parameters, component, config)
|
|
450
|
+
needs = self._get_needs_from_args(component, "train")
|
|
451
|
+
needs.update(self._get_model_provider_needs(train_nodes, component))
|
|
452
|
+
needs["training_data"] = last_run_node
|
|
453
|
+
|
|
454
|
+
train_node_name = f"train_{component_name}"
|
|
455
|
+
train_nodes[train_node_name] = SchemaNode(
|
|
456
|
+
needs=needs,
|
|
457
|
+
uses=component,
|
|
458
|
+
constructor_name="load" if self._is_finetuning else "create",
|
|
459
|
+
fn="train",
|
|
460
|
+
config={**config, **config_from_cli},
|
|
461
|
+
is_target=True,
|
|
462
|
+
)
|
|
463
|
+
return train_node_name
|
|
464
|
+
|
|
465
|
+
def _extra_config_from_cli(
|
|
466
|
+
self,
|
|
467
|
+
cli_parameters: Dict[Text, Any],
|
|
468
|
+
component: Type[GraphComponent],
|
|
469
|
+
component_config: Dict[Text, Any],
|
|
470
|
+
) -> Dict[Text, Any]:
|
|
471
|
+
from rasa.nlu.classifiers.mitie_intent_classifier import MitieIntentClassifier
|
|
472
|
+
from rasa.nlu.extractors.mitie_entity_extractor import MitieEntityExtractor
|
|
473
|
+
from rasa.nlu.classifiers.sklearn_intent_classifier import (
|
|
474
|
+
SklearnIntentClassifier,
|
|
475
|
+
)
|
|
476
|
+
|
|
477
|
+
cli_args_mapping: Dict[Type[GraphComponent], List[Text]] = {
|
|
478
|
+
MitieIntentClassifier: ["num_threads"],
|
|
479
|
+
MitieEntityExtractor: ["num_threads"],
|
|
480
|
+
SklearnIntentClassifier: ["num_threads"],
|
|
481
|
+
}
|
|
482
|
+
|
|
483
|
+
config_from_cli = {
|
|
484
|
+
param: cli_parameters[param]
|
|
485
|
+
for param in cli_args_mapping.get(component, [])
|
|
486
|
+
if param in cli_parameters and cli_parameters[param] is not None
|
|
487
|
+
}
|
|
488
|
+
|
|
489
|
+
if (
|
|
490
|
+
self._is_finetuning
|
|
491
|
+
and "finetuning_epoch_fraction" in cli_parameters
|
|
492
|
+
and EPOCHS in component.get_default_config()
|
|
493
|
+
):
|
|
494
|
+
old_number_epochs = component_config.get(
|
|
495
|
+
EPOCHS, component.get_default_config()[EPOCHS]
|
|
496
|
+
)
|
|
497
|
+
epoch_fraction = cli_parameters["finetuning_epoch_fraction"]
|
|
498
|
+
epoch_fraction = epoch_fraction if epoch_fraction is not None else 1.0
|
|
499
|
+
config_from_cli["finetuning_epoch_fraction"] = epoch_fraction
|
|
500
|
+
config_from_cli[EPOCHS] = math.ceil(
|
|
501
|
+
old_number_epochs * float(epoch_fraction)
|
|
502
|
+
)
|
|
503
|
+
|
|
504
|
+
return config_from_cli
|
|
505
|
+
|
|
506
|
+
def _add_nlu_process_node(
|
|
507
|
+
self,
|
|
508
|
+
train_nodes: Dict[Text, SchemaNode],
|
|
509
|
+
component_class: Type[GraphComponent],
|
|
510
|
+
component_name: Text,
|
|
511
|
+
last_run_node: Text,
|
|
512
|
+
component_config: Dict[Text, Any],
|
|
513
|
+
from_resource: Optional[Text] = None,
|
|
514
|
+
) -> Text:
|
|
515
|
+
needs = self._get_needs_from_args(component_class, "process_training_data")
|
|
516
|
+
needs.update(self._get_model_provider_needs(train_nodes, component_class))
|
|
517
|
+
|
|
518
|
+
if from_resource:
|
|
519
|
+
needs["resource"] = from_resource
|
|
520
|
+
|
|
521
|
+
needs["training_data"] = last_run_node
|
|
522
|
+
|
|
523
|
+
node_name = f"run_{component_name}"
|
|
524
|
+
train_nodes[node_name] = SchemaNode(
|
|
525
|
+
needs=needs,
|
|
526
|
+
uses=component_class,
|
|
527
|
+
constructor_name="load",
|
|
528
|
+
fn="process_training_data",
|
|
529
|
+
config=component_config,
|
|
530
|
+
)
|
|
531
|
+
return node_name
|
|
532
|
+
|
|
533
|
+
def _get_model_provider_needs(
|
|
534
|
+
self, nodes: Dict[Text, SchemaNode], component_class: Type[GraphComponent]
|
|
535
|
+
) -> Dict[Text, Text]:
|
|
536
|
+
model_provider_needs = {}
|
|
537
|
+
component = self._from_registry(component_class.__name__)
|
|
538
|
+
|
|
539
|
+
if not component.model_from:
|
|
540
|
+
return {}
|
|
541
|
+
|
|
542
|
+
node_name_of_provider = next(
|
|
543
|
+
(
|
|
544
|
+
node_name
|
|
545
|
+
for node_name, node in nodes.items()
|
|
546
|
+
if node.uses.__name__ == component.model_from
|
|
547
|
+
),
|
|
548
|
+
None,
|
|
549
|
+
)
|
|
550
|
+
if node_name_of_provider:
|
|
551
|
+
model_provider_needs["model"] = node_name_of_provider
|
|
552
|
+
|
|
553
|
+
return model_provider_needs
|
|
554
|
+
|
|
555
|
+
def _add_core_train_nodes(
|
|
556
|
+
self,
|
|
557
|
+
train_config: Dict[Text, Any],
|
|
558
|
+
train_nodes: Dict[Text, SchemaNode],
|
|
559
|
+
preprocessors: List[Text],
|
|
560
|
+
cli_parameters: Dict[Text, Any],
|
|
561
|
+
) -> None:
|
|
562
|
+
train_nodes["domain_provider"] = SchemaNode(
|
|
563
|
+
needs={"importer": "finetuning_validator"},
|
|
564
|
+
uses=DomainProvider,
|
|
565
|
+
constructor_name="create",
|
|
566
|
+
fn="provide_train",
|
|
567
|
+
config={},
|
|
568
|
+
is_target=True,
|
|
569
|
+
is_input=True,
|
|
570
|
+
)
|
|
571
|
+
train_nodes["domain_for_core_training_provider"] = SchemaNode(
|
|
572
|
+
needs={"domain": "domain_provider"},
|
|
573
|
+
uses=DomainForCoreTrainingProvider,
|
|
574
|
+
constructor_name="create",
|
|
575
|
+
fn="provide",
|
|
576
|
+
config={},
|
|
577
|
+
is_input=True,
|
|
578
|
+
)
|
|
579
|
+
train_nodes["forms_provider"] = SchemaNode(
|
|
580
|
+
needs={"domain": "domain_provider"},
|
|
581
|
+
uses=FormsProvider,
|
|
582
|
+
constructor_name="create",
|
|
583
|
+
fn="provide",
|
|
584
|
+
config={},
|
|
585
|
+
is_input=True,
|
|
586
|
+
)
|
|
587
|
+
train_nodes["responses_provider"] = SchemaNode(
|
|
588
|
+
needs={"domain": "domain_provider"},
|
|
589
|
+
uses=ResponsesProvider,
|
|
590
|
+
constructor_name="create",
|
|
591
|
+
fn="provide",
|
|
592
|
+
config={},
|
|
593
|
+
is_input=True,
|
|
594
|
+
)
|
|
595
|
+
train_nodes["story_graph_provider"] = SchemaNode(
|
|
596
|
+
needs={"importer": "finetuning_validator"},
|
|
597
|
+
uses=StoryGraphProvider,
|
|
598
|
+
constructor_name="create",
|
|
599
|
+
fn="provide_train",
|
|
600
|
+
config={"exclusion_percentage": cli_parameters.get("exclusion_percentage")},
|
|
601
|
+
is_input=True,
|
|
602
|
+
)
|
|
603
|
+
train_nodes["flows_provider"] = SchemaNode(
|
|
604
|
+
needs={
|
|
605
|
+
"importer": "finetuning_validator",
|
|
606
|
+
},
|
|
607
|
+
uses=FlowsProvider,
|
|
608
|
+
constructor_name="create",
|
|
609
|
+
fn="provide_train",
|
|
610
|
+
config={},
|
|
611
|
+
is_target=True,
|
|
612
|
+
is_input=True,
|
|
613
|
+
)
|
|
614
|
+
train_nodes["training_tracker_provider"] = SchemaNode(
|
|
615
|
+
needs={
|
|
616
|
+
"story_graph": "story_graph_provider",
|
|
617
|
+
"domain": "domain_for_core_training_provider",
|
|
618
|
+
},
|
|
619
|
+
uses=TrainingTrackerProvider,
|
|
620
|
+
constructor_name="create",
|
|
621
|
+
fn="provide",
|
|
622
|
+
config={
|
|
623
|
+
param: cli_parameters[param]
|
|
624
|
+
for param in ["debug_plots", "augmentation_factor"]
|
|
625
|
+
if param in cli_parameters
|
|
626
|
+
},
|
|
627
|
+
)
|
|
628
|
+
|
|
629
|
+
policy_with_end_to_end_support_used = False
|
|
630
|
+
for idx, config in enumerate(train_config["policies"]):
|
|
631
|
+
component_name = config.pop("name")
|
|
632
|
+
component = self._from_registry(component_name)
|
|
633
|
+
|
|
634
|
+
extra_config_from_cli = self._extra_config_from_cli(
|
|
635
|
+
cli_parameters, component.clazz, config
|
|
636
|
+
)
|
|
637
|
+
|
|
638
|
+
requires_end_to_end_data = self._use_end_to_end and (
|
|
639
|
+
self.ComponentType.POLICY_WITH_END_TO_END_SUPPORT in component.types
|
|
640
|
+
)
|
|
641
|
+
policy_with_end_to_end_support_used = (
|
|
642
|
+
policy_with_end_to_end_support_used or requires_end_to_end_data
|
|
643
|
+
)
|
|
644
|
+
|
|
645
|
+
needs = self._get_needs_from_args(component.clazz, "train")
|
|
646
|
+
if requires_end_to_end_data:
|
|
647
|
+
needs["precomputations"] = "end_to_end_features_provider"
|
|
648
|
+
# during core training we use a stripped down version of the domain
|
|
649
|
+
needs["domain"] = "domain_for_core_training_provider"
|
|
650
|
+
train_nodes[f"train_{component_name}{idx}"] = SchemaNode(
|
|
651
|
+
needs=needs,
|
|
652
|
+
uses=component.clazz,
|
|
653
|
+
constructor_name="load" if self._is_finetuning else "create",
|
|
654
|
+
fn="train",
|
|
655
|
+
is_target=True,
|
|
656
|
+
config={**config, **extra_config_from_cli},
|
|
657
|
+
)
|
|
658
|
+
|
|
659
|
+
if self._use_end_to_end and policy_with_end_to_end_support_used:
|
|
660
|
+
self._add_end_to_end_features_for_training(preprocessors, train_nodes)
|
|
661
|
+
|
|
662
|
+
def _add_end_to_end_features_for_training(
|
|
663
|
+
self, preprocessors: List[Text], train_nodes: Dict[Text, SchemaNode]
|
|
664
|
+
) -> None:
|
|
665
|
+
train_nodes["story_to_nlu_training_data_converter"] = SchemaNode(
|
|
666
|
+
needs={
|
|
667
|
+
"story_graph": "story_graph_provider",
|
|
668
|
+
"domain": "domain_for_core_training_provider",
|
|
669
|
+
},
|
|
670
|
+
uses=CoreFeaturizationInputConverter,
|
|
671
|
+
constructor_name="create",
|
|
672
|
+
fn="convert_for_training",
|
|
673
|
+
config={},
|
|
674
|
+
is_input=True,
|
|
675
|
+
)
|
|
676
|
+
|
|
677
|
+
last_node_name = "story_to_nlu_training_data_converter"
|
|
678
|
+
for preprocessor in preprocessors:
|
|
679
|
+
node = copy.deepcopy(train_nodes[preprocessor])
|
|
680
|
+
node.needs["training_data"] = last_node_name
|
|
681
|
+
|
|
682
|
+
node_name = f"e2e_{preprocessor}"
|
|
683
|
+
train_nodes[node_name] = node
|
|
684
|
+
last_node_name = node_name
|
|
685
|
+
|
|
686
|
+
node_with_e2e_features = "end_to_end_features_provider"
|
|
687
|
+
train_nodes[node_with_e2e_features] = SchemaNode(
|
|
688
|
+
needs={"messages": last_node_name},
|
|
689
|
+
uses=CoreFeaturizationCollector,
|
|
690
|
+
constructor_name="create",
|
|
691
|
+
fn="collect",
|
|
692
|
+
config={},
|
|
693
|
+
)
|
|
694
|
+
|
|
695
|
+
def _create_predict_nodes(
|
|
696
|
+
self,
|
|
697
|
+
config: Dict[Text, SchemaNode],
|
|
698
|
+
preprocessors: List[Text],
|
|
699
|
+
train_nodes: Dict[Text, SchemaNode],
|
|
700
|
+
) -> Dict[Text, SchemaNode]:
|
|
701
|
+
predict_config = copy.deepcopy(config)
|
|
702
|
+
predict_nodes = {}
|
|
703
|
+
|
|
704
|
+
from rasa.nlu.classifiers.regex_message_handler import RegexMessageHandler
|
|
705
|
+
|
|
706
|
+
predict_nodes["nlu_message_converter"] = SchemaNode(
|
|
707
|
+
**DEFAULT_PREDICT_KWARGS,
|
|
708
|
+
needs={"messages": PLACEHOLDER_MESSAGE},
|
|
709
|
+
uses=NLUMessageConverter,
|
|
710
|
+
fn="convert_user_message",
|
|
711
|
+
config={},
|
|
712
|
+
)
|
|
713
|
+
|
|
714
|
+
last_run_nlu_node = "nlu_message_converter"
|
|
715
|
+
|
|
716
|
+
if self._use_nlu:
|
|
717
|
+
last_run_nlu_node = self._add_nlu_predict_nodes(
|
|
718
|
+
last_run_nlu_node, predict_config, predict_nodes, train_nodes
|
|
719
|
+
)
|
|
720
|
+
|
|
721
|
+
domain_needs = {}
|
|
722
|
+
if self._use_core:
|
|
723
|
+
domain_needs["domain"] = "domain_provider"
|
|
724
|
+
|
|
725
|
+
regex_handler_node_name = f"run_{RegexMessageHandler.__name__}"
|
|
726
|
+
predict_nodes[regex_handler_node_name] = SchemaNode(
|
|
727
|
+
**DEFAULT_PREDICT_KWARGS,
|
|
728
|
+
needs={"messages": last_run_nlu_node, **domain_needs},
|
|
729
|
+
uses=RegexMessageHandler,
|
|
730
|
+
fn="process",
|
|
731
|
+
config={},
|
|
732
|
+
)
|
|
733
|
+
|
|
734
|
+
if self._use_core:
|
|
735
|
+
self._add_core_predict_nodes(
|
|
736
|
+
predict_config, predict_nodes, train_nodes, preprocessors
|
|
737
|
+
)
|
|
738
|
+
|
|
739
|
+
return predict_nodes
|
|
740
|
+
|
|
741
|
+
def _add_nlu_predict_nodes(
|
|
742
|
+
self,
|
|
743
|
+
last_run_node: Text,
|
|
744
|
+
predict_config: Dict[Text, Any],
|
|
745
|
+
predict_nodes: Dict[Text, SchemaNode],
|
|
746
|
+
train_nodes: Dict[Text, SchemaNode],
|
|
747
|
+
) -> Text:
|
|
748
|
+
predict_nodes["flows_provider"] = SchemaNode(
|
|
749
|
+
**DEFAULT_PREDICT_KWARGS,
|
|
750
|
+
needs={},
|
|
751
|
+
uses=FlowsProvider,
|
|
752
|
+
fn="provide_inference",
|
|
753
|
+
config={},
|
|
754
|
+
resource=Resource("flows_provider"),
|
|
755
|
+
)
|
|
756
|
+
predict_nodes["domain_provider"] = SchemaNode(
|
|
757
|
+
**DEFAULT_PREDICT_KWARGS,
|
|
758
|
+
needs={},
|
|
759
|
+
uses=DomainProvider,
|
|
760
|
+
fn="provide_inference",
|
|
761
|
+
config={},
|
|
762
|
+
resource=Resource("domain_provider"),
|
|
763
|
+
)
|
|
764
|
+
|
|
765
|
+
for idx, config in enumerate(predict_config["pipeline"]):
|
|
766
|
+
component_name = config.pop("name")
|
|
767
|
+
component = self._from_registry(component_name)
|
|
768
|
+
component_name = f"{component_name}{idx}"
|
|
769
|
+
if self.ComponentType.MODEL_LOADER in component.types:
|
|
770
|
+
predict_nodes[f"provide_{component_name}"] = SchemaNode(
|
|
771
|
+
**DEFAULT_PREDICT_KWARGS,
|
|
772
|
+
needs={},
|
|
773
|
+
uses=component.clazz,
|
|
774
|
+
fn="provide",
|
|
775
|
+
config=config,
|
|
776
|
+
)
|
|
777
|
+
|
|
778
|
+
if component.types.intersection(
|
|
779
|
+
{
|
|
780
|
+
self.ComponentType.MESSAGE_TOKENIZER,
|
|
781
|
+
self.ComponentType.MESSAGE_FEATURIZER,
|
|
782
|
+
}
|
|
783
|
+
):
|
|
784
|
+
last_run_node = self._add_nlu_predict_node_from_train(
|
|
785
|
+
predict_nodes,
|
|
786
|
+
component_name,
|
|
787
|
+
train_nodes,
|
|
788
|
+
last_run_node,
|
|
789
|
+
config,
|
|
790
|
+
from_resource=component.is_trainable,
|
|
791
|
+
)
|
|
792
|
+
elif component.types.intersection(
|
|
793
|
+
{
|
|
794
|
+
self.ComponentType.INTENT_CLASSIFIER,
|
|
795
|
+
self.ComponentType.ENTITY_EXTRACTOR,
|
|
796
|
+
self.ComponentType.COMMAND_GENERATOR,
|
|
797
|
+
self.ComponentType.COEXISTENCE_ROUTER,
|
|
798
|
+
}
|
|
799
|
+
):
|
|
800
|
+
if component.is_trainable:
|
|
801
|
+
last_run_node = self._add_nlu_predict_node_from_train(
|
|
802
|
+
predict_nodes,
|
|
803
|
+
component_name,
|
|
804
|
+
train_nodes,
|
|
805
|
+
last_run_node,
|
|
806
|
+
config,
|
|
807
|
+
from_resource=component.is_trainable,
|
|
808
|
+
)
|
|
809
|
+
else:
|
|
810
|
+
new_node = SchemaNode(
|
|
811
|
+
needs={"messages": last_run_node},
|
|
812
|
+
uses=component.clazz,
|
|
813
|
+
constructor_name="create",
|
|
814
|
+
fn="process",
|
|
815
|
+
config=config,
|
|
816
|
+
)
|
|
817
|
+
|
|
818
|
+
last_run_node = self._add_nlu_predict_node(
|
|
819
|
+
predict_nodes, new_node, component_name, last_run_node
|
|
820
|
+
)
|
|
821
|
+
|
|
822
|
+
return last_run_node
|
|
823
|
+
|
|
824
|
+
def _add_nlu_predict_node_from_train(
|
|
825
|
+
self,
|
|
826
|
+
predict_nodes: Dict[Text, SchemaNode],
|
|
827
|
+
node_name: Text,
|
|
828
|
+
train_nodes: Dict[Text, SchemaNode],
|
|
829
|
+
last_run_node: Text,
|
|
830
|
+
item_config: Dict[Text, Any],
|
|
831
|
+
from_resource: bool = False,
|
|
832
|
+
) -> Text:
|
|
833
|
+
train_node_name = f"run_{node_name}"
|
|
834
|
+
resource = None
|
|
835
|
+
if from_resource:
|
|
836
|
+
train_node_name = f"train_{node_name}"
|
|
837
|
+
resource = Resource(train_node_name)
|
|
838
|
+
|
|
839
|
+
return self._add_nlu_predict_node(
|
|
840
|
+
predict_nodes,
|
|
841
|
+
dataclasses.replace(
|
|
842
|
+
train_nodes[train_node_name], resource=resource, config=item_config
|
|
843
|
+
),
|
|
844
|
+
node_name,
|
|
845
|
+
last_run_node,
|
|
846
|
+
)
|
|
847
|
+
|
|
848
|
+
def _add_nlu_predict_node(
|
|
849
|
+
self,
|
|
850
|
+
predict_nodes: Dict[Text, SchemaNode],
|
|
851
|
+
node: SchemaNode,
|
|
852
|
+
component_name: Text,
|
|
853
|
+
last_run_node: Text,
|
|
854
|
+
) -> Text:
|
|
855
|
+
node_name = f"run_{component_name}"
|
|
856
|
+
|
|
857
|
+
needs = self._get_needs_from_args(node.uses, "process")
|
|
858
|
+
needs.update(self._get_model_provider_needs(predict_nodes, node.uses))
|
|
859
|
+
needs["messages"] = last_run_node
|
|
860
|
+
predict_nodes[node_name] = dataclasses.replace(
|
|
861
|
+
node,
|
|
862
|
+
needs=needs,
|
|
863
|
+
fn="process",
|
|
864
|
+
**DEFAULT_PREDICT_KWARGS,
|
|
865
|
+
)
|
|
866
|
+
|
|
867
|
+
return node_name
|
|
868
|
+
|
|
869
|
+
def _add_core_predict_nodes(
|
|
870
|
+
self,
|
|
871
|
+
predict_config: Dict[Text, Any],
|
|
872
|
+
predict_nodes: Dict[Text, SchemaNode],
|
|
873
|
+
train_nodes: Dict[Text, SchemaNode],
|
|
874
|
+
preprocessors: List[Text],
|
|
875
|
+
) -> None:
|
|
876
|
+
predict_nodes["domain_provider"] = SchemaNode(
|
|
877
|
+
**DEFAULT_PREDICT_KWARGS,
|
|
878
|
+
needs={},
|
|
879
|
+
uses=DomainProvider,
|
|
880
|
+
fn="provide_inference",
|
|
881
|
+
config={},
|
|
882
|
+
resource=Resource("domain_provider"),
|
|
883
|
+
)
|
|
884
|
+
predict_nodes["story_graph_provider"] = SchemaNode(
|
|
885
|
+
**DEFAULT_PREDICT_KWARGS,
|
|
886
|
+
needs={},
|
|
887
|
+
uses=StoryGraphProvider,
|
|
888
|
+
fn="provide_inference",
|
|
889
|
+
config={},
|
|
890
|
+
resource=Resource("story_graph_provider"),
|
|
891
|
+
)
|
|
892
|
+
predict_nodes["flows_provider"] = SchemaNode(
|
|
893
|
+
**DEFAULT_PREDICT_KWARGS,
|
|
894
|
+
needs={},
|
|
895
|
+
uses=FlowsProvider,
|
|
896
|
+
fn="provide_inference",
|
|
897
|
+
config={},
|
|
898
|
+
resource=Resource("flows_provider"),
|
|
899
|
+
)
|
|
900
|
+
|
|
901
|
+
node_with_e2e_features = None
|
|
902
|
+
|
|
903
|
+
if "end_to_end_features_provider" in train_nodes:
|
|
904
|
+
node_with_e2e_features = self._add_end_to_end_features_for_inference(
|
|
905
|
+
predict_nodes, preprocessors
|
|
906
|
+
)
|
|
907
|
+
|
|
908
|
+
predict_nodes["command_processor"] = SchemaNode(
|
|
909
|
+
**DEFAULT_PREDICT_KWARGS,
|
|
910
|
+
needs=self._get_needs_from_args(
|
|
911
|
+
CommandProcessorComponent, "execute_commands"
|
|
912
|
+
),
|
|
913
|
+
uses=CommandProcessorComponent,
|
|
914
|
+
fn="execute_commands",
|
|
915
|
+
config={},
|
|
916
|
+
resource=Resource("command_processor"),
|
|
917
|
+
)
|
|
918
|
+
|
|
919
|
+
rule_policy_resource = None
|
|
920
|
+
policies: List[Text] = []
|
|
921
|
+
|
|
922
|
+
for idx, config in enumerate(predict_config["policies"]):
|
|
923
|
+
component_name = config.pop("name")
|
|
924
|
+
component = self._from_registry(component_name)
|
|
925
|
+
|
|
926
|
+
train_node_name = f"train_{component_name}{idx}"
|
|
927
|
+
node_name = f"run_{component_name}{idx}"
|
|
928
|
+
|
|
929
|
+
from rasa.core.policies.rule_policy import RulePolicy
|
|
930
|
+
|
|
931
|
+
if issubclass(component.clazz, RulePolicy) and not rule_policy_resource:
|
|
932
|
+
rule_policy_resource = train_node_name
|
|
933
|
+
|
|
934
|
+
needs = self._get_needs_from_args(
|
|
935
|
+
train_nodes[train_node_name].uses, "predict_action_probabilities"
|
|
936
|
+
)
|
|
937
|
+
if (
|
|
938
|
+
self.ComponentType.POLICY_WITH_END_TO_END_SUPPORT in component.types
|
|
939
|
+
and node_with_e2e_features
|
|
940
|
+
):
|
|
941
|
+
needs["precomputations"] = node_with_e2e_features
|
|
942
|
+
|
|
943
|
+
predict_nodes[node_name] = dataclasses.replace(
|
|
944
|
+
train_nodes[train_node_name],
|
|
945
|
+
**DEFAULT_PREDICT_KWARGS,
|
|
946
|
+
needs=needs,
|
|
947
|
+
fn="predict_action_probabilities",
|
|
948
|
+
resource=Resource(train_node_name),
|
|
949
|
+
)
|
|
950
|
+
policies.append(node_name)
|
|
951
|
+
|
|
952
|
+
predict_nodes["rule_only_data_provider"] = SchemaNode(
|
|
953
|
+
**DEFAULT_PREDICT_KWARGS,
|
|
954
|
+
needs={},
|
|
955
|
+
uses=RuleOnlyDataProvider,
|
|
956
|
+
fn="provide",
|
|
957
|
+
config={},
|
|
958
|
+
resource=Resource(rule_policy_resource) if rule_policy_resource else None,
|
|
959
|
+
)
|
|
960
|
+
|
|
961
|
+
predict_nodes["select_prediction"] = SchemaNode(
|
|
962
|
+
**DEFAULT_PREDICT_KWARGS,
|
|
963
|
+
needs={
|
|
964
|
+
**{f"policy{idx}": name for idx, name in enumerate(policies)},
|
|
965
|
+
"domain": "domain_provider",
|
|
966
|
+
"tracker": PLACEHOLDER_TRACKER,
|
|
967
|
+
},
|
|
968
|
+
uses=DefaultPolicyPredictionEnsemble,
|
|
969
|
+
fn="combine_predictions_from_kwargs",
|
|
970
|
+
config={},
|
|
971
|
+
)
|
|
972
|
+
|
|
973
|
+
def _add_end_to_end_features_for_inference(
|
|
974
|
+
self, predict_nodes: Dict[Text, SchemaNode], preprocessors: List[Text]
|
|
975
|
+
) -> Text:
|
|
976
|
+
predict_nodes["tracker_to_message_converter"] = SchemaNode(
|
|
977
|
+
**DEFAULT_PREDICT_KWARGS,
|
|
978
|
+
needs={"tracker": PLACEHOLDER_TRACKER},
|
|
979
|
+
uses=CoreFeaturizationInputConverter,
|
|
980
|
+
fn="convert_for_inference",
|
|
981
|
+
config={},
|
|
982
|
+
)
|
|
983
|
+
|
|
984
|
+
last_node_name = "tracker_to_message_converter"
|
|
985
|
+
for preprocessor in preprocessors:
|
|
986
|
+
node = dataclasses.replace(
|
|
987
|
+
predict_nodes[preprocessor], needs={"messages": last_node_name}
|
|
988
|
+
)
|
|
989
|
+
|
|
990
|
+
node_name = f"e2e_{preprocessor}"
|
|
991
|
+
predict_nodes[node_name] = node
|
|
992
|
+
last_node_name = node_name
|
|
993
|
+
|
|
994
|
+
node_with_e2e_features = "end_to_end_features_provider"
|
|
995
|
+
predict_nodes[node_with_e2e_features] = SchemaNode(
|
|
996
|
+
**DEFAULT_PREDICT_KWARGS,
|
|
997
|
+
needs={"messages": last_node_name},
|
|
998
|
+
uses=CoreFeaturizationCollector,
|
|
999
|
+
fn="collect",
|
|
1000
|
+
config={},
|
|
1001
|
+
)
|
|
1002
|
+
return node_with_e2e_features
|
|
1003
|
+
|
|
1004
|
+
@staticmethod
|
|
1005
|
+
def auto_configure(
|
|
1006
|
+
config_file_path: Optional[Text],
|
|
1007
|
+
config: Dict,
|
|
1008
|
+
training_type: Optional[TrainingType] = TrainingType.BOTH,
|
|
1009
|
+
) -> Tuple[Dict[Text, Any], Set[str], Set[str]]:
|
|
1010
|
+
"""Determine configuration from auto-filled configuration file.
|
|
1011
|
+
|
|
1012
|
+
Keys that are provided and have a value in the file are kept. Keys that are not
|
|
1013
|
+
provided are configured automatically.
|
|
1014
|
+
|
|
1015
|
+
Note that this needs to be called explicitly; ie. we cannot
|
|
1016
|
+
auto-configure automatically from importers because importers are not
|
|
1017
|
+
allowed to access code outside of `rasa.shared`.
|
|
1018
|
+
|
|
1019
|
+
Args:
|
|
1020
|
+
config_file_path: The path to the configuration file.
|
|
1021
|
+
config: Configuration in dictionary format.
|
|
1022
|
+
training_type: Optional training type to auto-configure. By default
|
|
1023
|
+
both core and NLU will be auto-configured.
|
|
1024
|
+
"""
|
|
1025
|
+
missing_keys = DefaultV1Recipe._get_missing_config_keys(config, training_type)
|
|
1026
|
+
keys_to_configure = DefaultV1Recipe._get_unspecified_autoconfigurable_keys(
|
|
1027
|
+
config, training_type
|
|
1028
|
+
)
|
|
1029
|
+
|
|
1030
|
+
if keys_to_configure:
|
|
1031
|
+
config = DefaultV1Recipe.complete_config(config, keys_to_configure)
|
|
1032
|
+
DefaultV1Recipe._dump_config(
|
|
1033
|
+
config, config_file_path, missing_keys, keys_to_configure, training_type
|
|
1034
|
+
)
|
|
1035
|
+
|
|
1036
|
+
return config, missing_keys, keys_to_configure
|
|
1037
|
+
|
|
1038
|
+
@staticmethod
|
|
1039
|
+
def _get_unspecified_autoconfigurable_keys(
|
|
1040
|
+
config: Dict[Text, Any],
|
|
1041
|
+
training_type: Optional[TrainingType] = TrainingType.BOTH,
|
|
1042
|
+
) -> Set[Text]:
|
|
1043
|
+
if training_type == TrainingType.NLU:
|
|
1044
|
+
all_keys = rasa.shared.constants.CONFIG_AUTOCONFIGURABLE_KEYS_NLU
|
|
1045
|
+
elif training_type == TrainingType.CORE:
|
|
1046
|
+
all_keys = rasa.shared.constants.CONFIG_AUTOCONFIGURABLE_KEYS_CORE
|
|
1047
|
+
else:
|
|
1048
|
+
all_keys = rasa.shared.constants.CONFIG_AUTOCONFIGURABLE_KEYS
|
|
1049
|
+
|
|
1050
|
+
return {k for k in all_keys if config.get(k) is None}
|
|
1051
|
+
|
|
1052
|
+
@staticmethod
|
|
1053
|
+
def _get_missing_config_keys(
|
|
1054
|
+
config: Dict[Text, Any],
|
|
1055
|
+
training_type: Optional[TrainingType] = TrainingType.BOTH,
|
|
1056
|
+
) -> Set[Text]:
|
|
1057
|
+
if training_type == TrainingType.NLU:
|
|
1058
|
+
all_keys = rasa.shared.constants.CONFIG_KEYS_NLU
|
|
1059
|
+
elif training_type == TrainingType.CORE:
|
|
1060
|
+
all_keys = rasa.shared.constants.CONFIG_KEYS_CORE
|
|
1061
|
+
else:
|
|
1062
|
+
all_keys = rasa.shared.constants.CONFIG_KEYS
|
|
1063
|
+
|
|
1064
|
+
return {k for k in all_keys if k not in config.keys()}
|
|
1065
|
+
|
|
1066
|
+
@staticmethod
|
|
1067
|
+
def complete_config(
|
|
1068
|
+
config: Dict[Text, Any], keys_to_configure: Set[Text]
|
|
1069
|
+
) -> Dict[Text, Any]:
|
|
1070
|
+
"""Complete a config by adding automatic configuration for the specified keys.
|
|
1071
|
+
|
|
1072
|
+
Args:
|
|
1073
|
+
config: The provided configuration.
|
|
1074
|
+
keys_to_configure: Keys to be configured automatically (e.g. `policies`).
|
|
1075
|
+
|
|
1076
|
+
Returns:
|
|
1077
|
+
The resulting configuration including both the provided and
|
|
1078
|
+
the automatically configured keys.
|
|
1079
|
+
"""
|
|
1080
|
+
import importlib_resources
|
|
1081
|
+
|
|
1082
|
+
if keys_to_configure:
|
|
1083
|
+
logger.debug(
|
|
1084
|
+
f"The provided configuration does not contain the key(s) "
|
|
1085
|
+
f"{transform_collection_to_sentence(keys_to_configure)}. "
|
|
1086
|
+
f"Values will be provided from the default configuration."
|
|
1087
|
+
)
|
|
1088
|
+
|
|
1089
|
+
default_config_file = str(
|
|
1090
|
+
importlib_resources.files(__name__)
|
|
1091
|
+
.joinpath("config_files")
|
|
1092
|
+
.joinpath("default_config.yml")
|
|
1093
|
+
)
|
|
1094
|
+
default_config = read_config_file(default_config_file)
|
|
1095
|
+
|
|
1096
|
+
config = copy.deepcopy(config)
|
|
1097
|
+
for key in keys_to_configure:
|
|
1098
|
+
config[key] = default_config[key]
|
|
1099
|
+
|
|
1100
|
+
return config
|
|
1101
|
+
|
|
1102
|
+
@staticmethod
|
|
1103
|
+
def _dump_config(
|
|
1104
|
+
config: Dict[Text, Any],
|
|
1105
|
+
config_file_path: Text,
|
|
1106
|
+
missing_keys: Set[Text],
|
|
1107
|
+
auto_configured_keys: Set[Text],
|
|
1108
|
+
training_type: Optional[TrainingType] = TrainingType.BOTH,
|
|
1109
|
+
) -> None:
|
|
1110
|
+
"""Dump the automatically configured keys into the config file.
|
|
1111
|
+
|
|
1112
|
+
The configuration provided in the file is kept as it is (preserving the order of
|
|
1113
|
+
keys and comments).
|
|
1114
|
+
For keys that were automatically configured, an explanatory
|
|
1115
|
+
comment is added and the automatically chosen configuration is
|
|
1116
|
+
added commented-out.
|
|
1117
|
+
If there are already blocks with comments from a previous auto
|
|
1118
|
+
configuration run, they are replaced with the new auto
|
|
1119
|
+
configuration.
|
|
1120
|
+
|
|
1121
|
+
Args:
|
|
1122
|
+
config: The configuration including the automatically configured keys.
|
|
1123
|
+
config_file_path: The file into which the configuration should be dumped.
|
|
1124
|
+
missing_keys: Keys that need to be added to the config file.
|
|
1125
|
+
auto_configured_keys: Keys for which a commented out auto
|
|
1126
|
+
configuration section needs to be added to the config file.
|
|
1127
|
+
training_type: NLU, CORE or BOTH depending on which is trained.
|
|
1128
|
+
"""
|
|
1129
|
+
config_as_expected = DefaultV1Recipe._is_config_file_as_expected(
|
|
1130
|
+
config_file_path, missing_keys, auto_configured_keys, training_type
|
|
1131
|
+
)
|
|
1132
|
+
if not config_as_expected:
|
|
1133
|
+
rasa.shared.utils.cli.print_error(
|
|
1134
|
+
f"The configuration file at '{config_file_path}' has been removed or "
|
|
1135
|
+
f"modified while the automatic configuration was running. The current "
|
|
1136
|
+
f"configuration will therefore not be dumped to the file. If you want "
|
|
1137
|
+
f"your model to use the configuration provided in "
|
|
1138
|
+
f"'{config_file_path}' you need to re-run training."
|
|
1139
|
+
)
|
|
1140
|
+
return
|
|
1141
|
+
|
|
1142
|
+
DefaultV1Recipe._add_missing_config_keys_to_file(config_file_path, missing_keys)
|
|
1143
|
+
|
|
1144
|
+
autoconfig_lines = DefaultV1Recipe._get_commented_out_autoconfig_lines(
|
|
1145
|
+
config, auto_configured_keys
|
|
1146
|
+
)
|
|
1147
|
+
|
|
1148
|
+
current_config_content = rasa.shared.utils.io.read_file(config_file_path)
|
|
1149
|
+
current_config_lines = current_config_content.splitlines(keepends=True)
|
|
1150
|
+
|
|
1151
|
+
updated_lines = DefaultV1Recipe._get_lines_including_autoconfig(
|
|
1152
|
+
current_config_lines, autoconfig_lines
|
|
1153
|
+
)
|
|
1154
|
+
|
|
1155
|
+
rasa.shared.utils.io.write_text_file("".join(updated_lines), config_file_path)
|
|
1156
|
+
|
|
1157
|
+
auto_configured_keys_text = transform_collection_to_sentence(
|
|
1158
|
+
auto_configured_keys
|
|
1159
|
+
)
|
|
1160
|
+
rasa.shared.utils.cli.print_info(
|
|
1161
|
+
f"The configuration for {auto_configured_keys_text} "
|
|
1162
|
+
f"was chosen automatically. "
|
|
1163
|
+
f"It was written into the config file at '{config_file_path}'."
|
|
1164
|
+
)
|
|
1165
|
+
|
|
1166
|
+
@staticmethod
|
|
1167
|
+
def _is_config_file_as_expected(
|
|
1168
|
+
config_file_path: Text,
|
|
1169
|
+
missing_keys: Set[Text],
|
|
1170
|
+
auto_configured_keys: Set[Text],
|
|
1171
|
+
training_type: Optional[TrainingType] = TrainingType.BOTH,
|
|
1172
|
+
) -> bool:
|
|
1173
|
+
try:
|
|
1174
|
+
content = read_config_file(config_file_path)
|
|
1175
|
+
except FileNotFoundException:
|
|
1176
|
+
content = {}
|
|
1177
|
+
|
|
1178
|
+
return (
|
|
1179
|
+
bool(content)
|
|
1180
|
+
and missing_keys
|
|
1181
|
+
== DefaultV1Recipe._get_missing_config_keys(content, training_type)
|
|
1182
|
+
and auto_configured_keys
|
|
1183
|
+
== DefaultV1Recipe._get_unspecified_autoconfigurable_keys(
|
|
1184
|
+
content, training_type
|
|
1185
|
+
)
|
|
1186
|
+
)
|
|
1187
|
+
|
|
1188
|
+
@staticmethod
|
|
1189
|
+
def _add_missing_config_keys_to_file(
|
|
1190
|
+
config_file_path: Text, missing_keys: Set[Text]
|
|
1191
|
+
) -> None:
|
|
1192
|
+
if not missing_keys:
|
|
1193
|
+
return
|
|
1194
|
+
with open(
|
|
1195
|
+
config_file_path, "a", encoding=rasa.shared.utils.io.DEFAULT_ENCODING
|
|
1196
|
+
) as f:
|
|
1197
|
+
for key in missing_keys:
|
|
1198
|
+
f.write(f"{key}:\n")
|
|
1199
|
+
|
|
1200
|
+
@staticmethod
|
|
1201
|
+
def _get_lines_including_autoconfig(
|
|
1202
|
+
lines: List[Text], autoconfig_lines: Dict[Text, List[Text]]
|
|
1203
|
+
) -> List[Text]:
|
|
1204
|
+
auto_configured_keys = autoconfig_lines.keys()
|
|
1205
|
+
|
|
1206
|
+
lines_with_autoconfig = []
|
|
1207
|
+
remove_comments_until_next_uncommented_line = False
|
|
1208
|
+
for line in lines:
|
|
1209
|
+
insert_section = None
|
|
1210
|
+
|
|
1211
|
+
# remove old auto configuration
|
|
1212
|
+
if remove_comments_until_next_uncommented_line:
|
|
1213
|
+
if line.startswith("#"):
|
|
1214
|
+
continue
|
|
1215
|
+
remove_comments_until_next_uncommented_line = False
|
|
1216
|
+
|
|
1217
|
+
# add an explanatory comment to autoconfigured sections
|
|
1218
|
+
for key in auto_configured_keys:
|
|
1219
|
+
if line.startswith(f"{key}:"): # start of next auto-section
|
|
1220
|
+
line = line + COMMENTS_FOR_KEYS[key]
|
|
1221
|
+
insert_section = key
|
|
1222
|
+
remove_comments_until_next_uncommented_line = True
|
|
1223
|
+
|
|
1224
|
+
lines_with_autoconfig.append(line)
|
|
1225
|
+
|
|
1226
|
+
if not insert_section:
|
|
1227
|
+
continue
|
|
1228
|
+
|
|
1229
|
+
# add the autoconfiguration (commented out)
|
|
1230
|
+
lines_with_autoconfig += autoconfig_lines[insert_section]
|
|
1231
|
+
|
|
1232
|
+
return lines_with_autoconfig
|
|
1233
|
+
|
|
1234
|
+
@staticmethod
|
|
1235
|
+
def _get_commented_out_autoconfig_lines(
|
|
1236
|
+
config: Dict[Text, Any], auto_configured_keys: Set[Text]
|
|
1237
|
+
) -> Dict[Text, List[Text]]:
|
|
1238
|
+
import ruamel.yaml
|
|
1239
|
+
import ruamel.yaml.compat
|
|
1240
|
+
|
|
1241
|
+
yaml_parser = ruamel.yaml.YAML()
|
|
1242
|
+
yaml_parser.indent(mapping=2, sequence=4, offset=2)
|
|
1243
|
+
|
|
1244
|
+
autoconfig_lines = {}
|
|
1245
|
+
|
|
1246
|
+
for key in auto_configured_keys:
|
|
1247
|
+
stream = ruamel.yaml.compat.StringIO()
|
|
1248
|
+
yaml_parser.dump(config.get(key), stream)
|
|
1249
|
+
dump = stream.getvalue()
|
|
1250
|
+
|
|
1251
|
+
lines = dump.split("\n")
|
|
1252
|
+
if not lines[-1]:
|
|
1253
|
+
lines = lines[:-1] # yaml dump adds an empty line at the end
|
|
1254
|
+
lines = [f"# {line}\n" for line in lines]
|
|
1255
|
+
|
|
1256
|
+
autoconfig_lines[key] = lines
|
|
1257
|
+
|
|
1258
|
+
return autoconfig_lines
|