rasa-pro 3.12.0.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- README.md +41 -0
- rasa/__init__.py +9 -0
- rasa/__main__.py +177 -0
- rasa/anonymization/__init__.py +2 -0
- rasa/anonymization/anonymisation_rule_yaml_reader.py +91 -0
- rasa/anonymization/anonymization_pipeline.py +286 -0
- rasa/anonymization/anonymization_rule_executor.py +260 -0
- rasa/anonymization/anonymization_rule_orchestrator.py +120 -0
- rasa/anonymization/schemas/config.yml +47 -0
- rasa/anonymization/utils.py +118 -0
- rasa/api.py +160 -0
- rasa/cli/__init__.py +5 -0
- rasa/cli/arguments/__init__.py +0 -0
- rasa/cli/arguments/data.py +106 -0
- rasa/cli/arguments/default_arguments.py +207 -0
- rasa/cli/arguments/evaluate.py +65 -0
- rasa/cli/arguments/export.py +51 -0
- rasa/cli/arguments/interactive.py +74 -0
- rasa/cli/arguments/run.py +219 -0
- rasa/cli/arguments/shell.py +17 -0
- rasa/cli/arguments/test.py +211 -0
- rasa/cli/arguments/train.py +279 -0
- rasa/cli/arguments/visualize.py +34 -0
- rasa/cli/arguments/x.py +30 -0
- rasa/cli/data.py +354 -0
- rasa/cli/dialogue_understanding_test.py +251 -0
- rasa/cli/e2e_test.py +259 -0
- rasa/cli/evaluate.py +222 -0
- rasa/cli/export.py +250 -0
- rasa/cli/inspect.py +75 -0
- rasa/cli/interactive.py +166 -0
- rasa/cli/license.py +65 -0
- rasa/cli/llm_fine_tuning.py +403 -0
- rasa/cli/markers.py +78 -0
- rasa/cli/project_templates/__init__.py +0 -0
- rasa/cli/project_templates/calm/actions/__init__.py +0 -0
- rasa/cli/project_templates/calm/actions/action_template.py +27 -0
- rasa/cli/project_templates/calm/actions/add_contact.py +30 -0
- rasa/cli/project_templates/calm/actions/db.py +57 -0
- rasa/cli/project_templates/calm/actions/list_contacts.py +22 -0
- rasa/cli/project_templates/calm/actions/remove_contact.py +35 -0
- rasa/cli/project_templates/calm/config.yml +10 -0
- rasa/cli/project_templates/calm/credentials.yml +33 -0
- rasa/cli/project_templates/calm/data/flows/add_contact.yml +31 -0
- rasa/cli/project_templates/calm/data/flows/list_contacts.yml +14 -0
- rasa/cli/project_templates/calm/data/flows/remove_contact.yml +29 -0
- rasa/cli/project_templates/calm/db/contacts.json +10 -0
- rasa/cli/project_templates/calm/domain/add_contact.yml +39 -0
- rasa/cli/project_templates/calm/domain/list_contacts.yml +17 -0
- rasa/cli/project_templates/calm/domain/remove_contact.yml +38 -0
- rasa/cli/project_templates/calm/domain/shared.yml +10 -0
- rasa/cli/project_templates/calm/e2e_tests/cancelations/user_cancels_during_a_correction.yml +16 -0
- rasa/cli/project_templates/calm/e2e_tests/cancelations/user_changes_mind_on_a_whim.yml +7 -0
- rasa/cli/project_templates/calm/e2e_tests/corrections/user_corrects_contact_handle.yml +20 -0
- rasa/cli/project_templates/calm/e2e_tests/corrections/user_corrects_contact_name.yml +19 -0
- rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_adds_contact_to_their_list.yml +15 -0
- rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_lists_contacts.yml +5 -0
- rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_removes_contact.yml +11 -0
- rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_removes_contact_from_list.yml +12 -0
- rasa/cli/project_templates/calm/endpoints.yml +58 -0
- rasa/cli/project_templates/default/actions/__init__.py +0 -0
- rasa/cli/project_templates/default/actions/actions.py +27 -0
- rasa/cli/project_templates/default/config.yml +44 -0
- rasa/cli/project_templates/default/credentials.yml +33 -0
- rasa/cli/project_templates/default/data/nlu.yml +91 -0
- rasa/cli/project_templates/default/data/rules.yml +13 -0
- rasa/cli/project_templates/default/data/stories.yml +30 -0
- rasa/cli/project_templates/default/domain.yml +34 -0
- rasa/cli/project_templates/default/endpoints.yml +42 -0
- rasa/cli/project_templates/default/tests/test_stories.yml +91 -0
- rasa/cli/project_templates/tutorial/actions/__init__.py +0 -0
- rasa/cli/project_templates/tutorial/actions/actions.py +22 -0
- rasa/cli/project_templates/tutorial/config.yml +12 -0
- rasa/cli/project_templates/tutorial/credentials.yml +33 -0
- rasa/cli/project_templates/tutorial/data/flows.yml +8 -0
- rasa/cli/project_templates/tutorial/data/patterns.yml +11 -0
- rasa/cli/project_templates/tutorial/domain.yml +35 -0
- rasa/cli/project_templates/tutorial/endpoints.yml +55 -0
- rasa/cli/run.py +143 -0
- rasa/cli/scaffold.py +273 -0
- rasa/cli/shell.py +141 -0
- rasa/cli/studio/__init__.py +0 -0
- rasa/cli/studio/download.py +62 -0
- rasa/cli/studio/studio.py +296 -0
- rasa/cli/studio/train.py +59 -0
- rasa/cli/studio/upload.py +62 -0
- rasa/cli/telemetry.py +102 -0
- rasa/cli/test.py +280 -0
- rasa/cli/train.py +278 -0
- rasa/cli/utils.py +484 -0
- rasa/cli/visualize.py +40 -0
- rasa/cli/x.py +206 -0
- rasa/constants.py +45 -0
- rasa/core/__init__.py +17 -0
- rasa/core/actions/__init__.py +0 -0
- rasa/core/actions/action.py +1318 -0
- rasa/core/actions/action_clean_stack.py +59 -0
- rasa/core/actions/action_exceptions.py +24 -0
- rasa/core/actions/action_hangup.py +29 -0
- rasa/core/actions/action_repeat_bot_messages.py +89 -0
- rasa/core/actions/action_run_slot_rejections.py +210 -0
- rasa/core/actions/action_trigger_chitchat.py +31 -0
- rasa/core/actions/action_trigger_flow.py +109 -0
- rasa/core/actions/action_trigger_search.py +31 -0
- rasa/core/actions/constants.py +5 -0
- rasa/core/actions/custom_action_executor.py +191 -0
- rasa/core/actions/direct_custom_actions_executor.py +109 -0
- rasa/core/actions/e2e_stub_custom_action_executor.py +72 -0
- rasa/core/actions/forms.py +741 -0
- rasa/core/actions/grpc_custom_action_executor.py +251 -0
- rasa/core/actions/http_custom_action_executor.py +145 -0
- rasa/core/actions/loops.py +114 -0
- rasa/core/actions/two_stage_fallback.py +186 -0
- rasa/core/agent.py +559 -0
- rasa/core/auth_retry_tracker_store.py +122 -0
- rasa/core/brokers/__init__.py +0 -0
- rasa/core/brokers/broker.py +126 -0
- rasa/core/brokers/file.py +58 -0
- rasa/core/brokers/kafka.py +324 -0
- rasa/core/brokers/pika.py +388 -0
- rasa/core/brokers/sql.py +86 -0
- rasa/core/channels/__init__.py +61 -0
- rasa/core/channels/botframework.py +338 -0
- rasa/core/channels/callback.py +84 -0
- rasa/core/channels/channel.py +456 -0
- rasa/core/channels/console.py +241 -0
- rasa/core/channels/development_inspector.py +197 -0
- rasa/core/channels/facebook.py +419 -0
- rasa/core/channels/hangouts.py +329 -0
- rasa/core/channels/inspector/.eslintrc.cjs +25 -0
- rasa/core/channels/inspector/.gitignore +23 -0
- rasa/core/channels/inspector/README.md +54 -0
- rasa/core/channels/inspector/assets/favicon.ico +0 -0
- rasa/core/channels/inspector/assets/rasa-chat.js +2 -0
- rasa/core/channels/inspector/custom.d.ts +3 -0
- rasa/core/channels/inspector/dist/assets/arc-861ddd57.js +1 -0
- rasa/core/channels/inspector/dist/assets/array-9f3ba611.js +1 -0
- rasa/core/channels/inspector/dist/assets/c4Diagram-d0fbc5ce-921f02db.js +10 -0
- rasa/core/channels/inspector/dist/assets/classDiagram-936ed81e-b436c4f8.js +2 -0
- rasa/core/channels/inspector/dist/assets/classDiagram-v2-c3cb15f1-511a23cb.js +2 -0
- rasa/core/channels/inspector/dist/assets/createText-62fc7601-ef476ecd.js +7 -0
- rasa/core/channels/inspector/dist/assets/edges-f2ad444c-f1878e0a.js +4 -0
- rasa/core/channels/inspector/dist/assets/erDiagram-9d236eb7-fac75185.js +51 -0
- rasa/core/channels/inspector/dist/assets/flowDb-1972c806-201c5bbc.js +6 -0
- rasa/core/channels/inspector/dist/assets/flowDiagram-7ea5b25a-f904ae41.js +4 -0
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-b080d6f2.js +1 -0
- rasa/core/channels/inspector/dist/assets/flowchart-elk-definition-abe16c3d-1813da66.js +139 -0
- rasa/core/channels/inspector/dist/assets/ganttDiagram-9b5ea136-872af172.js +266 -0
- rasa/core/channels/inspector/dist/assets/gitGraphDiagram-99d0ae7c-34a0af5a.js +70 -0
- rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-128cfa44.ttf +0 -0
- rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-21dbcb97.woff +0 -0
- rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-222b5e26.svg +329 -0
- rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-9ad89b2a.woff2 +0 -0
- rasa/core/channels/inspector/dist/assets/index-2c4b9a3b-42ba3e3d.js +1 -0
- rasa/core/channels/inspector/dist/assets/index-37817b51.js +1317 -0
- rasa/core/channels/inspector/dist/assets/index-3ee28881.css +1 -0
- rasa/core/channels/inspector/dist/assets/infoDiagram-736b4530-6b731386.js +7 -0
- rasa/core/channels/inspector/dist/assets/init-77b53fdd.js +1 -0
- rasa/core/channels/inspector/dist/assets/journeyDiagram-df861f2b-e8579ac6.js +139 -0
- rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-60c05ee4.woff +0 -0
- rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-8335d9b8.svg +438 -0
- rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-9cc39c75.ttf +0 -0
- rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-ead13ccf.woff2 +0 -0
- rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-16705655.woff2 +0 -0
- rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-5aeb07f9.woff +0 -0
- rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-9c459044.ttf +0 -0
- rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-9e2898a4.svg +435 -0
- rasa/core/channels/inspector/dist/assets/layout-89e6403a.js +1 -0
- rasa/core/channels/inspector/dist/assets/line-dc73d3fc.js +1 -0
- rasa/core/channels/inspector/dist/assets/linear-f5b1d2bc.js +1 -0
- rasa/core/channels/inspector/dist/assets/mindmap-definition-beec6740-82cb74fa.js +109 -0
- rasa/core/channels/inspector/dist/assets/ordinal-ba9b4969.js +1 -0
- rasa/core/channels/inspector/dist/assets/path-53f90ab3.js +1 -0
- rasa/core/channels/inspector/dist/assets/pieDiagram-dbbf0591-bdf5f29b.js +35 -0
- rasa/core/channels/inspector/dist/assets/quadrantDiagram-4d7f4fd6-c7a0cbe4.js +7 -0
- rasa/core/channels/inspector/dist/assets/requirementDiagram-6fc4c22a-7ec5410f.js +52 -0
- rasa/core/channels/inspector/dist/assets/sankeyDiagram-8f13d901-caee5554.js +8 -0
- rasa/core/channels/inspector/dist/assets/sequenceDiagram-b655622a-2935f8db.js +122 -0
- rasa/core/channels/inspector/dist/assets/stateDiagram-59f0c015-8f5d9693.js +1 -0
- rasa/core/channels/inspector/dist/assets/stateDiagram-v2-2b26beab-d565d1de.js +1 -0
- rasa/core/channels/inspector/dist/assets/styles-080da4f6-75ad421d.js +110 -0
- rasa/core/channels/inspector/dist/assets/styles-3dcbcfbf-7e764226.js +159 -0
- rasa/core/channels/inspector/dist/assets/styles-9c745c82-7a4e0e61.js +207 -0
- rasa/core/channels/inspector/dist/assets/svgDrawCommon-4835440b-4019d1bf.js +1 -0
- rasa/core/channels/inspector/dist/assets/timeline-definition-5b62e21b-01ea12df.js +61 -0
- rasa/core/channels/inspector/dist/assets/xychartDiagram-2b33534f-89407137.js +7 -0
- rasa/core/channels/inspector/dist/index.html +42 -0
- rasa/core/channels/inspector/index.html +40 -0
- rasa/core/channels/inspector/jest.config.ts +13 -0
- rasa/core/channels/inspector/package.json +52 -0
- rasa/core/channels/inspector/setupTests.ts +2 -0
- rasa/core/channels/inspector/src/App.tsx +220 -0
- rasa/core/channels/inspector/src/components/Chat.tsx +95 -0
- rasa/core/channels/inspector/src/components/DiagramFlow.tsx +108 -0
- rasa/core/channels/inspector/src/components/DialogueInformation.tsx +187 -0
- rasa/core/channels/inspector/src/components/DialogueStack.tsx +136 -0
- rasa/core/channels/inspector/src/components/ExpandIcon.tsx +16 -0
- rasa/core/channels/inspector/src/components/FullscreenButton.tsx +45 -0
- rasa/core/channels/inspector/src/components/LoadingSpinner.tsx +22 -0
- rasa/core/channels/inspector/src/components/NoActiveFlow.tsx +21 -0
- rasa/core/channels/inspector/src/components/RasaLogo.tsx +32 -0
- rasa/core/channels/inspector/src/components/SaraDiagrams.tsx +39 -0
- rasa/core/channels/inspector/src/components/Slots.tsx +91 -0
- rasa/core/channels/inspector/src/components/Welcome.tsx +54 -0
- rasa/core/channels/inspector/src/helpers/audiostream.ts +191 -0
- rasa/core/channels/inspector/src/helpers/formatters.test.ts +392 -0
- rasa/core/channels/inspector/src/helpers/formatters.ts +306 -0
- rasa/core/channels/inspector/src/helpers/utils.ts +127 -0
- rasa/core/channels/inspector/src/main.tsx +13 -0
- rasa/core/channels/inspector/src/theme/Button/Button.ts +29 -0
- rasa/core/channels/inspector/src/theme/Heading/Heading.ts +31 -0
- rasa/core/channels/inspector/src/theme/Input/Input.ts +27 -0
- rasa/core/channels/inspector/src/theme/Link/Link.ts +10 -0
- rasa/core/channels/inspector/src/theme/Modal/Modal.ts +47 -0
- rasa/core/channels/inspector/src/theme/Table/Table.tsx +38 -0
- rasa/core/channels/inspector/src/theme/Tooltip/Tooltip.ts +12 -0
- rasa/core/channels/inspector/src/theme/base/breakpoints.ts +8 -0
- rasa/core/channels/inspector/src/theme/base/colors.ts +88 -0
- rasa/core/channels/inspector/src/theme/base/fonts/fontFaces.css +29 -0
- rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.eot +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.svg +329 -0
- rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.ttf +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.woff +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.woff2 +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.eot +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.svg +438 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.ttf +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.woff +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.woff2 +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.eot +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.svg +435 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.ttf +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.woff +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.woff2 +0 -0
- rasa/core/channels/inspector/src/theme/base/radii.ts +9 -0
- rasa/core/channels/inspector/src/theme/base/shadows.ts +7 -0
- rasa/core/channels/inspector/src/theme/base/sizes.ts +7 -0
- rasa/core/channels/inspector/src/theme/base/space.ts +15 -0
- rasa/core/channels/inspector/src/theme/base/styles.ts +13 -0
- rasa/core/channels/inspector/src/theme/base/typography.ts +24 -0
- rasa/core/channels/inspector/src/theme/base/zIndices.ts +19 -0
- rasa/core/channels/inspector/src/theme/index.ts +101 -0
- rasa/core/channels/inspector/src/types.ts +84 -0
- rasa/core/channels/inspector/src/vite-env.d.ts +1 -0
- rasa/core/channels/inspector/tests/__mocks__/fileMock.ts +1 -0
- rasa/core/channels/inspector/tests/__mocks__/matchMedia.ts +16 -0
- rasa/core/channels/inspector/tests/__mocks__/styleMock.ts +1 -0
- rasa/core/channels/inspector/tests/renderWithProviders.tsx +14 -0
- rasa/core/channels/inspector/tsconfig.json +26 -0
- rasa/core/channels/inspector/tsconfig.node.json +10 -0
- rasa/core/channels/inspector/vite.config.ts +8 -0
- rasa/core/channels/inspector/yarn.lock +6249 -0
- rasa/core/channels/mattermost.py +229 -0
- rasa/core/channels/rasa_chat.py +126 -0
- rasa/core/channels/rest.py +230 -0
- rasa/core/channels/rocketchat.py +174 -0
- rasa/core/channels/slack.py +620 -0
- rasa/core/channels/socketio.py +302 -0
- rasa/core/channels/telegram.py +298 -0
- rasa/core/channels/twilio.py +169 -0
- rasa/core/channels/vier_cvg.py +374 -0
- rasa/core/channels/voice_ready/__init__.py +0 -0
- rasa/core/channels/voice_ready/audiocodes.py +501 -0
- rasa/core/channels/voice_ready/jambonz.py +121 -0
- rasa/core/channels/voice_ready/jambonz_protocol.py +396 -0
- rasa/core/channels/voice_ready/twilio_voice.py +403 -0
- rasa/core/channels/voice_ready/utils.py +37 -0
- rasa/core/channels/voice_stream/__init__.py +0 -0
- rasa/core/channels/voice_stream/asr/__init__.py +0 -0
- rasa/core/channels/voice_stream/asr/asr_engine.py +89 -0
- rasa/core/channels/voice_stream/asr/asr_event.py +18 -0
- rasa/core/channels/voice_stream/asr/azure.py +130 -0
- rasa/core/channels/voice_stream/asr/deepgram.py +90 -0
- rasa/core/channels/voice_stream/audio_bytes.py +8 -0
- rasa/core/channels/voice_stream/browser_audio.py +107 -0
- rasa/core/channels/voice_stream/call_state.py +23 -0
- rasa/core/channels/voice_stream/tts/__init__.py +0 -0
- rasa/core/channels/voice_stream/tts/azure.py +106 -0
- rasa/core/channels/voice_stream/tts/cartesia.py +118 -0
- rasa/core/channels/voice_stream/tts/tts_cache.py +27 -0
- rasa/core/channels/voice_stream/tts/tts_engine.py +58 -0
- rasa/core/channels/voice_stream/twilio_media_streams.py +173 -0
- rasa/core/channels/voice_stream/util.py +57 -0
- rasa/core/channels/voice_stream/voice_channel.py +427 -0
- rasa/core/channels/webexteams.py +134 -0
- rasa/core/concurrent_lock_store.py +210 -0
- rasa/core/constants.py +112 -0
- rasa/core/evaluation/__init__.py +0 -0
- rasa/core/evaluation/marker.py +267 -0
- rasa/core/evaluation/marker_base.py +923 -0
- rasa/core/evaluation/marker_stats.py +293 -0
- rasa/core/evaluation/marker_tracker_loader.py +103 -0
- rasa/core/exceptions.py +29 -0
- rasa/core/exporter.py +284 -0
- rasa/core/featurizers/__init__.py +0 -0
- rasa/core/featurizers/precomputation.py +410 -0
- rasa/core/featurizers/single_state_featurizer.py +421 -0
- rasa/core/featurizers/tracker_featurizers.py +1262 -0
- rasa/core/http_interpreter.py +89 -0
- rasa/core/information_retrieval/__init__.py +7 -0
- rasa/core/information_retrieval/faiss.py +124 -0
- rasa/core/information_retrieval/information_retrieval.py +137 -0
- rasa/core/information_retrieval/milvus.py +59 -0
- rasa/core/information_retrieval/qdrant.py +96 -0
- rasa/core/jobs.py +63 -0
- rasa/core/lock.py +139 -0
- rasa/core/lock_store.py +343 -0
- rasa/core/migrate.py +403 -0
- rasa/core/nlg/__init__.py +3 -0
- rasa/core/nlg/callback.py +146 -0
- rasa/core/nlg/contextual_response_rephraser.py +320 -0
- rasa/core/nlg/generator.py +230 -0
- rasa/core/nlg/interpolator.py +143 -0
- rasa/core/nlg/response.py +155 -0
- rasa/core/nlg/summarize.py +70 -0
- rasa/core/persistor.py +538 -0
- rasa/core/policies/__init__.py +0 -0
- rasa/core/policies/ensemble.py +329 -0
- rasa/core/policies/enterprise_search_policy.py +905 -0
- rasa/core/policies/enterprise_search_prompt_template.jinja2 +25 -0
- rasa/core/policies/enterprise_search_prompt_with_citation_template.jinja2 +60 -0
- rasa/core/policies/flow_policy.py +205 -0
- rasa/core/policies/flows/__init__.py +0 -0
- rasa/core/policies/flows/flow_exceptions.py +44 -0
- rasa/core/policies/flows/flow_executor.py +754 -0
- rasa/core/policies/flows/flow_step_result.py +43 -0
- rasa/core/policies/intentless_policy.py +1031 -0
- rasa/core/policies/intentless_prompt_template.jinja2 +22 -0
- rasa/core/policies/memoization.py +538 -0
- rasa/core/policies/policy.py +725 -0
- rasa/core/policies/rule_policy.py +1273 -0
- rasa/core/policies/ted_policy.py +2169 -0
- rasa/core/policies/unexpected_intent_policy.py +1022 -0
- rasa/core/processor.py +1465 -0
- rasa/core/run.py +342 -0
- rasa/core/secrets_manager/__init__.py +0 -0
- rasa/core/secrets_manager/constants.py +36 -0
- rasa/core/secrets_manager/endpoints.py +391 -0
- rasa/core/secrets_manager/factory.py +241 -0
- rasa/core/secrets_manager/secret_manager.py +262 -0
- rasa/core/secrets_manager/vault.py +584 -0
- rasa/core/test.py +1335 -0
- rasa/core/tracker_store.py +1703 -0
- rasa/core/train.py +105 -0
- rasa/core/training/__init__.py +89 -0
- rasa/core/training/converters/__init__.py +0 -0
- rasa/core/training/converters/responses_prefix_converter.py +119 -0
- rasa/core/training/interactive.py +1744 -0
- rasa/core/training/story_conflict.py +381 -0
- rasa/core/training/training.py +93 -0
- rasa/core/utils.py +366 -0
- rasa/core/visualize.py +70 -0
- rasa/dialogue_understanding/__init__.py +0 -0
- rasa/dialogue_understanding/coexistence/__init__.py +0 -0
- rasa/dialogue_understanding/coexistence/constants.py +4 -0
- rasa/dialogue_understanding/coexistence/intent_based_router.py +196 -0
- rasa/dialogue_understanding/coexistence/llm_based_router.py +327 -0
- rasa/dialogue_understanding/coexistence/router_template.jinja2 +12 -0
- rasa/dialogue_understanding/commands/__init__.py +61 -0
- rasa/dialogue_understanding/commands/can_not_handle_command.py +70 -0
- rasa/dialogue_understanding/commands/cancel_flow_command.py +125 -0
- rasa/dialogue_understanding/commands/change_flow_command.py +44 -0
- rasa/dialogue_understanding/commands/chit_chat_answer_command.py +57 -0
- rasa/dialogue_understanding/commands/clarify_command.py +86 -0
- rasa/dialogue_understanding/commands/command.py +85 -0
- rasa/dialogue_understanding/commands/correct_slots_command.py +297 -0
- rasa/dialogue_understanding/commands/error_command.py +79 -0
- rasa/dialogue_understanding/commands/free_form_answer_command.py +9 -0
- rasa/dialogue_understanding/commands/handle_code_change_command.py +73 -0
- rasa/dialogue_understanding/commands/human_handoff_command.py +66 -0
- rasa/dialogue_understanding/commands/knowledge_answer_command.py +57 -0
- rasa/dialogue_understanding/commands/noop_command.py +54 -0
- rasa/dialogue_understanding/commands/repeat_bot_messages_command.py +60 -0
- rasa/dialogue_understanding/commands/restart_command.py +58 -0
- rasa/dialogue_understanding/commands/session_end_command.py +61 -0
- rasa/dialogue_understanding/commands/session_start_command.py +59 -0
- rasa/dialogue_understanding/commands/set_slot_command.py +160 -0
- rasa/dialogue_understanding/commands/skip_question_command.py +75 -0
- rasa/dialogue_understanding/commands/start_flow_command.py +107 -0
- rasa/dialogue_understanding/commands/user_silence_command.py +59 -0
- rasa/dialogue_understanding/commands/utils.py +45 -0
- rasa/dialogue_understanding/generator/__init__.py +21 -0
- rasa/dialogue_understanding/generator/command_generator.py +464 -0
- rasa/dialogue_understanding/generator/constants.py +27 -0
- rasa/dialogue_understanding/generator/flow_document_template.jinja2 +4 -0
- rasa/dialogue_understanding/generator/flow_retrieval.py +466 -0
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +500 -0
- rasa/dialogue_understanding/generator/llm_command_generator.py +67 -0
- rasa/dialogue_understanding/generator/multi_step/__init__.py +0 -0
- rasa/dialogue_understanding/generator/multi_step/fill_slots_prompt.jinja2 +62 -0
- rasa/dialogue_understanding/generator/multi_step/handle_flows_prompt.jinja2 +38 -0
- rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +920 -0
- rasa/dialogue_understanding/generator/nlu_command_adapter.py +261 -0
- rasa/dialogue_understanding/generator/single_step/__init__.py +0 -0
- rasa/dialogue_understanding/generator/single_step/command_prompt_template.jinja2 +60 -0
- rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +486 -0
- rasa/dialogue_understanding/patterns/__init__.py +0 -0
- rasa/dialogue_understanding/patterns/cancel.py +111 -0
- rasa/dialogue_understanding/patterns/cannot_handle.py +43 -0
- rasa/dialogue_understanding/patterns/chitchat.py +37 -0
- rasa/dialogue_understanding/patterns/clarify.py +97 -0
- rasa/dialogue_understanding/patterns/code_change.py +41 -0
- rasa/dialogue_understanding/patterns/collect_information.py +90 -0
- rasa/dialogue_understanding/patterns/completed.py +40 -0
- rasa/dialogue_understanding/patterns/continue_interrupted.py +42 -0
- rasa/dialogue_understanding/patterns/correction.py +278 -0
- rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +301 -0
- rasa/dialogue_understanding/patterns/human_handoff.py +37 -0
- rasa/dialogue_understanding/patterns/internal_error.py +47 -0
- rasa/dialogue_understanding/patterns/repeat.py +37 -0
- rasa/dialogue_understanding/patterns/restart.py +37 -0
- rasa/dialogue_understanding/patterns/search.py +37 -0
- rasa/dialogue_understanding/patterns/session_start.py +37 -0
- rasa/dialogue_understanding/patterns/skip_question.py +38 -0
- rasa/dialogue_understanding/patterns/user_silence.py +37 -0
- rasa/dialogue_understanding/processor/__init__.py +0 -0
- rasa/dialogue_understanding/processor/command_processor.py +720 -0
- rasa/dialogue_understanding/processor/command_processor_component.py +43 -0
- rasa/dialogue_understanding/stack/__init__.py +0 -0
- rasa/dialogue_understanding/stack/dialogue_stack.py +178 -0
- rasa/dialogue_understanding/stack/frames/__init__.py +19 -0
- rasa/dialogue_understanding/stack/frames/chit_chat_frame.py +27 -0
- rasa/dialogue_understanding/stack/frames/dialogue_stack_frame.py +137 -0
- rasa/dialogue_understanding/stack/frames/flow_stack_frame.py +157 -0
- rasa/dialogue_understanding/stack/frames/pattern_frame.py +10 -0
- rasa/dialogue_understanding/stack/frames/search_frame.py +27 -0
- rasa/dialogue_understanding/stack/utils.py +211 -0
- rasa/dialogue_understanding/utils.py +14 -0
- rasa/dialogue_understanding_test/__init__.py +0 -0
- rasa/dialogue_understanding_test/command_metric_calculation.py +12 -0
- rasa/dialogue_understanding_test/constants.py +17 -0
- rasa/dialogue_understanding_test/du_test_case.py +118 -0
- rasa/dialogue_understanding_test/du_test_result.py +11 -0
- rasa/dialogue_understanding_test/du_test_runner.py +93 -0
- rasa/dialogue_understanding_test/io.py +54 -0
- rasa/dialogue_understanding_test/validation.py +22 -0
- rasa/e2e_test/__init__.py +0 -0
- rasa/e2e_test/aggregate_test_stats_calculator.py +134 -0
- rasa/e2e_test/assertions.py +1345 -0
- rasa/e2e_test/assertions_schema.yml +129 -0
- rasa/e2e_test/constants.py +31 -0
- rasa/e2e_test/e2e_config.py +220 -0
- rasa/e2e_test/e2e_config_schema.yml +26 -0
- rasa/e2e_test/e2e_test_case.py +569 -0
- rasa/e2e_test/e2e_test_converter.py +363 -0
- rasa/e2e_test/e2e_test_converter_prompt.jinja2 +70 -0
- rasa/e2e_test/e2e_test_coverage_report.py +364 -0
- rasa/e2e_test/e2e_test_result.py +54 -0
- rasa/e2e_test/e2e_test_runner.py +1192 -0
- rasa/e2e_test/e2e_test_schema.yml +181 -0
- rasa/e2e_test/pykwalify_extensions.py +39 -0
- rasa/e2e_test/stub_custom_action.py +70 -0
- rasa/e2e_test/utils/__init__.py +0 -0
- rasa/e2e_test/utils/e2e_yaml_utils.py +55 -0
- rasa/e2e_test/utils/io.py +598 -0
- rasa/e2e_test/utils/validation.py +178 -0
- rasa/engine/__init__.py +0 -0
- rasa/engine/caching.py +463 -0
- rasa/engine/constants.py +17 -0
- rasa/engine/exceptions.py +14 -0
- rasa/engine/graph.py +642 -0
- rasa/engine/loader.py +48 -0
- rasa/engine/recipes/__init__.py +0 -0
- rasa/engine/recipes/config_files/default_config.yml +41 -0
- rasa/engine/recipes/default_components.py +97 -0
- rasa/engine/recipes/default_recipe.py +1272 -0
- rasa/engine/recipes/graph_recipe.py +79 -0
- rasa/engine/recipes/recipe.py +93 -0
- rasa/engine/runner/__init__.py +0 -0
- rasa/engine/runner/dask.py +250 -0
- rasa/engine/runner/interface.py +49 -0
- rasa/engine/storage/__init__.py +0 -0
- rasa/engine/storage/local_model_storage.py +244 -0
- rasa/engine/storage/resource.py +110 -0
- rasa/engine/storage/storage.py +199 -0
- rasa/engine/training/__init__.py +0 -0
- rasa/engine/training/components.py +176 -0
- rasa/engine/training/fingerprinting.py +64 -0
- rasa/engine/training/graph_trainer.py +256 -0
- rasa/engine/training/hooks.py +164 -0
- rasa/engine/validation.py +1451 -0
- rasa/env.py +14 -0
- rasa/exceptions.py +69 -0
- rasa/graph_components/__init__.py +0 -0
- rasa/graph_components/converters/__init__.py +0 -0
- rasa/graph_components/converters/nlu_message_converter.py +48 -0
- rasa/graph_components/providers/__init__.py +0 -0
- rasa/graph_components/providers/domain_for_core_training_provider.py +87 -0
- rasa/graph_components/providers/domain_provider.py +71 -0
- rasa/graph_components/providers/flows_provider.py +74 -0
- rasa/graph_components/providers/forms_provider.py +44 -0
- rasa/graph_components/providers/nlu_training_data_provider.py +56 -0
- rasa/graph_components/providers/responses_provider.py +44 -0
- rasa/graph_components/providers/rule_only_provider.py +49 -0
- rasa/graph_components/providers/story_graph_provider.py +96 -0
- rasa/graph_components/providers/training_tracker_provider.py +55 -0
- rasa/graph_components/validators/__init__.py +0 -0
- rasa/graph_components/validators/default_recipe_validator.py +550 -0
- rasa/graph_components/validators/finetuning_validator.py +302 -0
- rasa/hooks.py +111 -0
- rasa/jupyter.py +63 -0
- rasa/llm_fine_tuning/__init__.py +0 -0
- rasa/llm_fine_tuning/annotation_module.py +241 -0
- rasa/llm_fine_tuning/conversations.py +144 -0
- rasa/llm_fine_tuning/llm_data_preparation_module.py +178 -0
- rasa/llm_fine_tuning/paraphrasing/__init__.py +0 -0
- rasa/llm_fine_tuning/paraphrasing/conversation_rephraser.py +281 -0
- rasa/llm_fine_tuning/paraphrasing/default_rephrase_prompt_template.jina2 +44 -0
- rasa/llm_fine_tuning/paraphrasing/rephrase_validator.py +121 -0
- rasa/llm_fine_tuning/paraphrasing/rephrased_user_message.py +10 -0
- rasa/llm_fine_tuning/paraphrasing_module.py +128 -0
- rasa/llm_fine_tuning/storage.py +174 -0
- rasa/llm_fine_tuning/train_test_split_module.py +441 -0
- rasa/markers/__init__.py +0 -0
- rasa/markers/marker.py +269 -0
- rasa/markers/marker_base.py +828 -0
- rasa/markers/upload.py +74 -0
- rasa/markers/validate.py +21 -0
- rasa/model.py +118 -0
- rasa/model_manager/__init__.py +0 -0
- rasa/model_manager/config.py +40 -0
- rasa/model_manager/model_api.py +559 -0
- rasa/model_manager/runner_service.py +286 -0
- rasa/model_manager/socket_bridge.py +146 -0
- rasa/model_manager/studio_jwt_auth.py +86 -0
- rasa/model_manager/trainer_service.py +325 -0
- rasa/model_manager/utils.py +87 -0
- rasa/model_manager/warm_rasa_process.py +187 -0
- rasa/model_service.py +112 -0
- rasa/model_testing.py +457 -0
- rasa/model_training.py +596 -0
- rasa/nlu/__init__.py +7 -0
- rasa/nlu/classifiers/__init__.py +3 -0
- rasa/nlu/classifiers/classifier.py +5 -0
- rasa/nlu/classifiers/diet_classifier.py +1881 -0
- rasa/nlu/classifiers/fallback_classifier.py +192 -0
- rasa/nlu/classifiers/keyword_intent_classifier.py +188 -0
- rasa/nlu/classifiers/logistic_regression_classifier.py +253 -0
- rasa/nlu/classifiers/mitie_intent_classifier.py +156 -0
- rasa/nlu/classifiers/regex_message_handler.py +56 -0
- rasa/nlu/classifiers/sklearn_intent_classifier.py +330 -0
- rasa/nlu/constants.py +77 -0
- rasa/nlu/convert.py +40 -0
- rasa/nlu/emulators/__init__.py +0 -0
- rasa/nlu/emulators/dialogflow.py +55 -0
- rasa/nlu/emulators/emulator.py +49 -0
- rasa/nlu/emulators/luis.py +86 -0
- rasa/nlu/emulators/no_emulator.py +10 -0
- rasa/nlu/emulators/wit.py +56 -0
- rasa/nlu/extractors/__init__.py +0 -0
- rasa/nlu/extractors/crf_entity_extractor.py +715 -0
- rasa/nlu/extractors/duckling_entity_extractor.py +206 -0
- rasa/nlu/extractors/entity_synonyms.py +178 -0
- rasa/nlu/extractors/extractor.py +470 -0
- rasa/nlu/extractors/mitie_entity_extractor.py +293 -0
- rasa/nlu/extractors/regex_entity_extractor.py +220 -0
- rasa/nlu/extractors/spacy_entity_extractor.py +95 -0
- rasa/nlu/featurizers/__init__.py +0 -0
- rasa/nlu/featurizers/dense_featurizer/__init__.py +0 -0
- rasa/nlu/featurizers/dense_featurizer/convert_featurizer.py +445 -0
- rasa/nlu/featurizers/dense_featurizer/dense_featurizer.py +57 -0
- rasa/nlu/featurizers/dense_featurizer/lm_featurizer.py +768 -0
- rasa/nlu/featurizers/dense_featurizer/mitie_featurizer.py +170 -0
- rasa/nlu/featurizers/dense_featurizer/spacy_featurizer.py +132 -0
- rasa/nlu/featurizers/featurizer.py +89 -0
- rasa/nlu/featurizers/sparse_featurizer/__init__.py +0 -0
- rasa/nlu/featurizers/sparse_featurizer/count_vectors_featurizer.py +867 -0
- rasa/nlu/featurizers/sparse_featurizer/lexical_syntactic_featurizer.py +571 -0
- rasa/nlu/featurizers/sparse_featurizer/regex_featurizer.py +271 -0
- rasa/nlu/featurizers/sparse_featurizer/sparse_featurizer.py +9 -0
- rasa/nlu/model.py +24 -0
- rasa/nlu/run.py +27 -0
- rasa/nlu/selectors/__init__.py +0 -0
- rasa/nlu/selectors/response_selector.py +987 -0
- rasa/nlu/test.py +1940 -0
- rasa/nlu/tokenizers/__init__.py +0 -0
- rasa/nlu/tokenizers/jieba_tokenizer.py +148 -0
- rasa/nlu/tokenizers/mitie_tokenizer.py +75 -0
- rasa/nlu/tokenizers/spacy_tokenizer.py +72 -0
- rasa/nlu/tokenizers/tokenizer.py +239 -0
- rasa/nlu/tokenizers/whitespace_tokenizer.py +95 -0
- rasa/nlu/utils/__init__.py +35 -0
- rasa/nlu/utils/bilou_utils.py +462 -0
- rasa/nlu/utils/hugging_face/__init__.py +0 -0
- rasa/nlu/utils/hugging_face/registry.py +108 -0
- rasa/nlu/utils/hugging_face/transformers_pre_post_processors.py +311 -0
- rasa/nlu/utils/mitie_utils.py +113 -0
- rasa/nlu/utils/pattern_utils.py +168 -0
- rasa/nlu/utils/spacy_utils.py +310 -0
- rasa/plugin.py +90 -0
- rasa/server.py +1588 -0
- rasa/shared/__init__.py +0 -0
- rasa/shared/constants.py +311 -0
- rasa/shared/core/__init__.py +0 -0
- rasa/shared/core/command_payload_reader.py +109 -0
- rasa/shared/core/constants.py +180 -0
- rasa/shared/core/conversation.py +46 -0
- rasa/shared/core/domain.py +2172 -0
- rasa/shared/core/events.py +2559 -0
- rasa/shared/core/flows/__init__.py +7 -0
- rasa/shared/core/flows/flow.py +562 -0
- rasa/shared/core/flows/flow_path.py +84 -0
- rasa/shared/core/flows/flow_step.py +146 -0
- rasa/shared/core/flows/flow_step_links.py +319 -0
- rasa/shared/core/flows/flow_step_sequence.py +70 -0
- rasa/shared/core/flows/flows_list.py +258 -0
- rasa/shared/core/flows/flows_yaml_schema.json +303 -0
- rasa/shared/core/flows/nlu_trigger.py +117 -0
- rasa/shared/core/flows/steps/__init__.py +24 -0
- rasa/shared/core/flows/steps/action.py +56 -0
- rasa/shared/core/flows/steps/call.py +64 -0
- rasa/shared/core/flows/steps/collect.py +112 -0
- rasa/shared/core/flows/steps/constants.py +5 -0
- rasa/shared/core/flows/steps/continuation.py +36 -0
- rasa/shared/core/flows/steps/end.py +22 -0
- rasa/shared/core/flows/steps/internal.py +44 -0
- rasa/shared/core/flows/steps/link.py +51 -0
- rasa/shared/core/flows/steps/no_operation.py +48 -0
- rasa/shared/core/flows/steps/set_slots.py +50 -0
- rasa/shared/core/flows/steps/start.py +30 -0
- rasa/shared/core/flows/utils.py +39 -0
- rasa/shared/core/flows/validation.py +735 -0
- rasa/shared/core/flows/yaml_flows_io.py +405 -0
- rasa/shared/core/generator.py +908 -0
- rasa/shared/core/slot_mappings.py +526 -0
- rasa/shared/core/slots.py +654 -0
- rasa/shared/core/trackers.py +1183 -0
- rasa/shared/core/training_data/__init__.py +0 -0
- rasa/shared/core/training_data/loading.py +89 -0
- rasa/shared/core/training_data/story_reader/__init__.py +0 -0
- rasa/shared/core/training_data/story_reader/story_reader.py +129 -0
- rasa/shared/core/training_data/story_reader/story_step_builder.py +168 -0
- rasa/shared/core/training_data/story_reader/yaml_story_reader.py +888 -0
- rasa/shared/core/training_data/story_writer/__init__.py +0 -0
- rasa/shared/core/training_data/story_writer/story_writer.py +76 -0
- rasa/shared/core/training_data/story_writer/yaml_story_writer.py +444 -0
- rasa/shared/core/training_data/structures.py +858 -0
- rasa/shared/core/training_data/visualization.html +146 -0
- rasa/shared/core/training_data/visualization.py +603 -0
- rasa/shared/data.py +249 -0
- rasa/shared/engine/__init__.py +0 -0
- rasa/shared/engine/caching.py +26 -0
- rasa/shared/exceptions.py +167 -0
- rasa/shared/importers/__init__.py +0 -0
- rasa/shared/importers/importer.py +770 -0
- rasa/shared/importers/multi_project.py +215 -0
- rasa/shared/importers/rasa.py +108 -0
- rasa/shared/importers/remote_importer.py +196 -0
- rasa/shared/importers/utils.py +36 -0
- rasa/shared/nlu/__init__.py +0 -0
- rasa/shared/nlu/constants.py +53 -0
- rasa/shared/nlu/interpreter.py +10 -0
- rasa/shared/nlu/training_data/__init__.py +0 -0
- rasa/shared/nlu/training_data/entities_parser.py +208 -0
- rasa/shared/nlu/training_data/features.py +492 -0
- rasa/shared/nlu/training_data/formats/__init__.py +10 -0
- rasa/shared/nlu/training_data/formats/dialogflow.py +163 -0
- rasa/shared/nlu/training_data/formats/luis.py +87 -0
- rasa/shared/nlu/training_data/formats/rasa.py +135 -0
- rasa/shared/nlu/training_data/formats/rasa_yaml.py +618 -0
- rasa/shared/nlu/training_data/formats/readerwriter.py +244 -0
- rasa/shared/nlu/training_data/formats/wit.py +52 -0
- rasa/shared/nlu/training_data/loading.py +137 -0
- rasa/shared/nlu/training_data/lookup_tables_parser.py +30 -0
- rasa/shared/nlu/training_data/message.py +490 -0
- rasa/shared/nlu/training_data/schemas/__init__.py +0 -0
- rasa/shared/nlu/training_data/schemas/data_schema.py +85 -0
- rasa/shared/nlu/training_data/schemas/nlu.yml +53 -0
- rasa/shared/nlu/training_data/schemas/responses.yml +70 -0
- rasa/shared/nlu/training_data/synonyms_parser.py +42 -0
- rasa/shared/nlu/training_data/training_data.py +729 -0
- rasa/shared/nlu/training_data/util.py +223 -0
- rasa/shared/providers/__init__.py +0 -0
- rasa/shared/providers/_configs/__init__.py +0 -0
- rasa/shared/providers/_configs/azure_openai_client_config.py +677 -0
- rasa/shared/providers/_configs/client_config.py +59 -0
- rasa/shared/providers/_configs/default_litellm_client_config.py +132 -0
- rasa/shared/providers/_configs/huggingface_local_embedding_client_config.py +236 -0
- rasa/shared/providers/_configs/litellm_router_client_config.py +222 -0
- rasa/shared/providers/_configs/model_group_config.py +173 -0
- rasa/shared/providers/_configs/openai_client_config.py +177 -0
- rasa/shared/providers/_configs/rasa_llm_client_config.py +75 -0
- rasa/shared/providers/_configs/self_hosted_llm_client_config.py +178 -0
- rasa/shared/providers/_configs/utils.py +117 -0
- rasa/shared/providers/_ssl_verification_utils.py +124 -0
- rasa/shared/providers/_utils.py +79 -0
- rasa/shared/providers/constants.py +7 -0
- rasa/shared/providers/embedding/__init__.py +0 -0
- rasa/shared/providers/embedding/_base_litellm_embedding_client.py +243 -0
- rasa/shared/providers/embedding/_langchain_embedding_client_adapter.py +74 -0
- rasa/shared/providers/embedding/azure_openai_embedding_client.py +335 -0
- rasa/shared/providers/embedding/default_litellm_embedding_client.py +126 -0
- rasa/shared/providers/embedding/embedding_client.py +90 -0
- rasa/shared/providers/embedding/embedding_response.py +41 -0
- rasa/shared/providers/embedding/huggingface_local_embedding_client.py +191 -0
- rasa/shared/providers/embedding/litellm_router_embedding_client.py +138 -0
- rasa/shared/providers/embedding/openai_embedding_client.py +172 -0
- rasa/shared/providers/llm/__init__.py +0 -0
- rasa/shared/providers/llm/_base_litellm_client.py +265 -0
- rasa/shared/providers/llm/azure_openai_llm_client.py +415 -0
- rasa/shared/providers/llm/default_litellm_llm_client.py +110 -0
- rasa/shared/providers/llm/litellm_router_llm_client.py +202 -0
- rasa/shared/providers/llm/llm_client.py +78 -0
- rasa/shared/providers/llm/llm_response.py +50 -0
- rasa/shared/providers/llm/openai_llm_client.py +161 -0
- rasa/shared/providers/llm/rasa_llm_client.py +120 -0
- rasa/shared/providers/llm/self_hosted_llm_client.py +276 -0
- rasa/shared/providers/mappings.py +94 -0
- rasa/shared/providers/router/__init__.py +0 -0
- rasa/shared/providers/router/_base_litellm_router_client.py +185 -0
- rasa/shared/providers/router/router_client.py +75 -0
- rasa/shared/utils/__init__.py +0 -0
- rasa/shared/utils/cli.py +102 -0
- rasa/shared/utils/common.py +324 -0
- rasa/shared/utils/constants.py +4 -0
- rasa/shared/utils/health_check/__init__.py +0 -0
- rasa/shared/utils/health_check/embeddings_health_check_mixin.py +31 -0
- rasa/shared/utils/health_check/health_check.py +258 -0
- rasa/shared/utils/health_check/llm_health_check_mixin.py +31 -0
- rasa/shared/utils/io.py +499 -0
- rasa/shared/utils/llm.py +764 -0
- rasa/shared/utils/pykwalify_extensions.py +27 -0
- rasa/shared/utils/schemas/__init__.py +0 -0
- rasa/shared/utils/schemas/config.yml +2 -0
- rasa/shared/utils/schemas/domain.yml +145 -0
- rasa/shared/utils/schemas/events.py +214 -0
- rasa/shared/utils/schemas/model_config.yml +36 -0
- rasa/shared/utils/schemas/stories.yml +173 -0
- rasa/shared/utils/yaml.py +1068 -0
- rasa/studio/__init__.py +0 -0
- rasa/studio/auth.py +270 -0
- rasa/studio/config.py +136 -0
- rasa/studio/constants.py +19 -0
- rasa/studio/data_handler.py +368 -0
- rasa/studio/download.py +489 -0
- rasa/studio/results_logger.py +137 -0
- rasa/studio/train.py +134 -0
- rasa/studio/upload.py +563 -0
- rasa/telemetry.py +1876 -0
- rasa/tracing/__init__.py +0 -0
- rasa/tracing/config.py +355 -0
- rasa/tracing/constants.py +62 -0
- rasa/tracing/instrumentation/__init__.py +0 -0
- rasa/tracing/instrumentation/attribute_extractors.py +765 -0
- rasa/tracing/instrumentation/instrumentation.py +1306 -0
- rasa/tracing/instrumentation/intentless_policy_instrumentation.py +144 -0
- rasa/tracing/instrumentation/metrics.py +294 -0
- rasa/tracing/metric_instrument_provider.py +205 -0
- rasa/utils/__init__.py +0 -0
- rasa/utils/beta.py +83 -0
- rasa/utils/cli.py +28 -0
- rasa/utils/common.py +639 -0
- rasa/utils/converter.py +53 -0
- rasa/utils/endpoints.py +331 -0
- rasa/utils/io.py +252 -0
- rasa/utils/json_utils.py +60 -0
- rasa/utils/licensing.py +542 -0
- rasa/utils/log_utils.py +181 -0
- rasa/utils/mapper.py +210 -0
- rasa/utils/ml_utils.py +147 -0
- rasa/utils/plotting.py +362 -0
- rasa/utils/sanic_error_handler.py +32 -0
- rasa/utils/singleton.py +23 -0
- rasa/utils/tensorflow/__init__.py +0 -0
- rasa/utils/tensorflow/callback.py +112 -0
- rasa/utils/tensorflow/constants.py +116 -0
- rasa/utils/tensorflow/crf.py +492 -0
- rasa/utils/tensorflow/data_generator.py +440 -0
- rasa/utils/tensorflow/environment.py +161 -0
- rasa/utils/tensorflow/exceptions.py +5 -0
- rasa/utils/tensorflow/feature_array.py +366 -0
- rasa/utils/tensorflow/layers.py +1565 -0
- rasa/utils/tensorflow/layers_utils.py +113 -0
- rasa/utils/tensorflow/metrics.py +281 -0
- rasa/utils/tensorflow/model_data.py +798 -0
- rasa/utils/tensorflow/model_data_utils.py +499 -0
- rasa/utils/tensorflow/models.py +935 -0
- rasa/utils/tensorflow/rasa_layers.py +1094 -0
- rasa/utils/tensorflow/transformer.py +640 -0
- rasa/utils/tensorflow/types.py +6 -0
- rasa/utils/train_utils.py +572 -0
- rasa/utils/url_tools.py +53 -0
- rasa/utils/yaml.py +54 -0
- rasa/validator.py +1644 -0
- rasa/version.py +3 -0
- rasa_pro-3.12.0.dev1.dist-info/METADATA +199 -0
- rasa_pro-3.12.0.dev1.dist-info/NOTICE +5 -0
- rasa_pro-3.12.0.dev1.dist-info/RECORD +790 -0
- rasa_pro-3.12.0.dev1.dist-info/WHEEL +4 -0
- rasa_pro-3.12.0.dev1.dist-info/entry_points.txt +3 -0
rasa/nlu/test.py
ADDED
|
@@ -0,0 +1,1940 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
import itertools
|
|
3
|
+
import os
|
|
4
|
+
import logging
|
|
5
|
+
import structlog
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
from collections import defaultdict, namedtuple
|
|
10
|
+
from tqdm import tqdm
|
|
11
|
+
from typing import (
|
|
12
|
+
Iterable,
|
|
13
|
+
Iterator,
|
|
14
|
+
Tuple,
|
|
15
|
+
List,
|
|
16
|
+
Set,
|
|
17
|
+
Optional,
|
|
18
|
+
Text,
|
|
19
|
+
Union,
|
|
20
|
+
Dict,
|
|
21
|
+
Any,
|
|
22
|
+
NamedTuple,
|
|
23
|
+
TYPE_CHECKING,
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
from rasa import telemetry
|
|
27
|
+
from rasa.core.agent import Agent
|
|
28
|
+
from rasa.core.channels import UserMessage
|
|
29
|
+
from rasa.core.processor import MessageProcessor
|
|
30
|
+
from rasa.shared.nlu.training_data.training_data import TrainingData
|
|
31
|
+
from rasa.shared.utils.yaml import write_yaml
|
|
32
|
+
from rasa.utils.common import TempDirectoryPath, get_temp_dir_name
|
|
33
|
+
import rasa.shared.utils.io
|
|
34
|
+
import rasa.utils.plotting as plot_utils
|
|
35
|
+
import rasa.utils.io as io_utils
|
|
36
|
+
|
|
37
|
+
from rasa.constants import TEST_DATA_FILE, TRAIN_DATA_FILE, NLG_DATA_FILE
|
|
38
|
+
import rasa.nlu.classifiers.fallback_classifier
|
|
39
|
+
from rasa.nlu.constants import (
|
|
40
|
+
RESPONSE_SELECTOR_DEFAULT_INTENT,
|
|
41
|
+
RESPONSE_SELECTOR_PROPERTY_NAME,
|
|
42
|
+
RESPONSE_SELECTOR_PREDICTION_KEY,
|
|
43
|
+
TOKENS_NAMES,
|
|
44
|
+
ENTITY_ATTRIBUTE_CONFIDENCE_TYPE,
|
|
45
|
+
ENTITY_ATTRIBUTE_CONFIDENCE_ROLE,
|
|
46
|
+
ENTITY_ATTRIBUTE_CONFIDENCE_GROUP,
|
|
47
|
+
RESPONSE_SELECTOR_RETRIEVAL_INTENTS,
|
|
48
|
+
)
|
|
49
|
+
from rasa.shared.nlu.constants import (
|
|
50
|
+
TEXT,
|
|
51
|
+
INTENT,
|
|
52
|
+
INTENT_RESPONSE_KEY,
|
|
53
|
+
ENTITIES,
|
|
54
|
+
EXTRACTOR,
|
|
55
|
+
PRETRAINED_EXTRACTORS,
|
|
56
|
+
ENTITY_ATTRIBUTE_TYPE,
|
|
57
|
+
ENTITY_ATTRIBUTE_GROUP,
|
|
58
|
+
ENTITY_ATTRIBUTE_ROLE,
|
|
59
|
+
NO_ENTITY_TAG,
|
|
60
|
+
INTENT_NAME_KEY,
|
|
61
|
+
PREDICTED_CONFIDENCE_KEY,
|
|
62
|
+
)
|
|
63
|
+
from rasa.nlu.classifiers import fallback_classifier
|
|
64
|
+
from rasa.nlu.tokenizers.tokenizer import Token
|
|
65
|
+
from rasa.shared.importers.importer import TrainingDataImporter
|
|
66
|
+
from rasa.shared.nlu.training_data.formats.rasa_yaml import RasaYAMLWriter
|
|
67
|
+
|
|
68
|
+
if TYPE_CHECKING:
|
|
69
|
+
from typing_extensions import TypedDict
|
|
70
|
+
|
|
71
|
+
EntityPrediction = TypedDict(
|
|
72
|
+
"EntityPrediction",
|
|
73
|
+
{
|
|
74
|
+
"text": Text,
|
|
75
|
+
"entities": List[Dict[Text, Any]],
|
|
76
|
+
"predicted_entities": List[Dict[Text, Any]],
|
|
77
|
+
},
|
|
78
|
+
)
|
|
79
|
+
logger = logging.getLogger(__name__)
|
|
80
|
+
structlogger = structlog.get_logger()
|
|
81
|
+
|
|
82
|
+
# Exclude 'EntitySynonymMapper' and 'ResponseSelector' as their super class
|
|
83
|
+
# performs entity extraction but those two classifiers don't
|
|
84
|
+
ENTITY_PROCESSORS = {"EntitySynonymMapper", "ResponseSelector"}
|
|
85
|
+
|
|
86
|
+
EXTRACTORS_WITH_CONFIDENCES = {"CRFEntityExtractor", "DIETClassifier"}
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class CVEvaluationResult(NamedTuple):
|
|
90
|
+
"""Stores NLU cross-validation results."""
|
|
91
|
+
|
|
92
|
+
train: Dict
|
|
93
|
+
test: Dict
|
|
94
|
+
evaluation: Dict
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
NO_ENTITY = "no_entity"
|
|
98
|
+
|
|
99
|
+
IntentEvaluationResult = namedtuple(
|
|
100
|
+
"IntentEvaluationResult", "intent_target intent_prediction message confidence"
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
ResponseSelectionEvaluationResult = namedtuple(
|
|
104
|
+
"ResponseSelectionEvaluationResult",
|
|
105
|
+
"intent_response_key_target intent_response_key_prediction message confidence",
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
EntityEvaluationResult = namedtuple(
|
|
109
|
+
"EntityEvaluationResult", "entity_targets entity_predictions tokens message"
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
IntentMetrics = Dict[Text, List[float]]
|
|
113
|
+
EntityMetrics = Dict[Text, Dict[Text, List[float]]]
|
|
114
|
+
ResponseSelectionMetrics = Dict[Text, List[float]]
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def log_evaluation_table(
|
|
118
|
+
report: Text, precision: float, f1: float, accuracy: float
|
|
119
|
+
) -> None: # pragma: no cover
|
|
120
|
+
"""Log the sklearn evaluation metrics."""
|
|
121
|
+
logger.info(f"F1-Score: {f1}")
|
|
122
|
+
logger.info(f"Precision: {precision}")
|
|
123
|
+
logger.info(f"Accuracy: {accuracy}")
|
|
124
|
+
logger.info(f"Classification report: \n{report}")
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def remove_empty_intent_examples(
|
|
128
|
+
intent_results: List[IntentEvaluationResult],
|
|
129
|
+
) -> List[IntentEvaluationResult]:
|
|
130
|
+
"""Remove those examples without an intent.
|
|
131
|
+
|
|
132
|
+
Args:
|
|
133
|
+
intent_results: intent evaluation results
|
|
134
|
+
|
|
135
|
+
Returns: intent evaluation results
|
|
136
|
+
"""
|
|
137
|
+
filtered = []
|
|
138
|
+
for r in intent_results:
|
|
139
|
+
# substitute None values with empty string
|
|
140
|
+
# to enable sklearn evaluation
|
|
141
|
+
if r.intent_prediction is None:
|
|
142
|
+
r = r._replace(intent_prediction="")
|
|
143
|
+
|
|
144
|
+
if r.intent_target != "" and r.intent_target is not None:
|
|
145
|
+
filtered.append(r)
|
|
146
|
+
|
|
147
|
+
return filtered
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def remove_empty_response_examples(
|
|
151
|
+
response_results: List[ResponseSelectionEvaluationResult],
|
|
152
|
+
) -> List[ResponseSelectionEvaluationResult]:
|
|
153
|
+
"""Remove those examples without a response.
|
|
154
|
+
|
|
155
|
+
Args:
|
|
156
|
+
response_results: response selection evaluation results
|
|
157
|
+
|
|
158
|
+
Returns:
|
|
159
|
+
Response selection evaluation results
|
|
160
|
+
"""
|
|
161
|
+
filtered = []
|
|
162
|
+
for r in response_results:
|
|
163
|
+
# substitute None values with empty string
|
|
164
|
+
# to enable sklearn evaluation
|
|
165
|
+
if r.intent_response_key_prediction is None:
|
|
166
|
+
r = r._replace(intent_response_key_prediction="")
|
|
167
|
+
|
|
168
|
+
if r.confidence is None:
|
|
169
|
+
# This might happen if response selector training data is present but
|
|
170
|
+
# no response selector is part of the model
|
|
171
|
+
r = r._replace(confidence=0.0)
|
|
172
|
+
|
|
173
|
+
if r.intent_response_key_target:
|
|
174
|
+
filtered.append(r)
|
|
175
|
+
|
|
176
|
+
return filtered
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def drop_intents_below_freq(
|
|
180
|
+
training_data: TrainingData, cutoff: int = 5
|
|
181
|
+
) -> TrainingData:
|
|
182
|
+
"""Remove intent groups with less than cutoff instances.
|
|
183
|
+
|
|
184
|
+
Args:
|
|
185
|
+
training_data: training data
|
|
186
|
+
cutoff: threshold
|
|
187
|
+
|
|
188
|
+
Returns: updated training data
|
|
189
|
+
"""
|
|
190
|
+
logger.debug(
|
|
191
|
+
"Raw data intent examples: {}".format(len(training_data.intent_examples))
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
examples_per_intent = training_data.number_of_examples_per_intent
|
|
195
|
+
return training_data.filter_training_examples(
|
|
196
|
+
lambda ex: examples_per_intent.get(ex.get(INTENT), 0) >= cutoff
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
def write_intent_successes(
|
|
201
|
+
intent_results: List[IntentEvaluationResult], successes_filename: Text
|
|
202
|
+
) -> None:
|
|
203
|
+
"""Write successful intent predictions to a file.
|
|
204
|
+
|
|
205
|
+
Args:
|
|
206
|
+
intent_results: intent evaluation result
|
|
207
|
+
successes_filename: filename of file to save successful predictions to
|
|
208
|
+
"""
|
|
209
|
+
successes = [
|
|
210
|
+
{
|
|
211
|
+
"text": r.message,
|
|
212
|
+
"intent": r.intent_target,
|
|
213
|
+
"intent_prediction": {
|
|
214
|
+
INTENT_NAME_KEY: r.intent_prediction,
|
|
215
|
+
"confidence": r.confidence,
|
|
216
|
+
},
|
|
217
|
+
}
|
|
218
|
+
for r in intent_results
|
|
219
|
+
if r.intent_target == r.intent_prediction
|
|
220
|
+
]
|
|
221
|
+
|
|
222
|
+
if successes:
|
|
223
|
+
rasa.shared.utils.io.dump_obj_as_json_to_file(successes_filename, successes)
|
|
224
|
+
logger.info(f"Successful intent predictions saved to {successes_filename}.")
|
|
225
|
+
logger.debug(f"\n\nSuccessfully predicted the following intents: \n{successes}")
|
|
226
|
+
else:
|
|
227
|
+
logger.info("No successful intent predictions found.")
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
def _write_errors(errors: List[Dict], errors_filename: Text, error_type: Text) -> None:
|
|
231
|
+
"""Write incorrect intent predictions to a file.
|
|
232
|
+
|
|
233
|
+
Args:
|
|
234
|
+
errors: Serializable prediction errors.
|
|
235
|
+
errors_filename: filename of file to save incorrect predictions to
|
|
236
|
+
error_type: NLU entity which was evaluated (e.g. `intent` or `entity`).
|
|
237
|
+
"""
|
|
238
|
+
if errors:
|
|
239
|
+
rasa.shared.utils.io.dump_obj_as_json_to_file(errors_filename, errors)
|
|
240
|
+
logger.info(f"Incorrect {error_type} predictions saved to {errors_filename}.")
|
|
241
|
+
logger.debug(
|
|
242
|
+
f"\n\nThese {error_type} examples could not be classified "
|
|
243
|
+
f"correctly: \n{errors}"
|
|
244
|
+
)
|
|
245
|
+
else:
|
|
246
|
+
logger.info(f"Every {error_type} was predicted correctly by the model.")
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
def _get_intent_errors(intent_results: List[IntentEvaluationResult]) -> List[Dict]:
|
|
250
|
+
return [
|
|
251
|
+
{
|
|
252
|
+
"text": r.message,
|
|
253
|
+
"intent": r.intent_target,
|
|
254
|
+
"intent_prediction": {
|
|
255
|
+
INTENT_NAME_KEY: r.intent_prediction,
|
|
256
|
+
"confidence": r.confidence,
|
|
257
|
+
},
|
|
258
|
+
}
|
|
259
|
+
for r in intent_results
|
|
260
|
+
if r.intent_target != r.intent_prediction
|
|
261
|
+
]
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
def write_response_successes(
|
|
265
|
+
response_results: List[ResponseSelectionEvaluationResult], successes_filename: Text
|
|
266
|
+
) -> None:
|
|
267
|
+
"""Write successful response selection predictions to a file.
|
|
268
|
+
|
|
269
|
+
Args:
|
|
270
|
+
response_results: response selection evaluation result
|
|
271
|
+
successes_filename: filename of file to save successful predictions to
|
|
272
|
+
"""
|
|
273
|
+
successes = [
|
|
274
|
+
{
|
|
275
|
+
"text": r.message,
|
|
276
|
+
"intent_response_key_target": r.intent_response_key_target,
|
|
277
|
+
"intent_response_key_prediction": {
|
|
278
|
+
"name": r.intent_response_key_prediction,
|
|
279
|
+
"confidence": r.confidence,
|
|
280
|
+
},
|
|
281
|
+
}
|
|
282
|
+
for r in response_results
|
|
283
|
+
if r.intent_response_key_prediction == r.intent_response_key_target
|
|
284
|
+
]
|
|
285
|
+
|
|
286
|
+
if successes:
|
|
287
|
+
rasa.shared.utils.io.dump_obj_as_json_to_file(successes_filename, successes)
|
|
288
|
+
logger.info(f"Successful response predictions saved to {successes_filename}.")
|
|
289
|
+
structlogger.debug("test.write.response", successes=copy.deepcopy(successes))
|
|
290
|
+
else:
|
|
291
|
+
logger.info("No successful response predictions found.")
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
def _response_errors(
|
|
295
|
+
response_results: List[ResponseSelectionEvaluationResult],
|
|
296
|
+
) -> List[Dict]:
|
|
297
|
+
"""Write incorrect response selection predictions to a file.
|
|
298
|
+
|
|
299
|
+
Args:
|
|
300
|
+
response_results: response selection evaluation result
|
|
301
|
+
|
|
302
|
+
Returns:
|
|
303
|
+
Serializable prediction errors.
|
|
304
|
+
"""
|
|
305
|
+
return [
|
|
306
|
+
{
|
|
307
|
+
"text": r.message,
|
|
308
|
+
"intent_response_key_target": r.intent_response_key_target,
|
|
309
|
+
"intent_response_key_prediction": {
|
|
310
|
+
"name": r.intent_response_key_prediction,
|
|
311
|
+
"confidence": r.confidence,
|
|
312
|
+
},
|
|
313
|
+
}
|
|
314
|
+
for r in response_results
|
|
315
|
+
if r.intent_response_key_prediction != r.intent_response_key_target
|
|
316
|
+
]
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
def plot_attribute_confidences(
|
|
320
|
+
results: Union[
|
|
321
|
+
List[IntentEvaluationResult], List[ResponseSelectionEvaluationResult]
|
|
322
|
+
],
|
|
323
|
+
hist_filename: Optional[Text],
|
|
324
|
+
target_key: Text,
|
|
325
|
+
prediction_key: Text,
|
|
326
|
+
title: Text,
|
|
327
|
+
) -> None:
|
|
328
|
+
"""Create histogram of confidence distribution.
|
|
329
|
+
|
|
330
|
+
Args:
|
|
331
|
+
results: evaluation results
|
|
332
|
+
hist_filename: filename to save plot to
|
|
333
|
+
target_key: key of target in results
|
|
334
|
+
prediction_key: key of predictions in results
|
|
335
|
+
title: title of plot
|
|
336
|
+
"""
|
|
337
|
+
pos_hist = [
|
|
338
|
+
r.confidence
|
|
339
|
+
for r in results
|
|
340
|
+
if getattr(r, target_key) == getattr(r, prediction_key)
|
|
341
|
+
]
|
|
342
|
+
|
|
343
|
+
neg_hist = [
|
|
344
|
+
r.confidence
|
|
345
|
+
for r in results
|
|
346
|
+
if getattr(r, target_key) != getattr(r, prediction_key)
|
|
347
|
+
]
|
|
348
|
+
|
|
349
|
+
plot_utils.plot_paired_histogram([pos_hist, neg_hist], title, hist_filename)
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
def plot_entity_confidences(
|
|
353
|
+
merged_targets: List[Text],
|
|
354
|
+
merged_predictions: List[Text],
|
|
355
|
+
merged_confidences: List[float],
|
|
356
|
+
hist_filename: Text,
|
|
357
|
+
title: Text,
|
|
358
|
+
) -> None:
|
|
359
|
+
"""Creates histogram of confidence distribution.
|
|
360
|
+
|
|
361
|
+
Args:
|
|
362
|
+
merged_targets: Entity labels.
|
|
363
|
+
merged_predictions: Predicted entities.
|
|
364
|
+
merged_confidences: Confidence scores of predictions.
|
|
365
|
+
hist_filename: filename to save plot to
|
|
366
|
+
title: title of plot
|
|
367
|
+
"""
|
|
368
|
+
pos_hist = [
|
|
369
|
+
confidence
|
|
370
|
+
for target, prediction, confidence in zip(
|
|
371
|
+
merged_targets, merged_predictions, merged_confidences
|
|
372
|
+
)
|
|
373
|
+
if target != NO_ENTITY and target == prediction
|
|
374
|
+
]
|
|
375
|
+
|
|
376
|
+
neg_hist = [
|
|
377
|
+
confidence
|
|
378
|
+
for target, prediction, confidence in zip(
|
|
379
|
+
merged_targets, merged_predictions, merged_confidences
|
|
380
|
+
)
|
|
381
|
+
if prediction not in (NO_ENTITY, target)
|
|
382
|
+
]
|
|
383
|
+
|
|
384
|
+
plot_utils.plot_paired_histogram([pos_hist, neg_hist], title, hist_filename)
|
|
385
|
+
|
|
386
|
+
|
|
387
|
+
def evaluate_response_selections(
|
|
388
|
+
response_selection_results: List[ResponseSelectionEvaluationResult],
|
|
389
|
+
output_directory: Optional[Text],
|
|
390
|
+
successes: bool,
|
|
391
|
+
errors: bool,
|
|
392
|
+
disable_plotting: bool,
|
|
393
|
+
report_as_dict: Optional[bool] = None,
|
|
394
|
+
) -> Dict: # pragma: no cover
|
|
395
|
+
"""Creates summary statistics for response selection.
|
|
396
|
+
|
|
397
|
+
Only considers those examples with a set response.
|
|
398
|
+
Others are filtered out. Returns a dictionary of containing the
|
|
399
|
+
evaluation result.
|
|
400
|
+
|
|
401
|
+
Args:
|
|
402
|
+
response_selection_results: response selection evaluation results
|
|
403
|
+
output_directory: directory to store files to
|
|
404
|
+
successes: if True success are written down to disk
|
|
405
|
+
errors: if True errors are written down to disk
|
|
406
|
+
disable_plotting: if True no plots are created
|
|
407
|
+
report_as_dict: `True` if the evaluation report should be returned as `dict`.
|
|
408
|
+
If `False` the report is returned in a human-readable text format. If `None`
|
|
409
|
+
`report_as_dict` is considered as `True` in case an `output_directory` is
|
|
410
|
+
given.
|
|
411
|
+
|
|
412
|
+
Returns: dictionary with evaluation results
|
|
413
|
+
"""
|
|
414
|
+
# remove empty response targets
|
|
415
|
+
num_examples = len(response_selection_results)
|
|
416
|
+
response_selection_results = remove_empty_response_examples(
|
|
417
|
+
response_selection_results
|
|
418
|
+
)
|
|
419
|
+
|
|
420
|
+
logger.info(
|
|
421
|
+
f"Response Selection Evaluation: Only considering those "
|
|
422
|
+
f"{len(response_selection_results)} examples that have a defined response out "
|
|
423
|
+
f"of {num_examples} examples."
|
|
424
|
+
)
|
|
425
|
+
|
|
426
|
+
(
|
|
427
|
+
target_intent_response_keys,
|
|
428
|
+
predicted_intent_response_keys,
|
|
429
|
+
) = _targets_predictions_from(
|
|
430
|
+
response_selection_results,
|
|
431
|
+
"intent_response_key_target",
|
|
432
|
+
"intent_response_key_prediction",
|
|
433
|
+
)
|
|
434
|
+
|
|
435
|
+
report, precision, f1, accuracy, confusion_matrix, labels = _calculate_report(
|
|
436
|
+
output_directory,
|
|
437
|
+
target_intent_response_keys,
|
|
438
|
+
predicted_intent_response_keys,
|
|
439
|
+
report_as_dict,
|
|
440
|
+
)
|
|
441
|
+
if output_directory:
|
|
442
|
+
_dump_report(output_directory, "response_selection_report.json", report)
|
|
443
|
+
|
|
444
|
+
if successes:
|
|
445
|
+
successes_filename = "response_selection_successes.json"
|
|
446
|
+
if output_directory:
|
|
447
|
+
successes_filename = os.path.join(output_directory, successes_filename)
|
|
448
|
+
# save classified samples to file for debugging
|
|
449
|
+
write_response_successes(response_selection_results, successes_filename)
|
|
450
|
+
|
|
451
|
+
response_errors = _response_errors(response_selection_results)
|
|
452
|
+
|
|
453
|
+
if errors and output_directory:
|
|
454
|
+
errors_filename = "response_selection_errors.json"
|
|
455
|
+
errors_filename = os.path.join(output_directory, errors_filename)
|
|
456
|
+
_write_errors(response_errors, errors_filename, error_type="response")
|
|
457
|
+
|
|
458
|
+
if not disable_plotting:
|
|
459
|
+
confusion_matrix_filename = "response_selection_confusion_matrix.png"
|
|
460
|
+
if output_directory:
|
|
461
|
+
confusion_matrix_filename = os.path.join(
|
|
462
|
+
output_directory, confusion_matrix_filename
|
|
463
|
+
)
|
|
464
|
+
|
|
465
|
+
plot_utils.plot_confusion_matrix(
|
|
466
|
+
confusion_matrix,
|
|
467
|
+
classes=labels,
|
|
468
|
+
title="Response Selection Confusion Matrix",
|
|
469
|
+
output_file=confusion_matrix_filename,
|
|
470
|
+
)
|
|
471
|
+
|
|
472
|
+
histogram_filename = "response_selection_histogram.png"
|
|
473
|
+
if output_directory:
|
|
474
|
+
histogram_filename = os.path.join(output_directory, histogram_filename)
|
|
475
|
+
plot_attribute_confidences(
|
|
476
|
+
response_selection_results,
|
|
477
|
+
histogram_filename,
|
|
478
|
+
"intent_response_key_target",
|
|
479
|
+
"intent_response_key_prediction",
|
|
480
|
+
title="Response Selection Prediction Confidence Distribution",
|
|
481
|
+
)
|
|
482
|
+
|
|
483
|
+
predictions = [
|
|
484
|
+
{
|
|
485
|
+
"text": res.message,
|
|
486
|
+
"intent_response_key_target": res.intent_response_key_target,
|
|
487
|
+
"intent_response_key_prediction": res.intent_response_key_prediction,
|
|
488
|
+
"confidence": res.confidence,
|
|
489
|
+
}
|
|
490
|
+
for res in response_selection_results
|
|
491
|
+
]
|
|
492
|
+
|
|
493
|
+
return {
|
|
494
|
+
"predictions": predictions,
|
|
495
|
+
"report": report,
|
|
496
|
+
"precision": precision,
|
|
497
|
+
"f1_score": f1,
|
|
498
|
+
"accuracy": accuracy,
|
|
499
|
+
"errors": response_errors,
|
|
500
|
+
}
|
|
501
|
+
|
|
502
|
+
|
|
503
|
+
def _add_confused_labels_to_report(
|
|
504
|
+
report: Dict[Text, Dict[Text, Any]],
|
|
505
|
+
confusion_matrix: np.ndarray,
|
|
506
|
+
labels: List[Text],
|
|
507
|
+
exclude_labels: Optional[List[Text]] = None,
|
|
508
|
+
) -> Dict[Text, Dict[Text, Union[Dict, Any]]]:
|
|
509
|
+
"""Adds a field "confused_with" to the evaluation report.
|
|
510
|
+
|
|
511
|
+
The value is a dict of {"false_positive_label": false_positive_count} pairs.
|
|
512
|
+
If there are no false positives in the confusion matrix,
|
|
513
|
+
the dict will be empty. Typically, we include the two most
|
|
514
|
+
commonly false positive labels, three in the rare case that
|
|
515
|
+
the diagonal element in the confusion matrix is not one of the
|
|
516
|
+
three highest values in the row.
|
|
517
|
+
|
|
518
|
+
Args:
|
|
519
|
+
report: the evaluation report
|
|
520
|
+
confusion_matrix: confusion matrix
|
|
521
|
+
labels: list of labels
|
|
522
|
+
exclude_labels: labels to exclude from the report
|
|
523
|
+
|
|
524
|
+
Returns: updated evaluation report
|
|
525
|
+
"""
|
|
526
|
+
if exclude_labels is None:
|
|
527
|
+
exclude_labels = []
|
|
528
|
+
|
|
529
|
+
# sort confusion matrix by false positives
|
|
530
|
+
indices = np.argsort(confusion_matrix, axis=1)
|
|
531
|
+
n_candidates = min(3, len(labels))
|
|
532
|
+
|
|
533
|
+
for label in labels:
|
|
534
|
+
if label in exclude_labels:
|
|
535
|
+
continue
|
|
536
|
+
# it is possible to predict intent 'None'
|
|
537
|
+
if report.get(label):
|
|
538
|
+
report[label]["confused_with"] = {}
|
|
539
|
+
|
|
540
|
+
for i, label in enumerate(labels):
|
|
541
|
+
if label in exclude_labels:
|
|
542
|
+
continue
|
|
543
|
+
for j in range(n_candidates):
|
|
544
|
+
label_idx = indices[i, -(1 + j)]
|
|
545
|
+
false_pos_label = labels[label_idx]
|
|
546
|
+
false_positives = int(confusion_matrix[i, label_idx])
|
|
547
|
+
if (
|
|
548
|
+
false_pos_label != label
|
|
549
|
+
and false_pos_label not in exclude_labels
|
|
550
|
+
and false_positives > 0
|
|
551
|
+
):
|
|
552
|
+
report[label]["confused_with"][false_pos_label] = false_positives
|
|
553
|
+
|
|
554
|
+
return report
|
|
555
|
+
|
|
556
|
+
|
|
557
|
+
def evaluate_intents(
|
|
558
|
+
intent_results: List[IntentEvaluationResult],
|
|
559
|
+
output_directory: Optional[Text],
|
|
560
|
+
successes: bool,
|
|
561
|
+
errors: bool,
|
|
562
|
+
disable_plotting: bool,
|
|
563
|
+
report_as_dict: Optional[bool] = None,
|
|
564
|
+
) -> Dict: # pragma: no cover
|
|
565
|
+
"""Creates summary statistics for intents.
|
|
566
|
+
|
|
567
|
+
Only considers those examples with a set intent. Others are filtered out.
|
|
568
|
+
Returns a dictionary of containing the evaluation result.
|
|
569
|
+
|
|
570
|
+
Args:
|
|
571
|
+
intent_results: intent evaluation results
|
|
572
|
+
output_directory: directory to store files to
|
|
573
|
+
successes: if True correct predictions are written to disk
|
|
574
|
+
errors: if True incorrect predictions are written to disk
|
|
575
|
+
disable_plotting: if True no plots are created
|
|
576
|
+
report_as_dict: `True` if the evaluation report should be returned as `dict`.
|
|
577
|
+
If `False` the report is returned in a human-readable text format. If `None`
|
|
578
|
+
`report_as_dict` is considered as `True` in case an `output_directory` is
|
|
579
|
+
given.
|
|
580
|
+
|
|
581
|
+
Returns: dictionary with evaluation results
|
|
582
|
+
"""
|
|
583
|
+
# remove empty intent targets
|
|
584
|
+
num_examples = len(intent_results)
|
|
585
|
+
intent_results = remove_empty_intent_examples(intent_results)
|
|
586
|
+
|
|
587
|
+
logger.info(
|
|
588
|
+
f"Intent Evaluation: Only considering those {len(intent_results)} examples "
|
|
589
|
+
f"that have a defined intent out of {num_examples} examples."
|
|
590
|
+
)
|
|
591
|
+
|
|
592
|
+
target_intents, predicted_intents = _targets_predictions_from(
|
|
593
|
+
intent_results, "intent_target", "intent_prediction"
|
|
594
|
+
)
|
|
595
|
+
|
|
596
|
+
report, precision, f1, accuracy, confusion_matrix, labels = _calculate_report(
|
|
597
|
+
output_directory, target_intents, predicted_intents, report_as_dict
|
|
598
|
+
)
|
|
599
|
+
if output_directory:
|
|
600
|
+
_dump_report(output_directory, "intent_report.json", report)
|
|
601
|
+
|
|
602
|
+
if successes and output_directory:
|
|
603
|
+
successes_filename = os.path.join(output_directory, "intent_successes.json")
|
|
604
|
+
# save classified samples to file for debugging
|
|
605
|
+
write_intent_successes(intent_results, successes_filename)
|
|
606
|
+
|
|
607
|
+
intent_errors = _get_intent_errors(intent_results)
|
|
608
|
+
if errors and output_directory:
|
|
609
|
+
errors_filename = os.path.join(output_directory, "intent_errors.json")
|
|
610
|
+
_write_errors(intent_errors, errors_filename, "intent")
|
|
611
|
+
|
|
612
|
+
if not disable_plotting:
|
|
613
|
+
confusion_matrix_filename = "intent_confusion_matrix.png"
|
|
614
|
+
if output_directory:
|
|
615
|
+
confusion_matrix_filename = os.path.join(
|
|
616
|
+
output_directory, confusion_matrix_filename
|
|
617
|
+
)
|
|
618
|
+
plot_utils.plot_confusion_matrix(
|
|
619
|
+
confusion_matrix,
|
|
620
|
+
classes=labels,
|
|
621
|
+
title="Intent Confusion matrix",
|
|
622
|
+
output_file=confusion_matrix_filename,
|
|
623
|
+
)
|
|
624
|
+
|
|
625
|
+
histogram_filename = "intent_histogram.png"
|
|
626
|
+
if output_directory:
|
|
627
|
+
histogram_filename = os.path.join(output_directory, histogram_filename)
|
|
628
|
+
plot_attribute_confidences(
|
|
629
|
+
intent_results,
|
|
630
|
+
histogram_filename,
|
|
631
|
+
"intent_target",
|
|
632
|
+
"intent_prediction",
|
|
633
|
+
title="Intent Prediction Confidence Distribution",
|
|
634
|
+
)
|
|
635
|
+
|
|
636
|
+
predictions = [
|
|
637
|
+
{
|
|
638
|
+
"text": res.message,
|
|
639
|
+
"intent": res.intent_target,
|
|
640
|
+
"predicted": res.intent_prediction,
|
|
641
|
+
"confidence": res.confidence,
|
|
642
|
+
}
|
|
643
|
+
for res in intent_results
|
|
644
|
+
]
|
|
645
|
+
|
|
646
|
+
return {
|
|
647
|
+
"predictions": predictions,
|
|
648
|
+
"report": report,
|
|
649
|
+
"precision": precision,
|
|
650
|
+
"f1_score": f1,
|
|
651
|
+
"accuracy": accuracy,
|
|
652
|
+
"errors": intent_errors,
|
|
653
|
+
}
|
|
654
|
+
|
|
655
|
+
|
|
656
|
+
def _calculate_report(
|
|
657
|
+
output_directory: Optional[Text],
|
|
658
|
+
targets: Iterable[Any],
|
|
659
|
+
predictions: Iterable[Any],
|
|
660
|
+
report_as_dict: Optional[bool] = None,
|
|
661
|
+
exclude_label: Optional[Text] = None,
|
|
662
|
+
) -> Tuple[Union[Text, Dict], float, float, float, np.ndarray, List[Text]]:
|
|
663
|
+
from rasa.model_testing import get_evaluation_metrics
|
|
664
|
+
import sklearn.metrics
|
|
665
|
+
import sklearn.utils.multiclass
|
|
666
|
+
|
|
667
|
+
confusion_matrix = sklearn.metrics.confusion_matrix(targets, predictions)
|
|
668
|
+
labels = sklearn.utils.multiclass.unique_labels(targets, predictions)
|
|
669
|
+
|
|
670
|
+
if report_as_dict is None:
|
|
671
|
+
report_as_dict = bool(output_directory)
|
|
672
|
+
|
|
673
|
+
report, precision, f1, accuracy = get_evaluation_metrics(
|
|
674
|
+
targets, predictions, output_dict=report_as_dict, exclude_label=exclude_label
|
|
675
|
+
)
|
|
676
|
+
|
|
677
|
+
if report_as_dict:
|
|
678
|
+
report = _add_confused_labels_to_report( # type: ignore[assignment]
|
|
679
|
+
report,
|
|
680
|
+
confusion_matrix,
|
|
681
|
+
labels,
|
|
682
|
+
exclude_labels=[exclude_label] if exclude_label else [],
|
|
683
|
+
)
|
|
684
|
+
elif not output_directory:
|
|
685
|
+
log_evaluation_table(report, precision, f1, accuracy)
|
|
686
|
+
|
|
687
|
+
return report, precision, f1, accuracy, confusion_matrix, labels
|
|
688
|
+
|
|
689
|
+
|
|
690
|
+
def _dump_report(output_directory: Text, filename: Text, report: Dict) -> None:
|
|
691
|
+
report_filename = os.path.join(output_directory, filename)
|
|
692
|
+
rasa.shared.utils.io.dump_obj_as_json_to_file(report_filename, report)
|
|
693
|
+
logger.info(f"Classification report saved to {report_filename}.")
|
|
694
|
+
|
|
695
|
+
|
|
696
|
+
def merge_labels(
|
|
697
|
+
aligned_predictions: List[Dict], extractor: Optional[Text] = None
|
|
698
|
+
) -> List[Text]:
|
|
699
|
+
"""Concatenates all labels of the aligned predictions.
|
|
700
|
+
|
|
701
|
+
Takes the aligned prediction labels which are grouped for each message
|
|
702
|
+
and concatenates them.
|
|
703
|
+
|
|
704
|
+
Args:
|
|
705
|
+
aligned_predictions: aligned predictions
|
|
706
|
+
extractor: entity extractor name
|
|
707
|
+
|
|
708
|
+
Returns:
|
|
709
|
+
Concatenated predictions
|
|
710
|
+
"""
|
|
711
|
+
if extractor:
|
|
712
|
+
label_lists = [ap["extractor_labels"][extractor] for ap in aligned_predictions]
|
|
713
|
+
else:
|
|
714
|
+
label_lists = [ap["target_labels"] for ap in aligned_predictions]
|
|
715
|
+
|
|
716
|
+
return list(itertools.chain(*label_lists))
|
|
717
|
+
|
|
718
|
+
|
|
719
|
+
def merge_confidences(
|
|
720
|
+
aligned_predictions: List[Dict], extractor: Optional[Text] = None
|
|
721
|
+
) -> List[float]:
|
|
722
|
+
"""Concatenates all confidences of the aligned predictions.
|
|
723
|
+
|
|
724
|
+
Takes the aligned prediction confidences which are grouped for each message
|
|
725
|
+
and concatenates them.
|
|
726
|
+
|
|
727
|
+
Args:
|
|
728
|
+
aligned_predictions: aligned predictions
|
|
729
|
+
extractor: entity extractor name
|
|
730
|
+
|
|
731
|
+
Returns:
|
|
732
|
+
Concatenated confidences
|
|
733
|
+
"""
|
|
734
|
+
label_lists = [ap["confidences"][extractor] for ap in aligned_predictions]
|
|
735
|
+
return list(itertools.chain(*label_lists))
|
|
736
|
+
|
|
737
|
+
|
|
738
|
+
def substitute_labels(labels: List[Text], old: Text, new: Text) -> List[Text]:
|
|
739
|
+
"""Replaces label names in a list of labels.
|
|
740
|
+
|
|
741
|
+
Args:
|
|
742
|
+
labels: list of labels
|
|
743
|
+
old: old label name that should be replaced
|
|
744
|
+
new: new label name
|
|
745
|
+
|
|
746
|
+
Returns: updated labels
|
|
747
|
+
"""
|
|
748
|
+
return [new if label == old else label for label in labels]
|
|
749
|
+
|
|
750
|
+
|
|
751
|
+
def collect_incorrect_entity_predictions(
|
|
752
|
+
entity_results: List[EntityEvaluationResult],
|
|
753
|
+
merged_predictions: List[Text],
|
|
754
|
+
merged_targets: List[Text],
|
|
755
|
+
) -> List["EntityPrediction"]:
|
|
756
|
+
"""Get incorrect entity predictions.
|
|
757
|
+
|
|
758
|
+
Args:
|
|
759
|
+
entity_results: entity evaluation results
|
|
760
|
+
merged_predictions: list of predicted entity labels
|
|
761
|
+
merged_targets: list of true entity labels
|
|
762
|
+
|
|
763
|
+
Returns: list of incorrect predictions
|
|
764
|
+
"""
|
|
765
|
+
errors = []
|
|
766
|
+
offset = 0
|
|
767
|
+
for entity_result in entity_results:
|
|
768
|
+
for i in range(offset, offset + len(entity_result.tokens)):
|
|
769
|
+
if merged_targets[i] != merged_predictions[i]:
|
|
770
|
+
prediction: EntityPrediction = {
|
|
771
|
+
"text": entity_result.message,
|
|
772
|
+
"entities": entity_result.entity_targets,
|
|
773
|
+
"predicted_entities": entity_result.entity_predictions,
|
|
774
|
+
}
|
|
775
|
+
errors.append(prediction)
|
|
776
|
+
break
|
|
777
|
+
offset += len(entity_result.tokens)
|
|
778
|
+
return errors
|
|
779
|
+
|
|
780
|
+
|
|
781
|
+
def write_successful_entity_predictions(
|
|
782
|
+
entity_results: List[EntityEvaluationResult],
|
|
783
|
+
merged_targets: List[Text],
|
|
784
|
+
merged_predictions: List[Text],
|
|
785
|
+
successes_filename: Text,
|
|
786
|
+
) -> None:
|
|
787
|
+
"""Write correct entity predictions to a file.
|
|
788
|
+
|
|
789
|
+
Args:
|
|
790
|
+
entity_results: response selection evaluation result
|
|
791
|
+
merged_predictions: list of predicted entity labels
|
|
792
|
+
merged_targets: list of true entity labels
|
|
793
|
+
successes_filename: filename of file to save correct predictions to
|
|
794
|
+
"""
|
|
795
|
+
successes = collect_successful_entity_predictions(
|
|
796
|
+
entity_results, merged_predictions, merged_targets
|
|
797
|
+
)
|
|
798
|
+
|
|
799
|
+
if successes:
|
|
800
|
+
rasa.shared.utils.io.dump_obj_as_json_to_file(successes_filename, successes)
|
|
801
|
+
logger.info(f"Successful entity predictions saved to {successes_filename}.")
|
|
802
|
+
structlogger.debug("test.write.entities", successes=copy.deepcopy(successes))
|
|
803
|
+
else:
|
|
804
|
+
logger.info("No successful entity prediction found.")
|
|
805
|
+
|
|
806
|
+
|
|
807
|
+
def collect_successful_entity_predictions(
|
|
808
|
+
entity_results: List[EntityEvaluationResult],
|
|
809
|
+
merged_predictions: List[Text],
|
|
810
|
+
merged_targets: List[Text],
|
|
811
|
+
) -> List["EntityPrediction"]:
|
|
812
|
+
"""Get correct entity predictions.
|
|
813
|
+
|
|
814
|
+
Args:
|
|
815
|
+
entity_results: entity evaluation results
|
|
816
|
+
merged_predictions: list of predicted entity labels
|
|
817
|
+
merged_targets: list of true entity labels
|
|
818
|
+
|
|
819
|
+
Returns: list of correct predictions
|
|
820
|
+
"""
|
|
821
|
+
successes = []
|
|
822
|
+
offset = 0
|
|
823
|
+
for entity_result in entity_results:
|
|
824
|
+
for i in range(offset, offset + len(entity_result.tokens)):
|
|
825
|
+
if (
|
|
826
|
+
merged_targets[i] == merged_predictions[i]
|
|
827
|
+
and merged_targets[i] != NO_ENTITY
|
|
828
|
+
):
|
|
829
|
+
prediction: EntityPrediction = {
|
|
830
|
+
"text": entity_result.message,
|
|
831
|
+
"entities": entity_result.entity_targets,
|
|
832
|
+
"predicted_entities": entity_result.entity_predictions,
|
|
833
|
+
}
|
|
834
|
+
successes.append(prediction)
|
|
835
|
+
break
|
|
836
|
+
offset += len(entity_result.tokens)
|
|
837
|
+
return successes
|
|
838
|
+
|
|
839
|
+
|
|
840
|
+
def evaluate_entities(
|
|
841
|
+
entity_results: List[EntityEvaluationResult],
|
|
842
|
+
extractors: Set[Text],
|
|
843
|
+
output_directory: Optional[Text],
|
|
844
|
+
successes: bool,
|
|
845
|
+
errors: bool,
|
|
846
|
+
disable_plotting: bool,
|
|
847
|
+
report_as_dict: Optional[bool] = None,
|
|
848
|
+
) -> Dict: # pragma: no cover
|
|
849
|
+
"""Creates summary statistics for each entity extractor.
|
|
850
|
+
|
|
851
|
+
Logs precision, recall, and F1 per entity type for each extractor.
|
|
852
|
+
|
|
853
|
+
Args:
|
|
854
|
+
entity_results: entity evaluation results
|
|
855
|
+
extractors: entity extractors to consider
|
|
856
|
+
output_directory: directory to store files to
|
|
857
|
+
successes: if True correct predictions are written to disk
|
|
858
|
+
errors: if True incorrect predictions are written to disk
|
|
859
|
+
disable_plotting: if True no plots are created
|
|
860
|
+
report_as_dict: `True` if the evaluation report should be returned as `dict`.
|
|
861
|
+
If `False` the report is returned in a human-readable text format. If `None`
|
|
862
|
+
`report_as_dict` is considered as `True` in case an `output_directory` is
|
|
863
|
+
given.
|
|
864
|
+
|
|
865
|
+
Returns: dictionary with evaluation results
|
|
866
|
+
"""
|
|
867
|
+
aligned_predictions = align_all_entity_predictions(entity_results, extractors)
|
|
868
|
+
merged_targets = merge_labels(aligned_predictions)
|
|
869
|
+
merged_targets = substitute_labels(merged_targets, NO_ENTITY_TAG, NO_ENTITY)
|
|
870
|
+
|
|
871
|
+
result = {}
|
|
872
|
+
|
|
873
|
+
for extractor in extractors:
|
|
874
|
+
merged_predictions = merge_labels(aligned_predictions, extractor)
|
|
875
|
+
merged_predictions = substitute_labels(
|
|
876
|
+
merged_predictions, NO_ENTITY_TAG, NO_ENTITY
|
|
877
|
+
)
|
|
878
|
+
|
|
879
|
+
logger.info(f"Evaluation for entity extractor: {extractor} ")
|
|
880
|
+
|
|
881
|
+
report, precision, f1, accuracy, confusion_matrix, labels = _calculate_report(
|
|
882
|
+
output_directory,
|
|
883
|
+
merged_targets,
|
|
884
|
+
merged_predictions,
|
|
885
|
+
report_as_dict,
|
|
886
|
+
exclude_label=NO_ENTITY,
|
|
887
|
+
)
|
|
888
|
+
if output_directory:
|
|
889
|
+
_dump_report(output_directory, f"{extractor}_report.json", report)
|
|
890
|
+
|
|
891
|
+
if successes:
|
|
892
|
+
successes_filename = f"{extractor}_successes.json"
|
|
893
|
+
if output_directory:
|
|
894
|
+
successes_filename = os.path.join(output_directory, successes_filename)
|
|
895
|
+
# save classified samples to file for debugging
|
|
896
|
+
write_successful_entity_predictions(
|
|
897
|
+
entity_results, merged_targets, merged_predictions, successes_filename
|
|
898
|
+
)
|
|
899
|
+
|
|
900
|
+
entity_errors = collect_incorrect_entity_predictions(
|
|
901
|
+
entity_results, merged_predictions, merged_targets
|
|
902
|
+
)
|
|
903
|
+
if errors and output_directory:
|
|
904
|
+
errors_filename = os.path.join(output_directory, f"{extractor}_errors.json")
|
|
905
|
+
|
|
906
|
+
_write_errors(entity_errors, errors_filename, "entity")
|
|
907
|
+
|
|
908
|
+
if not disable_plotting:
|
|
909
|
+
confusion_matrix_filename = f"{extractor}_confusion_matrix.png"
|
|
910
|
+
if output_directory:
|
|
911
|
+
confusion_matrix_filename = os.path.join(
|
|
912
|
+
output_directory, confusion_matrix_filename
|
|
913
|
+
)
|
|
914
|
+
plot_utils.plot_confusion_matrix(
|
|
915
|
+
confusion_matrix,
|
|
916
|
+
classes=labels,
|
|
917
|
+
title="Entity Confusion matrix",
|
|
918
|
+
output_file=confusion_matrix_filename,
|
|
919
|
+
)
|
|
920
|
+
|
|
921
|
+
if extractor in EXTRACTORS_WITH_CONFIDENCES:
|
|
922
|
+
merged_confidences = merge_confidences(aligned_predictions, extractor)
|
|
923
|
+
histogram_filename = f"{extractor}_histogram.png"
|
|
924
|
+
if output_directory:
|
|
925
|
+
histogram_filename = os.path.join(
|
|
926
|
+
output_directory, histogram_filename
|
|
927
|
+
)
|
|
928
|
+
plot_entity_confidences(
|
|
929
|
+
merged_targets,
|
|
930
|
+
merged_predictions,
|
|
931
|
+
merged_confidences,
|
|
932
|
+
title="Entity Prediction Confidence Distribution",
|
|
933
|
+
hist_filename=histogram_filename,
|
|
934
|
+
)
|
|
935
|
+
|
|
936
|
+
result[extractor] = {
|
|
937
|
+
"report": report,
|
|
938
|
+
"precision": precision,
|
|
939
|
+
"f1_score": f1,
|
|
940
|
+
"accuracy": accuracy,
|
|
941
|
+
"errors": entity_errors,
|
|
942
|
+
}
|
|
943
|
+
|
|
944
|
+
return result
|
|
945
|
+
|
|
946
|
+
|
|
947
|
+
def is_token_within_entity(token: Token, entity: Dict) -> bool:
|
|
948
|
+
"""Checks if a token is within the boundaries of an entity."""
|
|
949
|
+
return determine_intersection(token, entity) == len(token.text)
|
|
950
|
+
|
|
951
|
+
|
|
952
|
+
def does_token_cross_borders(token: Token, entity: Dict) -> bool:
|
|
953
|
+
"""Checks if a token crosses the boundaries of an entity."""
|
|
954
|
+
num_intersect = determine_intersection(token, entity)
|
|
955
|
+
return 0 < num_intersect < len(token.text)
|
|
956
|
+
|
|
957
|
+
|
|
958
|
+
def determine_intersection(token: Token, entity: Dict) -> int:
|
|
959
|
+
"""Calculates how many characters a given token and entity share."""
|
|
960
|
+
pos_token = set(range(token.start, token.end))
|
|
961
|
+
pos_entity = set(range(entity["start"], entity["end"]))
|
|
962
|
+
return len(pos_token.intersection(pos_entity))
|
|
963
|
+
|
|
964
|
+
|
|
965
|
+
def do_entities_overlap(entities: List[Dict]) -> bool:
|
|
966
|
+
"""Checks if entities overlap.
|
|
967
|
+
|
|
968
|
+
I.e. cross each others start and end boundaries.
|
|
969
|
+
|
|
970
|
+
Args:
|
|
971
|
+
entities: list of entities
|
|
972
|
+
|
|
973
|
+
Returns: true if entities overlap, false otherwise.
|
|
974
|
+
"""
|
|
975
|
+
sorted_entities = sorted(entities, key=lambda e: e["start"])
|
|
976
|
+
for i in range(len(sorted_entities) - 1):
|
|
977
|
+
curr_ent = sorted_entities[i]
|
|
978
|
+
next_ent = sorted_entities[i + 1]
|
|
979
|
+
if (
|
|
980
|
+
next_ent["start"] < curr_ent["end"]
|
|
981
|
+
and next_ent["entity"] != curr_ent["entity"]
|
|
982
|
+
):
|
|
983
|
+
structlogger.warning(
|
|
984
|
+
"test.overlaping.entities",
|
|
985
|
+
current_entity=copy.deepcopy(curr_ent),
|
|
986
|
+
next_entity=copy.deepcopy(next_ent),
|
|
987
|
+
)
|
|
988
|
+
return True
|
|
989
|
+
|
|
990
|
+
return False
|
|
991
|
+
|
|
992
|
+
|
|
993
|
+
def find_intersecting_entities(token: Token, entities: List[Dict]) -> List[Dict]:
|
|
994
|
+
"""Finds the entities that intersect with a token.
|
|
995
|
+
|
|
996
|
+
Args:
|
|
997
|
+
token: a single token
|
|
998
|
+
entities: entities found by a single extractor
|
|
999
|
+
|
|
1000
|
+
Returns: list of entities
|
|
1001
|
+
"""
|
|
1002
|
+
candidates = []
|
|
1003
|
+
for e in entities:
|
|
1004
|
+
if is_token_within_entity(token, e):
|
|
1005
|
+
candidates.append(e)
|
|
1006
|
+
elif does_token_cross_borders(token, e):
|
|
1007
|
+
candidates.append(e)
|
|
1008
|
+
structlogger.debug(
|
|
1009
|
+
"test.intersecting.entities",
|
|
1010
|
+
token_text=copy.deepcopy(token.text),
|
|
1011
|
+
token_start=token.start,
|
|
1012
|
+
token_end=token.end,
|
|
1013
|
+
entity=copy.deepcopy(e),
|
|
1014
|
+
)
|
|
1015
|
+
return candidates
|
|
1016
|
+
|
|
1017
|
+
|
|
1018
|
+
def pick_best_entity_fit(
|
|
1019
|
+
token: Token, candidates: List[Dict[Text, Any]]
|
|
1020
|
+
) -> Optional[Dict[Text, Any]]:
|
|
1021
|
+
"""Determines the best fitting entity given intersecting entities.
|
|
1022
|
+
|
|
1023
|
+
Args:
|
|
1024
|
+
token: a single token
|
|
1025
|
+
candidates: entities found by a single extractor
|
|
1026
|
+
attribute_key: the attribute key of interest
|
|
1027
|
+
|
|
1028
|
+
Returns:
|
|
1029
|
+
the value of the attribute key of the best fitting entity
|
|
1030
|
+
"""
|
|
1031
|
+
if len(candidates) == 0:
|
|
1032
|
+
return None
|
|
1033
|
+
elif len(candidates) == 1:
|
|
1034
|
+
return candidates[0]
|
|
1035
|
+
else:
|
|
1036
|
+
best_fit = np.argmax([determine_intersection(token, c) for c in candidates])
|
|
1037
|
+
return candidates[int(best_fit)]
|
|
1038
|
+
|
|
1039
|
+
|
|
1040
|
+
def determine_token_labels(
|
|
1041
|
+
token: Token,
|
|
1042
|
+
entities: List[Dict],
|
|
1043
|
+
extractors: Optional[Set[Text]] = None,
|
|
1044
|
+
attribute_key: Text = ENTITY_ATTRIBUTE_TYPE,
|
|
1045
|
+
) -> Text:
|
|
1046
|
+
"""Select token label for the provided attribute key for non-overlapping entities.
|
|
1047
|
+
|
|
1048
|
+
Args:
|
|
1049
|
+
token: a single token
|
|
1050
|
+
entities: entities found by a single extractor
|
|
1051
|
+
extractors: list of extractors
|
|
1052
|
+
attribute_key: the attribute key for which the entity type should be returned
|
|
1053
|
+
Returns:
|
|
1054
|
+
entity type
|
|
1055
|
+
"""
|
|
1056
|
+
entity = determine_entity_for_token(token, entities, extractors)
|
|
1057
|
+
|
|
1058
|
+
if entity is None:
|
|
1059
|
+
return NO_ENTITY_TAG
|
|
1060
|
+
|
|
1061
|
+
label = entity.get(attribute_key)
|
|
1062
|
+
|
|
1063
|
+
if not label:
|
|
1064
|
+
return NO_ENTITY_TAG
|
|
1065
|
+
|
|
1066
|
+
return label
|
|
1067
|
+
|
|
1068
|
+
|
|
1069
|
+
def determine_entity_for_token(
|
|
1070
|
+
token: Token,
|
|
1071
|
+
entities: List[Dict[Text, Any]],
|
|
1072
|
+
extractors: Optional[Set[Text]] = None,
|
|
1073
|
+
) -> Optional[Dict[Text, Any]]:
|
|
1074
|
+
"""Determines the best fitting non-overlapping entity for the given token.
|
|
1075
|
+
|
|
1076
|
+
Args:
|
|
1077
|
+
token: a single token
|
|
1078
|
+
entities: entities found by a single extractor
|
|
1079
|
+
extractors: list of extractors
|
|
1080
|
+
|
|
1081
|
+
Returns:
|
|
1082
|
+
entity type
|
|
1083
|
+
"""
|
|
1084
|
+
if entities is None or len(entities) == 0:
|
|
1085
|
+
return None
|
|
1086
|
+
if do_any_extractors_not_support_overlap(extractors) and do_entities_overlap(
|
|
1087
|
+
entities
|
|
1088
|
+
):
|
|
1089
|
+
raise ValueError("The possible entities should not overlap.")
|
|
1090
|
+
|
|
1091
|
+
candidates = find_intersecting_entities(token, entities)
|
|
1092
|
+
return pick_best_entity_fit(token, candidates)
|
|
1093
|
+
|
|
1094
|
+
|
|
1095
|
+
def do_any_extractors_not_support_overlap(extractors: Optional[Set[Text]]) -> bool:
|
|
1096
|
+
"""Checks if any extractor does not support overlapping entities.
|
|
1097
|
+
|
|
1098
|
+
Args:
|
|
1099
|
+
extractors: Names of the entity extractors
|
|
1100
|
+
|
|
1101
|
+
Returns:
|
|
1102
|
+
`True` if and only if CRFEntityExtractor or DIETClassifier is in `extractors`
|
|
1103
|
+
"""
|
|
1104
|
+
if extractors is None:
|
|
1105
|
+
return False
|
|
1106
|
+
|
|
1107
|
+
from rasa.nlu.extractors.crf_entity_extractor import CRFEntityExtractor
|
|
1108
|
+
from rasa.nlu.classifiers.diet_classifier import DIETClassifier
|
|
1109
|
+
|
|
1110
|
+
return not extractors.isdisjoint(
|
|
1111
|
+
{CRFEntityExtractor.__name__, DIETClassifier.__name__}
|
|
1112
|
+
)
|
|
1113
|
+
|
|
1114
|
+
|
|
1115
|
+
def align_entity_predictions(
|
|
1116
|
+
result: EntityEvaluationResult, extractors: Set[Text]
|
|
1117
|
+
) -> Dict:
|
|
1118
|
+
"""Aligns entity predictions to the message tokens.
|
|
1119
|
+
|
|
1120
|
+
Determines for every token the true label based on the
|
|
1121
|
+
prediction targets and the label assigned by each
|
|
1122
|
+
single extractor.
|
|
1123
|
+
|
|
1124
|
+
Args:
|
|
1125
|
+
result: entity evaluation result
|
|
1126
|
+
extractors: the entity extractors that should be considered
|
|
1127
|
+
|
|
1128
|
+
Returns: dictionary containing the true token labels and token labels
|
|
1129
|
+
from the extractors
|
|
1130
|
+
"""
|
|
1131
|
+
true_token_labels = []
|
|
1132
|
+
entities_by_extractors: Dict[Text, List] = {
|
|
1133
|
+
extractor: [] for extractor in extractors
|
|
1134
|
+
}
|
|
1135
|
+
for p in result.entity_predictions:
|
|
1136
|
+
entities_by_extractors[p[EXTRACTOR]].append(p)
|
|
1137
|
+
extractor_labels: Dict[Text, List] = {extractor: [] for extractor in extractors}
|
|
1138
|
+
extractor_confidences: Dict[Text, List] = {
|
|
1139
|
+
extractor: [] for extractor in extractors
|
|
1140
|
+
}
|
|
1141
|
+
for t in result.tokens:
|
|
1142
|
+
true_token_labels.append(_concat_entity_labels(t, result.entity_targets))
|
|
1143
|
+
for extractor, entities in entities_by_extractors.items():
|
|
1144
|
+
extracted_labels = _concat_entity_labels(t, entities, {extractor})
|
|
1145
|
+
extracted_confidences = _get_entity_confidences(t, entities, {extractor})
|
|
1146
|
+
extractor_labels[extractor].append(extracted_labels)
|
|
1147
|
+
extractor_confidences[extractor].append(extracted_confidences)
|
|
1148
|
+
|
|
1149
|
+
return {
|
|
1150
|
+
"target_labels": true_token_labels,
|
|
1151
|
+
"extractor_labels": extractor_labels,
|
|
1152
|
+
"confidences": extractor_confidences,
|
|
1153
|
+
}
|
|
1154
|
+
|
|
1155
|
+
|
|
1156
|
+
def _concat_entity_labels(
|
|
1157
|
+
token: Token, entities: List[Dict], extractors: Optional[Set[Text]] = None
|
|
1158
|
+
) -> Text:
|
|
1159
|
+
"""Concatenate labels for entity type, role, and group for evaluation.
|
|
1160
|
+
|
|
1161
|
+
In order to calculate metrics also for entity type, role, and group we need to
|
|
1162
|
+
concatenate their labels. For example, 'location.destination'. This allows
|
|
1163
|
+
us to report metrics for every combination of entity type, role, and group.
|
|
1164
|
+
|
|
1165
|
+
Args:
|
|
1166
|
+
token: the token we are looking at
|
|
1167
|
+
entities: the available entities
|
|
1168
|
+
extractors: the extractor of interest
|
|
1169
|
+
|
|
1170
|
+
Returns:
|
|
1171
|
+
the entity label of the provided token
|
|
1172
|
+
"""
|
|
1173
|
+
entity_label = determine_token_labels(
|
|
1174
|
+
token, entities, extractors, ENTITY_ATTRIBUTE_TYPE
|
|
1175
|
+
)
|
|
1176
|
+
group_label = determine_token_labels(
|
|
1177
|
+
token, entities, extractors, ENTITY_ATTRIBUTE_GROUP
|
|
1178
|
+
)
|
|
1179
|
+
role_label = determine_token_labels(
|
|
1180
|
+
token, entities, extractors, ENTITY_ATTRIBUTE_ROLE
|
|
1181
|
+
)
|
|
1182
|
+
|
|
1183
|
+
if entity_label == role_label == group_label == NO_ENTITY_TAG:
|
|
1184
|
+
return NO_ENTITY_TAG
|
|
1185
|
+
|
|
1186
|
+
labels = [entity_label, group_label, role_label]
|
|
1187
|
+
labels = [label for label in labels if label != NO_ENTITY_TAG]
|
|
1188
|
+
|
|
1189
|
+
return ".".join(labels)
|
|
1190
|
+
|
|
1191
|
+
|
|
1192
|
+
def _get_entity_confidences(
|
|
1193
|
+
token: Token, entities: List[Dict], extractors: Optional[Set[Text]] = None
|
|
1194
|
+
) -> float:
|
|
1195
|
+
"""Get the confidence value of the best fitting entity.
|
|
1196
|
+
|
|
1197
|
+
If multiple confidence values are present, e.g. for type, role, group, we
|
|
1198
|
+
pick the lowest confidence value.
|
|
1199
|
+
|
|
1200
|
+
Args:
|
|
1201
|
+
token: the token we are looking at
|
|
1202
|
+
entities: the available entities
|
|
1203
|
+
extractors: the extractor of interest
|
|
1204
|
+
|
|
1205
|
+
Returns:
|
|
1206
|
+
the confidence value
|
|
1207
|
+
"""
|
|
1208
|
+
entity = determine_entity_for_token(token, entities, extractors)
|
|
1209
|
+
|
|
1210
|
+
if entity is None:
|
|
1211
|
+
return 0.0
|
|
1212
|
+
|
|
1213
|
+
if entity.get("extractor") not in EXTRACTORS_WITH_CONFIDENCES:
|
|
1214
|
+
return 0.0
|
|
1215
|
+
|
|
1216
|
+
conf_type = entity.get(ENTITY_ATTRIBUTE_CONFIDENCE_TYPE) or 1.0
|
|
1217
|
+
conf_role = entity.get(ENTITY_ATTRIBUTE_CONFIDENCE_ROLE) or 1.0
|
|
1218
|
+
conf_group = entity.get(ENTITY_ATTRIBUTE_CONFIDENCE_GROUP) or 1.0
|
|
1219
|
+
|
|
1220
|
+
return min(conf_type, conf_role, conf_group)
|
|
1221
|
+
|
|
1222
|
+
|
|
1223
|
+
def align_all_entity_predictions(
|
|
1224
|
+
entity_results: List[EntityEvaluationResult], extractors: Set[Text]
|
|
1225
|
+
) -> List[Dict]:
|
|
1226
|
+
"""Aligns entity predictions to the message tokens.
|
|
1227
|
+
|
|
1228
|
+
Processes the whole dataset using align_entity_predictions.
|
|
1229
|
+
|
|
1230
|
+
Args:
|
|
1231
|
+
entity_results: list of entity prediction results
|
|
1232
|
+
extractors: the entity extractors that should be considered
|
|
1233
|
+
|
|
1234
|
+
Returns: list of dictionaries containing the true token labels and token
|
|
1235
|
+
labels from the extractors
|
|
1236
|
+
"""
|
|
1237
|
+
aligned_predictions = []
|
|
1238
|
+
for result in entity_results:
|
|
1239
|
+
aligned_predictions.append(align_entity_predictions(result, extractors))
|
|
1240
|
+
|
|
1241
|
+
return aligned_predictions
|
|
1242
|
+
|
|
1243
|
+
|
|
1244
|
+
async def get_eval_data(
|
|
1245
|
+
processor: MessageProcessor, test_data: TrainingData
|
|
1246
|
+
) -> Tuple[
|
|
1247
|
+
List[IntentEvaluationResult],
|
|
1248
|
+
List[ResponseSelectionEvaluationResult],
|
|
1249
|
+
List[EntityEvaluationResult],
|
|
1250
|
+
]:
|
|
1251
|
+
"""Runs the model for the test set and extracts targets and predictions.
|
|
1252
|
+
|
|
1253
|
+
Returns intent results (intent targets and predictions, the original
|
|
1254
|
+
messages and the confidences of the predictions), response results (
|
|
1255
|
+
response targets and predictions) as well as entity results
|
|
1256
|
+
(entity_targets, entity_predictions, and tokens).
|
|
1257
|
+
|
|
1258
|
+
Args:
|
|
1259
|
+
processor: the processor
|
|
1260
|
+
test_data: test data
|
|
1261
|
+
|
|
1262
|
+
Returns: intent, response, and entity evaluation results
|
|
1263
|
+
"""
|
|
1264
|
+
logger.info("Running model for predictions:")
|
|
1265
|
+
|
|
1266
|
+
intent_results, entity_results, response_selection_results = [], [], []
|
|
1267
|
+
|
|
1268
|
+
response_labels = {
|
|
1269
|
+
e.get(INTENT_RESPONSE_KEY)
|
|
1270
|
+
for e in test_data.intent_examples
|
|
1271
|
+
if e.get(INTENT_RESPONSE_KEY) is not None
|
|
1272
|
+
}
|
|
1273
|
+
intent_labels = {e.get(INTENT) for e in test_data.intent_examples}
|
|
1274
|
+
should_eval_intents = len(intent_labels) >= 2
|
|
1275
|
+
should_eval_response_selection = len(response_labels) >= 2
|
|
1276
|
+
should_eval_entities = len(test_data.entity_examples) > 0
|
|
1277
|
+
|
|
1278
|
+
for example in tqdm(test_data.nlu_examples):
|
|
1279
|
+
result = await processor.parse_message(
|
|
1280
|
+
UserMessage(text=example.get(TEXT)),
|
|
1281
|
+
only_output_properties=False,
|
|
1282
|
+
)
|
|
1283
|
+
_remove_entities_of_extractors(result, PRETRAINED_EXTRACTORS)
|
|
1284
|
+
if should_eval_intents:
|
|
1285
|
+
if fallback_classifier.is_fallback_classifier_prediction(result):
|
|
1286
|
+
# Revert fallback prediction to not shadow
|
|
1287
|
+
# the wrongly predicted intent
|
|
1288
|
+
# during the test phase.
|
|
1289
|
+
result = fallback_classifier.undo_fallback_prediction(result)
|
|
1290
|
+
intent_prediction = result.get(INTENT, {})
|
|
1291
|
+
intent_results.append(
|
|
1292
|
+
IntentEvaluationResult(
|
|
1293
|
+
example.get(INTENT, ""),
|
|
1294
|
+
intent_prediction.get(INTENT_NAME_KEY),
|
|
1295
|
+
result.get(TEXT),
|
|
1296
|
+
intent_prediction.get("confidence"),
|
|
1297
|
+
)
|
|
1298
|
+
)
|
|
1299
|
+
|
|
1300
|
+
if should_eval_response_selection:
|
|
1301
|
+
# including all examples here. Empty response examples are filtered at the
|
|
1302
|
+
# time of metric calculation
|
|
1303
|
+
intent_target = example.get(INTENT, "")
|
|
1304
|
+
selector_properties = result.get(RESPONSE_SELECTOR_PROPERTY_NAME, {})
|
|
1305
|
+
response_selector_retrieval_intents = selector_properties.get(
|
|
1306
|
+
RESPONSE_SELECTOR_RETRIEVAL_INTENTS, set()
|
|
1307
|
+
)
|
|
1308
|
+
if (
|
|
1309
|
+
intent_target in response_selector_retrieval_intents
|
|
1310
|
+
and intent_target in selector_properties
|
|
1311
|
+
):
|
|
1312
|
+
response_prediction_key = intent_target
|
|
1313
|
+
else:
|
|
1314
|
+
response_prediction_key = RESPONSE_SELECTOR_DEFAULT_INTENT
|
|
1315
|
+
|
|
1316
|
+
response_prediction = selector_properties.get(
|
|
1317
|
+
response_prediction_key, {}
|
|
1318
|
+
).get(RESPONSE_SELECTOR_PREDICTION_KEY, {})
|
|
1319
|
+
|
|
1320
|
+
intent_response_key_target = example.get(INTENT_RESPONSE_KEY, "")
|
|
1321
|
+
|
|
1322
|
+
response_selection_results.append(
|
|
1323
|
+
ResponseSelectionEvaluationResult(
|
|
1324
|
+
intent_response_key_target,
|
|
1325
|
+
response_prediction.get(INTENT_RESPONSE_KEY),
|
|
1326
|
+
result.get(TEXT),
|
|
1327
|
+
response_prediction.get(PREDICTED_CONFIDENCE_KEY),
|
|
1328
|
+
)
|
|
1329
|
+
)
|
|
1330
|
+
|
|
1331
|
+
if should_eval_entities:
|
|
1332
|
+
entity_results.append(
|
|
1333
|
+
EntityEvaluationResult(
|
|
1334
|
+
example.get(ENTITIES, []),
|
|
1335
|
+
result.get(ENTITIES, []),
|
|
1336
|
+
result.get(TOKENS_NAMES[TEXT], []),
|
|
1337
|
+
result.get(TEXT),
|
|
1338
|
+
)
|
|
1339
|
+
)
|
|
1340
|
+
|
|
1341
|
+
return intent_results, response_selection_results, entity_results
|
|
1342
|
+
|
|
1343
|
+
|
|
1344
|
+
def _get_active_entity_extractors(
|
|
1345
|
+
entity_results: List[EntityEvaluationResult],
|
|
1346
|
+
) -> Set[Text]:
|
|
1347
|
+
"""Finds the names of entity extractors from the EntityEvaluationResults."""
|
|
1348
|
+
extractors: Set[Text] = set()
|
|
1349
|
+
for result in entity_results:
|
|
1350
|
+
for prediction in result.entity_predictions:
|
|
1351
|
+
if EXTRACTOR in prediction:
|
|
1352
|
+
extractors.add(prediction[EXTRACTOR])
|
|
1353
|
+
return extractors
|
|
1354
|
+
|
|
1355
|
+
|
|
1356
|
+
def _remove_entities_of_extractors(
|
|
1357
|
+
nlu_parse_result: Dict[Text, Any], extractor_names: Set[Text]
|
|
1358
|
+
) -> None:
|
|
1359
|
+
"""Removes the entities annotated by the given extractor names."""
|
|
1360
|
+
entities = nlu_parse_result.get(ENTITIES)
|
|
1361
|
+
if not entities:
|
|
1362
|
+
return
|
|
1363
|
+
filtered_entities = [e for e in entities if e.get(EXTRACTOR) not in extractor_names]
|
|
1364
|
+
nlu_parse_result[ENTITIES] = filtered_entities
|
|
1365
|
+
|
|
1366
|
+
|
|
1367
|
+
async def run_evaluation(
|
|
1368
|
+
data_path: Text,
|
|
1369
|
+
processor: MessageProcessor,
|
|
1370
|
+
output_directory: Optional[Text] = None,
|
|
1371
|
+
successes: bool = False,
|
|
1372
|
+
errors: bool = False,
|
|
1373
|
+
disable_plotting: bool = False,
|
|
1374
|
+
report_as_dict: Optional[bool] = None,
|
|
1375
|
+
domain_path: Optional[Text] = None,
|
|
1376
|
+
) -> Dict: # pragma: no cover
|
|
1377
|
+
"""Evaluate intent classification, response selection and entity extraction.
|
|
1378
|
+
|
|
1379
|
+
Args:
|
|
1380
|
+
data_path: path to the test data
|
|
1381
|
+
processor: the processor used to process and predict
|
|
1382
|
+
output_directory: path to folder where all output will be stored
|
|
1383
|
+
successes: if true successful predictions are written to a file
|
|
1384
|
+
errors: if true incorrect predictions are written to a file
|
|
1385
|
+
disable_plotting: if true confusion matrix and histogram will not be rendered
|
|
1386
|
+
report_as_dict: `True` if the evaluation report should be returned as `dict`.
|
|
1387
|
+
If `False` the report is returned in a human-readable text format. If `None`
|
|
1388
|
+
`report_as_dict` is considered as `True` in case an `output_directory` is
|
|
1389
|
+
given.
|
|
1390
|
+
domain_path: Path to the domain file(s).
|
|
1391
|
+
|
|
1392
|
+
Returns: dictionary containing evaluation results
|
|
1393
|
+
"""
|
|
1394
|
+
import rasa.shared.nlu.training_data.loading
|
|
1395
|
+
from rasa.shared.constants import DEFAULT_DOMAIN_PATH
|
|
1396
|
+
|
|
1397
|
+
test_data_importer = TrainingDataImporter.load_from_dict(
|
|
1398
|
+
training_data_paths=[data_path],
|
|
1399
|
+
domain_path=domain_path if domain_path else DEFAULT_DOMAIN_PATH,
|
|
1400
|
+
)
|
|
1401
|
+
test_data = test_data_importer.get_nlu_data()
|
|
1402
|
+
|
|
1403
|
+
result: Dict[Text, Optional[Dict]] = {
|
|
1404
|
+
"intent_evaluation": None,
|
|
1405
|
+
"entity_evaluation": None,
|
|
1406
|
+
"response_selection_evaluation": None,
|
|
1407
|
+
}
|
|
1408
|
+
|
|
1409
|
+
if output_directory:
|
|
1410
|
+
rasa.shared.utils.io.create_directory(output_directory)
|
|
1411
|
+
|
|
1412
|
+
(intent_results, response_selection_results, entity_results) = await get_eval_data(
|
|
1413
|
+
processor, test_data
|
|
1414
|
+
)
|
|
1415
|
+
|
|
1416
|
+
if intent_results:
|
|
1417
|
+
logger.info("Intent evaluation results:")
|
|
1418
|
+
result["intent_evaluation"] = evaluate_intents(
|
|
1419
|
+
intent_results,
|
|
1420
|
+
output_directory,
|
|
1421
|
+
successes,
|
|
1422
|
+
errors,
|
|
1423
|
+
disable_plotting,
|
|
1424
|
+
report_as_dict=report_as_dict,
|
|
1425
|
+
)
|
|
1426
|
+
|
|
1427
|
+
if response_selection_results:
|
|
1428
|
+
logger.info("Response selection evaluation results:")
|
|
1429
|
+
result["response_selection_evaluation"] = evaluate_response_selections(
|
|
1430
|
+
response_selection_results,
|
|
1431
|
+
output_directory,
|
|
1432
|
+
successes,
|
|
1433
|
+
errors,
|
|
1434
|
+
disable_plotting,
|
|
1435
|
+
report_as_dict=report_as_dict,
|
|
1436
|
+
)
|
|
1437
|
+
|
|
1438
|
+
if any(entity_results):
|
|
1439
|
+
logger.info("Entity evaluation results:")
|
|
1440
|
+
extractors = _get_active_entity_extractors(entity_results)
|
|
1441
|
+
result["entity_evaluation"] = evaluate_entities(
|
|
1442
|
+
entity_results,
|
|
1443
|
+
extractors,
|
|
1444
|
+
output_directory,
|
|
1445
|
+
successes,
|
|
1446
|
+
errors,
|
|
1447
|
+
disable_plotting,
|
|
1448
|
+
report_as_dict=report_as_dict,
|
|
1449
|
+
)
|
|
1450
|
+
|
|
1451
|
+
telemetry.track_nlu_model_test(test_data)
|
|
1452
|
+
|
|
1453
|
+
return result
|
|
1454
|
+
|
|
1455
|
+
|
|
1456
|
+
def generate_folds(
|
|
1457
|
+
n: int, training_data: TrainingData
|
|
1458
|
+
) -> Iterator[Tuple[TrainingData, TrainingData]]:
|
|
1459
|
+
"""Generates n cross validation folds for given training data."""
|
|
1460
|
+
from sklearn.model_selection import StratifiedKFold
|
|
1461
|
+
|
|
1462
|
+
skf = StratifiedKFold(n_splits=n, shuffle=True)
|
|
1463
|
+
x = training_data.intent_examples
|
|
1464
|
+
|
|
1465
|
+
# Get labels as they appear in the training data because we want a
|
|
1466
|
+
# stratified split on all intents(including retrieval intents if they exist)
|
|
1467
|
+
y = [example.get_full_intent() for example in x]
|
|
1468
|
+
for i_fold, (train_index, test_index) in enumerate(skf.split(x, y)):
|
|
1469
|
+
logger.debug(f"Fold: {i_fold}")
|
|
1470
|
+
train = [x[i] for i in train_index]
|
|
1471
|
+
test = [x[i] for i in test_index]
|
|
1472
|
+
yield (
|
|
1473
|
+
TrainingData(
|
|
1474
|
+
training_examples=train,
|
|
1475
|
+
entity_synonyms=training_data.entity_synonyms,
|
|
1476
|
+
regex_features=training_data.regex_features,
|
|
1477
|
+
lookup_tables=training_data.lookup_tables,
|
|
1478
|
+
responses=training_data.responses,
|
|
1479
|
+
),
|
|
1480
|
+
TrainingData(
|
|
1481
|
+
training_examples=test,
|
|
1482
|
+
entity_synonyms=training_data.entity_synonyms,
|
|
1483
|
+
regex_features=training_data.regex_features,
|
|
1484
|
+
lookup_tables=training_data.lookup_tables,
|
|
1485
|
+
responses=training_data.responses,
|
|
1486
|
+
),
|
|
1487
|
+
)
|
|
1488
|
+
|
|
1489
|
+
|
|
1490
|
+
async def combine_result(
|
|
1491
|
+
intent_metrics: IntentMetrics,
|
|
1492
|
+
entity_metrics: EntityMetrics,
|
|
1493
|
+
response_selection_metrics: ResponseSelectionMetrics,
|
|
1494
|
+
processor: MessageProcessor,
|
|
1495
|
+
data: TrainingData,
|
|
1496
|
+
intent_results: Optional[List[IntentEvaluationResult]] = None,
|
|
1497
|
+
entity_results: Optional[List[EntityEvaluationResult]] = None,
|
|
1498
|
+
response_selection_results: Optional[
|
|
1499
|
+
List[ResponseSelectionEvaluationResult]
|
|
1500
|
+
] = None,
|
|
1501
|
+
) -> Tuple[IntentMetrics, EntityMetrics, ResponseSelectionMetrics]:
|
|
1502
|
+
"""Collects intent, response selection and entity metrics for cross validation.
|
|
1503
|
+
|
|
1504
|
+
If `intent_results`, `response_selection_results` or `entity_results` is provided
|
|
1505
|
+
as a list, prediction results are also collected.
|
|
1506
|
+
|
|
1507
|
+
Args:
|
|
1508
|
+
intent_metrics: intent metrics
|
|
1509
|
+
entity_metrics: entity metrics
|
|
1510
|
+
response_selection_metrics: response selection metrics
|
|
1511
|
+
processor: the processor
|
|
1512
|
+
data: training data
|
|
1513
|
+
intent_results: intent evaluation results
|
|
1514
|
+
entity_results: entity evaluation results
|
|
1515
|
+
response_selection_results: reponse selection evaluation results
|
|
1516
|
+
|
|
1517
|
+
Returns: intent, entity, and response selection metrics
|
|
1518
|
+
"""
|
|
1519
|
+
(
|
|
1520
|
+
intent_current_metrics,
|
|
1521
|
+
entity_current_metrics,
|
|
1522
|
+
response_selection_current_metrics,
|
|
1523
|
+
current_intent_results,
|
|
1524
|
+
current_entity_results,
|
|
1525
|
+
current_response_selection_results,
|
|
1526
|
+
) = await compute_metrics(processor, data)
|
|
1527
|
+
|
|
1528
|
+
if intent_results is not None:
|
|
1529
|
+
intent_results += current_intent_results
|
|
1530
|
+
|
|
1531
|
+
if entity_results is not None:
|
|
1532
|
+
entity_results += current_entity_results
|
|
1533
|
+
|
|
1534
|
+
if response_selection_results is not None:
|
|
1535
|
+
response_selection_results += current_response_selection_results
|
|
1536
|
+
|
|
1537
|
+
for k, v in intent_current_metrics.items():
|
|
1538
|
+
intent_metrics[k] = v + intent_metrics[k]
|
|
1539
|
+
|
|
1540
|
+
for k, v in response_selection_current_metrics.items():
|
|
1541
|
+
response_selection_metrics[k] = v + response_selection_metrics[k]
|
|
1542
|
+
|
|
1543
|
+
for extractor, extractor_metric in entity_current_metrics.items():
|
|
1544
|
+
entity_metrics[extractor] = {
|
|
1545
|
+
k: v + entity_metrics[extractor][k] for k, v in extractor_metric.items()
|
|
1546
|
+
}
|
|
1547
|
+
|
|
1548
|
+
return intent_metrics, entity_metrics, response_selection_metrics
|
|
1549
|
+
|
|
1550
|
+
|
|
1551
|
+
def _contains_entity_labels(entity_results: List[EntityEvaluationResult]) -> bool:
|
|
1552
|
+
for result in entity_results:
|
|
1553
|
+
if result.entity_targets or result.entity_predictions:
|
|
1554
|
+
return True
|
|
1555
|
+
return False
|
|
1556
|
+
|
|
1557
|
+
|
|
1558
|
+
async def cross_validate(
|
|
1559
|
+
data: TrainingData,
|
|
1560
|
+
n_folds: int,
|
|
1561
|
+
nlu_config: Union[Text, Dict],
|
|
1562
|
+
output: Optional[Text] = None,
|
|
1563
|
+
successes: bool = False,
|
|
1564
|
+
errors: bool = False,
|
|
1565
|
+
disable_plotting: bool = False,
|
|
1566
|
+
report_as_dict: Optional[bool] = None,
|
|
1567
|
+
) -> Tuple[CVEvaluationResult, CVEvaluationResult, CVEvaluationResult]:
|
|
1568
|
+
"""Stratified cross validation on data.
|
|
1569
|
+
|
|
1570
|
+
Args:
|
|
1571
|
+
data: Training Data
|
|
1572
|
+
n_folds: integer, number of cv folds
|
|
1573
|
+
nlu_config: nlu config file
|
|
1574
|
+
output: path to folder where reports are stored
|
|
1575
|
+
successes: if true successful predictions are written to a file
|
|
1576
|
+
errors: if true incorrect predictions are written to a file
|
|
1577
|
+
disable_plotting: if true no confusion matrix and historgram plates are created
|
|
1578
|
+
report_as_dict: `True` if the evaluation report should be returned as `dict`.
|
|
1579
|
+
If `False` the report is returned in a human-readable text format. If `None`
|
|
1580
|
+
`report_as_dict` is considered as `True` in case an `output_directory` is
|
|
1581
|
+
given.
|
|
1582
|
+
|
|
1583
|
+
Returns:
|
|
1584
|
+
dictionary with key, list structure, where each entry in list
|
|
1585
|
+
corresponds to the relevant result for one fold
|
|
1586
|
+
"""
|
|
1587
|
+
import rasa.model_training
|
|
1588
|
+
|
|
1589
|
+
with TempDirectoryPath(get_temp_dir_name()) as temp_dir:
|
|
1590
|
+
tmp_path = Path(temp_dir)
|
|
1591
|
+
|
|
1592
|
+
if isinstance(nlu_config, Dict):
|
|
1593
|
+
config_path = tmp_path / "config.yml"
|
|
1594
|
+
write_yaml(nlu_config, config_path)
|
|
1595
|
+
nlu_config = str(config_path)
|
|
1596
|
+
|
|
1597
|
+
if output:
|
|
1598
|
+
rasa.shared.utils.io.create_directory(output)
|
|
1599
|
+
|
|
1600
|
+
intent_train_metrics: IntentMetrics = defaultdict(list)
|
|
1601
|
+
intent_test_metrics: IntentMetrics = defaultdict(list)
|
|
1602
|
+
entity_train_metrics: EntityMetrics = defaultdict(lambda: defaultdict(list))
|
|
1603
|
+
entity_test_metrics: EntityMetrics = defaultdict(lambda: defaultdict(list))
|
|
1604
|
+
response_selection_train_metrics: ResponseSelectionMetrics = defaultdict(list)
|
|
1605
|
+
response_selection_test_metrics: ResponseSelectionMetrics = defaultdict(list)
|
|
1606
|
+
|
|
1607
|
+
intent_test_results: List[IntentEvaluationResult] = []
|
|
1608
|
+
entity_test_results: List[EntityEvaluationResult] = []
|
|
1609
|
+
response_selection_test_results: List[ResponseSelectionEvaluationResult] = []
|
|
1610
|
+
|
|
1611
|
+
for train, test in generate_folds(n_folds, data):
|
|
1612
|
+
training_data_file = tmp_path / "training_data.yml"
|
|
1613
|
+
RasaYAMLWriter().dump(training_data_file, train)
|
|
1614
|
+
|
|
1615
|
+
model_file = await rasa.model_training.train_nlu(
|
|
1616
|
+
nlu_config, str(training_data_file), str(tmp_path)
|
|
1617
|
+
)
|
|
1618
|
+
|
|
1619
|
+
processor = Agent.load(model_file).processor
|
|
1620
|
+
|
|
1621
|
+
# calculate train accuracy
|
|
1622
|
+
await combine_result(
|
|
1623
|
+
intent_train_metrics,
|
|
1624
|
+
entity_train_metrics,
|
|
1625
|
+
response_selection_train_metrics,
|
|
1626
|
+
processor,
|
|
1627
|
+
train,
|
|
1628
|
+
)
|
|
1629
|
+
# calculate test accuracy
|
|
1630
|
+
await combine_result(
|
|
1631
|
+
intent_test_metrics,
|
|
1632
|
+
entity_test_metrics,
|
|
1633
|
+
response_selection_test_metrics,
|
|
1634
|
+
processor,
|
|
1635
|
+
test,
|
|
1636
|
+
intent_test_results,
|
|
1637
|
+
entity_test_results,
|
|
1638
|
+
response_selection_test_results,
|
|
1639
|
+
)
|
|
1640
|
+
|
|
1641
|
+
intent_evaluation = {}
|
|
1642
|
+
if intent_test_results:
|
|
1643
|
+
logger.info("Accumulated test folds intent evaluation results:")
|
|
1644
|
+
intent_evaluation = evaluate_intents(
|
|
1645
|
+
intent_test_results,
|
|
1646
|
+
output,
|
|
1647
|
+
successes,
|
|
1648
|
+
errors,
|
|
1649
|
+
disable_plotting,
|
|
1650
|
+
report_as_dict=report_as_dict,
|
|
1651
|
+
)
|
|
1652
|
+
|
|
1653
|
+
entity_evaluation = {}
|
|
1654
|
+
if entity_test_results:
|
|
1655
|
+
logger.info("Accumulated test folds entity evaluation results:")
|
|
1656
|
+
extractors = _get_active_entity_extractors(entity_test_results)
|
|
1657
|
+
entity_evaluation = evaluate_entities(
|
|
1658
|
+
entity_test_results,
|
|
1659
|
+
extractors,
|
|
1660
|
+
output,
|
|
1661
|
+
successes,
|
|
1662
|
+
errors,
|
|
1663
|
+
disable_plotting,
|
|
1664
|
+
report_as_dict=report_as_dict,
|
|
1665
|
+
)
|
|
1666
|
+
|
|
1667
|
+
responses_evaluation = {}
|
|
1668
|
+
if response_selection_test_results:
|
|
1669
|
+
logger.info("Accumulated test folds response selection evaluation results:")
|
|
1670
|
+
responses_evaluation = evaluate_response_selections(
|
|
1671
|
+
response_selection_test_results,
|
|
1672
|
+
output,
|
|
1673
|
+
successes,
|
|
1674
|
+
errors,
|
|
1675
|
+
disable_plotting,
|
|
1676
|
+
report_as_dict=report_as_dict,
|
|
1677
|
+
)
|
|
1678
|
+
|
|
1679
|
+
return (
|
|
1680
|
+
CVEvaluationResult(
|
|
1681
|
+
dict(intent_train_metrics), dict(intent_test_metrics), intent_evaluation
|
|
1682
|
+
),
|
|
1683
|
+
CVEvaluationResult(
|
|
1684
|
+
dict(entity_train_metrics), dict(entity_test_metrics), entity_evaluation
|
|
1685
|
+
),
|
|
1686
|
+
CVEvaluationResult(
|
|
1687
|
+
dict(response_selection_train_metrics),
|
|
1688
|
+
dict(response_selection_test_metrics),
|
|
1689
|
+
responses_evaluation,
|
|
1690
|
+
),
|
|
1691
|
+
)
|
|
1692
|
+
|
|
1693
|
+
|
|
1694
|
+
def _targets_predictions_from(
|
|
1695
|
+
results: Union[
|
|
1696
|
+
List[IntentEvaluationResult], List[ResponseSelectionEvaluationResult]
|
|
1697
|
+
],
|
|
1698
|
+
target_key: Text,
|
|
1699
|
+
prediction_key: Text,
|
|
1700
|
+
) -> Iterator[Iterable[Optional[Text]]]:
|
|
1701
|
+
return zip(*[(getattr(r, target_key), getattr(r, prediction_key)) for r in results])
|
|
1702
|
+
|
|
1703
|
+
|
|
1704
|
+
async def compute_metrics(
|
|
1705
|
+
processor: MessageProcessor, training_data: TrainingData
|
|
1706
|
+
) -> Tuple[
|
|
1707
|
+
IntentMetrics,
|
|
1708
|
+
EntityMetrics,
|
|
1709
|
+
ResponseSelectionMetrics,
|
|
1710
|
+
List[IntentEvaluationResult],
|
|
1711
|
+
List[EntityEvaluationResult],
|
|
1712
|
+
List[ResponseSelectionEvaluationResult],
|
|
1713
|
+
]:
|
|
1714
|
+
"""Metrics for intent classification, response selection and entity extraction.
|
|
1715
|
+
|
|
1716
|
+
Args:
|
|
1717
|
+
processor: the processor
|
|
1718
|
+
training_data: training data
|
|
1719
|
+
|
|
1720
|
+
Returns: intent, response selection and entity metrics, and prediction results.
|
|
1721
|
+
"""
|
|
1722
|
+
intent_results, response_selection_results, entity_results = await get_eval_data(
|
|
1723
|
+
processor, training_data
|
|
1724
|
+
)
|
|
1725
|
+
|
|
1726
|
+
intent_results = remove_empty_intent_examples(intent_results)
|
|
1727
|
+
|
|
1728
|
+
response_selection_results = remove_empty_response_examples(
|
|
1729
|
+
response_selection_results
|
|
1730
|
+
)
|
|
1731
|
+
|
|
1732
|
+
intent_metrics: IntentMetrics = {}
|
|
1733
|
+
if intent_results:
|
|
1734
|
+
intent_metrics = _compute_metrics(
|
|
1735
|
+
intent_results, "intent_target", "intent_prediction"
|
|
1736
|
+
)
|
|
1737
|
+
|
|
1738
|
+
entity_metrics = {}
|
|
1739
|
+
if entity_results:
|
|
1740
|
+
entity_metrics = _compute_entity_metrics(entity_results)
|
|
1741
|
+
|
|
1742
|
+
response_selection_metrics: ResponseSelectionMetrics = {}
|
|
1743
|
+
if response_selection_results:
|
|
1744
|
+
response_selection_metrics = _compute_metrics(
|
|
1745
|
+
response_selection_results,
|
|
1746
|
+
"intent_response_key_target",
|
|
1747
|
+
"intent_response_key_prediction",
|
|
1748
|
+
)
|
|
1749
|
+
|
|
1750
|
+
return (
|
|
1751
|
+
intent_metrics,
|
|
1752
|
+
entity_metrics,
|
|
1753
|
+
response_selection_metrics,
|
|
1754
|
+
intent_results,
|
|
1755
|
+
entity_results,
|
|
1756
|
+
response_selection_results,
|
|
1757
|
+
)
|
|
1758
|
+
|
|
1759
|
+
|
|
1760
|
+
async def compare_nlu(
|
|
1761
|
+
configs: List[Text],
|
|
1762
|
+
data: TrainingData,
|
|
1763
|
+
exclusion_percentages: List[int],
|
|
1764
|
+
f_score_results: Dict[Text, List[List[float]]],
|
|
1765
|
+
model_names: List[Text],
|
|
1766
|
+
output: Text,
|
|
1767
|
+
runs: int,
|
|
1768
|
+
) -> List[int]:
|
|
1769
|
+
"""Trains and compares multiple NLU models.
|
|
1770
|
+
|
|
1771
|
+
For each run and exclusion percentage a model per config file is trained.
|
|
1772
|
+
Thereby, the model is trained only on the current percentage of training data.
|
|
1773
|
+
Afterwards, the model is tested on the complete test data of that run.
|
|
1774
|
+
All results are stored in the provided output directory.
|
|
1775
|
+
|
|
1776
|
+
Args:
|
|
1777
|
+
configs: config files needed for training
|
|
1778
|
+
data: training data
|
|
1779
|
+
exclusion_percentages: percentages of training data to exclude during comparison
|
|
1780
|
+
f_score_results: dictionary of model name to f-score results per run
|
|
1781
|
+
model_names: names of the models to train
|
|
1782
|
+
output: the output directory
|
|
1783
|
+
runs: number of comparison runs
|
|
1784
|
+
|
|
1785
|
+
Returns: training examples per run
|
|
1786
|
+
"""
|
|
1787
|
+
import rasa.model_training
|
|
1788
|
+
|
|
1789
|
+
training_examples_per_run = []
|
|
1790
|
+
|
|
1791
|
+
for run in range(runs):
|
|
1792
|
+
logger.info("Beginning comparison run {}/{}".format(run + 1, runs))
|
|
1793
|
+
|
|
1794
|
+
run_path = os.path.join(output, "run_{}".format(run + 1))
|
|
1795
|
+
io_utils.create_path(run_path)
|
|
1796
|
+
|
|
1797
|
+
test_path = os.path.join(run_path, TEST_DATA_FILE)
|
|
1798
|
+
io_utils.create_path(test_path)
|
|
1799
|
+
|
|
1800
|
+
train, test = data.train_test_split()
|
|
1801
|
+
rasa.shared.utils.io.write_text_file(test.nlu_as_yaml(), test_path)
|
|
1802
|
+
|
|
1803
|
+
for percentage in exclusion_percentages:
|
|
1804
|
+
percent_string = f"{percentage}%_exclusion"
|
|
1805
|
+
|
|
1806
|
+
_, train_included = train.train_test_split(percentage / 100)
|
|
1807
|
+
# only count for the first run and ignore the others
|
|
1808
|
+
if run == 0:
|
|
1809
|
+
training_examples_per_run.append(len(train_included.nlu_examples))
|
|
1810
|
+
|
|
1811
|
+
model_output_path = os.path.join(run_path, percent_string)
|
|
1812
|
+
train_split_path = os.path.join(model_output_path, "train")
|
|
1813
|
+
train_nlu_split_path = os.path.join(train_split_path, TRAIN_DATA_FILE)
|
|
1814
|
+
train_nlg_split_path = os.path.join(train_split_path, NLG_DATA_FILE)
|
|
1815
|
+
io_utils.create_path(train_nlu_split_path)
|
|
1816
|
+
rasa.shared.utils.io.write_text_file(
|
|
1817
|
+
train_included.nlu_as_yaml(), train_nlu_split_path
|
|
1818
|
+
)
|
|
1819
|
+
rasa.shared.utils.io.write_text_file(
|
|
1820
|
+
train_included.nlg_as_yaml(), train_nlg_split_path
|
|
1821
|
+
)
|
|
1822
|
+
|
|
1823
|
+
for nlu_config, model_name in zip(configs, model_names):
|
|
1824
|
+
logger.info(
|
|
1825
|
+
"Evaluating configuration '{}' with {} training data.".format(
|
|
1826
|
+
model_name, percent_string
|
|
1827
|
+
)
|
|
1828
|
+
)
|
|
1829
|
+
|
|
1830
|
+
try:
|
|
1831
|
+
model_path = await rasa.model_training.train_nlu(
|
|
1832
|
+
nlu_config,
|
|
1833
|
+
train_split_path,
|
|
1834
|
+
model_output_path,
|
|
1835
|
+
fixed_model_name=model_name,
|
|
1836
|
+
)
|
|
1837
|
+
except Exception as e: # skipcq: PYL-W0703
|
|
1838
|
+
# general exception catching needed to continue evaluating other
|
|
1839
|
+
# model configurations
|
|
1840
|
+
logger.warning(f"Training model '{model_name}' failed. Error: {e}")
|
|
1841
|
+
f_score_results[model_name][run].append(0.0)
|
|
1842
|
+
continue
|
|
1843
|
+
|
|
1844
|
+
output_path = os.path.join(model_output_path, f"{model_name}_report")
|
|
1845
|
+
processor = Agent.load(model_path=model_path).processor
|
|
1846
|
+
result = await run_evaluation(
|
|
1847
|
+
test_path, processor, output_directory=output_path, errors=True
|
|
1848
|
+
)
|
|
1849
|
+
|
|
1850
|
+
f1 = result["intent_evaluation"]["f1_score"]
|
|
1851
|
+
f_score_results[model_name][run].append(f1)
|
|
1852
|
+
|
|
1853
|
+
return training_examples_per_run
|
|
1854
|
+
|
|
1855
|
+
|
|
1856
|
+
def _compute_metrics(
|
|
1857
|
+
results: Union[
|
|
1858
|
+
List[IntentEvaluationResult], List[ResponseSelectionEvaluationResult]
|
|
1859
|
+
],
|
|
1860
|
+
target_key: Text,
|
|
1861
|
+
prediction_key: Text,
|
|
1862
|
+
) -> Union[IntentMetrics, ResponseSelectionMetrics]:
|
|
1863
|
+
"""Computes evaluation metrics for a given corpus and returns the results.
|
|
1864
|
+
|
|
1865
|
+
Args:
|
|
1866
|
+
results: evaluation results
|
|
1867
|
+
target_key: target key name
|
|
1868
|
+
prediction_key: prediction key name
|
|
1869
|
+
|
|
1870
|
+
Returns: metrics
|
|
1871
|
+
"""
|
|
1872
|
+
from rasa.model_testing import get_evaluation_metrics
|
|
1873
|
+
|
|
1874
|
+
# compute fold metrics
|
|
1875
|
+
targets, predictions = _targets_predictions_from(
|
|
1876
|
+
results, target_key, prediction_key
|
|
1877
|
+
)
|
|
1878
|
+
_, precision, f1, accuracy = get_evaluation_metrics(targets, predictions)
|
|
1879
|
+
|
|
1880
|
+
return {"Accuracy": [accuracy], "F1-score": [f1], "Precision": [precision]}
|
|
1881
|
+
|
|
1882
|
+
|
|
1883
|
+
def _compute_entity_metrics(
|
|
1884
|
+
entity_results: List[EntityEvaluationResult],
|
|
1885
|
+
) -> EntityMetrics:
|
|
1886
|
+
"""Computes entity evaluation metrics and returns the results.
|
|
1887
|
+
|
|
1888
|
+
Args:
|
|
1889
|
+
entity_results: entity evaluation results
|
|
1890
|
+
Returns: entity metrics
|
|
1891
|
+
"""
|
|
1892
|
+
from rasa.model_testing import get_evaluation_metrics
|
|
1893
|
+
|
|
1894
|
+
entity_metric_results: EntityMetrics = defaultdict(lambda: defaultdict(list))
|
|
1895
|
+
extractors = _get_active_entity_extractors(entity_results)
|
|
1896
|
+
|
|
1897
|
+
if not extractors:
|
|
1898
|
+
return entity_metric_results
|
|
1899
|
+
|
|
1900
|
+
aligned_predictions = align_all_entity_predictions(entity_results, extractors)
|
|
1901
|
+
|
|
1902
|
+
merged_targets = merge_labels(aligned_predictions)
|
|
1903
|
+
merged_targets = substitute_labels(merged_targets, NO_ENTITY_TAG, NO_ENTITY)
|
|
1904
|
+
|
|
1905
|
+
for extractor in extractors:
|
|
1906
|
+
merged_predictions = merge_labels(aligned_predictions, extractor)
|
|
1907
|
+
merged_predictions = substitute_labels(
|
|
1908
|
+
merged_predictions, NO_ENTITY_TAG, NO_ENTITY
|
|
1909
|
+
)
|
|
1910
|
+
_, precision, f1, accuracy = get_evaluation_metrics(
|
|
1911
|
+
merged_targets, merged_predictions, exclude_label=NO_ENTITY
|
|
1912
|
+
)
|
|
1913
|
+
entity_metric_results[extractor]["Accuracy"].append(accuracy)
|
|
1914
|
+
entity_metric_results[extractor]["F1-score"].append(f1)
|
|
1915
|
+
entity_metric_results[extractor]["Precision"].append(precision)
|
|
1916
|
+
|
|
1917
|
+
return entity_metric_results
|
|
1918
|
+
|
|
1919
|
+
|
|
1920
|
+
def log_results(results: IntentMetrics, dataset_name: Text) -> None:
|
|
1921
|
+
"""Logs results of cross validation.
|
|
1922
|
+
|
|
1923
|
+
Args:
|
|
1924
|
+
results: dictionary of results returned from cross validation
|
|
1925
|
+
dataset_name: string of which dataset the results are from, e.g. test/train
|
|
1926
|
+
"""
|
|
1927
|
+
for k, v in results.items():
|
|
1928
|
+
logger.info(f"{dataset_name} {k}: {np.mean(v):.3f} ({np.std(v):.3f})")
|
|
1929
|
+
|
|
1930
|
+
|
|
1931
|
+
def log_entity_results(results: EntityMetrics, dataset_name: Text) -> None:
|
|
1932
|
+
"""Logs entity results of cross validation.
|
|
1933
|
+
|
|
1934
|
+
Args:
|
|
1935
|
+
results: dictionary of dictionaries of results returned from cross validation
|
|
1936
|
+
dataset_name: string of which dataset the results are from, e.g. test/train
|
|
1937
|
+
"""
|
|
1938
|
+
for extractor, result in results.items():
|
|
1939
|
+
logger.info(f"Entity extractor: {extractor}")
|
|
1940
|
+
log_results(result, dataset_name)
|