rasa-pro 3.13.0.dev1__py3-none-any.whl → 3.13.0.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rasa-pro might be problematic. Click here for more details.

Files changed (58) hide show
  1. rasa/core/actions/action.py +0 -6
  2. rasa/core/channels/voice_ready/audiocodes.py +52 -17
  3. rasa/core/channels/voice_stream/audiocodes.py +53 -9
  4. rasa/core/channels/voice_stream/genesys.py +146 -16
  5. rasa/core/information_retrieval/faiss.py +6 -1
  6. rasa/core/information_retrieval/information_retrieval.py +40 -2
  7. rasa/core/information_retrieval/milvus.py +7 -2
  8. rasa/core/information_retrieval/qdrant.py +7 -2
  9. rasa/core/policies/enterprise_search_policy.py +61 -301
  10. rasa/core/policies/flows/flow_executor.py +3 -38
  11. rasa/core/processor.py +27 -6
  12. rasa/core/utils.py +53 -0
  13. rasa/dialogue_understanding/commands/cancel_flow_command.py +4 -59
  14. rasa/dialogue_understanding/commands/start_flow_command.py +0 -41
  15. rasa/dialogue_understanding/generator/command_generator.py +67 -0
  16. rasa/dialogue_understanding/generator/command_parser.py +1 -1
  17. rasa/dialogue_understanding/generator/llm_based_command_generator.py +4 -13
  18. rasa/dialogue_understanding/generator/prompt_templates/command_prompt_template.jinja2 +1 -1
  19. rasa/dialogue_understanding/generator/prompt_templates/command_prompt_v2_gpt_4o_2024_11_20_template.jinja2 +20 -1
  20. rasa/dialogue_understanding/generator/single_step/compact_llm_command_generator.py +7 -0
  21. rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +0 -61
  22. rasa/dialogue_understanding/processor/command_processor.py +7 -65
  23. rasa/dialogue_understanding/stack/utils.py +0 -38
  24. rasa/dialogue_understanding_test/io.py +13 -8
  25. rasa/document_retrieval/__init__.py +0 -0
  26. rasa/document_retrieval/constants.py +32 -0
  27. rasa/document_retrieval/document_post_processor.py +351 -0
  28. rasa/document_retrieval/document_post_processor_prompt_template.jinja2 +0 -0
  29. rasa/document_retrieval/document_retriever.py +333 -0
  30. rasa/document_retrieval/knowledge_base_connectors/__init__.py +0 -0
  31. rasa/document_retrieval/knowledge_base_connectors/api_connector.py +39 -0
  32. rasa/document_retrieval/knowledge_base_connectors/knowledge_base_connector.py +34 -0
  33. rasa/document_retrieval/knowledge_base_connectors/vector_store_connector.py +226 -0
  34. rasa/document_retrieval/query_rewriter.py +234 -0
  35. rasa/document_retrieval/query_rewriter_prompt_template.jinja2 +8 -0
  36. rasa/engine/recipes/default_components.py +2 -0
  37. rasa/shared/core/constants.py +0 -8
  38. rasa/shared/core/domain.py +12 -3
  39. rasa/shared/core/flows/flow.py +0 -17
  40. rasa/shared/core/flows/flows_yaml_schema.json +3 -38
  41. rasa/shared/core/flows/steps/collect.py +5 -18
  42. rasa/shared/core/flows/utils.py +1 -16
  43. rasa/shared/core/slot_mappings.py +11 -5
  44. rasa/shared/nlu/constants.py +0 -1
  45. rasa/shared/utils/common.py +11 -1
  46. rasa/shared/utils/llm.py +1 -1
  47. rasa/tracing/instrumentation/attribute_extractors.py +10 -7
  48. rasa/tracing/instrumentation/instrumentation.py +12 -12
  49. rasa/validator.py +1 -123
  50. rasa/version.py +1 -1
  51. {rasa_pro-3.13.0.dev1.dist-info → rasa_pro-3.13.0.dev2.dist-info}/METADATA +1 -1
  52. {rasa_pro-3.13.0.dev1.dist-info → rasa_pro-3.13.0.dev2.dist-info}/RECORD +55 -47
  53. rasa/core/actions/action_handle_digressions.py +0 -164
  54. rasa/dialogue_understanding/commands/handle_digressions_command.py +0 -144
  55. rasa/dialogue_understanding/patterns/handle_digressions.py +0 -81
  56. {rasa_pro-3.13.0.dev1.dist-info → rasa_pro-3.13.0.dev2.dist-info}/NOTICE +0 -0
  57. {rasa_pro-3.13.0.dev1.dist-info → rasa_pro-3.13.0.dev2.dist-info}/WHEEL +0 -0
  58. {rasa_pro-3.13.0.dev1.dist-info → rasa_pro-3.13.0.dev2.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,234 @@
1
+ import importlib.resources
2
+ from enum import Enum
3
+ from typing import Any, Dict, Optional
4
+
5
+ import structlog
6
+ from jinja2 import Template
7
+
8
+ import rasa.shared.utils.io
9
+ from rasa.engine.storage.resource import Resource
10
+ from rasa.engine.storage.storage import ModelStorage
11
+ from rasa.shared.constants import (
12
+ LLM_CONFIG_KEY,
13
+ MODEL_CONFIG_KEY,
14
+ OPENAI_PROVIDER,
15
+ PROMPT_TEMPLATE_CONFIG_KEY,
16
+ PROVIDER_CONFIG_KEY,
17
+ TEXT,
18
+ TIMEOUT_CONFIG_KEY,
19
+ )
20
+ from rasa.shared.core.trackers import DialogueStateTracker
21
+ from rasa.shared.exceptions import FileIOException, ProviderClientAPIException
22
+ from rasa.shared.nlu.training_data.message import Message
23
+ from rasa.shared.providers.llm.llm_client import LLMClient
24
+ from rasa.shared.providers.llm.llm_response import LLMResponse
25
+ from rasa.shared.utils.health_check.health_check import perform_llm_health_check
26
+ from rasa.shared.utils.health_check.llm_health_check_mixin import (
27
+ LLMHealthCheckMixin,
28
+ )
29
+ from rasa.shared.utils.llm import (
30
+ AI,
31
+ DEFAULT_OPENAI_GENERATE_MODEL_NAME,
32
+ DEFAULT_OPENAI_MAX_GENERATED_TOKENS,
33
+ USER,
34
+ get_prompt_template,
35
+ llm_factory,
36
+ resolve_model_client_config,
37
+ sanitize_message_for_prompt,
38
+ tracker_as_readable_transcript,
39
+ )
40
+
41
+ QUERY_REWRITER_PROMPT_FILE_NAME = "query_rewriter_prompt_template.jinja2"
42
+ MAX_TURNS = "max_turns"
43
+ DEFAULT_LLM_CONFIG = {
44
+ PROVIDER_CONFIG_KEY: OPENAI_PROVIDER,
45
+ MODEL_CONFIG_KEY: DEFAULT_OPENAI_GENERATE_MODEL_NAME,
46
+ "temperature": 0.3,
47
+ "max_tokens": DEFAULT_OPENAI_MAX_GENERATED_TOKENS,
48
+ TIMEOUT_CONFIG_KEY: 5,
49
+ }
50
+ DEFAULT_QUERY_REWRITER_PROMPT_TEMPLATE = importlib.resources.read_text(
51
+ "rasa.document_retrieval",
52
+ "query_rewriter_prompt_template.jinja2",
53
+ )
54
+
55
+ TYPE_CONFIG_KEY = "type"
56
+
57
+ structlogger = structlog.get_logger()
58
+
59
+
60
+ class QueryRewritingType(Enum):
61
+ PLAIN = "PLAIN"
62
+ CONCATENATED_TURNS = "CONCATENATED_TURNS"
63
+ REPHRASE = "REPHRASE"
64
+ KEYWORD_EXTRACTION = "KEYWORD_EXTRACTION"
65
+
66
+ def __str__(self) -> str:
67
+ return self.value
68
+
69
+
70
+ class QueryRewriter(LLMHealthCheckMixin):
71
+ @classmethod
72
+ def get_default_config(cls) -> Dict[str, Any]:
73
+ """The default config for the query rewriter."""
74
+ return {
75
+ TYPE_CONFIG_KEY: QueryRewritingType.PLAIN,
76
+ MAX_TURNS: 0,
77
+ LLM_CONFIG_KEY: DEFAULT_LLM_CONFIG,
78
+ PROMPT_TEMPLATE_CONFIG_KEY: DEFAULT_QUERY_REWRITER_PROMPT_TEMPLATE,
79
+ }
80
+
81
+ def __init__(
82
+ self,
83
+ config: Dict[str, Any],
84
+ model_storage: ModelStorage,
85
+ resource: Resource,
86
+ prompt_template: Optional[str] = None,
87
+ ):
88
+ self.config = {**self.get_default_config(), **config}
89
+ self.config[LLM_CONFIG_KEY] = resolve_model_client_config(
90
+ self.config.get(LLM_CONFIG_KEY), QueryRewriter.__name__
91
+ )
92
+ self.prompt_template = prompt_template or get_prompt_template(
93
+ config.get(PROMPT_TEMPLATE_CONFIG_KEY),
94
+ DEFAULT_QUERY_REWRITER_PROMPT_TEMPLATE,
95
+ )
96
+
97
+ self._model_storage = model_storage
98
+ self._resource = resource
99
+
100
+ @classmethod
101
+ def load(
102
+ cls,
103
+ config: Dict[str, Any],
104
+ model_storage: ModelStorage,
105
+ resource: Resource,
106
+ **kwargs: Any,
107
+ ) -> "QueryRewriter":
108
+ """Load query rewriter."""
109
+ llm_config = resolve_model_client_config(config.get(LLM_CONFIG_KEY, {}))
110
+ perform_llm_health_check(
111
+ llm_config,
112
+ DEFAULT_LLM_CONFIG,
113
+ "query_rewriter.load",
114
+ QueryRewriter.__name__,
115
+ )
116
+
117
+ # load prompt template
118
+ prompt_template = None
119
+ try:
120
+ with model_storage.read_from(resource) as path:
121
+ prompt_template = rasa.shared.utils.io.read_file(
122
+ path / QUERY_REWRITER_PROMPT_FILE_NAME
123
+ )
124
+ except (FileNotFoundError, FileIOException) as e:
125
+ structlogger.warning(
126
+ "query_rewriter.load_prompt_template.failed",
127
+ error=e,
128
+ resource=resource.name,
129
+ )
130
+
131
+ return QueryRewriter(config, model_storage, resource, prompt_template)
132
+
133
+ def persist(self) -> None:
134
+ with self._model_storage.write_to(self._resource) as path:
135
+ rasa.shared.utils.io.write_text_file(
136
+ self.prompt_template, path / QUERY_REWRITER_PROMPT_FILE_NAME
137
+ )
138
+
139
+ @staticmethod
140
+ def _concatenate_turns(
141
+ message: Message, tracker: DialogueStateTracker, max_turns: int
142
+ ) -> str:
143
+ transcript = tracker_as_readable_transcript(tracker, max_turns=max_turns)
144
+ transcript += "\nUSER: " + message.get(TEXT)
145
+ return transcript
146
+
147
+ @staticmethod
148
+ async def _invoke_llm(prompt: str, llm: LLMClient) -> Optional[LLMResponse]:
149
+ try:
150
+ return await llm.acompletion(prompt)
151
+ except Exception as e:
152
+ # unfortunately, langchain does not wrap LLM exceptions which means
153
+ # we have to catch all exceptions here
154
+ structlogger.error("query_rewriter.llm.error", error=e)
155
+ raise ProviderClientAPIException(
156
+ message="LLM call exception", original_exception=e
157
+ )
158
+
159
+ async def _rephrase_message(
160
+ self, message: Message, tracker: DialogueStateTracker, max_turns: int = 5
161
+ ) -> str:
162
+ llm = llm_factory(self.config.get(LLM_CONFIG_KEY), DEFAULT_LLM_CONFIG)
163
+
164
+ transcript = tracker_as_readable_transcript(
165
+ tracker, max_turns=max_turns, ai_prefix="ASSISTANT"
166
+ )
167
+
168
+ inputs = {
169
+ "conversation": transcript,
170
+ "user_message": message.get(TEXT),
171
+ }
172
+
173
+ prompt = Template(self.prompt_template).render(**inputs)
174
+ llm_response = await self._invoke_llm(prompt, llm)
175
+ llm_response = LLMResponse.ensure_llm_response(llm_response)
176
+
177
+ return llm_response.choices[0]
178
+
179
+ @staticmethod
180
+ def _keyword_extraction(
181
+ message: Message, tracker: DialogueStateTracker, max_turns: int = 5
182
+ ) -> str:
183
+ import spacy
184
+
185
+ nlp = spacy.load("en_core_web_md")
186
+
187
+ transcript = tracker_as_readable_transcript(tracker, max_turns=max_turns)
188
+ transcript = transcript.replace(USER, "")
189
+ transcript = transcript.replace(AI, "")
190
+
191
+ doc = nlp(transcript)
192
+
193
+ keywords = set()
194
+ for token in doc:
195
+ # Extract nouns and proper nouns
196
+ if token.pos_ in ["NOUN", "PROPN"]:
197
+ keywords.add(token.lemma_)
198
+
199
+ for ent in doc.ents:
200
+ # Add named entities as keywords
201
+ keywords.add(ent.text)
202
+
203
+ # Remove stop words and punctuation
204
+ keywords = {
205
+ word
206
+ for word in keywords
207
+ if word.lower() not in nlp.Defaults.stop_words and word.isalpha()
208
+ }
209
+
210
+ if keywords:
211
+ return message.get(TEXT) + " " + " ".join(keywords)
212
+ else:
213
+ return message.get(TEXT)
214
+
215
+ async def prepare_search_query(
216
+ self, message: Message, tracker: DialogueStateTracker
217
+ ) -> str:
218
+ query_rewriting_type = self.config[TYPE_CONFIG_KEY]
219
+ max_turns: int = self.config[MAX_TURNS]
220
+
221
+ query: str
222
+
223
+ if query_rewriting_type == QueryRewritingType.CONCATENATED_TURNS.value:
224
+ query = self._concatenate_turns(message, tracker, max_turns)
225
+ elif query_rewriting_type == QueryRewritingType.KEYWORD_EXTRACTION.value:
226
+ query = self._keyword_extraction(message, tracker, max_turns)
227
+ elif query_rewriting_type == QueryRewritingType.REPHRASE.value:
228
+ query = await self._rephrase_message(message, tracker, max_turns)
229
+ elif query_rewriting_type == QueryRewritingType.PLAIN.value:
230
+ query = message.get(TEXT)
231
+ else:
232
+ raise ValueError(f"Invalid query rewriting type: {query_rewriting_type}")
233
+
234
+ return sanitize_message_for_prompt(query)
@@ -0,0 +1,8 @@
1
+ Conversation Context:
2
+ {{conversation}}
3
+
4
+ User's Latest Request: {{user_message}}
5
+
6
+ Instruction: Rephrase the user's latest request into a concise search query for a knowledge base.
7
+
8
+ Search Query:
@@ -13,6 +13,7 @@ from rasa.dialogue_understanding.generator import (
13
13
  LLMCommandGenerator,
14
14
  )
15
15
  from rasa.dialogue_understanding.generator.nlu_command_adapter import NLUCommandAdapter
16
+ from rasa.document_retrieval.document_retriever import DocumentRetriever
16
17
  from rasa.nlu.classifiers.diet_classifier import DIETClassifier
17
18
  from rasa.nlu.classifiers.fallback_classifier import FallbackClassifier
18
19
  from rasa.nlu.classifiers.keyword_intent_classifier import KeywordIntentClassifier
@@ -92,4 +93,5 @@ DEFAULT_COMPONENTS = [
92
93
  FlowPolicy,
93
94
  EnterpriseSearchPolicy,
94
95
  IntentlessPolicy,
96
+ DocumentRetriever,
95
97
  ]
@@ -52,8 +52,6 @@ ACTION_TRIGGER_CHITCHAT = "action_trigger_chitchat"
52
52
  ACTION_RESET_ROUTING = "action_reset_routing"
53
53
  ACTION_HANGUP = "action_hangup"
54
54
  ACTION_REPEAT_BOT_MESSAGES = "action_repeat_bot_messages"
55
- ACTION_BLOCK_DIGRESSION = "action_block_digression"
56
- ACTION_CONTINUE_DIGRESSION = "action_continue_digression"
57
55
 
58
56
  ACTION_METADATA_EXECUTION_SUCCESS = "execution_success"
59
57
  ACTION_METADATA_EXECUTION_ERROR_MESSAGE = "execution_error_message"
@@ -84,8 +82,6 @@ DEFAULT_ACTION_NAMES = [
84
82
  ACTION_RESET_ROUTING,
85
83
  ACTION_HANGUP,
86
84
  ACTION_REPEAT_BOT_MESSAGES,
87
- ACTION_BLOCK_DIGRESSION,
88
- ACTION_CONTINUE_DIGRESSION,
89
85
  ]
90
86
 
91
87
  ACTION_SHOULD_SEND_DOMAIN = "send_domain"
@@ -205,8 +201,4 @@ CLASSIFIER_NAME_FALLBACK = "FallbackClassifier"
205
201
 
206
202
  POLICIES_THAT_EXTRACT_ENTITIES = {"TEDPolicy"}
207
203
 
208
- # digression constants
209
- KEY_ASK_CONFIRM_DIGRESSIONS = "ask_confirm_digressions"
210
- KEY_BLOCK_DIGRESSIONS = "block_digressions"
211
-
212
204
  ERROR_CODE_KEY = "error_code"
@@ -1678,6 +1678,14 @@ class Domain:
1678
1678
  """Write domain to a file."""
1679
1679
  as_yaml = self.as_yaml()
1680
1680
  rasa.shared.utils.io.write_text_file(as_yaml, filename)
1681
+ # run the check again on the written domain to catch any errors
1682
+ # that may have been missed in the user defined domain files
1683
+ structlogger.info(
1684
+ "domain.persist.domain_written_to_file",
1685
+ event_info="The entire domain content has been written to file.",
1686
+ filename=filename,
1687
+ )
1688
+ Domain.is_domain_file(filename)
1681
1689
 
1682
1690
  def as_yaml(self) -> Text:
1683
1691
  """Dump the `Domain` object as a YAML string.
@@ -1972,17 +1980,18 @@ class Domain:
1972
1980
 
1973
1981
  try:
1974
1982
  content = read_yaml_file(filename, expand_env_vars=cls.expand_env_vars)
1975
- except (RasaException, YamlSyntaxException):
1976
- structlogger.warning(
1983
+ except (RasaException, YamlSyntaxException) as error:
1984
+ structlogger.error(
1977
1985
  "domain.cannot_load_domain_file",
1978
1986
  file=filename,
1987
+ error=error,
1979
1988
  event_info=(
1980
1989
  f"The file {filename} could not be loaded as domain file. "
1981
1990
  f"You can use https://yamlchecker.com/ to validate "
1982
1991
  f"the YAML syntax of your file."
1983
1992
  ),
1984
1993
  )
1985
- return False
1994
+ raise RasaException(f"Domain could not be loaded: {error}")
1986
1995
 
1987
1996
  return any(key in content for key in ALL_DOMAIN_KEYS)
1988
1997
 
@@ -13,10 +13,6 @@ from pypred import Predicate
13
13
  import rasa.shared.utils.io
14
14
  from rasa.engine.language import Language
15
15
  from rasa.shared.constants import RASA_DEFAULT_FLOW_PATTERN_PREFIX
16
- from rasa.shared.core.constants import (
17
- KEY_ASK_CONFIRM_DIGRESSIONS,
18
- KEY_BLOCK_DIGRESSIONS,
19
- )
20
16
  from rasa.shared.core.flows.constants import (
21
17
  KEY_ALWAYS_INCLUDE_IN_PROMPT,
22
18
  KEY_DESCRIPTION,
@@ -52,7 +48,6 @@ from rasa.shared.core.flows.steps.constants import (
52
48
  START_STEP,
53
49
  )
54
50
  from rasa.shared.core.flows.steps.continuation import ContinueFlowStep
55
- from rasa.shared.core.flows.utils import extract_digression_prop
56
51
  from rasa.shared.core.slots import Slot
57
52
 
58
53
  structlogger = structlog.get_logger()
@@ -94,10 +89,6 @@ class Flow:
94
89
  """The path to the file where the flow is stored."""
95
90
  persisted_slots: List[str] = field(default_factory=list)
96
91
  """The list of slots that should be persisted after the flow ends."""
97
- ask_confirm_digressions: List[str] = field(default_factory=list)
98
- """The flow ids for which the assistant should ask for confirmation."""
99
- block_digressions: List[str] = field(default_factory=list)
100
- """The flow ids that the assistant should block from digressing to."""
101
92
  run_pattern_completed: bool = True
102
93
  """Whether the pattern_completed flow should be run after the flow ends."""
103
94
 
@@ -138,10 +129,6 @@ class Flow:
138
129
  # data. When the model is trained, take the provided file_path.
139
130
  file_path=data.get(KEY_FILE_PATH) if KEY_FILE_PATH in data else file_path,
140
131
  persisted_slots=data.get(KEY_PERSISTED_SLOTS, []),
141
- ask_confirm_digressions=extract_digression_prop(
142
- KEY_ASK_CONFIRM_DIGRESSIONS, data
143
- ),
144
- block_digressions=extract_digression_prop(KEY_BLOCK_DIGRESSIONS, data),
145
132
  run_pattern_completed=data.get(KEY_RUN_PATTERN_COMPLETED, True),
146
133
  translation=extract_translations(
147
134
  translation_data=data.get(KEY_TRANSLATION, {})
@@ -220,10 +207,6 @@ class Flow:
220
207
  data[KEY_FILE_PATH] = self.file_path
221
208
  if self.persisted_slots:
222
209
  data[KEY_PERSISTED_SLOTS] = self.persisted_slots
223
- if self.ask_confirm_digressions:
224
- data[KEY_ASK_CONFIRM_DIGRESSIONS] = self.ask_confirm_digressions
225
- if self.block_digressions:
226
- data[KEY_BLOCK_DIGRESSIONS] = self.block_digressions
227
210
  if self.run_pattern_completed is not None:
228
211
  data["run_pattern_completed"] = self.run_pattern_completed
229
212
  if self.translation:
@@ -217,15 +217,12 @@
217
217
  "reset_after_flow_ends": {
218
218
  "type": "boolean"
219
219
  },
220
- "ask_confirm_digressions": {
221
- "$ref": "#/$defs/ask_confirm_digressions"
222
- },
223
- "block_digressions": {
224
- "$ref": "#/$defs/block_digressions"
225
- },
226
220
  "utter": {
227
221
  "type": "string"
228
222
  },
223
+ "force_slot_filling": {
224
+ "type": "boolean"
225
+ },
229
226
  "rejections": {
230
227
  "type": "array",
231
228
  "schema_name": "list of rejections",
@@ -253,32 +250,6 @@
253
250
  }
254
251
  }
255
252
  },
256
- "ask_confirm_digressions": {
257
- "oneOf": [
258
- {
259
- "type": "boolean"
260
- },
261
- {
262
- "type": "array",
263
- "items": {
264
- "type": "string"
265
- }
266
- }
267
- ]
268
- },
269
- "block_digressions": {
270
- "oneOf": [
271
- {
272
- "type": "boolean"
273
- },
274
- {
275
- "type": "array",
276
- "items": {
277
- "type": "string"
278
- }
279
- }
280
- ]
281
- },
282
253
  "flow": {
283
254
  "required": [
284
255
  "steps",
@@ -340,12 +311,6 @@
340
311
  "persisted_slots": {
341
312
  "$ref": "#/$defs/persisted_slots"
342
313
  },
343
- "ask_confirm_digressions": {
344
- "$ref": "#/$defs/ask_confirm_digressions"
345
- },
346
- "block_digressions": {
347
- "$ref": "#/$defs/block_digressions"
348
- },
349
314
  "run_pattern_completed": {
350
315
  "type": "boolean"
351
316
  }
@@ -1,15 +1,10 @@
1
1
  from __future__ import annotations
2
2
 
3
- from dataclasses import dataclass, field
3
+ from dataclasses import dataclass
4
4
  from typing import Any, Dict, List, Set, Text
5
5
 
6
6
  from rasa.shared.constants import ACTION_ASK_PREFIX, UTTER_ASK_PREFIX
7
- from rasa.shared.core.constants import (
8
- KEY_ASK_CONFIRM_DIGRESSIONS,
9
- KEY_BLOCK_DIGRESSIONS,
10
- )
11
7
  from rasa.shared.core.flows.flow_step import FlowStep
12
- from rasa.shared.core.flows.utils import extract_digression_prop
13
8
  from rasa.shared.core.slots import SlotRejection
14
9
 
15
10
 
@@ -29,10 +24,8 @@ class CollectInformationFlowStep(FlowStep):
29
24
  """Whether to always ask the question even if the slot is already filled."""
30
25
  reset_after_flow_ends: bool = True
31
26
  """Whether to reset the slot value at the end of the flow."""
32
- ask_confirm_digressions: List[str] = field(default_factory=list)
33
- """The flow id digressions for which the assistant should ask for confirmation."""
34
- block_digressions: List[str] = field(default_factory=list)
35
- """The flow id digressions that should be blocked during the flow step."""
27
+ force_slot_filling: bool = False
28
+ """Whether to keep only the SetSlot command for the collected slot."""
36
29
 
37
30
  @classmethod
38
31
  def from_json(
@@ -60,10 +53,7 @@ class CollectInformationFlowStep(FlowStep):
60
53
  SlotRejection.from_dict(rejection)
61
54
  for rejection in data.get("rejections", [])
62
55
  ],
63
- ask_confirm_digressions=extract_digression_prop(
64
- KEY_ASK_CONFIRM_DIGRESSIONS, data
65
- ),
66
- block_digressions=extract_digression_prop(KEY_BLOCK_DIGRESSIONS, data),
56
+ force_slot_filling=data.get("force_slot_filling", False),
67
57
  **base.__dict__,
68
58
  )
69
59
 
@@ -79,10 +69,7 @@ class CollectInformationFlowStep(FlowStep):
79
69
  data["ask_before_filling"] = self.ask_before_filling
80
70
  data["reset_after_flow_ends"] = self.reset_after_flow_ends
81
71
  data["rejections"] = [rejection.as_dict() for rejection in self.rejections]
82
- data["ask_confirm_digressions"] = self.ask_confirm_digressions
83
- data["block_digressions"] = (
84
- self.block_digressions if self.block_digressions else False
85
- )
72
+ data["force_slot_filling"] = self.force_slot_filling
86
73
 
87
74
  return data
88
75
 
@@ -1,4 +1,4 @@
1
- from typing import TYPE_CHECKING, Any, Dict, List, Set, Text
1
+ from typing import TYPE_CHECKING, Any, Dict, Set, Text
2
2
 
3
3
  from rasa.shared.utils.io import raise_deprecation_warning
4
4
 
@@ -8,7 +8,6 @@ if TYPE_CHECKING:
8
8
 
9
9
  RESET_PROPERTY_NAME = "reset_after_flow_ends"
10
10
  PERSIST_PROPERTY_NAME = "persisted_slots"
11
- ALL_LABEL = "ALL"
12
11
 
13
12
 
14
13
  def warn_deprecated_collect_step_config() -> None:
@@ -45,20 +44,6 @@ def get_invalid_slot_persistence_config_error_message(
45
44
  )
46
45
 
47
46
 
48
- def extract_digression_prop(prop: str, data: Dict[str, Any]) -> List[str]:
49
- """Extracts the digression property from the data.
50
-
51
- There can be two types of properties: ask_confirm_digressions and
52
- block_digressions.
53
- """
54
- digression_property = data.get(prop, [])
55
-
56
- if isinstance(digression_property, bool):
57
- digression_property = [ALL_LABEL] if digression_property else []
58
-
59
- return digression_property
60
-
61
-
62
47
  def extract_translations(
63
48
  translation_data: Dict[Text, Any],
64
49
  ) -> Dict[Text, "FlowLanguageTranslation"]:
@@ -648,12 +648,14 @@ class SlotFillingManager:
648
648
  output_channel: "OutputChannel",
649
649
  nlg: "NaturalLanguageGenerator",
650
650
  recreate_tracker: bool = False,
651
+ slot_events: Optional[List[Event]] = None,
651
652
  ) -> List[Event]:
652
653
  from rasa.core.actions.action import RemoteAction
653
654
  from rasa.shared.core.trackers import DialogueStateTracker
654
655
  from rasa.utils.endpoints import ClientResponseError
655
656
 
656
- slot_events: List[Event] = []
657
+ validated_slot_events: List[Event] = []
658
+ slot_events = slot_events if slot_events is not None else []
657
659
  remote_action = RemoteAction(custom_action, self._action_endpoint)
658
660
  disallowed_types = set()
659
661
 
@@ -673,9 +675,9 @@ class SlotFillingManager:
673
675
  )
674
676
  for event in custom_events:
675
677
  if isinstance(event, SlotSet):
676
- slot_events.append(event)
678
+ validated_slot_events.append(event)
677
679
  elif isinstance(event, BotUttered):
678
- slot_events.append(event)
680
+ validated_slot_events.append(event)
679
681
  else:
680
682
  disallowed_types.add(event.type_name)
681
683
  except (RasaException, ClientResponseError) as e:
@@ -699,7 +701,7 @@ class SlotFillingManager:
699
701
  f"updated with this event.",
700
702
  )
701
703
 
702
- return slot_events
704
+ return validated_slot_events
703
705
 
704
706
  async def execute_validation_action(
705
707
  self,
@@ -722,7 +724,11 @@ class SlotFillingManager:
722
724
  return cast(List[Event], slot_events)
723
725
 
724
726
  validate_events = await self._run_custom_action(
725
- ACTION_VALIDATE_SLOT_MAPPINGS, output_channel, nlg, recreate_tracker=True
727
+ ACTION_VALIDATE_SLOT_MAPPINGS,
728
+ output_channel,
729
+ nlg,
730
+ recreate_tracker=True,
731
+ slot_events=cast(List[Event], slot_events),
726
732
  )
727
733
  validated_slot_names = [
728
734
  event.key for event in validate_events if isinstance(event, SlotSet)
@@ -55,4 +55,3 @@ SPLIT_ENTITIES_BY_COMMA_DEFAULT_VALUE = True
55
55
  SINGLE_ENTITY_ALLOWED_INTERLEAVING_CHARSET = {".", ",", " ", ";"}
56
56
 
57
57
  SET_SLOT_COMMAND = "set slot"
58
- HANDLE_DIGRESSIONS_COMMAND = "handle digressions"
@@ -7,7 +7,17 @@ import os
7
7
  import pkgutil
8
8
  import sys
9
9
  from types import ModuleType
10
- from typing import Any, Callable, Collection, Dict, List, Optional, Sequence, Text, Type
10
+ from typing import (
11
+ Any,
12
+ Callable,
13
+ Collection,
14
+ Dict,
15
+ List,
16
+ Optional,
17
+ Sequence,
18
+ Text,
19
+ Type,
20
+ )
11
21
 
12
22
  import rasa.shared.utils.io
13
23
  from rasa.exceptions import MissingDependencyException
rasa/shared/utils/llm.py CHANGED
@@ -762,7 +762,7 @@ def allowed_values_for_slot(slot: Slot) -> Union[str, None]:
762
762
  if isinstance(slot, BooleanSlot):
763
763
  return str([True, False])
764
764
  if isinstance(slot, CategoricalSlot):
765
- return str([v for v in slot.values if v != "__other__"])
765
+ return str([v for v in slot.values if v != "__other__"] + ["other"])
766
766
  else:
767
767
  return None
768
768
 
@@ -326,7 +326,7 @@ def extract_attrs_for_command(
326
326
  def extract_llm_config(
327
327
  self: Any,
328
328
  default_llm_config: Dict[str, Any],
329
- default_embeddings_config: Dict[str, Any],
329
+ default_embeddings_config: Optional[Dict[str, Any]],
330
330
  ) -> Dict[str, Any]:
331
331
  if isinstance(self, ContextualResponseRephraser):
332
332
  # ContextualResponseRephraser is not a graph component, so it's
@@ -346,8 +346,12 @@ def extract_llm_config(
346
346
  default_embeddings_config,
347
347
  )
348
348
  else:
349
- embeddings_property = combine_custom_and_default_config(
350
- config.get(EMBEDDINGS_CONFIG_KEY), default_embeddings_config
349
+ embeddings_property = (
350
+ combine_custom_and_default_config(
351
+ config.get(EMBEDDINGS_CONFIG_KEY), default_embeddings_config
352
+ )
353
+ if default_embeddings_config is not None
354
+ else {}
351
355
  )
352
356
 
353
357
  attributes = {
@@ -402,7 +406,7 @@ def extract_attrs_for_contextual_response_rephraser(
402
406
  self,
403
407
  default_llm_config=DEFAULT_LLM_CONFIG,
404
408
  # rephraser is not using embeddings
405
- default_embeddings_config={},
409
+ default_embeddings_config=None,
406
410
  )
407
411
 
408
412
  return extend_attributes_with_prompt_tokens_length(self, attributes, prompt)
@@ -418,7 +422,7 @@ def extract_attrs_for_create_history(
418
422
  self,
419
423
  default_llm_config=DEFAULT_LLM_CONFIG,
420
424
  # rephraser is not using embeddings
421
- default_embeddings_config={},
425
+ default_embeddings_config=None,
422
426
  )
423
427
 
424
428
 
@@ -771,14 +775,13 @@ def extract_attrs_for_enterprise_search_generate_llm_answer(
771
775
  self: "EnterpriseSearchPolicy", llm: "BaseLLM", prompt: str
772
776
  ) -> Dict[str, Any]:
773
777
  from rasa.core.policies.enterprise_search_policy import (
774
- DEFAULT_EMBEDDINGS_CONFIG,
775
778
  DEFAULT_LLM_CONFIG,
776
779
  )
777
780
 
778
781
  attributes = extract_llm_config(
779
782
  self,
780
783
  default_llm_config=DEFAULT_LLM_CONFIG,
781
- default_embeddings_config=DEFAULT_EMBEDDINGS_CONFIG,
784
+ default_embeddings_config=None,
782
785
  )
783
786
 
784
787
  return extend_attributes_with_prompt_tokens_length(self, attributes, prompt)