rasa-pro 3.12.0.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- README.md +41 -0
- rasa/__init__.py +9 -0
- rasa/__main__.py +177 -0
- rasa/anonymization/__init__.py +2 -0
- rasa/anonymization/anonymisation_rule_yaml_reader.py +91 -0
- rasa/anonymization/anonymization_pipeline.py +286 -0
- rasa/anonymization/anonymization_rule_executor.py +260 -0
- rasa/anonymization/anonymization_rule_orchestrator.py +120 -0
- rasa/anonymization/schemas/config.yml +47 -0
- rasa/anonymization/utils.py +118 -0
- rasa/api.py +160 -0
- rasa/cli/__init__.py +5 -0
- rasa/cli/arguments/__init__.py +0 -0
- rasa/cli/arguments/data.py +106 -0
- rasa/cli/arguments/default_arguments.py +207 -0
- rasa/cli/arguments/evaluate.py +65 -0
- rasa/cli/arguments/export.py +51 -0
- rasa/cli/arguments/interactive.py +74 -0
- rasa/cli/arguments/run.py +219 -0
- rasa/cli/arguments/shell.py +17 -0
- rasa/cli/arguments/test.py +211 -0
- rasa/cli/arguments/train.py +279 -0
- rasa/cli/arguments/visualize.py +34 -0
- rasa/cli/arguments/x.py +30 -0
- rasa/cli/data.py +354 -0
- rasa/cli/dialogue_understanding_test.py +251 -0
- rasa/cli/e2e_test.py +259 -0
- rasa/cli/evaluate.py +222 -0
- rasa/cli/export.py +250 -0
- rasa/cli/inspect.py +75 -0
- rasa/cli/interactive.py +166 -0
- rasa/cli/license.py +65 -0
- rasa/cli/llm_fine_tuning.py +403 -0
- rasa/cli/markers.py +78 -0
- rasa/cli/project_templates/__init__.py +0 -0
- rasa/cli/project_templates/calm/actions/__init__.py +0 -0
- rasa/cli/project_templates/calm/actions/action_template.py +27 -0
- rasa/cli/project_templates/calm/actions/add_contact.py +30 -0
- rasa/cli/project_templates/calm/actions/db.py +57 -0
- rasa/cli/project_templates/calm/actions/list_contacts.py +22 -0
- rasa/cli/project_templates/calm/actions/remove_contact.py +35 -0
- rasa/cli/project_templates/calm/config.yml +10 -0
- rasa/cli/project_templates/calm/credentials.yml +33 -0
- rasa/cli/project_templates/calm/data/flows/add_contact.yml +31 -0
- rasa/cli/project_templates/calm/data/flows/list_contacts.yml +14 -0
- rasa/cli/project_templates/calm/data/flows/remove_contact.yml +29 -0
- rasa/cli/project_templates/calm/db/contacts.json +10 -0
- rasa/cli/project_templates/calm/domain/add_contact.yml +39 -0
- rasa/cli/project_templates/calm/domain/list_contacts.yml +17 -0
- rasa/cli/project_templates/calm/domain/remove_contact.yml +38 -0
- rasa/cli/project_templates/calm/domain/shared.yml +10 -0
- rasa/cli/project_templates/calm/e2e_tests/cancelations/user_cancels_during_a_correction.yml +16 -0
- rasa/cli/project_templates/calm/e2e_tests/cancelations/user_changes_mind_on_a_whim.yml +7 -0
- rasa/cli/project_templates/calm/e2e_tests/corrections/user_corrects_contact_handle.yml +20 -0
- rasa/cli/project_templates/calm/e2e_tests/corrections/user_corrects_contact_name.yml +19 -0
- rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_adds_contact_to_their_list.yml +15 -0
- rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_lists_contacts.yml +5 -0
- rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_removes_contact.yml +11 -0
- rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_removes_contact_from_list.yml +12 -0
- rasa/cli/project_templates/calm/endpoints.yml +58 -0
- rasa/cli/project_templates/default/actions/__init__.py +0 -0
- rasa/cli/project_templates/default/actions/actions.py +27 -0
- rasa/cli/project_templates/default/config.yml +44 -0
- rasa/cli/project_templates/default/credentials.yml +33 -0
- rasa/cli/project_templates/default/data/nlu.yml +91 -0
- rasa/cli/project_templates/default/data/rules.yml +13 -0
- rasa/cli/project_templates/default/data/stories.yml +30 -0
- rasa/cli/project_templates/default/domain.yml +34 -0
- rasa/cli/project_templates/default/endpoints.yml +42 -0
- rasa/cli/project_templates/default/tests/test_stories.yml +91 -0
- rasa/cli/project_templates/tutorial/actions/__init__.py +0 -0
- rasa/cli/project_templates/tutorial/actions/actions.py +22 -0
- rasa/cli/project_templates/tutorial/config.yml +12 -0
- rasa/cli/project_templates/tutorial/credentials.yml +33 -0
- rasa/cli/project_templates/tutorial/data/flows.yml +8 -0
- rasa/cli/project_templates/tutorial/data/patterns.yml +11 -0
- rasa/cli/project_templates/tutorial/domain.yml +35 -0
- rasa/cli/project_templates/tutorial/endpoints.yml +55 -0
- rasa/cli/run.py +143 -0
- rasa/cli/scaffold.py +273 -0
- rasa/cli/shell.py +141 -0
- rasa/cli/studio/__init__.py +0 -0
- rasa/cli/studio/download.py +62 -0
- rasa/cli/studio/studio.py +296 -0
- rasa/cli/studio/train.py +59 -0
- rasa/cli/studio/upload.py +62 -0
- rasa/cli/telemetry.py +102 -0
- rasa/cli/test.py +280 -0
- rasa/cli/train.py +278 -0
- rasa/cli/utils.py +484 -0
- rasa/cli/visualize.py +40 -0
- rasa/cli/x.py +206 -0
- rasa/constants.py +45 -0
- rasa/core/__init__.py +17 -0
- rasa/core/actions/__init__.py +0 -0
- rasa/core/actions/action.py +1318 -0
- rasa/core/actions/action_clean_stack.py +59 -0
- rasa/core/actions/action_exceptions.py +24 -0
- rasa/core/actions/action_hangup.py +29 -0
- rasa/core/actions/action_repeat_bot_messages.py +89 -0
- rasa/core/actions/action_run_slot_rejections.py +210 -0
- rasa/core/actions/action_trigger_chitchat.py +31 -0
- rasa/core/actions/action_trigger_flow.py +109 -0
- rasa/core/actions/action_trigger_search.py +31 -0
- rasa/core/actions/constants.py +5 -0
- rasa/core/actions/custom_action_executor.py +191 -0
- rasa/core/actions/direct_custom_actions_executor.py +109 -0
- rasa/core/actions/e2e_stub_custom_action_executor.py +72 -0
- rasa/core/actions/forms.py +741 -0
- rasa/core/actions/grpc_custom_action_executor.py +251 -0
- rasa/core/actions/http_custom_action_executor.py +145 -0
- rasa/core/actions/loops.py +114 -0
- rasa/core/actions/two_stage_fallback.py +186 -0
- rasa/core/agent.py +559 -0
- rasa/core/auth_retry_tracker_store.py +122 -0
- rasa/core/brokers/__init__.py +0 -0
- rasa/core/brokers/broker.py +126 -0
- rasa/core/brokers/file.py +58 -0
- rasa/core/brokers/kafka.py +324 -0
- rasa/core/brokers/pika.py +388 -0
- rasa/core/brokers/sql.py +86 -0
- rasa/core/channels/__init__.py +61 -0
- rasa/core/channels/botframework.py +338 -0
- rasa/core/channels/callback.py +84 -0
- rasa/core/channels/channel.py +456 -0
- rasa/core/channels/console.py +241 -0
- rasa/core/channels/development_inspector.py +197 -0
- rasa/core/channels/facebook.py +419 -0
- rasa/core/channels/hangouts.py +329 -0
- rasa/core/channels/inspector/.eslintrc.cjs +25 -0
- rasa/core/channels/inspector/.gitignore +23 -0
- rasa/core/channels/inspector/README.md +54 -0
- rasa/core/channels/inspector/assets/favicon.ico +0 -0
- rasa/core/channels/inspector/assets/rasa-chat.js +2 -0
- rasa/core/channels/inspector/custom.d.ts +3 -0
- rasa/core/channels/inspector/dist/assets/arc-861ddd57.js +1 -0
- rasa/core/channels/inspector/dist/assets/array-9f3ba611.js +1 -0
- rasa/core/channels/inspector/dist/assets/c4Diagram-d0fbc5ce-921f02db.js +10 -0
- rasa/core/channels/inspector/dist/assets/classDiagram-936ed81e-b436c4f8.js +2 -0
- rasa/core/channels/inspector/dist/assets/classDiagram-v2-c3cb15f1-511a23cb.js +2 -0
- rasa/core/channels/inspector/dist/assets/createText-62fc7601-ef476ecd.js +7 -0
- rasa/core/channels/inspector/dist/assets/edges-f2ad444c-f1878e0a.js +4 -0
- rasa/core/channels/inspector/dist/assets/erDiagram-9d236eb7-fac75185.js +51 -0
- rasa/core/channels/inspector/dist/assets/flowDb-1972c806-201c5bbc.js +6 -0
- rasa/core/channels/inspector/dist/assets/flowDiagram-7ea5b25a-f904ae41.js +4 -0
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-b080d6f2.js +1 -0
- rasa/core/channels/inspector/dist/assets/flowchart-elk-definition-abe16c3d-1813da66.js +139 -0
- rasa/core/channels/inspector/dist/assets/ganttDiagram-9b5ea136-872af172.js +266 -0
- rasa/core/channels/inspector/dist/assets/gitGraphDiagram-99d0ae7c-34a0af5a.js +70 -0
- rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-128cfa44.ttf +0 -0
- rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-21dbcb97.woff +0 -0
- rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-222b5e26.svg +329 -0
- rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-9ad89b2a.woff2 +0 -0
- rasa/core/channels/inspector/dist/assets/index-2c4b9a3b-42ba3e3d.js +1 -0
- rasa/core/channels/inspector/dist/assets/index-37817b51.js +1317 -0
- rasa/core/channels/inspector/dist/assets/index-3ee28881.css +1 -0
- rasa/core/channels/inspector/dist/assets/infoDiagram-736b4530-6b731386.js +7 -0
- rasa/core/channels/inspector/dist/assets/init-77b53fdd.js +1 -0
- rasa/core/channels/inspector/dist/assets/journeyDiagram-df861f2b-e8579ac6.js +139 -0
- rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-60c05ee4.woff +0 -0
- rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-8335d9b8.svg +438 -0
- rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-9cc39c75.ttf +0 -0
- rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-ead13ccf.woff2 +0 -0
- rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-16705655.woff2 +0 -0
- rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-5aeb07f9.woff +0 -0
- rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-9c459044.ttf +0 -0
- rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-9e2898a4.svg +435 -0
- rasa/core/channels/inspector/dist/assets/layout-89e6403a.js +1 -0
- rasa/core/channels/inspector/dist/assets/line-dc73d3fc.js +1 -0
- rasa/core/channels/inspector/dist/assets/linear-f5b1d2bc.js +1 -0
- rasa/core/channels/inspector/dist/assets/mindmap-definition-beec6740-82cb74fa.js +109 -0
- rasa/core/channels/inspector/dist/assets/ordinal-ba9b4969.js +1 -0
- rasa/core/channels/inspector/dist/assets/path-53f90ab3.js +1 -0
- rasa/core/channels/inspector/dist/assets/pieDiagram-dbbf0591-bdf5f29b.js +35 -0
- rasa/core/channels/inspector/dist/assets/quadrantDiagram-4d7f4fd6-c7a0cbe4.js +7 -0
- rasa/core/channels/inspector/dist/assets/requirementDiagram-6fc4c22a-7ec5410f.js +52 -0
- rasa/core/channels/inspector/dist/assets/sankeyDiagram-8f13d901-caee5554.js +8 -0
- rasa/core/channels/inspector/dist/assets/sequenceDiagram-b655622a-2935f8db.js +122 -0
- rasa/core/channels/inspector/dist/assets/stateDiagram-59f0c015-8f5d9693.js +1 -0
- rasa/core/channels/inspector/dist/assets/stateDiagram-v2-2b26beab-d565d1de.js +1 -0
- rasa/core/channels/inspector/dist/assets/styles-080da4f6-75ad421d.js +110 -0
- rasa/core/channels/inspector/dist/assets/styles-3dcbcfbf-7e764226.js +159 -0
- rasa/core/channels/inspector/dist/assets/styles-9c745c82-7a4e0e61.js +207 -0
- rasa/core/channels/inspector/dist/assets/svgDrawCommon-4835440b-4019d1bf.js +1 -0
- rasa/core/channels/inspector/dist/assets/timeline-definition-5b62e21b-01ea12df.js +61 -0
- rasa/core/channels/inspector/dist/assets/xychartDiagram-2b33534f-89407137.js +7 -0
- rasa/core/channels/inspector/dist/index.html +42 -0
- rasa/core/channels/inspector/index.html +40 -0
- rasa/core/channels/inspector/jest.config.ts +13 -0
- rasa/core/channels/inspector/package.json +52 -0
- rasa/core/channels/inspector/setupTests.ts +2 -0
- rasa/core/channels/inspector/src/App.tsx +220 -0
- rasa/core/channels/inspector/src/components/Chat.tsx +95 -0
- rasa/core/channels/inspector/src/components/DiagramFlow.tsx +108 -0
- rasa/core/channels/inspector/src/components/DialogueInformation.tsx +187 -0
- rasa/core/channels/inspector/src/components/DialogueStack.tsx +136 -0
- rasa/core/channels/inspector/src/components/ExpandIcon.tsx +16 -0
- rasa/core/channels/inspector/src/components/FullscreenButton.tsx +45 -0
- rasa/core/channels/inspector/src/components/LoadingSpinner.tsx +22 -0
- rasa/core/channels/inspector/src/components/NoActiveFlow.tsx +21 -0
- rasa/core/channels/inspector/src/components/RasaLogo.tsx +32 -0
- rasa/core/channels/inspector/src/components/SaraDiagrams.tsx +39 -0
- rasa/core/channels/inspector/src/components/Slots.tsx +91 -0
- rasa/core/channels/inspector/src/components/Welcome.tsx +54 -0
- rasa/core/channels/inspector/src/helpers/audiostream.ts +191 -0
- rasa/core/channels/inspector/src/helpers/formatters.test.ts +392 -0
- rasa/core/channels/inspector/src/helpers/formatters.ts +306 -0
- rasa/core/channels/inspector/src/helpers/utils.ts +127 -0
- rasa/core/channels/inspector/src/main.tsx +13 -0
- rasa/core/channels/inspector/src/theme/Button/Button.ts +29 -0
- rasa/core/channels/inspector/src/theme/Heading/Heading.ts +31 -0
- rasa/core/channels/inspector/src/theme/Input/Input.ts +27 -0
- rasa/core/channels/inspector/src/theme/Link/Link.ts +10 -0
- rasa/core/channels/inspector/src/theme/Modal/Modal.ts +47 -0
- rasa/core/channels/inspector/src/theme/Table/Table.tsx +38 -0
- rasa/core/channels/inspector/src/theme/Tooltip/Tooltip.ts +12 -0
- rasa/core/channels/inspector/src/theme/base/breakpoints.ts +8 -0
- rasa/core/channels/inspector/src/theme/base/colors.ts +88 -0
- rasa/core/channels/inspector/src/theme/base/fonts/fontFaces.css +29 -0
- rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.eot +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.svg +329 -0
- rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.ttf +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.woff +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.woff2 +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.eot +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.svg +438 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.ttf +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.woff +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.woff2 +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.eot +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.svg +435 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.ttf +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.woff +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.woff2 +0 -0
- rasa/core/channels/inspector/src/theme/base/radii.ts +9 -0
- rasa/core/channels/inspector/src/theme/base/shadows.ts +7 -0
- rasa/core/channels/inspector/src/theme/base/sizes.ts +7 -0
- rasa/core/channels/inspector/src/theme/base/space.ts +15 -0
- rasa/core/channels/inspector/src/theme/base/styles.ts +13 -0
- rasa/core/channels/inspector/src/theme/base/typography.ts +24 -0
- rasa/core/channels/inspector/src/theme/base/zIndices.ts +19 -0
- rasa/core/channels/inspector/src/theme/index.ts +101 -0
- rasa/core/channels/inspector/src/types.ts +84 -0
- rasa/core/channels/inspector/src/vite-env.d.ts +1 -0
- rasa/core/channels/inspector/tests/__mocks__/fileMock.ts +1 -0
- rasa/core/channels/inspector/tests/__mocks__/matchMedia.ts +16 -0
- rasa/core/channels/inspector/tests/__mocks__/styleMock.ts +1 -0
- rasa/core/channels/inspector/tests/renderWithProviders.tsx +14 -0
- rasa/core/channels/inspector/tsconfig.json +26 -0
- rasa/core/channels/inspector/tsconfig.node.json +10 -0
- rasa/core/channels/inspector/vite.config.ts +8 -0
- rasa/core/channels/inspector/yarn.lock +6249 -0
- rasa/core/channels/mattermost.py +229 -0
- rasa/core/channels/rasa_chat.py +126 -0
- rasa/core/channels/rest.py +230 -0
- rasa/core/channels/rocketchat.py +174 -0
- rasa/core/channels/slack.py +620 -0
- rasa/core/channels/socketio.py +302 -0
- rasa/core/channels/telegram.py +298 -0
- rasa/core/channels/twilio.py +169 -0
- rasa/core/channels/vier_cvg.py +374 -0
- rasa/core/channels/voice_ready/__init__.py +0 -0
- rasa/core/channels/voice_ready/audiocodes.py +501 -0
- rasa/core/channels/voice_ready/jambonz.py +121 -0
- rasa/core/channels/voice_ready/jambonz_protocol.py +396 -0
- rasa/core/channels/voice_ready/twilio_voice.py +403 -0
- rasa/core/channels/voice_ready/utils.py +37 -0
- rasa/core/channels/voice_stream/__init__.py +0 -0
- rasa/core/channels/voice_stream/asr/__init__.py +0 -0
- rasa/core/channels/voice_stream/asr/asr_engine.py +89 -0
- rasa/core/channels/voice_stream/asr/asr_event.py +18 -0
- rasa/core/channels/voice_stream/asr/azure.py +130 -0
- rasa/core/channels/voice_stream/asr/deepgram.py +90 -0
- rasa/core/channels/voice_stream/audio_bytes.py +8 -0
- rasa/core/channels/voice_stream/browser_audio.py +107 -0
- rasa/core/channels/voice_stream/call_state.py +23 -0
- rasa/core/channels/voice_stream/tts/__init__.py +0 -0
- rasa/core/channels/voice_stream/tts/azure.py +106 -0
- rasa/core/channels/voice_stream/tts/cartesia.py +118 -0
- rasa/core/channels/voice_stream/tts/tts_cache.py +27 -0
- rasa/core/channels/voice_stream/tts/tts_engine.py +58 -0
- rasa/core/channels/voice_stream/twilio_media_streams.py +173 -0
- rasa/core/channels/voice_stream/util.py +57 -0
- rasa/core/channels/voice_stream/voice_channel.py +427 -0
- rasa/core/channels/webexteams.py +134 -0
- rasa/core/concurrent_lock_store.py +210 -0
- rasa/core/constants.py +112 -0
- rasa/core/evaluation/__init__.py +0 -0
- rasa/core/evaluation/marker.py +267 -0
- rasa/core/evaluation/marker_base.py +923 -0
- rasa/core/evaluation/marker_stats.py +293 -0
- rasa/core/evaluation/marker_tracker_loader.py +103 -0
- rasa/core/exceptions.py +29 -0
- rasa/core/exporter.py +284 -0
- rasa/core/featurizers/__init__.py +0 -0
- rasa/core/featurizers/precomputation.py +410 -0
- rasa/core/featurizers/single_state_featurizer.py +421 -0
- rasa/core/featurizers/tracker_featurizers.py +1262 -0
- rasa/core/http_interpreter.py +89 -0
- rasa/core/information_retrieval/__init__.py +7 -0
- rasa/core/information_retrieval/faiss.py +124 -0
- rasa/core/information_retrieval/information_retrieval.py +137 -0
- rasa/core/information_retrieval/milvus.py +59 -0
- rasa/core/information_retrieval/qdrant.py +96 -0
- rasa/core/jobs.py +63 -0
- rasa/core/lock.py +139 -0
- rasa/core/lock_store.py +343 -0
- rasa/core/migrate.py +403 -0
- rasa/core/nlg/__init__.py +3 -0
- rasa/core/nlg/callback.py +146 -0
- rasa/core/nlg/contextual_response_rephraser.py +320 -0
- rasa/core/nlg/generator.py +230 -0
- rasa/core/nlg/interpolator.py +143 -0
- rasa/core/nlg/response.py +155 -0
- rasa/core/nlg/summarize.py +70 -0
- rasa/core/persistor.py +538 -0
- rasa/core/policies/__init__.py +0 -0
- rasa/core/policies/ensemble.py +329 -0
- rasa/core/policies/enterprise_search_policy.py +905 -0
- rasa/core/policies/enterprise_search_prompt_template.jinja2 +25 -0
- rasa/core/policies/enterprise_search_prompt_with_citation_template.jinja2 +60 -0
- rasa/core/policies/flow_policy.py +205 -0
- rasa/core/policies/flows/__init__.py +0 -0
- rasa/core/policies/flows/flow_exceptions.py +44 -0
- rasa/core/policies/flows/flow_executor.py +754 -0
- rasa/core/policies/flows/flow_step_result.py +43 -0
- rasa/core/policies/intentless_policy.py +1031 -0
- rasa/core/policies/intentless_prompt_template.jinja2 +22 -0
- rasa/core/policies/memoization.py +538 -0
- rasa/core/policies/policy.py +725 -0
- rasa/core/policies/rule_policy.py +1273 -0
- rasa/core/policies/ted_policy.py +2169 -0
- rasa/core/policies/unexpected_intent_policy.py +1022 -0
- rasa/core/processor.py +1465 -0
- rasa/core/run.py +342 -0
- rasa/core/secrets_manager/__init__.py +0 -0
- rasa/core/secrets_manager/constants.py +36 -0
- rasa/core/secrets_manager/endpoints.py +391 -0
- rasa/core/secrets_manager/factory.py +241 -0
- rasa/core/secrets_manager/secret_manager.py +262 -0
- rasa/core/secrets_manager/vault.py +584 -0
- rasa/core/test.py +1335 -0
- rasa/core/tracker_store.py +1703 -0
- rasa/core/train.py +105 -0
- rasa/core/training/__init__.py +89 -0
- rasa/core/training/converters/__init__.py +0 -0
- rasa/core/training/converters/responses_prefix_converter.py +119 -0
- rasa/core/training/interactive.py +1744 -0
- rasa/core/training/story_conflict.py +381 -0
- rasa/core/training/training.py +93 -0
- rasa/core/utils.py +366 -0
- rasa/core/visualize.py +70 -0
- rasa/dialogue_understanding/__init__.py +0 -0
- rasa/dialogue_understanding/coexistence/__init__.py +0 -0
- rasa/dialogue_understanding/coexistence/constants.py +4 -0
- rasa/dialogue_understanding/coexistence/intent_based_router.py +196 -0
- rasa/dialogue_understanding/coexistence/llm_based_router.py +327 -0
- rasa/dialogue_understanding/coexistence/router_template.jinja2 +12 -0
- rasa/dialogue_understanding/commands/__init__.py +61 -0
- rasa/dialogue_understanding/commands/can_not_handle_command.py +70 -0
- rasa/dialogue_understanding/commands/cancel_flow_command.py +125 -0
- rasa/dialogue_understanding/commands/change_flow_command.py +44 -0
- rasa/dialogue_understanding/commands/chit_chat_answer_command.py +57 -0
- rasa/dialogue_understanding/commands/clarify_command.py +86 -0
- rasa/dialogue_understanding/commands/command.py +85 -0
- rasa/dialogue_understanding/commands/correct_slots_command.py +297 -0
- rasa/dialogue_understanding/commands/error_command.py +79 -0
- rasa/dialogue_understanding/commands/free_form_answer_command.py +9 -0
- rasa/dialogue_understanding/commands/handle_code_change_command.py +73 -0
- rasa/dialogue_understanding/commands/human_handoff_command.py +66 -0
- rasa/dialogue_understanding/commands/knowledge_answer_command.py +57 -0
- rasa/dialogue_understanding/commands/noop_command.py +54 -0
- rasa/dialogue_understanding/commands/repeat_bot_messages_command.py +60 -0
- rasa/dialogue_understanding/commands/restart_command.py +58 -0
- rasa/dialogue_understanding/commands/session_end_command.py +61 -0
- rasa/dialogue_understanding/commands/session_start_command.py +59 -0
- rasa/dialogue_understanding/commands/set_slot_command.py +160 -0
- rasa/dialogue_understanding/commands/skip_question_command.py +75 -0
- rasa/dialogue_understanding/commands/start_flow_command.py +107 -0
- rasa/dialogue_understanding/commands/user_silence_command.py +59 -0
- rasa/dialogue_understanding/commands/utils.py +45 -0
- rasa/dialogue_understanding/generator/__init__.py +21 -0
- rasa/dialogue_understanding/generator/command_generator.py +464 -0
- rasa/dialogue_understanding/generator/constants.py +27 -0
- rasa/dialogue_understanding/generator/flow_document_template.jinja2 +4 -0
- rasa/dialogue_understanding/generator/flow_retrieval.py +466 -0
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +500 -0
- rasa/dialogue_understanding/generator/llm_command_generator.py +67 -0
- rasa/dialogue_understanding/generator/multi_step/__init__.py +0 -0
- rasa/dialogue_understanding/generator/multi_step/fill_slots_prompt.jinja2 +62 -0
- rasa/dialogue_understanding/generator/multi_step/handle_flows_prompt.jinja2 +38 -0
- rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +920 -0
- rasa/dialogue_understanding/generator/nlu_command_adapter.py +261 -0
- rasa/dialogue_understanding/generator/single_step/__init__.py +0 -0
- rasa/dialogue_understanding/generator/single_step/command_prompt_template.jinja2 +60 -0
- rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +486 -0
- rasa/dialogue_understanding/patterns/__init__.py +0 -0
- rasa/dialogue_understanding/patterns/cancel.py +111 -0
- rasa/dialogue_understanding/patterns/cannot_handle.py +43 -0
- rasa/dialogue_understanding/patterns/chitchat.py +37 -0
- rasa/dialogue_understanding/patterns/clarify.py +97 -0
- rasa/dialogue_understanding/patterns/code_change.py +41 -0
- rasa/dialogue_understanding/patterns/collect_information.py +90 -0
- rasa/dialogue_understanding/patterns/completed.py +40 -0
- rasa/dialogue_understanding/patterns/continue_interrupted.py +42 -0
- rasa/dialogue_understanding/patterns/correction.py +278 -0
- rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +301 -0
- rasa/dialogue_understanding/patterns/human_handoff.py +37 -0
- rasa/dialogue_understanding/patterns/internal_error.py +47 -0
- rasa/dialogue_understanding/patterns/repeat.py +37 -0
- rasa/dialogue_understanding/patterns/restart.py +37 -0
- rasa/dialogue_understanding/patterns/search.py +37 -0
- rasa/dialogue_understanding/patterns/session_start.py +37 -0
- rasa/dialogue_understanding/patterns/skip_question.py +38 -0
- rasa/dialogue_understanding/patterns/user_silence.py +37 -0
- rasa/dialogue_understanding/processor/__init__.py +0 -0
- rasa/dialogue_understanding/processor/command_processor.py +720 -0
- rasa/dialogue_understanding/processor/command_processor_component.py +43 -0
- rasa/dialogue_understanding/stack/__init__.py +0 -0
- rasa/dialogue_understanding/stack/dialogue_stack.py +178 -0
- rasa/dialogue_understanding/stack/frames/__init__.py +19 -0
- rasa/dialogue_understanding/stack/frames/chit_chat_frame.py +27 -0
- rasa/dialogue_understanding/stack/frames/dialogue_stack_frame.py +137 -0
- rasa/dialogue_understanding/stack/frames/flow_stack_frame.py +157 -0
- rasa/dialogue_understanding/stack/frames/pattern_frame.py +10 -0
- rasa/dialogue_understanding/stack/frames/search_frame.py +27 -0
- rasa/dialogue_understanding/stack/utils.py +211 -0
- rasa/dialogue_understanding/utils.py +14 -0
- rasa/dialogue_understanding_test/__init__.py +0 -0
- rasa/dialogue_understanding_test/command_metric_calculation.py +12 -0
- rasa/dialogue_understanding_test/constants.py +17 -0
- rasa/dialogue_understanding_test/du_test_case.py +118 -0
- rasa/dialogue_understanding_test/du_test_result.py +11 -0
- rasa/dialogue_understanding_test/du_test_runner.py +93 -0
- rasa/dialogue_understanding_test/io.py +54 -0
- rasa/dialogue_understanding_test/validation.py +22 -0
- rasa/e2e_test/__init__.py +0 -0
- rasa/e2e_test/aggregate_test_stats_calculator.py +134 -0
- rasa/e2e_test/assertions.py +1345 -0
- rasa/e2e_test/assertions_schema.yml +129 -0
- rasa/e2e_test/constants.py +31 -0
- rasa/e2e_test/e2e_config.py +220 -0
- rasa/e2e_test/e2e_config_schema.yml +26 -0
- rasa/e2e_test/e2e_test_case.py +569 -0
- rasa/e2e_test/e2e_test_converter.py +363 -0
- rasa/e2e_test/e2e_test_converter_prompt.jinja2 +70 -0
- rasa/e2e_test/e2e_test_coverage_report.py +364 -0
- rasa/e2e_test/e2e_test_result.py +54 -0
- rasa/e2e_test/e2e_test_runner.py +1192 -0
- rasa/e2e_test/e2e_test_schema.yml +181 -0
- rasa/e2e_test/pykwalify_extensions.py +39 -0
- rasa/e2e_test/stub_custom_action.py +70 -0
- rasa/e2e_test/utils/__init__.py +0 -0
- rasa/e2e_test/utils/e2e_yaml_utils.py +55 -0
- rasa/e2e_test/utils/io.py +598 -0
- rasa/e2e_test/utils/validation.py +178 -0
- rasa/engine/__init__.py +0 -0
- rasa/engine/caching.py +463 -0
- rasa/engine/constants.py +17 -0
- rasa/engine/exceptions.py +14 -0
- rasa/engine/graph.py +642 -0
- rasa/engine/loader.py +48 -0
- rasa/engine/recipes/__init__.py +0 -0
- rasa/engine/recipes/config_files/default_config.yml +41 -0
- rasa/engine/recipes/default_components.py +97 -0
- rasa/engine/recipes/default_recipe.py +1272 -0
- rasa/engine/recipes/graph_recipe.py +79 -0
- rasa/engine/recipes/recipe.py +93 -0
- rasa/engine/runner/__init__.py +0 -0
- rasa/engine/runner/dask.py +250 -0
- rasa/engine/runner/interface.py +49 -0
- rasa/engine/storage/__init__.py +0 -0
- rasa/engine/storage/local_model_storage.py +244 -0
- rasa/engine/storage/resource.py +110 -0
- rasa/engine/storage/storage.py +199 -0
- rasa/engine/training/__init__.py +0 -0
- rasa/engine/training/components.py +176 -0
- rasa/engine/training/fingerprinting.py +64 -0
- rasa/engine/training/graph_trainer.py +256 -0
- rasa/engine/training/hooks.py +164 -0
- rasa/engine/validation.py +1451 -0
- rasa/env.py +14 -0
- rasa/exceptions.py +69 -0
- rasa/graph_components/__init__.py +0 -0
- rasa/graph_components/converters/__init__.py +0 -0
- rasa/graph_components/converters/nlu_message_converter.py +48 -0
- rasa/graph_components/providers/__init__.py +0 -0
- rasa/graph_components/providers/domain_for_core_training_provider.py +87 -0
- rasa/graph_components/providers/domain_provider.py +71 -0
- rasa/graph_components/providers/flows_provider.py +74 -0
- rasa/graph_components/providers/forms_provider.py +44 -0
- rasa/graph_components/providers/nlu_training_data_provider.py +56 -0
- rasa/graph_components/providers/responses_provider.py +44 -0
- rasa/graph_components/providers/rule_only_provider.py +49 -0
- rasa/graph_components/providers/story_graph_provider.py +96 -0
- rasa/graph_components/providers/training_tracker_provider.py +55 -0
- rasa/graph_components/validators/__init__.py +0 -0
- rasa/graph_components/validators/default_recipe_validator.py +550 -0
- rasa/graph_components/validators/finetuning_validator.py +302 -0
- rasa/hooks.py +111 -0
- rasa/jupyter.py +63 -0
- rasa/llm_fine_tuning/__init__.py +0 -0
- rasa/llm_fine_tuning/annotation_module.py +241 -0
- rasa/llm_fine_tuning/conversations.py +144 -0
- rasa/llm_fine_tuning/llm_data_preparation_module.py +178 -0
- rasa/llm_fine_tuning/paraphrasing/__init__.py +0 -0
- rasa/llm_fine_tuning/paraphrasing/conversation_rephraser.py +281 -0
- rasa/llm_fine_tuning/paraphrasing/default_rephrase_prompt_template.jina2 +44 -0
- rasa/llm_fine_tuning/paraphrasing/rephrase_validator.py +121 -0
- rasa/llm_fine_tuning/paraphrasing/rephrased_user_message.py +10 -0
- rasa/llm_fine_tuning/paraphrasing_module.py +128 -0
- rasa/llm_fine_tuning/storage.py +174 -0
- rasa/llm_fine_tuning/train_test_split_module.py +441 -0
- rasa/markers/__init__.py +0 -0
- rasa/markers/marker.py +269 -0
- rasa/markers/marker_base.py +828 -0
- rasa/markers/upload.py +74 -0
- rasa/markers/validate.py +21 -0
- rasa/model.py +118 -0
- rasa/model_manager/__init__.py +0 -0
- rasa/model_manager/config.py +40 -0
- rasa/model_manager/model_api.py +559 -0
- rasa/model_manager/runner_service.py +286 -0
- rasa/model_manager/socket_bridge.py +146 -0
- rasa/model_manager/studio_jwt_auth.py +86 -0
- rasa/model_manager/trainer_service.py +325 -0
- rasa/model_manager/utils.py +87 -0
- rasa/model_manager/warm_rasa_process.py +187 -0
- rasa/model_service.py +112 -0
- rasa/model_testing.py +457 -0
- rasa/model_training.py +596 -0
- rasa/nlu/__init__.py +7 -0
- rasa/nlu/classifiers/__init__.py +3 -0
- rasa/nlu/classifiers/classifier.py +5 -0
- rasa/nlu/classifiers/diet_classifier.py +1881 -0
- rasa/nlu/classifiers/fallback_classifier.py +192 -0
- rasa/nlu/classifiers/keyword_intent_classifier.py +188 -0
- rasa/nlu/classifiers/logistic_regression_classifier.py +253 -0
- rasa/nlu/classifiers/mitie_intent_classifier.py +156 -0
- rasa/nlu/classifiers/regex_message_handler.py +56 -0
- rasa/nlu/classifiers/sklearn_intent_classifier.py +330 -0
- rasa/nlu/constants.py +77 -0
- rasa/nlu/convert.py +40 -0
- rasa/nlu/emulators/__init__.py +0 -0
- rasa/nlu/emulators/dialogflow.py +55 -0
- rasa/nlu/emulators/emulator.py +49 -0
- rasa/nlu/emulators/luis.py +86 -0
- rasa/nlu/emulators/no_emulator.py +10 -0
- rasa/nlu/emulators/wit.py +56 -0
- rasa/nlu/extractors/__init__.py +0 -0
- rasa/nlu/extractors/crf_entity_extractor.py +715 -0
- rasa/nlu/extractors/duckling_entity_extractor.py +206 -0
- rasa/nlu/extractors/entity_synonyms.py +178 -0
- rasa/nlu/extractors/extractor.py +470 -0
- rasa/nlu/extractors/mitie_entity_extractor.py +293 -0
- rasa/nlu/extractors/regex_entity_extractor.py +220 -0
- rasa/nlu/extractors/spacy_entity_extractor.py +95 -0
- rasa/nlu/featurizers/__init__.py +0 -0
- rasa/nlu/featurizers/dense_featurizer/__init__.py +0 -0
- rasa/nlu/featurizers/dense_featurizer/convert_featurizer.py +445 -0
- rasa/nlu/featurizers/dense_featurizer/dense_featurizer.py +57 -0
- rasa/nlu/featurizers/dense_featurizer/lm_featurizer.py +768 -0
- rasa/nlu/featurizers/dense_featurizer/mitie_featurizer.py +170 -0
- rasa/nlu/featurizers/dense_featurizer/spacy_featurizer.py +132 -0
- rasa/nlu/featurizers/featurizer.py +89 -0
- rasa/nlu/featurizers/sparse_featurizer/__init__.py +0 -0
- rasa/nlu/featurizers/sparse_featurizer/count_vectors_featurizer.py +867 -0
- rasa/nlu/featurizers/sparse_featurizer/lexical_syntactic_featurizer.py +571 -0
- rasa/nlu/featurizers/sparse_featurizer/regex_featurizer.py +271 -0
- rasa/nlu/featurizers/sparse_featurizer/sparse_featurizer.py +9 -0
- rasa/nlu/model.py +24 -0
- rasa/nlu/run.py +27 -0
- rasa/nlu/selectors/__init__.py +0 -0
- rasa/nlu/selectors/response_selector.py +987 -0
- rasa/nlu/test.py +1940 -0
- rasa/nlu/tokenizers/__init__.py +0 -0
- rasa/nlu/tokenizers/jieba_tokenizer.py +148 -0
- rasa/nlu/tokenizers/mitie_tokenizer.py +75 -0
- rasa/nlu/tokenizers/spacy_tokenizer.py +72 -0
- rasa/nlu/tokenizers/tokenizer.py +239 -0
- rasa/nlu/tokenizers/whitespace_tokenizer.py +95 -0
- rasa/nlu/utils/__init__.py +35 -0
- rasa/nlu/utils/bilou_utils.py +462 -0
- rasa/nlu/utils/hugging_face/__init__.py +0 -0
- rasa/nlu/utils/hugging_face/registry.py +108 -0
- rasa/nlu/utils/hugging_face/transformers_pre_post_processors.py +311 -0
- rasa/nlu/utils/mitie_utils.py +113 -0
- rasa/nlu/utils/pattern_utils.py +168 -0
- rasa/nlu/utils/spacy_utils.py +310 -0
- rasa/plugin.py +90 -0
- rasa/server.py +1588 -0
- rasa/shared/__init__.py +0 -0
- rasa/shared/constants.py +311 -0
- rasa/shared/core/__init__.py +0 -0
- rasa/shared/core/command_payload_reader.py +109 -0
- rasa/shared/core/constants.py +180 -0
- rasa/shared/core/conversation.py +46 -0
- rasa/shared/core/domain.py +2172 -0
- rasa/shared/core/events.py +2559 -0
- rasa/shared/core/flows/__init__.py +7 -0
- rasa/shared/core/flows/flow.py +562 -0
- rasa/shared/core/flows/flow_path.py +84 -0
- rasa/shared/core/flows/flow_step.py +146 -0
- rasa/shared/core/flows/flow_step_links.py +319 -0
- rasa/shared/core/flows/flow_step_sequence.py +70 -0
- rasa/shared/core/flows/flows_list.py +258 -0
- rasa/shared/core/flows/flows_yaml_schema.json +303 -0
- rasa/shared/core/flows/nlu_trigger.py +117 -0
- rasa/shared/core/flows/steps/__init__.py +24 -0
- rasa/shared/core/flows/steps/action.py +56 -0
- rasa/shared/core/flows/steps/call.py +64 -0
- rasa/shared/core/flows/steps/collect.py +112 -0
- rasa/shared/core/flows/steps/constants.py +5 -0
- rasa/shared/core/flows/steps/continuation.py +36 -0
- rasa/shared/core/flows/steps/end.py +22 -0
- rasa/shared/core/flows/steps/internal.py +44 -0
- rasa/shared/core/flows/steps/link.py +51 -0
- rasa/shared/core/flows/steps/no_operation.py +48 -0
- rasa/shared/core/flows/steps/set_slots.py +50 -0
- rasa/shared/core/flows/steps/start.py +30 -0
- rasa/shared/core/flows/utils.py +39 -0
- rasa/shared/core/flows/validation.py +735 -0
- rasa/shared/core/flows/yaml_flows_io.py +405 -0
- rasa/shared/core/generator.py +908 -0
- rasa/shared/core/slot_mappings.py +526 -0
- rasa/shared/core/slots.py +654 -0
- rasa/shared/core/trackers.py +1183 -0
- rasa/shared/core/training_data/__init__.py +0 -0
- rasa/shared/core/training_data/loading.py +89 -0
- rasa/shared/core/training_data/story_reader/__init__.py +0 -0
- rasa/shared/core/training_data/story_reader/story_reader.py +129 -0
- rasa/shared/core/training_data/story_reader/story_step_builder.py +168 -0
- rasa/shared/core/training_data/story_reader/yaml_story_reader.py +888 -0
- rasa/shared/core/training_data/story_writer/__init__.py +0 -0
- rasa/shared/core/training_data/story_writer/story_writer.py +76 -0
- rasa/shared/core/training_data/story_writer/yaml_story_writer.py +444 -0
- rasa/shared/core/training_data/structures.py +858 -0
- rasa/shared/core/training_data/visualization.html +146 -0
- rasa/shared/core/training_data/visualization.py +603 -0
- rasa/shared/data.py +249 -0
- rasa/shared/engine/__init__.py +0 -0
- rasa/shared/engine/caching.py +26 -0
- rasa/shared/exceptions.py +167 -0
- rasa/shared/importers/__init__.py +0 -0
- rasa/shared/importers/importer.py +770 -0
- rasa/shared/importers/multi_project.py +215 -0
- rasa/shared/importers/rasa.py +108 -0
- rasa/shared/importers/remote_importer.py +196 -0
- rasa/shared/importers/utils.py +36 -0
- rasa/shared/nlu/__init__.py +0 -0
- rasa/shared/nlu/constants.py +53 -0
- rasa/shared/nlu/interpreter.py +10 -0
- rasa/shared/nlu/training_data/__init__.py +0 -0
- rasa/shared/nlu/training_data/entities_parser.py +208 -0
- rasa/shared/nlu/training_data/features.py +492 -0
- rasa/shared/nlu/training_data/formats/__init__.py +10 -0
- rasa/shared/nlu/training_data/formats/dialogflow.py +163 -0
- rasa/shared/nlu/training_data/formats/luis.py +87 -0
- rasa/shared/nlu/training_data/formats/rasa.py +135 -0
- rasa/shared/nlu/training_data/formats/rasa_yaml.py +618 -0
- rasa/shared/nlu/training_data/formats/readerwriter.py +244 -0
- rasa/shared/nlu/training_data/formats/wit.py +52 -0
- rasa/shared/nlu/training_data/loading.py +137 -0
- rasa/shared/nlu/training_data/lookup_tables_parser.py +30 -0
- rasa/shared/nlu/training_data/message.py +490 -0
- rasa/shared/nlu/training_data/schemas/__init__.py +0 -0
- rasa/shared/nlu/training_data/schemas/data_schema.py +85 -0
- rasa/shared/nlu/training_data/schemas/nlu.yml +53 -0
- rasa/shared/nlu/training_data/schemas/responses.yml +70 -0
- rasa/shared/nlu/training_data/synonyms_parser.py +42 -0
- rasa/shared/nlu/training_data/training_data.py +729 -0
- rasa/shared/nlu/training_data/util.py +223 -0
- rasa/shared/providers/__init__.py +0 -0
- rasa/shared/providers/_configs/__init__.py +0 -0
- rasa/shared/providers/_configs/azure_openai_client_config.py +677 -0
- rasa/shared/providers/_configs/client_config.py +59 -0
- rasa/shared/providers/_configs/default_litellm_client_config.py +132 -0
- rasa/shared/providers/_configs/huggingface_local_embedding_client_config.py +236 -0
- rasa/shared/providers/_configs/litellm_router_client_config.py +222 -0
- rasa/shared/providers/_configs/model_group_config.py +173 -0
- rasa/shared/providers/_configs/openai_client_config.py +177 -0
- rasa/shared/providers/_configs/rasa_llm_client_config.py +75 -0
- rasa/shared/providers/_configs/self_hosted_llm_client_config.py +178 -0
- rasa/shared/providers/_configs/utils.py +117 -0
- rasa/shared/providers/_ssl_verification_utils.py +124 -0
- rasa/shared/providers/_utils.py +79 -0
- rasa/shared/providers/constants.py +7 -0
- rasa/shared/providers/embedding/__init__.py +0 -0
- rasa/shared/providers/embedding/_base_litellm_embedding_client.py +243 -0
- rasa/shared/providers/embedding/_langchain_embedding_client_adapter.py +74 -0
- rasa/shared/providers/embedding/azure_openai_embedding_client.py +335 -0
- rasa/shared/providers/embedding/default_litellm_embedding_client.py +126 -0
- rasa/shared/providers/embedding/embedding_client.py +90 -0
- rasa/shared/providers/embedding/embedding_response.py +41 -0
- rasa/shared/providers/embedding/huggingface_local_embedding_client.py +191 -0
- rasa/shared/providers/embedding/litellm_router_embedding_client.py +138 -0
- rasa/shared/providers/embedding/openai_embedding_client.py +172 -0
- rasa/shared/providers/llm/__init__.py +0 -0
- rasa/shared/providers/llm/_base_litellm_client.py +265 -0
- rasa/shared/providers/llm/azure_openai_llm_client.py +415 -0
- rasa/shared/providers/llm/default_litellm_llm_client.py +110 -0
- rasa/shared/providers/llm/litellm_router_llm_client.py +202 -0
- rasa/shared/providers/llm/llm_client.py +78 -0
- rasa/shared/providers/llm/llm_response.py +50 -0
- rasa/shared/providers/llm/openai_llm_client.py +161 -0
- rasa/shared/providers/llm/rasa_llm_client.py +120 -0
- rasa/shared/providers/llm/self_hosted_llm_client.py +276 -0
- rasa/shared/providers/mappings.py +94 -0
- rasa/shared/providers/router/__init__.py +0 -0
- rasa/shared/providers/router/_base_litellm_router_client.py +185 -0
- rasa/shared/providers/router/router_client.py +75 -0
- rasa/shared/utils/__init__.py +0 -0
- rasa/shared/utils/cli.py +102 -0
- rasa/shared/utils/common.py +324 -0
- rasa/shared/utils/constants.py +4 -0
- rasa/shared/utils/health_check/__init__.py +0 -0
- rasa/shared/utils/health_check/embeddings_health_check_mixin.py +31 -0
- rasa/shared/utils/health_check/health_check.py +258 -0
- rasa/shared/utils/health_check/llm_health_check_mixin.py +31 -0
- rasa/shared/utils/io.py +499 -0
- rasa/shared/utils/llm.py +764 -0
- rasa/shared/utils/pykwalify_extensions.py +27 -0
- rasa/shared/utils/schemas/__init__.py +0 -0
- rasa/shared/utils/schemas/config.yml +2 -0
- rasa/shared/utils/schemas/domain.yml +145 -0
- rasa/shared/utils/schemas/events.py +214 -0
- rasa/shared/utils/schemas/model_config.yml +36 -0
- rasa/shared/utils/schemas/stories.yml +173 -0
- rasa/shared/utils/yaml.py +1068 -0
- rasa/studio/__init__.py +0 -0
- rasa/studio/auth.py +270 -0
- rasa/studio/config.py +136 -0
- rasa/studio/constants.py +19 -0
- rasa/studio/data_handler.py +368 -0
- rasa/studio/download.py +489 -0
- rasa/studio/results_logger.py +137 -0
- rasa/studio/train.py +134 -0
- rasa/studio/upload.py +563 -0
- rasa/telemetry.py +1876 -0
- rasa/tracing/__init__.py +0 -0
- rasa/tracing/config.py +355 -0
- rasa/tracing/constants.py +62 -0
- rasa/tracing/instrumentation/__init__.py +0 -0
- rasa/tracing/instrumentation/attribute_extractors.py +765 -0
- rasa/tracing/instrumentation/instrumentation.py +1306 -0
- rasa/tracing/instrumentation/intentless_policy_instrumentation.py +144 -0
- rasa/tracing/instrumentation/metrics.py +294 -0
- rasa/tracing/metric_instrument_provider.py +205 -0
- rasa/utils/__init__.py +0 -0
- rasa/utils/beta.py +83 -0
- rasa/utils/cli.py +28 -0
- rasa/utils/common.py +639 -0
- rasa/utils/converter.py +53 -0
- rasa/utils/endpoints.py +331 -0
- rasa/utils/io.py +252 -0
- rasa/utils/json_utils.py +60 -0
- rasa/utils/licensing.py +542 -0
- rasa/utils/log_utils.py +181 -0
- rasa/utils/mapper.py +210 -0
- rasa/utils/ml_utils.py +147 -0
- rasa/utils/plotting.py +362 -0
- rasa/utils/sanic_error_handler.py +32 -0
- rasa/utils/singleton.py +23 -0
- rasa/utils/tensorflow/__init__.py +0 -0
- rasa/utils/tensorflow/callback.py +112 -0
- rasa/utils/tensorflow/constants.py +116 -0
- rasa/utils/tensorflow/crf.py +492 -0
- rasa/utils/tensorflow/data_generator.py +440 -0
- rasa/utils/tensorflow/environment.py +161 -0
- rasa/utils/tensorflow/exceptions.py +5 -0
- rasa/utils/tensorflow/feature_array.py +366 -0
- rasa/utils/tensorflow/layers.py +1565 -0
- rasa/utils/tensorflow/layers_utils.py +113 -0
- rasa/utils/tensorflow/metrics.py +281 -0
- rasa/utils/tensorflow/model_data.py +798 -0
- rasa/utils/tensorflow/model_data_utils.py +499 -0
- rasa/utils/tensorflow/models.py +935 -0
- rasa/utils/tensorflow/rasa_layers.py +1094 -0
- rasa/utils/tensorflow/transformer.py +640 -0
- rasa/utils/tensorflow/types.py +6 -0
- rasa/utils/train_utils.py +572 -0
- rasa/utils/url_tools.py +53 -0
- rasa/utils/yaml.py +54 -0
- rasa/validator.py +1644 -0
- rasa/version.py +3 -0
- rasa_pro-3.12.0.dev1.dist-info/METADATA +199 -0
- rasa_pro-3.12.0.dev1.dist-info/NOTICE +5 -0
- rasa_pro-3.12.0.dev1.dist-info/RECORD +790 -0
- rasa_pro-3.12.0.dev1.dist-info/WHEEL +4 -0
- rasa_pro-3.12.0.dev1.dist-info/entry_points.txt +3 -0
rasa/core/test.py
ADDED
|
@@ -0,0 +1,1335 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
import tempfile
|
|
5
|
+
import warnings as pywarnings
|
|
6
|
+
from collections import defaultdict, namedtuple
|
|
7
|
+
from typing import Any, Dict, List, Optional, Text, Tuple, TYPE_CHECKING, cast
|
|
8
|
+
|
|
9
|
+
from rasa import telemetry
|
|
10
|
+
from rasa.core.constants import (
|
|
11
|
+
CONFUSION_MATRIX_STORIES_FILE,
|
|
12
|
+
REPORT_STORIES_FILE,
|
|
13
|
+
FAILED_STORIES_FILE,
|
|
14
|
+
SUCCESSFUL_STORIES_FILE,
|
|
15
|
+
STORIES_WITH_WARNINGS_FILE,
|
|
16
|
+
)
|
|
17
|
+
from rasa.core.channels import UserMessage
|
|
18
|
+
from rasa.core.policies.policy import PolicyPrediction
|
|
19
|
+
from rasa.nlu.test import EntityEvaluationResult, evaluate_entities
|
|
20
|
+
from rasa.nlu.tokenizers.tokenizer import Token
|
|
21
|
+
from rasa.shared.constants import ROUTE_TO_CALM_SLOT
|
|
22
|
+
from rasa.shared.core.constants import (
|
|
23
|
+
POLICIES_THAT_EXTRACT_ENTITIES,
|
|
24
|
+
ACTION_UNLIKELY_INTENT_NAME,
|
|
25
|
+
)
|
|
26
|
+
from rasa.shared.exceptions import RasaException
|
|
27
|
+
import rasa.shared.utils.io
|
|
28
|
+
from rasa.shared.core.training_data.story_writer.yaml_story_writer import (
|
|
29
|
+
YAMLStoryWriter,
|
|
30
|
+
)
|
|
31
|
+
from rasa.shared.core.training_data.structures import StoryStep
|
|
32
|
+
from rasa.shared.core.domain import Domain
|
|
33
|
+
from rasa.nlu.constants import (
|
|
34
|
+
RESPONSE_SELECTOR_DEFAULT_INTENT,
|
|
35
|
+
RESPONSE_SELECTOR_RETRIEVAL_INTENTS,
|
|
36
|
+
TOKENS_NAMES,
|
|
37
|
+
RESPONSE_SELECTOR_PROPERTY_NAME,
|
|
38
|
+
)
|
|
39
|
+
from rasa.shared.nlu.constants import (
|
|
40
|
+
INTENT,
|
|
41
|
+
ENTITIES,
|
|
42
|
+
ENTITY_ATTRIBUTE_VALUE,
|
|
43
|
+
ENTITY_ATTRIBUTE_START,
|
|
44
|
+
ENTITY_ATTRIBUTE_END,
|
|
45
|
+
EXTRACTOR,
|
|
46
|
+
ENTITY_ATTRIBUTE_TYPE,
|
|
47
|
+
INTENT_RESPONSE_KEY,
|
|
48
|
+
INTENT_NAME_KEY,
|
|
49
|
+
RESPONSE,
|
|
50
|
+
RESPONSE_SELECTOR,
|
|
51
|
+
FULL_RETRIEVAL_INTENT_NAME_KEY,
|
|
52
|
+
TEXT,
|
|
53
|
+
ENTITY_ATTRIBUTE_TEXT,
|
|
54
|
+
)
|
|
55
|
+
from rasa.constants import RESULTS_FILE, PERCENTAGE_KEY
|
|
56
|
+
from rasa.shared.core.events import ActionExecuted, EntitiesAdded, UserUttered, SlotSet
|
|
57
|
+
from rasa.shared.core.trackers import DialogueStateTracker
|
|
58
|
+
from rasa.shared.nlu.training_data.formats.readerwriter import TrainingDataWriter
|
|
59
|
+
from rasa.shared.importers.importer import TrainingDataImporter
|
|
60
|
+
from rasa.shared.utils.io import DEFAULT_ENCODING
|
|
61
|
+
from rasa.utils.tensorflow.constants import QUERY_INTENT_KEY, SEVERITY_KEY
|
|
62
|
+
from rasa.exceptions import ActionLimitReached
|
|
63
|
+
|
|
64
|
+
from rasa.core.actions.action import ActionRetrieveResponse
|
|
65
|
+
|
|
66
|
+
if TYPE_CHECKING:
|
|
67
|
+
from rasa.core.agent import Agent
|
|
68
|
+
from rasa.core.processor import MessageProcessor
|
|
69
|
+
from rasa.shared.core.generator import TrainingDataGenerator
|
|
70
|
+
from rasa.shared.core.events import Event, EntityPrediction
|
|
71
|
+
|
|
72
|
+
logger = logging.getLogger(__name__)
|
|
73
|
+
|
|
74
|
+
StoryEvaluation = namedtuple(
|
|
75
|
+
"StoryEvaluation",
|
|
76
|
+
[
|
|
77
|
+
"evaluation_store",
|
|
78
|
+
"failed_stories",
|
|
79
|
+
"successful_stories",
|
|
80
|
+
"stories_with_warnings",
|
|
81
|
+
"action_list",
|
|
82
|
+
"in_training_data_fraction",
|
|
83
|
+
],
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
PredictionList = List[Optional[Text]]
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class WrongPredictionException(RasaException, ValueError):
|
|
90
|
+
"""Raised if a wrong prediction is encountered."""
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
class WarningPredictedAction(ActionExecuted):
|
|
94
|
+
"""The model predicted the correct action with warning."""
|
|
95
|
+
|
|
96
|
+
type_name = "warning_predicted"
|
|
97
|
+
|
|
98
|
+
def __init__(
|
|
99
|
+
self,
|
|
100
|
+
action_name_prediction: Text,
|
|
101
|
+
action_name: Optional[Text] = None,
|
|
102
|
+
policy: Optional[Text] = None,
|
|
103
|
+
confidence: Optional[float] = None,
|
|
104
|
+
timestamp: Optional[float] = None,
|
|
105
|
+
metadata: Optional[Dict] = None,
|
|
106
|
+
):
|
|
107
|
+
"""Creates event `action_unlikely_intent` predicted as warning.
|
|
108
|
+
|
|
109
|
+
See the docstring of the parent class for more information.
|
|
110
|
+
"""
|
|
111
|
+
self.action_name_prediction = action_name_prediction
|
|
112
|
+
super().__init__(action_name, policy, confidence, timestamp, metadata)
|
|
113
|
+
|
|
114
|
+
def inline_comment(self, **kwargs: Any) -> Text:
|
|
115
|
+
"""A comment attached to this event. Used during dumping."""
|
|
116
|
+
return f"predicted: {self.action_name_prediction}"
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
class WronglyPredictedAction(ActionExecuted):
|
|
120
|
+
"""The model predicted the wrong action.
|
|
121
|
+
|
|
122
|
+
Mostly used to mark wrong predictions and be able to
|
|
123
|
+
dump them as stories.
|
|
124
|
+
"""
|
|
125
|
+
|
|
126
|
+
type_name = "wrong_action"
|
|
127
|
+
|
|
128
|
+
def __init__(
|
|
129
|
+
self,
|
|
130
|
+
action_name_target: Text,
|
|
131
|
+
action_text_target: Text,
|
|
132
|
+
action_name_prediction: Text,
|
|
133
|
+
policy: Optional[Text] = None,
|
|
134
|
+
confidence: Optional[float] = None,
|
|
135
|
+
timestamp: Optional[float] = None,
|
|
136
|
+
metadata: Optional[Dict] = None,
|
|
137
|
+
predicted_action_unlikely_intent: bool = False,
|
|
138
|
+
) -> None:
|
|
139
|
+
"""Creates event for a successful event execution.
|
|
140
|
+
|
|
141
|
+
See the docstring of the parent class `ActionExecuted` for more information.
|
|
142
|
+
"""
|
|
143
|
+
self.action_name_prediction = action_name_prediction
|
|
144
|
+
self.predicted_action_unlikely_intent = predicted_action_unlikely_intent
|
|
145
|
+
super().__init__(
|
|
146
|
+
action_name_target,
|
|
147
|
+
policy,
|
|
148
|
+
confidence,
|
|
149
|
+
timestamp,
|
|
150
|
+
metadata,
|
|
151
|
+
action_text=action_text_target,
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
def inline_comment(self, **kwargs: Any) -> Text:
|
|
155
|
+
"""A comment attached to this event. Used during dumping."""
|
|
156
|
+
comment = f"predicted: {self.action_name_prediction}"
|
|
157
|
+
if self.predicted_action_unlikely_intent:
|
|
158
|
+
return f"{comment} after {ACTION_UNLIKELY_INTENT_NAME}"
|
|
159
|
+
return comment
|
|
160
|
+
|
|
161
|
+
def as_story_string(self) -> Text:
|
|
162
|
+
"""Returns the story equivalent representation."""
|
|
163
|
+
return f"{self.action_name} <!-- {self.inline_comment()} -->"
|
|
164
|
+
|
|
165
|
+
def __repr__(self) -> Text:
|
|
166
|
+
"""Returns event as string for debugging."""
|
|
167
|
+
return (
|
|
168
|
+
f"WronglyPredictedAction(action_target: {self.action_name}, "
|
|
169
|
+
f"action_prediction: {self.action_name_prediction}, "
|
|
170
|
+
f"policy: {self.policy}, confidence: {self.confidence}, "
|
|
171
|
+
f"metadata: {self.metadata})"
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
class EvaluationStore:
|
|
176
|
+
"""Class storing action, intent and entity predictions and targets."""
|
|
177
|
+
|
|
178
|
+
def __init__(
|
|
179
|
+
self,
|
|
180
|
+
action_predictions: Optional[PredictionList] = None,
|
|
181
|
+
action_targets: Optional[PredictionList] = None,
|
|
182
|
+
intent_predictions: Optional[PredictionList] = None,
|
|
183
|
+
intent_targets: Optional[PredictionList] = None,
|
|
184
|
+
entity_predictions: Optional[List["EntityPrediction"]] = None,
|
|
185
|
+
entity_targets: Optional[List["EntityPrediction"]] = None,
|
|
186
|
+
) -> None:
|
|
187
|
+
"""Initialize store attributes."""
|
|
188
|
+
self.action_predictions = action_predictions or []
|
|
189
|
+
self.action_targets = action_targets or []
|
|
190
|
+
self.intent_predictions = intent_predictions or []
|
|
191
|
+
self.intent_targets = intent_targets or []
|
|
192
|
+
self.entity_predictions: List["EntityPrediction"] = entity_predictions or []
|
|
193
|
+
self.entity_targets: List["EntityPrediction"] = entity_targets or []
|
|
194
|
+
|
|
195
|
+
def add_to_store(
|
|
196
|
+
self,
|
|
197
|
+
action_predictions: Optional[PredictionList] = None,
|
|
198
|
+
action_targets: Optional[PredictionList] = None,
|
|
199
|
+
intent_predictions: Optional[PredictionList] = None,
|
|
200
|
+
intent_targets: Optional[PredictionList] = None,
|
|
201
|
+
entity_predictions: Optional[List["EntityPrediction"]] = None,
|
|
202
|
+
entity_targets: Optional[List["EntityPrediction"]] = None,
|
|
203
|
+
) -> None:
|
|
204
|
+
"""Add items or lists of items to the store."""
|
|
205
|
+
self.action_predictions.extend(action_predictions or [])
|
|
206
|
+
self.action_targets.extend(action_targets or [])
|
|
207
|
+
self.intent_targets.extend(intent_targets or [])
|
|
208
|
+
self.intent_predictions.extend(intent_predictions or [])
|
|
209
|
+
self.entity_predictions.extend(entity_predictions or [])
|
|
210
|
+
self.entity_targets.extend(entity_targets or [])
|
|
211
|
+
|
|
212
|
+
def merge_store(self, other: "EvaluationStore") -> None:
|
|
213
|
+
"""Add the contents of other to self."""
|
|
214
|
+
self.add_to_store(
|
|
215
|
+
action_predictions=other.action_predictions,
|
|
216
|
+
action_targets=other.action_targets,
|
|
217
|
+
intent_predictions=other.intent_predictions,
|
|
218
|
+
intent_targets=other.intent_targets,
|
|
219
|
+
entity_predictions=other.entity_predictions,
|
|
220
|
+
entity_targets=other.entity_targets,
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
def _check_entity_prediction_target_mismatch(self) -> bool:
|
|
224
|
+
"""Checks that same entities were expected and actually extracted.
|
|
225
|
+
|
|
226
|
+
Possible duplicates or differences in order should not matter.
|
|
227
|
+
"""
|
|
228
|
+
deduplicated_targets = set(
|
|
229
|
+
tuple(entity.items()) for entity in self.entity_targets
|
|
230
|
+
)
|
|
231
|
+
deduplicated_predictions = set(
|
|
232
|
+
tuple(entity.items()) for entity in self.entity_predictions
|
|
233
|
+
)
|
|
234
|
+
return deduplicated_targets != deduplicated_predictions
|
|
235
|
+
|
|
236
|
+
def check_prediction_target_mismatch(self) -> bool:
|
|
237
|
+
"""Checks if intent, entity or action predictions don't match expected ones."""
|
|
238
|
+
return (
|
|
239
|
+
self.intent_predictions != self.intent_targets
|
|
240
|
+
or self._check_entity_prediction_target_mismatch()
|
|
241
|
+
or self.action_predictions != self.action_targets
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
@staticmethod
|
|
245
|
+
def _compare_entities(
|
|
246
|
+
entity_predictions: List["EntityPrediction"],
|
|
247
|
+
entity_targets: List["EntityPrediction"],
|
|
248
|
+
i_pred: int,
|
|
249
|
+
i_target: int,
|
|
250
|
+
) -> int:
|
|
251
|
+
"""Picks the fist entity from the current predicted and target entities.
|
|
252
|
+
|
|
253
|
+
If the predicted entity comes first it returns -1,
|
|
254
|
+
while it returns 1 if the target entity comes first.
|
|
255
|
+
If target and predicted are aligned it returns 0.
|
|
256
|
+
"""
|
|
257
|
+
pred = None
|
|
258
|
+
target = None
|
|
259
|
+
if i_pred < len(entity_predictions):
|
|
260
|
+
pred = entity_predictions[i_pred]
|
|
261
|
+
if i_target < len(entity_targets):
|
|
262
|
+
target = entity_targets[i_target]
|
|
263
|
+
if target and pred:
|
|
264
|
+
# Check which entity has the lower "start" value
|
|
265
|
+
if pred.get(ENTITY_ATTRIBUTE_START) < target.get(ENTITY_ATTRIBUTE_START):
|
|
266
|
+
return -1
|
|
267
|
+
elif target.get(ENTITY_ATTRIBUTE_START) < pred.get(ENTITY_ATTRIBUTE_START):
|
|
268
|
+
return 1
|
|
269
|
+
else:
|
|
270
|
+
# Since both have the same "start" values,
|
|
271
|
+
# check which one has the lower "end" value
|
|
272
|
+
if pred.get(ENTITY_ATTRIBUTE_END) < target.get(ENTITY_ATTRIBUTE_END):
|
|
273
|
+
return -1
|
|
274
|
+
elif target.get(ENTITY_ATTRIBUTE_END) < pred.get(ENTITY_ATTRIBUTE_END):
|
|
275
|
+
return 1
|
|
276
|
+
else:
|
|
277
|
+
# The entities have the same "start" and "end" values
|
|
278
|
+
return 0
|
|
279
|
+
return 1 if target else -1
|
|
280
|
+
|
|
281
|
+
@staticmethod
|
|
282
|
+
def _generate_entity_training_data(entity: Dict[Text, Any]) -> Text:
|
|
283
|
+
return TrainingDataWriter.generate_entity(entity.get("text"), entity)
|
|
284
|
+
|
|
285
|
+
def serialise(self) -> Tuple[PredictionList, PredictionList]:
|
|
286
|
+
"""Turn targets and predictions to lists of equal size for sklearn."""
|
|
287
|
+
texts = sorted(
|
|
288
|
+
set(
|
|
289
|
+
[str(e.get("text", "")) for e in self.entity_targets]
|
|
290
|
+
+ [str(e.get("text", "")) for e in self.entity_predictions]
|
|
291
|
+
)
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
aligned_entity_targets: List[Optional[Text]] = []
|
|
295
|
+
aligned_entity_predictions: List[Optional[Text]] = []
|
|
296
|
+
|
|
297
|
+
for text in texts:
|
|
298
|
+
# sort the entities of this sentence to compare them directly
|
|
299
|
+
entity_targets = sorted(
|
|
300
|
+
filter(
|
|
301
|
+
lambda x: x.get(ENTITY_ATTRIBUTE_TEXT) == text, self.entity_targets
|
|
302
|
+
),
|
|
303
|
+
key=lambda x: x[ENTITY_ATTRIBUTE_START], # type: ignore[literal-required]
|
|
304
|
+
)
|
|
305
|
+
entity_predictions = sorted(
|
|
306
|
+
filter(
|
|
307
|
+
lambda x: x.get(ENTITY_ATTRIBUTE_TEXT) == text,
|
|
308
|
+
self.entity_predictions,
|
|
309
|
+
),
|
|
310
|
+
key=lambda x: x[ENTITY_ATTRIBUTE_START], # type: ignore[literal-required]
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
i_pred, i_target = 0, 0
|
|
314
|
+
|
|
315
|
+
while i_pred < len(entity_predictions) or i_target < len(entity_targets):
|
|
316
|
+
cmp = self._compare_entities(
|
|
317
|
+
entity_predictions, entity_targets, i_pred, i_target
|
|
318
|
+
)
|
|
319
|
+
if cmp == -1: # predicted comes first
|
|
320
|
+
aligned_entity_predictions.append(
|
|
321
|
+
self._generate_entity_training_data(entity_predictions[i_pred])
|
|
322
|
+
)
|
|
323
|
+
aligned_entity_targets.append("None")
|
|
324
|
+
i_pred += 1
|
|
325
|
+
elif cmp == 1: # target entity comes first
|
|
326
|
+
aligned_entity_targets.append(
|
|
327
|
+
self._generate_entity_training_data(entity_targets[i_target])
|
|
328
|
+
)
|
|
329
|
+
aligned_entity_predictions.append("None")
|
|
330
|
+
i_target += 1
|
|
331
|
+
else: # target and predicted entity are aligned
|
|
332
|
+
aligned_entity_predictions.append(
|
|
333
|
+
self._generate_entity_training_data(entity_predictions[i_pred])
|
|
334
|
+
)
|
|
335
|
+
aligned_entity_targets.append(
|
|
336
|
+
self._generate_entity_training_data(entity_targets[i_target])
|
|
337
|
+
)
|
|
338
|
+
i_pred += 1
|
|
339
|
+
i_target += 1
|
|
340
|
+
|
|
341
|
+
targets = self.action_targets + self.intent_targets + aligned_entity_targets
|
|
342
|
+
|
|
343
|
+
predictions = (
|
|
344
|
+
self.action_predictions
|
|
345
|
+
+ self.intent_predictions
|
|
346
|
+
+ aligned_entity_predictions
|
|
347
|
+
)
|
|
348
|
+
return targets, predictions
|
|
349
|
+
|
|
350
|
+
|
|
351
|
+
class EndToEndUserUtterance(UserUttered):
|
|
352
|
+
"""End-to-end user utterance.
|
|
353
|
+
|
|
354
|
+
Mostly used to print the full end-to-end user message in the
|
|
355
|
+
`failed_test_stories.yml` output file.
|
|
356
|
+
"""
|
|
357
|
+
|
|
358
|
+
def as_story_string(self, e2e: bool = True) -> Text:
|
|
359
|
+
"""Returns the story equivalent representation."""
|
|
360
|
+
return super().as_story_string(e2e=True)
|
|
361
|
+
|
|
362
|
+
|
|
363
|
+
class WronglyClassifiedUserUtterance(UserUttered):
|
|
364
|
+
"""The NLU model predicted the wrong user utterance.
|
|
365
|
+
|
|
366
|
+
Mostly used to mark wrong predictions and be able to
|
|
367
|
+
dump them as stories.
|
|
368
|
+
"""
|
|
369
|
+
|
|
370
|
+
type_name = "wrong_utterance"
|
|
371
|
+
|
|
372
|
+
def __init__(self, event: UserUttered, eval_store: EvaluationStore) -> None:
|
|
373
|
+
"""Set `predicted_intent` and `predicted_entities` attributes."""
|
|
374
|
+
try:
|
|
375
|
+
self.predicted_intent = eval_store.intent_predictions[0]
|
|
376
|
+
except LookupError:
|
|
377
|
+
self.predicted_intent = None
|
|
378
|
+
|
|
379
|
+
self.target_entities = eval_store.entity_targets
|
|
380
|
+
self.predicted_entities = eval_store.entity_predictions
|
|
381
|
+
|
|
382
|
+
intent = {"name": eval_store.intent_targets[0]}
|
|
383
|
+
|
|
384
|
+
super().__init__(
|
|
385
|
+
event.text,
|
|
386
|
+
intent,
|
|
387
|
+
eval_store.entity_targets,
|
|
388
|
+
event.parse_data,
|
|
389
|
+
event.timestamp,
|
|
390
|
+
event.input_channel,
|
|
391
|
+
)
|
|
392
|
+
|
|
393
|
+
def inline_comment(self, force_comment_generation: bool = False) -> Optional[Text]:
|
|
394
|
+
"""A comment attached to this event. Used during dumping."""
|
|
395
|
+
from rasa.shared.core.events import format_message
|
|
396
|
+
|
|
397
|
+
if force_comment_generation or self.predicted_intent != self.intent["name"]:
|
|
398
|
+
predicted_message = format_message(
|
|
399
|
+
self.text, self.predicted_intent, self.predicted_entities
|
|
400
|
+
)
|
|
401
|
+
|
|
402
|
+
return f"predicted: {self.predicted_intent}: {predicted_message}"
|
|
403
|
+
else:
|
|
404
|
+
return None
|
|
405
|
+
|
|
406
|
+
@staticmethod
|
|
407
|
+
def inline_comment_for_entity(
|
|
408
|
+
predicted: Dict[Text, Any], entity: Dict[Text, Any]
|
|
409
|
+
) -> Optional[Text]:
|
|
410
|
+
"""Returns the predicted entity which is then printed as a comment."""
|
|
411
|
+
if predicted["entity"] != entity["entity"]:
|
|
412
|
+
return "predicted: " + predicted["entity"] + ": " + predicted["value"]
|
|
413
|
+
else:
|
|
414
|
+
return None
|
|
415
|
+
|
|
416
|
+
def as_story_string(self, e2e: bool = True) -> Text:
|
|
417
|
+
"""Returns text representation of event."""
|
|
418
|
+
from rasa.shared.core.events import format_message
|
|
419
|
+
|
|
420
|
+
correct_message = format_message(
|
|
421
|
+
self.text, self.intent.get("name"), self.entities
|
|
422
|
+
)
|
|
423
|
+
return (
|
|
424
|
+
f"{self.intent.get('name')}: {correct_message} "
|
|
425
|
+
f"<!-- {self.inline_comment()} -->"
|
|
426
|
+
)
|
|
427
|
+
|
|
428
|
+
|
|
429
|
+
def _create_data_generator(
|
|
430
|
+
resource_name: Text,
|
|
431
|
+
agent: "Agent",
|
|
432
|
+
max_stories: Optional[int] = None,
|
|
433
|
+
use_conversation_test_files: bool = False,
|
|
434
|
+
) -> "TrainingDataGenerator":
|
|
435
|
+
from rasa.shared.core.generator import TrainingDataGenerator
|
|
436
|
+
|
|
437
|
+
tmp_domain_path = Path(tempfile.mkdtemp()) / "domain.yaml"
|
|
438
|
+
domain = agent.domain if agent.domain is not None else Domain.empty()
|
|
439
|
+
domain.persist(tmp_domain_path)
|
|
440
|
+
test_data_importer = TrainingDataImporter.load_from_dict(
|
|
441
|
+
training_data_paths=[resource_name], domain_path=str(tmp_domain_path)
|
|
442
|
+
)
|
|
443
|
+
if use_conversation_test_files:
|
|
444
|
+
story_graph = test_data_importer.get_conversation_tests()
|
|
445
|
+
else:
|
|
446
|
+
story_graph = test_data_importer.get_stories()
|
|
447
|
+
|
|
448
|
+
return TrainingDataGenerator(
|
|
449
|
+
story_graph,
|
|
450
|
+
agent.domain,
|
|
451
|
+
use_story_concatenation=False,
|
|
452
|
+
augmentation_factor=0,
|
|
453
|
+
tracker_limit=max_stories,
|
|
454
|
+
)
|
|
455
|
+
|
|
456
|
+
|
|
457
|
+
def _clean_entity_results(
|
|
458
|
+
text: Text, entity_results: List[Dict[Text, Any]]
|
|
459
|
+
) -> List["EntityPrediction"]:
|
|
460
|
+
"""Extract only the token variables from an entity dict."""
|
|
461
|
+
cleaned_entities = []
|
|
462
|
+
|
|
463
|
+
for r in tuple(entity_results):
|
|
464
|
+
cleaned_entity: EntityPrediction = {ENTITY_ATTRIBUTE_TEXT: text} # type: ignore[misc]
|
|
465
|
+
for k in (
|
|
466
|
+
ENTITY_ATTRIBUTE_START,
|
|
467
|
+
ENTITY_ATTRIBUTE_END,
|
|
468
|
+
ENTITY_ATTRIBUTE_TYPE,
|
|
469
|
+
ENTITY_ATTRIBUTE_VALUE,
|
|
470
|
+
):
|
|
471
|
+
if k in set(r):
|
|
472
|
+
if k == ENTITY_ATTRIBUTE_VALUE and EXTRACTOR in set(r):
|
|
473
|
+
# convert values to strings for evaluation as
|
|
474
|
+
# target values are all of type string
|
|
475
|
+
r[k] = str(r[k])
|
|
476
|
+
cleaned_entity[k] = r[k] # type: ignore[literal-required]
|
|
477
|
+
cleaned_entities.append(cleaned_entity)
|
|
478
|
+
|
|
479
|
+
return cleaned_entities
|
|
480
|
+
|
|
481
|
+
|
|
482
|
+
def _get_full_retrieval_intent(parsed: Dict[Text, Any]) -> Text:
|
|
483
|
+
"""Return full retrieval intent, if it's present, or normal intent otherwise.
|
|
484
|
+
|
|
485
|
+
Args:
|
|
486
|
+
parsed: Predicted parsed data.
|
|
487
|
+
|
|
488
|
+
Returns:
|
|
489
|
+
The extracted intent.
|
|
490
|
+
"""
|
|
491
|
+
base_intent = parsed.get(INTENT, {}).get(INTENT_NAME_KEY)
|
|
492
|
+
response_selector = parsed.get(RESPONSE_SELECTOR, {})
|
|
493
|
+
|
|
494
|
+
# return normal intent if it's not a retrieval intent
|
|
495
|
+
if base_intent not in response_selector.get(
|
|
496
|
+
RESPONSE_SELECTOR_RETRIEVAL_INTENTS, {}
|
|
497
|
+
):
|
|
498
|
+
return base_intent
|
|
499
|
+
|
|
500
|
+
# extract full retrieval intent
|
|
501
|
+
# if the response selector parameter was not specified in config,
|
|
502
|
+
# the response selector contains a "default" key
|
|
503
|
+
if RESPONSE_SELECTOR_DEFAULT_INTENT in response_selector:
|
|
504
|
+
full_retrieval_intent = (
|
|
505
|
+
response_selector.get(RESPONSE_SELECTOR_DEFAULT_INTENT, {})
|
|
506
|
+
.get(RESPONSE, {})
|
|
507
|
+
.get(INTENT_RESPONSE_KEY)
|
|
508
|
+
)
|
|
509
|
+
return full_retrieval_intent if full_retrieval_intent else base_intent
|
|
510
|
+
|
|
511
|
+
# if specified, the response selector contains the base intent as key
|
|
512
|
+
full_retrieval_intent = (
|
|
513
|
+
response_selector.get(base_intent, {})
|
|
514
|
+
.get(RESPONSE, {})
|
|
515
|
+
.get(INTENT_RESPONSE_KEY)
|
|
516
|
+
)
|
|
517
|
+
return full_retrieval_intent if full_retrieval_intent else base_intent
|
|
518
|
+
|
|
519
|
+
|
|
520
|
+
def _collect_user_uttered_predictions(
|
|
521
|
+
event: UserUttered,
|
|
522
|
+
predicted: Dict[Text, Any],
|
|
523
|
+
partial_tracker: DialogueStateTracker,
|
|
524
|
+
fail_on_prediction_errors: bool,
|
|
525
|
+
) -> EvaluationStore:
|
|
526
|
+
user_uttered_eval_store = EvaluationStore()
|
|
527
|
+
|
|
528
|
+
# intent from the test story, may either be base intent or full retrieval intent
|
|
529
|
+
base_intent = event.intent.get(INTENT_NAME_KEY)
|
|
530
|
+
full_retrieval_intent = event.intent.get(FULL_RETRIEVAL_INTENT_NAME_KEY)
|
|
531
|
+
intent_gold = full_retrieval_intent if full_retrieval_intent else base_intent
|
|
532
|
+
|
|
533
|
+
# predicted intent: note that this is only the base intent at this point
|
|
534
|
+
predicted_base_intent = predicted.get(INTENT, {}).get(INTENT_NAME_KEY)
|
|
535
|
+
# if the test story only provides the base intent AND the prediction was correct,
|
|
536
|
+
# we are not interested in full retrieval intents and skip this section.
|
|
537
|
+
# In any other case we are interested in the full retrieval intent (e.g. for report)
|
|
538
|
+
if intent_gold != predicted_base_intent:
|
|
539
|
+
predicted_base_intent = _get_full_retrieval_intent(predicted)
|
|
540
|
+
|
|
541
|
+
user_uttered_eval_store.add_to_store(
|
|
542
|
+
intent_targets=[intent_gold], intent_predictions=[predicted_base_intent]
|
|
543
|
+
)
|
|
544
|
+
|
|
545
|
+
entity_gold = event.entities
|
|
546
|
+
predicted_entities = predicted.get(ENTITIES)
|
|
547
|
+
|
|
548
|
+
if entity_gold or predicted_entities:
|
|
549
|
+
user_uttered_eval_store.add_to_store(
|
|
550
|
+
entity_targets=_clean_entity_results(event.text, entity_gold),
|
|
551
|
+
entity_predictions=_clean_entity_results(event.text, predicted_entities),
|
|
552
|
+
)
|
|
553
|
+
|
|
554
|
+
if user_uttered_eval_store.check_prediction_target_mismatch():
|
|
555
|
+
partial_tracker.update(
|
|
556
|
+
WronglyClassifiedUserUtterance(event, user_uttered_eval_store)
|
|
557
|
+
)
|
|
558
|
+
if fail_on_prediction_errors:
|
|
559
|
+
story_dump = YAMLStoryWriter().dumps(partial_tracker.as_story().story_steps)
|
|
560
|
+
raise WrongPredictionException(
|
|
561
|
+
f"NLU model predicted a wrong intent or entities. Failed Story:"
|
|
562
|
+
f" \n\n{story_dump}"
|
|
563
|
+
)
|
|
564
|
+
else:
|
|
565
|
+
response_selector_info = (
|
|
566
|
+
{
|
|
567
|
+
RESPONSE_SELECTOR_PROPERTY_NAME: predicted[
|
|
568
|
+
RESPONSE_SELECTOR_PROPERTY_NAME
|
|
569
|
+
]
|
|
570
|
+
}
|
|
571
|
+
if RESPONSE_SELECTOR_PROPERTY_NAME in predicted
|
|
572
|
+
else None
|
|
573
|
+
)
|
|
574
|
+
end_to_end_user_utterance = EndToEndUserUtterance(
|
|
575
|
+
text=event.text,
|
|
576
|
+
intent=event.intent,
|
|
577
|
+
entities=event.entities,
|
|
578
|
+
parse_data=response_selector_info,
|
|
579
|
+
)
|
|
580
|
+
partial_tracker.update(end_to_end_user_utterance)
|
|
581
|
+
|
|
582
|
+
return user_uttered_eval_store
|
|
583
|
+
|
|
584
|
+
|
|
585
|
+
def emulate_loop_rejection(partial_tracker: DialogueStateTracker) -> None:
|
|
586
|
+
"""Add `ActionExecutionRejected` event to the tracker.
|
|
587
|
+
|
|
588
|
+
During evaluation, we don't run action server, therefore in order to correctly
|
|
589
|
+
test unhappy paths of the loops, we need to emulate loop rejection.
|
|
590
|
+
|
|
591
|
+
Args:
|
|
592
|
+
partial_tracker: a :class:`rasa.core.trackers.DialogueStateTracker`
|
|
593
|
+
"""
|
|
594
|
+
from rasa.shared.core.events import ActionExecutionRejected
|
|
595
|
+
|
|
596
|
+
rejected_action_name = partial_tracker.active_loop_name
|
|
597
|
+
partial_tracker.update(ActionExecutionRejected(rejected_action_name))
|
|
598
|
+
|
|
599
|
+
|
|
600
|
+
async def _get_e2e_entity_evaluation_result(
|
|
601
|
+
processor: "MessageProcessor",
|
|
602
|
+
tracker: DialogueStateTracker,
|
|
603
|
+
prediction: PolicyPrediction,
|
|
604
|
+
) -> Optional[EntityEvaluationResult]:
|
|
605
|
+
previous_event: Optional["Event"] = tracker.events[-1]
|
|
606
|
+
|
|
607
|
+
if isinstance(previous_event, SlotSet):
|
|
608
|
+
# UserUttered events with entities can be followed by SlotSet events
|
|
609
|
+
# if slots are defined in the domain
|
|
610
|
+
previous_event = tracker.get_last_event_for((UserUttered, ActionExecuted))
|
|
611
|
+
|
|
612
|
+
if isinstance(previous_event, UserUttered):
|
|
613
|
+
entities_predicted_by_policies = [
|
|
614
|
+
entity
|
|
615
|
+
for prediction_event in prediction.events
|
|
616
|
+
if isinstance(prediction_event, EntitiesAdded)
|
|
617
|
+
for entity in prediction_event.entities
|
|
618
|
+
]
|
|
619
|
+
entity_targets = previous_event.entities
|
|
620
|
+
if entity_targets or entities_predicted_by_policies:
|
|
621
|
+
text = previous_event.text
|
|
622
|
+
if text:
|
|
623
|
+
parsed_message = await processor.parse_message(UserMessage(text=text))
|
|
624
|
+
if parsed_message:
|
|
625
|
+
tokens = [
|
|
626
|
+
Token(text[start:end], start, end)
|
|
627
|
+
for start, end in parsed_message.get(TOKENS_NAMES[TEXT], [])
|
|
628
|
+
]
|
|
629
|
+
return EntityEvaluationResult(
|
|
630
|
+
entity_targets, entities_predicted_by_policies, tokens, text
|
|
631
|
+
)
|
|
632
|
+
return None
|
|
633
|
+
|
|
634
|
+
|
|
635
|
+
def _get_predicted_action_name(
|
|
636
|
+
predicted_action: rasa.core.actions.action.Action,
|
|
637
|
+
partial_tracker: DialogueStateTracker,
|
|
638
|
+
expected_action_name: Text,
|
|
639
|
+
) -> Text:
|
|
640
|
+
"""Get the name of predicted action.
|
|
641
|
+
|
|
642
|
+
If the action is instance of `ActionRetrieveResponse`, we need to return full
|
|
643
|
+
action name with its retrieval intent (e.g. utter_faq/is-this-legit).
|
|
644
|
+
The only case when we should not do it is when an expected action given in
|
|
645
|
+
a test story is a retrieval action but it's not specified in the test story.
|
|
646
|
+
To illustrate this, we're basically avoiding this unnecessary mismatch:
|
|
647
|
+
utter_faq (expected) != utter_faq/is-this-legit (predicted).
|
|
648
|
+
In this case or if the action isn't instance of `ActionRetrieveResponse`,
|
|
649
|
+
the function returns only the action name (e.g. utter_faq).
|
|
650
|
+
"""
|
|
651
|
+
if (
|
|
652
|
+
isinstance(predicted_action, ActionRetrieveResponse)
|
|
653
|
+
and expected_action_name != predicted_action.name()
|
|
654
|
+
):
|
|
655
|
+
full_retrieval_name = predicted_action.get_full_retrieval_name(partial_tracker)
|
|
656
|
+
predicted_action_name = (
|
|
657
|
+
full_retrieval_name if full_retrieval_name else predicted_action.name()
|
|
658
|
+
)
|
|
659
|
+
else:
|
|
660
|
+
predicted_action_name = predicted_action.name()
|
|
661
|
+
return predicted_action_name
|
|
662
|
+
|
|
663
|
+
|
|
664
|
+
async def _run_action_prediction(
|
|
665
|
+
processor: "MessageProcessor",
|
|
666
|
+
partial_tracker: DialogueStateTracker,
|
|
667
|
+
expected_action: Text,
|
|
668
|
+
) -> Tuple[Text, PolicyPrediction, Optional[EntityEvaluationResult]]:
|
|
669
|
+
action, prediction = await processor.predict_next_with_tracker_if_should(
|
|
670
|
+
partial_tracker
|
|
671
|
+
)
|
|
672
|
+
predicted_action = _get_predicted_action_name(
|
|
673
|
+
action, partial_tracker, expected_action
|
|
674
|
+
)
|
|
675
|
+
|
|
676
|
+
policy_entity_result = await _get_e2e_entity_evaluation_result(
|
|
677
|
+
processor, partial_tracker, prediction
|
|
678
|
+
)
|
|
679
|
+
if (
|
|
680
|
+
prediction.policy_name
|
|
681
|
+
and predicted_action != expected_action
|
|
682
|
+
and _form_might_have_been_rejected(
|
|
683
|
+
processor.domain, partial_tracker, predicted_action
|
|
684
|
+
)
|
|
685
|
+
):
|
|
686
|
+
# Wrong action was predicted,
|
|
687
|
+
# but it might be Ok if form action is rejected.
|
|
688
|
+
emulate_loop_rejection(partial_tracker)
|
|
689
|
+
# try again
|
|
690
|
+
action, prediction = await processor.predict_next_with_tracker_if_should(
|
|
691
|
+
partial_tracker
|
|
692
|
+
)
|
|
693
|
+
# Even if the prediction is also wrong, we don't have to undo the emulation
|
|
694
|
+
# of the action rejection as we know that the user explicitly specified
|
|
695
|
+
# that something else than the form was supposed to run.
|
|
696
|
+
predicted_action = _get_predicted_action_name(
|
|
697
|
+
action, partial_tracker, expected_action
|
|
698
|
+
)
|
|
699
|
+
|
|
700
|
+
return predicted_action, prediction, policy_entity_result
|
|
701
|
+
|
|
702
|
+
|
|
703
|
+
async def _collect_action_executed_predictions(
|
|
704
|
+
processor: "MessageProcessor",
|
|
705
|
+
partial_tracker: DialogueStateTracker,
|
|
706
|
+
event: ActionExecuted,
|
|
707
|
+
fail_on_prediction_errors: bool,
|
|
708
|
+
) -> Tuple[EvaluationStore, PolicyPrediction, Optional[EntityEvaluationResult]]:
|
|
709
|
+
action_executed_eval_store = EvaluationStore()
|
|
710
|
+
|
|
711
|
+
expected_action_name = event.action_name
|
|
712
|
+
expected_action_text = event.action_text
|
|
713
|
+
expected_action = expected_action_name or expected_action_text
|
|
714
|
+
|
|
715
|
+
policy_entity_result = None
|
|
716
|
+
prev_action_unlikely_intent = False
|
|
717
|
+
|
|
718
|
+
try:
|
|
719
|
+
(
|
|
720
|
+
predicted_action,
|
|
721
|
+
prediction,
|
|
722
|
+
policy_entity_result,
|
|
723
|
+
) = await _run_action_prediction(processor, partial_tracker, expected_action)
|
|
724
|
+
except ActionLimitReached:
|
|
725
|
+
prediction = PolicyPrediction([], policy_name=None)
|
|
726
|
+
predicted_action = "circuit breaker tripped"
|
|
727
|
+
|
|
728
|
+
predicted_action_unlikely_intent = predicted_action == ACTION_UNLIKELY_INTENT_NAME
|
|
729
|
+
if predicted_action_unlikely_intent and predicted_action != expected_action:
|
|
730
|
+
partial_tracker.update(
|
|
731
|
+
WronglyPredictedAction(
|
|
732
|
+
predicted_action,
|
|
733
|
+
expected_action_text,
|
|
734
|
+
predicted_action,
|
|
735
|
+
prediction.policy_name,
|
|
736
|
+
prediction.max_confidence,
|
|
737
|
+
event.timestamp,
|
|
738
|
+
metadata=prediction.action_metadata,
|
|
739
|
+
)
|
|
740
|
+
)
|
|
741
|
+
prev_action_unlikely_intent = True
|
|
742
|
+
|
|
743
|
+
try:
|
|
744
|
+
(
|
|
745
|
+
predicted_action,
|
|
746
|
+
prediction,
|
|
747
|
+
policy_entity_result,
|
|
748
|
+
) = await _run_action_prediction(
|
|
749
|
+
processor, partial_tracker, expected_action
|
|
750
|
+
)
|
|
751
|
+
except ActionLimitReached:
|
|
752
|
+
prediction = PolicyPrediction([], policy_name=None)
|
|
753
|
+
predicted_action = "circuit breaker tripped"
|
|
754
|
+
|
|
755
|
+
action_executed_eval_store.add_to_store(
|
|
756
|
+
action_predictions=[predicted_action], action_targets=[expected_action]
|
|
757
|
+
)
|
|
758
|
+
|
|
759
|
+
if action_executed_eval_store.check_prediction_target_mismatch():
|
|
760
|
+
partial_tracker.update(
|
|
761
|
+
WronglyPredictedAction(
|
|
762
|
+
expected_action_name,
|
|
763
|
+
expected_action_text,
|
|
764
|
+
predicted_action,
|
|
765
|
+
prediction.policy_name,
|
|
766
|
+
prediction.max_confidence,
|
|
767
|
+
event.timestamp,
|
|
768
|
+
metadata=prediction.action_metadata,
|
|
769
|
+
predicted_action_unlikely_intent=prev_action_unlikely_intent,
|
|
770
|
+
)
|
|
771
|
+
)
|
|
772
|
+
if (
|
|
773
|
+
fail_on_prediction_errors
|
|
774
|
+
and predicted_action != ACTION_UNLIKELY_INTENT_NAME
|
|
775
|
+
and predicted_action != expected_action
|
|
776
|
+
):
|
|
777
|
+
story_dump = YAMLStoryWriter().dumps(partial_tracker.as_story().story_steps)
|
|
778
|
+
error_msg = (
|
|
779
|
+
f"Model predicted a wrong action. Failed Story: " f"\n\n{story_dump}"
|
|
780
|
+
)
|
|
781
|
+
raise WrongPredictionException(error_msg)
|
|
782
|
+
elif prev_action_unlikely_intent:
|
|
783
|
+
partial_tracker.update(
|
|
784
|
+
WarningPredictedAction(
|
|
785
|
+
ACTION_UNLIKELY_INTENT_NAME,
|
|
786
|
+
predicted_action,
|
|
787
|
+
prediction.policy_name,
|
|
788
|
+
prediction.max_confidence,
|
|
789
|
+
event.timestamp,
|
|
790
|
+
prediction.action_metadata,
|
|
791
|
+
)
|
|
792
|
+
)
|
|
793
|
+
else:
|
|
794
|
+
partial_tracker.update(
|
|
795
|
+
ActionExecuted(
|
|
796
|
+
predicted_action,
|
|
797
|
+
prediction.policy_name,
|
|
798
|
+
prediction.max_confidence,
|
|
799
|
+
event.timestamp,
|
|
800
|
+
metadata=prediction.action_metadata,
|
|
801
|
+
)
|
|
802
|
+
)
|
|
803
|
+
|
|
804
|
+
return action_executed_eval_store, prediction, policy_entity_result
|
|
805
|
+
|
|
806
|
+
|
|
807
|
+
def _form_might_have_been_rejected(
|
|
808
|
+
domain: Domain, tracker: DialogueStateTracker, predicted_action_name: Text
|
|
809
|
+
) -> bool:
|
|
810
|
+
return (
|
|
811
|
+
tracker.active_loop_name == predicted_action_name
|
|
812
|
+
and predicted_action_name in domain.form_names
|
|
813
|
+
)
|
|
814
|
+
|
|
815
|
+
|
|
816
|
+
async def _predict_tracker_actions(
|
|
817
|
+
tracker: DialogueStateTracker,
|
|
818
|
+
agent: "Agent",
|
|
819
|
+
fail_on_prediction_errors: bool = False,
|
|
820
|
+
use_e2e: bool = False,
|
|
821
|
+
) -> Tuple[
|
|
822
|
+
EvaluationStore,
|
|
823
|
+
DialogueStateTracker,
|
|
824
|
+
List[Dict[Text, Any]],
|
|
825
|
+
List[EntityEvaluationResult],
|
|
826
|
+
]:
|
|
827
|
+
processor = agent.processor
|
|
828
|
+
if agent.processor is not None:
|
|
829
|
+
processor = agent.processor
|
|
830
|
+
else:
|
|
831
|
+
raise RasaException(
|
|
832
|
+
"The agent's processor has not been instantiated. "
|
|
833
|
+
"The processor needs to be defined before running "
|
|
834
|
+
"prediction."
|
|
835
|
+
)
|
|
836
|
+
|
|
837
|
+
tracker_eval_store = EvaluationStore()
|
|
838
|
+
|
|
839
|
+
events = list(tracker.events)
|
|
840
|
+
|
|
841
|
+
slots = agent.domain.slots if agent.domain is not None else []
|
|
842
|
+
|
|
843
|
+
partial_tracker = DialogueStateTracker.from_events(
|
|
844
|
+
tracker.sender_id,
|
|
845
|
+
events[:1],
|
|
846
|
+
slots,
|
|
847
|
+
sender_source=tracker.sender_source,
|
|
848
|
+
)
|
|
849
|
+
tracker_actions = []
|
|
850
|
+
policy_entity_results = []
|
|
851
|
+
|
|
852
|
+
for event in events[1:]:
|
|
853
|
+
if isinstance(event, ActionExecuted):
|
|
854
|
+
(
|
|
855
|
+
action_executed_result,
|
|
856
|
+
prediction,
|
|
857
|
+
entity_result,
|
|
858
|
+
) = await _collect_action_executed_predictions(
|
|
859
|
+
processor, partial_tracker, event, fail_on_prediction_errors
|
|
860
|
+
)
|
|
861
|
+
if entity_result:
|
|
862
|
+
policy_entity_results.append(entity_result)
|
|
863
|
+
|
|
864
|
+
if action_executed_result.action_targets:
|
|
865
|
+
tracker_eval_store.merge_store(action_executed_result)
|
|
866
|
+
tracker_actions.append(
|
|
867
|
+
{
|
|
868
|
+
"action": action_executed_result.action_targets[0],
|
|
869
|
+
"predicted": action_executed_result.action_predictions[0],
|
|
870
|
+
"policy": prediction.policy_name,
|
|
871
|
+
"confidence": prediction.max_confidence,
|
|
872
|
+
}
|
|
873
|
+
)
|
|
874
|
+
elif use_e2e and isinstance(event, UserUttered):
|
|
875
|
+
# This means that user utterance didn't have a user message, only intent,
|
|
876
|
+
# so we can skip the NLU part and take the parse data directly.
|
|
877
|
+
# Indirectly that means that the test story was in YAML format.
|
|
878
|
+
if not event.text:
|
|
879
|
+
# FIXME: better type annotation for `parse_data` would require
|
|
880
|
+
# a larger refactoring (e.g. switch to dataclass)
|
|
881
|
+
predicted = cast(Dict[Text, Any], event.parse_data)
|
|
882
|
+
# Indirectly that means that the test story was either:
|
|
883
|
+
# in YAML format containing a user message, or in Markdown format.
|
|
884
|
+
# Leaving that as it is because Markdown is in legacy mode.
|
|
885
|
+
else:
|
|
886
|
+
predicted = await processor.parse_message(UserMessage(event.text))
|
|
887
|
+
|
|
888
|
+
user_uttered_result = _collect_user_uttered_predictions(
|
|
889
|
+
event, predicted, partial_tracker, fail_on_prediction_errors
|
|
890
|
+
)
|
|
891
|
+
tracker_eval_store.merge_store(user_uttered_result)
|
|
892
|
+
else:
|
|
893
|
+
partial_tracker.update(event)
|
|
894
|
+
return tracker_eval_store, partial_tracker, tracker_actions, policy_entity_results
|
|
895
|
+
|
|
896
|
+
|
|
897
|
+
def _in_training_data_fraction(action_list: List[Dict[Text, Any]]) -> float:
|
|
898
|
+
"""Given a list of actions, returns the fraction predicted by non ML policies."""
|
|
899
|
+
import rasa.core.policies.ensemble
|
|
900
|
+
|
|
901
|
+
in_training_data = [
|
|
902
|
+
a["action"]
|
|
903
|
+
for a in action_list
|
|
904
|
+
if a["policy"]
|
|
905
|
+
and not rasa.core.policies.ensemble.is_not_in_training_data(a["policy"])
|
|
906
|
+
]
|
|
907
|
+
|
|
908
|
+
return len(in_training_data) / len(action_list) if action_list else 0
|
|
909
|
+
|
|
910
|
+
|
|
911
|
+
def _sort_trackers_with_severity_of_warning(
|
|
912
|
+
trackers_to_sort: List[DialogueStateTracker],
|
|
913
|
+
) -> List[DialogueStateTracker]:
|
|
914
|
+
"""Sort the given trackers according to 'severity' of `action_unlikely_intent`.
|
|
915
|
+
|
|
916
|
+
Severity is calculated by `IntentTEDPolicy` and is attached as
|
|
917
|
+
metadata to `ActionExecuted` event.
|
|
918
|
+
|
|
919
|
+
Args:
|
|
920
|
+
trackers_to_sort: Trackers to be sorted
|
|
921
|
+
|
|
922
|
+
Returns:
|
|
923
|
+
Sorted trackers in descending order of severity.
|
|
924
|
+
"""
|
|
925
|
+
tracker_severity_scores = []
|
|
926
|
+
for tracker in trackers_to_sort:
|
|
927
|
+
max_severity = 0
|
|
928
|
+
for event in tracker.applied_events():
|
|
929
|
+
if (
|
|
930
|
+
isinstance(event, WronglyPredictedAction)
|
|
931
|
+
and event.action_name_prediction == ACTION_UNLIKELY_INTENT_NAME
|
|
932
|
+
):
|
|
933
|
+
max_severity = max(
|
|
934
|
+
max_severity,
|
|
935
|
+
event.metadata.get(QUERY_INTENT_KEY, {}).get(SEVERITY_KEY, 0),
|
|
936
|
+
)
|
|
937
|
+
tracker_severity_scores.append(max_severity)
|
|
938
|
+
|
|
939
|
+
sorted_trackers_with_severity = sorted(
|
|
940
|
+
zip(tracker_severity_scores, trackers_to_sort),
|
|
941
|
+
# tuple unpacking is not supported in
|
|
942
|
+
# python 3.x that's why it might look a bit weird
|
|
943
|
+
key=lambda severity_tracker_tuple: -severity_tracker_tuple[0],
|
|
944
|
+
)
|
|
945
|
+
|
|
946
|
+
return [tracker for (_, tracker) in sorted_trackers_with_severity]
|
|
947
|
+
|
|
948
|
+
|
|
949
|
+
async def _collect_story_predictions(
|
|
950
|
+
completed_trackers: List["DialogueStateTracker"],
|
|
951
|
+
agent: "Agent",
|
|
952
|
+
fail_on_prediction_errors: bool = False,
|
|
953
|
+
use_e2e: bool = False,
|
|
954
|
+
) -> Tuple[StoryEvaluation, int, List[EntityEvaluationResult]]:
|
|
955
|
+
"""Test the stories from a file, running them through the stored model."""
|
|
956
|
+
from sklearn.metrics import accuracy_score
|
|
957
|
+
from tqdm import tqdm
|
|
958
|
+
|
|
959
|
+
story_eval_store = EvaluationStore()
|
|
960
|
+
failed_stories = []
|
|
961
|
+
successful_stories = []
|
|
962
|
+
stories_with_warnings = []
|
|
963
|
+
correct_dialogues = []
|
|
964
|
+
number_of_stories = len(completed_trackers)
|
|
965
|
+
|
|
966
|
+
logger.info(f"Evaluating {number_of_stories} stories\nProgress:")
|
|
967
|
+
|
|
968
|
+
action_list = []
|
|
969
|
+
entity_results = []
|
|
970
|
+
|
|
971
|
+
if agent.domain:
|
|
972
|
+
for slot in agent.domain.slots:
|
|
973
|
+
# set the routing slot to False in case the coexistence feature is used
|
|
974
|
+
# this way the DM1 policies will run and the CALM policies will keep silent
|
|
975
|
+
if slot.name == ROUTE_TO_CALM_SLOT:
|
|
976
|
+
slot.initial_value = False
|
|
977
|
+
|
|
978
|
+
for tracker in tqdm(completed_trackers):
|
|
979
|
+
(
|
|
980
|
+
tracker_results,
|
|
981
|
+
predicted_tracker,
|
|
982
|
+
tracker_actions,
|
|
983
|
+
tracker_entity_results,
|
|
984
|
+
) = await _predict_tracker_actions(
|
|
985
|
+
tracker, agent, fail_on_prediction_errors, use_e2e
|
|
986
|
+
)
|
|
987
|
+
|
|
988
|
+
entity_results.extend(tracker_entity_results)
|
|
989
|
+
|
|
990
|
+
story_eval_store.merge_store(tracker_results)
|
|
991
|
+
|
|
992
|
+
action_list.extend(tracker_actions)
|
|
993
|
+
|
|
994
|
+
if tracker_results.check_prediction_target_mismatch():
|
|
995
|
+
# there is at least one wrong prediction
|
|
996
|
+
failed_stories.append(predicted_tracker)
|
|
997
|
+
correct_dialogues.append(0)
|
|
998
|
+
else:
|
|
999
|
+
successful_stories.append(predicted_tracker)
|
|
1000
|
+
correct_dialogues.append(1)
|
|
1001
|
+
|
|
1002
|
+
if any(
|
|
1003
|
+
isinstance(event, WronglyPredictedAction)
|
|
1004
|
+
and event.action_name_prediction == ACTION_UNLIKELY_INTENT_NAME
|
|
1005
|
+
for event in predicted_tracker.events
|
|
1006
|
+
):
|
|
1007
|
+
stories_with_warnings.append(predicted_tracker)
|
|
1008
|
+
|
|
1009
|
+
logger.info("Finished collecting predictions.")
|
|
1010
|
+
|
|
1011
|
+
in_training_data_fraction = _in_training_data_fraction(action_list)
|
|
1012
|
+
|
|
1013
|
+
if len(correct_dialogues):
|
|
1014
|
+
accuracy = accuracy_score([1] * len(correct_dialogues), correct_dialogues)
|
|
1015
|
+
else:
|
|
1016
|
+
accuracy = 0
|
|
1017
|
+
|
|
1018
|
+
_log_evaluation_table([1] * len(completed_trackers), "CONVERSATION", accuracy)
|
|
1019
|
+
|
|
1020
|
+
return (
|
|
1021
|
+
StoryEvaluation(
|
|
1022
|
+
evaluation_store=story_eval_store,
|
|
1023
|
+
failed_stories=failed_stories,
|
|
1024
|
+
successful_stories=successful_stories,
|
|
1025
|
+
stories_with_warnings=_sort_trackers_with_severity_of_warning(
|
|
1026
|
+
stories_with_warnings
|
|
1027
|
+
),
|
|
1028
|
+
action_list=action_list,
|
|
1029
|
+
in_training_data_fraction=in_training_data_fraction,
|
|
1030
|
+
),
|
|
1031
|
+
number_of_stories,
|
|
1032
|
+
entity_results,
|
|
1033
|
+
)
|
|
1034
|
+
|
|
1035
|
+
|
|
1036
|
+
def _filter_step_events(step: StoryStep) -> StoryStep:
|
|
1037
|
+
events = []
|
|
1038
|
+
for event in step.events:
|
|
1039
|
+
if (
|
|
1040
|
+
isinstance(event, WronglyPredictedAction)
|
|
1041
|
+
and event.action_name
|
|
1042
|
+
== event.action_name_prediction
|
|
1043
|
+
== ACTION_UNLIKELY_INTENT_NAME
|
|
1044
|
+
):
|
|
1045
|
+
continue
|
|
1046
|
+
events.append(event)
|
|
1047
|
+
updated_step = step.create_copy(use_new_id=False)
|
|
1048
|
+
updated_step.events = events
|
|
1049
|
+
return updated_step
|
|
1050
|
+
|
|
1051
|
+
|
|
1052
|
+
def _log_stories(
|
|
1053
|
+
trackers: List[DialogueStateTracker], file_path: Text, message_if_no_trackers: Text
|
|
1054
|
+
) -> None:
|
|
1055
|
+
"""Write given stories to the given file."""
|
|
1056
|
+
with open(file_path, "w", encoding=DEFAULT_ENCODING) as f:
|
|
1057
|
+
if not trackers:
|
|
1058
|
+
f.write(f"# {message_if_no_trackers}")
|
|
1059
|
+
else:
|
|
1060
|
+
stories = [tracker.as_story(include_source=True) for tracker in trackers]
|
|
1061
|
+
steps = [
|
|
1062
|
+
_filter_step_events(step)
|
|
1063
|
+
for story in stories
|
|
1064
|
+
for step in story.story_steps
|
|
1065
|
+
]
|
|
1066
|
+
f.write(YAMLStoryWriter().dumps(steps))
|
|
1067
|
+
|
|
1068
|
+
|
|
1069
|
+
async def test(
|
|
1070
|
+
stories: Text,
|
|
1071
|
+
agent: "Agent",
|
|
1072
|
+
max_stories: Optional[int] = None,
|
|
1073
|
+
out_directory: Optional[Text] = None,
|
|
1074
|
+
fail_on_prediction_errors: bool = False,
|
|
1075
|
+
e2e: bool = False,
|
|
1076
|
+
disable_plotting: bool = False,
|
|
1077
|
+
successes: bool = False,
|
|
1078
|
+
errors: bool = True,
|
|
1079
|
+
warnings: bool = True,
|
|
1080
|
+
) -> Dict[Text, Any]:
|
|
1081
|
+
"""Run the evaluation of the stories, optionally plot the results.
|
|
1082
|
+
|
|
1083
|
+
Args:
|
|
1084
|
+
stories: the stories to evaluate on
|
|
1085
|
+
agent: the agent
|
|
1086
|
+
max_stories: maximum number of stories to consider
|
|
1087
|
+
out_directory: path to directory to results to
|
|
1088
|
+
fail_on_prediction_errors: boolean indicating whether to fail on prediction
|
|
1089
|
+
errors or not
|
|
1090
|
+
e2e: boolean indicating whether to use end to end evaluation or not
|
|
1091
|
+
disable_plotting: boolean indicating whether to disable plotting or not
|
|
1092
|
+
successes: boolean indicating whether to write down successful predictions or
|
|
1093
|
+
not
|
|
1094
|
+
errors: boolean indicating whether to write down incorrect predictions or not
|
|
1095
|
+
warnings: boolean indicating whether to write down prediction warnings or not
|
|
1096
|
+
|
|
1097
|
+
Returns:
|
|
1098
|
+
Evaluation summary.
|
|
1099
|
+
"""
|
|
1100
|
+
from rasa.model_testing import get_evaluation_metrics
|
|
1101
|
+
|
|
1102
|
+
generator = _create_data_generator(stories, agent, max_stories, e2e)
|
|
1103
|
+
completed_trackers = generator.generate_story_trackers()
|
|
1104
|
+
|
|
1105
|
+
story_evaluation, _, entity_results = await _collect_story_predictions(
|
|
1106
|
+
completed_trackers, agent, fail_on_prediction_errors, use_e2e=e2e
|
|
1107
|
+
)
|
|
1108
|
+
|
|
1109
|
+
evaluation_store = story_evaluation.evaluation_store
|
|
1110
|
+
|
|
1111
|
+
with pywarnings.catch_warnings():
|
|
1112
|
+
from sklearn.exceptions import UndefinedMetricWarning
|
|
1113
|
+
|
|
1114
|
+
pywarnings.simplefilter("ignore", UndefinedMetricWarning)
|
|
1115
|
+
|
|
1116
|
+
targets, predictions = evaluation_store.serialise()
|
|
1117
|
+
|
|
1118
|
+
report, precision, f1, action_accuracy = get_evaluation_metrics(
|
|
1119
|
+
targets, predictions, output_dict=True
|
|
1120
|
+
)
|
|
1121
|
+
if out_directory:
|
|
1122
|
+
# Add conversation level accuracy to story report.
|
|
1123
|
+
num_failed = len(story_evaluation.failed_stories)
|
|
1124
|
+
num_correct = len(story_evaluation.successful_stories)
|
|
1125
|
+
num_warnings = len(story_evaluation.stories_with_warnings)
|
|
1126
|
+
num_convs = num_failed + num_correct
|
|
1127
|
+
if num_convs and isinstance(report, Dict):
|
|
1128
|
+
conv_accuracy = num_correct / num_convs
|
|
1129
|
+
report["conversation_accuracy"] = {
|
|
1130
|
+
"accuracy": conv_accuracy,
|
|
1131
|
+
"correct": num_correct,
|
|
1132
|
+
"with_warnings": num_warnings,
|
|
1133
|
+
"total": num_convs,
|
|
1134
|
+
}
|
|
1135
|
+
report_filename = os.path.join(out_directory, REPORT_STORIES_FILE)
|
|
1136
|
+
rasa.shared.utils.io.dump_obj_as_json_to_file(report_filename, report)
|
|
1137
|
+
logger.info(f"Stories report saved to {report_filename}.")
|
|
1138
|
+
|
|
1139
|
+
evaluate_entities(
|
|
1140
|
+
entity_results,
|
|
1141
|
+
POLICIES_THAT_EXTRACT_ENTITIES,
|
|
1142
|
+
out_directory,
|
|
1143
|
+
successes,
|
|
1144
|
+
errors,
|
|
1145
|
+
disable_plotting,
|
|
1146
|
+
)
|
|
1147
|
+
|
|
1148
|
+
telemetry.track_core_model_test(len(generator.story_graph.story_steps), e2e, agent)
|
|
1149
|
+
|
|
1150
|
+
_log_evaluation_table(
|
|
1151
|
+
evaluation_store.action_targets,
|
|
1152
|
+
"ACTION",
|
|
1153
|
+
action_accuracy,
|
|
1154
|
+
precision=precision,
|
|
1155
|
+
f1=f1,
|
|
1156
|
+
in_training_data_fraction=story_evaluation.in_training_data_fraction,
|
|
1157
|
+
)
|
|
1158
|
+
|
|
1159
|
+
if not disable_plotting and out_directory:
|
|
1160
|
+
_plot_story_evaluation(
|
|
1161
|
+
evaluation_store.action_targets,
|
|
1162
|
+
evaluation_store.action_predictions,
|
|
1163
|
+
out_directory,
|
|
1164
|
+
)
|
|
1165
|
+
|
|
1166
|
+
if errors and out_directory:
|
|
1167
|
+
_log_stories(
|
|
1168
|
+
story_evaluation.failed_stories,
|
|
1169
|
+
os.path.join(out_directory, FAILED_STORIES_FILE),
|
|
1170
|
+
"None of the test stories failed - all good!",
|
|
1171
|
+
)
|
|
1172
|
+
if successes and out_directory:
|
|
1173
|
+
_log_stories(
|
|
1174
|
+
story_evaluation.successful_stories,
|
|
1175
|
+
os.path.join(out_directory, SUCCESSFUL_STORIES_FILE),
|
|
1176
|
+
"None of the test stories succeeded :(",
|
|
1177
|
+
)
|
|
1178
|
+
if warnings and out_directory:
|
|
1179
|
+
_log_stories(
|
|
1180
|
+
story_evaluation.stories_with_warnings,
|
|
1181
|
+
os.path.join(out_directory, STORIES_WITH_WARNINGS_FILE),
|
|
1182
|
+
"No warnings for test stories",
|
|
1183
|
+
)
|
|
1184
|
+
|
|
1185
|
+
return {
|
|
1186
|
+
"report": report,
|
|
1187
|
+
"precision": precision,
|
|
1188
|
+
"f1": f1,
|
|
1189
|
+
"accuracy": action_accuracy,
|
|
1190
|
+
"actions": story_evaluation.action_list,
|
|
1191
|
+
"in_training_data_fraction": story_evaluation.in_training_data_fraction,
|
|
1192
|
+
"is_end_to_end_evaluation": e2e,
|
|
1193
|
+
}
|
|
1194
|
+
|
|
1195
|
+
|
|
1196
|
+
def _log_evaluation_table(
|
|
1197
|
+
golds: List[Any],
|
|
1198
|
+
name: Text,
|
|
1199
|
+
accuracy: float,
|
|
1200
|
+
report: Optional[Dict[Text, Any]] = None,
|
|
1201
|
+
precision: Optional[float] = None,
|
|
1202
|
+
f1: Optional[float] = None,
|
|
1203
|
+
in_training_data_fraction: Optional[float] = None,
|
|
1204
|
+
include_report: bool = True,
|
|
1205
|
+
) -> None: # pragma: no cover
|
|
1206
|
+
"""Log the sklearn evaluation metrics."""
|
|
1207
|
+
logger.info(f"Evaluation Results on {name} level:")
|
|
1208
|
+
logger.info(f"\tCorrect: {int(len(golds) * accuracy)} / {len(golds)}")
|
|
1209
|
+
if f1 is not None:
|
|
1210
|
+
logger.info(f"\tF1-Score: {f1:.3f}")
|
|
1211
|
+
if precision is not None:
|
|
1212
|
+
logger.info(f"\tPrecision: {precision:.3f}")
|
|
1213
|
+
logger.info(f"\tAccuracy: {accuracy:.3f}")
|
|
1214
|
+
if in_training_data_fraction is not None:
|
|
1215
|
+
logger.info(f"\tIn-data fraction: {in_training_data_fraction:.3g}")
|
|
1216
|
+
|
|
1217
|
+
if include_report and report is not None:
|
|
1218
|
+
logger.info(f"\tClassification report: \n{report}")
|
|
1219
|
+
|
|
1220
|
+
|
|
1221
|
+
def _plot_story_evaluation(
|
|
1222
|
+
targets: PredictionList,
|
|
1223
|
+
predictions: PredictionList,
|
|
1224
|
+
output_directory: Optional[Text],
|
|
1225
|
+
) -> None:
|
|
1226
|
+
"""Plot a confusion matrix of story evaluation."""
|
|
1227
|
+
from sklearn.metrics import confusion_matrix
|
|
1228
|
+
from sklearn.utils.multiclass import unique_labels
|
|
1229
|
+
from rasa.utils.plotting import plot_confusion_matrix
|
|
1230
|
+
|
|
1231
|
+
confusion_matrix_filename = CONFUSION_MATRIX_STORIES_FILE
|
|
1232
|
+
if output_directory:
|
|
1233
|
+
confusion_matrix_filename = os.path.join(
|
|
1234
|
+
output_directory, confusion_matrix_filename
|
|
1235
|
+
)
|
|
1236
|
+
|
|
1237
|
+
cnf_matrix = confusion_matrix(targets, predictions)
|
|
1238
|
+
|
|
1239
|
+
plot_confusion_matrix(
|
|
1240
|
+
cnf_matrix,
|
|
1241
|
+
classes=unique_labels(targets, predictions),
|
|
1242
|
+
title="Action Confusion matrix",
|
|
1243
|
+
output_file=confusion_matrix_filename,
|
|
1244
|
+
)
|
|
1245
|
+
|
|
1246
|
+
|
|
1247
|
+
async def compare_models_in_dir(
|
|
1248
|
+
model_dir: Text,
|
|
1249
|
+
stories_file: Text,
|
|
1250
|
+
output: Text,
|
|
1251
|
+
use_conversation_test_files: bool = False,
|
|
1252
|
+
) -> None:
|
|
1253
|
+
"""Evaluates multiple trained models in a directory on a test set.
|
|
1254
|
+
|
|
1255
|
+
Args:
|
|
1256
|
+
model_dir: path to directory that contains the models to evaluate
|
|
1257
|
+
stories_file: path to the story file
|
|
1258
|
+
output: output directory to store results to
|
|
1259
|
+
use_conversation_test_files: `True` if conversation test files should be used
|
|
1260
|
+
for testing instead of regular Core story files.
|
|
1261
|
+
"""
|
|
1262
|
+
number_correct = defaultdict(list)
|
|
1263
|
+
|
|
1264
|
+
for run in rasa.shared.utils.io.list_subdirectories(model_dir):
|
|
1265
|
+
number_correct_in_run = defaultdict(list)
|
|
1266
|
+
|
|
1267
|
+
for model in sorted(rasa.shared.utils.io.list_files(run)):
|
|
1268
|
+
if not model.endswith("tar.gz"):
|
|
1269
|
+
continue
|
|
1270
|
+
|
|
1271
|
+
# The model files are named like <config-name>PERCENTAGE_KEY<number>.tar.gz
|
|
1272
|
+
# Remove the percentage key and number from the name to get the config name
|
|
1273
|
+
config_name = os.path.basename(model).split(PERCENTAGE_KEY)[0]
|
|
1274
|
+
number_of_correct_stories = await _evaluate_core_model(
|
|
1275
|
+
model,
|
|
1276
|
+
stories_file,
|
|
1277
|
+
use_conversation_test_files=use_conversation_test_files,
|
|
1278
|
+
)
|
|
1279
|
+
number_correct_in_run[config_name].append(number_of_correct_stories)
|
|
1280
|
+
|
|
1281
|
+
for k, v in number_correct_in_run.items():
|
|
1282
|
+
number_correct[k].append(v)
|
|
1283
|
+
|
|
1284
|
+
rasa.shared.utils.io.dump_obj_as_json_to_file(
|
|
1285
|
+
os.path.join(output, RESULTS_FILE), number_correct
|
|
1286
|
+
)
|
|
1287
|
+
|
|
1288
|
+
|
|
1289
|
+
async def compare_models(
|
|
1290
|
+
models: List[Text],
|
|
1291
|
+
stories_file: Text,
|
|
1292
|
+
output: Text,
|
|
1293
|
+
use_conversation_test_files: bool = False,
|
|
1294
|
+
) -> None:
|
|
1295
|
+
"""Evaluates multiple trained models on a test set.
|
|
1296
|
+
|
|
1297
|
+
Args:
|
|
1298
|
+
models: Paths to model files.
|
|
1299
|
+
stories_file: path to the story file
|
|
1300
|
+
output: output directory to store results to
|
|
1301
|
+
use_conversation_test_files: `True` if conversation test files should be used
|
|
1302
|
+
for testing instead of regular Core story files.
|
|
1303
|
+
"""
|
|
1304
|
+
number_correct = defaultdict(list)
|
|
1305
|
+
|
|
1306
|
+
for model in models:
|
|
1307
|
+
number_of_correct_stories = await _evaluate_core_model(
|
|
1308
|
+
model, stories_file, use_conversation_test_files=use_conversation_test_files
|
|
1309
|
+
)
|
|
1310
|
+
number_correct[os.path.basename(model)].append(number_of_correct_stories)
|
|
1311
|
+
|
|
1312
|
+
rasa.shared.utils.io.dump_obj_as_json_to_file(
|
|
1313
|
+
os.path.join(output, RESULTS_FILE), number_correct
|
|
1314
|
+
)
|
|
1315
|
+
|
|
1316
|
+
|
|
1317
|
+
async def _evaluate_core_model(
|
|
1318
|
+
model: Text, stories_file: Text, use_conversation_test_files: bool = False
|
|
1319
|
+
) -> int:
|
|
1320
|
+
from rasa.core.agent import Agent
|
|
1321
|
+
|
|
1322
|
+
logger.info(f"Evaluating model '{model}'")
|
|
1323
|
+
|
|
1324
|
+
agent = Agent.load(model)
|
|
1325
|
+
generator = _create_data_generator(
|
|
1326
|
+
stories_file, agent, use_conversation_test_files=use_conversation_test_files
|
|
1327
|
+
)
|
|
1328
|
+
completed_trackers = generator.generate_story_trackers()
|
|
1329
|
+
|
|
1330
|
+
# Entities are ignored here as we only compare number of correct stories.
|
|
1331
|
+
story_eval_store, number_of_stories, _ = await _collect_story_predictions(
|
|
1332
|
+
completed_trackers, agent
|
|
1333
|
+
)
|
|
1334
|
+
failed_stories = story_eval_store.failed_stories
|
|
1335
|
+
return number_of_stories - len(failed_stories)
|