rasa-pro 3.8.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rasa-pro might be problematic. Click here for more details.

Files changed (644) hide show
  1. README.md +380 -0
  2. rasa/__init__.py +10 -0
  3. rasa/__main__.py +151 -0
  4. rasa/anonymization/__init__.py +2 -0
  5. rasa/anonymization/anonymisation_rule_yaml_reader.py +91 -0
  6. rasa/anonymization/anonymization_pipeline.py +287 -0
  7. rasa/anonymization/anonymization_rule_executor.py +260 -0
  8. rasa/anonymization/anonymization_rule_orchestrator.py +120 -0
  9. rasa/anonymization/schemas/config.yml +47 -0
  10. rasa/anonymization/utils.py +117 -0
  11. rasa/api.py +146 -0
  12. rasa/cli/__init__.py +5 -0
  13. rasa/cli/arguments/__init__.py +0 -0
  14. rasa/cli/arguments/data.py +81 -0
  15. rasa/cli/arguments/default_arguments.py +165 -0
  16. rasa/cli/arguments/evaluate.py +65 -0
  17. rasa/cli/arguments/export.py +51 -0
  18. rasa/cli/arguments/interactive.py +74 -0
  19. rasa/cli/arguments/run.py +204 -0
  20. rasa/cli/arguments/shell.py +13 -0
  21. rasa/cli/arguments/test.py +211 -0
  22. rasa/cli/arguments/train.py +263 -0
  23. rasa/cli/arguments/visualize.py +34 -0
  24. rasa/cli/arguments/x.py +30 -0
  25. rasa/cli/data.py +292 -0
  26. rasa/cli/e2e_test.py +566 -0
  27. rasa/cli/evaluate.py +222 -0
  28. rasa/cli/export.py +251 -0
  29. rasa/cli/inspect.py +63 -0
  30. rasa/cli/interactive.py +164 -0
  31. rasa/cli/license.py +65 -0
  32. rasa/cli/markers.py +78 -0
  33. rasa/cli/project_templates/__init__.py +0 -0
  34. rasa/cli/project_templates/calm/actions/__init__.py +0 -0
  35. rasa/cli/project_templates/calm/actions/action_template.py +27 -0
  36. rasa/cli/project_templates/calm/actions/add_contact.py +30 -0
  37. rasa/cli/project_templates/calm/actions/db.py +57 -0
  38. rasa/cli/project_templates/calm/actions/list_contacts.py +22 -0
  39. rasa/cli/project_templates/calm/actions/remove_contact.py +35 -0
  40. rasa/cli/project_templates/calm/config.yml +12 -0
  41. rasa/cli/project_templates/calm/credentials.yml +33 -0
  42. rasa/cli/project_templates/calm/data/flows/add_contact.yml +31 -0
  43. rasa/cli/project_templates/calm/data/flows/list_contacts.yml +14 -0
  44. rasa/cli/project_templates/calm/data/flows/remove_contact.yml +29 -0
  45. rasa/cli/project_templates/calm/db/contacts.json +10 -0
  46. rasa/cli/project_templates/calm/domain/add_contact.yml +33 -0
  47. rasa/cli/project_templates/calm/domain/list_contacts.yml +14 -0
  48. rasa/cli/project_templates/calm/domain/remove_contact.yml +31 -0
  49. rasa/cli/project_templates/calm/domain/shared.yml +5 -0
  50. rasa/cli/project_templates/calm/e2e_tests/cancelations/user_cancels_during_a_correction.yml +16 -0
  51. rasa/cli/project_templates/calm/e2e_tests/cancelations/user_changes_mind_on_a_whim.yml +7 -0
  52. rasa/cli/project_templates/calm/e2e_tests/corrections/user_corrects_contact_handle.yml +20 -0
  53. rasa/cli/project_templates/calm/e2e_tests/corrections/user_corrects_contact_name.yml +19 -0
  54. rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_adds_contact_to_their_list.yml +15 -0
  55. rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_lists_contacts.yml +5 -0
  56. rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_removes_contact.yml +11 -0
  57. rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_removes_contact_from_list.yml +12 -0
  58. rasa/cli/project_templates/calm/endpoints.yml +45 -0
  59. rasa/cli/project_templates/default/actions/__init__.py +0 -0
  60. rasa/cli/project_templates/default/actions/actions.py +27 -0
  61. rasa/cli/project_templates/default/config.yml +44 -0
  62. rasa/cli/project_templates/default/credentials.yml +33 -0
  63. rasa/cli/project_templates/default/data/nlu.yml +91 -0
  64. rasa/cli/project_templates/default/data/rules.yml +13 -0
  65. rasa/cli/project_templates/default/data/stories.yml +30 -0
  66. rasa/cli/project_templates/default/domain.yml +34 -0
  67. rasa/cli/project_templates/default/endpoints.yml +42 -0
  68. rasa/cli/project_templates/default/tests/test_stories.yml +91 -0
  69. rasa/cli/project_templates/tutorial/actions.py +22 -0
  70. rasa/cli/project_templates/tutorial/config.yml +11 -0
  71. rasa/cli/project_templates/tutorial/credentials.yml +33 -0
  72. rasa/cli/project_templates/tutorial/data/flows.yml +8 -0
  73. rasa/cli/project_templates/tutorial/domain.yml +17 -0
  74. rasa/cli/project_templates/tutorial/endpoints.yml +45 -0
  75. rasa/cli/run.py +136 -0
  76. rasa/cli/scaffold.py +268 -0
  77. rasa/cli/shell.py +141 -0
  78. rasa/cli/studio/__init__.py +0 -0
  79. rasa/cli/studio/download.py +51 -0
  80. rasa/cli/studio/studio.py +110 -0
  81. rasa/cli/studio/train.py +59 -0
  82. rasa/cli/studio/upload.py +85 -0
  83. rasa/cli/telemetry.py +90 -0
  84. rasa/cli/test.py +280 -0
  85. rasa/cli/train.py +260 -0
  86. rasa/cli/utils.py +453 -0
  87. rasa/cli/visualize.py +40 -0
  88. rasa/cli/x.py +205 -0
  89. rasa/constants.py +37 -0
  90. rasa/core/__init__.py +17 -0
  91. rasa/core/actions/__init__.py +0 -0
  92. rasa/core/actions/action.py +1450 -0
  93. rasa/core/actions/action_clean_stack.py +59 -0
  94. rasa/core/actions/action_run_slot_rejections.py +207 -0
  95. rasa/core/actions/action_trigger_chitchat.py +31 -0
  96. rasa/core/actions/action_trigger_flow.py +109 -0
  97. rasa/core/actions/action_trigger_search.py +31 -0
  98. rasa/core/actions/constants.py +2 -0
  99. rasa/core/actions/forms.py +737 -0
  100. rasa/core/actions/loops.py +111 -0
  101. rasa/core/actions/two_stage_fallback.py +186 -0
  102. rasa/core/agent.py +557 -0
  103. rasa/core/auth_retry_tracker_store.py +122 -0
  104. rasa/core/brokers/__init__.py +0 -0
  105. rasa/core/brokers/broker.py +126 -0
  106. rasa/core/brokers/file.py +58 -0
  107. rasa/core/brokers/kafka.py +322 -0
  108. rasa/core/brokers/pika.py +387 -0
  109. rasa/core/brokers/sql.py +86 -0
  110. rasa/core/channels/__init__.py +55 -0
  111. rasa/core/channels/audiocodes.py +463 -0
  112. rasa/core/channels/botframework.py +339 -0
  113. rasa/core/channels/callback.py +85 -0
  114. rasa/core/channels/channel.py +419 -0
  115. rasa/core/channels/console.py +243 -0
  116. rasa/core/channels/development_inspector.py +93 -0
  117. rasa/core/channels/facebook.py +422 -0
  118. rasa/core/channels/hangouts.py +335 -0
  119. rasa/core/channels/inspector/.eslintrc.cjs +25 -0
  120. rasa/core/channels/inspector/.gitignore +23 -0
  121. rasa/core/channels/inspector/README.md +54 -0
  122. rasa/core/channels/inspector/assets/favicon.ico +0 -0
  123. rasa/core/channels/inspector/assets/rasa-chat.js +2 -0
  124. rasa/core/channels/inspector/custom.d.ts +3 -0
  125. rasa/core/channels/inspector/dist/assets/arc-5623b6dc.js +1 -0
  126. rasa/core/channels/inspector/dist/assets/array-9f3ba611.js +1 -0
  127. rasa/core/channels/inspector/dist/assets/c4Diagram-d0fbc5ce-685c106a.js +10 -0
  128. rasa/core/channels/inspector/dist/assets/classDiagram-936ed81e-8cbed007.js +2 -0
  129. rasa/core/channels/inspector/dist/assets/classDiagram-v2-c3cb15f1-5889cf12.js +2 -0
  130. rasa/core/channels/inspector/dist/assets/createText-62fc7601-24c249d7.js +7 -0
  131. rasa/core/channels/inspector/dist/assets/edges-f2ad444c-7dd06a75.js +4 -0
  132. rasa/core/channels/inspector/dist/assets/erDiagram-9d236eb7-62c1e54c.js +51 -0
  133. rasa/core/channels/inspector/dist/assets/flowDb-1972c806-ce49b86f.js +6 -0
  134. rasa/core/channels/inspector/dist/assets/flowDiagram-7ea5b25a-4067e48f.js +4 -0
  135. rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-85583a23.js +1 -0
  136. rasa/core/channels/inspector/dist/assets/flowchart-elk-definition-abe16c3d-59fe4051.js +139 -0
  137. rasa/core/channels/inspector/dist/assets/ganttDiagram-9b5ea136-47e3a43b.js +266 -0
  138. rasa/core/channels/inspector/dist/assets/gitGraphDiagram-99d0ae7c-5a2ac0d9.js +70 -0
  139. rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-128cfa44.ttf +0 -0
  140. rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-21dbcb97.woff +0 -0
  141. rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-222b5e26.svg +329 -0
  142. rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-9ad89b2a.woff2 +0 -0
  143. rasa/core/channels/inspector/dist/assets/index-268a75c0.js +1040 -0
  144. rasa/core/channels/inspector/dist/assets/index-2c4b9a3b-dfb8efc4.js +1 -0
  145. rasa/core/channels/inspector/dist/assets/index-3ee28881.css +1 -0
  146. rasa/core/channels/inspector/dist/assets/infoDiagram-736b4530-b0c470f2.js +7 -0
  147. rasa/core/channels/inspector/dist/assets/init-77b53fdd.js +1 -0
  148. rasa/core/channels/inspector/dist/assets/journeyDiagram-df861f2b-2edb829a.js +139 -0
  149. rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-60c05ee4.woff +0 -0
  150. rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-8335d9b8.svg +438 -0
  151. rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-9cc39c75.ttf +0 -0
  152. rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-ead13ccf.woff2 +0 -0
  153. rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-16705655.woff2 +0 -0
  154. rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-5aeb07f9.woff +0 -0
  155. rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-9c459044.ttf +0 -0
  156. rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-9e2898a4.svg +435 -0
  157. rasa/core/channels/inspector/dist/assets/layout-b6873d69.js +1 -0
  158. rasa/core/channels/inspector/dist/assets/line-1efc5781.js +1 -0
  159. rasa/core/channels/inspector/dist/assets/linear-661e9b94.js +1 -0
  160. rasa/core/channels/inspector/dist/assets/mindmap-definition-beec6740-2d2e727f.js +109 -0
  161. rasa/core/channels/inspector/dist/assets/ordinal-ba9b4969.js +1 -0
  162. rasa/core/channels/inspector/dist/assets/path-53f90ab3.js +1 -0
  163. rasa/core/channels/inspector/dist/assets/pieDiagram-dbbf0591-9d3ea93d.js +35 -0
  164. rasa/core/channels/inspector/dist/assets/quadrantDiagram-4d7f4fd6-06a178a2.js +7 -0
  165. rasa/core/channels/inspector/dist/assets/requirementDiagram-6fc4c22a-0bfedffc.js +52 -0
  166. rasa/core/channels/inspector/dist/assets/sankeyDiagram-8f13d901-d76d0a04.js +8 -0
  167. rasa/core/channels/inspector/dist/assets/sequenceDiagram-b655622a-37bb4341.js +122 -0
  168. rasa/core/channels/inspector/dist/assets/stateDiagram-59f0c015-f52f7f57.js +1 -0
  169. rasa/core/channels/inspector/dist/assets/stateDiagram-v2-2b26beab-4a986a20.js +1 -0
  170. rasa/core/channels/inspector/dist/assets/styles-080da4f6-7dd9ae12.js +110 -0
  171. rasa/core/channels/inspector/dist/assets/styles-3dcbcfbf-46e1ca14.js +159 -0
  172. rasa/core/channels/inspector/dist/assets/styles-9c745c82-4a97439a.js +207 -0
  173. rasa/core/channels/inspector/dist/assets/svgDrawCommon-4835440b-823917a3.js +1 -0
  174. rasa/core/channels/inspector/dist/assets/timeline-definition-5b62e21b-9ea72896.js +61 -0
  175. rasa/core/channels/inspector/dist/assets/xychartDiagram-2b33534f-b631a8b6.js +7 -0
  176. rasa/core/channels/inspector/dist/index.html +39 -0
  177. rasa/core/channels/inspector/index.html +37 -0
  178. rasa/core/channels/inspector/jest.config.ts +13 -0
  179. rasa/core/channels/inspector/package.json +48 -0
  180. rasa/core/channels/inspector/setupTests.ts +2 -0
  181. rasa/core/channels/inspector/src/App.tsx +170 -0
  182. rasa/core/channels/inspector/src/components/DiagramFlow.tsx +97 -0
  183. rasa/core/channels/inspector/src/components/DialogueInformation.tsx +187 -0
  184. rasa/core/channels/inspector/src/components/DialogueStack.tsx +151 -0
  185. rasa/core/channels/inspector/src/components/ExpandIcon.tsx +16 -0
  186. rasa/core/channels/inspector/src/components/FullscreenButton.tsx +45 -0
  187. rasa/core/channels/inspector/src/components/LoadingSpinner.tsx +19 -0
  188. rasa/core/channels/inspector/src/components/NoActiveFlow.tsx +21 -0
  189. rasa/core/channels/inspector/src/components/RasaLogo.tsx +32 -0
  190. rasa/core/channels/inspector/src/components/SaraDiagrams.tsx +39 -0
  191. rasa/core/channels/inspector/src/components/Slots.tsx +91 -0
  192. rasa/core/channels/inspector/src/components/Welcome.tsx +54 -0
  193. rasa/core/channels/inspector/src/helpers/formatters.test.ts +385 -0
  194. rasa/core/channels/inspector/src/helpers/formatters.ts +239 -0
  195. rasa/core/channels/inspector/src/helpers/utils.ts +42 -0
  196. rasa/core/channels/inspector/src/main.tsx +13 -0
  197. rasa/core/channels/inspector/src/theme/Button/Button.ts +29 -0
  198. rasa/core/channels/inspector/src/theme/Heading/Heading.ts +31 -0
  199. rasa/core/channels/inspector/src/theme/Input/Input.ts +27 -0
  200. rasa/core/channels/inspector/src/theme/Link/Link.ts +10 -0
  201. rasa/core/channels/inspector/src/theme/Modal/Modal.ts +47 -0
  202. rasa/core/channels/inspector/src/theme/Table/Table.tsx +38 -0
  203. rasa/core/channels/inspector/src/theme/Tooltip/Tooltip.ts +12 -0
  204. rasa/core/channels/inspector/src/theme/base/breakpoints.ts +8 -0
  205. rasa/core/channels/inspector/src/theme/base/colors.ts +88 -0
  206. rasa/core/channels/inspector/src/theme/base/fonts/fontFaces.css +29 -0
  207. rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.eot +0 -0
  208. rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.svg +329 -0
  209. rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.ttf +0 -0
  210. rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.woff +0 -0
  211. rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.woff2 +0 -0
  212. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.eot +0 -0
  213. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.svg +438 -0
  214. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.ttf +0 -0
  215. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.woff +0 -0
  216. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.woff2 +0 -0
  217. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.eot +0 -0
  218. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.svg +435 -0
  219. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.ttf +0 -0
  220. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.woff +0 -0
  221. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.woff2 +0 -0
  222. rasa/core/channels/inspector/src/theme/base/radii.ts +9 -0
  223. rasa/core/channels/inspector/src/theme/base/shadows.ts +7 -0
  224. rasa/core/channels/inspector/src/theme/base/sizes.ts +7 -0
  225. rasa/core/channels/inspector/src/theme/base/space.ts +15 -0
  226. rasa/core/channels/inspector/src/theme/base/styles.ts +13 -0
  227. rasa/core/channels/inspector/src/theme/base/typography.ts +24 -0
  228. rasa/core/channels/inspector/src/theme/base/zIndices.ts +19 -0
  229. rasa/core/channels/inspector/src/theme/index.ts +101 -0
  230. rasa/core/channels/inspector/src/types.ts +64 -0
  231. rasa/core/channels/inspector/src/vite-env.d.ts +1 -0
  232. rasa/core/channels/inspector/tests/__mocks__/fileMock.ts +1 -0
  233. rasa/core/channels/inspector/tests/__mocks__/matchMedia.ts +16 -0
  234. rasa/core/channels/inspector/tests/__mocks__/styleMock.ts +1 -0
  235. rasa/core/channels/inspector/tests/renderWithProviders.tsx +14 -0
  236. rasa/core/channels/inspector/tsconfig.json +26 -0
  237. rasa/core/channels/inspector/tsconfig.node.json +10 -0
  238. rasa/core/channels/inspector/vite.config.ts +8 -0
  239. rasa/core/channels/inspector/yarn.lock +6156 -0
  240. rasa/core/channels/mattermost.py +229 -0
  241. rasa/core/channels/rasa_chat.py +126 -0
  242. rasa/core/channels/rest.py +210 -0
  243. rasa/core/channels/rocketchat.py +175 -0
  244. rasa/core/channels/slack.py +620 -0
  245. rasa/core/channels/socketio.py +274 -0
  246. rasa/core/channels/telegram.py +298 -0
  247. rasa/core/channels/twilio.py +169 -0
  248. rasa/core/channels/twilio_voice.py +367 -0
  249. rasa/core/channels/vier_cvg.py +374 -0
  250. rasa/core/channels/webexteams.py +135 -0
  251. rasa/core/concurrent_lock_store.py +210 -0
  252. rasa/core/constants.py +107 -0
  253. rasa/core/evaluation/__init__.py +0 -0
  254. rasa/core/evaluation/marker.py +267 -0
  255. rasa/core/evaluation/marker_base.py +925 -0
  256. rasa/core/evaluation/marker_stats.py +294 -0
  257. rasa/core/evaluation/marker_tracker_loader.py +103 -0
  258. rasa/core/exceptions.py +29 -0
  259. rasa/core/exporter.py +284 -0
  260. rasa/core/featurizers/__init__.py +0 -0
  261. rasa/core/featurizers/precomputation.py +410 -0
  262. rasa/core/featurizers/single_state_featurizer.py +402 -0
  263. rasa/core/featurizers/tracker_featurizers.py +1172 -0
  264. rasa/core/http_interpreter.py +89 -0
  265. rasa/core/information_retrieval/__init__.py +0 -0
  266. rasa/core/information_retrieval/faiss.py +116 -0
  267. rasa/core/information_retrieval/information_retrieval.py +72 -0
  268. rasa/core/information_retrieval/milvus.py +59 -0
  269. rasa/core/information_retrieval/qdrant.py +102 -0
  270. rasa/core/jobs.py +63 -0
  271. rasa/core/lock.py +139 -0
  272. rasa/core/lock_store.py +344 -0
  273. rasa/core/migrate.py +404 -0
  274. rasa/core/nlg/__init__.py +3 -0
  275. rasa/core/nlg/callback.py +147 -0
  276. rasa/core/nlg/contextual_response_rephraser.py +270 -0
  277. rasa/core/nlg/generator.py +230 -0
  278. rasa/core/nlg/interpolator.py +143 -0
  279. rasa/core/nlg/response.py +155 -0
  280. rasa/core/nlg/summarize.py +69 -0
  281. rasa/core/policies/__init__.py +0 -0
  282. rasa/core/policies/ensemble.py +329 -0
  283. rasa/core/policies/enterprise_search_policy.py +717 -0
  284. rasa/core/policies/enterprise_search_prompt_template.jinja2 +62 -0
  285. rasa/core/policies/flow_policy.py +205 -0
  286. rasa/core/policies/flows/__init__.py +0 -0
  287. rasa/core/policies/flows/flow_exceptions.py +44 -0
  288. rasa/core/policies/flows/flow_executor.py +582 -0
  289. rasa/core/policies/flows/flow_step_result.py +43 -0
  290. rasa/core/policies/intentless_policy.py +924 -0
  291. rasa/core/policies/intentless_prompt_template.jinja2 +22 -0
  292. rasa/core/policies/memoization.py +538 -0
  293. rasa/core/policies/policy.py +716 -0
  294. rasa/core/policies/rule_policy.py +1276 -0
  295. rasa/core/policies/ted_policy.py +2146 -0
  296. rasa/core/policies/unexpected_intent_policy.py +1015 -0
  297. rasa/core/processor.py +1331 -0
  298. rasa/core/run.py +315 -0
  299. rasa/core/secrets_manager/__init__.py +0 -0
  300. rasa/core/secrets_manager/constants.py +32 -0
  301. rasa/core/secrets_manager/endpoints.py +391 -0
  302. rasa/core/secrets_manager/factory.py +233 -0
  303. rasa/core/secrets_manager/secret_manager.py +262 -0
  304. rasa/core/secrets_manager/vault.py +576 -0
  305. rasa/core/test.py +1337 -0
  306. rasa/core/tracker_store.py +1664 -0
  307. rasa/core/train.py +107 -0
  308. rasa/core/training/__init__.py +89 -0
  309. rasa/core/training/converters/__init__.py +0 -0
  310. rasa/core/training/converters/responses_prefix_converter.py +119 -0
  311. rasa/core/training/interactive.py +1742 -0
  312. rasa/core/training/story_conflict.py +381 -0
  313. rasa/core/training/training.py +93 -0
  314. rasa/core/utils.py +344 -0
  315. rasa/core/visualize.py +70 -0
  316. rasa/dialogue_understanding/__init__.py +0 -0
  317. rasa/dialogue_understanding/coexistence/__init__.py +0 -0
  318. rasa/dialogue_understanding/coexistence/constants.py +4 -0
  319. rasa/dialogue_understanding/coexistence/intent_based_router.py +189 -0
  320. rasa/dialogue_understanding/coexistence/llm_based_router.py +261 -0
  321. rasa/dialogue_understanding/coexistence/router_template.jinja2 +12 -0
  322. rasa/dialogue_understanding/commands/__init__.py +45 -0
  323. rasa/dialogue_understanding/commands/can_not_handle_command.py +61 -0
  324. rasa/dialogue_understanding/commands/cancel_flow_command.py +116 -0
  325. rasa/dialogue_understanding/commands/chit_chat_answer_command.py +48 -0
  326. rasa/dialogue_understanding/commands/clarify_command.py +77 -0
  327. rasa/dialogue_understanding/commands/command.py +85 -0
  328. rasa/dialogue_understanding/commands/correct_slots_command.py +288 -0
  329. rasa/dialogue_understanding/commands/error_command.py +67 -0
  330. rasa/dialogue_understanding/commands/free_form_answer_command.py +9 -0
  331. rasa/dialogue_understanding/commands/handle_code_change_command.py +64 -0
  332. rasa/dialogue_understanding/commands/human_handoff_command.py +57 -0
  333. rasa/dialogue_understanding/commands/knowledge_answer_command.py +48 -0
  334. rasa/dialogue_understanding/commands/noop_command.py +45 -0
  335. rasa/dialogue_understanding/commands/set_slot_command.py +125 -0
  336. rasa/dialogue_understanding/commands/skip_question_command.py +66 -0
  337. rasa/dialogue_understanding/commands/start_flow_command.py +98 -0
  338. rasa/dialogue_understanding/generator/__init__.py +6 -0
  339. rasa/dialogue_understanding/generator/command_generator.py +257 -0
  340. rasa/dialogue_understanding/generator/command_prompt_template.jinja2 +57 -0
  341. rasa/dialogue_understanding/generator/flow_document_template.jinja2 +4 -0
  342. rasa/dialogue_understanding/generator/flow_retrieval.py +410 -0
  343. rasa/dialogue_understanding/generator/llm_command_generator.py +637 -0
  344. rasa/dialogue_understanding/generator/nlu_command_adapter.py +157 -0
  345. rasa/dialogue_understanding/patterns/__init__.py +0 -0
  346. rasa/dialogue_understanding/patterns/cancel.py +111 -0
  347. rasa/dialogue_understanding/patterns/cannot_handle.py +43 -0
  348. rasa/dialogue_understanding/patterns/chitchat.py +37 -0
  349. rasa/dialogue_understanding/patterns/clarify.py +97 -0
  350. rasa/dialogue_understanding/patterns/code_change.py +41 -0
  351. rasa/dialogue_understanding/patterns/collect_information.py +90 -0
  352. rasa/dialogue_understanding/patterns/completed.py +40 -0
  353. rasa/dialogue_understanding/patterns/continue_interrupted.py +42 -0
  354. rasa/dialogue_understanding/patterns/correction.py +278 -0
  355. rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +243 -0
  356. rasa/dialogue_understanding/patterns/human_handoff.py +37 -0
  357. rasa/dialogue_understanding/patterns/internal_error.py +47 -0
  358. rasa/dialogue_understanding/patterns/search.py +37 -0
  359. rasa/dialogue_understanding/patterns/skip_question.py +38 -0
  360. rasa/dialogue_understanding/processor/__init__.py +0 -0
  361. rasa/dialogue_understanding/processor/command_processor.py +578 -0
  362. rasa/dialogue_understanding/processor/command_processor_component.py +39 -0
  363. rasa/dialogue_understanding/stack/__init__.py +0 -0
  364. rasa/dialogue_understanding/stack/dialogue_stack.py +178 -0
  365. rasa/dialogue_understanding/stack/frames/__init__.py +19 -0
  366. rasa/dialogue_understanding/stack/frames/chit_chat_frame.py +27 -0
  367. rasa/dialogue_understanding/stack/frames/dialogue_stack_frame.py +137 -0
  368. rasa/dialogue_understanding/stack/frames/flow_stack_frame.py +157 -0
  369. rasa/dialogue_understanding/stack/frames/pattern_frame.py +10 -0
  370. rasa/dialogue_understanding/stack/frames/search_frame.py +27 -0
  371. rasa/dialogue_understanding/stack/utils.py +211 -0
  372. rasa/e2e_test/__init__.py +0 -0
  373. rasa/e2e_test/constants.py +10 -0
  374. rasa/e2e_test/e2e_test_case.py +322 -0
  375. rasa/e2e_test/e2e_test_result.py +34 -0
  376. rasa/e2e_test/e2e_test_runner.py +659 -0
  377. rasa/e2e_test/e2e_test_schema.yml +67 -0
  378. rasa/engine/__init__.py +0 -0
  379. rasa/engine/caching.py +464 -0
  380. rasa/engine/constants.py +17 -0
  381. rasa/engine/exceptions.py +14 -0
  382. rasa/engine/graph.py +625 -0
  383. rasa/engine/loader.py +36 -0
  384. rasa/engine/recipes/__init__.py +0 -0
  385. rasa/engine/recipes/config_files/default_config.yml +44 -0
  386. rasa/engine/recipes/default_components.py +99 -0
  387. rasa/engine/recipes/default_recipe.py +1252 -0
  388. rasa/engine/recipes/graph_recipe.py +79 -0
  389. rasa/engine/recipes/recipe.py +93 -0
  390. rasa/engine/runner/__init__.py +0 -0
  391. rasa/engine/runner/dask.py +256 -0
  392. rasa/engine/runner/interface.py +49 -0
  393. rasa/engine/storage/__init__.py +0 -0
  394. rasa/engine/storage/local_model_storage.py +248 -0
  395. rasa/engine/storage/resource.py +110 -0
  396. rasa/engine/storage/storage.py +203 -0
  397. rasa/engine/training/__init__.py +0 -0
  398. rasa/engine/training/components.py +176 -0
  399. rasa/engine/training/fingerprinting.py +64 -0
  400. rasa/engine/training/graph_trainer.py +256 -0
  401. rasa/engine/training/hooks.py +164 -0
  402. rasa/engine/validation.py +839 -0
  403. rasa/env.py +5 -0
  404. rasa/exceptions.py +69 -0
  405. rasa/graph_components/__init__.py +0 -0
  406. rasa/graph_components/converters/__init__.py +0 -0
  407. rasa/graph_components/converters/nlu_message_converter.py +48 -0
  408. rasa/graph_components/providers/__init__.py +0 -0
  409. rasa/graph_components/providers/domain_for_core_training_provider.py +87 -0
  410. rasa/graph_components/providers/domain_provider.py +71 -0
  411. rasa/graph_components/providers/flows_provider.py +74 -0
  412. rasa/graph_components/providers/forms_provider.py +44 -0
  413. rasa/graph_components/providers/nlu_training_data_provider.py +56 -0
  414. rasa/graph_components/providers/responses_provider.py +44 -0
  415. rasa/graph_components/providers/rule_only_provider.py +49 -0
  416. rasa/graph_components/providers/story_graph_provider.py +43 -0
  417. rasa/graph_components/providers/training_tracker_provider.py +55 -0
  418. rasa/graph_components/validators/__init__.py +0 -0
  419. rasa/graph_components/validators/default_recipe_validator.py +552 -0
  420. rasa/graph_components/validators/finetuning_validator.py +302 -0
  421. rasa/hooks.py +113 -0
  422. rasa/jupyter.py +63 -0
  423. rasa/keys +1 -0
  424. rasa/markers/__init__.py +0 -0
  425. rasa/markers/marker.py +269 -0
  426. rasa/markers/marker_base.py +828 -0
  427. rasa/markers/upload.py +74 -0
  428. rasa/markers/validate.py +21 -0
  429. rasa/model.py +118 -0
  430. rasa/model_testing.py +457 -0
  431. rasa/model_training.py +535 -0
  432. rasa/nlu/__init__.py +7 -0
  433. rasa/nlu/classifiers/__init__.py +3 -0
  434. rasa/nlu/classifiers/classifier.py +5 -0
  435. rasa/nlu/classifiers/diet_classifier.py +1874 -0
  436. rasa/nlu/classifiers/fallback_classifier.py +192 -0
  437. rasa/nlu/classifiers/keyword_intent_classifier.py +188 -0
  438. rasa/nlu/classifiers/llm_intent_classifier.py +519 -0
  439. rasa/nlu/classifiers/logistic_regression_classifier.py +240 -0
  440. rasa/nlu/classifiers/mitie_intent_classifier.py +156 -0
  441. rasa/nlu/classifiers/regex_message_handler.py +56 -0
  442. rasa/nlu/classifiers/sklearn_intent_classifier.py +309 -0
  443. rasa/nlu/constants.py +77 -0
  444. rasa/nlu/convert.py +40 -0
  445. rasa/nlu/emulators/__init__.py +0 -0
  446. rasa/nlu/emulators/dialogflow.py +55 -0
  447. rasa/nlu/emulators/emulator.py +49 -0
  448. rasa/nlu/emulators/luis.py +86 -0
  449. rasa/nlu/emulators/no_emulator.py +10 -0
  450. rasa/nlu/emulators/wit.py +56 -0
  451. rasa/nlu/extractors/__init__.py +0 -0
  452. rasa/nlu/extractors/crf_entity_extractor.py +672 -0
  453. rasa/nlu/extractors/duckling_entity_extractor.py +206 -0
  454. rasa/nlu/extractors/entity_synonyms.py +178 -0
  455. rasa/nlu/extractors/extractor.py +470 -0
  456. rasa/nlu/extractors/mitie_entity_extractor.py +293 -0
  457. rasa/nlu/extractors/regex_entity_extractor.py +220 -0
  458. rasa/nlu/extractors/spacy_entity_extractor.py +95 -0
  459. rasa/nlu/featurizers/__init__.py +0 -0
  460. rasa/nlu/featurizers/dense_featurizer/__init__.py +0 -0
  461. rasa/nlu/featurizers/dense_featurizer/convert_featurizer.py +449 -0
  462. rasa/nlu/featurizers/dense_featurizer/dense_featurizer.py +57 -0
  463. rasa/nlu/featurizers/dense_featurizer/lm_featurizer.py +772 -0
  464. rasa/nlu/featurizers/dense_featurizer/mitie_featurizer.py +170 -0
  465. rasa/nlu/featurizers/dense_featurizer/spacy_featurizer.py +132 -0
  466. rasa/nlu/featurizers/featurizer.py +89 -0
  467. rasa/nlu/featurizers/sparse_featurizer/__init__.py +0 -0
  468. rasa/nlu/featurizers/sparse_featurizer/count_vectors_featurizer.py +840 -0
  469. rasa/nlu/featurizers/sparse_featurizer/lexical_syntactic_featurizer.py +539 -0
  470. rasa/nlu/featurizers/sparse_featurizer/regex_featurizer.py +269 -0
  471. rasa/nlu/featurizers/sparse_featurizer/sparse_featurizer.py +9 -0
  472. rasa/nlu/model.py +24 -0
  473. rasa/nlu/persistor.py +240 -0
  474. rasa/nlu/run.py +27 -0
  475. rasa/nlu/selectors/__init__.py +0 -0
  476. rasa/nlu/selectors/response_selector.py +990 -0
  477. rasa/nlu/test.py +1943 -0
  478. rasa/nlu/tokenizers/__init__.py +0 -0
  479. rasa/nlu/tokenizers/jieba_tokenizer.py +148 -0
  480. rasa/nlu/tokenizers/mitie_tokenizer.py +75 -0
  481. rasa/nlu/tokenizers/spacy_tokenizer.py +72 -0
  482. rasa/nlu/tokenizers/tokenizer.py +239 -0
  483. rasa/nlu/tokenizers/whitespace_tokenizer.py +106 -0
  484. rasa/nlu/utils/__init__.py +35 -0
  485. rasa/nlu/utils/bilou_utils.py +462 -0
  486. rasa/nlu/utils/hugging_face/__init__.py +0 -0
  487. rasa/nlu/utils/hugging_face/registry.py +108 -0
  488. rasa/nlu/utils/hugging_face/transformers_pre_post_processors.py +311 -0
  489. rasa/nlu/utils/mitie_utils.py +113 -0
  490. rasa/nlu/utils/pattern_utils.py +168 -0
  491. rasa/nlu/utils/spacy_utils.py +312 -0
  492. rasa/plugin.py +90 -0
  493. rasa/server.py +1536 -0
  494. rasa/shared/__init__.py +0 -0
  495. rasa/shared/constants.py +181 -0
  496. rasa/shared/core/__init__.py +0 -0
  497. rasa/shared/core/constants.py +168 -0
  498. rasa/shared/core/conversation.py +46 -0
  499. rasa/shared/core/domain.py +2106 -0
  500. rasa/shared/core/events.py +2507 -0
  501. rasa/shared/core/flows/__init__.py +7 -0
  502. rasa/shared/core/flows/flow.py +353 -0
  503. rasa/shared/core/flows/flow_step.py +146 -0
  504. rasa/shared/core/flows/flow_step_links.py +319 -0
  505. rasa/shared/core/flows/flow_step_sequence.py +70 -0
  506. rasa/shared/core/flows/flows_list.py +211 -0
  507. rasa/shared/core/flows/flows_yaml_schema.json +217 -0
  508. rasa/shared/core/flows/nlu_trigger.py +117 -0
  509. rasa/shared/core/flows/steps/__init__.py +24 -0
  510. rasa/shared/core/flows/steps/action.py +51 -0
  511. rasa/shared/core/flows/steps/call.py +64 -0
  512. rasa/shared/core/flows/steps/collect.py +112 -0
  513. rasa/shared/core/flows/steps/constants.py +5 -0
  514. rasa/shared/core/flows/steps/continuation.py +36 -0
  515. rasa/shared/core/flows/steps/end.py +22 -0
  516. rasa/shared/core/flows/steps/internal.py +44 -0
  517. rasa/shared/core/flows/steps/link.py +51 -0
  518. rasa/shared/core/flows/steps/no_operation.py +48 -0
  519. rasa/shared/core/flows/steps/set_slots.py +50 -0
  520. rasa/shared/core/flows/steps/start.py +30 -0
  521. rasa/shared/core/flows/validation.py +527 -0
  522. rasa/shared/core/flows/yaml_flows_io.py +278 -0
  523. rasa/shared/core/generator.py +907 -0
  524. rasa/shared/core/slot_mappings.py +235 -0
  525. rasa/shared/core/slots.py +647 -0
  526. rasa/shared/core/trackers.py +1159 -0
  527. rasa/shared/core/training_data/__init__.py +0 -0
  528. rasa/shared/core/training_data/loading.py +90 -0
  529. rasa/shared/core/training_data/story_reader/__init__.py +0 -0
  530. rasa/shared/core/training_data/story_reader/story_reader.py +129 -0
  531. rasa/shared/core/training_data/story_reader/story_step_builder.py +168 -0
  532. rasa/shared/core/training_data/story_reader/yaml_story_reader.py +888 -0
  533. rasa/shared/core/training_data/story_writer/__init__.py +0 -0
  534. rasa/shared/core/training_data/story_writer/story_writer.py +76 -0
  535. rasa/shared/core/training_data/story_writer/yaml_story_writer.py +442 -0
  536. rasa/shared/core/training_data/structures.py +838 -0
  537. rasa/shared/core/training_data/visualization.html +146 -0
  538. rasa/shared/core/training_data/visualization.py +603 -0
  539. rasa/shared/data.py +192 -0
  540. rasa/shared/engine/__init__.py +0 -0
  541. rasa/shared/engine/caching.py +26 -0
  542. rasa/shared/exceptions.py +129 -0
  543. rasa/shared/importers/__init__.py +0 -0
  544. rasa/shared/importers/importer.py +705 -0
  545. rasa/shared/importers/multi_project.py +203 -0
  546. rasa/shared/importers/rasa.py +100 -0
  547. rasa/shared/importers/utils.py +34 -0
  548. rasa/shared/nlu/__init__.py +0 -0
  549. rasa/shared/nlu/constants.py +45 -0
  550. rasa/shared/nlu/interpreter.py +10 -0
  551. rasa/shared/nlu/training_data/__init__.py +0 -0
  552. rasa/shared/nlu/training_data/entities_parser.py +209 -0
  553. rasa/shared/nlu/training_data/features.py +374 -0
  554. rasa/shared/nlu/training_data/formats/__init__.py +10 -0
  555. rasa/shared/nlu/training_data/formats/dialogflow.py +162 -0
  556. rasa/shared/nlu/training_data/formats/luis.py +87 -0
  557. rasa/shared/nlu/training_data/formats/rasa.py +135 -0
  558. rasa/shared/nlu/training_data/formats/rasa_yaml.py +605 -0
  559. rasa/shared/nlu/training_data/formats/readerwriter.py +245 -0
  560. rasa/shared/nlu/training_data/formats/wit.py +52 -0
  561. rasa/shared/nlu/training_data/loading.py +137 -0
  562. rasa/shared/nlu/training_data/lookup_tables_parser.py +30 -0
  563. rasa/shared/nlu/training_data/message.py +477 -0
  564. rasa/shared/nlu/training_data/schemas/__init__.py +0 -0
  565. rasa/shared/nlu/training_data/schemas/data_schema.py +85 -0
  566. rasa/shared/nlu/training_data/schemas/nlu.yml +53 -0
  567. rasa/shared/nlu/training_data/schemas/responses.yml +70 -0
  568. rasa/shared/nlu/training_data/synonyms_parser.py +42 -0
  569. rasa/shared/nlu/training_data/training_data.py +732 -0
  570. rasa/shared/nlu/training_data/util.py +223 -0
  571. rasa/shared/providers/__init__.py +0 -0
  572. rasa/shared/providers/openai/__init__.py +0 -0
  573. rasa/shared/providers/openai/clients.py +43 -0
  574. rasa/shared/providers/openai/session_handler.py +110 -0
  575. rasa/shared/utils/__init__.py +0 -0
  576. rasa/shared/utils/cli.py +72 -0
  577. rasa/shared/utils/common.py +308 -0
  578. rasa/shared/utils/constants.py +1 -0
  579. rasa/shared/utils/io.py +403 -0
  580. rasa/shared/utils/llm.py +405 -0
  581. rasa/shared/utils/pykwalify_extensions.py +26 -0
  582. rasa/shared/utils/schemas/__init__.py +0 -0
  583. rasa/shared/utils/schemas/config.yml +2 -0
  584. rasa/shared/utils/schemas/domain.yml +142 -0
  585. rasa/shared/utils/schemas/events.py +212 -0
  586. rasa/shared/utils/schemas/model_config.yml +46 -0
  587. rasa/shared/utils/schemas/stories.yml +173 -0
  588. rasa/shared/utils/yaml.py +777 -0
  589. rasa/studio/__init__.py +0 -0
  590. rasa/studio/auth.py +252 -0
  591. rasa/studio/config.py +127 -0
  592. rasa/studio/constants.py +16 -0
  593. rasa/studio/data_handler.py +352 -0
  594. rasa/studio/download.py +350 -0
  595. rasa/studio/train.py +136 -0
  596. rasa/studio/upload.py +408 -0
  597. rasa/telemetry.py +1583 -0
  598. rasa/tracing/__init__.py +0 -0
  599. rasa/tracing/config.py +338 -0
  600. rasa/tracing/constants.py +38 -0
  601. rasa/tracing/instrumentation/__init__.py +0 -0
  602. rasa/tracing/instrumentation/attribute_extractors.py +663 -0
  603. rasa/tracing/instrumentation/instrumentation.py +939 -0
  604. rasa/tracing/instrumentation/intentless_policy_instrumentation.py +142 -0
  605. rasa/tracing/instrumentation/metrics.py +206 -0
  606. rasa/tracing/metric_instrument_provider.py +125 -0
  607. rasa/utils/__init__.py +0 -0
  608. rasa/utils/beta.py +83 -0
  609. rasa/utils/cli.py +27 -0
  610. rasa/utils/common.py +635 -0
  611. rasa/utils/converter.py +53 -0
  612. rasa/utils/endpoints.py +303 -0
  613. rasa/utils/io.py +326 -0
  614. rasa/utils/licensing.py +319 -0
  615. rasa/utils/log_utils.py +174 -0
  616. rasa/utils/mapper.py +210 -0
  617. rasa/utils/ml_utils.py +145 -0
  618. rasa/utils/plotting.py +362 -0
  619. rasa/utils/singleton.py +23 -0
  620. rasa/utils/tensorflow/__init__.py +0 -0
  621. rasa/utils/tensorflow/callback.py +112 -0
  622. rasa/utils/tensorflow/constants.py +116 -0
  623. rasa/utils/tensorflow/crf.py +492 -0
  624. rasa/utils/tensorflow/data_generator.py +440 -0
  625. rasa/utils/tensorflow/environment.py +161 -0
  626. rasa/utils/tensorflow/exceptions.py +5 -0
  627. rasa/utils/tensorflow/layers.py +1565 -0
  628. rasa/utils/tensorflow/layers_utils.py +113 -0
  629. rasa/utils/tensorflow/metrics.py +281 -0
  630. rasa/utils/tensorflow/model_data.py +991 -0
  631. rasa/utils/tensorflow/model_data_utils.py +500 -0
  632. rasa/utils/tensorflow/models.py +936 -0
  633. rasa/utils/tensorflow/rasa_layers.py +1094 -0
  634. rasa/utils/tensorflow/transformer.py +640 -0
  635. rasa/utils/tensorflow/types.py +6 -0
  636. rasa/utils/train_utils.py +572 -0
  637. rasa/utils/yaml.py +54 -0
  638. rasa/validator.py +1035 -0
  639. rasa/version.py +3 -0
  640. rasa_pro-3.8.16.dist-info/METADATA +528 -0
  641. rasa_pro-3.8.16.dist-info/NOTICE +5 -0
  642. rasa_pro-3.8.16.dist-info/RECORD +644 -0
  643. rasa_pro-3.8.16.dist-info/WHEEL +4 -0
  644. rasa_pro-3.8.16.dist-info/entry_points.txt +3 -0
@@ -0,0 +1,717 @@
1
+ import importlib.resources
2
+ import json
3
+ import re
4
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Text
5
+
6
+ import dotenv
7
+ import rasa.shared.utils.io
8
+ import structlog
9
+ from jinja2 import Template
10
+ from pydantic.error_wrappers import ValidationError
11
+ from rasa.shared.exceptions import RasaException
12
+ from rasa.core.constants import (
13
+ POLICY_MAX_HISTORY,
14
+ POLICY_PRIORITY,
15
+ SEARCH_POLICY_PRIORITY,
16
+ )
17
+ from rasa.core.policies.policy import Policy, PolicyPrediction
18
+ from rasa.core.utils import AvailableEndpoints
19
+ from rasa.dialogue_understanding.patterns.internal_error import (
20
+ InternalErrorPatternFlowStackFrame,
21
+ )
22
+ from rasa.dialogue_understanding.patterns.cannot_handle import (
23
+ CannotHandlePatternFlowStackFrame,
24
+ )
25
+ from rasa.dialogue_understanding.stack.frames import PatternFlowStackFrame
26
+ from rasa.dialogue_understanding.stack.frames import (
27
+ DialogueStackFrame,
28
+ SearchStackFrame,
29
+ )
30
+ from rasa.engine.graph import ExecutionContext
31
+ from rasa.engine.recipes.default_recipe import DefaultV1Recipe
32
+ from rasa.engine.storage.resource import Resource
33
+ from rasa.engine.storage.storage import ModelStorage
34
+ from rasa.graph_components.providers.forms_provider import Forms
35
+ from rasa.graph_components.providers.responses_provider import Responses
36
+ from rasa.shared.core.constants import (
37
+ ACTION_CANCEL_FLOW,
38
+ ACTION_SEND_TEXT_NAME,
39
+ DEFAULT_SLOT_NAMES,
40
+ )
41
+ from rasa.shared.core.domain import Domain
42
+ from rasa.shared.core.events import Event
43
+ from rasa.shared.core.generator import TrackerWithCachedStates
44
+ from rasa.shared.core.trackers import DialogueStateTracker
45
+ from rasa.shared.nlu.training_data.training_data import TrainingData
46
+ from rasa.shared.utils.cli import print_error_and_exit
47
+ from rasa.shared.utils.io import deep_container_fingerprint
48
+ from rasa.shared.utils.llm import (
49
+ DEFAULT_OPENAI_CHAT_MODEL_NAME,
50
+ DEFAULT_OPENAI_EMBEDDING_MODEL_NAME,
51
+ embedder_factory,
52
+ get_prompt_template,
53
+ llm_factory,
54
+ sanitize_message_for_prompt,
55
+ tracker_as_readable_transcript,
56
+ )
57
+
58
+ from rasa.core.information_retrieval.faiss import FAISS_Store
59
+ from rasa.core.information_retrieval.information_retrieval import (
60
+ InformationRetrieval,
61
+ InformationRetrievalException,
62
+ create_from_endpoint_config,
63
+ )
64
+
65
+ if TYPE_CHECKING:
66
+ from langchain.schema import Document
67
+ from langchain.schema.embeddings import Embeddings
68
+ from langchain.llms.base import BaseLLM
69
+ from rasa.core.featurizers.tracker_featurizers import TrackerFeaturizer
70
+
71
+ from rasa.utils.log_utils import log_llm
72
+
73
+ logger = structlog.get_logger()
74
+
75
+ dotenv.load_dotenv("./.env")
76
+
77
+ SOURCE_PROPERTY = "source"
78
+ VECTOR_STORE_TYPE_PROPERTY = "type"
79
+ VECTOR_STORE_PROPERTY = "vector_store"
80
+ VECTOR_STORE_THRESHOLD_PROPERTY = "threshold"
81
+
82
+ DEFAULT_VECTOR_STORE_TYPE = "faiss"
83
+ DEFAULT_VECTOR_STORE_THRESHOLD = 0.0
84
+ DEFAULT_VECTOR_STORE = {
85
+ VECTOR_STORE_TYPE_PROPERTY: DEFAULT_VECTOR_STORE_TYPE,
86
+ SOURCE_PROPERTY: "./docs",
87
+ VECTOR_STORE_THRESHOLD_PROPERTY: DEFAULT_VECTOR_STORE_THRESHOLD,
88
+ }
89
+
90
+ DEFAULT_LLM_CONFIG = {
91
+ "_type": "openai",
92
+ "request_timeout": 10,
93
+ "temperature": 0.0,
94
+ "max_tokens": 256,
95
+ "model_name": DEFAULT_OPENAI_CHAT_MODEL_NAME,
96
+ "max_retries": 1,
97
+ }
98
+
99
+ DEFAULT_EMBEDDINGS_CONFIG = {
100
+ "_type": "openai",
101
+ "model": DEFAULT_OPENAI_EMBEDDING_MODEL_NAME,
102
+ }
103
+
104
+ EMBEDDINGS_CONFIG_KEY = "embeddings"
105
+ LLM_CONFIG_KEY = "llm"
106
+ ENTERPRISE_SEARCH_PROMPT_FILE_NAME = "enterprise_search_policy_prompt.jinja2"
107
+
108
+ DEFAULT_ENTERPRISE_SEARCH_PROMPT_TEMPLATE = importlib.resources.read_text(
109
+ "rasa.core.policies", "enterprise_search_prompt_template.jinja2"
110
+ )
111
+
112
+
113
+ class VectorStoreConnectionError(RasaException):
114
+ """Exception raised for errors in connecting to the vector store."""
115
+
116
+
117
+ class VectorStoreConfigurationError(RasaException):
118
+ """Exception raised for errors in vector store configuration."""
119
+
120
+
121
+ @DefaultV1Recipe.register(
122
+ DefaultV1Recipe.ComponentType.POLICY_WITH_END_TO_END_SUPPORT, is_trainable=True
123
+ )
124
+ class EnterpriseSearchPolicy(Policy):
125
+ """Policy which uses a vector store and LLMs to respond to user messages.
126
+
127
+ The policy uses a vector store and LLMs to respond to user messages. The
128
+ vector store is used to retrieve the most relevant responses to the user
129
+ message. The LLMs are used to rank the responses and select the best
130
+ response. The policy can be used to respond to user messages without
131
+ training data.
132
+
133
+ Example Configuration:
134
+
135
+ policies:
136
+ # - ...
137
+ - name: EnterpriseSearchPolicy
138
+ vector_store:
139
+ type: "milvus"
140
+ <vector_store_config>
141
+ # - ...
142
+ """
143
+
144
+ @staticmethod
145
+ def does_support_stack_frame(frame: DialogueStackFrame) -> bool:
146
+ """Checks if the policy supports the given stack frame."""
147
+ return isinstance(frame, SearchStackFrame)
148
+
149
+ @staticmethod
150
+ def get_default_config() -> Dict[str, Any]:
151
+ """Returns the default config of the policy."""
152
+ return {
153
+ POLICY_PRIORITY: SEARCH_POLICY_PRIORITY,
154
+ VECTOR_STORE_PROPERTY: DEFAULT_VECTOR_STORE,
155
+ }
156
+
157
+ def __init__(
158
+ self,
159
+ config: Dict[Text, Any],
160
+ model_storage: ModelStorage,
161
+ resource: Resource,
162
+ execution_context: ExecutionContext,
163
+ vector_store: Optional[InformationRetrieval] = None,
164
+ featurizer: Optional["TrackerFeaturizer"] = None,
165
+ prompt_template: Optional[Text] = None,
166
+ ) -> None:
167
+ """Constructs a new Policy object."""
168
+ super().__init__(config, model_storage, resource, execution_context, featurizer)
169
+
170
+ self.vector_store = vector_store
171
+ self.vector_store_config = config.get(
172
+ VECTOR_STORE_PROPERTY, DEFAULT_VECTOR_STORE
173
+ )
174
+ self.max_history = self.config.get(POLICY_MAX_HISTORY)
175
+ self.prompt_template = prompt_template or get_prompt_template(
176
+ self.config.get("prompt"),
177
+ DEFAULT_ENTERPRISE_SEARCH_PROMPT_TEMPLATE,
178
+ )
179
+ self.trace_prompt_tokens = self.config.get("trace_prompt_tokens", False)
180
+ self.citation_enabled = self.config.get("citation_enabled", False)
181
+
182
+ @classmethod
183
+ def _create_plain_embedder(cls, config: Dict[Text, Any]) -> "Embeddings":
184
+ """Creates an embedder based on the given configuration.
185
+
186
+ Returns:
187
+ The embedder.
188
+ """
189
+ return embedder_factory(
190
+ config.get(EMBEDDINGS_CONFIG_KEY), DEFAULT_EMBEDDINGS_CONFIG
191
+ )
192
+
193
+ def train( # type: ignore[override]
194
+ self,
195
+ training_trackers: List[TrackerWithCachedStates],
196
+ domain: Domain,
197
+ responses: Responses,
198
+ forms: Forms,
199
+ training_data: TrainingData,
200
+ **kwargs: Any,
201
+ ) -> Resource:
202
+ """Trains a policy.
203
+
204
+ Args:
205
+ training_trackers: The story and rules trackers from the training data.
206
+ domain: The model's domain.
207
+ responses: The model's responses.
208
+ forms: The model's forms.
209
+ training_data: The model's training data.
210
+ **kwargs: Depending on the specified `needs` section and the resulting
211
+ graph structure the policy can use different input to train itself.
212
+
213
+ Returns:
214
+ A policy must return its resource locator so that potential children nodes
215
+ can load the policy from the resource.
216
+ """
217
+ store_type = self.vector_store_config.get(VECTOR_STORE_TYPE_PROPERTY)
218
+
219
+ # validate embedding configuration
220
+ try:
221
+ embeddings = self._create_plain_embedder(self.config)
222
+ except ValidationError as e:
223
+ print_error_and_exit(
224
+ "Unable to create embedder. Please make sure you specified the "
225
+ f"required environment variables. Error: {e}"
226
+ )
227
+
228
+ # validate llm configuration
229
+ try:
230
+ llm_factory(self.config.get(LLM_CONFIG_KEY), DEFAULT_LLM_CONFIG)
231
+ except (ImportError, ValueError, ValidationError) as e:
232
+ # ImportError: llm library is likely not installed
233
+ # ValueError: llm config is likely invalid
234
+ # ValidationError: environment variables are likely not set
235
+ print_error_and_exit(f"Unable to create LLM. Error: {e}")
236
+
237
+ if store_type == DEFAULT_VECTOR_STORE_TYPE:
238
+ logger.info("enterprise_search_policy.train.faiss")
239
+ with self._model_storage.write_to(self._resource) as path:
240
+ self.vector_store = FAISS_Store(
241
+ docs_folder=self.vector_store_config.get(SOURCE_PROPERTY),
242
+ embeddings=embeddings,
243
+ index_path=path,
244
+ create_index=True,
245
+ )
246
+ else:
247
+ logger.info("enterprise_search_policy.train.custom", store_type=store_type)
248
+
249
+ self.persist()
250
+ return self._resource
251
+
252
+ def persist(self) -> None:
253
+ """Persists the policy to storage."""
254
+ with self._model_storage.write_to(self._resource) as path:
255
+ rasa.shared.utils.io.write_text_file(
256
+ self.prompt_template, path / ENTERPRISE_SEARCH_PROMPT_FILE_NAME
257
+ )
258
+
259
+ def _prepare_slots_for_template(
260
+ self, tracker: DialogueStateTracker
261
+ ) -> List[Dict[str, str]]:
262
+ """Prepares the slots for the template.
263
+
264
+ Args:
265
+ tracker: The tracker containing the conversation history up to now.
266
+
267
+ Returns:
268
+ The non-empty slots.
269
+ """
270
+ template_slots = []
271
+ for name, slot in tracker.slots.items():
272
+ if name not in DEFAULT_SLOT_NAMES and slot.value is not None:
273
+ template_slots.append(
274
+ {
275
+ "name": name,
276
+ "value": str(slot.value),
277
+ "type": slot.type_name,
278
+ }
279
+ )
280
+ return template_slots
281
+
282
+ def _connect_vector_store_or_raise(
283
+ self, endpoints: Optional[AvailableEndpoints]
284
+ ) -> None:
285
+ """Connects to the vector store or raises an exception.
286
+
287
+ Raise exceptions for the following cases:
288
+ - The configuration is not specified
289
+ - Unable to connect to the vector store
290
+
291
+ Args:
292
+ endpoints: Endpoints configuration.
293
+ """
294
+ config = endpoints.vector_store if endpoints else None
295
+ store_type = self.vector_store_config.get(VECTOR_STORE_TYPE_PROPERTY)
296
+ if config is None and store_type != DEFAULT_VECTOR_STORE_TYPE:
297
+ logger.error(
298
+ "enterprise_search_policy._connect_vector_store_or_raise.no_config"
299
+ )
300
+ raise VectorStoreConfigurationError(
301
+ """No vector store specified. Please specify a vector
302
+ store in the endpoints configuration"""
303
+ )
304
+ try:
305
+ self.vector_store.connect(config) # type: ignore
306
+ except Exception as e:
307
+ logger.error(
308
+ "enterprise_search_policy._connect_vector_store_or_raise.connect_error",
309
+ error=e,
310
+ )
311
+ raise VectorStoreConnectionError(
312
+ f"Unable to connect to the vector store. Error: {e}"
313
+ )
314
+
315
+ def _get_last_user_message(self, tracker: DialogueStateTracker) -> str:
316
+ """Get the last user message from the tracker.
317
+
318
+ Args:
319
+ tracker: The tracker containing the conversation history up to now.
320
+
321
+ Returns:
322
+ The last user message.
323
+ """
324
+ for event in reversed(tracker.events):
325
+ if isinstance(event, rasa.shared.core.events.UserUttered):
326
+ return sanitize_message_for_prompt(event.text)
327
+ return ""
328
+
329
+ async def predict_action_probabilities( # type: ignore[override]
330
+ self,
331
+ tracker: DialogueStateTracker,
332
+ domain: Domain,
333
+ endpoints: Optional[AvailableEndpoints],
334
+ rule_only_data: Optional[Dict[Text, Any]] = None,
335
+ **kwargs: Any,
336
+ ) -> PolicyPrediction:
337
+ """Predicts the next action the bot should take after seeing the tracker.
338
+
339
+ Args:
340
+ tracker: The tracker containing the conversation history up to now.
341
+ domain: The model's domain.
342
+ endpoints: The model's endpoints.
343
+ rule_only_data: Slots and loops which are specific to rules and hence
344
+ should be ignored by this policy.
345
+ **kwargs: Depending on the specified `needs` section and the resulting
346
+ graph structure the policy can use different input to make predictions.
347
+
348
+ Returns:
349
+ The prediction.
350
+ """
351
+ logger_key = "enterprise_search_policy.predict_action_probabilities"
352
+ vector_search_threshold = self.vector_store_config.get(
353
+ VECTOR_STORE_THRESHOLD_PROPERTY, DEFAULT_VECTOR_STORE_THRESHOLD
354
+ )
355
+ llm = llm_factory(self.config.get(LLM_CONFIG_KEY), DEFAULT_LLM_CONFIG)
356
+ if not self.supports_current_stack_frame(tracker, False, False):
357
+ return self._prediction(self._default_predictions(domain))
358
+
359
+ if not self.vector_store:
360
+ logger.error(f"{logger_key}.no_vector_store")
361
+ return self._create_prediction_internal_error(domain, tracker)
362
+
363
+ try:
364
+ self._connect_vector_store_or_raise(endpoints)
365
+ except (VectorStoreConfigurationError, VectorStoreConnectionError) as e:
366
+ logger.error(f"{logger_key}.connection_error", error=e)
367
+ return self._create_prediction_internal_error(domain, tracker)
368
+
369
+ search_query = self._get_last_user_message(tracker)
370
+
371
+ try:
372
+ documents = await self.vector_store.search(
373
+ query=search_query,
374
+ threshold=vector_search_threshold,
375
+ )
376
+ except InformationRetrievalException as e:
377
+ logger.error(f"{logger_key}.search_error", error=e)
378
+ return self._create_prediction_internal_error(domain, tracker)
379
+
380
+ if not documents:
381
+ logger.info(f"{logger_key}.no_documents")
382
+ return self._create_prediction_cannot_handle(domain, tracker)
383
+
384
+ logger.debug(f"{logger_key}.documents", num_documents=len(documents))
385
+ prompt = self._render_prompt(tracker, documents)
386
+ llm_answer = await self._generate_llm_answer(llm, prompt)
387
+ if llm_answer is None:
388
+ return self._create_prediction_internal_error(domain, tracker)
389
+
390
+ if self.citation_enabled:
391
+ llm_answer = self.post_process_citations(llm_answer)
392
+
393
+ logger.debug(f"{logger_key}.llm_answer", llm_answer=llm_answer)
394
+ action_metadata = {
395
+ "message": {
396
+ "text": llm_answer,
397
+ }
398
+ }
399
+
400
+ return self._create_prediction(
401
+ domain=domain, tracker=tracker, action_metadata=action_metadata
402
+ )
403
+
404
+ def _render_prompt(
405
+ self, tracker: DialogueStateTracker, documents: List["Document"]
406
+ ) -> Text:
407
+ """Renders the prompt from the template.
408
+
409
+ Args:
410
+ tracker: The tracker containing the conversation history up to now.
411
+ documents: The documents retrieved from the vector store.
412
+
413
+ Returns:
414
+ The rendered prompt.
415
+ """
416
+ inputs = {
417
+ "current_conversation": tracker_as_readable_transcript(
418
+ tracker, max_turns=self.max_history
419
+ ),
420
+ "docs": documents,
421
+ "slots": self._prepare_slots_for_template(tracker),
422
+ "citation_enabled": self.citation_enabled,
423
+ }
424
+ prompt = Template(self.prompt_template).render(**inputs)
425
+ log_llm(
426
+ logger=logger,
427
+ log_module="EnterpriseSearchPolicy",
428
+ log_event="enterprise_search_policy._render_prompt.prompt_rendered",
429
+ prompt=prompt,
430
+ )
431
+ return prompt
432
+
433
+ async def _generate_llm_answer(
434
+ self, llm: "BaseLLM", prompt: Text
435
+ ) -> Optional[Text]:
436
+ try:
437
+ llm_answer = await llm.apredict(prompt)
438
+ except Exception as e:
439
+ # unfortunately, langchain does not wrap LLM exceptions which means
440
+ # we have to catch all exceptions here
441
+ logger.error(
442
+ "enterprise_search_policy._generate_llm_answer.llm_error",
443
+ error=e,
444
+ )
445
+ llm_answer = None
446
+
447
+ return llm_answer
448
+
449
+ def _create_prediction(
450
+ self,
451
+ domain: Domain,
452
+ tracker: DialogueStateTracker,
453
+ action_metadata: Dict[Text, Any],
454
+ ) -> PolicyPrediction:
455
+ """Create a policy prediction result with ACTION_SEND_TEXT_NAME.
456
+
457
+ Args:
458
+ domain: The model's domain.
459
+ tracker: The tracker containing the conversation history up to now.
460
+ action_metadata: The metadata for the predicted action.
461
+
462
+ Returns:
463
+ The prediction.
464
+ """
465
+ result = self._prediction_result(ACTION_SEND_TEXT_NAME, domain)
466
+ stack = tracker.stack
467
+ if not stack.is_empty():
468
+ stack.pop()
469
+ events: List[Event] = tracker.create_stack_updated_events(stack)
470
+ else:
471
+ events = []
472
+
473
+ return self._prediction(result, action_metadata=action_metadata, events=events)
474
+
475
+ def _create_prediction_internal_error(
476
+ self, domain: Domain, tracker: DialogueStateTracker
477
+ ) -> PolicyPrediction:
478
+ return self._create_prediction_for_pattern(
479
+ domain, tracker, InternalErrorPatternFlowStackFrame()
480
+ )
481
+
482
+ def _create_prediction_cannot_handle(
483
+ self, domain: Domain, tracker: DialogueStateTracker
484
+ ) -> PolicyPrediction:
485
+ return self._create_prediction_for_pattern(
486
+ domain, tracker, CannotHandlePatternFlowStackFrame()
487
+ )
488
+
489
+ def _create_prediction_for_pattern(
490
+ self,
491
+ domain: Domain,
492
+ tracker: DialogueStateTracker,
493
+ pattern_stack_frame: PatternFlowStackFrame,
494
+ ) -> PolicyPrediction:
495
+ """Create a policy prediction result for error.
496
+
497
+ We should cancel the current flow (hence ACTION_CANCEL_FLOW) and push a
498
+ pattern stack frame (Internal Error Pattern by default) to start the pattern.
499
+
500
+ Args:
501
+ domain: The model's domain.
502
+ tracker: The tracker containing the conversation history up to now.
503
+ pattern_stack_frame: The pattern stack frame to push.
504
+
505
+ Returns:
506
+ The prediction.
507
+ """
508
+ # TODO: replace ACTION_CANCEL_FLOW (ATO-2097)
509
+ result = self._prediction_result(ACTION_CANCEL_FLOW, domain)
510
+ stack = tracker.stack
511
+ if not stack.is_empty():
512
+ stack.pop()
513
+ stack.push(pattern_stack_frame)
514
+ events: List[Event] = tracker.create_stack_updated_events(stack)
515
+ return self._prediction(result, action_metadata=None, events=events)
516
+
517
+ def _prediction_result(
518
+ self, action_name: Optional[Text], domain: Domain, score: Optional[float] = 1.0
519
+ ) -> List[float]:
520
+ """Creates a prediction result.
521
+
522
+ Args:
523
+ action_name: The name of the predicted action.
524
+ domain: The model's domain.
525
+ score: The score of the predicted action.
526
+
527
+ Returns:
528
+ The prediction result where the score is used for one hot encoding.
529
+ """
530
+ result = self._default_predictions(domain)
531
+ if action_name:
532
+ result[domain.index_for_action(action_name)] = score # type: ignore[assignment] # noqa: E501
533
+ return result
534
+
535
+ @classmethod
536
+ def load(
537
+ cls,
538
+ config: Dict[Text, Any],
539
+ model_storage: ModelStorage,
540
+ resource: Resource,
541
+ execution_context: ExecutionContext,
542
+ **kwargs: Any,
543
+ ) -> "EnterpriseSearchPolicy":
544
+ """Loads a trained policy (see parent class for full docstring)."""
545
+ prompt_template = None
546
+ store_type = config.get(VECTOR_STORE_PROPERTY, {}).get(
547
+ VECTOR_STORE_TYPE_PROPERTY
548
+ )
549
+
550
+ embeddings = cls._create_plain_embedder(config)
551
+ logger.info("enterprise_search_policy.load", config=config)
552
+ if store_type == DEFAULT_VECTOR_STORE_TYPE:
553
+ # if a vector store is not specified,
554
+ # default to using FAISS with the index stored in the model
555
+ # TODO figure out a way to get path without context manager
556
+ with model_storage.read_from(resource) as path:
557
+ vector_store = FAISS_Store(
558
+ embeddings=embeddings,
559
+ index_path=path,
560
+ docs_folder=None,
561
+ create_index=False,
562
+ )
563
+ else:
564
+ vector_store = create_from_endpoint_config(
565
+ config_type=store_type,
566
+ embeddings=embeddings,
567
+ ) # type: ignore
568
+ try:
569
+ with model_storage.read_from(resource) as path:
570
+ prompt_template = rasa.shared.utils.io.read_file(
571
+ path / ENTERPRISE_SEARCH_PROMPT_FILE_NAME
572
+ )
573
+
574
+ except (FileNotFoundError, FileNotFoundError) as e:
575
+ logger.warning(
576
+ "enterprise_search_policy.load.failed", error=e, resource=resource.name
577
+ )
578
+
579
+ return cls(
580
+ config,
581
+ model_storage,
582
+ resource,
583
+ execution_context,
584
+ vector_store=vector_store,
585
+ prompt_template=prompt_template,
586
+ )
587
+
588
+ @classmethod
589
+ def _get_local_knowledge_data(cls, config: Dict[str, Any]) -> Optional[List[str]]:
590
+ """This is required only for local knowledge base types.
591
+
592
+ e.g. FAISS, to ensure that the graph component is retrained when the knowledge
593
+ base is updated.
594
+ """
595
+ merged_config = {**cls.get_default_config(), **config}
596
+
597
+ store_type = merged_config.get(VECTOR_STORE_PROPERTY, {}).get(
598
+ VECTOR_STORE_TYPE_PROPERTY
599
+ )
600
+ if store_type != DEFAULT_VECTOR_STORE_TYPE:
601
+ return None
602
+
603
+ source = merged_config.get(VECTOR_STORE_PROPERTY, {}).get(SOURCE_PROPERTY)
604
+ if not source:
605
+ return None
606
+
607
+ docs = FAISS_Store.load_documents(source)
608
+
609
+ if len(docs) == 0:
610
+ return None
611
+
612
+ docs_as_strings = [
613
+ json.dumps(doc.dict(), ensure_ascii=False, sort_keys=True) for doc in docs
614
+ ]
615
+ return sorted(docs_as_strings)
616
+
617
+ @classmethod
618
+ def fingerprint_addon(cls, config: Dict[str, Any]) -> Optional[str]:
619
+ """Add a fingerprint of the knowledge base and prompt template for the graph."""
620
+ local_knowledge_data = cls._get_local_knowledge_data(config)
621
+
622
+ prompt_template = get_prompt_template(
623
+ config.get("prompt"),
624
+ DEFAULT_ENTERPRISE_SEARCH_PROMPT_TEMPLATE,
625
+ )
626
+ return deep_container_fingerprint([prompt_template, local_knowledge_data])
627
+
628
+ @staticmethod
629
+ def post_process_citations(llm_answer: str) -> str:
630
+ """Post-process the LLM answer.
631
+
632
+ Re-writes the bracketed numbers to start from 1 and
633
+ re-arranges the sources to follow the enumeration order.
634
+
635
+ Args:
636
+ llm_answer: The LLM answer.
637
+
638
+ Returns:
639
+ The post-processed LLM answer.
640
+ """
641
+ logger.debug(
642
+ "enterprise_search_policy.post_process_citations", llm_answer=llm_answer
643
+ )
644
+
645
+ # Split llm_answer into answer and citations
646
+ try:
647
+ answer, citations = llm_answer.rsplit("Sources:", 1)
648
+ except ValueError:
649
+ # if there is no "Sources:" in the llm_answer
650
+ return llm_answer
651
+
652
+ # Find all source references in the answer
653
+ pattern = r"\[\s*(\d+(?:\s*,\s*\d+)*)\s*\]"
654
+ matches = re.findall(pattern, answer)
655
+ old_source_indices = [
656
+ int(num.strip()) for match in matches for num in match.split(",")
657
+ ]
658
+
659
+ # Map old source references to the correct enumeration
660
+ renumber_mapping = {num: idx + 1 for idx, num in enumerate(old_source_indices)}
661
+
662
+ # remove whitespace from original source citations in answer
663
+ for match in matches:
664
+ answer = answer.replace(f"[{match}]", f"[{match.replace(' ', '')}]")
665
+
666
+ new_answer = []
667
+ for word in answer.split():
668
+ matches = re.findall(pattern, word)
669
+ if matches:
670
+ for match in matches:
671
+ if "," in match:
672
+ old_indices = [
673
+ int(num.strip()) for num in match.split(",") if num
674
+ ]
675
+ new_indices = [
676
+ renumber_mapping[old_index]
677
+ for old_index in old_indices
678
+ if old_index in renumber_mapping
679
+ ]
680
+ if not new_indices:
681
+ continue
682
+
683
+ word = word.replace(
684
+ match, f"{', '.join(map(str, new_indices))}"
685
+ )
686
+ else:
687
+ old_index = int(match.strip("[].,:;?!"))
688
+ new_index = renumber_mapping.get(old_index)
689
+ if not new_index:
690
+ continue
691
+
692
+ word = word.replace(str(old_index), str(new_index))
693
+ new_answer.append(word)
694
+
695
+ # join the words
696
+ joined_answer = " ".join(new_answer)
697
+ joined_answer += "\nSources:\n"
698
+
699
+ new_sources: List[str] = []
700
+
701
+ for line in citations.split("\n"):
702
+ pattern = r"(?<=\[)\d+"
703
+ match = re.search(pattern, line)
704
+ if match:
705
+ old_index = int(match.group(0))
706
+ new_index = renumber_mapping[old_index]
707
+ # replace only the first occurrence of the old index
708
+ line = line.replace(f"[{old_index}]", f"[{new_index}]", 1)
709
+
710
+ # insert the line into the new_index position
711
+ new_sources.insert(new_index - 1, line)
712
+ elif line.strip():
713
+ new_sources.append(line)
714
+
715
+ joined_sources = "\n".join(new_sources)
716
+
717
+ return joined_answer + joined_sources