rasa-pro 3.11.3a1.dev7__py3-none-any.whl → 3.12.0.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rasa-pro might be problematic. Click here for more details.

Files changed (65) hide show
  1. rasa/cli/arguments/default_arguments.py +1 -1
  2. rasa/cli/dialogue_understanding_test.py +251 -0
  3. rasa/core/actions/action.py +7 -16
  4. rasa/core/channels/__init__.py +0 -2
  5. rasa/core/channels/socketio.py +23 -1
  6. rasa/core/nlg/contextual_response_rephraser.py +9 -62
  7. rasa/core/policies/enterprise_search_policy.py +12 -77
  8. rasa/core/policies/flows/flow_executor.py +2 -26
  9. rasa/core/processor.py +8 -11
  10. rasa/dialogue_understanding/generator/command_generator.py +49 -43
  11. rasa/dialogue_understanding/generator/llm_based_command_generator.py +5 -5
  12. rasa/dialogue_understanding/generator/llm_command_generator.py +1 -2
  13. rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +15 -34
  14. rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +6 -11
  15. rasa/dialogue_understanding/utils.py +1 -8
  16. rasa/dialogue_understanding_test/command_metric_calculation.py +12 -0
  17. rasa/dialogue_understanding_test/constants.py +2 -0
  18. rasa/dialogue_understanding_test/du_test_runner.py +93 -0
  19. rasa/dialogue_understanding_test/io.py +54 -0
  20. rasa/dialogue_understanding_test/validation.py +22 -0
  21. rasa/e2e_test/e2e_test_runner.py +9 -7
  22. rasa/hooks.py +9 -15
  23. rasa/model_manager/socket_bridge.py +2 -7
  24. rasa/model_manager/warm_rasa_process.py +4 -9
  25. rasa/plugin.py +0 -11
  26. rasa/shared/constants.py +2 -21
  27. rasa/shared/core/events.py +8 -8
  28. rasa/shared/nlu/constants.py +0 -3
  29. rasa/shared/providers/_configs/azure_entra_id_client_creds.py +40 -0
  30. rasa/shared/providers/_configs/azure_entra_id_config.py +533 -0
  31. rasa/shared/providers/_configs/azure_openai_client_config.py +131 -15
  32. rasa/shared/providers/_configs/client_config.py +3 -1
  33. rasa/shared/providers/_configs/default_litellm_client_config.py +9 -7
  34. rasa/shared/providers/_configs/huggingface_local_embedding_client_config.py +13 -11
  35. rasa/shared/providers/_configs/litellm_router_client_config.py +12 -10
  36. rasa/shared/providers/_configs/model_group_config.py +11 -5
  37. rasa/shared/providers/_configs/oauth_config.py +33 -0
  38. rasa/shared/providers/_configs/openai_client_config.py +14 -12
  39. rasa/shared/providers/_configs/rasa_llm_client_config.py +5 -3
  40. rasa/shared/providers/_configs/self_hosted_llm_client_config.py +12 -11
  41. rasa/shared/providers/constants.py +6 -0
  42. rasa/shared/providers/embedding/azure_openai_embedding_client.py +30 -7
  43. rasa/shared/providers/embedding/litellm_router_embedding_client.py +5 -2
  44. rasa/shared/providers/llm/_base_litellm_client.py +6 -4
  45. rasa/shared/providers/llm/azure_openai_llm_client.py +88 -34
  46. rasa/shared/providers/llm/default_litellm_llm_client.py +4 -2
  47. rasa/shared/providers/llm/litellm_router_llm_client.py +23 -3
  48. rasa/shared/providers/llm/llm_client.py +4 -2
  49. rasa/shared/providers/llm/llm_response.py +1 -42
  50. rasa/shared/providers/llm/openai_llm_client.py +11 -5
  51. rasa/shared/providers/llm/rasa_llm_client.py +13 -5
  52. rasa/shared/providers/llm/self_hosted_llm_client.py +17 -10
  53. rasa/shared/providers/router/_base_litellm_router_client.py +10 -8
  54. rasa/shared/providers/router/router_client.py +3 -1
  55. rasa/shared/utils/llm.py +16 -12
  56. rasa/shared/utils/schemas/events.py +1 -1
  57. rasa/tracing/instrumentation/attribute_extractors.py +0 -2
  58. rasa/version.py +1 -1
  59. {rasa_pro-3.11.3a1.dev7.dist-info → rasa_pro-3.12.0.dev2.dist-info}/METADATA +2 -1
  60. {rasa_pro-3.11.3a1.dev7.dist-info → rasa_pro-3.12.0.dev2.dist-info}/RECORD +63 -56
  61. rasa/core/channels/studio_chat.py +0 -192
  62. rasa/dialogue_understanding/constants.py +0 -1
  63. {rasa_pro-3.11.3a1.dev7.dist-info → rasa_pro-3.12.0.dev2.dist-info}/NOTICE +0 -0
  64. {rasa_pro-3.11.3a1.dev7.dist-info → rasa_pro-3.12.0.dev2.dist-info}/WHEEL +0 -0
  65. {rasa_pro-3.11.3a1.dev7.dist-info → rasa_pro-3.12.0.dev2.dist-info}/entry_points.txt +0 -0
@@ -13,7 +13,7 @@ from rasa.shared.constants import (
13
13
 
14
14
 
15
15
  def add_model_param(
16
- parser: argparse.ArgumentParser,
16
+ parser: Union[argparse.ArgumentParser, argparse._ActionsContainer],
17
17
  model_name: Text = "Rasa",
18
18
  add_positional_arg: bool = True,
19
19
  default: Optional[Text] = DEFAULT_MODELS_PATH,
@@ -0,0 +1,251 @@
1
+ import argparse
2
+ import asyncio
3
+ import datetime
4
+ import sys
5
+ from typing import List
6
+
7
+ import structlog
8
+
9
+ import rasa.cli.utils
10
+ from rasa.cli import SubParsersAction
11
+ from rasa.cli.arguments.default_arguments import (
12
+ add_endpoint_param,
13
+ add_model_param,
14
+ add_remote_storage_param,
15
+ )
16
+ from rasa.core.exceptions import AgentNotReady
17
+ from rasa.core.utils import AvailableEndpoints
18
+ from rasa.dialogue_understanding_test.command_metric_calculation import (
19
+ calculate_command_metrics,
20
+ )
21
+ from rasa.dialogue_understanding_test.constants import (
22
+ DEFAULT_INPUT_TESTS_PATH,
23
+ KEY_STUB_CUSTOM_ACTIONS,
24
+ )
25
+ from rasa.dialogue_understanding_test.du_test_result import (
26
+ DialogueUnderstandingTestResult,
27
+ )
28
+ from rasa.dialogue_understanding_test.du_test_runner import (
29
+ DialogueUnderstandingTestRunner,
30
+ )
31
+ from rasa.dialogue_understanding_test.io import (
32
+ read_test_suite,
33
+ write_test_results_to_file,
34
+ print_test_results,
35
+ )
36
+ from rasa.dialogue_understanding_test.validation import (
37
+ validate_cli_arguments,
38
+ validate_test_cases,
39
+ )
40
+ from rasa.e2e_test.e2e_test_case import TestSuite
41
+ from rasa.exceptions import RasaException
42
+ from rasa.shared.constants import DEFAULT_ENDPOINTS_PATH
43
+ from rasa.utils.beta import ensure_beta_feature_is_enabled
44
+ from rasa.utils.endpoints import EndpointConfig
45
+
46
+ RASA_PRO_BETA_DIALOGUE_UNDERSTANDING_TEST_ENV_VAR_NAME = (
47
+ "RASA_PRO_BETA_DIALOGUE_UNDERSTANDING_TEST"
48
+ )
49
+
50
+ structlogger = structlog.get_logger()
51
+
52
+
53
+ def add_subparser(
54
+ subparsers: SubParsersAction, parents: List[argparse.ArgumentParser]
55
+ ) -> None:
56
+ """Add the dialogue understanding test subparser to `rasa test`.
57
+
58
+ Args:
59
+ subparsers: subparser we are going to attach to
60
+ parents: Parent parsers, needed to ensure tree structure in argparse
61
+ """
62
+ for subparser in subparsers.choices.values():
63
+ if subparser.prog == "rasa test":
64
+ du_test_subparser = create_du_test_subparser(parents)
65
+
66
+ for action in subparser._subparsers._actions:
67
+ if action.choices is not None:
68
+ action.choices["du"] = du_test_subparser
69
+ return
70
+
71
+ # If we get here, we couldn't hook the subparser to `rasa test`
72
+ raise RasaException(
73
+ "Hooking the dialogue understanding (du) test subparser to "
74
+ "`rasa test` command could not be completed. "
75
+ "Cannot run dialogue understanding testing."
76
+ )
77
+
78
+
79
+ def create_du_test_subparser(
80
+ parents: List[argparse.ArgumentParser],
81
+ ) -> argparse.ArgumentParser:
82
+ """Create dialogue understanding test subparser."""
83
+ du_test_subparser = argparse.ArgumentParser(
84
+ prog="rasa test du",
85
+ parents=parents,
86
+ conflict_handler="resolve",
87
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
88
+ description="Runs dialogue understanding testing.",
89
+ )
90
+
91
+ du_test_subparser.set_defaults(func=execute_dialogue_understanding_tests)
92
+
93
+ add_du_test_arguments(du_test_subparser)
94
+ add_bot_arguments(du_test_subparser)
95
+
96
+ return du_test_subparser
97
+
98
+
99
+ def add_bot_arguments(parser: argparse.ArgumentParser) -> None:
100
+ bot_arguments = parser.add_argument_group("Bot Settings")
101
+ add_model_param(bot_arguments, add_positional_arg=False)
102
+ add_endpoint_param(
103
+ bot_arguments,
104
+ help_text="Configuration file for the model server and the connectors as a "
105
+ "yml file.",
106
+ )
107
+ add_remote_storage_param(bot_arguments)
108
+
109
+
110
+ def add_du_test_arguments(parser: argparse.ArgumentParser) -> None:
111
+ """Arguments for running dialogue understanding tests."""
112
+ du_arguments = parser.add_argument_group("Testing Settings")
113
+ du_arguments.add_argument(
114
+ "path-to-test-cases",
115
+ nargs="?",
116
+ type=str,
117
+ default=DEFAULT_INPUT_TESTS_PATH,
118
+ help="Input file or folder containing dialogue understanding test cases.",
119
+ )
120
+ du_arguments.add_argument(
121
+ "--output-file",
122
+ type=str,
123
+ default="dialogue_understanding_test_{date:%Y%m%d-%H%M%S}.yml".format(
124
+ date=datetime.datetime.now()
125
+ ),
126
+ help="Path to the output file to write the results to.",
127
+ )
128
+ du_arguments.add_argument(
129
+ "--no-output",
130
+ action="store_true",
131
+ help="If set, no output file will be written to disk.",
132
+ )
133
+ du_arguments.add_argument(
134
+ "--output-prompt",
135
+ action="store_true",
136
+ help="If set, the dialogue understanding test output will contain "
137
+ "prompts for each failure.",
138
+ )
139
+
140
+
141
+ def execute_dialogue_understanding_tests(args: argparse.Namespace) -> None:
142
+ """Run the dialogue understanding tests.
143
+
144
+ Args:
145
+ args: Commandline arguments.
146
+ """
147
+ ensure_beta_feature_is_enabled(
148
+ "Dialogue Understanding (DU) Testing",
149
+ env_flag=RASA_PRO_BETA_DIALOGUE_UNDERSTANDING_TEST_ENV_VAR_NAME,
150
+ )
151
+
152
+ # basic validation of the passed CLI arguments
153
+ validate_cli_arguments(args)
154
+
155
+ # initialization of endpoints
156
+ endpoints = set_up_available_endpoints(args)
157
+
158
+ # read test cases from the given path
159
+ test_suite = get_valid_test_suite(args)
160
+
161
+ # setup stub custom actions if they are used
162
+ set_up_stub_custom_actions(test_suite, endpoints)
163
+
164
+ # set up the test runner, e.g. start the agent
165
+ try:
166
+ test_runner = DialogueUnderstandingTestRunner(
167
+ endpoints=endpoints,
168
+ model_path=args.model,
169
+ model_server=endpoints.model,
170
+ remote_storage=args.remote_storage,
171
+ )
172
+ except AgentNotReady as error:
173
+ structlogger.error(
174
+ "rasa.dialogue_understanding_test.agent_not_ready", message=error.message
175
+ )
176
+ sys.exit(1)
177
+
178
+ # run the actual test cases
179
+ test_results = asyncio.run(
180
+ test_runner.run_tests(
181
+ test_suite.test_cases, test_suite.fixtures, test_suite.metadata
182
+ )
183
+ )
184
+
185
+ # evaluate test results
186
+ failed_tests, passed_tests = split_test_results(test_results)
187
+ command_metrics = calculate_command_metrics(test_results)
188
+
189
+ # write results to console and file
190
+ print_test_results(failed_tests, passed_tests, command_metrics, args.output_prompt)
191
+ if not args.no_output:
192
+ write_test_results_to_file(
193
+ failed_tests,
194
+ passed_tests,
195
+ command_metrics,
196
+ args.output_file,
197
+ args.output_prompt,
198
+ )
199
+
200
+
201
+ def get_valid_test_suite(args: argparse.Namespace) -> TestSuite:
202
+ """Read the test cases from the given test case path and validate them."""
203
+ path_to_test_cases = getattr(args, "path-to-test-cases", DEFAULT_INPUT_TESTS_PATH)
204
+ test_suite = read_test_suite(path_to_test_cases)
205
+ validate_test_cases(test_suite.test_cases)
206
+ return test_suite
207
+
208
+
209
+ def set_up_available_endpoints(args: argparse.Namespace) -> AvailableEndpoints:
210
+ """Set up the available endpoints for the test runner."""
211
+ args.endpoints = rasa.cli.utils.get_validated_path(
212
+ args.endpoints, "endpoints", DEFAULT_ENDPOINTS_PATH, True
213
+ )
214
+ endpoints = AvailableEndpoints.get_instance(args.endpoints)
215
+
216
+ # Ignore all endpoints apart from action server, model, and nlu
217
+ # to ensure InMemoryTrackerStore is being used instead of production
218
+ # tracker store
219
+ endpoints.tracker_store = None
220
+ endpoints.lock_store = None
221
+ endpoints.event_broker = None
222
+
223
+ # disable nlg endpoint as we don't need it for dialogue understanding tests
224
+ endpoints.nlg = None
225
+
226
+ return endpoints
227
+
228
+
229
+ def set_up_stub_custom_actions(
230
+ test_suite: TestSuite, endpoints: AvailableEndpoints
231
+ ) -> None:
232
+ """Set up the stub custom actions if they are used."""
233
+ if test_suite.stub_custom_actions:
234
+ if not endpoints.action:
235
+ endpoints.action = EndpointConfig()
236
+
237
+ endpoints.action.kwargs[KEY_STUB_CUSTOM_ACTIONS] = (
238
+ test_suite.stub_custom_actions
239
+ )
240
+
241
+
242
+ def split_test_results(
243
+ results: List[DialogueUnderstandingTestResult],
244
+ ) -> tuple[
245
+ List[DialogueUnderstandingTestResult], List[DialogueUnderstandingTestResult]
246
+ ]:
247
+ """Split the test results into passed and failed test cases."""
248
+ passed_cases = [r for r in results if r.passed]
249
+ failed_cases = [r for r in results if not r.passed]
250
+
251
+ return passed_cases, failed_cases
@@ -17,15 +17,6 @@ from jsonschema import Draft202012Validator
17
17
 
18
18
  import rasa.core
19
19
  import rasa.shared.utils.io
20
- from rasa.shared.constants import (
21
- TEXT,
22
- ELEMENTS,
23
- QUICK_REPLIES,
24
- BUTTONS,
25
- ATTACHMENT,
26
- IMAGE,
27
- CUSTOM,
28
- )
29
20
  from rasa.core.actions.custom_action_executor import (
30
21
  CustomActionExecutor,
31
22
  NoEndpointCustomActionExecutor,
@@ -264,18 +255,18 @@ def action_for_name_or_text(
264
255
  def create_bot_utterance(message: Dict[Text, Any]) -> BotUttered:
265
256
  """Create BotUttered event from message."""
266
257
  bot_message = BotUttered(
267
- text=message.pop(TEXT, None),
258
+ text=message.pop("text", None),
268
259
  data={
269
- ELEMENTS: message.pop(ELEMENTS, None),
270
- QUICK_REPLIES: message.pop(QUICK_REPLIES, None),
271
- BUTTONS: message.pop(BUTTONS, None),
260
+ "elements": message.pop("elements", None),
261
+ "quick_replies": message.pop("quick_replies", None),
262
+ "buttons": message.pop("buttons", None),
272
263
  # for legacy / compatibility reasons we need to set the image
273
264
  # to be the attachment if there is no other attachment (the
274
265
  # `.get` is intentional - no `pop` as we still need the image`
275
266
  # property to set it in the following line)
276
- ATTACHMENT: message.pop(ATTACHMENT, None) or message.get(IMAGE, None),
277
- IMAGE: message.pop(IMAGE, None),
278
- CUSTOM: message.pop(CUSTOM, None),
267
+ "attachment": message.pop("attachment", None) or message.get("image", None),
268
+ "image": message.pop("image", None),
269
+ "custom": message.pop("custom", None),
279
270
  },
280
271
  metadata=message,
281
272
  )
@@ -31,7 +31,6 @@ from rasa.core.channels.vier_cvg import CVGInput
31
31
  from rasa.core.channels.voice_stream.twilio_media_streams import (
32
32
  TwilioMediaStreamsInputChannel,
33
33
  )
34
- from rasa.core.channels.studio_chat import StudioChatInput
35
34
 
36
35
  input_channel_classes: List[Type[InputChannel]] = [
37
36
  CmdlineInput,
@@ -54,7 +53,6 @@ input_channel_classes: List[Type[InputChannel]] = [
54
53
  JambonzVoiceReadyInput,
55
54
  TwilioMediaStreamsInputChannel,
56
55
  BrowserAudioInputChannel,
57
- StudioChatInput,
58
56
  ]
59
57
 
60
58
  # Mapping from an input channel name to its class to allow name based lookup.
@@ -54,9 +54,31 @@ class SocketIOOutput(OutputChannel):
54
54
  super().__init__()
55
55
  self.sio = sio
56
56
  self.bot_message_evt = bot_message_evt
57
+ self.last_event_timestamp = (
58
+ -1
59
+ ) # Initialize with -1 to send all events on first message
60
+
61
+ def _get_new_events(self) -> List[Dict[Text, Any]]:
62
+ """Get events that are newer than the last sent event."""
63
+ events = self.tracker_state.get("events", []) if self.tracker_state else []
64
+ new_events = [
65
+ event for event in events if event["timestamp"] > self.last_event_timestamp
66
+ ]
67
+ if new_events:
68
+ self.last_event_timestamp = new_events[-1]["timestamp"]
69
+ return new_events
57
70
 
58
71
  async def _send_message(self, socket_id: Text, response: Any) -> None:
59
72
  """Sends a message to the recipient using the bot event."""
73
+ # send tracker state (contains stack, slots and more)
74
+ await self.sio.emit("tracker_state", self.tracker_state, room=socket_id)
75
+
76
+ # send new events
77
+ new_events = self._get_new_events()
78
+ if new_events:
79
+ await self.sio.emit("rasa_events", new_events, room=socket_id)
80
+
81
+ # send bot response
60
82
  await self.sio.emit(self.bot_message_evt, response, room=socket_id)
61
83
 
62
84
  async def send_text_message(
@@ -192,7 +214,7 @@ class SocketIOInput(InputChannel):
192
214
 
193
215
  def blueprint(
194
216
  self, on_new_message: Callable[[UserMessage], Awaitable[Any]]
195
- ) -> SocketBlueprint:
217
+ ) -> Blueprint:
196
218
  """Defines a Sanic blueprint."""
197
219
  # Workaround so that socketio works with requests from other origins.
198
220
  # https://github.com/miguelgrinberg/python-socketio/issues/205#issuecomment-493769183
@@ -2,7 +2,6 @@ from typing import Any, Dict, Optional, Text
2
2
 
3
3
  import structlog
4
4
  from jinja2 import Template
5
-
6
5
  from rasa import telemetry
7
6
  from rasa.core.nlg.response import TemplatedNaturalLanguageGenerator
8
7
  from rasa.core.nlg.summarize import summarize_conversation
@@ -19,14 +18,6 @@ from rasa.shared.constants import (
19
18
  from rasa.shared.core.domain import KEY_RESPONSES_TEXT, Domain
20
19
  from rasa.shared.core.events import BotUttered, UserUttered
21
20
  from rasa.shared.core.trackers import DialogueStateTracker
22
- from rasa.shared.nlu.constants import (
23
- PROMPTS,
24
- KEY_USER_PROMPT,
25
- KEY_LLM_RESPONSE_METADATA,
26
- KEY_PROMPT_NAME,
27
- KEY_COMPONENT_NAME,
28
- )
29
- from rasa.shared.providers.llm.llm_response import LLMResponse
30
21
  from rasa.shared.utils.health_check.llm_health_check_mixin import LLMHealthCheckMixin
31
22
  from rasa.shared.utils.llm import (
32
23
  DEFAULT_OPENAI_GENERATE_MODEL_NAME,
@@ -133,39 +124,6 @@ class ContextualResponseRephraser(
133
124
  ContextualResponseRephraser.__name__,
134
125
  )
135
126
 
136
- @classmethod
137
- def _add_prompt_and_llm_metadata_to_response(
138
- cls,
139
- response: Dict[str, Any],
140
- prompt_name: str,
141
- user_prompt: str,
142
- llm_response: Optional["LLMResponse"] = None,
143
- ) -> Dict[str, Any]:
144
- """Stores the prompt and LLMResponse metadata to response.
145
-
146
- Args:
147
- response: The response to add the prompt and LLMResponse metadata to.
148
- prompt_name: A name identifying prompt usage.
149
- user_prompt: The user prompt that was sent to the LLM.
150
- llm_response: The response object from the LLM (None if no response).
151
- """
152
- from rasa.dialogue_understanding.utils import record_commands_and_prompts
153
-
154
- if not record_commands_and_prompts:
155
- return response
156
-
157
- prompt_data: Dict[Text, Any] = {
158
- KEY_COMPONENT_NAME: cls.__name__,
159
- KEY_PROMPT_NAME: prompt_name,
160
- KEY_USER_PROMPT: user_prompt,
161
- KEY_LLM_RESPONSE_METADATA: llm_response.to_dict() if llm_response else None,
162
- }
163
-
164
- prompts = response.get(PROMPTS, [])
165
- prompts.append(prompt_data)
166
- response[PROMPTS] = prompts
167
- return response
168
-
169
127
  def _last_message_if_human(self, tracker: DialogueStateTracker) -> Optional[str]:
170
128
  """Returns the latest message from the tracker.
171
129
 
@@ -184,21 +142,20 @@ class ContextualResponseRephraser(
184
142
  return None
185
143
  return None
186
144
 
187
- async def _generate_llm_response(self, prompt: str) -> Optional[LLMResponse]:
188
- """
189
- Use LLM to generate a response, returning an LLMResponse object
190
- containing both the generated text (choices) and metadata.
145
+ async def _generate_llm_response(self, prompt: str) -> Optional[str]:
146
+ """Use LLM to generate a response.
191
147
 
192
148
  Args:
193
- prompt: The prompt to send to the LLM.
149
+ prompt: the prompt to send to the LLM
194
150
 
195
151
  Returns:
196
- An LLMResponse object if successful, otherwise None.
152
+ generated text
197
153
  """
198
154
  llm = llm_factory(self.llm_config, DEFAULT_LLM_CONFIG)
199
155
 
200
156
  try:
201
- return await llm.acompletion(prompt)
157
+ llm_response = await llm.acompletion(prompt)
158
+ return llm_response.choices[0]
202
159
  except Exception as e:
203
160
  # unfortunately, langchain does not wrap LLM exceptions which means
204
161
  # we have to catch all exceptions here
@@ -298,21 +255,11 @@ class ContextualResponseRephraser(
298
255
  or self.llm_property(MODEL_NAME_CONFIG_KEY),
299
256
  llm_model_group_id=self.llm_property(MODEL_GROUP_ID_CONFIG_KEY),
300
257
  )
301
- llm_response = await self._generate_llm_response(prompt)
302
- llm_response = LLMResponse.ensure_llm_response(llm_response)
303
-
304
- response = self._add_prompt_and_llm_metadata_to_response(
305
- response=response,
306
- prompt_name="rephrase_prompt",
307
- user_prompt=prompt,
308
- llm_response=llm_response,
309
- )
310
-
311
- if not (llm_response and llm_response.choices and llm_response.choices[0]):
312
- # If the LLM fails to generate a response, return the original response.
258
+ if not (updated_text := await self._generate_llm_response(prompt)):
259
+ # If the LLM fails to generate a response, we
260
+ # return the original response.
313
261
  return response
314
262
 
315
- updated_text = llm_response.choices[0]
316
263
  structlogger.debug(
317
264
  "nlg.rewrite.complete",
318
265
  response_text=response_text,
@@ -2,7 +2,6 @@ import importlib.resources
2
2
  import json
3
3
  import re
4
4
  from typing import TYPE_CHECKING, Any, Dict, List, Optional, Text
5
-
6
5
  import dotenv
7
6
  import structlog
8
7
  from jinja2 import Template
@@ -64,19 +63,11 @@ from rasa.shared.core.events import Event, UserUttered, BotUttered
64
63
  from rasa.shared.core.generator import TrackerWithCachedStates
65
64
  from rasa.shared.core.trackers import DialogueStateTracker, EventVerbosity
66
65
  from rasa.shared.exceptions import RasaException, FileIOException
67
- from rasa.shared.nlu.constants import (
68
- PROMPTS,
69
- KEY_USER_PROMPT,
70
- KEY_LLM_RESPONSE_METADATA,
71
- KEY_PROMPT_NAME,
72
- KEY_COMPONENT_NAME,
73
- )
74
66
  from rasa.shared.nlu.training_data.training_data import TrainingData
75
67
  from rasa.shared.providers.embedding._langchain_embedding_client_adapter import (
76
68
  _LangchainEmbeddingClientAdapter,
77
69
  )
78
70
  from rasa.shared.providers.llm.llm_client import LLMClient
79
- from rasa.shared.providers.llm.llm_response import LLMResponse
80
71
  from rasa.shared.utils.cli import print_error_and_exit
81
72
  from rasa.shared.utils.health_check.embeddings_health_check_mixin import (
82
73
  EmbeddingsHealthCheckMixin,
@@ -281,43 +272,6 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
281
272
  # Wrap the embedding client in the adapter
282
273
  return _LangchainEmbeddingClientAdapter(client)
283
274
 
284
- @classmethod
285
- def _add_prompt_and_llm_response_to_latest_message(
286
- cls,
287
- tracker: DialogueStateTracker,
288
- prompt_name: str,
289
- user_prompt: str,
290
- llm_response: Optional[LLMResponse] = None,
291
- ) -> None:
292
- """Stores the prompt and LLMResponse metadata in the tracker.
293
-
294
- Args:
295
- tracker: The DialogueStateTracker containing the current conversation state.
296
- prompt_name: A name identifying prompt usage.
297
- user_prompt: The user prompt that was sent to the LLM.
298
- llm_response: The response object from the LLM (None if no response).
299
- """
300
- from rasa.dialogue_understanding.utils import record_commands_and_prompts
301
-
302
- if not record_commands_and_prompts:
303
- return
304
-
305
- if not tracker.latest_message:
306
- return
307
-
308
- parse_data = tracker.latest_message.parse_data
309
- if PROMPTS not in parse_data:
310
- parse_data[PROMPTS] = [] # type: ignore[literal-required]
311
-
312
- prompt_data: Dict[Text, Any] = {
313
- KEY_COMPONENT_NAME: cls.__name__,
314
- KEY_PROMPT_NAME: prompt_name,
315
- KEY_USER_PROMPT: user_prompt,
316
- KEY_LLM_RESPONSE_METADATA: llm_response.to_dict() if llm_response else None,
317
- }
318
-
319
- parse_data[PROMPTS].append(prompt_data) # type: ignore[literal-required]
320
-
321
275
  def train( # type: ignore[override]
322
276
  self,
323
277
  training_trackers: List[TrackerWithCachedStates],
@@ -544,27 +498,13 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
544
498
 
545
499
  if self.use_llm:
546
500
  prompt = self._render_prompt(tracker, documents.results)
547
- llm_response = await self._generate_llm_answer(llm, prompt)
548
- llm_response = LLMResponse.ensure_llm_response(llm_response)
549
-
550
- self._add_prompt_and_llm_response_to_latest_message(
551
- tracker=tracker,
552
- prompt_name="enterprise_search_prompt",
553
- user_prompt=prompt,
554
- llm_response=llm_response,
555
- )
501
+ llm_answer = await self._generate_llm_answer(llm, prompt)
556
502
 
557
- if llm_response is None or not llm_response.choices:
558
- logger.debug(f"{logger_key}.no_llm_response")
559
- response = None
560
- else:
561
- llm_answer = llm_response.choices[0]
503
+ if self.citation_enabled:
504
+ llm_answer = self.post_process_citations(llm_answer)
562
505
 
563
- if self.citation_enabled:
564
- llm_answer = self.post_process_citations(llm_answer)
565
-
566
- logger.debug(f"{logger_key}.llm_answer", llm_answer=llm_answer)
567
- response = llm_answer
506
+ logger.debug(f"{logger_key}.llm_answer", llm_answer=llm_answer)
507
+ response = llm_answer
568
508
  else:
569
509
  response = documents.results[0].metadata.get("answer", None)
570
510
  if not response:
@@ -576,6 +516,7 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
576
516
  "enterprise_search_policy.predict_action_probabilities.no_llm",
577
517
  search_results=documents,
578
518
  )
519
+
579
520
  if response is None:
580
521
  return self._create_prediction_internal_error(domain, tracker)
581
522
 
@@ -640,18 +581,10 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
640
581
 
641
582
  async def _generate_llm_answer(
642
583
  self, llm: LLMClient, prompt: Text
643
- ) -> Optional[LLMResponse]:
644
- """Fetches an LLM completion for the provided prompt.
645
-
646
- Args:
647
- llm: The LLM client used to get the completion.
648
- prompt: The prompt text to send to the model.
649
-
650
- Returns:
651
- An LLMResponse object, or None if the call fails.
652
- """
584
+ ) -> Optional[Text]:
653
585
  try:
654
- return await llm.acompletion(prompt)
586
+ llm_response = await llm.acompletion(prompt)
587
+ llm_answer = llm_response.choices[0]
655
588
  except Exception as e:
656
589
  # unfortunately, langchain does not wrap LLM exceptions which means
657
590
  # we have to catch all exceptions here
@@ -659,7 +592,9 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
659
592
  "enterprise_search_policy._generate_llm_answer.llm_error",
660
593
  error=e,
661
594
  )
662
- return None
595
+ llm_answer = None
596
+
597
+ return llm_answer
663
598
 
664
599
  def _create_prediction(
665
600
  self,