rasa-pro 3.11.3a1.dev7__py3-none-any.whl → 3.12.0.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- rasa/cli/arguments/default_arguments.py +1 -1
- rasa/cli/dialogue_understanding_test.py +251 -0
- rasa/core/actions/action.py +7 -16
- rasa/core/channels/__init__.py +0 -2
- rasa/core/channels/socketio.py +23 -1
- rasa/core/nlg/contextual_response_rephraser.py +9 -62
- rasa/core/policies/enterprise_search_policy.py +12 -77
- rasa/core/policies/flows/flow_executor.py +2 -26
- rasa/core/processor.py +8 -11
- rasa/dialogue_understanding/generator/command_generator.py +49 -43
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +5 -5
- rasa/dialogue_understanding/generator/llm_command_generator.py +1 -2
- rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +15 -34
- rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +6 -11
- rasa/dialogue_understanding/utils.py +1 -8
- rasa/dialogue_understanding_test/command_metric_calculation.py +12 -0
- rasa/dialogue_understanding_test/constants.py +2 -0
- rasa/dialogue_understanding_test/du_test_runner.py +93 -0
- rasa/dialogue_understanding_test/io.py +54 -0
- rasa/dialogue_understanding_test/validation.py +22 -0
- rasa/e2e_test/e2e_test_runner.py +9 -7
- rasa/hooks.py +9 -15
- rasa/model_manager/socket_bridge.py +2 -7
- rasa/model_manager/warm_rasa_process.py +4 -9
- rasa/plugin.py +0 -11
- rasa/shared/constants.py +2 -21
- rasa/shared/core/events.py +8 -8
- rasa/shared/nlu/constants.py +0 -3
- rasa/shared/providers/_configs/azure_entra_id_client_creds.py +40 -0
- rasa/shared/providers/_configs/azure_entra_id_config.py +533 -0
- rasa/shared/providers/_configs/azure_openai_client_config.py +131 -15
- rasa/shared/providers/_configs/client_config.py +3 -1
- rasa/shared/providers/_configs/default_litellm_client_config.py +9 -7
- rasa/shared/providers/_configs/huggingface_local_embedding_client_config.py +13 -11
- rasa/shared/providers/_configs/litellm_router_client_config.py +12 -10
- rasa/shared/providers/_configs/model_group_config.py +11 -5
- rasa/shared/providers/_configs/oauth_config.py +33 -0
- rasa/shared/providers/_configs/openai_client_config.py +14 -12
- rasa/shared/providers/_configs/rasa_llm_client_config.py +5 -3
- rasa/shared/providers/_configs/self_hosted_llm_client_config.py +12 -11
- rasa/shared/providers/constants.py +6 -0
- rasa/shared/providers/embedding/azure_openai_embedding_client.py +30 -7
- rasa/shared/providers/embedding/litellm_router_embedding_client.py +5 -2
- rasa/shared/providers/llm/_base_litellm_client.py +6 -4
- rasa/shared/providers/llm/azure_openai_llm_client.py +88 -34
- rasa/shared/providers/llm/default_litellm_llm_client.py +4 -2
- rasa/shared/providers/llm/litellm_router_llm_client.py +23 -3
- rasa/shared/providers/llm/llm_client.py +4 -2
- rasa/shared/providers/llm/llm_response.py +1 -42
- rasa/shared/providers/llm/openai_llm_client.py +11 -5
- rasa/shared/providers/llm/rasa_llm_client.py +13 -5
- rasa/shared/providers/llm/self_hosted_llm_client.py +17 -10
- rasa/shared/providers/router/_base_litellm_router_client.py +10 -8
- rasa/shared/providers/router/router_client.py +3 -1
- rasa/shared/utils/llm.py +16 -12
- rasa/shared/utils/schemas/events.py +1 -1
- rasa/tracing/instrumentation/attribute_extractors.py +0 -2
- rasa/version.py +1 -1
- {rasa_pro-3.11.3a1.dev7.dist-info → rasa_pro-3.12.0.dev2.dist-info}/METADATA +2 -1
- {rasa_pro-3.11.3a1.dev7.dist-info → rasa_pro-3.12.0.dev2.dist-info}/RECORD +63 -56
- rasa/core/channels/studio_chat.py +0 -192
- rasa/dialogue_understanding/constants.py +0 -1
- {rasa_pro-3.11.3a1.dev7.dist-info → rasa_pro-3.12.0.dev2.dist-info}/NOTICE +0 -0
- {rasa_pro-3.11.3a1.dev7.dist-info → rasa_pro-3.12.0.dev2.dist-info}/WHEEL +0 -0
- {rasa_pro-3.11.3a1.dev7.dist-info → rasa_pro-3.12.0.dev2.dist-info}/entry_points.txt +0 -0
|
@@ -13,7 +13,7 @@ from rasa.shared.constants import (
|
|
|
13
13
|
|
|
14
14
|
|
|
15
15
|
def add_model_param(
|
|
16
|
-
parser: argparse.ArgumentParser,
|
|
16
|
+
parser: Union[argparse.ArgumentParser, argparse._ActionsContainer],
|
|
17
17
|
model_name: Text = "Rasa",
|
|
18
18
|
add_positional_arg: bool = True,
|
|
19
19
|
default: Optional[Text] = DEFAULT_MODELS_PATH,
|
|
@@ -0,0 +1,251 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import asyncio
|
|
3
|
+
import datetime
|
|
4
|
+
import sys
|
|
5
|
+
from typing import List
|
|
6
|
+
|
|
7
|
+
import structlog
|
|
8
|
+
|
|
9
|
+
import rasa.cli.utils
|
|
10
|
+
from rasa.cli import SubParsersAction
|
|
11
|
+
from rasa.cli.arguments.default_arguments import (
|
|
12
|
+
add_endpoint_param,
|
|
13
|
+
add_model_param,
|
|
14
|
+
add_remote_storage_param,
|
|
15
|
+
)
|
|
16
|
+
from rasa.core.exceptions import AgentNotReady
|
|
17
|
+
from rasa.core.utils import AvailableEndpoints
|
|
18
|
+
from rasa.dialogue_understanding_test.command_metric_calculation import (
|
|
19
|
+
calculate_command_metrics,
|
|
20
|
+
)
|
|
21
|
+
from rasa.dialogue_understanding_test.constants import (
|
|
22
|
+
DEFAULT_INPUT_TESTS_PATH,
|
|
23
|
+
KEY_STUB_CUSTOM_ACTIONS,
|
|
24
|
+
)
|
|
25
|
+
from rasa.dialogue_understanding_test.du_test_result import (
|
|
26
|
+
DialogueUnderstandingTestResult,
|
|
27
|
+
)
|
|
28
|
+
from rasa.dialogue_understanding_test.du_test_runner import (
|
|
29
|
+
DialogueUnderstandingTestRunner,
|
|
30
|
+
)
|
|
31
|
+
from rasa.dialogue_understanding_test.io import (
|
|
32
|
+
read_test_suite,
|
|
33
|
+
write_test_results_to_file,
|
|
34
|
+
print_test_results,
|
|
35
|
+
)
|
|
36
|
+
from rasa.dialogue_understanding_test.validation import (
|
|
37
|
+
validate_cli_arguments,
|
|
38
|
+
validate_test_cases,
|
|
39
|
+
)
|
|
40
|
+
from rasa.e2e_test.e2e_test_case import TestSuite
|
|
41
|
+
from rasa.exceptions import RasaException
|
|
42
|
+
from rasa.shared.constants import DEFAULT_ENDPOINTS_PATH
|
|
43
|
+
from rasa.utils.beta import ensure_beta_feature_is_enabled
|
|
44
|
+
from rasa.utils.endpoints import EndpointConfig
|
|
45
|
+
|
|
46
|
+
RASA_PRO_BETA_DIALOGUE_UNDERSTANDING_TEST_ENV_VAR_NAME = (
|
|
47
|
+
"RASA_PRO_BETA_DIALOGUE_UNDERSTANDING_TEST"
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
structlogger = structlog.get_logger()
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def add_subparser(
|
|
54
|
+
subparsers: SubParsersAction, parents: List[argparse.ArgumentParser]
|
|
55
|
+
) -> None:
|
|
56
|
+
"""Add the dialogue understanding test subparser to `rasa test`.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
subparsers: subparser we are going to attach to
|
|
60
|
+
parents: Parent parsers, needed to ensure tree structure in argparse
|
|
61
|
+
"""
|
|
62
|
+
for subparser in subparsers.choices.values():
|
|
63
|
+
if subparser.prog == "rasa test":
|
|
64
|
+
du_test_subparser = create_du_test_subparser(parents)
|
|
65
|
+
|
|
66
|
+
for action in subparser._subparsers._actions:
|
|
67
|
+
if action.choices is not None:
|
|
68
|
+
action.choices["du"] = du_test_subparser
|
|
69
|
+
return
|
|
70
|
+
|
|
71
|
+
# If we get here, we couldn't hook the subparser to `rasa test`
|
|
72
|
+
raise RasaException(
|
|
73
|
+
"Hooking the dialogue understanding (du) test subparser to "
|
|
74
|
+
"`rasa test` command could not be completed. "
|
|
75
|
+
"Cannot run dialogue understanding testing."
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def create_du_test_subparser(
|
|
80
|
+
parents: List[argparse.ArgumentParser],
|
|
81
|
+
) -> argparse.ArgumentParser:
|
|
82
|
+
"""Create dialogue understanding test subparser."""
|
|
83
|
+
du_test_subparser = argparse.ArgumentParser(
|
|
84
|
+
prog="rasa test du",
|
|
85
|
+
parents=parents,
|
|
86
|
+
conflict_handler="resolve",
|
|
87
|
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
|
88
|
+
description="Runs dialogue understanding testing.",
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
du_test_subparser.set_defaults(func=execute_dialogue_understanding_tests)
|
|
92
|
+
|
|
93
|
+
add_du_test_arguments(du_test_subparser)
|
|
94
|
+
add_bot_arguments(du_test_subparser)
|
|
95
|
+
|
|
96
|
+
return du_test_subparser
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def add_bot_arguments(parser: argparse.ArgumentParser) -> None:
|
|
100
|
+
bot_arguments = parser.add_argument_group("Bot Settings")
|
|
101
|
+
add_model_param(bot_arguments, add_positional_arg=False)
|
|
102
|
+
add_endpoint_param(
|
|
103
|
+
bot_arguments,
|
|
104
|
+
help_text="Configuration file for the model server and the connectors as a "
|
|
105
|
+
"yml file.",
|
|
106
|
+
)
|
|
107
|
+
add_remote_storage_param(bot_arguments)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def add_du_test_arguments(parser: argparse.ArgumentParser) -> None:
|
|
111
|
+
"""Arguments for running dialogue understanding tests."""
|
|
112
|
+
du_arguments = parser.add_argument_group("Testing Settings")
|
|
113
|
+
du_arguments.add_argument(
|
|
114
|
+
"path-to-test-cases",
|
|
115
|
+
nargs="?",
|
|
116
|
+
type=str,
|
|
117
|
+
default=DEFAULT_INPUT_TESTS_PATH,
|
|
118
|
+
help="Input file or folder containing dialogue understanding test cases.",
|
|
119
|
+
)
|
|
120
|
+
du_arguments.add_argument(
|
|
121
|
+
"--output-file",
|
|
122
|
+
type=str,
|
|
123
|
+
default="dialogue_understanding_test_{date:%Y%m%d-%H%M%S}.yml".format(
|
|
124
|
+
date=datetime.datetime.now()
|
|
125
|
+
),
|
|
126
|
+
help="Path to the output file to write the results to.",
|
|
127
|
+
)
|
|
128
|
+
du_arguments.add_argument(
|
|
129
|
+
"--no-output",
|
|
130
|
+
action="store_true",
|
|
131
|
+
help="If set, no output file will be written to disk.",
|
|
132
|
+
)
|
|
133
|
+
du_arguments.add_argument(
|
|
134
|
+
"--output-prompt",
|
|
135
|
+
action="store_true",
|
|
136
|
+
help="If set, the dialogue understanding test output will contain "
|
|
137
|
+
"prompts for each failure.",
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def execute_dialogue_understanding_tests(args: argparse.Namespace) -> None:
|
|
142
|
+
"""Run the dialogue understanding tests.
|
|
143
|
+
|
|
144
|
+
Args:
|
|
145
|
+
args: Commandline arguments.
|
|
146
|
+
"""
|
|
147
|
+
ensure_beta_feature_is_enabled(
|
|
148
|
+
"Dialogue Understanding (DU) Testing",
|
|
149
|
+
env_flag=RASA_PRO_BETA_DIALOGUE_UNDERSTANDING_TEST_ENV_VAR_NAME,
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
# basic validation of the passed CLI arguments
|
|
153
|
+
validate_cli_arguments(args)
|
|
154
|
+
|
|
155
|
+
# initialization of endpoints
|
|
156
|
+
endpoints = set_up_available_endpoints(args)
|
|
157
|
+
|
|
158
|
+
# read test cases from the given path
|
|
159
|
+
test_suite = get_valid_test_suite(args)
|
|
160
|
+
|
|
161
|
+
# setup stub custom actions if they are used
|
|
162
|
+
set_up_stub_custom_actions(test_suite, endpoints)
|
|
163
|
+
|
|
164
|
+
# set up the test runner, e.g. start the agent
|
|
165
|
+
try:
|
|
166
|
+
test_runner = DialogueUnderstandingTestRunner(
|
|
167
|
+
endpoints=endpoints,
|
|
168
|
+
model_path=args.model,
|
|
169
|
+
model_server=endpoints.model,
|
|
170
|
+
remote_storage=args.remote_storage,
|
|
171
|
+
)
|
|
172
|
+
except AgentNotReady as error:
|
|
173
|
+
structlogger.error(
|
|
174
|
+
"rasa.dialogue_understanding_test.agent_not_ready", message=error.message
|
|
175
|
+
)
|
|
176
|
+
sys.exit(1)
|
|
177
|
+
|
|
178
|
+
# run the actual test cases
|
|
179
|
+
test_results = asyncio.run(
|
|
180
|
+
test_runner.run_tests(
|
|
181
|
+
test_suite.test_cases, test_suite.fixtures, test_suite.metadata
|
|
182
|
+
)
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
# evaluate test results
|
|
186
|
+
failed_tests, passed_tests = split_test_results(test_results)
|
|
187
|
+
command_metrics = calculate_command_metrics(test_results)
|
|
188
|
+
|
|
189
|
+
# write results to console and file
|
|
190
|
+
print_test_results(failed_tests, passed_tests, command_metrics, args.output_prompt)
|
|
191
|
+
if not args.no_output:
|
|
192
|
+
write_test_results_to_file(
|
|
193
|
+
failed_tests,
|
|
194
|
+
passed_tests,
|
|
195
|
+
command_metrics,
|
|
196
|
+
args.output_file,
|
|
197
|
+
args.output_prompt,
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
def get_valid_test_suite(args: argparse.Namespace) -> TestSuite:
|
|
202
|
+
"""Read the test cases from the given test case path and validate them."""
|
|
203
|
+
path_to_test_cases = getattr(args, "path-to-test-cases", DEFAULT_INPUT_TESTS_PATH)
|
|
204
|
+
test_suite = read_test_suite(path_to_test_cases)
|
|
205
|
+
validate_test_cases(test_suite.test_cases)
|
|
206
|
+
return test_suite
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def set_up_available_endpoints(args: argparse.Namespace) -> AvailableEndpoints:
|
|
210
|
+
"""Set up the available endpoints for the test runner."""
|
|
211
|
+
args.endpoints = rasa.cli.utils.get_validated_path(
|
|
212
|
+
args.endpoints, "endpoints", DEFAULT_ENDPOINTS_PATH, True
|
|
213
|
+
)
|
|
214
|
+
endpoints = AvailableEndpoints.get_instance(args.endpoints)
|
|
215
|
+
|
|
216
|
+
# Ignore all endpoints apart from action server, model, and nlu
|
|
217
|
+
# to ensure InMemoryTrackerStore is being used instead of production
|
|
218
|
+
# tracker store
|
|
219
|
+
endpoints.tracker_store = None
|
|
220
|
+
endpoints.lock_store = None
|
|
221
|
+
endpoints.event_broker = None
|
|
222
|
+
|
|
223
|
+
# disable nlg endpoint as we don't need it for dialogue understanding tests
|
|
224
|
+
endpoints.nlg = None
|
|
225
|
+
|
|
226
|
+
return endpoints
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
def set_up_stub_custom_actions(
|
|
230
|
+
test_suite: TestSuite, endpoints: AvailableEndpoints
|
|
231
|
+
) -> None:
|
|
232
|
+
"""Set up the stub custom actions if they are used."""
|
|
233
|
+
if test_suite.stub_custom_actions:
|
|
234
|
+
if not endpoints.action:
|
|
235
|
+
endpoints.action = EndpointConfig()
|
|
236
|
+
|
|
237
|
+
endpoints.action.kwargs[KEY_STUB_CUSTOM_ACTIONS] = (
|
|
238
|
+
test_suite.stub_custom_actions
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
def split_test_results(
|
|
243
|
+
results: List[DialogueUnderstandingTestResult],
|
|
244
|
+
) -> tuple[
|
|
245
|
+
List[DialogueUnderstandingTestResult], List[DialogueUnderstandingTestResult]
|
|
246
|
+
]:
|
|
247
|
+
"""Split the test results into passed and failed test cases."""
|
|
248
|
+
passed_cases = [r for r in results if r.passed]
|
|
249
|
+
failed_cases = [r for r in results if not r.passed]
|
|
250
|
+
|
|
251
|
+
return passed_cases, failed_cases
|
rasa/core/actions/action.py
CHANGED
|
@@ -17,15 +17,6 @@ from jsonschema import Draft202012Validator
|
|
|
17
17
|
|
|
18
18
|
import rasa.core
|
|
19
19
|
import rasa.shared.utils.io
|
|
20
|
-
from rasa.shared.constants import (
|
|
21
|
-
TEXT,
|
|
22
|
-
ELEMENTS,
|
|
23
|
-
QUICK_REPLIES,
|
|
24
|
-
BUTTONS,
|
|
25
|
-
ATTACHMENT,
|
|
26
|
-
IMAGE,
|
|
27
|
-
CUSTOM,
|
|
28
|
-
)
|
|
29
20
|
from rasa.core.actions.custom_action_executor import (
|
|
30
21
|
CustomActionExecutor,
|
|
31
22
|
NoEndpointCustomActionExecutor,
|
|
@@ -264,18 +255,18 @@ def action_for_name_or_text(
|
|
|
264
255
|
def create_bot_utterance(message: Dict[Text, Any]) -> BotUttered:
|
|
265
256
|
"""Create BotUttered event from message."""
|
|
266
257
|
bot_message = BotUttered(
|
|
267
|
-
text=message.pop(
|
|
258
|
+
text=message.pop("text", None),
|
|
268
259
|
data={
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
260
|
+
"elements": message.pop("elements", None),
|
|
261
|
+
"quick_replies": message.pop("quick_replies", None),
|
|
262
|
+
"buttons": message.pop("buttons", None),
|
|
272
263
|
# for legacy / compatibility reasons we need to set the image
|
|
273
264
|
# to be the attachment if there is no other attachment (the
|
|
274
265
|
# `.get` is intentional - no `pop` as we still need the image`
|
|
275
266
|
# property to set it in the following line)
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
267
|
+
"attachment": message.pop("attachment", None) or message.get("image", None),
|
|
268
|
+
"image": message.pop("image", None),
|
|
269
|
+
"custom": message.pop("custom", None),
|
|
279
270
|
},
|
|
280
271
|
metadata=message,
|
|
281
272
|
)
|
rasa/core/channels/__init__.py
CHANGED
|
@@ -31,7 +31,6 @@ from rasa.core.channels.vier_cvg import CVGInput
|
|
|
31
31
|
from rasa.core.channels.voice_stream.twilio_media_streams import (
|
|
32
32
|
TwilioMediaStreamsInputChannel,
|
|
33
33
|
)
|
|
34
|
-
from rasa.core.channels.studio_chat import StudioChatInput
|
|
35
34
|
|
|
36
35
|
input_channel_classes: List[Type[InputChannel]] = [
|
|
37
36
|
CmdlineInput,
|
|
@@ -54,7 +53,6 @@ input_channel_classes: List[Type[InputChannel]] = [
|
|
|
54
53
|
JambonzVoiceReadyInput,
|
|
55
54
|
TwilioMediaStreamsInputChannel,
|
|
56
55
|
BrowserAudioInputChannel,
|
|
57
|
-
StudioChatInput,
|
|
58
56
|
]
|
|
59
57
|
|
|
60
58
|
# Mapping from an input channel name to its class to allow name based lookup.
|
rasa/core/channels/socketio.py
CHANGED
|
@@ -54,9 +54,31 @@ class SocketIOOutput(OutputChannel):
|
|
|
54
54
|
super().__init__()
|
|
55
55
|
self.sio = sio
|
|
56
56
|
self.bot_message_evt = bot_message_evt
|
|
57
|
+
self.last_event_timestamp = (
|
|
58
|
+
-1
|
|
59
|
+
) # Initialize with -1 to send all events on first message
|
|
60
|
+
|
|
61
|
+
def _get_new_events(self) -> List[Dict[Text, Any]]:
|
|
62
|
+
"""Get events that are newer than the last sent event."""
|
|
63
|
+
events = self.tracker_state.get("events", []) if self.tracker_state else []
|
|
64
|
+
new_events = [
|
|
65
|
+
event for event in events if event["timestamp"] > self.last_event_timestamp
|
|
66
|
+
]
|
|
67
|
+
if new_events:
|
|
68
|
+
self.last_event_timestamp = new_events[-1]["timestamp"]
|
|
69
|
+
return new_events
|
|
57
70
|
|
|
58
71
|
async def _send_message(self, socket_id: Text, response: Any) -> None:
|
|
59
72
|
"""Sends a message to the recipient using the bot event."""
|
|
73
|
+
# send tracker state (contains stack, slots and more)
|
|
74
|
+
await self.sio.emit("tracker_state", self.tracker_state, room=socket_id)
|
|
75
|
+
|
|
76
|
+
# send new events
|
|
77
|
+
new_events = self._get_new_events()
|
|
78
|
+
if new_events:
|
|
79
|
+
await self.sio.emit("rasa_events", new_events, room=socket_id)
|
|
80
|
+
|
|
81
|
+
# send bot response
|
|
60
82
|
await self.sio.emit(self.bot_message_evt, response, room=socket_id)
|
|
61
83
|
|
|
62
84
|
async def send_text_message(
|
|
@@ -192,7 +214,7 @@ class SocketIOInput(InputChannel):
|
|
|
192
214
|
|
|
193
215
|
def blueprint(
|
|
194
216
|
self, on_new_message: Callable[[UserMessage], Awaitable[Any]]
|
|
195
|
-
) ->
|
|
217
|
+
) -> Blueprint:
|
|
196
218
|
"""Defines a Sanic blueprint."""
|
|
197
219
|
# Workaround so that socketio works with requests from other origins.
|
|
198
220
|
# https://github.com/miguelgrinberg/python-socketio/issues/205#issuecomment-493769183
|
|
@@ -2,7 +2,6 @@ from typing import Any, Dict, Optional, Text
|
|
|
2
2
|
|
|
3
3
|
import structlog
|
|
4
4
|
from jinja2 import Template
|
|
5
|
-
|
|
6
5
|
from rasa import telemetry
|
|
7
6
|
from rasa.core.nlg.response import TemplatedNaturalLanguageGenerator
|
|
8
7
|
from rasa.core.nlg.summarize import summarize_conversation
|
|
@@ -19,14 +18,6 @@ from rasa.shared.constants import (
|
|
|
19
18
|
from rasa.shared.core.domain import KEY_RESPONSES_TEXT, Domain
|
|
20
19
|
from rasa.shared.core.events import BotUttered, UserUttered
|
|
21
20
|
from rasa.shared.core.trackers import DialogueStateTracker
|
|
22
|
-
from rasa.shared.nlu.constants import (
|
|
23
|
-
PROMPTS,
|
|
24
|
-
KEY_USER_PROMPT,
|
|
25
|
-
KEY_LLM_RESPONSE_METADATA,
|
|
26
|
-
KEY_PROMPT_NAME,
|
|
27
|
-
KEY_COMPONENT_NAME,
|
|
28
|
-
)
|
|
29
|
-
from rasa.shared.providers.llm.llm_response import LLMResponse
|
|
30
21
|
from rasa.shared.utils.health_check.llm_health_check_mixin import LLMHealthCheckMixin
|
|
31
22
|
from rasa.shared.utils.llm import (
|
|
32
23
|
DEFAULT_OPENAI_GENERATE_MODEL_NAME,
|
|
@@ -133,39 +124,6 @@ class ContextualResponseRephraser(
|
|
|
133
124
|
ContextualResponseRephraser.__name__,
|
|
134
125
|
)
|
|
135
126
|
|
|
136
|
-
@classmethod
|
|
137
|
-
def _add_prompt_and_llm_metadata_to_response(
|
|
138
|
-
cls,
|
|
139
|
-
response: Dict[str, Any],
|
|
140
|
-
prompt_name: str,
|
|
141
|
-
user_prompt: str,
|
|
142
|
-
llm_response: Optional["LLMResponse"] = None,
|
|
143
|
-
) -> Dict[str, Any]:
|
|
144
|
-
"""Stores the prompt and LLMResponse metadata to response.
|
|
145
|
-
|
|
146
|
-
Args:
|
|
147
|
-
response: The response to add the prompt and LLMResponse metadata to.
|
|
148
|
-
prompt_name: A name identifying prompt usage.
|
|
149
|
-
user_prompt: The user prompt that was sent to the LLM.
|
|
150
|
-
llm_response: The response object from the LLM (None if no response).
|
|
151
|
-
"""
|
|
152
|
-
from rasa.dialogue_understanding.utils import record_commands_and_prompts
|
|
153
|
-
|
|
154
|
-
if not record_commands_and_prompts:
|
|
155
|
-
return response
|
|
156
|
-
|
|
157
|
-
prompt_data: Dict[Text, Any] = {
|
|
158
|
-
KEY_COMPONENT_NAME: cls.__name__,
|
|
159
|
-
KEY_PROMPT_NAME: prompt_name,
|
|
160
|
-
KEY_USER_PROMPT: user_prompt,
|
|
161
|
-
KEY_LLM_RESPONSE_METADATA: llm_response.to_dict() if llm_response else None,
|
|
162
|
-
}
|
|
163
|
-
|
|
164
|
-
prompts = response.get(PROMPTS, [])
|
|
165
|
-
prompts.append(prompt_data)
|
|
166
|
-
response[PROMPTS] = prompts
|
|
167
|
-
return response
|
|
168
|
-
|
|
169
127
|
def _last_message_if_human(self, tracker: DialogueStateTracker) -> Optional[str]:
|
|
170
128
|
"""Returns the latest message from the tracker.
|
|
171
129
|
|
|
@@ -184,21 +142,20 @@ class ContextualResponseRephraser(
|
|
|
184
142
|
return None
|
|
185
143
|
return None
|
|
186
144
|
|
|
187
|
-
async def _generate_llm_response(self, prompt: str) -> Optional[
|
|
188
|
-
"""
|
|
189
|
-
Use LLM to generate a response, returning an LLMResponse object
|
|
190
|
-
containing both the generated text (choices) and metadata.
|
|
145
|
+
async def _generate_llm_response(self, prompt: str) -> Optional[str]:
|
|
146
|
+
"""Use LLM to generate a response.
|
|
191
147
|
|
|
192
148
|
Args:
|
|
193
|
-
prompt:
|
|
149
|
+
prompt: the prompt to send to the LLM
|
|
194
150
|
|
|
195
151
|
Returns:
|
|
196
|
-
|
|
152
|
+
generated text
|
|
197
153
|
"""
|
|
198
154
|
llm = llm_factory(self.llm_config, DEFAULT_LLM_CONFIG)
|
|
199
155
|
|
|
200
156
|
try:
|
|
201
|
-
|
|
157
|
+
llm_response = await llm.acompletion(prompt)
|
|
158
|
+
return llm_response.choices[0]
|
|
202
159
|
except Exception as e:
|
|
203
160
|
# unfortunately, langchain does not wrap LLM exceptions which means
|
|
204
161
|
# we have to catch all exceptions here
|
|
@@ -298,21 +255,11 @@ class ContextualResponseRephraser(
|
|
|
298
255
|
or self.llm_property(MODEL_NAME_CONFIG_KEY),
|
|
299
256
|
llm_model_group_id=self.llm_property(MODEL_GROUP_ID_CONFIG_KEY),
|
|
300
257
|
)
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
response = self._add_prompt_and_llm_metadata_to_response(
|
|
305
|
-
response=response,
|
|
306
|
-
prompt_name="rephrase_prompt",
|
|
307
|
-
user_prompt=prompt,
|
|
308
|
-
llm_response=llm_response,
|
|
309
|
-
)
|
|
310
|
-
|
|
311
|
-
if not (llm_response and llm_response.choices and llm_response.choices[0]):
|
|
312
|
-
# If the LLM fails to generate a response, return the original response.
|
|
258
|
+
if not (updated_text := await self._generate_llm_response(prompt)):
|
|
259
|
+
# If the LLM fails to generate a response, we
|
|
260
|
+
# return the original response.
|
|
313
261
|
return response
|
|
314
262
|
|
|
315
|
-
updated_text = llm_response.choices[0]
|
|
316
263
|
structlogger.debug(
|
|
317
264
|
"nlg.rewrite.complete",
|
|
318
265
|
response_text=response_text,
|
|
@@ -2,7 +2,6 @@ import importlib.resources
|
|
|
2
2
|
import json
|
|
3
3
|
import re
|
|
4
4
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Text
|
|
5
|
-
|
|
6
5
|
import dotenv
|
|
7
6
|
import structlog
|
|
8
7
|
from jinja2 import Template
|
|
@@ -64,19 +63,11 @@ from rasa.shared.core.events import Event, UserUttered, BotUttered
|
|
|
64
63
|
from rasa.shared.core.generator import TrackerWithCachedStates
|
|
65
64
|
from rasa.shared.core.trackers import DialogueStateTracker, EventVerbosity
|
|
66
65
|
from rasa.shared.exceptions import RasaException, FileIOException
|
|
67
|
-
from rasa.shared.nlu.constants import (
|
|
68
|
-
PROMPTS,
|
|
69
|
-
KEY_USER_PROMPT,
|
|
70
|
-
KEY_LLM_RESPONSE_METADATA,
|
|
71
|
-
KEY_PROMPT_NAME,
|
|
72
|
-
KEY_COMPONENT_NAME,
|
|
73
|
-
)
|
|
74
66
|
from rasa.shared.nlu.training_data.training_data import TrainingData
|
|
75
67
|
from rasa.shared.providers.embedding._langchain_embedding_client_adapter import (
|
|
76
68
|
_LangchainEmbeddingClientAdapter,
|
|
77
69
|
)
|
|
78
70
|
from rasa.shared.providers.llm.llm_client import LLMClient
|
|
79
|
-
from rasa.shared.providers.llm.llm_response import LLMResponse
|
|
80
71
|
from rasa.shared.utils.cli import print_error_and_exit
|
|
81
72
|
from rasa.shared.utils.health_check.embeddings_health_check_mixin import (
|
|
82
73
|
EmbeddingsHealthCheckMixin,
|
|
@@ -281,43 +272,6 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
|
|
|
281
272
|
# Wrap the embedding client in the adapter
|
|
282
273
|
return _LangchainEmbeddingClientAdapter(client)
|
|
283
274
|
|
|
284
|
-
@classmethod
|
|
285
|
-
def _add_prompt_and_llm_response_to_latest_message(
|
|
286
|
-
cls,
|
|
287
|
-
tracker: DialogueStateTracker,
|
|
288
|
-
prompt_name: str,
|
|
289
|
-
user_prompt: str,
|
|
290
|
-
llm_response: Optional[LLMResponse] = None,
|
|
291
|
-
) -> None:
|
|
292
|
-
"""Stores the prompt and LLMResponse metadata in the tracker.
|
|
293
|
-
|
|
294
|
-
Args:
|
|
295
|
-
tracker: The DialogueStateTracker containing the current conversation state.
|
|
296
|
-
prompt_name: A name identifying prompt usage.
|
|
297
|
-
user_prompt: The user prompt that was sent to the LLM.
|
|
298
|
-
llm_response: The response object from the LLM (None if no response).
|
|
299
|
-
"""
|
|
300
|
-
from rasa.dialogue_understanding.utils import record_commands_and_prompts
|
|
301
|
-
|
|
302
|
-
if not record_commands_and_prompts:
|
|
303
|
-
return
|
|
304
|
-
|
|
305
|
-
if not tracker.latest_message:
|
|
306
|
-
return
|
|
307
|
-
|
|
308
|
-
parse_data = tracker.latest_message.parse_data
|
|
309
|
-
if PROMPTS not in parse_data:
|
|
310
|
-
parse_data[PROMPTS] = [] # type: ignore[literal-required]
|
|
311
|
-
|
|
312
|
-
prompt_data: Dict[Text, Any] = {
|
|
313
|
-
KEY_COMPONENT_NAME: cls.__name__,
|
|
314
|
-
KEY_PROMPT_NAME: prompt_name,
|
|
315
|
-
KEY_USER_PROMPT: user_prompt,
|
|
316
|
-
KEY_LLM_RESPONSE_METADATA: llm_response.to_dict() if llm_response else None,
|
|
317
|
-
}
|
|
318
|
-
|
|
319
|
-
parse_data[PROMPTS].append(prompt_data) # type: ignore[literal-required]
|
|
320
|
-
|
|
321
275
|
def train( # type: ignore[override]
|
|
322
276
|
self,
|
|
323
277
|
training_trackers: List[TrackerWithCachedStates],
|
|
@@ -544,27 +498,13 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
|
|
|
544
498
|
|
|
545
499
|
if self.use_llm:
|
|
546
500
|
prompt = self._render_prompt(tracker, documents.results)
|
|
547
|
-
|
|
548
|
-
llm_response = LLMResponse.ensure_llm_response(llm_response)
|
|
549
|
-
|
|
550
|
-
self._add_prompt_and_llm_response_to_latest_message(
|
|
551
|
-
tracker=tracker,
|
|
552
|
-
prompt_name="enterprise_search_prompt",
|
|
553
|
-
user_prompt=prompt,
|
|
554
|
-
llm_response=llm_response,
|
|
555
|
-
)
|
|
501
|
+
llm_answer = await self._generate_llm_answer(llm, prompt)
|
|
556
502
|
|
|
557
|
-
if
|
|
558
|
-
|
|
559
|
-
response = None
|
|
560
|
-
else:
|
|
561
|
-
llm_answer = llm_response.choices[0]
|
|
503
|
+
if self.citation_enabled:
|
|
504
|
+
llm_answer = self.post_process_citations(llm_answer)
|
|
562
505
|
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
logger.debug(f"{logger_key}.llm_answer", llm_answer=llm_answer)
|
|
567
|
-
response = llm_answer
|
|
506
|
+
logger.debug(f"{logger_key}.llm_answer", llm_answer=llm_answer)
|
|
507
|
+
response = llm_answer
|
|
568
508
|
else:
|
|
569
509
|
response = documents.results[0].metadata.get("answer", None)
|
|
570
510
|
if not response:
|
|
@@ -576,6 +516,7 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
|
|
|
576
516
|
"enterprise_search_policy.predict_action_probabilities.no_llm",
|
|
577
517
|
search_results=documents,
|
|
578
518
|
)
|
|
519
|
+
|
|
579
520
|
if response is None:
|
|
580
521
|
return self._create_prediction_internal_error(domain, tracker)
|
|
581
522
|
|
|
@@ -640,18 +581,10 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
|
|
|
640
581
|
|
|
641
582
|
async def _generate_llm_answer(
|
|
642
583
|
self, llm: LLMClient, prompt: Text
|
|
643
|
-
) -> Optional[
|
|
644
|
-
"""Fetches an LLM completion for the provided prompt.
|
|
645
|
-
|
|
646
|
-
Args:
|
|
647
|
-
llm: The LLM client used to get the completion.
|
|
648
|
-
prompt: The prompt text to send to the model.
|
|
649
|
-
|
|
650
|
-
Returns:
|
|
651
|
-
An LLMResponse object, or None if the call fails.
|
|
652
|
-
"""
|
|
584
|
+
) -> Optional[Text]:
|
|
653
585
|
try:
|
|
654
|
-
|
|
586
|
+
llm_response = await llm.acompletion(prompt)
|
|
587
|
+
llm_answer = llm_response.choices[0]
|
|
655
588
|
except Exception as e:
|
|
656
589
|
# unfortunately, langchain does not wrap LLM exceptions which means
|
|
657
590
|
# we have to catch all exceptions here
|
|
@@ -659,7 +592,9 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
|
|
|
659
592
|
"enterprise_search_policy._generate_llm_answer.llm_error",
|
|
660
593
|
error=e,
|
|
661
594
|
)
|
|
662
|
-
|
|
595
|
+
llm_answer = None
|
|
596
|
+
|
|
597
|
+
return llm_answer
|
|
663
598
|
|
|
664
599
|
def _create_prediction(
|
|
665
600
|
self,
|