rasa-pro 3.12.0.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rasa-pro might be problematic. Click here for more details.

Files changed (790) hide show
  1. README.md +41 -0
  2. rasa/__init__.py +9 -0
  3. rasa/__main__.py +177 -0
  4. rasa/anonymization/__init__.py +2 -0
  5. rasa/anonymization/anonymisation_rule_yaml_reader.py +91 -0
  6. rasa/anonymization/anonymization_pipeline.py +286 -0
  7. rasa/anonymization/anonymization_rule_executor.py +260 -0
  8. rasa/anonymization/anonymization_rule_orchestrator.py +120 -0
  9. rasa/anonymization/schemas/config.yml +47 -0
  10. rasa/anonymization/utils.py +118 -0
  11. rasa/api.py +160 -0
  12. rasa/cli/__init__.py +5 -0
  13. rasa/cli/arguments/__init__.py +0 -0
  14. rasa/cli/arguments/data.py +106 -0
  15. rasa/cli/arguments/default_arguments.py +207 -0
  16. rasa/cli/arguments/evaluate.py +65 -0
  17. rasa/cli/arguments/export.py +51 -0
  18. rasa/cli/arguments/interactive.py +74 -0
  19. rasa/cli/arguments/run.py +219 -0
  20. rasa/cli/arguments/shell.py +17 -0
  21. rasa/cli/arguments/test.py +211 -0
  22. rasa/cli/arguments/train.py +279 -0
  23. rasa/cli/arguments/visualize.py +34 -0
  24. rasa/cli/arguments/x.py +30 -0
  25. rasa/cli/data.py +354 -0
  26. rasa/cli/dialogue_understanding_test.py +251 -0
  27. rasa/cli/e2e_test.py +259 -0
  28. rasa/cli/evaluate.py +222 -0
  29. rasa/cli/export.py +250 -0
  30. rasa/cli/inspect.py +75 -0
  31. rasa/cli/interactive.py +166 -0
  32. rasa/cli/license.py +65 -0
  33. rasa/cli/llm_fine_tuning.py +403 -0
  34. rasa/cli/markers.py +78 -0
  35. rasa/cli/project_templates/__init__.py +0 -0
  36. rasa/cli/project_templates/calm/actions/__init__.py +0 -0
  37. rasa/cli/project_templates/calm/actions/action_template.py +27 -0
  38. rasa/cli/project_templates/calm/actions/add_contact.py +30 -0
  39. rasa/cli/project_templates/calm/actions/db.py +57 -0
  40. rasa/cli/project_templates/calm/actions/list_contacts.py +22 -0
  41. rasa/cli/project_templates/calm/actions/remove_contact.py +35 -0
  42. rasa/cli/project_templates/calm/config.yml +10 -0
  43. rasa/cli/project_templates/calm/credentials.yml +33 -0
  44. rasa/cli/project_templates/calm/data/flows/add_contact.yml +31 -0
  45. rasa/cli/project_templates/calm/data/flows/list_contacts.yml +14 -0
  46. rasa/cli/project_templates/calm/data/flows/remove_contact.yml +29 -0
  47. rasa/cli/project_templates/calm/db/contacts.json +10 -0
  48. rasa/cli/project_templates/calm/domain/add_contact.yml +39 -0
  49. rasa/cli/project_templates/calm/domain/list_contacts.yml +17 -0
  50. rasa/cli/project_templates/calm/domain/remove_contact.yml +38 -0
  51. rasa/cli/project_templates/calm/domain/shared.yml +10 -0
  52. rasa/cli/project_templates/calm/e2e_tests/cancelations/user_cancels_during_a_correction.yml +16 -0
  53. rasa/cli/project_templates/calm/e2e_tests/cancelations/user_changes_mind_on_a_whim.yml +7 -0
  54. rasa/cli/project_templates/calm/e2e_tests/corrections/user_corrects_contact_handle.yml +20 -0
  55. rasa/cli/project_templates/calm/e2e_tests/corrections/user_corrects_contact_name.yml +19 -0
  56. rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_adds_contact_to_their_list.yml +15 -0
  57. rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_lists_contacts.yml +5 -0
  58. rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_removes_contact.yml +11 -0
  59. rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_removes_contact_from_list.yml +12 -0
  60. rasa/cli/project_templates/calm/endpoints.yml +58 -0
  61. rasa/cli/project_templates/default/actions/__init__.py +0 -0
  62. rasa/cli/project_templates/default/actions/actions.py +27 -0
  63. rasa/cli/project_templates/default/config.yml +44 -0
  64. rasa/cli/project_templates/default/credentials.yml +33 -0
  65. rasa/cli/project_templates/default/data/nlu.yml +91 -0
  66. rasa/cli/project_templates/default/data/rules.yml +13 -0
  67. rasa/cli/project_templates/default/data/stories.yml +30 -0
  68. rasa/cli/project_templates/default/domain.yml +34 -0
  69. rasa/cli/project_templates/default/endpoints.yml +42 -0
  70. rasa/cli/project_templates/default/tests/test_stories.yml +91 -0
  71. rasa/cli/project_templates/tutorial/actions/__init__.py +0 -0
  72. rasa/cli/project_templates/tutorial/actions/actions.py +22 -0
  73. rasa/cli/project_templates/tutorial/config.yml +12 -0
  74. rasa/cli/project_templates/tutorial/credentials.yml +33 -0
  75. rasa/cli/project_templates/tutorial/data/flows.yml +8 -0
  76. rasa/cli/project_templates/tutorial/data/patterns.yml +11 -0
  77. rasa/cli/project_templates/tutorial/domain.yml +35 -0
  78. rasa/cli/project_templates/tutorial/endpoints.yml +55 -0
  79. rasa/cli/run.py +143 -0
  80. rasa/cli/scaffold.py +273 -0
  81. rasa/cli/shell.py +141 -0
  82. rasa/cli/studio/__init__.py +0 -0
  83. rasa/cli/studio/download.py +62 -0
  84. rasa/cli/studio/studio.py +296 -0
  85. rasa/cli/studio/train.py +59 -0
  86. rasa/cli/studio/upload.py +62 -0
  87. rasa/cli/telemetry.py +102 -0
  88. rasa/cli/test.py +280 -0
  89. rasa/cli/train.py +278 -0
  90. rasa/cli/utils.py +484 -0
  91. rasa/cli/visualize.py +40 -0
  92. rasa/cli/x.py +206 -0
  93. rasa/constants.py +45 -0
  94. rasa/core/__init__.py +17 -0
  95. rasa/core/actions/__init__.py +0 -0
  96. rasa/core/actions/action.py +1318 -0
  97. rasa/core/actions/action_clean_stack.py +59 -0
  98. rasa/core/actions/action_exceptions.py +24 -0
  99. rasa/core/actions/action_hangup.py +29 -0
  100. rasa/core/actions/action_repeat_bot_messages.py +89 -0
  101. rasa/core/actions/action_run_slot_rejections.py +210 -0
  102. rasa/core/actions/action_trigger_chitchat.py +31 -0
  103. rasa/core/actions/action_trigger_flow.py +109 -0
  104. rasa/core/actions/action_trigger_search.py +31 -0
  105. rasa/core/actions/constants.py +5 -0
  106. rasa/core/actions/custom_action_executor.py +191 -0
  107. rasa/core/actions/direct_custom_actions_executor.py +109 -0
  108. rasa/core/actions/e2e_stub_custom_action_executor.py +72 -0
  109. rasa/core/actions/forms.py +741 -0
  110. rasa/core/actions/grpc_custom_action_executor.py +251 -0
  111. rasa/core/actions/http_custom_action_executor.py +145 -0
  112. rasa/core/actions/loops.py +114 -0
  113. rasa/core/actions/two_stage_fallback.py +186 -0
  114. rasa/core/agent.py +559 -0
  115. rasa/core/auth_retry_tracker_store.py +122 -0
  116. rasa/core/brokers/__init__.py +0 -0
  117. rasa/core/brokers/broker.py +126 -0
  118. rasa/core/brokers/file.py +58 -0
  119. rasa/core/brokers/kafka.py +324 -0
  120. rasa/core/brokers/pika.py +388 -0
  121. rasa/core/brokers/sql.py +86 -0
  122. rasa/core/channels/__init__.py +61 -0
  123. rasa/core/channels/botframework.py +338 -0
  124. rasa/core/channels/callback.py +84 -0
  125. rasa/core/channels/channel.py +456 -0
  126. rasa/core/channels/console.py +241 -0
  127. rasa/core/channels/development_inspector.py +197 -0
  128. rasa/core/channels/facebook.py +419 -0
  129. rasa/core/channels/hangouts.py +329 -0
  130. rasa/core/channels/inspector/.eslintrc.cjs +25 -0
  131. rasa/core/channels/inspector/.gitignore +23 -0
  132. rasa/core/channels/inspector/README.md +54 -0
  133. rasa/core/channels/inspector/assets/favicon.ico +0 -0
  134. rasa/core/channels/inspector/assets/rasa-chat.js +2 -0
  135. rasa/core/channels/inspector/custom.d.ts +3 -0
  136. rasa/core/channels/inspector/dist/assets/arc-861ddd57.js +1 -0
  137. rasa/core/channels/inspector/dist/assets/array-9f3ba611.js +1 -0
  138. rasa/core/channels/inspector/dist/assets/c4Diagram-d0fbc5ce-921f02db.js +10 -0
  139. rasa/core/channels/inspector/dist/assets/classDiagram-936ed81e-b436c4f8.js +2 -0
  140. rasa/core/channels/inspector/dist/assets/classDiagram-v2-c3cb15f1-511a23cb.js +2 -0
  141. rasa/core/channels/inspector/dist/assets/createText-62fc7601-ef476ecd.js +7 -0
  142. rasa/core/channels/inspector/dist/assets/edges-f2ad444c-f1878e0a.js +4 -0
  143. rasa/core/channels/inspector/dist/assets/erDiagram-9d236eb7-fac75185.js +51 -0
  144. rasa/core/channels/inspector/dist/assets/flowDb-1972c806-201c5bbc.js +6 -0
  145. rasa/core/channels/inspector/dist/assets/flowDiagram-7ea5b25a-f904ae41.js +4 -0
  146. rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-b080d6f2.js +1 -0
  147. rasa/core/channels/inspector/dist/assets/flowchart-elk-definition-abe16c3d-1813da66.js +139 -0
  148. rasa/core/channels/inspector/dist/assets/ganttDiagram-9b5ea136-872af172.js +266 -0
  149. rasa/core/channels/inspector/dist/assets/gitGraphDiagram-99d0ae7c-34a0af5a.js +70 -0
  150. rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-128cfa44.ttf +0 -0
  151. rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-21dbcb97.woff +0 -0
  152. rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-222b5e26.svg +329 -0
  153. rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-9ad89b2a.woff2 +0 -0
  154. rasa/core/channels/inspector/dist/assets/index-2c4b9a3b-42ba3e3d.js +1 -0
  155. rasa/core/channels/inspector/dist/assets/index-37817b51.js +1317 -0
  156. rasa/core/channels/inspector/dist/assets/index-3ee28881.css +1 -0
  157. rasa/core/channels/inspector/dist/assets/infoDiagram-736b4530-6b731386.js +7 -0
  158. rasa/core/channels/inspector/dist/assets/init-77b53fdd.js +1 -0
  159. rasa/core/channels/inspector/dist/assets/journeyDiagram-df861f2b-e8579ac6.js +139 -0
  160. rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-60c05ee4.woff +0 -0
  161. rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-8335d9b8.svg +438 -0
  162. rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-9cc39c75.ttf +0 -0
  163. rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-ead13ccf.woff2 +0 -0
  164. rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-16705655.woff2 +0 -0
  165. rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-5aeb07f9.woff +0 -0
  166. rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-9c459044.ttf +0 -0
  167. rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-9e2898a4.svg +435 -0
  168. rasa/core/channels/inspector/dist/assets/layout-89e6403a.js +1 -0
  169. rasa/core/channels/inspector/dist/assets/line-dc73d3fc.js +1 -0
  170. rasa/core/channels/inspector/dist/assets/linear-f5b1d2bc.js +1 -0
  171. rasa/core/channels/inspector/dist/assets/mindmap-definition-beec6740-82cb74fa.js +109 -0
  172. rasa/core/channels/inspector/dist/assets/ordinal-ba9b4969.js +1 -0
  173. rasa/core/channels/inspector/dist/assets/path-53f90ab3.js +1 -0
  174. rasa/core/channels/inspector/dist/assets/pieDiagram-dbbf0591-bdf5f29b.js +35 -0
  175. rasa/core/channels/inspector/dist/assets/quadrantDiagram-4d7f4fd6-c7a0cbe4.js +7 -0
  176. rasa/core/channels/inspector/dist/assets/requirementDiagram-6fc4c22a-7ec5410f.js +52 -0
  177. rasa/core/channels/inspector/dist/assets/sankeyDiagram-8f13d901-caee5554.js +8 -0
  178. rasa/core/channels/inspector/dist/assets/sequenceDiagram-b655622a-2935f8db.js +122 -0
  179. rasa/core/channels/inspector/dist/assets/stateDiagram-59f0c015-8f5d9693.js +1 -0
  180. rasa/core/channels/inspector/dist/assets/stateDiagram-v2-2b26beab-d565d1de.js +1 -0
  181. rasa/core/channels/inspector/dist/assets/styles-080da4f6-75ad421d.js +110 -0
  182. rasa/core/channels/inspector/dist/assets/styles-3dcbcfbf-7e764226.js +159 -0
  183. rasa/core/channels/inspector/dist/assets/styles-9c745c82-7a4e0e61.js +207 -0
  184. rasa/core/channels/inspector/dist/assets/svgDrawCommon-4835440b-4019d1bf.js +1 -0
  185. rasa/core/channels/inspector/dist/assets/timeline-definition-5b62e21b-01ea12df.js +61 -0
  186. rasa/core/channels/inspector/dist/assets/xychartDiagram-2b33534f-89407137.js +7 -0
  187. rasa/core/channels/inspector/dist/index.html +42 -0
  188. rasa/core/channels/inspector/index.html +40 -0
  189. rasa/core/channels/inspector/jest.config.ts +13 -0
  190. rasa/core/channels/inspector/package.json +52 -0
  191. rasa/core/channels/inspector/setupTests.ts +2 -0
  192. rasa/core/channels/inspector/src/App.tsx +220 -0
  193. rasa/core/channels/inspector/src/components/Chat.tsx +95 -0
  194. rasa/core/channels/inspector/src/components/DiagramFlow.tsx +108 -0
  195. rasa/core/channels/inspector/src/components/DialogueInformation.tsx +187 -0
  196. rasa/core/channels/inspector/src/components/DialogueStack.tsx +136 -0
  197. rasa/core/channels/inspector/src/components/ExpandIcon.tsx +16 -0
  198. rasa/core/channels/inspector/src/components/FullscreenButton.tsx +45 -0
  199. rasa/core/channels/inspector/src/components/LoadingSpinner.tsx +22 -0
  200. rasa/core/channels/inspector/src/components/NoActiveFlow.tsx +21 -0
  201. rasa/core/channels/inspector/src/components/RasaLogo.tsx +32 -0
  202. rasa/core/channels/inspector/src/components/SaraDiagrams.tsx +39 -0
  203. rasa/core/channels/inspector/src/components/Slots.tsx +91 -0
  204. rasa/core/channels/inspector/src/components/Welcome.tsx +54 -0
  205. rasa/core/channels/inspector/src/helpers/audiostream.ts +191 -0
  206. rasa/core/channels/inspector/src/helpers/formatters.test.ts +392 -0
  207. rasa/core/channels/inspector/src/helpers/formatters.ts +306 -0
  208. rasa/core/channels/inspector/src/helpers/utils.ts +127 -0
  209. rasa/core/channels/inspector/src/main.tsx +13 -0
  210. rasa/core/channels/inspector/src/theme/Button/Button.ts +29 -0
  211. rasa/core/channels/inspector/src/theme/Heading/Heading.ts +31 -0
  212. rasa/core/channels/inspector/src/theme/Input/Input.ts +27 -0
  213. rasa/core/channels/inspector/src/theme/Link/Link.ts +10 -0
  214. rasa/core/channels/inspector/src/theme/Modal/Modal.ts +47 -0
  215. rasa/core/channels/inspector/src/theme/Table/Table.tsx +38 -0
  216. rasa/core/channels/inspector/src/theme/Tooltip/Tooltip.ts +12 -0
  217. rasa/core/channels/inspector/src/theme/base/breakpoints.ts +8 -0
  218. rasa/core/channels/inspector/src/theme/base/colors.ts +88 -0
  219. rasa/core/channels/inspector/src/theme/base/fonts/fontFaces.css +29 -0
  220. rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.eot +0 -0
  221. rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.svg +329 -0
  222. rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.ttf +0 -0
  223. rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.woff +0 -0
  224. rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.woff2 +0 -0
  225. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.eot +0 -0
  226. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.svg +438 -0
  227. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.ttf +0 -0
  228. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.woff +0 -0
  229. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.woff2 +0 -0
  230. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.eot +0 -0
  231. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.svg +435 -0
  232. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.ttf +0 -0
  233. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.woff +0 -0
  234. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.woff2 +0 -0
  235. rasa/core/channels/inspector/src/theme/base/radii.ts +9 -0
  236. rasa/core/channels/inspector/src/theme/base/shadows.ts +7 -0
  237. rasa/core/channels/inspector/src/theme/base/sizes.ts +7 -0
  238. rasa/core/channels/inspector/src/theme/base/space.ts +15 -0
  239. rasa/core/channels/inspector/src/theme/base/styles.ts +13 -0
  240. rasa/core/channels/inspector/src/theme/base/typography.ts +24 -0
  241. rasa/core/channels/inspector/src/theme/base/zIndices.ts +19 -0
  242. rasa/core/channels/inspector/src/theme/index.ts +101 -0
  243. rasa/core/channels/inspector/src/types.ts +84 -0
  244. rasa/core/channels/inspector/src/vite-env.d.ts +1 -0
  245. rasa/core/channels/inspector/tests/__mocks__/fileMock.ts +1 -0
  246. rasa/core/channels/inspector/tests/__mocks__/matchMedia.ts +16 -0
  247. rasa/core/channels/inspector/tests/__mocks__/styleMock.ts +1 -0
  248. rasa/core/channels/inspector/tests/renderWithProviders.tsx +14 -0
  249. rasa/core/channels/inspector/tsconfig.json +26 -0
  250. rasa/core/channels/inspector/tsconfig.node.json +10 -0
  251. rasa/core/channels/inspector/vite.config.ts +8 -0
  252. rasa/core/channels/inspector/yarn.lock +6249 -0
  253. rasa/core/channels/mattermost.py +229 -0
  254. rasa/core/channels/rasa_chat.py +126 -0
  255. rasa/core/channels/rest.py +230 -0
  256. rasa/core/channels/rocketchat.py +174 -0
  257. rasa/core/channels/slack.py +620 -0
  258. rasa/core/channels/socketio.py +302 -0
  259. rasa/core/channels/telegram.py +298 -0
  260. rasa/core/channels/twilio.py +169 -0
  261. rasa/core/channels/vier_cvg.py +374 -0
  262. rasa/core/channels/voice_ready/__init__.py +0 -0
  263. rasa/core/channels/voice_ready/audiocodes.py +501 -0
  264. rasa/core/channels/voice_ready/jambonz.py +121 -0
  265. rasa/core/channels/voice_ready/jambonz_protocol.py +396 -0
  266. rasa/core/channels/voice_ready/twilio_voice.py +403 -0
  267. rasa/core/channels/voice_ready/utils.py +37 -0
  268. rasa/core/channels/voice_stream/__init__.py +0 -0
  269. rasa/core/channels/voice_stream/asr/__init__.py +0 -0
  270. rasa/core/channels/voice_stream/asr/asr_engine.py +89 -0
  271. rasa/core/channels/voice_stream/asr/asr_event.py +18 -0
  272. rasa/core/channels/voice_stream/asr/azure.py +130 -0
  273. rasa/core/channels/voice_stream/asr/deepgram.py +90 -0
  274. rasa/core/channels/voice_stream/audio_bytes.py +8 -0
  275. rasa/core/channels/voice_stream/browser_audio.py +107 -0
  276. rasa/core/channels/voice_stream/call_state.py +23 -0
  277. rasa/core/channels/voice_stream/tts/__init__.py +0 -0
  278. rasa/core/channels/voice_stream/tts/azure.py +106 -0
  279. rasa/core/channels/voice_stream/tts/cartesia.py +118 -0
  280. rasa/core/channels/voice_stream/tts/tts_cache.py +27 -0
  281. rasa/core/channels/voice_stream/tts/tts_engine.py +58 -0
  282. rasa/core/channels/voice_stream/twilio_media_streams.py +173 -0
  283. rasa/core/channels/voice_stream/util.py +57 -0
  284. rasa/core/channels/voice_stream/voice_channel.py +427 -0
  285. rasa/core/channels/webexteams.py +134 -0
  286. rasa/core/concurrent_lock_store.py +210 -0
  287. rasa/core/constants.py +112 -0
  288. rasa/core/evaluation/__init__.py +0 -0
  289. rasa/core/evaluation/marker.py +267 -0
  290. rasa/core/evaluation/marker_base.py +923 -0
  291. rasa/core/evaluation/marker_stats.py +293 -0
  292. rasa/core/evaluation/marker_tracker_loader.py +103 -0
  293. rasa/core/exceptions.py +29 -0
  294. rasa/core/exporter.py +284 -0
  295. rasa/core/featurizers/__init__.py +0 -0
  296. rasa/core/featurizers/precomputation.py +410 -0
  297. rasa/core/featurizers/single_state_featurizer.py +421 -0
  298. rasa/core/featurizers/tracker_featurizers.py +1262 -0
  299. rasa/core/http_interpreter.py +89 -0
  300. rasa/core/information_retrieval/__init__.py +7 -0
  301. rasa/core/information_retrieval/faiss.py +124 -0
  302. rasa/core/information_retrieval/information_retrieval.py +137 -0
  303. rasa/core/information_retrieval/milvus.py +59 -0
  304. rasa/core/information_retrieval/qdrant.py +96 -0
  305. rasa/core/jobs.py +63 -0
  306. rasa/core/lock.py +139 -0
  307. rasa/core/lock_store.py +343 -0
  308. rasa/core/migrate.py +403 -0
  309. rasa/core/nlg/__init__.py +3 -0
  310. rasa/core/nlg/callback.py +146 -0
  311. rasa/core/nlg/contextual_response_rephraser.py +320 -0
  312. rasa/core/nlg/generator.py +230 -0
  313. rasa/core/nlg/interpolator.py +143 -0
  314. rasa/core/nlg/response.py +155 -0
  315. rasa/core/nlg/summarize.py +70 -0
  316. rasa/core/persistor.py +538 -0
  317. rasa/core/policies/__init__.py +0 -0
  318. rasa/core/policies/ensemble.py +329 -0
  319. rasa/core/policies/enterprise_search_policy.py +905 -0
  320. rasa/core/policies/enterprise_search_prompt_template.jinja2 +25 -0
  321. rasa/core/policies/enterprise_search_prompt_with_citation_template.jinja2 +60 -0
  322. rasa/core/policies/flow_policy.py +205 -0
  323. rasa/core/policies/flows/__init__.py +0 -0
  324. rasa/core/policies/flows/flow_exceptions.py +44 -0
  325. rasa/core/policies/flows/flow_executor.py +754 -0
  326. rasa/core/policies/flows/flow_step_result.py +43 -0
  327. rasa/core/policies/intentless_policy.py +1031 -0
  328. rasa/core/policies/intentless_prompt_template.jinja2 +22 -0
  329. rasa/core/policies/memoization.py +538 -0
  330. rasa/core/policies/policy.py +725 -0
  331. rasa/core/policies/rule_policy.py +1273 -0
  332. rasa/core/policies/ted_policy.py +2169 -0
  333. rasa/core/policies/unexpected_intent_policy.py +1022 -0
  334. rasa/core/processor.py +1465 -0
  335. rasa/core/run.py +342 -0
  336. rasa/core/secrets_manager/__init__.py +0 -0
  337. rasa/core/secrets_manager/constants.py +36 -0
  338. rasa/core/secrets_manager/endpoints.py +391 -0
  339. rasa/core/secrets_manager/factory.py +241 -0
  340. rasa/core/secrets_manager/secret_manager.py +262 -0
  341. rasa/core/secrets_manager/vault.py +584 -0
  342. rasa/core/test.py +1335 -0
  343. rasa/core/tracker_store.py +1703 -0
  344. rasa/core/train.py +105 -0
  345. rasa/core/training/__init__.py +89 -0
  346. rasa/core/training/converters/__init__.py +0 -0
  347. rasa/core/training/converters/responses_prefix_converter.py +119 -0
  348. rasa/core/training/interactive.py +1744 -0
  349. rasa/core/training/story_conflict.py +381 -0
  350. rasa/core/training/training.py +93 -0
  351. rasa/core/utils.py +366 -0
  352. rasa/core/visualize.py +70 -0
  353. rasa/dialogue_understanding/__init__.py +0 -0
  354. rasa/dialogue_understanding/coexistence/__init__.py +0 -0
  355. rasa/dialogue_understanding/coexistence/constants.py +4 -0
  356. rasa/dialogue_understanding/coexistence/intent_based_router.py +196 -0
  357. rasa/dialogue_understanding/coexistence/llm_based_router.py +327 -0
  358. rasa/dialogue_understanding/coexistence/router_template.jinja2 +12 -0
  359. rasa/dialogue_understanding/commands/__init__.py +61 -0
  360. rasa/dialogue_understanding/commands/can_not_handle_command.py +70 -0
  361. rasa/dialogue_understanding/commands/cancel_flow_command.py +125 -0
  362. rasa/dialogue_understanding/commands/change_flow_command.py +44 -0
  363. rasa/dialogue_understanding/commands/chit_chat_answer_command.py +57 -0
  364. rasa/dialogue_understanding/commands/clarify_command.py +86 -0
  365. rasa/dialogue_understanding/commands/command.py +85 -0
  366. rasa/dialogue_understanding/commands/correct_slots_command.py +297 -0
  367. rasa/dialogue_understanding/commands/error_command.py +79 -0
  368. rasa/dialogue_understanding/commands/free_form_answer_command.py +9 -0
  369. rasa/dialogue_understanding/commands/handle_code_change_command.py +73 -0
  370. rasa/dialogue_understanding/commands/human_handoff_command.py +66 -0
  371. rasa/dialogue_understanding/commands/knowledge_answer_command.py +57 -0
  372. rasa/dialogue_understanding/commands/noop_command.py +54 -0
  373. rasa/dialogue_understanding/commands/repeat_bot_messages_command.py +60 -0
  374. rasa/dialogue_understanding/commands/restart_command.py +58 -0
  375. rasa/dialogue_understanding/commands/session_end_command.py +61 -0
  376. rasa/dialogue_understanding/commands/session_start_command.py +59 -0
  377. rasa/dialogue_understanding/commands/set_slot_command.py +160 -0
  378. rasa/dialogue_understanding/commands/skip_question_command.py +75 -0
  379. rasa/dialogue_understanding/commands/start_flow_command.py +107 -0
  380. rasa/dialogue_understanding/commands/user_silence_command.py +59 -0
  381. rasa/dialogue_understanding/commands/utils.py +45 -0
  382. rasa/dialogue_understanding/generator/__init__.py +21 -0
  383. rasa/dialogue_understanding/generator/command_generator.py +464 -0
  384. rasa/dialogue_understanding/generator/constants.py +27 -0
  385. rasa/dialogue_understanding/generator/flow_document_template.jinja2 +4 -0
  386. rasa/dialogue_understanding/generator/flow_retrieval.py +466 -0
  387. rasa/dialogue_understanding/generator/llm_based_command_generator.py +500 -0
  388. rasa/dialogue_understanding/generator/llm_command_generator.py +67 -0
  389. rasa/dialogue_understanding/generator/multi_step/__init__.py +0 -0
  390. rasa/dialogue_understanding/generator/multi_step/fill_slots_prompt.jinja2 +62 -0
  391. rasa/dialogue_understanding/generator/multi_step/handle_flows_prompt.jinja2 +38 -0
  392. rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +920 -0
  393. rasa/dialogue_understanding/generator/nlu_command_adapter.py +261 -0
  394. rasa/dialogue_understanding/generator/single_step/__init__.py +0 -0
  395. rasa/dialogue_understanding/generator/single_step/command_prompt_template.jinja2 +60 -0
  396. rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +486 -0
  397. rasa/dialogue_understanding/patterns/__init__.py +0 -0
  398. rasa/dialogue_understanding/patterns/cancel.py +111 -0
  399. rasa/dialogue_understanding/patterns/cannot_handle.py +43 -0
  400. rasa/dialogue_understanding/patterns/chitchat.py +37 -0
  401. rasa/dialogue_understanding/patterns/clarify.py +97 -0
  402. rasa/dialogue_understanding/patterns/code_change.py +41 -0
  403. rasa/dialogue_understanding/patterns/collect_information.py +90 -0
  404. rasa/dialogue_understanding/patterns/completed.py +40 -0
  405. rasa/dialogue_understanding/patterns/continue_interrupted.py +42 -0
  406. rasa/dialogue_understanding/patterns/correction.py +278 -0
  407. rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +301 -0
  408. rasa/dialogue_understanding/patterns/human_handoff.py +37 -0
  409. rasa/dialogue_understanding/patterns/internal_error.py +47 -0
  410. rasa/dialogue_understanding/patterns/repeat.py +37 -0
  411. rasa/dialogue_understanding/patterns/restart.py +37 -0
  412. rasa/dialogue_understanding/patterns/search.py +37 -0
  413. rasa/dialogue_understanding/patterns/session_start.py +37 -0
  414. rasa/dialogue_understanding/patterns/skip_question.py +38 -0
  415. rasa/dialogue_understanding/patterns/user_silence.py +37 -0
  416. rasa/dialogue_understanding/processor/__init__.py +0 -0
  417. rasa/dialogue_understanding/processor/command_processor.py +720 -0
  418. rasa/dialogue_understanding/processor/command_processor_component.py +43 -0
  419. rasa/dialogue_understanding/stack/__init__.py +0 -0
  420. rasa/dialogue_understanding/stack/dialogue_stack.py +178 -0
  421. rasa/dialogue_understanding/stack/frames/__init__.py +19 -0
  422. rasa/dialogue_understanding/stack/frames/chit_chat_frame.py +27 -0
  423. rasa/dialogue_understanding/stack/frames/dialogue_stack_frame.py +137 -0
  424. rasa/dialogue_understanding/stack/frames/flow_stack_frame.py +157 -0
  425. rasa/dialogue_understanding/stack/frames/pattern_frame.py +10 -0
  426. rasa/dialogue_understanding/stack/frames/search_frame.py +27 -0
  427. rasa/dialogue_understanding/stack/utils.py +211 -0
  428. rasa/dialogue_understanding/utils.py +14 -0
  429. rasa/dialogue_understanding_test/__init__.py +0 -0
  430. rasa/dialogue_understanding_test/command_metric_calculation.py +12 -0
  431. rasa/dialogue_understanding_test/constants.py +17 -0
  432. rasa/dialogue_understanding_test/du_test_case.py +118 -0
  433. rasa/dialogue_understanding_test/du_test_result.py +11 -0
  434. rasa/dialogue_understanding_test/du_test_runner.py +93 -0
  435. rasa/dialogue_understanding_test/io.py +54 -0
  436. rasa/dialogue_understanding_test/validation.py +22 -0
  437. rasa/e2e_test/__init__.py +0 -0
  438. rasa/e2e_test/aggregate_test_stats_calculator.py +134 -0
  439. rasa/e2e_test/assertions.py +1345 -0
  440. rasa/e2e_test/assertions_schema.yml +129 -0
  441. rasa/e2e_test/constants.py +31 -0
  442. rasa/e2e_test/e2e_config.py +220 -0
  443. rasa/e2e_test/e2e_config_schema.yml +26 -0
  444. rasa/e2e_test/e2e_test_case.py +569 -0
  445. rasa/e2e_test/e2e_test_converter.py +363 -0
  446. rasa/e2e_test/e2e_test_converter_prompt.jinja2 +70 -0
  447. rasa/e2e_test/e2e_test_coverage_report.py +364 -0
  448. rasa/e2e_test/e2e_test_result.py +54 -0
  449. rasa/e2e_test/e2e_test_runner.py +1192 -0
  450. rasa/e2e_test/e2e_test_schema.yml +181 -0
  451. rasa/e2e_test/pykwalify_extensions.py +39 -0
  452. rasa/e2e_test/stub_custom_action.py +70 -0
  453. rasa/e2e_test/utils/__init__.py +0 -0
  454. rasa/e2e_test/utils/e2e_yaml_utils.py +55 -0
  455. rasa/e2e_test/utils/io.py +598 -0
  456. rasa/e2e_test/utils/validation.py +178 -0
  457. rasa/engine/__init__.py +0 -0
  458. rasa/engine/caching.py +463 -0
  459. rasa/engine/constants.py +17 -0
  460. rasa/engine/exceptions.py +14 -0
  461. rasa/engine/graph.py +642 -0
  462. rasa/engine/loader.py +48 -0
  463. rasa/engine/recipes/__init__.py +0 -0
  464. rasa/engine/recipes/config_files/default_config.yml +41 -0
  465. rasa/engine/recipes/default_components.py +97 -0
  466. rasa/engine/recipes/default_recipe.py +1272 -0
  467. rasa/engine/recipes/graph_recipe.py +79 -0
  468. rasa/engine/recipes/recipe.py +93 -0
  469. rasa/engine/runner/__init__.py +0 -0
  470. rasa/engine/runner/dask.py +250 -0
  471. rasa/engine/runner/interface.py +49 -0
  472. rasa/engine/storage/__init__.py +0 -0
  473. rasa/engine/storage/local_model_storage.py +244 -0
  474. rasa/engine/storage/resource.py +110 -0
  475. rasa/engine/storage/storage.py +199 -0
  476. rasa/engine/training/__init__.py +0 -0
  477. rasa/engine/training/components.py +176 -0
  478. rasa/engine/training/fingerprinting.py +64 -0
  479. rasa/engine/training/graph_trainer.py +256 -0
  480. rasa/engine/training/hooks.py +164 -0
  481. rasa/engine/validation.py +1451 -0
  482. rasa/env.py +14 -0
  483. rasa/exceptions.py +69 -0
  484. rasa/graph_components/__init__.py +0 -0
  485. rasa/graph_components/converters/__init__.py +0 -0
  486. rasa/graph_components/converters/nlu_message_converter.py +48 -0
  487. rasa/graph_components/providers/__init__.py +0 -0
  488. rasa/graph_components/providers/domain_for_core_training_provider.py +87 -0
  489. rasa/graph_components/providers/domain_provider.py +71 -0
  490. rasa/graph_components/providers/flows_provider.py +74 -0
  491. rasa/graph_components/providers/forms_provider.py +44 -0
  492. rasa/graph_components/providers/nlu_training_data_provider.py +56 -0
  493. rasa/graph_components/providers/responses_provider.py +44 -0
  494. rasa/graph_components/providers/rule_only_provider.py +49 -0
  495. rasa/graph_components/providers/story_graph_provider.py +96 -0
  496. rasa/graph_components/providers/training_tracker_provider.py +55 -0
  497. rasa/graph_components/validators/__init__.py +0 -0
  498. rasa/graph_components/validators/default_recipe_validator.py +550 -0
  499. rasa/graph_components/validators/finetuning_validator.py +302 -0
  500. rasa/hooks.py +111 -0
  501. rasa/jupyter.py +63 -0
  502. rasa/llm_fine_tuning/__init__.py +0 -0
  503. rasa/llm_fine_tuning/annotation_module.py +241 -0
  504. rasa/llm_fine_tuning/conversations.py +144 -0
  505. rasa/llm_fine_tuning/llm_data_preparation_module.py +178 -0
  506. rasa/llm_fine_tuning/paraphrasing/__init__.py +0 -0
  507. rasa/llm_fine_tuning/paraphrasing/conversation_rephraser.py +281 -0
  508. rasa/llm_fine_tuning/paraphrasing/default_rephrase_prompt_template.jina2 +44 -0
  509. rasa/llm_fine_tuning/paraphrasing/rephrase_validator.py +121 -0
  510. rasa/llm_fine_tuning/paraphrasing/rephrased_user_message.py +10 -0
  511. rasa/llm_fine_tuning/paraphrasing_module.py +128 -0
  512. rasa/llm_fine_tuning/storage.py +174 -0
  513. rasa/llm_fine_tuning/train_test_split_module.py +441 -0
  514. rasa/markers/__init__.py +0 -0
  515. rasa/markers/marker.py +269 -0
  516. rasa/markers/marker_base.py +828 -0
  517. rasa/markers/upload.py +74 -0
  518. rasa/markers/validate.py +21 -0
  519. rasa/model.py +118 -0
  520. rasa/model_manager/__init__.py +0 -0
  521. rasa/model_manager/config.py +40 -0
  522. rasa/model_manager/model_api.py +559 -0
  523. rasa/model_manager/runner_service.py +286 -0
  524. rasa/model_manager/socket_bridge.py +146 -0
  525. rasa/model_manager/studio_jwt_auth.py +86 -0
  526. rasa/model_manager/trainer_service.py +325 -0
  527. rasa/model_manager/utils.py +87 -0
  528. rasa/model_manager/warm_rasa_process.py +187 -0
  529. rasa/model_service.py +112 -0
  530. rasa/model_testing.py +457 -0
  531. rasa/model_training.py +596 -0
  532. rasa/nlu/__init__.py +7 -0
  533. rasa/nlu/classifiers/__init__.py +3 -0
  534. rasa/nlu/classifiers/classifier.py +5 -0
  535. rasa/nlu/classifiers/diet_classifier.py +1881 -0
  536. rasa/nlu/classifiers/fallback_classifier.py +192 -0
  537. rasa/nlu/classifiers/keyword_intent_classifier.py +188 -0
  538. rasa/nlu/classifiers/logistic_regression_classifier.py +253 -0
  539. rasa/nlu/classifiers/mitie_intent_classifier.py +156 -0
  540. rasa/nlu/classifiers/regex_message_handler.py +56 -0
  541. rasa/nlu/classifiers/sklearn_intent_classifier.py +330 -0
  542. rasa/nlu/constants.py +77 -0
  543. rasa/nlu/convert.py +40 -0
  544. rasa/nlu/emulators/__init__.py +0 -0
  545. rasa/nlu/emulators/dialogflow.py +55 -0
  546. rasa/nlu/emulators/emulator.py +49 -0
  547. rasa/nlu/emulators/luis.py +86 -0
  548. rasa/nlu/emulators/no_emulator.py +10 -0
  549. rasa/nlu/emulators/wit.py +56 -0
  550. rasa/nlu/extractors/__init__.py +0 -0
  551. rasa/nlu/extractors/crf_entity_extractor.py +715 -0
  552. rasa/nlu/extractors/duckling_entity_extractor.py +206 -0
  553. rasa/nlu/extractors/entity_synonyms.py +178 -0
  554. rasa/nlu/extractors/extractor.py +470 -0
  555. rasa/nlu/extractors/mitie_entity_extractor.py +293 -0
  556. rasa/nlu/extractors/regex_entity_extractor.py +220 -0
  557. rasa/nlu/extractors/spacy_entity_extractor.py +95 -0
  558. rasa/nlu/featurizers/__init__.py +0 -0
  559. rasa/nlu/featurizers/dense_featurizer/__init__.py +0 -0
  560. rasa/nlu/featurizers/dense_featurizer/convert_featurizer.py +445 -0
  561. rasa/nlu/featurizers/dense_featurizer/dense_featurizer.py +57 -0
  562. rasa/nlu/featurizers/dense_featurizer/lm_featurizer.py +768 -0
  563. rasa/nlu/featurizers/dense_featurizer/mitie_featurizer.py +170 -0
  564. rasa/nlu/featurizers/dense_featurizer/spacy_featurizer.py +132 -0
  565. rasa/nlu/featurizers/featurizer.py +89 -0
  566. rasa/nlu/featurizers/sparse_featurizer/__init__.py +0 -0
  567. rasa/nlu/featurizers/sparse_featurizer/count_vectors_featurizer.py +867 -0
  568. rasa/nlu/featurizers/sparse_featurizer/lexical_syntactic_featurizer.py +571 -0
  569. rasa/nlu/featurizers/sparse_featurizer/regex_featurizer.py +271 -0
  570. rasa/nlu/featurizers/sparse_featurizer/sparse_featurizer.py +9 -0
  571. rasa/nlu/model.py +24 -0
  572. rasa/nlu/run.py +27 -0
  573. rasa/nlu/selectors/__init__.py +0 -0
  574. rasa/nlu/selectors/response_selector.py +987 -0
  575. rasa/nlu/test.py +1940 -0
  576. rasa/nlu/tokenizers/__init__.py +0 -0
  577. rasa/nlu/tokenizers/jieba_tokenizer.py +148 -0
  578. rasa/nlu/tokenizers/mitie_tokenizer.py +75 -0
  579. rasa/nlu/tokenizers/spacy_tokenizer.py +72 -0
  580. rasa/nlu/tokenizers/tokenizer.py +239 -0
  581. rasa/nlu/tokenizers/whitespace_tokenizer.py +95 -0
  582. rasa/nlu/utils/__init__.py +35 -0
  583. rasa/nlu/utils/bilou_utils.py +462 -0
  584. rasa/nlu/utils/hugging_face/__init__.py +0 -0
  585. rasa/nlu/utils/hugging_face/registry.py +108 -0
  586. rasa/nlu/utils/hugging_face/transformers_pre_post_processors.py +311 -0
  587. rasa/nlu/utils/mitie_utils.py +113 -0
  588. rasa/nlu/utils/pattern_utils.py +168 -0
  589. rasa/nlu/utils/spacy_utils.py +310 -0
  590. rasa/plugin.py +90 -0
  591. rasa/server.py +1588 -0
  592. rasa/shared/__init__.py +0 -0
  593. rasa/shared/constants.py +311 -0
  594. rasa/shared/core/__init__.py +0 -0
  595. rasa/shared/core/command_payload_reader.py +109 -0
  596. rasa/shared/core/constants.py +180 -0
  597. rasa/shared/core/conversation.py +46 -0
  598. rasa/shared/core/domain.py +2172 -0
  599. rasa/shared/core/events.py +2559 -0
  600. rasa/shared/core/flows/__init__.py +7 -0
  601. rasa/shared/core/flows/flow.py +562 -0
  602. rasa/shared/core/flows/flow_path.py +84 -0
  603. rasa/shared/core/flows/flow_step.py +146 -0
  604. rasa/shared/core/flows/flow_step_links.py +319 -0
  605. rasa/shared/core/flows/flow_step_sequence.py +70 -0
  606. rasa/shared/core/flows/flows_list.py +258 -0
  607. rasa/shared/core/flows/flows_yaml_schema.json +303 -0
  608. rasa/shared/core/flows/nlu_trigger.py +117 -0
  609. rasa/shared/core/flows/steps/__init__.py +24 -0
  610. rasa/shared/core/flows/steps/action.py +56 -0
  611. rasa/shared/core/flows/steps/call.py +64 -0
  612. rasa/shared/core/flows/steps/collect.py +112 -0
  613. rasa/shared/core/flows/steps/constants.py +5 -0
  614. rasa/shared/core/flows/steps/continuation.py +36 -0
  615. rasa/shared/core/flows/steps/end.py +22 -0
  616. rasa/shared/core/flows/steps/internal.py +44 -0
  617. rasa/shared/core/flows/steps/link.py +51 -0
  618. rasa/shared/core/flows/steps/no_operation.py +48 -0
  619. rasa/shared/core/flows/steps/set_slots.py +50 -0
  620. rasa/shared/core/flows/steps/start.py +30 -0
  621. rasa/shared/core/flows/utils.py +39 -0
  622. rasa/shared/core/flows/validation.py +735 -0
  623. rasa/shared/core/flows/yaml_flows_io.py +405 -0
  624. rasa/shared/core/generator.py +908 -0
  625. rasa/shared/core/slot_mappings.py +526 -0
  626. rasa/shared/core/slots.py +654 -0
  627. rasa/shared/core/trackers.py +1183 -0
  628. rasa/shared/core/training_data/__init__.py +0 -0
  629. rasa/shared/core/training_data/loading.py +89 -0
  630. rasa/shared/core/training_data/story_reader/__init__.py +0 -0
  631. rasa/shared/core/training_data/story_reader/story_reader.py +129 -0
  632. rasa/shared/core/training_data/story_reader/story_step_builder.py +168 -0
  633. rasa/shared/core/training_data/story_reader/yaml_story_reader.py +888 -0
  634. rasa/shared/core/training_data/story_writer/__init__.py +0 -0
  635. rasa/shared/core/training_data/story_writer/story_writer.py +76 -0
  636. rasa/shared/core/training_data/story_writer/yaml_story_writer.py +444 -0
  637. rasa/shared/core/training_data/structures.py +858 -0
  638. rasa/shared/core/training_data/visualization.html +146 -0
  639. rasa/shared/core/training_data/visualization.py +603 -0
  640. rasa/shared/data.py +249 -0
  641. rasa/shared/engine/__init__.py +0 -0
  642. rasa/shared/engine/caching.py +26 -0
  643. rasa/shared/exceptions.py +167 -0
  644. rasa/shared/importers/__init__.py +0 -0
  645. rasa/shared/importers/importer.py +770 -0
  646. rasa/shared/importers/multi_project.py +215 -0
  647. rasa/shared/importers/rasa.py +108 -0
  648. rasa/shared/importers/remote_importer.py +196 -0
  649. rasa/shared/importers/utils.py +36 -0
  650. rasa/shared/nlu/__init__.py +0 -0
  651. rasa/shared/nlu/constants.py +53 -0
  652. rasa/shared/nlu/interpreter.py +10 -0
  653. rasa/shared/nlu/training_data/__init__.py +0 -0
  654. rasa/shared/nlu/training_data/entities_parser.py +208 -0
  655. rasa/shared/nlu/training_data/features.py +492 -0
  656. rasa/shared/nlu/training_data/formats/__init__.py +10 -0
  657. rasa/shared/nlu/training_data/formats/dialogflow.py +163 -0
  658. rasa/shared/nlu/training_data/formats/luis.py +87 -0
  659. rasa/shared/nlu/training_data/formats/rasa.py +135 -0
  660. rasa/shared/nlu/training_data/formats/rasa_yaml.py +618 -0
  661. rasa/shared/nlu/training_data/formats/readerwriter.py +244 -0
  662. rasa/shared/nlu/training_data/formats/wit.py +52 -0
  663. rasa/shared/nlu/training_data/loading.py +137 -0
  664. rasa/shared/nlu/training_data/lookup_tables_parser.py +30 -0
  665. rasa/shared/nlu/training_data/message.py +490 -0
  666. rasa/shared/nlu/training_data/schemas/__init__.py +0 -0
  667. rasa/shared/nlu/training_data/schemas/data_schema.py +85 -0
  668. rasa/shared/nlu/training_data/schemas/nlu.yml +53 -0
  669. rasa/shared/nlu/training_data/schemas/responses.yml +70 -0
  670. rasa/shared/nlu/training_data/synonyms_parser.py +42 -0
  671. rasa/shared/nlu/training_data/training_data.py +729 -0
  672. rasa/shared/nlu/training_data/util.py +223 -0
  673. rasa/shared/providers/__init__.py +0 -0
  674. rasa/shared/providers/_configs/__init__.py +0 -0
  675. rasa/shared/providers/_configs/azure_openai_client_config.py +677 -0
  676. rasa/shared/providers/_configs/client_config.py +59 -0
  677. rasa/shared/providers/_configs/default_litellm_client_config.py +132 -0
  678. rasa/shared/providers/_configs/huggingface_local_embedding_client_config.py +236 -0
  679. rasa/shared/providers/_configs/litellm_router_client_config.py +222 -0
  680. rasa/shared/providers/_configs/model_group_config.py +173 -0
  681. rasa/shared/providers/_configs/openai_client_config.py +177 -0
  682. rasa/shared/providers/_configs/rasa_llm_client_config.py +75 -0
  683. rasa/shared/providers/_configs/self_hosted_llm_client_config.py +178 -0
  684. rasa/shared/providers/_configs/utils.py +117 -0
  685. rasa/shared/providers/_ssl_verification_utils.py +124 -0
  686. rasa/shared/providers/_utils.py +79 -0
  687. rasa/shared/providers/constants.py +7 -0
  688. rasa/shared/providers/embedding/__init__.py +0 -0
  689. rasa/shared/providers/embedding/_base_litellm_embedding_client.py +243 -0
  690. rasa/shared/providers/embedding/_langchain_embedding_client_adapter.py +74 -0
  691. rasa/shared/providers/embedding/azure_openai_embedding_client.py +335 -0
  692. rasa/shared/providers/embedding/default_litellm_embedding_client.py +126 -0
  693. rasa/shared/providers/embedding/embedding_client.py +90 -0
  694. rasa/shared/providers/embedding/embedding_response.py +41 -0
  695. rasa/shared/providers/embedding/huggingface_local_embedding_client.py +191 -0
  696. rasa/shared/providers/embedding/litellm_router_embedding_client.py +138 -0
  697. rasa/shared/providers/embedding/openai_embedding_client.py +172 -0
  698. rasa/shared/providers/llm/__init__.py +0 -0
  699. rasa/shared/providers/llm/_base_litellm_client.py +265 -0
  700. rasa/shared/providers/llm/azure_openai_llm_client.py +415 -0
  701. rasa/shared/providers/llm/default_litellm_llm_client.py +110 -0
  702. rasa/shared/providers/llm/litellm_router_llm_client.py +202 -0
  703. rasa/shared/providers/llm/llm_client.py +78 -0
  704. rasa/shared/providers/llm/llm_response.py +50 -0
  705. rasa/shared/providers/llm/openai_llm_client.py +161 -0
  706. rasa/shared/providers/llm/rasa_llm_client.py +120 -0
  707. rasa/shared/providers/llm/self_hosted_llm_client.py +276 -0
  708. rasa/shared/providers/mappings.py +94 -0
  709. rasa/shared/providers/router/__init__.py +0 -0
  710. rasa/shared/providers/router/_base_litellm_router_client.py +185 -0
  711. rasa/shared/providers/router/router_client.py +75 -0
  712. rasa/shared/utils/__init__.py +0 -0
  713. rasa/shared/utils/cli.py +102 -0
  714. rasa/shared/utils/common.py +324 -0
  715. rasa/shared/utils/constants.py +4 -0
  716. rasa/shared/utils/health_check/__init__.py +0 -0
  717. rasa/shared/utils/health_check/embeddings_health_check_mixin.py +31 -0
  718. rasa/shared/utils/health_check/health_check.py +258 -0
  719. rasa/shared/utils/health_check/llm_health_check_mixin.py +31 -0
  720. rasa/shared/utils/io.py +499 -0
  721. rasa/shared/utils/llm.py +764 -0
  722. rasa/shared/utils/pykwalify_extensions.py +27 -0
  723. rasa/shared/utils/schemas/__init__.py +0 -0
  724. rasa/shared/utils/schemas/config.yml +2 -0
  725. rasa/shared/utils/schemas/domain.yml +145 -0
  726. rasa/shared/utils/schemas/events.py +214 -0
  727. rasa/shared/utils/schemas/model_config.yml +36 -0
  728. rasa/shared/utils/schemas/stories.yml +173 -0
  729. rasa/shared/utils/yaml.py +1068 -0
  730. rasa/studio/__init__.py +0 -0
  731. rasa/studio/auth.py +270 -0
  732. rasa/studio/config.py +136 -0
  733. rasa/studio/constants.py +19 -0
  734. rasa/studio/data_handler.py +368 -0
  735. rasa/studio/download.py +489 -0
  736. rasa/studio/results_logger.py +137 -0
  737. rasa/studio/train.py +134 -0
  738. rasa/studio/upload.py +563 -0
  739. rasa/telemetry.py +1876 -0
  740. rasa/tracing/__init__.py +0 -0
  741. rasa/tracing/config.py +355 -0
  742. rasa/tracing/constants.py +62 -0
  743. rasa/tracing/instrumentation/__init__.py +0 -0
  744. rasa/tracing/instrumentation/attribute_extractors.py +765 -0
  745. rasa/tracing/instrumentation/instrumentation.py +1306 -0
  746. rasa/tracing/instrumentation/intentless_policy_instrumentation.py +144 -0
  747. rasa/tracing/instrumentation/metrics.py +294 -0
  748. rasa/tracing/metric_instrument_provider.py +205 -0
  749. rasa/utils/__init__.py +0 -0
  750. rasa/utils/beta.py +83 -0
  751. rasa/utils/cli.py +28 -0
  752. rasa/utils/common.py +639 -0
  753. rasa/utils/converter.py +53 -0
  754. rasa/utils/endpoints.py +331 -0
  755. rasa/utils/io.py +252 -0
  756. rasa/utils/json_utils.py +60 -0
  757. rasa/utils/licensing.py +542 -0
  758. rasa/utils/log_utils.py +181 -0
  759. rasa/utils/mapper.py +210 -0
  760. rasa/utils/ml_utils.py +147 -0
  761. rasa/utils/plotting.py +362 -0
  762. rasa/utils/sanic_error_handler.py +32 -0
  763. rasa/utils/singleton.py +23 -0
  764. rasa/utils/tensorflow/__init__.py +0 -0
  765. rasa/utils/tensorflow/callback.py +112 -0
  766. rasa/utils/tensorflow/constants.py +116 -0
  767. rasa/utils/tensorflow/crf.py +492 -0
  768. rasa/utils/tensorflow/data_generator.py +440 -0
  769. rasa/utils/tensorflow/environment.py +161 -0
  770. rasa/utils/tensorflow/exceptions.py +5 -0
  771. rasa/utils/tensorflow/feature_array.py +366 -0
  772. rasa/utils/tensorflow/layers.py +1565 -0
  773. rasa/utils/tensorflow/layers_utils.py +113 -0
  774. rasa/utils/tensorflow/metrics.py +281 -0
  775. rasa/utils/tensorflow/model_data.py +798 -0
  776. rasa/utils/tensorflow/model_data_utils.py +499 -0
  777. rasa/utils/tensorflow/models.py +935 -0
  778. rasa/utils/tensorflow/rasa_layers.py +1094 -0
  779. rasa/utils/tensorflow/transformer.py +640 -0
  780. rasa/utils/tensorflow/types.py +6 -0
  781. rasa/utils/train_utils.py +572 -0
  782. rasa/utils/url_tools.py +53 -0
  783. rasa/utils/yaml.py +54 -0
  784. rasa/validator.py +1644 -0
  785. rasa/version.py +3 -0
  786. rasa_pro-3.12.0.dev1.dist-info/METADATA +199 -0
  787. rasa_pro-3.12.0.dev1.dist-info/NOTICE +5 -0
  788. rasa_pro-3.12.0.dev1.dist-info/RECORD +790 -0
  789. rasa_pro-3.12.0.dev1.dist-info/WHEEL +4 -0
  790. rasa_pro-3.12.0.dev1.dist-info/entry_points.txt +3 -0
@@ -0,0 +1,987 @@
1
+ from __future__ import annotations
2
+ import copy
3
+ import logging
4
+ from rasa.nlu.featurizers.featurizer import Featurizer
5
+
6
+ import numpy as np
7
+ import tensorflow as tf
8
+
9
+ from typing import Any, Dict, Optional, Text, Tuple, Union, List, Type
10
+
11
+ from rasa.engine.graph import ExecutionContext
12
+ from rasa.engine.recipes.default_recipe import DefaultV1Recipe
13
+ from rasa.engine.storage.resource import Resource
14
+ from rasa.engine.storage.storage import ModelStorage
15
+ from rasa.shared.constants import DIAGNOSTIC_DATA
16
+ from rasa.shared.nlu.training_data import util
17
+ import rasa.shared.utils.io
18
+ from rasa.shared.exceptions import InvalidConfigException
19
+ from rasa.shared.nlu.training_data.training_data import TrainingData
20
+ from rasa.shared.nlu.training_data.message import Message
21
+ from rasa.nlu.classifiers.diet_classifier import (
22
+ DIET,
23
+ LABEL_KEY,
24
+ LABEL_SUB_KEY,
25
+ SENTENCE,
26
+ SEQUENCE,
27
+ DIETClassifier,
28
+ )
29
+ from rasa.nlu.extractors.extractor import EntityTagSpec
30
+ from rasa.utils.tensorflow import rasa_layers
31
+ from rasa.utils.tensorflow.constants import (
32
+ LABEL,
33
+ HIDDEN_LAYERS_SIZES,
34
+ SHARE_HIDDEN_LAYERS,
35
+ TRANSFORMER_SIZE,
36
+ NUM_TRANSFORMER_LAYERS,
37
+ NUM_HEADS,
38
+ BATCH_SIZES,
39
+ BATCH_STRATEGY,
40
+ EPOCHS,
41
+ RANDOM_SEED,
42
+ LEARNING_RATE,
43
+ RANKING_LENGTH,
44
+ RENORMALIZE_CONFIDENCES,
45
+ LOSS_TYPE,
46
+ SIMILARITY_TYPE,
47
+ NUM_NEG,
48
+ SPARSE_INPUT_DROPOUT,
49
+ DENSE_INPUT_DROPOUT,
50
+ MASKED_LM,
51
+ ENTITY_RECOGNITION,
52
+ INTENT_CLASSIFICATION,
53
+ EVAL_NUM_EXAMPLES,
54
+ EVAL_NUM_EPOCHS,
55
+ UNIDIRECTIONAL_ENCODER,
56
+ DROP_RATE,
57
+ DROP_RATE_ATTENTION,
58
+ CONNECTION_DENSITY,
59
+ NEGATIVE_MARGIN_SCALE,
60
+ REGULARIZATION_CONSTANT,
61
+ SCALE_LOSS,
62
+ USE_MAX_NEG_SIM,
63
+ MAX_NEG_SIM,
64
+ MAX_POS_SIM,
65
+ EMBEDDING_DIMENSION,
66
+ BILOU_FLAG,
67
+ KEY_RELATIVE_ATTENTION,
68
+ VALUE_RELATIVE_ATTENTION,
69
+ MAX_RELATIVE_POSITION,
70
+ RETRIEVAL_INTENT,
71
+ USE_TEXT_AS_LABEL,
72
+ CROSS_ENTROPY,
73
+ AUTO,
74
+ BALANCED,
75
+ TENSORBOARD_LOG_DIR,
76
+ TENSORBOARD_LOG_LEVEL,
77
+ CONCAT_DIMENSION,
78
+ FEATURIZERS,
79
+ CHECKPOINT_MODEL,
80
+ DENSE_DIMENSION,
81
+ CONSTRAIN_SIMILARITIES,
82
+ MODEL_CONFIDENCE,
83
+ SOFTMAX,
84
+ )
85
+ from rasa.nlu.constants import (
86
+ RESPONSE_SELECTOR_PROPERTY_NAME,
87
+ RESPONSE_SELECTOR_RETRIEVAL_INTENTS,
88
+ RESPONSE_SELECTOR_RESPONSES_KEY,
89
+ RESPONSE_SELECTOR_PREDICTION_KEY,
90
+ RESPONSE_SELECTOR_RANKING_KEY,
91
+ RESPONSE_SELECTOR_UTTER_ACTION_KEY,
92
+ RESPONSE_SELECTOR_DEFAULT_INTENT,
93
+ DEFAULT_TRANSFORMER_SIZE,
94
+ )
95
+ from rasa.shared.nlu.constants import (
96
+ TEXT,
97
+ INTENT,
98
+ RESPONSE,
99
+ INTENT_RESPONSE_KEY,
100
+ INTENT_NAME_KEY,
101
+ PREDICTED_CONFIDENCE_KEY,
102
+ )
103
+
104
+ from rasa.utils.tensorflow.model_data import RasaModelData
105
+ from rasa.utils.tensorflow.models import RasaModel
106
+
107
+ logger = logging.getLogger(__name__)
108
+
109
+
110
+ @DefaultV1Recipe.register(
111
+ DefaultV1Recipe.ComponentType.INTENT_CLASSIFIER, is_trainable=True
112
+ )
113
+ class ResponseSelector(DIETClassifier):
114
+ """Response selector using supervised embeddings.
115
+
116
+ The response selector embeds user inputs
117
+ and candidate response into the same space.
118
+ Supervised embeddings are trained by maximizing similarity between them.
119
+ It also provides rankings of the response that did not "win".
120
+
121
+ The supervised response selector needs to be preceded by
122
+ a featurizer in the pipeline.
123
+ This featurizer creates the features used for the embeddings.
124
+ It is recommended to use ``CountVectorsFeaturizer`` that
125
+ can be optionally preceded by ``SpacyNLP`` and ``SpacyTokenizer``.
126
+
127
+ Based on the starspace idea from: https://arxiv.org/abs/1709.03856.
128
+ However, in this implementation the `mu` parameter is treated differently
129
+ and additional hidden layers are added together with dropout.
130
+ """
131
+
132
+ @classmethod
133
+ def required_components(cls) -> List[Type]:
134
+ """Components that should be included in the pipeline before this component."""
135
+ return [Featurizer]
136
+
137
+ @staticmethod
138
+ def get_default_config() -> Dict[Text, Any]:
139
+ """The component's default config (see parent class for full docstring)."""
140
+ return {
141
+ **DIETClassifier.get_default_config(),
142
+ # ## Architecture of the used neural network
143
+ # Hidden layer sizes for layers before the embedding layers for user message
144
+ # and labels.
145
+ # The number of hidden layers is equal to the length of the corresponding
146
+ # list.
147
+ HIDDEN_LAYERS_SIZES: {TEXT: [256, 128], LABEL: [256, 128]},
148
+ # Whether to share the hidden layer weights between input words
149
+ # and responses
150
+ SHARE_HIDDEN_LAYERS: False,
151
+ # Number of units in transformer
152
+ TRANSFORMER_SIZE: None,
153
+ # Number of transformer layers
154
+ NUM_TRANSFORMER_LAYERS: 0,
155
+ # Number of attention heads in transformer
156
+ NUM_HEADS: 4,
157
+ # If 'True' use key relative embeddings in attention
158
+ KEY_RELATIVE_ATTENTION: False,
159
+ # If 'True' use key relative embeddings in attention
160
+ VALUE_RELATIVE_ATTENTION: False,
161
+ # Max position for relative embeddings. Only in effect if key-
162
+ # or value relative attention are turned on
163
+ MAX_RELATIVE_POSITION: 5,
164
+ # Use a unidirectional or bidirectional encoder.
165
+ UNIDIRECTIONAL_ENCODER: False,
166
+ # ## Training parameters
167
+ # Initial and final batch sizes:
168
+ # Batch size will be linearly increased for each epoch.
169
+ BATCH_SIZES: [64, 256],
170
+ # Strategy used when creating batches.
171
+ # Can be either 'sequence' or 'balanced'.
172
+ BATCH_STRATEGY: BALANCED,
173
+ # Number of epochs to train
174
+ EPOCHS: 300,
175
+ # Set random seed to any 'int' to get reproducible results
176
+ RANDOM_SEED: None,
177
+ # Initial learning rate for the optimizer
178
+ LEARNING_RATE: 0.001,
179
+ # ## Parameters for embeddings
180
+ # Dimension size of embedding vectors
181
+ EMBEDDING_DIMENSION: 20,
182
+ # Default dense dimension to use if no dense features are present.
183
+ DENSE_DIMENSION: {TEXT: 512, LABEL: 512},
184
+ # Default dimension to use for concatenating sequence and sentence features.
185
+ CONCAT_DIMENSION: {TEXT: 512, LABEL: 512},
186
+ # The number of incorrect labels. The algorithm will minimize
187
+ # their similarity to the user input during training.
188
+ NUM_NEG: 20,
189
+ # Type of similarity measure to use, either 'auto' or 'cosine' or 'inner'.
190
+ SIMILARITY_TYPE: AUTO,
191
+ # The type of the loss function, either 'cross_entropy' or 'margin'.
192
+ LOSS_TYPE: CROSS_ENTROPY,
193
+ # Number of top actions for which confidences should be predicted.
194
+ # Set to 0 if confidences for all intents should be reported.
195
+ RANKING_LENGTH: 10,
196
+ # Determines whether the confidences of the chosen top actions should be
197
+ # renormalized so that they sum up to 1. By default, we do not renormalize
198
+ # and return the confidences for the top actions as is.
199
+ # Note that renormalization only makes sense if confidences are generated
200
+ # via `softmax`.
201
+ RENORMALIZE_CONFIDENCES: False,
202
+ # Indicates how similar the algorithm should try to make embedding vectors
203
+ # for correct labels.
204
+ # Should be 0.0 < ... < 1.0 for 'cosine' similarity type.
205
+ MAX_POS_SIM: 0.8,
206
+ # Maximum negative similarity for incorrect labels.
207
+ # Should be -1.0 < ... < 1.0 for 'cosine' similarity type.
208
+ MAX_NEG_SIM: -0.4,
209
+ # If 'True' the algorithm only minimizes maximum similarity over
210
+ # incorrect intent labels, used only if 'loss_type' is set to 'margin'.
211
+ USE_MAX_NEG_SIM: True,
212
+ # Scale loss inverse proportionally to confidence of correct prediction
213
+ SCALE_LOSS: True,
214
+ # ## Regularization parameters
215
+ # The scale of regularization
216
+ REGULARIZATION_CONSTANT: 0.002,
217
+ # Fraction of trainable weights in internal layers.
218
+ CONNECTION_DENSITY: 1.0,
219
+ # The scale of how important is to minimize the maximum similarity
220
+ # between embeddings of different labels.
221
+ NEGATIVE_MARGIN_SCALE: 0.8,
222
+ # Dropout rate for encoder
223
+ DROP_RATE: 0.2,
224
+ # Dropout rate for attention
225
+ DROP_RATE_ATTENTION: 0,
226
+ # If 'True' apply dropout to sparse input tensors
227
+ SPARSE_INPUT_DROPOUT: False,
228
+ # If 'True' apply dropout to dense input tensors
229
+ DENSE_INPUT_DROPOUT: False,
230
+ # ## Evaluation parameters
231
+ # How often calculate validation accuracy.
232
+ # Small values may hurt performance, e.g. model accuracy.
233
+ EVAL_NUM_EPOCHS: 20,
234
+ # How many examples to use for hold out validation set
235
+ # Large values may hurt performance, e.g. model accuracy.
236
+ EVAL_NUM_EXAMPLES: 0,
237
+ # ## Selector config
238
+ # If 'True' random tokens of the input message will be masked and the model
239
+ # should predict those tokens.
240
+ MASKED_LM: False,
241
+ # Name of the intent for which this response selector is to be trained
242
+ RETRIEVAL_INTENT: None,
243
+ # Boolean flag to check if actual text of the response
244
+ # should be used as ground truth label for training the model.
245
+ USE_TEXT_AS_LABEL: False,
246
+ # If you want to use tensorboard to visualize training
247
+ # and validation metrics,
248
+ # set this option to a valid output directory.
249
+ TENSORBOARD_LOG_DIR: None,
250
+ # Define when training metrics for tensorboard should be logged.
251
+ # Either after every epoch or for every training step.
252
+ # Valid values: 'epoch' and 'batch'
253
+ TENSORBOARD_LOG_LEVEL: "epoch",
254
+ # Specify what features to use as sequence and sentence features
255
+ # By default all features in the pipeline are used.
256
+ FEATURIZERS: [],
257
+ # Perform model checkpointing
258
+ CHECKPOINT_MODEL: False,
259
+ # if 'True' applies sigmoid on all similarity terms and adds it
260
+ # to the loss function to ensure that similarity values are
261
+ # approximately bounded. Used inside cross-entropy loss only.
262
+ CONSTRAIN_SIMILARITIES: False,
263
+ # Model confidence to be returned during inference. Currently, the only
264
+ # possible value is `softmax`.
265
+ MODEL_CONFIDENCE: SOFTMAX,
266
+ }
267
+
268
+ def __init__(
269
+ self,
270
+ config: Dict[Text, Any],
271
+ model_storage: ModelStorage,
272
+ resource: Resource,
273
+ execution_context: ExecutionContext,
274
+ index_label_id_mapping: Optional[Dict[int, Text]] = None,
275
+ entity_tag_specs: Optional[List[EntityTagSpec]] = None,
276
+ model: Optional[RasaModel] = None,
277
+ all_retrieval_intents: Optional[List[Text]] = None,
278
+ responses: Optional[Dict[Text, List[Dict[Text, Any]]]] = None,
279
+ sparse_feature_sizes: Optional[Dict[Text, Dict[Text, List[int]]]] = None,
280
+ ) -> None:
281
+ """Declare instance variables with default values.
282
+
283
+ Args:
284
+ config: Configuration for the component.
285
+ model_storage: Storage which graph components can use to persist and load
286
+ themselves.
287
+ resource: Resource locator for this component which can be used to persist
288
+ and load itself from the `model_storage`.
289
+ execution_context: Information about the current graph run.
290
+ index_label_id_mapping: Mapping between label and index used for encoding.
291
+ entity_tag_specs: Format specification all entity tags.
292
+ model: Model architecture.
293
+ all_retrieval_intents: All retrieval intents defined in the data.
294
+ responses: All responses defined in the data.
295
+ finetune_mode: If `True` loads the model with pre-trained weights,
296
+ otherwise initializes it with random weights.
297
+ sparse_feature_sizes: Sizes of the sparse features the model was trained on.
298
+ """
299
+ component_config = config
300
+
301
+ # the following properties cannot be adapted for the ResponseSelector
302
+ component_config[INTENT_CLASSIFICATION] = True
303
+ component_config[ENTITY_RECOGNITION] = False
304
+ component_config[BILOU_FLAG] = None
305
+
306
+ # Initialize defaults
307
+ self.responses = responses or {}
308
+ self.all_retrieval_intents = all_retrieval_intents or []
309
+ self.retrieval_intent = None
310
+ self.use_text_as_label = False
311
+
312
+ super().__init__(
313
+ component_config,
314
+ model_storage,
315
+ resource,
316
+ execution_context,
317
+ index_label_id_mapping,
318
+ entity_tag_specs,
319
+ model,
320
+ sparse_feature_sizes=sparse_feature_sizes,
321
+ )
322
+
323
+ @property
324
+ def label_key(self) -> Text:
325
+ """Returns label key."""
326
+ return LABEL_KEY
327
+
328
+ @property
329
+ def label_sub_key(self) -> Text:
330
+ """Returns label sub_key."""
331
+ return LABEL_SUB_KEY
332
+
333
+ @staticmethod
334
+ def model_class( # type: ignore[override]
335
+ use_text_as_label: bool,
336
+ ) -> Type[RasaModel]:
337
+ """Returns model class."""
338
+ if use_text_as_label:
339
+ return DIET2DIET
340
+ else:
341
+ return DIET2BOW
342
+
343
+ def _load_selector_params(self) -> None:
344
+ self.retrieval_intent = self.component_config[RETRIEVAL_INTENT]
345
+ self.use_text_as_label = self.component_config[USE_TEXT_AS_LABEL]
346
+
347
+ def _warn_about_transformer_and_hidden_layers_enabled(
348
+ self, selector_name: Text
349
+ ) -> None:
350
+ """Warns user if they enabled the transformer but didn't disable hidden layers.
351
+
352
+ ResponseSelector defaults specify considerable hidden layer sizes, but
353
+ this is for cases where no transformer is used. If a transformer exists,
354
+ then, from our experience, the best results are achieved with no hidden layers
355
+ used between the feature-combining layers and the transformer.
356
+ """
357
+ default_config = self.get_default_config()
358
+ hidden_layers_is_at_default_value = (
359
+ self.component_config[HIDDEN_LAYERS_SIZES]
360
+ == default_config[HIDDEN_LAYERS_SIZES]
361
+ )
362
+ config_for_disabling_hidden_layers: Dict[Text, List[Any]] = {
363
+ k: [] for k, _ in default_config[HIDDEN_LAYERS_SIZES].items()
364
+ }
365
+ # warn if the hidden layers aren't disabled
366
+ if (
367
+ self.component_config[HIDDEN_LAYERS_SIZES]
368
+ != config_for_disabling_hidden_layers
369
+ ):
370
+ # make the warning text more contextual by explaining what the user did
371
+ # to the hidden layers' config (i.e. what it is they should change)
372
+ if hidden_layers_is_at_default_value:
373
+ what_user_did = "left the hidden layer sizes at their default value:"
374
+ else:
375
+ what_user_did = "set the hidden layer sizes to be non-empty by setting"
376
+
377
+ rasa.shared.utils.io.raise_warning(
378
+ f"You have enabled a transformer inside {selector_name} by"
379
+ f" setting a positive value for `{NUM_TRANSFORMER_LAYERS}`, but you "
380
+ f"{what_user_did} `{HIDDEN_LAYERS_SIZES}="
381
+ f"{self.component_config[HIDDEN_LAYERS_SIZES]}`. We recommend to "
382
+ f"disable the hidden layers when using a transformer, by specifying "
383
+ f"`{HIDDEN_LAYERS_SIZES}={config_for_disabling_hidden_layers}`.",
384
+ category=UserWarning,
385
+ )
386
+
387
+ def _warn_and_correct_transformer_size(self, selector_name: Text) -> None:
388
+ """Corrects transformer size so that training doesn't break; informs the user.
389
+
390
+ If a transformer is used, the default `transformer_size` breaks things.
391
+ We need to set a reasonable default value so that the model works fine.
392
+ """
393
+ if (
394
+ self.component_config[TRANSFORMER_SIZE] is None
395
+ or self.component_config[TRANSFORMER_SIZE] < 1
396
+ ):
397
+ rasa.shared.utils.io.raise_warning(
398
+ f"`{TRANSFORMER_SIZE}` is set to "
399
+ f"`{self.component_config[TRANSFORMER_SIZE]}` for "
400
+ f"{selector_name}, but a positive size is required when using "
401
+ f"`{NUM_TRANSFORMER_LAYERS} > 0`. {selector_name} will proceed, using "
402
+ f"`{TRANSFORMER_SIZE}={DEFAULT_TRANSFORMER_SIZE}`. "
403
+ f"Alternatively, specify a different value in the component's config.",
404
+ category=UserWarning,
405
+ )
406
+ self.component_config[TRANSFORMER_SIZE] = DEFAULT_TRANSFORMER_SIZE
407
+
408
+ def _check_config_params_when_transformer_enabled(self) -> None:
409
+ """Checks & corrects config parameters when the transformer is enabled.
410
+
411
+ This is needed because the defaults for individual config parameters are
412
+ interdependent and some defaults should change when the transformer is enabled.
413
+ """
414
+ if self.component_config[NUM_TRANSFORMER_LAYERS] > 0:
415
+ selector_name = "ResponseSelector" + (
416
+ f"({self.retrieval_intent})" if self.retrieval_intent else ""
417
+ )
418
+ self._warn_about_transformer_and_hidden_layers_enabled(selector_name)
419
+ self._warn_and_correct_transformer_size(selector_name)
420
+
421
+ def _check_config_parameters(self) -> None:
422
+ """Checks that component configuration makes sense; corrects it where needed."""
423
+ super()._check_config_parameters()
424
+ self._load_selector_params()
425
+ # Once general DIET-related parameters have been checked, check also the ones
426
+ # specific to ResponseSelector.
427
+ self._check_config_params_when_transformer_enabled()
428
+
429
+ def _set_message_property(
430
+ self, message: Message, prediction_dict: Dict[Text, Any], selector_key: Text
431
+ ) -> None:
432
+ message_selector_properties = message.get(RESPONSE_SELECTOR_PROPERTY_NAME, {})
433
+ message_selector_properties[RESPONSE_SELECTOR_RETRIEVAL_INTENTS] = (
434
+ self.all_retrieval_intents
435
+ )
436
+ message_selector_properties[selector_key] = prediction_dict
437
+ message.set(
438
+ RESPONSE_SELECTOR_PROPERTY_NAME,
439
+ message_selector_properties,
440
+ add_to_output=True,
441
+ )
442
+
443
+ def preprocess_train_data(self, training_data: TrainingData) -> RasaModelData:
444
+ """Prepares data for training.
445
+
446
+ Performs sanity checks on training data, extracts encodings for labels.
447
+
448
+ Args:
449
+ training_data: training data to preprocessed.
450
+ """
451
+ # Collect all retrieval intents present in the data before filtering
452
+ self.all_retrieval_intents = list(training_data.retrieval_intents)
453
+
454
+ if self.retrieval_intent:
455
+ training_data = training_data.filter_training_examples(
456
+ lambda ex: self.retrieval_intent == ex.get(INTENT)
457
+ )
458
+ else:
459
+ # retrieval intent was left to its default value
460
+ logger.info(
461
+ "Retrieval intent parameter was left to its default value. This "
462
+ "response selector will be trained on training examples combining "
463
+ "all retrieval intents."
464
+ )
465
+
466
+ label_attribute = RESPONSE if self.use_text_as_label else INTENT_RESPONSE_KEY
467
+
468
+ label_id_index_mapping = self._label_id_index_mapping(
469
+ training_data, attribute=label_attribute
470
+ )
471
+
472
+ self.responses = training_data.responses
473
+
474
+ if not label_id_index_mapping:
475
+ # no labels are present to train
476
+ return RasaModelData()
477
+
478
+ self.index_label_id_mapping = self._invert_mapping(label_id_index_mapping)
479
+
480
+ self._label_data = self._create_label_data(
481
+ training_data, label_id_index_mapping, attribute=label_attribute
482
+ )
483
+
484
+ model_data = self._create_model_data(
485
+ training_data.intent_examples,
486
+ label_id_index_mapping,
487
+ label_attribute=label_attribute,
488
+ )
489
+
490
+ self._check_input_dimension_consistency(model_data)
491
+
492
+ return model_data
493
+
494
+ def _resolve_intent_response_key(
495
+ self, label: Dict[Text, Optional[Text]]
496
+ ) -> Optional[Text]:
497
+ """Given a label, return the response key based on the label id.
498
+
499
+ Args:
500
+ label: predicted label by the selector
501
+
502
+ Returns:
503
+ The match for the label that was found in the known responses.
504
+ It is always guaranteed to have a match, otherwise that case should have
505
+ been caught earlier and a warning should have been raised.
506
+ """
507
+ for key, responses in self.responses.items():
508
+ # First check if the predicted label was the key itself
509
+ search_key = util.template_key_to_intent_response_key(key)
510
+ if search_key == label.get("name"):
511
+ return search_key
512
+
513
+ # Otherwise loop over the responses to check if the text has a direct match
514
+ for response in responses:
515
+ if response.get(TEXT, "") == label.get("name"):
516
+ return search_key
517
+ return None
518
+
519
+ def process(self, messages: List[Message]) -> List[Message]:
520
+ """Selects most like response for message.
521
+
522
+ Args:
523
+ messages: List containing latest user message.
524
+
525
+ Returns:
526
+ List containing the message augmented with the most likely response,
527
+ the associated intent_response_key and its similarity to the input.
528
+ """
529
+ for message in messages:
530
+ out = self._predict(message)
531
+ top_label, label_ranking = self._predict_label(out)
532
+
533
+ # Get the exact intent_response_key and the associated
534
+ # responses for the top predicted label
535
+ label_intent_response_key = (
536
+ self._resolve_intent_response_key(top_label)
537
+ or top_label[INTENT_NAME_KEY]
538
+ )
539
+ label_responses = self.responses.get(
540
+ util.intent_response_key_to_template_key(label_intent_response_key)
541
+ )
542
+
543
+ if label_intent_response_key and not label_responses:
544
+ # responses seem to be unavailable,
545
+ # likely an issue with the training data
546
+ # we'll use a fallback instead
547
+ rasa.shared.utils.io.raise_warning(
548
+ f"Unable to fetch responses for {label_intent_response_key} "
549
+ f"This means that there is likely an issue with the training data."
550
+ f"Please make sure you have added responses for this intent."
551
+ )
552
+ label_responses = [{TEXT: label_intent_response_key}]
553
+
554
+ for label in label_ranking:
555
+ label[INTENT_RESPONSE_KEY] = (
556
+ self._resolve_intent_response_key(label) or label[INTENT_NAME_KEY]
557
+ )
558
+ # Remove the "name" key since it is either the same as
559
+ # "intent_response_key" or it is the response text which
560
+ # is not needed in the ranking.
561
+ label.pop(INTENT_NAME_KEY)
562
+
563
+ selector_key = (
564
+ self.retrieval_intent
565
+ if self.retrieval_intent
566
+ else RESPONSE_SELECTOR_DEFAULT_INTENT
567
+ )
568
+
569
+ logger.debug(
570
+ f"Adding following selector key to message property: {selector_key}"
571
+ )
572
+
573
+ utter_action_key = util.intent_response_key_to_template_key(
574
+ label_intent_response_key
575
+ )
576
+ prediction_dict = {
577
+ RESPONSE_SELECTOR_PREDICTION_KEY: {
578
+ RESPONSE_SELECTOR_RESPONSES_KEY: label_responses,
579
+ PREDICTED_CONFIDENCE_KEY: top_label[PREDICTED_CONFIDENCE_KEY],
580
+ INTENT_RESPONSE_KEY: label_intent_response_key,
581
+ RESPONSE_SELECTOR_UTTER_ACTION_KEY: utter_action_key,
582
+ },
583
+ RESPONSE_SELECTOR_RANKING_KEY: label_ranking,
584
+ }
585
+
586
+ self._set_message_property(message, prediction_dict, selector_key)
587
+
588
+ if (
589
+ self._execution_context.should_add_diagnostic_data
590
+ and out
591
+ and DIAGNOSTIC_DATA in out
592
+ ):
593
+ message.add_diagnostic_data(
594
+ self._execution_context.node_name, out.get(DIAGNOSTIC_DATA)
595
+ )
596
+
597
+ return messages
598
+
599
+ def persist(self) -> None:
600
+ """Persist this model into the passed directory."""
601
+ if self.model is None:
602
+ return None
603
+
604
+ with self._model_storage.write_to(self._resource) as model_path:
605
+ file_name = self.__class__.__name__
606
+
607
+ rasa.shared.utils.io.dump_obj_as_json_to_file(
608
+ model_path / f"{file_name}.responses.json", self.responses
609
+ )
610
+
611
+ rasa.shared.utils.io.dump_obj_as_json_to_file(
612
+ model_path / f"{file_name}.retrieval_intents.json",
613
+ self.all_retrieval_intents,
614
+ )
615
+
616
+ super().persist()
617
+
618
+ @classmethod
619
+ def _load_model_class(
620
+ cls,
621
+ tf_model_file: Text,
622
+ model_data_example: RasaModelData,
623
+ label_data: RasaModelData,
624
+ entity_tag_specs: List[EntityTagSpec],
625
+ config: Dict[Text, Any],
626
+ finetune_mode: bool = False,
627
+ ) -> "RasaModel":
628
+ predict_data_example = RasaModelData(
629
+ label_key=model_data_example.label_key,
630
+ data={
631
+ feature_name: features
632
+ for feature_name, features in model_data_example.items()
633
+ if TEXT in feature_name
634
+ },
635
+ )
636
+ return cls.model_class(config[USE_TEXT_AS_LABEL]).load(
637
+ tf_model_file,
638
+ model_data_example,
639
+ predict_data_example,
640
+ data_signature=model_data_example.get_signature(),
641
+ label_data=label_data,
642
+ entity_tag_specs=entity_tag_specs,
643
+ config=copy.deepcopy(config),
644
+ finetune_mode=finetune_mode,
645
+ )
646
+
647
+ def _instantiate_model_class(self, model_data: RasaModelData) -> "RasaModel":
648
+ return self.model_class(self.use_text_as_label)(
649
+ data_signature=model_data.get_signature(),
650
+ label_data=self._label_data,
651
+ entity_tag_specs=self._entity_tag_specs,
652
+ config=self.component_config,
653
+ )
654
+
655
+ @classmethod
656
+ def load(
657
+ cls,
658
+ config: Dict[Text, Any],
659
+ model_storage: ModelStorage,
660
+ resource: Resource,
661
+ execution_context: ExecutionContext,
662
+ **kwargs: Any,
663
+ ) -> ResponseSelector:
664
+ """Loads the trained model from the provided directory."""
665
+ model = super().load(
666
+ config, model_storage, resource, execution_context, **kwargs
667
+ )
668
+
669
+ try:
670
+ with model_storage.read_from(resource) as model_path:
671
+ file_name = cls.__name__
672
+ responses = rasa.shared.utils.io.read_json_file(
673
+ model_path / f"{file_name}.responses.json"
674
+ )
675
+ all_retrieval_intents = rasa.shared.utils.io.read_json_file(
676
+ model_path / f"{file_name}.retrieval_intents.json"
677
+ )
678
+ model.responses = responses
679
+ model.all_retrieval_intents = all_retrieval_intents
680
+ return model
681
+ except ValueError:
682
+ logger.debug(
683
+ f"Failed to load {cls.__name__} from model storage. Resource "
684
+ f"'{resource.name}' doesn't exist."
685
+ )
686
+ return cls(config, model_storage, resource, execution_context)
687
+
688
+
689
+ class DIET2BOW(DIET):
690
+ """DIET2BOW transformer implementation."""
691
+
692
+ def _create_metrics(self) -> None:
693
+ # self.metrics preserve order
694
+ # output losses first
695
+ self.mask_loss = tf.keras.metrics.Mean(name="m_loss")
696
+ self.response_loss = tf.keras.metrics.Mean(name="r_loss")
697
+ # output accuracies second
698
+ self.mask_acc = tf.keras.metrics.Mean(name="m_acc")
699
+ self.response_acc = tf.keras.metrics.Mean(name="r_acc")
700
+
701
+ def _update_metrics_to_log(self) -> None:
702
+ debug_log_level = logging.getLogger("rasa").level == logging.DEBUG
703
+
704
+ if self.config[MASKED_LM]:
705
+ self.metrics_to_log.append("m_acc")
706
+ if debug_log_level:
707
+ self.metrics_to_log.append("m_loss")
708
+
709
+ self.metrics_to_log.append("r_acc")
710
+ if debug_log_level:
711
+ self.metrics_to_log.append("r_loss")
712
+
713
+ self._log_metric_info()
714
+
715
+ def _log_metric_info(self) -> None:
716
+ metric_name = {"t": "total", "m": "mask", "r": "response"}
717
+ logger.debug("Following metrics will be logged during training: ")
718
+ for metric in self.metrics_to_log:
719
+ parts = metric.split("_")
720
+ name = f"{metric_name[parts[0]]} {parts[1]}"
721
+ logger.debug(f" {metric} ({name})")
722
+
723
+ def _update_label_metrics(self, loss: tf.Tensor, acc: tf.Tensor) -> None:
724
+ self.response_loss.update_state(loss)
725
+ self.response_acc.update_state(acc)
726
+
727
+
728
+ class DIET2DIET(DIET):
729
+ """Diet 2 Diet transformer implementation."""
730
+
731
+ def _check_data(self) -> None:
732
+ if TEXT not in self.data_signature:
733
+ raise InvalidConfigException(
734
+ f"No text features specified. "
735
+ f"Cannot train '{self.__class__.__name__}' model."
736
+ )
737
+ if LABEL not in self.data_signature:
738
+ raise InvalidConfigException(
739
+ f"No label features specified. "
740
+ f"Cannot train '{self.__class__.__name__}' model."
741
+ )
742
+ if (
743
+ self.config[SHARE_HIDDEN_LAYERS]
744
+ and self.data_signature[TEXT][SENTENCE]
745
+ != self.data_signature[LABEL][SENTENCE]
746
+ ):
747
+ raise ValueError(
748
+ "If hidden layer weights are shared, data signatures "
749
+ "for text_features and label_features must coincide."
750
+ )
751
+
752
+ def _create_metrics(self) -> None:
753
+ # self.metrics preserve order
754
+ # output losses first
755
+ self.mask_loss = tf.keras.metrics.Mean(name="m_loss")
756
+ self.response_loss = tf.keras.metrics.Mean(name="r_loss")
757
+ # output accuracies second
758
+ self.mask_acc = tf.keras.metrics.Mean(name="m_acc")
759
+ self.response_acc = tf.keras.metrics.Mean(name="r_acc")
760
+
761
+ def _update_metrics_to_log(self) -> None:
762
+ debug_log_level = logging.getLogger("rasa").level == logging.DEBUG
763
+
764
+ if self.config[MASKED_LM]:
765
+ self.metrics_to_log.append("m_acc")
766
+ if debug_log_level:
767
+ self.metrics_to_log.append("m_loss")
768
+
769
+ self.metrics_to_log.append("r_acc")
770
+ if debug_log_level:
771
+ self.metrics_to_log.append("r_loss")
772
+
773
+ self._log_metric_info()
774
+
775
+ def _log_metric_info(self) -> None:
776
+ metric_name = {"t": "total", "m": "mask", "r": "response"}
777
+ logger.debug("Following metrics will be logged during training: ")
778
+ for metric in self.metrics_to_log:
779
+ parts = metric.split("_")
780
+ name = f"{metric_name[parts[0]]} {parts[1]}"
781
+ logger.debug(f" {metric} ({name})")
782
+
783
+ def _prepare_layers(self) -> None:
784
+ self.text_name = TEXT
785
+ self.label_name = TEXT if self.config[SHARE_HIDDEN_LAYERS] else LABEL
786
+
787
+ # For user text and response text, prepare layers that combine different feature
788
+ # types, embed everything using a transformer and optionally also do masked
789
+ # language modeling. Omit input dropout for label features.
790
+ label_config = self.config.copy()
791
+ label_config.update({SPARSE_INPUT_DROPOUT: False, DENSE_INPUT_DROPOUT: False})
792
+ for attribute, config in [
793
+ (self.text_name, self.config),
794
+ (self.label_name, label_config),
795
+ ]:
796
+ self._tf_layers[f"sequence_layer.{attribute}"] = (
797
+ rasa_layers.RasaSequenceLayer(
798
+ attribute, self.data_signature[attribute], config
799
+ )
800
+ )
801
+
802
+ if self.config[MASKED_LM]:
803
+ self._prepare_mask_lm_loss(self.text_name)
804
+
805
+ self._prepare_label_classification_layers(predictor_attribute=self.text_name)
806
+
807
+ def _create_all_labels(self) -> Tuple[tf.Tensor, tf.Tensor]:
808
+ all_label_ids = self.tf_label_data[LABEL_KEY][LABEL_SUB_KEY][0]
809
+
810
+ sequence_feature_lengths = self._get_sequence_feature_lengths(
811
+ self.tf_label_data, LABEL
812
+ )
813
+
814
+ # Combine all feature types into one and embed using a transformer.
815
+ label_transformed, _, _, _, _, _ = self._tf_layers[
816
+ f"sequence_layer.{self.label_name}"
817
+ ](
818
+ (
819
+ self.tf_label_data[LABEL][SEQUENCE],
820
+ self.tf_label_data[LABEL][SENTENCE],
821
+ sequence_feature_lengths,
822
+ ),
823
+ training=self._training,
824
+ )
825
+
826
+ # Last token is taken from the last position with real features, determined
827
+ # - by the number of real tokens, i.e. by the sequence length of sequence-level
828
+ # features, and
829
+ # - by the presence or absence of sentence-level features (reflected in the
830
+ # effective sequence length of these features being 1 or 0.
831
+ # We need to combine the two lengths to correctly get the last position.
832
+ sentence_feature_lengths = self._get_sentence_feature_lengths(
833
+ self.tf_label_data, LABEL
834
+ )
835
+ sentence_label = self._last_token(
836
+ label_transformed, sequence_feature_lengths + sentence_feature_lengths
837
+ )
838
+
839
+ all_labels_embed = self._tf_layers[f"embed.{LABEL}"](sentence_label)
840
+
841
+ return all_label_ids, all_labels_embed
842
+
843
+ def batch_loss(
844
+ self, batch_in: Union[Tuple[tf.Tensor, ...], Tuple[np.ndarray, ...]]
845
+ ) -> tf.Tensor:
846
+ """Calculates the loss for the given batch.
847
+
848
+ Args:
849
+ batch_in: The batch.
850
+
851
+ Returns:
852
+ The loss of the given batch.
853
+ """
854
+ tf_batch_data = self.batch_to_model_data_format(batch_in, self.data_signature)
855
+
856
+ # Process all features for text.
857
+ sequence_feature_lengths_text = self._get_sequence_feature_lengths(
858
+ tf_batch_data, TEXT
859
+ )
860
+ (
861
+ text_transformed,
862
+ text_in,
863
+ _,
864
+ text_seq_ids,
865
+ mlm_mask_booleanean_text,
866
+ _,
867
+ ) = self._tf_layers[f"sequence_layer.{self.text_name}"](
868
+ (
869
+ tf_batch_data[TEXT][SEQUENCE],
870
+ tf_batch_data[TEXT][SENTENCE],
871
+ sequence_feature_lengths_text,
872
+ ),
873
+ training=self._training,
874
+ )
875
+
876
+ # Process all features for labels.
877
+ sequence_feature_lengths_label = self._get_sequence_feature_lengths(
878
+ tf_batch_data, LABEL
879
+ )
880
+ label_transformed, _, _, _, _, _ = self._tf_layers[
881
+ f"sequence_layer.{self.label_name}"
882
+ ](
883
+ (
884
+ tf_batch_data[LABEL][SEQUENCE],
885
+ tf_batch_data[LABEL][SENTENCE],
886
+ sequence_feature_lengths_label,
887
+ ),
888
+ training=self._training,
889
+ )
890
+
891
+ losses = []
892
+
893
+ if self.config[MASKED_LM]:
894
+ loss, acc = self._mask_loss(
895
+ text_transformed,
896
+ text_in,
897
+ text_seq_ids,
898
+ mlm_mask_booleanean_text,
899
+ self.text_name,
900
+ )
901
+
902
+ self.mask_loss.update_state(loss)
903
+ self.mask_acc.update_state(acc)
904
+ losses.append(loss)
905
+
906
+ # Get sentence feature vector for label classification. The vector is extracted
907
+ # from the last position with real features. To determine this position, we
908
+ # combine the sequence lengths of sequence- and sentence-level features.
909
+ sentence_feature_lengths_text = self._get_sentence_feature_lengths(
910
+ tf_batch_data, TEXT
911
+ )
912
+ sentence_vector_text = self._last_token(
913
+ text_transformed,
914
+ sequence_feature_lengths_text + sentence_feature_lengths_text,
915
+ )
916
+
917
+ # Extract sentence vector for the label attribute in the same way.
918
+ sentence_feature_lengths_label = self._get_sentence_feature_lengths(
919
+ tf_batch_data, LABEL
920
+ )
921
+ sentence_vector_label = self._last_token(
922
+ label_transformed,
923
+ sequence_feature_lengths_label + sentence_feature_lengths_label,
924
+ )
925
+ label_ids = tf_batch_data[LABEL_KEY][LABEL_SUB_KEY][0]
926
+
927
+ loss, acc = self._calculate_label_loss(
928
+ sentence_vector_text, sentence_vector_label, label_ids
929
+ )
930
+ self.response_loss.update_state(loss)
931
+ self.response_acc.update_state(acc)
932
+ losses.append(loss)
933
+
934
+ return tf.math.add_n(losses)
935
+
936
+ def batch_predict(
937
+ self, batch_in: Union[Tuple[tf.Tensor, ...], Tuple[np.ndarray, ...]]
938
+ ) -> Dict[Text, Union[tf.Tensor, Dict[Text, tf.Tensor]]]:
939
+ """Predicts the output of the given batch.
940
+
941
+ Args:
942
+ batch_in: The batch.
943
+
944
+ Returns:
945
+ The output to predict.
946
+ """
947
+ tf_batch_data = self.batch_to_model_data_format(
948
+ batch_in, self.predict_data_signature
949
+ )
950
+
951
+ sequence_feature_lengths = self._get_sequence_feature_lengths(
952
+ tf_batch_data, TEXT
953
+ )
954
+ text_transformed, _, _, _, _, attention_weights = self._tf_layers[
955
+ f"sequence_layer.{self.text_name}"
956
+ ](
957
+ (
958
+ tf_batch_data[TEXT][SEQUENCE],
959
+ tf_batch_data[TEXT][SENTENCE],
960
+ sequence_feature_lengths,
961
+ ),
962
+ training=self._training,
963
+ )
964
+
965
+ predictions = {
966
+ DIAGNOSTIC_DATA: {
967
+ "attention_weights": attention_weights,
968
+ "text_transformed": text_transformed,
969
+ }
970
+ }
971
+
972
+ if self.all_labels_embed is None:
973
+ _, self.all_labels_embed = self._create_all_labels()
974
+
975
+ # get sentence feature vector for intent classification
976
+ sentence_vector = self._last_token(text_transformed, sequence_feature_lengths)
977
+ sentence_vector_embed = self._tf_layers[f"embed.{TEXT}"](sentence_vector)
978
+
979
+ _, scores = self._tf_layers[
980
+ f"loss.{LABEL}"
981
+ ].get_similarities_and_confidences_from_embeddings(
982
+ sentence_vector_embed[:, tf.newaxis, :],
983
+ self.all_labels_embed[tf.newaxis, :, :],
984
+ )
985
+ predictions["i_scores"] = scores
986
+
987
+ return predictions