rasa-pro 3.10.16__py3-none-any.whl → 3.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- rasa/__main__.py +31 -15
- rasa/api.py +12 -2
- rasa/cli/arguments/default_arguments.py +24 -4
- rasa/cli/arguments/run.py +15 -0
- rasa/cli/arguments/shell.py +5 -1
- rasa/cli/arguments/train.py +17 -9
- rasa/cli/evaluate.py +7 -7
- rasa/cli/inspect.py +19 -7
- rasa/cli/interactive.py +1 -0
- rasa/cli/llm_fine_tuning.py +11 -14
- rasa/cli/project_templates/calm/config.yml +5 -7
- rasa/cli/project_templates/calm/endpoints.yml +15 -2
- rasa/cli/project_templates/tutorial/config.yml +8 -5
- rasa/cli/project_templates/tutorial/data/flows.yml +1 -1
- rasa/cli/project_templates/tutorial/data/patterns.yml +5 -0
- rasa/cli/project_templates/tutorial/domain.yml +14 -0
- rasa/cli/project_templates/tutorial/endpoints.yml +5 -0
- rasa/cli/run.py +7 -0
- rasa/cli/scaffold.py +4 -2
- rasa/cli/studio/upload.py +0 -15
- rasa/cli/train.py +14 -53
- rasa/cli/utils.py +14 -11
- rasa/cli/x.py +7 -7
- rasa/constants.py +3 -1
- rasa/core/actions/action.py +77 -33
- rasa/core/actions/action_hangup.py +29 -0
- rasa/core/actions/action_repeat_bot_messages.py +89 -0
- rasa/core/actions/e2e_stub_custom_action_executor.py +5 -1
- rasa/core/actions/http_custom_action_executor.py +4 -0
- rasa/core/agent.py +2 -2
- rasa/core/brokers/kafka.py +3 -1
- rasa/core/brokers/pika.py +3 -1
- rasa/core/channels/__init__.py +10 -6
- rasa/core/channels/channel.py +41 -4
- rasa/core/channels/development_inspector.py +150 -46
- rasa/core/channels/inspector/README.md +1 -1
- rasa/core/channels/inspector/dist/assets/{arc-b6e548fe.js → arc-bc141fb2.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{c4Diagram-d0fbc5ce-fa03ac9e.js → c4Diagram-d0fbc5ce-be2db283.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{classDiagram-936ed81e-ee67392a.js → classDiagram-936ed81e-55366915.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{classDiagram-v2-c3cb15f1-9b283fae.js → classDiagram-v2-c3cb15f1-bb529518.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{createText-62fc7601-8b6fcc2a.js → createText-62fc7601-b0ec81d6.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{edges-f2ad444c-22e77f4f.js → edges-f2ad444c-6166330c.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{erDiagram-9d236eb7-60ffc87f.js → erDiagram-9d236eb7-5ccc6a8e.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{flowDb-1972c806-9dd802e4.js → flowDb-1972c806-fca3bfe4.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{flowDiagram-7ea5b25a-5fa1912f.js → flowDiagram-7ea5b25a-4739080f.js} +1 -1
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-736177bf.js +1 -0
- rasa/core/channels/inspector/dist/assets/{flowchart-elk-definition-abe16c3d-622a1fd2.js → flowchart-elk-definition-abe16c3d-7c1b0e0f.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{ganttDiagram-9b5ea136-e285a63a.js → ganttDiagram-9b5ea136-772fd050.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{gitGraphDiagram-99d0ae7c-f237bdca.js → gitGraphDiagram-99d0ae7c-8eae1dc9.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{index-2c4b9a3b-4b03d70e.js → index-2c4b9a3b-f55afcdf.js} +1 -1
- rasa/core/channels/inspector/dist/assets/index-e7cef9de.js +1317 -0
- rasa/core/channels/inspector/dist/assets/{infoDiagram-736b4530-72a0fa5f.js → infoDiagram-736b4530-124d4a14.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{journeyDiagram-df861f2b-82218c41.js → journeyDiagram-df861f2b-7c4fae44.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{layout-78cff630.js → layout-b9885fb6.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{line-5038b469.js → line-7c59abb6.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{linear-c4fc4098.js → linear-4776f780.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{mindmap-definition-beec6740-c33c8ea6.js → mindmap-definition-beec6740-2332c46c.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{pieDiagram-dbbf0591-a8d03059.js → pieDiagram-dbbf0591-8fb39303.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{quadrantDiagram-4d7f4fd6-6a0e56b2.js → quadrantDiagram-4d7f4fd6-3c7180a2.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{requirementDiagram-6fc4c22a-2dc7c7bd.js → requirementDiagram-6fc4c22a-e910bcb8.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{sankeyDiagram-8f13d901-2360fe39.js → sankeyDiagram-8f13d901-ead16c89.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{sequenceDiagram-b655622a-41b9f9ad.js → sequenceDiagram-b655622a-29a02a19.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{stateDiagram-59f0c015-0aad326f.js → stateDiagram-59f0c015-042b3137.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{stateDiagram-v2-2b26beab-9847d984.js → stateDiagram-v2-2b26beab-2178c0f3.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-080da4f6-564d890e.js → styles-080da4f6-23ffa4fc.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-3dcbcfbf-38957613.js → styles-3dcbcfbf-94f59763.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-9c745c82-f0fc6921.js → styles-9c745c82-78a6bebc.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{svgDrawCommon-4835440b-ef3c5a77.js → svgDrawCommon-4835440b-eae2a6f6.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{timeline-definition-5b62e21b-bf3e91c1.js → timeline-definition-5b62e21b-5c968d92.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{xychartDiagram-2b33534f-4d4026c0.js → xychartDiagram-2b33534f-fd3db0d5.js} +1 -1
- rasa/core/channels/inspector/dist/index.html +18 -17
- rasa/core/channels/inspector/index.html +17 -16
- rasa/core/channels/inspector/package.json +5 -1
- rasa/core/channels/inspector/src/App.tsx +118 -68
- rasa/core/channels/inspector/src/components/Chat.tsx +95 -0
- rasa/core/channels/inspector/src/components/DiagramFlow.tsx +11 -10
- rasa/core/channels/inspector/src/components/DialogueStack.tsx +10 -25
- rasa/core/channels/inspector/src/components/LoadingSpinner.tsx +6 -3
- rasa/core/channels/inspector/src/helpers/audiostream.ts +165 -0
- rasa/core/channels/inspector/src/helpers/formatters.test.ts +10 -0
- rasa/core/channels/inspector/src/helpers/formatters.ts +107 -41
- rasa/core/channels/inspector/src/helpers/utils.ts +92 -7
- rasa/core/channels/inspector/src/types.ts +21 -1
- rasa/core/channels/inspector/yarn.lock +94 -1
- rasa/core/channels/rest.py +51 -46
- rasa/core/channels/socketio.py +28 -1
- rasa/core/channels/telegram.py +1 -1
- rasa/core/channels/twilio.py +1 -1
- rasa/core/channels/{audiocodes.py → voice_ready/audiocodes.py} +122 -69
- rasa/core/channels/{voice_aware → voice_ready}/jambonz.py +26 -8
- rasa/core/channels/{voice_aware → voice_ready}/jambonz_protocol.py +57 -5
- rasa/core/channels/{twilio_voice.py → voice_ready/twilio_voice.py} +64 -28
- rasa/core/channels/voice_ready/utils.py +37 -0
- rasa/core/channels/voice_stream/asr/__init__.py +0 -0
- rasa/core/channels/voice_stream/asr/asr_engine.py +89 -0
- rasa/core/channels/voice_stream/asr/asr_event.py +18 -0
- rasa/core/channels/voice_stream/asr/azure.py +129 -0
- rasa/core/channels/voice_stream/asr/deepgram.py +90 -0
- rasa/core/channels/voice_stream/audio_bytes.py +8 -0
- rasa/core/channels/voice_stream/browser_audio.py +107 -0
- rasa/core/channels/voice_stream/call_state.py +23 -0
- rasa/core/channels/voice_stream/tts/__init__.py +0 -0
- rasa/core/channels/voice_stream/tts/azure.py +106 -0
- rasa/core/channels/voice_stream/tts/cartesia.py +118 -0
- rasa/core/channels/voice_stream/tts/tts_cache.py +27 -0
- rasa/core/channels/voice_stream/tts/tts_engine.py +58 -0
- rasa/core/channels/voice_stream/twilio_media_streams.py +173 -0
- rasa/core/channels/voice_stream/util.py +57 -0
- rasa/core/channels/voice_stream/voice_channel.py +427 -0
- rasa/core/information_retrieval/qdrant.py +1 -0
- rasa/core/nlg/contextual_response_rephraser.py +45 -17
- rasa/{nlu → core}/persistor.py +203 -68
- rasa/core/policies/enterprise_search_policy.py +119 -63
- rasa/core/policies/flows/flow_executor.py +15 -22
- rasa/core/policies/intentless_policy.py +83 -28
- rasa/core/processor.py +25 -0
- rasa/core/run.py +12 -2
- rasa/core/secrets_manager/constants.py +4 -0
- rasa/core/secrets_manager/factory.py +8 -0
- rasa/core/secrets_manager/vault.py +11 -1
- rasa/core/training/interactive.py +33 -34
- rasa/core/utils.py +47 -21
- rasa/dialogue_understanding/coexistence/llm_based_router.py +41 -14
- rasa/dialogue_understanding/commands/__init__.py +6 -0
- rasa/dialogue_understanding/commands/repeat_bot_messages_command.py +60 -0
- rasa/dialogue_understanding/commands/session_end_command.py +61 -0
- rasa/dialogue_understanding/commands/user_silence_command.py +59 -0
- rasa/dialogue_understanding/commands/utils.py +5 -0
- rasa/dialogue_understanding/generator/constants.py +2 -0
- rasa/dialogue_understanding/generator/flow_retrieval.py +47 -9
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +38 -15
- rasa/dialogue_understanding/generator/llm_command_generator.py +1 -1
- rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +35 -13
- rasa/dialogue_understanding/generator/single_step/command_prompt_template.jinja2 +3 -0
- rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +60 -13
- rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +53 -0
- rasa/dialogue_understanding/patterns/repeat.py +37 -0
- rasa/dialogue_understanding/patterns/user_silence.py +37 -0
- rasa/dialogue_understanding/processor/command_processor.py +21 -1
- rasa/e2e_test/aggregate_test_stats_calculator.py +1 -11
- rasa/e2e_test/assertions.py +136 -61
- rasa/e2e_test/assertions_schema.yml +23 -0
- rasa/e2e_test/e2e_test_case.py +85 -6
- rasa/e2e_test/e2e_test_runner.py +2 -3
- rasa/e2e_test/utils/e2e_yaml_utils.py +1 -1
- rasa/engine/graph.py +3 -10
- rasa/engine/loader.py +12 -0
- rasa/engine/recipes/config_files/default_config.yml +0 -3
- rasa/engine/recipes/default_recipe.py +0 -1
- rasa/engine/recipes/graph_recipe.py +0 -1
- rasa/engine/runner/dask.py +2 -2
- rasa/engine/storage/local_model_storage.py +12 -42
- rasa/engine/storage/storage.py +1 -5
- rasa/engine/validation.py +527 -74
- rasa/model_manager/__init__.py +0 -0
- rasa/model_manager/config.py +40 -0
- rasa/model_manager/model_api.py +559 -0
- rasa/model_manager/runner_service.py +286 -0
- rasa/model_manager/socket_bridge.py +146 -0
- rasa/model_manager/studio_jwt_auth.py +86 -0
- rasa/model_manager/trainer_service.py +325 -0
- rasa/model_manager/utils.py +87 -0
- rasa/model_manager/warm_rasa_process.py +187 -0
- rasa/model_service.py +112 -0
- rasa/model_training.py +42 -23
- rasa/nlu/tokenizers/whitespace_tokenizer.py +3 -14
- rasa/server.py +4 -2
- rasa/shared/constants.py +60 -8
- rasa/shared/core/constants.py +13 -0
- rasa/shared/core/domain.py +107 -50
- rasa/shared/core/events.py +29 -0
- rasa/shared/core/flows/flow.py +5 -0
- rasa/shared/core/flows/flows_list.py +19 -6
- rasa/shared/core/flows/flows_yaml_schema.json +10 -0
- rasa/shared/core/flows/utils.py +39 -0
- rasa/shared/core/flows/validation.py +121 -0
- rasa/shared/core/flows/yaml_flows_io.py +15 -27
- rasa/shared/core/slots.py +5 -0
- rasa/shared/importers/importer.py +59 -41
- rasa/shared/importers/multi_project.py +23 -11
- rasa/shared/importers/rasa.py +12 -3
- rasa/shared/importers/remote_importer.py +196 -0
- rasa/shared/importers/utils.py +3 -1
- rasa/shared/nlu/training_data/formats/rasa_yaml.py +18 -3
- rasa/shared/nlu/training_data/training_data.py +18 -19
- rasa/shared/providers/_configs/litellm_router_client_config.py +220 -0
- rasa/shared/providers/_configs/model_group_config.py +167 -0
- rasa/shared/providers/_configs/openai_client_config.py +1 -1
- rasa/shared/providers/_configs/rasa_llm_client_config.py +73 -0
- rasa/shared/providers/_configs/self_hosted_llm_client_config.py +1 -0
- rasa/shared/providers/_configs/utils.py +16 -0
- rasa/shared/providers/_utils.py +79 -0
- rasa/shared/providers/embedding/_base_litellm_embedding_client.py +13 -29
- rasa/shared/providers/embedding/azure_openai_embedding_client.py +54 -21
- rasa/shared/providers/embedding/default_litellm_embedding_client.py +24 -0
- rasa/shared/providers/embedding/litellm_router_embedding_client.py +135 -0
- rasa/shared/providers/llm/_base_litellm_client.py +34 -22
- rasa/shared/providers/llm/azure_openai_llm_client.py +50 -29
- rasa/shared/providers/llm/default_litellm_llm_client.py +24 -0
- rasa/shared/providers/llm/litellm_router_llm_client.py +182 -0
- rasa/shared/providers/llm/rasa_llm_client.py +112 -0
- rasa/shared/providers/llm/self_hosted_llm_client.py +5 -29
- rasa/shared/providers/mappings.py +19 -0
- rasa/shared/providers/router/__init__.py +0 -0
- rasa/shared/providers/router/_base_litellm_router_client.py +183 -0
- rasa/shared/providers/router/router_client.py +73 -0
- rasa/shared/utils/common.py +40 -24
- rasa/shared/utils/health_check/__init__.py +0 -0
- rasa/shared/utils/health_check/embeddings_health_check_mixin.py +31 -0
- rasa/shared/utils/health_check/health_check.py +258 -0
- rasa/shared/utils/health_check/llm_health_check_mixin.py +31 -0
- rasa/shared/utils/io.py +27 -6
- rasa/shared/utils/llm.py +354 -44
- rasa/shared/utils/schemas/events.py +2 -0
- rasa/shared/utils/schemas/model_config.yml +0 -10
- rasa/shared/utils/yaml.py +181 -38
- rasa/studio/data_handler.py +3 -1
- rasa/studio/upload.py +160 -74
- rasa/telemetry.py +94 -17
- rasa/tracing/config.py +3 -1
- rasa/tracing/instrumentation/attribute_extractors.py +95 -18
- rasa/tracing/instrumentation/instrumentation.py +121 -0
- rasa/utils/common.py +5 -0
- rasa/utils/endpoints.py +27 -1
- rasa/utils/io.py +8 -16
- rasa/utils/log_utils.py +9 -2
- rasa/utils/sanic_error_handler.py +32 -0
- rasa/validator.py +110 -16
- rasa/version.py +1 -1
- {rasa_pro-3.10.16.dist-info → rasa_pro-3.11.0.dist-info}/METADATA +16 -14
- {rasa_pro-3.10.16.dist-info → rasa_pro-3.11.0.dist-info}/RECORD +236 -185
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-1844e5a5.js +0 -1
- rasa/core/channels/inspector/dist/assets/index-a5d3e69d.js +0 -1040
- rasa/core/channels/voice_aware/utils.py +0 -20
- rasa/llm_fine_tuning/notebooks/unsloth_finetuning.ipynb +0 -407
- /rasa/core/channels/{voice_aware → voice_ready}/__init__.py +0 -0
- /rasa/core/channels/{voice_native → voice_stream}/__init__.py +0 -0
- {rasa_pro-3.10.16.dist-info → rasa_pro-3.11.0.dist-info}/NOTICE +0 -0
- {rasa_pro-3.10.16.dist-info → rasa_pro-3.11.0.dist-info}/WHEEL +0 -0
- {rasa_pro-3.10.16.dist-info → rasa_pro-3.11.0.dist-info}/entry_points.txt +0 -0
rasa/core/utils.py
CHANGED
|
@@ -1,8 +1,9 @@
|
|
|
1
|
+
import structlog
|
|
1
2
|
import logging
|
|
2
3
|
import os
|
|
3
4
|
from pathlib import Path
|
|
4
5
|
from socket import SOCK_DGRAM, SOCK_STREAM
|
|
5
|
-
from typing import Any, Dict, Optional, Set, TYPE_CHECKING, Text, Tuple, Union
|
|
6
|
+
from typing import Any, Dict, Optional, List, Set, TYPE_CHECKING, Text, Tuple, Union
|
|
6
7
|
|
|
7
8
|
import numpy as np
|
|
8
9
|
from sanic import Sanic
|
|
@@ -19,14 +20,18 @@ from rasa.core.constants import (
|
|
|
19
20
|
from rasa.core.lock_store import LockStore, RedisLockStore, InMemoryLockStore
|
|
20
21
|
from rasa.shared.constants import DEFAULT_ENDPOINTS_PATH, TCP_PROTOCOL
|
|
21
22
|
from rasa.shared.core.trackers import DialogueStateTracker
|
|
22
|
-
from rasa.utils.endpoints import
|
|
23
|
+
from rasa.utils.endpoints import (
|
|
24
|
+
EndpointConfig,
|
|
25
|
+
read_endpoint_config,
|
|
26
|
+
read_property_config_from_endpoints_file,
|
|
27
|
+
)
|
|
23
28
|
from rasa.utils.io import write_yaml
|
|
24
29
|
|
|
25
30
|
if TYPE_CHECKING:
|
|
26
31
|
from rasa.core.nlg import NaturalLanguageGenerator
|
|
27
32
|
from rasa.shared.core.domain import Domain
|
|
28
33
|
|
|
29
|
-
|
|
34
|
+
structlogger = structlog.get_logger()
|
|
30
35
|
|
|
31
36
|
|
|
32
37
|
def configure_file_logging(
|
|
@@ -124,15 +129,17 @@ def list_routes(app: Sanic) -> Dict[Text, Text]:
|
|
|
124
129
|
for arg in route._params:
|
|
125
130
|
options[arg] = f"[{arg}]"
|
|
126
131
|
|
|
127
|
-
|
|
132
|
+
name = route.name.replace("rasa_server.", "")
|
|
133
|
+
methods = ",".join(route.methods)
|
|
128
134
|
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
output[name] = line
|
|
135
|
+
full_endpoint = "/" + "/".join(endpoint)
|
|
136
|
+
line = unquote(f"{full_endpoint:50s} {methods:30s} {name}")
|
|
137
|
+
output[name] = line
|
|
133
138
|
|
|
134
139
|
url_table = "\n".join(output[url] for url in sorted(output))
|
|
135
|
-
|
|
140
|
+
structlogger.debug(
|
|
141
|
+
"server.routes", event_info=f"Available web server routes: \n{url_table}"
|
|
142
|
+
)
|
|
136
143
|
|
|
137
144
|
return output
|
|
138
145
|
|
|
@@ -186,6 +193,9 @@ class AvailableEndpoints:
|
|
|
186
193
|
lock_store = read_endpoint_config(endpoint_file, endpoint_type="lock_store")
|
|
187
194
|
event_broker = read_endpoint_config(endpoint_file, endpoint_type="event_broker")
|
|
188
195
|
vector_store = read_endpoint_config(endpoint_file, endpoint_type="vector_store")
|
|
196
|
+
model_groups = read_property_config_from_endpoints_file(
|
|
197
|
+
endpoint_file, property_name="model_groups"
|
|
198
|
+
)
|
|
189
199
|
|
|
190
200
|
return cls(
|
|
191
201
|
nlg,
|
|
@@ -196,6 +206,7 @@ class AvailableEndpoints:
|
|
|
196
206
|
lock_store,
|
|
197
207
|
event_broker,
|
|
198
208
|
vector_store,
|
|
209
|
+
model_groups,
|
|
199
210
|
)
|
|
200
211
|
|
|
201
212
|
def __init__(
|
|
@@ -208,6 +219,7 @@ class AvailableEndpoints:
|
|
|
208
219
|
lock_store: Optional[EndpointConfig] = None,
|
|
209
220
|
event_broker: Optional[EndpointConfig] = None,
|
|
210
221
|
vector_store: Optional[EndpointConfig] = None,
|
|
222
|
+
model_groups: Optional[List[Dict[str, Any]]] = None,
|
|
211
223
|
) -> None:
|
|
212
224
|
"""Create an `AvailableEndpoints` object."""
|
|
213
225
|
self.model = model
|
|
@@ -218,6 +230,7 @@ class AvailableEndpoints:
|
|
|
218
230
|
self.lock_store = lock_store
|
|
219
231
|
self.event_broker = event_broker
|
|
220
232
|
self.vector_store = vector_store
|
|
233
|
+
self.model_groups = model_groups
|
|
221
234
|
|
|
222
235
|
@classmethod
|
|
223
236
|
def get_instance(cls, endpoint_file: Optional[Text] = None) -> "AvailableEndpoints":
|
|
@@ -273,17 +286,22 @@ def number_of_sanic_workers(lock_store: Union[EndpointConfig, LockStore, None])
|
|
|
273
286
|
"""
|
|
274
287
|
|
|
275
288
|
def _log_and_get_default_number_of_workers() -> int:
|
|
276
|
-
|
|
277
|
-
|
|
289
|
+
structlogger.debug(
|
|
290
|
+
"server.worker.set_count",
|
|
291
|
+
number_of_workers=DEFAULT_SANIC_WORKERS,
|
|
292
|
+
event_info=f"Using the default number of Sanic workers "
|
|
293
|
+
f"({DEFAULT_SANIC_WORKERS}).",
|
|
278
294
|
)
|
|
279
295
|
return DEFAULT_SANIC_WORKERS
|
|
280
296
|
|
|
281
297
|
try:
|
|
282
298
|
env_value = int(os.environ.get(ENV_SANIC_WORKERS, DEFAULT_SANIC_WORKERS))
|
|
283
299
|
except ValueError:
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
300
|
+
structlogger.error(
|
|
301
|
+
"server.worker.set_count.error",
|
|
302
|
+
number_of_workers=os.environ[ENV_SANIC_WORKERS],
|
|
303
|
+
event_info=f"Cannot convert environment variable `{ENV_SANIC_WORKERS}` "
|
|
304
|
+
f"to int ('{os.environ[ENV_SANIC_WORKERS]}').",
|
|
287
305
|
)
|
|
288
306
|
return _log_and_get_default_number_of_workers()
|
|
289
307
|
|
|
@@ -291,20 +309,28 @@ def number_of_sanic_workers(lock_store: Union[EndpointConfig, LockStore, None])
|
|
|
291
309
|
return _log_and_get_default_number_of_workers()
|
|
292
310
|
|
|
293
311
|
if env_value < 1:
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
312
|
+
structlogger.warning(
|
|
313
|
+
"server.worker.set_count.error_less_than_one",
|
|
314
|
+
number_of_workers=env_value,
|
|
315
|
+
event_info=f"Cannot set number of Sanic workers to the desired value "
|
|
316
|
+
f"({env_value}). The number of workers must be at least 1.",
|
|
297
317
|
)
|
|
298
318
|
return _log_and_get_default_number_of_workers()
|
|
299
319
|
|
|
300
320
|
if _lock_store_is_multi_worker_compatible(lock_store):
|
|
301
|
-
|
|
321
|
+
structlogger.debug(
|
|
322
|
+
"server.worker.set_count.success",
|
|
323
|
+
event_info=f"Using {env_value} Sanic workers.",
|
|
324
|
+
num_workers=env_value,
|
|
325
|
+
)
|
|
302
326
|
return env_value
|
|
303
327
|
|
|
304
|
-
|
|
305
|
-
|
|
328
|
+
structlogger.warning(
|
|
329
|
+
"server.worker.set_count.error_no_lock_store",
|
|
330
|
+
event_info=f"Unable to assign desired number of Sanic workers ({env_value}) as "
|
|
306
331
|
f"no `RedisLockStore` or custom `LockStore` endpoint "
|
|
307
|
-
f"configuration has been found."
|
|
332
|
+
f"configuration has been found.",
|
|
333
|
+
num_workers=env_value,
|
|
308
334
|
)
|
|
309
335
|
return _log_and_get_default_number_of_workers()
|
|
310
336
|
|
|
@@ -15,7 +15,9 @@ from rasa.dialogue_understanding.coexistence.constants import (
|
|
|
15
15
|
)
|
|
16
16
|
from rasa.dialogue_understanding.commands import Command, SetSlotCommand
|
|
17
17
|
from rasa.dialogue_understanding.commands.noop_command import NoopCommand
|
|
18
|
-
from rasa.dialogue_understanding.generator.constants import
|
|
18
|
+
from rasa.dialogue_understanding.generator.constants import (
|
|
19
|
+
LLM_CONFIG_KEY,
|
|
20
|
+
)
|
|
19
21
|
from rasa.engine.graph import ExecutionContext, GraphComponent
|
|
20
22
|
from rasa.engine.recipes.default_recipe import DefaultV1Recipe
|
|
21
23
|
from rasa.engine.storage.resource import Resource
|
|
@@ -33,11 +35,13 @@ from rasa.shared.exceptions import InvalidConfigException, FileIOException
|
|
|
33
35
|
from rasa.shared.nlu.constants import COMMANDS, TEXT
|
|
34
36
|
from rasa.shared.nlu.training_data.message import Message
|
|
35
37
|
from rasa.shared.nlu.training_data.training_data import TrainingData
|
|
38
|
+
from rasa.shared.utils.health_check.llm_health_check_mixin import LLMHealthCheckMixin
|
|
39
|
+
from rasa.shared.utils.io import deep_container_fingerprint
|
|
36
40
|
from rasa.shared.utils.llm import (
|
|
37
41
|
DEFAULT_OPENAI_CHAT_MODEL_NAME,
|
|
38
42
|
get_prompt_template,
|
|
39
43
|
llm_factory,
|
|
40
|
-
|
|
44
|
+
resolve_model_client_config,
|
|
41
45
|
)
|
|
42
46
|
from rasa.utils.log_utils import log_llm
|
|
43
47
|
|
|
@@ -45,6 +49,7 @@ LLM_BASED_ROUTER_PROMPT_FILE_NAME = "llm_based_router_prompt.jinja2"
|
|
|
45
49
|
DEFAULT_COMMAND_PROMPT_TEMPLATE = importlib.resources.read_text(
|
|
46
50
|
"rasa.dialogue_understanding.coexistence", "router_template.jinja2"
|
|
47
51
|
)
|
|
52
|
+
LLM_BASED_ROUTER_CONFIG_FILE_NAME = "config.json"
|
|
48
53
|
|
|
49
54
|
# Token ids for gpt 3.5 and gpt 4 corresponding to space + capitalized Letter
|
|
50
55
|
A_TO_C_TOKEN_IDS_CHATGPT = [
|
|
@@ -71,7 +76,7 @@ structlogger = structlog.get_logger()
|
|
|
71
76
|
],
|
|
72
77
|
is_trainable=True,
|
|
73
78
|
)
|
|
74
|
-
class LLMBasedRouter(GraphComponent):
|
|
79
|
+
class LLMBasedRouter(LLMHealthCheckMixin, GraphComponent):
|
|
75
80
|
@staticmethod
|
|
76
81
|
def get_default_config() -> Dict[str, Any]:
|
|
77
82
|
"""The component's default config (see parent class for full docstring)."""
|
|
@@ -93,6 +98,9 @@ class LLMBasedRouter(GraphComponent):
|
|
|
93
98
|
prompt_template: Optional[str] = None,
|
|
94
99
|
) -> None:
|
|
95
100
|
self.config = {**self.get_default_config(), **config}
|
|
101
|
+
self.config[LLM_CONFIG_KEY] = resolve_model_client_config(
|
|
102
|
+
self.config.get(LLM_CONFIG_KEY), LLMBasedRouter.__name__
|
|
103
|
+
)
|
|
96
104
|
|
|
97
105
|
self.prompt_template = (
|
|
98
106
|
prompt_template
|
|
@@ -126,15 +134,17 @@ class LLMBasedRouter(GraphComponent):
|
|
|
126
134
|
rasa.shared.utils.io.write_text_file(
|
|
127
135
|
self.prompt_template, path / LLM_BASED_ROUTER_PROMPT_FILE_NAME
|
|
128
136
|
)
|
|
137
|
+
rasa.shared.utils.io.dump_obj_as_json_to_file(
|
|
138
|
+
path / LLM_BASED_ROUTER_CONFIG_FILE_NAME, self.config
|
|
139
|
+
)
|
|
129
140
|
|
|
130
141
|
def train(self, training_data: TrainingData) -> Resource:
|
|
131
142
|
"""Train the intent classifier on a data set."""
|
|
132
|
-
|
|
133
|
-
try_instantiate_llm_client(
|
|
143
|
+
self.perform_llm_health_check(
|
|
134
144
|
self.config.get(LLM_CONFIG_KEY),
|
|
135
145
|
DEFAULT_LLM_CONFIG,
|
|
136
146
|
"llm_based_router.train",
|
|
137
|
-
|
|
147
|
+
LLMBasedRouter.__name__,
|
|
138
148
|
)
|
|
139
149
|
|
|
140
150
|
self.persist()
|
|
@@ -150,6 +160,16 @@ class LLMBasedRouter(GraphComponent):
|
|
|
150
160
|
**kwargs: Any,
|
|
151
161
|
) -> "LLMBasedRouter":
|
|
152
162
|
"""Loads trained component (see parent class for full docstring)."""
|
|
163
|
+
|
|
164
|
+
# Perform health check on the resolved LLM client config
|
|
165
|
+
llm_config = resolve_model_client_config(config.get(LLM_CONFIG_KEY, {}))
|
|
166
|
+
cls.perform_llm_health_check(
|
|
167
|
+
llm_config,
|
|
168
|
+
DEFAULT_LLM_CONFIG,
|
|
169
|
+
"llm_based_router.load",
|
|
170
|
+
LLMBasedRouter.__name__,
|
|
171
|
+
)
|
|
172
|
+
|
|
153
173
|
prompt_template = None
|
|
154
174
|
try:
|
|
155
175
|
with model_storage.read_from(resource) as path:
|
|
@@ -161,14 +181,7 @@ class LLMBasedRouter(GraphComponent):
|
|
|
161
181
|
"llm_based_router.load.failed", error=e, resource=resource.name
|
|
162
182
|
)
|
|
163
183
|
|
|
164
|
-
|
|
165
|
-
try_instantiate_llm_client(
|
|
166
|
-
router.config.get(LLM_CONFIG_KEY),
|
|
167
|
-
DEFAULT_LLM_CONFIG,
|
|
168
|
-
"llm_based_router.load",
|
|
169
|
-
LLMBasedRouter.__name__,
|
|
170
|
-
)
|
|
171
|
-
return router
|
|
184
|
+
return cls(config, model_storage, resource, prompt_template=prompt_template)
|
|
172
185
|
|
|
173
186
|
@classmethod
|
|
174
187
|
def create(
|
|
@@ -298,3 +311,17 @@ class LLMBasedRouter(GraphComponent):
|
|
|
298
311
|
# we have to catch all exceptions here
|
|
299
312
|
structlogger.error("llm_based_router.llm.error", error=e)
|
|
300
313
|
return None
|
|
314
|
+
|
|
315
|
+
@classmethod
|
|
316
|
+
def fingerprint_addon(cls, config: Dict[str, Any]) -> Optional[str]:
|
|
317
|
+
"""Add a fingerprint of llm based router for the graph."""
|
|
318
|
+
prompt_template = get_prompt_template(
|
|
319
|
+
config.get(PROMPT_CONFIG_KEY),
|
|
320
|
+
DEFAULT_COMMAND_PROMPT_TEMPLATE,
|
|
321
|
+
)
|
|
322
|
+
|
|
323
|
+
llm_config = resolve_model_client_config(
|
|
324
|
+
config.get(LLM_CONFIG_KEY), LLMBasedRouter.__name__
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
return deep_container_fingerprint([prompt_template, llm_config])
|
|
@@ -32,6 +32,10 @@ from rasa.dialogue_understanding.commands.change_flow_command import ChangeFlowC
|
|
|
32
32
|
from rasa.dialogue_understanding.commands.session_start_command import (
|
|
33
33
|
SessionStartCommand,
|
|
34
34
|
)
|
|
35
|
+
from rasa.dialogue_understanding.commands.session_end_command import SessionEndCommand
|
|
36
|
+
from rasa.dialogue_understanding.commands.repeat_bot_messages_command import (
|
|
37
|
+
RepeatBotMessagesCommand,
|
|
38
|
+
)
|
|
35
39
|
|
|
36
40
|
__all__ = [
|
|
37
41
|
"Command",
|
|
@@ -51,5 +55,7 @@ __all__ = [
|
|
|
51
55
|
"NoopCommand",
|
|
52
56
|
"ChangeFlowCommand",
|
|
53
57
|
"SessionStartCommand",
|
|
58
|
+
"SessionEndCommand",
|
|
59
|
+
"RepeatBotMessagesCommand",
|
|
54
60
|
"RestartCommand",
|
|
55
61
|
]
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Any, Dict, List
|
|
5
|
+
from rasa.dialogue_understanding.commands import Command
|
|
6
|
+
from rasa.dialogue_understanding.patterns.repeat import (
|
|
7
|
+
RepeatBotMessagesPatternFlowStackFrame,
|
|
8
|
+
)
|
|
9
|
+
from rasa.shared.core.events import Event
|
|
10
|
+
from rasa.shared.core.flows import FlowsList
|
|
11
|
+
from rasa.shared.core.trackers import DialogueStateTracker
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass
|
|
15
|
+
class RepeatBotMessagesCommand(Command):
|
|
16
|
+
"""A command to indicate that the bot should repeat its last messages."""
|
|
17
|
+
|
|
18
|
+
@classmethod
|
|
19
|
+
def command(cls) -> str:
|
|
20
|
+
"""Returns the command type."""
|
|
21
|
+
return "repeat"
|
|
22
|
+
|
|
23
|
+
@classmethod
|
|
24
|
+
def from_dict(cls, data: Dict[str, Any]) -> RepeatBotMessagesCommand:
|
|
25
|
+
"""Converts the dictionary to a command.
|
|
26
|
+
|
|
27
|
+
Returns:
|
|
28
|
+
The converted dictionary.
|
|
29
|
+
"""
|
|
30
|
+
return RepeatBotMessagesCommand()
|
|
31
|
+
|
|
32
|
+
def run_command_on_tracker(
|
|
33
|
+
self,
|
|
34
|
+
tracker: DialogueStateTracker,
|
|
35
|
+
all_flows: FlowsList,
|
|
36
|
+
original_tracker: DialogueStateTracker,
|
|
37
|
+
) -> List[Event]:
|
|
38
|
+
"""Runs the command on the tracker.
|
|
39
|
+
Get all the bot utterances until last user utterance and repeat them.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
tracker: The tracker to run the command on.
|
|
43
|
+
all_flows: All flows in the assistant.
|
|
44
|
+
original_tracker: The tracker before any command was executed.
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
The events to apply to the tracker.
|
|
48
|
+
"""
|
|
49
|
+
stack = tracker.stack
|
|
50
|
+
stack.push(RepeatBotMessagesPatternFlowStackFrame())
|
|
51
|
+
return tracker.create_stack_updated_events(stack)
|
|
52
|
+
|
|
53
|
+
def __hash__(self) -> int:
|
|
54
|
+
return hash(self.command())
|
|
55
|
+
|
|
56
|
+
def __eq__(self, other: object) -> bool:
|
|
57
|
+
if not isinstance(other, RepeatBotMessagesCommand):
|
|
58
|
+
return False
|
|
59
|
+
|
|
60
|
+
return True
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Any, Dict, List
|
|
5
|
+
from rasa.dialogue_understanding.commands import Command
|
|
6
|
+
from rasa.shared.core.events import Event, SessionEnded
|
|
7
|
+
from rasa.shared.core.flows import FlowsList
|
|
8
|
+
from rasa.shared.core.trackers import DialogueStateTracker
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass
|
|
12
|
+
class SessionEndCommand(Command):
|
|
13
|
+
"""A command to indicate the end of a session."""
|
|
14
|
+
|
|
15
|
+
@classmethod
|
|
16
|
+
def command(cls) -> str:
|
|
17
|
+
"""Returns the command type."""
|
|
18
|
+
return "session end"
|
|
19
|
+
|
|
20
|
+
@classmethod
|
|
21
|
+
def from_dict(cls, data: Dict[str, Any]) -> SessionEndCommand:
|
|
22
|
+
"""Converts the dictionary to a command.
|
|
23
|
+
|
|
24
|
+
Returns:
|
|
25
|
+
The converted dictionary.
|
|
26
|
+
"""
|
|
27
|
+
return SessionEndCommand()
|
|
28
|
+
|
|
29
|
+
def run_command_on_tracker(
|
|
30
|
+
self,
|
|
31
|
+
tracker: DialogueStateTracker,
|
|
32
|
+
all_flows: FlowsList,
|
|
33
|
+
original_tracker: DialogueStateTracker,
|
|
34
|
+
) -> List[Event]:
|
|
35
|
+
"""Runs the command on the tracker.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
tracker: The tracker to run the command on.
|
|
39
|
+
all_flows: All flows in the assistant.
|
|
40
|
+
original_tracker: The tracker before any command was executed.
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
The events to apply to the tracker.
|
|
44
|
+
"""
|
|
45
|
+
metadata = {"_reason": "user disconnected"}
|
|
46
|
+
|
|
47
|
+
# Add metadata sent by the channel connector, if available
|
|
48
|
+
if tracker.latest_message:
|
|
49
|
+
user_metadata = tracker.latest_message.metadata or {}
|
|
50
|
+
metadata.update(user_metadata)
|
|
51
|
+
|
|
52
|
+
return [SessionEnded(metadata=metadata)]
|
|
53
|
+
|
|
54
|
+
def __hash__(self) -> int:
|
|
55
|
+
return hash(self.command())
|
|
56
|
+
|
|
57
|
+
def __eq__(self, other: object) -> bool:
|
|
58
|
+
if not isinstance(other, SessionEndCommand):
|
|
59
|
+
return False
|
|
60
|
+
|
|
61
|
+
return True
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Any, Dict, List
|
|
5
|
+
from rasa.dialogue_understanding.commands import Command
|
|
6
|
+
from rasa.dialogue_understanding.patterns.user_silence import (
|
|
7
|
+
UserSilencePatternFlowStackFrame,
|
|
8
|
+
)
|
|
9
|
+
from rasa.shared.core.events import Event
|
|
10
|
+
from rasa.shared.core.flows import FlowsList
|
|
11
|
+
from rasa.shared.core.trackers import DialogueStateTracker
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass
|
|
15
|
+
class UserSilenceCommand(Command):
|
|
16
|
+
"""A command to indicate user silence."""
|
|
17
|
+
|
|
18
|
+
@classmethod
|
|
19
|
+
def command(cls) -> str:
|
|
20
|
+
"""Returns the command type."""
|
|
21
|
+
return "user silence"
|
|
22
|
+
|
|
23
|
+
@classmethod
|
|
24
|
+
def from_dict(cls, data: Dict[str, Any]) -> UserSilenceCommand:
|
|
25
|
+
"""Converts the dictionary to a command.
|
|
26
|
+
|
|
27
|
+
Returns:
|
|
28
|
+
The converted dictionary.
|
|
29
|
+
"""
|
|
30
|
+
return UserSilenceCommand()
|
|
31
|
+
|
|
32
|
+
def run_command_on_tracker(
|
|
33
|
+
self,
|
|
34
|
+
tracker: DialogueStateTracker,
|
|
35
|
+
all_flows: FlowsList,
|
|
36
|
+
original_tracker: DialogueStateTracker,
|
|
37
|
+
) -> List[Event]:
|
|
38
|
+
"""Runs the command on the tracker.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
tracker: The tracker to run the command on.
|
|
42
|
+
all_flows: All flows in the assistant.
|
|
43
|
+
original_tracker: The tracker before any command was executed.
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
The events to apply to the tracker.
|
|
47
|
+
"""
|
|
48
|
+
stack = tracker.stack
|
|
49
|
+
stack.push(UserSilencePatternFlowStackFrame())
|
|
50
|
+
return tracker.create_stack_updated_events(stack)
|
|
51
|
+
|
|
52
|
+
def __hash__(self) -> int:
|
|
53
|
+
return hash(self.command())
|
|
54
|
+
|
|
55
|
+
def __eq__(self, other: object) -> bool:
|
|
56
|
+
if not isinstance(other, UserSilenceCommand):
|
|
57
|
+
return False
|
|
58
|
+
|
|
59
|
+
return True
|
|
@@ -11,6 +11,7 @@ from rasa.dialogue_understanding.commands import (
|
|
|
11
11
|
SkipQuestionCommand,
|
|
12
12
|
RestartCommand,
|
|
13
13
|
)
|
|
14
|
+
from rasa.dialogue_understanding.commands.user_silence_command import UserSilenceCommand
|
|
14
15
|
from rasa.dialogue_understanding.patterns.cancel import CancelPatternFlowStackFrame
|
|
15
16
|
from rasa.dialogue_understanding.patterns.cannot_handle import (
|
|
16
17
|
CannotHandlePatternFlowStackFrame,
|
|
@@ -27,9 +28,13 @@ from rasa.dialogue_understanding.patterns.session_start import (
|
|
|
27
28
|
from rasa.dialogue_understanding.patterns.skip_question import (
|
|
28
29
|
SkipQuestionPatternFlowStackFrame,
|
|
29
30
|
)
|
|
31
|
+
from rasa.dialogue_understanding.patterns.user_silence import (
|
|
32
|
+
UserSilencePatternFlowStackFrame,
|
|
33
|
+
)
|
|
30
34
|
|
|
31
35
|
triggerable_pattern_to_command_class: Dict[str, Type[Command]] = {
|
|
32
36
|
SessionStartPatternFlowStackFrame.flow_id: SessionStartCommand,
|
|
37
|
+
UserSilencePatternFlowStackFrame.flow_id: UserSilenceCommand,
|
|
33
38
|
CancelPatternFlowStackFrame.flow_id: CancelFlowCommand,
|
|
34
39
|
ChitchatPatternFlowStackFrame.flow_id: ChitChatAnswerCommand,
|
|
35
40
|
HumanHandoffPatternFlowStackFrame.flow_id: HumanHandoffCommand,
|
|
@@ -1,5 +1,4 @@
|
|
|
1
|
-
"""
|
|
2
|
-
The module is primarily centered around the `FlowRetrieval` class which handles the
|
|
1
|
+
"""The module is primarily centered around the `FlowRetrieval` class which handles the
|
|
3
2
|
initialization, configuration validation, vector store management, and flow retrieval
|
|
4
3
|
logic. It integrates components for managing embeddings, vector stores, and
|
|
5
4
|
flow-specific templates, facilitating semantic search functionalities.
|
|
@@ -27,8 +26,10 @@ from langchain.docstore.document import Document
|
|
|
27
26
|
from langchain.schema.embeddings import Embeddings
|
|
28
27
|
from langchain_community.vectorstores.faiss import FAISS
|
|
29
28
|
from langchain_community.vectorstores.utils import DistanceStrategy
|
|
29
|
+
|
|
30
30
|
from rasa.engine.storage.resource import Resource
|
|
31
31
|
from rasa.engine.storage.storage import ModelStorage
|
|
32
|
+
import rasa.shared.utils.io
|
|
32
33
|
from rasa.shared.constants import (
|
|
33
34
|
EMBEDDINGS_CONFIG_KEY,
|
|
34
35
|
PROVIDER_CONFIG_KEY,
|
|
@@ -37,12 +38,15 @@ from rasa.shared.constants import (
|
|
|
37
38
|
from rasa.shared.core.domain import Domain
|
|
38
39
|
from rasa.shared.core.flows import FlowsList
|
|
39
40
|
from rasa.shared.core.trackers import DialogueStateTracker
|
|
41
|
+
from rasa.shared.exceptions import ProviderClientAPIException
|
|
40
42
|
from rasa.shared.nlu.constants import TEXT, FLOWS_FROM_SEMANTIC_SEARCH
|
|
41
43
|
from rasa.shared.nlu.training_data.message import Message
|
|
42
|
-
from rasa.shared.exceptions import ProviderClientAPIException
|
|
43
44
|
from rasa.shared.providers.embedding._langchain_embedding_client_adapter import (
|
|
44
45
|
_LangchainEmbeddingClientAdapter,
|
|
45
46
|
)
|
|
47
|
+
from rasa.shared.utils.health_check.embeddings_health_check_mixin import (
|
|
48
|
+
EmbeddingsHealthCheckMixin,
|
|
49
|
+
)
|
|
46
50
|
from rasa.shared.utils.llm import (
|
|
47
51
|
tracker_as_readable_transcript,
|
|
48
52
|
embedder_factory,
|
|
@@ -50,13 +54,15 @@ from rasa.shared.utils.llm import (
|
|
|
50
54
|
USER,
|
|
51
55
|
get_prompt_template,
|
|
52
56
|
allowed_values_for_slot,
|
|
53
|
-
|
|
57
|
+
resolve_model_client_config,
|
|
54
58
|
)
|
|
55
59
|
|
|
56
60
|
DEFAULT_FLOW_DOCUMENT_TEMPLATE = importlib.resources.read_text(
|
|
57
61
|
"rasa.dialogue_understanding.generator", "flow_document_template.jinja2"
|
|
58
62
|
)
|
|
59
63
|
|
|
64
|
+
FLOW_RETRIEVAL_CONFIG_FILE_NAME = "flow_retrieval_config.json"
|
|
65
|
+
|
|
60
66
|
DEFAULT_EMBEDDINGS_CONFIG = {
|
|
61
67
|
PROVIDER_CONFIG_KEY: OPENAI_PROVIDER,
|
|
62
68
|
"model": DEFAULT_OPENAI_EMBEDDING_MODEL_NAME,
|
|
@@ -74,7 +80,7 @@ DEFAULT_SHOULD_EMBED_SLOTS = True
|
|
|
74
80
|
structlogger = structlog.get_logger()
|
|
75
81
|
|
|
76
82
|
|
|
77
|
-
class FlowRetrieval:
|
|
83
|
+
class FlowRetrieval(EmbeddingsHealthCheckMixin):
|
|
78
84
|
@classmethod
|
|
79
85
|
def get_default_config(cls) -> Dict[str, Any]:
|
|
80
86
|
"""The default config for the flow retrieval."""
|
|
@@ -93,6 +99,9 @@ class FlowRetrieval:
|
|
|
93
99
|
):
|
|
94
100
|
config = {**self.get_default_config(), **config}
|
|
95
101
|
self.config = self.validate_config(config)
|
|
102
|
+
self.config[EMBEDDINGS_CONFIG_KEY] = resolve_model_client_config(
|
|
103
|
+
self.config.get(EMBEDDINGS_CONFIG_KEY), FlowRetrieval.__name__
|
|
104
|
+
)
|
|
96
105
|
self.vector_store: Optional[FAISS] = None
|
|
97
106
|
self.flow_document_template = get_prompt_template(
|
|
98
107
|
None, DEFAULT_FLOW_DOCUMENT_TEMPLATE
|
|
@@ -141,19 +150,26 @@ class FlowRetrieval:
|
|
|
141
150
|
**kwargs: Any,
|
|
142
151
|
) -> "FlowRetrieval":
|
|
143
152
|
"""Load flow retrieval with previously populated FAISS vector store."""
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
153
|
+
|
|
154
|
+
# Perform health check on resolved embedding client config
|
|
155
|
+
embeddings_config = resolve_model_client_config(
|
|
156
|
+
config.get(EMBEDDINGS_CONFIG_KEY, {})
|
|
157
|
+
)
|
|
158
|
+
cls.perform_embeddings_health_check(
|
|
159
|
+
embeddings_config,
|
|
148
160
|
DEFAULT_EMBEDDINGS_CONFIG,
|
|
149
161
|
"flow_retrieval.load",
|
|
150
162
|
FlowRetrieval.__name__,
|
|
151
163
|
)
|
|
164
|
+
|
|
165
|
+
# initialize base flow retrieval
|
|
166
|
+
flow_retrieval = FlowRetrieval(config, model_storage, resource)
|
|
152
167
|
# load vector store
|
|
153
168
|
vector_store = cls._load_vector_store(
|
|
154
169
|
flow_retrieval.config, model_storage, resource
|
|
155
170
|
)
|
|
156
171
|
flow_retrieval.vector_store = vector_store
|
|
172
|
+
|
|
157
173
|
return flow_retrieval
|
|
158
174
|
|
|
159
175
|
@classmethod
|
|
@@ -185,13 +201,21 @@ class FlowRetrieval:
|
|
|
185
201
|
Returns:
|
|
186
202
|
The embedder.
|
|
187
203
|
"""
|
|
204
|
+
# Copy the config so original config is not modified
|
|
205
|
+
config = config.copy()
|
|
206
|
+
# Resolve config and instantiate the embedding client
|
|
207
|
+
config[EMBEDDINGS_CONFIG_KEY] = resolve_model_client_config(
|
|
208
|
+
config.get(EMBEDDINGS_CONFIG_KEY), FlowRetrieval.__name__
|
|
209
|
+
)
|
|
188
210
|
client = embedder_factory(
|
|
189
211
|
config.get(EMBEDDINGS_CONFIG_KEY), DEFAULT_EMBEDDINGS_CONFIG
|
|
190
212
|
)
|
|
213
|
+
# Wrap the embedding client in the adapter
|
|
191
214
|
return _LangchainEmbeddingClientAdapter(client)
|
|
192
215
|
|
|
193
216
|
def persist(self) -> None:
|
|
194
217
|
self._persist_vector_store()
|
|
218
|
+
self._persist_config()
|
|
195
219
|
|
|
196
220
|
def _persist_vector_store(self) -> None:
|
|
197
221
|
"""Persists the FAISS vector store."""
|
|
@@ -204,6 +228,12 @@ class FlowRetrieval:
|
|
|
204
228
|
event_info="Vector store is None, not persisted.",
|
|
205
229
|
)
|
|
206
230
|
|
|
231
|
+
def _persist_config(self) -> None:
|
|
232
|
+
with self._model_storage.write_to(self._resource) as path:
|
|
233
|
+
rasa.shared.utils.io.dump_obj_as_json_to_file(
|
|
234
|
+
path / FLOW_RETRIEVAL_CONFIG_FILE_NAME, self.config
|
|
235
|
+
)
|
|
236
|
+
|
|
207
237
|
def populate(self, flows: FlowsList, domain: Domain) -> None:
|
|
208
238
|
"""Populates the vector store with embeddings generated from
|
|
209
239
|
documents based on the flow descriptions, and flow slots
|
|
@@ -213,6 +243,14 @@ class FlowRetrieval:
|
|
|
213
243
|
flows: List of flows to populate the vector store with.
|
|
214
244
|
domain: The domain containing relevant slot information.
|
|
215
245
|
"""
|
|
246
|
+
# Perform health check before populating the vector store with flows
|
|
247
|
+
self.perform_embeddings_health_check(
|
|
248
|
+
self.config.get(EMBEDDINGS_CONFIG_KEY),
|
|
249
|
+
DEFAULT_EMBEDDINGS_CONFIG,
|
|
250
|
+
"flow_retrieval.train",
|
|
251
|
+
FlowRetrieval.__name__,
|
|
252
|
+
)
|
|
253
|
+
|
|
216
254
|
flows_to_embedd = flows.exclude_link_only_flows()
|
|
217
255
|
embeddings = self._create_embedder(self.config)
|
|
218
256
|
documents = self._generate_flow_documents(flows_to_embedd, domain)
|