rasa-pro 3.12.0.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rasa-pro might be problematic. Click here for more details.

Files changed (790) hide show
  1. README.md +41 -0
  2. rasa/__init__.py +9 -0
  3. rasa/__main__.py +177 -0
  4. rasa/anonymization/__init__.py +2 -0
  5. rasa/anonymization/anonymisation_rule_yaml_reader.py +91 -0
  6. rasa/anonymization/anonymization_pipeline.py +286 -0
  7. rasa/anonymization/anonymization_rule_executor.py +260 -0
  8. rasa/anonymization/anonymization_rule_orchestrator.py +120 -0
  9. rasa/anonymization/schemas/config.yml +47 -0
  10. rasa/anonymization/utils.py +118 -0
  11. rasa/api.py +160 -0
  12. rasa/cli/__init__.py +5 -0
  13. rasa/cli/arguments/__init__.py +0 -0
  14. rasa/cli/arguments/data.py +106 -0
  15. rasa/cli/arguments/default_arguments.py +207 -0
  16. rasa/cli/arguments/evaluate.py +65 -0
  17. rasa/cli/arguments/export.py +51 -0
  18. rasa/cli/arguments/interactive.py +74 -0
  19. rasa/cli/arguments/run.py +219 -0
  20. rasa/cli/arguments/shell.py +17 -0
  21. rasa/cli/arguments/test.py +211 -0
  22. rasa/cli/arguments/train.py +279 -0
  23. rasa/cli/arguments/visualize.py +34 -0
  24. rasa/cli/arguments/x.py +30 -0
  25. rasa/cli/data.py +354 -0
  26. rasa/cli/dialogue_understanding_test.py +251 -0
  27. rasa/cli/e2e_test.py +259 -0
  28. rasa/cli/evaluate.py +222 -0
  29. rasa/cli/export.py +250 -0
  30. rasa/cli/inspect.py +75 -0
  31. rasa/cli/interactive.py +166 -0
  32. rasa/cli/license.py +65 -0
  33. rasa/cli/llm_fine_tuning.py +403 -0
  34. rasa/cli/markers.py +78 -0
  35. rasa/cli/project_templates/__init__.py +0 -0
  36. rasa/cli/project_templates/calm/actions/__init__.py +0 -0
  37. rasa/cli/project_templates/calm/actions/action_template.py +27 -0
  38. rasa/cli/project_templates/calm/actions/add_contact.py +30 -0
  39. rasa/cli/project_templates/calm/actions/db.py +57 -0
  40. rasa/cli/project_templates/calm/actions/list_contacts.py +22 -0
  41. rasa/cli/project_templates/calm/actions/remove_contact.py +35 -0
  42. rasa/cli/project_templates/calm/config.yml +10 -0
  43. rasa/cli/project_templates/calm/credentials.yml +33 -0
  44. rasa/cli/project_templates/calm/data/flows/add_contact.yml +31 -0
  45. rasa/cli/project_templates/calm/data/flows/list_contacts.yml +14 -0
  46. rasa/cli/project_templates/calm/data/flows/remove_contact.yml +29 -0
  47. rasa/cli/project_templates/calm/db/contacts.json +10 -0
  48. rasa/cli/project_templates/calm/domain/add_contact.yml +39 -0
  49. rasa/cli/project_templates/calm/domain/list_contacts.yml +17 -0
  50. rasa/cli/project_templates/calm/domain/remove_contact.yml +38 -0
  51. rasa/cli/project_templates/calm/domain/shared.yml +10 -0
  52. rasa/cli/project_templates/calm/e2e_tests/cancelations/user_cancels_during_a_correction.yml +16 -0
  53. rasa/cli/project_templates/calm/e2e_tests/cancelations/user_changes_mind_on_a_whim.yml +7 -0
  54. rasa/cli/project_templates/calm/e2e_tests/corrections/user_corrects_contact_handle.yml +20 -0
  55. rasa/cli/project_templates/calm/e2e_tests/corrections/user_corrects_contact_name.yml +19 -0
  56. rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_adds_contact_to_their_list.yml +15 -0
  57. rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_lists_contacts.yml +5 -0
  58. rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_removes_contact.yml +11 -0
  59. rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_removes_contact_from_list.yml +12 -0
  60. rasa/cli/project_templates/calm/endpoints.yml +58 -0
  61. rasa/cli/project_templates/default/actions/__init__.py +0 -0
  62. rasa/cli/project_templates/default/actions/actions.py +27 -0
  63. rasa/cli/project_templates/default/config.yml +44 -0
  64. rasa/cli/project_templates/default/credentials.yml +33 -0
  65. rasa/cli/project_templates/default/data/nlu.yml +91 -0
  66. rasa/cli/project_templates/default/data/rules.yml +13 -0
  67. rasa/cli/project_templates/default/data/stories.yml +30 -0
  68. rasa/cli/project_templates/default/domain.yml +34 -0
  69. rasa/cli/project_templates/default/endpoints.yml +42 -0
  70. rasa/cli/project_templates/default/tests/test_stories.yml +91 -0
  71. rasa/cli/project_templates/tutorial/actions/__init__.py +0 -0
  72. rasa/cli/project_templates/tutorial/actions/actions.py +22 -0
  73. rasa/cli/project_templates/tutorial/config.yml +12 -0
  74. rasa/cli/project_templates/tutorial/credentials.yml +33 -0
  75. rasa/cli/project_templates/tutorial/data/flows.yml +8 -0
  76. rasa/cli/project_templates/tutorial/data/patterns.yml +11 -0
  77. rasa/cli/project_templates/tutorial/domain.yml +35 -0
  78. rasa/cli/project_templates/tutorial/endpoints.yml +55 -0
  79. rasa/cli/run.py +143 -0
  80. rasa/cli/scaffold.py +273 -0
  81. rasa/cli/shell.py +141 -0
  82. rasa/cli/studio/__init__.py +0 -0
  83. rasa/cli/studio/download.py +62 -0
  84. rasa/cli/studio/studio.py +296 -0
  85. rasa/cli/studio/train.py +59 -0
  86. rasa/cli/studio/upload.py +62 -0
  87. rasa/cli/telemetry.py +102 -0
  88. rasa/cli/test.py +280 -0
  89. rasa/cli/train.py +278 -0
  90. rasa/cli/utils.py +484 -0
  91. rasa/cli/visualize.py +40 -0
  92. rasa/cli/x.py +206 -0
  93. rasa/constants.py +45 -0
  94. rasa/core/__init__.py +17 -0
  95. rasa/core/actions/__init__.py +0 -0
  96. rasa/core/actions/action.py +1318 -0
  97. rasa/core/actions/action_clean_stack.py +59 -0
  98. rasa/core/actions/action_exceptions.py +24 -0
  99. rasa/core/actions/action_hangup.py +29 -0
  100. rasa/core/actions/action_repeat_bot_messages.py +89 -0
  101. rasa/core/actions/action_run_slot_rejections.py +210 -0
  102. rasa/core/actions/action_trigger_chitchat.py +31 -0
  103. rasa/core/actions/action_trigger_flow.py +109 -0
  104. rasa/core/actions/action_trigger_search.py +31 -0
  105. rasa/core/actions/constants.py +5 -0
  106. rasa/core/actions/custom_action_executor.py +191 -0
  107. rasa/core/actions/direct_custom_actions_executor.py +109 -0
  108. rasa/core/actions/e2e_stub_custom_action_executor.py +72 -0
  109. rasa/core/actions/forms.py +741 -0
  110. rasa/core/actions/grpc_custom_action_executor.py +251 -0
  111. rasa/core/actions/http_custom_action_executor.py +145 -0
  112. rasa/core/actions/loops.py +114 -0
  113. rasa/core/actions/two_stage_fallback.py +186 -0
  114. rasa/core/agent.py +559 -0
  115. rasa/core/auth_retry_tracker_store.py +122 -0
  116. rasa/core/brokers/__init__.py +0 -0
  117. rasa/core/brokers/broker.py +126 -0
  118. rasa/core/brokers/file.py +58 -0
  119. rasa/core/brokers/kafka.py +324 -0
  120. rasa/core/brokers/pika.py +388 -0
  121. rasa/core/brokers/sql.py +86 -0
  122. rasa/core/channels/__init__.py +61 -0
  123. rasa/core/channels/botframework.py +338 -0
  124. rasa/core/channels/callback.py +84 -0
  125. rasa/core/channels/channel.py +456 -0
  126. rasa/core/channels/console.py +241 -0
  127. rasa/core/channels/development_inspector.py +197 -0
  128. rasa/core/channels/facebook.py +419 -0
  129. rasa/core/channels/hangouts.py +329 -0
  130. rasa/core/channels/inspector/.eslintrc.cjs +25 -0
  131. rasa/core/channels/inspector/.gitignore +23 -0
  132. rasa/core/channels/inspector/README.md +54 -0
  133. rasa/core/channels/inspector/assets/favicon.ico +0 -0
  134. rasa/core/channels/inspector/assets/rasa-chat.js +2 -0
  135. rasa/core/channels/inspector/custom.d.ts +3 -0
  136. rasa/core/channels/inspector/dist/assets/arc-861ddd57.js +1 -0
  137. rasa/core/channels/inspector/dist/assets/array-9f3ba611.js +1 -0
  138. rasa/core/channels/inspector/dist/assets/c4Diagram-d0fbc5ce-921f02db.js +10 -0
  139. rasa/core/channels/inspector/dist/assets/classDiagram-936ed81e-b436c4f8.js +2 -0
  140. rasa/core/channels/inspector/dist/assets/classDiagram-v2-c3cb15f1-511a23cb.js +2 -0
  141. rasa/core/channels/inspector/dist/assets/createText-62fc7601-ef476ecd.js +7 -0
  142. rasa/core/channels/inspector/dist/assets/edges-f2ad444c-f1878e0a.js +4 -0
  143. rasa/core/channels/inspector/dist/assets/erDiagram-9d236eb7-fac75185.js +51 -0
  144. rasa/core/channels/inspector/dist/assets/flowDb-1972c806-201c5bbc.js +6 -0
  145. rasa/core/channels/inspector/dist/assets/flowDiagram-7ea5b25a-f904ae41.js +4 -0
  146. rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-b080d6f2.js +1 -0
  147. rasa/core/channels/inspector/dist/assets/flowchart-elk-definition-abe16c3d-1813da66.js +139 -0
  148. rasa/core/channels/inspector/dist/assets/ganttDiagram-9b5ea136-872af172.js +266 -0
  149. rasa/core/channels/inspector/dist/assets/gitGraphDiagram-99d0ae7c-34a0af5a.js +70 -0
  150. rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-128cfa44.ttf +0 -0
  151. rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-21dbcb97.woff +0 -0
  152. rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-222b5e26.svg +329 -0
  153. rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-9ad89b2a.woff2 +0 -0
  154. rasa/core/channels/inspector/dist/assets/index-2c4b9a3b-42ba3e3d.js +1 -0
  155. rasa/core/channels/inspector/dist/assets/index-37817b51.js +1317 -0
  156. rasa/core/channels/inspector/dist/assets/index-3ee28881.css +1 -0
  157. rasa/core/channels/inspector/dist/assets/infoDiagram-736b4530-6b731386.js +7 -0
  158. rasa/core/channels/inspector/dist/assets/init-77b53fdd.js +1 -0
  159. rasa/core/channels/inspector/dist/assets/journeyDiagram-df861f2b-e8579ac6.js +139 -0
  160. rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-60c05ee4.woff +0 -0
  161. rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-8335d9b8.svg +438 -0
  162. rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-9cc39c75.ttf +0 -0
  163. rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-ead13ccf.woff2 +0 -0
  164. rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-16705655.woff2 +0 -0
  165. rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-5aeb07f9.woff +0 -0
  166. rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-9c459044.ttf +0 -0
  167. rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-9e2898a4.svg +435 -0
  168. rasa/core/channels/inspector/dist/assets/layout-89e6403a.js +1 -0
  169. rasa/core/channels/inspector/dist/assets/line-dc73d3fc.js +1 -0
  170. rasa/core/channels/inspector/dist/assets/linear-f5b1d2bc.js +1 -0
  171. rasa/core/channels/inspector/dist/assets/mindmap-definition-beec6740-82cb74fa.js +109 -0
  172. rasa/core/channels/inspector/dist/assets/ordinal-ba9b4969.js +1 -0
  173. rasa/core/channels/inspector/dist/assets/path-53f90ab3.js +1 -0
  174. rasa/core/channels/inspector/dist/assets/pieDiagram-dbbf0591-bdf5f29b.js +35 -0
  175. rasa/core/channels/inspector/dist/assets/quadrantDiagram-4d7f4fd6-c7a0cbe4.js +7 -0
  176. rasa/core/channels/inspector/dist/assets/requirementDiagram-6fc4c22a-7ec5410f.js +52 -0
  177. rasa/core/channels/inspector/dist/assets/sankeyDiagram-8f13d901-caee5554.js +8 -0
  178. rasa/core/channels/inspector/dist/assets/sequenceDiagram-b655622a-2935f8db.js +122 -0
  179. rasa/core/channels/inspector/dist/assets/stateDiagram-59f0c015-8f5d9693.js +1 -0
  180. rasa/core/channels/inspector/dist/assets/stateDiagram-v2-2b26beab-d565d1de.js +1 -0
  181. rasa/core/channels/inspector/dist/assets/styles-080da4f6-75ad421d.js +110 -0
  182. rasa/core/channels/inspector/dist/assets/styles-3dcbcfbf-7e764226.js +159 -0
  183. rasa/core/channels/inspector/dist/assets/styles-9c745c82-7a4e0e61.js +207 -0
  184. rasa/core/channels/inspector/dist/assets/svgDrawCommon-4835440b-4019d1bf.js +1 -0
  185. rasa/core/channels/inspector/dist/assets/timeline-definition-5b62e21b-01ea12df.js +61 -0
  186. rasa/core/channels/inspector/dist/assets/xychartDiagram-2b33534f-89407137.js +7 -0
  187. rasa/core/channels/inspector/dist/index.html +42 -0
  188. rasa/core/channels/inspector/index.html +40 -0
  189. rasa/core/channels/inspector/jest.config.ts +13 -0
  190. rasa/core/channels/inspector/package.json +52 -0
  191. rasa/core/channels/inspector/setupTests.ts +2 -0
  192. rasa/core/channels/inspector/src/App.tsx +220 -0
  193. rasa/core/channels/inspector/src/components/Chat.tsx +95 -0
  194. rasa/core/channels/inspector/src/components/DiagramFlow.tsx +108 -0
  195. rasa/core/channels/inspector/src/components/DialogueInformation.tsx +187 -0
  196. rasa/core/channels/inspector/src/components/DialogueStack.tsx +136 -0
  197. rasa/core/channels/inspector/src/components/ExpandIcon.tsx +16 -0
  198. rasa/core/channels/inspector/src/components/FullscreenButton.tsx +45 -0
  199. rasa/core/channels/inspector/src/components/LoadingSpinner.tsx +22 -0
  200. rasa/core/channels/inspector/src/components/NoActiveFlow.tsx +21 -0
  201. rasa/core/channels/inspector/src/components/RasaLogo.tsx +32 -0
  202. rasa/core/channels/inspector/src/components/SaraDiagrams.tsx +39 -0
  203. rasa/core/channels/inspector/src/components/Slots.tsx +91 -0
  204. rasa/core/channels/inspector/src/components/Welcome.tsx +54 -0
  205. rasa/core/channels/inspector/src/helpers/audiostream.ts +191 -0
  206. rasa/core/channels/inspector/src/helpers/formatters.test.ts +392 -0
  207. rasa/core/channels/inspector/src/helpers/formatters.ts +306 -0
  208. rasa/core/channels/inspector/src/helpers/utils.ts +127 -0
  209. rasa/core/channels/inspector/src/main.tsx +13 -0
  210. rasa/core/channels/inspector/src/theme/Button/Button.ts +29 -0
  211. rasa/core/channels/inspector/src/theme/Heading/Heading.ts +31 -0
  212. rasa/core/channels/inspector/src/theme/Input/Input.ts +27 -0
  213. rasa/core/channels/inspector/src/theme/Link/Link.ts +10 -0
  214. rasa/core/channels/inspector/src/theme/Modal/Modal.ts +47 -0
  215. rasa/core/channels/inspector/src/theme/Table/Table.tsx +38 -0
  216. rasa/core/channels/inspector/src/theme/Tooltip/Tooltip.ts +12 -0
  217. rasa/core/channels/inspector/src/theme/base/breakpoints.ts +8 -0
  218. rasa/core/channels/inspector/src/theme/base/colors.ts +88 -0
  219. rasa/core/channels/inspector/src/theme/base/fonts/fontFaces.css +29 -0
  220. rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.eot +0 -0
  221. rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.svg +329 -0
  222. rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.ttf +0 -0
  223. rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.woff +0 -0
  224. rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.woff2 +0 -0
  225. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.eot +0 -0
  226. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.svg +438 -0
  227. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.ttf +0 -0
  228. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.woff +0 -0
  229. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.woff2 +0 -0
  230. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.eot +0 -0
  231. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.svg +435 -0
  232. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.ttf +0 -0
  233. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.woff +0 -0
  234. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.woff2 +0 -0
  235. rasa/core/channels/inspector/src/theme/base/radii.ts +9 -0
  236. rasa/core/channels/inspector/src/theme/base/shadows.ts +7 -0
  237. rasa/core/channels/inspector/src/theme/base/sizes.ts +7 -0
  238. rasa/core/channels/inspector/src/theme/base/space.ts +15 -0
  239. rasa/core/channels/inspector/src/theme/base/styles.ts +13 -0
  240. rasa/core/channels/inspector/src/theme/base/typography.ts +24 -0
  241. rasa/core/channels/inspector/src/theme/base/zIndices.ts +19 -0
  242. rasa/core/channels/inspector/src/theme/index.ts +101 -0
  243. rasa/core/channels/inspector/src/types.ts +84 -0
  244. rasa/core/channels/inspector/src/vite-env.d.ts +1 -0
  245. rasa/core/channels/inspector/tests/__mocks__/fileMock.ts +1 -0
  246. rasa/core/channels/inspector/tests/__mocks__/matchMedia.ts +16 -0
  247. rasa/core/channels/inspector/tests/__mocks__/styleMock.ts +1 -0
  248. rasa/core/channels/inspector/tests/renderWithProviders.tsx +14 -0
  249. rasa/core/channels/inspector/tsconfig.json +26 -0
  250. rasa/core/channels/inspector/tsconfig.node.json +10 -0
  251. rasa/core/channels/inspector/vite.config.ts +8 -0
  252. rasa/core/channels/inspector/yarn.lock +6249 -0
  253. rasa/core/channels/mattermost.py +229 -0
  254. rasa/core/channels/rasa_chat.py +126 -0
  255. rasa/core/channels/rest.py +230 -0
  256. rasa/core/channels/rocketchat.py +174 -0
  257. rasa/core/channels/slack.py +620 -0
  258. rasa/core/channels/socketio.py +302 -0
  259. rasa/core/channels/telegram.py +298 -0
  260. rasa/core/channels/twilio.py +169 -0
  261. rasa/core/channels/vier_cvg.py +374 -0
  262. rasa/core/channels/voice_ready/__init__.py +0 -0
  263. rasa/core/channels/voice_ready/audiocodes.py +501 -0
  264. rasa/core/channels/voice_ready/jambonz.py +121 -0
  265. rasa/core/channels/voice_ready/jambonz_protocol.py +396 -0
  266. rasa/core/channels/voice_ready/twilio_voice.py +403 -0
  267. rasa/core/channels/voice_ready/utils.py +37 -0
  268. rasa/core/channels/voice_stream/__init__.py +0 -0
  269. rasa/core/channels/voice_stream/asr/__init__.py +0 -0
  270. rasa/core/channels/voice_stream/asr/asr_engine.py +89 -0
  271. rasa/core/channels/voice_stream/asr/asr_event.py +18 -0
  272. rasa/core/channels/voice_stream/asr/azure.py +130 -0
  273. rasa/core/channels/voice_stream/asr/deepgram.py +90 -0
  274. rasa/core/channels/voice_stream/audio_bytes.py +8 -0
  275. rasa/core/channels/voice_stream/browser_audio.py +107 -0
  276. rasa/core/channels/voice_stream/call_state.py +23 -0
  277. rasa/core/channels/voice_stream/tts/__init__.py +0 -0
  278. rasa/core/channels/voice_stream/tts/azure.py +106 -0
  279. rasa/core/channels/voice_stream/tts/cartesia.py +118 -0
  280. rasa/core/channels/voice_stream/tts/tts_cache.py +27 -0
  281. rasa/core/channels/voice_stream/tts/tts_engine.py +58 -0
  282. rasa/core/channels/voice_stream/twilio_media_streams.py +173 -0
  283. rasa/core/channels/voice_stream/util.py +57 -0
  284. rasa/core/channels/voice_stream/voice_channel.py +427 -0
  285. rasa/core/channels/webexteams.py +134 -0
  286. rasa/core/concurrent_lock_store.py +210 -0
  287. rasa/core/constants.py +112 -0
  288. rasa/core/evaluation/__init__.py +0 -0
  289. rasa/core/evaluation/marker.py +267 -0
  290. rasa/core/evaluation/marker_base.py +923 -0
  291. rasa/core/evaluation/marker_stats.py +293 -0
  292. rasa/core/evaluation/marker_tracker_loader.py +103 -0
  293. rasa/core/exceptions.py +29 -0
  294. rasa/core/exporter.py +284 -0
  295. rasa/core/featurizers/__init__.py +0 -0
  296. rasa/core/featurizers/precomputation.py +410 -0
  297. rasa/core/featurizers/single_state_featurizer.py +421 -0
  298. rasa/core/featurizers/tracker_featurizers.py +1262 -0
  299. rasa/core/http_interpreter.py +89 -0
  300. rasa/core/information_retrieval/__init__.py +7 -0
  301. rasa/core/information_retrieval/faiss.py +124 -0
  302. rasa/core/information_retrieval/information_retrieval.py +137 -0
  303. rasa/core/information_retrieval/milvus.py +59 -0
  304. rasa/core/information_retrieval/qdrant.py +96 -0
  305. rasa/core/jobs.py +63 -0
  306. rasa/core/lock.py +139 -0
  307. rasa/core/lock_store.py +343 -0
  308. rasa/core/migrate.py +403 -0
  309. rasa/core/nlg/__init__.py +3 -0
  310. rasa/core/nlg/callback.py +146 -0
  311. rasa/core/nlg/contextual_response_rephraser.py +320 -0
  312. rasa/core/nlg/generator.py +230 -0
  313. rasa/core/nlg/interpolator.py +143 -0
  314. rasa/core/nlg/response.py +155 -0
  315. rasa/core/nlg/summarize.py +70 -0
  316. rasa/core/persistor.py +538 -0
  317. rasa/core/policies/__init__.py +0 -0
  318. rasa/core/policies/ensemble.py +329 -0
  319. rasa/core/policies/enterprise_search_policy.py +905 -0
  320. rasa/core/policies/enterprise_search_prompt_template.jinja2 +25 -0
  321. rasa/core/policies/enterprise_search_prompt_with_citation_template.jinja2 +60 -0
  322. rasa/core/policies/flow_policy.py +205 -0
  323. rasa/core/policies/flows/__init__.py +0 -0
  324. rasa/core/policies/flows/flow_exceptions.py +44 -0
  325. rasa/core/policies/flows/flow_executor.py +754 -0
  326. rasa/core/policies/flows/flow_step_result.py +43 -0
  327. rasa/core/policies/intentless_policy.py +1031 -0
  328. rasa/core/policies/intentless_prompt_template.jinja2 +22 -0
  329. rasa/core/policies/memoization.py +538 -0
  330. rasa/core/policies/policy.py +725 -0
  331. rasa/core/policies/rule_policy.py +1273 -0
  332. rasa/core/policies/ted_policy.py +2169 -0
  333. rasa/core/policies/unexpected_intent_policy.py +1022 -0
  334. rasa/core/processor.py +1465 -0
  335. rasa/core/run.py +342 -0
  336. rasa/core/secrets_manager/__init__.py +0 -0
  337. rasa/core/secrets_manager/constants.py +36 -0
  338. rasa/core/secrets_manager/endpoints.py +391 -0
  339. rasa/core/secrets_manager/factory.py +241 -0
  340. rasa/core/secrets_manager/secret_manager.py +262 -0
  341. rasa/core/secrets_manager/vault.py +584 -0
  342. rasa/core/test.py +1335 -0
  343. rasa/core/tracker_store.py +1703 -0
  344. rasa/core/train.py +105 -0
  345. rasa/core/training/__init__.py +89 -0
  346. rasa/core/training/converters/__init__.py +0 -0
  347. rasa/core/training/converters/responses_prefix_converter.py +119 -0
  348. rasa/core/training/interactive.py +1744 -0
  349. rasa/core/training/story_conflict.py +381 -0
  350. rasa/core/training/training.py +93 -0
  351. rasa/core/utils.py +366 -0
  352. rasa/core/visualize.py +70 -0
  353. rasa/dialogue_understanding/__init__.py +0 -0
  354. rasa/dialogue_understanding/coexistence/__init__.py +0 -0
  355. rasa/dialogue_understanding/coexistence/constants.py +4 -0
  356. rasa/dialogue_understanding/coexistence/intent_based_router.py +196 -0
  357. rasa/dialogue_understanding/coexistence/llm_based_router.py +327 -0
  358. rasa/dialogue_understanding/coexistence/router_template.jinja2 +12 -0
  359. rasa/dialogue_understanding/commands/__init__.py +61 -0
  360. rasa/dialogue_understanding/commands/can_not_handle_command.py +70 -0
  361. rasa/dialogue_understanding/commands/cancel_flow_command.py +125 -0
  362. rasa/dialogue_understanding/commands/change_flow_command.py +44 -0
  363. rasa/dialogue_understanding/commands/chit_chat_answer_command.py +57 -0
  364. rasa/dialogue_understanding/commands/clarify_command.py +86 -0
  365. rasa/dialogue_understanding/commands/command.py +85 -0
  366. rasa/dialogue_understanding/commands/correct_slots_command.py +297 -0
  367. rasa/dialogue_understanding/commands/error_command.py +79 -0
  368. rasa/dialogue_understanding/commands/free_form_answer_command.py +9 -0
  369. rasa/dialogue_understanding/commands/handle_code_change_command.py +73 -0
  370. rasa/dialogue_understanding/commands/human_handoff_command.py +66 -0
  371. rasa/dialogue_understanding/commands/knowledge_answer_command.py +57 -0
  372. rasa/dialogue_understanding/commands/noop_command.py +54 -0
  373. rasa/dialogue_understanding/commands/repeat_bot_messages_command.py +60 -0
  374. rasa/dialogue_understanding/commands/restart_command.py +58 -0
  375. rasa/dialogue_understanding/commands/session_end_command.py +61 -0
  376. rasa/dialogue_understanding/commands/session_start_command.py +59 -0
  377. rasa/dialogue_understanding/commands/set_slot_command.py +160 -0
  378. rasa/dialogue_understanding/commands/skip_question_command.py +75 -0
  379. rasa/dialogue_understanding/commands/start_flow_command.py +107 -0
  380. rasa/dialogue_understanding/commands/user_silence_command.py +59 -0
  381. rasa/dialogue_understanding/commands/utils.py +45 -0
  382. rasa/dialogue_understanding/generator/__init__.py +21 -0
  383. rasa/dialogue_understanding/generator/command_generator.py +464 -0
  384. rasa/dialogue_understanding/generator/constants.py +27 -0
  385. rasa/dialogue_understanding/generator/flow_document_template.jinja2 +4 -0
  386. rasa/dialogue_understanding/generator/flow_retrieval.py +466 -0
  387. rasa/dialogue_understanding/generator/llm_based_command_generator.py +500 -0
  388. rasa/dialogue_understanding/generator/llm_command_generator.py +67 -0
  389. rasa/dialogue_understanding/generator/multi_step/__init__.py +0 -0
  390. rasa/dialogue_understanding/generator/multi_step/fill_slots_prompt.jinja2 +62 -0
  391. rasa/dialogue_understanding/generator/multi_step/handle_flows_prompt.jinja2 +38 -0
  392. rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +920 -0
  393. rasa/dialogue_understanding/generator/nlu_command_adapter.py +261 -0
  394. rasa/dialogue_understanding/generator/single_step/__init__.py +0 -0
  395. rasa/dialogue_understanding/generator/single_step/command_prompt_template.jinja2 +60 -0
  396. rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +486 -0
  397. rasa/dialogue_understanding/patterns/__init__.py +0 -0
  398. rasa/dialogue_understanding/patterns/cancel.py +111 -0
  399. rasa/dialogue_understanding/patterns/cannot_handle.py +43 -0
  400. rasa/dialogue_understanding/patterns/chitchat.py +37 -0
  401. rasa/dialogue_understanding/patterns/clarify.py +97 -0
  402. rasa/dialogue_understanding/patterns/code_change.py +41 -0
  403. rasa/dialogue_understanding/patterns/collect_information.py +90 -0
  404. rasa/dialogue_understanding/patterns/completed.py +40 -0
  405. rasa/dialogue_understanding/patterns/continue_interrupted.py +42 -0
  406. rasa/dialogue_understanding/patterns/correction.py +278 -0
  407. rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +301 -0
  408. rasa/dialogue_understanding/patterns/human_handoff.py +37 -0
  409. rasa/dialogue_understanding/patterns/internal_error.py +47 -0
  410. rasa/dialogue_understanding/patterns/repeat.py +37 -0
  411. rasa/dialogue_understanding/patterns/restart.py +37 -0
  412. rasa/dialogue_understanding/patterns/search.py +37 -0
  413. rasa/dialogue_understanding/patterns/session_start.py +37 -0
  414. rasa/dialogue_understanding/patterns/skip_question.py +38 -0
  415. rasa/dialogue_understanding/patterns/user_silence.py +37 -0
  416. rasa/dialogue_understanding/processor/__init__.py +0 -0
  417. rasa/dialogue_understanding/processor/command_processor.py +720 -0
  418. rasa/dialogue_understanding/processor/command_processor_component.py +43 -0
  419. rasa/dialogue_understanding/stack/__init__.py +0 -0
  420. rasa/dialogue_understanding/stack/dialogue_stack.py +178 -0
  421. rasa/dialogue_understanding/stack/frames/__init__.py +19 -0
  422. rasa/dialogue_understanding/stack/frames/chit_chat_frame.py +27 -0
  423. rasa/dialogue_understanding/stack/frames/dialogue_stack_frame.py +137 -0
  424. rasa/dialogue_understanding/stack/frames/flow_stack_frame.py +157 -0
  425. rasa/dialogue_understanding/stack/frames/pattern_frame.py +10 -0
  426. rasa/dialogue_understanding/stack/frames/search_frame.py +27 -0
  427. rasa/dialogue_understanding/stack/utils.py +211 -0
  428. rasa/dialogue_understanding/utils.py +14 -0
  429. rasa/dialogue_understanding_test/__init__.py +0 -0
  430. rasa/dialogue_understanding_test/command_metric_calculation.py +12 -0
  431. rasa/dialogue_understanding_test/constants.py +17 -0
  432. rasa/dialogue_understanding_test/du_test_case.py +118 -0
  433. rasa/dialogue_understanding_test/du_test_result.py +11 -0
  434. rasa/dialogue_understanding_test/du_test_runner.py +93 -0
  435. rasa/dialogue_understanding_test/io.py +54 -0
  436. rasa/dialogue_understanding_test/validation.py +22 -0
  437. rasa/e2e_test/__init__.py +0 -0
  438. rasa/e2e_test/aggregate_test_stats_calculator.py +134 -0
  439. rasa/e2e_test/assertions.py +1345 -0
  440. rasa/e2e_test/assertions_schema.yml +129 -0
  441. rasa/e2e_test/constants.py +31 -0
  442. rasa/e2e_test/e2e_config.py +220 -0
  443. rasa/e2e_test/e2e_config_schema.yml +26 -0
  444. rasa/e2e_test/e2e_test_case.py +569 -0
  445. rasa/e2e_test/e2e_test_converter.py +363 -0
  446. rasa/e2e_test/e2e_test_converter_prompt.jinja2 +70 -0
  447. rasa/e2e_test/e2e_test_coverage_report.py +364 -0
  448. rasa/e2e_test/e2e_test_result.py +54 -0
  449. rasa/e2e_test/e2e_test_runner.py +1192 -0
  450. rasa/e2e_test/e2e_test_schema.yml +181 -0
  451. rasa/e2e_test/pykwalify_extensions.py +39 -0
  452. rasa/e2e_test/stub_custom_action.py +70 -0
  453. rasa/e2e_test/utils/__init__.py +0 -0
  454. rasa/e2e_test/utils/e2e_yaml_utils.py +55 -0
  455. rasa/e2e_test/utils/io.py +598 -0
  456. rasa/e2e_test/utils/validation.py +178 -0
  457. rasa/engine/__init__.py +0 -0
  458. rasa/engine/caching.py +463 -0
  459. rasa/engine/constants.py +17 -0
  460. rasa/engine/exceptions.py +14 -0
  461. rasa/engine/graph.py +642 -0
  462. rasa/engine/loader.py +48 -0
  463. rasa/engine/recipes/__init__.py +0 -0
  464. rasa/engine/recipes/config_files/default_config.yml +41 -0
  465. rasa/engine/recipes/default_components.py +97 -0
  466. rasa/engine/recipes/default_recipe.py +1272 -0
  467. rasa/engine/recipes/graph_recipe.py +79 -0
  468. rasa/engine/recipes/recipe.py +93 -0
  469. rasa/engine/runner/__init__.py +0 -0
  470. rasa/engine/runner/dask.py +250 -0
  471. rasa/engine/runner/interface.py +49 -0
  472. rasa/engine/storage/__init__.py +0 -0
  473. rasa/engine/storage/local_model_storage.py +244 -0
  474. rasa/engine/storage/resource.py +110 -0
  475. rasa/engine/storage/storage.py +199 -0
  476. rasa/engine/training/__init__.py +0 -0
  477. rasa/engine/training/components.py +176 -0
  478. rasa/engine/training/fingerprinting.py +64 -0
  479. rasa/engine/training/graph_trainer.py +256 -0
  480. rasa/engine/training/hooks.py +164 -0
  481. rasa/engine/validation.py +1451 -0
  482. rasa/env.py +14 -0
  483. rasa/exceptions.py +69 -0
  484. rasa/graph_components/__init__.py +0 -0
  485. rasa/graph_components/converters/__init__.py +0 -0
  486. rasa/graph_components/converters/nlu_message_converter.py +48 -0
  487. rasa/graph_components/providers/__init__.py +0 -0
  488. rasa/graph_components/providers/domain_for_core_training_provider.py +87 -0
  489. rasa/graph_components/providers/domain_provider.py +71 -0
  490. rasa/graph_components/providers/flows_provider.py +74 -0
  491. rasa/graph_components/providers/forms_provider.py +44 -0
  492. rasa/graph_components/providers/nlu_training_data_provider.py +56 -0
  493. rasa/graph_components/providers/responses_provider.py +44 -0
  494. rasa/graph_components/providers/rule_only_provider.py +49 -0
  495. rasa/graph_components/providers/story_graph_provider.py +96 -0
  496. rasa/graph_components/providers/training_tracker_provider.py +55 -0
  497. rasa/graph_components/validators/__init__.py +0 -0
  498. rasa/graph_components/validators/default_recipe_validator.py +550 -0
  499. rasa/graph_components/validators/finetuning_validator.py +302 -0
  500. rasa/hooks.py +111 -0
  501. rasa/jupyter.py +63 -0
  502. rasa/llm_fine_tuning/__init__.py +0 -0
  503. rasa/llm_fine_tuning/annotation_module.py +241 -0
  504. rasa/llm_fine_tuning/conversations.py +144 -0
  505. rasa/llm_fine_tuning/llm_data_preparation_module.py +178 -0
  506. rasa/llm_fine_tuning/paraphrasing/__init__.py +0 -0
  507. rasa/llm_fine_tuning/paraphrasing/conversation_rephraser.py +281 -0
  508. rasa/llm_fine_tuning/paraphrasing/default_rephrase_prompt_template.jina2 +44 -0
  509. rasa/llm_fine_tuning/paraphrasing/rephrase_validator.py +121 -0
  510. rasa/llm_fine_tuning/paraphrasing/rephrased_user_message.py +10 -0
  511. rasa/llm_fine_tuning/paraphrasing_module.py +128 -0
  512. rasa/llm_fine_tuning/storage.py +174 -0
  513. rasa/llm_fine_tuning/train_test_split_module.py +441 -0
  514. rasa/markers/__init__.py +0 -0
  515. rasa/markers/marker.py +269 -0
  516. rasa/markers/marker_base.py +828 -0
  517. rasa/markers/upload.py +74 -0
  518. rasa/markers/validate.py +21 -0
  519. rasa/model.py +118 -0
  520. rasa/model_manager/__init__.py +0 -0
  521. rasa/model_manager/config.py +40 -0
  522. rasa/model_manager/model_api.py +559 -0
  523. rasa/model_manager/runner_service.py +286 -0
  524. rasa/model_manager/socket_bridge.py +146 -0
  525. rasa/model_manager/studio_jwt_auth.py +86 -0
  526. rasa/model_manager/trainer_service.py +325 -0
  527. rasa/model_manager/utils.py +87 -0
  528. rasa/model_manager/warm_rasa_process.py +187 -0
  529. rasa/model_service.py +112 -0
  530. rasa/model_testing.py +457 -0
  531. rasa/model_training.py +596 -0
  532. rasa/nlu/__init__.py +7 -0
  533. rasa/nlu/classifiers/__init__.py +3 -0
  534. rasa/nlu/classifiers/classifier.py +5 -0
  535. rasa/nlu/classifiers/diet_classifier.py +1881 -0
  536. rasa/nlu/classifiers/fallback_classifier.py +192 -0
  537. rasa/nlu/classifiers/keyword_intent_classifier.py +188 -0
  538. rasa/nlu/classifiers/logistic_regression_classifier.py +253 -0
  539. rasa/nlu/classifiers/mitie_intent_classifier.py +156 -0
  540. rasa/nlu/classifiers/regex_message_handler.py +56 -0
  541. rasa/nlu/classifiers/sklearn_intent_classifier.py +330 -0
  542. rasa/nlu/constants.py +77 -0
  543. rasa/nlu/convert.py +40 -0
  544. rasa/nlu/emulators/__init__.py +0 -0
  545. rasa/nlu/emulators/dialogflow.py +55 -0
  546. rasa/nlu/emulators/emulator.py +49 -0
  547. rasa/nlu/emulators/luis.py +86 -0
  548. rasa/nlu/emulators/no_emulator.py +10 -0
  549. rasa/nlu/emulators/wit.py +56 -0
  550. rasa/nlu/extractors/__init__.py +0 -0
  551. rasa/nlu/extractors/crf_entity_extractor.py +715 -0
  552. rasa/nlu/extractors/duckling_entity_extractor.py +206 -0
  553. rasa/nlu/extractors/entity_synonyms.py +178 -0
  554. rasa/nlu/extractors/extractor.py +470 -0
  555. rasa/nlu/extractors/mitie_entity_extractor.py +293 -0
  556. rasa/nlu/extractors/regex_entity_extractor.py +220 -0
  557. rasa/nlu/extractors/spacy_entity_extractor.py +95 -0
  558. rasa/nlu/featurizers/__init__.py +0 -0
  559. rasa/nlu/featurizers/dense_featurizer/__init__.py +0 -0
  560. rasa/nlu/featurizers/dense_featurizer/convert_featurizer.py +445 -0
  561. rasa/nlu/featurizers/dense_featurizer/dense_featurizer.py +57 -0
  562. rasa/nlu/featurizers/dense_featurizer/lm_featurizer.py +768 -0
  563. rasa/nlu/featurizers/dense_featurizer/mitie_featurizer.py +170 -0
  564. rasa/nlu/featurizers/dense_featurizer/spacy_featurizer.py +132 -0
  565. rasa/nlu/featurizers/featurizer.py +89 -0
  566. rasa/nlu/featurizers/sparse_featurizer/__init__.py +0 -0
  567. rasa/nlu/featurizers/sparse_featurizer/count_vectors_featurizer.py +867 -0
  568. rasa/nlu/featurizers/sparse_featurizer/lexical_syntactic_featurizer.py +571 -0
  569. rasa/nlu/featurizers/sparse_featurizer/regex_featurizer.py +271 -0
  570. rasa/nlu/featurizers/sparse_featurizer/sparse_featurizer.py +9 -0
  571. rasa/nlu/model.py +24 -0
  572. rasa/nlu/run.py +27 -0
  573. rasa/nlu/selectors/__init__.py +0 -0
  574. rasa/nlu/selectors/response_selector.py +987 -0
  575. rasa/nlu/test.py +1940 -0
  576. rasa/nlu/tokenizers/__init__.py +0 -0
  577. rasa/nlu/tokenizers/jieba_tokenizer.py +148 -0
  578. rasa/nlu/tokenizers/mitie_tokenizer.py +75 -0
  579. rasa/nlu/tokenizers/spacy_tokenizer.py +72 -0
  580. rasa/nlu/tokenizers/tokenizer.py +239 -0
  581. rasa/nlu/tokenizers/whitespace_tokenizer.py +95 -0
  582. rasa/nlu/utils/__init__.py +35 -0
  583. rasa/nlu/utils/bilou_utils.py +462 -0
  584. rasa/nlu/utils/hugging_face/__init__.py +0 -0
  585. rasa/nlu/utils/hugging_face/registry.py +108 -0
  586. rasa/nlu/utils/hugging_face/transformers_pre_post_processors.py +311 -0
  587. rasa/nlu/utils/mitie_utils.py +113 -0
  588. rasa/nlu/utils/pattern_utils.py +168 -0
  589. rasa/nlu/utils/spacy_utils.py +310 -0
  590. rasa/plugin.py +90 -0
  591. rasa/server.py +1588 -0
  592. rasa/shared/__init__.py +0 -0
  593. rasa/shared/constants.py +311 -0
  594. rasa/shared/core/__init__.py +0 -0
  595. rasa/shared/core/command_payload_reader.py +109 -0
  596. rasa/shared/core/constants.py +180 -0
  597. rasa/shared/core/conversation.py +46 -0
  598. rasa/shared/core/domain.py +2172 -0
  599. rasa/shared/core/events.py +2559 -0
  600. rasa/shared/core/flows/__init__.py +7 -0
  601. rasa/shared/core/flows/flow.py +562 -0
  602. rasa/shared/core/flows/flow_path.py +84 -0
  603. rasa/shared/core/flows/flow_step.py +146 -0
  604. rasa/shared/core/flows/flow_step_links.py +319 -0
  605. rasa/shared/core/flows/flow_step_sequence.py +70 -0
  606. rasa/shared/core/flows/flows_list.py +258 -0
  607. rasa/shared/core/flows/flows_yaml_schema.json +303 -0
  608. rasa/shared/core/flows/nlu_trigger.py +117 -0
  609. rasa/shared/core/flows/steps/__init__.py +24 -0
  610. rasa/shared/core/flows/steps/action.py +56 -0
  611. rasa/shared/core/flows/steps/call.py +64 -0
  612. rasa/shared/core/flows/steps/collect.py +112 -0
  613. rasa/shared/core/flows/steps/constants.py +5 -0
  614. rasa/shared/core/flows/steps/continuation.py +36 -0
  615. rasa/shared/core/flows/steps/end.py +22 -0
  616. rasa/shared/core/flows/steps/internal.py +44 -0
  617. rasa/shared/core/flows/steps/link.py +51 -0
  618. rasa/shared/core/flows/steps/no_operation.py +48 -0
  619. rasa/shared/core/flows/steps/set_slots.py +50 -0
  620. rasa/shared/core/flows/steps/start.py +30 -0
  621. rasa/shared/core/flows/utils.py +39 -0
  622. rasa/shared/core/flows/validation.py +735 -0
  623. rasa/shared/core/flows/yaml_flows_io.py +405 -0
  624. rasa/shared/core/generator.py +908 -0
  625. rasa/shared/core/slot_mappings.py +526 -0
  626. rasa/shared/core/slots.py +654 -0
  627. rasa/shared/core/trackers.py +1183 -0
  628. rasa/shared/core/training_data/__init__.py +0 -0
  629. rasa/shared/core/training_data/loading.py +89 -0
  630. rasa/shared/core/training_data/story_reader/__init__.py +0 -0
  631. rasa/shared/core/training_data/story_reader/story_reader.py +129 -0
  632. rasa/shared/core/training_data/story_reader/story_step_builder.py +168 -0
  633. rasa/shared/core/training_data/story_reader/yaml_story_reader.py +888 -0
  634. rasa/shared/core/training_data/story_writer/__init__.py +0 -0
  635. rasa/shared/core/training_data/story_writer/story_writer.py +76 -0
  636. rasa/shared/core/training_data/story_writer/yaml_story_writer.py +444 -0
  637. rasa/shared/core/training_data/structures.py +858 -0
  638. rasa/shared/core/training_data/visualization.html +146 -0
  639. rasa/shared/core/training_data/visualization.py +603 -0
  640. rasa/shared/data.py +249 -0
  641. rasa/shared/engine/__init__.py +0 -0
  642. rasa/shared/engine/caching.py +26 -0
  643. rasa/shared/exceptions.py +167 -0
  644. rasa/shared/importers/__init__.py +0 -0
  645. rasa/shared/importers/importer.py +770 -0
  646. rasa/shared/importers/multi_project.py +215 -0
  647. rasa/shared/importers/rasa.py +108 -0
  648. rasa/shared/importers/remote_importer.py +196 -0
  649. rasa/shared/importers/utils.py +36 -0
  650. rasa/shared/nlu/__init__.py +0 -0
  651. rasa/shared/nlu/constants.py +53 -0
  652. rasa/shared/nlu/interpreter.py +10 -0
  653. rasa/shared/nlu/training_data/__init__.py +0 -0
  654. rasa/shared/nlu/training_data/entities_parser.py +208 -0
  655. rasa/shared/nlu/training_data/features.py +492 -0
  656. rasa/shared/nlu/training_data/formats/__init__.py +10 -0
  657. rasa/shared/nlu/training_data/formats/dialogflow.py +163 -0
  658. rasa/shared/nlu/training_data/formats/luis.py +87 -0
  659. rasa/shared/nlu/training_data/formats/rasa.py +135 -0
  660. rasa/shared/nlu/training_data/formats/rasa_yaml.py +618 -0
  661. rasa/shared/nlu/training_data/formats/readerwriter.py +244 -0
  662. rasa/shared/nlu/training_data/formats/wit.py +52 -0
  663. rasa/shared/nlu/training_data/loading.py +137 -0
  664. rasa/shared/nlu/training_data/lookup_tables_parser.py +30 -0
  665. rasa/shared/nlu/training_data/message.py +490 -0
  666. rasa/shared/nlu/training_data/schemas/__init__.py +0 -0
  667. rasa/shared/nlu/training_data/schemas/data_schema.py +85 -0
  668. rasa/shared/nlu/training_data/schemas/nlu.yml +53 -0
  669. rasa/shared/nlu/training_data/schemas/responses.yml +70 -0
  670. rasa/shared/nlu/training_data/synonyms_parser.py +42 -0
  671. rasa/shared/nlu/training_data/training_data.py +729 -0
  672. rasa/shared/nlu/training_data/util.py +223 -0
  673. rasa/shared/providers/__init__.py +0 -0
  674. rasa/shared/providers/_configs/__init__.py +0 -0
  675. rasa/shared/providers/_configs/azure_openai_client_config.py +677 -0
  676. rasa/shared/providers/_configs/client_config.py +59 -0
  677. rasa/shared/providers/_configs/default_litellm_client_config.py +132 -0
  678. rasa/shared/providers/_configs/huggingface_local_embedding_client_config.py +236 -0
  679. rasa/shared/providers/_configs/litellm_router_client_config.py +222 -0
  680. rasa/shared/providers/_configs/model_group_config.py +173 -0
  681. rasa/shared/providers/_configs/openai_client_config.py +177 -0
  682. rasa/shared/providers/_configs/rasa_llm_client_config.py +75 -0
  683. rasa/shared/providers/_configs/self_hosted_llm_client_config.py +178 -0
  684. rasa/shared/providers/_configs/utils.py +117 -0
  685. rasa/shared/providers/_ssl_verification_utils.py +124 -0
  686. rasa/shared/providers/_utils.py +79 -0
  687. rasa/shared/providers/constants.py +7 -0
  688. rasa/shared/providers/embedding/__init__.py +0 -0
  689. rasa/shared/providers/embedding/_base_litellm_embedding_client.py +243 -0
  690. rasa/shared/providers/embedding/_langchain_embedding_client_adapter.py +74 -0
  691. rasa/shared/providers/embedding/azure_openai_embedding_client.py +335 -0
  692. rasa/shared/providers/embedding/default_litellm_embedding_client.py +126 -0
  693. rasa/shared/providers/embedding/embedding_client.py +90 -0
  694. rasa/shared/providers/embedding/embedding_response.py +41 -0
  695. rasa/shared/providers/embedding/huggingface_local_embedding_client.py +191 -0
  696. rasa/shared/providers/embedding/litellm_router_embedding_client.py +138 -0
  697. rasa/shared/providers/embedding/openai_embedding_client.py +172 -0
  698. rasa/shared/providers/llm/__init__.py +0 -0
  699. rasa/shared/providers/llm/_base_litellm_client.py +265 -0
  700. rasa/shared/providers/llm/azure_openai_llm_client.py +415 -0
  701. rasa/shared/providers/llm/default_litellm_llm_client.py +110 -0
  702. rasa/shared/providers/llm/litellm_router_llm_client.py +202 -0
  703. rasa/shared/providers/llm/llm_client.py +78 -0
  704. rasa/shared/providers/llm/llm_response.py +50 -0
  705. rasa/shared/providers/llm/openai_llm_client.py +161 -0
  706. rasa/shared/providers/llm/rasa_llm_client.py +120 -0
  707. rasa/shared/providers/llm/self_hosted_llm_client.py +276 -0
  708. rasa/shared/providers/mappings.py +94 -0
  709. rasa/shared/providers/router/__init__.py +0 -0
  710. rasa/shared/providers/router/_base_litellm_router_client.py +185 -0
  711. rasa/shared/providers/router/router_client.py +75 -0
  712. rasa/shared/utils/__init__.py +0 -0
  713. rasa/shared/utils/cli.py +102 -0
  714. rasa/shared/utils/common.py +324 -0
  715. rasa/shared/utils/constants.py +4 -0
  716. rasa/shared/utils/health_check/__init__.py +0 -0
  717. rasa/shared/utils/health_check/embeddings_health_check_mixin.py +31 -0
  718. rasa/shared/utils/health_check/health_check.py +258 -0
  719. rasa/shared/utils/health_check/llm_health_check_mixin.py +31 -0
  720. rasa/shared/utils/io.py +499 -0
  721. rasa/shared/utils/llm.py +764 -0
  722. rasa/shared/utils/pykwalify_extensions.py +27 -0
  723. rasa/shared/utils/schemas/__init__.py +0 -0
  724. rasa/shared/utils/schemas/config.yml +2 -0
  725. rasa/shared/utils/schemas/domain.yml +145 -0
  726. rasa/shared/utils/schemas/events.py +214 -0
  727. rasa/shared/utils/schemas/model_config.yml +36 -0
  728. rasa/shared/utils/schemas/stories.yml +173 -0
  729. rasa/shared/utils/yaml.py +1068 -0
  730. rasa/studio/__init__.py +0 -0
  731. rasa/studio/auth.py +270 -0
  732. rasa/studio/config.py +136 -0
  733. rasa/studio/constants.py +19 -0
  734. rasa/studio/data_handler.py +368 -0
  735. rasa/studio/download.py +489 -0
  736. rasa/studio/results_logger.py +137 -0
  737. rasa/studio/train.py +134 -0
  738. rasa/studio/upload.py +563 -0
  739. rasa/telemetry.py +1876 -0
  740. rasa/tracing/__init__.py +0 -0
  741. rasa/tracing/config.py +355 -0
  742. rasa/tracing/constants.py +62 -0
  743. rasa/tracing/instrumentation/__init__.py +0 -0
  744. rasa/tracing/instrumentation/attribute_extractors.py +765 -0
  745. rasa/tracing/instrumentation/instrumentation.py +1306 -0
  746. rasa/tracing/instrumentation/intentless_policy_instrumentation.py +144 -0
  747. rasa/tracing/instrumentation/metrics.py +294 -0
  748. rasa/tracing/metric_instrument_provider.py +205 -0
  749. rasa/utils/__init__.py +0 -0
  750. rasa/utils/beta.py +83 -0
  751. rasa/utils/cli.py +28 -0
  752. rasa/utils/common.py +639 -0
  753. rasa/utils/converter.py +53 -0
  754. rasa/utils/endpoints.py +331 -0
  755. rasa/utils/io.py +252 -0
  756. rasa/utils/json_utils.py +60 -0
  757. rasa/utils/licensing.py +542 -0
  758. rasa/utils/log_utils.py +181 -0
  759. rasa/utils/mapper.py +210 -0
  760. rasa/utils/ml_utils.py +147 -0
  761. rasa/utils/plotting.py +362 -0
  762. rasa/utils/sanic_error_handler.py +32 -0
  763. rasa/utils/singleton.py +23 -0
  764. rasa/utils/tensorflow/__init__.py +0 -0
  765. rasa/utils/tensorflow/callback.py +112 -0
  766. rasa/utils/tensorflow/constants.py +116 -0
  767. rasa/utils/tensorflow/crf.py +492 -0
  768. rasa/utils/tensorflow/data_generator.py +440 -0
  769. rasa/utils/tensorflow/environment.py +161 -0
  770. rasa/utils/tensorflow/exceptions.py +5 -0
  771. rasa/utils/tensorflow/feature_array.py +366 -0
  772. rasa/utils/tensorflow/layers.py +1565 -0
  773. rasa/utils/tensorflow/layers_utils.py +113 -0
  774. rasa/utils/tensorflow/metrics.py +281 -0
  775. rasa/utils/tensorflow/model_data.py +798 -0
  776. rasa/utils/tensorflow/model_data_utils.py +499 -0
  777. rasa/utils/tensorflow/models.py +935 -0
  778. rasa/utils/tensorflow/rasa_layers.py +1094 -0
  779. rasa/utils/tensorflow/transformer.py +640 -0
  780. rasa/utils/tensorflow/types.py +6 -0
  781. rasa/utils/train_utils.py +572 -0
  782. rasa/utils/url_tools.py +53 -0
  783. rasa/utils/yaml.py +54 -0
  784. rasa/validator.py +1644 -0
  785. rasa/version.py +3 -0
  786. rasa_pro-3.12.0.dev1.dist-info/METADATA +199 -0
  787. rasa_pro-3.12.0.dev1.dist-info/NOTICE +5 -0
  788. rasa_pro-3.12.0.dev1.dist-info/RECORD +790 -0
  789. rasa_pro-3.12.0.dev1.dist-info/WHEEL +4 -0
  790. rasa_pro-3.12.0.dev1.dist-info/entry_points.txt +3 -0
@@ -0,0 +1,1881 @@
1
+ from __future__ import annotations
2
+
3
+ import copy
4
+ import logging
5
+ from collections import defaultdict
6
+ from pathlib import Path
7
+ from typing import Any, Dict, List, Optional, Text, Tuple, Union, TypeVar, Type
8
+
9
+ import numpy as np
10
+ import scipy.sparse
11
+ import tensorflow as tf
12
+
13
+ from rasa.exceptions import ModelNotFound
14
+ from rasa.nlu.featurizers.featurizer import Featurizer
15
+ from rasa.engine.graph import ExecutionContext, GraphComponent
16
+ from rasa.engine.recipes.default_recipe import DefaultV1Recipe
17
+ from rasa.engine.storage.resource import Resource
18
+ from rasa.engine.storage.storage import ModelStorage
19
+ from rasa.nlu.extractors.extractor import EntityExtractorMixin
20
+ from rasa.nlu.classifiers.classifier import IntentClassifier
21
+ import rasa.shared.utils.io
22
+ import rasa.nlu.utils.bilou_utils as bilou_utils
23
+ from rasa.shared.constants import DIAGNOSTIC_DATA
24
+ from rasa.nlu.extractors.extractor import EntityTagSpec
25
+ from rasa.nlu.classifiers import LABEL_RANKING_LENGTH
26
+ from rasa.utils import train_utils
27
+ from rasa.utils.tensorflow import rasa_layers
28
+ from rasa.utils.tensorflow.feature_array import (
29
+ FeatureArray,
30
+ serialize_nested_feature_arrays,
31
+ deserialize_nested_feature_arrays,
32
+ )
33
+ from rasa.utils.tensorflow.models import RasaModel, TransformerRasaModel
34
+ from rasa.utils.tensorflow.model_data import (
35
+ RasaModelData,
36
+ FeatureSignature,
37
+ )
38
+ from rasa.nlu.constants import TOKENS_NAMES, DEFAULT_TRANSFORMER_SIZE
39
+ from rasa.shared.nlu.constants import (
40
+ SPLIT_ENTITIES_BY_COMMA_DEFAULT_VALUE,
41
+ TEXT,
42
+ INTENT,
43
+ INTENT_RESPONSE_KEY,
44
+ ENTITIES,
45
+ ENTITY_ATTRIBUTE_TYPE,
46
+ ENTITY_ATTRIBUTE_GROUP,
47
+ ENTITY_ATTRIBUTE_ROLE,
48
+ NO_ENTITY_TAG,
49
+ SPLIT_ENTITIES_BY_COMMA,
50
+ )
51
+ from rasa.shared.exceptions import InvalidConfigException
52
+ from rasa.shared.nlu.training_data.training_data import TrainingData
53
+ from rasa.shared.nlu.training_data.message import Message
54
+ from rasa.utils.tensorflow.constants import (
55
+ DROP_SMALL_LAST_BATCH,
56
+ LABEL,
57
+ IDS,
58
+ HIDDEN_LAYERS_SIZES,
59
+ RENORMALIZE_CONFIDENCES,
60
+ SHARE_HIDDEN_LAYERS,
61
+ TRANSFORMER_SIZE,
62
+ NUM_TRANSFORMER_LAYERS,
63
+ NUM_HEADS,
64
+ BATCH_SIZES,
65
+ BATCH_STRATEGY,
66
+ EPOCHS,
67
+ RANDOM_SEED,
68
+ LEARNING_RATE,
69
+ RANKING_LENGTH,
70
+ LOSS_TYPE,
71
+ SIMILARITY_TYPE,
72
+ NUM_NEG,
73
+ SPARSE_INPUT_DROPOUT,
74
+ DENSE_INPUT_DROPOUT,
75
+ MASKED_LM,
76
+ ENTITY_RECOGNITION,
77
+ TENSORBOARD_LOG_DIR,
78
+ INTENT_CLASSIFICATION,
79
+ EVAL_NUM_EXAMPLES,
80
+ EVAL_NUM_EPOCHS,
81
+ UNIDIRECTIONAL_ENCODER,
82
+ DROP_RATE,
83
+ DROP_RATE_ATTENTION,
84
+ CONNECTION_DENSITY,
85
+ NEGATIVE_MARGIN_SCALE,
86
+ REGULARIZATION_CONSTANT,
87
+ SCALE_LOSS,
88
+ USE_MAX_NEG_SIM,
89
+ MAX_NEG_SIM,
90
+ MAX_POS_SIM,
91
+ EMBEDDING_DIMENSION,
92
+ BILOU_FLAG,
93
+ KEY_RELATIVE_ATTENTION,
94
+ VALUE_RELATIVE_ATTENTION,
95
+ MAX_RELATIVE_POSITION,
96
+ AUTO,
97
+ BALANCED,
98
+ CROSS_ENTROPY,
99
+ TENSORBOARD_LOG_LEVEL,
100
+ CONCAT_DIMENSION,
101
+ FEATURIZERS,
102
+ CHECKPOINT_MODEL,
103
+ SEQUENCE,
104
+ SENTENCE,
105
+ SEQUENCE_LENGTH,
106
+ DENSE_DIMENSION,
107
+ MASK,
108
+ CONSTRAIN_SIMILARITIES,
109
+ MODEL_CONFIDENCE,
110
+ SOFTMAX,
111
+ RUN_EAGERLY,
112
+ )
113
+
114
+ logger = logging.getLogger(__name__)
115
+
116
+ SPARSE = "sparse"
117
+ DENSE = "dense"
118
+ LABEL_KEY = LABEL
119
+ LABEL_SUB_KEY = IDS
120
+
121
+ POSSIBLE_TAGS = [ENTITY_ATTRIBUTE_TYPE, ENTITY_ATTRIBUTE_ROLE, ENTITY_ATTRIBUTE_GROUP]
122
+
123
+ DIETClassifierT = TypeVar("DIETClassifierT", bound="DIETClassifier")
124
+
125
+
126
+ @DefaultV1Recipe.register(
127
+ [
128
+ DefaultV1Recipe.ComponentType.INTENT_CLASSIFIER,
129
+ DefaultV1Recipe.ComponentType.ENTITY_EXTRACTOR,
130
+ ],
131
+ is_trainable=True,
132
+ )
133
+ class DIETClassifier(GraphComponent, IntentClassifier, EntityExtractorMixin):
134
+ """A multi-task model for intent classification and entity extraction.
135
+
136
+ DIET is Dual Intent and Entity Transformer.
137
+ The architecture is based on a transformer which is shared for both tasks.
138
+ A sequence of entity labels is predicted through a Conditional Random Field (CRF)
139
+ tagging layer on top of the transformer output sequence corresponding to the
140
+ input sequence of tokens. The transformer output for the ``__CLS__`` token and
141
+ intent labels are embedded into a single semantic vector space. We use the
142
+ dot-product loss to maximize the similarity with the target label and minimize
143
+ similarities with negative samples.
144
+ """
145
+
146
+ @classmethod
147
+ def required_components(cls) -> List[Type]:
148
+ """Components that should be included in the pipeline before this component."""
149
+ return [Featurizer]
150
+
151
+ @staticmethod
152
+ def get_default_config() -> Dict[Text, Any]:
153
+ """The component's default config (see parent class for full docstring)."""
154
+ # please make sure to update the docs when changing a default parameter
155
+ return {
156
+ # ## Architecture of the used neural network
157
+ # Hidden layer sizes for layers before the embedding layers for user message
158
+ # and labels.
159
+ # The number of hidden layers is equal to the length of the corresponding
160
+ # list.
161
+ HIDDEN_LAYERS_SIZES: {TEXT: [], LABEL: []},
162
+ # Whether to share the hidden layer weights between user message and labels.
163
+ SHARE_HIDDEN_LAYERS: False,
164
+ # Number of units in transformer
165
+ TRANSFORMER_SIZE: DEFAULT_TRANSFORMER_SIZE,
166
+ # Number of transformer layers
167
+ NUM_TRANSFORMER_LAYERS: 2,
168
+ # Number of attention heads in transformer
169
+ NUM_HEADS: 4,
170
+ # If 'True' use key relative embeddings in attention
171
+ KEY_RELATIVE_ATTENTION: False,
172
+ # If 'True' use value relative embeddings in attention
173
+ VALUE_RELATIVE_ATTENTION: False,
174
+ # Max position for relative embeddings. Only in effect if key- or value
175
+ # relative attention are turned on
176
+ MAX_RELATIVE_POSITION: 5,
177
+ # Use a unidirectional or bidirectional encoder.
178
+ UNIDIRECTIONAL_ENCODER: False,
179
+ # ## Training parameters
180
+ # Initial and final batch sizes:
181
+ # Batch size will be linearly increased for each epoch.
182
+ BATCH_SIZES: [64, 256],
183
+ # Strategy used when creating batches.
184
+ # Can be either 'sequence' or 'balanced'.
185
+ BATCH_STRATEGY: BALANCED,
186
+ # Number of epochs to train
187
+ EPOCHS: 300,
188
+ # Set random seed to any 'int' to get reproducible results
189
+ RANDOM_SEED: None,
190
+ # Initial learning rate for the optimizer
191
+ LEARNING_RATE: 0.001,
192
+ # ## Parameters for embeddings
193
+ # Dimension size of embedding vectors
194
+ EMBEDDING_DIMENSION: 20,
195
+ # Dense dimension to use for sparse features.
196
+ DENSE_DIMENSION: {TEXT: 128, LABEL: 20},
197
+ # Default dimension to use for concatenating sequence and sentence features.
198
+ CONCAT_DIMENSION: {TEXT: 128, LABEL: 20},
199
+ # The number of incorrect labels. The algorithm will minimize
200
+ # their similarity to the user input during training.
201
+ NUM_NEG: 20,
202
+ # Type of similarity measure to use, either 'auto' or 'cosine' or 'inner'.
203
+ SIMILARITY_TYPE: AUTO,
204
+ # The type of the loss function, either 'cross_entropy' or 'margin'.
205
+ LOSS_TYPE: CROSS_ENTROPY,
206
+ # Number of top intents for which confidences should be reported.
207
+ # Set to 0 if confidences for all intents should be reported.
208
+ RANKING_LENGTH: LABEL_RANKING_LENGTH,
209
+ # Indicates how similar the algorithm should try to make embedding vectors
210
+ # for correct labels.
211
+ # Should be 0.0 < ... < 1.0 for 'cosine' similarity type.
212
+ MAX_POS_SIM: 0.8,
213
+ # Maximum negative similarity for incorrect labels.
214
+ # Should be -1.0 < ... < 1.0 for 'cosine' similarity type.
215
+ MAX_NEG_SIM: -0.4,
216
+ # If 'True' the algorithm only minimizes maximum similarity over
217
+ # incorrect intent labels, used only if 'loss_type' is set to 'margin'.
218
+ USE_MAX_NEG_SIM: True,
219
+ # If 'True' scale loss inverse proportionally to the confidence
220
+ # of the correct prediction
221
+ SCALE_LOSS: False,
222
+ # ## Regularization parameters
223
+ # The scale of regularization
224
+ REGULARIZATION_CONSTANT: 0.002,
225
+ # The scale of how important is to minimize the maximum similarity
226
+ # between embeddings of different labels,
227
+ # used only if 'loss_type' is set to 'margin'.
228
+ NEGATIVE_MARGIN_SCALE: 0.8,
229
+ # Dropout rate for encoder
230
+ DROP_RATE: 0.2,
231
+ # Dropout rate for attention
232
+ DROP_RATE_ATTENTION: 0,
233
+ # Fraction of trainable weights in internal layers.
234
+ CONNECTION_DENSITY: 0.2,
235
+ # If 'True' apply dropout to sparse input tensors
236
+ SPARSE_INPUT_DROPOUT: True,
237
+ # If 'True' apply dropout to dense input tensors
238
+ DENSE_INPUT_DROPOUT: True,
239
+ # ## Evaluation parameters
240
+ # How often calculate validation accuracy.
241
+ # Small values may hurt performance.
242
+ EVAL_NUM_EPOCHS: 20,
243
+ # How many examples to use for hold out validation set
244
+ # Large values may hurt performance, e.g. model accuracy.
245
+ # Set to 0 for no validation.
246
+ EVAL_NUM_EXAMPLES: 0,
247
+ # ## Model config
248
+ # If 'True' intent classification is trained and intent predicted.
249
+ INTENT_CLASSIFICATION: True,
250
+ # If 'True' named entity recognition is trained and entities predicted.
251
+ ENTITY_RECOGNITION: True,
252
+ # If 'True' random tokens of the input message will be masked and the model
253
+ # should predict those tokens.
254
+ MASKED_LM: False,
255
+ # 'BILOU_flag' determines whether to use BILOU tagging or not.
256
+ # If set to 'True' labelling is more rigorous, however more
257
+ # examples per entity are required.
258
+ # Rule of thumb: you should have more than 100 examples per entity.
259
+ BILOU_FLAG: True,
260
+ # If you want to use tensorboard to visualize training and validation
261
+ # metrics, set this option to a valid output directory.
262
+ TENSORBOARD_LOG_DIR: None,
263
+ # Define when training metrics for tensorboard should be logged.
264
+ # Either after every epoch or for every training step.
265
+ # Valid values: 'epoch' and 'batch'
266
+ TENSORBOARD_LOG_LEVEL: "epoch",
267
+ # Perform model checkpointing
268
+ CHECKPOINT_MODEL: False,
269
+ # Specify what features to use as sequence and sentence features
270
+ # By default all features in the pipeline are used.
271
+ FEATURIZERS: [],
272
+ # Split entities by comma, this makes sense e.g. for a list of ingredients
273
+ # in a recipie, but it doesn't make sense for the parts of an address
274
+ SPLIT_ENTITIES_BY_COMMA: True,
275
+ # If 'True' applies sigmoid on all similarity terms and adds
276
+ # it to the loss function to ensure that similarity values are
277
+ # approximately bounded. Used inside cross-entropy loss only.
278
+ CONSTRAIN_SIMILARITIES: False,
279
+ # Model confidence to be returned during inference. Currently, the only
280
+ # possible value is `softmax`.
281
+ MODEL_CONFIDENCE: SOFTMAX,
282
+ # Determines whether the confidences of the chosen top intents should be
283
+ # renormalized so that they sum up to 1. By default, we do not renormalize
284
+ # and return the confidences for the top intents as is.
285
+ # Note that renormalization only makes sense if confidences are generated
286
+ # via `softmax`.
287
+ RENORMALIZE_CONFIDENCES: False,
288
+ # Determines whether to construct the model graph or not.
289
+ # This is advantageous when the model is only trained or inferred for
290
+ # a few steps, as the compilation of the graph tends to take more time than
291
+ # running it. It is recommended to not adjust the optimization parameter.
292
+ RUN_EAGERLY: False,
293
+ # Determines whether the last batch should be dropped if it contains fewer
294
+ # than half a batch size of examples
295
+ DROP_SMALL_LAST_BATCH: False,
296
+ }
297
+
298
+ def __init__(
299
+ self,
300
+ config: Dict[Text, Any],
301
+ model_storage: ModelStorage,
302
+ resource: Resource,
303
+ execution_context: ExecutionContext,
304
+ index_label_id_mapping: Optional[Dict[int, Text]] = None,
305
+ entity_tag_specs: Optional[List[EntityTagSpec]] = None,
306
+ model: Optional[RasaModel] = None,
307
+ sparse_feature_sizes: Optional[Dict[Text, Dict[Text, List[int]]]] = None,
308
+ ) -> None:
309
+ """Declare instance variables with default values."""
310
+ if EPOCHS not in config:
311
+ rasa.shared.utils.io.raise_warning(
312
+ f"Please configure the number of '{EPOCHS}' in your configuration file."
313
+ f" We will change the default value of '{EPOCHS}' in the future to 1. "
314
+ )
315
+
316
+ self.component_config = config
317
+ self._model_storage = model_storage
318
+ self._resource = resource
319
+ self._execution_context = execution_context
320
+
321
+ self._check_config_parameters()
322
+
323
+ # transform numbers to labels
324
+ self.index_label_id_mapping = index_label_id_mapping or {}
325
+
326
+ self._entity_tag_specs = entity_tag_specs
327
+
328
+ self.model = model
329
+
330
+ self.tmp_checkpoint_dir = None
331
+ if self.component_config[CHECKPOINT_MODEL]:
332
+ self.tmp_checkpoint_dir = Path(rasa.utils.io.create_temporary_directory())
333
+
334
+ self._label_data: Optional[RasaModelData] = None
335
+ self._data_example: Optional[Dict[Text, Dict[Text, List[FeatureArray]]]] = None
336
+
337
+ self.split_entities_config = rasa.utils.train_utils.init_split_entities(
338
+ self.component_config[SPLIT_ENTITIES_BY_COMMA],
339
+ SPLIT_ENTITIES_BY_COMMA_DEFAULT_VALUE,
340
+ )
341
+
342
+ self.finetune_mode = self._execution_context.is_finetuning
343
+ self._sparse_feature_sizes = sparse_feature_sizes
344
+
345
+ # init helpers
346
+ def _check_masked_lm(self) -> None:
347
+ if (
348
+ self.component_config[MASKED_LM]
349
+ and self.component_config[NUM_TRANSFORMER_LAYERS] == 0
350
+ ):
351
+ raise ValueError(
352
+ f"If number of transformer layers is 0, "
353
+ f"'{MASKED_LM}' option should be 'False'."
354
+ )
355
+
356
+ def _check_share_hidden_layers_sizes(self) -> None:
357
+ if self.component_config.get(SHARE_HIDDEN_LAYERS):
358
+ first_hidden_layer_sizes = next(
359
+ iter(self.component_config[HIDDEN_LAYERS_SIZES].values())
360
+ )
361
+ # check that all hidden layer sizes are the same
362
+ identical_hidden_layer_sizes = all(
363
+ current_hidden_layer_sizes == first_hidden_layer_sizes
364
+ for current_hidden_layer_sizes in self.component_config[
365
+ HIDDEN_LAYERS_SIZES
366
+ ].values()
367
+ )
368
+ if not identical_hidden_layer_sizes:
369
+ raise ValueError(
370
+ f"If hidden layer weights are shared, "
371
+ f"{HIDDEN_LAYERS_SIZES} must coincide."
372
+ )
373
+
374
+ def _check_config_parameters(self) -> None:
375
+ self.component_config = train_utils.check_deprecated_options(
376
+ self.component_config
377
+ )
378
+
379
+ self._check_masked_lm()
380
+ self._check_share_hidden_layers_sizes()
381
+
382
+ self.component_config = train_utils.update_confidence_type(
383
+ self.component_config
384
+ )
385
+
386
+ train_utils.validate_configuration_settings(self.component_config)
387
+
388
+ self.component_config = train_utils.update_similarity_type(
389
+ self.component_config
390
+ )
391
+ self.component_config = train_utils.update_evaluation_parameters(
392
+ self.component_config
393
+ )
394
+
395
+ @classmethod
396
+ def create(
397
+ cls,
398
+ config: Dict[Text, Any],
399
+ model_storage: ModelStorage,
400
+ resource: Resource,
401
+ execution_context: ExecutionContext,
402
+ ) -> DIETClassifier:
403
+ """Creates a new untrained component (see parent class for full docstring)."""
404
+ return cls(config, model_storage, resource, execution_context)
405
+
406
+ @property
407
+ def label_key(self) -> Optional[Text]:
408
+ """Return key if intent classification is activated."""
409
+ return LABEL_KEY if self.component_config[INTENT_CLASSIFICATION] else None
410
+
411
+ @property
412
+ def label_sub_key(self) -> Optional[Text]:
413
+ """Return sub key if intent classification is activated."""
414
+ return LABEL_SUB_KEY if self.component_config[INTENT_CLASSIFICATION] else None
415
+
416
+ @staticmethod
417
+ def model_class() -> Type[RasaModel]:
418
+ return DIET
419
+
420
+ # training data helpers:
421
+ @staticmethod
422
+ def _label_id_index_mapping(
423
+ training_data: TrainingData, attribute: Text
424
+ ) -> Dict[Text, int]:
425
+ """Create label_id dictionary."""
426
+ distinct_label_ids = {
427
+ example.get(attribute) for example in training_data.intent_examples
428
+ } - {None}
429
+ return {
430
+ label_id: idx for idx, label_id in enumerate(sorted(distinct_label_ids))
431
+ }
432
+
433
+ @staticmethod
434
+ def _invert_mapping(mapping: Dict) -> Dict:
435
+ return {value: key for key, value in mapping.items()}
436
+
437
+ def _create_entity_tag_specs(
438
+ self, training_data: TrainingData
439
+ ) -> List[EntityTagSpec]:
440
+ """Create entity tag specifications with their respective tag id mappings."""
441
+ _tag_specs = []
442
+
443
+ for tag_name in POSSIBLE_TAGS:
444
+ if self.component_config[BILOU_FLAG]:
445
+ tag_id_index_mapping = bilou_utils.build_tag_id_dict(
446
+ training_data, tag_name
447
+ )
448
+ else:
449
+ tag_id_index_mapping = self._tag_id_index_mapping_for(
450
+ tag_name, training_data
451
+ )
452
+
453
+ if tag_id_index_mapping:
454
+ _tag_specs.append(
455
+ EntityTagSpec(
456
+ tag_name=tag_name,
457
+ tags_to_ids=tag_id_index_mapping,
458
+ ids_to_tags=self._invert_mapping(tag_id_index_mapping),
459
+ num_tags=len(tag_id_index_mapping),
460
+ )
461
+ )
462
+
463
+ return _tag_specs
464
+
465
+ @staticmethod
466
+ def _tag_id_index_mapping_for(
467
+ tag_name: Text, training_data: TrainingData
468
+ ) -> Optional[Dict[Text, int]]:
469
+ """Create mapping from tag name to id."""
470
+ if tag_name == ENTITY_ATTRIBUTE_ROLE:
471
+ distinct_tags = training_data.entity_roles
472
+ elif tag_name == ENTITY_ATTRIBUTE_GROUP:
473
+ distinct_tags = training_data.entity_groups
474
+ else:
475
+ distinct_tags = training_data.entities
476
+
477
+ distinct_tags = distinct_tags - {NO_ENTITY_TAG} - {None}
478
+
479
+ if not distinct_tags:
480
+ return None
481
+
482
+ tag_id_dict = {
483
+ tag_id: idx for idx, tag_id in enumerate(sorted(distinct_tags), 1)
484
+ }
485
+ # NO_ENTITY_TAG corresponds to non-entity which should correspond to 0 index
486
+ # needed for correct prediction for padding
487
+ tag_id_dict[NO_ENTITY_TAG] = 0
488
+
489
+ return tag_id_dict
490
+
491
+ @staticmethod
492
+ def _find_example_for_label(
493
+ label: Text, examples: List[Message], attribute: Text
494
+ ) -> Optional[Message]:
495
+ for ex in examples:
496
+ if ex.get(attribute) == label:
497
+ return ex
498
+ return None
499
+
500
+ def _check_labels_features_exist(
501
+ self, labels_example: List[Message], attribute: Text
502
+ ) -> bool:
503
+ """Checks if all labels have features set."""
504
+ return all(
505
+ label_example.features_present(
506
+ attribute, self.component_config[FEATURIZERS]
507
+ )
508
+ for label_example in labels_example
509
+ )
510
+
511
+ def _extract_features(
512
+ self, message: Message, attribute: Text
513
+ ) -> Dict[Text, Union[scipy.sparse.spmatrix, np.ndarray]]:
514
+ (
515
+ sparse_sequence_features,
516
+ sparse_sentence_features,
517
+ ) = message.get_sparse_features(attribute, self.component_config[FEATURIZERS])
518
+ dense_sequence_features, dense_sentence_features = message.get_dense_features(
519
+ attribute, self.component_config[FEATURIZERS]
520
+ )
521
+
522
+ if dense_sequence_features is not None and sparse_sequence_features is not None:
523
+ if (
524
+ dense_sequence_features.features.shape[0]
525
+ != sparse_sequence_features.features.shape[0]
526
+ ):
527
+ raise ValueError(
528
+ f"Sequence dimensions for sparse and dense sequence features "
529
+ f"don't coincide in '{message.get(TEXT)}'"
530
+ f"for attribute '{attribute}'."
531
+ )
532
+ if dense_sentence_features is not None and sparse_sentence_features is not None:
533
+ if (
534
+ dense_sentence_features.features.shape[0]
535
+ != sparse_sentence_features.features.shape[0]
536
+ ):
537
+ raise ValueError(
538
+ f"Sequence dimensions for sparse and dense sentence features "
539
+ f"don't coincide in '{message.get(TEXT)}'"
540
+ f"for attribute '{attribute}'."
541
+ )
542
+
543
+ # If we don't use the transformer and we don't want to do entity recognition,
544
+ # to speed up training take only the sentence features as feature vector.
545
+ # We would not make use of the sequence anyway in this setup. Carrying over
546
+ # those features to the actual training process takes quite some time.
547
+ if (
548
+ self.component_config[NUM_TRANSFORMER_LAYERS] == 0
549
+ and not self.component_config[ENTITY_RECOGNITION]
550
+ and attribute not in [INTENT, INTENT_RESPONSE_KEY]
551
+ ):
552
+ sparse_sequence_features = None
553
+ dense_sequence_features = None
554
+
555
+ out = {}
556
+
557
+ if sparse_sentence_features is not None:
558
+ out[f"{SPARSE}_{SENTENCE}"] = sparse_sentence_features.features
559
+ if sparse_sequence_features is not None:
560
+ out[f"{SPARSE}_{SEQUENCE}"] = sparse_sequence_features.features
561
+ if dense_sentence_features is not None:
562
+ out[f"{DENSE}_{SENTENCE}"] = dense_sentence_features.features
563
+ if dense_sequence_features is not None:
564
+ out[f"{DENSE}_{SEQUENCE}"] = dense_sequence_features.features
565
+
566
+ return out
567
+
568
+ def _check_input_dimension_consistency(self, model_data: RasaModelData) -> None:
569
+ """Checks if features have same dimensionality if hidden layers are shared."""
570
+ if self.component_config.get(SHARE_HIDDEN_LAYERS):
571
+ num_text_sentence_features = model_data.number_of_units(TEXT, SENTENCE)
572
+ num_label_sentence_features = model_data.number_of_units(LABEL, SENTENCE)
573
+ num_text_sequence_features = model_data.number_of_units(TEXT, SEQUENCE)
574
+ num_label_sequence_features = model_data.number_of_units(LABEL, SEQUENCE)
575
+
576
+ if (0 < num_text_sentence_features != num_label_sentence_features > 0) or (
577
+ 0 < num_text_sequence_features != num_label_sequence_features > 0
578
+ ):
579
+ raise ValueError(
580
+ "If embeddings are shared text features and label features "
581
+ "must coincide. Check the output dimensions of previous components."
582
+ )
583
+
584
+ def _extract_labels_precomputed_features(
585
+ self, label_examples: List[Message], attribute: Text = INTENT
586
+ ) -> Tuple[List[FeatureArray], List[FeatureArray]]:
587
+ """Collects precomputed encodings."""
588
+ features = defaultdict(list)
589
+
590
+ for e in label_examples:
591
+ label_features = self._extract_features(e, attribute)
592
+ for feature_key, feature_value in label_features.items():
593
+ features[feature_key].append(feature_value)
594
+ sequence_features = []
595
+ sentence_features = []
596
+ for feature_name, feature_value in features.items():
597
+ if SEQUENCE in feature_name:
598
+ sequence_features.append(
599
+ FeatureArray(np.array(feature_value), number_of_dimensions=3)
600
+ )
601
+ else:
602
+ sentence_features.append(
603
+ FeatureArray(np.array(feature_value), number_of_dimensions=3)
604
+ )
605
+ return sequence_features, sentence_features
606
+
607
+ @staticmethod
608
+ def _compute_default_label_features(
609
+ labels_example: List[Message],
610
+ ) -> List[FeatureArray]:
611
+ """Computes one-hot representation for the labels."""
612
+ logger.debug("No label features found. Computing default label features.")
613
+
614
+ eye_matrix = np.eye(len(labels_example), dtype=np.float32)
615
+ # add sequence dimension to one-hot labels
616
+ return [
617
+ FeatureArray(
618
+ np.array([np.expand_dims(a, 0) for a in eye_matrix]),
619
+ number_of_dimensions=3,
620
+ )
621
+ ]
622
+
623
+ def _create_label_data(
624
+ self,
625
+ training_data: TrainingData,
626
+ label_id_dict: Dict[Text, int],
627
+ attribute: Text,
628
+ ) -> RasaModelData:
629
+ """Create matrix with label_ids encoded in rows as bag of words.
630
+
631
+ Find a training example for each label and get the encoded features
632
+ from the corresponding Message object.
633
+ If the features are already computed, fetch them from the message object
634
+ else compute a one hot encoding for the label as the feature vector.
635
+ """
636
+ # Collect one example for each label
637
+ labels_idx_examples = []
638
+ for label_name, idx in label_id_dict.items():
639
+ label_example = self._find_example_for_label(
640
+ label_name, training_data.intent_examples, attribute
641
+ )
642
+ labels_idx_examples.append((idx, label_example))
643
+
644
+ # Sort the list of tuples based on label_idx
645
+ labels_idx_examples = sorted(labels_idx_examples, key=lambda x: x[0])
646
+ labels_example = [example for (_, example) in labels_idx_examples]
647
+ # Collect features, precomputed if they exist, else compute on the fly
648
+ if self._check_labels_features_exist(labels_example, attribute):
649
+ (
650
+ sequence_features,
651
+ sentence_features,
652
+ ) = self._extract_labels_precomputed_features(labels_example, attribute)
653
+ else:
654
+ sequence_features = None
655
+ sentence_features = self._compute_default_label_features(labels_example)
656
+
657
+ label_data = RasaModelData()
658
+ label_data.add_features(LABEL, SEQUENCE, sequence_features)
659
+ label_data.add_features(LABEL, SENTENCE, sentence_features)
660
+ if label_data.does_feature_not_exist(
661
+ LABEL, SENTENCE
662
+ ) and label_data.does_feature_not_exist(LABEL, SEQUENCE):
663
+ raise ValueError(
664
+ "No label features are present. Please check your configuration file."
665
+ )
666
+
667
+ label_ids = np.array([idx for (idx, _) in labels_idx_examples])
668
+ # explicitly add last dimension to label_ids
669
+ # to track correctly dynamic sequences
670
+ label_data.add_features(
671
+ LABEL_KEY,
672
+ LABEL_SUB_KEY,
673
+ [
674
+ FeatureArray(
675
+ np.expand_dims(label_ids, -1),
676
+ number_of_dimensions=2,
677
+ )
678
+ ],
679
+ )
680
+
681
+ label_data.add_lengths(LABEL, SEQUENCE_LENGTH, LABEL, SEQUENCE)
682
+
683
+ return label_data
684
+
685
+ def _use_default_label_features(self, label_ids: np.ndarray) -> List[FeatureArray]:
686
+ if self._label_data is None:
687
+ return []
688
+
689
+ feature_arrays = self._label_data.get(LABEL, SENTENCE)
690
+ all_label_features = feature_arrays[0]
691
+ return [
692
+ FeatureArray(
693
+ np.array([all_label_features[label_id] for label_id in label_ids]),
694
+ number_of_dimensions=all_label_features.number_of_dimensions,
695
+ )
696
+ ]
697
+
698
+ def _create_model_data(
699
+ self,
700
+ training_data: List[Message],
701
+ label_id_dict: Optional[Dict[Text, int]] = None,
702
+ label_attribute: Optional[Text] = None,
703
+ training: bool = True,
704
+ ) -> RasaModelData:
705
+ """Prepare data for training and create a RasaModelData object."""
706
+ from rasa.utils.tensorflow import model_data_utils
707
+
708
+ attributes_to_consider = [TEXT]
709
+ if training and self.component_config[INTENT_CLASSIFICATION]:
710
+ # we don't have any intent labels during prediction, just add them during
711
+ # training
712
+ attributes_to_consider.append(label_attribute)
713
+ if (
714
+ training
715
+ and self.component_config[ENTITY_RECOGNITION]
716
+ and self._entity_tag_specs
717
+ ):
718
+ # Add entities as labels only during training and only if there was
719
+ # training data added for entities with DIET configured to predict entities.
720
+ attributes_to_consider.append(ENTITIES)
721
+
722
+ if training and label_attribute is not None:
723
+ # only use those training examples that have the label_attribute set
724
+ # during training
725
+ training_data = [
726
+ example for example in training_data if label_attribute in example.data
727
+ ]
728
+
729
+ training_data = [
730
+ message
731
+ for message in training_data
732
+ if message.features_present(
733
+ attribute=TEXT, featurizers=self.component_config.get(FEATURIZERS)
734
+ )
735
+ ]
736
+
737
+ if not training_data:
738
+ # no training data are present to train
739
+ return RasaModelData()
740
+
741
+ (
742
+ features_for_examples,
743
+ sparse_feature_sizes,
744
+ ) = model_data_utils.featurize_training_examples(
745
+ training_data,
746
+ attributes_to_consider,
747
+ entity_tag_specs=self._entity_tag_specs,
748
+ featurizers=self.component_config[FEATURIZERS],
749
+ bilou_tagging=self.component_config[BILOU_FLAG],
750
+ )
751
+ attribute_data, _ = model_data_utils.convert_to_data_format(
752
+ features_for_examples, consider_dialogue_dimension=False
753
+ )
754
+
755
+ model_data = RasaModelData(
756
+ label_key=self.label_key, label_sub_key=self.label_sub_key
757
+ )
758
+ model_data.add_data(attribute_data)
759
+ model_data.add_lengths(TEXT, SEQUENCE_LENGTH, TEXT, SEQUENCE)
760
+ # Current implementation doesn't yet account for updating sparse
761
+ # feature sizes of label attributes. That's why we remove them.
762
+ sparse_feature_sizes = self._remove_label_sparse_feature_sizes(
763
+ sparse_feature_sizes=sparse_feature_sizes, label_attribute=label_attribute
764
+ )
765
+ model_data.add_sparse_feature_sizes(sparse_feature_sizes)
766
+
767
+ self._add_label_features(
768
+ model_data, training_data, label_attribute, label_id_dict, training
769
+ )
770
+
771
+ # make sure all keys are in the same order during training and prediction
772
+ # as we rely on the order of key and sub-key when constructing the actual
773
+ # tensors from the model data
774
+ model_data.sort()
775
+
776
+ return model_data
777
+
778
+ @staticmethod
779
+ def _remove_label_sparse_feature_sizes(
780
+ sparse_feature_sizes: Dict[Text, Dict[Text, List[int]]],
781
+ label_attribute: Optional[Text] = None,
782
+ ) -> Dict[Text, Dict[Text, List[int]]]:
783
+ if label_attribute in sparse_feature_sizes:
784
+ del sparse_feature_sizes[label_attribute]
785
+ return sparse_feature_sizes
786
+
787
+ def _add_label_features(
788
+ self,
789
+ model_data: RasaModelData,
790
+ training_data: List[Message],
791
+ label_attribute: Text,
792
+ label_id_dict: Dict[Text, int],
793
+ training: bool = True,
794
+ ) -> None:
795
+ label_ids = []
796
+ if training and self.component_config[INTENT_CLASSIFICATION]:
797
+ for example in training_data:
798
+ if example.get(label_attribute):
799
+ label_ids.append(label_id_dict[example.get(label_attribute)])
800
+ # explicitly add last dimension to label_ids
801
+ # to track correctly dynamic sequences
802
+ model_data.add_features(
803
+ LABEL_KEY,
804
+ LABEL_SUB_KEY,
805
+ [
806
+ FeatureArray(
807
+ np.expand_dims(label_ids, -1),
808
+ number_of_dimensions=2,
809
+ )
810
+ ],
811
+ )
812
+
813
+ if (
814
+ label_attribute
815
+ and model_data.does_feature_not_exist(label_attribute, SENTENCE)
816
+ and model_data.does_feature_not_exist(label_attribute, SEQUENCE)
817
+ ):
818
+ # no label features are present, get default features from _label_data
819
+ model_data.add_features(
820
+ LABEL, SENTENCE, self._use_default_label_features(np.array(label_ids))
821
+ )
822
+
823
+ # as label_attribute can have different values, e.g. INTENT or RESPONSE,
824
+ # copy over the features to the LABEL key to make
825
+ # it easier to access the label features inside the model itself
826
+ model_data.update_key(label_attribute, SENTENCE, LABEL, SENTENCE)
827
+ model_data.update_key(label_attribute, SEQUENCE, LABEL, SEQUENCE)
828
+ model_data.update_key(label_attribute, MASK, LABEL, MASK)
829
+
830
+ model_data.add_lengths(LABEL, SEQUENCE_LENGTH, LABEL, SEQUENCE)
831
+
832
+ # train helpers
833
+ def preprocess_train_data(self, training_data: TrainingData) -> RasaModelData:
834
+ """Prepares data for training.
835
+
836
+ Performs sanity checks on training data, extracts encodings for labels.
837
+ """
838
+ if (
839
+ self.component_config[BILOU_FLAG]
840
+ and self.component_config[ENTITY_RECOGNITION]
841
+ ):
842
+ bilou_utils.apply_bilou_schema(training_data)
843
+
844
+ label_id_index_mapping = self._label_id_index_mapping(
845
+ training_data, attribute=INTENT
846
+ )
847
+
848
+ if not label_id_index_mapping:
849
+ # no labels are present to train
850
+ return RasaModelData()
851
+
852
+ self.index_label_id_mapping = self._invert_mapping(label_id_index_mapping)
853
+
854
+ self._label_data = self._create_label_data(
855
+ training_data, label_id_index_mapping, attribute=INTENT
856
+ )
857
+
858
+ self._entity_tag_specs = self._create_entity_tag_specs(training_data)
859
+
860
+ label_attribute = (
861
+ INTENT if self.component_config[INTENT_CLASSIFICATION] else None
862
+ )
863
+ model_data = self._create_model_data(
864
+ training_data.nlu_examples,
865
+ label_id_index_mapping,
866
+ label_attribute=label_attribute,
867
+ )
868
+
869
+ self._check_input_dimension_consistency(model_data)
870
+
871
+ return model_data
872
+
873
+ @staticmethod
874
+ def _check_enough_labels(model_data: RasaModelData) -> bool:
875
+ return len(np.unique(model_data.get(LABEL_KEY, LABEL_SUB_KEY))) >= 2
876
+
877
+ def train(self, training_data: TrainingData) -> Resource:
878
+ """Train the embedding intent classifier on a data set."""
879
+ model_data = self.preprocess_train_data(training_data)
880
+ if model_data.is_empty():
881
+ logger.debug(
882
+ f"Cannot train '{self.__class__.__name__}'. No data was provided. "
883
+ f"Skipping training of the classifier."
884
+ )
885
+ return self._resource
886
+
887
+ if not self.model and self.finetune_mode:
888
+ raise rasa.shared.exceptions.InvalidParameterException(
889
+ f"{self.__class__.__name__} was instantiated "
890
+ f"with `model=None` and `finetune_mode=True`. "
891
+ f"This is not a valid combination as the component "
892
+ f"needs an already instantiated and trained model "
893
+ f"to continue training in finetune mode."
894
+ )
895
+
896
+ if self.component_config.get(INTENT_CLASSIFICATION):
897
+ if not self._check_enough_labels(model_data):
898
+ logger.error(
899
+ f"Cannot train '{self.__class__.__name__}'. "
900
+ f"Need at least 2 different intent classes. "
901
+ f"Skipping training of classifier."
902
+ )
903
+ return self._resource
904
+ if self.component_config.get(ENTITY_RECOGNITION):
905
+ self.check_correct_entity_annotations(training_data)
906
+
907
+ # keep one example for persisting and loading
908
+ self._data_example = model_data.first_data_example()
909
+
910
+ if not self.finetune_mode:
911
+ # No pre-trained model to load from. Create a new instance of the model.
912
+ self.model = self._instantiate_model_class(model_data)
913
+ self.model.compile(
914
+ optimizer=tf.keras.optimizers.Adam(
915
+ self.component_config[LEARNING_RATE]
916
+ ),
917
+ run_eagerly=self.component_config[RUN_EAGERLY],
918
+ )
919
+ else:
920
+ if self.model is None:
921
+ raise ModelNotFound("Model could not be found. ")
922
+
923
+ self.model.adjust_for_incremental_training(
924
+ data_example=self._data_example,
925
+ new_sparse_feature_sizes=model_data.get_sparse_feature_sizes(),
926
+ old_sparse_feature_sizes=self._sparse_feature_sizes,
927
+ )
928
+ self._sparse_feature_sizes = model_data.get_sparse_feature_sizes()
929
+
930
+ data_generator, validation_data_generator = train_utils.create_data_generators(
931
+ model_data,
932
+ self.component_config[BATCH_SIZES],
933
+ self.component_config[EPOCHS],
934
+ self.component_config[BATCH_STRATEGY],
935
+ self.component_config[EVAL_NUM_EXAMPLES],
936
+ self.component_config[RANDOM_SEED],
937
+ drop_small_last_batch=self.component_config[DROP_SMALL_LAST_BATCH],
938
+ )
939
+ callbacks = train_utils.create_common_callbacks(
940
+ self.component_config[EPOCHS],
941
+ self.component_config[TENSORBOARD_LOG_DIR],
942
+ self.component_config[TENSORBOARD_LOG_LEVEL],
943
+ self.tmp_checkpoint_dir,
944
+ )
945
+
946
+ self.model.fit(
947
+ data_generator,
948
+ epochs=self.component_config[EPOCHS],
949
+ validation_data=validation_data_generator,
950
+ validation_freq=self.component_config[EVAL_NUM_EPOCHS],
951
+ callbacks=callbacks,
952
+ verbose=False,
953
+ shuffle=False, # we use custom shuffle inside data generator
954
+ )
955
+
956
+ self.persist()
957
+
958
+ return self._resource
959
+
960
+ # process helpers
961
+ def _predict(
962
+ self, message: Message
963
+ ) -> Optional[Dict[Text, Union[tf.Tensor, Dict[Text, tf.Tensor]]]]:
964
+ if self.model is None:
965
+ logger.debug(
966
+ f"There is no trained model for '{self.__class__.__name__}': The "
967
+ f"component is either not trained or didn't receive enough training "
968
+ f"data."
969
+ )
970
+ return None
971
+
972
+ # create session data from message and convert it into a batch of 1
973
+ model_data = self._create_model_data([message], training=False)
974
+ if model_data.is_empty():
975
+ return None
976
+ return self.model.run_inference(model_data)
977
+
978
+ def _predict_label(
979
+ self, predict_out: Optional[Dict[Text, tf.Tensor]]
980
+ ) -> Tuple[Dict[Text, Any], List[Dict[Text, Any]]]:
981
+ """Predicts the intent of the provided message."""
982
+ label: Dict[Text, Any] = {"name": None, "confidence": 0.0}
983
+ label_ranking: List[Dict[Text, Any]] = []
984
+
985
+ if predict_out is None:
986
+ return label, label_ranking
987
+
988
+ message_sim = predict_out["i_scores"]
989
+ message_sim = message_sim.flatten() # sim is a matrix
990
+
991
+ # if X contains all zeros do not predict some label
992
+ if message_sim.size == 0:
993
+ return label, label_ranking
994
+
995
+ # rank the confidences
996
+ ranking_length = self.component_config[RANKING_LENGTH]
997
+ renormalize = (
998
+ self.component_config[RENORMALIZE_CONFIDENCES]
999
+ and self.component_config[MODEL_CONFIDENCE] == SOFTMAX
1000
+ )
1001
+ ranked_label_indices, message_sim = train_utils.rank_and_mask(
1002
+ message_sim, ranking_length=ranking_length, renormalize=renormalize
1003
+ )
1004
+
1005
+ # construct the label and ranking
1006
+ casted_message_sim: List[float] = message_sim.tolist() # np.float to float
1007
+ top_label_idx = ranked_label_indices[0]
1008
+ label = {
1009
+ "name": self.index_label_id_mapping[top_label_idx],
1010
+ "confidence": casted_message_sim[top_label_idx],
1011
+ }
1012
+
1013
+ ranking = [(idx, casted_message_sim[idx]) for idx in ranked_label_indices]
1014
+ label_ranking = [
1015
+ {"name": self.index_label_id_mapping[label_idx], "confidence": score}
1016
+ for label_idx, score in ranking
1017
+ ]
1018
+
1019
+ return label, label_ranking
1020
+
1021
+ def _predict_entities(
1022
+ self, predict_out: Optional[Dict[Text, tf.Tensor]], message: Message
1023
+ ) -> List[Dict]:
1024
+ if predict_out is None:
1025
+ return []
1026
+
1027
+ predicted_tags, confidence_values = train_utils.entity_label_to_tags(
1028
+ predict_out, self._entity_tag_specs, self.component_config[BILOU_FLAG]
1029
+ )
1030
+
1031
+ entities = self.convert_predictions_into_entities(
1032
+ message.get(TEXT),
1033
+ message.get(TOKENS_NAMES[TEXT], []),
1034
+ predicted_tags,
1035
+ self.split_entities_config,
1036
+ confidence_values,
1037
+ )
1038
+
1039
+ entities = self.add_extractor_name(entities)
1040
+ entities = message.get(ENTITIES, []) + entities
1041
+
1042
+ return entities
1043
+
1044
+ def process(self, messages: List[Message]) -> List[Message]:
1045
+ """Augments the message with intents, entities, and diagnostic data."""
1046
+ for message in messages:
1047
+ out = self._predict(message)
1048
+
1049
+ if self.component_config[INTENT_CLASSIFICATION]:
1050
+ label, label_ranking = self._predict_label(out)
1051
+
1052
+ message.set(INTENT, label, add_to_output=True)
1053
+ message.set("intent_ranking", label_ranking, add_to_output=True)
1054
+
1055
+ if self.component_config[ENTITY_RECOGNITION]:
1056
+ entities = self._predict_entities(out, message)
1057
+
1058
+ message.set(ENTITIES, entities, add_to_output=True)
1059
+
1060
+ if out and self._execution_context.should_add_diagnostic_data:
1061
+ message.add_diagnostic_data(
1062
+ self._execution_context.node_name, out.get(DIAGNOSTIC_DATA)
1063
+ )
1064
+
1065
+ return messages
1066
+
1067
+ def persist(self) -> None:
1068
+ """Persist this model into the passed directory."""
1069
+ if self.model is None:
1070
+ return None
1071
+
1072
+ with self._model_storage.write_to(self._resource) as model_path:
1073
+ file_name = self.__class__.__name__
1074
+ tf_model_file = model_path / f"{file_name}.tf_model"
1075
+
1076
+ rasa.shared.utils.io.create_directory_for_file(tf_model_file)
1077
+
1078
+ if self.component_config[CHECKPOINT_MODEL] and self.tmp_checkpoint_dir:
1079
+ self.model.load_weights(self.tmp_checkpoint_dir / "checkpoint.tf_model")
1080
+ # Save an empty file to flag that this model has been
1081
+ # produced using checkpointing
1082
+ checkpoint_marker = model_path / f"{file_name}.from_checkpoint.pkl"
1083
+ checkpoint_marker.touch()
1084
+
1085
+ self.model.save(str(tf_model_file))
1086
+
1087
+ # save data example
1088
+ serialize_nested_feature_arrays(
1089
+ self._data_example,
1090
+ model_path / f"{file_name}.data_example.st",
1091
+ model_path / f"{file_name}.data_example_metadata.json",
1092
+ )
1093
+ # save label data
1094
+ serialize_nested_feature_arrays(
1095
+ dict(self._label_data.data) if self._label_data is not None else {},
1096
+ model_path / f"{file_name}.label_data.st",
1097
+ model_path / f"{file_name}.label_data_metadata.json",
1098
+ )
1099
+
1100
+ rasa.shared.utils.io.dump_obj_as_json_to_file(
1101
+ model_path / f"{file_name}.sparse_feature_sizes.json",
1102
+ self._sparse_feature_sizes,
1103
+ )
1104
+ rasa.shared.utils.io.dump_obj_as_json_to_file(
1105
+ model_path / f"{file_name}.index_label_id_mapping.json",
1106
+ self.index_label_id_mapping,
1107
+ )
1108
+
1109
+ entity_tag_specs = (
1110
+ [tag_spec._asdict() for tag_spec in self._entity_tag_specs]
1111
+ if self._entity_tag_specs
1112
+ else []
1113
+ )
1114
+ rasa.shared.utils.io.dump_obj_as_json_to_file(
1115
+ model_path / f"{file_name}.entity_tag_specs.json", entity_tag_specs
1116
+ )
1117
+
1118
+ @classmethod
1119
+ def load(
1120
+ cls: Type[DIETClassifierT],
1121
+ config: Dict[Text, Any],
1122
+ model_storage: ModelStorage,
1123
+ resource: Resource,
1124
+ execution_context: ExecutionContext,
1125
+ **kwargs: Any,
1126
+ ) -> DIETClassifierT:
1127
+ """Loads a policy from the storage (see parent class for full docstring)."""
1128
+ try:
1129
+ with model_storage.read_from(resource) as model_path:
1130
+ return cls._load(
1131
+ model_path, config, model_storage, resource, execution_context
1132
+ )
1133
+ except ValueError:
1134
+ logger.debug(
1135
+ f"Failed to load {cls.__class__.__name__} from model storage. Resource "
1136
+ f"'{resource.name}' doesn't exist."
1137
+ )
1138
+ return cls(config, model_storage, resource, execution_context)
1139
+
1140
+ @classmethod
1141
+ def _load(
1142
+ cls: Type[DIETClassifierT],
1143
+ model_path: Path,
1144
+ config: Dict[Text, Any],
1145
+ model_storage: ModelStorage,
1146
+ resource: Resource,
1147
+ execution_context: ExecutionContext,
1148
+ ) -> DIETClassifierT:
1149
+ """Loads the trained model from the provided directory."""
1150
+ (
1151
+ index_label_id_mapping,
1152
+ entity_tag_specs,
1153
+ label_data,
1154
+ data_example,
1155
+ sparse_feature_sizes,
1156
+ ) = cls._load_from_files(model_path)
1157
+
1158
+ config = train_utils.update_confidence_type(config)
1159
+ config = train_utils.update_similarity_type(config)
1160
+
1161
+ model = cls._load_model(
1162
+ entity_tag_specs,
1163
+ label_data,
1164
+ config,
1165
+ data_example,
1166
+ model_path,
1167
+ finetune_mode=execution_context.is_finetuning,
1168
+ )
1169
+
1170
+ return cls(
1171
+ config=config,
1172
+ model_storage=model_storage,
1173
+ resource=resource,
1174
+ execution_context=execution_context,
1175
+ index_label_id_mapping=index_label_id_mapping,
1176
+ entity_tag_specs=entity_tag_specs,
1177
+ model=model,
1178
+ sparse_feature_sizes=sparse_feature_sizes,
1179
+ )
1180
+
1181
+ @classmethod
1182
+ def _load_from_files(
1183
+ cls, model_path: Path
1184
+ ) -> Tuple[
1185
+ Dict[int, Text],
1186
+ List[EntityTagSpec],
1187
+ RasaModelData,
1188
+ Dict[Text, Dict[Text, List[FeatureArray]]],
1189
+ Dict[Text, Dict[Text, List[int]]],
1190
+ ]:
1191
+ file_name = cls.__name__
1192
+
1193
+ # load data example
1194
+ data_example = deserialize_nested_feature_arrays(
1195
+ str(model_path / f"{file_name}.data_example.st"),
1196
+ str(model_path / f"{file_name}.data_example_metadata.json"),
1197
+ )
1198
+ # load label data
1199
+ loaded_label_data = deserialize_nested_feature_arrays(
1200
+ str(model_path / f"{file_name}.label_data.st"),
1201
+ str(model_path / f"{file_name}.label_data_metadata.json"),
1202
+ )
1203
+ label_data = RasaModelData(data=loaded_label_data)
1204
+
1205
+ sparse_feature_sizes = rasa.shared.utils.io.read_json_file(
1206
+ model_path / f"{file_name}.sparse_feature_sizes.json"
1207
+ )
1208
+ index_label_id_mapping = rasa.shared.utils.io.read_json_file(
1209
+ model_path / f"{file_name}.index_label_id_mapping.json"
1210
+ )
1211
+ entity_tag_specs = rasa.shared.utils.io.read_json_file(
1212
+ model_path / f"{file_name}.entity_tag_specs.json"
1213
+ )
1214
+ entity_tag_specs = [
1215
+ EntityTagSpec(
1216
+ tag_name=tag_spec["tag_name"],
1217
+ ids_to_tags={
1218
+ int(key): value for key, value in tag_spec["ids_to_tags"].items()
1219
+ },
1220
+ tags_to_ids={
1221
+ key: int(value) for key, value in tag_spec["tags_to_ids"].items()
1222
+ },
1223
+ num_tags=tag_spec["num_tags"],
1224
+ )
1225
+ for tag_spec in entity_tag_specs
1226
+ ]
1227
+
1228
+ index_label_id_mapping = {
1229
+ int(key): value for key, value in index_label_id_mapping.items()
1230
+ }
1231
+
1232
+ return (
1233
+ index_label_id_mapping,
1234
+ entity_tag_specs,
1235
+ label_data,
1236
+ data_example,
1237
+ sparse_feature_sizes,
1238
+ )
1239
+
1240
+ @classmethod
1241
+ def _load_model(
1242
+ cls,
1243
+ entity_tag_specs: List[EntityTagSpec],
1244
+ label_data: RasaModelData,
1245
+ config: Dict[Text, Any],
1246
+ data_example: Dict[Text, Dict[Text, List[FeatureArray]]],
1247
+ model_path: Path,
1248
+ finetune_mode: bool = False,
1249
+ ) -> "RasaModel":
1250
+ file_name = cls.__name__
1251
+ tf_model_file = model_path / f"{file_name}.tf_model"
1252
+
1253
+ label_key = LABEL_KEY if config[INTENT_CLASSIFICATION] else None
1254
+ label_sub_key = LABEL_SUB_KEY if config[INTENT_CLASSIFICATION] else None
1255
+
1256
+ model_data_example = RasaModelData(
1257
+ label_key=label_key, label_sub_key=label_sub_key, data=data_example
1258
+ )
1259
+
1260
+ model = cls._load_model_class(
1261
+ tf_model_file,
1262
+ model_data_example,
1263
+ label_data,
1264
+ entity_tag_specs,
1265
+ config,
1266
+ finetune_mode=finetune_mode,
1267
+ )
1268
+
1269
+ return model
1270
+
1271
+ @classmethod
1272
+ def _load_model_class(
1273
+ cls,
1274
+ tf_model_file: Text,
1275
+ model_data_example: RasaModelData,
1276
+ label_data: RasaModelData,
1277
+ entity_tag_specs: List[EntityTagSpec],
1278
+ config: Dict[Text, Any],
1279
+ finetune_mode: bool,
1280
+ ) -> "RasaModel":
1281
+ predict_data_example = RasaModelData(
1282
+ label_key=model_data_example.label_key,
1283
+ data={
1284
+ feature_name: features
1285
+ for feature_name, features in model_data_example.items()
1286
+ if TEXT in feature_name
1287
+ },
1288
+ )
1289
+
1290
+ return cls.model_class().load(
1291
+ tf_model_file,
1292
+ model_data_example,
1293
+ predict_data_example,
1294
+ data_signature=model_data_example.get_signature(),
1295
+ label_data=label_data,
1296
+ entity_tag_specs=entity_tag_specs,
1297
+ config=copy.deepcopy(config),
1298
+ finetune_mode=finetune_mode,
1299
+ )
1300
+
1301
+ def _instantiate_model_class(self, model_data: RasaModelData) -> "RasaModel":
1302
+ return self.model_class()(
1303
+ data_signature=model_data.get_signature(),
1304
+ label_data=self._label_data,
1305
+ entity_tag_specs=self._entity_tag_specs,
1306
+ config=self.component_config,
1307
+ )
1308
+
1309
+
1310
+ class DIET(TransformerRasaModel):
1311
+ def __init__(
1312
+ self,
1313
+ data_signature: Dict[Text, Dict[Text, List[FeatureSignature]]],
1314
+ label_data: RasaModelData,
1315
+ entity_tag_specs: Optional[List[EntityTagSpec]],
1316
+ config: Dict[Text, Any],
1317
+ ) -> None:
1318
+ # create entity tag spec before calling super otherwise building the model
1319
+ # will fail
1320
+ super().__init__("DIET", config, data_signature, label_data)
1321
+ self._entity_tag_specs = self._ordered_tag_specs(entity_tag_specs)
1322
+
1323
+ self.predict_data_signature = {
1324
+ feature_name: features
1325
+ for feature_name, features in data_signature.items()
1326
+ if TEXT in feature_name
1327
+ }
1328
+
1329
+ # tf training
1330
+ self._create_metrics()
1331
+ self._update_metrics_to_log()
1332
+
1333
+ # needed for efficient prediction
1334
+ self.all_labels_embed: Optional[tf.Tensor] = None
1335
+
1336
+ self._prepare_layers()
1337
+
1338
+ @staticmethod
1339
+ def _ordered_tag_specs(
1340
+ entity_tag_specs: Optional[List[EntityTagSpec]],
1341
+ ) -> List[EntityTagSpec]:
1342
+ """Ensure that order of entity tag specs matches CRF layer order."""
1343
+ if entity_tag_specs is None:
1344
+ return []
1345
+
1346
+ crf_order = [
1347
+ ENTITY_ATTRIBUTE_TYPE,
1348
+ ENTITY_ATTRIBUTE_ROLE,
1349
+ ENTITY_ATTRIBUTE_GROUP,
1350
+ ]
1351
+
1352
+ ordered_tag_spec = []
1353
+
1354
+ for tag_name in crf_order:
1355
+ for tag_spec in entity_tag_specs:
1356
+ if tag_name == tag_spec.tag_name:
1357
+ ordered_tag_spec.append(tag_spec)
1358
+
1359
+ return ordered_tag_spec
1360
+
1361
+ def _check_data(self) -> None:
1362
+ if TEXT not in self.data_signature:
1363
+ raise InvalidConfigException(
1364
+ f"No text features specified. "
1365
+ f"Cannot train '{self.__class__.__name__}' model."
1366
+ )
1367
+ if self.config[INTENT_CLASSIFICATION]:
1368
+ if LABEL not in self.data_signature:
1369
+ raise InvalidConfigException(
1370
+ f"No label features specified. "
1371
+ f"Cannot train '{self.__class__.__name__}' model."
1372
+ )
1373
+
1374
+ if self.config[SHARE_HIDDEN_LAYERS]:
1375
+ different_sentence_signatures = False
1376
+ different_sequence_signatures = False
1377
+ if (
1378
+ SENTENCE in self.data_signature[TEXT]
1379
+ and SENTENCE in self.data_signature[LABEL]
1380
+ ):
1381
+ different_sentence_signatures = (
1382
+ self.data_signature[TEXT][SENTENCE]
1383
+ != self.data_signature[LABEL][SENTENCE]
1384
+ )
1385
+ if (
1386
+ SEQUENCE in self.data_signature[TEXT]
1387
+ and SEQUENCE in self.data_signature[LABEL]
1388
+ ):
1389
+ different_sequence_signatures = (
1390
+ self.data_signature[TEXT][SEQUENCE]
1391
+ != self.data_signature[LABEL][SEQUENCE]
1392
+ )
1393
+
1394
+ if different_sentence_signatures or different_sequence_signatures:
1395
+ raise ValueError(
1396
+ "If hidden layer weights are shared, data signatures "
1397
+ "for text_features and label_features must coincide."
1398
+ )
1399
+
1400
+ if self.config[ENTITY_RECOGNITION] and (
1401
+ ENTITIES not in self.data_signature
1402
+ or ENTITY_ATTRIBUTE_TYPE not in self.data_signature[ENTITIES]
1403
+ ):
1404
+ logger.debug(
1405
+ f"You specified '{self.__class__.__name__}' to train entities, but "
1406
+ f"no entities are present in the training data. Skipping training of "
1407
+ f"entities."
1408
+ )
1409
+ self.config[ENTITY_RECOGNITION] = False
1410
+
1411
+ def _create_metrics(self) -> None:
1412
+ # self.metrics will have the same order as they are created
1413
+ # so create loss metrics first to output losses first
1414
+ self.mask_loss = tf.keras.metrics.Mean(name="m_loss")
1415
+ self.intent_loss = tf.keras.metrics.Mean(name="i_loss")
1416
+ self.entity_loss = tf.keras.metrics.Mean(name="e_loss")
1417
+ self.entity_group_loss = tf.keras.metrics.Mean(name="g_loss")
1418
+ self.entity_role_loss = tf.keras.metrics.Mean(name="r_loss")
1419
+ # create accuracy metrics second to output accuracies second
1420
+ self.mask_acc = tf.keras.metrics.Mean(name="m_acc")
1421
+ self.intent_acc = tf.keras.metrics.Mean(name="i_acc")
1422
+ self.entity_f1 = tf.keras.metrics.Mean(name="e_f1")
1423
+ self.entity_group_f1 = tf.keras.metrics.Mean(name="g_f1")
1424
+ self.entity_role_f1 = tf.keras.metrics.Mean(name="r_f1")
1425
+
1426
+ def _update_metrics_to_log(self) -> None:
1427
+ debug_log_level = logging.getLogger("rasa").level == logging.DEBUG
1428
+
1429
+ if self.config[MASKED_LM]:
1430
+ self.metrics_to_log.append("m_acc")
1431
+ if debug_log_level:
1432
+ self.metrics_to_log.append("m_loss")
1433
+ if self.config[INTENT_CLASSIFICATION]:
1434
+ self.metrics_to_log.append("i_acc")
1435
+ if debug_log_level:
1436
+ self.metrics_to_log.append("i_loss")
1437
+ if self.config[ENTITY_RECOGNITION]:
1438
+ for tag_spec in self._entity_tag_specs:
1439
+ if tag_spec.num_tags != 0:
1440
+ name = tag_spec.tag_name
1441
+ self.metrics_to_log.append(f"{name[0]}_f1")
1442
+ if debug_log_level:
1443
+ self.metrics_to_log.append(f"{name[0]}_loss")
1444
+
1445
+ self._log_metric_info()
1446
+
1447
+ def _log_metric_info(self) -> None:
1448
+ metric_name = {
1449
+ "t": "total",
1450
+ "i": "intent",
1451
+ "e": "entity",
1452
+ "m": "mask",
1453
+ "r": "role",
1454
+ "g": "group",
1455
+ }
1456
+ logger.debug("Following metrics will be logged during training: ")
1457
+ for metric in self.metrics_to_log:
1458
+ parts = metric.split("_")
1459
+ name = f"{metric_name[parts[0]]} {parts[1]}"
1460
+ logger.debug(f" {metric} ({name})")
1461
+
1462
+ def _prepare_layers(self) -> None:
1463
+ # For user text, prepare layers that combine different feature types, embed
1464
+ # everything using a transformer and optionally also do masked language
1465
+ # modeling.
1466
+ self.text_name = TEXT
1467
+ self._tf_layers[f"sequence_layer.{self.text_name}"] = (
1468
+ rasa_layers.RasaSequenceLayer(
1469
+ self.text_name, self.data_signature[self.text_name], self.config
1470
+ )
1471
+ )
1472
+ if self.config[MASKED_LM]:
1473
+ self._prepare_mask_lm_loss(self.text_name)
1474
+
1475
+ # Intent labels are treated similarly to user text but without the transformer,
1476
+ # without masked language modelling, and with no dropout applied to the
1477
+ # individual features, only to the overall label embedding after all label
1478
+ # features have been combined.
1479
+ if self.config[INTENT_CLASSIFICATION]:
1480
+ self.label_name = TEXT if self.config[SHARE_HIDDEN_LAYERS] else LABEL
1481
+
1482
+ # disable input dropout applied to sparse and dense label features
1483
+ label_config = self.config.copy()
1484
+ label_config.update(
1485
+ {SPARSE_INPUT_DROPOUT: False, DENSE_INPUT_DROPOUT: False}
1486
+ )
1487
+
1488
+ self._tf_layers[f"feature_combining_layer.{self.label_name}"] = (
1489
+ rasa_layers.RasaFeatureCombiningLayer(
1490
+ self.label_name, self.label_signature[self.label_name], label_config
1491
+ )
1492
+ )
1493
+
1494
+ self._prepare_ffnn_layer(
1495
+ self.label_name,
1496
+ self.config[HIDDEN_LAYERS_SIZES][self.label_name],
1497
+ self.config[DROP_RATE],
1498
+ )
1499
+
1500
+ self._prepare_label_classification_layers(predictor_attribute=TEXT)
1501
+
1502
+ if self.config[ENTITY_RECOGNITION]:
1503
+ self._prepare_entity_recognition_layers()
1504
+
1505
+ def _prepare_mask_lm_loss(self, name: Text) -> None:
1506
+ # for embedding predicted tokens at masked positions
1507
+ self._prepare_embed_layers(f"{name}_lm_mask")
1508
+
1509
+ # for embedding the true tokens that got masked
1510
+ self._prepare_embed_layers(f"{name}_golden_token")
1511
+
1512
+ # mask loss is additional loss
1513
+ # set scaling to False, so that it doesn't overpower other losses
1514
+ self._prepare_dot_product_loss(f"{name}_mask", scale_loss=False)
1515
+
1516
+ def _create_bow(
1517
+ self,
1518
+ sequence_features: List[Union[tf.Tensor, tf.SparseTensor]],
1519
+ sentence_features: List[Union[tf.Tensor, tf.SparseTensor]],
1520
+ sequence_feature_lengths: tf.Tensor,
1521
+ name: Text,
1522
+ ) -> tf.Tensor:
1523
+ x, _ = self._tf_layers[f"feature_combining_layer.{name}"](
1524
+ (sequence_features, sentence_features, sequence_feature_lengths),
1525
+ training=self._training,
1526
+ )
1527
+
1528
+ # convert to bag-of-words by summing along the sequence dimension
1529
+ x = tf.reduce_sum(x, axis=1)
1530
+
1531
+ return self._tf_layers[f"ffnn.{name}"](x, self._training)
1532
+
1533
+ def _create_all_labels(self) -> Tuple[tf.Tensor, tf.Tensor]:
1534
+ all_label_ids = self.tf_label_data[LABEL_KEY][LABEL_SUB_KEY][0]
1535
+
1536
+ sequence_feature_lengths = self._get_sequence_feature_lengths(
1537
+ self.tf_label_data, LABEL
1538
+ )
1539
+
1540
+ x = self._create_bow(
1541
+ self.tf_label_data[LABEL][SEQUENCE],
1542
+ self.tf_label_data[LABEL][SENTENCE],
1543
+ sequence_feature_lengths,
1544
+ self.label_name,
1545
+ )
1546
+ all_labels_embed = self._tf_layers[f"embed.{LABEL}"](x)
1547
+
1548
+ return all_label_ids, all_labels_embed
1549
+
1550
+ def _mask_loss(
1551
+ self,
1552
+ outputs: tf.Tensor,
1553
+ inputs: tf.Tensor,
1554
+ seq_ids: tf.Tensor,
1555
+ mlm_mask_boolean: tf.Tensor,
1556
+ name: Text,
1557
+ ) -> tf.Tensor:
1558
+ # make sure there is at least one element in the mask
1559
+ mlm_mask_boolean = tf.cond(
1560
+ tf.reduce_any(mlm_mask_boolean),
1561
+ lambda: mlm_mask_boolean,
1562
+ lambda: tf.scatter_nd([[0, 0, 0]], [True], tf.shape(mlm_mask_boolean)),
1563
+ )
1564
+
1565
+ mlm_mask_boolean = tf.squeeze(mlm_mask_boolean, -1)
1566
+
1567
+ # Pick elements that were masked, throwing away the batch & sequence dimension
1568
+ # and effectively switching from shape (batch_size, sequence_length, units) to
1569
+ # (num_masked_elements, units).
1570
+ outputs = tf.boolean_mask(outputs, mlm_mask_boolean)
1571
+ inputs = tf.boolean_mask(inputs, mlm_mask_boolean)
1572
+ ids = tf.boolean_mask(seq_ids, mlm_mask_boolean)
1573
+
1574
+ tokens_predicted_embed = self._tf_layers[f"embed.{name}_lm_mask"](outputs)
1575
+ tokens_true_embed = self._tf_layers[f"embed.{name}_golden_token"](inputs)
1576
+
1577
+ # To limit the otherwise computationally expensive loss calculation, we
1578
+ # constrain the label space in MLM (i.e. token space) to only those tokens that
1579
+ # were masked in this batch. Hence the reduced list of token embeddings
1580
+ # (tokens_true_embed) and the reduced list of labels (ids) are passed as
1581
+ # all_labels_embed and all_labels, respectively. In the future, we could be less
1582
+ # restrictive and construct a slightly bigger label space which could include
1583
+ # tokens not masked in the current batch too.
1584
+ return self._tf_layers[f"loss.{name}_mask"](
1585
+ inputs_embed=tokens_predicted_embed,
1586
+ labels_embed=tokens_true_embed,
1587
+ labels=ids,
1588
+ all_labels_embed=tokens_true_embed,
1589
+ all_labels=ids,
1590
+ )
1591
+
1592
+ def _calculate_label_loss(
1593
+ self, text_features: tf.Tensor, label_features: tf.Tensor, label_ids: tf.Tensor
1594
+ ) -> tf.Tensor:
1595
+ all_label_ids, all_labels_embed = self._create_all_labels()
1596
+
1597
+ text_embed = self._tf_layers[f"embed.{TEXT}"](text_features)
1598
+ label_embed = self._tf_layers[f"embed.{LABEL}"](label_features)
1599
+
1600
+ return self._tf_layers[f"loss.{LABEL}"](
1601
+ text_embed, label_embed, label_ids, all_labels_embed, all_label_ids
1602
+ )
1603
+
1604
+ def batch_loss(
1605
+ self, batch_in: Union[Tuple[tf.Tensor, ...], Tuple[np.ndarray, ...]]
1606
+ ) -> tf.Tensor:
1607
+ """Calculates the loss for the given batch.
1608
+
1609
+ Args:
1610
+ batch_in: The batch.
1611
+
1612
+ Returns:
1613
+ The loss of the given batch.
1614
+ """
1615
+ tf_batch_data = self.batch_to_model_data_format(batch_in, self.data_signature)
1616
+
1617
+ sequence_feature_lengths = self._get_sequence_feature_lengths(
1618
+ tf_batch_data, TEXT
1619
+ )
1620
+
1621
+ (
1622
+ text_transformed,
1623
+ text_in,
1624
+ mask_combined_sequence_sentence,
1625
+ text_seq_ids,
1626
+ mlm_mask_boolean_text,
1627
+ _,
1628
+ ) = self._tf_layers[f"sequence_layer.{self.text_name}"](
1629
+ (
1630
+ tf_batch_data[TEXT][SEQUENCE],
1631
+ tf_batch_data[TEXT][SENTENCE],
1632
+ sequence_feature_lengths,
1633
+ ),
1634
+ training=self._training,
1635
+ )
1636
+
1637
+ losses = []
1638
+
1639
+ # Lengths of sequences in case of sentence-level features are always 1, but they
1640
+ # can effectively be 0 if sentence-level features aren't present.
1641
+ sentence_feature_lengths = self._get_sentence_feature_lengths(
1642
+ tf_batch_data, TEXT
1643
+ )
1644
+
1645
+ combined_sequence_sentence_feature_lengths = (
1646
+ sequence_feature_lengths + sentence_feature_lengths
1647
+ )
1648
+
1649
+ if self.config[MASKED_LM] and self._training:
1650
+ loss, acc = self._mask_loss(
1651
+ text_transformed, text_in, text_seq_ids, mlm_mask_boolean_text, TEXT
1652
+ )
1653
+ self.mask_loss.update_state(loss)
1654
+ self.mask_acc.update_state(acc)
1655
+ losses.append(loss)
1656
+
1657
+ if self.config[INTENT_CLASSIFICATION]:
1658
+ loss = self._batch_loss_intent(
1659
+ combined_sequence_sentence_feature_lengths,
1660
+ text_transformed,
1661
+ tf_batch_data,
1662
+ )
1663
+ losses.append(loss)
1664
+
1665
+ if self.config[ENTITY_RECOGNITION]:
1666
+ losses += self._batch_loss_entities(
1667
+ mask_combined_sequence_sentence,
1668
+ sequence_feature_lengths,
1669
+ text_transformed,
1670
+ tf_batch_data,
1671
+ )
1672
+
1673
+ return tf.math.add_n(losses)
1674
+
1675
+ def _batch_loss_intent(
1676
+ self,
1677
+ combined_sequence_sentence_feature_lengths_text: tf.Tensor,
1678
+ text_transformed: tf.Tensor,
1679
+ tf_batch_data: Dict[Text, Dict[Text, List[tf.Tensor]]],
1680
+ ) -> tf.Tensor:
1681
+ # get sentence features vector for intent classification
1682
+ sentence_vector = self._last_token(
1683
+ text_transformed, combined_sequence_sentence_feature_lengths_text
1684
+ )
1685
+
1686
+ sequence_feature_lengths_label = self._get_sequence_feature_lengths(
1687
+ tf_batch_data, LABEL
1688
+ )
1689
+
1690
+ label_ids = tf_batch_data[LABEL_KEY][LABEL_SUB_KEY][0]
1691
+ label = self._create_bow(
1692
+ tf_batch_data[LABEL][SEQUENCE],
1693
+ tf_batch_data[LABEL][SENTENCE],
1694
+ sequence_feature_lengths_label,
1695
+ self.label_name,
1696
+ )
1697
+ loss, acc = self._calculate_label_loss(sentence_vector, label, label_ids)
1698
+
1699
+ self._update_label_metrics(loss, acc)
1700
+
1701
+ return loss
1702
+
1703
+ def _update_label_metrics(self, loss: tf.Tensor, acc: tf.Tensor) -> None:
1704
+ self.intent_loss.update_state(loss)
1705
+ self.intent_acc.update_state(acc)
1706
+
1707
+ def _batch_loss_entities(
1708
+ self,
1709
+ mask_combined_sequence_sentence: tf.Tensor,
1710
+ sequence_feature_lengths: tf.Tensor,
1711
+ text_transformed: tf.Tensor,
1712
+ tf_batch_data: Dict[Text, Dict[Text, List[tf.Tensor]]],
1713
+ ) -> List[tf.Tensor]:
1714
+ losses = []
1715
+
1716
+ entity_tags = None
1717
+
1718
+ for tag_spec in self._entity_tag_specs:
1719
+ if tag_spec.num_tags == 0:
1720
+ continue
1721
+
1722
+ tag_ids = tf_batch_data[ENTITIES][tag_spec.tag_name][0]
1723
+ # add a zero (no entity) for the sentence features to match the shape of
1724
+ # inputs
1725
+ tag_ids = tf.pad(tag_ids, [[0, 0], [0, 1], [0, 0]])
1726
+
1727
+ loss, f1, _logits = self._calculate_entity_loss(
1728
+ text_transformed,
1729
+ tag_ids,
1730
+ mask_combined_sequence_sentence,
1731
+ sequence_feature_lengths,
1732
+ tag_spec.tag_name,
1733
+ entity_tags,
1734
+ )
1735
+
1736
+ if tag_spec.tag_name == ENTITY_ATTRIBUTE_TYPE:
1737
+ # use the entity tags as additional input for the role
1738
+ # and group CRF
1739
+ entity_tags = tf.one_hot(
1740
+ tf.cast(tag_ids[:, :, 0], tf.int32), depth=tag_spec.num_tags
1741
+ )
1742
+
1743
+ self._update_entity_metrics(loss, f1, tag_spec.tag_name)
1744
+
1745
+ losses.append(loss)
1746
+
1747
+ return losses
1748
+
1749
+ def _update_entity_metrics(
1750
+ self, loss: tf.Tensor, f1: tf.Tensor, tag_name: Text
1751
+ ) -> None:
1752
+ if tag_name == ENTITY_ATTRIBUTE_TYPE:
1753
+ self.entity_loss.update_state(loss)
1754
+ self.entity_f1.update_state(f1)
1755
+ elif tag_name == ENTITY_ATTRIBUTE_GROUP:
1756
+ self.entity_group_loss.update_state(loss)
1757
+ self.entity_group_f1.update_state(f1)
1758
+ elif tag_name == ENTITY_ATTRIBUTE_ROLE:
1759
+ self.entity_role_loss.update_state(loss)
1760
+ self.entity_role_f1.update_state(f1)
1761
+
1762
+ def prepare_for_predict(self) -> None:
1763
+ """Prepares the model for prediction."""
1764
+ if self.config[INTENT_CLASSIFICATION]:
1765
+ _, self.all_labels_embed = self._create_all_labels()
1766
+
1767
+ def batch_predict(
1768
+ self, batch_in: Union[Tuple[tf.Tensor, ...], Tuple[np.ndarray, ...]]
1769
+ ) -> Dict[Text, tf.Tensor]:
1770
+ """Predicts the output of the given batch.
1771
+
1772
+ Args:
1773
+ batch_in: The batch.
1774
+
1775
+ Returns:
1776
+ The output to predict.
1777
+ """
1778
+ tf_batch_data = self.batch_to_model_data_format(
1779
+ batch_in, self.predict_data_signature
1780
+ )
1781
+
1782
+ sequence_feature_lengths = self._get_sequence_feature_lengths(
1783
+ tf_batch_data, TEXT
1784
+ )
1785
+ sentence_feature_lengths = self._get_sentence_feature_lengths(
1786
+ tf_batch_data, TEXT
1787
+ )
1788
+
1789
+ text_transformed, _, _, _, _, attention_weights = self._tf_layers[
1790
+ f"sequence_layer.{self.text_name}"
1791
+ ](
1792
+ (
1793
+ tf_batch_data[TEXT][SEQUENCE],
1794
+ tf_batch_data[TEXT][SENTENCE],
1795
+ sequence_feature_lengths,
1796
+ ),
1797
+ training=self._training,
1798
+ )
1799
+ predictions = {
1800
+ DIAGNOSTIC_DATA: {
1801
+ "attention_weights": attention_weights,
1802
+ "text_transformed": text_transformed,
1803
+ }
1804
+ }
1805
+
1806
+ if self.config[INTENT_CLASSIFICATION]:
1807
+ predictions.update(
1808
+ self._batch_predict_intents(
1809
+ sequence_feature_lengths + sentence_feature_lengths,
1810
+ text_transformed,
1811
+ )
1812
+ )
1813
+
1814
+ if self.config[ENTITY_RECOGNITION]:
1815
+ predictions.update(
1816
+ self._batch_predict_entities(sequence_feature_lengths, text_transformed)
1817
+ )
1818
+
1819
+ return predictions
1820
+
1821
+ def _batch_predict_entities(
1822
+ self, sequence_feature_lengths: tf.Tensor, text_transformed: tf.Tensor
1823
+ ) -> Dict[Text, tf.Tensor]:
1824
+ predictions: Dict[Text, tf.Tensor] = {}
1825
+
1826
+ entity_tags = None
1827
+
1828
+ for tag_spec in self._entity_tag_specs:
1829
+ # skip crf layer if it was not trained
1830
+ if tag_spec.num_tags == 0:
1831
+ continue
1832
+
1833
+ name = tag_spec.tag_name
1834
+ _input = text_transformed
1835
+
1836
+ if entity_tags is not None:
1837
+ _tags = self._tf_layers[f"embed.{name}.tags"](entity_tags)
1838
+ _input = tf.concat([_input, _tags], axis=-1)
1839
+
1840
+ _logits = self._tf_layers[f"embed.{name}.logits"](_input)
1841
+ pred_ids, confidences = self._tf_layers[f"crf.{name}"](
1842
+ _logits, sequence_feature_lengths
1843
+ )
1844
+
1845
+ predictions[f"e_{name}_ids"] = pred_ids
1846
+ predictions[f"e_{name}_scores"] = confidences
1847
+
1848
+ if name == ENTITY_ATTRIBUTE_TYPE:
1849
+ # use the entity tags as additional input for the role
1850
+ # and group CRF
1851
+ entity_tags = tf.one_hot(
1852
+ tf.cast(pred_ids, tf.int32), depth=tag_spec.num_tags
1853
+ )
1854
+
1855
+ return predictions
1856
+
1857
+ def _batch_predict_intents(
1858
+ self,
1859
+ combined_sequence_sentence_feature_lengths: tf.Tensor,
1860
+ text_transformed: tf.Tensor,
1861
+ ) -> Dict[Text, tf.Tensor]:
1862
+ if self.all_labels_embed is None:
1863
+ raise ValueError(
1864
+ "The model was not prepared for prediction. "
1865
+ "Call `prepare_for_predict` first."
1866
+ )
1867
+
1868
+ # get sentence feature vector for intent classification
1869
+ sentence_vector = self._last_token(
1870
+ text_transformed, combined_sequence_sentence_feature_lengths
1871
+ )
1872
+ sentence_vector_embed = self._tf_layers[f"embed.{TEXT}"](sentence_vector)
1873
+
1874
+ _, scores = self._tf_layers[
1875
+ f"loss.{LABEL}"
1876
+ ].get_similarities_and_confidences_from_embeddings(
1877
+ sentence_vector_embed[:, tf.newaxis, :],
1878
+ self.all_labels_embed[tf.newaxis, :, :],
1879
+ )
1880
+
1881
+ return {"i_scores": scores}