rasa-pro 3.8.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rasa-pro might be problematic. Click here for more details.

Files changed (644) hide show
  1. README.md +380 -0
  2. rasa/__init__.py +10 -0
  3. rasa/__main__.py +151 -0
  4. rasa/anonymization/__init__.py +2 -0
  5. rasa/anonymization/anonymisation_rule_yaml_reader.py +91 -0
  6. rasa/anonymization/anonymization_pipeline.py +287 -0
  7. rasa/anonymization/anonymization_rule_executor.py +260 -0
  8. rasa/anonymization/anonymization_rule_orchestrator.py +120 -0
  9. rasa/anonymization/schemas/config.yml +47 -0
  10. rasa/anonymization/utils.py +117 -0
  11. rasa/api.py +146 -0
  12. rasa/cli/__init__.py +5 -0
  13. rasa/cli/arguments/__init__.py +0 -0
  14. rasa/cli/arguments/data.py +81 -0
  15. rasa/cli/arguments/default_arguments.py +165 -0
  16. rasa/cli/arguments/evaluate.py +65 -0
  17. rasa/cli/arguments/export.py +51 -0
  18. rasa/cli/arguments/interactive.py +74 -0
  19. rasa/cli/arguments/run.py +204 -0
  20. rasa/cli/arguments/shell.py +13 -0
  21. rasa/cli/arguments/test.py +211 -0
  22. rasa/cli/arguments/train.py +263 -0
  23. rasa/cli/arguments/visualize.py +34 -0
  24. rasa/cli/arguments/x.py +30 -0
  25. rasa/cli/data.py +292 -0
  26. rasa/cli/e2e_test.py +566 -0
  27. rasa/cli/evaluate.py +222 -0
  28. rasa/cli/export.py +251 -0
  29. rasa/cli/inspect.py +63 -0
  30. rasa/cli/interactive.py +164 -0
  31. rasa/cli/license.py +65 -0
  32. rasa/cli/markers.py +78 -0
  33. rasa/cli/project_templates/__init__.py +0 -0
  34. rasa/cli/project_templates/calm/actions/__init__.py +0 -0
  35. rasa/cli/project_templates/calm/actions/action_template.py +27 -0
  36. rasa/cli/project_templates/calm/actions/add_contact.py +30 -0
  37. rasa/cli/project_templates/calm/actions/db.py +57 -0
  38. rasa/cli/project_templates/calm/actions/list_contacts.py +22 -0
  39. rasa/cli/project_templates/calm/actions/remove_contact.py +35 -0
  40. rasa/cli/project_templates/calm/config.yml +12 -0
  41. rasa/cli/project_templates/calm/credentials.yml +33 -0
  42. rasa/cli/project_templates/calm/data/flows/add_contact.yml +31 -0
  43. rasa/cli/project_templates/calm/data/flows/list_contacts.yml +14 -0
  44. rasa/cli/project_templates/calm/data/flows/remove_contact.yml +29 -0
  45. rasa/cli/project_templates/calm/db/contacts.json +10 -0
  46. rasa/cli/project_templates/calm/domain/add_contact.yml +33 -0
  47. rasa/cli/project_templates/calm/domain/list_contacts.yml +14 -0
  48. rasa/cli/project_templates/calm/domain/remove_contact.yml +31 -0
  49. rasa/cli/project_templates/calm/domain/shared.yml +5 -0
  50. rasa/cli/project_templates/calm/e2e_tests/cancelations/user_cancels_during_a_correction.yml +16 -0
  51. rasa/cli/project_templates/calm/e2e_tests/cancelations/user_changes_mind_on_a_whim.yml +7 -0
  52. rasa/cli/project_templates/calm/e2e_tests/corrections/user_corrects_contact_handle.yml +20 -0
  53. rasa/cli/project_templates/calm/e2e_tests/corrections/user_corrects_contact_name.yml +19 -0
  54. rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_adds_contact_to_their_list.yml +15 -0
  55. rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_lists_contacts.yml +5 -0
  56. rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_removes_contact.yml +11 -0
  57. rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_removes_contact_from_list.yml +12 -0
  58. rasa/cli/project_templates/calm/endpoints.yml +45 -0
  59. rasa/cli/project_templates/default/actions/__init__.py +0 -0
  60. rasa/cli/project_templates/default/actions/actions.py +27 -0
  61. rasa/cli/project_templates/default/config.yml +44 -0
  62. rasa/cli/project_templates/default/credentials.yml +33 -0
  63. rasa/cli/project_templates/default/data/nlu.yml +91 -0
  64. rasa/cli/project_templates/default/data/rules.yml +13 -0
  65. rasa/cli/project_templates/default/data/stories.yml +30 -0
  66. rasa/cli/project_templates/default/domain.yml +34 -0
  67. rasa/cli/project_templates/default/endpoints.yml +42 -0
  68. rasa/cli/project_templates/default/tests/test_stories.yml +91 -0
  69. rasa/cli/project_templates/tutorial/actions.py +22 -0
  70. rasa/cli/project_templates/tutorial/config.yml +11 -0
  71. rasa/cli/project_templates/tutorial/credentials.yml +33 -0
  72. rasa/cli/project_templates/tutorial/data/flows.yml +8 -0
  73. rasa/cli/project_templates/tutorial/domain.yml +17 -0
  74. rasa/cli/project_templates/tutorial/endpoints.yml +45 -0
  75. rasa/cli/run.py +136 -0
  76. rasa/cli/scaffold.py +268 -0
  77. rasa/cli/shell.py +141 -0
  78. rasa/cli/studio/__init__.py +0 -0
  79. rasa/cli/studio/download.py +51 -0
  80. rasa/cli/studio/studio.py +110 -0
  81. rasa/cli/studio/train.py +59 -0
  82. rasa/cli/studio/upload.py +85 -0
  83. rasa/cli/telemetry.py +90 -0
  84. rasa/cli/test.py +280 -0
  85. rasa/cli/train.py +260 -0
  86. rasa/cli/utils.py +453 -0
  87. rasa/cli/visualize.py +40 -0
  88. rasa/cli/x.py +205 -0
  89. rasa/constants.py +37 -0
  90. rasa/core/__init__.py +17 -0
  91. rasa/core/actions/__init__.py +0 -0
  92. rasa/core/actions/action.py +1450 -0
  93. rasa/core/actions/action_clean_stack.py +59 -0
  94. rasa/core/actions/action_run_slot_rejections.py +207 -0
  95. rasa/core/actions/action_trigger_chitchat.py +31 -0
  96. rasa/core/actions/action_trigger_flow.py +109 -0
  97. rasa/core/actions/action_trigger_search.py +31 -0
  98. rasa/core/actions/constants.py +2 -0
  99. rasa/core/actions/forms.py +737 -0
  100. rasa/core/actions/loops.py +111 -0
  101. rasa/core/actions/two_stage_fallback.py +186 -0
  102. rasa/core/agent.py +557 -0
  103. rasa/core/auth_retry_tracker_store.py +122 -0
  104. rasa/core/brokers/__init__.py +0 -0
  105. rasa/core/brokers/broker.py +126 -0
  106. rasa/core/brokers/file.py +58 -0
  107. rasa/core/brokers/kafka.py +322 -0
  108. rasa/core/brokers/pika.py +387 -0
  109. rasa/core/brokers/sql.py +86 -0
  110. rasa/core/channels/__init__.py +55 -0
  111. rasa/core/channels/audiocodes.py +463 -0
  112. rasa/core/channels/botframework.py +339 -0
  113. rasa/core/channels/callback.py +85 -0
  114. rasa/core/channels/channel.py +419 -0
  115. rasa/core/channels/console.py +243 -0
  116. rasa/core/channels/development_inspector.py +93 -0
  117. rasa/core/channels/facebook.py +422 -0
  118. rasa/core/channels/hangouts.py +335 -0
  119. rasa/core/channels/inspector/.eslintrc.cjs +25 -0
  120. rasa/core/channels/inspector/.gitignore +23 -0
  121. rasa/core/channels/inspector/README.md +54 -0
  122. rasa/core/channels/inspector/assets/favicon.ico +0 -0
  123. rasa/core/channels/inspector/assets/rasa-chat.js +2 -0
  124. rasa/core/channels/inspector/custom.d.ts +3 -0
  125. rasa/core/channels/inspector/dist/assets/arc-5623b6dc.js +1 -0
  126. rasa/core/channels/inspector/dist/assets/array-9f3ba611.js +1 -0
  127. rasa/core/channels/inspector/dist/assets/c4Diagram-d0fbc5ce-685c106a.js +10 -0
  128. rasa/core/channels/inspector/dist/assets/classDiagram-936ed81e-8cbed007.js +2 -0
  129. rasa/core/channels/inspector/dist/assets/classDiagram-v2-c3cb15f1-5889cf12.js +2 -0
  130. rasa/core/channels/inspector/dist/assets/createText-62fc7601-24c249d7.js +7 -0
  131. rasa/core/channels/inspector/dist/assets/edges-f2ad444c-7dd06a75.js +4 -0
  132. rasa/core/channels/inspector/dist/assets/erDiagram-9d236eb7-62c1e54c.js +51 -0
  133. rasa/core/channels/inspector/dist/assets/flowDb-1972c806-ce49b86f.js +6 -0
  134. rasa/core/channels/inspector/dist/assets/flowDiagram-7ea5b25a-4067e48f.js +4 -0
  135. rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-85583a23.js +1 -0
  136. rasa/core/channels/inspector/dist/assets/flowchart-elk-definition-abe16c3d-59fe4051.js +139 -0
  137. rasa/core/channels/inspector/dist/assets/ganttDiagram-9b5ea136-47e3a43b.js +266 -0
  138. rasa/core/channels/inspector/dist/assets/gitGraphDiagram-99d0ae7c-5a2ac0d9.js +70 -0
  139. rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-128cfa44.ttf +0 -0
  140. rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-21dbcb97.woff +0 -0
  141. rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-222b5e26.svg +329 -0
  142. rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-9ad89b2a.woff2 +0 -0
  143. rasa/core/channels/inspector/dist/assets/index-268a75c0.js +1040 -0
  144. rasa/core/channels/inspector/dist/assets/index-2c4b9a3b-dfb8efc4.js +1 -0
  145. rasa/core/channels/inspector/dist/assets/index-3ee28881.css +1 -0
  146. rasa/core/channels/inspector/dist/assets/infoDiagram-736b4530-b0c470f2.js +7 -0
  147. rasa/core/channels/inspector/dist/assets/init-77b53fdd.js +1 -0
  148. rasa/core/channels/inspector/dist/assets/journeyDiagram-df861f2b-2edb829a.js +139 -0
  149. rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-60c05ee4.woff +0 -0
  150. rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-8335d9b8.svg +438 -0
  151. rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-9cc39c75.ttf +0 -0
  152. rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-ead13ccf.woff2 +0 -0
  153. rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-16705655.woff2 +0 -0
  154. rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-5aeb07f9.woff +0 -0
  155. rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-9c459044.ttf +0 -0
  156. rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-9e2898a4.svg +435 -0
  157. rasa/core/channels/inspector/dist/assets/layout-b6873d69.js +1 -0
  158. rasa/core/channels/inspector/dist/assets/line-1efc5781.js +1 -0
  159. rasa/core/channels/inspector/dist/assets/linear-661e9b94.js +1 -0
  160. rasa/core/channels/inspector/dist/assets/mindmap-definition-beec6740-2d2e727f.js +109 -0
  161. rasa/core/channels/inspector/dist/assets/ordinal-ba9b4969.js +1 -0
  162. rasa/core/channels/inspector/dist/assets/path-53f90ab3.js +1 -0
  163. rasa/core/channels/inspector/dist/assets/pieDiagram-dbbf0591-9d3ea93d.js +35 -0
  164. rasa/core/channels/inspector/dist/assets/quadrantDiagram-4d7f4fd6-06a178a2.js +7 -0
  165. rasa/core/channels/inspector/dist/assets/requirementDiagram-6fc4c22a-0bfedffc.js +52 -0
  166. rasa/core/channels/inspector/dist/assets/sankeyDiagram-8f13d901-d76d0a04.js +8 -0
  167. rasa/core/channels/inspector/dist/assets/sequenceDiagram-b655622a-37bb4341.js +122 -0
  168. rasa/core/channels/inspector/dist/assets/stateDiagram-59f0c015-f52f7f57.js +1 -0
  169. rasa/core/channels/inspector/dist/assets/stateDiagram-v2-2b26beab-4a986a20.js +1 -0
  170. rasa/core/channels/inspector/dist/assets/styles-080da4f6-7dd9ae12.js +110 -0
  171. rasa/core/channels/inspector/dist/assets/styles-3dcbcfbf-46e1ca14.js +159 -0
  172. rasa/core/channels/inspector/dist/assets/styles-9c745c82-4a97439a.js +207 -0
  173. rasa/core/channels/inspector/dist/assets/svgDrawCommon-4835440b-823917a3.js +1 -0
  174. rasa/core/channels/inspector/dist/assets/timeline-definition-5b62e21b-9ea72896.js +61 -0
  175. rasa/core/channels/inspector/dist/assets/xychartDiagram-2b33534f-b631a8b6.js +7 -0
  176. rasa/core/channels/inspector/dist/index.html +39 -0
  177. rasa/core/channels/inspector/index.html +37 -0
  178. rasa/core/channels/inspector/jest.config.ts +13 -0
  179. rasa/core/channels/inspector/package.json +48 -0
  180. rasa/core/channels/inspector/setupTests.ts +2 -0
  181. rasa/core/channels/inspector/src/App.tsx +170 -0
  182. rasa/core/channels/inspector/src/components/DiagramFlow.tsx +97 -0
  183. rasa/core/channels/inspector/src/components/DialogueInformation.tsx +187 -0
  184. rasa/core/channels/inspector/src/components/DialogueStack.tsx +151 -0
  185. rasa/core/channels/inspector/src/components/ExpandIcon.tsx +16 -0
  186. rasa/core/channels/inspector/src/components/FullscreenButton.tsx +45 -0
  187. rasa/core/channels/inspector/src/components/LoadingSpinner.tsx +19 -0
  188. rasa/core/channels/inspector/src/components/NoActiveFlow.tsx +21 -0
  189. rasa/core/channels/inspector/src/components/RasaLogo.tsx +32 -0
  190. rasa/core/channels/inspector/src/components/SaraDiagrams.tsx +39 -0
  191. rasa/core/channels/inspector/src/components/Slots.tsx +91 -0
  192. rasa/core/channels/inspector/src/components/Welcome.tsx +54 -0
  193. rasa/core/channels/inspector/src/helpers/formatters.test.ts +385 -0
  194. rasa/core/channels/inspector/src/helpers/formatters.ts +239 -0
  195. rasa/core/channels/inspector/src/helpers/utils.ts +42 -0
  196. rasa/core/channels/inspector/src/main.tsx +13 -0
  197. rasa/core/channels/inspector/src/theme/Button/Button.ts +29 -0
  198. rasa/core/channels/inspector/src/theme/Heading/Heading.ts +31 -0
  199. rasa/core/channels/inspector/src/theme/Input/Input.ts +27 -0
  200. rasa/core/channels/inspector/src/theme/Link/Link.ts +10 -0
  201. rasa/core/channels/inspector/src/theme/Modal/Modal.ts +47 -0
  202. rasa/core/channels/inspector/src/theme/Table/Table.tsx +38 -0
  203. rasa/core/channels/inspector/src/theme/Tooltip/Tooltip.ts +12 -0
  204. rasa/core/channels/inspector/src/theme/base/breakpoints.ts +8 -0
  205. rasa/core/channels/inspector/src/theme/base/colors.ts +88 -0
  206. rasa/core/channels/inspector/src/theme/base/fonts/fontFaces.css +29 -0
  207. rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.eot +0 -0
  208. rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.svg +329 -0
  209. rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.ttf +0 -0
  210. rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.woff +0 -0
  211. rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.woff2 +0 -0
  212. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.eot +0 -0
  213. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.svg +438 -0
  214. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.ttf +0 -0
  215. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.woff +0 -0
  216. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.woff2 +0 -0
  217. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.eot +0 -0
  218. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.svg +435 -0
  219. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.ttf +0 -0
  220. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.woff +0 -0
  221. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.woff2 +0 -0
  222. rasa/core/channels/inspector/src/theme/base/radii.ts +9 -0
  223. rasa/core/channels/inspector/src/theme/base/shadows.ts +7 -0
  224. rasa/core/channels/inspector/src/theme/base/sizes.ts +7 -0
  225. rasa/core/channels/inspector/src/theme/base/space.ts +15 -0
  226. rasa/core/channels/inspector/src/theme/base/styles.ts +13 -0
  227. rasa/core/channels/inspector/src/theme/base/typography.ts +24 -0
  228. rasa/core/channels/inspector/src/theme/base/zIndices.ts +19 -0
  229. rasa/core/channels/inspector/src/theme/index.ts +101 -0
  230. rasa/core/channels/inspector/src/types.ts +64 -0
  231. rasa/core/channels/inspector/src/vite-env.d.ts +1 -0
  232. rasa/core/channels/inspector/tests/__mocks__/fileMock.ts +1 -0
  233. rasa/core/channels/inspector/tests/__mocks__/matchMedia.ts +16 -0
  234. rasa/core/channels/inspector/tests/__mocks__/styleMock.ts +1 -0
  235. rasa/core/channels/inspector/tests/renderWithProviders.tsx +14 -0
  236. rasa/core/channels/inspector/tsconfig.json +26 -0
  237. rasa/core/channels/inspector/tsconfig.node.json +10 -0
  238. rasa/core/channels/inspector/vite.config.ts +8 -0
  239. rasa/core/channels/inspector/yarn.lock +6156 -0
  240. rasa/core/channels/mattermost.py +229 -0
  241. rasa/core/channels/rasa_chat.py +126 -0
  242. rasa/core/channels/rest.py +210 -0
  243. rasa/core/channels/rocketchat.py +175 -0
  244. rasa/core/channels/slack.py +620 -0
  245. rasa/core/channels/socketio.py +274 -0
  246. rasa/core/channels/telegram.py +298 -0
  247. rasa/core/channels/twilio.py +169 -0
  248. rasa/core/channels/twilio_voice.py +367 -0
  249. rasa/core/channels/vier_cvg.py +374 -0
  250. rasa/core/channels/webexteams.py +135 -0
  251. rasa/core/concurrent_lock_store.py +210 -0
  252. rasa/core/constants.py +107 -0
  253. rasa/core/evaluation/__init__.py +0 -0
  254. rasa/core/evaluation/marker.py +267 -0
  255. rasa/core/evaluation/marker_base.py +925 -0
  256. rasa/core/evaluation/marker_stats.py +294 -0
  257. rasa/core/evaluation/marker_tracker_loader.py +103 -0
  258. rasa/core/exceptions.py +29 -0
  259. rasa/core/exporter.py +284 -0
  260. rasa/core/featurizers/__init__.py +0 -0
  261. rasa/core/featurizers/precomputation.py +410 -0
  262. rasa/core/featurizers/single_state_featurizer.py +402 -0
  263. rasa/core/featurizers/tracker_featurizers.py +1172 -0
  264. rasa/core/http_interpreter.py +89 -0
  265. rasa/core/information_retrieval/__init__.py +0 -0
  266. rasa/core/information_retrieval/faiss.py +116 -0
  267. rasa/core/information_retrieval/information_retrieval.py +72 -0
  268. rasa/core/information_retrieval/milvus.py +59 -0
  269. rasa/core/information_retrieval/qdrant.py +102 -0
  270. rasa/core/jobs.py +63 -0
  271. rasa/core/lock.py +139 -0
  272. rasa/core/lock_store.py +344 -0
  273. rasa/core/migrate.py +404 -0
  274. rasa/core/nlg/__init__.py +3 -0
  275. rasa/core/nlg/callback.py +147 -0
  276. rasa/core/nlg/contextual_response_rephraser.py +270 -0
  277. rasa/core/nlg/generator.py +230 -0
  278. rasa/core/nlg/interpolator.py +143 -0
  279. rasa/core/nlg/response.py +155 -0
  280. rasa/core/nlg/summarize.py +69 -0
  281. rasa/core/policies/__init__.py +0 -0
  282. rasa/core/policies/ensemble.py +329 -0
  283. rasa/core/policies/enterprise_search_policy.py +717 -0
  284. rasa/core/policies/enterprise_search_prompt_template.jinja2 +62 -0
  285. rasa/core/policies/flow_policy.py +205 -0
  286. rasa/core/policies/flows/__init__.py +0 -0
  287. rasa/core/policies/flows/flow_exceptions.py +44 -0
  288. rasa/core/policies/flows/flow_executor.py +582 -0
  289. rasa/core/policies/flows/flow_step_result.py +43 -0
  290. rasa/core/policies/intentless_policy.py +924 -0
  291. rasa/core/policies/intentless_prompt_template.jinja2 +22 -0
  292. rasa/core/policies/memoization.py +538 -0
  293. rasa/core/policies/policy.py +716 -0
  294. rasa/core/policies/rule_policy.py +1276 -0
  295. rasa/core/policies/ted_policy.py +2146 -0
  296. rasa/core/policies/unexpected_intent_policy.py +1015 -0
  297. rasa/core/processor.py +1331 -0
  298. rasa/core/run.py +315 -0
  299. rasa/core/secrets_manager/__init__.py +0 -0
  300. rasa/core/secrets_manager/constants.py +32 -0
  301. rasa/core/secrets_manager/endpoints.py +391 -0
  302. rasa/core/secrets_manager/factory.py +233 -0
  303. rasa/core/secrets_manager/secret_manager.py +262 -0
  304. rasa/core/secrets_manager/vault.py +576 -0
  305. rasa/core/test.py +1337 -0
  306. rasa/core/tracker_store.py +1664 -0
  307. rasa/core/train.py +107 -0
  308. rasa/core/training/__init__.py +89 -0
  309. rasa/core/training/converters/__init__.py +0 -0
  310. rasa/core/training/converters/responses_prefix_converter.py +119 -0
  311. rasa/core/training/interactive.py +1742 -0
  312. rasa/core/training/story_conflict.py +381 -0
  313. rasa/core/training/training.py +93 -0
  314. rasa/core/utils.py +344 -0
  315. rasa/core/visualize.py +70 -0
  316. rasa/dialogue_understanding/__init__.py +0 -0
  317. rasa/dialogue_understanding/coexistence/__init__.py +0 -0
  318. rasa/dialogue_understanding/coexistence/constants.py +4 -0
  319. rasa/dialogue_understanding/coexistence/intent_based_router.py +189 -0
  320. rasa/dialogue_understanding/coexistence/llm_based_router.py +261 -0
  321. rasa/dialogue_understanding/coexistence/router_template.jinja2 +12 -0
  322. rasa/dialogue_understanding/commands/__init__.py +45 -0
  323. rasa/dialogue_understanding/commands/can_not_handle_command.py +61 -0
  324. rasa/dialogue_understanding/commands/cancel_flow_command.py +116 -0
  325. rasa/dialogue_understanding/commands/chit_chat_answer_command.py +48 -0
  326. rasa/dialogue_understanding/commands/clarify_command.py +77 -0
  327. rasa/dialogue_understanding/commands/command.py +85 -0
  328. rasa/dialogue_understanding/commands/correct_slots_command.py +288 -0
  329. rasa/dialogue_understanding/commands/error_command.py +67 -0
  330. rasa/dialogue_understanding/commands/free_form_answer_command.py +9 -0
  331. rasa/dialogue_understanding/commands/handle_code_change_command.py +64 -0
  332. rasa/dialogue_understanding/commands/human_handoff_command.py +57 -0
  333. rasa/dialogue_understanding/commands/knowledge_answer_command.py +48 -0
  334. rasa/dialogue_understanding/commands/noop_command.py +45 -0
  335. rasa/dialogue_understanding/commands/set_slot_command.py +125 -0
  336. rasa/dialogue_understanding/commands/skip_question_command.py +66 -0
  337. rasa/dialogue_understanding/commands/start_flow_command.py +98 -0
  338. rasa/dialogue_understanding/generator/__init__.py +6 -0
  339. rasa/dialogue_understanding/generator/command_generator.py +257 -0
  340. rasa/dialogue_understanding/generator/command_prompt_template.jinja2 +57 -0
  341. rasa/dialogue_understanding/generator/flow_document_template.jinja2 +4 -0
  342. rasa/dialogue_understanding/generator/flow_retrieval.py +410 -0
  343. rasa/dialogue_understanding/generator/llm_command_generator.py +637 -0
  344. rasa/dialogue_understanding/generator/nlu_command_adapter.py +157 -0
  345. rasa/dialogue_understanding/patterns/__init__.py +0 -0
  346. rasa/dialogue_understanding/patterns/cancel.py +111 -0
  347. rasa/dialogue_understanding/patterns/cannot_handle.py +43 -0
  348. rasa/dialogue_understanding/patterns/chitchat.py +37 -0
  349. rasa/dialogue_understanding/patterns/clarify.py +97 -0
  350. rasa/dialogue_understanding/patterns/code_change.py +41 -0
  351. rasa/dialogue_understanding/patterns/collect_information.py +90 -0
  352. rasa/dialogue_understanding/patterns/completed.py +40 -0
  353. rasa/dialogue_understanding/patterns/continue_interrupted.py +42 -0
  354. rasa/dialogue_understanding/patterns/correction.py +278 -0
  355. rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +243 -0
  356. rasa/dialogue_understanding/patterns/human_handoff.py +37 -0
  357. rasa/dialogue_understanding/patterns/internal_error.py +47 -0
  358. rasa/dialogue_understanding/patterns/search.py +37 -0
  359. rasa/dialogue_understanding/patterns/skip_question.py +38 -0
  360. rasa/dialogue_understanding/processor/__init__.py +0 -0
  361. rasa/dialogue_understanding/processor/command_processor.py +578 -0
  362. rasa/dialogue_understanding/processor/command_processor_component.py +39 -0
  363. rasa/dialogue_understanding/stack/__init__.py +0 -0
  364. rasa/dialogue_understanding/stack/dialogue_stack.py +178 -0
  365. rasa/dialogue_understanding/stack/frames/__init__.py +19 -0
  366. rasa/dialogue_understanding/stack/frames/chit_chat_frame.py +27 -0
  367. rasa/dialogue_understanding/stack/frames/dialogue_stack_frame.py +137 -0
  368. rasa/dialogue_understanding/stack/frames/flow_stack_frame.py +157 -0
  369. rasa/dialogue_understanding/stack/frames/pattern_frame.py +10 -0
  370. rasa/dialogue_understanding/stack/frames/search_frame.py +27 -0
  371. rasa/dialogue_understanding/stack/utils.py +211 -0
  372. rasa/e2e_test/__init__.py +0 -0
  373. rasa/e2e_test/constants.py +10 -0
  374. rasa/e2e_test/e2e_test_case.py +322 -0
  375. rasa/e2e_test/e2e_test_result.py +34 -0
  376. rasa/e2e_test/e2e_test_runner.py +659 -0
  377. rasa/e2e_test/e2e_test_schema.yml +67 -0
  378. rasa/engine/__init__.py +0 -0
  379. rasa/engine/caching.py +464 -0
  380. rasa/engine/constants.py +17 -0
  381. rasa/engine/exceptions.py +14 -0
  382. rasa/engine/graph.py +625 -0
  383. rasa/engine/loader.py +36 -0
  384. rasa/engine/recipes/__init__.py +0 -0
  385. rasa/engine/recipes/config_files/default_config.yml +44 -0
  386. rasa/engine/recipes/default_components.py +99 -0
  387. rasa/engine/recipes/default_recipe.py +1252 -0
  388. rasa/engine/recipes/graph_recipe.py +79 -0
  389. rasa/engine/recipes/recipe.py +93 -0
  390. rasa/engine/runner/__init__.py +0 -0
  391. rasa/engine/runner/dask.py +256 -0
  392. rasa/engine/runner/interface.py +49 -0
  393. rasa/engine/storage/__init__.py +0 -0
  394. rasa/engine/storage/local_model_storage.py +248 -0
  395. rasa/engine/storage/resource.py +110 -0
  396. rasa/engine/storage/storage.py +203 -0
  397. rasa/engine/training/__init__.py +0 -0
  398. rasa/engine/training/components.py +176 -0
  399. rasa/engine/training/fingerprinting.py +64 -0
  400. rasa/engine/training/graph_trainer.py +256 -0
  401. rasa/engine/training/hooks.py +164 -0
  402. rasa/engine/validation.py +839 -0
  403. rasa/env.py +5 -0
  404. rasa/exceptions.py +69 -0
  405. rasa/graph_components/__init__.py +0 -0
  406. rasa/graph_components/converters/__init__.py +0 -0
  407. rasa/graph_components/converters/nlu_message_converter.py +48 -0
  408. rasa/graph_components/providers/__init__.py +0 -0
  409. rasa/graph_components/providers/domain_for_core_training_provider.py +87 -0
  410. rasa/graph_components/providers/domain_provider.py +71 -0
  411. rasa/graph_components/providers/flows_provider.py +74 -0
  412. rasa/graph_components/providers/forms_provider.py +44 -0
  413. rasa/graph_components/providers/nlu_training_data_provider.py +56 -0
  414. rasa/graph_components/providers/responses_provider.py +44 -0
  415. rasa/graph_components/providers/rule_only_provider.py +49 -0
  416. rasa/graph_components/providers/story_graph_provider.py +43 -0
  417. rasa/graph_components/providers/training_tracker_provider.py +55 -0
  418. rasa/graph_components/validators/__init__.py +0 -0
  419. rasa/graph_components/validators/default_recipe_validator.py +552 -0
  420. rasa/graph_components/validators/finetuning_validator.py +302 -0
  421. rasa/hooks.py +113 -0
  422. rasa/jupyter.py +63 -0
  423. rasa/keys +1 -0
  424. rasa/markers/__init__.py +0 -0
  425. rasa/markers/marker.py +269 -0
  426. rasa/markers/marker_base.py +828 -0
  427. rasa/markers/upload.py +74 -0
  428. rasa/markers/validate.py +21 -0
  429. rasa/model.py +118 -0
  430. rasa/model_testing.py +457 -0
  431. rasa/model_training.py +535 -0
  432. rasa/nlu/__init__.py +7 -0
  433. rasa/nlu/classifiers/__init__.py +3 -0
  434. rasa/nlu/classifiers/classifier.py +5 -0
  435. rasa/nlu/classifiers/diet_classifier.py +1874 -0
  436. rasa/nlu/classifiers/fallback_classifier.py +192 -0
  437. rasa/nlu/classifiers/keyword_intent_classifier.py +188 -0
  438. rasa/nlu/classifiers/llm_intent_classifier.py +519 -0
  439. rasa/nlu/classifiers/logistic_regression_classifier.py +240 -0
  440. rasa/nlu/classifiers/mitie_intent_classifier.py +156 -0
  441. rasa/nlu/classifiers/regex_message_handler.py +56 -0
  442. rasa/nlu/classifiers/sklearn_intent_classifier.py +309 -0
  443. rasa/nlu/constants.py +77 -0
  444. rasa/nlu/convert.py +40 -0
  445. rasa/nlu/emulators/__init__.py +0 -0
  446. rasa/nlu/emulators/dialogflow.py +55 -0
  447. rasa/nlu/emulators/emulator.py +49 -0
  448. rasa/nlu/emulators/luis.py +86 -0
  449. rasa/nlu/emulators/no_emulator.py +10 -0
  450. rasa/nlu/emulators/wit.py +56 -0
  451. rasa/nlu/extractors/__init__.py +0 -0
  452. rasa/nlu/extractors/crf_entity_extractor.py +672 -0
  453. rasa/nlu/extractors/duckling_entity_extractor.py +206 -0
  454. rasa/nlu/extractors/entity_synonyms.py +178 -0
  455. rasa/nlu/extractors/extractor.py +470 -0
  456. rasa/nlu/extractors/mitie_entity_extractor.py +293 -0
  457. rasa/nlu/extractors/regex_entity_extractor.py +220 -0
  458. rasa/nlu/extractors/spacy_entity_extractor.py +95 -0
  459. rasa/nlu/featurizers/__init__.py +0 -0
  460. rasa/nlu/featurizers/dense_featurizer/__init__.py +0 -0
  461. rasa/nlu/featurizers/dense_featurizer/convert_featurizer.py +449 -0
  462. rasa/nlu/featurizers/dense_featurizer/dense_featurizer.py +57 -0
  463. rasa/nlu/featurizers/dense_featurizer/lm_featurizer.py +772 -0
  464. rasa/nlu/featurizers/dense_featurizer/mitie_featurizer.py +170 -0
  465. rasa/nlu/featurizers/dense_featurizer/spacy_featurizer.py +132 -0
  466. rasa/nlu/featurizers/featurizer.py +89 -0
  467. rasa/nlu/featurizers/sparse_featurizer/__init__.py +0 -0
  468. rasa/nlu/featurizers/sparse_featurizer/count_vectors_featurizer.py +840 -0
  469. rasa/nlu/featurizers/sparse_featurizer/lexical_syntactic_featurizer.py +539 -0
  470. rasa/nlu/featurizers/sparse_featurizer/regex_featurizer.py +269 -0
  471. rasa/nlu/featurizers/sparse_featurizer/sparse_featurizer.py +9 -0
  472. rasa/nlu/model.py +24 -0
  473. rasa/nlu/persistor.py +240 -0
  474. rasa/nlu/run.py +27 -0
  475. rasa/nlu/selectors/__init__.py +0 -0
  476. rasa/nlu/selectors/response_selector.py +990 -0
  477. rasa/nlu/test.py +1943 -0
  478. rasa/nlu/tokenizers/__init__.py +0 -0
  479. rasa/nlu/tokenizers/jieba_tokenizer.py +148 -0
  480. rasa/nlu/tokenizers/mitie_tokenizer.py +75 -0
  481. rasa/nlu/tokenizers/spacy_tokenizer.py +72 -0
  482. rasa/nlu/tokenizers/tokenizer.py +239 -0
  483. rasa/nlu/tokenizers/whitespace_tokenizer.py +106 -0
  484. rasa/nlu/utils/__init__.py +35 -0
  485. rasa/nlu/utils/bilou_utils.py +462 -0
  486. rasa/nlu/utils/hugging_face/__init__.py +0 -0
  487. rasa/nlu/utils/hugging_face/registry.py +108 -0
  488. rasa/nlu/utils/hugging_face/transformers_pre_post_processors.py +311 -0
  489. rasa/nlu/utils/mitie_utils.py +113 -0
  490. rasa/nlu/utils/pattern_utils.py +168 -0
  491. rasa/nlu/utils/spacy_utils.py +312 -0
  492. rasa/plugin.py +90 -0
  493. rasa/server.py +1536 -0
  494. rasa/shared/__init__.py +0 -0
  495. rasa/shared/constants.py +181 -0
  496. rasa/shared/core/__init__.py +0 -0
  497. rasa/shared/core/constants.py +168 -0
  498. rasa/shared/core/conversation.py +46 -0
  499. rasa/shared/core/domain.py +2106 -0
  500. rasa/shared/core/events.py +2507 -0
  501. rasa/shared/core/flows/__init__.py +7 -0
  502. rasa/shared/core/flows/flow.py +353 -0
  503. rasa/shared/core/flows/flow_step.py +146 -0
  504. rasa/shared/core/flows/flow_step_links.py +319 -0
  505. rasa/shared/core/flows/flow_step_sequence.py +70 -0
  506. rasa/shared/core/flows/flows_list.py +211 -0
  507. rasa/shared/core/flows/flows_yaml_schema.json +217 -0
  508. rasa/shared/core/flows/nlu_trigger.py +117 -0
  509. rasa/shared/core/flows/steps/__init__.py +24 -0
  510. rasa/shared/core/flows/steps/action.py +51 -0
  511. rasa/shared/core/flows/steps/call.py +64 -0
  512. rasa/shared/core/flows/steps/collect.py +112 -0
  513. rasa/shared/core/flows/steps/constants.py +5 -0
  514. rasa/shared/core/flows/steps/continuation.py +36 -0
  515. rasa/shared/core/flows/steps/end.py +22 -0
  516. rasa/shared/core/flows/steps/internal.py +44 -0
  517. rasa/shared/core/flows/steps/link.py +51 -0
  518. rasa/shared/core/flows/steps/no_operation.py +48 -0
  519. rasa/shared/core/flows/steps/set_slots.py +50 -0
  520. rasa/shared/core/flows/steps/start.py +30 -0
  521. rasa/shared/core/flows/validation.py +527 -0
  522. rasa/shared/core/flows/yaml_flows_io.py +278 -0
  523. rasa/shared/core/generator.py +907 -0
  524. rasa/shared/core/slot_mappings.py +235 -0
  525. rasa/shared/core/slots.py +647 -0
  526. rasa/shared/core/trackers.py +1159 -0
  527. rasa/shared/core/training_data/__init__.py +0 -0
  528. rasa/shared/core/training_data/loading.py +90 -0
  529. rasa/shared/core/training_data/story_reader/__init__.py +0 -0
  530. rasa/shared/core/training_data/story_reader/story_reader.py +129 -0
  531. rasa/shared/core/training_data/story_reader/story_step_builder.py +168 -0
  532. rasa/shared/core/training_data/story_reader/yaml_story_reader.py +888 -0
  533. rasa/shared/core/training_data/story_writer/__init__.py +0 -0
  534. rasa/shared/core/training_data/story_writer/story_writer.py +76 -0
  535. rasa/shared/core/training_data/story_writer/yaml_story_writer.py +442 -0
  536. rasa/shared/core/training_data/structures.py +838 -0
  537. rasa/shared/core/training_data/visualization.html +146 -0
  538. rasa/shared/core/training_data/visualization.py +603 -0
  539. rasa/shared/data.py +192 -0
  540. rasa/shared/engine/__init__.py +0 -0
  541. rasa/shared/engine/caching.py +26 -0
  542. rasa/shared/exceptions.py +129 -0
  543. rasa/shared/importers/__init__.py +0 -0
  544. rasa/shared/importers/importer.py +705 -0
  545. rasa/shared/importers/multi_project.py +203 -0
  546. rasa/shared/importers/rasa.py +100 -0
  547. rasa/shared/importers/utils.py +34 -0
  548. rasa/shared/nlu/__init__.py +0 -0
  549. rasa/shared/nlu/constants.py +45 -0
  550. rasa/shared/nlu/interpreter.py +10 -0
  551. rasa/shared/nlu/training_data/__init__.py +0 -0
  552. rasa/shared/nlu/training_data/entities_parser.py +209 -0
  553. rasa/shared/nlu/training_data/features.py +374 -0
  554. rasa/shared/nlu/training_data/formats/__init__.py +10 -0
  555. rasa/shared/nlu/training_data/formats/dialogflow.py +162 -0
  556. rasa/shared/nlu/training_data/formats/luis.py +87 -0
  557. rasa/shared/nlu/training_data/formats/rasa.py +135 -0
  558. rasa/shared/nlu/training_data/formats/rasa_yaml.py +605 -0
  559. rasa/shared/nlu/training_data/formats/readerwriter.py +245 -0
  560. rasa/shared/nlu/training_data/formats/wit.py +52 -0
  561. rasa/shared/nlu/training_data/loading.py +137 -0
  562. rasa/shared/nlu/training_data/lookup_tables_parser.py +30 -0
  563. rasa/shared/nlu/training_data/message.py +477 -0
  564. rasa/shared/nlu/training_data/schemas/__init__.py +0 -0
  565. rasa/shared/nlu/training_data/schemas/data_schema.py +85 -0
  566. rasa/shared/nlu/training_data/schemas/nlu.yml +53 -0
  567. rasa/shared/nlu/training_data/schemas/responses.yml +70 -0
  568. rasa/shared/nlu/training_data/synonyms_parser.py +42 -0
  569. rasa/shared/nlu/training_data/training_data.py +732 -0
  570. rasa/shared/nlu/training_data/util.py +223 -0
  571. rasa/shared/providers/__init__.py +0 -0
  572. rasa/shared/providers/openai/__init__.py +0 -0
  573. rasa/shared/providers/openai/clients.py +43 -0
  574. rasa/shared/providers/openai/session_handler.py +110 -0
  575. rasa/shared/utils/__init__.py +0 -0
  576. rasa/shared/utils/cli.py +72 -0
  577. rasa/shared/utils/common.py +308 -0
  578. rasa/shared/utils/constants.py +1 -0
  579. rasa/shared/utils/io.py +403 -0
  580. rasa/shared/utils/llm.py +405 -0
  581. rasa/shared/utils/pykwalify_extensions.py +26 -0
  582. rasa/shared/utils/schemas/__init__.py +0 -0
  583. rasa/shared/utils/schemas/config.yml +2 -0
  584. rasa/shared/utils/schemas/domain.yml +142 -0
  585. rasa/shared/utils/schemas/events.py +212 -0
  586. rasa/shared/utils/schemas/model_config.yml +46 -0
  587. rasa/shared/utils/schemas/stories.yml +173 -0
  588. rasa/shared/utils/yaml.py +777 -0
  589. rasa/studio/__init__.py +0 -0
  590. rasa/studio/auth.py +252 -0
  591. rasa/studio/config.py +127 -0
  592. rasa/studio/constants.py +16 -0
  593. rasa/studio/data_handler.py +352 -0
  594. rasa/studio/download.py +350 -0
  595. rasa/studio/train.py +136 -0
  596. rasa/studio/upload.py +408 -0
  597. rasa/telemetry.py +1583 -0
  598. rasa/tracing/__init__.py +0 -0
  599. rasa/tracing/config.py +338 -0
  600. rasa/tracing/constants.py +38 -0
  601. rasa/tracing/instrumentation/__init__.py +0 -0
  602. rasa/tracing/instrumentation/attribute_extractors.py +663 -0
  603. rasa/tracing/instrumentation/instrumentation.py +939 -0
  604. rasa/tracing/instrumentation/intentless_policy_instrumentation.py +142 -0
  605. rasa/tracing/instrumentation/metrics.py +206 -0
  606. rasa/tracing/metric_instrument_provider.py +125 -0
  607. rasa/utils/__init__.py +0 -0
  608. rasa/utils/beta.py +83 -0
  609. rasa/utils/cli.py +27 -0
  610. rasa/utils/common.py +635 -0
  611. rasa/utils/converter.py +53 -0
  612. rasa/utils/endpoints.py +303 -0
  613. rasa/utils/io.py +326 -0
  614. rasa/utils/licensing.py +319 -0
  615. rasa/utils/log_utils.py +174 -0
  616. rasa/utils/mapper.py +210 -0
  617. rasa/utils/ml_utils.py +145 -0
  618. rasa/utils/plotting.py +362 -0
  619. rasa/utils/singleton.py +23 -0
  620. rasa/utils/tensorflow/__init__.py +0 -0
  621. rasa/utils/tensorflow/callback.py +112 -0
  622. rasa/utils/tensorflow/constants.py +116 -0
  623. rasa/utils/tensorflow/crf.py +492 -0
  624. rasa/utils/tensorflow/data_generator.py +440 -0
  625. rasa/utils/tensorflow/environment.py +161 -0
  626. rasa/utils/tensorflow/exceptions.py +5 -0
  627. rasa/utils/tensorflow/layers.py +1565 -0
  628. rasa/utils/tensorflow/layers_utils.py +113 -0
  629. rasa/utils/tensorflow/metrics.py +281 -0
  630. rasa/utils/tensorflow/model_data.py +991 -0
  631. rasa/utils/tensorflow/model_data_utils.py +500 -0
  632. rasa/utils/tensorflow/models.py +936 -0
  633. rasa/utils/tensorflow/rasa_layers.py +1094 -0
  634. rasa/utils/tensorflow/transformer.py +640 -0
  635. rasa/utils/tensorflow/types.py +6 -0
  636. rasa/utils/train_utils.py +572 -0
  637. rasa/utils/yaml.py +54 -0
  638. rasa/validator.py +1035 -0
  639. rasa/version.py +3 -0
  640. rasa_pro-3.8.16.dist-info/METADATA +528 -0
  641. rasa_pro-3.8.16.dist-info/NOTICE +5 -0
  642. rasa_pro-3.8.16.dist-info/RECORD +644 -0
  643. rasa_pro-3.8.16.dist-info/WHEEL +4 -0
  644. rasa_pro-3.8.16.dist-info/entry_points.txt +3 -0
@@ -0,0 +1,1874 @@
1
+ from __future__ import annotations
2
+ import copy
3
+ import logging
4
+ from collections import defaultdict
5
+ from pathlib import Path
6
+
7
+ from rasa.exceptions import ModelNotFound
8
+ from rasa.nlu.featurizers.featurizer import Featurizer
9
+
10
+ import numpy as np
11
+ import scipy.sparse
12
+ import tensorflow as tf
13
+
14
+ from typing import Any, Dict, List, Optional, Text, Tuple, Union, TypeVar, Type
15
+
16
+ from rasa.engine.graph import ExecutionContext, GraphComponent
17
+ from rasa.engine.recipes.default_recipe import DefaultV1Recipe
18
+ from rasa.engine.storage.resource import Resource
19
+ from rasa.engine.storage.storage import ModelStorage
20
+ from rasa.nlu.extractors.extractor import EntityExtractorMixin
21
+ from rasa.nlu.classifiers.classifier import IntentClassifier
22
+ import rasa.shared.utils.io
23
+ import rasa.utils.io as io_utils
24
+ import rasa.nlu.utils.bilou_utils as bilou_utils
25
+ from rasa.shared.constants import DIAGNOSTIC_DATA
26
+ from rasa.nlu.extractors.extractor import EntityTagSpec
27
+ from rasa.nlu.classifiers import LABEL_RANKING_LENGTH
28
+ from rasa.utils import train_utils
29
+ from rasa.utils.tensorflow import rasa_layers
30
+ from rasa.utils.tensorflow.models import RasaModel, TransformerRasaModel
31
+ from rasa.utils.tensorflow.model_data import (
32
+ RasaModelData,
33
+ FeatureSignature,
34
+ FeatureArray,
35
+ )
36
+ from rasa.nlu.constants import TOKENS_NAMES, DEFAULT_TRANSFORMER_SIZE
37
+ from rasa.shared.nlu.constants import (
38
+ SPLIT_ENTITIES_BY_COMMA_DEFAULT_VALUE,
39
+ TEXT,
40
+ INTENT,
41
+ INTENT_RESPONSE_KEY,
42
+ ENTITIES,
43
+ ENTITY_ATTRIBUTE_TYPE,
44
+ ENTITY_ATTRIBUTE_GROUP,
45
+ ENTITY_ATTRIBUTE_ROLE,
46
+ NO_ENTITY_TAG,
47
+ SPLIT_ENTITIES_BY_COMMA,
48
+ )
49
+ from rasa.shared.exceptions import InvalidConfigException
50
+ from rasa.shared.nlu.training_data.training_data import TrainingData
51
+ from rasa.shared.nlu.training_data.message import Message
52
+ from rasa.utils.tensorflow.constants import (
53
+ DROP_SMALL_LAST_BATCH,
54
+ LABEL,
55
+ IDS,
56
+ HIDDEN_LAYERS_SIZES,
57
+ RENORMALIZE_CONFIDENCES,
58
+ SHARE_HIDDEN_LAYERS,
59
+ TRANSFORMER_SIZE,
60
+ NUM_TRANSFORMER_LAYERS,
61
+ NUM_HEADS,
62
+ BATCH_SIZES,
63
+ BATCH_STRATEGY,
64
+ EPOCHS,
65
+ RANDOM_SEED,
66
+ LEARNING_RATE,
67
+ RANKING_LENGTH,
68
+ LOSS_TYPE,
69
+ SIMILARITY_TYPE,
70
+ NUM_NEG,
71
+ SPARSE_INPUT_DROPOUT,
72
+ DENSE_INPUT_DROPOUT,
73
+ MASKED_LM,
74
+ ENTITY_RECOGNITION,
75
+ TENSORBOARD_LOG_DIR,
76
+ INTENT_CLASSIFICATION,
77
+ EVAL_NUM_EXAMPLES,
78
+ EVAL_NUM_EPOCHS,
79
+ UNIDIRECTIONAL_ENCODER,
80
+ DROP_RATE,
81
+ DROP_RATE_ATTENTION,
82
+ CONNECTION_DENSITY,
83
+ NEGATIVE_MARGIN_SCALE,
84
+ REGULARIZATION_CONSTANT,
85
+ SCALE_LOSS,
86
+ USE_MAX_NEG_SIM,
87
+ MAX_NEG_SIM,
88
+ MAX_POS_SIM,
89
+ EMBEDDING_DIMENSION,
90
+ BILOU_FLAG,
91
+ KEY_RELATIVE_ATTENTION,
92
+ VALUE_RELATIVE_ATTENTION,
93
+ MAX_RELATIVE_POSITION,
94
+ AUTO,
95
+ BALANCED,
96
+ CROSS_ENTROPY,
97
+ TENSORBOARD_LOG_LEVEL,
98
+ CONCAT_DIMENSION,
99
+ FEATURIZERS,
100
+ CHECKPOINT_MODEL,
101
+ SEQUENCE,
102
+ SENTENCE,
103
+ SEQUENCE_LENGTH,
104
+ DENSE_DIMENSION,
105
+ MASK,
106
+ CONSTRAIN_SIMILARITIES,
107
+ MODEL_CONFIDENCE,
108
+ SOFTMAX,
109
+ RUN_EAGERLY,
110
+ )
111
+
112
+ logger = logging.getLogger(__name__)
113
+
114
+ SPARSE = "sparse"
115
+ DENSE = "dense"
116
+ LABEL_KEY = LABEL
117
+ LABEL_SUB_KEY = IDS
118
+
119
+ POSSIBLE_TAGS = [ENTITY_ATTRIBUTE_TYPE, ENTITY_ATTRIBUTE_ROLE, ENTITY_ATTRIBUTE_GROUP]
120
+
121
+
122
+ DIETClassifierT = TypeVar("DIETClassifierT", bound="DIETClassifier")
123
+
124
+
125
+ @DefaultV1Recipe.register(
126
+ [
127
+ DefaultV1Recipe.ComponentType.INTENT_CLASSIFIER,
128
+ DefaultV1Recipe.ComponentType.ENTITY_EXTRACTOR,
129
+ ],
130
+ is_trainable=True,
131
+ )
132
+ class DIETClassifier(GraphComponent, IntentClassifier, EntityExtractorMixin):
133
+ """A multi-task model for intent classification and entity extraction.
134
+
135
+ DIET is Dual Intent and Entity Transformer.
136
+ The architecture is based on a transformer which is shared for both tasks.
137
+ A sequence of entity labels is predicted through a Conditional Random Field (CRF)
138
+ tagging layer on top of the transformer output sequence corresponding to the
139
+ input sequence of tokens. The transformer output for the ``__CLS__`` token and
140
+ intent labels are embedded into a single semantic vector space. We use the
141
+ dot-product loss to maximize the similarity with the target label and minimize
142
+ similarities with negative samples.
143
+ """
144
+
145
+ @classmethod
146
+ def required_components(cls) -> List[Type]:
147
+ """Components that should be included in the pipeline before this component."""
148
+ return [Featurizer]
149
+
150
+ @staticmethod
151
+ def get_default_config() -> Dict[Text, Any]:
152
+ """The component's default config (see parent class for full docstring)."""
153
+ # please make sure to update the docs when changing a default parameter
154
+ return {
155
+ # ## Architecture of the used neural network
156
+ # Hidden layer sizes for layers before the embedding layers for user message
157
+ # and labels.
158
+ # The number of hidden layers is equal to the length of the corresponding
159
+ # list.
160
+ HIDDEN_LAYERS_SIZES: {TEXT: [], LABEL: []},
161
+ # Whether to share the hidden layer weights between user message and labels.
162
+ SHARE_HIDDEN_LAYERS: False,
163
+ # Number of units in transformer
164
+ TRANSFORMER_SIZE: DEFAULT_TRANSFORMER_SIZE,
165
+ # Number of transformer layers
166
+ NUM_TRANSFORMER_LAYERS: 2,
167
+ # Number of attention heads in transformer
168
+ NUM_HEADS: 4,
169
+ # If 'True' use key relative embeddings in attention
170
+ KEY_RELATIVE_ATTENTION: False,
171
+ # If 'True' use value relative embeddings in attention
172
+ VALUE_RELATIVE_ATTENTION: False,
173
+ # Max position for relative embeddings. Only in effect if key- or value
174
+ # relative attention are turned on
175
+ MAX_RELATIVE_POSITION: 5,
176
+ # Use a unidirectional or bidirectional encoder.
177
+ UNIDIRECTIONAL_ENCODER: False,
178
+ # ## Training parameters
179
+ # Initial and final batch sizes:
180
+ # Batch size will be linearly increased for each epoch.
181
+ BATCH_SIZES: [64, 256],
182
+ # Strategy used when creating batches.
183
+ # Can be either 'sequence' or 'balanced'.
184
+ BATCH_STRATEGY: BALANCED,
185
+ # Number of epochs to train
186
+ EPOCHS: 300,
187
+ # Set random seed to any 'int' to get reproducible results
188
+ RANDOM_SEED: None,
189
+ # Initial learning rate for the optimizer
190
+ LEARNING_RATE: 0.001,
191
+ # ## Parameters for embeddings
192
+ # Dimension size of embedding vectors
193
+ EMBEDDING_DIMENSION: 20,
194
+ # Dense dimension to use for sparse features.
195
+ DENSE_DIMENSION: {TEXT: 128, LABEL: 20},
196
+ # Default dimension to use for concatenating sequence and sentence features.
197
+ CONCAT_DIMENSION: {TEXT: 128, LABEL: 20},
198
+ # The number of incorrect labels. The algorithm will minimize
199
+ # their similarity to the user input during training.
200
+ NUM_NEG: 20,
201
+ # Type of similarity measure to use, either 'auto' or 'cosine' or 'inner'.
202
+ SIMILARITY_TYPE: AUTO,
203
+ # The type of the loss function, either 'cross_entropy' or 'margin'.
204
+ LOSS_TYPE: CROSS_ENTROPY,
205
+ # Number of top intents for which confidences should be reported.
206
+ # Set to 0 if confidences for all intents should be reported.
207
+ RANKING_LENGTH: LABEL_RANKING_LENGTH,
208
+ # Indicates how similar the algorithm should try to make embedding vectors
209
+ # for correct labels.
210
+ # Should be 0.0 < ... < 1.0 for 'cosine' similarity type.
211
+ MAX_POS_SIM: 0.8,
212
+ # Maximum negative similarity for incorrect labels.
213
+ # Should be -1.0 < ... < 1.0 for 'cosine' similarity type.
214
+ MAX_NEG_SIM: -0.4,
215
+ # If 'True' the algorithm only minimizes maximum similarity over
216
+ # incorrect intent labels, used only if 'loss_type' is set to 'margin'.
217
+ USE_MAX_NEG_SIM: True,
218
+ # If 'True' scale loss inverse proportionally to the confidence
219
+ # of the correct prediction
220
+ SCALE_LOSS: False,
221
+ # ## Regularization parameters
222
+ # The scale of regularization
223
+ REGULARIZATION_CONSTANT: 0.002,
224
+ # The scale of how important is to minimize the maximum similarity
225
+ # between embeddings of different labels,
226
+ # used only if 'loss_type' is set to 'margin'.
227
+ NEGATIVE_MARGIN_SCALE: 0.8,
228
+ # Dropout rate for encoder
229
+ DROP_RATE: 0.2,
230
+ # Dropout rate for attention
231
+ DROP_RATE_ATTENTION: 0,
232
+ # Fraction of trainable weights in internal layers.
233
+ CONNECTION_DENSITY: 0.2,
234
+ # If 'True' apply dropout to sparse input tensors
235
+ SPARSE_INPUT_DROPOUT: True,
236
+ # If 'True' apply dropout to dense input tensors
237
+ DENSE_INPUT_DROPOUT: True,
238
+ # ## Evaluation parameters
239
+ # How often calculate validation accuracy.
240
+ # Small values may hurt performance.
241
+ EVAL_NUM_EPOCHS: 20,
242
+ # How many examples to use for hold out validation set
243
+ # Large values may hurt performance, e.g. model accuracy.
244
+ # Set to 0 for no validation.
245
+ EVAL_NUM_EXAMPLES: 0,
246
+ # ## Model config
247
+ # If 'True' intent classification is trained and intent predicted.
248
+ INTENT_CLASSIFICATION: True,
249
+ # If 'True' named entity recognition is trained and entities predicted.
250
+ ENTITY_RECOGNITION: True,
251
+ # If 'True' random tokens of the input message will be masked and the model
252
+ # should predict those tokens.
253
+ MASKED_LM: False,
254
+ # 'BILOU_flag' determines whether to use BILOU tagging or not.
255
+ # If set to 'True' labelling is more rigorous, however more
256
+ # examples per entity are required.
257
+ # Rule of thumb: you should have more than 100 examples per entity.
258
+ BILOU_FLAG: True,
259
+ # If you want to use tensorboard to visualize training and validation
260
+ # metrics, set this option to a valid output directory.
261
+ TENSORBOARD_LOG_DIR: None,
262
+ # Define when training metrics for tensorboard should be logged.
263
+ # Either after every epoch or for every training step.
264
+ # Valid values: 'epoch' and 'batch'
265
+ TENSORBOARD_LOG_LEVEL: "epoch",
266
+ # Perform model checkpointing
267
+ CHECKPOINT_MODEL: False,
268
+ # Specify what features to use as sequence and sentence features
269
+ # By default all features in the pipeline are used.
270
+ FEATURIZERS: [],
271
+ # Split entities by comma, this makes sense e.g. for a list of ingredients
272
+ # in a recipie, but it doesn't make sense for the parts of an address
273
+ SPLIT_ENTITIES_BY_COMMA: True,
274
+ # If 'True' applies sigmoid on all similarity terms and adds
275
+ # it to the loss function to ensure that similarity values are
276
+ # approximately bounded. Used inside cross-entropy loss only.
277
+ CONSTRAIN_SIMILARITIES: False,
278
+ # Model confidence to be returned during inference. Currently, the only
279
+ # possible value is `softmax`.
280
+ MODEL_CONFIDENCE: SOFTMAX,
281
+ # Determines whether the confidences of the chosen top intents should be
282
+ # renormalized so that they sum up to 1. By default, we do not renormalize
283
+ # and return the confidences for the top intents as is.
284
+ # Note that renormalization only makes sense if confidences are generated
285
+ # via `softmax`.
286
+ RENORMALIZE_CONFIDENCES: False,
287
+ # Determines whether to construct the model graph or not.
288
+ # This is advantageous when the model is only trained or inferred for
289
+ # a few steps, as the compilation of the graph tends to take more time than
290
+ # running it. It is recommended to not adjust the optimization parameter.
291
+ RUN_EAGERLY: False,
292
+ # Determines whether the last batch should be dropped if it contains fewer
293
+ # than half a batch size of examples
294
+ DROP_SMALL_LAST_BATCH: False,
295
+ }
296
+
297
+ def __init__(
298
+ self,
299
+ config: Dict[Text, Any],
300
+ model_storage: ModelStorage,
301
+ resource: Resource,
302
+ execution_context: ExecutionContext,
303
+ index_label_id_mapping: Optional[Dict[int, Text]] = None,
304
+ entity_tag_specs: Optional[List[EntityTagSpec]] = None,
305
+ model: Optional[RasaModel] = None,
306
+ sparse_feature_sizes: Optional[Dict[Text, Dict[Text, List[int]]]] = None,
307
+ ) -> None:
308
+ """Declare instance variables with default values."""
309
+ if EPOCHS not in config:
310
+ rasa.shared.utils.io.raise_warning(
311
+ f"Please configure the number of '{EPOCHS}' in your configuration file."
312
+ f" We will change the default value of '{EPOCHS}' in the future to 1. "
313
+ )
314
+
315
+ self.component_config = config
316
+ self._model_storage = model_storage
317
+ self._resource = resource
318
+ self._execution_context = execution_context
319
+
320
+ self._check_config_parameters()
321
+
322
+ # transform numbers to labels
323
+ self.index_label_id_mapping = index_label_id_mapping or {}
324
+
325
+ self._entity_tag_specs = entity_tag_specs
326
+
327
+ self.model = model
328
+
329
+ self.tmp_checkpoint_dir = None
330
+ if self.component_config[CHECKPOINT_MODEL]:
331
+ self.tmp_checkpoint_dir = Path(rasa.utils.io.create_temporary_directory())
332
+
333
+ self._label_data: Optional[RasaModelData] = None
334
+ self._data_example: Optional[Dict[Text, Dict[Text, List[FeatureArray]]]] = None
335
+
336
+ self.split_entities_config = rasa.utils.train_utils.init_split_entities(
337
+ self.component_config[SPLIT_ENTITIES_BY_COMMA],
338
+ SPLIT_ENTITIES_BY_COMMA_DEFAULT_VALUE,
339
+ )
340
+
341
+ self.finetune_mode = self._execution_context.is_finetuning
342
+ self._sparse_feature_sizes = sparse_feature_sizes
343
+
344
+ # init helpers
345
+ def _check_masked_lm(self) -> None:
346
+ if (
347
+ self.component_config[MASKED_LM]
348
+ and self.component_config[NUM_TRANSFORMER_LAYERS] == 0
349
+ ):
350
+ raise ValueError(
351
+ f"If number of transformer layers is 0, "
352
+ f"'{MASKED_LM}' option should be 'False'."
353
+ )
354
+
355
+ def _check_share_hidden_layers_sizes(self) -> None:
356
+ if self.component_config.get(SHARE_HIDDEN_LAYERS):
357
+ first_hidden_layer_sizes = next(
358
+ iter(self.component_config[HIDDEN_LAYERS_SIZES].values())
359
+ )
360
+ # check that all hidden layer sizes are the same
361
+ identical_hidden_layer_sizes = all(
362
+ current_hidden_layer_sizes == first_hidden_layer_sizes
363
+ for current_hidden_layer_sizes in self.component_config[
364
+ HIDDEN_LAYERS_SIZES
365
+ ].values()
366
+ )
367
+ if not identical_hidden_layer_sizes:
368
+ raise ValueError(
369
+ f"If hidden layer weights are shared, "
370
+ f"{HIDDEN_LAYERS_SIZES} must coincide."
371
+ )
372
+
373
+ def _check_config_parameters(self) -> None:
374
+ self.component_config = train_utils.check_deprecated_options(
375
+ self.component_config
376
+ )
377
+
378
+ self._check_masked_lm()
379
+ self._check_share_hidden_layers_sizes()
380
+
381
+ self.component_config = train_utils.update_confidence_type(
382
+ self.component_config
383
+ )
384
+
385
+ train_utils.validate_configuration_settings(self.component_config)
386
+
387
+ self.component_config = train_utils.update_similarity_type(
388
+ self.component_config
389
+ )
390
+ self.component_config = train_utils.update_evaluation_parameters(
391
+ self.component_config
392
+ )
393
+
394
+ @classmethod
395
+ def create(
396
+ cls,
397
+ config: Dict[Text, Any],
398
+ model_storage: ModelStorage,
399
+ resource: Resource,
400
+ execution_context: ExecutionContext,
401
+ ) -> DIETClassifier:
402
+ """Creates a new untrained component (see parent class for full docstring)."""
403
+ return cls(config, model_storage, resource, execution_context)
404
+
405
+ @property
406
+ def label_key(self) -> Optional[Text]:
407
+ """Return key if intent classification is activated."""
408
+ return LABEL_KEY if self.component_config[INTENT_CLASSIFICATION] else None
409
+
410
+ @property
411
+ def label_sub_key(self) -> Optional[Text]:
412
+ """Return sub key if intent classification is activated."""
413
+ return LABEL_SUB_KEY if self.component_config[INTENT_CLASSIFICATION] else None
414
+
415
+ @staticmethod
416
+ def model_class() -> Type[RasaModel]:
417
+ return DIET
418
+
419
+ # training data helpers:
420
+ @staticmethod
421
+ def _label_id_index_mapping(
422
+ training_data: TrainingData, attribute: Text
423
+ ) -> Dict[Text, int]:
424
+ """Create label_id dictionary."""
425
+ distinct_label_ids = {
426
+ example.get(attribute) for example in training_data.intent_examples
427
+ } - {None}
428
+ return {
429
+ label_id: idx for idx, label_id in enumerate(sorted(distinct_label_ids))
430
+ }
431
+
432
+ @staticmethod
433
+ def _invert_mapping(mapping: Dict) -> Dict:
434
+ return {value: key for key, value in mapping.items()}
435
+
436
+ def _create_entity_tag_specs(
437
+ self, training_data: TrainingData
438
+ ) -> List[EntityTagSpec]:
439
+ """Create entity tag specifications with their respective tag id mappings."""
440
+ _tag_specs = []
441
+
442
+ for tag_name in POSSIBLE_TAGS:
443
+ if self.component_config[BILOU_FLAG]:
444
+ tag_id_index_mapping = bilou_utils.build_tag_id_dict(
445
+ training_data, tag_name
446
+ )
447
+ else:
448
+ tag_id_index_mapping = self._tag_id_index_mapping_for(
449
+ tag_name, training_data
450
+ )
451
+
452
+ if tag_id_index_mapping:
453
+ _tag_specs.append(
454
+ EntityTagSpec(
455
+ tag_name=tag_name,
456
+ tags_to_ids=tag_id_index_mapping,
457
+ ids_to_tags=self._invert_mapping(tag_id_index_mapping),
458
+ num_tags=len(tag_id_index_mapping),
459
+ )
460
+ )
461
+
462
+ return _tag_specs
463
+
464
+ @staticmethod
465
+ def _tag_id_index_mapping_for(
466
+ tag_name: Text, training_data: TrainingData
467
+ ) -> Optional[Dict[Text, int]]:
468
+ """Create mapping from tag name to id."""
469
+ if tag_name == ENTITY_ATTRIBUTE_ROLE:
470
+ distinct_tags = training_data.entity_roles
471
+ elif tag_name == ENTITY_ATTRIBUTE_GROUP:
472
+ distinct_tags = training_data.entity_groups
473
+ else:
474
+ distinct_tags = training_data.entities
475
+
476
+ distinct_tags = distinct_tags - {NO_ENTITY_TAG} - {None}
477
+
478
+ if not distinct_tags:
479
+ return None
480
+
481
+ tag_id_dict = {
482
+ tag_id: idx for idx, tag_id in enumerate(sorted(distinct_tags), 1)
483
+ }
484
+ # NO_ENTITY_TAG corresponds to non-entity which should correspond to 0 index
485
+ # needed for correct prediction for padding
486
+ tag_id_dict[NO_ENTITY_TAG] = 0
487
+
488
+ return tag_id_dict
489
+
490
+ @staticmethod
491
+ def _find_example_for_label(
492
+ label: Text, examples: List[Message], attribute: Text
493
+ ) -> Optional[Message]:
494
+ for ex in examples:
495
+ if ex.get(attribute) == label:
496
+ return ex
497
+ return None
498
+
499
+ def _check_labels_features_exist(
500
+ self, labels_example: List[Message], attribute: Text
501
+ ) -> bool:
502
+ """Checks if all labels have features set."""
503
+ return all(
504
+ label_example.features_present(
505
+ attribute, self.component_config[FEATURIZERS]
506
+ )
507
+ for label_example in labels_example
508
+ )
509
+
510
+ def _extract_features(
511
+ self, message: Message, attribute: Text
512
+ ) -> Dict[Text, Union[scipy.sparse.spmatrix, np.ndarray]]:
513
+
514
+ (
515
+ sparse_sequence_features,
516
+ sparse_sentence_features,
517
+ ) = message.get_sparse_features(attribute, self.component_config[FEATURIZERS])
518
+ dense_sequence_features, dense_sentence_features = message.get_dense_features(
519
+ attribute, self.component_config[FEATURIZERS]
520
+ )
521
+
522
+ if dense_sequence_features is not None and sparse_sequence_features is not None:
523
+ if (
524
+ dense_sequence_features.features.shape[0]
525
+ != sparse_sequence_features.features.shape[0]
526
+ ):
527
+ raise ValueError(
528
+ f"Sequence dimensions for sparse and dense sequence features "
529
+ f"don't coincide in '{message.get(TEXT)}'"
530
+ f"for attribute '{attribute}'."
531
+ )
532
+ if dense_sentence_features is not None and sparse_sentence_features is not None:
533
+ if (
534
+ dense_sentence_features.features.shape[0]
535
+ != sparse_sentence_features.features.shape[0]
536
+ ):
537
+ raise ValueError(
538
+ f"Sequence dimensions for sparse and dense sentence features "
539
+ f"don't coincide in '{message.get(TEXT)}'"
540
+ f"for attribute '{attribute}'."
541
+ )
542
+
543
+ # If we don't use the transformer and we don't want to do entity recognition,
544
+ # to speed up training take only the sentence features as feature vector.
545
+ # We would not make use of the sequence anyway in this setup. Carrying over
546
+ # those features to the actual training process takes quite some time.
547
+ if (
548
+ self.component_config[NUM_TRANSFORMER_LAYERS] == 0
549
+ and not self.component_config[ENTITY_RECOGNITION]
550
+ and attribute not in [INTENT, INTENT_RESPONSE_KEY]
551
+ ):
552
+ sparse_sequence_features = None
553
+ dense_sequence_features = None
554
+
555
+ out = {}
556
+
557
+ if sparse_sentence_features is not None:
558
+ out[f"{SPARSE}_{SENTENCE}"] = sparse_sentence_features.features
559
+ if sparse_sequence_features is not None:
560
+ out[f"{SPARSE}_{SEQUENCE}"] = sparse_sequence_features.features
561
+ if dense_sentence_features is not None:
562
+ out[f"{DENSE}_{SENTENCE}"] = dense_sentence_features.features
563
+ if dense_sequence_features is not None:
564
+ out[f"{DENSE}_{SEQUENCE}"] = dense_sequence_features.features
565
+
566
+ return out
567
+
568
+ def _check_input_dimension_consistency(self, model_data: RasaModelData) -> None:
569
+ """Checks if features have same dimensionality if hidden layers are shared."""
570
+ if self.component_config.get(SHARE_HIDDEN_LAYERS):
571
+ num_text_sentence_features = model_data.number_of_units(TEXT, SENTENCE)
572
+ num_label_sentence_features = model_data.number_of_units(LABEL, SENTENCE)
573
+ num_text_sequence_features = model_data.number_of_units(TEXT, SEQUENCE)
574
+ num_label_sequence_features = model_data.number_of_units(LABEL, SEQUENCE)
575
+
576
+ if (0 < num_text_sentence_features != num_label_sentence_features > 0) or (
577
+ 0 < num_text_sequence_features != num_label_sequence_features > 0
578
+ ):
579
+ raise ValueError(
580
+ "If embeddings are shared text features and label features "
581
+ "must coincide. Check the output dimensions of previous components."
582
+ )
583
+
584
+ def _extract_labels_precomputed_features(
585
+ self, label_examples: List[Message], attribute: Text = INTENT
586
+ ) -> Tuple[List[FeatureArray], List[FeatureArray]]:
587
+ """Collects precomputed encodings."""
588
+ features = defaultdict(list)
589
+
590
+ for e in label_examples:
591
+ label_features = self._extract_features(e, attribute)
592
+ for feature_key, feature_value in label_features.items():
593
+ features[feature_key].append(feature_value)
594
+ sequence_features = []
595
+ sentence_features = []
596
+ for feature_name, feature_value in features.items():
597
+ if SEQUENCE in feature_name:
598
+ sequence_features.append(
599
+ FeatureArray(np.array(feature_value), number_of_dimensions=3)
600
+ )
601
+ else:
602
+ sentence_features.append(
603
+ FeatureArray(np.array(feature_value), number_of_dimensions=3)
604
+ )
605
+ return sequence_features, sentence_features
606
+
607
+ @staticmethod
608
+ def _compute_default_label_features(
609
+ labels_example: List[Message],
610
+ ) -> List[FeatureArray]:
611
+ """Computes one-hot representation for the labels."""
612
+ logger.debug("No label features found. Computing default label features.")
613
+
614
+ eye_matrix = np.eye(len(labels_example), dtype=np.float32)
615
+ # add sequence dimension to one-hot labels
616
+ return [
617
+ FeatureArray(
618
+ np.array([np.expand_dims(a, 0) for a in eye_matrix]),
619
+ number_of_dimensions=3,
620
+ )
621
+ ]
622
+
623
+ def _create_label_data(
624
+ self,
625
+ training_data: TrainingData,
626
+ label_id_dict: Dict[Text, int],
627
+ attribute: Text,
628
+ ) -> RasaModelData:
629
+ """Create matrix with label_ids encoded in rows as bag of words.
630
+
631
+ Find a training example for each label and get the encoded features
632
+ from the corresponding Message object.
633
+ If the features are already computed, fetch them from the message object
634
+ else compute a one hot encoding for the label as the feature vector.
635
+ """
636
+ # Collect one example for each label
637
+ labels_idx_examples = []
638
+ for label_name, idx in label_id_dict.items():
639
+ label_example = self._find_example_for_label(
640
+ label_name, training_data.intent_examples, attribute
641
+ )
642
+ labels_idx_examples.append((idx, label_example))
643
+
644
+ # Sort the list of tuples based on label_idx
645
+ labels_idx_examples = sorted(labels_idx_examples, key=lambda x: x[0])
646
+ labels_example = [example for (_, example) in labels_idx_examples]
647
+ # Collect features, precomputed if they exist, else compute on the fly
648
+ if self._check_labels_features_exist(labels_example, attribute):
649
+ (
650
+ sequence_features,
651
+ sentence_features,
652
+ ) = self._extract_labels_precomputed_features(labels_example, attribute)
653
+ else:
654
+ sequence_features = None
655
+ sentence_features = self._compute_default_label_features(labels_example)
656
+
657
+ label_data = RasaModelData()
658
+ label_data.add_features(LABEL, SEQUENCE, sequence_features)
659
+ label_data.add_features(LABEL, SENTENCE, sentence_features)
660
+ if label_data.does_feature_not_exist(
661
+ LABEL, SENTENCE
662
+ ) and label_data.does_feature_not_exist(LABEL, SEQUENCE):
663
+ raise ValueError(
664
+ "No label features are present. Please check your configuration file."
665
+ )
666
+
667
+ label_ids = np.array([idx for (idx, _) in labels_idx_examples])
668
+ # explicitly add last dimension to label_ids
669
+ # to track correctly dynamic sequences
670
+ label_data.add_features(
671
+ LABEL_KEY,
672
+ LABEL_SUB_KEY,
673
+ [
674
+ FeatureArray(
675
+ np.expand_dims(label_ids, -1),
676
+ number_of_dimensions=2,
677
+ )
678
+ ],
679
+ )
680
+
681
+ label_data.add_lengths(LABEL, SEQUENCE_LENGTH, LABEL, SEQUENCE)
682
+
683
+ return label_data
684
+
685
+ def _use_default_label_features(self, label_ids: np.ndarray) -> List[FeatureArray]:
686
+ if self._label_data is None:
687
+ return []
688
+
689
+ feature_arrays = self._label_data.get(LABEL, SENTENCE)
690
+ all_label_features = feature_arrays[0]
691
+ return [
692
+ FeatureArray(
693
+ np.array([all_label_features[label_id] for label_id in label_ids]),
694
+ number_of_dimensions=all_label_features.number_of_dimensions,
695
+ )
696
+ ]
697
+
698
+ def _create_model_data(
699
+ self,
700
+ training_data: List[Message],
701
+ label_id_dict: Optional[Dict[Text, int]] = None,
702
+ label_attribute: Optional[Text] = None,
703
+ training: bool = True,
704
+ ) -> RasaModelData:
705
+ """Prepare data for training and create a RasaModelData object."""
706
+ from rasa.utils.tensorflow import model_data_utils
707
+
708
+ attributes_to_consider = [TEXT]
709
+ if training and self.component_config[INTENT_CLASSIFICATION]:
710
+ # we don't have any intent labels during prediction, just add them during
711
+ # training
712
+ attributes_to_consider.append(label_attribute)
713
+ if (
714
+ training
715
+ and self.component_config[ENTITY_RECOGNITION]
716
+ and self._entity_tag_specs
717
+ ):
718
+ # Add entities as labels only during training and only if there was
719
+ # training data added for entities with DIET configured to predict entities.
720
+ attributes_to_consider.append(ENTITIES)
721
+
722
+ if training and label_attribute is not None:
723
+ # only use those training examples that have the label_attribute set
724
+ # during training
725
+ training_data = [
726
+ example for example in training_data if label_attribute in example.data
727
+ ]
728
+
729
+ training_data = [
730
+ message
731
+ for message in training_data
732
+ if message.features_present(
733
+ attribute=TEXT, featurizers=self.component_config.get(FEATURIZERS)
734
+ )
735
+ ]
736
+
737
+ if not training_data:
738
+ # no training data are present to train
739
+ return RasaModelData()
740
+
741
+ (
742
+ features_for_examples,
743
+ sparse_feature_sizes,
744
+ ) = model_data_utils.featurize_training_examples(
745
+ training_data,
746
+ attributes_to_consider,
747
+ entity_tag_specs=self._entity_tag_specs,
748
+ featurizers=self.component_config[FEATURIZERS],
749
+ bilou_tagging=self.component_config[BILOU_FLAG],
750
+ )
751
+ attribute_data, _ = model_data_utils.convert_to_data_format(
752
+ features_for_examples, consider_dialogue_dimension=False
753
+ )
754
+
755
+ model_data = RasaModelData(
756
+ label_key=self.label_key, label_sub_key=self.label_sub_key
757
+ )
758
+ model_data.add_data(attribute_data)
759
+ model_data.add_lengths(TEXT, SEQUENCE_LENGTH, TEXT, SEQUENCE)
760
+ # Current implementation doesn't yet account for updating sparse
761
+ # feature sizes of label attributes. That's why we remove them.
762
+ sparse_feature_sizes = self._remove_label_sparse_feature_sizes(
763
+ sparse_feature_sizes=sparse_feature_sizes, label_attribute=label_attribute
764
+ )
765
+ model_data.add_sparse_feature_sizes(sparse_feature_sizes)
766
+
767
+ self._add_label_features(
768
+ model_data, training_data, label_attribute, label_id_dict, training
769
+ )
770
+
771
+ # make sure all keys are in the same order during training and prediction
772
+ # as we rely on the order of key and sub-key when constructing the actual
773
+ # tensors from the model data
774
+ model_data.sort()
775
+
776
+ return model_data
777
+
778
+ @staticmethod
779
+ def _remove_label_sparse_feature_sizes(
780
+ sparse_feature_sizes: Dict[Text, Dict[Text, List[int]]],
781
+ label_attribute: Optional[Text] = None,
782
+ ) -> Dict[Text, Dict[Text, List[int]]]:
783
+
784
+ if label_attribute in sparse_feature_sizes:
785
+ del sparse_feature_sizes[label_attribute]
786
+ return sparse_feature_sizes
787
+
788
+ def _add_label_features(
789
+ self,
790
+ model_data: RasaModelData,
791
+ training_data: List[Message],
792
+ label_attribute: Text,
793
+ label_id_dict: Dict[Text, int],
794
+ training: bool = True,
795
+ ) -> None:
796
+ label_ids = []
797
+ if training and self.component_config[INTENT_CLASSIFICATION]:
798
+ for example in training_data:
799
+ if example.get(label_attribute):
800
+ label_ids.append(label_id_dict[example.get(label_attribute)])
801
+ # explicitly add last dimension to label_ids
802
+ # to track correctly dynamic sequences
803
+ model_data.add_features(
804
+ LABEL_KEY,
805
+ LABEL_SUB_KEY,
806
+ [
807
+ FeatureArray(
808
+ np.expand_dims(label_ids, -1),
809
+ number_of_dimensions=2,
810
+ )
811
+ ],
812
+ )
813
+
814
+ if (
815
+ label_attribute
816
+ and model_data.does_feature_not_exist(label_attribute, SENTENCE)
817
+ and model_data.does_feature_not_exist(label_attribute, SEQUENCE)
818
+ ):
819
+ # no label features are present, get default features from _label_data
820
+ model_data.add_features(
821
+ LABEL, SENTENCE, self._use_default_label_features(np.array(label_ids))
822
+ )
823
+
824
+ # as label_attribute can have different values, e.g. INTENT or RESPONSE,
825
+ # copy over the features to the LABEL key to make
826
+ # it easier to access the label features inside the model itself
827
+ model_data.update_key(label_attribute, SENTENCE, LABEL, SENTENCE)
828
+ model_data.update_key(label_attribute, SEQUENCE, LABEL, SEQUENCE)
829
+ model_data.update_key(label_attribute, MASK, LABEL, MASK)
830
+
831
+ model_data.add_lengths(LABEL, SEQUENCE_LENGTH, LABEL, SEQUENCE)
832
+
833
+ # train helpers
834
+ def preprocess_train_data(self, training_data: TrainingData) -> RasaModelData:
835
+ """Prepares data for training.
836
+
837
+ Performs sanity checks on training data, extracts encodings for labels.
838
+ """
839
+ if (
840
+ self.component_config[BILOU_FLAG]
841
+ and self.component_config[ENTITY_RECOGNITION]
842
+ ):
843
+ bilou_utils.apply_bilou_schema(training_data)
844
+
845
+ label_id_index_mapping = self._label_id_index_mapping(
846
+ training_data, attribute=INTENT
847
+ )
848
+
849
+ if not label_id_index_mapping:
850
+ # no labels are present to train
851
+ return RasaModelData()
852
+
853
+ self.index_label_id_mapping = self._invert_mapping(label_id_index_mapping)
854
+
855
+ self._label_data = self._create_label_data(
856
+ training_data, label_id_index_mapping, attribute=INTENT
857
+ )
858
+
859
+ self._entity_tag_specs = self._create_entity_tag_specs(training_data)
860
+
861
+ label_attribute = (
862
+ INTENT if self.component_config[INTENT_CLASSIFICATION] else None
863
+ )
864
+ model_data = self._create_model_data(
865
+ training_data.nlu_examples,
866
+ label_id_index_mapping,
867
+ label_attribute=label_attribute,
868
+ )
869
+
870
+ self._check_input_dimension_consistency(model_data)
871
+
872
+ return model_data
873
+
874
+ @staticmethod
875
+ def _check_enough_labels(model_data: RasaModelData) -> bool:
876
+ return len(np.unique(model_data.get(LABEL_KEY, LABEL_SUB_KEY))) >= 2
877
+
878
+ def train(self, training_data: TrainingData) -> Resource:
879
+ """Train the embedding intent classifier on a data set."""
880
+ model_data = self.preprocess_train_data(training_data)
881
+ if model_data.is_empty():
882
+ logger.debug(
883
+ f"Cannot train '{self.__class__.__name__}'. No data was provided. "
884
+ f"Skipping training of the classifier."
885
+ )
886
+ return self._resource
887
+
888
+ if not self.model and self.finetune_mode:
889
+ raise rasa.shared.exceptions.InvalidParameterException(
890
+ f"{self.__class__.__name__} was instantiated "
891
+ f"with `model=None` and `finetune_mode=True`. "
892
+ f"This is not a valid combination as the component "
893
+ f"needs an already instantiated and trained model "
894
+ f"to continue training in finetune mode."
895
+ )
896
+
897
+ if self.component_config.get(INTENT_CLASSIFICATION):
898
+ if not self._check_enough_labels(model_data):
899
+ logger.error(
900
+ f"Cannot train '{self.__class__.__name__}'. "
901
+ f"Need at least 2 different intent classes. "
902
+ f"Skipping training of classifier."
903
+ )
904
+ return self._resource
905
+ if self.component_config.get(ENTITY_RECOGNITION):
906
+ self.check_correct_entity_annotations(training_data)
907
+
908
+ # keep one example for persisting and loading
909
+ self._data_example = model_data.first_data_example()
910
+
911
+ if not self.finetune_mode:
912
+ # No pre-trained model to load from. Create a new instance of the model.
913
+ self.model = self._instantiate_model_class(model_data)
914
+ self.model.compile(
915
+ optimizer=tf.keras.optimizers.Adam(
916
+ self.component_config[LEARNING_RATE]
917
+ ),
918
+ run_eagerly=self.component_config[RUN_EAGERLY],
919
+ )
920
+ else:
921
+ if self.model is None:
922
+ raise ModelNotFound("Model could not be found. ")
923
+
924
+ self.model.adjust_for_incremental_training(
925
+ data_example=self._data_example,
926
+ new_sparse_feature_sizes=model_data.get_sparse_feature_sizes(),
927
+ old_sparse_feature_sizes=self._sparse_feature_sizes,
928
+ )
929
+ self._sparse_feature_sizes = model_data.get_sparse_feature_sizes()
930
+
931
+ data_generator, validation_data_generator = train_utils.create_data_generators(
932
+ model_data,
933
+ self.component_config[BATCH_SIZES],
934
+ self.component_config[EPOCHS],
935
+ self.component_config[BATCH_STRATEGY],
936
+ self.component_config[EVAL_NUM_EXAMPLES],
937
+ self.component_config[RANDOM_SEED],
938
+ drop_small_last_batch=self.component_config[DROP_SMALL_LAST_BATCH],
939
+ )
940
+ callbacks = train_utils.create_common_callbacks(
941
+ self.component_config[EPOCHS],
942
+ self.component_config[TENSORBOARD_LOG_DIR],
943
+ self.component_config[TENSORBOARD_LOG_LEVEL],
944
+ self.tmp_checkpoint_dir,
945
+ )
946
+
947
+ self.model.fit(
948
+ data_generator,
949
+ epochs=self.component_config[EPOCHS],
950
+ validation_data=validation_data_generator,
951
+ validation_freq=self.component_config[EVAL_NUM_EPOCHS],
952
+ callbacks=callbacks,
953
+ verbose=False,
954
+ shuffle=False, # we use custom shuffle inside data generator
955
+ )
956
+
957
+ self.persist()
958
+
959
+ return self._resource
960
+
961
+ # process helpers
962
+ def _predict(
963
+ self, message: Message
964
+ ) -> Optional[Dict[Text, Union[tf.Tensor, Dict[Text, tf.Tensor]]]]:
965
+ if self.model is None:
966
+ logger.debug(
967
+ f"There is no trained model for '{self.__class__.__name__}': The "
968
+ f"component is either not trained or didn't receive enough training "
969
+ f"data."
970
+ )
971
+ return None
972
+
973
+ # create session data from message and convert it into a batch of 1
974
+ model_data = self._create_model_data([message], training=False)
975
+ if model_data.is_empty():
976
+ return None
977
+ return self.model.run_inference(model_data)
978
+
979
+ def _predict_label(
980
+ self, predict_out: Optional[Dict[Text, tf.Tensor]]
981
+ ) -> Tuple[Dict[Text, Any], List[Dict[Text, Any]]]:
982
+ """Predicts the intent of the provided message."""
983
+ label: Dict[Text, Any] = {"name": None, "confidence": 0.0}
984
+ label_ranking: List[Dict[Text, Any]] = []
985
+
986
+ if predict_out is None:
987
+ return label, label_ranking
988
+
989
+ message_sim = predict_out["i_scores"]
990
+ message_sim = message_sim.flatten() # sim is a matrix
991
+
992
+ # if X contains all zeros do not predict some label
993
+ if message_sim.size == 0:
994
+ return label, label_ranking
995
+
996
+ # rank the confidences
997
+ ranking_length = self.component_config[RANKING_LENGTH]
998
+ renormalize = (
999
+ self.component_config[RENORMALIZE_CONFIDENCES]
1000
+ and self.component_config[MODEL_CONFIDENCE] == SOFTMAX
1001
+ )
1002
+ ranked_label_indices, message_sim = train_utils.rank_and_mask(
1003
+ message_sim, ranking_length=ranking_length, renormalize=renormalize
1004
+ )
1005
+
1006
+ # construct the label and ranking
1007
+ casted_message_sim: List[float] = message_sim.tolist() # np.float to float
1008
+ top_label_idx = ranked_label_indices[0]
1009
+ label = {
1010
+ "name": self.index_label_id_mapping[top_label_idx],
1011
+ "confidence": casted_message_sim[top_label_idx],
1012
+ }
1013
+
1014
+ ranking = [(idx, casted_message_sim[idx]) for idx in ranked_label_indices]
1015
+ label_ranking = [
1016
+ {"name": self.index_label_id_mapping[label_idx], "confidence": score}
1017
+ for label_idx, score in ranking
1018
+ ]
1019
+
1020
+ return label, label_ranking
1021
+
1022
+ def _predict_entities(
1023
+ self, predict_out: Optional[Dict[Text, tf.Tensor]], message: Message
1024
+ ) -> List[Dict]:
1025
+ if predict_out is None:
1026
+ return []
1027
+
1028
+ predicted_tags, confidence_values = train_utils.entity_label_to_tags(
1029
+ predict_out, self._entity_tag_specs, self.component_config[BILOU_FLAG]
1030
+ )
1031
+
1032
+ entities = self.convert_predictions_into_entities(
1033
+ message.get(TEXT),
1034
+ message.get(TOKENS_NAMES[TEXT], []),
1035
+ predicted_tags,
1036
+ self.split_entities_config,
1037
+ confidence_values,
1038
+ )
1039
+
1040
+ entities = self.add_extractor_name(entities)
1041
+ entities = message.get(ENTITIES, []) + entities
1042
+
1043
+ return entities
1044
+
1045
+ def process(self, messages: List[Message]) -> List[Message]:
1046
+ """Augments the message with intents, entities, and diagnostic data."""
1047
+ for message in messages:
1048
+ out = self._predict(message)
1049
+
1050
+ if self.component_config[INTENT_CLASSIFICATION]:
1051
+ label, label_ranking = self._predict_label(out)
1052
+
1053
+ message.set(INTENT, label, add_to_output=True)
1054
+ message.set("intent_ranking", label_ranking, add_to_output=True)
1055
+
1056
+ if self.component_config[ENTITY_RECOGNITION]:
1057
+ entities = self._predict_entities(out, message)
1058
+
1059
+ message.set(ENTITIES, entities, add_to_output=True)
1060
+
1061
+ if out and self._execution_context.should_add_diagnostic_data:
1062
+ message.add_diagnostic_data(
1063
+ self._execution_context.node_name, out.get(DIAGNOSTIC_DATA)
1064
+ )
1065
+
1066
+ return messages
1067
+
1068
+ def persist(self) -> None:
1069
+ """Persist this model into the passed directory."""
1070
+ if self.model is None:
1071
+ return None
1072
+
1073
+ with self._model_storage.write_to(self._resource) as model_path:
1074
+ file_name = self.__class__.__name__
1075
+ tf_model_file = model_path / f"{file_name}.tf_model"
1076
+
1077
+ rasa.shared.utils.io.create_directory_for_file(tf_model_file)
1078
+
1079
+ if self.component_config[CHECKPOINT_MODEL] and self.tmp_checkpoint_dir:
1080
+ self.model.load_weights(self.tmp_checkpoint_dir / "checkpoint.tf_model")
1081
+ # Save an empty file to flag that this model has been
1082
+ # produced using checkpointing
1083
+ checkpoint_marker = model_path / f"{file_name}.from_checkpoint.pkl"
1084
+ checkpoint_marker.touch()
1085
+
1086
+ self.model.save(str(tf_model_file))
1087
+
1088
+ io_utils.pickle_dump(
1089
+ model_path / f"{file_name}.data_example.pkl", self._data_example
1090
+ )
1091
+ io_utils.pickle_dump(
1092
+ model_path / f"{file_name}.sparse_feature_sizes.pkl",
1093
+ self._sparse_feature_sizes,
1094
+ )
1095
+ io_utils.pickle_dump(
1096
+ model_path / f"{file_name}.label_data.pkl",
1097
+ dict(self._label_data.data) if self._label_data is not None else {},
1098
+ )
1099
+ io_utils.json_pickle(
1100
+ model_path / f"{file_name}.index_label_id_mapping.json",
1101
+ self.index_label_id_mapping,
1102
+ )
1103
+
1104
+ entity_tag_specs = (
1105
+ [tag_spec._asdict() for tag_spec in self._entity_tag_specs]
1106
+ if self._entity_tag_specs
1107
+ else []
1108
+ )
1109
+ rasa.shared.utils.io.dump_obj_as_json_to_file(
1110
+ model_path / f"{file_name}.entity_tag_specs.json", entity_tag_specs
1111
+ )
1112
+
1113
+ @classmethod
1114
+ def load(
1115
+ cls: Type[DIETClassifierT],
1116
+ config: Dict[Text, Any],
1117
+ model_storage: ModelStorage,
1118
+ resource: Resource,
1119
+ execution_context: ExecutionContext,
1120
+ **kwargs: Any,
1121
+ ) -> DIETClassifierT:
1122
+ """Loads a policy from the storage (see parent class for full docstring)."""
1123
+ try:
1124
+ with model_storage.read_from(resource) as model_path:
1125
+ return cls._load(
1126
+ model_path, config, model_storage, resource, execution_context
1127
+ )
1128
+ except ValueError:
1129
+ logger.debug(
1130
+ f"Failed to load {cls.__class__.__name__} from model storage. Resource "
1131
+ f"'{resource.name}' doesn't exist."
1132
+ )
1133
+ return cls(config, model_storage, resource, execution_context)
1134
+
1135
+ @classmethod
1136
+ def _load(
1137
+ cls: Type[DIETClassifierT],
1138
+ model_path: Path,
1139
+ config: Dict[Text, Any],
1140
+ model_storage: ModelStorage,
1141
+ resource: Resource,
1142
+ execution_context: ExecutionContext,
1143
+ ) -> DIETClassifierT:
1144
+ """Loads the trained model from the provided directory."""
1145
+ (
1146
+ index_label_id_mapping,
1147
+ entity_tag_specs,
1148
+ label_data,
1149
+ data_example,
1150
+ sparse_feature_sizes,
1151
+ ) = cls._load_from_files(model_path)
1152
+
1153
+ config = train_utils.update_confidence_type(config)
1154
+ config = train_utils.update_similarity_type(config)
1155
+
1156
+ model = cls._load_model(
1157
+ entity_tag_specs,
1158
+ label_data,
1159
+ config,
1160
+ data_example,
1161
+ model_path,
1162
+ finetune_mode=execution_context.is_finetuning,
1163
+ )
1164
+
1165
+ return cls(
1166
+ config=config,
1167
+ model_storage=model_storage,
1168
+ resource=resource,
1169
+ execution_context=execution_context,
1170
+ index_label_id_mapping=index_label_id_mapping,
1171
+ entity_tag_specs=entity_tag_specs,
1172
+ model=model,
1173
+ sparse_feature_sizes=sparse_feature_sizes,
1174
+ )
1175
+
1176
+ @classmethod
1177
+ def _load_from_files(
1178
+ cls, model_path: Path
1179
+ ) -> Tuple[
1180
+ Dict[int, Text],
1181
+ List[EntityTagSpec],
1182
+ RasaModelData,
1183
+ Dict[Text, Dict[Text, List[FeatureArray]]],
1184
+ Dict[Text, Dict[Text, List[int]]],
1185
+ ]:
1186
+ file_name = cls.__name__
1187
+
1188
+ data_example = io_utils.pickle_load(
1189
+ model_path / f"{file_name}.data_example.pkl"
1190
+ )
1191
+ label_data = io_utils.pickle_load(model_path / f"{file_name}.label_data.pkl")
1192
+ label_data = RasaModelData(data=label_data)
1193
+ sparse_feature_sizes = io_utils.pickle_load(
1194
+ model_path / f"{file_name}.sparse_feature_sizes.pkl"
1195
+ )
1196
+ index_label_id_mapping = io_utils.json_unpickle(
1197
+ model_path / f"{file_name}.index_label_id_mapping.json"
1198
+ )
1199
+ entity_tag_specs = rasa.shared.utils.io.read_json_file(
1200
+ model_path / f"{file_name}.entity_tag_specs.json"
1201
+ )
1202
+ entity_tag_specs = [
1203
+ EntityTagSpec(
1204
+ tag_name=tag_spec["tag_name"],
1205
+ ids_to_tags={
1206
+ int(key): value for key, value in tag_spec["ids_to_tags"].items()
1207
+ },
1208
+ tags_to_ids={
1209
+ key: int(value) for key, value in tag_spec["tags_to_ids"].items()
1210
+ },
1211
+ num_tags=tag_spec["num_tags"],
1212
+ )
1213
+ for tag_spec in entity_tag_specs
1214
+ ]
1215
+
1216
+ # jsonpickle converts dictionary keys to strings
1217
+ index_label_id_mapping = {
1218
+ int(key): value for key, value in index_label_id_mapping.items()
1219
+ }
1220
+
1221
+ return (
1222
+ index_label_id_mapping,
1223
+ entity_tag_specs,
1224
+ label_data,
1225
+ data_example,
1226
+ sparse_feature_sizes,
1227
+ )
1228
+
1229
+ @classmethod
1230
+ def _load_model(
1231
+ cls,
1232
+ entity_tag_specs: List[EntityTagSpec],
1233
+ label_data: RasaModelData,
1234
+ config: Dict[Text, Any],
1235
+ data_example: Dict[Text, Dict[Text, List[FeatureArray]]],
1236
+ model_path: Path,
1237
+ finetune_mode: bool = False,
1238
+ ) -> "RasaModel":
1239
+ file_name = cls.__name__
1240
+ tf_model_file = model_path / f"{file_name}.tf_model"
1241
+
1242
+ label_key = LABEL_KEY if config[INTENT_CLASSIFICATION] else None
1243
+ label_sub_key = LABEL_SUB_KEY if config[INTENT_CLASSIFICATION] else None
1244
+
1245
+ model_data_example = RasaModelData(
1246
+ label_key=label_key, label_sub_key=label_sub_key, data=data_example
1247
+ )
1248
+
1249
+ model = cls._load_model_class(
1250
+ tf_model_file,
1251
+ model_data_example,
1252
+ label_data,
1253
+ entity_tag_specs,
1254
+ config,
1255
+ finetune_mode=finetune_mode,
1256
+ )
1257
+
1258
+ return model
1259
+
1260
+ @classmethod
1261
+ def _load_model_class(
1262
+ cls,
1263
+ tf_model_file: Text,
1264
+ model_data_example: RasaModelData,
1265
+ label_data: RasaModelData,
1266
+ entity_tag_specs: List[EntityTagSpec],
1267
+ config: Dict[Text, Any],
1268
+ finetune_mode: bool,
1269
+ ) -> "RasaModel":
1270
+
1271
+ predict_data_example = RasaModelData(
1272
+ label_key=model_data_example.label_key,
1273
+ data={
1274
+ feature_name: features
1275
+ for feature_name, features in model_data_example.items()
1276
+ if TEXT in feature_name
1277
+ },
1278
+ )
1279
+
1280
+ return cls.model_class().load(
1281
+ tf_model_file,
1282
+ model_data_example,
1283
+ predict_data_example,
1284
+ data_signature=model_data_example.get_signature(),
1285
+ label_data=label_data,
1286
+ entity_tag_specs=entity_tag_specs,
1287
+ config=copy.deepcopy(config),
1288
+ finetune_mode=finetune_mode,
1289
+ )
1290
+
1291
+ def _instantiate_model_class(self, model_data: RasaModelData) -> "RasaModel":
1292
+ return self.model_class()(
1293
+ data_signature=model_data.get_signature(),
1294
+ label_data=self._label_data,
1295
+ entity_tag_specs=self._entity_tag_specs,
1296
+ config=self.component_config,
1297
+ )
1298
+
1299
+
1300
+ class DIET(TransformerRasaModel):
1301
+ def __init__(
1302
+ self,
1303
+ data_signature: Dict[Text, Dict[Text, List[FeatureSignature]]],
1304
+ label_data: RasaModelData,
1305
+ entity_tag_specs: Optional[List[EntityTagSpec]],
1306
+ config: Dict[Text, Any],
1307
+ ) -> None:
1308
+ # create entity tag spec before calling super otherwise building the model
1309
+ # will fail
1310
+ super().__init__("DIET", config, data_signature, label_data)
1311
+ self._entity_tag_specs = self._ordered_tag_specs(entity_tag_specs)
1312
+
1313
+ self.predict_data_signature = {
1314
+ feature_name: features
1315
+ for feature_name, features in data_signature.items()
1316
+ if TEXT in feature_name
1317
+ }
1318
+
1319
+ # tf training
1320
+ self._create_metrics()
1321
+ self._update_metrics_to_log()
1322
+
1323
+ # needed for efficient prediction
1324
+ self.all_labels_embed: Optional[tf.Tensor] = None
1325
+
1326
+ self._prepare_layers()
1327
+
1328
+ @staticmethod
1329
+ def _ordered_tag_specs(
1330
+ entity_tag_specs: Optional[List[EntityTagSpec]],
1331
+ ) -> List[EntityTagSpec]:
1332
+ """Ensure that order of entity tag specs matches CRF layer order."""
1333
+ if entity_tag_specs is None:
1334
+ return []
1335
+
1336
+ crf_order = [
1337
+ ENTITY_ATTRIBUTE_TYPE,
1338
+ ENTITY_ATTRIBUTE_ROLE,
1339
+ ENTITY_ATTRIBUTE_GROUP,
1340
+ ]
1341
+
1342
+ ordered_tag_spec = []
1343
+
1344
+ for tag_name in crf_order:
1345
+ for tag_spec in entity_tag_specs:
1346
+ if tag_name == tag_spec.tag_name:
1347
+ ordered_tag_spec.append(tag_spec)
1348
+
1349
+ return ordered_tag_spec
1350
+
1351
+ def _check_data(self) -> None:
1352
+ if TEXT not in self.data_signature:
1353
+ raise InvalidConfigException(
1354
+ f"No text features specified. "
1355
+ f"Cannot train '{self.__class__.__name__}' model."
1356
+ )
1357
+ if self.config[INTENT_CLASSIFICATION]:
1358
+ if LABEL not in self.data_signature:
1359
+ raise InvalidConfigException(
1360
+ f"No label features specified. "
1361
+ f"Cannot train '{self.__class__.__name__}' model."
1362
+ )
1363
+
1364
+ if self.config[SHARE_HIDDEN_LAYERS]:
1365
+ different_sentence_signatures = False
1366
+ different_sequence_signatures = False
1367
+ if (
1368
+ SENTENCE in self.data_signature[TEXT]
1369
+ and SENTENCE in self.data_signature[LABEL]
1370
+ ):
1371
+ different_sentence_signatures = (
1372
+ self.data_signature[TEXT][SENTENCE]
1373
+ != self.data_signature[LABEL][SENTENCE]
1374
+ )
1375
+ if (
1376
+ SEQUENCE in self.data_signature[TEXT]
1377
+ and SEQUENCE in self.data_signature[LABEL]
1378
+ ):
1379
+ different_sequence_signatures = (
1380
+ self.data_signature[TEXT][SEQUENCE]
1381
+ != self.data_signature[LABEL][SEQUENCE]
1382
+ )
1383
+
1384
+ if different_sentence_signatures or different_sequence_signatures:
1385
+ raise ValueError(
1386
+ "If hidden layer weights are shared, data signatures "
1387
+ "for text_features and label_features must coincide."
1388
+ )
1389
+
1390
+ if self.config[ENTITY_RECOGNITION] and (
1391
+ ENTITIES not in self.data_signature
1392
+ or ENTITY_ATTRIBUTE_TYPE not in self.data_signature[ENTITIES]
1393
+ ):
1394
+ logger.debug(
1395
+ f"You specified '{self.__class__.__name__}' to train entities, but "
1396
+ f"no entities are present in the training data. Skipping training of "
1397
+ f"entities."
1398
+ )
1399
+ self.config[ENTITY_RECOGNITION] = False
1400
+
1401
+ def _create_metrics(self) -> None:
1402
+ # self.metrics will have the same order as they are created
1403
+ # so create loss metrics first to output losses first
1404
+ self.mask_loss = tf.keras.metrics.Mean(name="m_loss")
1405
+ self.intent_loss = tf.keras.metrics.Mean(name="i_loss")
1406
+ self.entity_loss = tf.keras.metrics.Mean(name="e_loss")
1407
+ self.entity_group_loss = tf.keras.metrics.Mean(name="g_loss")
1408
+ self.entity_role_loss = tf.keras.metrics.Mean(name="r_loss")
1409
+ # create accuracy metrics second to output accuracies second
1410
+ self.mask_acc = tf.keras.metrics.Mean(name="m_acc")
1411
+ self.intent_acc = tf.keras.metrics.Mean(name="i_acc")
1412
+ self.entity_f1 = tf.keras.metrics.Mean(name="e_f1")
1413
+ self.entity_group_f1 = tf.keras.metrics.Mean(name="g_f1")
1414
+ self.entity_role_f1 = tf.keras.metrics.Mean(name="r_f1")
1415
+
1416
+ def _update_metrics_to_log(self) -> None:
1417
+ debug_log_level = logging.getLogger("rasa").level == logging.DEBUG
1418
+
1419
+ if self.config[MASKED_LM]:
1420
+ self.metrics_to_log.append("m_acc")
1421
+ if debug_log_level:
1422
+ self.metrics_to_log.append("m_loss")
1423
+ if self.config[INTENT_CLASSIFICATION]:
1424
+ self.metrics_to_log.append("i_acc")
1425
+ if debug_log_level:
1426
+ self.metrics_to_log.append("i_loss")
1427
+ if self.config[ENTITY_RECOGNITION]:
1428
+ for tag_spec in self._entity_tag_specs:
1429
+ if tag_spec.num_tags != 0:
1430
+ name = tag_spec.tag_name
1431
+ self.metrics_to_log.append(f"{name[0]}_f1")
1432
+ if debug_log_level:
1433
+ self.metrics_to_log.append(f"{name[0]}_loss")
1434
+
1435
+ self._log_metric_info()
1436
+
1437
+ def _log_metric_info(self) -> None:
1438
+ metric_name = {
1439
+ "t": "total",
1440
+ "i": "intent",
1441
+ "e": "entity",
1442
+ "m": "mask",
1443
+ "r": "role",
1444
+ "g": "group",
1445
+ }
1446
+ logger.debug("Following metrics will be logged during training: ")
1447
+ for metric in self.metrics_to_log:
1448
+ parts = metric.split("_")
1449
+ name = f"{metric_name[parts[0]]} {parts[1]}"
1450
+ logger.debug(f" {metric} ({name})")
1451
+
1452
+ def _prepare_layers(self) -> None:
1453
+ # For user text, prepare layers that combine different feature types, embed
1454
+ # everything using a transformer and optionally also do masked language
1455
+ # modeling.
1456
+ self.text_name = TEXT
1457
+ self._tf_layers[
1458
+ f"sequence_layer.{self.text_name}"
1459
+ ] = rasa_layers.RasaSequenceLayer(
1460
+ self.text_name, self.data_signature[self.text_name], self.config
1461
+ )
1462
+ if self.config[MASKED_LM]:
1463
+ self._prepare_mask_lm_loss(self.text_name)
1464
+
1465
+ # Intent labels are treated similarly to user text but without the transformer,
1466
+ # without masked language modelling, and with no dropout applied to the
1467
+ # individual features, only to the overall label embedding after all label
1468
+ # features have been combined.
1469
+ if self.config[INTENT_CLASSIFICATION]:
1470
+ self.label_name = TEXT if self.config[SHARE_HIDDEN_LAYERS] else LABEL
1471
+
1472
+ # disable input dropout applied to sparse and dense label features
1473
+ label_config = self.config.copy()
1474
+ label_config.update(
1475
+ {SPARSE_INPUT_DROPOUT: False, DENSE_INPUT_DROPOUT: False}
1476
+ )
1477
+
1478
+ self._tf_layers[
1479
+ f"feature_combining_layer.{self.label_name}"
1480
+ ] = rasa_layers.RasaFeatureCombiningLayer(
1481
+ self.label_name, self.label_signature[self.label_name], label_config
1482
+ )
1483
+
1484
+ self._prepare_ffnn_layer(
1485
+ self.label_name,
1486
+ self.config[HIDDEN_LAYERS_SIZES][self.label_name],
1487
+ self.config[DROP_RATE],
1488
+ )
1489
+
1490
+ self._prepare_label_classification_layers(predictor_attribute=TEXT)
1491
+
1492
+ if self.config[ENTITY_RECOGNITION]:
1493
+ self._prepare_entity_recognition_layers()
1494
+
1495
+ def _prepare_mask_lm_loss(self, name: Text) -> None:
1496
+ # for embedding predicted tokens at masked positions
1497
+ self._prepare_embed_layers(f"{name}_lm_mask")
1498
+
1499
+ # for embedding the true tokens that got masked
1500
+ self._prepare_embed_layers(f"{name}_golden_token")
1501
+
1502
+ # mask loss is additional loss
1503
+ # set scaling to False, so that it doesn't overpower other losses
1504
+ self._prepare_dot_product_loss(f"{name}_mask", scale_loss=False)
1505
+
1506
+ def _create_bow(
1507
+ self,
1508
+ sequence_features: List[Union[tf.Tensor, tf.SparseTensor]],
1509
+ sentence_features: List[Union[tf.Tensor, tf.SparseTensor]],
1510
+ sequence_feature_lengths: tf.Tensor,
1511
+ name: Text,
1512
+ ) -> tf.Tensor:
1513
+
1514
+ x, _ = self._tf_layers[f"feature_combining_layer.{name}"](
1515
+ (sequence_features, sentence_features, sequence_feature_lengths),
1516
+ training=self._training,
1517
+ )
1518
+
1519
+ # convert to bag-of-words by summing along the sequence dimension
1520
+ x = tf.reduce_sum(x, axis=1)
1521
+
1522
+ return self._tf_layers[f"ffnn.{name}"](x, self._training)
1523
+
1524
+ def _create_all_labels(self) -> Tuple[tf.Tensor, tf.Tensor]:
1525
+ all_label_ids = self.tf_label_data[LABEL_KEY][LABEL_SUB_KEY][0]
1526
+
1527
+ sequence_feature_lengths = self._get_sequence_feature_lengths(
1528
+ self.tf_label_data, LABEL
1529
+ )
1530
+
1531
+ x = self._create_bow(
1532
+ self.tf_label_data[LABEL][SEQUENCE],
1533
+ self.tf_label_data[LABEL][SENTENCE],
1534
+ sequence_feature_lengths,
1535
+ self.label_name,
1536
+ )
1537
+ all_labels_embed = self._tf_layers[f"embed.{LABEL}"](x)
1538
+
1539
+ return all_label_ids, all_labels_embed
1540
+
1541
+ def _mask_loss(
1542
+ self,
1543
+ outputs: tf.Tensor,
1544
+ inputs: tf.Tensor,
1545
+ seq_ids: tf.Tensor,
1546
+ mlm_mask_boolean: tf.Tensor,
1547
+ name: Text,
1548
+ ) -> tf.Tensor:
1549
+ # make sure there is at least one element in the mask
1550
+ mlm_mask_boolean = tf.cond(
1551
+ tf.reduce_any(mlm_mask_boolean),
1552
+ lambda: mlm_mask_boolean,
1553
+ lambda: tf.scatter_nd([[0, 0, 0]], [True], tf.shape(mlm_mask_boolean)),
1554
+ )
1555
+
1556
+ mlm_mask_boolean = tf.squeeze(mlm_mask_boolean, -1)
1557
+
1558
+ # Pick elements that were masked, throwing away the batch & sequence dimension
1559
+ # and effectively switching from shape (batch_size, sequence_length, units) to
1560
+ # (num_masked_elements, units).
1561
+ outputs = tf.boolean_mask(outputs, mlm_mask_boolean)
1562
+ inputs = tf.boolean_mask(inputs, mlm_mask_boolean)
1563
+ ids = tf.boolean_mask(seq_ids, mlm_mask_boolean)
1564
+
1565
+ tokens_predicted_embed = self._tf_layers[f"embed.{name}_lm_mask"](outputs)
1566
+ tokens_true_embed = self._tf_layers[f"embed.{name}_golden_token"](inputs)
1567
+
1568
+ # To limit the otherwise computationally expensive loss calculation, we
1569
+ # constrain the label space in MLM (i.e. token space) to only those tokens that
1570
+ # were masked in this batch. Hence the reduced list of token embeddings
1571
+ # (tokens_true_embed) and the reduced list of labels (ids) are passed as
1572
+ # all_labels_embed and all_labels, respectively. In the future, we could be less
1573
+ # restrictive and construct a slightly bigger label space which could include
1574
+ # tokens not masked in the current batch too.
1575
+ return self._tf_layers[f"loss.{name}_mask"](
1576
+ inputs_embed=tokens_predicted_embed,
1577
+ labels_embed=tokens_true_embed,
1578
+ labels=ids,
1579
+ all_labels_embed=tokens_true_embed,
1580
+ all_labels=ids,
1581
+ )
1582
+
1583
+ def _calculate_label_loss(
1584
+ self, text_features: tf.Tensor, label_features: tf.Tensor, label_ids: tf.Tensor
1585
+ ) -> tf.Tensor:
1586
+ all_label_ids, all_labels_embed = self._create_all_labels()
1587
+
1588
+ text_embed = self._tf_layers[f"embed.{TEXT}"](text_features)
1589
+ label_embed = self._tf_layers[f"embed.{LABEL}"](label_features)
1590
+
1591
+ return self._tf_layers[f"loss.{LABEL}"](
1592
+ text_embed, label_embed, label_ids, all_labels_embed, all_label_ids
1593
+ )
1594
+
1595
+ def batch_loss(
1596
+ self, batch_in: Union[Tuple[tf.Tensor, ...], Tuple[np.ndarray, ...]]
1597
+ ) -> tf.Tensor:
1598
+ """Calculates the loss for the given batch.
1599
+
1600
+ Args:
1601
+ batch_in: The batch.
1602
+
1603
+ Returns:
1604
+ The loss of the given batch.
1605
+ """
1606
+ tf_batch_data = self.batch_to_model_data_format(batch_in, self.data_signature)
1607
+
1608
+ sequence_feature_lengths = self._get_sequence_feature_lengths(
1609
+ tf_batch_data, TEXT
1610
+ )
1611
+
1612
+ (
1613
+ text_transformed,
1614
+ text_in,
1615
+ mask_combined_sequence_sentence,
1616
+ text_seq_ids,
1617
+ mlm_mask_boolean_text,
1618
+ _,
1619
+ ) = self._tf_layers[f"sequence_layer.{self.text_name}"](
1620
+ (
1621
+ tf_batch_data[TEXT][SEQUENCE],
1622
+ tf_batch_data[TEXT][SENTENCE],
1623
+ sequence_feature_lengths,
1624
+ ),
1625
+ training=self._training,
1626
+ )
1627
+
1628
+ losses = []
1629
+
1630
+ # Lengths of sequences in case of sentence-level features are always 1, but they
1631
+ # can effectively be 0 if sentence-level features aren't present.
1632
+ sentence_feature_lengths = self._get_sentence_feature_lengths(
1633
+ tf_batch_data, TEXT
1634
+ )
1635
+
1636
+ combined_sequence_sentence_feature_lengths = (
1637
+ sequence_feature_lengths + sentence_feature_lengths
1638
+ )
1639
+
1640
+ if self.config[MASKED_LM] and self._training:
1641
+ loss, acc = self._mask_loss(
1642
+ text_transformed, text_in, text_seq_ids, mlm_mask_boolean_text, TEXT
1643
+ )
1644
+ self.mask_loss.update_state(loss)
1645
+ self.mask_acc.update_state(acc)
1646
+ losses.append(loss)
1647
+
1648
+ if self.config[INTENT_CLASSIFICATION]:
1649
+ loss = self._batch_loss_intent(
1650
+ combined_sequence_sentence_feature_lengths,
1651
+ text_transformed,
1652
+ tf_batch_data,
1653
+ )
1654
+ losses.append(loss)
1655
+
1656
+ if self.config[ENTITY_RECOGNITION]:
1657
+ losses += self._batch_loss_entities(
1658
+ mask_combined_sequence_sentence,
1659
+ sequence_feature_lengths,
1660
+ text_transformed,
1661
+ tf_batch_data,
1662
+ )
1663
+
1664
+ return tf.math.add_n(losses)
1665
+
1666
+ def _batch_loss_intent(
1667
+ self,
1668
+ combined_sequence_sentence_feature_lengths_text: tf.Tensor,
1669
+ text_transformed: tf.Tensor,
1670
+ tf_batch_data: Dict[Text, Dict[Text, List[tf.Tensor]]],
1671
+ ) -> tf.Tensor:
1672
+ # get sentence features vector for intent classification
1673
+ sentence_vector = self._last_token(
1674
+ text_transformed, combined_sequence_sentence_feature_lengths_text
1675
+ )
1676
+
1677
+ sequence_feature_lengths_label = self._get_sequence_feature_lengths(
1678
+ tf_batch_data, LABEL
1679
+ )
1680
+
1681
+ label_ids = tf_batch_data[LABEL_KEY][LABEL_SUB_KEY][0]
1682
+ label = self._create_bow(
1683
+ tf_batch_data[LABEL][SEQUENCE],
1684
+ tf_batch_data[LABEL][SENTENCE],
1685
+ sequence_feature_lengths_label,
1686
+ self.label_name,
1687
+ )
1688
+ loss, acc = self._calculate_label_loss(sentence_vector, label, label_ids)
1689
+
1690
+ self._update_label_metrics(loss, acc)
1691
+
1692
+ return loss
1693
+
1694
+ def _update_label_metrics(self, loss: tf.Tensor, acc: tf.Tensor) -> None:
1695
+
1696
+ self.intent_loss.update_state(loss)
1697
+ self.intent_acc.update_state(acc)
1698
+
1699
+ def _batch_loss_entities(
1700
+ self,
1701
+ mask_combined_sequence_sentence: tf.Tensor,
1702
+ sequence_feature_lengths: tf.Tensor,
1703
+ text_transformed: tf.Tensor,
1704
+ tf_batch_data: Dict[Text, Dict[Text, List[tf.Tensor]]],
1705
+ ) -> List[tf.Tensor]:
1706
+ losses = []
1707
+
1708
+ entity_tags = None
1709
+
1710
+ for tag_spec in self._entity_tag_specs:
1711
+ if tag_spec.num_tags == 0:
1712
+ continue
1713
+
1714
+ tag_ids = tf_batch_data[ENTITIES][tag_spec.tag_name][0]
1715
+ # add a zero (no entity) for the sentence features to match the shape of
1716
+ # inputs
1717
+ tag_ids = tf.pad(tag_ids, [[0, 0], [0, 1], [0, 0]])
1718
+
1719
+ loss, f1, _logits = self._calculate_entity_loss(
1720
+ text_transformed,
1721
+ tag_ids,
1722
+ mask_combined_sequence_sentence,
1723
+ sequence_feature_lengths,
1724
+ tag_spec.tag_name,
1725
+ entity_tags,
1726
+ )
1727
+
1728
+ if tag_spec.tag_name == ENTITY_ATTRIBUTE_TYPE:
1729
+ # use the entity tags as additional input for the role
1730
+ # and group CRF
1731
+ entity_tags = tf.one_hot(
1732
+ tf.cast(tag_ids[:, :, 0], tf.int32), depth=tag_spec.num_tags
1733
+ )
1734
+
1735
+ self._update_entity_metrics(loss, f1, tag_spec.tag_name)
1736
+
1737
+ losses.append(loss)
1738
+
1739
+ return losses
1740
+
1741
+ def _update_entity_metrics(
1742
+ self, loss: tf.Tensor, f1: tf.Tensor, tag_name: Text
1743
+ ) -> None:
1744
+ if tag_name == ENTITY_ATTRIBUTE_TYPE:
1745
+ self.entity_loss.update_state(loss)
1746
+ self.entity_f1.update_state(f1)
1747
+ elif tag_name == ENTITY_ATTRIBUTE_GROUP:
1748
+ self.entity_group_loss.update_state(loss)
1749
+ self.entity_group_f1.update_state(f1)
1750
+ elif tag_name == ENTITY_ATTRIBUTE_ROLE:
1751
+ self.entity_role_loss.update_state(loss)
1752
+ self.entity_role_f1.update_state(f1)
1753
+
1754
+ def prepare_for_predict(self) -> None:
1755
+ """Prepares the model for prediction."""
1756
+ if self.config[INTENT_CLASSIFICATION]:
1757
+ _, self.all_labels_embed = self._create_all_labels()
1758
+
1759
+ def batch_predict(
1760
+ self, batch_in: Union[Tuple[tf.Tensor, ...], Tuple[np.ndarray, ...]]
1761
+ ) -> Dict[Text, tf.Tensor]:
1762
+ """Predicts the output of the given batch.
1763
+
1764
+ Args:
1765
+ batch_in: The batch.
1766
+
1767
+ Returns:
1768
+ The output to predict.
1769
+ """
1770
+ tf_batch_data = self.batch_to_model_data_format(
1771
+ batch_in, self.predict_data_signature
1772
+ )
1773
+
1774
+ sequence_feature_lengths = self._get_sequence_feature_lengths(
1775
+ tf_batch_data, TEXT
1776
+ )
1777
+ sentence_feature_lengths = self._get_sentence_feature_lengths(
1778
+ tf_batch_data, TEXT
1779
+ )
1780
+
1781
+ text_transformed, _, _, _, _, attention_weights = self._tf_layers[
1782
+ f"sequence_layer.{self.text_name}"
1783
+ ](
1784
+ (
1785
+ tf_batch_data[TEXT][SEQUENCE],
1786
+ tf_batch_data[TEXT][SENTENCE],
1787
+ sequence_feature_lengths,
1788
+ ),
1789
+ training=self._training,
1790
+ )
1791
+ predictions = {
1792
+ DIAGNOSTIC_DATA: {
1793
+ "attention_weights": attention_weights,
1794
+ "text_transformed": text_transformed,
1795
+ }
1796
+ }
1797
+
1798
+ if self.config[INTENT_CLASSIFICATION]:
1799
+ predictions.update(
1800
+ self._batch_predict_intents(
1801
+ sequence_feature_lengths + sentence_feature_lengths,
1802
+ text_transformed,
1803
+ )
1804
+ )
1805
+
1806
+ if self.config[ENTITY_RECOGNITION]:
1807
+ predictions.update(
1808
+ self._batch_predict_entities(sequence_feature_lengths, text_transformed)
1809
+ )
1810
+
1811
+ return predictions
1812
+
1813
+ def _batch_predict_entities(
1814
+ self, sequence_feature_lengths: tf.Tensor, text_transformed: tf.Tensor
1815
+ ) -> Dict[Text, tf.Tensor]:
1816
+ predictions: Dict[Text, tf.Tensor] = {}
1817
+
1818
+ entity_tags = None
1819
+
1820
+ for tag_spec in self._entity_tag_specs:
1821
+ # skip crf layer if it was not trained
1822
+ if tag_spec.num_tags == 0:
1823
+ continue
1824
+
1825
+ name = tag_spec.tag_name
1826
+ _input = text_transformed
1827
+
1828
+ if entity_tags is not None:
1829
+ _tags = self._tf_layers[f"embed.{name}.tags"](entity_tags)
1830
+ _input = tf.concat([_input, _tags], axis=-1)
1831
+
1832
+ _logits = self._tf_layers[f"embed.{name}.logits"](_input)
1833
+ pred_ids, confidences = self._tf_layers[f"crf.{name}"](
1834
+ _logits, sequence_feature_lengths
1835
+ )
1836
+
1837
+ predictions[f"e_{name}_ids"] = pred_ids
1838
+ predictions[f"e_{name}_scores"] = confidences
1839
+
1840
+ if name == ENTITY_ATTRIBUTE_TYPE:
1841
+ # use the entity tags as additional input for the role
1842
+ # and group CRF
1843
+ entity_tags = tf.one_hot(
1844
+ tf.cast(pred_ids, tf.int32), depth=tag_spec.num_tags
1845
+ )
1846
+
1847
+ return predictions
1848
+
1849
+ def _batch_predict_intents(
1850
+ self,
1851
+ combined_sequence_sentence_feature_lengths: tf.Tensor,
1852
+ text_transformed: tf.Tensor,
1853
+ ) -> Dict[Text, tf.Tensor]:
1854
+
1855
+ if self.all_labels_embed is None:
1856
+ raise ValueError(
1857
+ "The model was not prepared for prediction. "
1858
+ "Call `prepare_for_predict` first."
1859
+ )
1860
+
1861
+ # get sentence feature vector for intent classification
1862
+ sentence_vector = self._last_token(
1863
+ text_transformed, combined_sequence_sentence_feature_lengths
1864
+ )
1865
+ sentence_vector_embed = self._tf_layers[f"embed.{TEXT}"](sentence_vector)
1866
+
1867
+ _, scores = self._tf_layers[
1868
+ f"loss.{LABEL}"
1869
+ ].get_similarities_and_confidences_from_embeddings(
1870
+ sentence_vector_embed[:, tf.newaxis, :],
1871
+ self.all_labels_embed[tf.newaxis, :, :],
1872
+ )
1873
+
1874
+ return {"i_scores": scores}