rasa-pro 3.9.18__py3-none-any.whl → 3.10.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- README.md +0 -374
- rasa/__init__.py +1 -2
- rasa/__main__.py +5 -0
- rasa/anonymization/anonymization_rule_executor.py +2 -2
- rasa/api.py +27 -23
- rasa/cli/arguments/data.py +27 -2
- rasa/cli/arguments/default_arguments.py +25 -3
- rasa/cli/arguments/run.py +9 -9
- rasa/cli/arguments/train.py +11 -3
- rasa/cli/data.py +70 -8
- rasa/cli/e2e_test.py +104 -431
- rasa/cli/evaluate.py +1 -1
- rasa/cli/interactive.py +1 -0
- rasa/cli/llm_fine_tuning.py +398 -0
- rasa/cli/project_templates/calm/endpoints.yml +1 -1
- rasa/cli/project_templates/tutorial/endpoints.yml +1 -1
- rasa/cli/run.py +15 -14
- rasa/cli/scaffold.py +10 -8
- rasa/cli/studio/studio.py +35 -5
- rasa/cli/train.py +56 -8
- rasa/cli/utils.py +22 -5
- rasa/cli/x.py +1 -1
- rasa/constants.py +7 -1
- rasa/core/actions/action.py +98 -49
- rasa/core/actions/action_run_slot_rejections.py +4 -1
- rasa/core/actions/custom_action_executor.py +9 -6
- rasa/core/actions/direct_custom_actions_executor.py +80 -0
- rasa/core/actions/e2e_stub_custom_action_executor.py +68 -0
- rasa/core/actions/grpc_custom_action_executor.py +2 -2
- rasa/core/actions/http_custom_action_executor.py +6 -5
- rasa/core/agent.py +21 -17
- rasa/core/channels/__init__.py +2 -0
- rasa/core/channels/audiocodes.py +1 -16
- rasa/core/channels/voice_aware/__init__.py +0 -0
- rasa/core/channels/voice_aware/jambonz.py +103 -0
- rasa/core/channels/voice_aware/jambonz_protocol.py +344 -0
- rasa/core/channels/voice_aware/utils.py +20 -0
- rasa/core/channels/voice_native/__init__.py +0 -0
- rasa/core/constants.py +6 -1
- rasa/core/information_retrieval/faiss.py +7 -4
- rasa/core/information_retrieval/information_retrieval.py +8 -0
- rasa/core/information_retrieval/milvus.py +9 -2
- rasa/core/information_retrieval/qdrant.py +1 -1
- rasa/core/nlg/contextual_response_rephraser.py +32 -10
- rasa/core/nlg/summarize.py +4 -3
- rasa/core/policies/enterprise_search_policy.py +113 -45
- rasa/core/policies/flows/flow_executor.py +122 -76
- rasa/core/policies/intentless_policy.py +83 -29
- rasa/core/processor.py +72 -54
- rasa/core/run.py +5 -4
- rasa/core/tracker_store.py +8 -4
- rasa/core/training/interactive.py +1 -1
- rasa/core/utils.py +56 -57
- rasa/dialogue_understanding/coexistence/llm_based_router.py +53 -13
- rasa/dialogue_understanding/commands/__init__.py +6 -0
- rasa/dialogue_understanding/commands/restart_command.py +58 -0
- rasa/dialogue_understanding/commands/session_start_command.py +59 -0
- rasa/dialogue_understanding/commands/utils.py +40 -0
- rasa/dialogue_understanding/generator/constants.py +10 -3
- rasa/dialogue_understanding/generator/flow_retrieval.py +21 -5
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +13 -3
- rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +134 -90
- rasa/dialogue_understanding/generator/nlu_command_adapter.py +47 -7
- rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +127 -41
- rasa/dialogue_understanding/patterns/restart.py +37 -0
- rasa/dialogue_understanding/patterns/session_start.py +37 -0
- rasa/dialogue_understanding/processor/command_processor.py +16 -3
- rasa/dialogue_understanding/processor/command_processor_component.py +6 -2
- rasa/e2e_test/aggregate_test_stats_calculator.py +134 -0
- rasa/e2e_test/assertions.py +1223 -0
- rasa/e2e_test/assertions_schema.yml +106 -0
- rasa/e2e_test/constants.py +20 -0
- rasa/e2e_test/e2e_config.py +220 -0
- rasa/e2e_test/e2e_config_schema.yml +26 -0
- rasa/e2e_test/e2e_test_case.py +131 -8
- rasa/e2e_test/e2e_test_converter.py +363 -0
- rasa/e2e_test/e2e_test_converter_prompt.jinja2 +70 -0
- rasa/e2e_test/e2e_test_coverage_report.py +364 -0
- rasa/e2e_test/e2e_test_result.py +26 -6
- rasa/e2e_test/e2e_test_runner.py +493 -71
- rasa/e2e_test/e2e_test_schema.yml +96 -0
- rasa/e2e_test/pykwalify_extensions.py +39 -0
- rasa/e2e_test/stub_custom_action.py +70 -0
- rasa/e2e_test/utils/__init__.py +0 -0
- rasa/e2e_test/utils/e2e_yaml_utils.py +55 -0
- rasa/e2e_test/utils/io.py +598 -0
- rasa/e2e_test/utils/validation.py +80 -0
- rasa/engine/graph.py +9 -3
- rasa/engine/recipes/default_components.py +0 -2
- rasa/engine/recipes/default_recipe.py +10 -2
- rasa/engine/storage/local_model_storage.py +40 -12
- rasa/engine/validation.py +78 -1
- rasa/env.py +9 -0
- rasa/graph_components/providers/story_graph_provider.py +59 -6
- rasa/llm_fine_tuning/__init__.py +0 -0
- rasa/llm_fine_tuning/annotation_module.py +241 -0
- rasa/llm_fine_tuning/conversations.py +144 -0
- rasa/llm_fine_tuning/llm_data_preparation_module.py +178 -0
- rasa/llm_fine_tuning/notebooks/unsloth_finetuning.ipynb +407 -0
- rasa/llm_fine_tuning/paraphrasing/__init__.py +0 -0
- rasa/llm_fine_tuning/paraphrasing/conversation_rephraser.py +281 -0
- rasa/llm_fine_tuning/paraphrasing/default_rephrase_prompt_template.jina2 +44 -0
- rasa/llm_fine_tuning/paraphrasing/rephrase_validator.py +121 -0
- rasa/llm_fine_tuning/paraphrasing/rephrased_user_message.py +10 -0
- rasa/llm_fine_tuning/paraphrasing_module.py +128 -0
- rasa/llm_fine_tuning/storage.py +174 -0
- rasa/llm_fine_tuning/train_test_split_module.py +441 -0
- rasa/model_training.py +56 -16
- rasa/nlu/persistor.py +157 -36
- rasa/server.py +45 -10
- rasa/shared/constants.py +76 -16
- rasa/shared/core/domain.py +27 -19
- rasa/shared/core/events.py +28 -2
- rasa/shared/core/flows/flow.py +208 -13
- rasa/shared/core/flows/flow_path.py +84 -0
- rasa/shared/core/flows/flows_list.py +33 -11
- rasa/shared/core/flows/flows_yaml_schema.json +269 -193
- rasa/shared/core/flows/validation.py +112 -25
- rasa/shared/core/flows/yaml_flows_io.py +149 -10
- rasa/shared/core/trackers.py +6 -0
- rasa/shared/core/training_data/structures.py +20 -0
- rasa/shared/core/training_data/visualization.html +2 -2
- rasa/shared/exceptions.py +4 -0
- rasa/shared/importers/importer.py +64 -16
- rasa/shared/nlu/constants.py +2 -0
- rasa/shared/providers/_configs/__init__.py +0 -0
- rasa/shared/providers/_configs/azure_openai_client_config.py +183 -0
- rasa/shared/providers/_configs/client_config.py +57 -0
- rasa/shared/providers/_configs/default_litellm_client_config.py +130 -0
- rasa/shared/providers/_configs/huggingface_local_embedding_client_config.py +234 -0
- rasa/shared/providers/_configs/openai_client_config.py +175 -0
- rasa/shared/providers/_configs/self_hosted_llm_client_config.py +176 -0
- rasa/shared/providers/_configs/utils.py +101 -0
- rasa/shared/providers/_ssl_verification_utils.py +124 -0
- rasa/shared/providers/embedding/__init__.py +0 -0
- rasa/shared/providers/embedding/_base_litellm_embedding_client.py +259 -0
- rasa/shared/providers/embedding/_langchain_embedding_client_adapter.py +74 -0
- rasa/shared/providers/embedding/azure_openai_embedding_client.py +277 -0
- rasa/shared/providers/embedding/default_litellm_embedding_client.py +102 -0
- rasa/shared/providers/embedding/embedding_client.py +90 -0
- rasa/shared/providers/embedding/embedding_response.py +41 -0
- rasa/shared/providers/embedding/huggingface_local_embedding_client.py +191 -0
- rasa/shared/providers/embedding/openai_embedding_client.py +172 -0
- rasa/shared/providers/llm/__init__.py +0 -0
- rasa/shared/providers/llm/_base_litellm_client.py +251 -0
- rasa/shared/providers/llm/azure_openai_llm_client.py +338 -0
- rasa/shared/providers/llm/default_litellm_llm_client.py +84 -0
- rasa/shared/providers/llm/llm_client.py +76 -0
- rasa/shared/providers/llm/llm_response.py +50 -0
- rasa/shared/providers/llm/openai_llm_client.py +155 -0
- rasa/shared/providers/llm/self_hosted_llm_client.py +293 -0
- rasa/shared/providers/mappings.py +75 -0
- rasa/shared/utils/cli.py +30 -0
- rasa/shared/utils/io.py +65 -2
- rasa/shared/utils/llm.py +246 -200
- rasa/shared/utils/yaml.py +121 -15
- rasa/studio/auth.py +6 -4
- rasa/studio/config.py +13 -4
- rasa/studio/constants.py +1 -0
- rasa/studio/data_handler.py +10 -3
- rasa/studio/download.py +19 -13
- rasa/studio/train.py +2 -3
- rasa/studio/upload.py +19 -11
- rasa/telemetry.py +113 -58
- rasa/tracing/instrumentation/attribute_extractors.py +32 -17
- rasa/utils/common.py +18 -19
- rasa/utils/endpoints.py +7 -4
- rasa/utils/json_utils.py +60 -0
- rasa/utils/licensing.py +9 -1
- rasa/utils/ml_utils.py +4 -2
- rasa/validator.py +213 -3
- rasa/version.py +1 -1
- rasa_pro-3.10.16.dist-info/METADATA +196 -0
- {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.16.dist-info}/RECORD +179 -113
- rasa/nlu/classifiers/llm_intent_classifier.py +0 -519
- rasa/shared/providers/openai/clients.py +0 -43
- rasa/shared/providers/openai/session_handler.py +0 -110
- rasa_pro-3.9.18.dist-info/METADATA +0 -563
- /rasa/{shared/providers/openai → cli/project_templates/tutorial/actions}/__init__.py +0 -0
- /rasa/cli/project_templates/tutorial/{actions.py → actions/actions.py} +0 -0
- {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.16.dist-info}/NOTICE +0 -0
- {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.16.dist-info}/WHEEL +0 -0
- {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.16.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,344 @@
|
|
|
1
|
+
from dataclasses import dataclass, field
|
|
2
|
+
import json
|
|
3
|
+
import uuid
|
|
4
|
+
from typing import Any, Awaitable, Callable, Dict, List, Text
|
|
5
|
+
|
|
6
|
+
import structlog
|
|
7
|
+
from rasa.core.channels.channel import UserMessage
|
|
8
|
+
from sanic import Websocket # type: ignore[attr-defined]
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
structlogger = structlog.get_logger()
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass
|
|
15
|
+
class NewSessionMessage:
|
|
16
|
+
"""Message indicating a new session has been started."""
|
|
17
|
+
|
|
18
|
+
call_sid: str
|
|
19
|
+
message_id: str
|
|
20
|
+
|
|
21
|
+
@staticmethod
|
|
22
|
+
def from_message(message: Dict[str, Any]) -> "NewSessionMessage":
|
|
23
|
+
return NewSessionMessage(
|
|
24
|
+
message.get("call_sid"),
|
|
25
|
+
message.get("msgid"),
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@dataclass
|
|
30
|
+
class Transcript:
|
|
31
|
+
"""Transcript of a spoken utterance."""
|
|
32
|
+
|
|
33
|
+
text: str
|
|
34
|
+
confidence: float
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@dataclass
|
|
38
|
+
class TranscriptResult:
|
|
39
|
+
"""Result of an ASR call with potential transcripts."""
|
|
40
|
+
|
|
41
|
+
call_sid: str
|
|
42
|
+
message_id: str
|
|
43
|
+
is_final: bool
|
|
44
|
+
transcripts: List[Transcript] = field(default_factory=list)
|
|
45
|
+
|
|
46
|
+
@staticmethod
|
|
47
|
+
def from_speech_result(message: Dict[str, Any]) -> "TranscriptResult":
|
|
48
|
+
return TranscriptResult(
|
|
49
|
+
message.get("call_sid"),
|
|
50
|
+
message.get("msgid"),
|
|
51
|
+
message.get("data", {}).get("speech", {}).get("is_final", True),
|
|
52
|
+
transcripts=[
|
|
53
|
+
Transcript(t.get("transcript", ""), t.get("confidence", 1.0))
|
|
54
|
+
for t in message.get("data", {})
|
|
55
|
+
.get("speech", {})
|
|
56
|
+
.get("alternatives", [])
|
|
57
|
+
],
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
@staticmethod
|
|
61
|
+
def from_dtmf_result(message: Dict[str, Any]) -> "TranscriptResult":
|
|
62
|
+
"""Create a transcript result from a DTMF result.
|
|
63
|
+
|
|
64
|
+
We use the dtmf as the text with confidence 1.0
|
|
65
|
+
"""
|
|
66
|
+
return TranscriptResult(
|
|
67
|
+
message.get("call_sid"),
|
|
68
|
+
message.get("msgid"),
|
|
69
|
+
is_final=True,
|
|
70
|
+
transcripts=[
|
|
71
|
+
Transcript(str(message.get("data", {}).get("digits", "")), 1.0)
|
|
72
|
+
],
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
@dataclass
|
|
77
|
+
class CallStatusChanged:
|
|
78
|
+
"""Message indicating a change in the call status."""
|
|
79
|
+
|
|
80
|
+
call_sid: str
|
|
81
|
+
status: str
|
|
82
|
+
|
|
83
|
+
@staticmethod
|
|
84
|
+
def from_message(message: Dict[str, Any]) -> "CallStatusChanged":
|
|
85
|
+
return CallStatusChanged(
|
|
86
|
+
message.get("call_sid"), message.get("data", {}).get("call_status")
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
@dataclass
|
|
91
|
+
class SessionReconnect:
|
|
92
|
+
"""Message indicating a session has reconnected."""
|
|
93
|
+
|
|
94
|
+
call_sid: str
|
|
95
|
+
|
|
96
|
+
@staticmethod
|
|
97
|
+
def from_message(message: Dict[str, Any]) -> "SessionReconnect":
|
|
98
|
+
return SessionReconnect(message.get("call_sid"))
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
@dataclass
|
|
102
|
+
class VerbStatusChanged:
|
|
103
|
+
"""Message indicating a change in the status of a verb."""
|
|
104
|
+
|
|
105
|
+
call_sid: str
|
|
106
|
+
event: str
|
|
107
|
+
id: str
|
|
108
|
+
name: str
|
|
109
|
+
|
|
110
|
+
@staticmethod
|
|
111
|
+
def from_message(message: Dict[str, Any]) -> "VerbStatusChanged":
|
|
112
|
+
return VerbStatusChanged(
|
|
113
|
+
message.get("call_sid"),
|
|
114
|
+
message.get("data", {}).get("event"),
|
|
115
|
+
message.get("data", {}).get("id"),
|
|
116
|
+
message.get("data", {}).get("name"),
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
@dataclass
|
|
121
|
+
class GatherTimeout:
|
|
122
|
+
"""Message indicating a gather timeout."""
|
|
123
|
+
|
|
124
|
+
call_sid: str
|
|
125
|
+
|
|
126
|
+
@staticmethod
|
|
127
|
+
def from_message(message: Dict[str, Any]) -> "GatherTimeout":
|
|
128
|
+
return GatherTimeout(message.get("call_sid"))
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
async def websocket_message_handler(
|
|
132
|
+
message_dump: str,
|
|
133
|
+
on_new_message: Callable[[UserMessage], Awaitable[Any]],
|
|
134
|
+
ws: Websocket,
|
|
135
|
+
) -> None:
|
|
136
|
+
"""Handle incoming messages from the websocket."""
|
|
137
|
+
message = json.loads(message_dump)
|
|
138
|
+
|
|
139
|
+
# parse and handle the different message types
|
|
140
|
+
if message.get("type") == "session:new":
|
|
141
|
+
new_session = NewSessionMessage.from_message(message)
|
|
142
|
+
await handle_new_session(new_session, on_new_message, ws)
|
|
143
|
+
elif message.get("type") == "session:reconnect":
|
|
144
|
+
session_reconnect = SessionReconnect.from_message(message)
|
|
145
|
+
await handle_session_reconnect(session_reconnect)
|
|
146
|
+
elif message.get("type") == "call:status":
|
|
147
|
+
call_status = CallStatusChanged.from_message(message)
|
|
148
|
+
await handle_call_status(call_status)
|
|
149
|
+
elif message.get("type") == "verb:hook" and message.get("hook") == "/gather":
|
|
150
|
+
hook_trigger_reason = message.get("data", {}).get("reason")
|
|
151
|
+
|
|
152
|
+
if hook_trigger_reason == "speechDetected":
|
|
153
|
+
transcript = TranscriptResult.from_speech_result(message)
|
|
154
|
+
await handle_gather_completed(transcript, on_new_message, ws)
|
|
155
|
+
elif hook_trigger_reason == "timeout":
|
|
156
|
+
gather_timeout = GatherTimeout.from_message(message)
|
|
157
|
+
await handle_gather_timeout(gather_timeout, ws)
|
|
158
|
+
elif hook_trigger_reason == "dtmfDetected":
|
|
159
|
+
# for now, let's handle it as normal user input with a
|
|
160
|
+
# confidence of 1.0
|
|
161
|
+
transcript = TranscriptResult.from_dtmf_result(message)
|
|
162
|
+
await handle_gather_completed(transcript, on_new_message, ws)
|
|
163
|
+
else:
|
|
164
|
+
structlogger.debug(
|
|
165
|
+
"jambonz.websocket.message.verb_hook",
|
|
166
|
+
call_sid=message.get("call_sid"),
|
|
167
|
+
reason=hook_trigger_reason,
|
|
168
|
+
message=message,
|
|
169
|
+
)
|
|
170
|
+
elif message.get("type") == "verb:status":
|
|
171
|
+
verb_status = VerbStatusChanged.from_message(message)
|
|
172
|
+
await handle_verb_status(verb_status)
|
|
173
|
+
elif message.get("type") == "jambonz:error":
|
|
174
|
+
# jambonz ran into a fatal error handling the call. the call will be
|
|
175
|
+
# terminated.
|
|
176
|
+
structlogger.error("jambonz.websocket.message.error", message=message)
|
|
177
|
+
else:
|
|
178
|
+
structlogger.warning("jambonz.websocket.message.unknown_type", message=message)
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
async def handle_new_session(
|
|
182
|
+
message: NewSessionMessage,
|
|
183
|
+
on_new_message: Callable[[UserMessage], Awaitable[Any]],
|
|
184
|
+
ws: Websocket,
|
|
185
|
+
) -> None:
|
|
186
|
+
"""Handle new session message."""
|
|
187
|
+
from rasa.core.channels.voice_aware.jambonz import JambonzWebsocketOutput
|
|
188
|
+
|
|
189
|
+
structlogger.debug("jambonz.websocket.message.new_call", call_sid=message.call_sid)
|
|
190
|
+
output_channel = JambonzWebsocketOutput(ws, message.call_sid)
|
|
191
|
+
user_msg = UserMessage(
|
|
192
|
+
text="/session_start",
|
|
193
|
+
output_channel=output_channel,
|
|
194
|
+
sender_id=message.call_sid,
|
|
195
|
+
metadata={},
|
|
196
|
+
)
|
|
197
|
+
await send_config_ack(message.message_id, ws)
|
|
198
|
+
await on_new_message(user_msg)
|
|
199
|
+
await send_gather_input(ws)
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
async def handle_gather_completed(
|
|
203
|
+
transcript_result: TranscriptResult,
|
|
204
|
+
on_new_message: Callable[[UserMessage], Awaitable[Any]],
|
|
205
|
+
ws: Websocket,
|
|
206
|
+
) -> None:
|
|
207
|
+
"""Handle changes to commands we have send to jambonz.
|
|
208
|
+
|
|
209
|
+
This includes results of gather calles with their transcription.
|
|
210
|
+
"""
|
|
211
|
+
from rasa.core.channels.voice_aware.jambonz import JambonzWebsocketOutput
|
|
212
|
+
|
|
213
|
+
if not transcript_result.is_final:
|
|
214
|
+
# in case of a non final transcript, we are going to wait for the final
|
|
215
|
+
# one and ignore the partial one
|
|
216
|
+
structlogger.debug(
|
|
217
|
+
"jambonz.websocket.message.transcript_partial",
|
|
218
|
+
call_sid=transcript_result.call_sid,
|
|
219
|
+
number_of_transcripts=len(transcript_result.transcripts),
|
|
220
|
+
)
|
|
221
|
+
return
|
|
222
|
+
|
|
223
|
+
if transcript_result.transcripts:
|
|
224
|
+
most_likely_transcript = transcript_result.transcripts[0]
|
|
225
|
+
output_channel = JambonzWebsocketOutput(ws, transcript_result.call_sid)
|
|
226
|
+
user_msg = UserMessage(
|
|
227
|
+
text=most_likely_transcript.text,
|
|
228
|
+
output_channel=output_channel,
|
|
229
|
+
sender_id=transcript_result.call_sid,
|
|
230
|
+
metadata={},
|
|
231
|
+
)
|
|
232
|
+
structlogger.debug(
|
|
233
|
+
"jambonz.websocket.message.transcript",
|
|
234
|
+
call_sid=transcript_result.call_sid,
|
|
235
|
+
transcript=most_likely_transcript.text,
|
|
236
|
+
confidence=most_likely_transcript.confidence,
|
|
237
|
+
number_of_transcripts=len(transcript_result.transcripts),
|
|
238
|
+
)
|
|
239
|
+
await on_new_message(user_msg)
|
|
240
|
+
else:
|
|
241
|
+
structlogger.warning(
|
|
242
|
+
"jambonz.websocket.message.no_transcript",
|
|
243
|
+
call_sid=transcript_result.call_sid,
|
|
244
|
+
)
|
|
245
|
+
await send_gather_input(ws)
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
async def handle_gather_timeout(gather_timeout: GatherTimeout, ws: Websocket) -> None:
|
|
249
|
+
"""Handle gather timeout."""
|
|
250
|
+
structlogger.debug(
|
|
251
|
+
"jambonz.websocket.message.gather_timeout",
|
|
252
|
+
call_sid=gather_timeout.call_sid,
|
|
253
|
+
)
|
|
254
|
+
# TODO: figure out how to handle timeouts
|
|
255
|
+
await send_ws_text_message(ws, "I'm sorry, I didn't catch that.")
|
|
256
|
+
await send_gather_input(ws)
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
async def handle_call_status(call_status: CallStatusChanged) -> None:
|
|
260
|
+
"""Handle changes in the call status."""
|
|
261
|
+
structlogger.debug(
|
|
262
|
+
"jambonz.websocket.message.call_status_changed",
|
|
263
|
+
call_sid=call_status.call_sid,
|
|
264
|
+
message=call_status.status,
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
async def handle_session_reconnect(session_reconnect: SessionReconnect) -> None:
|
|
269
|
+
"""Handle session reconnect message."""
|
|
270
|
+
# there is nothing we need to do atm when a session reconnects.
|
|
271
|
+
# this happens if jambonz looses the websocket connection and reconnects
|
|
272
|
+
structlogger.debug(
|
|
273
|
+
"jambonz.websocket.message.session_reconnect",
|
|
274
|
+
call_sid=session_reconnect.call_sid,
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
async def handle_verb_status(verb_status: VerbStatusChanged) -> None:
|
|
279
|
+
"""Handle changes in the status of a verb."""
|
|
280
|
+
structlogger.debug(
|
|
281
|
+
"jambonz.websocket.message.verb_status_changed",
|
|
282
|
+
call_sid=verb_status.call_sid,
|
|
283
|
+
event_type=verb_status.event,
|
|
284
|
+
id=verb_status.id,
|
|
285
|
+
name=verb_status.name,
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
async def send_config_ack(message_id: str, ws: Websocket) -> None:
|
|
290
|
+
"""Send an ack message to jambonz including the configuration."""
|
|
291
|
+
await ws.send(
|
|
292
|
+
json.dumps(
|
|
293
|
+
{
|
|
294
|
+
"type": "ack",
|
|
295
|
+
"msgid": message_id,
|
|
296
|
+
"data": [{"config": {"notifyEvents": True}}],
|
|
297
|
+
}
|
|
298
|
+
)
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
async def send_gather_input(ws: Websocket) -> None:
|
|
303
|
+
"""Send a gather input command to jambonz."""
|
|
304
|
+
await ws.send(
|
|
305
|
+
json.dumps(
|
|
306
|
+
{
|
|
307
|
+
"type": "command",
|
|
308
|
+
"command": "redirect",
|
|
309
|
+
"queueCommand": True,
|
|
310
|
+
"data": [
|
|
311
|
+
{
|
|
312
|
+
"gather": {
|
|
313
|
+
"input": ["speech", "digits"],
|
|
314
|
+
"minDigits": 1,
|
|
315
|
+
"id": uuid.uuid4().hex,
|
|
316
|
+
"actionHook": "/gather",
|
|
317
|
+
}
|
|
318
|
+
}
|
|
319
|
+
],
|
|
320
|
+
}
|
|
321
|
+
)
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
|
|
325
|
+
async def send_ws_text_message(ws: Websocket, text: Text) -> None:
|
|
326
|
+
"""Send a text message to the websocket using the jambonz interface."""
|
|
327
|
+
await ws.send(
|
|
328
|
+
json.dumps(
|
|
329
|
+
{
|
|
330
|
+
"type": "command",
|
|
331
|
+
"command": "redirect",
|
|
332
|
+
"queueCommand": True,
|
|
333
|
+
"data": [
|
|
334
|
+
{
|
|
335
|
+
"say": {
|
|
336
|
+
# id can be used for status notifications
|
|
337
|
+
"id": uuid.uuid4().hex,
|
|
338
|
+
"text": text,
|
|
339
|
+
}
|
|
340
|
+
}
|
|
341
|
+
],
|
|
342
|
+
}
|
|
343
|
+
)
|
|
344
|
+
)
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
import structlog
|
|
2
|
+
|
|
3
|
+
from rasa.utils.licensing import (
|
|
4
|
+
PRODUCT_AREA,
|
|
5
|
+
VOICE_SCOPE,
|
|
6
|
+
validate_license_from_env,
|
|
7
|
+
)
|
|
8
|
+
|
|
9
|
+
structlogger = structlog.get_logger()
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def validate_voice_license_scope() -> None:
|
|
13
|
+
"""Validate that the correct license scope is present."""
|
|
14
|
+
structlogger.info(
|
|
15
|
+
f"Validating current Rasa Pro license scope which must include "
|
|
16
|
+
f"the '{VOICE_SCOPE}' scope to use the voice channel."
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
voice_product_scope = PRODUCT_AREA + " " + VOICE_SCOPE
|
|
20
|
+
validate_license_from_env(product_area=voice_product_scope)
|
|
File without changes
|
rasa/core/constants.py
CHANGED
|
@@ -61,7 +61,6 @@ SEARCH_POLICY_PRIORITY = CHAT_POLICY_PRIORITY + 1
|
|
|
61
61
|
# flow policy priority
|
|
62
62
|
FLOW_POLICY_PRIORITY = SEARCH_POLICY_PRIORITY + 1
|
|
63
63
|
|
|
64
|
-
|
|
65
64
|
DIALOGUE = "dialogue"
|
|
66
65
|
|
|
67
66
|
# RabbitMQ message property header added to events published using `rasa export`
|
|
@@ -105,3 +104,9 @@ DEFAULT_TEMPLATE_ENGINE = RASA_FORMAT_TEMPLATE_ENGINE
|
|
|
105
104
|
# configuration parameter used to specify the template engine to use
|
|
106
105
|
# for a response
|
|
107
106
|
TEMPLATE_ENGINE_CONFIG_KEY = "template"
|
|
107
|
+
|
|
108
|
+
# metadata keys for bot utterance events
|
|
109
|
+
UTTER_SOURCE_METADATA_KEY = "utter_source"
|
|
110
|
+
DOMAIN_GROUND_TRUTH_METADATA_KEY = "domain_ground_truth"
|
|
111
|
+
ACTIVE_FLOW_METADATA_KEY = "active_flow"
|
|
112
|
+
STEP_ID_METADATA_KEY = "step_id"
|
|
@@ -2,9 +2,10 @@ from pathlib import Path
|
|
|
2
2
|
from typing import TYPE_CHECKING, List, Optional, Text, Any, Dict
|
|
3
3
|
|
|
4
4
|
import structlog
|
|
5
|
-
from langchain.document_loaders import DirectoryLoader, TextLoader
|
|
6
5
|
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
7
|
-
from
|
|
6
|
+
from langchain_community.document_loaders.text import TextLoader
|
|
7
|
+
from langchain_community.document_loaders.directory import DirectoryLoader
|
|
8
|
+
from langchain_community.vectorstores.faiss import FAISS
|
|
8
9
|
from rasa.utils.endpoints import EndpointConfig
|
|
9
10
|
|
|
10
11
|
from rasa.core.information_retrieval import (
|
|
@@ -46,7 +47,9 @@ class FAISS_Store(InformationRetrieval):
|
|
|
46
47
|
logger.info(
|
|
47
48
|
"information_retrieval.faiss_store.load_index", path=path.absolute()
|
|
48
49
|
)
|
|
49
|
-
self.index = FAISS.load_local(
|
|
50
|
+
self.index = FAISS.load_local(
|
|
51
|
+
str(path), embeddings, allow_dangerous_deserialization=True
|
|
52
|
+
)
|
|
50
53
|
|
|
51
54
|
@staticmethod
|
|
52
55
|
def load_documents(docs_folder: str) -> List["Document"]:
|
|
@@ -114,7 +117,7 @@ class FAISS_Store(InformationRetrieval):
|
|
|
114
117
|
) -> SearchResultList:
|
|
115
118
|
logger.debug("information_retrieval.faiss_store.search", query=query)
|
|
116
119
|
try:
|
|
117
|
-
documents = await self.index.as_retriever().
|
|
120
|
+
documents = await self.index.as_retriever().ainvoke(query)
|
|
118
121
|
except Exception as exc:
|
|
119
122
|
raise InformationRetrievalException from exc
|
|
120
123
|
|
|
@@ -19,6 +19,14 @@ logger = structlog.get_logger()
|
|
|
19
19
|
|
|
20
20
|
@dataclass
|
|
21
21
|
class SearchResult:
|
|
22
|
+
"""A search result object.
|
|
23
|
+
|
|
24
|
+
Attributes:
|
|
25
|
+
text: The text content of the retrieved document result.
|
|
26
|
+
metadata: The metadata associated with the document result.
|
|
27
|
+
score: The score of the search result.
|
|
28
|
+
"""
|
|
29
|
+
|
|
22
30
|
text: str
|
|
23
31
|
metadata: dict
|
|
24
32
|
score: Optional[float] = None
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from typing import Text, Any, Dict
|
|
2
2
|
|
|
3
3
|
import structlog
|
|
4
|
-
from
|
|
4
|
+
from langchain_community.vectorstores.milvus import Milvus
|
|
5
5
|
from rasa.utils.endpoints import EndpointConfig
|
|
6
6
|
|
|
7
7
|
from rasa.core.information_retrieval import (
|
|
@@ -48,5 +48,12 @@ class Milvus_Store(InformationRetrieval):
|
|
|
48
48
|
except Exception as exc:
|
|
49
49
|
raise InformationRetrievalException from exc
|
|
50
50
|
|
|
51
|
-
|
|
51
|
+
scores = [score for _, score in hits]
|
|
52
|
+
logger.debug(
|
|
53
|
+
"information_retrieval.milvus_store.search_results_before_threshold",
|
|
54
|
+
scores=scores,
|
|
55
|
+
)
|
|
56
|
+
# Milvus uses Euclidean distance metric by default
|
|
57
|
+
# so the lower the score, the better the match.
|
|
58
|
+
filtered_hits = [doc for doc, score in hits if score <= threshold]
|
|
52
59
|
return SearchResultList.from_document_list(filtered_hits)
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from typing import Text, Any, Dict
|
|
2
2
|
|
|
3
3
|
import structlog
|
|
4
|
-
from
|
|
4
|
+
from langchain_community.vectorstores.qdrant import Qdrant
|
|
5
5
|
from pydantic import ValidationError
|
|
6
6
|
from qdrant_client import QdrantClient
|
|
7
7
|
from rasa.utils.endpoints import EndpointConfig
|
|
@@ -5,6 +5,15 @@ from jinja2 import Template
|
|
|
5
5
|
|
|
6
6
|
from rasa import telemetry
|
|
7
7
|
from rasa.core.nlg.response import TemplatedNaturalLanguageGenerator
|
|
8
|
+
from rasa.shared.constants import (
|
|
9
|
+
LLM_CONFIG_KEY,
|
|
10
|
+
MODEL_CONFIG_KEY,
|
|
11
|
+
MODEL_NAME_CONFIG_KEY,
|
|
12
|
+
PROMPT_CONFIG_KEY,
|
|
13
|
+
PROVIDER_CONFIG_KEY,
|
|
14
|
+
OPENAI_PROVIDER,
|
|
15
|
+
TIMEOUT_CONFIG_KEY,
|
|
16
|
+
)
|
|
8
17
|
from rasa.shared.core.domain import KEY_RESPONSES_TEXT, Domain
|
|
9
18
|
from rasa.shared.core.events import BotUttered, UserUttered
|
|
10
19
|
from rasa.shared.core.trackers import DialogueStateTracker
|
|
@@ -15,6 +24,7 @@ from rasa.shared.utils.llm import (
|
|
|
15
24
|
combine_custom_and_default_config,
|
|
16
25
|
get_prompt_template,
|
|
17
26
|
llm_factory,
|
|
27
|
+
try_instantiate_llm_client,
|
|
18
28
|
)
|
|
19
29
|
from rasa.utils.endpoints import EndpointConfig
|
|
20
30
|
|
|
@@ -31,11 +41,11 @@ RESPONSE_REPHRASING_TEMPLATE_KEY = "rephrase_prompt"
|
|
|
31
41
|
DEFAULT_REPHRASE_ALL = False
|
|
32
42
|
|
|
33
43
|
DEFAULT_LLM_CONFIG = {
|
|
34
|
-
|
|
35
|
-
|
|
44
|
+
PROVIDER_CONFIG_KEY: OPENAI_PROVIDER,
|
|
45
|
+
MODEL_CONFIG_KEY: DEFAULT_OPENAI_GENERATE_MODEL_NAME,
|
|
36
46
|
"temperature": 0.3,
|
|
37
|
-
"model_name": DEFAULT_OPENAI_GENERATE_MODEL_NAME,
|
|
38
47
|
"max_tokens": DEFAULT_OPENAI_MAX_GENERATED_TOKENS,
|
|
48
|
+
TIMEOUT_CONFIG_KEY: 5,
|
|
39
49
|
}
|
|
40
50
|
|
|
41
51
|
DEFAULT_RESPONSE_VARIATION_PROMPT_TEMPLATE = """The following is a conversation with
|
|
@@ -78,7 +88,7 @@ class ContextualResponseRephraser(TemplatedNaturalLanguageGenerator):
|
|
|
78
88
|
|
|
79
89
|
self.nlg_endpoint = endpoint_config
|
|
80
90
|
self.prompt_template = get_prompt_template(
|
|
81
|
-
self.nlg_endpoint.kwargs.get(
|
|
91
|
+
self.nlg_endpoint.kwargs.get(PROMPT_CONFIG_KEY),
|
|
82
92
|
DEFAULT_RESPONSE_VARIATION_PROMPT_TEMPLATE,
|
|
83
93
|
)
|
|
84
94
|
self.rephrase_all = self.nlg_endpoint.kwargs.get(
|
|
@@ -87,6 +97,12 @@ class ContextualResponseRephraser(TemplatedNaturalLanguageGenerator):
|
|
|
87
97
|
self.trace_prompt_tokens = self.nlg_endpoint.kwargs.get(
|
|
88
98
|
"trace_prompt_tokens", False
|
|
89
99
|
)
|
|
100
|
+
try_instantiate_llm_client(
|
|
101
|
+
self.nlg_endpoint.kwargs.get(LLM_CONFIG_KEY),
|
|
102
|
+
DEFAULT_LLM_CONFIG,
|
|
103
|
+
"contextual_response_rephraser.init",
|
|
104
|
+
"ContextualResponseRephraser",
|
|
105
|
+
)
|
|
90
106
|
|
|
91
107
|
def _last_message_if_human(self, tracker: DialogueStateTracker) -> Optional[str]:
|
|
92
108
|
"""Returns the latest message from the tracker.
|
|
@@ -115,10 +131,13 @@ class ContextualResponseRephraser(TemplatedNaturalLanguageGenerator):
|
|
|
115
131
|
Returns:
|
|
116
132
|
generated text
|
|
117
133
|
"""
|
|
118
|
-
llm = llm_factory(
|
|
134
|
+
llm = llm_factory(
|
|
135
|
+
self.nlg_endpoint.kwargs.get(LLM_CONFIG_KEY), DEFAULT_LLM_CONFIG
|
|
136
|
+
)
|
|
119
137
|
|
|
120
138
|
try:
|
|
121
|
-
|
|
139
|
+
llm_response = await llm.acompletion(prompt)
|
|
140
|
+
return llm_response.choices[0]
|
|
122
141
|
except Exception as e:
|
|
123
142
|
# unfortunately, langchain does not wrap LLM exceptions which means
|
|
124
143
|
# we have to catch all exceptions here
|
|
@@ -128,7 +147,7 @@ class ContextualResponseRephraser(TemplatedNaturalLanguageGenerator):
|
|
|
128
147
|
def llm_property(self, prop: str) -> Optional[str]:
|
|
129
148
|
"""Returns a property of the LLM provider."""
|
|
130
149
|
return combine_custom_and_default_config(
|
|
131
|
-
self.nlg_endpoint.kwargs.get(
|
|
150
|
+
self.nlg_endpoint.kwargs.get(LLM_CONFIG_KEY), DEFAULT_LLM_CONFIG
|
|
132
151
|
).get(prop)
|
|
133
152
|
|
|
134
153
|
def custom_prompt_template(self, prompt_template: str) -> Optional[str]:
|
|
@@ -161,7 +180,9 @@ class ContextualResponseRephraser(TemplatedNaturalLanguageGenerator):
|
|
|
161
180
|
Returns:
|
|
162
181
|
The history for the prompt.
|
|
163
182
|
"""
|
|
164
|
-
llm = llm_factory(
|
|
183
|
+
llm = llm_factory(
|
|
184
|
+
self.nlg_endpoint.kwargs.get(LLM_CONFIG_KEY), DEFAULT_LLM_CONFIG
|
|
185
|
+
)
|
|
165
186
|
return await summarize_conversation(tracker, llm, max_turns=5)
|
|
166
187
|
|
|
167
188
|
async def rephrase(
|
|
@@ -202,8 +223,9 @@ class ContextualResponseRephraser(TemplatedNaturalLanguageGenerator):
|
|
|
202
223
|
telemetry.track_response_rephrase(
|
|
203
224
|
rephrase_all=self.rephrase_all,
|
|
204
225
|
custom_prompt_template=self.custom_prompt_template(prompt_template_text),
|
|
205
|
-
llm_type=self.llm_property(
|
|
206
|
-
llm_model=self.llm_property(
|
|
226
|
+
llm_type=self.llm_property(PROVIDER_CONFIG_KEY),
|
|
227
|
+
llm_model=self.llm_property(MODEL_CONFIG_KEY)
|
|
228
|
+
or self.llm_property(MODEL_NAME_CONFIG_KEY),
|
|
207
229
|
)
|
|
208
230
|
if not (updated_text := await self._generate_llm_response(prompt)):
|
|
209
231
|
# If the LLM fails to generate a response, we
|
rasa/core/nlg/summarize.py
CHANGED
|
@@ -2,8 +2,8 @@ from typing import Optional
|
|
|
2
2
|
|
|
3
3
|
import structlog
|
|
4
4
|
from jinja2 import Template
|
|
5
|
-
from langchain.llms.base import BaseLLM
|
|
6
5
|
from rasa.core.tracker_store import DialogueStateTracker
|
|
6
|
+
from rasa.shared.providers.llm.llm_client import LLMClient
|
|
7
7
|
from rasa.shared.utils.llm import (
|
|
8
8
|
tracker_as_readable_transcript,
|
|
9
9
|
)
|
|
@@ -43,7 +43,7 @@ def _create_summarization_prompt(
|
|
|
43
43
|
|
|
44
44
|
async def summarize_conversation(
|
|
45
45
|
tracker: DialogueStateTracker,
|
|
46
|
-
llm:
|
|
46
|
+
llm: LLMClient,
|
|
47
47
|
max_turns: Optional[int] = MAX_TURNS_DEFAULT,
|
|
48
48
|
) -> str:
|
|
49
49
|
"""Summarizes the dialogue using the LLM.
|
|
@@ -58,7 +58,8 @@ async def summarize_conversation(
|
|
|
58
58
|
"""
|
|
59
59
|
prompt = _create_summarization_prompt(tracker, max_turns)
|
|
60
60
|
try:
|
|
61
|
-
|
|
61
|
+
llm_response = await llm.acompletion(prompt)
|
|
62
|
+
summarization = llm_response.choices[0].strip()
|
|
62
63
|
structlogger.debug(
|
|
63
64
|
"summarization.success", summarization=summarization, prompt=prompt
|
|
64
65
|
)
|