rasa-pro 3.10.16__py3-none-any.whl → 3.11.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- README.md +396 -17
- rasa/api.py +9 -3
- rasa/cli/arguments/default_arguments.py +23 -2
- rasa/cli/arguments/run.py +15 -0
- rasa/cli/arguments/train.py +3 -9
- rasa/cli/e2e_test.py +1 -1
- rasa/cli/evaluate.py +1 -1
- rasa/cli/inspect.py +8 -4
- rasa/cli/llm_fine_tuning.py +12 -15
- rasa/cli/run.py +8 -1
- rasa/cli/studio/studio.py +8 -18
- rasa/cli/train.py +11 -53
- rasa/cli/utils.py +8 -10
- rasa/cli/x.py +1 -1
- rasa/constants.py +1 -1
- rasa/core/actions/action.py +2 -0
- rasa/core/actions/action_hangup.py +29 -0
- rasa/core/agent.py +2 -2
- rasa/core/brokers/kafka.py +3 -1
- rasa/core/brokers/pika.py +3 -1
- rasa/core/channels/__init__.py +8 -6
- rasa/core/channels/channel.py +21 -4
- rasa/core/channels/development_inspector.py +143 -46
- rasa/core/channels/inspector/README.md +1 -1
- rasa/core/channels/inspector/dist/assets/{arc-b6e548fe.js → arc-86942a71.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{c4Diagram-d0fbc5ce-fa03ac9e.js → c4Diagram-d0fbc5ce-b0290676.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{classDiagram-936ed81e-ee67392a.js → classDiagram-936ed81e-f6405f6e.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{classDiagram-v2-c3cb15f1-9b283fae.js → classDiagram-v2-c3cb15f1-ef61ac77.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{createText-62fc7601-8b6fcc2a.js → createText-62fc7601-f0411e58.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{edges-f2ad444c-22e77f4f.js → edges-f2ad444c-7dcc4f3b.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{erDiagram-9d236eb7-60ffc87f.js → erDiagram-9d236eb7-e0c092d7.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{flowDb-1972c806-9dd802e4.js → flowDb-1972c806-fba2e3ce.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{flowDiagram-7ea5b25a-5fa1912f.js → flowDiagram-7ea5b25a-7a70b71a.js} +1 -1
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-24a5f41a.js +1 -0
- rasa/core/channels/inspector/dist/assets/{flowchart-elk-definition-abe16c3d-622a1fd2.js → flowchart-elk-definition-abe16c3d-00a59b68.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{ganttDiagram-9b5ea136-e285a63a.js → ganttDiagram-9b5ea136-293c91fa.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{gitGraphDiagram-99d0ae7c-f237bdca.js → gitGraphDiagram-99d0ae7c-07b2d68c.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{index-2c4b9a3b-4b03d70e.js → index-2c4b9a3b-bc959fbd.js} +1 -1
- rasa/core/channels/inspector/dist/assets/index-3a8a5a28.js +1317 -0
- rasa/core/channels/inspector/dist/assets/{infoDiagram-736b4530-72a0fa5f.js → infoDiagram-736b4530-4a350f72.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{journeyDiagram-df861f2b-82218c41.js → journeyDiagram-df861f2b-af464fb7.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{layout-78cff630.js → layout-0071f036.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{line-5038b469.js → line-2f73cc83.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{linear-c4fc4098.js → linear-f014b4cc.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{mindmap-definition-beec6740-c33c8ea6.js → mindmap-definition-beec6740-d2426fb6.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{pieDiagram-dbbf0591-a8d03059.js → pieDiagram-dbbf0591-776f01a2.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{quadrantDiagram-4d7f4fd6-6a0e56b2.js → quadrantDiagram-4d7f4fd6-82e00b57.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{requirementDiagram-6fc4c22a-2dc7c7bd.js → requirementDiagram-6fc4c22a-ea13c6bb.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{sankeyDiagram-8f13d901-2360fe39.js → sankeyDiagram-8f13d901-1feca7e9.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{sequenceDiagram-b655622a-41b9f9ad.js → sequenceDiagram-b655622a-070c61d2.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{stateDiagram-59f0c015-0aad326f.js → stateDiagram-59f0c015-24f46263.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{stateDiagram-v2-2b26beab-9847d984.js → stateDiagram-v2-2b26beab-c9056051.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-080da4f6-564d890e.js → styles-080da4f6-08abc34a.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-3dcbcfbf-38957613.js → styles-3dcbcfbf-bc74c25a.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-9c745c82-f0fc6921.js → styles-9c745c82-4e5d66de.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{svgDrawCommon-4835440b-ef3c5a77.js → svgDrawCommon-4835440b-849c4517.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{timeline-definition-5b62e21b-bf3e91c1.js → timeline-definition-5b62e21b-d0fb1598.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{xychartDiagram-2b33534f-4d4026c0.js → xychartDiagram-2b33534f-04d115e2.js} +1 -1
- rasa/core/channels/inspector/dist/index.html +18 -17
- rasa/core/channels/inspector/index.html +17 -16
- rasa/core/channels/inspector/package.json +5 -1
- rasa/core/channels/inspector/src/App.tsx +117 -67
- rasa/core/channels/inspector/src/components/Chat.tsx +95 -0
- rasa/core/channels/inspector/src/components/DiagramFlow.tsx +11 -10
- rasa/core/channels/inspector/src/components/DialogueStack.tsx +10 -25
- rasa/core/channels/inspector/src/components/LoadingSpinner.tsx +1 -1
- rasa/core/channels/inspector/src/helpers/formatters.test.ts +10 -0
- rasa/core/channels/inspector/src/helpers/formatters.ts +107 -41
- rasa/core/channels/inspector/src/helpers/utils.ts +92 -7
- rasa/core/channels/inspector/src/types.ts +21 -1
- rasa/core/channels/inspector/yarn.lock +94 -1
- rasa/core/channels/rest.py +51 -46
- rasa/core/channels/socketio.py +22 -0
- rasa/core/channels/{audiocodes.py → voice_ready/audiocodes.py} +110 -68
- rasa/core/channels/{voice_aware → voice_ready}/jambonz.py +11 -4
- rasa/core/channels/{voice_aware → voice_ready}/jambonz_protocol.py +57 -5
- rasa/core/channels/{twilio_voice.py → voice_ready/twilio_voice.py} +58 -7
- rasa/core/channels/{voice_aware → voice_ready}/utils.py +16 -0
- rasa/core/channels/voice_stream/asr/__init__.py +0 -0
- rasa/core/channels/voice_stream/asr/asr_engine.py +71 -0
- rasa/core/channels/voice_stream/asr/asr_event.py +13 -0
- rasa/core/channels/voice_stream/asr/deepgram.py +77 -0
- rasa/core/channels/voice_stream/audio_bytes.py +7 -0
- rasa/core/channels/voice_stream/tts/__init__.py +0 -0
- rasa/core/channels/voice_stream/tts/azure.py +100 -0
- rasa/core/channels/voice_stream/tts/cartesia.py +114 -0
- rasa/core/channels/voice_stream/tts/tts_cache.py +27 -0
- rasa/core/channels/voice_stream/tts/tts_engine.py +48 -0
- rasa/core/channels/voice_stream/twilio_media_streams.py +164 -0
- rasa/core/channels/voice_stream/util.py +57 -0
- rasa/core/channels/voice_stream/voice_channel.py +247 -0
- rasa/core/featurizers/single_state_featurizer.py +1 -22
- rasa/core/featurizers/tracker_featurizers.py +18 -115
- rasa/core/nlg/contextual_response_rephraser.py +11 -2
- rasa/{nlu → core}/persistor.py +16 -38
- rasa/core/policies/enterprise_search_policy.py +12 -15
- rasa/core/policies/flows/flow_executor.py +8 -18
- rasa/core/policies/intentless_policy.py +10 -15
- rasa/core/policies/ted_policy.py +33 -58
- rasa/core/policies/unexpected_intent_policy.py +7 -15
- rasa/core/processor.py +13 -64
- rasa/core/run.py +11 -1
- rasa/core/secrets_manager/constants.py +4 -0
- rasa/core/secrets_manager/factory.py +8 -0
- rasa/core/secrets_manager/vault.py +11 -1
- rasa/core/training/interactive.py +1 -1
- rasa/core/utils.py +1 -11
- rasa/dialogue_understanding/coexistence/llm_based_router.py +10 -10
- rasa/dialogue_understanding/commands/__init__.py +2 -0
- rasa/dialogue_understanding/commands/change_flow_command.py +0 -6
- rasa/dialogue_understanding/commands/session_end_command.py +61 -0
- rasa/dialogue_understanding/generator/flow_retrieval.py +0 -7
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +12 -3
- rasa/dialogue_understanding/generator/llm_command_generator.py +1 -1
- rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +3 -28
- rasa/dialogue_understanding/generator/nlu_command_adapter.py +1 -19
- rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +4 -37
- rasa/e2e_test/aggregate_test_stats_calculator.py +1 -11
- rasa/e2e_test/assertions.py +6 -48
- rasa/e2e_test/e2e_test_runner.py +6 -9
- rasa/e2e_test/utils/e2e_yaml_utils.py +1 -1
- rasa/e2e_test/utils/io.py +1 -3
- rasa/engine/graph.py +3 -10
- rasa/engine/recipes/config_files/default_config.yml +0 -3
- rasa/engine/recipes/default_recipe.py +0 -1
- rasa/engine/recipes/graph_recipe.py +0 -1
- rasa/engine/runner/dask.py +2 -2
- rasa/engine/storage/local_model_storage.py +12 -42
- rasa/engine/storage/storage.py +1 -5
- rasa/engine/validation.py +1 -78
- rasa/keys +1 -0
- rasa/model_training.py +13 -16
- rasa/nlu/classifiers/diet_classifier.py +25 -38
- rasa/nlu/classifiers/logistic_regression_classifier.py +9 -22
- rasa/nlu/classifiers/sklearn_intent_classifier.py +16 -37
- rasa/nlu/extractors/crf_entity_extractor.py +50 -93
- rasa/nlu/featurizers/sparse_featurizer/count_vectors_featurizer.py +16 -45
- rasa/nlu/featurizers/sparse_featurizer/lexical_syntactic_featurizer.py +17 -52
- rasa/nlu/featurizers/sparse_featurizer/regex_featurizer.py +3 -5
- rasa/server.py +1 -1
- rasa/shared/constants.py +3 -12
- rasa/shared/core/constants.py +4 -0
- rasa/shared/core/domain.py +101 -47
- rasa/shared/core/events.py +29 -0
- rasa/shared/core/flows/flows_list.py +20 -11
- rasa/shared/core/flows/validation.py +25 -0
- rasa/shared/core/flows/yaml_flows_io.py +3 -24
- rasa/shared/importers/importer.py +40 -39
- rasa/shared/importers/multi_project.py +23 -11
- rasa/shared/importers/rasa.py +7 -2
- rasa/shared/importers/remote_importer.py +196 -0
- rasa/shared/importers/utils.py +3 -1
- rasa/shared/nlu/training_data/features.py +2 -120
- rasa/shared/nlu/training_data/training_data.py +18 -19
- rasa/shared/providers/_configs/azure_openai_client_config.py +3 -5
- rasa/shared/providers/embedding/_base_litellm_embedding_client.py +1 -6
- rasa/shared/providers/llm/_base_litellm_client.py +11 -31
- rasa/shared/providers/llm/self_hosted_llm_client.py +3 -15
- rasa/shared/utils/common.py +3 -22
- rasa/shared/utils/io.py +0 -1
- rasa/shared/utils/llm.py +30 -27
- rasa/shared/utils/schemas/events.py +2 -0
- rasa/shared/utils/schemas/model_config.yml +0 -10
- rasa/shared/utils/yaml.py +44 -0
- rasa/studio/auth.py +5 -3
- rasa/studio/config.py +4 -13
- rasa/studio/constants.py +0 -1
- rasa/studio/data_handler.py +3 -10
- rasa/studio/upload.py +8 -17
- rasa/tracing/instrumentation/attribute_extractors.py +1 -1
- rasa/utils/io.py +66 -0
- rasa/utils/tensorflow/model_data.py +193 -2
- rasa/validator.py +0 -12
- rasa/version.py +1 -1
- rasa_pro-3.11.0a1.dist-info/METADATA +576 -0
- {rasa_pro-3.10.16.dist-info → rasa_pro-3.11.0a1.dist-info}/RECORD +181 -164
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-1844e5a5.js +0 -1
- rasa/core/channels/inspector/dist/assets/index-a5d3e69d.js +0 -1040
- rasa/utils/tensorflow/feature_array.py +0 -366
- rasa_pro-3.10.16.dist-info/METADATA +0 -196
- /rasa/core/channels/{voice_aware → voice_ready}/__init__.py +0 -0
- /rasa/core/channels/{voice_native → voice_stream}/__init__.py +0 -0
- {rasa_pro-3.10.16.dist-info → rasa_pro-3.11.0a1.dist-info}/NOTICE +0 -0
- {rasa_pro-3.10.16.dist-info → rasa_pro-3.11.0a1.dist-info}/WHEEL +0 -0
- {rasa_pro-3.10.16.dist-info → rasa_pro-3.11.0a1.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import AsyncIterator, Dict, Optional
|
|
3
|
+
import os
|
|
4
|
+
import aiohttp
|
|
5
|
+
import structlog
|
|
6
|
+
from aiohttp import ClientConnectorError
|
|
7
|
+
|
|
8
|
+
from rasa.core.channels.voice_stream.tts.tts_engine import (
|
|
9
|
+
TTSEngineConfig,
|
|
10
|
+
)
|
|
11
|
+
|
|
12
|
+
from rasa.core.channels.voice_stream.audio_bytes import RasaAudioBytes
|
|
13
|
+
from rasa.core.channels.voice_stream.tts.tts_engine import TTSEngine, TTSError
|
|
14
|
+
from rasa.shared.exceptions import ConnectionException
|
|
15
|
+
|
|
16
|
+
structlogger = structlog.get_logger()
|
|
17
|
+
|
|
18
|
+
CARTESIA_API_KEY = "CARTESIA_API_KEY"
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dataclass
|
|
22
|
+
class CartesiaTTSConfig(TTSEngineConfig):
|
|
23
|
+
model_id: Optional[str] = None
|
|
24
|
+
version: Optional[str] = None
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class CartesiaTTS(TTSEngine[CartesiaTTSConfig]):
|
|
28
|
+
session: Optional[aiohttp.ClientSession] = None
|
|
29
|
+
|
|
30
|
+
def __init__(self, config: Optional[CartesiaTTSConfig] = None):
|
|
31
|
+
super().__init__(config)
|
|
32
|
+
# Have to create this class-shared session lazily at run time otherwise
|
|
33
|
+
# the async event loop doesn't work
|
|
34
|
+
if self.__class__.session is None or self.__class__.session.closed:
|
|
35
|
+
self.__class__.session = aiohttp.ClientSession()
|
|
36
|
+
|
|
37
|
+
@staticmethod
|
|
38
|
+
def get_tts_endpoint() -> str:
|
|
39
|
+
"""Create the endpoint string for cartesia."""
|
|
40
|
+
return "https://api.cartesia.ai/tts/bytes"
|
|
41
|
+
|
|
42
|
+
@staticmethod
|
|
43
|
+
def get_request_body(text: str, config: CartesiaTTSConfig) -> Dict:
|
|
44
|
+
"""Create the request body for cartesia."""
|
|
45
|
+
# more info on payload:
|
|
46
|
+
# https://docs.cartesia.ai/reference/api-reference/rest/stream-speech-bytes
|
|
47
|
+
return {
|
|
48
|
+
"model_id": config.model_id,
|
|
49
|
+
"transcript": text,
|
|
50
|
+
"language": config.language,
|
|
51
|
+
"voice": {
|
|
52
|
+
"mode": "id",
|
|
53
|
+
"id": config.voice,
|
|
54
|
+
},
|
|
55
|
+
"output_format": {
|
|
56
|
+
"container": "raw",
|
|
57
|
+
"encoding": "pcm_mulaw",
|
|
58
|
+
"sample_rate": 8000,
|
|
59
|
+
},
|
|
60
|
+
}
|
|
61
|
+
|
|
62
|
+
@staticmethod
|
|
63
|
+
def get_request_headers(config: CartesiaTTSConfig) -> dict[str, str]:
|
|
64
|
+
cartesia_api_key = os.environ.get(CARTESIA_API_KEY)
|
|
65
|
+
return {
|
|
66
|
+
"Cartesia-Version": str(config.version),
|
|
67
|
+
"Content-Type": "application/json",
|
|
68
|
+
"X-API-Key": str(cartesia_api_key),
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
async def synthesize(
|
|
72
|
+
self, text: str, config: Optional[CartesiaTTSConfig] = None
|
|
73
|
+
) -> AsyncIterator[RasaAudioBytes]:
|
|
74
|
+
"""Generate speech from text using a remote TTS system."""
|
|
75
|
+
config = self.config.merge(config)
|
|
76
|
+
payload = self.get_request_body(text, config)
|
|
77
|
+
headers = self.get_request_headers(config)
|
|
78
|
+
url = self.get_tts_endpoint()
|
|
79
|
+
if self.session is None:
|
|
80
|
+
raise ConnectionException("Client session is not initialized")
|
|
81
|
+
try:
|
|
82
|
+
async with self.session.post(
|
|
83
|
+
url, headers=headers, json=payload, chunked=True
|
|
84
|
+
) as response:
|
|
85
|
+
if 200 <= response.status < 300:
|
|
86
|
+
async for data in response.content.iter_chunked(1024):
|
|
87
|
+
yield self.engine_bytes_to_rasa_audio_bytes(data)
|
|
88
|
+
return
|
|
89
|
+
else:
|
|
90
|
+
structlogger.error(
|
|
91
|
+
"azure.synthesize.rest.failed",
|
|
92
|
+
status_code=response.status,
|
|
93
|
+
msg=response.text(),
|
|
94
|
+
)
|
|
95
|
+
raise TTSError(f"TTS failed: {response.text()}")
|
|
96
|
+
except ClientConnectorError as e:
|
|
97
|
+
raise TTSError(e)
|
|
98
|
+
|
|
99
|
+
def engine_bytes_to_rasa_audio_bytes(self, chunk: bytes) -> RasaAudioBytes:
|
|
100
|
+
"""Convert the generated tts audio bytes into rasa audio bytes."""
|
|
101
|
+
return RasaAudioBytes(chunk)
|
|
102
|
+
|
|
103
|
+
@staticmethod
|
|
104
|
+
def get_default_config() -> CartesiaTTSConfig:
|
|
105
|
+
return CartesiaTTSConfig(
|
|
106
|
+
language="en",
|
|
107
|
+
voice="248be419-c632-4f23-adf1-5324ed7dbf1d",
|
|
108
|
+
model_id="sonic-english",
|
|
109
|
+
version="2024-06-10",
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
@classmethod
|
|
113
|
+
def from_config_dict(cls, config: Dict) -> "CartesiaTTS":
|
|
114
|
+
return cls(CartesiaTTSConfig.from_dict(config))
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
from collections import OrderedDict
|
|
3
|
+
import logging
|
|
4
|
+
from rasa.core.channels.voice_stream.audio_bytes import RasaAudioBytes
|
|
5
|
+
|
|
6
|
+
logger = logging.getLogger(__name__)
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class TTSCache:
|
|
10
|
+
"""An LRU Cache for TTS based on pythons OrderedDict."""
|
|
11
|
+
|
|
12
|
+
def __init__(self, max_size: int):
|
|
13
|
+
self.cache: OrderedDict[str, RasaAudioBytes] = OrderedDict()
|
|
14
|
+
self.max_size = max_size
|
|
15
|
+
|
|
16
|
+
def get(self, text: str) -> Optional[RasaAudioBytes]:
|
|
17
|
+
if text not in self.cache:
|
|
18
|
+
return None
|
|
19
|
+
else:
|
|
20
|
+
self.cache.move_to_end(text)
|
|
21
|
+
return self.cache[text]
|
|
22
|
+
|
|
23
|
+
def put(self, text: str, audio_bytes: RasaAudioBytes) -> None:
|
|
24
|
+
self.cache[text] = audio_bytes
|
|
25
|
+
self.cache.move_to_end(text)
|
|
26
|
+
if len(self.cache) > self.max_size:
|
|
27
|
+
self.cache.popitem(last=False)
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
from typing import AsyncIterator, Dict, Generic, Optional, Type, TypeVar
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
|
|
4
|
+
from rasa.core.channels.voice_stream.audio_bytes import RasaAudioBytes
|
|
5
|
+
from rasa.core.channels.voice_stream.util import MergeableConfig
|
|
6
|
+
from rasa.shared.exceptions import RasaException
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class TTSError(RasaException):
|
|
10
|
+
pass
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
T = TypeVar("T", bound="TTSEngineConfig")
|
|
14
|
+
E = TypeVar("E", bound="TTSEngine")
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass
|
|
18
|
+
class TTSEngineConfig(MergeableConfig):
|
|
19
|
+
language: Optional[str] = None
|
|
20
|
+
voice: Optional[str] = None
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class TTSEngine(Generic[T]):
|
|
24
|
+
def __init__(self, config: Optional[T] = None):
|
|
25
|
+
self.config = self.get_default_config().merge(config)
|
|
26
|
+
|
|
27
|
+
async def close_connection(self) -> None:
|
|
28
|
+
"""Cleanup the connection if necessary."""
|
|
29
|
+
return
|
|
30
|
+
|
|
31
|
+
async def synthesize(
|
|
32
|
+
self, text: str, config: Optional[T] = None
|
|
33
|
+
) -> AsyncIterator[RasaAudioBytes]:
|
|
34
|
+
"""Generate speech from text using a remote TTS system."""
|
|
35
|
+
yield RasaAudioBytes(b"")
|
|
36
|
+
|
|
37
|
+
def engine_bytes_to_rasa_audio_bytes(self, chunk: bytes) -> RasaAudioBytes:
|
|
38
|
+
"""Convert the generated tts audio bytes into rasa audio bytes."""
|
|
39
|
+
raise NotImplementedError
|
|
40
|
+
|
|
41
|
+
@staticmethod
|
|
42
|
+
def get_default_config() -> T:
|
|
43
|
+
"""Get the default config for this component."""
|
|
44
|
+
raise NotImplementedError
|
|
45
|
+
|
|
46
|
+
@classmethod
|
|
47
|
+
def from_config_dict(cls: Type[E], config: Dict) -> E:
|
|
48
|
+
raise NotImplementedError
|
|
@@ -0,0 +1,164 @@
|
|
|
1
|
+
import base64
|
|
2
|
+
import json
|
|
3
|
+
import structlog
|
|
4
|
+
from typing import Any, Awaitable, Callable, Dict, List, Optional, Text
|
|
5
|
+
import uuid
|
|
6
|
+
|
|
7
|
+
from sanic import Blueprint, HTTPResponse, Request, response
|
|
8
|
+
from sanic import Websocket # type: ignore
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
from rasa.core.channels import UserMessage
|
|
12
|
+
from rasa.core.channels.voice_ready.utils import CallParameters
|
|
13
|
+
from rasa.core.channels.voice_stream.tts.tts_engine import TTSEngine
|
|
14
|
+
from rasa.core.channels.voice_stream.audio_bytes import RasaAudioBytes
|
|
15
|
+
from rasa.core.channels.voice_stream.voice_channel import (
|
|
16
|
+
EndConversationAction,
|
|
17
|
+
NewAudioAction,
|
|
18
|
+
VoiceChannelAction,
|
|
19
|
+
ContinueConversationAction,
|
|
20
|
+
VoiceInputChannel,
|
|
21
|
+
VoiceOutputChannel,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
structlogger = structlog.get_logger()
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def map_call_params(data: Dict[Text, Any]) -> CallParameters:
|
|
28
|
+
"""Map the twilio stream parameters to the CallParameters dataclass."""
|
|
29
|
+
stream_sid = data["streamSid"]
|
|
30
|
+
parameters = data["start"]["customParameters"]
|
|
31
|
+
return CallParameters(
|
|
32
|
+
call_id=parameters.get("call_id", ""),
|
|
33
|
+
user_phone=parameters.get("user_phone", ""),
|
|
34
|
+
bot_phone=parameters.get("bot_phone", ""),
|
|
35
|
+
direction=parameters.get("direction"),
|
|
36
|
+
stream_id=stream_sid,
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class TwilioMediaStreamsOutputChannel(VoiceOutputChannel):
|
|
41
|
+
@classmethod
|
|
42
|
+
def name(cls) -> str:
|
|
43
|
+
return "twilio_media_streams"
|
|
44
|
+
|
|
45
|
+
def rasa_audio_bytes_to_channel_bytes(
|
|
46
|
+
self, rasa_audio_bytes: RasaAudioBytes
|
|
47
|
+
) -> bytes:
|
|
48
|
+
return base64.b64encode(rasa_audio_bytes)
|
|
49
|
+
|
|
50
|
+
def channel_bytes_to_messages(
|
|
51
|
+
self, recipient_id: str, channel_bytes: bytes
|
|
52
|
+
) -> List[Any]:
|
|
53
|
+
message_id = uuid.uuid4().hex
|
|
54
|
+
media_message = json.dumps(
|
|
55
|
+
{
|
|
56
|
+
"event": "media",
|
|
57
|
+
"streamSid": recipient_id,
|
|
58
|
+
"media": {
|
|
59
|
+
"payload": channel_bytes.decode("utf-8"),
|
|
60
|
+
},
|
|
61
|
+
}
|
|
62
|
+
)
|
|
63
|
+
mark_message = json.dumps(
|
|
64
|
+
{
|
|
65
|
+
"event": "mark",
|
|
66
|
+
"streamSid": recipient_id,
|
|
67
|
+
"mark": {"name": message_id},
|
|
68
|
+
}
|
|
69
|
+
)
|
|
70
|
+
self.latest_message_id = message_id
|
|
71
|
+
return [media_message, mark_message]
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class TwilioMediaStreamsInputChannel(VoiceInputChannel):
|
|
75
|
+
@classmethod
|
|
76
|
+
def name(cls) -> str:
|
|
77
|
+
return "twilio_media_streams"
|
|
78
|
+
|
|
79
|
+
def channel_bytes_to_rasa_audio_bytes(self, input_bytes: bytes) -> RasaAudioBytes:
|
|
80
|
+
return RasaAudioBytes(base64.b64decode(input_bytes))
|
|
81
|
+
|
|
82
|
+
async def collect_call_parameters(
|
|
83
|
+
self, channel_websocket: Websocket
|
|
84
|
+
) -> Optional[CallParameters]:
|
|
85
|
+
async for message in channel_websocket:
|
|
86
|
+
data = json.loads(message)
|
|
87
|
+
if data["event"] == "start":
|
|
88
|
+
# retrieve parameters set in the webhook - contains info about the
|
|
89
|
+
# caller
|
|
90
|
+
return map_call_params(data)
|
|
91
|
+
return None
|
|
92
|
+
|
|
93
|
+
def map_input_message(
|
|
94
|
+
self,
|
|
95
|
+
message: Any,
|
|
96
|
+
) -> VoiceChannelAction:
|
|
97
|
+
data = json.loads(message)
|
|
98
|
+
if data["event"] == "media":
|
|
99
|
+
audio_bytes = self.channel_bytes_to_rasa_audio_bytes(
|
|
100
|
+
data["media"]["payload"]
|
|
101
|
+
)
|
|
102
|
+
return NewAudioAction(audio_bytes)
|
|
103
|
+
elif data["event"] == "stop":
|
|
104
|
+
return EndConversationAction()
|
|
105
|
+
elif data["event"] == "mark":
|
|
106
|
+
if data["mark"]["name"] == self.hangup_after:
|
|
107
|
+
structlogger.debug("twilio_streams.hangup", marker=self.hangup_after)
|
|
108
|
+
return EndConversationAction()
|
|
109
|
+
return ContinueConversationAction()
|
|
110
|
+
|
|
111
|
+
def create_output_channel(
|
|
112
|
+
self, voice_websocket: Websocket, tts_engine: TTSEngine
|
|
113
|
+
) -> VoiceOutputChannel:
|
|
114
|
+
return TwilioMediaStreamsOutputChannel(
|
|
115
|
+
voice_websocket,
|
|
116
|
+
tts_engine,
|
|
117
|
+
self.tts_cache,
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
def websocket_stream_url(self) -> str:
|
|
121
|
+
"""Returns the websocket stream URL."""
|
|
122
|
+
# depending on the config value, the url might contain http as a
|
|
123
|
+
# protocol or not - we'll make sure both work
|
|
124
|
+
if self.server_url.startswith("http"):
|
|
125
|
+
base_url = self.server_url.replace("http", "ws")
|
|
126
|
+
else:
|
|
127
|
+
base_url = f"wss://{self.server_url}"
|
|
128
|
+
return f"{base_url}/webhooks/twilio_media_streams/websocket"
|
|
129
|
+
|
|
130
|
+
def blueprint(
|
|
131
|
+
self, on_new_message: Callable[[UserMessage], Awaitable[Any]]
|
|
132
|
+
) -> Blueprint:
|
|
133
|
+
"""Defines a Sanic bluelogger.debug."""
|
|
134
|
+
blueprint = Blueprint("socketio_webhook", __name__)
|
|
135
|
+
|
|
136
|
+
@blueprint.route("/", methods=["GET"])
|
|
137
|
+
async def health(_: Request) -> HTTPResponse:
|
|
138
|
+
return response.json({"status": "ok"})
|
|
139
|
+
|
|
140
|
+
@blueprint.route("/webhook", methods=["POST"])
|
|
141
|
+
async def receive(request: Request) -> HTTPResponse:
|
|
142
|
+
from twilio.twiml.voice_response import Connect, VoiceResponse
|
|
143
|
+
|
|
144
|
+
voice_response = VoiceResponse()
|
|
145
|
+
start = Connect()
|
|
146
|
+
stream = start.stream(url=self.websocket_stream_url())
|
|
147
|
+
# pass information about the call to the webhook - so we can
|
|
148
|
+
# store it in the input channel
|
|
149
|
+
stream.parameter(name="call_id", value=request.form.get("CallSid", None))
|
|
150
|
+
stream.parameter(name="user_phone", value=request.form.get("From", None))
|
|
151
|
+
stream.parameter(name="bot_phone", value=request.form.get("To", None))
|
|
152
|
+
stream.parameter(
|
|
153
|
+
name="direction", value=request.form.get("Direction", None)
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
voice_response.append(start)
|
|
157
|
+
|
|
158
|
+
return response.text(str(voice_response), content_type="text/xml")
|
|
159
|
+
|
|
160
|
+
@blueprint.websocket("/websocket") # type: ignore
|
|
161
|
+
async def handle_message(request: Request, ws: Websocket) -> None:
|
|
162
|
+
await self.run_audio_streaming(on_new_message, ws)
|
|
163
|
+
|
|
164
|
+
return blueprint
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
import wave
|
|
2
|
+
import audioop
|
|
3
|
+
from dataclasses import asdict, dataclass
|
|
4
|
+
from typing import Optional, Type, TypeVar
|
|
5
|
+
|
|
6
|
+
import structlog
|
|
7
|
+
|
|
8
|
+
from rasa.core.channels.voice_stream.audio_bytes import RasaAudioBytes
|
|
9
|
+
from rasa.shared.exceptions import RasaException
|
|
10
|
+
|
|
11
|
+
structlogger = structlog.get_logger()
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def read_wav_to_rasa_audio_bytes(file_name: str) -> Optional[RasaAudioBytes]:
|
|
15
|
+
"""Reads rasa audio bytes from a file."""
|
|
16
|
+
if not file_name.endswith(".wav"):
|
|
17
|
+
raise RasaException("Should only read .wav files with this method.")
|
|
18
|
+
wave_object = wave.open(file_name, "rb")
|
|
19
|
+
wave_data = wave_object.readframes(wave_object.getnframes())
|
|
20
|
+
if wave_object.getnchannels() != 1:
|
|
21
|
+
wave_data = audioop.tomono(wave_data, wave_object.getsampwidth(), 1, 1)
|
|
22
|
+
if wave_object.getsampwidth() != 1:
|
|
23
|
+
wave_data = audioop.lin2lin(wave_data, wave_object.getsampwidth(), 1)
|
|
24
|
+
# 8 bit is unsigned
|
|
25
|
+
# wave_data = audioop.bias(wave_data, 1, 128)
|
|
26
|
+
if wave_object.getframerate() != 8000:
|
|
27
|
+
wave_data, _ = audioop.ratecv(
|
|
28
|
+
wave_data, 1, 1, wave_object.getframerate(), 8000, None
|
|
29
|
+
)
|
|
30
|
+
wave_data = audioop.lin2ulaw(wave_data, 1)
|
|
31
|
+
return RasaAudioBytes(wave_data)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def generate_silence(length_in_seconds: float = 1.0) -> RasaAudioBytes:
|
|
35
|
+
return RasaAudioBytes(b"\00" * int(length_in_seconds * 8000))
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
T = TypeVar("T", bound="MergeableConfig")
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@dataclass
|
|
42
|
+
class MergeableConfig:
|
|
43
|
+
def __init__(self) -> None:
|
|
44
|
+
pass
|
|
45
|
+
|
|
46
|
+
def merge(self: T, other: Optional[T]) -> T:
|
|
47
|
+
"""Merges two configs while dropping None values of the second config."""
|
|
48
|
+
if other is None:
|
|
49
|
+
return self
|
|
50
|
+
other_dict = asdict(other)
|
|
51
|
+
other_dict_clean = {k: v for k, v in other_dict.items() if v is not None}
|
|
52
|
+
merged = {**asdict(self), **other_dict_clean}
|
|
53
|
+
return self.from_dict(merged)
|
|
54
|
+
|
|
55
|
+
@classmethod
|
|
56
|
+
def from_dict(cls: Type[T], data: dict[str, Optional[str]]) -> T:
|
|
57
|
+
return cls(**data)
|
|
@@ -0,0 +1,247 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import logging
|
|
3
|
+
import copy
|
|
4
|
+
from dataclasses import asdict, dataclass
|
|
5
|
+
from typing import Any, Awaitable, Callable, Dict, List, Optional
|
|
6
|
+
|
|
7
|
+
from sanic.exceptions import ServerError, WebsocketClosed
|
|
8
|
+
|
|
9
|
+
from rasa.core.channels import InputChannel, OutputChannel, UserMessage
|
|
10
|
+
from rasa.core.channels.voice_ready.utils import CallParameters
|
|
11
|
+
from rasa.core.channels.voice_stream.asr.asr_engine import ASREngine
|
|
12
|
+
from rasa.core.channels.voice_stream.asr.asr_event import ASREvent, NewTranscript
|
|
13
|
+
from sanic import Websocket # type: ignore
|
|
14
|
+
|
|
15
|
+
from rasa.core.channels.voice_stream.asr.deepgram import DeepgramASR
|
|
16
|
+
from rasa.core.channels.voice_stream.audio_bytes import RasaAudioBytes
|
|
17
|
+
from rasa.core.channels.voice_stream.tts.azure import AzureTTS
|
|
18
|
+
from rasa.core.channels.voice_stream.tts.tts_engine import TTSEngine, TTSError
|
|
19
|
+
from rasa.core.channels.voice_stream.tts.cartesia import CartesiaTTS
|
|
20
|
+
from rasa.core.channels.voice_stream.tts.tts_cache import TTSCache
|
|
21
|
+
|
|
22
|
+
logger = logging.getLogger(__name__)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass
|
|
26
|
+
class VoiceChannelAction:
|
|
27
|
+
pass
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@dataclass
|
|
31
|
+
class NewAudioAction(VoiceChannelAction):
|
|
32
|
+
audio_bytes: RasaAudioBytes
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@dataclass
|
|
36
|
+
class EndConversationAction(VoiceChannelAction):
|
|
37
|
+
pass
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@dataclass
|
|
41
|
+
class ContinueConversationAction(VoiceChannelAction):
|
|
42
|
+
pass
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def asr_engine_from_config(asr_config: Dict) -> ASREngine:
|
|
46
|
+
name = str(asr_config["name"]).lower()
|
|
47
|
+
asr_config = copy.copy(asr_config)
|
|
48
|
+
asr_config.pop("name")
|
|
49
|
+
if name == "deepgram":
|
|
50
|
+
return DeepgramASR.from_config_dict(asr_config)
|
|
51
|
+
else:
|
|
52
|
+
raise NotImplementedError
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def tts_engine_from_config(tts_config: Dict) -> TTSEngine:
|
|
56
|
+
name = str(tts_config["name"]).lower()
|
|
57
|
+
tts_config = copy.copy(tts_config)
|
|
58
|
+
tts_config.pop("name")
|
|
59
|
+
if name == "azure":
|
|
60
|
+
return AzureTTS.from_config_dict(tts_config)
|
|
61
|
+
elif name == "cartesia":
|
|
62
|
+
return CartesiaTTS.from_config_dict(tts_config)
|
|
63
|
+
else:
|
|
64
|
+
raise NotImplementedError(f"TTS engine {name} is not implemented")
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class VoiceOutputChannel(OutputChannel):
|
|
68
|
+
def __init__(
|
|
69
|
+
self,
|
|
70
|
+
voice_websocket: Websocket,
|
|
71
|
+
tts_engine: TTSEngine,
|
|
72
|
+
tts_cache: TTSCache,
|
|
73
|
+
):
|
|
74
|
+
self.voice_websocket = voice_websocket
|
|
75
|
+
self.tts_engine = tts_engine
|
|
76
|
+
self.tts_cache = tts_cache
|
|
77
|
+
|
|
78
|
+
self.should_hangup = False
|
|
79
|
+
self.latest_message_id: Optional[str] = None
|
|
80
|
+
|
|
81
|
+
def rasa_audio_bytes_to_channel_bytes(
|
|
82
|
+
self, rasa_audio_bytes: RasaAudioBytes
|
|
83
|
+
) -> bytes:
|
|
84
|
+
raise NotImplementedError
|
|
85
|
+
|
|
86
|
+
def channel_bytes_to_messages(
|
|
87
|
+
self, recipient_id: str, channel_bytes: bytes
|
|
88
|
+
) -> List[Any]:
|
|
89
|
+
raise NotImplementedError
|
|
90
|
+
|
|
91
|
+
async def send_text_message(
|
|
92
|
+
self, recipient_id: str, text: str, **kwargs: Any
|
|
93
|
+
) -> None:
|
|
94
|
+
cached_audio_bytes = self.tts_cache.get(text)
|
|
95
|
+
|
|
96
|
+
if cached_audio_bytes:
|
|
97
|
+
await self.send_audio_bytes(recipient_id, cached_audio_bytes)
|
|
98
|
+
return
|
|
99
|
+
collected_audio_bytes = RasaAudioBytes(b"")
|
|
100
|
+
# Todo: make kwargs compatible with engine config
|
|
101
|
+
synth_config = self.tts_engine.config.__class__.from_dict({})
|
|
102
|
+
try:
|
|
103
|
+
audio_stream = self.tts_engine.synthesize(text, synth_config)
|
|
104
|
+
except TTSError:
|
|
105
|
+
# TODO: add message that works without tts, e.g. loading from disc
|
|
106
|
+
pass
|
|
107
|
+
async for audio_bytes in audio_stream:
|
|
108
|
+
try:
|
|
109
|
+
await self.send_audio_bytes(recipient_id, audio_bytes)
|
|
110
|
+
except (WebsocketClosed, ServerError):
|
|
111
|
+
# ignore sending error, and keep collecting and caching audio bytes
|
|
112
|
+
self.should_hangup = True
|
|
113
|
+
|
|
114
|
+
collected_audio_bytes = RasaAudioBytes(collected_audio_bytes + audio_bytes)
|
|
115
|
+
|
|
116
|
+
self.tts_cache.put(text, collected_audio_bytes)
|
|
117
|
+
|
|
118
|
+
async def send_audio_bytes(
|
|
119
|
+
self, recipient_id: str, audio_bytes: RasaAudioBytes
|
|
120
|
+
) -> None:
|
|
121
|
+
channel_bytes = self.rasa_audio_bytes_to_channel_bytes(audio_bytes)
|
|
122
|
+
for message in self.channel_bytes_to_messages(recipient_id, channel_bytes):
|
|
123
|
+
await self.voice_websocket.send(message)
|
|
124
|
+
|
|
125
|
+
async def hangup(self, recipient_id: str, **kwargs: Any) -> None:
|
|
126
|
+
self.should_hangup = True
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
class VoiceInputChannel(InputChannel):
|
|
130
|
+
def __init__(self, server_url: str, asr_config: Dict, tts_config: Dict):
|
|
131
|
+
self.server_url = server_url
|
|
132
|
+
self.asr_config = asr_config
|
|
133
|
+
self.tts_config = tts_config
|
|
134
|
+
self.tts_cache = TTSCache(tts_config.get("cache_size", 1000))
|
|
135
|
+
|
|
136
|
+
# if set to a value, call will be hungup after marker is reached
|
|
137
|
+
self.hangup_after: Optional[str] = None
|
|
138
|
+
|
|
139
|
+
@classmethod
|
|
140
|
+
def from_credentials(cls, credentials: Optional[Dict[str, Any]]) -> InputChannel:
|
|
141
|
+
credentials = credentials or {}
|
|
142
|
+
return cls(credentials["server_url"], credentials["asr"], credentials["tts"])
|
|
143
|
+
|
|
144
|
+
def channel_bytes_to_rasa_audio_bytes(self, input_bytes: bytes) -> RasaAudioBytes:
|
|
145
|
+
raise NotImplementedError
|
|
146
|
+
|
|
147
|
+
async def collect_call_parameters(
|
|
148
|
+
self, channel_websocket: Websocket
|
|
149
|
+
) -> Optional[CallParameters]:
|
|
150
|
+
raise NotImplementedError
|
|
151
|
+
|
|
152
|
+
async def start_session(
|
|
153
|
+
self,
|
|
154
|
+
channel_websocket: Websocket,
|
|
155
|
+
on_new_message: Callable[[UserMessage], Awaitable[Any]],
|
|
156
|
+
tts_engine: TTSEngine,
|
|
157
|
+
call_parameters: CallParameters,
|
|
158
|
+
) -> None:
|
|
159
|
+
output_channel = self.create_output_channel(channel_websocket, tts_engine)
|
|
160
|
+
message = UserMessage(
|
|
161
|
+
"/session_start",
|
|
162
|
+
output_channel,
|
|
163
|
+
call_parameters.stream_id,
|
|
164
|
+
input_channel=self.name(),
|
|
165
|
+
metadata=asdict(call_parameters),
|
|
166
|
+
)
|
|
167
|
+
await on_new_message(message)
|
|
168
|
+
|
|
169
|
+
def map_input_message(
|
|
170
|
+
self,
|
|
171
|
+
message: Any,
|
|
172
|
+
) -> VoiceChannelAction:
|
|
173
|
+
"""Map a channel input message to a voice channel action."""
|
|
174
|
+
raise NotImplementedError
|
|
175
|
+
|
|
176
|
+
async def run_audio_streaming(
|
|
177
|
+
self,
|
|
178
|
+
on_new_message: Callable[[UserMessage], Awaitable[Any]],
|
|
179
|
+
channel_websocket: Websocket,
|
|
180
|
+
) -> None:
|
|
181
|
+
"""Pipe input audio to ASR and consume ASR events simultaneously."""
|
|
182
|
+
asr_engine = asr_engine_from_config(self.asr_config)
|
|
183
|
+
tts_engine = tts_engine_from_config(self.tts_config)
|
|
184
|
+
await asr_engine.connect()
|
|
185
|
+
|
|
186
|
+
call_parameters = await self.collect_call_parameters(channel_websocket)
|
|
187
|
+
if call_parameters is None:
|
|
188
|
+
raise ValueError("Failed to extract call parameters for call.")
|
|
189
|
+
await self.start_session(
|
|
190
|
+
channel_websocket, on_new_message, tts_engine, call_parameters
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
async def consume_audio_bytes() -> None:
|
|
194
|
+
async for message in channel_websocket:
|
|
195
|
+
channel_action = self.map_input_message(message)
|
|
196
|
+
if isinstance(channel_action, NewAudioAction):
|
|
197
|
+
await asr_engine.send_audio_chunks(channel_action.audio_bytes)
|
|
198
|
+
elif isinstance(channel_action, EndConversationAction):
|
|
199
|
+
# end stream event came from the other side
|
|
200
|
+
break
|
|
201
|
+
|
|
202
|
+
async def consume_asr_events() -> None:
|
|
203
|
+
async for event in asr_engine.stream_asr_events():
|
|
204
|
+
await self.handle_asr_event(
|
|
205
|
+
event,
|
|
206
|
+
channel_websocket,
|
|
207
|
+
on_new_message,
|
|
208
|
+
tts_engine,
|
|
209
|
+
call_parameters,
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
await asyncio.wait(
|
|
213
|
+
[consume_audio_bytes(), consume_asr_events()],
|
|
214
|
+
return_when=asyncio.FIRST_COMPLETED,
|
|
215
|
+
)
|
|
216
|
+
await tts_engine.close_connection()
|
|
217
|
+
await asr_engine.close_connection()
|
|
218
|
+
|
|
219
|
+
def create_output_channel(
|
|
220
|
+
self, voice_websocket: Websocket, tts_engine: TTSEngine
|
|
221
|
+
) -> VoiceOutputChannel:
|
|
222
|
+
"""Create a matching voice output channel for this voice input channel."""
|
|
223
|
+
raise NotImplementedError
|
|
224
|
+
|
|
225
|
+
async def handle_asr_event(
|
|
226
|
+
self,
|
|
227
|
+
e: ASREvent,
|
|
228
|
+
voice_websocket: Websocket,
|
|
229
|
+
on_new_message: Callable[[UserMessage], Awaitable[Any]],
|
|
230
|
+
tts_engine: TTSEngine,
|
|
231
|
+
call_parameters: CallParameters,
|
|
232
|
+
) -> None:
|
|
233
|
+
"""Handle a new event from the ASR system."""
|
|
234
|
+
if isinstance(e, NewTranscript) and e.text:
|
|
235
|
+
logger.info(f"New transcript: {e.text}")
|
|
236
|
+
output_channel = self.create_output_channel(voice_websocket, tts_engine)
|
|
237
|
+
message = UserMessage(
|
|
238
|
+
e.text,
|
|
239
|
+
output_channel,
|
|
240
|
+
call_parameters.stream_id,
|
|
241
|
+
input_channel=self.name(),
|
|
242
|
+
metadata=asdict(call_parameters),
|
|
243
|
+
)
|
|
244
|
+
await on_new_message(message)
|
|
245
|
+
|
|
246
|
+
if output_channel.should_hangup:
|
|
247
|
+
self.hangup_after = output_channel.latest_message_id
|
|
@@ -1,8 +1,7 @@
|
|
|
1
1
|
import logging
|
|
2
|
-
from typing import List, Optional, Dict, Text, Set, Any
|
|
3
|
-
|
|
4
2
|
import numpy as np
|
|
5
3
|
import scipy.sparse
|
|
4
|
+
from typing import List, Optional, Dict, Text, Set, Any
|
|
6
5
|
|
|
7
6
|
from rasa.core.featurizers.precomputation import MessageContainerForCoreFeaturization
|
|
8
7
|
from rasa.nlu.extractors.extractor import EntityTagSpec
|
|
@@ -361,26 +360,6 @@ class SingleStateFeaturizer:
|
|
|
361
360
|
for action in domain.action_names_or_texts
|
|
362
361
|
]
|
|
363
362
|
|
|
364
|
-
def to_dict(self) -> Dict[str, Any]:
|
|
365
|
-
return {
|
|
366
|
-
"action_texts": self.action_texts,
|
|
367
|
-
"entity_tag_specs": self.entity_tag_specs,
|
|
368
|
-
"feature_states": self._default_feature_states,
|
|
369
|
-
}
|
|
370
|
-
|
|
371
|
-
@classmethod
|
|
372
|
-
def create_from_dict(
|
|
373
|
-
cls, data: Dict[str, Any]
|
|
374
|
-
) -> Optional["SingleStateFeaturizer"]:
|
|
375
|
-
if not data:
|
|
376
|
-
return None
|
|
377
|
-
|
|
378
|
-
featurizer = SingleStateFeaturizer()
|
|
379
|
-
featurizer.action_texts = data["action_texts"]
|
|
380
|
-
featurizer._default_feature_states = data["feature_states"]
|
|
381
|
-
featurizer.entity_tag_specs = data["entity_tag_specs"]
|
|
382
|
-
return featurizer
|
|
383
|
-
|
|
384
363
|
|
|
385
364
|
class IntentTokenizerSingleStateFeaturizer(SingleStateFeaturizer):
|
|
386
365
|
"""A SingleStateFeaturizer for use with policies that predict intent labels."""
|