rasa-pro 3.8.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rasa-pro might be problematic. Click here for more details.

Files changed (644) hide show
  1. README.md +380 -0
  2. rasa/__init__.py +10 -0
  3. rasa/__main__.py +151 -0
  4. rasa/anonymization/__init__.py +2 -0
  5. rasa/anonymization/anonymisation_rule_yaml_reader.py +91 -0
  6. rasa/anonymization/anonymization_pipeline.py +287 -0
  7. rasa/anonymization/anonymization_rule_executor.py +260 -0
  8. rasa/anonymization/anonymization_rule_orchestrator.py +120 -0
  9. rasa/anonymization/schemas/config.yml +47 -0
  10. rasa/anonymization/utils.py +117 -0
  11. rasa/api.py +146 -0
  12. rasa/cli/__init__.py +5 -0
  13. rasa/cli/arguments/__init__.py +0 -0
  14. rasa/cli/arguments/data.py +81 -0
  15. rasa/cli/arguments/default_arguments.py +165 -0
  16. rasa/cli/arguments/evaluate.py +65 -0
  17. rasa/cli/arguments/export.py +51 -0
  18. rasa/cli/arguments/interactive.py +74 -0
  19. rasa/cli/arguments/run.py +204 -0
  20. rasa/cli/arguments/shell.py +13 -0
  21. rasa/cli/arguments/test.py +211 -0
  22. rasa/cli/arguments/train.py +263 -0
  23. rasa/cli/arguments/visualize.py +34 -0
  24. rasa/cli/arguments/x.py +30 -0
  25. rasa/cli/data.py +292 -0
  26. rasa/cli/e2e_test.py +566 -0
  27. rasa/cli/evaluate.py +222 -0
  28. rasa/cli/export.py +251 -0
  29. rasa/cli/inspect.py +63 -0
  30. rasa/cli/interactive.py +164 -0
  31. rasa/cli/license.py +65 -0
  32. rasa/cli/markers.py +78 -0
  33. rasa/cli/project_templates/__init__.py +0 -0
  34. rasa/cli/project_templates/calm/actions/__init__.py +0 -0
  35. rasa/cli/project_templates/calm/actions/action_template.py +27 -0
  36. rasa/cli/project_templates/calm/actions/add_contact.py +30 -0
  37. rasa/cli/project_templates/calm/actions/db.py +57 -0
  38. rasa/cli/project_templates/calm/actions/list_contacts.py +22 -0
  39. rasa/cli/project_templates/calm/actions/remove_contact.py +35 -0
  40. rasa/cli/project_templates/calm/config.yml +12 -0
  41. rasa/cli/project_templates/calm/credentials.yml +33 -0
  42. rasa/cli/project_templates/calm/data/flows/add_contact.yml +31 -0
  43. rasa/cli/project_templates/calm/data/flows/list_contacts.yml +14 -0
  44. rasa/cli/project_templates/calm/data/flows/remove_contact.yml +29 -0
  45. rasa/cli/project_templates/calm/db/contacts.json +10 -0
  46. rasa/cli/project_templates/calm/domain/add_contact.yml +33 -0
  47. rasa/cli/project_templates/calm/domain/list_contacts.yml +14 -0
  48. rasa/cli/project_templates/calm/domain/remove_contact.yml +31 -0
  49. rasa/cli/project_templates/calm/domain/shared.yml +5 -0
  50. rasa/cli/project_templates/calm/e2e_tests/cancelations/user_cancels_during_a_correction.yml +16 -0
  51. rasa/cli/project_templates/calm/e2e_tests/cancelations/user_changes_mind_on_a_whim.yml +7 -0
  52. rasa/cli/project_templates/calm/e2e_tests/corrections/user_corrects_contact_handle.yml +20 -0
  53. rasa/cli/project_templates/calm/e2e_tests/corrections/user_corrects_contact_name.yml +19 -0
  54. rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_adds_contact_to_their_list.yml +15 -0
  55. rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_lists_contacts.yml +5 -0
  56. rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_removes_contact.yml +11 -0
  57. rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_removes_contact_from_list.yml +12 -0
  58. rasa/cli/project_templates/calm/endpoints.yml +45 -0
  59. rasa/cli/project_templates/default/actions/__init__.py +0 -0
  60. rasa/cli/project_templates/default/actions/actions.py +27 -0
  61. rasa/cli/project_templates/default/config.yml +44 -0
  62. rasa/cli/project_templates/default/credentials.yml +33 -0
  63. rasa/cli/project_templates/default/data/nlu.yml +91 -0
  64. rasa/cli/project_templates/default/data/rules.yml +13 -0
  65. rasa/cli/project_templates/default/data/stories.yml +30 -0
  66. rasa/cli/project_templates/default/domain.yml +34 -0
  67. rasa/cli/project_templates/default/endpoints.yml +42 -0
  68. rasa/cli/project_templates/default/tests/test_stories.yml +91 -0
  69. rasa/cli/project_templates/tutorial/actions.py +22 -0
  70. rasa/cli/project_templates/tutorial/config.yml +11 -0
  71. rasa/cli/project_templates/tutorial/credentials.yml +33 -0
  72. rasa/cli/project_templates/tutorial/data/flows.yml +8 -0
  73. rasa/cli/project_templates/tutorial/domain.yml +17 -0
  74. rasa/cli/project_templates/tutorial/endpoints.yml +45 -0
  75. rasa/cli/run.py +136 -0
  76. rasa/cli/scaffold.py +268 -0
  77. rasa/cli/shell.py +141 -0
  78. rasa/cli/studio/__init__.py +0 -0
  79. rasa/cli/studio/download.py +51 -0
  80. rasa/cli/studio/studio.py +110 -0
  81. rasa/cli/studio/train.py +59 -0
  82. rasa/cli/studio/upload.py +85 -0
  83. rasa/cli/telemetry.py +90 -0
  84. rasa/cli/test.py +280 -0
  85. rasa/cli/train.py +260 -0
  86. rasa/cli/utils.py +453 -0
  87. rasa/cli/visualize.py +40 -0
  88. rasa/cli/x.py +205 -0
  89. rasa/constants.py +37 -0
  90. rasa/core/__init__.py +17 -0
  91. rasa/core/actions/__init__.py +0 -0
  92. rasa/core/actions/action.py +1450 -0
  93. rasa/core/actions/action_clean_stack.py +59 -0
  94. rasa/core/actions/action_run_slot_rejections.py +207 -0
  95. rasa/core/actions/action_trigger_chitchat.py +31 -0
  96. rasa/core/actions/action_trigger_flow.py +109 -0
  97. rasa/core/actions/action_trigger_search.py +31 -0
  98. rasa/core/actions/constants.py +2 -0
  99. rasa/core/actions/forms.py +737 -0
  100. rasa/core/actions/loops.py +111 -0
  101. rasa/core/actions/two_stage_fallback.py +186 -0
  102. rasa/core/agent.py +557 -0
  103. rasa/core/auth_retry_tracker_store.py +122 -0
  104. rasa/core/brokers/__init__.py +0 -0
  105. rasa/core/brokers/broker.py +126 -0
  106. rasa/core/brokers/file.py +58 -0
  107. rasa/core/brokers/kafka.py +322 -0
  108. rasa/core/brokers/pika.py +387 -0
  109. rasa/core/brokers/sql.py +86 -0
  110. rasa/core/channels/__init__.py +55 -0
  111. rasa/core/channels/audiocodes.py +463 -0
  112. rasa/core/channels/botframework.py +339 -0
  113. rasa/core/channels/callback.py +85 -0
  114. rasa/core/channels/channel.py +419 -0
  115. rasa/core/channels/console.py +243 -0
  116. rasa/core/channels/development_inspector.py +93 -0
  117. rasa/core/channels/facebook.py +422 -0
  118. rasa/core/channels/hangouts.py +335 -0
  119. rasa/core/channels/inspector/.eslintrc.cjs +25 -0
  120. rasa/core/channels/inspector/.gitignore +23 -0
  121. rasa/core/channels/inspector/README.md +54 -0
  122. rasa/core/channels/inspector/assets/favicon.ico +0 -0
  123. rasa/core/channels/inspector/assets/rasa-chat.js +2 -0
  124. rasa/core/channels/inspector/custom.d.ts +3 -0
  125. rasa/core/channels/inspector/dist/assets/arc-5623b6dc.js +1 -0
  126. rasa/core/channels/inspector/dist/assets/array-9f3ba611.js +1 -0
  127. rasa/core/channels/inspector/dist/assets/c4Diagram-d0fbc5ce-685c106a.js +10 -0
  128. rasa/core/channels/inspector/dist/assets/classDiagram-936ed81e-8cbed007.js +2 -0
  129. rasa/core/channels/inspector/dist/assets/classDiagram-v2-c3cb15f1-5889cf12.js +2 -0
  130. rasa/core/channels/inspector/dist/assets/createText-62fc7601-24c249d7.js +7 -0
  131. rasa/core/channels/inspector/dist/assets/edges-f2ad444c-7dd06a75.js +4 -0
  132. rasa/core/channels/inspector/dist/assets/erDiagram-9d236eb7-62c1e54c.js +51 -0
  133. rasa/core/channels/inspector/dist/assets/flowDb-1972c806-ce49b86f.js +6 -0
  134. rasa/core/channels/inspector/dist/assets/flowDiagram-7ea5b25a-4067e48f.js +4 -0
  135. rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-85583a23.js +1 -0
  136. rasa/core/channels/inspector/dist/assets/flowchart-elk-definition-abe16c3d-59fe4051.js +139 -0
  137. rasa/core/channels/inspector/dist/assets/ganttDiagram-9b5ea136-47e3a43b.js +266 -0
  138. rasa/core/channels/inspector/dist/assets/gitGraphDiagram-99d0ae7c-5a2ac0d9.js +70 -0
  139. rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-128cfa44.ttf +0 -0
  140. rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-21dbcb97.woff +0 -0
  141. rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-222b5e26.svg +329 -0
  142. rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-9ad89b2a.woff2 +0 -0
  143. rasa/core/channels/inspector/dist/assets/index-268a75c0.js +1040 -0
  144. rasa/core/channels/inspector/dist/assets/index-2c4b9a3b-dfb8efc4.js +1 -0
  145. rasa/core/channels/inspector/dist/assets/index-3ee28881.css +1 -0
  146. rasa/core/channels/inspector/dist/assets/infoDiagram-736b4530-b0c470f2.js +7 -0
  147. rasa/core/channels/inspector/dist/assets/init-77b53fdd.js +1 -0
  148. rasa/core/channels/inspector/dist/assets/journeyDiagram-df861f2b-2edb829a.js +139 -0
  149. rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-60c05ee4.woff +0 -0
  150. rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-8335d9b8.svg +438 -0
  151. rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-9cc39c75.ttf +0 -0
  152. rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-ead13ccf.woff2 +0 -0
  153. rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-16705655.woff2 +0 -0
  154. rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-5aeb07f9.woff +0 -0
  155. rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-9c459044.ttf +0 -0
  156. rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-9e2898a4.svg +435 -0
  157. rasa/core/channels/inspector/dist/assets/layout-b6873d69.js +1 -0
  158. rasa/core/channels/inspector/dist/assets/line-1efc5781.js +1 -0
  159. rasa/core/channels/inspector/dist/assets/linear-661e9b94.js +1 -0
  160. rasa/core/channels/inspector/dist/assets/mindmap-definition-beec6740-2d2e727f.js +109 -0
  161. rasa/core/channels/inspector/dist/assets/ordinal-ba9b4969.js +1 -0
  162. rasa/core/channels/inspector/dist/assets/path-53f90ab3.js +1 -0
  163. rasa/core/channels/inspector/dist/assets/pieDiagram-dbbf0591-9d3ea93d.js +35 -0
  164. rasa/core/channels/inspector/dist/assets/quadrantDiagram-4d7f4fd6-06a178a2.js +7 -0
  165. rasa/core/channels/inspector/dist/assets/requirementDiagram-6fc4c22a-0bfedffc.js +52 -0
  166. rasa/core/channels/inspector/dist/assets/sankeyDiagram-8f13d901-d76d0a04.js +8 -0
  167. rasa/core/channels/inspector/dist/assets/sequenceDiagram-b655622a-37bb4341.js +122 -0
  168. rasa/core/channels/inspector/dist/assets/stateDiagram-59f0c015-f52f7f57.js +1 -0
  169. rasa/core/channels/inspector/dist/assets/stateDiagram-v2-2b26beab-4a986a20.js +1 -0
  170. rasa/core/channels/inspector/dist/assets/styles-080da4f6-7dd9ae12.js +110 -0
  171. rasa/core/channels/inspector/dist/assets/styles-3dcbcfbf-46e1ca14.js +159 -0
  172. rasa/core/channels/inspector/dist/assets/styles-9c745c82-4a97439a.js +207 -0
  173. rasa/core/channels/inspector/dist/assets/svgDrawCommon-4835440b-823917a3.js +1 -0
  174. rasa/core/channels/inspector/dist/assets/timeline-definition-5b62e21b-9ea72896.js +61 -0
  175. rasa/core/channels/inspector/dist/assets/xychartDiagram-2b33534f-b631a8b6.js +7 -0
  176. rasa/core/channels/inspector/dist/index.html +39 -0
  177. rasa/core/channels/inspector/index.html +37 -0
  178. rasa/core/channels/inspector/jest.config.ts +13 -0
  179. rasa/core/channels/inspector/package.json +48 -0
  180. rasa/core/channels/inspector/setupTests.ts +2 -0
  181. rasa/core/channels/inspector/src/App.tsx +170 -0
  182. rasa/core/channels/inspector/src/components/DiagramFlow.tsx +97 -0
  183. rasa/core/channels/inspector/src/components/DialogueInformation.tsx +187 -0
  184. rasa/core/channels/inspector/src/components/DialogueStack.tsx +151 -0
  185. rasa/core/channels/inspector/src/components/ExpandIcon.tsx +16 -0
  186. rasa/core/channels/inspector/src/components/FullscreenButton.tsx +45 -0
  187. rasa/core/channels/inspector/src/components/LoadingSpinner.tsx +19 -0
  188. rasa/core/channels/inspector/src/components/NoActiveFlow.tsx +21 -0
  189. rasa/core/channels/inspector/src/components/RasaLogo.tsx +32 -0
  190. rasa/core/channels/inspector/src/components/SaraDiagrams.tsx +39 -0
  191. rasa/core/channels/inspector/src/components/Slots.tsx +91 -0
  192. rasa/core/channels/inspector/src/components/Welcome.tsx +54 -0
  193. rasa/core/channels/inspector/src/helpers/formatters.test.ts +385 -0
  194. rasa/core/channels/inspector/src/helpers/formatters.ts +239 -0
  195. rasa/core/channels/inspector/src/helpers/utils.ts +42 -0
  196. rasa/core/channels/inspector/src/main.tsx +13 -0
  197. rasa/core/channels/inspector/src/theme/Button/Button.ts +29 -0
  198. rasa/core/channels/inspector/src/theme/Heading/Heading.ts +31 -0
  199. rasa/core/channels/inspector/src/theme/Input/Input.ts +27 -0
  200. rasa/core/channels/inspector/src/theme/Link/Link.ts +10 -0
  201. rasa/core/channels/inspector/src/theme/Modal/Modal.ts +47 -0
  202. rasa/core/channels/inspector/src/theme/Table/Table.tsx +38 -0
  203. rasa/core/channels/inspector/src/theme/Tooltip/Tooltip.ts +12 -0
  204. rasa/core/channels/inspector/src/theme/base/breakpoints.ts +8 -0
  205. rasa/core/channels/inspector/src/theme/base/colors.ts +88 -0
  206. rasa/core/channels/inspector/src/theme/base/fonts/fontFaces.css +29 -0
  207. rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.eot +0 -0
  208. rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.svg +329 -0
  209. rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.ttf +0 -0
  210. rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.woff +0 -0
  211. rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.woff2 +0 -0
  212. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.eot +0 -0
  213. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.svg +438 -0
  214. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.ttf +0 -0
  215. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.woff +0 -0
  216. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.woff2 +0 -0
  217. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.eot +0 -0
  218. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.svg +435 -0
  219. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.ttf +0 -0
  220. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.woff +0 -0
  221. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.woff2 +0 -0
  222. rasa/core/channels/inspector/src/theme/base/radii.ts +9 -0
  223. rasa/core/channels/inspector/src/theme/base/shadows.ts +7 -0
  224. rasa/core/channels/inspector/src/theme/base/sizes.ts +7 -0
  225. rasa/core/channels/inspector/src/theme/base/space.ts +15 -0
  226. rasa/core/channels/inspector/src/theme/base/styles.ts +13 -0
  227. rasa/core/channels/inspector/src/theme/base/typography.ts +24 -0
  228. rasa/core/channels/inspector/src/theme/base/zIndices.ts +19 -0
  229. rasa/core/channels/inspector/src/theme/index.ts +101 -0
  230. rasa/core/channels/inspector/src/types.ts +64 -0
  231. rasa/core/channels/inspector/src/vite-env.d.ts +1 -0
  232. rasa/core/channels/inspector/tests/__mocks__/fileMock.ts +1 -0
  233. rasa/core/channels/inspector/tests/__mocks__/matchMedia.ts +16 -0
  234. rasa/core/channels/inspector/tests/__mocks__/styleMock.ts +1 -0
  235. rasa/core/channels/inspector/tests/renderWithProviders.tsx +14 -0
  236. rasa/core/channels/inspector/tsconfig.json +26 -0
  237. rasa/core/channels/inspector/tsconfig.node.json +10 -0
  238. rasa/core/channels/inspector/vite.config.ts +8 -0
  239. rasa/core/channels/inspector/yarn.lock +6156 -0
  240. rasa/core/channels/mattermost.py +229 -0
  241. rasa/core/channels/rasa_chat.py +126 -0
  242. rasa/core/channels/rest.py +210 -0
  243. rasa/core/channels/rocketchat.py +175 -0
  244. rasa/core/channels/slack.py +620 -0
  245. rasa/core/channels/socketio.py +274 -0
  246. rasa/core/channels/telegram.py +298 -0
  247. rasa/core/channels/twilio.py +169 -0
  248. rasa/core/channels/twilio_voice.py +367 -0
  249. rasa/core/channels/vier_cvg.py +374 -0
  250. rasa/core/channels/webexteams.py +135 -0
  251. rasa/core/concurrent_lock_store.py +210 -0
  252. rasa/core/constants.py +107 -0
  253. rasa/core/evaluation/__init__.py +0 -0
  254. rasa/core/evaluation/marker.py +267 -0
  255. rasa/core/evaluation/marker_base.py +925 -0
  256. rasa/core/evaluation/marker_stats.py +294 -0
  257. rasa/core/evaluation/marker_tracker_loader.py +103 -0
  258. rasa/core/exceptions.py +29 -0
  259. rasa/core/exporter.py +284 -0
  260. rasa/core/featurizers/__init__.py +0 -0
  261. rasa/core/featurizers/precomputation.py +410 -0
  262. rasa/core/featurizers/single_state_featurizer.py +402 -0
  263. rasa/core/featurizers/tracker_featurizers.py +1172 -0
  264. rasa/core/http_interpreter.py +89 -0
  265. rasa/core/information_retrieval/__init__.py +0 -0
  266. rasa/core/information_retrieval/faiss.py +116 -0
  267. rasa/core/information_retrieval/information_retrieval.py +72 -0
  268. rasa/core/information_retrieval/milvus.py +59 -0
  269. rasa/core/information_retrieval/qdrant.py +102 -0
  270. rasa/core/jobs.py +63 -0
  271. rasa/core/lock.py +139 -0
  272. rasa/core/lock_store.py +344 -0
  273. rasa/core/migrate.py +404 -0
  274. rasa/core/nlg/__init__.py +3 -0
  275. rasa/core/nlg/callback.py +147 -0
  276. rasa/core/nlg/contextual_response_rephraser.py +270 -0
  277. rasa/core/nlg/generator.py +230 -0
  278. rasa/core/nlg/interpolator.py +143 -0
  279. rasa/core/nlg/response.py +155 -0
  280. rasa/core/nlg/summarize.py +69 -0
  281. rasa/core/policies/__init__.py +0 -0
  282. rasa/core/policies/ensemble.py +329 -0
  283. rasa/core/policies/enterprise_search_policy.py +717 -0
  284. rasa/core/policies/enterprise_search_prompt_template.jinja2 +62 -0
  285. rasa/core/policies/flow_policy.py +205 -0
  286. rasa/core/policies/flows/__init__.py +0 -0
  287. rasa/core/policies/flows/flow_exceptions.py +44 -0
  288. rasa/core/policies/flows/flow_executor.py +582 -0
  289. rasa/core/policies/flows/flow_step_result.py +43 -0
  290. rasa/core/policies/intentless_policy.py +924 -0
  291. rasa/core/policies/intentless_prompt_template.jinja2 +22 -0
  292. rasa/core/policies/memoization.py +538 -0
  293. rasa/core/policies/policy.py +716 -0
  294. rasa/core/policies/rule_policy.py +1276 -0
  295. rasa/core/policies/ted_policy.py +2146 -0
  296. rasa/core/policies/unexpected_intent_policy.py +1015 -0
  297. rasa/core/processor.py +1331 -0
  298. rasa/core/run.py +315 -0
  299. rasa/core/secrets_manager/__init__.py +0 -0
  300. rasa/core/secrets_manager/constants.py +32 -0
  301. rasa/core/secrets_manager/endpoints.py +391 -0
  302. rasa/core/secrets_manager/factory.py +233 -0
  303. rasa/core/secrets_manager/secret_manager.py +262 -0
  304. rasa/core/secrets_manager/vault.py +576 -0
  305. rasa/core/test.py +1337 -0
  306. rasa/core/tracker_store.py +1664 -0
  307. rasa/core/train.py +107 -0
  308. rasa/core/training/__init__.py +89 -0
  309. rasa/core/training/converters/__init__.py +0 -0
  310. rasa/core/training/converters/responses_prefix_converter.py +119 -0
  311. rasa/core/training/interactive.py +1742 -0
  312. rasa/core/training/story_conflict.py +381 -0
  313. rasa/core/training/training.py +93 -0
  314. rasa/core/utils.py +344 -0
  315. rasa/core/visualize.py +70 -0
  316. rasa/dialogue_understanding/__init__.py +0 -0
  317. rasa/dialogue_understanding/coexistence/__init__.py +0 -0
  318. rasa/dialogue_understanding/coexistence/constants.py +4 -0
  319. rasa/dialogue_understanding/coexistence/intent_based_router.py +189 -0
  320. rasa/dialogue_understanding/coexistence/llm_based_router.py +261 -0
  321. rasa/dialogue_understanding/coexistence/router_template.jinja2 +12 -0
  322. rasa/dialogue_understanding/commands/__init__.py +45 -0
  323. rasa/dialogue_understanding/commands/can_not_handle_command.py +61 -0
  324. rasa/dialogue_understanding/commands/cancel_flow_command.py +116 -0
  325. rasa/dialogue_understanding/commands/chit_chat_answer_command.py +48 -0
  326. rasa/dialogue_understanding/commands/clarify_command.py +77 -0
  327. rasa/dialogue_understanding/commands/command.py +85 -0
  328. rasa/dialogue_understanding/commands/correct_slots_command.py +288 -0
  329. rasa/dialogue_understanding/commands/error_command.py +67 -0
  330. rasa/dialogue_understanding/commands/free_form_answer_command.py +9 -0
  331. rasa/dialogue_understanding/commands/handle_code_change_command.py +64 -0
  332. rasa/dialogue_understanding/commands/human_handoff_command.py +57 -0
  333. rasa/dialogue_understanding/commands/knowledge_answer_command.py +48 -0
  334. rasa/dialogue_understanding/commands/noop_command.py +45 -0
  335. rasa/dialogue_understanding/commands/set_slot_command.py +125 -0
  336. rasa/dialogue_understanding/commands/skip_question_command.py +66 -0
  337. rasa/dialogue_understanding/commands/start_flow_command.py +98 -0
  338. rasa/dialogue_understanding/generator/__init__.py +6 -0
  339. rasa/dialogue_understanding/generator/command_generator.py +257 -0
  340. rasa/dialogue_understanding/generator/command_prompt_template.jinja2 +57 -0
  341. rasa/dialogue_understanding/generator/flow_document_template.jinja2 +4 -0
  342. rasa/dialogue_understanding/generator/flow_retrieval.py +410 -0
  343. rasa/dialogue_understanding/generator/llm_command_generator.py +637 -0
  344. rasa/dialogue_understanding/generator/nlu_command_adapter.py +157 -0
  345. rasa/dialogue_understanding/patterns/__init__.py +0 -0
  346. rasa/dialogue_understanding/patterns/cancel.py +111 -0
  347. rasa/dialogue_understanding/patterns/cannot_handle.py +43 -0
  348. rasa/dialogue_understanding/patterns/chitchat.py +37 -0
  349. rasa/dialogue_understanding/patterns/clarify.py +97 -0
  350. rasa/dialogue_understanding/patterns/code_change.py +41 -0
  351. rasa/dialogue_understanding/patterns/collect_information.py +90 -0
  352. rasa/dialogue_understanding/patterns/completed.py +40 -0
  353. rasa/dialogue_understanding/patterns/continue_interrupted.py +42 -0
  354. rasa/dialogue_understanding/patterns/correction.py +278 -0
  355. rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +243 -0
  356. rasa/dialogue_understanding/patterns/human_handoff.py +37 -0
  357. rasa/dialogue_understanding/patterns/internal_error.py +47 -0
  358. rasa/dialogue_understanding/patterns/search.py +37 -0
  359. rasa/dialogue_understanding/patterns/skip_question.py +38 -0
  360. rasa/dialogue_understanding/processor/__init__.py +0 -0
  361. rasa/dialogue_understanding/processor/command_processor.py +578 -0
  362. rasa/dialogue_understanding/processor/command_processor_component.py +39 -0
  363. rasa/dialogue_understanding/stack/__init__.py +0 -0
  364. rasa/dialogue_understanding/stack/dialogue_stack.py +178 -0
  365. rasa/dialogue_understanding/stack/frames/__init__.py +19 -0
  366. rasa/dialogue_understanding/stack/frames/chit_chat_frame.py +27 -0
  367. rasa/dialogue_understanding/stack/frames/dialogue_stack_frame.py +137 -0
  368. rasa/dialogue_understanding/stack/frames/flow_stack_frame.py +157 -0
  369. rasa/dialogue_understanding/stack/frames/pattern_frame.py +10 -0
  370. rasa/dialogue_understanding/stack/frames/search_frame.py +27 -0
  371. rasa/dialogue_understanding/stack/utils.py +211 -0
  372. rasa/e2e_test/__init__.py +0 -0
  373. rasa/e2e_test/constants.py +10 -0
  374. rasa/e2e_test/e2e_test_case.py +322 -0
  375. rasa/e2e_test/e2e_test_result.py +34 -0
  376. rasa/e2e_test/e2e_test_runner.py +659 -0
  377. rasa/e2e_test/e2e_test_schema.yml +67 -0
  378. rasa/engine/__init__.py +0 -0
  379. rasa/engine/caching.py +464 -0
  380. rasa/engine/constants.py +17 -0
  381. rasa/engine/exceptions.py +14 -0
  382. rasa/engine/graph.py +625 -0
  383. rasa/engine/loader.py +36 -0
  384. rasa/engine/recipes/__init__.py +0 -0
  385. rasa/engine/recipes/config_files/default_config.yml +44 -0
  386. rasa/engine/recipes/default_components.py +99 -0
  387. rasa/engine/recipes/default_recipe.py +1252 -0
  388. rasa/engine/recipes/graph_recipe.py +79 -0
  389. rasa/engine/recipes/recipe.py +93 -0
  390. rasa/engine/runner/__init__.py +0 -0
  391. rasa/engine/runner/dask.py +256 -0
  392. rasa/engine/runner/interface.py +49 -0
  393. rasa/engine/storage/__init__.py +0 -0
  394. rasa/engine/storage/local_model_storage.py +248 -0
  395. rasa/engine/storage/resource.py +110 -0
  396. rasa/engine/storage/storage.py +203 -0
  397. rasa/engine/training/__init__.py +0 -0
  398. rasa/engine/training/components.py +176 -0
  399. rasa/engine/training/fingerprinting.py +64 -0
  400. rasa/engine/training/graph_trainer.py +256 -0
  401. rasa/engine/training/hooks.py +164 -0
  402. rasa/engine/validation.py +839 -0
  403. rasa/env.py +5 -0
  404. rasa/exceptions.py +69 -0
  405. rasa/graph_components/__init__.py +0 -0
  406. rasa/graph_components/converters/__init__.py +0 -0
  407. rasa/graph_components/converters/nlu_message_converter.py +48 -0
  408. rasa/graph_components/providers/__init__.py +0 -0
  409. rasa/graph_components/providers/domain_for_core_training_provider.py +87 -0
  410. rasa/graph_components/providers/domain_provider.py +71 -0
  411. rasa/graph_components/providers/flows_provider.py +74 -0
  412. rasa/graph_components/providers/forms_provider.py +44 -0
  413. rasa/graph_components/providers/nlu_training_data_provider.py +56 -0
  414. rasa/graph_components/providers/responses_provider.py +44 -0
  415. rasa/graph_components/providers/rule_only_provider.py +49 -0
  416. rasa/graph_components/providers/story_graph_provider.py +43 -0
  417. rasa/graph_components/providers/training_tracker_provider.py +55 -0
  418. rasa/graph_components/validators/__init__.py +0 -0
  419. rasa/graph_components/validators/default_recipe_validator.py +552 -0
  420. rasa/graph_components/validators/finetuning_validator.py +302 -0
  421. rasa/hooks.py +113 -0
  422. rasa/jupyter.py +63 -0
  423. rasa/keys +1 -0
  424. rasa/markers/__init__.py +0 -0
  425. rasa/markers/marker.py +269 -0
  426. rasa/markers/marker_base.py +828 -0
  427. rasa/markers/upload.py +74 -0
  428. rasa/markers/validate.py +21 -0
  429. rasa/model.py +118 -0
  430. rasa/model_testing.py +457 -0
  431. rasa/model_training.py +535 -0
  432. rasa/nlu/__init__.py +7 -0
  433. rasa/nlu/classifiers/__init__.py +3 -0
  434. rasa/nlu/classifiers/classifier.py +5 -0
  435. rasa/nlu/classifiers/diet_classifier.py +1874 -0
  436. rasa/nlu/classifiers/fallback_classifier.py +192 -0
  437. rasa/nlu/classifiers/keyword_intent_classifier.py +188 -0
  438. rasa/nlu/classifiers/llm_intent_classifier.py +519 -0
  439. rasa/nlu/classifiers/logistic_regression_classifier.py +240 -0
  440. rasa/nlu/classifiers/mitie_intent_classifier.py +156 -0
  441. rasa/nlu/classifiers/regex_message_handler.py +56 -0
  442. rasa/nlu/classifiers/sklearn_intent_classifier.py +309 -0
  443. rasa/nlu/constants.py +77 -0
  444. rasa/nlu/convert.py +40 -0
  445. rasa/nlu/emulators/__init__.py +0 -0
  446. rasa/nlu/emulators/dialogflow.py +55 -0
  447. rasa/nlu/emulators/emulator.py +49 -0
  448. rasa/nlu/emulators/luis.py +86 -0
  449. rasa/nlu/emulators/no_emulator.py +10 -0
  450. rasa/nlu/emulators/wit.py +56 -0
  451. rasa/nlu/extractors/__init__.py +0 -0
  452. rasa/nlu/extractors/crf_entity_extractor.py +672 -0
  453. rasa/nlu/extractors/duckling_entity_extractor.py +206 -0
  454. rasa/nlu/extractors/entity_synonyms.py +178 -0
  455. rasa/nlu/extractors/extractor.py +470 -0
  456. rasa/nlu/extractors/mitie_entity_extractor.py +293 -0
  457. rasa/nlu/extractors/regex_entity_extractor.py +220 -0
  458. rasa/nlu/extractors/spacy_entity_extractor.py +95 -0
  459. rasa/nlu/featurizers/__init__.py +0 -0
  460. rasa/nlu/featurizers/dense_featurizer/__init__.py +0 -0
  461. rasa/nlu/featurizers/dense_featurizer/convert_featurizer.py +449 -0
  462. rasa/nlu/featurizers/dense_featurizer/dense_featurizer.py +57 -0
  463. rasa/nlu/featurizers/dense_featurizer/lm_featurizer.py +772 -0
  464. rasa/nlu/featurizers/dense_featurizer/mitie_featurizer.py +170 -0
  465. rasa/nlu/featurizers/dense_featurizer/spacy_featurizer.py +132 -0
  466. rasa/nlu/featurizers/featurizer.py +89 -0
  467. rasa/nlu/featurizers/sparse_featurizer/__init__.py +0 -0
  468. rasa/nlu/featurizers/sparse_featurizer/count_vectors_featurizer.py +840 -0
  469. rasa/nlu/featurizers/sparse_featurizer/lexical_syntactic_featurizer.py +539 -0
  470. rasa/nlu/featurizers/sparse_featurizer/regex_featurizer.py +269 -0
  471. rasa/nlu/featurizers/sparse_featurizer/sparse_featurizer.py +9 -0
  472. rasa/nlu/model.py +24 -0
  473. rasa/nlu/persistor.py +240 -0
  474. rasa/nlu/run.py +27 -0
  475. rasa/nlu/selectors/__init__.py +0 -0
  476. rasa/nlu/selectors/response_selector.py +990 -0
  477. rasa/nlu/test.py +1943 -0
  478. rasa/nlu/tokenizers/__init__.py +0 -0
  479. rasa/nlu/tokenizers/jieba_tokenizer.py +148 -0
  480. rasa/nlu/tokenizers/mitie_tokenizer.py +75 -0
  481. rasa/nlu/tokenizers/spacy_tokenizer.py +72 -0
  482. rasa/nlu/tokenizers/tokenizer.py +239 -0
  483. rasa/nlu/tokenizers/whitespace_tokenizer.py +106 -0
  484. rasa/nlu/utils/__init__.py +35 -0
  485. rasa/nlu/utils/bilou_utils.py +462 -0
  486. rasa/nlu/utils/hugging_face/__init__.py +0 -0
  487. rasa/nlu/utils/hugging_face/registry.py +108 -0
  488. rasa/nlu/utils/hugging_face/transformers_pre_post_processors.py +311 -0
  489. rasa/nlu/utils/mitie_utils.py +113 -0
  490. rasa/nlu/utils/pattern_utils.py +168 -0
  491. rasa/nlu/utils/spacy_utils.py +312 -0
  492. rasa/plugin.py +90 -0
  493. rasa/server.py +1536 -0
  494. rasa/shared/__init__.py +0 -0
  495. rasa/shared/constants.py +181 -0
  496. rasa/shared/core/__init__.py +0 -0
  497. rasa/shared/core/constants.py +168 -0
  498. rasa/shared/core/conversation.py +46 -0
  499. rasa/shared/core/domain.py +2106 -0
  500. rasa/shared/core/events.py +2507 -0
  501. rasa/shared/core/flows/__init__.py +7 -0
  502. rasa/shared/core/flows/flow.py +353 -0
  503. rasa/shared/core/flows/flow_step.py +146 -0
  504. rasa/shared/core/flows/flow_step_links.py +319 -0
  505. rasa/shared/core/flows/flow_step_sequence.py +70 -0
  506. rasa/shared/core/flows/flows_list.py +211 -0
  507. rasa/shared/core/flows/flows_yaml_schema.json +217 -0
  508. rasa/shared/core/flows/nlu_trigger.py +117 -0
  509. rasa/shared/core/flows/steps/__init__.py +24 -0
  510. rasa/shared/core/flows/steps/action.py +51 -0
  511. rasa/shared/core/flows/steps/call.py +64 -0
  512. rasa/shared/core/flows/steps/collect.py +112 -0
  513. rasa/shared/core/flows/steps/constants.py +5 -0
  514. rasa/shared/core/flows/steps/continuation.py +36 -0
  515. rasa/shared/core/flows/steps/end.py +22 -0
  516. rasa/shared/core/flows/steps/internal.py +44 -0
  517. rasa/shared/core/flows/steps/link.py +51 -0
  518. rasa/shared/core/flows/steps/no_operation.py +48 -0
  519. rasa/shared/core/flows/steps/set_slots.py +50 -0
  520. rasa/shared/core/flows/steps/start.py +30 -0
  521. rasa/shared/core/flows/validation.py +527 -0
  522. rasa/shared/core/flows/yaml_flows_io.py +278 -0
  523. rasa/shared/core/generator.py +907 -0
  524. rasa/shared/core/slot_mappings.py +235 -0
  525. rasa/shared/core/slots.py +647 -0
  526. rasa/shared/core/trackers.py +1159 -0
  527. rasa/shared/core/training_data/__init__.py +0 -0
  528. rasa/shared/core/training_data/loading.py +90 -0
  529. rasa/shared/core/training_data/story_reader/__init__.py +0 -0
  530. rasa/shared/core/training_data/story_reader/story_reader.py +129 -0
  531. rasa/shared/core/training_data/story_reader/story_step_builder.py +168 -0
  532. rasa/shared/core/training_data/story_reader/yaml_story_reader.py +888 -0
  533. rasa/shared/core/training_data/story_writer/__init__.py +0 -0
  534. rasa/shared/core/training_data/story_writer/story_writer.py +76 -0
  535. rasa/shared/core/training_data/story_writer/yaml_story_writer.py +442 -0
  536. rasa/shared/core/training_data/structures.py +838 -0
  537. rasa/shared/core/training_data/visualization.html +146 -0
  538. rasa/shared/core/training_data/visualization.py +603 -0
  539. rasa/shared/data.py +192 -0
  540. rasa/shared/engine/__init__.py +0 -0
  541. rasa/shared/engine/caching.py +26 -0
  542. rasa/shared/exceptions.py +129 -0
  543. rasa/shared/importers/__init__.py +0 -0
  544. rasa/shared/importers/importer.py +705 -0
  545. rasa/shared/importers/multi_project.py +203 -0
  546. rasa/shared/importers/rasa.py +100 -0
  547. rasa/shared/importers/utils.py +34 -0
  548. rasa/shared/nlu/__init__.py +0 -0
  549. rasa/shared/nlu/constants.py +45 -0
  550. rasa/shared/nlu/interpreter.py +10 -0
  551. rasa/shared/nlu/training_data/__init__.py +0 -0
  552. rasa/shared/nlu/training_data/entities_parser.py +209 -0
  553. rasa/shared/nlu/training_data/features.py +374 -0
  554. rasa/shared/nlu/training_data/formats/__init__.py +10 -0
  555. rasa/shared/nlu/training_data/formats/dialogflow.py +162 -0
  556. rasa/shared/nlu/training_data/formats/luis.py +87 -0
  557. rasa/shared/nlu/training_data/formats/rasa.py +135 -0
  558. rasa/shared/nlu/training_data/formats/rasa_yaml.py +605 -0
  559. rasa/shared/nlu/training_data/formats/readerwriter.py +245 -0
  560. rasa/shared/nlu/training_data/formats/wit.py +52 -0
  561. rasa/shared/nlu/training_data/loading.py +137 -0
  562. rasa/shared/nlu/training_data/lookup_tables_parser.py +30 -0
  563. rasa/shared/nlu/training_data/message.py +477 -0
  564. rasa/shared/nlu/training_data/schemas/__init__.py +0 -0
  565. rasa/shared/nlu/training_data/schemas/data_schema.py +85 -0
  566. rasa/shared/nlu/training_data/schemas/nlu.yml +53 -0
  567. rasa/shared/nlu/training_data/schemas/responses.yml +70 -0
  568. rasa/shared/nlu/training_data/synonyms_parser.py +42 -0
  569. rasa/shared/nlu/training_data/training_data.py +732 -0
  570. rasa/shared/nlu/training_data/util.py +223 -0
  571. rasa/shared/providers/__init__.py +0 -0
  572. rasa/shared/providers/openai/__init__.py +0 -0
  573. rasa/shared/providers/openai/clients.py +43 -0
  574. rasa/shared/providers/openai/session_handler.py +110 -0
  575. rasa/shared/utils/__init__.py +0 -0
  576. rasa/shared/utils/cli.py +72 -0
  577. rasa/shared/utils/common.py +308 -0
  578. rasa/shared/utils/constants.py +1 -0
  579. rasa/shared/utils/io.py +403 -0
  580. rasa/shared/utils/llm.py +405 -0
  581. rasa/shared/utils/pykwalify_extensions.py +26 -0
  582. rasa/shared/utils/schemas/__init__.py +0 -0
  583. rasa/shared/utils/schemas/config.yml +2 -0
  584. rasa/shared/utils/schemas/domain.yml +142 -0
  585. rasa/shared/utils/schemas/events.py +212 -0
  586. rasa/shared/utils/schemas/model_config.yml +46 -0
  587. rasa/shared/utils/schemas/stories.yml +173 -0
  588. rasa/shared/utils/yaml.py +777 -0
  589. rasa/studio/__init__.py +0 -0
  590. rasa/studio/auth.py +252 -0
  591. rasa/studio/config.py +127 -0
  592. rasa/studio/constants.py +16 -0
  593. rasa/studio/data_handler.py +352 -0
  594. rasa/studio/download.py +350 -0
  595. rasa/studio/train.py +136 -0
  596. rasa/studio/upload.py +408 -0
  597. rasa/telemetry.py +1583 -0
  598. rasa/tracing/__init__.py +0 -0
  599. rasa/tracing/config.py +338 -0
  600. rasa/tracing/constants.py +38 -0
  601. rasa/tracing/instrumentation/__init__.py +0 -0
  602. rasa/tracing/instrumentation/attribute_extractors.py +663 -0
  603. rasa/tracing/instrumentation/instrumentation.py +939 -0
  604. rasa/tracing/instrumentation/intentless_policy_instrumentation.py +142 -0
  605. rasa/tracing/instrumentation/metrics.py +206 -0
  606. rasa/tracing/metric_instrument_provider.py +125 -0
  607. rasa/utils/__init__.py +0 -0
  608. rasa/utils/beta.py +83 -0
  609. rasa/utils/cli.py +27 -0
  610. rasa/utils/common.py +635 -0
  611. rasa/utils/converter.py +53 -0
  612. rasa/utils/endpoints.py +303 -0
  613. rasa/utils/io.py +326 -0
  614. rasa/utils/licensing.py +319 -0
  615. rasa/utils/log_utils.py +174 -0
  616. rasa/utils/mapper.py +210 -0
  617. rasa/utils/ml_utils.py +145 -0
  618. rasa/utils/plotting.py +362 -0
  619. rasa/utils/singleton.py +23 -0
  620. rasa/utils/tensorflow/__init__.py +0 -0
  621. rasa/utils/tensorflow/callback.py +112 -0
  622. rasa/utils/tensorflow/constants.py +116 -0
  623. rasa/utils/tensorflow/crf.py +492 -0
  624. rasa/utils/tensorflow/data_generator.py +440 -0
  625. rasa/utils/tensorflow/environment.py +161 -0
  626. rasa/utils/tensorflow/exceptions.py +5 -0
  627. rasa/utils/tensorflow/layers.py +1565 -0
  628. rasa/utils/tensorflow/layers_utils.py +113 -0
  629. rasa/utils/tensorflow/metrics.py +281 -0
  630. rasa/utils/tensorflow/model_data.py +991 -0
  631. rasa/utils/tensorflow/model_data_utils.py +500 -0
  632. rasa/utils/tensorflow/models.py +936 -0
  633. rasa/utils/tensorflow/rasa_layers.py +1094 -0
  634. rasa/utils/tensorflow/transformer.py +640 -0
  635. rasa/utils/tensorflow/types.py +6 -0
  636. rasa/utils/train_utils.py +572 -0
  637. rasa/utils/yaml.py +54 -0
  638. rasa/validator.py +1035 -0
  639. rasa/version.py +3 -0
  640. rasa_pro-3.8.16.dist-info/METADATA +528 -0
  641. rasa_pro-3.8.16.dist-info/NOTICE +5 -0
  642. rasa_pro-3.8.16.dist-info/RECORD +644 -0
  643. rasa_pro-3.8.16.dist-info/WHEEL +4 -0
  644. rasa_pro-3.8.16.dist-info/entry_points.txt +3 -0
rasa/core/test.py ADDED
@@ -0,0 +1,1337 @@
1
+ import logging
2
+ import os
3
+ from pathlib import Path
4
+ import tempfile
5
+ import warnings as pywarnings
6
+ from collections import defaultdict, namedtuple
7
+ from typing import Any, Dict, List, Optional, Text, Tuple, TYPE_CHECKING, cast
8
+
9
+ from rasa import telemetry
10
+ from rasa.core.constants import (
11
+ CONFUSION_MATRIX_STORIES_FILE,
12
+ REPORT_STORIES_FILE,
13
+ FAILED_STORIES_FILE,
14
+ SUCCESSFUL_STORIES_FILE,
15
+ STORIES_WITH_WARNINGS_FILE,
16
+ )
17
+ from rasa.core.channels import UserMessage
18
+ from rasa.core.policies.policy import PolicyPrediction
19
+ from rasa.nlu.test import EntityEvaluationResult, evaluate_entities
20
+ from rasa.nlu.tokenizers.tokenizer import Token
21
+ from rasa.shared.constants import ROUTE_TO_CALM_SLOT
22
+ from rasa.shared.core.constants import (
23
+ POLICIES_THAT_EXTRACT_ENTITIES,
24
+ ACTION_UNLIKELY_INTENT_NAME,
25
+ )
26
+ from rasa.shared.exceptions import RasaException
27
+ import rasa.shared.utils.io
28
+ from rasa.shared.core.training_data.story_writer.yaml_story_writer import (
29
+ YAMLStoryWriter,
30
+ )
31
+ from rasa.shared.core.training_data.structures import StoryStep
32
+ from rasa.shared.core.domain import Domain
33
+ from rasa.nlu.constants import (
34
+ RESPONSE_SELECTOR_DEFAULT_INTENT,
35
+ RESPONSE_SELECTOR_RETRIEVAL_INTENTS,
36
+ TOKENS_NAMES,
37
+ RESPONSE_SELECTOR_PROPERTY_NAME,
38
+ )
39
+ from rasa.shared.nlu.constants import (
40
+ INTENT,
41
+ ENTITIES,
42
+ ENTITY_ATTRIBUTE_VALUE,
43
+ ENTITY_ATTRIBUTE_START,
44
+ ENTITY_ATTRIBUTE_END,
45
+ EXTRACTOR,
46
+ ENTITY_ATTRIBUTE_TYPE,
47
+ INTENT_RESPONSE_KEY,
48
+ INTENT_NAME_KEY,
49
+ RESPONSE,
50
+ RESPONSE_SELECTOR,
51
+ FULL_RETRIEVAL_INTENT_NAME_KEY,
52
+ TEXT,
53
+ ENTITY_ATTRIBUTE_TEXT,
54
+ )
55
+ from rasa.constants import RESULTS_FILE, PERCENTAGE_KEY
56
+ from rasa.shared.core.events import ActionExecuted, EntitiesAdded, UserUttered, SlotSet
57
+ from rasa.shared.core.trackers import DialogueStateTracker
58
+ from rasa.shared.nlu.training_data.formats.readerwriter import TrainingDataWriter
59
+ from rasa.shared.importers.importer import TrainingDataImporter
60
+ from rasa.shared.utils.io import DEFAULT_ENCODING
61
+ from rasa.utils.tensorflow.constants import QUERY_INTENT_KEY, SEVERITY_KEY
62
+ from rasa.exceptions import ActionLimitReached
63
+
64
+ from rasa.core.actions.action import ActionRetrieveResponse
65
+
66
+ if TYPE_CHECKING:
67
+ from rasa.core.agent import Agent
68
+ from rasa.core.processor import MessageProcessor
69
+ from rasa.shared.core.generator import TrainingDataGenerator
70
+ from rasa.shared.core.events import Event, EntityPrediction
71
+
72
+ logger = logging.getLogger(__name__)
73
+
74
+ StoryEvaluation = namedtuple(
75
+ "StoryEvaluation",
76
+ [
77
+ "evaluation_store",
78
+ "failed_stories",
79
+ "successful_stories",
80
+ "stories_with_warnings",
81
+ "action_list",
82
+ "in_training_data_fraction",
83
+ ],
84
+ )
85
+
86
+ PredictionList = List[Optional[Text]]
87
+
88
+
89
+ class WrongPredictionException(RasaException, ValueError):
90
+ """Raised if a wrong prediction is encountered."""
91
+
92
+
93
+ class WarningPredictedAction(ActionExecuted):
94
+ """The model predicted the correct action with warning."""
95
+
96
+ type_name = "warning_predicted"
97
+
98
+ def __init__(
99
+ self,
100
+ action_name_prediction: Text,
101
+ action_name: Optional[Text] = None,
102
+ policy: Optional[Text] = None,
103
+ confidence: Optional[float] = None,
104
+ timestamp: Optional[float] = None,
105
+ metadata: Optional[Dict] = None,
106
+ ):
107
+ """Creates event `action_unlikely_intent` predicted as warning.
108
+
109
+ See the docstring of the parent class for more information.
110
+ """
111
+ self.action_name_prediction = action_name_prediction
112
+ super().__init__(action_name, policy, confidence, timestamp, metadata)
113
+
114
+ def inline_comment(self, **kwargs: Any) -> Text:
115
+ """A comment attached to this event. Used during dumping."""
116
+ return f"predicted: {self.action_name_prediction}"
117
+
118
+
119
+ class WronglyPredictedAction(ActionExecuted):
120
+ """The model predicted the wrong action.
121
+
122
+ Mostly used to mark wrong predictions and be able to
123
+ dump them as stories.
124
+ """
125
+
126
+ type_name = "wrong_action"
127
+
128
+ def __init__(
129
+ self,
130
+ action_name_target: Text,
131
+ action_text_target: Text,
132
+ action_name_prediction: Text,
133
+ policy: Optional[Text] = None,
134
+ confidence: Optional[float] = None,
135
+ timestamp: Optional[float] = None,
136
+ metadata: Optional[Dict] = None,
137
+ predicted_action_unlikely_intent: bool = False,
138
+ ) -> None:
139
+ """Creates event for a successful event execution.
140
+
141
+ See the docstring of the parent class `ActionExecuted` for more information.
142
+ """
143
+ self.action_name_prediction = action_name_prediction
144
+ self.predicted_action_unlikely_intent = predicted_action_unlikely_intent
145
+ super().__init__(
146
+ action_name_target,
147
+ policy,
148
+ confidence,
149
+ timestamp,
150
+ metadata,
151
+ action_text=action_text_target,
152
+ )
153
+
154
+ def inline_comment(self, **kwargs: Any) -> Text:
155
+ """A comment attached to this event. Used during dumping."""
156
+ comment = f"predicted: {self.action_name_prediction}"
157
+ if self.predicted_action_unlikely_intent:
158
+ return f"{comment} after {ACTION_UNLIKELY_INTENT_NAME}"
159
+ return comment
160
+
161
+ def as_story_string(self) -> Text:
162
+ """Returns the story equivalent representation."""
163
+ return f"{self.action_name} <!-- {self.inline_comment()} -->"
164
+
165
+ def __repr__(self) -> Text:
166
+ """Returns event as string for debugging."""
167
+ return (
168
+ f"WronglyPredictedAction(action_target: {self.action_name}, "
169
+ f"action_prediction: {self.action_name_prediction}, "
170
+ f"policy: {self.policy}, confidence: {self.confidence}, "
171
+ f"metadata: {self.metadata})"
172
+ )
173
+
174
+
175
+ class EvaluationStore:
176
+ """Class storing action, intent and entity predictions and targets."""
177
+
178
+ def __init__(
179
+ self,
180
+ action_predictions: Optional[PredictionList] = None,
181
+ action_targets: Optional[PredictionList] = None,
182
+ intent_predictions: Optional[PredictionList] = None,
183
+ intent_targets: Optional[PredictionList] = None,
184
+ entity_predictions: Optional[List["EntityPrediction"]] = None,
185
+ entity_targets: Optional[List["EntityPrediction"]] = None,
186
+ ) -> None:
187
+ """Initialize store attributes."""
188
+ self.action_predictions = action_predictions or []
189
+ self.action_targets = action_targets or []
190
+ self.intent_predictions = intent_predictions or []
191
+ self.intent_targets = intent_targets or []
192
+ self.entity_predictions: List["EntityPrediction"] = entity_predictions or []
193
+ self.entity_targets: List["EntityPrediction"] = entity_targets or []
194
+
195
+ def add_to_store(
196
+ self,
197
+ action_predictions: Optional[PredictionList] = None,
198
+ action_targets: Optional[PredictionList] = None,
199
+ intent_predictions: Optional[PredictionList] = None,
200
+ intent_targets: Optional[PredictionList] = None,
201
+ entity_predictions: Optional[List["EntityPrediction"]] = None,
202
+ entity_targets: Optional[List["EntityPrediction"]] = None,
203
+ ) -> None:
204
+ """Add items or lists of items to the store."""
205
+ self.action_predictions.extend(action_predictions or [])
206
+ self.action_targets.extend(action_targets or [])
207
+ self.intent_targets.extend(intent_targets or [])
208
+ self.intent_predictions.extend(intent_predictions or [])
209
+ self.entity_predictions.extend(entity_predictions or [])
210
+ self.entity_targets.extend(entity_targets or [])
211
+
212
+ def merge_store(self, other: "EvaluationStore") -> None:
213
+ """Add the contents of other to self."""
214
+ self.add_to_store(
215
+ action_predictions=other.action_predictions,
216
+ action_targets=other.action_targets,
217
+ intent_predictions=other.intent_predictions,
218
+ intent_targets=other.intent_targets,
219
+ entity_predictions=other.entity_predictions,
220
+ entity_targets=other.entity_targets,
221
+ )
222
+
223
+ def _check_entity_prediction_target_mismatch(self) -> bool:
224
+ """Checks that same entities were expected and actually extracted.
225
+
226
+ Possible duplicates or differences in order should not matter.
227
+ """
228
+ deduplicated_targets = set(
229
+ tuple(entity.items()) for entity in self.entity_targets
230
+ )
231
+ deduplicated_predictions = set(
232
+ tuple(entity.items()) for entity in self.entity_predictions
233
+ )
234
+ return deduplicated_targets != deduplicated_predictions
235
+
236
+ def check_prediction_target_mismatch(self) -> bool:
237
+ """Checks if intent, entity or action predictions don't match expected ones."""
238
+ return (
239
+ self.intent_predictions != self.intent_targets
240
+ or self._check_entity_prediction_target_mismatch()
241
+ or self.action_predictions != self.action_targets
242
+ )
243
+
244
+ @staticmethod
245
+ def _compare_entities(
246
+ entity_predictions: List["EntityPrediction"],
247
+ entity_targets: List["EntityPrediction"],
248
+ i_pred: int,
249
+ i_target: int,
250
+ ) -> int:
251
+ """Picks the fist entity from the current predicted and target entities.
252
+
253
+ If the predicted entity comes first it returns -1,
254
+ while it returns 1 if the target entity comes first.
255
+ If target and predicted are aligned it returns 0.
256
+ """
257
+ pred = None
258
+ target = None
259
+ if i_pred < len(entity_predictions):
260
+ pred = entity_predictions[i_pred]
261
+ if i_target < len(entity_targets):
262
+ target = entity_targets[i_target]
263
+ if target and pred:
264
+ # Check which entity has the lower "start" value
265
+ if pred.get(ENTITY_ATTRIBUTE_START) < target.get(ENTITY_ATTRIBUTE_START):
266
+ return -1
267
+ elif target.get(ENTITY_ATTRIBUTE_START) < pred.get(ENTITY_ATTRIBUTE_START):
268
+ return 1
269
+ else:
270
+ # Since both have the same "start" values,
271
+ # check which one has the lower "end" value
272
+ if pred.get(ENTITY_ATTRIBUTE_END) < target.get(ENTITY_ATTRIBUTE_END):
273
+ return -1
274
+ elif target.get(ENTITY_ATTRIBUTE_END) < pred.get(ENTITY_ATTRIBUTE_END):
275
+ return 1
276
+ else:
277
+ # The entities have the same "start" and "end" values
278
+ return 0
279
+ return 1 if target else -1
280
+
281
+ @staticmethod
282
+ def _generate_entity_training_data(entity: Dict[Text, Any]) -> Text:
283
+ return TrainingDataWriter.generate_entity(entity.get("text"), entity)
284
+
285
+ def serialise(self) -> Tuple[PredictionList, PredictionList]:
286
+ """Turn targets and predictions to lists of equal size for sklearn."""
287
+ texts = sorted(
288
+ set(
289
+ [str(e.get("text", "")) for e in self.entity_targets]
290
+ + [str(e.get("text", "")) for e in self.entity_predictions]
291
+ )
292
+ )
293
+
294
+ aligned_entity_targets: List[Optional[Text]] = []
295
+ aligned_entity_predictions: List[Optional[Text]] = []
296
+
297
+ for text in texts:
298
+ # sort the entities of this sentence to compare them directly
299
+ entity_targets = sorted(
300
+ filter(
301
+ lambda x: x.get(ENTITY_ATTRIBUTE_TEXT) == text, self.entity_targets
302
+ ),
303
+ key=lambda x: x[ENTITY_ATTRIBUTE_START], # type: ignore[literal-required] # noqa: E501
304
+ )
305
+ entity_predictions = sorted(
306
+ filter(
307
+ lambda x: x.get(ENTITY_ATTRIBUTE_TEXT) == text,
308
+ self.entity_predictions,
309
+ ),
310
+ key=lambda x: x[ENTITY_ATTRIBUTE_START], # type: ignore[literal-required] # noqa: E501
311
+ )
312
+
313
+ i_pred, i_target = 0, 0
314
+
315
+ while i_pred < len(entity_predictions) or i_target < len(entity_targets):
316
+ cmp = self._compare_entities(
317
+ entity_predictions, entity_targets, i_pred, i_target
318
+ )
319
+ if cmp == -1: # predicted comes first
320
+ aligned_entity_predictions.append(
321
+ self._generate_entity_training_data(entity_predictions[i_pred])
322
+ )
323
+ aligned_entity_targets.append("None")
324
+ i_pred += 1
325
+ elif cmp == 1: # target entity comes first
326
+ aligned_entity_targets.append(
327
+ self._generate_entity_training_data(entity_targets[i_target])
328
+ )
329
+ aligned_entity_predictions.append("None")
330
+ i_target += 1
331
+ else: # target and predicted entity are aligned
332
+ aligned_entity_predictions.append(
333
+ self._generate_entity_training_data(entity_predictions[i_pred])
334
+ )
335
+ aligned_entity_targets.append(
336
+ self._generate_entity_training_data(entity_targets[i_target])
337
+ )
338
+ i_pred += 1
339
+ i_target += 1
340
+
341
+ targets = self.action_targets + self.intent_targets + aligned_entity_targets
342
+
343
+ predictions = (
344
+ self.action_predictions
345
+ + self.intent_predictions
346
+ + aligned_entity_predictions
347
+ )
348
+ return targets, predictions
349
+
350
+
351
+ class EndToEndUserUtterance(UserUttered):
352
+ """End-to-end user utterance.
353
+
354
+ Mostly used to print the full end-to-end user message in the
355
+ `failed_test_stories.yml` output file.
356
+ """
357
+
358
+ def as_story_string(self, e2e: bool = True) -> Text:
359
+ """Returns the story equivalent representation."""
360
+ return super().as_story_string(e2e=True)
361
+
362
+
363
+ class WronglyClassifiedUserUtterance(UserUttered):
364
+ """The NLU model predicted the wrong user utterance.
365
+
366
+ Mostly used to mark wrong predictions and be able to
367
+ dump them as stories.
368
+ """
369
+
370
+ type_name = "wrong_utterance"
371
+
372
+ def __init__(self, event: UserUttered, eval_store: EvaluationStore) -> None:
373
+ """Set `predicted_intent` and `predicted_entities` attributes."""
374
+ try:
375
+ self.predicted_intent = eval_store.intent_predictions[0]
376
+ except LookupError:
377
+ self.predicted_intent = None
378
+
379
+ self.target_entities = eval_store.entity_targets
380
+ self.predicted_entities = eval_store.entity_predictions
381
+
382
+ intent = {"name": eval_store.intent_targets[0]}
383
+
384
+ super().__init__(
385
+ event.text,
386
+ intent,
387
+ eval_store.entity_targets,
388
+ event.parse_data,
389
+ event.timestamp,
390
+ event.input_channel,
391
+ )
392
+
393
+ def inline_comment(self, force_comment_generation: bool = False) -> Optional[Text]:
394
+ """A comment attached to this event. Used during dumping."""
395
+ from rasa.shared.core.events import format_message
396
+
397
+ if force_comment_generation or self.predicted_intent != self.intent["name"]:
398
+ predicted_message = format_message(
399
+ self.text, self.predicted_intent, self.predicted_entities
400
+ )
401
+
402
+ return f"predicted: {self.predicted_intent}: {predicted_message}"
403
+ else:
404
+ return None
405
+
406
+ @staticmethod
407
+ def inline_comment_for_entity(
408
+ predicted: Dict[Text, Any], entity: Dict[Text, Any]
409
+ ) -> Optional[Text]:
410
+ """Returns the predicted entity which is then printed as a comment."""
411
+ if predicted["entity"] != entity["entity"]:
412
+ return "predicted: " + predicted["entity"] + ": " + predicted["value"]
413
+ else:
414
+ return None
415
+
416
+ def as_story_string(self, e2e: bool = True) -> Text:
417
+ """Returns text representation of event."""
418
+ from rasa.shared.core.events import format_message
419
+
420
+ correct_message = format_message(
421
+ self.text, self.intent.get("name"), self.entities
422
+ )
423
+ return (
424
+ f"{self.intent.get('name')}: {correct_message} "
425
+ f"<!-- {self.inline_comment()} -->"
426
+ )
427
+
428
+
429
+ def _create_data_generator(
430
+ resource_name: Text,
431
+ agent: "Agent",
432
+ max_stories: Optional[int] = None,
433
+ use_conversation_test_files: bool = False,
434
+ ) -> "TrainingDataGenerator":
435
+ from rasa.shared.core.generator import TrainingDataGenerator
436
+
437
+ tmp_domain_path = Path(tempfile.mkdtemp()) / "domain.yaml"
438
+ domain = agent.domain if agent.domain is not None else Domain.empty()
439
+ domain.persist(tmp_domain_path)
440
+ test_data_importer = TrainingDataImporter.load_from_dict(
441
+ training_data_paths=[resource_name], domain_path=str(tmp_domain_path)
442
+ )
443
+ if use_conversation_test_files:
444
+ story_graph = test_data_importer.get_conversation_tests()
445
+ else:
446
+ story_graph = test_data_importer.get_stories()
447
+
448
+ return TrainingDataGenerator(
449
+ story_graph,
450
+ agent.domain,
451
+ use_story_concatenation=False,
452
+ augmentation_factor=0,
453
+ tracker_limit=max_stories,
454
+ )
455
+
456
+
457
+ def _clean_entity_results(
458
+ text: Text, entity_results: List[Dict[Text, Any]]
459
+ ) -> List["EntityPrediction"]:
460
+ """Extract only the token variables from an entity dict."""
461
+ cleaned_entities = []
462
+
463
+ for r in tuple(entity_results):
464
+ cleaned_entity: EntityPrediction = {ENTITY_ATTRIBUTE_TEXT: text} # type: ignore[misc] # noqa E501
465
+ for k in (
466
+ ENTITY_ATTRIBUTE_START,
467
+ ENTITY_ATTRIBUTE_END,
468
+ ENTITY_ATTRIBUTE_TYPE,
469
+ ENTITY_ATTRIBUTE_VALUE,
470
+ ):
471
+ if k in set(r):
472
+ if k == ENTITY_ATTRIBUTE_VALUE and EXTRACTOR in set(r):
473
+ # convert values to strings for evaluation as
474
+ # target values are all of type string
475
+ r[k] = str(r[k])
476
+ cleaned_entity[k] = r[k] # type: ignore[literal-required]
477
+ cleaned_entities.append(cleaned_entity)
478
+
479
+ return cleaned_entities
480
+
481
+
482
+ def _get_full_retrieval_intent(parsed: Dict[Text, Any]) -> Text:
483
+ """Return full retrieval intent, if it's present, or normal intent otherwise.
484
+
485
+ Args:
486
+ parsed: Predicted parsed data.
487
+
488
+ Returns:
489
+ The extracted intent.
490
+ """
491
+ base_intent = parsed.get(INTENT, {}).get(INTENT_NAME_KEY)
492
+ response_selector = parsed.get(RESPONSE_SELECTOR, {})
493
+
494
+ # return normal intent if it's not a retrieval intent
495
+ if base_intent not in response_selector.get(
496
+ RESPONSE_SELECTOR_RETRIEVAL_INTENTS, {}
497
+ ):
498
+ return base_intent
499
+
500
+ # extract full retrieval intent
501
+ # if the response selector parameter was not specified in config,
502
+ # the response selector contains a "default" key
503
+ if RESPONSE_SELECTOR_DEFAULT_INTENT in response_selector:
504
+ full_retrieval_intent = (
505
+ response_selector.get(RESPONSE_SELECTOR_DEFAULT_INTENT, {})
506
+ .get(RESPONSE, {})
507
+ .get(INTENT_RESPONSE_KEY)
508
+ )
509
+ return full_retrieval_intent if full_retrieval_intent else base_intent
510
+
511
+ # if specified, the response selector contains the base intent as key
512
+ full_retrieval_intent = (
513
+ response_selector.get(base_intent, {})
514
+ .get(RESPONSE, {})
515
+ .get(INTENT_RESPONSE_KEY)
516
+ )
517
+ return full_retrieval_intent if full_retrieval_intent else base_intent
518
+
519
+
520
+ def _collect_user_uttered_predictions(
521
+ event: UserUttered,
522
+ predicted: Dict[Text, Any],
523
+ partial_tracker: DialogueStateTracker,
524
+ fail_on_prediction_errors: bool,
525
+ ) -> EvaluationStore:
526
+ user_uttered_eval_store = EvaluationStore()
527
+
528
+ # intent from the test story, may either be base intent or full retrieval intent
529
+ base_intent = event.intent.get(INTENT_NAME_KEY)
530
+ full_retrieval_intent = event.intent.get(FULL_RETRIEVAL_INTENT_NAME_KEY)
531
+ intent_gold = full_retrieval_intent if full_retrieval_intent else base_intent
532
+
533
+ # predicted intent: note that this is only the base intent at this point
534
+ predicted_base_intent = predicted.get(INTENT, {}).get(INTENT_NAME_KEY)
535
+ # if the test story only provides the base intent AND the prediction was correct,
536
+ # we are not interested in full retrieval intents and skip this section.
537
+ # In any other case we are interested in the full retrieval intent (e.g. for report)
538
+ if intent_gold != predicted_base_intent:
539
+ predicted_base_intent = _get_full_retrieval_intent(predicted)
540
+
541
+ user_uttered_eval_store.add_to_store(
542
+ intent_targets=[intent_gold], intent_predictions=[predicted_base_intent]
543
+ )
544
+
545
+ entity_gold = event.entities
546
+ predicted_entities = predicted.get(ENTITIES)
547
+
548
+ if entity_gold or predicted_entities:
549
+ user_uttered_eval_store.add_to_store(
550
+ entity_targets=_clean_entity_results(event.text, entity_gold),
551
+ entity_predictions=_clean_entity_results(event.text, predicted_entities),
552
+ )
553
+
554
+ if user_uttered_eval_store.check_prediction_target_mismatch():
555
+ partial_tracker.update(
556
+ WronglyClassifiedUserUtterance(event, user_uttered_eval_store)
557
+ )
558
+ if fail_on_prediction_errors:
559
+ story_dump = YAMLStoryWriter().dumps(partial_tracker.as_story().story_steps)
560
+ raise WrongPredictionException(
561
+ f"NLU model predicted a wrong intent or entities. Failed Story:"
562
+ f" \n\n{story_dump}"
563
+ )
564
+ else:
565
+ response_selector_info = (
566
+ {
567
+ RESPONSE_SELECTOR_PROPERTY_NAME: predicted[
568
+ RESPONSE_SELECTOR_PROPERTY_NAME
569
+ ]
570
+ }
571
+ if RESPONSE_SELECTOR_PROPERTY_NAME in predicted
572
+ else None
573
+ )
574
+ end_to_end_user_utterance = EndToEndUserUtterance(
575
+ text=event.text,
576
+ intent=event.intent,
577
+ entities=event.entities,
578
+ parse_data=response_selector_info,
579
+ )
580
+ partial_tracker.update(end_to_end_user_utterance)
581
+
582
+ return user_uttered_eval_store
583
+
584
+
585
+ def emulate_loop_rejection(partial_tracker: DialogueStateTracker) -> None:
586
+ """Add `ActionExecutionRejected` event to the tracker.
587
+
588
+ During evaluation, we don't run action server, therefore in order to correctly
589
+ test unhappy paths of the loops, we need to emulate loop rejection.
590
+
591
+ Args:
592
+ partial_tracker: a :class:`rasa.core.trackers.DialogueStateTracker`
593
+ """
594
+ from rasa.shared.core.events import ActionExecutionRejected
595
+
596
+ rejected_action_name = partial_tracker.active_loop_name
597
+ partial_tracker.update(ActionExecutionRejected(rejected_action_name))
598
+
599
+
600
+ async def _get_e2e_entity_evaluation_result(
601
+ processor: "MessageProcessor",
602
+ tracker: DialogueStateTracker,
603
+ prediction: PolicyPrediction,
604
+ ) -> Optional[EntityEvaluationResult]:
605
+ previous_event: Optional["Event"] = tracker.events[-1]
606
+
607
+ if isinstance(previous_event, SlotSet):
608
+ # UserUttered events with entities can be followed by SlotSet events
609
+ # if slots are defined in the domain
610
+ previous_event = tracker.get_last_event_for((UserUttered, ActionExecuted))
611
+
612
+ if isinstance(previous_event, UserUttered):
613
+ entities_predicted_by_policies = [
614
+ entity
615
+ for prediction_event in prediction.events
616
+ if isinstance(prediction_event, EntitiesAdded)
617
+ for entity in prediction_event.entities
618
+ ]
619
+ entity_targets = previous_event.entities
620
+ if entity_targets or entities_predicted_by_policies:
621
+ text = previous_event.text
622
+ if text:
623
+ parsed_message = await processor.parse_message(UserMessage(text=text))
624
+ if parsed_message:
625
+ tokens = [
626
+ Token(text[start:end], start, end)
627
+ for start, end in parsed_message.get(TOKENS_NAMES[TEXT], [])
628
+ ]
629
+ return EntityEvaluationResult(
630
+ entity_targets, entities_predicted_by_policies, tokens, text
631
+ )
632
+ return None
633
+
634
+
635
+ def _get_predicted_action_name(
636
+ predicted_action: rasa.core.actions.action.Action,
637
+ partial_tracker: DialogueStateTracker,
638
+ expected_action_name: Text,
639
+ ) -> Text:
640
+ """Get the name of predicted action.
641
+
642
+ If the action is instance of `ActionRetrieveResponse`, we need to return full
643
+ action name with its retrieval intent (e.g. utter_faq/is-this-legit).
644
+ The only case when we should not do it is when an expected action given in
645
+ a test story is a retrieval action but it's not specified in the test story.
646
+ To illustrate this, we're basically avoiding this unnecessary mismatch:
647
+ utter_faq (expected) != utter_faq/is-this-legit (predicted).
648
+ In this case or if the action isn't instance of `ActionRetrieveResponse`,
649
+ the function returns only the action name (e.g. utter_faq).
650
+ """
651
+ if (
652
+ isinstance(predicted_action, ActionRetrieveResponse)
653
+ and expected_action_name != predicted_action.name()
654
+ ):
655
+ full_retrieval_name = predicted_action.get_full_retrieval_name(partial_tracker)
656
+ predicted_action_name = (
657
+ full_retrieval_name if full_retrieval_name else predicted_action.name()
658
+ )
659
+ else:
660
+ predicted_action_name = predicted_action.name()
661
+ return predicted_action_name
662
+
663
+
664
+ async def _run_action_prediction(
665
+ processor: "MessageProcessor",
666
+ partial_tracker: DialogueStateTracker,
667
+ expected_action: Text,
668
+ ) -> Tuple[Text, PolicyPrediction, Optional[EntityEvaluationResult]]:
669
+ action, prediction = await processor.predict_next_with_tracker_if_should(
670
+ partial_tracker
671
+ )
672
+ predicted_action = _get_predicted_action_name(
673
+ action, partial_tracker, expected_action
674
+ )
675
+
676
+ policy_entity_result = await _get_e2e_entity_evaluation_result(
677
+ processor, partial_tracker, prediction
678
+ )
679
+ if (
680
+ prediction.policy_name
681
+ and predicted_action != expected_action
682
+ and _form_might_have_been_rejected(
683
+ processor.domain, partial_tracker, predicted_action
684
+ )
685
+ ):
686
+ # Wrong action was predicted,
687
+ # but it might be Ok if form action is rejected.
688
+ emulate_loop_rejection(partial_tracker)
689
+ # try again
690
+ action, prediction = await processor.predict_next_with_tracker_if_should(
691
+ partial_tracker
692
+ )
693
+ # Even if the prediction is also wrong, we don't have to undo the emulation
694
+ # of the action rejection as we know that the user explicitly specified
695
+ # that something else than the form was supposed to run.
696
+ predicted_action = _get_predicted_action_name(
697
+ action, partial_tracker, expected_action
698
+ )
699
+
700
+ return predicted_action, prediction, policy_entity_result
701
+
702
+
703
+ async def _collect_action_executed_predictions(
704
+ processor: "MessageProcessor",
705
+ partial_tracker: DialogueStateTracker,
706
+ event: ActionExecuted,
707
+ fail_on_prediction_errors: bool,
708
+ ) -> Tuple[EvaluationStore, PolicyPrediction, Optional[EntityEvaluationResult]]:
709
+
710
+ action_executed_eval_store = EvaluationStore()
711
+
712
+ expected_action_name = event.action_name
713
+ expected_action_text = event.action_text
714
+ expected_action = expected_action_name or expected_action_text
715
+
716
+ policy_entity_result = None
717
+ prev_action_unlikely_intent = False
718
+
719
+ try:
720
+ (
721
+ predicted_action,
722
+ prediction,
723
+ policy_entity_result,
724
+ ) = await _run_action_prediction(processor, partial_tracker, expected_action)
725
+ except ActionLimitReached:
726
+ prediction = PolicyPrediction([], policy_name=None)
727
+ predicted_action = "circuit breaker tripped"
728
+
729
+ predicted_action_unlikely_intent = predicted_action == ACTION_UNLIKELY_INTENT_NAME
730
+ if predicted_action_unlikely_intent and predicted_action != expected_action:
731
+ partial_tracker.update(
732
+ WronglyPredictedAction(
733
+ predicted_action,
734
+ expected_action_text,
735
+ predicted_action,
736
+ prediction.policy_name,
737
+ prediction.max_confidence,
738
+ event.timestamp,
739
+ metadata=prediction.action_metadata,
740
+ )
741
+ )
742
+ prev_action_unlikely_intent = True
743
+
744
+ try:
745
+ (
746
+ predicted_action,
747
+ prediction,
748
+ policy_entity_result,
749
+ ) = await _run_action_prediction(
750
+ processor, partial_tracker, expected_action
751
+ )
752
+ except ActionLimitReached:
753
+ prediction = PolicyPrediction([], policy_name=None)
754
+ predicted_action = "circuit breaker tripped"
755
+
756
+ action_executed_eval_store.add_to_store(
757
+ action_predictions=[predicted_action], action_targets=[expected_action]
758
+ )
759
+
760
+ if action_executed_eval_store.check_prediction_target_mismatch():
761
+ partial_tracker.update(
762
+ WronglyPredictedAction(
763
+ expected_action_name,
764
+ expected_action_text,
765
+ predicted_action,
766
+ prediction.policy_name,
767
+ prediction.max_confidence,
768
+ event.timestamp,
769
+ metadata=prediction.action_metadata,
770
+ predicted_action_unlikely_intent=prev_action_unlikely_intent,
771
+ )
772
+ )
773
+ if (
774
+ fail_on_prediction_errors
775
+ and predicted_action != ACTION_UNLIKELY_INTENT_NAME
776
+ and predicted_action != expected_action
777
+ ):
778
+ story_dump = YAMLStoryWriter().dumps(partial_tracker.as_story().story_steps)
779
+ error_msg = (
780
+ f"Model predicted a wrong action. Failed Story: " f"\n\n{story_dump}"
781
+ )
782
+ raise WrongPredictionException(error_msg)
783
+ elif prev_action_unlikely_intent:
784
+ partial_tracker.update(
785
+ WarningPredictedAction(
786
+ ACTION_UNLIKELY_INTENT_NAME,
787
+ predicted_action,
788
+ prediction.policy_name,
789
+ prediction.max_confidence,
790
+ event.timestamp,
791
+ prediction.action_metadata,
792
+ )
793
+ )
794
+ else:
795
+ partial_tracker.update(
796
+ ActionExecuted(
797
+ predicted_action,
798
+ prediction.policy_name,
799
+ prediction.max_confidence,
800
+ event.timestamp,
801
+ metadata=prediction.action_metadata,
802
+ )
803
+ )
804
+
805
+ return action_executed_eval_store, prediction, policy_entity_result
806
+
807
+
808
+ def _form_might_have_been_rejected(
809
+ domain: Domain, tracker: DialogueStateTracker, predicted_action_name: Text
810
+ ) -> bool:
811
+ return (
812
+ tracker.active_loop_name == predicted_action_name
813
+ and predicted_action_name in domain.form_names
814
+ )
815
+
816
+
817
+ async def _predict_tracker_actions(
818
+ tracker: DialogueStateTracker,
819
+ agent: "Agent",
820
+ fail_on_prediction_errors: bool = False,
821
+ use_e2e: bool = False,
822
+ ) -> Tuple[
823
+ EvaluationStore,
824
+ DialogueStateTracker,
825
+ List[Dict[Text, Any]],
826
+ List[EntityEvaluationResult],
827
+ ]:
828
+
829
+ processor = agent.processor
830
+ if agent.processor is not None:
831
+ processor = agent.processor
832
+ else:
833
+ raise RasaException(
834
+ "The agent's processor has not been instantiated. "
835
+ "The processor needs to be defined before running "
836
+ "prediction."
837
+ )
838
+
839
+ tracker_eval_store = EvaluationStore()
840
+
841
+ events = list(tracker.events)
842
+
843
+ slots = agent.domain.slots if agent.domain is not None else []
844
+
845
+ partial_tracker = DialogueStateTracker.from_events(
846
+ tracker.sender_id,
847
+ events[:1],
848
+ slots,
849
+ sender_source=tracker.sender_source,
850
+ )
851
+ tracker_actions = []
852
+ policy_entity_results = []
853
+
854
+ for event in events[1:]:
855
+ if isinstance(event, ActionExecuted):
856
+ (
857
+ action_executed_result,
858
+ prediction,
859
+ entity_result,
860
+ ) = await _collect_action_executed_predictions(
861
+ processor, partial_tracker, event, fail_on_prediction_errors
862
+ )
863
+ if entity_result:
864
+ policy_entity_results.append(entity_result)
865
+
866
+ if action_executed_result.action_targets:
867
+ tracker_eval_store.merge_store(action_executed_result)
868
+ tracker_actions.append(
869
+ {
870
+ "action": action_executed_result.action_targets[0],
871
+ "predicted": action_executed_result.action_predictions[0],
872
+ "policy": prediction.policy_name,
873
+ "confidence": prediction.max_confidence,
874
+ }
875
+ )
876
+ elif use_e2e and isinstance(event, UserUttered):
877
+ # This means that user utterance didn't have a user message, only intent,
878
+ # so we can skip the NLU part and take the parse data directly.
879
+ # Indirectly that means that the test story was in YAML format.
880
+ if not event.text:
881
+ # FIXME: better type annotation for `parse_data` would require
882
+ # a larger refactoring (e.g. switch to dataclass)
883
+ predicted = cast(Dict[Text, Any], event.parse_data)
884
+ # Indirectly that means that the test story was either:
885
+ # in YAML format containing a user message, or in Markdown format.
886
+ # Leaving that as it is because Markdown is in legacy mode.
887
+ else:
888
+ predicted = await processor.parse_message(UserMessage(event.text))
889
+
890
+ user_uttered_result = _collect_user_uttered_predictions(
891
+ event, predicted, partial_tracker, fail_on_prediction_errors
892
+ )
893
+ tracker_eval_store.merge_store(user_uttered_result)
894
+ else:
895
+ partial_tracker.update(event)
896
+ return tracker_eval_store, partial_tracker, tracker_actions, policy_entity_results
897
+
898
+
899
+ def _in_training_data_fraction(action_list: List[Dict[Text, Any]]) -> float:
900
+ """Given a list of actions, returns the fraction predicted by non ML policies."""
901
+ import rasa.core.policies.ensemble
902
+
903
+ in_training_data = [
904
+ a["action"]
905
+ for a in action_list
906
+ if a["policy"]
907
+ and not rasa.core.policies.ensemble.is_not_in_training_data(a["policy"])
908
+ ]
909
+
910
+ return len(in_training_data) / len(action_list) if action_list else 0
911
+
912
+
913
+ def _sort_trackers_with_severity_of_warning(
914
+ trackers_to_sort: List[DialogueStateTracker],
915
+ ) -> List[DialogueStateTracker]:
916
+ """Sort the given trackers according to 'severity' of `action_unlikely_intent`.
917
+
918
+ Severity is calculated by `IntentTEDPolicy` and is attached as
919
+ metadata to `ActionExecuted` event.
920
+
921
+ Args:
922
+ trackers_to_sort: Trackers to be sorted
923
+
924
+ Returns:
925
+ Sorted trackers in descending order of severity.
926
+ """
927
+ tracker_severity_scores = []
928
+ for tracker in trackers_to_sort:
929
+ max_severity = 0
930
+ for event in tracker.applied_events():
931
+ if (
932
+ isinstance(event, WronglyPredictedAction)
933
+ and event.action_name_prediction == ACTION_UNLIKELY_INTENT_NAME
934
+ ):
935
+ max_severity = max(
936
+ max_severity,
937
+ event.metadata.get(QUERY_INTENT_KEY, {}).get(SEVERITY_KEY, 0),
938
+ )
939
+ tracker_severity_scores.append(max_severity)
940
+
941
+ sorted_trackers_with_severity = sorted(
942
+ zip(tracker_severity_scores, trackers_to_sort),
943
+ # tuple unpacking is not supported in
944
+ # python 3.x that's why it might look a bit weird
945
+ key=lambda severity_tracker_tuple: -severity_tracker_tuple[0],
946
+ )
947
+
948
+ return [tracker for (_, tracker) in sorted_trackers_with_severity]
949
+
950
+
951
+ async def _collect_story_predictions(
952
+ completed_trackers: List["DialogueStateTracker"],
953
+ agent: "Agent",
954
+ fail_on_prediction_errors: bool = False,
955
+ use_e2e: bool = False,
956
+ ) -> Tuple[StoryEvaluation, int, List[EntityEvaluationResult]]:
957
+ """Test the stories from a file, running them through the stored model."""
958
+ from sklearn.metrics import accuracy_score
959
+ from tqdm import tqdm
960
+
961
+ story_eval_store = EvaluationStore()
962
+ failed_stories = []
963
+ successful_stories = []
964
+ stories_with_warnings = []
965
+ correct_dialogues = []
966
+ number_of_stories = len(completed_trackers)
967
+
968
+ logger.info(f"Evaluating {number_of_stories} stories\nProgress:")
969
+
970
+ action_list = []
971
+ entity_results = []
972
+
973
+ if agent.domain:
974
+ for slot in agent.domain.slots:
975
+ # set the routing slot to False in case the coexistence feature is used
976
+ # this way the DM1 policies will run and the CALM policies will keep silent
977
+ if slot.name == ROUTE_TO_CALM_SLOT:
978
+ slot.initial_value = False
979
+
980
+ for tracker in tqdm(completed_trackers):
981
+ (
982
+ tracker_results,
983
+ predicted_tracker,
984
+ tracker_actions,
985
+ tracker_entity_results,
986
+ ) = await _predict_tracker_actions(
987
+ tracker, agent, fail_on_prediction_errors, use_e2e
988
+ )
989
+
990
+ entity_results.extend(tracker_entity_results)
991
+
992
+ story_eval_store.merge_store(tracker_results)
993
+
994
+ action_list.extend(tracker_actions)
995
+
996
+ if tracker_results.check_prediction_target_mismatch():
997
+ # there is at least one wrong prediction
998
+ failed_stories.append(predicted_tracker)
999
+ correct_dialogues.append(0)
1000
+ else:
1001
+ successful_stories.append(predicted_tracker)
1002
+ correct_dialogues.append(1)
1003
+
1004
+ if any(
1005
+ isinstance(event, WronglyPredictedAction)
1006
+ and event.action_name_prediction == ACTION_UNLIKELY_INTENT_NAME
1007
+ for event in predicted_tracker.events
1008
+ ):
1009
+ stories_with_warnings.append(predicted_tracker)
1010
+
1011
+ logger.info("Finished collecting predictions.")
1012
+
1013
+ in_training_data_fraction = _in_training_data_fraction(action_list)
1014
+
1015
+ if len(correct_dialogues):
1016
+ accuracy = accuracy_score([1] * len(correct_dialogues), correct_dialogues)
1017
+ else:
1018
+ accuracy = 0
1019
+
1020
+ _log_evaluation_table([1] * len(completed_trackers), "CONVERSATION", accuracy)
1021
+
1022
+ return (
1023
+ StoryEvaluation(
1024
+ evaluation_store=story_eval_store,
1025
+ failed_stories=failed_stories,
1026
+ successful_stories=successful_stories,
1027
+ stories_with_warnings=_sort_trackers_with_severity_of_warning(
1028
+ stories_with_warnings
1029
+ ),
1030
+ action_list=action_list,
1031
+ in_training_data_fraction=in_training_data_fraction,
1032
+ ),
1033
+ number_of_stories,
1034
+ entity_results,
1035
+ )
1036
+
1037
+
1038
+ def _filter_step_events(step: StoryStep) -> StoryStep:
1039
+ events = []
1040
+ for event in step.events:
1041
+ if (
1042
+ isinstance(event, WronglyPredictedAction)
1043
+ and event.action_name
1044
+ == event.action_name_prediction
1045
+ == ACTION_UNLIKELY_INTENT_NAME
1046
+ ):
1047
+ continue
1048
+ events.append(event)
1049
+ updated_step = step.create_copy(use_new_id=False)
1050
+ updated_step.events = events
1051
+ return updated_step
1052
+
1053
+
1054
+ def _log_stories(
1055
+ trackers: List[DialogueStateTracker], file_path: Text, message_if_no_trackers: Text
1056
+ ) -> None:
1057
+ """Write given stories to the given file."""
1058
+ with open(file_path, "w", encoding=DEFAULT_ENCODING) as f:
1059
+ if not trackers:
1060
+ f.write(f"# {message_if_no_trackers}")
1061
+ else:
1062
+ stories = [tracker.as_story(include_source=True) for tracker in trackers]
1063
+ steps = [
1064
+ _filter_step_events(step)
1065
+ for story in stories
1066
+ for step in story.story_steps
1067
+ ]
1068
+ f.write(YAMLStoryWriter().dumps(steps))
1069
+
1070
+
1071
+ async def test(
1072
+ stories: Text,
1073
+ agent: "Agent",
1074
+ max_stories: Optional[int] = None,
1075
+ out_directory: Optional[Text] = None,
1076
+ fail_on_prediction_errors: bool = False,
1077
+ e2e: bool = False,
1078
+ disable_plotting: bool = False,
1079
+ successes: bool = False,
1080
+ errors: bool = True,
1081
+ warnings: bool = True,
1082
+ ) -> Dict[Text, Any]:
1083
+ """Run the evaluation of the stories, optionally plot the results.
1084
+
1085
+ Args:
1086
+ stories: the stories to evaluate on
1087
+ agent: the agent
1088
+ max_stories: maximum number of stories to consider
1089
+ out_directory: path to directory to results to
1090
+ fail_on_prediction_errors: boolean indicating whether to fail on prediction
1091
+ errors or not
1092
+ e2e: boolean indicating whether to use end to end evaluation or not
1093
+ disable_plotting: boolean indicating whether to disable plotting or not
1094
+ successes: boolean indicating whether to write down successful predictions or
1095
+ not
1096
+ errors: boolean indicating whether to write down incorrect predictions or not
1097
+ warnings: boolean indicating whether to write down prediction warnings or not
1098
+
1099
+ Returns:
1100
+ Evaluation summary.
1101
+ """
1102
+ from rasa.model_testing import get_evaluation_metrics
1103
+
1104
+ generator = _create_data_generator(stories, agent, max_stories, e2e)
1105
+ completed_trackers = generator.generate_story_trackers()
1106
+
1107
+ story_evaluation, _, entity_results = await _collect_story_predictions(
1108
+ completed_trackers, agent, fail_on_prediction_errors, use_e2e=e2e
1109
+ )
1110
+
1111
+ evaluation_store = story_evaluation.evaluation_store
1112
+
1113
+ with pywarnings.catch_warnings():
1114
+ from sklearn.exceptions import UndefinedMetricWarning
1115
+
1116
+ pywarnings.simplefilter("ignore", UndefinedMetricWarning)
1117
+
1118
+ targets, predictions = evaluation_store.serialise()
1119
+
1120
+ report, precision, f1, action_accuracy = get_evaluation_metrics(
1121
+ targets, predictions, output_dict=True
1122
+ )
1123
+ if out_directory:
1124
+ # Add conversation level accuracy to story report.
1125
+ num_failed = len(story_evaluation.failed_stories)
1126
+ num_correct = len(story_evaluation.successful_stories)
1127
+ num_warnings = len(story_evaluation.stories_with_warnings)
1128
+ num_convs = num_failed + num_correct
1129
+ if num_convs and isinstance(report, Dict):
1130
+ conv_accuracy = num_correct / num_convs
1131
+ report["conversation_accuracy"] = {
1132
+ "accuracy": conv_accuracy,
1133
+ "correct": num_correct,
1134
+ "with_warnings": num_warnings,
1135
+ "total": num_convs,
1136
+ }
1137
+ report_filename = os.path.join(out_directory, REPORT_STORIES_FILE)
1138
+ rasa.shared.utils.io.dump_obj_as_json_to_file(report_filename, report)
1139
+ logger.info(f"Stories report saved to {report_filename}.")
1140
+
1141
+ evaluate_entities(
1142
+ entity_results,
1143
+ POLICIES_THAT_EXTRACT_ENTITIES,
1144
+ out_directory,
1145
+ successes,
1146
+ errors,
1147
+ disable_plotting,
1148
+ )
1149
+
1150
+ telemetry.track_core_model_test(len(generator.story_graph.story_steps), e2e, agent)
1151
+
1152
+ _log_evaluation_table(
1153
+ evaluation_store.action_targets,
1154
+ "ACTION",
1155
+ action_accuracy,
1156
+ precision=precision,
1157
+ f1=f1,
1158
+ in_training_data_fraction=story_evaluation.in_training_data_fraction,
1159
+ )
1160
+
1161
+ if not disable_plotting and out_directory:
1162
+ _plot_story_evaluation(
1163
+ evaluation_store.action_targets,
1164
+ evaluation_store.action_predictions,
1165
+ out_directory,
1166
+ )
1167
+
1168
+ if errors and out_directory:
1169
+ _log_stories(
1170
+ story_evaluation.failed_stories,
1171
+ os.path.join(out_directory, FAILED_STORIES_FILE),
1172
+ "None of the test stories failed - all good!",
1173
+ )
1174
+ if successes and out_directory:
1175
+ _log_stories(
1176
+ story_evaluation.successful_stories,
1177
+ os.path.join(out_directory, SUCCESSFUL_STORIES_FILE),
1178
+ "None of the test stories succeeded :(",
1179
+ )
1180
+ if warnings and out_directory:
1181
+ _log_stories(
1182
+ story_evaluation.stories_with_warnings,
1183
+ os.path.join(out_directory, STORIES_WITH_WARNINGS_FILE),
1184
+ "No warnings for test stories",
1185
+ )
1186
+
1187
+ return {
1188
+ "report": report,
1189
+ "precision": precision,
1190
+ "f1": f1,
1191
+ "accuracy": action_accuracy,
1192
+ "actions": story_evaluation.action_list,
1193
+ "in_training_data_fraction": story_evaluation.in_training_data_fraction,
1194
+ "is_end_to_end_evaluation": e2e,
1195
+ }
1196
+
1197
+
1198
+ def _log_evaluation_table(
1199
+ golds: List[Any],
1200
+ name: Text,
1201
+ accuracy: float,
1202
+ report: Optional[Dict[Text, Any]] = None,
1203
+ precision: Optional[float] = None,
1204
+ f1: Optional[float] = None,
1205
+ in_training_data_fraction: Optional[float] = None,
1206
+ include_report: bool = True,
1207
+ ) -> None: # pragma: no cover
1208
+ """Log the sklearn evaluation metrics."""
1209
+ logger.info(f"Evaluation Results on {name} level:")
1210
+ logger.info(f"\tCorrect: {int(len(golds) * accuracy)} / {len(golds)}")
1211
+ if f1 is not None:
1212
+ logger.info(f"\tF1-Score: {f1:.3f}")
1213
+ if precision is not None:
1214
+ logger.info(f"\tPrecision: {precision:.3f}")
1215
+ logger.info(f"\tAccuracy: {accuracy:.3f}")
1216
+ if in_training_data_fraction is not None:
1217
+ logger.info(f"\tIn-data fraction: {in_training_data_fraction:.3g}")
1218
+
1219
+ if include_report and report is not None:
1220
+ logger.info(f"\tClassification report: \n{report}")
1221
+
1222
+
1223
+ def _plot_story_evaluation(
1224
+ targets: PredictionList,
1225
+ predictions: PredictionList,
1226
+ output_directory: Optional[Text],
1227
+ ) -> None:
1228
+ """Plot a confusion matrix of story evaluation."""
1229
+ from sklearn.metrics import confusion_matrix
1230
+ from sklearn.utils.multiclass import unique_labels
1231
+ from rasa.utils.plotting import plot_confusion_matrix
1232
+
1233
+ confusion_matrix_filename = CONFUSION_MATRIX_STORIES_FILE
1234
+ if output_directory:
1235
+ confusion_matrix_filename = os.path.join(
1236
+ output_directory, confusion_matrix_filename
1237
+ )
1238
+
1239
+ cnf_matrix = confusion_matrix(targets, predictions)
1240
+
1241
+ plot_confusion_matrix(
1242
+ cnf_matrix,
1243
+ classes=unique_labels(targets, predictions),
1244
+ title="Action Confusion matrix",
1245
+ output_file=confusion_matrix_filename,
1246
+ )
1247
+
1248
+
1249
+ async def compare_models_in_dir(
1250
+ model_dir: Text,
1251
+ stories_file: Text,
1252
+ output: Text,
1253
+ use_conversation_test_files: bool = False,
1254
+ ) -> None:
1255
+ """Evaluates multiple trained models in a directory on a test set.
1256
+
1257
+ Args:
1258
+ model_dir: path to directory that contains the models to evaluate
1259
+ stories_file: path to the story file
1260
+ output: output directory to store results to
1261
+ use_conversation_test_files: `True` if conversation test files should be used
1262
+ for testing instead of regular Core story files.
1263
+ """
1264
+ number_correct = defaultdict(list)
1265
+
1266
+ for run in rasa.shared.utils.io.list_subdirectories(model_dir):
1267
+ number_correct_in_run = defaultdict(list)
1268
+
1269
+ for model in sorted(rasa.shared.utils.io.list_files(run)):
1270
+ if not model.endswith("tar.gz"):
1271
+ continue
1272
+
1273
+ # The model files are named like <config-name>PERCENTAGE_KEY<number>.tar.gz
1274
+ # Remove the percentage key and number from the name to get the config name
1275
+ config_name = os.path.basename(model).split(PERCENTAGE_KEY)[0]
1276
+ number_of_correct_stories = await _evaluate_core_model(
1277
+ model,
1278
+ stories_file,
1279
+ use_conversation_test_files=use_conversation_test_files,
1280
+ )
1281
+ number_correct_in_run[config_name].append(number_of_correct_stories)
1282
+
1283
+ for k, v in number_correct_in_run.items():
1284
+ number_correct[k].append(v)
1285
+
1286
+ rasa.shared.utils.io.dump_obj_as_json_to_file(
1287
+ os.path.join(output, RESULTS_FILE), number_correct
1288
+ )
1289
+
1290
+
1291
+ async def compare_models(
1292
+ models: List[Text],
1293
+ stories_file: Text,
1294
+ output: Text,
1295
+ use_conversation_test_files: bool = False,
1296
+ ) -> None:
1297
+ """Evaluates multiple trained models on a test set.
1298
+
1299
+ Args:
1300
+ models: Paths to model files.
1301
+ stories_file: path to the story file
1302
+ output: output directory to store results to
1303
+ use_conversation_test_files: `True` if conversation test files should be used
1304
+ for testing instead of regular Core story files.
1305
+ """
1306
+ number_correct = defaultdict(list)
1307
+
1308
+ for model in models:
1309
+ number_of_correct_stories = await _evaluate_core_model(
1310
+ model, stories_file, use_conversation_test_files=use_conversation_test_files
1311
+ )
1312
+ number_correct[os.path.basename(model)].append(number_of_correct_stories)
1313
+
1314
+ rasa.shared.utils.io.dump_obj_as_json_to_file(
1315
+ os.path.join(output, RESULTS_FILE), number_correct
1316
+ )
1317
+
1318
+
1319
+ async def _evaluate_core_model(
1320
+ model: Text, stories_file: Text, use_conversation_test_files: bool = False
1321
+ ) -> int:
1322
+ from rasa.core.agent import Agent
1323
+
1324
+ logger.info(f"Evaluating model '{model}'")
1325
+
1326
+ agent = Agent.load(model)
1327
+ generator = _create_data_generator(
1328
+ stories_file, agent, use_conversation_test_files=use_conversation_test_files
1329
+ )
1330
+ completed_trackers = generator.generate_story_trackers()
1331
+
1332
+ # Entities are ignored here as we only compare number of correct stories.
1333
+ story_eval_store, number_of_stories, _ = await _collect_story_predictions(
1334
+ completed_trackers, agent
1335
+ )
1336
+ failed_stories = story_eval_store.failed_stories
1337
+ return number_of_stories - len(failed_stories)