rasa-pro 3.12.0.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rasa-pro might be problematic. Click here for more details.

Files changed (790) hide show
  1. README.md +41 -0
  2. rasa/__init__.py +9 -0
  3. rasa/__main__.py +177 -0
  4. rasa/anonymization/__init__.py +2 -0
  5. rasa/anonymization/anonymisation_rule_yaml_reader.py +91 -0
  6. rasa/anonymization/anonymization_pipeline.py +286 -0
  7. rasa/anonymization/anonymization_rule_executor.py +260 -0
  8. rasa/anonymization/anonymization_rule_orchestrator.py +120 -0
  9. rasa/anonymization/schemas/config.yml +47 -0
  10. rasa/anonymization/utils.py +118 -0
  11. rasa/api.py +160 -0
  12. rasa/cli/__init__.py +5 -0
  13. rasa/cli/arguments/__init__.py +0 -0
  14. rasa/cli/arguments/data.py +106 -0
  15. rasa/cli/arguments/default_arguments.py +207 -0
  16. rasa/cli/arguments/evaluate.py +65 -0
  17. rasa/cli/arguments/export.py +51 -0
  18. rasa/cli/arguments/interactive.py +74 -0
  19. rasa/cli/arguments/run.py +219 -0
  20. rasa/cli/arguments/shell.py +17 -0
  21. rasa/cli/arguments/test.py +211 -0
  22. rasa/cli/arguments/train.py +279 -0
  23. rasa/cli/arguments/visualize.py +34 -0
  24. rasa/cli/arguments/x.py +30 -0
  25. rasa/cli/data.py +354 -0
  26. rasa/cli/dialogue_understanding_test.py +251 -0
  27. rasa/cli/e2e_test.py +259 -0
  28. rasa/cli/evaluate.py +222 -0
  29. rasa/cli/export.py +250 -0
  30. rasa/cli/inspect.py +75 -0
  31. rasa/cli/interactive.py +166 -0
  32. rasa/cli/license.py +65 -0
  33. rasa/cli/llm_fine_tuning.py +403 -0
  34. rasa/cli/markers.py +78 -0
  35. rasa/cli/project_templates/__init__.py +0 -0
  36. rasa/cli/project_templates/calm/actions/__init__.py +0 -0
  37. rasa/cli/project_templates/calm/actions/action_template.py +27 -0
  38. rasa/cli/project_templates/calm/actions/add_contact.py +30 -0
  39. rasa/cli/project_templates/calm/actions/db.py +57 -0
  40. rasa/cli/project_templates/calm/actions/list_contacts.py +22 -0
  41. rasa/cli/project_templates/calm/actions/remove_contact.py +35 -0
  42. rasa/cli/project_templates/calm/config.yml +10 -0
  43. rasa/cli/project_templates/calm/credentials.yml +33 -0
  44. rasa/cli/project_templates/calm/data/flows/add_contact.yml +31 -0
  45. rasa/cli/project_templates/calm/data/flows/list_contacts.yml +14 -0
  46. rasa/cli/project_templates/calm/data/flows/remove_contact.yml +29 -0
  47. rasa/cli/project_templates/calm/db/contacts.json +10 -0
  48. rasa/cli/project_templates/calm/domain/add_contact.yml +39 -0
  49. rasa/cli/project_templates/calm/domain/list_contacts.yml +17 -0
  50. rasa/cli/project_templates/calm/domain/remove_contact.yml +38 -0
  51. rasa/cli/project_templates/calm/domain/shared.yml +10 -0
  52. rasa/cli/project_templates/calm/e2e_tests/cancelations/user_cancels_during_a_correction.yml +16 -0
  53. rasa/cli/project_templates/calm/e2e_tests/cancelations/user_changes_mind_on_a_whim.yml +7 -0
  54. rasa/cli/project_templates/calm/e2e_tests/corrections/user_corrects_contact_handle.yml +20 -0
  55. rasa/cli/project_templates/calm/e2e_tests/corrections/user_corrects_contact_name.yml +19 -0
  56. rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_adds_contact_to_their_list.yml +15 -0
  57. rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_lists_contacts.yml +5 -0
  58. rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_removes_contact.yml +11 -0
  59. rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_removes_contact_from_list.yml +12 -0
  60. rasa/cli/project_templates/calm/endpoints.yml +58 -0
  61. rasa/cli/project_templates/default/actions/__init__.py +0 -0
  62. rasa/cli/project_templates/default/actions/actions.py +27 -0
  63. rasa/cli/project_templates/default/config.yml +44 -0
  64. rasa/cli/project_templates/default/credentials.yml +33 -0
  65. rasa/cli/project_templates/default/data/nlu.yml +91 -0
  66. rasa/cli/project_templates/default/data/rules.yml +13 -0
  67. rasa/cli/project_templates/default/data/stories.yml +30 -0
  68. rasa/cli/project_templates/default/domain.yml +34 -0
  69. rasa/cli/project_templates/default/endpoints.yml +42 -0
  70. rasa/cli/project_templates/default/tests/test_stories.yml +91 -0
  71. rasa/cli/project_templates/tutorial/actions/__init__.py +0 -0
  72. rasa/cli/project_templates/tutorial/actions/actions.py +22 -0
  73. rasa/cli/project_templates/tutorial/config.yml +12 -0
  74. rasa/cli/project_templates/tutorial/credentials.yml +33 -0
  75. rasa/cli/project_templates/tutorial/data/flows.yml +8 -0
  76. rasa/cli/project_templates/tutorial/data/patterns.yml +11 -0
  77. rasa/cli/project_templates/tutorial/domain.yml +35 -0
  78. rasa/cli/project_templates/tutorial/endpoints.yml +55 -0
  79. rasa/cli/run.py +143 -0
  80. rasa/cli/scaffold.py +273 -0
  81. rasa/cli/shell.py +141 -0
  82. rasa/cli/studio/__init__.py +0 -0
  83. rasa/cli/studio/download.py +62 -0
  84. rasa/cli/studio/studio.py +296 -0
  85. rasa/cli/studio/train.py +59 -0
  86. rasa/cli/studio/upload.py +62 -0
  87. rasa/cli/telemetry.py +102 -0
  88. rasa/cli/test.py +280 -0
  89. rasa/cli/train.py +278 -0
  90. rasa/cli/utils.py +484 -0
  91. rasa/cli/visualize.py +40 -0
  92. rasa/cli/x.py +206 -0
  93. rasa/constants.py +45 -0
  94. rasa/core/__init__.py +17 -0
  95. rasa/core/actions/__init__.py +0 -0
  96. rasa/core/actions/action.py +1318 -0
  97. rasa/core/actions/action_clean_stack.py +59 -0
  98. rasa/core/actions/action_exceptions.py +24 -0
  99. rasa/core/actions/action_hangup.py +29 -0
  100. rasa/core/actions/action_repeat_bot_messages.py +89 -0
  101. rasa/core/actions/action_run_slot_rejections.py +210 -0
  102. rasa/core/actions/action_trigger_chitchat.py +31 -0
  103. rasa/core/actions/action_trigger_flow.py +109 -0
  104. rasa/core/actions/action_trigger_search.py +31 -0
  105. rasa/core/actions/constants.py +5 -0
  106. rasa/core/actions/custom_action_executor.py +191 -0
  107. rasa/core/actions/direct_custom_actions_executor.py +109 -0
  108. rasa/core/actions/e2e_stub_custom_action_executor.py +72 -0
  109. rasa/core/actions/forms.py +741 -0
  110. rasa/core/actions/grpc_custom_action_executor.py +251 -0
  111. rasa/core/actions/http_custom_action_executor.py +145 -0
  112. rasa/core/actions/loops.py +114 -0
  113. rasa/core/actions/two_stage_fallback.py +186 -0
  114. rasa/core/agent.py +559 -0
  115. rasa/core/auth_retry_tracker_store.py +122 -0
  116. rasa/core/brokers/__init__.py +0 -0
  117. rasa/core/brokers/broker.py +126 -0
  118. rasa/core/brokers/file.py +58 -0
  119. rasa/core/brokers/kafka.py +324 -0
  120. rasa/core/brokers/pika.py +388 -0
  121. rasa/core/brokers/sql.py +86 -0
  122. rasa/core/channels/__init__.py +61 -0
  123. rasa/core/channels/botframework.py +338 -0
  124. rasa/core/channels/callback.py +84 -0
  125. rasa/core/channels/channel.py +456 -0
  126. rasa/core/channels/console.py +241 -0
  127. rasa/core/channels/development_inspector.py +197 -0
  128. rasa/core/channels/facebook.py +419 -0
  129. rasa/core/channels/hangouts.py +329 -0
  130. rasa/core/channels/inspector/.eslintrc.cjs +25 -0
  131. rasa/core/channels/inspector/.gitignore +23 -0
  132. rasa/core/channels/inspector/README.md +54 -0
  133. rasa/core/channels/inspector/assets/favicon.ico +0 -0
  134. rasa/core/channels/inspector/assets/rasa-chat.js +2 -0
  135. rasa/core/channels/inspector/custom.d.ts +3 -0
  136. rasa/core/channels/inspector/dist/assets/arc-861ddd57.js +1 -0
  137. rasa/core/channels/inspector/dist/assets/array-9f3ba611.js +1 -0
  138. rasa/core/channels/inspector/dist/assets/c4Diagram-d0fbc5ce-921f02db.js +10 -0
  139. rasa/core/channels/inspector/dist/assets/classDiagram-936ed81e-b436c4f8.js +2 -0
  140. rasa/core/channels/inspector/dist/assets/classDiagram-v2-c3cb15f1-511a23cb.js +2 -0
  141. rasa/core/channels/inspector/dist/assets/createText-62fc7601-ef476ecd.js +7 -0
  142. rasa/core/channels/inspector/dist/assets/edges-f2ad444c-f1878e0a.js +4 -0
  143. rasa/core/channels/inspector/dist/assets/erDiagram-9d236eb7-fac75185.js +51 -0
  144. rasa/core/channels/inspector/dist/assets/flowDb-1972c806-201c5bbc.js +6 -0
  145. rasa/core/channels/inspector/dist/assets/flowDiagram-7ea5b25a-f904ae41.js +4 -0
  146. rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-b080d6f2.js +1 -0
  147. rasa/core/channels/inspector/dist/assets/flowchart-elk-definition-abe16c3d-1813da66.js +139 -0
  148. rasa/core/channels/inspector/dist/assets/ganttDiagram-9b5ea136-872af172.js +266 -0
  149. rasa/core/channels/inspector/dist/assets/gitGraphDiagram-99d0ae7c-34a0af5a.js +70 -0
  150. rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-128cfa44.ttf +0 -0
  151. rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-21dbcb97.woff +0 -0
  152. rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-222b5e26.svg +329 -0
  153. rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-9ad89b2a.woff2 +0 -0
  154. rasa/core/channels/inspector/dist/assets/index-2c4b9a3b-42ba3e3d.js +1 -0
  155. rasa/core/channels/inspector/dist/assets/index-37817b51.js +1317 -0
  156. rasa/core/channels/inspector/dist/assets/index-3ee28881.css +1 -0
  157. rasa/core/channels/inspector/dist/assets/infoDiagram-736b4530-6b731386.js +7 -0
  158. rasa/core/channels/inspector/dist/assets/init-77b53fdd.js +1 -0
  159. rasa/core/channels/inspector/dist/assets/journeyDiagram-df861f2b-e8579ac6.js +139 -0
  160. rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-60c05ee4.woff +0 -0
  161. rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-8335d9b8.svg +438 -0
  162. rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-9cc39c75.ttf +0 -0
  163. rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-ead13ccf.woff2 +0 -0
  164. rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-16705655.woff2 +0 -0
  165. rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-5aeb07f9.woff +0 -0
  166. rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-9c459044.ttf +0 -0
  167. rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-9e2898a4.svg +435 -0
  168. rasa/core/channels/inspector/dist/assets/layout-89e6403a.js +1 -0
  169. rasa/core/channels/inspector/dist/assets/line-dc73d3fc.js +1 -0
  170. rasa/core/channels/inspector/dist/assets/linear-f5b1d2bc.js +1 -0
  171. rasa/core/channels/inspector/dist/assets/mindmap-definition-beec6740-82cb74fa.js +109 -0
  172. rasa/core/channels/inspector/dist/assets/ordinal-ba9b4969.js +1 -0
  173. rasa/core/channels/inspector/dist/assets/path-53f90ab3.js +1 -0
  174. rasa/core/channels/inspector/dist/assets/pieDiagram-dbbf0591-bdf5f29b.js +35 -0
  175. rasa/core/channels/inspector/dist/assets/quadrantDiagram-4d7f4fd6-c7a0cbe4.js +7 -0
  176. rasa/core/channels/inspector/dist/assets/requirementDiagram-6fc4c22a-7ec5410f.js +52 -0
  177. rasa/core/channels/inspector/dist/assets/sankeyDiagram-8f13d901-caee5554.js +8 -0
  178. rasa/core/channels/inspector/dist/assets/sequenceDiagram-b655622a-2935f8db.js +122 -0
  179. rasa/core/channels/inspector/dist/assets/stateDiagram-59f0c015-8f5d9693.js +1 -0
  180. rasa/core/channels/inspector/dist/assets/stateDiagram-v2-2b26beab-d565d1de.js +1 -0
  181. rasa/core/channels/inspector/dist/assets/styles-080da4f6-75ad421d.js +110 -0
  182. rasa/core/channels/inspector/dist/assets/styles-3dcbcfbf-7e764226.js +159 -0
  183. rasa/core/channels/inspector/dist/assets/styles-9c745c82-7a4e0e61.js +207 -0
  184. rasa/core/channels/inspector/dist/assets/svgDrawCommon-4835440b-4019d1bf.js +1 -0
  185. rasa/core/channels/inspector/dist/assets/timeline-definition-5b62e21b-01ea12df.js +61 -0
  186. rasa/core/channels/inspector/dist/assets/xychartDiagram-2b33534f-89407137.js +7 -0
  187. rasa/core/channels/inspector/dist/index.html +42 -0
  188. rasa/core/channels/inspector/index.html +40 -0
  189. rasa/core/channels/inspector/jest.config.ts +13 -0
  190. rasa/core/channels/inspector/package.json +52 -0
  191. rasa/core/channels/inspector/setupTests.ts +2 -0
  192. rasa/core/channels/inspector/src/App.tsx +220 -0
  193. rasa/core/channels/inspector/src/components/Chat.tsx +95 -0
  194. rasa/core/channels/inspector/src/components/DiagramFlow.tsx +108 -0
  195. rasa/core/channels/inspector/src/components/DialogueInformation.tsx +187 -0
  196. rasa/core/channels/inspector/src/components/DialogueStack.tsx +136 -0
  197. rasa/core/channels/inspector/src/components/ExpandIcon.tsx +16 -0
  198. rasa/core/channels/inspector/src/components/FullscreenButton.tsx +45 -0
  199. rasa/core/channels/inspector/src/components/LoadingSpinner.tsx +22 -0
  200. rasa/core/channels/inspector/src/components/NoActiveFlow.tsx +21 -0
  201. rasa/core/channels/inspector/src/components/RasaLogo.tsx +32 -0
  202. rasa/core/channels/inspector/src/components/SaraDiagrams.tsx +39 -0
  203. rasa/core/channels/inspector/src/components/Slots.tsx +91 -0
  204. rasa/core/channels/inspector/src/components/Welcome.tsx +54 -0
  205. rasa/core/channels/inspector/src/helpers/audiostream.ts +191 -0
  206. rasa/core/channels/inspector/src/helpers/formatters.test.ts +392 -0
  207. rasa/core/channels/inspector/src/helpers/formatters.ts +306 -0
  208. rasa/core/channels/inspector/src/helpers/utils.ts +127 -0
  209. rasa/core/channels/inspector/src/main.tsx +13 -0
  210. rasa/core/channels/inspector/src/theme/Button/Button.ts +29 -0
  211. rasa/core/channels/inspector/src/theme/Heading/Heading.ts +31 -0
  212. rasa/core/channels/inspector/src/theme/Input/Input.ts +27 -0
  213. rasa/core/channels/inspector/src/theme/Link/Link.ts +10 -0
  214. rasa/core/channels/inspector/src/theme/Modal/Modal.ts +47 -0
  215. rasa/core/channels/inspector/src/theme/Table/Table.tsx +38 -0
  216. rasa/core/channels/inspector/src/theme/Tooltip/Tooltip.ts +12 -0
  217. rasa/core/channels/inspector/src/theme/base/breakpoints.ts +8 -0
  218. rasa/core/channels/inspector/src/theme/base/colors.ts +88 -0
  219. rasa/core/channels/inspector/src/theme/base/fonts/fontFaces.css +29 -0
  220. rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.eot +0 -0
  221. rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.svg +329 -0
  222. rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.ttf +0 -0
  223. rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.woff +0 -0
  224. rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.woff2 +0 -0
  225. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.eot +0 -0
  226. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.svg +438 -0
  227. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.ttf +0 -0
  228. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.woff +0 -0
  229. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.woff2 +0 -0
  230. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.eot +0 -0
  231. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.svg +435 -0
  232. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.ttf +0 -0
  233. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.woff +0 -0
  234. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.woff2 +0 -0
  235. rasa/core/channels/inspector/src/theme/base/radii.ts +9 -0
  236. rasa/core/channels/inspector/src/theme/base/shadows.ts +7 -0
  237. rasa/core/channels/inspector/src/theme/base/sizes.ts +7 -0
  238. rasa/core/channels/inspector/src/theme/base/space.ts +15 -0
  239. rasa/core/channels/inspector/src/theme/base/styles.ts +13 -0
  240. rasa/core/channels/inspector/src/theme/base/typography.ts +24 -0
  241. rasa/core/channels/inspector/src/theme/base/zIndices.ts +19 -0
  242. rasa/core/channels/inspector/src/theme/index.ts +101 -0
  243. rasa/core/channels/inspector/src/types.ts +84 -0
  244. rasa/core/channels/inspector/src/vite-env.d.ts +1 -0
  245. rasa/core/channels/inspector/tests/__mocks__/fileMock.ts +1 -0
  246. rasa/core/channels/inspector/tests/__mocks__/matchMedia.ts +16 -0
  247. rasa/core/channels/inspector/tests/__mocks__/styleMock.ts +1 -0
  248. rasa/core/channels/inspector/tests/renderWithProviders.tsx +14 -0
  249. rasa/core/channels/inspector/tsconfig.json +26 -0
  250. rasa/core/channels/inspector/tsconfig.node.json +10 -0
  251. rasa/core/channels/inspector/vite.config.ts +8 -0
  252. rasa/core/channels/inspector/yarn.lock +6249 -0
  253. rasa/core/channels/mattermost.py +229 -0
  254. rasa/core/channels/rasa_chat.py +126 -0
  255. rasa/core/channels/rest.py +230 -0
  256. rasa/core/channels/rocketchat.py +174 -0
  257. rasa/core/channels/slack.py +620 -0
  258. rasa/core/channels/socketio.py +302 -0
  259. rasa/core/channels/telegram.py +298 -0
  260. rasa/core/channels/twilio.py +169 -0
  261. rasa/core/channels/vier_cvg.py +374 -0
  262. rasa/core/channels/voice_ready/__init__.py +0 -0
  263. rasa/core/channels/voice_ready/audiocodes.py +501 -0
  264. rasa/core/channels/voice_ready/jambonz.py +121 -0
  265. rasa/core/channels/voice_ready/jambonz_protocol.py +396 -0
  266. rasa/core/channels/voice_ready/twilio_voice.py +403 -0
  267. rasa/core/channels/voice_ready/utils.py +37 -0
  268. rasa/core/channels/voice_stream/__init__.py +0 -0
  269. rasa/core/channels/voice_stream/asr/__init__.py +0 -0
  270. rasa/core/channels/voice_stream/asr/asr_engine.py +89 -0
  271. rasa/core/channels/voice_stream/asr/asr_event.py +18 -0
  272. rasa/core/channels/voice_stream/asr/azure.py +130 -0
  273. rasa/core/channels/voice_stream/asr/deepgram.py +90 -0
  274. rasa/core/channels/voice_stream/audio_bytes.py +8 -0
  275. rasa/core/channels/voice_stream/browser_audio.py +107 -0
  276. rasa/core/channels/voice_stream/call_state.py +23 -0
  277. rasa/core/channels/voice_stream/tts/__init__.py +0 -0
  278. rasa/core/channels/voice_stream/tts/azure.py +106 -0
  279. rasa/core/channels/voice_stream/tts/cartesia.py +118 -0
  280. rasa/core/channels/voice_stream/tts/tts_cache.py +27 -0
  281. rasa/core/channels/voice_stream/tts/tts_engine.py +58 -0
  282. rasa/core/channels/voice_stream/twilio_media_streams.py +173 -0
  283. rasa/core/channels/voice_stream/util.py +57 -0
  284. rasa/core/channels/voice_stream/voice_channel.py +427 -0
  285. rasa/core/channels/webexteams.py +134 -0
  286. rasa/core/concurrent_lock_store.py +210 -0
  287. rasa/core/constants.py +112 -0
  288. rasa/core/evaluation/__init__.py +0 -0
  289. rasa/core/evaluation/marker.py +267 -0
  290. rasa/core/evaluation/marker_base.py +923 -0
  291. rasa/core/evaluation/marker_stats.py +293 -0
  292. rasa/core/evaluation/marker_tracker_loader.py +103 -0
  293. rasa/core/exceptions.py +29 -0
  294. rasa/core/exporter.py +284 -0
  295. rasa/core/featurizers/__init__.py +0 -0
  296. rasa/core/featurizers/precomputation.py +410 -0
  297. rasa/core/featurizers/single_state_featurizer.py +421 -0
  298. rasa/core/featurizers/tracker_featurizers.py +1262 -0
  299. rasa/core/http_interpreter.py +89 -0
  300. rasa/core/information_retrieval/__init__.py +7 -0
  301. rasa/core/information_retrieval/faiss.py +124 -0
  302. rasa/core/information_retrieval/information_retrieval.py +137 -0
  303. rasa/core/information_retrieval/milvus.py +59 -0
  304. rasa/core/information_retrieval/qdrant.py +96 -0
  305. rasa/core/jobs.py +63 -0
  306. rasa/core/lock.py +139 -0
  307. rasa/core/lock_store.py +343 -0
  308. rasa/core/migrate.py +403 -0
  309. rasa/core/nlg/__init__.py +3 -0
  310. rasa/core/nlg/callback.py +146 -0
  311. rasa/core/nlg/contextual_response_rephraser.py +320 -0
  312. rasa/core/nlg/generator.py +230 -0
  313. rasa/core/nlg/interpolator.py +143 -0
  314. rasa/core/nlg/response.py +155 -0
  315. rasa/core/nlg/summarize.py +70 -0
  316. rasa/core/persistor.py +538 -0
  317. rasa/core/policies/__init__.py +0 -0
  318. rasa/core/policies/ensemble.py +329 -0
  319. rasa/core/policies/enterprise_search_policy.py +905 -0
  320. rasa/core/policies/enterprise_search_prompt_template.jinja2 +25 -0
  321. rasa/core/policies/enterprise_search_prompt_with_citation_template.jinja2 +60 -0
  322. rasa/core/policies/flow_policy.py +205 -0
  323. rasa/core/policies/flows/__init__.py +0 -0
  324. rasa/core/policies/flows/flow_exceptions.py +44 -0
  325. rasa/core/policies/flows/flow_executor.py +754 -0
  326. rasa/core/policies/flows/flow_step_result.py +43 -0
  327. rasa/core/policies/intentless_policy.py +1031 -0
  328. rasa/core/policies/intentless_prompt_template.jinja2 +22 -0
  329. rasa/core/policies/memoization.py +538 -0
  330. rasa/core/policies/policy.py +725 -0
  331. rasa/core/policies/rule_policy.py +1273 -0
  332. rasa/core/policies/ted_policy.py +2169 -0
  333. rasa/core/policies/unexpected_intent_policy.py +1022 -0
  334. rasa/core/processor.py +1465 -0
  335. rasa/core/run.py +342 -0
  336. rasa/core/secrets_manager/__init__.py +0 -0
  337. rasa/core/secrets_manager/constants.py +36 -0
  338. rasa/core/secrets_manager/endpoints.py +391 -0
  339. rasa/core/secrets_manager/factory.py +241 -0
  340. rasa/core/secrets_manager/secret_manager.py +262 -0
  341. rasa/core/secrets_manager/vault.py +584 -0
  342. rasa/core/test.py +1335 -0
  343. rasa/core/tracker_store.py +1703 -0
  344. rasa/core/train.py +105 -0
  345. rasa/core/training/__init__.py +89 -0
  346. rasa/core/training/converters/__init__.py +0 -0
  347. rasa/core/training/converters/responses_prefix_converter.py +119 -0
  348. rasa/core/training/interactive.py +1744 -0
  349. rasa/core/training/story_conflict.py +381 -0
  350. rasa/core/training/training.py +93 -0
  351. rasa/core/utils.py +366 -0
  352. rasa/core/visualize.py +70 -0
  353. rasa/dialogue_understanding/__init__.py +0 -0
  354. rasa/dialogue_understanding/coexistence/__init__.py +0 -0
  355. rasa/dialogue_understanding/coexistence/constants.py +4 -0
  356. rasa/dialogue_understanding/coexistence/intent_based_router.py +196 -0
  357. rasa/dialogue_understanding/coexistence/llm_based_router.py +327 -0
  358. rasa/dialogue_understanding/coexistence/router_template.jinja2 +12 -0
  359. rasa/dialogue_understanding/commands/__init__.py +61 -0
  360. rasa/dialogue_understanding/commands/can_not_handle_command.py +70 -0
  361. rasa/dialogue_understanding/commands/cancel_flow_command.py +125 -0
  362. rasa/dialogue_understanding/commands/change_flow_command.py +44 -0
  363. rasa/dialogue_understanding/commands/chit_chat_answer_command.py +57 -0
  364. rasa/dialogue_understanding/commands/clarify_command.py +86 -0
  365. rasa/dialogue_understanding/commands/command.py +85 -0
  366. rasa/dialogue_understanding/commands/correct_slots_command.py +297 -0
  367. rasa/dialogue_understanding/commands/error_command.py +79 -0
  368. rasa/dialogue_understanding/commands/free_form_answer_command.py +9 -0
  369. rasa/dialogue_understanding/commands/handle_code_change_command.py +73 -0
  370. rasa/dialogue_understanding/commands/human_handoff_command.py +66 -0
  371. rasa/dialogue_understanding/commands/knowledge_answer_command.py +57 -0
  372. rasa/dialogue_understanding/commands/noop_command.py +54 -0
  373. rasa/dialogue_understanding/commands/repeat_bot_messages_command.py +60 -0
  374. rasa/dialogue_understanding/commands/restart_command.py +58 -0
  375. rasa/dialogue_understanding/commands/session_end_command.py +61 -0
  376. rasa/dialogue_understanding/commands/session_start_command.py +59 -0
  377. rasa/dialogue_understanding/commands/set_slot_command.py +160 -0
  378. rasa/dialogue_understanding/commands/skip_question_command.py +75 -0
  379. rasa/dialogue_understanding/commands/start_flow_command.py +107 -0
  380. rasa/dialogue_understanding/commands/user_silence_command.py +59 -0
  381. rasa/dialogue_understanding/commands/utils.py +45 -0
  382. rasa/dialogue_understanding/generator/__init__.py +21 -0
  383. rasa/dialogue_understanding/generator/command_generator.py +464 -0
  384. rasa/dialogue_understanding/generator/constants.py +27 -0
  385. rasa/dialogue_understanding/generator/flow_document_template.jinja2 +4 -0
  386. rasa/dialogue_understanding/generator/flow_retrieval.py +466 -0
  387. rasa/dialogue_understanding/generator/llm_based_command_generator.py +500 -0
  388. rasa/dialogue_understanding/generator/llm_command_generator.py +67 -0
  389. rasa/dialogue_understanding/generator/multi_step/__init__.py +0 -0
  390. rasa/dialogue_understanding/generator/multi_step/fill_slots_prompt.jinja2 +62 -0
  391. rasa/dialogue_understanding/generator/multi_step/handle_flows_prompt.jinja2 +38 -0
  392. rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +920 -0
  393. rasa/dialogue_understanding/generator/nlu_command_adapter.py +261 -0
  394. rasa/dialogue_understanding/generator/single_step/__init__.py +0 -0
  395. rasa/dialogue_understanding/generator/single_step/command_prompt_template.jinja2 +60 -0
  396. rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +486 -0
  397. rasa/dialogue_understanding/patterns/__init__.py +0 -0
  398. rasa/dialogue_understanding/patterns/cancel.py +111 -0
  399. rasa/dialogue_understanding/patterns/cannot_handle.py +43 -0
  400. rasa/dialogue_understanding/patterns/chitchat.py +37 -0
  401. rasa/dialogue_understanding/patterns/clarify.py +97 -0
  402. rasa/dialogue_understanding/patterns/code_change.py +41 -0
  403. rasa/dialogue_understanding/patterns/collect_information.py +90 -0
  404. rasa/dialogue_understanding/patterns/completed.py +40 -0
  405. rasa/dialogue_understanding/patterns/continue_interrupted.py +42 -0
  406. rasa/dialogue_understanding/patterns/correction.py +278 -0
  407. rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +301 -0
  408. rasa/dialogue_understanding/patterns/human_handoff.py +37 -0
  409. rasa/dialogue_understanding/patterns/internal_error.py +47 -0
  410. rasa/dialogue_understanding/patterns/repeat.py +37 -0
  411. rasa/dialogue_understanding/patterns/restart.py +37 -0
  412. rasa/dialogue_understanding/patterns/search.py +37 -0
  413. rasa/dialogue_understanding/patterns/session_start.py +37 -0
  414. rasa/dialogue_understanding/patterns/skip_question.py +38 -0
  415. rasa/dialogue_understanding/patterns/user_silence.py +37 -0
  416. rasa/dialogue_understanding/processor/__init__.py +0 -0
  417. rasa/dialogue_understanding/processor/command_processor.py +720 -0
  418. rasa/dialogue_understanding/processor/command_processor_component.py +43 -0
  419. rasa/dialogue_understanding/stack/__init__.py +0 -0
  420. rasa/dialogue_understanding/stack/dialogue_stack.py +178 -0
  421. rasa/dialogue_understanding/stack/frames/__init__.py +19 -0
  422. rasa/dialogue_understanding/stack/frames/chit_chat_frame.py +27 -0
  423. rasa/dialogue_understanding/stack/frames/dialogue_stack_frame.py +137 -0
  424. rasa/dialogue_understanding/stack/frames/flow_stack_frame.py +157 -0
  425. rasa/dialogue_understanding/stack/frames/pattern_frame.py +10 -0
  426. rasa/dialogue_understanding/stack/frames/search_frame.py +27 -0
  427. rasa/dialogue_understanding/stack/utils.py +211 -0
  428. rasa/dialogue_understanding/utils.py +14 -0
  429. rasa/dialogue_understanding_test/__init__.py +0 -0
  430. rasa/dialogue_understanding_test/command_metric_calculation.py +12 -0
  431. rasa/dialogue_understanding_test/constants.py +17 -0
  432. rasa/dialogue_understanding_test/du_test_case.py +118 -0
  433. rasa/dialogue_understanding_test/du_test_result.py +11 -0
  434. rasa/dialogue_understanding_test/du_test_runner.py +93 -0
  435. rasa/dialogue_understanding_test/io.py +54 -0
  436. rasa/dialogue_understanding_test/validation.py +22 -0
  437. rasa/e2e_test/__init__.py +0 -0
  438. rasa/e2e_test/aggregate_test_stats_calculator.py +134 -0
  439. rasa/e2e_test/assertions.py +1345 -0
  440. rasa/e2e_test/assertions_schema.yml +129 -0
  441. rasa/e2e_test/constants.py +31 -0
  442. rasa/e2e_test/e2e_config.py +220 -0
  443. rasa/e2e_test/e2e_config_schema.yml +26 -0
  444. rasa/e2e_test/e2e_test_case.py +569 -0
  445. rasa/e2e_test/e2e_test_converter.py +363 -0
  446. rasa/e2e_test/e2e_test_converter_prompt.jinja2 +70 -0
  447. rasa/e2e_test/e2e_test_coverage_report.py +364 -0
  448. rasa/e2e_test/e2e_test_result.py +54 -0
  449. rasa/e2e_test/e2e_test_runner.py +1192 -0
  450. rasa/e2e_test/e2e_test_schema.yml +181 -0
  451. rasa/e2e_test/pykwalify_extensions.py +39 -0
  452. rasa/e2e_test/stub_custom_action.py +70 -0
  453. rasa/e2e_test/utils/__init__.py +0 -0
  454. rasa/e2e_test/utils/e2e_yaml_utils.py +55 -0
  455. rasa/e2e_test/utils/io.py +598 -0
  456. rasa/e2e_test/utils/validation.py +178 -0
  457. rasa/engine/__init__.py +0 -0
  458. rasa/engine/caching.py +463 -0
  459. rasa/engine/constants.py +17 -0
  460. rasa/engine/exceptions.py +14 -0
  461. rasa/engine/graph.py +642 -0
  462. rasa/engine/loader.py +48 -0
  463. rasa/engine/recipes/__init__.py +0 -0
  464. rasa/engine/recipes/config_files/default_config.yml +41 -0
  465. rasa/engine/recipes/default_components.py +97 -0
  466. rasa/engine/recipes/default_recipe.py +1272 -0
  467. rasa/engine/recipes/graph_recipe.py +79 -0
  468. rasa/engine/recipes/recipe.py +93 -0
  469. rasa/engine/runner/__init__.py +0 -0
  470. rasa/engine/runner/dask.py +250 -0
  471. rasa/engine/runner/interface.py +49 -0
  472. rasa/engine/storage/__init__.py +0 -0
  473. rasa/engine/storage/local_model_storage.py +244 -0
  474. rasa/engine/storage/resource.py +110 -0
  475. rasa/engine/storage/storage.py +199 -0
  476. rasa/engine/training/__init__.py +0 -0
  477. rasa/engine/training/components.py +176 -0
  478. rasa/engine/training/fingerprinting.py +64 -0
  479. rasa/engine/training/graph_trainer.py +256 -0
  480. rasa/engine/training/hooks.py +164 -0
  481. rasa/engine/validation.py +1451 -0
  482. rasa/env.py +14 -0
  483. rasa/exceptions.py +69 -0
  484. rasa/graph_components/__init__.py +0 -0
  485. rasa/graph_components/converters/__init__.py +0 -0
  486. rasa/graph_components/converters/nlu_message_converter.py +48 -0
  487. rasa/graph_components/providers/__init__.py +0 -0
  488. rasa/graph_components/providers/domain_for_core_training_provider.py +87 -0
  489. rasa/graph_components/providers/domain_provider.py +71 -0
  490. rasa/graph_components/providers/flows_provider.py +74 -0
  491. rasa/graph_components/providers/forms_provider.py +44 -0
  492. rasa/graph_components/providers/nlu_training_data_provider.py +56 -0
  493. rasa/graph_components/providers/responses_provider.py +44 -0
  494. rasa/graph_components/providers/rule_only_provider.py +49 -0
  495. rasa/graph_components/providers/story_graph_provider.py +96 -0
  496. rasa/graph_components/providers/training_tracker_provider.py +55 -0
  497. rasa/graph_components/validators/__init__.py +0 -0
  498. rasa/graph_components/validators/default_recipe_validator.py +550 -0
  499. rasa/graph_components/validators/finetuning_validator.py +302 -0
  500. rasa/hooks.py +111 -0
  501. rasa/jupyter.py +63 -0
  502. rasa/llm_fine_tuning/__init__.py +0 -0
  503. rasa/llm_fine_tuning/annotation_module.py +241 -0
  504. rasa/llm_fine_tuning/conversations.py +144 -0
  505. rasa/llm_fine_tuning/llm_data_preparation_module.py +178 -0
  506. rasa/llm_fine_tuning/paraphrasing/__init__.py +0 -0
  507. rasa/llm_fine_tuning/paraphrasing/conversation_rephraser.py +281 -0
  508. rasa/llm_fine_tuning/paraphrasing/default_rephrase_prompt_template.jina2 +44 -0
  509. rasa/llm_fine_tuning/paraphrasing/rephrase_validator.py +121 -0
  510. rasa/llm_fine_tuning/paraphrasing/rephrased_user_message.py +10 -0
  511. rasa/llm_fine_tuning/paraphrasing_module.py +128 -0
  512. rasa/llm_fine_tuning/storage.py +174 -0
  513. rasa/llm_fine_tuning/train_test_split_module.py +441 -0
  514. rasa/markers/__init__.py +0 -0
  515. rasa/markers/marker.py +269 -0
  516. rasa/markers/marker_base.py +828 -0
  517. rasa/markers/upload.py +74 -0
  518. rasa/markers/validate.py +21 -0
  519. rasa/model.py +118 -0
  520. rasa/model_manager/__init__.py +0 -0
  521. rasa/model_manager/config.py +40 -0
  522. rasa/model_manager/model_api.py +559 -0
  523. rasa/model_manager/runner_service.py +286 -0
  524. rasa/model_manager/socket_bridge.py +146 -0
  525. rasa/model_manager/studio_jwt_auth.py +86 -0
  526. rasa/model_manager/trainer_service.py +325 -0
  527. rasa/model_manager/utils.py +87 -0
  528. rasa/model_manager/warm_rasa_process.py +187 -0
  529. rasa/model_service.py +112 -0
  530. rasa/model_testing.py +457 -0
  531. rasa/model_training.py +596 -0
  532. rasa/nlu/__init__.py +7 -0
  533. rasa/nlu/classifiers/__init__.py +3 -0
  534. rasa/nlu/classifiers/classifier.py +5 -0
  535. rasa/nlu/classifiers/diet_classifier.py +1881 -0
  536. rasa/nlu/classifiers/fallback_classifier.py +192 -0
  537. rasa/nlu/classifiers/keyword_intent_classifier.py +188 -0
  538. rasa/nlu/classifiers/logistic_regression_classifier.py +253 -0
  539. rasa/nlu/classifiers/mitie_intent_classifier.py +156 -0
  540. rasa/nlu/classifiers/regex_message_handler.py +56 -0
  541. rasa/nlu/classifiers/sklearn_intent_classifier.py +330 -0
  542. rasa/nlu/constants.py +77 -0
  543. rasa/nlu/convert.py +40 -0
  544. rasa/nlu/emulators/__init__.py +0 -0
  545. rasa/nlu/emulators/dialogflow.py +55 -0
  546. rasa/nlu/emulators/emulator.py +49 -0
  547. rasa/nlu/emulators/luis.py +86 -0
  548. rasa/nlu/emulators/no_emulator.py +10 -0
  549. rasa/nlu/emulators/wit.py +56 -0
  550. rasa/nlu/extractors/__init__.py +0 -0
  551. rasa/nlu/extractors/crf_entity_extractor.py +715 -0
  552. rasa/nlu/extractors/duckling_entity_extractor.py +206 -0
  553. rasa/nlu/extractors/entity_synonyms.py +178 -0
  554. rasa/nlu/extractors/extractor.py +470 -0
  555. rasa/nlu/extractors/mitie_entity_extractor.py +293 -0
  556. rasa/nlu/extractors/regex_entity_extractor.py +220 -0
  557. rasa/nlu/extractors/spacy_entity_extractor.py +95 -0
  558. rasa/nlu/featurizers/__init__.py +0 -0
  559. rasa/nlu/featurizers/dense_featurizer/__init__.py +0 -0
  560. rasa/nlu/featurizers/dense_featurizer/convert_featurizer.py +445 -0
  561. rasa/nlu/featurizers/dense_featurizer/dense_featurizer.py +57 -0
  562. rasa/nlu/featurizers/dense_featurizer/lm_featurizer.py +768 -0
  563. rasa/nlu/featurizers/dense_featurizer/mitie_featurizer.py +170 -0
  564. rasa/nlu/featurizers/dense_featurizer/spacy_featurizer.py +132 -0
  565. rasa/nlu/featurizers/featurizer.py +89 -0
  566. rasa/nlu/featurizers/sparse_featurizer/__init__.py +0 -0
  567. rasa/nlu/featurizers/sparse_featurizer/count_vectors_featurizer.py +867 -0
  568. rasa/nlu/featurizers/sparse_featurizer/lexical_syntactic_featurizer.py +571 -0
  569. rasa/nlu/featurizers/sparse_featurizer/regex_featurizer.py +271 -0
  570. rasa/nlu/featurizers/sparse_featurizer/sparse_featurizer.py +9 -0
  571. rasa/nlu/model.py +24 -0
  572. rasa/nlu/run.py +27 -0
  573. rasa/nlu/selectors/__init__.py +0 -0
  574. rasa/nlu/selectors/response_selector.py +987 -0
  575. rasa/nlu/test.py +1940 -0
  576. rasa/nlu/tokenizers/__init__.py +0 -0
  577. rasa/nlu/tokenizers/jieba_tokenizer.py +148 -0
  578. rasa/nlu/tokenizers/mitie_tokenizer.py +75 -0
  579. rasa/nlu/tokenizers/spacy_tokenizer.py +72 -0
  580. rasa/nlu/tokenizers/tokenizer.py +239 -0
  581. rasa/nlu/tokenizers/whitespace_tokenizer.py +95 -0
  582. rasa/nlu/utils/__init__.py +35 -0
  583. rasa/nlu/utils/bilou_utils.py +462 -0
  584. rasa/nlu/utils/hugging_face/__init__.py +0 -0
  585. rasa/nlu/utils/hugging_face/registry.py +108 -0
  586. rasa/nlu/utils/hugging_face/transformers_pre_post_processors.py +311 -0
  587. rasa/nlu/utils/mitie_utils.py +113 -0
  588. rasa/nlu/utils/pattern_utils.py +168 -0
  589. rasa/nlu/utils/spacy_utils.py +310 -0
  590. rasa/plugin.py +90 -0
  591. rasa/server.py +1588 -0
  592. rasa/shared/__init__.py +0 -0
  593. rasa/shared/constants.py +311 -0
  594. rasa/shared/core/__init__.py +0 -0
  595. rasa/shared/core/command_payload_reader.py +109 -0
  596. rasa/shared/core/constants.py +180 -0
  597. rasa/shared/core/conversation.py +46 -0
  598. rasa/shared/core/domain.py +2172 -0
  599. rasa/shared/core/events.py +2559 -0
  600. rasa/shared/core/flows/__init__.py +7 -0
  601. rasa/shared/core/flows/flow.py +562 -0
  602. rasa/shared/core/flows/flow_path.py +84 -0
  603. rasa/shared/core/flows/flow_step.py +146 -0
  604. rasa/shared/core/flows/flow_step_links.py +319 -0
  605. rasa/shared/core/flows/flow_step_sequence.py +70 -0
  606. rasa/shared/core/flows/flows_list.py +258 -0
  607. rasa/shared/core/flows/flows_yaml_schema.json +303 -0
  608. rasa/shared/core/flows/nlu_trigger.py +117 -0
  609. rasa/shared/core/flows/steps/__init__.py +24 -0
  610. rasa/shared/core/flows/steps/action.py +56 -0
  611. rasa/shared/core/flows/steps/call.py +64 -0
  612. rasa/shared/core/flows/steps/collect.py +112 -0
  613. rasa/shared/core/flows/steps/constants.py +5 -0
  614. rasa/shared/core/flows/steps/continuation.py +36 -0
  615. rasa/shared/core/flows/steps/end.py +22 -0
  616. rasa/shared/core/flows/steps/internal.py +44 -0
  617. rasa/shared/core/flows/steps/link.py +51 -0
  618. rasa/shared/core/flows/steps/no_operation.py +48 -0
  619. rasa/shared/core/flows/steps/set_slots.py +50 -0
  620. rasa/shared/core/flows/steps/start.py +30 -0
  621. rasa/shared/core/flows/utils.py +39 -0
  622. rasa/shared/core/flows/validation.py +735 -0
  623. rasa/shared/core/flows/yaml_flows_io.py +405 -0
  624. rasa/shared/core/generator.py +908 -0
  625. rasa/shared/core/slot_mappings.py +526 -0
  626. rasa/shared/core/slots.py +654 -0
  627. rasa/shared/core/trackers.py +1183 -0
  628. rasa/shared/core/training_data/__init__.py +0 -0
  629. rasa/shared/core/training_data/loading.py +89 -0
  630. rasa/shared/core/training_data/story_reader/__init__.py +0 -0
  631. rasa/shared/core/training_data/story_reader/story_reader.py +129 -0
  632. rasa/shared/core/training_data/story_reader/story_step_builder.py +168 -0
  633. rasa/shared/core/training_data/story_reader/yaml_story_reader.py +888 -0
  634. rasa/shared/core/training_data/story_writer/__init__.py +0 -0
  635. rasa/shared/core/training_data/story_writer/story_writer.py +76 -0
  636. rasa/shared/core/training_data/story_writer/yaml_story_writer.py +444 -0
  637. rasa/shared/core/training_data/structures.py +858 -0
  638. rasa/shared/core/training_data/visualization.html +146 -0
  639. rasa/shared/core/training_data/visualization.py +603 -0
  640. rasa/shared/data.py +249 -0
  641. rasa/shared/engine/__init__.py +0 -0
  642. rasa/shared/engine/caching.py +26 -0
  643. rasa/shared/exceptions.py +167 -0
  644. rasa/shared/importers/__init__.py +0 -0
  645. rasa/shared/importers/importer.py +770 -0
  646. rasa/shared/importers/multi_project.py +215 -0
  647. rasa/shared/importers/rasa.py +108 -0
  648. rasa/shared/importers/remote_importer.py +196 -0
  649. rasa/shared/importers/utils.py +36 -0
  650. rasa/shared/nlu/__init__.py +0 -0
  651. rasa/shared/nlu/constants.py +53 -0
  652. rasa/shared/nlu/interpreter.py +10 -0
  653. rasa/shared/nlu/training_data/__init__.py +0 -0
  654. rasa/shared/nlu/training_data/entities_parser.py +208 -0
  655. rasa/shared/nlu/training_data/features.py +492 -0
  656. rasa/shared/nlu/training_data/formats/__init__.py +10 -0
  657. rasa/shared/nlu/training_data/formats/dialogflow.py +163 -0
  658. rasa/shared/nlu/training_data/formats/luis.py +87 -0
  659. rasa/shared/nlu/training_data/formats/rasa.py +135 -0
  660. rasa/shared/nlu/training_data/formats/rasa_yaml.py +618 -0
  661. rasa/shared/nlu/training_data/formats/readerwriter.py +244 -0
  662. rasa/shared/nlu/training_data/formats/wit.py +52 -0
  663. rasa/shared/nlu/training_data/loading.py +137 -0
  664. rasa/shared/nlu/training_data/lookup_tables_parser.py +30 -0
  665. rasa/shared/nlu/training_data/message.py +490 -0
  666. rasa/shared/nlu/training_data/schemas/__init__.py +0 -0
  667. rasa/shared/nlu/training_data/schemas/data_schema.py +85 -0
  668. rasa/shared/nlu/training_data/schemas/nlu.yml +53 -0
  669. rasa/shared/nlu/training_data/schemas/responses.yml +70 -0
  670. rasa/shared/nlu/training_data/synonyms_parser.py +42 -0
  671. rasa/shared/nlu/training_data/training_data.py +729 -0
  672. rasa/shared/nlu/training_data/util.py +223 -0
  673. rasa/shared/providers/__init__.py +0 -0
  674. rasa/shared/providers/_configs/__init__.py +0 -0
  675. rasa/shared/providers/_configs/azure_openai_client_config.py +677 -0
  676. rasa/shared/providers/_configs/client_config.py +59 -0
  677. rasa/shared/providers/_configs/default_litellm_client_config.py +132 -0
  678. rasa/shared/providers/_configs/huggingface_local_embedding_client_config.py +236 -0
  679. rasa/shared/providers/_configs/litellm_router_client_config.py +222 -0
  680. rasa/shared/providers/_configs/model_group_config.py +173 -0
  681. rasa/shared/providers/_configs/openai_client_config.py +177 -0
  682. rasa/shared/providers/_configs/rasa_llm_client_config.py +75 -0
  683. rasa/shared/providers/_configs/self_hosted_llm_client_config.py +178 -0
  684. rasa/shared/providers/_configs/utils.py +117 -0
  685. rasa/shared/providers/_ssl_verification_utils.py +124 -0
  686. rasa/shared/providers/_utils.py +79 -0
  687. rasa/shared/providers/constants.py +7 -0
  688. rasa/shared/providers/embedding/__init__.py +0 -0
  689. rasa/shared/providers/embedding/_base_litellm_embedding_client.py +243 -0
  690. rasa/shared/providers/embedding/_langchain_embedding_client_adapter.py +74 -0
  691. rasa/shared/providers/embedding/azure_openai_embedding_client.py +335 -0
  692. rasa/shared/providers/embedding/default_litellm_embedding_client.py +126 -0
  693. rasa/shared/providers/embedding/embedding_client.py +90 -0
  694. rasa/shared/providers/embedding/embedding_response.py +41 -0
  695. rasa/shared/providers/embedding/huggingface_local_embedding_client.py +191 -0
  696. rasa/shared/providers/embedding/litellm_router_embedding_client.py +138 -0
  697. rasa/shared/providers/embedding/openai_embedding_client.py +172 -0
  698. rasa/shared/providers/llm/__init__.py +0 -0
  699. rasa/shared/providers/llm/_base_litellm_client.py +265 -0
  700. rasa/shared/providers/llm/azure_openai_llm_client.py +415 -0
  701. rasa/shared/providers/llm/default_litellm_llm_client.py +110 -0
  702. rasa/shared/providers/llm/litellm_router_llm_client.py +202 -0
  703. rasa/shared/providers/llm/llm_client.py +78 -0
  704. rasa/shared/providers/llm/llm_response.py +50 -0
  705. rasa/shared/providers/llm/openai_llm_client.py +161 -0
  706. rasa/shared/providers/llm/rasa_llm_client.py +120 -0
  707. rasa/shared/providers/llm/self_hosted_llm_client.py +276 -0
  708. rasa/shared/providers/mappings.py +94 -0
  709. rasa/shared/providers/router/__init__.py +0 -0
  710. rasa/shared/providers/router/_base_litellm_router_client.py +185 -0
  711. rasa/shared/providers/router/router_client.py +75 -0
  712. rasa/shared/utils/__init__.py +0 -0
  713. rasa/shared/utils/cli.py +102 -0
  714. rasa/shared/utils/common.py +324 -0
  715. rasa/shared/utils/constants.py +4 -0
  716. rasa/shared/utils/health_check/__init__.py +0 -0
  717. rasa/shared/utils/health_check/embeddings_health_check_mixin.py +31 -0
  718. rasa/shared/utils/health_check/health_check.py +258 -0
  719. rasa/shared/utils/health_check/llm_health_check_mixin.py +31 -0
  720. rasa/shared/utils/io.py +499 -0
  721. rasa/shared/utils/llm.py +764 -0
  722. rasa/shared/utils/pykwalify_extensions.py +27 -0
  723. rasa/shared/utils/schemas/__init__.py +0 -0
  724. rasa/shared/utils/schemas/config.yml +2 -0
  725. rasa/shared/utils/schemas/domain.yml +145 -0
  726. rasa/shared/utils/schemas/events.py +214 -0
  727. rasa/shared/utils/schemas/model_config.yml +36 -0
  728. rasa/shared/utils/schemas/stories.yml +173 -0
  729. rasa/shared/utils/yaml.py +1068 -0
  730. rasa/studio/__init__.py +0 -0
  731. rasa/studio/auth.py +270 -0
  732. rasa/studio/config.py +136 -0
  733. rasa/studio/constants.py +19 -0
  734. rasa/studio/data_handler.py +368 -0
  735. rasa/studio/download.py +489 -0
  736. rasa/studio/results_logger.py +137 -0
  737. rasa/studio/train.py +134 -0
  738. rasa/studio/upload.py +563 -0
  739. rasa/telemetry.py +1876 -0
  740. rasa/tracing/__init__.py +0 -0
  741. rasa/tracing/config.py +355 -0
  742. rasa/tracing/constants.py +62 -0
  743. rasa/tracing/instrumentation/__init__.py +0 -0
  744. rasa/tracing/instrumentation/attribute_extractors.py +765 -0
  745. rasa/tracing/instrumentation/instrumentation.py +1306 -0
  746. rasa/tracing/instrumentation/intentless_policy_instrumentation.py +144 -0
  747. rasa/tracing/instrumentation/metrics.py +294 -0
  748. rasa/tracing/metric_instrument_provider.py +205 -0
  749. rasa/utils/__init__.py +0 -0
  750. rasa/utils/beta.py +83 -0
  751. rasa/utils/cli.py +28 -0
  752. rasa/utils/common.py +639 -0
  753. rasa/utils/converter.py +53 -0
  754. rasa/utils/endpoints.py +331 -0
  755. rasa/utils/io.py +252 -0
  756. rasa/utils/json_utils.py +60 -0
  757. rasa/utils/licensing.py +542 -0
  758. rasa/utils/log_utils.py +181 -0
  759. rasa/utils/mapper.py +210 -0
  760. rasa/utils/ml_utils.py +147 -0
  761. rasa/utils/plotting.py +362 -0
  762. rasa/utils/sanic_error_handler.py +32 -0
  763. rasa/utils/singleton.py +23 -0
  764. rasa/utils/tensorflow/__init__.py +0 -0
  765. rasa/utils/tensorflow/callback.py +112 -0
  766. rasa/utils/tensorflow/constants.py +116 -0
  767. rasa/utils/tensorflow/crf.py +492 -0
  768. rasa/utils/tensorflow/data_generator.py +440 -0
  769. rasa/utils/tensorflow/environment.py +161 -0
  770. rasa/utils/tensorflow/exceptions.py +5 -0
  771. rasa/utils/tensorflow/feature_array.py +366 -0
  772. rasa/utils/tensorflow/layers.py +1565 -0
  773. rasa/utils/tensorflow/layers_utils.py +113 -0
  774. rasa/utils/tensorflow/metrics.py +281 -0
  775. rasa/utils/tensorflow/model_data.py +798 -0
  776. rasa/utils/tensorflow/model_data_utils.py +499 -0
  777. rasa/utils/tensorflow/models.py +935 -0
  778. rasa/utils/tensorflow/rasa_layers.py +1094 -0
  779. rasa/utils/tensorflow/transformer.py +640 -0
  780. rasa/utils/tensorflow/types.py +6 -0
  781. rasa/utils/train_utils.py +572 -0
  782. rasa/utils/url_tools.py +53 -0
  783. rasa/utils/yaml.py +54 -0
  784. rasa/validator.py +1644 -0
  785. rasa/version.py +3 -0
  786. rasa_pro-3.12.0.dev1.dist-info/METADATA +199 -0
  787. rasa_pro-3.12.0.dev1.dist-info/NOTICE +5 -0
  788. rasa_pro-3.12.0.dev1.dist-info/RECORD +790 -0
  789. rasa_pro-3.12.0.dev1.dist-info/WHEEL +4 -0
  790. rasa_pro-3.12.0.dev1.dist-info/entry_points.txt +3 -0
@@ -0,0 +1,2169 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from pathlib import Path
5
+ from collections import defaultdict
6
+ import contextlib
7
+ from typing import Any, List, Optional, Text, Dict, Tuple, Union, Type
8
+
9
+ import numpy as np
10
+ import tensorflow as tf
11
+
12
+ from rasa.engine.recipes.default_recipe import DefaultV1Recipe
13
+ from rasa.engine.graph import ExecutionContext
14
+ from rasa.engine.storage.resource import Resource
15
+ from rasa.engine.storage.storage import ModelStorage
16
+ from rasa.exceptions import ModelNotFound
17
+ from rasa.nlu.constants import TOKENS_NAMES
18
+ from rasa.nlu.extractors.extractor import EntityTagSpec, EntityExtractorMixin
19
+ import rasa.core.actions.action
20
+ from rasa.core.featurizers.precomputation import MessageContainerForCoreFeaturization
21
+ from rasa.core.featurizers.tracker_featurizers import TrackerFeaturizer
22
+ from rasa.core.featurizers.tracker_featurizers import MaxHistoryTrackerFeaturizer
23
+ from rasa.shared.exceptions import RasaException
24
+ from rasa.shared.nlu.constants import (
25
+ ACTION_TEXT,
26
+ ACTION_NAME,
27
+ INTENT,
28
+ TEXT,
29
+ ENTITIES,
30
+ FEATURE_TYPE_SENTENCE,
31
+ ENTITY_ATTRIBUTE_TYPE,
32
+ ENTITY_TAGS,
33
+ EXTRACTOR,
34
+ SPLIT_ENTITIES_BY_COMMA,
35
+ SPLIT_ENTITIES_BY_COMMA_DEFAULT_VALUE,
36
+ )
37
+ from rasa.core.policies.policy import PolicyPrediction, Policy, SupportedData
38
+ from rasa.core.constants import (
39
+ DIALOGUE,
40
+ POLICY_MAX_HISTORY,
41
+ DEFAULT_MAX_HISTORY,
42
+ DEFAULT_POLICY_PRIORITY,
43
+ POLICY_PRIORITY,
44
+ )
45
+ from rasa.shared.constants import DIAGNOSTIC_DATA
46
+ from rasa.shared.core.constants import ACTIVE_LOOP, SLOTS, ACTION_LISTEN_NAME
47
+ from rasa.shared.core.trackers import DialogueStateTracker
48
+ from rasa.shared.core.generator import TrackerWithCachedStates
49
+ from rasa.shared.core.events import EntitiesAdded, Event
50
+ from rasa.shared.core.domain import Domain
51
+ from rasa.shared.nlu.training_data.message import Message
52
+ from rasa.shared.nlu.training_data.features import (
53
+ Features,
54
+ save_features,
55
+ load_features,
56
+ )
57
+ import rasa.shared.utils.io
58
+ import rasa.utils.io
59
+ from rasa.utils import train_utils
60
+ from rasa.utils.tensorflow.feature_array import (
61
+ FeatureArray,
62
+ serialize_nested_feature_arrays,
63
+ deserialize_nested_feature_arrays,
64
+ )
65
+ from rasa.utils.tensorflow.models import RasaModel, TransformerRasaModel
66
+ from rasa.utils.tensorflow import rasa_layers
67
+ from rasa.utils.tensorflow.model_data import RasaModelData, FeatureSignature, Data
68
+ from rasa.utils.tensorflow.model_data_utils import convert_to_data_format
69
+ from rasa.utils.tensorflow.constants import (
70
+ LABEL,
71
+ IDS,
72
+ TRANSFORMER_SIZE,
73
+ NUM_TRANSFORMER_LAYERS,
74
+ NUM_HEADS,
75
+ BATCH_SIZES,
76
+ BATCH_STRATEGY,
77
+ EPOCHS,
78
+ RANDOM_SEED,
79
+ LEARNING_RATE,
80
+ RANKING_LENGTH,
81
+ RENORMALIZE_CONFIDENCES,
82
+ LOSS_TYPE,
83
+ SIMILARITY_TYPE,
84
+ NUM_NEG,
85
+ EVAL_NUM_EXAMPLES,
86
+ EVAL_NUM_EPOCHS,
87
+ NEGATIVE_MARGIN_SCALE,
88
+ REGULARIZATION_CONSTANT,
89
+ SCALE_LOSS,
90
+ USE_MAX_NEG_SIM,
91
+ MAX_NEG_SIM,
92
+ MAX_POS_SIM,
93
+ EMBEDDING_DIMENSION,
94
+ DROP_RATE_DIALOGUE,
95
+ DROP_RATE_LABEL,
96
+ DROP_RATE,
97
+ DROP_RATE_ATTENTION,
98
+ CONNECTION_DENSITY,
99
+ KEY_RELATIVE_ATTENTION,
100
+ VALUE_RELATIVE_ATTENTION,
101
+ MAX_RELATIVE_POSITION,
102
+ CROSS_ENTROPY,
103
+ AUTO,
104
+ BALANCED,
105
+ TENSORBOARD_LOG_DIR,
106
+ TENSORBOARD_LOG_LEVEL,
107
+ CHECKPOINT_MODEL,
108
+ ENCODING_DIMENSION,
109
+ UNIDIRECTIONAL_ENCODER,
110
+ SEQUENCE,
111
+ SENTENCE,
112
+ SEQUENCE_LENGTH,
113
+ DENSE_DIMENSION,
114
+ CONCAT_DIMENSION,
115
+ SPARSE_INPUT_DROPOUT,
116
+ DENSE_INPUT_DROPOUT,
117
+ MASKED_LM,
118
+ MASK,
119
+ HIDDEN_LAYERS_SIZES,
120
+ FEATURIZERS,
121
+ ENTITY_RECOGNITION,
122
+ CONSTRAIN_SIMILARITIES,
123
+ MODEL_CONFIDENCE,
124
+ SOFTMAX,
125
+ BILOU_FLAG,
126
+ EPOCH_OVERRIDE,
127
+ USE_GPU,
128
+ )
129
+
130
+ logger = logging.getLogger(__name__)
131
+
132
+ E2E_CONFIDENCE_THRESHOLD = "e2e_confidence_threshold"
133
+ LABEL_KEY = LABEL
134
+ LABEL_SUB_KEY = IDS
135
+ LENGTH = "length"
136
+ INDICES = "indices"
137
+ SENTENCE_FEATURES_TO_ENCODE = [INTENT, TEXT, ACTION_NAME, ACTION_TEXT]
138
+ SEQUENCE_FEATURES_TO_ENCODE = [TEXT, ACTION_TEXT, f"{LABEL}_{ACTION_TEXT}"]
139
+ LABEL_FEATURES_TO_ENCODE = [
140
+ f"{LABEL}_{ACTION_NAME}",
141
+ f"{LABEL}_{ACTION_TEXT}",
142
+ f"{LABEL}_{INTENT}",
143
+ ]
144
+ STATE_LEVEL_FEATURES = [ENTITIES, SLOTS, ACTIVE_LOOP]
145
+ PREDICTION_FEATURES = STATE_LEVEL_FEATURES + SENTENCE_FEATURES_TO_ENCODE + [DIALOGUE]
146
+
147
+
148
+ @DefaultV1Recipe.register(
149
+ DefaultV1Recipe.ComponentType.POLICY_WITH_END_TO_END_SUPPORT, is_trainable=True
150
+ )
151
+ class TEDPolicy(Policy):
152
+ """Transformer Embedding Dialogue (TED) Policy.
153
+
154
+ The model architecture is described in
155
+ detail in https://arxiv.org/abs/1910.00486.
156
+ In summary, the architecture comprises of the
157
+ following steps:
158
+ - concatenate user input (user intent and entities), previous system actions,
159
+ slots and active forms for each time step into an input vector to
160
+ pre-transformer embedding layer;
161
+ - feed it to transformer;
162
+ - apply a dense layer to the output of the transformer to get embeddings of a
163
+ dialogue for each time step;
164
+ - apply a dense layer to create embeddings for system actions for each time
165
+ step;
166
+ - calculate the similarity between the dialogue embedding and embedded system
167
+ actions. This step is based on the StarSpace
168
+ (https://arxiv.org/abs/1709.03856) idea.
169
+ """
170
+
171
+ @staticmethod
172
+ def get_default_config() -> Dict[Text, Any]:
173
+ """Returns the default config (see parent class for full docstring)."""
174
+ # please make sure to update the docs when changing a default parameter
175
+ return {
176
+ # ## Architecture of the used neural network
177
+ # Hidden layer sizes for layers before the embedding layers for user message
178
+ # and labels.
179
+ # The number of hidden layers is equal to the length of the corresponding
180
+ # list.
181
+ HIDDEN_LAYERS_SIZES: {
182
+ TEXT: [],
183
+ ACTION_TEXT: [],
184
+ f"{LABEL}_{ACTION_TEXT}": [],
185
+ },
186
+ # Dense dimension to use for sparse features.
187
+ DENSE_DIMENSION: {
188
+ TEXT: 128,
189
+ ACTION_TEXT: 128,
190
+ f"{LABEL}_{ACTION_TEXT}": 128,
191
+ INTENT: 20,
192
+ ACTION_NAME: 20,
193
+ f"{LABEL}_{ACTION_NAME}": 20,
194
+ ENTITIES: 20,
195
+ SLOTS: 20,
196
+ ACTIVE_LOOP: 20,
197
+ },
198
+ # Default dimension to use for concatenating sequence and sentence features.
199
+ CONCAT_DIMENSION: {
200
+ TEXT: 128,
201
+ ACTION_TEXT: 128,
202
+ f"{LABEL}_{ACTION_TEXT}": 128,
203
+ },
204
+ # Dimension size of embedding vectors before the dialogue transformer
205
+ # encoder.
206
+ ENCODING_DIMENSION: 50,
207
+ # Number of units in transformer encoders
208
+ TRANSFORMER_SIZE: {
209
+ TEXT: 128,
210
+ ACTION_TEXT: 128,
211
+ f"{LABEL}_{ACTION_TEXT}": 128,
212
+ DIALOGUE: 128,
213
+ },
214
+ # Number of layers in transformer encoders
215
+ NUM_TRANSFORMER_LAYERS: {
216
+ TEXT: 1,
217
+ ACTION_TEXT: 1,
218
+ f"{LABEL}_{ACTION_TEXT}": 1,
219
+ DIALOGUE: 1,
220
+ },
221
+ # Number of attention heads in transformer
222
+ NUM_HEADS: 4,
223
+ # If 'True' use key relative embeddings in attention
224
+ KEY_RELATIVE_ATTENTION: False,
225
+ # If 'True' use value relative embeddings in attention
226
+ VALUE_RELATIVE_ATTENTION: False,
227
+ # Max position for relative embeddings. Only in effect if key- or value
228
+ # relative
229
+ # attention are turned on
230
+ MAX_RELATIVE_POSITION: 5,
231
+ # Use a unidirectional or bidirectional encoder
232
+ # for `text`, `action_text`, and `label_action_text`.
233
+ UNIDIRECTIONAL_ENCODER: False,
234
+ # ## Training parameters
235
+ # Initial and final batch sizes:
236
+ # Batch size will be linearly increased for each epoch.
237
+ BATCH_SIZES: [64, 256],
238
+ # Strategy used whenc creating batches.
239
+ # Can be either 'sequence' or 'balanced'.
240
+ BATCH_STRATEGY: BALANCED,
241
+ # Number of epochs to train
242
+ EPOCHS: 1,
243
+ # Set random seed to any 'int' to get reproducible results
244
+ RANDOM_SEED: None,
245
+ # Initial learning rate for the optimizer
246
+ LEARNING_RATE: 0.001,
247
+ # ## Parameters for embeddings
248
+ # Dimension size of embedding vectors
249
+ EMBEDDING_DIMENSION: 20,
250
+ # The number of incorrect labels. The algorithm will minimize
251
+ # their similarity to the user input during training.
252
+ NUM_NEG: 20,
253
+ # Type of similarity measure to use, either 'auto' or 'cosine' or 'inner'.
254
+ SIMILARITY_TYPE: AUTO,
255
+ # The type of the loss function, either 'cross_entropy' or 'margin'.
256
+ LOSS_TYPE: CROSS_ENTROPY,
257
+ # Number of top actions for which confidences should be predicted.
258
+ # The number of Set to `0` if confidences for all actions should be
259
+ # predicted. The confidences for all other actions will be set to 0.
260
+ RANKING_LENGTH: 0,
261
+ # Determines wether the confidences of the chosen top actions should be
262
+ # renormalized so that they sum up to 1. By default, we do not renormalize
263
+ # and return the confidences for the top actions as is.
264
+ # Note that renormalization only makes sense if confidences are generated
265
+ # via `softmax`.
266
+ RENORMALIZE_CONFIDENCES: False,
267
+ # Indicates how similar the algorithm should try to make embedding vectors
268
+ # for correct labels.
269
+ # Should be 0.0 < ... < 1.0 for 'cosine' similarity type.
270
+ MAX_POS_SIM: 0.8,
271
+ # Maximum negative similarity for incorrect labels.
272
+ # Should be -1.0 < ... < 1.0 for 'cosine' similarity type.
273
+ MAX_NEG_SIM: -0.2,
274
+ # If 'True' the algorithm only minimizes maximum similarity over
275
+ # incorrect intent labels, used only if 'loss_type' is set to 'margin'.
276
+ USE_MAX_NEG_SIM: True,
277
+ # If 'True' scale loss inverse proportionally to the confidence
278
+ # of the correct prediction
279
+ SCALE_LOSS: True,
280
+ # ## Regularization parameters
281
+ # The scale of regularization
282
+ REGULARIZATION_CONSTANT: 0.001,
283
+ # The scale of how important is to minimize the maximum similarity
284
+ # between embeddings of different labels,
285
+ # used only if 'loss_type' is set to 'margin'.
286
+ NEGATIVE_MARGIN_SCALE: 0.8,
287
+ # Dropout rate for embedding layers of dialogue features.
288
+ DROP_RATE_DIALOGUE: 0.1,
289
+ # Dropout rate for embedding layers of utterance level features.
290
+ DROP_RATE: 0.0,
291
+ # Dropout rate for embedding layers of label, e.g. action, features.
292
+ DROP_RATE_LABEL: 0.0,
293
+ # Dropout rate for attention.
294
+ DROP_RATE_ATTENTION: 0.0,
295
+ # Fraction of trainable weights in internal layers.
296
+ CONNECTION_DENSITY: 0.2,
297
+ # If 'True' apply dropout to sparse input tensors
298
+ SPARSE_INPUT_DROPOUT: True,
299
+ # If 'True' apply dropout to dense input tensors
300
+ DENSE_INPUT_DROPOUT: True,
301
+ # If 'True' random tokens of the input message will be masked. Since there
302
+ # is no related loss term used inside TED, the masking effectively becomes
303
+ # just input dropout applied to the text of user utterances.
304
+ MASKED_LM: False,
305
+ # ## Evaluation parameters
306
+ # How often calculate validation accuracy.
307
+ # Small values may hurt performance.
308
+ EVAL_NUM_EPOCHS: 20,
309
+ # How many examples to use for hold out validation set
310
+ # Large values may hurt performance, e.g. model accuracy.
311
+ # Set to 0 for no validation.
312
+ EVAL_NUM_EXAMPLES: 0,
313
+ # If you want to use tensorboard to visualize training and validation
314
+ # metrics, set this option to a valid output directory.
315
+ TENSORBOARD_LOG_DIR: None,
316
+ # Define when training metrics for tensorboard should be logged.
317
+ # Either after every epoch or for every training step.
318
+ # Valid values: 'epoch' and 'batch'
319
+ TENSORBOARD_LOG_LEVEL: "epoch",
320
+ # Perform model checkpointing
321
+ CHECKPOINT_MODEL: False,
322
+ # Only pick e2e prediction if the policy is confident enough
323
+ E2E_CONFIDENCE_THRESHOLD: 0.5,
324
+ # Specify what features to use as sequence and sentence features.
325
+ # By default all features in the pipeline are used.
326
+ FEATURIZERS: [],
327
+ # If set to true, entities are predicted in user utterances.
328
+ ENTITY_RECOGNITION: True,
329
+ # if 'True' applies sigmoid on all similarity terms and adds
330
+ # it to the loss function to ensure that similarity values are
331
+ # approximately bounded. Used inside cross-entropy loss only.
332
+ CONSTRAIN_SIMILARITIES: False,
333
+ # Model confidence to be returned during inference. Currently, the only
334
+ # possible value is `softmax`.
335
+ MODEL_CONFIDENCE: SOFTMAX,
336
+ # 'BILOU_flag' determines whether to use BILOU tagging or not.
337
+ # If set to 'True' labelling is more rigorous, however more
338
+ # examples per entity are required.
339
+ # Rule of thumb: you should have more than 100 examples per entity.
340
+ BILOU_FLAG: True,
341
+ # Split entities by comma, this makes sense e.g. for a list of
342
+ # ingredients in a recipe, but it doesn't make sense for the parts of
343
+ # an address
344
+ SPLIT_ENTITIES_BY_COMMA: SPLIT_ENTITIES_BY_COMMA_DEFAULT_VALUE,
345
+ # Max history of the policy, unbounded by default
346
+ POLICY_MAX_HISTORY: DEFAULT_MAX_HISTORY,
347
+ # Determines the importance of policies, higher values take precedence
348
+ POLICY_PRIORITY: DEFAULT_POLICY_PRIORITY,
349
+ USE_GPU: True,
350
+ }
351
+
352
+ def __init__(
353
+ self,
354
+ config: Dict[Text, Any],
355
+ model_storage: ModelStorage,
356
+ resource: Resource,
357
+ execution_context: ExecutionContext,
358
+ model: Optional[RasaModel] = None,
359
+ featurizer: Optional[TrackerFeaturizer] = None,
360
+ fake_features: Optional[Dict[Text, List[Features]]] = None,
361
+ entity_tag_specs: Optional[List[EntityTagSpec]] = None,
362
+ ) -> None:
363
+ """Declares instance variables with default values."""
364
+ super().__init__(
365
+ config, model_storage, resource, execution_context, featurizer=featurizer
366
+ )
367
+ self.split_entities_config = rasa.utils.train_utils.init_split_entities(
368
+ config[SPLIT_ENTITIES_BY_COMMA], SPLIT_ENTITIES_BY_COMMA_DEFAULT_VALUE
369
+ )
370
+ self._load_params(config)
371
+
372
+ self.model = model
373
+
374
+ self._entity_tag_specs = entity_tag_specs
375
+
376
+ self.fake_features = fake_features or defaultdict(list)
377
+ # TED is only e2e if only text is present in fake features, which represent
378
+ # all possible input features for current version of this trained ted
379
+ self.only_e2e = TEXT in self.fake_features and INTENT not in self.fake_features
380
+
381
+ self._label_data: Optional[RasaModelData] = None
382
+ self.data_example: Optional[Dict[Text, Dict[Text, List[FeatureArray]]]] = None
383
+
384
+ self.tmp_checkpoint_dir = None
385
+ if self.config[CHECKPOINT_MODEL]:
386
+ self.tmp_checkpoint_dir = Path(rasa.utils.io.create_temporary_directory())
387
+
388
+ @staticmethod
389
+ def model_class() -> Type[TED]:
390
+ """Gets the class of the model architecture to be used by the policy.
391
+
392
+ Returns:
393
+ Required class.
394
+ """
395
+ return TED
396
+
397
+ @classmethod
398
+ def _metadata_filename(cls) -> Optional[Text]:
399
+ return "ted_policy"
400
+
401
+ def _load_params(self, config: Dict[Text, Any]) -> None:
402
+ new_config = rasa.utils.train_utils.check_core_deprecated_options(config)
403
+ self.config = new_config
404
+ self._auto_update_configuration()
405
+
406
+ def _auto_update_configuration(self) -> None:
407
+ """Takes care of deprecations and compatibility of parameters."""
408
+ self.config = rasa.utils.train_utils.update_confidence_type(self.config)
409
+ rasa.utils.train_utils.validate_configuration_settings(self.config)
410
+ self.config = rasa.utils.train_utils.update_similarity_type(self.config)
411
+ self.config = rasa.utils.train_utils.update_evaluation_parameters(self.config)
412
+
413
+ def _create_label_data(
414
+ self,
415
+ domain: Domain,
416
+ precomputations: Optional[MessageContainerForCoreFeaturization],
417
+ ) -> Tuple[RasaModelData, List[Dict[Text, List[Features]]]]:
418
+ # encode all label_ids with policies' featurizer
419
+ state_featurizer = self.featurizer.state_featurizer
420
+ encoded_all_labels = (
421
+ state_featurizer.encode_all_labels(domain, precomputations)
422
+ if state_featurizer is not None
423
+ else []
424
+ )
425
+
426
+ attribute_data, _ = convert_to_data_format(
427
+ encoded_all_labels, featurizers=self.config[FEATURIZERS]
428
+ )
429
+
430
+ label_data = self._assemble_label_data(attribute_data, domain)
431
+
432
+ return label_data, encoded_all_labels
433
+
434
+ def _assemble_label_data(
435
+ self, attribute_data: Data, domain: Domain
436
+ ) -> RasaModelData:
437
+ """Constructs data regarding labels to be fed to the model.
438
+
439
+ The resultant model data can possibly contain one or both of the
440
+ keys - [`label_action_name`, `label_action_text`] but will definitely
441
+ contain the `label` key.
442
+ `label_action_*` will contain the sequence, sentence and mask features
443
+ for corresponding labels and `label` will contain the numerical label ids.
444
+
445
+ Args:
446
+ attribute_data: Feature data for all labels.
447
+ domain: Domain of the assistant.
448
+
449
+ Returns:
450
+ Features of labels ready to be fed to the model.
451
+ """
452
+ label_data = RasaModelData()
453
+ label_data.add_data(attribute_data, key_prefix=f"{LABEL_KEY}_")
454
+ label_data.add_lengths(
455
+ f"{LABEL}_{ACTION_TEXT}",
456
+ SEQUENCE_LENGTH,
457
+ f"{LABEL}_{ACTION_TEXT}",
458
+ SEQUENCE,
459
+ )
460
+ label_ids = np.arange(domain.num_actions)
461
+ label_data.add_features(
462
+ LABEL_KEY,
463
+ LABEL_SUB_KEY,
464
+ [
465
+ FeatureArray(
466
+ np.expand_dims(label_ids, -1),
467
+ number_of_dimensions=2,
468
+ )
469
+ ],
470
+ )
471
+ return label_data
472
+
473
+ @staticmethod
474
+ def _should_extract_entities(
475
+ entity_tags: List[List[Dict[Text, List[Features]]]],
476
+ ) -> bool:
477
+ for turns_tags in entity_tags:
478
+ for turn_tags in turns_tags:
479
+ # if turn_tags are empty or all entity tag indices are `0`
480
+ # it means that all the inputs only contain NO_ENTITY_TAG
481
+ if turn_tags and np.any(turn_tags[ENTITY_TAGS][0].features):
482
+ return True
483
+ return False
484
+
485
+ def _create_data_for_entities(
486
+ self, entity_tags: Optional[List[List[Dict[Text, List[Features]]]]]
487
+ ) -> Optional[Data]:
488
+ if not self.config[ENTITY_RECOGNITION]:
489
+ return None
490
+
491
+ # check that there are real entity tags
492
+ if entity_tags and self._should_extract_entities(entity_tags):
493
+ entity_tags_data, _ = convert_to_data_format(entity_tags)
494
+ return entity_tags_data
495
+
496
+ # there are no "real" entity tags
497
+ logger.debug(
498
+ f"Entity recognition cannot be performed, "
499
+ f"set '{ENTITY_RECOGNITION}' config parameter to 'False'."
500
+ )
501
+ self.config[ENTITY_RECOGNITION] = False
502
+
503
+ return None
504
+
505
+ def _create_model_data(
506
+ self,
507
+ tracker_state_features: List[List[Dict[Text, List[Features]]]],
508
+ label_ids: Optional[np.ndarray] = None,
509
+ entity_tags: Optional[List[List[Dict[Text, List[Features]]]]] = None,
510
+ encoded_all_labels: Optional[List[Dict[Text, List[Features]]]] = None,
511
+ ) -> RasaModelData:
512
+ """Combine all model related data into RasaModelData.
513
+
514
+ Args:
515
+ tracker_state_features: a dictionary of attributes
516
+ (INTENT, TEXT, ACTION_NAME, ACTION_TEXT, ENTITIES, SLOTS, ACTIVE_LOOP)
517
+ to a list of features for all dialogue turns in all training trackers
518
+ label_ids: the label ids (e.g. action ids) for every dialogue turn in all
519
+ training trackers
520
+ entity_tags: a dictionary of entity type (ENTITY_TAGS) to a list of features
521
+ containing entity tag ids for text user inputs otherwise empty dict
522
+ for all dialogue turns in all training trackers
523
+ encoded_all_labels: a list of dictionaries containing attribute features
524
+ for label ids
525
+
526
+ Returns:
527
+ RasaModelData
528
+ """
529
+ model_data = RasaModelData(label_key=LABEL_KEY, label_sub_key=LABEL_SUB_KEY)
530
+
531
+ if label_ids is not None and encoded_all_labels is not None:
532
+ label_ids = np.array(
533
+ [np.expand_dims(seq_label_ids, -1) for seq_label_ids in label_ids]
534
+ )
535
+ model_data.add_features(
536
+ LABEL_KEY,
537
+ LABEL_SUB_KEY,
538
+ [FeatureArray(label_ids, number_of_dimensions=3)],
539
+ )
540
+
541
+ attribute_data, self.fake_features = convert_to_data_format(
542
+ tracker_state_features, featurizers=self.config[FEATURIZERS]
543
+ )
544
+
545
+ entity_tags_data = self._create_data_for_entities(entity_tags)
546
+ if entity_tags_data is not None:
547
+ model_data.add_data(entity_tags_data)
548
+ else:
549
+ # method is called during prediction
550
+ attribute_data, _ = convert_to_data_format(
551
+ tracker_state_features,
552
+ self.fake_features,
553
+ featurizers=self.config[FEATURIZERS],
554
+ )
555
+
556
+ model_data.add_data(attribute_data)
557
+ model_data.add_lengths(TEXT, SEQUENCE_LENGTH, TEXT, SEQUENCE)
558
+ model_data.add_lengths(ACTION_TEXT, SEQUENCE_LENGTH, ACTION_TEXT, SEQUENCE)
559
+
560
+ # add the dialogue lengths
561
+ attribute_present = next(iter(list(attribute_data.keys())))
562
+ dialogue_lengths = np.array(
563
+ [
564
+ np.size(np.squeeze(f, -1))
565
+ for f in model_data.data[attribute_present][MASK][0]
566
+ ]
567
+ )
568
+ model_data.data[DIALOGUE][LENGTH] = [
569
+ FeatureArray(dialogue_lengths, number_of_dimensions=1)
570
+ ]
571
+
572
+ # make sure all keys are in the same order during training and prediction
573
+ model_data.sort()
574
+
575
+ return model_data
576
+
577
+ @staticmethod
578
+ def _get_trackers_for_training(
579
+ trackers: List[TrackerWithCachedStates],
580
+ ) -> List[TrackerWithCachedStates]:
581
+ """Filters out the list of trackers which should not be used for training.
582
+
583
+ Args:
584
+ trackers: All trackers available for training.
585
+
586
+ Returns:
587
+ Trackers which should be used for training.
588
+ """
589
+ # By default, we train on all available trackers.
590
+ return trackers
591
+
592
+ def _prepare_for_training(
593
+ self,
594
+ trackers: List[TrackerWithCachedStates],
595
+ domain: Domain,
596
+ precomputations: MessageContainerForCoreFeaturization,
597
+ **kwargs: Any,
598
+ ) -> Tuple[RasaModelData, np.ndarray]:
599
+ """Prepares data to be fed into the model.
600
+
601
+ Args:
602
+ trackers: List of training trackers to be featurized.
603
+ domain: Domain of the assistant.
604
+ precomputations: Contains precomputed features and attributes.
605
+ **kwargs: Any other arguments.
606
+
607
+ Returns:
608
+ Featurized data to be fed to the model and corresponding label ids.
609
+ """
610
+ training_trackers = self._get_trackers_for_training(trackers)
611
+ # dealing with training data
612
+ tracker_state_features, label_ids, entity_tags = self._featurize_for_training(
613
+ training_trackers,
614
+ domain,
615
+ precomputations=precomputations,
616
+ bilou_tagging=self.config[BILOU_FLAG],
617
+ **kwargs,
618
+ )
619
+
620
+ if not tracker_state_features:
621
+ return RasaModelData(), label_ids
622
+
623
+ self._label_data, encoded_all_labels = self._create_label_data(
624
+ domain, precomputations=precomputations
625
+ )
626
+
627
+ # extract actual training data to feed to model
628
+ model_data = self._create_model_data(
629
+ tracker_state_features, label_ids, entity_tags, encoded_all_labels
630
+ )
631
+
632
+ if self.config[ENTITY_RECOGNITION]:
633
+ self._entity_tag_specs = (
634
+ self.featurizer.state_featurizer.entity_tag_specs
635
+ if self.featurizer.state_featurizer is not None
636
+ else []
637
+ )
638
+
639
+ # keep one example for persisting and loading
640
+ self.data_example = model_data.first_data_example()
641
+
642
+ return model_data, label_ids
643
+
644
+ def run_training(
645
+ self, model_data: RasaModelData, label_ids: Optional[np.ndarray] = None
646
+ ) -> None:
647
+ """Feeds the featurized training data to the model.
648
+
649
+ Args:
650
+ model_data: Featurized training data.
651
+ label_ids: Label ids corresponding to the data points in `model_data`.
652
+ These may or may not be used by the function depending
653
+ on how the policy is trained.
654
+ """
655
+ if not self.finetune_mode:
656
+ # This means the model wasn't loaded from a
657
+ # previously trained model and hence needs
658
+ # to be instantiated.
659
+ self.model = self.model_class()(
660
+ model_data.get_signature(),
661
+ self.config,
662
+ isinstance(self.featurizer, MaxHistoryTrackerFeaturizer),
663
+ self._label_data,
664
+ self._entity_tag_specs,
665
+ )
666
+ self.model.compile(
667
+ optimizer=tf.keras.optimizers.Adam(self.config[LEARNING_RATE])
668
+ )
669
+ (
670
+ data_generator,
671
+ validation_data_generator,
672
+ ) = rasa.utils.train_utils.create_data_generators(
673
+ model_data,
674
+ self.config[BATCH_SIZES],
675
+ self.config[EPOCHS],
676
+ self.config[BATCH_STRATEGY],
677
+ self.config[EVAL_NUM_EXAMPLES],
678
+ self.config[RANDOM_SEED],
679
+ )
680
+ callbacks = rasa.utils.train_utils.create_common_callbacks(
681
+ self.config[EPOCHS],
682
+ self.config[TENSORBOARD_LOG_DIR],
683
+ self.config[TENSORBOARD_LOG_LEVEL],
684
+ self.tmp_checkpoint_dir,
685
+ )
686
+
687
+ if self.model is None:
688
+ raise ModelNotFound("No model was detected prior to training.")
689
+
690
+ self.model.fit(
691
+ data_generator,
692
+ epochs=self.config[EPOCHS],
693
+ validation_data=validation_data_generator,
694
+ validation_freq=self.config[EVAL_NUM_EPOCHS],
695
+ callbacks=callbacks,
696
+ verbose=False,
697
+ shuffle=False, # we use custom shuffle inside data generator
698
+ )
699
+
700
+ def train(
701
+ self,
702
+ training_trackers: List[TrackerWithCachedStates],
703
+ domain: Domain,
704
+ precomputations: Optional[MessageContainerForCoreFeaturization] = None,
705
+ **kwargs: Any,
706
+ ) -> Resource:
707
+ """Trains the policy (see parent class for full docstring)."""
708
+ if not training_trackers:
709
+ rasa.shared.utils.io.raise_warning(
710
+ f"Skipping training of `{self.__class__.__name__}` "
711
+ f"as no data was provided. You can exclude this "
712
+ f"policy in the configuration "
713
+ f"file to avoid this warning.",
714
+ category=UserWarning,
715
+ )
716
+ return self._resource
717
+
718
+ training_trackers = SupportedData.trackers_for_supported_data(
719
+ self.supported_data(), training_trackers
720
+ )
721
+
722
+ model_data, label_ids = self._prepare_for_training(
723
+ training_trackers, domain, precomputations
724
+ )
725
+
726
+ if model_data.is_empty():
727
+ rasa.shared.utils.io.raise_warning(
728
+ f"Skipping training of `{self.__class__.__name__}` "
729
+ f"as no data was provided. You can exclude this "
730
+ f"policy in the configuration "
731
+ f"file to avoid this warning.",
732
+ category=UserWarning,
733
+ )
734
+ return self._resource
735
+
736
+ with (
737
+ contextlib.nullcontext() if self.config["use_gpu"] else tf.device("/cpu:0")
738
+ ):
739
+ self.run_training(model_data, label_ids)
740
+
741
+ self.persist()
742
+
743
+ return self._resource
744
+
745
+ def _featurize_tracker(
746
+ self,
747
+ tracker: DialogueStateTracker,
748
+ domain: Domain,
749
+ precomputations: Optional[MessageContainerForCoreFeaturization],
750
+ rule_only_data: Optional[Dict[Text, Any]],
751
+ ) -> List[List[Dict[Text, List[Features]]]]:
752
+ # construct two examples in the batch to be fed to the model -
753
+ # one by featurizing last user text
754
+ # and second - an optional one (see conditions below),
755
+ # the first example in the constructed batch either does not contain user input
756
+ # or uses intent or text based on whether TED is e2e only.
757
+ tracker_state_features = self._featurize_for_prediction(
758
+ tracker,
759
+ domain,
760
+ precomputations=precomputations,
761
+ use_text_for_last_user_input=self.only_e2e,
762
+ rule_only_data=rule_only_data,
763
+ )
764
+ # the second - text, but only after user utterance and if not only e2e
765
+ if (
766
+ tracker.latest_action_name == ACTION_LISTEN_NAME
767
+ and TEXT in self.fake_features
768
+ and not self.only_e2e
769
+ ):
770
+ tracker_state_features += self._featurize_for_prediction(
771
+ tracker,
772
+ domain,
773
+ precomputations=precomputations,
774
+ use_text_for_last_user_input=True,
775
+ rule_only_data=rule_only_data,
776
+ )
777
+ return tracker_state_features
778
+
779
+ def _pick_confidence(
780
+ self, confidences: np.ndarray, similarities: np.ndarray, domain: Domain
781
+ ) -> Tuple[np.ndarray, bool]:
782
+ # the confidences and similarities have shape (batch-size x number of actions)
783
+ # batch-size can only be 1 or 2;
784
+ # in the case batch-size==2, the first example contain user intent as features,
785
+ # the second - user text as features
786
+ if confidences.shape[0] > 2:
787
+ raise ValueError(
788
+ "We cannot pick prediction from batches of size more than 2."
789
+ )
790
+ # we use heuristic to pick correct prediction
791
+ if confidences.shape[0] == 2:
792
+ # we use similarities to pick appropriate input,
793
+ # since it seems to be more accurate measure,
794
+ # policy is trained to maximize the similarity not the confidence
795
+ non_e2e_action_name = domain.action_names_or_texts[
796
+ np.argmax(confidences[0])
797
+ ]
798
+ logger.debug(f"User intent lead to '{non_e2e_action_name}'.")
799
+ e2e_action_name = domain.action_names_or_texts[np.argmax(confidences[1])]
800
+ logger.debug(f"User text lead to '{e2e_action_name}'.")
801
+ if (
802
+ np.max(confidences[1]) > self.config[E2E_CONFIDENCE_THRESHOLD]
803
+ # TODO maybe compare confidences is better
804
+ and np.max(similarities[1]) > np.max(similarities[0])
805
+ ):
806
+ logger.debug(f"TED predicted '{e2e_action_name}' based on user text.")
807
+ return confidences[1], True
808
+
809
+ logger.debug(f"TED predicted '{non_e2e_action_name}' based on user intent.")
810
+ return confidences[0], False
811
+
812
+ # by default the first example in a batch is the one to use for prediction
813
+ predicted_action_name = domain.action_names_or_texts[np.argmax(confidences[0])]
814
+ basis_for_prediction = "text" if self.only_e2e else "intent"
815
+ logger.debug(
816
+ f"TED predicted '{predicted_action_name}' "
817
+ f"based on user {basis_for_prediction}."
818
+ )
819
+ return confidences[0], self.only_e2e
820
+
821
+ async def predict_action_probabilities(
822
+ self,
823
+ tracker: DialogueStateTracker,
824
+ domain: Domain,
825
+ rule_only_data: Optional[Dict[Text, Any]] = None,
826
+ precomputations: Optional[MessageContainerForCoreFeaturization] = None,
827
+ **kwargs: Any,
828
+ ) -> PolicyPrediction:
829
+ """Predicts the next action (see parent class for full docstring)."""
830
+ if self.model is None or self.should_abstain_in_coexistence(tracker, False):
831
+ return self._prediction(self._default_predictions(domain))
832
+
833
+ # create model data from tracker
834
+ tracker_state_features = self._featurize_tracker(
835
+ tracker, domain, precomputations, rule_only_data=rule_only_data
836
+ )
837
+ model_data = self._create_model_data(tracker_state_features)
838
+ outputs = self.model.run_inference(model_data)
839
+
840
+ if isinstance(outputs["similarities"], np.ndarray):
841
+ # take the last prediction in the sequence
842
+ similarities = outputs["similarities"][:, -1, :]
843
+ else:
844
+ raise TypeError(
845
+ "model output for `similarities` " "should be a numpy array"
846
+ )
847
+ if isinstance(outputs["scores"], np.ndarray):
848
+ confidences = outputs["scores"][:, -1, :]
849
+ else:
850
+ raise TypeError("model output for `scores` should be a numpy array")
851
+ # take correct prediction from batch
852
+ confidence, is_e2e_prediction = self._pick_confidence(
853
+ confidences, similarities, domain
854
+ )
855
+
856
+ # rank and mask the confidence (if we need to)
857
+ ranking_length = self.config[RANKING_LENGTH]
858
+ if 0 < ranking_length < len(confidence):
859
+ renormalize = (
860
+ self.config[RENORMALIZE_CONFIDENCES]
861
+ and self.config[MODEL_CONFIDENCE] == SOFTMAX
862
+ )
863
+ _, confidence = train_utils.rank_and_mask(
864
+ confidence, ranking_length=ranking_length, renormalize=renormalize
865
+ )
866
+
867
+ optional_events = self._create_optional_event_for_entities(
868
+ outputs, is_e2e_prediction, precomputations, tracker
869
+ )
870
+
871
+ return self._prediction(
872
+ confidence.tolist(),
873
+ is_end_to_end_prediction=is_e2e_prediction,
874
+ optional_events=optional_events,
875
+ diagnostic_data=outputs.get(DIAGNOSTIC_DATA),
876
+ )
877
+
878
+ def _create_optional_event_for_entities(
879
+ self,
880
+ prediction_output: Dict[Text, tf.Tensor],
881
+ is_e2e_prediction: bool,
882
+ precomputations: Optional[MessageContainerForCoreFeaturization],
883
+ tracker: DialogueStateTracker,
884
+ ) -> Optional[List[Event]]:
885
+ if tracker.latest_action_name != ACTION_LISTEN_NAME or not is_e2e_prediction:
886
+ # entities belong only to the last user message
887
+ # and only if user text was used for prediction,
888
+ # a user message always comes after action listen
889
+ return None
890
+
891
+ if not self.config[ENTITY_RECOGNITION]:
892
+ # entity recognition is not turned on, no entities can be predicted
893
+ return None
894
+
895
+ # The batch dimension of entity prediction is not the same as batch size,
896
+ # rather it is the number of last (if max history featurizer else all)
897
+ # text inputs in the batch
898
+ # therefore, in order to pick entities from the latest user message
899
+ # we need to pick entities from the last batch dimension of entity prediction
900
+ predicted_tags, confidence_values = rasa.utils.train_utils.entity_label_to_tags(
901
+ prediction_output,
902
+ self._entity_tag_specs,
903
+ self.config[BILOU_FLAG],
904
+ prediction_index=-1,
905
+ )
906
+
907
+ if ENTITY_ATTRIBUTE_TYPE not in predicted_tags:
908
+ # no entities detected
909
+ return None
910
+
911
+ # entities belong to the last message of the tracker
912
+ # convert the predicted tags to actual entities
913
+ text = tracker.latest_message.text if tracker.latest_message is not None else ""
914
+ if precomputations is not None:
915
+ parsed_message = precomputations.lookup_message(user_text=text)
916
+ else:
917
+ parsed_message = Message(data={TEXT: text})
918
+ tokens = parsed_message.get(TOKENS_NAMES[TEXT])
919
+ entities = EntityExtractorMixin.convert_predictions_into_entities(
920
+ text,
921
+ tokens,
922
+ predicted_tags,
923
+ self.split_entities_config,
924
+ confidences=confidence_values,
925
+ )
926
+
927
+ # add the extractor name
928
+ for entity in entities:
929
+ entity[EXTRACTOR] = "TEDPolicy"
930
+
931
+ return [EntitiesAdded(entities)]
932
+
933
+ def persist(self) -> None:
934
+ """Persists the policy to a storage."""
935
+ if self.model is None:
936
+ logger.debug(
937
+ "Method `persist(...)` was called without a trained model present. "
938
+ "Nothing to persist then!"
939
+ )
940
+ return
941
+
942
+ with self._model_storage.write_to(self._resource) as model_path:
943
+ model_filename = self._metadata_filename()
944
+ tf_model_file = model_path / f"{model_filename}.tf_model"
945
+
946
+ rasa.shared.utils.io.create_directory_for_file(tf_model_file)
947
+
948
+ self.featurizer.persist(model_path)
949
+
950
+ if self.config[CHECKPOINT_MODEL] and self.tmp_checkpoint_dir:
951
+ self.model.load_weights(self.tmp_checkpoint_dir / "checkpoint.tf_model")
952
+ # Save an empty file to flag that this model has been
953
+ # produced using checkpointing
954
+ checkpoint_marker = model_path / f"{model_filename}.from_checkpoint.pkl"
955
+ checkpoint_marker.touch()
956
+
957
+ self.model.save(str(tf_model_file))
958
+
959
+ self.persist_model_utilities(model_path)
960
+
961
+ def persist_model_utilities(self, model_path: Path) -> None:
962
+ """Persists model's utility attributes like model weights, etc.
963
+
964
+ Args:
965
+ model_path: Path where model is to be persisted
966
+ """
967
+ model_filename = self._metadata_filename()
968
+ rasa.shared.utils.io.dump_obj_as_json_to_file(
969
+ model_path / f"{model_filename}.priority.json", self.priority
970
+ )
971
+ rasa.shared.utils.io.dump_obj_as_json_to_file(
972
+ model_path / f"{model_filename}.meta.json", self.config
973
+ )
974
+ # save data example
975
+ serialize_nested_feature_arrays(
976
+ self.data_example,
977
+ str(model_path / f"{model_filename}.data_example.st"),
978
+ str(model_path / f"{model_filename}.data_example_metadata.json"),
979
+ )
980
+ # save label data
981
+ serialize_nested_feature_arrays(
982
+ dict(self._label_data.data) if self._label_data is not None else {},
983
+ str(model_path / f"{model_filename}.label_data.st"),
984
+ str(model_path / f"{model_filename}.label_data_metadata.json"),
985
+ )
986
+ # save fake features
987
+ metadata = save_features(
988
+ self.fake_features, str(model_path / f"{model_filename}.fake_features.st")
989
+ )
990
+ rasa.shared.utils.io.dump_obj_as_json_to_file(
991
+ model_path / f"{model_filename}.fake_features_metadata.json", metadata
992
+ )
993
+
994
+ entity_tag_specs = (
995
+ [tag_spec._asdict() for tag_spec in self._entity_tag_specs]
996
+ if self._entity_tag_specs
997
+ else []
998
+ )
999
+ rasa.shared.utils.io.dump_obj_as_json_to_file(
1000
+ model_path / f"{model_filename}.entity_tag_specs.json", entity_tag_specs
1001
+ )
1002
+
1003
+ @classmethod
1004
+ def _load_model_utilities(cls, model_path: Path) -> Dict[Text, Any]:
1005
+ """Loads model's utility attributes.
1006
+
1007
+ Args:
1008
+ model_path: Path where model is to be persisted.
1009
+ """
1010
+ tf_model_file = model_path / f"{cls._metadata_filename()}.tf_model"
1011
+
1012
+ # load data example
1013
+ loaded_data = deserialize_nested_feature_arrays(
1014
+ str(model_path / f"{cls._metadata_filename()}.data_example.st"),
1015
+ str(model_path / f"{cls._metadata_filename()}.data_example_metadata.json"),
1016
+ )
1017
+ # load label data
1018
+ loaded_label_data = deserialize_nested_feature_arrays(
1019
+ str(model_path / f"{cls._metadata_filename()}.label_data.st"),
1020
+ str(model_path / f"{cls._metadata_filename()}.label_data_metadata.json"),
1021
+ )
1022
+ label_data = RasaModelData(data=loaded_label_data)
1023
+
1024
+ # load fake features
1025
+ metadata = rasa.shared.utils.io.read_json_file(
1026
+ model_path / f"{cls._metadata_filename()}.fake_features_metadata.json"
1027
+ )
1028
+ fake_features = load_features(
1029
+ str(model_path / f"{cls._metadata_filename()}.fake_features.st"), metadata
1030
+ )
1031
+
1032
+ priority = rasa.shared.utils.io.read_json_file(
1033
+ model_path / f"{cls._metadata_filename()}.priority.json"
1034
+ )
1035
+ entity_tag_specs = rasa.shared.utils.io.read_json_file(
1036
+ model_path / f"{cls._metadata_filename()}.entity_tag_specs.json"
1037
+ )
1038
+ entity_tag_specs = [
1039
+ EntityTagSpec(
1040
+ tag_name=tag_spec["tag_name"],
1041
+ ids_to_tags={
1042
+ int(key): value for key, value in tag_spec["ids_to_tags"].items()
1043
+ },
1044
+ tags_to_ids={
1045
+ key: int(value) for key, value in tag_spec["tags_to_ids"].items()
1046
+ },
1047
+ num_tags=tag_spec["num_tags"],
1048
+ )
1049
+ for tag_spec in entity_tag_specs
1050
+ ]
1051
+ model_config = rasa.shared.utils.io.read_json_file(
1052
+ model_path / f"{cls._metadata_filename()}.meta.json"
1053
+ )
1054
+
1055
+ return {
1056
+ "tf_model_file": tf_model_file,
1057
+ "loaded_data": loaded_data,
1058
+ "fake_features": fake_features,
1059
+ "label_data": label_data,
1060
+ "priority": priority,
1061
+ "entity_tag_specs": entity_tag_specs,
1062
+ "model_config": model_config,
1063
+ }
1064
+
1065
+ @classmethod
1066
+ def load(
1067
+ cls,
1068
+ config: Dict[Text, Any],
1069
+ model_storage: ModelStorage,
1070
+ resource: Resource,
1071
+ execution_context: ExecutionContext,
1072
+ **kwargs: Any,
1073
+ ) -> TEDPolicy:
1074
+ """Loads a policy from the storage (see parent class for full docstring)."""
1075
+ try:
1076
+ with model_storage.read_from(resource) as model_path:
1077
+ return cls._load(
1078
+ model_path, config, model_storage, resource, execution_context
1079
+ )
1080
+ except ValueError:
1081
+ logger.debug(
1082
+ f"Failed to load {cls.__class__.__name__} from model storage. Resource "
1083
+ f"'{resource.name}' doesn't exist."
1084
+ )
1085
+ return cls(config, model_storage, resource, execution_context)
1086
+
1087
+ @classmethod
1088
+ def _load(
1089
+ cls,
1090
+ model_path: Path,
1091
+ config: Dict[Text, Any],
1092
+ model_storage: ModelStorage,
1093
+ resource: Resource,
1094
+ execution_context: ExecutionContext,
1095
+ ) -> TEDPolicy:
1096
+ featurizer = TrackerFeaturizer.load(model_path)
1097
+
1098
+ if not (model_path / f"{cls._metadata_filename()}.data_example.st").is_file():
1099
+ return cls(
1100
+ config,
1101
+ model_storage,
1102
+ resource,
1103
+ execution_context,
1104
+ featurizer=featurizer,
1105
+ )
1106
+
1107
+ model_utilities = cls._load_model_utilities(model_path)
1108
+
1109
+ config = cls._update_loaded_params(config)
1110
+ if execution_context.is_finetuning and EPOCH_OVERRIDE in config:
1111
+ config[EPOCHS] = config.get(EPOCH_OVERRIDE)
1112
+
1113
+ (
1114
+ model_data_example,
1115
+ predict_data_example,
1116
+ ) = cls._construct_model_initialization_data(model_utilities["loaded_data"])
1117
+
1118
+ model = None
1119
+
1120
+ with contextlib.nullcontext() if config["use_gpu"] else tf.device("/cpu:0"):
1121
+ model = cls._load_tf_model(
1122
+ model_utilities,
1123
+ model_data_example,
1124
+ predict_data_example,
1125
+ featurizer,
1126
+ execution_context.is_finetuning,
1127
+ )
1128
+
1129
+ return cls._load_policy_with_model(
1130
+ config,
1131
+ model_storage,
1132
+ resource,
1133
+ execution_context,
1134
+ featurizer=featurizer,
1135
+ model_utilities=model_utilities,
1136
+ model=model,
1137
+ )
1138
+
1139
+ @classmethod
1140
+ def _load_policy_with_model(
1141
+ cls,
1142
+ config: Dict[Text, Any],
1143
+ model_storage: ModelStorage,
1144
+ resource: Resource,
1145
+ execution_context: ExecutionContext,
1146
+ featurizer: TrackerFeaturizer,
1147
+ model: TED,
1148
+ model_utilities: Dict[Text, Any],
1149
+ ) -> TEDPolicy:
1150
+ return cls(
1151
+ config,
1152
+ model_storage,
1153
+ resource,
1154
+ execution_context,
1155
+ model=model,
1156
+ featurizer=featurizer,
1157
+ fake_features=model_utilities["fake_features"],
1158
+ entity_tag_specs=model_utilities["entity_tag_specs"],
1159
+ )
1160
+
1161
+ @classmethod
1162
+ def _load_tf_model(
1163
+ cls,
1164
+ model_utilities: Dict[Text, Any],
1165
+ model_data_example: RasaModelData,
1166
+ predict_data_example: RasaModelData,
1167
+ featurizer: TrackerFeaturizer,
1168
+ should_finetune: bool,
1169
+ ) -> TED:
1170
+ model = cls.model_class().load(
1171
+ str(model_utilities["tf_model_file"]),
1172
+ model_data_example,
1173
+ predict_data_example,
1174
+ data_signature=model_data_example.get_signature(),
1175
+ config=model_utilities["model_config"],
1176
+ max_history_featurizer_is_used=isinstance(
1177
+ featurizer, MaxHistoryTrackerFeaturizer
1178
+ ),
1179
+ label_data=model_utilities["label_data"],
1180
+ entity_tag_specs=model_utilities["entity_tag_specs"],
1181
+ finetune_mode=should_finetune,
1182
+ )
1183
+ return model
1184
+
1185
+ @classmethod
1186
+ def _construct_model_initialization_data(
1187
+ cls, loaded_data: Dict[Text, Dict[Text, List[FeatureArray]]]
1188
+ ) -> Tuple[RasaModelData, RasaModelData]:
1189
+ model_data_example = RasaModelData(
1190
+ label_key=LABEL_KEY, label_sub_key=LABEL_SUB_KEY, data=loaded_data
1191
+ )
1192
+ predict_data_example = RasaModelData(
1193
+ label_key=LABEL_KEY,
1194
+ label_sub_key=LABEL_SUB_KEY,
1195
+ data={
1196
+ feature_name: features
1197
+ for feature_name, features in model_data_example.items()
1198
+ # we need to remove label features for prediction if they are present
1199
+ if feature_name in PREDICTION_FEATURES
1200
+ },
1201
+ )
1202
+ return model_data_example, predict_data_example
1203
+
1204
+ @classmethod
1205
+ def _update_loaded_params(cls, meta: Dict[Text, Any]) -> Dict[Text, Any]:
1206
+ meta = rasa.utils.train_utils.update_confidence_type(meta)
1207
+ meta = rasa.utils.train_utils.update_similarity_type(meta)
1208
+
1209
+ return meta
1210
+
1211
+
1212
+ class TED(TransformerRasaModel):
1213
+ """TED model architecture from https://arxiv.org/abs/1910.00486."""
1214
+
1215
+ def __init__(
1216
+ self,
1217
+ data_signature: Dict[Text, Dict[Text, List[FeatureSignature]]],
1218
+ config: Dict[Text, Any],
1219
+ max_history_featurizer_is_used: bool,
1220
+ label_data: RasaModelData,
1221
+ entity_tag_specs: Optional[List[EntityTagSpec]],
1222
+ ) -> None:
1223
+ """Initializes the TED model.
1224
+
1225
+ Args:
1226
+ data_signature: the data signature of the input data
1227
+ config: the model configuration
1228
+ max_history_featurizer_is_used: if 'True'
1229
+ only the last dialogue turn will be used
1230
+ label_data: the label data
1231
+ entity_tag_specs: the entity tag specifications
1232
+ """
1233
+ super().__init__("TED", config, data_signature, label_data)
1234
+
1235
+ self.max_history_featurizer_is_used = max_history_featurizer_is_used
1236
+
1237
+ self.predict_data_signature = {
1238
+ feature_name: features
1239
+ for feature_name, features in data_signature.items()
1240
+ if feature_name in PREDICTION_FEATURES
1241
+ }
1242
+
1243
+ self._entity_tag_specs = entity_tag_specs
1244
+
1245
+ # metrics
1246
+ self.action_loss = tf.keras.metrics.Mean(name="loss")
1247
+ self.action_acc = tf.keras.metrics.Mean(name="acc")
1248
+ self.entity_loss = tf.keras.metrics.Mean(name="e_loss")
1249
+ self.entity_f1 = tf.keras.metrics.Mean(name="e_f1")
1250
+ self.metrics_to_log += ["loss", "acc"]
1251
+ if self.config[ENTITY_RECOGNITION]:
1252
+ self.metrics_to_log += ["e_loss", "e_f1"]
1253
+
1254
+ # needed for efficient prediction
1255
+ self.all_labels_embed: Optional[tf.Tensor] = None
1256
+
1257
+ self._prepare_layers()
1258
+
1259
+ def _check_data(self) -> None:
1260
+ if not any(key in [INTENT, TEXT] for key in self.data_signature.keys()):
1261
+ raise RasaException(
1262
+ f"No user features specified. "
1263
+ f"Cannot train '{self.__class__.__name__}' model."
1264
+ )
1265
+
1266
+ if not any(
1267
+ key in [ACTION_NAME, ACTION_TEXT] for key in self.data_signature.keys()
1268
+ ):
1269
+ raise ValueError(
1270
+ f"No action features specified. "
1271
+ f"Cannot train '{self.__class__.__name__}' model."
1272
+ )
1273
+ if LABEL not in self.data_signature:
1274
+ raise ValueError(
1275
+ f"No label features specified. "
1276
+ f"Cannot train '{self.__class__.__name__}' model."
1277
+ )
1278
+
1279
+ # ---CREATING LAYERS HELPERS---
1280
+
1281
+ def _prepare_layers(self) -> None:
1282
+ for name in self.data_signature.keys():
1283
+ self._prepare_input_layers(
1284
+ name, self.data_signature[name], is_label_attribute=False
1285
+ )
1286
+ self._prepare_encoding_layers(name)
1287
+
1288
+ for name in self.label_signature.keys():
1289
+ self._prepare_input_layers(
1290
+ name, self.label_signature[name], is_label_attribute=True
1291
+ )
1292
+ self._prepare_encoding_layers(name)
1293
+
1294
+ self._tf_layers[f"transformer.{DIALOGUE}"] = (
1295
+ rasa_layers.prepare_transformer_layer(
1296
+ attribute_name=DIALOGUE,
1297
+ config=self.config,
1298
+ num_layers=self.config[NUM_TRANSFORMER_LAYERS][DIALOGUE],
1299
+ units=self.config[TRANSFORMER_SIZE][DIALOGUE],
1300
+ drop_rate=self.config[DROP_RATE_DIALOGUE],
1301
+ # use bidirectional transformer, because
1302
+ # we will invert dialogue sequence so that the last turn is located
1303
+ # at the first position and would always have
1304
+ # exactly the same positional encoding
1305
+ unidirectional=not self.max_history_featurizer_is_used,
1306
+ )
1307
+ )
1308
+
1309
+ self._prepare_label_classification_layers(DIALOGUE)
1310
+
1311
+ if self.config[ENTITY_RECOGNITION]:
1312
+ self._prepare_entity_recognition_layers()
1313
+
1314
+ def _prepare_input_layers(
1315
+ self,
1316
+ attribute_name: Text,
1317
+ attribute_signature: Dict[Text, List[FeatureSignature]],
1318
+ is_label_attribute: bool = False,
1319
+ ) -> None:
1320
+ """Prepares feature processing layers for sentence/sequence-level features.
1321
+
1322
+ Distinguishes between label features and other features, not applying input
1323
+ dropout to the label ones.
1324
+ """
1325
+ # Disable input dropout in the config to be used if this is a label attribute.
1326
+ if is_label_attribute:
1327
+ config_to_use = self.config.copy()
1328
+ config_to_use.update(
1329
+ {SPARSE_INPUT_DROPOUT: False, DENSE_INPUT_DROPOUT: False}
1330
+ )
1331
+ else:
1332
+ config_to_use = self.config
1333
+ # Attributes with sequence-level features also have sentence-level features,
1334
+ # all these need to be combined and further processed.
1335
+ if attribute_name in SEQUENCE_FEATURES_TO_ENCODE:
1336
+ self._tf_layers[f"sequence_layer.{attribute_name}"] = (
1337
+ rasa_layers.RasaSequenceLayer(
1338
+ attribute_name, attribute_signature, config_to_use
1339
+ )
1340
+ )
1341
+ # Attributes without sequence-level features require some actual feature
1342
+ # processing only if they have sentence-level features. Attributes with no
1343
+ # sequence- and sentence-level features (dialogue, entity_tags, label) are
1344
+ # skipped here.
1345
+ elif SENTENCE in attribute_signature:
1346
+ self._tf_layers[f"sparse_dense_concat_layer.{attribute_name}"] = (
1347
+ rasa_layers.ConcatenateSparseDenseFeatures(
1348
+ attribute=attribute_name,
1349
+ feature_type=SENTENCE,
1350
+ feature_type_signature=attribute_signature[SENTENCE],
1351
+ config=config_to_use,
1352
+ )
1353
+ )
1354
+
1355
+ def _prepare_encoding_layers(self, name: Text) -> None:
1356
+ """Create Ffnn encoding layer used just before combining all dialogue features.
1357
+
1358
+ Args:
1359
+ name: attribute name
1360
+ """
1361
+ # create encoding layers only for the features which should be encoded;
1362
+ if name not in SENTENCE_FEATURES_TO_ENCODE + LABEL_FEATURES_TO_ENCODE:
1363
+ return
1364
+ # check that there are SENTENCE features for the attribute name in data
1365
+ if (
1366
+ name in SENTENCE_FEATURES_TO_ENCODE
1367
+ and FEATURE_TYPE_SENTENCE not in self.data_signature[name]
1368
+ ):
1369
+ return
1370
+ # same for label_data
1371
+ if (
1372
+ name in LABEL_FEATURES_TO_ENCODE
1373
+ and FEATURE_TYPE_SENTENCE not in self.label_signature[name]
1374
+ ):
1375
+ return
1376
+
1377
+ self._prepare_ffnn_layer(
1378
+ f"{name}",
1379
+ [self.config[ENCODING_DIMENSION]],
1380
+ self.config[DROP_RATE_DIALOGUE],
1381
+ prefix="encoding_layer",
1382
+ )
1383
+
1384
+ # ---GRAPH BUILDING HELPERS---
1385
+
1386
+ @staticmethod
1387
+ def _compute_dialogue_indices(
1388
+ tf_batch_data: Dict[Text, Dict[Text, List[tf.Tensor]]],
1389
+ ) -> None:
1390
+ dialogue_lengths = tf.cast(tf_batch_data[DIALOGUE][LENGTH][0], dtype=tf.int32)
1391
+ # wrap in a list, because that's the structure of tf_batch_data
1392
+ tf_batch_data[DIALOGUE][INDICES] = [
1393
+ (
1394
+ tf.map_fn(
1395
+ tf.range,
1396
+ dialogue_lengths,
1397
+ fn_output_signature=tf.RaggedTensorSpec(
1398
+ shape=[None], dtype=tf.int32
1399
+ ),
1400
+ )
1401
+ ).values
1402
+ ]
1403
+
1404
+ def _create_all_labels_embed(self) -> Tuple[tf.Tensor, tf.Tensor]:
1405
+ all_label_ids = self.tf_label_data[LABEL_KEY][LABEL_SUB_KEY][0]
1406
+ # labels cannot have all features "fake"
1407
+ all_labels_encoded = {}
1408
+ for key in self.tf_label_data.keys():
1409
+ if key != LABEL_KEY:
1410
+ attribute_features, _, _ = self._encode_real_features_per_attribute(
1411
+ self.tf_label_data, key
1412
+ )
1413
+ all_labels_encoded[key] = attribute_features
1414
+
1415
+ x = self._collect_label_attribute_encodings(all_labels_encoded)
1416
+
1417
+ # additional sequence axis is artifact of our RasaModelData creation
1418
+ # TODO check whether this should be solved in data creation
1419
+ x = tf.squeeze(x, axis=1)
1420
+
1421
+ all_labels_embed = self._tf_layers[f"embed.{LABEL}"](x)
1422
+
1423
+ return all_label_ids, all_labels_embed
1424
+
1425
+ @staticmethod
1426
+ def _collect_label_attribute_encodings(
1427
+ all_labels_encoded: Dict[Text, tf.Tensor],
1428
+ ) -> tf.Tensor:
1429
+ # Initialize with at least one attribute first
1430
+ # so that the subsequent TF ops are simplified.
1431
+ all_attributes_present = list(all_labels_encoded.keys())
1432
+ x = all_labels_encoded.pop(all_attributes_present[0])
1433
+
1434
+ # Add remaining attributes
1435
+ for attribute in all_labels_encoded:
1436
+ x += all_labels_encoded.get(attribute)
1437
+ return x
1438
+
1439
+ def _embed_dialogue(
1440
+ self,
1441
+ dialogue_in: tf.Tensor,
1442
+ tf_batch_data: Dict[Text, Dict[Text, List[tf.Tensor]]],
1443
+ ) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor, Optional[tf.Tensor]]:
1444
+ """Creates dialogue level embedding and mask.
1445
+
1446
+ Args:
1447
+ dialogue_in: The encoded dialogue.
1448
+ tf_batch_data: Batch in model data format.
1449
+
1450
+ Returns:
1451
+ The dialogue embedding, the mask, and (for diagnostic purposes)
1452
+ also the attention weights.
1453
+ """
1454
+ dialogue_lengths = tf.cast(tf_batch_data[DIALOGUE][LENGTH][0], tf.int32)
1455
+ mask = rasa_layers.compute_mask(dialogue_lengths)
1456
+
1457
+ if self.max_history_featurizer_is_used:
1458
+ # invert dialogue sequence so that the last turn would always have
1459
+ # exactly the same positional encoding
1460
+ dialogue_in = tf.reverse_sequence(dialogue_in, dialogue_lengths, seq_axis=1)
1461
+
1462
+ dialogue_transformed, attention_weights = self._tf_layers[
1463
+ f"transformer.{DIALOGUE}"
1464
+ ](dialogue_in, 1 - mask, self._training)
1465
+ dialogue_transformed = tf.nn.gelu(dialogue_transformed)
1466
+
1467
+ if self.max_history_featurizer_is_used:
1468
+ # pick last vector if max history featurizer is used, since we inverted
1469
+ # dialogue sequence, the last vector is actually the first one
1470
+ dialogue_transformed = dialogue_transformed[:, :1, :]
1471
+ mask = tf.expand_dims(self._last_token(mask, dialogue_lengths), 1)
1472
+ elif not self._training:
1473
+ # during prediction we don't care about previous dialogue turns,
1474
+ # so to save computation time, use only the last one
1475
+ dialogue_transformed = tf.expand_dims(
1476
+ self._last_token(dialogue_transformed, dialogue_lengths), 1
1477
+ )
1478
+ mask = tf.expand_dims(self._last_token(mask, dialogue_lengths), 1)
1479
+
1480
+ dialogue_embed = self._tf_layers[f"embed.{DIALOGUE}"](dialogue_transformed)
1481
+
1482
+ return dialogue_embed, mask, dialogue_transformed, attention_weights
1483
+
1484
+ def _encode_features_per_attribute(
1485
+ self, tf_batch_data: Dict[Text, Dict[Text, List[tf.Tensor]]], attribute: Text
1486
+ ) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
1487
+ # The input is a representation of 4d tensor of
1488
+ # shape (batch-size x dialogue-len x sequence-len x units) in 3d of shape
1489
+ # (sum of dialogue history length for all tensors in the batch x
1490
+ # max sequence length x number of features).
1491
+
1492
+ # However, some dialogue turns contain non existent state features,
1493
+ # e.g. `intent` and `text` features are mutually exclusive,
1494
+ # as well as `action_name` and `action_text` are mutually exclusive,
1495
+ # or some dialogue turns don't contain any `slots`.
1496
+ # In order to create 4d full tensors, we created "fake" zero features for
1497
+ # these non existent state features. And filtered them during batch generation.
1498
+ # Therefore the first dimensions for different attributes are different.
1499
+ # It could happen that some batches don't contain "real" features at all,
1500
+ # e.g. large number of stories don't contain any `slots`.
1501
+ # Therefore actual input tensors will be empty.
1502
+ # Since we need actual numbers to create dialogue turn features, we create
1503
+ # zero tensors in `_encode_fake_features_per_attribute` for these attributes.
1504
+ return tf.cond(
1505
+ tf.shape(tf_batch_data[attribute][SENTENCE][0])[0] > 0,
1506
+ lambda: self._encode_real_features_per_attribute(tf_batch_data, attribute),
1507
+ lambda: self._encode_fake_features_per_attribute(tf_batch_data, attribute),
1508
+ )
1509
+
1510
+ def _encode_fake_features_per_attribute(
1511
+ self, tf_batch_data: Dict[Text, Dict[Text, List[tf.Tensor]]], attribute: Text
1512
+ ) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
1513
+ """Returns dummy outputs for fake features of a given attribute.
1514
+
1515
+ Needs to match the outputs of `_encode_real_features_per_attribute` in shape
1516
+ but these outputs will be filled with zeros.
1517
+
1518
+ Args:
1519
+ tf_batch_data: Maps each attribute to its features and masks.
1520
+ attribute: The attribute whose fake features will be "processed", e.g.
1521
+ `ACTION_NAME`, `INTENT`.
1522
+
1523
+ Returns:
1524
+ attribute_features: A tensor of shape `(batch_size, dialogue_length, units)`
1525
+ filled with zeros.
1526
+ text_output: Only for `TEXT` attribute (otherwise an empty tensor): A tensor
1527
+ of shape `(combined batch_size & dialogue_length, max seq length,
1528
+ units)` filled with zeros.
1529
+ text_sequence_lengths: Only for `TEXT` attribute, otherwise an empty tensor:
1530
+ Of hape `(combined batch_size & dialogue_length, 1)`, filled with zeros.
1531
+ """
1532
+ # we need to create real zero tensors with appropriate batch and dialogue dim
1533
+ # because they are passed to dialogue transformer
1534
+ attribute_mask = tf_batch_data[attribute][MASK][0]
1535
+
1536
+ # determine all dimensions so that fake features of the correct shape can be
1537
+ # created
1538
+ batch_dim = tf.shape(attribute_mask)[0]
1539
+ dialogue_dim = tf.shape(attribute_mask)[1]
1540
+ if attribute in set(SENTENCE_FEATURES_TO_ENCODE + LABEL_FEATURES_TO_ENCODE):
1541
+ units = self.config[ENCODING_DIMENSION]
1542
+ else:
1543
+ # state-level attributes don't use an encoding layer, hence their size is
1544
+ # just the output size of the corresponding sparse+dense feature combining
1545
+ # layer
1546
+ units = self._tf_layers[
1547
+ f"sparse_dense_concat_layer.{attribute}"
1548
+ ].output_units
1549
+
1550
+ attribute_features = tf.zeros(
1551
+ (batch_dim, dialogue_dim, units), dtype=tf.float32
1552
+ )
1553
+
1554
+ # Only for user text, the transformer output and sequence lengths also have to
1555
+ # be created (here using fake features) to enable entity recognition training
1556
+ # and prediction.
1557
+ if attribute == TEXT:
1558
+ # we just need to get the correct last dimension size from the prepared
1559
+ # transformer
1560
+ text_units = self._tf_layers[f"sequence_layer.{attribute}"].output_units
1561
+ text_output = tf.zeros((0, 0, text_units), dtype=tf.float32)
1562
+ text_sequence_lengths = tf.zeros((0,), dtype=tf.int32)
1563
+ else:
1564
+ # simulate None with empty tensor of zeros
1565
+ text_output = tf.zeros((0,))
1566
+ text_sequence_lengths = tf.zeros((0,))
1567
+
1568
+ return attribute_features, text_output, text_sequence_lengths
1569
+
1570
+ @staticmethod
1571
+ def _create_last_dialogue_turns_mask(
1572
+ tf_batch_data: Dict[Text, Dict[Text, List[tf.Tensor]]], attribute: Text
1573
+ ) -> tf.Tensor:
1574
+ # Since max_history_featurizer_is_used is True,
1575
+ # we need to find the locations of last dialogue turns in
1576
+ # (combined batch dimension and dialogue length,) dimension,
1577
+ # so that we can use `_sequence_lengths` as a boolean mask to pick
1578
+ # which ones are "real" textual input in these last dialogue turns.
1579
+
1580
+ # In order to do that we can use given `dialogue_lengths`.
1581
+ # For example:
1582
+ # If we have `dialogue_lengths = [2, 1, 3]`, than
1583
+ # `dialogue_indices = [0, 1, 0, 0, 1, 2]` here we can spot that `0`
1584
+ # always indicates the first dialogue turn,
1585
+ # which means that previous dialogue turn is the last dialogue turn.
1586
+ # Combining this with the fact that the last element in
1587
+ # `dialogue_indices` is always the last dialogue turn, we can add
1588
+ # a `0` to the end, getting
1589
+ # `_dialogue_indices = [0, 1, 0, 0, 1, 2, 0]`.
1590
+ # Then removing the first element
1591
+ # `_last_dialogue_turn_inverse_indicator = [1, 0, 0, 1, 2, 0]`
1592
+ # we see that `0` points to the last dialogue turn.
1593
+ # We convert all positive numbers to `True` and take
1594
+ # the inverse mask to get
1595
+ # `last_dialogue_mask = [0, 1, 1, 0, 0, 1],
1596
+ # which precisely corresponds to the fact that first dialogue is of
1597
+ # length 2, the second 1 and the third 3.
1598
+ last_dialogue_turn_mask = tf.math.logical_not(
1599
+ tf.cast(
1600
+ tf.concat(
1601
+ [
1602
+ tf_batch_data[DIALOGUE][INDICES][0],
1603
+ tf.zeros((1,), dtype=tf.int32),
1604
+ ],
1605
+ axis=0,
1606
+ )[1:],
1607
+ dtype=tf.bool,
1608
+ )
1609
+ )
1610
+ # get only the indices of real inputs
1611
+ return tf.boolean_mask(
1612
+ last_dialogue_turn_mask,
1613
+ tf.reshape(tf_batch_data[attribute][SEQUENCE_LENGTH][0], (-1,)),
1614
+ )
1615
+
1616
+ def _encode_real_features_per_attribute(
1617
+ self, tf_batch_data: Dict[Text, Dict[Text, List[tf.Tensor]]], attribute: Text
1618
+ ) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
1619
+ """Encodes features for a given attribute.
1620
+
1621
+ Args:
1622
+ tf_batch_data: Maps each attribute to its features and masks.
1623
+ attribute: the attribute we will encode features for
1624
+ (e.g., ACTION_NAME, INTENT)
1625
+
1626
+ Returns:
1627
+ attribute_features: A tensor of shape `(batch_size, dialogue_length, units)`
1628
+ with all features for `attribute` processed and combined. If sequence-
1629
+ level features are present, the sequence dimension is eliminated using
1630
+ a transformer.
1631
+ text_output: Only for `TEXT` attribute (otherwise an empty tensor): A tensor
1632
+ of shape `(combined batch_size & dialogue_length, max seq length,
1633
+ units)` containing token-level embeddings further used for entity
1634
+ extraction from user text. Similar to `attribute_features` but returned
1635
+ for all tokens, not just for the last one.
1636
+ text_sequence_lengths: Only for `TEXT` attribute, otherwise an empty tensor:
1637
+ Shape `(combined batch_size & dialogue_length, 1)`, containing the
1638
+ sequence length for user text examples in `text_output`. The sequence
1639
+ length is effectively the number of tokens + 1 (to account also for
1640
+ sentence-level features). Needed for entity extraction from user text.
1641
+ """
1642
+ # simulate None with empty tensor of zeros
1643
+ text_output = tf.zeros((0,))
1644
+ text_sequence_lengths = tf.zeros((0,))
1645
+
1646
+ if attribute in SEQUENCE_FEATURES_TO_ENCODE:
1647
+ # get lengths of real token sequences as a 3D tensor
1648
+ sequence_feature_lengths = self._get_sequence_feature_lengths(
1649
+ tf_batch_data, attribute
1650
+ )
1651
+
1652
+ # sequence_feature_lengths contain `0` for "fake" features, while
1653
+ # tf_batch_data[attribute] contains only "real" features. Hence, we need to
1654
+ # get rid of the lengths that are 0. This step produces a 1D tensor.
1655
+ sequence_feature_lengths = tf.boolean_mask(
1656
+ sequence_feature_lengths, sequence_feature_lengths
1657
+ )
1658
+
1659
+ attribute_features, _, _, _, _, _ = self._tf_layers[
1660
+ f"sequence_layer.{attribute}"
1661
+ ](
1662
+ (
1663
+ tf_batch_data[attribute][SEQUENCE],
1664
+ tf_batch_data[attribute][SENTENCE],
1665
+ sequence_feature_lengths,
1666
+ ),
1667
+ training=self._training,
1668
+ )
1669
+
1670
+ combined_sentence_sequence_feature_lengths = sequence_feature_lengths + 1
1671
+
1672
+ # Only for user text, the transformer output and sequence lengths also have
1673
+ # to be returned to enable entity recognition training and prediction.
1674
+ if attribute == TEXT:
1675
+ text_output = attribute_features
1676
+ text_sequence_lengths = combined_sentence_sequence_feature_lengths
1677
+
1678
+ if self.max_history_featurizer_is_used:
1679
+ # get the location of all last dialogue inputs
1680
+ last_dialogue_turns_mask = self._create_last_dialogue_turns_mask(
1681
+ tf_batch_data, attribute
1682
+ )
1683
+ # pick outputs that correspond to the last dialogue turns
1684
+ text_output = tf.boolean_mask(text_output, last_dialogue_turns_mask)
1685
+ text_sequence_lengths = tf.boolean_mask(
1686
+ text_sequence_lengths, last_dialogue_turns_mask
1687
+ )
1688
+
1689
+ # resulting attribute features will have shape
1690
+ # combined batch dimension and dialogue length x 1 x units
1691
+ attribute_features = tf.expand_dims(
1692
+ self._last_token(
1693
+ attribute_features, combined_sentence_sequence_feature_lengths
1694
+ ),
1695
+ axis=1,
1696
+ )
1697
+
1698
+ # for attributes without sequence-level features, all we need is to combine the
1699
+ # sparse and dense sentence-level features into one
1700
+ else:
1701
+ # resulting attribute features will have shape
1702
+ # combined batch dimension and dialogue length x 1 x units
1703
+ attribute_features = self._tf_layers[
1704
+ f"sparse_dense_concat_layer.{attribute}"
1705
+ ]((tf_batch_data[attribute][SENTENCE],), training=self._training)
1706
+
1707
+ if attribute in SENTENCE_FEATURES_TO_ENCODE + LABEL_FEATURES_TO_ENCODE:
1708
+ attribute_features = self._tf_layers[f"encoding_layer.{attribute}"](
1709
+ attribute_features, self._training
1710
+ )
1711
+
1712
+ # attribute features have shape
1713
+ # (combined batch dimension and dialogue length x 1 x units)
1714
+ # convert them back to their original shape of
1715
+ # batch size x dialogue length x units
1716
+ attribute_features = self._convert_to_original_shape(
1717
+ attribute_features, tf_batch_data, attribute
1718
+ )
1719
+
1720
+ return attribute_features, text_output, text_sequence_lengths
1721
+
1722
+ @staticmethod
1723
+ def _convert_to_original_shape(
1724
+ attribute_features: tf.Tensor,
1725
+ tf_batch_data: Dict[Text, Dict[Text, List[tf.Tensor]]],
1726
+ attribute: Text,
1727
+ ) -> tf.Tensor:
1728
+ """Transform attribute features back to original shape.
1729
+
1730
+ Given shape: (combined batch and dialogue dimension x 1 x units)
1731
+ Original shape: (batch x dialogue length x units)
1732
+
1733
+ Args:
1734
+ attribute_features: the "real" features to convert
1735
+ tf_batch_data: dictionary mapping every attribute to its features and masks
1736
+ attribute: the attribute we will encode features for
1737
+ (e.g., ACTION_NAME, INTENT)
1738
+
1739
+ Returns:
1740
+ The converted attribute features
1741
+ """
1742
+ # in order to convert the attribute features with shape
1743
+ # (combined batch-size and dialogue length x 1 x units)
1744
+ # to a shape of (batch-size x dialogue length x units)
1745
+ # we use tf.scatter_nd. Therefore, we need the target shape and the indices
1746
+ # mapping the values of attribute features to the position in the resulting
1747
+ # tensor.
1748
+
1749
+ # attribute_mask has shape batch x dialogue_len x 1
1750
+ attribute_mask = tf_batch_data[attribute][MASK][0]
1751
+
1752
+ if attribute in SENTENCE_FEATURES_TO_ENCODE + STATE_LEVEL_FEATURES:
1753
+ dialogue_lengths = tf.cast(
1754
+ tf_batch_data[DIALOGUE][LENGTH][0], dtype=tf.int32
1755
+ )
1756
+ dialogue_indices = tf_batch_data[DIALOGUE][INDICES][0]
1757
+ else:
1758
+ # for labels, dialogue length is a fake dim and equal to 1
1759
+ dialogue_lengths = tf.ones((tf.shape(attribute_mask)[0],), dtype=tf.int32)
1760
+ dialogue_indices = tf.zeros((tf.shape(attribute_mask)[0],), dtype=tf.int32)
1761
+
1762
+ batch_dim = tf.shape(attribute_mask)[0]
1763
+ dialogue_dim = tf.shape(attribute_mask)[1]
1764
+ units = attribute_features.shape[-1]
1765
+
1766
+ # attribute_mask has shape (batch x dialogue_len x 1), remove last dimension
1767
+ attribute_mask = tf.cast(tf.squeeze(attribute_mask, axis=-1), dtype=tf.int32)
1768
+ # sum of attribute mask contains number of dialogue turns with "real" features
1769
+ non_fake_dialogue_lengths = tf.reduce_sum(attribute_mask, axis=-1)
1770
+ # create the batch indices
1771
+ batch_indices = tf.repeat(tf.range(batch_dim), non_fake_dialogue_lengths)
1772
+
1773
+ # attribute_mask has shape (batch x dialogue_len x 1), while
1774
+ # dialogue_indices has shape (combined_dialogue_len,)
1775
+ # in order to find positions of real input we need to flatten
1776
+ # attribute mask to (combined_dialogue_len,)
1777
+ dialogue_indices_mask = tf.boolean_mask(
1778
+ attribute_mask, tf.sequence_mask(dialogue_lengths, dtype=tf.int32)
1779
+ )
1780
+ # pick only those indices that contain "real" input
1781
+ dialogue_indices = tf.boolean_mask(dialogue_indices, dialogue_indices_mask)
1782
+
1783
+ indices = tf.stack([batch_indices, dialogue_indices], axis=1)
1784
+
1785
+ shape = tf.convert_to_tensor([batch_dim, dialogue_dim, units])
1786
+ attribute_features = tf.squeeze(attribute_features, axis=1)
1787
+
1788
+ return tf.scatter_nd(indices, attribute_features, shape)
1789
+
1790
+ def _process_batch_data(
1791
+ self, tf_batch_data: Dict[Text, Dict[Text, List[tf.Tensor]]]
1792
+ ) -> Tuple[tf.Tensor, Optional[tf.Tensor], Optional[tf.Tensor]]:
1793
+ """Encodes batch data.
1794
+
1795
+ Combines intent and text and action name and action text if both are present.
1796
+
1797
+ Args:
1798
+ tf_batch_data: dictionary mapping every attribute to its features and masks
1799
+
1800
+ Returns:
1801
+ Tensor: encoding of all features in the batch, combined;
1802
+ """
1803
+ # encode each attribute present in tf_batch_data
1804
+ text_output = None
1805
+ text_sequence_lengths = None
1806
+ batch_encoded = {}
1807
+ for attribute in tf_batch_data.keys():
1808
+ if attribute in SENTENCE_FEATURES_TO_ENCODE + STATE_LEVEL_FEATURES:
1809
+ (
1810
+ attribute_features,
1811
+ _text_output,
1812
+ _text_sequence_lengths,
1813
+ ) = self._encode_features_per_attribute(tf_batch_data, attribute)
1814
+
1815
+ batch_encoded[attribute] = attribute_features
1816
+ if attribute == TEXT:
1817
+ text_output = _text_output
1818
+ text_sequence_lengths = _text_sequence_lengths
1819
+
1820
+ # if both action text and action name are present, combine them; otherwise,
1821
+ # return the one which is present
1822
+
1823
+ if (
1824
+ batch_encoded.get(ACTION_TEXT) is not None
1825
+ and batch_encoded.get(ACTION_NAME) is not None
1826
+ ):
1827
+ batch_action = batch_encoded.pop(ACTION_TEXT) + batch_encoded.pop(
1828
+ ACTION_NAME
1829
+ )
1830
+ elif batch_encoded.get(ACTION_TEXT) is not None:
1831
+ batch_action = batch_encoded.pop(ACTION_TEXT)
1832
+ else:
1833
+ batch_action = batch_encoded.pop(ACTION_NAME)
1834
+ # same for user input
1835
+ if (
1836
+ batch_encoded.get(INTENT) is not None
1837
+ and batch_encoded.get(TEXT) is not None
1838
+ ):
1839
+ batch_user = batch_encoded.pop(INTENT) + batch_encoded.pop(TEXT)
1840
+ elif batch_encoded.get(TEXT) is not None:
1841
+ batch_user = batch_encoded.pop(TEXT)
1842
+ else:
1843
+ batch_user = batch_encoded.pop(INTENT)
1844
+
1845
+ batch_features = [batch_user, batch_action]
1846
+ # once we have user input and previous action,
1847
+ # add all other attributes (SLOTS, ACTIVE_LOOP, etc.) to batch_features;
1848
+ for key in batch_encoded.keys():
1849
+ batch_features.append(batch_encoded.get(key))
1850
+
1851
+ batch_features = tf.concat(batch_features, axis=-1)
1852
+
1853
+ return batch_features, text_output, text_sequence_lengths
1854
+
1855
+ def _reshape_for_entities(
1856
+ self,
1857
+ tf_batch_data: Dict[Text, Dict[Text, List[tf.Tensor]]],
1858
+ dialogue_transformer_output: tf.Tensor,
1859
+ text_output: tf.Tensor,
1860
+ text_sequence_lengths: tf.Tensor,
1861
+ ) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
1862
+ # The first dim of the output of the text sequence transformer is the same
1863
+ # as number of "real" features for `text` at the last dialogue turns
1864
+ # (let's call it `N`),
1865
+ # which corresponds to the first dim of the tag ids tensor.
1866
+ # To calculate the loss for entities we need the output of the text
1867
+ # sequence transformer (shape: N x sequence length x units),
1868
+ # the output of the dialogue transformer
1869
+ # (shape: batch size x dialogue length x units) and the tag ids for the
1870
+ # entities (shape: N x sequence length - 1 x units)
1871
+ # In order to process the tensors, they need to have the same shape.
1872
+ # Convert the output of the dialogue transformer to shape
1873
+ # (N x 1 x units).
1874
+
1875
+ # Note: The CRF layer cannot handle 4D tensors. E.g. we cannot use the shape
1876
+ # batch size x dialogue length x sequence length x units
1877
+
1878
+ # convert the output of the dialogue transformer
1879
+ # to shape (real entity dim x 1 x units)
1880
+ attribute_mask = tf_batch_data[TEXT][MASK][0]
1881
+ dialogue_lengths = tf.cast(tf_batch_data[DIALOGUE][LENGTH][0], tf.int32)
1882
+
1883
+ if self.max_history_featurizer_is_used:
1884
+ # pick outputs that correspond to the last dialogue turns
1885
+ attribute_mask = tf.expand_dims(
1886
+ self._last_token(attribute_mask, dialogue_lengths), axis=1
1887
+ )
1888
+ dialogue_transformer_output = tf.boolean_mask(
1889
+ dialogue_transformer_output, tf.squeeze(attribute_mask, axis=-1)
1890
+ )
1891
+
1892
+ # boolean mask removed axis=1, add it back
1893
+ dialogue_transformer_output = tf.expand_dims(
1894
+ dialogue_transformer_output, axis=1
1895
+ )
1896
+
1897
+ # broadcast the dialogue transformer output sequence-length-times to get the
1898
+ # same shape as the text sequence transformer output
1899
+ dialogue_transformer_output = tf.tile(
1900
+ dialogue_transformer_output, (1, tf.shape(text_output)[1], 1)
1901
+ )
1902
+
1903
+ # concat the output of the dialogue transformer to the output of the text
1904
+ # sequence transformer (adding context)
1905
+ # resulting shape (N x sequence length x 2 units)
1906
+ # N = number of "real" features for `text` at the last dialogue turns
1907
+ text_transformed = tf.concat(
1908
+ [text_output, dialogue_transformer_output], axis=-1
1909
+ )
1910
+ text_mask = rasa_layers.compute_mask(text_sequence_lengths)
1911
+
1912
+ # add zeros to match the shape of text_transformed, because
1913
+ # max sequence length might differ, since it is calculated dynamically
1914
+ # based on a subset of sequence lengths
1915
+ sequence_diff = tf.shape(text_transformed)[1] - tf.shape(text_mask)[1]
1916
+ text_mask = tf.pad(text_mask, [[0, 0], [0, sequence_diff], [0, 0]])
1917
+
1918
+ # remove additional dims and sentence features
1919
+ text_sequence_lengths = tf.reshape(text_sequence_lengths, (-1,)) - 1
1920
+
1921
+ return text_transformed, text_mask, text_sequence_lengths
1922
+
1923
+ # ---TRAINING---
1924
+
1925
+ def _batch_loss_entities(
1926
+ self,
1927
+ tf_batch_data: Dict[Text, Dict[Text, List[tf.Tensor]]],
1928
+ dialogue_transformer_output: tf.Tensor,
1929
+ text_output: tf.Tensor,
1930
+ text_sequence_lengths: tf.Tensor,
1931
+ ) -> tf.Tensor:
1932
+ # It could happen that some batches don't contain "real" features for `text`,
1933
+ # e.g. large number of stories are intent only.
1934
+ # Therefore actual `text_output` will be empty.
1935
+ # We cannot create a loss with empty tensors.
1936
+ # Since we need actual numbers to create a full loss, we output
1937
+ # zero in this case.
1938
+ return tf.cond(
1939
+ tf.shape(text_output)[0] > 0,
1940
+ lambda: self._real_batch_loss_entities(
1941
+ tf_batch_data,
1942
+ dialogue_transformer_output,
1943
+ text_output,
1944
+ text_sequence_lengths,
1945
+ ),
1946
+ lambda: tf.constant(0.0),
1947
+ )
1948
+
1949
+ def _real_batch_loss_entities(
1950
+ self,
1951
+ tf_batch_data: Dict[Text, Dict[Text, List[tf.Tensor]]],
1952
+ dialogue_transformer_output: tf.Tensor,
1953
+ text_output: tf.Tensor,
1954
+ text_sequence_lengths: tf.Tensor,
1955
+ ) -> tf.Tensor:
1956
+ text_transformed, text_mask, text_sequence_lengths = self._reshape_for_entities(
1957
+ tf_batch_data,
1958
+ dialogue_transformer_output,
1959
+ text_output,
1960
+ text_sequence_lengths,
1961
+ )
1962
+
1963
+ tag_ids = tf_batch_data[ENTITY_TAGS][IDS][0]
1964
+ # add a zero (no entity) for the sentence features to match the shape of inputs
1965
+ sequence_diff = tf.shape(text_transformed)[1] - tf.shape(tag_ids)[1]
1966
+ tag_ids = tf.pad(tag_ids, [[0, 0], [0, sequence_diff], [0, 0]])
1967
+
1968
+ loss, f1, _ = self._calculate_entity_loss(
1969
+ text_transformed,
1970
+ tag_ids,
1971
+ text_mask,
1972
+ text_sequence_lengths,
1973
+ ENTITY_ATTRIBUTE_TYPE,
1974
+ )
1975
+
1976
+ self.entity_loss.update_state(loss)
1977
+ self.entity_f1.update_state(f1)
1978
+
1979
+ return loss
1980
+
1981
+ @staticmethod
1982
+ def _get_labels_embed(
1983
+ label_ids: tf.Tensor, all_labels_embed: tf.Tensor
1984
+ ) -> tf.Tensor:
1985
+ # instead of processing labels again, gather embeddings from
1986
+ # all_labels_embed using label ids
1987
+
1988
+ indices = tf.cast(label_ids[:, :, 0], tf.int32)
1989
+ labels_embed = tf.gather(all_labels_embed, indices)
1990
+
1991
+ return labels_embed
1992
+
1993
+ def batch_loss(
1994
+ self, batch_in: Union[Tuple[tf.Tensor, ...], Tuple[np.ndarray, ...]]
1995
+ ) -> tf.Tensor:
1996
+ """Calculates the loss for the given batch.
1997
+
1998
+ Args:
1999
+ batch_in: The batch.
2000
+
2001
+ Returns:
2002
+ The loss of the given batch.
2003
+ """
2004
+ tf_batch_data = self.batch_to_model_data_format(batch_in, self.data_signature)
2005
+ self._compute_dialogue_indices(tf_batch_data)
2006
+
2007
+ all_label_ids, all_labels_embed = self._create_all_labels_embed()
2008
+
2009
+ label_ids = tf_batch_data[LABEL_KEY][LABEL_SUB_KEY][0]
2010
+ labels_embed = self._get_labels_embed(label_ids, all_labels_embed)
2011
+
2012
+ dialogue_in, text_output, text_sequence_lengths = self._process_batch_data(
2013
+ tf_batch_data
2014
+ )
2015
+ (
2016
+ dialogue_embed,
2017
+ dialogue_mask,
2018
+ dialogue_transformer_output,
2019
+ _,
2020
+ ) = self._embed_dialogue(dialogue_in, tf_batch_data)
2021
+ dialogue_mask = tf.squeeze(dialogue_mask, axis=-1)
2022
+
2023
+ losses = []
2024
+
2025
+ loss, acc = self._tf_layers[f"loss.{LABEL}"](
2026
+ dialogue_embed,
2027
+ labels_embed,
2028
+ label_ids,
2029
+ all_labels_embed,
2030
+ all_label_ids,
2031
+ dialogue_mask,
2032
+ )
2033
+ losses.append(loss)
2034
+
2035
+ if (
2036
+ self.config[ENTITY_RECOGNITION]
2037
+ and text_output is not None
2038
+ and text_sequence_lengths is not None
2039
+ ):
2040
+ losses.append(
2041
+ self._batch_loss_entities(
2042
+ tf_batch_data,
2043
+ dialogue_transformer_output,
2044
+ text_output,
2045
+ text_sequence_lengths,
2046
+ )
2047
+ )
2048
+
2049
+ self.action_loss.update_state(loss)
2050
+ self.action_acc.update_state(acc)
2051
+
2052
+ return tf.math.add_n(losses)
2053
+
2054
+ # ---PREDICTION---
2055
+ def prepare_for_predict(self) -> None:
2056
+ """Prepares the model for prediction."""
2057
+ _, self.all_labels_embed = self._create_all_labels_embed()
2058
+
2059
+ def batch_predict(
2060
+ self, batch_in: Union[Tuple[tf.Tensor, ...], Tuple[np.ndarray, ...]]
2061
+ ) -> Dict[Text, Union[tf.Tensor, Dict[Text, tf.Tensor]]]:
2062
+ """Predicts the output of the given batch.
2063
+
2064
+ Args:
2065
+ batch_in: The batch.
2066
+
2067
+ Returns:
2068
+ The output to predict.
2069
+ """
2070
+ if self.all_labels_embed is None:
2071
+ raise ValueError(
2072
+ "The model was not prepared for prediction. "
2073
+ "Call `prepare_for_predict` first."
2074
+ )
2075
+
2076
+ tf_batch_data = self.batch_to_model_data_format(
2077
+ batch_in, self.predict_data_signature
2078
+ )
2079
+ self._compute_dialogue_indices(tf_batch_data)
2080
+
2081
+ dialogue_in, text_output, text_sequence_lengths = self._process_batch_data(
2082
+ tf_batch_data
2083
+ )
2084
+ (
2085
+ dialogue_embed,
2086
+ dialogue_mask,
2087
+ dialogue_transformer_output,
2088
+ attention_weights,
2089
+ ) = self._embed_dialogue(dialogue_in, tf_batch_data)
2090
+ dialogue_mask = tf.squeeze(dialogue_mask, axis=-1)
2091
+
2092
+ sim_all, scores = self._tf_layers[
2093
+ f"loss.{LABEL}"
2094
+ ].get_similarities_and_confidences_from_embeddings(
2095
+ dialogue_embed[:, :, tf.newaxis, :],
2096
+ self.all_labels_embed[tf.newaxis, tf.newaxis, :, :],
2097
+ dialogue_mask,
2098
+ )
2099
+
2100
+ predictions = {
2101
+ "scores": scores,
2102
+ "similarities": sim_all,
2103
+ DIAGNOSTIC_DATA: {"attention_weights": attention_weights},
2104
+ }
2105
+
2106
+ if (
2107
+ self.config[ENTITY_RECOGNITION]
2108
+ and text_output is not None
2109
+ and text_sequence_lengths is not None
2110
+ ):
2111
+ pred_ids, confidences = self._batch_predict_entities(
2112
+ tf_batch_data,
2113
+ dialogue_transformer_output,
2114
+ text_output,
2115
+ text_sequence_lengths,
2116
+ )
2117
+ name = ENTITY_ATTRIBUTE_TYPE
2118
+ predictions[f"e_{name}_ids"] = pred_ids
2119
+ predictions[f"e_{name}_scores"] = confidences
2120
+
2121
+ return predictions
2122
+
2123
+ def _batch_predict_entities(
2124
+ self,
2125
+ tf_batch_data: Dict[Text, Dict[Text, List[tf.Tensor]]],
2126
+ dialogue_transformer_output: tf.Tensor,
2127
+ text_output: tf.Tensor,
2128
+ text_sequence_lengths: tf.Tensor,
2129
+ ) -> Tuple[tf.Tensor, tf.Tensor]:
2130
+ # It could happen that current prediction turn don't contain
2131
+ # "real" features for `text`,
2132
+ # Therefore actual `text_output` will be empty.
2133
+ # We cannot predict entities with empty tensors.
2134
+ # Since we need to output some tensors of the same shape, we output
2135
+ # zero tensors.
2136
+ return tf.cond(
2137
+ tf.shape(text_output)[0] > 0,
2138
+ lambda: self._real_batch_predict_entities(
2139
+ tf_batch_data,
2140
+ dialogue_transformer_output,
2141
+ text_output,
2142
+ text_sequence_lengths,
2143
+ ),
2144
+ lambda: (
2145
+ # the output is of shape (batch_size, max_seq_len)
2146
+ tf.zeros(tf.shape(text_output)[:2], dtype=tf.int32),
2147
+ tf.zeros(tf.shape(text_output)[:2], dtype=tf.float32),
2148
+ ),
2149
+ )
2150
+
2151
+ def _real_batch_predict_entities(
2152
+ self,
2153
+ tf_batch_data: Dict[Text, Dict[Text, List[tf.Tensor]]],
2154
+ dialogue_transformer_output: tf.Tensor,
2155
+ text_output: tf.Tensor,
2156
+ text_sequence_lengths: tf.Tensor,
2157
+ ) -> Tuple[tf.Tensor, tf.Tensor]:
2158
+ text_transformed, _, text_sequence_lengths = self._reshape_for_entities(
2159
+ tf_batch_data,
2160
+ dialogue_transformer_output,
2161
+ text_output,
2162
+ text_sequence_lengths,
2163
+ )
2164
+
2165
+ name = ENTITY_ATTRIBUTE_TYPE
2166
+
2167
+ _logits = self._tf_layers[f"embed.{name}.logits"](text_transformed)
2168
+
2169
+ return self._tf_layers[f"crf.{name}"](_logits, text_sequence_lengths)