rasa-pro 3.8.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rasa-pro might be problematic. Click here for more details.

Files changed (644) hide show
  1. README.md +380 -0
  2. rasa/__init__.py +10 -0
  3. rasa/__main__.py +151 -0
  4. rasa/anonymization/__init__.py +2 -0
  5. rasa/anonymization/anonymisation_rule_yaml_reader.py +91 -0
  6. rasa/anonymization/anonymization_pipeline.py +287 -0
  7. rasa/anonymization/anonymization_rule_executor.py +260 -0
  8. rasa/anonymization/anonymization_rule_orchestrator.py +120 -0
  9. rasa/anonymization/schemas/config.yml +47 -0
  10. rasa/anonymization/utils.py +117 -0
  11. rasa/api.py +146 -0
  12. rasa/cli/__init__.py +5 -0
  13. rasa/cli/arguments/__init__.py +0 -0
  14. rasa/cli/arguments/data.py +81 -0
  15. rasa/cli/arguments/default_arguments.py +165 -0
  16. rasa/cli/arguments/evaluate.py +65 -0
  17. rasa/cli/arguments/export.py +51 -0
  18. rasa/cli/arguments/interactive.py +74 -0
  19. rasa/cli/arguments/run.py +204 -0
  20. rasa/cli/arguments/shell.py +13 -0
  21. rasa/cli/arguments/test.py +211 -0
  22. rasa/cli/arguments/train.py +263 -0
  23. rasa/cli/arguments/visualize.py +34 -0
  24. rasa/cli/arguments/x.py +30 -0
  25. rasa/cli/data.py +292 -0
  26. rasa/cli/e2e_test.py +566 -0
  27. rasa/cli/evaluate.py +222 -0
  28. rasa/cli/export.py +251 -0
  29. rasa/cli/inspect.py +63 -0
  30. rasa/cli/interactive.py +164 -0
  31. rasa/cli/license.py +65 -0
  32. rasa/cli/markers.py +78 -0
  33. rasa/cli/project_templates/__init__.py +0 -0
  34. rasa/cli/project_templates/calm/actions/__init__.py +0 -0
  35. rasa/cli/project_templates/calm/actions/action_template.py +27 -0
  36. rasa/cli/project_templates/calm/actions/add_contact.py +30 -0
  37. rasa/cli/project_templates/calm/actions/db.py +57 -0
  38. rasa/cli/project_templates/calm/actions/list_contacts.py +22 -0
  39. rasa/cli/project_templates/calm/actions/remove_contact.py +35 -0
  40. rasa/cli/project_templates/calm/config.yml +12 -0
  41. rasa/cli/project_templates/calm/credentials.yml +33 -0
  42. rasa/cli/project_templates/calm/data/flows/add_contact.yml +31 -0
  43. rasa/cli/project_templates/calm/data/flows/list_contacts.yml +14 -0
  44. rasa/cli/project_templates/calm/data/flows/remove_contact.yml +29 -0
  45. rasa/cli/project_templates/calm/db/contacts.json +10 -0
  46. rasa/cli/project_templates/calm/domain/add_contact.yml +33 -0
  47. rasa/cli/project_templates/calm/domain/list_contacts.yml +14 -0
  48. rasa/cli/project_templates/calm/domain/remove_contact.yml +31 -0
  49. rasa/cli/project_templates/calm/domain/shared.yml +5 -0
  50. rasa/cli/project_templates/calm/e2e_tests/cancelations/user_cancels_during_a_correction.yml +16 -0
  51. rasa/cli/project_templates/calm/e2e_tests/cancelations/user_changes_mind_on_a_whim.yml +7 -0
  52. rasa/cli/project_templates/calm/e2e_tests/corrections/user_corrects_contact_handle.yml +20 -0
  53. rasa/cli/project_templates/calm/e2e_tests/corrections/user_corrects_contact_name.yml +19 -0
  54. rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_adds_contact_to_their_list.yml +15 -0
  55. rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_lists_contacts.yml +5 -0
  56. rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_removes_contact.yml +11 -0
  57. rasa/cli/project_templates/calm/e2e_tests/happy_paths/user_removes_contact_from_list.yml +12 -0
  58. rasa/cli/project_templates/calm/endpoints.yml +45 -0
  59. rasa/cli/project_templates/default/actions/__init__.py +0 -0
  60. rasa/cli/project_templates/default/actions/actions.py +27 -0
  61. rasa/cli/project_templates/default/config.yml +44 -0
  62. rasa/cli/project_templates/default/credentials.yml +33 -0
  63. rasa/cli/project_templates/default/data/nlu.yml +91 -0
  64. rasa/cli/project_templates/default/data/rules.yml +13 -0
  65. rasa/cli/project_templates/default/data/stories.yml +30 -0
  66. rasa/cli/project_templates/default/domain.yml +34 -0
  67. rasa/cli/project_templates/default/endpoints.yml +42 -0
  68. rasa/cli/project_templates/default/tests/test_stories.yml +91 -0
  69. rasa/cli/project_templates/tutorial/actions.py +22 -0
  70. rasa/cli/project_templates/tutorial/config.yml +11 -0
  71. rasa/cli/project_templates/tutorial/credentials.yml +33 -0
  72. rasa/cli/project_templates/tutorial/data/flows.yml +8 -0
  73. rasa/cli/project_templates/tutorial/domain.yml +17 -0
  74. rasa/cli/project_templates/tutorial/endpoints.yml +45 -0
  75. rasa/cli/run.py +136 -0
  76. rasa/cli/scaffold.py +268 -0
  77. rasa/cli/shell.py +141 -0
  78. rasa/cli/studio/__init__.py +0 -0
  79. rasa/cli/studio/download.py +51 -0
  80. rasa/cli/studio/studio.py +110 -0
  81. rasa/cli/studio/train.py +59 -0
  82. rasa/cli/studio/upload.py +85 -0
  83. rasa/cli/telemetry.py +90 -0
  84. rasa/cli/test.py +280 -0
  85. rasa/cli/train.py +260 -0
  86. rasa/cli/utils.py +453 -0
  87. rasa/cli/visualize.py +40 -0
  88. rasa/cli/x.py +205 -0
  89. rasa/constants.py +37 -0
  90. rasa/core/__init__.py +17 -0
  91. rasa/core/actions/__init__.py +0 -0
  92. rasa/core/actions/action.py +1450 -0
  93. rasa/core/actions/action_clean_stack.py +59 -0
  94. rasa/core/actions/action_run_slot_rejections.py +207 -0
  95. rasa/core/actions/action_trigger_chitchat.py +31 -0
  96. rasa/core/actions/action_trigger_flow.py +109 -0
  97. rasa/core/actions/action_trigger_search.py +31 -0
  98. rasa/core/actions/constants.py +2 -0
  99. rasa/core/actions/forms.py +737 -0
  100. rasa/core/actions/loops.py +111 -0
  101. rasa/core/actions/two_stage_fallback.py +186 -0
  102. rasa/core/agent.py +557 -0
  103. rasa/core/auth_retry_tracker_store.py +122 -0
  104. rasa/core/brokers/__init__.py +0 -0
  105. rasa/core/brokers/broker.py +126 -0
  106. rasa/core/brokers/file.py +58 -0
  107. rasa/core/brokers/kafka.py +322 -0
  108. rasa/core/brokers/pika.py +387 -0
  109. rasa/core/brokers/sql.py +86 -0
  110. rasa/core/channels/__init__.py +55 -0
  111. rasa/core/channels/audiocodes.py +463 -0
  112. rasa/core/channels/botframework.py +339 -0
  113. rasa/core/channels/callback.py +85 -0
  114. rasa/core/channels/channel.py +419 -0
  115. rasa/core/channels/console.py +243 -0
  116. rasa/core/channels/development_inspector.py +93 -0
  117. rasa/core/channels/facebook.py +422 -0
  118. rasa/core/channels/hangouts.py +335 -0
  119. rasa/core/channels/inspector/.eslintrc.cjs +25 -0
  120. rasa/core/channels/inspector/.gitignore +23 -0
  121. rasa/core/channels/inspector/README.md +54 -0
  122. rasa/core/channels/inspector/assets/favicon.ico +0 -0
  123. rasa/core/channels/inspector/assets/rasa-chat.js +2 -0
  124. rasa/core/channels/inspector/custom.d.ts +3 -0
  125. rasa/core/channels/inspector/dist/assets/arc-5623b6dc.js +1 -0
  126. rasa/core/channels/inspector/dist/assets/array-9f3ba611.js +1 -0
  127. rasa/core/channels/inspector/dist/assets/c4Diagram-d0fbc5ce-685c106a.js +10 -0
  128. rasa/core/channels/inspector/dist/assets/classDiagram-936ed81e-8cbed007.js +2 -0
  129. rasa/core/channels/inspector/dist/assets/classDiagram-v2-c3cb15f1-5889cf12.js +2 -0
  130. rasa/core/channels/inspector/dist/assets/createText-62fc7601-24c249d7.js +7 -0
  131. rasa/core/channels/inspector/dist/assets/edges-f2ad444c-7dd06a75.js +4 -0
  132. rasa/core/channels/inspector/dist/assets/erDiagram-9d236eb7-62c1e54c.js +51 -0
  133. rasa/core/channels/inspector/dist/assets/flowDb-1972c806-ce49b86f.js +6 -0
  134. rasa/core/channels/inspector/dist/assets/flowDiagram-7ea5b25a-4067e48f.js +4 -0
  135. rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-85583a23.js +1 -0
  136. rasa/core/channels/inspector/dist/assets/flowchart-elk-definition-abe16c3d-59fe4051.js +139 -0
  137. rasa/core/channels/inspector/dist/assets/ganttDiagram-9b5ea136-47e3a43b.js +266 -0
  138. rasa/core/channels/inspector/dist/assets/gitGraphDiagram-99d0ae7c-5a2ac0d9.js +70 -0
  139. rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-128cfa44.ttf +0 -0
  140. rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-21dbcb97.woff +0 -0
  141. rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-222b5e26.svg +329 -0
  142. rasa/core/channels/inspector/dist/assets/ibm-plex-mono-v4-latin-regular-9ad89b2a.woff2 +0 -0
  143. rasa/core/channels/inspector/dist/assets/index-268a75c0.js +1040 -0
  144. rasa/core/channels/inspector/dist/assets/index-2c4b9a3b-dfb8efc4.js +1 -0
  145. rasa/core/channels/inspector/dist/assets/index-3ee28881.css +1 -0
  146. rasa/core/channels/inspector/dist/assets/infoDiagram-736b4530-b0c470f2.js +7 -0
  147. rasa/core/channels/inspector/dist/assets/init-77b53fdd.js +1 -0
  148. rasa/core/channels/inspector/dist/assets/journeyDiagram-df861f2b-2edb829a.js +139 -0
  149. rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-60c05ee4.woff +0 -0
  150. rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-8335d9b8.svg +438 -0
  151. rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-9cc39c75.ttf +0 -0
  152. rasa/core/channels/inspector/dist/assets/lato-v14-latin-700-ead13ccf.woff2 +0 -0
  153. rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-16705655.woff2 +0 -0
  154. rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-5aeb07f9.woff +0 -0
  155. rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-9c459044.ttf +0 -0
  156. rasa/core/channels/inspector/dist/assets/lato-v14-latin-regular-9e2898a4.svg +435 -0
  157. rasa/core/channels/inspector/dist/assets/layout-b6873d69.js +1 -0
  158. rasa/core/channels/inspector/dist/assets/line-1efc5781.js +1 -0
  159. rasa/core/channels/inspector/dist/assets/linear-661e9b94.js +1 -0
  160. rasa/core/channels/inspector/dist/assets/mindmap-definition-beec6740-2d2e727f.js +109 -0
  161. rasa/core/channels/inspector/dist/assets/ordinal-ba9b4969.js +1 -0
  162. rasa/core/channels/inspector/dist/assets/path-53f90ab3.js +1 -0
  163. rasa/core/channels/inspector/dist/assets/pieDiagram-dbbf0591-9d3ea93d.js +35 -0
  164. rasa/core/channels/inspector/dist/assets/quadrantDiagram-4d7f4fd6-06a178a2.js +7 -0
  165. rasa/core/channels/inspector/dist/assets/requirementDiagram-6fc4c22a-0bfedffc.js +52 -0
  166. rasa/core/channels/inspector/dist/assets/sankeyDiagram-8f13d901-d76d0a04.js +8 -0
  167. rasa/core/channels/inspector/dist/assets/sequenceDiagram-b655622a-37bb4341.js +122 -0
  168. rasa/core/channels/inspector/dist/assets/stateDiagram-59f0c015-f52f7f57.js +1 -0
  169. rasa/core/channels/inspector/dist/assets/stateDiagram-v2-2b26beab-4a986a20.js +1 -0
  170. rasa/core/channels/inspector/dist/assets/styles-080da4f6-7dd9ae12.js +110 -0
  171. rasa/core/channels/inspector/dist/assets/styles-3dcbcfbf-46e1ca14.js +159 -0
  172. rasa/core/channels/inspector/dist/assets/styles-9c745c82-4a97439a.js +207 -0
  173. rasa/core/channels/inspector/dist/assets/svgDrawCommon-4835440b-823917a3.js +1 -0
  174. rasa/core/channels/inspector/dist/assets/timeline-definition-5b62e21b-9ea72896.js +61 -0
  175. rasa/core/channels/inspector/dist/assets/xychartDiagram-2b33534f-b631a8b6.js +7 -0
  176. rasa/core/channels/inspector/dist/index.html +39 -0
  177. rasa/core/channels/inspector/index.html +37 -0
  178. rasa/core/channels/inspector/jest.config.ts +13 -0
  179. rasa/core/channels/inspector/package.json +48 -0
  180. rasa/core/channels/inspector/setupTests.ts +2 -0
  181. rasa/core/channels/inspector/src/App.tsx +170 -0
  182. rasa/core/channels/inspector/src/components/DiagramFlow.tsx +97 -0
  183. rasa/core/channels/inspector/src/components/DialogueInformation.tsx +187 -0
  184. rasa/core/channels/inspector/src/components/DialogueStack.tsx +151 -0
  185. rasa/core/channels/inspector/src/components/ExpandIcon.tsx +16 -0
  186. rasa/core/channels/inspector/src/components/FullscreenButton.tsx +45 -0
  187. rasa/core/channels/inspector/src/components/LoadingSpinner.tsx +19 -0
  188. rasa/core/channels/inspector/src/components/NoActiveFlow.tsx +21 -0
  189. rasa/core/channels/inspector/src/components/RasaLogo.tsx +32 -0
  190. rasa/core/channels/inspector/src/components/SaraDiagrams.tsx +39 -0
  191. rasa/core/channels/inspector/src/components/Slots.tsx +91 -0
  192. rasa/core/channels/inspector/src/components/Welcome.tsx +54 -0
  193. rasa/core/channels/inspector/src/helpers/formatters.test.ts +385 -0
  194. rasa/core/channels/inspector/src/helpers/formatters.ts +239 -0
  195. rasa/core/channels/inspector/src/helpers/utils.ts +42 -0
  196. rasa/core/channels/inspector/src/main.tsx +13 -0
  197. rasa/core/channels/inspector/src/theme/Button/Button.ts +29 -0
  198. rasa/core/channels/inspector/src/theme/Heading/Heading.ts +31 -0
  199. rasa/core/channels/inspector/src/theme/Input/Input.ts +27 -0
  200. rasa/core/channels/inspector/src/theme/Link/Link.ts +10 -0
  201. rasa/core/channels/inspector/src/theme/Modal/Modal.ts +47 -0
  202. rasa/core/channels/inspector/src/theme/Table/Table.tsx +38 -0
  203. rasa/core/channels/inspector/src/theme/Tooltip/Tooltip.ts +12 -0
  204. rasa/core/channels/inspector/src/theme/base/breakpoints.ts +8 -0
  205. rasa/core/channels/inspector/src/theme/base/colors.ts +88 -0
  206. rasa/core/channels/inspector/src/theme/base/fonts/fontFaces.css +29 -0
  207. rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.eot +0 -0
  208. rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.svg +329 -0
  209. rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.ttf +0 -0
  210. rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.woff +0 -0
  211. rasa/core/channels/inspector/src/theme/base/fonts/ibm-plex-mono-v4-latin/ibm-plex-mono-v4-latin-regular.woff2 +0 -0
  212. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.eot +0 -0
  213. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.svg +438 -0
  214. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.ttf +0 -0
  215. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.woff +0 -0
  216. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-700.woff2 +0 -0
  217. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.eot +0 -0
  218. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.svg +435 -0
  219. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.ttf +0 -0
  220. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.woff +0 -0
  221. rasa/core/channels/inspector/src/theme/base/fonts/lato-v14-latin/lato-v14-latin-regular.woff2 +0 -0
  222. rasa/core/channels/inspector/src/theme/base/radii.ts +9 -0
  223. rasa/core/channels/inspector/src/theme/base/shadows.ts +7 -0
  224. rasa/core/channels/inspector/src/theme/base/sizes.ts +7 -0
  225. rasa/core/channels/inspector/src/theme/base/space.ts +15 -0
  226. rasa/core/channels/inspector/src/theme/base/styles.ts +13 -0
  227. rasa/core/channels/inspector/src/theme/base/typography.ts +24 -0
  228. rasa/core/channels/inspector/src/theme/base/zIndices.ts +19 -0
  229. rasa/core/channels/inspector/src/theme/index.ts +101 -0
  230. rasa/core/channels/inspector/src/types.ts +64 -0
  231. rasa/core/channels/inspector/src/vite-env.d.ts +1 -0
  232. rasa/core/channels/inspector/tests/__mocks__/fileMock.ts +1 -0
  233. rasa/core/channels/inspector/tests/__mocks__/matchMedia.ts +16 -0
  234. rasa/core/channels/inspector/tests/__mocks__/styleMock.ts +1 -0
  235. rasa/core/channels/inspector/tests/renderWithProviders.tsx +14 -0
  236. rasa/core/channels/inspector/tsconfig.json +26 -0
  237. rasa/core/channels/inspector/tsconfig.node.json +10 -0
  238. rasa/core/channels/inspector/vite.config.ts +8 -0
  239. rasa/core/channels/inspector/yarn.lock +6156 -0
  240. rasa/core/channels/mattermost.py +229 -0
  241. rasa/core/channels/rasa_chat.py +126 -0
  242. rasa/core/channels/rest.py +210 -0
  243. rasa/core/channels/rocketchat.py +175 -0
  244. rasa/core/channels/slack.py +620 -0
  245. rasa/core/channels/socketio.py +274 -0
  246. rasa/core/channels/telegram.py +298 -0
  247. rasa/core/channels/twilio.py +169 -0
  248. rasa/core/channels/twilio_voice.py +367 -0
  249. rasa/core/channels/vier_cvg.py +374 -0
  250. rasa/core/channels/webexteams.py +135 -0
  251. rasa/core/concurrent_lock_store.py +210 -0
  252. rasa/core/constants.py +107 -0
  253. rasa/core/evaluation/__init__.py +0 -0
  254. rasa/core/evaluation/marker.py +267 -0
  255. rasa/core/evaluation/marker_base.py +925 -0
  256. rasa/core/evaluation/marker_stats.py +294 -0
  257. rasa/core/evaluation/marker_tracker_loader.py +103 -0
  258. rasa/core/exceptions.py +29 -0
  259. rasa/core/exporter.py +284 -0
  260. rasa/core/featurizers/__init__.py +0 -0
  261. rasa/core/featurizers/precomputation.py +410 -0
  262. rasa/core/featurizers/single_state_featurizer.py +402 -0
  263. rasa/core/featurizers/tracker_featurizers.py +1172 -0
  264. rasa/core/http_interpreter.py +89 -0
  265. rasa/core/information_retrieval/__init__.py +0 -0
  266. rasa/core/information_retrieval/faiss.py +116 -0
  267. rasa/core/information_retrieval/information_retrieval.py +72 -0
  268. rasa/core/information_retrieval/milvus.py +59 -0
  269. rasa/core/information_retrieval/qdrant.py +102 -0
  270. rasa/core/jobs.py +63 -0
  271. rasa/core/lock.py +139 -0
  272. rasa/core/lock_store.py +344 -0
  273. rasa/core/migrate.py +404 -0
  274. rasa/core/nlg/__init__.py +3 -0
  275. rasa/core/nlg/callback.py +147 -0
  276. rasa/core/nlg/contextual_response_rephraser.py +270 -0
  277. rasa/core/nlg/generator.py +230 -0
  278. rasa/core/nlg/interpolator.py +143 -0
  279. rasa/core/nlg/response.py +155 -0
  280. rasa/core/nlg/summarize.py +69 -0
  281. rasa/core/policies/__init__.py +0 -0
  282. rasa/core/policies/ensemble.py +329 -0
  283. rasa/core/policies/enterprise_search_policy.py +717 -0
  284. rasa/core/policies/enterprise_search_prompt_template.jinja2 +62 -0
  285. rasa/core/policies/flow_policy.py +205 -0
  286. rasa/core/policies/flows/__init__.py +0 -0
  287. rasa/core/policies/flows/flow_exceptions.py +44 -0
  288. rasa/core/policies/flows/flow_executor.py +582 -0
  289. rasa/core/policies/flows/flow_step_result.py +43 -0
  290. rasa/core/policies/intentless_policy.py +924 -0
  291. rasa/core/policies/intentless_prompt_template.jinja2 +22 -0
  292. rasa/core/policies/memoization.py +538 -0
  293. rasa/core/policies/policy.py +716 -0
  294. rasa/core/policies/rule_policy.py +1276 -0
  295. rasa/core/policies/ted_policy.py +2146 -0
  296. rasa/core/policies/unexpected_intent_policy.py +1015 -0
  297. rasa/core/processor.py +1331 -0
  298. rasa/core/run.py +315 -0
  299. rasa/core/secrets_manager/__init__.py +0 -0
  300. rasa/core/secrets_manager/constants.py +32 -0
  301. rasa/core/secrets_manager/endpoints.py +391 -0
  302. rasa/core/secrets_manager/factory.py +233 -0
  303. rasa/core/secrets_manager/secret_manager.py +262 -0
  304. rasa/core/secrets_manager/vault.py +576 -0
  305. rasa/core/test.py +1337 -0
  306. rasa/core/tracker_store.py +1664 -0
  307. rasa/core/train.py +107 -0
  308. rasa/core/training/__init__.py +89 -0
  309. rasa/core/training/converters/__init__.py +0 -0
  310. rasa/core/training/converters/responses_prefix_converter.py +119 -0
  311. rasa/core/training/interactive.py +1742 -0
  312. rasa/core/training/story_conflict.py +381 -0
  313. rasa/core/training/training.py +93 -0
  314. rasa/core/utils.py +344 -0
  315. rasa/core/visualize.py +70 -0
  316. rasa/dialogue_understanding/__init__.py +0 -0
  317. rasa/dialogue_understanding/coexistence/__init__.py +0 -0
  318. rasa/dialogue_understanding/coexistence/constants.py +4 -0
  319. rasa/dialogue_understanding/coexistence/intent_based_router.py +189 -0
  320. rasa/dialogue_understanding/coexistence/llm_based_router.py +261 -0
  321. rasa/dialogue_understanding/coexistence/router_template.jinja2 +12 -0
  322. rasa/dialogue_understanding/commands/__init__.py +45 -0
  323. rasa/dialogue_understanding/commands/can_not_handle_command.py +61 -0
  324. rasa/dialogue_understanding/commands/cancel_flow_command.py +116 -0
  325. rasa/dialogue_understanding/commands/chit_chat_answer_command.py +48 -0
  326. rasa/dialogue_understanding/commands/clarify_command.py +77 -0
  327. rasa/dialogue_understanding/commands/command.py +85 -0
  328. rasa/dialogue_understanding/commands/correct_slots_command.py +288 -0
  329. rasa/dialogue_understanding/commands/error_command.py +67 -0
  330. rasa/dialogue_understanding/commands/free_form_answer_command.py +9 -0
  331. rasa/dialogue_understanding/commands/handle_code_change_command.py +64 -0
  332. rasa/dialogue_understanding/commands/human_handoff_command.py +57 -0
  333. rasa/dialogue_understanding/commands/knowledge_answer_command.py +48 -0
  334. rasa/dialogue_understanding/commands/noop_command.py +45 -0
  335. rasa/dialogue_understanding/commands/set_slot_command.py +125 -0
  336. rasa/dialogue_understanding/commands/skip_question_command.py +66 -0
  337. rasa/dialogue_understanding/commands/start_flow_command.py +98 -0
  338. rasa/dialogue_understanding/generator/__init__.py +6 -0
  339. rasa/dialogue_understanding/generator/command_generator.py +257 -0
  340. rasa/dialogue_understanding/generator/command_prompt_template.jinja2 +57 -0
  341. rasa/dialogue_understanding/generator/flow_document_template.jinja2 +4 -0
  342. rasa/dialogue_understanding/generator/flow_retrieval.py +410 -0
  343. rasa/dialogue_understanding/generator/llm_command_generator.py +637 -0
  344. rasa/dialogue_understanding/generator/nlu_command_adapter.py +157 -0
  345. rasa/dialogue_understanding/patterns/__init__.py +0 -0
  346. rasa/dialogue_understanding/patterns/cancel.py +111 -0
  347. rasa/dialogue_understanding/patterns/cannot_handle.py +43 -0
  348. rasa/dialogue_understanding/patterns/chitchat.py +37 -0
  349. rasa/dialogue_understanding/patterns/clarify.py +97 -0
  350. rasa/dialogue_understanding/patterns/code_change.py +41 -0
  351. rasa/dialogue_understanding/patterns/collect_information.py +90 -0
  352. rasa/dialogue_understanding/patterns/completed.py +40 -0
  353. rasa/dialogue_understanding/patterns/continue_interrupted.py +42 -0
  354. rasa/dialogue_understanding/patterns/correction.py +278 -0
  355. rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +243 -0
  356. rasa/dialogue_understanding/patterns/human_handoff.py +37 -0
  357. rasa/dialogue_understanding/patterns/internal_error.py +47 -0
  358. rasa/dialogue_understanding/patterns/search.py +37 -0
  359. rasa/dialogue_understanding/patterns/skip_question.py +38 -0
  360. rasa/dialogue_understanding/processor/__init__.py +0 -0
  361. rasa/dialogue_understanding/processor/command_processor.py +578 -0
  362. rasa/dialogue_understanding/processor/command_processor_component.py +39 -0
  363. rasa/dialogue_understanding/stack/__init__.py +0 -0
  364. rasa/dialogue_understanding/stack/dialogue_stack.py +178 -0
  365. rasa/dialogue_understanding/stack/frames/__init__.py +19 -0
  366. rasa/dialogue_understanding/stack/frames/chit_chat_frame.py +27 -0
  367. rasa/dialogue_understanding/stack/frames/dialogue_stack_frame.py +137 -0
  368. rasa/dialogue_understanding/stack/frames/flow_stack_frame.py +157 -0
  369. rasa/dialogue_understanding/stack/frames/pattern_frame.py +10 -0
  370. rasa/dialogue_understanding/stack/frames/search_frame.py +27 -0
  371. rasa/dialogue_understanding/stack/utils.py +211 -0
  372. rasa/e2e_test/__init__.py +0 -0
  373. rasa/e2e_test/constants.py +10 -0
  374. rasa/e2e_test/e2e_test_case.py +322 -0
  375. rasa/e2e_test/e2e_test_result.py +34 -0
  376. rasa/e2e_test/e2e_test_runner.py +659 -0
  377. rasa/e2e_test/e2e_test_schema.yml +67 -0
  378. rasa/engine/__init__.py +0 -0
  379. rasa/engine/caching.py +464 -0
  380. rasa/engine/constants.py +17 -0
  381. rasa/engine/exceptions.py +14 -0
  382. rasa/engine/graph.py +625 -0
  383. rasa/engine/loader.py +36 -0
  384. rasa/engine/recipes/__init__.py +0 -0
  385. rasa/engine/recipes/config_files/default_config.yml +44 -0
  386. rasa/engine/recipes/default_components.py +99 -0
  387. rasa/engine/recipes/default_recipe.py +1252 -0
  388. rasa/engine/recipes/graph_recipe.py +79 -0
  389. rasa/engine/recipes/recipe.py +93 -0
  390. rasa/engine/runner/__init__.py +0 -0
  391. rasa/engine/runner/dask.py +256 -0
  392. rasa/engine/runner/interface.py +49 -0
  393. rasa/engine/storage/__init__.py +0 -0
  394. rasa/engine/storage/local_model_storage.py +248 -0
  395. rasa/engine/storage/resource.py +110 -0
  396. rasa/engine/storage/storage.py +203 -0
  397. rasa/engine/training/__init__.py +0 -0
  398. rasa/engine/training/components.py +176 -0
  399. rasa/engine/training/fingerprinting.py +64 -0
  400. rasa/engine/training/graph_trainer.py +256 -0
  401. rasa/engine/training/hooks.py +164 -0
  402. rasa/engine/validation.py +839 -0
  403. rasa/env.py +5 -0
  404. rasa/exceptions.py +69 -0
  405. rasa/graph_components/__init__.py +0 -0
  406. rasa/graph_components/converters/__init__.py +0 -0
  407. rasa/graph_components/converters/nlu_message_converter.py +48 -0
  408. rasa/graph_components/providers/__init__.py +0 -0
  409. rasa/graph_components/providers/domain_for_core_training_provider.py +87 -0
  410. rasa/graph_components/providers/domain_provider.py +71 -0
  411. rasa/graph_components/providers/flows_provider.py +74 -0
  412. rasa/graph_components/providers/forms_provider.py +44 -0
  413. rasa/graph_components/providers/nlu_training_data_provider.py +56 -0
  414. rasa/graph_components/providers/responses_provider.py +44 -0
  415. rasa/graph_components/providers/rule_only_provider.py +49 -0
  416. rasa/graph_components/providers/story_graph_provider.py +43 -0
  417. rasa/graph_components/providers/training_tracker_provider.py +55 -0
  418. rasa/graph_components/validators/__init__.py +0 -0
  419. rasa/graph_components/validators/default_recipe_validator.py +552 -0
  420. rasa/graph_components/validators/finetuning_validator.py +302 -0
  421. rasa/hooks.py +113 -0
  422. rasa/jupyter.py +63 -0
  423. rasa/keys +1 -0
  424. rasa/markers/__init__.py +0 -0
  425. rasa/markers/marker.py +269 -0
  426. rasa/markers/marker_base.py +828 -0
  427. rasa/markers/upload.py +74 -0
  428. rasa/markers/validate.py +21 -0
  429. rasa/model.py +118 -0
  430. rasa/model_testing.py +457 -0
  431. rasa/model_training.py +535 -0
  432. rasa/nlu/__init__.py +7 -0
  433. rasa/nlu/classifiers/__init__.py +3 -0
  434. rasa/nlu/classifiers/classifier.py +5 -0
  435. rasa/nlu/classifiers/diet_classifier.py +1874 -0
  436. rasa/nlu/classifiers/fallback_classifier.py +192 -0
  437. rasa/nlu/classifiers/keyword_intent_classifier.py +188 -0
  438. rasa/nlu/classifiers/llm_intent_classifier.py +519 -0
  439. rasa/nlu/classifiers/logistic_regression_classifier.py +240 -0
  440. rasa/nlu/classifiers/mitie_intent_classifier.py +156 -0
  441. rasa/nlu/classifiers/regex_message_handler.py +56 -0
  442. rasa/nlu/classifiers/sklearn_intent_classifier.py +309 -0
  443. rasa/nlu/constants.py +77 -0
  444. rasa/nlu/convert.py +40 -0
  445. rasa/nlu/emulators/__init__.py +0 -0
  446. rasa/nlu/emulators/dialogflow.py +55 -0
  447. rasa/nlu/emulators/emulator.py +49 -0
  448. rasa/nlu/emulators/luis.py +86 -0
  449. rasa/nlu/emulators/no_emulator.py +10 -0
  450. rasa/nlu/emulators/wit.py +56 -0
  451. rasa/nlu/extractors/__init__.py +0 -0
  452. rasa/nlu/extractors/crf_entity_extractor.py +672 -0
  453. rasa/nlu/extractors/duckling_entity_extractor.py +206 -0
  454. rasa/nlu/extractors/entity_synonyms.py +178 -0
  455. rasa/nlu/extractors/extractor.py +470 -0
  456. rasa/nlu/extractors/mitie_entity_extractor.py +293 -0
  457. rasa/nlu/extractors/regex_entity_extractor.py +220 -0
  458. rasa/nlu/extractors/spacy_entity_extractor.py +95 -0
  459. rasa/nlu/featurizers/__init__.py +0 -0
  460. rasa/nlu/featurizers/dense_featurizer/__init__.py +0 -0
  461. rasa/nlu/featurizers/dense_featurizer/convert_featurizer.py +449 -0
  462. rasa/nlu/featurizers/dense_featurizer/dense_featurizer.py +57 -0
  463. rasa/nlu/featurizers/dense_featurizer/lm_featurizer.py +772 -0
  464. rasa/nlu/featurizers/dense_featurizer/mitie_featurizer.py +170 -0
  465. rasa/nlu/featurizers/dense_featurizer/spacy_featurizer.py +132 -0
  466. rasa/nlu/featurizers/featurizer.py +89 -0
  467. rasa/nlu/featurizers/sparse_featurizer/__init__.py +0 -0
  468. rasa/nlu/featurizers/sparse_featurizer/count_vectors_featurizer.py +840 -0
  469. rasa/nlu/featurizers/sparse_featurizer/lexical_syntactic_featurizer.py +539 -0
  470. rasa/nlu/featurizers/sparse_featurizer/regex_featurizer.py +269 -0
  471. rasa/nlu/featurizers/sparse_featurizer/sparse_featurizer.py +9 -0
  472. rasa/nlu/model.py +24 -0
  473. rasa/nlu/persistor.py +240 -0
  474. rasa/nlu/run.py +27 -0
  475. rasa/nlu/selectors/__init__.py +0 -0
  476. rasa/nlu/selectors/response_selector.py +990 -0
  477. rasa/nlu/test.py +1943 -0
  478. rasa/nlu/tokenizers/__init__.py +0 -0
  479. rasa/nlu/tokenizers/jieba_tokenizer.py +148 -0
  480. rasa/nlu/tokenizers/mitie_tokenizer.py +75 -0
  481. rasa/nlu/tokenizers/spacy_tokenizer.py +72 -0
  482. rasa/nlu/tokenizers/tokenizer.py +239 -0
  483. rasa/nlu/tokenizers/whitespace_tokenizer.py +106 -0
  484. rasa/nlu/utils/__init__.py +35 -0
  485. rasa/nlu/utils/bilou_utils.py +462 -0
  486. rasa/nlu/utils/hugging_face/__init__.py +0 -0
  487. rasa/nlu/utils/hugging_face/registry.py +108 -0
  488. rasa/nlu/utils/hugging_face/transformers_pre_post_processors.py +311 -0
  489. rasa/nlu/utils/mitie_utils.py +113 -0
  490. rasa/nlu/utils/pattern_utils.py +168 -0
  491. rasa/nlu/utils/spacy_utils.py +312 -0
  492. rasa/plugin.py +90 -0
  493. rasa/server.py +1536 -0
  494. rasa/shared/__init__.py +0 -0
  495. rasa/shared/constants.py +181 -0
  496. rasa/shared/core/__init__.py +0 -0
  497. rasa/shared/core/constants.py +168 -0
  498. rasa/shared/core/conversation.py +46 -0
  499. rasa/shared/core/domain.py +2106 -0
  500. rasa/shared/core/events.py +2507 -0
  501. rasa/shared/core/flows/__init__.py +7 -0
  502. rasa/shared/core/flows/flow.py +353 -0
  503. rasa/shared/core/flows/flow_step.py +146 -0
  504. rasa/shared/core/flows/flow_step_links.py +319 -0
  505. rasa/shared/core/flows/flow_step_sequence.py +70 -0
  506. rasa/shared/core/flows/flows_list.py +211 -0
  507. rasa/shared/core/flows/flows_yaml_schema.json +217 -0
  508. rasa/shared/core/flows/nlu_trigger.py +117 -0
  509. rasa/shared/core/flows/steps/__init__.py +24 -0
  510. rasa/shared/core/flows/steps/action.py +51 -0
  511. rasa/shared/core/flows/steps/call.py +64 -0
  512. rasa/shared/core/flows/steps/collect.py +112 -0
  513. rasa/shared/core/flows/steps/constants.py +5 -0
  514. rasa/shared/core/flows/steps/continuation.py +36 -0
  515. rasa/shared/core/flows/steps/end.py +22 -0
  516. rasa/shared/core/flows/steps/internal.py +44 -0
  517. rasa/shared/core/flows/steps/link.py +51 -0
  518. rasa/shared/core/flows/steps/no_operation.py +48 -0
  519. rasa/shared/core/flows/steps/set_slots.py +50 -0
  520. rasa/shared/core/flows/steps/start.py +30 -0
  521. rasa/shared/core/flows/validation.py +527 -0
  522. rasa/shared/core/flows/yaml_flows_io.py +278 -0
  523. rasa/shared/core/generator.py +907 -0
  524. rasa/shared/core/slot_mappings.py +235 -0
  525. rasa/shared/core/slots.py +647 -0
  526. rasa/shared/core/trackers.py +1159 -0
  527. rasa/shared/core/training_data/__init__.py +0 -0
  528. rasa/shared/core/training_data/loading.py +90 -0
  529. rasa/shared/core/training_data/story_reader/__init__.py +0 -0
  530. rasa/shared/core/training_data/story_reader/story_reader.py +129 -0
  531. rasa/shared/core/training_data/story_reader/story_step_builder.py +168 -0
  532. rasa/shared/core/training_data/story_reader/yaml_story_reader.py +888 -0
  533. rasa/shared/core/training_data/story_writer/__init__.py +0 -0
  534. rasa/shared/core/training_data/story_writer/story_writer.py +76 -0
  535. rasa/shared/core/training_data/story_writer/yaml_story_writer.py +442 -0
  536. rasa/shared/core/training_data/structures.py +838 -0
  537. rasa/shared/core/training_data/visualization.html +146 -0
  538. rasa/shared/core/training_data/visualization.py +603 -0
  539. rasa/shared/data.py +192 -0
  540. rasa/shared/engine/__init__.py +0 -0
  541. rasa/shared/engine/caching.py +26 -0
  542. rasa/shared/exceptions.py +129 -0
  543. rasa/shared/importers/__init__.py +0 -0
  544. rasa/shared/importers/importer.py +705 -0
  545. rasa/shared/importers/multi_project.py +203 -0
  546. rasa/shared/importers/rasa.py +100 -0
  547. rasa/shared/importers/utils.py +34 -0
  548. rasa/shared/nlu/__init__.py +0 -0
  549. rasa/shared/nlu/constants.py +45 -0
  550. rasa/shared/nlu/interpreter.py +10 -0
  551. rasa/shared/nlu/training_data/__init__.py +0 -0
  552. rasa/shared/nlu/training_data/entities_parser.py +209 -0
  553. rasa/shared/nlu/training_data/features.py +374 -0
  554. rasa/shared/nlu/training_data/formats/__init__.py +10 -0
  555. rasa/shared/nlu/training_data/formats/dialogflow.py +162 -0
  556. rasa/shared/nlu/training_data/formats/luis.py +87 -0
  557. rasa/shared/nlu/training_data/formats/rasa.py +135 -0
  558. rasa/shared/nlu/training_data/formats/rasa_yaml.py +605 -0
  559. rasa/shared/nlu/training_data/formats/readerwriter.py +245 -0
  560. rasa/shared/nlu/training_data/formats/wit.py +52 -0
  561. rasa/shared/nlu/training_data/loading.py +137 -0
  562. rasa/shared/nlu/training_data/lookup_tables_parser.py +30 -0
  563. rasa/shared/nlu/training_data/message.py +477 -0
  564. rasa/shared/nlu/training_data/schemas/__init__.py +0 -0
  565. rasa/shared/nlu/training_data/schemas/data_schema.py +85 -0
  566. rasa/shared/nlu/training_data/schemas/nlu.yml +53 -0
  567. rasa/shared/nlu/training_data/schemas/responses.yml +70 -0
  568. rasa/shared/nlu/training_data/synonyms_parser.py +42 -0
  569. rasa/shared/nlu/training_data/training_data.py +732 -0
  570. rasa/shared/nlu/training_data/util.py +223 -0
  571. rasa/shared/providers/__init__.py +0 -0
  572. rasa/shared/providers/openai/__init__.py +0 -0
  573. rasa/shared/providers/openai/clients.py +43 -0
  574. rasa/shared/providers/openai/session_handler.py +110 -0
  575. rasa/shared/utils/__init__.py +0 -0
  576. rasa/shared/utils/cli.py +72 -0
  577. rasa/shared/utils/common.py +308 -0
  578. rasa/shared/utils/constants.py +1 -0
  579. rasa/shared/utils/io.py +403 -0
  580. rasa/shared/utils/llm.py +405 -0
  581. rasa/shared/utils/pykwalify_extensions.py +26 -0
  582. rasa/shared/utils/schemas/__init__.py +0 -0
  583. rasa/shared/utils/schemas/config.yml +2 -0
  584. rasa/shared/utils/schemas/domain.yml +142 -0
  585. rasa/shared/utils/schemas/events.py +212 -0
  586. rasa/shared/utils/schemas/model_config.yml +46 -0
  587. rasa/shared/utils/schemas/stories.yml +173 -0
  588. rasa/shared/utils/yaml.py +777 -0
  589. rasa/studio/__init__.py +0 -0
  590. rasa/studio/auth.py +252 -0
  591. rasa/studio/config.py +127 -0
  592. rasa/studio/constants.py +16 -0
  593. rasa/studio/data_handler.py +352 -0
  594. rasa/studio/download.py +350 -0
  595. rasa/studio/train.py +136 -0
  596. rasa/studio/upload.py +408 -0
  597. rasa/telemetry.py +1583 -0
  598. rasa/tracing/__init__.py +0 -0
  599. rasa/tracing/config.py +338 -0
  600. rasa/tracing/constants.py +38 -0
  601. rasa/tracing/instrumentation/__init__.py +0 -0
  602. rasa/tracing/instrumentation/attribute_extractors.py +663 -0
  603. rasa/tracing/instrumentation/instrumentation.py +939 -0
  604. rasa/tracing/instrumentation/intentless_policy_instrumentation.py +142 -0
  605. rasa/tracing/instrumentation/metrics.py +206 -0
  606. rasa/tracing/metric_instrument_provider.py +125 -0
  607. rasa/utils/__init__.py +0 -0
  608. rasa/utils/beta.py +83 -0
  609. rasa/utils/cli.py +27 -0
  610. rasa/utils/common.py +635 -0
  611. rasa/utils/converter.py +53 -0
  612. rasa/utils/endpoints.py +303 -0
  613. rasa/utils/io.py +326 -0
  614. rasa/utils/licensing.py +319 -0
  615. rasa/utils/log_utils.py +174 -0
  616. rasa/utils/mapper.py +210 -0
  617. rasa/utils/ml_utils.py +145 -0
  618. rasa/utils/plotting.py +362 -0
  619. rasa/utils/singleton.py +23 -0
  620. rasa/utils/tensorflow/__init__.py +0 -0
  621. rasa/utils/tensorflow/callback.py +112 -0
  622. rasa/utils/tensorflow/constants.py +116 -0
  623. rasa/utils/tensorflow/crf.py +492 -0
  624. rasa/utils/tensorflow/data_generator.py +440 -0
  625. rasa/utils/tensorflow/environment.py +161 -0
  626. rasa/utils/tensorflow/exceptions.py +5 -0
  627. rasa/utils/tensorflow/layers.py +1565 -0
  628. rasa/utils/tensorflow/layers_utils.py +113 -0
  629. rasa/utils/tensorflow/metrics.py +281 -0
  630. rasa/utils/tensorflow/model_data.py +991 -0
  631. rasa/utils/tensorflow/model_data_utils.py +500 -0
  632. rasa/utils/tensorflow/models.py +936 -0
  633. rasa/utils/tensorflow/rasa_layers.py +1094 -0
  634. rasa/utils/tensorflow/transformer.py +640 -0
  635. rasa/utils/tensorflow/types.py +6 -0
  636. rasa/utils/train_utils.py +572 -0
  637. rasa/utils/yaml.py +54 -0
  638. rasa/validator.py +1035 -0
  639. rasa/version.py +3 -0
  640. rasa_pro-3.8.16.dist-info/METADATA +528 -0
  641. rasa_pro-3.8.16.dist-info/NOTICE +5 -0
  642. rasa_pro-3.8.16.dist-info/RECORD +644 -0
  643. rasa_pro-3.8.16.dist-info/WHEEL +4 -0
  644. rasa_pro-3.8.16.dist-info/entry_points.txt +3 -0
@@ -0,0 +1,991 @@
1
+ import logging
2
+ from typing import (
3
+ Optional,
4
+ DefaultDict,
5
+ Dict,
6
+ Iterable,
7
+ Text,
8
+ List,
9
+ Tuple,
10
+ Any,
11
+ Union,
12
+ NamedTuple,
13
+ ItemsView,
14
+ overload,
15
+ cast,
16
+ )
17
+ from collections import defaultdict, OrderedDict
18
+
19
+ import numpy as np
20
+ import scipy.sparse
21
+ from sklearn.model_selection import train_test_split
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ def ragged_array_to_ndarray(ragged_array: Iterable[np.ndarray]) -> np.ndarray:
27
+ """Converts ragged array to numpy array.
28
+
29
+ Ragged array, also known as a jagged array, irregular array is an array of
30
+ arrays of which the member arrays can be of different lengths.
31
+ Try to convert as is (preserves type), if it fails because not all numpy arrays have
32
+ the same shape, then creates numpy array of objects.
33
+ """
34
+ try:
35
+ return np.array(ragged_array)
36
+ except ValueError:
37
+ return np.array(ragged_array, dtype=object)
38
+
39
+
40
+ class FeatureArray(np.ndarray):
41
+ """Stores any kind of features ready to be used by a RasaModel.
42
+
43
+ Next to the input numpy array of features, it also received the number of
44
+ dimensions of the features.
45
+ As our features can have 1 to 4 dimensions we might have different number of numpy
46
+ arrays stacked. The number of dimensions helps us to figure out how to handle this
47
+ particular feature array. Also, it is automatically determined whether the feature
48
+ array is sparse or not and the number of units is determined as well.
49
+
50
+ Subclassing np.array: https://numpy.org/doc/stable/user/basics.subclassing.html
51
+ """
52
+
53
+ def __new__(
54
+ cls, input_array: np.ndarray, number_of_dimensions: int
55
+ ) -> "FeatureArray":
56
+ """Create and return a new object. See help(type) for accurate signature."""
57
+ FeatureArray._validate_number_of_dimensions(number_of_dimensions, input_array)
58
+
59
+ feature_array = np.asarray(input_array).view(cls)
60
+
61
+ if number_of_dimensions <= 2:
62
+ feature_array.units = input_array.shape[-1]
63
+ feature_array.is_sparse = isinstance(input_array[0], scipy.sparse.spmatrix)
64
+ elif number_of_dimensions == 3:
65
+ feature_array.units = input_array[0].shape[-1]
66
+ feature_array.is_sparse = isinstance(input_array[0], scipy.sparse.spmatrix)
67
+ elif number_of_dimensions == 4:
68
+ feature_array.units = input_array[0][0].shape[-1]
69
+ feature_array.is_sparse = isinstance(
70
+ input_array[0][0], scipy.sparse.spmatrix
71
+ )
72
+ else:
73
+ raise ValueError(
74
+ f"Number of dimensions '{number_of_dimensions}' currently not "
75
+ f"supported."
76
+ )
77
+
78
+ feature_array.number_of_dimensions = number_of_dimensions
79
+
80
+ return feature_array
81
+
82
+ def __init__(
83
+ self, input_array: Any, number_of_dimensions: int, **kwargs: Any
84
+ ) -> None:
85
+ """Initialize. FeatureArray.
86
+
87
+ Needed in order to avoid 'Invalid keyword argument number_of_dimensions
88
+ to function FeatureArray.__init__ '
89
+ Args:
90
+ input_array: the array that contains features
91
+ number_of_dimensions: number of dimensions in input_array
92
+ """
93
+ super().__init__(**kwargs)
94
+ self.number_of_dimensions = number_of_dimensions
95
+
96
+ def __array_finalize__(self, obj: Optional[np.ndarray]) -> None:
97
+ """This method is called when the system allocates a new array from obj.
98
+
99
+ Args:
100
+ obj: A subclass (subtype) of ndarray.
101
+ """
102
+ if obj is None:
103
+ return
104
+
105
+ self.units = getattr(obj, "units", None)
106
+ self.number_of_dimensions = getattr(obj, "number_of_dimensions", None) # type: ignore[assignment] # noqa:E501
107
+ self.is_sparse = getattr(obj, "is_sparse", None)
108
+
109
+ default_attributes = {
110
+ "units": self.units,
111
+ "number_of_dimensions": self.number_of_dimensions,
112
+ "is_spare": self.is_sparse,
113
+ }
114
+ self.__dict__.update(default_attributes)
115
+
116
+ # pytype: disable=attribute-error
117
+ def __array_ufunc__(
118
+ self, ufunc: Any, method: Text, *inputs: Any, **kwargs: Any
119
+ ) -> Any:
120
+ """Overwrite this method as we are subclassing numpy array.
121
+
122
+ Args:
123
+ ufunc: The ufunc object that was called.
124
+ method: A string indicating which Ufunc method was called
125
+ (one of "__call__", "reduce", "reduceat", "accumulate", "outer",
126
+ "inner").
127
+ *inputs: A tuple of the input arguments to the ufunc.
128
+ **kwargs: Any additional arguments
129
+
130
+ Returns:
131
+ The result of the operation.
132
+ """
133
+ f = {
134
+ "reduce": ufunc.reduce,
135
+ "accumulate": ufunc.accumulate,
136
+ "reduceat": ufunc.reduceat,
137
+ "outer": ufunc.outer,
138
+ "at": ufunc.at,
139
+ "__call__": ufunc,
140
+ }
141
+ # convert the inputs to np.ndarray to prevent recursion, call the function,
142
+ # then cast it back as FeatureArray
143
+ output = FeatureArray(
144
+ f[method](*(i.view(np.ndarray) for i in inputs), **kwargs),
145
+ number_of_dimensions=kwargs["number_of_dimensions"],
146
+ )
147
+ output.__dict__ = self.__dict__ # carry forward attributes
148
+ return output
149
+
150
+ def __reduce__(self) -> Tuple[Any, Any, Any]:
151
+ """Needed in order to pickle this object.
152
+
153
+ Returns:
154
+ A tuple.
155
+ """
156
+ pickled_state = super(FeatureArray, self).__reduce__()
157
+ if isinstance(pickled_state, str):
158
+ raise TypeError("np array __reduce__ returned string instead of tuple.")
159
+ new_state = pickled_state[2] + (
160
+ self.number_of_dimensions,
161
+ self.is_sparse,
162
+ self.units,
163
+ )
164
+ return pickled_state[0], pickled_state[1], new_state
165
+
166
+ def __setstate__(self, state: Any, **kwargs: Any) -> None:
167
+ """Sets the state.
168
+
169
+ Args:
170
+ state: The state argument must be a sequence that contains the following
171
+ elements version, shape, dtype, isFortan, rawdata.
172
+ **kwargs: Any additional parameter
173
+ """
174
+ # Needed in order to load the object
175
+ self.number_of_dimensions = state[-3]
176
+ self.is_sparse = state[-2]
177
+ self.units = state[-1]
178
+ super(FeatureArray, self).__setstate__(state[0:-3], **kwargs)
179
+
180
+ # pytype: enable=attribute-error
181
+
182
+ @staticmethod
183
+ def _validate_number_of_dimensions(
184
+ number_of_dimensions: int, input_array: np.ndarray
185
+ ) -> None:
186
+ """Validates if the the input array has given number of dimensions.
187
+
188
+ Args:
189
+ number_of_dimensions: number of dimensions
190
+ input_array: input array
191
+
192
+ Raises: ValueError in case the dimensions do not match
193
+ """
194
+ _sub_array = input_array
195
+ dim = 0
196
+ # Go number_of_dimensions into the given input_array
197
+ for i in range(1, number_of_dimensions + 1):
198
+ _sub_array = _sub_array[0]
199
+ if isinstance(_sub_array, scipy.sparse.spmatrix):
200
+ dim = i
201
+ break
202
+ if isinstance(_sub_array, np.ndarray) and _sub_array.shape[0] == 0:
203
+ # sequence dimension is 0, we are dealing with "fake" features
204
+ dim = i
205
+ break
206
+
207
+ # If the resulting sub_array is sparse, the remaining number of dimensions
208
+ # should be at least 2
209
+ if isinstance(_sub_array, scipy.sparse.spmatrix):
210
+ if dim > 2:
211
+ raise ValueError(
212
+ f"Given number of dimensions '{number_of_dimensions}' does not "
213
+ f"match dimensions of given input array: {input_array}."
214
+ )
215
+ elif isinstance(_sub_array, np.ndarray) and _sub_array.shape[0] == 0:
216
+ # sequence dimension is 0, we are dealing with "fake" features,
217
+ # but they should be of dim 2
218
+ if dim > 2:
219
+ raise ValueError(
220
+ f"Given number of dimensions '{number_of_dimensions}' does not "
221
+ f"match dimensions of given input array: {input_array}."
222
+ )
223
+ # If the resulting sub_array is dense, the sub_array should be a single number
224
+ elif not np.issubdtype(type(_sub_array), np.integer) and not isinstance(
225
+ _sub_array, (np.float32, np.float64)
226
+ ):
227
+ raise ValueError(
228
+ f"Given number of dimensions '{number_of_dimensions}' does not match "
229
+ f"dimensions of given input array: {input_array}."
230
+ )
231
+
232
+
233
+ class FeatureSignature(NamedTuple):
234
+ """Signature of feature arrays.
235
+
236
+ Stores the number of units, the type (sparse vs dense), and the number of
237
+ dimensions of features.
238
+ """
239
+
240
+ is_sparse: bool
241
+ units: Optional[int]
242
+ number_of_dimensions: int
243
+
244
+
245
+ # Mapping of attribute name and feature name to a list of feature arrays representing
246
+ # the actual features
247
+ # For example:
248
+ # "text" -> { "sentence": [
249
+ # "feature array containing dense features for every training example",
250
+ # "feature array containing sparse features for every training example"
251
+ # ]}
252
+ Data = Dict[Text, Dict[Text, List[FeatureArray]]]
253
+
254
+
255
+ class RasaModelData:
256
+ """Data object used for all RasaModels.
257
+
258
+ It contains all features needed to train the models.
259
+ 'data' is a mapping of attribute name, e.g. TEXT, INTENT, etc., and feature name,
260
+ e.g. SENTENCE, SEQUENCE, etc., to a list of feature arrays representing the actual
261
+ features.
262
+ 'label_key' and 'label_sub_key' point to the labels inside 'data'. For
263
+ example, if your intent labels are stored under INTENT -> IDS, 'label_key' would
264
+ be "INTENT" and 'label_sub_key' would be "IDS".
265
+ """
266
+
267
+ def __init__(
268
+ self,
269
+ label_key: Optional[Text] = None,
270
+ label_sub_key: Optional[Text] = None,
271
+ data: Optional[Data] = None,
272
+ ) -> None:
273
+ """Initializes the RasaModelData object.
274
+
275
+ Args:
276
+ label_key: the key of a label used for balancing, etc.
277
+ label_sub_key: the sub key of a label used for balancing, etc.
278
+ data: the data holding the features
279
+ """
280
+ self.data = data or defaultdict(lambda: defaultdict(list))
281
+ self.label_key = label_key
282
+ self.label_sub_key = label_sub_key
283
+ # should be updated when features are added
284
+ self.num_examples = self.number_of_examples()
285
+ self.sparse_feature_sizes: Dict[Text, Dict[Text, List[int]]] = {}
286
+
287
+ @overload
288
+ def get(self, key: Text, sub_key: Text) -> List[FeatureArray]:
289
+ ...
290
+
291
+ @overload
292
+ def get(self, key: Text, sub_key: None = ...) -> Dict[Text, List[FeatureArray]]:
293
+ ...
294
+
295
+ def get(
296
+ self, key: Text, sub_key: Optional[Text] = None
297
+ ) -> Union[Dict[Text, List[FeatureArray]], List[FeatureArray]]:
298
+ """Get the data under the given keys.
299
+
300
+ Args:
301
+ key: The key.
302
+ sub_key: The optional sub key.
303
+
304
+ Returns:
305
+ The requested data.
306
+ """
307
+ if sub_key is None and key in self.data:
308
+ return self.data[key]
309
+
310
+ if sub_key and key in self.data and sub_key in self.data[key]:
311
+ return self.data[key][sub_key]
312
+
313
+ return []
314
+
315
+ def items(self) -> ItemsView:
316
+ """Return the items of the data attribute.
317
+
318
+ Returns:
319
+ The items of data.
320
+ """
321
+ return self.data.items()
322
+
323
+ def values(self) -> Any:
324
+ """Return the values of the data attribute.
325
+
326
+ Returns:
327
+ The values of data.
328
+ """
329
+ return self.data.values()
330
+
331
+ def keys(self, key: Optional[Text] = None) -> List[Text]:
332
+ """Return the keys of the data attribute.
333
+
334
+ Args:
335
+ key: The optional key.
336
+
337
+ Returns:
338
+ The keys of the data.
339
+ """
340
+ if key is None:
341
+ return list(self.data.keys())
342
+
343
+ if key in self.data:
344
+ return list(self.data[key].keys())
345
+
346
+ return []
347
+
348
+ def sort(self) -> None:
349
+ """Sorts data according to its keys."""
350
+ for key, attribute_data in self.data.items():
351
+ self.data[key] = OrderedDict(sorted(attribute_data.items()))
352
+ self.data = OrderedDict(sorted(self.data.items()))
353
+
354
+ def first_data_example(self) -> Data:
355
+ """Return the data with just one feature example per key, sub-key.
356
+
357
+ Returns:
358
+ The simplified data.
359
+ """
360
+ out_data: Data = {}
361
+ for key, attribute_data in self.data.items():
362
+ out_data[key] = {}
363
+ for sub_key, features in attribute_data.items():
364
+ feature_slices = [feature[:1] for feature in features]
365
+ out_data[key][sub_key] = cast(List[FeatureArray], feature_slices)
366
+ return out_data
367
+
368
+ def does_feature_exist(self, key: Text, sub_key: Optional[Text] = None) -> bool:
369
+ """Check if feature key (and sub-key) is present and features are available.
370
+
371
+ Args:
372
+ key: The key.
373
+ sub_key: The optional sub-key.
374
+
375
+ Returns:
376
+ False, if no features for the given keys exists, True otherwise.
377
+ """
378
+ return not self.does_feature_not_exist(key, sub_key)
379
+
380
+ def does_feature_not_exist(self, key: Text, sub_key: Optional[Text] = None) -> bool:
381
+ """Check if feature key (and sub-key) is present and features are available.
382
+
383
+ Args:
384
+ key: The key.
385
+ sub_key: The optional sub-key.
386
+
387
+ Returns:
388
+ True, if no features for the given keys exists, False otherwise.
389
+ """
390
+ if sub_key:
391
+ return (
392
+ key not in self.data
393
+ or not self.data[key]
394
+ or sub_key not in self.data[key]
395
+ or not self.data[key][sub_key]
396
+ )
397
+
398
+ return key not in self.data or not self.data[key]
399
+
400
+ def is_empty(self) -> bool:
401
+ """Checks if data is set."""
402
+ return not self.data
403
+
404
+ def number_of_examples(self, data: Optional[Data] = None) -> int:
405
+ """Obtain number of examples in data.
406
+
407
+ Args:
408
+ data: The data.
409
+
410
+ Raises: A ValueError if number of examples differ for different features.
411
+
412
+ Returns:
413
+ The number of examples in data.
414
+ """
415
+ if not data:
416
+ data = self.data
417
+
418
+ if not data:
419
+ return 0
420
+
421
+ example_lengths = [
422
+ len(f)
423
+ for attribute_data in data.values()
424
+ for features in attribute_data.values()
425
+ for f in features
426
+ ]
427
+
428
+ if not example_lengths:
429
+ return 0
430
+
431
+ # check if number of examples is the same for all values
432
+ if not all(length == example_lengths[0] for length in example_lengths):
433
+ raise ValueError(
434
+ f"Number of examples differs for keys '{data.keys()}'. Number of "
435
+ f"examples should be the same for all data."
436
+ )
437
+
438
+ return example_lengths[0]
439
+
440
+ def number_of_units(self, key: Text, sub_key: Text) -> int:
441
+ """Get the number of units of the given key.
442
+
443
+ Args:
444
+ key: The key.
445
+ sub_key: The optional sub-key.
446
+
447
+ Returns:
448
+ The number of units.
449
+ """
450
+ if key not in self.data or sub_key not in self.data[key]:
451
+ return 0
452
+
453
+ units = 0
454
+ for features in self.data[key][sub_key]:
455
+ if len(features) > 0:
456
+ units += features.units # type: ignore[operator]
457
+
458
+ return units
459
+
460
+ def add_data(self, data: Data, key_prefix: Optional[Text] = None) -> None:
461
+ """Add incoming data to data.
462
+
463
+ Args:
464
+ data: The data to add.
465
+ key_prefix: Optional key prefix to use in front of the key value.
466
+ """
467
+ for key, attribute_data in data.items():
468
+ for sub_key, features in attribute_data.items():
469
+ if key_prefix:
470
+ self.add_features(f"{key_prefix}{key}", sub_key, features)
471
+ else:
472
+ self.add_features(key, sub_key, features)
473
+
474
+ def update_key(
475
+ self, from_key: Text, from_sub_key: Text, to_key: Text, to_sub_key: Text
476
+ ) -> None:
477
+ """Copies the features under the given keys to the new keys and deletes the old.
478
+
479
+ Args:
480
+ from_key: current feature key
481
+ from_sub_key: current feature sub-key
482
+ to_key: new key for feature
483
+ to_sub_key: new sub-key for feature
484
+ """
485
+ if from_key not in self.data or from_sub_key not in self.data[from_key]:
486
+ return
487
+
488
+ if to_key not in self.data:
489
+ self.data[to_key] = {}
490
+ self.data[to_key][to_sub_key] = self.get(from_key, from_sub_key)
491
+ del self.data[from_key][from_sub_key]
492
+
493
+ if not self.data[from_key]:
494
+ del self.data[from_key]
495
+
496
+ def add_features(
497
+ self, key: Text, sub_key: Text, features: Optional[List[FeatureArray]]
498
+ ) -> None:
499
+ """Add list of features to data under specified key.
500
+
501
+ Should update number of examples.
502
+
503
+ Args:
504
+ key: The key
505
+ sub_key: The sub-key
506
+ features: The features to add.
507
+ """
508
+ if features is None:
509
+ return
510
+
511
+ for feature_array in features:
512
+ if len(feature_array) > 0:
513
+ self.data[key][sub_key].append(feature_array)
514
+
515
+ if not self.data[key][sub_key]:
516
+ del self.data[key][sub_key]
517
+
518
+ # update number of examples
519
+ self.num_examples = self.number_of_examples()
520
+
521
+ def add_lengths(
522
+ self, key: Text, sub_key: Text, from_key: Text, from_sub_key: Text
523
+ ) -> None:
524
+ """Adds a feature array of lengths of sequences to data under given key.
525
+
526
+ Args:
527
+ key: The key to add the lengths to
528
+ sub_key: The sub-key to add the lengths to
529
+ from_key: The key to take the lengths from
530
+ from_sub_key: The sub-key to take the lengths from
531
+ """
532
+ if not self.data.get(from_key) or not self.data.get(from_key, {}).get(
533
+ from_sub_key
534
+ ):
535
+ return
536
+
537
+ self.data[key][sub_key] = []
538
+
539
+ for features in self.data[from_key][from_sub_key]:
540
+ if len(features) == 0:
541
+ continue
542
+
543
+ if features.number_of_dimensions == 4:
544
+ lengths = FeatureArray(
545
+ ragged_array_to_ndarray(
546
+ [
547
+ # add one more dim so that dialogue dim
548
+ # would be a sequence
549
+ np.array([[[x.shape[0]]] for x in _features])
550
+ for _features in features
551
+ ]
552
+ ),
553
+ number_of_dimensions=4,
554
+ )
555
+ else:
556
+ lengths = FeatureArray(
557
+ np.array([x.shape[0] for x in features]), number_of_dimensions=1
558
+ )
559
+ self.data[key][sub_key].extend([lengths])
560
+ break
561
+
562
+ def add_sparse_feature_sizes(
563
+ self, sparse_feature_sizes: Dict[Text, Dict[Text, List[int]]]
564
+ ) -> None:
565
+ """Adds a dictionary of feature sizes for different attributes.
566
+
567
+ Args:
568
+ sparse_feature_sizes: a dictionary of attribute that has sparse
569
+ features to a dictionary of a feature type
570
+ to a list of different sparse feature sizes.
571
+ """
572
+ self.sparse_feature_sizes = sparse_feature_sizes
573
+
574
+ def get_sparse_feature_sizes(self) -> Dict[Text, Dict[Text, List[int]]]:
575
+ """Get feature sizes of the model.
576
+
577
+ sparse_feature_sizes is a dictionary of attribute that has sparse features to
578
+ a dictionary of a feature type to a list of different sparse feature sizes.
579
+
580
+ Returns:
581
+ A dictionary of key and sub-key to a list of feature signatures
582
+ (same structure as the data attribute).
583
+ """
584
+ return self.sparse_feature_sizes
585
+
586
+ def split(
587
+ self, number_of_test_examples: int, random_seed: int
588
+ ) -> Tuple["RasaModelData", "RasaModelData"]:
589
+ """Create random hold out test set using stratified split.
590
+
591
+ Args:
592
+ number_of_test_examples: Number of test examples.
593
+ random_seed: Random seed.
594
+
595
+ Returns:
596
+ A tuple of train and test RasaModelData.
597
+ """
598
+ self._check_label_key()
599
+
600
+ if self.label_key is None or self.label_sub_key is None:
601
+ # randomly split data as no label key is set
602
+ multi_values = [
603
+ v
604
+ for attribute_data in self.data.values()
605
+ for data in attribute_data.values()
606
+ for v in data
607
+ ]
608
+ solo_values: List[Any] = [
609
+ []
610
+ for attribute_data in self.data.values()
611
+ for data in attribute_data.values()
612
+ for _ in data
613
+ ]
614
+ stratify = None
615
+ else:
616
+ # make sure that examples for each label value are in both split sets
617
+ label_ids = self._create_label_ids(
618
+ self.data[self.label_key][self.label_sub_key][0]
619
+ )
620
+ label_counts: Dict[int, int] = dict(
621
+ zip(
622
+ *np.unique(
623
+ label_ids,
624
+ return_counts=True,
625
+ axis=0,
626
+ )
627
+ )
628
+ )
629
+
630
+ self._check_train_test_sizes(number_of_test_examples, label_counts)
631
+
632
+ counts = np.array([label_counts[label] for label in label_ids])
633
+ # we perform stratified train test split,
634
+ # which insures every label is present in the train and test data
635
+ # this operation can be performed only for labels
636
+ # that contain several data points
637
+ multi_values = [
638
+ f[counts > 1].view(FeatureArray)
639
+ for attribute_data in self.data.values()
640
+ for features in attribute_data.values()
641
+ for f in features
642
+ ]
643
+ # collect data points that are unique for their label
644
+ solo_values = [
645
+ f[counts == 1]
646
+ for attribute_data in self.data.values()
647
+ for features in attribute_data.values()
648
+ for f in features
649
+ ]
650
+
651
+ stratify = label_ids[counts > 1]
652
+
653
+ output_values = train_test_split(
654
+ *multi_values,
655
+ test_size=number_of_test_examples,
656
+ random_state=random_seed,
657
+ stratify=stratify,
658
+ )
659
+
660
+ return self._convert_train_test_split(output_values, solo_values)
661
+
662
+ def get_signature(
663
+ self, data: Optional[Data] = None
664
+ ) -> Dict[Text, Dict[Text, List[FeatureSignature]]]:
665
+ """Get signature of RasaModelData.
666
+
667
+ Signature stores the shape and whether features are sparse or not for every key.
668
+
669
+ Returns:
670
+ A dictionary of key and sub-key to a list of feature signatures
671
+ (same structure as the data attribute).
672
+ """
673
+ if not data:
674
+ data = self.data
675
+
676
+ return {
677
+ key: {
678
+ sub_key: [
679
+ FeatureSignature(f.is_sparse, f.units, f.number_of_dimensions)
680
+ for f in features
681
+ ]
682
+ for sub_key, features in attribute_data.items()
683
+ }
684
+ for key, attribute_data in data.items()
685
+ }
686
+
687
+ def shuffled_data(self, data: Data) -> Data:
688
+ """Shuffle model data.
689
+
690
+ Args:
691
+ data: The data to shuffle
692
+
693
+ Returns:
694
+ The shuffled data.
695
+ """
696
+ ids = np.random.permutation(self.num_examples)
697
+ return self._data_for_ids(data, ids)
698
+
699
+ def balanced_data(self, data: Data, batch_size: int, shuffle: bool) -> Data:
700
+ """Mix model data to account for class imbalance.
701
+
702
+ This batching strategy puts rare classes approximately in every other batch,
703
+ by repeating them. Mimics stratified batching, but also takes into account
704
+ that more populated classes should appear more often.
705
+
706
+ Args:
707
+ data: The data.
708
+ batch_size: The batch size.
709
+ shuffle: Boolean indicating whether to shuffle the data or not.
710
+
711
+ Returns:
712
+ The balanced data.
713
+ """
714
+ self._check_label_key()
715
+
716
+ # skip balancing if labels are token based
717
+ if (
718
+ self.label_key is None
719
+ or self.label_sub_key is None
720
+ or data[self.label_key][self.label_sub_key][0][0].size > 1
721
+ ):
722
+ return data
723
+
724
+ label_ids = self._create_label_ids(data[self.label_key][self.label_sub_key][0])
725
+
726
+ unique_label_ids, counts_label_ids = np.unique(
727
+ label_ids, return_counts=True, axis=0
728
+ )
729
+ num_label_ids = len(unique_label_ids)
730
+
731
+ # group data points by their label
732
+ # need to call every time, so that the data is shuffled inside each class
733
+ data_by_label = self._split_by_label_ids(data, label_ids, unique_label_ids)
734
+
735
+ # running index inside each data grouped by labels
736
+ data_idx = [0] * num_label_ids
737
+ # number of cycles each label was passed
738
+ num_data_cycles = [0] * num_label_ids
739
+ # if a label was skipped in current batch
740
+ skipped = [False] * num_label_ids
741
+
742
+ new_data: DefaultDict[
743
+ Text, DefaultDict[Text, List[List[FeatureArray]]]
744
+ ] = defaultdict(lambda: defaultdict(list))
745
+
746
+ while min(num_data_cycles) == 0:
747
+ if shuffle:
748
+ indices_of_labels = np.random.permutation(num_label_ids)
749
+ else:
750
+ indices_of_labels = np.asarray(range(num_label_ids))
751
+
752
+ for index in indices_of_labels:
753
+ if num_data_cycles[index] > 0 and not skipped[index]:
754
+ skipped[index] = True
755
+ continue
756
+
757
+ skipped[index] = False
758
+
759
+ index_batch_size = (
760
+ int(counts_label_ids[index] / self.num_examples * batch_size) + 1
761
+ )
762
+
763
+ for key, attribute_data in data_by_label[index].items():
764
+ for sub_key, features in attribute_data.items():
765
+ for i, f in enumerate(features):
766
+ if len(new_data[key][sub_key]) < i + 1:
767
+ new_data[key][sub_key].append([])
768
+ new_data[key][sub_key][i].append(
769
+ f[data_idx[index] : data_idx[index] + index_batch_size]
770
+ )
771
+
772
+ data_idx[index] += index_batch_size
773
+ if data_idx[index] >= counts_label_ids[index]:
774
+ num_data_cycles[index] += 1
775
+ data_idx[index] = 0
776
+
777
+ if min(num_data_cycles) > 0:
778
+ break
779
+
780
+ final_data: Data = defaultdict(lambda: defaultdict(list))
781
+ for key, attribute_data in new_data.items():
782
+ for sub_key, features in attribute_data.items():
783
+ for f in features:
784
+ final_data[key][sub_key].append(
785
+ FeatureArray(
786
+ np.concatenate(f),
787
+ number_of_dimensions=f[0].number_of_dimensions,
788
+ )
789
+ )
790
+
791
+ return final_data
792
+
793
+ def _check_train_test_sizes(
794
+ self, number_of_test_examples: int, label_counts: Dict[Any, int]
795
+ ) -> None:
796
+ """Check whether the test data set is too large or too small.
797
+
798
+ Args:
799
+ number_of_test_examples: number of test examples
800
+ label_counts: number of labels
801
+
802
+ Raises:
803
+ A ValueError if the number of examples does not fit.
804
+ """
805
+ if number_of_test_examples >= self.num_examples - len(label_counts):
806
+ raise ValueError(
807
+ f"Test set of {number_of_test_examples} is too large. Remaining "
808
+ f"train set should be at least equal to number of classes "
809
+ f"{len(label_counts)}."
810
+ )
811
+ if number_of_test_examples < len(label_counts):
812
+ raise ValueError(
813
+ f"Test set of {number_of_test_examples} is too small. It should "
814
+ f"be at least equal to number of classes {label_counts}."
815
+ )
816
+
817
+ @staticmethod
818
+ def _data_for_ids(data: Optional[Data], ids: np.ndarray) -> Data:
819
+ """Filter model data by ids.
820
+
821
+ Args:
822
+ data: The data to filter
823
+ ids: The ids
824
+
825
+ Returns:
826
+ The filtered data
827
+ """
828
+ new_data: Data = defaultdict(lambda: defaultdict(list))
829
+
830
+ if data is None:
831
+ return new_data
832
+
833
+ for key, attribute_data in data.items():
834
+ for sub_key, features in attribute_data.items():
835
+ for f in features:
836
+ new_data[key][sub_key].append(f[ids])
837
+ return new_data
838
+
839
+ def _split_by_label_ids(
840
+ self, data: Optional[Data], label_ids: np.ndarray, unique_label_ids: np.ndarray
841
+ ) -> List["RasaModelData"]:
842
+ """Reorganize model data into a list of model data with the same labels.
843
+
844
+ Args:
845
+ data: The data
846
+ label_ids: The label ids
847
+ unique_label_ids: The unique label ids
848
+
849
+ Returns:
850
+ Reorganized RasaModelData
851
+ """
852
+ label_data = []
853
+ for label_id in unique_label_ids:
854
+ matching_ids = np.array(label_ids) == label_id
855
+ label_data.append(
856
+ RasaModelData(
857
+ self.label_key,
858
+ self.label_sub_key,
859
+ self._data_for_ids(data, matching_ids),
860
+ )
861
+ )
862
+ return label_data
863
+
864
+ def _check_label_key(self) -> None:
865
+ """Check if the label key exists.
866
+
867
+ Raises:
868
+ ValueError if the label key and sub-key is not in data.
869
+ """
870
+ if (
871
+ self.label_key is not None
872
+ and self.label_sub_key is not None
873
+ and (
874
+ self.label_key not in self.data
875
+ or self.label_sub_key not in self.data[self.label_key]
876
+ or len(self.data[self.label_key][self.label_sub_key]) > 1
877
+ )
878
+ ):
879
+ raise ValueError(
880
+ f"Key '{self.label_key}.{self.label_sub_key}' not in RasaModelData."
881
+ )
882
+
883
+ def _convert_train_test_split(
884
+ self, output_values: List[Any], solo_values: List[Any]
885
+ ) -> Tuple["RasaModelData", "RasaModelData"]:
886
+ """Converts the output of sklearn's train_test_split into model data.
887
+
888
+ Args:
889
+ output_values: output values of sklearn's train_test_split
890
+ solo_values: list of solo values
891
+
892
+ Returns:
893
+ The test and train RasaModelData
894
+ """
895
+ data_train: DefaultDict[
896
+ Text, DefaultDict[Text, List[FeatureArray]]
897
+ ] = defaultdict(lambda: defaultdict(list))
898
+ data_val: DefaultDict[Text, DefaultDict[Text, List[Any]]] = defaultdict(
899
+ lambda: defaultdict(list)
900
+ )
901
+
902
+ # output_values = x_train, x_val, y_train, y_val, z_train, z_val, etc.
903
+ # order is kept, e.g. same order as model data keys
904
+
905
+ # train datasets have an even index
906
+ index = 0
907
+ for key, attribute_data in self.data.items():
908
+ for sub_key, features in attribute_data.items():
909
+ for f in features:
910
+ data_train[key][sub_key].append(
911
+ self._combine_features(
912
+ output_values[index * 2],
913
+ solo_values[index],
914
+ f.number_of_dimensions,
915
+ )
916
+ )
917
+ index += 1
918
+
919
+ # val datasets have an odd index
920
+ index = 0
921
+ for key, attribute_data in self.data.items():
922
+ for sub_key, features in attribute_data.items():
923
+ for _ in features:
924
+ data_val[key][sub_key].append(output_values[(index * 2) + 1])
925
+ index += 1
926
+
927
+ return (
928
+ RasaModelData(self.label_key, self.label_sub_key, data_train),
929
+ RasaModelData(self.label_key, self.label_sub_key, data_val),
930
+ )
931
+
932
+ @staticmethod
933
+ def _combine_features(
934
+ feature_1: Union[np.ndarray, scipy.sparse.spmatrix],
935
+ feature_2: Union[np.ndarray, scipy.sparse.spmatrix],
936
+ number_of_dimensions: Optional[int] = 1,
937
+ ) -> FeatureArray:
938
+ """Concatenate features.
939
+
940
+ Args:
941
+ feature_1: Features to concatenate.
942
+ feature_2: Features to concatenate.
943
+
944
+ Returns:
945
+ The combined features.
946
+ """
947
+ if isinstance(feature_1, scipy.sparse.spmatrix) and isinstance(
948
+ feature_2, scipy.sparse.spmatrix
949
+ ):
950
+ if feature_2.shape[0] == 0:
951
+ return FeatureArray(feature_1, number_of_dimensions)
952
+ if feature_1.shape[0] == 0:
953
+ return FeatureArray(feature_2, number_of_dimensions)
954
+ return FeatureArray(
955
+ scipy.sparse.vstack([feature_1, feature_2]), number_of_dimensions
956
+ )
957
+ return FeatureArray(
958
+ np.concatenate([feature_1, feature_2]),
959
+ number_of_dimensions,
960
+ )
961
+
962
+ @staticmethod
963
+ def _create_label_ids(label_ids: FeatureArray) -> np.ndarray:
964
+ """Convert various size label_ids into single dim array.
965
+
966
+ For multi-label y, map each distinct row to a string representation
967
+ using join because str(row) uses an ellipsis if len(row) > 1000.
968
+ Idea taken from sklearn's stratify split.
969
+
970
+ Args:
971
+ label_ids: The label ids.
972
+
973
+ Raises:
974
+ ValueError if dimensionality of label ids is not supported
975
+
976
+ Returns:
977
+ The single dim label array.
978
+ """
979
+ if label_ids.ndim == 1:
980
+ return label_ids
981
+
982
+ if label_ids.ndim == 2 and label_ids.shape[-1] == 1:
983
+ return label_ids[:, 0]
984
+
985
+ if label_ids.ndim == 2:
986
+ return np.array([" ".join(row.astype("str")) for row in label_ids])
987
+
988
+ if label_ids.ndim == 3 and label_ids.shape[-1] == 1:
989
+ return np.array([" ".join(row.astype("str")) for row in label_ids[:, :, 0]])
990
+
991
+ raise ValueError("Unsupported label_ids dimensions")