rasa-pro 3.9.18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- README.md +415 -0
- rasa/__init__.py +10 -0
- rasa/__main__.py +156 -0
- rasa/anonymization/__init__.py +2 -0
- rasa/anonymization/anonymisation_rule_yaml_reader.py +91 -0
- rasa/anonymization/anonymization_pipeline.py +286 -0
- rasa/anonymization/anonymization_rule_executor.py +260 -0
- rasa/anonymization/anonymization_rule_orchestrator.py +120 -0
- rasa/anonymization/schemas/config.yml +47 -0
- rasa/anonymization/utils.py +118 -0
- rasa/api.py +146 -0
- rasa/cli/__init__.py +5 -0
- rasa/cli/arguments/__init__.py +0 -0
- rasa/cli/arguments/data.py +81 -0
- rasa/cli/arguments/default_arguments.py +165 -0
- rasa/cli/arguments/evaluate.py +65 -0
- rasa/cli/arguments/export.py +51 -0
- rasa/cli/arguments/interactive.py +74 -0
- rasa/cli/arguments/run.py +204 -0
- rasa/cli/arguments/shell.py +13 -0
- rasa/cli/arguments/test.py +211 -0
- rasa/cli/arguments/train.py +263 -0
- rasa/cli/arguments/visualize.py +34 -0
- rasa/cli/arguments/x.py +30 -0
- rasa/cli/data.py +292 -0
- rasa/cli/e2e_test.py +586 -0
- rasa/cli/evaluate.py +222 -0
- rasa/cli/export.py +250 -0
- rasa/cli/inspect.py +63 -0
- rasa/cli/interactive.py +164 -0
- rasa/cli/license.py +65 -0
- rasa/cli/markers.py +78 -0
- rasa/cli/project_templates/__init__.py +0 -0
- rasa/cli/project_templates/calm/actions/__init__.py +0 -0
- rasa/cli/project_templates/calm/actions/action_template.py +27 -0
- rasa/cli/project_templates/calm/actions/add_contact.py +30 -0
- rasa/cli/project_templates/calm/actions/db.py +57 -0
- rasa/cli/project_templates/calm/actions/list_contacts.py +22 -0
- rasa/cli/project_templates/calm/actions/remove_contact.py +35 -0
- rasa/cli/project_templates/calm/config.yml +12 -0
- rasa/cli/project_templates/calm/credentials.yml +33 -0
- rasa/cli/project_templates/calm/data/flows/add_contact.yml +31 -0
- rasa/cli/project_templates/calm/data/flows/list_contacts.yml +14 -0
- rasa/cli/project_templates/calm/data/flows/remove_contact.yml +29 -0
- rasa/cli/project_templates/calm/db/contacts.json +10 -0
- rasa/cli/project_templates/calm/domain/add_contact.yml +39 -0
- rasa/cli/project_templates/calm/domain/list_contacts.yml +17 -0
- rasa/cli/project_templates/calm/domain/remove_contact.yml +38 -0
- rasa/cli/project_templates/calm/domain/shared.yml +10 -0
- rasa/cli/project_templates/calm/e2e_tests/cancelations/user_cancels_during_a_correction.yml +16 -0
- rasa/cli/project_templates/calm/e2e_tests/cancelations/user_changes_mind_on_a_whim.yml +7 -0
- rasa/cli/project_templates/calm/e2e_tests/corrections/user_corrects_contact_handle.yml +20 -0
- rasa/cli/project_templates/calm/e2e_tests/corrections/user_corrects_contact_name.yml +19 -0
- rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_adds_contact_to_their_list.yml +15 -0
- rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_lists_contacts.yml +5 -0
- rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_removes_contact.yml +11 -0
- rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_removes_contact_from_list.yml +12 -0
- rasa/cli/project_templates/calm/endpoints.yml +45 -0
- rasa/cli/project_templates/default/actions/__init__.py +0 -0
- rasa/cli/project_templates/default/actions/actions.py +27 -0
- rasa/cli/project_templates/default/config.yml +44 -0
- rasa/cli/project_templates/default/credentials.yml +33 -0
- rasa/cli/project_templates/default/data/nlu.yml +91 -0
- rasa/cli/project_templates/default/data/rules.yml +13 -0
- rasa/cli/project_templates/default/data/stories.yml +30 -0
- rasa/cli/project_templates/default/domain.yml +34 -0
- rasa/cli/project_templates/default/endpoints.yml +42 -0
- rasa/cli/project_templates/default/tests/test_stories.yml +91 -0
- rasa/cli/project_templates/tutorial/actions.py +22 -0
- rasa/cli/project_templates/tutorial/config.yml +11 -0
- rasa/cli/project_templates/tutorial/credentials.yml +33 -0
- rasa/cli/project_templates/tutorial/data/flows.yml +8 -0
- rasa/cli/project_templates/tutorial/data/patterns.yml +6 -0
- rasa/cli/project_templates/tutorial/domain.yml +21 -0
- rasa/cli/project_templates/tutorial/endpoints.yml +45 -0
- rasa/cli/run.py +135 -0
- rasa/cli/scaffold.py +269 -0
- rasa/cli/shell.py +141 -0
- rasa/cli/studio/__init__.py +0 -0
- rasa/cli/studio/download.py +62 -0
- rasa/cli/studio/studio.py +266 -0
- rasa/cli/studio/train.py +59 -0
- rasa/cli/studio/upload.py +77 -0
- rasa/cli/telemetry.py +102 -0
- rasa/cli/test.py +280 -0
- rasa/cli/train.py +260 -0
- rasa/cli/utils.py +464 -0
- rasa/cli/visualize.py +40 -0
- rasa/cli/x.py +206 -0
- rasa/constants.py +37 -0
- rasa/core/__init__.py +17 -0
- rasa/core/actions/__init__.py +0 -0
- rasa/core/actions/action.py +1225 -0
- rasa/core/actions/action_clean_stack.py +59 -0
- rasa/core/actions/action_exceptions.py +24 -0
- rasa/core/actions/action_run_slot_rejections.py +207 -0
- rasa/core/actions/action_trigger_chitchat.py +31 -0
- rasa/core/actions/action_trigger_flow.py +109 -0
- rasa/core/actions/action_trigger_search.py +31 -0
- rasa/core/actions/constants.py +5 -0
- rasa/core/actions/custom_action_executor.py +188 -0
- rasa/core/actions/forms.py +741 -0
- rasa/core/actions/grpc_custom_action_executor.py +251 -0
- rasa/core/actions/http_custom_action_executor.py +140 -0
- rasa/core/actions/loops.py +114 -0
- rasa/core/actions/two_stage_fallback.py +186 -0
- rasa/core/agent.py +555 -0
- rasa/core/auth_retry_tracker_store.py +122 -0
- rasa/core/brokers/__init__.py +0 -0
- rasa/core/brokers/broker.py +126 -0
- rasa/core/brokers/file.py +58 -0
- rasa/core/brokers/kafka.py +322 -0
- rasa/core/brokers/pika.py +386 -0
- rasa/core/brokers/sql.py +86 -0
- rasa/core/channels/__init__.py +55 -0
- rasa/core/channels/audiocodes.py +463 -0
- rasa/core/channels/botframework.py +338 -0
- rasa/core/channels/callback.py +84 -0
- rasa/core/channels/channel.py +419 -0
- rasa/core/channels/console.py +241 -0
- rasa/core/channels/development_inspector.py +93 -0
- rasa/core/channels/facebook.py +419 -0
- rasa/core/channels/hangouts.py +329 -0
- rasa/core/channels/inspector/.eslintrc.cjs +25 -0
- rasa/core/channels/inspector/.gitignore +23 -0
- rasa/core/channels/inspector/README.md +54 -0
- rasa/core/channels/inspector/assets/favicon.ico +0 -0
- rasa/core/channels/inspector/assets/rasa-chat.js +2 -0
- rasa/core/channels/inspector/custom.d.ts +3 -0
- rasa/core/channels/inspector/dist/assets/arc-b6e548fe.js +1 -0
- rasa/core/channels/inspector/dist/assets/array-9f3ba611.js +1 -0
- rasa/core/channels/inspector/dist/assets/c4Diagram-d0fbc5ce-fa03ac9e.js +10 -0
- rasa/core/channels/inspector/dist/assets/classDiagram-936ed81e-ee67392a.js +2 -0
- rasa/core/channels/inspector/dist/assets/classDiagram-v2-c3cb15f1-9b283fae.js +2 -0
- rasa/core/channels/inspector/dist/assets/createText-62fc7601-8b6fcc2a.js +7 -0
- rasa/core/channels/inspector/dist/assets/edges-f2ad444c-22e77f4f.js +4 -0
- rasa/core/channels/inspector/dist/assets/erDiagram-9d236eb7-60ffc87f.js +51 -0
- rasa/core/channels/inspector/dist/assets/flowDb-1972c806-9dd802e4.js +6 -0
- rasa/core/channels/inspector/dist/assets/flowDiagram-7ea5b25a-5fa1912f.js +4 -0
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-1844e5a5.js +1 -0
- rasa/core/channels/inspector/dist/assets/flowchart-elk-definition-abe16c3d-622a1fd2.js +139 -0
- rasa/core/channels/inspector/dist/assets/ganttDiagram-9b5ea136-e285a63a.js +266 -0
- rasa/core/channels/inspector/dist/assets/gitGraphDiagram-99d0ae7c-f237bdca.js +70 -0
- rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-128cfa44.ttf +0 -0
- rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-21dbcb97.woff +0 -0
- rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-222b5e26.svg +329 -0
- rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-9ad89b2a.woff2 +0 -0
- rasa/core/channels/inspector/dist/assets/index-2c4b9a3b-4b03d70e.js +1 -0
- rasa/core/channels/inspector/dist/assets/index-3ee28881.css +1 -0
- rasa/core/channels/inspector/dist/assets/index-a5d3e69d.js +1040 -0
- rasa/core/channels/inspector/dist/assets/infoDiagram-736b4530-72a0fa5f.js +7 -0
- rasa/core/channels/inspector/dist/assets/init-77b53fdd.js +1 -0
- rasa/core/channels/inspector/dist/assets/journeyDiagram-df861f2b-82218c41.js +139 -0
- rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-60c05ee4.woff +0 -0
- rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-8335d9b8.svg +438 -0
- rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-9cc39c75.ttf +0 -0
- rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-ead13ccf.woff2 +0 -0
- rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-16705655.woff2 +0 -0
- rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-5aeb07f9.woff +0 -0
- rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-9c459044.ttf +0 -0
- rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-9e2898a4.svg +435 -0
- rasa/core/channels/inspector/dist/assets/layout-78cff630.js +1 -0
- rasa/core/channels/inspector/dist/assets/line-5038b469.js +1 -0
- rasa/core/channels/inspector/dist/assets/linear-c4fc4098.js +1 -0
- rasa/core/channels/inspector/dist/assets/mindmap-definition-beec6740-c33c8ea6.js +109 -0
- rasa/core/channels/inspector/dist/assets/ordinal-ba9b4969.js +1 -0
- rasa/core/channels/inspector/dist/assets/path-53f90ab3.js +1 -0
- rasa/core/channels/inspector/dist/assets/pieDiagram-dbbf0591-a8d03059.js +35 -0
- rasa/core/channels/inspector/dist/assets/quadrantDiagram-4d7f4fd6-6a0e56b2.js +7 -0
- rasa/core/channels/inspector/dist/assets/requirementDiagram-6fc4c22a-2dc7c7bd.js +52 -0
- rasa/core/channels/inspector/dist/assets/sankeyDiagram-8f13d901-2360fe39.js +8 -0
- rasa/core/channels/inspector/dist/assets/sequenceDiagram-b655622a-41b9f9ad.js +122 -0
- rasa/core/channels/inspector/dist/assets/stateDiagram-59f0c015-0aad326f.js +1 -0
- rasa/core/channels/inspector/dist/assets/stateDiagram-v2-2b26beab-9847d984.js +1 -0
- rasa/core/channels/inspector/dist/assets/styles-080da4f6-564d890e.js +110 -0
- rasa/core/channels/inspector/dist/assets/styles-3dcbcfbf-38957613.js +159 -0
- rasa/core/channels/inspector/dist/assets/styles-9c745c82-f0fc6921.js +207 -0
- rasa/core/channels/inspector/dist/assets/svgDrawCommon-4835440b-ef3c5a77.js +1 -0
- rasa/core/channels/inspector/dist/assets/timeline-definition-5b62e21b-bf3e91c1.js +61 -0
- rasa/core/channels/inspector/dist/assets/xychartDiagram-2b33534f-4d4026c0.js +7 -0
- rasa/core/channels/inspector/dist/index.html +41 -0
- rasa/core/channels/inspector/index.html +39 -0
- rasa/core/channels/inspector/jest.config.ts +13 -0
- rasa/core/channels/inspector/package.json +48 -0
- rasa/core/channels/inspector/setupTests.ts +2 -0
- rasa/core/channels/inspector/src/App.tsx +170 -0
- rasa/core/channels/inspector/src/components/DiagramFlow.tsx +107 -0
- rasa/core/channels/inspector/src/components/DialogueInformation.tsx +187 -0
- rasa/core/channels/inspector/src/components/DialogueStack.tsx +151 -0
- rasa/core/channels/inspector/src/components/ExpandIcon.tsx +16 -0
- rasa/core/channels/inspector/src/components/FullscreenButton.tsx +45 -0
- rasa/core/channels/inspector/src/components/LoadingSpinner.tsx +19 -0
- rasa/core/channels/inspector/src/components/NoActiveFlow.tsx +21 -0
- rasa/core/channels/inspector/src/components/RasaLogo.tsx +32 -0
- rasa/core/channels/inspector/src/components/SaraDiagrams.tsx +39 -0
- rasa/core/channels/inspector/src/components/Slots.tsx +91 -0
- rasa/core/channels/inspector/src/components/Welcome.tsx +54 -0
- rasa/core/channels/inspector/src/helpers/formatters.test.ts +382 -0
- rasa/core/channels/inspector/src/helpers/formatters.ts +240 -0
- rasa/core/channels/inspector/src/helpers/utils.ts +42 -0
- rasa/core/channels/inspector/src/main.tsx +13 -0
- rasa/core/channels/inspector/src/theme/Button/Button.ts +29 -0
- rasa/core/channels/inspector/src/theme/Heading/Heading.ts +31 -0
- rasa/core/channels/inspector/src/theme/Input/Input.ts +27 -0
- rasa/core/channels/inspector/src/theme/Link/Link.ts +10 -0
- rasa/core/channels/inspector/src/theme/Modal/Modal.ts +47 -0
- rasa/core/channels/inspector/src/theme/Table/Table.tsx +38 -0
- rasa/core/channels/inspector/src/theme/Tooltip/Tooltip.ts +12 -0
- rasa/core/channels/inspector/src/theme/base/breakpoints.ts +8 -0
- rasa/core/channels/inspector/src/theme/base/colors.ts +88 -0
- rasa/core/channels/inspector/src/theme/base/fonts/fontFaces.css +29 -0
- rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.eot +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.svg +329 -0
- rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.ttf +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.woff +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.woff2 +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.eot +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.svg +438 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.ttf +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.woff +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.woff2 +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.eot +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.svg +435 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.ttf +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.woff +0 -0
- rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.woff2 +0 -0
- rasa/core/channels/inspector/src/theme/base/radii.ts +9 -0
- rasa/core/channels/inspector/src/theme/base/shadows.ts +7 -0
- rasa/core/channels/inspector/src/theme/base/sizes.ts +7 -0
- rasa/core/channels/inspector/src/theme/base/space.ts +15 -0
- rasa/core/channels/inspector/src/theme/base/styles.ts +13 -0
- rasa/core/channels/inspector/src/theme/base/typography.ts +24 -0
- rasa/core/channels/inspector/src/theme/base/zIndices.ts +19 -0
- rasa/core/channels/inspector/src/theme/index.ts +101 -0
- rasa/core/channels/inspector/src/types.ts +64 -0
- rasa/core/channels/inspector/src/vite-env.d.ts +1 -0
- rasa/core/channels/inspector/tests/__mocks__/fileMock.ts +1 -0
- rasa/core/channels/inspector/tests/__mocks__/matchMedia.ts +16 -0
- rasa/core/channels/inspector/tests/__mocks__/styleMock.ts +1 -0
- rasa/core/channels/inspector/tests/renderWithProviders.tsx +14 -0
- rasa/core/channels/inspector/tsconfig.json +26 -0
- rasa/core/channels/inspector/tsconfig.node.json +10 -0
- rasa/core/channels/inspector/vite.config.ts +8 -0
- rasa/core/channels/inspector/yarn.lock +6156 -0
- rasa/core/channels/mattermost.py +229 -0
- rasa/core/channels/rasa_chat.py +126 -0
- rasa/core/channels/rest.py +225 -0
- rasa/core/channels/rocketchat.py +174 -0
- rasa/core/channels/slack.py +620 -0
- rasa/core/channels/socketio.py +274 -0
- rasa/core/channels/telegram.py +298 -0
- rasa/core/channels/twilio.py +169 -0
- rasa/core/channels/twilio_voice.py +367 -0
- rasa/core/channels/vier_cvg.py +374 -0
- rasa/core/channels/webexteams.py +134 -0
- rasa/core/concurrent_lock_store.py +210 -0
- rasa/core/constants.py +107 -0
- rasa/core/evaluation/__init__.py +0 -0
- rasa/core/evaluation/marker.py +267 -0
- rasa/core/evaluation/marker_base.py +923 -0
- rasa/core/evaluation/marker_stats.py +293 -0
- rasa/core/evaluation/marker_tracker_loader.py +103 -0
- rasa/core/exceptions.py +29 -0
- rasa/core/exporter.py +284 -0
- rasa/core/featurizers/__init__.py +0 -0
- rasa/core/featurizers/precomputation.py +410 -0
- rasa/core/featurizers/single_state_featurizer.py +421 -0
- rasa/core/featurizers/tracker_featurizers.py +1262 -0
- rasa/core/http_interpreter.py +89 -0
- rasa/core/information_retrieval/__init__.py +7 -0
- rasa/core/information_retrieval/faiss.py +121 -0
- rasa/core/information_retrieval/information_retrieval.py +129 -0
- rasa/core/information_retrieval/milvus.py +52 -0
- rasa/core/information_retrieval/qdrant.py +95 -0
- rasa/core/jobs.py +63 -0
- rasa/core/lock.py +139 -0
- rasa/core/lock_store.py +343 -0
- rasa/core/migrate.py +403 -0
- rasa/core/nlg/__init__.py +3 -0
- rasa/core/nlg/callback.py +146 -0
- rasa/core/nlg/contextual_response_rephraser.py +270 -0
- rasa/core/nlg/generator.py +230 -0
- rasa/core/nlg/interpolator.py +143 -0
- rasa/core/nlg/response.py +155 -0
- rasa/core/nlg/summarize.py +69 -0
- rasa/core/policies/__init__.py +0 -0
- rasa/core/policies/ensemble.py +329 -0
- rasa/core/policies/enterprise_search_policy.py +781 -0
- rasa/core/policies/enterprise_search_prompt_template.jinja2 +25 -0
- rasa/core/policies/enterprise_search_prompt_with_citation_template.jinja2 +60 -0
- rasa/core/policies/flow_policy.py +205 -0
- rasa/core/policies/flows/__init__.py +0 -0
- rasa/core/policies/flows/flow_exceptions.py +44 -0
- rasa/core/policies/flows/flow_executor.py +705 -0
- rasa/core/policies/flows/flow_step_result.py +43 -0
- rasa/core/policies/intentless_policy.py +922 -0
- rasa/core/policies/intentless_prompt_template.jinja2 +22 -0
- rasa/core/policies/memoization.py +538 -0
- rasa/core/policies/policy.py +725 -0
- rasa/core/policies/rule_policy.py +1273 -0
- rasa/core/policies/ted_policy.py +2169 -0
- rasa/core/policies/unexpected_intent_policy.py +1022 -0
- rasa/core/processor.py +1422 -0
- rasa/core/run.py +331 -0
- rasa/core/secrets_manager/__init__.py +0 -0
- rasa/core/secrets_manager/constants.py +32 -0
- rasa/core/secrets_manager/endpoints.py +391 -0
- rasa/core/secrets_manager/factory.py +233 -0
- rasa/core/secrets_manager/secret_manager.py +262 -0
- rasa/core/secrets_manager/vault.py +574 -0
- rasa/core/test.py +1335 -0
- rasa/core/tracker_store.py +1699 -0
- rasa/core/train.py +105 -0
- rasa/core/training/__init__.py +89 -0
- rasa/core/training/converters/__init__.py +0 -0
- rasa/core/training/converters/responses_prefix_converter.py +119 -0
- rasa/core/training/interactive.py +1745 -0
- rasa/core/training/story_conflict.py +381 -0
- rasa/core/training/training.py +93 -0
- rasa/core/utils.py +339 -0
- rasa/core/visualize.py +70 -0
- rasa/dialogue_understanding/__init__.py +0 -0
- rasa/dialogue_understanding/coexistence/__init__.py +0 -0
- rasa/dialogue_understanding/coexistence/constants.py +4 -0
- rasa/dialogue_understanding/coexistence/intent_based_router.py +196 -0
- rasa/dialogue_understanding/coexistence/llm_based_router.py +260 -0
- rasa/dialogue_understanding/coexistence/router_template.jinja2 +12 -0
- rasa/dialogue_understanding/commands/__init__.py +49 -0
- rasa/dialogue_understanding/commands/can_not_handle_command.py +70 -0
- rasa/dialogue_understanding/commands/cancel_flow_command.py +125 -0
- rasa/dialogue_understanding/commands/change_flow_command.py +44 -0
- rasa/dialogue_understanding/commands/chit_chat_answer_command.py +57 -0
- rasa/dialogue_understanding/commands/clarify_command.py +86 -0
- rasa/dialogue_understanding/commands/command.py +85 -0
- rasa/dialogue_understanding/commands/correct_slots_command.py +297 -0
- rasa/dialogue_understanding/commands/error_command.py +79 -0
- rasa/dialogue_understanding/commands/free_form_answer_command.py +9 -0
- rasa/dialogue_understanding/commands/handle_code_change_command.py +73 -0
- rasa/dialogue_understanding/commands/human_handoff_command.py +66 -0
- rasa/dialogue_understanding/commands/knowledge_answer_command.py +57 -0
- rasa/dialogue_understanding/commands/noop_command.py +54 -0
- rasa/dialogue_understanding/commands/set_slot_command.py +160 -0
- rasa/dialogue_understanding/commands/skip_question_command.py +75 -0
- rasa/dialogue_understanding/commands/start_flow_command.py +107 -0
- rasa/dialogue_understanding/generator/__init__.py +21 -0
- rasa/dialogue_understanding/generator/command_generator.py +343 -0
- rasa/dialogue_understanding/generator/constants.py +18 -0
- rasa/dialogue_understanding/generator/flow_document_template.jinja2 +4 -0
- rasa/dialogue_understanding/generator/flow_retrieval.py +412 -0
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +467 -0
- rasa/dialogue_understanding/generator/llm_command_generator.py +67 -0
- rasa/dialogue_understanding/generator/multi_step/__init__.py +0 -0
- rasa/dialogue_understanding/generator/multi_step/fill_slots_prompt.jinja2 +62 -0
- rasa/dialogue_understanding/generator/multi_step/handle_flows_prompt.jinja2 +38 -0
- rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +827 -0
- rasa/dialogue_understanding/generator/nlu_command_adapter.py +218 -0
- rasa/dialogue_understanding/generator/single_step/__init__.py +0 -0
- rasa/dialogue_understanding/generator/single_step/command_prompt_template.jinja2 +57 -0
- rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +345 -0
- rasa/dialogue_understanding/patterns/__init__.py +0 -0
- rasa/dialogue_understanding/patterns/cancel.py +111 -0
- rasa/dialogue_understanding/patterns/cannot_handle.py +43 -0
- rasa/dialogue_understanding/patterns/chitchat.py +37 -0
- rasa/dialogue_understanding/patterns/clarify.py +97 -0
- rasa/dialogue_understanding/patterns/code_change.py +41 -0
- rasa/dialogue_understanding/patterns/collect_information.py +90 -0
- rasa/dialogue_understanding/patterns/completed.py +40 -0
- rasa/dialogue_understanding/patterns/continue_interrupted.py +42 -0
- rasa/dialogue_understanding/patterns/correction.py +278 -0
- rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +248 -0
- rasa/dialogue_understanding/patterns/human_handoff.py +37 -0
- rasa/dialogue_understanding/patterns/internal_error.py +47 -0
- rasa/dialogue_understanding/patterns/search.py +37 -0
- rasa/dialogue_understanding/patterns/skip_question.py +38 -0
- rasa/dialogue_understanding/processor/__init__.py +0 -0
- rasa/dialogue_understanding/processor/command_processor.py +687 -0
- rasa/dialogue_understanding/processor/command_processor_component.py +39 -0
- rasa/dialogue_understanding/stack/__init__.py +0 -0
- rasa/dialogue_understanding/stack/dialogue_stack.py +178 -0
- rasa/dialogue_understanding/stack/frames/__init__.py +19 -0
- rasa/dialogue_understanding/stack/frames/chit_chat_frame.py +27 -0
- rasa/dialogue_understanding/stack/frames/dialogue_stack_frame.py +137 -0
- rasa/dialogue_understanding/stack/frames/flow_stack_frame.py +157 -0
- rasa/dialogue_understanding/stack/frames/pattern_frame.py +10 -0
- rasa/dialogue_understanding/stack/frames/search_frame.py +27 -0
- rasa/dialogue_understanding/stack/utils.py +211 -0
- rasa/e2e_test/__init__.py +0 -0
- rasa/e2e_test/constants.py +11 -0
- rasa/e2e_test/e2e_test_case.py +366 -0
- rasa/e2e_test/e2e_test_result.py +34 -0
- rasa/e2e_test/e2e_test_runner.py +768 -0
- rasa/e2e_test/e2e_test_schema.yml +85 -0
- rasa/engine/__init__.py +0 -0
- rasa/engine/caching.py +463 -0
- rasa/engine/constants.py +17 -0
- rasa/engine/exceptions.py +14 -0
- rasa/engine/graph.py +637 -0
- rasa/engine/loader.py +36 -0
- rasa/engine/recipes/__init__.py +0 -0
- rasa/engine/recipes/config_files/default_config.yml +44 -0
- rasa/engine/recipes/default_components.py +99 -0
- rasa/engine/recipes/default_recipe.py +1251 -0
- rasa/engine/recipes/graph_recipe.py +79 -0
- rasa/engine/recipes/recipe.py +93 -0
- rasa/engine/runner/__init__.py +0 -0
- rasa/engine/runner/dask.py +250 -0
- rasa/engine/runner/interface.py +49 -0
- rasa/engine/storage/__init__.py +0 -0
- rasa/engine/storage/local_model_storage.py +246 -0
- rasa/engine/storage/resource.py +110 -0
- rasa/engine/storage/storage.py +203 -0
- rasa/engine/training/__init__.py +0 -0
- rasa/engine/training/components.py +176 -0
- rasa/engine/training/fingerprinting.py +64 -0
- rasa/engine/training/graph_trainer.py +256 -0
- rasa/engine/training/hooks.py +164 -0
- rasa/engine/validation.py +873 -0
- rasa/env.py +5 -0
- rasa/exceptions.py +69 -0
- rasa/graph_components/__init__.py +0 -0
- rasa/graph_components/converters/__init__.py +0 -0
- rasa/graph_components/converters/nlu_message_converter.py +48 -0
- rasa/graph_components/providers/__init__.py +0 -0
- rasa/graph_components/providers/domain_for_core_training_provider.py +87 -0
- rasa/graph_components/providers/domain_provider.py +71 -0
- rasa/graph_components/providers/flows_provider.py +74 -0
- rasa/graph_components/providers/forms_provider.py +44 -0
- rasa/graph_components/providers/nlu_training_data_provider.py +56 -0
- rasa/graph_components/providers/responses_provider.py +44 -0
- rasa/graph_components/providers/rule_only_provider.py +49 -0
- rasa/graph_components/providers/story_graph_provider.py +43 -0
- rasa/graph_components/providers/training_tracker_provider.py +55 -0
- rasa/graph_components/validators/__init__.py +0 -0
- rasa/graph_components/validators/default_recipe_validator.py +550 -0
- rasa/graph_components/validators/finetuning_validator.py +302 -0
- rasa/hooks.py +112 -0
- rasa/jupyter.py +63 -0
- rasa/markers/__init__.py +0 -0
- rasa/markers/marker.py +269 -0
- rasa/markers/marker_base.py +828 -0
- rasa/markers/upload.py +74 -0
- rasa/markers/validate.py +21 -0
- rasa/model.py +118 -0
- rasa/model_testing.py +457 -0
- rasa/model_training.py +536 -0
- rasa/nlu/__init__.py +7 -0
- rasa/nlu/classifiers/__init__.py +3 -0
- rasa/nlu/classifiers/classifier.py +5 -0
- rasa/nlu/classifiers/diet_classifier.py +1881 -0
- rasa/nlu/classifiers/fallback_classifier.py +192 -0
- rasa/nlu/classifiers/keyword_intent_classifier.py +188 -0
- rasa/nlu/classifiers/llm_intent_classifier.py +519 -0
- rasa/nlu/classifiers/logistic_regression_classifier.py +253 -0
- rasa/nlu/classifiers/mitie_intent_classifier.py +156 -0
- rasa/nlu/classifiers/regex_message_handler.py +56 -0
- rasa/nlu/classifiers/sklearn_intent_classifier.py +330 -0
- rasa/nlu/constants.py +77 -0
- rasa/nlu/convert.py +40 -0
- rasa/nlu/emulators/__init__.py +0 -0
- rasa/nlu/emulators/dialogflow.py +55 -0
- rasa/nlu/emulators/emulator.py +49 -0
- rasa/nlu/emulators/luis.py +86 -0
- rasa/nlu/emulators/no_emulator.py +10 -0
- rasa/nlu/emulators/wit.py +56 -0
- rasa/nlu/extractors/__init__.py +0 -0
- rasa/nlu/extractors/crf_entity_extractor.py +715 -0
- rasa/nlu/extractors/duckling_entity_extractor.py +206 -0
- rasa/nlu/extractors/entity_synonyms.py +178 -0
- rasa/nlu/extractors/extractor.py +470 -0
- rasa/nlu/extractors/mitie_entity_extractor.py +293 -0
- rasa/nlu/extractors/regex_entity_extractor.py +220 -0
- rasa/nlu/extractors/spacy_entity_extractor.py +95 -0
- rasa/nlu/featurizers/__init__.py +0 -0
- rasa/nlu/featurizers/dense_featurizer/__init__.py +0 -0
- rasa/nlu/featurizers/dense_featurizer/convert_featurizer.py +445 -0
- rasa/nlu/featurizers/dense_featurizer/dense_featurizer.py +57 -0
- rasa/nlu/featurizers/dense_featurizer/lm_featurizer.py +768 -0
- rasa/nlu/featurizers/dense_featurizer/mitie_featurizer.py +170 -0
- rasa/nlu/featurizers/dense_featurizer/spacy_featurizer.py +132 -0
- rasa/nlu/featurizers/featurizer.py +89 -0
- rasa/nlu/featurizers/sparse_featurizer/__init__.py +0 -0
- rasa/nlu/featurizers/sparse_featurizer/count_vectors_featurizer.py +867 -0
- rasa/nlu/featurizers/sparse_featurizer/lexical_syntactic_featurizer.py +571 -0
- rasa/nlu/featurizers/sparse_featurizer/regex_featurizer.py +271 -0
- rasa/nlu/featurizers/sparse_featurizer/sparse_featurizer.py +9 -0
- rasa/nlu/model.py +24 -0
- rasa/nlu/persistor.py +282 -0
- rasa/nlu/run.py +27 -0
- rasa/nlu/selectors/__init__.py +0 -0
- rasa/nlu/selectors/response_selector.py +987 -0
- rasa/nlu/test.py +1940 -0
- rasa/nlu/tokenizers/__init__.py +0 -0
- rasa/nlu/tokenizers/jieba_tokenizer.py +148 -0
- rasa/nlu/tokenizers/mitie_tokenizer.py +75 -0
- rasa/nlu/tokenizers/spacy_tokenizer.py +72 -0
- rasa/nlu/tokenizers/tokenizer.py +239 -0
- rasa/nlu/tokenizers/whitespace_tokenizer.py +106 -0
- rasa/nlu/utils/__init__.py +35 -0
- rasa/nlu/utils/bilou_utils.py +462 -0
- rasa/nlu/utils/hugging_face/__init__.py +0 -0
- rasa/nlu/utils/hugging_face/registry.py +108 -0
- rasa/nlu/utils/hugging_face/transformers_pre_post_processors.py +311 -0
- rasa/nlu/utils/mitie_utils.py +113 -0
- rasa/nlu/utils/pattern_utils.py +168 -0
- rasa/nlu/utils/spacy_utils.py +310 -0
- rasa/plugin.py +90 -0
- rasa/server.py +1551 -0
- rasa/shared/__init__.py +0 -0
- rasa/shared/constants.py +192 -0
- rasa/shared/core/__init__.py +0 -0
- rasa/shared/core/command_payload_reader.py +109 -0
- rasa/shared/core/constants.py +167 -0
- rasa/shared/core/conversation.py +46 -0
- rasa/shared/core/domain.py +2107 -0
- rasa/shared/core/events.py +2504 -0
- rasa/shared/core/flows/__init__.py +7 -0
- rasa/shared/core/flows/flow.py +362 -0
- rasa/shared/core/flows/flow_step.py +146 -0
- rasa/shared/core/flows/flow_step_links.py +319 -0
- rasa/shared/core/flows/flow_step_sequence.py +70 -0
- rasa/shared/core/flows/flows_list.py +223 -0
- rasa/shared/core/flows/flows_yaml_schema.json +217 -0
- rasa/shared/core/flows/nlu_trigger.py +117 -0
- rasa/shared/core/flows/steps/__init__.py +24 -0
- rasa/shared/core/flows/steps/action.py +56 -0
- rasa/shared/core/flows/steps/call.py +64 -0
- rasa/shared/core/flows/steps/collect.py +112 -0
- rasa/shared/core/flows/steps/constants.py +5 -0
- rasa/shared/core/flows/steps/continuation.py +36 -0
- rasa/shared/core/flows/steps/end.py +22 -0
- rasa/shared/core/flows/steps/internal.py +44 -0
- rasa/shared/core/flows/steps/link.py +51 -0
- rasa/shared/core/flows/steps/no_operation.py +48 -0
- rasa/shared/core/flows/steps/set_slots.py +50 -0
- rasa/shared/core/flows/steps/start.py +30 -0
- rasa/shared/core/flows/validation.py +527 -0
- rasa/shared/core/flows/yaml_flows_io.py +278 -0
- rasa/shared/core/generator.py +908 -0
- rasa/shared/core/slot_mappings.py +526 -0
- rasa/shared/core/slots.py +649 -0
- rasa/shared/core/trackers.py +1177 -0
- rasa/shared/core/training_data/__init__.py +0 -0
- rasa/shared/core/training_data/loading.py +89 -0
- rasa/shared/core/training_data/story_reader/__init__.py +0 -0
- rasa/shared/core/training_data/story_reader/story_reader.py +129 -0
- rasa/shared/core/training_data/story_reader/story_step_builder.py +168 -0
- rasa/shared/core/training_data/story_reader/yaml_story_reader.py +888 -0
- rasa/shared/core/training_data/story_writer/__init__.py +0 -0
- rasa/shared/core/training_data/story_writer/story_writer.py +76 -0
- rasa/shared/core/training_data/story_writer/yaml_story_writer.py +444 -0
- rasa/shared/core/training_data/structures.py +838 -0
- rasa/shared/core/training_data/visualization.html +146 -0
- rasa/shared/core/training_data/visualization.py +603 -0
- rasa/shared/data.py +249 -0
- rasa/shared/engine/__init__.py +0 -0
- rasa/shared/engine/caching.py +26 -0
- rasa/shared/exceptions.py +163 -0
- rasa/shared/importers/__init__.py +0 -0
- rasa/shared/importers/importer.py +704 -0
- rasa/shared/importers/multi_project.py +203 -0
- rasa/shared/importers/rasa.py +99 -0
- rasa/shared/importers/utils.py +34 -0
- rasa/shared/nlu/__init__.py +0 -0
- rasa/shared/nlu/constants.py +47 -0
- rasa/shared/nlu/interpreter.py +10 -0
- rasa/shared/nlu/training_data/__init__.py +0 -0
- rasa/shared/nlu/training_data/entities_parser.py +208 -0
- rasa/shared/nlu/training_data/features.py +492 -0
- rasa/shared/nlu/training_data/formats/__init__.py +10 -0
- rasa/shared/nlu/training_data/formats/dialogflow.py +163 -0
- rasa/shared/nlu/training_data/formats/luis.py +87 -0
- rasa/shared/nlu/training_data/formats/rasa.py +135 -0
- rasa/shared/nlu/training_data/formats/rasa_yaml.py +603 -0
- rasa/shared/nlu/training_data/formats/readerwriter.py +244 -0
- rasa/shared/nlu/training_data/formats/wit.py +52 -0
- rasa/shared/nlu/training_data/loading.py +137 -0
- rasa/shared/nlu/training_data/lookup_tables_parser.py +30 -0
- rasa/shared/nlu/training_data/message.py +490 -0
- rasa/shared/nlu/training_data/schemas/__init__.py +0 -0
- rasa/shared/nlu/training_data/schemas/data_schema.py +85 -0
- rasa/shared/nlu/training_data/schemas/nlu.yml +53 -0
- rasa/shared/nlu/training_data/schemas/responses.yml +70 -0
- rasa/shared/nlu/training_data/synonyms_parser.py +42 -0
- rasa/shared/nlu/training_data/training_data.py +730 -0
- rasa/shared/nlu/training_data/util.py +223 -0
- rasa/shared/providers/__init__.py +0 -0
- rasa/shared/providers/openai/__init__.py +0 -0
- rasa/shared/providers/openai/clients.py +43 -0
- rasa/shared/providers/openai/session_handler.py +110 -0
- rasa/shared/utils/__init__.py +0 -0
- rasa/shared/utils/cli.py +72 -0
- rasa/shared/utils/common.py +308 -0
- rasa/shared/utils/constants.py +4 -0
- rasa/shared/utils/io.py +415 -0
- rasa/shared/utils/llm.py +404 -0
- rasa/shared/utils/pykwalify_extensions.py +27 -0
- rasa/shared/utils/schemas/__init__.py +0 -0
- rasa/shared/utils/schemas/config.yml +2 -0
- rasa/shared/utils/schemas/domain.yml +145 -0
- rasa/shared/utils/schemas/events.py +212 -0
- rasa/shared/utils/schemas/model_config.yml +46 -0
- rasa/shared/utils/schemas/stories.yml +173 -0
- rasa/shared/utils/yaml.py +786 -0
- rasa/studio/__init__.py +0 -0
- rasa/studio/auth.py +268 -0
- rasa/studio/config.py +127 -0
- rasa/studio/constants.py +18 -0
- rasa/studio/data_handler.py +359 -0
- rasa/studio/download.py +483 -0
- rasa/studio/results_logger.py +137 -0
- rasa/studio/train.py +135 -0
- rasa/studio/upload.py +433 -0
- rasa/telemetry.py +1737 -0
- rasa/tracing/__init__.py +0 -0
- rasa/tracing/config.py +353 -0
- rasa/tracing/constants.py +62 -0
- rasa/tracing/instrumentation/__init__.py +0 -0
- rasa/tracing/instrumentation/attribute_extractors.py +672 -0
- rasa/tracing/instrumentation/instrumentation.py +1185 -0
- rasa/tracing/instrumentation/intentless_policy_instrumentation.py +144 -0
- rasa/tracing/instrumentation/metrics.py +294 -0
- rasa/tracing/metric_instrument_provider.py +205 -0
- rasa/utils/__init__.py +0 -0
- rasa/utils/beta.py +83 -0
- rasa/utils/cli.py +28 -0
- rasa/utils/common.py +635 -0
- rasa/utils/converter.py +53 -0
- rasa/utils/endpoints.py +302 -0
- rasa/utils/io.py +260 -0
- rasa/utils/licensing.py +534 -0
- rasa/utils/log_utils.py +174 -0
- rasa/utils/mapper.py +210 -0
- rasa/utils/ml_utils.py +145 -0
- rasa/utils/plotting.py +362 -0
- rasa/utils/singleton.py +23 -0
- rasa/utils/tensorflow/__init__.py +0 -0
- rasa/utils/tensorflow/callback.py +112 -0
- rasa/utils/tensorflow/constants.py +116 -0
- rasa/utils/tensorflow/crf.py +492 -0
- rasa/utils/tensorflow/data_generator.py +440 -0
- rasa/utils/tensorflow/environment.py +161 -0
- rasa/utils/tensorflow/exceptions.py +5 -0
- rasa/utils/tensorflow/feature_array.py +366 -0
- rasa/utils/tensorflow/layers.py +1565 -0
- rasa/utils/tensorflow/layers_utils.py +113 -0
- rasa/utils/tensorflow/metrics.py +281 -0
- rasa/utils/tensorflow/model_data.py +798 -0
- rasa/utils/tensorflow/model_data_utils.py +499 -0
- rasa/utils/tensorflow/models.py +935 -0
- rasa/utils/tensorflow/rasa_layers.py +1094 -0
- rasa/utils/tensorflow/transformer.py +640 -0
- rasa/utils/tensorflow/types.py +6 -0
- rasa/utils/train_utils.py +572 -0
- rasa/utils/url_tools.py +53 -0
- rasa/utils/yaml.py +54 -0
- rasa/validator.py +1337 -0
- rasa/version.py +3 -0
- rasa_pro-3.9.18.dist-info/METADATA +563 -0
- rasa_pro-3.9.18.dist-info/NOTICE +5 -0
- rasa_pro-3.9.18.dist-info/RECORD +662 -0
- rasa_pro-3.9.18.dist-info/WHEEL +4 -0
- rasa_pro-3.9.18.dist-info/entry_points.txt +3 -0
|
@@ -0,0 +1,462 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import operator
|
|
3
|
+
from collections import defaultdict, Counter
|
|
4
|
+
from typing import List, Tuple, Text, Optional, Dict, Any, TYPE_CHECKING
|
|
5
|
+
|
|
6
|
+
from rasa.nlu.constants import (
|
|
7
|
+
TOKENS_NAMES,
|
|
8
|
+
BILOU_ENTITIES,
|
|
9
|
+
BILOU_ENTITIES_GROUP,
|
|
10
|
+
BILOU_ENTITIES_ROLE,
|
|
11
|
+
)
|
|
12
|
+
from rasa.shared.nlu.constants import (
|
|
13
|
+
TEXT,
|
|
14
|
+
ENTITIES,
|
|
15
|
+
ENTITY_ATTRIBUTE_START,
|
|
16
|
+
ENTITY_ATTRIBUTE_END,
|
|
17
|
+
ENTITY_ATTRIBUTE_TYPE,
|
|
18
|
+
ENTITY_ATTRIBUTE_GROUP,
|
|
19
|
+
ENTITY_ATTRIBUTE_ROLE,
|
|
20
|
+
NO_ENTITY_TAG,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
if TYPE_CHECKING:
|
|
24
|
+
from rasa.nlu.tokenizers.tokenizer import Token
|
|
25
|
+
from rasa.shared.nlu.training_data.training_data import TrainingData
|
|
26
|
+
from rasa.shared.nlu.training_data.message import Message
|
|
27
|
+
|
|
28
|
+
logger = logging.getLogger(__name__)
|
|
29
|
+
|
|
30
|
+
BEGINNING = "B-"
|
|
31
|
+
INSIDE = "I-"
|
|
32
|
+
LAST = "L-"
|
|
33
|
+
UNIT = "U-"
|
|
34
|
+
BILOU_PREFIXES = [BEGINNING, INSIDE, LAST, UNIT]
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def bilou_prefix_from_tag(tag: Text) -> Optional[Text]:
|
|
38
|
+
"""Returns the BILOU prefix from the given tag.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
tag: the tag
|
|
42
|
+
|
|
43
|
+
Returns: the BILOU prefix of the tag
|
|
44
|
+
"""
|
|
45
|
+
if tag[:2] in BILOU_PREFIXES:
|
|
46
|
+
return tag[:2]
|
|
47
|
+
return None
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def tag_without_prefix(tag: Text) -> Text:
|
|
51
|
+
"""Remove the BILOU prefix from the given tag.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
tag: the tag
|
|
55
|
+
|
|
56
|
+
Returns: the tag without the BILOU prefix
|
|
57
|
+
"""
|
|
58
|
+
if tag[:2] in BILOU_PREFIXES:
|
|
59
|
+
return tag[2:]
|
|
60
|
+
return tag
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def bilou_tags_to_ids(
|
|
64
|
+
message: "Message",
|
|
65
|
+
tag_id_dict: Dict[Text, int],
|
|
66
|
+
tag_name: Text = ENTITY_ATTRIBUTE_TYPE,
|
|
67
|
+
) -> List[int]:
|
|
68
|
+
"""Maps the entity tags of the message to the ids of the provided dict.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
message: the message
|
|
72
|
+
tag_id_dict: mapping of tags to ids
|
|
73
|
+
tag_name: tag name of interest
|
|
74
|
+
|
|
75
|
+
Returns: a list of tag ids
|
|
76
|
+
"""
|
|
77
|
+
bilou_key = get_bilou_key_for_tag(tag_name)
|
|
78
|
+
|
|
79
|
+
if message.get(bilou_key):
|
|
80
|
+
_tags = [
|
|
81
|
+
tag_id_dict[_tag] if _tag in tag_id_dict else tag_id_dict[NO_ENTITY_TAG]
|
|
82
|
+
for _tag in message.get(bilou_key)
|
|
83
|
+
]
|
|
84
|
+
else:
|
|
85
|
+
_tags = [tag_id_dict[NO_ENTITY_TAG] for _ in message.get(TOKENS_NAMES[TEXT])]
|
|
86
|
+
|
|
87
|
+
return _tags
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def get_bilou_key_for_tag(tag_name: Text) -> Text:
|
|
91
|
+
"""Get the message key for the BILOU tagging format of the provided tag name.
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
tag_name: the tag name
|
|
95
|
+
|
|
96
|
+
Returns:
|
|
97
|
+
the message key to store the BILOU tags
|
|
98
|
+
"""
|
|
99
|
+
if tag_name == ENTITY_ATTRIBUTE_ROLE:
|
|
100
|
+
return BILOU_ENTITIES_ROLE
|
|
101
|
+
|
|
102
|
+
if tag_name == ENTITY_ATTRIBUTE_GROUP:
|
|
103
|
+
return BILOU_ENTITIES_GROUP
|
|
104
|
+
|
|
105
|
+
return BILOU_ENTITIES
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def build_tag_id_dict(
|
|
109
|
+
training_data: "TrainingData", tag_name: Text = ENTITY_ATTRIBUTE_TYPE
|
|
110
|
+
) -> Optional[Dict[Text, int]]:
|
|
111
|
+
"""Create a mapping of unique tags to ids.
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
training_data: the training data
|
|
115
|
+
tag_name: tag name of interest
|
|
116
|
+
|
|
117
|
+
Returns: a mapping of tags to ids
|
|
118
|
+
"""
|
|
119
|
+
bilou_key = get_bilou_key_for_tag(tag_name)
|
|
120
|
+
|
|
121
|
+
distinct_tags = set(
|
|
122
|
+
[
|
|
123
|
+
tag_without_prefix(e)
|
|
124
|
+
for example in training_data.nlu_examples
|
|
125
|
+
if example.get(bilou_key)
|
|
126
|
+
for e in example.get(bilou_key)
|
|
127
|
+
]
|
|
128
|
+
) - {NO_ENTITY_TAG}
|
|
129
|
+
|
|
130
|
+
if not distinct_tags:
|
|
131
|
+
return None
|
|
132
|
+
|
|
133
|
+
tag_id_dict = {
|
|
134
|
+
f"{prefix}{tag}": idx_1 * len(BILOU_PREFIXES) + idx_2 + 1
|
|
135
|
+
for idx_1, tag in enumerate(sorted(distinct_tags))
|
|
136
|
+
for idx_2, prefix in enumerate(BILOU_PREFIXES)
|
|
137
|
+
}
|
|
138
|
+
# NO_ENTITY_TAG corresponds to non-entity which should correspond to 0 index
|
|
139
|
+
# needed for correct prediction for padding
|
|
140
|
+
tag_id_dict[NO_ENTITY_TAG] = 0
|
|
141
|
+
|
|
142
|
+
return tag_id_dict
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def apply_bilou_schema(training_data: "TrainingData") -> None:
|
|
146
|
+
"""Get a list of BILOU entity tags and set them on the given messages.
|
|
147
|
+
|
|
148
|
+
Args:
|
|
149
|
+
training_data: the training data
|
|
150
|
+
"""
|
|
151
|
+
for message in training_data.nlu_examples:
|
|
152
|
+
apply_bilou_schema_to_message(message)
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def apply_bilou_schema_to_message(message: "Message") -> None:
|
|
156
|
+
"""Get a list of BILOU entity tags and set them on the given message.
|
|
157
|
+
|
|
158
|
+
Args:
|
|
159
|
+
message: the message
|
|
160
|
+
"""
|
|
161
|
+
entities = message.get(ENTITIES)
|
|
162
|
+
|
|
163
|
+
if not entities:
|
|
164
|
+
return
|
|
165
|
+
|
|
166
|
+
tokens = message.get(TOKENS_NAMES[TEXT])
|
|
167
|
+
|
|
168
|
+
for attribute, message_key in [
|
|
169
|
+
(ENTITY_ATTRIBUTE_TYPE, BILOU_ENTITIES),
|
|
170
|
+
(ENTITY_ATTRIBUTE_ROLE, BILOU_ENTITIES_ROLE),
|
|
171
|
+
(ENTITY_ATTRIBUTE_GROUP, BILOU_ENTITIES_GROUP),
|
|
172
|
+
]:
|
|
173
|
+
entities = map_message_entities(message, attribute)
|
|
174
|
+
output = bilou_tags_from_offsets(tokens, entities)
|
|
175
|
+
message.set(message_key, output)
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def map_message_entities(
|
|
179
|
+
message: "Message", attribute_key: Text = ENTITY_ATTRIBUTE_TYPE
|
|
180
|
+
) -> List[Tuple[int, int, Text]]:
|
|
181
|
+
"""Maps the entities of the given message to their start, end, and tag values.
|
|
182
|
+
|
|
183
|
+
Args:
|
|
184
|
+
message: the message
|
|
185
|
+
attribute_key: key of tag value to use
|
|
186
|
+
|
|
187
|
+
Returns: a list of start, end, and tag value tuples
|
|
188
|
+
"""
|
|
189
|
+
|
|
190
|
+
def convert_entity(entity: Dict[Text, Any]) -> Tuple[int, int, Text]:
|
|
191
|
+
return (
|
|
192
|
+
entity[ENTITY_ATTRIBUTE_START],
|
|
193
|
+
entity[ENTITY_ATTRIBUTE_END],
|
|
194
|
+
entity.get(attribute_key) or NO_ENTITY_TAG,
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
entities = [convert_entity(entity) for entity in message.get(ENTITIES, [])]
|
|
198
|
+
|
|
199
|
+
# entities is a list of tuples (start, end, tag value).
|
|
200
|
+
# filter out all entities with tag value == NO_ENTITY_TAG.
|
|
201
|
+
tag_value_idx = 2
|
|
202
|
+
return [entity for entity in entities if entity[tag_value_idx] != NO_ENTITY_TAG]
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
def bilou_tags_from_offsets(
|
|
206
|
+
tokens: List["Token"], entities: List[Tuple[int, int, Text]]
|
|
207
|
+
) -> List[Text]:
|
|
208
|
+
"""Creates BILOU tags for the given tokens and entities.
|
|
209
|
+
|
|
210
|
+
Args:
|
|
211
|
+
message: The message object.
|
|
212
|
+
tokens: The list of tokens.
|
|
213
|
+
entities: The list of start, end, and tag tuples.
|
|
214
|
+
missing: The tag for missing entities.
|
|
215
|
+
|
|
216
|
+
Returns:
|
|
217
|
+
BILOU tags.
|
|
218
|
+
"""
|
|
219
|
+
start_pos_to_token_idx = {token.start: i for i, token in enumerate(tokens)}
|
|
220
|
+
end_pos_to_token_idx = {token.end: i for i, token in enumerate(tokens)}
|
|
221
|
+
|
|
222
|
+
bilou = [NO_ENTITY_TAG for _ in tokens]
|
|
223
|
+
|
|
224
|
+
_add_bilou_tags_to_entities(
|
|
225
|
+
bilou, entities, end_pos_to_token_idx, start_pos_to_token_idx
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
return bilou
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
def _add_bilou_tags_to_entities(
|
|
232
|
+
bilou: List[Text],
|
|
233
|
+
entities: List[Tuple[int, int, Text]],
|
|
234
|
+
end_pos_to_token_idx: Dict[int, int],
|
|
235
|
+
start_pos_to_token_idx: Dict[int, int],
|
|
236
|
+
) -> None:
|
|
237
|
+
for start_pos, end_pos, label in entities:
|
|
238
|
+
start_token_idx = start_pos_to_token_idx.get(start_pos)
|
|
239
|
+
end_token_idx = end_pos_to_token_idx.get(end_pos)
|
|
240
|
+
|
|
241
|
+
# Only interested if the tokenization is correct
|
|
242
|
+
if start_token_idx is not None and end_token_idx is not None:
|
|
243
|
+
if start_token_idx == end_token_idx:
|
|
244
|
+
bilou[start_token_idx] = f"{UNIT}{label}"
|
|
245
|
+
else:
|
|
246
|
+
bilou[start_token_idx] = f"{BEGINNING}{label}"
|
|
247
|
+
for i in range(start_token_idx + 1, end_token_idx):
|
|
248
|
+
bilou[i] = f"{INSIDE}{label}"
|
|
249
|
+
bilou[end_token_idx] = f"{LAST}{label}"
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
def ensure_consistent_bilou_tagging(
|
|
253
|
+
predicted_tags: List[Text], predicted_confidences: List[float]
|
|
254
|
+
) -> Tuple[List[Text], List[float]]:
|
|
255
|
+
"""Ensure predicted tags follow the BILOU tagging schema.
|
|
256
|
+
|
|
257
|
+
We assume that starting B- tags are correct. Followed tags that belong to start
|
|
258
|
+
tag but have a different entity type are updated considering also the confidence
|
|
259
|
+
values of those tags.
|
|
260
|
+
For example, B-a I-b L-a is updated to B-a I-a L-a and B-a I-a O is changed to
|
|
261
|
+
B-a L-a.
|
|
262
|
+
|
|
263
|
+
Args:
|
|
264
|
+
predicted_tags: predicted tags
|
|
265
|
+
predicted_confidences: predicted confidences
|
|
266
|
+
|
|
267
|
+
Return:
|
|
268
|
+
List of tags.
|
|
269
|
+
List of confidences.
|
|
270
|
+
"""
|
|
271
|
+
for idx, predicted_tag in enumerate(predicted_tags):
|
|
272
|
+
prefix = bilou_prefix_from_tag(predicted_tag)
|
|
273
|
+
tag = tag_without_prefix(predicted_tag)
|
|
274
|
+
|
|
275
|
+
if prefix == BEGINNING:
|
|
276
|
+
last_idx = _find_bilou_end(idx, predicted_tags)
|
|
277
|
+
|
|
278
|
+
relevant_confidences = predicted_confidences[idx : last_idx + 1]
|
|
279
|
+
relevant_tags = [
|
|
280
|
+
tag_without_prefix(tag) for tag in predicted_tags[idx : last_idx + 1]
|
|
281
|
+
]
|
|
282
|
+
|
|
283
|
+
# if not all tags are the same, for example, B-person I-person L-location
|
|
284
|
+
# we need to check what tag we should use depending on the confidence
|
|
285
|
+
# values and update the tags and confidences accordingly
|
|
286
|
+
if not all(relevant_tags[0] == tag for tag in relevant_tags):
|
|
287
|
+
# decide which tag this entity should use
|
|
288
|
+
tag, tag_score = _tag_to_use(relevant_tags, relevant_confidences)
|
|
289
|
+
|
|
290
|
+
logger.debug(
|
|
291
|
+
f"Using tag '{tag}' for entity with mixed tag labels "
|
|
292
|
+
f"(original tags: {predicted_tags[idx : last_idx + 1]}, "
|
|
293
|
+
f"(original confidences: "
|
|
294
|
+
f"{predicted_confidences[idx : last_idx + 1]})."
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
# all tags that change get the score of that tag assigned
|
|
298
|
+
predicted_confidences = _update_confidences(
|
|
299
|
+
predicted_confidences, predicted_tags, tag, tag_score, idx, last_idx
|
|
300
|
+
)
|
|
301
|
+
|
|
302
|
+
# ensure correct BILOU annotations
|
|
303
|
+
if last_idx == idx:
|
|
304
|
+
predicted_tags[idx] = f"{UNIT}{tag}"
|
|
305
|
+
elif last_idx - idx == 1:
|
|
306
|
+
predicted_tags[idx] = f"{BEGINNING}{tag}"
|
|
307
|
+
predicted_tags[last_idx] = f"{LAST}{tag}"
|
|
308
|
+
else:
|
|
309
|
+
predicted_tags[idx] = f"{BEGINNING}{tag}"
|
|
310
|
+
predicted_tags[last_idx] = f"{LAST}{tag}"
|
|
311
|
+
for i in range(idx + 1, last_idx):
|
|
312
|
+
predicted_tags[i] = f"{INSIDE}{tag}"
|
|
313
|
+
|
|
314
|
+
return predicted_tags, predicted_confidences
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
def _tag_to_use(
|
|
318
|
+
relevant_tags: List[Text], relevant_confidences: List[float]
|
|
319
|
+
) -> Tuple[Text, float]:
|
|
320
|
+
"""Decide what tag to use according to the following metric:
|
|
321
|
+
|
|
322
|
+
Calculate the average confidence per tag.
|
|
323
|
+
Calculate the percentage of tokens assigned to a tag within the entity per tag.
|
|
324
|
+
The harmonic mean of those two metrics is the score for the tag.
|
|
325
|
+
The tag with the highest score is taken as the tag for the entity.
|
|
326
|
+
|
|
327
|
+
Args:
|
|
328
|
+
relevant_tags: The tags of the entity.
|
|
329
|
+
relevant_confidences: The confidence values.
|
|
330
|
+
|
|
331
|
+
Returns:
|
|
332
|
+
The tag to use. The score of that tag.
|
|
333
|
+
"""
|
|
334
|
+
# Calculate the average confidence per tag.
|
|
335
|
+
avg_confidence_per_tag = _avg_confidence_per_tag(
|
|
336
|
+
relevant_tags, relevant_confidences
|
|
337
|
+
)
|
|
338
|
+
# Calculate the percentage of tokens assigned to a tag per tag.
|
|
339
|
+
tag_counts = Counter(relevant_tags)
|
|
340
|
+
token_percentage_per_tag: Dict[Text, float] = {}
|
|
341
|
+
for tag, count in tag_counts.items():
|
|
342
|
+
token_percentage_per_tag[tag] = round(count / len(relevant_tags), 2)
|
|
343
|
+
|
|
344
|
+
# Calculate the harmonic mean between the two metrics per tag.
|
|
345
|
+
score_per_tag = {}
|
|
346
|
+
for tag, token_percentage in token_percentage_per_tag.items():
|
|
347
|
+
avg_confidence = avg_confidence_per_tag[tag]
|
|
348
|
+
score_per_tag[tag] = (
|
|
349
|
+
2
|
|
350
|
+
* (avg_confidence * token_percentage)
|
|
351
|
+
/ (avg_confidence + token_percentage)
|
|
352
|
+
)
|
|
353
|
+
|
|
354
|
+
# Take the tag with the highest score as the tag for the entity
|
|
355
|
+
tag, score = max(score_per_tag.items(), key=operator.itemgetter(1))
|
|
356
|
+
|
|
357
|
+
return tag, score
|
|
358
|
+
|
|
359
|
+
|
|
360
|
+
def _update_confidences(
|
|
361
|
+
predicted_confidences: List[float],
|
|
362
|
+
predicted_tags: List[Text],
|
|
363
|
+
tag: Text,
|
|
364
|
+
score: float,
|
|
365
|
+
idx: int,
|
|
366
|
+
last_idx: int,
|
|
367
|
+
) -> List[float]:
|
|
368
|
+
"""Update the confidence values.
|
|
369
|
+
|
|
370
|
+
Set the confidence value of a tag to score value if the predicated
|
|
371
|
+
tag changed.
|
|
372
|
+
|
|
373
|
+
Args:
|
|
374
|
+
predicted_confidences: The list of predicted confidences.
|
|
375
|
+
predicted_tags: The list of predicted tags.
|
|
376
|
+
tag: The tag of the entity.
|
|
377
|
+
score: The score value of that tag.
|
|
378
|
+
idx: The start index of the entity.
|
|
379
|
+
last_idx: The end index of the entity.
|
|
380
|
+
|
|
381
|
+
Returns:
|
|
382
|
+
The updated list of confidences.
|
|
383
|
+
"""
|
|
384
|
+
for i in range(idx, last_idx + 1):
|
|
385
|
+
predicted_confidences[i] = (
|
|
386
|
+
round(score, 2)
|
|
387
|
+
if tag_without_prefix(predicted_tags[i]) != tag
|
|
388
|
+
else predicted_confidences[i]
|
|
389
|
+
)
|
|
390
|
+
return predicted_confidences
|
|
391
|
+
|
|
392
|
+
|
|
393
|
+
def _avg_confidence_per_tag(
|
|
394
|
+
relevant_tags: List[Text], relevant_confidences: List[float]
|
|
395
|
+
) -> Dict[Text, float]:
|
|
396
|
+
confidences_per_tag = defaultdict(list)
|
|
397
|
+
|
|
398
|
+
for tag, confidence in zip(relevant_tags, relevant_confidences):
|
|
399
|
+
confidences_per_tag[tag].append(confidence)
|
|
400
|
+
|
|
401
|
+
avg_confidence_per_tag = {}
|
|
402
|
+
for tag, confidences in confidences_per_tag.items():
|
|
403
|
+
avg_confidence_per_tag[tag] = round(sum(confidences) / len(confidences), 2)
|
|
404
|
+
|
|
405
|
+
return avg_confidence_per_tag
|
|
406
|
+
|
|
407
|
+
|
|
408
|
+
def _find_bilou_end(start_idx: int, predicted_tags: List[Text]) -> int:
|
|
409
|
+
"""Find the last index of the entity.
|
|
410
|
+
|
|
411
|
+
The start index is pointing to a B- tag. The entity is closed as soon as we find
|
|
412
|
+
a L- tag or a O tag.
|
|
413
|
+
|
|
414
|
+
Args:
|
|
415
|
+
start_idx: The start index of the entity
|
|
416
|
+
predicted_tags: The list of predicted tags
|
|
417
|
+
|
|
418
|
+
Returns:
|
|
419
|
+
The end index of the entity
|
|
420
|
+
"""
|
|
421
|
+
current_idx = start_idx + 1
|
|
422
|
+
finished = False
|
|
423
|
+
start_tag = tag_without_prefix(predicted_tags[start_idx])
|
|
424
|
+
|
|
425
|
+
while not finished:
|
|
426
|
+
if current_idx >= len(predicted_tags):
|
|
427
|
+
logger.debug(
|
|
428
|
+
"Inconsistent BILOU tagging found, B- tag not closed by L- tag, "
|
|
429
|
+
"i.e [B-a, I-a, O] instead of [B-a, L-a, O].\n"
|
|
430
|
+
"Assuming last tag is L- instead of I-."
|
|
431
|
+
)
|
|
432
|
+
current_idx -= 1
|
|
433
|
+
break
|
|
434
|
+
|
|
435
|
+
current_label = predicted_tags[current_idx]
|
|
436
|
+
prefix = bilou_prefix_from_tag(current_label)
|
|
437
|
+
tag = tag_without_prefix(current_label)
|
|
438
|
+
|
|
439
|
+
if tag != start_tag:
|
|
440
|
+
# words are not tagged the same entity class
|
|
441
|
+
logger.debug(
|
|
442
|
+
"Inconsistent BILOU tagging found, B- tag, L- tag pair encloses "
|
|
443
|
+
"multiple entity classes.i.e. [B-a, I-b, L-a] instead of "
|
|
444
|
+
"[B-a, I-a, L-a].\nAssuming B- class is correct."
|
|
445
|
+
)
|
|
446
|
+
|
|
447
|
+
if prefix == LAST:
|
|
448
|
+
finished = True
|
|
449
|
+
elif prefix == INSIDE:
|
|
450
|
+
# middle part of the entity
|
|
451
|
+
current_idx += 1
|
|
452
|
+
else:
|
|
453
|
+
# entity not closed by an L- tag
|
|
454
|
+
finished = True
|
|
455
|
+
current_idx -= 1
|
|
456
|
+
logger.debug(
|
|
457
|
+
"Inconsistent BILOU tagging found, B- tag not closed by L- tag, "
|
|
458
|
+
"i.e [B-a, I-a, O] instead of [B-a, L-a, O].\n"
|
|
459
|
+
"Assuming last tag is L- instead of I-."
|
|
460
|
+
)
|
|
461
|
+
|
|
462
|
+
return current_idx
|
|
File without changes
|
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Dict, Text, Type
|
|
3
|
+
|
|
4
|
+
# Explicitly set logging level for this module before any import
|
|
5
|
+
# because otherwise it logs tensorflow/pytorch versions
|
|
6
|
+
logging.getLogger("transformers.file_utils").setLevel(logging.WARNING)
|
|
7
|
+
|
|
8
|
+
from transformers import ( # noqa: E402
|
|
9
|
+
TFPreTrainedModel,
|
|
10
|
+
TFBertModel,
|
|
11
|
+
TFOpenAIGPTModel,
|
|
12
|
+
TFGPT2Model,
|
|
13
|
+
TFXLNetModel,
|
|
14
|
+
# TFXLMModel,
|
|
15
|
+
TFDistilBertModel,
|
|
16
|
+
TFRobertaModel,
|
|
17
|
+
TFCamembertModel,
|
|
18
|
+
PreTrainedTokenizer,
|
|
19
|
+
BertTokenizer,
|
|
20
|
+
OpenAIGPTTokenizer,
|
|
21
|
+
GPT2Tokenizer,
|
|
22
|
+
XLNetTokenizer,
|
|
23
|
+
# XLMTokenizer,
|
|
24
|
+
DistilBertTokenizer,
|
|
25
|
+
RobertaTokenizer,
|
|
26
|
+
CamembertTokenizer,
|
|
27
|
+
)
|
|
28
|
+
from rasa.nlu.utils.hugging_face.transformers_pre_post_processors import ( # noqa: E402
|
|
29
|
+
bert_tokens_pre_processor,
|
|
30
|
+
gpt_tokens_pre_processor,
|
|
31
|
+
xlnet_tokens_pre_processor,
|
|
32
|
+
roberta_tokens_pre_processor,
|
|
33
|
+
bert_embeddings_post_processor,
|
|
34
|
+
gpt_embeddings_post_processor,
|
|
35
|
+
xlnet_embeddings_post_processor,
|
|
36
|
+
roberta_embeddings_post_processor,
|
|
37
|
+
bert_tokens_cleaner,
|
|
38
|
+
openaigpt_tokens_cleaner,
|
|
39
|
+
gpt2_tokens_cleaner,
|
|
40
|
+
xlnet_tokens_cleaner,
|
|
41
|
+
camembert_tokens_pre_processor,
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
model_class_dict: Dict[Text, Type[TFPreTrainedModel]] = {
|
|
46
|
+
"bert": TFBertModel,
|
|
47
|
+
"gpt": TFOpenAIGPTModel,
|
|
48
|
+
"gpt2": TFGPT2Model,
|
|
49
|
+
"xlnet": TFXLNetModel,
|
|
50
|
+
# "xlm": TFXLMModel, # Currently doesn't work because of a bug in transformers
|
|
51
|
+
# library https://github.com/huggingface/transformers/issues/2729
|
|
52
|
+
"distilbert": TFDistilBertModel,
|
|
53
|
+
"roberta": TFRobertaModel,
|
|
54
|
+
"camembert": TFCamembertModel,
|
|
55
|
+
}
|
|
56
|
+
model_tokenizer_dict: Dict[Text, Type[PreTrainedTokenizer]] = {
|
|
57
|
+
"bert": BertTokenizer,
|
|
58
|
+
"gpt": OpenAIGPTTokenizer,
|
|
59
|
+
"gpt2": GPT2Tokenizer,
|
|
60
|
+
"xlnet": XLNetTokenizer,
|
|
61
|
+
# "xlm": XLMTokenizer,
|
|
62
|
+
"distilbert": DistilBertTokenizer,
|
|
63
|
+
"roberta": RobertaTokenizer,
|
|
64
|
+
"camembert": CamembertTokenizer,
|
|
65
|
+
}
|
|
66
|
+
model_weights_defaults = {
|
|
67
|
+
"bert": "rasa/LaBSE",
|
|
68
|
+
"gpt": "openai-gpt",
|
|
69
|
+
"gpt2": "gpt2",
|
|
70
|
+
"xlnet": "xlnet-base-cased",
|
|
71
|
+
# "xlm": "xlm-mlm-enfr-1024",
|
|
72
|
+
"distilbert": "distilbert-base-uncased",
|
|
73
|
+
"roberta": "roberta-base",
|
|
74
|
+
"camembert": "camembert-base",
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
model_special_tokens_pre_processors = {
|
|
78
|
+
"bert": bert_tokens_pre_processor,
|
|
79
|
+
"gpt": gpt_tokens_pre_processor,
|
|
80
|
+
"gpt2": gpt_tokens_pre_processor,
|
|
81
|
+
"xlnet": xlnet_tokens_pre_processor,
|
|
82
|
+
# "xlm": xlm_tokens_pre_processor,
|
|
83
|
+
"distilbert": bert_tokens_pre_processor,
|
|
84
|
+
"roberta": roberta_tokens_pre_processor,
|
|
85
|
+
"camembert": camembert_tokens_pre_processor,
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
model_tokens_cleaners = {
|
|
89
|
+
"bert": bert_tokens_cleaner,
|
|
90
|
+
"gpt": openaigpt_tokens_cleaner,
|
|
91
|
+
"gpt2": gpt2_tokens_cleaner,
|
|
92
|
+
"xlnet": xlnet_tokens_cleaner,
|
|
93
|
+
# "xlm": xlm_tokens_pre_processor,
|
|
94
|
+
"distilbert": bert_tokens_cleaner, # uses the same as BERT
|
|
95
|
+
"roberta": gpt2_tokens_cleaner, # Uses the same as GPT2
|
|
96
|
+
"camembert": xlnet_tokens_cleaner, # Removing underscores _
|
|
97
|
+
}
|
|
98
|
+
|
|
99
|
+
model_embeddings_post_processors = {
|
|
100
|
+
"bert": bert_embeddings_post_processor,
|
|
101
|
+
"gpt": gpt_embeddings_post_processor,
|
|
102
|
+
"gpt2": gpt_embeddings_post_processor,
|
|
103
|
+
"xlnet": xlnet_embeddings_post_processor,
|
|
104
|
+
# "xlm": xlm_embeddings_post_processor,
|
|
105
|
+
"distilbert": bert_embeddings_post_processor,
|
|
106
|
+
"roberta": roberta_embeddings_post_processor,
|
|
107
|
+
"camembert": roberta_embeddings_post_processor,
|
|
108
|
+
}
|