rasa-pro 3.13.7__py3-none-any.whl → 3.14.0.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rasa-pro might be problematic. Click here for more details.

Files changed (178) hide show
  1. rasa/agents/__init__.py +0 -0
  2. rasa/agents/agent_factory.py +122 -0
  3. rasa/agents/agent_manager.py +162 -0
  4. rasa/agents/constants.py +31 -0
  5. rasa/agents/core/__init__.py +0 -0
  6. rasa/agents/core/agent_protocol.py +108 -0
  7. rasa/agents/core/types.py +70 -0
  8. rasa/agents/exceptions.py +8 -0
  9. rasa/agents/protocol/__init__.py +5 -0
  10. rasa/agents/protocol/a2a/__init__.py +0 -0
  11. rasa/agents/protocol/a2a/a2a_agent.py +51 -0
  12. rasa/agents/protocol/mcp/__init__.py +0 -0
  13. rasa/agents/protocol/mcp/mcp_base_agent.py +697 -0
  14. rasa/agents/protocol/mcp/mcp_open_agent.py +275 -0
  15. rasa/agents/protocol/mcp/mcp_task_agent.py +447 -0
  16. rasa/agents/schemas/__init__.py +6 -0
  17. rasa/agents/schemas/agent_input.py +24 -0
  18. rasa/agents/schemas/agent_output.py +26 -0
  19. rasa/agents/schemas/agent_tool_result.py +51 -0
  20. rasa/agents/schemas/agent_tool_schema.py +112 -0
  21. rasa/agents/templates/__init__.py +0 -0
  22. rasa/agents/templates/mcp_open_agent_prompt_template.jinja2 +15 -0
  23. rasa/agents/templates/mcp_task_agent_prompt_template.jinja2 +13 -0
  24. rasa/agents/utils.py +72 -0
  25. rasa/api.py +5 -0
  26. rasa/cli/arguments/default_arguments.py +12 -0
  27. rasa/cli/arguments/run.py +2 -0
  28. rasa/cli/dialogue_understanding_test.py +4 -0
  29. rasa/cli/e2e_test.py +4 -0
  30. rasa/cli/inspect.py +3 -0
  31. rasa/cli/llm_fine_tuning.py +5 -0
  32. rasa/cli/run.py +4 -0
  33. rasa/cli/shell.py +3 -0
  34. rasa/cli/train.py +2 -2
  35. rasa/constants.py +6 -0
  36. rasa/core/actions/action.py +69 -39
  37. rasa/core/actions/action_run_slot_rejections.py +1 -1
  38. rasa/core/agent.py +16 -0
  39. rasa/core/available_agents.py +196 -0
  40. rasa/core/available_endpoints.py +30 -0
  41. rasa/core/channels/development_inspector.py +47 -14
  42. rasa/core/channels/inspector/dist/assets/{arc-0b11fe30.js → arc-2e78c586.js} +1 -1
  43. rasa/core/channels/inspector/dist/assets/{blockDiagram-38ab4fdb-9eef30a7.js → blockDiagram-38ab4fdb-806b712e.js} +1 -1
  44. rasa/core/channels/inspector/dist/assets/{c4Diagram-3d4e48cf-03e94f28.js → c4Diagram-3d4e48cf-0745efa9.js} +1 -1
  45. rasa/core/channels/inspector/dist/assets/channel-c436ca7c.js +1 -0
  46. rasa/core/channels/inspector/dist/assets/{classDiagram-70f12bd4-95c09eba.js → classDiagram-70f12bd4-7bd1082b.js} +1 -1
  47. rasa/core/channels/inspector/dist/assets/{classDiagram-v2-f2320105-38e8446c.js → classDiagram-v2-f2320105-d937ba49.js} +1 -1
  48. rasa/core/channels/inspector/dist/assets/clone-50dd656b.js +1 -0
  49. rasa/core/channels/inspector/dist/assets/{createText-2e5e7dd3-57dc3038.js → createText-2e5e7dd3-a2a564ca.js} +1 -1
  50. rasa/core/channels/inspector/dist/assets/{edges-e0da2a9e-4bac0545.js → edges-e0da2a9e-b5256940.js} +1 -1
  51. rasa/core/channels/inspector/dist/assets/{erDiagram-9861fffd-81795c90.js → erDiagram-9861fffd-e6883ad2.js} +1 -1
  52. rasa/core/channels/inspector/dist/assets/{flowDb-956e92f1-89489ae6.js → flowDb-956e92f1-e576fc02.js} +1 -1
  53. rasa/core/channels/inspector/dist/assets/{flowDiagram-66a62f08-cd152627.js → flowDiagram-66a62f08-2e298d01.js} +1 -1
  54. rasa/core/channels/inspector/dist/assets/flowDiagram-v2-96b9c2cf-2b2aeaf8.js +1 -0
  55. rasa/core/channels/inspector/dist/assets/{flowchart-elk-definition-4a651766-3da369bc.js → flowchart-elk-definition-4a651766-dd7b150a.js} +1 -1
  56. rasa/core/channels/inspector/dist/assets/{ganttDiagram-c361ad54-85ec16f8.js → ganttDiagram-c361ad54-5b79575c.js} +1 -1
  57. rasa/core/channels/inspector/dist/assets/{gitGraphDiagram-72cf32ee-495bc140.js → gitGraphDiagram-72cf32ee-3016f40a.js} +1 -1
  58. rasa/core/channels/inspector/dist/assets/{graph-1ec4d266.js → graph-3e19170f.js} +1 -1
  59. rasa/core/channels/inspector/dist/assets/index-1bd9135e.js +1353 -0
  60. rasa/core/channels/inspector/dist/assets/{index-3862675e-0a0e97c9.js → index-3862675e-eb9c86de.js} +1 -1
  61. rasa/core/channels/inspector/dist/assets/{infoDiagram-f8f76790-4d54bcde.js → infoDiagram-f8f76790-b4280e4d.js} +1 -1
  62. rasa/core/channels/inspector/dist/assets/{journeyDiagram-49397b02-dc097114.js → journeyDiagram-49397b02-556091f8.js} +1 -1
  63. rasa/core/channels/inspector/dist/assets/{layout-1a08981e.js → layout-08436411.js} +1 -1
  64. rasa/core/channels/inspector/dist/assets/{line-95f7f1d3.js → line-683c4f3b.js} +1 -1
  65. rasa/core/channels/inspector/dist/assets/{linear-97e69543.js → linear-cee6d791.js} +1 -1
  66. rasa/core/channels/inspector/dist/assets/{mindmap-definition-fc14e90a-8c71ff03.js → mindmap-definition-fc14e90a-a0bf0b1a.js} +1 -1
  67. rasa/core/channels/inspector/dist/assets/{pieDiagram-8a3498a8-f14c71c7.js → pieDiagram-8a3498a8-3730d5c4.js} +1 -1
  68. rasa/core/channels/inspector/dist/assets/{quadrantDiagram-120e2f19-f1d3c9ff.js → quadrantDiagram-120e2f19-12a20fed.js} +1 -1
  69. rasa/core/channels/inspector/dist/assets/{requirementDiagram-deff3bca-bfa2412f.js → requirementDiagram-deff3bca-b9732102.js} +1 -1
  70. rasa/core/channels/inspector/dist/assets/{sankeyDiagram-04a897e0-53f2c97b.js → sankeyDiagram-04a897e0-a2e72776.js} +1 -1
  71. rasa/core/channels/inspector/dist/assets/{sequenceDiagram-704730f1-319d7c0e.js → sequenceDiagram-704730f1-8b7a76bb.js} +1 -1
  72. rasa/core/channels/inspector/dist/assets/{stateDiagram-587899a1-76a09418.js → stateDiagram-587899a1-e65853ac.js} +1 -1
  73. rasa/core/channels/inspector/dist/assets/{stateDiagram-v2-d93cdb3a-a67f15d4.js → stateDiagram-v2-d93cdb3a-6f58a44b.js} +1 -1
  74. rasa/core/channels/inspector/dist/assets/{styles-6aaf32cf-0654e7c3.js → styles-6aaf32cf-df25b934.js} +1 -1
  75. rasa/core/channels/inspector/dist/assets/{styles-9a916d00-1394bb9d.js → styles-9a916d00-88357141.js} +1 -1
  76. rasa/core/channels/inspector/dist/assets/{styles-c10674c1-e4c5bdae.js → styles-c10674c1-d600174d.js} +1 -1
  77. rasa/core/channels/inspector/dist/assets/{svgDrawCommon-08f97a94-50957104.js → svgDrawCommon-08f97a94-4adc3e0b.js} +1 -1
  78. rasa/core/channels/inspector/dist/assets/{timeline-definition-85554ec2-b0885a6a.js → timeline-definition-85554ec2-42816fa1.js} +1 -1
  79. rasa/core/channels/inspector/dist/assets/{xychartDiagram-e933f94c-79e6541a.js → xychartDiagram-e933f94c-621eb66a.js} +1 -1
  80. rasa/core/channels/inspector/dist/index.html +2 -2
  81. rasa/core/channels/inspector/index.html +1 -1
  82. rasa/core/channels/inspector/src/App.tsx +53 -7
  83. rasa/core/channels/inspector/src/components/Chat.tsx +3 -2
  84. rasa/core/channels/inspector/src/components/DiagramFlow.tsx +1 -1
  85. rasa/core/channels/inspector/src/components/DialogueStack.tsx +7 -5
  86. rasa/core/channels/inspector/src/components/LatencyDisplay.tsx +268 -0
  87. rasa/core/channels/inspector/src/components/LoadingSpinner.tsx +6 -2
  88. rasa/core/channels/inspector/src/helpers/audio/audiostream.ts +8 -3
  89. rasa/core/channels/inspector/src/helpers/formatters.ts +24 -3
  90. rasa/core/channels/inspector/src/theme/base/styles.ts +19 -1
  91. rasa/core/channels/inspector/src/types.ts +12 -0
  92. rasa/core/channels/studio_chat.py +125 -34
  93. rasa/core/channels/voice_ready/twilio_voice.py +1 -1
  94. rasa/core/channels/voice_stream/audiocodes.py +9 -6
  95. rasa/core/channels/voice_stream/browser_audio.py +39 -4
  96. rasa/core/channels/voice_stream/call_state.py +13 -2
  97. rasa/core/channels/voice_stream/genesys.py +16 -13
  98. rasa/core/channels/voice_stream/jambonz.py +13 -11
  99. rasa/core/channels/voice_stream/twilio_media_streams.py +14 -13
  100. rasa/core/channels/voice_stream/util.py +11 -1
  101. rasa/core/channels/voice_stream/voice_channel.py +101 -29
  102. rasa/core/constants.py +4 -0
  103. rasa/core/nlg/contextual_response_rephraser.py +11 -7
  104. rasa/core/nlg/generator.py +21 -5
  105. rasa/core/nlg/response.py +43 -6
  106. rasa/core/nlg/translate.py +8 -0
  107. rasa/core/policies/enterprise_search_policy.py +4 -2
  108. rasa/core/policies/flow_policy.py +2 -2
  109. rasa/core/policies/flows/flow_executor.py +374 -35
  110. rasa/core/policies/flows/mcp_tool_executor.py +240 -0
  111. rasa/core/processor.py +6 -1
  112. rasa/core/run.py +8 -1
  113. rasa/core/utils.py +21 -1
  114. rasa/dialogue_understanding/commands/__init__.py +8 -0
  115. rasa/dialogue_understanding/commands/cancel_flow_command.py +97 -4
  116. rasa/dialogue_understanding/commands/chit_chat_answer_command.py +11 -0
  117. rasa/dialogue_understanding/commands/continue_agent_command.py +91 -0
  118. rasa/dialogue_understanding/commands/knowledge_answer_command.py +11 -0
  119. rasa/dialogue_understanding/commands/restart_agent_command.py +146 -0
  120. rasa/dialogue_understanding/commands/start_flow_command.py +129 -8
  121. rasa/dialogue_understanding/commands/utils.py +6 -2
  122. rasa/dialogue_understanding/generator/command_parser.py +4 -0
  123. rasa/dialogue_understanding/generator/llm_based_command_generator.py +50 -12
  124. rasa/dialogue_understanding/generator/prompt_templates/agent_command_prompt_v2_claude_3_5_sonnet_20240620_template.jinja2 +61 -0
  125. rasa/dialogue_understanding/generator/prompt_templates/agent_command_prompt_v2_gpt_4o_2024_11_20_template.jinja2 +61 -0
  126. rasa/dialogue_understanding/generator/prompt_templates/agent_command_prompt_v3_claude_3_5_sonnet_20240620_template.jinja2 +81 -0
  127. rasa/dialogue_understanding/generator/prompt_templates/agent_command_prompt_v3_gpt_4o_2024_11_20_template.jinja2 +81 -0
  128. rasa/dialogue_understanding/generator/single_step/compact_llm_command_generator.py +7 -6
  129. rasa/dialogue_understanding/generator/single_step/search_ready_llm_command_generator.py +7 -6
  130. rasa/dialogue_understanding/generator/single_step/single_step_based_llm_command_generator.py +41 -2
  131. rasa/dialogue_understanding/patterns/continue_interrupted.py +163 -1
  132. rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +51 -7
  133. rasa/dialogue_understanding/stack/dialogue_stack.py +123 -2
  134. rasa/dialogue_understanding/stack/frames/flow_stack_frame.py +57 -0
  135. rasa/dialogue_understanding/stack/utils.py +3 -2
  136. rasa/dialogue_understanding_test/du_test_runner.py +7 -2
  137. rasa/dialogue_understanding_test/du_test_schema.yml +3 -3
  138. rasa/e2e_test/e2e_test_runner.py +5 -0
  139. rasa/e2e_test/e2e_test_schema.yml +3 -3
  140. rasa/model_manager/model_api.py +1 -1
  141. rasa/model_manager/socket_bridge.py +8 -2
  142. rasa/server.py +10 -0
  143. rasa/shared/agents/__init__.py +0 -0
  144. rasa/shared/agents/utils.py +35 -0
  145. rasa/shared/constants.py +5 -0
  146. rasa/shared/core/constants.py +12 -1
  147. rasa/shared/core/domain.py +5 -5
  148. rasa/shared/core/events.py +319 -0
  149. rasa/shared/core/flows/flows_list.py +2 -2
  150. rasa/shared/core/flows/flows_yaml_schema.json +101 -186
  151. rasa/shared/core/flows/steps/call.py +51 -5
  152. rasa/shared/core/flows/validation.py +45 -7
  153. rasa/shared/core/flows/yaml_flows_io.py +3 -3
  154. rasa/shared/providers/llm/_base_litellm_client.py +39 -7
  155. rasa/shared/providers/llm/litellm_router_llm_client.py +8 -4
  156. rasa/shared/providers/llm/llm_client.py +7 -3
  157. rasa/shared/providers/llm/llm_response.py +49 -0
  158. rasa/shared/providers/llm/self_hosted_llm_client.py +8 -4
  159. rasa/shared/utils/common.py +2 -1
  160. rasa/shared/utils/llm.py +28 -5
  161. rasa/shared/utils/mcp/__init__.py +0 -0
  162. rasa/shared/utils/mcp/server_connection.py +157 -0
  163. rasa/shared/utils/schemas/events.py +42 -0
  164. rasa/studio/upload.py +4 -7
  165. rasa/tracing/instrumentation/instrumentation.py +4 -2
  166. rasa/utils/common.py +53 -0
  167. rasa/utils/licensing.py +21 -10
  168. rasa/utils/plotting.py +1 -1
  169. rasa/version.py +1 -1
  170. {rasa_pro-3.13.7.dist-info → rasa_pro-3.14.0.dev1.dist-info}/METADATA +16 -15
  171. {rasa_pro-3.13.7.dist-info → rasa_pro-3.14.0.dev1.dist-info}/RECORD +174 -137
  172. rasa/core/channels/inspector/dist/assets/channel-51d02e9e.js +0 -1
  173. rasa/core/channels/inspector/dist/assets/clone-cc738fa6.js +0 -1
  174. rasa/core/channels/inspector/dist/assets/flowDiagram-v2-96b9c2cf-0c716443.js +0 -1
  175. rasa/core/channels/inspector/dist/assets/index-c804b295.js +0 -1335
  176. {rasa_pro-3.13.7.dist-info → rasa_pro-3.14.0.dev1.dist-info}/NOTICE +0 -0
  177. {rasa_pro-3.13.7.dist-info → rasa_pro-3.14.0.dev1.dist-info}/WHEEL +0 -0
  178. {rasa_pro-3.13.7.dist-info → rasa_pro-3.14.0.dev1.dist-info}/entry_points.txt +0 -0
@@ -21,7 +21,7 @@ from rasa.shared.providers._ssl_verification_utils import (
21
21
  ensure_ssl_certificates_for_litellm_non_openai_based_clients,
22
22
  ensure_ssl_certificates_for_litellm_openai_based_clients,
23
23
  )
24
- from rasa.shared.providers.llm.llm_response import LLMResponse, LLMUsage
24
+ from rasa.shared.providers.llm.llm_response import LLMResponse, LLMToolCall, LLMUsage
25
25
  from rasa.shared.utils.io import resolve_environment_variables, suppress_logs
26
26
 
27
27
  structlogger = structlog.get_logger()
@@ -126,7 +126,9 @@ class _BaseLiteLLMClient:
126
126
  raise ProviderClientValidationError(event_info)
127
127
 
128
128
  @suppress_logs(log_level=logging.WARNING)
129
- def completion(self, messages: Union[List[dict], List[str], str]) -> LLMResponse:
129
+ def completion(
130
+ self, messages: Union[List[dict], List[str], str], **kwargs: Any
131
+ ) -> LLMResponse:
130
132
  """Synchronously generate completions for given list of messages.
131
133
 
132
134
  Args:
@@ -138,6 +140,7 @@ class _BaseLiteLLMClient:
138
140
  - a list of messages. Each message is a string and will be formatted
139
141
  as a user message.
140
142
  - a single message as a string which will be formatted as user message.
143
+ **kwargs: Additional parameters to pass to the completion call.
141
144
 
142
145
  Returns:
143
146
  List of message completions.
@@ -147,15 +150,19 @@ class _BaseLiteLLMClient:
147
150
  """
148
151
  try:
149
152
  formatted_messages = self._get_formatted_messages(messages)
150
- arguments = resolve_environment_variables(self._completion_fn_args)
151
- response = completion(messages=formatted_messages, **arguments)
153
+ arguments = cast(
154
+ Dict[str, Any], resolve_environment_variables(self._completion_fn_args)
155
+ )
156
+ response = completion(
157
+ messages=formatted_messages, **{**arguments, **kwargs}
158
+ )
152
159
  return self._format_response(response)
153
160
  except Exception as e:
154
161
  raise ProviderClientAPIException(e)
155
162
 
156
163
  @suppress_logs(log_level=logging.WARNING)
157
164
  async def acompletion(
158
- self, messages: Union[List[dict], List[str], str]
165
+ self, messages: Union[List[dict], List[str], str], **kwargs: Any
159
166
  ) -> LLMResponse:
160
167
  """Asynchronously generate completions for given list of messages.
161
168
 
@@ -168,6 +175,7 @@ class _BaseLiteLLMClient:
168
175
  - a list of messages. Each message is a string and will be formatted
169
176
  as a user message.
170
177
  - a single message as a string which will be formatted as user message.
178
+ **kwargs: Additional parameters to pass to the completion call.
171
179
 
172
180
  Returns:
173
181
  List of message completions.
@@ -177,8 +185,12 @@ class _BaseLiteLLMClient:
177
185
  """
178
186
  try:
179
187
  formatted_messages = self._get_formatted_messages(messages)
180
- arguments = resolve_environment_variables(self._completion_fn_args)
181
- response = await acompletion(messages=formatted_messages, **arguments)
188
+ arguments = cast(
189
+ Dict[str, Any], resolve_environment_variables(self._completion_fn_args)
190
+ )
191
+ response = await acompletion(
192
+ messages=formatted_messages, **{**arguments, **kwargs}
193
+ )
182
194
  return self._format_response(response)
183
195
  except Exception as e:
184
196
  message = ""
@@ -246,12 +258,32 @@ class _BaseLiteLLMClient:
246
258
  else 0
247
259
  )
248
260
  formatted_response.usage = LLMUsage(prompt_tokens, completion_tokens)
261
+
262
+ # Extract tool calls from all choices
263
+ formatted_response.tool_calls = self._extract_tool_calls(response)
264
+
249
265
  structlogger.debug(
250
266
  "base_litellm_client.formatted_response",
251
267
  formatted_response=formatted_response.to_dict(),
252
268
  )
253
269
  return formatted_response
254
270
 
271
+ def _extract_tool_calls(self, response: Any) -> List[LLMToolCall]:
272
+ """Extract tool calls from response choices.
273
+
274
+ Args:
275
+ response: List of response choices from LiteLLM
276
+
277
+ Returns:
278
+ List of LLMToolCall objects, empty if no tool calls found
279
+ """
280
+ return [
281
+ LLMToolCall.from_litellm(tool_call)
282
+ for choice in response.choices
283
+ if choice.message.tool_calls
284
+ for tool_call in choice.message.tool_calls
285
+ ]
286
+
255
287
  def _format_text_completion_response(self, response: Any) -> LLMResponse:
256
288
  """Parses the LiteLLM text completion response to Rasa format."""
257
289
  formatted_response = LLMResponse(
@@ -122,7 +122,9 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
122
122
  raise ProviderClientAPIException(e)
123
123
 
124
124
  @suppress_logs(log_level=logging.WARNING)
125
- def completion(self, messages: Union[List[dict], List[str], str]) -> LLMResponse:
125
+ def completion(
126
+ self, messages: Union[List[dict], List[str], str], **kwargs: Any
127
+ ) -> LLMResponse:
126
128
  """
127
129
  Synchronously generate completions for given list of messages.
128
130
 
@@ -140,6 +142,7 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
140
142
  - a list of messages. Each message is a string and will be formatted
141
143
  as a user message.
142
144
  - a single message as a string which will be formatted as user message.
145
+ **kwargs: Additional parameters to pass to the completion call.
143
146
  Returns:
144
147
  List of message completions.
145
148
  Raises:
@@ -150,7 +153,7 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
150
153
  try:
151
154
  formatted_messages = self._format_messages(messages)
152
155
  response = self.router_client.completion(
153
- messages=formatted_messages, **self._completion_fn_args
156
+ messages=formatted_messages, **{**self._completion_fn_args, **kwargs}
154
157
  )
155
158
  return self._format_response(response)
156
159
  except Exception as e:
@@ -158,7 +161,7 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
158
161
 
159
162
  @suppress_logs(log_level=logging.WARNING)
160
163
  async def acompletion(
161
- self, messages: Union[List[dict], List[str], str]
164
+ self, messages: Union[List[dict], List[str], str], **kwargs: Any
162
165
  ) -> LLMResponse:
163
166
  """
164
167
  Asynchronously generate completions for given list of messages.
@@ -177,6 +180,7 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
177
180
  - a list of messages. Each message is a string and will be formatted
178
181
  as a user message.
179
182
  - a single message as a string which will be formatted as user message.
183
+ **kwargs: Additional parameters to pass to the completion call.
180
184
  Returns:
181
185
  List of message completions.
182
186
  Raises:
@@ -187,7 +191,7 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
187
191
  try:
188
192
  formatted_messages = self._format_messages(messages)
189
193
  response = await self.router_client.acompletion(
190
- messages=formatted_messages, **self._completion_fn_args
194
+ messages=formatted_messages, **{**self._completion_fn_args, **kwargs}
191
195
  )
192
196
  return self._format_response(response)
193
197
  except Exception as e:
@@ -1,6 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
- from typing import Dict, List, Protocol, Union, runtime_checkable
3
+ from typing import Any, Dict, List, Protocol, Union, runtime_checkable
4
4
 
5
5
  from rasa.shared.providers.llm.llm_response import LLMResponse
6
6
 
@@ -32,7 +32,9 @@ class LLMClient(Protocol):
32
32
  """
33
33
  ...
34
34
 
35
- def completion(self, messages: Union[List[dict], List[str], str]) -> LLMResponse:
35
+ def completion(
36
+ self, messages: Union[List[dict], List[str], str], **kwargs: Any
37
+ ) -> LLMResponse:
36
38
  """
37
39
  Synchronously generate completions for given list of messages.
38
40
 
@@ -48,13 +50,14 @@ class LLMClient(Protocol):
48
50
  - a list of messages. Each message is a string and will be formatted
49
51
  as a user message.
50
52
  - a single message as a string which will be formatted as user message.
53
+ **kwargs: Additional parameters to pass to the completion call.
51
54
  Returns:
52
55
  LLMResponse
53
56
  """
54
57
  ...
55
58
 
56
59
  async def acompletion(
57
- self, messages: Union[List[dict], List[str], str]
60
+ self, messages: Union[List[dict], List[str], str], **kwargs: Any
58
61
  ) -> LLMResponse:
59
62
  """
60
63
  Asynchronously generate completions for given list of messages.
@@ -71,6 +74,7 @@ class LLMClient(Protocol):
71
74
  - a list of messages. Each message is a string and will be formatted
72
75
  as a user message.
73
76
  - a single message as a string which will be formatted as user message.
77
+ **kwargs: Additional parameters to pass to the completion call.
74
78
  Returns:
75
79
  LLMResponse
76
80
  """
@@ -1,9 +1,14 @@
1
1
  import functools
2
+ import json
2
3
  import time
3
4
  from dataclasses import asdict, dataclass, field
4
5
  from typing import Any, Awaitable, Callable, Dict, List, Optional, Text, Union
5
6
 
6
7
  import structlog
8
+ from litellm.utils import ChatCompletionMessageToolCall
9
+ from pydantic import BaseModel
10
+
11
+ from rasa.shared.constants import KEY_TOOL_CALLS
7
12
 
8
13
  structlogger = structlog.get_logger()
9
14
 
@@ -38,6 +43,37 @@ class LLMUsage:
38
43
  return asdict(self)
39
44
 
40
45
 
46
+ class LLMToolCall(BaseModel):
47
+ """A class representing a response from an LLM tool call."""
48
+
49
+ id: str
50
+ """The ID of the tool call."""
51
+
52
+ tool_name: str
53
+ """The name of the tool that was called."""
54
+
55
+ tool_args: Dict[str, Any]
56
+ """The arguments passed to the tool call."""
57
+
58
+ type: str = "function"
59
+ """The type of the tool call."""
60
+
61
+ @classmethod
62
+ def from_dict(cls, data: Dict[Text, Any]) -> "LLMToolCall":
63
+ """Creates an LLMToolResponse from a dictionary."""
64
+ return cls(**data)
65
+
66
+ @classmethod
67
+ def from_litellm(cls, data: ChatCompletionMessageToolCall) -> "LLMToolCall":
68
+ """Creates an LLMToolResponse from a dictionary."""
69
+ return cls(
70
+ id=data.id,
71
+ tool_name=data.function.name,
72
+ tool_args=json.loads(data.function.arguments),
73
+ type=data.type,
74
+ )
75
+
76
+
41
77
  @dataclass
42
78
  class LLMResponse:
43
79
  id: str
@@ -62,12 +98,22 @@ class LLMResponse:
62
98
  latency: Optional[float] = None
63
99
  """Optional field to store the latency of the LLM API call."""
64
100
 
101
+ tool_calls: Optional[List[LLMToolCall]] = None
102
+ """The list of tool calls the model generated for the input prompt."""
103
+
65
104
  @classmethod
66
105
  def from_dict(cls, data: Dict[Text, Any]) -> "LLMResponse":
67
106
  """Creates an LLMResponse from a dictionary."""
68
107
  usage_data = data.get("usage", {})
69
108
  usage_obj = LLMUsage.from_dict(usage_data) if usage_data else None
70
109
 
110
+ tool_calls_data = data.get(KEY_TOOL_CALLS, [])
111
+ tool_calls_obj = (
112
+ [LLMToolCall.from_dict(tool) for tool in tool_calls_data]
113
+ if tool_calls_data
114
+ else None
115
+ )
116
+
71
117
  return cls(
72
118
  id=data["id"],
73
119
  choices=data["choices"],
@@ -76,6 +122,7 @@ class LLMResponse:
76
122
  usage=usage_obj,
77
123
  additional_info=data.get("additional_info"),
78
124
  latency=data.get("latency"),
125
+ tool_calls=tool_calls_obj,
79
126
  )
80
127
 
81
128
  @classmethod
@@ -92,6 +139,8 @@ class LLMResponse:
92
139
  result = asdict(self)
93
140
  if self.usage:
94
141
  result["usage"] = self.usage.to_dict()
142
+ if self.tool_calls:
143
+ result[KEY_TOOL_CALLS] = [tool.model_dump() for tool in self.tool_calls]
95
144
  return result
96
145
 
97
146
 
@@ -237,7 +237,7 @@ class SelfHostedLLMClient(_BaseLiteLLMClient):
237
237
  raise ProviderClientAPIException(e)
238
238
 
239
239
  async def acompletion(
240
- self, messages: Union[List[dict], List[str], str]
240
+ self, messages: Union[List[dict], List[str], str], **kwargs: Any
241
241
  ) -> LLMResponse:
242
242
  """Asynchronous completion of the model with the given messages.
243
243
 
@@ -255,15 +255,18 @@ class SelfHostedLLMClient(_BaseLiteLLMClient):
255
255
  - a list of messages. Each message is a string and will be formatted
256
256
  as a user message.
257
257
  - a single message as a string which will be formatted as user message.
258
+ **kwargs: Additional parameters to pass to the completion call.
258
259
 
259
260
  Returns:
260
261
  The completion response.
261
262
  """
262
263
  if self._use_chat_completions_endpoint:
263
- return await super().acompletion(messages)
264
+ return await super().acompletion(messages, **kwargs)
264
265
  return await self._atext_completion(messages)
265
266
 
266
- def completion(self, messages: Union[List[dict], List[str], str]) -> LLMResponse:
267
+ def completion(
268
+ self, messages: Union[List[dict], List[str], str], **kwargs: Any
269
+ ) -> LLMResponse:
267
270
  """Completion of the model with the given messages.
268
271
 
269
272
  Method overrides the base class method to call the appropriate
@@ -273,12 +276,13 @@ class SelfHostedLLMClient(_BaseLiteLLMClient):
273
276
 
274
277
  Args:
275
278
  messages: The messages to be used for completion.
279
+ **kwargs: Additional parameters to pass to the completion call.
276
280
 
277
281
  Returns:
278
282
  The completion response.
279
283
  """
280
284
  if self._use_chat_completions_endpoint:
281
- return super().completion(messages)
285
+ return super().completion(messages, **kwargs)
282
286
  return self._text_completion(messages)
283
287
 
284
288
  @staticmethod
@@ -373,7 +373,8 @@ def validate_environment(
373
373
  importlib.import_module(p)
374
374
  except ImportError:
375
375
  raise MissingDependencyException(
376
- f"Missing package for {component_name}: {p}"
376
+ f"Missing dependency for {component_name}: {p}. "
377
+ f"Please ensure the correct package is installed."
377
378
  )
378
379
 
379
380
 
rasa/shared/utils/llm.py CHANGED
@@ -49,7 +49,15 @@ from rasa.shared.constants import (
49
49
  RASA_PATTERN_INTERNAL_ERROR_USER_INPUT_TOO_LONG,
50
50
  ROUTER_CONFIG_KEY,
51
51
  )
52
- from rasa.shared.core.events import BotUttered, UserUttered
52
+ from rasa.shared.core.events import (
53
+ AgentCancelled,
54
+ AgentCompleted,
55
+ AgentInterrupted,
56
+ AgentResumed,
57
+ AgentStarted,
58
+ BotUttered,
59
+ UserUttered,
60
+ )
53
61
  from rasa.shared.core.slots import BooleanSlot, CategoricalSlot, Slot
54
62
  from rasa.shared.engine.caching import get_local_cache_location
55
63
  from rasa.shared.exceptions import (
@@ -112,7 +120,7 @@ DEPLOYMENT_CENTRIC_PROVIDERS = [AZURE_OPENAI_PROVIDER]
112
120
 
113
121
  # Placeholder messages used in the transcript for
114
122
  # instances where user input results in an error
115
- ERROR_PLACEHOLDER = {
123
+ ERROR_PLACEHOLDER: Dict[str, str] = {
116
124
  RASA_PATTERN_INTERNAL_ERROR_USER_INPUT_TOO_LONG: "[User sent really long message]",
117
125
  RASA_PATTERN_INTERNAL_ERROR_USER_INPUT_EMPTY: "",
118
126
  "default": "[User input triggered an error]",
@@ -225,6 +233,7 @@ def tracker_as_readable_transcript(
225
233
  ai_prefix: str = AI,
226
234
  max_turns: Optional[int] = 20,
227
235
  turns_wrapper: Optional[Callable[[List[str]], List[str]]] = None,
236
+ highlight_agent_turns: bool = False,
228
237
  ) -> str:
229
238
  """Creates a readable dialogue from a tracker.
230
239
 
@@ -234,6 +243,7 @@ def tracker_as_readable_transcript(
234
243
  ai_prefix: the prefix to use for ai utterances
235
244
  max_turns: the maximum number of turns to include in the transcript
236
245
  turns_wrapper: optional function to wrap the turns in a custom way
246
+ highlight_agent_turns: whether to highlight agent turns in the transcript
237
247
 
238
248
  Example:
239
249
  >>> tracker = Tracker(
@@ -251,7 +261,9 @@ def tracker_as_readable_transcript(
251
261
  Returns:
252
262
  A string representing the transcript of the tracker
253
263
  """
254
- transcript = []
264
+ transcript: List[str] = []
265
+
266
+ current_ai_prefix = ai_prefix
255
267
 
256
268
  # using `applied_events` rather than `events` means that only events after the
257
269
  # most recent `Restart` or `SessionStarted` are included in the transcript
@@ -266,9 +278,20 @@ def tracker_as_readable_transcript(
266
278
  else:
267
279
  message = sanitize_message_for_prompt(event.text)
268
280
  transcript.append(f"{human_prefix}: {message}")
269
-
270
281
  elif isinstance(event, BotUttered):
271
- transcript.append(f"{ai_prefix}: {sanitize_message_for_prompt(event.text)}")
282
+ transcript.append(
283
+ f"{current_ai_prefix}: {sanitize_message_for_prompt(event.text)}"
284
+ )
285
+
286
+ if highlight_agent_turns:
287
+ if isinstance(event, AgentStarted) or isinstance(event, AgentResumed):
288
+ current_ai_prefix = event.agent_id
289
+ elif (
290
+ isinstance(event, AgentCompleted)
291
+ or isinstance(event, AgentCancelled)
292
+ or isinstance(event, AgentInterrupted)
293
+ ):
294
+ current_ai_prefix = ai_prefix
272
295
 
273
296
  # turns_wrapper to count multiple utterances by bot/user as single turn
274
297
  if turns_wrapper:
File without changes
@@ -0,0 +1,157 @@
1
+ """MCP server connection utilities."""
2
+
3
+ import asyncio
4
+ from contextlib import AsyncExitStack
5
+ from typing import Any, Dict, Optional
6
+
7
+ import structlog
8
+ from mcp import ClientSession
9
+ from mcp.client.streamable_http import streamablehttp_client
10
+
11
+ structlogger = structlog.get_logger()
12
+
13
+
14
+ class MCPServerConnection:
15
+ """
16
+ Manages connection to an MCP server.
17
+
18
+ This class handles the lifecycle of connections to MCP servers,
19
+ including connection establishment, session management, and cleanup.
20
+ """
21
+
22
+ def __init__(self, server_name: str, server_url: str, server_type: str):
23
+ """
24
+ Initialize the MCP server connection.
25
+
26
+ Args:
27
+ server_name: Server name to identify the server
28
+ server_url: Server URL
29
+ server_type: Server type (currently only 'http' is supported)
30
+ """
31
+ self.server_name = server_name
32
+ self.server_url = server_url
33
+ self.server_type = server_type
34
+ self.session: Optional[ClientSession] = None
35
+ self.exit_stack: Optional[AsyncExitStack] = None
36
+
37
+ @classmethod
38
+ def from_config(cls, server_config: Dict[str, Any]) -> "MCPServerConnection":
39
+ """Initialize the MCP server connection from a configuration dictionary."""
40
+ return cls(
41
+ server_config["name"],
42
+ server_config["url"],
43
+ server_config.get("type", "http"),
44
+ )
45
+
46
+ async def connect(self) -> None:
47
+ """Establish connection to the MCP server.
48
+
49
+ Raises:
50
+ ValueError: If the server type is not supported.
51
+ ConnectionError: If connection fails.
52
+ """
53
+ if self.server_type != "http":
54
+ raise ValueError(f"Unsupported server type: {self.server_type}")
55
+
56
+ # Create a new exit stack for this connection to avoid task boundary issues
57
+ self.exit_stack = AsyncExitStack()
58
+
59
+ try:
60
+ read_stream, write_stream, _ = await self.exit_stack.enter_async_context(
61
+ streamablehttp_client(self.server_url)
62
+ )
63
+ self.session = await self.exit_stack.enter_async_context(
64
+ ClientSession(read_stream, write_stream)
65
+ )
66
+ await self.session.initialize()
67
+ except asyncio.CancelledError as e:
68
+ event_info = f"Connection to MCP server {self.server_name} was cancelled."
69
+ structlogger.error(
70
+ "mcp_server_connection.connect.connection_cancelled",
71
+ event_info=event_info,
72
+ server_name=self.server_name,
73
+ server_url=self.server_url,
74
+ )
75
+ # Clean up on cancellation
76
+ await self._cleanup()
77
+ raise ConnectionError(e)
78
+
79
+ except Exception as e:
80
+ event_info = f"Failed to connect to MCP server {self.server_name}: {e}"
81
+ structlogger.error(
82
+ "mcp_server_connection.connect.connection_failed",
83
+ event_info=event_info,
84
+ server_name=self.server_name,
85
+ server_url=self.server_url,
86
+ )
87
+ # Clean up on error
88
+ await self._cleanup()
89
+ raise ConnectionError(e)
90
+
91
+ async def ensure_active_session(self) -> ClientSession:
92
+ """
93
+ Ensure an active session is available.
94
+
95
+ If no session exists or the current session is inactive,
96
+ a new connection will be established.
97
+
98
+ Returns:
99
+ Active ClientSession instance.
100
+ """
101
+ if self.session is None:
102
+ await self.connect()
103
+ structlogger.info(
104
+ "mcp_server_connection.ensure_active_session.no_session",
105
+ server_name=self.server_name,
106
+ server_url=self.server_url,
107
+ event_info=(
108
+ "No session found, connecting to the server "
109
+ f"`{self.server_name}` @ `{self.server_url}`"
110
+ ),
111
+ )
112
+ if self.session:
113
+ try:
114
+ await self.session.send_ping()
115
+ except Exception as e:
116
+ structlogger.error(
117
+ "mcp_server_connection.ensure_active_session.ping_failed",
118
+ error=str(e),
119
+ server_name=self.server_name,
120
+ server_url=self.server_url,
121
+ event_info=(
122
+ "Ping failed, trying to reconnect to the server "
123
+ f"`{self.server_name}` @ `{self.server_url}`"
124
+ ),
125
+ )
126
+ await self.connect()
127
+ structlogger.info(
128
+ "mcp_server_connection.ensure_active_session.reconnected",
129
+ server_name=self.server_name,
130
+ server_url=self.server_url,
131
+ event_info=(
132
+ "Reconnected to the server "
133
+ f"`{self.server_name}` @ `{self.server_url}`"
134
+ ),
135
+ )
136
+ assert self.session is not None # Ensures type for mypy
137
+ return self.session
138
+
139
+ async def close(self) -> None:
140
+ """Close the connection and clean up resources."""
141
+ await self._cleanup()
142
+
143
+ async def _cleanup(self) -> None:
144
+ """Internal cleanup method to safely close resources."""
145
+ if self.exit_stack:
146
+ try:
147
+ await self.exit_stack.aclose()
148
+ except Exception as e:
149
+ # Log cleanup errors but don't raise them
150
+ structlogger.warning(
151
+ "mcp_server_connection.cleanup.failed",
152
+ server_name=self.server_name,
153
+ error=str(e),
154
+ )
155
+ finally:
156
+ self.exit_stack = None
157
+ self.session = None
@@ -160,6 +160,43 @@ FLOW_CANCELLED = {
160
160
  "step_id": {"type": "string"},
161
161
  }
162
162
  }
163
+ AGENT_STARTED = {
164
+ "properties": {
165
+ "event": {"const": "agent_started"},
166
+ "agent_id": {"type": "string"},
167
+ "flow_id": {"type": "string"},
168
+ }
169
+ }
170
+ AGENT_COMPLETED = {
171
+ "properties": {
172
+ "event": {"const": "agent_completed"},
173
+ "agent_id": {"type": "string"},
174
+ "flow_id": {"type": "string"},
175
+ "status": {"type": "string"},
176
+ }
177
+ }
178
+ AGENT_INTERRUPTED = {
179
+ "properties": {
180
+ "event": {"const": "agent_interrupted"},
181
+ "agent_id": {"type": "string"},
182
+ "flow_id": {"type": "string"},
183
+ }
184
+ }
185
+ AGENT_RESUMED = {
186
+ "properties": {
187
+ "event": {"const": "agent_resumed"},
188
+ "agent_id": {"type": "string"},
189
+ "flow_id": {"type": "string"},
190
+ }
191
+ }
192
+ AGENT_CANCELLED = {
193
+ "properties": {
194
+ "event": {"const": "agent_cancelled"},
195
+ "agent_id": {"type": "string"},
196
+ "flow_id": {"type": "string"},
197
+ "reason": {"type": "string"},
198
+ }
199
+ }
163
200
  DIALOGUE_STACK_UPDATED = {
164
201
  "properties": {"event": {"const": "stack"}, "update": {"type": "string"}}
165
202
  }
@@ -204,6 +241,11 @@ EVENT_SCHEMA = {
204
241
  FLOW_RESUMED,
205
242
  FLOW_COMPLETED,
206
243
  FLOW_CANCELLED,
244
+ AGENT_STARTED,
245
+ AGENT_COMPLETED,
246
+ AGENT_INTERRUPTED,
247
+ AGENT_RESUMED,
248
+ AGENT_CANCELLED,
207
249
  DIALOGUE_STACK_UPDATED,
208
250
  ROUTING_SESSION_ENDED,
209
251
  SESSION_ENDED,
rasa/studio/upload.py CHANGED
@@ -115,10 +115,9 @@ def run_validation(args: argparse.Namespace) -> None:
115
115
  """
116
116
  from rasa.validator import Validator
117
117
 
118
- training_data_paths = args.data if isinstance(args.data, list) else [args.data]
119
118
  training_data_importer = TrainingDataImporter.load_from_dict(
120
119
  domain_path=args.domain,
121
- training_data_paths=training_data_paths,
120
+ training_data_paths=[args.data],
122
121
  config_path=args.config,
123
122
  expand_env_vars=False,
124
123
  )
@@ -264,9 +263,8 @@ def build_calm_import_parts(
264
263
  domain_from_files = importer.get_user_domain().as_dict()
265
264
  domain = extract_values(domain_from_files, DOMAIN_KEYS)
266
265
 
267
- training_data_paths = data_path if isinstance(data_path, list) else [str(data_path)]
268
266
  flow_importer = FlowSyncImporter.load_from_dict(
269
- training_data_paths=training_data_paths, expand_env_vars=False
267
+ training_data_paths=[str(data_path)], expand_env_vars=False
270
268
  )
271
269
 
272
270
  flows = list(flow_importer.get_user_flows())
@@ -274,7 +272,7 @@ def build_calm_import_parts(
274
272
  flows = read_yaml(flows_yaml, expand_env_vars=False)
275
273
 
276
274
  nlu_importer = TrainingDataImporter.load_from_dict(
277
- training_data_paths=training_data_paths, expand_env_vars=False
275
+ training_data_paths=[str(data_path)], expand_env_vars=False
278
276
  )
279
277
  nlu_data = nlu_importer.get_nlu_data()
280
278
  nlu_examples = nlu_data.filter_training_examples(
@@ -351,10 +349,9 @@ def upload_nlu_assistant(
351
349
  "rasa.studio.upload.nlu_data_read",
352
350
  event_info="Found DM1 assistant data, parsing...",
353
351
  )
354
- training_data_paths = args.data if isinstance(args.data, list) else [args.data]
355
352
  importer = TrainingDataImporter.load_from_dict(
356
353
  domain_path=args.domain,
357
- training_data_paths=training_data_paths,
354
+ training_data_paths=[args.data],
358
355
  config_path=args.config,
359
356
  expand_env_vars=False,
360
357
  )