rasa-pro 3.8.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rasa-pro might be problematic. Click here for more details.

Files changed (644) hide show
  1. README.md +380 -0
  2. rasa/__init__.py +10 -0
  3. rasa/__main__.py +151 -0
  4. rasa/anonymization/__init__.py +2 -0
  5. rasa/anonymization/anonymisation_rule_yaml_reader.py +91 -0
  6. rasa/anonymization/anonymization_pipeline.py +287 -0
  7. rasa/anonymization/anonymization_rule_executor.py +260 -0
  8. rasa/anonymization/anonymization_rule_orchestrator.py +120 -0
  9. rasa/anonymization/schemas/config.yml +47 -0
  10. rasa/anonymization/utils.py +117 -0
  11. rasa/api.py +146 -0
  12. rasa/cli/__init__.py +5 -0
  13. rasa/cli/arguments/__init__.py +0 -0
  14. rasa/cli/arguments/data.py +81 -0
  15. rasa/cli/arguments/default_arguments.py +165 -0
  16. rasa/cli/arguments/evaluate.py +65 -0
  17. rasa/cli/arguments/export.py +51 -0
  18. rasa/cli/arguments/interactive.py +74 -0
  19. rasa/cli/arguments/run.py +204 -0
  20. rasa/cli/arguments/shell.py +13 -0
  21. rasa/cli/arguments/test.py +211 -0
  22. rasa/cli/arguments/train.py +263 -0
  23. rasa/cli/arguments/visualize.py +34 -0
  24. rasa/cli/arguments/x.py +30 -0
  25. rasa/cli/data.py +292 -0
  26. rasa/cli/e2e_test.py +566 -0
  27. rasa/cli/evaluate.py +222 -0
  28. rasa/cli/export.py +251 -0
  29. rasa/cli/inspect.py +63 -0
  30. rasa/cli/interactive.py +164 -0
  31. rasa/cli/license.py +65 -0
  32. rasa/cli/markers.py +78 -0
  33. rasa/cli/project_templates/__init__.py +0 -0
  34. rasa/cli/project_templates/calm/actions/__init__.py +0 -0
  35. rasa/cli/project_templates/calm/actions/action_template.py +27 -0
  36. rasa/cli/project_templates/calm/actions/add_contact.py +30 -0
  37. rasa/cli/project_templates/calm/actions/db.py +57 -0
  38. rasa/cli/project_templates/calm/actions/list_contacts.py +22 -0
  39. rasa/cli/project_templates/calm/actions/remove_contact.py +35 -0
  40. rasa/cli/project_templates/calm/config.yml +12 -0
  41. rasa/cli/project_templates/calm/credentials.yml +33 -0
  42. rasa/cli/project_templates/calm/data/flows/add_contact.yml +31 -0
  43. rasa/cli/project_templates/calm/data/flows/list_contacts.yml +14 -0
  44. rasa/cli/project_templates/calm/data/flows/remove_contact.yml +29 -0
  45. rasa/cli/project_templates/calm/db/contacts.json +10 -0
  46. rasa/cli/project_templates/calm/domain/add_contact.yml +33 -0
  47. rasa/cli/project_templates/calm/domain/list_contacts.yml +14 -0
  48. rasa/cli/project_templates/calm/domain/remove_contact.yml +31 -0
  49. rasa/cli/project_templates/calm/domain/shared.yml +5 -0
  50. rasa/cli/project_templates/calm/e2e_tests/cancelations/user_cancels_during_a_correction.yml +16 -0
  51. rasa/cli/project_templates/calm/e2e_tests/cancelations/user_changes_mind_on_a_whim.yml +7 -0
  52. rasa/cli/project_templates/calm/e2e_tests/corrections/user_corrects_contact_handle.yml +20 -0
  53. rasa/cli/project_templates/calm/e2e_tests/corrections/user_corrects_contact_name.yml +19 -0
  54. rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_adds_contact_to_their_list.yml +15 -0
  55. rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_lists_contacts.yml +5 -0
  56. rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_removes_contact.yml +11 -0
  57. rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_removes_contact_from_list.yml +12 -0
  58. rasa/cli/project_templates/calm/endpoints.yml +45 -0
  59. rasa/cli/project_templates/default/actions/__init__.py +0 -0
  60. rasa/cli/project_templates/default/actions/actions.py +27 -0
  61. rasa/cli/project_templates/default/config.yml +44 -0
  62. rasa/cli/project_templates/default/credentials.yml +33 -0
  63. rasa/cli/project_templates/default/data/nlu.yml +91 -0
  64. rasa/cli/project_templates/default/data/rules.yml +13 -0
  65. rasa/cli/project_templates/default/data/stories.yml +30 -0
  66. rasa/cli/project_templates/default/domain.yml +34 -0
  67. rasa/cli/project_templates/default/endpoints.yml +42 -0
  68. rasa/cli/project_templates/default/tests/test_stories.yml +91 -0
  69. rasa/cli/project_templates/tutorial/actions.py +22 -0
  70. rasa/cli/project_templates/tutorial/config.yml +11 -0
  71. rasa/cli/project_templates/tutorial/credentials.yml +33 -0
  72. rasa/cli/project_templates/tutorial/data/flows.yml +8 -0
  73. rasa/cli/project_templates/tutorial/domain.yml +17 -0
  74. rasa/cli/project_templates/tutorial/endpoints.yml +45 -0
  75. rasa/cli/run.py +136 -0
  76. rasa/cli/scaffold.py +268 -0
  77. rasa/cli/shell.py +141 -0
  78. rasa/cli/studio/__init__.py +0 -0
  79. rasa/cli/studio/download.py +51 -0
  80. rasa/cli/studio/studio.py +110 -0
  81. rasa/cli/studio/train.py +59 -0
  82. rasa/cli/studio/upload.py +85 -0
  83. rasa/cli/telemetry.py +90 -0
  84. rasa/cli/test.py +280 -0
  85. rasa/cli/train.py +260 -0
  86. rasa/cli/utils.py +453 -0
  87. rasa/cli/visualize.py +40 -0
  88. rasa/cli/x.py +205 -0
  89. rasa/constants.py +37 -0
  90. rasa/core/__init__.py +17 -0
  91. rasa/core/actions/__init__.py +0 -0
  92. rasa/core/actions/action.py +1450 -0
  93. rasa/core/actions/action_clean_stack.py +59 -0
  94. rasa/core/actions/action_run_slot_rejections.py +207 -0
  95. rasa/core/actions/action_trigger_chitchat.py +31 -0
  96. rasa/core/actions/action_trigger_flow.py +109 -0
  97. rasa/core/actions/action_trigger_search.py +31 -0
  98. rasa/core/actions/constants.py +2 -0
  99. rasa/core/actions/forms.py +737 -0
  100. rasa/core/actions/loops.py +111 -0
  101. rasa/core/actions/two_stage_fallback.py +186 -0
  102. rasa/core/agent.py +557 -0
  103. rasa/core/auth_retry_tracker_store.py +122 -0
  104. rasa/core/brokers/__init__.py +0 -0
  105. rasa/core/brokers/broker.py +126 -0
  106. rasa/core/brokers/file.py +58 -0
  107. rasa/core/brokers/kafka.py +322 -0
  108. rasa/core/brokers/pika.py +387 -0
  109. rasa/core/brokers/sql.py +86 -0
  110. rasa/core/channels/__init__.py +55 -0
  111. rasa/core/channels/audiocodes.py +463 -0
  112. rasa/core/channels/botframework.py +339 -0
  113. rasa/core/channels/callback.py +85 -0
  114. rasa/core/channels/channel.py +419 -0
  115. rasa/core/channels/console.py +243 -0
  116. rasa/core/channels/development_inspector.py +93 -0
  117. rasa/core/channels/facebook.py +422 -0
  118. rasa/core/channels/hangouts.py +335 -0
  119. rasa/core/channels/inspector/.eslintrc.cjs +25 -0
  120. rasa/core/channels/inspector/.gitignore +23 -0
  121. rasa/core/channels/inspector/README.md +54 -0
  122. rasa/core/channels/inspector/assets/favicon.ico +0 -0
  123. rasa/core/channels/inspector/assets/rasa-chat.js +2 -0
  124. rasa/core/channels/inspector/custom.d.ts +3 -0
  125. rasa/core/channels/inspector/dist/assets/arc-5623b6dc.js +1 -0
  126. rasa/core/channels/inspector/dist/assets/array-9f3ba611.js +1 -0
  127. rasa/core/channels/inspector/dist/assets/c4Diagram-d0fbc5ce-685c106a.js +10 -0
  128. rasa/core/channels/inspector/dist/assets/classDiagram-936ed81e-8cbed007.js +2 -0
  129. rasa/core/channels/inspector/dist/assets/classDiagram-v2-c3cb15f1-5889cf12.js +2 -0
  130. rasa/core/channels/inspector/dist/assets/createText-62fc7601-24c249d7.js +7 -0
  131. rasa/core/channels/inspector/dist/assets/edges-f2ad444c-7dd06a75.js +4 -0
  132. rasa/core/channels/inspector/dist/assets/erDiagram-9d236eb7-62c1e54c.js +51 -0
  133. rasa/core/channels/inspector/dist/assets/flowDb-1972c806-ce49b86f.js +6 -0
  134. rasa/core/channels/inspector/dist/assets/flowDiagram-7ea5b25a-4067e48f.js +4 -0
  135. rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-85583a23.js +1 -0
  136. rasa/core/channels/inspector/dist/assets/flowchart-elk-definition-abe16c3d-59fe4051.js +139 -0
  137. rasa/core/channels/inspector/dist/assets/ganttDiagram-9b5ea136-47e3a43b.js +266 -0
  138. rasa/core/channels/inspector/dist/assets/gitGraphDiagram-99d0ae7c-5a2ac0d9.js +70 -0
  139. rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-128cfa44.ttf +0 -0
  140. rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-21dbcb97.woff +0 -0
  141. rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-222b5e26.svg +329 -0
  142. rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-9ad89b2a.woff2 +0 -0
  143. rasa/core/channels/inspector/dist/assets/index-268a75c0.js +1040 -0
  144. rasa/core/channels/inspector/dist/assets/index-2c4b9a3b-dfb8efc4.js +1 -0
  145. rasa/core/channels/inspector/dist/assets/index-3ee28881.css +1 -0
  146. rasa/core/channels/inspector/dist/assets/infoDiagram-736b4530-b0c470f2.js +7 -0
  147. rasa/core/channels/inspector/dist/assets/init-77b53fdd.js +1 -0
  148. rasa/core/channels/inspector/dist/assets/journeyDiagram-df861f2b-2edb829a.js +139 -0
  149. rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-60c05ee4.woff +0 -0
  150. rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-8335d9b8.svg +438 -0
  151. rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-9cc39c75.ttf +0 -0
  152. rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-ead13ccf.woff2 +0 -0
  153. rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-16705655.woff2 +0 -0
  154. rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-5aeb07f9.woff +0 -0
  155. rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-9c459044.ttf +0 -0
  156. rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-9e2898a4.svg +435 -0
  157. rasa/core/channels/inspector/dist/assets/layout-b6873d69.js +1 -0
  158. rasa/core/channels/inspector/dist/assets/line-1efc5781.js +1 -0
  159. rasa/core/channels/inspector/dist/assets/linear-661e9b94.js +1 -0
  160. rasa/core/channels/inspector/dist/assets/mindmap-definition-beec6740-2d2e727f.js +109 -0
  161. rasa/core/channels/inspector/dist/assets/ordinal-ba9b4969.js +1 -0
  162. rasa/core/channels/inspector/dist/assets/path-53f90ab3.js +1 -0
  163. rasa/core/channels/inspector/dist/assets/pieDiagram-dbbf0591-9d3ea93d.js +35 -0
  164. rasa/core/channels/inspector/dist/assets/quadrantDiagram-4d7f4fd6-06a178a2.js +7 -0
  165. rasa/core/channels/inspector/dist/assets/requirementDiagram-6fc4c22a-0bfedffc.js +52 -0
  166. rasa/core/channels/inspector/dist/assets/sankeyDiagram-8f13d901-d76d0a04.js +8 -0
  167. rasa/core/channels/inspector/dist/assets/sequenceDiagram-b655622a-37bb4341.js +122 -0
  168. rasa/core/channels/inspector/dist/assets/stateDiagram-59f0c015-f52f7f57.js +1 -0
  169. rasa/core/channels/inspector/dist/assets/stateDiagram-v2-2b26beab-4a986a20.js +1 -0
  170. rasa/core/channels/inspector/dist/assets/styles-080da4f6-7dd9ae12.js +110 -0
  171. rasa/core/channels/inspector/dist/assets/styles-3dcbcfbf-46e1ca14.js +159 -0
  172. rasa/core/channels/inspector/dist/assets/styles-9c745c82-4a97439a.js +207 -0
  173. rasa/core/channels/inspector/dist/assets/svgDrawCommon-4835440b-823917a3.js +1 -0
  174. rasa/core/channels/inspector/dist/assets/timeline-definition-5b62e21b-9ea72896.js +61 -0
  175. rasa/core/channels/inspector/dist/assets/xychartDiagram-2b33534f-b631a8b6.js +7 -0
  176. rasa/core/channels/inspector/dist/index.html +39 -0
  177. rasa/core/channels/inspector/index.html +37 -0
  178. rasa/core/channels/inspector/jest.config.ts +13 -0
  179. rasa/core/channels/inspector/package.json +48 -0
  180. rasa/core/channels/inspector/setupTests.ts +2 -0
  181. rasa/core/channels/inspector/src/App.tsx +170 -0
  182. rasa/core/channels/inspector/src/components/DiagramFlow.tsx +97 -0
  183. rasa/core/channels/inspector/src/components/DialogueInformation.tsx +187 -0
  184. rasa/core/channels/inspector/src/components/DialogueStack.tsx +151 -0
  185. rasa/core/channels/inspector/src/components/ExpandIcon.tsx +16 -0
  186. rasa/core/channels/inspector/src/components/FullscreenButton.tsx +45 -0
  187. rasa/core/channels/inspector/src/components/LoadingSpinner.tsx +19 -0
  188. rasa/core/channels/inspector/src/components/NoActiveFlow.tsx +21 -0
  189. rasa/core/channels/inspector/src/components/RasaLogo.tsx +32 -0
  190. rasa/core/channels/inspector/src/components/SaraDiagrams.tsx +39 -0
  191. rasa/core/channels/inspector/src/components/Slots.tsx +91 -0
  192. rasa/core/channels/inspector/src/components/Welcome.tsx +54 -0
  193. rasa/core/channels/inspector/src/helpers/formatters.test.ts +385 -0
  194. rasa/core/channels/inspector/src/helpers/formatters.ts +239 -0
  195. rasa/core/channels/inspector/src/helpers/utils.ts +42 -0
  196. rasa/core/channels/inspector/src/main.tsx +13 -0
  197. rasa/core/channels/inspector/src/theme/Button/Button.ts +29 -0
  198. rasa/core/channels/inspector/src/theme/Heading/Heading.ts +31 -0
  199. rasa/core/channels/inspector/src/theme/Input/Input.ts +27 -0
  200. rasa/core/channels/inspector/src/theme/Link/Link.ts +10 -0
  201. rasa/core/channels/inspector/src/theme/Modal/Modal.ts +47 -0
  202. rasa/core/channels/inspector/src/theme/Table/Table.tsx +38 -0
  203. rasa/core/channels/inspector/src/theme/Tooltip/Tooltip.ts +12 -0
  204. rasa/core/channels/inspector/src/theme/base/breakpoints.ts +8 -0
  205. rasa/core/channels/inspector/src/theme/base/colors.ts +88 -0
  206. rasa/core/channels/inspector/src/theme/base/fonts/fontFaces.css +29 -0
  207. rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.eot +0 -0
  208. rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.svg +329 -0
  209. rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.ttf +0 -0
  210. rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.woff +0 -0
  211. rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.woff2 +0 -0
  212. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.eot +0 -0
  213. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.svg +438 -0
  214. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.ttf +0 -0
  215. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.woff +0 -0
  216. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.woff2 +0 -0
  217. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.eot +0 -0
  218. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.svg +435 -0
  219. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.ttf +0 -0
  220. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.woff +0 -0
  221. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.woff2 +0 -0
  222. rasa/core/channels/inspector/src/theme/base/radii.ts +9 -0
  223. rasa/core/channels/inspector/src/theme/base/shadows.ts +7 -0
  224. rasa/core/channels/inspector/src/theme/base/sizes.ts +7 -0
  225. rasa/core/channels/inspector/src/theme/base/space.ts +15 -0
  226. rasa/core/channels/inspector/src/theme/base/styles.ts +13 -0
  227. rasa/core/channels/inspector/src/theme/base/typography.ts +24 -0
  228. rasa/core/channels/inspector/src/theme/base/zIndices.ts +19 -0
  229. rasa/core/channels/inspector/src/theme/index.ts +101 -0
  230. rasa/core/channels/inspector/src/types.ts +64 -0
  231. rasa/core/channels/inspector/src/vite-env.d.ts +1 -0
  232. rasa/core/channels/inspector/tests/__mocks__/fileMock.ts +1 -0
  233. rasa/core/channels/inspector/tests/__mocks__/matchMedia.ts +16 -0
  234. rasa/core/channels/inspector/tests/__mocks__/styleMock.ts +1 -0
  235. rasa/core/channels/inspector/tests/renderWithProviders.tsx +14 -0
  236. rasa/core/channels/inspector/tsconfig.json +26 -0
  237. rasa/core/channels/inspector/tsconfig.node.json +10 -0
  238. rasa/core/channels/inspector/vite.config.ts +8 -0
  239. rasa/core/channels/inspector/yarn.lock +6156 -0
  240. rasa/core/channels/mattermost.py +229 -0
  241. rasa/core/channels/rasa_chat.py +126 -0
  242. rasa/core/channels/rest.py +210 -0
  243. rasa/core/channels/rocketchat.py +175 -0
  244. rasa/core/channels/slack.py +620 -0
  245. rasa/core/channels/socketio.py +274 -0
  246. rasa/core/channels/telegram.py +298 -0
  247. rasa/core/channels/twilio.py +169 -0
  248. rasa/core/channels/twilio_voice.py +367 -0
  249. rasa/core/channels/vier_cvg.py +374 -0
  250. rasa/core/channels/webexteams.py +135 -0
  251. rasa/core/concurrent_lock_store.py +210 -0
  252. rasa/core/constants.py +107 -0
  253. rasa/core/evaluation/__init__.py +0 -0
  254. rasa/core/evaluation/marker.py +267 -0
  255. rasa/core/evaluation/marker_base.py +925 -0
  256. rasa/core/evaluation/marker_stats.py +294 -0
  257. rasa/core/evaluation/marker_tracker_loader.py +103 -0
  258. rasa/core/exceptions.py +29 -0
  259. rasa/core/exporter.py +284 -0
  260. rasa/core/featurizers/__init__.py +0 -0
  261. rasa/core/featurizers/precomputation.py +410 -0
  262. rasa/core/featurizers/single_state_featurizer.py +402 -0
  263. rasa/core/featurizers/tracker_featurizers.py +1172 -0
  264. rasa/core/http_interpreter.py +89 -0
  265. rasa/core/information_retrieval/__init__.py +0 -0
  266. rasa/core/information_retrieval/faiss.py +116 -0
  267. rasa/core/information_retrieval/information_retrieval.py +72 -0
  268. rasa/core/information_retrieval/milvus.py +59 -0
  269. rasa/core/information_retrieval/qdrant.py +102 -0
  270. rasa/core/jobs.py +63 -0
  271. rasa/core/lock.py +139 -0
  272. rasa/core/lock_store.py +344 -0
  273. rasa/core/migrate.py +404 -0
  274. rasa/core/nlg/__init__.py +3 -0
  275. rasa/core/nlg/callback.py +147 -0
  276. rasa/core/nlg/contextual_response_rephraser.py +270 -0
  277. rasa/core/nlg/generator.py +230 -0
  278. rasa/core/nlg/interpolator.py +143 -0
  279. rasa/core/nlg/response.py +155 -0
  280. rasa/core/nlg/summarize.py +69 -0
  281. rasa/core/policies/__init__.py +0 -0
  282. rasa/core/policies/ensemble.py +329 -0
  283. rasa/core/policies/enterprise_search_policy.py +717 -0
  284. rasa/core/policies/enterprise_search_prompt_template.jinja2 +62 -0
  285. rasa/core/policies/flow_policy.py +205 -0
  286. rasa/core/policies/flows/__init__.py +0 -0
  287. rasa/core/policies/flows/flow_exceptions.py +44 -0
  288. rasa/core/policies/flows/flow_executor.py +582 -0
  289. rasa/core/policies/flows/flow_step_result.py +43 -0
  290. rasa/core/policies/intentless_policy.py +924 -0
  291. rasa/core/policies/intentless_prompt_template.jinja2 +22 -0
  292. rasa/core/policies/memoization.py +538 -0
  293. rasa/core/policies/policy.py +716 -0
  294. rasa/core/policies/rule_policy.py +1276 -0
  295. rasa/core/policies/ted_policy.py +2146 -0
  296. rasa/core/policies/unexpected_intent_policy.py +1015 -0
  297. rasa/core/processor.py +1331 -0
  298. rasa/core/run.py +315 -0
  299. rasa/core/secrets_manager/__init__.py +0 -0
  300. rasa/core/secrets_manager/constants.py +32 -0
  301. rasa/core/secrets_manager/endpoints.py +391 -0
  302. rasa/core/secrets_manager/factory.py +233 -0
  303. rasa/core/secrets_manager/secret_manager.py +262 -0
  304. rasa/core/secrets_manager/vault.py +576 -0
  305. rasa/core/test.py +1337 -0
  306. rasa/core/tracker_store.py +1664 -0
  307. rasa/core/train.py +107 -0
  308. rasa/core/training/__init__.py +89 -0
  309. rasa/core/training/converters/__init__.py +0 -0
  310. rasa/core/training/converters/responses_prefix_converter.py +119 -0
  311. rasa/core/training/interactive.py +1742 -0
  312. rasa/core/training/story_conflict.py +381 -0
  313. rasa/core/training/training.py +93 -0
  314. rasa/core/utils.py +344 -0
  315. rasa/core/visualize.py +70 -0
  316. rasa/dialogue_understanding/__init__.py +0 -0
  317. rasa/dialogue_understanding/coexistence/__init__.py +0 -0
  318. rasa/dialogue_understanding/coexistence/constants.py +4 -0
  319. rasa/dialogue_understanding/coexistence/intent_based_router.py +189 -0
  320. rasa/dialogue_understanding/coexistence/llm_based_router.py +261 -0
  321. rasa/dialogue_understanding/coexistence/router_template.jinja2 +12 -0
  322. rasa/dialogue_understanding/commands/__init__.py +45 -0
  323. rasa/dialogue_understanding/commands/can_not_handle_command.py +61 -0
  324. rasa/dialogue_understanding/commands/cancel_flow_command.py +116 -0
  325. rasa/dialogue_understanding/commands/chit_chat_answer_command.py +48 -0
  326. rasa/dialogue_understanding/commands/clarify_command.py +77 -0
  327. rasa/dialogue_understanding/commands/command.py +85 -0
  328. rasa/dialogue_understanding/commands/correct_slots_command.py +288 -0
  329. rasa/dialogue_understanding/commands/error_command.py +67 -0
  330. rasa/dialogue_understanding/commands/free_form_answer_command.py +9 -0
  331. rasa/dialogue_understanding/commands/handle_code_change_command.py +64 -0
  332. rasa/dialogue_understanding/commands/human_handoff_command.py +57 -0
  333. rasa/dialogue_understanding/commands/knowledge_answer_command.py +48 -0
  334. rasa/dialogue_understanding/commands/noop_command.py +45 -0
  335. rasa/dialogue_understanding/commands/set_slot_command.py +125 -0
  336. rasa/dialogue_understanding/commands/skip_question_command.py +66 -0
  337. rasa/dialogue_understanding/commands/start_flow_command.py +98 -0
  338. rasa/dialogue_understanding/generator/__init__.py +6 -0
  339. rasa/dialogue_understanding/generator/command_generator.py +257 -0
  340. rasa/dialogue_understanding/generator/command_prompt_template.jinja2 +57 -0
  341. rasa/dialogue_understanding/generator/flow_document_template.jinja2 +4 -0
  342. rasa/dialogue_understanding/generator/flow_retrieval.py +410 -0
  343. rasa/dialogue_understanding/generator/llm_command_generator.py +637 -0
  344. rasa/dialogue_understanding/generator/nlu_command_adapter.py +157 -0
  345. rasa/dialogue_understanding/patterns/__init__.py +0 -0
  346. rasa/dialogue_understanding/patterns/cancel.py +111 -0
  347. rasa/dialogue_understanding/patterns/cannot_handle.py +43 -0
  348. rasa/dialogue_understanding/patterns/chitchat.py +37 -0
  349. rasa/dialogue_understanding/patterns/clarify.py +97 -0
  350. rasa/dialogue_understanding/patterns/code_change.py +41 -0
  351. rasa/dialogue_understanding/patterns/collect_information.py +90 -0
  352. rasa/dialogue_understanding/patterns/completed.py +40 -0
  353. rasa/dialogue_understanding/patterns/continue_interrupted.py +42 -0
  354. rasa/dialogue_understanding/patterns/correction.py +278 -0
  355. rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +243 -0
  356. rasa/dialogue_understanding/patterns/human_handoff.py +37 -0
  357. rasa/dialogue_understanding/patterns/internal_error.py +47 -0
  358. rasa/dialogue_understanding/patterns/search.py +37 -0
  359. rasa/dialogue_understanding/patterns/skip_question.py +38 -0
  360. rasa/dialogue_understanding/processor/__init__.py +0 -0
  361. rasa/dialogue_understanding/processor/command_processor.py +578 -0
  362. rasa/dialogue_understanding/processor/command_processor_component.py +39 -0
  363. rasa/dialogue_understanding/stack/__init__.py +0 -0
  364. rasa/dialogue_understanding/stack/dialogue_stack.py +178 -0
  365. rasa/dialogue_understanding/stack/frames/__init__.py +19 -0
  366. rasa/dialogue_understanding/stack/frames/chit_chat_frame.py +27 -0
  367. rasa/dialogue_understanding/stack/frames/dialogue_stack_frame.py +137 -0
  368. rasa/dialogue_understanding/stack/frames/flow_stack_frame.py +157 -0
  369. rasa/dialogue_understanding/stack/frames/pattern_frame.py +10 -0
  370. rasa/dialogue_understanding/stack/frames/search_frame.py +27 -0
  371. rasa/dialogue_understanding/stack/utils.py +211 -0
  372. rasa/e2e_test/__init__.py +0 -0
  373. rasa/e2e_test/constants.py +10 -0
  374. rasa/e2e_test/e2e_test_case.py +322 -0
  375. rasa/e2e_test/e2e_test_result.py +34 -0
  376. rasa/e2e_test/e2e_test_runner.py +659 -0
  377. rasa/e2e_test/e2e_test_schema.yml +67 -0
  378. rasa/engine/__init__.py +0 -0
  379. rasa/engine/caching.py +464 -0
  380. rasa/engine/constants.py +17 -0
  381. rasa/engine/exceptions.py +14 -0
  382. rasa/engine/graph.py +625 -0
  383. rasa/engine/loader.py +36 -0
  384. rasa/engine/recipes/__init__.py +0 -0
  385. rasa/engine/recipes/config_files/default_config.yml +44 -0
  386. rasa/engine/recipes/default_components.py +99 -0
  387. rasa/engine/recipes/default_recipe.py +1252 -0
  388. rasa/engine/recipes/graph_recipe.py +79 -0
  389. rasa/engine/recipes/recipe.py +93 -0
  390. rasa/engine/runner/__init__.py +0 -0
  391. rasa/engine/runner/dask.py +256 -0
  392. rasa/engine/runner/interface.py +49 -0
  393. rasa/engine/storage/__init__.py +0 -0
  394. rasa/engine/storage/local_model_storage.py +248 -0
  395. rasa/engine/storage/resource.py +110 -0
  396. rasa/engine/storage/storage.py +203 -0
  397. rasa/engine/training/__init__.py +0 -0
  398. rasa/engine/training/components.py +176 -0
  399. rasa/engine/training/fingerprinting.py +64 -0
  400. rasa/engine/training/graph_trainer.py +256 -0
  401. rasa/engine/training/hooks.py +164 -0
  402. rasa/engine/validation.py +839 -0
  403. rasa/env.py +5 -0
  404. rasa/exceptions.py +69 -0
  405. rasa/graph_components/__init__.py +0 -0
  406. rasa/graph_components/converters/__init__.py +0 -0
  407. rasa/graph_components/converters/nlu_message_converter.py +48 -0
  408. rasa/graph_components/providers/__init__.py +0 -0
  409. rasa/graph_components/providers/domain_for_core_training_provider.py +87 -0
  410. rasa/graph_components/providers/domain_provider.py +71 -0
  411. rasa/graph_components/providers/flows_provider.py +74 -0
  412. rasa/graph_components/providers/forms_provider.py +44 -0
  413. rasa/graph_components/providers/nlu_training_data_provider.py +56 -0
  414. rasa/graph_components/providers/responses_provider.py +44 -0
  415. rasa/graph_components/providers/rule_only_provider.py +49 -0
  416. rasa/graph_components/providers/story_graph_provider.py +43 -0
  417. rasa/graph_components/providers/training_tracker_provider.py +55 -0
  418. rasa/graph_components/validators/__init__.py +0 -0
  419. rasa/graph_components/validators/default_recipe_validator.py +552 -0
  420. rasa/graph_components/validators/finetuning_validator.py +302 -0
  421. rasa/hooks.py +113 -0
  422. rasa/jupyter.py +63 -0
  423. rasa/keys +1 -0
  424. rasa/markers/__init__.py +0 -0
  425. rasa/markers/marker.py +269 -0
  426. rasa/markers/marker_base.py +828 -0
  427. rasa/markers/upload.py +74 -0
  428. rasa/markers/validate.py +21 -0
  429. rasa/model.py +118 -0
  430. rasa/model_testing.py +457 -0
  431. rasa/model_training.py +535 -0
  432. rasa/nlu/__init__.py +7 -0
  433. rasa/nlu/classifiers/__init__.py +3 -0
  434. rasa/nlu/classifiers/classifier.py +5 -0
  435. rasa/nlu/classifiers/diet_classifier.py +1874 -0
  436. rasa/nlu/classifiers/fallback_classifier.py +192 -0
  437. rasa/nlu/classifiers/keyword_intent_classifier.py +188 -0
  438. rasa/nlu/classifiers/llm_intent_classifier.py +519 -0
  439. rasa/nlu/classifiers/logistic_regression_classifier.py +240 -0
  440. rasa/nlu/classifiers/mitie_intent_classifier.py +156 -0
  441. rasa/nlu/classifiers/regex_message_handler.py +56 -0
  442. rasa/nlu/classifiers/sklearn_intent_classifier.py +309 -0
  443. rasa/nlu/constants.py +77 -0
  444. rasa/nlu/convert.py +40 -0
  445. rasa/nlu/emulators/__init__.py +0 -0
  446. rasa/nlu/emulators/dialogflow.py +55 -0
  447. rasa/nlu/emulators/emulator.py +49 -0
  448. rasa/nlu/emulators/luis.py +86 -0
  449. rasa/nlu/emulators/no_emulator.py +10 -0
  450. rasa/nlu/emulators/wit.py +56 -0
  451. rasa/nlu/extractors/__init__.py +0 -0
  452. rasa/nlu/extractors/crf_entity_extractor.py +672 -0
  453. rasa/nlu/extractors/duckling_entity_extractor.py +206 -0
  454. rasa/nlu/extractors/entity_synonyms.py +178 -0
  455. rasa/nlu/extractors/extractor.py +470 -0
  456. rasa/nlu/extractors/mitie_entity_extractor.py +293 -0
  457. rasa/nlu/extractors/regex_entity_extractor.py +220 -0
  458. rasa/nlu/extractors/spacy_entity_extractor.py +95 -0
  459. rasa/nlu/featurizers/__init__.py +0 -0
  460. rasa/nlu/featurizers/dense_featurizer/__init__.py +0 -0
  461. rasa/nlu/featurizers/dense_featurizer/convert_featurizer.py +449 -0
  462. rasa/nlu/featurizers/dense_featurizer/dense_featurizer.py +57 -0
  463. rasa/nlu/featurizers/dense_featurizer/lm_featurizer.py +772 -0
  464. rasa/nlu/featurizers/dense_featurizer/mitie_featurizer.py +170 -0
  465. rasa/nlu/featurizers/dense_featurizer/spacy_featurizer.py +132 -0
  466. rasa/nlu/featurizers/featurizer.py +89 -0
  467. rasa/nlu/featurizers/sparse_featurizer/__init__.py +0 -0
  468. rasa/nlu/featurizers/sparse_featurizer/count_vectors_featurizer.py +840 -0
  469. rasa/nlu/featurizers/sparse_featurizer/lexical_syntactic_featurizer.py +539 -0
  470. rasa/nlu/featurizers/sparse_featurizer/regex_featurizer.py +269 -0
  471. rasa/nlu/featurizers/sparse_featurizer/sparse_featurizer.py +9 -0
  472. rasa/nlu/model.py +24 -0
  473. rasa/nlu/persistor.py +240 -0
  474. rasa/nlu/run.py +27 -0
  475. rasa/nlu/selectors/__init__.py +0 -0
  476. rasa/nlu/selectors/response_selector.py +990 -0
  477. rasa/nlu/test.py +1943 -0
  478. rasa/nlu/tokenizers/__init__.py +0 -0
  479. rasa/nlu/tokenizers/jieba_tokenizer.py +148 -0
  480. rasa/nlu/tokenizers/mitie_tokenizer.py +75 -0
  481. rasa/nlu/tokenizers/spacy_tokenizer.py +72 -0
  482. rasa/nlu/tokenizers/tokenizer.py +239 -0
  483. rasa/nlu/tokenizers/whitespace_tokenizer.py +106 -0
  484. rasa/nlu/utils/__init__.py +35 -0
  485. rasa/nlu/utils/bilou_utils.py +462 -0
  486. rasa/nlu/utils/hugging_face/__init__.py +0 -0
  487. rasa/nlu/utils/hugging_face/registry.py +108 -0
  488. rasa/nlu/utils/hugging_face/transformers_pre_post_processors.py +311 -0
  489. rasa/nlu/utils/mitie_utils.py +113 -0
  490. rasa/nlu/utils/pattern_utils.py +168 -0
  491. rasa/nlu/utils/spacy_utils.py +312 -0
  492. rasa/plugin.py +90 -0
  493. rasa/server.py +1536 -0
  494. rasa/shared/__init__.py +0 -0
  495. rasa/shared/constants.py +181 -0
  496. rasa/shared/core/__init__.py +0 -0
  497. rasa/shared/core/constants.py +168 -0
  498. rasa/shared/core/conversation.py +46 -0
  499. rasa/shared/core/domain.py +2106 -0
  500. rasa/shared/core/events.py +2507 -0
  501. rasa/shared/core/flows/__init__.py +7 -0
  502. rasa/shared/core/flows/flow.py +353 -0
  503. rasa/shared/core/flows/flow_step.py +146 -0
  504. rasa/shared/core/flows/flow_step_links.py +319 -0
  505. rasa/shared/core/flows/flow_step_sequence.py +70 -0
  506. rasa/shared/core/flows/flows_list.py +211 -0
  507. rasa/shared/core/flows/flows_yaml_schema.json +217 -0
  508. rasa/shared/core/flows/nlu_trigger.py +117 -0
  509. rasa/shared/core/flows/steps/__init__.py +24 -0
  510. rasa/shared/core/flows/steps/action.py +51 -0
  511. rasa/shared/core/flows/steps/call.py +64 -0
  512. rasa/shared/core/flows/steps/collect.py +112 -0
  513. rasa/shared/core/flows/steps/constants.py +5 -0
  514. rasa/shared/core/flows/steps/continuation.py +36 -0
  515. rasa/shared/core/flows/steps/end.py +22 -0
  516. rasa/shared/core/flows/steps/internal.py +44 -0
  517. rasa/shared/core/flows/steps/link.py +51 -0
  518. rasa/shared/core/flows/steps/no_operation.py +48 -0
  519. rasa/shared/core/flows/steps/set_slots.py +50 -0
  520. rasa/shared/core/flows/steps/start.py +30 -0
  521. rasa/shared/core/flows/validation.py +527 -0
  522. rasa/shared/core/flows/yaml_flows_io.py +278 -0
  523. rasa/shared/core/generator.py +907 -0
  524. rasa/shared/core/slot_mappings.py +235 -0
  525. rasa/shared/core/slots.py +647 -0
  526. rasa/shared/core/trackers.py +1159 -0
  527. rasa/shared/core/training_data/__init__.py +0 -0
  528. rasa/shared/core/training_data/loading.py +90 -0
  529. rasa/shared/core/training_data/story_reader/__init__.py +0 -0
  530. rasa/shared/core/training_data/story_reader/story_reader.py +129 -0
  531. rasa/shared/core/training_data/story_reader/story_step_builder.py +168 -0
  532. rasa/shared/core/training_data/story_reader/yaml_story_reader.py +888 -0
  533. rasa/shared/core/training_data/story_writer/__init__.py +0 -0
  534. rasa/shared/core/training_data/story_writer/story_writer.py +76 -0
  535. rasa/shared/core/training_data/story_writer/yaml_story_writer.py +442 -0
  536. rasa/shared/core/training_data/structures.py +838 -0
  537. rasa/shared/core/training_data/visualization.html +146 -0
  538. rasa/shared/core/training_data/visualization.py +603 -0
  539. rasa/shared/data.py +192 -0
  540. rasa/shared/engine/__init__.py +0 -0
  541. rasa/shared/engine/caching.py +26 -0
  542. rasa/shared/exceptions.py +129 -0
  543. rasa/shared/importers/__init__.py +0 -0
  544. rasa/shared/importers/importer.py +705 -0
  545. rasa/shared/importers/multi_project.py +203 -0
  546. rasa/shared/importers/rasa.py +100 -0
  547. rasa/shared/importers/utils.py +34 -0
  548. rasa/shared/nlu/__init__.py +0 -0
  549. rasa/shared/nlu/constants.py +45 -0
  550. rasa/shared/nlu/interpreter.py +10 -0
  551. rasa/shared/nlu/training_data/__init__.py +0 -0
  552. rasa/shared/nlu/training_data/entities_parser.py +209 -0
  553. rasa/shared/nlu/training_data/features.py +374 -0
  554. rasa/shared/nlu/training_data/formats/__init__.py +10 -0
  555. rasa/shared/nlu/training_data/formats/dialogflow.py +162 -0
  556. rasa/shared/nlu/training_data/formats/luis.py +87 -0
  557. rasa/shared/nlu/training_data/formats/rasa.py +135 -0
  558. rasa/shared/nlu/training_data/formats/rasa_yaml.py +605 -0
  559. rasa/shared/nlu/training_data/formats/readerwriter.py +245 -0
  560. rasa/shared/nlu/training_data/formats/wit.py +52 -0
  561. rasa/shared/nlu/training_data/loading.py +137 -0
  562. rasa/shared/nlu/training_data/lookup_tables_parser.py +30 -0
  563. rasa/shared/nlu/training_data/message.py +477 -0
  564. rasa/shared/nlu/training_data/schemas/__init__.py +0 -0
  565. rasa/shared/nlu/training_data/schemas/data_schema.py +85 -0
  566. rasa/shared/nlu/training_data/schemas/nlu.yml +53 -0
  567. rasa/shared/nlu/training_data/schemas/responses.yml +70 -0
  568. rasa/shared/nlu/training_data/synonyms_parser.py +42 -0
  569. rasa/shared/nlu/training_data/training_data.py +732 -0
  570. rasa/shared/nlu/training_data/util.py +223 -0
  571. rasa/shared/providers/__init__.py +0 -0
  572. rasa/shared/providers/openai/__init__.py +0 -0
  573. rasa/shared/providers/openai/clients.py +43 -0
  574. rasa/shared/providers/openai/session_handler.py +110 -0
  575. rasa/shared/utils/__init__.py +0 -0
  576. rasa/shared/utils/cli.py +72 -0
  577. rasa/shared/utils/common.py +308 -0
  578. rasa/shared/utils/constants.py +1 -0
  579. rasa/shared/utils/io.py +403 -0
  580. rasa/shared/utils/llm.py +405 -0
  581. rasa/shared/utils/pykwalify_extensions.py +26 -0
  582. rasa/shared/utils/schemas/__init__.py +0 -0
  583. rasa/shared/utils/schemas/config.yml +2 -0
  584. rasa/shared/utils/schemas/domain.yml +142 -0
  585. rasa/shared/utils/schemas/events.py +212 -0
  586. rasa/shared/utils/schemas/model_config.yml +46 -0
  587. rasa/shared/utils/schemas/stories.yml +173 -0
  588. rasa/shared/utils/yaml.py +777 -0
  589. rasa/studio/__init__.py +0 -0
  590. rasa/studio/auth.py +252 -0
  591. rasa/studio/config.py +127 -0
  592. rasa/studio/constants.py +16 -0
  593. rasa/studio/data_handler.py +352 -0
  594. rasa/studio/download.py +350 -0
  595. rasa/studio/train.py +136 -0
  596. rasa/studio/upload.py +408 -0
  597. rasa/telemetry.py +1583 -0
  598. rasa/tracing/__init__.py +0 -0
  599. rasa/tracing/config.py +338 -0
  600. rasa/tracing/constants.py +38 -0
  601. rasa/tracing/instrumentation/__init__.py +0 -0
  602. rasa/tracing/instrumentation/attribute_extractors.py +663 -0
  603. rasa/tracing/instrumentation/instrumentation.py +939 -0
  604. rasa/tracing/instrumentation/intentless_policy_instrumentation.py +142 -0
  605. rasa/tracing/instrumentation/metrics.py +206 -0
  606. rasa/tracing/metric_instrument_provider.py +125 -0
  607. rasa/utils/__init__.py +0 -0
  608. rasa/utils/beta.py +83 -0
  609. rasa/utils/cli.py +27 -0
  610. rasa/utils/common.py +635 -0
  611. rasa/utils/converter.py +53 -0
  612. rasa/utils/endpoints.py +303 -0
  613. rasa/utils/io.py +326 -0
  614. rasa/utils/licensing.py +319 -0
  615. rasa/utils/log_utils.py +174 -0
  616. rasa/utils/mapper.py +210 -0
  617. rasa/utils/ml_utils.py +145 -0
  618. rasa/utils/plotting.py +362 -0
  619. rasa/utils/singleton.py +23 -0
  620. rasa/utils/tensorflow/__init__.py +0 -0
  621. rasa/utils/tensorflow/callback.py +112 -0
  622. rasa/utils/tensorflow/constants.py +116 -0
  623. rasa/utils/tensorflow/crf.py +492 -0
  624. rasa/utils/tensorflow/data_generator.py +440 -0
  625. rasa/utils/tensorflow/environment.py +161 -0
  626. rasa/utils/tensorflow/exceptions.py +5 -0
  627. rasa/utils/tensorflow/layers.py +1565 -0
  628. rasa/utils/tensorflow/layers_utils.py +113 -0
  629. rasa/utils/tensorflow/metrics.py +281 -0
  630. rasa/utils/tensorflow/model_data.py +991 -0
  631. rasa/utils/tensorflow/model_data_utils.py +500 -0
  632. rasa/utils/tensorflow/models.py +936 -0
  633. rasa/utils/tensorflow/rasa_layers.py +1094 -0
  634. rasa/utils/tensorflow/transformer.py +640 -0
  635. rasa/utils/tensorflow/types.py +6 -0
  636. rasa/utils/train_utils.py +572 -0
  637. rasa/utils/yaml.py +54 -0
  638. rasa/validator.py +1035 -0
  639. rasa/version.py +3 -0
  640. rasa_pro-3.8.16.dist-info/METADATA +528 -0
  641. rasa_pro-3.8.16.dist-info/NOTICE +5 -0
  642. rasa_pro-3.8.16.dist-info/RECORD +644 -0
  643. rasa_pro-3.8.16.dist-info/WHEEL +4 -0
  644. rasa_pro-3.8.16.dist-info/entry_points.txt +3 -0
rasa/nlu/test.py ADDED
@@ -0,0 +1,1943 @@
1
+ import copy
2
+ import itertools
3
+ import os
4
+ import logging
5
+ import structlog
6
+ from pathlib import Path
7
+
8
+ import numpy as np
9
+ from collections import defaultdict, namedtuple
10
+ from tqdm import tqdm
11
+ from typing import (
12
+ Iterable,
13
+ Iterator,
14
+ Tuple,
15
+ List,
16
+ Set,
17
+ Optional,
18
+ Text,
19
+ Union,
20
+ Dict,
21
+ Any,
22
+ NamedTuple,
23
+ TYPE_CHECKING,
24
+ )
25
+
26
+ from rasa import telemetry
27
+ from rasa.core.agent import Agent
28
+ from rasa.core.channels import UserMessage
29
+ from rasa.core.processor import MessageProcessor
30
+ from rasa.shared.nlu.training_data.training_data import TrainingData
31
+ from rasa.shared.utils.yaml import write_yaml
32
+ from rasa.utils.common import TempDirectoryPath, get_temp_dir_name
33
+ import rasa.shared.utils.io
34
+ import rasa.utils.plotting as plot_utils
35
+ import rasa.utils.io as io_utils
36
+
37
+ from rasa.constants import TEST_DATA_FILE, TRAIN_DATA_FILE, NLG_DATA_FILE
38
+ import rasa.nlu.classifiers.fallback_classifier
39
+ from rasa.nlu.constants import (
40
+ RESPONSE_SELECTOR_DEFAULT_INTENT,
41
+ RESPONSE_SELECTOR_PROPERTY_NAME,
42
+ RESPONSE_SELECTOR_PREDICTION_KEY,
43
+ TOKENS_NAMES,
44
+ ENTITY_ATTRIBUTE_CONFIDENCE_TYPE,
45
+ ENTITY_ATTRIBUTE_CONFIDENCE_ROLE,
46
+ ENTITY_ATTRIBUTE_CONFIDENCE_GROUP,
47
+ RESPONSE_SELECTOR_RETRIEVAL_INTENTS,
48
+ )
49
+ from rasa.shared.nlu.constants import (
50
+ TEXT,
51
+ INTENT,
52
+ INTENT_RESPONSE_KEY,
53
+ ENTITIES,
54
+ EXTRACTOR,
55
+ PRETRAINED_EXTRACTORS,
56
+ ENTITY_ATTRIBUTE_TYPE,
57
+ ENTITY_ATTRIBUTE_GROUP,
58
+ ENTITY_ATTRIBUTE_ROLE,
59
+ NO_ENTITY_TAG,
60
+ INTENT_NAME_KEY,
61
+ PREDICTED_CONFIDENCE_KEY,
62
+ )
63
+ from rasa.nlu.classifiers import fallback_classifier
64
+ from rasa.nlu.tokenizers.tokenizer import Token
65
+ from rasa.shared.importers.importer import TrainingDataImporter
66
+ from rasa.shared.nlu.training_data.formats.rasa_yaml import RasaYAMLWriter
67
+
68
+ if TYPE_CHECKING:
69
+ from typing_extensions import TypedDict
70
+
71
+ EntityPrediction = TypedDict(
72
+ "EntityPrediction",
73
+ {
74
+ "text": Text,
75
+ "entities": List[Dict[Text, Any]],
76
+ "predicted_entities": List[Dict[Text, Any]],
77
+ },
78
+ )
79
+ logger = logging.getLogger(__name__)
80
+ structlogger = structlog.get_logger()
81
+
82
+ # Exclude 'EntitySynonymMapper' and 'ResponseSelector' as their super class
83
+ # performs entity extraction but those two classifiers don't
84
+ ENTITY_PROCESSORS = {"EntitySynonymMapper", "ResponseSelector"}
85
+
86
+ EXTRACTORS_WITH_CONFIDENCES = {"CRFEntityExtractor", "DIETClassifier"}
87
+
88
+
89
+ class CVEvaluationResult(NamedTuple):
90
+ """Stores NLU cross-validation results."""
91
+
92
+ train: Dict
93
+ test: Dict
94
+ evaluation: Dict
95
+
96
+
97
+ NO_ENTITY = "no_entity"
98
+
99
+ IntentEvaluationResult = namedtuple(
100
+ "IntentEvaluationResult", "intent_target intent_prediction message confidence"
101
+ )
102
+
103
+ ResponseSelectionEvaluationResult = namedtuple(
104
+ "ResponseSelectionEvaluationResult",
105
+ "intent_response_key_target intent_response_key_prediction message confidence",
106
+ )
107
+
108
+ EntityEvaluationResult = namedtuple(
109
+ "EntityEvaluationResult", "entity_targets entity_predictions tokens message"
110
+ )
111
+
112
+ IntentMetrics = Dict[Text, List[float]]
113
+ EntityMetrics = Dict[Text, Dict[Text, List[float]]]
114
+ ResponseSelectionMetrics = Dict[Text, List[float]]
115
+
116
+
117
+ def log_evaluation_table(
118
+ report: Text, precision: float, f1: float, accuracy: float
119
+ ) -> None: # pragma: no cover
120
+ """Log the sklearn evaluation metrics."""
121
+ logger.info(f"F1-Score: {f1}")
122
+ logger.info(f"Precision: {precision}")
123
+ logger.info(f"Accuracy: {accuracy}")
124
+ logger.info(f"Classification report: \n{report}")
125
+
126
+
127
+ def remove_empty_intent_examples(
128
+ intent_results: List[IntentEvaluationResult],
129
+ ) -> List[IntentEvaluationResult]:
130
+ """Remove those examples without an intent.
131
+
132
+ Args:
133
+ intent_results: intent evaluation results
134
+
135
+ Returns: intent evaluation results
136
+ """
137
+ filtered = []
138
+ for r in intent_results:
139
+ # substitute None values with empty string
140
+ # to enable sklearn evaluation
141
+ if r.intent_prediction is None:
142
+ r = r._replace(intent_prediction="")
143
+
144
+ if r.intent_target != "" and r.intent_target is not None:
145
+ filtered.append(r)
146
+
147
+ return filtered
148
+
149
+
150
+ def remove_empty_response_examples(
151
+ response_results: List[ResponseSelectionEvaluationResult],
152
+ ) -> List[ResponseSelectionEvaluationResult]:
153
+ """Remove those examples without a response.
154
+
155
+ Args:
156
+ response_results: response selection evaluation results
157
+
158
+ Returns:
159
+ Response selection evaluation results
160
+ """
161
+ filtered = []
162
+ for r in response_results:
163
+ # substitute None values with empty string
164
+ # to enable sklearn evaluation
165
+ if r.intent_response_key_prediction is None:
166
+ r = r._replace(intent_response_key_prediction="")
167
+
168
+ if r.confidence is None:
169
+ # This might happen if response selector training data is present but
170
+ # no response selector is part of the model
171
+ r = r._replace(confidence=0.0)
172
+
173
+ if r.intent_response_key_target:
174
+ filtered.append(r)
175
+
176
+ return filtered
177
+
178
+
179
+ def drop_intents_below_freq(
180
+ training_data: TrainingData, cutoff: int = 5
181
+ ) -> TrainingData:
182
+ """Remove intent groups with less than cutoff instances.
183
+
184
+ Args:
185
+ training_data: training data
186
+ cutoff: threshold
187
+
188
+ Returns: updated training data
189
+ """
190
+ logger.debug(
191
+ "Raw data intent examples: {}".format(len(training_data.intent_examples))
192
+ )
193
+
194
+ examples_per_intent = training_data.number_of_examples_per_intent
195
+ return training_data.filter_training_examples(
196
+ lambda ex: examples_per_intent.get(ex.get(INTENT), 0) >= cutoff
197
+ )
198
+
199
+
200
+ def write_intent_successes(
201
+ intent_results: List[IntentEvaluationResult], successes_filename: Text
202
+ ) -> None:
203
+ """Write successful intent predictions to a file.
204
+
205
+ Args:
206
+ intent_results: intent evaluation result
207
+ successes_filename: filename of file to save successful predictions to
208
+ """
209
+ successes = [
210
+ {
211
+ "text": r.message,
212
+ "intent": r.intent_target,
213
+ "intent_prediction": {
214
+ INTENT_NAME_KEY: r.intent_prediction,
215
+ "confidence": r.confidence,
216
+ },
217
+ }
218
+ for r in intent_results
219
+ if r.intent_target == r.intent_prediction
220
+ ]
221
+
222
+ if successes:
223
+ rasa.shared.utils.io.dump_obj_as_json_to_file(successes_filename, successes)
224
+ logger.info(f"Successful intent predictions saved to {successes_filename}.")
225
+ logger.debug(f"\n\nSuccessfully predicted the following intents: \n{successes}")
226
+ else:
227
+ logger.info("No successful intent predictions found.")
228
+
229
+
230
+ def _write_errors(errors: List[Dict], errors_filename: Text, error_type: Text) -> None:
231
+ """Write incorrect intent predictions to a file.
232
+
233
+ Args:
234
+ errors: Serializable prediction errors.
235
+ errors_filename: filename of file to save incorrect predictions to
236
+ error_type: NLU entity which was evaluated (e.g. `intent` or `entity`).
237
+ """
238
+ if errors:
239
+ rasa.shared.utils.io.dump_obj_as_json_to_file(errors_filename, errors)
240
+ logger.info(f"Incorrect {error_type} predictions saved to {errors_filename}.")
241
+ logger.debug(
242
+ f"\n\nThese {error_type} examples could not be classified "
243
+ f"correctly: \n{errors}"
244
+ )
245
+ else:
246
+ logger.info(f"Every {error_type} was predicted correctly by the model.")
247
+
248
+
249
+ def _get_intent_errors(intent_results: List[IntentEvaluationResult]) -> List[Dict]:
250
+ return [
251
+ {
252
+ "text": r.message,
253
+ "intent": r.intent_target,
254
+ "intent_prediction": {
255
+ INTENT_NAME_KEY: r.intent_prediction,
256
+ "confidence": r.confidence,
257
+ },
258
+ }
259
+ for r in intent_results
260
+ if r.intent_target != r.intent_prediction
261
+ ]
262
+
263
+
264
+ def write_response_successes(
265
+ response_results: List[ResponseSelectionEvaluationResult], successes_filename: Text
266
+ ) -> None:
267
+ """Write successful response selection predictions to a file.
268
+
269
+ Args:
270
+ response_results: response selection evaluation result
271
+ successes_filename: filename of file to save successful predictions to
272
+ """
273
+ successes = [
274
+ {
275
+ "text": r.message,
276
+ "intent_response_key_target": r.intent_response_key_target,
277
+ "intent_response_key_prediction": {
278
+ "name": r.intent_response_key_prediction,
279
+ "confidence": r.confidence,
280
+ },
281
+ }
282
+ for r in response_results
283
+ if r.intent_response_key_prediction == r.intent_response_key_target
284
+ ]
285
+
286
+ if successes:
287
+ rasa.shared.utils.io.dump_obj_as_json_to_file(successes_filename, successes)
288
+ logger.info(f"Successful response predictions saved to {successes_filename}.")
289
+ structlogger.debug("test.write.response", successes=copy.deepcopy(successes))
290
+ else:
291
+ logger.info("No successful response predictions found.")
292
+
293
+
294
+ def _response_errors(
295
+ response_results: List[ResponseSelectionEvaluationResult],
296
+ ) -> List[Dict]:
297
+ """Write incorrect response selection predictions to a file.
298
+
299
+ Args:
300
+ response_results: response selection evaluation result
301
+
302
+ Returns:
303
+ Serializable prediction errors.
304
+ """
305
+ return [
306
+ {
307
+ "text": r.message,
308
+ "intent_response_key_target": r.intent_response_key_target,
309
+ "intent_response_key_prediction": {
310
+ "name": r.intent_response_key_prediction,
311
+ "confidence": r.confidence,
312
+ },
313
+ }
314
+ for r in response_results
315
+ if r.intent_response_key_prediction != r.intent_response_key_target
316
+ ]
317
+
318
+
319
+ def plot_attribute_confidences(
320
+ results: Union[
321
+ List[IntentEvaluationResult], List[ResponseSelectionEvaluationResult]
322
+ ],
323
+ hist_filename: Optional[Text],
324
+ target_key: Text,
325
+ prediction_key: Text,
326
+ title: Text,
327
+ ) -> None:
328
+ """Create histogram of confidence distribution.
329
+
330
+ Args:
331
+ results: evaluation results
332
+ hist_filename: filename to save plot to
333
+ target_key: key of target in results
334
+ prediction_key: key of predictions in results
335
+ title: title of plot
336
+ """
337
+ pos_hist = [
338
+ r.confidence
339
+ for r in results
340
+ if getattr(r, target_key) == getattr(r, prediction_key)
341
+ ]
342
+
343
+ neg_hist = [
344
+ r.confidence
345
+ for r in results
346
+ if getattr(r, target_key) != getattr(r, prediction_key)
347
+ ]
348
+
349
+ plot_utils.plot_paired_histogram([pos_hist, neg_hist], title, hist_filename)
350
+
351
+
352
+ def plot_entity_confidences(
353
+ merged_targets: List[Text],
354
+ merged_predictions: List[Text],
355
+ merged_confidences: List[float],
356
+ hist_filename: Text,
357
+ title: Text,
358
+ ) -> None:
359
+ """Creates histogram of confidence distribution.
360
+
361
+ Args:
362
+ merged_targets: Entity labels.
363
+ merged_predictions: Predicted entities.
364
+ merged_confidences: Confidence scores of predictions.
365
+ hist_filename: filename to save plot to
366
+ title: title of plot
367
+ """
368
+ pos_hist = [
369
+ confidence
370
+ for target, prediction, confidence in zip(
371
+ merged_targets, merged_predictions, merged_confidences
372
+ )
373
+ if target != NO_ENTITY and target == prediction
374
+ ]
375
+
376
+ neg_hist = [
377
+ confidence
378
+ for target, prediction, confidence in zip(
379
+ merged_targets, merged_predictions, merged_confidences
380
+ )
381
+ if prediction not in (NO_ENTITY, target)
382
+ ]
383
+
384
+ plot_utils.plot_paired_histogram([pos_hist, neg_hist], title, hist_filename)
385
+
386
+
387
+ def evaluate_response_selections(
388
+ response_selection_results: List[ResponseSelectionEvaluationResult],
389
+ output_directory: Optional[Text],
390
+ successes: bool,
391
+ errors: bool,
392
+ disable_plotting: bool,
393
+ report_as_dict: Optional[bool] = None,
394
+ ) -> Dict: # pragma: no cover
395
+ """Creates summary statistics for response selection.
396
+
397
+ Only considers those examples with a set response.
398
+ Others are filtered out. Returns a dictionary of containing the
399
+ evaluation result.
400
+
401
+ Args:
402
+ response_selection_results: response selection evaluation results
403
+ output_directory: directory to store files to
404
+ successes: if True success are written down to disk
405
+ errors: if True errors are written down to disk
406
+ disable_plotting: if True no plots are created
407
+ report_as_dict: `True` if the evaluation report should be returned as `dict`.
408
+ If `False` the report is returned in a human-readable text format. If `None`
409
+ `report_as_dict` is considered as `True` in case an `output_directory` is
410
+ given.
411
+
412
+ Returns: dictionary with evaluation results
413
+ """
414
+ # remove empty response targets
415
+ num_examples = len(response_selection_results)
416
+ response_selection_results = remove_empty_response_examples(
417
+ response_selection_results
418
+ )
419
+
420
+ logger.info(
421
+ f"Response Selection Evaluation: Only considering those "
422
+ f"{len(response_selection_results)} examples that have a defined response out "
423
+ f"of {num_examples} examples."
424
+ )
425
+
426
+ (
427
+ target_intent_response_keys,
428
+ predicted_intent_response_keys,
429
+ ) = _targets_predictions_from(
430
+ response_selection_results,
431
+ "intent_response_key_target",
432
+ "intent_response_key_prediction",
433
+ )
434
+
435
+ report, precision, f1, accuracy, confusion_matrix, labels = _calculate_report(
436
+ output_directory,
437
+ target_intent_response_keys,
438
+ predicted_intent_response_keys,
439
+ report_as_dict,
440
+ )
441
+ if output_directory:
442
+ _dump_report(output_directory, "response_selection_report.json", report)
443
+
444
+ if successes:
445
+ successes_filename = "response_selection_successes.json"
446
+ if output_directory:
447
+ successes_filename = os.path.join(output_directory, successes_filename)
448
+ # save classified samples to file for debugging
449
+ write_response_successes(response_selection_results, successes_filename)
450
+
451
+ response_errors = _response_errors(response_selection_results)
452
+
453
+ if errors and output_directory:
454
+ errors_filename = "response_selection_errors.json"
455
+ errors_filename = os.path.join(output_directory, errors_filename)
456
+ _write_errors(response_errors, errors_filename, error_type="response")
457
+
458
+ if not disable_plotting:
459
+ confusion_matrix_filename = "response_selection_confusion_matrix.png"
460
+ if output_directory:
461
+ confusion_matrix_filename = os.path.join(
462
+ output_directory, confusion_matrix_filename
463
+ )
464
+
465
+ plot_utils.plot_confusion_matrix(
466
+ confusion_matrix,
467
+ classes=labels,
468
+ title="Response Selection Confusion Matrix",
469
+ output_file=confusion_matrix_filename,
470
+ )
471
+
472
+ histogram_filename = "response_selection_histogram.png"
473
+ if output_directory:
474
+ histogram_filename = os.path.join(output_directory, histogram_filename)
475
+ plot_attribute_confidences(
476
+ response_selection_results,
477
+ histogram_filename,
478
+ "intent_response_key_target",
479
+ "intent_response_key_prediction",
480
+ title="Response Selection Prediction Confidence Distribution",
481
+ )
482
+
483
+ predictions = [
484
+ {
485
+ "text": res.message,
486
+ "intent_response_key_target": res.intent_response_key_target,
487
+ "intent_response_key_prediction": res.intent_response_key_prediction,
488
+ "confidence": res.confidence,
489
+ }
490
+ for res in response_selection_results
491
+ ]
492
+
493
+ return {
494
+ "predictions": predictions,
495
+ "report": report,
496
+ "precision": precision,
497
+ "f1_score": f1,
498
+ "accuracy": accuracy,
499
+ "errors": response_errors,
500
+ }
501
+
502
+
503
+ def _add_confused_labels_to_report(
504
+ report: Dict[Text, Dict[Text, Any]],
505
+ confusion_matrix: np.ndarray,
506
+ labels: List[Text],
507
+ exclude_labels: Optional[List[Text]] = None,
508
+ ) -> Dict[Text, Dict[Text, Union[Dict, Any]]]:
509
+ """Adds a field "confused_with" to the evaluation report.
510
+
511
+ The value is a dict of {"false_positive_label": false_positive_count} pairs.
512
+ If there are no false positives in the confusion matrix,
513
+ the dict will be empty. Typically, we include the two most
514
+ commonly false positive labels, three in the rare case that
515
+ the diagonal element in the confusion matrix is not one of the
516
+ three highest values in the row.
517
+
518
+ Args:
519
+ report: the evaluation report
520
+ confusion_matrix: confusion matrix
521
+ labels: list of labels
522
+ exclude_labels: labels to exclude from the report
523
+
524
+ Returns: updated evaluation report
525
+ """
526
+ if exclude_labels is None:
527
+ exclude_labels = []
528
+
529
+ # sort confusion matrix by false positives
530
+ indices = np.argsort(confusion_matrix, axis=1)
531
+ n_candidates = min(3, len(labels))
532
+
533
+ for label in labels:
534
+ if label in exclude_labels:
535
+ continue
536
+ # it is possible to predict intent 'None'
537
+ if report.get(label):
538
+ report[label]["confused_with"] = {}
539
+
540
+ for i, label in enumerate(labels):
541
+ if label in exclude_labels:
542
+ continue
543
+ for j in range(n_candidates):
544
+ label_idx = indices[i, -(1 + j)]
545
+ false_pos_label = labels[label_idx]
546
+ false_positives = int(confusion_matrix[i, label_idx])
547
+ if (
548
+ false_pos_label != label
549
+ and false_pos_label not in exclude_labels
550
+ and false_positives > 0
551
+ ):
552
+ report[label]["confused_with"][false_pos_label] = false_positives
553
+
554
+ return report
555
+
556
+
557
+ def evaluate_intents(
558
+ intent_results: List[IntentEvaluationResult],
559
+ output_directory: Optional[Text],
560
+ successes: bool,
561
+ errors: bool,
562
+ disable_plotting: bool,
563
+ report_as_dict: Optional[bool] = None,
564
+ ) -> Dict: # pragma: no cover
565
+ """Creates summary statistics for intents.
566
+
567
+ Only considers those examples with a set intent. Others are filtered out.
568
+ Returns a dictionary of containing the evaluation result.
569
+
570
+ Args:
571
+ intent_results: intent evaluation results
572
+ output_directory: directory to store files to
573
+ successes: if True correct predictions are written to disk
574
+ errors: if True incorrect predictions are written to disk
575
+ disable_plotting: if True no plots are created
576
+ report_as_dict: `True` if the evaluation report should be returned as `dict`.
577
+ If `False` the report is returned in a human-readable text format. If `None`
578
+ `report_as_dict` is considered as `True` in case an `output_directory` is
579
+ given.
580
+
581
+ Returns: dictionary with evaluation results
582
+ """
583
+ # remove empty intent targets
584
+ num_examples = len(intent_results)
585
+ intent_results = remove_empty_intent_examples(intent_results)
586
+
587
+ logger.info(
588
+ f"Intent Evaluation: Only considering those {len(intent_results)} examples "
589
+ f"that have a defined intent out of {num_examples} examples."
590
+ )
591
+
592
+ target_intents, predicted_intents = _targets_predictions_from(
593
+ intent_results, "intent_target", "intent_prediction"
594
+ )
595
+
596
+ report, precision, f1, accuracy, confusion_matrix, labels = _calculate_report(
597
+ output_directory, target_intents, predicted_intents, report_as_dict
598
+ )
599
+ if output_directory:
600
+ _dump_report(output_directory, "intent_report.json", report)
601
+
602
+ if successes and output_directory:
603
+ successes_filename = os.path.join(output_directory, "intent_successes.json")
604
+ # save classified samples to file for debugging
605
+ write_intent_successes(intent_results, successes_filename)
606
+
607
+ intent_errors = _get_intent_errors(intent_results)
608
+ if errors and output_directory:
609
+ errors_filename = os.path.join(output_directory, "intent_errors.json")
610
+ _write_errors(intent_errors, errors_filename, "intent")
611
+
612
+ if not disable_plotting:
613
+ confusion_matrix_filename = "intent_confusion_matrix.png"
614
+ if output_directory:
615
+ confusion_matrix_filename = os.path.join(
616
+ output_directory, confusion_matrix_filename
617
+ )
618
+ plot_utils.plot_confusion_matrix(
619
+ confusion_matrix,
620
+ classes=labels,
621
+ title="Intent Confusion matrix",
622
+ output_file=confusion_matrix_filename,
623
+ )
624
+
625
+ histogram_filename = "intent_histogram.png"
626
+ if output_directory:
627
+ histogram_filename = os.path.join(output_directory, histogram_filename)
628
+ plot_attribute_confidences(
629
+ intent_results,
630
+ histogram_filename,
631
+ "intent_target",
632
+ "intent_prediction",
633
+ title="Intent Prediction Confidence Distribution",
634
+ )
635
+
636
+ predictions = [
637
+ {
638
+ "text": res.message,
639
+ "intent": res.intent_target,
640
+ "predicted": res.intent_prediction,
641
+ "confidence": res.confidence,
642
+ }
643
+ for res in intent_results
644
+ ]
645
+
646
+ return {
647
+ "predictions": predictions,
648
+ "report": report,
649
+ "precision": precision,
650
+ "f1_score": f1,
651
+ "accuracy": accuracy,
652
+ "errors": intent_errors,
653
+ }
654
+
655
+
656
+ def _calculate_report(
657
+ output_directory: Optional[Text],
658
+ targets: Iterable[Any],
659
+ predictions: Iterable[Any],
660
+ report_as_dict: Optional[bool] = None,
661
+ exclude_label: Optional[Text] = None,
662
+ ) -> Tuple[Union[Text, Dict], float, float, float, np.ndarray, List[Text]]:
663
+ from rasa.model_testing import get_evaluation_metrics
664
+ import sklearn.metrics
665
+ import sklearn.utils.multiclass
666
+
667
+ confusion_matrix = sklearn.metrics.confusion_matrix(targets, predictions)
668
+ labels = sklearn.utils.multiclass.unique_labels(targets, predictions)
669
+
670
+ if report_as_dict is None:
671
+ report_as_dict = bool(output_directory)
672
+
673
+ report, precision, f1, accuracy = get_evaluation_metrics(
674
+ targets, predictions, output_dict=report_as_dict, exclude_label=exclude_label
675
+ )
676
+
677
+ if report_as_dict:
678
+ report = _add_confused_labels_to_report( # type: ignore[assignment]
679
+ report,
680
+ confusion_matrix,
681
+ labels,
682
+ exclude_labels=[exclude_label] if exclude_label else [],
683
+ )
684
+ elif not output_directory:
685
+ log_evaluation_table(report, precision, f1, accuracy)
686
+
687
+ return report, precision, f1, accuracy, confusion_matrix, labels
688
+
689
+
690
+ def _dump_report(output_directory: Text, filename: Text, report: Dict) -> None:
691
+ report_filename = os.path.join(output_directory, filename)
692
+ rasa.shared.utils.io.dump_obj_as_json_to_file(report_filename, report)
693
+ logger.info(f"Classification report saved to {report_filename}.")
694
+
695
+
696
+ def merge_labels(
697
+ aligned_predictions: List[Dict], extractor: Optional[Text] = None
698
+ ) -> List[Text]:
699
+ """Concatenates all labels of the aligned predictions.
700
+
701
+ Takes the aligned prediction labels which are grouped for each message
702
+ and concatenates them.
703
+
704
+ Args:
705
+ aligned_predictions: aligned predictions
706
+ extractor: entity extractor name
707
+
708
+ Returns:
709
+ Concatenated predictions
710
+ """
711
+ if extractor:
712
+ label_lists = [ap["extractor_labels"][extractor] for ap in aligned_predictions]
713
+ else:
714
+ label_lists = [ap["target_labels"] for ap in aligned_predictions]
715
+
716
+ return list(itertools.chain(*label_lists))
717
+
718
+
719
+ def merge_confidences(
720
+ aligned_predictions: List[Dict], extractor: Optional[Text] = None
721
+ ) -> List[float]:
722
+ """Concatenates all confidences of the aligned predictions.
723
+
724
+ Takes the aligned prediction confidences which are grouped for each message
725
+ and concatenates them.
726
+
727
+ Args:
728
+ aligned_predictions: aligned predictions
729
+ extractor: entity extractor name
730
+
731
+ Returns:
732
+ Concatenated confidences
733
+ """
734
+ label_lists = [ap["confidences"][extractor] for ap in aligned_predictions]
735
+ return list(itertools.chain(*label_lists))
736
+
737
+
738
+ def substitute_labels(labels: List[Text], old: Text, new: Text) -> List[Text]:
739
+ """Replaces label names in a list of labels.
740
+
741
+ Args:
742
+ labels: list of labels
743
+ old: old label name that should be replaced
744
+ new: new label name
745
+
746
+ Returns: updated labels
747
+ """
748
+ return [new if label == old else label for label in labels]
749
+
750
+
751
+ def collect_incorrect_entity_predictions(
752
+ entity_results: List[EntityEvaluationResult],
753
+ merged_predictions: List[Text],
754
+ merged_targets: List[Text],
755
+ ) -> List["EntityPrediction"]:
756
+ """Get incorrect entity predictions.
757
+
758
+ Args:
759
+ entity_results: entity evaluation results
760
+ merged_predictions: list of predicted entity labels
761
+ merged_targets: list of true entity labels
762
+
763
+ Returns: list of incorrect predictions
764
+ """
765
+ errors = []
766
+ offset = 0
767
+ for entity_result in entity_results:
768
+ for i in range(offset, offset + len(entity_result.tokens)):
769
+ if merged_targets[i] != merged_predictions[i]:
770
+ prediction: EntityPrediction = {
771
+ "text": entity_result.message,
772
+ "entities": entity_result.entity_targets,
773
+ "predicted_entities": entity_result.entity_predictions,
774
+ }
775
+ errors.append(prediction)
776
+ break
777
+ offset += len(entity_result.tokens)
778
+ return errors
779
+
780
+
781
+ def write_successful_entity_predictions(
782
+ entity_results: List[EntityEvaluationResult],
783
+ merged_targets: List[Text],
784
+ merged_predictions: List[Text],
785
+ successes_filename: Text,
786
+ ) -> None:
787
+ """Write correct entity predictions to a file.
788
+
789
+ Args:
790
+ entity_results: response selection evaluation result
791
+ merged_predictions: list of predicted entity labels
792
+ merged_targets: list of true entity labels
793
+ successes_filename: filename of file to save correct predictions to
794
+ """
795
+ successes = collect_successful_entity_predictions(
796
+ entity_results, merged_predictions, merged_targets
797
+ )
798
+
799
+ if successes:
800
+ rasa.shared.utils.io.dump_obj_as_json_to_file(successes_filename, successes)
801
+ logger.info(f"Successful entity predictions saved to {successes_filename}.")
802
+ structlogger.debug("test.write.entities", successes=copy.deepcopy(successes))
803
+ else:
804
+ logger.info("No successful entity prediction found.")
805
+
806
+
807
+ def collect_successful_entity_predictions(
808
+ entity_results: List[EntityEvaluationResult],
809
+ merged_predictions: List[Text],
810
+ merged_targets: List[Text],
811
+ ) -> List["EntityPrediction"]:
812
+ """Get correct entity predictions.
813
+
814
+ Args:
815
+ entity_results: entity evaluation results
816
+ merged_predictions: list of predicted entity labels
817
+ merged_targets: list of true entity labels
818
+
819
+ Returns: list of correct predictions
820
+ """
821
+ successes = []
822
+ offset = 0
823
+ for entity_result in entity_results:
824
+ for i in range(offset, offset + len(entity_result.tokens)):
825
+ if (
826
+ merged_targets[i] == merged_predictions[i]
827
+ and merged_targets[i] != NO_ENTITY
828
+ ):
829
+ prediction: EntityPrediction = {
830
+ "text": entity_result.message,
831
+ "entities": entity_result.entity_targets,
832
+ "predicted_entities": entity_result.entity_predictions,
833
+ }
834
+ successes.append(prediction)
835
+ break
836
+ offset += len(entity_result.tokens)
837
+ return successes
838
+
839
+
840
+ def evaluate_entities(
841
+ entity_results: List[EntityEvaluationResult],
842
+ extractors: Set[Text],
843
+ output_directory: Optional[Text],
844
+ successes: bool,
845
+ errors: bool,
846
+ disable_plotting: bool,
847
+ report_as_dict: Optional[bool] = None,
848
+ ) -> Dict: # pragma: no cover
849
+ """Creates summary statistics for each entity extractor.
850
+
851
+ Logs precision, recall, and F1 per entity type for each extractor.
852
+
853
+ Args:
854
+ entity_results: entity evaluation results
855
+ extractors: entity extractors to consider
856
+ output_directory: directory to store files to
857
+ successes: if True correct predictions are written to disk
858
+ errors: if True incorrect predictions are written to disk
859
+ disable_plotting: if True no plots are created
860
+ report_as_dict: `True` if the evaluation report should be returned as `dict`.
861
+ If `False` the report is returned in a human-readable text format. If `None`
862
+ `report_as_dict` is considered as `True` in case an `output_directory` is
863
+ given.
864
+
865
+ Returns: dictionary with evaluation results
866
+ """
867
+ aligned_predictions = align_all_entity_predictions(entity_results, extractors)
868
+ merged_targets = merge_labels(aligned_predictions)
869
+ merged_targets = substitute_labels(merged_targets, NO_ENTITY_TAG, NO_ENTITY)
870
+
871
+ result = {}
872
+
873
+ for extractor in extractors:
874
+ merged_predictions = merge_labels(aligned_predictions, extractor)
875
+ merged_predictions = substitute_labels(
876
+ merged_predictions, NO_ENTITY_TAG, NO_ENTITY
877
+ )
878
+
879
+ logger.info(f"Evaluation for entity extractor: {extractor} ")
880
+
881
+ report, precision, f1, accuracy, confusion_matrix, labels = _calculate_report(
882
+ output_directory,
883
+ merged_targets,
884
+ merged_predictions,
885
+ report_as_dict,
886
+ exclude_label=NO_ENTITY,
887
+ )
888
+ if output_directory:
889
+
890
+ _dump_report(output_directory, f"{extractor}_report.json", report)
891
+
892
+ if successes:
893
+ successes_filename = f"{extractor}_successes.json"
894
+ if output_directory:
895
+ successes_filename = os.path.join(output_directory, successes_filename)
896
+ # save classified samples to file for debugging
897
+ write_successful_entity_predictions(
898
+ entity_results, merged_targets, merged_predictions, successes_filename
899
+ )
900
+
901
+ entity_errors = collect_incorrect_entity_predictions(
902
+ entity_results, merged_predictions, merged_targets
903
+ )
904
+ if errors and output_directory:
905
+ errors_filename = os.path.join(output_directory, f"{extractor}_errors.json")
906
+
907
+ _write_errors(entity_errors, errors_filename, "entity")
908
+
909
+ if not disable_plotting:
910
+ confusion_matrix_filename = f"{extractor}_confusion_matrix.png"
911
+ if output_directory:
912
+ confusion_matrix_filename = os.path.join(
913
+ output_directory, confusion_matrix_filename
914
+ )
915
+ plot_utils.plot_confusion_matrix(
916
+ confusion_matrix,
917
+ classes=labels,
918
+ title="Entity Confusion matrix",
919
+ output_file=confusion_matrix_filename,
920
+ )
921
+
922
+ if extractor in EXTRACTORS_WITH_CONFIDENCES:
923
+ merged_confidences = merge_confidences(aligned_predictions, extractor)
924
+ histogram_filename = f"{extractor}_histogram.png"
925
+ if output_directory:
926
+ histogram_filename = os.path.join(
927
+ output_directory, histogram_filename
928
+ )
929
+ plot_entity_confidences(
930
+ merged_targets,
931
+ merged_predictions,
932
+ merged_confidences,
933
+ title="Entity Prediction Confidence Distribution",
934
+ hist_filename=histogram_filename,
935
+ )
936
+
937
+ result[extractor] = {
938
+ "report": report,
939
+ "precision": precision,
940
+ "f1_score": f1,
941
+ "accuracy": accuracy,
942
+ "errors": entity_errors,
943
+ }
944
+
945
+ return result
946
+
947
+
948
+ def is_token_within_entity(token: Token, entity: Dict) -> bool:
949
+ """Checks if a token is within the boundaries of an entity."""
950
+ return determine_intersection(token, entity) == len(token.text)
951
+
952
+
953
+ def does_token_cross_borders(token: Token, entity: Dict) -> bool:
954
+ """Checks if a token crosses the boundaries of an entity."""
955
+ num_intersect = determine_intersection(token, entity)
956
+ return 0 < num_intersect < len(token.text)
957
+
958
+
959
+ def determine_intersection(token: Token, entity: Dict) -> int:
960
+ """Calculates how many characters a given token and entity share."""
961
+ pos_token = set(range(token.start, token.end))
962
+ pos_entity = set(range(entity["start"], entity["end"]))
963
+ return len(pos_token.intersection(pos_entity))
964
+
965
+
966
+ def do_entities_overlap(entities: List[Dict]) -> bool:
967
+ """Checks if entities overlap.
968
+
969
+ I.e. cross each others start and end boundaries.
970
+
971
+ Args:
972
+ entities: list of entities
973
+
974
+ Returns: true if entities overlap, false otherwise.
975
+ """
976
+ sorted_entities = sorted(entities, key=lambda e: e["start"])
977
+ for i in range(len(sorted_entities) - 1):
978
+ curr_ent = sorted_entities[i]
979
+ next_ent = sorted_entities[i + 1]
980
+ if (
981
+ next_ent["start"] < curr_ent["end"]
982
+ and next_ent["entity"] != curr_ent["entity"]
983
+ ):
984
+ structlogger.warning(
985
+ "test.overlaping.entities",
986
+ current_entity=copy.deepcopy(curr_ent),
987
+ next_entity=copy.deepcopy(next_ent),
988
+ )
989
+ return True
990
+
991
+ return False
992
+
993
+
994
+ def find_intersecting_entities(token: Token, entities: List[Dict]) -> List[Dict]:
995
+ """Finds the entities that intersect with a token.
996
+
997
+ Args:
998
+ token: a single token
999
+ entities: entities found by a single extractor
1000
+
1001
+ Returns: list of entities
1002
+ """
1003
+ candidates = []
1004
+ for e in entities:
1005
+ if is_token_within_entity(token, e):
1006
+ candidates.append(e)
1007
+ elif does_token_cross_borders(token, e):
1008
+ candidates.append(e)
1009
+ structlogger.debug(
1010
+ "test.intersecting.entities",
1011
+ token_text=copy.deepcopy(token.text),
1012
+ token_start=token.start,
1013
+ token_end=token.end,
1014
+ entity=copy.deepcopy(e),
1015
+ )
1016
+ return candidates
1017
+
1018
+
1019
+ def pick_best_entity_fit(
1020
+ token: Token, candidates: List[Dict[Text, Any]]
1021
+ ) -> Optional[Dict[Text, Any]]:
1022
+ """Determines the best fitting entity given intersecting entities.
1023
+
1024
+ Args:
1025
+ token: a single token
1026
+ candidates: entities found by a single extractor
1027
+ attribute_key: the attribute key of interest
1028
+
1029
+ Returns:
1030
+ the value of the attribute key of the best fitting entity
1031
+ """
1032
+ if len(candidates) == 0:
1033
+ return None
1034
+ elif len(candidates) == 1:
1035
+ return candidates[0]
1036
+ else:
1037
+ best_fit = np.argmax([determine_intersection(token, c) for c in candidates])
1038
+ return candidates[int(best_fit)]
1039
+
1040
+
1041
+ def determine_token_labels(
1042
+ token: Token,
1043
+ entities: List[Dict],
1044
+ extractors: Optional[Set[Text]] = None,
1045
+ attribute_key: Text = ENTITY_ATTRIBUTE_TYPE,
1046
+ ) -> Text:
1047
+ """Select token label for the provided attribute key for non-overlapping entities.
1048
+
1049
+ Args:
1050
+ token: a single token
1051
+ entities: entities found by a single extractor
1052
+ extractors: list of extractors
1053
+ attribute_key: the attribute key for which the entity type should be returned
1054
+ Returns:
1055
+ entity type
1056
+ """
1057
+ entity = determine_entity_for_token(token, entities, extractors)
1058
+
1059
+ if entity is None:
1060
+ return NO_ENTITY_TAG
1061
+
1062
+ label = entity.get(attribute_key)
1063
+
1064
+ if not label:
1065
+ return NO_ENTITY_TAG
1066
+
1067
+ return label
1068
+
1069
+
1070
+ def determine_entity_for_token(
1071
+ token: Token,
1072
+ entities: List[Dict[Text, Any]],
1073
+ extractors: Optional[Set[Text]] = None,
1074
+ ) -> Optional[Dict[Text, Any]]:
1075
+ """Determines the best fitting non-overlapping entity for the given token.
1076
+
1077
+ Args:
1078
+ token: a single token
1079
+ entities: entities found by a single extractor
1080
+ extractors: list of extractors
1081
+
1082
+ Returns:
1083
+ entity type
1084
+ """
1085
+ if entities is None or len(entities) == 0:
1086
+ return None
1087
+ if do_any_extractors_not_support_overlap(extractors) and do_entities_overlap(
1088
+ entities
1089
+ ):
1090
+ raise ValueError("The possible entities should not overlap.")
1091
+
1092
+ candidates = find_intersecting_entities(token, entities)
1093
+ return pick_best_entity_fit(token, candidates)
1094
+
1095
+
1096
+ def do_any_extractors_not_support_overlap(extractors: Optional[Set[Text]]) -> bool:
1097
+ """Checks if any extractor does not support overlapping entities.
1098
+
1099
+ Args:
1100
+ extractors: Names of the entity extractors
1101
+
1102
+ Returns:
1103
+ `True` if and only if CRFEntityExtractor or DIETClassifier is in `extractors`
1104
+ """
1105
+ if extractors is None:
1106
+ return False
1107
+
1108
+ from rasa.nlu.extractors.crf_entity_extractor import CRFEntityExtractor
1109
+ from rasa.nlu.classifiers.diet_classifier import DIETClassifier
1110
+
1111
+ return not extractors.isdisjoint(
1112
+ {CRFEntityExtractor.__name__, DIETClassifier.__name__}
1113
+ )
1114
+
1115
+
1116
+ def align_entity_predictions(
1117
+ result: EntityEvaluationResult, extractors: Set[Text]
1118
+ ) -> Dict:
1119
+ """Aligns entity predictions to the message tokens.
1120
+
1121
+ Determines for every token the true label based on the
1122
+ prediction targets and the label assigned by each
1123
+ single extractor.
1124
+
1125
+ Args:
1126
+ result: entity evaluation result
1127
+ extractors: the entity extractors that should be considered
1128
+
1129
+ Returns: dictionary containing the true token labels and token labels
1130
+ from the extractors
1131
+ """
1132
+ true_token_labels = []
1133
+ entities_by_extractors: Dict[Text, List] = {
1134
+ extractor: [] for extractor in extractors
1135
+ }
1136
+ for p in result.entity_predictions:
1137
+ entities_by_extractors[p[EXTRACTOR]].append(p)
1138
+ extractor_labels: Dict[Text, List] = {extractor: [] for extractor in extractors}
1139
+ extractor_confidences: Dict[Text, List] = {
1140
+ extractor: [] for extractor in extractors
1141
+ }
1142
+ for t in result.tokens:
1143
+ true_token_labels.append(_concat_entity_labels(t, result.entity_targets))
1144
+ for extractor, entities in entities_by_extractors.items():
1145
+ extracted_labels = _concat_entity_labels(t, entities, {extractor})
1146
+ extracted_confidences = _get_entity_confidences(t, entities, {extractor})
1147
+ extractor_labels[extractor].append(extracted_labels)
1148
+ extractor_confidences[extractor].append(extracted_confidences)
1149
+
1150
+ return {
1151
+ "target_labels": true_token_labels,
1152
+ "extractor_labels": extractor_labels,
1153
+ "confidences": extractor_confidences,
1154
+ }
1155
+
1156
+
1157
+ def _concat_entity_labels(
1158
+ token: Token, entities: List[Dict], extractors: Optional[Set[Text]] = None
1159
+ ) -> Text:
1160
+ """Concatenate labels for entity type, role, and group for evaluation.
1161
+
1162
+ In order to calculate metrics also for entity type, role, and group we need to
1163
+ concatenate their labels. For example, 'location.destination'. This allows
1164
+ us to report metrics for every combination of entity type, role, and group.
1165
+
1166
+ Args:
1167
+ token: the token we are looking at
1168
+ entities: the available entities
1169
+ extractors: the extractor of interest
1170
+
1171
+ Returns:
1172
+ the entity label of the provided token
1173
+ """
1174
+ entity_label = determine_token_labels(
1175
+ token, entities, extractors, ENTITY_ATTRIBUTE_TYPE
1176
+ )
1177
+ group_label = determine_token_labels(
1178
+ token, entities, extractors, ENTITY_ATTRIBUTE_GROUP
1179
+ )
1180
+ role_label = determine_token_labels(
1181
+ token, entities, extractors, ENTITY_ATTRIBUTE_ROLE
1182
+ )
1183
+
1184
+ if entity_label == role_label == group_label == NO_ENTITY_TAG:
1185
+ return NO_ENTITY_TAG
1186
+
1187
+ labels = [entity_label, group_label, role_label]
1188
+ labels = [label for label in labels if label != NO_ENTITY_TAG]
1189
+
1190
+ return ".".join(labels)
1191
+
1192
+
1193
+ def _get_entity_confidences(
1194
+ token: Token, entities: List[Dict], extractors: Optional[Set[Text]] = None
1195
+ ) -> float:
1196
+ """Get the confidence value of the best fitting entity.
1197
+
1198
+ If multiple confidence values are present, e.g. for type, role, group, we
1199
+ pick the lowest confidence value.
1200
+
1201
+ Args:
1202
+ token: the token we are looking at
1203
+ entities: the available entities
1204
+ extractors: the extractor of interest
1205
+
1206
+ Returns:
1207
+ the confidence value
1208
+ """
1209
+ entity = determine_entity_for_token(token, entities, extractors)
1210
+
1211
+ if entity is None:
1212
+ return 0.0
1213
+
1214
+ if entity.get("extractor") not in EXTRACTORS_WITH_CONFIDENCES:
1215
+ return 0.0
1216
+
1217
+ conf_type = entity.get(ENTITY_ATTRIBUTE_CONFIDENCE_TYPE) or 1.0
1218
+ conf_role = entity.get(ENTITY_ATTRIBUTE_CONFIDENCE_ROLE) or 1.0
1219
+ conf_group = entity.get(ENTITY_ATTRIBUTE_CONFIDENCE_GROUP) or 1.0
1220
+
1221
+ return min(conf_type, conf_role, conf_group)
1222
+
1223
+
1224
+ def align_all_entity_predictions(
1225
+ entity_results: List[EntityEvaluationResult], extractors: Set[Text]
1226
+ ) -> List[Dict]:
1227
+ """Aligns entity predictions to the message tokens.
1228
+
1229
+ Processes the whole dataset using align_entity_predictions.
1230
+
1231
+ Args:
1232
+ entity_results: list of entity prediction results
1233
+ extractors: the entity extractors that should be considered
1234
+
1235
+ Returns: list of dictionaries containing the true token labels and token
1236
+ labels from the extractors
1237
+ """
1238
+ aligned_predictions = []
1239
+ for result in entity_results:
1240
+ aligned_predictions.append(align_entity_predictions(result, extractors))
1241
+
1242
+ return aligned_predictions
1243
+
1244
+
1245
+ async def get_eval_data(
1246
+ processor: MessageProcessor, test_data: TrainingData
1247
+ ) -> Tuple[
1248
+ List[IntentEvaluationResult],
1249
+ List[ResponseSelectionEvaluationResult],
1250
+ List[EntityEvaluationResult],
1251
+ ]:
1252
+ """Runs the model for the test set and extracts targets and predictions.
1253
+
1254
+ Returns intent results (intent targets and predictions, the original
1255
+ messages and the confidences of the predictions), response results (
1256
+ response targets and predictions) as well as entity results
1257
+ (entity_targets, entity_predictions, and tokens).
1258
+
1259
+ Args:
1260
+ processor: the processor
1261
+ test_data: test data
1262
+
1263
+ Returns: intent, response, and entity evaluation results
1264
+ """
1265
+ logger.info("Running model for predictions:")
1266
+
1267
+ intent_results, entity_results, response_selection_results = [], [], []
1268
+
1269
+ response_labels = {
1270
+ e.get(INTENT_RESPONSE_KEY)
1271
+ for e in test_data.intent_examples
1272
+ if e.get(INTENT_RESPONSE_KEY) is not None
1273
+ }
1274
+ intent_labels = {e.get(INTENT) for e in test_data.intent_examples}
1275
+ should_eval_intents = len(intent_labels) >= 2
1276
+ should_eval_response_selection = len(response_labels) >= 2
1277
+ should_eval_entities = len(test_data.entity_examples) > 0
1278
+
1279
+ for example in tqdm(test_data.nlu_examples):
1280
+ result = await processor.parse_message(
1281
+ UserMessage(text=example.get(TEXT)),
1282
+ only_output_properties=False,
1283
+ )
1284
+ _remove_entities_of_extractors(result, PRETRAINED_EXTRACTORS)
1285
+ if should_eval_intents:
1286
+ if fallback_classifier.is_fallback_classifier_prediction(result):
1287
+ # Revert fallback prediction to not shadow
1288
+ # the wrongly predicted intent
1289
+ # during the test phase.
1290
+ result = fallback_classifier.undo_fallback_prediction(result)
1291
+ intent_prediction = result.get(INTENT, {})
1292
+ intent_results.append(
1293
+ IntentEvaluationResult(
1294
+ example.get(INTENT, ""),
1295
+ intent_prediction.get(INTENT_NAME_KEY),
1296
+ result.get(TEXT),
1297
+ intent_prediction.get("confidence"),
1298
+ )
1299
+ )
1300
+
1301
+ if should_eval_response_selection:
1302
+ # including all examples here. Empty response examples are filtered at the
1303
+ # time of metric calculation
1304
+ intent_target = example.get(INTENT, "")
1305
+ selector_properties = result.get(RESPONSE_SELECTOR_PROPERTY_NAME, {})
1306
+ response_selector_retrieval_intents = selector_properties.get(
1307
+ RESPONSE_SELECTOR_RETRIEVAL_INTENTS, set()
1308
+ )
1309
+ if (
1310
+ intent_target in response_selector_retrieval_intents
1311
+ and intent_target in selector_properties
1312
+ ):
1313
+ response_prediction_key = intent_target
1314
+ else:
1315
+ response_prediction_key = RESPONSE_SELECTOR_DEFAULT_INTENT
1316
+
1317
+ response_prediction = selector_properties.get(
1318
+ response_prediction_key, {}
1319
+ ).get(RESPONSE_SELECTOR_PREDICTION_KEY, {})
1320
+
1321
+ intent_response_key_target = example.get(INTENT_RESPONSE_KEY, "")
1322
+
1323
+ response_selection_results.append(
1324
+ ResponseSelectionEvaluationResult(
1325
+ intent_response_key_target,
1326
+ response_prediction.get(INTENT_RESPONSE_KEY),
1327
+ result.get(TEXT),
1328
+ response_prediction.get(PREDICTED_CONFIDENCE_KEY),
1329
+ )
1330
+ )
1331
+
1332
+ if should_eval_entities:
1333
+ entity_results.append(
1334
+ EntityEvaluationResult(
1335
+ example.get(ENTITIES, []),
1336
+ result.get(ENTITIES, []),
1337
+ result.get(TOKENS_NAMES[TEXT], []),
1338
+ result.get(TEXT),
1339
+ )
1340
+ )
1341
+
1342
+ return intent_results, response_selection_results, entity_results
1343
+
1344
+
1345
+ def _get_active_entity_extractors(
1346
+ entity_results: List[EntityEvaluationResult],
1347
+ ) -> Set[Text]:
1348
+ """Finds the names of entity extractors from the EntityEvaluationResults."""
1349
+ extractors: Set[Text] = set()
1350
+ for result in entity_results:
1351
+ for prediction in result.entity_predictions:
1352
+ if EXTRACTOR in prediction:
1353
+ extractors.add(prediction[EXTRACTOR])
1354
+ return extractors
1355
+
1356
+
1357
+ def _remove_entities_of_extractors(
1358
+ nlu_parse_result: Dict[Text, Any], extractor_names: Set[Text]
1359
+ ) -> None:
1360
+ """Removes the entities annotated by the given extractor names."""
1361
+ entities = nlu_parse_result.get(ENTITIES)
1362
+ if not entities:
1363
+ return
1364
+ filtered_entities = [e for e in entities if e.get(EXTRACTOR) not in extractor_names]
1365
+ nlu_parse_result[ENTITIES] = filtered_entities
1366
+
1367
+
1368
+ async def run_evaluation(
1369
+ data_path: Text,
1370
+ processor: MessageProcessor,
1371
+ output_directory: Optional[Text] = None,
1372
+ successes: bool = False,
1373
+ errors: bool = False,
1374
+ disable_plotting: bool = False,
1375
+ report_as_dict: Optional[bool] = None,
1376
+ domain_path: Optional[Text] = None,
1377
+ ) -> Dict: # pragma: no cover
1378
+ """Evaluate intent classification, response selection and entity extraction.
1379
+
1380
+ Args:
1381
+ data_path: path to the test data
1382
+ processor: the processor used to process and predict
1383
+ output_directory: path to folder where all output will be stored
1384
+ successes: if true successful predictions are written to a file
1385
+ errors: if true incorrect predictions are written to a file
1386
+ disable_plotting: if true confusion matrix and histogram will not be rendered
1387
+ report_as_dict: `True` if the evaluation report should be returned as `dict`.
1388
+ If `False` the report is returned in a human-readable text format. If `None`
1389
+ `report_as_dict` is considered as `True` in case an `output_directory` is
1390
+ given.
1391
+ domain_path: Path to the domain file(s).
1392
+
1393
+ Returns: dictionary containing evaluation results
1394
+ """
1395
+ import rasa.shared.nlu.training_data.loading
1396
+ from rasa.shared.constants import DEFAULT_DOMAIN_PATH
1397
+
1398
+ test_data_importer = TrainingDataImporter.load_from_dict(
1399
+ training_data_paths=[data_path],
1400
+ domain_path=domain_path if domain_path else DEFAULT_DOMAIN_PATH,
1401
+ )
1402
+ test_data = test_data_importer.get_nlu_data()
1403
+
1404
+ result: Dict[Text, Optional[Dict]] = {
1405
+ "intent_evaluation": None,
1406
+ "entity_evaluation": None,
1407
+ "response_selection_evaluation": None,
1408
+ }
1409
+
1410
+ if output_directory:
1411
+ rasa.shared.utils.io.create_directory(output_directory)
1412
+
1413
+ (intent_results, response_selection_results, entity_results) = await get_eval_data(
1414
+ processor, test_data
1415
+ )
1416
+
1417
+ if intent_results:
1418
+ logger.info("Intent evaluation results:")
1419
+ result["intent_evaluation"] = evaluate_intents(
1420
+ intent_results,
1421
+ output_directory,
1422
+ successes,
1423
+ errors,
1424
+ disable_plotting,
1425
+ report_as_dict=report_as_dict,
1426
+ )
1427
+
1428
+ if response_selection_results:
1429
+ logger.info("Response selection evaluation results:")
1430
+ result["response_selection_evaluation"] = evaluate_response_selections(
1431
+ response_selection_results,
1432
+ output_directory,
1433
+ successes,
1434
+ errors,
1435
+ disable_plotting,
1436
+ report_as_dict=report_as_dict,
1437
+ )
1438
+
1439
+ if any(entity_results):
1440
+ logger.info("Entity evaluation results:")
1441
+ extractors = _get_active_entity_extractors(entity_results)
1442
+ result["entity_evaluation"] = evaluate_entities(
1443
+ entity_results,
1444
+ extractors,
1445
+ output_directory,
1446
+ successes,
1447
+ errors,
1448
+ disable_plotting,
1449
+ report_as_dict=report_as_dict,
1450
+ )
1451
+
1452
+ telemetry.track_nlu_model_test(test_data)
1453
+
1454
+ return result
1455
+
1456
+
1457
+ def generate_folds(
1458
+ n: int, training_data: TrainingData
1459
+ ) -> Iterator[Tuple[TrainingData, TrainingData]]:
1460
+ """Generates n cross validation folds for given training data."""
1461
+ from sklearn.model_selection import StratifiedKFold
1462
+
1463
+ skf = StratifiedKFold(n_splits=n, shuffle=True)
1464
+ x = training_data.intent_examples
1465
+
1466
+ # Get labels as they appear in the training data because we want a
1467
+ # stratified split on all intents(including retrieval intents if they exist)
1468
+ y = [example.get_full_intent() for example in x]
1469
+ for i_fold, (train_index, test_index) in enumerate(skf.split(x, y)):
1470
+ logger.debug(f"Fold: {i_fold}")
1471
+ train = [x[i] for i in train_index]
1472
+ test = [x[i] for i in test_index]
1473
+ yield (
1474
+ TrainingData(
1475
+ training_examples=train,
1476
+ entity_synonyms=training_data.entity_synonyms,
1477
+ regex_features=training_data.regex_features,
1478
+ lookup_tables=training_data.lookup_tables,
1479
+ responses=training_data.responses,
1480
+ ),
1481
+ TrainingData(
1482
+ training_examples=test,
1483
+ entity_synonyms=training_data.entity_synonyms,
1484
+ regex_features=training_data.regex_features,
1485
+ lookup_tables=training_data.lookup_tables,
1486
+ responses=training_data.responses,
1487
+ ),
1488
+ )
1489
+
1490
+
1491
+ async def combine_result(
1492
+ intent_metrics: IntentMetrics,
1493
+ entity_metrics: EntityMetrics,
1494
+ response_selection_metrics: ResponseSelectionMetrics,
1495
+ processor: MessageProcessor,
1496
+ data: TrainingData,
1497
+ intent_results: Optional[List[IntentEvaluationResult]] = None,
1498
+ entity_results: Optional[List[EntityEvaluationResult]] = None,
1499
+ response_selection_results: Optional[
1500
+ List[ResponseSelectionEvaluationResult]
1501
+ ] = None,
1502
+ ) -> Tuple[IntentMetrics, EntityMetrics, ResponseSelectionMetrics]:
1503
+ """Collects intent, response selection and entity metrics for cross validation.
1504
+
1505
+ If `intent_results`, `response_selection_results` or `entity_results` is provided
1506
+ as a list, prediction results are also collected.
1507
+
1508
+ Args:
1509
+ intent_metrics: intent metrics
1510
+ entity_metrics: entity metrics
1511
+ response_selection_metrics: response selection metrics
1512
+ processor: the processor
1513
+ data: training data
1514
+ intent_results: intent evaluation results
1515
+ entity_results: entity evaluation results
1516
+ response_selection_results: reponse selection evaluation results
1517
+
1518
+ Returns: intent, entity, and response selection metrics
1519
+ """
1520
+ (
1521
+ intent_current_metrics,
1522
+ entity_current_metrics,
1523
+ response_selection_current_metrics,
1524
+ current_intent_results,
1525
+ current_entity_results,
1526
+ current_response_selection_results,
1527
+ ) = await compute_metrics(processor, data)
1528
+
1529
+ if intent_results is not None:
1530
+ intent_results += current_intent_results
1531
+
1532
+ if entity_results is not None:
1533
+ entity_results += current_entity_results
1534
+
1535
+ if response_selection_results is not None:
1536
+ response_selection_results += current_response_selection_results
1537
+
1538
+ for k, v in intent_current_metrics.items():
1539
+ intent_metrics[k] = v + intent_metrics[k]
1540
+
1541
+ for k, v in response_selection_current_metrics.items():
1542
+ response_selection_metrics[k] = v + response_selection_metrics[k]
1543
+
1544
+ for extractor, extractor_metric in entity_current_metrics.items():
1545
+ entity_metrics[extractor] = {
1546
+ k: v + entity_metrics[extractor][k] for k, v in extractor_metric.items()
1547
+ }
1548
+
1549
+ return intent_metrics, entity_metrics, response_selection_metrics
1550
+
1551
+
1552
+ def _contains_entity_labels(entity_results: List[EntityEvaluationResult]) -> bool:
1553
+
1554
+ for result in entity_results:
1555
+ if result.entity_targets or result.entity_predictions:
1556
+ return True
1557
+ return False
1558
+
1559
+
1560
+ async def cross_validate(
1561
+ data: TrainingData,
1562
+ n_folds: int,
1563
+ nlu_config: Union[Text, Dict],
1564
+ output: Optional[Text] = None,
1565
+ successes: bool = False,
1566
+ errors: bool = False,
1567
+ disable_plotting: bool = False,
1568
+ report_as_dict: Optional[bool] = None,
1569
+ ) -> Tuple[CVEvaluationResult, CVEvaluationResult, CVEvaluationResult]:
1570
+ """Stratified cross validation on data.
1571
+
1572
+ Args:
1573
+ data: Training Data
1574
+ n_folds: integer, number of cv folds
1575
+ nlu_config: nlu config file
1576
+ output: path to folder where reports are stored
1577
+ successes: if true successful predictions are written to a file
1578
+ errors: if true incorrect predictions are written to a file
1579
+ disable_plotting: if true no confusion matrix and historgram plates are created
1580
+ report_as_dict: `True` if the evaluation report should be returned as `dict`.
1581
+ If `False` the report is returned in a human-readable text format. If `None`
1582
+ `report_as_dict` is considered as `True` in case an `output_directory` is
1583
+ given.
1584
+
1585
+ Returns:
1586
+ dictionary with key, list structure, where each entry in list
1587
+ corresponds to the relevant result for one fold
1588
+ """
1589
+ import rasa.model_training
1590
+
1591
+ with TempDirectoryPath(get_temp_dir_name()) as temp_dir:
1592
+ tmp_path = Path(temp_dir)
1593
+
1594
+ if isinstance(nlu_config, Dict):
1595
+ config_path = tmp_path / "config.yml"
1596
+ write_yaml(nlu_config, config_path)
1597
+ nlu_config = str(config_path)
1598
+
1599
+ if output:
1600
+ rasa.shared.utils.io.create_directory(output)
1601
+
1602
+ intent_train_metrics: IntentMetrics = defaultdict(list)
1603
+ intent_test_metrics: IntentMetrics = defaultdict(list)
1604
+ entity_train_metrics: EntityMetrics = defaultdict(lambda: defaultdict(list))
1605
+ entity_test_metrics: EntityMetrics = defaultdict(lambda: defaultdict(list))
1606
+ response_selection_train_metrics: ResponseSelectionMetrics = defaultdict(list)
1607
+ response_selection_test_metrics: ResponseSelectionMetrics = defaultdict(list)
1608
+
1609
+ intent_test_results: List[IntentEvaluationResult] = []
1610
+ entity_test_results: List[EntityEvaluationResult] = []
1611
+ response_selection_test_results: List[ResponseSelectionEvaluationResult] = []
1612
+
1613
+ for train, test in generate_folds(n_folds, data):
1614
+ training_data_file = tmp_path / "training_data.yml"
1615
+ RasaYAMLWriter().dump(training_data_file, train)
1616
+
1617
+ model_file = await rasa.model_training.train_nlu(
1618
+ nlu_config, str(training_data_file), str(tmp_path)
1619
+ )
1620
+
1621
+ processor = Agent.load(model_file).processor
1622
+
1623
+ # calculate train accuracy
1624
+ await combine_result(
1625
+ intent_train_metrics,
1626
+ entity_train_metrics,
1627
+ response_selection_train_metrics,
1628
+ processor,
1629
+ train,
1630
+ )
1631
+ # calculate test accuracy
1632
+ await combine_result(
1633
+ intent_test_metrics,
1634
+ entity_test_metrics,
1635
+ response_selection_test_metrics,
1636
+ processor,
1637
+ test,
1638
+ intent_test_results,
1639
+ entity_test_results,
1640
+ response_selection_test_results,
1641
+ )
1642
+
1643
+ intent_evaluation = {}
1644
+ if intent_test_results:
1645
+ logger.info("Accumulated test folds intent evaluation results:")
1646
+ intent_evaluation = evaluate_intents(
1647
+ intent_test_results,
1648
+ output,
1649
+ successes,
1650
+ errors,
1651
+ disable_plotting,
1652
+ report_as_dict=report_as_dict,
1653
+ )
1654
+
1655
+ entity_evaluation = {}
1656
+ if entity_test_results:
1657
+ logger.info("Accumulated test folds entity evaluation results:")
1658
+ extractors = _get_active_entity_extractors(entity_test_results)
1659
+ entity_evaluation = evaluate_entities(
1660
+ entity_test_results,
1661
+ extractors,
1662
+ output,
1663
+ successes,
1664
+ errors,
1665
+ disable_plotting,
1666
+ report_as_dict=report_as_dict,
1667
+ )
1668
+
1669
+ responses_evaluation = {}
1670
+ if response_selection_test_results:
1671
+ logger.info("Accumulated test folds response selection evaluation results:")
1672
+ responses_evaluation = evaluate_response_selections(
1673
+ response_selection_test_results,
1674
+ output,
1675
+ successes,
1676
+ errors,
1677
+ disable_plotting,
1678
+ report_as_dict=report_as_dict,
1679
+ )
1680
+
1681
+ return (
1682
+ CVEvaluationResult(
1683
+ dict(intent_train_metrics), dict(intent_test_metrics), intent_evaluation
1684
+ ),
1685
+ CVEvaluationResult(
1686
+ dict(entity_train_metrics), dict(entity_test_metrics), entity_evaluation
1687
+ ),
1688
+ CVEvaluationResult(
1689
+ dict(response_selection_train_metrics),
1690
+ dict(response_selection_test_metrics),
1691
+ responses_evaluation,
1692
+ ),
1693
+ )
1694
+
1695
+
1696
+ def _targets_predictions_from(
1697
+ results: Union[
1698
+ List[IntentEvaluationResult], List[ResponseSelectionEvaluationResult]
1699
+ ],
1700
+ target_key: Text,
1701
+ prediction_key: Text,
1702
+ ) -> Iterator[Iterable[Optional[Text]]]:
1703
+ return zip(*[(getattr(r, target_key), getattr(r, prediction_key)) for r in results])
1704
+
1705
+
1706
+ async def compute_metrics(
1707
+ processor: MessageProcessor, training_data: TrainingData
1708
+ ) -> Tuple[
1709
+ IntentMetrics,
1710
+ EntityMetrics,
1711
+ ResponseSelectionMetrics,
1712
+ List[IntentEvaluationResult],
1713
+ List[EntityEvaluationResult],
1714
+ List[ResponseSelectionEvaluationResult],
1715
+ ]:
1716
+ """Metrics for intent classification, response selection and entity extraction.
1717
+
1718
+ Args:
1719
+ processor: the processor
1720
+ training_data: training data
1721
+
1722
+ Returns: intent, response selection and entity metrics, and prediction results.
1723
+ """
1724
+ intent_results, response_selection_results, entity_results = await get_eval_data(
1725
+ processor, training_data
1726
+ )
1727
+
1728
+ intent_results = remove_empty_intent_examples(intent_results)
1729
+
1730
+ response_selection_results = remove_empty_response_examples(
1731
+ response_selection_results
1732
+ )
1733
+
1734
+ intent_metrics: IntentMetrics = {}
1735
+ if intent_results:
1736
+ intent_metrics = _compute_metrics(
1737
+ intent_results, "intent_target", "intent_prediction"
1738
+ )
1739
+
1740
+ entity_metrics = {}
1741
+ if entity_results:
1742
+ entity_metrics = _compute_entity_metrics(entity_results)
1743
+
1744
+ response_selection_metrics: ResponseSelectionMetrics = {}
1745
+ if response_selection_results:
1746
+ response_selection_metrics = _compute_metrics(
1747
+ response_selection_results,
1748
+ "intent_response_key_target",
1749
+ "intent_response_key_prediction",
1750
+ )
1751
+
1752
+ return (
1753
+ intent_metrics,
1754
+ entity_metrics,
1755
+ response_selection_metrics,
1756
+ intent_results,
1757
+ entity_results,
1758
+ response_selection_results,
1759
+ )
1760
+
1761
+
1762
+ async def compare_nlu(
1763
+ configs: List[Text],
1764
+ data: TrainingData,
1765
+ exclusion_percentages: List[int],
1766
+ f_score_results: Dict[Text, List[List[float]]],
1767
+ model_names: List[Text],
1768
+ output: Text,
1769
+ runs: int,
1770
+ ) -> List[int]:
1771
+ """Trains and compares multiple NLU models.
1772
+
1773
+ For each run and exclusion percentage a model per config file is trained.
1774
+ Thereby, the model is trained only on the current percentage of training data.
1775
+ Afterwards, the model is tested on the complete test data of that run.
1776
+ All results are stored in the provided output directory.
1777
+
1778
+ Args:
1779
+ configs: config files needed for training
1780
+ data: training data
1781
+ exclusion_percentages: percentages of training data to exclude during comparison
1782
+ f_score_results: dictionary of model name to f-score results per run
1783
+ model_names: names of the models to train
1784
+ output: the output directory
1785
+ runs: number of comparison runs
1786
+
1787
+ Returns: training examples per run
1788
+ """
1789
+ import rasa.model_training
1790
+
1791
+ training_examples_per_run = []
1792
+
1793
+ for run in range(runs):
1794
+
1795
+ logger.info("Beginning comparison run {}/{}".format(run + 1, runs))
1796
+
1797
+ run_path = os.path.join(output, "run_{}".format(run + 1))
1798
+ io_utils.create_path(run_path)
1799
+
1800
+ test_path = os.path.join(run_path, TEST_DATA_FILE)
1801
+ io_utils.create_path(test_path)
1802
+
1803
+ train, test = data.train_test_split()
1804
+ rasa.shared.utils.io.write_text_file(test.nlu_as_yaml(), test_path)
1805
+
1806
+ for percentage in exclusion_percentages:
1807
+ percent_string = f"{percentage}%_exclusion"
1808
+
1809
+ _, train_included = train.train_test_split(percentage / 100)
1810
+ # only count for the first run and ignore the others
1811
+ if run == 0:
1812
+ training_examples_per_run.append(len(train_included.nlu_examples))
1813
+
1814
+ model_output_path = os.path.join(run_path, percent_string)
1815
+ train_split_path = os.path.join(model_output_path, "train")
1816
+ train_nlu_split_path = os.path.join(train_split_path, TRAIN_DATA_FILE)
1817
+ train_nlg_split_path = os.path.join(train_split_path, NLG_DATA_FILE)
1818
+ io_utils.create_path(train_nlu_split_path)
1819
+ rasa.shared.utils.io.write_text_file(
1820
+ train_included.nlu_as_yaml(), train_nlu_split_path
1821
+ )
1822
+ rasa.shared.utils.io.write_text_file(
1823
+ train_included.nlg_as_yaml(), train_nlg_split_path
1824
+ )
1825
+
1826
+ for nlu_config, model_name in zip(configs, model_names):
1827
+ logger.info(
1828
+ "Evaluating configuration '{}' with {} training data.".format(
1829
+ model_name, percent_string
1830
+ )
1831
+ )
1832
+
1833
+ try:
1834
+ model_path = await rasa.model_training.train_nlu(
1835
+ nlu_config,
1836
+ train_split_path,
1837
+ model_output_path,
1838
+ fixed_model_name=model_name,
1839
+ )
1840
+ except Exception as e: # skipcq: PYL-W0703
1841
+ # general exception catching needed to continue evaluating other
1842
+ # model configurations
1843
+ logger.warning(f"Training model '{model_name}' failed. Error: {e}")
1844
+ f_score_results[model_name][run].append(0.0)
1845
+ continue
1846
+
1847
+ output_path = os.path.join(model_output_path, f"{model_name}_report")
1848
+ processor = Agent.load(model_path=model_path).processor
1849
+ result = await run_evaluation(
1850
+ test_path, processor, output_directory=output_path, errors=True
1851
+ )
1852
+
1853
+ f1 = result["intent_evaluation"]["f1_score"]
1854
+ f_score_results[model_name][run].append(f1)
1855
+
1856
+ return training_examples_per_run
1857
+
1858
+
1859
+ def _compute_metrics(
1860
+ results: Union[
1861
+ List[IntentEvaluationResult], List[ResponseSelectionEvaluationResult]
1862
+ ],
1863
+ target_key: Text,
1864
+ prediction_key: Text,
1865
+ ) -> Union[IntentMetrics, ResponseSelectionMetrics]:
1866
+ """Computes evaluation metrics for a given corpus and returns the results.
1867
+
1868
+ Args:
1869
+ results: evaluation results
1870
+ target_key: target key name
1871
+ prediction_key: prediction key name
1872
+
1873
+ Returns: metrics
1874
+ """
1875
+ from rasa.model_testing import get_evaluation_metrics
1876
+
1877
+ # compute fold metrics
1878
+ targets, predictions = _targets_predictions_from(
1879
+ results, target_key, prediction_key
1880
+ )
1881
+ _, precision, f1, accuracy = get_evaluation_metrics(targets, predictions)
1882
+
1883
+ return {"Accuracy": [accuracy], "F1-score": [f1], "Precision": [precision]}
1884
+
1885
+
1886
+ def _compute_entity_metrics(
1887
+ entity_results: List[EntityEvaluationResult],
1888
+ ) -> EntityMetrics:
1889
+ """Computes entity evaluation metrics and returns the results.
1890
+
1891
+ Args:
1892
+ entity_results: entity evaluation results
1893
+ Returns: entity metrics
1894
+ """
1895
+ from rasa.model_testing import get_evaluation_metrics
1896
+
1897
+ entity_metric_results: EntityMetrics = defaultdict(lambda: defaultdict(list))
1898
+ extractors = _get_active_entity_extractors(entity_results)
1899
+
1900
+ if not extractors:
1901
+ return entity_metric_results
1902
+
1903
+ aligned_predictions = align_all_entity_predictions(entity_results, extractors)
1904
+
1905
+ merged_targets = merge_labels(aligned_predictions)
1906
+ merged_targets = substitute_labels(merged_targets, NO_ENTITY_TAG, NO_ENTITY)
1907
+
1908
+ for extractor in extractors:
1909
+ merged_predictions = merge_labels(aligned_predictions, extractor)
1910
+ merged_predictions = substitute_labels(
1911
+ merged_predictions, NO_ENTITY_TAG, NO_ENTITY
1912
+ )
1913
+ _, precision, f1, accuracy = get_evaluation_metrics(
1914
+ merged_targets, merged_predictions, exclude_label=NO_ENTITY
1915
+ )
1916
+ entity_metric_results[extractor]["Accuracy"].append(accuracy)
1917
+ entity_metric_results[extractor]["F1-score"].append(f1)
1918
+ entity_metric_results[extractor]["Precision"].append(precision)
1919
+
1920
+ return entity_metric_results
1921
+
1922
+
1923
+ def log_results(results: IntentMetrics, dataset_name: Text) -> None:
1924
+ """Logs results of cross validation.
1925
+
1926
+ Args:
1927
+ results: dictionary of results returned from cross validation
1928
+ dataset_name: string of which dataset the results are from, e.g. test/train
1929
+ """
1930
+ for k, v in results.items():
1931
+ logger.info(f"{dataset_name} {k}: {np.mean(v):.3f} ({np.std(v):.3f})")
1932
+
1933
+
1934
+ def log_entity_results(results: EntityMetrics, dataset_name: Text) -> None:
1935
+ """Logs entity results of cross validation.
1936
+
1937
+ Args:
1938
+ results: dictionary of dictionaries of results returned from cross validation
1939
+ dataset_name: string of which dataset the results are from, e.g. test/train
1940
+ """
1941
+ for extractor, result in results.items():
1942
+ logger.info(f"Entity extractor: {extractor}")
1943
+ log_results(result, dataset_name)