rasa-pro 3.9.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rasa-pro might be problematic. Click here for more details.

Files changed (662) hide show
  1. README.md +415 -0
  2. rasa/__init__.py +10 -0
  3. rasa/__main__.py +156 -0
  4. rasa/anonymization/__init__.py +2 -0
  5. rasa/anonymization/anonymisation_rule_yaml_reader.py +91 -0
  6. rasa/anonymization/anonymization_pipeline.py +286 -0
  7. rasa/anonymization/anonymization_rule_executor.py +260 -0
  8. rasa/anonymization/anonymization_rule_orchestrator.py +120 -0
  9. rasa/anonymization/schemas/config.yml +47 -0
  10. rasa/anonymization/utils.py +118 -0
  11. rasa/api.py +146 -0
  12. rasa/cli/__init__.py +5 -0
  13. rasa/cli/arguments/__init__.py +0 -0
  14. rasa/cli/arguments/data.py +81 -0
  15. rasa/cli/arguments/default_arguments.py +165 -0
  16. rasa/cli/arguments/evaluate.py +65 -0
  17. rasa/cli/arguments/export.py +51 -0
  18. rasa/cli/arguments/interactive.py +74 -0
  19. rasa/cli/arguments/run.py +204 -0
  20. rasa/cli/arguments/shell.py +13 -0
  21. rasa/cli/arguments/test.py +211 -0
  22. rasa/cli/arguments/train.py +263 -0
  23. rasa/cli/arguments/visualize.py +34 -0
  24. rasa/cli/arguments/x.py +30 -0
  25. rasa/cli/data.py +292 -0
  26. rasa/cli/e2e_test.py +586 -0
  27. rasa/cli/evaluate.py +222 -0
  28. rasa/cli/export.py +250 -0
  29. rasa/cli/inspect.py +63 -0
  30. rasa/cli/interactive.py +164 -0
  31. rasa/cli/license.py +65 -0
  32. rasa/cli/markers.py +78 -0
  33. rasa/cli/project_templates/__init__.py +0 -0
  34. rasa/cli/project_templates/calm/actions/__init__.py +0 -0
  35. rasa/cli/project_templates/calm/actions/action_template.py +27 -0
  36. rasa/cli/project_templates/calm/actions/add_contact.py +30 -0
  37. rasa/cli/project_templates/calm/actions/db.py +57 -0
  38. rasa/cli/project_templates/calm/actions/list_contacts.py +22 -0
  39. rasa/cli/project_templates/calm/actions/remove_contact.py +35 -0
  40. rasa/cli/project_templates/calm/config.yml +12 -0
  41. rasa/cli/project_templates/calm/credentials.yml +33 -0
  42. rasa/cli/project_templates/calm/data/flows/add_contact.yml +31 -0
  43. rasa/cli/project_templates/calm/data/flows/list_contacts.yml +14 -0
  44. rasa/cli/project_templates/calm/data/flows/remove_contact.yml +29 -0
  45. rasa/cli/project_templates/calm/db/contacts.json +10 -0
  46. rasa/cli/project_templates/calm/domain/add_contact.yml +39 -0
  47. rasa/cli/project_templates/calm/domain/list_contacts.yml +17 -0
  48. rasa/cli/project_templates/calm/domain/remove_contact.yml +38 -0
  49. rasa/cli/project_templates/calm/domain/shared.yml +10 -0
  50. rasa/cli/project_templates/calm/e2e_tests/cancelations/user_cancels_during_a_correction.yml +16 -0
  51. rasa/cli/project_templates/calm/e2e_tests/cancelations/user_changes_mind_on_a_whim.yml +7 -0
  52. rasa/cli/project_templates/calm/e2e_tests/corrections/user_corrects_contact_handle.yml +20 -0
  53. rasa/cli/project_templates/calm/e2e_tests/corrections/user_corrects_contact_name.yml +19 -0
  54. rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_adds_contact_to_their_list.yml +15 -0
  55. rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_lists_contacts.yml +5 -0
  56. rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_removes_contact.yml +11 -0
  57. rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_removes_contact_from_list.yml +12 -0
  58. rasa/cli/project_templates/calm/endpoints.yml +45 -0
  59. rasa/cli/project_templates/default/actions/__init__.py +0 -0
  60. rasa/cli/project_templates/default/actions/actions.py +27 -0
  61. rasa/cli/project_templates/default/config.yml +44 -0
  62. rasa/cli/project_templates/default/credentials.yml +33 -0
  63. rasa/cli/project_templates/default/data/nlu.yml +91 -0
  64. rasa/cli/project_templates/default/data/rules.yml +13 -0
  65. rasa/cli/project_templates/default/data/stories.yml +30 -0
  66. rasa/cli/project_templates/default/domain.yml +34 -0
  67. rasa/cli/project_templates/default/endpoints.yml +42 -0
  68. rasa/cli/project_templates/default/tests/test_stories.yml +91 -0
  69. rasa/cli/project_templates/tutorial/actions.py +22 -0
  70. rasa/cli/project_templates/tutorial/config.yml +11 -0
  71. rasa/cli/project_templates/tutorial/credentials.yml +33 -0
  72. rasa/cli/project_templates/tutorial/data/flows.yml +8 -0
  73. rasa/cli/project_templates/tutorial/data/patterns.yml +6 -0
  74. rasa/cli/project_templates/tutorial/domain.yml +21 -0
  75. rasa/cli/project_templates/tutorial/endpoints.yml +45 -0
  76. rasa/cli/run.py +135 -0
  77. rasa/cli/scaffold.py +269 -0
  78. rasa/cli/shell.py +141 -0
  79. rasa/cli/studio/__init__.py +0 -0
  80. rasa/cli/studio/download.py +62 -0
  81. rasa/cli/studio/studio.py +266 -0
  82. rasa/cli/studio/train.py +59 -0
  83. rasa/cli/studio/upload.py +77 -0
  84. rasa/cli/telemetry.py +102 -0
  85. rasa/cli/test.py +280 -0
  86. rasa/cli/train.py +260 -0
  87. rasa/cli/utils.py +464 -0
  88. rasa/cli/visualize.py +40 -0
  89. rasa/cli/x.py +206 -0
  90. rasa/constants.py +37 -0
  91. rasa/core/__init__.py +17 -0
  92. rasa/core/actions/__init__.py +0 -0
  93. rasa/core/actions/action.py +1225 -0
  94. rasa/core/actions/action_clean_stack.py +59 -0
  95. rasa/core/actions/action_exceptions.py +24 -0
  96. rasa/core/actions/action_run_slot_rejections.py +207 -0
  97. rasa/core/actions/action_trigger_chitchat.py +31 -0
  98. rasa/core/actions/action_trigger_flow.py +109 -0
  99. rasa/core/actions/action_trigger_search.py +31 -0
  100. rasa/core/actions/constants.py +5 -0
  101. rasa/core/actions/custom_action_executor.py +188 -0
  102. rasa/core/actions/forms.py +741 -0
  103. rasa/core/actions/grpc_custom_action_executor.py +251 -0
  104. rasa/core/actions/http_custom_action_executor.py +140 -0
  105. rasa/core/actions/loops.py +114 -0
  106. rasa/core/actions/two_stage_fallback.py +186 -0
  107. rasa/core/agent.py +555 -0
  108. rasa/core/auth_retry_tracker_store.py +122 -0
  109. rasa/core/brokers/__init__.py +0 -0
  110. rasa/core/brokers/broker.py +126 -0
  111. rasa/core/brokers/file.py +58 -0
  112. rasa/core/brokers/kafka.py +322 -0
  113. rasa/core/brokers/pika.py +386 -0
  114. rasa/core/brokers/sql.py +86 -0
  115. rasa/core/channels/__init__.py +55 -0
  116. rasa/core/channels/audiocodes.py +463 -0
  117. rasa/core/channels/botframework.py +338 -0
  118. rasa/core/channels/callback.py +84 -0
  119. rasa/core/channels/channel.py +419 -0
  120. rasa/core/channels/console.py +241 -0
  121. rasa/core/channels/development_inspector.py +93 -0
  122. rasa/core/channels/facebook.py +419 -0
  123. rasa/core/channels/hangouts.py +329 -0
  124. rasa/core/channels/inspector/.eslintrc.cjs +25 -0
  125. rasa/core/channels/inspector/.gitignore +23 -0
  126. rasa/core/channels/inspector/README.md +54 -0
  127. rasa/core/channels/inspector/assets/favicon.ico +0 -0
  128. rasa/core/channels/inspector/assets/rasa-chat.js +2 -0
  129. rasa/core/channels/inspector/custom.d.ts +3 -0
  130. rasa/core/channels/inspector/dist/assets/arc-b6e548fe.js +1 -0
  131. rasa/core/channels/inspector/dist/assets/array-9f3ba611.js +1 -0
  132. rasa/core/channels/inspector/dist/assets/c4Diagram-d0fbc5ce-fa03ac9e.js +10 -0
  133. rasa/core/channels/inspector/dist/assets/classDiagram-936ed81e-ee67392a.js +2 -0
  134. rasa/core/channels/inspector/dist/assets/classDiagram-v2-c3cb15f1-9b283fae.js +2 -0
  135. rasa/core/channels/inspector/dist/assets/createText-62fc7601-8b6fcc2a.js +7 -0
  136. rasa/core/channels/inspector/dist/assets/edges-f2ad444c-22e77f4f.js +4 -0
  137. rasa/core/channels/inspector/dist/assets/erDiagram-9d236eb7-60ffc87f.js +51 -0
  138. rasa/core/channels/inspector/dist/assets/flowDb-1972c806-9dd802e4.js +6 -0
  139. rasa/core/channels/inspector/dist/assets/flowDiagram-7ea5b25a-5fa1912f.js +4 -0
  140. rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-1844e5a5.js +1 -0
  141. rasa/core/channels/inspector/dist/assets/flowchart-elk-definition-abe16c3d-622a1fd2.js +139 -0
  142. rasa/core/channels/inspector/dist/assets/ganttDiagram-9b5ea136-e285a63a.js +266 -0
  143. rasa/core/channels/inspector/dist/assets/gitGraphDiagram-99d0ae7c-f237bdca.js +70 -0
  144. rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-128cfa44.ttf +0 -0
  145. rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-21dbcb97.woff +0 -0
  146. rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-222b5e26.svg +329 -0
  147. rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-9ad89b2a.woff2 +0 -0
  148. rasa/core/channels/inspector/dist/assets/index-2c4b9a3b-4b03d70e.js +1 -0
  149. rasa/core/channels/inspector/dist/assets/index-3ee28881.css +1 -0
  150. rasa/core/channels/inspector/dist/assets/index-a5d3e69d.js +1040 -0
  151. rasa/core/channels/inspector/dist/assets/infoDiagram-736b4530-72a0fa5f.js +7 -0
  152. rasa/core/channels/inspector/dist/assets/init-77b53fdd.js +1 -0
  153. rasa/core/channels/inspector/dist/assets/journeyDiagram-df861f2b-82218c41.js +139 -0
  154. rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-60c05ee4.woff +0 -0
  155. rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-8335d9b8.svg +438 -0
  156. rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-9cc39c75.ttf +0 -0
  157. rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-ead13ccf.woff2 +0 -0
  158. rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-16705655.woff2 +0 -0
  159. rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-5aeb07f9.woff +0 -0
  160. rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-9c459044.ttf +0 -0
  161. rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-9e2898a4.svg +435 -0
  162. rasa/core/channels/inspector/dist/assets/layout-78cff630.js +1 -0
  163. rasa/core/channels/inspector/dist/assets/line-5038b469.js +1 -0
  164. rasa/core/channels/inspector/dist/assets/linear-c4fc4098.js +1 -0
  165. rasa/core/channels/inspector/dist/assets/mindmap-definition-beec6740-c33c8ea6.js +109 -0
  166. rasa/core/channels/inspector/dist/assets/ordinal-ba9b4969.js +1 -0
  167. rasa/core/channels/inspector/dist/assets/path-53f90ab3.js +1 -0
  168. rasa/core/channels/inspector/dist/assets/pieDiagram-dbbf0591-a8d03059.js +35 -0
  169. rasa/core/channels/inspector/dist/assets/quadrantDiagram-4d7f4fd6-6a0e56b2.js +7 -0
  170. rasa/core/channels/inspector/dist/assets/requirementDiagram-6fc4c22a-2dc7c7bd.js +52 -0
  171. rasa/core/channels/inspector/dist/assets/sankeyDiagram-8f13d901-2360fe39.js +8 -0
  172. rasa/core/channels/inspector/dist/assets/sequenceDiagram-b655622a-41b9f9ad.js +122 -0
  173. rasa/core/channels/inspector/dist/assets/stateDiagram-59f0c015-0aad326f.js +1 -0
  174. rasa/core/channels/inspector/dist/assets/stateDiagram-v2-2b26beab-9847d984.js +1 -0
  175. rasa/core/channels/inspector/dist/assets/styles-080da4f6-564d890e.js +110 -0
  176. rasa/core/channels/inspector/dist/assets/styles-3dcbcfbf-38957613.js +159 -0
  177. rasa/core/channels/inspector/dist/assets/styles-9c745c82-f0fc6921.js +207 -0
  178. rasa/core/channels/inspector/dist/assets/svgDrawCommon-4835440b-ef3c5a77.js +1 -0
  179. rasa/core/channels/inspector/dist/assets/timeline-definition-5b62e21b-bf3e91c1.js +61 -0
  180. rasa/core/channels/inspector/dist/assets/xychartDiagram-2b33534f-4d4026c0.js +7 -0
  181. rasa/core/channels/inspector/dist/index.html +41 -0
  182. rasa/core/channels/inspector/index.html +39 -0
  183. rasa/core/channels/inspector/jest.config.ts +13 -0
  184. rasa/core/channels/inspector/package.json +48 -0
  185. rasa/core/channels/inspector/setupTests.ts +2 -0
  186. rasa/core/channels/inspector/src/App.tsx +170 -0
  187. rasa/core/channels/inspector/src/components/DiagramFlow.tsx +107 -0
  188. rasa/core/channels/inspector/src/components/DialogueInformation.tsx +187 -0
  189. rasa/core/channels/inspector/src/components/DialogueStack.tsx +151 -0
  190. rasa/core/channels/inspector/src/components/ExpandIcon.tsx +16 -0
  191. rasa/core/channels/inspector/src/components/FullscreenButton.tsx +45 -0
  192. rasa/core/channels/inspector/src/components/LoadingSpinner.tsx +19 -0
  193. rasa/core/channels/inspector/src/components/NoActiveFlow.tsx +21 -0
  194. rasa/core/channels/inspector/src/components/RasaLogo.tsx +32 -0
  195. rasa/core/channels/inspector/src/components/SaraDiagrams.tsx +39 -0
  196. rasa/core/channels/inspector/src/components/Slots.tsx +91 -0
  197. rasa/core/channels/inspector/src/components/Welcome.tsx +54 -0
  198. rasa/core/channels/inspector/src/helpers/formatters.test.ts +382 -0
  199. rasa/core/channels/inspector/src/helpers/formatters.ts +240 -0
  200. rasa/core/channels/inspector/src/helpers/utils.ts +42 -0
  201. rasa/core/channels/inspector/src/main.tsx +13 -0
  202. rasa/core/channels/inspector/src/theme/Button/Button.ts +29 -0
  203. rasa/core/channels/inspector/src/theme/Heading/Heading.ts +31 -0
  204. rasa/core/channels/inspector/src/theme/Input/Input.ts +27 -0
  205. rasa/core/channels/inspector/src/theme/Link/Link.ts +10 -0
  206. rasa/core/channels/inspector/src/theme/Modal/Modal.ts +47 -0
  207. rasa/core/channels/inspector/src/theme/Table/Table.tsx +38 -0
  208. rasa/core/channels/inspector/src/theme/Tooltip/Tooltip.ts +12 -0
  209. rasa/core/channels/inspector/src/theme/base/breakpoints.ts +8 -0
  210. rasa/core/channels/inspector/src/theme/base/colors.ts +88 -0
  211. rasa/core/channels/inspector/src/theme/base/fonts/fontFaces.css +29 -0
  212. rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.eot +0 -0
  213. rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.svg +329 -0
  214. rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.ttf +0 -0
  215. rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.woff +0 -0
  216. rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.woff2 +0 -0
  217. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.eot +0 -0
  218. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.svg +438 -0
  219. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.ttf +0 -0
  220. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.woff +0 -0
  221. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.woff2 +0 -0
  222. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.eot +0 -0
  223. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.svg +435 -0
  224. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.ttf +0 -0
  225. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.woff +0 -0
  226. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.woff2 +0 -0
  227. rasa/core/channels/inspector/src/theme/base/radii.ts +9 -0
  228. rasa/core/channels/inspector/src/theme/base/shadows.ts +7 -0
  229. rasa/core/channels/inspector/src/theme/base/sizes.ts +7 -0
  230. rasa/core/channels/inspector/src/theme/base/space.ts +15 -0
  231. rasa/core/channels/inspector/src/theme/base/styles.ts +13 -0
  232. rasa/core/channels/inspector/src/theme/base/typography.ts +24 -0
  233. rasa/core/channels/inspector/src/theme/base/zIndices.ts +19 -0
  234. rasa/core/channels/inspector/src/theme/index.ts +101 -0
  235. rasa/core/channels/inspector/src/types.ts +64 -0
  236. rasa/core/channels/inspector/src/vite-env.d.ts +1 -0
  237. rasa/core/channels/inspector/tests/__mocks__/fileMock.ts +1 -0
  238. rasa/core/channels/inspector/tests/__mocks__/matchMedia.ts +16 -0
  239. rasa/core/channels/inspector/tests/__mocks__/styleMock.ts +1 -0
  240. rasa/core/channels/inspector/tests/renderWithProviders.tsx +14 -0
  241. rasa/core/channels/inspector/tsconfig.json +26 -0
  242. rasa/core/channels/inspector/tsconfig.node.json +10 -0
  243. rasa/core/channels/inspector/vite.config.ts +8 -0
  244. rasa/core/channels/inspector/yarn.lock +6156 -0
  245. rasa/core/channels/mattermost.py +229 -0
  246. rasa/core/channels/rasa_chat.py +126 -0
  247. rasa/core/channels/rest.py +225 -0
  248. rasa/core/channels/rocketchat.py +174 -0
  249. rasa/core/channels/slack.py +620 -0
  250. rasa/core/channels/socketio.py +274 -0
  251. rasa/core/channels/telegram.py +298 -0
  252. rasa/core/channels/twilio.py +169 -0
  253. rasa/core/channels/twilio_voice.py +367 -0
  254. rasa/core/channels/vier_cvg.py +374 -0
  255. rasa/core/channels/webexteams.py +134 -0
  256. rasa/core/concurrent_lock_store.py +210 -0
  257. rasa/core/constants.py +107 -0
  258. rasa/core/evaluation/__init__.py +0 -0
  259. rasa/core/evaluation/marker.py +267 -0
  260. rasa/core/evaluation/marker_base.py +923 -0
  261. rasa/core/evaluation/marker_stats.py +293 -0
  262. rasa/core/evaluation/marker_tracker_loader.py +103 -0
  263. rasa/core/exceptions.py +29 -0
  264. rasa/core/exporter.py +284 -0
  265. rasa/core/featurizers/__init__.py +0 -0
  266. rasa/core/featurizers/precomputation.py +410 -0
  267. rasa/core/featurizers/single_state_featurizer.py +421 -0
  268. rasa/core/featurizers/tracker_featurizers.py +1262 -0
  269. rasa/core/http_interpreter.py +89 -0
  270. rasa/core/information_retrieval/__init__.py +7 -0
  271. rasa/core/information_retrieval/faiss.py +121 -0
  272. rasa/core/information_retrieval/information_retrieval.py +129 -0
  273. rasa/core/information_retrieval/milvus.py +52 -0
  274. rasa/core/information_retrieval/qdrant.py +95 -0
  275. rasa/core/jobs.py +63 -0
  276. rasa/core/lock.py +139 -0
  277. rasa/core/lock_store.py +343 -0
  278. rasa/core/migrate.py +403 -0
  279. rasa/core/nlg/__init__.py +3 -0
  280. rasa/core/nlg/callback.py +146 -0
  281. rasa/core/nlg/contextual_response_rephraser.py +270 -0
  282. rasa/core/nlg/generator.py +230 -0
  283. rasa/core/nlg/interpolator.py +143 -0
  284. rasa/core/nlg/response.py +155 -0
  285. rasa/core/nlg/summarize.py +69 -0
  286. rasa/core/policies/__init__.py +0 -0
  287. rasa/core/policies/ensemble.py +329 -0
  288. rasa/core/policies/enterprise_search_policy.py +781 -0
  289. rasa/core/policies/enterprise_search_prompt_template.jinja2 +25 -0
  290. rasa/core/policies/enterprise_search_prompt_with_citation_template.jinja2 +60 -0
  291. rasa/core/policies/flow_policy.py +205 -0
  292. rasa/core/policies/flows/__init__.py +0 -0
  293. rasa/core/policies/flows/flow_exceptions.py +44 -0
  294. rasa/core/policies/flows/flow_executor.py +705 -0
  295. rasa/core/policies/flows/flow_step_result.py +43 -0
  296. rasa/core/policies/intentless_policy.py +922 -0
  297. rasa/core/policies/intentless_prompt_template.jinja2 +22 -0
  298. rasa/core/policies/memoization.py +538 -0
  299. rasa/core/policies/policy.py +725 -0
  300. rasa/core/policies/rule_policy.py +1273 -0
  301. rasa/core/policies/ted_policy.py +2169 -0
  302. rasa/core/policies/unexpected_intent_policy.py +1022 -0
  303. rasa/core/processor.py +1422 -0
  304. rasa/core/run.py +331 -0
  305. rasa/core/secrets_manager/__init__.py +0 -0
  306. rasa/core/secrets_manager/constants.py +32 -0
  307. rasa/core/secrets_manager/endpoints.py +391 -0
  308. rasa/core/secrets_manager/factory.py +233 -0
  309. rasa/core/secrets_manager/secret_manager.py +262 -0
  310. rasa/core/secrets_manager/vault.py +574 -0
  311. rasa/core/test.py +1335 -0
  312. rasa/core/tracker_store.py +1699 -0
  313. rasa/core/train.py +105 -0
  314. rasa/core/training/__init__.py +89 -0
  315. rasa/core/training/converters/__init__.py +0 -0
  316. rasa/core/training/converters/responses_prefix_converter.py +119 -0
  317. rasa/core/training/interactive.py +1745 -0
  318. rasa/core/training/story_conflict.py +381 -0
  319. rasa/core/training/training.py +93 -0
  320. rasa/core/utils.py +339 -0
  321. rasa/core/visualize.py +70 -0
  322. rasa/dialogue_understanding/__init__.py +0 -0
  323. rasa/dialogue_understanding/coexistence/__init__.py +0 -0
  324. rasa/dialogue_understanding/coexistence/constants.py +4 -0
  325. rasa/dialogue_understanding/coexistence/intent_based_router.py +196 -0
  326. rasa/dialogue_understanding/coexistence/llm_based_router.py +260 -0
  327. rasa/dialogue_understanding/coexistence/router_template.jinja2 +12 -0
  328. rasa/dialogue_understanding/commands/__init__.py +49 -0
  329. rasa/dialogue_understanding/commands/can_not_handle_command.py +70 -0
  330. rasa/dialogue_understanding/commands/cancel_flow_command.py +125 -0
  331. rasa/dialogue_understanding/commands/change_flow_command.py +44 -0
  332. rasa/dialogue_understanding/commands/chit_chat_answer_command.py +57 -0
  333. rasa/dialogue_understanding/commands/clarify_command.py +86 -0
  334. rasa/dialogue_understanding/commands/command.py +85 -0
  335. rasa/dialogue_understanding/commands/correct_slots_command.py +297 -0
  336. rasa/dialogue_understanding/commands/error_command.py +79 -0
  337. rasa/dialogue_understanding/commands/free_form_answer_command.py +9 -0
  338. rasa/dialogue_understanding/commands/handle_code_change_command.py +73 -0
  339. rasa/dialogue_understanding/commands/human_handoff_command.py +66 -0
  340. rasa/dialogue_understanding/commands/knowledge_answer_command.py +57 -0
  341. rasa/dialogue_understanding/commands/noop_command.py +54 -0
  342. rasa/dialogue_understanding/commands/set_slot_command.py +160 -0
  343. rasa/dialogue_understanding/commands/skip_question_command.py +75 -0
  344. rasa/dialogue_understanding/commands/start_flow_command.py +107 -0
  345. rasa/dialogue_understanding/generator/__init__.py +21 -0
  346. rasa/dialogue_understanding/generator/command_generator.py +343 -0
  347. rasa/dialogue_understanding/generator/constants.py +18 -0
  348. rasa/dialogue_understanding/generator/flow_document_template.jinja2 +4 -0
  349. rasa/dialogue_understanding/generator/flow_retrieval.py +412 -0
  350. rasa/dialogue_understanding/generator/llm_based_command_generator.py +467 -0
  351. rasa/dialogue_understanding/generator/llm_command_generator.py +67 -0
  352. rasa/dialogue_understanding/generator/multi_step/__init__.py +0 -0
  353. rasa/dialogue_understanding/generator/multi_step/fill_slots_prompt.jinja2 +62 -0
  354. rasa/dialogue_understanding/generator/multi_step/handle_flows_prompt.jinja2 +38 -0
  355. rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +827 -0
  356. rasa/dialogue_understanding/generator/nlu_command_adapter.py +218 -0
  357. rasa/dialogue_understanding/generator/single_step/__init__.py +0 -0
  358. rasa/dialogue_understanding/generator/single_step/command_prompt_template.jinja2 +57 -0
  359. rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +345 -0
  360. rasa/dialogue_understanding/patterns/__init__.py +0 -0
  361. rasa/dialogue_understanding/patterns/cancel.py +111 -0
  362. rasa/dialogue_understanding/patterns/cannot_handle.py +43 -0
  363. rasa/dialogue_understanding/patterns/chitchat.py +37 -0
  364. rasa/dialogue_understanding/patterns/clarify.py +97 -0
  365. rasa/dialogue_understanding/patterns/code_change.py +41 -0
  366. rasa/dialogue_understanding/patterns/collect_information.py +90 -0
  367. rasa/dialogue_understanding/patterns/completed.py +40 -0
  368. rasa/dialogue_understanding/patterns/continue_interrupted.py +42 -0
  369. rasa/dialogue_understanding/patterns/correction.py +278 -0
  370. rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +248 -0
  371. rasa/dialogue_understanding/patterns/human_handoff.py +37 -0
  372. rasa/dialogue_understanding/patterns/internal_error.py +47 -0
  373. rasa/dialogue_understanding/patterns/search.py +37 -0
  374. rasa/dialogue_understanding/patterns/skip_question.py +38 -0
  375. rasa/dialogue_understanding/processor/__init__.py +0 -0
  376. rasa/dialogue_understanding/processor/command_processor.py +687 -0
  377. rasa/dialogue_understanding/processor/command_processor_component.py +39 -0
  378. rasa/dialogue_understanding/stack/__init__.py +0 -0
  379. rasa/dialogue_understanding/stack/dialogue_stack.py +178 -0
  380. rasa/dialogue_understanding/stack/frames/__init__.py +19 -0
  381. rasa/dialogue_understanding/stack/frames/chit_chat_frame.py +27 -0
  382. rasa/dialogue_understanding/stack/frames/dialogue_stack_frame.py +137 -0
  383. rasa/dialogue_understanding/stack/frames/flow_stack_frame.py +157 -0
  384. rasa/dialogue_understanding/stack/frames/pattern_frame.py +10 -0
  385. rasa/dialogue_understanding/stack/frames/search_frame.py +27 -0
  386. rasa/dialogue_understanding/stack/utils.py +211 -0
  387. rasa/e2e_test/__init__.py +0 -0
  388. rasa/e2e_test/constants.py +11 -0
  389. rasa/e2e_test/e2e_test_case.py +366 -0
  390. rasa/e2e_test/e2e_test_result.py +34 -0
  391. rasa/e2e_test/e2e_test_runner.py +768 -0
  392. rasa/e2e_test/e2e_test_schema.yml +85 -0
  393. rasa/engine/__init__.py +0 -0
  394. rasa/engine/caching.py +463 -0
  395. rasa/engine/constants.py +17 -0
  396. rasa/engine/exceptions.py +14 -0
  397. rasa/engine/graph.py +637 -0
  398. rasa/engine/loader.py +36 -0
  399. rasa/engine/recipes/__init__.py +0 -0
  400. rasa/engine/recipes/config_files/default_config.yml +44 -0
  401. rasa/engine/recipes/default_components.py +99 -0
  402. rasa/engine/recipes/default_recipe.py +1251 -0
  403. rasa/engine/recipes/graph_recipe.py +79 -0
  404. rasa/engine/recipes/recipe.py +93 -0
  405. rasa/engine/runner/__init__.py +0 -0
  406. rasa/engine/runner/dask.py +250 -0
  407. rasa/engine/runner/interface.py +49 -0
  408. rasa/engine/storage/__init__.py +0 -0
  409. rasa/engine/storage/local_model_storage.py +246 -0
  410. rasa/engine/storage/resource.py +110 -0
  411. rasa/engine/storage/storage.py +203 -0
  412. rasa/engine/training/__init__.py +0 -0
  413. rasa/engine/training/components.py +176 -0
  414. rasa/engine/training/fingerprinting.py +64 -0
  415. rasa/engine/training/graph_trainer.py +256 -0
  416. rasa/engine/training/hooks.py +164 -0
  417. rasa/engine/validation.py +873 -0
  418. rasa/env.py +5 -0
  419. rasa/exceptions.py +69 -0
  420. rasa/graph_components/__init__.py +0 -0
  421. rasa/graph_components/converters/__init__.py +0 -0
  422. rasa/graph_components/converters/nlu_message_converter.py +48 -0
  423. rasa/graph_components/providers/__init__.py +0 -0
  424. rasa/graph_components/providers/domain_for_core_training_provider.py +87 -0
  425. rasa/graph_components/providers/domain_provider.py +71 -0
  426. rasa/graph_components/providers/flows_provider.py +74 -0
  427. rasa/graph_components/providers/forms_provider.py +44 -0
  428. rasa/graph_components/providers/nlu_training_data_provider.py +56 -0
  429. rasa/graph_components/providers/responses_provider.py +44 -0
  430. rasa/graph_components/providers/rule_only_provider.py +49 -0
  431. rasa/graph_components/providers/story_graph_provider.py +43 -0
  432. rasa/graph_components/providers/training_tracker_provider.py +55 -0
  433. rasa/graph_components/validators/__init__.py +0 -0
  434. rasa/graph_components/validators/default_recipe_validator.py +550 -0
  435. rasa/graph_components/validators/finetuning_validator.py +302 -0
  436. rasa/hooks.py +112 -0
  437. rasa/jupyter.py +63 -0
  438. rasa/markers/__init__.py +0 -0
  439. rasa/markers/marker.py +269 -0
  440. rasa/markers/marker_base.py +828 -0
  441. rasa/markers/upload.py +74 -0
  442. rasa/markers/validate.py +21 -0
  443. rasa/model.py +118 -0
  444. rasa/model_testing.py +457 -0
  445. rasa/model_training.py +536 -0
  446. rasa/nlu/__init__.py +7 -0
  447. rasa/nlu/classifiers/__init__.py +3 -0
  448. rasa/nlu/classifiers/classifier.py +5 -0
  449. rasa/nlu/classifiers/diet_classifier.py +1881 -0
  450. rasa/nlu/classifiers/fallback_classifier.py +192 -0
  451. rasa/nlu/classifiers/keyword_intent_classifier.py +188 -0
  452. rasa/nlu/classifiers/llm_intent_classifier.py +519 -0
  453. rasa/nlu/classifiers/logistic_regression_classifier.py +253 -0
  454. rasa/nlu/classifiers/mitie_intent_classifier.py +156 -0
  455. rasa/nlu/classifiers/regex_message_handler.py +56 -0
  456. rasa/nlu/classifiers/sklearn_intent_classifier.py +330 -0
  457. rasa/nlu/constants.py +77 -0
  458. rasa/nlu/convert.py +40 -0
  459. rasa/nlu/emulators/__init__.py +0 -0
  460. rasa/nlu/emulators/dialogflow.py +55 -0
  461. rasa/nlu/emulators/emulator.py +49 -0
  462. rasa/nlu/emulators/luis.py +86 -0
  463. rasa/nlu/emulators/no_emulator.py +10 -0
  464. rasa/nlu/emulators/wit.py +56 -0
  465. rasa/nlu/extractors/__init__.py +0 -0
  466. rasa/nlu/extractors/crf_entity_extractor.py +715 -0
  467. rasa/nlu/extractors/duckling_entity_extractor.py +206 -0
  468. rasa/nlu/extractors/entity_synonyms.py +178 -0
  469. rasa/nlu/extractors/extractor.py +470 -0
  470. rasa/nlu/extractors/mitie_entity_extractor.py +293 -0
  471. rasa/nlu/extractors/regex_entity_extractor.py +220 -0
  472. rasa/nlu/extractors/spacy_entity_extractor.py +95 -0
  473. rasa/nlu/featurizers/__init__.py +0 -0
  474. rasa/nlu/featurizers/dense_featurizer/__init__.py +0 -0
  475. rasa/nlu/featurizers/dense_featurizer/convert_featurizer.py +445 -0
  476. rasa/nlu/featurizers/dense_featurizer/dense_featurizer.py +57 -0
  477. rasa/nlu/featurizers/dense_featurizer/lm_featurizer.py +768 -0
  478. rasa/nlu/featurizers/dense_featurizer/mitie_featurizer.py +170 -0
  479. rasa/nlu/featurizers/dense_featurizer/spacy_featurizer.py +132 -0
  480. rasa/nlu/featurizers/featurizer.py +89 -0
  481. rasa/nlu/featurizers/sparse_featurizer/__init__.py +0 -0
  482. rasa/nlu/featurizers/sparse_featurizer/count_vectors_featurizer.py +867 -0
  483. rasa/nlu/featurizers/sparse_featurizer/lexical_syntactic_featurizer.py +571 -0
  484. rasa/nlu/featurizers/sparse_featurizer/regex_featurizer.py +271 -0
  485. rasa/nlu/featurizers/sparse_featurizer/sparse_featurizer.py +9 -0
  486. rasa/nlu/model.py +24 -0
  487. rasa/nlu/persistor.py +282 -0
  488. rasa/nlu/run.py +27 -0
  489. rasa/nlu/selectors/__init__.py +0 -0
  490. rasa/nlu/selectors/response_selector.py +987 -0
  491. rasa/nlu/test.py +1940 -0
  492. rasa/nlu/tokenizers/__init__.py +0 -0
  493. rasa/nlu/tokenizers/jieba_tokenizer.py +148 -0
  494. rasa/nlu/tokenizers/mitie_tokenizer.py +75 -0
  495. rasa/nlu/tokenizers/spacy_tokenizer.py +72 -0
  496. rasa/nlu/tokenizers/tokenizer.py +239 -0
  497. rasa/nlu/tokenizers/whitespace_tokenizer.py +106 -0
  498. rasa/nlu/utils/__init__.py +35 -0
  499. rasa/nlu/utils/bilou_utils.py +462 -0
  500. rasa/nlu/utils/hugging_face/__init__.py +0 -0
  501. rasa/nlu/utils/hugging_face/registry.py +108 -0
  502. rasa/nlu/utils/hugging_face/transformers_pre_post_processors.py +311 -0
  503. rasa/nlu/utils/mitie_utils.py +113 -0
  504. rasa/nlu/utils/pattern_utils.py +168 -0
  505. rasa/nlu/utils/spacy_utils.py +310 -0
  506. rasa/plugin.py +90 -0
  507. rasa/server.py +1551 -0
  508. rasa/shared/__init__.py +0 -0
  509. rasa/shared/constants.py +192 -0
  510. rasa/shared/core/__init__.py +0 -0
  511. rasa/shared/core/command_payload_reader.py +109 -0
  512. rasa/shared/core/constants.py +167 -0
  513. rasa/shared/core/conversation.py +46 -0
  514. rasa/shared/core/domain.py +2107 -0
  515. rasa/shared/core/events.py +2504 -0
  516. rasa/shared/core/flows/__init__.py +7 -0
  517. rasa/shared/core/flows/flow.py +362 -0
  518. rasa/shared/core/flows/flow_step.py +146 -0
  519. rasa/shared/core/flows/flow_step_links.py +319 -0
  520. rasa/shared/core/flows/flow_step_sequence.py +70 -0
  521. rasa/shared/core/flows/flows_list.py +223 -0
  522. rasa/shared/core/flows/flows_yaml_schema.json +217 -0
  523. rasa/shared/core/flows/nlu_trigger.py +117 -0
  524. rasa/shared/core/flows/steps/__init__.py +24 -0
  525. rasa/shared/core/flows/steps/action.py +56 -0
  526. rasa/shared/core/flows/steps/call.py +64 -0
  527. rasa/shared/core/flows/steps/collect.py +112 -0
  528. rasa/shared/core/flows/steps/constants.py +5 -0
  529. rasa/shared/core/flows/steps/continuation.py +36 -0
  530. rasa/shared/core/flows/steps/end.py +22 -0
  531. rasa/shared/core/flows/steps/internal.py +44 -0
  532. rasa/shared/core/flows/steps/link.py +51 -0
  533. rasa/shared/core/flows/steps/no_operation.py +48 -0
  534. rasa/shared/core/flows/steps/set_slots.py +50 -0
  535. rasa/shared/core/flows/steps/start.py +30 -0
  536. rasa/shared/core/flows/validation.py +527 -0
  537. rasa/shared/core/flows/yaml_flows_io.py +278 -0
  538. rasa/shared/core/generator.py +908 -0
  539. rasa/shared/core/slot_mappings.py +526 -0
  540. rasa/shared/core/slots.py +649 -0
  541. rasa/shared/core/trackers.py +1177 -0
  542. rasa/shared/core/training_data/__init__.py +0 -0
  543. rasa/shared/core/training_data/loading.py +89 -0
  544. rasa/shared/core/training_data/story_reader/__init__.py +0 -0
  545. rasa/shared/core/training_data/story_reader/story_reader.py +129 -0
  546. rasa/shared/core/training_data/story_reader/story_step_builder.py +168 -0
  547. rasa/shared/core/training_data/story_reader/yaml_story_reader.py +888 -0
  548. rasa/shared/core/training_data/story_writer/__init__.py +0 -0
  549. rasa/shared/core/training_data/story_writer/story_writer.py +76 -0
  550. rasa/shared/core/training_data/story_writer/yaml_story_writer.py +444 -0
  551. rasa/shared/core/training_data/structures.py +838 -0
  552. rasa/shared/core/training_data/visualization.html +146 -0
  553. rasa/shared/core/training_data/visualization.py +603 -0
  554. rasa/shared/data.py +249 -0
  555. rasa/shared/engine/__init__.py +0 -0
  556. rasa/shared/engine/caching.py +26 -0
  557. rasa/shared/exceptions.py +163 -0
  558. rasa/shared/importers/__init__.py +0 -0
  559. rasa/shared/importers/importer.py +704 -0
  560. rasa/shared/importers/multi_project.py +203 -0
  561. rasa/shared/importers/rasa.py +99 -0
  562. rasa/shared/importers/utils.py +34 -0
  563. rasa/shared/nlu/__init__.py +0 -0
  564. rasa/shared/nlu/constants.py +47 -0
  565. rasa/shared/nlu/interpreter.py +10 -0
  566. rasa/shared/nlu/training_data/__init__.py +0 -0
  567. rasa/shared/nlu/training_data/entities_parser.py +208 -0
  568. rasa/shared/nlu/training_data/features.py +492 -0
  569. rasa/shared/nlu/training_data/formats/__init__.py +10 -0
  570. rasa/shared/nlu/training_data/formats/dialogflow.py +163 -0
  571. rasa/shared/nlu/training_data/formats/luis.py +87 -0
  572. rasa/shared/nlu/training_data/formats/rasa.py +135 -0
  573. rasa/shared/nlu/training_data/formats/rasa_yaml.py +603 -0
  574. rasa/shared/nlu/training_data/formats/readerwriter.py +244 -0
  575. rasa/shared/nlu/training_data/formats/wit.py +52 -0
  576. rasa/shared/nlu/training_data/loading.py +137 -0
  577. rasa/shared/nlu/training_data/lookup_tables_parser.py +30 -0
  578. rasa/shared/nlu/training_data/message.py +490 -0
  579. rasa/shared/nlu/training_data/schemas/__init__.py +0 -0
  580. rasa/shared/nlu/training_data/schemas/data_schema.py +85 -0
  581. rasa/shared/nlu/training_data/schemas/nlu.yml +53 -0
  582. rasa/shared/nlu/training_data/schemas/responses.yml +70 -0
  583. rasa/shared/nlu/training_data/synonyms_parser.py +42 -0
  584. rasa/shared/nlu/training_data/training_data.py +730 -0
  585. rasa/shared/nlu/training_data/util.py +223 -0
  586. rasa/shared/providers/__init__.py +0 -0
  587. rasa/shared/providers/openai/__init__.py +0 -0
  588. rasa/shared/providers/openai/clients.py +43 -0
  589. rasa/shared/providers/openai/session_handler.py +110 -0
  590. rasa/shared/utils/__init__.py +0 -0
  591. rasa/shared/utils/cli.py +72 -0
  592. rasa/shared/utils/common.py +308 -0
  593. rasa/shared/utils/constants.py +4 -0
  594. rasa/shared/utils/io.py +415 -0
  595. rasa/shared/utils/llm.py +404 -0
  596. rasa/shared/utils/pykwalify_extensions.py +27 -0
  597. rasa/shared/utils/schemas/__init__.py +0 -0
  598. rasa/shared/utils/schemas/config.yml +2 -0
  599. rasa/shared/utils/schemas/domain.yml +145 -0
  600. rasa/shared/utils/schemas/events.py +212 -0
  601. rasa/shared/utils/schemas/model_config.yml +46 -0
  602. rasa/shared/utils/schemas/stories.yml +173 -0
  603. rasa/shared/utils/yaml.py +786 -0
  604. rasa/studio/__init__.py +0 -0
  605. rasa/studio/auth.py +268 -0
  606. rasa/studio/config.py +127 -0
  607. rasa/studio/constants.py +18 -0
  608. rasa/studio/data_handler.py +359 -0
  609. rasa/studio/download.py +483 -0
  610. rasa/studio/results_logger.py +137 -0
  611. rasa/studio/train.py +135 -0
  612. rasa/studio/upload.py +433 -0
  613. rasa/telemetry.py +1737 -0
  614. rasa/tracing/__init__.py +0 -0
  615. rasa/tracing/config.py +353 -0
  616. rasa/tracing/constants.py +62 -0
  617. rasa/tracing/instrumentation/__init__.py +0 -0
  618. rasa/tracing/instrumentation/attribute_extractors.py +672 -0
  619. rasa/tracing/instrumentation/instrumentation.py +1185 -0
  620. rasa/tracing/instrumentation/intentless_policy_instrumentation.py +144 -0
  621. rasa/tracing/instrumentation/metrics.py +294 -0
  622. rasa/tracing/metric_instrument_provider.py +205 -0
  623. rasa/utils/__init__.py +0 -0
  624. rasa/utils/beta.py +83 -0
  625. rasa/utils/cli.py +28 -0
  626. rasa/utils/common.py +635 -0
  627. rasa/utils/converter.py +53 -0
  628. rasa/utils/endpoints.py +302 -0
  629. rasa/utils/io.py +260 -0
  630. rasa/utils/licensing.py +534 -0
  631. rasa/utils/log_utils.py +174 -0
  632. rasa/utils/mapper.py +210 -0
  633. rasa/utils/ml_utils.py +145 -0
  634. rasa/utils/plotting.py +362 -0
  635. rasa/utils/singleton.py +23 -0
  636. rasa/utils/tensorflow/__init__.py +0 -0
  637. rasa/utils/tensorflow/callback.py +112 -0
  638. rasa/utils/tensorflow/constants.py +116 -0
  639. rasa/utils/tensorflow/crf.py +492 -0
  640. rasa/utils/tensorflow/data_generator.py +440 -0
  641. rasa/utils/tensorflow/environment.py +161 -0
  642. rasa/utils/tensorflow/exceptions.py +5 -0
  643. rasa/utils/tensorflow/feature_array.py +366 -0
  644. rasa/utils/tensorflow/layers.py +1565 -0
  645. rasa/utils/tensorflow/layers_utils.py +113 -0
  646. rasa/utils/tensorflow/metrics.py +281 -0
  647. rasa/utils/tensorflow/model_data.py +798 -0
  648. rasa/utils/tensorflow/model_data_utils.py +499 -0
  649. rasa/utils/tensorflow/models.py +935 -0
  650. rasa/utils/tensorflow/rasa_layers.py +1094 -0
  651. rasa/utils/tensorflow/transformer.py +640 -0
  652. rasa/utils/tensorflow/types.py +6 -0
  653. rasa/utils/train_utils.py +572 -0
  654. rasa/utils/url_tools.py +53 -0
  655. rasa/utils/yaml.py +54 -0
  656. rasa/validator.py +1337 -0
  657. rasa/version.py +3 -0
  658. rasa_pro-3.9.18.dist-info/METADATA +563 -0
  659. rasa_pro-3.9.18.dist-info/NOTICE +5 -0
  660. rasa_pro-3.9.18.dist-info/RECORD +662 -0
  661. rasa_pro-3.9.18.dist-info/WHEEL +4 -0
  662. rasa_pro-3.9.18.dist-info/entry_points.txt +3 -0
@@ -0,0 +1,492 @@
1
+ import tensorflow as tf
2
+ from tensorflow import TensorShape
3
+ from tensorflow.types.experimental import TensorLike
4
+ from typing import Tuple, Any, List, Union, Optional
5
+
6
+
7
+ # original code taken from
8
+ # https://github.com/tensorflow/addons/blob/b8cab7fd61af4f697a1cdae4f51c37c346b9c6f0/tensorflow_addons/text/crf.py
9
+ # (modified to our neeeds)
10
+
11
+
12
+ class CrfDecodeForwardRnnCell(tf.keras.layers.AbstractRNNCell):
13
+ """Computes the forward decoding in a linear-chain CRF."""
14
+
15
+ def __init__(self, transition_params: TensorLike, **kwargs: Any) -> None:
16
+ """Initialize the CrfDecodeForwardRnnCell.
17
+
18
+ Args:
19
+ transition_params: A [num_tags, num_tags] matrix of binary
20
+ potentials. This matrix is expanded into a
21
+ [1, num_tags, num_tags] in preparation for the broadcast
22
+ summation occurring within the cell.
23
+ """
24
+ super().__init__(**kwargs)
25
+ self._transition_params = tf.expand_dims(transition_params, 0)
26
+ self._num_tags = transition_params.shape[0]
27
+
28
+ @property
29
+ def state_size(self) -> int:
30
+ return self._num_tags
31
+
32
+ @property
33
+ def output_size(self) -> int:
34
+ """Returns count of tags."""
35
+ return self._num_tags
36
+
37
+ def build(self, input_shape: Union[TensorShape, List[TensorShape]]) -> None:
38
+ """Creates the variables of the layer."""
39
+ super().build(input_shape)
40
+
41
+ def call(
42
+ self, inputs: TensorLike, state: TensorLike
43
+ ) -> Tuple[tf.Tensor, tf.Tensor]:
44
+ """Build the CrfDecodeForwardRnnCell.
45
+
46
+ Args:
47
+ inputs: A [batch_size, num_tags] matrix of unary potentials.
48
+ state: A [batch_size, num_tags] matrix containing the previous step's
49
+ score values.
50
+
51
+ Returns:
52
+ output: A [batch_size, num_tags * 2] matrix of backpointers and scores.
53
+ new_state: A [batch_size, num_tags] matrix of new score values.
54
+ """
55
+ state = tf.expand_dims(state[0], 2)
56
+ transition_scores = state + self._transition_params
57
+ new_state = inputs + tf.reduce_max(transition_scores, [1])
58
+
59
+ backpointers = tf.argmax(transition_scores, 1)
60
+ backpointers = tf.cast(backpointers, tf.float32)
61
+
62
+ # apply softmax to transition_scores to get scores in range from 0 to 1
63
+ scores = tf.reduce_max(tf.nn.softmax(transition_scores, axis=1), [1])
64
+
65
+ # In the RNN implementation only the first value that is returned from a cell
66
+ # is kept throughout the RNN, so that you will have the values from each time
67
+ # step in the final output. As we need the backpointers as well as the scores
68
+ # for each time step, we concatenate them.
69
+ return tf.concat([backpointers, scores], axis=1), new_state
70
+
71
+
72
+ def crf_decode_forward(
73
+ inputs: TensorLike,
74
+ state: TensorLike,
75
+ transition_params: TensorLike,
76
+ sequence_lengths: TensorLike,
77
+ ) -> Tuple[tf.Tensor, tf.Tensor]:
78
+ """Computes forward decoding in a linear-chain CRF.
79
+
80
+ Args:
81
+ inputs: A [batch_size, num_tags] matrix of unary potentials.
82
+ state: A [batch_size, num_tags] matrix containing the previous step's
83
+ score values.
84
+ transition_params: A [num_tags, num_tags] matrix of binary potentials.
85
+ sequence_lengths: A [batch_size] vector of true sequence lengths.
86
+
87
+ Returns:
88
+ output: A [batch_size, num_tags * 2] matrix of backpointers and scores.
89
+ new_state: A [batch_size, num_tags] matrix of new score values.
90
+ """
91
+ sequence_lengths = tf.cast(sequence_lengths, dtype=tf.int32)
92
+ mask = tf.sequence_mask(sequence_lengths, tf.shape(inputs)[1])
93
+ crf_fwd_cell = CrfDecodeForwardRnnCell(transition_params)
94
+ crf_fwd_layer = tf.keras.layers.RNN(
95
+ crf_fwd_cell, return_sequences=True, return_state=True
96
+ )
97
+ return crf_fwd_layer(inputs, state, mask=mask)
98
+
99
+
100
+ def crf_decode_backward(
101
+ backpointers: TensorLike, scores: TensorLike, state: TensorLike
102
+ ) -> Tuple[tf.Tensor, tf.Tensor]:
103
+ """Computes backward decoding in a linear-chain CRF.
104
+
105
+ Args:
106
+ backpointers: A [batch_size, num_tags] matrix of backpointer of next step
107
+ (in time order).
108
+ scores: A [batch_size, num_tags] matrix of scores of next step (in time order).
109
+ state: A [batch_size, 1] matrix of tag index of next step.
110
+
111
+ Returns:
112
+ new_tags: A [batch_size, num_tags] tensor containing the new tag indices.
113
+ new_scores: A [batch_size, num_tags] tensor containing the new score values.
114
+ """
115
+ backpointers = tf.transpose(backpointers, [1, 0, 2])
116
+ scores = tf.transpose(scores, [1, 0, 2])
117
+
118
+ def _scan_fn(_state: TensorLike, _inputs: TensorLike) -> tf.Tensor:
119
+ _state = tf.cast(tf.squeeze(_state, axis=[1]), dtype=tf.int32)
120
+ idxs = tf.stack([tf.range(tf.shape(_inputs)[0]), _state], axis=1)
121
+ return tf.expand_dims(tf.gather_nd(_inputs, idxs), axis=-1)
122
+
123
+ output_tags = tf.scan(_scan_fn, backpointers, state)
124
+ # the dtype of the input parameters of tf.scan need to match
125
+ # convert state to float32 to match the type of scores
126
+ state = tf.cast(state, dtype=tf.float32)
127
+ output_scores = tf.scan(_scan_fn, scores, state)
128
+
129
+ return tf.transpose(output_tags, [1, 0, 2]), tf.transpose(output_scores, [1, 0, 2])
130
+
131
+
132
+ def crf_decode(
133
+ potentials: TensorLike, transition_params: TensorLike, sequence_length: TensorLike
134
+ ) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
135
+ """Decode the highest scoring sequence of tags.
136
+
137
+ Args:
138
+ potentials: A [batch_size, max_seq_len, num_tags] tensor of
139
+ unary potentials.
140
+ transition_params: A [num_tags, num_tags] matrix of
141
+ binary potentials.
142
+ sequence_length: A [batch_size] vector of true sequence lengths.
143
+
144
+ Returns:
145
+ decode_tags: A [batch_size, max_seq_len] matrix, with dtype `tf.int32`.
146
+ Contains the highest scoring tag indices.
147
+ decode_scores: A [batch_size, max_seq_len] matrix, containing the score of
148
+ `decode_tags`.
149
+ best_score: A [batch_size] vector, containing the best score of `decode_tags`.
150
+ """
151
+ sequence_length = tf.cast(sequence_length, dtype=tf.int32)
152
+
153
+ # If max_seq_len is 1, we skip the algorithm and simply return the
154
+ # argmax tag and the max activation.
155
+ def _single_seq_fn() -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
156
+ decode_tags = tf.cast(tf.argmax(potentials, axis=2), dtype=tf.int32)
157
+ decode_scores = tf.reduce_max(tf.nn.softmax(potentials, axis=2), axis=2)
158
+ best_score = tf.reshape(tf.reduce_max(potentials, axis=2), shape=[-1])
159
+ return decode_tags, decode_scores, best_score
160
+
161
+ def _multi_seq_fn() -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
162
+ # Computes forward decoding. Get last score and backpointers.
163
+ initial_state = tf.slice(potentials, [0, 0, 0], [-1, 1, -1])
164
+ initial_state = tf.squeeze(initial_state, axis=[1])
165
+ inputs = tf.slice(potentials, [0, 1, 0], [-1, -1, -1])
166
+
167
+ sequence_length_less_one = tf.maximum(
168
+ tf.constant(0, dtype=tf.int32), sequence_length - 1
169
+ )
170
+
171
+ output, last_score = crf_decode_forward(
172
+ inputs, initial_state, transition_params, sequence_length_less_one
173
+ )
174
+
175
+ # output is a matrix of size [batch-size, max-seq-length, num-tags * 2]
176
+ # split the matrix on axis 2 to get the backpointers and scores, which are
177
+ # both of size [batch-size, max-seq-length, num-tags]
178
+ backpointers, scores = tf.split(output, 2, axis=2)
179
+
180
+ backpointers = tf.cast(backpointers, dtype=tf.int32)
181
+ backpointers = tf.reverse_sequence(
182
+ backpointers, sequence_length_less_one, seq_axis=1
183
+ )
184
+
185
+ scores = tf.reverse_sequence(scores, sequence_length_less_one, seq_axis=1)
186
+
187
+ initial_state = tf.cast(tf.argmax(last_score, axis=1), dtype=tf.int32)
188
+ initial_state = tf.expand_dims(initial_state, axis=-1)
189
+
190
+ initial_score = tf.reduce_max(tf.nn.softmax(last_score, axis=1), axis=[1])
191
+ initial_score = tf.expand_dims(initial_score, axis=-1)
192
+
193
+ decode_tags, decode_scores = crf_decode_backward(
194
+ backpointers, scores, initial_state
195
+ )
196
+
197
+ decode_tags = tf.squeeze(decode_tags, axis=[2])
198
+ decode_tags = tf.concat([initial_state, decode_tags], axis=1)
199
+ decode_tags = tf.reverse_sequence(decode_tags, sequence_length, seq_axis=1)
200
+
201
+ decode_scores = tf.squeeze(decode_scores, axis=[2])
202
+ decode_scores = tf.concat([initial_score, decode_scores], axis=1)
203
+ decode_scores = tf.reverse_sequence(decode_scores, sequence_length, seq_axis=1)
204
+
205
+ best_score = tf.reduce_max(last_score, axis=1)
206
+
207
+ return decode_tags, decode_scores, best_score
208
+
209
+ if potentials.shape[1] is not None:
210
+ # shape is statically know, so we just execute
211
+ # the appropriate code path
212
+ if potentials.shape[1] == 1:
213
+ return _single_seq_fn()
214
+
215
+ return _multi_seq_fn()
216
+
217
+ return tf.cond(tf.equal(tf.shape(potentials)[1], 1), _single_seq_fn, _multi_seq_fn)
218
+
219
+
220
+ def crf_unary_score(
221
+ tag_indices: TensorLike, sequence_lengths: TensorLike, inputs: TensorLike
222
+ ) -> tf.Tensor:
223
+ """Computes the unary scores of tag sequences.
224
+
225
+ Args:
226
+ tag_indices: A [batch_size, max_seq_len] matrix of tag indices.
227
+ sequence_lengths: A [batch_size] vector of true sequence lengths.
228
+ inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials.
229
+
230
+ Returns:
231
+ unary_scores: A [batch_size] vector of unary scores.
232
+ """
233
+ tag_indices = tf.cast(tag_indices, dtype=tf.int32)
234
+ sequence_lengths = tf.cast(sequence_lengths, dtype=tf.int32)
235
+
236
+ batch_size = tf.shape(inputs)[0]
237
+ max_seq_len = tf.shape(inputs)[1]
238
+ num_tags = tf.shape(inputs)[2]
239
+
240
+ flattened_inputs = tf.reshape(inputs, [-1])
241
+
242
+ offsets = tf.expand_dims(tf.range(batch_size) * max_seq_len * num_tags, 1)
243
+ offsets += tf.expand_dims(tf.range(max_seq_len) * num_tags, 0)
244
+ # Use int32 or int64 based on tag_indices' dtype.
245
+ if tag_indices.dtype == tf.int64:
246
+ offsets = tf.cast(offsets, tf.int64)
247
+ flattened_tag_indices = tf.reshape(offsets + tag_indices, [-1])
248
+
249
+ unary_scores = tf.reshape(
250
+ tf.gather(flattened_inputs, flattened_tag_indices), [batch_size, max_seq_len]
251
+ )
252
+
253
+ masks = tf.sequence_mask(
254
+ sequence_lengths, maxlen=tf.shape(tag_indices)[1], dtype=unary_scores.dtype
255
+ )
256
+
257
+ unary_scores = tf.reduce_sum(unary_scores * masks, 1)
258
+ return unary_scores
259
+
260
+
261
+ def crf_binary_score(
262
+ tag_indices: TensorLike, sequence_lengths: TensorLike, transition_params: TensorLike
263
+ ) -> tf.Tensor:
264
+ """Computes the binary scores of tag sequences.
265
+
266
+ Args:
267
+ tag_indices: A [batch_size, max_seq_len] matrix of tag indices.
268
+ sequence_lengths: A [batch_size] vector of true sequence lengths.
269
+ transition_params: A [num_tags, num_tags] matrix of binary potentials.
270
+
271
+ Returns:
272
+ binary_scores: A [batch_size] vector of binary scores.
273
+ """
274
+ tag_indices = tf.cast(tag_indices, dtype=tf.int32)
275
+ sequence_lengths = tf.cast(sequence_lengths, dtype=tf.int32)
276
+
277
+ num_tags = tf.shape(transition_params)[0]
278
+ num_transitions = tf.shape(tag_indices)[1] - 1
279
+
280
+ # Truncate by one on each side of the sequence to get the start and end
281
+ # indices of each transition.
282
+ start_tag_indices = tf.slice(tag_indices, [0, 0], [-1, num_transitions])
283
+ end_tag_indices = tf.slice(tag_indices, [0, 1], [-1, num_transitions])
284
+
285
+ # Encode the indices in a flattened representation.
286
+ flattened_transition_indices = start_tag_indices * num_tags + end_tag_indices
287
+ flattened_transition_params = tf.reshape(transition_params, [-1])
288
+
289
+ # Get the binary scores based on the flattened representation.
290
+ binary_scores = tf.gather(flattened_transition_params, flattened_transition_indices)
291
+
292
+ masks = tf.sequence_mask(
293
+ sequence_lengths, maxlen=tf.shape(tag_indices)[1], dtype=binary_scores.dtype
294
+ )
295
+ truncated_masks = tf.slice(masks, [0, 1], [-1, -1])
296
+ binary_scores = tf.reduce_sum(binary_scores * truncated_masks, 1)
297
+ return binary_scores
298
+
299
+
300
+ def crf_sequence_score(
301
+ inputs: TensorLike,
302
+ tag_indices: TensorLike,
303
+ sequence_lengths: TensorLike,
304
+ transition_params: TensorLike,
305
+ ) -> tf.Tensor:
306
+ """Computes the unnormalized score for a tag sequence.
307
+
308
+ Args:
309
+ inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials
310
+ to use as input to the CRF layer.
311
+ tag_indices: A [batch_size, max_seq_len] matrix of tag indices for which
312
+ we compute the unnormalized score.
313
+ sequence_lengths: A [batch_size] vector of true sequence lengths.
314
+ transition_params: A [num_tags, num_tags] transition matrix.
315
+
316
+ Returns:
317
+ sequence_scores: A [batch_size] vector of unnormalized sequence scores.
318
+ """
319
+ tag_indices = tf.cast(tag_indices, dtype=tf.int32)
320
+ sequence_lengths = tf.cast(sequence_lengths, dtype=tf.int32)
321
+
322
+ # If max_seq_len is 1, we skip the score calculation and simply gather the
323
+ # unary potentials of the single tag.
324
+ def _single_seq_fn() -> TensorLike:
325
+ batch_size = tf.shape(inputs, out_type=tf.int32)[0]
326
+ batch_inds = tf.reshape(tf.range(batch_size), [-1, 1])
327
+ indices = tf.concat([batch_inds, tf.zeros_like(batch_inds)], axis=1)
328
+
329
+ tag_inds = tf.gather_nd(tag_indices, indices)
330
+ tag_inds = tf.reshape(tag_inds, [-1, 1])
331
+ indices = tf.concat([indices, tag_inds], axis=1)
332
+
333
+ sequence_scores = tf.gather_nd(inputs, indices)
334
+
335
+ sequence_scores = tf.where(
336
+ tf.less_equal(sequence_lengths, 0),
337
+ tf.zeros_like(sequence_scores),
338
+ sequence_scores,
339
+ )
340
+ return sequence_scores
341
+
342
+ def _multi_seq_fn() -> TensorLike:
343
+ # Compute the scores of the given tag sequence.
344
+ unary_scores = crf_unary_score(tag_indices, sequence_lengths, inputs)
345
+ binary_scores = crf_binary_score(
346
+ tag_indices, sequence_lengths, transition_params
347
+ )
348
+ sequence_scores = unary_scores + binary_scores
349
+ return sequence_scores
350
+
351
+ return tf.cond(tf.equal(tf.shape(inputs)[1], 1), _single_seq_fn, _multi_seq_fn)
352
+
353
+
354
+ def crf_forward(
355
+ inputs: TensorLike,
356
+ state: TensorLike,
357
+ transition_params: TensorLike,
358
+ sequence_lengths: TensorLike,
359
+ ) -> tf.Tensor:
360
+ """Computes the alpha values in a linear-chain CRF.
361
+
362
+ See http://www.cs.columbia.edu/~mcollins/fb.pdf for reference.
363
+
364
+ Args:
365
+ inputs: A [batch_size, num_tags] matrix of unary potentials.
366
+ state: A [batch_size, num_tags] matrix containing the previous alpha
367
+ values.
368
+ transition_params: A [num_tags, num_tags] matrix of binary potentials.
369
+ This matrix is expanded into a [1, num_tags, num_tags] in preparation
370
+ for the broadcast summation occurring within the cell.
371
+ sequence_lengths: A [batch_size] vector of true sequence lengths.
372
+
373
+ Returns:
374
+ new_alphas: A [batch_size, num_tags] matrix containing the
375
+ new alpha values.
376
+ """
377
+ sequence_lengths = tf.cast(sequence_lengths, dtype=tf.int32)
378
+
379
+ last_index = tf.maximum(
380
+ tf.constant(0, dtype=sequence_lengths.dtype), sequence_lengths - 1
381
+ )
382
+ inputs = tf.transpose(inputs, [1, 0, 2])
383
+ transition_params = tf.expand_dims(transition_params, 0)
384
+
385
+ def _scan_fn(_state: TensorLike, _inputs: TensorLike) -> TensorLike:
386
+ _state = tf.expand_dims(_state, 2)
387
+ transition_scores = _state + transition_params
388
+ new_alphas = _inputs + tf.reduce_logsumexp(transition_scores, [1])
389
+ return new_alphas
390
+
391
+ all_alphas = tf.transpose(tf.scan(_scan_fn, inputs, state), [1, 0, 2])
392
+ # add first state for sequences of length 1
393
+ all_alphas = tf.concat([tf.expand_dims(state, 1), all_alphas], 1)
394
+
395
+ idxs = tf.stack([tf.range(tf.shape(last_index)[0]), last_index], axis=1)
396
+ return tf.gather_nd(all_alphas, idxs)
397
+
398
+
399
+ def crf_log_norm(
400
+ inputs: TensorLike, sequence_lengths: TensorLike, transition_params: TensorLike
401
+ ) -> tf.Tensor:
402
+ """Computes the normalization for a CRF.
403
+
404
+ Args:
405
+ inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials
406
+ to use as input to the CRF layer.
407
+ sequence_lengths: A [batch_size] vector of true sequence lengths.
408
+ transition_params: A [num_tags, num_tags] transition matrix.
409
+
410
+ Returns:
411
+ log_norm: A [batch_size] vector of normalizers for a CRF.
412
+ """
413
+ sequence_lengths = tf.cast(sequence_lengths, dtype=tf.int32)
414
+ # Split up the first and rest of the inputs in preparation for the forward
415
+ # algorithm.
416
+ first_input = tf.slice(inputs, [0, 0, 0], [-1, 1, -1])
417
+ first_input = tf.squeeze(first_input, [1])
418
+
419
+ # If max_seq_len is 1, we skip the algorithm and simply reduce_logsumexp
420
+ # over the "initial state" (the unary potentials).
421
+ def _single_seq_fn() -> TensorLike:
422
+ log_norm = tf.reduce_logsumexp(first_input, [1])
423
+ # Mask `log_norm` of the sequences with length <= zero.
424
+ log_norm = tf.where(
425
+ tf.less_equal(sequence_lengths, 0), tf.zeros_like(log_norm), log_norm
426
+ )
427
+ return log_norm
428
+
429
+ def _multi_seq_fn() -> TensorLike:
430
+ """Forward computation of alpha values."""
431
+ rest_of_input = tf.slice(inputs, [0, 1, 0], [-1, -1, -1])
432
+ # Compute the alpha values in the forward algorithm in order to get the
433
+ # partition function.
434
+
435
+ alphas = crf_forward(
436
+ rest_of_input, first_input, transition_params, sequence_lengths
437
+ )
438
+ log_norm = tf.reduce_logsumexp(alphas, [1])
439
+ # Mask `log_norm` of the sequences with length <= zero.
440
+ log_norm = tf.where(
441
+ tf.less_equal(sequence_lengths, 0), tf.zeros_like(log_norm), log_norm
442
+ )
443
+ return log_norm
444
+
445
+ return tf.cond(tf.equal(tf.shape(inputs)[1], 1), _single_seq_fn, _multi_seq_fn)
446
+
447
+
448
+ def crf_log_likelihood(
449
+ inputs: TensorLike,
450
+ tag_indices: TensorLike,
451
+ sequence_lengths: TensorLike,
452
+ transition_params: Optional[TensorLike] = None,
453
+ ) -> Tuple[tf.Tensor, tf.Tensor]:
454
+ """Computes the log-likelihood of tag sequences in a CRF.
455
+
456
+ Args:
457
+ inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials
458
+ to use as input to the CRF layer.
459
+ tag_indices: A [batch_size, max_seq_len] matrix of tag indices for which
460
+ we compute the log-likelihood.
461
+ sequence_lengths: A [batch_size] vector of true sequence lengths.
462
+ transition_params: A [num_tags, num_tags] transition matrix,
463
+ if available.
464
+
465
+ Returns:
466
+ log_likelihood: A [batch_size] `Tensor` containing the log-likelihood of
467
+ each example, given the sequence of tag indices.
468
+ transition_params: A [num_tags, num_tags] transition matrix. This is
469
+ either provided by the caller or created in this function.
470
+ """
471
+ inputs = tf.convert_to_tensor(inputs)
472
+
473
+ num_tags = inputs.shape[2]
474
+
475
+ # cast type to handle different types
476
+ tag_indices = tf.cast(tag_indices, dtype=tf.int32)
477
+ sequence_lengths = tf.cast(sequence_lengths, dtype=tf.int32)
478
+
479
+ if transition_params is None:
480
+ initializer = tf.keras.initializers.GlorotUniform()
481
+ transition_params = tf.Variable(
482
+ initializer([num_tags, num_tags]), "transitions"
483
+ )
484
+ transition_params = tf.cast(transition_params, inputs.dtype)
485
+ sequence_scores = crf_sequence_score(
486
+ inputs, tag_indices, sequence_lengths, transition_params
487
+ )
488
+ log_norm = crf_log_norm(inputs, sequence_lengths, transition_params)
489
+
490
+ # Normalize the scores to get the log-likelihood per example.
491
+ log_likelihood = sequence_scores - log_norm
492
+ return log_likelihood, transition_params