rasa-pro 3.8.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rasa-pro might be problematic. Click here for more details.

Files changed (644) hide show
  1. README.md +380 -0
  2. rasa/__init__.py +10 -0
  3. rasa/__main__.py +151 -0
  4. rasa/anonymization/__init__.py +2 -0
  5. rasa/anonymization/anonymisation_rule_yaml_reader.py +91 -0
  6. rasa/anonymization/anonymization_pipeline.py +287 -0
  7. rasa/anonymization/anonymization_rule_executor.py +260 -0
  8. rasa/anonymization/anonymization_rule_orchestrator.py +120 -0
  9. rasa/anonymization/schemas/config.yml +47 -0
  10. rasa/anonymization/utils.py +117 -0
  11. rasa/api.py +146 -0
  12. rasa/cli/__init__.py +5 -0
  13. rasa/cli/arguments/__init__.py +0 -0
  14. rasa/cli/arguments/data.py +81 -0
  15. rasa/cli/arguments/default_arguments.py +165 -0
  16. rasa/cli/arguments/evaluate.py +65 -0
  17. rasa/cli/arguments/export.py +51 -0
  18. rasa/cli/arguments/interactive.py +74 -0
  19. rasa/cli/arguments/run.py +204 -0
  20. rasa/cli/arguments/shell.py +13 -0
  21. rasa/cli/arguments/test.py +211 -0
  22. rasa/cli/arguments/train.py +263 -0
  23. rasa/cli/arguments/visualize.py +34 -0
  24. rasa/cli/arguments/x.py +30 -0
  25. rasa/cli/data.py +292 -0
  26. rasa/cli/e2e_test.py +566 -0
  27. rasa/cli/evaluate.py +222 -0
  28. rasa/cli/export.py +251 -0
  29. rasa/cli/inspect.py +63 -0
  30. rasa/cli/interactive.py +164 -0
  31. rasa/cli/license.py +65 -0
  32. rasa/cli/markers.py +78 -0
  33. rasa/cli/project_templates/__init__.py +0 -0
  34. rasa/cli/project_templates/calm/actions/__init__.py +0 -0
  35. rasa/cli/project_templates/calm/actions/action_template.py +27 -0
  36. rasa/cli/project_templates/calm/actions/add_contact.py +30 -0
  37. rasa/cli/project_templates/calm/actions/db.py +57 -0
  38. rasa/cli/project_templates/calm/actions/list_contacts.py +22 -0
  39. rasa/cli/project_templates/calm/actions/remove_contact.py +35 -0
  40. rasa/cli/project_templates/calm/config.yml +12 -0
  41. rasa/cli/project_templates/calm/credentials.yml +33 -0
  42. rasa/cli/project_templates/calm/data/flows/add_contact.yml +31 -0
  43. rasa/cli/project_templates/calm/data/flows/list_contacts.yml +14 -0
  44. rasa/cli/project_templates/calm/data/flows/remove_contact.yml +29 -0
  45. rasa/cli/project_templates/calm/db/contacts.json +10 -0
  46. rasa/cli/project_templates/calm/domain/add_contact.yml +33 -0
  47. rasa/cli/project_templates/calm/domain/list_contacts.yml +14 -0
  48. rasa/cli/project_templates/calm/domain/remove_contact.yml +31 -0
  49. rasa/cli/project_templates/calm/domain/shared.yml +5 -0
  50. rasa/cli/project_templates/calm/e2e_tests/cancelations/user_cancels_during_a_correction.yml +16 -0
  51. rasa/cli/project_templates/calm/e2e_tests/cancelations/user_changes_mind_on_a_whim.yml +7 -0
  52. rasa/cli/project_templates/calm/e2e_tests/corrections/user_corrects_contact_handle.yml +20 -0
  53. rasa/cli/project_templates/calm/e2e_tests/corrections/user_corrects_contact_name.yml +19 -0
  54. rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_adds_contact_to_their_list.yml +15 -0
  55. rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_lists_contacts.yml +5 -0
  56. rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_removes_contact.yml +11 -0
  57. rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_removes_contact_from_list.yml +12 -0
  58. rasa/cli/project_templates/calm/endpoints.yml +45 -0
  59. rasa/cli/project_templates/default/actions/__init__.py +0 -0
  60. rasa/cli/project_templates/default/actions/actions.py +27 -0
  61. rasa/cli/project_templates/default/config.yml +44 -0
  62. rasa/cli/project_templates/default/credentials.yml +33 -0
  63. rasa/cli/project_templates/default/data/nlu.yml +91 -0
  64. rasa/cli/project_templates/default/data/rules.yml +13 -0
  65. rasa/cli/project_templates/default/data/stories.yml +30 -0
  66. rasa/cli/project_templates/default/domain.yml +34 -0
  67. rasa/cli/project_templates/default/endpoints.yml +42 -0
  68. rasa/cli/project_templates/default/tests/test_stories.yml +91 -0
  69. rasa/cli/project_templates/tutorial/actions.py +22 -0
  70. rasa/cli/project_templates/tutorial/config.yml +11 -0
  71. rasa/cli/project_templates/tutorial/credentials.yml +33 -0
  72. rasa/cli/project_templates/tutorial/data/flows.yml +8 -0
  73. rasa/cli/project_templates/tutorial/domain.yml +17 -0
  74. rasa/cli/project_templates/tutorial/endpoints.yml +45 -0
  75. rasa/cli/run.py +136 -0
  76. rasa/cli/scaffold.py +268 -0
  77. rasa/cli/shell.py +141 -0
  78. rasa/cli/studio/__init__.py +0 -0
  79. rasa/cli/studio/download.py +51 -0
  80. rasa/cli/studio/studio.py +110 -0
  81. rasa/cli/studio/train.py +59 -0
  82. rasa/cli/studio/upload.py +85 -0
  83. rasa/cli/telemetry.py +90 -0
  84. rasa/cli/test.py +280 -0
  85. rasa/cli/train.py +260 -0
  86. rasa/cli/utils.py +453 -0
  87. rasa/cli/visualize.py +40 -0
  88. rasa/cli/x.py +205 -0
  89. rasa/constants.py +37 -0
  90. rasa/core/__init__.py +17 -0
  91. rasa/core/actions/__init__.py +0 -0
  92. rasa/core/actions/action.py +1450 -0
  93. rasa/core/actions/action_clean_stack.py +59 -0
  94. rasa/core/actions/action_run_slot_rejections.py +207 -0
  95. rasa/core/actions/action_trigger_chitchat.py +31 -0
  96. rasa/core/actions/action_trigger_flow.py +109 -0
  97. rasa/core/actions/action_trigger_search.py +31 -0
  98. rasa/core/actions/constants.py +2 -0
  99. rasa/core/actions/forms.py +737 -0
  100. rasa/core/actions/loops.py +111 -0
  101. rasa/core/actions/two_stage_fallback.py +186 -0
  102. rasa/core/agent.py +557 -0
  103. rasa/core/auth_retry_tracker_store.py +122 -0
  104. rasa/core/brokers/__init__.py +0 -0
  105. rasa/core/brokers/broker.py +126 -0
  106. rasa/core/brokers/file.py +58 -0
  107. rasa/core/brokers/kafka.py +322 -0
  108. rasa/core/brokers/pika.py +387 -0
  109. rasa/core/brokers/sql.py +86 -0
  110. rasa/core/channels/__init__.py +55 -0
  111. rasa/core/channels/audiocodes.py +463 -0
  112. rasa/core/channels/botframework.py +339 -0
  113. rasa/core/channels/callback.py +85 -0
  114. rasa/core/channels/channel.py +419 -0
  115. rasa/core/channels/console.py +243 -0
  116. rasa/core/channels/development_inspector.py +93 -0
  117. rasa/core/channels/facebook.py +422 -0
  118. rasa/core/channels/hangouts.py +335 -0
  119. rasa/core/channels/inspector/.eslintrc.cjs +25 -0
  120. rasa/core/channels/inspector/.gitignore +23 -0
  121. rasa/core/channels/inspector/README.md +54 -0
  122. rasa/core/channels/inspector/assets/favicon.ico +0 -0
  123. rasa/core/channels/inspector/assets/rasa-chat.js +2 -0
  124. rasa/core/channels/inspector/custom.d.ts +3 -0
  125. rasa/core/channels/inspector/dist/assets/arc-5623b6dc.js +1 -0
  126. rasa/core/channels/inspector/dist/assets/array-9f3ba611.js +1 -0
  127. rasa/core/channels/inspector/dist/assets/c4Diagram-d0fbc5ce-685c106a.js +10 -0
  128. rasa/core/channels/inspector/dist/assets/classDiagram-936ed81e-8cbed007.js +2 -0
  129. rasa/core/channels/inspector/dist/assets/classDiagram-v2-c3cb15f1-5889cf12.js +2 -0
  130. rasa/core/channels/inspector/dist/assets/createText-62fc7601-24c249d7.js +7 -0
  131. rasa/core/channels/inspector/dist/assets/edges-f2ad444c-7dd06a75.js +4 -0
  132. rasa/core/channels/inspector/dist/assets/erDiagram-9d236eb7-62c1e54c.js +51 -0
  133. rasa/core/channels/inspector/dist/assets/flowDb-1972c806-ce49b86f.js +6 -0
  134. rasa/core/channels/inspector/dist/assets/flowDiagram-7ea5b25a-4067e48f.js +4 -0
  135. rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-85583a23.js +1 -0
  136. rasa/core/channels/inspector/dist/assets/flowchart-elk-definition-abe16c3d-59fe4051.js +139 -0
  137. rasa/core/channels/inspector/dist/assets/ganttDiagram-9b5ea136-47e3a43b.js +266 -0
  138. rasa/core/channels/inspector/dist/assets/gitGraphDiagram-99d0ae7c-5a2ac0d9.js +70 -0
  139. rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-128cfa44.ttf +0 -0
  140. rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-21dbcb97.woff +0 -0
  141. rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-222b5e26.svg +329 -0
  142. rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-9ad89b2a.woff2 +0 -0
  143. rasa/core/channels/inspector/dist/assets/index-268a75c0.js +1040 -0
  144. rasa/core/channels/inspector/dist/assets/index-2c4b9a3b-dfb8efc4.js +1 -0
  145. rasa/core/channels/inspector/dist/assets/index-3ee28881.css +1 -0
  146. rasa/core/channels/inspector/dist/assets/infoDiagram-736b4530-b0c470f2.js +7 -0
  147. rasa/core/channels/inspector/dist/assets/init-77b53fdd.js +1 -0
  148. rasa/core/channels/inspector/dist/assets/journeyDiagram-df861f2b-2edb829a.js +139 -0
  149. rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-60c05ee4.woff +0 -0
  150. rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-8335d9b8.svg +438 -0
  151. rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-9cc39c75.ttf +0 -0
  152. rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-ead13ccf.woff2 +0 -0
  153. rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-16705655.woff2 +0 -0
  154. rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-5aeb07f9.woff +0 -0
  155. rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-9c459044.ttf +0 -0
  156. rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-9e2898a4.svg +435 -0
  157. rasa/core/channels/inspector/dist/assets/layout-b6873d69.js +1 -0
  158. rasa/core/channels/inspector/dist/assets/line-1efc5781.js +1 -0
  159. rasa/core/channels/inspector/dist/assets/linear-661e9b94.js +1 -0
  160. rasa/core/channels/inspector/dist/assets/mindmap-definition-beec6740-2d2e727f.js +109 -0
  161. rasa/core/channels/inspector/dist/assets/ordinal-ba9b4969.js +1 -0
  162. rasa/core/channels/inspector/dist/assets/path-53f90ab3.js +1 -0
  163. rasa/core/channels/inspector/dist/assets/pieDiagram-dbbf0591-9d3ea93d.js +35 -0
  164. rasa/core/channels/inspector/dist/assets/quadrantDiagram-4d7f4fd6-06a178a2.js +7 -0
  165. rasa/core/channels/inspector/dist/assets/requirementDiagram-6fc4c22a-0bfedffc.js +52 -0
  166. rasa/core/channels/inspector/dist/assets/sankeyDiagram-8f13d901-d76d0a04.js +8 -0
  167. rasa/core/channels/inspector/dist/assets/sequenceDiagram-b655622a-37bb4341.js +122 -0
  168. rasa/core/channels/inspector/dist/assets/stateDiagram-59f0c015-f52f7f57.js +1 -0
  169. rasa/core/channels/inspector/dist/assets/stateDiagram-v2-2b26beab-4a986a20.js +1 -0
  170. rasa/core/channels/inspector/dist/assets/styles-080da4f6-7dd9ae12.js +110 -0
  171. rasa/core/channels/inspector/dist/assets/styles-3dcbcfbf-46e1ca14.js +159 -0
  172. rasa/core/channels/inspector/dist/assets/styles-9c745c82-4a97439a.js +207 -0
  173. rasa/core/channels/inspector/dist/assets/svgDrawCommon-4835440b-823917a3.js +1 -0
  174. rasa/core/channels/inspector/dist/assets/timeline-definition-5b62e21b-9ea72896.js +61 -0
  175. rasa/core/channels/inspector/dist/assets/xychartDiagram-2b33534f-b631a8b6.js +7 -0
  176. rasa/core/channels/inspector/dist/index.html +39 -0
  177. rasa/core/channels/inspector/index.html +37 -0
  178. rasa/core/channels/inspector/jest.config.ts +13 -0
  179. rasa/core/channels/inspector/package.json +48 -0
  180. rasa/core/channels/inspector/setupTests.ts +2 -0
  181. rasa/core/channels/inspector/src/App.tsx +170 -0
  182. rasa/core/channels/inspector/src/components/DiagramFlow.tsx +97 -0
  183. rasa/core/channels/inspector/src/components/DialogueInformation.tsx +187 -0
  184. rasa/core/channels/inspector/src/components/DialogueStack.tsx +151 -0
  185. rasa/core/channels/inspector/src/components/ExpandIcon.tsx +16 -0
  186. rasa/core/channels/inspector/src/components/FullscreenButton.tsx +45 -0
  187. rasa/core/channels/inspector/src/components/LoadingSpinner.tsx +19 -0
  188. rasa/core/channels/inspector/src/components/NoActiveFlow.tsx +21 -0
  189. rasa/core/channels/inspector/src/components/RasaLogo.tsx +32 -0
  190. rasa/core/channels/inspector/src/components/SaraDiagrams.tsx +39 -0
  191. rasa/core/channels/inspector/src/components/Slots.tsx +91 -0
  192. rasa/core/channels/inspector/src/components/Welcome.tsx +54 -0
  193. rasa/core/channels/inspector/src/helpers/formatters.test.ts +385 -0
  194. rasa/core/channels/inspector/src/helpers/formatters.ts +239 -0
  195. rasa/core/channels/inspector/src/helpers/utils.ts +42 -0
  196. rasa/core/channels/inspector/src/main.tsx +13 -0
  197. rasa/core/channels/inspector/src/theme/Button/Button.ts +29 -0
  198. rasa/core/channels/inspector/src/theme/Heading/Heading.ts +31 -0
  199. rasa/core/channels/inspector/src/theme/Input/Input.ts +27 -0
  200. rasa/core/channels/inspector/src/theme/Link/Link.ts +10 -0
  201. rasa/core/channels/inspector/src/theme/Modal/Modal.ts +47 -0
  202. rasa/core/channels/inspector/src/theme/Table/Table.tsx +38 -0
  203. rasa/core/channels/inspector/src/theme/Tooltip/Tooltip.ts +12 -0
  204. rasa/core/channels/inspector/src/theme/base/breakpoints.ts +8 -0
  205. rasa/core/channels/inspector/src/theme/base/colors.ts +88 -0
  206. rasa/core/channels/inspector/src/theme/base/fonts/fontFaces.css +29 -0
  207. rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.eot +0 -0
  208. rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.svg +329 -0
  209. rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.ttf +0 -0
  210. rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.woff +0 -0
  211. rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.woff2 +0 -0
  212. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.eot +0 -0
  213. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.svg +438 -0
  214. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.ttf +0 -0
  215. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.woff +0 -0
  216. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.woff2 +0 -0
  217. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.eot +0 -0
  218. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.svg +435 -0
  219. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.ttf +0 -0
  220. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.woff +0 -0
  221. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.woff2 +0 -0
  222. rasa/core/channels/inspector/src/theme/base/radii.ts +9 -0
  223. rasa/core/channels/inspector/src/theme/base/shadows.ts +7 -0
  224. rasa/core/channels/inspector/src/theme/base/sizes.ts +7 -0
  225. rasa/core/channels/inspector/src/theme/base/space.ts +15 -0
  226. rasa/core/channels/inspector/src/theme/base/styles.ts +13 -0
  227. rasa/core/channels/inspector/src/theme/base/typography.ts +24 -0
  228. rasa/core/channels/inspector/src/theme/base/zIndices.ts +19 -0
  229. rasa/core/channels/inspector/src/theme/index.ts +101 -0
  230. rasa/core/channels/inspector/src/types.ts +64 -0
  231. rasa/core/channels/inspector/src/vite-env.d.ts +1 -0
  232. rasa/core/channels/inspector/tests/__mocks__/fileMock.ts +1 -0
  233. rasa/core/channels/inspector/tests/__mocks__/matchMedia.ts +16 -0
  234. rasa/core/channels/inspector/tests/__mocks__/styleMock.ts +1 -0
  235. rasa/core/channels/inspector/tests/renderWithProviders.tsx +14 -0
  236. rasa/core/channels/inspector/tsconfig.json +26 -0
  237. rasa/core/channels/inspector/tsconfig.node.json +10 -0
  238. rasa/core/channels/inspector/vite.config.ts +8 -0
  239. rasa/core/channels/inspector/yarn.lock +6156 -0
  240. rasa/core/channels/mattermost.py +229 -0
  241. rasa/core/channels/rasa_chat.py +126 -0
  242. rasa/core/channels/rest.py +210 -0
  243. rasa/core/channels/rocketchat.py +175 -0
  244. rasa/core/channels/slack.py +620 -0
  245. rasa/core/channels/socketio.py +274 -0
  246. rasa/core/channels/telegram.py +298 -0
  247. rasa/core/channels/twilio.py +169 -0
  248. rasa/core/channels/twilio_voice.py +367 -0
  249. rasa/core/channels/vier_cvg.py +374 -0
  250. rasa/core/channels/webexteams.py +135 -0
  251. rasa/core/concurrent_lock_store.py +210 -0
  252. rasa/core/constants.py +107 -0
  253. rasa/core/evaluation/__init__.py +0 -0
  254. rasa/core/evaluation/marker.py +267 -0
  255. rasa/core/evaluation/marker_base.py +925 -0
  256. rasa/core/evaluation/marker_stats.py +294 -0
  257. rasa/core/evaluation/marker_tracker_loader.py +103 -0
  258. rasa/core/exceptions.py +29 -0
  259. rasa/core/exporter.py +284 -0
  260. rasa/core/featurizers/__init__.py +0 -0
  261. rasa/core/featurizers/precomputation.py +410 -0
  262. rasa/core/featurizers/single_state_featurizer.py +402 -0
  263. rasa/core/featurizers/tracker_featurizers.py +1172 -0
  264. rasa/core/http_interpreter.py +89 -0
  265. rasa/core/information_retrieval/__init__.py +0 -0
  266. rasa/core/information_retrieval/faiss.py +116 -0
  267. rasa/core/information_retrieval/information_retrieval.py +72 -0
  268. rasa/core/information_retrieval/milvus.py +59 -0
  269. rasa/core/information_retrieval/qdrant.py +102 -0
  270. rasa/core/jobs.py +63 -0
  271. rasa/core/lock.py +139 -0
  272. rasa/core/lock_store.py +344 -0
  273. rasa/core/migrate.py +404 -0
  274. rasa/core/nlg/__init__.py +3 -0
  275. rasa/core/nlg/callback.py +147 -0
  276. rasa/core/nlg/contextual_response_rephraser.py +270 -0
  277. rasa/core/nlg/generator.py +230 -0
  278. rasa/core/nlg/interpolator.py +143 -0
  279. rasa/core/nlg/response.py +155 -0
  280. rasa/core/nlg/summarize.py +69 -0
  281. rasa/core/policies/__init__.py +0 -0
  282. rasa/core/policies/ensemble.py +329 -0
  283. rasa/core/policies/enterprise_search_policy.py +717 -0
  284. rasa/core/policies/enterprise_search_prompt_template.jinja2 +62 -0
  285. rasa/core/policies/flow_policy.py +205 -0
  286. rasa/core/policies/flows/__init__.py +0 -0
  287. rasa/core/policies/flows/flow_exceptions.py +44 -0
  288. rasa/core/policies/flows/flow_executor.py +582 -0
  289. rasa/core/policies/flows/flow_step_result.py +43 -0
  290. rasa/core/policies/intentless_policy.py +924 -0
  291. rasa/core/policies/intentless_prompt_template.jinja2 +22 -0
  292. rasa/core/policies/memoization.py +538 -0
  293. rasa/core/policies/policy.py +716 -0
  294. rasa/core/policies/rule_policy.py +1276 -0
  295. rasa/core/policies/ted_policy.py +2146 -0
  296. rasa/core/policies/unexpected_intent_policy.py +1015 -0
  297. rasa/core/processor.py +1331 -0
  298. rasa/core/run.py +315 -0
  299. rasa/core/secrets_manager/__init__.py +0 -0
  300. rasa/core/secrets_manager/constants.py +32 -0
  301. rasa/core/secrets_manager/endpoints.py +391 -0
  302. rasa/core/secrets_manager/factory.py +233 -0
  303. rasa/core/secrets_manager/secret_manager.py +262 -0
  304. rasa/core/secrets_manager/vault.py +576 -0
  305. rasa/core/test.py +1337 -0
  306. rasa/core/tracker_store.py +1664 -0
  307. rasa/core/train.py +107 -0
  308. rasa/core/training/__init__.py +89 -0
  309. rasa/core/training/converters/__init__.py +0 -0
  310. rasa/core/training/converters/responses_prefix_converter.py +119 -0
  311. rasa/core/training/interactive.py +1742 -0
  312. rasa/core/training/story_conflict.py +381 -0
  313. rasa/core/training/training.py +93 -0
  314. rasa/core/utils.py +344 -0
  315. rasa/core/visualize.py +70 -0
  316. rasa/dialogue_understanding/__init__.py +0 -0
  317. rasa/dialogue_understanding/coexistence/__init__.py +0 -0
  318. rasa/dialogue_understanding/coexistence/constants.py +4 -0
  319. rasa/dialogue_understanding/coexistence/intent_based_router.py +189 -0
  320. rasa/dialogue_understanding/coexistence/llm_based_router.py +261 -0
  321. rasa/dialogue_understanding/coexistence/router_template.jinja2 +12 -0
  322. rasa/dialogue_understanding/commands/__init__.py +45 -0
  323. rasa/dialogue_understanding/commands/can_not_handle_command.py +61 -0
  324. rasa/dialogue_understanding/commands/cancel_flow_command.py +116 -0
  325. rasa/dialogue_understanding/commands/chit_chat_answer_command.py +48 -0
  326. rasa/dialogue_understanding/commands/clarify_command.py +77 -0
  327. rasa/dialogue_understanding/commands/command.py +85 -0
  328. rasa/dialogue_understanding/commands/correct_slots_command.py +288 -0
  329. rasa/dialogue_understanding/commands/error_command.py +67 -0
  330. rasa/dialogue_understanding/commands/free_form_answer_command.py +9 -0
  331. rasa/dialogue_understanding/commands/handle_code_change_command.py +64 -0
  332. rasa/dialogue_understanding/commands/human_handoff_command.py +57 -0
  333. rasa/dialogue_understanding/commands/knowledge_answer_command.py +48 -0
  334. rasa/dialogue_understanding/commands/noop_command.py +45 -0
  335. rasa/dialogue_understanding/commands/set_slot_command.py +125 -0
  336. rasa/dialogue_understanding/commands/skip_question_command.py +66 -0
  337. rasa/dialogue_understanding/commands/start_flow_command.py +98 -0
  338. rasa/dialogue_understanding/generator/__init__.py +6 -0
  339. rasa/dialogue_understanding/generator/command_generator.py +257 -0
  340. rasa/dialogue_understanding/generator/command_prompt_template.jinja2 +57 -0
  341. rasa/dialogue_understanding/generator/flow_document_template.jinja2 +4 -0
  342. rasa/dialogue_understanding/generator/flow_retrieval.py +410 -0
  343. rasa/dialogue_understanding/generator/llm_command_generator.py +637 -0
  344. rasa/dialogue_understanding/generator/nlu_command_adapter.py +157 -0
  345. rasa/dialogue_understanding/patterns/__init__.py +0 -0
  346. rasa/dialogue_understanding/patterns/cancel.py +111 -0
  347. rasa/dialogue_understanding/patterns/cannot_handle.py +43 -0
  348. rasa/dialogue_understanding/patterns/chitchat.py +37 -0
  349. rasa/dialogue_understanding/patterns/clarify.py +97 -0
  350. rasa/dialogue_understanding/patterns/code_change.py +41 -0
  351. rasa/dialogue_understanding/patterns/collect_information.py +90 -0
  352. rasa/dialogue_understanding/patterns/completed.py +40 -0
  353. rasa/dialogue_understanding/patterns/continue_interrupted.py +42 -0
  354. rasa/dialogue_understanding/patterns/correction.py +278 -0
  355. rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +243 -0
  356. rasa/dialogue_understanding/patterns/human_handoff.py +37 -0
  357. rasa/dialogue_understanding/patterns/internal_error.py +47 -0
  358. rasa/dialogue_understanding/patterns/search.py +37 -0
  359. rasa/dialogue_understanding/patterns/skip_question.py +38 -0
  360. rasa/dialogue_understanding/processor/__init__.py +0 -0
  361. rasa/dialogue_understanding/processor/command_processor.py +578 -0
  362. rasa/dialogue_understanding/processor/command_processor_component.py +39 -0
  363. rasa/dialogue_understanding/stack/__init__.py +0 -0
  364. rasa/dialogue_understanding/stack/dialogue_stack.py +178 -0
  365. rasa/dialogue_understanding/stack/frames/__init__.py +19 -0
  366. rasa/dialogue_understanding/stack/frames/chit_chat_frame.py +27 -0
  367. rasa/dialogue_understanding/stack/frames/dialogue_stack_frame.py +137 -0
  368. rasa/dialogue_understanding/stack/frames/flow_stack_frame.py +157 -0
  369. rasa/dialogue_understanding/stack/frames/pattern_frame.py +10 -0
  370. rasa/dialogue_understanding/stack/frames/search_frame.py +27 -0
  371. rasa/dialogue_understanding/stack/utils.py +211 -0
  372. rasa/e2e_test/__init__.py +0 -0
  373. rasa/e2e_test/constants.py +10 -0
  374. rasa/e2e_test/e2e_test_case.py +322 -0
  375. rasa/e2e_test/e2e_test_result.py +34 -0
  376. rasa/e2e_test/e2e_test_runner.py +659 -0
  377. rasa/e2e_test/e2e_test_schema.yml +67 -0
  378. rasa/engine/__init__.py +0 -0
  379. rasa/engine/caching.py +464 -0
  380. rasa/engine/constants.py +17 -0
  381. rasa/engine/exceptions.py +14 -0
  382. rasa/engine/graph.py +625 -0
  383. rasa/engine/loader.py +36 -0
  384. rasa/engine/recipes/__init__.py +0 -0
  385. rasa/engine/recipes/config_files/default_config.yml +44 -0
  386. rasa/engine/recipes/default_components.py +99 -0
  387. rasa/engine/recipes/default_recipe.py +1252 -0
  388. rasa/engine/recipes/graph_recipe.py +79 -0
  389. rasa/engine/recipes/recipe.py +93 -0
  390. rasa/engine/runner/__init__.py +0 -0
  391. rasa/engine/runner/dask.py +256 -0
  392. rasa/engine/runner/interface.py +49 -0
  393. rasa/engine/storage/__init__.py +0 -0
  394. rasa/engine/storage/local_model_storage.py +248 -0
  395. rasa/engine/storage/resource.py +110 -0
  396. rasa/engine/storage/storage.py +203 -0
  397. rasa/engine/training/__init__.py +0 -0
  398. rasa/engine/training/components.py +176 -0
  399. rasa/engine/training/fingerprinting.py +64 -0
  400. rasa/engine/training/graph_trainer.py +256 -0
  401. rasa/engine/training/hooks.py +164 -0
  402. rasa/engine/validation.py +839 -0
  403. rasa/env.py +5 -0
  404. rasa/exceptions.py +69 -0
  405. rasa/graph_components/__init__.py +0 -0
  406. rasa/graph_components/converters/__init__.py +0 -0
  407. rasa/graph_components/converters/nlu_message_converter.py +48 -0
  408. rasa/graph_components/providers/__init__.py +0 -0
  409. rasa/graph_components/providers/domain_for_core_training_provider.py +87 -0
  410. rasa/graph_components/providers/domain_provider.py +71 -0
  411. rasa/graph_components/providers/flows_provider.py +74 -0
  412. rasa/graph_components/providers/forms_provider.py +44 -0
  413. rasa/graph_components/providers/nlu_training_data_provider.py +56 -0
  414. rasa/graph_components/providers/responses_provider.py +44 -0
  415. rasa/graph_components/providers/rule_only_provider.py +49 -0
  416. rasa/graph_components/providers/story_graph_provider.py +43 -0
  417. rasa/graph_components/providers/training_tracker_provider.py +55 -0
  418. rasa/graph_components/validators/__init__.py +0 -0
  419. rasa/graph_components/validators/default_recipe_validator.py +552 -0
  420. rasa/graph_components/validators/finetuning_validator.py +302 -0
  421. rasa/hooks.py +113 -0
  422. rasa/jupyter.py +63 -0
  423. rasa/keys +1 -0
  424. rasa/markers/__init__.py +0 -0
  425. rasa/markers/marker.py +269 -0
  426. rasa/markers/marker_base.py +828 -0
  427. rasa/markers/upload.py +74 -0
  428. rasa/markers/validate.py +21 -0
  429. rasa/model.py +118 -0
  430. rasa/model_testing.py +457 -0
  431. rasa/model_training.py +535 -0
  432. rasa/nlu/__init__.py +7 -0
  433. rasa/nlu/classifiers/__init__.py +3 -0
  434. rasa/nlu/classifiers/classifier.py +5 -0
  435. rasa/nlu/classifiers/diet_classifier.py +1874 -0
  436. rasa/nlu/classifiers/fallback_classifier.py +192 -0
  437. rasa/nlu/classifiers/keyword_intent_classifier.py +188 -0
  438. rasa/nlu/classifiers/llm_intent_classifier.py +519 -0
  439. rasa/nlu/classifiers/logistic_regression_classifier.py +240 -0
  440. rasa/nlu/classifiers/mitie_intent_classifier.py +156 -0
  441. rasa/nlu/classifiers/regex_message_handler.py +56 -0
  442. rasa/nlu/classifiers/sklearn_intent_classifier.py +309 -0
  443. rasa/nlu/constants.py +77 -0
  444. rasa/nlu/convert.py +40 -0
  445. rasa/nlu/emulators/__init__.py +0 -0
  446. rasa/nlu/emulators/dialogflow.py +55 -0
  447. rasa/nlu/emulators/emulator.py +49 -0
  448. rasa/nlu/emulators/luis.py +86 -0
  449. rasa/nlu/emulators/no_emulator.py +10 -0
  450. rasa/nlu/emulators/wit.py +56 -0
  451. rasa/nlu/extractors/__init__.py +0 -0
  452. rasa/nlu/extractors/crf_entity_extractor.py +672 -0
  453. rasa/nlu/extractors/duckling_entity_extractor.py +206 -0
  454. rasa/nlu/extractors/entity_synonyms.py +178 -0
  455. rasa/nlu/extractors/extractor.py +470 -0
  456. rasa/nlu/extractors/mitie_entity_extractor.py +293 -0
  457. rasa/nlu/extractors/regex_entity_extractor.py +220 -0
  458. rasa/nlu/extractors/spacy_entity_extractor.py +95 -0
  459. rasa/nlu/featurizers/__init__.py +0 -0
  460. rasa/nlu/featurizers/dense_featurizer/__init__.py +0 -0
  461. rasa/nlu/featurizers/dense_featurizer/convert_featurizer.py +449 -0
  462. rasa/nlu/featurizers/dense_featurizer/dense_featurizer.py +57 -0
  463. rasa/nlu/featurizers/dense_featurizer/lm_featurizer.py +772 -0
  464. rasa/nlu/featurizers/dense_featurizer/mitie_featurizer.py +170 -0
  465. rasa/nlu/featurizers/dense_featurizer/spacy_featurizer.py +132 -0
  466. rasa/nlu/featurizers/featurizer.py +89 -0
  467. rasa/nlu/featurizers/sparse_featurizer/__init__.py +0 -0
  468. rasa/nlu/featurizers/sparse_featurizer/count_vectors_featurizer.py +840 -0
  469. rasa/nlu/featurizers/sparse_featurizer/lexical_syntactic_featurizer.py +539 -0
  470. rasa/nlu/featurizers/sparse_featurizer/regex_featurizer.py +269 -0
  471. rasa/nlu/featurizers/sparse_featurizer/sparse_featurizer.py +9 -0
  472. rasa/nlu/model.py +24 -0
  473. rasa/nlu/persistor.py +240 -0
  474. rasa/nlu/run.py +27 -0
  475. rasa/nlu/selectors/__init__.py +0 -0
  476. rasa/nlu/selectors/response_selector.py +990 -0
  477. rasa/nlu/test.py +1943 -0
  478. rasa/nlu/tokenizers/__init__.py +0 -0
  479. rasa/nlu/tokenizers/jieba_tokenizer.py +148 -0
  480. rasa/nlu/tokenizers/mitie_tokenizer.py +75 -0
  481. rasa/nlu/tokenizers/spacy_tokenizer.py +72 -0
  482. rasa/nlu/tokenizers/tokenizer.py +239 -0
  483. rasa/nlu/tokenizers/whitespace_tokenizer.py +106 -0
  484. rasa/nlu/utils/__init__.py +35 -0
  485. rasa/nlu/utils/bilou_utils.py +462 -0
  486. rasa/nlu/utils/hugging_face/__init__.py +0 -0
  487. rasa/nlu/utils/hugging_face/registry.py +108 -0
  488. rasa/nlu/utils/hugging_face/transformers_pre_post_processors.py +311 -0
  489. rasa/nlu/utils/mitie_utils.py +113 -0
  490. rasa/nlu/utils/pattern_utils.py +168 -0
  491. rasa/nlu/utils/spacy_utils.py +312 -0
  492. rasa/plugin.py +90 -0
  493. rasa/server.py +1536 -0
  494. rasa/shared/__init__.py +0 -0
  495. rasa/shared/constants.py +181 -0
  496. rasa/shared/core/__init__.py +0 -0
  497. rasa/shared/core/constants.py +168 -0
  498. rasa/shared/core/conversation.py +46 -0
  499. rasa/shared/core/domain.py +2106 -0
  500. rasa/shared/core/events.py +2507 -0
  501. rasa/shared/core/flows/__init__.py +7 -0
  502. rasa/shared/core/flows/flow.py +353 -0
  503. rasa/shared/core/flows/flow_step.py +146 -0
  504. rasa/shared/core/flows/flow_step_links.py +319 -0
  505. rasa/shared/core/flows/flow_step_sequence.py +70 -0
  506. rasa/shared/core/flows/flows_list.py +211 -0
  507. rasa/shared/core/flows/flows_yaml_schema.json +217 -0
  508. rasa/shared/core/flows/nlu_trigger.py +117 -0
  509. rasa/shared/core/flows/steps/__init__.py +24 -0
  510. rasa/shared/core/flows/steps/action.py +51 -0
  511. rasa/shared/core/flows/steps/call.py +64 -0
  512. rasa/shared/core/flows/steps/collect.py +112 -0
  513. rasa/shared/core/flows/steps/constants.py +5 -0
  514. rasa/shared/core/flows/steps/continuation.py +36 -0
  515. rasa/shared/core/flows/steps/end.py +22 -0
  516. rasa/shared/core/flows/steps/internal.py +44 -0
  517. rasa/shared/core/flows/steps/link.py +51 -0
  518. rasa/shared/core/flows/steps/no_operation.py +48 -0
  519. rasa/shared/core/flows/steps/set_slots.py +50 -0
  520. rasa/shared/core/flows/steps/start.py +30 -0
  521. rasa/shared/core/flows/validation.py +527 -0
  522. rasa/shared/core/flows/yaml_flows_io.py +278 -0
  523. rasa/shared/core/generator.py +907 -0
  524. rasa/shared/core/slot_mappings.py +235 -0
  525. rasa/shared/core/slots.py +647 -0
  526. rasa/shared/core/trackers.py +1159 -0
  527. rasa/shared/core/training_data/__init__.py +0 -0
  528. rasa/shared/core/training_data/loading.py +90 -0
  529. rasa/shared/core/training_data/story_reader/__init__.py +0 -0
  530. rasa/shared/core/training_data/story_reader/story_reader.py +129 -0
  531. rasa/shared/core/training_data/story_reader/story_step_builder.py +168 -0
  532. rasa/shared/core/training_data/story_reader/yaml_story_reader.py +888 -0
  533. rasa/shared/core/training_data/story_writer/__init__.py +0 -0
  534. rasa/shared/core/training_data/story_writer/story_writer.py +76 -0
  535. rasa/shared/core/training_data/story_writer/yaml_story_writer.py +442 -0
  536. rasa/shared/core/training_data/structures.py +838 -0
  537. rasa/shared/core/training_data/visualization.html +146 -0
  538. rasa/shared/core/training_data/visualization.py +603 -0
  539. rasa/shared/data.py +192 -0
  540. rasa/shared/engine/__init__.py +0 -0
  541. rasa/shared/engine/caching.py +26 -0
  542. rasa/shared/exceptions.py +129 -0
  543. rasa/shared/importers/__init__.py +0 -0
  544. rasa/shared/importers/importer.py +705 -0
  545. rasa/shared/importers/multi_project.py +203 -0
  546. rasa/shared/importers/rasa.py +100 -0
  547. rasa/shared/importers/utils.py +34 -0
  548. rasa/shared/nlu/__init__.py +0 -0
  549. rasa/shared/nlu/constants.py +45 -0
  550. rasa/shared/nlu/interpreter.py +10 -0
  551. rasa/shared/nlu/training_data/__init__.py +0 -0
  552. rasa/shared/nlu/training_data/entities_parser.py +209 -0
  553. rasa/shared/nlu/training_data/features.py +374 -0
  554. rasa/shared/nlu/training_data/formats/__init__.py +10 -0
  555. rasa/shared/nlu/training_data/formats/dialogflow.py +162 -0
  556. rasa/shared/nlu/training_data/formats/luis.py +87 -0
  557. rasa/shared/nlu/training_data/formats/rasa.py +135 -0
  558. rasa/shared/nlu/training_data/formats/rasa_yaml.py +605 -0
  559. rasa/shared/nlu/training_data/formats/readerwriter.py +245 -0
  560. rasa/shared/nlu/training_data/formats/wit.py +52 -0
  561. rasa/shared/nlu/training_data/loading.py +137 -0
  562. rasa/shared/nlu/training_data/lookup_tables_parser.py +30 -0
  563. rasa/shared/nlu/training_data/message.py +477 -0
  564. rasa/shared/nlu/training_data/schemas/__init__.py +0 -0
  565. rasa/shared/nlu/training_data/schemas/data_schema.py +85 -0
  566. rasa/shared/nlu/training_data/schemas/nlu.yml +53 -0
  567. rasa/shared/nlu/training_data/schemas/responses.yml +70 -0
  568. rasa/shared/nlu/training_data/synonyms_parser.py +42 -0
  569. rasa/shared/nlu/training_data/training_data.py +732 -0
  570. rasa/shared/nlu/training_data/util.py +223 -0
  571. rasa/shared/providers/__init__.py +0 -0
  572. rasa/shared/providers/openai/__init__.py +0 -0
  573. rasa/shared/providers/openai/clients.py +43 -0
  574. rasa/shared/providers/openai/session_handler.py +110 -0
  575. rasa/shared/utils/__init__.py +0 -0
  576. rasa/shared/utils/cli.py +72 -0
  577. rasa/shared/utils/common.py +308 -0
  578. rasa/shared/utils/constants.py +1 -0
  579. rasa/shared/utils/io.py +403 -0
  580. rasa/shared/utils/llm.py +405 -0
  581. rasa/shared/utils/pykwalify_extensions.py +26 -0
  582. rasa/shared/utils/schemas/__init__.py +0 -0
  583. rasa/shared/utils/schemas/config.yml +2 -0
  584. rasa/shared/utils/schemas/domain.yml +142 -0
  585. rasa/shared/utils/schemas/events.py +212 -0
  586. rasa/shared/utils/schemas/model_config.yml +46 -0
  587. rasa/shared/utils/schemas/stories.yml +173 -0
  588. rasa/shared/utils/yaml.py +777 -0
  589. rasa/studio/__init__.py +0 -0
  590. rasa/studio/auth.py +252 -0
  591. rasa/studio/config.py +127 -0
  592. rasa/studio/constants.py +16 -0
  593. rasa/studio/data_handler.py +352 -0
  594. rasa/studio/download.py +350 -0
  595. rasa/studio/train.py +136 -0
  596. rasa/studio/upload.py +408 -0
  597. rasa/telemetry.py +1583 -0
  598. rasa/tracing/__init__.py +0 -0
  599. rasa/tracing/config.py +338 -0
  600. rasa/tracing/constants.py +38 -0
  601. rasa/tracing/instrumentation/__init__.py +0 -0
  602. rasa/tracing/instrumentation/attribute_extractors.py +663 -0
  603. rasa/tracing/instrumentation/instrumentation.py +939 -0
  604. rasa/tracing/instrumentation/intentless_policy_instrumentation.py +142 -0
  605. rasa/tracing/instrumentation/metrics.py +206 -0
  606. rasa/tracing/metric_instrument_provider.py +125 -0
  607. rasa/utils/__init__.py +0 -0
  608. rasa/utils/beta.py +83 -0
  609. rasa/utils/cli.py +27 -0
  610. rasa/utils/common.py +635 -0
  611. rasa/utils/converter.py +53 -0
  612. rasa/utils/endpoints.py +303 -0
  613. rasa/utils/io.py +326 -0
  614. rasa/utils/licensing.py +319 -0
  615. rasa/utils/log_utils.py +174 -0
  616. rasa/utils/mapper.py +210 -0
  617. rasa/utils/ml_utils.py +145 -0
  618. rasa/utils/plotting.py +362 -0
  619. rasa/utils/singleton.py +23 -0
  620. rasa/utils/tensorflow/__init__.py +0 -0
  621. rasa/utils/tensorflow/callback.py +112 -0
  622. rasa/utils/tensorflow/constants.py +116 -0
  623. rasa/utils/tensorflow/crf.py +492 -0
  624. rasa/utils/tensorflow/data_generator.py +440 -0
  625. rasa/utils/tensorflow/environment.py +161 -0
  626. rasa/utils/tensorflow/exceptions.py +5 -0
  627. rasa/utils/tensorflow/layers.py +1565 -0
  628. rasa/utils/tensorflow/layers_utils.py +113 -0
  629. rasa/utils/tensorflow/metrics.py +281 -0
  630. rasa/utils/tensorflow/model_data.py +991 -0
  631. rasa/utils/tensorflow/model_data_utils.py +500 -0
  632. rasa/utils/tensorflow/models.py +936 -0
  633. rasa/utils/tensorflow/rasa_layers.py +1094 -0
  634. rasa/utils/tensorflow/transformer.py +640 -0
  635. rasa/utils/tensorflow/types.py +6 -0
  636. rasa/utils/train_utils.py +572 -0
  637. rasa/utils/yaml.py +54 -0
  638. rasa/validator.py +1035 -0
  639. rasa/version.py +3 -0
  640. rasa_pro-3.8.16.dist-info/METADATA +528 -0
  641. rasa_pro-3.8.16.dist-info/NOTICE +5 -0
  642. rasa_pro-3.8.16.dist-info/RECORD +644 -0
  643. rasa_pro-3.8.16.dist-info/WHEEL +4 -0
  644. rasa_pro-3.8.16.dist-info/entry_points.txt +3 -0
@@ -0,0 +1,1252 @@
1
+ from __future__ import annotations
2
+
3
+ import copy
4
+ import enum
5
+ import logging
6
+ import math
7
+ from enum import Enum
8
+ from typing import Dict, Text, Any, Tuple, Type, Optional, List, Callable, Set, Union
9
+
10
+ import dataclasses
11
+
12
+ from rasa.core.featurizers.precomputation import (
13
+ CoreFeaturizationInputConverter,
14
+ CoreFeaturizationCollector,
15
+ )
16
+ from rasa.graph_components.providers.flows_provider import FlowsProvider
17
+ from rasa.dialogue_understanding.processor.command_processor_component import (
18
+ CommandProcessorComponent,
19
+ )
20
+ from rasa.shared.exceptions import FileNotFoundException
21
+ from rasa.core.policies.ensemble import DefaultPolicyPredictionEnsemble
22
+
23
+ from rasa.engine.graph import (
24
+ GraphSchema,
25
+ GraphComponent,
26
+ SchemaNode,
27
+ GraphModelConfiguration,
28
+ )
29
+ from rasa.engine.constants import (
30
+ PLACEHOLDER_IMPORTER,
31
+ PLACEHOLDER_MESSAGE,
32
+ PLACEHOLDER_TRACKER,
33
+ PLACEHOLDER_ENDPOINTS,
34
+ )
35
+ from rasa.engine.recipes.recipe import Recipe
36
+ from rasa.engine.storage.resource import Resource
37
+ from rasa.graph_components.converters.nlu_message_converter import NLUMessageConverter
38
+ from rasa.graph_components.providers.domain_provider import DomainProvider
39
+ from rasa.graph_components.providers.forms_provider import FormsProvider
40
+ from rasa.graph_components.providers.responses_provider import ResponsesProvider
41
+ from rasa.graph_components.providers.domain_for_core_training_provider import (
42
+ DomainForCoreTrainingProvider,
43
+ )
44
+ from rasa.graph_components.providers.nlu_training_data_provider import (
45
+ NLUTrainingDataProvider,
46
+ )
47
+ from rasa.graph_components.providers.rule_only_provider import RuleOnlyDataProvider
48
+ from rasa.graph_components.providers.story_graph_provider import StoryGraphProvider
49
+ from rasa.graph_components.providers.training_tracker_provider import (
50
+ TrainingTrackerProvider,
51
+ )
52
+ import rasa.shared.constants
53
+ from rasa.shared.exceptions import RasaException, InvalidConfigException
54
+ from rasa.shared.constants import ASSISTANT_ID_KEY
55
+ from rasa.shared.data import TrainingType
56
+ from rasa.shared.utils.yaml import read_config_file
57
+
58
+ from rasa.utils.tensorflow.constants import EPOCHS
59
+ from rasa.shared.utils.common import (
60
+ class_from_module_path,
61
+ transform_collection_to_sentence,
62
+ )
63
+
64
+ logger = logging.getLogger(__name__)
65
+
66
+
67
+ DEFAULT_PREDICT_KWARGS = dict(constructor_name="load", eager=True, is_target=False)
68
+
69
+ COMMENTS_FOR_KEYS = {
70
+ "pipeline": (
71
+ f"# # No configuration for the NLU pipeline was provided. The following "
72
+ f"default pipeline was used to train your model.\n"
73
+ f"# # If you'd like to customize it, uncomment and adjust the pipeline.\n"
74
+ f"# # See {rasa.shared.constants.DOCS_URL_PIPELINE} for more information.\n"
75
+ ),
76
+ "policies": (
77
+ f"# # No configuration for policies was provided. The following default "
78
+ f"policies were used to train your model.\n"
79
+ f"# # If you'd like to customize them, uncomment and adjust the policies.\n"
80
+ f"# # See {rasa.shared.constants.DOCS_URL_POLICIES} for more information.\n"
81
+ ),
82
+ }
83
+
84
+
85
+ class DefaultV1RecipeRegisterException(RasaException):
86
+ """If you register a class which is not of type `GraphComponent`."""
87
+
88
+ pass
89
+
90
+
91
+ class DefaultV1Recipe(Recipe):
92
+ """Recipe which converts the normal model config to train and predict graph."""
93
+
94
+ @enum.unique
95
+ class ComponentType(Enum):
96
+ """Enum to categorize and place custom components correctly in the graph."""
97
+
98
+ MESSAGE_TOKENIZER = 0
99
+ MESSAGE_FEATURIZER = 1
100
+ INTENT_CLASSIFIER = 2
101
+ ENTITY_EXTRACTOR = 3
102
+ POLICY_WITHOUT_END_TO_END_SUPPORT = 4
103
+ POLICY_WITH_END_TO_END_SUPPORT = 5
104
+ MODEL_LOADER = 6
105
+ COMMAND_GENERATOR = 7
106
+ COEXISTENCE_ROUTER = 8
107
+
108
+ name = "default.v1"
109
+ _registered_components: Dict[Text, RegisteredComponent] = {} # noqa: RUF012
110
+
111
+ def __init__(self) -> None:
112
+ """Creates recipe."""
113
+ self._use_core = True
114
+ self._use_nlu = True
115
+ self._use_end_to_end = True
116
+ self._is_finetuning = False
117
+
118
+ @dataclasses.dataclass()
119
+ class RegisteredComponent:
120
+ """Describes a graph component which was registered with the decorator."""
121
+
122
+ clazz: Type[GraphComponent]
123
+ types: Set[DefaultV1Recipe.ComponentType]
124
+ is_trainable: bool
125
+ model_from: Optional[Text]
126
+
127
+ @classmethod
128
+ def register(
129
+ cls,
130
+ component_types: Union[ComponentType, List[ComponentType]],
131
+ is_trainable: bool,
132
+ model_from: Optional[Text] = None,
133
+ ) -> Callable[[Type[GraphComponent]], Type[GraphComponent]]:
134
+ """This decorator can be used to register classes with the recipe.
135
+
136
+ Args:
137
+ component_types: Describes the types of a component which are then used
138
+ to place the component in the graph.
139
+ is_trainable: `True` if the component requires training.
140
+ model_from: Can be used if this component requires a pre-loaded model
141
+ such as `SpacyNLP` or `MitieNLP`.
142
+
143
+ Returns:
144
+ The registered class.
145
+ """
146
+
147
+ def decorator(registered_class: Type[GraphComponent]) -> Type[GraphComponent]:
148
+ if not issubclass(registered_class, GraphComponent):
149
+ raise DefaultV1RecipeRegisterException(
150
+ f"Failed to register class '{registered_class.__name__}' with "
151
+ f"the recipe '{cls.name}'. The class has to be of type "
152
+ f"'{GraphComponent.__name__}'."
153
+ )
154
+
155
+ if isinstance(component_types, cls.ComponentType):
156
+ unique_types = {component_types}
157
+ else:
158
+ unique_types = set(component_types)
159
+
160
+ cls._registered_components[
161
+ registered_class.__name__
162
+ ] = cls.RegisteredComponent(
163
+ registered_class, unique_types, is_trainable, model_from
164
+ )
165
+ return registered_class
166
+
167
+ return decorator
168
+
169
+ @classmethod
170
+ def _from_registry(cls, name: Text) -> RegisteredComponent:
171
+ # Importing all the default Rasa components will automatically register them
172
+ from rasa.engine.recipes.default_components import DEFAULT_COMPONENTS # noqa
173
+
174
+ if name in cls._registered_components:
175
+ return cls._registered_components[name]
176
+
177
+ if "." in name:
178
+ clazz = class_from_module_path(name)
179
+ if clazz.__name__ in cls._registered_components:
180
+ return cls._registered_components[clazz.__name__]
181
+
182
+ raise InvalidConfigException(
183
+ f"Can't load class for name '{name}'. Please make sure to provide "
184
+ f"a valid name or module path and to register it using the "
185
+ f"'@DefaultV1Recipe.register' decorator."
186
+ )
187
+
188
+ def graph_config_for_recipe(
189
+ self,
190
+ config: Dict,
191
+ cli_parameters: Dict[Text, Any],
192
+ training_type: TrainingType = TrainingType.BOTH,
193
+ is_finetuning: bool = False,
194
+ ) -> GraphModelConfiguration:
195
+ """Converts the default config to graphs (see interface for full docstring)."""
196
+ self._use_core = (
197
+ bool(config.get("policies")) and not training_type == TrainingType.NLU
198
+ )
199
+ self._use_nlu = (
200
+ bool(config.get("pipeline")) and not training_type == TrainingType.CORE
201
+ )
202
+
203
+ if not self._use_nlu and training_type == TrainingType.NLU:
204
+ raise InvalidConfigException(
205
+ "Can't train an NLU model without a specified pipeline. Please make "
206
+ "sure to specify a valid pipeline in your configuration."
207
+ )
208
+
209
+ if not self._use_core and training_type == TrainingType.CORE:
210
+ raise InvalidConfigException(
211
+ "Can't train an Core model without policies. Please make "
212
+ "sure to specify a valid policy in your configuration."
213
+ )
214
+
215
+ self._use_end_to_end = (
216
+ self._use_nlu
217
+ and self._use_core
218
+ and training_type == TrainingType.END_TO_END
219
+ )
220
+
221
+ self._is_finetuning = is_finetuning
222
+
223
+ train_nodes, preprocessors = self._create_train_nodes(config, cli_parameters)
224
+ predict_nodes = self._create_predict_nodes(config, preprocessors, train_nodes)
225
+
226
+ core_target = "select_prediction" if self._use_core else None
227
+
228
+ from rasa.nlu.classifiers.regex_message_handler import RegexMessageHandler
229
+
230
+ return GraphModelConfiguration(
231
+ train_schema=GraphSchema(train_nodes),
232
+ predict_schema=GraphSchema(predict_nodes),
233
+ training_type=training_type,
234
+ assistant_id=config.get(ASSISTANT_ID_KEY),
235
+ language=config.get("language"),
236
+ spaces=config.get("spaces"),
237
+ core_target=core_target,
238
+ nlu_target=f"run_{RegexMessageHandler.__name__}",
239
+ )
240
+
241
+ def _create_train_nodes(
242
+ self, config: Dict[Text, Any], cli_parameters: Dict[Text, Any]
243
+ ) -> Tuple[Dict[Text, SchemaNode], List[Text]]:
244
+ from rasa.graph_components.validators.default_recipe_validator import (
245
+ DefaultV1RecipeValidator,
246
+ )
247
+ from rasa.graph_components.validators.finetuning_validator import (
248
+ FinetuningValidator,
249
+ )
250
+
251
+ train_config = copy.deepcopy(config)
252
+
253
+ train_nodes = {
254
+ "schema_validator": SchemaNode(
255
+ needs={"importer": PLACEHOLDER_IMPORTER},
256
+ uses=DefaultV1RecipeValidator,
257
+ constructor_name="create",
258
+ fn="validate",
259
+ config={},
260
+ is_input=True,
261
+ ),
262
+ "finetuning_validator": SchemaNode(
263
+ needs={"importer": "schema_validator"},
264
+ uses=FinetuningValidator,
265
+ constructor_name="load" if self._is_finetuning else "create",
266
+ fn="validate",
267
+ is_input=True,
268
+ config={"validate_core": self._use_core, "validate_nlu": self._use_nlu},
269
+ ),
270
+ }
271
+
272
+ preprocessors = []
273
+
274
+ if self._use_nlu:
275
+ preprocessors = self._add_nlu_train_nodes(
276
+ train_config, train_nodes, cli_parameters
277
+ )
278
+
279
+ if self._use_core:
280
+ self._add_core_train_nodes(
281
+ train_config, train_nodes, preprocessors, cli_parameters
282
+ )
283
+
284
+ return train_nodes, preprocessors
285
+
286
+ def _add_nlu_train_nodes(
287
+ self,
288
+ train_config: Dict[Text, Any],
289
+ train_nodes: Dict[Text, SchemaNode],
290
+ cli_parameters: Dict[Text, Any],
291
+ ) -> List[Text]:
292
+ train_nodes["flows_provider"] = SchemaNode(
293
+ needs={
294
+ "importer": "finetuning_validator",
295
+ },
296
+ uses=FlowsProvider,
297
+ constructor_name="create",
298
+ fn="provide_train",
299
+ config={},
300
+ is_target=True,
301
+ is_input=True,
302
+ )
303
+ train_nodes["domain_provider"] = SchemaNode(
304
+ needs={
305
+ "importer": "finetuning_validator",
306
+ },
307
+ uses=DomainProvider,
308
+ constructor_name="create",
309
+ fn="provide_train",
310
+ config={},
311
+ is_target=True,
312
+ is_input=True,
313
+ )
314
+ persist_nlu_data = bool(cli_parameters.get("persist_nlu_training_data"))
315
+ train_nodes["nlu_training_data_provider"] = SchemaNode(
316
+ needs={"importer": "finetuning_validator"},
317
+ uses=NLUTrainingDataProvider,
318
+ constructor_name="create",
319
+ fn="provide",
320
+ config={
321
+ "language": train_config.get("language"),
322
+ "persist": persist_nlu_data,
323
+ },
324
+ is_target=persist_nlu_data,
325
+ is_input=True,
326
+ )
327
+
328
+ last_run_node = "nlu_training_data_provider"
329
+ preprocessors: List[Text] = []
330
+
331
+ for idx, config in enumerate(train_config["pipeline"]):
332
+ component_name = config.pop("name")
333
+ component = self._from_registry(component_name)
334
+ component_name = f"{component_name}{idx}"
335
+
336
+ if (
337
+ self.ComponentType.POLICY_WITHOUT_END_TO_END_SUPPORT in component.types
338
+ or self.ComponentType.POLICY_WITH_END_TO_END_SUPPORT in component.types
339
+ ):
340
+ raise InvalidConfigException(
341
+ f"Found policy '{component_name}' in NLU pipeline. Policies should "
342
+ f"be defined in the 'policies' section of your configuration."
343
+ )
344
+ if self.ComponentType.MODEL_LOADER in component.types:
345
+ node_name = f"provide_{component_name}"
346
+ train_nodes[node_name] = SchemaNode(
347
+ needs={},
348
+ uses=component.clazz,
349
+ constructor_name="create",
350
+ fn="provide",
351
+ config=config,
352
+ )
353
+
354
+ from_resource = None
355
+ if component.is_trainable:
356
+ from_resource = self._add_nlu_train_node(
357
+ train_nodes,
358
+ component.clazz,
359
+ component_name,
360
+ last_run_node,
361
+ config,
362
+ cli_parameters,
363
+ )
364
+
365
+ if component.types.intersection(
366
+ {
367
+ self.ComponentType.MESSAGE_TOKENIZER,
368
+ self.ComponentType.MESSAGE_FEATURIZER,
369
+ }
370
+ ):
371
+ last_run_node = self._add_nlu_process_node(
372
+ train_nodes,
373
+ component.clazz,
374
+ component_name,
375
+ last_run_node,
376
+ config,
377
+ from_resource=from_resource,
378
+ )
379
+
380
+ # Remember for End-to-End-Featurization
381
+ preprocessors.append(last_run_node)
382
+
383
+ return preprocessors
384
+
385
+ def _get_needs_from_args(
386
+ self, component: Type[GraphComponent], fn_name: str
387
+ ) -> Dict[str, str]:
388
+ """Get the needed arguments from the method on the component.
389
+
390
+ Filters out arguments that are already provided by other graph
391
+ components. Does not check if the created providers are actually
392
+ part of the graph. If they aren't an error will be raised later on
393
+ when the graph is validated.
394
+
395
+ Args:
396
+ component: The component class.
397
+ fn_name: The name of the method to inspect.
398
+
399
+ Returns:
400
+ The name of the arguments which need to be provided.
401
+ """
402
+ from inspect import signature
403
+
404
+ if not hasattr(component, fn_name):
405
+ return {}
406
+
407
+ def resolver_name_from_parameter(parameter: str) -> str:
408
+ # we got a couple special cases to handle wher the parameter name
409
+ # doesn't match the provider name
410
+ if "training_trackers" == parameter:
411
+ return "training_tracker_provider"
412
+ elif "tracker" == parameter:
413
+ return PLACEHOLDER_TRACKER
414
+ elif "endpoints" == parameter:
415
+ return PLACEHOLDER_ENDPOINTS
416
+ elif "training_data" == parameter:
417
+ return "nlu_training_data_provider"
418
+ return f"{parameter}_provider"
419
+
420
+ sig = signature(getattr(component, fn_name))
421
+ parameters = {
422
+ name
423
+ for name, param in sig.parameters.items()
424
+ if param.kind == param.POSITIONAL_OR_KEYWORD
425
+ }
426
+
427
+ # filter out parameters which are already resolved in other ways
428
+ unprovided_parameters = parameters - {
429
+ "message",
430
+ "messages",
431
+ "self",
432
+ "model",
433
+ "precomputations",
434
+ }
435
+
436
+ return {
437
+ parameter: resolver_name_from_parameter(parameter)
438
+ for parameter in unprovided_parameters
439
+ }
440
+
441
+ def _add_nlu_train_node(
442
+ self,
443
+ train_nodes: Dict[Text, SchemaNode],
444
+ component: Type[GraphComponent],
445
+ component_name: Text,
446
+ last_run_node: Text,
447
+ config: Dict[Text, Any],
448
+ cli_parameters: Dict[Text, Any],
449
+ ) -> Text:
450
+ config_from_cli = self._extra_config_from_cli(cli_parameters, component, config)
451
+ needs = self._get_needs_from_args(component, "train")
452
+ needs.update(self._get_model_provider_needs(train_nodes, component))
453
+ needs["training_data"] = last_run_node
454
+
455
+ train_node_name = f"train_{component_name}"
456
+ train_nodes[train_node_name] = SchemaNode(
457
+ needs=needs,
458
+ uses=component,
459
+ constructor_name="load" if self._is_finetuning else "create",
460
+ fn="train",
461
+ config={**config, **config_from_cli},
462
+ is_target=True,
463
+ )
464
+ return train_node_name
465
+
466
+ def _extra_config_from_cli(
467
+ self,
468
+ cli_parameters: Dict[Text, Any],
469
+ component: Type[GraphComponent],
470
+ component_config: Dict[Text, Any],
471
+ ) -> Dict[Text, Any]:
472
+ from rasa.nlu.classifiers.mitie_intent_classifier import MitieIntentClassifier
473
+ from rasa.nlu.extractors.mitie_entity_extractor import MitieEntityExtractor
474
+ from rasa.nlu.classifiers.sklearn_intent_classifier import (
475
+ SklearnIntentClassifier,
476
+ )
477
+
478
+ cli_args_mapping: Dict[Type[GraphComponent], List[Text]] = {
479
+ MitieIntentClassifier: ["num_threads"],
480
+ MitieEntityExtractor: ["num_threads"],
481
+ SklearnIntentClassifier: ["num_threads"],
482
+ }
483
+
484
+ config_from_cli = {
485
+ param: cli_parameters[param]
486
+ for param in cli_args_mapping.get(component, [])
487
+ if param in cli_parameters and cli_parameters[param] is not None
488
+ }
489
+
490
+ if (
491
+ self._is_finetuning
492
+ and "finetuning_epoch_fraction" in cli_parameters
493
+ and EPOCHS in component.get_default_config()
494
+ ):
495
+ old_number_epochs = component_config.get(
496
+ EPOCHS, component.get_default_config()[EPOCHS]
497
+ )
498
+ epoch_fraction = cli_parameters["finetuning_epoch_fraction"]
499
+ epoch_fraction = epoch_fraction if epoch_fraction is not None else 1.0
500
+ config_from_cli["finetuning_epoch_fraction"] = epoch_fraction
501
+ config_from_cli[EPOCHS] = math.ceil(
502
+ old_number_epochs * float(epoch_fraction)
503
+ )
504
+
505
+ return config_from_cli
506
+
507
+ def _add_nlu_process_node(
508
+ self,
509
+ train_nodes: Dict[Text, SchemaNode],
510
+ component_class: Type[GraphComponent],
511
+ component_name: Text,
512
+ last_run_node: Text,
513
+ component_config: Dict[Text, Any],
514
+ from_resource: Optional[Text] = None,
515
+ ) -> Text:
516
+ needs = self._get_needs_from_args(component_class, "process_training_data")
517
+ needs.update(self._get_model_provider_needs(train_nodes, component_class))
518
+
519
+ if from_resource:
520
+ needs["resource"] = from_resource
521
+
522
+ needs["training_data"] = last_run_node
523
+
524
+ node_name = f"run_{component_name}"
525
+ train_nodes[node_name] = SchemaNode(
526
+ needs=needs,
527
+ uses=component_class,
528
+ constructor_name="load",
529
+ fn="process_training_data",
530
+ config=component_config,
531
+ )
532
+ return node_name
533
+
534
+ def _get_model_provider_needs(
535
+ self, nodes: Dict[Text, SchemaNode], component_class: Type[GraphComponent]
536
+ ) -> Dict[Text, Text]:
537
+ model_provider_needs = {}
538
+ component = self._from_registry(component_class.__name__)
539
+
540
+ if not component.model_from:
541
+ return {}
542
+
543
+ node_name_of_provider = next(
544
+ (
545
+ node_name
546
+ for node_name, node in nodes.items()
547
+ if node.uses.__name__ == component.model_from
548
+ ),
549
+ None,
550
+ )
551
+ if node_name_of_provider:
552
+ model_provider_needs["model"] = node_name_of_provider
553
+
554
+ return model_provider_needs
555
+
556
+ def _add_core_train_nodes(
557
+ self,
558
+ train_config: Dict[Text, Any],
559
+ train_nodes: Dict[Text, SchemaNode],
560
+ preprocessors: List[Text],
561
+ cli_parameters: Dict[Text, Any],
562
+ ) -> None:
563
+ train_nodes["domain_provider"] = SchemaNode(
564
+ needs={"importer": "finetuning_validator"},
565
+ uses=DomainProvider,
566
+ constructor_name="create",
567
+ fn="provide_train",
568
+ config={},
569
+ is_target=True,
570
+ is_input=True,
571
+ )
572
+ train_nodes["domain_for_core_training_provider"] = SchemaNode(
573
+ needs={"domain": "domain_provider"},
574
+ uses=DomainForCoreTrainingProvider,
575
+ constructor_name="create",
576
+ fn="provide",
577
+ config={},
578
+ is_input=True,
579
+ )
580
+ train_nodes["forms_provider"] = SchemaNode(
581
+ needs={"domain": "domain_provider"},
582
+ uses=FormsProvider,
583
+ constructor_name="create",
584
+ fn="provide",
585
+ config={},
586
+ is_input=True,
587
+ )
588
+ train_nodes["responses_provider"] = SchemaNode(
589
+ needs={"domain": "domain_provider"},
590
+ uses=ResponsesProvider,
591
+ constructor_name="create",
592
+ fn="provide",
593
+ config={},
594
+ is_input=True,
595
+ )
596
+ train_nodes["story_graph_provider"] = SchemaNode(
597
+ needs={"importer": "finetuning_validator"},
598
+ uses=StoryGraphProvider,
599
+ constructor_name="create",
600
+ fn="provide",
601
+ config={"exclusion_percentage": cli_parameters.get("exclusion_percentage")},
602
+ is_input=True,
603
+ )
604
+ train_nodes["flows_provider"] = SchemaNode(
605
+ needs={
606
+ "importer": "finetuning_validator",
607
+ },
608
+ uses=FlowsProvider,
609
+ constructor_name="create",
610
+ fn="provide_train",
611
+ config={},
612
+ is_target=True,
613
+ is_input=True,
614
+ )
615
+ train_nodes["training_tracker_provider"] = SchemaNode(
616
+ needs={
617
+ "story_graph": "story_graph_provider",
618
+ "domain": "domain_for_core_training_provider",
619
+ },
620
+ uses=TrainingTrackerProvider,
621
+ constructor_name="create",
622
+ fn="provide",
623
+ config={
624
+ param: cli_parameters[param]
625
+ for param in ["debug_plots", "augmentation_factor"]
626
+ if param in cli_parameters
627
+ },
628
+ )
629
+
630
+ policy_with_end_to_end_support_used = False
631
+ for idx, config in enumerate(train_config["policies"]):
632
+ component_name = config.pop("name")
633
+ component = self._from_registry(component_name)
634
+
635
+ extra_config_from_cli = self._extra_config_from_cli(
636
+ cli_parameters, component.clazz, config
637
+ )
638
+
639
+ requires_end_to_end_data = self._use_end_to_end and (
640
+ self.ComponentType.POLICY_WITH_END_TO_END_SUPPORT in component.types
641
+ )
642
+ policy_with_end_to_end_support_used = (
643
+ policy_with_end_to_end_support_used or requires_end_to_end_data
644
+ )
645
+
646
+ needs = self._get_needs_from_args(component.clazz, "train")
647
+ if requires_end_to_end_data:
648
+ needs["precomputations"] = "end_to_end_features_provider"
649
+ # during core training we use a stripped down version of the domain
650
+ needs["domain"] = "domain_for_core_training_provider"
651
+ train_nodes[f"train_{component_name}{idx}"] = SchemaNode(
652
+ needs=needs,
653
+ uses=component.clazz,
654
+ constructor_name="load" if self._is_finetuning else "create",
655
+ fn="train",
656
+ is_target=True,
657
+ config={**config, **extra_config_from_cli},
658
+ )
659
+
660
+ if self._use_end_to_end and policy_with_end_to_end_support_used:
661
+ self._add_end_to_end_features_for_training(preprocessors, train_nodes)
662
+
663
+ def _add_end_to_end_features_for_training(
664
+ self, preprocessors: List[Text], train_nodes: Dict[Text, SchemaNode]
665
+ ) -> None:
666
+ train_nodes["story_to_nlu_training_data_converter"] = SchemaNode(
667
+ needs={
668
+ "story_graph": "story_graph_provider",
669
+ "domain": "domain_for_core_training_provider",
670
+ },
671
+ uses=CoreFeaturizationInputConverter,
672
+ constructor_name="create",
673
+ fn="convert_for_training",
674
+ config={},
675
+ is_input=True,
676
+ )
677
+
678
+ last_node_name = "story_to_nlu_training_data_converter"
679
+ for preprocessor in preprocessors:
680
+ node = copy.deepcopy(train_nodes[preprocessor])
681
+ node.needs["training_data"] = last_node_name
682
+
683
+ node_name = f"e2e_{preprocessor}"
684
+ train_nodes[node_name] = node
685
+ last_node_name = node_name
686
+
687
+ node_with_e2e_features = "end_to_end_features_provider"
688
+ train_nodes[node_with_e2e_features] = SchemaNode(
689
+ needs={"messages": last_node_name},
690
+ uses=CoreFeaturizationCollector,
691
+ constructor_name="create",
692
+ fn="collect",
693
+ config={},
694
+ )
695
+
696
+ def _create_predict_nodes(
697
+ self,
698
+ config: Dict[Text, SchemaNode],
699
+ preprocessors: List[Text],
700
+ train_nodes: Dict[Text, SchemaNode],
701
+ ) -> Dict[Text, SchemaNode]:
702
+
703
+ predict_config = copy.deepcopy(config)
704
+ predict_nodes = {}
705
+
706
+ from rasa.nlu.classifiers.regex_message_handler import RegexMessageHandler
707
+
708
+ predict_nodes["nlu_message_converter"] = SchemaNode(
709
+ **DEFAULT_PREDICT_KWARGS,
710
+ needs={"messages": PLACEHOLDER_MESSAGE},
711
+ uses=NLUMessageConverter,
712
+ fn="convert_user_message",
713
+ config={},
714
+ )
715
+
716
+ last_run_nlu_node = "nlu_message_converter"
717
+
718
+ if self._use_nlu:
719
+ last_run_nlu_node = self._add_nlu_predict_nodes(
720
+ last_run_nlu_node, predict_config, predict_nodes, train_nodes
721
+ )
722
+
723
+ domain_needs = {}
724
+ if self._use_core:
725
+ domain_needs["domain"] = "domain_provider"
726
+
727
+ regex_handler_node_name = f"run_{RegexMessageHandler.__name__}"
728
+ predict_nodes[regex_handler_node_name] = SchemaNode(
729
+ **DEFAULT_PREDICT_KWARGS,
730
+ needs={"messages": last_run_nlu_node, **domain_needs},
731
+ uses=RegexMessageHandler,
732
+ fn="process",
733
+ config={},
734
+ )
735
+
736
+ if self._use_core:
737
+ self._add_core_predict_nodes(
738
+ predict_config, predict_nodes, train_nodes, preprocessors
739
+ )
740
+
741
+ return predict_nodes
742
+
743
+ def _add_nlu_predict_nodes(
744
+ self,
745
+ last_run_node: Text,
746
+ predict_config: Dict[Text, Any],
747
+ predict_nodes: Dict[Text, SchemaNode],
748
+ train_nodes: Dict[Text, SchemaNode],
749
+ ) -> Text:
750
+ predict_nodes["flows_provider"] = SchemaNode(
751
+ **DEFAULT_PREDICT_KWARGS,
752
+ needs={},
753
+ uses=FlowsProvider,
754
+ fn="provide_inference",
755
+ config={},
756
+ resource=Resource("flows_provider"),
757
+ )
758
+ predict_nodes["domain_provider"] = SchemaNode(
759
+ **DEFAULT_PREDICT_KWARGS,
760
+ needs={},
761
+ uses=DomainProvider,
762
+ fn="provide_inference",
763
+ config={},
764
+ resource=Resource("domain_provider"),
765
+ )
766
+
767
+ for idx, config in enumerate(predict_config["pipeline"]):
768
+ component_name = config.pop("name")
769
+ component = self._from_registry(component_name)
770
+ component_name = f"{component_name}{idx}"
771
+ if self.ComponentType.MODEL_LOADER in component.types:
772
+ predict_nodes[f"provide_{component_name}"] = SchemaNode(
773
+ **DEFAULT_PREDICT_KWARGS,
774
+ needs={},
775
+ uses=component.clazz,
776
+ fn="provide",
777
+ config=config,
778
+ )
779
+
780
+ if component.types.intersection(
781
+ {
782
+ self.ComponentType.MESSAGE_TOKENIZER,
783
+ self.ComponentType.MESSAGE_FEATURIZER,
784
+ }
785
+ ):
786
+ last_run_node = self._add_nlu_predict_node_from_train(
787
+ predict_nodes,
788
+ component_name,
789
+ train_nodes,
790
+ last_run_node,
791
+ config,
792
+ from_resource=component.is_trainable,
793
+ )
794
+ elif component.types.intersection(
795
+ {
796
+ self.ComponentType.INTENT_CLASSIFIER,
797
+ self.ComponentType.ENTITY_EXTRACTOR,
798
+ self.ComponentType.COMMAND_GENERATOR,
799
+ self.ComponentType.COEXISTENCE_ROUTER,
800
+ }
801
+ ):
802
+ if component.is_trainable:
803
+ last_run_node = self._add_nlu_predict_node_from_train(
804
+ predict_nodes,
805
+ component_name,
806
+ train_nodes,
807
+ last_run_node,
808
+ config,
809
+ from_resource=component.is_trainable,
810
+ )
811
+ else:
812
+ new_node = SchemaNode(
813
+ needs={"messages": last_run_node},
814
+ uses=component.clazz,
815
+ constructor_name="create",
816
+ fn="process",
817
+ config=config,
818
+ )
819
+
820
+ last_run_node = self._add_nlu_predict_node(
821
+ predict_nodes, new_node, component_name, last_run_node
822
+ )
823
+
824
+ return last_run_node
825
+
826
+ def _add_nlu_predict_node_from_train(
827
+ self,
828
+ predict_nodes: Dict[Text, SchemaNode],
829
+ node_name: Text,
830
+ train_nodes: Dict[Text, SchemaNode],
831
+ last_run_node: Text,
832
+ item_config: Dict[Text, Any],
833
+ from_resource: bool = False,
834
+ ) -> Text:
835
+ train_node_name = f"run_{node_name}"
836
+ resource = None
837
+ if from_resource:
838
+ train_node_name = f"train_{node_name}"
839
+ resource = Resource(train_node_name)
840
+
841
+ return self._add_nlu_predict_node(
842
+ predict_nodes,
843
+ dataclasses.replace(
844
+ train_nodes[train_node_name], resource=resource, config=item_config
845
+ ),
846
+ node_name,
847
+ last_run_node,
848
+ )
849
+
850
+ def _add_nlu_predict_node(
851
+ self,
852
+ predict_nodes: Dict[Text, SchemaNode],
853
+ node: SchemaNode,
854
+ component_name: Text,
855
+ last_run_node: Text,
856
+ ) -> Text:
857
+ node_name = f"run_{component_name}"
858
+
859
+ needs = self._get_needs_from_args(node.uses, "process")
860
+ needs.update(self._get_model_provider_needs(predict_nodes, node.uses))
861
+ needs["messages"] = last_run_node
862
+ predict_nodes[node_name] = dataclasses.replace(
863
+ node,
864
+ needs=needs,
865
+ fn="process",
866
+ **DEFAULT_PREDICT_KWARGS,
867
+ )
868
+
869
+ return node_name
870
+
871
+ def _add_core_predict_nodes(
872
+ self,
873
+ predict_config: Dict[Text, Any],
874
+ predict_nodes: Dict[Text, SchemaNode],
875
+ train_nodes: Dict[Text, SchemaNode],
876
+ preprocessors: List[Text],
877
+ ) -> None:
878
+ predict_nodes["domain_provider"] = SchemaNode(
879
+ **DEFAULT_PREDICT_KWARGS,
880
+ needs={},
881
+ uses=DomainProvider,
882
+ fn="provide_inference",
883
+ config={},
884
+ resource=Resource("domain_provider"),
885
+ )
886
+ predict_nodes["flows_provider"] = SchemaNode(
887
+ **DEFAULT_PREDICT_KWARGS,
888
+ needs={},
889
+ uses=FlowsProvider,
890
+ fn="provide_inference",
891
+ config={},
892
+ resource=Resource("flows_provider"),
893
+ )
894
+
895
+ node_with_e2e_features = None
896
+
897
+ if "end_to_end_features_provider" in train_nodes:
898
+ node_with_e2e_features = self._add_end_to_end_features_for_inference(
899
+ predict_nodes, preprocessors
900
+ )
901
+
902
+ predict_nodes["command_processor"] = SchemaNode(
903
+ **DEFAULT_PREDICT_KWARGS,
904
+ needs=self._get_needs_from_args(
905
+ CommandProcessorComponent, "execute_commands"
906
+ ),
907
+ uses=CommandProcessorComponent,
908
+ fn="execute_commands",
909
+ config={},
910
+ resource=Resource("command_processor"),
911
+ )
912
+
913
+ rule_policy_resource = None
914
+ policies: List[Text] = []
915
+
916
+ for idx, config in enumerate(predict_config["policies"]):
917
+ component_name = config.pop("name")
918
+ component = self._from_registry(component_name)
919
+
920
+ train_node_name = f"train_{component_name}{idx}"
921
+ node_name = f"run_{component_name}{idx}"
922
+
923
+ from rasa.core.policies.rule_policy import RulePolicy
924
+
925
+ if issubclass(component.clazz, RulePolicy) and not rule_policy_resource:
926
+ rule_policy_resource = train_node_name
927
+
928
+ needs = self._get_needs_from_args(
929
+ train_nodes[train_node_name].uses, "predict_action_probabilities"
930
+ )
931
+ if (
932
+ self.ComponentType.POLICY_WITH_END_TO_END_SUPPORT in component.types
933
+ and node_with_e2e_features
934
+ ):
935
+ needs["precomputations"] = node_with_e2e_features
936
+
937
+ predict_nodes[node_name] = dataclasses.replace(
938
+ train_nodes[train_node_name],
939
+ **DEFAULT_PREDICT_KWARGS,
940
+ needs=needs,
941
+ fn="predict_action_probabilities",
942
+ resource=Resource(train_node_name),
943
+ )
944
+ policies.append(node_name)
945
+
946
+ predict_nodes["rule_only_data_provider"] = SchemaNode(
947
+ **DEFAULT_PREDICT_KWARGS,
948
+ needs={},
949
+ uses=RuleOnlyDataProvider,
950
+ fn="provide",
951
+ config={},
952
+ resource=Resource(rule_policy_resource) if rule_policy_resource else None,
953
+ )
954
+
955
+ predict_nodes["select_prediction"] = SchemaNode(
956
+ **DEFAULT_PREDICT_KWARGS,
957
+ needs={
958
+ **{f"policy{idx}": name for idx, name in enumerate(policies)},
959
+ "domain": "domain_provider",
960
+ "tracker": PLACEHOLDER_TRACKER,
961
+ },
962
+ uses=DefaultPolicyPredictionEnsemble,
963
+ fn="combine_predictions_from_kwargs",
964
+ config={},
965
+ )
966
+
967
+ def _add_end_to_end_features_for_inference(
968
+ self, predict_nodes: Dict[Text, SchemaNode], preprocessors: List[Text]
969
+ ) -> Text:
970
+ predict_nodes["tracker_to_message_converter"] = SchemaNode(
971
+ **DEFAULT_PREDICT_KWARGS,
972
+ needs={"tracker": PLACEHOLDER_TRACKER},
973
+ uses=CoreFeaturizationInputConverter,
974
+ fn="convert_for_inference",
975
+ config={},
976
+ )
977
+
978
+ last_node_name = "tracker_to_message_converter"
979
+ for preprocessor in preprocessors:
980
+ node = dataclasses.replace(
981
+ predict_nodes[preprocessor], needs={"messages": last_node_name}
982
+ )
983
+
984
+ node_name = f"e2e_{preprocessor}"
985
+ predict_nodes[node_name] = node
986
+ last_node_name = node_name
987
+
988
+ node_with_e2e_features = "end_to_end_features_provider"
989
+ predict_nodes[node_with_e2e_features] = SchemaNode(
990
+ **DEFAULT_PREDICT_KWARGS,
991
+ needs={"messages": last_node_name},
992
+ uses=CoreFeaturizationCollector,
993
+ fn="collect",
994
+ config={},
995
+ )
996
+ return node_with_e2e_features
997
+
998
+ @staticmethod
999
+ def auto_configure(
1000
+ config_file_path: Optional[Text],
1001
+ config: Dict,
1002
+ training_type: Optional[TrainingType] = TrainingType.BOTH,
1003
+ ) -> Tuple[Dict[Text, Any], Set[str], Set[str]]:
1004
+ """Determine configuration from auto-filled configuration file.
1005
+
1006
+ Keys that are provided and have a value in the file are kept. Keys that are not
1007
+ provided are configured automatically.
1008
+
1009
+ Note that this needs to be called explicitly; ie. we cannot
1010
+ auto-configure automatically from importers because importers are not
1011
+ allowed to access code outside of `rasa.shared`.
1012
+
1013
+ Args:
1014
+ config_file_path: The path to the configuration file.
1015
+ config: Configuration in dictionary format.
1016
+ training_type: Optional training type to auto-configure. By default
1017
+ both core and NLU will be auto-configured.
1018
+ """
1019
+ missing_keys = DefaultV1Recipe._get_missing_config_keys(config, training_type)
1020
+ keys_to_configure = DefaultV1Recipe._get_unspecified_autoconfigurable_keys(
1021
+ config, training_type
1022
+ )
1023
+
1024
+ if keys_to_configure:
1025
+ config = DefaultV1Recipe.complete_config(config, keys_to_configure)
1026
+ DefaultV1Recipe._dump_config(
1027
+ config, config_file_path, missing_keys, keys_to_configure, training_type
1028
+ )
1029
+
1030
+ return config, missing_keys, keys_to_configure
1031
+
1032
+ @staticmethod
1033
+ def _get_unspecified_autoconfigurable_keys(
1034
+ config: Dict[Text, Any],
1035
+ training_type: Optional[TrainingType] = TrainingType.BOTH,
1036
+ ) -> Set[Text]:
1037
+ if training_type == TrainingType.NLU:
1038
+ all_keys = rasa.shared.constants.CONFIG_AUTOCONFIGURABLE_KEYS_NLU
1039
+ elif training_type == TrainingType.CORE:
1040
+ all_keys = rasa.shared.constants.CONFIG_AUTOCONFIGURABLE_KEYS_CORE
1041
+ else:
1042
+ all_keys = rasa.shared.constants.CONFIG_AUTOCONFIGURABLE_KEYS
1043
+
1044
+ return {k for k in all_keys if config.get(k) is None}
1045
+
1046
+ @staticmethod
1047
+ def _get_missing_config_keys(
1048
+ config: Dict[Text, Any],
1049
+ training_type: Optional[TrainingType] = TrainingType.BOTH,
1050
+ ) -> Set[Text]:
1051
+ if training_type == TrainingType.NLU:
1052
+ all_keys = rasa.shared.constants.CONFIG_KEYS_NLU
1053
+ elif training_type == TrainingType.CORE:
1054
+ all_keys = rasa.shared.constants.CONFIG_KEYS_CORE
1055
+ else:
1056
+ all_keys = rasa.shared.constants.CONFIG_KEYS
1057
+
1058
+ return {k for k in all_keys if k not in config.keys()}
1059
+
1060
+ @staticmethod
1061
+ def complete_config(
1062
+ config: Dict[Text, Any], keys_to_configure: Set[Text]
1063
+ ) -> Dict[Text, Any]:
1064
+ """Complete a config by adding automatic configuration for the specified keys.
1065
+
1066
+ Args:
1067
+ config: The provided configuration.
1068
+ keys_to_configure: Keys to be configured automatically (e.g. `policies`).
1069
+
1070
+ Returns:
1071
+ The resulting configuration including both the provided and
1072
+ the automatically configured keys.
1073
+ """
1074
+ import importlib_resources
1075
+
1076
+ if keys_to_configure:
1077
+ logger.debug(
1078
+ f"The provided configuration does not contain the key(s) "
1079
+ f"{transform_collection_to_sentence(keys_to_configure)}. "
1080
+ f"Values will be provided from the default configuration."
1081
+ )
1082
+
1083
+ default_config_file = str(
1084
+ importlib_resources.files(__name__)
1085
+ .joinpath("config_files")
1086
+ .joinpath("default_config.yml")
1087
+ )
1088
+ default_config = read_config_file(default_config_file)
1089
+
1090
+ config = copy.deepcopy(config)
1091
+ for key in keys_to_configure:
1092
+ config[key] = default_config[key]
1093
+
1094
+ return config
1095
+
1096
+ @staticmethod
1097
+ def _dump_config(
1098
+ config: Dict[Text, Any],
1099
+ config_file_path: Text,
1100
+ missing_keys: Set[Text],
1101
+ auto_configured_keys: Set[Text],
1102
+ training_type: Optional[TrainingType] = TrainingType.BOTH,
1103
+ ) -> None:
1104
+ """Dump the automatically configured keys into the config file.
1105
+
1106
+ The configuration provided in the file is kept as it is (preserving the order of
1107
+ keys and comments).
1108
+ For keys that were automatically configured, an explanatory
1109
+ comment is added and the automatically chosen configuration is
1110
+ added commented-out.
1111
+ If there are already blocks with comments from a previous auto
1112
+ configuration run, they are replaced with the new auto
1113
+ configuration.
1114
+
1115
+ Args:
1116
+ config: The configuration including the automatically configured keys.
1117
+ config_file_path: The file into which the configuration should be dumped.
1118
+ missing_keys: Keys that need to be added to the config file.
1119
+ auto_configured_keys: Keys for which a commented out auto
1120
+ configuration section needs to be added to the config file.
1121
+ training_type: NLU, CORE or BOTH depending on which is trained.
1122
+ """
1123
+ config_as_expected = DefaultV1Recipe._is_config_file_as_expected(
1124
+ config_file_path, missing_keys, auto_configured_keys, training_type
1125
+ )
1126
+ if not config_as_expected:
1127
+ rasa.shared.utils.cli.print_error(
1128
+ f"The configuration file at '{config_file_path}' has been removed or "
1129
+ f"modified while the automatic configuration was running. The current "
1130
+ f"configuration will therefore not be dumped to the file. If you want "
1131
+ f"your model to use the configuration provided in "
1132
+ f"'{config_file_path}' you need to re-run training."
1133
+ )
1134
+ return
1135
+
1136
+ DefaultV1Recipe._add_missing_config_keys_to_file(config_file_path, missing_keys)
1137
+
1138
+ autoconfig_lines = DefaultV1Recipe._get_commented_out_autoconfig_lines(
1139
+ config, auto_configured_keys
1140
+ )
1141
+
1142
+ current_config_content = rasa.shared.utils.io.read_file(config_file_path)
1143
+ current_config_lines = current_config_content.splitlines(keepends=True)
1144
+
1145
+ updated_lines = DefaultV1Recipe._get_lines_including_autoconfig(
1146
+ current_config_lines, autoconfig_lines
1147
+ )
1148
+
1149
+ rasa.shared.utils.io.write_text_file("".join(updated_lines), config_file_path)
1150
+
1151
+ auto_configured_keys_text = transform_collection_to_sentence(
1152
+ auto_configured_keys
1153
+ )
1154
+ rasa.shared.utils.cli.print_info(
1155
+ f"The configuration for {auto_configured_keys_text} "
1156
+ f"was chosen automatically. "
1157
+ f"It was written into the config file at '{config_file_path}'."
1158
+ )
1159
+
1160
+ @staticmethod
1161
+ def _is_config_file_as_expected(
1162
+ config_file_path: Text,
1163
+ missing_keys: Set[Text],
1164
+ auto_configured_keys: Set[Text],
1165
+ training_type: Optional[TrainingType] = TrainingType.BOTH,
1166
+ ) -> bool:
1167
+ try:
1168
+ content = read_config_file(config_file_path)
1169
+ except FileNotFoundException:
1170
+ content = {}
1171
+
1172
+ return (
1173
+ bool(content)
1174
+ and missing_keys
1175
+ == DefaultV1Recipe._get_missing_config_keys(content, training_type)
1176
+ and auto_configured_keys
1177
+ == DefaultV1Recipe._get_unspecified_autoconfigurable_keys(
1178
+ content, training_type
1179
+ )
1180
+ )
1181
+
1182
+ @staticmethod
1183
+ def _add_missing_config_keys_to_file(
1184
+ config_file_path: Text, missing_keys: Set[Text]
1185
+ ) -> None:
1186
+ if not missing_keys:
1187
+ return
1188
+ with open(
1189
+ config_file_path, "a", encoding=rasa.shared.utils.io.DEFAULT_ENCODING
1190
+ ) as f:
1191
+ for key in missing_keys:
1192
+ f.write(f"{key}:\n")
1193
+
1194
+ @staticmethod
1195
+ def _get_lines_including_autoconfig(
1196
+ lines: List[Text], autoconfig_lines: Dict[Text, List[Text]]
1197
+ ) -> List[Text]:
1198
+ auto_configured_keys = autoconfig_lines.keys()
1199
+
1200
+ lines_with_autoconfig = []
1201
+ remove_comments_until_next_uncommented_line = False
1202
+ for line in lines:
1203
+ insert_section = None
1204
+
1205
+ # remove old auto configuration
1206
+ if remove_comments_until_next_uncommented_line:
1207
+ if line.startswith("#"):
1208
+ continue
1209
+ remove_comments_until_next_uncommented_line = False
1210
+
1211
+ # add an explanatory comment to autoconfigured sections
1212
+ for key in auto_configured_keys:
1213
+ if line.startswith(f"{key}:"): # start of next auto-section
1214
+ line = line + COMMENTS_FOR_KEYS[key]
1215
+ insert_section = key
1216
+ remove_comments_until_next_uncommented_line = True
1217
+
1218
+ lines_with_autoconfig.append(line)
1219
+
1220
+ if not insert_section:
1221
+ continue
1222
+
1223
+ # add the autoconfiguration (commented out)
1224
+ lines_with_autoconfig += autoconfig_lines[insert_section]
1225
+
1226
+ return lines_with_autoconfig
1227
+
1228
+ @staticmethod
1229
+ def _get_commented_out_autoconfig_lines(
1230
+ config: Dict[Text, Any], auto_configured_keys: Set[Text]
1231
+ ) -> Dict[Text, List[Text]]:
1232
+ import ruamel.yaml
1233
+ import ruamel.yaml.compat
1234
+
1235
+ yaml_parser = ruamel.yaml.YAML()
1236
+ yaml_parser.indent(mapping=2, sequence=4, offset=2)
1237
+
1238
+ autoconfig_lines = {}
1239
+
1240
+ for key in auto_configured_keys:
1241
+ stream = ruamel.yaml.compat.StringIO()
1242
+ yaml_parser.dump(config.get(key), stream)
1243
+ dump = stream.getvalue()
1244
+
1245
+ lines = dump.split("\n")
1246
+ if not lines[-1]:
1247
+ lines = lines[:-1] # yaml dump adds an empty line at the end
1248
+ lines = [f"# {line}\n" for line in lines]
1249
+
1250
+ autoconfig_lines[key] = lines
1251
+
1252
+ return autoconfig_lines