rasa-pro 3.12.0.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rasa-pro might be problematic. Click here for more details.

Files changed (790) hide show
  1. README.md +41 -0
  2. rasa/__init__.py +9 -0
  3. rasa/__main__.py +177 -0
  4. rasa/anonymization/__init__.py +2 -0
  5. rasa/anonymization/anonymisation_rule_yaml_reader.py +91 -0
  6. rasa/anonymization/anonymization_pipeline.py +286 -0
  7. rasa/anonymization/anonymization_rule_executor.py +260 -0
  8. rasa/anonymization/anonymization_rule_orchestrator.py +120 -0
  9. rasa/anonymization/schemas/config.yml +47 -0
  10. rasa/anonymization/utils.py +118 -0
  11. rasa/api.py +160 -0
  12. rasa/cli/__init__.py +5 -0
  13. rasa/cli/arguments/__init__.py +0 -0
  14. rasa/cli/arguments/data.py +106 -0
  15. rasa/cli/arguments/default_arguments.py +207 -0
  16. rasa/cli/arguments/evaluate.py +65 -0
  17. rasa/cli/arguments/export.py +51 -0
  18. rasa/cli/arguments/interactive.py +74 -0
  19. rasa/cli/arguments/run.py +219 -0
  20. rasa/cli/arguments/shell.py +17 -0
  21. rasa/cli/arguments/test.py +211 -0
  22. rasa/cli/arguments/train.py +279 -0
  23. rasa/cli/arguments/visualize.py +34 -0
  24. rasa/cli/arguments/x.py +30 -0
  25. rasa/cli/data.py +354 -0
  26. rasa/cli/dialogue_understanding_test.py +251 -0
  27. rasa/cli/e2e_test.py +259 -0
  28. rasa/cli/evaluate.py +222 -0
  29. rasa/cli/export.py +250 -0
  30. rasa/cli/inspect.py +75 -0
  31. rasa/cli/interactive.py +166 -0
  32. rasa/cli/license.py +65 -0
  33. rasa/cli/llm_fine_tuning.py +403 -0
  34. rasa/cli/markers.py +78 -0
  35. rasa/cli/project_templates/__init__.py +0 -0
  36. rasa/cli/project_templates/calm/actions/__init__.py +0 -0
  37. rasa/cli/project_templates/calm/actions/action_template.py +27 -0
  38. rasa/cli/project_templates/calm/actions/add_contact.py +30 -0
  39. rasa/cli/project_templates/calm/actions/db.py +57 -0
  40. rasa/cli/project_templates/calm/actions/list_contacts.py +22 -0
  41. rasa/cli/project_templates/calm/actions/remove_contact.py +35 -0
  42. rasa/cli/project_templates/calm/config.yml +10 -0
  43. rasa/cli/project_templates/calm/credentials.yml +33 -0
  44. rasa/cli/project_templates/calm/data/flows/add_contact.yml +31 -0
  45. rasa/cli/project_templates/calm/data/flows/list_contacts.yml +14 -0
  46. rasa/cli/project_templates/calm/data/flows/remove_contact.yml +29 -0
  47. rasa/cli/project_templates/calm/db/contacts.json +10 -0
  48. rasa/cli/project_templates/calm/domain/add_contact.yml +39 -0
  49. rasa/cli/project_templates/calm/domain/list_contacts.yml +17 -0
  50. rasa/cli/project_templates/calm/domain/remove_contact.yml +38 -0
  51. rasa/cli/project_templates/calm/domain/shared.yml +10 -0
  52. rasa/cli/project_templates/calm/e2e_tests/cancelations/user_cancels_during_a_correction.yml +16 -0
  53. rasa/cli/project_templates/calm/e2e_tests/cancelations/user_changes_mind_on_a_whim.yml +7 -0
  54. rasa/cli/project_templates/calm/e2e_tests/corrections/user_corrects_contact_handle.yml +20 -0
  55. rasa/cli/project_templates/calm/e2e_tests/corrections/user_corrects_contact_name.yml +19 -0
  56. rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_adds_contact_to_their_list.yml +15 -0
  57. rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_lists_contacts.yml +5 -0
  58. rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_removes_contact.yml +11 -0
  59. rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_removes_contact_from_list.yml +12 -0
  60. rasa/cli/project_templates/calm/endpoints.yml +58 -0
  61. rasa/cli/project_templates/default/actions/__init__.py +0 -0
  62. rasa/cli/project_templates/default/actions/actions.py +27 -0
  63. rasa/cli/project_templates/default/config.yml +44 -0
  64. rasa/cli/project_templates/default/credentials.yml +33 -0
  65. rasa/cli/project_templates/default/data/nlu.yml +91 -0
  66. rasa/cli/project_templates/default/data/rules.yml +13 -0
  67. rasa/cli/project_templates/default/data/stories.yml +30 -0
  68. rasa/cli/project_templates/default/domain.yml +34 -0
  69. rasa/cli/project_templates/default/endpoints.yml +42 -0
  70. rasa/cli/project_templates/default/tests/test_stories.yml +91 -0
  71. rasa/cli/project_templates/tutorial/actions/__init__.py +0 -0
  72. rasa/cli/project_templates/tutorial/actions/actions.py +22 -0
  73. rasa/cli/project_templates/tutorial/config.yml +12 -0
  74. rasa/cli/project_templates/tutorial/credentials.yml +33 -0
  75. rasa/cli/project_templates/tutorial/data/flows.yml +8 -0
  76. rasa/cli/project_templates/tutorial/data/patterns.yml +11 -0
  77. rasa/cli/project_templates/tutorial/domain.yml +35 -0
  78. rasa/cli/project_templates/tutorial/endpoints.yml +55 -0
  79. rasa/cli/run.py +143 -0
  80. rasa/cli/scaffold.py +273 -0
  81. rasa/cli/shell.py +141 -0
  82. rasa/cli/studio/__init__.py +0 -0
  83. rasa/cli/studio/download.py +62 -0
  84. rasa/cli/studio/studio.py +296 -0
  85. rasa/cli/studio/train.py +59 -0
  86. rasa/cli/studio/upload.py +62 -0
  87. rasa/cli/telemetry.py +102 -0
  88. rasa/cli/test.py +280 -0
  89. rasa/cli/train.py +278 -0
  90. rasa/cli/utils.py +484 -0
  91. rasa/cli/visualize.py +40 -0
  92. rasa/cli/x.py +206 -0
  93. rasa/constants.py +45 -0
  94. rasa/core/__init__.py +17 -0
  95. rasa/core/actions/__init__.py +0 -0
  96. rasa/core/actions/action.py +1318 -0
  97. rasa/core/actions/action_clean_stack.py +59 -0
  98. rasa/core/actions/action_exceptions.py +24 -0
  99. rasa/core/actions/action_hangup.py +29 -0
  100. rasa/core/actions/action_repeat_bot_messages.py +89 -0
  101. rasa/core/actions/action_run_slot_rejections.py +210 -0
  102. rasa/core/actions/action_trigger_chitchat.py +31 -0
  103. rasa/core/actions/action_trigger_flow.py +109 -0
  104. rasa/core/actions/action_trigger_search.py +31 -0
  105. rasa/core/actions/constants.py +5 -0
  106. rasa/core/actions/custom_action_executor.py +191 -0
  107. rasa/core/actions/direct_custom_actions_executor.py +109 -0
  108. rasa/core/actions/e2e_stub_custom_action_executor.py +72 -0
  109. rasa/core/actions/forms.py +741 -0
  110. rasa/core/actions/grpc_custom_action_executor.py +251 -0
  111. rasa/core/actions/http_custom_action_executor.py +145 -0
  112. rasa/core/actions/loops.py +114 -0
  113. rasa/core/actions/two_stage_fallback.py +186 -0
  114. rasa/core/agent.py +559 -0
  115. rasa/core/auth_retry_tracker_store.py +122 -0
  116. rasa/core/brokers/__init__.py +0 -0
  117. rasa/core/brokers/broker.py +126 -0
  118. rasa/core/brokers/file.py +58 -0
  119. rasa/core/brokers/kafka.py +324 -0
  120. rasa/core/brokers/pika.py +388 -0
  121. rasa/core/brokers/sql.py +86 -0
  122. rasa/core/channels/__init__.py +61 -0
  123. rasa/core/channels/botframework.py +338 -0
  124. rasa/core/channels/callback.py +84 -0
  125. rasa/core/channels/channel.py +456 -0
  126. rasa/core/channels/console.py +241 -0
  127. rasa/core/channels/development_inspector.py +197 -0
  128. rasa/core/channels/facebook.py +419 -0
  129. rasa/core/channels/hangouts.py +329 -0
  130. rasa/core/channels/inspector/.eslintrc.cjs +25 -0
  131. rasa/core/channels/inspector/.gitignore +23 -0
  132. rasa/core/channels/inspector/README.md +54 -0
  133. rasa/core/channels/inspector/assets/favicon.ico +0 -0
  134. rasa/core/channels/inspector/assets/rasa-chat.js +2 -0
  135. rasa/core/channels/inspector/custom.d.ts +3 -0
  136. rasa/core/channels/inspector/dist/assets/arc-861ddd57.js +1 -0
  137. rasa/core/channels/inspector/dist/assets/array-9f3ba611.js +1 -0
  138. rasa/core/channels/inspector/dist/assets/c4Diagram-d0fbc5ce-921f02db.js +10 -0
  139. rasa/core/channels/inspector/dist/assets/classDiagram-936ed81e-b436c4f8.js +2 -0
  140. rasa/core/channels/inspector/dist/assets/classDiagram-v2-c3cb15f1-511a23cb.js +2 -0
  141. rasa/core/channels/inspector/dist/assets/createText-62fc7601-ef476ecd.js +7 -0
  142. rasa/core/channels/inspector/dist/assets/edges-f2ad444c-f1878e0a.js +4 -0
  143. rasa/core/channels/inspector/dist/assets/erDiagram-9d236eb7-fac75185.js +51 -0
  144. rasa/core/channels/inspector/dist/assets/flowDb-1972c806-201c5bbc.js +6 -0
  145. rasa/core/channels/inspector/dist/assets/flowDiagram-7ea5b25a-f904ae41.js +4 -0
  146. rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-b080d6f2.js +1 -0
  147. rasa/core/channels/inspector/dist/assets/flowchart-elk-definition-abe16c3d-1813da66.js +139 -0
  148. rasa/core/channels/inspector/dist/assets/ganttDiagram-9b5ea136-872af172.js +266 -0
  149. rasa/core/channels/inspector/dist/assets/gitGraphDiagram-99d0ae7c-34a0af5a.js +70 -0
  150. rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-128cfa44.ttf +0 -0
  151. rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-21dbcb97.woff +0 -0
  152. rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-222b5e26.svg +329 -0
  153. rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-9ad89b2a.woff2 +0 -0
  154. rasa/core/channels/inspector/dist/assets/index-2c4b9a3b-42ba3e3d.js +1 -0
  155. rasa/core/channels/inspector/dist/assets/index-37817b51.js +1317 -0
  156. rasa/core/channels/inspector/dist/assets/index-3ee28881.css +1 -0
  157. rasa/core/channels/inspector/dist/assets/infoDiagram-736b4530-6b731386.js +7 -0
  158. rasa/core/channels/inspector/dist/assets/init-77b53fdd.js +1 -0
  159. rasa/core/channels/inspector/dist/assets/journeyDiagram-df861f2b-e8579ac6.js +139 -0
  160. rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-60c05ee4.woff +0 -0
  161. rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-8335d9b8.svg +438 -0
  162. rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-9cc39c75.ttf +0 -0
  163. rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-ead13ccf.woff2 +0 -0
  164. rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-16705655.woff2 +0 -0
  165. rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-5aeb07f9.woff +0 -0
  166. rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-9c459044.ttf +0 -0
  167. rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-9e2898a4.svg +435 -0
  168. rasa/core/channels/inspector/dist/assets/layout-89e6403a.js +1 -0
  169. rasa/core/channels/inspector/dist/assets/line-dc73d3fc.js +1 -0
  170. rasa/core/channels/inspector/dist/assets/linear-f5b1d2bc.js +1 -0
  171. rasa/core/channels/inspector/dist/assets/mindmap-definition-beec6740-82cb74fa.js +109 -0
  172. rasa/core/channels/inspector/dist/assets/ordinal-ba9b4969.js +1 -0
  173. rasa/core/channels/inspector/dist/assets/path-53f90ab3.js +1 -0
  174. rasa/core/channels/inspector/dist/assets/pieDiagram-dbbf0591-bdf5f29b.js +35 -0
  175. rasa/core/channels/inspector/dist/assets/quadrantDiagram-4d7f4fd6-c7a0cbe4.js +7 -0
  176. rasa/core/channels/inspector/dist/assets/requirementDiagram-6fc4c22a-7ec5410f.js +52 -0
  177. rasa/core/channels/inspector/dist/assets/sankeyDiagram-8f13d901-caee5554.js +8 -0
  178. rasa/core/channels/inspector/dist/assets/sequenceDiagram-b655622a-2935f8db.js +122 -0
  179. rasa/core/channels/inspector/dist/assets/stateDiagram-59f0c015-8f5d9693.js +1 -0
  180. rasa/core/channels/inspector/dist/assets/stateDiagram-v2-2b26beab-d565d1de.js +1 -0
  181. rasa/core/channels/inspector/dist/assets/styles-080da4f6-75ad421d.js +110 -0
  182. rasa/core/channels/inspector/dist/assets/styles-3dcbcfbf-7e764226.js +159 -0
  183. rasa/core/channels/inspector/dist/assets/styles-9c745c82-7a4e0e61.js +207 -0
  184. rasa/core/channels/inspector/dist/assets/svgDrawCommon-4835440b-4019d1bf.js +1 -0
  185. rasa/core/channels/inspector/dist/assets/timeline-definition-5b62e21b-01ea12df.js +61 -0
  186. rasa/core/channels/inspector/dist/assets/xychartDiagram-2b33534f-89407137.js +7 -0
  187. rasa/core/channels/inspector/dist/index.html +42 -0
  188. rasa/core/channels/inspector/index.html +40 -0
  189. rasa/core/channels/inspector/jest.config.ts +13 -0
  190. rasa/core/channels/inspector/package.json +52 -0
  191. rasa/core/channels/inspector/setupTests.ts +2 -0
  192. rasa/core/channels/inspector/src/App.tsx +220 -0
  193. rasa/core/channels/inspector/src/components/Chat.tsx +95 -0
  194. rasa/core/channels/inspector/src/components/DiagramFlow.tsx +108 -0
  195. rasa/core/channels/inspector/src/components/DialogueInformation.tsx +187 -0
  196. rasa/core/channels/inspector/src/components/DialogueStack.tsx +136 -0
  197. rasa/core/channels/inspector/src/components/ExpandIcon.tsx +16 -0
  198. rasa/core/channels/inspector/src/components/FullscreenButton.tsx +45 -0
  199. rasa/core/channels/inspector/src/components/LoadingSpinner.tsx +22 -0
  200. rasa/core/channels/inspector/src/components/NoActiveFlow.tsx +21 -0
  201. rasa/core/channels/inspector/src/components/RasaLogo.tsx +32 -0
  202. rasa/core/channels/inspector/src/components/SaraDiagrams.tsx +39 -0
  203. rasa/core/channels/inspector/src/components/Slots.tsx +91 -0
  204. rasa/core/channels/inspector/src/components/Welcome.tsx +54 -0
  205. rasa/core/channels/inspector/src/helpers/audiostream.ts +191 -0
  206. rasa/core/channels/inspector/src/helpers/formatters.test.ts +392 -0
  207. rasa/core/channels/inspector/src/helpers/formatters.ts +306 -0
  208. rasa/core/channels/inspector/src/helpers/utils.ts +127 -0
  209. rasa/core/channels/inspector/src/main.tsx +13 -0
  210. rasa/core/channels/inspector/src/theme/Button/Button.ts +29 -0
  211. rasa/core/channels/inspector/src/theme/Heading/Heading.ts +31 -0
  212. rasa/core/channels/inspector/src/theme/Input/Input.ts +27 -0
  213. rasa/core/channels/inspector/src/theme/Link/Link.ts +10 -0
  214. rasa/core/channels/inspector/src/theme/Modal/Modal.ts +47 -0
  215. rasa/core/channels/inspector/src/theme/Table/Table.tsx +38 -0
  216. rasa/core/channels/inspector/src/theme/Tooltip/Tooltip.ts +12 -0
  217. rasa/core/channels/inspector/src/theme/base/breakpoints.ts +8 -0
  218. rasa/core/channels/inspector/src/theme/base/colors.ts +88 -0
  219. rasa/core/channels/inspector/src/theme/base/fonts/fontFaces.css +29 -0
  220. rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.eot +0 -0
  221. rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.svg +329 -0
  222. rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.ttf +0 -0
  223. rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.woff +0 -0
  224. rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.woff2 +0 -0
  225. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.eot +0 -0
  226. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.svg +438 -0
  227. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.ttf +0 -0
  228. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.woff +0 -0
  229. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.woff2 +0 -0
  230. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.eot +0 -0
  231. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.svg +435 -0
  232. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.ttf +0 -0
  233. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.woff +0 -0
  234. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.woff2 +0 -0
  235. rasa/core/channels/inspector/src/theme/base/radii.ts +9 -0
  236. rasa/core/channels/inspector/src/theme/base/shadows.ts +7 -0
  237. rasa/core/channels/inspector/src/theme/base/sizes.ts +7 -0
  238. rasa/core/channels/inspector/src/theme/base/space.ts +15 -0
  239. rasa/core/channels/inspector/src/theme/base/styles.ts +13 -0
  240. rasa/core/channels/inspector/src/theme/base/typography.ts +24 -0
  241. rasa/core/channels/inspector/src/theme/base/zIndices.ts +19 -0
  242. rasa/core/channels/inspector/src/theme/index.ts +101 -0
  243. rasa/core/channels/inspector/src/types.ts +84 -0
  244. rasa/core/channels/inspector/src/vite-env.d.ts +1 -0
  245. rasa/core/channels/inspector/tests/__mocks__/fileMock.ts +1 -0
  246. rasa/core/channels/inspector/tests/__mocks__/matchMedia.ts +16 -0
  247. rasa/core/channels/inspector/tests/__mocks__/styleMock.ts +1 -0
  248. rasa/core/channels/inspector/tests/renderWithProviders.tsx +14 -0
  249. rasa/core/channels/inspector/tsconfig.json +26 -0
  250. rasa/core/channels/inspector/tsconfig.node.json +10 -0
  251. rasa/core/channels/inspector/vite.config.ts +8 -0
  252. rasa/core/channels/inspector/yarn.lock +6249 -0
  253. rasa/core/channels/mattermost.py +229 -0
  254. rasa/core/channels/rasa_chat.py +126 -0
  255. rasa/core/channels/rest.py +230 -0
  256. rasa/core/channels/rocketchat.py +174 -0
  257. rasa/core/channels/slack.py +620 -0
  258. rasa/core/channels/socketio.py +302 -0
  259. rasa/core/channels/telegram.py +298 -0
  260. rasa/core/channels/twilio.py +169 -0
  261. rasa/core/channels/vier_cvg.py +374 -0
  262. rasa/core/channels/voice_ready/__init__.py +0 -0
  263. rasa/core/channels/voice_ready/audiocodes.py +501 -0
  264. rasa/core/channels/voice_ready/jambonz.py +121 -0
  265. rasa/core/channels/voice_ready/jambonz_protocol.py +396 -0
  266. rasa/core/channels/voice_ready/twilio_voice.py +403 -0
  267. rasa/core/channels/voice_ready/utils.py +37 -0
  268. rasa/core/channels/voice_stream/__init__.py +0 -0
  269. rasa/core/channels/voice_stream/asr/__init__.py +0 -0
  270. rasa/core/channels/voice_stream/asr/asr_engine.py +89 -0
  271. rasa/core/channels/voice_stream/asr/asr_event.py +18 -0
  272. rasa/core/channels/voice_stream/asr/azure.py +130 -0
  273. rasa/core/channels/voice_stream/asr/deepgram.py +90 -0
  274. rasa/core/channels/voice_stream/audio_bytes.py +8 -0
  275. rasa/core/channels/voice_stream/browser_audio.py +107 -0
  276. rasa/core/channels/voice_stream/call_state.py +23 -0
  277. rasa/core/channels/voice_stream/tts/__init__.py +0 -0
  278. rasa/core/channels/voice_stream/tts/azure.py +106 -0
  279. rasa/core/channels/voice_stream/tts/cartesia.py +118 -0
  280. rasa/core/channels/voice_stream/tts/tts_cache.py +27 -0
  281. rasa/core/channels/voice_stream/tts/tts_engine.py +58 -0
  282. rasa/core/channels/voice_stream/twilio_media_streams.py +173 -0
  283. rasa/core/channels/voice_stream/util.py +57 -0
  284. rasa/core/channels/voice_stream/voice_channel.py +427 -0
  285. rasa/core/channels/webexteams.py +134 -0
  286. rasa/core/concurrent_lock_store.py +210 -0
  287. rasa/core/constants.py +112 -0
  288. rasa/core/evaluation/__init__.py +0 -0
  289. rasa/core/evaluation/marker.py +267 -0
  290. rasa/core/evaluation/marker_base.py +923 -0
  291. rasa/core/evaluation/marker_stats.py +293 -0
  292. rasa/core/evaluation/marker_tracker_loader.py +103 -0
  293. rasa/core/exceptions.py +29 -0
  294. rasa/core/exporter.py +284 -0
  295. rasa/core/featurizers/__init__.py +0 -0
  296. rasa/core/featurizers/precomputation.py +410 -0
  297. rasa/core/featurizers/single_state_featurizer.py +421 -0
  298. rasa/core/featurizers/tracker_featurizers.py +1262 -0
  299. rasa/core/http_interpreter.py +89 -0
  300. rasa/core/information_retrieval/__init__.py +7 -0
  301. rasa/core/information_retrieval/faiss.py +124 -0
  302. rasa/core/information_retrieval/information_retrieval.py +137 -0
  303. rasa/core/information_retrieval/milvus.py +59 -0
  304. rasa/core/information_retrieval/qdrant.py +96 -0
  305. rasa/core/jobs.py +63 -0
  306. rasa/core/lock.py +139 -0
  307. rasa/core/lock_store.py +343 -0
  308. rasa/core/migrate.py +403 -0
  309. rasa/core/nlg/__init__.py +3 -0
  310. rasa/core/nlg/callback.py +146 -0
  311. rasa/core/nlg/contextual_response_rephraser.py +320 -0
  312. rasa/core/nlg/generator.py +230 -0
  313. rasa/core/nlg/interpolator.py +143 -0
  314. rasa/core/nlg/response.py +155 -0
  315. rasa/core/nlg/summarize.py +70 -0
  316. rasa/core/persistor.py +538 -0
  317. rasa/core/policies/__init__.py +0 -0
  318. rasa/core/policies/ensemble.py +329 -0
  319. rasa/core/policies/enterprise_search_policy.py +905 -0
  320. rasa/core/policies/enterprise_search_prompt_template.jinja2 +25 -0
  321. rasa/core/policies/enterprise_search_prompt_with_citation_template.jinja2 +60 -0
  322. rasa/core/policies/flow_policy.py +205 -0
  323. rasa/core/policies/flows/__init__.py +0 -0
  324. rasa/core/policies/flows/flow_exceptions.py +44 -0
  325. rasa/core/policies/flows/flow_executor.py +754 -0
  326. rasa/core/policies/flows/flow_step_result.py +43 -0
  327. rasa/core/policies/intentless_policy.py +1031 -0
  328. rasa/core/policies/intentless_prompt_template.jinja2 +22 -0
  329. rasa/core/policies/memoization.py +538 -0
  330. rasa/core/policies/policy.py +725 -0
  331. rasa/core/policies/rule_policy.py +1273 -0
  332. rasa/core/policies/ted_policy.py +2169 -0
  333. rasa/core/policies/unexpected_intent_policy.py +1022 -0
  334. rasa/core/processor.py +1465 -0
  335. rasa/core/run.py +342 -0
  336. rasa/core/secrets_manager/__init__.py +0 -0
  337. rasa/core/secrets_manager/constants.py +36 -0
  338. rasa/core/secrets_manager/endpoints.py +391 -0
  339. rasa/core/secrets_manager/factory.py +241 -0
  340. rasa/core/secrets_manager/secret_manager.py +262 -0
  341. rasa/core/secrets_manager/vault.py +584 -0
  342. rasa/core/test.py +1335 -0
  343. rasa/core/tracker_store.py +1703 -0
  344. rasa/core/train.py +105 -0
  345. rasa/core/training/__init__.py +89 -0
  346. rasa/core/training/converters/__init__.py +0 -0
  347. rasa/core/training/converters/responses_prefix_converter.py +119 -0
  348. rasa/core/training/interactive.py +1744 -0
  349. rasa/core/training/story_conflict.py +381 -0
  350. rasa/core/training/training.py +93 -0
  351. rasa/core/utils.py +366 -0
  352. rasa/core/visualize.py +70 -0
  353. rasa/dialogue_understanding/__init__.py +0 -0
  354. rasa/dialogue_understanding/coexistence/__init__.py +0 -0
  355. rasa/dialogue_understanding/coexistence/constants.py +4 -0
  356. rasa/dialogue_understanding/coexistence/intent_based_router.py +196 -0
  357. rasa/dialogue_understanding/coexistence/llm_based_router.py +327 -0
  358. rasa/dialogue_understanding/coexistence/router_template.jinja2 +12 -0
  359. rasa/dialogue_understanding/commands/__init__.py +61 -0
  360. rasa/dialogue_understanding/commands/can_not_handle_command.py +70 -0
  361. rasa/dialogue_understanding/commands/cancel_flow_command.py +125 -0
  362. rasa/dialogue_understanding/commands/change_flow_command.py +44 -0
  363. rasa/dialogue_understanding/commands/chit_chat_answer_command.py +57 -0
  364. rasa/dialogue_understanding/commands/clarify_command.py +86 -0
  365. rasa/dialogue_understanding/commands/command.py +85 -0
  366. rasa/dialogue_understanding/commands/correct_slots_command.py +297 -0
  367. rasa/dialogue_understanding/commands/error_command.py +79 -0
  368. rasa/dialogue_understanding/commands/free_form_answer_command.py +9 -0
  369. rasa/dialogue_understanding/commands/handle_code_change_command.py +73 -0
  370. rasa/dialogue_understanding/commands/human_handoff_command.py +66 -0
  371. rasa/dialogue_understanding/commands/knowledge_answer_command.py +57 -0
  372. rasa/dialogue_understanding/commands/noop_command.py +54 -0
  373. rasa/dialogue_understanding/commands/repeat_bot_messages_command.py +60 -0
  374. rasa/dialogue_understanding/commands/restart_command.py +58 -0
  375. rasa/dialogue_understanding/commands/session_end_command.py +61 -0
  376. rasa/dialogue_understanding/commands/session_start_command.py +59 -0
  377. rasa/dialogue_understanding/commands/set_slot_command.py +160 -0
  378. rasa/dialogue_understanding/commands/skip_question_command.py +75 -0
  379. rasa/dialogue_understanding/commands/start_flow_command.py +107 -0
  380. rasa/dialogue_understanding/commands/user_silence_command.py +59 -0
  381. rasa/dialogue_understanding/commands/utils.py +45 -0
  382. rasa/dialogue_understanding/generator/__init__.py +21 -0
  383. rasa/dialogue_understanding/generator/command_generator.py +464 -0
  384. rasa/dialogue_understanding/generator/constants.py +27 -0
  385. rasa/dialogue_understanding/generator/flow_document_template.jinja2 +4 -0
  386. rasa/dialogue_understanding/generator/flow_retrieval.py +466 -0
  387. rasa/dialogue_understanding/generator/llm_based_command_generator.py +500 -0
  388. rasa/dialogue_understanding/generator/llm_command_generator.py +67 -0
  389. rasa/dialogue_understanding/generator/multi_step/__init__.py +0 -0
  390. rasa/dialogue_understanding/generator/multi_step/fill_slots_prompt.jinja2 +62 -0
  391. rasa/dialogue_understanding/generator/multi_step/handle_flows_prompt.jinja2 +38 -0
  392. rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +920 -0
  393. rasa/dialogue_understanding/generator/nlu_command_adapter.py +261 -0
  394. rasa/dialogue_understanding/generator/single_step/__init__.py +0 -0
  395. rasa/dialogue_understanding/generator/single_step/command_prompt_template.jinja2 +60 -0
  396. rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +486 -0
  397. rasa/dialogue_understanding/patterns/__init__.py +0 -0
  398. rasa/dialogue_understanding/patterns/cancel.py +111 -0
  399. rasa/dialogue_understanding/patterns/cannot_handle.py +43 -0
  400. rasa/dialogue_understanding/patterns/chitchat.py +37 -0
  401. rasa/dialogue_understanding/patterns/clarify.py +97 -0
  402. rasa/dialogue_understanding/patterns/code_change.py +41 -0
  403. rasa/dialogue_understanding/patterns/collect_information.py +90 -0
  404. rasa/dialogue_understanding/patterns/completed.py +40 -0
  405. rasa/dialogue_understanding/patterns/continue_interrupted.py +42 -0
  406. rasa/dialogue_understanding/patterns/correction.py +278 -0
  407. rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +301 -0
  408. rasa/dialogue_understanding/patterns/human_handoff.py +37 -0
  409. rasa/dialogue_understanding/patterns/internal_error.py +47 -0
  410. rasa/dialogue_understanding/patterns/repeat.py +37 -0
  411. rasa/dialogue_understanding/patterns/restart.py +37 -0
  412. rasa/dialogue_understanding/patterns/search.py +37 -0
  413. rasa/dialogue_understanding/patterns/session_start.py +37 -0
  414. rasa/dialogue_understanding/patterns/skip_question.py +38 -0
  415. rasa/dialogue_understanding/patterns/user_silence.py +37 -0
  416. rasa/dialogue_understanding/processor/__init__.py +0 -0
  417. rasa/dialogue_understanding/processor/command_processor.py +720 -0
  418. rasa/dialogue_understanding/processor/command_processor_component.py +43 -0
  419. rasa/dialogue_understanding/stack/__init__.py +0 -0
  420. rasa/dialogue_understanding/stack/dialogue_stack.py +178 -0
  421. rasa/dialogue_understanding/stack/frames/__init__.py +19 -0
  422. rasa/dialogue_understanding/stack/frames/chit_chat_frame.py +27 -0
  423. rasa/dialogue_understanding/stack/frames/dialogue_stack_frame.py +137 -0
  424. rasa/dialogue_understanding/stack/frames/flow_stack_frame.py +157 -0
  425. rasa/dialogue_understanding/stack/frames/pattern_frame.py +10 -0
  426. rasa/dialogue_understanding/stack/frames/search_frame.py +27 -0
  427. rasa/dialogue_understanding/stack/utils.py +211 -0
  428. rasa/dialogue_understanding/utils.py +14 -0
  429. rasa/dialogue_understanding_test/__init__.py +0 -0
  430. rasa/dialogue_understanding_test/command_metric_calculation.py +12 -0
  431. rasa/dialogue_understanding_test/constants.py +17 -0
  432. rasa/dialogue_understanding_test/du_test_case.py +118 -0
  433. rasa/dialogue_understanding_test/du_test_result.py +11 -0
  434. rasa/dialogue_understanding_test/du_test_runner.py +93 -0
  435. rasa/dialogue_understanding_test/io.py +54 -0
  436. rasa/dialogue_understanding_test/validation.py +22 -0
  437. rasa/e2e_test/__init__.py +0 -0
  438. rasa/e2e_test/aggregate_test_stats_calculator.py +134 -0
  439. rasa/e2e_test/assertions.py +1345 -0
  440. rasa/e2e_test/assertions_schema.yml +129 -0
  441. rasa/e2e_test/constants.py +31 -0
  442. rasa/e2e_test/e2e_config.py +220 -0
  443. rasa/e2e_test/e2e_config_schema.yml +26 -0
  444. rasa/e2e_test/e2e_test_case.py +569 -0
  445. rasa/e2e_test/e2e_test_converter.py +363 -0
  446. rasa/e2e_test/e2e_test_converter_prompt.jinja2 +70 -0
  447. rasa/e2e_test/e2e_test_coverage_report.py +364 -0
  448. rasa/e2e_test/e2e_test_result.py +54 -0
  449. rasa/e2e_test/e2e_test_runner.py +1192 -0
  450. rasa/e2e_test/e2e_test_schema.yml +181 -0
  451. rasa/e2e_test/pykwalify_extensions.py +39 -0
  452. rasa/e2e_test/stub_custom_action.py +70 -0
  453. rasa/e2e_test/utils/__init__.py +0 -0
  454. rasa/e2e_test/utils/e2e_yaml_utils.py +55 -0
  455. rasa/e2e_test/utils/io.py +598 -0
  456. rasa/e2e_test/utils/validation.py +178 -0
  457. rasa/engine/__init__.py +0 -0
  458. rasa/engine/caching.py +463 -0
  459. rasa/engine/constants.py +17 -0
  460. rasa/engine/exceptions.py +14 -0
  461. rasa/engine/graph.py +642 -0
  462. rasa/engine/loader.py +48 -0
  463. rasa/engine/recipes/__init__.py +0 -0
  464. rasa/engine/recipes/config_files/default_config.yml +41 -0
  465. rasa/engine/recipes/default_components.py +97 -0
  466. rasa/engine/recipes/default_recipe.py +1272 -0
  467. rasa/engine/recipes/graph_recipe.py +79 -0
  468. rasa/engine/recipes/recipe.py +93 -0
  469. rasa/engine/runner/__init__.py +0 -0
  470. rasa/engine/runner/dask.py +250 -0
  471. rasa/engine/runner/interface.py +49 -0
  472. rasa/engine/storage/__init__.py +0 -0
  473. rasa/engine/storage/local_model_storage.py +244 -0
  474. rasa/engine/storage/resource.py +110 -0
  475. rasa/engine/storage/storage.py +199 -0
  476. rasa/engine/training/__init__.py +0 -0
  477. rasa/engine/training/components.py +176 -0
  478. rasa/engine/training/fingerprinting.py +64 -0
  479. rasa/engine/training/graph_trainer.py +256 -0
  480. rasa/engine/training/hooks.py +164 -0
  481. rasa/engine/validation.py +1451 -0
  482. rasa/env.py +14 -0
  483. rasa/exceptions.py +69 -0
  484. rasa/graph_components/__init__.py +0 -0
  485. rasa/graph_components/converters/__init__.py +0 -0
  486. rasa/graph_components/converters/nlu_message_converter.py +48 -0
  487. rasa/graph_components/providers/__init__.py +0 -0
  488. rasa/graph_components/providers/domain_for_core_training_provider.py +87 -0
  489. rasa/graph_components/providers/domain_provider.py +71 -0
  490. rasa/graph_components/providers/flows_provider.py +74 -0
  491. rasa/graph_components/providers/forms_provider.py +44 -0
  492. rasa/graph_components/providers/nlu_training_data_provider.py +56 -0
  493. rasa/graph_components/providers/responses_provider.py +44 -0
  494. rasa/graph_components/providers/rule_only_provider.py +49 -0
  495. rasa/graph_components/providers/story_graph_provider.py +96 -0
  496. rasa/graph_components/providers/training_tracker_provider.py +55 -0
  497. rasa/graph_components/validators/__init__.py +0 -0
  498. rasa/graph_components/validators/default_recipe_validator.py +550 -0
  499. rasa/graph_components/validators/finetuning_validator.py +302 -0
  500. rasa/hooks.py +111 -0
  501. rasa/jupyter.py +63 -0
  502. rasa/llm_fine_tuning/__init__.py +0 -0
  503. rasa/llm_fine_tuning/annotation_module.py +241 -0
  504. rasa/llm_fine_tuning/conversations.py +144 -0
  505. rasa/llm_fine_tuning/llm_data_preparation_module.py +178 -0
  506. rasa/llm_fine_tuning/paraphrasing/__init__.py +0 -0
  507. rasa/llm_fine_tuning/paraphrasing/conversation_rephraser.py +281 -0
  508. rasa/llm_fine_tuning/paraphrasing/default_rephrase_prompt_template.jina2 +44 -0
  509. rasa/llm_fine_tuning/paraphrasing/rephrase_validator.py +121 -0
  510. rasa/llm_fine_tuning/paraphrasing/rephrased_user_message.py +10 -0
  511. rasa/llm_fine_tuning/paraphrasing_module.py +128 -0
  512. rasa/llm_fine_tuning/storage.py +174 -0
  513. rasa/llm_fine_tuning/train_test_split_module.py +441 -0
  514. rasa/markers/__init__.py +0 -0
  515. rasa/markers/marker.py +269 -0
  516. rasa/markers/marker_base.py +828 -0
  517. rasa/markers/upload.py +74 -0
  518. rasa/markers/validate.py +21 -0
  519. rasa/model.py +118 -0
  520. rasa/model_manager/__init__.py +0 -0
  521. rasa/model_manager/config.py +40 -0
  522. rasa/model_manager/model_api.py +559 -0
  523. rasa/model_manager/runner_service.py +286 -0
  524. rasa/model_manager/socket_bridge.py +146 -0
  525. rasa/model_manager/studio_jwt_auth.py +86 -0
  526. rasa/model_manager/trainer_service.py +325 -0
  527. rasa/model_manager/utils.py +87 -0
  528. rasa/model_manager/warm_rasa_process.py +187 -0
  529. rasa/model_service.py +112 -0
  530. rasa/model_testing.py +457 -0
  531. rasa/model_training.py +596 -0
  532. rasa/nlu/__init__.py +7 -0
  533. rasa/nlu/classifiers/__init__.py +3 -0
  534. rasa/nlu/classifiers/classifier.py +5 -0
  535. rasa/nlu/classifiers/diet_classifier.py +1881 -0
  536. rasa/nlu/classifiers/fallback_classifier.py +192 -0
  537. rasa/nlu/classifiers/keyword_intent_classifier.py +188 -0
  538. rasa/nlu/classifiers/logistic_regression_classifier.py +253 -0
  539. rasa/nlu/classifiers/mitie_intent_classifier.py +156 -0
  540. rasa/nlu/classifiers/regex_message_handler.py +56 -0
  541. rasa/nlu/classifiers/sklearn_intent_classifier.py +330 -0
  542. rasa/nlu/constants.py +77 -0
  543. rasa/nlu/convert.py +40 -0
  544. rasa/nlu/emulators/__init__.py +0 -0
  545. rasa/nlu/emulators/dialogflow.py +55 -0
  546. rasa/nlu/emulators/emulator.py +49 -0
  547. rasa/nlu/emulators/luis.py +86 -0
  548. rasa/nlu/emulators/no_emulator.py +10 -0
  549. rasa/nlu/emulators/wit.py +56 -0
  550. rasa/nlu/extractors/__init__.py +0 -0
  551. rasa/nlu/extractors/crf_entity_extractor.py +715 -0
  552. rasa/nlu/extractors/duckling_entity_extractor.py +206 -0
  553. rasa/nlu/extractors/entity_synonyms.py +178 -0
  554. rasa/nlu/extractors/extractor.py +470 -0
  555. rasa/nlu/extractors/mitie_entity_extractor.py +293 -0
  556. rasa/nlu/extractors/regex_entity_extractor.py +220 -0
  557. rasa/nlu/extractors/spacy_entity_extractor.py +95 -0
  558. rasa/nlu/featurizers/__init__.py +0 -0
  559. rasa/nlu/featurizers/dense_featurizer/__init__.py +0 -0
  560. rasa/nlu/featurizers/dense_featurizer/convert_featurizer.py +445 -0
  561. rasa/nlu/featurizers/dense_featurizer/dense_featurizer.py +57 -0
  562. rasa/nlu/featurizers/dense_featurizer/lm_featurizer.py +768 -0
  563. rasa/nlu/featurizers/dense_featurizer/mitie_featurizer.py +170 -0
  564. rasa/nlu/featurizers/dense_featurizer/spacy_featurizer.py +132 -0
  565. rasa/nlu/featurizers/featurizer.py +89 -0
  566. rasa/nlu/featurizers/sparse_featurizer/__init__.py +0 -0
  567. rasa/nlu/featurizers/sparse_featurizer/count_vectors_featurizer.py +867 -0
  568. rasa/nlu/featurizers/sparse_featurizer/lexical_syntactic_featurizer.py +571 -0
  569. rasa/nlu/featurizers/sparse_featurizer/regex_featurizer.py +271 -0
  570. rasa/nlu/featurizers/sparse_featurizer/sparse_featurizer.py +9 -0
  571. rasa/nlu/model.py +24 -0
  572. rasa/nlu/run.py +27 -0
  573. rasa/nlu/selectors/__init__.py +0 -0
  574. rasa/nlu/selectors/response_selector.py +987 -0
  575. rasa/nlu/test.py +1940 -0
  576. rasa/nlu/tokenizers/__init__.py +0 -0
  577. rasa/nlu/tokenizers/jieba_tokenizer.py +148 -0
  578. rasa/nlu/tokenizers/mitie_tokenizer.py +75 -0
  579. rasa/nlu/tokenizers/spacy_tokenizer.py +72 -0
  580. rasa/nlu/tokenizers/tokenizer.py +239 -0
  581. rasa/nlu/tokenizers/whitespace_tokenizer.py +95 -0
  582. rasa/nlu/utils/__init__.py +35 -0
  583. rasa/nlu/utils/bilou_utils.py +462 -0
  584. rasa/nlu/utils/hugging_face/__init__.py +0 -0
  585. rasa/nlu/utils/hugging_face/registry.py +108 -0
  586. rasa/nlu/utils/hugging_face/transformers_pre_post_processors.py +311 -0
  587. rasa/nlu/utils/mitie_utils.py +113 -0
  588. rasa/nlu/utils/pattern_utils.py +168 -0
  589. rasa/nlu/utils/spacy_utils.py +310 -0
  590. rasa/plugin.py +90 -0
  591. rasa/server.py +1588 -0
  592. rasa/shared/__init__.py +0 -0
  593. rasa/shared/constants.py +311 -0
  594. rasa/shared/core/__init__.py +0 -0
  595. rasa/shared/core/command_payload_reader.py +109 -0
  596. rasa/shared/core/constants.py +180 -0
  597. rasa/shared/core/conversation.py +46 -0
  598. rasa/shared/core/domain.py +2172 -0
  599. rasa/shared/core/events.py +2559 -0
  600. rasa/shared/core/flows/__init__.py +7 -0
  601. rasa/shared/core/flows/flow.py +562 -0
  602. rasa/shared/core/flows/flow_path.py +84 -0
  603. rasa/shared/core/flows/flow_step.py +146 -0
  604. rasa/shared/core/flows/flow_step_links.py +319 -0
  605. rasa/shared/core/flows/flow_step_sequence.py +70 -0
  606. rasa/shared/core/flows/flows_list.py +258 -0
  607. rasa/shared/core/flows/flows_yaml_schema.json +303 -0
  608. rasa/shared/core/flows/nlu_trigger.py +117 -0
  609. rasa/shared/core/flows/steps/__init__.py +24 -0
  610. rasa/shared/core/flows/steps/action.py +56 -0
  611. rasa/shared/core/flows/steps/call.py +64 -0
  612. rasa/shared/core/flows/steps/collect.py +112 -0
  613. rasa/shared/core/flows/steps/constants.py +5 -0
  614. rasa/shared/core/flows/steps/continuation.py +36 -0
  615. rasa/shared/core/flows/steps/end.py +22 -0
  616. rasa/shared/core/flows/steps/internal.py +44 -0
  617. rasa/shared/core/flows/steps/link.py +51 -0
  618. rasa/shared/core/flows/steps/no_operation.py +48 -0
  619. rasa/shared/core/flows/steps/set_slots.py +50 -0
  620. rasa/shared/core/flows/steps/start.py +30 -0
  621. rasa/shared/core/flows/utils.py +39 -0
  622. rasa/shared/core/flows/validation.py +735 -0
  623. rasa/shared/core/flows/yaml_flows_io.py +405 -0
  624. rasa/shared/core/generator.py +908 -0
  625. rasa/shared/core/slot_mappings.py +526 -0
  626. rasa/shared/core/slots.py +654 -0
  627. rasa/shared/core/trackers.py +1183 -0
  628. rasa/shared/core/training_data/__init__.py +0 -0
  629. rasa/shared/core/training_data/loading.py +89 -0
  630. rasa/shared/core/training_data/story_reader/__init__.py +0 -0
  631. rasa/shared/core/training_data/story_reader/story_reader.py +129 -0
  632. rasa/shared/core/training_data/story_reader/story_step_builder.py +168 -0
  633. rasa/shared/core/training_data/story_reader/yaml_story_reader.py +888 -0
  634. rasa/shared/core/training_data/story_writer/__init__.py +0 -0
  635. rasa/shared/core/training_data/story_writer/story_writer.py +76 -0
  636. rasa/shared/core/training_data/story_writer/yaml_story_writer.py +444 -0
  637. rasa/shared/core/training_data/structures.py +858 -0
  638. rasa/shared/core/training_data/visualization.html +146 -0
  639. rasa/shared/core/training_data/visualization.py +603 -0
  640. rasa/shared/data.py +249 -0
  641. rasa/shared/engine/__init__.py +0 -0
  642. rasa/shared/engine/caching.py +26 -0
  643. rasa/shared/exceptions.py +167 -0
  644. rasa/shared/importers/__init__.py +0 -0
  645. rasa/shared/importers/importer.py +770 -0
  646. rasa/shared/importers/multi_project.py +215 -0
  647. rasa/shared/importers/rasa.py +108 -0
  648. rasa/shared/importers/remote_importer.py +196 -0
  649. rasa/shared/importers/utils.py +36 -0
  650. rasa/shared/nlu/__init__.py +0 -0
  651. rasa/shared/nlu/constants.py +53 -0
  652. rasa/shared/nlu/interpreter.py +10 -0
  653. rasa/shared/nlu/training_data/__init__.py +0 -0
  654. rasa/shared/nlu/training_data/entities_parser.py +208 -0
  655. rasa/shared/nlu/training_data/features.py +492 -0
  656. rasa/shared/nlu/training_data/formats/__init__.py +10 -0
  657. rasa/shared/nlu/training_data/formats/dialogflow.py +163 -0
  658. rasa/shared/nlu/training_data/formats/luis.py +87 -0
  659. rasa/shared/nlu/training_data/formats/rasa.py +135 -0
  660. rasa/shared/nlu/training_data/formats/rasa_yaml.py +618 -0
  661. rasa/shared/nlu/training_data/formats/readerwriter.py +244 -0
  662. rasa/shared/nlu/training_data/formats/wit.py +52 -0
  663. rasa/shared/nlu/training_data/loading.py +137 -0
  664. rasa/shared/nlu/training_data/lookup_tables_parser.py +30 -0
  665. rasa/shared/nlu/training_data/message.py +490 -0
  666. rasa/shared/nlu/training_data/schemas/__init__.py +0 -0
  667. rasa/shared/nlu/training_data/schemas/data_schema.py +85 -0
  668. rasa/shared/nlu/training_data/schemas/nlu.yml +53 -0
  669. rasa/shared/nlu/training_data/schemas/responses.yml +70 -0
  670. rasa/shared/nlu/training_data/synonyms_parser.py +42 -0
  671. rasa/shared/nlu/training_data/training_data.py +729 -0
  672. rasa/shared/nlu/training_data/util.py +223 -0
  673. rasa/shared/providers/__init__.py +0 -0
  674. rasa/shared/providers/_configs/__init__.py +0 -0
  675. rasa/shared/providers/_configs/azure_openai_client_config.py +677 -0
  676. rasa/shared/providers/_configs/client_config.py +59 -0
  677. rasa/shared/providers/_configs/default_litellm_client_config.py +132 -0
  678. rasa/shared/providers/_configs/huggingface_local_embedding_client_config.py +236 -0
  679. rasa/shared/providers/_configs/litellm_router_client_config.py +222 -0
  680. rasa/shared/providers/_configs/model_group_config.py +173 -0
  681. rasa/shared/providers/_configs/openai_client_config.py +177 -0
  682. rasa/shared/providers/_configs/rasa_llm_client_config.py +75 -0
  683. rasa/shared/providers/_configs/self_hosted_llm_client_config.py +178 -0
  684. rasa/shared/providers/_configs/utils.py +117 -0
  685. rasa/shared/providers/_ssl_verification_utils.py +124 -0
  686. rasa/shared/providers/_utils.py +79 -0
  687. rasa/shared/providers/constants.py +7 -0
  688. rasa/shared/providers/embedding/__init__.py +0 -0
  689. rasa/shared/providers/embedding/_base_litellm_embedding_client.py +243 -0
  690. rasa/shared/providers/embedding/_langchain_embedding_client_adapter.py +74 -0
  691. rasa/shared/providers/embedding/azure_openai_embedding_client.py +335 -0
  692. rasa/shared/providers/embedding/default_litellm_embedding_client.py +126 -0
  693. rasa/shared/providers/embedding/embedding_client.py +90 -0
  694. rasa/shared/providers/embedding/embedding_response.py +41 -0
  695. rasa/shared/providers/embedding/huggingface_local_embedding_client.py +191 -0
  696. rasa/shared/providers/embedding/litellm_router_embedding_client.py +138 -0
  697. rasa/shared/providers/embedding/openai_embedding_client.py +172 -0
  698. rasa/shared/providers/llm/__init__.py +0 -0
  699. rasa/shared/providers/llm/_base_litellm_client.py +265 -0
  700. rasa/shared/providers/llm/azure_openai_llm_client.py +415 -0
  701. rasa/shared/providers/llm/default_litellm_llm_client.py +110 -0
  702. rasa/shared/providers/llm/litellm_router_llm_client.py +202 -0
  703. rasa/shared/providers/llm/llm_client.py +78 -0
  704. rasa/shared/providers/llm/llm_response.py +50 -0
  705. rasa/shared/providers/llm/openai_llm_client.py +161 -0
  706. rasa/shared/providers/llm/rasa_llm_client.py +120 -0
  707. rasa/shared/providers/llm/self_hosted_llm_client.py +276 -0
  708. rasa/shared/providers/mappings.py +94 -0
  709. rasa/shared/providers/router/__init__.py +0 -0
  710. rasa/shared/providers/router/_base_litellm_router_client.py +185 -0
  711. rasa/shared/providers/router/router_client.py +75 -0
  712. rasa/shared/utils/__init__.py +0 -0
  713. rasa/shared/utils/cli.py +102 -0
  714. rasa/shared/utils/common.py +324 -0
  715. rasa/shared/utils/constants.py +4 -0
  716. rasa/shared/utils/health_check/__init__.py +0 -0
  717. rasa/shared/utils/health_check/embeddings_health_check_mixin.py +31 -0
  718. rasa/shared/utils/health_check/health_check.py +258 -0
  719. rasa/shared/utils/health_check/llm_health_check_mixin.py +31 -0
  720. rasa/shared/utils/io.py +499 -0
  721. rasa/shared/utils/llm.py +764 -0
  722. rasa/shared/utils/pykwalify_extensions.py +27 -0
  723. rasa/shared/utils/schemas/__init__.py +0 -0
  724. rasa/shared/utils/schemas/config.yml +2 -0
  725. rasa/shared/utils/schemas/domain.yml +145 -0
  726. rasa/shared/utils/schemas/events.py +214 -0
  727. rasa/shared/utils/schemas/model_config.yml +36 -0
  728. rasa/shared/utils/schemas/stories.yml +173 -0
  729. rasa/shared/utils/yaml.py +1068 -0
  730. rasa/studio/__init__.py +0 -0
  731. rasa/studio/auth.py +270 -0
  732. rasa/studio/config.py +136 -0
  733. rasa/studio/constants.py +19 -0
  734. rasa/studio/data_handler.py +368 -0
  735. rasa/studio/download.py +489 -0
  736. rasa/studio/results_logger.py +137 -0
  737. rasa/studio/train.py +134 -0
  738. rasa/studio/upload.py +563 -0
  739. rasa/telemetry.py +1876 -0
  740. rasa/tracing/__init__.py +0 -0
  741. rasa/tracing/config.py +355 -0
  742. rasa/tracing/constants.py +62 -0
  743. rasa/tracing/instrumentation/__init__.py +0 -0
  744. rasa/tracing/instrumentation/attribute_extractors.py +765 -0
  745. rasa/tracing/instrumentation/instrumentation.py +1306 -0
  746. rasa/tracing/instrumentation/intentless_policy_instrumentation.py +144 -0
  747. rasa/tracing/instrumentation/metrics.py +294 -0
  748. rasa/tracing/metric_instrument_provider.py +205 -0
  749. rasa/utils/__init__.py +0 -0
  750. rasa/utils/beta.py +83 -0
  751. rasa/utils/cli.py +28 -0
  752. rasa/utils/common.py +639 -0
  753. rasa/utils/converter.py +53 -0
  754. rasa/utils/endpoints.py +331 -0
  755. rasa/utils/io.py +252 -0
  756. rasa/utils/json_utils.py +60 -0
  757. rasa/utils/licensing.py +542 -0
  758. rasa/utils/log_utils.py +181 -0
  759. rasa/utils/mapper.py +210 -0
  760. rasa/utils/ml_utils.py +147 -0
  761. rasa/utils/plotting.py +362 -0
  762. rasa/utils/sanic_error_handler.py +32 -0
  763. rasa/utils/singleton.py +23 -0
  764. rasa/utils/tensorflow/__init__.py +0 -0
  765. rasa/utils/tensorflow/callback.py +112 -0
  766. rasa/utils/tensorflow/constants.py +116 -0
  767. rasa/utils/tensorflow/crf.py +492 -0
  768. rasa/utils/tensorflow/data_generator.py +440 -0
  769. rasa/utils/tensorflow/environment.py +161 -0
  770. rasa/utils/tensorflow/exceptions.py +5 -0
  771. rasa/utils/tensorflow/feature_array.py +366 -0
  772. rasa/utils/tensorflow/layers.py +1565 -0
  773. rasa/utils/tensorflow/layers_utils.py +113 -0
  774. rasa/utils/tensorflow/metrics.py +281 -0
  775. rasa/utils/tensorflow/model_data.py +798 -0
  776. rasa/utils/tensorflow/model_data_utils.py +499 -0
  777. rasa/utils/tensorflow/models.py +935 -0
  778. rasa/utils/tensorflow/rasa_layers.py +1094 -0
  779. rasa/utils/tensorflow/transformer.py +640 -0
  780. rasa/utils/tensorflow/types.py +6 -0
  781. rasa/utils/train_utils.py +572 -0
  782. rasa/utils/url_tools.py +53 -0
  783. rasa/utils/yaml.py +54 -0
  784. rasa/validator.py +1644 -0
  785. rasa/version.py +3 -0
  786. rasa_pro-3.12.0.dev1.dist-info/METADATA +199 -0
  787. rasa_pro-3.12.0.dev1.dist-info/NOTICE +5 -0
  788. rasa_pro-3.12.0.dev1.dist-info/RECORD +790 -0
  789. rasa_pro-3.12.0.dev1.dist-info/WHEEL +4 -0
  790. rasa_pro-3.12.0.dev1.dist-info/entry_points.txt +3 -0
rasa/nlu/test.py ADDED
@@ -0,0 +1,1940 @@
1
+ import copy
2
+ import itertools
3
+ import os
4
+ import logging
5
+ import structlog
6
+ from pathlib import Path
7
+
8
+ import numpy as np
9
+ from collections import defaultdict, namedtuple
10
+ from tqdm import tqdm
11
+ from typing import (
12
+ Iterable,
13
+ Iterator,
14
+ Tuple,
15
+ List,
16
+ Set,
17
+ Optional,
18
+ Text,
19
+ Union,
20
+ Dict,
21
+ Any,
22
+ NamedTuple,
23
+ TYPE_CHECKING,
24
+ )
25
+
26
+ from rasa import telemetry
27
+ from rasa.core.agent import Agent
28
+ from rasa.core.channels import UserMessage
29
+ from rasa.core.processor import MessageProcessor
30
+ from rasa.shared.nlu.training_data.training_data import TrainingData
31
+ from rasa.shared.utils.yaml import write_yaml
32
+ from rasa.utils.common import TempDirectoryPath, get_temp_dir_name
33
+ import rasa.shared.utils.io
34
+ import rasa.utils.plotting as plot_utils
35
+ import rasa.utils.io as io_utils
36
+
37
+ from rasa.constants import TEST_DATA_FILE, TRAIN_DATA_FILE, NLG_DATA_FILE
38
+ import rasa.nlu.classifiers.fallback_classifier
39
+ from rasa.nlu.constants import (
40
+ RESPONSE_SELECTOR_DEFAULT_INTENT,
41
+ RESPONSE_SELECTOR_PROPERTY_NAME,
42
+ RESPONSE_SELECTOR_PREDICTION_KEY,
43
+ TOKENS_NAMES,
44
+ ENTITY_ATTRIBUTE_CONFIDENCE_TYPE,
45
+ ENTITY_ATTRIBUTE_CONFIDENCE_ROLE,
46
+ ENTITY_ATTRIBUTE_CONFIDENCE_GROUP,
47
+ RESPONSE_SELECTOR_RETRIEVAL_INTENTS,
48
+ )
49
+ from rasa.shared.nlu.constants import (
50
+ TEXT,
51
+ INTENT,
52
+ INTENT_RESPONSE_KEY,
53
+ ENTITIES,
54
+ EXTRACTOR,
55
+ PRETRAINED_EXTRACTORS,
56
+ ENTITY_ATTRIBUTE_TYPE,
57
+ ENTITY_ATTRIBUTE_GROUP,
58
+ ENTITY_ATTRIBUTE_ROLE,
59
+ NO_ENTITY_TAG,
60
+ INTENT_NAME_KEY,
61
+ PREDICTED_CONFIDENCE_KEY,
62
+ )
63
+ from rasa.nlu.classifiers import fallback_classifier
64
+ from rasa.nlu.tokenizers.tokenizer import Token
65
+ from rasa.shared.importers.importer import TrainingDataImporter
66
+ from rasa.shared.nlu.training_data.formats.rasa_yaml import RasaYAMLWriter
67
+
68
+ if TYPE_CHECKING:
69
+ from typing_extensions import TypedDict
70
+
71
+ EntityPrediction = TypedDict(
72
+ "EntityPrediction",
73
+ {
74
+ "text": Text,
75
+ "entities": List[Dict[Text, Any]],
76
+ "predicted_entities": List[Dict[Text, Any]],
77
+ },
78
+ )
79
+ logger = logging.getLogger(__name__)
80
+ structlogger = structlog.get_logger()
81
+
82
+ # Exclude 'EntitySynonymMapper' and 'ResponseSelector' as their super class
83
+ # performs entity extraction but those two classifiers don't
84
+ ENTITY_PROCESSORS = {"EntitySynonymMapper", "ResponseSelector"}
85
+
86
+ EXTRACTORS_WITH_CONFIDENCES = {"CRFEntityExtractor", "DIETClassifier"}
87
+
88
+
89
+ class CVEvaluationResult(NamedTuple):
90
+ """Stores NLU cross-validation results."""
91
+
92
+ train: Dict
93
+ test: Dict
94
+ evaluation: Dict
95
+
96
+
97
+ NO_ENTITY = "no_entity"
98
+
99
+ IntentEvaluationResult = namedtuple(
100
+ "IntentEvaluationResult", "intent_target intent_prediction message confidence"
101
+ )
102
+
103
+ ResponseSelectionEvaluationResult = namedtuple(
104
+ "ResponseSelectionEvaluationResult",
105
+ "intent_response_key_target intent_response_key_prediction message confidence",
106
+ )
107
+
108
+ EntityEvaluationResult = namedtuple(
109
+ "EntityEvaluationResult", "entity_targets entity_predictions tokens message"
110
+ )
111
+
112
+ IntentMetrics = Dict[Text, List[float]]
113
+ EntityMetrics = Dict[Text, Dict[Text, List[float]]]
114
+ ResponseSelectionMetrics = Dict[Text, List[float]]
115
+
116
+
117
+ def log_evaluation_table(
118
+ report: Text, precision: float, f1: float, accuracy: float
119
+ ) -> None: # pragma: no cover
120
+ """Log the sklearn evaluation metrics."""
121
+ logger.info(f"F1-Score: {f1}")
122
+ logger.info(f"Precision: {precision}")
123
+ logger.info(f"Accuracy: {accuracy}")
124
+ logger.info(f"Classification report: \n{report}")
125
+
126
+
127
+ def remove_empty_intent_examples(
128
+ intent_results: List[IntentEvaluationResult],
129
+ ) -> List[IntentEvaluationResult]:
130
+ """Remove those examples without an intent.
131
+
132
+ Args:
133
+ intent_results: intent evaluation results
134
+
135
+ Returns: intent evaluation results
136
+ """
137
+ filtered = []
138
+ for r in intent_results:
139
+ # substitute None values with empty string
140
+ # to enable sklearn evaluation
141
+ if r.intent_prediction is None:
142
+ r = r._replace(intent_prediction="")
143
+
144
+ if r.intent_target != "" and r.intent_target is not None:
145
+ filtered.append(r)
146
+
147
+ return filtered
148
+
149
+
150
+ def remove_empty_response_examples(
151
+ response_results: List[ResponseSelectionEvaluationResult],
152
+ ) -> List[ResponseSelectionEvaluationResult]:
153
+ """Remove those examples without a response.
154
+
155
+ Args:
156
+ response_results: response selection evaluation results
157
+
158
+ Returns:
159
+ Response selection evaluation results
160
+ """
161
+ filtered = []
162
+ for r in response_results:
163
+ # substitute None values with empty string
164
+ # to enable sklearn evaluation
165
+ if r.intent_response_key_prediction is None:
166
+ r = r._replace(intent_response_key_prediction="")
167
+
168
+ if r.confidence is None:
169
+ # This might happen if response selector training data is present but
170
+ # no response selector is part of the model
171
+ r = r._replace(confidence=0.0)
172
+
173
+ if r.intent_response_key_target:
174
+ filtered.append(r)
175
+
176
+ return filtered
177
+
178
+
179
+ def drop_intents_below_freq(
180
+ training_data: TrainingData, cutoff: int = 5
181
+ ) -> TrainingData:
182
+ """Remove intent groups with less than cutoff instances.
183
+
184
+ Args:
185
+ training_data: training data
186
+ cutoff: threshold
187
+
188
+ Returns: updated training data
189
+ """
190
+ logger.debug(
191
+ "Raw data intent examples: {}".format(len(training_data.intent_examples))
192
+ )
193
+
194
+ examples_per_intent = training_data.number_of_examples_per_intent
195
+ return training_data.filter_training_examples(
196
+ lambda ex: examples_per_intent.get(ex.get(INTENT), 0) >= cutoff
197
+ )
198
+
199
+
200
+ def write_intent_successes(
201
+ intent_results: List[IntentEvaluationResult], successes_filename: Text
202
+ ) -> None:
203
+ """Write successful intent predictions to a file.
204
+
205
+ Args:
206
+ intent_results: intent evaluation result
207
+ successes_filename: filename of file to save successful predictions to
208
+ """
209
+ successes = [
210
+ {
211
+ "text": r.message,
212
+ "intent": r.intent_target,
213
+ "intent_prediction": {
214
+ INTENT_NAME_KEY: r.intent_prediction,
215
+ "confidence": r.confidence,
216
+ },
217
+ }
218
+ for r in intent_results
219
+ if r.intent_target == r.intent_prediction
220
+ ]
221
+
222
+ if successes:
223
+ rasa.shared.utils.io.dump_obj_as_json_to_file(successes_filename, successes)
224
+ logger.info(f"Successful intent predictions saved to {successes_filename}.")
225
+ logger.debug(f"\n\nSuccessfully predicted the following intents: \n{successes}")
226
+ else:
227
+ logger.info("No successful intent predictions found.")
228
+
229
+
230
+ def _write_errors(errors: List[Dict], errors_filename: Text, error_type: Text) -> None:
231
+ """Write incorrect intent predictions to a file.
232
+
233
+ Args:
234
+ errors: Serializable prediction errors.
235
+ errors_filename: filename of file to save incorrect predictions to
236
+ error_type: NLU entity which was evaluated (e.g. `intent` or `entity`).
237
+ """
238
+ if errors:
239
+ rasa.shared.utils.io.dump_obj_as_json_to_file(errors_filename, errors)
240
+ logger.info(f"Incorrect {error_type} predictions saved to {errors_filename}.")
241
+ logger.debug(
242
+ f"\n\nThese {error_type} examples could not be classified "
243
+ f"correctly: \n{errors}"
244
+ )
245
+ else:
246
+ logger.info(f"Every {error_type} was predicted correctly by the model.")
247
+
248
+
249
+ def _get_intent_errors(intent_results: List[IntentEvaluationResult]) -> List[Dict]:
250
+ return [
251
+ {
252
+ "text": r.message,
253
+ "intent": r.intent_target,
254
+ "intent_prediction": {
255
+ INTENT_NAME_KEY: r.intent_prediction,
256
+ "confidence": r.confidence,
257
+ },
258
+ }
259
+ for r in intent_results
260
+ if r.intent_target != r.intent_prediction
261
+ ]
262
+
263
+
264
+ def write_response_successes(
265
+ response_results: List[ResponseSelectionEvaluationResult], successes_filename: Text
266
+ ) -> None:
267
+ """Write successful response selection predictions to a file.
268
+
269
+ Args:
270
+ response_results: response selection evaluation result
271
+ successes_filename: filename of file to save successful predictions to
272
+ """
273
+ successes = [
274
+ {
275
+ "text": r.message,
276
+ "intent_response_key_target": r.intent_response_key_target,
277
+ "intent_response_key_prediction": {
278
+ "name": r.intent_response_key_prediction,
279
+ "confidence": r.confidence,
280
+ },
281
+ }
282
+ for r in response_results
283
+ if r.intent_response_key_prediction == r.intent_response_key_target
284
+ ]
285
+
286
+ if successes:
287
+ rasa.shared.utils.io.dump_obj_as_json_to_file(successes_filename, successes)
288
+ logger.info(f"Successful response predictions saved to {successes_filename}.")
289
+ structlogger.debug("test.write.response", successes=copy.deepcopy(successes))
290
+ else:
291
+ logger.info("No successful response predictions found.")
292
+
293
+
294
+ def _response_errors(
295
+ response_results: List[ResponseSelectionEvaluationResult],
296
+ ) -> List[Dict]:
297
+ """Write incorrect response selection predictions to a file.
298
+
299
+ Args:
300
+ response_results: response selection evaluation result
301
+
302
+ Returns:
303
+ Serializable prediction errors.
304
+ """
305
+ return [
306
+ {
307
+ "text": r.message,
308
+ "intent_response_key_target": r.intent_response_key_target,
309
+ "intent_response_key_prediction": {
310
+ "name": r.intent_response_key_prediction,
311
+ "confidence": r.confidence,
312
+ },
313
+ }
314
+ for r in response_results
315
+ if r.intent_response_key_prediction != r.intent_response_key_target
316
+ ]
317
+
318
+
319
+ def plot_attribute_confidences(
320
+ results: Union[
321
+ List[IntentEvaluationResult], List[ResponseSelectionEvaluationResult]
322
+ ],
323
+ hist_filename: Optional[Text],
324
+ target_key: Text,
325
+ prediction_key: Text,
326
+ title: Text,
327
+ ) -> None:
328
+ """Create histogram of confidence distribution.
329
+
330
+ Args:
331
+ results: evaluation results
332
+ hist_filename: filename to save plot to
333
+ target_key: key of target in results
334
+ prediction_key: key of predictions in results
335
+ title: title of plot
336
+ """
337
+ pos_hist = [
338
+ r.confidence
339
+ for r in results
340
+ if getattr(r, target_key) == getattr(r, prediction_key)
341
+ ]
342
+
343
+ neg_hist = [
344
+ r.confidence
345
+ for r in results
346
+ if getattr(r, target_key) != getattr(r, prediction_key)
347
+ ]
348
+
349
+ plot_utils.plot_paired_histogram([pos_hist, neg_hist], title, hist_filename)
350
+
351
+
352
+ def plot_entity_confidences(
353
+ merged_targets: List[Text],
354
+ merged_predictions: List[Text],
355
+ merged_confidences: List[float],
356
+ hist_filename: Text,
357
+ title: Text,
358
+ ) -> None:
359
+ """Creates histogram of confidence distribution.
360
+
361
+ Args:
362
+ merged_targets: Entity labels.
363
+ merged_predictions: Predicted entities.
364
+ merged_confidences: Confidence scores of predictions.
365
+ hist_filename: filename to save plot to
366
+ title: title of plot
367
+ """
368
+ pos_hist = [
369
+ confidence
370
+ for target, prediction, confidence in zip(
371
+ merged_targets, merged_predictions, merged_confidences
372
+ )
373
+ if target != NO_ENTITY and target == prediction
374
+ ]
375
+
376
+ neg_hist = [
377
+ confidence
378
+ for target, prediction, confidence in zip(
379
+ merged_targets, merged_predictions, merged_confidences
380
+ )
381
+ if prediction not in (NO_ENTITY, target)
382
+ ]
383
+
384
+ plot_utils.plot_paired_histogram([pos_hist, neg_hist], title, hist_filename)
385
+
386
+
387
+ def evaluate_response_selections(
388
+ response_selection_results: List[ResponseSelectionEvaluationResult],
389
+ output_directory: Optional[Text],
390
+ successes: bool,
391
+ errors: bool,
392
+ disable_plotting: bool,
393
+ report_as_dict: Optional[bool] = None,
394
+ ) -> Dict: # pragma: no cover
395
+ """Creates summary statistics for response selection.
396
+
397
+ Only considers those examples with a set response.
398
+ Others are filtered out. Returns a dictionary of containing the
399
+ evaluation result.
400
+
401
+ Args:
402
+ response_selection_results: response selection evaluation results
403
+ output_directory: directory to store files to
404
+ successes: if True success are written down to disk
405
+ errors: if True errors are written down to disk
406
+ disable_plotting: if True no plots are created
407
+ report_as_dict: `True` if the evaluation report should be returned as `dict`.
408
+ If `False` the report is returned in a human-readable text format. If `None`
409
+ `report_as_dict` is considered as `True` in case an `output_directory` is
410
+ given.
411
+
412
+ Returns: dictionary with evaluation results
413
+ """
414
+ # remove empty response targets
415
+ num_examples = len(response_selection_results)
416
+ response_selection_results = remove_empty_response_examples(
417
+ response_selection_results
418
+ )
419
+
420
+ logger.info(
421
+ f"Response Selection Evaluation: Only considering those "
422
+ f"{len(response_selection_results)} examples that have a defined response out "
423
+ f"of {num_examples} examples."
424
+ )
425
+
426
+ (
427
+ target_intent_response_keys,
428
+ predicted_intent_response_keys,
429
+ ) = _targets_predictions_from(
430
+ response_selection_results,
431
+ "intent_response_key_target",
432
+ "intent_response_key_prediction",
433
+ )
434
+
435
+ report, precision, f1, accuracy, confusion_matrix, labels = _calculate_report(
436
+ output_directory,
437
+ target_intent_response_keys,
438
+ predicted_intent_response_keys,
439
+ report_as_dict,
440
+ )
441
+ if output_directory:
442
+ _dump_report(output_directory, "response_selection_report.json", report)
443
+
444
+ if successes:
445
+ successes_filename = "response_selection_successes.json"
446
+ if output_directory:
447
+ successes_filename = os.path.join(output_directory, successes_filename)
448
+ # save classified samples to file for debugging
449
+ write_response_successes(response_selection_results, successes_filename)
450
+
451
+ response_errors = _response_errors(response_selection_results)
452
+
453
+ if errors and output_directory:
454
+ errors_filename = "response_selection_errors.json"
455
+ errors_filename = os.path.join(output_directory, errors_filename)
456
+ _write_errors(response_errors, errors_filename, error_type="response")
457
+
458
+ if not disable_plotting:
459
+ confusion_matrix_filename = "response_selection_confusion_matrix.png"
460
+ if output_directory:
461
+ confusion_matrix_filename = os.path.join(
462
+ output_directory, confusion_matrix_filename
463
+ )
464
+
465
+ plot_utils.plot_confusion_matrix(
466
+ confusion_matrix,
467
+ classes=labels,
468
+ title="Response Selection Confusion Matrix",
469
+ output_file=confusion_matrix_filename,
470
+ )
471
+
472
+ histogram_filename = "response_selection_histogram.png"
473
+ if output_directory:
474
+ histogram_filename = os.path.join(output_directory, histogram_filename)
475
+ plot_attribute_confidences(
476
+ response_selection_results,
477
+ histogram_filename,
478
+ "intent_response_key_target",
479
+ "intent_response_key_prediction",
480
+ title="Response Selection Prediction Confidence Distribution",
481
+ )
482
+
483
+ predictions = [
484
+ {
485
+ "text": res.message,
486
+ "intent_response_key_target": res.intent_response_key_target,
487
+ "intent_response_key_prediction": res.intent_response_key_prediction,
488
+ "confidence": res.confidence,
489
+ }
490
+ for res in response_selection_results
491
+ ]
492
+
493
+ return {
494
+ "predictions": predictions,
495
+ "report": report,
496
+ "precision": precision,
497
+ "f1_score": f1,
498
+ "accuracy": accuracy,
499
+ "errors": response_errors,
500
+ }
501
+
502
+
503
+ def _add_confused_labels_to_report(
504
+ report: Dict[Text, Dict[Text, Any]],
505
+ confusion_matrix: np.ndarray,
506
+ labels: List[Text],
507
+ exclude_labels: Optional[List[Text]] = None,
508
+ ) -> Dict[Text, Dict[Text, Union[Dict, Any]]]:
509
+ """Adds a field "confused_with" to the evaluation report.
510
+
511
+ The value is a dict of {"false_positive_label": false_positive_count} pairs.
512
+ If there are no false positives in the confusion matrix,
513
+ the dict will be empty. Typically, we include the two most
514
+ commonly false positive labels, three in the rare case that
515
+ the diagonal element in the confusion matrix is not one of the
516
+ three highest values in the row.
517
+
518
+ Args:
519
+ report: the evaluation report
520
+ confusion_matrix: confusion matrix
521
+ labels: list of labels
522
+ exclude_labels: labels to exclude from the report
523
+
524
+ Returns: updated evaluation report
525
+ """
526
+ if exclude_labels is None:
527
+ exclude_labels = []
528
+
529
+ # sort confusion matrix by false positives
530
+ indices = np.argsort(confusion_matrix, axis=1)
531
+ n_candidates = min(3, len(labels))
532
+
533
+ for label in labels:
534
+ if label in exclude_labels:
535
+ continue
536
+ # it is possible to predict intent 'None'
537
+ if report.get(label):
538
+ report[label]["confused_with"] = {}
539
+
540
+ for i, label in enumerate(labels):
541
+ if label in exclude_labels:
542
+ continue
543
+ for j in range(n_candidates):
544
+ label_idx = indices[i, -(1 + j)]
545
+ false_pos_label = labels[label_idx]
546
+ false_positives = int(confusion_matrix[i, label_idx])
547
+ if (
548
+ false_pos_label != label
549
+ and false_pos_label not in exclude_labels
550
+ and false_positives > 0
551
+ ):
552
+ report[label]["confused_with"][false_pos_label] = false_positives
553
+
554
+ return report
555
+
556
+
557
+ def evaluate_intents(
558
+ intent_results: List[IntentEvaluationResult],
559
+ output_directory: Optional[Text],
560
+ successes: bool,
561
+ errors: bool,
562
+ disable_plotting: bool,
563
+ report_as_dict: Optional[bool] = None,
564
+ ) -> Dict: # pragma: no cover
565
+ """Creates summary statistics for intents.
566
+
567
+ Only considers those examples with a set intent. Others are filtered out.
568
+ Returns a dictionary of containing the evaluation result.
569
+
570
+ Args:
571
+ intent_results: intent evaluation results
572
+ output_directory: directory to store files to
573
+ successes: if True correct predictions are written to disk
574
+ errors: if True incorrect predictions are written to disk
575
+ disable_plotting: if True no plots are created
576
+ report_as_dict: `True` if the evaluation report should be returned as `dict`.
577
+ If `False` the report is returned in a human-readable text format. If `None`
578
+ `report_as_dict` is considered as `True` in case an `output_directory` is
579
+ given.
580
+
581
+ Returns: dictionary with evaluation results
582
+ """
583
+ # remove empty intent targets
584
+ num_examples = len(intent_results)
585
+ intent_results = remove_empty_intent_examples(intent_results)
586
+
587
+ logger.info(
588
+ f"Intent Evaluation: Only considering those {len(intent_results)} examples "
589
+ f"that have a defined intent out of {num_examples} examples."
590
+ )
591
+
592
+ target_intents, predicted_intents = _targets_predictions_from(
593
+ intent_results, "intent_target", "intent_prediction"
594
+ )
595
+
596
+ report, precision, f1, accuracy, confusion_matrix, labels = _calculate_report(
597
+ output_directory, target_intents, predicted_intents, report_as_dict
598
+ )
599
+ if output_directory:
600
+ _dump_report(output_directory, "intent_report.json", report)
601
+
602
+ if successes and output_directory:
603
+ successes_filename = os.path.join(output_directory, "intent_successes.json")
604
+ # save classified samples to file for debugging
605
+ write_intent_successes(intent_results, successes_filename)
606
+
607
+ intent_errors = _get_intent_errors(intent_results)
608
+ if errors and output_directory:
609
+ errors_filename = os.path.join(output_directory, "intent_errors.json")
610
+ _write_errors(intent_errors, errors_filename, "intent")
611
+
612
+ if not disable_plotting:
613
+ confusion_matrix_filename = "intent_confusion_matrix.png"
614
+ if output_directory:
615
+ confusion_matrix_filename = os.path.join(
616
+ output_directory, confusion_matrix_filename
617
+ )
618
+ plot_utils.plot_confusion_matrix(
619
+ confusion_matrix,
620
+ classes=labels,
621
+ title="Intent Confusion matrix",
622
+ output_file=confusion_matrix_filename,
623
+ )
624
+
625
+ histogram_filename = "intent_histogram.png"
626
+ if output_directory:
627
+ histogram_filename = os.path.join(output_directory, histogram_filename)
628
+ plot_attribute_confidences(
629
+ intent_results,
630
+ histogram_filename,
631
+ "intent_target",
632
+ "intent_prediction",
633
+ title="Intent Prediction Confidence Distribution",
634
+ )
635
+
636
+ predictions = [
637
+ {
638
+ "text": res.message,
639
+ "intent": res.intent_target,
640
+ "predicted": res.intent_prediction,
641
+ "confidence": res.confidence,
642
+ }
643
+ for res in intent_results
644
+ ]
645
+
646
+ return {
647
+ "predictions": predictions,
648
+ "report": report,
649
+ "precision": precision,
650
+ "f1_score": f1,
651
+ "accuracy": accuracy,
652
+ "errors": intent_errors,
653
+ }
654
+
655
+
656
+ def _calculate_report(
657
+ output_directory: Optional[Text],
658
+ targets: Iterable[Any],
659
+ predictions: Iterable[Any],
660
+ report_as_dict: Optional[bool] = None,
661
+ exclude_label: Optional[Text] = None,
662
+ ) -> Tuple[Union[Text, Dict], float, float, float, np.ndarray, List[Text]]:
663
+ from rasa.model_testing import get_evaluation_metrics
664
+ import sklearn.metrics
665
+ import sklearn.utils.multiclass
666
+
667
+ confusion_matrix = sklearn.metrics.confusion_matrix(targets, predictions)
668
+ labels = sklearn.utils.multiclass.unique_labels(targets, predictions)
669
+
670
+ if report_as_dict is None:
671
+ report_as_dict = bool(output_directory)
672
+
673
+ report, precision, f1, accuracy = get_evaluation_metrics(
674
+ targets, predictions, output_dict=report_as_dict, exclude_label=exclude_label
675
+ )
676
+
677
+ if report_as_dict:
678
+ report = _add_confused_labels_to_report( # type: ignore[assignment]
679
+ report,
680
+ confusion_matrix,
681
+ labels,
682
+ exclude_labels=[exclude_label] if exclude_label else [],
683
+ )
684
+ elif not output_directory:
685
+ log_evaluation_table(report, precision, f1, accuracy)
686
+
687
+ return report, precision, f1, accuracy, confusion_matrix, labels
688
+
689
+
690
+ def _dump_report(output_directory: Text, filename: Text, report: Dict) -> None:
691
+ report_filename = os.path.join(output_directory, filename)
692
+ rasa.shared.utils.io.dump_obj_as_json_to_file(report_filename, report)
693
+ logger.info(f"Classification report saved to {report_filename}.")
694
+
695
+
696
+ def merge_labels(
697
+ aligned_predictions: List[Dict], extractor: Optional[Text] = None
698
+ ) -> List[Text]:
699
+ """Concatenates all labels of the aligned predictions.
700
+
701
+ Takes the aligned prediction labels which are grouped for each message
702
+ and concatenates them.
703
+
704
+ Args:
705
+ aligned_predictions: aligned predictions
706
+ extractor: entity extractor name
707
+
708
+ Returns:
709
+ Concatenated predictions
710
+ """
711
+ if extractor:
712
+ label_lists = [ap["extractor_labels"][extractor] for ap in aligned_predictions]
713
+ else:
714
+ label_lists = [ap["target_labels"] for ap in aligned_predictions]
715
+
716
+ return list(itertools.chain(*label_lists))
717
+
718
+
719
+ def merge_confidences(
720
+ aligned_predictions: List[Dict], extractor: Optional[Text] = None
721
+ ) -> List[float]:
722
+ """Concatenates all confidences of the aligned predictions.
723
+
724
+ Takes the aligned prediction confidences which are grouped for each message
725
+ and concatenates them.
726
+
727
+ Args:
728
+ aligned_predictions: aligned predictions
729
+ extractor: entity extractor name
730
+
731
+ Returns:
732
+ Concatenated confidences
733
+ """
734
+ label_lists = [ap["confidences"][extractor] for ap in aligned_predictions]
735
+ return list(itertools.chain(*label_lists))
736
+
737
+
738
+ def substitute_labels(labels: List[Text], old: Text, new: Text) -> List[Text]:
739
+ """Replaces label names in a list of labels.
740
+
741
+ Args:
742
+ labels: list of labels
743
+ old: old label name that should be replaced
744
+ new: new label name
745
+
746
+ Returns: updated labels
747
+ """
748
+ return [new if label == old else label for label in labels]
749
+
750
+
751
+ def collect_incorrect_entity_predictions(
752
+ entity_results: List[EntityEvaluationResult],
753
+ merged_predictions: List[Text],
754
+ merged_targets: List[Text],
755
+ ) -> List["EntityPrediction"]:
756
+ """Get incorrect entity predictions.
757
+
758
+ Args:
759
+ entity_results: entity evaluation results
760
+ merged_predictions: list of predicted entity labels
761
+ merged_targets: list of true entity labels
762
+
763
+ Returns: list of incorrect predictions
764
+ """
765
+ errors = []
766
+ offset = 0
767
+ for entity_result in entity_results:
768
+ for i in range(offset, offset + len(entity_result.tokens)):
769
+ if merged_targets[i] != merged_predictions[i]:
770
+ prediction: EntityPrediction = {
771
+ "text": entity_result.message,
772
+ "entities": entity_result.entity_targets,
773
+ "predicted_entities": entity_result.entity_predictions,
774
+ }
775
+ errors.append(prediction)
776
+ break
777
+ offset += len(entity_result.tokens)
778
+ return errors
779
+
780
+
781
+ def write_successful_entity_predictions(
782
+ entity_results: List[EntityEvaluationResult],
783
+ merged_targets: List[Text],
784
+ merged_predictions: List[Text],
785
+ successes_filename: Text,
786
+ ) -> None:
787
+ """Write correct entity predictions to a file.
788
+
789
+ Args:
790
+ entity_results: response selection evaluation result
791
+ merged_predictions: list of predicted entity labels
792
+ merged_targets: list of true entity labels
793
+ successes_filename: filename of file to save correct predictions to
794
+ """
795
+ successes = collect_successful_entity_predictions(
796
+ entity_results, merged_predictions, merged_targets
797
+ )
798
+
799
+ if successes:
800
+ rasa.shared.utils.io.dump_obj_as_json_to_file(successes_filename, successes)
801
+ logger.info(f"Successful entity predictions saved to {successes_filename}.")
802
+ structlogger.debug("test.write.entities", successes=copy.deepcopy(successes))
803
+ else:
804
+ logger.info("No successful entity prediction found.")
805
+
806
+
807
+ def collect_successful_entity_predictions(
808
+ entity_results: List[EntityEvaluationResult],
809
+ merged_predictions: List[Text],
810
+ merged_targets: List[Text],
811
+ ) -> List["EntityPrediction"]:
812
+ """Get correct entity predictions.
813
+
814
+ Args:
815
+ entity_results: entity evaluation results
816
+ merged_predictions: list of predicted entity labels
817
+ merged_targets: list of true entity labels
818
+
819
+ Returns: list of correct predictions
820
+ """
821
+ successes = []
822
+ offset = 0
823
+ for entity_result in entity_results:
824
+ for i in range(offset, offset + len(entity_result.tokens)):
825
+ if (
826
+ merged_targets[i] == merged_predictions[i]
827
+ and merged_targets[i] != NO_ENTITY
828
+ ):
829
+ prediction: EntityPrediction = {
830
+ "text": entity_result.message,
831
+ "entities": entity_result.entity_targets,
832
+ "predicted_entities": entity_result.entity_predictions,
833
+ }
834
+ successes.append(prediction)
835
+ break
836
+ offset += len(entity_result.tokens)
837
+ return successes
838
+
839
+
840
+ def evaluate_entities(
841
+ entity_results: List[EntityEvaluationResult],
842
+ extractors: Set[Text],
843
+ output_directory: Optional[Text],
844
+ successes: bool,
845
+ errors: bool,
846
+ disable_plotting: bool,
847
+ report_as_dict: Optional[bool] = None,
848
+ ) -> Dict: # pragma: no cover
849
+ """Creates summary statistics for each entity extractor.
850
+
851
+ Logs precision, recall, and F1 per entity type for each extractor.
852
+
853
+ Args:
854
+ entity_results: entity evaluation results
855
+ extractors: entity extractors to consider
856
+ output_directory: directory to store files to
857
+ successes: if True correct predictions are written to disk
858
+ errors: if True incorrect predictions are written to disk
859
+ disable_plotting: if True no plots are created
860
+ report_as_dict: `True` if the evaluation report should be returned as `dict`.
861
+ If `False` the report is returned in a human-readable text format. If `None`
862
+ `report_as_dict` is considered as `True` in case an `output_directory` is
863
+ given.
864
+
865
+ Returns: dictionary with evaluation results
866
+ """
867
+ aligned_predictions = align_all_entity_predictions(entity_results, extractors)
868
+ merged_targets = merge_labels(aligned_predictions)
869
+ merged_targets = substitute_labels(merged_targets, NO_ENTITY_TAG, NO_ENTITY)
870
+
871
+ result = {}
872
+
873
+ for extractor in extractors:
874
+ merged_predictions = merge_labels(aligned_predictions, extractor)
875
+ merged_predictions = substitute_labels(
876
+ merged_predictions, NO_ENTITY_TAG, NO_ENTITY
877
+ )
878
+
879
+ logger.info(f"Evaluation for entity extractor: {extractor} ")
880
+
881
+ report, precision, f1, accuracy, confusion_matrix, labels = _calculate_report(
882
+ output_directory,
883
+ merged_targets,
884
+ merged_predictions,
885
+ report_as_dict,
886
+ exclude_label=NO_ENTITY,
887
+ )
888
+ if output_directory:
889
+ _dump_report(output_directory, f"{extractor}_report.json", report)
890
+
891
+ if successes:
892
+ successes_filename = f"{extractor}_successes.json"
893
+ if output_directory:
894
+ successes_filename = os.path.join(output_directory, successes_filename)
895
+ # save classified samples to file for debugging
896
+ write_successful_entity_predictions(
897
+ entity_results, merged_targets, merged_predictions, successes_filename
898
+ )
899
+
900
+ entity_errors = collect_incorrect_entity_predictions(
901
+ entity_results, merged_predictions, merged_targets
902
+ )
903
+ if errors and output_directory:
904
+ errors_filename = os.path.join(output_directory, f"{extractor}_errors.json")
905
+
906
+ _write_errors(entity_errors, errors_filename, "entity")
907
+
908
+ if not disable_plotting:
909
+ confusion_matrix_filename = f"{extractor}_confusion_matrix.png"
910
+ if output_directory:
911
+ confusion_matrix_filename = os.path.join(
912
+ output_directory, confusion_matrix_filename
913
+ )
914
+ plot_utils.plot_confusion_matrix(
915
+ confusion_matrix,
916
+ classes=labels,
917
+ title="Entity Confusion matrix",
918
+ output_file=confusion_matrix_filename,
919
+ )
920
+
921
+ if extractor in EXTRACTORS_WITH_CONFIDENCES:
922
+ merged_confidences = merge_confidences(aligned_predictions, extractor)
923
+ histogram_filename = f"{extractor}_histogram.png"
924
+ if output_directory:
925
+ histogram_filename = os.path.join(
926
+ output_directory, histogram_filename
927
+ )
928
+ plot_entity_confidences(
929
+ merged_targets,
930
+ merged_predictions,
931
+ merged_confidences,
932
+ title="Entity Prediction Confidence Distribution",
933
+ hist_filename=histogram_filename,
934
+ )
935
+
936
+ result[extractor] = {
937
+ "report": report,
938
+ "precision": precision,
939
+ "f1_score": f1,
940
+ "accuracy": accuracy,
941
+ "errors": entity_errors,
942
+ }
943
+
944
+ return result
945
+
946
+
947
+ def is_token_within_entity(token: Token, entity: Dict) -> bool:
948
+ """Checks if a token is within the boundaries of an entity."""
949
+ return determine_intersection(token, entity) == len(token.text)
950
+
951
+
952
+ def does_token_cross_borders(token: Token, entity: Dict) -> bool:
953
+ """Checks if a token crosses the boundaries of an entity."""
954
+ num_intersect = determine_intersection(token, entity)
955
+ return 0 < num_intersect < len(token.text)
956
+
957
+
958
+ def determine_intersection(token: Token, entity: Dict) -> int:
959
+ """Calculates how many characters a given token and entity share."""
960
+ pos_token = set(range(token.start, token.end))
961
+ pos_entity = set(range(entity["start"], entity["end"]))
962
+ return len(pos_token.intersection(pos_entity))
963
+
964
+
965
+ def do_entities_overlap(entities: List[Dict]) -> bool:
966
+ """Checks if entities overlap.
967
+
968
+ I.e. cross each others start and end boundaries.
969
+
970
+ Args:
971
+ entities: list of entities
972
+
973
+ Returns: true if entities overlap, false otherwise.
974
+ """
975
+ sorted_entities = sorted(entities, key=lambda e: e["start"])
976
+ for i in range(len(sorted_entities) - 1):
977
+ curr_ent = sorted_entities[i]
978
+ next_ent = sorted_entities[i + 1]
979
+ if (
980
+ next_ent["start"] < curr_ent["end"]
981
+ and next_ent["entity"] != curr_ent["entity"]
982
+ ):
983
+ structlogger.warning(
984
+ "test.overlaping.entities",
985
+ current_entity=copy.deepcopy(curr_ent),
986
+ next_entity=copy.deepcopy(next_ent),
987
+ )
988
+ return True
989
+
990
+ return False
991
+
992
+
993
+ def find_intersecting_entities(token: Token, entities: List[Dict]) -> List[Dict]:
994
+ """Finds the entities that intersect with a token.
995
+
996
+ Args:
997
+ token: a single token
998
+ entities: entities found by a single extractor
999
+
1000
+ Returns: list of entities
1001
+ """
1002
+ candidates = []
1003
+ for e in entities:
1004
+ if is_token_within_entity(token, e):
1005
+ candidates.append(e)
1006
+ elif does_token_cross_borders(token, e):
1007
+ candidates.append(e)
1008
+ structlogger.debug(
1009
+ "test.intersecting.entities",
1010
+ token_text=copy.deepcopy(token.text),
1011
+ token_start=token.start,
1012
+ token_end=token.end,
1013
+ entity=copy.deepcopy(e),
1014
+ )
1015
+ return candidates
1016
+
1017
+
1018
+ def pick_best_entity_fit(
1019
+ token: Token, candidates: List[Dict[Text, Any]]
1020
+ ) -> Optional[Dict[Text, Any]]:
1021
+ """Determines the best fitting entity given intersecting entities.
1022
+
1023
+ Args:
1024
+ token: a single token
1025
+ candidates: entities found by a single extractor
1026
+ attribute_key: the attribute key of interest
1027
+
1028
+ Returns:
1029
+ the value of the attribute key of the best fitting entity
1030
+ """
1031
+ if len(candidates) == 0:
1032
+ return None
1033
+ elif len(candidates) == 1:
1034
+ return candidates[0]
1035
+ else:
1036
+ best_fit = np.argmax([determine_intersection(token, c) for c in candidates])
1037
+ return candidates[int(best_fit)]
1038
+
1039
+
1040
+ def determine_token_labels(
1041
+ token: Token,
1042
+ entities: List[Dict],
1043
+ extractors: Optional[Set[Text]] = None,
1044
+ attribute_key: Text = ENTITY_ATTRIBUTE_TYPE,
1045
+ ) -> Text:
1046
+ """Select token label for the provided attribute key for non-overlapping entities.
1047
+
1048
+ Args:
1049
+ token: a single token
1050
+ entities: entities found by a single extractor
1051
+ extractors: list of extractors
1052
+ attribute_key: the attribute key for which the entity type should be returned
1053
+ Returns:
1054
+ entity type
1055
+ """
1056
+ entity = determine_entity_for_token(token, entities, extractors)
1057
+
1058
+ if entity is None:
1059
+ return NO_ENTITY_TAG
1060
+
1061
+ label = entity.get(attribute_key)
1062
+
1063
+ if not label:
1064
+ return NO_ENTITY_TAG
1065
+
1066
+ return label
1067
+
1068
+
1069
+ def determine_entity_for_token(
1070
+ token: Token,
1071
+ entities: List[Dict[Text, Any]],
1072
+ extractors: Optional[Set[Text]] = None,
1073
+ ) -> Optional[Dict[Text, Any]]:
1074
+ """Determines the best fitting non-overlapping entity for the given token.
1075
+
1076
+ Args:
1077
+ token: a single token
1078
+ entities: entities found by a single extractor
1079
+ extractors: list of extractors
1080
+
1081
+ Returns:
1082
+ entity type
1083
+ """
1084
+ if entities is None or len(entities) == 0:
1085
+ return None
1086
+ if do_any_extractors_not_support_overlap(extractors) and do_entities_overlap(
1087
+ entities
1088
+ ):
1089
+ raise ValueError("The possible entities should not overlap.")
1090
+
1091
+ candidates = find_intersecting_entities(token, entities)
1092
+ return pick_best_entity_fit(token, candidates)
1093
+
1094
+
1095
+ def do_any_extractors_not_support_overlap(extractors: Optional[Set[Text]]) -> bool:
1096
+ """Checks if any extractor does not support overlapping entities.
1097
+
1098
+ Args:
1099
+ extractors: Names of the entity extractors
1100
+
1101
+ Returns:
1102
+ `True` if and only if CRFEntityExtractor or DIETClassifier is in `extractors`
1103
+ """
1104
+ if extractors is None:
1105
+ return False
1106
+
1107
+ from rasa.nlu.extractors.crf_entity_extractor import CRFEntityExtractor
1108
+ from rasa.nlu.classifiers.diet_classifier import DIETClassifier
1109
+
1110
+ return not extractors.isdisjoint(
1111
+ {CRFEntityExtractor.__name__, DIETClassifier.__name__}
1112
+ )
1113
+
1114
+
1115
+ def align_entity_predictions(
1116
+ result: EntityEvaluationResult, extractors: Set[Text]
1117
+ ) -> Dict:
1118
+ """Aligns entity predictions to the message tokens.
1119
+
1120
+ Determines for every token the true label based on the
1121
+ prediction targets and the label assigned by each
1122
+ single extractor.
1123
+
1124
+ Args:
1125
+ result: entity evaluation result
1126
+ extractors: the entity extractors that should be considered
1127
+
1128
+ Returns: dictionary containing the true token labels and token labels
1129
+ from the extractors
1130
+ """
1131
+ true_token_labels = []
1132
+ entities_by_extractors: Dict[Text, List] = {
1133
+ extractor: [] for extractor in extractors
1134
+ }
1135
+ for p in result.entity_predictions:
1136
+ entities_by_extractors[p[EXTRACTOR]].append(p)
1137
+ extractor_labels: Dict[Text, List] = {extractor: [] for extractor in extractors}
1138
+ extractor_confidences: Dict[Text, List] = {
1139
+ extractor: [] for extractor in extractors
1140
+ }
1141
+ for t in result.tokens:
1142
+ true_token_labels.append(_concat_entity_labels(t, result.entity_targets))
1143
+ for extractor, entities in entities_by_extractors.items():
1144
+ extracted_labels = _concat_entity_labels(t, entities, {extractor})
1145
+ extracted_confidences = _get_entity_confidences(t, entities, {extractor})
1146
+ extractor_labels[extractor].append(extracted_labels)
1147
+ extractor_confidences[extractor].append(extracted_confidences)
1148
+
1149
+ return {
1150
+ "target_labels": true_token_labels,
1151
+ "extractor_labels": extractor_labels,
1152
+ "confidences": extractor_confidences,
1153
+ }
1154
+
1155
+
1156
+ def _concat_entity_labels(
1157
+ token: Token, entities: List[Dict], extractors: Optional[Set[Text]] = None
1158
+ ) -> Text:
1159
+ """Concatenate labels for entity type, role, and group for evaluation.
1160
+
1161
+ In order to calculate metrics also for entity type, role, and group we need to
1162
+ concatenate their labels. For example, 'location.destination'. This allows
1163
+ us to report metrics for every combination of entity type, role, and group.
1164
+
1165
+ Args:
1166
+ token: the token we are looking at
1167
+ entities: the available entities
1168
+ extractors: the extractor of interest
1169
+
1170
+ Returns:
1171
+ the entity label of the provided token
1172
+ """
1173
+ entity_label = determine_token_labels(
1174
+ token, entities, extractors, ENTITY_ATTRIBUTE_TYPE
1175
+ )
1176
+ group_label = determine_token_labels(
1177
+ token, entities, extractors, ENTITY_ATTRIBUTE_GROUP
1178
+ )
1179
+ role_label = determine_token_labels(
1180
+ token, entities, extractors, ENTITY_ATTRIBUTE_ROLE
1181
+ )
1182
+
1183
+ if entity_label == role_label == group_label == NO_ENTITY_TAG:
1184
+ return NO_ENTITY_TAG
1185
+
1186
+ labels = [entity_label, group_label, role_label]
1187
+ labels = [label for label in labels if label != NO_ENTITY_TAG]
1188
+
1189
+ return ".".join(labels)
1190
+
1191
+
1192
+ def _get_entity_confidences(
1193
+ token: Token, entities: List[Dict], extractors: Optional[Set[Text]] = None
1194
+ ) -> float:
1195
+ """Get the confidence value of the best fitting entity.
1196
+
1197
+ If multiple confidence values are present, e.g. for type, role, group, we
1198
+ pick the lowest confidence value.
1199
+
1200
+ Args:
1201
+ token: the token we are looking at
1202
+ entities: the available entities
1203
+ extractors: the extractor of interest
1204
+
1205
+ Returns:
1206
+ the confidence value
1207
+ """
1208
+ entity = determine_entity_for_token(token, entities, extractors)
1209
+
1210
+ if entity is None:
1211
+ return 0.0
1212
+
1213
+ if entity.get("extractor") not in EXTRACTORS_WITH_CONFIDENCES:
1214
+ return 0.0
1215
+
1216
+ conf_type = entity.get(ENTITY_ATTRIBUTE_CONFIDENCE_TYPE) or 1.0
1217
+ conf_role = entity.get(ENTITY_ATTRIBUTE_CONFIDENCE_ROLE) or 1.0
1218
+ conf_group = entity.get(ENTITY_ATTRIBUTE_CONFIDENCE_GROUP) or 1.0
1219
+
1220
+ return min(conf_type, conf_role, conf_group)
1221
+
1222
+
1223
+ def align_all_entity_predictions(
1224
+ entity_results: List[EntityEvaluationResult], extractors: Set[Text]
1225
+ ) -> List[Dict]:
1226
+ """Aligns entity predictions to the message tokens.
1227
+
1228
+ Processes the whole dataset using align_entity_predictions.
1229
+
1230
+ Args:
1231
+ entity_results: list of entity prediction results
1232
+ extractors: the entity extractors that should be considered
1233
+
1234
+ Returns: list of dictionaries containing the true token labels and token
1235
+ labels from the extractors
1236
+ """
1237
+ aligned_predictions = []
1238
+ for result in entity_results:
1239
+ aligned_predictions.append(align_entity_predictions(result, extractors))
1240
+
1241
+ return aligned_predictions
1242
+
1243
+
1244
+ async def get_eval_data(
1245
+ processor: MessageProcessor, test_data: TrainingData
1246
+ ) -> Tuple[
1247
+ List[IntentEvaluationResult],
1248
+ List[ResponseSelectionEvaluationResult],
1249
+ List[EntityEvaluationResult],
1250
+ ]:
1251
+ """Runs the model for the test set and extracts targets and predictions.
1252
+
1253
+ Returns intent results (intent targets and predictions, the original
1254
+ messages and the confidences of the predictions), response results (
1255
+ response targets and predictions) as well as entity results
1256
+ (entity_targets, entity_predictions, and tokens).
1257
+
1258
+ Args:
1259
+ processor: the processor
1260
+ test_data: test data
1261
+
1262
+ Returns: intent, response, and entity evaluation results
1263
+ """
1264
+ logger.info("Running model for predictions:")
1265
+
1266
+ intent_results, entity_results, response_selection_results = [], [], []
1267
+
1268
+ response_labels = {
1269
+ e.get(INTENT_RESPONSE_KEY)
1270
+ for e in test_data.intent_examples
1271
+ if e.get(INTENT_RESPONSE_KEY) is not None
1272
+ }
1273
+ intent_labels = {e.get(INTENT) for e in test_data.intent_examples}
1274
+ should_eval_intents = len(intent_labels) >= 2
1275
+ should_eval_response_selection = len(response_labels) >= 2
1276
+ should_eval_entities = len(test_data.entity_examples) > 0
1277
+
1278
+ for example in tqdm(test_data.nlu_examples):
1279
+ result = await processor.parse_message(
1280
+ UserMessage(text=example.get(TEXT)),
1281
+ only_output_properties=False,
1282
+ )
1283
+ _remove_entities_of_extractors(result, PRETRAINED_EXTRACTORS)
1284
+ if should_eval_intents:
1285
+ if fallback_classifier.is_fallback_classifier_prediction(result):
1286
+ # Revert fallback prediction to not shadow
1287
+ # the wrongly predicted intent
1288
+ # during the test phase.
1289
+ result = fallback_classifier.undo_fallback_prediction(result)
1290
+ intent_prediction = result.get(INTENT, {})
1291
+ intent_results.append(
1292
+ IntentEvaluationResult(
1293
+ example.get(INTENT, ""),
1294
+ intent_prediction.get(INTENT_NAME_KEY),
1295
+ result.get(TEXT),
1296
+ intent_prediction.get("confidence"),
1297
+ )
1298
+ )
1299
+
1300
+ if should_eval_response_selection:
1301
+ # including all examples here. Empty response examples are filtered at the
1302
+ # time of metric calculation
1303
+ intent_target = example.get(INTENT, "")
1304
+ selector_properties = result.get(RESPONSE_SELECTOR_PROPERTY_NAME, {})
1305
+ response_selector_retrieval_intents = selector_properties.get(
1306
+ RESPONSE_SELECTOR_RETRIEVAL_INTENTS, set()
1307
+ )
1308
+ if (
1309
+ intent_target in response_selector_retrieval_intents
1310
+ and intent_target in selector_properties
1311
+ ):
1312
+ response_prediction_key = intent_target
1313
+ else:
1314
+ response_prediction_key = RESPONSE_SELECTOR_DEFAULT_INTENT
1315
+
1316
+ response_prediction = selector_properties.get(
1317
+ response_prediction_key, {}
1318
+ ).get(RESPONSE_SELECTOR_PREDICTION_KEY, {})
1319
+
1320
+ intent_response_key_target = example.get(INTENT_RESPONSE_KEY, "")
1321
+
1322
+ response_selection_results.append(
1323
+ ResponseSelectionEvaluationResult(
1324
+ intent_response_key_target,
1325
+ response_prediction.get(INTENT_RESPONSE_KEY),
1326
+ result.get(TEXT),
1327
+ response_prediction.get(PREDICTED_CONFIDENCE_KEY),
1328
+ )
1329
+ )
1330
+
1331
+ if should_eval_entities:
1332
+ entity_results.append(
1333
+ EntityEvaluationResult(
1334
+ example.get(ENTITIES, []),
1335
+ result.get(ENTITIES, []),
1336
+ result.get(TOKENS_NAMES[TEXT], []),
1337
+ result.get(TEXT),
1338
+ )
1339
+ )
1340
+
1341
+ return intent_results, response_selection_results, entity_results
1342
+
1343
+
1344
+ def _get_active_entity_extractors(
1345
+ entity_results: List[EntityEvaluationResult],
1346
+ ) -> Set[Text]:
1347
+ """Finds the names of entity extractors from the EntityEvaluationResults."""
1348
+ extractors: Set[Text] = set()
1349
+ for result in entity_results:
1350
+ for prediction in result.entity_predictions:
1351
+ if EXTRACTOR in prediction:
1352
+ extractors.add(prediction[EXTRACTOR])
1353
+ return extractors
1354
+
1355
+
1356
+ def _remove_entities_of_extractors(
1357
+ nlu_parse_result: Dict[Text, Any], extractor_names: Set[Text]
1358
+ ) -> None:
1359
+ """Removes the entities annotated by the given extractor names."""
1360
+ entities = nlu_parse_result.get(ENTITIES)
1361
+ if not entities:
1362
+ return
1363
+ filtered_entities = [e for e in entities if e.get(EXTRACTOR) not in extractor_names]
1364
+ nlu_parse_result[ENTITIES] = filtered_entities
1365
+
1366
+
1367
+ async def run_evaluation(
1368
+ data_path: Text,
1369
+ processor: MessageProcessor,
1370
+ output_directory: Optional[Text] = None,
1371
+ successes: bool = False,
1372
+ errors: bool = False,
1373
+ disable_plotting: bool = False,
1374
+ report_as_dict: Optional[bool] = None,
1375
+ domain_path: Optional[Text] = None,
1376
+ ) -> Dict: # pragma: no cover
1377
+ """Evaluate intent classification, response selection and entity extraction.
1378
+
1379
+ Args:
1380
+ data_path: path to the test data
1381
+ processor: the processor used to process and predict
1382
+ output_directory: path to folder where all output will be stored
1383
+ successes: if true successful predictions are written to a file
1384
+ errors: if true incorrect predictions are written to a file
1385
+ disable_plotting: if true confusion matrix and histogram will not be rendered
1386
+ report_as_dict: `True` if the evaluation report should be returned as `dict`.
1387
+ If `False` the report is returned in a human-readable text format. If `None`
1388
+ `report_as_dict` is considered as `True` in case an `output_directory` is
1389
+ given.
1390
+ domain_path: Path to the domain file(s).
1391
+
1392
+ Returns: dictionary containing evaluation results
1393
+ """
1394
+ import rasa.shared.nlu.training_data.loading
1395
+ from rasa.shared.constants import DEFAULT_DOMAIN_PATH
1396
+
1397
+ test_data_importer = TrainingDataImporter.load_from_dict(
1398
+ training_data_paths=[data_path],
1399
+ domain_path=domain_path if domain_path else DEFAULT_DOMAIN_PATH,
1400
+ )
1401
+ test_data = test_data_importer.get_nlu_data()
1402
+
1403
+ result: Dict[Text, Optional[Dict]] = {
1404
+ "intent_evaluation": None,
1405
+ "entity_evaluation": None,
1406
+ "response_selection_evaluation": None,
1407
+ }
1408
+
1409
+ if output_directory:
1410
+ rasa.shared.utils.io.create_directory(output_directory)
1411
+
1412
+ (intent_results, response_selection_results, entity_results) = await get_eval_data(
1413
+ processor, test_data
1414
+ )
1415
+
1416
+ if intent_results:
1417
+ logger.info("Intent evaluation results:")
1418
+ result["intent_evaluation"] = evaluate_intents(
1419
+ intent_results,
1420
+ output_directory,
1421
+ successes,
1422
+ errors,
1423
+ disable_plotting,
1424
+ report_as_dict=report_as_dict,
1425
+ )
1426
+
1427
+ if response_selection_results:
1428
+ logger.info("Response selection evaluation results:")
1429
+ result["response_selection_evaluation"] = evaluate_response_selections(
1430
+ response_selection_results,
1431
+ output_directory,
1432
+ successes,
1433
+ errors,
1434
+ disable_plotting,
1435
+ report_as_dict=report_as_dict,
1436
+ )
1437
+
1438
+ if any(entity_results):
1439
+ logger.info("Entity evaluation results:")
1440
+ extractors = _get_active_entity_extractors(entity_results)
1441
+ result["entity_evaluation"] = evaluate_entities(
1442
+ entity_results,
1443
+ extractors,
1444
+ output_directory,
1445
+ successes,
1446
+ errors,
1447
+ disable_plotting,
1448
+ report_as_dict=report_as_dict,
1449
+ )
1450
+
1451
+ telemetry.track_nlu_model_test(test_data)
1452
+
1453
+ return result
1454
+
1455
+
1456
+ def generate_folds(
1457
+ n: int, training_data: TrainingData
1458
+ ) -> Iterator[Tuple[TrainingData, TrainingData]]:
1459
+ """Generates n cross validation folds for given training data."""
1460
+ from sklearn.model_selection import StratifiedKFold
1461
+
1462
+ skf = StratifiedKFold(n_splits=n, shuffle=True)
1463
+ x = training_data.intent_examples
1464
+
1465
+ # Get labels as they appear in the training data because we want a
1466
+ # stratified split on all intents(including retrieval intents if they exist)
1467
+ y = [example.get_full_intent() for example in x]
1468
+ for i_fold, (train_index, test_index) in enumerate(skf.split(x, y)):
1469
+ logger.debug(f"Fold: {i_fold}")
1470
+ train = [x[i] for i in train_index]
1471
+ test = [x[i] for i in test_index]
1472
+ yield (
1473
+ TrainingData(
1474
+ training_examples=train,
1475
+ entity_synonyms=training_data.entity_synonyms,
1476
+ regex_features=training_data.regex_features,
1477
+ lookup_tables=training_data.lookup_tables,
1478
+ responses=training_data.responses,
1479
+ ),
1480
+ TrainingData(
1481
+ training_examples=test,
1482
+ entity_synonyms=training_data.entity_synonyms,
1483
+ regex_features=training_data.regex_features,
1484
+ lookup_tables=training_data.lookup_tables,
1485
+ responses=training_data.responses,
1486
+ ),
1487
+ )
1488
+
1489
+
1490
+ async def combine_result(
1491
+ intent_metrics: IntentMetrics,
1492
+ entity_metrics: EntityMetrics,
1493
+ response_selection_metrics: ResponseSelectionMetrics,
1494
+ processor: MessageProcessor,
1495
+ data: TrainingData,
1496
+ intent_results: Optional[List[IntentEvaluationResult]] = None,
1497
+ entity_results: Optional[List[EntityEvaluationResult]] = None,
1498
+ response_selection_results: Optional[
1499
+ List[ResponseSelectionEvaluationResult]
1500
+ ] = None,
1501
+ ) -> Tuple[IntentMetrics, EntityMetrics, ResponseSelectionMetrics]:
1502
+ """Collects intent, response selection and entity metrics for cross validation.
1503
+
1504
+ If `intent_results`, `response_selection_results` or `entity_results` is provided
1505
+ as a list, prediction results are also collected.
1506
+
1507
+ Args:
1508
+ intent_metrics: intent metrics
1509
+ entity_metrics: entity metrics
1510
+ response_selection_metrics: response selection metrics
1511
+ processor: the processor
1512
+ data: training data
1513
+ intent_results: intent evaluation results
1514
+ entity_results: entity evaluation results
1515
+ response_selection_results: reponse selection evaluation results
1516
+
1517
+ Returns: intent, entity, and response selection metrics
1518
+ """
1519
+ (
1520
+ intent_current_metrics,
1521
+ entity_current_metrics,
1522
+ response_selection_current_metrics,
1523
+ current_intent_results,
1524
+ current_entity_results,
1525
+ current_response_selection_results,
1526
+ ) = await compute_metrics(processor, data)
1527
+
1528
+ if intent_results is not None:
1529
+ intent_results += current_intent_results
1530
+
1531
+ if entity_results is not None:
1532
+ entity_results += current_entity_results
1533
+
1534
+ if response_selection_results is not None:
1535
+ response_selection_results += current_response_selection_results
1536
+
1537
+ for k, v in intent_current_metrics.items():
1538
+ intent_metrics[k] = v + intent_metrics[k]
1539
+
1540
+ for k, v in response_selection_current_metrics.items():
1541
+ response_selection_metrics[k] = v + response_selection_metrics[k]
1542
+
1543
+ for extractor, extractor_metric in entity_current_metrics.items():
1544
+ entity_metrics[extractor] = {
1545
+ k: v + entity_metrics[extractor][k] for k, v in extractor_metric.items()
1546
+ }
1547
+
1548
+ return intent_metrics, entity_metrics, response_selection_metrics
1549
+
1550
+
1551
+ def _contains_entity_labels(entity_results: List[EntityEvaluationResult]) -> bool:
1552
+ for result in entity_results:
1553
+ if result.entity_targets or result.entity_predictions:
1554
+ return True
1555
+ return False
1556
+
1557
+
1558
+ async def cross_validate(
1559
+ data: TrainingData,
1560
+ n_folds: int,
1561
+ nlu_config: Union[Text, Dict],
1562
+ output: Optional[Text] = None,
1563
+ successes: bool = False,
1564
+ errors: bool = False,
1565
+ disable_plotting: bool = False,
1566
+ report_as_dict: Optional[bool] = None,
1567
+ ) -> Tuple[CVEvaluationResult, CVEvaluationResult, CVEvaluationResult]:
1568
+ """Stratified cross validation on data.
1569
+
1570
+ Args:
1571
+ data: Training Data
1572
+ n_folds: integer, number of cv folds
1573
+ nlu_config: nlu config file
1574
+ output: path to folder where reports are stored
1575
+ successes: if true successful predictions are written to a file
1576
+ errors: if true incorrect predictions are written to a file
1577
+ disable_plotting: if true no confusion matrix and historgram plates are created
1578
+ report_as_dict: `True` if the evaluation report should be returned as `dict`.
1579
+ If `False` the report is returned in a human-readable text format. If `None`
1580
+ `report_as_dict` is considered as `True` in case an `output_directory` is
1581
+ given.
1582
+
1583
+ Returns:
1584
+ dictionary with key, list structure, where each entry in list
1585
+ corresponds to the relevant result for one fold
1586
+ """
1587
+ import rasa.model_training
1588
+
1589
+ with TempDirectoryPath(get_temp_dir_name()) as temp_dir:
1590
+ tmp_path = Path(temp_dir)
1591
+
1592
+ if isinstance(nlu_config, Dict):
1593
+ config_path = tmp_path / "config.yml"
1594
+ write_yaml(nlu_config, config_path)
1595
+ nlu_config = str(config_path)
1596
+
1597
+ if output:
1598
+ rasa.shared.utils.io.create_directory(output)
1599
+
1600
+ intent_train_metrics: IntentMetrics = defaultdict(list)
1601
+ intent_test_metrics: IntentMetrics = defaultdict(list)
1602
+ entity_train_metrics: EntityMetrics = defaultdict(lambda: defaultdict(list))
1603
+ entity_test_metrics: EntityMetrics = defaultdict(lambda: defaultdict(list))
1604
+ response_selection_train_metrics: ResponseSelectionMetrics = defaultdict(list)
1605
+ response_selection_test_metrics: ResponseSelectionMetrics = defaultdict(list)
1606
+
1607
+ intent_test_results: List[IntentEvaluationResult] = []
1608
+ entity_test_results: List[EntityEvaluationResult] = []
1609
+ response_selection_test_results: List[ResponseSelectionEvaluationResult] = []
1610
+
1611
+ for train, test in generate_folds(n_folds, data):
1612
+ training_data_file = tmp_path / "training_data.yml"
1613
+ RasaYAMLWriter().dump(training_data_file, train)
1614
+
1615
+ model_file = await rasa.model_training.train_nlu(
1616
+ nlu_config, str(training_data_file), str(tmp_path)
1617
+ )
1618
+
1619
+ processor = Agent.load(model_file).processor
1620
+
1621
+ # calculate train accuracy
1622
+ await combine_result(
1623
+ intent_train_metrics,
1624
+ entity_train_metrics,
1625
+ response_selection_train_metrics,
1626
+ processor,
1627
+ train,
1628
+ )
1629
+ # calculate test accuracy
1630
+ await combine_result(
1631
+ intent_test_metrics,
1632
+ entity_test_metrics,
1633
+ response_selection_test_metrics,
1634
+ processor,
1635
+ test,
1636
+ intent_test_results,
1637
+ entity_test_results,
1638
+ response_selection_test_results,
1639
+ )
1640
+
1641
+ intent_evaluation = {}
1642
+ if intent_test_results:
1643
+ logger.info("Accumulated test folds intent evaluation results:")
1644
+ intent_evaluation = evaluate_intents(
1645
+ intent_test_results,
1646
+ output,
1647
+ successes,
1648
+ errors,
1649
+ disable_plotting,
1650
+ report_as_dict=report_as_dict,
1651
+ )
1652
+
1653
+ entity_evaluation = {}
1654
+ if entity_test_results:
1655
+ logger.info("Accumulated test folds entity evaluation results:")
1656
+ extractors = _get_active_entity_extractors(entity_test_results)
1657
+ entity_evaluation = evaluate_entities(
1658
+ entity_test_results,
1659
+ extractors,
1660
+ output,
1661
+ successes,
1662
+ errors,
1663
+ disable_plotting,
1664
+ report_as_dict=report_as_dict,
1665
+ )
1666
+
1667
+ responses_evaluation = {}
1668
+ if response_selection_test_results:
1669
+ logger.info("Accumulated test folds response selection evaluation results:")
1670
+ responses_evaluation = evaluate_response_selections(
1671
+ response_selection_test_results,
1672
+ output,
1673
+ successes,
1674
+ errors,
1675
+ disable_plotting,
1676
+ report_as_dict=report_as_dict,
1677
+ )
1678
+
1679
+ return (
1680
+ CVEvaluationResult(
1681
+ dict(intent_train_metrics), dict(intent_test_metrics), intent_evaluation
1682
+ ),
1683
+ CVEvaluationResult(
1684
+ dict(entity_train_metrics), dict(entity_test_metrics), entity_evaluation
1685
+ ),
1686
+ CVEvaluationResult(
1687
+ dict(response_selection_train_metrics),
1688
+ dict(response_selection_test_metrics),
1689
+ responses_evaluation,
1690
+ ),
1691
+ )
1692
+
1693
+
1694
+ def _targets_predictions_from(
1695
+ results: Union[
1696
+ List[IntentEvaluationResult], List[ResponseSelectionEvaluationResult]
1697
+ ],
1698
+ target_key: Text,
1699
+ prediction_key: Text,
1700
+ ) -> Iterator[Iterable[Optional[Text]]]:
1701
+ return zip(*[(getattr(r, target_key), getattr(r, prediction_key)) for r in results])
1702
+
1703
+
1704
+ async def compute_metrics(
1705
+ processor: MessageProcessor, training_data: TrainingData
1706
+ ) -> Tuple[
1707
+ IntentMetrics,
1708
+ EntityMetrics,
1709
+ ResponseSelectionMetrics,
1710
+ List[IntentEvaluationResult],
1711
+ List[EntityEvaluationResult],
1712
+ List[ResponseSelectionEvaluationResult],
1713
+ ]:
1714
+ """Metrics for intent classification, response selection and entity extraction.
1715
+
1716
+ Args:
1717
+ processor: the processor
1718
+ training_data: training data
1719
+
1720
+ Returns: intent, response selection and entity metrics, and prediction results.
1721
+ """
1722
+ intent_results, response_selection_results, entity_results = await get_eval_data(
1723
+ processor, training_data
1724
+ )
1725
+
1726
+ intent_results = remove_empty_intent_examples(intent_results)
1727
+
1728
+ response_selection_results = remove_empty_response_examples(
1729
+ response_selection_results
1730
+ )
1731
+
1732
+ intent_metrics: IntentMetrics = {}
1733
+ if intent_results:
1734
+ intent_metrics = _compute_metrics(
1735
+ intent_results, "intent_target", "intent_prediction"
1736
+ )
1737
+
1738
+ entity_metrics = {}
1739
+ if entity_results:
1740
+ entity_metrics = _compute_entity_metrics(entity_results)
1741
+
1742
+ response_selection_metrics: ResponseSelectionMetrics = {}
1743
+ if response_selection_results:
1744
+ response_selection_metrics = _compute_metrics(
1745
+ response_selection_results,
1746
+ "intent_response_key_target",
1747
+ "intent_response_key_prediction",
1748
+ )
1749
+
1750
+ return (
1751
+ intent_metrics,
1752
+ entity_metrics,
1753
+ response_selection_metrics,
1754
+ intent_results,
1755
+ entity_results,
1756
+ response_selection_results,
1757
+ )
1758
+
1759
+
1760
+ async def compare_nlu(
1761
+ configs: List[Text],
1762
+ data: TrainingData,
1763
+ exclusion_percentages: List[int],
1764
+ f_score_results: Dict[Text, List[List[float]]],
1765
+ model_names: List[Text],
1766
+ output: Text,
1767
+ runs: int,
1768
+ ) -> List[int]:
1769
+ """Trains and compares multiple NLU models.
1770
+
1771
+ For each run and exclusion percentage a model per config file is trained.
1772
+ Thereby, the model is trained only on the current percentage of training data.
1773
+ Afterwards, the model is tested on the complete test data of that run.
1774
+ All results are stored in the provided output directory.
1775
+
1776
+ Args:
1777
+ configs: config files needed for training
1778
+ data: training data
1779
+ exclusion_percentages: percentages of training data to exclude during comparison
1780
+ f_score_results: dictionary of model name to f-score results per run
1781
+ model_names: names of the models to train
1782
+ output: the output directory
1783
+ runs: number of comparison runs
1784
+
1785
+ Returns: training examples per run
1786
+ """
1787
+ import rasa.model_training
1788
+
1789
+ training_examples_per_run = []
1790
+
1791
+ for run in range(runs):
1792
+ logger.info("Beginning comparison run {}/{}".format(run + 1, runs))
1793
+
1794
+ run_path = os.path.join(output, "run_{}".format(run + 1))
1795
+ io_utils.create_path(run_path)
1796
+
1797
+ test_path = os.path.join(run_path, TEST_DATA_FILE)
1798
+ io_utils.create_path(test_path)
1799
+
1800
+ train, test = data.train_test_split()
1801
+ rasa.shared.utils.io.write_text_file(test.nlu_as_yaml(), test_path)
1802
+
1803
+ for percentage in exclusion_percentages:
1804
+ percent_string = f"{percentage}%_exclusion"
1805
+
1806
+ _, train_included = train.train_test_split(percentage / 100)
1807
+ # only count for the first run and ignore the others
1808
+ if run == 0:
1809
+ training_examples_per_run.append(len(train_included.nlu_examples))
1810
+
1811
+ model_output_path = os.path.join(run_path, percent_string)
1812
+ train_split_path = os.path.join(model_output_path, "train")
1813
+ train_nlu_split_path = os.path.join(train_split_path, TRAIN_DATA_FILE)
1814
+ train_nlg_split_path = os.path.join(train_split_path, NLG_DATA_FILE)
1815
+ io_utils.create_path(train_nlu_split_path)
1816
+ rasa.shared.utils.io.write_text_file(
1817
+ train_included.nlu_as_yaml(), train_nlu_split_path
1818
+ )
1819
+ rasa.shared.utils.io.write_text_file(
1820
+ train_included.nlg_as_yaml(), train_nlg_split_path
1821
+ )
1822
+
1823
+ for nlu_config, model_name in zip(configs, model_names):
1824
+ logger.info(
1825
+ "Evaluating configuration '{}' with {} training data.".format(
1826
+ model_name, percent_string
1827
+ )
1828
+ )
1829
+
1830
+ try:
1831
+ model_path = await rasa.model_training.train_nlu(
1832
+ nlu_config,
1833
+ train_split_path,
1834
+ model_output_path,
1835
+ fixed_model_name=model_name,
1836
+ )
1837
+ except Exception as e: # skipcq: PYL-W0703
1838
+ # general exception catching needed to continue evaluating other
1839
+ # model configurations
1840
+ logger.warning(f"Training model '{model_name}' failed. Error: {e}")
1841
+ f_score_results[model_name][run].append(0.0)
1842
+ continue
1843
+
1844
+ output_path = os.path.join(model_output_path, f"{model_name}_report")
1845
+ processor = Agent.load(model_path=model_path).processor
1846
+ result = await run_evaluation(
1847
+ test_path, processor, output_directory=output_path, errors=True
1848
+ )
1849
+
1850
+ f1 = result["intent_evaluation"]["f1_score"]
1851
+ f_score_results[model_name][run].append(f1)
1852
+
1853
+ return training_examples_per_run
1854
+
1855
+
1856
+ def _compute_metrics(
1857
+ results: Union[
1858
+ List[IntentEvaluationResult], List[ResponseSelectionEvaluationResult]
1859
+ ],
1860
+ target_key: Text,
1861
+ prediction_key: Text,
1862
+ ) -> Union[IntentMetrics, ResponseSelectionMetrics]:
1863
+ """Computes evaluation metrics for a given corpus and returns the results.
1864
+
1865
+ Args:
1866
+ results: evaluation results
1867
+ target_key: target key name
1868
+ prediction_key: prediction key name
1869
+
1870
+ Returns: metrics
1871
+ """
1872
+ from rasa.model_testing import get_evaluation_metrics
1873
+
1874
+ # compute fold metrics
1875
+ targets, predictions = _targets_predictions_from(
1876
+ results, target_key, prediction_key
1877
+ )
1878
+ _, precision, f1, accuracy = get_evaluation_metrics(targets, predictions)
1879
+
1880
+ return {"Accuracy": [accuracy], "F1-score": [f1], "Precision": [precision]}
1881
+
1882
+
1883
+ def _compute_entity_metrics(
1884
+ entity_results: List[EntityEvaluationResult],
1885
+ ) -> EntityMetrics:
1886
+ """Computes entity evaluation metrics and returns the results.
1887
+
1888
+ Args:
1889
+ entity_results: entity evaluation results
1890
+ Returns: entity metrics
1891
+ """
1892
+ from rasa.model_testing import get_evaluation_metrics
1893
+
1894
+ entity_metric_results: EntityMetrics = defaultdict(lambda: defaultdict(list))
1895
+ extractors = _get_active_entity_extractors(entity_results)
1896
+
1897
+ if not extractors:
1898
+ return entity_metric_results
1899
+
1900
+ aligned_predictions = align_all_entity_predictions(entity_results, extractors)
1901
+
1902
+ merged_targets = merge_labels(aligned_predictions)
1903
+ merged_targets = substitute_labels(merged_targets, NO_ENTITY_TAG, NO_ENTITY)
1904
+
1905
+ for extractor in extractors:
1906
+ merged_predictions = merge_labels(aligned_predictions, extractor)
1907
+ merged_predictions = substitute_labels(
1908
+ merged_predictions, NO_ENTITY_TAG, NO_ENTITY
1909
+ )
1910
+ _, precision, f1, accuracy = get_evaluation_metrics(
1911
+ merged_targets, merged_predictions, exclude_label=NO_ENTITY
1912
+ )
1913
+ entity_metric_results[extractor]["Accuracy"].append(accuracy)
1914
+ entity_metric_results[extractor]["F1-score"].append(f1)
1915
+ entity_metric_results[extractor]["Precision"].append(precision)
1916
+
1917
+ return entity_metric_results
1918
+
1919
+
1920
+ def log_results(results: IntentMetrics, dataset_name: Text) -> None:
1921
+ """Logs results of cross validation.
1922
+
1923
+ Args:
1924
+ results: dictionary of results returned from cross validation
1925
+ dataset_name: string of which dataset the results are from, e.g. test/train
1926
+ """
1927
+ for k, v in results.items():
1928
+ logger.info(f"{dataset_name} {k}: {np.mean(v):.3f} ({np.std(v):.3f})")
1929
+
1930
+
1931
+ def log_entity_results(results: EntityMetrics, dataset_name: Text) -> None:
1932
+ """Logs entity results of cross validation.
1933
+
1934
+ Args:
1935
+ results: dictionary of dictionaries of results returned from cross validation
1936
+ dataset_name: string of which dataset the results are from, e.g. test/train
1937
+ """
1938
+ for extractor, result in results.items():
1939
+ logger.info(f"Entity extractor: {extractor}")
1940
+ log_results(result, dataset_name)