rasa-pro 3.11.0__py3-none-any.whl → 3.11.0a2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- README.md +396 -17
- rasa/__main__.py +15 -31
- rasa/api.py +1 -5
- rasa/cli/arguments/default_arguments.py +2 -1
- rasa/cli/arguments/shell.py +1 -5
- rasa/cli/arguments/train.py +0 -14
- rasa/cli/e2e_test.py +1 -1
- rasa/cli/evaluate.py +8 -8
- rasa/cli/inspect.py +5 -7
- rasa/cli/interactive.py +0 -1
- rasa/cli/llm_fine_tuning.py +1 -1
- rasa/cli/project_templates/calm/config.yml +7 -5
- rasa/cli/project_templates/calm/endpoints.yml +2 -15
- rasa/cli/project_templates/tutorial/config.yml +5 -8
- rasa/cli/project_templates/tutorial/data/flows.yml +1 -1
- rasa/cli/project_templates/tutorial/data/patterns.yml +0 -5
- rasa/cli/project_templates/tutorial/domain.yml +0 -14
- rasa/cli/project_templates/tutorial/endpoints.yml +0 -5
- rasa/cli/run.py +1 -1
- rasa/cli/scaffold.py +2 -4
- rasa/cli/studio/studio.py +8 -18
- rasa/cli/studio/upload.py +15 -0
- rasa/cli/train.py +0 -3
- rasa/cli/utils.py +1 -6
- rasa/cli/x.py +8 -8
- rasa/constants.py +1 -3
- rasa/core/actions/action.py +33 -75
- rasa/core/actions/e2e_stub_custom_action_executor.py +1 -5
- rasa/core/actions/http_custom_action_executor.py +0 -4
- rasa/core/channels/channel.py +0 -20
- rasa/core/channels/development_inspector.py +2 -8
- rasa/core/channels/inspector/dist/assets/{arc-bc141fb2.js → arc-6852c607.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{c4Diagram-d0fbc5ce-be2db283.js → c4Diagram-d0fbc5ce-acc952b2.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{classDiagram-936ed81e-55366915.js → classDiagram-936ed81e-848a7597.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{classDiagram-v2-c3cb15f1-bb529518.js → classDiagram-v2-c3cb15f1-a73d3e68.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{createText-62fc7601-b0ec81d6.js → createText-62fc7601-e5ee049d.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{edges-f2ad444c-6166330c.js → edges-f2ad444c-771e517e.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{erDiagram-9d236eb7-5ccc6a8e.js → erDiagram-9d236eb7-aa347178.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{flowDb-1972c806-fca3bfe4.js → flowDb-1972c806-651fc57d.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{flowDiagram-7ea5b25a-4739080f.js → flowDiagram-7ea5b25a-ca67804f.js} +1 -1
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-587d82d8.js +1 -0
- rasa/core/channels/inspector/dist/assets/{flowchart-elk-definition-abe16c3d-7c1b0e0f.js → flowchart-elk-definition-abe16c3d-2dbc568d.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{ganttDiagram-9b5ea136-772fd050.js → ganttDiagram-9b5ea136-25a65bd8.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{gitGraphDiagram-99d0ae7c-8eae1dc9.js → gitGraphDiagram-99d0ae7c-fdc7378d.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{index-2c4b9a3b-f55afcdf.js → index-2c4b9a3b-6f1fd606.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{index-e7cef9de.js → index-efdd30c1.js} +68 -68
- rasa/core/channels/inspector/dist/assets/{infoDiagram-736b4530-124d4a14.js → infoDiagram-736b4530-cb1a041a.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{journeyDiagram-df861f2b-7c4fae44.js → journeyDiagram-df861f2b-14609879.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{layout-b9885fb6.js → layout-2490f52b.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{line-7c59abb6.js → line-40186f1f.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{linear-4776f780.js → linear-08814e93.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{mindmap-definition-beec6740-2332c46c.js → mindmap-definition-beec6740-1a534584.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{pieDiagram-dbbf0591-8fb39303.js → pieDiagram-dbbf0591-72397b61.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{quadrantDiagram-4d7f4fd6-3c7180a2.js → quadrantDiagram-4d7f4fd6-3bb0b6a3.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{requirementDiagram-6fc4c22a-e910bcb8.js → requirementDiagram-6fc4c22a-57334f61.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{sankeyDiagram-8f13d901-ead16c89.js → sankeyDiagram-8f13d901-111e1297.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{sequenceDiagram-b655622a-29a02a19.js → sequenceDiagram-b655622a-10bcfe62.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{stateDiagram-59f0c015-042b3137.js → stateDiagram-59f0c015-acaf7513.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{stateDiagram-v2-2b26beab-2178c0f3.js → stateDiagram-v2-2b26beab-3ec2a235.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-080da4f6-23ffa4fc.js → styles-080da4f6-62730289.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-3dcbcfbf-94f59763.js → styles-3dcbcfbf-5284ee76.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-9c745c82-78a6bebc.js → styles-9c745c82-642435e3.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{svgDrawCommon-4835440b-eae2a6f6.js → svgDrawCommon-4835440b-b250a350.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{timeline-definition-5b62e21b-5c968d92.js → timeline-definition-5b62e21b-c2b147ed.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{xychartDiagram-2b33534f-fd3db0d5.js → xychartDiagram-2b33534f-f92cfea9.js} +1 -1
- rasa/core/channels/inspector/dist/index.html +1 -1
- rasa/core/channels/inspector/src/App.tsx +1 -1
- rasa/core/channels/inspector/src/helpers/audiostream.ts +16 -77
- rasa/core/channels/socketio.py +2 -7
- rasa/core/channels/telegram.py +1 -1
- rasa/core/channels/twilio.py +1 -1
- rasa/core/channels/voice_ready/audiocodes.py +4 -15
- rasa/core/channels/voice_ready/jambonz.py +4 -15
- rasa/core/channels/voice_ready/twilio_voice.py +21 -6
- rasa/core/channels/voice_ready/utils.py +5 -6
- rasa/core/channels/voice_stream/asr/asr_engine.py +1 -19
- rasa/core/channels/voice_stream/asr/asr_event.py +0 -5
- rasa/core/channels/voice_stream/asr/deepgram.py +15 -28
- rasa/core/channels/voice_stream/audio_bytes.py +0 -1
- rasa/core/channels/voice_stream/browser_audio.py +9 -32
- rasa/core/channels/voice_stream/tts/azure.py +3 -9
- rasa/core/channels/voice_stream/tts/cartesia.py +8 -12
- rasa/core/channels/voice_stream/tts/tts_engine.py +1 -11
- rasa/core/channels/voice_stream/twilio_media_streams.py +19 -28
- rasa/core/channels/voice_stream/util.py +4 -4
- rasa/core/channels/voice_stream/voice_channel.py +42 -222
- rasa/core/featurizers/single_state_featurizer.py +1 -22
- rasa/core/featurizers/tracker_featurizers.py +18 -115
- rasa/core/information_retrieval/qdrant.py +0 -1
- rasa/core/nlg/contextual_response_rephraser.py +25 -44
- rasa/core/persistor.py +34 -191
- rasa/core/policies/enterprise_search_policy.py +60 -119
- rasa/core/policies/flows/flow_executor.py +4 -7
- rasa/core/policies/intentless_policy.py +22 -82
- rasa/core/policies/ted_policy.py +33 -58
- rasa/core/policies/unexpected_intent_policy.py +7 -15
- rasa/core/processor.py +5 -32
- rasa/core/training/interactive.py +35 -34
- rasa/core/utils.py +22 -58
- rasa/dialogue_understanding/coexistence/llm_based_router.py +12 -39
- rasa/dialogue_understanding/commands/__init__.py +0 -4
- rasa/dialogue_understanding/commands/change_flow_command.py +0 -6
- rasa/dialogue_understanding/commands/utils.py +0 -5
- rasa/dialogue_understanding/generator/constants.py +0 -2
- rasa/dialogue_understanding/generator/flow_retrieval.py +4 -49
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +23 -37
- rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +10 -57
- rasa/dialogue_understanding/generator/nlu_command_adapter.py +1 -19
- rasa/dialogue_understanding/generator/single_step/command_prompt_template.jinja2 +0 -3
- rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +10 -90
- rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +0 -53
- rasa/dialogue_understanding/processor/command_processor.py +1 -21
- rasa/e2e_test/assertions.py +16 -133
- rasa/e2e_test/assertions_schema.yml +0 -23
- rasa/e2e_test/e2e_test_case.py +6 -85
- rasa/e2e_test/e2e_test_runner.py +4 -6
- rasa/e2e_test/utils/io.py +1 -3
- rasa/engine/loader.py +0 -12
- rasa/engine/validation.py +11 -541
- rasa/keys +1 -0
- rasa/llm_fine_tuning/notebooks/unsloth_finetuning.ipynb +407 -0
- rasa/model_training.py +7 -29
- rasa/nlu/classifiers/diet_classifier.py +25 -38
- rasa/nlu/classifiers/logistic_regression_classifier.py +9 -22
- rasa/nlu/classifiers/sklearn_intent_classifier.py +16 -37
- rasa/nlu/extractors/crf_entity_extractor.py +50 -93
- rasa/nlu/featurizers/sparse_featurizer/count_vectors_featurizer.py +16 -45
- rasa/nlu/featurizers/sparse_featurizer/lexical_syntactic_featurizer.py +17 -52
- rasa/nlu/featurizers/sparse_featurizer/regex_featurizer.py +3 -5
- rasa/nlu/tokenizers/whitespace_tokenizer.py +14 -3
- rasa/server.py +1 -3
- rasa/shared/constants.py +0 -61
- rasa/shared/core/constants.py +0 -9
- rasa/shared/core/domain.py +5 -8
- rasa/shared/core/flows/flow.py +0 -5
- rasa/shared/core/flows/flows_list.py +1 -5
- rasa/shared/core/flows/flows_yaml_schema.json +0 -10
- rasa/shared/core/flows/validation.py +0 -96
- rasa/shared/core/flows/yaml_flows_io.py +4 -13
- rasa/shared/core/slots.py +0 -5
- rasa/shared/importers/importer.py +2 -19
- rasa/shared/importers/rasa.py +1 -5
- rasa/shared/nlu/training_data/features.py +2 -120
- rasa/shared/nlu/training_data/formats/rasa_yaml.py +3 -18
- rasa/shared/providers/_configs/azure_openai_client_config.py +3 -5
- rasa/shared/providers/_configs/openai_client_config.py +1 -1
- rasa/shared/providers/_configs/self_hosted_llm_client_config.py +0 -1
- rasa/shared/providers/_configs/utils.py +0 -16
- rasa/shared/providers/embedding/_base_litellm_embedding_client.py +29 -18
- rasa/shared/providers/embedding/azure_openai_embedding_client.py +21 -54
- rasa/shared/providers/embedding/default_litellm_embedding_client.py +0 -24
- rasa/shared/providers/llm/_base_litellm_client.py +31 -63
- rasa/shared/providers/llm/azure_openai_llm_client.py +29 -50
- rasa/shared/providers/llm/default_litellm_llm_client.py +0 -24
- rasa/shared/providers/llm/self_hosted_llm_client.py +29 -17
- rasa/shared/providers/mappings.py +0 -19
- rasa/shared/utils/common.py +2 -37
- rasa/shared/utils/io.py +6 -28
- rasa/shared/utils/llm.py +46 -353
- rasa/shared/utils/yaml.py +82 -181
- rasa/studio/auth.py +5 -3
- rasa/studio/config.py +4 -13
- rasa/studio/constants.py +0 -1
- rasa/studio/data_handler.py +4 -13
- rasa/studio/upload.py +80 -175
- rasa/telemetry.py +17 -94
- rasa/tracing/config.py +1 -3
- rasa/tracing/instrumentation/attribute_extractors.py +17 -94
- rasa/tracing/instrumentation/instrumentation.py +0 -121
- rasa/utils/common.py +0 -5
- rasa/utils/endpoints.py +1 -27
- rasa/utils/io.py +81 -7
- rasa/utils/log_utils.py +2 -9
- rasa/utils/tensorflow/model_data.py +193 -2
- rasa/validator.py +4 -110
- rasa/version.py +1 -1
- rasa_pro-3.11.0a2.dist-info/METADATA +576 -0
- {rasa_pro-3.11.0.dist-info → rasa_pro-3.11.0a2.dist-info}/RECORD +181 -213
- rasa/core/actions/action_repeat_bot_messages.py +0 -89
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-736177bf.js +0 -1
- rasa/core/channels/voice_stream/asr/azure.py +0 -129
- rasa/core/channels/voice_stream/call_state.py +0 -23
- rasa/dialogue_understanding/commands/repeat_bot_messages_command.py +0 -60
- rasa/dialogue_understanding/commands/user_silence_command.py +0 -59
- rasa/dialogue_understanding/patterns/repeat.py +0 -37
- rasa/dialogue_understanding/patterns/user_silence.py +0 -37
- rasa/model_manager/__init__.py +0 -0
- rasa/model_manager/config.py +0 -40
- rasa/model_manager/model_api.py +0 -559
- rasa/model_manager/runner_service.py +0 -286
- rasa/model_manager/socket_bridge.py +0 -146
- rasa/model_manager/studio_jwt_auth.py +0 -86
- rasa/model_manager/trainer_service.py +0 -325
- rasa/model_manager/utils.py +0 -87
- rasa/model_manager/warm_rasa_process.py +0 -187
- rasa/model_service.py +0 -112
- rasa/shared/core/flows/utils.py +0 -39
- rasa/shared/providers/_configs/litellm_router_client_config.py +0 -220
- rasa/shared/providers/_configs/model_group_config.py +0 -167
- rasa/shared/providers/_configs/rasa_llm_client_config.py +0 -73
- rasa/shared/providers/_utils.py +0 -79
- rasa/shared/providers/embedding/litellm_router_embedding_client.py +0 -135
- rasa/shared/providers/llm/litellm_router_llm_client.py +0 -182
- rasa/shared/providers/llm/rasa_llm_client.py +0 -112
- rasa/shared/providers/router/__init__.py +0 -0
- rasa/shared/providers/router/_base_litellm_router_client.py +0 -183
- rasa/shared/providers/router/router_client.py +0 -73
- rasa/shared/utils/health_check/__init__.py +0 -0
- rasa/shared/utils/health_check/embeddings_health_check_mixin.py +0 -31
- rasa/shared/utils/health_check/health_check.py +0 -258
- rasa/shared/utils/health_check/llm_health_check_mixin.py +0 -31
- rasa/utils/sanic_error_handler.py +0 -32
- rasa/utils/tensorflow/feature_array.py +0 -366
- rasa_pro-3.11.0.dist-info/METADATA +0 -198
- {rasa_pro-3.11.0.dist-info → rasa_pro-3.11.0a2.dist-info}/NOTICE +0 -0
- {rasa_pro-3.11.0.dist-info → rasa_pro-3.11.0a2.dist-info}/WHEEL +0 -0
- {rasa_pro-3.11.0.dist-info → rasa_pro-3.11.0a2.dist-info}/entry_points.txt +0 -0
|
@@ -1,45 +1,25 @@
|
|
|
1
1
|
import asyncio
|
|
2
|
-
import
|
|
2
|
+
import logging
|
|
3
3
|
import copy
|
|
4
4
|
from dataclasses import asdict, dataclass
|
|
5
|
-
from typing import Any,
|
|
6
|
-
|
|
7
|
-
from rasa.core.channels.voice_stream.util import generate_silence
|
|
8
|
-
from rasa.shared.core.constants import SLOT_SILENCE_TIMEOUT
|
|
9
|
-
from rasa.shared.utils.common import (
|
|
10
|
-
class_from_module_path,
|
|
11
|
-
mark_as_beta_feature,
|
|
12
|
-
)
|
|
13
|
-
from rasa.shared.utils.cli import print_error_and_exit
|
|
5
|
+
from typing import Any, Awaitable, Callable, Dict, List, Optional
|
|
14
6
|
|
|
15
7
|
from sanic.exceptions import ServerError, WebsocketClosed
|
|
16
8
|
|
|
17
9
|
from rasa.core.channels import InputChannel, OutputChannel, UserMessage
|
|
18
10
|
from rasa.core.channels.voice_ready.utils import CallParameters
|
|
19
|
-
from rasa.core.channels.voice_ready.utils import validate_voice_license_scope
|
|
20
11
|
from rasa.core.channels.voice_stream.asr.asr_engine import ASREngine
|
|
21
|
-
from rasa.core.channels.voice_stream.asr.asr_event import
|
|
22
|
-
ASREvent,
|
|
23
|
-
NewTranscript,
|
|
24
|
-
UserIsSpeaking,
|
|
25
|
-
)
|
|
12
|
+
from rasa.core.channels.voice_stream.asr.asr_event import ASREvent, NewTranscript
|
|
26
13
|
from sanic import Websocket # type: ignore
|
|
27
14
|
|
|
28
15
|
from rasa.core.channels.voice_stream.asr.deepgram import DeepgramASR
|
|
29
|
-
from rasa.core.channels.voice_stream.
|
|
30
|
-
from rasa.core.channels.voice_stream.audio_bytes import HERTZ, RasaAudioBytes
|
|
31
|
-
from rasa.core.channels.voice_stream.call_state import (
|
|
32
|
-
CallState,
|
|
33
|
-
_call_state,
|
|
34
|
-
call_state,
|
|
35
|
-
)
|
|
16
|
+
from rasa.core.channels.voice_stream.audio_bytes import RasaAudioBytes
|
|
36
17
|
from rasa.core.channels.voice_stream.tts.azure import AzureTTS
|
|
37
18
|
from rasa.core.channels.voice_stream.tts.tts_engine import TTSEngine, TTSError
|
|
38
19
|
from rasa.core.channels.voice_stream.tts.cartesia import CartesiaTTS
|
|
39
20
|
from rasa.core.channels.voice_stream.tts.tts_cache import TTSCache
|
|
40
|
-
from rasa.utils.io import remove_emojis
|
|
41
21
|
|
|
42
|
-
logger =
|
|
22
|
+
logger = logging.getLogger(__name__)
|
|
43
23
|
|
|
44
24
|
|
|
45
25
|
@dataclass
|
|
@@ -63,55 +43,25 @@ class ContinueConversationAction(VoiceChannelAction):
|
|
|
63
43
|
|
|
64
44
|
|
|
65
45
|
def asr_engine_from_config(asr_config: Dict) -> ASREngine:
|
|
66
|
-
name = str(asr_config["name"])
|
|
46
|
+
name = str(asr_config["name"]).lower()
|
|
67
47
|
asr_config = copy.copy(asr_config)
|
|
68
48
|
asr_config.pop("name")
|
|
69
|
-
if name
|
|
49
|
+
if name == "deepgram":
|
|
70
50
|
return DeepgramASR.from_config_dict(asr_config)
|
|
71
|
-
if name == "azure":
|
|
72
|
-
return AzureASR.from_config_dict(asr_config)
|
|
73
51
|
else:
|
|
74
|
-
|
|
75
|
-
try:
|
|
76
|
-
asr_engine_class = class_from_module_path(name)
|
|
77
|
-
return asr_engine_class.from_config_dict(asr_config)
|
|
78
|
-
except NameError:
|
|
79
|
-
print_error_and_exit(
|
|
80
|
-
f"Failed to initialize ASR Engine with type '{name}'. "
|
|
81
|
-
f"Please make sure the method `from_config_dict`is implemented."
|
|
82
|
-
)
|
|
83
|
-
except TypeError as e:
|
|
84
|
-
print_error_and_exit(
|
|
85
|
-
f"Failed to initialize ASR Engine with type '{name}'. "
|
|
86
|
-
f"Invalid configuration provided. "
|
|
87
|
-
f"Error: {e}"
|
|
88
|
-
)
|
|
52
|
+
raise NotImplementedError
|
|
89
53
|
|
|
90
54
|
|
|
91
55
|
def tts_engine_from_config(tts_config: Dict) -> TTSEngine:
|
|
92
|
-
name = str(tts_config["name"])
|
|
56
|
+
name = str(tts_config["name"]).lower()
|
|
93
57
|
tts_config = copy.copy(tts_config)
|
|
94
58
|
tts_config.pop("name")
|
|
95
|
-
if name
|
|
59
|
+
if name == "azure":
|
|
96
60
|
return AzureTTS.from_config_dict(tts_config)
|
|
97
|
-
elif name
|
|
61
|
+
elif name == "cartesia":
|
|
98
62
|
return CartesiaTTS.from_config_dict(tts_config)
|
|
99
63
|
else:
|
|
100
|
-
|
|
101
|
-
try:
|
|
102
|
-
tts_engine_class = class_from_module_path(name)
|
|
103
|
-
return tts_engine_class.from_config_dict(tts_config)
|
|
104
|
-
except NameError:
|
|
105
|
-
print_error_and_exit(
|
|
106
|
-
f"Failed to initialize TTS Engine with type '{name}'. "
|
|
107
|
-
f"Please make sure the method `from_config_dict`is implemented."
|
|
108
|
-
)
|
|
109
|
-
except TypeError as e:
|
|
110
|
-
print_error_and_exit(
|
|
111
|
-
f"Failed to initialize ASR Engine with type '{name}'. "
|
|
112
|
-
f"Invalid configuration provided. "
|
|
113
|
-
f"Error: {e}"
|
|
114
|
-
)
|
|
64
|
+
raise NotImplementedError(f"TTS engine {name} is not implemented")
|
|
115
65
|
|
|
116
66
|
|
|
117
67
|
class VoiceOutputChannel(OutputChannel):
|
|
@@ -121,171 +71,75 @@ class VoiceOutputChannel(OutputChannel):
|
|
|
121
71
|
tts_engine: TTSEngine,
|
|
122
72
|
tts_cache: TTSCache,
|
|
123
73
|
):
|
|
124
|
-
super().__init__()
|
|
125
74
|
self.voice_websocket = voice_websocket
|
|
126
75
|
self.tts_engine = tts_engine
|
|
127
76
|
self.tts_cache = tts_cache
|
|
128
77
|
|
|
78
|
+
self.should_hangup = False
|
|
129
79
|
self.latest_message_id: Optional[str] = None
|
|
130
80
|
|
|
131
81
|
def rasa_audio_bytes_to_channel_bytes(
|
|
132
82
|
self, rasa_audio_bytes: RasaAudioBytes
|
|
133
83
|
) -> bytes:
|
|
134
|
-
"""Turn rasa's audio byte format into the format for the channel."""
|
|
135
84
|
raise NotImplementedError
|
|
136
85
|
|
|
137
|
-
def
|
|
138
|
-
|
|
86
|
+
def channel_bytes_to_messages(
|
|
87
|
+
self, recipient_id: str, channel_bytes: bytes
|
|
88
|
+
) -> List[Any]:
|
|
139
89
|
raise NotImplementedError
|
|
140
90
|
|
|
141
|
-
def create_marker_message(self, recipient_id: str) -> Tuple[str, str]:
|
|
142
|
-
"""Create a marker message for a specific channel."""
|
|
143
|
-
raise NotImplementedError
|
|
144
|
-
|
|
145
|
-
async def send_marker_message(self, recipient_id: str) -> None:
|
|
146
|
-
"""Send a message that marks positions in the audio stream."""
|
|
147
|
-
marker_message, mark_id = self.create_marker_message(recipient_id)
|
|
148
|
-
await self.voice_websocket.send(marker_message)
|
|
149
|
-
self.latest_message_id = mark_id
|
|
150
|
-
|
|
151
|
-
def update_silence_timeout(self) -> None:
|
|
152
|
-
"""Updates the silence timeout for the session."""
|
|
153
|
-
if self.tracker_state:
|
|
154
|
-
call_state.silence_timeout = ( # type: ignore[attr-defined]
|
|
155
|
-
self.tracker_state["slots"][SLOT_SILENCE_TIMEOUT]
|
|
156
|
-
)
|
|
157
|
-
|
|
158
|
-
async def send_text_with_buttons(
|
|
159
|
-
self,
|
|
160
|
-
recipient_id: str,
|
|
161
|
-
text: str,
|
|
162
|
-
buttons: List[Dict[str, Any]],
|
|
163
|
-
**kwargs: Any,
|
|
164
|
-
) -> None:
|
|
165
|
-
"""Uses the concise button output format for voice channels."""
|
|
166
|
-
await self.send_text_with_buttons_concise(recipient_id, text, buttons, **kwargs)
|
|
167
|
-
|
|
168
91
|
async def send_text_message(
|
|
169
92
|
self, recipient_id: str, text: str, **kwargs: Any
|
|
170
93
|
) -> None:
|
|
171
|
-
text = remove_emojis(text)
|
|
172
|
-
self.update_silence_timeout()
|
|
173
94
|
cached_audio_bytes = self.tts_cache.get(text)
|
|
174
|
-
collected_audio_bytes = RasaAudioBytes(b"")
|
|
175
|
-
seconds_marker = -1
|
|
176
|
-
if cached_audio_bytes:
|
|
177
|
-
audio_stream = self.chunk_audio(cached_audio_bytes)
|
|
178
|
-
else:
|
|
179
|
-
# Todo: make kwargs compatible with engine config
|
|
180
|
-
synth_config = self.tts_engine.config.__class__.from_dict({})
|
|
181
|
-
try:
|
|
182
|
-
audio_stream = self.tts_engine.synthesize(text, synth_config)
|
|
183
|
-
except TTSError:
|
|
184
|
-
# TODO: add message that works without tts, e.g. loading from disc
|
|
185
|
-
audio_stream = self.chunk_audio(generate_silence())
|
|
186
95
|
|
|
96
|
+
if cached_audio_bytes:
|
|
97
|
+
await self.send_audio_bytes(recipient_id, cached_audio_bytes)
|
|
98
|
+
return
|
|
99
|
+
collected_audio_bytes = RasaAudioBytes(b"")
|
|
100
|
+
# Todo: make kwargs compatible with engine config
|
|
101
|
+
synth_config = self.tts_engine.config.__class__.from_dict({})
|
|
102
|
+
try:
|
|
103
|
+
audio_stream = self.tts_engine.synthesize(text, synth_config)
|
|
104
|
+
except TTSError:
|
|
105
|
+
# TODO: add message that works without tts, e.g. loading from disc
|
|
106
|
+
pass
|
|
187
107
|
async for audio_bytes in audio_stream:
|
|
188
108
|
try:
|
|
189
109
|
await self.send_audio_bytes(recipient_id, audio_bytes)
|
|
190
|
-
full_seconds_of_audio = len(collected_audio_bytes) // HERTZ
|
|
191
|
-
if full_seconds_of_audio > seconds_marker:
|
|
192
|
-
await self.send_marker_message(recipient_id)
|
|
193
|
-
seconds_marker = full_seconds_of_audio
|
|
194
|
-
|
|
195
110
|
except (WebsocketClosed, ServerError):
|
|
196
111
|
# ignore sending error, and keep collecting and caching audio bytes
|
|
197
|
-
|
|
112
|
+
self.should_hangup = True
|
|
113
|
+
|
|
198
114
|
collected_audio_bytes = RasaAudioBytes(collected_audio_bytes + audio_bytes)
|
|
199
|
-
try:
|
|
200
|
-
await self.send_marker_message(recipient_id)
|
|
201
|
-
except (WebsocketClosed, ServerError):
|
|
202
|
-
# ignore sending error
|
|
203
|
-
pass
|
|
204
|
-
call_state.latest_bot_audio_id = self.latest_message_id # type: ignore[attr-defined]
|
|
205
115
|
|
|
206
|
-
|
|
207
|
-
self.tts_cache.put(text, collected_audio_bytes)
|
|
116
|
+
self.tts_cache.put(text, collected_audio_bytes)
|
|
208
117
|
|
|
209
118
|
async def send_audio_bytes(
|
|
210
119
|
self, recipient_id: str, audio_bytes: RasaAudioBytes
|
|
211
120
|
) -> None:
|
|
212
121
|
channel_bytes = self.rasa_audio_bytes_to_channel_bytes(audio_bytes)
|
|
213
|
-
message
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
async def chunk_audio(
|
|
217
|
-
self, audio_bytes: RasaAudioBytes, chunk_size: int = 2048
|
|
218
|
-
) -> AsyncIterator[RasaAudioBytes]:
|
|
219
|
-
"""Generate chunks from cached audio bytes."""
|
|
220
|
-
offset = 0
|
|
221
|
-
while offset < len(audio_bytes):
|
|
222
|
-
chunk = audio_bytes[offset : offset + chunk_size]
|
|
223
|
-
if len(chunk):
|
|
224
|
-
yield RasaAudioBytes(chunk)
|
|
225
|
-
offset += chunk_size
|
|
226
|
-
return
|
|
122
|
+
for message in self.channel_bytes_to_messages(recipient_id, channel_bytes):
|
|
123
|
+
await self.voice_websocket.send(message)
|
|
227
124
|
|
|
228
125
|
async def hangup(self, recipient_id: str, **kwargs: Any) -> None:
|
|
229
|
-
|
|
126
|
+
self.should_hangup = True
|
|
230
127
|
|
|
231
128
|
|
|
232
129
|
class VoiceInputChannel(InputChannel):
|
|
233
|
-
def __init__(
|
|
234
|
-
self,
|
|
235
|
-
server_url: str,
|
|
236
|
-
asr_config: Dict,
|
|
237
|
-
tts_config: Dict,
|
|
238
|
-
monitor_silence: bool = False,
|
|
239
|
-
):
|
|
240
|
-
validate_voice_license_scope()
|
|
130
|
+
def __init__(self, server_url: str, asr_config: Dict, tts_config: Dict):
|
|
241
131
|
self.server_url = server_url
|
|
242
132
|
self.asr_config = asr_config
|
|
243
133
|
self.tts_config = tts_config
|
|
244
|
-
self.monitor_silence = monitor_silence
|
|
245
134
|
self.tts_cache = TTSCache(tts_config.get("cache_size", 1000))
|
|
246
135
|
|
|
247
|
-
|
|
248
|
-
self
|
|
249
|
-
voice_websocket: Websocket,
|
|
250
|
-
on_new_message: Callable[[UserMessage], Awaitable[Any]],
|
|
251
|
-
tts_engine: TTSEngine,
|
|
252
|
-
call_parameters: CallParameters,
|
|
253
|
-
) -> None:
|
|
254
|
-
timeout = call_state.silence_timeout
|
|
255
|
-
if not timeout:
|
|
256
|
-
return
|
|
257
|
-
if not self.monitor_silence:
|
|
258
|
-
return
|
|
259
|
-
logger.debug("voice_channel.silence_timeout_watch_started", timeout=timeout)
|
|
260
|
-
await asyncio.sleep(timeout)
|
|
261
|
-
logger.debug("voice_channel.silence_timeout_tripped")
|
|
262
|
-
output_channel = self.create_output_channel(voice_websocket, tts_engine)
|
|
263
|
-
message = UserMessage(
|
|
264
|
-
"/silence_timeout",
|
|
265
|
-
output_channel,
|
|
266
|
-
call_parameters.stream_id,
|
|
267
|
-
input_channel=self.name(),
|
|
268
|
-
metadata=asdict(call_parameters),
|
|
269
|
-
)
|
|
270
|
-
await on_new_message(message)
|
|
271
|
-
|
|
272
|
-
@staticmethod
|
|
273
|
-
def _cancel_silence_timeout_watcher() -> None:
|
|
274
|
-
"""Cancels the silent timeout task if it exists."""
|
|
275
|
-
if call_state.silence_timeout_watcher:
|
|
276
|
-
logger.debug("voice_channel.cancelling_current_timeout_watcher_task")
|
|
277
|
-
call_state.silence_timeout_watcher.cancel()
|
|
278
|
-
call_state.silence_timeout_watcher = None # type: ignore[attr-defined]
|
|
136
|
+
# if set to a value, call will be hungup after marker is reached
|
|
137
|
+
self.hangup_after: Optional[str] = None
|
|
279
138
|
|
|
280
139
|
@classmethod
|
|
281
140
|
def from_credentials(cls, credentials: Optional[Dict[str, Any]]) -> InputChannel:
|
|
282
141
|
credentials = credentials or {}
|
|
283
|
-
return cls(
|
|
284
|
-
credentials["server_url"],
|
|
285
|
-
credentials["asr"],
|
|
286
|
-
credentials["tts"],
|
|
287
|
-
credentials.get("monitor_silence", False),
|
|
288
|
-
)
|
|
142
|
+
return cls(credentials["server_url"], credentials["asr"], credentials["tts"])
|
|
289
143
|
|
|
290
144
|
def channel_bytes_to_rasa_audio_bytes(self, input_bytes: bytes) -> RasaAudioBytes:
|
|
291
145
|
raise NotImplementedError
|
|
@@ -325,7 +179,6 @@ class VoiceInputChannel(InputChannel):
|
|
|
325
179
|
channel_websocket: Websocket,
|
|
326
180
|
) -> None:
|
|
327
181
|
"""Pipe input audio to ASR and consume ASR events simultaneously."""
|
|
328
|
-
_call_state.set(CallState())
|
|
329
182
|
asr_engine = asr_engine_from_config(self.asr_config)
|
|
330
183
|
tts_engine = tts_engine_from_config(self.tts_config)
|
|
331
184
|
await asr_engine.connect()
|
|
@@ -339,29 +192,7 @@ class VoiceInputChannel(InputChannel):
|
|
|
339
192
|
|
|
340
193
|
async def consume_audio_bytes() -> None:
|
|
341
194
|
async for message in channel_websocket:
|
|
342
|
-
is_bot_speaking_before = call_state.is_bot_speaking
|
|
343
195
|
channel_action = self.map_input_message(message)
|
|
344
|
-
is_bot_speaking_after = call_state.is_bot_speaking
|
|
345
|
-
|
|
346
|
-
if not is_bot_speaking_before and is_bot_speaking_after:
|
|
347
|
-
logger.debug("voice_channel.bot_started_speaking")
|
|
348
|
-
# relevant when the bot speaks multiple messages in one turn
|
|
349
|
-
self._cancel_silence_timeout_watcher()
|
|
350
|
-
|
|
351
|
-
# we just stopped speaking, starting a watcher for silence timeout
|
|
352
|
-
if is_bot_speaking_before and not is_bot_speaking_after:
|
|
353
|
-
logger.debug("voice_channel.bot_stopped_speaking")
|
|
354
|
-
self._cancel_silence_timeout_watcher()
|
|
355
|
-
call_state.silence_timeout_watcher = ( # type: ignore[attr-defined]
|
|
356
|
-
asyncio.create_task(
|
|
357
|
-
self.handle_silence_timeout(
|
|
358
|
-
channel_websocket,
|
|
359
|
-
on_new_message,
|
|
360
|
-
tts_engine,
|
|
361
|
-
call_parameters,
|
|
362
|
-
)
|
|
363
|
-
)
|
|
364
|
-
)
|
|
365
196
|
if isinstance(channel_action, NewAudioAction):
|
|
366
197
|
await asr_engine.send_audio_chunks(channel_action.audio_bytes)
|
|
367
198
|
elif isinstance(channel_action, EndConversationAction):
|
|
@@ -378,20 +209,12 @@ class VoiceInputChannel(InputChannel):
|
|
|
378
209
|
call_parameters,
|
|
379
210
|
)
|
|
380
211
|
|
|
381
|
-
audio_forwarding_task = asyncio.create_task(consume_audio_bytes())
|
|
382
|
-
asr_event_task = asyncio.create_task(consume_asr_events())
|
|
383
212
|
await asyncio.wait(
|
|
384
|
-
[
|
|
213
|
+
[consume_audio_bytes(), consume_asr_events()],
|
|
385
214
|
return_when=asyncio.FIRST_COMPLETED,
|
|
386
215
|
)
|
|
387
|
-
if not audio_forwarding_task.done():
|
|
388
|
-
audio_forwarding_task.cancel()
|
|
389
|
-
if not asr_event_task.done():
|
|
390
|
-
asr_event_task.cancel()
|
|
391
216
|
await tts_engine.close_connection()
|
|
392
217
|
await asr_engine.close_connection()
|
|
393
|
-
await channel_websocket.close()
|
|
394
|
-
self._cancel_silence_timeout_watcher()
|
|
395
218
|
|
|
396
219
|
def create_output_channel(
|
|
397
220
|
self, voice_websocket: Websocket, tts_engine: TTSEngine
|
|
@@ -409,10 +232,7 @@ class VoiceInputChannel(InputChannel):
|
|
|
409
232
|
) -> None:
|
|
410
233
|
"""Handle a new event from the ASR system."""
|
|
411
234
|
if isinstance(e, NewTranscript) and e.text:
|
|
412
|
-
logger.
|
|
413
|
-
"VoiceInputChannel.handle_asr_event.new_transcript", transcript=e.text
|
|
414
|
-
)
|
|
415
|
-
call_state.is_user_speaking = False # type: ignore[attr-defined]
|
|
235
|
+
logger.info(f"New transcript: {e.text}")
|
|
416
236
|
output_channel = self.create_output_channel(voice_websocket, tts_engine)
|
|
417
237
|
message = UserMessage(
|
|
418
238
|
e.text,
|
|
@@ -422,6 +242,6 @@ class VoiceInputChannel(InputChannel):
|
|
|
422
242
|
metadata=asdict(call_parameters),
|
|
423
243
|
)
|
|
424
244
|
await on_new_message(message)
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
245
|
+
|
|
246
|
+
if output_channel.should_hangup:
|
|
247
|
+
self.hangup_after = output_channel.latest_message_id
|
|
@@ -1,8 +1,7 @@
|
|
|
1
1
|
import logging
|
|
2
|
-
from typing import List, Optional, Dict, Text, Set, Any
|
|
3
|
-
|
|
4
2
|
import numpy as np
|
|
5
3
|
import scipy.sparse
|
|
4
|
+
from typing import List, Optional, Dict, Text, Set, Any
|
|
6
5
|
|
|
7
6
|
from rasa.core.featurizers.precomputation import MessageContainerForCoreFeaturization
|
|
8
7
|
from rasa.nlu.extractors.extractor import EntityTagSpec
|
|
@@ -361,26 +360,6 @@ class SingleStateFeaturizer:
|
|
|
361
360
|
for action in domain.action_names_or_texts
|
|
362
361
|
]
|
|
363
362
|
|
|
364
|
-
def to_dict(self) -> Dict[str, Any]:
|
|
365
|
-
return {
|
|
366
|
-
"action_texts": self.action_texts,
|
|
367
|
-
"entity_tag_specs": self.entity_tag_specs,
|
|
368
|
-
"feature_states": self._default_feature_states,
|
|
369
|
-
}
|
|
370
|
-
|
|
371
|
-
@classmethod
|
|
372
|
-
def create_from_dict(
|
|
373
|
-
cls, data: Dict[str, Any]
|
|
374
|
-
) -> Optional["SingleStateFeaturizer"]:
|
|
375
|
-
if not data:
|
|
376
|
-
return None
|
|
377
|
-
|
|
378
|
-
featurizer = SingleStateFeaturizer()
|
|
379
|
-
featurizer.action_texts = data["action_texts"]
|
|
380
|
-
featurizer._default_feature_states = data["feature_states"]
|
|
381
|
-
featurizer.entity_tag_specs = data["entity_tag_specs"]
|
|
382
|
-
return featurizer
|
|
383
|
-
|
|
384
363
|
|
|
385
364
|
class IntentTokenizerSingleStateFeaturizer(SingleStateFeaturizer):
|
|
386
365
|
"""A SingleStateFeaturizer for use with policies that predict intent labels."""
|
|
@@ -1,9 +1,11 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
|
-
|
|
3
|
-
import logging
|
|
4
|
-
from abc import abstractmethod
|
|
5
|
-
from collections import defaultdict
|
|
6
2
|
from pathlib import Path
|
|
3
|
+
from collections import defaultdict
|
|
4
|
+
from abc import abstractmethod
|
|
5
|
+
import jsonpickle
|
|
6
|
+
import logging
|
|
7
|
+
|
|
8
|
+
from tqdm import tqdm
|
|
7
9
|
from typing import (
|
|
8
10
|
Tuple,
|
|
9
11
|
List,
|
|
@@ -16,30 +18,25 @@ from typing import (
|
|
|
16
18
|
Set,
|
|
17
19
|
DefaultDict,
|
|
18
20
|
cast,
|
|
19
|
-
Type,
|
|
20
|
-
Callable,
|
|
21
|
-
ClassVar,
|
|
22
21
|
)
|
|
23
|
-
|
|
24
22
|
import numpy as np
|
|
25
|
-
from tqdm import tqdm
|
|
26
23
|
|
|
24
|
+
from rasa.core.featurizers.single_state_featurizer import SingleStateFeaturizer
|
|
25
|
+
from rasa.core.featurizers.precomputation import MessageContainerForCoreFeaturization
|
|
26
|
+
from rasa.core.exceptions import InvalidTrackerFeaturizerUsageError
|
|
27
27
|
import rasa.shared.core.trackers
|
|
28
28
|
import rasa.shared.utils.io
|
|
29
|
-
from rasa.
|
|
30
|
-
from rasa.
|
|
31
|
-
from rasa.core.
|
|
29
|
+
from rasa.shared.nlu.constants import TEXT, INTENT, ENTITIES, ACTION_NAME
|
|
30
|
+
from rasa.shared.nlu.training_data.features import Features
|
|
31
|
+
from rasa.shared.core.trackers import DialogueStateTracker
|
|
32
|
+
from rasa.shared.core.domain import State, Domain
|
|
33
|
+
from rasa.shared.core.events import Event, ActionExecuted, UserUttered
|
|
32
34
|
from rasa.shared.core.constants import (
|
|
33
35
|
USER,
|
|
34
36
|
ACTION_UNLIKELY_INTENT_NAME,
|
|
35
37
|
PREVIOUS_ACTION,
|
|
36
38
|
)
|
|
37
|
-
from rasa.shared.core.domain import State, Domain
|
|
38
|
-
from rasa.shared.core.events import Event, ActionExecuted, UserUttered
|
|
39
|
-
from rasa.shared.core.trackers import DialogueStateTracker
|
|
40
39
|
from rasa.shared.exceptions import RasaException
|
|
41
|
-
from rasa.shared.nlu.constants import TEXT, INTENT, ENTITIES, ACTION_NAME
|
|
42
|
-
from rasa.shared.nlu.training_data.features import Features
|
|
43
40
|
from rasa.utils.tensorflow.constants import LABEL_PAD_ID
|
|
44
41
|
from rasa.utils.tensorflow.model_data import ragged_array_to_ndarray
|
|
45
42
|
|
|
@@ -67,10 +64,6 @@ class InvalidStory(RasaException):
|
|
|
67
64
|
class TrackerFeaturizer:
|
|
68
65
|
"""Base class for actual tracker featurizers."""
|
|
69
66
|
|
|
70
|
-
# Class registry to store all subclasses
|
|
71
|
-
_registry: ClassVar[Dict[str, Type["TrackerFeaturizer"]]] = {}
|
|
72
|
-
_featurizer_type: str = "TrackerFeaturizer"
|
|
73
|
-
|
|
74
67
|
def __init__(
|
|
75
68
|
self, state_featurizer: Optional[SingleStateFeaturizer] = None
|
|
76
69
|
) -> None:
|
|
@@ -81,36 +74,6 @@ class TrackerFeaturizer:
|
|
|
81
74
|
"""
|
|
82
75
|
self.state_featurizer = state_featurizer
|
|
83
76
|
|
|
84
|
-
@classmethod
|
|
85
|
-
def register(cls, featurizer_type: str) -> Callable:
|
|
86
|
-
"""Decorator to register featurizer subclasses."""
|
|
87
|
-
|
|
88
|
-
def wrapper(subclass: Type["TrackerFeaturizer"]) -> Type["TrackerFeaturizer"]:
|
|
89
|
-
cls._registry[featurizer_type] = subclass
|
|
90
|
-
# Store the type identifier in the class for serialization
|
|
91
|
-
subclass._featurizer_type = featurizer_type
|
|
92
|
-
return subclass
|
|
93
|
-
|
|
94
|
-
return wrapper
|
|
95
|
-
|
|
96
|
-
@classmethod
|
|
97
|
-
def from_dict(cls, data: Dict[str, Any]) -> "TrackerFeaturizer":
|
|
98
|
-
"""Create featurizer instance from dictionary."""
|
|
99
|
-
featurizer_type = data.pop("type")
|
|
100
|
-
|
|
101
|
-
if featurizer_type not in cls._registry:
|
|
102
|
-
raise ValueError(f"Unknown featurizer type: {featurizer_type}")
|
|
103
|
-
|
|
104
|
-
# Get the correct subclass and instantiate it
|
|
105
|
-
subclass = cls._registry[featurizer_type]
|
|
106
|
-
return subclass.create_from_dict(data)
|
|
107
|
-
|
|
108
|
-
@classmethod
|
|
109
|
-
@abstractmethod
|
|
110
|
-
def create_from_dict(cls, data: Dict[str, Any]) -> "TrackerFeaturizer":
|
|
111
|
-
"""Each subclass must implement its own creation from dict method."""
|
|
112
|
-
pass
|
|
113
|
-
|
|
114
77
|
@staticmethod
|
|
115
78
|
def _create_states(
|
|
116
79
|
tracker: DialogueStateTracker,
|
|
@@ -502,7 +465,9 @@ class TrackerFeaturizer:
|
|
|
502
465
|
self.state_featurizer.entity_tag_specs = []
|
|
503
466
|
|
|
504
467
|
# noinspection PyTypeChecker
|
|
505
|
-
rasa.shared.utils.io.
|
|
468
|
+
rasa.shared.utils.io.write_text_file(
|
|
469
|
+
str(jsonpickle.encode(self)), featurizer_file
|
|
470
|
+
)
|
|
506
471
|
|
|
507
472
|
@staticmethod
|
|
508
473
|
def load(path: Union[Text, Path]) -> Optional[TrackerFeaturizer]:
|
|
@@ -516,17 +481,7 @@ class TrackerFeaturizer:
|
|
|
516
481
|
"""
|
|
517
482
|
featurizer_file = Path(path) / FEATURIZER_FILE
|
|
518
483
|
if featurizer_file.is_file():
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
if "type" not in data:
|
|
522
|
-
logger.error(
|
|
523
|
-
f"Couldn't load featurizer for policy. "
|
|
524
|
-
f"File '{featurizer_file}' does not contain all "
|
|
525
|
-
f"necessary information. 'type' is missing."
|
|
526
|
-
)
|
|
527
|
-
return None
|
|
528
|
-
|
|
529
|
-
return TrackerFeaturizer.from_dict(data)
|
|
484
|
+
return jsonpickle.decode(rasa.shared.utils.io.read_file(featurizer_file))
|
|
530
485
|
|
|
531
486
|
logger.error(
|
|
532
487
|
f"Couldn't load featurizer for policy. "
|
|
@@ -553,16 +508,7 @@ class TrackerFeaturizer:
|
|
|
553
508
|
)
|
|
554
509
|
]
|
|
555
510
|
|
|
556
|
-
def to_dict(self) -> Dict[str, Any]:
|
|
557
|
-
return {
|
|
558
|
-
"type": self.__class__._featurizer_type,
|
|
559
|
-
"state_featurizer": (
|
|
560
|
-
self.state_featurizer.to_dict() if self.state_featurizer else None
|
|
561
|
-
),
|
|
562
|
-
}
|
|
563
|
-
|
|
564
511
|
|
|
565
|
-
@TrackerFeaturizer.register("FullDialogueTrackerFeaturizer")
|
|
566
512
|
class FullDialogueTrackerFeaturizer(TrackerFeaturizer):
|
|
567
513
|
"""Creates full dialogue training data for time distributed architectures.
|
|
568
514
|
|
|
@@ -700,20 +646,7 @@ class FullDialogueTrackerFeaturizer(TrackerFeaturizer):
|
|
|
700
646
|
|
|
701
647
|
return trackers_as_states
|
|
702
648
|
|
|
703
|
-
def to_dict(self) -> Dict[str, Any]:
|
|
704
|
-
return super().to_dict()
|
|
705
649
|
|
|
706
|
-
@classmethod
|
|
707
|
-
def create_from_dict(cls, data: Dict[str, Any]) -> "FullDialogueTrackerFeaturizer":
|
|
708
|
-
state_featurizer = SingleStateFeaturizer.create_from_dict(
|
|
709
|
-
data["state_featurizer"]
|
|
710
|
-
)
|
|
711
|
-
return cls(
|
|
712
|
-
state_featurizer,
|
|
713
|
-
)
|
|
714
|
-
|
|
715
|
-
|
|
716
|
-
@TrackerFeaturizer.register("MaxHistoryTrackerFeaturizer")
|
|
717
650
|
class MaxHistoryTrackerFeaturizer(TrackerFeaturizer):
|
|
718
651
|
"""Truncates the tracker history into `max_history` long sequences.
|
|
719
652
|
|
|
@@ -951,25 +884,7 @@ class MaxHistoryTrackerFeaturizer(TrackerFeaturizer):
|
|
|
951
884
|
|
|
952
885
|
return trackers_as_states
|
|
953
886
|
|
|
954
|
-
def to_dict(self) -> Dict[str, Any]:
|
|
955
|
-
data = super().to_dict()
|
|
956
|
-
data.update(
|
|
957
|
-
{
|
|
958
|
-
"remove_duplicates": self.remove_duplicates,
|
|
959
|
-
"max_history": self.max_history,
|
|
960
|
-
}
|
|
961
|
-
)
|
|
962
|
-
return data
|
|
963
|
-
|
|
964
|
-
@classmethod
|
|
965
|
-
def create_from_dict(cls, data: Dict[str, Any]) -> "MaxHistoryTrackerFeaturizer":
|
|
966
|
-
state_featurizer = SingleStateFeaturizer.create_from_dict(
|
|
967
|
-
data["state_featurizer"]
|
|
968
|
-
)
|
|
969
|
-
return cls(state_featurizer, data["max_history"], data["remove_duplicates"])
|
|
970
887
|
|
|
971
|
-
|
|
972
|
-
@TrackerFeaturizer.register("IntentMaxHistoryTrackerFeaturizer")
|
|
973
888
|
class IntentMaxHistoryTrackerFeaturizer(MaxHistoryTrackerFeaturizer):
|
|
974
889
|
"""Truncates the tracker history into `max_history` long sequences.
|
|
975
890
|
|
|
@@ -1244,18 +1159,6 @@ class IntentMaxHistoryTrackerFeaturizer(MaxHistoryTrackerFeaturizer):
|
|
|
1244
1159
|
|
|
1245
1160
|
return trackers_as_states
|
|
1246
1161
|
|
|
1247
|
-
def to_dict(self) -> Dict[str, Any]:
|
|
1248
|
-
return super().to_dict()
|
|
1249
|
-
|
|
1250
|
-
@classmethod
|
|
1251
|
-
def create_from_dict(
|
|
1252
|
-
cls, data: Dict[str, Any]
|
|
1253
|
-
) -> "IntentMaxHistoryTrackerFeaturizer":
|
|
1254
|
-
state_featurizer = SingleStateFeaturizer.create_from_dict(
|
|
1255
|
-
data["state_featurizer"]
|
|
1256
|
-
)
|
|
1257
|
-
return cls(state_featurizer, data["max_history"], data["remove_duplicates"])
|
|
1258
|
-
|
|
1259
1162
|
|
|
1260
1163
|
def _is_prev_action_unlikely_intent_in_state(state: State) -> bool:
|
|
1261
1164
|
prev_action_name = state.get(PREVIOUS_ACTION, {}).get(ACTION_NAME)
|
|
@@ -62,7 +62,6 @@ class Qdrant_Store(InformationRetrieval):
|
|
|
62
62
|
embeddings=self.embeddings,
|
|
63
63
|
content_payload_key=params.get("content_payload_key", "text"),
|
|
64
64
|
metadata_payload_key=params.get("metadata_payload_key", "metadata"),
|
|
65
|
-
vector_name=params.get("vector_name", None),
|
|
66
65
|
)
|
|
67
66
|
|
|
68
67
|
async def search(
|