rasa-pro 3.14.0a20__py3-none-any.whl → 3.14.0a23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rasa-pro might be problematic. Click here for more details.

Files changed (331) hide show
  1. rasa/__main__.py +15 -3
  2. rasa/agents/__init__.py +0 -0
  3. rasa/agents/agent_factory.py +122 -0
  4. rasa/agents/agent_manager.py +211 -0
  5. rasa/agents/constants.py +43 -0
  6. rasa/agents/core/__init__.py +0 -0
  7. rasa/agents/core/agent_protocol.py +107 -0
  8. rasa/agents/core/types.py +81 -0
  9. rasa/agents/exceptions.py +38 -0
  10. rasa/agents/protocol/__init__.py +5 -0
  11. rasa/agents/protocol/a2a/__init__.py +0 -0
  12. rasa/agents/protocol/a2a/a2a_agent.py +879 -0
  13. rasa/agents/protocol/mcp/__init__.py +0 -0
  14. rasa/agents/protocol/mcp/mcp_base_agent.py +726 -0
  15. rasa/agents/protocol/mcp/mcp_open_agent.py +327 -0
  16. rasa/agents/protocol/mcp/mcp_task_agent.py +522 -0
  17. rasa/agents/schemas/__init__.py +13 -0
  18. rasa/agents/schemas/agent_input.py +38 -0
  19. rasa/agents/schemas/agent_output.py +26 -0
  20. rasa/agents/schemas/agent_tool_result.py +65 -0
  21. rasa/agents/schemas/agent_tool_schema.py +186 -0
  22. rasa/agents/templates/__init__.py +0 -0
  23. rasa/agents/templates/mcp_open_agent_prompt_template.jinja2 +20 -0
  24. rasa/agents/templates/mcp_task_agent_prompt_template.jinja2 +22 -0
  25. rasa/agents/utils.py +206 -0
  26. rasa/agents/validation.py +485 -0
  27. rasa/api.py +24 -9
  28. rasa/builder/config.py +6 -2
  29. rasa/builder/guardrails/{lakera.py → clients.py} +55 -5
  30. rasa/builder/guardrails/constants.py +3 -0
  31. rasa/builder/guardrails/models.py +45 -10
  32. rasa/builder/guardrails/policy_checker.py +324 -0
  33. rasa/builder/guardrails/utils.py +42 -276
  34. rasa/builder/llm_service.py +32 -5
  35. rasa/builder/models.py +1 -0
  36. rasa/builder/project_generator.py +6 -1
  37. rasa/builder/service.py +16 -13
  38. rasa/builder/training_service.py +18 -24
  39. rasa/builder/validation_service.py +1 -1
  40. rasa/cli/arguments/default_arguments.py +12 -0
  41. rasa/cli/arguments/run.py +2 -0
  42. rasa/cli/arguments/train.py +2 -0
  43. rasa/cli/data.py +10 -8
  44. rasa/cli/dialogue_understanding_test.py +10 -7
  45. rasa/cli/e2e_test.py +9 -6
  46. rasa/cli/evaluate.py +4 -2
  47. rasa/cli/export.py +5 -2
  48. rasa/cli/inspect.py +8 -4
  49. rasa/cli/interactive.py +5 -4
  50. rasa/cli/llm_fine_tuning.py +11 -6
  51. rasa/cli/project_templates/finance/actions/general/__init__.py +0 -0
  52. rasa/cli/project_templates/finance/actions/general/action_human_handoff.py +49 -0
  53. rasa/cli/project_templates/finance/data/general/bot_challenge.yml +6 -0
  54. rasa/cli/project_templates/finance/data/general/goodbye.yml +1 -1
  55. rasa/cli/project_templates/finance/data/general/human_handoff.yml +1 -1
  56. rasa/cli/project_templates/finance/data/system/patterns/pattern_session_start.yml +1 -1
  57. rasa/cli/project_templates/finance/domain/general/_shared.yml +0 -14
  58. rasa/cli/project_templates/finance/domain/general/bot_challenge.yml +4 -0
  59. rasa/cli/project_templates/finance/domain/general/goodbye.yml +7 -0
  60. rasa/cli/project_templates/finance/domain/general/human_handoff.yml +3 -6
  61. rasa/cli/project_templates/finance/domain/general/welcome.yml +29 -1
  62. rasa/cli/project_templates/finance/tests/e2e_test_cases/accounts/check_balance.yml +9 -0
  63. rasa/cli/project_templates/finance/tests/e2e_test_cases/accounts/download_statements.yml +43 -0
  64. rasa/cli/project_templates/finance/tests/e2e_test_cases/cards/block_card.yml +55 -0
  65. rasa/cli/project_templates/finance/tests/e2e_test_cases/general/bot_challenge.yml +8 -0
  66. rasa/cli/project_templates/finance/tests/e2e_test_cases/general/feedback.yml +46 -0
  67. rasa/cli/project_templates/finance/tests/e2e_test_cases/general/goodbye.yml +9 -0
  68. rasa/cli/project_templates/finance/tests/e2e_test_cases/general/hello.yml +8 -0
  69. rasa/cli/project_templates/finance/tests/e2e_test_cases/general/human_handoff.yml +35 -0
  70. rasa/cli/project_templates/finance/tests/e2e_test_cases/general/patterns.yml +22 -0
  71. rasa/cli/project_templates/finance/tests/e2e_test_cases/transfers/transfer_money.yml +56 -0
  72. rasa/cli/project_templates/telco/tests/e2e_test_cases/general/feedback.yml +1 -1
  73. rasa/cli/project_templates/telco/tests/e2e_test_cases/general/hello.yml +1 -1
  74. rasa/cli/project_templates/telco/tests/e2e_test_cases/general/human_handoff.yml +1 -1
  75. rasa/cli/project_templates/telco/tests/e2e_test_cases/general/patterns.yml +1 -1
  76. rasa/cli/project_templates/tutorial/credentials.yml +10 -0
  77. rasa/cli/run.py +12 -10
  78. rasa/cli/scaffold.py +4 -4
  79. rasa/cli/shell.py +9 -5
  80. rasa/cli/studio/studio.py +1 -1
  81. rasa/cli/test.py +34 -14
  82. rasa/cli/train.py +41 -28
  83. rasa/cli/utils.py +1 -393
  84. rasa/cli/validation/__init__.py +0 -0
  85. rasa/cli/validation/bot_config.py +223 -0
  86. rasa/cli/validation/config_path_validation.py +257 -0
  87. rasa/cli/x.py +8 -4
  88. rasa/constants.py +7 -1
  89. rasa/core/actions/action.py +51 -10
  90. rasa/core/actions/action_run_slot_rejections.py +1 -1
  91. rasa/core/actions/direct_custom_actions_executor.py +9 -2
  92. rasa/core/actions/grpc_custom_action_executor.py +1 -1
  93. rasa/core/agent.py +19 -2
  94. rasa/core/available_agents.py +229 -0
  95. rasa/core/brokers/kafka.py +1 -1
  96. rasa/core/channels/__init__.py +82 -35
  97. rasa/core/channels/development_inspector.py +3 -3
  98. rasa/core/channels/inspector/README.md +25 -13
  99. rasa/core/channels/inspector/dist/assets/{arc-35222594.js → arc-6177260a.js} +1 -1
  100. rasa/core/channels/inspector/dist/assets/{blockDiagram-38ab4fdb-a0efbfd3.js → blockDiagram-38ab4fdb-b054f038.js} +1 -1
  101. rasa/core/channels/inspector/dist/assets/{c4Diagram-3d4e48cf-0584c0f2.js → c4Diagram-3d4e48cf-f25427d5.js} +1 -1
  102. rasa/core/channels/inspector/dist/assets/channel-bf9cbb34.js +1 -0
  103. rasa/core/channels/inspector/dist/assets/{classDiagram-70f12bd4-39f40dbe.js → classDiagram-70f12bd4-c7a2af53.js} +1 -1
  104. rasa/core/channels/inspector/dist/assets/{classDiagram-v2-f2320105-1ad755f3.js → classDiagram-v2-f2320105-58db65c0.js} +1 -1
  105. rasa/core/channels/inspector/dist/assets/clone-8f9083bb.js +1 -0
  106. rasa/core/channels/inspector/dist/assets/{createText-2e5e7dd3-b0f4f0fe.js → createText-2e5e7dd3-088372e2.js} +1 -1
  107. rasa/core/channels/inspector/dist/assets/{edges-e0da2a9e-9039bff9.js → edges-e0da2a9e-58676240.js} +1 -1
  108. rasa/core/channels/inspector/dist/assets/{erDiagram-9861fffd-65c9b127.js → erDiagram-9861fffd-0c14d7c6.js} +1 -1
  109. rasa/core/channels/inspector/dist/assets/{flowDb-956e92f1-4f08b38e.js → flowDb-956e92f1-ea63f85c.js} +1 -1
  110. rasa/core/channels/inspector/dist/assets/{flowDiagram-66a62f08-e95c362a.js → flowDiagram-66a62f08-a2af48cd.js} +1 -1
  111. rasa/core/channels/inspector/dist/assets/flowDiagram-v2-96b9c2cf-9ecd5b59.js +1 -0
  112. rasa/core/channels/inspector/dist/assets/{flowchart-elk-definition-4a651766-703c3015.js → flowchart-elk-definition-4a651766-6937abe7.js} +1 -1
  113. rasa/core/channels/inspector/dist/assets/{ganttDiagram-c361ad54-699328ea.js → ganttDiagram-c361ad54-7473f357.js} +1 -1
  114. rasa/core/channels/inspector/dist/assets/{gitGraphDiagram-72cf32ee-04cf4b05.js → gitGraphDiagram-72cf32ee-d0c9405e.js} +1 -1
  115. rasa/core/channels/inspector/dist/assets/{graph-ee94449e.js → graph-0a6f8466.js} +1 -1
  116. rasa/core/channels/inspector/dist/assets/{index-3862675e-940162b4.js → index-3862675e-7610671a.js} +1 -1
  117. rasa/core/channels/inspector/dist/assets/index-74e01d94.js +1354 -0
  118. rasa/core/channels/inspector/dist/assets/{infoDiagram-f8f76790-c79c2866.js → infoDiagram-f8f76790-be397dc7.js} +1 -1
  119. rasa/core/channels/inspector/dist/assets/{journeyDiagram-49397b02-84489d30.js → journeyDiagram-49397b02-4cefbf62.js} +1 -1
  120. rasa/core/channels/inspector/dist/assets/{layout-a9aa9858.js → layout-e7fbc2bf.js} +1 -1
  121. rasa/core/channels/inspector/dist/assets/{line-eb73cf26.js → line-a8aa457c.js} +1 -1
  122. rasa/core/channels/inspector/dist/assets/{linear-b3399f9a.js → linear-3351e0d2.js} +1 -1
  123. rasa/core/channels/inspector/dist/assets/{mindmap-definition-fc14e90a-b095bf1a.js → mindmap-definition-fc14e90a-b8cbf605.js} +1 -1
  124. rasa/core/channels/inspector/dist/assets/{pieDiagram-8a3498a8-07644b66.js → pieDiagram-8a3498a8-f327f774.js} +1 -1
  125. rasa/core/channels/inspector/dist/assets/{quadrantDiagram-120e2f19-573a3f9c.js → quadrantDiagram-120e2f19-2854c591.js} +1 -1
  126. rasa/core/channels/inspector/dist/assets/{requirementDiagram-deff3bca-d457e1e1.js → requirementDiagram-deff3bca-964985d5.js} +1 -1
  127. rasa/core/channels/inspector/dist/assets/{sankeyDiagram-04a897e0-9d26e1a2.js → sankeyDiagram-04a897e0-edeb4f33.js} +1 -1
  128. rasa/core/channels/inspector/dist/assets/{sequenceDiagram-704730f1-3a9cde10.js → sequenceDiagram-704730f1-fcf70125.js} +1 -1
  129. rasa/core/channels/inspector/dist/assets/{stateDiagram-587899a1-4f3e8cec.js → stateDiagram-587899a1-0e770395.js} +1 -1
  130. rasa/core/channels/inspector/dist/assets/{stateDiagram-v2-d93cdb3a-e617e5bf.js → stateDiagram-v2-d93cdb3a-af8dcd22.js} +1 -1
  131. rasa/core/channels/inspector/dist/assets/{styles-6aaf32cf-eab30d2f.js → styles-6aaf32cf-36a9e70d.js} +1 -1
  132. rasa/core/channels/inspector/dist/assets/{styles-9a916d00-09994be2.js → styles-9a916d00-884a8b5b.js} +1 -1
  133. rasa/core/channels/inspector/dist/assets/{styles-c10674c1-b7110364.js → styles-c10674c1-dc097813.js} +1 -1
  134. rasa/core/channels/inspector/dist/assets/{svgDrawCommon-08f97a94-3ebc92ad.js → svgDrawCommon-08f97a94-5a2c7eed.js} +1 -1
  135. rasa/core/channels/inspector/dist/assets/{timeline-definition-85554ec2-7d13d2f2.js → timeline-definition-85554ec2-e89c4f6e.js} +1 -1
  136. rasa/core/channels/inspector/dist/assets/{xychartDiagram-e933f94c-488385e1.js → xychartDiagram-e933f94c-afb6fe56.js} +1 -1
  137. rasa/core/channels/inspector/dist/index.html +1 -1
  138. rasa/core/channels/inspector/package.json +18 -18
  139. rasa/core/channels/inspector/src/App.tsx +29 -4
  140. rasa/core/channels/inspector/src/components/DialogueAgentStack.tsx +108 -0
  141. rasa/core/channels/inspector/src/components/{DialogueStack.tsx → DialogueHistoryStack.tsx} +4 -2
  142. rasa/core/channels/inspector/src/helpers/audio/audiostream.ts +7 -4
  143. rasa/core/channels/inspector/src/helpers/formatters.test.ts +4 -0
  144. rasa/core/channels/inspector/src/helpers/formatters.ts +24 -3
  145. rasa/core/channels/inspector/src/helpers/utils.test.ts +127 -0
  146. rasa/core/channels/inspector/src/helpers/utils.ts +66 -1
  147. rasa/core/channels/inspector/src/theme/base/styles.ts +19 -1
  148. rasa/core/channels/inspector/src/types.ts +21 -0
  149. rasa/core/channels/inspector/yarn.lock +336 -189
  150. rasa/core/channels/studio_chat.py +6 -6
  151. rasa/core/channels/telegram.py +4 -9
  152. rasa/core/channels/voice_stream/browser_audio.py +2 -0
  153. rasa/core/channels/voice_stream/genesys.py +1 -1
  154. rasa/core/channels/voice_stream/tts/deepgram.py +140 -0
  155. rasa/core/channels/voice_stream/twilio_media_streams.py +5 -1
  156. rasa/core/channels/voice_stream/voice_channel.py +3 -0
  157. rasa/core/config/__init__.py +0 -0
  158. rasa/core/{available_endpoints.py → config/available_endpoints.py} +51 -16
  159. rasa/core/config/configuration.py +260 -0
  160. rasa/core/config/credentials.py +19 -0
  161. rasa/core/config/message_procesing_config.py +34 -0
  162. rasa/core/constants.py +5 -0
  163. rasa/core/iam_credentials_providers/aws_iam_credentials_providers.py +88 -3
  164. rasa/core/iam_credentials_providers/credentials_provider_protocol.py +2 -1
  165. rasa/core/lock_store.py +6 -4
  166. rasa/core/nlg/generator.py +1 -1
  167. rasa/core/policies/enterprise_search_policy.py +5 -3
  168. rasa/core/policies/flow_policy.py +4 -4
  169. rasa/core/policies/flows/agent_executor.py +632 -0
  170. rasa/core/policies/flows/flow_executor.py +137 -76
  171. rasa/core/policies/flows/mcp_tool_executor.py +298 -0
  172. rasa/core/policies/intentless_policy.py +1 -1
  173. rasa/core/policies/ted_policy.py +20 -12
  174. rasa/core/policies/unexpected_intent_policy.py +6 -0
  175. rasa/core/processor.py +68 -44
  176. rasa/core/redis_connection_factory.py +78 -20
  177. rasa/core/run.py +37 -8
  178. rasa/core/test.py +4 -0
  179. rasa/core/tracker_stores/sql_tracker_store.py +1 -1
  180. rasa/core/tracker_stores/tracker_store.py +3 -7
  181. rasa/core/train.py +1 -1
  182. rasa/core/training/interactive.py +20 -18
  183. rasa/core/training/story_conflict.py +5 -5
  184. rasa/core/utils.py +22 -23
  185. rasa/dialogue_understanding/commands/__init__.py +8 -0
  186. rasa/dialogue_understanding/commands/cancel_flow_command.py +19 -5
  187. rasa/dialogue_understanding/commands/chit_chat_answer_command.py +21 -2
  188. rasa/dialogue_understanding/commands/clarify_command.py +20 -2
  189. rasa/dialogue_understanding/commands/continue_agent_command.py +91 -0
  190. rasa/dialogue_understanding/commands/knowledge_answer_command.py +21 -2
  191. rasa/dialogue_understanding/commands/restart_agent_command.py +162 -0
  192. rasa/dialogue_understanding/commands/start_flow_command.py +68 -7
  193. rasa/dialogue_understanding/commands/utils.py +124 -2
  194. rasa/dialogue_understanding/generator/command_parser.py +4 -0
  195. rasa/dialogue_understanding/generator/llm_based_command_generator.py +50 -12
  196. rasa/dialogue_understanding/generator/llm_command_generator.py +1 -1
  197. rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +1 -1
  198. rasa/dialogue_understanding/generator/prompt_templates/agent_command_prompt_v2_claude_3_5_sonnet_20240620_template.jinja2 +66 -0
  199. rasa/dialogue_understanding/generator/prompt_templates/agent_command_prompt_v2_gpt_4o_2024_11_20_template.jinja2 +66 -0
  200. rasa/dialogue_understanding/generator/prompt_templates/agent_command_prompt_v3_claude_3_5_sonnet_20240620_template.jinja2 +89 -0
  201. rasa/dialogue_understanding/generator/prompt_templates/agent_command_prompt_v3_gpt_4o_2024_11_20_template.jinja2 +88 -0
  202. rasa/dialogue_understanding/generator/single_step/compact_llm_command_generator.py +42 -7
  203. rasa/dialogue_understanding/generator/single_step/search_ready_llm_command_generator.py +40 -3
  204. rasa/dialogue_understanding/generator/single_step/single_step_based_llm_command_generator.py +20 -3
  205. rasa/dialogue_understanding/patterns/cancel.py +27 -6
  206. rasa/dialogue_understanding/patterns/clarify.py +3 -14
  207. rasa/dialogue_understanding/patterns/continue_interrupted.py +239 -6
  208. rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +46 -8
  209. rasa/dialogue_understanding/processor/command_processor.py +136 -15
  210. rasa/dialogue_understanding/stack/dialogue_stack.py +98 -2
  211. rasa/dialogue_understanding/stack/frames/flow_stack_frame.py +57 -0
  212. rasa/dialogue_understanding/stack/utils.py +57 -3
  213. rasa/dialogue_understanding/utils.py +24 -4
  214. rasa/dialogue_understanding_test/du_test_runner.py +8 -3
  215. rasa/e2e_test/e2e_test_runner.py +13 -3
  216. rasa/engine/caching.py +2 -2
  217. rasa/engine/constants.py +1 -1
  218. rasa/engine/loader.py +12 -0
  219. rasa/engine/recipes/default_components.py +138 -49
  220. rasa/engine/recipes/default_recipe.py +108 -11
  221. rasa/engine/runner/dask.py +8 -5
  222. rasa/engine/validation.py +19 -6
  223. rasa/graph_components/validators/default_recipe_validator.py +86 -28
  224. rasa/hooks.py +5 -5
  225. rasa/llm_fine_tuning/utils.py +2 -2
  226. rasa/model_training.py +60 -47
  227. rasa/nlu/classifiers/diet_classifier.py +198 -98
  228. rasa/nlu/classifiers/logistic_regression_classifier.py +1 -4
  229. rasa/nlu/classifiers/mitie_intent_classifier.py +3 -0
  230. rasa/nlu/classifiers/sklearn_intent_classifier.py +1 -3
  231. rasa/nlu/extractors/crf_entity_extractor.py +9 -10
  232. rasa/nlu/extractors/mitie_entity_extractor.py +3 -0
  233. rasa/nlu/extractors/spacy_entity_extractor.py +3 -0
  234. rasa/nlu/featurizers/dense_featurizer/convert_featurizer.py +4 -0
  235. rasa/nlu/featurizers/dense_featurizer/lm_featurizer.py +5 -0
  236. rasa/nlu/featurizers/dense_featurizer/mitie_featurizer.py +2 -0
  237. rasa/nlu/featurizers/dense_featurizer/spacy_featurizer.py +3 -0
  238. rasa/nlu/featurizers/sparse_featurizer/count_vectors_featurizer.py +4 -2
  239. rasa/nlu/featurizers/sparse_featurizer/lexical_syntactic_featurizer.py +4 -0
  240. rasa/nlu/selectors/response_selector.py +10 -2
  241. rasa/nlu/tokenizers/jieba_tokenizer.py +3 -4
  242. rasa/nlu/tokenizers/mitie_tokenizer.py +3 -2
  243. rasa/nlu/tokenizers/spacy_tokenizer.py +3 -2
  244. rasa/nlu/utils/mitie_utils.py +3 -0
  245. rasa/nlu/utils/spacy_utils.py +3 -2
  246. rasa/plugin.py +8 -8
  247. rasa/privacy/privacy_manager.py +12 -3
  248. rasa/server.py +15 -3
  249. rasa/shared/agents/__init__.py +0 -0
  250. rasa/shared/agents/auth/__init__.py +0 -0
  251. rasa/shared/agents/auth/agent_auth_factory.py +105 -0
  252. rasa/shared/agents/auth/agent_auth_manager.py +92 -0
  253. rasa/shared/agents/auth/auth_strategy/__init__.py +19 -0
  254. rasa/shared/agents/auth/auth_strategy/agent_auth_strategy.py +52 -0
  255. rasa/shared/agents/auth/auth_strategy/api_key_auth_strategy.py +42 -0
  256. rasa/shared/agents/auth/auth_strategy/bearer_token_auth_strategy.py +28 -0
  257. rasa/shared/agents/auth/auth_strategy/oauth2_auth_strategy.py +167 -0
  258. rasa/shared/agents/auth/constants.py +12 -0
  259. rasa/shared/agents/auth/types.py +12 -0
  260. rasa/shared/agents/utils.py +35 -0
  261. rasa/shared/constants.py +8 -0
  262. rasa/shared/core/constants.py +16 -1
  263. rasa/shared/core/domain.py +0 -7
  264. rasa/shared/core/events.py +327 -0
  265. rasa/shared/core/flows/constants.py +5 -0
  266. rasa/shared/core/flows/flow.py +1 -1
  267. rasa/shared/core/flows/flows_list.py +21 -5
  268. rasa/shared/core/flows/flows_yaml_schema.json +119 -184
  269. rasa/shared/core/flows/steps/call.py +49 -5
  270. rasa/shared/core/flows/steps/collect.py +98 -13
  271. rasa/shared/core/flows/validation.py +372 -8
  272. rasa/shared/core/flows/yaml_flows_io.py +3 -2
  273. rasa/shared/core/slots.py +2 -2
  274. rasa/shared/core/trackers.py +5 -2
  275. rasa/shared/exceptions.py +16 -0
  276. rasa/shared/importers/rasa.py +1 -1
  277. rasa/shared/importers/utils.py +9 -3
  278. rasa/shared/providers/llm/_base_litellm_client.py +41 -9
  279. rasa/shared/providers/llm/litellm_router_llm_client.py +8 -4
  280. rasa/shared/providers/llm/llm_client.py +7 -3
  281. rasa/shared/providers/llm/llm_response.py +66 -0
  282. rasa/shared/providers/llm/self_hosted_llm_client.py +8 -4
  283. rasa/shared/utils/common.py +24 -0
  284. rasa/shared/utils/health_check/health_check.py +7 -3
  285. rasa/shared/utils/llm.py +39 -16
  286. rasa/shared/utils/mcp/__init__.py +0 -0
  287. rasa/shared/utils/mcp/server_connection.py +247 -0
  288. rasa/shared/utils/mcp/utils.py +20 -0
  289. rasa/shared/utils/schemas/events.py +42 -0
  290. rasa/shared/utils/yaml.py +3 -1
  291. rasa/studio/pull/pull.py +3 -2
  292. rasa/studio/train.py +8 -7
  293. rasa/studio/upload.py +3 -6
  294. rasa/telemetry.py +69 -5
  295. rasa/tracing/config.py +45 -12
  296. rasa/tracing/constants.py +14 -0
  297. rasa/tracing/instrumentation/attribute_extractors.py +142 -9
  298. rasa/tracing/instrumentation/instrumentation.py +626 -21
  299. rasa/tracing/instrumentation/intentless_policy_instrumentation.py +4 -4
  300. rasa/tracing/instrumentation/metrics.py +32 -0
  301. rasa/tracing/metric_instrument_provider.py +68 -0
  302. rasa/utils/common.py +92 -1
  303. rasa/utils/endpoints.py +11 -2
  304. rasa/utils/log_utils.py +96 -5
  305. rasa/utils/ml_utils.py +1 -1
  306. rasa/utils/pypred.py +38 -0
  307. rasa/utils/tensorflow/__init__.py +7 -0
  308. rasa/utils/tensorflow/callback.py +136 -101
  309. rasa/utils/tensorflow/crf.py +1 -1
  310. rasa/utils/tensorflow/data_generator.py +21 -8
  311. rasa/utils/tensorflow/layers.py +21 -11
  312. rasa/utils/tensorflow/metrics.py +7 -3
  313. rasa/utils/tensorflow/models.py +56 -8
  314. rasa/utils/tensorflow/rasa_layers.py +8 -6
  315. rasa/utils/tensorflow/transformer.py +2 -3
  316. rasa/utils/train_utils.py +54 -24
  317. rasa/validator.py +17 -13
  318. rasa/version.py +1 -1
  319. {rasa_pro-3.14.0a20.dist-info → rasa_pro-3.14.0a23.dist-info}/METADATA +48 -42
  320. {rasa_pro-3.14.0a20.dist-info → rasa_pro-3.14.0a23.dist-info}/RECORD +323 -251
  321. rasa/builder/scrape_rasa_docs.py +0 -97
  322. rasa/cli/project_templates/finance/data/general/agent_details.yml +0 -6
  323. rasa/cli/project_templates/finance/domain/_system/patterns/pattern_session_start.yml +0 -11
  324. rasa/cli/project_templates/finance/domain/general/agent_details.yml +0 -31
  325. rasa/core/channels/inspector/dist/assets/channel-8e08bed9.js +0 -1
  326. rasa/core/channels/inspector/dist/assets/clone-78c82dea.js +0 -1
  327. rasa/core/channels/inspector/dist/assets/flowDiagram-v2-96b9c2cf-2b08f601.js +0 -1
  328. rasa/core/channels/inspector/dist/assets/index-c941dcb3.js +0 -1336
  329. {rasa_pro-3.14.0a20.dist-info → rasa_pro-3.14.0a23.dist-info}/NOTICE +0 -0
  330. {rasa_pro-3.14.0a20.dist-info → rasa_pro-3.14.0a23.dist-info}/WHEEL +0 -0
  331. {rasa_pro-3.14.0a20.dist-info → rasa_pro-3.14.0a23.dist-info}/entry_points.txt +0 -0
@@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Text, Tuple
5
5
 
6
6
  import structlog
7
7
  import tiktoken
8
- from deprecated import deprecated # type: ignore[import]
8
+ from deprecated import deprecated # type: ignore[import-untyped]
9
9
  from jinja2 import Template
10
10
  from langchain.docstore.document import Document
11
11
  from langchain.schema.embeddings import Embeddings
@@ -58,6 +58,7 @@ from rasa.shared.nlu.training_data.features import (
58
58
  save_features,
59
59
  )
60
60
  from rasa.shared.nlu.training_data.message import Message
61
+ from rasa.shared.utils.io import raise_deprecation_warning
61
62
  from rasa.utils import train_utils
62
63
  from rasa.utils.tensorflow import rasa_layers
63
64
  from rasa.utils.tensorflow.constants import (
@@ -80,7 +81,6 @@ from rasa.utils.tensorflow.constants import (
80
81
  EMBEDDING_DIMENSION,
81
82
  ENCODING_DIMENSION,
82
83
  ENTITY_RECOGNITION,
83
- EPOCH_OVERRIDE,
84
84
  EPOCHS,
85
85
  EVAL_NUM_EPOCHS,
86
86
  EVAL_NUM_EXAMPLES,
@@ -363,6 +363,9 @@ class TEDPolicy(Policy):
363
363
  entity_tag_specs: Optional[List[EntityTagSpec]] = None,
364
364
  ) -> None:
365
365
  """Declares instance variables with default values."""
366
+ raise_deprecation_warning(
367
+ "TEDPolicy is deprecated and will be removed in a future version."
368
+ )
366
369
  super().__init__(
367
370
  config, model_storage, resource, execution_context, featurizer=featurizer
368
371
  )
@@ -668,6 +671,7 @@ class TEDPolicy(Policy):
668
671
  self.model.compile(
669
672
  optimizer=tf.keras.optimizers.Adam(self.config[LEARNING_RATE])
670
673
  )
674
+
671
675
  (
672
676
  data_generator,
673
677
  validation_data_generator,
@@ -943,14 +947,16 @@ class TEDPolicy(Policy):
943
947
 
944
948
  with self._model_storage.write_to(self._resource) as model_path:
945
949
  model_filename = self._metadata_filename()
946
- tf_model_file = model_path / f"{model_filename}.tf_model"
950
+ tf_model_file = model_path / f"{model_filename}.weights.h5"
947
951
 
948
952
  rasa.shared.utils.io.create_directory_for_file(tf_model_file)
949
953
 
950
954
  self.featurizer.persist(model_path)
951
955
 
952
956
  if self.config[CHECKPOINT_MODEL] and self.tmp_checkpoint_dir:
953
- self.model.load_weights(self.tmp_checkpoint_dir / "checkpoint.tf_model")
957
+ self.model.load_weights(
958
+ self.tmp_checkpoint_dir / "checkpoint.weights.h5"
959
+ )
954
960
  # Save an empty file to flag that this model has been
955
961
  # produced using checkpointing
956
962
  checkpoint_marker = model_path / f"{model_filename}.from_checkpoint.pkl"
@@ -1009,7 +1015,7 @@ class TEDPolicy(Policy):
1009
1015
  Args:
1010
1016
  model_path: Path where model is to be persisted.
1011
1017
  """
1012
- tf_model_file = model_path / f"{cls._metadata_filename()}.tf_model"
1018
+ tf_model_file = model_path / f"{cls._metadata_filename()}.weights.h5"
1013
1019
 
1014
1020
  # load data example
1015
1021
  loaded_data = deserialize_nested_feature_arrays(
@@ -1109,8 +1115,6 @@ class TEDPolicy(Policy):
1109
1115
  model_utilities = cls._load_model_utilities(model_path)
1110
1116
 
1111
1117
  config = cls._update_loaded_params(config)
1112
- if execution_context.is_finetuning and EPOCH_OVERRIDE in config:
1113
- config[EPOCHS] = config.get(EPOCH_OVERRIDE)
1114
1118
 
1115
1119
  (
1116
1120
  model_data_example,
@@ -1125,7 +1129,6 @@ class TEDPolicy(Policy):
1125
1129
  model_data_example,
1126
1130
  predict_data_example,
1127
1131
  featurizer,
1128
- execution_context.is_finetuning,
1129
1132
  )
1130
1133
 
1131
1134
  return cls._load_policy_with_model(
@@ -1167,7 +1170,6 @@ class TEDPolicy(Policy):
1167
1170
  model_data_example: RasaModelData,
1168
1171
  predict_data_example: RasaModelData,
1169
1172
  featurizer: TrackerFeaturizer,
1170
- should_finetune: bool,
1171
1173
  ) -> TED:
1172
1174
  model = cls.model_class().load(
1173
1175
  str(model_utilities["tf_model_file"]),
@@ -1180,7 +1182,9 @@ class TEDPolicy(Policy):
1180
1182
  ),
1181
1183
  label_data=model_utilities["label_data"],
1182
1184
  entity_tag_specs=model_utilities["entity_tag_specs"],
1183
- finetune_mode=should_finetune,
1185
+ # This feature is no longer supported as the updated version
1186
+ # of Keras does not allow updating a compiled model anymore.
1187
+ finetune_mode=False,
1184
1188
  )
1185
1189
  return model
1186
1190
 
@@ -1463,7 +1467,7 @@ class TED(TransformerRasaModel):
1463
1467
 
1464
1468
  dialogue_transformed, attention_weights = self._tf_layers[
1465
1469
  f"transformer.{DIALOGUE}"
1466
- ](dialogue_in, 1 - mask, self._training)
1470
+ ](dialogue_in, 1 - mask, training=self._training)
1467
1471
  dialogue_transformed = tf.nn.gelu(dialogue_transformed)
1468
1472
 
1469
1473
  if self.max_history_featurizer_is_used:
@@ -1708,7 +1712,7 @@ class TED(TransformerRasaModel):
1708
1712
 
1709
1713
  if attribute in SENTENCE_FEATURES_TO_ENCODE + LABEL_FEATURES_TO_ENCODE:
1710
1714
  attribute_features = self._tf_layers[f"encoding_layer.{attribute}"](
1711
- attribute_features, self._training
1715
+ attribute_features, training=self._training
1712
1716
  )
1713
1717
 
1714
1718
  # attribute features have shape
@@ -2102,7 +2106,11 @@ class TED(TransformerRasaModel):
2102
2106
  predictions = {
2103
2107
  "scores": scores,
2104
2108
  "similarities": sim_all,
2105
- DIAGNOSTIC_DATA: {"attention_weights": attention_weights},
2109
+ DIAGNOSTIC_DATA: {
2110
+ "attention_weights": attention_weights.numpy()
2111
+ if attention_weights is not None and hasattr(attention_weights, "numpy")
2112
+ else attention_weights,
2113
+ },
2106
2114
  }
2107
2115
 
2108
2116
  if (
@@ -54,6 +54,7 @@ from rasa.shared.nlu.constants import (
54
54
  )
55
55
  from rasa.shared.nlu.training_data.features import Features
56
56
  from rasa.shared.utils import common
57
+ from rasa.shared.utils.io import raise_deprecation_warning
57
58
  from rasa.utils import train_utils
58
59
  from rasa.utils.tensorflow import layers
59
60
  from rasa.utils.tensorflow.constants import (
@@ -300,6 +301,10 @@ class UnexpecTEDIntentPolicy(TEDPolicy):
300
301
  label_quantiles: Optional[Dict[int, List[float]]] = None,
301
302
  ):
302
303
  """Declares instance variables with default values."""
304
+ raise_deprecation_warning(
305
+ "UnexpecTEDIntentPolicy is deprecated and "
306
+ "will be removed in a future version."
307
+ )
303
308
  # Set all invalid / non configurable parameters
304
309
  config[ENTITY_RECOGNITION] = False
305
310
  config[BILOU_FLAG] = False
@@ -624,6 +629,7 @@ class UnexpecTEDIntentPolicy(TEDPolicy):
624
629
  query_intent = (
625
630
  last_user_uttered_event.intent_name
626
631
  if last_user_uttered_event is not None
632
+ and isinstance(last_user_uttered_event, UserUttered)
627
633
  else ""
628
634
  )
629
635
  is_unlikely_intent = self._check_unlikely_intent(
rasa/core/processor.py CHANGED
@@ -1,6 +1,5 @@
1
1
  import copy
2
2
  import inspect
3
- import logging
4
3
  import os
5
4
  import re
6
5
  import tarfile
@@ -69,6 +68,7 @@ from rasa.shared.constants import (
69
68
  UTTER_PREFIX,
70
69
  )
71
70
  from rasa.shared.core.constants import (
71
+ ACTION_AGENT_REQUEST_USER_INPUT_NAME,
72
72
  ACTION_CORRECT_FLOW_SLOT,
73
73
  ACTION_EXTRACT_SLOTS,
74
74
  ACTION_LISTEN_NAME,
@@ -113,10 +113,9 @@ from rasa.utils.common import TempDirectoryPath, get_temp_dir_name
113
113
  from rasa.utils.endpoints import EndpointConfig
114
114
 
115
115
  if TYPE_CHECKING:
116
- from rasa.core.available_endpoints import AvailableEndpoints
116
+ from rasa.core.config.available_endpoints import AvailableEndpoints
117
117
  from rasa.privacy.privacy_manager import BackgroundPrivacyManager
118
118
 
119
- logger = logging.getLogger(__name__)
120
119
  structlogger = structlog.get_logger()
121
120
 
122
121
  MAX_NUMBER_OF_PREDICTIONS = int(os.environ.get("MAX_NUMBER_OF_PREDICTIONS", "10"))
@@ -190,7 +189,11 @@ class MessageProcessor:
190
189
  except TypeError:
191
190
  raise ModelNotFound(f"Model {model_path} can not be loaded.")
192
191
 
193
- logger.info(f"Loading model {model_tar}...")
192
+ structlogger.info(
193
+ "rasa.core.processor.load_model",
194
+ event_info="Loading model.",
195
+ model_path=model_tar,
196
+ )
194
197
  with TempDirectoryPath(get_temp_dir_name()) as temporary_directory:
195
198
  try:
196
199
  metadata, runner = loader.load_predict_graph_runner(
@@ -365,8 +368,10 @@ class MessageProcessor:
365
368
  `ActionSessionStart`.
366
369
  """
367
370
  if not tracker.applied_events() or self._has_session_expired(tracker):
368
- logger.debug(
369
- f"Starting a new session for conversation ID '{tracker.sender_id}'."
371
+ structlogger.debug(
372
+ "rasa.core.processor._update_tracker_session",
373
+ event_info="Starting a new session.",
374
+ sender_id=tracker.sender_id,
370
375
  )
371
376
 
372
377
  action_session_start = self._get_action(ACTION_SESSION_START_NAME)
@@ -598,9 +603,11 @@ class MessageProcessor:
598
603
  prediction.max_confidence_index, self.domain, self.action_endpoint
599
604
  )
600
605
 
601
- logger.debug(
602
- f"Predicted next action '{action.name()}' with confidence "
603
- f"{prediction.max_confidence:.2f}."
606
+ structlogger.debug(
607
+ "rasa.core.processor.predict_next_with_tracker_if_should",
608
+ event_info="Predicted next action.",
609
+ action=action.name(),
610
+ confidence=prediction.max_confidence,
604
611
  )
605
612
 
606
613
  return action, prediction
@@ -650,8 +657,10 @@ class MessageProcessor:
650
657
  and self._has_message_after_reminder(tracker, reminder_event)
651
658
  or not self._is_reminder_still_valid(tracker, reminder_event)
652
659
  ):
653
- logger.debug(
654
- f"Canceled reminder because it is outdated ({reminder_event})."
660
+ structlogger.debug(
661
+ "rasa.core.processor.handle_reminder",
662
+ event_info="Canceled reminder because it is outdated.",
663
+ reminder_event=reminder_event,
655
664
  )
656
665
  else:
657
666
  intent = reminder_event.intent
@@ -731,7 +740,7 @@ class MessageProcessor:
731
740
  if not self.domain or self.domain.is_empty():
732
741
  return
733
742
 
734
- intent = parse_data["intent"][INTENT_NAME_KEY]
743
+ intent = parse_data[INTENT][INTENT_NAME_KEY]
735
744
  if intent and intent not in self.domain.intents:
736
745
  rasa.shared.utils.io.raise_warning(
737
746
  f"Parsed an intent '{intent}' "
@@ -740,7 +749,7 @@ class MessageProcessor:
740
749
  docs=DOCS_URL_DOMAINS,
741
750
  )
742
751
 
743
- entities = parse_data["entities"] or []
752
+ entities = parse_data[ENTITIES] or []
744
753
  for element in entities:
745
754
  entity = element["entity"]
746
755
  if entity and entity not in self.domain.entities:
@@ -824,9 +833,9 @@ class MessageProcessor:
824
833
  self._update_full_retrieval_intent(parse_data)
825
834
  structlogger.debug(
826
835
  "processor.message.parse",
827
- parse_data_text=copy.deepcopy(parse_data["text"]),
828
- parse_data_intent=parse_data["intent"],
829
- parse_data_entities=copy.deepcopy(parse_data["entities"]),
836
+ parse_data_text=copy.deepcopy(parse_data[TEXT]),
837
+ parse_data_intent=parse_data[INTENT],
838
+ parse_data_entities=copy.deepcopy(parse_data[ENTITIES]),
830
839
  )
831
840
 
832
841
  self._check_for_unseen_features(parse_data)
@@ -975,7 +984,7 @@ class MessageProcessor:
975
984
  f"invalid intent: {parse_data[INTENT]['name']}. "
976
985
  f"Returning CannotHandleCommand() as a fallback."
977
986
  ),
978
- invalid_intent=parse_data[INTENT]["name"],
987
+ invalid_intent=parse_data[INTENT][INTENT_NAME_KEY],
979
988
  )
980
989
  commands.append(
981
990
  CannotHandleCommand(RASA_PATTERN_CANNOT_HANDLE_INVALID_INTENT)
@@ -985,7 +994,7 @@ class MessageProcessor:
985
994
 
986
995
  def _contains_undefined_intent(self, message: Message) -> bool:
987
996
  """Checks if the message contains an undefined intent."""
988
- intent_name = message.get(INTENT, {}).get("name")
997
+ intent_name = message.get(INTENT, {}).get(INTENT_NAME_KEY)
989
998
  return intent_name is not None and intent_name not in self.domain.intents
990
999
 
991
1000
  async def _parse_message_with_graph(
@@ -1035,8 +1044,8 @@ class MessageProcessor:
1035
1044
  tracker.update(
1036
1045
  UserUttered(
1037
1046
  message.text,
1038
- parse_data["intent"],
1039
- parse_data["entities"],
1047
+ parse_data[INTENT],
1048
+ parse_data[ENTITIES],
1040
1049
  parse_data,
1041
1050
  input_channel=message.input_channel,
1042
1051
  message_id=message.message_id,
@@ -1045,13 +1054,16 @@ class MessageProcessor:
1045
1054
  self.domain,
1046
1055
  )
1047
1056
 
1048
- if parse_data["entities"]:
1057
+ if parse_data[ENTITIES]:
1049
1058
  self._log_slots(tracker)
1050
1059
 
1051
1060
  plugin_manager().hook.after_new_user_message(tracker=tracker)
1052
1061
 
1053
- logger.debug(
1054
- f"Logged UserUtterance - tracker now has {len(tracker.events)} events."
1062
+ structlogger.debug(
1063
+ "rasa.core.processor.handle_message_with_tracker",
1064
+ event_info="Logged UserUtterance.",
1065
+ user_message=message.text,
1066
+ number_of_events=len(tracker.events),
1055
1067
  )
1056
1068
 
1057
1069
  @staticmethod
@@ -1166,9 +1178,11 @@ class MessageProcessor:
1166
1178
  tracker
1167
1179
  )
1168
1180
  except ActionLimitReached:
1169
- logger.warning(
1170
- "Circuit breaker tripped. Stopped predicting "
1171
- f"more actions for sender '{tracker.sender_id}'."
1181
+ structlogger.warning(
1182
+ "rasa.core.processor.run_prediction_loop",
1183
+ event_info="Circuit breaker tripped. Stopped predicting more "
1184
+ "actions.",
1185
+ sender_id=tracker.sender_id,
1172
1186
  )
1173
1187
  if self.on_circuit_break:
1174
1188
  # call a registered callback
@@ -1176,9 +1190,11 @@ class MessageProcessor:
1176
1190
  break
1177
1191
 
1178
1192
  if prediction.is_end_to_end_prediction:
1179
- logger.debug(
1180
- f"An end-to-end prediction was made which has triggered the 2nd "
1181
- f"execution of the default action '{ACTION_EXTRACT_SLOTS}'."
1193
+ structlogger.debug(
1194
+ "rasa.core.processor.run_prediction_loop",
1195
+ event_info="An end-to-end prediction was made which has "
1196
+ "triggered the 2nd execution of the default action.",
1197
+ action=ACTION_EXTRACT_SLOTS,
1182
1198
  )
1183
1199
  tracker = await self.run_action_extract_slots(output_channel, tracker)
1184
1200
 
@@ -1197,7 +1213,11 @@ class MessageProcessor:
1197
1213
  `False` if `action_name` is `ACTION_LISTEN_NAME` or
1198
1214
  `ACTION_SESSION_START_NAME`, otherwise `True`.
1199
1215
  """
1200
- return action_name not in (ACTION_LISTEN_NAME, ACTION_SESSION_START_NAME)
1216
+ return action_name not in (
1217
+ ACTION_LISTEN_NAME,
1218
+ ACTION_SESSION_START_NAME,
1219
+ ACTION_AGENT_REQUEST_USER_INPUT_NAME,
1220
+ )
1201
1221
 
1202
1222
  async def execute_side_effects(
1203
1223
  self,
@@ -1390,10 +1410,11 @@ class MessageProcessor:
1390
1410
  )
1391
1411
 
1392
1412
  if any(isinstance(e, UserUttered) for e in events):
1393
- logger.debug(
1394
- f"A `UserUttered` event was returned by executing "
1413
+ structlogger.debug(
1414
+ "rasa.core.processor.run_action",
1415
+ message="A `UserUttered` event was returned by executing "
1395
1416
  f"action '{action.name()}'. This will run the default action "
1396
- f"'{ACTION_EXTRACT_SLOTS}'."
1417
+ f"'{ACTION_EXTRACT_SLOTS}'.",
1397
1418
  )
1398
1419
  tracker = await self.run_action_extract_slots(output_channel, tracker)
1399
1420
 
@@ -1499,11 +1520,9 @@ class MessageProcessor:
1499
1520
  # tracker has never expired if sessions are disabled
1500
1521
  return False
1501
1522
 
1502
- user_uttered_event: Optional[UserUttered] = tracker.get_last_event_for(
1503
- UserUttered
1504
- )
1523
+ user_uttered_event = tracker.get_last_event_for(UserUttered)
1505
1524
 
1506
- if not user_uttered_event:
1525
+ if not user_uttered_event or not isinstance(user_uttered_event, UserUttered):
1507
1526
  # there is no user event so far so the session should not be considered
1508
1527
  # expired
1509
1528
  return False
@@ -1514,9 +1533,10 @@ class MessageProcessor:
1514
1533
  > self.domain.session_config.session_expiration_time
1515
1534
  )
1516
1535
  if has_expired:
1517
- logger.debug(
1518
- f"The latest session for conversation ID '{tracker.sender_id}' has "
1519
- f"expired."
1536
+ structlogger.debug(
1537
+ "rasa.core.processor.has_session_expired",
1538
+ event_info="The latest session has expired.",
1539
+ sender_id=tracker.sender_id,
1520
1540
  )
1521
1541
 
1522
1542
  return has_expired
@@ -1542,10 +1562,14 @@ class MessageProcessor:
1542
1562
  )
1543
1563
  return prediction
1544
1564
 
1545
- logger.error(
1546
- f"Trying to run unknown follow-up action '{followup_action}'. "
1547
- "Instead of running that, Rasa Pro will ignore the action "
1548
- "and predict the next action."
1565
+ structlogger.error(
1566
+ "rasa.core.processor.predict_next_with_tracker",
1567
+ event_info="Trying to run unknown follow-up action.",
1568
+ message=(
1569
+ "Trying to run unknown follow-up action. Instead of running "
1570
+ "that, Rasa Pro will ignore the action and predict the next action."
1571
+ ),
1572
+ followup_action=followup_action,
1549
1573
  )
1550
1574
 
1551
1575
  target = self.model_metadata.core_target
@@ -1,10 +1,18 @@
1
+ import os
1
2
  from enum import Enum
2
3
  from typing import Any, Dict, List, Optional, Text, Tuple, Union
3
4
 
4
5
  import redis
5
6
  import structlog
6
- from pydantic import BaseModel
7
-
7
+ from pydantic import BaseModel, ConfigDict
8
+
9
+ from rasa.core.constants import AWS_ELASTICACHE_CLUSTER_NAME_ENV_VAR_NAME
10
+ from rasa.core.iam_credentials_providers.credentials_provider_protocol import (
11
+ IAMCredentialsProvider,
12
+ IAMCredentialsProviderInput,
13
+ SupportedServiceType,
14
+ create_iam_credentials_provider,
15
+ )
8
16
  from rasa.shared.exceptions import ConnectionException, RasaException
9
17
 
10
18
  structlogger = structlog.getLogger(__name__)
@@ -23,6 +31,8 @@ class DeploymentMode(Enum):
23
31
  class StandardRedisConfig(BaseModel):
24
32
  """Base configuration for Redis connections."""
25
33
 
34
+ model_config = ConfigDict(arbitrary_types_allowed=True)
35
+
26
36
  host: Text = "localhost"
27
37
  port: int = 6379
28
38
  username: Optional[Text] = None
@@ -34,6 +44,7 @@ class StandardRedisConfig(BaseModel):
34
44
  db: int = 0
35
45
  socket_timeout: float = DEFAULT_SOCKET_TIMEOUT_IN_SECONDS
36
46
  decode_responses: bool = False
47
+ iam_credentials_provider: Optional[IAMCredentialsProvider] = None
37
48
 
38
49
 
39
50
  class ClusterRedisConfig(StandardRedisConfig):
@@ -104,6 +115,14 @@ class RedisConnectionFactory:
104
115
  deployment_mode_enum, config.endpoints, config.host, config.port
105
116
  )
106
117
 
118
+ iam_credentials_provider = create_iam_credentials_provider(
119
+ IAMCredentialsProviderInput(
120
+ service_name=SupportedServiceType.LOCK_STORE,
121
+ username=config.username,
122
+ cluster_name=os.getenv(AWS_ELASTICACHE_CLUSTER_NAME_ENV_VAR_NAME),
123
+ )
124
+ )
125
+
107
126
  if deployment_mode_enum == DeploymentMode.CLUSTER:
108
127
  cls._log_cluster_db_warning(deployment_mode_enum, config.db)
109
128
  cluster_config = ClusterRedisConfig(
@@ -117,6 +136,7 @@ class RedisConnectionFactory:
117
136
  socket_timeout=config.socket_timeout,
118
137
  decode_responses=config.decode_responses,
119
138
  endpoints=parsed_endpoints,
139
+ iam_credentials_provider=iam_credentials_provider,
120
140
  )
121
141
  return cls._create_cluster_connection(cluster_config)
122
142
  elif deployment_mode_enum == DeploymentMode.SENTINEL:
@@ -131,6 +151,7 @@ class RedisConnectionFactory:
131
151
  "socket_timeout": config.socket_timeout,
132
152
  "decode_responses": config.decode_responses,
133
153
  "endpoints": parsed_endpoints,
154
+ "iam_credentials_provider": iam_credentials_provider,
134
155
  }
135
156
 
136
157
  if config.sentinel_service is not None:
@@ -151,6 +172,7 @@ class RedisConnectionFactory:
151
172
  db=config.db,
152
173
  socket_timeout=config.socket_timeout,
153
174
  decode_responses=config.decode_responses,
175
+ iam_credentials_provider=iam_credentials_provider,
154
176
  )
155
177
  return cls._create_standard_connection(standard_config)
156
178
 
@@ -279,18 +301,31 @@ class RedisConnectionFactory:
279
301
  )
280
302
 
281
303
  cluster_nodes = [ClusterNode(host, port) for host, port in config.endpoints]
304
+
305
+ common_config_kwargs = {
306
+ "startup_nodes": cluster_nodes,
307
+ "ssl": config.use_ssl,
308
+ "ssl_certfile": config.ssl_certfile,
309
+ "ssl_keyfile": config.ssl_keyfile,
310
+ "ssl_ca_certs": config.ssl_ca_certs,
311
+ "socket_timeout": config.socket_timeout,
312
+ "decode_responses": config.decode_responses,
313
+ }
314
+
282
315
  try:
283
- redis_cluster: redis.RedisCluster = redis.RedisCluster(
284
- startup_nodes=cluster_nodes,
285
- username=config.username,
286
- password=config.password,
287
- ssl=config.use_ssl,
288
- ssl_certfile=config.ssl_certfile,
289
- ssl_keyfile=config.ssl_keyfile,
290
- ssl_ca_certs=config.ssl_ca_certs,
291
- socket_timeout=config.socket_timeout,
292
- decode_responses=config.decode_responses,
293
- )
316
+ if config.iam_credentials_provider is not None:
317
+ structlogger.debug("redis_connection_factory.cluster_iam_auth_enabled")
318
+
319
+ redis_cluster: redis.RedisCluster = redis.RedisCluster(
320
+ credential_provider=config.iam_credentials_provider,
321
+ **common_config_kwargs,
322
+ )
323
+ else:
324
+ redis_cluster = redis.RedisCluster(
325
+ username=config.username,
326
+ password=config.password,
327
+ **common_config_kwargs,
328
+ )
294
329
  except Exception as e:
295
330
  raise ConnectionException(f"Error initializing Redis Cluster: {e}")
296
331
 
@@ -324,15 +359,25 @@ class RedisConnectionFactory:
324
359
  )
325
360
 
326
361
  # Configuration for Sentinel connection
327
- sentinel_kwargs = {
328
- "username": config.username,
329
- "password": config.password,
362
+ connection_kwargs: Dict[str, Any] = {
330
363
  "socket_timeout": config.socket_timeout,
331
364
  }
332
365
 
366
+ sentinel_kwargs: Optional[Dict] = None
367
+ if config.iam_credentials_provider is not None:
368
+ structlogger.debug("redis_connection_factory.sentinel_iam_auth_enabled")
369
+ sentinel_kwargs = {"credential_provider": config.iam_credentials_provider}
370
+ else:
371
+ connection_kwargs.update(
372
+ {
373
+ "username": config.username,
374
+ "password": config.password,
375
+ }
376
+ )
377
+
333
378
  # SSL configuration
334
379
  if config.use_ssl:
335
- sentinel_kwargs.update(
380
+ connection_kwargs.update(
336
381
  {
337
382
  "ssl": config.use_ssl,
338
383
  "ssl_certfile": config.ssl_certfile,
@@ -350,7 +395,9 @@ class RedisConnectionFactory:
350
395
 
351
396
  # Create Sentinel instance
352
397
  try:
353
- sentinel = Sentinel(config.endpoints, **sentinel_kwargs)
398
+ sentinel = Sentinel(
399
+ config.endpoints, sentinel_kwargs=sentinel_kwargs, **connection_kwargs
400
+ )
354
401
  master = sentinel.master_for(config.sentinel_service, **client_kwargs)
355
402
 
356
403
  # Test the connection
@@ -383,16 +430,27 @@ class RedisConnectionFactory:
383
430
  "host": config.host,
384
431
  "port": int(config.port),
385
432
  "db": config.db,
386
- "password": config.password,
387
433
  "socket_timeout": float(config.socket_timeout),
388
434
  "ssl": config.use_ssl,
389
435
  "ssl_certfile": config.ssl_certfile,
390
436
  "ssl_keyfile": config.ssl_keyfile,
391
437
  "ssl_ca_certs": config.ssl_ca_certs,
392
- "username": config.username,
393
438
  "decode_responses": config.decode_responses,
394
439
  }
395
440
 
441
+ if config.iam_credentials_provider is not None:
442
+ structlogger.debug("redis_connection_factory.standard_iam_auth_enabled")
443
+ connection_args.update(
444
+ {"credential_provider": config.iam_credentials_provider}
445
+ )
446
+ else:
447
+ connection_args.update(
448
+ {
449
+ "password": config.password,
450
+ "username": config.username,
451
+ }
452
+ )
453
+
396
454
  try:
397
455
  standard_redis = redis.StrictRedis(**connection_args)
398
456
  except Exception as e: