rasa-pro 3.8.18__py3-none-any.whl → 3.9.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rasa-pro might be problematic. Click here for more details.

Files changed (278) hide show
  1. README.md +6 -42
  2. rasa/__main__.py +14 -9
  3. rasa/anonymization/anonymization_pipeline.py +0 -1
  4. rasa/anonymization/anonymization_rule_executor.py +3 -3
  5. rasa/anonymization/utils.py +4 -3
  6. rasa/api.py +2 -2
  7. rasa/cli/arguments/default_arguments.py +1 -1
  8. rasa/cli/arguments/run.py +2 -2
  9. rasa/cli/arguments/test.py +1 -1
  10. rasa/cli/arguments/train.py +10 -10
  11. rasa/cli/e2e_test.py +27 -7
  12. rasa/cli/export.py +0 -1
  13. rasa/cli/license.py +3 -3
  14. rasa/cli/project_templates/calm/actions/action_template.py +1 -1
  15. rasa/cli/project_templates/calm/config.yml +1 -1
  16. rasa/cli/project_templates/calm/credentials.yml +1 -1
  17. rasa/cli/project_templates/calm/data/flows/add_contact.yml +1 -1
  18. rasa/cli/project_templates/calm/data/flows/remove_contact.yml +1 -1
  19. rasa/cli/project_templates/calm/domain/add_contact.yml +8 -2
  20. rasa/cli/project_templates/calm/domain/list_contacts.yml +3 -0
  21. rasa/cli/project_templates/calm/domain/remove_contact.yml +9 -2
  22. rasa/cli/project_templates/calm/domain/shared.yml +5 -0
  23. rasa/cli/project_templates/calm/endpoints.yml +4 -4
  24. rasa/cli/project_templates/default/actions/actions.py +1 -1
  25. rasa/cli/project_templates/default/config.yml +5 -5
  26. rasa/cli/project_templates/default/credentials.yml +1 -1
  27. rasa/cli/project_templates/default/endpoints.yml +4 -4
  28. rasa/cli/project_templates/default/tests/test_stories.yml +1 -1
  29. rasa/cli/project_templates/tutorial/config.yml +1 -1
  30. rasa/cli/project_templates/tutorial/credentials.yml +1 -1
  31. rasa/cli/project_templates/tutorial/data/patterns.yml +6 -0
  32. rasa/cli/project_templates/tutorial/domain.yml +4 -0
  33. rasa/cli/project_templates/tutorial/endpoints.yml +6 -6
  34. rasa/cli/run.py +0 -1
  35. rasa/cli/scaffold.py +3 -2
  36. rasa/cli/studio/download.py +11 -0
  37. rasa/cli/studio/studio.py +180 -24
  38. rasa/cli/studio/upload.py +0 -8
  39. rasa/cli/telemetry.py +18 -6
  40. rasa/cli/utils.py +21 -10
  41. rasa/cli/x.py +3 -2
  42. rasa/constants.py +1 -1
  43. rasa/core/actions/action.py +90 -315
  44. rasa/core/actions/action_exceptions.py +24 -0
  45. rasa/core/actions/constants.py +3 -0
  46. rasa/core/actions/custom_action_executor.py +188 -0
  47. rasa/core/actions/forms.py +11 -7
  48. rasa/core/actions/grpc_custom_action_executor.py +251 -0
  49. rasa/core/actions/http_custom_action_executor.py +140 -0
  50. rasa/core/actions/loops.py +3 -0
  51. rasa/core/actions/two_stage_fallback.py +1 -1
  52. rasa/core/agent.py +2 -4
  53. rasa/core/brokers/pika.py +1 -2
  54. rasa/core/channels/audiocodes.py +1 -1
  55. rasa/core/channels/botframework.py +0 -1
  56. rasa/core/channels/callback.py +0 -1
  57. rasa/core/channels/console.py +6 -8
  58. rasa/core/channels/development_inspector.py +1 -1
  59. rasa/core/channels/facebook.py +0 -3
  60. rasa/core/channels/hangouts.py +0 -6
  61. rasa/core/channels/inspector/dist/assets/{arc-5623b6dc.js → arc-b6e548fe.js} +1 -1
  62. rasa/core/channels/inspector/dist/assets/{c4Diagram-d0fbc5ce-685c106a.js → c4Diagram-d0fbc5ce-fa03ac9e.js} +1 -1
  63. rasa/core/channels/inspector/dist/assets/{classDiagram-936ed81e-8cbed007.js → classDiagram-936ed81e-ee67392a.js} +1 -1
  64. rasa/core/channels/inspector/dist/assets/{classDiagram-v2-c3cb15f1-5889cf12.js → classDiagram-v2-c3cb15f1-9b283fae.js} +1 -1
  65. rasa/core/channels/inspector/dist/assets/{createText-62fc7601-24c249d7.js → createText-62fc7601-8b6fcc2a.js} +1 -1
  66. rasa/core/channels/inspector/dist/assets/{edges-f2ad444c-7dd06a75.js → edges-f2ad444c-22e77f4f.js} +1 -1
  67. rasa/core/channels/inspector/dist/assets/{erDiagram-9d236eb7-62c1e54c.js → erDiagram-9d236eb7-60ffc87f.js} +1 -1
  68. rasa/core/channels/inspector/dist/assets/{flowDb-1972c806-ce49b86f.js → flowDb-1972c806-9dd802e4.js} +1 -1
  69. rasa/core/channels/inspector/dist/assets/{flowDiagram-7ea5b25a-4067e48f.js → flowDiagram-7ea5b25a-5fa1912f.js} +1 -1
  70. rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-1844e5a5.js +1 -0
  71. rasa/core/channels/inspector/dist/assets/{flowchart-elk-definition-abe16c3d-59fe4051.js → flowchart-elk-definition-abe16c3d-622a1fd2.js} +1 -1
  72. rasa/core/channels/inspector/dist/assets/{ganttDiagram-9b5ea136-47e3a43b.js → ganttDiagram-9b5ea136-e285a63a.js} +1 -1
  73. rasa/core/channels/inspector/dist/assets/{gitGraphDiagram-99d0ae7c-5a2ac0d9.js → gitGraphDiagram-99d0ae7c-f237bdca.js} +1 -1
  74. rasa/core/channels/inspector/dist/assets/{index-2c4b9a3b-dfb8efc4.js → index-2c4b9a3b-4b03d70e.js} +1 -1
  75. rasa/core/channels/inspector/dist/assets/{index-268a75c0.js → index-a5d3e69d.js} +4 -4
  76. rasa/core/channels/inspector/dist/assets/{infoDiagram-736b4530-b0c470f2.js → infoDiagram-736b4530-72a0fa5f.js} +1 -1
  77. rasa/core/channels/inspector/dist/assets/{journeyDiagram-df861f2b-2edb829a.js → journeyDiagram-df861f2b-82218c41.js} +1 -1
  78. rasa/core/channels/inspector/dist/assets/{layout-b6873d69.js → layout-78cff630.js} +1 -1
  79. rasa/core/channels/inspector/dist/assets/{line-1efc5781.js → line-5038b469.js} +1 -1
  80. rasa/core/channels/inspector/dist/assets/{linear-661e9b94.js → linear-c4fc4098.js} +1 -1
  81. rasa/core/channels/inspector/dist/assets/{mindmap-definition-beec6740-2d2e727f.js → mindmap-definition-beec6740-c33c8ea6.js} +1 -1
  82. rasa/core/channels/inspector/dist/assets/{pieDiagram-dbbf0591-9d3ea93d.js → pieDiagram-dbbf0591-a8d03059.js} +1 -1
  83. rasa/core/channels/inspector/dist/assets/{quadrantDiagram-4d7f4fd6-06a178a2.js → quadrantDiagram-4d7f4fd6-6a0e56b2.js} +1 -1
  84. rasa/core/channels/inspector/dist/assets/{requirementDiagram-6fc4c22a-0bfedffc.js → requirementDiagram-6fc4c22a-2dc7c7bd.js} +1 -1
  85. rasa/core/channels/inspector/dist/assets/{sankeyDiagram-8f13d901-d76d0a04.js → sankeyDiagram-8f13d901-2360fe39.js} +1 -1
  86. rasa/core/channels/inspector/dist/assets/{sequenceDiagram-b655622a-37bb4341.js → sequenceDiagram-b655622a-41b9f9ad.js} +1 -1
  87. rasa/core/channels/inspector/dist/assets/{stateDiagram-59f0c015-f52f7f57.js → stateDiagram-59f0c015-0aad326f.js} +1 -1
  88. rasa/core/channels/inspector/dist/assets/{stateDiagram-v2-2b26beab-4a986a20.js → stateDiagram-v2-2b26beab-9847d984.js} +1 -1
  89. rasa/core/channels/inspector/dist/assets/{styles-080da4f6-7dd9ae12.js → styles-080da4f6-564d890e.js} +1 -1
  90. rasa/core/channels/inspector/dist/assets/{styles-3dcbcfbf-46e1ca14.js → styles-3dcbcfbf-38957613.js} +1 -1
  91. rasa/core/channels/inspector/dist/assets/{styles-9c745c82-4a97439a.js → styles-9c745c82-f0fc6921.js} +1 -1
  92. rasa/core/channels/inspector/dist/assets/{svgDrawCommon-4835440b-823917a3.js → svgDrawCommon-4835440b-ef3c5a77.js} +1 -1
  93. rasa/core/channels/inspector/dist/assets/{timeline-definition-5b62e21b-9ea72896.js → timeline-definition-5b62e21b-bf3e91c1.js} +1 -1
  94. rasa/core/channels/inspector/dist/assets/{xychartDiagram-2b33534f-b631a8b6.js → xychartDiagram-2b33534f-4d4026c0.js} +1 -1
  95. rasa/core/channels/inspector/dist/index.html +1 -1
  96. rasa/core/channels/inspector/src/components/DiagramFlow.tsx +10 -0
  97. rasa/core/channels/inspector/src/helpers/formatters.test.ts +4 -7
  98. rasa/core/channels/inspector/src/helpers/formatters.ts +3 -2
  99. rasa/core/channels/rest.py +36 -21
  100. rasa/core/channels/rocketchat.py +0 -1
  101. rasa/core/channels/socketio.py +1 -1
  102. rasa/core/channels/telegram.py +3 -3
  103. rasa/core/channels/webexteams.py +0 -1
  104. rasa/core/concurrent_lock_store.py +1 -1
  105. rasa/core/evaluation/marker_base.py +1 -3
  106. rasa/core/evaluation/marker_stats.py +1 -2
  107. rasa/core/featurizers/single_state_featurizer.py +3 -26
  108. rasa/core/featurizers/tracker_featurizers.py +18 -122
  109. rasa/core/information_retrieval/__init__.py +7 -0
  110. rasa/core/information_retrieval/faiss.py +9 -4
  111. rasa/core/information_retrieval/information_retrieval.py +64 -7
  112. rasa/core/information_retrieval/milvus.py +7 -14
  113. rasa/core/information_retrieval/qdrant.py +8 -15
  114. rasa/core/lock_store.py +0 -1
  115. rasa/core/migrate.py +1 -2
  116. rasa/core/nlg/callback.py +3 -4
  117. rasa/core/policies/enterprise_search_policy.py +86 -22
  118. rasa/core/policies/enterprise_search_prompt_template.jinja2 +4 -41
  119. rasa/core/policies/enterprise_search_prompt_with_citation_template.jinja2 +60 -0
  120. rasa/core/policies/flows/flow_executor.py +104 -2
  121. rasa/core/policies/intentless_policy.py +7 -9
  122. rasa/core/policies/memoization.py +3 -3
  123. rasa/core/policies/policy.py +18 -9
  124. rasa/core/policies/rule_policy.py +8 -11
  125. rasa/core/policies/ted_policy.py +61 -88
  126. rasa/core/policies/unexpected_intent_policy.py +8 -17
  127. rasa/core/processor.py +136 -47
  128. rasa/core/run.py +41 -25
  129. rasa/core/secrets_manager/endpoints.py +2 -2
  130. rasa/core/secrets_manager/vault.py +6 -8
  131. rasa/core/test.py +3 -5
  132. rasa/core/tracker_store.py +49 -14
  133. rasa/core/train.py +1 -3
  134. rasa/core/training/interactive.py +9 -6
  135. rasa/core/utils.py +5 -10
  136. rasa/dialogue_understanding/coexistence/intent_based_router.py +11 -4
  137. rasa/dialogue_understanding/coexistence/llm_based_router.py +2 -3
  138. rasa/dialogue_understanding/commands/__init__.py +4 -0
  139. rasa/dialogue_understanding/commands/can_not_handle_command.py +9 -0
  140. rasa/dialogue_understanding/commands/cancel_flow_command.py +9 -0
  141. rasa/dialogue_understanding/commands/change_flow_command.py +38 -0
  142. rasa/dialogue_understanding/commands/chit_chat_answer_command.py +9 -0
  143. rasa/dialogue_understanding/commands/clarify_command.py +9 -0
  144. rasa/dialogue_understanding/commands/correct_slots_command.py +9 -0
  145. rasa/dialogue_understanding/commands/error_command.py +12 -0
  146. rasa/dialogue_understanding/commands/handle_code_change_command.py +9 -0
  147. rasa/dialogue_understanding/commands/human_handoff_command.py +9 -0
  148. rasa/dialogue_understanding/commands/knowledge_answer_command.py +9 -0
  149. rasa/dialogue_understanding/commands/noop_command.py +9 -0
  150. rasa/dialogue_understanding/commands/set_slot_command.py +38 -3
  151. rasa/dialogue_understanding/commands/skip_question_command.py +9 -0
  152. rasa/dialogue_understanding/commands/start_flow_command.py +9 -0
  153. rasa/dialogue_understanding/generator/__init__.py +16 -1
  154. rasa/dialogue_understanding/generator/command_generator.py +92 -6
  155. rasa/dialogue_understanding/generator/constants.py +18 -0
  156. rasa/dialogue_understanding/generator/flow_retrieval.py +7 -5
  157. rasa/dialogue_understanding/generator/llm_based_command_generator.py +467 -0
  158. rasa/dialogue_understanding/generator/llm_command_generator.py +39 -609
  159. rasa/dialogue_understanding/generator/multi_step/__init__.py +0 -0
  160. rasa/dialogue_understanding/generator/multi_step/fill_slots_prompt.jinja2 +62 -0
  161. rasa/dialogue_understanding/generator/multi_step/handle_flows_prompt.jinja2 +38 -0
  162. rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +827 -0
  163. rasa/dialogue_understanding/generator/nlu_command_adapter.py +69 -8
  164. rasa/dialogue_understanding/generator/single_step/__init__.py +0 -0
  165. rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +345 -0
  166. rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +36 -31
  167. rasa/dialogue_understanding/processor/command_processor.py +112 -3
  168. rasa/e2e_test/constants.py +1 -0
  169. rasa/e2e_test/e2e_test_case.py +44 -0
  170. rasa/e2e_test/e2e_test_runner.py +114 -11
  171. rasa/e2e_test/e2e_test_schema.yml +18 -0
  172. rasa/engine/caching.py +0 -1
  173. rasa/engine/graph.py +18 -6
  174. rasa/engine/recipes/config_files/default_config.yml +3 -3
  175. rasa/engine/recipes/default_components.py +1 -1
  176. rasa/engine/recipes/default_recipe.py +4 -5
  177. rasa/engine/recipes/recipe.py +1 -1
  178. rasa/engine/runner/dask.py +3 -9
  179. rasa/engine/storage/local_model_storage.py +0 -2
  180. rasa/engine/validation.py +179 -145
  181. rasa/exceptions.py +2 -2
  182. rasa/graph_components/validators/default_recipe_validator.py +3 -5
  183. rasa/hooks.py +0 -1
  184. rasa/model.py +1 -1
  185. rasa/model_training.py +1 -0
  186. rasa/nlu/classifiers/diet_classifier.py +33 -52
  187. rasa/nlu/classifiers/logistic_regression_classifier.py +9 -22
  188. rasa/nlu/classifiers/sklearn_intent_classifier.py +16 -37
  189. rasa/nlu/extractors/crf_entity_extractor.py +54 -97
  190. rasa/nlu/extractors/duckling_entity_extractor.py +1 -1
  191. rasa/nlu/featurizers/dense_featurizer/convert_featurizer.py +1 -5
  192. rasa/nlu/featurizers/dense_featurizer/lm_featurizer.py +0 -4
  193. rasa/nlu/featurizers/featurizer.py +1 -1
  194. rasa/nlu/featurizers/sparse_featurizer/count_vectors_featurizer.py +18 -49
  195. rasa/nlu/featurizers/sparse_featurizer/lexical_syntactic_featurizer.py +26 -64
  196. rasa/nlu/featurizers/sparse_featurizer/regex_featurizer.py +3 -5
  197. rasa/nlu/persistor.py +68 -26
  198. rasa/nlu/selectors/response_selector.py +7 -10
  199. rasa/nlu/test.py +0 -3
  200. rasa/nlu/utils/hugging_face/registry.py +1 -1
  201. rasa/nlu/utils/spacy_utils.py +1 -3
  202. rasa/server.py +22 -7
  203. rasa/shared/constants.py +12 -1
  204. rasa/shared/core/command_payload_reader.py +109 -0
  205. rasa/shared/core/constants.py +4 -5
  206. rasa/shared/core/domain.py +57 -56
  207. rasa/shared/core/events.py +4 -7
  208. rasa/shared/core/flows/flow.py +9 -0
  209. rasa/shared/core/flows/flows_list.py +12 -0
  210. rasa/shared/core/flows/steps/action.py +7 -2
  211. rasa/shared/core/generator.py +12 -11
  212. rasa/shared/core/slot_mappings.py +315 -24
  213. rasa/shared/core/slots.py +4 -2
  214. rasa/shared/core/trackers.py +32 -14
  215. rasa/shared/core/training_data/loading.py +0 -1
  216. rasa/shared/core/training_data/story_reader/story_reader.py +3 -3
  217. rasa/shared/core/training_data/story_reader/yaml_story_reader.py +11 -11
  218. rasa/shared/core/training_data/story_writer/yaml_story_writer.py +5 -3
  219. rasa/shared/core/training_data/structures.py +1 -1
  220. rasa/shared/core/training_data/visualization.py +1 -1
  221. rasa/shared/data.py +58 -1
  222. rasa/shared/exceptions.py +36 -2
  223. rasa/shared/importers/importer.py +1 -2
  224. rasa/shared/importers/rasa.py +0 -1
  225. rasa/shared/nlu/constants.py +2 -0
  226. rasa/shared/nlu/training_data/entities_parser.py +1 -2
  227. rasa/shared/nlu/training_data/features.py +2 -120
  228. rasa/shared/nlu/training_data/formats/dialogflow.py +3 -2
  229. rasa/shared/nlu/training_data/formats/rasa_yaml.py +3 -5
  230. rasa/shared/nlu/training_data/formats/readerwriter.py +0 -1
  231. rasa/shared/nlu/training_data/message.py +13 -0
  232. rasa/shared/nlu/training_data/training_data.py +0 -2
  233. rasa/shared/providers/openai/session_handler.py +2 -2
  234. rasa/shared/utils/constants.py +3 -0
  235. rasa/shared/utils/io.py +11 -1
  236. rasa/shared/utils/llm.py +1 -2
  237. rasa/shared/utils/pykwalify_extensions.py +1 -0
  238. rasa/shared/utils/schemas/domain.yml +3 -0
  239. rasa/shared/utils/yaml.py +44 -35
  240. rasa/studio/auth.py +26 -10
  241. rasa/studio/constants.py +2 -0
  242. rasa/studio/data_handler.py +114 -107
  243. rasa/studio/download.py +160 -27
  244. rasa/studio/results_logger.py +137 -0
  245. rasa/studio/train.py +6 -7
  246. rasa/studio/upload.py +159 -134
  247. rasa/telemetry.py +188 -34
  248. rasa/tracing/config.py +18 -3
  249. rasa/tracing/constants.py +26 -2
  250. rasa/tracing/instrumentation/attribute_extractors.py +50 -41
  251. rasa/tracing/instrumentation/instrumentation.py +290 -44
  252. rasa/tracing/instrumentation/intentless_policy_instrumentation.py +7 -5
  253. rasa/tracing/instrumentation/metrics.py +109 -21
  254. rasa/tracing/metric_instrument_provider.py +83 -3
  255. rasa/utils/cli.py +2 -1
  256. rasa/utils/common.py +1 -1
  257. rasa/utils/endpoints.py +1 -2
  258. rasa/utils/io.py +72 -6
  259. rasa/utils/licensing.py +246 -31
  260. rasa/utils/ml_utils.py +1 -1
  261. rasa/utils/tensorflow/data_generator.py +1 -1
  262. rasa/utils/tensorflow/environment.py +1 -1
  263. rasa/utils/tensorflow/model_data.py +201 -12
  264. rasa/utils/tensorflow/model_data_utils.py +499 -500
  265. rasa/utils/tensorflow/models.py +5 -6
  266. rasa/utils/tensorflow/rasa_layers.py +15 -15
  267. rasa/utils/train_utils.py +1 -1
  268. rasa/utils/url_tools.py +53 -0
  269. rasa/validator.py +305 -3
  270. rasa/version.py +1 -1
  271. {rasa_pro-3.8.18.dist-info → rasa_pro-3.9.15.dist-info}/METADATA +25 -61
  272. {rasa_pro-3.8.18.dist-info → rasa_pro-3.9.15.dist-info}/RECORD +276 -259
  273. rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-85583a23.js +0 -1
  274. rasa/utils/tensorflow/feature_array.py +0 -370
  275. /rasa/dialogue_understanding/generator/{command_prompt_template.jinja2 → single_step/command_prompt_template.jinja2} +0 -0
  276. {rasa_pro-3.8.18.dist-info → rasa_pro-3.9.15.dist-info}/NOTICE +0 -0
  277. {rasa_pro-3.8.18.dist-info → rasa_pro-3.9.15.dist-info}/WHEEL +0 -0
  278. {rasa_pro-3.8.18.dist-info → rasa_pro-3.9.15.dist-info}/entry_points.txt +0 -0
@@ -3,6 +3,8 @@ from __future__ import annotations
3
3
  from typing import Any, Dict, Text, List, Optional
4
4
 
5
5
  from jinja2 import Template
6
+ from rasa.dialogue_understanding.commands import CancelFlowCommand
7
+ from rasa.dialogue_understanding.patterns.cancel import CancelPatternFlowStackFrame
6
8
  from structlog.contextvars import (
7
9
  bound_contextvars,
8
10
  )
@@ -17,6 +19,9 @@ from rasa.core.policies.flows.flow_step_result import (
17
19
  FlowStepResult,
18
20
  PauseFlowReturnPrediction,
19
21
  )
22
+ from rasa.dialogue_understanding.patterns.internal_error import (
23
+ InternalErrorPatternFlowStackFrame,
24
+ )
20
25
  from rasa.dialogue_understanding.patterns.search import SearchPatternFlowStackFrame
21
26
  from rasa.dialogue_understanding.stack.dialogue_stack import DialogueStack
22
27
  from rasa.dialogue_understanding.stack.frames import (
@@ -42,7 +47,7 @@ from rasa.dialogue_understanding.stack.utils import (
42
47
 
43
48
  from pypred import Predicate
44
49
 
45
- from rasa.shared.core.constants import ACTION_LISTEN_NAME
50
+ from rasa.shared.core.constants import ACTION_LISTEN_NAME, SlotMappingType
46
51
  from rasa.shared.core.events import (
47
52
  Event,
48
53
  FlowCompleted,
@@ -72,6 +77,7 @@ from rasa.shared.core.flows.flow import (
72
77
  FlowStep,
73
78
  )
74
79
  from rasa.shared.core.flows.steps.collect import SlotRejection
80
+ from rasa.shared.core.slots import Slot
75
81
  from rasa.shared.core.trackers import (
76
82
  DialogueStateTracker,
77
83
  )
@@ -397,7 +403,6 @@ def advance_flows_until_next_action(
397
403
  number_of_steps_taken = 0
398
404
 
399
405
  while isinstance(step_result, ContinueFlowWithNextStep):
400
-
401
406
  number_of_steps_taken += 1
402
407
  if number_of_steps_taken > MAX_NUMBER_OF_STEPS:
403
408
  raise FlowCircuitBreakerTrippedException(
@@ -467,6 +472,87 @@ def advance_flows_until_next_action(
467
472
  return FlowActionPrediction(None, 0.0, events=gathered_events)
468
473
 
469
474
 
475
+ def validate_collect_step(
476
+ step: CollectInformationFlowStep,
477
+ stack: DialogueStack,
478
+ available_actions: List[str],
479
+ slots: Dict[Text, Slot],
480
+ ) -> bool:
481
+ """Validate that a collect step can be executed.
482
+
483
+ A collect step can be executed if either the `utter_ask` or the `action_ask` is
484
+ defined in the domain. If neither is defined, the collect step can still be
485
+ executed if the slot has an initial value defined in the domain, which would cause
486
+ the step to be skipped."""
487
+ slot = slots.get(step.collect)
488
+ slot_has_initial_value_defined = slot and slot.initial_value is not None
489
+ if (
490
+ slot_has_initial_value_defined
491
+ or step.utter in available_actions
492
+ or step.collect_action in available_actions
493
+ ):
494
+ return True
495
+
496
+ structlogger.error(
497
+ "flow.step.run.collect_missing_utter_or_collect_action",
498
+ slot_name=step.collect,
499
+ )
500
+
501
+ cancel_flow_and_push_internal_error(stack)
502
+
503
+ return False
504
+
505
+
506
+ def cancel_flow_and_push_internal_error(stack: DialogueStack) -> None:
507
+ """Cancel the top user flow and push the internal error pattern."""
508
+ top_frame = stack.top()
509
+
510
+ if isinstance(top_frame, BaseFlowStackFrame):
511
+ # we need to first cancel the top user flow
512
+ # because we cannot collect one of its slots
513
+ # and therefore should not proceed with the flow
514
+ # after triggering pattern_internal_error
515
+ canceled_frames = CancelFlowCommand.select_canceled_frames(stack)
516
+ stack.push(
517
+ CancelPatternFlowStackFrame(
518
+ canceled_name=top_frame.flow_id,
519
+ canceled_frames=canceled_frames,
520
+ )
521
+ )
522
+ stack.push(InternalErrorPatternFlowStackFrame())
523
+
524
+
525
+ def validate_custom_slot_mappings(
526
+ step: CollectInformationFlowStep,
527
+ stack: DialogueStack,
528
+ tracker: DialogueStateTracker,
529
+ available_actions: List[str],
530
+ ) -> bool:
531
+ """Validate a slot with custom mappings.
532
+
533
+ If invalid, trigger pattern_internal_error and return False.
534
+ """
535
+ slot = tracker.slots.get(step.collect, None)
536
+ slot_mappings = slot.mappings if slot else []
537
+ for mapping in slot_mappings:
538
+ if (
539
+ mapping.get("type") == SlotMappingType.CUSTOM.value
540
+ and mapping.get("action") is None
541
+ ):
542
+ # this is a slot that must be filled by a custom action
543
+ # check if collect_action exists
544
+ if step.collect_action not in available_actions:
545
+ structlogger.error(
546
+ "flow.step.run.collect_action_not_found_for_custom_slot_mapping",
547
+ action=step.collect_action,
548
+ collect=step.collect,
549
+ )
550
+ cancel_flow_and_push_internal_error(stack)
551
+ return False
552
+
553
+ return True
554
+
555
+
470
556
  def run_step(
471
557
  step: FlowStep,
472
558
  flow: Flow,
@@ -500,6 +586,22 @@ def run_step(
500
586
  initial_events.append(FlowStarted(flow.id))
501
587
 
502
588
  if isinstance(step, CollectInformationFlowStep):
589
+ is_step_valid = validate_collect_step(
590
+ step, stack, available_actions, tracker.slots
591
+ )
592
+ if not is_step_valid:
593
+ # if we return any other FlowStepResult, the assistant will stay silent
594
+ # instead of triggering the internal error pattern
595
+ return ContinueFlowWithNextStep(events=initial_events)
596
+
597
+ is_mapping_valid = validate_custom_slot_mappings(
598
+ step, stack, tracker, available_actions
599
+ )
600
+ if not is_mapping_valid:
601
+ # if we return any other FlowStepResult, the assistant will stay silent
602
+ # instead of triggering the internal error pattern
603
+ return ContinueFlowWithNextStep(events=initial_events)
604
+
503
605
  structlogger.debug("flow.step.run.collect")
504
606
  trigger_pattern_ask_collect_information(
505
607
  step.collect, stack, step.rejections, step.utter, step.collect_action
@@ -3,7 +3,6 @@ import math
3
3
  from dataclasses import dataclass, field
4
4
  from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING, Text, Tuple
5
5
 
6
- import rasa.shared.utils.io
7
6
  import structlog
8
7
  import tiktoken
9
8
  from jinja2 import Template
@@ -11,6 +10,7 @@ from langchain.docstore.document import Document
11
10
  from langchain.schema.embeddings import Embeddings
12
11
  from langchain.vectorstores import FAISS
13
12
 
13
+ import rasa.shared.utils.io
14
14
  from rasa import telemetry
15
15
  from rasa.core.constants import (
16
16
  CHAT_POLICY_PRIORITY,
@@ -56,7 +56,6 @@ from rasa.shared.utils.llm import (
56
56
  sanitize_message_for_prompt,
57
57
  tracker_as_readable_transcript,
58
58
  )
59
-
60
59
  from rasa.utils.ml_utils import (
61
60
  extract_ai_response_examples,
62
61
  extract_participant_messages_from_transcript,
@@ -65,7 +64,6 @@ from rasa.utils.ml_utils import (
65
64
  persist_faiss_vector_store,
66
65
  response_for_template,
67
66
  )
68
-
69
67
  from rasa.utils.log_utils import log_llm
70
68
 
71
69
  if TYPE_CHECKING:
@@ -543,7 +541,9 @@ class IntentlessPolicy(Policy):
543
541
  Returns:
544
542
  The prediction.
545
543
  """
546
- if not self.supports_current_stack_frame(tracker):
544
+ if not self.supports_current_stack_frame(
545
+ tracker
546
+ ) or self.should_abstain_in_coexistence(tracker, True):
547
547
  return self._prediction(self._default_predictions(domain))
548
548
 
549
549
  if tracker.has_bot_message_after_latest_user_message():
@@ -670,7 +670,7 @@ class IntentlessPolicy(Policy):
670
670
  if tracker.latest_message.text.startswith("/"):
671
671
  # we don't want to generate a response if the user is trying to
672
672
  # execute a "command" - this should be handled by the regex
673
- # intent classifier in rasa open source.
673
+ # intent classifier in rasa pro.
674
674
  structlogger.debug("intentless_policy.prediction.skip_slash")
675
675
  return None, 0.0
676
676
 
@@ -863,7 +863,7 @@ class IntentlessPolicy(Policy):
863
863
  """
864
864
  result = self._default_predictions(domain)
865
865
  if action_name:
866
- result[domain.index_for_action(action_name)] = score # type: ignore[assignment] # noqa: E501
866
+ result[domain.index_for_action(action_name)] = score # type: ignore[assignment]
867
867
  return result
868
868
 
869
869
  @classmethod
@@ -892,9 +892,7 @@ class IntentlessPolicy(Policy):
892
892
  # normalized. unfortunatley langchain doesn't persist / load
893
893
  # this parameter.
894
894
  if responses_docsearch:
895
- responses_docsearch._normalize_L2 = (
896
- True # pylint: disable=protected-access
897
- )
895
+ responses_docsearch._normalize_L2 = True # pylint: disable=protected-access
898
896
  prompt_template = rasa.shared.utils.io.read_file(
899
897
  path / INTENTLESS_PROMPT_TEMPLATE_FILE_NAME
900
898
  )
@@ -419,9 +419,9 @@ class AugmentedMemoizationPolicy(MemoizationPolicy):
419
419
  logger.debug("Launch DeLorean...")
420
420
 
421
421
  # Truncate the tracker based on `max_history`
422
- truncated_tracker: Optional[
423
- DialogueStateTracker
424
- ] = _trim_tracker_by_max_history(tracker, self.config[POLICY_MAX_HISTORY])
422
+ truncated_tracker: Optional[DialogueStateTracker] = (
423
+ _trim_tracker_by_max_history(tracker, self.config[POLICY_MAX_HISTORY])
424
+ )
425
425
  truncated_tracker = self._strip_leading_events_until_action_executed(
426
426
  truncated_tracker
427
427
  )
@@ -1,12 +1,10 @@
1
1
  from __future__ import annotations
2
+
2
3
  import abc
3
4
  import copy
4
5
  import logging
5
6
  from enum import Enum
6
7
  from pathlib import Path
7
-
8
- from rasa.shared.constants import ROUTE_TO_CALM_SLOT
9
- from rasa.shared.core.events import Event
10
8
  from typing import (
11
9
  Any,
12
10
  List,
@@ -21,6 +19,8 @@ from typing import (
21
19
 
22
20
  import numpy as np
23
21
 
22
+ from rasa.shared.constants import ROUTE_TO_CALM_SLOT
23
+ from rasa.shared.core.events import Event
24
24
  from rasa.engine.graph import GraphComponent, ExecutionContext
25
25
  from rasa.engine.storage.resource import Resource
26
26
  from rasa.engine.storage.storage import ModelStorage
@@ -40,14 +40,12 @@ from rasa.core.constants import (
40
40
  from rasa.shared.core.constants import USER, SLOTS, PREVIOUS_ACTION, ACTIVE_LOOP
41
41
  import rasa.shared.utils.common
42
42
 
43
-
44
43
  if TYPE_CHECKING:
45
44
  from rasa.shared.nlu.training_data.features import Features
46
45
  from rasa.core.featurizers.tracker_featurizers import TrackerFeaturizer
47
46
  from rasa.core.featurizers.tracker_featurizers import MaxHistoryTrackerFeaturizer
48
47
  from rasa.dialogue_understanding.stack.frames import DialogueStackFrame
49
48
 
50
-
51
49
  logger = logging.getLogger(__name__)
52
50
 
53
51
  TrackerListTypeVar = TypeVar(
@@ -137,10 +135,20 @@ class Policy(GraphComponent):
137
135
  def should_abstain_in_coexistence(
138
136
  self, tracker: DialogueStateTracker, is_calm_policy: bool
139
137
  ) -> bool:
140
- """Whether a policy should abstain making predictions in coexistence."""
138
+ """Whether a policy should abstain making predictions in coexistence.
139
+
140
+ A calm policy should run when the routing slot is set to True.
141
+ A nlu-based policy should run when the routing slot is set to False or None.
142
+ """
143
+ if is_calm_policy:
144
+ return tracker.has_coexistence_routing_slot and (
145
+ tracker.get_slot(ROUTE_TO_CALM_SLOT) is False
146
+ or tracker.get_slot(ROUTE_TO_CALM_SLOT) is None
147
+ )
148
+
141
149
  return (
142
150
  tracker.has_coexistence_routing_slot
143
- and tracker.get_slot(ROUTE_TO_CALM_SLOT) != is_calm_policy
151
+ and tracker.get_slot(ROUTE_TO_CALM_SLOT) is True
144
152
  )
145
153
 
146
154
  def __init__(
@@ -299,8 +307,9 @@ class Policy(GraphComponent):
299
307
  max_training_samples = kwargs.get("max_training_samples")
300
308
  if max_training_samples is not None:
301
309
  logger.debug(
302
- "Limit training data to {} training samples."
303
- "".format(max_training_samples)
310
+ "Limit training data to {} training samples.".format(
311
+ max_training_samples
312
+ )
304
313
  )
305
314
  state_features = state_features[:max_training_samples]
306
315
  label_ids = label_ids[:max_training_samples]
@@ -60,7 +60,7 @@ logger = logging.getLogger(__name__)
60
60
  structlogger = structlog.get_logger()
61
61
 
62
62
 
63
- # These are Rasa Open Source default actions and overrule everything at any time.
63
+ # These are Rasa Pro default actions and overrule everything at any time.
64
64
  DEFAULT_ACTION_MAPPINGS = {
65
65
  USER_INTENT_RESTART: ACTION_RESTART_NAME,
66
66
  USER_INTENT_BACK: ACTION_BACK_NAME,
@@ -271,15 +271,13 @@ class RulePolicy(MemoizationPolicy):
271
271
  if (
272
272
  # loop is predicted after action_listen in unhappy path,
273
273
  # therefore no validation is needed
274
- is_prev_action_listen_in_state(states[-1])
275
- and action == active_loop
274
+ is_prev_action_listen_in_state(states[-1]) and action == active_loop
276
275
  ):
277
276
  lookup[feature_key] = LOOP_WAS_INTERRUPTED
278
277
  elif (
279
278
  # some action other than active_loop is predicted in unhappy path,
280
279
  # therefore active_loop shouldn't be predicted by the rule
281
- not is_prev_action_listen_in_state(states[-1])
282
- and action != active_loop
280
+ not is_prev_action_listen_in_state(states[-1]) and action != active_loop
283
281
  ):
284
282
  lookup[feature_key] = DO_NOT_PREDICT_LOOP_ACTION
285
283
  return lookup
@@ -777,10 +775,10 @@ class RulePolicy(MemoizationPolicy):
777
775
  trackers_as_actions = rule_trackers_as_actions + story_trackers_as_actions
778
776
 
779
777
  # negative rules are not anti-rules, they are auxiliary to actual rules
780
- self.lookup[
781
- RULES_FOR_LOOP_UNHAPPY_PATH
782
- ] = self._create_loop_unhappy_lookup_from_states(
783
- trackers_as_states, trackers_as_actions
778
+ self.lookup[RULES_FOR_LOOP_UNHAPPY_PATH] = (
779
+ self._create_loop_unhappy_lookup_from_states(
780
+ trackers_as_states, trackers_as_actions
781
+ )
784
782
  )
785
783
 
786
784
  def train(
@@ -955,7 +953,6 @@ class RulePolicy(MemoizationPolicy):
955
953
  def _find_action_from_loop_happy_path(
956
954
  tracker: DialogueStateTracker,
957
955
  ) -> Tuple[Optional[Text], Optional[Text]]:
958
-
959
956
  active_loop_name = tracker.active_loop_name
960
957
  if active_loop_name is None:
961
958
  return None, None
@@ -1132,7 +1129,7 @@ class RulePolicy(MemoizationPolicy):
1132
1129
  tracker, domain, use_text_for_last_user_input=True
1133
1130
  )
1134
1131
 
1135
- # Rasa Open Source default actions overrule anything. If users want to achieve
1132
+ # Rasa Pro default actions overrule anything. If users want to achieve
1136
1133
  # the same, they need to write a rule or make sure that their loop rejects
1137
1134
  # accordingly.
1138
1135
  (
@@ -1,15 +1,15 @@
1
1
  from __future__ import annotations
2
-
3
2
  import logging
3
+
4
+ from rasa.engine.recipes.default_recipe import DefaultV1Recipe
4
5
  from pathlib import Path
5
6
  from collections import defaultdict
6
7
  import contextlib
7
- from typing import Any, List, Optional, Text, Dict, Tuple, Union, Type
8
8
 
9
9
  import numpy as np
10
10
  import tensorflow as tf
11
+ from typing import Any, List, Optional, Text, Dict, Tuple, Union, Type
11
12
 
12
- from rasa.engine.recipes.default_recipe import DefaultV1Recipe
13
13
  from rasa.engine.graph import ExecutionContext
14
14
  from rasa.engine.storage.resource import Resource
15
15
  from rasa.engine.storage.storage import ModelStorage
@@ -49,22 +49,18 @@ from rasa.shared.core.generator import TrackerWithCachedStates
49
49
  from rasa.shared.core.events import EntitiesAdded, Event
50
50
  from rasa.shared.core.domain import Domain
51
51
  from rasa.shared.nlu.training_data.message import Message
52
- from rasa.shared.nlu.training_data.features import (
53
- Features,
54
- save_features,
55
- load_features,
56
- )
52
+ from rasa.shared.nlu.training_data.features import Features
57
53
  import rasa.shared.utils.io
58
54
  import rasa.utils.io
59
55
  from rasa.utils import train_utils
60
- from rasa.utils.tensorflow.feature_array import (
61
- FeatureArray,
62
- serialize_nested_feature_arrays,
63
- deserialize_nested_feature_arrays,
64
- )
65
56
  from rasa.utils.tensorflow.models import RasaModel, TransformerRasaModel
66
57
  from rasa.utils.tensorflow import rasa_layers
67
- from rasa.utils.tensorflow.model_data import RasaModelData, FeatureSignature, Data
58
+ from rasa.utils.tensorflow.model_data import (
59
+ RasaModelData,
60
+ FeatureSignature,
61
+ FeatureArray,
62
+ Data,
63
+ )
68
64
  from rasa.utils.tensorflow.model_data_utils import convert_to_data_format
69
65
  from rasa.utils.tensorflow.constants import (
70
66
  LABEL,
@@ -472,7 +468,7 @@ class TEDPolicy(Policy):
472
468
 
473
469
  @staticmethod
474
470
  def _should_extract_entities(
475
- entity_tags: List[List[Dict[Text, List[Features]]]]
471
+ entity_tags: List[List[Dict[Text, List[Features]]]],
476
472
  ) -> bool:
477
473
  for turns_tags in entity_tags:
478
474
  for turn_tags in turns_tags:
@@ -965,32 +961,22 @@ class TEDPolicy(Policy):
965
961
  model_path: Path where model is to be persisted
966
962
  """
967
963
  model_filename = self._metadata_filename()
968
- rasa.shared.utils.io.dump_obj_as_json_to_file(
969
- model_path / f"{model_filename}.priority.json", self.priority
964
+ rasa.utils.io.json_pickle(
965
+ model_path / f"{model_filename}.priority.pkl", self.priority
970
966
  )
971
- rasa.shared.utils.io.dump_obj_as_json_to_file(
972
- model_path / f"{model_filename}.meta.json", self.config
973
- )
974
- # save data example
975
- serialize_nested_feature_arrays(
976
- self.data_example,
977
- str(model_path / f"{model_filename}.data_example.st"),
978
- str(model_path / f"{model_filename}.data_example_metadata.json"),
967
+ rasa.utils.io.pickle_dump(
968
+ model_path / f"{model_filename}.meta.pkl", self.config
979
969
  )
980
- # save label data
981
- serialize_nested_feature_arrays(
982
- dict(self._label_data.data) if self._label_data is not None else {},
983
- str(model_path / f"{model_filename}.label_data.st"),
984
- str(model_path / f"{model_filename}.label_data_metadata.json"),
970
+ rasa.utils.io.pickle_dump(
971
+ model_path / f"{model_filename}.data_example.pkl", self.data_example
985
972
  )
986
- # save fake features
987
- metadata = save_features(
988
- self.fake_features, str(model_path / f"{model_filename}.fake_features.st")
973
+ rasa.utils.io.pickle_dump(
974
+ model_path / f"{model_filename}.fake_features.pkl", self.fake_features
989
975
  )
990
- rasa.shared.utils.io.dump_obj_as_json_to_file(
991
- model_path / f"{model_filename}.fake_features_metadata.json", metadata
976
+ rasa.utils.io.pickle_dump(
977
+ model_path / f"{model_filename}.label_data.pkl",
978
+ dict(self._label_data.data) if self._label_data is not None else {},
992
979
  )
993
-
994
980
  entity_tag_specs = (
995
981
  [tag_spec._asdict() for tag_spec in self._entity_tag_specs]
996
982
  if self._entity_tag_specs
@@ -1008,29 +994,18 @@ class TEDPolicy(Policy):
1008
994
  model_path: Path where model is to be persisted.
1009
995
  """
1010
996
  tf_model_file = model_path / f"{cls._metadata_filename()}.tf_model"
1011
-
1012
- # load data example
1013
- loaded_data = deserialize_nested_feature_arrays(
1014
- str(model_path / f"{cls._metadata_filename()}.data_example.st"),
1015
- str(model_path / f"{cls._metadata_filename()}.data_example_metadata.json"),
997
+ loaded_data = rasa.utils.io.pickle_load(
998
+ model_path / f"{cls._metadata_filename()}.data_example.pkl"
1016
999
  )
1017
- # load label data
1018
- loaded_label_data = deserialize_nested_feature_arrays(
1019
- str(model_path / f"{cls._metadata_filename()}.label_data.st"),
1020
- str(model_path / f"{cls._metadata_filename()}.label_data_metadata.json"),
1000
+ label_data = rasa.utils.io.pickle_load(
1001
+ model_path / f"{cls._metadata_filename()}.label_data.pkl"
1021
1002
  )
1022
- label_data = RasaModelData(data=loaded_label_data)
1023
-
1024
- # load fake features
1025
- metadata = rasa.shared.utils.io.read_json_file(
1026
- model_path / f"{cls._metadata_filename()}.fake_features_metadata.json"
1027
- )
1028
- fake_features = load_features(
1029
- str(model_path / f"{cls._metadata_filename()}.fake_features.st"), metadata
1003
+ fake_features = rasa.utils.io.pickle_load(
1004
+ model_path / f"{cls._metadata_filename()}.fake_features.pkl"
1030
1005
  )
1031
-
1032
- priority = rasa.shared.utils.io.read_json_file(
1033
- model_path / f"{cls._metadata_filename()}.priority.json"
1006
+ label_data = RasaModelData(data=label_data)
1007
+ priority = rasa.utils.io.json_unpickle(
1008
+ model_path / f"{cls._metadata_filename()}.priority.pkl"
1034
1009
  )
1035
1010
  entity_tag_specs = rasa.shared.utils.io.read_json_file(
1036
1011
  model_path / f"{cls._metadata_filename()}.entity_tag_specs.json"
@@ -1048,8 +1023,8 @@ class TEDPolicy(Policy):
1048
1023
  )
1049
1024
  for tag_spec in entity_tag_specs
1050
1025
  ]
1051
- model_config = rasa.shared.utils.io.read_json_file(
1052
- model_path / f"{cls._metadata_filename()}.meta.json"
1026
+ model_config = rasa.utils.io.pickle_load(
1027
+ model_path / f"{cls._metadata_filename()}.meta.pkl"
1053
1028
  )
1054
1029
 
1055
1030
  return {
@@ -1095,7 +1070,7 @@ class TEDPolicy(Policy):
1095
1070
  ) -> TEDPolicy:
1096
1071
  featurizer = TrackerFeaturizer.load(model_path)
1097
1072
 
1098
- if not (model_path / f"{cls._metadata_filename()}.data_example.st").is_file():
1073
+ if not (model_path / f"{cls._metadata_filename()}.data_example.pkl").is_file():
1099
1074
  return cls(
1100
1075
  config,
1101
1076
  model_storage,
@@ -1117,7 +1092,7 @@ class TEDPolicy(Policy):
1117
1092
 
1118
1093
  model = None
1119
1094
 
1120
- with (contextlib.nullcontext() if config["use_gpu"] else tf.device("/cpu:0")):
1095
+ with contextlib.nullcontext() if config["use_gpu"] else tf.device("/cpu:0"):
1121
1096
  model = cls._load_tf_model(
1122
1097
  model_utilities,
1123
1098
  model_data_example,
@@ -1291,19 +1266,19 @@ class TED(TransformerRasaModel):
1291
1266
  )
1292
1267
  self._prepare_encoding_layers(name)
1293
1268
 
1294
- self._tf_layers[
1295
- f"transformer.{DIALOGUE}"
1296
- ] = rasa_layers.prepare_transformer_layer(
1297
- attribute_name=DIALOGUE,
1298
- config=self.config,
1299
- num_layers=self.config[NUM_TRANSFORMER_LAYERS][DIALOGUE],
1300
- units=self.config[TRANSFORMER_SIZE][DIALOGUE],
1301
- drop_rate=self.config[DROP_RATE_DIALOGUE],
1302
- # use bidirectional transformer, because
1303
- # we will invert dialogue sequence so that the last turn is located
1304
- # at the first position and would always have
1305
- # exactly the same positional encoding
1306
- unidirectional=not self.max_history_featurizer_is_used,
1269
+ self._tf_layers[f"transformer.{DIALOGUE}"] = (
1270
+ rasa_layers.prepare_transformer_layer(
1271
+ attribute_name=DIALOGUE,
1272
+ config=self.config,
1273
+ num_layers=self.config[NUM_TRANSFORMER_LAYERS][DIALOGUE],
1274
+ units=self.config[TRANSFORMER_SIZE][DIALOGUE],
1275
+ drop_rate=self.config[DROP_RATE_DIALOGUE],
1276
+ # use bidirectional transformer, because
1277
+ # we will invert dialogue sequence so that the last turn is located
1278
+ # at the first position and would always have
1279
+ # exactly the same positional encoding
1280
+ unidirectional=not self.max_history_featurizer_is_used,
1281
+ )
1307
1282
  )
1308
1283
 
1309
1284
  self._prepare_label_classification_layers(DIALOGUE)
@@ -1333,23 +1308,23 @@ class TED(TransformerRasaModel):
1333
1308
  # Attributes with sequence-level features also have sentence-level features,
1334
1309
  # all these need to be combined and further processed.
1335
1310
  if attribute_name in SEQUENCE_FEATURES_TO_ENCODE:
1336
- self._tf_layers[
1337
- f"sequence_layer.{attribute_name}"
1338
- ] = rasa_layers.RasaSequenceLayer(
1339
- attribute_name, attribute_signature, config_to_use
1311
+ self._tf_layers[f"sequence_layer.{attribute_name}"] = (
1312
+ rasa_layers.RasaSequenceLayer(
1313
+ attribute_name, attribute_signature, config_to_use
1314
+ )
1340
1315
  )
1341
1316
  # Attributes without sequence-level features require some actual feature
1342
1317
  # processing only if they have sentence-level features. Attributes with no
1343
1318
  # sequence- and sentence-level features (dialogue, entity_tags, label) are
1344
1319
  # skipped here.
1345
1320
  elif SENTENCE in attribute_signature:
1346
- self._tf_layers[
1347
- f"sparse_dense_concat_layer.{attribute_name}"
1348
- ] = rasa_layers.ConcatenateSparseDenseFeatures(
1349
- attribute=attribute_name,
1350
- feature_type=SENTENCE,
1351
- feature_type_signature=attribute_signature[SENTENCE],
1352
- config=config_to_use,
1321
+ self._tf_layers[f"sparse_dense_concat_layer.{attribute_name}"] = (
1322
+ rasa_layers.ConcatenateSparseDenseFeatures(
1323
+ attribute=attribute_name,
1324
+ feature_type=SENTENCE,
1325
+ feature_type_signature=attribute_signature[SENTENCE],
1326
+ config=config_to_use,
1327
+ )
1353
1328
  )
1354
1329
 
1355
1330
  def _prepare_encoding_layers(self, name: Text) -> None:
@@ -1385,7 +1360,7 @@ class TED(TransformerRasaModel):
1385
1360
 
1386
1361
  @staticmethod
1387
1362
  def _compute_dialogue_indices(
1388
- tf_batch_data: Dict[Text, Dict[Text, List[tf.Tensor]]]
1363
+ tf_batch_data: Dict[Text, Dict[Text, List[tf.Tensor]]],
1389
1364
  ) -> None:
1390
1365
  dialogue_lengths = tf.cast(tf_batch_data[DIALOGUE][LENGTH][0], dtype=tf.int32)
1391
1366
  # wrap in a list, because that's the structure of tf_batch_data
@@ -1424,7 +1399,7 @@ class TED(TransformerRasaModel):
1424
1399
 
1425
1400
  @staticmethod
1426
1401
  def _collect_label_attribute_encodings(
1427
- all_labels_encoded: Dict[Text, tf.Tensor]
1402
+ all_labels_encoded: Dict[Text, tf.Tensor],
1428
1403
  ) -> tf.Tensor:
1429
1404
  # Initialize with at least one attribute first
1430
1405
  # so that the subsequent TF ops are simplified.
@@ -1953,7 +1928,6 @@ class TED(TransformerRasaModel):
1953
1928
  text_output: tf.Tensor,
1954
1929
  text_sequence_lengths: tf.Tensor,
1955
1930
  ) -> tf.Tensor:
1956
-
1957
1931
  text_transformed, text_mask, text_sequence_lengths = self._reshape_for_entities(
1958
1932
  tf_batch_data,
1959
1933
  dialogue_transformer_output,
@@ -2156,7 +2130,6 @@ class TED(TransformerRasaModel):
2156
2130
  text_output: tf.Tensor,
2157
2131
  text_sequence_lengths: tf.Tensor,
2158
2132
  ) -> Tuple[tf.Tensor, tf.Tensor]:
2159
-
2160
2133
  text_transformed, _, text_sequence_lengths = self._reshape_for_entities(
2161
2134
  tf_batch_data,
2162
2135
  dialogue_transformer_output,