rasa-pro 3.9.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rasa-pro might be problematic. Click here for more details.

Files changed (662) hide show
  1. README.md +415 -0
  2. rasa/__init__.py +10 -0
  3. rasa/__main__.py +156 -0
  4. rasa/anonymization/__init__.py +2 -0
  5. rasa/anonymization/anonymisation_rule_yaml_reader.py +91 -0
  6. rasa/anonymization/anonymization_pipeline.py +286 -0
  7. rasa/anonymization/anonymization_rule_executor.py +260 -0
  8. rasa/anonymization/anonymization_rule_orchestrator.py +120 -0
  9. rasa/anonymization/schemas/config.yml +47 -0
  10. rasa/anonymization/utils.py +118 -0
  11. rasa/api.py +146 -0
  12. rasa/cli/__init__.py +5 -0
  13. rasa/cli/arguments/__init__.py +0 -0
  14. rasa/cli/arguments/data.py +81 -0
  15. rasa/cli/arguments/default_arguments.py +165 -0
  16. rasa/cli/arguments/evaluate.py +65 -0
  17. rasa/cli/arguments/export.py +51 -0
  18. rasa/cli/arguments/interactive.py +74 -0
  19. rasa/cli/arguments/run.py +204 -0
  20. rasa/cli/arguments/shell.py +13 -0
  21. rasa/cli/arguments/test.py +211 -0
  22. rasa/cli/arguments/train.py +263 -0
  23. rasa/cli/arguments/visualize.py +34 -0
  24. rasa/cli/arguments/x.py +30 -0
  25. rasa/cli/data.py +292 -0
  26. rasa/cli/e2e_test.py +586 -0
  27. rasa/cli/evaluate.py +222 -0
  28. rasa/cli/export.py +250 -0
  29. rasa/cli/inspect.py +63 -0
  30. rasa/cli/interactive.py +164 -0
  31. rasa/cli/license.py +65 -0
  32. rasa/cli/markers.py +78 -0
  33. rasa/cli/project_templates/__init__.py +0 -0
  34. rasa/cli/project_templates/calm/actions/__init__.py +0 -0
  35. rasa/cli/project_templates/calm/actions/action_template.py +27 -0
  36. rasa/cli/project_templates/calm/actions/add_contact.py +30 -0
  37. rasa/cli/project_templates/calm/actions/db.py +57 -0
  38. rasa/cli/project_templates/calm/actions/list_contacts.py +22 -0
  39. rasa/cli/project_templates/calm/actions/remove_contact.py +35 -0
  40. rasa/cli/project_templates/calm/config.yml +12 -0
  41. rasa/cli/project_templates/calm/credentials.yml +33 -0
  42. rasa/cli/project_templates/calm/data/flows/add_contact.yml +31 -0
  43. rasa/cli/project_templates/calm/data/flows/list_contacts.yml +14 -0
  44. rasa/cli/project_templates/calm/data/flows/remove_contact.yml +29 -0
  45. rasa/cli/project_templates/calm/db/contacts.json +10 -0
  46. rasa/cli/project_templates/calm/domain/add_contact.yml +39 -0
  47. rasa/cli/project_templates/calm/domain/list_contacts.yml +17 -0
  48. rasa/cli/project_templates/calm/domain/remove_contact.yml +38 -0
  49. rasa/cli/project_templates/calm/domain/shared.yml +10 -0
  50. rasa/cli/project_templates/calm/e2e_tests/cancelations/user_cancels_during_a_correction.yml +16 -0
  51. rasa/cli/project_templates/calm/e2e_tests/cancelations/user_changes_mind_on_a_whim.yml +7 -0
  52. rasa/cli/project_templates/calm/e2e_tests/corrections/user_corrects_contact_handle.yml +20 -0
  53. rasa/cli/project_templates/calm/e2e_tests/corrections/user_corrects_contact_name.yml +19 -0
  54. rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_adds_contact_to_their_list.yml +15 -0
  55. rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_lists_contacts.yml +5 -0
  56. rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_removes_contact.yml +11 -0
  57. rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_removes_contact_from_list.yml +12 -0
  58. rasa/cli/project_templates/calm/endpoints.yml +45 -0
  59. rasa/cli/project_templates/default/actions/__init__.py +0 -0
  60. rasa/cli/project_templates/default/actions/actions.py +27 -0
  61. rasa/cli/project_templates/default/config.yml +44 -0
  62. rasa/cli/project_templates/default/credentials.yml +33 -0
  63. rasa/cli/project_templates/default/data/nlu.yml +91 -0
  64. rasa/cli/project_templates/default/data/rules.yml +13 -0
  65. rasa/cli/project_templates/default/data/stories.yml +30 -0
  66. rasa/cli/project_templates/default/domain.yml +34 -0
  67. rasa/cli/project_templates/default/endpoints.yml +42 -0
  68. rasa/cli/project_templates/default/tests/test_stories.yml +91 -0
  69. rasa/cli/project_templates/tutorial/actions.py +22 -0
  70. rasa/cli/project_templates/tutorial/config.yml +11 -0
  71. rasa/cli/project_templates/tutorial/credentials.yml +33 -0
  72. rasa/cli/project_templates/tutorial/data/flows.yml +8 -0
  73. rasa/cli/project_templates/tutorial/data/patterns.yml +6 -0
  74. rasa/cli/project_templates/tutorial/domain.yml +21 -0
  75. rasa/cli/project_templates/tutorial/endpoints.yml +45 -0
  76. rasa/cli/run.py +135 -0
  77. rasa/cli/scaffold.py +269 -0
  78. rasa/cli/shell.py +141 -0
  79. rasa/cli/studio/__init__.py +0 -0
  80. rasa/cli/studio/download.py +62 -0
  81. rasa/cli/studio/studio.py +266 -0
  82. rasa/cli/studio/train.py +59 -0
  83. rasa/cli/studio/upload.py +77 -0
  84. rasa/cli/telemetry.py +102 -0
  85. rasa/cli/test.py +280 -0
  86. rasa/cli/train.py +260 -0
  87. rasa/cli/utils.py +464 -0
  88. rasa/cli/visualize.py +40 -0
  89. rasa/cli/x.py +206 -0
  90. rasa/constants.py +37 -0
  91. rasa/core/__init__.py +17 -0
  92. rasa/core/actions/__init__.py +0 -0
  93. rasa/core/actions/action.py +1225 -0
  94. rasa/core/actions/action_clean_stack.py +59 -0
  95. rasa/core/actions/action_exceptions.py +24 -0
  96. rasa/core/actions/action_run_slot_rejections.py +207 -0
  97. rasa/core/actions/action_trigger_chitchat.py +31 -0
  98. rasa/core/actions/action_trigger_flow.py +109 -0
  99. rasa/core/actions/action_trigger_search.py +31 -0
  100. rasa/core/actions/constants.py +5 -0
  101. rasa/core/actions/custom_action_executor.py +188 -0
  102. rasa/core/actions/forms.py +741 -0
  103. rasa/core/actions/grpc_custom_action_executor.py +251 -0
  104. rasa/core/actions/http_custom_action_executor.py +140 -0
  105. rasa/core/actions/loops.py +114 -0
  106. rasa/core/actions/two_stage_fallback.py +186 -0
  107. rasa/core/agent.py +555 -0
  108. rasa/core/auth_retry_tracker_store.py +122 -0
  109. rasa/core/brokers/__init__.py +0 -0
  110. rasa/core/brokers/broker.py +126 -0
  111. rasa/core/brokers/file.py +58 -0
  112. rasa/core/brokers/kafka.py +322 -0
  113. rasa/core/brokers/pika.py +386 -0
  114. rasa/core/brokers/sql.py +86 -0
  115. rasa/core/channels/__init__.py +55 -0
  116. rasa/core/channels/audiocodes.py +463 -0
  117. rasa/core/channels/botframework.py +338 -0
  118. rasa/core/channels/callback.py +84 -0
  119. rasa/core/channels/channel.py +419 -0
  120. rasa/core/channels/console.py +241 -0
  121. rasa/core/channels/development_inspector.py +93 -0
  122. rasa/core/channels/facebook.py +419 -0
  123. rasa/core/channels/hangouts.py +329 -0
  124. rasa/core/channels/inspector/.eslintrc.cjs +25 -0
  125. rasa/core/channels/inspector/.gitignore +23 -0
  126. rasa/core/channels/inspector/README.md +54 -0
  127. rasa/core/channels/inspector/assets/favicon.ico +0 -0
  128. rasa/core/channels/inspector/assets/rasa-chat.js +2 -0
  129. rasa/core/channels/inspector/custom.d.ts +3 -0
  130. rasa/core/channels/inspector/dist/assets/arc-b6e548fe.js +1 -0
  131. rasa/core/channels/inspector/dist/assets/array-9f3ba611.js +1 -0
  132. rasa/core/channels/inspector/dist/assets/c4Diagram-d0fbc5ce-fa03ac9e.js +10 -0
  133. rasa/core/channels/inspector/dist/assets/classDiagram-936ed81e-ee67392a.js +2 -0
  134. rasa/core/channels/inspector/dist/assets/classDiagram-v2-c3cb15f1-9b283fae.js +2 -0
  135. rasa/core/channels/inspector/dist/assets/createText-62fc7601-8b6fcc2a.js +7 -0
  136. rasa/core/channels/inspector/dist/assets/edges-f2ad444c-22e77f4f.js +4 -0
  137. rasa/core/channels/inspector/dist/assets/erDiagram-9d236eb7-60ffc87f.js +51 -0
  138. rasa/core/channels/inspector/dist/assets/flowDb-1972c806-9dd802e4.js +6 -0
  139. rasa/core/channels/inspector/dist/assets/flowDiagram-7ea5b25a-5fa1912f.js +4 -0
  140. rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-1844e5a5.js +1 -0
  141. rasa/core/channels/inspector/dist/assets/flowchart-elk-definition-abe16c3d-622a1fd2.js +139 -0
  142. rasa/core/channels/inspector/dist/assets/ganttDiagram-9b5ea136-e285a63a.js +266 -0
  143. rasa/core/channels/inspector/dist/assets/gitGraphDiagram-99d0ae7c-f237bdca.js +70 -0
  144. rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-128cfa44.ttf +0 -0
  145. rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-21dbcb97.woff +0 -0
  146. rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-222b5e26.svg +329 -0
  147. rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-9ad89b2a.woff2 +0 -0
  148. rasa/core/channels/inspector/dist/assets/index-2c4b9a3b-4b03d70e.js +1 -0
  149. rasa/core/channels/inspector/dist/assets/index-3ee28881.css +1 -0
  150. rasa/core/channels/inspector/dist/assets/index-a5d3e69d.js +1040 -0
  151. rasa/core/channels/inspector/dist/assets/infoDiagram-736b4530-72a0fa5f.js +7 -0
  152. rasa/core/channels/inspector/dist/assets/init-77b53fdd.js +1 -0
  153. rasa/core/channels/inspector/dist/assets/journeyDiagram-df861f2b-82218c41.js +139 -0
  154. rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-60c05ee4.woff +0 -0
  155. rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-8335d9b8.svg +438 -0
  156. rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-9cc39c75.ttf +0 -0
  157. rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-ead13ccf.woff2 +0 -0
  158. rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-16705655.woff2 +0 -0
  159. rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-5aeb07f9.woff +0 -0
  160. rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-9c459044.ttf +0 -0
  161. rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-9e2898a4.svg +435 -0
  162. rasa/core/channels/inspector/dist/assets/layout-78cff630.js +1 -0
  163. rasa/core/channels/inspector/dist/assets/line-5038b469.js +1 -0
  164. rasa/core/channels/inspector/dist/assets/linear-c4fc4098.js +1 -0
  165. rasa/core/channels/inspector/dist/assets/mindmap-definition-beec6740-c33c8ea6.js +109 -0
  166. rasa/core/channels/inspector/dist/assets/ordinal-ba9b4969.js +1 -0
  167. rasa/core/channels/inspector/dist/assets/path-53f90ab3.js +1 -0
  168. rasa/core/channels/inspector/dist/assets/pieDiagram-dbbf0591-a8d03059.js +35 -0
  169. rasa/core/channels/inspector/dist/assets/quadrantDiagram-4d7f4fd6-6a0e56b2.js +7 -0
  170. rasa/core/channels/inspector/dist/assets/requirementDiagram-6fc4c22a-2dc7c7bd.js +52 -0
  171. rasa/core/channels/inspector/dist/assets/sankeyDiagram-8f13d901-2360fe39.js +8 -0
  172. rasa/core/channels/inspector/dist/assets/sequenceDiagram-b655622a-41b9f9ad.js +122 -0
  173. rasa/core/channels/inspector/dist/assets/stateDiagram-59f0c015-0aad326f.js +1 -0
  174. rasa/core/channels/inspector/dist/assets/stateDiagram-v2-2b26beab-9847d984.js +1 -0
  175. rasa/core/channels/inspector/dist/assets/styles-080da4f6-564d890e.js +110 -0
  176. rasa/core/channels/inspector/dist/assets/styles-3dcbcfbf-38957613.js +159 -0
  177. rasa/core/channels/inspector/dist/assets/styles-9c745c82-f0fc6921.js +207 -0
  178. rasa/core/channels/inspector/dist/assets/svgDrawCommon-4835440b-ef3c5a77.js +1 -0
  179. rasa/core/channels/inspector/dist/assets/timeline-definition-5b62e21b-bf3e91c1.js +61 -0
  180. rasa/core/channels/inspector/dist/assets/xychartDiagram-2b33534f-4d4026c0.js +7 -0
  181. rasa/core/channels/inspector/dist/index.html +41 -0
  182. rasa/core/channels/inspector/index.html +39 -0
  183. rasa/core/channels/inspector/jest.config.ts +13 -0
  184. rasa/core/channels/inspector/package.json +48 -0
  185. rasa/core/channels/inspector/setupTests.ts +2 -0
  186. rasa/core/channels/inspector/src/App.tsx +170 -0
  187. rasa/core/channels/inspector/src/components/DiagramFlow.tsx +107 -0
  188. rasa/core/channels/inspector/src/components/DialogueInformation.tsx +187 -0
  189. rasa/core/channels/inspector/src/components/DialogueStack.tsx +151 -0
  190. rasa/core/channels/inspector/src/components/ExpandIcon.tsx +16 -0
  191. rasa/core/channels/inspector/src/components/FullscreenButton.tsx +45 -0
  192. rasa/core/channels/inspector/src/components/LoadingSpinner.tsx +19 -0
  193. rasa/core/channels/inspector/src/components/NoActiveFlow.tsx +21 -0
  194. rasa/core/channels/inspector/src/components/RasaLogo.tsx +32 -0
  195. rasa/core/channels/inspector/src/components/SaraDiagrams.tsx +39 -0
  196. rasa/core/channels/inspector/src/components/Slots.tsx +91 -0
  197. rasa/core/channels/inspector/src/components/Welcome.tsx +54 -0
  198. rasa/core/channels/inspector/src/helpers/formatters.test.ts +382 -0
  199. rasa/core/channels/inspector/src/helpers/formatters.ts +240 -0
  200. rasa/core/channels/inspector/src/helpers/utils.ts +42 -0
  201. rasa/core/channels/inspector/src/main.tsx +13 -0
  202. rasa/core/channels/inspector/src/theme/Button/Button.ts +29 -0
  203. rasa/core/channels/inspector/src/theme/Heading/Heading.ts +31 -0
  204. rasa/core/channels/inspector/src/theme/Input/Input.ts +27 -0
  205. rasa/core/channels/inspector/src/theme/Link/Link.ts +10 -0
  206. rasa/core/channels/inspector/src/theme/Modal/Modal.ts +47 -0
  207. rasa/core/channels/inspector/src/theme/Table/Table.tsx +38 -0
  208. rasa/core/channels/inspector/src/theme/Tooltip/Tooltip.ts +12 -0
  209. rasa/core/channels/inspector/src/theme/base/breakpoints.ts +8 -0
  210. rasa/core/channels/inspector/src/theme/base/colors.ts +88 -0
  211. rasa/core/channels/inspector/src/theme/base/fonts/fontFaces.css +29 -0
  212. rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.eot +0 -0
  213. rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.svg +329 -0
  214. rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.ttf +0 -0
  215. rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.woff +0 -0
  216. rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.woff2 +0 -0
  217. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.eot +0 -0
  218. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.svg +438 -0
  219. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.ttf +0 -0
  220. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.woff +0 -0
  221. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.woff2 +0 -0
  222. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.eot +0 -0
  223. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.svg +435 -0
  224. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.ttf +0 -0
  225. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.woff +0 -0
  226. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.woff2 +0 -0
  227. rasa/core/channels/inspector/src/theme/base/radii.ts +9 -0
  228. rasa/core/channels/inspector/src/theme/base/shadows.ts +7 -0
  229. rasa/core/channels/inspector/src/theme/base/sizes.ts +7 -0
  230. rasa/core/channels/inspector/src/theme/base/space.ts +15 -0
  231. rasa/core/channels/inspector/src/theme/base/styles.ts +13 -0
  232. rasa/core/channels/inspector/src/theme/base/typography.ts +24 -0
  233. rasa/core/channels/inspector/src/theme/base/zIndices.ts +19 -0
  234. rasa/core/channels/inspector/src/theme/index.ts +101 -0
  235. rasa/core/channels/inspector/src/types.ts +64 -0
  236. rasa/core/channels/inspector/src/vite-env.d.ts +1 -0
  237. rasa/core/channels/inspector/tests/__mocks__/fileMock.ts +1 -0
  238. rasa/core/channels/inspector/tests/__mocks__/matchMedia.ts +16 -0
  239. rasa/core/channels/inspector/tests/__mocks__/styleMock.ts +1 -0
  240. rasa/core/channels/inspector/tests/renderWithProviders.tsx +14 -0
  241. rasa/core/channels/inspector/tsconfig.json +26 -0
  242. rasa/core/channels/inspector/tsconfig.node.json +10 -0
  243. rasa/core/channels/inspector/vite.config.ts +8 -0
  244. rasa/core/channels/inspector/yarn.lock +6156 -0
  245. rasa/core/channels/mattermost.py +229 -0
  246. rasa/core/channels/rasa_chat.py +126 -0
  247. rasa/core/channels/rest.py +225 -0
  248. rasa/core/channels/rocketchat.py +174 -0
  249. rasa/core/channels/slack.py +620 -0
  250. rasa/core/channels/socketio.py +274 -0
  251. rasa/core/channels/telegram.py +298 -0
  252. rasa/core/channels/twilio.py +169 -0
  253. rasa/core/channels/twilio_voice.py +367 -0
  254. rasa/core/channels/vier_cvg.py +374 -0
  255. rasa/core/channels/webexteams.py +134 -0
  256. rasa/core/concurrent_lock_store.py +210 -0
  257. rasa/core/constants.py +107 -0
  258. rasa/core/evaluation/__init__.py +0 -0
  259. rasa/core/evaluation/marker.py +267 -0
  260. rasa/core/evaluation/marker_base.py +923 -0
  261. rasa/core/evaluation/marker_stats.py +293 -0
  262. rasa/core/evaluation/marker_tracker_loader.py +103 -0
  263. rasa/core/exceptions.py +29 -0
  264. rasa/core/exporter.py +284 -0
  265. rasa/core/featurizers/__init__.py +0 -0
  266. rasa/core/featurizers/precomputation.py +410 -0
  267. rasa/core/featurizers/single_state_featurizer.py +421 -0
  268. rasa/core/featurizers/tracker_featurizers.py +1262 -0
  269. rasa/core/http_interpreter.py +89 -0
  270. rasa/core/information_retrieval/__init__.py +7 -0
  271. rasa/core/information_retrieval/faiss.py +121 -0
  272. rasa/core/information_retrieval/information_retrieval.py +129 -0
  273. rasa/core/information_retrieval/milvus.py +52 -0
  274. rasa/core/information_retrieval/qdrant.py +95 -0
  275. rasa/core/jobs.py +63 -0
  276. rasa/core/lock.py +139 -0
  277. rasa/core/lock_store.py +343 -0
  278. rasa/core/migrate.py +403 -0
  279. rasa/core/nlg/__init__.py +3 -0
  280. rasa/core/nlg/callback.py +146 -0
  281. rasa/core/nlg/contextual_response_rephraser.py +270 -0
  282. rasa/core/nlg/generator.py +230 -0
  283. rasa/core/nlg/interpolator.py +143 -0
  284. rasa/core/nlg/response.py +155 -0
  285. rasa/core/nlg/summarize.py +69 -0
  286. rasa/core/policies/__init__.py +0 -0
  287. rasa/core/policies/ensemble.py +329 -0
  288. rasa/core/policies/enterprise_search_policy.py +781 -0
  289. rasa/core/policies/enterprise_search_prompt_template.jinja2 +25 -0
  290. rasa/core/policies/enterprise_search_prompt_with_citation_template.jinja2 +60 -0
  291. rasa/core/policies/flow_policy.py +205 -0
  292. rasa/core/policies/flows/__init__.py +0 -0
  293. rasa/core/policies/flows/flow_exceptions.py +44 -0
  294. rasa/core/policies/flows/flow_executor.py +705 -0
  295. rasa/core/policies/flows/flow_step_result.py +43 -0
  296. rasa/core/policies/intentless_policy.py +922 -0
  297. rasa/core/policies/intentless_prompt_template.jinja2 +22 -0
  298. rasa/core/policies/memoization.py +538 -0
  299. rasa/core/policies/policy.py +725 -0
  300. rasa/core/policies/rule_policy.py +1273 -0
  301. rasa/core/policies/ted_policy.py +2169 -0
  302. rasa/core/policies/unexpected_intent_policy.py +1022 -0
  303. rasa/core/processor.py +1422 -0
  304. rasa/core/run.py +331 -0
  305. rasa/core/secrets_manager/__init__.py +0 -0
  306. rasa/core/secrets_manager/constants.py +32 -0
  307. rasa/core/secrets_manager/endpoints.py +391 -0
  308. rasa/core/secrets_manager/factory.py +233 -0
  309. rasa/core/secrets_manager/secret_manager.py +262 -0
  310. rasa/core/secrets_manager/vault.py +574 -0
  311. rasa/core/test.py +1335 -0
  312. rasa/core/tracker_store.py +1699 -0
  313. rasa/core/train.py +105 -0
  314. rasa/core/training/__init__.py +89 -0
  315. rasa/core/training/converters/__init__.py +0 -0
  316. rasa/core/training/converters/responses_prefix_converter.py +119 -0
  317. rasa/core/training/interactive.py +1745 -0
  318. rasa/core/training/story_conflict.py +381 -0
  319. rasa/core/training/training.py +93 -0
  320. rasa/core/utils.py +339 -0
  321. rasa/core/visualize.py +70 -0
  322. rasa/dialogue_understanding/__init__.py +0 -0
  323. rasa/dialogue_understanding/coexistence/__init__.py +0 -0
  324. rasa/dialogue_understanding/coexistence/constants.py +4 -0
  325. rasa/dialogue_understanding/coexistence/intent_based_router.py +196 -0
  326. rasa/dialogue_understanding/coexistence/llm_based_router.py +260 -0
  327. rasa/dialogue_understanding/coexistence/router_template.jinja2 +12 -0
  328. rasa/dialogue_understanding/commands/__init__.py +49 -0
  329. rasa/dialogue_understanding/commands/can_not_handle_command.py +70 -0
  330. rasa/dialogue_understanding/commands/cancel_flow_command.py +125 -0
  331. rasa/dialogue_understanding/commands/change_flow_command.py +44 -0
  332. rasa/dialogue_understanding/commands/chit_chat_answer_command.py +57 -0
  333. rasa/dialogue_understanding/commands/clarify_command.py +86 -0
  334. rasa/dialogue_understanding/commands/command.py +85 -0
  335. rasa/dialogue_understanding/commands/correct_slots_command.py +297 -0
  336. rasa/dialogue_understanding/commands/error_command.py +79 -0
  337. rasa/dialogue_understanding/commands/free_form_answer_command.py +9 -0
  338. rasa/dialogue_understanding/commands/handle_code_change_command.py +73 -0
  339. rasa/dialogue_understanding/commands/human_handoff_command.py +66 -0
  340. rasa/dialogue_understanding/commands/knowledge_answer_command.py +57 -0
  341. rasa/dialogue_understanding/commands/noop_command.py +54 -0
  342. rasa/dialogue_understanding/commands/set_slot_command.py +160 -0
  343. rasa/dialogue_understanding/commands/skip_question_command.py +75 -0
  344. rasa/dialogue_understanding/commands/start_flow_command.py +107 -0
  345. rasa/dialogue_understanding/generator/__init__.py +21 -0
  346. rasa/dialogue_understanding/generator/command_generator.py +343 -0
  347. rasa/dialogue_understanding/generator/constants.py +18 -0
  348. rasa/dialogue_understanding/generator/flow_document_template.jinja2 +4 -0
  349. rasa/dialogue_understanding/generator/flow_retrieval.py +412 -0
  350. rasa/dialogue_understanding/generator/llm_based_command_generator.py +467 -0
  351. rasa/dialogue_understanding/generator/llm_command_generator.py +67 -0
  352. rasa/dialogue_understanding/generator/multi_step/__init__.py +0 -0
  353. rasa/dialogue_understanding/generator/multi_step/fill_slots_prompt.jinja2 +62 -0
  354. rasa/dialogue_understanding/generator/multi_step/handle_flows_prompt.jinja2 +38 -0
  355. rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +827 -0
  356. rasa/dialogue_understanding/generator/nlu_command_adapter.py +218 -0
  357. rasa/dialogue_understanding/generator/single_step/__init__.py +0 -0
  358. rasa/dialogue_understanding/generator/single_step/command_prompt_template.jinja2 +57 -0
  359. rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +345 -0
  360. rasa/dialogue_understanding/patterns/__init__.py +0 -0
  361. rasa/dialogue_understanding/patterns/cancel.py +111 -0
  362. rasa/dialogue_understanding/patterns/cannot_handle.py +43 -0
  363. rasa/dialogue_understanding/patterns/chitchat.py +37 -0
  364. rasa/dialogue_understanding/patterns/clarify.py +97 -0
  365. rasa/dialogue_understanding/patterns/code_change.py +41 -0
  366. rasa/dialogue_understanding/patterns/collect_information.py +90 -0
  367. rasa/dialogue_understanding/patterns/completed.py +40 -0
  368. rasa/dialogue_understanding/patterns/continue_interrupted.py +42 -0
  369. rasa/dialogue_understanding/patterns/correction.py +278 -0
  370. rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +248 -0
  371. rasa/dialogue_understanding/patterns/human_handoff.py +37 -0
  372. rasa/dialogue_understanding/patterns/internal_error.py +47 -0
  373. rasa/dialogue_understanding/patterns/search.py +37 -0
  374. rasa/dialogue_understanding/patterns/skip_question.py +38 -0
  375. rasa/dialogue_understanding/processor/__init__.py +0 -0
  376. rasa/dialogue_understanding/processor/command_processor.py +687 -0
  377. rasa/dialogue_understanding/processor/command_processor_component.py +39 -0
  378. rasa/dialogue_understanding/stack/__init__.py +0 -0
  379. rasa/dialogue_understanding/stack/dialogue_stack.py +178 -0
  380. rasa/dialogue_understanding/stack/frames/__init__.py +19 -0
  381. rasa/dialogue_understanding/stack/frames/chit_chat_frame.py +27 -0
  382. rasa/dialogue_understanding/stack/frames/dialogue_stack_frame.py +137 -0
  383. rasa/dialogue_understanding/stack/frames/flow_stack_frame.py +157 -0
  384. rasa/dialogue_understanding/stack/frames/pattern_frame.py +10 -0
  385. rasa/dialogue_understanding/stack/frames/search_frame.py +27 -0
  386. rasa/dialogue_understanding/stack/utils.py +211 -0
  387. rasa/e2e_test/__init__.py +0 -0
  388. rasa/e2e_test/constants.py +11 -0
  389. rasa/e2e_test/e2e_test_case.py +366 -0
  390. rasa/e2e_test/e2e_test_result.py +34 -0
  391. rasa/e2e_test/e2e_test_runner.py +768 -0
  392. rasa/e2e_test/e2e_test_schema.yml +85 -0
  393. rasa/engine/__init__.py +0 -0
  394. rasa/engine/caching.py +463 -0
  395. rasa/engine/constants.py +17 -0
  396. rasa/engine/exceptions.py +14 -0
  397. rasa/engine/graph.py +637 -0
  398. rasa/engine/loader.py +36 -0
  399. rasa/engine/recipes/__init__.py +0 -0
  400. rasa/engine/recipes/config_files/default_config.yml +44 -0
  401. rasa/engine/recipes/default_components.py +99 -0
  402. rasa/engine/recipes/default_recipe.py +1251 -0
  403. rasa/engine/recipes/graph_recipe.py +79 -0
  404. rasa/engine/recipes/recipe.py +93 -0
  405. rasa/engine/runner/__init__.py +0 -0
  406. rasa/engine/runner/dask.py +250 -0
  407. rasa/engine/runner/interface.py +49 -0
  408. rasa/engine/storage/__init__.py +0 -0
  409. rasa/engine/storage/local_model_storage.py +246 -0
  410. rasa/engine/storage/resource.py +110 -0
  411. rasa/engine/storage/storage.py +203 -0
  412. rasa/engine/training/__init__.py +0 -0
  413. rasa/engine/training/components.py +176 -0
  414. rasa/engine/training/fingerprinting.py +64 -0
  415. rasa/engine/training/graph_trainer.py +256 -0
  416. rasa/engine/training/hooks.py +164 -0
  417. rasa/engine/validation.py +873 -0
  418. rasa/env.py +5 -0
  419. rasa/exceptions.py +69 -0
  420. rasa/graph_components/__init__.py +0 -0
  421. rasa/graph_components/converters/__init__.py +0 -0
  422. rasa/graph_components/converters/nlu_message_converter.py +48 -0
  423. rasa/graph_components/providers/__init__.py +0 -0
  424. rasa/graph_components/providers/domain_for_core_training_provider.py +87 -0
  425. rasa/graph_components/providers/domain_provider.py +71 -0
  426. rasa/graph_components/providers/flows_provider.py +74 -0
  427. rasa/graph_components/providers/forms_provider.py +44 -0
  428. rasa/graph_components/providers/nlu_training_data_provider.py +56 -0
  429. rasa/graph_components/providers/responses_provider.py +44 -0
  430. rasa/graph_components/providers/rule_only_provider.py +49 -0
  431. rasa/graph_components/providers/story_graph_provider.py +43 -0
  432. rasa/graph_components/providers/training_tracker_provider.py +55 -0
  433. rasa/graph_components/validators/__init__.py +0 -0
  434. rasa/graph_components/validators/default_recipe_validator.py +550 -0
  435. rasa/graph_components/validators/finetuning_validator.py +302 -0
  436. rasa/hooks.py +112 -0
  437. rasa/jupyter.py +63 -0
  438. rasa/markers/__init__.py +0 -0
  439. rasa/markers/marker.py +269 -0
  440. rasa/markers/marker_base.py +828 -0
  441. rasa/markers/upload.py +74 -0
  442. rasa/markers/validate.py +21 -0
  443. rasa/model.py +118 -0
  444. rasa/model_testing.py +457 -0
  445. rasa/model_training.py +536 -0
  446. rasa/nlu/__init__.py +7 -0
  447. rasa/nlu/classifiers/__init__.py +3 -0
  448. rasa/nlu/classifiers/classifier.py +5 -0
  449. rasa/nlu/classifiers/diet_classifier.py +1881 -0
  450. rasa/nlu/classifiers/fallback_classifier.py +192 -0
  451. rasa/nlu/classifiers/keyword_intent_classifier.py +188 -0
  452. rasa/nlu/classifiers/llm_intent_classifier.py +519 -0
  453. rasa/nlu/classifiers/logistic_regression_classifier.py +253 -0
  454. rasa/nlu/classifiers/mitie_intent_classifier.py +156 -0
  455. rasa/nlu/classifiers/regex_message_handler.py +56 -0
  456. rasa/nlu/classifiers/sklearn_intent_classifier.py +330 -0
  457. rasa/nlu/constants.py +77 -0
  458. rasa/nlu/convert.py +40 -0
  459. rasa/nlu/emulators/__init__.py +0 -0
  460. rasa/nlu/emulators/dialogflow.py +55 -0
  461. rasa/nlu/emulators/emulator.py +49 -0
  462. rasa/nlu/emulators/luis.py +86 -0
  463. rasa/nlu/emulators/no_emulator.py +10 -0
  464. rasa/nlu/emulators/wit.py +56 -0
  465. rasa/nlu/extractors/__init__.py +0 -0
  466. rasa/nlu/extractors/crf_entity_extractor.py +715 -0
  467. rasa/nlu/extractors/duckling_entity_extractor.py +206 -0
  468. rasa/nlu/extractors/entity_synonyms.py +178 -0
  469. rasa/nlu/extractors/extractor.py +470 -0
  470. rasa/nlu/extractors/mitie_entity_extractor.py +293 -0
  471. rasa/nlu/extractors/regex_entity_extractor.py +220 -0
  472. rasa/nlu/extractors/spacy_entity_extractor.py +95 -0
  473. rasa/nlu/featurizers/__init__.py +0 -0
  474. rasa/nlu/featurizers/dense_featurizer/__init__.py +0 -0
  475. rasa/nlu/featurizers/dense_featurizer/convert_featurizer.py +445 -0
  476. rasa/nlu/featurizers/dense_featurizer/dense_featurizer.py +57 -0
  477. rasa/nlu/featurizers/dense_featurizer/lm_featurizer.py +768 -0
  478. rasa/nlu/featurizers/dense_featurizer/mitie_featurizer.py +170 -0
  479. rasa/nlu/featurizers/dense_featurizer/spacy_featurizer.py +132 -0
  480. rasa/nlu/featurizers/featurizer.py +89 -0
  481. rasa/nlu/featurizers/sparse_featurizer/__init__.py +0 -0
  482. rasa/nlu/featurizers/sparse_featurizer/count_vectors_featurizer.py +867 -0
  483. rasa/nlu/featurizers/sparse_featurizer/lexical_syntactic_featurizer.py +571 -0
  484. rasa/nlu/featurizers/sparse_featurizer/regex_featurizer.py +271 -0
  485. rasa/nlu/featurizers/sparse_featurizer/sparse_featurizer.py +9 -0
  486. rasa/nlu/model.py +24 -0
  487. rasa/nlu/persistor.py +282 -0
  488. rasa/nlu/run.py +27 -0
  489. rasa/nlu/selectors/__init__.py +0 -0
  490. rasa/nlu/selectors/response_selector.py +987 -0
  491. rasa/nlu/test.py +1940 -0
  492. rasa/nlu/tokenizers/__init__.py +0 -0
  493. rasa/nlu/tokenizers/jieba_tokenizer.py +148 -0
  494. rasa/nlu/tokenizers/mitie_tokenizer.py +75 -0
  495. rasa/nlu/tokenizers/spacy_tokenizer.py +72 -0
  496. rasa/nlu/tokenizers/tokenizer.py +239 -0
  497. rasa/nlu/tokenizers/whitespace_tokenizer.py +106 -0
  498. rasa/nlu/utils/__init__.py +35 -0
  499. rasa/nlu/utils/bilou_utils.py +462 -0
  500. rasa/nlu/utils/hugging_face/__init__.py +0 -0
  501. rasa/nlu/utils/hugging_face/registry.py +108 -0
  502. rasa/nlu/utils/hugging_face/transformers_pre_post_processors.py +311 -0
  503. rasa/nlu/utils/mitie_utils.py +113 -0
  504. rasa/nlu/utils/pattern_utils.py +168 -0
  505. rasa/nlu/utils/spacy_utils.py +310 -0
  506. rasa/plugin.py +90 -0
  507. rasa/server.py +1551 -0
  508. rasa/shared/__init__.py +0 -0
  509. rasa/shared/constants.py +192 -0
  510. rasa/shared/core/__init__.py +0 -0
  511. rasa/shared/core/command_payload_reader.py +109 -0
  512. rasa/shared/core/constants.py +167 -0
  513. rasa/shared/core/conversation.py +46 -0
  514. rasa/shared/core/domain.py +2107 -0
  515. rasa/shared/core/events.py +2504 -0
  516. rasa/shared/core/flows/__init__.py +7 -0
  517. rasa/shared/core/flows/flow.py +362 -0
  518. rasa/shared/core/flows/flow_step.py +146 -0
  519. rasa/shared/core/flows/flow_step_links.py +319 -0
  520. rasa/shared/core/flows/flow_step_sequence.py +70 -0
  521. rasa/shared/core/flows/flows_list.py +223 -0
  522. rasa/shared/core/flows/flows_yaml_schema.json +217 -0
  523. rasa/shared/core/flows/nlu_trigger.py +117 -0
  524. rasa/shared/core/flows/steps/__init__.py +24 -0
  525. rasa/shared/core/flows/steps/action.py +56 -0
  526. rasa/shared/core/flows/steps/call.py +64 -0
  527. rasa/shared/core/flows/steps/collect.py +112 -0
  528. rasa/shared/core/flows/steps/constants.py +5 -0
  529. rasa/shared/core/flows/steps/continuation.py +36 -0
  530. rasa/shared/core/flows/steps/end.py +22 -0
  531. rasa/shared/core/flows/steps/internal.py +44 -0
  532. rasa/shared/core/flows/steps/link.py +51 -0
  533. rasa/shared/core/flows/steps/no_operation.py +48 -0
  534. rasa/shared/core/flows/steps/set_slots.py +50 -0
  535. rasa/shared/core/flows/steps/start.py +30 -0
  536. rasa/shared/core/flows/validation.py +527 -0
  537. rasa/shared/core/flows/yaml_flows_io.py +278 -0
  538. rasa/shared/core/generator.py +908 -0
  539. rasa/shared/core/slot_mappings.py +526 -0
  540. rasa/shared/core/slots.py +649 -0
  541. rasa/shared/core/trackers.py +1177 -0
  542. rasa/shared/core/training_data/__init__.py +0 -0
  543. rasa/shared/core/training_data/loading.py +89 -0
  544. rasa/shared/core/training_data/story_reader/__init__.py +0 -0
  545. rasa/shared/core/training_data/story_reader/story_reader.py +129 -0
  546. rasa/shared/core/training_data/story_reader/story_step_builder.py +168 -0
  547. rasa/shared/core/training_data/story_reader/yaml_story_reader.py +888 -0
  548. rasa/shared/core/training_data/story_writer/__init__.py +0 -0
  549. rasa/shared/core/training_data/story_writer/story_writer.py +76 -0
  550. rasa/shared/core/training_data/story_writer/yaml_story_writer.py +444 -0
  551. rasa/shared/core/training_data/structures.py +838 -0
  552. rasa/shared/core/training_data/visualization.html +146 -0
  553. rasa/shared/core/training_data/visualization.py +603 -0
  554. rasa/shared/data.py +249 -0
  555. rasa/shared/engine/__init__.py +0 -0
  556. rasa/shared/engine/caching.py +26 -0
  557. rasa/shared/exceptions.py +163 -0
  558. rasa/shared/importers/__init__.py +0 -0
  559. rasa/shared/importers/importer.py +704 -0
  560. rasa/shared/importers/multi_project.py +203 -0
  561. rasa/shared/importers/rasa.py +99 -0
  562. rasa/shared/importers/utils.py +34 -0
  563. rasa/shared/nlu/__init__.py +0 -0
  564. rasa/shared/nlu/constants.py +47 -0
  565. rasa/shared/nlu/interpreter.py +10 -0
  566. rasa/shared/nlu/training_data/__init__.py +0 -0
  567. rasa/shared/nlu/training_data/entities_parser.py +208 -0
  568. rasa/shared/nlu/training_data/features.py +492 -0
  569. rasa/shared/nlu/training_data/formats/__init__.py +10 -0
  570. rasa/shared/nlu/training_data/formats/dialogflow.py +163 -0
  571. rasa/shared/nlu/training_data/formats/luis.py +87 -0
  572. rasa/shared/nlu/training_data/formats/rasa.py +135 -0
  573. rasa/shared/nlu/training_data/formats/rasa_yaml.py +603 -0
  574. rasa/shared/nlu/training_data/formats/readerwriter.py +244 -0
  575. rasa/shared/nlu/training_data/formats/wit.py +52 -0
  576. rasa/shared/nlu/training_data/loading.py +137 -0
  577. rasa/shared/nlu/training_data/lookup_tables_parser.py +30 -0
  578. rasa/shared/nlu/training_data/message.py +490 -0
  579. rasa/shared/nlu/training_data/schemas/__init__.py +0 -0
  580. rasa/shared/nlu/training_data/schemas/data_schema.py +85 -0
  581. rasa/shared/nlu/training_data/schemas/nlu.yml +53 -0
  582. rasa/shared/nlu/training_data/schemas/responses.yml +70 -0
  583. rasa/shared/nlu/training_data/synonyms_parser.py +42 -0
  584. rasa/shared/nlu/training_data/training_data.py +730 -0
  585. rasa/shared/nlu/training_data/util.py +223 -0
  586. rasa/shared/providers/__init__.py +0 -0
  587. rasa/shared/providers/openai/__init__.py +0 -0
  588. rasa/shared/providers/openai/clients.py +43 -0
  589. rasa/shared/providers/openai/session_handler.py +110 -0
  590. rasa/shared/utils/__init__.py +0 -0
  591. rasa/shared/utils/cli.py +72 -0
  592. rasa/shared/utils/common.py +308 -0
  593. rasa/shared/utils/constants.py +4 -0
  594. rasa/shared/utils/io.py +415 -0
  595. rasa/shared/utils/llm.py +404 -0
  596. rasa/shared/utils/pykwalify_extensions.py +27 -0
  597. rasa/shared/utils/schemas/__init__.py +0 -0
  598. rasa/shared/utils/schemas/config.yml +2 -0
  599. rasa/shared/utils/schemas/domain.yml +145 -0
  600. rasa/shared/utils/schemas/events.py +212 -0
  601. rasa/shared/utils/schemas/model_config.yml +46 -0
  602. rasa/shared/utils/schemas/stories.yml +173 -0
  603. rasa/shared/utils/yaml.py +786 -0
  604. rasa/studio/__init__.py +0 -0
  605. rasa/studio/auth.py +268 -0
  606. rasa/studio/config.py +127 -0
  607. rasa/studio/constants.py +18 -0
  608. rasa/studio/data_handler.py +359 -0
  609. rasa/studio/download.py +483 -0
  610. rasa/studio/results_logger.py +137 -0
  611. rasa/studio/train.py +135 -0
  612. rasa/studio/upload.py +433 -0
  613. rasa/telemetry.py +1737 -0
  614. rasa/tracing/__init__.py +0 -0
  615. rasa/tracing/config.py +353 -0
  616. rasa/tracing/constants.py +62 -0
  617. rasa/tracing/instrumentation/__init__.py +0 -0
  618. rasa/tracing/instrumentation/attribute_extractors.py +672 -0
  619. rasa/tracing/instrumentation/instrumentation.py +1185 -0
  620. rasa/tracing/instrumentation/intentless_policy_instrumentation.py +144 -0
  621. rasa/tracing/instrumentation/metrics.py +294 -0
  622. rasa/tracing/metric_instrument_provider.py +205 -0
  623. rasa/utils/__init__.py +0 -0
  624. rasa/utils/beta.py +83 -0
  625. rasa/utils/cli.py +28 -0
  626. rasa/utils/common.py +635 -0
  627. rasa/utils/converter.py +53 -0
  628. rasa/utils/endpoints.py +302 -0
  629. rasa/utils/io.py +260 -0
  630. rasa/utils/licensing.py +534 -0
  631. rasa/utils/log_utils.py +174 -0
  632. rasa/utils/mapper.py +210 -0
  633. rasa/utils/ml_utils.py +145 -0
  634. rasa/utils/plotting.py +362 -0
  635. rasa/utils/singleton.py +23 -0
  636. rasa/utils/tensorflow/__init__.py +0 -0
  637. rasa/utils/tensorflow/callback.py +112 -0
  638. rasa/utils/tensorflow/constants.py +116 -0
  639. rasa/utils/tensorflow/crf.py +492 -0
  640. rasa/utils/tensorflow/data_generator.py +440 -0
  641. rasa/utils/tensorflow/environment.py +161 -0
  642. rasa/utils/tensorflow/exceptions.py +5 -0
  643. rasa/utils/tensorflow/feature_array.py +366 -0
  644. rasa/utils/tensorflow/layers.py +1565 -0
  645. rasa/utils/tensorflow/layers_utils.py +113 -0
  646. rasa/utils/tensorflow/metrics.py +281 -0
  647. rasa/utils/tensorflow/model_data.py +798 -0
  648. rasa/utils/tensorflow/model_data_utils.py +499 -0
  649. rasa/utils/tensorflow/models.py +935 -0
  650. rasa/utils/tensorflow/rasa_layers.py +1094 -0
  651. rasa/utils/tensorflow/transformer.py +640 -0
  652. rasa/utils/tensorflow/types.py +6 -0
  653. rasa/utils/train_utils.py +572 -0
  654. rasa/utils/url_tools.py +53 -0
  655. rasa/utils/yaml.py +54 -0
  656. rasa/validator.py +1337 -0
  657. rasa/version.py +3 -0
  658. rasa_pro-3.9.18.dist-info/METADATA +563 -0
  659. rasa_pro-3.9.18.dist-info/NOTICE +5 -0
  660. rasa_pro-3.9.18.dist-info/RECORD +662 -0
  661. rasa_pro-3.9.18.dist-info/WHEEL +4 -0
  662. rasa_pro-3.9.18.dist-info/entry_points.txt +3 -0
@@ -0,0 +1,640 @@
1
+ from typing import Optional, Text, Tuple, Union
2
+
3
+ import numpy as np
4
+ import tensorflow as tf
5
+
6
+ # TODO: The following is not (yet) available via tf.keras
7
+ from keras.src.utils.control_flow_util import smart_cond
8
+ from tensorflow.keras import backend as K
9
+
10
+ import rasa.shared.utils.cli
11
+ from rasa.utils.tensorflow.layers import RandomlyConnectedDense
12
+
13
+
14
+ # from https://www.tensorflow.org/tutorials/text/transformer
15
+ # and https://github.com/tensorflow/tensor2tensor
16
+ class MultiHeadAttention(tf.keras.layers.Layer):
17
+ """Multi-headed attention layer.
18
+
19
+ Arguments:
20
+ units: Positive integer, output dim of hidden layer.
21
+ num_heads: Positive integer, number of heads
22
+ to repeat the same attention structure.
23
+ attention_dropout_rate: Float, dropout rate inside attention for training.
24
+ density: Approximate fraction of trainable weights (in
25
+ `RandomlyConnectedDense` layers).
26
+ unidirectional: Boolean, use a unidirectional or bidirectional encoder.
27
+ use_key_relative_position: Boolean, if 'True' use key
28
+ relative embeddings in attention.
29
+ use_value_relative_position: Boolean, if 'True' use value
30
+ relative embeddings in attention.
31
+ max_relative_position: Positive integer, max position for relative embeddings.
32
+ heads_share_relative_embedding: Boolean, if 'True'
33
+ heads will share relative embeddings.
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ units: int,
39
+ num_heads: int,
40
+ attention_dropout_rate: float = 0.0,
41
+ density: float = 0.2,
42
+ unidirectional: bool = False,
43
+ use_key_relative_position: bool = False,
44
+ use_value_relative_position: bool = False,
45
+ max_relative_position: int = 5,
46
+ heads_share_relative_embedding: bool = False,
47
+ ) -> None:
48
+ super().__init__()
49
+
50
+ if units % num_heads != 0:
51
+ rasa.shared.utils.cli.print_error_and_exit(
52
+ f"Value Error: The given transformer size {units} should be a "
53
+ f"multiple of the number of attention heads {num_heads}."
54
+ )
55
+
56
+ self.num_heads = num_heads
57
+ self.units = units
58
+ self.attention_dropout_rate = attention_dropout_rate
59
+ self.unidirectional = unidirectional
60
+ self.use_key_relative_position = use_key_relative_position
61
+ self.use_value_relative_position = use_value_relative_position
62
+ self.relative_length = max_relative_position
63
+ self.relative_length += 1 # include current time
64
+ self.heads_share_relative_embedding = heads_share_relative_embedding
65
+
66
+ self._depth = units // self.num_heads
67
+
68
+ # process queries
69
+ self._query_dense_layer = RandomlyConnectedDense(
70
+ units=units, use_bias=False, density=density
71
+ )
72
+ # process keys
73
+ self._key_dense_layer = RandomlyConnectedDense(
74
+ units=units, use_bias=False, density=density
75
+ )
76
+ # process values
77
+ self._value_dense_layer = RandomlyConnectedDense(
78
+ units=units, use_bias=False, density=density
79
+ )
80
+ # process attention output
81
+ self._output_dense_layer = RandomlyConnectedDense(units=units, density=density)
82
+
83
+ self._create_relative_embeddings()
84
+
85
+ def _create_relative_embeddings(self) -> None:
86
+ """Create relative embeddings."""
87
+ relative_embedding_shape: Optional[
88
+ Union[Tuple[int, int], Tuple[int, int, int]]
89
+ ] = None
90
+ self.key_relative_embeddings = None
91
+ self.value_relative_embeddings = None
92
+
93
+ if self.use_key_relative_position or self.use_value_relative_position:
94
+ if not self.relative_length:
95
+ raise ValueError(
96
+ f"Max relative position {self.relative_length} "
97
+ f"should be > 0 when using relative attention."
98
+ )
99
+
100
+ if self.unidirectional:
101
+ relative_length = self.relative_length
102
+ else:
103
+ relative_length = 2 * self.relative_length - 1
104
+
105
+ if self.heads_share_relative_embedding:
106
+ relative_embedding_shape = (relative_length, self._depth)
107
+ else:
108
+ relative_embedding_shape = (
109
+ self.num_heads,
110
+ relative_length,
111
+ self._depth,
112
+ )
113
+
114
+ if self.use_key_relative_position:
115
+ self.key_relative_embeddings = self.add_weight(
116
+ shape=relative_embedding_shape, name="key_relative_embeddings"
117
+ )
118
+
119
+ if self.use_value_relative_position:
120
+ self.value_relative_embeddings = self.add_weight(
121
+ shape=relative_embedding_shape, name="value_relative_embeddings"
122
+ )
123
+
124
+ def _pad_relative_embeddings(self, x: tf.Tensor, length: tf.Tensor) -> tf.Tensor:
125
+ # pad the left side to length
126
+ pad_left = x[:, :, :, :1, :]
127
+ pad_left = tf.tile(pad_left, (1, 1, 1, length - self.relative_length, 1))
128
+
129
+ # pad the right side to length
130
+ if self.unidirectional:
131
+ right_relative_length = 1 # current time
132
+ pad_right = tf.zeros_like(x[:, :, :, -1:, :])
133
+ else:
134
+ right_relative_length = self.relative_length
135
+ pad_right = x[:, :, :, -1:, :]
136
+ pad_right = tf.tile(pad_right, (1, 1, 1, length - right_relative_length, 1))
137
+
138
+ return tf.concat([pad_left, x, pad_right], axis=-2)
139
+
140
+ def _slice_relative_embeddings(self, x: tf.Tensor, length: tf.Tensor) -> tf.Tensor:
141
+ if self.unidirectional:
142
+ # pad the right side to relative_length
143
+ pad_right = tf.zeros_like(x[:, :, :, -1:, :])
144
+ pad_right = tf.tile(pad_right, (1, 1, 1, self.relative_length - 1, 1))
145
+ x = tf.concat([x, pad_right], axis=-2)
146
+
147
+ extra_length = self.relative_length - length
148
+ full_length = tf.shape(x)[-2]
149
+ return x[:, :, :, extra_length : full_length - extra_length, :]
150
+
151
+ def _relative_to_absolute_position(self, x: tf.Tensor) -> tf.Tensor:
152
+ """Universal method to convert tensor from relative to absolute indexing.
153
+
154
+ "Slides" relative embeddings by 45 degree.
155
+
156
+ Arguments:
157
+ x: A tensor of shape (batch, num_heads, length, relative_length, depth)
158
+ or (batch, num_heads, length, relative_length)
159
+
160
+ Returns:
161
+ A tensor of shape (batch, num_heads, length, length, depth)
162
+ or (batch, num_heads, length, length)
163
+ """
164
+ x_dim = len(x.shape)
165
+
166
+ if x_dim < 4 or x_dim > 5:
167
+ raise ValueError(
168
+ f"Relative tensor has a wrong shape {x.shape}, "
169
+ f"it should have 4 or 5 dimensions."
170
+ )
171
+ if x_dim == 4:
172
+ # add fake depth dimension
173
+ x = tf.expand_dims(x, axis=-1)
174
+
175
+ batch = tf.shape(x)[0]
176
+ num_heads = tf.shape(x)[1]
177
+ length = tf.shape(x)[2]
178
+ depth = tf.shape(x)[-1]
179
+
180
+ x = tf.cond(
181
+ length > self.relative_length,
182
+ lambda: self._pad_relative_embeddings(x, length),
183
+ lambda: self._slice_relative_embeddings(x, length),
184
+ )
185
+
186
+ # add a column of zeros to "slide" columns to diagonals through reshape
187
+ pad_shift = tf.zeros_like(x[:, :, :, -1:, :])
188
+ x = tf.concat([x, pad_shift], axis=-2)
189
+
190
+ # flatten length dimensions
191
+ x = tf.reshape(x, (batch, num_heads, -1, depth))
192
+ width = 2 * length
193
+
194
+ # add zeros so that the result of back reshape is still a matrix
195
+ pad_flat = tf.zeros_like(
196
+ x[:, :, : ((width - 1) - width * length % (width - 1)) % (width - 1), :]
197
+ )
198
+ x = tf.concat([x, pad_flat], axis=-2)
199
+
200
+ # "slide" columns to diagonals through reshape
201
+ x = tf.reshape(x, (batch, num_heads, -1, width - 1, depth))
202
+
203
+ # slice needed "diagonal" matrix
204
+ x = x[:, :, :-1, -length:, :]
205
+
206
+ if x_dim == 4:
207
+ # remove fake depth dimension
208
+ x = tf.squeeze(x, axis=-1)
209
+
210
+ return x
211
+
212
+ def _matmul_with_relative_keys(self, x: tf.Tensor) -> tf.Tensor:
213
+ y = self.key_relative_embeddings
214
+
215
+ if self.heads_share_relative_embedding:
216
+ matmul = tf.einsum("bhld,md->bhlm", x, y)
217
+ else:
218
+ matmul = tf.einsum("bhld,hmd->bhlm", x, y)
219
+
220
+ return self._relative_to_absolute_position(matmul)
221
+
222
+ def _tile_relative_embeddings(self, x: tf.Tensor, length: tf.Tensor) -> tf.Tensor:
223
+ if self.heads_share_relative_embedding:
224
+ x = tf.expand_dims(x, axis=0) # add head dimension
225
+
226
+ x = tf.expand_dims(x, axis=1) # add length dimension
227
+ x = tf.tile(x, (1, length, 1, 1))
228
+ return tf.expand_dims(x, axis=0) # add batch dimension
229
+
230
+ def _squeeze_relative_embeddings(self, x: tf.Tensor) -> tf.Tensor:
231
+ x = tf.squeeze(x, axis=0) # squeeze batch dimension
232
+ if self.heads_share_relative_embedding:
233
+ x = tf.squeeze(x, axis=1) # squeeze head dimension
234
+ return x
235
+
236
+ def _matmul_with_relative_values(self, x: tf.Tensor) -> tf.Tensor:
237
+ y = self._tile_relative_embeddings(
238
+ self.value_relative_embeddings, tf.shape(x)[-2]
239
+ )
240
+ y = self._relative_to_absolute_position(y)
241
+ y = self._squeeze_relative_embeddings(y)
242
+
243
+ if self.heads_share_relative_embedding:
244
+ return tf.einsum("bhlm,lmd->bhld", x, y)
245
+ else:
246
+ return tf.einsum("bhlm,hlmd->bhld", x, y)
247
+
248
+ def _drop_attention_logits(
249
+ self, logits: tf.Tensor, pad_mask: tf.Tensor, training: tf.Tensor
250
+ ) -> tf.Tensor:
251
+ def droped_logits() -> tf.Tensor:
252
+ keep_prob = tf.random.uniform(tf.shape(logits), 0, 1) + pad_mask
253
+ drop_mask = tf.cast(
254
+ tf.less(keep_prob, self.attention_dropout_rate), logits.dtype
255
+ )
256
+
257
+ return logits + drop_mask * -1e9
258
+
259
+ return smart_cond(training, droped_logits, lambda: tf.identity(logits))
260
+
261
+ def _scaled_dot_product_attention(
262
+ self,
263
+ query: tf.Tensor,
264
+ key: tf.Tensor,
265
+ value: tf.Tensor,
266
+ pad_mask: tf.Tensor,
267
+ training: tf.Tensor,
268
+ ) -> Tuple[tf.Tensor, tf.Tensor]:
269
+ """Calculate the attention weights.
270
+
271
+ query, key, value must have matching leading dimensions.
272
+ key, value must have matching penultimate dimension,
273
+ i.e.: seq_len_k = seq_len_v.
274
+ The mask has different shapes depending on its type (padding or look ahead)
275
+ but it must be broadcastable for addition.
276
+
277
+ Arguments:
278
+ query: A tensor with shape (..., length, depth).
279
+ key: A tensor with shape (..., length, depth).
280
+ value: A tensor with shape (..., length, depth).
281
+ pad_mask: Float tensor with shape broadcastable
282
+ to (..., length, length). Defaults to None.
283
+
284
+ Returns:
285
+ output: A tensor with shape (..., length, depth).
286
+ attention_weights: A tensor with shape (..., length, length).
287
+ """
288
+ matmul_qk = tf.matmul(query, key, transpose_b=True) # (..., length, length)
289
+
290
+ if self.use_key_relative_position:
291
+ matmul_qk += self._matmul_with_relative_keys(query)
292
+
293
+ # scale matmul_qk
294
+ dk = tf.cast(tf.shape(key)[-1], tf.float32)
295
+ logits = matmul_qk / tf.math.sqrt(dk)
296
+
297
+ # add the mask to the scaled tensor.
298
+ if pad_mask is not None:
299
+ logits += pad_mask * -1e9
300
+
301
+ # apply attention dropout before softmax to maintain attention_weights norm as 1
302
+ if self.attention_dropout_rate > 0:
303
+ logits = self._drop_attention_logits(logits, pad_mask, training)
304
+
305
+ # softmax is normalized on the last axis (length) so that the scores
306
+ # add up to 1.
307
+ attention_weights = tf.nn.softmax(logits, axis=-1) # (..., length, length)
308
+
309
+ output = tf.matmul(attention_weights, value) # (..., length, depth)
310
+ if self.use_value_relative_position:
311
+ output += self._matmul_with_relative_values(attention_weights)
312
+
313
+ return output, attention_weights
314
+
315
+ def _split_heads(self, x: tf.Tensor) -> tf.Tensor:
316
+ """Split the last dimension into (num_heads, depth).
317
+
318
+ Transpose the result such that the shape is
319
+ (batch_size, num_heads, length, depth)
320
+ """
321
+ x = tf.reshape(x, (tf.shape(x)[0], -1, self.num_heads, self._depth))
322
+ return tf.transpose(x, perm=[0, 2, 1, 3])
323
+
324
+ def _combine_heads(self, x: tf.Tensor) -> tf.Tensor:
325
+ """Inverse of split_heads.
326
+
327
+ Args:
328
+ x: A Tensor with shape [batch, num_heads, length, units / num_heads]
329
+
330
+ Returns:
331
+ A Tensor with shape [batch, length, units]
332
+ """
333
+ # (batch_size, length, num_heads, depth)
334
+ x = tf.transpose(x, perm=[0, 2, 1, 3])
335
+ # (batch_size, length, units)
336
+ return tf.reshape(x, (tf.shape(x)[0], -1, self.units))
337
+
338
+ # noinspection PyMethodOverriding
339
+ def call(
340
+ self,
341
+ query_input: tf.Tensor,
342
+ source_input: tf.Tensor,
343
+ pad_mask: Optional[tf.Tensor] = None,
344
+ training: Optional[Union[tf.Tensor, bool]] = None,
345
+ ) -> Tuple[tf.Tensor, tf.Tensor]:
346
+ """Apply attention mechanism to query_input and source_input.
347
+
348
+ Arguments:
349
+ query_input: A tensor with shape [batch_size, length, input_size].
350
+ source_input: A tensor with shape [batch_size, length, input_size].
351
+ pad_mask: Float tensor with shape broadcastable
352
+ to (..., length, length). Defaults to None.
353
+ training: A bool, whether in training mode or not.
354
+
355
+ Returns:
356
+ Attention layer output with shape [batch_size, length, units]
357
+ """
358
+ if training is None:
359
+ training = K.learning_phase()
360
+
361
+ query = self._query_dense_layer(query_input) # (batch_size, length, units)
362
+ key = self._key_dense_layer(source_input) # (batch_size, length, units)
363
+ value = self._value_dense_layer(source_input) # (batch_size, length, units)
364
+
365
+ query = self._split_heads(query) # (batch_size, num_heads, length, depth)
366
+ key = self._split_heads(key) # (batch_size, num_heads, length, depth)
367
+ value = self._split_heads(value) # (batch_size, num_heads, length, depth)
368
+
369
+ attention, attention_weights = self._scaled_dot_product_attention(
370
+ query, key, value, pad_mask, training
371
+ )
372
+ # attention.shape == (batch_size, num_heads, length, depth)
373
+ # attention_weights.shape == (batch_size, num_heads, length, length)
374
+ attention = self._combine_heads(attention) # (batch_size, length, units)
375
+
376
+ output = self._output_dense_layer(attention) # (batch_size, length, units)
377
+
378
+ return output, attention_weights
379
+
380
+
381
+ class TransformerEncoderLayer(tf.keras.layers.Layer):
382
+ """Transformer encoder layer.
383
+
384
+ The layer is composed of the sublayers:
385
+ 1. Self-attention layer
386
+ 2. Feed-forward network (which is 2 fully-connected layers)
387
+
388
+ Arguments:
389
+ units: Positive integer, output dim of hidden layer.
390
+ num_heads: Positive integer, number of heads
391
+ to repeat the same attention structure.
392
+ filter_units: Positive integer, output dim of the first ffn hidden layer.
393
+ dropout_rate: Float between 0 and 1; fraction of the input units to drop.
394
+ attention_dropout_rate: Float, dropout rate inside attention for training.
395
+ density: Fraction of trainable weights in `RandomlyConnectedDense` layers.
396
+ unidirectional: Boolean, use a unidirectional or bidirectional encoder.
397
+ use_key_relative_position: Boolean, if 'True' use key
398
+ relative embeddings in attention.
399
+ use_value_relative_position: Boolean, if 'True' use value
400
+ relative embeddings in attention.
401
+ max_relative_position: Positive integer, max position for relative embeddings.
402
+ heads_share_relative_embedding: Boolean, if 'True'
403
+ heads will share relative embeddings.
404
+ """
405
+
406
+ def __init__(
407
+ self,
408
+ units: int,
409
+ num_heads: int,
410
+ filter_units: int,
411
+ dropout_rate: float = 0.1,
412
+ attention_dropout_rate: float = 0.0,
413
+ density: float = 0.2,
414
+ unidirectional: bool = False,
415
+ use_key_relative_position: bool = False,
416
+ use_value_relative_position: bool = False,
417
+ max_relative_position: int = 5,
418
+ heads_share_relative_embedding: bool = False,
419
+ ) -> None:
420
+ super().__init__()
421
+
422
+ self._layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-6)
423
+ self._mha = MultiHeadAttention(
424
+ units,
425
+ num_heads,
426
+ attention_dropout_rate,
427
+ density,
428
+ unidirectional,
429
+ use_key_relative_position,
430
+ use_value_relative_position,
431
+ max_relative_position,
432
+ heads_share_relative_embedding,
433
+ )
434
+ self._dropout = tf.keras.layers.Dropout(dropout_rate)
435
+
436
+ self._ffn_layers = [
437
+ tf.keras.layers.LayerNormalization(epsilon=1e-6),
438
+ RandomlyConnectedDense(
439
+ units=filter_units, activation=tf.nn.gelu, density=density
440
+ ), # (batch_size, length, filter_units)
441
+ tf.keras.layers.Dropout(dropout_rate),
442
+ RandomlyConnectedDense(
443
+ units=units, density=density
444
+ ), # (batch_size, length, units)
445
+ tf.keras.layers.Dropout(dropout_rate),
446
+ ]
447
+
448
+ def call(
449
+ self,
450
+ x: tf.Tensor,
451
+ pad_mask: Optional[tf.Tensor] = None,
452
+ training: Optional[Union[tf.Tensor, bool]] = None,
453
+ ) -> Tuple[tf.Tensor, tf.Tensor]:
454
+ """Apply transformer encoder layer.
455
+
456
+ Arguments:
457
+ x: A tensor with shape [batch_size, length, units].
458
+ pad_mask: Float tensor with shape broadcastable
459
+ to (..., length, length). Defaults to None.
460
+ training: A bool, whether in training mode or not.
461
+
462
+ Returns:
463
+ Transformer encoder layer output with shape [batch_size, length, units]
464
+ """
465
+ if training is None:
466
+ training = K.learning_phase()
467
+
468
+ x_norm = self._layer_norm(x) # (batch_size, length, units)
469
+ attn_out, attn_weights = self._mha(
470
+ x_norm, x_norm, pad_mask=pad_mask, training=training
471
+ )
472
+ attn_out = self._dropout(attn_out, training=training)
473
+ x += attn_out
474
+
475
+ ffn_out = x # (batch_size, length, units)
476
+ for layer in self._ffn_layers:
477
+ ffn_out = layer(ffn_out, training=training)
478
+ x += ffn_out
479
+
480
+ # (batch_size, length, units), (batch_size, num_heads, length, length)
481
+ return x, attn_weights
482
+
483
+
484
+ class TransformerEncoder(tf.keras.layers.Layer):
485
+ """Transformer encoder.
486
+
487
+ Encoder stack is made up of `num_layers` identical encoder layers.
488
+
489
+ Arguments:
490
+ num_layers: Positive integer, number of encoder layers.
491
+ units: Positive integer, output dim of hidden layer.
492
+ num_heads: Positive integer, number of heads
493
+ to repeat the same attention structure.
494
+ filter_units: Positive integer, output dim of the first ffn hidden layer.
495
+ reg_lambda: Float, regularization factor.
496
+ dropout_rate: Float between 0 and 1; fraction of the input units to drop.
497
+ attention_dropout_rate: Float, dropout rate inside attention for training.
498
+ density: Approximate fraction of trainable weights (in
499
+ `RandomlyConnectedDense` layers).
500
+ unidirectional: Boolean, use a unidirectional or bidirectional encoder.
501
+ use_key_relative_position: Boolean, if 'True' use key
502
+ relative embeddings in attention.
503
+ use_value_relative_position: Boolean, if 'True' use value
504
+ relative embeddings in attention.
505
+ max_relative_position: Positive integer, max position for relative embeddings.
506
+ heads_share_relative_embedding: Boolean, if 'True'
507
+ heads will share relative embeddings.
508
+ name: Optional name of the layer.
509
+ """
510
+
511
+ def __init__(
512
+ self,
513
+ num_layers: int,
514
+ units: int,
515
+ num_heads: int,
516
+ filter_units: int,
517
+ reg_lambda: float,
518
+ dropout_rate: float = 0.1,
519
+ attention_dropout_rate: float = 0.0,
520
+ density: float = 0.2,
521
+ unidirectional: bool = False,
522
+ use_key_relative_position: bool = False,
523
+ use_value_relative_position: bool = False,
524
+ max_relative_position: int = 5,
525
+ heads_share_relative_embedding: bool = False,
526
+ name: Optional[Text] = None,
527
+ ) -> None:
528
+ super().__init__(name=name)
529
+
530
+ self.units = units
531
+ self.unidirectional = unidirectional
532
+
533
+ l2_regularizer = tf.keras.regularizers.l2(reg_lambda)
534
+ self._embedding = RandomlyConnectedDense(
535
+ units=units, kernel_regularizer=l2_regularizer, density=density
536
+ )
537
+ # positional encoding helpers
538
+ self._angles = self._get_angles()
539
+ self._even_indices = np.arange(0, self.units, 2, dtype=np.int32)[:, np.newaxis]
540
+ self._odd_indices = np.arange(1, self.units, 2, dtype=np.int32)[:, np.newaxis]
541
+
542
+ self._dropout = tf.keras.layers.Dropout(dropout_rate)
543
+
544
+ self._enc_layers = [
545
+ TransformerEncoderLayer(
546
+ units,
547
+ num_heads,
548
+ filter_units,
549
+ dropout_rate,
550
+ attention_dropout_rate,
551
+ density,
552
+ unidirectional,
553
+ use_key_relative_position,
554
+ use_value_relative_position,
555
+ max_relative_position,
556
+ heads_share_relative_embedding,
557
+ )
558
+ for _ in range(num_layers)
559
+ ]
560
+ self._layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-6)
561
+
562
+ def _get_angles(self) -> np.ndarray:
563
+ array_2d = np.arange(self.units)[np.newaxis, :]
564
+ return 1 / np.power(10000, (2 * (array_2d // 2)) / np.float32(self.units))
565
+
566
+ def _positional_encoding(self, max_position: tf.Tensor) -> tf.Tensor:
567
+ max_position = tf.cast(max_position, dtype=tf.float32)
568
+ angle_rads = tf.range(max_position)[:, tf.newaxis] * self._angles
569
+
570
+ # transpose for easy slicing
571
+ angle_rads = tf.transpose(angle_rads, perm=[1, 0])
572
+ shape = tf.shape(angle_rads)
573
+ # apply sin to even indices in the array; 2i
574
+ sin_even = tf.sin(tf.gather_nd(angle_rads, self._even_indices))
575
+ pos_encoding_even = tf.scatter_nd(self._even_indices, sin_even, shape)
576
+ # apply cos to odd indices in the array; 2i+1
577
+ cos_odd = tf.cos(tf.gather_nd(angle_rads, self._odd_indices))
578
+ pos_encoding_odd = tf.scatter_nd(self._odd_indices, cos_odd, shape)
579
+ # combine even and odd positions and transpose back
580
+ pos_encoding = tf.transpose(pos_encoding_even + pos_encoding_odd, perm=[1, 0])
581
+ # add batch dimension
582
+ return tf.stop_gradient(pos_encoding[tf.newaxis, ...])
583
+
584
+ @staticmethod
585
+ def _look_ahead_pad_mask(max_position: tf.Tensor) -> tf.Tensor:
586
+ pad_mask = 1 - tf.linalg.band_part(tf.ones((max_position, max_position)), -1, 0)
587
+ return pad_mask[tf.newaxis, tf.newaxis, :, :] # (1, 1, seq_len, seq_len)
588
+
589
+ def call(
590
+ self,
591
+ x: tf.Tensor,
592
+ pad_mask: Optional[tf.Tensor] = None,
593
+ training: Optional[Union[tf.Tensor, bool]] = None,
594
+ ) -> Tuple[tf.Tensor, tf.Tensor]:
595
+ """Apply transformer encoder.
596
+
597
+ Arguments:
598
+ x: A tensor with shape [batch_size, length, input_size].
599
+ pad_mask: Float tensor with shape broadcastable
600
+ to (..., length, length). Defaults to None.
601
+ training: A bool, whether in training mode or not.
602
+
603
+ Returns:
604
+ Transformer encoder output with shape [batch_size, length, units]
605
+ """
606
+ # adding embedding and position encoding.
607
+ x = self._embedding(x) # (batch_size, length, units)
608
+ x *= tf.math.sqrt(tf.cast(self.units, tf.float32))
609
+ x += self._positional_encoding(tf.shape(x)[1])
610
+ x = self._dropout(x, training=training)
611
+
612
+ if pad_mask is not None:
613
+ pad_mask = tf.squeeze(pad_mask, -1) # (batch_size, length)
614
+ pad_mask = pad_mask[:, tf.newaxis, tf.newaxis, :]
615
+ # pad_mask.shape = (batch_size, 1, 1, length)
616
+ if self.unidirectional:
617
+ # add look ahead pad mask to emulate unidirectional behavior
618
+ pad_mask = tf.minimum(
619
+ 1.0, pad_mask + self._look_ahead_pad_mask(tf.shape(pad_mask)[-1])
620
+ ) # (batch_size, 1, length, length)
621
+
622
+ layer_attention_weights = []
623
+
624
+ for layer in self._enc_layers:
625
+ x, attn_weights = layer(x, pad_mask=pad_mask, training=training)
626
+ layer_attention_weights.append(attn_weights)
627
+
628
+ # if normalization is done in encoding layers, then it should also be done
629
+ # on the output, since the output can grow very large, being the sum of
630
+ # a whole stack of unnormalized layer outputs.
631
+ x = self._layer_norm(x) # (batch_size, length, units)
632
+
633
+ # Keep the batch dimension on the first axis
634
+ attention_weights_as_output = tf.transpose(
635
+ tf.stack(layer_attention_weights), (1, 0, 2, 3, 4)
636
+ )
637
+
638
+ # (batch_size, length, units),
639
+ # (batch_size, num_layers, num_heads, length, length)
640
+ return x, attention_weights_as_output
@@ -0,0 +1,6 @@
1
+ from typing import Tuple, Union
2
+ import tensorflow as tf
3
+ import numpy as np
4
+
5
+ BatchData = Union[Tuple[tf.Tensor, ...], Tuple[np.ndarray, ...]]
6
+ MaybeNestedBatchData = Union[Tuple[BatchData, ...], BatchData]