rasa-pro 3.8.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rasa-pro might be problematic. Click here for more details.

Files changed (644) hide show
  1. README.md +380 -0
  2. rasa/__init__.py +10 -0
  3. rasa/__main__.py +151 -0
  4. rasa/anonymization/__init__.py +2 -0
  5. rasa/anonymization/anonymisation_rule_yaml_reader.py +91 -0
  6. rasa/anonymization/anonymization_pipeline.py +287 -0
  7. rasa/anonymization/anonymization_rule_executor.py +260 -0
  8. rasa/anonymization/anonymization_rule_orchestrator.py +120 -0
  9. rasa/anonymization/schemas/config.yml +47 -0
  10. rasa/anonymization/utils.py +117 -0
  11. rasa/api.py +146 -0
  12. rasa/cli/__init__.py +5 -0
  13. rasa/cli/arguments/__init__.py +0 -0
  14. rasa/cli/arguments/data.py +81 -0
  15. rasa/cli/arguments/default_arguments.py +165 -0
  16. rasa/cli/arguments/evaluate.py +65 -0
  17. rasa/cli/arguments/export.py +51 -0
  18. rasa/cli/arguments/interactive.py +74 -0
  19. rasa/cli/arguments/run.py +204 -0
  20. rasa/cli/arguments/shell.py +13 -0
  21. rasa/cli/arguments/test.py +211 -0
  22. rasa/cli/arguments/train.py +263 -0
  23. rasa/cli/arguments/visualize.py +34 -0
  24. rasa/cli/arguments/x.py +30 -0
  25. rasa/cli/data.py +292 -0
  26. rasa/cli/e2e_test.py +566 -0
  27. rasa/cli/evaluate.py +222 -0
  28. rasa/cli/export.py +251 -0
  29. rasa/cli/inspect.py +63 -0
  30. rasa/cli/interactive.py +164 -0
  31. rasa/cli/license.py +65 -0
  32. rasa/cli/markers.py +78 -0
  33. rasa/cli/project_templates/__init__.py +0 -0
  34. rasa/cli/project_templates/calm/actions/__init__.py +0 -0
  35. rasa/cli/project_templates/calm/actions/action_template.py +27 -0
  36. rasa/cli/project_templates/calm/actions/add_contact.py +30 -0
  37. rasa/cli/project_templates/calm/actions/db.py +57 -0
  38. rasa/cli/project_templates/calm/actions/list_contacts.py +22 -0
  39. rasa/cli/project_templates/calm/actions/remove_contact.py +35 -0
  40. rasa/cli/project_templates/calm/config.yml +12 -0
  41. rasa/cli/project_templates/calm/credentials.yml +33 -0
  42. rasa/cli/project_templates/calm/data/flows/add_contact.yml +31 -0
  43. rasa/cli/project_templates/calm/data/flows/list_contacts.yml +14 -0
  44. rasa/cli/project_templates/calm/data/flows/remove_contact.yml +29 -0
  45. rasa/cli/project_templates/calm/db/contacts.json +10 -0
  46. rasa/cli/project_templates/calm/domain/add_contact.yml +33 -0
  47. rasa/cli/project_templates/calm/domain/list_contacts.yml +14 -0
  48. rasa/cli/project_templates/calm/domain/remove_contact.yml +31 -0
  49. rasa/cli/project_templates/calm/domain/shared.yml +5 -0
  50. rasa/cli/project_templates/calm/e2e_tests/cancelations/user_cancels_during_a_correction.yml +16 -0
  51. rasa/cli/project_templates/calm/e2e_tests/cancelations/user_changes_mind_on_a_whim.yml +7 -0
  52. rasa/cli/project_templates/calm/e2e_tests/corrections/user_corrects_contact_handle.yml +20 -0
  53. rasa/cli/project_templates/calm/e2e_tests/corrections/user_corrects_contact_name.yml +19 -0
  54. rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_adds_contact_to_their_list.yml +15 -0
  55. rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_lists_contacts.yml +5 -0
  56. rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_removes_contact.yml +11 -0
  57. rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_removes_contact_from_list.yml +12 -0
  58. rasa/cli/project_templates/calm/endpoints.yml +45 -0
  59. rasa/cli/project_templates/default/actions/__init__.py +0 -0
  60. rasa/cli/project_templates/default/actions/actions.py +27 -0
  61. rasa/cli/project_templates/default/config.yml +44 -0
  62. rasa/cli/project_templates/default/credentials.yml +33 -0
  63. rasa/cli/project_templates/default/data/nlu.yml +91 -0
  64. rasa/cli/project_templates/default/data/rules.yml +13 -0
  65. rasa/cli/project_templates/default/data/stories.yml +30 -0
  66. rasa/cli/project_templates/default/domain.yml +34 -0
  67. rasa/cli/project_templates/default/endpoints.yml +42 -0
  68. rasa/cli/project_templates/default/tests/test_stories.yml +91 -0
  69. rasa/cli/project_templates/tutorial/actions.py +22 -0
  70. rasa/cli/project_templates/tutorial/config.yml +11 -0
  71. rasa/cli/project_templates/tutorial/credentials.yml +33 -0
  72. rasa/cli/project_templates/tutorial/data/flows.yml +8 -0
  73. rasa/cli/project_templates/tutorial/domain.yml +17 -0
  74. rasa/cli/project_templates/tutorial/endpoints.yml +45 -0
  75. rasa/cli/run.py +136 -0
  76. rasa/cli/scaffold.py +268 -0
  77. rasa/cli/shell.py +141 -0
  78. rasa/cli/studio/__init__.py +0 -0
  79. rasa/cli/studio/download.py +51 -0
  80. rasa/cli/studio/studio.py +110 -0
  81. rasa/cli/studio/train.py +59 -0
  82. rasa/cli/studio/upload.py +85 -0
  83. rasa/cli/telemetry.py +90 -0
  84. rasa/cli/test.py +280 -0
  85. rasa/cli/train.py +260 -0
  86. rasa/cli/utils.py +453 -0
  87. rasa/cli/visualize.py +40 -0
  88. rasa/cli/x.py +205 -0
  89. rasa/constants.py +37 -0
  90. rasa/core/__init__.py +17 -0
  91. rasa/core/actions/__init__.py +0 -0
  92. rasa/core/actions/action.py +1450 -0
  93. rasa/core/actions/action_clean_stack.py +59 -0
  94. rasa/core/actions/action_run_slot_rejections.py +207 -0
  95. rasa/core/actions/action_trigger_chitchat.py +31 -0
  96. rasa/core/actions/action_trigger_flow.py +109 -0
  97. rasa/core/actions/action_trigger_search.py +31 -0
  98. rasa/core/actions/constants.py +2 -0
  99. rasa/core/actions/forms.py +737 -0
  100. rasa/core/actions/loops.py +111 -0
  101. rasa/core/actions/two_stage_fallback.py +186 -0
  102. rasa/core/agent.py +557 -0
  103. rasa/core/auth_retry_tracker_store.py +122 -0
  104. rasa/core/brokers/__init__.py +0 -0
  105. rasa/core/brokers/broker.py +126 -0
  106. rasa/core/brokers/file.py +58 -0
  107. rasa/core/brokers/kafka.py +322 -0
  108. rasa/core/brokers/pika.py +387 -0
  109. rasa/core/brokers/sql.py +86 -0
  110. rasa/core/channels/__init__.py +55 -0
  111. rasa/core/channels/audiocodes.py +463 -0
  112. rasa/core/channels/botframework.py +339 -0
  113. rasa/core/channels/callback.py +85 -0
  114. rasa/core/channels/channel.py +419 -0
  115. rasa/core/channels/console.py +243 -0
  116. rasa/core/channels/development_inspector.py +93 -0
  117. rasa/core/channels/facebook.py +422 -0
  118. rasa/core/channels/hangouts.py +335 -0
  119. rasa/core/channels/inspector/.eslintrc.cjs +25 -0
  120. rasa/core/channels/inspector/.gitignore +23 -0
  121. rasa/core/channels/inspector/README.md +54 -0
  122. rasa/core/channels/inspector/assets/favicon.ico +0 -0
  123. rasa/core/channels/inspector/assets/rasa-chat.js +2 -0
  124. rasa/core/channels/inspector/custom.d.ts +3 -0
  125. rasa/core/channels/inspector/dist/assets/arc-5623b6dc.js +1 -0
  126. rasa/core/channels/inspector/dist/assets/array-9f3ba611.js +1 -0
  127. rasa/core/channels/inspector/dist/assets/c4Diagram-d0fbc5ce-685c106a.js +10 -0
  128. rasa/core/channels/inspector/dist/assets/classDiagram-936ed81e-8cbed007.js +2 -0
  129. rasa/core/channels/inspector/dist/assets/classDiagram-v2-c3cb15f1-5889cf12.js +2 -0
  130. rasa/core/channels/inspector/dist/assets/createText-62fc7601-24c249d7.js +7 -0
  131. rasa/core/channels/inspector/dist/assets/edges-f2ad444c-7dd06a75.js +4 -0
  132. rasa/core/channels/inspector/dist/assets/erDiagram-9d236eb7-62c1e54c.js +51 -0
  133. rasa/core/channels/inspector/dist/assets/flowDb-1972c806-ce49b86f.js +6 -0
  134. rasa/core/channels/inspector/dist/assets/flowDiagram-7ea5b25a-4067e48f.js +4 -0
  135. rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-85583a23.js +1 -0
  136. rasa/core/channels/inspector/dist/assets/flowchart-elk-definition-abe16c3d-59fe4051.js +139 -0
  137. rasa/core/channels/inspector/dist/assets/ganttDiagram-9b5ea136-47e3a43b.js +266 -0
  138. rasa/core/channels/inspector/dist/assets/gitGraphDiagram-99d0ae7c-5a2ac0d9.js +70 -0
  139. rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-128cfa44.ttf +0 -0
  140. rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-21dbcb97.woff +0 -0
  141. rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-222b5e26.svg +329 -0
  142. rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-9ad89b2a.woff2 +0 -0
  143. rasa/core/channels/inspector/dist/assets/index-268a75c0.js +1040 -0
  144. rasa/core/channels/inspector/dist/assets/index-2c4b9a3b-dfb8efc4.js +1 -0
  145. rasa/core/channels/inspector/dist/assets/index-3ee28881.css +1 -0
  146. rasa/core/channels/inspector/dist/assets/infoDiagram-736b4530-b0c470f2.js +7 -0
  147. rasa/core/channels/inspector/dist/assets/init-77b53fdd.js +1 -0
  148. rasa/core/channels/inspector/dist/assets/journeyDiagram-df861f2b-2edb829a.js +139 -0
  149. rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-60c05ee4.woff +0 -0
  150. rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-8335d9b8.svg +438 -0
  151. rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-9cc39c75.ttf +0 -0
  152. rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-ead13ccf.woff2 +0 -0
  153. rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-16705655.woff2 +0 -0
  154. rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-5aeb07f9.woff +0 -0
  155. rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-9c459044.ttf +0 -0
  156. rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-9e2898a4.svg +435 -0
  157. rasa/core/channels/inspector/dist/assets/layout-b6873d69.js +1 -0
  158. rasa/core/channels/inspector/dist/assets/line-1efc5781.js +1 -0
  159. rasa/core/channels/inspector/dist/assets/linear-661e9b94.js +1 -0
  160. rasa/core/channels/inspector/dist/assets/mindmap-definition-beec6740-2d2e727f.js +109 -0
  161. rasa/core/channels/inspector/dist/assets/ordinal-ba9b4969.js +1 -0
  162. rasa/core/channels/inspector/dist/assets/path-53f90ab3.js +1 -0
  163. rasa/core/channels/inspector/dist/assets/pieDiagram-dbbf0591-9d3ea93d.js +35 -0
  164. rasa/core/channels/inspector/dist/assets/quadrantDiagram-4d7f4fd6-06a178a2.js +7 -0
  165. rasa/core/channels/inspector/dist/assets/requirementDiagram-6fc4c22a-0bfedffc.js +52 -0
  166. rasa/core/channels/inspector/dist/assets/sankeyDiagram-8f13d901-d76d0a04.js +8 -0
  167. rasa/core/channels/inspector/dist/assets/sequenceDiagram-b655622a-37bb4341.js +122 -0
  168. rasa/core/channels/inspector/dist/assets/stateDiagram-59f0c015-f52f7f57.js +1 -0
  169. rasa/core/channels/inspector/dist/assets/stateDiagram-v2-2b26beab-4a986a20.js +1 -0
  170. rasa/core/channels/inspector/dist/assets/styles-080da4f6-7dd9ae12.js +110 -0
  171. rasa/core/channels/inspector/dist/assets/styles-3dcbcfbf-46e1ca14.js +159 -0
  172. rasa/core/channels/inspector/dist/assets/styles-9c745c82-4a97439a.js +207 -0
  173. rasa/core/channels/inspector/dist/assets/svgDrawCommon-4835440b-823917a3.js +1 -0
  174. rasa/core/channels/inspector/dist/assets/timeline-definition-5b62e21b-9ea72896.js +61 -0
  175. rasa/core/channels/inspector/dist/assets/xychartDiagram-2b33534f-b631a8b6.js +7 -0
  176. rasa/core/channels/inspector/dist/index.html +39 -0
  177. rasa/core/channels/inspector/index.html +37 -0
  178. rasa/core/channels/inspector/jest.config.ts +13 -0
  179. rasa/core/channels/inspector/package.json +48 -0
  180. rasa/core/channels/inspector/setupTests.ts +2 -0
  181. rasa/core/channels/inspector/src/App.tsx +170 -0
  182. rasa/core/channels/inspector/src/components/DiagramFlow.tsx +97 -0
  183. rasa/core/channels/inspector/src/components/DialogueInformation.tsx +187 -0
  184. rasa/core/channels/inspector/src/components/DialogueStack.tsx +151 -0
  185. rasa/core/channels/inspector/src/components/ExpandIcon.tsx +16 -0
  186. rasa/core/channels/inspector/src/components/FullscreenButton.tsx +45 -0
  187. rasa/core/channels/inspector/src/components/LoadingSpinner.tsx +19 -0
  188. rasa/core/channels/inspector/src/components/NoActiveFlow.tsx +21 -0
  189. rasa/core/channels/inspector/src/components/RasaLogo.tsx +32 -0
  190. rasa/core/channels/inspector/src/components/SaraDiagrams.tsx +39 -0
  191. rasa/core/channels/inspector/src/components/Slots.tsx +91 -0
  192. rasa/core/channels/inspector/src/components/Welcome.tsx +54 -0
  193. rasa/core/channels/inspector/src/helpers/formatters.test.ts +385 -0
  194. rasa/core/channels/inspector/src/helpers/formatters.ts +239 -0
  195. rasa/core/channels/inspector/src/helpers/utils.ts +42 -0
  196. rasa/core/channels/inspector/src/main.tsx +13 -0
  197. rasa/core/channels/inspector/src/theme/Button/Button.ts +29 -0
  198. rasa/core/channels/inspector/src/theme/Heading/Heading.ts +31 -0
  199. rasa/core/channels/inspector/src/theme/Input/Input.ts +27 -0
  200. rasa/core/channels/inspector/src/theme/Link/Link.ts +10 -0
  201. rasa/core/channels/inspector/src/theme/Modal/Modal.ts +47 -0
  202. rasa/core/channels/inspector/src/theme/Table/Table.tsx +38 -0
  203. rasa/core/channels/inspector/src/theme/Tooltip/Tooltip.ts +12 -0
  204. rasa/core/channels/inspector/src/theme/base/breakpoints.ts +8 -0
  205. rasa/core/channels/inspector/src/theme/base/colors.ts +88 -0
  206. rasa/core/channels/inspector/src/theme/base/fonts/fontFaces.css +29 -0
  207. rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.eot +0 -0
  208. rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.svg +329 -0
  209. rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.ttf +0 -0
  210. rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.woff +0 -0
  211. rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.woff2 +0 -0
  212. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.eot +0 -0
  213. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.svg +438 -0
  214. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.ttf +0 -0
  215. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.woff +0 -0
  216. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.woff2 +0 -0
  217. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.eot +0 -0
  218. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.svg +435 -0
  219. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.ttf +0 -0
  220. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.woff +0 -0
  221. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.woff2 +0 -0
  222. rasa/core/channels/inspector/src/theme/base/radii.ts +9 -0
  223. rasa/core/channels/inspector/src/theme/base/shadows.ts +7 -0
  224. rasa/core/channels/inspector/src/theme/base/sizes.ts +7 -0
  225. rasa/core/channels/inspector/src/theme/base/space.ts +15 -0
  226. rasa/core/channels/inspector/src/theme/base/styles.ts +13 -0
  227. rasa/core/channels/inspector/src/theme/base/typography.ts +24 -0
  228. rasa/core/channels/inspector/src/theme/base/zIndices.ts +19 -0
  229. rasa/core/channels/inspector/src/theme/index.ts +101 -0
  230. rasa/core/channels/inspector/src/types.ts +64 -0
  231. rasa/core/channels/inspector/src/vite-env.d.ts +1 -0
  232. rasa/core/channels/inspector/tests/__mocks__/fileMock.ts +1 -0
  233. rasa/core/channels/inspector/tests/__mocks__/matchMedia.ts +16 -0
  234. rasa/core/channels/inspector/tests/__mocks__/styleMock.ts +1 -0
  235. rasa/core/channels/inspector/tests/renderWithProviders.tsx +14 -0
  236. rasa/core/channels/inspector/tsconfig.json +26 -0
  237. rasa/core/channels/inspector/tsconfig.node.json +10 -0
  238. rasa/core/channels/inspector/vite.config.ts +8 -0
  239. rasa/core/channels/inspector/yarn.lock +6156 -0
  240. rasa/core/channels/mattermost.py +229 -0
  241. rasa/core/channels/rasa_chat.py +126 -0
  242. rasa/core/channels/rest.py +210 -0
  243. rasa/core/channels/rocketchat.py +175 -0
  244. rasa/core/channels/slack.py +620 -0
  245. rasa/core/channels/socketio.py +274 -0
  246. rasa/core/channels/telegram.py +298 -0
  247. rasa/core/channels/twilio.py +169 -0
  248. rasa/core/channels/twilio_voice.py +367 -0
  249. rasa/core/channels/vier_cvg.py +374 -0
  250. rasa/core/channels/webexteams.py +135 -0
  251. rasa/core/concurrent_lock_store.py +210 -0
  252. rasa/core/constants.py +107 -0
  253. rasa/core/evaluation/__init__.py +0 -0
  254. rasa/core/evaluation/marker.py +267 -0
  255. rasa/core/evaluation/marker_base.py +925 -0
  256. rasa/core/evaluation/marker_stats.py +294 -0
  257. rasa/core/evaluation/marker_tracker_loader.py +103 -0
  258. rasa/core/exceptions.py +29 -0
  259. rasa/core/exporter.py +284 -0
  260. rasa/core/featurizers/__init__.py +0 -0
  261. rasa/core/featurizers/precomputation.py +410 -0
  262. rasa/core/featurizers/single_state_featurizer.py +402 -0
  263. rasa/core/featurizers/tracker_featurizers.py +1172 -0
  264. rasa/core/http_interpreter.py +89 -0
  265. rasa/core/information_retrieval/__init__.py +0 -0
  266. rasa/core/information_retrieval/faiss.py +116 -0
  267. rasa/core/information_retrieval/information_retrieval.py +72 -0
  268. rasa/core/information_retrieval/milvus.py +59 -0
  269. rasa/core/information_retrieval/qdrant.py +102 -0
  270. rasa/core/jobs.py +63 -0
  271. rasa/core/lock.py +139 -0
  272. rasa/core/lock_store.py +344 -0
  273. rasa/core/migrate.py +404 -0
  274. rasa/core/nlg/__init__.py +3 -0
  275. rasa/core/nlg/callback.py +147 -0
  276. rasa/core/nlg/contextual_response_rephraser.py +270 -0
  277. rasa/core/nlg/generator.py +230 -0
  278. rasa/core/nlg/interpolator.py +143 -0
  279. rasa/core/nlg/response.py +155 -0
  280. rasa/core/nlg/summarize.py +69 -0
  281. rasa/core/policies/__init__.py +0 -0
  282. rasa/core/policies/ensemble.py +329 -0
  283. rasa/core/policies/enterprise_search_policy.py +717 -0
  284. rasa/core/policies/enterprise_search_prompt_template.jinja2 +62 -0
  285. rasa/core/policies/flow_policy.py +205 -0
  286. rasa/core/policies/flows/__init__.py +0 -0
  287. rasa/core/policies/flows/flow_exceptions.py +44 -0
  288. rasa/core/policies/flows/flow_executor.py +582 -0
  289. rasa/core/policies/flows/flow_step_result.py +43 -0
  290. rasa/core/policies/intentless_policy.py +924 -0
  291. rasa/core/policies/intentless_prompt_template.jinja2 +22 -0
  292. rasa/core/policies/memoization.py +538 -0
  293. rasa/core/policies/policy.py +716 -0
  294. rasa/core/policies/rule_policy.py +1276 -0
  295. rasa/core/policies/ted_policy.py +2146 -0
  296. rasa/core/policies/unexpected_intent_policy.py +1015 -0
  297. rasa/core/processor.py +1331 -0
  298. rasa/core/run.py +315 -0
  299. rasa/core/secrets_manager/__init__.py +0 -0
  300. rasa/core/secrets_manager/constants.py +32 -0
  301. rasa/core/secrets_manager/endpoints.py +391 -0
  302. rasa/core/secrets_manager/factory.py +233 -0
  303. rasa/core/secrets_manager/secret_manager.py +262 -0
  304. rasa/core/secrets_manager/vault.py +576 -0
  305. rasa/core/test.py +1337 -0
  306. rasa/core/tracker_store.py +1664 -0
  307. rasa/core/train.py +107 -0
  308. rasa/core/training/__init__.py +89 -0
  309. rasa/core/training/converters/__init__.py +0 -0
  310. rasa/core/training/converters/responses_prefix_converter.py +119 -0
  311. rasa/core/training/interactive.py +1742 -0
  312. rasa/core/training/story_conflict.py +381 -0
  313. rasa/core/training/training.py +93 -0
  314. rasa/core/utils.py +344 -0
  315. rasa/core/visualize.py +70 -0
  316. rasa/dialogue_understanding/__init__.py +0 -0
  317. rasa/dialogue_understanding/coexistence/__init__.py +0 -0
  318. rasa/dialogue_understanding/coexistence/constants.py +4 -0
  319. rasa/dialogue_understanding/coexistence/intent_based_router.py +189 -0
  320. rasa/dialogue_understanding/coexistence/llm_based_router.py +261 -0
  321. rasa/dialogue_understanding/coexistence/router_template.jinja2 +12 -0
  322. rasa/dialogue_understanding/commands/__init__.py +45 -0
  323. rasa/dialogue_understanding/commands/can_not_handle_command.py +61 -0
  324. rasa/dialogue_understanding/commands/cancel_flow_command.py +116 -0
  325. rasa/dialogue_understanding/commands/chit_chat_answer_command.py +48 -0
  326. rasa/dialogue_understanding/commands/clarify_command.py +77 -0
  327. rasa/dialogue_understanding/commands/command.py +85 -0
  328. rasa/dialogue_understanding/commands/correct_slots_command.py +288 -0
  329. rasa/dialogue_understanding/commands/error_command.py +67 -0
  330. rasa/dialogue_understanding/commands/free_form_answer_command.py +9 -0
  331. rasa/dialogue_understanding/commands/handle_code_change_command.py +64 -0
  332. rasa/dialogue_understanding/commands/human_handoff_command.py +57 -0
  333. rasa/dialogue_understanding/commands/knowledge_answer_command.py +48 -0
  334. rasa/dialogue_understanding/commands/noop_command.py +45 -0
  335. rasa/dialogue_understanding/commands/set_slot_command.py +125 -0
  336. rasa/dialogue_understanding/commands/skip_question_command.py +66 -0
  337. rasa/dialogue_understanding/commands/start_flow_command.py +98 -0
  338. rasa/dialogue_understanding/generator/__init__.py +6 -0
  339. rasa/dialogue_understanding/generator/command_generator.py +257 -0
  340. rasa/dialogue_understanding/generator/command_prompt_template.jinja2 +57 -0
  341. rasa/dialogue_understanding/generator/flow_document_template.jinja2 +4 -0
  342. rasa/dialogue_understanding/generator/flow_retrieval.py +410 -0
  343. rasa/dialogue_understanding/generator/llm_command_generator.py +637 -0
  344. rasa/dialogue_understanding/generator/nlu_command_adapter.py +157 -0
  345. rasa/dialogue_understanding/patterns/__init__.py +0 -0
  346. rasa/dialogue_understanding/patterns/cancel.py +111 -0
  347. rasa/dialogue_understanding/patterns/cannot_handle.py +43 -0
  348. rasa/dialogue_understanding/patterns/chitchat.py +37 -0
  349. rasa/dialogue_understanding/patterns/clarify.py +97 -0
  350. rasa/dialogue_understanding/patterns/code_change.py +41 -0
  351. rasa/dialogue_understanding/patterns/collect_information.py +90 -0
  352. rasa/dialogue_understanding/patterns/completed.py +40 -0
  353. rasa/dialogue_understanding/patterns/continue_interrupted.py +42 -0
  354. rasa/dialogue_understanding/patterns/correction.py +278 -0
  355. rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +243 -0
  356. rasa/dialogue_understanding/patterns/human_handoff.py +37 -0
  357. rasa/dialogue_understanding/patterns/internal_error.py +47 -0
  358. rasa/dialogue_understanding/patterns/search.py +37 -0
  359. rasa/dialogue_understanding/patterns/skip_question.py +38 -0
  360. rasa/dialogue_understanding/processor/__init__.py +0 -0
  361. rasa/dialogue_understanding/processor/command_processor.py +578 -0
  362. rasa/dialogue_understanding/processor/command_processor_component.py +39 -0
  363. rasa/dialogue_understanding/stack/__init__.py +0 -0
  364. rasa/dialogue_understanding/stack/dialogue_stack.py +178 -0
  365. rasa/dialogue_understanding/stack/frames/__init__.py +19 -0
  366. rasa/dialogue_understanding/stack/frames/chit_chat_frame.py +27 -0
  367. rasa/dialogue_understanding/stack/frames/dialogue_stack_frame.py +137 -0
  368. rasa/dialogue_understanding/stack/frames/flow_stack_frame.py +157 -0
  369. rasa/dialogue_understanding/stack/frames/pattern_frame.py +10 -0
  370. rasa/dialogue_understanding/stack/frames/search_frame.py +27 -0
  371. rasa/dialogue_understanding/stack/utils.py +211 -0
  372. rasa/e2e_test/__init__.py +0 -0
  373. rasa/e2e_test/constants.py +10 -0
  374. rasa/e2e_test/e2e_test_case.py +322 -0
  375. rasa/e2e_test/e2e_test_result.py +34 -0
  376. rasa/e2e_test/e2e_test_runner.py +659 -0
  377. rasa/e2e_test/e2e_test_schema.yml +67 -0
  378. rasa/engine/__init__.py +0 -0
  379. rasa/engine/caching.py +464 -0
  380. rasa/engine/constants.py +17 -0
  381. rasa/engine/exceptions.py +14 -0
  382. rasa/engine/graph.py +625 -0
  383. rasa/engine/loader.py +36 -0
  384. rasa/engine/recipes/__init__.py +0 -0
  385. rasa/engine/recipes/config_files/default_config.yml +44 -0
  386. rasa/engine/recipes/default_components.py +99 -0
  387. rasa/engine/recipes/default_recipe.py +1252 -0
  388. rasa/engine/recipes/graph_recipe.py +79 -0
  389. rasa/engine/recipes/recipe.py +93 -0
  390. rasa/engine/runner/__init__.py +0 -0
  391. rasa/engine/runner/dask.py +256 -0
  392. rasa/engine/runner/interface.py +49 -0
  393. rasa/engine/storage/__init__.py +0 -0
  394. rasa/engine/storage/local_model_storage.py +248 -0
  395. rasa/engine/storage/resource.py +110 -0
  396. rasa/engine/storage/storage.py +203 -0
  397. rasa/engine/training/__init__.py +0 -0
  398. rasa/engine/training/components.py +176 -0
  399. rasa/engine/training/fingerprinting.py +64 -0
  400. rasa/engine/training/graph_trainer.py +256 -0
  401. rasa/engine/training/hooks.py +164 -0
  402. rasa/engine/validation.py +839 -0
  403. rasa/env.py +5 -0
  404. rasa/exceptions.py +69 -0
  405. rasa/graph_components/__init__.py +0 -0
  406. rasa/graph_components/converters/__init__.py +0 -0
  407. rasa/graph_components/converters/nlu_message_converter.py +48 -0
  408. rasa/graph_components/providers/__init__.py +0 -0
  409. rasa/graph_components/providers/domain_for_core_training_provider.py +87 -0
  410. rasa/graph_components/providers/domain_provider.py +71 -0
  411. rasa/graph_components/providers/flows_provider.py +74 -0
  412. rasa/graph_components/providers/forms_provider.py +44 -0
  413. rasa/graph_components/providers/nlu_training_data_provider.py +56 -0
  414. rasa/graph_components/providers/responses_provider.py +44 -0
  415. rasa/graph_components/providers/rule_only_provider.py +49 -0
  416. rasa/graph_components/providers/story_graph_provider.py +43 -0
  417. rasa/graph_components/providers/training_tracker_provider.py +55 -0
  418. rasa/graph_components/validators/__init__.py +0 -0
  419. rasa/graph_components/validators/default_recipe_validator.py +552 -0
  420. rasa/graph_components/validators/finetuning_validator.py +302 -0
  421. rasa/hooks.py +113 -0
  422. rasa/jupyter.py +63 -0
  423. rasa/keys +1 -0
  424. rasa/markers/__init__.py +0 -0
  425. rasa/markers/marker.py +269 -0
  426. rasa/markers/marker_base.py +828 -0
  427. rasa/markers/upload.py +74 -0
  428. rasa/markers/validate.py +21 -0
  429. rasa/model.py +118 -0
  430. rasa/model_testing.py +457 -0
  431. rasa/model_training.py +535 -0
  432. rasa/nlu/__init__.py +7 -0
  433. rasa/nlu/classifiers/__init__.py +3 -0
  434. rasa/nlu/classifiers/classifier.py +5 -0
  435. rasa/nlu/classifiers/diet_classifier.py +1874 -0
  436. rasa/nlu/classifiers/fallback_classifier.py +192 -0
  437. rasa/nlu/classifiers/keyword_intent_classifier.py +188 -0
  438. rasa/nlu/classifiers/llm_intent_classifier.py +519 -0
  439. rasa/nlu/classifiers/logistic_regression_classifier.py +240 -0
  440. rasa/nlu/classifiers/mitie_intent_classifier.py +156 -0
  441. rasa/nlu/classifiers/regex_message_handler.py +56 -0
  442. rasa/nlu/classifiers/sklearn_intent_classifier.py +309 -0
  443. rasa/nlu/constants.py +77 -0
  444. rasa/nlu/convert.py +40 -0
  445. rasa/nlu/emulators/__init__.py +0 -0
  446. rasa/nlu/emulators/dialogflow.py +55 -0
  447. rasa/nlu/emulators/emulator.py +49 -0
  448. rasa/nlu/emulators/luis.py +86 -0
  449. rasa/nlu/emulators/no_emulator.py +10 -0
  450. rasa/nlu/emulators/wit.py +56 -0
  451. rasa/nlu/extractors/__init__.py +0 -0
  452. rasa/nlu/extractors/crf_entity_extractor.py +672 -0
  453. rasa/nlu/extractors/duckling_entity_extractor.py +206 -0
  454. rasa/nlu/extractors/entity_synonyms.py +178 -0
  455. rasa/nlu/extractors/extractor.py +470 -0
  456. rasa/nlu/extractors/mitie_entity_extractor.py +293 -0
  457. rasa/nlu/extractors/regex_entity_extractor.py +220 -0
  458. rasa/nlu/extractors/spacy_entity_extractor.py +95 -0
  459. rasa/nlu/featurizers/__init__.py +0 -0
  460. rasa/nlu/featurizers/dense_featurizer/__init__.py +0 -0
  461. rasa/nlu/featurizers/dense_featurizer/convert_featurizer.py +449 -0
  462. rasa/nlu/featurizers/dense_featurizer/dense_featurizer.py +57 -0
  463. rasa/nlu/featurizers/dense_featurizer/lm_featurizer.py +772 -0
  464. rasa/nlu/featurizers/dense_featurizer/mitie_featurizer.py +170 -0
  465. rasa/nlu/featurizers/dense_featurizer/spacy_featurizer.py +132 -0
  466. rasa/nlu/featurizers/featurizer.py +89 -0
  467. rasa/nlu/featurizers/sparse_featurizer/__init__.py +0 -0
  468. rasa/nlu/featurizers/sparse_featurizer/count_vectors_featurizer.py +840 -0
  469. rasa/nlu/featurizers/sparse_featurizer/lexical_syntactic_featurizer.py +539 -0
  470. rasa/nlu/featurizers/sparse_featurizer/regex_featurizer.py +269 -0
  471. rasa/nlu/featurizers/sparse_featurizer/sparse_featurizer.py +9 -0
  472. rasa/nlu/model.py +24 -0
  473. rasa/nlu/persistor.py +240 -0
  474. rasa/nlu/run.py +27 -0
  475. rasa/nlu/selectors/__init__.py +0 -0
  476. rasa/nlu/selectors/response_selector.py +990 -0
  477. rasa/nlu/test.py +1943 -0
  478. rasa/nlu/tokenizers/__init__.py +0 -0
  479. rasa/nlu/tokenizers/jieba_tokenizer.py +148 -0
  480. rasa/nlu/tokenizers/mitie_tokenizer.py +75 -0
  481. rasa/nlu/tokenizers/spacy_tokenizer.py +72 -0
  482. rasa/nlu/tokenizers/tokenizer.py +239 -0
  483. rasa/nlu/tokenizers/whitespace_tokenizer.py +106 -0
  484. rasa/nlu/utils/__init__.py +35 -0
  485. rasa/nlu/utils/bilou_utils.py +462 -0
  486. rasa/nlu/utils/hugging_face/__init__.py +0 -0
  487. rasa/nlu/utils/hugging_face/registry.py +108 -0
  488. rasa/nlu/utils/hugging_face/transformers_pre_post_processors.py +311 -0
  489. rasa/nlu/utils/mitie_utils.py +113 -0
  490. rasa/nlu/utils/pattern_utils.py +168 -0
  491. rasa/nlu/utils/spacy_utils.py +312 -0
  492. rasa/plugin.py +90 -0
  493. rasa/server.py +1536 -0
  494. rasa/shared/__init__.py +0 -0
  495. rasa/shared/constants.py +181 -0
  496. rasa/shared/core/__init__.py +0 -0
  497. rasa/shared/core/constants.py +168 -0
  498. rasa/shared/core/conversation.py +46 -0
  499. rasa/shared/core/domain.py +2106 -0
  500. rasa/shared/core/events.py +2507 -0
  501. rasa/shared/core/flows/__init__.py +7 -0
  502. rasa/shared/core/flows/flow.py +353 -0
  503. rasa/shared/core/flows/flow_step.py +146 -0
  504. rasa/shared/core/flows/flow_step_links.py +319 -0
  505. rasa/shared/core/flows/flow_step_sequence.py +70 -0
  506. rasa/shared/core/flows/flows_list.py +211 -0
  507. rasa/shared/core/flows/flows_yaml_schema.json +217 -0
  508. rasa/shared/core/flows/nlu_trigger.py +117 -0
  509. rasa/shared/core/flows/steps/__init__.py +24 -0
  510. rasa/shared/core/flows/steps/action.py +51 -0
  511. rasa/shared/core/flows/steps/call.py +64 -0
  512. rasa/shared/core/flows/steps/collect.py +112 -0
  513. rasa/shared/core/flows/steps/constants.py +5 -0
  514. rasa/shared/core/flows/steps/continuation.py +36 -0
  515. rasa/shared/core/flows/steps/end.py +22 -0
  516. rasa/shared/core/flows/steps/internal.py +44 -0
  517. rasa/shared/core/flows/steps/link.py +51 -0
  518. rasa/shared/core/flows/steps/no_operation.py +48 -0
  519. rasa/shared/core/flows/steps/set_slots.py +50 -0
  520. rasa/shared/core/flows/steps/start.py +30 -0
  521. rasa/shared/core/flows/validation.py +527 -0
  522. rasa/shared/core/flows/yaml_flows_io.py +278 -0
  523. rasa/shared/core/generator.py +907 -0
  524. rasa/shared/core/slot_mappings.py +235 -0
  525. rasa/shared/core/slots.py +647 -0
  526. rasa/shared/core/trackers.py +1159 -0
  527. rasa/shared/core/training_data/__init__.py +0 -0
  528. rasa/shared/core/training_data/loading.py +90 -0
  529. rasa/shared/core/training_data/story_reader/__init__.py +0 -0
  530. rasa/shared/core/training_data/story_reader/story_reader.py +129 -0
  531. rasa/shared/core/training_data/story_reader/story_step_builder.py +168 -0
  532. rasa/shared/core/training_data/story_reader/yaml_story_reader.py +888 -0
  533. rasa/shared/core/training_data/story_writer/__init__.py +0 -0
  534. rasa/shared/core/training_data/story_writer/story_writer.py +76 -0
  535. rasa/shared/core/training_data/story_writer/yaml_story_writer.py +442 -0
  536. rasa/shared/core/training_data/structures.py +838 -0
  537. rasa/shared/core/training_data/visualization.html +146 -0
  538. rasa/shared/core/training_data/visualization.py +603 -0
  539. rasa/shared/data.py +192 -0
  540. rasa/shared/engine/__init__.py +0 -0
  541. rasa/shared/engine/caching.py +26 -0
  542. rasa/shared/exceptions.py +129 -0
  543. rasa/shared/importers/__init__.py +0 -0
  544. rasa/shared/importers/importer.py +705 -0
  545. rasa/shared/importers/multi_project.py +203 -0
  546. rasa/shared/importers/rasa.py +100 -0
  547. rasa/shared/importers/utils.py +34 -0
  548. rasa/shared/nlu/__init__.py +0 -0
  549. rasa/shared/nlu/constants.py +45 -0
  550. rasa/shared/nlu/interpreter.py +10 -0
  551. rasa/shared/nlu/training_data/__init__.py +0 -0
  552. rasa/shared/nlu/training_data/entities_parser.py +209 -0
  553. rasa/shared/nlu/training_data/features.py +374 -0
  554. rasa/shared/nlu/training_data/formats/__init__.py +10 -0
  555. rasa/shared/nlu/training_data/formats/dialogflow.py +162 -0
  556. rasa/shared/nlu/training_data/formats/luis.py +87 -0
  557. rasa/shared/nlu/training_data/formats/rasa.py +135 -0
  558. rasa/shared/nlu/training_data/formats/rasa_yaml.py +605 -0
  559. rasa/shared/nlu/training_data/formats/readerwriter.py +245 -0
  560. rasa/shared/nlu/training_data/formats/wit.py +52 -0
  561. rasa/shared/nlu/training_data/loading.py +137 -0
  562. rasa/shared/nlu/training_data/lookup_tables_parser.py +30 -0
  563. rasa/shared/nlu/training_data/message.py +477 -0
  564. rasa/shared/nlu/training_data/schemas/__init__.py +0 -0
  565. rasa/shared/nlu/training_data/schemas/data_schema.py +85 -0
  566. rasa/shared/nlu/training_data/schemas/nlu.yml +53 -0
  567. rasa/shared/nlu/training_data/schemas/responses.yml +70 -0
  568. rasa/shared/nlu/training_data/synonyms_parser.py +42 -0
  569. rasa/shared/nlu/training_data/training_data.py +732 -0
  570. rasa/shared/nlu/training_data/util.py +223 -0
  571. rasa/shared/providers/__init__.py +0 -0
  572. rasa/shared/providers/openai/__init__.py +0 -0
  573. rasa/shared/providers/openai/clients.py +43 -0
  574. rasa/shared/providers/openai/session_handler.py +110 -0
  575. rasa/shared/utils/__init__.py +0 -0
  576. rasa/shared/utils/cli.py +72 -0
  577. rasa/shared/utils/common.py +308 -0
  578. rasa/shared/utils/constants.py +1 -0
  579. rasa/shared/utils/io.py +403 -0
  580. rasa/shared/utils/llm.py +405 -0
  581. rasa/shared/utils/pykwalify_extensions.py +26 -0
  582. rasa/shared/utils/schemas/__init__.py +0 -0
  583. rasa/shared/utils/schemas/config.yml +2 -0
  584. rasa/shared/utils/schemas/domain.yml +142 -0
  585. rasa/shared/utils/schemas/events.py +212 -0
  586. rasa/shared/utils/schemas/model_config.yml +46 -0
  587. rasa/shared/utils/schemas/stories.yml +173 -0
  588. rasa/shared/utils/yaml.py +777 -0
  589. rasa/studio/__init__.py +0 -0
  590. rasa/studio/auth.py +252 -0
  591. rasa/studio/config.py +127 -0
  592. rasa/studio/constants.py +16 -0
  593. rasa/studio/data_handler.py +352 -0
  594. rasa/studio/download.py +350 -0
  595. rasa/studio/train.py +136 -0
  596. rasa/studio/upload.py +408 -0
  597. rasa/telemetry.py +1583 -0
  598. rasa/tracing/__init__.py +0 -0
  599. rasa/tracing/config.py +338 -0
  600. rasa/tracing/constants.py +38 -0
  601. rasa/tracing/instrumentation/__init__.py +0 -0
  602. rasa/tracing/instrumentation/attribute_extractors.py +663 -0
  603. rasa/tracing/instrumentation/instrumentation.py +939 -0
  604. rasa/tracing/instrumentation/intentless_policy_instrumentation.py +142 -0
  605. rasa/tracing/instrumentation/metrics.py +206 -0
  606. rasa/tracing/metric_instrument_provider.py +125 -0
  607. rasa/utils/__init__.py +0 -0
  608. rasa/utils/beta.py +83 -0
  609. rasa/utils/cli.py +27 -0
  610. rasa/utils/common.py +635 -0
  611. rasa/utils/converter.py +53 -0
  612. rasa/utils/endpoints.py +303 -0
  613. rasa/utils/io.py +326 -0
  614. rasa/utils/licensing.py +319 -0
  615. rasa/utils/log_utils.py +174 -0
  616. rasa/utils/mapper.py +210 -0
  617. rasa/utils/ml_utils.py +145 -0
  618. rasa/utils/plotting.py +362 -0
  619. rasa/utils/singleton.py +23 -0
  620. rasa/utils/tensorflow/__init__.py +0 -0
  621. rasa/utils/tensorflow/callback.py +112 -0
  622. rasa/utils/tensorflow/constants.py +116 -0
  623. rasa/utils/tensorflow/crf.py +492 -0
  624. rasa/utils/tensorflow/data_generator.py +440 -0
  625. rasa/utils/tensorflow/environment.py +161 -0
  626. rasa/utils/tensorflow/exceptions.py +5 -0
  627. rasa/utils/tensorflow/layers.py +1565 -0
  628. rasa/utils/tensorflow/layers_utils.py +113 -0
  629. rasa/utils/tensorflow/metrics.py +281 -0
  630. rasa/utils/tensorflow/model_data.py +991 -0
  631. rasa/utils/tensorflow/model_data_utils.py +500 -0
  632. rasa/utils/tensorflow/models.py +936 -0
  633. rasa/utils/tensorflow/rasa_layers.py +1094 -0
  634. rasa/utils/tensorflow/transformer.py +640 -0
  635. rasa/utils/tensorflow/types.py +6 -0
  636. rasa/utils/train_utils.py +572 -0
  637. rasa/utils/yaml.py +54 -0
  638. rasa/validator.py +1035 -0
  639. rasa/version.py +3 -0
  640. rasa_pro-3.8.16.dist-info/METADATA +528 -0
  641. rasa_pro-3.8.16.dist-info/NOTICE +5 -0
  642. rasa_pro-3.8.16.dist-info/RECORD +644 -0
  643. rasa_pro-3.8.16.dist-info/WHEEL +4 -0
  644. rasa_pro-3.8.16.dist-info/entry_points.txt +3 -0
@@ -0,0 +1,2146 @@
1
+ from __future__ import annotations
2
+ import logging
3
+
4
+ from rasa.engine.recipes.default_recipe import DefaultV1Recipe
5
+ from pathlib import Path
6
+ from collections import defaultdict
7
+ import contextlib
8
+
9
+ import numpy as np
10
+ import tensorflow as tf
11
+ from typing import Any, List, Optional, Text, Dict, Tuple, Union, Type
12
+
13
+ from rasa.engine.graph import ExecutionContext
14
+ from rasa.engine.storage.resource import Resource
15
+ from rasa.engine.storage.storage import ModelStorage
16
+ from rasa.exceptions import ModelNotFound
17
+ from rasa.nlu.constants import TOKENS_NAMES
18
+ from rasa.nlu.extractors.extractor import EntityTagSpec, EntityExtractorMixin
19
+ import rasa.core.actions.action
20
+ from rasa.core.featurizers.precomputation import MessageContainerForCoreFeaturization
21
+ from rasa.core.featurizers.tracker_featurizers import TrackerFeaturizer
22
+ from rasa.core.featurizers.tracker_featurizers import MaxHistoryTrackerFeaturizer
23
+ from rasa.shared.exceptions import RasaException
24
+ from rasa.shared.nlu.constants import (
25
+ ACTION_TEXT,
26
+ ACTION_NAME,
27
+ INTENT,
28
+ TEXT,
29
+ ENTITIES,
30
+ FEATURE_TYPE_SENTENCE,
31
+ ENTITY_ATTRIBUTE_TYPE,
32
+ ENTITY_TAGS,
33
+ EXTRACTOR,
34
+ SPLIT_ENTITIES_BY_COMMA,
35
+ SPLIT_ENTITIES_BY_COMMA_DEFAULT_VALUE,
36
+ )
37
+ from rasa.core.policies.policy import PolicyPrediction, Policy, SupportedData
38
+ from rasa.core.constants import (
39
+ DIALOGUE,
40
+ POLICY_MAX_HISTORY,
41
+ DEFAULT_MAX_HISTORY,
42
+ DEFAULT_POLICY_PRIORITY,
43
+ POLICY_PRIORITY,
44
+ )
45
+ from rasa.shared.constants import DIAGNOSTIC_DATA
46
+ from rasa.shared.core.constants import ACTIVE_LOOP, SLOTS, ACTION_LISTEN_NAME
47
+ from rasa.shared.core.trackers import DialogueStateTracker
48
+ from rasa.shared.core.generator import TrackerWithCachedStates
49
+ from rasa.shared.core.events import EntitiesAdded, Event
50
+ from rasa.shared.core.domain import Domain
51
+ from rasa.shared.nlu.training_data.message import Message
52
+ from rasa.shared.nlu.training_data.features import Features
53
+ import rasa.shared.utils.io
54
+ import rasa.utils.io
55
+ from rasa.utils import train_utils
56
+ from rasa.utils.tensorflow.models import RasaModel, TransformerRasaModel
57
+ from rasa.utils.tensorflow import rasa_layers
58
+ from rasa.utils.tensorflow.model_data import (
59
+ RasaModelData,
60
+ FeatureSignature,
61
+ FeatureArray,
62
+ Data,
63
+ )
64
+ from rasa.utils.tensorflow.model_data_utils import convert_to_data_format
65
+ from rasa.utils.tensorflow.constants import (
66
+ LABEL,
67
+ IDS,
68
+ TRANSFORMER_SIZE,
69
+ NUM_TRANSFORMER_LAYERS,
70
+ NUM_HEADS,
71
+ BATCH_SIZES,
72
+ BATCH_STRATEGY,
73
+ EPOCHS,
74
+ RANDOM_SEED,
75
+ LEARNING_RATE,
76
+ RANKING_LENGTH,
77
+ RENORMALIZE_CONFIDENCES,
78
+ LOSS_TYPE,
79
+ SIMILARITY_TYPE,
80
+ NUM_NEG,
81
+ EVAL_NUM_EXAMPLES,
82
+ EVAL_NUM_EPOCHS,
83
+ NEGATIVE_MARGIN_SCALE,
84
+ REGULARIZATION_CONSTANT,
85
+ SCALE_LOSS,
86
+ USE_MAX_NEG_SIM,
87
+ MAX_NEG_SIM,
88
+ MAX_POS_SIM,
89
+ EMBEDDING_DIMENSION,
90
+ DROP_RATE_DIALOGUE,
91
+ DROP_RATE_LABEL,
92
+ DROP_RATE,
93
+ DROP_RATE_ATTENTION,
94
+ CONNECTION_DENSITY,
95
+ KEY_RELATIVE_ATTENTION,
96
+ VALUE_RELATIVE_ATTENTION,
97
+ MAX_RELATIVE_POSITION,
98
+ CROSS_ENTROPY,
99
+ AUTO,
100
+ BALANCED,
101
+ TENSORBOARD_LOG_DIR,
102
+ TENSORBOARD_LOG_LEVEL,
103
+ CHECKPOINT_MODEL,
104
+ ENCODING_DIMENSION,
105
+ UNIDIRECTIONAL_ENCODER,
106
+ SEQUENCE,
107
+ SENTENCE,
108
+ SEQUENCE_LENGTH,
109
+ DENSE_DIMENSION,
110
+ CONCAT_DIMENSION,
111
+ SPARSE_INPUT_DROPOUT,
112
+ DENSE_INPUT_DROPOUT,
113
+ MASKED_LM,
114
+ MASK,
115
+ HIDDEN_LAYERS_SIZES,
116
+ FEATURIZERS,
117
+ ENTITY_RECOGNITION,
118
+ CONSTRAIN_SIMILARITIES,
119
+ MODEL_CONFIDENCE,
120
+ SOFTMAX,
121
+ BILOU_FLAG,
122
+ EPOCH_OVERRIDE,
123
+ USE_GPU,
124
+ )
125
+
126
+ logger = logging.getLogger(__name__)
127
+
128
+ E2E_CONFIDENCE_THRESHOLD = "e2e_confidence_threshold"
129
+ LABEL_KEY = LABEL
130
+ LABEL_SUB_KEY = IDS
131
+ LENGTH = "length"
132
+ INDICES = "indices"
133
+ SENTENCE_FEATURES_TO_ENCODE = [INTENT, TEXT, ACTION_NAME, ACTION_TEXT]
134
+ SEQUENCE_FEATURES_TO_ENCODE = [TEXT, ACTION_TEXT, f"{LABEL}_{ACTION_TEXT}"]
135
+ LABEL_FEATURES_TO_ENCODE = [
136
+ f"{LABEL}_{ACTION_NAME}",
137
+ f"{LABEL}_{ACTION_TEXT}",
138
+ f"{LABEL}_{INTENT}",
139
+ ]
140
+ STATE_LEVEL_FEATURES = [ENTITIES, SLOTS, ACTIVE_LOOP]
141
+ PREDICTION_FEATURES = STATE_LEVEL_FEATURES + SENTENCE_FEATURES_TO_ENCODE + [DIALOGUE]
142
+
143
+
144
+ @DefaultV1Recipe.register(
145
+ DefaultV1Recipe.ComponentType.POLICY_WITH_END_TO_END_SUPPORT, is_trainable=True
146
+ )
147
+ class TEDPolicy(Policy):
148
+ """Transformer Embedding Dialogue (TED) Policy.
149
+
150
+ The model architecture is described in
151
+ detail in https://arxiv.org/abs/1910.00486.
152
+ In summary, the architecture comprises of the
153
+ following steps:
154
+ - concatenate user input (user intent and entities), previous system actions,
155
+ slots and active forms for each time step into an input vector to
156
+ pre-transformer embedding layer;
157
+ - feed it to transformer;
158
+ - apply a dense layer to the output of the transformer to get embeddings of a
159
+ dialogue for each time step;
160
+ - apply a dense layer to create embeddings for system actions for each time
161
+ step;
162
+ - calculate the similarity between the dialogue embedding and embedded system
163
+ actions. This step is based on the StarSpace
164
+ (https://arxiv.org/abs/1709.03856) idea.
165
+ """
166
+
167
+ @staticmethod
168
+ def get_default_config() -> Dict[Text, Any]:
169
+ """Returns the default config (see parent class for full docstring)."""
170
+ # please make sure to update the docs when changing a default parameter
171
+ return {
172
+ # ## Architecture of the used neural network
173
+ # Hidden layer sizes for layers before the embedding layers for user message
174
+ # and labels.
175
+ # The number of hidden layers is equal to the length of the corresponding
176
+ # list.
177
+ HIDDEN_LAYERS_SIZES: {
178
+ TEXT: [],
179
+ ACTION_TEXT: [],
180
+ f"{LABEL}_{ACTION_TEXT}": [],
181
+ },
182
+ # Dense dimension to use for sparse features.
183
+ DENSE_DIMENSION: {
184
+ TEXT: 128,
185
+ ACTION_TEXT: 128,
186
+ f"{LABEL}_{ACTION_TEXT}": 128,
187
+ INTENT: 20,
188
+ ACTION_NAME: 20,
189
+ f"{LABEL}_{ACTION_NAME}": 20,
190
+ ENTITIES: 20,
191
+ SLOTS: 20,
192
+ ACTIVE_LOOP: 20,
193
+ },
194
+ # Default dimension to use for concatenating sequence and sentence features.
195
+ CONCAT_DIMENSION: {
196
+ TEXT: 128,
197
+ ACTION_TEXT: 128,
198
+ f"{LABEL}_{ACTION_TEXT}": 128,
199
+ },
200
+ # Dimension size of embedding vectors before the dialogue transformer
201
+ # encoder.
202
+ ENCODING_DIMENSION: 50,
203
+ # Number of units in transformer encoders
204
+ TRANSFORMER_SIZE: {
205
+ TEXT: 128,
206
+ ACTION_TEXT: 128,
207
+ f"{LABEL}_{ACTION_TEXT}": 128,
208
+ DIALOGUE: 128,
209
+ },
210
+ # Number of layers in transformer encoders
211
+ NUM_TRANSFORMER_LAYERS: {
212
+ TEXT: 1,
213
+ ACTION_TEXT: 1,
214
+ f"{LABEL}_{ACTION_TEXT}": 1,
215
+ DIALOGUE: 1,
216
+ },
217
+ # Number of attention heads in transformer
218
+ NUM_HEADS: 4,
219
+ # If 'True' use key relative embeddings in attention
220
+ KEY_RELATIVE_ATTENTION: False,
221
+ # If 'True' use value relative embeddings in attention
222
+ VALUE_RELATIVE_ATTENTION: False,
223
+ # Max position for relative embeddings. Only in effect if key- or value
224
+ # relative
225
+ # attention are turned on
226
+ MAX_RELATIVE_POSITION: 5,
227
+ # Use a unidirectional or bidirectional encoder
228
+ # for `text`, `action_text`, and `label_action_text`.
229
+ UNIDIRECTIONAL_ENCODER: False,
230
+ # ## Training parameters
231
+ # Initial and final batch sizes:
232
+ # Batch size will be linearly increased for each epoch.
233
+ BATCH_SIZES: [64, 256],
234
+ # Strategy used whenc creating batches.
235
+ # Can be either 'sequence' or 'balanced'.
236
+ BATCH_STRATEGY: BALANCED,
237
+ # Number of epochs to train
238
+ EPOCHS: 1,
239
+ # Set random seed to any 'int' to get reproducible results
240
+ RANDOM_SEED: None,
241
+ # Initial learning rate for the optimizer
242
+ LEARNING_RATE: 0.001,
243
+ # ## Parameters for embeddings
244
+ # Dimension size of embedding vectors
245
+ EMBEDDING_DIMENSION: 20,
246
+ # The number of incorrect labels. The algorithm will minimize
247
+ # their similarity to the user input during training.
248
+ NUM_NEG: 20,
249
+ # Type of similarity measure to use, either 'auto' or 'cosine' or 'inner'.
250
+ SIMILARITY_TYPE: AUTO,
251
+ # The type of the loss function, either 'cross_entropy' or 'margin'.
252
+ LOSS_TYPE: CROSS_ENTROPY,
253
+ # Number of top actions for which confidences should be predicted.
254
+ # The number of Set to `0` if confidences for all actions should be
255
+ # predicted. The confidences for all other actions will be set to 0.
256
+ RANKING_LENGTH: 0,
257
+ # Determines wether the confidences of the chosen top actions should be
258
+ # renormalized so that they sum up to 1. By default, we do not renormalize
259
+ # and return the confidences for the top actions as is.
260
+ # Note that renormalization only makes sense if confidences are generated
261
+ # via `softmax`.
262
+ RENORMALIZE_CONFIDENCES: False,
263
+ # Indicates how similar the algorithm should try to make embedding vectors
264
+ # for correct labels.
265
+ # Should be 0.0 < ... < 1.0 for 'cosine' similarity type.
266
+ MAX_POS_SIM: 0.8,
267
+ # Maximum negative similarity for incorrect labels.
268
+ # Should be -1.0 < ... < 1.0 for 'cosine' similarity type.
269
+ MAX_NEG_SIM: -0.2,
270
+ # If 'True' the algorithm only minimizes maximum similarity over
271
+ # incorrect intent labels, used only if 'loss_type' is set to 'margin'.
272
+ USE_MAX_NEG_SIM: True,
273
+ # If 'True' scale loss inverse proportionally to the confidence
274
+ # of the correct prediction
275
+ SCALE_LOSS: True,
276
+ # ## Regularization parameters
277
+ # The scale of regularization
278
+ REGULARIZATION_CONSTANT: 0.001,
279
+ # The scale of how important is to minimize the maximum similarity
280
+ # between embeddings of different labels,
281
+ # used only if 'loss_type' is set to 'margin'.
282
+ NEGATIVE_MARGIN_SCALE: 0.8,
283
+ # Dropout rate for embedding layers of dialogue features.
284
+ DROP_RATE_DIALOGUE: 0.1,
285
+ # Dropout rate for embedding layers of utterance level features.
286
+ DROP_RATE: 0.0,
287
+ # Dropout rate for embedding layers of label, e.g. action, features.
288
+ DROP_RATE_LABEL: 0.0,
289
+ # Dropout rate for attention.
290
+ DROP_RATE_ATTENTION: 0.0,
291
+ # Fraction of trainable weights in internal layers.
292
+ CONNECTION_DENSITY: 0.2,
293
+ # If 'True' apply dropout to sparse input tensors
294
+ SPARSE_INPUT_DROPOUT: True,
295
+ # If 'True' apply dropout to dense input tensors
296
+ DENSE_INPUT_DROPOUT: True,
297
+ # If 'True' random tokens of the input message will be masked. Since there
298
+ # is no related loss term used inside TED, the masking effectively becomes
299
+ # just input dropout applied to the text of user utterances.
300
+ MASKED_LM: False,
301
+ # ## Evaluation parameters
302
+ # How often calculate validation accuracy.
303
+ # Small values may hurt performance.
304
+ EVAL_NUM_EPOCHS: 20,
305
+ # How many examples to use for hold out validation set
306
+ # Large values may hurt performance, e.g. model accuracy.
307
+ # Set to 0 for no validation.
308
+ EVAL_NUM_EXAMPLES: 0,
309
+ # If you want to use tensorboard to visualize training and validation
310
+ # metrics, set this option to a valid output directory.
311
+ TENSORBOARD_LOG_DIR: None,
312
+ # Define when training metrics for tensorboard should be logged.
313
+ # Either after every epoch or for every training step.
314
+ # Valid values: 'epoch' and 'batch'
315
+ TENSORBOARD_LOG_LEVEL: "epoch",
316
+ # Perform model checkpointing
317
+ CHECKPOINT_MODEL: False,
318
+ # Only pick e2e prediction if the policy is confident enough
319
+ E2E_CONFIDENCE_THRESHOLD: 0.5,
320
+ # Specify what features to use as sequence and sentence features.
321
+ # By default all features in the pipeline are used.
322
+ FEATURIZERS: [],
323
+ # If set to true, entities are predicted in user utterances.
324
+ ENTITY_RECOGNITION: True,
325
+ # if 'True' applies sigmoid on all similarity terms and adds
326
+ # it to the loss function to ensure that similarity values are
327
+ # approximately bounded. Used inside cross-entropy loss only.
328
+ CONSTRAIN_SIMILARITIES: False,
329
+ # Model confidence to be returned during inference. Currently, the only
330
+ # possible value is `softmax`.
331
+ MODEL_CONFIDENCE: SOFTMAX,
332
+ # 'BILOU_flag' determines whether to use BILOU tagging or not.
333
+ # If set to 'True' labelling is more rigorous, however more
334
+ # examples per entity are required.
335
+ # Rule of thumb: you should have more than 100 examples per entity.
336
+ BILOU_FLAG: True,
337
+ # Split entities by comma, this makes sense e.g. for a list of
338
+ # ingredients in a recipe, but it doesn't make sense for the parts of
339
+ # an address
340
+ SPLIT_ENTITIES_BY_COMMA: SPLIT_ENTITIES_BY_COMMA_DEFAULT_VALUE,
341
+ # Max history of the policy, unbounded by default
342
+ POLICY_MAX_HISTORY: DEFAULT_MAX_HISTORY,
343
+ # Determines the importance of policies, higher values take precedence
344
+ POLICY_PRIORITY: DEFAULT_POLICY_PRIORITY,
345
+ USE_GPU: True,
346
+ }
347
+
348
+ def __init__(
349
+ self,
350
+ config: Dict[Text, Any],
351
+ model_storage: ModelStorage,
352
+ resource: Resource,
353
+ execution_context: ExecutionContext,
354
+ model: Optional[RasaModel] = None,
355
+ featurizer: Optional[TrackerFeaturizer] = None,
356
+ fake_features: Optional[Dict[Text, List[Features]]] = None,
357
+ entity_tag_specs: Optional[List[EntityTagSpec]] = None,
358
+ ) -> None:
359
+ """Declares instance variables with default values."""
360
+ super().__init__(
361
+ config, model_storage, resource, execution_context, featurizer=featurizer
362
+ )
363
+ self.split_entities_config = rasa.utils.train_utils.init_split_entities(
364
+ config[SPLIT_ENTITIES_BY_COMMA], SPLIT_ENTITIES_BY_COMMA_DEFAULT_VALUE
365
+ )
366
+ self._load_params(config)
367
+
368
+ self.model = model
369
+
370
+ self._entity_tag_specs = entity_tag_specs
371
+
372
+ self.fake_features = fake_features or defaultdict(list)
373
+ # TED is only e2e if only text is present in fake features, which represent
374
+ # all possible input features for current version of this trained ted
375
+ self.only_e2e = TEXT in self.fake_features and INTENT not in self.fake_features
376
+
377
+ self._label_data: Optional[RasaModelData] = None
378
+ self.data_example: Optional[Dict[Text, Dict[Text, List[FeatureArray]]]] = None
379
+
380
+ self.tmp_checkpoint_dir = None
381
+ if self.config[CHECKPOINT_MODEL]:
382
+ self.tmp_checkpoint_dir = Path(rasa.utils.io.create_temporary_directory())
383
+
384
+ @staticmethod
385
+ def model_class() -> Type[TED]:
386
+ """Gets the class of the model architecture to be used by the policy.
387
+
388
+ Returns:
389
+ Required class.
390
+ """
391
+ return TED
392
+
393
+ @classmethod
394
+ def _metadata_filename(cls) -> Optional[Text]:
395
+ return "ted_policy"
396
+
397
+ def _load_params(self, config: Dict[Text, Any]) -> None:
398
+ new_config = rasa.utils.train_utils.check_core_deprecated_options(config)
399
+ self.config = new_config
400
+ self._auto_update_configuration()
401
+
402
+ def _auto_update_configuration(self) -> None:
403
+ """Takes care of deprecations and compatibility of parameters."""
404
+ self.config = rasa.utils.train_utils.update_confidence_type(self.config)
405
+ rasa.utils.train_utils.validate_configuration_settings(self.config)
406
+ self.config = rasa.utils.train_utils.update_similarity_type(self.config)
407
+ self.config = rasa.utils.train_utils.update_evaluation_parameters(self.config)
408
+
409
+ def _create_label_data(
410
+ self,
411
+ domain: Domain,
412
+ precomputations: Optional[MessageContainerForCoreFeaturization],
413
+ ) -> Tuple[RasaModelData, List[Dict[Text, List[Features]]]]:
414
+ # encode all label_ids with policies' featurizer
415
+ state_featurizer = self.featurizer.state_featurizer
416
+ encoded_all_labels = (
417
+ state_featurizer.encode_all_labels(domain, precomputations)
418
+ if state_featurizer is not None
419
+ else []
420
+ )
421
+
422
+ attribute_data, _ = convert_to_data_format(
423
+ encoded_all_labels, featurizers=self.config[FEATURIZERS]
424
+ )
425
+
426
+ label_data = self._assemble_label_data(attribute_data, domain)
427
+
428
+ return label_data, encoded_all_labels
429
+
430
+ def _assemble_label_data(
431
+ self, attribute_data: Data, domain: Domain
432
+ ) -> RasaModelData:
433
+ """Constructs data regarding labels to be fed to the model.
434
+
435
+ The resultant model data can possibly contain one or both of the
436
+ keys - [`label_action_name`, `label_action_text`] but will definitely
437
+ contain the `label` key.
438
+ `label_action_*` will contain the sequence, sentence and mask features
439
+ for corresponding labels and `label` will contain the numerical label ids.
440
+
441
+ Args:
442
+ attribute_data: Feature data for all labels.
443
+ domain: Domain of the assistant.
444
+
445
+ Returns:
446
+ Features of labels ready to be fed to the model.
447
+ """
448
+ label_data = RasaModelData()
449
+ label_data.add_data(attribute_data, key_prefix=f"{LABEL_KEY}_")
450
+ label_data.add_lengths(
451
+ f"{LABEL}_{ACTION_TEXT}",
452
+ SEQUENCE_LENGTH,
453
+ f"{LABEL}_{ACTION_TEXT}",
454
+ SEQUENCE,
455
+ )
456
+ label_ids = np.arange(domain.num_actions)
457
+ label_data.add_features(
458
+ LABEL_KEY,
459
+ LABEL_SUB_KEY,
460
+ [
461
+ FeatureArray(
462
+ np.expand_dims(label_ids, -1),
463
+ number_of_dimensions=2,
464
+ )
465
+ ],
466
+ )
467
+ return label_data
468
+
469
+ @staticmethod
470
+ def _should_extract_entities(
471
+ entity_tags: List[List[Dict[Text, List[Features]]]]
472
+ ) -> bool:
473
+ for turns_tags in entity_tags:
474
+ for turn_tags in turns_tags:
475
+ # if turn_tags are empty or all entity tag indices are `0`
476
+ # it means that all the inputs only contain NO_ENTITY_TAG
477
+ if turn_tags and np.any(turn_tags[ENTITY_TAGS][0].features):
478
+ return True
479
+ return False
480
+
481
+ def _create_data_for_entities(
482
+ self, entity_tags: Optional[List[List[Dict[Text, List[Features]]]]]
483
+ ) -> Optional[Data]:
484
+ if not self.config[ENTITY_RECOGNITION]:
485
+ return None
486
+
487
+ # check that there are real entity tags
488
+ if entity_tags and self._should_extract_entities(entity_tags):
489
+ entity_tags_data, _ = convert_to_data_format(entity_tags)
490
+ return entity_tags_data
491
+
492
+ # there are no "real" entity tags
493
+ logger.debug(
494
+ f"Entity recognition cannot be performed, "
495
+ f"set '{ENTITY_RECOGNITION}' config parameter to 'False'."
496
+ )
497
+ self.config[ENTITY_RECOGNITION] = False
498
+
499
+ return None
500
+
501
+ def _create_model_data(
502
+ self,
503
+ tracker_state_features: List[List[Dict[Text, List[Features]]]],
504
+ label_ids: Optional[np.ndarray] = None,
505
+ entity_tags: Optional[List[List[Dict[Text, List[Features]]]]] = None,
506
+ encoded_all_labels: Optional[List[Dict[Text, List[Features]]]] = None,
507
+ ) -> RasaModelData:
508
+ """Combine all model related data into RasaModelData.
509
+
510
+ Args:
511
+ tracker_state_features: a dictionary of attributes
512
+ (INTENT, TEXT, ACTION_NAME, ACTION_TEXT, ENTITIES, SLOTS, ACTIVE_LOOP)
513
+ to a list of features for all dialogue turns in all training trackers
514
+ label_ids: the label ids (e.g. action ids) for every dialogue turn in all
515
+ training trackers
516
+ entity_tags: a dictionary of entity type (ENTITY_TAGS) to a list of features
517
+ containing entity tag ids for text user inputs otherwise empty dict
518
+ for all dialogue turns in all training trackers
519
+ encoded_all_labels: a list of dictionaries containing attribute features
520
+ for label ids
521
+
522
+ Returns:
523
+ RasaModelData
524
+ """
525
+ model_data = RasaModelData(label_key=LABEL_KEY, label_sub_key=LABEL_SUB_KEY)
526
+
527
+ if label_ids is not None and encoded_all_labels is not None:
528
+ label_ids = np.array(
529
+ [np.expand_dims(seq_label_ids, -1) for seq_label_ids in label_ids]
530
+ )
531
+ model_data.add_features(
532
+ LABEL_KEY,
533
+ LABEL_SUB_KEY,
534
+ [FeatureArray(label_ids, number_of_dimensions=3)],
535
+ )
536
+
537
+ attribute_data, self.fake_features = convert_to_data_format(
538
+ tracker_state_features, featurizers=self.config[FEATURIZERS]
539
+ )
540
+
541
+ entity_tags_data = self._create_data_for_entities(entity_tags)
542
+ if entity_tags_data is not None:
543
+ model_data.add_data(entity_tags_data)
544
+ else:
545
+ # method is called during prediction
546
+ attribute_data, _ = convert_to_data_format(
547
+ tracker_state_features,
548
+ self.fake_features,
549
+ featurizers=self.config[FEATURIZERS],
550
+ )
551
+
552
+ model_data.add_data(attribute_data)
553
+ model_data.add_lengths(TEXT, SEQUENCE_LENGTH, TEXT, SEQUENCE)
554
+ model_data.add_lengths(ACTION_TEXT, SEQUENCE_LENGTH, ACTION_TEXT, SEQUENCE)
555
+
556
+ # add the dialogue lengths
557
+ attribute_present = next(iter(list(attribute_data.keys())))
558
+ dialogue_lengths = np.array(
559
+ [
560
+ np.size(np.squeeze(f, -1))
561
+ for f in model_data.data[attribute_present][MASK][0]
562
+ ]
563
+ )
564
+ model_data.data[DIALOGUE][LENGTH] = [
565
+ FeatureArray(dialogue_lengths, number_of_dimensions=1)
566
+ ]
567
+
568
+ # make sure all keys are in the same order during training and prediction
569
+ model_data.sort()
570
+
571
+ return model_data
572
+
573
+ @staticmethod
574
+ def _get_trackers_for_training(
575
+ trackers: List[TrackerWithCachedStates],
576
+ ) -> List[TrackerWithCachedStates]:
577
+ """Filters out the list of trackers which should not be used for training.
578
+
579
+ Args:
580
+ trackers: All trackers available for training.
581
+
582
+ Returns:
583
+ Trackers which should be used for training.
584
+ """
585
+ # By default, we train on all available trackers.
586
+ return trackers
587
+
588
+ def _prepare_for_training(
589
+ self,
590
+ trackers: List[TrackerWithCachedStates],
591
+ domain: Domain,
592
+ precomputations: MessageContainerForCoreFeaturization,
593
+ **kwargs: Any,
594
+ ) -> Tuple[RasaModelData, np.ndarray]:
595
+ """Prepares data to be fed into the model.
596
+
597
+ Args:
598
+ trackers: List of training trackers to be featurized.
599
+ domain: Domain of the assistant.
600
+ precomputations: Contains precomputed features and attributes.
601
+ **kwargs: Any other arguments.
602
+
603
+ Returns:
604
+ Featurized data to be fed to the model and corresponding label ids.
605
+ """
606
+ training_trackers = self._get_trackers_for_training(trackers)
607
+ # dealing with training data
608
+ tracker_state_features, label_ids, entity_tags = self._featurize_for_training(
609
+ training_trackers,
610
+ domain,
611
+ precomputations=precomputations,
612
+ bilou_tagging=self.config[BILOU_FLAG],
613
+ **kwargs,
614
+ )
615
+
616
+ if not tracker_state_features:
617
+ return RasaModelData(), label_ids
618
+
619
+ self._label_data, encoded_all_labels = self._create_label_data(
620
+ domain, precomputations=precomputations
621
+ )
622
+
623
+ # extract actual training data to feed to model
624
+ model_data = self._create_model_data(
625
+ tracker_state_features, label_ids, entity_tags, encoded_all_labels
626
+ )
627
+
628
+ if self.config[ENTITY_RECOGNITION]:
629
+ self._entity_tag_specs = (
630
+ self.featurizer.state_featurizer.entity_tag_specs
631
+ if self.featurizer.state_featurizer is not None
632
+ else []
633
+ )
634
+
635
+ # keep one example for persisting and loading
636
+ self.data_example = model_data.first_data_example()
637
+
638
+ return model_data, label_ids
639
+
640
+ def run_training(
641
+ self, model_data: RasaModelData, label_ids: Optional[np.ndarray] = None
642
+ ) -> None:
643
+ """Feeds the featurized training data to the model.
644
+
645
+ Args:
646
+ model_data: Featurized training data.
647
+ label_ids: Label ids corresponding to the data points in `model_data`.
648
+ These may or may not be used by the function depending
649
+ on how the policy is trained.
650
+ """
651
+ if not self.finetune_mode:
652
+ # This means the model wasn't loaded from a
653
+ # previously trained model and hence needs
654
+ # to be instantiated.
655
+ self.model = self.model_class()(
656
+ model_data.get_signature(),
657
+ self.config,
658
+ isinstance(self.featurizer, MaxHistoryTrackerFeaturizer),
659
+ self._label_data,
660
+ self._entity_tag_specs,
661
+ )
662
+ self.model.compile(
663
+ optimizer=tf.keras.optimizers.Adam(self.config[LEARNING_RATE])
664
+ )
665
+ (
666
+ data_generator,
667
+ validation_data_generator,
668
+ ) = rasa.utils.train_utils.create_data_generators(
669
+ model_data,
670
+ self.config[BATCH_SIZES],
671
+ self.config[EPOCHS],
672
+ self.config[BATCH_STRATEGY],
673
+ self.config[EVAL_NUM_EXAMPLES],
674
+ self.config[RANDOM_SEED],
675
+ )
676
+ callbacks = rasa.utils.train_utils.create_common_callbacks(
677
+ self.config[EPOCHS],
678
+ self.config[TENSORBOARD_LOG_DIR],
679
+ self.config[TENSORBOARD_LOG_LEVEL],
680
+ self.tmp_checkpoint_dir,
681
+ )
682
+
683
+ if self.model is None:
684
+ raise ModelNotFound("No model was detected prior to training.")
685
+
686
+ self.model.fit(
687
+ data_generator,
688
+ epochs=self.config[EPOCHS],
689
+ validation_data=validation_data_generator,
690
+ validation_freq=self.config[EVAL_NUM_EPOCHS],
691
+ callbacks=callbacks,
692
+ verbose=False,
693
+ shuffle=False, # we use custom shuffle inside data generator
694
+ )
695
+
696
+ def train(
697
+ self,
698
+ training_trackers: List[TrackerWithCachedStates],
699
+ domain: Domain,
700
+ precomputations: Optional[MessageContainerForCoreFeaturization] = None,
701
+ **kwargs: Any,
702
+ ) -> Resource:
703
+ """Trains the policy (see parent class for full docstring)."""
704
+ if not training_trackers:
705
+ rasa.shared.utils.io.raise_warning(
706
+ f"Skipping training of `{self.__class__.__name__}` "
707
+ f"as no data was provided. You can exclude this "
708
+ f"policy in the configuration "
709
+ f"file to avoid this warning.",
710
+ category=UserWarning,
711
+ )
712
+ return self._resource
713
+
714
+ training_trackers = SupportedData.trackers_for_supported_data(
715
+ self.supported_data(), training_trackers
716
+ )
717
+
718
+ model_data, label_ids = self._prepare_for_training(
719
+ training_trackers, domain, precomputations
720
+ )
721
+
722
+ if model_data.is_empty():
723
+ rasa.shared.utils.io.raise_warning(
724
+ f"Skipping training of `{self.__class__.__name__}` "
725
+ f"as no data was provided. You can exclude this "
726
+ f"policy in the configuration "
727
+ f"file to avoid this warning.",
728
+ category=UserWarning,
729
+ )
730
+ return self._resource
731
+
732
+ with (
733
+ contextlib.nullcontext() if self.config["use_gpu"] else tf.device("/cpu:0")
734
+ ):
735
+ self.run_training(model_data, label_ids)
736
+
737
+ self.persist()
738
+
739
+ return self._resource
740
+
741
+ def _featurize_tracker(
742
+ self,
743
+ tracker: DialogueStateTracker,
744
+ domain: Domain,
745
+ precomputations: Optional[MessageContainerForCoreFeaturization],
746
+ rule_only_data: Optional[Dict[Text, Any]],
747
+ ) -> List[List[Dict[Text, List[Features]]]]:
748
+ # construct two examples in the batch to be fed to the model -
749
+ # one by featurizing last user text
750
+ # and second - an optional one (see conditions below),
751
+ # the first example in the constructed batch either does not contain user input
752
+ # or uses intent or text based on whether TED is e2e only.
753
+ tracker_state_features = self._featurize_for_prediction(
754
+ tracker,
755
+ domain,
756
+ precomputations=precomputations,
757
+ use_text_for_last_user_input=self.only_e2e,
758
+ rule_only_data=rule_only_data,
759
+ )
760
+ # the second - text, but only after user utterance and if not only e2e
761
+ if (
762
+ tracker.latest_action_name == ACTION_LISTEN_NAME
763
+ and TEXT in self.fake_features
764
+ and not self.only_e2e
765
+ ):
766
+ tracker_state_features += self._featurize_for_prediction(
767
+ tracker,
768
+ domain,
769
+ precomputations=precomputations,
770
+ use_text_for_last_user_input=True,
771
+ rule_only_data=rule_only_data,
772
+ )
773
+ return tracker_state_features
774
+
775
+ def _pick_confidence(
776
+ self, confidences: np.ndarray, similarities: np.ndarray, domain: Domain
777
+ ) -> Tuple[np.ndarray, bool]:
778
+ # the confidences and similarities have shape (batch-size x number of actions)
779
+ # batch-size can only be 1 or 2;
780
+ # in the case batch-size==2, the first example contain user intent as features,
781
+ # the second - user text as features
782
+ if confidences.shape[0] > 2:
783
+ raise ValueError(
784
+ "We cannot pick prediction from batches of size more than 2."
785
+ )
786
+ # we use heuristic to pick correct prediction
787
+ if confidences.shape[0] == 2:
788
+ # we use similarities to pick appropriate input,
789
+ # since it seems to be more accurate measure,
790
+ # policy is trained to maximize the similarity not the confidence
791
+ non_e2e_action_name = domain.action_names_or_texts[
792
+ np.argmax(confidences[0])
793
+ ]
794
+ logger.debug(f"User intent lead to '{non_e2e_action_name}'.")
795
+ e2e_action_name = domain.action_names_or_texts[np.argmax(confidences[1])]
796
+ logger.debug(f"User text lead to '{e2e_action_name}'.")
797
+ if (
798
+ np.max(confidences[1]) > self.config[E2E_CONFIDENCE_THRESHOLD]
799
+ # TODO maybe compare confidences is better
800
+ and np.max(similarities[1]) > np.max(similarities[0])
801
+ ):
802
+ logger.debug(f"TED predicted '{e2e_action_name}' based on user text.")
803
+ return confidences[1], True
804
+
805
+ logger.debug(f"TED predicted '{non_e2e_action_name}' based on user intent.")
806
+ return confidences[0], False
807
+
808
+ # by default the first example in a batch is the one to use for prediction
809
+ predicted_action_name = domain.action_names_or_texts[np.argmax(confidences[0])]
810
+ basis_for_prediction = "text" if self.only_e2e else "intent"
811
+ logger.debug(
812
+ f"TED predicted '{predicted_action_name}' "
813
+ f"based on user {basis_for_prediction}."
814
+ )
815
+ return confidences[0], self.only_e2e
816
+
817
+ async def predict_action_probabilities(
818
+ self,
819
+ tracker: DialogueStateTracker,
820
+ domain: Domain,
821
+ rule_only_data: Optional[Dict[Text, Any]] = None,
822
+ precomputations: Optional[MessageContainerForCoreFeaturization] = None,
823
+ **kwargs: Any,
824
+ ) -> PolicyPrediction:
825
+ """Predicts the next action (see parent class for full docstring)."""
826
+ if self.model is None or self.should_abstain_in_coexistence(tracker, False):
827
+ return self._prediction(self._default_predictions(domain))
828
+
829
+ # create model data from tracker
830
+ tracker_state_features = self._featurize_tracker(
831
+ tracker, domain, precomputations, rule_only_data=rule_only_data
832
+ )
833
+ model_data = self._create_model_data(tracker_state_features)
834
+ outputs = self.model.run_inference(model_data)
835
+
836
+ if isinstance(outputs["similarities"], np.ndarray):
837
+ # take the last prediction in the sequence
838
+ similarities = outputs["similarities"][:, -1, :]
839
+ else:
840
+ raise TypeError(
841
+ "model output for `similarities` " "should be a numpy array"
842
+ )
843
+ if isinstance(outputs["scores"], np.ndarray):
844
+ confidences = outputs["scores"][:, -1, :]
845
+ else:
846
+ raise TypeError("model output for `scores` should be a numpy array")
847
+ # take correct prediction from batch
848
+ confidence, is_e2e_prediction = self._pick_confidence(
849
+ confidences, similarities, domain
850
+ )
851
+
852
+ # rank and mask the confidence (if we need to)
853
+ ranking_length = self.config[RANKING_LENGTH]
854
+ if 0 < ranking_length < len(confidence):
855
+ renormalize = (
856
+ self.config[RENORMALIZE_CONFIDENCES]
857
+ and self.config[MODEL_CONFIDENCE] == SOFTMAX
858
+ )
859
+ _, confidence = train_utils.rank_and_mask(
860
+ confidence, ranking_length=ranking_length, renormalize=renormalize
861
+ )
862
+
863
+ optional_events = self._create_optional_event_for_entities(
864
+ outputs, is_e2e_prediction, precomputations, tracker
865
+ )
866
+
867
+ return self._prediction(
868
+ confidence.tolist(),
869
+ is_end_to_end_prediction=is_e2e_prediction,
870
+ optional_events=optional_events,
871
+ diagnostic_data=outputs.get(DIAGNOSTIC_DATA),
872
+ )
873
+
874
+ def _create_optional_event_for_entities(
875
+ self,
876
+ prediction_output: Dict[Text, tf.Tensor],
877
+ is_e2e_prediction: bool,
878
+ precomputations: Optional[MessageContainerForCoreFeaturization],
879
+ tracker: DialogueStateTracker,
880
+ ) -> Optional[List[Event]]:
881
+ if tracker.latest_action_name != ACTION_LISTEN_NAME or not is_e2e_prediction:
882
+ # entities belong only to the last user message
883
+ # and only if user text was used for prediction,
884
+ # a user message always comes after action listen
885
+ return None
886
+
887
+ if not self.config[ENTITY_RECOGNITION]:
888
+ # entity recognition is not turned on, no entities can be predicted
889
+ return None
890
+
891
+ # The batch dimension of entity prediction is not the same as batch size,
892
+ # rather it is the number of last (if max history featurizer else all)
893
+ # text inputs in the batch
894
+ # therefore, in order to pick entities from the latest user message
895
+ # we need to pick entities from the last batch dimension of entity prediction
896
+ predicted_tags, confidence_values = rasa.utils.train_utils.entity_label_to_tags(
897
+ prediction_output,
898
+ self._entity_tag_specs,
899
+ self.config[BILOU_FLAG],
900
+ prediction_index=-1,
901
+ )
902
+
903
+ if ENTITY_ATTRIBUTE_TYPE not in predicted_tags:
904
+ # no entities detected
905
+ return None
906
+
907
+ # entities belong to the last message of the tracker
908
+ # convert the predicted tags to actual entities
909
+ text = tracker.latest_message.text if tracker.latest_message is not None else ""
910
+ if precomputations is not None:
911
+ parsed_message = precomputations.lookup_message(user_text=text)
912
+ else:
913
+ parsed_message = Message(data={TEXT: text})
914
+ tokens = parsed_message.get(TOKENS_NAMES[TEXT])
915
+ entities = EntityExtractorMixin.convert_predictions_into_entities(
916
+ text,
917
+ tokens,
918
+ predicted_tags,
919
+ self.split_entities_config,
920
+ confidences=confidence_values,
921
+ )
922
+
923
+ # add the extractor name
924
+ for entity in entities:
925
+ entity[EXTRACTOR] = "TEDPolicy"
926
+
927
+ return [EntitiesAdded(entities)]
928
+
929
+ def persist(self) -> None:
930
+ """Persists the policy to a storage."""
931
+ if self.model is None:
932
+ logger.debug(
933
+ "Method `persist(...)` was called without a trained model present. "
934
+ "Nothing to persist then!"
935
+ )
936
+ return
937
+
938
+ with self._model_storage.write_to(self._resource) as model_path:
939
+ model_filename = self._metadata_filename()
940
+ tf_model_file = model_path / f"{model_filename}.tf_model"
941
+
942
+ rasa.shared.utils.io.create_directory_for_file(tf_model_file)
943
+
944
+ self.featurizer.persist(model_path)
945
+
946
+ if self.config[CHECKPOINT_MODEL] and self.tmp_checkpoint_dir:
947
+ self.model.load_weights(self.tmp_checkpoint_dir / "checkpoint.tf_model")
948
+ # Save an empty file to flag that this model has been
949
+ # produced using checkpointing
950
+ checkpoint_marker = model_path / f"{model_filename}.from_checkpoint.pkl"
951
+ checkpoint_marker.touch()
952
+
953
+ self.model.save(str(tf_model_file))
954
+
955
+ self.persist_model_utilities(model_path)
956
+
957
+ def persist_model_utilities(self, model_path: Path) -> None:
958
+ """Persists model's utility attributes like model weights, etc.
959
+
960
+ Args:
961
+ model_path: Path where model is to be persisted
962
+ """
963
+ model_filename = self._metadata_filename()
964
+ rasa.utils.io.json_pickle(
965
+ model_path / f"{model_filename}.priority.pkl", self.priority
966
+ )
967
+ rasa.utils.io.pickle_dump(
968
+ model_path / f"{model_filename}.meta.pkl", self.config
969
+ )
970
+ rasa.utils.io.pickle_dump(
971
+ model_path / f"{model_filename}.data_example.pkl", self.data_example
972
+ )
973
+ rasa.utils.io.pickle_dump(
974
+ model_path / f"{model_filename}.fake_features.pkl", self.fake_features
975
+ )
976
+ rasa.utils.io.pickle_dump(
977
+ model_path / f"{model_filename}.label_data.pkl",
978
+ dict(self._label_data.data) if self._label_data is not None else {},
979
+ )
980
+ entity_tag_specs = (
981
+ [tag_spec._asdict() for tag_spec in self._entity_tag_specs]
982
+ if self._entity_tag_specs
983
+ else []
984
+ )
985
+ rasa.shared.utils.io.dump_obj_as_json_to_file(
986
+ model_path / f"{model_filename}.entity_tag_specs.json", entity_tag_specs
987
+ )
988
+
989
+ @classmethod
990
+ def _load_model_utilities(cls, model_path: Path) -> Dict[Text, Any]:
991
+ """Loads model's utility attributes.
992
+
993
+ Args:
994
+ model_path: Path where model is to be persisted.
995
+ """
996
+ tf_model_file = model_path / f"{cls._metadata_filename()}.tf_model"
997
+ loaded_data = rasa.utils.io.pickle_load(
998
+ model_path / f"{cls._metadata_filename()}.data_example.pkl"
999
+ )
1000
+ label_data = rasa.utils.io.pickle_load(
1001
+ model_path / f"{cls._metadata_filename()}.label_data.pkl"
1002
+ )
1003
+ fake_features = rasa.utils.io.pickle_load(
1004
+ model_path / f"{cls._metadata_filename()}.fake_features.pkl"
1005
+ )
1006
+ label_data = RasaModelData(data=label_data)
1007
+ priority = rasa.utils.io.json_unpickle(
1008
+ model_path / f"{cls._metadata_filename()}.priority.pkl"
1009
+ )
1010
+ entity_tag_specs = rasa.shared.utils.io.read_json_file(
1011
+ model_path / f"{cls._metadata_filename()}.entity_tag_specs.json"
1012
+ )
1013
+ entity_tag_specs = [
1014
+ EntityTagSpec(
1015
+ tag_name=tag_spec["tag_name"],
1016
+ ids_to_tags={
1017
+ int(key): value for key, value in tag_spec["ids_to_tags"].items()
1018
+ },
1019
+ tags_to_ids={
1020
+ key: int(value) for key, value in tag_spec["tags_to_ids"].items()
1021
+ },
1022
+ num_tags=tag_spec["num_tags"],
1023
+ )
1024
+ for tag_spec in entity_tag_specs
1025
+ ]
1026
+ model_config = rasa.utils.io.pickle_load(
1027
+ model_path / f"{cls._metadata_filename()}.meta.pkl"
1028
+ )
1029
+
1030
+ return {
1031
+ "tf_model_file": tf_model_file,
1032
+ "loaded_data": loaded_data,
1033
+ "fake_features": fake_features,
1034
+ "label_data": label_data,
1035
+ "priority": priority,
1036
+ "entity_tag_specs": entity_tag_specs,
1037
+ "model_config": model_config,
1038
+ }
1039
+
1040
+ @classmethod
1041
+ def load(
1042
+ cls,
1043
+ config: Dict[Text, Any],
1044
+ model_storage: ModelStorage,
1045
+ resource: Resource,
1046
+ execution_context: ExecutionContext,
1047
+ **kwargs: Any,
1048
+ ) -> TEDPolicy:
1049
+ """Loads a policy from the storage (see parent class for full docstring)."""
1050
+ try:
1051
+ with model_storage.read_from(resource) as model_path:
1052
+ return cls._load(
1053
+ model_path, config, model_storage, resource, execution_context
1054
+ )
1055
+ except ValueError:
1056
+ logger.debug(
1057
+ f"Failed to load {cls.__class__.__name__} from model storage. Resource "
1058
+ f"'{resource.name}' doesn't exist."
1059
+ )
1060
+ return cls(config, model_storage, resource, execution_context)
1061
+
1062
+ @classmethod
1063
+ def _load(
1064
+ cls,
1065
+ model_path: Path,
1066
+ config: Dict[Text, Any],
1067
+ model_storage: ModelStorage,
1068
+ resource: Resource,
1069
+ execution_context: ExecutionContext,
1070
+ ) -> TEDPolicy:
1071
+ featurizer = TrackerFeaturizer.load(model_path)
1072
+
1073
+ if not (model_path / f"{cls._metadata_filename()}.data_example.pkl").is_file():
1074
+ return cls(
1075
+ config,
1076
+ model_storage,
1077
+ resource,
1078
+ execution_context,
1079
+ featurizer=featurizer,
1080
+ )
1081
+
1082
+ model_utilities = cls._load_model_utilities(model_path)
1083
+
1084
+ config = cls._update_loaded_params(config)
1085
+ if execution_context.is_finetuning and EPOCH_OVERRIDE in config:
1086
+ config[EPOCHS] = config.get(EPOCH_OVERRIDE)
1087
+
1088
+ (
1089
+ model_data_example,
1090
+ predict_data_example,
1091
+ ) = cls._construct_model_initialization_data(model_utilities["loaded_data"])
1092
+
1093
+ model = None
1094
+
1095
+ with (contextlib.nullcontext() if config["use_gpu"] else tf.device("/cpu:0")):
1096
+ model = cls._load_tf_model(
1097
+ model_utilities,
1098
+ model_data_example,
1099
+ predict_data_example,
1100
+ featurizer,
1101
+ execution_context.is_finetuning,
1102
+ )
1103
+
1104
+ return cls._load_policy_with_model(
1105
+ config,
1106
+ model_storage,
1107
+ resource,
1108
+ execution_context,
1109
+ featurizer=featurizer,
1110
+ model_utilities=model_utilities,
1111
+ model=model,
1112
+ )
1113
+
1114
+ @classmethod
1115
+ def _load_policy_with_model(
1116
+ cls,
1117
+ config: Dict[Text, Any],
1118
+ model_storage: ModelStorage,
1119
+ resource: Resource,
1120
+ execution_context: ExecutionContext,
1121
+ featurizer: TrackerFeaturizer,
1122
+ model: TED,
1123
+ model_utilities: Dict[Text, Any],
1124
+ ) -> TEDPolicy:
1125
+ return cls(
1126
+ config,
1127
+ model_storage,
1128
+ resource,
1129
+ execution_context,
1130
+ model=model,
1131
+ featurizer=featurizer,
1132
+ fake_features=model_utilities["fake_features"],
1133
+ entity_tag_specs=model_utilities["entity_tag_specs"],
1134
+ )
1135
+
1136
+ @classmethod
1137
+ def _load_tf_model(
1138
+ cls,
1139
+ model_utilities: Dict[Text, Any],
1140
+ model_data_example: RasaModelData,
1141
+ predict_data_example: RasaModelData,
1142
+ featurizer: TrackerFeaturizer,
1143
+ should_finetune: bool,
1144
+ ) -> TED:
1145
+ model = cls.model_class().load(
1146
+ str(model_utilities["tf_model_file"]),
1147
+ model_data_example,
1148
+ predict_data_example,
1149
+ data_signature=model_data_example.get_signature(),
1150
+ config=model_utilities["model_config"],
1151
+ max_history_featurizer_is_used=isinstance(
1152
+ featurizer, MaxHistoryTrackerFeaturizer
1153
+ ),
1154
+ label_data=model_utilities["label_data"],
1155
+ entity_tag_specs=model_utilities["entity_tag_specs"],
1156
+ finetune_mode=should_finetune,
1157
+ )
1158
+ return model
1159
+
1160
+ @classmethod
1161
+ def _construct_model_initialization_data(
1162
+ cls, loaded_data: Dict[Text, Dict[Text, List[FeatureArray]]]
1163
+ ) -> Tuple[RasaModelData, RasaModelData]:
1164
+ model_data_example = RasaModelData(
1165
+ label_key=LABEL_KEY, label_sub_key=LABEL_SUB_KEY, data=loaded_data
1166
+ )
1167
+ predict_data_example = RasaModelData(
1168
+ label_key=LABEL_KEY,
1169
+ label_sub_key=LABEL_SUB_KEY,
1170
+ data={
1171
+ feature_name: features
1172
+ for feature_name, features in model_data_example.items()
1173
+ # we need to remove label features for prediction if they are present
1174
+ if feature_name in PREDICTION_FEATURES
1175
+ },
1176
+ )
1177
+ return model_data_example, predict_data_example
1178
+
1179
+ @classmethod
1180
+ def _update_loaded_params(cls, meta: Dict[Text, Any]) -> Dict[Text, Any]:
1181
+ meta = rasa.utils.train_utils.update_confidence_type(meta)
1182
+ meta = rasa.utils.train_utils.update_similarity_type(meta)
1183
+
1184
+ return meta
1185
+
1186
+
1187
+ class TED(TransformerRasaModel):
1188
+ """TED model architecture from https://arxiv.org/abs/1910.00486."""
1189
+
1190
+ def __init__(
1191
+ self,
1192
+ data_signature: Dict[Text, Dict[Text, List[FeatureSignature]]],
1193
+ config: Dict[Text, Any],
1194
+ max_history_featurizer_is_used: bool,
1195
+ label_data: RasaModelData,
1196
+ entity_tag_specs: Optional[List[EntityTagSpec]],
1197
+ ) -> None:
1198
+ """Initializes the TED model.
1199
+
1200
+ Args:
1201
+ data_signature: the data signature of the input data
1202
+ config: the model configuration
1203
+ max_history_featurizer_is_used: if 'True'
1204
+ only the last dialogue turn will be used
1205
+ label_data: the label data
1206
+ entity_tag_specs: the entity tag specifications
1207
+ """
1208
+ super().__init__("TED", config, data_signature, label_data)
1209
+
1210
+ self.max_history_featurizer_is_used = max_history_featurizer_is_used
1211
+
1212
+ self.predict_data_signature = {
1213
+ feature_name: features
1214
+ for feature_name, features in data_signature.items()
1215
+ if feature_name in PREDICTION_FEATURES
1216
+ }
1217
+
1218
+ self._entity_tag_specs = entity_tag_specs
1219
+
1220
+ # metrics
1221
+ self.action_loss = tf.keras.metrics.Mean(name="loss")
1222
+ self.action_acc = tf.keras.metrics.Mean(name="acc")
1223
+ self.entity_loss = tf.keras.metrics.Mean(name="e_loss")
1224
+ self.entity_f1 = tf.keras.metrics.Mean(name="e_f1")
1225
+ self.metrics_to_log += ["loss", "acc"]
1226
+ if self.config[ENTITY_RECOGNITION]:
1227
+ self.metrics_to_log += ["e_loss", "e_f1"]
1228
+
1229
+ # needed for efficient prediction
1230
+ self.all_labels_embed: Optional[tf.Tensor] = None
1231
+
1232
+ self._prepare_layers()
1233
+
1234
+ def _check_data(self) -> None:
1235
+ if not any(key in [INTENT, TEXT] for key in self.data_signature.keys()):
1236
+ raise RasaException(
1237
+ f"No user features specified. "
1238
+ f"Cannot train '{self.__class__.__name__}' model."
1239
+ )
1240
+
1241
+ if not any(
1242
+ key in [ACTION_NAME, ACTION_TEXT] for key in self.data_signature.keys()
1243
+ ):
1244
+ raise ValueError(
1245
+ f"No action features specified. "
1246
+ f"Cannot train '{self.__class__.__name__}' model."
1247
+ )
1248
+ if LABEL not in self.data_signature:
1249
+ raise ValueError(
1250
+ f"No label features specified. "
1251
+ f"Cannot train '{self.__class__.__name__}' model."
1252
+ )
1253
+
1254
+ # ---CREATING LAYERS HELPERS---
1255
+
1256
+ def _prepare_layers(self) -> None:
1257
+ for name in self.data_signature.keys():
1258
+ self._prepare_input_layers(
1259
+ name, self.data_signature[name], is_label_attribute=False
1260
+ )
1261
+ self._prepare_encoding_layers(name)
1262
+
1263
+ for name in self.label_signature.keys():
1264
+ self._prepare_input_layers(
1265
+ name, self.label_signature[name], is_label_attribute=True
1266
+ )
1267
+ self._prepare_encoding_layers(name)
1268
+
1269
+ self._tf_layers[
1270
+ f"transformer.{DIALOGUE}"
1271
+ ] = rasa_layers.prepare_transformer_layer(
1272
+ attribute_name=DIALOGUE,
1273
+ config=self.config,
1274
+ num_layers=self.config[NUM_TRANSFORMER_LAYERS][DIALOGUE],
1275
+ units=self.config[TRANSFORMER_SIZE][DIALOGUE],
1276
+ drop_rate=self.config[DROP_RATE_DIALOGUE],
1277
+ # use bidirectional transformer, because
1278
+ # we will invert dialogue sequence so that the last turn is located
1279
+ # at the first position and would always have
1280
+ # exactly the same positional encoding
1281
+ unidirectional=not self.max_history_featurizer_is_used,
1282
+ )
1283
+
1284
+ self._prepare_label_classification_layers(DIALOGUE)
1285
+
1286
+ if self.config[ENTITY_RECOGNITION]:
1287
+ self._prepare_entity_recognition_layers()
1288
+
1289
+ def _prepare_input_layers(
1290
+ self,
1291
+ attribute_name: Text,
1292
+ attribute_signature: Dict[Text, List[FeatureSignature]],
1293
+ is_label_attribute: bool = False,
1294
+ ) -> None:
1295
+ """Prepares feature processing layers for sentence/sequence-level features.
1296
+
1297
+ Distinguishes between label features and other features, not applying input
1298
+ dropout to the label ones.
1299
+ """
1300
+ # Disable input dropout in the config to be used if this is a label attribute.
1301
+ if is_label_attribute:
1302
+ config_to_use = self.config.copy()
1303
+ config_to_use.update(
1304
+ {SPARSE_INPUT_DROPOUT: False, DENSE_INPUT_DROPOUT: False}
1305
+ )
1306
+ else:
1307
+ config_to_use = self.config
1308
+ # Attributes with sequence-level features also have sentence-level features,
1309
+ # all these need to be combined and further processed.
1310
+ if attribute_name in SEQUENCE_FEATURES_TO_ENCODE:
1311
+ self._tf_layers[
1312
+ f"sequence_layer.{attribute_name}"
1313
+ ] = rasa_layers.RasaSequenceLayer(
1314
+ attribute_name, attribute_signature, config_to_use
1315
+ )
1316
+ # Attributes without sequence-level features require some actual feature
1317
+ # processing only if they have sentence-level features. Attributes with no
1318
+ # sequence- and sentence-level features (dialogue, entity_tags, label) are
1319
+ # skipped here.
1320
+ elif SENTENCE in attribute_signature:
1321
+ self._tf_layers[
1322
+ f"sparse_dense_concat_layer.{attribute_name}"
1323
+ ] = rasa_layers.ConcatenateSparseDenseFeatures(
1324
+ attribute=attribute_name,
1325
+ feature_type=SENTENCE,
1326
+ feature_type_signature=attribute_signature[SENTENCE],
1327
+ config=config_to_use,
1328
+ )
1329
+
1330
+ def _prepare_encoding_layers(self, name: Text) -> None:
1331
+ """Create Ffnn encoding layer used just before combining all dialogue features.
1332
+
1333
+ Args:
1334
+ name: attribute name
1335
+ """
1336
+ # create encoding layers only for the features which should be encoded;
1337
+ if name not in SENTENCE_FEATURES_TO_ENCODE + LABEL_FEATURES_TO_ENCODE:
1338
+ return
1339
+ # check that there are SENTENCE features for the attribute name in data
1340
+ if (
1341
+ name in SENTENCE_FEATURES_TO_ENCODE
1342
+ and FEATURE_TYPE_SENTENCE not in self.data_signature[name]
1343
+ ):
1344
+ return
1345
+ # same for label_data
1346
+ if (
1347
+ name in LABEL_FEATURES_TO_ENCODE
1348
+ and FEATURE_TYPE_SENTENCE not in self.label_signature[name]
1349
+ ):
1350
+ return
1351
+
1352
+ self._prepare_ffnn_layer(
1353
+ f"{name}",
1354
+ [self.config[ENCODING_DIMENSION]],
1355
+ self.config[DROP_RATE_DIALOGUE],
1356
+ prefix="encoding_layer",
1357
+ )
1358
+
1359
+ # ---GRAPH BUILDING HELPERS---
1360
+
1361
+ @staticmethod
1362
+ def _compute_dialogue_indices(
1363
+ tf_batch_data: Dict[Text, Dict[Text, List[tf.Tensor]]]
1364
+ ) -> None:
1365
+ dialogue_lengths = tf.cast(tf_batch_data[DIALOGUE][LENGTH][0], dtype=tf.int32)
1366
+ # wrap in a list, because that's the structure of tf_batch_data
1367
+ tf_batch_data[DIALOGUE][INDICES] = [
1368
+ (
1369
+ tf.map_fn(
1370
+ tf.range,
1371
+ dialogue_lengths,
1372
+ fn_output_signature=tf.RaggedTensorSpec(
1373
+ shape=[None], dtype=tf.int32
1374
+ ),
1375
+ )
1376
+ ).values
1377
+ ]
1378
+
1379
+ def _create_all_labels_embed(self) -> Tuple[tf.Tensor, tf.Tensor]:
1380
+ all_label_ids = self.tf_label_data[LABEL_KEY][LABEL_SUB_KEY][0]
1381
+ # labels cannot have all features "fake"
1382
+ all_labels_encoded = {}
1383
+ for key in self.tf_label_data.keys():
1384
+ if key != LABEL_KEY:
1385
+ attribute_features, _, _ = self._encode_real_features_per_attribute(
1386
+ self.tf_label_data, key
1387
+ )
1388
+ all_labels_encoded[key] = attribute_features
1389
+
1390
+ x = self._collect_label_attribute_encodings(all_labels_encoded)
1391
+
1392
+ # additional sequence axis is artifact of our RasaModelData creation
1393
+ # TODO check whether this should be solved in data creation
1394
+ x = tf.squeeze(x, axis=1)
1395
+
1396
+ all_labels_embed = self._tf_layers[f"embed.{LABEL}"](x)
1397
+
1398
+ return all_label_ids, all_labels_embed
1399
+
1400
+ @staticmethod
1401
+ def _collect_label_attribute_encodings(
1402
+ all_labels_encoded: Dict[Text, tf.Tensor]
1403
+ ) -> tf.Tensor:
1404
+ # Initialize with at least one attribute first
1405
+ # so that the subsequent TF ops are simplified.
1406
+ all_attributes_present = list(all_labels_encoded.keys())
1407
+ x = all_labels_encoded.pop(all_attributes_present[0])
1408
+
1409
+ # Add remaining attributes
1410
+ for attribute in all_labels_encoded:
1411
+ x += all_labels_encoded.get(attribute)
1412
+ return x
1413
+
1414
+ def _embed_dialogue(
1415
+ self,
1416
+ dialogue_in: tf.Tensor,
1417
+ tf_batch_data: Dict[Text, Dict[Text, List[tf.Tensor]]],
1418
+ ) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor, Optional[tf.Tensor]]:
1419
+ """Creates dialogue level embedding and mask.
1420
+
1421
+ Args:
1422
+ dialogue_in: The encoded dialogue.
1423
+ tf_batch_data: Batch in model data format.
1424
+
1425
+ Returns:
1426
+ The dialogue embedding, the mask, and (for diagnostic purposes)
1427
+ also the attention weights.
1428
+ """
1429
+ dialogue_lengths = tf.cast(tf_batch_data[DIALOGUE][LENGTH][0], tf.int32)
1430
+ mask = rasa_layers.compute_mask(dialogue_lengths)
1431
+
1432
+ if self.max_history_featurizer_is_used:
1433
+ # invert dialogue sequence so that the last turn would always have
1434
+ # exactly the same positional encoding
1435
+ dialogue_in = tf.reverse_sequence(dialogue_in, dialogue_lengths, seq_axis=1)
1436
+
1437
+ dialogue_transformed, attention_weights = self._tf_layers[
1438
+ f"transformer.{DIALOGUE}"
1439
+ ](dialogue_in, 1 - mask, self._training)
1440
+ dialogue_transformed = tf.nn.gelu(dialogue_transformed)
1441
+
1442
+ if self.max_history_featurizer_is_used:
1443
+ # pick last vector if max history featurizer is used, since we inverted
1444
+ # dialogue sequence, the last vector is actually the first one
1445
+ dialogue_transformed = dialogue_transformed[:, :1, :]
1446
+ mask = tf.expand_dims(self._last_token(mask, dialogue_lengths), 1)
1447
+ elif not self._training:
1448
+ # during prediction we don't care about previous dialogue turns,
1449
+ # so to save computation time, use only the last one
1450
+ dialogue_transformed = tf.expand_dims(
1451
+ self._last_token(dialogue_transformed, dialogue_lengths), 1
1452
+ )
1453
+ mask = tf.expand_dims(self._last_token(mask, dialogue_lengths), 1)
1454
+
1455
+ dialogue_embed = self._tf_layers[f"embed.{DIALOGUE}"](dialogue_transformed)
1456
+
1457
+ return dialogue_embed, mask, dialogue_transformed, attention_weights
1458
+
1459
+ def _encode_features_per_attribute(
1460
+ self, tf_batch_data: Dict[Text, Dict[Text, List[tf.Tensor]]], attribute: Text
1461
+ ) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
1462
+ # The input is a representation of 4d tensor of
1463
+ # shape (batch-size x dialogue-len x sequence-len x units) in 3d of shape
1464
+ # (sum of dialogue history length for all tensors in the batch x
1465
+ # max sequence length x number of features).
1466
+
1467
+ # However, some dialogue turns contain non existent state features,
1468
+ # e.g. `intent` and `text` features are mutually exclusive,
1469
+ # as well as `action_name` and `action_text` are mutually exclusive,
1470
+ # or some dialogue turns don't contain any `slots`.
1471
+ # In order to create 4d full tensors, we created "fake" zero features for
1472
+ # these non existent state features. And filtered them during batch generation.
1473
+ # Therefore the first dimensions for different attributes are different.
1474
+ # It could happen that some batches don't contain "real" features at all,
1475
+ # e.g. large number of stories don't contain any `slots`.
1476
+ # Therefore actual input tensors will be empty.
1477
+ # Since we need actual numbers to create dialogue turn features, we create
1478
+ # zero tensors in `_encode_fake_features_per_attribute` for these attributes.
1479
+ return tf.cond(
1480
+ tf.shape(tf_batch_data[attribute][SENTENCE][0])[0] > 0,
1481
+ lambda: self._encode_real_features_per_attribute(tf_batch_data, attribute),
1482
+ lambda: self._encode_fake_features_per_attribute(tf_batch_data, attribute),
1483
+ )
1484
+
1485
+ def _encode_fake_features_per_attribute(
1486
+ self, tf_batch_data: Dict[Text, Dict[Text, List[tf.Tensor]]], attribute: Text
1487
+ ) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
1488
+ """Returns dummy outputs for fake features of a given attribute.
1489
+
1490
+ Needs to match the outputs of `_encode_real_features_per_attribute` in shape
1491
+ but these outputs will be filled with zeros.
1492
+
1493
+ Args:
1494
+ tf_batch_data: Maps each attribute to its features and masks.
1495
+ attribute: The attribute whose fake features will be "processed", e.g.
1496
+ `ACTION_NAME`, `INTENT`.
1497
+
1498
+ Returns:
1499
+ attribute_features: A tensor of shape `(batch_size, dialogue_length, units)`
1500
+ filled with zeros.
1501
+ text_output: Only for `TEXT` attribute (otherwise an empty tensor): A tensor
1502
+ of shape `(combined batch_size & dialogue_length, max seq length,
1503
+ units)` filled with zeros.
1504
+ text_sequence_lengths: Only for `TEXT` attribute, otherwise an empty tensor:
1505
+ Of hape `(combined batch_size & dialogue_length, 1)`, filled with zeros.
1506
+ """
1507
+ # we need to create real zero tensors with appropriate batch and dialogue dim
1508
+ # because they are passed to dialogue transformer
1509
+ attribute_mask = tf_batch_data[attribute][MASK][0]
1510
+
1511
+ # determine all dimensions so that fake features of the correct shape can be
1512
+ # created
1513
+ batch_dim = tf.shape(attribute_mask)[0]
1514
+ dialogue_dim = tf.shape(attribute_mask)[1]
1515
+ if attribute in set(SENTENCE_FEATURES_TO_ENCODE + LABEL_FEATURES_TO_ENCODE):
1516
+ units = self.config[ENCODING_DIMENSION]
1517
+ else:
1518
+ # state-level attributes don't use an encoding layer, hence their size is
1519
+ # just the output size of the corresponding sparse+dense feature combining
1520
+ # layer
1521
+ units = self._tf_layers[
1522
+ f"sparse_dense_concat_layer.{attribute}"
1523
+ ].output_units
1524
+
1525
+ attribute_features = tf.zeros(
1526
+ (batch_dim, dialogue_dim, units), dtype=tf.float32
1527
+ )
1528
+
1529
+ # Only for user text, the transformer output and sequence lengths also have to
1530
+ # be created (here using fake features) to enable entity recognition training
1531
+ # and prediction.
1532
+ if attribute == TEXT:
1533
+ # we just need to get the correct last dimension size from the prepared
1534
+ # transformer
1535
+ text_units = self._tf_layers[f"sequence_layer.{attribute}"].output_units
1536
+ text_output = tf.zeros((0, 0, text_units), dtype=tf.float32)
1537
+ text_sequence_lengths = tf.zeros((0,), dtype=tf.int32)
1538
+ else:
1539
+ # simulate None with empty tensor of zeros
1540
+ text_output = tf.zeros((0,))
1541
+ text_sequence_lengths = tf.zeros((0,))
1542
+
1543
+ return attribute_features, text_output, text_sequence_lengths
1544
+
1545
+ @staticmethod
1546
+ def _create_last_dialogue_turns_mask(
1547
+ tf_batch_data: Dict[Text, Dict[Text, List[tf.Tensor]]], attribute: Text
1548
+ ) -> tf.Tensor:
1549
+ # Since max_history_featurizer_is_used is True,
1550
+ # we need to find the locations of last dialogue turns in
1551
+ # (combined batch dimension and dialogue length,) dimension,
1552
+ # so that we can use `_sequence_lengths` as a boolean mask to pick
1553
+ # which ones are "real" textual input in these last dialogue turns.
1554
+
1555
+ # In order to do that we can use given `dialogue_lengths`.
1556
+ # For example:
1557
+ # If we have `dialogue_lengths = [2, 1, 3]`, than
1558
+ # `dialogue_indices = [0, 1, 0, 0, 1, 2]` here we can spot that `0`
1559
+ # always indicates the first dialogue turn,
1560
+ # which means that previous dialogue turn is the last dialogue turn.
1561
+ # Combining this with the fact that the last element in
1562
+ # `dialogue_indices` is always the last dialogue turn, we can add
1563
+ # a `0` to the end, getting
1564
+ # `_dialogue_indices = [0, 1, 0, 0, 1, 2, 0]`.
1565
+ # Then removing the first element
1566
+ # `_last_dialogue_turn_inverse_indicator = [1, 0, 0, 1, 2, 0]`
1567
+ # we see that `0` points to the last dialogue turn.
1568
+ # We convert all positive numbers to `True` and take
1569
+ # the inverse mask to get
1570
+ # `last_dialogue_mask = [0, 1, 1, 0, 0, 1],
1571
+ # which precisely corresponds to the fact that first dialogue is of
1572
+ # length 2, the second 1 and the third 3.
1573
+ last_dialogue_turn_mask = tf.math.logical_not(
1574
+ tf.cast(
1575
+ tf.concat(
1576
+ [
1577
+ tf_batch_data[DIALOGUE][INDICES][0],
1578
+ tf.zeros((1,), dtype=tf.int32),
1579
+ ],
1580
+ axis=0,
1581
+ )[1:],
1582
+ dtype=tf.bool,
1583
+ )
1584
+ )
1585
+ # get only the indices of real inputs
1586
+ return tf.boolean_mask(
1587
+ last_dialogue_turn_mask,
1588
+ tf.reshape(tf_batch_data[attribute][SEQUENCE_LENGTH][0], (-1,)),
1589
+ )
1590
+
1591
+ def _encode_real_features_per_attribute(
1592
+ self, tf_batch_data: Dict[Text, Dict[Text, List[tf.Tensor]]], attribute: Text
1593
+ ) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
1594
+ """Encodes features for a given attribute.
1595
+
1596
+ Args:
1597
+ tf_batch_data: Maps each attribute to its features and masks.
1598
+ attribute: the attribute we will encode features for
1599
+ (e.g., ACTION_NAME, INTENT)
1600
+
1601
+ Returns:
1602
+ attribute_features: A tensor of shape `(batch_size, dialogue_length, units)`
1603
+ with all features for `attribute` processed and combined. If sequence-
1604
+ level features are present, the sequence dimension is eliminated using
1605
+ a transformer.
1606
+ text_output: Only for `TEXT` attribute (otherwise an empty tensor): A tensor
1607
+ of shape `(combined batch_size & dialogue_length, max seq length,
1608
+ units)` containing token-level embeddings further used for entity
1609
+ extraction from user text. Similar to `attribute_features` but returned
1610
+ for all tokens, not just for the last one.
1611
+ text_sequence_lengths: Only for `TEXT` attribute, otherwise an empty tensor:
1612
+ Shape `(combined batch_size & dialogue_length, 1)`, containing the
1613
+ sequence length for user text examples in `text_output`. The sequence
1614
+ length is effectively the number of tokens + 1 (to account also for
1615
+ sentence-level features). Needed for entity extraction from user text.
1616
+ """
1617
+ # simulate None with empty tensor of zeros
1618
+ text_output = tf.zeros((0,))
1619
+ text_sequence_lengths = tf.zeros((0,))
1620
+
1621
+ if attribute in SEQUENCE_FEATURES_TO_ENCODE:
1622
+ # get lengths of real token sequences as a 3D tensor
1623
+ sequence_feature_lengths = self._get_sequence_feature_lengths(
1624
+ tf_batch_data, attribute
1625
+ )
1626
+
1627
+ # sequence_feature_lengths contain `0` for "fake" features, while
1628
+ # tf_batch_data[attribute] contains only "real" features. Hence, we need to
1629
+ # get rid of the lengths that are 0. This step produces a 1D tensor.
1630
+ sequence_feature_lengths = tf.boolean_mask(
1631
+ sequence_feature_lengths, sequence_feature_lengths
1632
+ )
1633
+
1634
+ attribute_features, _, _, _, _, _ = self._tf_layers[
1635
+ f"sequence_layer.{attribute}"
1636
+ ](
1637
+ (
1638
+ tf_batch_data[attribute][SEQUENCE],
1639
+ tf_batch_data[attribute][SENTENCE],
1640
+ sequence_feature_lengths,
1641
+ ),
1642
+ training=self._training,
1643
+ )
1644
+
1645
+ combined_sentence_sequence_feature_lengths = sequence_feature_lengths + 1
1646
+
1647
+ # Only for user text, the transformer output and sequence lengths also have
1648
+ # to be returned to enable entity recognition training and prediction.
1649
+ if attribute == TEXT:
1650
+ text_output = attribute_features
1651
+ text_sequence_lengths = combined_sentence_sequence_feature_lengths
1652
+
1653
+ if self.max_history_featurizer_is_used:
1654
+ # get the location of all last dialogue inputs
1655
+ last_dialogue_turns_mask = self._create_last_dialogue_turns_mask(
1656
+ tf_batch_data, attribute
1657
+ )
1658
+ # pick outputs that correspond to the last dialogue turns
1659
+ text_output = tf.boolean_mask(text_output, last_dialogue_turns_mask)
1660
+ text_sequence_lengths = tf.boolean_mask(
1661
+ text_sequence_lengths, last_dialogue_turns_mask
1662
+ )
1663
+
1664
+ # resulting attribute features will have shape
1665
+ # combined batch dimension and dialogue length x 1 x units
1666
+ attribute_features = tf.expand_dims(
1667
+ self._last_token(
1668
+ attribute_features, combined_sentence_sequence_feature_lengths
1669
+ ),
1670
+ axis=1,
1671
+ )
1672
+
1673
+ # for attributes without sequence-level features, all we need is to combine the
1674
+ # sparse and dense sentence-level features into one
1675
+ else:
1676
+ # resulting attribute features will have shape
1677
+ # combined batch dimension and dialogue length x 1 x units
1678
+ attribute_features = self._tf_layers[
1679
+ f"sparse_dense_concat_layer.{attribute}"
1680
+ ]((tf_batch_data[attribute][SENTENCE],), training=self._training)
1681
+
1682
+ if attribute in SENTENCE_FEATURES_TO_ENCODE + LABEL_FEATURES_TO_ENCODE:
1683
+ attribute_features = self._tf_layers[f"encoding_layer.{attribute}"](
1684
+ attribute_features, self._training
1685
+ )
1686
+
1687
+ # attribute features have shape
1688
+ # (combined batch dimension and dialogue length x 1 x units)
1689
+ # convert them back to their original shape of
1690
+ # batch size x dialogue length x units
1691
+ attribute_features = self._convert_to_original_shape(
1692
+ attribute_features, tf_batch_data, attribute
1693
+ )
1694
+
1695
+ return attribute_features, text_output, text_sequence_lengths
1696
+
1697
+ @staticmethod
1698
+ def _convert_to_original_shape(
1699
+ attribute_features: tf.Tensor,
1700
+ tf_batch_data: Dict[Text, Dict[Text, List[tf.Tensor]]],
1701
+ attribute: Text,
1702
+ ) -> tf.Tensor:
1703
+ """Transform attribute features back to original shape.
1704
+
1705
+ Given shape: (combined batch and dialogue dimension x 1 x units)
1706
+ Original shape: (batch x dialogue length x units)
1707
+
1708
+ Args:
1709
+ attribute_features: the "real" features to convert
1710
+ tf_batch_data: dictionary mapping every attribute to its features and masks
1711
+ attribute: the attribute we will encode features for
1712
+ (e.g., ACTION_NAME, INTENT)
1713
+
1714
+ Returns:
1715
+ The converted attribute features
1716
+ """
1717
+ # in order to convert the attribute features with shape
1718
+ # (combined batch-size and dialogue length x 1 x units)
1719
+ # to a shape of (batch-size x dialogue length x units)
1720
+ # we use tf.scatter_nd. Therefore, we need the target shape and the indices
1721
+ # mapping the values of attribute features to the position in the resulting
1722
+ # tensor.
1723
+
1724
+ # attribute_mask has shape batch x dialogue_len x 1
1725
+ attribute_mask = tf_batch_data[attribute][MASK][0]
1726
+
1727
+ if attribute in SENTENCE_FEATURES_TO_ENCODE + STATE_LEVEL_FEATURES:
1728
+ dialogue_lengths = tf.cast(
1729
+ tf_batch_data[DIALOGUE][LENGTH][0], dtype=tf.int32
1730
+ )
1731
+ dialogue_indices = tf_batch_data[DIALOGUE][INDICES][0]
1732
+ else:
1733
+ # for labels, dialogue length is a fake dim and equal to 1
1734
+ dialogue_lengths = tf.ones((tf.shape(attribute_mask)[0],), dtype=tf.int32)
1735
+ dialogue_indices = tf.zeros((tf.shape(attribute_mask)[0],), dtype=tf.int32)
1736
+
1737
+ batch_dim = tf.shape(attribute_mask)[0]
1738
+ dialogue_dim = tf.shape(attribute_mask)[1]
1739
+ units = attribute_features.shape[-1]
1740
+
1741
+ # attribute_mask has shape (batch x dialogue_len x 1), remove last dimension
1742
+ attribute_mask = tf.cast(tf.squeeze(attribute_mask, axis=-1), dtype=tf.int32)
1743
+ # sum of attribute mask contains number of dialogue turns with "real" features
1744
+ non_fake_dialogue_lengths = tf.reduce_sum(attribute_mask, axis=-1)
1745
+ # create the batch indices
1746
+ batch_indices = tf.repeat(tf.range(batch_dim), non_fake_dialogue_lengths)
1747
+
1748
+ # attribute_mask has shape (batch x dialogue_len x 1), while
1749
+ # dialogue_indices has shape (combined_dialogue_len,)
1750
+ # in order to find positions of real input we need to flatten
1751
+ # attribute mask to (combined_dialogue_len,)
1752
+ dialogue_indices_mask = tf.boolean_mask(
1753
+ attribute_mask, tf.sequence_mask(dialogue_lengths, dtype=tf.int32)
1754
+ )
1755
+ # pick only those indices that contain "real" input
1756
+ dialogue_indices = tf.boolean_mask(dialogue_indices, dialogue_indices_mask)
1757
+
1758
+ indices = tf.stack([batch_indices, dialogue_indices], axis=1)
1759
+
1760
+ shape = tf.convert_to_tensor([batch_dim, dialogue_dim, units])
1761
+ attribute_features = tf.squeeze(attribute_features, axis=1)
1762
+
1763
+ return tf.scatter_nd(indices, attribute_features, shape)
1764
+
1765
+ def _process_batch_data(
1766
+ self, tf_batch_data: Dict[Text, Dict[Text, List[tf.Tensor]]]
1767
+ ) -> Tuple[tf.Tensor, Optional[tf.Tensor], Optional[tf.Tensor]]:
1768
+ """Encodes batch data.
1769
+
1770
+ Combines intent and text and action name and action text if both are present.
1771
+
1772
+ Args:
1773
+ tf_batch_data: dictionary mapping every attribute to its features and masks
1774
+
1775
+ Returns:
1776
+ Tensor: encoding of all features in the batch, combined;
1777
+ """
1778
+ # encode each attribute present in tf_batch_data
1779
+ text_output = None
1780
+ text_sequence_lengths = None
1781
+ batch_encoded = {}
1782
+ for attribute in tf_batch_data.keys():
1783
+ if attribute in SENTENCE_FEATURES_TO_ENCODE + STATE_LEVEL_FEATURES:
1784
+ (
1785
+ attribute_features,
1786
+ _text_output,
1787
+ _text_sequence_lengths,
1788
+ ) = self._encode_features_per_attribute(tf_batch_data, attribute)
1789
+
1790
+ batch_encoded[attribute] = attribute_features
1791
+ if attribute == TEXT:
1792
+ text_output = _text_output
1793
+ text_sequence_lengths = _text_sequence_lengths
1794
+
1795
+ # if both action text and action name are present, combine them; otherwise,
1796
+ # return the one which is present
1797
+
1798
+ if (
1799
+ batch_encoded.get(ACTION_TEXT) is not None
1800
+ and batch_encoded.get(ACTION_NAME) is not None
1801
+ ):
1802
+ batch_action = batch_encoded.pop(ACTION_TEXT) + batch_encoded.pop(
1803
+ ACTION_NAME
1804
+ )
1805
+ elif batch_encoded.get(ACTION_TEXT) is not None:
1806
+ batch_action = batch_encoded.pop(ACTION_TEXT)
1807
+ else:
1808
+ batch_action = batch_encoded.pop(ACTION_NAME)
1809
+ # same for user input
1810
+ if (
1811
+ batch_encoded.get(INTENT) is not None
1812
+ and batch_encoded.get(TEXT) is not None
1813
+ ):
1814
+ batch_user = batch_encoded.pop(INTENT) + batch_encoded.pop(TEXT)
1815
+ elif batch_encoded.get(TEXT) is not None:
1816
+ batch_user = batch_encoded.pop(TEXT)
1817
+ else:
1818
+ batch_user = batch_encoded.pop(INTENT)
1819
+
1820
+ batch_features = [batch_user, batch_action]
1821
+ # once we have user input and previous action,
1822
+ # add all other attributes (SLOTS, ACTIVE_LOOP, etc.) to batch_features;
1823
+ for key in batch_encoded.keys():
1824
+ batch_features.append(batch_encoded.get(key))
1825
+
1826
+ batch_features = tf.concat(batch_features, axis=-1)
1827
+
1828
+ return batch_features, text_output, text_sequence_lengths
1829
+
1830
+ def _reshape_for_entities(
1831
+ self,
1832
+ tf_batch_data: Dict[Text, Dict[Text, List[tf.Tensor]]],
1833
+ dialogue_transformer_output: tf.Tensor,
1834
+ text_output: tf.Tensor,
1835
+ text_sequence_lengths: tf.Tensor,
1836
+ ) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
1837
+ # The first dim of the output of the text sequence transformer is the same
1838
+ # as number of "real" features for `text` at the last dialogue turns
1839
+ # (let's call it `N`),
1840
+ # which corresponds to the first dim of the tag ids tensor.
1841
+ # To calculate the loss for entities we need the output of the text
1842
+ # sequence transformer (shape: N x sequence length x units),
1843
+ # the output of the dialogue transformer
1844
+ # (shape: batch size x dialogue length x units) and the tag ids for the
1845
+ # entities (shape: N x sequence length - 1 x units)
1846
+ # In order to process the tensors, they need to have the same shape.
1847
+ # Convert the output of the dialogue transformer to shape
1848
+ # (N x 1 x units).
1849
+
1850
+ # Note: The CRF layer cannot handle 4D tensors. E.g. we cannot use the shape
1851
+ # batch size x dialogue length x sequence length x units
1852
+
1853
+ # convert the output of the dialogue transformer
1854
+ # to shape (real entity dim x 1 x units)
1855
+ attribute_mask = tf_batch_data[TEXT][MASK][0]
1856
+ dialogue_lengths = tf.cast(tf_batch_data[DIALOGUE][LENGTH][0], tf.int32)
1857
+
1858
+ if self.max_history_featurizer_is_used:
1859
+ # pick outputs that correspond to the last dialogue turns
1860
+ attribute_mask = tf.expand_dims(
1861
+ self._last_token(attribute_mask, dialogue_lengths), axis=1
1862
+ )
1863
+ dialogue_transformer_output = tf.boolean_mask(
1864
+ dialogue_transformer_output, tf.squeeze(attribute_mask, axis=-1)
1865
+ )
1866
+
1867
+ # boolean mask removed axis=1, add it back
1868
+ dialogue_transformer_output = tf.expand_dims(
1869
+ dialogue_transformer_output, axis=1
1870
+ )
1871
+
1872
+ # broadcast the dialogue transformer output sequence-length-times to get the
1873
+ # same shape as the text sequence transformer output
1874
+ dialogue_transformer_output = tf.tile(
1875
+ dialogue_transformer_output, (1, tf.shape(text_output)[1], 1)
1876
+ )
1877
+
1878
+ # concat the output of the dialogue transformer to the output of the text
1879
+ # sequence transformer (adding context)
1880
+ # resulting shape (N x sequence length x 2 units)
1881
+ # N = number of "real" features for `text` at the last dialogue turns
1882
+ text_transformed = tf.concat(
1883
+ [text_output, dialogue_transformer_output], axis=-1
1884
+ )
1885
+ text_mask = rasa_layers.compute_mask(text_sequence_lengths)
1886
+
1887
+ # add zeros to match the shape of text_transformed, because
1888
+ # max sequence length might differ, since it is calculated dynamically
1889
+ # based on a subset of sequence lengths
1890
+ sequence_diff = tf.shape(text_transformed)[1] - tf.shape(text_mask)[1]
1891
+ text_mask = tf.pad(text_mask, [[0, 0], [0, sequence_diff], [0, 0]])
1892
+
1893
+ # remove additional dims and sentence features
1894
+ text_sequence_lengths = tf.reshape(text_sequence_lengths, (-1,)) - 1
1895
+
1896
+ return text_transformed, text_mask, text_sequence_lengths
1897
+
1898
+ # ---TRAINING---
1899
+
1900
+ def _batch_loss_entities(
1901
+ self,
1902
+ tf_batch_data: Dict[Text, Dict[Text, List[tf.Tensor]]],
1903
+ dialogue_transformer_output: tf.Tensor,
1904
+ text_output: tf.Tensor,
1905
+ text_sequence_lengths: tf.Tensor,
1906
+ ) -> tf.Tensor:
1907
+ # It could happen that some batches don't contain "real" features for `text`,
1908
+ # e.g. large number of stories are intent only.
1909
+ # Therefore actual `text_output` will be empty.
1910
+ # We cannot create a loss with empty tensors.
1911
+ # Since we need actual numbers to create a full loss, we output
1912
+ # zero in this case.
1913
+ return tf.cond(
1914
+ tf.shape(text_output)[0] > 0,
1915
+ lambda: self._real_batch_loss_entities(
1916
+ tf_batch_data,
1917
+ dialogue_transformer_output,
1918
+ text_output,
1919
+ text_sequence_lengths,
1920
+ ),
1921
+ lambda: tf.constant(0.0),
1922
+ )
1923
+
1924
+ def _real_batch_loss_entities(
1925
+ self,
1926
+ tf_batch_data: Dict[Text, Dict[Text, List[tf.Tensor]]],
1927
+ dialogue_transformer_output: tf.Tensor,
1928
+ text_output: tf.Tensor,
1929
+ text_sequence_lengths: tf.Tensor,
1930
+ ) -> tf.Tensor:
1931
+
1932
+ text_transformed, text_mask, text_sequence_lengths = self._reshape_for_entities(
1933
+ tf_batch_data,
1934
+ dialogue_transformer_output,
1935
+ text_output,
1936
+ text_sequence_lengths,
1937
+ )
1938
+
1939
+ tag_ids = tf_batch_data[ENTITY_TAGS][IDS][0]
1940
+ # add a zero (no entity) for the sentence features to match the shape of inputs
1941
+ sequence_diff = tf.shape(text_transformed)[1] - tf.shape(tag_ids)[1]
1942
+ tag_ids = tf.pad(tag_ids, [[0, 0], [0, sequence_diff], [0, 0]])
1943
+
1944
+ loss, f1, _ = self._calculate_entity_loss(
1945
+ text_transformed,
1946
+ tag_ids,
1947
+ text_mask,
1948
+ text_sequence_lengths,
1949
+ ENTITY_ATTRIBUTE_TYPE,
1950
+ )
1951
+
1952
+ self.entity_loss.update_state(loss)
1953
+ self.entity_f1.update_state(f1)
1954
+
1955
+ return loss
1956
+
1957
+ @staticmethod
1958
+ def _get_labels_embed(
1959
+ label_ids: tf.Tensor, all_labels_embed: tf.Tensor
1960
+ ) -> tf.Tensor:
1961
+ # instead of processing labels again, gather embeddings from
1962
+ # all_labels_embed using label ids
1963
+
1964
+ indices = tf.cast(label_ids[:, :, 0], tf.int32)
1965
+ labels_embed = tf.gather(all_labels_embed, indices)
1966
+
1967
+ return labels_embed
1968
+
1969
+ def batch_loss(
1970
+ self, batch_in: Union[Tuple[tf.Tensor, ...], Tuple[np.ndarray, ...]]
1971
+ ) -> tf.Tensor:
1972
+ """Calculates the loss for the given batch.
1973
+
1974
+ Args:
1975
+ batch_in: The batch.
1976
+
1977
+ Returns:
1978
+ The loss of the given batch.
1979
+ """
1980
+ tf_batch_data = self.batch_to_model_data_format(batch_in, self.data_signature)
1981
+ self._compute_dialogue_indices(tf_batch_data)
1982
+
1983
+ all_label_ids, all_labels_embed = self._create_all_labels_embed()
1984
+
1985
+ label_ids = tf_batch_data[LABEL_KEY][LABEL_SUB_KEY][0]
1986
+ labels_embed = self._get_labels_embed(label_ids, all_labels_embed)
1987
+
1988
+ dialogue_in, text_output, text_sequence_lengths = self._process_batch_data(
1989
+ tf_batch_data
1990
+ )
1991
+ (
1992
+ dialogue_embed,
1993
+ dialogue_mask,
1994
+ dialogue_transformer_output,
1995
+ _,
1996
+ ) = self._embed_dialogue(dialogue_in, tf_batch_data)
1997
+ dialogue_mask = tf.squeeze(dialogue_mask, axis=-1)
1998
+
1999
+ losses = []
2000
+
2001
+ loss, acc = self._tf_layers[f"loss.{LABEL}"](
2002
+ dialogue_embed,
2003
+ labels_embed,
2004
+ label_ids,
2005
+ all_labels_embed,
2006
+ all_label_ids,
2007
+ dialogue_mask,
2008
+ )
2009
+ losses.append(loss)
2010
+
2011
+ if (
2012
+ self.config[ENTITY_RECOGNITION]
2013
+ and text_output is not None
2014
+ and text_sequence_lengths is not None
2015
+ ):
2016
+ losses.append(
2017
+ self._batch_loss_entities(
2018
+ tf_batch_data,
2019
+ dialogue_transformer_output,
2020
+ text_output,
2021
+ text_sequence_lengths,
2022
+ )
2023
+ )
2024
+
2025
+ self.action_loss.update_state(loss)
2026
+ self.action_acc.update_state(acc)
2027
+
2028
+ return tf.math.add_n(losses)
2029
+
2030
+ # ---PREDICTION---
2031
+ def prepare_for_predict(self) -> None:
2032
+ """Prepares the model for prediction."""
2033
+ _, self.all_labels_embed = self._create_all_labels_embed()
2034
+
2035
+ def batch_predict(
2036
+ self, batch_in: Union[Tuple[tf.Tensor, ...], Tuple[np.ndarray, ...]]
2037
+ ) -> Dict[Text, Union[tf.Tensor, Dict[Text, tf.Tensor]]]:
2038
+ """Predicts the output of the given batch.
2039
+
2040
+ Args:
2041
+ batch_in: The batch.
2042
+
2043
+ Returns:
2044
+ The output to predict.
2045
+ """
2046
+ if self.all_labels_embed is None:
2047
+ raise ValueError(
2048
+ "The model was not prepared for prediction. "
2049
+ "Call `prepare_for_predict` first."
2050
+ )
2051
+
2052
+ tf_batch_data = self.batch_to_model_data_format(
2053
+ batch_in, self.predict_data_signature
2054
+ )
2055
+ self._compute_dialogue_indices(tf_batch_data)
2056
+
2057
+ dialogue_in, text_output, text_sequence_lengths = self._process_batch_data(
2058
+ tf_batch_data
2059
+ )
2060
+ (
2061
+ dialogue_embed,
2062
+ dialogue_mask,
2063
+ dialogue_transformer_output,
2064
+ attention_weights,
2065
+ ) = self._embed_dialogue(dialogue_in, tf_batch_data)
2066
+ dialogue_mask = tf.squeeze(dialogue_mask, axis=-1)
2067
+
2068
+ sim_all, scores = self._tf_layers[
2069
+ f"loss.{LABEL}"
2070
+ ].get_similarities_and_confidences_from_embeddings(
2071
+ dialogue_embed[:, :, tf.newaxis, :],
2072
+ self.all_labels_embed[tf.newaxis, tf.newaxis, :, :],
2073
+ dialogue_mask,
2074
+ )
2075
+
2076
+ predictions = {
2077
+ "scores": scores,
2078
+ "similarities": sim_all,
2079
+ DIAGNOSTIC_DATA: {"attention_weights": attention_weights},
2080
+ }
2081
+
2082
+ if (
2083
+ self.config[ENTITY_RECOGNITION]
2084
+ and text_output is not None
2085
+ and text_sequence_lengths is not None
2086
+ ):
2087
+ pred_ids, confidences = self._batch_predict_entities(
2088
+ tf_batch_data,
2089
+ dialogue_transformer_output,
2090
+ text_output,
2091
+ text_sequence_lengths,
2092
+ )
2093
+ name = ENTITY_ATTRIBUTE_TYPE
2094
+ predictions[f"e_{name}_ids"] = pred_ids
2095
+ predictions[f"e_{name}_scores"] = confidences
2096
+
2097
+ return predictions
2098
+
2099
+ def _batch_predict_entities(
2100
+ self,
2101
+ tf_batch_data: Dict[Text, Dict[Text, List[tf.Tensor]]],
2102
+ dialogue_transformer_output: tf.Tensor,
2103
+ text_output: tf.Tensor,
2104
+ text_sequence_lengths: tf.Tensor,
2105
+ ) -> Tuple[tf.Tensor, tf.Tensor]:
2106
+ # It could happen that current prediction turn don't contain
2107
+ # "real" features for `text`,
2108
+ # Therefore actual `text_output` will be empty.
2109
+ # We cannot predict entities with empty tensors.
2110
+ # Since we need to output some tensors of the same shape, we output
2111
+ # zero tensors.
2112
+ return tf.cond(
2113
+ tf.shape(text_output)[0] > 0,
2114
+ lambda: self._real_batch_predict_entities(
2115
+ tf_batch_data,
2116
+ dialogue_transformer_output,
2117
+ text_output,
2118
+ text_sequence_lengths,
2119
+ ),
2120
+ lambda: (
2121
+ # the output is of shape (batch_size, max_seq_len)
2122
+ tf.zeros(tf.shape(text_output)[:2], dtype=tf.int32),
2123
+ tf.zeros(tf.shape(text_output)[:2], dtype=tf.float32),
2124
+ ),
2125
+ )
2126
+
2127
+ def _real_batch_predict_entities(
2128
+ self,
2129
+ tf_batch_data: Dict[Text, Dict[Text, List[tf.Tensor]]],
2130
+ dialogue_transformer_output: tf.Tensor,
2131
+ text_output: tf.Tensor,
2132
+ text_sequence_lengths: tf.Tensor,
2133
+ ) -> Tuple[tf.Tensor, tf.Tensor]:
2134
+
2135
+ text_transformed, _, text_sequence_lengths = self._reshape_for_entities(
2136
+ tf_batch_data,
2137
+ dialogue_transformer_output,
2138
+ text_output,
2139
+ text_sequence_lengths,
2140
+ )
2141
+
2142
+ name = ENTITY_ATTRIBUTE_TYPE
2143
+
2144
+ _logits = self._tf_layers[f"embed.{name}.logits"](text_transformed)
2145
+
2146
+ return self._tf_layers[f"crf.{name}"](_logits, text_sequence_lengths)