rasa-pro 3.10.15__py3-none-any.whl → 3.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- rasa/__main__.py +31 -15
- rasa/api.py +12 -2
- rasa/cli/arguments/default_arguments.py +24 -4
- rasa/cli/arguments/run.py +15 -0
- rasa/cli/arguments/shell.py +5 -1
- rasa/cli/arguments/train.py +17 -9
- rasa/cli/evaluate.py +7 -7
- rasa/cli/inspect.py +19 -7
- rasa/cli/interactive.py +1 -0
- rasa/cli/project_templates/calm/config.yml +5 -7
- rasa/cli/project_templates/calm/endpoints.yml +15 -2
- rasa/cli/project_templates/tutorial/config.yml +8 -5
- rasa/cli/project_templates/tutorial/data/flows.yml +1 -1
- rasa/cli/project_templates/tutorial/data/patterns.yml +5 -0
- rasa/cli/project_templates/tutorial/domain.yml +14 -0
- rasa/cli/project_templates/tutorial/endpoints.yml +5 -0
- rasa/cli/run.py +7 -0
- rasa/cli/scaffold.py +4 -2
- rasa/cli/studio/upload.py +0 -15
- rasa/cli/train.py +14 -53
- rasa/cli/utils.py +14 -11
- rasa/cli/x.py +7 -7
- rasa/constants.py +3 -1
- rasa/core/actions/action.py +77 -33
- rasa/core/actions/action_hangup.py +29 -0
- rasa/core/actions/action_repeat_bot_messages.py +89 -0
- rasa/core/actions/e2e_stub_custom_action_executor.py +5 -1
- rasa/core/actions/http_custom_action_executor.py +4 -0
- rasa/core/agent.py +2 -2
- rasa/core/brokers/kafka.py +3 -1
- rasa/core/brokers/pika.py +3 -1
- rasa/core/channels/__init__.py +10 -6
- rasa/core/channels/channel.py +41 -4
- rasa/core/channels/development_inspector.py +150 -46
- rasa/core/channels/inspector/README.md +1 -1
- rasa/core/channels/inspector/dist/assets/{arc-b6e548fe.js → arc-bc141fb2.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{c4Diagram-d0fbc5ce-fa03ac9e.js → c4Diagram-d0fbc5ce-be2db283.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{classDiagram-936ed81e-ee67392a.js → classDiagram-936ed81e-55366915.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{classDiagram-v2-c3cb15f1-9b283fae.js → classDiagram-v2-c3cb15f1-bb529518.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{createText-62fc7601-8b6fcc2a.js → createText-62fc7601-b0ec81d6.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{edges-f2ad444c-22e77f4f.js → edges-f2ad444c-6166330c.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{erDiagram-9d236eb7-60ffc87f.js → erDiagram-9d236eb7-5ccc6a8e.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{flowDb-1972c806-9dd802e4.js → flowDb-1972c806-fca3bfe4.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{flowDiagram-7ea5b25a-5fa1912f.js → flowDiagram-7ea5b25a-4739080f.js} +1 -1
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-736177bf.js +1 -0
- rasa/core/channels/inspector/dist/assets/{flowchart-elk-definition-abe16c3d-622a1fd2.js → flowchart-elk-definition-abe16c3d-7c1b0e0f.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{ganttDiagram-9b5ea136-e285a63a.js → ganttDiagram-9b5ea136-772fd050.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{gitGraphDiagram-99d0ae7c-f237bdca.js → gitGraphDiagram-99d0ae7c-8eae1dc9.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{index-2c4b9a3b-4b03d70e.js → index-2c4b9a3b-f55afcdf.js} +1 -1
- rasa/core/channels/inspector/dist/assets/index-e7cef9de.js +1317 -0
- rasa/core/channels/inspector/dist/assets/{infoDiagram-736b4530-72a0fa5f.js → infoDiagram-736b4530-124d4a14.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{journeyDiagram-df861f2b-82218c41.js → journeyDiagram-df861f2b-7c4fae44.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{layout-78cff630.js → layout-b9885fb6.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{line-5038b469.js → line-7c59abb6.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{linear-c4fc4098.js → linear-4776f780.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{mindmap-definition-beec6740-c33c8ea6.js → mindmap-definition-beec6740-2332c46c.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{pieDiagram-dbbf0591-a8d03059.js → pieDiagram-dbbf0591-8fb39303.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{quadrantDiagram-4d7f4fd6-6a0e56b2.js → quadrantDiagram-4d7f4fd6-3c7180a2.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{requirementDiagram-6fc4c22a-2dc7c7bd.js → requirementDiagram-6fc4c22a-e910bcb8.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{sankeyDiagram-8f13d901-2360fe39.js → sankeyDiagram-8f13d901-ead16c89.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{sequenceDiagram-b655622a-41b9f9ad.js → sequenceDiagram-b655622a-29a02a19.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{stateDiagram-59f0c015-0aad326f.js → stateDiagram-59f0c015-042b3137.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{stateDiagram-v2-2b26beab-9847d984.js → stateDiagram-v2-2b26beab-2178c0f3.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-080da4f6-564d890e.js → styles-080da4f6-23ffa4fc.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-3dcbcfbf-38957613.js → styles-3dcbcfbf-94f59763.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-9c745c82-f0fc6921.js → styles-9c745c82-78a6bebc.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{svgDrawCommon-4835440b-ef3c5a77.js → svgDrawCommon-4835440b-eae2a6f6.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{timeline-definition-5b62e21b-bf3e91c1.js → timeline-definition-5b62e21b-5c968d92.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{xychartDiagram-2b33534f-4d4026c0.js → xychartDiagram-2b33534f-fd3db0d5.js} +1 -1
- rasa/core/channels/inspector/dist/index.html +18 -15
- rasa/core/channels/inspector/index.html +17 -14
- rasa/core/channels/inspector/package.json +5 -1
- rasa/core/channels/inspector/src/App.tsx +118 -68
- rasa/core/channels/inspector/src/components/Chat.tsx +95 -0
- rasa/core/channels/inspector/src/components/DiagramFlow.tsx +11 -10
- rasa/core/channels/inspector/src/components/DialogueStack.tsx +10 -25
- rasa/core/channels/inspector/src/components/LoadingSpinner.tsx +6 -3
- rasa/core/channels/inspector/src/helpers/audiostream.ts +165 -0
- rasa/core/channels/inspector/src/helpers/formatters.test.ts +10 -0
- rasa/core/channels/inspector/src/helpers/formatters.ts +107 -41
- rasa/core/channels/inspector/src/helpers/utils.ts +92 -7
- rasa/core/channels/inspector/src/types.ts +21 -1
- rasa/core/channels/inspector/yarn.lock +94 -1
- rasa/core/channels/rest.py +51 -46
- rasa/core/channels/socketio.py +28 -1
- rasa/core/channels/telegram.py +1 -1
- rasa/core/channels/twilio.py +1 -1
- rasa/core/channels/{audiocodes.py → voice_ready/audiocodes.py} +122 -69
- rasa/core/channels/{voice_aware → voice_ready}/jambonz.py +26 -8
- rasa/core/channels/{voice_aware → voice_ready}/jambonz_protocol.py +57 -5
- rasa/core/channels/{twilio_voice.py → voice_ready/twilio_voice.py} +64 -28
- rasa/core/channels/voice_ready/utils.py +37 -0
- rasa/core/channels/voice_stream/asr/__init__.py +0 -0
- rasa/core/channels/voice_stream/asr/asr_engine.py +89 -0
- rasa/core/channels/voice_stream/asr/asr_event.py +18 -0
- rasa/core/channels/voice_stream/asr/azure.py +129 -0
- rasa/core/channels/voice_stream/asr/deepgram.py +90 -0
- rasa/core/channels/voice_stream/audio_bytes.py +8 -0
- rasa/core/channels/voice_stream/browser_audio.py +107 -0
- rasa/core/channels/voice_stream/call_state.py +23 -0
- rasa/core/channels/voice_stream/tts/__init__.py +0 -0
- rasa/core/channels/voice_stream/tts/azure.py +106 -0
- rasa/core/channels/voice_stream/tts/cartesia.py +118 -0
- rasa/core/channels/voice_stream/tts/tts_cache.py +27 -0
- rasa/core/channels/voice_stream/tts/tts_engine.py +58 -0
- rasa/core/channels/voice_stream/twilio_media_streams.py +173 -0
- rasa/core/channels/voice_stream/util.py +57 -0
- rasa/core/channels/voice_stream/voice_channel.py +427 -0
- rasa/core/information_retrieval/qdrant.py +1 -0
- rasa/core/nlg/contextual_response_rephraser.py +45 -17
- rasa/{nlu → core}/persistor.py +203 -68
- rasa/core/policies/enterprise_search_policy.py +119 -63
- rasa/core/policies/flows/flow_executor.py +15 -22
- rasa/core/policies/intentless_policy.py +83 -28
- rasa/core/processor.py +25 -0
- rasa/core/run.py +12 -2
- rasa/core/secrets_manager/constants.py +4 -0
- rasa/core/secrets_manager/factory.py +8 -0
- rasa/core/secrets_manager/vault.py +11 -1
- rasa/core/training/interactive.py +33 -34
- rasa/core/utils.py +47 -21
- rasa/dialogue_understanding/coexistence/llm_based_router.py +41 -14
- rasa/dialogue_understanding/commands/__init__.py +6 -0
- rasa/dialogue_understanding/commands/repeat_bot_messages_command.py +60 -0
- rasa/dialogue_understanding/commands/session_end_command.py +61 -0
- rasa/dialogue_understanding/commands/user_silence_command.py +59 -0
- rasa/dialogue_understanding/commands/utils.py +5 -0
- rasa/dialogue_understanding/generator/constants.py +2 -0
- rasa/dialogue_understanding/generator/flow_retrieval.py +47 -9
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +38 -15
- rasa/dialogue_understanding/generator/llm_command_generator.py +1 -1
- rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +35 -13
- rasa/dialogue_understanding/generator/single_step/command_prompt_template.jinja2 +3 -0
- rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +60 -13
- rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +53 -0
- rasa/dialogue_understanding/patterns/repeat.py +37 -0
- rasa/dialogue_understanding/patterns/user_silence.py +37 -0
- rasa/dialogue_understanding/processor/command_processor.py +21 -1
- rasa/e2e_test/aggregate_test_stats_calculator.py +1 -11
- rasa/e2e_test/assertions.py +136 -61
- rasa/e2e_test/assertions_schema.yml +23 -0
- rasa/e2e_test/e2e_test_case.py +85 -6
- rasa/e2e_test/e2e_test_runner.py +2 -3
- rasa/engine/graph.py +0 -1
- rasa/engine/loader.py +12 -0
- rasa/engine/recipes/config_files/default_config.yml +0 -3
- rasa/engine/recipes/default_recipe.py +0 -1
- rasa/engine/recipes/graph_recipe.py +0 -1
- rasa/engine/runner/dask.py +2 -2
- rasa/engine/storage/local_model_storage.py +12 -42
- rasa/engine/storage/storage.py +1 -5
- rasa/engine/validation.py +527 -74
- rasa/model_manager/__init__.py +0 -0
- rasa/model_manager/config.py +40 -0
- rasa/model_manager/model_api.py +559 -0
- rasa/model_manager/runner_service.py +286 -0
- rasa/model_manager/socket_bridge.py +146 -0
- rasa/model_manager/studio_jwt_auth.py +86 -0
- rasa/model_manager/trainer_service.py +325 -0
- rasa/model_manager/utils.py +87 -0
- rasa/model_manager/warm_rasa_process.py +187 -0
- rasa/model_service.py +112 -0
- rasa/model_training.py +42 -23
- rasa/nlu/tokenizers/whitespace_tokenizer.py +3 -14
- rasa/server.py +4 -2
- rasa/shared/constants.py +60 -8
- rasa/shared/core/constants.py +13 -0
- rasa/shared/core/domain.py +107 -50
- rasa/shared/core/events.py +29 -0
- rasa/shared/core/flows/flow.py +5 -0
- rasa/shared/core/flows/flows_list.py +19 -6
- rasa/shared/core/flows/flows_yaml_schema.json +10 -0
- rasa/shared/core/flows/utils.py +39 -0
- rasa/shared/core/flows/validation.py +121 -0
- rasa/shared/core/flows/yaml_flows_io.py +15 -27
- rasa/shared/core/slots.py +5 -0
- rasa/shared/importers/importer.py +59 -41
- rasa/shared/importers/multi_project.py +23 -11
- rasa/shared/importers/rasa.py +12 -3
- rasa/shared/importers/remote_importer.py +196 -0
- rasa/shared/importers/utils.py +3 -1
- rasa/shared/nlu/training_data/formats/rasa_yaml.py +18 -3
- rasa/shared/nlu/training_data/training_data.py +18 -19
- rasa/shared/providers/_configs/litellm_router_client_config.py +220 -0
- rasa/shared/providers/_configs/model_group_config.py +167 -0
- rasa/shared/providers/_configs/openai_client_config.py +1 -1
- rasa/shared/providers/_configs/rasa_llm_client_config.py +73 -0
- rasa/shared/providers/_configs/self_hosted_llm_client_config.py +1 -0
- rasa/shared/providers/_configs/utils.py +16 -0
- rasa/shared/providers/_utils.py +79 -0
- rasa/shared/providers/embedding/_base_litellm_embedding_client.py +13 -29
- rasa/shared/providers/embedding/azure_openai_embedding_client.py +54 -21
- rasa/shared/providers/embedding/default_litellm_embedding_client.py +24 -0
- rasa/shared/providers/embedding/litellm_router_embedding_client.py +135 -0
- rasa/shared/providers/llm/_base_litellm_client.py +34 -22
- rasa/shared/providers/llm/azure_openai_llm_client.py +50 -29
- rasa/shared/providers/llm/default_litellm_llm_client.py +24 -0
- rasa/shared/providers/llm/litellm_router_llm_client.py +182 -0
- rasa/shared/providers/llm/rasa_llm_client.py +112 -0
- rasa/shared/providers/llm/self_hosted_llm_client.py +5 -29
- rasa/shared/providers/mappings.py +19 -0
- rasa/shared/providers/router/__init__.py +0 -0
- rasa/shared/providers/router/_base_litellm_router_client.py +183 -0
- rasa/shared/providers/router/router_client.py +73 -0
- rasa/shared/utils/common.py +40 -24
- rasa/shared/utils/health_check/__init__.py +0 -0
- rasa/shared/utils/health_check/embeddings_health_check_mixin.py +31 -0
- rasa/shared/utils/health_check/health_check.py +258 -0
- rasa/shared/utils/health_check/llm_health_check_mixin.py +31 -0
- rasa/shared/utils/io.py +27 -6
- rasa/shared/utils/llm.py +353 -43
- rasa/shared/utils/schemas/events.py +2 -0
- rasa/shared/utils/schemas/model_config.yml +0 -10
- rasa/shared/utils/yaml.py +181 -38
- rasa/studio/data_handler.py +3 -1
- rasa/studio/upload.py +160 -74
- rasa/telemetry.py +94 -17
- rasa/tracing/config.py +3 -1
- rasa/tracing/instrumentation/attribute_extractors.py +95 -18
- rasa/tracing/instrumentation/instrumentation.py +121 -0
- rasa/utils/common.py +5 -0
- rasa/utils/endpoints.py +27 -1
- rasa/utils/io.py +8 -16
- rasa/utils/log_utils.py +9 -2
- rasa/utils/sanic_error_handler.py +32 -0
- rasa/validator.py +110 -4
- rasa/version.py +1 -1
- {rasa_pro-3.10.15.dist-info → rasa_pro-3.11.0.dist-info}/METADATA +14 -12
- {rasa_pro-3.10.15.dist-info → rasa_pro-3.11.0.dist-info}/RECORD +234 -183
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-1844e5a5.js +0 -1
- rasa/core/channels/inspector/dist/assets/index-a5d3e69d.js +0 -1040
- rasa/core/channels/voice_aware/utils.py +0 -20
- rasa/llm_fine_tuning/notebooks/unsloth_finetuning.ipynb +0 -407
- /rasa/core/channels/{voice_aware → voice_ready}/__init__.py +0 -0
- /rasa/core/channels/{voice_native → voice_stream}/__init__.py +0 -0
- {rasa_pro-3.10.15.dist-info → rasa_pro-3.11.0.dist-info}/NOTICE +0 -0
- {rasa_pro-3.10.15.dist-info → rasa_pro-3.11.0.dist-info}/WHEEL +0 -0
- {rasa_pro-3.10.15.dist-info → rasa_pro-3.11.0.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,427 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import structlog
|
|
3
|
+
import copy
|
|
4
|
+
from dataclasses import asdict, dataclass
|
|
5
|
+
from typing import Any, AsyncIterator, Awaitable, Callable, Dict, List, Optional, Tuple
|
|
6
|
+
|
|
7
|
+
from rasa.core.channels.voice_stream.util import generate_silence
|
|
8
|
+
from rasa.shared.core.constants import SLOT_SILENCE_TIMEOUT
|
|
9
|
+
from rasa.shared.utils.common import (
|
|
10
|
+
class_from_module_path,
|
|
11
|
+
mark_as_beta_feature,
|
|
12
|
+
)
|
|
13
|
+
from rasa.shared.utils.cli import print_error_and_exit
|
|
14
|
+
|
|
15
|
+
from sanic.exceptions import ServerError, WebsocketClosed
|
|
16
|
+
|
|
17
|
+
from rasa.core.channels import InputChannel, OutputChannel, UserMessage
|
|
18
|
+
from rasa.core.channels.voice_ready.utils import CallParameters
|
|
19
|
+
from rasa.core.channels.voice_ready.utils import validate_voice_license_scope
|
|
20
|
+
from rasa.core.channels.voice_stream.asr.asr_engine import ASREngine
|
|
21
|
+
from rasa.core.channels.voice_stream.asr.asr_event import (
|
|
22
|
+
ASREvent,
|
|
23
|
+
NewTranscript,
|
|
24
|
+
UserIsSpeaking,
|
|
25
|
+
)
|
|
26
|
+
from sanic import Websocket # type: ignore
|
|
27
|
+
|
|
28
|
+
from rasa.core.channels.voice_stream.asr.deepgram import DeepgramASR
|
|
29
|
+
from rasa.core.channels.voice_stream.asr.azure import AzureASR
|
|
30
|
+
from rasa.core.channels.voice_stream.audio_bytes import HERTZ, RasaAudioBytes
|
|
31
|
+
from rasa.core.channels.voice_stream.call_state import (
|
|
32
|
+
CallState,
|
|
33
|
+
_call_state,
|
|
34
|
+
call_state,
|
|
35
|
+
)
|
|
36
|
+
from rasa.core.channels.voice_stream.tts.azure import AzureTTS
|
|
37
|
+
from rasa.core.channels.voice_stream.tts.tts_engine import TTSEngine, TTSError
|
|
38
|
+
from rasa.core.channels.voice_stream.tts.cartesia import CartesiaTTS
|
|
39
|
+
from rasa.core.channels.voice_stream.tts.tts_cache import TTSCache
|
|
40
|
+
from rasa.utils.io import remove_emojis
|
|
41
|
+
|
|
42
|
+
logger = structlog.get_logger(__name__)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@dataclass
|
|
46
|
+
class VoiceChannelAction:
|
|
47
|
+
pass
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
@dataclass
|
|
51
|
+
class NewAudioAction(VoiceChannelAction):
|
|
52
|
+
audio_bytes: RasaAudioBytes
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
@dataclass
|
|
56
|
+
class EndConversationAction(VoiceChannelAction):
|
|
57
|
+
pass
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
@dataclass
|
|
61
|
+
class ContinueConversationAction(VoiceChannelAction):
|
|
62
|
+
pass
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def asr_engine_from_config(asr_config: Dict) -> ASREngine:
|
|
66
|
+
name = str(asr_config["name"])
|
|
67
|
+
asr_config = copy.copy(asr_config)
|
|
68
|
+
asr_config.pop("name")
|
|
69
|
+
if name.lower() == "deepgram":
|
|
70
|
+
return DeepgramASR.from_config_dict(asr_config)
|
|
71
|
+
if name == "azure":
|
|
72
|
+
return AzureASR.from_config_dict(asr_config)
|
|
73
|
+
else:
|
|
74
|
+
mark_as_beta_feature("Custom ASR Engine")
|
|
75
|
+
try:
|
|
76
|
+
asr_engine_class = class_from_module_path(name)
|
|
77
|
+
return asr_engine_class.from_config_dict(asr_config)
|
|
78
|
+
except NameError:
|
|
79
|
+
print_error_and_exit(
|
|
80
|
+
f"Failed to initialize ASR Engine with type '{name}'. "
|
|
81
|
+
f"Please make sure the method `from_config_dict`is implemented."
|
|
82
|
+
)
|
|
83
|
+
except TypeError as e:
|
|
84
|
+
print_error_and_exit(
|
|
85
|
+
f"Failed to initialize ASR Engine with type '{name}'. "
|
|
86
|
+
f"Invalid configuration provided. "
|
|
87
|
+
f"Error: {e}"
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def tts_engine_from_config(tts_config: Dict) -> TTSEngine:
|
|
92
|
+
name = str(tts_config["name"])
|
|
93
|
+
tts_config = copy.copy(tts_config)
|
|
94
|
+
tts_config.pop("name")
|
|
95
|
+
if name.lower() == "azure":
|
|
96
|
+
return AzureTTS.from_config_dict(tts_config)
|
|
97
|
+
elif name.lower() == "cartesia":
|
|
98
|
+
return CartesiaTTS.from_config_dict(tts_config)
|
|
99
|
+
else:
|
|
100
|
+
mark_as_beta_feature("Custom TTS Engine")
|
|
101
|
+
try:
|
|
102
|
+
tts_engine_class = class_from_module_path(name)
|
|
103
|
+
return tts_engine_class.from_config_dict(tts_config)
|
|
104
|
+
except NameError:
|
|
105
|
+
print_error_and_exit(
|
|
106
|
+
f"Failed to initialize TTS Engine with type '{name}'. "
|
|
107
|
+
f"Please make sure the method `from_config_dict`is implemented."
|
|
108
|
+
)
|
|
109
|
+
except TypeError as e:
|
|
110
|
+
print_error_and_exit(
|
|
111
|
+
f"Failed to initialize ASR Engine with type '{name}'. "
|
|
112
|
+
f"Invalid configuration provided. "
|
|
113
|
+
f"Error: {e}"
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
class VoiceOutputChannel(OutputChannel):
|
|
118
|
+
def __init__(
|
|
119
|
+
self,
|
|
120
|
+
voice_websocket: Websocket,
|
|
121
|
+
tts_engine: TTSEngine,
|
|
122
|
+
tts_cache: TTSCache,
|
|
123
|
+
):
|
|
124
|
+
super().__init__()
|
|
125
|
+
self.voice_websocket = voice_websocket
|
|
126
|
+
self.tts_engine = tts_engine
|
|
127
|
+
self.tts_cache = tts_cache
|
|
128
|
+
|
|
129
|
+
self.latest_message_id: Optional[str] = None
|
|
130
|
+
|
|
131
|
+
def rasa_audio_bytes_to_channel_bytes(
|
|
132
|
+
self, rasa_audio_bytes: RasaAudioBytes
|
|
133
|
+
) -> bytes:
|
|
134
|
+
"""Turn rasa's audio byte format into the format for the channel."""
|
|
135
|
+
raise NotImplementedError
|
|
136
|
+
|
|
137
|
+
def channel_bytes_to_message(self, recipient_id: str, channel_bytes: bytes) -> str:
|
|
138
|
+
"""Wrap the bytes for the channel in the proper format."""
|
|
139
|
+
raise NotImplementedError
|
|
140
|
+
|
|
141
|
+
def create_marker_message(self, recipient_id: str) -> Tuple[str, str]:
|
|
142
|
+
"""Create a marker message for a specific channel."""
|
|
143
|
+
raise NotImplementedError
|
|
144
|
+
|
|
145
|
+
async def send_marker_message(self, recipient_id: str) -> None:
|
|
146
|
+
"""Send a message that marks positions in the audio stream."""
|
|
147
|
+
marker_message, mark_id = self.create_marker_message(recipient_id)
|
|
148
|
+
await self.voice_websocket.send(marker_message)
|
|
149
|
+
self.latest_message_id = mark_id
|
|
150
|
+
|
|
151
|
+
def update_silence_timeout(self) -> None:
|
|
152
|
+
"""Updates the silence timeout for the session."""
|
|
153
|
+
if self.tracker_state:
|
|
154
|
+
call_state.silence_timeout = ( # type: ignore[attr-defined]
|
|
155
|
+
self.tracker_state["slots"][SLOT_SILENCE_TIMEOUT]
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
async def send_text_with_buttons(
|
|
159
|
+
self,
|
|
160
|
+
recipient_id: str,
|
|
161
|
+
text: str,
|
|
162
|
+
buttons: List[Dict[str, Any]],
|
|
163
|
+
**kwargs: Any,
|
|
164
|
+
) -> None:
|
|
165
|
+
"""Uses the concise button output format for voice channels."""
|
|
166
|
+
await self.send_text_with_buttons_concise(recipient_id, text, buttons, **kwargs)
|
|
167
|
+
|
|
168
|
+
async def send_text_message(
|
|
169
|
+
self, recipient_id: str, text: str, **kwargs: Any
|
|
170
|
+
) -> None:
|
|
171
|
+
text = remove_emojis(text)
|
|
172
|
+
self.update_silence_timeout()
|
|
173
|
+
cached_audio_bytes = self.tts_cache.get(text)
|
|
174
|
+
collected_audio_bytes = RasaAudioBytes(b"")
|
|
175
|
+
seconds_marker = -1
|
|
176
|
+
if cached_audio_bytes:
|
|
177
|
+
audio_stream = self.chunk_audio(cached_audio_bytes)
|
|
178
|
+
else:
|
|
179
|
+
# Todo: make kwargs compatible with engine config
|
|
180
|
+
synth_config = self.tts_engine.config.__class__.from_dict({})
|
|
181
|
+
try:
|
|
182
|
+
audio_stream = self.tts_engine.synthesize(text, synth_config)
|
|
183
|
+
except TTSError:
|
|
184
|
+
# TODO: add message that works without tts, e.g. loading from disc
|
|
185
|
+
audio_stream = self.chunk_audio(generate_silence())
|
|
186
|
+
|
|
187
|
+
async for audio_bytes in audio_stream:
|
|
188
|
+
try:
|
|
189
|
+
await self.send_audio_bytes(recipient_id, audio_bytes)
|
|
190
|
+
full_seconds_of_audio = len(collected_audio_bytes) // HERTZ
|
|
191
|
+
if full_seconds_of_audio > seconds_marker:
|
|
192
|
+
await self.send_marker_message(recipient_id)
|
|
193
|
+
seconds_marker = full_seconds_of_audio
|
|
194
|
+
|
|
195
|
+
except (WebsocketClosed, ServerError):
|
|
196
|
+
# ignore sending error, and keep collecting and caching audio bytes
|
|
197
|
+
call_state.connection_failed = True # type: ignore[attr-defined]
|
|
198
|
+
collected_audio_bytes = RasaAudioBytes(collected_audio_bytes + audio_bytes)
|
|
199
|
+
try:
|
|
200
|
+
await self.send_marker_message(recipient_id)
|
|
201
|
+
except (WebsocketClosed, ServerError):
|
|
202
|
+
# ignore sending error
|
|
203
|
+
pass
|
|
204
|
+
call_state.latest_bot_audio_id = self.latest_message_id # type: ignore[attr-defined]
|
|
205
|
+
|
|
206
|
+
if not cached_audio_bytes:
|
|
207
|
+
self.tts_cache.put(text, collected_audio_bytes)
|
|
208
|
+
|
|
209
|
+
async def send_audio_bytes(
|
|
210
|
+
self, recipient_id: str, audio_bytes: RasaAudioBytes
|
|
211
|
+
) -> None:
|
|
212
|
+
channel_bytes = self.rasa_audio_bytes_to_channel_bytes(audio_bytes)
|
|
213
|
+
message = self.channel_bytes_to_message(recipient_id, channel_bytes)
|
|
214
|
+
await self.voice_websocket.send(message)
|
|
215
|
+
|
|
216
|
+
async def chunk_audio(
|
|
217
|
+
self, audio_bytes: RasaAudioBytes, chunk_size: int = 2048
|
|
218
|
+
) -> AsyncIterator[RasaAudioBytes]:
|
|
219
|
+
"""Generate chunks from cached audio bytes."""
|
|
220
|
+
offset = 0
|
|
221
|
+
while offset < len(audio_bytes):
|
|
222
|
+
chunk = audio_bytes[offset : offset + chunk_size]
|
|
223
|
+
if len(chunk):
|
|
224
|
+
yield RasaAudioBytes(chunk)
|
|
225
|
+
offset += chunk_size
|
|
226
|
+
return
|
|
227
|
+
|
|
228
|
+
async def hangup(self, recipient_id: str, **kwargs: Any) -> None:
|
|
229
|
+
call_state.should_hangup = True # type: ignore[attr-defined]
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
class VoiceInputChannel(InputChannel):
|
|
233
|
+
def __init__(
|
|
234
|
+
self,
|
|
235
|
+
server_url: str,
|
|
236
|
+
asr_config: Dict,
|
|
237
|
+
tts_config: Dict,
|
|
238
|
+
monitor_silence: bool = False,
|
|
239
|
+
):
|
|
240
|
+
validate_voice_license_scope()
|
|
241
|
+
self.server_url = server_url
|
|
242
|
+
self.asr_config = asr_config
|
|
243
|
+
self.tts_config = tts_config
|
|
244
|
+
self.monitor_silence = monitor_silence
|
|
245
|
+
self.tts_cache = TTSCache(tts_config.get("cache_size", 1000))
|
|
246
|
+
|
|
247
|
+
async def handle_silence_timeout(
|
|
248
|
+
self,
|
|
249
|
+
voice_websocket: Websocket,
|
|
250
|
+
on_new_message: Callable[[UserMessage], Awaitable[Any]],
|
|
251
|
+
tts_engine: TTSEngine,
|
|
252
|
+
call_parameters: CallParameters,
|
|
253
|
+
) -> None:
|
|
254
|
+
timeout = call_state.silence_timeout
|
|
255
|
+
if not timeout:
|
|
256
|
+
return
|
|
257
|
+
if not self.monitor_silence:
|
|
258
|
+
return
|
|
259
|
+
logger.debug("voice_channel.silence_timeout_watch_started", timeout=timeout)
|
|
260
|
+
await asyncio.sleep(timeout)
|
|
261
|
+
logger.debug("voice_channel.silence_timeout_tripped")
|
|
262
|
+
output_channel = self.create_output_channel(voice_websocket, tts_engine)
|
|
263
|
+
message = UserMessage(
|
|
264
|
+
"/silence_timeout",
|
|
265
|
+
output_channel,
|
|
266
|
+
call_parameters.stream_id,
|
|
267
|
+
input_channel=self.name(),
|
|
268
|
+
metadata=asdict(call_parameters),
|
|
269
|
+
)
|
|
270
|
+
await on_new_message(message)
|
|
271
|
+
|
|
272
|
+
@staticmethod
|
|
273
|
+
def _cancel_silence_timeout_watcher() -> None:
|
|
274
|
+
"""Cancels the silent timeout task if it exists."""
|
|
275
|
+
if call_state.silence_timeout_watcher:
|
|
276
|
+
logger.debug("voice_channel.cancelling_current_timeout_watcher_task")
|
|
277
|
+
call_state.silence_timeout_watcher.cancel()
|
|
278
|
+
call_state.silence_timeout_watcher = None # type: ignore[attr-defined]
|
|
279
|
+
|
|
280
|
+
@classmethod
|
|
281
|
+
def from_credentials(cls, credentials: Optional[Dict[str, Any]]) -> InputChannel:
|
|
282
|
+
credentials = credentials or {}
|
|
283
|
+
return cls(
|
|
284
|
+
credentials["server_url"],
|
|
285
|
+
credentials["asr"],
|
|
286
|
+
credentials["tts"],
|
|
287
|
+
credentials.get("monitor_silence", False),
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
def channel_bytes_to_rasa_audio_bytes(self, input_bytes: bytes) -> RasaAudioBytes:
|
|
291
|
+
raise NotImplementedError
|
|
292
|
+
|
|
293
|
+
async def collect_call_parameters(
|
|
294
|
+
self, channel_websocket: Websocket
|
|
295
|
+
) -> Optional[CallParameters]:
|
|
296
|
+
raise NotImplementedError
|
|
297
|
+
|
|
298
|
+
async def start_session(
|
|
299
|
+
self,
|
|
300
|
+
channel_websocket: Websocket,
|
|
301
|
+
on_new_message: Callable[[UserMessage], Awaitable[Any]],
|
|
302
|
+
tts_engine: TTSEngine,
|
|
303
|
+
call_parameters: CallParameters,
|
|
304
|
+
) -> None:
|
|
305
|
+
output_channel = self.create_output_channel(channel_websocket, tts_engine)
|
|
306
|
+
message = UserMessage(
|
|
307
|
+
"/session_start",
|
|
308
|
+
output_channel,
|
|
309
|
+
call_parameters.stream_id,
|
|
310
|
+
input_channel=self.name(),
|
|
311
|
+
metadata=asdict(call_parameters),
|
|
312
|
+
)
|
|
313
|
+
await on_new_message(message)
|
|
314
|
+
|
|
315
|
+
def map_input_message(
|
|
316
|
+
self,
|
|
317
|
+
message: Any,
|
|
318
|
+
) -> VoiceChannelAction:
|
|
319
|
+
"""Map a channel input message to a voice channel action."""
|
|
320
|
+
raise NotImplementedError
|
|
321
|
+
|
|
322
|
+
async def run_audio_streaming(
|
|
323
|
+
self,
|
|
324
|
+
on_new_message: Callable[[UserMessage], Awaitable[Any]],
|
|
325
|
+
channel_websocket: Websocket,
|
|
326
|
+
) -> None:
|
|
327
|
+
"""Pipe input audio to ASR and consume ASR events simultaneously."""
|
|
328
|
+
_call_state.set(CallState())
|
|
329
|
+
asr_engine = asr_engine_from_config(self.asr_config)
|
|
330
|
+
tts_engine = tts_engine_from_config(self.tts_config)
|
|
331
|
+
await asr_engine.connect()
|
|
332
|
+
|
|
333
|
+
call_parameters = await self.collect_call_parameters(channel_websocket)
|
|
334
|
+
if call_parameters is None:
|
|
335
|
+
raise ValueError("Failed to extract call parameters for call.")
|
|
336
|
+
await self.start_session(
|
|
337
|
+
channel_websocket, on_new_message, tts_engine, call_parameters
|
|
338
|
+
)
|
|
339
|
+
|
|
340
|
+
async def consume_audio_bytes() -> None:
|
|
341
|
+
async for message in channel_websocket:
|
|
342
|
+
is_bot_speaking_before = call_state.is_bot_speaking
|
|
343
|
+
channel_action = self.map_input_message(message)
|
|
344
|
+
is_bot_speaking_after = call_state.is_bot_speaking
|
|
345
|
+
|
|
346
|
+
if not is_bot_speaking_before and is_bot_speaking_after:
|
|
347
|
+
logger.debug("voice_channel.bot_started_speaking")
|
|
348
|
+
# relevant when the bot speaks multiple messages in one turn
|
|
349
|
+
self._cancel_silence_timeout_watcher()
|
|
350
|
+
|
|
351
|
+
# we just stopped speaking, starting a watcher for silence timeout
|
|
352
|
+
if is_bot_speaking_before and not is_bot_speaking_after:
|
|
353
|
+
logger.debug("voice_channel.bot_stopped_speaking")
|
|
354
|
+
self._cancel_silence_timeout_watcher()
|
|
355
|
+
call_state.silence_timeout_watcher = ( # type: ignore[attr-defined]
|
|
356
|
+
asyncio.create_task(
|
|
357
|
+
self.handle_silence_timeout(
|
|
358
|
+
channel_websocket,
|
|
359
|
+
on_new_message,
|
|
360
|
+
tts_engine,
|
|
361
|
+
call_parameters,
|
|
362
|
+
)
|
|
363
|
+
)
|
|
364
|
+
)
|
|
365
|
+
if isinstance(channel_action, NewAudioAction):
|
|
366
|
+
await asr_engine.send_audio_chunks(channel_action.audio_bytes)
|
|
367
|
+
elif isinstance(channel_action, EndConversationAction):
|
|
368
|
+
# end stream event came from the other side
|
|
369
|
+
break
|
|
370
|
+
|
|
371
|
+
async def consume_asr_events() -> None:
|
|
372
|
+
async for event in asr_engine.stream_asr_events():
|
|
373
|
+
await self.handle_asr_event(
|
|
374
|
+
event,
|
|
375
|
+
channel_websocket,
|
|
376
|
+
on_new_message,
|
|
377
|
+
tts_engine,
|
|
378
|
+
call_parameters,
|
|
379
|
+
)
|
|
380
|
+
|
|
381
|
+
audio_forwarding_task = asyncio.create_task(consume_audio_bytes())
|
|
382
|
+
asr_event_task = asyncio.create_task(consume_asr_events())
|
|
383
|
+
await asyncio.wait(
|
|
384
|
+
[audio_forwarding_task, asr_event_task],
|
|
385
|
+
return_when=asyncio.FIRST_COMPLETED,
|
|
386
|
+
)
|
|
387
|
+
if not audio_forwarding_task.done():
|
|
388
|
+
audio_forwarding_task.cancel()
|
|
389
|
+
if not asr_event_task.done():
|
|
390
|
+
asr_event_task.cancel()
|
|
391
|
+
await tts_engine.close_connection()
|
|
392
|
+
await asr_engine.close_connection()
|
|
393
|
+
await channel_websocket.close()
|
|
394
|
+
self._cancel_silence_timeout_watcher()
|
|
395
|
+
|
|
396
|
+
def create_output_channel(
|
|
397
|
+
self, voice_websocket: Websocket, tts_engine: TTSEngine
|
|
398
|
+
) -> VoiceOutputChannel:
|
|
399
|
+
"""Create a matching voice output channel for this voice input channel."""
|
|
400
|
+
raise NotImplementedError
|
|
401
|
+
|
|
402
|
+
async def handle_asr_event(
|
|
403
|
+
self,
|
|
404
|
+
e: ASREvent,
|
|
405
|
+
voice_websocket: Websocket,
|
|
406
|
+
on_new_message: Callable[[UserMessage], Awaitable[Any]],
|
|
407
|
+
tts_engine: TTSEngine,
|
|
408
|
+
call_parameters: CallParameters,
|
|
409
|
+
) -> None:
|
|
410
|
+
"""Handle a new event from the ASR system."""
|
|
411
|
+
if isinstance(e, NewTranscript) and e.text:
|
|
412
|
+
logger.debug(
|
|
413
|
+
"VoiceInputChannel.handle_asr_event.new_transcript", transcript=e.text
|
|
414
|
+
)
|
|
415
|
+
call_state.is_user_speaking = False # type: ignore[attr-defined]
|
|
416
|
+
output_channel = self.create_output_channel(voice_websocket, tts_engine)
|
|
417
|
+
message = UserMessage(
|
|
418
|
+
e.text,
|
|
419
|
+
output_channel,
|
|
420
|
+
call_parameters.stream_id,
|
|
421
|
+
input_channel=self.name(),
|
|
422
|
+
metadata=asdict(call_parameters),
|
|
423
|
+
)
|
|
424
|
+
await on_new_message(message)
|
|
425
|
+
elif isinstance(e, UserIsSpeaking):
|
|
426
|
+
self._cancel_silence_timeout_watcher()
|
|
427
|
+
call_state.is_user_speaking = True # type: ignore[attr-defined]
|
|
@@ -62,6 +62,7 @@ class Qdrant_Store(InformationRetrieval):
|
|
|
62
62
|
embeddings=self.embeddings,
|
|
63
63
|
content_payload_key=params.get("content_payload_key", "text"),
|
|
64
64
|
metadata_payload_key=params.get("metadata_payload_key", "metadata"),
|
|
65
|
+
vector_name=params.get("vector_name", None),
|
|
65
66
|
)
|
|
66
67
|
|
|
67
68
|
async def search(
|
|
@@ -2,9 +2,9 @@ from typing import Any, Dict, Optional, Text
|
|
|
2
2
|
|
|
3
3
|
import structlog
|
|
4
4
|
from jinja2 import Template
|
|
5
|
-
|
|
6
5
|
from rasa import telemetry
|
|
7
6
|
from rasa.core.nlg.response import TemplatedNaturalLanguageGenerator
|
|
7
|
+
from rasa.core.nlg.summarize import summarize_conversation
|
|
8
8
|
from rasa.shared.constants import (
|
|
9
9
|
LLM_CONFIG_KEY,
|
|
10
10
|
MODEL_CONFIG_KEY,
|
|
@@ -13,10 +13,12 @@ from rasa.shared.constants import (
|
|
|
13
13
|
PROVIDER_CONFIG_KEY,
|
|
14
14
|
OPENAI_PROVIDER,
|
|
15
15
|
TIMEOUT_CONFIG_KEY,
|
|
16
|
+
MODEL_GROUP_ID_CONFIG_KEY,
|
|
16
17
|
)
|
|
17
18
|
from rasa.shared.core.domain import KEY_RESPONSES_TEXT, Domain
|
|
18
19
|
from rasa.shared.core.events import BotUttered, UserUttered
|
|
19
20
|
from rasa.shared.core.trackers import DialogueStateTracker
|
|
21
|
+
from rasa.shared.utils.health_check.llm_health_check_mixin import LLMHealthCheckMixin
|
|
20
22
|
from rasa.shared.utils.llm import (
|
|
21
23
|
DEFAULT_OPENAI_GENERATE_MODEL_NAME,
|
|
22
24
|
DEFAULT_OPENAI_MAX_GENERATED_TOKENS,
|
|
@@ -24,12 +26,12 @@ from rasa.shared.utils.llm import (
|
|
|
24
26
|
combine_custom_and_default_config,
|
|
25
27
|
get_prompt_template,
|
|
26
28
|
llm_factory,
|
|
27
|
-
|
|
29
|
+
resolve_model_client_config,
|
|
30
|
+
)
|
|
31
|
+
from rasa.shared.utils.llm import (
|
|
32
|
+
tracker_as_readable_transcript,
|
|
28
33
|
)
|
|
29
34
|
from rasa.utils.endpoints import EndpointConfig
|
|
30
|
-
|
|
31
|
-
from rasa.core.nlg.summarize import summarize_conversation
|
|
32
|
-
|
|
33
35
|
from rasa.utils.log_utils import log_llm
|
|
34
36
|
|
|
35
37
|
structlogger = structlog.get_logger()
|
|
@@ -38,7 +40,11 @@ RESPONSE_REPHRASING_KEY = "rephrase"
|
|
|
38
40
|
|
|
39
41
|
RESPONSE_REPHRASING_TEMPLATE_KEY = "rephrase_prompt"
|
|
40
42
|
|
|
43
|
+
RESPONSE_SUMMARISE_CONVERSATION_KEY = "summarize_conversation"
|
|
44
|
+
|
|
41
45
|
DEFAULT_REPHRASE_ALL = False
|
|
46
|
+
DEFAULT_SUMMARIZE_HISTORY = True
|
|
47
|
+
DEFAULT_MAX_HISTORICAL_TURNS = 5
|
|
42
48
|
|
|
43
49
|
DEFAULT_LLM_CONFIG = {
|
|
44
50
|
PROVIDER_CONFIG_KEY: OPENAI_PROVIDER,
|
|
@@ -63,7 +69,9 @@ Suggested AI Response: {{suggested_response}}
|
|
|
63
69
|
Rephrased AI Response:"""
|
|
64
70
|
|
|
65
71
|
|
|
66
|
-
class ContextualResponseRephraser(
|
|
72
|
+
class ContextualResponseRephraser(
|
|
73
|
+
LLMHealthCheckMixin, TemplatedNaturalLanguageGenerator
|
|
74
|
+
):
|
|
67
75
|
"""Generates responses based on modified templates.
|
|
68
76
|
|
|
69
77
|
The templates are filled with the entities and slots that are available in the
|
|
@@ -97,11 +105,23 @@ class ContextualResponseRephraser(TemplatedNaturalLanguageGenerator):
|
|
|
97
105
|
self.trace_prompt_tokens = self.nlg_endpoint.kwargs.get(
|
|
98
106
|
"trace_prompt_tokens", False
|
|
99
107
|
)
|
|
100
|
-
|
|
108
|
+
self.summarize_history = self.nlg_endpoint.kwargs.get(
|
|
109
|
+
"summarize_history", DEFAULT_SUMMARIZE_HISTORY
|
|
110
|
+
)
|
|
111
|
+
self.max_historical_turns = self.nlg_endpoint.kwargs.get(
|
|
112
|
+
"max_historical_turns", DEFAULT_MAX_HISTORICAL_TURNS
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
self.llm_config = resolve_model_client_config(
|
|
101
116
|
self.nlg_endpoint.kwargs.get(LLM_CONFIG_KEY),
|
|
117
|
+
ContextualResponseRephraser.__name__,
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
self.perform_llm_health_check(
|
|
121
|
+
self.llm_config,
|
|
102
122
|
DEFAULT_LLM_CONFIG,
|
|
103
123
|
"contextual_response_rephraser.init",
|
|
104
|
-
|
|
124
|
+
ContextualResponseRephraser.__name__,
|
|
105
125
|
)
|
|
106
126
|
|
|
107
127
|
def _last_message_if_human(self, tracker: DialogueStateTracker) -> Optional[str]:
|
|
@@ -131,9 +151,7 @@ class ContextualResponseRephraser(TemplatedNaturalLanguageGenerator):
|
|
|
131
151
|
Returns:
|
|
132
152
|
generated text
|
|
133
153
|
"""
|
|
134
|
-
llm = llm_factory(
|
|
135
|
-
self.nlg_endpoint.kwargs.get(LLM_CONFIG_KEY), DEFAULT_LLM_CONFIG
|
|
136
|
-
)
|
|
154
|
+
llm = llm_factory(self.llm_config, DEFAULT_LLM_CONFIG)
|
|
137
155
|
|
|
138
156
|
try:
|
|
139
157
|
llm_response = await llm.acompletion(prompt)
|
|
@@ -147,7 +165,7 @@ class ContextualResponseRephraser(TemplatedNaturalLanguageGenerator):
|
|
|
147
165
|
def llm_property(self, prop: str) -> Optional[str]:
|
|
148
166
|
"""Returns a property of the LLM provider."""
|
|
149
167
|
return combine_custom_and_default_config(
|
|
150
|
-
self.
|
|
168
|
+
self.llm_config, DEFAULT_LLM_CONFIG
|
|
151
169
|
).get(prop)
|
|
152
170
|
|
|
153
171
|
def custom_prompt_template(self, prompt_template: str) -> Optional[str]:
|
|
@@ -180,9 +198,7 @@ class ContextualResponseRephraser(TemplatedNaturalLanguageGenerator):
|
|
|
180
198
|
Returns:
|
|
181
199
|
The history for the prompt.
|
|
182
200
|
"""
|
|
183
|
-
llm = llm_factory(
|
|
184
|
-
self.nlg_endpoint.kwargs.get(LLM_CONFIG_KEY), DEFAULT_LLM_CONFIG
|
|
185
|
-
)
|
|
201
|
+
llm = llm_factory(self.llm_config, DEFAULT_LLM_CONFIG)
|
|
186
202
|
return await summarize_conversation(tracker, llm, max_turns=5)
|
|
187
203
|
|
|
188
204
|
async def rephrase(
|
|
@@ -203,13 +219,24 @@ class ContextualResponseRephraser(TemplatedNaturalLanguageGenerator):
|
|
|
203
219
|
if not (response_text := response.get(KEY_RESPONSES_TEXT)):
|
|
204
220
|
return response
|
|
205
221
|
|
|
222
|
+
prompt_template_text = self._template_for_response_rephrasing(response)
|
|
223
|
+
|
|
224
|
+
# Retrieve inputs for the dynamic prompt
|
|
206
225
|
latest_message = self._last_message_if_human(tracker)
|
|
207
226
|
current_input = f"{USER}: {latest_message}" if latest_message else ""
|
|
208
227
|
|
|
209
|
-
|
|
228
|
+
# Only summarise conversation history if flagged
|
|
229
|
+
if self.summarize_history:
|
|
230
|
+
history = await self._create_history(tracker)
|
|
231
|
+
else:
|
|
232
|
+
# make sure the transcript/history contains the last user utterance
|
|
233
|
+
max_turns = max(self.max_historical_turns, 1)
|
|
234
|
+
history = tracker_as_readable_transcript(tracker, max_turns=max_turns)
|
|
235
|
+
# the history already contains the current input
|
|
236
|
+
current_input = ""
|
|
210
237
|
|
|
211
238
|
prompt = Template(prompt_template_text).render(
|
|
212
|
-
history=
|
|
239
|
+
history=history,
|
|
213
240
|
suggested_response=response_text,
|
|
214
241
|
current_input=current_input,
|
|
215
242
|
slots=tracker.current_slot_values(),
|
|
@@ -226,6 +253,7 @@ class ContextualResponseRephraser(TemplatedNaturalLanguageGenerator):
|
|
|
226
253
|
llm_type=self.llm_property(PROVIDER_CONFIG_KEY),
|
|
227
254
|
llm_model=self.llm_property(MODEL_CONFIG_KEY)
|
|
228
255
|
or self.llm_property(MODEL_NAME_CONFIG_KEY),
|
|
256
|
+
llm_model_group_id=self.llm_property(MODEL_GROUP_ID_CONFIG_KEY),
|
|
229
257
|
)
|
|
230
258
|
if not (updated_text := await self._generate_llm_response(prompt)):
|
|
231
259
|
# If the LLM fails to generate a response, we
|