rasa-pro 3.8.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rasa-pro might be problematic. Click here for more details.

Files changed (644) hide show
  1. README.md +380 -0
  2. rasa/__init__.py +10 -0
  3. rasa/__main__.py +151 -0
  4. rasa/anonymization/__init__.py +2 -0
  5. rasa/anonymization/anonymisation_rule_yaml_reader.py +91 -0
  6. rasa/anonymization/anonymization_pipeline.py +287 -0
  7. rasa/anonymization/anonymization_rule_executor.py +260 -0
  8. rasa/anonymization/anonymization_rule_orchestrator.py +120 -0
  9. rasa/anonymization/schemas/config.yml +47 -0
  10. rasa/anonymization/utils.py +117 -0
  11. rasa/api.py +146 -0
  12. rasa/cli/__init__.py +5 -0
  13. rasa/cli/arguments/__init__.py +0 -0
  14. rasa/cli/arguments/data.py +81 -0
  15. rasa/cli/arguments/default_arguments.py +165 -0
  16. rasa/cli/arguments/evaluate.py +65 -0
  17. rasa/cli/arguments/export.py +51 -0
  18. rasa/cli/arguments/interactive.py +74 -0
  19. rasa/cli/arguments/run.py +204 -0
  20. rasa/cli/arguments/shell.py +13 -0
  21. rasa/cli/arguments/test.py +211 -0
  22. rasa/cli/arguments/train.py +263 -0
  23. rasa/cli/arguments/visualize.py +34 -0
  24. rasa/cli/arguments/x.py +30 -0
  25. rasa/cli/data.py +292 -0
  26. rasa/cli/e2e_test.py +566 -0
  27. rasa/cli/evaluate.py +222 -0
  28. rasa/cli/export.py +251 -0
  29. rasa/cli/inspect.py +63 -0
  30. rasa/cli/interactive.py +164 -0
  31. rasa/cli/license.py +65 -0
  32. rasa/cli/markers.py +78 -0
  33. rasa/cli/project_templates/__init__.py +0 -0
  34. rasa/cli/project_templates/calm/actions/__init__.py +0 -0
  35. rasa/cli/project_templates/calm/actions/action_template.py +27 -0
  36. rasa/cli/project_templates/calm/actions/add_contact.py +30 -0
  37. rasa/cli/project_templates/calm/actions/db.py +57 -0
  38. rasa/cli/project_templates/calm/actions/list_contacts.py +22 -0
  39. rasa/cli/project_templates/calm/actions/remove_contact.py +35 -0
  40. rasa/cli/project_templates/calm/config.yml +12 -0
  41. rasa/cli/project_templates/calm/credentials.yml +33 -0
  42. rasa/cli/project_templates/calm/data/flows/add_contact.yml +31 -0
  43. rasa/cli/project_templates/calm/data/flows/list_contacts.yml +14 -0
  44. rasa/cli/project_templates/calm/data/flows/remove_contact.yml +29 -0
  45. rasa/cli/project_templates/calm/db/contacts.json +10 -0
  46. rasa/cli/project_templates/calm/domain/add_contact.yml +33 -0
  47. rasa/cli/project_templates/calm/domain/list_contacts.yml +14 -0
  48. rasa/cli/project_templates/calm/domain/remove_contact.yml +31 -0
  49. rasa/cli/project_templates/calm/domain/shared.yml +5 -0
  50. rasa/cli/project_templates/calm/e2e_tests/cancelations/user_cancels_during_a_correction.yml +16 -0
  51. rasa/cli/project_templates/calm/e2e_tests/cancelations/user_changes_mind_on_a_whim.yml +7 -0
  52. rasa/cli/project_templates/calm/e2e_tests/corrections/user_corrects_contact_handle.yml +20 -0
  53. rasa/cli/project_templates/calm/e2e_tests/corrections/user_corrects_contact_name.yml +19 -0
  54. rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_adds_contact_to_their_list.yml +15 -0
  55. rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_lists_contacts.yml +5 -0
  56. rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_removes_contact.yml +11 -0
  57. rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_removes_contact_from_list.yml +12 -0
  58. rasa/cli/project_templates/calm/endpoints.yml +45 -0
  59. rasa/cli/project_templates/default/actions/__init__.py +0 -0
  60. rasa/cli/project_templates/default/actions/actions.py +27 -0
  61. rasa/cli/project_templates/default/config.yml +44 -0
  62. rasa/cli/project_templates/default/credentials.yml +33 -0
  63. rasa/cli/project_templates/default/data/nlu.yml +91 -0
  64. rasa/cli/project_templates/default/data/rules.yml +13 -0
  65. rasa/cli/project_templates/default/data/stories.yml +30 -0
  66. rasa/cli/project_templates/default/domain.yml +34 -0
  67. rasa/cli/project_templates/default/endpoints.yml +42 -0
  68. rasa/cli/project_templates/default/tests/test_stories.yml +91 -0
  69. rasa/cli/project_templates/tutorial/actions.py +22 -0
  70. rasa/cli/project_templates/tutorial/config.yml +11 -0
  71. rasa/cli/project_templates/tutorial/credentials.yml +33 -0
  72. rasa/cli/project_templates/tutorial/data/flows.yml +8 -0
  73. rasa/cli/project_templates/tutorial/domain.yml +17 -0
  74. rasa/cli/project_templates/tutorial/endpoints.yml +45 -0
  75. rasa/cli/run.py +136 -0
  76. rasa/cli/scaffold.py +268 -0
  77. rasa/cli/shell.py +141 -0
  78. rasa/cli/studio/__init__.py +0 -0
  79. rasa/cli/studio/download.py +51 -0
  80. rasa/cli/studio/studio.py +110 -0
  81. rasa/cli/studio/train.py +59 -0
  82. rasa/cli/studio/upload.py +85 -0
  83. rasa/cli/telemetry.py +90 -0
  84. rasa/cli/test.py +280 -0
  85. rasa/cli/train.py +260 -0
  86. rasa/cli/utils.py +453 -0
  87. rasa/cli/visualize.py +40 -0
  88. rasa/cli/x.py +205 -0
  89. rasa/constants.py +37 -0
  90. rasa/core/__init__.py +17 -0
  91. rasa/core/actions/__init__.py +0 -0
  92. rasa/core/actions/action.py +1450 -0
  93. rasa/core/actions/action_clean_stack.py +59 -0
  94. rasa/core/actions/action_run_slot_rejections.py +207 -0
  95. rasa/core/actions/action_trigger_chitchat.py +31 -0
  96. rasa/core/actions/action_trigger_flow.py +109 -0
  97. rasa/core/actions/action_trigger_search.py +31 -0
  98. rasa/core/actions/constants.py +2 -0
  99. rasa/core/actions/forms.py +737 -0
  100. rasa/core/actions/loops.py +111 -0
  101. rasa/core/actions/two_stage_fallback.py +186 -0
  102. rasa/core/agent.py +557 -0
  103. rasa/core/auth_retry_tracker_store.py +122 -0
  104. rasa/core/brokers/__init__.py +0 -0
  105. rasa/core/brokers/broker.py +126 -0
  106. rasa/core/brokers/file.py +58 -0
  107. rasa/core/brokers/kafka.py +322 -0
  108. rasa/core/brokers/pika.py +387 -0
  109. rasa/core/brokers/sql.py +86 -0
  110. rasa/core/channels/__init__.py +55 -0
  111. rasa/core/channels/audiocodes.py +463 -0
  112. rasa/core/channels/botframework.py +339 -0
  113. rasa/core/channels/callback.py +85 -0
  114. rasa/core/channels/channel.py +419 -0
  115. rasa/core/channels/console.py +243 -0
  116. rasa/core/channels/development_inspector.py +93 -0
  117. rasa/core/channels/facebook.py +422 -0
  118. rasa/core/channels/hangouts.py +335 -0
  119. rasa/core/channels/inspector/.eslintrc.cjs +25 -0
  120. rasa/core/channels/inspector/.gitignore +23 -0
  121. rasa/core/channels/inspector/README.md +54 -0
  122. rasa/core/channels/inspector/assets/favicon.ico +0 -0
  123. rasa/core/channels/inspector/assets/rasa-chat.js +2 -0
  124. rasa/core/channels/inspector/custom.d.ts +3 -0
  125. rasa/core/channels/inspector/dist/assets/arc-5623b6dc.js +1 -0
  126. rasa/core/channels/inspector/dist/assets/array-9f3ba611.js +1 -0
  127. rasa/core/channels/inspector/dist/assets/c4Diagram-d0fbc5ce-685c106a.js +10 -0
  128. rasa/core/channels/inspector/dist/assets/classDiagram-936ed81e-8cbed007.js +2 -0
  129. rasa/core/channels/inspector/dist/assets/classDiagram-v2-c3cb15f1-5889cf12.js +2 -0
  130. rasa/core/channels/inspector/dist/assets/createText-62fc7601-24c249d7.js +7 -0
  131. rasa/core/channels/inspector/dist/assets/edges-f2ad444c-7dd06a75.js +4 -0
  132. rasa/core/channels/inspector/dist/assets/erDiagram-9d236eb7-62c1e54c.js +51 -0
  133. rasa/core/channels/inspector/dist/assets/flowDb-1972c806-ce49b86f.js +6 -0
  134. rasa/core/channels/inspector/dist/assets/flowDiagram-7ea5b25a-4067e48f.js +4 -0
  135. rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-85583a23.js +1 -0
  136. rasa/core/channels/inspector/dist/assets/flowchart-elk-definition-abe16c3d-59fe4051.js +139 -0
  137. rasa/core/channels/inspector/dist/assets/ganttDiagram-9b5ea136-47e3a43b.js +266 -0
  138. rasa/core/channels/inspector/dist/assets/gitGraphDiagram-99d0ae7c-5a2ac0d9.js +70 -0
  139. rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-128cfa44.ttf +0 -0
  140. rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-21dbcb97.woff +0 -0
  141. rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-222b5e26.svg +329 -0
  142. rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-9ad89b2a.woff2 +0 -0
  143. rasa/core/channels/inspector/dist/assets/index-268a75c0.js +1040 -0
  144. rasa/core/channels/inspector/dist/assets/index-2c4b9a3b-dfb8efc4.js +1 -0
  145. rasa/core/channels/inspector/dist/assets/index-3ee28881.css +1 -0
  146. rasa/core/channels/inspector/dist/assets/infoDiagram-736b4530-b0c470f2.js +7 -0
  147. rasa/core/channels/inspector/dist/assets/init-77b53fdd.js +1 -0
  148. rasa/core/channels/inspector/dist/assets/journeyDiagram-df861f2b-2edb829a.js +139 -0
  149. rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-60c05ee4.woff +0 -0
  150. rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-8335d9b8.svg +438 -0
  151. rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-9cc39c75.ttf +0 -0
  152. rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-ead13ccf.woff2 +0 -0
  153. rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-16705655.woff2 +0 -0
  154. rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-5aeb07f9.woff +0 -0
  155. rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-9c459044.ttf +0 -0
  156. rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-9e2898a4.svg +435 -0
  157. rasa/core/channels/inspector/dist/assets/layout-b6873d69.js +1 -0
  158. rasa/core/channels/inspector/dist/assets/line-1efc5781.js +1 -0
  159. rasa/core/channels/inspector/dist/assets/linear-661e9b94.js +1 -0
  160. rasa/core/channels/inspector/dist/assets/mindmap-definition-beec6740-2d2e727f.js +109 -0
  161. rasa/core/channels/inspector/dist/assets/ordinal-ba9b4969.js +1 -0
  162. rasa/core/channels/inspector/dist/assets/path-53f90ab3.js +1 -0
  163. rasa/core/channels/inspector/dist/assets/pieDiagram-dbbf0591-9d3ea93d.js +35 -0
  164. rasa/core/channels/inspector/dist/assets/quadrantDiagram-4d7f4fd6-06a178a2.js +7 -0
  165. rasa/core/channels/inspector/dist/assets/requirementDiagram-6fc4c22a-0bfedffc.js +52 -0
  166. rasa/core/channels/inspector/dist/assets/sankeyDiagram-8f13d901-d76d0a04.js +8 -0
  167. rasa/core/channels/inspector/dist/assets/sequenceDiagram-b655622a-37bb4341.js +122 -0
  168. rasa/core/channels/inspector/dist/assets/stateDiagram-59f0c015-f52f7f57.js +1 -0
  169. rasa/core/channels/inspector/dist/assets/stateDiagram-v2-2b26beab-4a986a20.js +1 -0
  170. rasa/core/channels/inspector/dist/assets/styles-080da4f6-7dd9ae12.js +110 -0
  171. rasa/core/channels/inspector/dist/assets/styles-3dcbcfbf-46e1ca14.js +159 -0
  172. rasa/core/channels/inspector/dist/assets/styles-9c745c82-4a97439a.js +207 -0
  173. rasa/core/channels/inspector/dist/assets/svgDrawCommon-4835440b-823917a3.js +1 -0
  174. rasa/core/channels/inspector/dist/assets/timeline-definition-5b62e21b-9ea72896.js +61 -0
  175. rasa/core/channels/inspector/dist/assets/xychartDiagram-2b33534f-b631a8b6.js +7 -0
  176. rasa/core/channels/inspector/dist/index.html +39 -0
  177. rasa/core/channels/inspector/index.html +37 -0
  178. rasa/core/channels/inspector/jest.config.ts +13 -0
  179. rasa/core/channels/inspector/package.json +48 -0
  180. rasa/core/channels/inspector/setupTests.ts +2 -0
  181. rasa/core/channels/inspector/src/App.tsx +170 -0
  182. rasa/core/channels/inspector/src/components/DiagramFlow.tsx +97 -0
  183. rasa/core/channels/inspector/src/components/DialogueInformation.tsx +187 -0
  184. rasa/core/channels/inspector/src/components/DialogueStack.tsx +151 -0
  185. rasa/core/channels/inspector/src/components/ExpandIcon.tsx +16 -0
  186. rasa/core/channels/inspector/src/components/FullscreenButton.tsx +45 -0
  187. rasa/core/channels/inspector/src/components/LoadingSpinner.tsx +19 -0
  188. rasa/core/channels/inspector/src/components/NoActiveFlow.tsx +21 -0
  189. rasa/core/channels/inspector/src/components/RasaLogo.tsx +32 -0
  190. rasa/core/channels/inspector/src/components/SaraDiagrams.tsx +39 -0
  191. rasa/core/channels/inspector/src/components/Slots.tsx +91 -0
  192. rasa/core/channels/inspector/src/components/Welcome.tsx +54 -0
  193. rasa/core/channels/inspector/src/helpers/formatters.test.ts +385 -0
  194. rasa/core/channels/inspector/src/helpers/formatters.ts +239 -0
  195. rasa/core/channels/inspector/src/helpers/utils.ts +42 -0
  196. rasa/core/channels/inspector/src/main.tsx +13 -0
  197. rasa/core/channels/inspector/src/theme/Button/Button.ts +29 -0
  198. rasa/core/channels/inspector/src/theme/Heading/Heading.ts +31 -0
  199. rasa/core/channels/inspector/src/theme/Input/Input.ts +27 -0
  200. rasa/core/channels/inspector/src/theme/Link/Link.ts +10 -0
  201. rasa/core/channels/inspector/src/theme/Modal/Modal.ts +47 -0
  202. rasa/core/channels/inspector/src/theme/Table/Table.tsx +38 -0
  203. rasa/core/channels/inspector/src/theme/Tooltip/Tooltip.ts +12 -0
  204. rasa/core/channels/inspector/src/theme/base/breakpoints.ts +8 -0
  205. rasa/core/channels/inspector/src/theme/base/colors.ts +88 -0
  206. rasa/core/channels/inspector/src/theme/base/fonts/fontFaces.css +29 -0
  207. rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.eot +0 -0
  208. rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.svg +329 -0
  209. rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.ttf +0 -0
  210. rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.woff +0 -0
  211. rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.woff2 +0 -0
  212. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.eot +0 -0
  213. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.svg +438 -0
  214. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.ttf +0 -0
  215. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.woff +0 -0
  216. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.woff2 +0 -0
  217. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.eot +0 -0
  218. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.svg +435 -0
  219. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.ttf +0 -0
  220. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.woff +0 -0
  221. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.woff2 +0 -0
  222. rasa/core/channels/inspector/src/theme/base/radii.ts +9 -0
  223. rasa/core/channels/inspector/src/theme/base/shadows.ts +7 -0
  224. rasa/core/channels/inspector/src/theme/base/sizes.ts +7 -0
  225. rasa/core/channels/inspector/src/theme/base/space.ts +15 -0
  226. rasa/core/channels/inspector/src/theme/base/styles.ts +13 -0
  227. rasa/core/channels/inspector/src/theme/base/typography.ts +24 -0
  228. rasa/core/channels/inspector/src/theme/base/zIndices.ts +19 -0
  229. rasa/core/channels/inspector/src/theme/index.ts +101 -0
  230. rasa/core/channels/inspector/src/types.ts +64 -0
  231. rasa/core/channels/inspector/src/vite-env.d.ts +1 -0
  232. rasa/core/channels/inspector/tests/__mocks__/fileMock.ts +1 -0
  233. rasa/core/channels/inspector/tests/__mocks__/matchMedia.ts +16 -0
  234. rasa/core/channels/inspector/tests/__mocks__/styleMock.ts +1 -0
  235. rasa/core/channels/inspector/tests/renderWithProviders.tsx +14 -0
  236. rasa/core/channels/inspector/tsconfig.json +26 -0
  237. rasa/core/channels/inspector/tsconfig.node.json +10 -0
  238. rasa/core/channels/inspector/vite.config.ts +8 -0
  239. rasa/core/channels/inspector/yarn.lock +6156 -0
  240. rasa/core/channels/mattermost.py +229 -0
  241. rasa/core/channels/rasa_chat.py +126 -0
  242. rasa/core/channels/rest.py +210 -0
  243. rasa/core/channels/rocketchat.py +175 -0
  244. rasa/core/channels/slack.py +620 -0
  245. rasa/core/channels/socketio.py +274 -0
  246. rasa/core/channels/telegram.py +298 -0
  247. rasa/core/channels/twilio.py +169 -0
  248. rasa/core/channels/twilio_voice.py +367 -0
  249. rasa/core/channels/vier_cvg.py +374 -0
  250. rasa/core/channels/webexteams.py +135 -0
  251. rasa/core/concurrent_lock_store.py +210 -0
  252. rasa/core/constants.py +107 -0
  253. rasa/core/evaluation/__init__.py +0 -0
  254. rasa/core/evaluation/marker.py +267 -0
  255. rasa/core/evaluation/marker_base.py +925 -0
  256. rasa/core/evaluation/marker_stats.py +294 -0
  257. rasa/core/evaluation/marker_tracker_loader.py +103 -0
  258. rasa/core/exceptions.py +29 -0
  259. rasa/core/exporter.py +284 -0
  260. rasa/core/featurizers/__init__.py +0 -0
  261. rasa/core/featurizers/precomputation.py +410 -0
  262. rasa/core/featurizers/single_state_featurizer.py +402 -0
  263. rasa/core/featurizers/tracker_featurizers.py +1172 -0
  264. rasa/core/http_interpreter.py +89 -0
  265. rasa/core/information_retrieval/__init__.py +0 -0
  266. rasa/core/information_retrieval/faiss.py +116 -0
  267. rasa/core/information_retrieval/information_retrieval.py +72 -0
  268. rasa/core/information_retrieval/milvus.py +59 -0
  269. rasa/core/information_retrieval/qdrant.py +102 -0
  270. rasa/core/jobs.py +63 -0
  271. rasa/core/lock.py +139 -0
  272. rasa/core/lock_store.py +344 -0
  273. rasa/core/migrate.py +404 -0
  274. rasa/core/nlg/__init__.py +3 -0
  275. rasa/core/nlg/callback.py +147 -0
  276. rasa/core/nlg/contextual_response_rephraser.py +270 -0
  277. rasa/core/nlg/generator.py +230 -0
  278. rasa/core/nlg/interpolator.py +143 -0
  279. rasa/core/nlg/response.py +155 -0
  280. rasa/core/nlg/summarize.py +69 -0
  281. rasa/core/policies/__init__.py +0 -0
  282. rasa/core/policies/ensemble.py +329 -0
  283. rasa/core/policies/enterprise_search_policy.py +717 -0
  284. rasa/core/policies/enterprise_search_prompt_template.jinja2 +62 -0
  285. rasa/core/policies/flow_policy.py +205 -0
  286. rasa/core/policies/flows/__init__.py +0 -0
  287. rasa/core/policies/flows/flow_exceptions.py +44 -0
  288. rasa/core/policies/flows/flow_executor.py +582 -0
  289. rasa/core/policies/flows/flow_step_result.py +43 -0
  290. rasa/core/policies/intentless_policy.py +924 -0
  291. rasa/core/policies/intentless_prompt_template.jinja2 +22 -0
  292. rasa/core/policies/memoization.py +538 -0
  293. rasa/core/policies/policy.py +716 -0
  294. rasa/core/policies/rule_policy.py +1276 -0
  295. rasa/core/policies/ted_policy.py +2146 -0
  296. rasa/core/policies/unexpected_intent_policy.py +1015 -0
  297. rasa/core/processor.py +1331 -0
  298. rasa/core/run.py +315 -0
  299. rasa/core/secrets_manager/__init__.py +0 -0
  300. rasa/core/secrets_manager/constants.py +32 -0
  301. rasa/core/secrets_manager/endpoints.py +391 -0
  302. rasa/core/secrets_manager/factory.py +233 -0
  303. rasa/core/secrets_manager/secret_manager.py +262 -0
  304. rasa/core/secrets_manager/vault.py +576 -0
  305. rasa/core/test.py +1337 -0
  306. rasa/core/tracker_store.py +1664 -0
  307. rasa/core/train.py +107 -0
  308. rasa/core/training/__init__.py +89 -0
  309. rasa/core/training/converters/__init__.py +0 -0
  310. rasa/core/training/converters/responses_prefix_converter.py +119 -0
  311. rasa/core/training/interactive.py +1742 -0
  312. rasa/core/training/story_conflict.py +381 -0
  313. rasa/core/training/training.py +93 -0
  314. rasa/core/utils.py +344 -0
  315. rasa/core/visualize.py +70 -0
  316. rasa/dialogue_understanding/__init__.py +0 -0
  317. rasa/dialogue_understanding/coexistence/__init__.py +0 -0
  318. rasa/dialogue_understanding/coexistence/constants.py +4 -0
  319. rasa/dialogue_understanding/coexistence/intent_based_router.py +189 -0
  320. rasa/dialogue_understanding/coexistence/llm_based_router.py +261 -0
  321. rasa/dialogue_understanding/coexistence/router_template.jinja2 +12 -0
  322. rasa/dialogue_understanding/commands/__init__.py +45 -0
  323. rasa/dialogue_understanding/commands/can_not_handle_command.py +61 -0
  324. rasa/dialogue_understanding/commands/cancel_flow_command.py +116 -0
  325. rasa/dialogue_understanding/commands/chit_chat_answer_command.py +48 -0
  326. rasa/dialogue_understanding/commands/clarify_command.py +77 -0
  327. rasa/dialogue_understanding/commands/command.py +85 -0
  328. rasa/dialogue_understanding/commands/correct_slots_command.py +288 -0
  329. rasa/dialogue_understanding/commands/error_command.py +67 -0
  330. rasa/dialogue_understanding/commands/free_form_answer_command.py +9 -0
  331. rasa/dialogue_understanding/commands/handle_code_change_command.py +64 -0
  332. rasa/dialogue_understanding/commands/human_handoff_command.py +57 -0
  333. rasa/dialogue_understanding/commands/knowledge_answer_command.py +48 -0
  334. rasa/dialogue_understanding/commands/noop_command.py +45 -0
  335. rasa/dialogue_understanding/commands/set_slot_command.py +125 -0
  336. rasa/dialogue_understanding/commands/skip_question_command.py +66 -0
  337. rasa/dialogue_understanding/commands/start_flow_command.py +98 -0
  338. rasa/dialogue_understanding/generator/__init__.py +6 -0
  339. rasa/dialogue_understanding/generator/command_generator.py +257 -0
  340. rasa/dialogue_understanding/generator/command_prompt_template.jinja2 +57 -0
  341. rasa/dialogue_understanding/generator/flow_document_template.jinja2 +4 -0
  342. rasa/dialogue_understanding/generator/flow_retrieval.py +410 -0
  343. rasa/dialogue_understanding/generator/llm_command_generator.py +637 -0
  344. rasa/dialogue_understanding/generator/nlu_command_adapter.py +157 -0
  345. rasa/dialogue_understanding/patterns/__init__.py +0 -0
  346. rasa/dialogue_understanding/patterns/cancel.py +111 -0
  347. rasa/dialogue_understanding/patterns/cannot_handle.py +43 -0
  348. rasa/dialogue_understanding/patterns/chitchat.py +37 -0
  349. rasa/dialogue_understanding/patterns/clarify.py +97 -0
  350. rasa/dialogue_understanding/patterns/code_change.py +41 -0
  351. rasa/dialogue_understanding/patterns/collect_information.py +90 -0
  352. rasa/dialogue_understanding/patterns/completed.py +40 -0
  353. rasa/dialogue_understanding/patterns/continue_interrupted.py +42 -0
  354. rasa/dialogue_understanding/patterns/correction.py +278 -0
  355. rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +243 -0
  356. rasa/dialogue_understanding/patterns/human_handoff.py +37 -0
  357. rasa/dialogue_understanding/patterns/internal_error.py +47 -0
  358. rasa/dialogue_understanding/patterns/search.py +37 -0
  359. rasa/dialogue_understanding/patterns/skip_question.py +38 -0
  360. rasa/dialogue_understanding/processor/__init__.py +0 -0
  361. rasa/dialogue_understanding/processor/command_processor.py +578 -0
  362. rasa/dialogue_understanding/processor/command_processor_component.py +39 -0
  363. rasa/dialogue_understanding/stack/__init__.py +0 -0
  364. rasa/dialogue_understanding/stack/dialogue_stack.py +178 -0
  365. rasa/dialogue_understanding/stack/frames/__init__.py +19 -0
  366. rasa/dialogue_understanding/stack/frames/chit_chat_frame.py +27 -0
  367. rasa/dialogue_understanding/stack/frames/dialogue_stack_frame.py +137 -0
  368. rasa/dialogue_understanding/stack/frames/flow_stack_frame.py +157 -0
  369. rasa/dialogue_understanding/stack/frames/pattern_frame.py +10 -0
  370. rasa/dialogue_understanding/stack/frames/search_frame.py +27 -0
  371. rasa/dialogue_understanding/stack/utils.py +211 -0
  372. rasa/e2e_test/__init__.py +0 -0
  373. rasa/e2e_test/constants.py +10 -0
  374. rasa/e2e_test/e2e_test_case.py +322 -0
  375. rasa/e2e_test/e2e_test_result.py +34 -0
  376. rasa/e2e_test/e2e_test_runner.py +659 -0
  377. rasa/e2e_test/e2e_test_schema.yml +67 -0
  378. rasa/engine/__init__.py +0 -0
  379. rasa/engine/caching.py +464 -0
  380. rasa/engine/constants.py +17 -0
  381. rasa/engine/exceptions.py +14 -0
  382. rasa/engine/graph.py +625 -0
  383. rasa/engine/loader.py +36 -0
  384. rasa/engine/recipes/__init__.py +0 -0
  385. rasa/engine/recipes/config_files/default_config.yml +44 -0
  386. rasa/engine/recipes/default_components.py +99 -0
  387. rasa/engine/recipes/default_recipe.py +1252 -0
  388. rasa/engine/recipes/graph_recipe.py +79 -0
  389. rasa/engine/recipes/recipe.py +93 -0
  390. rasa/engine/runner/__init__.py +0 -0
  391. rasa/engine/runner/dask.py +256 -0
  392. rasa/engine/runner/interface.py +49 -0
  393. rasa/engine/storage/__init__.py +0 -0
  394. rasa/engine/storage/local_model_storage.py +248 -0
  395. rasa/engine/storage/resource.py +110 -0
  396. rasa/engine/storage/storage.py +203 -0
  397. rasa/engine/training/__init__.py +0 -0
  398. rasa/engine/training/components.py +176 -0
  399. rasa/engine/training/fingerprinting.py +64 -0
  400. rasa/engine/training/graph_trainer.py +256 -0
  401. rasa/engine/training/hooks.py +164 -0
  402. rasa/engine/validation.py +839 -0
  403. rasa/env.py +5 -0
  404. rasa/exceptions.py +69 -0
  405. rasa/graph_components/__init__.py +0 -0
  406. rasa/graph_components/converters/__init__.py +0 -0
  407. rasa/graph_components/converters/nlu_message_converter.py +48 -0
  408. rasa/graph_components/providers/__init__.py +0 -0
  409. rasa/graph_components/providers/domain_for_core_training_provider.py +87 -0
  410. rasa/graph_components/providers/domain_provider.py +71 -0
  411. rasa/graph_components/providers/flows_provider.py +74 -0
  412. rasa/graph_components/providers/forms_provider.py +44 -0
  413. rasa/graph_components/providers/nlu_training_data_provider.py +56 -0
  414. rasa/graph_components/providers/responses_provider.py +44 -0
  415. rasa/graph_components/providers/rule_only_provider.py +49 -0
  416. rasa/graph_components/providers/story_graph_provider.py +43 -0
  417. rasa/graph_components/providers/training_tracker_provider.py +55 -0
  418. rasa/graph_components/validators/__init__.py +0 -0
  419. rasa/graph_components/validators/default_recipe_validator.py +552 -0
  420. rasa/graph_components/validators/finetuning_validator.py +302 -0
  421. rasa/hooks.py +113 -0
  422. rasa/jupyter.py +63 -0
  423. rasa/keys +1 -0
  424. rasa/markers/__init__.py +0 -0
  425. rasa/markers/marker.py +269 -0
  426. rasa/markers/marker_base.py +828 -0
  427. rasa/markers/upload.py +74 -0
  428. rasa/markers/validate.py +21 -0
  429. rasa/model.py +118 -0
  430. rasa/model_testing.py +457 -0
  431. rasa/model_training.py +535 -0
  432. rasa/nlu/__init__.py +7 -0
  433. rasa/nlu/classifiers/__init__.py +3 -0
  434. rasa/nlu/classifiers/classifier.py +5 -0
  435. rasa/nlu/classifiers/diet_classifier.py +1874 -0
  436. rasa/nlu/classifiers/fallback_classifier.py +192 -0
  437. rasa/nlu/classifiers/keyword_intent_classifier.py +188 -0
  438. rasa/nlu/classifiers/llm_intent_classifier.py +519 -0
  439. rasa/nlu/classifiers/logistic_regression_classifier.py +240 -0
  440. rasa/nlu/classifiers/mitie_intent_classifier.py +156 -0
  441. rasa/nlu/classifiers/regex_message_handler.py +56 -0
  442. rasa/nlu/classifiers/sklearn_intent_classifier.py +309 -0
  443. rasa/nlu/constants.py +77 -0
  444. rasa/nlu/convert.py +40 -0
  445. rasa/nlu/emulators/__init__.py +0 -0
  446. rasa/nlu/emulators/dialogflow.py +55 -0
  447. rasa/nlu/emulators/emulator.py +49 -0
  448. rasa/nlu/emulators/luis.py +86 -0
  449. rasa/nlu/emulators/no_emulator.py +10 -0
  450. rasa/nlu/emulators/wit.py +56 -0
  451. rasa/nlu/extractors/__init__.py +0 -0
  452. rasa/nlu/extractors/crf_entity_extractor.py +672 -0
  453. rasa/nlu/extractors/duckling_entity_extractor.py +206 -0
  454. rasa/nlu/extractors/entity_synonyms.py +178 -0
  455. rasa/nlu/extractors/extractor.py +470 -0
  456. rasa/nlu/extractors/mitie_entity_extractor.py +293 -0
  457. rasa/nlu/extractors/regex_entity_extractor.py +220 -0
  458. rasa/nlu/extractors/spacy_entity_extractor.py +95 -0
  459. rasa/nlu/featurizers/__init__.py +0 -0
  460. rasa/nlu/featurizers/dense_featurizer/__init__.py +0 -0
  461. rasa/nlu/featurizers/dense_featurizer/convert_featurizer.py +449 -0
  462. rasa/nlu/featurizers/dense_featurizer/dense_featurizer.py +57 -0
  463. rasa/nlu/featurizers/dense_featurizer/lm_featurizer.py +772 -0
  464. rasa/nlu/featurizers/dense_featurizer/mitie_featurizer.py +170 -0
  465. rasa/nlu/featurizers/dense_featurizer/spacy_featurizer.py +132 -0
  466. rasa/nlu/featurizers/featurizer.py +89 -0
  467. rasa/nlu/featurizers/sparse_featurizer/__init__.py +0 -0
  468. rasa/nlu/featurizers/sparse_featurizer/count_vectors_featurizer.py +840 -0
  469. rasa/nlu/featurizers/sparse_featurizer/lexical_syntactic_featurizer.py +539 -0
  470. rasa/nlu/featurizers/sparse_featurizer/regex_featurizer.py +269 -0
  471. rasa/nlu/featurizers/sparse_featurizer/sparse_featurizer.py +9 -0
  472. rasa/nlu/model.py +24 -0
  473. rasa/nlu/persistor.py +240 -0
  474. rasa/nlu/run.py +27 -0
  475. rasa/nlu/selectors/__init__.py +0 -0
  476. rasa/nlu/selectors/response_selector.py +990 -0
  477. rasa/nlu/test.py +1943 -0
  478. rasa/nlu/tokenizers/__init__.py +0 -0
  479. rasa/nlu/tokenizers/jieba_tokenizer.py +148 -0
  480. rasa/nlu/tokenizers/mitie_tokenizer.py +75 -0
  481. rasa/nlu/tokenizers/spacy_tokenizer.py +72 -0
  482. rasa/nlu/tokenizers/tokenizer.py +239 -0
  483. rasa/nlu/tokenizers/whitespace_tokenizer.py +106 -0
  484. rasa/nlu/utils/__init__.py +35 -0
  485. rasa/nlu/utils/bilou_utils.py +462 -0
  486. rasa/nlu/utils/hugging_face/__init__.py +0 -0
  487. rasa/nlu/utils/hugging_face/registry.py +108 -0
  488. rasa/nlu/utils/hugging_face/transformers_pre_post_processors.py +311 -0
  489. rasa/nlu/utils/mitie_utils.py +113 -0
  490. rasa/nlu/utils/pattern_utils.py +168 -0
  491. rasa/nlu/utils/spacy_utils.py +312 -0
  492. rasa/plugin.py +90 -0
  493. rasa/server.py +1536 -0
  494. rasa/shared/__init__.py +0 -0
  495. rasa/shared/constants.py +181 -0
  496. rasa/shared/core/__init__.py +0 -0
  497. rasa/shared/core/constants.py +168 -0
  498. rasa/shared/core/conversation.py +46 -0
  499. rasa/shared/core/domain.py +2106 -0
  500. rasa/shared/core/events.py +2507 -0
  501. rasa/shared/core/flows/__init__.py +7 -0
  502. rasa/shared/core/flows/flow.py +353 -0
  503. rasa/shared/core/flows/flow_step.py +146 -0
  504. rasa/shared/core/flows/flow_step_links.py +319 -0
  505. rasa/shared/core/flows/flow_step_sequence.py +70 -0
  506. rasa/shared/core/flows/flows_list.py +211 -0
  507. rasa/shared/core/flows/flows_yaml_schema.json +217 -0
  508. rasa/shared/core/flows/nlu_trigger.py +117 -0
  509. rasa/shared/core/flows/steps/__init__.py +24 -0
  510. rasa/shared/core/flows/steps/action.py +51 -0
  511. rasa/shared/core/flows/steps/call.py +64 -0
  512. rasa/shared/core/flows/steps/collect.py +112 -0
  513. rasa/shared/core/flows/steps/constants.py +5 -0
  514. rasa/shared/core/flows/steps/continuation.py +36 -0
  515. rasa/shared/core/flows/steps/end.py +22 -0
  516. rasa/shared/core/flows/steps/internal.py +44 -0
  517. rasa/shared/core/flows/steps/link.py +51 -0
  518. rasa/shared/core/flows/steps/no_operation.py +48 -0
  519. rasa/shared/core/flows/steps/set_slots.py +50 -0
  520. rasa/shared/core/flows/steps/start.py +30 -0
  521. rasa/shared/core/flows/validation.py +527 -0
  522. rasa/shared/core/flows/yaml_flows_io.py +278 -0
  523. rasa/shared/core/generator.py +907 -0
  524. rasa/shared/core/slot_mappings.py +235 -0
  525. rasa/shared/core/slots.py +647 -0
  526. rasa/shared/core/trackers.py +1159 -0
  527. rasa/shared/core/training_data/__init__.py +0 -0
  528. rasa/shared/core/training_data/loading.py +90 -0
  529. rasa/shared/core/training_data/story_reader/__init__.py +0 -0
  530. rasa/shared/core/training_data/story_reader/story_reader.py +129 -0
  531. rasa/shared/core/training_data/story_reader/story_step_builder.py +168 -0
  532. rasa/shared/core/training_data/story_reader/yaml_story_reader.py +888 -0
  533. rasa/shared/core/training_data/story_writer/__init__.py +0 -0
  534. rasa/shared/core/training_data/story_writer/story_writer.py +76 -0
  535. rasa/shared/core/training_data/story_writer/yaml_story_writer.py +442 -0
  536. rasa/shared/core/training_data/structures.py +838 -0
  537. rasa/shared/core/training_data/visualization.html +146 -0
  538. rasa/shared/core/training_data/visualization.py +603 -0
  539. rasa/shared/data.py +192 -0
  540. rasa/shared/engine/__init__.py +0 -0
  541. rasa/shared/engine/caching.py +26 -0
  542. rasa/shared/exceptions.py +129 -0
  543. rasa/shared/importers/__init__.py +0 -0
  544. rasa/shared/importers/importer.py +705 -0
  545. rasa/shared/importers/multi_project.py +203 -0
  546. rasa/shared/importers/rasa.py +100 -0
  547. rasa/shared/importers/utils.py +34 -0
  548. rasa/shared/nlu/__init__.py +0 -0
  549. rasa/shared/nlu/constants.py +45 -0
  550. rasa/shared/nlu/interpreter.py +10 -0
  551. rasa/shared/nlu/training_data/__init__.py +0 -0
  552. rasa/shared/nlu/training_data/entities_parser.py +209 -0
  553. rasa/shared/nlu/training_data/features.py +374 -0
  554. rasa/shared/nlu/training_data/formats/__init__.py +10 -0
  555. rasa/shared/nlu/training_data/formats/dialogflow.py +162 -0
  556. rasa/shared/nlu/training_data/formats/luis.py +87 -0
  557. rasa/shared/nlu/training_data/formats/rasa.py +135 -0
  558. rasa/shared/nlu/training_data/formats/rasa_yaml.py +605 -0
  559. rasa/shared/nlu/training_data/formats/readerwriter.py +245 -0
  560. rasa/shared/nlu/training_data/formats/wit.py +52 -0
  561. rasa/shared/nlu/training_data/loading.py +137 -0
  562. rasa/shared/nlu/training_data/lookup_tables_parser.py +30 -0
  563. rasa/shared/nlu/training_data/message.py +477 -0
  564. rasa/shared/nlu/training_data/schemas/__init__.py +0 -0
  565. rasa/shared/nlu/training_data/schemas/data_schema.py +85 -0
  566. rasa/shared/nlu/training_data/schemas/nlu.yml +53 -0
  567. rasa/shared/nlu/training_data/schemas/responses.yml +70 -0
  568. rasa/shared/nlu/training_data/synonyms_parser.py +42 -0
  569. rasa/shared/nlu/training_data/training_data.py +732 -0
  570. rasa/shared/nlu/training_data/util.py +223 -0
  571. rasa/shared/providers/__init__.py +0 -0
  572. rasa/shared/providers/openai/__init__.py +0 -0
  573. rasa/shared/providers/openai/clients.py +43 -0
  574. rasa/shared/providers/openai/session_handler.py +110 -0
  575. rasa/shared/utils/__init__.py +0 -0
  576. rasa/shared/utils/cli.py +72 -0
  577. rasa/shared/utils/common.py +308 -0
  578. rasa/shared/utils/constants.py +1 -0
  579. rasa/shared/utils/io.py +403 -0
  580. rasa/shared/utils/llm.py +405 -0
  581. rasa/shared/utils/pykwalify_extensions.py +26 -0
  582. rasa/shared/utils/schemas/__init__.py +0 -0
  583. rasa/shared/utils/schemas/config.yml +2 -0
  584. rasa/shared/utils/schemas/domain.yml +142 -0
  585. rasa/shared/utils/schemas/events.py +212 -0
  586. rasa/shared/utils/schemas/model_config.yml +46 -0
  587. rasa/shared/utils/schemas/stories.yml +173 -0
  588. rasa/shared/utils/yaml.py +777 -0
  589. rasa/studio/__init__.py +0 -0
  590. rasa/studio/auth.py +252 -0
  591. rasa/studio/config.py +127 -0
  592. rasa/studio/constants.py +16 -0
  593. rasa/studio/data_handler.py +352 -0
  594. rasa/studio/download.py +350 -0
  595. rasa/studio/train.py +136 -0
  596. rasa/studio/upload.py +408 -0
  597. rasa/telemetry.py +1583 -0
  598. rasa/tracing/__init__.py +0 -0
  599. rasa/tracing/config.py +338 -0
  600. rasa/tracing/constants.py +38 -0
  601. rasa/tracing/instrumentation/__init__.py +0 -0
  602. rasa/tracing/instrumentation/attribute_extractors.py +663 -0
  603. rasa/tracing/instrumentation/instrumentation.py +939 -0
  604. rasa/tracing/instrumentation/intentless_policy_instrumentation.py +142 -0
  605. rasa/tracing/instrumentation/metrics.py +206 -0
  606. rasa/tracing/metric_instrument_provider.py +125 -0
  607. rasa/utils/__init__.py +0 -0
  608. rasa/utils/beta.py +83 -0
  609. rasa/utils/cli.py +27 -0
  610. rasa/utils/common.py +635 -0
  611. rasa/utils/converter.py +53 -0
  612. rasa/utils/endpoints.py +303 -0
  613. rasa/utils/io.py +326 -0
  614. rasa/utils/licensing.py +319 -0
  615. rasa/utils/log_utils.py +174 -0
  616. rasa/utils/mapper.py +210 -0
  617. rasa/utils/ml_utils.py +145 -0
  618. rasa/utils/plotting.py +362 -0
  619. rasa/utils/singleton.py +23 -0
  620. rasa/utils/tensorflow/__init__.py +0 -0
  621. rasa/utils/tensorflow/callback.py +112 -0
  622. rasa/utils/tensorflow/constants.py +116 -0
  623. rasa/utils/tensorflow/crf.py +492 -0
  624. rasa/utils/tensorflow/data_generator.py +440 -0
  625. rasa/utils/tensorflow/environment.py +161 -0
  626. rasa/utils/tensorflow/exceptions.py +5 -0
  627. rasa/utils/tensorflow/layers.py +1565 -0
  628. rasa/utils/tensorflow/layers_utils.py +113 -0
  629. rasa/utils/tensorflow/metrics.py +281 -0
  630. rasa/utils/tensorflow/model_data.py +991 -0
  631. rasa/utils/tensorflow/model_data_utils.py +500 -0
  632. rasa/utils/tensorflow/models.py +936 -0
  633. rasa/utils/tensorflow/rasa_layers.py +1094 -0
  634. rasa/utils/tensorflow/transformer.py +640 -0
  635. rasa/utils/tensorflow/types.py +6 -0
  636. rasa/utils/train_utils.py +572 -0
  637. rasa/utils/yaml.py +54 -0
  638. rasa/validator.py +1035 -0
  639. rasa/version.py +3 -0
  640. rasa_pro-3.8.16.dist-info/METADATA +528 -0
  641. rasa_pro-3.8.16.dist-info/NOTICE +5 -0
  642. rasa_pro-3.8.16.dist-info/RECORD +644 -0
  643. rasa_pro-3.8.16.dist-info/WHEEL +4 -0
  644. rasa_pro-3.8.16.dist-info/entry_points.txt +3 -0
@@ -0,0 +1,907 @@
1
+ from collections import defaultdict, namedtuple, deque
2
+
3
+ import copy
4
+ import logging
5
+ import random
6
+ from contextlib import contextmanager
7
+
8
+ from tqdm import tqdm
9
+ from typing import (
10
+ Optional,
11
+ List,
12
+ Text,
13
+ Set,
14
+ Dict,
15
+ Tuple,
16
+ Deque,
17
+ DefaultDict,
18
+ Any,
19
+ Iterable,
20
+ Generator,
21
+ )
22
+
23
+ from rasa.shared.constants import DOCS_URL_STORIES
24
+ from rasa.shared.core.constants import SHOULD_NOT_BE_SET
25
+ from rasa.shared.core.domain import Domain, State
26
+ from rasa.shared.core.events import (
27
+ ActionExecuted,
28
+ UserUttered,
29
+ ActionReverted,
30
+ UserUtteranceReverted,
31
+ Restarted,
32
+ Event,
33
+ SlotSet,
34
+ ActiveLoop,
35
+ )
36
+ from rasa.shared.core.trackers import DialogueStateTracker, FrozenState
37
+ from rasa.shared.core.slots import Slot
38
+ from rasa.shared.core.training_data.structures import (
39
+ StoryGraph,
40
+ STORY_START,
41
+ StoryStep,
42
+ RuleStep,
43
+ GENERATED_CHECKPOINT_PREFIX,
44
+ )
45
+ from rasa.shared.utils.io import is_logging_disabled
46
+ import rasa.shared.utils.io
47
+
48
+ logger = logging.getLogger(__name__)
49
+
50
+ ExtractorConfig = namedtuple(
51
+ "ExtractorConfig",
52
+ "remove_duplicates "
53
+ "unique_last_num_states "
54
+ "augmentation_factor "
55
+ "max_number_of_augmented_trackers "
56
+ "tracker_limit "
57
+ "use_story_concatenation "
58
+ "rand",
59
+ )
60
+
61
+
62
+ class TrackerWithCachedStates(DialogueStateTracker):
63
+ """A tracker wrapper that caches the state creation of the tracker."""
64
+
65
+ def __init__(
66
+ self,
67
+ sender_id: Text,
68
+ slots: Optional[Iterable[Slot]],
69
+ max_event_history: Optional[int] = None,
70
+ domain: Optional[Domain] = None,
71
+ is_augmented: bool = False,
72
+ is_rule_tracker: bool = False,
73
+ ) -> None:
74
+ """Initializes a tracker with cached states."""
75
+ super().__init__(
76
+ sender_id, slots, max_event_history, is_rule_tracker=is_rule_tracker
77
+ )
78
+ self._states_for_hashing: Deque[FrozenState] = deque()
79
+ self.domain = domain if domain is not None else Domain.empty()
80
+ # T/F property to filter augmented stories
81
+ self.is_augmented = is_augmented
82
+ self.__skip_states = False
83
+
84
+ @classmethod
85
+ def from_events(
86
+ cls,
87
+ sender_id: Text,
88
+ evts: List[Event],
89
+ slots: Optional[Iterable[Slot]] = None,
90
+ max_event_history: Optional[int] = None,
91
+ sender_source: Optional[Text] = None,
92
+ domain: Optional[Domain] = None,
93
+ is_rule_tracker: bool = False,
94
+ ) -> "TrackerWithCachedStates":
95
+ """Initializes a tracker with given events."""
96
+ tracker = cls(
97
+ sender_id, slots, max_event_history, domain, is_rule_tracker=is_rule_tracker
98
+ )
99
+ for e in evts:
100
+ tracker.update(e)
101
+ return tracker
102
+
103
+ def past_states_for_hashing(
104
+ self, domain: Domain, omit_unset_slots: bool = False
105
+ ) -> Deque[FrozenState]:
106
+ """Generates and caches the past states of this tracker based on the history.
107
+
108
+ Args:
109
+ domain: a :class:`rasa.shared.core.domain.Domain`
110
+ omit_unset_slots: If `True` do not include the initial values of slots.
111
+
112
+ Returns:
113
+ A list of states
114
+ """
115
+ if domain != self.domain:
116
+ raise ValueError(
117
+ "TrackerWithCachedStates cannot be used with a domain "
118
+ "that is different from the one it was created with."
119
+ )
120
+
121
+ if omit_unset_slots:
122
+ # the tracker caches states with omit_unset_slots=False
123
+ # Retrieving them from cache with omit_unset_slots=True is not possible as
124
+ # this information is lost after a position in the event stream is turned
125
+ # into a state
126
+ states = super().past_states(domain, omit_unset_slots=omit_unset_slots)
127
+ states_for_hashing = deque(self.freeze_current_state(s) for s in states)
128
+ else:
129
+ # if don't have it cached, we use the domain to calculate the states
130
+ # from the events
131
+ # note: we ignore omit_unset_slots here as the cache was generated
132
+ # with the default value
133
+ states_for_hashing = self._states_for_hashing
134
+ if not states_for_hashing:
135
+ states = super().past_states(domain)
136
+ states_for_hashing = deque(self.freeze_current_state(s) for s in states)
137
+
138
+ self._states_for_hashing = states_for_hashing
139
+
140
+ return states_for_hashing
141
+
142
+ @staticmethod
143
+ def _unfreeze_states(frozen_states: Deque[FrozenState]) -> List[State]:
144
+ return [
145
+ {key: dict(value) for key, value in dict(frozen_state).items()}
146
+ for frozen_state in frozen_states
147
+ ]
148
+
149
+ def past_states(
150
+ self,
151
+ domain: Domain,
152
+ omit_unset_slots: bool = False,
153
+ ignore_rule_only_turns: bool = False,
154
+ rule_only_data: Optional[Dict[Text, Any]] = None,
155
+ ) -> List[State]:
156
+ """Generates the past states of this tracker based on the history.
157
+
158
+ Args:
159
+ domain: The Domain.
160
+ omit_unset_slots: If `True` do not include the initial values of slots.
161
+ ignore_rule_only_turns: If True ignore dialogue turns that are present
162
+ only in rules.
163
+ rule_only_data: Slots and loops,
164
+ which only occur in rules but not in stories.
165
+
166
+ Returns:
167
+ a list of states
168
+ """
169
+ states_for_hashing = self.past_states_for_hashing(
170
+ domain, omit_unset_slots=omit_unset_slots
171
+ )
172
+ return self._unfreeze_states(states_for_hashing)
173
+
174
+ def clear_states(self) -> None:
175
+ """Reset the states."""
176
+ self._states_for_hashing = deque()
177
+
178
+ def init_copy(self) -> "TrackerWithCachedStates":
179
+ """Create a new state tracker with the same initial values."""
180
+ return type(self)(
181
+ "",
182
+ self.slots.values(),
183
+ self._max_event_history,
184
+ self.domain,
185
+ self.is_augmented,
186
+ self.is_rule_tracker,
187
+ )
188
+
189
+ @contextmanager
190
+ def _skip_states_manager(self) -> Generator[None, None, None]:
191
+ self.__skip_states = True
192
+ try:
193
+ yield
194
+ finally:
195
+ self.__skip_states = False
196
+
197
+ def copy(
198
+ self, sender_id: Text = "", sender_source: Text = ""
199
+ ) -> "TrackerWithCachedStates":
200
+ """Creates a duplicate of this tracker.
201
+
202
+ A new tracker will be created and all events
203
+ will be replayed.
204
+ """
205
+ # This is an optimization, we could use the original copy, but
206
+ # the states would be lost and we would need to recalculate them
207
+
208
+ tracker = self.init_copy()
209
+ tracker.sender_id = sender_id
210
+ tracker.sender_source = sender_source
211
+
212
+ with tracker._skip_states_manager():
213
+ for event in self.events:
214
+ tracker.update(event)
215
+
216
+ tracker._states_for_hashing = copy.copy(self._states_for_hashing)
217
+
218
+ return tracker
219
+
220
+ def _append_current_state(self) -> None:
221
+ if self._states_for_hashing is None:
222
+ self._states_for_hashing = self.past_states_for_hashing(self.domain)
223
+ else:
224
+ state = self.domain.get_active_state(self)
225
+ frozen_state = self.freeze_current_state(state)
226
+ self._states_for_hashing.append(frozen_state)
227
+
228
+ def update(
229
+ self,
230
+ event: Event,
231
+ domain: Optional[Domain] = None,
232
+ ) -> None:
233
+ """Modify the state of the tracker according to an ``Event``."""
234
+ # if `skip_states` is `True`, this function behaves exactly like the
235
+ # normal update of the `DialogueStateTracker`
236
+ if not self._states_for_hashing and not self.__skip_states:
237
+ # rest of this function assumes we have the previous state
238
+ # cached. let's make sure it is there.
239
+ self._states_for_hashing = self.past_states_for_hashing(self.domain)
240
+
241
+ super().update(event)
242
+
243
+ if not self.__skip_states:
244
+ if isinstance(event, ActionExecuted):
245
+ pass
246
+ elif isinstance(event, ActionReverted):
247
+ self._states_for_hashing.pop() # removes the state after the action
248
+ self._states_for_hashing.pop() # removes the state used for the action
249
+ elif isinstance(event, UserUtteranceReverted):
250
+ self.clear_states()
251
+ elif isinstance(event, Restarted):
252
+ self.clear_states()
253
+ else:
254
+ self._states_for_hashing.pop()
255
+
256
+ self._append_current_state()
257
+
258
+
259
+ # define types
260
+ TrackerLookupDict = DefaultDict[Text, List[TrackerWithCachedStates]]
261
+
262
+ TrackersTuple = Tuple[List[TrackerWithCachedStates], List[TrackerWithCachedStates]]
263
+
264
+
265
+ class TrainingDataGenerator:
266
+ """Generates trackers from training data."""
267
+
268
+ def __init__(
269
+ self,
270
+ story_graph: StoryGraph,
271
+ domain: Domain,
272
+ remove_duplicates: bool = True,
273
+ unique_last_num_states: Optional[int] = None,
274
+ augmentation_factor: int = 50,
275
+ tracker_limit: Optional[int] = None,
276
+ use_story_concatenation: bool = True,
277
+ debug_plots: bool = False,
278
+ ):
279
+ """Given a set of story parts, generates all stories that are possible.
280
+
281
+ The different story parts can end and start with checkpoints
282
+ and this generator will match start and end checkpoints to
283
+ connect complete stories. Afterwards, duplicate stories will be
284
+ removed and the data is augmented (if augmentation is enabled).
285
+ """
286
+ self.story_graph = story_graph.with_cycles_removed()
287
+ if debug_plots:
288
+ self.story_graph.visualize("story_blocks_connections.html")
289
+
290
+ self.domain = domain
291
+
292
+ # 10x factor is a heuristic for augmentation rounds
293
+ max_number_of_augmented_trackers = augmentation_factor * 10
294
+
295
+ self.config = ExtractorConfig(
296
+ remove_duplicates=remove_duplicates,
297
+ unique_last_num_states=unique_last_num_states,
298
+ augmentation_factor=augmentation_factor,
299
+ max_number_of_augmented_trackers=max_number_of_augmented_trackers,
300
+ tracker_limit=tracker_limit,
301
+ use_story_concatenation=use_story_concatenation,
302
+ rand=random.Random(42),
303
+ )
304
+ # hashed featurization of all finished trackers
305
+ self.hashed_featurizations: Set[int] = set()
306
+
307
+ @staticmethod
308
+ def _phase_name(everything_reachable_is_reached: bool, phase: int) -> Text:
309
+ if everything_reachable_is_reached:
310
+ return f"augmentation round {phase}"
311
+ else:
312
+ return f"data generation round {phase}"
313
+
314
+ def generate(self) -> List[TrackerWithCachedStates]:
315
+ """Generate trackers from stories and rules.
316
+
317
+ Returns:
318
+ The generated trackers.
319
+ """
320
+ return self.generate_story_trackers() + self._generate_rule_trackers()
321
+
322
+ def generate_story_trackers(self) -> List[TrackerWithCachedStates]:
323
+ """Generate trackers from stories (exclude rule trackers).
324
+
325
+ Returns:
326
+ The generated story trackers.
327
+ """
328
+ steps = [
329
+ step
330
+ for step in self.story_graph.ordered_steps()
331
+ if not isinstance(step, RuleStep)
332
+ ]
333
+
334
+ return self._generate(steps, is_rule_data=False)
335
+
336
+ def _generate_rule_trackers(self) -> List[TrackerWithCachedStates]:
337
+ steps = [
338
+ step
339
+ for step in self.story_graph.ordered_steps()
340
+ if isinstance(step, RuleStep)
341
+ ]
342
+
343
+ return self._generate(steps, is_rule_data=True)
344
+
345
+ def _generate(
346
+ self, story_steps: List[StoryStep], is_rule_data: bool = False
347
+ ) -> List[TrackerWithCachedStates]:
348
+ if not story_steps:
349
+ logger.debug(f"No {'rules' if is_rule_data else 'story blocks'} found.")
350
+ return []
351
+
352
+ if self.config.remove_duplicates and self.config.unique_last_num_states:
353
+ logger.debug(
354
+ "Generated trackers will be deduplicated "
355
+ "based on their unique last {} states."
356
+ "".format(self.config.unique_last_num_states)
357
+ )
358
+ self._mark_first_action_in_story_steps_as_unpredictable()
359
+
360
+ active_trackers: DefaultDict[Text, List[TrackerWithCachedStates]] = defaultdict(
361
+ list
362
+ )
363
+
364
+ init_tracker = TrackerWithCachedStates(
365
+ "",
366
+ self.domain.slots,
367
+ max_event_history=self.config.tracker_limit,
368
+ domain=self.domain,
369
+ is_rule_tracker=is_rule_data,
370
+ )
371
+ active_trackers[STORY_START].append(init_tracker)
372
+
373
+ # trackers that are sent to a featurizer
374
+ finished_trackers = []
375
+ # keep story end trackers separately for augmentation
376
+ story_end_trackers = []
377
+
378
+ phase = 0 # one phase is one traversal of all story steps.
379
+
380
+ # do not augment rule data
381
+ if not is_rule_data:
382
+ min_num_aug_phases = 3 if self.config.augmentation_factor > 0 else 0
383
+ logger.debug(f"Number of augmentation rounds is {min_num_aug_phases}")
384
+ else:
385
+ min_num_aug_phases = 0
386
+
387
+ # placeholder to track gluing process of checkpoints
388
+ used_checkpoints: Set[Text] = set()
389
+ previous_unused: Set[Text] = set()
390
+ everything_reachable_is_reached = False
391
+
392
+ # we will continue generating data until we have reached all
393
+ # checkpoints that seem to be reachable. This is a heuristic,
394
+ # if we did not reach any new checkpoints in an iteration, we
395
+ # assume we have reached all and stop.
396
+
397
+ while not everything_reachable_is_reached or phase < min_num_aug_phases:
398
+ phase_name = self._phase_name(everything_reachable_is_reached, phase)
399
+
400
+ num_active_trackers = self._count_trackers(active_trackers)
401
+
402
+ if num_active_trackers:
403
+ logger.debug(
404
+ "Starting {} ... (with {} trackers)"
405
+ "".format(phase_name, num_active_trackers)
406
+ )
407
+ else:
408
+ logger.debug(f"There are no trackers for {phase_name}")
409
+ break
410
+
411
+ # track unused checkpoints for this phase
412
+ unused_checkpoints: Set[Text] = set()
413
+
414
+ desc = f"Processed {'rules' if is_rule_data else 'story blocks'}"
415
+ pbar = tqdm(story_steps, desc=desc, disable=is_logging_disabled())
416
+ for step in pbar:
417
+ incoming_trackers: List[TrackerWithCachedStates] = []
418
+ for start in step.start_checkpoints:
419
+ if active_trackers[start.name]:
420
+ ts = start.filter_trackers(active_trackers[start.name])
421
+ incoming_trackers.extend(ts)
422
+ used_checkpoints.add(start.name)
423
+ elif start.name not in used_checkpoints:
424
+ # need to skip - there was no previous step that
425
+ # had this start checkpoint as an end checkpoint
426
+ # it will be processed in next phases
427
+ unused_checkpoints.add(start.name)
428
+ if not incoming_trackers:
429
+ # if there are no trackers,
430
+ # we can skip the rest of the loop
431
+ continue
432
+
433
+ # these are the trackers that reached this story
434
+ # step and that need to handle all events of the step
435
+
436
+ if self.config.remove_duplicates:
437
+ incoming_trackers, end_trackers = self._remove_duplicate_trackers(
438
+ incoming_trackers
439
+ )
440
+
441
+ # append end trackers to finished trackers
442
+ finished_trackers.extend(end_trackers)
443
+
444
+ if everything_reachable_is_reached:
445
+ # augmentation round
446
+ incoming_trackers = self._subsample_trackers(
447
+ incoming_trackers, self.config.max_number_of_augmented_trackers
448
+ )
449
+
450
+ # update progress bar
451
+ pbar.set_postfix({"# trackers": "{:d}".format(len(incoming_trackers))})
452
+
453
+ trackers, end_trackers = self._process_step(step, incoming_trackers)
454
+
455
+ # add end trackers to finished trackers
456
+ finished_trackers.extend(end_trackers)
457
+
458
+ # update our tracker dictionary with the trackers
459
+ # that handled the events of the step and
460
+ # that can now be used for further story steps
461
+ # that start with the checkpoint this step ended with
462
+
463
+ for end in step.end_checkpoints:
464
+ start_name = self._find_start_checkpoint_name(end.name)
465
+
466
+ active_trackers[start_name].extend(trackers)
467
+
468
+ if start_name in used_checkpoints:
469
+ # add end checkpoint as unused
470
+ # if this checkpoint was processed as
471
+ # start one before
472
+ unused_checkpoints.add(start_name)
473
+
474
+ if not step.end_checkpoints:
475
+ unique_ends = self._remove_duplicate_story_end_trackers(trackers)
476
+ story_end_trackers.extend(unique_ends)
477
+
478
+ num_finished = len(finished_trackers) + len(story_end_trackers)
479
+ logger.debug(f"Finished phase ({num_finished} training samples found).")
480
+
481
+ # prepare next round
482
+ phase += 1
483
+
484
+ if not everything_reachable_is_reached:
485
+ # check if we reached all nodes that can be reached
486
+ # if we reached at least one more node this round
487
+ # than last one, we assume there is still
488
+ # something left to reach and we continue
489
+
490
+ unused_checkpoints = self._add_unused_end_checkpoints(
491
+ set(active_trackers.keys()), unused_checkpoints, used_checkpoints
492
+ )
493
+ active_trackers = self._filter_active_trackers(
494
+ active_trackers, unused_checkpoints
495
+ )
496
+ num_active_trackers = self._count_trackers(active_trackers)
497
+
498
+ everything_reachable_is_reached = (
499
+ unused_checkpoints == previous_unused or num_active_trackers == 0
500
+ )
501
+ previous_unused = unused_checkpoints
502
+
503
+ if everything_reachable_is_reached:
504
+ # should happen only once
505
+
506
+ previous_unused -= used_checkpoints
507
+ # add trackers with unused checkpoints
508
+ # to finished trackers
509
+ for start_name in previous_unused:
510
+ finished_trackers.extend(active_trackers[start_name])
511
+
512
+ logger.debug("Data generation rounds finished.")
513
+ logger.debug(
514
+ "Found {} unused checkpoints".format(len(previous_unused))
515
+ )
516
+ phase = 0
517
+ else:
518
+ logger.debug(
519
+ "Found {} unused checkpoints "
520
+ "in current phase."
521
+ "".format(len(unused_checkpoints))
522
+ )
523
+ logger.debug(
524
+ "Found {} active trackers "
525
+ "for these checkpoints."
526
+ "".format(num_active_trackers)
527
+ )
528
+
529
+ if everything_reachable_is_reached:
530
+ # augmentation round, so we process only
531
+ # story end checkpoints
532
+ # reset used checkpoints
533
+ used_checkpoints = set()
534
+
535
+ # generate active trackers for augmentation
536
+ active_trackers = self._create_start_trackers_for_augmentation(
537
+ story_end_trackers
538
+ )
539
+
540
+ finished_trackers.extend(story_end_trackers)
541
+ self._issue_unused_checkpoint_notification(previous_unused)
542
+ logger.debug("Found {} training trackers.".format(len(finished_trackers)))
543
+
544
+ if self.config.augmentation_factor > 0:
545
+ augmented_trackers, original_trackers = [], []
546
+ for t in finished_trackers:
547
+ if t.is_augmented:
548
+ augmented_trackers.append(t)
549
+ else:
550
+ original_trackers.append(t)
551
+ augmented_trackers = self._subsample_trackers(
552
+ augmented_trackers, self.config.max_number_of_augmented_trackers
553
+ )
554
+ logger.debug(
555
+ "Subsampled to {} augmented training trackers."
556
+ "".format(len(augmented_trackers))
557
+ )
558
+ logger.debug(
559
+ "There are {} original trackers.".format(len(original_trackers))
560
+ )
561
+ finished_trackers = original_trackers + augmented_trackers
562
+
563
+ return finished_trackers
564
+
565
+ @staticmethod
566
+ def _count_trackers(active_trackers: TrackerLookupDict) -> int:
567
+ """Count the number of trackers in the tracker dictionary."""
568
+ return sum(len(ts) for ts in active_trackers.values())
569
+
570
+ def _subsample_trackers(
571
+ self,
572
+ incoming_trackers: List[TrackerWithCachedStates],
573
+ max_number_of_trackers: int,
574
+ ) -> List[TrackerWithCachedStates]:
575
+ """Subsample the list of trackers to retrieve a random subset."""
576
+ # if flows get very long and have a lot of forks we
577
+ # get into trouble by collecting too many trackers
578
+ # hence the sub sampling
579
+ if max_number_of_trackers is not None:
580
+ return _subsample_array(
581
+ incoming_trackers, max_number_of_trackers, rand=self.config.rand
582
+ )
583
+ else:
584
+ return incoming_trackers
585
+
586
+ def _find_start_checkpoint_name(self, end_name: Text) -> Text:
587
+ """Find start checkpoint name given end checkpoint name of a cycle."""
588
+ return self.story_graph.story_end_checkpoints.get(end_name, end_name)
589
+
590
+ @staticmethod
591
+ def _add_unused_end_checkpoints(
592
+ start_checkpoints: Set[Text],
593
+ unused_checkpoints: Set[Text],
594
+ used_checkpoints: Set[Text],
595
+ ) -> Set[Text]:
596
+ """Add unused end checkpoints
597
+ if they were never encountered as start checkpoints.
598
+ """
599
+ return unused_checkpoints.union(
600
+ {
601
+ start_name
602
+ for start_name in start_checkpoints
603
+ if start_name not in used_checkpoints
604
+ }
605
+ )
606
+
607
+ @staticmethod
608
+ def _filter_active_trackers(
609
+ active_trackers: TrackerLookupDict, unused_checkpoints: Set[Text]
610
+ ) -> TrackerLookupDict:
611
+ """Filter active trackers that ended with unused checkpoint
612
+ or are parts of loops.
613
+ """
614
+ next_active_trackers = defaultdict(list)
615
+
616
+ for start_name in unused_checkpoints:
617
+ # process trackers ended with unused checkpoints further
618
+ if start_name != STORY_START:
619
+ # there is no point to process STORY_START checkpoint again
620
+ next_active_trackers[start_name] = active_trackers.get(start_name, [])
621
+
622
+ return next_active_trackers
623
+
624
+ def _create_start_trackers_for_augmentation(
625
+ self, story_end_trackers: List[TrackerWithCachedStates]
626
+ ) -> TrackerLookupDict:
627
+ """This is where the augmentation magic happens.
628
+
629
+ We will reuse all the trackers that reached the
630
+ end checkpoint `None` (which is the end of a
631
+ story) and start processing all steps again. So instead
632
+ of starting with a fresh tracker, the second and
633
+ all following phases will reuse a couple of the trackers
634
+ that made their way to a story end.
635
+
636
+ We need to do some cleanup before processing them again.
637
+ """
638
+ next_active_trackers = defaultdict(list)
639
+
640
+ if self.config.use_story_concatenation:
641
+ ending_trackers = _subsample_array(
642
+ story_end_trackers,
643
+ self.config.augmentation_factor,
644
+ rand=self.config.rand,
645
+ )
646
+ for t in ending_trackers:
647
+ # this is a nasty thing - all stories end and
648
+ # start with action listen - so after logging the first
649
+ # actions in the next phase the trackers would
650
+ # contain action listen followed by action listen.
651
+ # to fix this we are going to "undo" the last action listen
652
+
653
+ # tracker should be copied,
654
+ # otherwise original tracker is updated
655
+ aug_t = t.copy()
656
+ aug_t.is_augmented = True
657
+ aug_t.update(ActionReverted())
658
+ next_active_trackers[STORY_START].append(aug_t)
659
+
660
+ return next_active_trackers
661
+
662
+ def _process_step(
663
+ self, step: StoryStep, incoming_trackers: List[TrackerWithCachedStates]
664
+ ) -> TrackersTuple:
665
+ """Processes a steps events with all trackers.
666
+
667
+ The trackers that reached the steps starting checkpoint will
668
+ be used to process the events. Collects and returns training
669
+ data while processing the story step.
670
+ """
671
+ events = step.explicit_events(self.domain)
672
+
673
+ trackers = []
674
+ if events: # small optimization
675
+
676
+ # need to copy the tracker as multiple story steps
677
+ # might start with the same checkpoint and all of them
678
+ # will use the same set of incoming trackers
679
+
680
+ for tracker in incoming_trackers:
681
+ # sender id is used to be able for a human to see where the
682
+ # messages and events for this tracker came from - to do this
683
+ # we concatenate the story block names of the blocks that
684
+ # contribute to the trackers events
685
+ if tracker.sender_id:
686
+ if (
687
+ step.block_name
688
+ and step.block_name not in tracker.sender_id.split(" > ")
689
+ ):
690
+ new_sender = tracker.sender_id + " > " + step.block_name
691
+ else:
692
+ new_sender = tracker.sender_id
693
+ else:
694
+ new_sender = step.block_name
695
+ trackers.append(tracker.copy(new_sender, step.source_name))
696
+
697
+ end_trackers = []
698
+ for event in events:
699
+ if (
700
+ isinstance(event, ActionExecuted)
701
+ and event.action_text
702
+ and event.action_text not in self.domain.action_texts
703
+ ):
704
+ rasa.shared.utils.cli.print_warning(
705
+ f"Test story '{step.block_name}' in "
706
+ f"'{step.source_name}' contains the bot utterance "
707
+ f"'{event.action_text}', which is not part "
708
+ f"of the training data / domain."
709
+ )
710
+ for tracker in trackers:
711
+ if isinstance(
712
+ event, (ActionReverted, UserUtteranceReverted, Restarted)
713
+ ):
714
+ end_trackers.append(tracker.copy(tracker.sender_id))
715
+ if isinstance(step, RuleStep):
716
+ # The rules can specify that a form or a slot shouldn't be set,
717
+ # therefore we need to distinguish between not set
718
+ # and explicitly set to None
719
+ if isinstance(event, ActiveLoop) and event.name is None:
720
+ event.name = SHOULD_NOT_BE_SET
721
+
722
+ if isinstance(event, SlotSet) and event.value is None:
723
+ event.value = SHOULD_NOT_BE_SET
724
+
725
+ tracker.update(event)
726
+
727
+ # end trackers should be returned separately
728
+ # to avoid using them for augmentation
729
+ return trackers, end_trackers
730
+
731
+ def _remove_duplicate_trackers(
732
+ self, trackers: List[TrackerWithCachedStates]
733
+ ) -> TrackersTuple:
734
+ """Removes trackers that create equal featurizations
735
+ for current story step.
736
+
737
+ From multiple trackers that create equal featurizations
738
+ we only need to keep one. Because as we continue processing
739
+ events and story steps, all trackers that created the
740
+ same featurization once will do so in the future (as we
741
+ feed the same events to all trackers).
742
+ """
743
+ step_hashed_featurizations = set()
744
+
745
+ # collected trackers that created different featurizations
746
+ unique_trackers = [] # for current step
747
+ end_trackers = [] # for all steps
748
+
749
+ for tracker in trackers:
750
+ states_for_hashing = tuple(tracker.past_states_for_hashing(self.domain))
751
+ hashed = hash(states_for_hashing)
752
+
753
+ # only continue with trackers that created a
754
+ # hashed_featurization we haven't observed
755
+ if hashed not in step_hashed_featurizations:
756
+ if self.config.unique_last_num_states:
757
+ last_states = states_for_hashing[
758
+ -self.config.unique_last_num_states :
759
+ ]
760
+ last_hashed = hash(last_states)
761
+
762
+ if last_hashed not in step_hashed_featurizations:
763
+ step_hashed_featurizations.add(last_hashed)
764
+ unique_trackers.append(tracker)
765
+ elif (
766
+ len(states_for_hashing) > len(last_states)
767
+ and hashed not in self.hashed_featurizations
768
+ ):
769
+ self.hashed_featurizations.add(hashed)
770
+ end_trackers.append(tracker)
771
+ else:
772
+ unique_trackers.append(tracker)
773
+
774
+ step_hashed_featurizations.add(hashed)
775
+
776
+ return unique_trackers, end_trackers
777
+
778
+ def _remove_duplicate_story_end_trackers(
779
+ self, trackers: List[TrackerWithCachedStates]
780
+ ) -> List[TrackerWithCachedStates]:
781
+ """Removes trackers that reached story end and
782
+ created equal featurizations.
783
+ """
784
+ # collected trackers that created different featurizations
785
+ unique_trackers = [] # for all steps
786
+
787
+ # deduplication of finished trackers is needed,
788
+ # otherwise featurization does a lot of unnecessary work
789
+
790
+ for tracker in trackers:
791
+ states_for_hashing = tuple(tracker.past_states_for_hashing(self.domain))
792
+ hashed = hash(states_for_hashing + (tracker.is_rule_tracker,))
793
+
794
+ # only continue with trackers that created a
795
+ # hashed_featurization we haven't observed
796
+
797
+ if hashed not in self.hashed_featurizations:
798
+ self.hashed_featurizations.add(hashed)
799
+ unique_trackers.append(tracker)
800
+
801
+ return unique_trackers
802
+
803
+ def _mark_first_action_in_story_steps_as_unpredictable(self) -> None:
804
+ """Mark actions which shouldn't be used during ML training.
805
+
806
+ If a story starts with an action, we can not use
807
+ that first action as a training example, as there is no
808
+ history. There is one exception though, we do want to
809
+ predict action listen. But because stories never
810
+ contain action listen events (they are added when a
811
+ story gets converted to a dialogue) we need to apply a
812
+ small trick to avoid marking actions occurring after
813
+ an action listen as unpredictable.
814
+ """
815
+ for step in self.story_graph.story_steps:
816
+ # TODO: this does not work if a step is the conversational start
817
+ # as well as an intermediary part of a conversation.
818
+ # This means a checkpoint can either have multiple
819
+ # checkpoints OR be the start of a conversation
820
+ # but not both.
821
+ if STORY_START in {s.name for s in step.start_checkpoints}:
822
+ for i, e in enumerate(step.events):
823
+ if isinstance(e, UserUttered):
824
+ # if there is a user utterance, that means before the
825
+ # user uttered something there has to be
826
+ # an action listen. therefore, any action that comes
827
+ # after this user utterance isn't the first
828
+ # action anymore and the tracker used for prediction
829
+ # is not empty anymore. Hence, it is fine
830
+ # to predict anything that occurs after an utterance.
831
+ break
832
+ if isinstance(e, ActionExecuted):
833
+ e.unpredictable = True
834
+ break
835
+
836
+ def _issue_unused_checkpoint_notification(
837
+ self, unused_checkpoints: Set[Text]
838
+ ) -> None:
839
+ """Warns about unused story blocks.
840
+
841
+ Unused steps are ones having a start or end checkpoint
842
+ that no one provided.
843
+ """
844
+ if STORY_START in unused_checkpoints:
845
+ rasa.shared.utils.io.raise_warning(
846
+ "There is no starting story block "
847
+ "in the training data. "
848
+ "All your story blocks start with some checkpoint. "
849
+ "There should be at least one story block "
850
+ "that starts without any checkpoint.",
851
+ docs=DOCS_URL_STORIES + "#stories",
852
+ )
853
+
854
+ # running through the steps first will result in only one warning
855
+ # per block (as one block might have multiple steps)
856
+ collected_start = set()
857
+ collected_end = set()
858
+ for step in self.story_graph.story_steps:
859
+ for start in step.start_checkpoints:
860
+ if start.name in unused_checkpoints:
861
+ # After processing, there shouldn't be a story part left.
862
+ # This indicates a start checkpoint that doesn't exist
863
+ collected_start.add((start.name, step.block_name))
864
+
865
+ for end in step.end_checkpoints:
866
+ if end.name in unused_checkpoints:
867
+ # After processing, there shouldn't be a story part left.
868
+ # This indicates an end checkpoint that doesn't exist
869
+ collected_end.add((end.name, step.block_name))
870
+
871
+ for cp, block_name in collected_start:
872
+ if not cp.startswith(GENERATED_CHECKPOINT_PREFIX):
873
+ rasa.shared.utils.io.raise_warning(
874
+ f"Unsatisfied start checkpoint '{cp}' "
875
+ f"in block '{block_name}'. "
876
+ f"Remove this checkpoint or add "
877
+ f"story blocks that end "
878
+ f"with this checkpoint.",
879
+ docs=DOCS_URL_STORIES + "#checkpoints",
880
+ )
881
+
882
+ for cp, block_name in collected_end:
883
+ if not cp.startswith(GENERATED_CHECKPOINT_PREFIX):
884
+ rasa.shared.utils.io.raise_warning(
885
+ f"Unsatisfied end checkpoint '{cp}' "
886
+ f"in block '{block_name}'. "
887
+ f"Remove this checkpoint or add "
888
+ f"story blocks that start "
889
+ f"with this checkpoint.",
890
+ docs=DOCS_URL_STORIES + "#checkpoints",
891
+ )
892
+
893
+
894
+ def _subsample_array(
895
+ arr: List[Any],
896
+ max_values: int,
897
+ can_modify_incoming_array: bool = True,
898
+ rand: Optional[random.Random] = None,
899
+ ) -> List[Any]:
900
+ """Shuffles the array and returns `max_values` number of elements."""
901
+ if not can_modify_incoming_array:
902
+ arr = arr[:]
903
+ if rand is not None:
904
+ rand.shuffle(arr)
905
+ else:
906
+ random.shuffle(arr)
907
+ return arr[:max_values]