rasa-pro 3.11.5__py3-none-any.whl → 3.11.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rasa-pro might be problematic. Click here for more details.

Files changed (29) hide show
  1. rasa/core/brokers/kafka.py +59 -20
  2. rasa/core/channels/voice_ready/audiocodes.py +57 -24
  3. rasa/core/channels/voice_stream/asr/deepgram.py +57 -16
  4. rasa/core/channels/voice_stream/browser_audio.py +4 -1
  5. rasa/core/channels/voice_stream/tts/cartesia.py +11 -2
  6. rasa/core/policies/enterprise_search_prompt_with_citation_template.jinja2 +1 -1
  7. rasa/core/policies/intentless_policy.py +5 -59
  8. rasa/dialogue_understanding/generator/nlu_command_adapter.py +1 -1
  9. rasa/dialogue_understanding/processor/command_processor.py +20 -5
  10. rasa/dialogue_understanding/processor/command_processor_component.py +5 -2
  11. rasa/engine/validation.py +37 -2
  12. rasa/llm_fine_tuning/conversations.py +1 -1
  13. rasa/model_training.py +2 -1
  14. rasa/shared/constants.py +4 -0
  15. rasa/shared/core/constants.py +2 -0
  16. rasa/shared/core/domain.py +12 -3
  17. rasa/shared/core/events.py +67 -0
  18. rasa/shared/core/policies/__init__.py +0 -0
  19. rasa/shared/core/policies/utils.py +87 -0
  20. rasa/shared/providers/llm/default_litellm_llm_client.py +6 -1
  21. rasa/shared/utils/schemas/events.py +2 -2
  22. rasa/tracing/instrumentation/attribute_extractors.py +2 -0
  23. rasa/version.py +1 -1
  24. {rasa_pro-3.11.5.dist-info → rasa_pro-3.11.7.dist-info}/METADATA +10 -12
  25. {rasa_pro-3.11.5.dist-info → rasa_pro-3.11.7.dist-info}/RECORD +28 -27
  26. {rasa_pro-3.11.5.dist-info → rasa_pro-3.11.7.dist-info}/WHEEL +1 -1
  27. README.md +0 -41
  28. {rasa_pro-3.11.5.dist-info → rasa_pro-3.11.7.dist-info}/NOTICE +0 -0
  29. {rasa_pro-3.11.5.dist-info → rasa_pro-3.11.7.dist-info}/entry_points.txt +0 -0
@@ -12,6 +12,7 @@ import time
12
12
 
13
13
  from rasa.core.brokers.broker import EventBroker
14
14
  from rasa.core.exceptions import KafkaProducerInitializationError
15
+ from rasa.shared.core.events import ErrorHandled
15
16
  from rasa.shared.utils.io import DEFAULT_ENCODING
16
17
  from rasa.utils.endpoints import EndpointConfig
17
18
  import rasa.shared.utils.common
@@ -119,7 +120,7 @@ class KafkaEventBroker(EventBroker):
119
120
  retry_delay_in_seconds: float = 5,
120
121
  ) -> None:
121
122
  """Publishes events."""
122
- from confluent_kafka import KafkaException
123
+ from confluent_kafka import KafkaError, KafkaException
123
124
 
124
125
  if retries == 1:
125
126
  retries = 2
@@ -143,28 +144,66 @@ class KafkaEventBroker(EventBroker):
143
144
  )
144
145
  self.producer.poll(1)
145
146
  retries -= 1
146
- except Exception as e:
147
- logger.error(
148
- f"Could not publish message to kafka url '{self.url}'. "
149
- f"Failed with error: {e}"
150
- )
151
- try:
152
- self._check_kafka_connection()
153
- except KafkaException:
154
- logger.debug("Connection to kafka lost, reconnecting...")
155
- self.producer = self._create_producer()
156
- try:
157
- self._check_kafka_connection()
158
- logger.debug("Reconnection to kafka successful")
159
- self._publish(event)
160
- return
161
- except KafkaException:
162
- pass
163
- retries -= 1
164
- time.sleep(retry_delay_in_seconds)
147
+ except Exception as exc:
148
+ if (
149
+ isinstance(exc, KafkaException)
150
+ and exc.args[0].code() == KafkaError.MSG_SIZE_TOO_LARGE
151
+ ):
152
+ logger.warning(
153
+ "Message size is too large for the Kafka broker. "
154
+ "Please check the message.max.bytes configuration. "
155
+ "Sending error event."
156
+ )
157
+
158
+ original_event_type = event.get("event", "")
159
+ sender_id = event.get("sender_id", "")
160
+ event = ErrorHandled(
161
+ error_code=KafkaError.MSG_SIZE_TOO_LARGE,
162
+ metadata={
163
+ "error_msg": f"Skipping message for event type "
164
+ f"'{original_event_type}' because "
165
+ f"of Kafka message size limit: {exc.args[0].str()}.",
166
+ "error_source": "KafkaEventBroker",
167
+ },
168
+ ).as_dict()
169
+ event.update({"sender_id": sender_id})
170
+ else:
171
+ logger.error(
172
+ f"Could not publish message to kafka url '{self.url}'. "
173
+ f"Failed with error: {exc}"
174
+ )
175
+ self._retry_publish(event, retries, retry_delay_in_seconds)
165
176
 
166
177
  logger.error("Failed to publish Kafka event.")
167
178
 
179
+ def _retry_publish(
180
+ self,
181
+ event: Dict[Text, Any],
182
+ retries: int,
183
+ retry_delay_in_seconds: float,
184
+ ) -> None:
185
+ """Retries publishing if the producer is not connected.
186
+
187
+ Args:
188
+ event: The event to publish.
189
+ """
190
+ from confluent_kafka import KafkaException
191
+
192
+ try:
193
+ self._check_kafka_connection()
194
+ except KafkaException:
195
+ logger.debug("Connection to kafka lost, reconnecting...")
196
+ self.producer = self._create_producer()
197
+ try:
198
+ self._check_kafka_connection()
199
+ logger.debug("Reconnection to kafka successful")
200
+ self._publish(event)
201
+ return
202
+ except KafkaException:
203
+ pass
204
+ retries -= 1
205
+ time.sleep(retry_delay_in_seconds)
206
+
168
207
  def _check_kafka_connection(self) -> None:
169
208
  """Verifies connection with Kafka.
170
209
 
@@ -1,28 +1,29 @@
1
1
  import asyncio
2
2
  import copy
3
- from datetime import datetime, timezone, timedelta
3
+ import hmac
4
4
  import json
5
5
  import uuid
6
6
  from collections import defaultdict
7
7
  from dataclasses import asdict
8
+ from datetime import datetime, timedelta, timezone
8
9
  from typing import Any, Awaitable, Callable, Dict, List, Optional, Set, Text, Union
9
10
 
10
11
  import structlog
11
12
  from jsonschema import ValidationError, validate
13
+ from sanic import Blueprint, response
14
+ from sanic.exceptions import NotFound, SanicException, ServerError
15
+ from sanic.request import Request
16
+ from sanic.response import HTTPResponse
17
+
12
18
  from rasa.core import jobs
13
19
  from rasa.core.channels.channel import InputChannel, OutputChannel, UserMessage
14
20
  from rasa.core.channels.voice_ready.utils import (
15
- validate_voice_license_scope,
16
21
  CallParameters,
22
+ validate_voice_license_scope,
17
23
  )
18
24
  from rasa.shared.constants import INTENT_MESSAGE_PREFIX
19
25
  from rasa.shared.core.constants import USER_INTENT_SESSION_START
20
26
  from rasa.shared.exceptions import RasaException
21
- from sanic import Blueprint, response
22
- from sanic.exceptions import NotFound, SanicException, ServerError
23
- from sanic.request import Request
24
- from sanic.response import HTTPResponse
25
-
26
27
  from rasa.utils.io import remove_emojis
27
28
 
28
29
  structlogger = structlog.get_logger()
@@ -114,11 +115,21 @@ class Conversation:
114
115
  async def handle_activities(
115
116
  self,
116
117
  message: Dict[Text, Any],
118
+ input_channel_name: str,
117
119
  output_channel: OutputChannel,
118
120
  on_new_message: Callable[[UserMessage], Awaitable[Any]],
119
121
  ) -> None:
120
122
  """Handle activities sent by Audiocodes."""
121
123
  structlogger.debug("audiocodes.handle.activities")
124
+ if input_channel_name == "":
125
+ structlogger.warning(
126
+ "audiocodes.handle.activities.empty_input_channel_name",
127
+ event_info=(
128
+ f"Audiocodes input channel name is empty "
129
+ f"for conversation {self.conversation_id}"
130
+ ),
131
+ )
132
+
122
133
  for activity in message["activities"]:
123
134
  text = None
124
135
  if activity[ACTIVITY_ID_KEY] in self.activity_ids:
@@ -142,6 +153,7 @@ class Conversation:
142
153
  metadata = self.get_metadata(activity)
143
154
  user_msg = UserMessage(
144
155
  text=text,
156
+ input_channel=input_channel_name,
145
157
  output_channel=output_channel,
146
158
  sender_id=self.conversation_id,
147
159
  metadata=metadata,
@@ -246,6 +258,9 @@ class AudiocodesInput(InputChannel):
246
258
  def _check_token(self, token: Optional[Text]) -> None:
247
259
  if not token:
248
260
  raise HttpUnauthorized("Authentication token required.")
261
+ if not hmac.compare_digest(str(token), str(self.token)):
262
+ structlogger.error("audiocodes.invalid_token", invalid_token=token)
263
+ raise HttpUnauthorized("Invalid authentication token.")
249
264
 
250
265
  def _get_conversation(
251
266
  self, token: Optional[Text], conversation_id: Text
@@ -388,7 +403,12 @@ class AudiocodesInput(InputChannel):
388
403
  # start a background task to handle activities
389
404
  self._create_task(
390
405
  conversation_id,
391
- conversation.handle_activities(request.json, ac_output, on_new_message),
406
+ conversation.handle_activities(
407
+ request.json,
408
+ input_channel_name=self.name(),
409
+ output_channel=ac_output,
410
+ on_new_message=on_new_message,
411
+ ),
392
412
  )
393
413
  return response.json(response_json)
394
414
 
@@ -401,23 +421,9 @@ class AudiocodesInput(InputChannel):
401
421
  Example of payload:
402
422
  {"conversation": <conversation_id>, "reason": Optional[Text]}.
403
423
  """
404
- self._get_conversation(request.token, conversation_id)
405
- reason = {"reason": request.json.get("reason")}
406
- await on_new_message(
407
- UserMessage(
408
- text=f"{INTENT_MESSAGE_PREFIX}session_end",
409
- output_channel=None,
410
- sender_id=conversation_id,
411
- metadata=reason,
412
- )
424
+ return await self._handle_disconnect(
425
+ request, conversation_id, on_new_message
413
426
  )
414
- del self.conversations[conversation_id]
415
- structlogger.debug(
416
- "audiocodes.disconnect",
417
- conversation=conversation_id,
418
- request=request.json,
419
- )
420
- return response.json({})
421
427
 
422
428
  @ac_webhook.route("/conversation/<conversation_id>/keepalive", methods=["POST"])
423
429
  async def keepalive(request: Request, conversation_id: Text) -> HTTPResponse:
@@ -432,6 +438,32 @@ class AudiocodesInput(InputChannel):
432
438
 
433
439
  return ac_webhook
434
440
 
441
+ async def _handle_disconnect(
442
+ self,
443
+ request: Request,
444
+ conversation_id: Text,
445
+ on_new_message: Callable[[UserMessage], Awaitable[Any]],
446
+ ) -> HTTPResponse:
447
+ """Triggered when the call is disconnected."""
448
+ self._get_conversation(request.token, conversation_id)
449
+ reason = {"reason": request.json.get("reason")}
450
+ await on_new_message(
451
+ UserMessage(
452
+ text=f"{INTENT_MESSAGE_PREFIX}session_end",
453
+ input_channel=self.name(),
454
+ output_channel=None,
455
+ sender_id=conversation_id,
456
+ metadata=reason,
457
+ )
458
+ )
459
+ del self.conversations[conversation_id]
460
+ structlogger.debug(
461
+ "audiocodes.disconnect",
462
+ conversation=conversation_id,
463
+ request=request.json,
464
+ )
465
+ return response.json({})
466
+
435
467
 
436
468
  class AudiocodesOutput(OutputChannel):
437
469
  @classmethod
@@ -439,6 +471,7 @@ class AudiocodesOutput(OutputChannel):
439
471
  return CHANNEL_NAME
440
472
 
441
473
  def __init__(self) -> None:
474
+ super().__init__()
442
475
  self.messages: List[Dict] = []
443
476
 
444
477
  async def add_message(self, message: Dict) -> None:
@@ -1,7 +1,8 @@
1
- from dataclasses import dataclass
2
- from typing import Any, Dict, Optional
3
1
  import json
4
2
  import os
3
+ from dataclasses import dataclass
4
+ from typing import Any, Dict, Optional
5
+ from urllib.parse import urlencode
5
6
 
6
7
  import websockets
7
8
  from websockets.legacy.client import WebSocketClientProtocol
@@ -19,11 +20,14 @@ from rasa.shared.constants import DEEPGRAM_API_KEY_ENV_VAR
19
20
  @dataclass
20
21
  class DeepgramASRConfig(ASREngineConfig):
21
22
  endpoint: Optional[str] = None
22
- # number of miliseconds of silence to determine end of speech
23
+ # number of milliseconds of silence to determine end of speech
23
24
  endpointing: Optional[int] = None
24
25
  language: Optional[str] = None
25
26
  model: Optional[str] = None
26
27
  smart_format: Optional[bool] = None
28
+ # number of milliseconds of no new transcript to determine end of speech
29
+ # should be at least 1000 according to docs
30
+ utterance_end_ms: Optional[int] = None
27
31
 
28
32
 
29
33
  class DeepgramASR(ASREngine[DeepgramASRConfig]):
@@ -37,22 +41,35 @@ class DeepgramASR(ASREngine[DeepgramASRConfig]):
37
41
  """Connect to the ASR system."""
38
42
  deepgram_api_key = os.environ[DEEPGRAM_API_KEY_ENV_VAR]
39
43
  extra_headers = {"Authorization": f"Token {deepgram_api_key}"}
40
- api_url = self._get_api_url()
41
- query_params = self._get_query_params()
42
44
  return await websockets.connect( # type: ignore
43
- api_url + query_params,
45
+ self._get_api_url_with_query_params(),
44
46
  extra_headers=extra_headers,
45
47
  )
46
48
 
49
+ def _get_api_url_with_query_params(self) -> str:
50
+ """Combine api url and query params."""
51
+ return self._get_api_url() + self._get_query_params()
52
+
47
53
  def _get_api_url(self) -> str:
54
+ """Get the api url with the configured endpoint."""
48
55
  return f"wss://{self.config.endpoint}/v1/listen?"
49
56
 
50
57
  def _get_query_params(self) -> str:
51
- return (
52
- f"encoding=mulaw&sample_rate={HERTZ}&endpointing={self.config.endpointing}"
53
- f"&vad_events=true&language={self.config.language}&interim_results=true"
54
- f"&model={self.config.model}&smart_format={str(self.config.smart_format).lower()}"
55
- )
58
+ """Get the configured query parameters for the api."""
59
+ query_params = {
60
+ "encoding": "mulaw",
61
+ "sample_rate": HERTZ,
62
+ "endpointing": self.config.endpointing,
63
+ "vad_events": "true",
64
+ "language": self.config.language,
65
+ "interim_results": "true",
66
+ "model": self.config.model,
67
+ "smart_format": str(self.config.smart_format).lower(),
68
+ }
69
+ if self.config.utterance_end_ms and self.config.utterance_end_ms > 0:
70
+ query_params["utterance_end_ms"] = self.config.utterance_end_ms
71
+
72
+ return urlencode(query_params)
56
73
 
57
74
  async def signal_audio_done(self) -> None:
58
75
  """Signal to the ASR Api that you are done sending data."""
@@ -67,24 +84,48 @@ class DeepgramASR(ASREngine[DeepgramASRConfig]):
67
84
  def engine_event_to_asr_event(self, e: Any) -> Optional[ASREvent]:
68
85
  """Translate an engine event to a common ASREvent."""
69
86
  data = json.loads(e)
70
- if "is_final" in data:
71
- transcript = data["channel"]["alternatives"][0]["transcript"]
87
+ data_type = data["type"]
88
+ if data_type == "Results":
89
+ transcript_data = data["channel"]["alternatives"][0]
90
+ transcript = transcript_data["transcript"]
72
91
  if data["is_final"]:
73
92
  if data.get("speech_final"):
74
- full_transcript = self.accumulated_transcript + transcript
93
+ full_transcript = self.concatenate_transcripts(
94
+ self.accumulated_transcript, transcript
95
+ )
75
96
  self.accumulated_transcript = ""
76
97
  if full_transcript:
77
98
  return NewTranscript(full_transcript)
78
99
  else:
79
- self.accumulated_transcript += transcript
100
+ self.accumulated_transcript = self.concatenate_transcripts(
101
+ self.accumulated_transcript, transcript
102
+ )
80
103
  elif transcript:
81
104
  return UserIsSpeaking()
105
+ # event that comes after utterance_end_ms of no new transcript
106
+ elif data_type == "UtteranceEnd":
107
+ if self.accumulated_transcript:
108
+ transcript = self.accumulated_transcript
109
+ self.accumulated_transcript = ""
110
+ return NewTranscript(transcript)
82
111
  return None
83
112
 
84
113
  @staticmethod
85
114
  def get_default_config() -> DeepgramASRConfig:
86
- return DeepgramASRConfig("api.deepgram.com", 400, "en", "nova-2-general", True)
115
+ return DeepgramASRConfig(
116
+ endpoint="api.deepgram.com",
117
+ endpointing=400,
118
+ language="en",
119
+ model="nova-2-general",
120
+ smart_format=True,
121
+ utterance_end_ms=1000,
122
+ )
87
123
 
88
124
  @classmethod
89
125
  def from_config_dict(cls, config: Dict) -> "DeepgramASR":
90
126
  return DeepgramASR(DeepgramASRConfig.from_dict(config))
127
+
128
+ @staticmethod
129
+ def concatenate_transcripts(t1: str, t2: str) -> str:
130
+ """Concatenate two transcripts making sure there is a space between them."""
131
+ return (t1.strip() + " " + t2.strip()).strip()
@@ -102,6 +102,9 @@ class BrowserAudioInputChannel(VoiceInputChannel):
102
102
 
103
103
  @blueprint.websocket("/websocket") # type: ignore
104
104
  async def handle_message(request: Request, ws: Websocket) -> None:
105
- await self.run_audio_streaming(on_new_message, ws)
105
+ try:
106
+ await self.run_audio_streaming(on_new_message, ws)
107
+ except Exception as e:
108
+ logger.error("browser_audio.handle_message.error", error=e)
106
109
 
107
110
  return blueprint
@@ -87,13 +87,22 @@ class CartesiaTTS(TTSEngine[CartesiaTTSConfig]):
87
87
  async for data in response.content.iter_chunked(1024):
88
88
  yield self.engine_bytes_to_rasa_audio_bytes(data)
89
89
  return
90
+ elif response.status == 401:
91
+ structlogger.error(
92
+ "cartesia.synthesize.rest.unauthorized",
93
+ status_code=response.status,
94
+ )
95
+ raise TTSError(
96
+ "Unauthorized. Please make sure you have the correct API key."
97
+ )
90
98
  else:
99
+ response_text = await response.text()
91
100
  structlogger.error(
92
101
  "cartesia.synthesize.rest.failed",
93
102
  status_code=response.status,
94
- msg=response.text(),
103
+ msg=response_text,
95
104
  )
96
- raise TTSError(f"TTS failed: {response.text()}")
105
+ raise TTSError(f"TTS failed: {response_text}")
97
106
  except ClientConnectorError as e:
98
107
  raise TTSError(e)
99
108
  except TimeoutError as e:
@@ -4,7 +4,7 @@ If the answer is not known or cannot be determined from the provided documents o
4
4
  Use the following documents to answer the question:
5
5
  {% for doc in docs %}
6
6
  {{ loop.cycle("*")}}. {{ doc.metadata }}
7
- {{ doc.page_content }}
7
+ {{ doc.text }}
8
8
  {% endfor %}
9
9
 
10
10
  {% if citation_enabled %}
@@ -1,7 +1,7 @@
1
1
  import importlib.resources
2
2
  import math
3
3
  from dataclasses import dataclass, field
4
- from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING, Text, Tuple
4
+ from typing import Any, Dict, List, Optional, TYPE_CHECKING, Text, Tuple
5
5
 
6
6
  import structlog
7
7
  import tiktoken
@@ -18,7 +18,6 @@ from rasa.core.constants import (
18
18
  UTTER_SOURCE_METADATA_KEY,
19
19
  )
20
20
  from rasa.core.policies.policy import Policy, PolicyPrediction, SupportedData
21
- from rasa.dialogue_understanding.patterns.chitchat import FLOW_PATTERN_CHITCHAT
22
21
  from rasa.dialogue_understanding.stack.frames import (
23
22
  ChitChatStackFrame,
24
23
  DialogueStackFrame,
@@ -30,7 +29,6 @@ from rasa.engine.storage.storage import ModelStorage
30
29
  from rasa.graph_components.providers.forms_provider import Forms
31
30
  from rasa.graph_components.providers.responses_provider import Responses
32
31
  from rasa.shared.constants import (
33
- REQUIRED_SLOTS_KEY,
34
32
  EMBEDDINGS_CONFIG_KEY,
35
33
  LLM_CONFIG_KEY,
36
34
  MODEL_CONFIG_KEY,
@@ -42,7 +40,6 @@ from rasa.shared.constants import (
42
40
  MODEL_GROUP_ID_CONFIG_KEY,
43
41
  )
44
42
  from rasa.shared.core.constants import ACTION_LISTEN_NAME
45
- from rasa.shared.core.constants import ACTION_TRIGGER_CHITCHAT
46
43
  from rasa.shared.core.domain import KEY_RESPONSES_TEXT, Domain
47
44
  from rasa.shared.core.events import (
48
45
  ActionExecuted,
@@ -52,6 +49,7 @@ from rasa.shared.core.events import (
52
49
  )
53
50
  from rasa.shared.core.flows import FlowsList
54
51
  from rasa.shared.core.generator import TrackerWithCachedStates
52
+ from rasa.shared.core.policies.utils import filter_responses_for_intentless_policy
55
53
  from rasa.shared.core.trackers import DialogueStateTracker
56
54
  from rasa.shared.exceptions import FileIOException, RasaCoreException
57
55
  from rasa.shared.nlu.constants import PREDICTED_CONFIDENCE_KEY
@@ -147,59 +145,6 @@ class Conversation:
147
145
  interactions: List[Interaction] = field(default_factory=list)
148
146
 
149
147
 
150
- def collect_form_responses(forms: Forms) -> Set[Text]:
151
- """Collect responses that belong the requested slots in forms.
152
-
153
- Args:
154
- forms: the forms from the domain
155
- Returns:
156
- all utterances used in forms
157
- """
158
- form_responses = set()
159
- for _, form_info in forms.data.items():
160
- for required_slot in form_info.get(REQUIRED_SLOTS_KEY, []):
161
- form_responses.add(f"utter_ask_{required_slot}")
162
- return form_responses
163
-
164
-
165
- def filter_responses(responses: Responses, forms: Forms, flows: FlowsList) -> Responses:
166
- """Filters out responses that are unwanted for the intentless policy.
167
-
168
- This includes utterances used in flows and forms.
169
-
170
- Args:
171
- responses: the responses from the domain
172
- forms: the forms from the domain
173
- flows: all flows
174
- Returns:
175
- The remaining, relevant responses for the intentless policy.
176
- """
177
- form_responses = collect_form_responses(forms)
178
- flow_responses = flows.utterances
179
- combined_responses = form_responses | flow_responses
180
- filtered_responses = {
181
- name: variants
182
- for name, variants in responses.data.items()
183
- if name not in combined_responses
184
- }
185
-
186
- pattern_chitchat = flows.flow_by_id(FLOW_PATTERN_CHITCHAT)
187
-
188
- # The following condition is highly unlikely, but mypy requires the case
189
- # of pattern_chitchat == None to be addressed
190
- if not pattern_chitchat:
191
- return Responses(data=filtered_responses)
192
-
193
- # if action_trigger_chitchat, filter out "utter_free_chitchat_response"
194
- has_action_trigger_chitchat = pattern_chitchat.has_action_step(
195
- ACTION_TRIGGER_CHITCHAT
196
- )
197
- if has_action_trigger_chitchat:
198
- filtered_responses.pop("utter_free_chitchat_response", None)
199
-
200
- return Responses(data=filtered_responses)
201
-
202
-
203
148
  def action_from_response(
204
149
  text: Optional[str], responses: Dict[Text, List[Dict[Text, Any]]]
205
150
  ) -> Optional[str]:
@@ -513,7 +458,9 @@ class IntentlessPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Policy):
513
458
  # Perform health checks of both LLM and embeddings client configs
514
459
  self._perform_health_checks(self.config, "intentless_policy.train")
515
460
 
516
- responses = filter_responses(responses, forms, flows or FlowsList([]))
461
+ responses = filter_responses_for_intentless_policy(
462
+ responses, forms, flows or FlowsList([])
463
+ )
517
464
  telemetry.track_intentless_policy_train()
518
465
  response_texts = [r for r in extract_ai_response_examples(responses.data)]
519
466
 
@@ -948,7 +895,6 @@ class IntentlessPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Policy):
948
895
  **kwargs: Any,
949
896
  ) -> "IntentlessPolicy":
950
897
  """Loads a trained policy (see parent class for full docstring)."""
951
-
952
898
  # Perform health checks of both LLM and embeddings client configs
953
899
  cls._perform_health_checks(config, "intentless_policy.load")
954
900
 
@@ -139,7 +139,7 @@ class NLUCommandAdapter(GraphComponent, CommandGenerator):
139
139
 
140
140
  if commands:
141
141
  commands = clean_up_commands(
142
- commands, tracker, flows, self._execution_context
142
+ commands, tracker, flows, self._execution_context, domain
143
143
  )
144
144
  log_llm(
145
145
  logger=structlogger,
@@ -41,9 +41,11 @@ from rasa.shared.constants import (
41
41
  )
42
42
  from rasa.shared.core.constants import ACTION_TRIGGER_CHITCHAT, SlotMappingType
43
43
  from rasa.shared.core.constants import FLOW_HASHES_SLOT
44
+ from rasa.shared.core.domain import Domain
44
45
  from rasa.shared.core.events import Event, SlotSet
45
46
  from rasa.shared.core.flows import FlowsList
46
47
  from rasa.shared.core.flows.steps.collect import CollectInformationFlowStep
48
+ from rasa.shared.core.policies.utils import contains_intentless_policy_responses
47
49
  from rasa.shared.core.slots import Slot
48
50
  from rasa.shared.core.trackers import DialogueStateTracker
49
51
  from rasa.shared.nlu.constants import COMMANDS
@@ -182,6 +184,7 @@ def execute_commands(
182
184
  all_flows: FlowsList,
183
185
  execution_context: ExecutionContext,
184
186
  story_graph: Optional[StoryGraph] = None,
187
+ domain: Optional[Domain] = None,
185
188
  ) -> List[Event]:
186
189
  """Executes a list of commands.
187
190
 
@@ -191,6 +194,7 @@ def execute_commands(
191
194
  all_flows: All flows.
192
195
  execution_context: Information about the single graph run.
193
196
  story_graph: StoryGraph object with stories available for training.
197
+ domain: The domain of the bot.
194
198
 
195
199
  Returns:
196
200
  A list of the events that were created.
@@ -199,7 +203,7 @@ def execute_commands(
199
203
  original_tracker = tracker.copy()
200
204
 
201
205
  commands = clean_up_commands(
202
- commands, tracker, all_flows, execution_context, story_graph
206
+ commands, tracker, all_flows, execution_context, story_graph, domain
203
207
  )
204
208
 
205
209
  updated_flows = find_updated_flows(tracker, all_flows)
@@ -333,6 +337,7 @@ def clean_up_commands(
333
337
  all_flows: FlowsList,
334
338
  execution_context: ExecutionContext,
335
339
  story_graph: Optional[StoryGraph] = None,
340
+ domain: Optional[Domain] = None,
336
341
  ) -> List[Command]:
337
342
  """Clean up a list of commands.
338
343
 
@@ -348,10 +353,13 @@ def clean_up_commands(
348
353
  all_flows: All flows.
349
354
  execution_context: Information about a single graph run.
350
355
  story_graph: StoryGraph object with stories available for training.
356
+ domain: The domain of the bot.
351
357
 
352
358
  Returns:
353
359
  The cleaned up commands.
354
360
  """
361
+ domain = domain if domain else Domain.empty()
362
+
355
363
  slots_so_far, active_flow = filled_slots_for_active_flow(tracker, all_flows)
356
364
 
357
365
  clean_commands: List[Command] = []
@@ -394,7 +402,12 @@ def clean_up_commands(
394
402
  # handle chitchat command differently from other free-form answer commands
395
403
  elif isinstance(command, ChitChatAnswerCommand):
396
404
  clean_commands = clean_up_chitchat_command(
397
- clean_commands, command, all_flows, execution_context, story_graph
405
+ clean_commands,
406
+ command,
407
+ all_flows,
408
+ execution_context,
409
+ domain,
410
+ story_graph,
398
411
  )
399
412
 
400
413
  elif isinstance(command, FreeFormAnswerCommand):
@@ -590,6 +603,7 @@ def clean_up_chitchat_command(
590
603
  command: ChitChatAnswerCommand,
591
604
  flows: FlowsList,
592
605
  execution_context: ExecutionContext,
606
+ domain: Domain,
593
607
  story_graph: Optional[StoryGraph] = None,
594
608
  ) -> List[Command]:
595
609
  """Clean up a chitchat answer command.
@@ -603,6 +617,8 @@ def clean_up_chitchat_command(
603
617
  flows: All flows.
604
618
  execution_context: Information about a single graph run.
605
619
  story_graph: StoryGraph object with stories available for training.
620
+ domain: The domain of the bot.
621
+
606
622
  Returns:
607
623
  The cleaned up commands.
608
624
  """
@@ -628,10 +644,9 @@ def clean_up_chitchat_command(
628
644
  )
629
645
  defines_intentless_policy = execution_context.has_node(IntentlessPolicy)
630
646
 
631
- has_e2e_stories = True if (story_graph and story_graph.has_e2e_stories()) else False
632
-
633
647
  if (has_action_trigger_chitchat and not defines_intentless_policy) or (
634
- defines_intentless_policy and not has_e2e_stories
648
+ defines_intentless_policy
649
+ and not contains_intentless_policy_responses(flows, domain, story_graph)
635
650
  ):
636
651
  resulting_commands.insert(
637
652
  0, CannotHandleCommand(RASA_PATTERN_CANNOT_HANDLE_CHITCHAT)