rasa-pro 3.15.0a1__py3-none-any.whl → 3.15.0a3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rasa-pro might be problematic. Click here for more details.

Files changed (50) hide show
  1. rasa/builder/constants.py +5 -0
  2. rasa/builder/copilot/models.py +80 -28
  3. rasa/builder/download.py +110 -0
  4. rasa/builder/evaluator/__init__.py +0 -0
  5. rasa/builder/evaluator/constants.py +15 -0
  6. rasa/builder/evaluator/copilot_executor.py +89 -0
  7. rasa/builder/evaluator/dataset/models.py +173 -0
  8. rasa/builder/evaluator/exceptions.py +4 -0
  9. rasa/builder/evaluator/response_classification/__init__.py +0 -0
  10. rasa/builder/evaluator/response_classification/constants.py +66 -0
  11. rasa/builder/evaluator/response_classification/evaluator.py +346 -0
  12. rasa/builder/evaluator/response_classification/langfuse_runner.py +463 -0
  13. rasa/builder/evaluator/response_classification/models.py +61 -0
  14. rasa/builder/evaluator/scripts/__init__.py +0 -0
  15. rasa/builder/evaluator/scripts/run_response_classification_evaluator.py +152 -0
  16. rasa/builder/jobs.py +208 -1
  17. rasa/builder/logging_utils.py +25 -24
  18. rasa/builder/main.py +6 -1
  19. rasa/builder/models.py +23 -0
  20. rasa/builder/project_generator.py +29 -10
  21. rasa/builder/service.py +104 -22
  22. rasa/builder/training_service.py +13 -1
  23. rasa/builder/validation_service.py +2 -1
  24. rasa/core/actions/action_clean_stack.py +32 -0
  25. rasa/core/actions/constants.py +4 -0
  26. rasa/core/actions/custom_action_executor.py +70 -12
  27. rasa/core/actions/grpc_custom_action_executor.py +41 -2
  28. rasa/core/actions/http_custom_action_executor.py +49 -25
  29. rasa/core/channels/voice_stream/voice_channel.py +14 -2
  30. rasa/dialogue_understanding/generator/llm_based_command_generator.py +6 -3
  31. rasa/dialogue_understanding/generator/single_step/compact_llm_command_generator.py +15 -7
  32. rasa/dialogue_understanding/generator/single_step/search_ready_llm_command_generator.py +15 -8
  33. rasa/dialogue_understanding/processor/command_processor.py +49 -7
  34. rasa/shared/providers/_configs/azure_openai_client_config.py +4 -5
  35. rasa/shared/providers/_configs/default_litellm_client_config.py +4 -4
  36. rasa/shared/providers/_configs/litellm_router_client_config.py +3 -2
  37. rasa/shared/providers/_configs/openai_client_config.py +5 -7
  38. rasa/shared/providers/_configs/rasa_llm_client_config.py +4 -4
  39. rasa/shared/providers/_configs/self_hosted_llm_client_config.py +4 -4
  40. rasa/shared/providers/llm/_base_litellm_client.py +42 -14
  41. rasa/shared/providers/llm/litellm_router_llm_client.py +38 -15
  42. rasa/shared/providers/llm/self_hosted_llm_client.py +34 -32
  43. rasa/shared/utils/configs.py +5 -8
  44. rasa/utils/endpoints.py +6 -0
  45. rasa/version.py +1 -1
  46. {rasa_pro-3.15.0a1.dist-info → rasa_pro-3.15.0a3.dist-info}/METADATA +12 -12
  47. {rasa_pro-3.15.0a1.dist-info → rasa_pro-3.15.0a3.dist-info}/RECORD +50 -37
  48. {rasa_pro-3.15.0a1.dist-info → rasa_pro-3.15.0a3.dist-info}/NOTICE +0 -0
  49. {rasa_pro-3.15.0a1.dist-info → rasa_pro-3.15.0a3.dist-info}/WHEEL +0 -0
  50. {rasa_pro-3.15.0a1.dist-info → rasa_pro-3.15.0a3.dist-info}/entry_points.txt +0 -0
@@ -99,10 +99,16 @@ async def try_load_existing_agent(project_folder: str) -> Optional[Agent]:
99
99
  available_endpoints = Configuration.initialise_endpoints(
100
100
  endpoints_path=Path(project_folder) / DEFAULT_ENDPOINTS_PATH
101
101
  ).endpoints
102
+ # Get available sub agents for agent loading
103
+ _sub_agents = Configuration.initialise_sub_agents(
104
+ sub_agents_path=None
105
+ ).available_agents
102
106
 
103
107
  # Load the agent
104
108
  agent = await load_agent(
105
- model_path=latest_model_path, endpoints=available_endpoints
109
+ model_path=latest_model_path,
110
+ endpoints=available_endpoints,
111
+ sub_agents=_sub_agents,
106
112
  )
107
113
 
108
114
  if agent and agent.is_ready():
@@ -133,6 +139,9 @@ async def _train_model(
133
139
  try:
134
140
  structlogger.info("training.started")
135
141
 
142
+ # init sub agents using the default path
143
+ Configuration.initialise_sub_agents(sub_agents_path=None)
144
+
136
145
  training_result = await train(
137
146
  domain="",
138
147
  config=str(config_file),
@@ -160,6 +169,8 @@ async def _load_agent(model_path: str, endpoints_file: Path) -> Agent:
160
169
  available_endpoints = Configuration.initialise_endpoints(
161
170
  endpoints_path=endpoints_file
162
171
  ).endpoints
172
+ _sub_agents = Configuration.get_instance().available_agents
173
+
163
174
  if available_endpoints is None:
164
175
  raise AgentLoadError("No endpoints available for agent loading")
165
176
 
@@ -173,6 +184,7 @@ async def _load_agent(model_path: str, endpoints_file: Path) -> Agent:
173
184
  model_path=model_path,
174
185
  remote_storage=None,
175
186
  endpoints=available_endpoints,
187
+ sub_agents=_sub_agents,
176
188
  )
177
189
 
178
190
  if agent_instance is None:
@@ -24,7 +24,7 @@ def _mock_sys_exit() -> Generator[Dict[str, bool], Any, None]:
24
24
  was_sys_exit_called["value"] = True
25
25
 
26
26
  original_exit = sys.exit
27
- sys.exit = sys_exit_mock # type: ignore[assignment]
27
+ sys.exit = sys_exit_mock # type: ignore
28
28
 
29
29
  try:
30
30
  yield was_sys_exit_called
@@ -50,6 +50,7 @@ async def validate_project(importer: TrainingDataImporter) -> Optional[str]:
50
50
  from rasa.core.config.configuration import Configuration
51
51
 
52
52
  Configuration.initialise_empty()
53
+ Configuration.initialise_sub_agents(sub_agents_path=None)
53
54
 
54
55
  validate_files(
55
56
  fail_on_warnings=config.VALIDATION_FAIL_ON_WARNINGS,
@@ -4,9 +4,11 @@ from typing import Any, Dict, List, Optional
4
4
 
5
5
  import structlog
6
6
 
7
+ import rasa.dialogue_understanding.stack.utils
7
8
  from rasa.core.actions.action import Action
8
9
  from rasa.core.channels import OutputChannel
9
10
  from rasa.core.nlg import NaturalLanguageGenerator
11
+ from rasa.dialogue_understanding.patterns.code_change import FLOW_PATTERN_CODE_CHANGE_ID
10
12
  from rasa.dialogue_understanding.stack.dialogue_stack import DialogueStack
11
13
  from rasa.dialogue_understanding.stack.frames import (
12
14
  BaseFlowStackFrame,
@@ -41,6 +43,15 @@ class ActionCleanStack(Action):
41
43
  """Clean the stack."""
42
44
  structlogger.debug("action_clean_stack.run")
43
45
  new_frames = []
46
+ top_flow_frame = rasa.dialogue_understanding.stack.utils.top_flow_frame(
47
+ tracker.stack, ignore_call_frames=False
48
+ )
49
+ top_user_flow_frame = (
50
+ rasa.dialogue_understanding.stack.utils.top_user_flow_frame(
51
+ tracker.stack, ignore_call_and_link_frames=False
52
+ )
53
+ )
54
+
44
55
  # Set all frames to their end step, filter out any non-BaseFlowStackFrames
45
56
  for frame in tracker.stack.frames:
46
57
  if isinstance(frame, BaseFlowStackFrame):
@@ -56,4 +67,25 @@ class ActionCleanStack(Action):
56
67
  new_frames.append(frame)
57
68
  new_stack = DialogueStack.from_dict([frame.as_dict() for frame in new_frames])
58
69
 
70
+ # Check if the action is being called from within a user flow
71
+ if (
72
+ top_flow_frame
73
+ and top_flow_frame.flow_id != FLOW_PATTERN_CODE_CHANGE_ID
74
+ and top_user_flow_frame
75
+ and top_user_flow_frame.flow_id == top_flow_frame.flow_id
76
+ ):
77
+ # The action is being called from within a user flow on the stack.
78
+ # If there are other frames on the stack, we need to make sure
79
+ # the last executed frame is the end step of the current user flow so
80
+ # that we can trigger pattern_completed for this user flow.
81
+ new_stack.pop()
82
+ structlogger.debug(
83
+ "action_clean_stack.pushing_user_frame_at_the_bottom_of_stack",
84
+ flow_id=top_user_flow_frame.flow_id,
85
+ )
86
+ new_stack.push(
87
+ top_user_flow_frame,
88
+ index=0,
89
+ )
90
+
59
91
  return tracker.create_stack_updated_events(new_stack)
@@ -3,3 +3,7 @@ SELECTIVE_DOMAIN = "enable_selective_domain"
3
3
 
4
4
  SSL_CLIENT_CERT_FIELD = "ssl_client_cert"
5
5
  SSL_CLIENT_KEY_FIELD = "ssl_client_key"
6
+
7
+ # Special marker key used by EndpointConfig to indicate 449 status
8
+ # without raising an exception
9
+ MISSING_DOMAIN_MARKER = "missing_domain"
@@ -2,7 +2,10 @@ from __future__ import annotations
2
2
 
3
3
  import abc
4
4
  import logging
5
- from typing import TYPE_CHECKING, Any, Dict, Text
5
+ from enum import Enum
6
+ from typing import TYPE_CHECKING, Any, Dict, Optional, Text
7
+
8
+ from pydantic import BaseModel
6
9
 
7
10
  import rasa
8
11
  from rasa.core.actions.action_exceptions import DomainNotFound
@@ -19,6 +22,23 @@ if TYPE_CHECKING:
19
22
  logger = logging.getLogger(__name__)
20
23
 
21
24
 
25
+ class ActionResultType(Enum):
26
+ SUCCESS = "success"
27
+ RETRY_WITH_DOMAIN = "retry_with_domain"
28
+
29
+
30
+ class ActionResult(BaseModel):
31
+ """Result of custom action execution.
32
+
33
+ This is used to avoid raising exceptions for expected conditions
34
+ like missing domain (449 status code), which would otherwise be
35
+ captured by tracing as errors.
36
+ """
37
+
38
+ result_type: ActionResultType
39
+ response: Optional[Dict[Text, Any]] = None
40
+
41
+
22
42
  class CustomActionExecutor(abc.ABC):
23
43
  """Interface for custom action executors.
24
44
 
@@ -45,6 +65,34 @@ class CustomActionExecutor(abc.ABC):
45
65
  """
46
66
  pass
47
67
 
68
+ async def run_with_result(
69
+ self,
70
+ tracker: "DialogueStateTracker",
71
+ domain: "Domain",
72
+ include_domain: bool = False,
73
+ ) -> ActionResult:
74
+ """Executes the custom action and returns a result.
75
+
76
+ This method is used to avoid raising exceptions for expected conditions
77
+ like missing domain, which would otherwise be captured by tracing as errors.
78
+
79
+ By default, this method calls the run method and wraps the response
80
+ for backward compatibility.
81
+
82
+ Args:
83
+ tracker: The current state of the dialogue.
84
+ domain: The domain object containing domain-specific information.
85
+ include_domain: If True, the domain is included in the request.
86
+
87
+ Returns:
88
+ ActionResult containing the response and result type.
89
+ """
90
+ try:
91
+ response = await self.run(tracker, domain, include_domain)
92
+ return ActionResult(result_type=ActionResultType.SUCCESS, response=response)
93
+ except DomainNotFound:
94
+ return ActionResult(result_type=ActionResultType.RETRY_WITH_DOMAIN)
95
+
48
96
 
49
97
  class NoEndpointCustomActionExecutor(CustomActionExecutor):
50
98
  """Implementation of a custom action executor when endpoint is not set.
@@ -163,13 +211,13 @@ class RetryCustomActionExecutor(CustomActionExecutor):
163
211
  domain: "Domain",
164
212
  include_domain: bool = False,
165
213
  ) -> Dict[Text, Any]:
166
- """Runs the wrapped custom action executor.
214
+ """Runs the wrapped custom action executor with retry logic.
167
215
 
168
216
  First request to the action server is made with/without the domain
169
217
  as specified by the `include_domain` parameter.
170
218
 
171
- If the action server responds with a `DomainNotFound` error, by running the
172
- custom action executor again with the domain information.
219
+ If the action server responds with a missing domain indication,
220
+ retries the request with the domain included.
173
221
 
174
222
  Args:
175
223
  tracker: The current state of the dialogue.
@@ -178,14 +226,24 @@ class RetryCustomActionExecutor(CustomActionExecutor):
178
226
 
179
227
  Returns:
180
228
  The response from the execution of the custom action.
229
+
230
+ Raises:
231
+ DomainNotFound: If the action server still requires domain after retry.
181
232
  """
182
- try:
183
- return await self._custom_action_executor.run(
184
- tracker,
185
- domain,
186
- include_domain=include_domain,
187
- )
188
- except DomainNotFound:
189
- return await self._custom_action_executor.run(
233
+ result = await self._custom_action_executor.run_with_result(
234
+ tracker,
235
+ domain,
236
+ include_domain=include_domain,
237
+ )
238
+
239
+ if result.result_type == ActionResultType.RETRY_WITH_DOMAIN:
240
+ # Retry with domain included
241
+ result = await self._custom_action_executor.run_with_result(
190
242
  tracker, domain, include_domain=True
191
243
  )
244
+
245
+ # If still missing domain after retry, raise error
246
+ if result.result_type == ActionResultType.RETRY_WITH_DOMAIN:
247
+ raise DomainNotFound()
248
+
249
+ return result.response if result.response is not None else {}
@@ -11,6 +11,8 @@ from rasa_sdk.grpc_py import action_webhook_pb2, action_webhook_pb2_grpc
11
11
  from rasa.core.actions.action_exceptions import DomainNotFound
12
12
  from rasa.core.actions.constants import SSL_CLIENT_CERT_FIELD, SSL_CLIENT_KEY_FIELD
13
13
  from rasa.core.actions.custom_action_executor import (
14
+ ActionResult,
15
+ ActionResultType,
14
16
  CustomActionExecutor,
15
17
  CustomActionRequestWriter,
16
18
  )
@@ -101,13 +103,51 @@ class GRPCCustomActionExecutor(CustomActionExecutor):
101
103
 
102
104
  Returns:
103
105
  Response from the action server.
106
+ Returns empty dict if domain is missing.
107
+
108
+ Raises:
109
+ RasaException: If an error occurs while making the gRPC request
110
+ (other than missing domain).
104
111
  """
112
+ result = await self.run_with_result(tracker, domain, include_domain)
113
+
114
+ # Return empty dict for retry cases to avoid raising exceptions
115
+ # RetryCustomActionExecutor will handle the retry logic
116
+ if result.result_type == ActionResultType.RETRY_WITH_DOMAIN:
117
+ return {}
118
+
119
+ return result.response if result.response is not None else {}
120
+
121
+ async def run_with_result(
122
+ self,
123
+ tracker: "DialogueStateTracker",
124
+ domain: "Domain",
125
+ include_domain: bool = False,
126
+ ) -> ActionResult:
127
+ """Execute the custom action and return an ActionResult.
128
+
129
+ This method avoids raising DomainNotFound exception for missing domain,
130
+ instead returning an ActionResult with RETRY_WITH_DOMAIN type.
131
+ This prevents tracing from capturing this expected condition as an error.
105
132
 
133
+ Args:
134
+ tracker: Tracker for the current conversation.
135
+ domain: Domain of the assistant.
136
+ include_domain: If True, the domain is included in the request.
137
+
138
+ Returns:
139
+ ActionResult containing the response and result type.
140
+ """
106
141
  request = self._create_payload(
107
142
  tracker=tracker, domain=domain, include_domain=include_domain
108
143
  )
109
144
 
110
- return self._request(request)
145
+ try:
146
+ response = self._request(request)
147
+ return ActionResult(result_type=ActionResultType.SUCCESS, response=response)
148
+ except DomainNotFound:
149
+ # Return retry result instead of raising DomainNotFound
150
+ return ActionResult(result_type=ActionResultType.RETRY_WITH_DOMAIN)
111
151
 
112
152
  def _request(
113
153
  self,
@@ -121,7 +161,6 @@ class GRPCCustomActionExecutor(CustomActionExecutor):
121
161
  Returns:
122
162
  Response from the action server.
123
163
  """
124
-
125
164
  client = self._create_grpc_client()
126
165
  metadata = self._build_metadata()
127
166
  try:
@@ -4,8 +4,11 @@ from typing import TYPE_CHECKING, Any, Dict, Optional
4
4
 
5
5
  import aiohttp
6
6
 
7
- from rasa.core.actions.action_exceptions import ActionExecutionRejection, DomainNotFound
7
+ from rasa.core.actions.action_exceptions import ActionExecutionRejection
8
+ from rasa.core.actions.constants import MISSING_DOMAIN_MARKER
8
9
  from rasa.core.actions.custom_action_executor import (
10
+ ActionResult,
11
+ ActionResultType,
9
12
  CustomActionExecutor,
10
13
  CustomActionRequestWriter,
11
14
  )
@@ -18,12 +21,12 @@ from rasa.shared.core.domain import Domain
18
21
  from rasa.shared.core.trackers import DialogueStateTracker
19
22
  from rasa.shared.exceptions import RasaException
20
23
  from rasa.utils.common import get_bool_env_variable
24
+ from rasa.utils.endpoints import ClientResponseError, EndpointConfig
21
25
 
22
26
  if TYPE_CHECKING:
23
27
  from rasa.shared.core.domain import Domain
24
28
  from rasa.shared.core.trackers import DialogueStateTracker
25
29
 
26
- from rasa.utils.endpoints import ClientResponseError, EndpointConfig
27
30
 
28
31
  logger = logging.getLogger(__name__)
29
32
 
@@ -62,9 +65,40 @@ class HTTPCustomActionExecutor(CustomActionExecutor):
62
65
 
63
66
  Returns:
64
67
  A dictionary containing the response from the custom action endpoint.
68
+ Returns empty dict if domain is missing (449 status).
65
69
 
66
70
  Raises:
67
- RasaException: If an error occurs while making the HTTP request.
71
+ RasaException: If an error occurs while making the HTTP request
72
+ (other than missing domain).
73
+ """
74
+ result = await self.run_with_result(tracker, domain, include_domain)
75
+
76
+ # Return empty dict for retry cases to avoid raising exceptions
77
+ # RetryCustomActionExecutor will handle the retry logic
78
+ if result.result_type == ActionResultType.RETRY_WITH_DOMAIN:
79
+ return {}
80
+
81
+ return result.response if result.response is not None else {}
82
+
83
+ async def run_with_result(
84
+ self,
85
+ tracker: "DialogueStateTracker",
86
+ domain: Optional["Domain"] = None,
87
+ include_domain: bool = False,
88
+ ) -> ActionResult:
89
+ """Execute the custom action and return an ActionResult.
90
+
91
+ This method avoids raising DomainNotFound exception for 449 status code,
92
+ instead returning an ActionResult with RETRY_WITH_DOMAIN type.
93
+ This prevents tracing from capturing this expected condition as an error.
94
+
95
+ Args:
96
+ tracker: The current state of the dialogue.
97
+ domain: The domain object containing domain-specific information.
98
+ include_domain: If True, the domain is included in the request.
99
+
100
+ Returns:
101
+ ActionResult containing the response and result type.
68
102
  """
69
103
  from rasa.core.actions.action import RemoteActionJSONValidator
70
104
 
@@ -77,14 +111,23 @@ class HTTPCustomActionExecutor(CustomActionExecutor):
77
111
  tracker=tracker, domain=domain, include_domain=include_domain
78
112
  )
79
113
 
80
- response = await self._perform_request_with_retries(json_body)
114
+ assert self.action_endpoint is not None
115
+ response = await self.action_endpoint.request(
116
+ json=json_body,
117
+ method="post",
118
+ timeout=DEFAULT_REQUEST_TIMEOUT,
119
+ compress=self.should_compress,
120
+ )
121
+
122
+ # Check if we got the special marker for 449 status (missing domain)
123
+ if isinstance(response, dict) and response.get(MISSING_DOMAIN_MARKER):
124
+ return ActionResult(result_type=ActionResultType.RETRY_WITH_DOMAIN)
81
125
 
82
126
  if response is None:
83
127
  response = {}
84
128
 
85
129
  RemoteActionJSONValidator.validate(response)
86
-
87
- return response
130
+ return ActionResult(result_type=ActionResultType.SUCCESS, response=response)
88
131
 
89
132
  except ClientResponseError as e:
90
133
  if e.status == 400:
@@ -131,22 +174,3 @@ class HTTPCustomActionExecutor(CustomActionExecutor):
131
174
  "and returns a 200 once the action is executed. "
132
175
  "Error: {}".format(self.action_name, status, e)
133
176
  )
134
-
135
- async def _perform_request_with_retries(
136
- self,
137
- json_body: Dict[str, Any],
138
- ) -> Any:
139
- """Attempts to perform the request with retries if necessary."""
140
- assert self.action_endpoint is not None
141
- try:
142
- return await self.action_endpoint.request(
143
- json=json_body,
144
- method="post",
145
- timeout=DEFAULT_REQUEST_TIMEOUT,
146
- compress=self.should_compress,
147
- )
148
- except ClientResponseError as e:
149
- # Repeat the request because Domain was not in the payload
150
- if e.status == 449:
151
- raise DomainNotFound()
152
- raise e
@@ -373,6 +373,14 @@ class VoiceOutputChannel(OutputChannel):
373
373
  async def hangup(self, recipient_id: str, **kwargs: Any) -> None:
374
374
  call_state.should_hangup = True
375
375
 
376
+ async def send_turn_end_marker(self, recipient_id: str) -> None:
377
+ """Send a marker to indicate the bot has finished its turn.
378
+
379
+ Used internally by Rasa during conversation simulations.
380
+ This is called after all bot messages in a turn have been sent.
381
+ """
382
+ pass
383
+
376
384
 
377
385
  class VoiceInputChannel(InputChannel):
378
386
  # All children of this class require a voice license to be used.
@@ -471,14 +479,16 @@ class VoiceInputChannel(InputChannel):
471
479
  call_parameters: CallParameters,
472
480
  ) -> None:
473
481
  output_channel = self.create_output_channel(channel_websocket, tts_engine)
482
+ sender_id = self.get_sender_id(call_parameters)
474
483
  message = UserMessage(
475
484
  text=USER_CONVERSATION_SESSION_START,
476
485
  output_channel=output_channel,
477
- sender_id=self.get_sender_id(call_parameters),
486
+ sender_id=sender_id,
478
487
  input_channel=self.name(),
479
488
  metadata=asdict(call_parameters),
480
489
  )
481
490
  await on_new_message(message)
491
+ await output_channel.send_turn_end_marker(sender_id)
482
492
 
483
493
  def map_input_message(
484
494
  self,
@@ -646,14 +656,16 @@ class VoiceInputChannel(InputChannel):
646
656
  call_state.rasa_processing_start_time = time.time()
647
657
 
648
658
  output_channel = self.create_output_channel(voice_websocket, tts_engine)
659
+ sender_id = self.get_sender_id(call_parameters)
649
660
  message = UserMessage(
650
661
  text=e.text,
651
662
  output_channel=output_channel,
652
- sender_id=self.get_sender_id(call_parameters),
663
+ sender_id=sender_id,
653
664
  input_channel=self.name(),
654
665
  metadata=asdict(call_parameters),
655
666
  )
656
667
  await on_new_message(message)
668
+ await output_channel.send_turn_end_marker(sender_id)
657
669
  elif isinstance(e, UserIsSpeaking):
658
670
  # Track when user starts speaking for ASR latency calculation
659
671
  if not call_state.is_user_speaking:
@@ -62,7 +62,9 @@ structlogger = structlog.get_logger()
62
62
  class LLMBasedCommandGenerator(
63
63
  LLMHealthCheckMixin, GraphComponent, CommandGenerator, ABC
64
64
  ):
65
- """An abstract class defining interface and common functionality
65
+ """This class provides common functionality for all LLM-based command generators.
66
+
67
+ An abstract class defining interface and common functionality
66
68
  of an LLM-based command generators.
67
69
  """
68
70
 
@@ -174,8 +176,9 @@ class LLMBasedCommandGenerator(
174
176
  def train(
175
177
  self, training_data: TrainingData, flows: FlowsList, domain: Domain
176
178
  ) -> Resource:
177
- """Train the llm based command generator. Stores all flows into a vector
178
- store.
179
+ """Trains the LLM-based command generator and prepares flow retrieval data.
180
+
181
+ Stores all flows into a vector store.
179
182
  """
180
183
  self.perform_llm_health_check(
181
184
  self.config.get(LLM_CONFIG_KEY),
@@ -168,6 +168,20 @@ class CompactLLMCommandGenerator(SingleStepBasedLLMCommandGenerator):
168
168
  if prompt_template is not None:
169
169
  return prompt_template
170
170
 
171
+ # Try to load the template from the given path or fallback to the default for
172
+ # the component.
173
+ custom_prompt_template_path = config.get(PROMPT_TEMPLATE_CONFIG_KEY)
174
+ if custom_prompt_template_path is not None:
175
+ custom_prompt_template = get_prompt_template(
176
+ custom_prompt_template_path,
177
+ None, # Default will be based on the model
178
+ log_source_component=log_source_component,
179
+ log_source_method=log_context,
180
+ )
181
+ if custom_prompt_template is not None:
182
+ return custom_prompt_template
183
+
184
+ # Fallback to the default prompt template based on the model.
171
185
  default_command_prompt_template = get_default_prompt_template_based_on_model(
172
186
  llm_config=config.get(LLM_CONFIG_KEY, {}) or {},
173
187
  model_prompt_mapping=cls.get_model_prompt_mapper(),
@@ -177,10 +191,4 @@ class CompactLLMCommandGenerator(SingleStepBasedLLMCommandGenerator):
177
191
  log_source_method=log_context,
178
192
  )
179
193
 
180
- # Return the prompt template either from the config or the default prompt.
181
- return get_prompt_template(
182
- config.get(PROMPT_TEMPLATE_CONFIG_KEY),
183
- default_command_prompt_template,
184
- log_source_component=log_source_component,
185
- log_source_method=log_context,
186
- )
194
+ return default_command_prompt_template
@@ -165,7 +165,20 @@ class SearchReadyLLMCommandGenerator(SingleStepBasedLLMCommandGenerator):
165
165
  if prompt_template is not None:
166
166
  return prompt_template
167
167
 
168
- # Get the default prompt template based on the model name.
168
+ # Try to load the template from the given path or fallback to the default for
169
+ # the component.
170
+ custom_prompt_template_path = config.get(PROMPT_TEMPLATE_CONFIG_KEY)
171
+ if custom_prompt_template_path is not None:
172
+ custom_prompt_template = get_prompt_template(
173
+ custom_prompt_template_path,
174
+ None, # Default will be based on the model
175
+ log_source_component=log_source_component,
176
+ log_source_method=log_context,
177
+ )
178
+ if custom_prompt_template is not None:
179
+ return custom_prompt_template
180
+
181
+ # Fallback to the default prompt template based on the model.
169
182
  default_command_prompt_template = get_default_prompt_template_based_on_model(
170
183
  llm_config=config.get(LLM_CONFIG_KEY, {}) or {},
171
184
  model_prompt_mapping=cls.get_model_prompt_mapper(),
@@ -175,10 +188,4 @@ class SearchReadyLLMCommandGenerator(SingleStepBasedLLMCommandGenerator):
175
188
  log_source_method=log_context,
176
189
  )
177
190
 
178
- # Return the prompt template either from the config or the default prompt.
179
- return get_prompt_template(
180
- config.get(PROMPT_TEMPLATE_CONFIG_KEY),
181
- default_command_prompt_template,
182
- log_source_component=log_source_component,
183
- log_source_method=log_context,
184
- )
191
+ return default_command_prompt_template
@@ -1,3 +1,4 @@
1
+ import os
1
2
  from typing import Dict, List, Optional, Set, Type
2
3
 
3
4
  import structlog
@@ -70,6 +71,8 @@ from rasa.shared.nlu.constants import COMMANDS
70
71
 
71
72
  structlogger = structlog.get_logger()
72
73
 
74
+ CLARIFY_ON_MULTIPLE_START_FLOWS_ENV_VAR_NAME = "CLARIFY_ON_MULTIPLE_START_FLOWS"
75
+
73
76
 
74
77
  def contains_command(commands: List[Command], typ: Type[Command]) -> bool:
75
78
  """Check if a list of commands contains a command of a given type.
@@ -499,10 +502,24 @@ def clean_up_commands(
499
502
  else:
500
503
  clean_commands.append(command)
501
504
 
505
+ clean_commands = _process_multiple_start_flow_commands(clean_commands, tracker)
506
+
507
+ # ensure that there is only one command of a certain command type
508
+ clean_commands = ensure_max_number_of_command_type(
509
+ clean_commands, CannotHandleCommand, 1
510
+ )
511
+ clean_commands = ensure_max_number_of_command_type(
512
+ clean_commands, RepeatBotMessagesCommand, 1
513
+ )
514
+ clean_commands = ensure_max_number_of_command_type(
515
+ clean_commands, ChitChatAnswerCommand, 1
516
+ )
517
+
502
518
  # Replace CannotHandleCommands with ContinueAgentCommand when an agent is active
503
519
  # to keep the agent running, but preserve chitchat
504
520
  clean_commands = _replace_cannot_handle_with_continue_agent(clean_commands, tracker)
505
521
 
522
+ # filter out cannot handle commands if there are other commands present
506
523
  # when coexistence is enabled, by default there will be a SetSlotCommand
507
524
  # for the ROUTE_TO_CALM_SLOT slot.
508
525
  if tracker.has_coexistence_routing_slot and len(clean_commands) > 2:
@@ -510,12 +527,6 @@ def clean_up_commands(
510
527
  elif not tracker.has_coexistence_routing_slot and len(clean_commands) > 1:
511
528
  clean_commands = filter_cannot_handle_command(clean_commands)
512
529
 
513
- clean_commands = ensure_max_number_of_command_type(
514
- clean_commands, RepeatBotMessagesCommand, 1
515
- )
516
- clean_commands = ensure_max_number_of_command_type(
517
- clean_commands, ContinueAgentCommand, 1
518
- )
519
530
  structlogger.debug(
520
531
  "command_processor.clean_up_commands.final_commands",
521
532
  command=clean_commands,
@@ -526,6 +537,37 @@ def clean_up_commands(
526
537
  return clean_commands
527
538
 
528
539
 
540
+ def _process_multiple_start_flow_commands(
541
+ commands: List[Command],
542
+ tracker: DialogueStateTracker,
543
+ ) -> List[Command]:
544
+ """Process multiple start flow commands.
545
+
546
+ If there are multiple start flow commands, no active flows and the
547
+ CLARIFY_ON_MULTIPLE_START_FLOWS env var is enabled, we replace the
548
+ start flow commands with a clarify command.
549
+ """
550
+ start_flow_candidates = filter_start_flow_commands(commands)
551
+ clarify_enabled = (
552
+ os.getenv("CLARIFY_ON_MULTIPLE_START_FLOWS", "false").lower() == "true"
553
+ )
554
+
555
+ if clarify_enabled and len(start_flow_candidates) > 1 and tracker.stack.is_empty():
556
+ # replace the start flow commands with a clarify command
557
+ commands = [
558
+ command for command in commands if not isinstance(command, StartFlowCommand)
559
+ ]
560
+ # avoid adding duplicate clarify commands
561
+ if not any(isinstance(c, ClarifyCommand) for c in commands):
562
+ structlogger.debug(
563
+ "command_processor.clean_up_commands.trigger_clarify_for_multiple_start_flows",
564
+ candidate_flows=start_flow_candidates,
565
+ )
566
+ commands.append(ClarifyCommand(options=start_flow_candidates))
567
+
568
+ return commands
569
+
570
+
529
571
  def _get_slots_eligible_for_correction(tracker: DialogueStateTracker) -> Set[str]:
530
572
  """Get all slots that are eligible for correction.
531
573
 
@@ -580,7 +622,7 @@ def clean_up_start_flow_command(
580
622
  # drop a start flow command if the starting flow is equal
581
623
  # to the currently active flow
582
624
  structlogger.debug(
583
- "command_processor.clean_up_commands." "skip_command_flow_already_active",
625
+ "command_processor.clean_up_commands.skip_command_flow_already_active",
584
626
  command=command,
585
627
  )
586
628
  return clean_commands