rasa-pro 3.13.13__py3-none-any.whl → 3.13.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- rasa/constants.py +1 -0
- rasa/core/actions/action_clean_stack.py +32 -0
- rasa/core/actions/constants.py +4 -0
- rasa/core/actions/custom_action_executor.py +70 -12
- rasa/core/actions/grpc_custom_action_executor.py +41 -2
- rasa/core/actions/http_custom_action_executor.py +49 -25
- rasa/core/channels/voice_stream/voice_channel.py +26 -16
- rasa/core/policies/flows/flow_executor.py +20 -6
- rasa/core/run.py +0 -1
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +6 -3
- rasa/dialogue_understanding/generator/single_step/compact_llm_command_generator.py +15 -8
- rasa/dialogue_understanding/generator/single_step/search_ready_llm_command_generator.py +15 -8
- rasa/dialogue_understanding/processor/command_processor.py +12 -3
- rasa/e2e_test/e2e_config.py +4 -3
- rasa/model_manager/socket_bridge.py +1 -2
- rasa/shared/core/flows/flow.py +8 -2
- rasa/shared/core/slots.py +55 -24
- rasa/shared/providers/_configs/azure_openai_client_config.py +4 -5
- rasa/shared/providers/_configs/default_litellm_client_config.py +4 -4
- rasa/shared/providers/_configs/litellm_router_client_config.py +3 -2
- rasa/shared/providers/_configs/openai_client_config.py +5 -7
- rasa/shared/providers/_configs/rasa_llm_client_config.py +4 -4
- rasa/shared/providers/_configs/self_hosted_llm_client_config.py +4 -4
- rasa/shared/providers/llm/_base_litellm_client.py +42 -13
- rasa/shared/providers/llm/litellm_router_llm_client.py +37 -15
- rasa/shared/providers/llm/self_hosted_llm_client.py +34 -32
- rasa/shared/utils/configs.py +5 -8
- rasa/utils/common.py +9 -0
- rasa/utils/endpoints.py +6 -0
- rasa/version.py +1 -1
- {rasa_pro-3.13.13.dist-info → rasa_pro-3.13.14.dist-info}/METADATA +2 -2
- {rasa_pro-3.13.13.dist-info → rasa_pro-3.13.14.dist-info}/RECORD +35 -35
- {rasa_pro-3.13.13.dist-info → rasa_pro-3.13.14.dist-info}/NOTICE +0 -0
- {rasa_pro-3.13.13.dist-info → rasa_pro-3.13.14.dist-info}/WHEEL +0 -0
- {rasa_pro-3.13.13.dist-info → rasa_pro-3.13.14.dist-info}/entry_points.txt +0 -0
rasa/constants.py
CHANGED
|
@@ -27,6 +27,7 @@ ENV_LOG_LEVEL_LIBRARIES = "LOG_LEVEL_LIBRARIES"
|
|
|
27
27
|
ENV_LOG_LEVEL_MATPLOTLIB = "LOG_LEVEL_MATPLOTLIB"
|
|
28
28
|
ENV_LOG_LEVEL_RABBITMQ = "LOG_LEVEL_RABBITMQ"
|
|
29
29
|
ENV_LOG_LEVEL_KAFKA = "LOG_LEVEL_KAFKA"
|
|
30
|
+
ENV_LOG_LEVEL_PYMONGO = "LOG_LEVEL_PYMONGO"
|
|
30
31
|
|
|
31
32
|
DEFAULT_SANIC_WORKERS = 1
|
|
32
33
|
ENV_SANIC_WORKERS = "SANIC_WORKERS"
|
|
@@ -4,9 +4,11 @@ from typing import Any, Dict, List, Optional
|
|
|
4
4
|
|
|
5
5
|
import structlog
|
|
6
6
|
|
|
7
|
+
import rasa.dialogue_understanding.stack.utils
|
|
7
8
|
from rasa.core.actions.action import Action
|
|
8
9
|
from rasa.core.channels import OutputChannel
|
|
9
10
|
from rasa.core.nlg import NaturalLanguageGenerator
|
|
11
|
+
from rasa.dialogue_understanding.patterns.code_change import FLOW_PATTERN_CODE_CHANGE_ID
|
|
10
12
|
from rasa.dialogue_understanding.stack.dialogue_stack import DialogueStack
|
|
11
13
|
from rasa.dialogue_understanding.stack.frames import (
|
|
12
14
|
BaseFlowStackFrame,
|
|
@@ -41,6 +43,15 @@ class ActionCleanStack(Action):
|
|
|
41
43
|
"""Clean the stack."""
|
|
42
44
|
structlogger.debug("action_clean_stack.run")
|
|
43
45
|
new_frames = []
|
|
46
|
+
top_flow_frame = rasa.dialogue_understanding.stack.utils.top_flow_frame(
|
|
47
|
+
tracker.stack, ignore_call_frames=False
|
|
48
|
+
)
|
|
49
|
+
top_user_flow_frame = (
|
|
50
|
+
rasa.dialogue_understanding.stack.utils.top_user_flow_frame(
|
|
51
|
+
tracker.stack, ignore_call_and_link_frames=False
|
|
52
|
+
)
|
|
53
|
+
)
|
|
54
|
+
|
|
44
55
|
# Set all frames to their end step, filter out any non-BaseFlowStackFrames
|
|
45
56
|
for frame in tracker.stack.frames:
|
|
46
57
|
if isinstance(frame, BaseFlowStackFrame):
|
|
@@ -56,4 +67,25 @@ class ActionCleanStack(Action):
|
|
|
56
67
|
new_frames.append(frame)
|
|
57
68
|
new_stack = DialogueStack.from_dict([frame.as_dict() for frame in new_frames])
|
|
58
69
|
|
|
70
|
+
# Check if the action is being called from within a user flow
|
|
71
|
+
if (
|
|
72
|
+
top_flow_frame
|
|
73
|
+
and top_flow_frame.flow_id != FLOW_PATTERN_CODE_CHANGE_ID
|
|
74
|
+
and top_user_flow_frame
|
|
75
|
+
and top_user_flow_frame.flow_id == top_flow_frame.flow_id
|
|
76
|
+
):
|
|
77
|
+
# The action is being called from within a user flow on the stack.
|
|
78
|
+
# If there are other frames on the stack, we need to make sure
|
|
79
|
+
# the last executed frame is the end step of the current user flow so
|
|
80
|
+
# that we can trigger pattern_completed for this user flow.
|
|
81
|
+
new_stack.pop()
|
|
82
|
+
structlogger.debug(
|
|
83
|
+
"action_clean_stack.pushing_user_frame_at_the_bottom_of_stack",
|
|
84
|
+
flow_id=top_user_flow_frame.flow_id,
|
|
85
|
+
)
|
|
86
|
+
new_stack.push(
|
|
87
|
+
top_user_flow_frame,
|
|
88
|
+
index=0,
|
|
89
|
+
)
|
|
90
|
+
|
|
59
91
|
return tracker.create_stack_updated_events(new_stack)
|
rasa/core/actions/constants.py
CHANGED
|
@@ -3,3 +3,7 @@ SELECTIVE_DOMAIN = "enable_selective_domain"
|
|
|
3
3
|
|
|
4
4
|
SSL_CLIENT_CERT_FIELD = "ssl_client_cert"
|
|
5
5
|
SSL_CLIENT_KEY_FIELD = "ssl_client_key"
|
|
6
|
+
|
|
7
|
+
# Special marker key used by EndpointConfig to indicate 449 status
|
|
8
|
+
# without raising an exception
|
|
9
|
+
MISSING_DOMAIN_MARKER = "missing_domain"
|
|
@@ -2,7 +2,10 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
import abc
|
|
4
4
|
import logging
|
|
5
|
-
from
|
|
5
|
+
from enum import Enum
|
|
6
|
+
from typing import TYPE_CHECKING, Any, Dict, Optional, Text
|
|
7
|
+
|
|
8
|
+
from pydantic import BaseModel
|
|
6
9
|
|
|
7
10
|
import rasa
|
|
8
11
|
from rasa.core.actions.action_exceptions import DomainNotFound
|
|
@@ -19,6 +22,23 @@ if TYPE_CHECKING:
|
|
|
19
22
|
logger = logging.getLogger(__name__)
|
|
20
23
|
|
|
21
24
|
|
|
25
|
+
class ActionResultType(Enum):
|
|
26
|
+
SUCCESS = "success"
|
|
27
|
+
RETRY_WITH_DOMAIN = "retry_with_domain"
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class ActionResult(BaseModel):
|
|
31
|
+
"""Result of custom action execution.
|
|
32
|
+
|
|
33
|
+
This is used to avoid raising exceptions for expected conditions
|
|
34
|
+
like missing domain (449 status code), which would otherwise be
|
|
35
|
+
captured by tracing as errors.
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
result_type: ActionResultType
|
|
39
|
+
response: Optional[Dict[Text, Any]] = None
|
|
40
|
+
|
|
41
|
+
|
|
22
42
|
class CustomActionExecutor(abc.ABC):
|
|
23
43
|
"""Interface for custom action executors.
|
|
24
44
|
|
|
@@ -45,6 +65,34 @@ class CustomActionExecutor(abc.ABC):
|
|
|
45
65
|
"""
|
|
46
66
|
pass
|
|
47
67
|
|
|
68
|
+
async def run_with_result(
|
|
69
|
+
self,
|
|
70
|
+
tracker: "DialogueStateTracker",
|
|
71
|
+
domain: "Domain",
|
|
72
|
+
include_domain: bool = False,
|
|
73
|
+
) -> ActionResult:
|
|
74
|
+
"""Executes the custom action and returns a result.
|
|
75
|
+
|
|
76
|
+
This method is used to avoid raising exceptions for expected conditions
|
|
77
|
+
like missing domain, which would otherwise be captured by tracing as errors.
|
|
78
|
+
|
|
79
|
+
By default, this method calls the run method and wraps the response
|
|
80
|
+
for backward compatibility.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
tracker: The current state of the dialogue.
|
|
84
|
+
domain: The domain object containing domain-specific information.
|
|
85
|
+
include_domain: If True, the domain is included in the request.
|
|
86
|
+
|
|
87
|
+
Returns:
|
|
88
|
+
ActionResult containing the response and result type.
|
|
89
|
+
"""
|
|
90
|
+
try:
|
|
91
|
+
response = await self.run(tracker, domain, include_domain)
|
|
92
|
+
return ActionResult(result_type=ActionResultType.SUCCESS, response=response)
|
|
93
|
+
except DomainNotFound:
|
|
94
|
+
return ActionResult(result_type=ActionResultType.RETRY_WITH_DOMAIN)
|
|
95
|
+
|
|
48
96
|
|
|
49
97
|
class NoEndpointCustomActionExecutor(CustomActionExecutor):
|
|
50
98
|
"""Implementation of a custom action executor when endpoint is not set.
|
|
@@ -163,13 +211,13 @@ class RetryCustomActionExecutor(CustomActionExecutor):
|
|
|
163
211
|
domain: "Domain",
|
|
164
212
|
include_domain: bool = False,
|
|
165
213
|
) -> Dict[Text, Any]:
|
|
166
|
-
"""Runs the wrapped custom action executor.
|
|
214
|
+
"""Runs the wrapped custom action executor with retry logic.
|
|
167
215
|
|
|
168
216
|
First request to the action server is made with/without the domain
|
|
169
217
|
as specified by the `include_domain` parameter.
|
|
170
218
|
|
|
171
|
-
If the action server responds with a
|
|
172
|
-
|
|
219
|
+
If the action server responds with a missing domain indication,
|
|
220
|
+
retries the request with the domain included.
|
|
173
221
|
|
|
174
222
|
Args:
|
|
175
223
|
tracker: The current state of the dialogue.
|
|
@@ -178,14 +226,24 @@ class RetryCustomActionExecutor(CustomActionExecutor):
|
|
|
178
226
|
|
|
179
227
|
Returns:
|
|
180
228
|
The response from the execution of the custom action.
|
|
229
|
+
|
|
230
|
+
Raises:
|
|
231
|
+
DomainNotFound: If the action server still requires domain after retry.
|
|
181
232
|
"""
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
233
|
+
result = await self._custom_action_executor.run_with_result(
|
|
234
|
+
tracker,
|
|
235
|
+
domain,
|
|
236
|
+
include_domain=include_domain,
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
if result.result_type == ActionResultType.RETRY_WITH_DOMAIN:
|
|
240
|
+
# Retry with domain included
|
|
241
|
+
result = await self._custom_action_executor.run_with_result(
|
|
190
242
|
tracker, domain, include_domain=True
|
|
191
243
|
)
|
|
244
|
+
|
|
245
|
+
# If still missing domain after retry, raise error
|
|
246
|
+
if result.result_type == ActionResultType.RETRY_WITH_DOMAIN:
|
|
247
|
+
raise DomainNotFound()
|
|
248
|
+
|
|
249
|
+
return result.response if result.response is not None else {}
|
|
@@ -11,6 +11,8 @@ from rasa_sdk.grpc_py import action_webhook_pb2, action_webhook_pb2_grpc
|
|
|
11
11
|
from rasa.core.actions.action_exceptions import DomainNotFound
|
|
12
12
|
from rasa.core.actions.constants import SSL_CLIENT_CERT_FIELD, SSL_CLIENT_KEY_FIELD
|
|
13
13
|
from rasa.core.actions.custom_action_executor import (
|
|
14
|
+
ActionResult,
|
|
15
|
+
ActionResultType,
|
|
14
16
|
CustomActionExecutor,
|
|
15
17
|
CustomActionRequestWriter,
|
|
16
18
|
)
|
|
@@ -101,13 +103,51 @@ class GRPCCustomActionExecutor(CustomActionExecutor):
|
|
|
101
103
|
|
|
102
104
|
Returns:
|
|
103
105
|
Response from the action server.
|
|
106
|
+
Returns empty dict if domain is missing.
|
|
107
|
+
|
|
108
|
+
Raises:
|
|
109
|
+
RasaException: If an error occurs while making the gRPC request
|
|
110
|
+
(other than missing domain).
|
|
104
111
|
"""
|
|
112
|
+
result = await self.run_with_result(tracker, domain, include_domain)
|
|
113
|
+
|
|
114
|
+
# Return empty dict for retry cases to avoid raising exceptions
|
|
115
|
+
# RetryCustomActionExecutor will handle the retry logic
|
|
116
|
+
if result.result_type == ActionResultType.RETRY_WITH_DOMAIN:
|
|
117
|
+
return {}
|
|
118
|
+
|
|
119
|
+
return result.response if result.response is not None else {}
|
|
120
|
+
|
|
121
|
+
async def run_with_result(
|
|
122
|
+
self,
|
|
123
|
+
tracker: "DialogueStateTracker",
|
|
124
|
+
domain: "Domain",
|
|
125
|
+
include_domain: bool = False,
|
|
126
|
+
) -> ActionResult:
|
|
127
|
+
"""Execute the custom action and return an ActionResult.
|
|
128
|
+
|
|
129
|
+
This method avoids raising DomainNotFound exception for missing domain,
|
|
130
|
+
instead returning an ActionResult with RETRY_WITH_DOMAIN type.
|
|
131
|
+
This prevents tracing from capturing this expected condition as an error.
|
|
105
132
|
|
|
133
|
+
Args:
|
|
134
|
+
tracker: Tracker for the current conversation.
|
|
135
|
+
domain: Domain of the assistant.
|
|
136
|
+
include_domain: If True, the domain is included in the request.
|
|
137
|
+
|
|
138
|
+
Returns:
|
|
139
|
+
ActionResult containing the response and result type.
|
|
140
|
+
"""
|
|
106
141
|
request = self._create_payload(
|
|
107
142
|
tracker=tracker, domain=domain, include_domain=include_domain
|
|
108
143
|
)
|
|
109
144
|
|
|
110
|
-
|
|
145
|
+
try:
|
|
146
|
+
response = self._request(request)
|
|
147
|
+
return ActionResult(result_type=ActionResultType.SUCCESS, response=response)
|
|
148
|
+
except DomainNotFound:
|
|
149
|
+
# Return retry result instead of raising DomainNotFound
|
|
150
|
+
return ActionResult(result_type=ActionResultType.RETRY_WITH_DOMAIN)
|
|
111
151
|
|
|
112
152
|
def _request(
|
|
113
153
|
self,
|
|
@@ -121,7 +161,6 @@ class GRPCCustomActionExecutor(CustomActionExecutor):
|
|
|
121
161
|
Returns:
|
|
122
162
|
Response from the action server.
|
|
123
163
|
"""
|
|
124
|
-
|
|
125
164
|
client = self._create_grpc_client()
|
|
126
165
|
metadata = self._build_metadata()
|
|
127
166
|
try:
|
|
@@ -4,8 +4,11 @@ from typing import TYPE_CHECKING, Any, Dict, Optional
|
|
|
4
4
|
|
|
5
5
|
import aiohttp
|
|
6
6
|
|
|
7
|
-
from rasa.core.actions.action_exceptions import ActionExecutionRejection
|
|
7
|
+
from rasa.core.actions.action_exceptions import ActionExecutionRejection
|
|
8
|
+
from rasa.core.actions.constants import MISSING_DOMAIN_MARKER
|
|
8
9
|
from rasa.core.actions.custom_action_executor import (
|
|
10
|
+
ActionResult,
|
|
11
|
+
ActionResultType,
|
|
9
12
|
CustomActionExecutor,
|
|
10
13
|
CustomActionRequestWriter,
|
|
11
14
|
)
|
|
@@ -18,12 +21,12 @@ from rasa.shared.core.domain import Domain
|
|
|
18
21
|
from rasa.shared.core.trackers import DialogueStateTracker
|
|
19
22
|
from rasa.shared.exceptions import RasaException
|
|
20
23
|
from rasa.utils.common import get_bool_env_variable
|
|
24
|
+
from rasa.utils.endpoints import ClientResponseError, EndpointConfig
|
|
21
25
|
|
|
22
26
|
if TYPE_CHECKING:
|
|
23
27
|
from rasa.shared.core.domain import Domain
|
|
24
28
|
from rasa.shared.core.trackers import DialogueStateTracker
|
|
25
29
|
|
|
26
|
-
from rasa.utils.endpoints import ClientResponseError, EndpointConfig
|
|
27
30
|
|
|
28
31
|
logger = logging.getLogger(__name__)
|
|
29
32
|
|
|
@@ -62,9 +65,40 @@ class HTTPCustomActionExecutor(CustomActionExecutor):
|
|
|
62
65
|
|
|
63
66
|
Returns:
|
|
64
67
|
A dictionary containing the response from the custom action endpoint.
|
|
68
|
+
Returns empty dict if domain is missing (449 status).
|
|
65
69
|
|
|
66
70
|
Raises:
|
|
67
|
-
RasaException: If an error occurs while making the HTTP request
|
|
71
|
+
RasaException: If an error occurs while making the HTTP request
|
|
72
|
+
(other than missing domain).
|
|
73
|
+
"""
|
|
74
|
+
result = await self.run_with_result(tracker, domain, include_domain)
|
|
75
|
+
|
|
76
|
+
# Return empty dict for retry cases to avoid raising exceptions
|
|
77
|
+
# RetryCustomActionExecutor will handle the retry logic
|
|
78
|
+
if result.result_type == ActionResultType.RETRY_WITH_DOMAIN:
|
|
79
|
+
return {}
|
|
80
|
+
|
|
81
|
+
return result.response if result.response is not None else {}
|
|
82
|
+
|
|
83
|
+
async def run_with_result(
|
|
84
|
+
self,
|
|
85
|
+
tracker: "DialogueStateTracker",
|
|
86
|
+
domain: Optional["Domain"] = None,
|
|
87
|
+
include_domain: bool = False,
|
|
88
|
+
) -> ActionResult:
|
|
89
|
+
"""Execute the custom action and return an ActionResult.
|
|
90
|
+
|
|
91
|
+
This method avoids raising DomainNotFound exception for 449 status code,
|
|
92
|
+
instead returning an ActionResult with RETRY_WITH_DOMAIN type.
|
|
93
|
+
This prevents tracing from capturing this expected condition as an error.
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
tracker: The current state of the dialogue.
|
|
97
|
+
domain: The domain object containing domain-specific information.
|
|
98
|
+
include_domain: If True, the domain is included in the request.
|
|
99
|
+
|
|
100
|
+
Returns:
|
|
101
|
+
ActionResult containing the response and result type.
|
|
68
102
|
"""
|
|
69
103
|
from rasa.core.actions.action import RemoteActionJSONValidator
|
|
70
104
|
|
|
@@ -77,14 +111,23 @@ class HTTPCustomActionExecutor(CustomActionExecutor):
|
|
|
77
111
|
tracker=tracker, domain=domain, include_domain=include_domain
|
|
78
112
|
)
|
|
79
113
|
|
|
80
|
-
|
|
114
|
+
assert self.action_endpoint is not None
|
|
115
|
+
response = await self.action_endpoint.request(
|
|
116
|
+
json=json_body,
|
|
117
|
+
method="post",
|
|
118
|
+
timeout=DEFAULT_REQUEST_TIMEOUT,
|
|
119
|
+
compress=self.should_compress,
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
# Check if we got the special marker for 449 status (missing domain)
|
|
123
|
+
if isinstance(response, dict) and response.get(MISSING_DOMAIN_MARKER):
|
|
124
|
+
return ActionResult(result_type=ActionResultType.RETRY_WITH_DOMAIN)
|
|
81
125
|
|
|
82
126
|
if response is None:
|
|
83
127
|
response = {}
|
|
84
128
|
|
|
85
129
|
RemoteActionJSONValidator.validate(response)
|
|
86
|
-
|
|
87
|
-
return response
|
|
130
|
+
return ActionResult(result_type=ActionResultType.SUCCESS, response=response)
|
|
88
131
|
|
|
89
132
|
except ClientResponseError as e:
|
|
90
133
|
if e.status == 400:
|
|
@@ -131,22 +174,3 @@ class HTTPCustomActionExecutor(CustomActionExecutor):
|
|
|
131
174
|
"and returns a 200 once the action is executed. "
|
|
132
175
|
"Error: {}".format(self.action_name, status, e)
|
|
133
176
|
)
|
|
134
|
-
|
|
135
|
-
async def _perform_request_with_retries(
|
|
136
|
-
self,
|
|
137
|
-
json_body: Dict[str, Any],
|
|
138
|
-
) -> Any:
|
|
139
|
-
"""Attempts to perform the request with retries if necessary."""
|
|
140
|
-
assert self.action_endpoint is not None
|
|
141
|
-
try:
|
|
142
|
-
return await self.action_endpoint.request(
|
|
143
|
-
json=json_body,
|
|
144
|
-
method="post",
|
|
145
|
-
timeout=DEFAULT_REQUEST_TIMEOUT,
|
|
146
|
-
compress=self.should_compress,
|
|
147
|
-
)
|
|
148
|
-
except ClientResponseError as e:
|
|
149
|
-
# Repeat the request because Domain was not in the payload
|
|
150
|
-
if e.status == 449:
|
|
151
|
-
raise DomainNotFound()
|
|
152
|
-
raise e
|
|
@@ -176,6 +176,8 @@ class VoiceOutputChannel(OutputChannel):
|
|
|
176
176
|
async def send_start_marker(self, recipient_id: str) -> None:
|
|
177
177
|
"""Send a marker message before the first audio chunk."""
|
|
178
178
|
# Default implementation uses the generic marker message
|
|
179
|
+
call_state.is_bot_speaking = True # type: ignore[attr-defined]
|
|
180
|
+
VoiceInputChannel._cancel_silence_timeout_watcher()
|
|
179
181
|
await self.send_marker_message(recipient_id)
|
|
180
182
|
|
|
181
183
|
async def send_intermediate_marker(self, recipient_id: str) -> None:
|
|
@@ -212,10 +214,6 @@ class VoiceOutputChannel(OutputChannel):
|
|
|
212
214
|
) -> None:
|
|
213
215
|
text = remove_emojis(text)
|
|
214
216
|
self.update_silence_timeout()
|
|
215
|
-
cached_audio_bytes = self.tts_cache.get(text)
|
|
216
|
-
collected_audio_bytes = RasaAudioBytes(b"")
|
|
217
|
-
seconds_marker = -1
|
|
218
|
-
last_sent_offset = 0
|
|
219
217
|
logger.debug("voice_channel.sending_audio", text=text)
|
|
220
218
|
|
|
221
219
|
# Send start marker before first chunk
|
|
@@ -224,17 +222,11 @@ class VoiceOutputChannel(OutputChannel):
|
|
|
224
222
|
except (WebsocketClosed, ServerError):
|
|
225
223
|
call_state.connection_failed = True # type: ignore[attr-defined]
|
|
226
224
|
|
|
227
|
-
|
|
228
|
-
audio_stream = self.chunk_audio(cached_audio_bytes)
|
|
229
|
-
else:
|
|
230
|
-
# Todo: make kwargs compatible with engine config
|
|
231
|
-
synth_config = self.tts_engine.config.__class__.from_dict({})
|
|
232
|
-
try:
|
|
233
|
-
audio_stream = self.tts_engine.synthesize(text, synth_config)
|
|
234
|
-
except TTSError:
|
|
235
|
-
# TODO: add message that works without tts, e.g. loading from disc
|
|
236
|
-
audio_stream = self.chunk_audio(generate_silence())
|
|
225
|
+
audio_stream = await self._create_audio_stream(text)
|
|
237
226
|
|
|
227
|
+
collected_audio_bytes = RasaAudioBytes(b"")
|
|
228
|
+
last_sent_offset = 0
|
|
229
|
+
seconds_marker = -1
|
|
238
230
|
async for audio_bytes in audio_stream:
|
|
239
231
|
collected_audio_bytes = RasaAudioBytes(collected_audio_bytes + audio_bytes)
|
|
240
232
|
|
|
@@ -249,6 +241,8 @@ class VoiceOutputChannel(OutputChannel):
|
|
|
249
241
|
await self.send_audio_bytes(recipient_id, new_bytes)
|
|
250
242
|
last_sent_offset = len(collected_audio_bytes)
|
|
251
243
|
|
|
244
|
+
# seconds of audio rounded down to floor number
|
|
245
|
+
# e.g 7 // 2 = 3
|
|
252
246
|
full_seconds_of_audio = len(collected_audio_bytes) // HERTZ
|
|
253
247
|
if full_seconds_of_audio > seconds_marker:
|
|
254
248
|
await self.send_intermediate_marker(recipient_id)
|
|
@@ -275,7 +269,7 @@ class VoiceOutputChannel(OutputChannel):
|
|
|
275
269
|
pass
|
|
276
270
|
call_state.latest_bot_audio_id = self.latest_message_id # type: ignore[attr-defined]
|
|
277
271
|
|
|
278
|
-
if not
|
|
272
|
+
if not self.tts_cache.get(text):
|
|
279
273
|
self.tts_cache.put(text, collected_audio_bytes)
|
|
280
274
|
|
|
281
275
|
async def send_audio_bytes(
|
|
@@ -300,6 +294,22 @@ class VoiceOutputChannel(OutputChannel):
|
|
|
300
294
|
async def hangup(self, recipient_id: str, **kwargs: Any) -> None:
|
|
301
295
|
call_state.should_hangup = True # type: ignore[attr-defined]
|
|
302
296
|
|
|
297
|
+
async def _create_audio_stream(self, text: str) -> AsyncIterator[RasaAudioBytes]:
|
|
298
|
+
cached_audio_bytes = self.tts_cache.get(text)
|
|
299
|
+
|
|
300
|
+
if cached_audio_bytes:
|
|
301
|
+
audio_stream = self.chunk_audio(cached_audio_bytes)
|
|
302
|
+
else:
|
|
303
|
+
# Todo: make kwargs compatible with engine config
|
|
304
|
+
synth_config = self.tts_engine.config.__class__.from_dict({})
|
|
305
|
+
try:
|
|
306
|
+
audio_stream = self.tts_engine.synthesize(text, synth_config)
|
|
307
|
+
except TTSError:
|
|
308
|
+
# TODO: add message that works without tts, e.g. loading from disc
|
|
309
|
+
audio_stream = self.chunk_audio(generate_silence())
|
|
310
|
+
|
|
311
|
+
return audio_stream
|
|
312
|
+
|
|
303
313
|
|
|
304
314
|
class VoiceInputChannel(InputChannel):
|
|
305
315
|
# All children of this class require a voice license to be used.
|
|
@@ -435,7 +445,7 @@ class VoiceInputChannel(InputChannel):
|
|
|
435
445
|
# relevant when the bot speaks multiple messages in one turn
|
|
436
446
|
self._cancel_silence_timeout_watcher()
|
|
437
447
|
|
|
438
|
-
#
|
|
448
|
+
# bot just stopped speaking, starting a watcher for silence timeout
|
|
439
449
|
if was_bot_speaking_before and not is_bot_speaking_after:
|
|
440
450
|
logger.debug("voice_channel.bot_stopped_speaking")
|
|
441
451
|
self._cancel_silence_timeout_watcher()
|
|
@@ -333,6 +333,10 @@ def reset_scoped_slots(
|
|
|
333
333
|
flow_persistable_slots = current_flow.persisted_slots
|
|
334
334
|
|
|
335
335
|
for step in current_flow.steps_with_calls_resolved:
|
|
336
|
+
# take persisted slots from called flows into consideration
|
|
337
|
+
# before resetting slots
|
|
338
|
+
if isinstance(step, CallFlowStep) and step.called_flow_reference:
|
|
339
|
+
flow_persistable_slots.extend(step.called_flow_reference.persisted_slots)
|
|
336
340
|
if isinstance(step, CollectInformationFlowStep):
|
|
337
341
|
# reset all slots scoped to the flow
|
|
338
342
|
slot_name = step.collect
|
|
@@ -344,7 +348,22 @@ def reset_scoped_slots(
|
|
|
344
348
|
# slots set by the set slots step should be reset after the flow ends
|
|
345
349
|
# unless they are also used in a collect step where `reset_after_flow_ends`
|
|
346
350
|
# is set to `False` or set in the `persisted_slots` list.
|
|
347
|
-
resettable_set_slots =
|
|
351
|
+
resettable_set_slots = _get_resettable_set_slots(
|
|
352
|
+
current_flow, not_resettable_slot_names, flow_persistable_slots
|
|
353
|
+
)
|
|
354
|
+
for name in resettable_set_slots:
|
|
355
|
+
_reset_slot(name, tracker)
|
|
356
|
+
|
|
357
|
+
return events
|
|
358
|
+
|
|
359
|
+
|
|
360
|
+
def _get_resettable_set_slots(
|
|
361
|
+
current_flow: Flow,
|
|
362
|
+
not_resettable_slot_names: set[Text],
|
|
363
|
+
flow_persistable_slots: List[Text],
|
|
364
|
+
) -> List[Text]:
|
|
365
|
+
"""Get list of slot names from SetSlotsFlowStep that should be reset."""
|
|
366
|
+
return [
|
|
348
367
|
slot["key"]
|
|
349
368
|
for step in current_flow.steps_with_calls_resolved
|
|
350
369
|
if isinstance(step, SetSlotsFlowStep)
|
|
@@ -353,11 +372,6 @@ def reset_scoped_slots(
|
|
|
353
372
|
and slot["key"] not in flow_persistable_slots
|
|
354
373
|
]
|
|
355
374
|
|
|
356
|
-
for name in resettable_set_slots:
|
|
357
|
-
_reset_slot(name, tracker)
|
|
358
|
-
|
|
359
|
-
return events
|
|
360
|
-
|
|
361
375
|
|
|
362
376
|
def advance_flows(
|
|
363
377
|
tracker: DialogueStateTracker, available_actions: List[str], flows: FlowsList
|
rasa/core/run.py
CHANGED
|
@@ -60,7 +60,9 @@ structlogger = structlog.get_logger()
|
|
|
60
60
|
class LLMBasedCommandGenerator(
|
|
61
61
|
LLMHealthCheckMixin, GraphComponent, CommandGenerator, ABC
|
|
62
62
|
):
|
|
63
|
-
"""
|
|
63
|
+
"""This class provides common functionality for all LLM-based command generators.
|
|
64
|
+
|
|
65
|
+
An abstract class defining interface and common functionality
|
|
64
66
|
of an LLM-based command generators.
|
|
65
67
|
"""
|
|
66
68
|
|
|
@@ -172,8 +174,9 @@ class LLMBasedCommandGenerator(
|
|
|
172
174
|
def train(
|
|
173
175
|
self, training_data: TrainingData, flows: FlowsList, domain: Domain
|
|
174
176
|
) -> Resource:
|
|
175
|
-
"""
|
|
176
|
-
|
|
177
|
+
"""Trains the LLM-based command generator and prepares flow retrieval data.
|
|
178
|
+
|
|
179
|
+
Stores all flows into a vector store.
|
|
177
180
|
"""
|
|
178
181
|
self.perform_llm_health_check(
|
|
179
182
|
self.config.get(LLM_CONFIG_KEY),
|
|
@@ -132,7 +132,20 @@ class CompactLLMCommandGenerator(SingleStepBasedLLMCommandGenerator):
|
|
|
132
132
|
if prompt_template is not None:
|
|
133
133
|
return prompt_template
|
|
134
134
|
|
|
135
|
-
#
|
|
135
|
+
# Try to load the template from the given path or fallback to the default for
|
|
136
|
+
# the component.
|
|
137
|
+
custom_prompt_template_path = config.get(PROMPT_TEMPLATE_CONFIG_KEY)
|
|
138
|
+
if custom_prompt_template_path is not None:
|
|
139
|
+
custom_prompt_template = get_prompt_template(
|
|
140
|
+
custom_prompt_template_path,
|
|
141
|
+
None, # Default will be based on the model
|
|
142
|
+
log_source_component=log_source_component,
|
|
143
|
+
log_source_method=log_context,
|
|
144
|
+
)
|
|
145
|
+
if custom_prompt_template is not None:
|
|
146
|
+
return custom_prompt_template
|
|
147
|
+
|
|
148
|
+
# Fallback to the default prompt template based on the model.
|
|
136
149
|
default_command_prompt_template = get_default_prompt_template_based_on_model(
|
|
137
150
|
llm_config=config.get(LLM_CONFIG_KEY, {}) or {},
|
|
138
151
|
model_prompt_mapping=cls.get_model_prompt_mapper(),
|
|
@@ -142,10 +155,4 @@ class CompactLLMCommandGenerator(SingleStepBasedLLMCommandGenerator):
|
|
|
142
155
|
log_source_method=log_context,
|
|
143
156
|
)
|
|
144
157
|
|
|
145
|
-
|
|
146
|
-
return get_prompt_template(
|
|
147
|
-
config.get(PROMPT_TEMPLATE_CONFIG_KEY),
|
|
148
|
-
default_command_prompt_template,
|
|
149
|
-
log_source_component=log_source_component,
|
|
150
|
-
log_source_method=log_context,
|
|
151
|
-
)
|
|
158
|
+
return default_command_prompt_template
|
|
@@ -128,7 +128,20 @@ class SearchReadyLLMCommandGenerator(SingleStepBasedLLMCommandGenerator):
|
|
|
128
128
|
if prompt_template is not None:
|
|
129
129
|
return prompt_template
|
|
130
130
|
|
|
131
|
-
#
|
|
131
|
+
# Try to load the template from the given path or fallback to the default for
|
|
132
|
+
# the component.
|
|
133
|
+
custom_prompt_template_path = config.get(PROMPT_TEMPLATE_CONFIG_KEY)
|
|
134
|
+
if custom_prompt_template_path is not None:
|
|
135
|
+
custom_prompt_template = get_prompt_template(
|
|
136
|
+
custom_prompt_template_path,
|
|
137
|
+
None, # Default will be based on the model
|
|
138
|
+
log_source_component=log_source_component,
|
|
139
|
+
log_source_method=log_context,
|
|
140
|
+
)
|
|
141
|
+
if custom_prompt_template is not None:
|
|
142
|
+
return custom_prompt_template
|
|
143
|
+
|
|
144
|
+
# Fallback to the default prompt template based on the model.
|
|
132
145
|
default_command_prompt_template = get_default_prompt_template_based_on_model(
|
|
133
146
|
llm_config=config.get(LLM_CONFIG_KEY, {}) or {},
|
|
134
147
|
model_prompt_mapping=cls.get_model_prompt_mapper(),
|
|
@@ -138,10 +151,4 @@ class SearchReadyLLMCommandGenerator(SingleStepBasedLLMCommandGenerator):
|
|
|
138
151
|
log_source_method=log_context,
|
|
139
152
|
)
|
|
140
153
|
|
|
141
|
-
|
|
142
|
-
return get_prompt_template(
|
|
143
|
-
config.get(PROMPT_TEMPLATE_CONFIG_KEY),
|
|
144
|
-
default_command_prompt_template,
|
|
145
|
-
log_source_component=log_source_component,
|
|
146
|
-
log_source_method=log_context,
|
|
147
|
-
)
|
|
154
|
+
return default_command_prompt_template
|
|
@@ -476,6 +476,18 @@ def clean_up_commands(
|
|
|
476
476
|
else:
|
|
477
477
|
clean_commands.append(command)
|
|
478
478
|
|
|
479
|
+
# ensure that there is only one command of a certain command type
|
|
480
|
+
clean_commands = ensure_max_number_of_command_type(
|
|
481
|
+
clean_commands, CannotHandleCommand, 1
|
|
482
|
+
)
|
|
483
|
+
clean_commands = ensure_max_number_of_command_type(
|
|
484
|
+
clean_commands, RepeatBotMessagesCommand, 1
|
|
485
|
+
)
|
|
486
|
+
clean_commands = ensure_max_number_of_command_type(
|
|
487
|
+
clean_commands, ChitChatAnswerCommand, 1
|
|
488
|
+
)
|
|
489
|
+
|
|
490
|
+
# filter out cannot handle commands if there are other commands present
|
|
479
491
|
# when coexistence is enabled, by default there will be a SetSlotCommand
|
|
480
492
|
# for the ROUTE_TO_CALM_SLOT slot.
|
|
481
493
|
if tracker.has_coexistence_routing_slot and len(clean_commands) > 2:
|
|
@@ -483,9 +495,6 @@ def clean_up_commands(
|
|
|
483
495
|
elif not tracker.has_coexistence_routing_slot and len(clean_commands) > 1:
|
|
484
496
|
clean_commands = filter_cannot_handle_command(clean_commands)
|
|
485
497
|
|
|
486
|
-
clean_commands = ensure_max_number_of_command_type(
|
|
487
|
-
clean_commands, RepeatBotMessagesCommand, 1
|
|
488
|
-
)
|
|
489
498
|
structlogger.debug(
|
|
490
499
|
"command_processor.clean_up_commands.final_commands",
|
|
491
500
|
command=clean_commands,
|
rasa/e2e_test/e2e_config.py
CHANGED
|
@@ -72,9 +72,10 @@ class LLMJudgeConfig(BaseModel):
|
|
|
72
72
|
|
|
73
73
|
llm_config = resolve_model_client_config(llm_config)
|
|
74
74
|
llm_config, llm_extra_parameters = cls.extract_attributes(llm_config)
|
|
75
|
-
|
|
76
|
-
llm_config
|
|
77
|
-
|
|
75
|
+
if not llm_config:
|
|
76
|
+
llm_config = combine_custom_and_default_config(
|
|
77
|
+
llm_config, cls.get_default_llm_config()
|
|
78
|
+
)
|
|
78
79
|
embeddings_config = resolve_model_client_config(embeddings)
|
|
79
80
|
embeddings_config, embeddings_extra_parameters = cls.extract_attributes(
|
|
80
81
|
embeddings_config
|
|
@@ -2,8 +2,7 @@ import json
|
|
|
2
2
|
from typing import Any, Dict, Optional
|
|
3
3
|
|
|
4
4
|
import structlog
|
|
5
|
-
from socketio import AsyncServer
|
|
6
|
-
from socketio.asyncio_client import AsyncClient
|
|
5
|
+
from socketio import AsyncClient, AsyncServer # type: ignore[attr-defined]
|
|
7
6
|
from socketio.exceptions import ConnectionRefusedError
|
|
8
7
|
|
|
9
8
|
from rasa.model_manager.runner_service import BotSession
|