rasa-pro 3.8.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rasa-pro might be problematic. Click here for more details.

Files changed (644) hide show
  1. README.md +380 -0
  2. rasa/__init__.py +10 -0
  3. rasa/__main__.py +151 -0
  4. rasa/anonymization/__init__.py +2 -0
  5. rasa/anonymization/anonymisation_rule_yaml_reader.py +91 -0
  6. rasa/anonymization/anonymization_pipeline.py +287 -0
  7. rasa/anonymization/anonymization_rule_executor.py +260 -0
  8. rasa/anonymization/anonymization_rule_orchestrator.py +120 -0
  9. rasa/anonymization/schemas/config.yml +47 -0
  10. rasa/anonymization/utils.py +117 -0
  11. rasa/api.py +146 -0
  12. rasa/cli/__init__.py +5 -0
  13. rasa/cli/arguments/__init__.py +0 -0
  14. rasa/cli/arguments/data.py +81 -0
  15. rasa/cli/arguments/default_arguments.py +165 -0
  16. rasa/cli/arguments/evaluate.py +65 -0
  17. rasa/cli/arguments/export.py +51 -0
  18. rasa/cli/arguments/interactive.py +74 -0
  19. rasa/cli/arguments/run.py +204 -0
  20. rasa/cli/arguments/shell.py +13 -0
  21. rasa/cli/arguments/test.py +211 -0
  22. rasa/cli/arguments/train.py +263 -0
  23. rasa/cli/arguments/visualize.py +34 -0
  24. rasa/cli/arguments/x.py +30 -0
  25. rasa/cli/data.py +292 -0
  26. rasa/cli/e2e_test.py +566 -0
  27. rasa/cli/evaluate.py +222 -0
  28. rasa/cli/export.py +251 -0
  29. rasa/cli/inspect.py +63 -0
  30. rasa/cli/interactive.py +164 -0
  31. rasa/cli/license.py +65 -0
  32. rasa/cli/markers.py +78 -0
  33. rasa/cli/project_templates/__init__.py +0 -0
  34. rasa/cli/project_templates/calm/actions/__init__.py +0 -0
  35. rasa/cli/project_templates/calm/actions/action_template.py +27 -0
  36. rasa/cli/project_templates/calm/actions/add_contact.py +30 -0
  37. rasa/cli/project_templates/calm/actions/db.py +57 -0
  38. rasa/cli/project_templates/calm/actions/list_contacts.py +22 -0
  39. rasa/cli/project_templates/calm/actions/remove_contact.py +35 -0
  40. rasa/cli/project_templates/calm/config.yml +12 -0
  41. rasa/cli/project_templates/calm/credentials.yml +33 -0
  42. rasa/cli/project_templates/calm/data/flows/add_contact.yml +31 -0
  43. rasa/cli/project_templates/calm/data/flows/list_contacts.yml +14 -0
  44. rasa/cli/project_templates/calm/data/flows/remove_contact.yml +29 -0
  45. rasa/cli/project_templates/calm/db/contacts.json +10 -0
  46. rasa/cli/project_templates/calm/domain/add_contact.yml +33 -0
  47. rasa/cli/project_templates/calm/domain/list_contacts.yml +14 -0
  48. rasa/cli/project_templates/calm/domain/remove_contact.yml +31 -0
  49. rasa/cli/project_templates/calm/domain/shared.yml +5 -0
  50. rasa/cli/project_templates/calm/e2e_tests/cancelations/user_cancels_during_a_correction.yml +16 -0
  51. rasa/cli/project_templates/calm/e2e_tests/cancelations/user_changes_mind_on_a_whim.yml +7 -0
  52. rasa/cli/project_templates/calm/e2e_tests/corrections/user_corrects_contact_handle.yml +20 -0
  53. rasa/cli/project_templates/calm/e2e_tests/corrections/user_corrects_contact_name.yml +19 -0
  54. rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_adds_contact_to_their_list.yml +15 -0
  55. rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_lists_contacts.yml +5 -0
  56. rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_removes_contact.yml +11 -0
  57. rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_removes_contact_from_list.yml +12 -0
  58. rasa/cli/project_templates/calm/endpoints.yml +45 -0
  59. rasa/cli/project_templates/default/actions/__init__.py +0 -0
  60. rasa/cli/project_templates/default/actions/actions.py +27 -0
  61. rasa/cli/project_templates/default/config.yml +44 -0
  62. rasa/cli/project_templates/default/credentials.yml +33 -0
  63. rasa/cli/project_templates/default/data/nlu.yml +91 -0
  64. rasa/cli/project_templates/default/data/rules.yml +13 -0
  65. rasa/cli/project_templates/default/data/stories.yml +30 -0
  66. rasa/cli/project_templates/default/domain.yml +34 -0
  67. rasa/cli/project_templates/default/endpoints.yml +42 -0
  68. rasa/cli/project_templates/default/tests/test_stories.yml +91 -0
  69. rasa/cli/project_templates/tutorial/actions.py +22 -0
  70. rasa/cli/project_templates/tutorial/config.yml +11 -0
  71. rasa/cli/project_templates/tutorial/credentials.yml +33 -0
  72. rasa/cli/project_templates/tutorial/data/flows.yml +8 -0
  73. rasa/cli/project_templates/tutorial/domain.yml +17 -0
  74. rasa/cli/project_templates/tutorial/endpoints.yml +45 -0
  75. rasa/cli/run.py +136 -0
  76. rasa/cli/scaffold.py +268 -0
  77. rasa/cli/shell.py +141 -0
  78. rasa/cli/studio/__init__.py +0 -0
  79. rasa/cli/studio/download.py +51 -0
  80. rasa/cli/studio/studio.py +110 -0
  81. rasa/cli/studio/train.py +59 -0
  82. rasa/cli/studio/upload.py +85 -0
  83. rasa/cli/telemetry.py +90 -0
  84. rasa/cli/test.py +280 -0
  85. rasa/cli/train.py +260 -0
  86. rasa/cli/utils.py +453 -0
  87. rasa/cli/visualize.py +40 -0
  88. rasa/cli/x.py +205 -0
  89. rasa/constants.py +37 -0
  90. rasa/core/__init__.py +17 -0
  91. rasa/core/actions/__init__.py +0 -0
  92. rasa/core/actions/action.py +1450 -0
  93. rasa/core/actions/action_clean_stack.py +59 -0
  94. rasa/core/actions/action_run_slot_rejections.py +207 -0
  95. rasa/core/actions/action_trigger_chitchat.py +31 -0
  96. rasa/core/actions/action_trigger_flow.py +109 -0
  97. rasa/core/actions/action_trigger_search.py +31 -0
  98. rasa/core/actions/constants.py +2 -0
  99. rasa/core/actions/forms.py +737 -0
  100. rasa/core/actions/loops.py +111 -0
  101. rasa/core/actions/two_stage_fallback.py +186 -0
  102. rasa/core/agent.py +557 -0
  103. rasa/core/auth_retry_tracker_store.py +122 -0
  104. rasa/core/brokers/__init__.py +0 -0
  105. rasa/core/brokers/broker.py +126 -0
  106. rasa/core/brokers/file.py +58 -0
  107. rasa/core/brokers/kafka.py +322 -0
  108. rasa/core/brokers/pika.py +387 -0
  109. rasa/core/brokers/sql.py +86 -0
  110. rasa/core/channels/__init__.py +55 -0
  111. rasa/core/channels/audiocodes.py +463 -0
  112. rasa/core/channels/botframework.py +339 -0
  113. rasa/core/channels/callback.py +85 -0
  114. rasa/core/channels/channel.py +419 -0
  115. rasa/core/channels/console.py +243 -0
  116. rasa/core/channels/development_inspector.py +93 -0
  117. rasa/core/channels/facebook.py +422 -0
  118. rasa/core/channels/hangouts.py +335 -0
  119. rasa/core/channels/inspector/.eslintrc.cjs +25 -0
  120. rasa/core/channels/inspector/.gitignore +23 -0
  121. rasa/core/channels/inspector/README.md +54 -0
  122. rasa/core/channels/inspector/assets/favicon.ico +0 -0
  123. rasa/core/channels/inspector/assets/rasa-chat.js +2 -0
  124. rasa/core/channels/inspector/custom.d.ts +3 -0
  125. rasa/core/channels/inspector/dist/assets/arc-5623b6dc.js +1 -0
  126. rasa/core/channels/inspector/dist/assets/array-9f3ba611.js +1 -0
  127. rasa/core/channels/inspector/dist/assets/c4Diagram-d0fbc5ce-685c106a.js +10 -0
  128. rasa/core/channels/inspector/dist/assets/classDiagram-936ed81e-8cbed007.js +2 -0
  129. rasa/core/channels/inspector/dist/assets/classDiagram-v2-c3cb15f1-5889cf12.js +2 -0
  130. rasa/core/channels/inspector/dist/assets/createText-62fc7601-24c249d7.js +7 -0
  131. rasa/core/channels/inspector/dist/assets/edges-f2ad444c-7dd06a75.js +4 -0
  132. rasa/core/channels/inspector/dist/assets/erDiagram-9d236eb7-62c1e54c.js +51 -0
  133. rasa/core/channels/inspector/dist/assets/flowDb-1972c806-ce49b86f.js +6 -0
  134. rasa/core/channels/inspector/dist/assets/flowDiagram-7ea5b25a-4067e48f.js +4 -0
  135. rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-85583a23.js +1 -0
  136. rasa/core/channels/inspector/dist/assets/flowchart-elk-definition-abe16c3d-59fe4051.js +139 -0
  137. rasa/core/channels/inspector/dist/assets/ganttDiagram-9b5ea136-47e3a43b.js +266 -0
  138. rasa/core/channels/inspector/dist/assets/gitGraphDiagram-99d0ae7c-5a2ac0d9.js +70 -0
  139. rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-128cfa44.ttf +0 -0
  140. rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-21dbcb97.woff +0 -0
  141. rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-222b5e26.svg +329 -0
  142. rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-9ad89b2a.woff2 +0 -0
  143. rasa/core/channels/inspector/dist/assets/index-268a75c0.js +1040 -0
  144. rasa/core/channels/inspector/dist/assets/index-2c4b9a3b-dfb8efc4.js +1 -0
  145. rasa/core/channels/inspector/dist/assets/index-3ee28881.css +1 -0
  146. rasa/core/channels/inspector/dist/assets/infoDiagram-736b4530-b0c470f2.js +7 -0
  147. rasa/core/channels/inspector/dist/assets/init-77b53fdd.js +1 -0
  148. rasa/core/channels/inspector/dist/assets/journeyDiagram-df861f2b-2edb829a.js +139 -0
  149. rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-60c05ee4.woff +0 -0
  150. rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-8335d9b8.svg +438 -0
  151. rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-9cc39c75.ttf +0 -0
  152. rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-ead13ccf.woff2 +0 -0
  153. rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-16705655.woff2 +0 -0
  154. rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-5aeb07f9.woff +0 -0
  155. rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-9c459044.ttf +0 -0
  156. rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-9e2898a4.svg +435 -0
  157. rasa/core/channels/inspector/dist/assets/layout-b6873d69.js +1 -0
  158. rasa/core/channels/inspector/dist/assets/line-1efc5781.js +1 -0
  159. rasa/core/channels/inspector/dist/assets/linear-661e9b94.js +1 -0
  160. rasa/core/channels/inspector/dist/assets/mindmap-definition-beec6740-2d2e727f.js +109 -0
  161. rasa/core/channels/inspector/dist/assets/ordinal-ba9b4969.js +1 -0
  162. rasa/core/channels/inspector/dist/assets/path-53f90ab3.js +1 -0
  163. rasa/core/channels/inspector/dist/assets/pieDiagram-dbbf0591-9d3ea93d.js +35 -0
  164. rasa/core/channels/inspector/dist/assets/quadrantDiagram-4d7f4fd6-06a178a2.js +7 -0
  165. rasa/core/channels/inspector/dist/assets/requirementDiagram-6fc4c22a-0bfedffc.js +52 -0
  166. rasa/core/channels/inspector/dist/assets/sankeyDiagram-8f13d901-d76d0a04.js +8 -0
  167. rasa/core/channels/inspector/dist/assets/sequenceDiagram-b655622a-37bb4341.js +122 -0
  168. rasa/core/channels/inspector/dist/assets/stateDiagram-59f0c015-f52f7f57.js +1 -0
  169. rasa/core/channels/inspector/dist/assets/stateDiagram-v2-2b26beab-4a986a20.js +1 -0
  170. rasa/core/channels/inspector/dist/assets/styles-080da4f6-7dd9ae12.js +110 -0
  171. rasa/core/channels/inspector/dist/assets/styles-3dcbcfbf-46e1ca14.js +159 -0
  172. rasa/core/channels/inspector/dist/assets/styles-9c745c82-4a97439a.js +207 -0
  173. rasa/core/channels/inspector/dist/assets/svgDrawCommon-4835440b-823917a3.js +1 -0
  174. rasa/core/channels/inspector/dist/assets/timeline-definition-5b62e21b-9ea72896.js +61 -0
  175. rasa/core/channels/inspector/dist/assets/xychartDiagram-2b33534f-b631a8b6.js +7 -0
  176. rasa/core/channels/inspector/dist/index.html +39 -0
  177. rasa/core/channels/inspector/index.html +37 -0
  178. rasa/core/channels/inspector/jest.config.ts +13 -0
  179. rasa/core/channels/inspector/package.json +48 -0
  180. rasa/core/channels/inspector/setupTests.ts +2 -0
  181. rasa/core/channels/inspector/src/App.tsx +170 -0
  182. rasa/core/channels/inspector/src/components/DiagramFlow.tsx +97 -0
  183. rasa/core/channels/inspector/src/components/DialogueInformation.tsx +187 -0
  184. rasa/core/channels/inspector/src/components/DialogueStack.tsx +151 -0
  185. rasa/core/channels/inspector/src/components/ExpandIcon.tsx +16 -0
  186. rasa/core/channels/inspector/src/components/FullscreenButton.tsx +45 -0
  187. rasa/core/channels/inspector/src/components/LoadingSpinner.tsx +19 -0
  188. rasa/core/channels/inspector/src/components/NoActiveFlow.tsx +21 -0
  189. rasa/core/channels/inspector/src/components/RasaLogo.tsx +32 -0
  190. rasa/core/channels/inspector/src/components/SaraDiagrams.tsx +39 -0
  191. rasa/core/channels/inspector/src/components/Slots.tsx +91 -0
  192. rasa/core/channels/inspector/src/components/Welcome.tsx +54 -0
  193. rasa/core/channels/inspector/src/helpers/formatters.test.ts +385 -0
  194. rasa/core/channels/inspector/src/helpers/formatters.ts +239 -0
  195. rasa/core/channels/inspector/src/helpers/utils.ts +42 -0
  196. rasa/core/channels/inspector/src/main.tsx +13 -0
  197. rasa/core/channels/inspector/src/theme/Button/Button.ts +29 -0
  198. rasa/core/channels/inspector/src/theme/Heading/Heading.ts +31 -0
  199. rasa/core/channels/inspector/src/theme/Input/Input.ts +27 -0
  200. rasa/core/channels/inspector/src/theme/Link/Link.ts +10 -0
  201. rasa/core/channels/inspector/src/theme/Modal/Modal.ts +47 -0
  202. rasa/core/channels/inspector/src/theme/Table/Table.tsx +38 -0
  203. rasa/core/channels/inspector/src/theme/Tooltip/Tooltip.ts +12 -0
  204. rasa/core/channels/inspector/src/theme/base/breakpoints.ts +8 -0
  205. rasa/core/channels/inspector/src/theme/base/colors.ts +88 -0
  206. rasa/core/channels/inspector/src/theme/base/fonts/fontFaces.css +29 -0
  207. rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.eot +0 -0
  208. rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.svg +329 -0
  209. rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.ttf +0 -0
  210. rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.woff +0 -0
  211. rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.woff2 +0 -0
  212. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.eot +0 -0
  213. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.svg +438 -0
  214. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.ttf +0 -0
  215. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.woff +0 -0
  216. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.woff2 +0 -0
  217. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.eot +0 -0
  218. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.svg +435 -0
  219. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.ttf +0 -0
  220. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.woff +0 -0
  221. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.woff2 +0 -0
  222. rasa/core/channels/inspector/src/theme/base/radii.ts +9 -0
  223. rasa/core/channels/inspector/src/theme/base/shadows.ts +7 -0
  224. rasa/core/channels/inspector/src/theme/base/sizes.ts +7 -0
  225. rasa/core/channels/inspector/src/theme/base/space.ts +15 -0
  226. rasa/core/channels/inspector/src/theme/base/styles.ts +13 -0
  227. rasa/core/channels/inspector/src/theme/base/typography.ts +24 -0
  228. rasa/core/channels/inspector/src/theme/base/zIndices.ts +19 -0
  229. rasa/core/channels/inspector/src/theme/index.ts +101 -0
  230. rasa/core/channels/inspector/src/types.ts +64 -0
  231. rasa/core/channels/inspector/src/vite-env.d.ts +1 -0
  232. rasa/core/channels/inspector/tests/__mocks__/fileMock.ts +1 -0
  233. rasa/core/channels/inspector/tests/__mocks__/matchMedia.ts +16 -0
  234. rasa/core/channels/inspector/tests/__mocks__/styleMock.ts +1 -0
  235. rasa/core/channels/inspector/tests/renderWithProviders.tsx +14 -0
  236. rasa/core/channels/inspector/tsconfig.json +26 -0
  237. rasa/core/channels/inspector/tsconfig.node.json +10 -0
  238. rasa/core/channels/inspector/vite.config.ts +8 -0
  239. rasa/core/channels/inspector/yarn.lock +6156 -0
  240. rasa/core/channels/mattermost.py +229 -0
  241. rasa/core/channels/rasa_chat.py +126 -0
  242. rasa/core/channels/rest.py +210 -0
  243. rasa/core/channels/rocketchat.py +175 -0
  244. rasa/core/channels/slack.py +620 -0
  245. rasa/core/channels/socketio.py +274 -0
  246. rasa/core/channels/telegram.py +298 -0
  247. rasa/core/channels/twilio.py +169 -0
  248. rasa/core/channels/twilio_voice.py +367 -0
  249. rasa/core/channels/vier_cvg.py +374 -0
  250. rasa/core/channels/webexteams.py +135 -0
  251. rasa/core/concurrent_lock_store.py +210 -0
  252. rasa/core/constants.py +107 -0
  253. rasa/core/evaluation/__init__.py +0 -0
  254. rasa/core/evaluation/marker.py +267 -0
  255. rasa/core/evaluation/marker_base.py +925 -0
  256. rasa/core/evaluation/marker_stats.py +294 -0
  257. rasa/core/evaluation/marker_tracker_loader.py +103 -0
  258. rasa/core/exceptions.py +29 -0
  259. rasa/core/exporter.py +284 -0
  260. rasa/core/featurizers/__init__.py +0 -0
  261. rasa/core/featurizers/precomputation.py +410 -0
  262. rasa/core/featurizers/single_state_featurizer.py +402 -0
  263. rasa/core/featurizers/tracker_featurizers.py +1172 -0
  264. rasa/core/http_interpreter.py +89 -0
  265. rasa/core/information_retrieval/__init__.py +0 -0
  266. rasa/core/information_retrieval/faiss.py +116 -0
  267. rasa/core/information_retrieval/information_retrieval.py +72 -0
  268. rasa/core/information_retrieval/milvus.py +59 -0
  269. rasa/core/information_retrieval/qdrant.py +102 -0
  270. rasa/core/jobs.py +63 -0
  271. rasa/core/lock.py +139 -0
  272. rasa/core/lock_store.py +344 -0
  273. rasa/core/migrate.py +404 -0
  274. rasa/core/nlg/__init__.py +3 -0
  275. rasa/core/nlg/callback.py +147 -0
  276. rasa/core/nlg/contextual_response_rephraser.py +270 -0
  277. rasa/core/nlg/generator.py +230 -0
  278. rasa/core/nlg/interpolator.py +143 -0
  279. rasa/core/nlg/response.py +155 -0
  280. rasa/core/nlg/summarize.py +69 -0
  281. rasa/core/policies/__init__.py +0 -0
  282. rasa/core/policies/ensemble.py +329 -0
  283. rasa/core/policies/enterprise_search_policy.py +717 -0
  284. rasa/core/policies/enterprise_search_prompt_template.jinja2 +62 -0
  285. rasa/core/policies/flow_policy.py +205 -0
  286. rasa/core/policies/flows/__init__.py +0 -0
  287. rasa/core/policies/flows/flow_exceptions.py +44 -0
  288. rasa/core/policies/flows/flow_executor.py +582 -0
  289. rasa/core/policies/flows/flow_step_result.py +43 -0
  290. rasa/core/policies/intentless_policy.py +924 -0
  291. rasa/core/policies/intentless_prompt_template.jinja2 +22 -0
  292. rasa/core/policies/memoization.py +538 -0
  293. rasa/core/policies/policy.py +716 -0
  294. rasa/core/policies/rule_policy.py +1276 -0
  295. rasa/core/policies/ted_policy.py +2146 -0
  296. rasa/core/policies/unexpected_intent_policy.py +1015 -0
  297. rasa/core/processor.py +1331 -0
  298. rasa/core/run.py +315 -0
  299. rasa/core/secrets_manager/__init__.py +0 -0
  300. rasa/core/secrets_manager/constants.py +32 -0
  301. rasa/core/secrets_manager/endpoints.py +391 -0
  302. rasa/core/secrets_manager/factory.py +233 -0
  303. rasa/core/secrets_manager/secret_manager.py +262 -0
  304. rasa/core/secrets_manager/vault.py +576 -0
  305. rasa/core/test.py +1337 -0
  306. rasa/core/tracker_store.py +1664 -0
  307. rasa/core/train.py +107 -0
  308. rasa/core/training/__init__.py +89 -0
  309. rasa/core/training/converters/__init__.py +0 -0
  310. rasa/core/training/converters/responses_prefix_converter.py +119 -0
  311. rasa/core/training/interactive.py +1742 -0
  312. rasa/core/training/story_conflict.py +381 -0
  313. rasa/core/training/training.py +93 -0
  314. rasa/core/utils.py +344 -0
  315. rasa/core/visualize.py +70 -0
  316. rasa/dialogue_understanding/__init__.py +0 -0
  317. rasa/dialogue_understanding/coexistence/__init__.py +0 -0
  318. rasa/dialogue_understanding/coexistence/constants.py +4 -0
  319. rasa/dialogue_understanding/coexistence/intent_based_router.py +189 -0
  320. rasa/dialogue_understanding/coexistence/llm_based_router.py +261 -0
  321. rasa/dialogue_understanding/coexistence/router_template.jinja2 +12 -0
  322. rasa/dialogue_understanding/commands/__init__.py +45 -0
  323. rasa/dialogue_understanding/commands/can_not_handle_command.py +61 -0
  324. rasa/dialogue_understanding/commands/cancel_flow_command.py +116 -0
  325. rasa/dialogue_understanding/commands/chit_chat_answer_command.py +48 -0
  326. rasa/dialogue_understanding/commands/clarify_command.py +77 -0
  327. rasa/dialogue_understanding/commands/command.py +85 -0
  328. rasa/dialogue_understanding/commands/correct_slots_command.py +288 -0
  329. rasa/dialogue_understanding/commands/error_command.py +67 -0
  330. rasa/dialogue_understanding/commands/free_form_answer_command.py +9 -0
  331. rasa/dialogue_understanding/commands/handle_code_change_command.py +64 -0
  332. rasa/dialogue_understanding/commands/human_handoff_command.py +57 -0
  333. rasa/dialogue_understanding/commands/knowledge_answer_command.py +48 -0
  334. rasa/dialogue_understanding/commands/noop_command.py +45 -0
  335. rasa/dialogue_understanding/commands/set_slot_command.py +125 -0
  336. rasa/dialogue_understanding/commands/skip_question_command.py +66 -0
  337. rasa/dialogue_understanding/commands/start_flow_command.py +98 -0
  338. rasa/dialogue_understanding/generator/__init__.py +6 -0
  339. rasa/dialogue_understanding/generator/command_generator.py +257 -0
  340. rasa/dialogue_understanding/generator/command_prompt_template.jinja2 +57 -0
  341. rasa/dialogue_understanding/generator/flow_document_template.jinja2 +4 -0
  342. rasa/dialogue_understanding/generator/flow_retrieval.py +410 -0
  343. rasa/dialogue_understanding/generator/llm_command_generator.py +637 -0
  344. rasa/dialogue_understanding/generator/nlu_command_adapter.py +157 -0
  345. rasa/dialogue_understanding/patterns/__init__.py +0 -0
  346. rasa/dialogue_understanding/patterns/cancel.py +111 -0
  347. rasa/dialogue_understanding/patterns/cannot_handle.py +43 -0
  348. rasa/dialogue_understanding/patterns/chitchat.py +37 -0
  349. rasa/dialogue_understanding/patterns/clarify.py +97 -0
  350. rasa/dialogue_understanding/patterns/code_change.py +41 -0
  351. rasa/dialogue_understanding/patterns/collect_information.py +90 -0
  352. rasa/dialogue_understanding/patterns/completed.py +40 -0
  353. rasa/dialogue_understanding/patterns/continue_interrupted.py +42 -0
  354. rasa/dialogue_understanding/patterns/correction.py +278 -0
  355. rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +243 -0
  356. rasa/dialogue_understanding/patterns/human_handoff.py +37 -0
  357. rasa/dialogue_understanding/patterns/internal_error.py +47 -0
  358. rasa/dialogue_understanding/patterns/search.py +37 -0
  359. rasa/dialogue_understanding/patterns/skip_question.py +38 -0
  360. rasa/dialogue_understanding/processor/__init__.py +0 -0
  361. rasa/dialogue_understanding/processor/command_processor.py +578 -0
  362. rasa/dialogue_understanding/processor/command_processor_component.py +39 -0
  363. rasa/dialogue_understanding/stack/__init__.py +0 -0
  364. rasa/dialogue_understanding/stack/dialogue_stack.py +178 -0
  365. rasa/dialogue_understanding/stack/frames/__init__.py +19 -0
  366. rasa/dialogue_understanding/stack/frames/chit_chat_frame.py +27 -0
  367. rasa/dialogue_understanding/stack/frames/dialogue_stack_frame.py +137 -0
  368. rasa/dialogue_understanding/stack/frames/flow_stack_frame.py +157 -0
  369. rasa/dialogue_understanding/stack/frames/pattern_frame.py +10 -0
  370. rasa/dialogue_understanding/stack/frames/search_frame.py +27 -0
  371. rasa/dialogue_understanding/stack/utils.py +211 -0
  372. rasa/e2e_test/__init__.py +0 -0
  373. rasa/e2e_test/constants.py +10 -0
  374. rasa/e2e_test/e2e_test_case.py +322 -0
  375. rasa/e2e_test/e2e_test_result.py +34 -0
  376. rasa/e2e_test/e2e_test_runner.py +659 -0
  377. rasa/e2e_test/e2e_test_schema.yml +67 -0
  378. rasa/engine/__init__.py +0 -0
  379. rasa/engine/caching.py +464 -0
  380. rasa/engine/constants.py +17 -0
  381. rasa/engine/exceptions.py +14 -0
  382. rasa/engine/graph.py +625 -0
  383. rasa/engine/loader.py +36 -0
  384. rasa/engine/recipes/__init__.py +0 -0
  385. rasa/engine/recipes/config_files/default_config.yml +44 -0
  386. rasa/engine/recipes/default_components.py +99 -0
  387. rasa/engine/recipes/default_recipe.py +1252 -0
  388. rasa/engine/recipes/graph_recipe.py +79 -0
  389. rasa/engine/recipes/recipe.py +93 -0
  390. rasa/engine/runner/__init__.py +0 -0
  391. rasa/engine/runner/dask.py +256 -0
  392. rasa/engine/runner/interface.py +49 -0
  393. rasa/engine/storage/__init__.py +0 -0
  394. rasa/engine/storage/local_model_storage.py +248 -0
  395. rasa/engine/storage/resource.py +110 -0
  396. rasa/engine/storage/storage.py +203 -0
  397. rasa/engine/training/__init__.py +0 -0
  398. rasa/engine/training/components.py +176 -0
  399. rasa/engine/training/fingerprinting.py +64 -0
  400. rasa/engine/training/graph_trainer.py +256 -0
  401. rasa/engine/training/hooks.py +164 -0
  402. rasa/engine/validation.py +839 -0
  403. rasa/env.py +5 -0
  404. rasa/exceptions.py +69 -0
  405. rasa/graph_components/__init__.py +0 -0
  406. rasa/graph_components/converters/__init__.py +0 -0
  407. rasa/graph_components/converters/nlu_message_converter.py +48 -0
  408. rasa/graph_components/providers/__init__.py +0 -0
  409. rasa/graph_components/providers/domain_for_core_training_provider.py +87 -0
  410. rasa/graph_components/providers/domain_provider.py +71 -0
  411. rasa/graph_components/providers/flows_provider.py +74 -0
  412. rasa/graph_components/providers/forms_provider.py +44 -0
  413. rasa/graph_components/providers/nlu_training_data_provider.py +56 -0
  414. rasa/graph_components/providers/responses_provider.py +44 -0
  415. rasa/graph_components/providers/rule_only_provider.py +49 -0
  416. rasa/graph_components/providers/story_graph_provider.py +43 -0
  417. rasa/graph_components/providers/training_tracker_provider.py +55 -0
  418. rasa/graph_components/validators/__init__.py +0 -0
  419. rasa/graph_components/validators/default_recipe_validator.py +552 -0
  420. rasa/graph_components/validators/finetuning_validator.py +302 -0
  421. rasa/hooks.py +113 -0
  422. rasa/jupyter.py +63 -0
  423. rasa/keys +1 -0
  424. rasa/markers/__init__.py +0 -0
  425. rasa/markers/marker.py +269 -0
  426. rasa/markers/marker_base.py +828 -0
  427. rasa/markers/upload.py +74 -0
  428. rasa/markers/validate.py +21 -0
  429. rasa/model.py +118 -0
  430. rasa/model_testing.py +457 -0
  431. rasa/model_training.py +535 -0
  432. rasa/nlu/__init__.py +7 -0
  433. rasa/nlu/classifiers/__init__.py +3 -0
  434. rasa/nlu/classifiers/classifier.py +5 -0
  435. rasa/nlu/classifiers/diet_classifier.py +1874 -0
  436. rasa/nlu/classifiers/fallback_classifier.py +192 -0
  437. rasa/nlu/classifiers/keyword_intent_classifier.py +188 -0
  438. rasa/nlu/classifiers/llm_intent_classifier.py +519 -0
  439. rasa/nlu/classifiers/logistic_regression_classifier.py +240 -0
  440. rasa/nlu/classifiers/mitie_intent_classifier.py +156 -0
  441. rasa/nlu/classifiers/regex_message_handler.py +56 -0
  442. rasa/nlu/classifiers/sklearn_intent_classifier.py +309 -0
  443. rasa/nlu/constants.py +77 -0
  444. rasa/nlu/convert.py +40 -0
  445. rasa/nlu/emulators/__init__.py +0 -0
  446. rasa/nlu/emulators/dialogflow.py +55 -0
  447. rasa/nlu/emulators/emulator.py +49 -0
  448. rasa/nlu/emulators/luis.py +86 -0
  449. rasa/nlu/emulators/no_emulator.py +10 -0
  450. rasa/nlu/emulators/wit.py +56 -0
  451. rasa/nlu/extractors/__init__.py +0 -0
  452. rasa/nlu/extractors/crf_entity_extractor.py +672 -0
  453. rasa/nlu/extractors/duckling_entity_extractor.py +206 -0
  454. rasa/nlu/extractors/entity_synonyms.py +178 -0
  455. rasa/nlu/extractors/extractor.py +470 -0
  456. rasa/nlu/extractors/mitie_entity_extractor.py +293 -0
  457. rasa/nlu/extractors/regex_entity_extractor.py +220 -0
  458. rasa/nlu/extractors/spacy_entity_extractor.py +95 -0
  459. rasa/nlu/featurizers/__init__.py +0 -0
  460. rasa/nlu/featurizers/dense_featurizer/__init__.py +0 -0
  461. rasa/nlu/featurizers/dense_featurizer/convert_featurizer.py +449 -0
  462. rasa/nlu/featurizers/dense_featurizer/dense_featurizer.py +57 -0
  463. rasa/nlu/featurizers/dense_featurizer/lm_featurizer.py +772 -0
  464. rasa/nlu/featurizers/dense_featurizer/mitie_featurizer.py +170 -0
  465. rasa/nlu/featurizers/dense_featurizer/spacy_featurizer.py +132 -0
  466. rasa/nlu/featurizers/featurizer.py +89 -0
  467. rasa/nlu/featurizers/sparse_featurizer/__init__.py +0 -0
  468. rasa/nlu/featurizers/sparse_featurizer/count_vectors_featurizer.py +840 -0
  469. rasa/nlu/featurizers/sparse_featurizer/lexical_syntactic_featurizer.py +539 -0
  470. rasa/nlu/featurizers/sparse_featurizer/regex_featurizer.py +269 -0
  471. rasa/nlu/featurizers/sparse_featurizer/sparse_featurizer.py +9 -0
  472. rasa/nlu/model.py +24 -0
  473. rasa/nlu/persistor.py +240 -0
  474. rasa/nlu/run.py +27 -0
  475. rasa/nlu/selectors/__init__.py +0 -0
  476. rasa/nlu/selectors/response_selector.py +990 -0
  477. rasa/nlu/test.py +1943 -0
  478. rasa/nlu/tokenizers/__init__.py +0 -0
  479. rasa/nlu/tokenizers/jieba_tokenizer.py +148 -0
  480. rasa/nlu/tokenizers/mitie_tokenizer.py +75 -0
  481. rasa/nlu/tokenizers/spacy_tokenizer.py +72 -0
  482. rasa/nlu/tokenizers/tokenizer.py +239 -0
  483. rasa/nlu/tokenizers/whitespace_tokenizer.py +106 -0
  484. rasa/nlu/utils/__init__.py +35 -0
  485. rasa/nlu/utils/bilou_utils.py +462 -0
  486. rasa/nlu/utils/hugging_face/__init__.py +0 -0
  487. rasa/nlu/utils/hugging_face/registry.py +108 -0
  488. rasa/nlu/utils/hugging_face/transformers_pre_post_processors.py +311 -0
  489. rasa/nlu/utils/mitie_utils.py +113 -0
  490. rasa/nlu/utils/pattern_utils.py +168 -0
  491. rasa/nlu/utils/spacy_utils.py +312 -0
  492. rasa/plugin.py +90 -0
  493. rasa/server.py +1536 -0
  494. rasa/shared/__init__.py +0 -0
  495. rasa/shared/constants.py +181 -0
  496. rasa/shared/core/__init__.py +0 -0
  497. rasa/shared/core/constants.py +168 -0
  498. rasa/shared/core/conversation.py +46 -0
  499. rasa/shared/core/domain.py +2106 -0
  500. rasa/shared/core/events.py +2507 -0
  501. rasa/shared/core/flows/__init__.py +7 -0
  502. rasa/shared/core/flows/flow.py +353 -0
  503. rasa/shared/core/flows/flow_step.py +146 -0
  504. rasa/shared/core/flows/flow_step_links.py +319 -0
  505. rasa/shared/core/flows/flow_step_sequence.py +70 -0
  506. rasa/shared/core/flows/flows_list.py +211 -0
  507. rasa/shared/core/flows/flows_yaml_schema.json +217 -0
  508. rasa/shared/core/flows/nlu_trigger.py +117 -0
  509. rasa/shared/core/flows/steps/__init__.py +24 -0
  510. rasa/shared/core/flows/steps/action.py +51 -0
  511. rasa/shared/core/flows/steps/call.py +64 -0
  512. rasa/shared/core/flows/steps/collect.py +112 -0
  513. rasa/shared/core/flows/steps/constants.py +5 -0
  514. rasa/shared/core/flows/steps/continuation.py +36 -0
  515. rasa/shared/core/flows/steps/end.py +22 -0
  516. rasa/shared/core/flows/steps/internal.py +44 -0
  517. rasa/shared/core/flows/steps/link.py +51 -0
  518. rasa/shared/core/flows/steps/no_operation.py +48 -0
  519. rasa/shared/core/flows/steps/set_slots.py +50 -0
  520. rasa/shared/core/flows/steps/start.py +30 -0
  521. rasa/shared/core/flows/validation.py +527 -0
  522. rasa/shared/core/flows/yaml_flows_io.py +278 -0
  523. rasa/shared/core/generator.py +907 -0
  524. rasa/shared/core/slot_mappings.py +235 -0
  525. rasa/shared/core/slots.py +647 -0
  526. rasa/shared/core/trackers.py +1159 -0
  527. rasa/shared/core/training_data/__init__.py +0 -0
  528. rasa/shared/core/training_data/loading.py +90 -0
  529. rasa/shared/core/training_data/story_reader/__init__.py +0 -0
  530. rasa/shared/core/training_data/story_reader/story_reader.py +129 -0
  531. rasa/shared/core/training_data/story_reader/story_step_builder.py +168 -0
  532. rasa/shared/core/training_data/story_reader/yaml_story_reader.py +888 -0
  533. rasa/shared/core/training_data/story_writer/__init__.py +0 -0
  534. rasa/shared/core/training_data/story_writer/story_writer.py +76 -0
  535. rasa/shared/core/training_data/story_writer/yaml_story_writer.py +442 -0
  536. rasa/shared/core/training_data/structures.py +838 -0
  537. rasa/shared/core/training_data/visualization.html +146 -0
  538. rasa/shared/core/training_data/visualization.py +603 -0
  539. rasa/shared/data.py +192 -0
  540. rasa/shared/engine/__init__.py +0 -0
  541. rasa/shared/engine/caching.py +26 -0
  542. rasa/shared/exceptions.py +129 -0
  543. rasa/shared/importers/__init__.py +0 -0
  544. rasa/shared/importers/importer.py +705 -0
  545. rasa/shared/importers/multi_project.py +203 -0
  546. rasa/shared/importers/rasa.py +100 -0
  547. rasa/shared/importers/utils.py +34 -0
  548. rasa/shared/nlu/__init__.py +0 -0
  549. rasa/shared/nlu/constants.py +45 -0
  550. rasa/shared/nlu/interpreter.py +10 -0
  551. rasa/shared/nlu/training_data/__init__.py +0 -0
  552. rasa/shared/nlu/training_data/entities_parser.py +209 -0
  553. rasa/shared/nlu/training_data/features.py +374 -0
  554. rasa/shared/nlu/training_data/formats/__init__.py +10 -0
  555. rasa/shared/nlu/training_data/formats/dialogflow.py +162 -0
  556. rasa/shared/nlu/training_data/formats/luis.py +87 -0
  557. rasa/shared/nlu/training_data/formats/rasa.py +135 -0
  558. rasa/shared/nlu/training_data/formats/rasa_yaml.py +605 -0
  559. rasa/shared/nlu/training_data/formats/readerwriter.py +245 -0
  560. rasa/shared/nlu/training_data/formats/wit.py +52 -0
  561. rasa/shared/nlu/training_data/loading.py +137 -0
  562. rasa/shared/nlu/training_data/lookup_tables_parser.py +30 -0
  563. rasa/shared/nlu/training_data/message.py +477 -0
  564. rasa/shared/nlu/training_data/schemas/__init__.py +0 -0
  565. rasa/shared/nlu/training_data/schemas/data_schema.py +85 -0
  566. rasa/shared/nlu/training_data/schemas/nlu.yml +53 -0
  567. rasa/shared/nlu/training_data/schemas/responses.yml +70 -0
  568. rasa/shared/nlu/training_data/synonyms_parser.py +42 -0
  569. rasa/shared/nlu/training_data/training_data.py +732 -0
  570. rasa/shared/nlu/training_data/util.py +223 -0
  571. rasa/shared/providers/__init__.py +0 -0
  572. rasa/shared/providers/openai/__init__.py +0 -0
  573. rasa/shared/providers/openai/clients.py +43 -0
  574. rasa/shared/providers/openai/session_handler.py +110 -0
  575. rasa/shared/utils/__init__.py +0 -0
  576. rasa/shared/utils/cli.py +72 -0
  577. rasa/shared/utils/common.py +308 -0
  578. rasa/shared/utils/constants.py +1 -0
  579. rasa/shared/utils/io.py +403 -0
  580. rasa/shared/utils/llm.py +405 -0
  581. rasa/shared/utils/pykwalify_extensions.py +26 -0
  582. rasa/shared/utils/schemas/__init__.py +0 -0
  583. rasa/shared/utils/schemas/config.yml +2 -0
  584. rasa/shared/utils/schemas/domain.yml +142 -0
  585. rasa/shared/utils/schemas/events.py +212 -0
  586. rasa/shared/utils/schemas/model_config.yml +46 -0
  587. rasa/shared/utils/schemas/stories.yml +173 -0
  588. rasa/shared/utils/yaml.py +777 -0
  589. rasa/studio/__init__.py +0 -0
  590. rasa/studio/auth.py +252 -0
  591. rasa/studio/config.py +127 -0
  592. rasa/studio/constants.py +16 -0
  593. rasa/studio/data_handler.py +352 -0
  594. rasa/studio/download.py +350 -0
  595. rasa/studio/train.py +136 -0
  596. rasa/studio/upload.py +408 -0
  597. rasa/telemetry.py +1583 -0
  598. rasa/tracing/__init__.py +0 -0
  599. rasa/tracing/config.py +338 -0
  600. rasa/tracing/constants.py +38 -0
  601. rasa/tracing/instrumentation/__init__.py +0 -0
  602. rasa/tracing/instrumentation/attribute_extractors.py +663 -0
  603. rasa/tracing/instrumentation/instrumentation.py +939 -0
  604. rasa/tracing/instrumentation/intentless_policy_instrumentation.py +142 -0
  605. rasa/tracing/instrumentation/metrics.py +206 -0
  606. rasa/tracing/metric_instrument_provider.py +125 -0
  607. rasa/utils/__init__.py +0 -0
  608. rasa/utils/beta.py +83 -0
  609. rasa/utils/cli.py +27 -0
  610. rasa/utils/common.py +635 -0
  611. rasa/utils/converter.py +53 -0
  612. rasa/utils/endpoints.py +303 -0
  613. rasa/utils/io.py +326 -0
  614. rasa/utils/licensing.py +319 -0
  615. rasa/utils/log_utils.py +174 -0
  616. rasa/utils/mapper.py +210 -0
  617. rasa/utils/ml_utils.py +145 -0
  618. rasa/utils/plotting.py +362 -0
  619. rasa/utils/singleton.py +23 -0
  620. rasa/utils/tensorflow/__init__.py +0 -0
  621. rasa/utils/tensorflow/callback.py +112 -0
  622. rasa/utils/tensorflow/constants.py +116 -0
  623. rasa/utils/tensorflow/crf.py +492 -0
  624. rasa/utils/tensorflow/data_generator.py +440 -0
  625. rasa/utils/tensorflow/environment.py +161 -0
  626. rasa/utils/tensorflow/exceptions.py +5 -0
  627. rasa/utils/tensorflow/layers.py +1565 -0
  628. rasa/utils/tensorflow/layers_utils.py +113 -0
  629. rasa/utils/tensorflow/metrics.py +281 -0
  630. rasa/utils/tensorflow/model_data.py +991 -0
  631. rasa/utils/tensorflow/model_data_utils.py +500 -0
  632. rasa/utils/tensorflow/models.py +936 -0
  633. rasa/utils/tensorflow/rasa_layers.py +1094 -0
  634. rasa/utils/tensorflow/transformer.py +640 -0
  635. rasa/utils/tensorflow/types.py +6 -0
  636. rasa/utils/train_utils.py +572 -0
  637. rasa/utils/yaml.py +54 -0
  638. rasa/validator.py +1035 -0
  639. rasa/version.py +3 -0
  640. rasa_pro-3.8.16.dist-info/METADATA +528 -0
  641. rasa_pro-3.8.16.dist-info/NOTICE +5 -0
  642. rasa_pro-3.8.16.dist-info/RECORD +644 -0
  643. rasa_pro-3.8.16.dist-info/WHEEL +4 -0
  644. rasa_pro-3.8.16.dist-info/entry_points.txt +3 -0
@@ -0,0 +1,772 @@
1
+ from __future__ import annotations
2
+ import numpy as np
3
+ import logging
4
+
5
+ from typing import Any, Text, List, Dict, Tuple, Type
6
+ import tensorflow as tf
7
+
8
+ from rasa.engine.graph import ExecutionContext, GraphComponent
9
+ from rasa.engine.recipes.default_recipe import DefaultV1Recipe
10
+ from rasa.engine.storage.resource import Resource
11
+ from rasa.engine.storage.storage import ModelStorage
12
+ from rasa.nlu.featurizers.dense_featurizer.dense_featurizer import DenseFeaturizer
13
+ from rasa.nlu.tokenizers.tokenizer import Token, Tokenizer
14
+ from rasa.shared.nlu.training_data.training_data import TrainingData
15
+ from rasa.shared.nlu.training_data.message import Message
16
+ from rasa.nlu.constants import (
17
+ DENSE_FEATURIZABLE_ATTRIBUTES,
18
+ SEQUENCE_FEATURES,
19
+ SENTENCE_FEATURES,
20
+ NO_LENGTH_RESTRICTION,
21
+ NUMBER_OF_SUB_TOKENS,
22
+ TOKENS_NAMES,
23
+ )
24
+ from rasa.shared.nlu.constants import TEXT, ACTION_TEXT
25
+ from rasa.utils import train_utils
26
+ from rasa.utils.tensorflow.model_data import ragged_array_to_ndarray
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+ MAX_SEQUENCE_LENGTHS = {
31
+ "bert": 512,
32
+ "gpt": 512,
33
+ "gpt2": 512,
34
+ "xlnet": NO_LENGTH_RESTRICTION,
35
+ "distilbert": 512,
36
+ "roberta": 512,
37
+ "camembert": 512,
38
+ }
39
+
40
+
41
+ @DefaultV1Recipe.register(
42
+ DefaultV1Recipe.ComponentType.MESSAGE_FEATURIZER, is_trainable=False
43
+ )
44
+ class LanguageModelFeaturizer(DenseFeaturizer, GraphComponent):
45
+ """A featurizer that uses transformer-based language models.
46
+
47
+ This component loads a pre-trained language model
48
+ from the Transformers library (https://github.com/huggingface/transformers)
49
+ including BERT, GPT, GPT-2, xlnet, distilbert, and roberta.
50
+ It also tokenizes and featurizes the featurizable dense attributes of
51
+ each message.
52
+ """
53
+
54
+ @classmethod
55
+ def required_components(cls) -> List[Type]:
56
+ """Components that should be included in the pipeline before this component."""
57
+ return [Tokenizer]
58
+
59
+ def __init__(
60
+ self, config: Dict[Text, Any], execution_context: ExecutionContext
61
+ ) -> None:
62
+ """Initializes the featurizer with the model in the config."""
63
+ super(LanguageModelFeaturizer, self).__init__(
64
+ execution_context.node_name, config
65
+ )
66
+ self._load_model_metadata()
67
+ self._load_model_instance()
68
+
69
+ @staticmethod
70
+ def get_default_config() -> Dict[Text, Any]:
71
+ """Returns LanguageModelFeaturizer's default config."""
72
+ return {
73
+ **DenseFeaturizer.get_default_config(),
74
+ # name of the language model to load.
75
+ "model_name": "bert",
76
+ # Pre-Trained weights to be loaded(string)
77
+ "model_weights": None,
78
+ # an optional path to a specific directory to download
79
+ # and cache the pre-trained model weights.
80
+ "cache_dir": None,
81
+ }
82
+
83
+ @classmethod
84
+ def validate_config(cls, config: Dict[Text, Any]) -> None:
85
+ """Validates the configuration."""
86
+ pass
87
+
88
+ @classmethod
89
+ def create(
90
+ cls,
91
+ config: Dict[Text, Any],
92
+ model_storage: ModelStorage,
93
+ resource: Resource,
94
+ execution_context: ExecutionContext,
95
+ ) -> LanguageModelFeaturizer:
96
+ """Creates a LanguageModelFeaturizer.
97
+
98
+ Loads the model specified in the config.
99
+ """
100
+ return cls(config, execution_context)
101
+
102
+ @staticmethod
103
+ def required_packages() -> List[Text]:
104
+ """Returns the extra python dependencies required."""
105
+ return ["transformers"]
106
+
107
+ def _load_model_metadata(self) -> None:
108
+ """Loads the metadata for the specified model and set them as properties.
109
+
110
+ This includes the model name, model weights, cache directory and the
111
+ maximum sequence length the model can handle.
112
+ """
113
+ from rasa.nlu.utils.hugging_face.registry import (
114
+ model_class_dict,
115
+ model_weights_defaults,
116
+ )
117
+
118
+ self.model_name = self._config["model_name"]
119
+
120
+ if self.model_name not in model_class_dict:
121
+ raise KeyError(
122
+ f"'{self.model_name}' not a valid model name. Choose from "
123
+ f"{list(model_class_dict.keys())!s} or create"
124
+ f"a new class inheriting from this class to support your model."
125
+ )
126
+
127
+ self.model_weights = self._config["model_weights"]
128
+ self.cache_dir = self._config["cache_dir"]
129
+
130
+ if not self.model_weights:
131
+ logger.info(
132
+ f"Model weights not specified. Will choose default model "
133
+ f"weights: {model_weights_defaults[self.model_name]}"
134
+ )
135
+ self.model_weights = model_weights_defaults[self.model_name]
136
+
137
+ self.max_model_sequence_length = MAX_SEQUENCE_LENGTHS[self.model_name]
138
+
139
+ def _load_model_instance(self) -> None:
140
+ """Tries to load the model instance.
141
+
142
+ Model loading should be skipped in unit tests.
143
+ See unit tests for examples.
144
+ """
145
+ from rasa.nlu.utils.hugging_face.registry import (
146
+ model_class_dict,
147
+ model_tokenizer_dict,
148
+ )
149
+
150
+ logger.debug(f"Loading Tokenizer and Model for {self.model_name}")
151
+
152
+ self.tokenizer = model_tokenizer_dict[self.model_name].from_pretrained(
153
+ self.model_weights, cache_dir=self.cache_dir
154
+ )
155
+ self.model = model_class_dict[self.model_name].from_pretrained(
156
+ self.model_weights, cache_dir=self.cache_dir
157
+ )
158
+
159
+ # Use a universal pad token since all transformer architectures do not have a
160
+ # consistent token. Instead of pad_token_id we use unk_token_id because
161
+ # pad_token_id is not set for all architectures. We can't add a new token as
162
+ # well since vocabulary resizing is not yet supported for TF classes.
163
+ # Also, this does not hurt the model predictions since we use an attention mask
164
+ # while feeding input.
165
+ self.pad_token_id = self.tokenizer.unk_token_id
166
+
167
+ def _lm_tokenize(self, text: Text) -> Tuple[List[int], List[Text]]:
168
+ """Passes the text through the tokenizer of the language model.
169
+
170
+ Args:
171
+ text: Text to be tokenized.
172
+
173
+ Returns: List of token ids and token strings.
174
+ """
175
+ split_token_ids = self.tokenizer.encode(text, add_special_tokens=False)
176
+
177
+ split_token_strings = self.tokenizer.convert_ids_to_tokens(split_token_ids)
178
+
179
+ return split_token_ids, split_token_strings
180
+
181
+ def _add_lm_specific_special_tokens(
182
+ self, token_ids: List[List[int]]
183
+ ) -> List[List[int]]:
184
+ """Adds the language and model-specific tokens used during training.
185
+
186
+ Args:
187
+ token_ids: List of token ids for each example in the batch.
188
+
189
+ Returns: Augmented list of token ids for each example in the batch.
190
+ """
191
+ from rasa.nlu.utils.hugging_face.registry import (
192
+ model_special_tokens_pre_processors,
193
+ )
194
+
195
+ augmented_tokens = [
196
+ model_special_tokens_pre_processors[self.model_name](example_token_ids)
197
+ for example_token_ids in token_ids
198
+ ]
199
+ return augmented_tokens
200
+
201
+ def _lm_specific_token_cleanup(
202
+ self, split_token_ids: List[int], token_strings: List[Text]
203
+ ) -> Tuple[List[int], List[Text]]:
204
+ """Cleans up special chars added by tokenizers of language models.
205
+
206
+ Many language models add a special char in front/back of (some) words. We clean
207
+ up those chars as they are not
208
+ needed once the features are already computed.
209
+
210
+ Args:
211
+ split_token_ids: List of token ids received as output from the language
212
+ model specific tokenizer.
213
+ token_strings: List of token strings received as output from the language
214
+ model specific tokenizer.
215
+
216
+ Returns: Cleaned up token ids and token strings.
217
+ """
218
+ from rasa.nlu.utils.hugging_face.registry import model_tokens_cleaners
219
+
220
+ return model_tokens_cleaners[self.model_name](split_token_ids, token_strings)
221
+
222
+ def _post_process_sequence_embeddings(
223
+ self, sequence_embeddings: np.ndarray
224
+ ) -> Tuple[np.ndarray, np.ndarray]:
225
+ """Computes sentence and sequence level representations for relevant tokens.
226
+
227
+ Args:
228
+ sequence_embeddings: Sequence level dense features received as output from
229
+ language model.
230
+
231
+ Returns: Sentence and sequence level representations.
232
+ """
233
+ from rasa.nlu.utils.hugging_face.registry import (
234
+ model_embeddings_post_processors,
235
+ )
236
+
237
+ sentence_embeddings = []
238
+ post_processed_sequence_embeddings = []
239
+
240
+ for example_embedding in sequence_embeddings:
241
+ (
242
+ example_sentence_embedding,
243
+ example_post_processed_embedding,
244
+ ) = model_embeddings_post_processors[self.model_name](example_embedding)
245
+
246
+ sentence_embeddings.append(example_sentence_embedding)
247
+ post_processed_sequence_embeddings.append(example_post_processed_embedding)
248
+
249
+ return (
250
+ np.array(sentence_embeddings),
251
+ ragged_array_to_ndarray(post_processed_sequence_embeddings),
252
+ )
253
+
254
+ def _tokenize_example(
255
+ self, message: Message, attribute: Text
256
+ ) -> Tuple[List[Token], List[int]]:
257
+ """Tokenizes a single message example.
258
+
259
+ Many language models add a special char in front of (some) words and split
260
+ words into sub-words. To ensure the entity start and end values matches the
261
+ token values, use the tokens produced by the Tokenizer component. If
262
+ individual tokens are split up into multiple tokens, we add this information
263
+ to the respected token.
264
+
265
+ Args:
266
+ message: Single message object to be processed.
267
+ attribute: Property of message to be processed, one of ``TEXT`` or
268
+ ``RESPONSE``.
269
+
270
+ Returns: List of token strings and token ids for the corresponding
271
+ attribute of the message.
272
+ """
273
+ tokens_in = message.get(TOKENS_NAMES[attribute])
274
+ tokens_out = []
275
+
276
+ token_ids_out = []
277
+
278
+ for token in tokens_in:
279
+ # use lm specific tokenizer to further tokenize the text
280
+ split_token_ids, split_token_strings = self._lm_tokenize(token.text)
281
+
282
+ if not split_token_ids:
283
+ # fix the situation that `token.text` only contains whitespace or other
284
+ # special characters, which cause `split_token_ids` and
285
+ # `split_token_strings` be empty, finally cause
286
+ # `self._lm_specific_token_cleanup()` to raise an exception
287
+ continue
288
+
289
+ (split_token_ids, split_token_strings) = self._lm_specific_token_cleanup(
290
+ split_token_ids, split_token_strings
291
+ )
292
+
293
+ token_ids_out += split_token_ids
294
+
295
+ token.set(NUMBER_OF_SUB_TOKENS, len(split_token_strings))
296
+
297
+ tokens_out.append(token)
298
+
299
+ return tokens_out, token_ids_out
300
+
301
+ def _get_token_ids_for_batch(
302
+ self, batch_examples: List[Message], attribute: Text
303
+ ) -> Tuple[List[List[Token]], List[List[int]]]:
304
+ """Computes token ids and token strings for each example in batch.
305
+
306
+ A token id is the id of that token in the vocabulary of the language model.
307
+
308
+ Args:
309
+ batch_examples: Batch of message objects for which tokens need to be
310
+ computed.
311
+ attribute: Property of message to be processed, one of ``TEXT`` or
312
+ ``RESPONSE``.
313
+
314
+ Returns: List of token strings and token ids for each example in the batch.
315
+ """
316
+ batch_token_ids = []
317
+ batch_tokens = []
318
+ for example in batch_examples:
319
+
320
+ example_tokens, example_token_ids = self._tokenize_example(
321
+ example, attribute
322
+ )
323
+ batch_tokens.append(example_tokens)
324
+ batch_token_ids.append(example_token_ids)
325
+
326
+ return batch_tokens, batch_token_ids
327
+
328
+ @staticmethod
329
+ def _compute_attention_mask(
330
+ actual_sequence_lengths: List[int], max_input_sequence_length: int
331
+ ) -> np.ndarray:
332
+ """Computes a mask for padding tokens.
333
+
334
+ This mask will be used by the language model so that it does not attend to
335
+ padding tokens.
336
+
337
+ Args:
338
+ actual_sequence_lengths: List of length of each example without any
339
+ padding.
340
+ max_input_sequence_length: Maximum length of a sequence that will be
341
+ present in the input batch. This is
342
+ after taking into consideration the maximum input sequence the model
343
+ can handle. Hence it can never be
344
+ greater than self.max_model_sequence_length in case the model
345
+ applies length restriction.
346
+
347
+ Returns: Computed attention mask, 0 for padding and 1 for non-padding
348
+ tokens.
349
+ """
350
+ attention_mask = []
351
+
352
+ for actual_sequence_length in actual_sequence_lengths:
353
+ # add 1s for present tokens, fill up the remaining space up to max
354
+ # sequence length with 0s (non-existing tokens)
355
+ padded_sequence = [1] * min(
356
+ actual_sequence_length, max_input_sequence_length
357
+ ) + [0] * (
358
+ max_input_sequence_length
359
+ - min(actual_sequence_length, max_input_sequence_length)
360
+ )
361
+ attention_mask.append(padded_sequence)
362
+
363
+ return np.array(attention_mask).astype(np.float32)
364
+
365
+ def _extract_sequence_lengths(
366
+ self, batch_token_ids: List[List[int]]
367
+ ) -> Tuple[List[int], int]:
368
+ """Extracts the sequence length for each example and maximum sequence length.
369
+
370
+ Args:
371
+ batch_token_ids: List of token ids for each example in the batch.
372
+
373
+ Returns:
374
+ Tuple consisting of: the actual sequence lengths for each example,
375
+ and the maximum input sequence length (taking into account the
376
+ maximum sequence length that the model can handle.
377
+ """
378
+ # Compute max length across examples
379
+ max_input_sequence_length = 0
380
+ actual_sequence_lengths = []
381
+
382
+ for example_token_ids in batch_token_ids:
383
+ sequence_length = len(example_token_ids)
384
+ actual_sequence_lengths.append(sequence_length)
385
+ max_input_sequence_length = max(
386
+ max_input_sequence_length, len(example_token_ids)
387
+ )
388
+
389
+ # Take into account the maximum sequence length the model can handle
390
+ max_input_sequence_length = (
391
+ max_input_sequence_length
392
+ if self.max_model_sequence_length == NO_LENGTH_RESTRICTION
393
+ else min(max_input_sequence_length, self.max_model_sequence_length)
394
+ )
395
+
396
+ return actual_sequence_lengths, max_input_sequence_length
397
+
398
+ def _add_padding_to_batch(
399
+ self, batch_token_ids: List[List[int]], max_sequence_length_model: int
400
+ ) -> List[List[int]]:
401
+ """Adds padding so that all examples in the batch are of the same length.
402
+
403
+ Args:
404
+ batch_token_ids: Batch of examples where each example is a non-padded list
405
+ of token ids.
406
+ max_sequence_length_model: Maximum length of any input sequence in the batch
407
+ to be fed to the model.
408
+
409
+ Returns:
410
+ Padded batch with all examples of the same length.
411
+ """
412
+ padded_token_ids = []
413
+
414
+ # Add padding according to max_sequence_length
415
+ # Some models don't contain pad token, we use unknown token as padding token.
416
+ # This doesn't affect the computation since we compute an attention mask
417
+ # anyways.
418
+ for example_token_ids in batch_token_ids:
419
+
420
+ # Truncate any longer sequences so that they can be fed to the model
421
+ if len(example_token_ids) > max_sequence_length_model:
422
+ example_token_ids = example_token_ids[:max_sequence_length_model]
423
+
424
+ padded_token_ids.append(
425
+ example_token_ids
426
+ + [self.pad_token_id]
427
+ * (max_sequence_length_model - len(example_token_ids))
428
+ )
429
+ return padded_token_ids
430
+
431
+ @staticmethod
432
+ def _extract_nonpadded_embeddings(
433
+ embeddings: np.ndarray, actual_sequence_lengths: List[int]
434
+ ) -> np.ndarray:
435
+ """Extracts embeddings for actual tokens.
436
+
437
+ Use pre-computed non-padded lengths of each example to extract embeddings
438
+ for non-padding tokens.
439
+
440
+ Args:
441
+ embeddings: sequence level representations for each example of the batch.
442
+ actual_sequence_lengths: non-padded lengths of each example of the batch.
443
+
444
+ Returns:
445
+ Sequence level embeddings for only non-padding tokens of the batch.
446
+ """
447
+ nonpadded_sequence_embeddings = []
448
+ for index, embedding in enumerate(embeddings):
449
+ unmasked_embedding = embedding[: actual_sequence_lengths[index]]
450
+ nonpadded_sequence_embeddings.append(unmasked_embedding)
451
+
452
+ return ragged_array_to_ndarray(nonpadded_sequence_embeddings)
453
+
454
+ def _compute_batch_sequence_features(
455
+ self, batch_attention_mask: np.ndarray, padded_token_ids: List[List[int]]
456
+ ) -> np.ndarray:
457
+ """Feeds the padded batch to the language model.
458
+
459
+ Args:
460
+ batch_attention_mask: Mask of 0s and 1s which indicate whether the token
461
+ is a padding token or not.
462
+ padded_token_ids: Batch of token ids for each example. The batch is padded
463
+ and hence can be fed at once.
464
+
465
+ Returns:
466
+ Sequence level representations from the language model.
467
+ """
468
+ model_outputs = self.model(
469
+ tf.convert_to_tensor(padded_token_ids),
470
+ attention_mask=tf.convert_to_tensor(batch_attention_mask),
471
+ )
472
+
473
+ # sequence hidden states is always the first output from all models
474
+ sequence_hidden_states = model_outputs[0]
475
+
476
+ sequence_hidden_states = sequence_hidden_states.numpy()
477
+ return sequence_hidden_states
478
+
479
+ def _validate_sequence_lengths(
480
+ self,
481
+ actual_sequence_lengths: List[int],
482
+ batch_examples: List[Message],
483
+ attribute: Text,
484
+ inference_mode: bool = False,
485
+ ) -> None:
486
+ """Validates sequence length.
487
+
488
+ Checks if sequence lengths of inputs are less than
489
+ the max sequence length the model can handle.
490
+
491
+ This method should throw an error during training, and log a debug
492
+ message during inference if any of the input examples have a length
493
+ greater than maximum sequence length allowed.
494
+
495
+ Args:
496
+ actual_sequence_lengths: original sequence length of all inputs
497
+ batch_examples: all message instances in the batch
498
+ attribute: attribute of message object to be processed
499
+ inference_mode: whether this is during training or inference
500
+ """
501
+ if self.max_model_sequence_length == NO_LENGTH_RESTRICTION:
502
+ # There is no restriction on sequence length from the model
503
+ return
504
+
505
+ for sequence_length, example in zip(actual_sequence_lengths, batch_examples):
506
+ if sequence_length > self.max_model_sequence_length:
507
+ if not inference_mode:
508
+ raise RuntimeError(
509
+ f"The sequence length of '{example.get(attribute)[:20]}...' "
510
+ f"is too long({sequence_length} tokens) for the "
511
+ f"model chosen {self.model_name} which has a maximum "
512
+ f"sequence length of {self.max_model_sequence_length} tokens. "
513
+ f"Either shorten the message or use a model which has no "
514
+ f"restriction on input sequence length like XLNet."
515
+ )
516
+ logger.debug(
517
+ f"The sequence length of '{example.get(attribute)[:20]}...' "
518
+ f"is too long({sequence_length} tokens) for the "
519
+ f"model chosen {self.model_name} which has a maximum "
520
+ f"sequence length of {self.max_model_sequence_length} tokens. "
521
+ f"Downstream model predictions may be affected because of this."
522
+ )
523
+
524
+ def _add_extra_padding(
525
+ self, sequence_embeddings: np.ndarray, actual_sequence_lengths: List[int]
526
+ ) -> np.ndarray:
527
+ """Adds extra zero padding to match the original sequence length.
528
+
529
+ This is only done if the input was truncated during the batch
530
+ preparation of input for the model.
531
+
532
+ Args:
533
+ sequence_embeddings: Embeddings returned from the model
534
+ actual_sequence_lengths: original sequence length of all inputs
535
+
536
+ Returns:
537
+ Modified sequence embeddings with padding if necessary
538
+ """
539
+ if self.max_model_sequence_length == NO_LENGTH_RESTRICTION:
540
+ # No extra padding needed because there wouldn't have been any
541
+ # truncation in the first place
542
+ return sequence_embeddings
543
+
544
+ reshaped_sequence_embeddings = []
545
+ for index, embedding in enumerate(sequence_embeddings):
546
+ embedding_size = embedding.shape[-1]
547
+ if actual_sequence_lengths[index] > self.max_model_sequence_length:
548
+ embedding = np.concatenate(
549
+ [
550
+ embedding,
551
+ np.zeros(
552
+ (
553
+ actual_sequence_lengths[index]
554
+ - self.max_model_sequence_length,
555
+ embedding_size,
556
+ ),
557
+ dtype=np.float32,
558
+ ),
559
+ ]
560
+ )
561
+ reshaped_sequence_embeddings.append(embedding)
562
+ return ragged_array_to_ndarray(reshaped_sequence_embeddings)
563
+
564
+ def _get_model_features_for_batch(
565
+ self,
566
+ batch_token_ids: List[List[int]],
567
+ batch_tokens: List[List[Token]],
568
+ batch_examples: List[Message],
569
+ attribute: Text,
570
+ inference_mode: bool = False,
571
+ ) -> Tuple[np.ndarray, np.ndarray]:
572
+ """Computes dense features of each example in the batch.
573
+
574
+ We first add the special tokens corresponding to each language model. Next, we
575
+ add appropriate padding and compute a mask for that padding so that it doesn't
576
+ affect the feature computation. The padded batch is next fed to the language
577
+ model and token level embeddings are computed. Using the pre-computed mask,
578
+ embeddings for non-padding tokens are extracted and subsequently sentence
579
+ level embeddings are computed.
580
+
581
+ Args:
582
+ batch_token_ids: List of token ids of each example in the batch.
583
+ batch_tokens: List of token objects for each example in the batch.
584
+ batch_examples: List of examples in the batch.
585
+ attribute: attribute of the Message object to be processed.
586
+ inference_mode: Whether the call is during training or during inference.
587
+
588
+ Returns:
589
+ Sentence and token level dense representations.
590
+ """
591
+ # Let's first add tokenizer specific special tokens to all examples
592
+ batch_token_ids_augmented = self._add_lm_specific_special_tokens(
593
+ batch_token_ids
594
+ )
595
+
596
+ # Compute sequence lengths for all examples
597
+ (
598
+ actual_sequence_lengths,
599
+ max_input_sequence_length,
600
+ ) = self._extract_sequence_lengths(batch_token_ids_augmented)
601
+
602
+ # Validate that all sequences can be processed based on their sequence
603
+ # lengths and the maximum sequence length the model can handle
604
+ self._validate_sequence_lengths(
605
+ actual_sequence_lengths, batch_examples, attribute, inference_mode
606
+ )
607
+
608
+ # Add padding so that whole batch can be fed to the model
609
+ padded_token_ids = self._add_padding_to_batch(
610
+ batch_token_ids_augmented, max_input_sequence_length
611
+ )
612
+
613
+ # Compute attention mask based on actual_sequence_length
614
+ batch_attention_mask = self._compute_attention_mask(
615
+ actual_sequence_lengths, max_input_sequence_length
616
+ )
617
+
618
+ # Get token level features from the model
619
+ sequence_hidden_states = self._compute_batch_sequence_features(
620
+ batch_attention_mask, padded_token_ids
621
+ )
622
+
623
+ # Extract features for only non-padding tokens
624
+ sequence_nonpadded_embeddings = self._extract_nonpadded_embeddings(
625
+ sequence_hidden_states, actual_sequence_lengths
626
+ )
627
+
628
+ # Extract sentence level and post-processed features
629
+ (
630
+ sentence_embeddings,
631
+ sequence_embeddings,
632
+ ) = self._post_process_sequence_embeddings(sequence_nonpadded_embeddings)
633
+
634
+ # Pad zeros for examples which were truncated in inference mode.
635
+ # This is intentionally done after sentence embeddings have been
636
+ # extracted so that they are not affected
637
+ sequence_embeddings = self._add_extra_padding(
638
+ sequence_embeddings, actual_sequence_lengths
639
+ )
640
+
641
+ # shape of matrix for all sequence embeddings
642
+ batch_dim = len(sequence_embeddings)
643
+ seq_dim = max(e.shape[0] for e in sequence_embeddings)
644
+ feature_dim = sequence_embeddings[0].shape[1]
645
+ shape = (batch_dim, seq_dim, feature_dim)
646
+
647
+ # align features with tokens so that we have just one vector per token
648
+ # (don't include sub-tokens)
649
+ sequence_embeddings = train_utils.align_token_features(
650
+ batch_tokens, sequence_embeddings, shape
651
+ )
652
+
653
+ # sequence_embeddings is a padded numpy array
654
+ # remove the padding, keep just the non-zero vectors
655
+ sequence_final_embeddings = []
656
+ for embeddings, tokens in zip(sequence_embeddings, batch_tokens):
657
+ sequence_final_embeddings.append(embeddings[: len(tokens)])
658
+
659
+ return sentence_embeddings, ragged_array_to_ndarray(sequence_final_embeddings)
660
+
661
+ def _get_docs_for_batch(
662
+ self,
663
+ batch_examples: List[Message],
664
+ attribute: Text,
665
+ inference_mode: bool = False,
666
+ ) -> List[Dict[Text, Any]]:
667
+ """Computes language model docs for all examples in the batch.
668
+
669
+ Args:
670
+ batch_examples: Batch of message objects for which language model docs
671
+ need to be computed.
672
+ attribute: Property of message to be processed, one of ``TEXT`` or
673
+ ``RESPONSE``.
674
+ inference_mode: Whether the call is during inference or during training.
675
+
676
+
677
+ Returns:
678
+ List of language model docs for each message in batch.
679
+ """
680
+ batch_tokens, batch_token_ids = self._get_token_ids_for_batch(
681
+ batch_examples, attribute
682
+ )
683
+
684
+ (
685
+ batch_sentence_features,
686
+ batch_sequence_features,
687
+ ) = self._get_model_features_for_batch(
688
+ batch_token_ids, batch_tokens, batch_examples, attribute, inference_mode
689
+ )
690
+
691
+ # A doc consists of
692
+ # {'sequence_features': ..., 'sentence_features': ...}
693
+ batch_docs = []
694
+ for index in range(len(batch_examples)):
695
+ doc = {
696
+ SEQUENCE_FEATURES: batch_sequence_features[index],
697
+ SENTENCE_FEATURES: np.reshape(batch_sentence_features[index], (1, -1)),
698
+ }
699
+ batch_docs.append(doc)
700
+
701
+ return batch_docs
702
+
703
+ def process_training_data(self, training_data: TrainingData) -> TrainingData:
704
+ """Computes tokens and dense features for each message in training data.
705
+
706
+ Args:
707
+ training_data: NLU training data to be tokenized and featurized
708
+ config: NLU pipeline config consisting of all components.
709
+ """
710
+ batch_size = 64
711
+
712
+ for attribute in DENSE_FEATURIZABLE_ATTRIBUTES:
713
+
714
+ non_empty_examples = list(
715
+ filter(lambda x: x.get(attribute), training_data.training_examples)
716
+ )
717
+
718
+ batch_start_index = 0
719
+
720
+ while batch_start_index < len(non_empty_examples):
721
+
722
+ batch_end_index = min(
723
+ batch_start_index + batch_size, len(non_empty_examples)
724
+ )
725
+ # Collect batch examples
726
+ batch_messages = non_empty_examples[batch_start_index:batch_end_index]
727
+
728
+ # Construct a doc with relevant features
729
+ # extracted(tokens, dense_features)
730
+ batch_docs = self._get_docs_for_batch(batch_messages, attribute)
731
+
732
+ for index, ex in enumerate(batch_messages):
733
+ self._set_lm_features(batch_docs[index], ex, attribute)
734
+ batch_start_index += batch_size
735
+
736
+ return training_data
737
+
738
+ def process(self, messages: List[Message]) -> List[Message]:
739
+ """Processes messages by computing tokens and dense features."""
740
+ for message in messages:
741
+ self._process_message(message)
742
+ return messages
743
+
744
+ def _process_message(self, message: Message) -> Message:
745
+ """Processes a message by computing tokens and dense features."""
746
+ # processing featurizers operates only on TEXT and ACTION_TEXT attributes,
747
+ # because all other attributes are labels which are featurized during
748
+ # training and their features are stored by the model itself.
749
+ for attribute in {TEXT, ACTION_TEXT}:
750
+ if message.get(attribute):
751
+ self._set_lm_features(
752
+ self._get_docs_for_batch(
753
+ [message], attribute=attribute, inference_mode=True
754
+ )[0],
755
+ message,
756
+ attribute,
757
+ )
758
+ return message
759
+
760
+ def _set_lm_features(
761
+ self, doc: Dict[Text, Any], message: Message, attribute: Text = TEXT
762
+ ) -> None:
763
+ """Adds the precomputed word vectors to the messages features."""
764
+ sequence_features = doc[SEQUENCE_FEATURES]
765
+ sentence_features = doc[SENTENCE_FEATURES]
766
+
767
+ self.add_features_to_message(
768
+ sequence=sequence_features,
769
+ sentence=sentence_features,
770
+ attribute=attribute,
771
+ message=message,
772
+ )