rasa-pro 3.14.1__py3-none-any.whl → 3.14.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- rasa/constants.py +1 -0
- rasa/core/actions/action_clean_stack.py +32 -0
- rasa/core/actions/constants.py +4 -0
- rasa/core/actions/custom_action_executor.py +70 -12
- rasa/core/actions/grpc_custom_action_executor.py +41 -2
- rasa/core/actions/http_custom_action_executor.py +49 -25
- rasa/core/channels/voice_stream/browser_audio.py +3 -3
- rasa/core/channels/voice_stream/voice_channel.py +27 -17
- rasa/core/config/credentials.py +3 -3
- rasa/core/policies/flows/flow_executor.py +49 -29
- rasa/core/run.py +21 -5
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +6 -3
- rasa/dialogue_understanding/generator/single_step/compact_llm_command_generator.py +15 -7
- rasa/dialogue_understanding/generator/single_step/search_ready_llm_command_generator.py +15 -8
- rasa/dialogue_understanding/processor/command_processor.py +13 -7
- rasa/e2e_test/e2e_config.py +4 -3
- rasa/engine/recipes/default_components.py +16 -6
- rasa/graph_components/validators/default_recipe_validator.py +10 -4
- rasa/nlu/classifiers/diet_classifier.py +2 -0
- rasa/shared/core/flows/flow.py +8 -2
- rasa/shared/core/slots.py +55 -24
- rasa/shared/providers/_configs/azure_openai_client_config.py +4 -5
- rasa/shared/providers/_configs/default_litellm_client_config.py +4 -4
- rasa/shared/providers/_configs/litellm_router_client_config.py +3 -2
- rasa/shared/providers/_configs/openai_client_config.py +5 -7
- rasa/shared/providers/_configs/rasa_llm_client_config.py +4 -4
- rasa/shared/providers/_configs/self_hosted_llm_client_config.py +4 -4
- rasa/shared/providers/llm/_base_litellm_client.py +42 -14
- rasa/shared/providers/llm/litellm_router_llm_client.py +38 -15
- rasa/shared/providers/llm/self_hosted_llm_client.py +34 -32
- rasa/shared/utils/common.py +9 -1
- rasa/shared/utils/configs.py +5 -8
- rasa/utils/common.py +9 -0
- rasa/utils/endpoints.py +6 -0
- rasa/utils/installation_utils.py +111 -0
- rasa/utils/tensorflow/callback.py +2 -0
- rasa/utils/tensorflow/models.py +3 -0
- rasa/utils/train_utils.py +2 -0
- rasa/version.py +1 -1
- {rasa_pro-3.14.1.dist-info → rasa_pro-3.14.2.dist-info}/METADATA +2 -2
- {rasa_pro-3.14.1.dist-info → rasa_pro-3.14.2.dist-info}/RECORD +44 -43
- {rasa_pro-3.14.1.dist-info → rasa_pro-3.14.2.dist-info}/NOTICE +0 -0
- {rasa_pro-3.14.1.dist-info → rasa_pro-3.14.2.dist-info}/WHEEL +0 -0
- {rasa_pro-3.14.1.dist-info → rasa_pro-3.14.2.dist-info}/entry_points.txt +0 -0
rasa/constants.py
CHANGED
|
@@ -33,6 +33,7 @@ ENV_MCP_LOGGING_ENABLED = "MCP_LOGGING_ENABLED"
|
|
|
33
33
|
ENV_LOG_LEVEL_MATPLOTLIB = "LOG_LEVEL_MATPLOTLIB"
|
|
34
34
|
ENV_LOG_LEVEL_RABBITMQ = "LOG_LEVEL_RABBITMQ"
|
|
35
35
|
ENV_LOG_LEVEL_KAFKA = "LOG_LEVEL_KAFKA"
|
|
36
|
+
ENV_LOG_LEVEL_PYMONGO = "LOG_LEVEL_PYMONGO"
|
|
36
37
|
|
|
37
38
|
DEFAULT_SANIC_WORKERS = 1
|
|
38
39
|
ENV_SANIC_WORKERS = "SANIC_WORKERS"
|
|
@@ -4,9 +4,11 @@ from typing import Any, Dict, List, Optional
|
|
|
4
4
|
|
|
5
5
|
import structlog
|
|
6
6
|
|
|
7
|
+
import rasa.dialogue_understanding.stack.utils
|
|
7
8
|
from rasa.core.actions.action import Action
|
|
8
9
|
from rasa.core.channels import OutputChannel
|
|
9
10
|
from rasa.core.nlg import NaturalLanguageGenerator
|
|
11
|
+
from rasa.dialogue_understanding.patterns.code_change import FLOW_PATTERN_CODE_CHANGE_ID
|
|
10
12
|
from rasa.dialogue_understanding.stack.dialogue_stack import DialogueStack
|
|
11
13
|
from rasa.dialogue_understanding.stack.frames import (
|
|
12
14
|
BaseFlowStackFrame,
|
|
@@ -41,6 +43,15 @@ class ActionCleanStack(Action):
|
|
|
41
43
|
"""Clean the stack."""
|
|
42
44
|
structlogger.debug("action_clean_stack.run")
|
|
43
45
|
new_frames = []
|
|
46
|
+
top_flow_frame = rasa.dialogue_understanding.stack.utils.top_flow_frame(
|
|
47
|
+
tracker.stack, ignore_call_frames=False
|
|
48
|
+
)
|
|
49
|
+
top_user_flow_frame = (
|
|
50
|
+
rasa.dialogue_understanding.stack.utils.top_user_flow_frame(
|
|
51
|
+
tracker.stack, ignore_call_and_link_frames=False
|
|
52
|
+
)
|
|
53
|
+
)
|
|
54
|
+
|
|
44
55
|
# Set all frames to their end step, filter out any non-BaseFlowStackFrames
|
|
45
56
|
for frame in tracker.stack.frames:
|
|
46
57
|
if isinstance(frame, BaseFlowStackFrame):
|
|
@@ -56,4 +67,25 @@ class ActionCleanStack(Action):
|
|
|
56
67
|
new_frames.append(frame)
|
|
57
68
|
new_stack = DialogueStack.from_dict([frame.as_dict() for frame in new_frames])
|
|
58
69
|
|
|
70
|
+
# Check if the action is being called from within a user flow
|
|
71
|
+
if (
|
|
72
|
+
top_flow_frame
|
|
73
|
+
and top_flow_frame.flow_id != FLOW_PATTERN_CODE_CHANGE_ID
|
|
74
|
+
and top_user_flow_frame
|
|
75
|
+
and top_user_flow_frame.flow_id == top_flow_frame.flow_id
|
|
76
|
+
):
|
|
77
|
+
# The action is being called from within a user flow on the stack.
|
|
78
|
+
# If there are other frames on the stack, we need to make sure
|
|
79
|
+
# the last executed frame is the end step of the current user flow so
|
|
80
|
+
# that we can trigger pattern_completed for this user flow.
|
|
81
|
+
new_stack.pop()
|
|
82
|
+
structlogger.debug(
|
|
83
|
+
"action_clean_stack.pushing_user_frame_at_the_bottom_of_stack",
|
|
84
|
+
flow_id=top_user_flow_frame.flow_id,
|
|
85
|
+
)
|
|
86
|
+
new_stack.push(
|
|
87
|
+
top_user_flow_frame,
|
|
88
|
+
index=0,
|
|
89
|
+
)
|
|
90
|
+
|
|
59
91
|
return tracker.create_stack_updated_events(new_stack)
|
rasa/core/actions/constants.py
CHANGED
|
@@ -3,3 +3,7 @@ SELECTIVE_DOMAIN = "enable_selective_domain"
|
|
|
3
3
|
|
|
4
4
|
SSL_CLIENT_CERT_FIELD = "ssl_client_cert"
|
|
5
5
|
SSL_CLIENT_KEY_FIELD = "ssl_client_key"
|
|
6
|
+
|
|
7
|
+
# Special marker key used by EndpointConfig to indicate 449 status
|
|
8
|
+
# without raising an exception
|
|
9
|
+
MISSING_DOMAIN_MARKER = "missing_domain"
|
|
@@ -2,7 +2,10 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
import abc
|
|
4
4
|
import logging
|
|
5
|
-
from
|
|
5
|
+
from enum import Enum
|
|
6
|
+
from typing import TYPE_CHECKING, Any, Dict, Optional, Text
|
|
7
|
+
|
|
8
|
+
from pydantic import BaseModel
|
|
6
9
|
|
|
7
10
|
import rasa
|
|
8
11
|
from rasa.core.actions.action_exceptions import DomainNotFound
|
|
@@ -19,6 +22,23 @@ if TYPE_CHECKING:
|
|
|
19
22
|
logger = logging.getLogger(__name__)
|
|
20
23
|
|
|
21
24
|
|
|
25
|
+
class ActionResultType(Enum):
|
|
26
|
+
SUCCESS = "success"
|
|
27
|
+
RETRY_WITH_DOMAIN = "retry_with_domain"
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class ActionResult(BaseModel):
|
|
31
|
+
"""Result of custom action execution.
|
|
32
|
+
|
|
33
|
+
This is used to avoid raising exceptions for expected conditions
|
|
34
|
+
like missing domain (449 status code), which would otherwise be
|
|
35
|
+
captured by tracing as errors.
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
result_type: ActionResultType
|
|
39
|
+
response: Optional[Dict[Text, Any]] = None
|
|
40
|
+
|
|
41
|
+
|
|
22
42
|
class CustomActionExecutor(abc.ABC):
|
|
23
43
|
"""Interface for custom action executors.
|
|
24
44
|
|
|
@@ -45,6 +65,34 @@ class CustomActionExecutor(abc.ABC):
|
|
|
45
65
|
"""
|
|
46
66
|
pass
|
|
47
67
|
|
|
68
|
+
async def run_with_result(
|
|
69
|
+
self,
|
|
70
|
+
tracker: "DialogueStateTracker",
|
|
71
|
+
domain: "Domain",
|
|
72
|
+
include_domain: bool = False,
|
|
73
|
+
) -> ActionResult:
|
|
74
|
+
"""Executes the custom action and returns a result.
|
|
75
|
+
|
|
76
|
+
This method is used to avoid raising exceptions for expected conditions
|
|
77
|
+
like missing domain, which would otherwise be captured by tracing as errors.
|
|
78
|
+
|
|
79
|
+
By default, this method calls the run method and wraps the response
|
|
80
|
+
for backward compatibility.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
tracker: The current state of the dialogue.
|
|
84
|
+
domain: The domain object containing domain-specific information.
|
|
85
|
+
include_domain: If True, the domain is included in the request.
|
|
86
|
+
|
|
87
|
+
Returns:
|
|
88
|
+
ActionResult containing the response and result type.
|
|
89
|
+
"""
|
|
90
|
+
try:
|
|
91
|
+
response = await self.run(tracker, domain, include_domain)
|
|
92
|
+
return ActionResult(result_type=ActionResultType.SUCCESS, response=response)
|
|
93
|
+
except DomainNotFound:
|
|
94
|
+
return ActionResult(result_type=ActionResultType.RETRY_WITH_DOMAIN)
|
|
95
|
+
|
|
48
96
|
|
|
49
97
|
class NoEndpointCustomActionExecutor(CustomActionExecutor):
|
|
50
98
|
"""Implementation of a custom action executor when endpoint is not set.
|
|
@@ -163,13 +211,13 @@ class RetryCustomActionExecutor(CustomActionExecutor):
|
|
|
163
211
|
domain: "Domain",
|
|
164
212
|
include_domain: bool = False,
|
|
165
213
|
) -> Dict[Text, Any]:
|
|
166
|
-
"""Runs the wrapped custom action executor.
|
|
214
|
+
"""Runs the wrapped custom action executor with retry logic.
|
|
167
215
|
|
|
168
216
|
First request to the action server is made with/without the domain
|
|
169
217
|
as specified by the `include_domain` parameter.
|
|
170
218
|
|
|
171
|
-
If the action server responds with a
|
|
172
|
-
|
|
219
|
+
If the action server responds with a missing domain indication,
|
|
220
|
+
retries the request with the domain included.
|
|
173
221
|
|
|
174
222
|
Args:
|
|
175
223
|
tracker: The current state of the dialogue.
|
|
@@ -178,14 +226,24 @@ class RetryCustomActionExecutor(CustomActionExecutor):
|
|
|
178
226
|
|
|
179
227
|
Returns:
|
|
180
228
|
The response from the execution of the custom action.
|
|
229
|
+
|
|
230
|
+
Raises:
|
|
231
|
+
DomainNotFound: If the action server still requires domain after retry.
|
|
181
232
|
"""
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
233
|
+
result = await self._custom_action_executor.run_with_result(
|
|
234
|
+
tracker,
|
|
235
|
+
domain,
|
|
236
|
+
include_domain=include_domain,
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
if result.result_type == ActionResultType.RETRY_WITH_DOMAIN:
|
|
240
|
+
# Retry with domain included
|
|
241
|
+
result = await self._custom_action_executor.run_with_result(
|
|
190
242
|
tracker, domain, include_domain=True
|
|
191
243
|
)
|
|
244
|
+
|
|
245
|
+
# If still missing domain after retry, raise error
|
|
246
|
+
if result.result_type == ActionResultType.RETRY_WITH_DOMAIN:
|
|
247
|
+
raise DomainNotFound()
|
|
248
|
+
|
|
249
|
+
return result.response if result.response is not None else {}
|
|
@@ -11,6 +11,8 @@ from rasa_sdk.grpc_py import action_webhook_pb2, action_webhook_pb2_grpc
|
|
|
11
11
|
from rasa.core.actions.action_exceptions import DomainNotFound
|
|
12
12
|
from rasa.core.actions.constants import SSL_CLIENT_CERT_FIELD, SSL_CLIENT_KEY_FIELD
|
|
13
13
|
from rasa.core.actions.custom_action_executor import (
|
|
14
|
+
ActionResult,
|
|
15
|
+
ActionResultType,
|
|
14
16
|
CustomActionExecutor,
|
|
15
17
|
CustomActionRequestWriter,
|
|
16
18
|
)
|
|
@@ -101,13 +103,51 @@ class GRPCCustomActionExecutor(CustomActionExecutor):
|
|
|
101
103
|
|
|
102
104
|
Returns:
|
|
103
105
|
Response from the action server.
|
|
106
|
+
Returns empty dict if domain is missing.
|
|
107
|
+
|
|
108
|
+
Raises:
|
|
109
|
+
RasaException: If an error occurs while making the gRPC request
|
|
110
|
+
(other than missing domain).
|
|
104
111
|
"""
|
|
112
|
+
result = await self.run_with_result(tracker, domain, include_domain)
|
|
113
|
+
|
|
114
|
+
# Return empty dict for retry cases to avoid raising exceptions
|
|
115
|
+
# RetryCustomActionExecutor will handle the retry logic
|
|
116
|
+
if result.result_type == ActionResultType.RETRY_WITH_DOMAIN:
|
|
117
|
+
return {}
|
|
118
|
+
|
|
119
|
+
return result.response if result.response is not None else {}
|
|
120
|
+
|
|
121
|
+
async def run_with_result(
|
|
122
|
+
self,
|
|
123
|
+
tracker: "DialogueStateTracker",
|
|
124
|
+
domain: "Domain",
|
|
125
|
+
include_domain: bool = False,
|
|
126
|
+
) -> ActionResult:
|
|
127
|
+
"""Execute the custom action and return an ActionResult.
|
|
128
|
+
|
|
129
|
+
This method avoids raising DomainNotFound exception for missing domain,
|
|
130
|
+
instead returning an ActionResult with RETRY_WITH_DOMAIN type.
|
|
131
|
+
This prevents tracing from capturing this expected condition as an error.
|
|
105
132
|
|
|
133
|
+
Args:
|
|
134
|
+
tracker: Tracker for the current conversation.
|
|
135
|
+
domain: Domain of the assistant.
|
|
136
|
+
include_domain: If True, the domain is included in the request.
|
|
137
|
+
|
|
138
|
+
Returns:
|
|
139
|
+
ActionResult containing the response and result type.
|
|
140
|
+
"""
|
|
106
141
|
request = self._create_payload(
|
|
107
142
|
tracker=tracker, domain=domain, include_domain=include_domain
|
|
108
143
|
)
|
|
109
144
|
|
|
110
|
-
|
|
145
|
+
try:
|
|
146
|
+
response = self._request(request)
|
|
147
|
+
return ActionResult(result_type=ActionResultType.SUCCESS, response=response)
|
|
148
|
+
except DomainNotFound:
|
|
149
|
+
# Return retry result instead of raising DomainNotFound
|
|
150
|
+
return ActionResult(result_type=ActionResultType.RETRY_WITH_DOMAIN)
|
|
111
151
|
|
|
112
152
|
def _request(
|
|
113
153
|
self,
|
|
@@ -121,7 +161,6 @@ class GRPCCustomActionExecutor(CustomActionExecutor):
|
|
|
121
161
|
Returns:
|
|
122
162
|
Response from the action server.
|
|
123
163
|
"""
|
|
124
|
-
|
|
125
164
|
client = self._create_grpc_client()
|
|
126
165
|
metadata = self._build_metadata()
|
|
127
166
|
try:
|
|
@@ -4,8 +4,11 @@ from typing import TYPE_CHECKING, Any, Dict, Optional
|
|
|
4
4
|
|
|
5
5
|
import aiohttp
|
|
6
6
|
|
|
7
|
-
from rasa.core.actions.action_exceptions import ActionExecutionRejection
|
|
7
|
+
from rasa.core.actions.action_exceptions import ActionExecutionRejection
|
|
8
|
+
from rasa.core.actions.constants import MISSING_DOMAIN_MARKER
|
|
8
9
|
from rasa.core.actions.custom_action_executor import (
|
|
10
|
+
ActionResult,
|
|
11
|
+
ActionResultType,
|
|
9
12
|
CustomActionExecutor,
|
|
10
13
|
CustomActionRequestWriter,
|
|
11
14
|
)
|
|
@@ -18,12 +21,12 @@ from rasa.shared.core.domain import Domain
|
|
|
18
21
|
from rasa.shared.core.trackers import DialogueStateTracker
|
|
19
22
|
from rasa.shared.exceptions import RasaException
|
|
20
23
|
from rasa.utils.common import get_bool_env_variable
|
|
24
|
+
from rasa.utils.endpoints import ClientResponseError, EndpointConfig
|
|
21
25
|
|
|
22
26
|
if TYPE_CHECKING:
|
|
23
27
|
from rasa.shared.core.domain import Domain
|
|
24
28
|
from rasa.shared.core.trackers import DialogueStateTracker
|
|
25
29
|
|
|
26
|
-
from rasa.utils.endpoints import ClientResponseError, EndpointConfig
|
|
27
30
|
|
|
28
31
|
logger = logging.getLogger(__name__)
|
|
29
32
|
|
|
@@ -62,9 +65,40 @@ class HTTPCustomActionExecutor(CustomActionExecutor):
|
|
|
62
65
|
|
|
63
66
|
Returns:
|
|
64
67
|
A dictionary containing the response from the custom action endpoint.
|
|
68
|
+
Returns empty dict if domain is missing (449 status).
|
|
65
69
|
|
|
66
70
|
Raises:
|
|
67
|
-
RasaException: If an error occurs while making the HTTP request
|
|
71
|
+
RasaException: If an error occurs while making the HTTP request
|
|
72
|
+
(other than missing domain).
|
|
73
|
+
"""
|
|
74
|
+
result = await self.run_with_result(tracker, domain, include_domain)
|
|
75
|
+
|
|
76
|
+
# Return empty dict for retry cases to avoid raising exceptions
|
|
77
|
+
# RetryCustomActionExecutor will handle the retry logic
|
|
78
|
+
if result.result_type == ActionResultType.RETRY_WITH_DOMAIN:
|
|
79
|
+
return {}
|
|
80
|
+
|
|
81
|
+
return result.response if result.response is not None else {}
|
|
82
|
+
|
|
83
|
+
async def run_with_result(
|
|
84
|
+
self,
|
|
85
|
+
tracker: "DialogueStateTracker",
|
|
86
|
+
domain: Optional["Domain"] = None,
|
|
87
|
+
include_domain: bool = False,
|
|
88
|
+
) -> ActionResult:
|
|
89
|
+
"""Execute the custom action and return an ActionResult.
|
|
90
|
+
|
|
91
|
+
This method avoids raising DomainNotFound exception for 449 status code,
|
|
92
|
+
instead returning an ActionResult with RETRY_WITH_DOMAIN type.
|
|
93
|
+
This prevents tracing from capturing this expected condition as an error.
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
tracker: The current state of the dialogue.
|
|
97
|
+
domain: The domain object containing domain-specific information.
|
|
98
|
+
include_domain: If True, the domain is included in the request.
|
|
99
|
+
|
|
100
|
+
Returns:
|
|
101
|
+
ActionResult containing the response and result type.
|
|
68
102
|
"""
|
|
69
103
|
from rasa.core.actions.action import RemoteActionJSONValidator
|
|
70
104
|
|
|
@@ -77,14 +111,23 @@ class HTTPCustomActionExecutor(CustomActionExecutor):
|
|
|
77
111
|
tracker=tracker, domain=domain, include_domain=include_domain
|
|
78
112
|
)
|
|
79
113
|
|
|
80
|
-
|
|
114
|
+
assert self.action_endpoint is not None
|
|
115
|
+
response = await self.action_endpoint.request(
|
|
116
|
+
json=json_body,
|
|
117
|
+
method="post",
|
|
118
|
+
timeout=DEFAULT_REQUEST_TIMEOUT,
|
|
119
|
+
compress=self.should_compress,
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
# Check if we got the special marker for 449 status (missing domain)
|
|
123
|
+
if isinstance(response, dict) and response.get(MISSING_DOMAIN_MARKER):
|
|
124
|
+
return ActionResult(result_type=ActionResultType.RETRY_WITH_DOMAIN)
|
|
81
125
|
|
|
82
126
|
if response is None:
|
|
83
127
|
response = {}
|
|
84
128
|
|
|
85
129
|
RemoteActionJSONValidator.validate(response)
|
|
86
|
-
|
|
87
|
-
return response
|
|
130
|
+
return ActionResult(result_type=ActionResultType.SUCCESS, response=response)
|
|
88
131
|
|
|
89
132
|
except ClientResponseError as e:
|
|
90
133
|
if e.status == 400:
|
|
@@ -131,22 +174,3 @@ class HTTPCustomActionExecutor(CustomActionExecutor):
|
|
|
131
174
|
"and returns a 200 once the action is executed. "
|
|
132
175
|
"Error: {}".format(self.action_name, status, e)
|
|
133
176
|
)
|
|
134
|
-
|
|
135
|
-
async def _perform_request_with_retries(
|
|
136
|
-
self,
|
|
137
|
-
json_body: Dict[str, Any],
|
|
138
|
-
) -> Any:
|
|
139
|
-
"""Attempts to perform the request with retries if necessary."""
|
|
140
|
-
assert self.action_endpoint is not None
|
|
141
|
-
try:
|
|
142
|
-
return await self.action_endpoint.request(
|
|
143
|
-
json=json_body,
|
|
144
|
-
method="post",
|
|
145
|
-
timeout=DEFAULT_REQUEST_TIMEOUT,
|
|
146
|
-
compress=self.should_compress,
|
|
147
|
-
)
|
|
148
|
-
except ClientResponseError as e:
|
|
149
|
-
# Repeat the request because Domain was not in the payload
|
|
150
|
-
if e.status == 449:
|
|
151
|
-
raise DomainNotFound()
|
|
152
|
-
raise e
|
|
@@ -90,13 +90,13 @@ class BrowserAudioInputChannel(VoiceInputChannel):
|
|
|
90
90
|
self._wav_file: Optional[wave.Wave_write] = None
|
|
91
91
|
|
|
92
92
|
def _start_recording(self, call_id: str, user_id: str) -> None:
|
|
93
|
+
if not self._recording_enabled:
|
|
94
|
+
return
|
|
95
|
+
|
|
93
96
|
os.makedirs("recordings", exist_ok=True)
|
|
94
97
|
filename = f"{user_id}_{call_id}.wav"
|
|
95
98
|
file_path = os.path.join("recordings", filename)
|
|
96
99
|
|
|
97
|
-
if not self._recording_enabled:
|
|
98
|
-
return
|
|
99
|
-
|
|
100
100
|
self._wav_file = wave.open(file_path, "wb")
|
|
101
101
|
self._wav_file.setnchannels(1) # Mono audio
|
|
102
102
|
self._wav_file.setsampwidth(4) # 32-bit audio (4 bytes)
|
|
@@ -192,6 +192,8 @@ class VoiceOutputChannel(OutputChannel):
|
|
|
192
192
|
async def send_start_marker(self, recipient_id: str) -> None:
|
|
193
193
|
"""Send a marker message before the first audio chunk."""
|
|
194
194
|
# Default implementation uses the generic marker message
|
|
195
|
+
call_state.is_bot_speaking = True
|
|
196
|
+
VoiceInputChannel._cancel_silence_timeout_watcher()
|
|
195
197
|
await self.send_marker_message(recipient_id)
|
|
196
198
|
|
|
197
199
|
async def send_intermediate_marker(self, recipient_id: str) -> None:
|
|
@@ -268,11 +270,6 @@ class VoiceOutputChannel(OutputChannel):
|
|
|
268
270
|
# Track TTS start time
|
|
269
271
|
call_state.tts_start_time = time.time()
|
|
270
272
|
|
|
271
|
-
cached_audio_bytes = self.tts_cache.get(text)
|
|
272
|
-
collected_audio_bytes = RasaAudioBytes(b"")
|
|
273
|
-
seconds_marker = -1
|
|
274
|
-
last_sent_offset = 0
|
|
275
|
-
first_audio_sent = False
|
|
276
273
|
logger.debug("voice_channel.sending_audio", text=text)
|
|
277
274
|
|
|
278
275
|
# Send start marker before first chunk
|
|
@@ -285,17 +282,12 @@ class VoiceOutputChannel(OutputChannel):
|
|
|
285
282
|
allow_interruptions = kwargs.get("allow_interruptions", True)
|
|
286
283
|
call_state.channel_data["allow_interruptions"] = allow_interruptions
|
|
287
284
|
|
|
288
|
-
|
|
289
|
-
audio_stream = self.chunk_audio(cached_audio_bytes)
|
|
290
|
-
else:
|
|
291
|
-
# Todo: make kwargs compatible with engine config
|
|
292
|
-
synth_config = self.tts_engine.config.__class__.from_dict({})
|
|
293
|
-
try:
|
|
294
|
-
audio_stream = self.tts_engine.synthesize(text, synth_config)
|
|
295
|
-
except TTSError:
|
|
296
|
-
# TODO: add message that works without tts, e.g. loading from disc
|
|
297
|
-
audio_stream = self.chunk_audio(generate_silence())
|
|
285
|
+
audio_stream = await self._create_audio_stream(text)
|
|
298
286
|
|
|
287
|
+
collected_audio_bytes = RasaAudioBytes(b"")
|
|
288
|
+
last_sent_offset = 0
|
|
289
|
+
first_audio_sent = False
|
|
290
|
+
seconds_marker = -1
|
|
299
291
|
async for audio_bytes in audio_stream:
|
|
300
292
|
collected_audio_bytes = RasaAudioBytes(collected_audio_bytes + audio_bytes)
|
|
301
293
|
|
|
@@ -315,6 +307,8 @@ class VoiceOutputChannel(OutputChannel):
|
|
|
315
307
|
await self.send_audio_bytes(recipient_id, new_bytes)
|
|
316
308
|
last_sent_offset = len(collected_audio_bytes)
|
|
317
309
|
|
|
310
|
+
# seconds of audio rounded down to floor number
|
|
311
|
+
# e.g 7 // 2 = 3
|
|
318
312
|
full_seconds_of_audio = len(collected_audio_bytes) // HERTZ
|
|
319
313
|
if full_seconds_of_audio > seconds_marker:
|
|
320
314
|
await self.send_intermediate_marker(recipient_id)
|
|
@@ -348,7 +342,7 @@ class VoiceOutputChannel(OutputChannel):
|
|
|
348
342
|
pass
|
|
349
343
|
call_state.latest_bot_audio_id = self.latest_message_id
|
|
350
344
|
|
|
351
|
-
if not
|
|
345
|
+
if not self.tts_cache.get(text):
|
|
352
346
|
self.tts_cache.put(text, collected_audio_bytes)
|
|
353
347
|
|
|
354
348
|
async def send_audio_bytes(
|
|
@@ -373,6 +367,22 @@ class VoiceOutputChannel(OutputChannel):
|
|
|
373
367
|
async def hangup(self, recipient_id: str, **kwargs: Any) -> None:
|
|
374
368
|
call_state.should_hangup = True
|
|
375
369
|
|
|
370
|
+
async def _create_audio_stream(self, text: str) -> AsyncIterator[RasaAudioBytes]:
|
|
371
|
+
cached_audio_bytes = self.tts_cache.get(text)
|
|
372
|
+
|
|
373
|
+
if cached_audio_bytes:
|
|
374
|
+
audio_stream = self.chunk_audio(cached_audio_bytes)
|
|
375
|
+
else:
|
|
376
|
+
# Todo: make kwargs compatible with engine config
|
|
377
|
+
synth_config = self.tts_engine.config.__class__.from_dict({})
|
|
378
|
+
try:
|
|
379
|
+
audio_stream = self.tts_engine.synthesize(text, synth_config)
|
|
380
|
+
except TTSError:
|
|
381
|
+
# TODO: add message that works without tts, e.g. loading from disc
|
|
382
|
+
audio_stream = self.chunk_audio(generate_silence())
|
|
383
|
+
|
|
384
|
+
return audio_stream
|
|
385
|
+
|
|
376
386
|
|
|
377
387
|
class VoiceInputChannel(InputChannel):
|
|
378
388
|
# All children of this class require a voice license to be used.
|
|
@@ -555,7 +565,7 @@ class VoiceInputChannel(InputChannel):
|
|
|
555
565
|
# relevant when the bot speaks multiple messages in one turn
|
|
556
566
|
self._cancel_silence_timeout_watcher()
|
|
557
567
|
|
|
558
|
-
#
|
|
568
|
+
# bot just stopped speaking, starting a watcher for silence timeout
|
|
559
569
|
if was_bot_speaking_before and not is_bot_speaking_after:
|
|
560
570
|
logger.debug("voice_channel.bot_stopped_speaking")
|
|
561
571
|
self._cancel_silence_timeout_watcher()
|
rasa/core/config/credentials.py
CHANGED
|
@@ -5,11 +5,11 @@ from typing import Any, Dict
|
|
|
5
5
|
|
|
6
6
|
from rasa.shared.utils.yaml import read_config_file
|
|
7
7
|
|
|
8
|
+
ChannelsType = Dict[str, Dict[str, Any]]
|
|
9
|
+
|
|
8
10
|
|
|
9
11
|
class CredentialsConfig:
|
|
10
|
-
def __init__(
|
|
11
|
-
self, channels: Dict[str, Dict[str, Any]], config_file_path: Path
|
|
12
|
-
) -> None:
|
|
12
|
+
def __init__(self, channels: ChannelsType, config_file_path: Path) -> None:
|
|
13
13
|
self.channels = channels
|
|
14
14
|
self.config_file_path = config_file_path
|
|
15
15
|
|
|
@@ -357,6 +357,10 @@ def reset_scoped_slots(
|
|
|
357
357
|
flow_persistable_slots = current_flow.persisted_slots
|
|
358
358
|
|
|
359
359
|
for step in current_flow.steps_with_calls_resolved:
|
|
360
|
+
# take persisted slots from called flows into consideration
|
|
361
|
+
# before resetting slots
|
|
362
|
+
if isinstance(step, CallFlowStep) and step.called_flow_reference:
|
|
363
|
+
flow_persistable_slots.extend(step.called_flow_reference.persisted_slots)
|
|
360
364
|
if isinstance(step, CollectInformationFlowStep):
|
|
361
365
|
# reset all slots scoped to the flow
|
|
362
366
|
slot_name = step.collect
|
|
@@ -368,7 +372,22 @@ def reset_scoped_slots(
|
|
|
368
372
|
# slots set by the set slots step should be reset after the flow ends
|
|
369
373
|
# unless they are also used in a collect step where `reset_after_flow_ends`
|
|
370
374
|
# is set to `False` or set in the `persisted_slots` list.
|
|
371
|
-
resettable_set_slots =
|
|
375
|
+
resettable_set_slots = _get_resettable_set_slots(
|
|
376
|
+
current_flow, not_resettable_slot_names, flow_persistable_slots
|
|
377
|
+
)
|
|
378
|
+
for name in resettable_set_slots:
|
|
379
|
+
_reset_slot(name, tracker)
|
|
380
|
+
|
|
381
|
+
return events
|
|
382
|
+
|
|
383
|
+
|
|
384
|
+
def _get_resettable_set_slots(
|
|
385
|
+
current_flow: Flow,
|
|
386
|
+
not_resettable_slot_names: set[Text],
|
|
387
|
+
flow_persistable_slots: List[Text],
|
|
388
|
+
) -> List[Text]:
|
|
389
|
+
"""Get list of slot names from SetSlotsFlowStep that should be reset."""
|
|
390
|
+
return [
|
|
372
391
|
slot["key"]
|
|
373
392
|
for step in current_flow.steps_with_calls_resolved
|
|
374
393
|
if isinstance(step, SetSlotsFlowStep)
|
|
@@ -377,11 +396,6 @@ def reset_scoped_slots(
|
|
|
377
396
|
and slot["key"] not in flow_persistable_slots
|
|
378
397
|
]
|
|
379
398
|
|
|
380
|
-
for name in resettable_set_slots:
|
|
381
|
-
_reset_slot(name, tracker)
|
|
382
|
-
|
|
383
|
-
return events
|
|
384
|
-
|
|
385
399
|
|
|
386
400
|
async def advance_flows(
|
|
387
401
|
tracker: DialogueStateTracker,
|
|
@@ -853,25 +867,7 @@ def _silence_timeout_events_for_collect_step(
|
|
|
853
867
|
input_channel_name
|
|
854
868
|
)
|
|
855
869
|
else:
|
|
856
|
-
|
|
857
|
-
credentials_config = Configuration.get_instance().credentials
|
|
858
|
-
|
|
859
|
-
if credentials_config:
|
|
860
|
-
channel_config = (
|
|
861
|
-
credentials_config.channels.get(input_channel_name)
|
|
862
|
-
if input_channel_name
|
|
863
|
-
else None
|
|
864
|
-
)
|
|
865
|
-
|
|
866
|
-
silence_timeout = (
|
|
867
|
-
channel_config.get(
|
|
868
|
-
SILENCE_TIMEOUT_CHANNEL_KEY, GLOBAL_SILENCE_TIMEOUT_DEFAULT_VALUE
|
|
869
|
-
)
|
|
870
|
-
if channel_config
|
|
871
|
-
else GLOBAL_SILENCE_TIMEOUT_DEFAULT_VALUE
|
|
872
|
-
)
|
|
873
|
-
else:
|
|
874
|
-
silence_timeout = GLOBAL_SILENCE_TIMEOUT_DEFAULT_VALUE
|
|
870
|
+
silence_timeout = _get_default_silence_timeout(tracker)
|
|
875
871
|
|
|
876
872
|
structlogger.debug(
|
|
877
873
|
"flow.step.run.use_channel_silence_timeout",
|
|
@@ -891,13 +887,37 @@ def _append_global_silence_timeout_event(
|
|
|
891
887
|
events: List[Event], tracker: DialogueStateTracker
|
|
892
888
|
) -> None:
|
|
893
889
|
current_silence_timeout = tracker.get_slot(SILENCE_TIMEOUT_SLOT)
|
|
894
|
-
|
|
895
|
-
global_silence_timeout = endpoints.interaction_handling.global_silence_timeout
|
|
890
|
+
default_silence_timeout = _get_default_silence_timeout(tracker)
|
|
896
891
|
|
|
897
|
-
if current_silence_timeout !=
|
|
892
|
+
if current_silence_timeout != default_silence_timeout:
|
|
898
893
|
events.append(
|
|
899
894
|
SlotSet(
|
|
900
895
|
SILENCE_TIMEOUT_SLOT,
|
|
901
|
-
|
|
896
|
+
default_silence_timeout,
|
|
897
|
+
)
|
|
898
|
+
)
|
|
899
|
+
|
|
900
|
+
|
|
901
|
+
def _get_default_silence_timeout(tracker: DialogueStateTracker) -> float:
|
|
902
|
+
"""Get the default silence timeout for the tracker."""
|
|
903
|
+
input_channel_name = tracker.get_latest_input_channel()
|
|
904
|
+
credentials_config = Configuration.get_instance().credentials
|
|
905
|
+
|
|
906
|
+
if credentials_config:
|
|
907
|
+
channel_config = (
|
|
908
|
+
credentials_config.channels.get(input_channel_name)
|
|
909
|
+
if input_channel_name
|
|
910
|
+
else None
|
|
911
|
+
)
|
|
912
|
+
|
|
913
|
+
silence_timeout = (
|
|
914
|
+
channel_config.get(
|
|
915
|
+
SILENCE_TIMEOUT_CHANNEL_KEY, GLOBAL_SILENCE_TIMEOUT_DEFAULT_VALUE
|
|
902
916
|
)
|
|
917
|
+
if channel_config
|
|
918
|
+
else GLOBAL_SILENCE_TIMEOUT_DEFAULT_VALUE
|
|
903
919
|
)
|
|
920
|
+
else:
|
|
921
|
+
silence_timeout = GLOBAL_SILENCE_TIMEOUT_DEFAULT_VALUE
|
|
922
|
+
|
|
923
|
+
return silence_timeout
|
rasa/core/run.py
CHANGED
|
@@ -5,6 +5,7 @@ import platform
|
|
|
5
5
|
import uuid
|
|
6
6
|
import warnings
|
|
7
7
|
from asyncio import AbstractEventLoop
|
|
8
|
+
from copy import deepcopy
|
|
8
9
|
from functools import partial
|
|
9
10
|
from typing import (
|
|
10
11
|
Any,
|
|
@@ -112,7 +113,11 @@ def _create_single_channel(
|
|
|
112
113
|
if channel in BUILTIN_CHANNELS:
|
|
113
114
|
channel_class = BUILTIN_CHANNELS[channel]
|
|
114
115
|
|
|
115
|
-
|
|
116
|
+
channel_credentials = deepcopy(credentials)
|
|
117
|
+
channel_credentials.pop(
|
|
118
|
+
"silence_timeout", None
|
|
119
|
+
) if channel_credentials else None
|
|
120
|
+
return channel_class.from_credentials(channel_credentials)
|
|
116
121
|
elif channel in channels_with_optional_deps:
|
|
117
122
|
# Channel is known but not available due to missing dependency
|
|
118
123
|
dependency = channels_with_optional_deps[channel]
|
|
@@ -328,10 +333,21 @@ def serve_application(
|
|
|
328
333
|
|
|
329
334
|
logger.info(f"Starting Rasa server on {protocol}://{interface}:{port}")
|
|
330
335
|
|
|
331
|
-
app
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
336
|
+
async def load_agent_and_check_failure(app: Sanic, loop: AbstractEventLoop) -> None:
|
|
337
|
+
"""Load agent and exit if it fails in non-debug mode."""
|
|
338
|
+
try:
|
|
339
|
+
await load_agent_on_start(
|
|
340
|
+
model_path, endpoints, remote_storage, sub_agents, app, loop
|
|
341
|
+
)
|
|
342
|
+
except Exception as e:
|
|
343
|
+
is_debug = logger.isEnabledFor(logging.DEBUG)
|
|
344
|
+
if is_debug:
|
|
345
|
+
raise e # show traceback in debug
|
|
346
|
+
# non-debug: log and exit without starting server
|
|
347
|
+
logger.error(f"Failed to load agent: {e}")
|
|
348
|
+
os._exit(1) # Any other exit method would show a traceback.
|
|
349
|
+
|
|
350
|
+
app.register_listener(load_agent_and_check_failure, "before_server_start")
|
|
335
351
|
|
|
336
352
|
app.register_listener(
|
|
337
353
|
licensing.validate_limited_server_license, "after_server_start"
|