rasa-pro 3.12.0.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rasa-pro might be problematic. Click here for more details.

Files changed (790) hide show
  1. README.md +41 -0
  2. rasa/__init__.py +9 -0
  3. rasa/__main__.py +177 -0
  4. rasa/anonymization/__init__.py +2 -0
  5. rasa/anonymization/anonymisation_rule_yaml_reader.py +91 -0
  6. rasa/anonymization/anonymization_pipeline.py +286 -0
  7. rasa/anonymization/anonymization_rule_executor.py +260 -0
  8. rasa/anonymization/anonymization_rule_orchestrator.py +120 -0
  9. rasa/anonymization/schemas/config.yml +47 -0
  10. rasa/anonymization/utils.py +118 -0
  11. rasa/api.py +160 -0
  12. rasa/cli/__init__.py +5 -0
  13. rasa/cli/arguments/__init__.py +0 -0
  14. rasa/cli/arguments/data.py +106 -0
  15. rasa/cli/arguments/default_arguments.py +207 -0
  16. rasa/cli/arguments/evaluate.py +65 -0
  17. rasa/cli/arguments/export.py +51 -0
  18. rasa/cli/arguments/interactive.py +74 -0
  19. rasa/cli/arguments/run.py +219 -0
  20. rasa/cli/arguments/shell.py +17 -0
  21. rasa/cli/arguments/test.py +211 -0
  22. rasa/cli/arguments/train.py +279 -0
  23. rasa/cli/arguments/visualize.py +34 -0
  24. rasa/cli/arguments/x.py +30 -0
  25. rasa/cli/data.py +354 -0
  26. rasa/cli/dialogue_understanding_test.py +251 -0
  27. rasa/cli/e2e_test.py +259 -0
  28. rasa/cli/evaluate.py +222 -0
  29. rasa/cli/export.py +250 -0
  30. rasa/cli/inspect.py +75 -0
  31. rasa/cli/interactive.py +166 -0
  32. rasa/cli/license.py +65 -0
  33. rasa/cli/llm_fine_tuning.py +403 -0
  34. rasa/cli/markers.py +78 -0
  35. rasa/cli/project_templates/__init__.py +0 -0
  36. rasa/cli/project_templates/calm/actions/__init__.py +0 -0
  37. rasa/cli/project_templates/calm/actions/action_template.py +27 -0
  38. rasa/cli/project_templates/calm/actions/add_contact.py +30 -0
  39. rasa/cli/project_templates/calm/actions/db.py +57 -0
  40. rasa/cli/project_templates/calm/actions/list_contacts.py +22 -0
  41. rasa/cli/project_templates/calm/actions/remove_contact.py +35 -0
  42. rasa/cli/project_templates/calm/config.yml +10 -0
  43. rasa/cli/project_templates/calm/credentials.yml +33 -0
  44. rasa/cli/project_templates/calm/data/flows/add_contact.yml +31 -0
  45. rasa/cli/project_templates/calm/data/flows/list_contacts.yml +14 -0
  46. rasa/cli/project_templates/calm/data/flows/remove_contact.yml +29 -0
  47. rasa/cli/project_templates/calm/db/contacts.json +10 -0
  48. rasa/cli/project_templates/calm/domain/add_contact.yml +39 -0
  49. rasa/cli/project_templates/calm/domain/list_contacts.yml +17 -0
  50. rasa/cli/project_templates/calm/domain/remove_contact.yml +38 -0
  51. rasa/cli/project_templates/calm/domain/shared.yml +10 -0
  52. rasa/cli/project_templates/calm/e2e_tests/cancelations/user_cancels_during_a_correction.yml +16 -0
  53. rasa/cli/project_templates/calm/e2e_tests/cancelations/user_changes_mind_on_a_whim.yml +7 -0
  54. rasa/cli/project_templates/calm/e2e_tests/corrections/user_corrects_contact_handle.yml +20 -0
  55. rasa/cli/project_templates/calm/e2e_tests/corrections/user_corrects_contact_name.yml +19 -0
  56. rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_adds_contact_to_their_list.yml +15 -0
  57. rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_lists_contacts.yml +5 -0
  58. rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_removes_contact.yml +11 -0
  59. rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_removes_contact_from_list.yml +12 -0
  60. rasa/cli/project_templates/calm/endpoints.yml +58 -0
  61. rasa/cli/project_templates/default/actions/__init__.py +0 -0
  62. rasa/cli/project_templates/default/actions/actions.py +27 -0
  63. rasa/cli/project_templates/default/config.yml +44 -0
  64. rasa/cli/project_templates/default/credentials.yml +33 -0
  65. rasa/cli/project_templates/default/data/nlu.yml +91 -0
  66. rasa/cli/project_templates/default/data/rules.yml +13 -0
  67. rasa/cli/project_templates/default/data/stories.yml +30 -0
  68. rasa/cli/project_templates/default/domain.yml +34 -0
  69. rasa/cli/project_templates/default/endpoints.yml +42 -0
  70. rasa/cli/project_templates/default/tests/test_stories.yml +91 -0
  71. rasa/cli/project_templates/tutorial/actions/__init__.py +0 -0
  72. rasa/cli/project_templates/tutorial/actions/actions.py +22 -0
  73. rasa/cli/project_templates/tutorial/config.yml +12 -0
  74. rasa/cli/project_templates/tutorial/credentials.yml +33 -0
  75. rasa/cli/project_templates/tutorial/data/flows.yml +8 -0
  76. rasa/cli/project_templates/tutorial/data/patterns.yml +11 -0
  77. rasa/cli/project_templates/tutorial/domain.yml +35 -0
  78. rasa/cli/project_templates/tutorial/endpoints.yml +55 -0
  79. rasa/cli/run.py +143 -0
  80. rasa/cli/scaffold.py +273 -0
  81. rasa/cli/shell.py +141 -0
  82. rasa/cli/studio/__init__.py +0 -0
  83. rasa/cli/studio/download.py +62 -0
  84. rasa/cli/studio/studio.py +296 -0
  85. rasa/cli/studio/train.py +59 -0
  86. rasa/cli/studio/upload.py +62 -0
  87. rasa/cli/telemetry.py +102 -0
  88. rasa/cli/test.py +280 -0
  89. rasa/cli/train.py +278 -0
  90. rasa/cli/utils.py +484 -0
  91. rasa/cli/visualize.py +40 -0
  92. rasa/cli/x.py +206 -0
  93. rasa/constants.py +45 -0
  94. rasa/core/__init__.py +17 -0
  95. rasa/core/actions/__init__.py +0 -0
  96. rasa/core/actions/action.py +1318 -0
  97. rasa/core/actions/action_clean_stack.py +59 -0
  98. rasa/core/actions/action_exceptions.py +24 -0
  99. rasa/core/actions/action_hangup.py +29 -0
  100. rasa/core/actions/action_repeat_bot_messages.py +89 -0
  101. rasa/core/actions/action_run_slot_rejections.py +210 -0
  102. rasa/core/actions/action_trigger_chitchat.py +31 -0
  103. rasa/core/actions/action_trigger_flow.py +109 -0
  104. rasa/core/actions/action_trigger_search.py +31 -0
  105. rasa/core/actions/constants.py +5 -0
  106. rasa/core/actions/custom_action_executor.py +191 -0
  107. rasa/core/actions/direct_custom_actions_executor.py +109 -0
  108. rasa/core/actions/e2e_stub_custom_action_executor.py +72 -0
  109. rasa/core/actions/forms.py +741 -0
  110. rasa/core/actions/grpc_custom_action_executor.py +251 -0
  111. rasa/core/actions/http_custom_action_executor.py +145 -0
  112. rasa/core/actions/loops.py +114 -0
  113. rasa/core/actions/two_stage_fallback.py +186 -0
  114. rasa/core/agent.py +559 -0
  115. rasa/core/auth_retry_tracker_store.py +122 -0
  116. rasa/core/brokers/__init__.py +0 -0
  117. rasa/core/brokers/broker.py +126 -0
  118. rasa/core/brokers/file.py +58 -0
  119. rasa/core/brokers/kafka.py +324 -0
  120. rasa/core/brokers/pika.py +388 -0
  121. rasa/core/brokers/sql.py +86 -0
  122. rasa/core/channels/__init__.py +61 -0
  123. rasa/core/channels/botframework.py +338 -0
  124. rasa/core/channels/callback.py +84 -0
  125. rasa/core/channels/channel.py +456 -0
  126. rasa/core/channels/console.py +241 -0
  127. rasa/core/channels/development_inspector.py +197 -0
  128. rasa/core/channels/facebook.py +419 -0
  129. rasa/core/channels/hangouts.py +329 -0
  130. rasa/core/channels/inspector/.eslintrc.cjs +25 -0
  131. rasa/core/channels/inspector/.gitignore +23 -0
  132. rasa/core/channels/inspector/README.md +54 -0
  133. rasa/core/channels/inspector/assets/favicon.ico +0 -0
  134. rasa/core/channels/inspector/assets/rasa-chat.js +2 -0
  135. rasa/core/channels/inspector/custom.d.ts +3 -0
  136. rasa/core/channels/inspector/dist/assets/arc-861ddd57.js +1 -0
  137. rasa/core/channels/inspector/dist/assets/array-9f3ba611.js +1 -0
  138. rasa/core/channels/inspector/dist/assets/c4Diagram-d0fbc5ce-921f02db.js +10 -0
  139. rasa/core/channels/inspector/dist/assets/classDiagram-936ed81e-b436c4f8.js +2 -0
  140. rasa/core/channels/inspector/dist/assets/classDiagram-v2-c3cb15f1-511a23cb.js +2 -0
  141. rasa/core/channels/inspector/dist/assets/createText-62fc7601-ef476ecd.js +7 -0
  142. rasa/core/channels/inspector/dist/assets/edges-f2ad444c-f1878e0a.js +4 -0
  143. rasa/core/channels/inspector/dist/assets/erDiagram-9d236eb7-fac75185.js +51 -0
  144. rasa/core/channels/inspector/dist/assets/flowDb-1972c806-201c5bbc.js +6 -0
  145. rasa/core/channels/inspector/dist/assets/flowDiagram-7ea5b25a-f904ae41.js +4 -0
  146. rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-b080d6f2.js +1 -0
  147. rasa/core/channels/inspector/dist/assets/flowchart-elk-definition-abe16c3d-1813da66.js +139 -0
  148. rasa/core/channels/inspector/dist/assets/ganttDiagram-9b5ea136-872af172.js +266 -0
  149. rasa/core/channels/inspector/dist/assets/gitGraphDiagram-99d0ae7c-34a0af5a.js +70 -0
  150. rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-128cfa44.ttf +0 -0
  151. rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-21dbcb97.woff +0 -0
  152. rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-222b5e26.svg +329 -0
  153. rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-9ad89b2a.woff2 +0 -0
  154. rasa/core/channels/inspector/dist/assets/index-2c4b9a3b-42ba3e3d.js +1 -0
  155. rasa/core/channels/inspector/dist/assets/index-37817b51.js +1317 -0
  156. rasa/core/channels/inspector/dist/assets/index-3ee28881.css +1 -0
  157. rasa/core/channels/inspector/dist/assets/infoDiagram-736b4530-6b731386.js +7 -0
  158. rasa/core/channels/inspector/dist/assets/init-77b53fdd.js +1 -0
  159. rasa/core/channels/inspector/dist/assets/journeyDiagram-df861f2b-e8579ac6.js +139 -0
  160. rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-60c05ee4.woff +0 -0
  161. rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-8335d9b8.svg +438 -0
  162. rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-9cc39c75.ttf +0 -0
  163. rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-ead13ccf.woff2 +0 -0
  164. rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-16705655.woff2 +0 -0
  165. rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-5aeb07f9.woff +0 -0
  166. rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-9c459044.ttf +0 -0
  167. rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-9e2898a4.svg +435 -0
  168. rasa/core/channels/inspector/dist/assets/layout-89e6403a.js +1 -0
  169. rasa/core/channels/inspector/dist/assets/line-dc73d3fc.js +1 -0
  170. rasa/core/channels/inspector/dist/assets/linear-f5b1d2bc.js +1 -0
  171. rasa/core/channels/inspector/dist/assets/mindmap-definition-beec6740-82cb74fa.js +109 -0
  172. rasa/core/channels/inspector/dist/assets/ordinal-ba9b4969.js +1 -0
  173. rasa/core/channels/inspector/dist/assets/path-53f90ab3.js +1 -0
  174. rasa/core/channels/inspector/dist/assets/pieDiagram-dbbf0591-bdf5f29b.js +35 -0
  175. rasa/core/channels/inspector/dist/assets/quadrantDiagram-4d7f4fd6-c7a0cbe4.js +7 -0
  176. rasa/core/channels/inspector/dist/assets/requirementDiagram-6fc4c22a-7ec5410f.js +52 -0
  177. rasa/core/channels/inspector/dist/assets/sankeyDiagram-8f13d901-caee5554.js +8 -0
  178. rasa/core/channels/inspector/dist/assets/sequenceDiagram-b655622a-2935f8db.js +122 -0
  179. rasa/core/channels/inspector/dist/assets/stateDiagram-59f0c015-8f5d9693.js +1 -0
  180. rasa/core/channels/inspector/dist/assets/stateDiagram-v2-2b26beab-d565d1de.js +1 -0
  181. rasa/core/channels/inspector/dist/assets/styles-080da4f6-75ad421d.js +110 -0
  182. rasa/core/channels/inspector/dist/assets/styles-3dcbcfbf-7e764226.js +159 -0
  183. rasa/core/channels/inspector/dist/assets/styles-9c745c82-7a4e0e61.js +207 -0
  184. rasa/core/channels/inspector/dist/assets/svgDrawCommon-4835440b-4019d1bf.js +1 -0
  185. rasa/core/channels/inspector/dist/assets/timeline-definition-5b62e21b-01ea12df.js +61 -0
  186. rasa/core/channels/inspector/dist/assets/xychartDiagram-2b33534f-89407137.js +7 -0
  187. rasa/core/channels/inspector/dist/index.html +42 -0
  188. rasa/core/channels/inspector/index.html +40 -0
  189. rasa/core/channels/inspector/jest.config.ts +13 -0
  190. rasa/core/channels/inspector/package.json +52 -0
  191. rasa/core/channels/inspector/setupTests.ts +2 -0
  192. rasa/core/channels/inspector/src/App.tsx +220 -0
  193. rasa/core/channels/inspector/src/components/Chat.tsx +95 -0
  194. rasa/core/channels/inspector/src/components/DiagramFlow.tsx +108 -0
  195. rasa/core/channels/inspector/src/components/DialogueInformation.tsx +187 -0
  196. rasa/core/channels/inspector/src/components/DialogueStack.tsx +136 -0
  197. rasa/core/channels/inspector/src/components/ExpandIcon.tsx +16 -0
  198. rasa/core/channels/inspector/src/components/FullscreenButton.tsx +45 -0
  199. rasa/core/channels/inspector/src/components/LoadingSpinner.tsx +22 -0
  200. rasa/core/channels/inspector/src/components/NoActiveFlow.tsx +21 -0
  201. rasa/core/channels/inspector/src/components/RasaLogo.tsx +32 -0
  202. rasa/core/channels/inspector/src/components/SaraDiagrams.tsx +39 -0
  203. rasa/core/channels/inspector/src/components/Slots.tsx +91 -0
  204. rasa/core/channels/inspector/src/components/Welcome.tsx +54 -0
  205. rasa/core/channels/inspector/src/helpers/audiostream.ts +191 -0
  206. rasa/core/channels/inspector/src/helpers/formatters.test.ts +392 -0
  207. rasa/core/channels/inspector/src/helpers/formatters.ts +306 -0
  208. rasa/core/channels/inspector/src/helpers/utils.ts +127 -0
  209. rasa/core/channels/inspector/src/main.tsx +13 -0
  210. rasa/core/channels/inspector/src/theme/Button/Button.ts +29 -0
  211. rasa/core/channels/inspector/src/theme/Heading/Heading.ts +31 -0
  212. rasa/core/channels/inspector/src/theme/Input/Input.ts +27 -0
  213. rasa/core/channels/inspector/src/theme/Link/Link.ts +10 -0
  214. rasa/core/channels/inspector/src/theme/Modal/Modal.ts +47 -0
  215. rasa/core/channels/inspector/src/theme/Table/Table.tsx +38 -0
  216. rasa/core/channels/inspector/src/theme/Tooltip/Tooltip.ts +12 -0
  217. rasa/core/channels/inspector/src/theme/base/breakpoints.ts +8 -0
  218. rasa/core/channels/inspector/src/theme/base/colors.ts +88 -0
  219. rasa/core/channels/inspector/src/theme/base/fonts/fontFaces.css +29 -0
  220. rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.eot +0 -0
  221. rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.svg +329 -0
  222. rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.ttf +0 -0
  223. rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.woff +0 -0
  224. rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.woff2 +0 -0
  225. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.eot +0 -0
  226. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.svg +438 -0
  227. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.ttf +0 -0
  228. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.woff +0 -0
  229. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.woff2 +0 -0
  230. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.eot +0 -0
  231. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.svg +435 -0
  232. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.ttf +0 -0
  233. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.woff +0 -0
  234. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.woff2 +0 -0
  235. rasa/core/channels/inspector/src/theme/base/radii.ts +9 -0
  236. rasa/core/channels/inspector/src/theme/base/shadows.ts +7 -0
  237. rasa/core/channels/inspector/src/theme/base/sizes.ts +7 -0
  238. rasa/core/channels/inspector/src/theme/base/space.ts +15 -0
  239. rasa/core/channels/inspector/src/theme/base/styles.ts +13 -0
  240. rasa/core/channels/inspector/src/theme/base/typography.ts +24 -0
  241. rasa/core/channels/inspector/src/theme/base/zIndices.ts +19 -0
  242. rasa/core/channels/inspector/src/theme/index.ts +101 -0
  243. rasa/core/channels/inspector/src/types.ts +84 -0
  244. rasa/core/channels/inspector/src/vite-env.d.ts +1 -0
  245. rasa/core/channels/inspector/tests/__mocks__/fileMock.ts +1 -0
  246. rasa/core/channels/inspector/tests/__mocks__/matchMedia.ts +16 -0
  247. rasa/core/channels/inspector/tests/__mocks__/styleMock.ts +1 -0
  248. rasa/core/channels/inspector/tests/renderWithProviders.tsx +14 -0
  249. rasa/core/channels/inspector/tsconfig.json +26 -0
  250. rasa/core/channels/inspector/tsconfig.node.json +10 -0
  251. rasa/core/channels/inspector/vite.config.ts +8 -0
  252. rasa/core/channels/inspector/yarn.lock +6249 -0
  253. rasa/core/channels/mattermost.py +229 -0
  254. rasa/core/channels/rasa_chat.py +126 -0
  255. rasa/core/channels/rest.py +230 -0
  256. rasa/core/channels/rocketchat.py +174 -0
  257. rasa/core/channels/slack.py +620 -0
  258. rasa/core/channels/socketio.py +302 -0
  259. rasa/core/channels/telegram.py +298 -0
  260. rasa/core/channels/twilio.py +169 -0
  261. rasa/core/channels/vier_cvg.py +374 -0
  262. rasa/core/channels/voice_ready/__init__.py +0 -0
  263. rasa/core/channels/voice_ready/audiocodes.py +501 -0
  264. rasa/core/channels/voice_ready/jambonz.py +121 -0
  265. rasa/core/channels/voice_ready/jambonz_protocol.py +396 -0
  266. rasa/core/channels/voice_ready/twilio_voice.py +403 -0
  267. rasa/core/channels/voice_ready/utils.py +37 -0
  268. rasa/core/channels/voice_stream/__init__.py +0 -0
  269. rasa/core/channels/voice_stream/asr/__init__.py +0 -0
  270. rasa/core/channels/voice_stream/asr/asr_engine.py +89 -0
  271. rasa/core/channels/voice_stream/asr/asr_event.py +18 -0
  272. rasa/core/channels/voice_stream/asr/azure.py +130 -0
  273. rasa/core/channels/voice_stream/asr/deepgram.py +90 -0
  274. rasa/core/channels/voice_stream/audio_bytes.py +8 -0
  275. rasa/core/channels/voice_stream/browser_audio.py +107 -0
  276. rasa/core/channels/voice_stream/call_state.py +23 -0
  277. rasa/core/channels/voice_stream/tts/__init__.py +0 -0
  278. rasa/core/channels/voice_stream/tts/azure.py +106 -0
  279. rasa/core/channels/voice_stream/tts/cartesia.py +118 -0
  280. rasa/core/channels/voice_stream/tts/tts_cache.py +27 -0
  281. rasa/core/channels/voice_stream/tts/tts_engine.py +58 -0
  282. rasa/core/channels/voice_stream/twilio_media_streams.py +173 -0
  283. rasa/core/channels/voice_stream/util.py +57 -0
  284. rasa/core/channels/voice_stream/voice_channel.py +427 -0
  285. rasa/core/channels/webexteams.py +134 -0
  286. rasa/core/concurrent_lock_store.py +210 -0
  287. rasa/core/constants.py +112 -0
  288. rasa/core/evaluation/__init__.py +0 -0
  289. rasa/core/evaluation/marker.py +267 -0
  290. rasa/core/evaluation/marker_base.py +923 -0
  291. rasa/core/evaluation/marker_stats.py +293 -0
  292. rasa/core/evaluation/marker_tracker_loader.py +103 -0
  293. rasa/core/exceptions.py +29 -0
  294. rasa/core/exporter.py +284 -0
  295. rasa/core/featurizers/__init__.py +0 -0
  296. rasa/core/featurizers/precomputation.py +410 -0
  297. rasa/core/featurizers/single_state_featurizer.py +421 -0
  298. rasa/core/featurizers/tracker_featurizers.py +1262 -0
  299. rasa/core/http_interpreter.py +89 -0
  300. rasa/core/information_retrieval/__init__.py +7 -0
  301. rasa/core/information_retrieval/faiss.py +124 -0
  302. rasa/core/information_retrieval/information_retrieval.py +137 -0
  303. rasa/core/information_retrieval/milvus.py +59 -0
  304. rasa/core/information_retrieval/qdrant.py +96 -0
  305. rasa/core/jobs.py +63 -0
  306. rasa/core/lock.py +139 -0
  307. rasa/core/lock_store.py +343 -0
  308. rasa/core/migrate.py +403 -0
  309. rasa/core/nlg/__init__.py +3 -0
  310. rasa/core/nlg/callback.py +146 -0
  311. rasa/core/nlg/contextual_response_rephraser.py +320 -0
  312. rasa/core/nlg/generator.py +230 -0
  313. rasa/core/nlg/interpolator.py +143 -0
  314. rasa/core/nlg/response.py +155 -0
  315. rasa/core/nlg/summarize.py +70 -0
  316. rasa/core/persistor.py +538 -0
  317. rasa/core/policies/__init__.py +0 -0
  318. rasa/core/policies/ensemble.py +329 -0
  319. rasa/core/policies/enterprise_search_policy.py +905 -0
  320. rasa/core/policies/enterprise_search_prompt_template.jinja2 +25 -0
  321. rasa/core/policies/enterprise_search_prompt_with_citation_template.jinja2 +60 -0
  322. rasa/core/policies/flow_policy.py +205 -0
  323. rasa/core/policies/flows/__init__.py +0 -0
  324. rasa/core/policies/flows/flow_exceptions.py +44 -0
  325. rasa/core/policies/flows/flow_executor.py +754 -0
  326. rasa/core/policies/flows/flow_step_result.py +43 -0
  327. rasa/core/policies/intentless_policy.py +1031 -0
  328. rasa/core/policies/intentless_prompt_template.jinja2 +22 -0
  329. rasa/core/policies/memoization.py +538 -0
  330. rasa/core/policies/policy.py +725 -0
  331. rasa/core/policies/rule_policy.py +1273 -0
  332. rasa/core/policies/ted_policy.py +2169 -0
  333. rasa/core/policies/unexpected_intent_policy.py +1022 -0
  334. rasa/core/processor.py +1465 -0
  335. rasa/core/run.py +342 -0
  336. rasa/core/secrets_manager/__init__.py +0 -0
  337. rasa/core/secrets_manager/constants.py +36 -0
  338. rasa/core/secrets_manager/endpoints.py +391 -0
  339. rasa/core/secrets_manager/factory.py +241 -0
  340. rasa/core/secrets_manager/secret_manager.py +262 -0
  341. rasa/core/secrets_manager/vault.py +584 -0
  342. rasa/core/test.py +1335 -0
  343. rasa/core/tracker_store.py +1703 -0
  344. rasa/core/train.py +105 -0
  345. rasa/core/training/__init__.py +89 -0
  346. rasa/core/training/converters/__init__.py +0 -0
  347. rasa/core/training/converters/responses_prefix_converter.py +119 -0
  348. rasa/core/training/interactive.py +1744 -0
  349. rasa/core/training/story_conflict.py +381 -0
  350. rasa/core/training/training.py +93 -0
  351. rasa/core/utils.py +366 -0
  352. rasa/core/visualize.py +70 -0
  353. rasa/dialogue_understanding/__init__.py +0 -0
  354. rasa/dialogue_understanding/coexistence/__init__.py +0 -0
  355. rasa/dialogue_understanding/coexistence/constants.py +4 -0
  356. rasa/dialogue_understanding/coexistence/intent_based_router.py +196 -0
  357. rasa/dialogue_understanding/coexistence/llm_based_router.py +327 -0
  358. rasa/dialogue_understanding/coexistence/router_template.jinja2 +12 -0
  359. rasa/dialogue_understanding/commands/__init__.py +61 -0
  360. rasa/dialogue_understanding/commands/can_not_handle_command.py +70 -0
  361. rasa/dialogue_understanding/commands/cancel_flow_command.py +125 -0
  362. rasa/dialogue_understanding/commands/change_flow_command.py +44 -0
  363. rasa/dialogue_understanding/commands/chit_chat_answer_command.py +57 -0
  364. rasa/dialogue_understanding/commands/clarify_command.py +86 -0
  365. rasa/dialogue_understanding/commands/command.py +85 -0
  366. rasa/dialogue_understanding/commands/correct_slots_command.py +297 -0
  367. rasa/dialogue_understanding/commands/error_command.py +79 -0
  368. rasa/dialogue_understanding/commands/free_form_answer_command.py +9 -0
  369. rasa/dialogue_understanding/commands/handle_code_change_command.py +73 -0
  370. rasa/dialogue_understanding/commands/human_handoff_command.py +66 -0
  371. rasa/dialogue_understanding/commands/knowledge_answer_command.py +57 -0
  372. rasa/dialogue_understanding/commands/noop_command.py +54 -0
  373. rasa/dialogue_understanding/commands/repeat_bot_messages_command.py +60 -0
  374. rasa/dialogue_understanding/commands/restart_command.py +58 -0
  375. rasa/dialogue_understanding/commands/session_end_command.py +61 -0
  376. rasa/dialogue_understanding/commands/session_start_command.py +59 -0
  377. rasa/dialogue_understanding/commands/set_slot_command.py +160 -0
  378. rasa/dialogue_understanding/commands/skip_question_command.py +75 -0
  379. rasa/dialogue_understanding/commands/start_flow_command.py +107 -0
  380. rasa/dialogue_understanding/commands/user_silence_command.py +59 -0
  381. rasa/dialogue_understanding/commands/utils.py +45 -0
  382. rasa/dialogue_understanding/generator/__init__.py +21 -0
  383. rasa/dialogue_understanding/generator/command_generator.py +464 -0
  384. rasa/dialogue_understanding/generator/constants.py +27 -0
  385. rasa/dialogue_understanding/generator/flow_document_template.jinja2 +4 -0
  386. rasa/dialogue_understanding/generator/flow_retrieval.py +466 -0
  387. rasa/dialogue_understanding/generator/llm_based_command_generator.py +500 -0
  388. rasa/dialogue_understanding/generator/llm_command_generator.py +67 -0
  389. rasa/dialogue_understanding/generator/multi_step/__init__.py +0 -0
  390. rasa/dialogue_understanding/generator/multi_step/fill_slots_prompt.jinja2 +62 -0
  391. rasa/dialogue_understanding/generator/multi_step/handle_flows_prompt.jinja2 +38 -0
  392. rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +920 -0
  393. rasa/dialogue_understanding/generator/nlu_command_adapter.py +261 -0
  394. rasa/dialogue_understanding/generator/single_step/__init__.py +0 -0
  395. rasa/dialogue_understanding/generator/single_step/command_prompt_template.jinja2 +60 -0
  396. rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +486 -0
  397. rasa/dialogue_understanding/patterns/__init__.py +0 -0
  398. rasa/dialogue_understanding/patterns/cancel.py +111 -0
  399. rasa/dialogue_understanding/patterns/cannot_handle.py +43 -0
  400. rasa/dialogue_understanding/patterns/chitchat.py +37 -0
  401. rasa/dialogue_understanding/patterns/clarify.py +97 -0
  402. rasa/dialogue_understanding/patterns/code_change.py +41 -0
  403. rasa/dialogue_understanding/patterns/collect_information.py +90 -0
  404. rasa/dialogue_understanding/patterns/completed.py +40 -0
  405. rasa/dialogue_understanding/patterns/continue_interrupted.py +42 -0
  406. rasa/dialogue_understanding/patterns/correction.py +278 -0
  407. rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +301 -0
  408. rasa/dialogue_understanding/patterns/human_handoff.py +37 -0
  409. rasa/dialogue_understanding/patterns/internal_error.py +47 -0
  410. rasa/dialogue_understanding/patterns/repeat.py +37 -0
  411. rasa/dialogue_understanding/patterns/restart.py +37 -0
  412. rasa/dialogue_understanding/patterns/search.py +37 -0
  413. rasa/dialogue_understanding/patterns/session_start.py +37 -0
  414. rasa/dialogue_understanding/patterns/skip_question.py +38 -0
  415. rasa/dialogue_understanding/patterns/user_silence.py +37 -0
  416. rasa/dialogue_understanding/processor/__init__.py +0 -0
  417. rasa/dialogue_understanding/processor/command_processor.py +720 -0
  418. rasa/dialogue_understanding/processor/command_processor_component.py +43 -0
  419. rasa/dialogue_understanding/stack/__init__.py +0 -0
  420. rasa/dialogue_understanding/stack/dialogue_stack.py +178 -0
  421. rasa/dialogue_understanding/stack/frames/__init__.py +19 -0
  422. rasa/dialogue_understanding/stack/frames/chit_chat_frame.py +27 -0
  423. rasa/dialogue_understanding/stack/frames/dialogue_stack_frame.py +137 -0
  424. rasa/dialogue_understanding/stack/frames/flow_stack_frame.py +157 -0
  425. rasa/dialogue_understanding/stack/frames/pattern_frame.py +10 -0
  426. rasa/dialogue_understanding/stack/frames/search_frame.py +27 -0
  427. rasa/dialogue_understanding/stack/utils.py +211 -0
  428. rasa/dialogue_understanding/utils.py +14 -0
  429. rasa/dialogue_understanding_test/__init__.py +0 -0
  430. rasa/dialogue_understanding_test/command_metric_calculation.py +12 -0
  431. rasa/dialogue_understanding_test/constants.py +17 -0
  432. rasa/dialogue_understanding_test/du_test_case.py +118 -0
  433. rasa/dialogue_understanding_test/du_test_result.py +11 -0
  434. rasa/dialogue_understanding_test/du_test_runner.py +93 -0
  435. rasa/dialogue_understanding_test/io.py +54 -0
  436. rasa/dialogue_understanding_test/validation.py +22 -0
  437. rasa/e2e_test/__init__.py +0 -0
  438. rasa/e2e_test/aggregate_test_stats_calculator.py +134 -0
  439. rasa/e2e_test/assertions.py +1345 -0
  440. rasa/e2e_test/assertions_schema.yml +129 -0
  441. rasa/e2e_test/constants.py +31 -0
  442. rasa/e2e_test/e2e_config.py +220 -0
  443. rasa/e2e_test/e2e_config_schema.yml +26 -0
  444. rasa/e2e_test/e2e_test_case.py +569 -0
  445. rasa/e2e_test/e2e_test_converter.py +363 -0
  446. rasa/e2e_test/e2e_test_converter_prompt.jinja2 +70 -0
  447. rasa/e2e_test/e2e_test_coverage_report.py +364 -0
  448. rasa/e2e_test/e2e_test_result.py +54 -0
  449. rasa/e2e_test/e2e_test_runner.py +1192 -0
  450. rasa/e2e_test/e2e_test_schema.yml +181 -0
  451. rasa/e2e_test/pykwalify_extensions.py +39 -0
  452. rasa/e2e_test/stub_custom_action.py +70 -0
  453. rasa/e2e_test/utils/__init__.py +0 -0
  454. rasa/e2e_test/utils/e2e_yaml_utils.py +55 -0
  455. rasa/e2e_test/utils/io.py +598 -0
  456. rasa/e2e_test/utils/validation.py +178 -0
  457. rasa/engine/__init__.py +0 -0
  458. rasa/engine/caching.py +463 -0
  459. rasa/engine/constants.py +17 -0
  460. rasa/engine/exceptions.py +14 -0
  461. rasa/engine/graph.py +642 -0
  462. rasa/engine/loader.py +48 -0
  463. rasa/engine/recipes/__init__.py +0 -0
  464. rasa/engine/recipes/config_files/default_config.yml +41 -0
  465. rasa/engine/recipes/default_components.py +97 -0
  466. rasa/engine/recipes/default_recipe.py +1272 -0
  467. rasa/engine/recipes/graph_recipe.py +79 -0
  468. rasa/engine/recipes/recipe.py +93 -0
  469. rasa/engine/runner/__init__.py +0 -0
  470. rasa/engine/runner/dask.py +250 -0
  471. rasa/engine/runner/interface.py +49 -0
  472. rasa/engine/storage/__init__.py +0 -0
  473. rasa/engine/storage/local_model_storage.py +244 -0
  474. rasa/engine/storage/resource.py +110 -0
  475. rasa/engine/storage/storage.py +199 -0
  476. rasa/engine/training/__init__.py +0 -0
  477. rasa/engine/training/components.py +176 -0
  478. rasa/engine/training/fingerprinting.py +64 -0
  479. rasa/engine/training/graph_trainer.py +256 -0
  480. rasa/engine/training/hooks.py +164 -0
  481. rasa/engine/validation.py +1451 -0
  482. rasa/env.py +14 -0
  483. rasa/exceptions.py +69 -0
  484. rasa/graph_components/__init__.py +0 -0
  485. rasa/graph_components/converters/__init__.py +0 -0
  486. rasa/graph_components/converters/nlu_message_converter.py +48 -0
  487. rasa/graph_components/providers/__init__.py +0 -0
  488. rasa/graph_components/providers/domain_for_core_training_provider.py +87 -0
  489. rasa/graph_components/providers/domain_provider.py +71 -0
  490. rasa/graph_components/providers/flows_provider.py +74 -0
  491. rasa/graph_components/providers/forms_provider.py +44 -0
  492. rasa/graph_components/providers/nlu_training_data_provider.py +56 -0
  493. rasa/graph_components/providers/responses_provider.py +44 -0
  494. rasa/graph_components/providers/rule_only_provider.py +49 -0
  495. rasa/graph_components/providers/story_graph_provider.py +96 -0
  496. rasa/graph_components/providers/training_tracker_provider.py +55 -0
  497. rasa/graph_components/validators/__init__.py +0 -0
  498. rasa/graph_components/validators/default_recipe_validator.py +550 -0
  499. rasa/graph_components/validators/finetuning_validator.py +302 -0
  500. rasa/hooks.py +111 -0
  501. rasa/jupyter.py +63 -0
  502. rasa/llm_fine_tuning/__init__.py +0 -0
  503. rasa/llm_fine_tuning/annotation_module.py +241 -0
  504. rasa/llm_fine_tuning/conversations.py +144 -0
  505. rasa/llm_fine_tuning/llm_data_preparation_module.py +178 -0
  506. rasa/llm_fine_tuning/paraphrasing/__init__.py +0 -0
  507. rasa/llm_fine_tuning/paraphrasing/conversation_rephraser.py +281 -0
  508. rasa/llm_fine_tuning/paraphrasing/default_rephrase_prompt_template.jina2 +44 -0
  509. rasa/llm_fine_tuning/paraphrasing/rephrase_validator.py +121 -0
  510. rasa/llm_fine_tuning/paraphrasing/rephrased_user_message.py +10 -0
  511. rasa/llm_fine_tuning/paraphrasing_module.py +128 -0
  512. rasa/llm_fine_tuning/storage.py +174 -0
  513. rasa/llm_fine_tuning/train_test_split_module.py +441 -0
  514. rasa/markers/__init__.py +0 -0
  515. rasa/markers/marker.py +269 -0
  516. rasa/markers/marker_base.py +828 -0
  517. rasa/markers/upload.py +74 -0
  518. rasa/markers/validate.py +21 -0
  519. rasa/model.py +118 -0
  520. rasa/model_manager/__init__.py +0 -0
  521. rasa/model_manager/config.py +40 -0
  522. rasa/model_manager/model_api.py +559 -0
  523. rasa/model_manager/runner_service.py +286 -0
  524. rasa/model_manager/socket_bridge.py +146 -0
  525. rasa/model_manager/studio_jwt_auth.py +86 -0
  526. rasa/model_manager/trainer_service.py +325 -0
  527. rasa/model_manager/utils.py +87 -0
  528. rasa/model_manager/warm_rasa_process.py +187 -0
  529. rasa/model_service.py +112 -0
  530. rasa/model_testing.py +457 -0
  531. rasa/model_training.py +596 -0
  532. rasa/nlu/__init__.py +7 -0
  533. rasa/nlu/classifiers/__init__.py +3 -0
  534. rasa/nlu/classifiers/classifier.py +5 -0
  535. rasa/nlu/classifiers/diet_classifier.py +1881 -0
  536. rasa/nlu/classifiers/fallback_classifier.py +192 -0
  537. rasa/nlu/classifiers/keyword_intent_classifier.py +188 -0
  538. rasa/nlu/classifiers/logistic_regression_classifier.py +253 -0
  539. rasa/nlu/classifiers/mitie_intent_classifier.py +156 -0
  540. rasa/nlu/classifiers/regex_message_handler.py +56 -0
  541. rasa/nlu/classifiers/sklearn_intent_classifier.py +330 -0
  542. rasa/nlu/constants.py +77 -0
  543. rasa/nlu/convert.py +40 -0
  544. rasa/nlu/emulators/__init__.py +0 -0
  545. rasa/nlu/emulators/dialogflow.py +55 -0
  546. rasa/nlu/emulators/emulator.py +49 -0
  547. rasa/nlu/emulators/luis.py +86 -0
  548. rasa/nlu/emulators/no_emulator.py +10 -0
  549. rasa/nlu/emulators/wit.py +56 -0
  550. rasa/nlu/extractors/__init__.py +0 -0
  551. rasa/nlu/extractors/crf_entity_extractor.py +715 -0
  552. rasa/nlu/extractors/duckling_entity_extractor.py +206 -0
  553. rasa/nlu/extractors/entity_synonyms.py +178 -0
  554. rasa/nlu/extractors/extractor.py +470 -0
  555. rasa/nlu/extractors/mitie_entity_extractor.py +293 -0
  556. rasa/nlu/extractors/regex_entity_extractor.py +220 -0
  557. rasa/nlu/extractors/spacy_entity_extractor.py +95 -0
  558. rasa/nlu/featurizers/__init__.py +0 -0
  559. rasa/nlu/featurizers/dense_featurizer/__init__.py +0 -0
  560. rasa/nlu/featurizers/dense_featurizer/convert_featurizer.py +445 -0
  561. rasa/nlu/featurizers/dense_featurizer/dense_featurizer.py +57 -0
  562. rasa/nlu/featurizers/dense_featurizer/lm_featurizer.py +768 -0
  563. rasa/nlu/featurizers/dense_featurizer/mitie_featurizer.py +170 -0
  564. rasa/nlu/featurizers/dense_featurizer/spacy_featurizer.py +132 -0
  565. rasa/nlu/featurizers/featurizer.py +89 -0
  566. rasa/nlu/featurizers/sparse_featurizer/__init__.py +0 -0
  567. rasa/nlu/featurizers/sparse_featurizer/count_vectors_featurizer.py +867 -0
  568. rasa/nlu/featurizers/sparse_featurizer/lexical_syntactic_featurizer.py +571 -0
  569. rasa/nlu/featurizers/sparse_featurizer/regex_featurizer.py +271 -0
  570. rasa/nlu/featurizers/sparse_featurizer/sparse_featurizer.py +9 -0
  571. rasa/nlu/model.py +24 -0
  572. rasa/nlu/run.py +27 -0
  573. rasa/nlu/selectors/__init__.py +0 -0
  574. rasa/nlu/selectors/response_selector.py +987 -0
  575. rasa/nlu/test.py +1940 -0
  576. rasa/nlu/tokenizers/__init__.py +0 -0
  577. rasa/nlu/tokenizers/jieba_tokenizer.py +148 -0
  578. rasa/nlu/tokenizers/mitie_tokenizer.py +75 -0
  579. rasa/nlu/tokenizers/spacy_tokenizer.py +72 -0
  580. rasa/nlu/tokenizers/tokenizer.py +239 -0
  581. rasa/nlu/tokenizers/whitespace_tokenizer.py +95 -0
  582. rasa/nlu/utils/__init__.py +35 -0
  583. rasa/nlu/utils/bilou_utils.py +462 -0
  584. rasa/nlu/utils/hugging_face/__init__.py +0 -0
  585. rasa/nlu/utils/hugging_face/registry.py +108 -0
  586. rasa/nlu/utils/hugging_face/transformers_pre_post_processors.py +311 -0
  587. rasa/nlu/utils/mitie_utils.py +113 -0
  588. rasa/nlu/utils/pattern_utils.py +168 -0
  589. rasa/nlu/utils/spacy_utils.py +310 -0
  590. rasa/plugin.py +90 -0
  591. rasa/server.py +1588 -0
  592. rasa/shared/__init__.py +0 -0
  593. rasa/shared/constants.py +311 -0
  594. rasa/shared/core/__init__.py +0 -0
  595. rasa/shared/core/command_payload_reader.py +109 -0
  596. rasa/shared/core/constants.py +180 -0
  597. rasa/shared/core/conversation.py +46 -0
  598. rasa/shared/core/domain.py +2172 -0
  599. rasa/shared/core/events.py +2559 -0
  600. rasa/shared/core/flows/__init__.py +7 -0
  601. rasa/shared/core/flows/flow.py +562 -0
  602. rasa/shared/core/flows/flow_path.py +84 -0
  603. rasa/shared/core/flows/flow_step.py +146 -0
  604. rasa/shared/core/flows/flow_step_links.py +319 -0
  605. rasa/shared/core/flows/flow_step_sequence.py +70 -0
  606. rasa/shared/core/flows/flows_list.py +258 -0
  607. rasa/shared/core/flows/flows_yaml_schema.json +303 -0
  608. rasa/shared/core/flows/nlu_trigger.py +117 -0
  609. rasa/shared/core/flows/steps/__init__.py +24 -0
  610. rasa/shared/core/flows/steps/action.py +56 -0
  611. rasa/shared/core/flows/steps/call.py +64 -0
  612. rasa/shared/core/flows/steps/collect.py +112 -0
  613. rasa/shared/core/flows/steps/constants.py +5 -0
  614. rasa/shared/core/flows/steps/continuation.py +36 -0
  615. rasa/shared/core/flows/steps/end.py +22 -0
  616. rasa/shared/core/flows/steps/internal.py +44 -0
  617. rasa/shared/core/flows/steps/link.py +51 -0
  618. rasa/shared/core/flows/steps/no_operation.py +48 -0
  619. rasa/shared/core/flows/steps/set_slots.py +50 -0
  620. rasa/shared/core/flows/steps/start.py +30 -0
  621. rasa/shared/core/flows/utils.py +39 -0
  622. rasa/shared/core/flows/validation.py +735 -0
  623. rasa/shared/core/flows/yaml_flows_io.py +405 -0
  624. rasa/shared/core/generator.py +908 -0
  625. rasa/shared/core/slot_mappings.py +526 -0
  626. rasa/shared/core/slots.py +654 -0
  627. rasa/shared/core/trackers.py +1183 -0
  628. rasa/shared/core/training_data/__init__.py +0 -0
  629. rasa/shared/core/training_data/loading.py +89 -0
  630. rasa/shared/core/training_data/story_reader/__init__.py +0 -0
  631. rasa/shared/core/training_data/story_reader/story_reader.py +129 -0
  632. rasa/shared/core/training_data/story_reader/story_step_builder.py +168 -0
  633. rasa/shared/core/training_data/story_reader/yaml_story_reader.py +888 -0
  634. rasa/shared/core/training_data/story_writer/__init__.py +0 -0
  635. rasa/shared/core/training_data/story_writer/story_writer.py +76 -0
  636. rasa/shared/core/training_data/story_writer/yaml_story_writer.py +444 -0
  637. rasa/shared/core/training_data/structures.py +858 -0
  638. rasa/shared/core/training_data/visualization.html +146 -0
  639. rasa/shared/core/training_data/visualization.py +603 -0
  640. rasa/shared/data.py +249 -0
  641. rasa/shared/engine/__init__.py +0 -0
  642. rasa/shared/engine/caching.py +26 -0
  643. rasa/shared/exceptions.py +167 -0
  644. rasa/shared/importers/__init__.py +0 -0
  645. rasa/shared/importers/importer.py +770 -0
  646. rasa/shared/importers/multi_project.py +215 -0
  647. rasa/shared/importers/rasa.py +108 -0
  648. rasa/shared/importers/remote_importer.py +196 -0
  649. rasa/shared/importers/utils.py +36 -0
  650. rasa/shared/nlu/__init__.py +0 -0
  651. rasa/shared/nlu/constants.py +53 -0
  652. rasa/shared/nlu/interpreter.py +10 -0
  653. rasa/shared/nlu/training_data/__init__.py +0 -0
  654. rasa/shared/nlu/training_data/entities_parser.py +208 -0
  655. rasa/shared/nlu/training_data/features.py +492 -0
  656. rasa/shared/nlu/training_data/formats/__init__.py +10 -0
  657. rasa/shared/nlu/training_data/formats/dialogflow.py +163 -0
  658. rasa/shared/nlu/training_data/formats/luis.py +87 -0
  659. rasa/shared/nlu/training_data/formats/rasa.py +135 -0
  660. rasa/shared/nlu/training_data/formats/rasa_yaml.py +618 -0
  661. rasa/shared/nlu/training_data/formats/readerwriter.py +244 -0
  662. rasa/shared/nlu/training_data/formats/wit.py +52 -0
  663. rasa/shared/nlu/training_data/loading.py +137 -0
  664. rasa/shared/nlu/training_data/lookup_tables_parser.py +30 -0
  665. rasa/shared/nlu/training_data/message.py +490 -0
  666. rasa/shared/nlu/training_data/schemas/__init__.py +0 -0
  667. rasa/shared/nlu/training_data/schemas/data_schema.py +85 -0
  668. rasa/shared/nlu/training_data/schemas/nlu.yml +53 -0
  669. rasa/shared/nlu/training_data/schemas/responses.yml +70 -0
  670. rasa/shared/nlu/training_data/synonyms_parser.py +42 -0
  671. rasa/shared/nlu/training_data/training_data.py +729 -0
  672. rasa/shared/nlu/training_data/util.py +223 -0
  673. rasa/shared/providers/__init__.py +0 -0
  674. rasa/shared/providers/_configs/__init__.py +0 -0
  675. rasa/shared/providers/_configs/azure_openai_client_config.py +677 -0
  676. rasa/shared/providers/_configs/client_config.py +59 -0
  677. rasa/shared/providers/_configs/default_litellm_client_config.py +132 -0
  678. rasa/shared/providers/_configs/huggingface_local_embedding_client_config.py +236 -0
  679. rasa/shared/providers/_configs/litellm_router_client_config.py +222 -0
  680. rasa/shared/providers/_configs/model_group_config.py +173 -0
  681. rasa/shared/providers/_configs/openai_client_config.py +177 -0
  682. rasa/shared/providers/_configs/rasa_llm_client_config.py +75 -0
  683. rasa/shared/providers/_configs/self_hosted_llm_client_config.py +178 -0
  684. rasa/shared/providers/_configs/utils.py +117 -0
  685. rasa/shared/providers/_ssl_verification_utils.py +124 -0
  686. rasa/shared/providers/_utils.py +79 -0
  687. rasa/shared/providers/constants.py +7 -0
  688. rasa/shared/providers/embedding/__init__.py +0 -0
  689. rasa/shared/providers/embedding/_base_litellm_embedding_client.py +243 -0
  690. rasa/shared/providers/embedding/_langchain_embedding_client_adapter.py +74 -0
  691. rasa/shared/providers/embedding/azure_openai_embedding_client.py +335 -0
  692. rasa/shared/providers/embedding/default_litellm_embedding_client.py +126 -0
  693. rasa/shared/providers/embedding/embedding_client.py +90 -0
  694. rasa/shared/providers/embedding/embedding_response.py +41 -0
  695. rasa/shared/providers/embedding/huggingface_local_embedding_client.py +191 -0
  696. rasa/shared/providers/embedding/litellm_router_embedding_client.py +138 -0
  697. rasa/shared/providers/embedding/openai_embedding_client.py +172 -0
  698. rasa/shared/providers/llm/__init__.py +0 -0
  699. rasa/shared/providers/llm/_base_litellm_client.py +265 -0
  700. rasa/shared/providers/llm/azure_openai_llm_client.py +415 -0
  701. rasa/shared/providers/llm/default_litellm_llm_client.py +110 -0
  702. rasa/shared/providers/llm/litellm_router_llm_client.py +202 -0
  703. rasa/shared/providers/llm/llm_client.py +78 -0
  704. rasa/shared/providers/llm/llm_response.py +50 -0
  705. rasa/shared/providers/llm/openai_llm_client.py +161 -0
  706. rasa/shared/providers/llm/rasa_llm_client.py +120 -0
  707. rasa/shared/providers/llm/self_hosted_llm_client.py +276 -0
  708. rasa/shared/providers/mappings.py +94 -0
  709. rasa/shared/providers/router/__init__.py +0 -0
  710. rasa/shared/providers/router/_base_litellm_router_client.py +185 -0
  711. rasa/shared/providers/router/router_client.py +75 -0
  712. rasa/shared/utils/__init__.py +0 -0
  713. rasa/shared/utils/cli.py +102 -0
  714. rasa/shared/utils/common.py +324 -0
  715. rasa/shared/utils/constants.py +4 -0
  716. rasa/shared/utils/health_check/__init__.py +0 -0
  717. rasa/shared/utils/health_check/embeddings_health_check_mixin.py +31 -0
  718. rasa/shared/utils/health_check/health_check.py +258 -0
  719. rasa/shared/utils/health_check/llm_health_check_mixin.py +31 -0
  720. rasa/shared/utils/io.py +499 -0
  721. rasa/shared/utils/llm.py +764 -0
  722. rasa/shared/utils/pykwalify_extensions.py +27 -0
  723. rasa/shared/utils/schemas/__init__.py +0 -0
  724. rasa/shared/utils/schemas/config.yml +2 -0
  725. rasa/shared/utils/schemas/domain.yml +145 -0
  726. rasa/shared/utils/schemas/events.py +214 -0
  727. rasa/shared/utils/schemas/model_config.yml +36 -0
  728. rasa/shared/utils/schemas/stories.yml +173 -0
  729. rasa/shared/utils/yaml.py +1068 -0
  730. rasa/studio/__init__.py +0 -0
  731. rasa/studio/auth.py +270 -0
  732. rasa/studio/config.py +136 -0
  733. rasa/studio/constants.py +19 -0
  734. rasa/studio/data_handler.py +368 -0
  735. rasa/studio/download.py +489 -0
  736. rasa/studio/results_logger.py +137 -0
  737. rasa/studio/train.py +134 -0
  738. rasa/studio/upload.py +563 -0
  739. rasa/telemetry.py +1876 -0
  740. rasa/tracing/__init__.py +0 -0
  741. rasa/tracing/config.py +355 -0
  742. rasa/tracing/constants.py +62 -0
  743. rasa/tracing/instrumentation/__init__.py +0 -0
  744. rasa/tracing/instrumentation/attribute_extractors.py +765 -0
  745. rasa/tracing/instrumentation/instrumentation.py +1306 -0
  746. rasa/tracing/instrumentation/intentless_policy_instrumentation.py +144 -0
  747. rasa/tracing/instrumentation/metrics.py +294 -0
  748. rasa/tracing/metric_instrument_provider.py +205 -0
  749. rasa/utils/__init__.py +0 -0
  750. rasa/utils/beta.py +83 -0
  751. rasa/utils/cli.py +28 -0
  752. rasa/utils/common.py +639 -0
  753. rasa/utils/converter.py +53 -0
  754. rasa/utils/endpoints.py +331 -0
  755. rasa/utils/io.py +252 -0
  756. rasa/utils/json_utils.py +60 -0
  757. rasa/utils/licensing.py +542 -0
  758. rasa/utils/log_utils.py +181 -0
  759. rasa/utils/mapper.py +210 -0
  760. rasa/utils/ml_utils.py +147 -0
  761. rasa/utils/plotting.py +362 -0
  762. rasa/utils/sanic_error_handler.py +32 -0
  763. rasa/utils/singleton.py +23 -0
  764. rasa/utils/tensorflow/__init__.py +0 -0
  765. rasa/utils/tensorflow/callback.py +112 -0
  766. rasa/utils/tensorflow/constants.py +116 -0
  767. rasa/utils/tensorflow/crf.py +492 -0
  768. rasa/utils/tensorflow/data_generator.py +440 -0
  769. rasa/utils/tensorflow/environment.py +161 -0
  770. rasa/utils/tensorflow/exceptions.py +5 -0
  771. rasa/utils/tensorflow/feature_array.py +366 -0
  772. rasa/utils/tensorflow/layers.py +1565 -0
  773. rasa/utils/tensorflow/layers_utils.py +113 -0
  774. rasa/utils/tensorflow/metrics.py +281 -0
  775. rasa/utils/tensorflow/model_data.py +798 -0
  776. rasa/utils/tensorflow/model_data_utils.py +499 -0
  777. rasa/utils/tensorflow/models.py +935 -0
  778. rasa/utils/tensorflow/rasa_layers.py +1094 -0
  779. rasa/utils/tensorflow/transformer.py +640 -0
  780. rasa/utils/tensorflow/types.py +6 -0
  781. rasa/utils/train_utils.py +572 -0
  782. rasa/utils/url_tools.py +53 -0
  783. rasa/utils/yaml.py +54 -0
  784. rasa/validator.py +1644 -0
  785. rasa/version.py +3 -0
  786. rasa_pro-3.12.0.dev1.dist-info/METADATA +199 -0
  787. rasa_pro-3.12.0.dev1.dist-info/NOTICE +5 -0
  788. rasa_pro-3.12.0.dev1.dist-info/RECORD +790 -0
  789. rasa_pro-3.12.0.dev1.dist-info/WHEEL +4 -0
  790. rasa_pro-3.12.0.dev1.dist-info/entry_points.txt +3 -0
@@ -0,0 +1,1565 @@
1
+ import logging
2
+ from typing import List, Optional, Text, Tuple, Callable, Union, Any
3
+ import tensorflow as tf
4
+
5
+ # TODO: The following is not (yet) available via tf.keras
6
+ from keras.src.utils.control_flow_util import smart_cond
7
+ import tensorflow.keras.backend as K
8
+
9
+ import rasa.utils.tensorflow.crf
10
+ from rasa.utils.tensorflow.constants import (
11
+ SOFTMAX,
12
+ MARGIN,
13
+ COSINE,
14
+ INNER,
15
+ CROSS_ENTROPY,
16
+ LABEL,
17
+ LABEL_PAD_ID,
18
+ )
19
+ from rasa.core.constants import DIALOGUE
20
+ from rasa.shared.nlu.constants import FEATURE_TYPE_SENTENCE, FEATURE_TYPE_SEQUENCE
21
+ from rasa.shared.nlu.constants import TEXT, INTENT, ACTION_NAME, ACTION_TEXT
22
+
23
+ from rasa.utils.tensorflow.metrics import F1Score
24
+ from rasa.utils.tensorflow.exceptions import TFLayerConfigException
25
+ import rasa.utils.tensorflow.layers_utils as layers_utils
26
+ from rasa.utils.tensorflow.crf import crf_log_likelihood
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ POSSIBLE_ATTRIBUTES = [
32
+ TEXT,
33
+ INTENT,
34
+ LABEL,
35
+ DIALOGUE,
36
+ ACTION_NAME,
37
+ ACTION_TEXT,
38
+ f"{LABEL}_{ACTION_NAME}",
39
+ f"{LABEL}_{ACTION_TEXT}",
40
+ ]
41
+
42
+
43
+ class SparseDropout(tf.keras.layers.Dropout):
44
+ """Applies Dropout to the input.
45
+
46
+ Dropout consists in randomly setting
47
+ a fraction `rate` of input units to 0 at each update during training time,
48
+ which helps prevent overfitting.
49
+
50
+ Arguments:
51
+ rate: Fraction of the input units to drop (between 0 and 1).
52
+ """
53
+
54
+ def call(
55
+ self, inputs: tf.SparseTensor, training: Optional[Union[tf.Tensor, bool]] = None
56
+ ) -> tf.SparseTensor:
57
+ """Apply dropout to sparse inputs.
58
+
59
+ Arguments:
60
+ inputs: Input sparse tensor (of any rank).
61
+ training: Indicates whether the layer should behave in
62
+ training mode (adding dropout) or in inference mode (doing nothing).
63
+
64
+ Returns:
65
+ Output of dropout layer.
66
+
67
+ Raises:
68
+ A ValueError if inputs is not a sparse tensor
69
+ """
70
+ if not isinstance(inputs, tf.SparseTensor):
71
+ raise ValueError("Input tensor should be sparse.")
72
+
73
+ if training is None:
74
+ training = K.learning_phase()
75
+
76
+ def dropped_inputs() -> tf.SparseTensor:
77
+ to_retain_prob = tf.random.uniform(
78
+ tf.shape(inputs.values), 0, 1, inputs.values.dtype
79
+ )
80
+ to_retain = tf.greater_equal(to_retain_prob, self.rate)
81
+ return tf.sparse.retain(inputs, to_retain)
82
+
83
+ outputs = smart_cond(training, dropped_inputs, lambda: tf.identity(inputs))
84
+ # need to explicitly recreate sparse tensor, because otherwise the shape
85
+ # information will be lost after `retain`
86
+ # noinspection PyProtectedMember
87
+ return tf.SparseTensor(outputs.indices, outputs.values, inputs._dense_shape)
88
+
89
+
90
+ class DenseForSparse(tf.keras.layers.Dense):
91
+ """Dense layer for sparse input tensor.
92
+
93
+ Just your regular densely-connected NN layer but for sparse tensors.
94
+
95
+ `Dense` implements the operation:
96
+ `output = activation(dot(input, kernel) + bias)`
97
+ where `activation` is the element-wise activation function
98
+ passed as the `activation` argument, `kernel` is a weights matrix
99
+ created by the layer, and `bias` is a bias vector created by the layer
100
+ (only applicable if `use_bias` is `True`).
101
+
102
+ Note: If the input to the layer has a rank greater than 2, then
103
+ it is flattened prior to the initial dot product with `kernel`.
104
+
105
+ Arguments:
106
+ units: Positive integer, dimensionality of the output space.
107
+ activation: Activation function to use.
108
+ If you don't specify anything, no activation is applied
109
+ (ie. "linear" activation: `a(x) = x`).
110
+ use_bias: Indicates whether the layer uses a bias vector.
111
+ kernel_initializer: Initializer for the `kernel` weights matrix.
112
+ bias_initializer: Initializer for the bias vector.
113
+ reg_lambda: regularization factor
114
+ bias_regularizer: Regularizer function applied to the bias vector.
115
+ activity_regularizer: Regularizer function applied to
116
+ the output of the layer (its "activation")..
117
+ kernel_constraint: Constraint function applied to
118
+ the `kernel` weights matrix.
119
+ bias_constraint: Constraint function applied to the bias vector.
120
+
121
+ Input shape:
122
+ N-D tensor with shape: `(batch_size, ..., input_dim)`.
123
+ The most common situation would be
124
+ a 2D input with shape `(batch_size, input_dim)`.
125
+
126
+ Output shape:
127
+ N-D tensor with shape: `(batch_size, ..., units)`.
128
+ For instance, for a 2D input with shape `(batch_size, input_dim)`,
129
+ the output would have shape `(batch_size, units)`.
130
+ """
131
+
132
+ def __init__(self, reg_lambda: float = 0, **kwargs: Any) -> None:
133
+ if reg_lambda > 0:
134
+ regularizer = tf.keras.regularizers.l2(reg_lambda)
135
+ else:
136
+ regularizer = None
137
+
138
+ super().__init__(kernel_regularizer=regularizer, **kwargs)
139
+
140
+ def get_units(self) -> int:
141
+ """Returns number of output units."""
142
+ return self.units
143
+
144
+ def get_kernel(self) -> tf.Tensor:
145
+ """Returns kernel tensor."""
146
+ return self.kernel
147
+
148
+ def get_bias(self) -> Union[tf.Tensor, None]:
149
+ """Returns bias tensor."""
150
+ if self.use_bias:
151
+ return self.bias
152
+ return None
153
+
154
+ def get_feature_type(self) -> Union[Text, None]:
155
+ """Returns a feature type of the data that's fed to the layer.
156
+
157
+ In order to correctly return a feature type, the function heavily relies
158
+ on the name of `DenseForSparse` layer to contain the feature type.
159
+ Acceptable values of feature types are `FEATURE_TYPE_SENTENCE`
160
+ and `FEATURE_TYPE_SEQUENCE`.
161
+
162
+ Returns:
163
+ feature type of dense layer.
164
+ """
165
+ for feature_type in [FEATURE_TYPE_SENTENCE, FEATURE_TYPE_SEQUENCE]:
166
+ if feature_type in self.name:
167
+ return feature_type
168
+ return None
169
+
170
+ def get_attribute(self) -> Union[Text, None]:
171
+ """Returns the attribute for which this layer was constructed.
172
+
173
+ For example: TEXT, LABEL, etc.
174
+
175
+ In order to correctly return an attribute, the function heavily relies
176
+ on the name of `DenseForSparse` layer being in the following format:
177
+ f"sparse_to_dense.{attribute}_{feature_type}".
178
+
179
+ Returns:
180
+ attribute of the layer.
181
+ """
182
+ metadata = self.name.split(".")
183
+ if len(metadata) > 1:
184
+ attribute_splits = metadata[1].split("_")[:-1]
185
+ attribute = "_".join(attribute_splits)
186
+ if attribute in POSSIBLE_ATTRIBUTES:
187
+ return attribute
188
+ return None
189
+
190
+ def call(self, inputs: tf.SparseTensor) -> tf.Tensor:
191
+ """Apply dense layer to sparse inputs.
192
+
193
+ Arguments:
194
+ inputs: Input sparse tensor (of any rank).
195
+
196
+ Returns:
197
+ Output of dense layer.
198
+
199
+ Raises:
200
+ A ValueError if inputs is not a sparse tensor
201
+ """
202
+ if not isinstance(inputs, tf.SparseTensor):
203
+ raise ValueError("Input tensor should be sparse.")
204
+
205
+ # outputs will be 2D
206
+ outputs = tf.sparse.sparse_dense_matmul(
207
+ tf.sparse.reshape(inputs, [-1, tf.shape(inputs)[-1]]), self.kernel
208
+ )
209
+
210
+ if len(inputs.shape) == 3:
211
+ # reshape back
212
+ outputs = tf.reshape(
213
+ outputs, (tf.shape(inputs)[0], tf.shape(inputs)[1], self.units)
214
+ )
215
+
216
+ if self.use_bias:
217
+ outputs = tf.nn.bias_add(outputs, self.bias)
218
+ if self.activation is not None:
219
+ return self.activation(outputs)
220
+ return outputs
221
+
222
+
223
+ class RandomlyConnectedDense(tf.keras.layers.Dense):
224
+ """Layer with dense ouputs that are connected to a random subset of inputs.
225
+
226
+ `RandomlyConnectedDense` implements the operation:
227
+ `output = activation(dot(input, kernel) + bias)`
228
+ where `activation` is the element-wise activation function
229
+ passed as the `activation` argument, `kernel` is a weights matrix
230
+ created by the layer, and `bias` is a bias vector created by the layer
231
+ (only applicable if `use_bias` is `True`).
232
+ It creates `kernel_mask` to set a fraction of the `kernel` weights to zero.
233
+
234
+ Note: If the input to the layer has a rank greater than 2, then
235
+ it is flattened prior to the initial dot product with `kernel`.
236
+
237
+ The output is guaranteed to be dense (each output is connected to at least one
238
+ input), and no input is disconnected (each input is connected to at least one
239
+ output).
240
+
241
+ At `density = 0.0` the number of trainable weights is `max(input_size, units)`. At
242
+ `density = 1.0` this layer is equivalent to `tf.keras.layers.Dense`.
243
+
244
+ Input shape:
245
+ N-D tensor with shape: `(batch_size, ..., input_dim)`.
246
+ The most common situation would be
247
+ a 2D input with shape `(batch_size, input_dim)`.
248
+
249
+ Output shape:
250
+ N-D tensor with shape: `(batch_size, ..., units)`.
251
+ For instance, for a 2D input with shape `(batch_size, input_dim)`,
252
+ the output would have shape `(batch_size, units)`.
253
+ """
254
+
255
+ def __init__(self, density: float = 0.2, **kwargs: Any) -> None:
256
+ """Declares instance variables with default values.
257
+
258
+ Args:
259
+ density: Approximate fraction of trainable weights (between 0 and 1).
260
+ units: Positive integer, dimensionality of the output space.
261
+ activation: Activation function to use.
262
+ If you don't specify anything, no activation is applied
263
+ (ie. "linear" activation: `a(x) = x`).
264
+ use_bias: Indicates whether the layer uses a bias vector.
265
+ kernel_initializer: Initializer for the `kernel` weights matrix.
266
+ bias_initializer: Initializer for the bias vector.
267
+ kernel_regularizer: Regularizer function applied to
268
+ the `kernel` weights matrix.
269
+ bias_regularizer: Regularizer function applied to the bias vector.
270
+ activity_regularizer: Regularizer function applied to
271
+ the output of the layer (its "activation")..
272
+ kernel_constraint: Constraint function applied to
273
+ the `kernel` weights matrix.
274
+ bias_constraint: Constraint function applied to the bias vector.
275
+ """
276
+ super().__init__(**kwargs)
277
+
278
+ if density < 0.0 or density > 1.0:
279
+ raise TFLayerConfigException("Layer density must be in [0, 1].")
280
+
281
+ self.density = density
282
+
283
+ def build(self, input_shape: tf.TensorShape) -> None:
284
+ """Prepares the kernel mask.
285
+
286
+ Args:
287
+ input_shape: Shape of the inputs to this layer
288
+ """
289
+ super().build(input_shape)
290
+
291
+ if self.density == 1.0:
292
+ self.kernel_mask = None
293
+ return
294
+
295
+ # Construct mask with given density and guarantee that every output is
296
+ # connected to at least one input
297
+ kernel_mask = self._minimal_mask() + self._random_mask()
298
+
299
+ # We might accidently have added a random connection on top of
300
+ # a fixed connection
301
+ kernel_mask = tf.clip_by_value(kernel_mask, 0, 1)
302
+
303
+ self.kernel_mask = tf.Variable(
304
+ initial_value=kernel_mask, trainable=False, name="kernel_mask"
305
+ )
306
+
307
+ def _random_mask(self) -> tf.Tensor:
308
+ """Creates a random matrix with `num_ones` 1s and 0s otherwise.
309
+
310
+ Returns:
311
+ A random mask matrix
312
+ """
313
+ mask = tf.random.uniform(tf.shape(self.kernel), 0, 1)
314
+ mask = tf.cast(tf.math.less(mask, self.density), self.kernel.dtype)
315
+ return mask
316
+
317
+ def _minimal_mask(self) -> tf.Tensor:
318
+ """Creates a matrix with a minimal number of 1s to connect everythinig.
319
+
320
+ If num_rows == num_cols, this creates the identity matrix.
321
+ If num_rows > num_cols, this creates
322
+ 1 0 0 0
323
+ 0 1 0 0
324
+ 0 0 1 0
325
+ 0 0 0 1
326
+ 1 0 0 0
327
+ 0 1 0 0
328
+ 0 0 1 0
329
+ . . . .
330
+ . . . .
331
+ . . . .
332
+ If num_rows < num_cols, this creates
333
+ 1 0 0 1 0 0 1 ...
334
+ 0 1 0 0 1 0 0 ...
335
+ 0 0 1 0 0 1 0 ...
336
+
337
+ Returns:
338
+ A tiled and croped identity matrix.
339
+ """
340
+ kernel_shape = tf.shape(self.kernel)
341
+ num_rows = kernel_shape[0]
342
+ num_cols = kernel_shape[1]
343
+ short_dimension = tf.minimum(num_rows, num_cols)
344
+
345
+ mask = tf.tile(
346
+ tf.eye(short_dimension, dtype=self.kernel.dtype),
347
+ [
348
+ tf.math.ceil(num_rows / short_dimension),
349
+ tf.math.ceil(num_cols / short_dimension),
350
+ ],
351
+ )[:num_rows, :num_cols]
352
+
353
+ return mask
354
+
355
+ def call(self, inputs: tf.Tensor) -> tf.Tensor:
356
+ """Processes the given inputs.
357
+
358
+ Args:
359
+ inputs: What goes into this layer
360
+
361
+ Returns:
362
+ The processed inputs.
363
+ """
364
+ if self.density < 1.0:
365
+ # Set fraction of the `kernel` weights to zero according to precomputed mask
366
+ self.kernel.assign(self.kernel * self.kernel_mask)
367
+ return super().call(inputs)
368
+
369
+
370
+ class Ffnn(tf.keras.layers.Layer):
371
+ """Feed-forward network layer.
372
+
373
+ Arguments:
374
+ layer_sizes: List of integers with dimensionality of the layers.
375
+ dropout_rate: Fraction of the input units to drop (between 0 and 1).
376
+ reg_lambda: regularization factor.
377
+ density: Approximate fraction of trainable weights (between 0 and 1).
378
+ layer_name_suffix: Text added to the name of the layers.
379
+
380
+ Input shape:
381
+ N-D tensor with shape: `(batch_size, ..., input_dim)`.
382
+ The most common situation would be
383
+ a 2D input with shape `(batch_size, input_dim)`.
384
+
385
+ Output shape:
386
+ N-D tensor with shape: `(batch_size, ..., layer_sizes[-1])`.
387
+ For instance, for a 2D input with shape `(batch_size, input_dim)`,
388
+ the output would have shape `(batch_size, layer_sizes[-1])`.
389
+ """
390
+
391
+ def __init__(
392
+ self,
393
+ layer_sizes: List[int],
394
+ dropout_rate: float,
395
+ reg_lambda: float,
396
+ density: float,
397
+ layer_name_suffix: Text,
398
+ ) -> None:
399
+ super().__init__(name=f"ffnn_{layer_name_suffix}")
400
+
401
+ l2_regularizer = tf.keras.regularizers.l2(reg_lambda)
402
+ self._ffn_layers = []
403
+ for i, layer_size in enumerate(layer_sizes):
404
+ self._ffn_layers.append(
405
+ RandomlyConnectedDense(
406
+ units=layer_size,
407
+ density=density,
408
+ activation=tf.nn.gelu,
409
+ kernel_regularizer=l2_regularizer,
410
+ name=f"hidden_layer_{layer_name_suffix}_{i}",
411
+ )
412
+ )
413
+ self._ffn_layers.append(tf.keras.layers.Dropout(dropout_rate))
414
+
415
+ def call(
416
+ self, x: tf.Tensor, training: Optional[Union[tf.Tensor, bool]] = None
417
+ ) -> tf.Tensor:
418
+ """Apply feed-forward network layer."""
419
+ for layer in self._ffn_layers:
420
+ x = layer(x, training=training)
421
+
422
+ return x
423
+
424
+
425
+ class Embed(tf.keras.layers.Layer):
426
+ """Dense embedding layer.
427
+
428
+ Input shape:
429
+ N-D tensor with shape: `(batch_size, ..., input_dim)`.
430
+ The most common situation would be
431
+ a 2D input with shape `(batch_size, input_dim)`.
432
+
433
+ Output shape:
434
+ N-D tensor with shape: `(batch_size, ..., embed_dim)`.
435
+ For instance, for a 2D input with shape `(batch_size, input_dim)`,
436
+ the output would have shape `(batch_size, embed_dim)`.
437
+ """
438
+
439
+ def __init__(
440
+ self, embed_dim: int, reg_lambda: float, layer_name_suffix: Text
441
+ ) -> None:
442
+ """Initialize layer.
443
+
444
+ Args:
445
+ embed_dim: Dimensionality of the output space.
446
+ reg_lambda: Regularization factor.
447
+ layer_name_suffix: Text added to the name of the layers.
448
+ """
449
+ super().__init__(name=f"embed_{layer_name_suffix}")
450
+
451
+ regularizer = tf.keras.regularizers.l2(reg_lambda)
452
+ self._dense = tf.keras.layers.Dense(
453
+ units=embed_dim,
454
+ activation=None,
455
+ kernel_regularizer=regularizer,
456
+ name=f"embed_layer_{layer_name_suffix}",
457
+ )
458
+
459
+ # noinspection PyMethodOverriding
460
+ def call(self, x: tf.Tensor) -> tf.Tensor:
461
+ """Apply dense layer."""
462
+ x = self._dense(x)
463
+ return x
464
+
465
+
466
+ class InputMask(tf.keras.layers.Layer):
467
+ """The layer that masks 15% of the input.
468
+
469
+ Input shape:
470
+ N-D tensor with shape: `(batch_size, ..., input_dim)`.
471
+ The most common situation would be
472
+ a 2D input with shape `(batch_size, input_dim)`.
473
+
474
+ Output shape:
475
+ N-D tensor with shape: `(batch_size, ..., input_dim)`.
476
+ For instance, for a 2D input with shape `(batch_size, input_dim)`,
477
+ the output would have shape `(batch_size, input_dim)`.
478
+ """
479
+
480
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
481
+ super().__init__(*args, **kwargs)
482
+
483
+ self._masking_prob = 0.85
484
+ self._mask_vector_prob = 0.7
485
+ self._random_vector_prob = 0.1
486
+
487
+ def build(self, input_shape: tf.TensorShape) -> None:
488
+ self.mask_vector = self.add_weight(
489
+ shape=(1, 1, input_shape[-1]), name="mask_vector"
490
+ )
491
+ self.built = True
492
+
493
+ # noinspection PyMethodOverriding
494
+ def call(
495
+ self,
496
+ x: tf.Tensor,
497
+ mask: tf.Tensor,
498
+ training: Optional[Union[tf.Tensor, bool]] = None,
499
+ ) -> Tuple[tf.Tensor, tf.Tensor]:
500
+ """Randomly mask input sequences.
501
+
502
+ Arguments:
503
+ x: Input sequence tensor of rank 3.
504
+ mask: A tensor representing sequence mask,
505
+ contains `1` for inputs and `0` for padding.
506
+ training: Indicates whether the layer should run in
507
+ training mode (mask inputs) or in inference mode (doing nothing).
508
+
509
+ Returns:
510
+ A tuple of masked inputs and boolean mask.
511
+ """
512
+ if training is None:
513
+ training = K.learning_phase()
514
+
515
+ lm_mask_prob = tf.random.uniform(tf.shape(mask), 0, 1, mask.dtype) * mask
516
+ lm_mask_bool = tf.greater_equal(lm_mask_prob, self._masking_prob)
517
+
518
+ def x_masked() -> tf.Tensor:
519
+ x_random_pad = tf.random.uniform(
520
+ tf.shape(x), tf.reduce_min(x), tf.reduce_max(x), x.dtype
521
+ ) * (1 - mask)
522
+ # shuffle over batch dim
523
+ x_shuffle = tf.random.shuffle(x * mask + x_random_pad)
524
+
525
+ # shuffle over sequence dim
526
+ x_shuffle = tf.transpose(x_shuffle, [1, 0, 2])
527
+ x_shuffle = tf.random.shuffle(x_shuffle)
528
+ x_shuffle = tf.transpose(x_shuffle, [1, 0, 2])
529
+
530
+ # shuffle doesn't support backprop
531
+ x_shuffle = tf.stop_gradient(x_shuffle)
532
+
533
+ mask_vector = tf.tile(self.mask_vector, (tf.shape(x)[0], tf.shape(x)[1], 1))
534
+
535
+ other_prob = tf.random.uniform(tf.shape(mask), 0, 1, mask.dtype)
536
+ other_prob = tf.tile(other_prob, (1, 1, x.shape[-1]))
537
+ x_other = tf.where(
538
+ other_prob < self._mask_vector_prob,
539
+ mask_vector,
540
+ tf.where(
541
+ other_prob < self._mask_vector_prob + self._random_vector_prob,
542
+ x_shuffle,
543
+ x,
544
+ ),
545
+ )
546
+
547
+ return tf.where(tf.tile(lm_mask_bool, (1, 1, x.shape[-1])), x_other, x)
548
+
549
+ return (smart_cond(training, x_masked, lambda: tf.identity(x)), lm_mask_bool)
550
+
551
+
552
+ def _scale_loss(log_likelihood: tf.Tensor) -> tf.Tensor:
553
+ """Creates scaling loss coefficient depending on the prediction probability.
554
+
555
+ Arguments:
556
+ log_likelihood: a tensor, log-likelihood of prediction
557
+
558
+ Returns:
559
+ Scaling tensor.
560
+ """
561
+ p = tf.math.exp(log_likelihood)
562
+ # only scale loss if some examples are already learned
563
+ return tf.cond(
564
+ tf.reduce_max(p) > 0.5,
565
+ lambda: tf.stop_gradient(tf.pow((1 - p) / 0.5, 4)),
566
+ lambda: tf.ones_like(p),
567
+ )
568
+
569
+
570
+ class CRF(tf.keras.layers.Layer):
571
+ """CRF layer.
572
+
573
+ Arguments:
574
+ num_tags: Positive integer, number of tags.
575
+ reg_lambda: regularization factor.
576
+ name: Optional name of the layer.
577
+ """
578
+
579
+ def __init__(
580
+ self,
581
+ num_tags: int,
582
+ reg_lambda: float,
583
+ scale_loss: bool,
584
+ name: Optional[Text] = None,
585
+ ) -> None:
586
+ super().__init__(name=name)
587
+ self.num_tags = num_tags
588
+ self.scale_loss = scale_loss
589
+ self.transition_regularizer = tf.keras.regularizers.l2(reg_lambda)
590
+ self.f1_score_metric = F1Score(
591
+ num_classes=num_tags - 1, # `0` prediction is not a prediction
592
+ average="micro",
593
+ )
594
+
595
+ def build(self, input_shape: tf.TensorShape) -> None:
596
+ # the weights should be created in `build` to apply random_seed
597
+ self.transition_params = self.add_weight(
598
+ shape=(self.num_tags, self.num_tags),
599
+ regularizer=self.transition_regularizer,
600
+ name="transitions",
601
+ )
602
+ self.built = True
603
+
604
+ # noinspection PyMethodOverriding
605
+ def call(
606
+ self, logits: tf.Tensor, sequence_lengths: tf.Tensor
607
+ ) -> Tuple[tf.Tensor, tf.Tensor]:
608
+ """Decodes the highest scoring sequence of tags.
609
+
610
+ Arguments:
611
+ logits: A [batch_size, max_seq_len, num_tags] tensor of
612
+ unary potentials.
613
+ sequence_lengths: A [batch_size] vector of true sequence lengths.
614
+
615
+ Returns:
616
+ A [batch_size, max_seq_len] matrix, with dtype `tf.int32`.
617
+ Contains the highest scoring tag indices.
618
+ A [batch_size, max_seq_len] matrix, with dtype `tf.float32`.
619
+ Contains the confidence values of the highest scoring tag indices.
620
+ """
621
+ predicted_ids, scores, _ = rasa.utils.tensorflow.crf.crf_decode(
622
+ logits, self.transition_params, sequence_lengths
623
+ )
624
+ # set prediction index for padding to `0`
625
+ mask = tf.sequence_mask(
626
+ sequence_lengths,
627
+ maxlen=tf.shape(predicted_ids)[1],
628
+ dtype=predicted_ids.dtype,
629
+ )
630
+
631
+ confidence_values = scores * tf.cast(mask, tf.float32)
632
+ predicted_ids = predicted_ids * mask
633
+
634
+ return predicted_ids, confidence_values
635
+
636
+ def loss(
637
+ self, logits: tf.Tensor, tag_indices: tf.Tensor, sequence_lengths: tf.Tensor
638
+ ) -> tf.Tensor:
639
+ """Computes the log-likelihood of tag sequences in a CRF.
640
+
641
+ Arguments:
642
+ logits: A [batch_size, max_seq_len, num_tags] tensor of unary potentials
643
+ to use as input to the CRF layer.
644
+ tag_indices: A [batch_size, max_seq_len] matrix of tag indices for which
645
+ we compute the log-likelihood.
646
+ sequence_lengths: A [batch_size] vector of true sequence lengths.
647
+
648
+ Returns:
649
+ Negative mean log-likelihood of all examples,
650
+ given the sequence of tag indices.
651
+ """
652
+ log_likelihood, _ = crf_log_likelihood(
653
+ logits, tag_indices, sequence_lengths, self.transition_params
654
+ )
655
+ loss = -log_likelihood
656
+ if self.scale_loss:
657
+ loss *= _scale_loss(log_likelihood)
658
+
659
+ return tf.reduce_mean(loss)
660
+
661
+ def f1_score(
662
+ self, tag_ids: tf.Tensor, pred_ids: tf.Tensor, mask: tf.Tensor
663
+ ) -> tf.Tensor:
664
+ """Calculates f1 score for train predictions."""
665
+ mask_bool = tf.cast(mask[:, :, 0], tf.bool)
666
+
667
+ # pick only non padding values and flatten sequences
668
+ tag_ids_flat = tf.boolean_mask(tag_ids, mask_bool)
669
+ pred_ids_flat = tf.boolean_mask(pred_ids, mask_bool)
670
+
671
+ # set `0` prediction to not a prediction
672
+ num_tags = self.num_tags - 1
673
+
674
+ tag_ids_flat_one_hot = tf.one_hot(tag_ids_flat - 1, num_tags)
675
+ pred_ids_flat_one_hot = tf.one_hot(pred_ids_flat - 1, num_tags)
676
+
677
+ return self.f1_score_metric(tag_ids_flat_one_hot, pred_ids_flat_one_hot)
678
+
679
+
680
+ class DotProductLoss(tf.keras.layers.Layer):
681
+ """Abstract dot-product loss layer class.
682
+
683
+ Idea based on StarSpace paper: http://arxiv.org/abs/1709.03856
684
+
685
+ Implements similarity methods
686
+ * `sim` (computes a similarity between vectors)
687
+ * `get_similarities_and_confidences_from_embeddings` (calls `sim` and also computes
688
+ confidence values)
689
+
690
+ Specific loss functions (single- or multi-label) must be implemented in child
691
+ classes.
692
+ """
693
+
694
+ def __init__(
695
+ self,
696
+ num_candidates: int,
697
+ scale_loss: bool = False,
698
+ constrain_similarities: bool = True,
699
+ model_confidence: Text = SOFTMAX,
700
+ similarity_type: Text = INNER,
701
+ name: Optional[Text] = None,
702
+ **kwargs: Any,
703
+ ):
704
+ """Declares instance variables with default values.
705
+
706
+ Args:
707
+ num_candidates: Number of labels besides the positive one. Depending on
708
+ whether single- or multi-label loss is implemented (done in
709
+ sub-classes), these can be all negative example labels, or a mixture of
710
+ negative and further positive labels, respectively.
711
+ scale_loss: Boolean, if `True` scale loss inverse proportionally to
712
+ the confidence of the correct prediction.
713
+ constrain_similarities: Boolean, if `True` applies sigmoid on all
714
+ similarity terms and adds to the loss function to
715
+ ensure that similarity values are approximately bounded.
716
+ Used inside _loss_cross_entropy() only.
717
+ model_confidence: Normalization of confidence values during inference.
718
+ Currently, the only possible value is `SOFTMAX`.
719
+ similarity_type: Similarity measure to use, either `cosine` or `inner`.
720
+ name: Optional name of the layer.
721
+
722
+ Raises:
723
+ TFLayerConfigException: When `similarity_type` is not one of `COSINE` or
724
+ `INNER`.
725
+ """
726
+ super().__init__(name=name)
727
+ self.num_neg = num_candidates
728
+ self.scale_loss = scale_loss
729
+ self.constrain_similarities = constrain_similarities
730
+ self.model_confidence = model_confidence
731
+ self.similarity_type = similarity_type
732
+ if self.similarity_type not in {COSINE, INNER}:
733
+ raise TFLayerConfigException(
734
+ f"Unsupported similarity type '{self.similarity_type}', "
735
+ f"should be '{COSINE}' or '{INNER}'."
736
+ )
737
+
738
+ def sim(
739
+ self, a: tf.Tensor, b: tf.Tensor, mask: Optional[tf.Tensor] = None
740
+ ) -> tf.Tensor:
741
+ """Calculates similarity between `a` and `b`.
742
+
743
+ Operates on the last dimension. When `a` and `b` are vectors, then `sim`
744
+ computes either the dot-product, or the cosine of the angle between `a` and `b`,
745
+ depending on `self.similarity_type`.
746
+ Specifically, when the similarity type is `INNER`, then we compute the scalar
747
+ product `a . b`. When the similarity type is `COSINE`, we compute
748
+ `a . b / (|a| |b|)`, i.e. the cosine of the angle between `a` and `b`.
749
+
750
+ Args:
751
+ a: Any float tensor
752
+ b: Any tensor of the same shape and type as `a`
753
+ mask: Mask (should contain 1s for inputs and 0s for padding). Note, that
754
+ `len(mask.shape) == len(a.shape) - 1` should hold.
755
+
756
+ Returns:
757
+ Similarities between vectors in `a` and `b`.
758
+ """
759
+ if self.similarity_type == COSINE:
760
+ a = tf.nn.l2_normalize(a, axis=-1)
761
+ b = tf.nn.l2_normalize(b, axis=-1)
762
+ sim = tf.reduce_sum(a * b, axis=-1)
763
+ if mask is not None:
764
+ sim *= tf.expand_dims(mask, 2)
765
+
766
+ return sim
767
+
768
+ def get_similarities_and_confidences_from_embeddings(
769
+ self,
770
+ input_embeddings: tf.Tensor,
771
+ label_embeddings: tf.Tensor,
772
+ mask: Optional[tf.Tensor] = None,
773
+ ) -> Tuple[tf.Tensor, tf.Tensor]:
774
+ """Computes similary between input and label embeddings and model's confidence.
775
+
776
+ First compute the similarity from embeddings and then apply an activation
777
+ function if needed to get the confidence.
778
+
779
+ Args:
780
+ input_embeddings: Embeddings of input.
781
+ label_embeddings: Embeddings of labels.
782
+ mask: Mask (should contain 1s for inputs and 0s for padding). Note, that
783
+ `len(mask.shape) == len(a.shape) - 1` should hold.
784
+
785
+ Returns:
786
+ similarity between input and label embeddings and model's prediction
787
+ confidence for each label.
788
+ """
789
+ similarities = self.sim(input_embeddings, label_embeddings, mask)
790
+ confidences = similarities
791
+ if self.model_confidence == SOFTMAX:
792
+ confidences = tf.nn.softmax(similarities)
793
+ return similarities, confidences
794
+
795
+ def call(self, *args: Any, **kwargs: Any) -> Tuple[tf.Tensor, tf.Tensor]:
796
+ """Layer's logic - to be implemented in child class."""
797
+ raise NotImplementedError
798
+
799
+ def apply_mask_and_scaling(
800
+ self, loss: tf.Tensor, mask: Optional[tf.Tensor]
801
+ ) -> tf.Tensor:
802
+ """Scales the loss and applies the mask if necessary.
803
+
804
+ Args:
805
+ loss: The loss tensor
806
+ mask: (Optional) A mask to multiply with the loss
807
+
808
+ Returns:
809
+ The scaled loss, potentially averaged over the sequence
810
+ dimension.
811
+ """
812
+ if self.scale_loss:
813
+ # in case of cross entropy log_likelihood = -loss
814
+ loss *= _scale_loss(-loss)
815
+
816
+ if mask is not None:
817
+ loss *= mask
818
+
819
+ if len(loss.shape) == 2:
820
+ # average over the sequence
821
+ if mask is not None:
822
+ loss = tf.reduce_sum(loss, axis=-1) / tf.reduce_sum(mask, axis=-1)
823
+ else:
824
+ loss = tf.reduce_mean(loss, axis=-1)
825
+
826
+ return loss
827
+
828
+
829
+ class SingleLabelDotProductLoss(DotProductLoss):
830
+ """Single-label dot-product loss layer.
831
+
832
+ This loss layer assumes that only one output (label) is correct for any given input.
833
+ """
834
+
835
+ def __init__(
836
+ self,
837
+ num_candidates: int,
838
+ scale_loss: bool = False,
839
+ constrain_similarities: bool = True,
840
+ model_confidence: Text = SOFTMAX,
841
+ similarity_type: Text = INNER,
842
+ name: Optional[Text] = None,
843
+ loss_type: Text = CROSS_ENTROPY,
844
+ mu_pos: float = 0.8,
845
+ mu_neg: float = -0.2,
846
+ use_max_sim_neg: bool = True,
847
+ neg_lambda: float = 0.5,
848
+ same_sampling: bool = False,
849
+ **kwargs: Any,
850
+ ) -> None:
851
+ """Declares instance variables with default values.
852
+
853
+ Args:
854
+ num_candidates: Positive integer, the number of incorrect labels;
855
+ the algorithm will minimize their similarity to the input.
856
+ loss_type: The type of the loss function, either `cross_entropy` or
857
+ `margin`.
858
+ mu_pos: Indicates how similar the algorithm should
859
+ try to make embedding vectors for correct labels;
860
+ should be 0.0 < ... < 1.0 for `cosine` similarity type.
861
+ mu_neg: Maximum negative similarity for incorrect labels,
862
+ should be -1.0 < ... < 1.0 for `cosine` similarity type.
863
+ use_max_sim_neg: If `True` the algorithm only minimizes
864
+ maximum similarity over incorrect intent labels,
865
+ used only if `loss_type` is set to `margin`.
866
+ neg_lambda: The scale of how important it is to minimize
867
+ the maximum similarity between embeddings of different labels,
868
+ used only if `loss_type` is set to `margin`.
869
+ scale_loss: If `True` scale loss inverse proportionally to
870
+ the confidence of the correct prediction.
871
+ similarity_type: Similarity measure to use, either `cosine` or `inner`.
872
+ name: Optional name of the layer.
873
+ same_sampling: If `True` sample same negative labels
874
+ for the whole batch.
875
+ constrain_similarities: If `True` and loss_type is `cross_entropy`, a
876
+ sigmoid loss term is added to the total loss to ensure that similarity
877
+ values are approximately bounded.
878
+ model_confidence: Normalization of confidence values during inference.
879
+ Currently, the only possible value is `SOFTMAX`.
880
+ """
881
+ super().__init__(
882
+ num_candidates,
883
+ scale_loss=scale_loss,
884
+ constrain_similarities=constrain_similarities,
885
+ model_confidence=model_confidence,
886
+ similarity_type=similarity_type,
887
+ name=name,
888
+ )
889
+ self.loss_type = loss_type
890
+ self.mu_pos = mu_pos
891
+ self.mu_neg = mu_neg
892
+ self.use_max_sim_neg = use_max_sim_neg
893
+ self.neg_lambda = neg_lambda
894
+ self.same_sampling = same_sampling
895
+
896
+ def _get_bad_mask(
897
+ self, labels: tf.Tensor, target_labels: tf.Tensor, idxs: tf.Tensor
898
+ ) -> tf.Tensor:
899
+ """Calculate bad mask for given indices.
900
+
901
+ Checks that input features are different for positive negative samples.
902
+ """
903
+ pos_labels = tf.expand_dims(target_labels, axis=-2)
904
+ neg_labels = layers_utils.get_candidate_values(labels, idxs)
905
+
906
+ return tf.cast(
907
+ tf.reduce_all(tf.equal(neg_labels, pos_labels), axis=-1), pos_labels.dtype
908
+ )
909
+
910
+ def _get_negs(
911
+ self, embeds: tf.Tensor, labels: tf.Tensor, target_labels: tf.Tensor
912
+ ) -> Tuple[tf.Tensor, tf.Tensor]:
913
+ """Gets negative examples from given tensor."""
914
+ embeds_flat = layers_utils.batch_flatten(embeds)
915
+ labels_flat = layers_utils.batch_flatten(labels)
916
+ target_labels_flat = layers_utils.batch_flatten(target_labels)
917
+
918
+ total_candidates = tf.shape(embeds_flat)[0]
919
+ target_size = tf.shape(target_labels_flat)[0]
920
+
921
+ neg_ids = layers_utils.random_indices(
922
+ target_size, self.num_neg, total_candidates
923
+ )
924
+
925
+ neg_embeds = layers_utils.get_candidate_values(embeds_flat, neg_ids)
926
+ bad_negs = self._get_bad_mask(labels_flat, target_labels_flat, neg_ids)
927
+
928
+ # check if inputs have sequence dimension
929
+ if len(target_labels.shape) == 3:
930
+ # tensors were flattened for sampling, reshape back
931
+ # add sequence dimension if it was present in the inputs
932
+ target_shape = tf.shape(target_labels)
933
+ neg_embeds = tf.reshape(
934
+ neg_embeds, (target_shape[0], target_shape[1], -1, embeds.shape[-1])
935
+ )
936
+ bad_negs = tf.reshape(bad_negs, (target_shape[0], target_shape[1], -1))
937
+
938
+ return neg_embeds, bad_negs
939
+
940
+ def _sample_negatives(
941
+ self,
942
+ inputs_embed: tf.Tensor,
943
+ labels_embed: tf.Tensor,
944
+ labels: tf.Tensor,
945
+ all_labels_embed: tf.Tensor,
946
+ all_labels: tf.Tensor,
947
+ ) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor]:
948
+ """Sample negative examples."""
949
+ pos_inputs_embed = tf.expand_dims(inputs_embed, axis=-2)
950
+ pos_labels_embed = tf.expand_dims(labels_embed, axis=-2)
951
+
952
+ # sample negative inputs
953
+ neg_inputs_embed, inputs_bad_negs = self._get_negs(inputs_embed, labels, labels)
954
+ # sample negative labels
955
+ neg_labels_embed, labels_bad_negs = self._get_negs(
956
+ all_labels_embed, all_labels, labels
957
+ )
958
+ return (
959
+ pos_inputs_embed,
960
+ pos_labels_embed,
961
+ neg_inputs_embed,
962
+ neg_labels_embed,
963
+ inputs_bad_negs,
964
+ labels_bad_negs,
965
+ )
966
+
967
+ def _train_sim(
968
+ self,
969
+ pos_inputs_embed: tf.Tensor,
970
+ pos_labels_embed: tf.Tensor,
971
+ neg_inputs_embed: tf.Tensor,
972
+ neg_labels_embed: tf.Tensor,
973
+ inputs_bad_negs: tf.Tensor,
974
+ labels_bad_negs: tf.Tensor,
975
+ mask: Optional[tf.Tensor],
976
+ ) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor]:
977
+ """Define similarity."""
978
+ # calculate similarity with several
979
+ # embedded actions for the loss
980
+ neg_inf = tf.constant(-1e9)
981
+
982
+ sim_pos = self.sim(pos_inputs_embed, pos_labels_embed, mask)
983
+ sim_neg_il = (
984
+ self.sim(pos_inputs_embed, neg_labels_embed, mask)
985
+ + neg_inf * labels_bad_negs
986
+ )
987
+ sim_neg_ll = (
988
+ self.sim(pos_labels_embed, neg_labels_embed, mask)
989
+ + neg_inf * labels_bad_negs
990
+ )
991
+ sim_neg_ii = (
992
+ self.sim(pos_inputs_embed, neg_inputs_embed, mask)
993
+ + neg_inf * inputs_bad_negs
994
+ )
995
+ sim_neg_li = (
996
+ self.sim(pos_labels_embed, neg_inputs_embed, mask)
997
+ + neg_inf * inputs_bad_negs
998
+ )
999
+
1000
+ # output similarities between user input and bot actions
1001
+ # and similarities between bot actions and similarities between user inputs
1002
+ return sim_pos, sim_neg_il, sim_neg_ll, sim_neg_ii, sim_neg_li
1003
+
1004
+ @staticmethod
1005
+ def _calc_accuracy(sim_pos: tf.Tensor, sim_neg: tf.Tensor) -> tf.Tensor:
1006
+ """Calculate accuracy."""
1007
+ max_all_sim = tf.reduce_max(tf.concat([sim_pos, sim_neg], axis=-1), axis=-1)
1008
+ sim_pos = tf.squeeze(sim_pos, axis=-1)
1009
+ return layers_utils.reduce_mean_equal(max_all_sim, sim_pos)
1010
+
1011
+ def _loss_margin(
1012
+ self,
1013
+ sim_pos: tf.Tensor,
1014
+ sim_neg_il: tf.Tensor,
1015
+ sim_neg_ll: tf.Tensor,
1016
+ sim_neg_ii: tf.Tensor,
1017
+ sim_neg_li: tf.Tensor,
1018
+ mask: Optional[tf.Tensor],
1019
+ ) -> tf.Tensor:
1020
+ """Define max margin loss."""
1021
+ # loss for maximizing similarity with correct action
1022
+ loss = tf.maximum(0.0, self.mu_pos - tf.squeeze(sim_pos, axis=-1))
1023
+
1024
+ # loss for minimizing similarity with `num_neg` incorrect actions
1025
+ if self.use_max_sim_neg:
1026
+ # minimize only maximum similarity over incorrect actions
1027
+ max_sim_neg_il = tf.reduce_max(sim_neg_il, axis=-1)
1028
+ loss += tf.maximum(0.0, self.mu_neg + max_sim_neg_il)
1029
+ else:
1030
+ # minimize all similarities with incorrect actions
1031
+ max_margin = tf.maximum(0.0, self.mu_neg + sim_neg_il)
1032
+ loss += tf.reduce_sum(max_margin, axis=-1)
1033
+
1034
+ # penalize max similarity between pos bot and neg bot embeddings
1035
+ max_sim_neg_ll = tf.maximum(
1036
+ 0.0, self.mu_neg + tf.reduce_max(sim_neg_ll, axis=-1)
1037
+ )
1038
+ loss += max_sim_neg_ll * self.neg_lambda
1039
+
1040
+ # penalize max similarity between pos dial and neg dial embeddings
1041
+ max_sim_neg_ii = tf.maximum(
1042
+ 0.0, self.mu_neg + tf.reduce_max(sim_neg_ii, axis=-1)
1043
+ )
1044
+ loss += max_sim_neg_ii * self.neg_lambda
1045
+
1046
+ # penalize max similarity between pos bot and neg dial embeddings
1047
+ max_sim_neg_li = tf.maximum(
1048
+ 0.0, self.mu_neg + tf.reduce_max(sim_neg_li, axis=-1)
1049
+ )
1050
+ loss += max_sim_neg_li * self.neg_lambda
1051
+
1052
+ if mask is not None:
1053
+ # mask loss for different length sequences
1054
+ loss *= mask
1055
+ # average the loss over sequence length
1056
+ loss = tf.reduce_sum(loss, axis=-1) / tf.reduce_sum(mask, axis=1)
1057
+
1058
+ # average the loss over the batch
1059
+ loss = tf.reduce_mean(loss)
1060
+
1061
+ return loss
1062
+
1063
+ def _loss_cross_entropy(
1064
+ self,
1065
+ sim_pos: tf.Tensor,
1066
+ sim_neg_il: tf.Tensor,
1067
+ sim_neg_ll: tf.Tensor,
1068
+ sim_neg_ii: tf.Tensor,
1069
+ sim_neg_li: tf.Tensor,
1070
+ mask: Optional[tf.Tensor],
1071
+ ) -> tf.Tensor:
1072
+ """Defines cross entropy loss."""
1073
+ loss = self._compute_softmax_loss(
1074
+ sim_pos, sim_neg_il, sim_neg_ll, sim_neg_ii, sim_neg_li
1075
+ )
1076
+
1077
+ if self.constrain_similarities:
1078
+ loss += self._compute_sigmoid_loss(
1079
+ sim_pos, sim_neg_il, sim_neg_ll, sim_neg_ii, sim_neg_li
1080
+ )
1081
+
1082
+ loss = self.apply_mask_and_scaling(loss, mask)
1083
+
1084
+ # average the loss over the batch
1085
+ return tf.reduce_mean(loss)
1086
+
1087
+ @staticmethod
1088
+ def _compute_sigmoid_loss(
1089
+ sim_pos: tf.Tensor,
1090
+ sim_neg_il: tf.Tensor,
1091
+ sim_neg_ll: tf.Tensor,
1092
+ sim_neg_ii: tf.Tensor,
1093
+ sim_neg_li: tf.Tensor,
1094
+ ) -> tf.Tensor:
1095
+ # Constrain similarity values in a range by applying sigmoid
1096
+ # on them individually so that they saturate at extreme values.
1097
+ sigmoid_logits = tf.concat(
1098
+ [sim_pos, sim_neg_il, sim_neg_ll, sim_neg_ii, sim_neg_li], axis=-1
1099
+ )
1100
+ sigmoid_labels = tf.concat(
1101
+ [
1102
+ tf.ones_like(sigmoid_logits[..., :1]),
1103
+ tf.zeros_like(sigmoid_logits[..., 1:]),
1104
+ ],
1105
+ axis=-1,
1106
+ )
1107
+ sigmoid_loss = tf.nn.sigmoid_cross_entropy_with_logits(
1108
+ labels=sigmoid_labels, logits=sigmoid_logits
1109
+ )
1110
+ # average over logits axis
1111
+ return tf.reduce_mean(sigmoid_loss, axis=-1)
1112
+
1113
+ def _compute_softmax_loss(
1114
+ self,
1115
+ sim_pos: tf.Tensor,
1116
+ sim_neg_il: tf.Tensor,
1117
+ sim_neg_ll: tf.Tensor,
1118
+ sim_neg_ii: tf.Tensor,
1119
+ sim_neg_li: tf.Tensor,
1120
+ ) -> tf.Tensor:
1121
+ # Similarity terms between input and label should be optimized relative
1122
+ # to each other and hence use them as logits for softmax term
1123
+ softmax_logits = tf.concat([sim_pos, sim_neg_il, sim_neg_li], axis=-1)
1124
+ if not self.constrain_similarities:
1125
+ # Concatenate other similarity terms as well. Due to this,
1126
+ # similarity values between input and label may not be
1127
+ # approximately bounded in a defined range.
1128
+ softmax_logits = tf.concat(
1129
+ [softmax_logits, sim_neg_ii, sim_neg_ll], axis=-1
1130
+ )
1131
+ # create label_ids for softmax
1132
+ softmax_label_ids = tf.zeros_like(softmax_logits[..., 0], tf.int32)
1133
+ softmax_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
1134
+ labels=softmax_label_ids, logits=softmax_logits
1135
+ )
1136
+ return softmax_loss
1137
+
1138
+ @property
1139
+ def _chosen_loss(self) -> Callable:
1140
+ """Use loss depending on given option."""
1141
+ if self.loss_type == MARGIN:
1142
+ return self._loss_margin
1143
+ elif self.loss_type == CROSS_ENTROPY:
1144
+ return self._loss_cross_entropy
1145
+ else:
1146
+ raise TFLayerConfigException(
1147
+ f"Wrong loss type '{self.loss_type}', "
1148
+ f"should be '{MARGIN}' or '{CROSS_ENTROPY}'"
1149
+ )
1150
+
1151
+ # noinspection PyMethodOverriding
1152
+ def call(
1153
+ self,
1154
+ inputs_embed: tf.Tensor,
1155
+ labels_embed: tf.Tensor,
1156
+ labels: tf.Tensor,
1157
+ all_labels_embed: tf.Tensor,
1158
+ all_labels: tf.Tensor,
1159
+ mask: Optional[tf.Tensor] = None,
1160
+ ) -> Tuple[tf.Tensor, tf.Tensor]:
1161
+ """Calculate loss and accuracy.
1162
+
1163
+ Args:
1164
+ inputs_embed: Embedding tensor for the batch inputs;
1165
+ shape `(batch_size, ..., num_features)`
1166
+ labels_embed: Embedding tensor for the batch labels;
1167
+ shape `(batch_size, ..., num_features)`
1168
+ labels: Tensor representing batch labels; shape `(batch_size, ..., 1)`
1169
+ all_labels_embed: Embedding tensor for the all labels;
1170
+ shape `(num_labels, num_features)`
1171
+ all_labels: Tensor representing all labels; shape `(num_labels, 1)`
1172
+ mask: Optional mask, contains `1` for inputs and `0` for padding;
1173
+ shape `(batch_size, 1)`
1174
+
1175
+ Returns:
1176
+ loss: Total loss.
1177
+ accuracy: Training accuracy.
1178
+ """
1179
+ (
1180
+ pos_inputs_embed,
1181
+ pos_labels_embed,
1182
+ neg_inputs_embed,
1183
+ neg_labels_embed,
1184
+ inputs_bad_negs,
1185
+ labels_bad_negs,
1186
+ ) = self._sample_negatives(
1187
+ inputs_embed, labels_embed, labels, all_labels_embed, all_labels
1188
+ )
1189
+
1190
+ # calculate similarities
1191
+ sim_pos, sim_neg_il, sim_neg_ll, sim_neg_ii, sim_neg_li = self._train_sim(
1192
+ pos_inputs_embed,
1193
+ pos_labels_embed,
1194
+ neg_inputs_embed,
1195
+ neg_labels_embed,
1196
+ inputs_bad_negs,
1197
+ labels_bad_negs,
1198
+ mask,
1199
+ )
1200
+
1201
+ accuracy = self._calc_accuracy(sim_pos, sim_neg_il)
1202
+
1203
+ loss = self._chosen_loss(
1204
+ sim_pos, sim_neg_il, sim_neg_ll, sim_neg_ii, sim_neg_li, mask
1205
+ )
1206
+
1207
+ return loss, accuracy
1208
+
1209
+
1210
+ class MultiLabelDotProductLoss(DotProductLoss):
1211
+ """Multi-label dot-product loss layer.
1212
+
1213
+ This loss layer assumes that multiple outputs (labels) can be correct for any given
1214
+ input. To accomodate for this, we use a sigmoid cross-entropy loss here.
1215
+ """
1216
+
1217
+ def __init__(
1218
+ self,
1219
+ num_candidates: int,
1220
+ scale_loss: bool = False,
1221
+ constrain_similarities: bool = True,
1222
+ model_confidence: Text = SOFTMAX,
1223
+ similarity_type: Text = INNER,
1224
+ name: Optional[Text] = None,
1225
+ **kwargs: Any,
1226
+ ) -> None:
1227
+ """Declares instance variables with default values.
1228
+
1229
+ Args:
1230
+ num_candidates: Positive integer, the number of candidate labels.
1231
+ scale_loss: If `True` scale loss inverse proportionally to
1232
+ the confidence of the correct prediction.
1233
+ similarity_type: Similarity measure to use, either `cosine` or `inner`.
1234
+ name: Optional name of the layer.
1235
+ constrain_similarities: Boolean, if `True` applies sigmoid on all
1236
+ similarity terms and adds to the loss function to
1237
+ ensure that similarity values are approximately bounded.
1238
+ Used inside _loss_cross_entropy() only.
1239
+ model_confidence: Normalization of confidence values during inference.
1240
+ Currently, the only possible value is `SOFTMAX`.
1241
+ """
1242
+ super().__init__(
1243
+ num_candidates,
1244
+ scale_loss=scale_loss,
1245
+ similarity_type=similarity_type,
1246
+ name=name,
1247
+ constrain_similarities=constrain_similarities,
1248
+ model_confidence=model_confidence,
1249
+ )
1250
+
1251
+ def call(
1252
+ self,
1253
+ batch_inputs_embed: tf.Tensor,
1254
+ batch_labels_embed: tf.Tensor,
1255
+ batch_labels_ids: tf.Tensor,
1256
+ all_labels_embed: tf.Tensor,
1257
+ all_labels_ids: tf.Tensor,
1258
+ mask: Optional[tf.Tensor] = None,
1259
+ ) -> Tuple[tf.Tensor, tf.Tensor]:
1260
+ """Calculates loss and accuracy.
1261
+
1262
+ Args:
1263
+ batch_inputs_embed: Embeddings of the batch inputs (e.g. featurized
1264
+ trackers); shape `(batch_size, 1, num_features)`
1265
+ batch_labels_embed: Embeddings of the batch labels (e.g. featurized intents
1266
+ for IntentTED);
1267
+ shape `(batch_size, max_num_labels_per_input, num_features)`
1268
+ batch_labels_ids: Batch label indices (e.g. indices of the intents). We
1269
+ assume that indices are integers that run from `0` to
1270
+ `(number of labels) - 1`.
1271
+ shape `(batch_size, max_num_labels_per_input, 1)`
1272
+ all_labels_embed: Embeddings for all labels in the domain;
1273
+ shape `(batch_size, num_features)`
1274
+ all_labels_ids: Indices for all labels in the domain;
1275
+ shape `(num_labels, 1)`
1276
+ mask: Optional sequence mask, which contains `1` for inputs and `0` for
1277
+ padding.
1278
+
1279
+ Returns:
1280
+ loss: Total loss (based on StarSpace http://arxiv.org/abs/1709.03856);
1281
+ scalar
1282
+ accuracy: Training accuracy; scalar
1283
+ """
1284
+ (
1285
+ pos_inputs_embed, # (batch_size, 1, 1, num_features)
1286
+ pos_labels_embed, # (batch_size, 1, max_num_labels_per_input, num_features)
1287
+ candidate_labels_embed, # (batch_size, 1, num_candidates, num_features)
1288
+ pos_neg_labels, # (batch_size, num_candidates)
1289
+ ) = self._sample_candidates(
1290
+ batch_inputs_embed,
1291
+ batch_labels_embed,
1292
+ batch_labels_ids,
1293
+ all_labels_embed,
1294
+ all_labels_ids,
1295
+ )
1296
+
1297
+ # Calculate similarities
1298
+ sim_pos, sim_candidate_il = self._train_sim(
1299
+ pos_inputs_embed, pos_labels_embed, candidate_labels_embed, mask
1300
+ )
1301
+
1302
+ label_padding_mask = self._construct_mask_for_label_padding(
1303
+ batch_labels_ids, tf.shape(pos_neg_labels)[-1]
1304
+ )
1305
+
1306
+ # Repurpose the `mask` argument of `_accuracy` and `_loss_sigmoid`
1307
+ # to pass the `label_padding_mask`. We can do this right now because
1308
+ # we don't use `MultiLabelDotProductLoss` for sequence tagging tasks
1309
+ # yet. Hence, the `mask` argument passed to this function will always
1310
+ # be empty. Whenever, we come across a use case where `mask` is
1311
+ # non-empty we'll have to refactor the `_accuracy` and `_loss_sigmoid`
1312
+ # functions to take into consideration both, sequence level masks as
1313
+ # well as label padding masks.
1314
+
1315
+ accuracy = self._accuracy(
1316
+ sim_pos, sim_candidate_il, pos_neg_labels, label_padding_mask
1317
+ )
1318
+ loss = self._loss_sigmoid(
1319
+ sim_pos, sim_candidate_il, pos_neg_labels, mask=label_padding_mask
1320
+ )
1321
+
1322
+ return loss, accuracy
1323
+
1324
+ @staticmethod
1325
+ def _construct_mask_for_label_padding(
1326
+ batch_labels_ids: tf.Tensor, num_candidates: tf.Tensor
1327
+ ) -> tf.Tensor:
1328
+ """Constructs a mask which indicates indices for valid label ids.
1329
+
1330
+ Indices corresponding to valid label ids have a
1331
+ `1` and indices corresponding to `LABEL_PAD_ID`
1332
+ have a `0`.
1333
+
1334
+ Args:
1335
+ batch_labels_ids: Batch label indices (e.g. indices of the intents). We
1336
+ assume that indices are integers that run from `0` to
1337
+ `(number of labels) - 1` with a special
1338
+ value for padding which is set to `LABEL_PAD_ID`.
1339
+ shape `(batch_size, max_num_labels_per_input, 1)`
1340
+ num_candidates: Number of candidates sampled.
1341
+
1342
+ Returns:
1343
+ Constructed mask.
1344
+ """
1345
+ pos_label_pad_indices = tf.cast(
1346
+ tf.equal(tf.squeeze(batch_labels_ids, -1), LABEL_PAD_ID), dtype=tf.float32
1347
+ )
1348
+
1349
+ # Flip 1 and 0 to 0 and 1 respectively
1350
+ pos_label_pad_mask = 1 - pos_label_pad_indices
1351
+
1352
+ # `pos_label_pad_mask` only contains the mask for label ids
1353
+ # seen in the batch. For sampled candidate label ids, the mask
1354
+ # should be a tensor of `1`s since all candidate label ids
1355
+ # are valid. From this, we construct the padding mask for
1356
+ # all label ids: label ids seen in the batch + label ids sampled.
1357
+ all_label_pad_mask = tf.concat(
1358
+ [
1359
+ pos_label_pad_mask,
1360
+ tf.ones(
1361
+ (tf.shape(batch_labels_ids)[0], num_candidates), dtype=tf.float32
1362
+ ),
1363
+ ],
1364
+ axis=-1,
1365
+ )
1366
+
1367
+ return all_label_pad_mask
1368
+
1369
+ def _train_sim(
1370
+ self,
1371
+ pos_inputs_embed: tf.Tensor,
1372
+ pos_labels_embed: tf.Tensor,
1373
+ candidate_labels_embed: tf.Tensor,
1374
+ mask: tf.Tensor,
1375
+ ) -> Tuple[tf.Tensor, tf.Tensor]:
1376
+ sim_pos = self.sim(
1377
+ pos_inputs_embed, pos_labels_embed, mask
1378
+ ) # (batch_size, 1, max_labels_per_input)
1379
+ sim_candidate_il = self.sim(
1380
+ pos_inputs_embed, candidate_labels_embed, mask
1381
+ ) # (batch_size, 1, num_candidates)
1382
+
1383
+ return sim_pos, sim_candidate_il
1384
+
1385
+ def _sample_candidates(
1386
+ self,
1387
+ batch_inputs_embed: tf.Tensor,
1388
+ batch_labels_embed: tf.Tensor,
1389
+ batch_labels_ids: tf.Tensor,
1390
+ all_labels_embed: tf.Tensor,
1391
+ all_labels_ids: tf.Tensor,
1392
+ ) -> Tuple[
1393
+ tf.Tensor, # (batch_size, 1, 1, num_features)
1394
+ tf.Tensor, # (batch_size, 1, num_features)
1395
+ tf.Tensor, # (batch_size, 1, num_candidates, num_features)
1396
+ tf.Tensor, # (batch_size, num_candidates)
1397
+ ]:
1398
+ """Samples candidate examples.
1399
+
1400
+ Args:
1401
+ batch_inputs_embed: Embeddings of the batch inputs (e.g. featurized
1402
+ trackers) # (batch_size, 1, num_features)
1403
+ batch_labels_embed: Embeddings of the batch labels (e.g. featurized intents
1404
+ for IntentTED) # (batch_size, max_num_labels_per_input, num_features)
1405
+ batch_labels_ids: Batch label indices (e.g. indices of the
1406
+ intents) # (batch_size, max_num_labels_per_input, 1)
1407
+ all_labels_embed: Embeddings for all labels in
1408
+ the domain # (num_labels, num_features)
1409
+ all_labels_ids: Indices for all labels in the
1410
+ domain # (num_labels, 1)
1411
+
1412
+ Returns:
1413
+ pos_inputs_embed: Embeddings of the batch inputs
1414
+ pos_labels_embed: Embeddings of the batch labels with an extra
1415
+ dimension inserted.
1416
+ candidate_labels_embed: More examples of embeddings of labels, some positive
1417
+ some negative
1418
+ pos_neg_indicators: Indicator for which candidates are positives and which
1419
+ are negatives
1420
+ """
1421
+ pos_inputs_embed = tf.expand_dims(
1422
+ batch_inputs_embed, axis=-2, name="expand_pos_input"
1423
+ )
1424
+
1425
+ pos_labels_embed = tf.expand_dims(
1426
+ batch_labels_embed, axis=1, name="expand_pos_labels"
1427
+ )
1428
+
1429
+ # Pick random examples from the batch
1430
+ candidate_ids = layers_utils.random_indices(
1431
+ batch_size=tf.shape(batch_inputs_embed)[0],
1432
+ n=self.num_neg,
1433
+ n_max=tf.shape(all_labels_embed)[0],
1434
+ )
1435
+
1436
+ # Get the label embeddings corresponding to candidate indices
1437
+ candidate_labels_embed = layers_utils.get_candidate_values(
1438
+ all_labels_embed, candidate_ids
1439
+ )
1440
+ candidate_labels_embed = tf.expand_dims(candidate_labels_embed, axis=1)
1441
+
1442
+ # Get binary indicators of whether a candidate is positive or not
1443
+ pos_neg_indicators = self._get_pos_neg_indicators(
1444
+ all_labels_ids, batch_labels_ids, candidate_ids
1445
+ )
1446
+
1447
+ return (
1448
+ pos_inputs_embed,
1449
+ pos_labels_embed,
1450
+ candidate_labels_embed,
1451
+ pos_neg_indicators,
1452
+ )
1453
+
1454
+ def _get_pos_neg_indicators(
1455
+ self,
1456
+ all_labels_ids: tf.Tensor,
1457
+ batch_labels_ids: tf.Tensor,
1458
+ candidate_ids: tf.Tensor,
1459
+ ) -> tf.Tensor:
1460
+ """Computes indicators for which candidates are positive labels.
1461
+
1462
+ Args:
1463
+ all_labels_ids: Indices of all the labels
1464
+ batch_labels_ids: Indices of the labels in the examples
1465
+ candidate_ids: Indices of labels that may or may not appear in the examples
1466
+
1467
+ Returns:
1468
+ Binary indicators of whether or not a label is positive
1469
+ """
1470
+ candidate_labels_ids = layers_utils.get_candidate_values(
1471
+ all_labels_ids, candidate_ids
1472
+ )
1473
+ candidate_labels_ids = tf.expand_dims(candidate_labels_ids, axis=1)
1474
+
1475
+ # Determine how many distinct labels exist (highest label index)
1476
+ max_label_id = tf.cast(tf.math.reduce_max(all_labels_ids), dtype=tf.int32)
1477
+
1478
+ # Convert the positive label ids to their one_hot representation.
1479
+ # Note: -1 indices yield a zeros-only vector. We use -1 as a padding token,
1480
+ # as the number of positive labels in each example can differ. The padding is
1481
+ # added in the TrackerFeaturizer.
1482
+ batch_labels_one_hot = tf.one_hot(
1483
+ tf.cast(tf.squeeze(batch_labels_ids, axis=-1), tf.int32),
1484
+ max_label_id + 1,
1485
+ axis=-1,
1486
+ ) # (batch_size, max_num_labels_per_input, max_label_id)
1487
+
1488
+ # Collapse the extra dimension and convert to a multi-hot representation
1489
+ # by aggregating all ones in the one-hot representation.
1490
+ # We use tf.reduce_any instead of tf.reduce_sum because several examples can
1491
+ # have the same postivie label.
1492
+ batch_labels_multi_hot = tf.cast(
1493
+ tf.math.reduce_any(tf.cast(batch_labels_one_hot, dtype=tf.bool), axis=-2),
1494
+ tf.float32,
1495
+ ) # (batch_size, max_label_id)
1496
+
1497
+ # Remove extra dimensions for gather
1498
+ candidate_labels_ids = tf.squeeze(tf.squeeze(candidate_labels_ids, 1), -1)
1499
+
1500
+ # Collect binary indicators of whether or not a label is positive
1501
+ return tf.gather(
1502
+ batch_labels_multi_hot,
1503
+ tf.cast(candidate_labels_ids, tf.int32),
1504
+ batch_dims=1,
1505
+ name="gather_labels",
1506
+ )
1507
+
1508
+ def _loss_sigmoid(
1509
+ self,
1510
+ sim_pos: tf.Tensor, # (batch_size, 1, max_num_labels_per_input)
1511
+ sim_candidates_il: tf.Tensor, # (batch_size, 1, num_candidates)
1512
+ pos_neg_labels: tf.Tensor, # (batch_size, num_candidates)
1513
+ mask: Optional[
1514
+ tf.Tensor
1515
+ ] = None, # (batch_size, max_num_labels_per_input + num_candidates)
1516
+ ) -> tf.Tensor: # ()
1517
+ """Computes the sigmoid loss."""
1518
+ # Concatenate the guaranteed positive examples with the candidate examples,
1519
+ # some of which are positives and others are negatives. Which are which
1520
+ # is stored in `pos_neg_labels`.
1521
+ logits = tf.concat([sim_pos, sim_candidates_il], axis=-1, name="logit_concat")
1522
+ logits = tf.squeeze(logits, 1)
1523
+
1524
+ # Create label_ids for sigmoid. `mask` will take care of the
1525
+ # extra 1s we create as label ids for indices corresponding
1526
+ # to padding ids.
1527
+ pos_label_ids = tf.squeeze(tf.ones_like(sim_pos, tf.float32), 1)
1528
+ label_ids = tf.concat(
1529
+ [pos_label_ids, pos_neg_labels], axis=-1, name="gt_concat"
1530
+ )
1531
+
1532
+ # Compute the sigmoid cross-entropy loss. When minimized, the embeddings
1533
+ # for the two classes (positive and negative) are pushed away from each
1534
+ # other in the embedding space, while it is allowed that any input embedding
1535
+ # corresponds to more than one label.
1536
+ loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=label_ids, logits=logits)
1537
+
1538
+ loss = self.apply_mask_and_scaling(loss, mask)
1539
+
1540
+ # Average the loss over the batch
1541
+ return tf.reduce_mean(loss)
1542
+
1543
+ @staticmethod
1544
+ def _accuracy(
1545
+ sim_pos: tf.Tensor, # (batch_size, 1, max_num_labels_per_input)
1546
+ sim_candidates: tf.Tensor, # (batch_size, 1, num_candidates)
1547
+ pos_neg_indicators: tf.Tensor, # (batch_size, num_candidates)
1548
+ mask: tf.Tensor, # (batch_size, max_num_labels_per_input + num_candidates)
1549
+ ) -> tf.Tensor: # ()
1550
+ """Calculates the accuracy."""
1551
+ all_preds = tf.concat(
1552
+ [sim_pos, sim_candidates], axis=-1, name="acc_concat_preds"
1553
+ )
1554
+ all_preds_sigmoid = tf.nn.sigmoid(all_preds)
1555
+ all_pred_labels = tf.squeeze(tf.math.round(all_preds_sigmoid), 1)
1556
+
1557
+ # Create an indicator for the positive labels by concatenating the 1 for all
1558
+ # guaranteed positive labels and the `pos_neg_indicators`
1559
+ all_positives = tf.concat(
1560
+ [tf.squeeze(tf.ones_like(sim_pos), axis=1), pos_neg_indicators],
1561
+ axis=-1,
1562
+ name="acc_concat_gt",
1563
+ )
1564
+
1565
+ return layers_utils.reduce_mean_equal(all_pred_labels, all_positives, mask=mask)