rasa-pro 3.12.0.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rasa-pro might be problematic. Click here for more details.

Files changed (790) hide show
  1. README.md +41 -0
  2. rasa/__init__.py +9 -0
  3. rasa/__main__.py +177 -0
  4. rasa/anonymization/__init__.py +2 -0
  5. rasa/anonymization/anonymisation_rule_yaml_reader.py +91 -0
  6. rasa/anonymization/anonymization_pipeline.py +286 -0
  7. rasa/anonymization/anonymization_rule_executor.py +260 -0
  8. rasa/anonymization/anonymization_rule_orchestrator.py +120 -0
  9. rasa/anonymization/schemas/config.yml +47 -0
  10. rasa/anonymization/utils.py +118 -0
  11. rasa/api.py +160 -0
  12. rasa/cli/__init__.py +5 -0
  13. rasa/cli/arguments/__init__.py +0 -0
  14. rasa/cli/arguments/data.py +106 -0
  15. rasa/cli/arguments/default_arguments.py +207 -0
  16. rasa/cli/arguments/evaluate.py +65 -0
  17. rasa/cli/arguments/export.py +51 -0
  18. rasa/cli/arguments/interactive.py +74 -0
  19. rasa/cli/arguments/run.py +219 -0
  20. rasa/cli/arguments/shell.py +17 -0
  21. rasa/cli/arguments/test.py +211 -0
  22. rasa/cli/arguments/train.py +279 -0
  23. rasa/cli/arguments/visualize.py +34 -0
  24. rasa/cli/arguments/x.py +30 -0
  25. rasa/cli/data.py +354 -0
  26. rasa/cli/dialogue_understanding_test.py +251 -0
  27. rasa/cli/e2e_test.py +259 -0
  28. rasa/cli/evaluate.py +222 -0
  29. rasa/cli/export.py +250 -0
  30. rasa/cli/inspect.py +75 -0
  31. rasa/cli/interactive.py +166 -0
  32. rasa/cli/license.py +65 -0
  33. rasa/cli/llm_fine_tuning.py +403 -0
  34. rasa/cli/markers.py +78 -0
  35. rasa/cli/project_templates/__init__.py +0 -0
  36. rasa/cli/project_templates/calm/actions/__init__.py +0 -0
  37. rasa/cli/project_templates/calm/actions/action_template.py +27 -0
  38. rasa/cli/project_templates/calm/actions/add_contact.py +30 -0
  39. rasa/cli/project_templates/calm/actions/db.py +57 -0
  40. rasa/cli/project_templates/calm/actions/list_contacts.py +22 -0
  41. rasa/cli/project_templates/calm/actions/remove_contact.py +35 -0
  42. rasa/cli/project_templates/calm/config.yml +10 -0
  43. rasa/cli/project_templates/calm/credentials.yml +33 -0
  44. rasa/cli/project_templates/calm/data/flows/add_contact.yml +31 -0
  45. rasa/cli/project_templates/calm/data/flows/list_contacts.yml +14 -0
  46. rasa/cli/project_templates/calm/data/flows/remove_contact.yml +29 -0
  47. rasa/cli/project_templates/calm/db/contacts.json +10 -0
  48. rasa/cli/project_templates/calm/domain/add_contact.yml +39 -0
  49. rasa/cli/project_templates/calm/domain/list_contacts.yml +17 -0
  50. rasa/cli/project_templates/calm/domain/remove_contact.yml +38 -0
  51. rasa/cli/project_templates/calm/domain/shared.yml +10 -0
  52. rasa/cli/project_templates/calm/e2e_tests/cancelations/user_cancels_during_a_correction.yml +16 -0
  53. rasa/cli/project_templates/calm/e2e_tests/cancelations/user_changes_mind_on_a_whim.yml +7 -0
  54. rasa/cli/project_templates/calm/e2e_tests/corrections/user_corrects_contact_handle.yml +20 -0
  55. rasa/cli/project_templates/calm/e2e_tests/corrections/user_corrects_contact_name.yml +19 -0
  56. rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_adds_contact_to_their_list.yml +15 -0
  57. rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_lists_contacts.yml +5 -0
  58. rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_removes_contact.yml +11 -0
  59. rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_removes_contact_from_list.yml +12 -0
  60. rasa/cli/project_templates/calm/endpoints.yml +58 -0
  61. rasa/cli/project_templates/default/actions/__init__.py +0 -0
  62. rasa/cli/project_templates/default/actions/actions.py +27 -0
  63. rasa/cli/project_templates/default/config.yml +44 -0
  64. rasa/cli/project_templates/default/credentials.yml +33 -0
  65. rasa/cli/project_templates/default/data/nlu.yml +91 -0
  66. rasa/cli/project_templates/default/data/rules.yml +13 -0
  67. rasa/cli/project_templates/default/data/stories.yml +30 -0
  68. rasa/cli/project_templates/default/domain.yml +34 -0
  69. rasa/cli/project_templates/default/endpoints.yml +42 -0
  70. rasa/cli/project_templates/default/tests/test_stories.yml +91 -0
  71. rasa/cli/project_templates/tutorial/actions/__init__.py +0 -0
  72. rasa/cli/project_templates/tutorial/actions/actions.py +22 -0
  73. rasa/cli/project_templates/tutorial/config.yml +12 -0
  74. rasa/cli/project_templates/tutorial/credentials.yml +33 -0
  75. rasa/cli/project_templates/tutorial/data/flows.yml +8 -0
  76. rasa/cli/project_templates/tutorial/data/patterns.yml +11 -0
  77. rasa/cli/project_templates/tutorial/domain.yml +35 -0
  78. rasa/cli/project_templates/tutorial/endpoints.yml +55 -0
  79. rasa/cli/run.py +143 -0
  80. rasa/cli/scaffold.py +273 -0
  81. rasa/cli/shell.py +141 -0
  82. rasa/cli/studio/__init__.py +0 -0
  83. rasa/cli/studio/download.py +62 -0
  84. rasa/cli/studio/studio.py +296 -0
  85. rasa/cli/studio/train.py +59 -0
  86. rasa/cli/studio/upload.py +62 -0
  87. rasa/cli/telemetry.py +102 -0
  88. rasa/cli/test.py +280 -0
  89. rasa/cli/train.py +278 -0
  90. rasa/cli/utils.py +484 -0
  91. rasa/cli/visualize.py +40 -0
  92. rasa/cli/x.py +206 -0
  93. rasa/constants.py +45 -0
  94. rasa/core/__init__.py +17 -0
  95. rasa/core/actions/__init__.py +0 -0
  96. rasa/core/actions/action.py +1318 -0
  97. rasa/core/actions/action_clean_stack.py +59 -0
  98. rasa/core/actions/action_exceptions.py +24 -0
  99. rasa/core/actions/action_hangup.py +29 -0
  100. rasa/core/actions/action_repeat_bot_messages.py +89 -0
  101. rasa/core/actions/action_run_slot_rejections.py +210 -0
  102. rasa/core/actions/action_trigger_chitchat.py +31 -0
  103. rasa/core/actions/action_trigger_flow.py +109 -0
  104. rasa/core/actions/action_trigger_search.py +31 -0
  105. rasa/core/actions/constants.py +5 -0
  106. rasa/core/actions/custom_action_executor.py +191 -0
  107. rasa/core/actions/direct_custom_actions_executor.py +109 -0
  108. rasa/core/actions/e2e_stub_custom_action_executor.py +72 -0
  109. rasa/core/actions/forms.py +741 -0
  110. rasa/core/actions/grpc_custom_action_executor.py +251 -0
  111. rasa/core/actions/http_custom_action_executor.py +145 -0
  112. rasa/core/actions/loops.py +114 -0
  113. rasa/core/actions/two_stage_fallback.py +186 -0
  114. rasa/core/agent.py +559 -0
  115. rasa/core/auth_retry_tracker_store.py +122 -0
  116. rasa/core/brokers/__init__.py +0 -0
  117. rasa/core/brokers/broker.py +126 -0
  118. rasa/core/brokers/file.py +58 -0
  119. rasa/core/brokers/kafka.py +324 -0
  120. rasa/core/brokers/pika.py +388 -0
  121. rasa/core/brokers/sql.py +86 -0
  122. rasa/core/channels/__init__.py +61 -0
  123. rasa/core/channels/botframework.py +338 -0
  124. rasa/core/channels/callback.py +84 -0
  125. rasa/core/channels/channel.py +456 -0
  126. rasa/core/channels/console.py +241 -0
  127. rasa/core/channels/development_inspector.py +197 -0
  128. rasa/core/channels/facebook.py +419 -0
  129. rasa/core/channels/hangouts.py +329 -0
  130. rasa/core/channels/inspector/.eslintrc.cjs +25 -0
  131. rasa/core/channels/inspector/.gitignore +23 -0
  132. rasa/core/channels/inspector/README.md +54 -0
  133. rasa/core/channels/inspector/assets/favicon.ico +0 -0
  134. rasa/core/channels/inspector/assets/rasa-chat.js +2 -0
  135. rasa/core/channels/inspector/custom.d.ts +3 -0
  136. rasa/core/channels/inspector/dist/assets/arc-861ddd57.js +1 -0
  137. rasa/core/channels/inspector/dist/assets/array-9f3ba611.js +1 -0
  138. rasa/core/channels/inspector/dist/assets/c4Diagram-d0fbc5ce-921f02db.js +10 -0
  139. rasa/core/channels/inspector/dist/assets/classDiagram-936ed81e-b436c4f8.js +2 -0
  140. rasa/core/channels/inspector/dist/assets/classDiagram-v2-c3cb15f1-511a23cb.js +2 -0
  141. rasa/core/channels/inspector/dist/assets/createText-62fc7601-ef476ecd.js +7 -0
  142. rasa/core/channels/inspector/dist/assets/edges-f2ad444c-f1878e0a.js +4 -0
  143. rasa/core/channels/inspector/dist/assets/erDiagram-9d236eb7-fac75185.js +51 -0
  144. rasa/core/channels/inspector/dist/assets/flowDb-1972c806-201c5bbc.js +6 -0
  145. rasa/core/channels/inspector/dist/assets/flowDiagram-7ea5b25a-f904ae41.js +4 -0
  146. rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-b080d6f2.js +1 -0
  147. rasa/core/channels/inspector/dist/assets/flowchart-elk-definition-abe16c3d-1813da66.js +139 -0
  148. rasa/core/channels/inspector/dist/assets/ganttDiagram-9b5ea136-872af172.js +266 -0
  149. rasa/core/channels/inspector/dist/assets/gitGraphDiagram-99d0ae7c-34a0af5a.js +70 -0
  150. rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-128cfa44.ttf +0 -0
  151. rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-21dbcb97.woff +0 -0
  152. rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-222b5e26.svg +329 -0
  153. rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-9ad89b2a.woff2 +0 -0
  154. rasa/core/channels/inspector/dist/assets/index-2c4b9a3b-42ba3e3d.js +1 -0
  155. rasa/core/channels/inspector/dist/assets/index-37817b51.js +1317 -0
  156. rasa/core/channels/inspector/dist/assets/index-3ee28881.css +1 -0
  157. rasa/core/channels/inspector/dist/assets/infoDiagram-736b4530-6b731386.js +7 -0
  158. rasa/core/channels/inspector/dist/assets/init-77b53fdd.js +1 -0
  159. rasa/core/channels/inspector/dist/assets/journeyDiagram-df861f2b-e8579ac6.js +139 -0
  160. rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-60c05ee4.woff +0 -0
  161. rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-8335d9b8.svg +438 -0
  162. rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-9cc39c75.ttf +0 -0
  163. rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-ead13ccf.woff2 +0 -0
  164. rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-16705655.woff2 +0 -0
  165. rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-5aeb07f9.woff +0 -0
  166. rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-9c459044.ttf +0 -0
  167. rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-9e2898a4.svg +435 -0
  168. rasa/core/channels/inspector/dist/assets/layout-89e6403a.js +1 -0
  169. rasa/core/channels/inspector/dist/assets/line-dc73d3fc.js +1 -0
  170. rasa/core/channels/inspector/dist/assets/linear-f5b1d2bc.js +1 -0
  171. rasa/core/channels/inspector/dist/assets/mindmap-definition-beec6740-82cb74fa.js +109 -0
  172. rasa/core/channels/inspector/dist/assets/ordinal-ba9b4969.js +1 -0
  173. rasa/core/channels/inspector/dist/assets/path-53f90ab3.js +1 -0
  174. rasa/core/channels/inspector/dist/assets/pieDiagram-dbbf0591-bdf5f29b.js +35 -0
  175. rasa/core/channels/inspector/dist/assets/quadrantDiagram-4d7f4fd6-c7a0cbe4.js +7 -0
  176. rasa/core/channels/inspector/dist/assets/requirementDiagram-6fc4c22a-7ec5410f.js +52 -0
  177. rasa/core/channels/inspector/dist/assets/sankeyDiagram-8f13d901-caee5554.js +8 -0
  178. rasa/core/channels/inspector/dist/assets/sequenceDiagram-b655622a-2935f8db.js +122 -0
  179. rasa/core/channels/inspector/dist/assets/stateDiagram-59f0c015-8f5d9693.js +1 -0
  180. rasa/core/channels/inspector/dist/assets/stateDiagram-v2-2b26beab-d565d1de.js +1 -0
  181. rasa/core/channels/inspector/dist/assets/styles-080da4f6-75ad421d.js +110 -0
  182. rasa/core/channels/inspector/dist/assets/styles-3dcbcfbf-7e764226.js +159 -0
  183. rasa/core/channels/inspector/dist/assets/styles-9c745c82-7a4e0e61.js +207 -0
  184. rasa/core/channels/inspector/dist/assets/svgDrawCommon-4835440b-4019d1bf.js +1 -0
  185. rasa/core/channels/inspector/dist/assets/timeline-definition-5b62e21b-01ea12df.js +61 -0
  186. rasa/core/channels/inspector/dist/assets/xychartDiagram-2b33534f-89407137.js +7 -0
  187. rasa/core/channels/inspector/dist/index.html +42 -0
  188. rasa/core/channels/inspector/index.html +40 -0
  189. rasa/core/channels/inspector/jest.config.ts +13 -0
  190. rasa/core/channels/inspector/package.json +52 -0
  191. rasa/core/channels/inspector/setupTests.ts +2 -0
  192. rasa/core/channels/inspector/src/App.tsx +220 -0
  193. rasa/core/channels/inspector/src/components/Chat.tsx +95 -0
  194. rasa/core/channels/inspector/src/components/DiagramFlow.tsx +108 -0
  195. rasa/core/channels/inspector/src/components/DialogueInformation.tsx +187 -0
  196. rasa/core/channels/inspector/src/components/DialogueStack.tsx +136 -0
  197. rasa/core/channels/inspector/src/components/ExpandIcon.tsx +16 -0
  198. rasa/core/channels/inspector/src/components/FullscreenButton.tsx +45 -0
  199. rasa/core/channels/inspector/src/components/LoadingSpinner.tsx +22 -0
  200. rasa/core/channels/inspector/src/components/NoActiveFlow.tsx +21 -0
  201. rasa/core/channels/inspector/src/components/RasaLogo.tsx +32 -0
  202. rasa/core/channels/inspector/src/components/SaraDiagrams.tsx +39 -0
  203. rasa/core/channels/inspector/src/components/Slots.tsx +91 -0
  204. rasa/core/channels/inspector/src/components/Welcome.tsx +54 -0
  205. rasa/core/channels/inspector/src/helpers/audiostream.ts +191 -0
  206. rasa/core/channels/inspector/src/helpers/formatters.test.ts +392 -0
  207. rasa/core/channels/inspector/src/helpers/formatters.ts +306 -0
  208. rasa/core/channels/inspector/src/helpers/utils.ts +127 -0
  209. rasa/core/channels/inspector/src/main.tsx +13 -0
  210. rasa/core/channels/inspector/src/theme/Button/Button.ts +29 -0
  211. rasa/core/channels/inspector/src/theme/Heading/Heading.ts +31 -0
  212. rasa/core/channels/inspector/src/theme/Input/Input.ts +27 -0
  213. rasa/core/channels/inspector/src/theme/Link/Link.ts +10 -0
  214. rasa/core/channels/inspector/src/theme/Modal/Modal.ts +47 -0
  215. rasa/core/channels/inspector/src/theme/Table/Table.tsx +38 -0
  216. rasa/core/channels/inspector/src/theme/Tooltip/Tooltip.ts +12 -0
  217. rasa/core/channels/inspector/src/theme/base/breakpoints.ts +8 -0
  218. rasa/core/channels/inspector/src/theme/base/colors.ts +88 -0
  219. rasa/core/channels/inspector/src/theme/base/fonts/fontFaces.css +29 -0
  220. rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.eot +0 -0
  221. rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.svg +329 -0
  222. rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.ttf +0 -0
  223. rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.woff +0 -0
  224. rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.woff2 +0 -0
  225. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.eot +0 -0
  226. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.svg +438 -0
  227. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.ttf +0 -0
  228. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.woff +0 -0
  229. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.woff2 +0 -0
  230. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.eot +0 -0
  231. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.svg +435 -0
  232. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.ttf +0 -0
  233. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.woff +0 -0
  234. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.woff2 +0 -0
  235. rasa/core/channels/inspector/src/theme/base/radii.ts +9 -0
  236. rasa/core/channels/inspector/src/theme/base/shadows.ts +7 -0
  237. rasa/core/channels/inspector/src/theme/base/sizes.ts +7 -0
  238. rasa/core/channels/inspector/src/theme/base/space.ts +15 -0
  239. rasa/core/channels/inspector/src/theme/base/styles.ts +13 -0
  240. rasa/core/channels/inspector/src/theme/base/typography.ts +24 -0
  241. rasa/core/channels/inspector/src/theme/base/zIndices.ts +19 -0
  242. rasa/core/channels/inspector/src/theme/index.ts +101 -0
  243. rasa/core/channels/inspector/src/types.ts +84 -0
  244. rasa/core/channels/inspector/src/vite-env.d.ts +1 -0
  245. rasa/core/channels/inspector/tests/__mocks__/fileMock.ts +1 -0
  246. rasa/core/channels/inspector/tests/__mocks__/matchMedia.ts +16 -0
  247. rasa/core/channels/inspector/tests/__mocks__/styleMock.ts +1 -0
  248. rasa/core/channels/inspector/tests/renderWithProviders.tsx +14 -0
  249. rasa/core/channels/inspector/tsconfig.json +26 -0
  250. rasa/core/channels/inspector/tsconfig.node.json +10 -0
  251. rasa/core/channels/inspector/vite.config.ts +8 -0
  252. rasa/core/channels/inspector/yarn.lock +6249 -0
  253. rasa/core/channels/mattermost.py +229 -0
  254. rasa/core/channels/rasa_chat.py +126 -0
  255. rasa/core/channels/rest.py +230 -0
  256. rasa/core/channels/rocketchat.py +174 -0
  257. rasa/core/channels/slack.py +620 -0
  258. rasa/core/channels/socketio.py +302 -0
  259. rasa/core/channels/telegram.py +298 -0
  260. rasa/core/channels/twilio.py +169 -0
  261. rasa/core/channels/vier_cvg.py +374 -0
  262. rasa/core/channels/voice_ready/__init__.py +0 -0
  263. rasa/core/channels/voice_ready/audiocodes.py +501 -0
  264. rasa/core/channels/voice_ready/jambonz.py +121 -0
  265. rasa/core/channels/voice_ready/jambonz_protocol.py +396 -0
  266. rasa/core/channels/voice_ready/twilio_voice.py +403 -0
  267. rasa/core/channels/voice_ready/utils.py +37 -0
  268. rasa/core/channels/voice_stream/__init__.py +0 -0
  269. rasa/core/channels/voice_stream/asr/__init__.py +0 -0
  270. rasa/core/channels/voice_stream/asr/asr_engine.py +89 -0
  271. rasa/core/channels/voice_stream/asr/asr_event.py +18 -0
  272. rasa/core/channels/voice_stream/asr/azure.py +130 -0
  273. rasa/core/channels/voice_stream/asr/deepgram.py +90 -0
  274. rasa/core/channels/voice_stream/audio_bytes.py +8 -0
  275. rasa/core/channels/voice_stream/browser_audio.py +107 -0
  276. rasa/core/channels/voice_stream/call_state.py +23 -0
  277. rasa/core/channels/voice_stream/tts/__init__.py +0 -0
  278. rasa/core/channels/voice_stream/tts/azure.py +106 -0
  279. rasa/core/channels/voice_stream/tts/cartesia.py +118 -0
  280. rasa/core/channels/voice_stream/tts/tts_cache.py +27 -0
  281. rasa/core/channels/voice_stream/tts/tts_engine.py +58 -0
  282. rasa/core/channels/voice_stream/twilio_media_streams.py +173 -0
  283. rasa/core/channels/voice_stream/util.py +57 -0
  284. rasa/core/channels/voice_stream/voice_channel.py +427 -0
  285. rasa/core/channels/webexteams.py +134 -0
  286. rasa/core/concurrent_lock_store.py +210 -0
  287. rasa/core/constants.py +112 -0
  288. rasa/core/evaluation/__init__.py +0 -0
  289. rasa/core/evaluation/marker.py +267 -0
  290. rasa/core/evaluation/marker_base.py +923 -0
  291. rasa/core/evaluation/marker_stats.py +293 -0
  292. rasa/core/evaluation/marker_tracker_loader.py +103 -0
  293. rasa/core/exceptions.py +29 -0
  294. rasa/core/exporter.py +284 -0
  295. rasa/core/featurizers/__init__.py +0 -0
  296. rasa/core/featurizers/precomputation.py +410 -0
  297. rasa/core/featurizers/single_state_featurizer.py +421 -0
  298. rasa/core/featurizers/tracker_featurizers.py +1262 -0
  299. rasa/core/http_interpreter.py +89 -0
  300. rasa/core/information_retrieval/__init__.py +7 -0
  301. rasa/core/information_retrieval/faiss.py +124 -0
  302. rasa/core/information_retrieval/information_retrieval.py +137 -0
  303. rasa/core/information_retrieval/milvus.py +59 -0
  304. rasa/core/information_retrieval/qdrant.py +96 -0
  305. rasa/core/jobs.py +63 -0
  306. rasa/core/lock.py +139 -0
  307. rasa/core/lock_store.py +343 -0
  308. rasa/core/migrate.py +403 -0
  309. rasa/core/nlg/__init__.py +3 -0
  310. rasa/core/nlg/callback.py +146 -0
  311. rasa/core/nlg/contextual_response_rephraser.py +320 -0
  312. rasa/core/nlg/generator.py +230 -0
  313. rasa/core/nlg/interpolator.py +143 -0
  314. rasa/core/nlg/response.py +155 -0
  315. rasa/core/nlg/summarize.py +70 -0
  316. rasa/core/persistor.py +538 -0
  317. rasa/core/policies/__init__.py +0 -0
  318. rasa/core/policies/ensemble.py +329 -0
  319. rasa/core/policies/enterprise_search_policy.py +905 -0
  320. rasa/core/policies/enterprise_search_prompt_template.jinja2 +25 -0
  321. rasa/core/policies/enterprise_search_prompt_with_citation_template.jinja2 +60 -0
  322. rasa/core/policies/flow_policy.py +205 -0
  323. rasa/core/policies/flows/__init__.py +0 -0
  324. rasa/core/policies/flows/flow_exceptions.py +44 -0
  325. rasa/core/policies/flows/flow_executor.py +754 -0
  326. rasa/core/policies/flows/flow_step_result.py +43 -0
  327. rasa/core/policies/intentless_policy.py +1031 -0
  328. rasa/core/policies/intentless_prompt_template.jinja2 +22 -0
  329. rasa/core/policies/memoization.py +538 -0
  330. rasa/core/policies/policy.py +725 -0
  331. rasa/core/policies/rule_policy.py +1273 -0
  332. rasa/core/policies/ted_policy.py +2169 -0
  333. rasa/core/policies/unexpected_intent_policy.py +1022 -0
  334. rasa/core/processor.py +1465 -0
  335. rasa/core/run.py +342 -0
  336. rasa/core/secrets_manager/__init__.py +0 -0
  337. rasa/core/secrets_manager/constants.py +36 -0
  338. rasa/core/secrets_manager/endpoints.py +391 -0
  339. rasa/core/secrets_manager/factory.py +241 -0
  340. rasa/core/secrets_manager/secret_manager.py +262 -0
  341. rasa/core/secrets_manager/vault.py +584 -0
  342. rasa/core/test.py +1335 -0
  343. rasa/core/tracker_store.py +1703 -0
  344. rasa/core/train.py +105 -0
  345. rasa/core/training/__init__.py +89 -0
  346. rasa/core/training/converters/__init__.py +0 -0
  347. rasa/core/training/converters/responses_prefix_converter.py +119 -0
  348. rasa/core/training/interactive.py +1744 -0
  349. rasa/core/training/story_conflict.py +381 -0
  350. rasa/core/training/training.py +93 -0
  351. rasa/core/utils.py +366 -0
  352. rasa/core/visualize.py +70 -0
  353. rasa/dialogue_understanding/__init__.py +0 -0
  354. rasa/dialogue_understanding/coexistence/__init__.py +0 -0
  355. rasa/dialogue_understanding/coexistence/constants.py +4 -0
  356. rasa/dialogue_understanding/coexistence/intent_based_router.py +196 -0
  357. rasa/dialogue_understanding/coexistence/llm_based_router.py +327 -0
  358. rasa/dialogue_understanding/coexistence/router_template.jinja2 +12 -0
  359. rasa/dialogue_understanding/commands/__init__.py +61 -0
  360. rasa/dialogue_understanding/commands/can_not_handle_command.py +70 -0
  361. rasa/dialogue_understanding/commands/cancel_flow_command.py +125 -0
  362. rasa/dialogue_understanding/commands/change_flow_command.py +44 -0
  363. rasa/dialogue_understanding/commands/chit_chat_answer_command.py +57 -0
  364. rasa/dialogue_understanding/commands/clarify_command.py +86 -0
  365. rasa/dialogue_understanding/commands/command.py +85 -0
  366. rasa/dialogue_understanding/commands/correct_slots_command.py +297 -0
  367. rasa/dialogue_understanding/commands/error_command.py +79 -0
  368. rasa/dialogue_understanding/commands/free_form_answer_command.py +9 -0
  369. rasa/dialogue_understanding/commands/handle_code_change_command.py +73 -0
  370. rasa/dialogue_understanding/commands/human_handoff_command.py +66 -0
  371. rasa/dialogue_understanding/commands/knowledge_answer_command.py +57 -0
  372. rasa/dialogue_understanding/commands/noop_command.py +54 -0
  373. rasa/dialogue_understanding/commands/repeat_bot_messages_command.py +60 -0
  374. rasa/dialogue_understanding/commands/restart_command.py +58 -0
  375. rasa/dialogue_understanding/commands/session_end_command.py +61 -0
  376. rasa/dialogue_understanding/commands/session_start_command.py +59 -0
  377. rasa/dialogue_understanding/commands/set_slot_command.py +160 -0
  378. rasa/dialogue_understanding/commands/skip_question_command.py +75 -0
  379. rasa/dialogue_understanding/commands/start_flow_command.py +107 -0
  380. rasa/dialogue_understanding/commands/user_silence_command.py +59 -0
  381. rasa/dialogue_understanding/commands/utils.py +45 -0
  382. rasa/dialogue_understanding/generator/__init__.py +21 -0
  383. rasa/dialogue_understanding/generator/command_generator.py +464 -0
  384. rasa/dialogue_understanding/generator/constants.py +27 -0
  385. rasa/dialogue_understanding/generator/flow_document_template.jinja2 +4 -0
  386. rasa/dialogue_understanding/generator/flow_retrieval.py +466 -0
  387. rasa/dialogue_understanding/generator/llm_based_command_generator.py +500 -0
  388. rasa/dialogue_understanding/generator/llm_command_generator.py +67 -0
  389. rasa/dialogue_understanding/generator/multi_step/__init__.py +0 -0
  390. rasa/dialogue_understanding/generator/multi_step/fill_slots_prompt.jinja2 +62 -0
  391. rasa/dialogue_understanding/generator/multi_step/handle_flows_prompt.jinja2 +38 -0
  392. rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +920 -0
  393. rasa/dialogue_understanding/generator/nlu_command_adapter.py +261 -0
  394. rasa/dialogue_understanding/generator/single_step/__init__.py +0 -0
  395. rasa/dialogue_understanding/generator/single_step/command_prompt_template.jinja2 +60 -0
  396. rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +486 -0
  397. rasa/dialogue_understanding/patterns/__init__.py +0 -0
  398. rasa/dialogue_understanding/patterns/cancel.py +111 -0
  399. rasa/dialogue_understanding/patterns/cannot_handle.py +43 -0
  400. rasa/dialogue_understanding/patterns/chitchat.py +37 -0
  401. rasa/dialogue_understanding/patterns/clarify.py +97 -0
  402. rasa/dialogue_understanding/patterns/code_change.py +41 -0
  403. rasa/dialogue_understanding/patterns/collect_information.py +90 -0
  404. rasa/dialogue_understanding/patterns/completed.py +40 -0
  405. rasa/dialogue_understanding/patterns/continue_interrupted.py +42 -0
  406. rasa/dialogue_understanding/patterns/correction.py +278 -0
  407. rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +301 -0
  408. rasa/dialogue_understanding/patterns/human_handoff.py +37 -0
  409. rasa/dialogue_understanding/patterns/internal_error.py +47 -0
  410. rasa/dialogue_understanding/patterns/repeat.py +37 -0
  411. rasa/dialogue_understanding/patterns/restart.py +37 -0
  412. rasa/dialogue_understanding/patterns/search.py +37 -0
  413. rasa/dialogue_understanding/patterns/session_start.py +37 -0
  414. rasa/dialogue_understanding/patterns/skip_question.py +38 -0
  415. rasa/dialogue_understanding/patterns/user_silence.py +37 -0
  416. rasa/dialogue_understanding/processor/__init__.py +0 -0
  417. rasa/dialogue_understanding/processor/command_processor.py +720 -0
  418. rasa/dialogue_understanding/processor/command_processor_component.py +43 -0
  419. rasa/dialogue_understanding/stack/__init__.py +0 -0
  420. rasa/dialogue_understanding/stack/dialogue_stack.py +178 -0
  421. rasa/dialogue_understanding/stack/frames/__init__.py +19 -0
  422. rasa/dialogue_understanding/stack/frames/chit_chat_frame.py +27 -0
  423. rasa/dialogue_understanding/stack/frames/dialogue_stack_frame.py +137 -0
  424. rasa/dialogue_understanding/stack/frames/flow_stack_frame.py +157 -0
  425. rasa/dialogue_understanding/stack/frames/pattern_frame.py +10 -0
  426. rasa/dialogue_understanding/stack/frames/search_frame.py +27 -0
  427. rasa/dialogue_understanding/stack/utils.py +211 -0
  428. rasa/dialogue_understanding/utils.py +14 -0
  429. rasa/dialogue_understanding_test/__init__.py +0 -0
  430. rasa/dialogue_understanding_test/command_metric_calculation.py +12 -0
  431. rasa/dialogue_understanding_test/constants.py +17 -0
  432. rasa/dialogue_understanding_test/du_test_case.py +118 -0
  433. rasa/dialogue_understanding_test/du_test_result.py +11 -0
  434. rasa/dialogue_understanding_test/du_test_runner.py +93 -0
  435. rasa/dialogue_understanding_test/io.py +54 -0
  436. rasa/dialogue_understanding_test/validation.py +22 -0
  437. rasa/e2e_test/__init__.py +0 -0
  438. rasa/e2e_test/aggregate_test_stats_calculator.py +134 -0
  439. rasa/e2e_test/assertions.py +1345 -0
  440. rasa/e2e_test/assertions_schema.yml +129 -0
  441. rasa/e2e_test/constants.py +31 -0
  442. rasa/e2e_test/e2e_config.py +220 -0
  443. rasa/e2e_test/e2e_config_schema.yml +26 -0
  444. rasa/e2e_test/e2e_test_case.py +569 -0
  445. rasa/e2e_test/e2e_test_converter.py +363 -0
  446. rasa/e2e_test/e2e_test_converter_prompt.jinja2 +70 -0
  447. rasa/e2e_test/e2e_test_coverage_report.py +364 -0
  448. rasa/e2e_test/e2e_test_result.py +54 -0
  449. rasa/e2e_test/e2e_test_runner.py +1192 -0
  450. rasa/e2e_test/e2e_test_schema.yml +181 -0
  451. rasa/e2e_test/pykwalify_extensions.py +39 -0
  452. rasa/e2e_test/stub_custom_action.py +70 -0
  453. rasa/e2e_test/utils/__init__.py +0 -0
  454. rasa/e2e_test/utils/e2e_yaml_utils.py +55 -0
  455. rasa/e2e_test/utils/io.py +598 -0
  456. rasa/e2e_test/utils/validation.py +178 -0
  457. rasa/engine/__init__.py +0 -0
  458. rasa/engine/caching.py +463 -0
  459. rasa/engine/constants.py +17 -0
  460. rasa/engine/exceptions.py +14 -0
  461. rasa/engine/graph.py +642 -0
  462. rasa/engine/loader.py +48 -0
  463. rasa/engine/recipes/__init__.py +0 -0
  464. rasa/engine/recipes/config_files/default_config.yml +41 -0
  465. rasa/engine/recipes/default_components.py +97 -0
  466. rasa/engine/recipes/default_recipe.py +1272 -0
  467. rasa/engine/recipes/graph_recipe.py +79 -0
  468. rasa/engine/recipes/recipe.py +93 -0
  469. rasa/engine/runner/__init__.py +0 -0
  470. rasa/engine/runner/dask.py +250 -0
  471. rasa/engine/runner/interface.py +49 -0
  472. rasa/engine/storage/__init__.py +0 -0
  473. rasa/engine/storage/local_model_storage.py +244 -0
  474. rasa/engine/storage/resource.py +110 -0
  475. rasa/engine/storage/storage.py +199 -0
  476. rasa/engine/training/__init__.py +0 -0
  477. rasa/engine/training/components.py +176 -0
  478. rasa/engine/training/fingerprinting.py +64 -0
  479. rasa/engine/training/graph_trainer.py +256 -0
  480. rasa/engine/training/hooks.py +164 -0
  481. rasa/engine/validation.py +1451 -0
  482. rasa/env.py +14 -0
  483. rasa/exceptions.py +69 -0
  484. rasa/graph_components/__init__.py +0 -0
  485. rasa/graph_components/converters/__init__.py +0 -0
  486. rasa/graph_components/converters/nlu_message_converter.py +48 -0
  487. rasa/graph_components/providers/__init__.py +0 -0
  488. rasa/graph_components/providers/domain_for_core_training_provider.py +87 -0
  489. rasa/graph_components/providers/domain_provider.py +71 -0
  490. rasa/graph_components/providers/flows_provider.py +74 -0
  491. rasa/graph_components/providers/forms_provider.py +44 -0
  492. rasa/graph_components/providers/nlu_training_data_provider.py +56 -0
  493. rasa/graph_components/providers/responses_provider.py +44 -0
  494. rasa/graph_components/providers/rule_only_provider.py +49 -0
  495. rasa/graph_components/providers/story_graph_provider.py +96 -0
  496. rasa/graph_components/providers/training_tracker_provider.py +55 -0
  497. rasa/graph_components/validators/__init__.py +0 -0
  498. rasa/graph_components/validators/default_recipe_validator.py +550 -0
  499. rasa/graph_components/validators/finetuning_validator.py +302 -0
  500. rasa/hooks.py +111 -0
  501. rasa/jupyter.py +63 -0
  502. rasa/llm_fine_tuning/__init__.py +0 -0
  503. rasa/llm_fine_tuning/annotation_module.py +241 -0
  504. rasa/llm_fine_tuning/conversations.py +144 -0
  505. rasa/llm_fine_tuning/llm_data_preparation_module.py +178 -0
  506. rasa/llm_fine_tuning/paraphrasing/__init__.py +0 -0
  507. rasa/llm_fine_tuning/paraphrasing/conversation_rephraser.py +281 -0
  508. rasa/llm_fine_tuning/paraphrasing/default_rephrase_prompt_template.jina2 +44 -0
  509. rasa/llm_fine_tuning/paraphrasing/rephrase_validator.py +121 -0
  510. rasa/llm_fine_tuning/paraphrasing/rephrased_user_message.py +10 -0
  511. rasa/llm_fine_tuning/paraphrasing_module.py +128 -0
  512. rasa/llm_fine_tuning/storage.py +174 -0
  513. rasa/llm_fine_tuning/train_test_split_module.py +441 -0
  514. rasa/markers/__init__.py +0 -0
  515. rasa/markers/marker.py +269 -0
  516. rasa/markers/marker_base.py +828 -0
  517. rasa/markers/upload.py +74 -0
  518. rasa/markers/validate.py +21 -0
  519. rasa/model.py +118 -0
  520. rasa/model_manager/__init__.py +0 -0
  521. rasa/model_manager/config.py +40 -0
  522. rasa/model_manager/model_api.py +559 -0
  523. rasa/model_manager/runner_service.py +286 -0
  524. rasa/model_manager/socket_bridge.py +146 -0
  525. rasa/model_manager/studio_jwt_auth.py +86 -0
  526. rasa/model_manager/trainer_service.py +325 -0
  527. rasa/model_manager/utils.py +87 -0
  528. rasa/model_manager/warm_rasa_process.py +187 -0
  529. rasa/model_service.py +112 -0
  530. rasa/model_testing.py +457 -0
  531. rasa/model_training.py +596 -0
  532. rasa/nlu/__init__.py +7 -0
  533. rasa/nlu/classifiers/__init__.py +3 -0
  534. rasa/nlu/classifiers/classifier.py +5 -0
  535. rasa/nlu/classifiers/diet_classifier.py +1881 -0
  536. rasa/nlu/classifiers/fallback_classifier.py +192 -0
  537. rasa/nlu/classifiers/keyword_intent_classifier.py +188 -0
  538. rasa/nlu/classifiers/logistic_regression_classifier.py +253 -0
  539. rasa/nlu/classifiers/mitie_intent_classifier.py +156 -0
  540. rasa/nlu/classifiers/regex_message_handler.py +56 -0
  541. rasa/nlu/classifiers/sklearn_intent_classifier.py +330 -0
  542. rasa/nlu/constants.py +77 -0
  543. rasa/nlu/convert.py +40 -0
  544. rasa/nlu/emulators/__init__.py +0 -0
  545. rasa/nlu/emulators/dialogflow.py +55 -0
  546. rasa/nlu/emulators/emulator.py +49 -0
  547. rasa/nlu/emulators/luis.py +86 -0
  548. rasa/nlu/emulators/no_emulator.py +10 -0
  549. rasa/nlu/emulators/wit.py +56 -0
  550. rasa/nlu/extractors/__init__.py +0 -0
  551. rasa/nlu/extractors/crf_entity_extractor.py +715 -0
  552. rasa/nlu/extractors/duckling_entity_extractor.py +206 -0
  553. rasa/nlu/extractors/entity_synonyms.py +178 -0
  554. rasa/nlu/extractors/extractor.py +470 -0
  555. rasa/nlu/extractors/mitie_entity_extractor.py +293 -0
  556. rasa/nlu/extractors/regex_entity_extractor.py +220 -0
  557. rasa/nlu/extractors/spacy_entity_extractor.py +95 -0
  558. rasa/nlu/featurizers/__init__.py +0 -0
  559. rasa/nlu/featurizers/dense_featurizer/__init__.py +0 -0
  560. rasa/nlu/featurizers/dense_featurizer/convert_featurizer.py +445 -0
  561. rasa/nlu/featurizers/dense_featurizer/dense_featurizer.py +57 -0
  562. rasa/nlu/featurizers/dense_featurizer/lm_featurizer.py +768 -0
  563. rasa/nlu/featurizers/dense_featurizer/mitie_featurizer.py +170 -0
  564. rasa/nlu/featurizers/dense_featurizer/spacy_featurizer.py +132 -0
  565. rasa/nlu/featurizers/featurizer.py +89 -0
  566. rasa/nlu/featurizers/sparse_featurizer/__init__.py +0 -0
  567. rasa/nlu/featurizers/sparse_featurizer/count_vectors_featurizer.py +867 -0
  568. rasa/nlu/featurizers/sparse_featurizer/lexical_syntactic_featurizer.py +571 -0
  569. rasa/nlu/featurizers/sparse_featurizer/regex_featurizer.py +271 -0
  570. rasa/nlu/featurizers/sparse_featurizer/sparse_featurizer.py +9 -0
  571. rasa/nlu/model.py +24 -0
  572. rasa/nlu/run.py +27 -0
  573. rasa/nlu/selectors/__init__.py +0 -0
  574. rasa/nlu/selectors/response_selector.py +987 -0
  575. rasa/nlu/test.py +1940 -0
  576. rasa/nlu/tokenizers/__init__.py +0 -0
  577. rasa/nlu/tokenizers/jieba_tokenizer.py +148 -0
  578. rasa/nlu/tokenizers/mitie_tokenizer.py +75 -0
  579. rasa/nlu/tokenizers/spacy_tokenizer.py +72 -0
  580. rasa/nlu/tokenizers/tokenizer.py +239 -0
  581. rasa/nlu/tokenizers/whitespace_tokenizer.py +95 -0
  582. rasa/nlu/utils/__init__.py +35 -0
  583. rasa/nlu/utils/bilou_utils.py +462 -0
  584. rasa/nlu/utils/hugging_face/__init__.py +0 -0
  585. rasa/nlu/utils/hugging_face/registry.py +108 -0
  586. rasa/nlu/utils/hugging_face/transformers_pre_post_processors.py +311 -0
  587. rasa/nlu/utils/mitie_utils.py +113 -0
  588. rasa/nlu/utils/pattern_utils.py +168 -0
  589. rasa/nlu/utils/spacy_utils.py +310 -0
  590. rasa/plugin.py +90 -0
  591. rasa/server.py +1588 -0
  592. rasa/shared/__init__.py +0 -0
  593. rasa/shared/constants.py +311 -0
  594. rasa/shared/core/__init__.py +0 -0
  595. rasa/shared/core/command_payload_reader.py +109 -0
  596. rasa/shared/core/constants.py +180 -0
  597. rasa/shared/core/conversation.py +46 -0
  598. rasa/shared/core/domain.py +2172 -0
  599. rasa/shared/core/events.py +2559 -0
  600. rasa/shared/core/flows/__init__.py +7 -0
  601. rasa/shared/core/flows/flow.py +562 -0
  602. rasa/shared/core/flows/flow_path.py +84 -0
  603. rasa/shared/core/flows/flow_step.py +146 -0
  604. rasa/shared/core/flows/flow_step_links.py +319 -0
  605. rasa/shared/core/flows/flow_step_sequence.py +70 -0
  606. rasa/shared/core/flows/flows_list.py +258 -0
  607. rasa/shared/core/flows/flows_yaml_schema.json +303 -0
  608. rasa/shared/core/flows/nlu_trigger.py +117 -0
  609. rasa/shared/core/flows/steps/__init__.py +24 -0
  610. rasa/shared/core/flows/steps/action.py +56 -0
  611. rasa/shared/core/flows/steps/call.py +64 -0
  612. rasa/shared/core/flows/steps/collect.py +112 -0
  613. rasa/shared/core/flows/steps/constants.py +5 -0
  614. rasa/shared/core/flows/steps/continuation.py +36 -0
  615. rasa/shared/core/flows/steps/end.py +22 -0
  616. rasa/shared/core/flows/steps/internal.py +44 -0
  617. rasa/shared/core/flows/steps/link.py +51 -0
  618. rasa/shared/core/flows/steps/no_operation.py +48 -0
  619. rasa/shared/core/flows/steps/set_slots.py +50 -0
  620. rasa/shared/core/flows/steps/start.py +30 -0
  621. rasa/shared/core/flows/utils.py +39 -0
  622. rasa/shared/core/flows/validation.py +735 -0
  623. rasa/shared/core/flows/yaml_flows_io.py +405 -0
  624. rasa/shared/core/generator.py +908 -0
  625. rasa/shared/core/slot_mappings.py +526 -0
  626. rasa/shared/core/slots.py +654 -0
  627. rasa/shared/core/trackers.py +1183 -0
  628. rasa/shared/core/training_data/__init__.py +0 -0
  629. rasa/shared/core/training_data/loading.py +89 -0
  630. rasa/shared/core/training_data/story_reader/__init__.py +0 -0
  631. rasa/shared/core/training_data/story_reader/story_reader.py +129 -0
  632. rasa/shared/core/training_data/story_reader/story_step_builder.py +168 -0
  633. rasa/shared/core/training_data/story_reader/yaml_story_reader.py +888 -0
  634. rasa/shared/core/training_data/story_writer/__init__.py +0 -0
  635. rasa/shared/core/training_data/story_writer/story_writer.py +76 -0
  636. rasa/shared/core/training_data/story_writer/yaml_story_writer.py +444 -0
  637. rasa/shared/core/training_data/structures.py +858 -0
  638. rasa/shared/core/training_data/visualization.html +146 -0
  639. rasa/shared/core/training_data/visualization.py +603 -0
  640. rasa/shared/data.py +249 -0
  641. rasa/shared/engine/__init__.py +0 -0
  642. rasa/shared/engine/caching.py +26 -0
  643. rasa/shared/exceptions.py +167 -0
  644. rasa/shared/importers/__init__.py +0 -0
  645. rasa/shared/importers/importer.py +770 -0
  646. rasa/shared/importers/multi_project.py +215 -0
  647. rasa/shared/importers/rasa.py +108 -0
  648. rasa/shared/importers/remote_importer.py +196 -0
  649. rasa/shared/importers/utils.py +36 -0
  650. rasa/shared/nlu/__init__.py +0 -0
  651. rasa/shared/nlu/constants.py +53 -0
  652. rasa/shared/nlu/interpreter.py +10 -0
  653. rasa/shared/nlu/training_data/__init__.py +0 -0
  654. rasa/shared/nlu/training_data/entities_parser.py +208 -0
  655. rasa/shared/nlu/training_data/features.py +492 -0
  656. rasa/shared/nlu/training_data/formats/__init__.py +10 -0
  657. rasa/shared/nlu/training_data/formats/dialogflow.py +163 -0
  658. rasa/shared/nlu/training_data/formats/luis.py +87 -0
  659. rasa/shared/nlu/training_data/formats/rasa.py +135 -0
  660. rasa/shared/nlu/training_data/formats/rasa_yaml.py +618 -0
  661. rasa/shared/nlu/training_data/formats/readerwriter.py +244 -0
  662. rasa/shared/nlu/training_data/formats/wit.py +52 -0
  663. rasa/shared/nlu/training_data/loading.py +137 -0
  664. rasa/shared/nlu/training_data/lookup_tables_parser.py +30 -0
  665. rasa/shared/nlu/training_data/message.py +490 -0
  666. rasa/shared/nlu/training_data/schemas/__init__.py +0 -0
  667. rasa/shared/nlu/training_data/schemas/data_schema.py +85 -0
  668. rasa/shared/nlu/training_data/schemas/nlu.yml +53 -0
  669. rasa/shared/nlu/training_data/schemas/responses.yml +70 -0
  670. rasa/shared/nlu/training_data/synonyms_parser.py +42 -0
  671. rasa/shared/nlu/training_data/training_data.py +729 -0
  672. rasa/shared/nlu/training_data/util.py +223 -0
  673. rasa/shared/providers/__init__.py +0 -0
  674. rasa/shared/providers/_configs/__init__.py +0 -0
  675. rasa/shared/providers/_configs/azure_openai_client_config.py +677 -0
  676. rasa/shared/providers/_configs/client_config.py +59 -0
  677. rasa/shared/providers/_configs/default_litellm_client_config.py +132 -0
  678. rasa/shared/providers/_configs/huggingface_local_embedding_client_config.py +236 -0
  679. rasa/shared/providers/_configs/litellm_router_client_config.py +222 -0
  680. rasa/shared/providers/_configs/model_group_config.py +173 -0
  681. rasa/shared/providers/_configs/openai_client_config.py +177 -0
  682. rasa/shared/providers/_configs/rasa_llm_client_config.py +75 -0
  683. rasa/shared/providers/_configs/self_hosted_llm_client_config.py +178 -0
  684. rasa/shared/providers/_configs/utils.py +117 -0
  685. rasa/shared/providers/_ssl_verification_utils.py +124 -0
  686. rasa/shared/providers/_utils.py +79 -0
  687. rasa/shared/providers/constants.py +7 -0
  688. rasa/shared/providers/embedding/__init__.py +0 -0
  689. rasa/shared/providers/embedding/_base_litellm_embedding_client.py +243 -0
  690. rasa/shared/providers/embedding/_langchain_embedding_client_adapter.py +74 -0
  691. rasa/shared/providers/embedding/azure_openai_embedding_client.py +335 -0
  692. rasa/shared/providers/embedding/default_litellm_embedding_client.py +126 -0
  693. rasa/shared/providers/embedding/embedding_client.py +90 -0
  694. rasa/shared/providers/embedding/embedding_response.py +41 -0
  695. rasa/shared/providers/embedding/huggingface_local_embedding_client.py +191 -0
  696. rasa/shared/providers/embedding/litellm_router_embedding_client.py +138 -0
  697. rasa/shared/providers/embedding/openai_embedding_client.py +172 -0
  698. rasa/shared/providers/llm/__init__.py +0 -0
  699. rasa/shared/providers/llm/_base_litellm_client.py +265 -0
  700. rasa/shared/providers/llm/azure_openai_llm_client.py +415 -0
  701. rasa/shared/providers/llm/default_litellm_llm_client.py +110 -0
  702. rasa/shared/providers/llm/litellm_router_llm_client.py +202 -0
  703. rasa/shared/providers/llm/llm_client.py +78 -0
  704. rasa/shared/providers/llm/llm_response.py +50 -0
  705. rasa/shared/providers/llm/openai_llm_client.py +161 -0
  706. rasa/shared/providers/llm/rasa_llm_client.py +120 -0
  707. rasa/shared/providers/llm/self_hosted_llm_client.py +276 -0
  708. rasa/shared/providers/mappings.py +94 -0
  709. rasa/shared/providers/router/__init__.py +0 -0
  710. rasa/shared/providers/router/_base_litellm_router_client.py +185 -0
  711. rasa/shared/providers/router/router_client.py +75 -0
  712. rasa/shared/utils/__init__.py +0 -0
  713. rasa/shared/utils/cli.py +102 -0
  714. rasa/shared/utils/common.py +324 -0
  715. rasa/shared/utils/constants.py +4 -0
  716. rasa/shared/utils/health_check/__init__.py +0 -0
  717. rasa/shared/utils/health_check/embeddings_health_check_mixin.py +31 -0
  718. rasa/shared/utils/health_check/health_check.py +258 -0
  719. rasa/shared/utils/health_check/llm_health_check_mixin.py +31 -0
  720. rasa/shared/utils/io.py +499 -0
  721. rasa/shared/utils/llm.py +764 -0
  722. rasa/shared/utils/pykwalify_extensions.py +27 -0
  723. rasa/shared/utils/schemas/__init__.py +0 -0
  724. rasa/shared/utils/schemas/config.yml +2 -0
  725. rasa/shared/utils/schemas/domain.yml +145 -0
  726. rasa/shared/utils/schemas/events.py +214 -0
  727. rasa/shared/utils/schemas/model_config.yml +36 -0
  728. rasa/shared/utils/schemas/stories.yml +173 -0
  729. rasa/shared/utils/yaml.py +1068 -0
  730. rasa/studio/__init__.py +0 -0
  731. rasa/studio/auth.py +270 -0
  732. rasa/studio/config.py +136 -0
  733. rasa/studio/constants.py +19 -0
  734. rasa/studio/data_handler.py +368 -0
  735. rasa/studio/download.py +489 -0
  736. rasa/studio/results_logger.py +137 -0
  737. rasa/studio/train.py +134 -0
  738. rasa/studio/upload.py +563 -0
  739. rasa/telemetry.py +1876 -0
  740. rasa/tracing/__init__.py +0 -0
  741. rasa/tracing/config.py +355 -0
  742. rasa/tracing/constants.py +62 -0
  743. rasa/tracing/instrumentation/__init__.py +0 -0
  744. rasa/tracing/instrumentation/attribute_extractors.py +765 -0
  745. rasa/tracing/instrumentation/instrumentation.py +1306 -0
  746. rasa/tracing/instrumentation/intentless_policy_instrumentation.py +144 -0
  747. rasa/tracing/instrumentation/metrics.py +294 -0
  748. rasa/tracing/metric_instrument_provider.py +205 -0
  749. rasa/utils/__init__.py +0 -0
  750. rasa/utils/beta.py +83 -0
  751. rasa/utils/cli.py +28 -0
  752. rasa/utils/common.py +639 -0
  753. rasa/utils/converter.py +53 -0
  754. rasa/utils/endpoints.py +331 -0
  755. rasa/utils/io.py +252 -0
  756. rasa/utils/json_utils.py +60 -0
  757. rasa/utils/licensing.py +542 -0
  758. rasa/utils/log_utils.py +181 -0
  759. rasa/utils/mapper.py +210 -0
  760. rasa/utils/ml_utils.py +147 -0
  761. rasa/utils/plotting.py +362 -0
  762. rasa/utils/sanic_error_handler.py +32 -0
  763. rasa/utils/singleton.py +23 -0
  764. rasa/utils/tensorflow/__init__.py +0 -0
  765. rasa/utils/tensorflow/callback.py +112 -0
  766. rasa/utils/tensorflow/constants.py +116 -0
  767. rasa/utils/tensorflow/crf.py +492 -0
  768. rasa/utils/tensorflow/data_generator.py +440 -0
  769. rasa/utils/tensorflow/environment.py +161 -0
  770. rasa/utils/tensorflow/exceptions.py +5 -0
  771. rasa/utils/tensorflow/feature_array.py +366 -0
  772. rasa/utils/tensorflow/layers.py +1565 -0
  773. rasa/utils/tensorflow/layers_utils.py +113 -0
  774. rasa/utils/tensorflow/metrics.py +281 -0
  775. rasa/utils/tensorflow/model_data.py +798 -0
  776. rasa/utils/tensorflow/model_data_utils.py +499 -0
  777. rasa/utils/tensorflow/models.py +935 -0
  778. rasa/utils/tensorflow/rasa_layers.py +1094 -0
  779. rasa/utils/tensorflow/transformer.py +640 -0
  780. rasa/utils/tensorflow/types.py +6 -0
  781. rasa/utils/train_utils.py +572 -0
  782. rasa/utils/url_tools.py +53 -0
  783. rasa/utils/yaml.py +54 -0
  784. rasa/validator.py +1644 -0
  785. rasa/version.py +3 -0
  786. rasa_pro-3.12.0.dev1.dist-info/METADATA +199 -0
  787. rasa_pro-3.12.0.dev1.dist-info/NOTICE +5 -0
  788. rasa_pro-3.12.0.dev1.dist-info/RECORD +790 -0
  789. rasa_pro-3.12.0.dev1.dist-info/WHEEL +4 -0
  790. rasa_pro-3.12.0.dev1.dist-info/entry_points.txt +3 -0
@@ -0,0 +1,1022 @@
1
+ import dataclasses
2
+ import logging
3
+ from pathlib import Path
4
+ from typing import Any, List, Optional, Text, Dict, Type, Union
5
+
6
+ import numpy as np
7
+ import tensorflow as tf
8
+
9
+ import rasa.utils.common
10
+ from rasa.engine.graph import ExecutionContext
11
+ from rasa.engine.recipes.default_recipe import DefaultV1Recipe
12
+ from rasa.engine.storage.resource import Resource
13
+ from rasa.engine.storage.storage import ModelStorage
14
+ from rasa.nlu.classifiers import LABEL_RANKING_LENGTH
15
+ from rasa.shared.nlu.training_data.features import Features
16
+ from rasa.shared.core.domain import Domain
17
+ from rasa.shared.core.trackers import DialogueStateTracker
18
+ from rasa.shared.core.constants import SLOTS, ACTIVE_LOOP, ACTION_UNLIKELY_INTENT_NAME
19
+ from rasa.shared.core.events import UserUttered, ActionExecuted
20
+ import rasa.shared.utils.io
21
+ from rasa.shared.nlu.constants import (
22
+ INTENT,
23
+ TEXT,
24
+ ENTITIES,
25
+ ACTION_NAME,
26
+ SPLIT_ENTITIES_BY_COMMA,
27
+ SPLIT_ENTITIES_BY_COMMA_DEFAULT_VALUE,
28
+ )
29
+ from rasa.nlu.extractors.extractor import EntityTagSpec
30
+ from rasa.core.featurizers.precomputation import MessageContainerForCoreFeaturization
31
+ from rasa.core.featurizers.tracker_featurizers import TrackerFeaturizer
32
+ from rasa.core.featurizers.tracker_featurizers import IntentMaxHistoryTrackerFeaturizer
33
+ from rasa.core.featurizers.single_state_featurizer import (
34
+ IntentTokenizerSingleStateFeaturizer,
35
+ )
36
+ from rasa.shared.core.generator import TrackerWithCachedStates
37
+ from rasa.core.constants import (
38
+ DIALOGUE,
39
+ POLICY_MAX_HISTORY,
40
+ POLICY_PRIORITY,
41
+ UNLIKELY_INTENT_POLICY_PRIORITY,
42
+ )
43
+ from rasa.core.policies.policy import PolicyPrediction
44
+ from rasa.core.policies.ted_policy import (
45
+ LABEL_KEY,
46
+ LABEL_SUB_KEY,
47
+ TEDPolicy,
48
+ TED,
49
+ SEQUENCE_LENGTH,
50
+ SEQUENCE,
51
+ PREDICTION_FEATURES,
52
+ )
53
+ from rasa.utils import train_utils
54
+ from rasa.utils.tensorflow.models import RasaModel
55
+ from rasa.utils.tensorflow.constants import (
56
+ LABEL,
57
+ DENSE_DIMENSION,
58
+ ENCODING_DIMENSION,
59
+ UNIDIRECTIONAL_ENCODER,
60
+ TRANSFORMER_SIZE,
61
+ NUM_TRANSFORMER_LAYERS,
62
+ NUM_HEADS,
63
+ BATCH_SIZES,
64
+ BATCH_STRATEGY,
65
+ EPOCHS,
66
+ RANDOM_SEED,
67
+ RANKING_LENGTH,
68
+ LOSS_TYPE,
69
+ SIMILARITY_TYPE,
70
+ NUM_NEG,
71
+ EVAL_NUM_EXAMPLES,
72
+ EVAL_NUM_EPOCHS,
73
+ REGULARIZATION_CONSTANT,
74
+ SCALE_LOSS,
75
+ EMBEDDING_DIMENSION,
76
+ DROP_RATE_DIALOGUE,
77
+ DROP_RATE_LABEL,
78
+ DROP_RATE,
79
+ DROP_RATE_ATTENTION,
80
+ CONNECTION_DENSITY,
81
+ KEY_RELATIVE_ATTENTION,
82
+ VALUE_RELATIVE_ATTENTION,
83
+ MAX_RELATIVE_POSITION,
84
+ INNER,
85
+ BALANCED,
86
+ TENSORBOARD_LOG_DIR,
87
+ TENSORBOARD_LOG_LEVEL,
88
+ CHECKPOINT_MODEL,
89
+ FEATURIZERS,
90
+ ENTITY_RECOGNITION,
91
+ IGNORE_INTENTS_LIST,
92
+ BILOU_FLAG,
93
+ LEARNING_RATE,
94
+ CROSS_ENTROPY,
95
+ SPARSE_INPUT_DROPOUT,
96
+ DENSE_INPUT_DROPOUT,
97
+ MASKED_LM,
98
+ HIDDEN_LAYERS_SIZES,
99
+ CONCAT_DIMENSION,
100
+ TOLERANCE,
101
+ LABEL_PAD_ID,
102
+ POSITIVE_SCORES_KEY,
103
+ NEGATIVE_SCORES_KEY,
104
+ USE_GPU,
105
+ )
106
+ from rasa.utils.tensorflow import layers
107
+ from rasa.utils.tensorflow.model_data import RasaModelData, FeatureArray, Data
108
+ from rasa.core.exceptions import RasaCoreException
109
+ from rasa.shared.utils import common
110
+
111
+
112
+ @dataclasses.dataclass
113
+ class RankingCandidateMetadata:
114
+ """Dataclass to represent metada for a candidate intent."""
115
+
116
+ name: Text
117
+ score: float
118
+ threshold: Optional[float]
119
+ severity: Optional[float]
120
+
121
+
122
+ @dataclasses.dataclass
123
+ class UnexpecTEDIntentPolicyMetadata:
124
+ """Dataclass to represent policy metadata."""
125
+
126
+ query_intent: RankingCandidateMetadata
127
+ ranking: List[RankingCandidateMetadata]
128
+
129
+
130
+ logger = logging.getLogger(__name__)
131
+
132
+
133
+ @DefaultV1Recipe.register(
134
+ DefaultV1Recipe.ComponentType.POLICY_WITH_END_TO_END_SUPPORT, is_trainable=True
135
+ )
136
+ class UnexpecTEDIntentPolicy(TEDPolicy):
137
+ """`UnexpecTEDIntentPolicy` has the same model architecture as `TEDPolicy`.
138
+
139
+ The difference is at a task level.
140
+ Instead of predicting the next probable action, this policy
141
+ predicts whether the last predicted intent is a likely intent
142
+ according to the training stories and conversation context.
143
+ """
144
+
145
+ @staticmethod
146
+ def get_default_config() -> Dict[Text, Any]:
147
+ """Returns the default config (see parent class for full docstring)."""
148
+ return {
149
+ # ## Architecture of the used neural network
150
+ # Hidden layer sizes for layers before the embedding layers for user message
151
+ # and labels.
152
+ # The number of hidden layers is equal to the length
153
+ # of the corresponding list.
154
+ HIDDEN_LAYERS_SIZES: {TEXT: []},
155
+ # Dense dimension to use for sparse features.
156
+ DENSE_DIMENSION: {
157
+ TEXT: 128,
158
+ INTENT: 20,
159
+ ACTION_NAME: 20,
160
+ ENTITIES: 20,
161
+ SLOTS: 20,
162
+ ACTIVE_LOOP: 20,
163
+ f"{LABEL}_{INTENT}": 20,
164
+ },
165
+ # Default dimension to use for concatenating sequence and sentence features.
166
+ CONCAT_DIMENSION: {TEXT: 128},
167
+ # Dimension size of embedding vectors before
168
+ # the dialogue transformer encoder.
169
+ ENCODING_DIMENSION: 50,
170
+ # Number of units in transformer encoders
171
+ TRANSFORMER_SIZE: {TEXT: 128, DIALOGUE: 128},
172
+ # Number of layers in transformer encoders
173
+ NUM_TRANSFORMER_LAYERS: {TEXT: 1, DIALOGUE: 1},
174
+ # Number of attention heads in transformer
175
+ NUM_HEADS: 4,
176
+ # If 'True' use key relative embeddings in attention
177
+ KEY_RELATIVE_ATTENTION: False,
178
+ # If 'True' use value relative embeddings in attention
179
+ VALUE_RELATIVE_ATTENTION: False,
180
+ # Max position for relative embeddings. Only in effect
181
+ # if key- or value relative attention are turned on
182
+ MAX_RELATIVE_POSITION: 5,
183
+ # Use a unidirectional or bidirectional encoder
184
+ # for `text`, `action_text`, and `label_action_text`.
185
+ UNIDIRECTIONAL_ENCODER: False,
186
+ # ## Training parameters
187
+ # Initial and final batch sizes:
188
+ # Batch size will be linearly increased for each epoch.
189
+ BATCH_SIZES: [64, 256],
190
+ # Strategy used when creating batches.
191
+ # Can be either 'sequence' or 'balanced'.
192
+ BATCH_STRATEGY: BALANCED,
193
+ # Number of epochs to train
194
+ EPOCHS: 1,
195
+ # Set random seed to any 'int' to get reproducible results
196
+ RANDOM_SEED: None,
197
+ # Initial learning rate for the optimizer
198
+ LEARNING_RATE: 0.001,
199
+ # ## Parameters for embeddings
200
+ # Dimension size of embedding vectors
201
+ EMBEDDING_DIMENSION: 20,
202
+ # The number of incorrect labels. The algorithm will minimize
203
+ # their similarity to the user input during training.
204
+ NUM_NEG: 20,
205
+ # Number of intents to store in ranking key of predicted action metadata.
206
+ # Set this to `0` to include all intents.
207
+ RANKING_LENGTH: LABEL_RANKING_LENGTH,
208
+ # If 'True' scale loss inverse proportionally to the confidence
209
+ # of the correct prediction
210
+ SCALE_LOSS: True,
211
+ # ## Regularization parameters
212
+ # The scale of regularization
213
+ REGULARIZATION_CONSTANT: 0.001,
214
+ # Dropout rate for embedding layers of dialogue features.
215
+ DROP_RATE_DIALOGUE: 0.1,
216
+ # Dropout rate for embedding layers of utterance level features.
217
+ DROP_RATE: 0.0,
218
+ # Dropout rate for embedding layers of label, e.g. action, features.
219
+ DROP_RATE_LABEL: 0.0,
220
+ # Dropout rate for attention.
221
+ DROP_RATE_ATTENTION: 0.0,
222
+ # Fraction of trainable weights in internal layers.
223
+ CONNECTION_DENSITY: 0.2,
224
+ # If 'True' apply dropout to sparse input tensors
225
+ SPARSE_INPUT_DROPOUT: True,
226
+ # If 'True' apply dropout to dense input tensors
227
+ DENSE_INPUT_DROPOUT: True,
228
+ # If 'True' random tokens of the input message will be masked.
229
+ # Since there is no related loss term used inside TED, the masking
230
+ # effectively becomes just input dropout applied to the text of user
231
+ # utterances.
232
+ MASKED_LM: False,
233
+ # ## Evaluation parameters
234
+ # How often calculate validation accuracy.
235
+ # Small values may hurt performance, e.g. model accuracy.
236
+ EVAL_NUM_EPOCHS: 20,
237
+ # How many examples to use for hold out validation set
238
+ # Large values may hurt performance, e.g. model accuracy.
239
+ EVAL_NUM_EXAMPLES: 0,
240
+ # If you want to use tensorboard to visualize training and validation
241
+ # metrics, set this option to a valid output directory.
242
+ TENSORBOARD_LOG_DIR: None,
243
+ # Define when training metrics for tensorboard should be logged.
244
+ # Either after every epoch or for every training step.
245
+ # Valid values: 'epoch' and 'batch'
246
+ TENSORBOARD_LOG_LEVEL: "epoch",
247
+ # Perform model checkpointing
248
+ CHECKPOINT_MODEL: False,
249
+ # Specify what features to use as sequence and sentence features.
250
+ # By default all features in the pipeline are used.
251
+ FEATURIZERS: [],
252
+ # List of intents to ignore for `action_unlikely_intent` prediction.
253
+ IGNORE_INTENTS_LIST: [],
254
+ # Tolerance for prediction of `action_unlikely_intent`.
255
+ # For each intent, the tolerance is the percentage of
256
+ # negative training instances (trackers for which
257
+ # the corresponding intent is not the correct label) that
258
+ # would be ignored by `UnexpecTEDIntentPolicy`. This is converted
259
+ # into a similarity threshold by identifying the similarity
260
+ # score for the (1 - tolerance) percentile of negative
261
+ # examples. Any tracker with a similarity score below this
262
+ # threshold will trigger an `action_unlikely_intent`.
263
+ # Higher values of `tolerance` means the policy is more
264
+ # "tolerant" to surprising paths in conversations and
265
+ # hence will result in lesser number of `action_unlikely_intent`
266
+ # triggers. Acceptable values are between 0.0 and 1.0 (inclusive).
267
+ TOLERANCE: 0.0,
268
+ # Split entities by comma, this makes sense e.g. for a list of
269
+ # ingredients in a recipe, but it doesn't make sense for the parts of
270
+ # an address
271
+ SPLIT_ENTITIES_BY_COMMA: SPLIT_ENTITIES_BY_COMMA_DEFAULT_VALUE,
272
+ # Type of similarity measure to use, either 'auto' or 'cosine' or 'inner'.
273
+ SIMILARITY_TYPE: INNER,
274
+ # If set to true, entities are predicted in user utterances.
275
+ ENTITY_RECOGNITION: False,
276
+ # 'BILOU_flag' determines whether to use BILOU tagging or not.
277
+ # If set to 'True' labelling is more rigorous, however more
278
+ # examples per entity are required.
279
+ # Rule of thumb: you should have more than 100 examples per entity.
280
+ BILOU_FLAG: False,
281
+ # The type of the loss function, either 'cross_entropy' or 'margin'.
282
+ LOSS_TYPE: CROSS_ENTROPY,
283
+ # Determines the importance of policies, higher values take precedence
284
+ POLICY_PRIORITY: UNLIKELY_INTENT_POLICY_PRIORITY,
285
+ USE_GPU: True,
286
+ }
287
+
288
+ def __init__(
289
+ self,
290
+ config: Dict[Text, Any],
291
+ model_storage: ModelStorage,
292
+ resource: Resource,
293
+ execution_context: ExecutionContext,
294
+ model: Optional[RasaModel] = None,
295
+ featurizer: Optional[TrackerFeaturizer] = None,
296
+ fake_features: Optional[Dict[Text, List[Features]]] = None,
297
+ entity_tag_specs: Optional[List[EntityTagSpec]] = None,
298
+ label_quantiles: Optional[Dict[int, List[float]]] = None,
299
+ ):
300
+ """Declares instance variables with default values."""
301
+ # Set all invalid / non configurable parameters
302
+ config[ENTITY_RECOGNITION] = False
303
+ config[BILOU_FLAG] = False
304
+ config[SIMILARITY_TYPE] = INNER
305
+ config[LOSS_TYPE] = CROSS_ENTROPY
306
+ self.config = config
307
+
308
+ super().__init__(
309
+ self.config,
310
+ model_storage,
311
+ resource,
312
+ execution_context,
313
+ model,
314
+ featurizer,
315
+ fake_features,
316
+ entity_tag_specs,
317
+ )
318
+
319
+ self.label_quantiles = label_quantiles or {}
320
+ self.label_thresholds = (
321
+ self._pick_thresholds(self.label_quantiles, self.config[TOLERANCE])
322
+ if self.label_quantiles
323
+ else {}
324
+ )
325
+ self.ignore_intent_list = self.config[IGNORE_INTENTS_LIST]
326
+
327
+ common.mark_as_experimental_feature("UnexpecTED Intent Policy")
328
+
329
+ def _standard_featurizer(self) -> IntentMaxHistoryTrackerFeaturizer:
330
+ return IntentMaxHistoryTrackerFeaturizer(
331
+ IntentTokenizerSingleStateFeaturizer(),
332
+ max_history=self.config.get(POLICY_MAX_HISTORY),
333
+ )
334
+
335
+ @staticmethod
336
+ def model_class() -> Type["IntentTED"]:
337
+ """Gets the class of the model architecture to be used by the policy.
338
+
339
+ Returns:
340
+ Required class.
341
+ """
342
+ return IntentTED
343
+
344
+ def _auto_update_configuration(self) -> None:
345
+ self.config = train_utils.update_evaluation_parameters(self.config)
346
+
347
+ @classmethod
348
+ def _metadata_filename(cls) -> Optional[Text]:
349
+ return "unexpected_intent_policy"
350
+
351
+ def _assemble_label_data(
352
+ self, attribute_data: Data, domain: Domain
353
+ ) -> RasaModelData:
354
+ """Constructs data regarding labels to be fed to the model.
355
+
356
+ The resultant model data should contain the keys `label_intent`, `label`.
357
+ `label_intent` will contain the sequence, sentence and mask features
358
+ for all intent labels and `label` will contain the numerical label ids.
359
+
360
+ Args:
361
+ attribute_data: Feature data for all intent labels.
362
+ domain: Domain of the assistant.
363
+
364
+ Returns:
365
+ Features of labels ready to be fed to the model.
366
+ """
367
+ label_data = RasaModelData()
368
+ label_data.add_data(attribute_data, key_prefix=f"{LABEL_KEY}_")
369
+ label_data.add_lengths(
370
+ f"{LABEL}_{INTENT}", SEQUENCE_LENGTH, f"{LABEL}_{INTENT}", SEQUENCE
371
+ )
372
+ label_ids = np.arange(len(domain.intents))
373
+ label_data.add_features(
374
+ LABEL_KEY,
375
+ LABEL_SUB_KEY,
376
+ [
377
+ FeatureArray(
378
+ np.expand_dims(label_ids, -1),
379
+ number_of_dimensions=2,
380
+ )
381
+ ],
382
+ )
383
+ return label_data
384
+
385
+ @staticmethod
386
+ def _prepare_data_for_prediction(model_data: RasaModelData) -> RasaModelData:
387
+ """Transforms training model data to data usable for making model predictions.
388
+
389
+ Transformation involves filtering out all features which
390
+ are not useful at prediction time. This is important
391
+ because the prediction signature will not contain these
392
+ attributes and hence prediction will break.
393
+
394
+ Args:
395
+ model_data: Data used during model training.
396
+
397
+ Returns:
398
+ Transformed data usable for making predictions.
399
+ """
400
+ filtered_data: Dict[Text, Dict[Text, Any]] = {
401
+ key: features
402
+ for key, features in model_data.data.items()
403
+ if key in PREDICTION_FEATURES
404
+ }
405
+ return RasaModelData(data=filtered_data)
406
+
407
+ def compute_label_quantiles_post_training(
408
+ self, model_data: RasaModelData, label_ids: np.ndarray
409
+ ) -> None:
410
+ """Computes quantile scores for prediction of `action_unlikely_intent`.
411
+
412
+ Multiple quantiles are computed for each label
413
+ so that an appropriate threshold can be picked at
414
+ inference time according to the `tolerance` value specified.
415
+
416
+ Args:
417
+ model_data: Data used for training the model.
418
+ label_ids: Numerical IDs of labels for each data point used during training.
419
+ """
420
+ # `model_data` contains data attributes like `label` which were
421
+ # used during training. These attributes are not present in
422
+ # the `predict_data_signature`. Prediction through the model
423
+ # will break if `model_data` is passed as it is through the model.
424
+ # Hence, we first filter out the attributes inside `model_data`
425
+ # to keep only those which should be present during prediction.
426
+ model_prediction_data = self._prepare_data_for_prediction(model_data)
427
+ prediction_scores = (
428
+ self.model.run_bulk_inference(model_prediction_data)
429
+ if self.model is not None
430
+ else {}
431
+ )
432
+ label_id_scores = self._collect_label_id_grouped_scores(
433
+ prediction_scores, label_ids
434
+ )
435
+ # For each label id, compute multiple quantile scores.
436
+ # These quantile scores can be looked up during inference
437
+ # to select a specific threshold according to the `tolerance`
438
+ # value specified in the configuration.
439
+ self.label_quantiles = self._compute_label_quantiles(label_id_scores)
440
+
441
+ @staticmethod
442
+ def _get_trackers_for_training(
443
+ trackers: List[TrackerWithCachedStates],
444
+ ) -> List[TrackerWithCachedStates]:
445
+ """Filters out the list of trackers which should not be used for training.
446
+
447
+ `UnexpecTEDIntentPolicy` cannot be trained on trackers with:
448
+ 1. `UserUttered` events with no intent.
449
+ 2. `ActionExecuted` events with no action_name.
450
+
451
+ Trackers with such events are filtered out.
452
+
453
+ Args:
454
+ trackers: All trackers available for training.
455
+
456
+ Returns:
457
+ Trackers which should be used for training.
458
+ """
459
+ trackers_for_training = []
460
+ for tracker in trackers:
461
+ tracker_compatible = True
462
+ for event in tracker.applied_events(True):
463
+ if (isinstance(event, UserUttered) and event.intent_name is None) or (
464
+ isinstance(event, ActionExecuted) and event.action_name is None
465
+ ):
466
+ tracker_compatible = False
467
+ break
468
+ if tracker_compatible:
469
+ trackers_for_training.append(tracker)
470
+ return trackers_for_training
471
+
472
+ def run_training(
473
+ self, model_data: RasaModelData, label_ids: Optional[np.ndarray] = None
474
+ ) -> None:
475
+ """Feeds the featurized training data to the model.
476
+
477
+ Args:
478
+ model_data: Featurized training data.
479
+ label_ids: Label ids corresponding to the data points in `model_data`.
480
+
481
+ Raises:
482
+ `RasaCoreException` if `label_ids` is None as it's needed for
483
+ running post training procedures.
484
+ """
485
+ if label_ids is None:
486
+ raise RasaCoreException(
487
+ f"Incorrect usage of `run_training` "
488
+ f"method of `{self.__class__.__name__}`."
489
+ f"`label_ids` cannot be left to `None`."
490
+ )
491
+ super().run_training(model_data, label_ids)
492
+ self.compute_label_quantiles_post_training(model_data, label_ids)
493
+
494
+ def _collect_action_metadata(
495
+ self, domain: Domain, similarities: np.ndarray, query_intent: Text
496
+ ) -> UnexpecTEDIntentPolicyMetadata:
497
+ """Collects metadata to be attached to the predicted action.
498
+
499
+ Metadata schema looks like this:
500
+
501
+ {
502
+ "query_intent": <metadata of intent that was queried>,
503
+ "ranking": <sorted list of metadata corresponding to all intents
504
+ (truncated by `ranking_length` parameter)
505
+ It also includes the `query_intent`.
506
+ Sorting is based on predicted similarities.>
507
+ }
508
+
509
+ Each metadata dictionary looks like this:
510
+
511
+ {
512
+ "name": <name of intent>,
513
+ "score": <predicted similarity score>,
514
+ "threshold": <threshold used for intent>,
515
+ "severity": <numerical difference between threshold and score>
516
+ }
517
+
518
+ Args:
519
+ domain: Domain of the assistant.
520
+ similarities: Predicted similarities for each intent.
521
+ query_intent: Name of intent queried in this round of inference.
522
+
523
+ Returns:
524
+ Metadata to be attached.
525
+ """
526
+ query_intent_index = domain.intents.index(query_intent)
527
+
528
+ def _compile_metadata_for_label(
529
+ label_name: Text, similarity_score: float, threshold: Optional[float]
530
+ ) -> RankingCandidateMetadata:
531
+ severity = float(threshold - similarity_score) if threshold else None
532
+ return RankingCandidateMetadata(
533
+ label_name,
534
+ float(similarity_score),
535
+ float(threshold) if threshold else None,
536
+ severity,
537
+ )
538
+
539
+ query_intent_metadata = _compile_metadata_for_label(
540
+ query_intent,
541
+ similarities[0][domain.intents.index(query_intent)],
542
+ self.label_thresholds.get(query_intent_index),
543
+ )
544
+
545
+ # Ranking in descending order of predicted similarities
546
+ sorted_similarities = sorted(
547
+ [(index, similarity) for index, similarity in enumerate(similarities[0])],
548
+ key=lambda x: -x[1],
549
+ )
550
+
551
+ if self.config[RANKING_LENGTH] > 0:
552
+ sorted_similarities = sorted_similarities[: self.config[RANKING_LENGTH]]
553
+
554
+ ranking_metadata = [
555
+ _compile_metadata_for_label(
556
+ domain.intents[intent_index],
557
+ similarity,
558
+ self.label_thresholds.get(intent_index),
559
+ )
560
+ for intent_index, similarity in sorted_similarities
561
+ ]
562
+
563
+ return UnexpecTEDIntentPolicyMetadata(query_intent_metadata, ranking_metadata)
564
+
565
+ async def predict_action_probabilities(
566
+ self,
567
+ tracker: DialogueStateTracker,
568
+ domain: Domain,
569
+ rule_only_data: Optional[Dict[Text, Any]] = None,
570
+ precomputations: Optional[MessageContainerForCoreFeaturization] = None,
571
+ **kwargs: Any,
572
+ ) -> PolicyPrediction:
573
+ """Predicts the next action the bot should take after seeing the tracker.
574
+
575
+ Args:
576
+ tracker: Tracker containing past conversation events.
577
+ domain: Domain of the assistant.
578
+ rule_only_data: Slots and loops which are specific to rules and hence
579
+ should be ignored by this policy.
580
+ precomputations: Contains precomputed features and attributes.
581
+ **kwargs: Additional arguments.
582
+
583
+ Returns:
584
+ The policy's prediction (e.g. the probabilities for the actions).
585
+ """
586
+ if self.model is None or self.should_abstain_in_coexistence(tracker, False):
587
+ return self._prediction(self._default_predictions(domain))
588
+
589
+ # Prediction through the policy is skipped if:
590
+ # 1. If the tracker does not contain any event of type `UserUttered`
591
+ # till now or the intent of such event is not in domain.
592
+ # 2. There is at least one event of type `ActionExecuted`
593
+ # after the last `UserUttered` event.
594
+ if self._should_skip_prediction(tracker, domain):
595
+ logger.debug(
596
+ f"Skipping predictions for {self.__class__.__name__} "
597
+ f"as either there is no event of type `UserUttered`, "
598
+ f"event's intent is new and not in domain or "
599
+ f"there is an event of type `ActionExecuted` after "
600
+ f"the last `UserUttered`."
601
+ )
602
+ return self._prediction(self._default_predictions(domain))
603
+
604
+ # create model data from tracker
605
+ tracker_state_features = self._featurize_for_prediction(
606
+ tracker, domain, precomputations, rule_only_data=rule_only_data
607
+ )
608
+
609
+ model_data = self._create_model_data(tracker_state_features)
610
+ output = self.model.run_inference(model_data)
611
+
612
+ # take the last prediction in the sequence
613
+ if isinstance(output["similarities"], np.ndarray):
614
+ sequence_similarities = output["similarities"][:, -1, :]
615
+ else:
616
+ raise TypeError(
617
+ "model output for `similarities` " "should be a numpy array"
618
+ )
619
+
620
+ # Check for unlikely intent
621
+ last_user_uttered_event = tracker.get_last_event_for(UserUttered)
622
+ query_intent = (
623
+ last_user_uttered_event.intent_name
624
+ if last_user_uttered_event is not None
625
+ else ""
626
+ )
627
+ is_unlikely_intent = self._check_unlikely_intent(
628
+ domain, sequence_similarities, query_intent
629
+ )
630
+
631
+ confidences = list(np.zeros(domain.num_actions))
632
+
633
+ if is_unlikely_intent:
634
+ confidences[domain.index_for_action(ACTION_UNLIKELY_INTENT_NAME)] = 1.0
635
+
636
+ return self._prediction(
637
+ confidences,
638
+ action_metadata=dataclasses.asdict(
639
+ self._collect_action_metadata(
640
+ domain, sequence_similarities, query_intent
641
+ )
642
+ ),
643
+ )
644
+
645
+ @staticmethod
646
+ def _should_skip_prediction(tracker: DialogueStateTracker, domain: Domain) -> bool:
647
+ """Checks if the policy should skip making a prediction.
648
+
649
+ A prediction can be skipped if:
650
+ 1. There is no event of type `UserUttered` in the tracker.
651
+ 2. If the `UserUttered` event's intent is new and not in domain
652
+ (a new intent can be created from rasa interactive and not placed in
653
+ domain yet)
654
+ 3. There is an event of type `ActionExecuted` after the last
655
+ `UserUttered` event. This is to prevent the dialogue manager
656
+ from getting stuck in a prediction loop.
657
+ For example, if the last `ActionExecuted` event
658
+ contained `action_unlikely_intent` predicted by
659
+ `UnexpecTEDIntentPolicy` and
660
+ if `UnexpecTEDIntentPolicy` runs inference
661
+ on the same tracker, it will predict `action_unlikely_intent`
662
+ again which would make the dialogue manager get stuck in a
663
+ prediction loop.
664
+
665
+ Returns:
666
+ Whether prediction should be skipped.
667
+ """
668
+ applied_events = tracker.applied_events(True)
669
+
670
+ for event in reversed(applied_events):
671
+ if isinstance(event, ActionExecuted):
672
+ return True
673
+ elif isinstance(event, UserUttered):
674
+ if event.intent_name not in domain.intents:
675
+ return True
676
+ return False
677
+ # No event of type `ActionExecuted` and `UserUttered` means
678
+ # that there is nothing for `UnexpecTEDIntentPolicy` to predict on.
679
+ return True
680
+
681
+ def _should_check_for_intent(self, intent: Text, domain: Domain) -> bool:
682
+ """Checks if the intent should raise `action_unlikely_intent`.
683
+
684
+ Args:
685
+ intent: Intent to be queried.
686
+ domain: Domain of the assistant.
687
+
688
+ Returns:
689
+ Whether intent should raise `action_unlikely_intent` or not.
690
+ """
691
+ if domain.intents.index(intent) not in self.label_thresholds:
692
+ # This means the intent was never present in a story
693
+ logger.debug(
694
+ f"Query intent index {domain.intents.index(intent)} not "
695
+ f"found in label thresholds - {self.label_thresholds}. "
696
+ f"Check for `{ACTION_UNLIKELY_INTENT_NAME}` prediction will be skipped."
697
+ )
698
+ return False
699
+ if intent in self.config[IGNORE_INTENTS_LIST]:
700
+ logger.debug(
701
+ f"Query intent `{intent}` found in "
702
+ f"`{IGNORE_INTENTS_LIST}={self.config[IGNORE_INTENTS_LIST]}`. "
703
+ f"Check for `{ACTION_UNLIKELY_INTENT_NAME}` prediction will be skipped."
704
+ )
705
+ return False
706
+
707
+ return True
708
+
709
+ def _check_unlikely_intent(
710
+ self, domain: Domain, similarities: np.ndarray, query_intent: Text
711
+ ) -> bool:
712
+ """Checks if the query intent is probable according to model's predictions.
713
+
714
+ If the similarity prediction for the intent
715
+ is lower than the threshold calculated for that
716
+ intent during training, the corresponding user
717
+ intent is unlikely.
718
+
719
+ Args:
720
+ domain: Domain of the assistant.
721
+ similarities: Predicted similarities for all intents.
722
+ query_intent: Intent to be queried.
723
+
724
+ Returns:
725
+ Whether query intent is likely or not.
726
+ """
727
+ logger.debug(f"Querying for intent `{query_intent}`.")
728
+
729
+ if not self._should_check_for_intent(query_intent, domain):
730
+ return False
731
+
732
+ predicted_intent_scores = {
733
+ index: similarities[0][index] for index, intent in enumerate(domain.intents)
734
+ }
735
+ sorted_intent_scores = sorted(
736
+ [
737
+ (domain.intents[label_index], score)
738
+ for label_index, score in predicted_intent_scores.items()
739
+ ],
740
+ key=lambda x: x[1],
741
+ )
742
+ query_intent_id = domain.intents.index(query_intent)
743
+ query_intent_similarity = similarities[0][query_intent_id]
744
+ highest_likely_intent_id = domain.intents.index(sorted_intent_scores[-1][0])
745
+
746
+ logger.debug(
747
+ f"Score for intent `{query_intent}` is "
748
+ f"`{query_intent_similarity}`, while "
749
+ f"threshold is `{self.label_thresholds[query_intent_id]}`."
750
+ )
751
+ logger.debug(
752
+ f"Top 5 intents (in ascending order) that "
753
+ f"are likely here are: `{sorted_intent_scores[-5:]}`."
754
+ )
755
+
756
+ # If score for query intent is below threshold and
757
+ # the query intent is not the top likely intent
758
+ if (
759
+ query_intent_similarity < self.label_thresholds[query_intent_id]
760
+ and query_intent_id != highest_likely_intent_id
761
+ ):
762
+ logger.debug(
763
+ f"Intent `{query_intent}-{query_intent_id}` unlikely to occur here."
764
+ )
765
+ return True
766
+
767
+ return False
768
+
769
+ @staticmethod
770
+ def _collect_label_id_grouped_scores(
771
+ output_scores: Dict[Text, np.ndarray], label_ids: np.ndarray
772
+ ) -> Dict[int, Dict[Text, List[float]]]:
773
+ """Collects similarities predicted for each label id.
774
+
775
+ For each `label_id`, we collect similarity scores across
776
+ all trackers and categorize them into two buckets:
777
+ 1. Similarity scores when `label_id` is the correct label.
778
+ 2. Similarity scores when `label_id` is the wrong label.
779
+
780
+ Args:
781
+ output_scores: Model's predictions for each data point.
782
+ label_ids: Numerical IDs of labels for each data point.
783
+
784
+ Returns:
785
+ Both buckets of similarity scores grouped by each unique label id.
786
+ """
787
+ unique_label_ids = np.unique(label_ids).tolist()
788
+ if LABEL_PAD_ID in unique_label_ids:
789
+ unique_label_ids.remove(LABEL_PAD_ID)
790
+
791
+ label_id_scores: Dict[int, Dict[Text, List[float]]] = {
792
+ label_id: {POSITIVE_SCORES_KEY: [], NEGATIVE_SCORES_KEY: []}
793
+ for label_id in unique_label_ids
794
+ }
795
+
796
+ for index, all_pos_labels in enumerate(label_ids):
797
+ for candidate_label_id in unique_label_ids:
798
+ if candidate_label_id in all_pos_labels:
799
+ label_id_scores[candidate_label_id][POSITIVE_SCORES_KEY].append(
800
+ output_scores["similarities"][index, 0, candidate_label_id]
801
+ )
802
+ else:
803
+ label_id_scores[candidate_label_id][NEGATIVE_SCORES_KEY].append(
804
+ output_scores["similarities"][index, 0, candidate_label_id]
805
+ )
806
+
807
+ return label_id_scores
808
+
809
+ @staticmethod
810
+ def _compute_label_quantiles(
811
+ label_id_scores: Dict[int, Dict[Text, List[float]]],
812
+ ) -> Dict[int, List[float]]:
813
+ """Computes multiple quantiles for each label id.
814
+
815
+ The quantiles are computed over the negative scores
816
+ collected for each label id. However, no quantile score
817
+ can be greater than the minimum positive score collected
818
+ for the corresponding label id.
819
+
820
+ Args:
821
+ label_id_scores: Scores collected for each label id
822
+ over positive and negative trackers.
823
+
824
+ Returns:
825
+ Computed quantiles for each label id.
826
+ """
827
+ label_quantiles = {}
828
+
829
+ quantile_indices = [
830
+ 1 - tolerance_value / 100.0 for tolerance_value in range(0, 100, 5)
831
+ ]
832
+ for label_id, prediction_scores in label_id_scores.items():
833
+ positive_scores, negative_scores = (
834
+ prediction_scores[POSITIVE_SCORES_KEY],
835
+ prediction_scores[NEGATIVE_SCORES_KEY],
836
+ )
837
+ minimum_positive_score = min(positive_scores)
838
+ if negative_scores:
839
+ quantile_values = np.quantile( # type: ignore[call-overload]
840
+ negative_scores, quantile_indices, interpolation="lower"
841
+ )
842
+ label_quantiles[label_id] = [
843
+ min(minimum_positive_score, value) for value in quantile_values
844
+ ]
845
+ else:
846
+ label_quantiles[label_id] = [minimum_positive_score] * len(
847
+ quantile_indices
848
+ )
849
+
850
+ return label_quantiles
851
+
852
+ @staticmethod
853
+ def _pick_thresholds(
854
+ label_quantiles: Dict[int, List[float]], tolerance: float
855
+ ) -> Dict[int, float]:
856
+ """Computes a threshold for each label id.
857
+
858
+ Uses tolerance which is the percentage of negative
859
+ trackers for which predicted score should be equal
860
+ to or above the threshold.
861
+
862
+ Args:
863
+ label_quantiles: Quantiles computed for each label id
864
+ tolerance: Specified tolerance value from the configuration.
865
+
866
+ Returns:
867
+ Computed thresholds
868
+ """
869
+ label_thresholds = {}
870
+ for label_id in label_quantiles:
871
+ num_thresholds = len(label_quantiles[label_id])
872
+ label_thresholds[label_id] = label_quantiles[label_id][
873
+ min(int(tolerance * num_thresholds), num_thresholds - 1)
874
+ ]
875
+ return label_thresholds
876
+
877
+ def persist_model_utilities(self, model_path: Path) -> None:
878
+ """Persists model's utility attributes like model weights, etc.
879
+
880
+ Args:
881
+ model_path: Path where model is to be persisted
882
+ """
883
+ super().persist_model_utilities(model_path)
884
+
885
+ from safetensors.numpy import save_file
886
+
887
+ save_file(
888
+ {str(k): np.array(v) for k, v in self.label_quantiles.items()},
889
+ model_path / f"{self._metadata_filename()}.label_quantiles.st",
890
+ )
891
+
892
+ @classmethod
893
+ def _load_model_utilities(cls, model_path: Path) -> Dict[Text, Any]:
894
+ """Loads model's utility attributes.
895
+
896
+ Args:
897
+ model_path: Path where model is to be persisted.
898
+ """
899
+ model_utilties = super()._load_model_utilities(model_path)
900
+
901
+ from safetensors.numpy import load_file
902
+
903
+ loaded_label_quantiles = load_file(
904
+ model_path / f"{cls._metadata_filename()}.label_quantiles.st"
905
+ )
906
+ label_quantiles = {int(k): list(v) for k, v in loaded_label_quantiles.items()}
907
+
908
+ model_utilties.update({"label_quantiles": label_quantiles})
909
+ return model_utilties
910
+
911
+ @classmethod
912
+ def _update_loaded_params(cls, meta: Dict[Text, Any]) -> Dict[Text, Any]:
913
+ meta = rasa.utils.common.override_defaults(cls.get_default_config(), meta)
914
+ return meta
915
+
916
+ @classmethod
917
+ def _load_policy_with_model(
918
+ cls,
919
+ config: Dict[Text, Any],
920
+ model_storage: ModelStorage,
921
+ resource: Resource,
922
+ execution_context: ExecutionContext,
923
+ featurizer: TrackerFeaturizer,
924
+ model: "IntentTED",
925
+ model_utilities: Dict[Text, Any],
926
+ ) -> "UnexpecTEDIntentPolicy":
927
+ return cls(
928
+ config,
929
+ model_storage,
930
+ resource,
931
+ execution_context,
932
+ model=model,
933
+ featurizer=featurizer,
934
+ fake_features=model_utilities["fake_features"],
935
+ entity_tag_specs=model_utilities["entity_tag_specs"],
936
+ label_quantiles=model_utilities["label_quantiles"],
937
+ )
938
+
939
+
940
+ class IntentTED(TED):
941
+ """Follows TED's model architecture from https://arxiv.org/abs/1910.00486.
942
+
943
+ However, it has been re-purposed to predict multiple
944
+ labels (intents) instead of a single label (action).
945
+ """
946
+
947
+ def _prepare_dot_product_loss(
948
+ self, name: Text, scale_loss: bool, prefix: Text = "loss"
949
+ ) -> None:
950
+ self._tf_layers[f"{prefix}.{name}"] = self.dot_product_loss_layer(
951
+ self.config[NUM_NEG],
952
+ scale_loss,
953
+ similarity_type=self.config[SIMILARITY_TYPE],
954
+ )
955
+
956
+ @property
957
+ def dot_product_loss_layer(self) -> tf.keras.layers.Layer:
958
+ """Returns the dot-product loss layer to use.
959
+
960
+ Multiple intents can be valid simultaneously, so `IntentTED` uses the
961
+ `MultiLabelDotProductLoss`.
962
+
963
+ Returns:
964
+ The loss layer that is used by `_prepare_dot_product_loss`.
965
+ """
966
+ return layers.MultiLabelDotProductLoss
967
+
968
+ @staticmethod
969
+ def _get_labels_embed(
970
+ label_ids: tf.Tensor, all_labels_embed: tf.Tensor
971
+ ) -> tf.Tensor:
972
+ # instead of processing labels again, gather embeddings from
973
+ # all_labels_embed using label ids
974
+
975
+ indices = tf.cast(label_ids[:, :, 0], tf.int32)
976
+
977
+ # Find padding indices. They should have a value equal to `LABEL_PAD_ID`
978
+ padding_indices = tf.where(tf.equal(indices, LABEL_PAD_ID))
979
+
980
+ # Create a tensor of values with sign opposite to `LABEL_PAD_ID` which
981
+ # will serve as updates to original `indices`
982
+ updates_to_indices = (
983
+ tf.ones((tf.shape(padding_indices)[0]), dtype=tf.int32) * -1 * LABEL_PAD_ID
984
+ )
985
+
986
+ # Add the updates tensor to indices with padding.
987
+ # So, effectively all indices with `LABEL_PAD_ID=-1`
988
+ # become 0 because updates contain 1s.
989
+ # This is fine because we don't change the original non-padding label
990
+ # indices but only make the padding indices 'compatible'
991
+ # for the `tf.gather` op below.
992
+ indices_to_gather = tf.cast(
993
+ tf.tensor_scatter_nd_add(indices, padding_indices, updates_to_indices),
994
+ tf.int32,
995
+ )
996
+
997
+ labels_embed = tf.gather(all_labels_embed, indices_to_gather)
998
+
999
+ return labels_embed
1000
+
1001
+ def run_bulk_inference(
1002
+ self, model_data: RasaModelData
1003
+ ) -> Dict[Text, Union[np.ndarray, Dict[Text, Any]]]:
1004
+ """Computes model's predictions for input data.
1005
+
1006
+ Args:
1007
+ model_data: Data to be passed as input
1008
+
1009
+ Returns:
1010
+ Predictions for the input data.
1011
+ """
1012
+ self._training = False
1013
+
1014
+ batch_size = (
1015
+ self.config[BATCH_SIZES]
1016
+ if isinstance(self.config[BATCH_SIZES], int)
1017
+ else self.config[BATCH_SIZES][0]
1018
+ )
1019
+
1020
+ return self.run_inference(
1021
+ model_data, batch_size=batch_size, output_keys_expected=["similarities"]
1022
+ )