rasa-pro 3.11.0__py3-none-any.whl → 3.11.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- README.md +396 -17
- rasa/__main__.py +15 -31
- rasa/api.py +1 -5
- rasa/cli/arguments/default_arguments.py +2 -1
- rasa/cli/arguments/shell.py +1 -5
- rasa/cli/arguments/train.py +0 -14
- rasa/cli/e2e_test.py +1 -1
- rasa/cli/evaluate.py +8 -8
- rasa/cli/inspect.py +7 -15
- rasa/cli/interactive.py +0 -1
- rasa/cli/llm_fine_tuning.py +1 -1
- rasa/cli/project_templates/calm/config.yml +7 -5
- rasa/cli/project_templates/calm/endpoints.yml +2 -15
- rasa/cli/project_templates/tutorial/config.yml +5 -8
- rasa/cli/project_templates/tutorial/data/flows.yml +1 -1
- rasa/cli/project_templates/tutorial/data/patterns.yml +0 -5
- rasa/cli/project_templates/tutorial/domain.yml +0 -14
- rasa/cli/project_templates/tutorial/endpoints.yml +0 -5
- rasa/cli/run.py +1 -1
- rasa/cli/scaffold.py +2 -4
- rasa/cli/studio/studio.py +8 -18
- rasa/cli/studio/upload.py +15 -0
- rasa/cli/train.py +0 -3
- rasa/cli/utils.py +1 -6
- rasa/cli/x.py +8 -8
- rasa/constants.py +1 -3
- rasa/core/actions/action.py +33 -75
- rasa/core/actions/e2e_stub_custom_action_executor.py +1 -5
- rasa/core/actions/http_custom_action_executor.py +0 -4
- rasa/core/channels/__init__.py +0 -2
- rasa/core/channels/channel.py +0 -20
- rasa/core/channels/development_inspector.py +3 -10
- rasa/core/channels/inspector/dist/assets/{arc-bc141fb2.js → arc-86942a71.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{c4Diagram-d0fbc5ce-be2db283.js → c4Diagram-d0fbc5ce-b0290676.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{classDiagram-936ed81e-55366915.js → classDiagram-936ed81e-f6405f6e.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{classDiagram-v2-c3cb15f1-bb529518.js → classDiagram-v2-c3cb15f1-ef61ac77.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{createText-62fc7601-b0ec81d6.js → createText-62fc7601-f0411e58.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{edges-f2ad444c-6166330c.js → edges-f2ad444c-7dcc4f3b.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{erDiagram-9d236eb7-5ccc6a8e.js → erDiagram-9d236eb7-e0c092d7.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{flowDb-1972c806-fca3bfe4.js → flowDb-1972c806-fba2e3ce.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{flowDiagram-7ea5b25a-4739080f.js → flowDiagram-7ea5b25a-7a70b71a.js} +1 -1
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-24a5f41a.js +1 -0
- rasa/core/channels/inspector/dist/assets/{flowchart-elk-definition-abe16c3d-7c1b0e0f.js → flowchart-elk-definition-abe16c3d-00a59b68.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{ganttDiagram-9b5ea136-772fd050.js → ganttDiagram-9b5ea136-293c91fa.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{gitGraphDiagram-99d0ae7c-8eae1dc9.js → gitGraphDiagram-99d0ae7c-07b2d68c.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{index-2c4b9a3b-f55afcdf.js → index-2c4b9a3b-bc959fbd.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{index-e7cef9de.js → index-3a8a5a28.js} +143 -143
- rasa/core/channels/inspector/dist/assets/{infoDiagram-736b4530-124d4a14.js → infoDiagram-736b4530-4a350f72.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{journeyDiagram-df861f2b-7c4fae44.js → journeyDiagram-df861f2b-af464fb7.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{layout-b9885fb6.js → layout-0071f036.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{line-7c59abb6.js → line-2f73cc83.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{linear-4776f780.js → linear-f014b4cc.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{mindmap-definition-beec6740-2332c46c.js → mindmap-definition-beec6740-d2426fb6.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{pieDiagram-dbbf0591-8fb39303.js → pieDiagram-dbbf0591-776f01a2.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{quadrantDiagram-4d7f4fd6-3c7180a2.js → quadrantDiagram-4d7f4fd6-82e00b57.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{requirementDiagram-6fc4c22a-e910bcb8.js → requirementDiagram-6fc4c22a-ea13c6bb.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{sankeyDiagram-8f13d901-ead16c89.js → sankeyDiagram-8f13d901-1feca7e9.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{sequenceDiagram-b655622a-29a02a19.js → sequenceDiagram-b655622a-070c61d2.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{stateDiagram-59f0c015-042b3137.js → stateDiagram-59f0c015-24f46263.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{stateDiagram-v2-2b26beab-2178c0f3.js → stateDiagram-v2-2b26beab-c9056051.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-080da4f6-23ffa4fc.js → styles-080da4f6-08abc34a.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-3dcbcfbf-94f59763.js → styles-3dcbcfbf-bc74c25a.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-9c745c82-78a6bebc.js → styles-9c745c82-4e5d66de.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{svgDrawCommon-4835440b-eae2a6f6.js → svgDrawCommon-4835440b-849c4517.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{timeline-definition-5b62e21b-5c968d92.js → timeline-definition-5b62e21b-d0fb1598.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{xychartDiagram-2b33534f-fd3db0d5.js → xychartDiagram-2b33534f-04d115e2.js} +1 -1
- rasa/core/channels/inspector/dist/index.html +1 -1
- rasa/core/channels/inspector/src/App.tsx +1 -1
- rasa/core/channels/inspector/src/components/LoadingSpinner.tsx +3 -6
- rasa/core/channels/socketio.py +2 -7
- rasa/core/channels/telegram.py +1 -1
- rasa/core/channels/twilio.py +1 -1
- rasa/core/channels/voice_ready/audiocodes.py +4 -15
- rasa/core/channels/voice_ready/jambonz.py +4 -15
- rasa/core/channels/voice_ready/twilio_voice.py +21 -6
- rasa/core/channels/voice_ready/utils.py +5 -6
- rasa/core/channels/voice_stream/asr/asr_engine.py +1 -19
- rasa/core/channels/voice_stream/asr/asr_event.py +0 -5
- rasa/core/channels/voice_stream/asr/deepgram.py +15 -28
- rasa/core/channels/voice_stream/audio_bytes.py +0 -1
- rasa/core/channels/voice_stream/tts/azure.py +3 -9
- rasa/core/channels/voice_stream/tts/cartesia.py +8 -12
- rasa/core/channels/voice_stream/tts/tts_engine.py +1 -11
- rasa/core/channels/voice_stream/twilio_media_streams.py +19 -28
- rasa/core/channels/voice_stream/util.py +4 -4
- rasa/core/channels/voice_stream/voice_channel.py +42 -222
- rasa/core/featurizers/single_state_featurizer.py +1 -22
- rasa/core/featurizers/tracker_featurizers.py +18 -115
- rasa/core/information_retrieval/qdrant.py +0 -1
- rasa/core/nlg/contextual_response_rephraser.py +25 -44
- rasa/core/persistor.py +34 -191
- rasa/core/policies/enterprise_search_policy.py +60 -119
- rasa/core/policies/flows/flow_executor.py +4 -7
- rasa/core/policies/intentless_policy.py +22 -82
- rasa/core/policies/ted_policy.py +33 -58
- rasa/core/policies/unexpected_intent_policy.py +7 -15
- rasa/core/processor.py +13 -89
- rasa/core/run.py +2 -2
- rasa/core/training/interactive.py +35 -34
- rasa/core/utils.py +22 -58
- rasa/dialogue_understanding/coexistence/llm_based_router.py +12 -39
- rasa/dialogue_understanding/commands/__init__.py +0 -4
- rasa/dialogue_understanding/commands/change_flow_command.py +0 -6
- rasa/dialogue_understanding/commands/utils.py +0 -5
- rasa/dialogue_understanding/generator/constants.py +0 -2
- rasa/dialogue_understanding/generator/flow_retrieval.py +4 -49
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +23 -37
- rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +10 -57
- rasa/dialogue_understanding/generator/nlu_command_adapter.py +1 -19
- rasa/dialogue_understanding/generator/single_step/command_prompt_template.jinja2 +0 -3
- rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +10 -90
- rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +0 -53
- rasa/dialogue_understanding/processor/command_processor.py +1 -21
- rasa/e2e_test/assertions.py +16 -133
- rasa/e2e_test/assertions_schema.yml +0 -23
- rasa/e2e_test/e2e_test_case.py +6 -85
- rasa/e2e_test/e2e_test_runner.py +4 -6
- rasa/e2e_test/utils/io.py +1 -3
- rasa/engine/loader.py +0 -12
- rasa/engine/validation.py +11 -541
- rasa/keys +1 -0
- rasa/llm_fine_tuning/notebooks/unsloth_finetuning.ipynb +407 -0
- rasa/model_training.py +7 -29
- rasa/nlu/classifiers/diet_classifier.py +25 -38
- rasa/nlu/classifiers/logistic_regression_classifier.py +9 -22
- rasa/nlu/classifiers/sklearn_intent_classifier.py +16 -37
- rasa/nlu/extractors/crf_entity_extractor.py +50 -93
- rasa/nlu/featurizers/sparse_featurizer/count_vectors_featurizer.py +16 -45
- rasa/nlu/featurizers/sparse_featurizer/lexical_syntactic_featurizer.py +17 -52
- rasa/nlu/featurizers/sparse_featurizer/regex_featurizer.py +3 -5
- rasa/nlu/tokenizers/whitespace_tokenizer.py +14 -3
- rasa/server.py +1 -3
- rasa/shared/constants.py +0 -61
- rasa/shared/core/constants.py +0 -9
- rasa/shared/core/domain.py +5 -8
- rasa/shared/core/flows/flow.py +0 -5
- rasa/shared/core/flows/flows_list.py +1 -5
- rasa/shared/core/flows/flows_yaml_schema.json +0 -10
- rasa/shared/core/flows/validation.py +0 -96
- rasa/shared/core/flows/yaml_flows_io.py +4 -13
- rasa/shared/core/slots.py +0 -5
- rasa/shared/importers/importer.py +2 -19
- rasa/shared/importers/rasa.py +1 -5
- rasa/shared/nlu/training_data/features.py +2 -120
- rasa/shared/nlu/training_data/formats/rasa_yaml.py +3 -18
- rasa/shared/providers/_configs/azure_openai_client_config.py +3 -5
- rasa/shared/providers/_configs/openai_client_config.py +1 -1
- rasa/shared/providers/_configs/self_hosted_llm_client_config.py +0 -1
- rasa/shared/providers/_configs/utils.py +0 -16
- rasa/shared/providers/embedding/_base_litellm_embedding_client.py +29 -18
- rasa/shared/providers/embedding/azure_openai_embedding_client.py +21 -54
- rasa/shared/providers/embedding/default_litellm_embedding_client.py +0 -24
- rasa/shared/providers/llm/_base_litellm_client.py +31 -63
- rasa/shared/providers/llm/azure_openai_llm_client.py +29 -50
- rasa/shared/providers/llm/default_litellm_llm_client.py +0 -24
- rasa/shared/providers/llm/self_hosted_llm_client.py +29 -17
- rasa/shared/providers/mappings.py +0 -19
- rasa/shared/utils/common.py +2 -37
- rasa/shared/utils/io.py +6 -28
- rasa/shared/utils/llm.py +46 -353
- rasa/shared/utils/yaml.py +82 -181
- rasa/studio/auth.py +5 -3
- rasa/studio/config.py +4 -13
- rasa/studio/constants.py +0 -1
- rasa/studio/data_handler.py +4 -13
- rasa/studio/upload.py +80 -175
- rasa/telemetry.py +17 -94
- rasa/tracing/config.py +1 -3
- rasa/tracing/instrumentation/attribute_extractors.py +17 -94
- rasa/tracing/instrumentation/instrumentation.py +0 -121
- rasa/utils/common.py +0 -5
- rasa/utils/endpoints.py +1 -27
- rasa/utils/io.py +81 -7
- rasa/utils/log_utils.py +2 -9
- rasa/utils/tensorflow/model_data.py +193 -2
- rasa/validator.py +4 -110
- rasa/version.py +1 -1
- rasa_pro-3.11.0a1.dist-info/METADATA +576 -0
- {rasa_pro-3.11.0.dist-info → rasa_pro-3.11.0a1.dist-info}/RECORD +182 -216
- rasa/core/actions/action_repeat_bot_messages.py +0 -89
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-736177bf.js +0 -1
- rasa/core/channels/inspector/src/helpers/audiostream.ts +0 -165
- rasa/core/channels/voice_stream/asr/azure.py +0 -129
- rasa/core/channels/voice_stream/browser_audio.py +0 -107
- rasa/core/channels/voice_stream/call_state.py +0 -23
- rasa/dialogue_understanding/commands/repeat_bot_messages_command.py +0 -60
- rasa/dialogue_understanding/commands/user_silence_command.py +0 -59
- rasa/dialogue_understanding/patterns/repeat.py +0 -37
- rasa/dialogue_understanding/patterns/user_silence.py +0 -37
- rasa/model_manager/__init__.py +0 -0
- rasa/model_manager/config.py +0 -40
- rasa/model_manager/model_api.py +0 -559
- rasa/model_manager/runner_service.py +0 -286
- rasa/model_manager/socket_bridge.py +0 -146
- rasa/model_manager/studio_jwt_auth.py +0 -86
- rasa/model_manager/trainer_service.py +0 -325
- rasa/model_manager/utils.py +0 -87
- rasa/model_manager/warm_rasa_process.py +0 -187
- rasa/model_service.py +0 -112
- rasa/shared/core/flows/utils.py +0 -39
- rasa/shared/providers/_configs/litellm_router_client_config.py +0 -220
- rasa/shared/providers/_configs/model_group_config.py +0 -167
- rasa/shared/providers/_configs/rasa_llm_client_config.py +0 -73
- rasa/shared/providers/_utils.py +0 -79
- rasa/shared/providers/embedding/litellm_router_embedding_client.py +0 -135
- rasa/shared/providers/llm/litellm_router_llm_client.py +0 -182
- rasa/shared/providers/llm/rasa_llm_client.py +0 -112
- rasa/shared/providers/router/__init__.py +0 -0
- rasa/shared/providers/router/_base_litellm_router_client.py +0 -183
- rasa/shared/providers/router/router_client.py +0 -73
- rasa/shared/utils/health_check/__init__.py +0 -0
- rasa/shared/utils/health_check/embeddings_health_check_mixin.py +0 -31
- rasa/shared/utils/health_check/health_check.py +0 -258
- rasa/shared/utils/health_check/llm_health_check_mixin.py +0 -31
- rasa/utils/sanic_error_handler.py +0 -32
- rasa/utils/tensorflow/feature_array.py +0 -366
- rasa_pro-3.11.0.dist-info/METADATA +0 -198
- {rasa_pro-3.11.0.dist-info → rasa_pro-3.11.0a1.dist-info}/NOTICE +0 -0
- {rasa_pro-3.11.0.dist-info → rasa_pro-3.11.0a1.dist-info}/WHEEL +0 -0
- {rasa_pro-3.11.0.dist-info → rasa_pro-3.11.0a1.dist-info}/entry_points.txt +0 -0
rasa/core/utils.py
CHANGED
|
@@ -1,9 +1,8 @@
|
|
|
1
|
-
import structlog
|
|
2
1
|
import logging
|
|
3
2
|
import os
|
|
4
3
|
from pathlib import Path
|
|
5
4
|
from socket import SOCK_DGRAM, SOCK_STREAM
|
|
6
|
-
from typing import Any, Dict, Optional,
|
|
5
|
+
from typing import Any, Dict, Optional, Set, TYPE_CHECKING, Text, Tuple, Union
|
|
7
6
|
|
|
8
7
|
import numpy as np
|
|
9
8
|
from sanic import Sanic
|
|
@@ -20,18 +19,14 @@ from rasa.core.constants import (
|
|
|
20
19
|
from rasa.core.lock_store import LockStore, RedisLockStore, InMemoryLockStore
|
|
21
20
|
from rasa.shared.constants import DEFAULT_ENDPOINTS_PATH, TCP_PROTOCOL
|
|
22
21
|
from rasa.shared.core.trackers import DialogueStateTracker
|
|
23
|
-
from rasa.utils.endpoints import
|
|
24
|
-
EndpointConfig,
|
|
25
|
-
read_endpoint_config,
|
|
26
|
-
read_property_config_from_endpoints_file,
|
|
27
|
-
)
|
|
22
|
+
from rasa.utils.endpoints import EndpointConfig, read_endpoint_config
|
|
28
23
|
from rasa.utils.io import write_yaml
|
|
29
24
|
|
|
30
25
|
if TYPE_CHECKING:
|
|
31
26
|
from rasa.core.nlg import NaturalLanguageGenerator
|
|
32
27
|
from rasa.shared.core.domain import Domain
|
|
33
28
|
|
|
34
|
-
|
|
29
|
+
logger = logging.getLogger(__name__)
|
|
35
30
|
|
|
36
31
|
|
|
37
32
|
def configure_file_logging(
|
|
@@ -129,17 +124,15 @@ def list_routes(app: Sanic) -> Dict[Text, Text]:
|
|
|
129
124
|
for arg in route._params:
|
|
130
125
|
options[arg] = f"[{arg}]"
|
|
131
126
|
|
|
132
|
-
|
|
133
|
-
methods = ",".join(route.methods)
|
|
127
|
+
handlers = [(next(iter(route.methods)), route.name.replace("rasa_server.", ""))]
|
|
134
128
|
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
129
|
+
for method, name in handlers:
|
|
130
|
+
full_endpoint = "/" + "/".join(endpoint)
|
|
131
|
+
line = unquote(f"{full_endpoint:50s} {method:30s} {name}")
|
|
132
|
+
output[name] = line
|
|
138
133
|
|
|
139
134
|
url_table = "\n".join(output[url] for url in sorted(output))
|
|
140
|
-
|
|
141
|
-
"server.routes", event_info=f"Available web server routes: \n{url_table}"
|
|
142
|
-
)
|
|
135
|
+
logger.debug(f"Available web server routes: \n{url_table}")
|
|
143
136
|
|
|
144
137
|
return output
|
|
145
138
|
|
|
@@ -178,8 +171,6 @@ def is_limit_reached(num_messages: int, limit: Optional[int]) -> bool:
|
|
|
178
171
|
class AvailableEndpoints:
|
|
179
172
|
"""Collection of configured endpoints."""
|
|
180
173
|
|
|
181
|
-
_instance = None
|
|
182
|
-
|
|
183
174
|
@classmethod
|
|
184
175
|
def read_endpoints(cls, endpoint_file: Text) -> "AvailableEndpoints":
|
|
185
176
|
"""Read the different endpoints from a yaml file."""
|
|
@@ -193,9 +184,6 @@ class AvailableEndpoints:
|
|
|
193
184
|
lock_store = read_endpoint_config(endpoint_file, endpoint_type="lock_store")
|
|
194
185
|
event_broker = read_endpoint_config(endpoint_file, endpoint_type="event_broker")
|
|
195
186
|
vector_store = read_endpoint_config(endpoint_file, endpoint_type="vector_store")
|
|
196
|
-
model_groups = read_property_config_from_endpoints_file(
|
|
197
|
-
endpoint_file, property_name="model_groups"
|
|
198
|
-
)
|
|
199
187
|
|
|
200
188
|
return cls(
|
|
201
189
|
nlg,
|
|
@@ -206,7 +194,6 @@ class AvailableEndpoints:
|
|
|
206
194
|
lock_store,
|
|
207
195
|
event_broker,
|
|
208
196
|
vector_store,
|
|
209
|
-
model_groups,
|
|
210
197
|
)
|
|
211
198
|
|
|
212
199
|
def __init__(
|
|
@@ -219,7 +206,6 @@ class AvailableEndpoints:
|
|
|
219
206
|
lock_store: Optional[EndpointConfig] = None,
|
|
220
207
|
event_broker: Optional[EndpointConfig] = None,
|
|
221
208
|
vector_store: Optional[EndpointConfig] = None,
|
|
222
|
-
model_groups: Optional[List[Dict[str, Any]]] = None,
|
|
223
209
|
) -> None:
|
|
224
210
|
"""Create an `AvailableEndpoints` object."""
|
|
225
211
|
self.model = model
|
|
@@ -230,15 +216,6 @@ class AvailableEndpoints:
|
|
|
230
216
|
self.lock_store = lock_store
|
|
231
217
|
self.event_broker = event_broker
|
|
232
218
|
self.vector_store = vector_store
|
|
233
|
-
self.model_groups = model_groups
|
|
234
|
-
|
|
235
|
-
@classmethod
|
|
236
|
-
def get_instance(cls, endpoint_file: Optional[Text] = None) -> "AvailableEndpoints":
|
|
237
|
-
"""Get the singleton instance of AvailableEndpoints."""
|
|
238
|
-
# Ensure that the instance is initialized only once.
|
|
239
|
-
if cls._instance is None:
|
|
240
|
-
cls._instance = cls.read_endpoints(endpoint_file)
|
|
241
|
-
return cls._instance
|
|
242
219
|
|
|
243
220
|
|
|
244
221
|
def read_endpoints_from_path(
|
|
@@ -257,7 +234,7 @@ def read_endpoints_from_path(
|
|
|
257
234
|
endpoints_config_path = cli_utils.get_validated_path(
|
|
258
235
|
endpoints_path, "endpoints", DEFAULT_ENDPOINTS_PATH, True
|
|
259
236
|
)
|
|
260
|
-
return AvailableEndpoints.
|
|
237
|
+
return AvailableEndpoints.read_endpoints(endpoints_config_path)
|
|
261
238
|
|
|
262
239
|
|
|
263
240
|
def _lock_store_is_multi_worker_compatible(
|
|
@@ -286,22 +263,17 @@ def number_of_sanic_workers(lock_store: Union[EndpointConfig, LockStore, None])
|
|
|
286
263
|
"""
|
|
287
264
|
|
|
288
265
|
def _log_and_get_default_number_of_workers() -> int:
|
|
289
|
-
|
|
290
|
-
"
|
|
291
|
-
number_of_workers=DEFAULT_SANIC_WORKERS,
|
|
292
|
-
event_info=f"Using the default number of Sanic workers "
|
|
293
|
-
f"({DEFAULT_SANIC_WORKERS}).",
|
|
266
|
+
logger.debug(
|
|
267
|
+
f"Using the default number of Sanic workers ({DEFAULT_SANIC_WORKERS})."
|
|
294
268
|
)
|
|
295
269
|
return DEFAULT_SANIC_WORKERS
|
|
296
270
|
|
|
297
271
|
try:
|
|
298
272
|
env_value = int(os.environ.get(ENV_SANIC_WORKERS, DEFAULT_SANIC_WORKERS))
|
|
299
273
|
except ValueError:
|
|
300
|
-
|
|
301
|
-
"
|
|
302
|
-
|
|
303
|
-
event_info=f"Cannot convert environment variable `{ENV_SANIC_WORKERS}` "
|
|
304
|
-
f"to int ('{os.environ[ENV_SANIC_WORKERS]}').",
|
|
274
|
+
logger.error(
|
|
275
|
+
f"Cannot convert environment variable `{ENV_SANIC_WORKERS}` "
|
|
276
|
+
f"to int ('{os.environ[ENV_SANIC_WORKERS]}')."
|
|
305
277
|
)
|
|
306
278
|
return _log_and_get_default_number_of_workers()
|
|
307
279
|
|
|
@@ -309,28 +281,20 @@ def number_of_sanic_workers(lock_store: Union[EndpointConfig, LockStore, None])
|
|
|
309
281
|
return _log_and_get_default_number_of_workers()
|
|
310
282
|
|
|
311
283
|
if env_value < 1:
|
|
312
|
-
|
|
313
|
-
"
|
|
314
|
-
|
|
315
|
-
event_info=f"Cannot set number of Sanic workers to the desired value "
|
|
316
|
-
f"({env_value}). The number of workers must be at least 1.",
|
|
284
|
+
logger.debug(
|
|
285
|
+
f"Cannot set number of Sanic workers to the desired value "
|
|
286
|
+
f"({env_value}). The number of workers must be at least 1."
|
|
317
287
|
)
|
|
318
288
|
return _log_and_get_default_number_of_workers()
|
|
319
289
|
|
|
320
290
|
if _lock_store_is_multi_worker_compatible(lock_store):
|
|
321
|
-
|
|
322
|
-
"server.worker.set_count.success",
|
|
323
|
-
event_info=f"Using {env_value} Sanic workers.",
|
|
324
|
-
num_workers=env_value,
|
|
325
|
-
)
|
|
291
|
+
logger.debug(f"Using {env_value} Sanic workers.")
|
|
326
292
|
return env_value
|
|
327
293
|
|
|
328
|
-
|
|
329
|
-
"
|
|
330
|
-
event_info=f"Unable to assign desired number of Sanic workers ({env_value}) as "
|
|
294
|
+
logger.debug(
|
|
295
|
+
f"Unable to assign desired number of Sanic workers ({env_value}) as "
|
|
331
296
|
f"no `RedisLockStore` or custom `LockStore` endpoint "
|
|
332
|
-
f"configuration has been found."
|
|
333
|
-
num_workers=env_value,
|
|
297
|
+
f"configuration has been found."
|
|
334
298
|
)
|
|
335
299
|
return _log_and_get_default_number_of_workers()
|
|
336
300
|
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import importlib
|
|
4
|
+
import os
|
|
4
5
|
from typing import Any, Dict, List, Optional
|
|
5
6
|
|
|
6
7
|
import structlog
|
|
@@ -15,14 +16,13 @@ from rasa.dialogue_understanding.coexistence.constants import (
|
|
|
15
16
|
)
|
|
16
17
|
from rasa.dialogue_understanding.commands import Command, SetSlotCommand
|
|
17
18
|
from rasa.dialogue_understanding.commands.noop_command import NoopCommand
|
|
18
|
-
from rasa.dialogue_understanding.generator.constants import
|
|
19
|
-
LLM_CONFIG_KEY,
|
|
20
|
-
)
|
|
19
|
+
from rasa.dialogue_understanding.generator.constants import LLM_CONFIG_KEY
|
|
21
20
|
from rasa.engine.graph import ExecutionContext, GraphComponent
|
|
22
21
|
from rasa.engine.recipes.default_recipe import DefaultV1Recipe
|
|
23
22
|
from rasa.engine.storage.resource import Resource
|
|
24
23
|
from rasa.engine.storage.storage import ModelStorage
|
|
25
24
|
from rasa.shared.constants import (
|
|
25
|
+
LLM_API_HEALTH_CHECK_ENV_VAR,
|
|
26
26
|
ROUTE_TO_CALM_SLOT,
|
|
27
27
|
PROMPT_CONFIG_KEY,
|
|
28
28
|
PROVIDER_CONFIG_KEY,
|
|
@@ -35,13 +35,12 @@ from rasa.shared.exceptions import InvalidConfigException, FileIOException
|
|
|
35
35
|
from rasa.shared.nlu.constants import COMMANDS, TEXT
|
|
36
36
|
from rasa.shared.nlu.training_data.message import Message
|
|
37
37
|
from rasa.shared.nlu.training_data.training_data import TrainingData
|
|
38
|
-
from rasa.shared.utils.health_check.llm_health_check_mixin import LLMHealthCheckMixin
|
|
39
|
-
from rasa.shared.utils.io import deep_container_fingerprint
|
|
40
38
|
from rasa.shared.utils.llm import (
|
|
41
39
|
DEFAULT_OPENAI_CHAT_MODEL_NAME,
|
|
42
40
|
get_prompt_template,
|
|
41
|
+
llm_api_health_check,
|
|
43
42
|
llm_factory,
|
|
44
|
-
|
|
43
|
+
try_instantiate_llm_client,
|
|
45
44
|
)
|
|
46
45
|
from rasa.utils.log_utils import log_llm
|
|
47
46
|
|
|
@@ -49,7 +48,6 @@ LLM_BASED_ROUTER_PROMPT_FILE_NAME = "llm_based_router_prompt.jinja2"
|
|
|
49
48
|
DEFAULT_COMMAND_PROMPT_TEMPLATE = importlib.resources.read_text(
|
|
50
49
|
"rasa.dialogue_understanding.coexistence", "router_template.jinja2"
|
|
51
50
|
)
|
|
52
|
-
LLM_BASED_ROUTER_CONFIG_FILE_NAME = "config.json"
|
|
53
51
|
|
|
54
52
|
# Token ids for gpt 3.5 and gpt 4 corresponding to space + capitalized Letter
|
|
55
53
|
A_TO_C_TOKEN_IDS_CHATGPT = [
|
|
@@ -76,7 +74,7 @@ structlogger = structlog.get_logger()
|
|
|
76
74
|
],
|
|
77
75
|
is_trainable=True,
|
|
78
76
|
)
|
|
79
|
-
class LLMBasedRouter(
|
|
77
|
+
class LLMBasedRouter(GraphComponent):
|
|
80
78
|
@staticmethod
|
|
81
79
|
def get_default_config() -> Dict[str, Any]:
|
|
82
80
|
"""The component's default config (see parent class for full docstring)."""
|
|
@@ -98,9 +96,6 @@ class LLMBasedRouter(LLMHealthCheckMixin, GraphComponent):
|
|
|
98
96
|
prompt_template: Optional[str] = None,
|
|
99
97
|
) -> None:
|
|
100
98
|
self.config = {**self.get_default_config(), **config}
|
|
101
|
-
self.config[LLM_CONFIG_KEY] = resolve_model_client_config(
|
|
102
|
-
self.config.get(LLM_CONFIG_KEY), LLMBasedRouter.__name__
|
|
103
|
-
)
|
|
104
99
|
|
|
105
100
|
self.prompt_template = (
|
|
106
101
|
prompt_template
|
|
@@ -134,18 +129,20 @@ class LLMBasedRouter(LLMHealthCheckMixin, GraphComponent):
|
|
|
134
129
|
rasa.shared.utils.io.write_text_file(
|
|
135
130
|
self.prompt_template, path / LLM_BASED_ROUTER_PROMPT_FILE_NAME
|
|
136
131
|
)
|
|
137
|
-
rasa.shared.utils.io.dump_obj_as_json_to_file(
|
|
138
|
-
path / LLM_BASED_ROUTER_CONFIG_FILE_NAME, self.config
|
|
139
|
-
)
|
|
140
132
|
|
|
141
133
|
def train(self, training_data: TrainingData) -> Resource:
|
|
142
134
|
"""Train the intent classifier on a data set."""
|
|
143
|
-
|
|
135
|
+
# Validate llm configuration
|
|
136
|
+
llm_client = try_instantiate_llm_client(
|
|
144
137
|
self.config.get(LLM_CONFIG_KEY),
|
|
145
138
|
DEFAULT_LLM_CONFIG,
|
|
146
139
|
"llm_based_router.train",
|
|
147
140
|
LLMBasedRouter.__name__,
|
|
148
141
|
)
|
|
142
|
+
if os.getenv(LLM_API_HEALTH_CHECK_ENV_VAR, "true").lower() == "true":
|
|
143
|
+
llm_api_health_check(
|
|
144
|
+
llm_client, "llm_based_router.train", LLMBasedRouter.__name__
|
|
145
|
+
)
|
|
149
146
|
|
|
150
147
|
self.persist()
|
|
151
148
|
return self._resource
|
|
@@ -160,16 +157,6 @@ class LLMBasedRouter(LLMHealthCheckMixin, GraphComponent):
|
|
|
160
157
|
**kwargs: Any,
|
|
161
158
|
) -> "LLMBasedRouter":
|
|
162
159
|
"""Loads trained component (see parent class for full docstring)."""
|
|
163
|
-
|
|
164
|
-
# Perform health check on the resolved LLM client config
|
|
165
|
-
llm_config = resolve_model_client_config(config.get(LLM_CONFIG_KEY, {}))
|
|
166
|
-
cls.perform_llm_health_check(
|
|
167
|
-
llm_config,
|
|
168
|
-
DEFAULT_LLM_CONFIG,
|
|
169
|
-
"llm_based_router.load",
|
|
170
|
-
LLMBasedRouter.__name__,
|
|
171
|
-
)
|
|
172
|
-
|
|
173
160
|
prompt_template = None
|
|
174
161
|
try:
|
|
175
162
|
with model_storage.read_from(resource) as path:
|
|
@@ -311,17 +298,3 @@ class LLMBasedRouter(LLMHealthCheckMixin, GraphComponent):
|
|
|
311
298
|
# we have to catch all exceptions here
|
|
312
299
|
structlogger.error("llm_based_router.llm.error", error=e)
|
|
313
300
|
return None
|
|
314
|
-
|
|
315
|
-
@classmethod
|
|
316
|
-
def fingerprint_addon(cls, config: Dict[str, Any]) -> Optional[str]:
|
|
317
|
-
"""Add a fingerprint of llm based router for the graph."""
|
|
318
|
-
prompt_template = get_prompt_template(
|
|
319
|
-
config.get(PROMPT_CONFIG_KEY),
|
|
320
|
-
DEFAULT_COMMAND_PROMPT_TEMPLATE,
|
|
321
|
-
)
|
|
322
|
-
|
|
323
|
-
llm_config = resolve_model_client_config(
|
|
324
|
-
config.get(LLM_CONFIG_KEY), LLMBasedRouter.__name__
|
|
325
|
-
)
|
|
326
|
-
|
|
327
|
-
return deep_container_fingerprint([prompt_template, llm_config])
|
|
@@ -33,9 +33,6 @@ from rasa.dialogue_understanding.commands.session_start_command import (
|
|
|
33
33
|
SessionStartCommand,
|
|
34
34
|
)
|
|
35
35
|
from rasa.dialogue_understanding.commands.session_end_command import SessionEndCommand
|
|
36
|
-
from rasa.dialogue_understanding.commands.repeat_bot_messages_command import (
|
|
37
|
-
RepeatBotMessagesCommand,
|
|
38
|
-
)
|
|
39
36
|
|
|
40
37
|
__all__ = [
|
|
41
38
|
"Command",
|
|
@@ -56,6 +53,5 @@ __all__ = [
|
|
|
56
53
|
"ChangeFlowCommand",
|
|
57
54
|
"SessionStartCommand",
|
|
58
55
|
"SessionEndCommand",
|
|
59
|
-
"RepeatBotMessagesCommand",
|
|
60
56
|
"RestartCommand",
|
|
61
57
|
]
|
|
@@ -36,9 +36,3 @@ class ChangeFlowCommand(Command):
|
|
|
36
36
|
# the change flow command is not actually pushing anything to the tracker,
|
|
37
37
|
# but it is predicted by the MultiStepLLMCommandGenerator and used internally
|
|
38
38
|
return []
|
|
39
|
-
|
|
40
|
-
def __eq__(self, other: Any) -> bool:
|
|
41
|
-
return isinstance(other, ChangeFlowCommand)
|
|
42
|
-
|
|
43
|
-
def __hash__(self) -> int:
|
|
44
|
-
return hash(self.command())
|
|
@@ -11,7 +11,6 @@ from rasa.dialogue_understanding.commands import (
|
|
|
11
11
|
SkipQuestionCommand,
|
|
12
12
|
RestartCommand,
|
|
13
13
|
)
|
|
14
|
-
from rasa.dialogue_understanding.commands.user_silence_command import UserSilenceCommand
|
|
15
14
|
from rasa.dialogue_understanding.patterns.cancel import CancelPatternFlowStackFrame
|
|
16
15
|
from rasa.dialogue_understanding.patterns.cannot_handle import (
|
|
17
16
|
CannotHandlePatternFlowStackFrame,
|
|
@@ -28,13 +27,9 @@ from rasa.dialogue_understanding.patterns.session_start import (
|
|
|
28
27
|
from rasa.dialogue_understanding.patterns.skip_question import (
|
|
29
28
|
SkipQuestionPatternFlowStackFrame,
|
|
30
29
|
)
|
|
31
|
-
from rasa.dialogue_understanding.patterns.user_silence import (
|
|
32
|
-
UserSilencePatternFlowStackFrame,
|
|
33
|
-
)
|
|
34
30
|
|
|
35
31
|
triggerable_pattern_to_command_class: Dict[str, Type[Command]] = {
|
|
36
32
|
SessionStartPatternFlowStackFrame.flow_id: SessionStartCommand,
|
|
37
|
-
UserSilencePatternFlowStackFrame.flow_id: UserSilenceCommand,
|
|
38
33
|
CancelPatternFlowStackFrame.flow_id: CancelFlowCommand,
|
|
39
34
|
ChitchatPatternFlowStackFrame.flow_id: ChitChatAnswerCommand,
|
|
40
35
|
HumanHandoffPatternFlowStackFrame.flow_id: HumanHandoffCommand,
|
|
@@ -1,4 +1,5 @@
|
|
|
1
|
-
"""
|
|
1
|
+
"""
|
|
2
|
+
The module is primarily centered around the `FlowRetrieval` class which handles the
|
|
2
3
|
initialization, configuration validation, vector store management, and flow retrieval
|
|
3
4
|
logic. It integrates components for managing embeddings, vector stores, and
|
|
4
5
|
flow-specific templates, facilitating semantic search functionalities.
|
|
@@ -26,10 +27,8 @@ from langchain.docstore.document import Document
|
|
|
26
27
|
from langchain.schema.embeddings import Embeddings
|
|
27
28
|
from langchain_community.vectorstores.faiss import FAISS
|
|
28
29
|
from langchain_community.vectorstores.utils import DistanceStrategy
|
|
29
|
-
|
|
30
30
|
from rasa.engine.storage.resource import Resource
|
|
31
31
|
from rasa.engine.storage.storage import ModelStorage
|
|
32
|
-
import rasa.shared.utils.io
|
|
33
32
|
from rasa.shared.constants import (
|
|
34
33
|
EMBEDDINGS_CONFIG_KEY,
|
|
35
34
|
PROVIDER_CONFIG_KEY,
|
|
@@ -38,15 +37,12 @@ from rasa.shared.constants import (
|
|
|
38
37
|
from rasa.shared.core.domain import Domain
|
|
39
38
|
from rasa.shared.core.flows import FlowsList
|
|
40
39
|
from rasa.shared.core.trackers import DialogueStateTracker
|
|
41
|
-
from rasa.shared.exceptions import ProviderClientAPIException
|
|
42
40
|
from rasa.shared.nlu.constants import TEXT, FLOWS_FROM_SEMANTIC_SEARCH
|
|
43
41
|
from rasa.shared.nlu.training_data.message import Message
|
|
42
|
+
from rasa.shared.exceptions import ProviderClientAPIException
|
|
44
43
|
from rasa.shared.providers.embedding._langchain_embedding_client_adapter import (
|
|
45
44
|
_LangchainEmbeddingClientAdapter,
|
|
46
45
|
)
|
|
47
|
-
from rasa.shared.utils.health_check.embeddings_health_check_mixin import (
|
|
48
|
-
EmbeddingsHealthCheckMixin,
|
|
49
|
-
)
|
|
50
46
|
from rasa.shared.utils.llm import (
|
|
51
47
|
tracker_as_readable_transcript,
|
|
52
48
|
embedder_factory,
|
|
@@ -54,15 +50,12 @@ from rasa.shared.utils.llm import (
|
|
|
54
50
|
USER,
|
|
55
51
|
get_prompt_template,
|
|
56
52
|
allowed_values_for_slot,
|
|
57
|
-
resolve_model_client_config,
|
|
58
53
|
)
|
|
59
54
|
|
|
60
55
|
DEFAULT_FLOW_DOCUMENT_TEMPLATE = importlib.resources.read_text(
|
|
61
56
|
"rasa.dialogue_understanding.generator", "flow_document_template.jinja2"
|
|
62
57
|
)
|
|
63
58
|
|
|
64
|
-
FLOW_RETRIEVAL_CONFIG_FILE_NAME = "flow_retrieval_config.json"
|
|
65
|
-
|
|
66
59
|
DEFAULT_EMBEDDINGS_CONFIG = {
|
|
67
60
|
PROVIDER_CONFIG_KEY: OPENAI_PROVIDER,
|
|
68
61
|
"model": DEFAULT_OPENAI_EMBEDDING_MODEL_NAME,
|
|
@@ -80,7 +73,7 @@ DEFAULT_SHOULD_EMBED_SLOTS = True
|
|
|
80
73
|
structlogger = structlog.get_logger()
|
|
81
74
|
|
|
82
75
|
|
|
83
|
-
class FlowRetrieval
|
|
76
|
+
class FlowRetrieval:
|
|
84
77
|
@classmethod
|
|
85
78
|
def get_default_config(cls) -> Dict[str, Any]:
|
|
86
79
|
"""The default config for the flow retrieval."""
|
|
@@ -99,9 +92,6 @@ class FlowRetrieval(EmbeddingsHealthCheckMixin):
|
|
|
99
92
|
):
|
|
100
93
|
config = {**self.get_default_config(), **config}
|
|
101
94
|
self.config = self.validate_config(config)
|
|
102
|
-
self.config[EMBEDDINGS_CONFIG_KEY] = resolve_model_client_config(
|
|
103
|
-
self.config.get(EMBEDDINGS_CONFIG_KEY), FlowRetrieval.__name__
|
|
104
|
-
)
|
|
105
95
|
self.vector_store: Optional[FAISS] = None
|
|
106
96
|
self.flow_document_template = get_prompt_template(
|
|
107
97
|
None, DEFAULT_FLOW_DOCUMENT_TEMPLATE
|
|
@@ -150,18 +140,6 @@ class FlowRetrieval(EmbeddingsHealthCheckMixin):
|
|
|
150
140
|
**kwargs: Any,
|
|
151
141
|
) -> "FlowRetrieval":
|
|
152
142
|
"""Load flow retrieval with previously populated FAISS vector store."""
|
|
153
|
-
|
|
154
|
-
# Perform health check on resolved embedding client config
|
|
155
|
-
embeddings_config = resolve_model_client_config(
|
|
156
|
-
config.get(EMBEDDINGS_CONFIG_KEY, {})
|
|
157
|
-
)
|
|
158
|
-
cls.perform_embeddings_health_check(
|
|
159
|
-
embeddings_config,
|
|
160
|
-
DEFAULT_EMBEDDINGS_CONFIG,
|
|
161
|
-
"flow_retrieval.load",
|
|
162
|
-
FlowRetrieval.__name__,
|
|
163
|
-
)
|
|
164
|
-
|
|
165
143
|
# initialize base flow retrieval
|
|
166
144
|
flow_retrieval = FlowRetrieval(config, model_storage, resource)
|
|
167
145
|
# load vector store
|
|
@@ -169,7 +147,6 @@ class FlowRetrieval(EmbeddingsHealthCheckMixin):
|
|
|
169
147
|
flow_retrieval.config, model_storage, resource
|
|
170
148
|
)
|
|
171
149
|
flow_retrieval.vector_store = vector_store
|
|
172
|
-
|
|
173
150
|
return flow_retrieval
|
|
174
151
|
|
|
175
152
|
@classmethod
|
|
@@ -201,21 +178,13 @@ class FlowRetrieval(EmbeddingsHealthCheckMixin):
|
|
|
201
178
|
Returns:
|
|
202
179
|
The embedder.
|
|
203
180
|
"""
|
|
204
|
-
# Copy the config so original config is not modified
|
|
205
|
-
config = config.copy()
|
|
206
|
-
# Resolve config and instantiate the embedding client
|
|
207
|
-
config[EMBEDDINGS_CONFIG_KEY] = resolve_model_client_config(
|
|
208
|
-
config.get(EMBEDDINGS_CONFIG_KEY), FlowRetrieval.__name__
|
|
209
|
-
)
|
|
210
181
|
client = embedder_factory(
|
|
211
182
|
config.get(EMBEDDINGS_CONFIG_KEY), DEFAULT_EMBEDDINGS_CONFIG
|
|
212
183
|
)
|
|
213
|
-
# Wrap the embedding client in the adapter
|
|
214
184
|
return _LangchainEmbeddingClientAdapter(client)
|
|
215
185
|
|
|
216
186
|
def persist(self) -> None:
|
|
217
187
|
self._persist_vector_store()
|
|
218
|
-
self._persist_config()
|
|
219
188
|
|
|
220
189
|
def _persist_vector_store(self) -> None:
|
|
221
190
|
"""Persists the FAISS vector store."""
|
|
@@ -228,12 +197,6 @@ class FlowRetrieval(EmbeddingsHealthCheckMixin):
|
|
|
228
197
|
event_info="Vector store is None, not persisted.",
|
|
229
198
|
)
|
|
230
199
|
|
|
231
|
-
def _persist_config(self) -> None:
|
|
232
|
-
with self._model_storage.write_to(self._resource) as path:
|
|
233
|
-
rasa.shared.utils.io.dump_obj_as_json_to_file(
|
|
234
|
-
path / FLOW_RETRIEVAL_CONFIG_FILE_NAME, self.config
|
|
235
|
-
)
|
|
236
|
-
|
|
237
200
|
def populate(self, flows: FlowsList, domain: Domain) -> None:
|
|
238
201
|
"""Populates the vector store with embeddings generated from
|
|
239
202
|
documents based on the flow descriptions, and flow slots
|
|
@@ -243,14 +206,6 @@ class FlowRetrieval(EmbeddingsHealthCheckMixin):
|
|
|
243
206
|
flows: List of flows to populate the vector store with.
|
|
244
207
|
domain: The domain containing relevant slot information.
|
|
245
208
|
"""
|
|
246
|
-
# Perform health check before populating the vector store with flows
|
|
247
|
-
self.perform_embeddings_health_check(
|
|
248
|
-
self.config.get(EMBEDDINGS_CONFIG_KEY),
|
|
249
|
-
DEFAULT_EMBEDDINGS_CONFIG,
|
|
250
|
-
"flow_retrieval.train",
|
|
251
|
-
FlowRetrieval.__name__,
|
|
252
|
-
)
|
|
253
|
-
|
|
254
209
|
flows_to_embedd = flows.exclude_link_only_flows()
|
|
255
210
|
embeddings = self._create_embedder(self.config)
|
|
256
211
|
documents = self._generate_flow_documents(flows_to_embedd, domain)
|
|
@@ -2,6 +2,7 @@ from abc import ABC, abstractmethod
|
|
|
2
2
|
from functools import lru_cache
|
|
3
3
|
from typing import Dict, Any, List, Optional, Tuple, Union, Text
|
|
4
4
|
|
|
5
|
+
import os
|
|
5
6
|
import structlog
|
|
6
7
|
from jinja2 import Template
|
|
7
8
|
|
|
@@ -16,13 +17,13 @@ from rasa.dialogue_understanding.generator.constants import (
|
|
|
16
17
|
LLM_CONFIG_KEY,
|
|
17
18
|
FLOW_RETRIEVAL_KEY,
|
|
18
19
|
FLOW_RETRIEVAL_ACTIVE_KEY,
|
|
19
|
-
FLOW_RETRIEVAL_FLOW_THRESHOLD,
|
|
20
20
|
)
|
|
21
21
|
from rasa.dialogue_understanding.generator.flow_retrieval import FlowRetrieval
|
|
22
22
|
from rasa.engine.graph import GraphComponent, ExecutionContext
|
|
23
23
|
from rasa.engine.recipes.default_recipe import DefaultV1Recipe
|
|
24
24
|
from rasa.engine.storage.resource import Resource
|
|
25
25
|
from rasa.engine.storage.storage import ModelStorage
|
|
26
|
+
from rasa.shared.constants import LLM_API_HEALTH_CHECK_ENV_VAR
|
|
26
27
|
from rasa.shared.core.domain import Domain
|
|
27
28
|
from rasa.shared.core.flows import FlowStep, Flow, FlowsList
|
|
28
29
|
from rasa.shared.core.flows.steps.collect import CollectInformationFlowStep
|
|
@@ -32,11 +33,11 @@ from rasa.shared.exceptions import ProviderClientAPIException
|
|
|
32
33
|
from rasa.shared.nlu.constants import FLOWS_IN_PROMPT
|
|
33
34
|
from rasa.shared.nlu.training_data.message import Message
|
|
34
35
|
from rasa.shared.nlu.training_data.training_data import TrainingData
|
|
35
|
-
from rasa.shared.utils.health_check.llm_health_check_mixin import LLMHealthCheckMixin
|
|
36
36
|
from rasa.shared.utils.llm import (
|
|
37
37
|
allowed_values_for_slot,
|
|
38
|
+
llm_api_health_check,
|
|
38
39
|
llm_factory,
|
|
39
|
-
|
|
40
|
+
try_instantiate_llm_client,
|
|
40
41
|
)
|
|
41
42
|
from rasa.utils.log_utils import log_llm
|
|
42
43
|
|
|
@@ -49,9 +50,7 @@ structlogger = structlog.get_logger()
|
|
|
49
50
|
],
|
|
50
51
|
is_trainable=True,
|
|
51
52
|
)
|
|
52
|
-
class LLMBasedCommandGenerator(
|
|
53
|
-
LLMHealthCheckMixin, GraphComponent, CommandGenerator, ABC
|
|
54
|
-
):
|
|
53
|
+
class LLMBasedCommandGenerator(GraphComponent, CommandGenerator, ABC):
|
|
55
54
|
"""An abstract class defining interface and common functionality
|
|
56
55
|
of an LLM-based command generators.
|
|
57
56
|
"""
|
|
@@ -65,9 +64,6 @@ class LLMBasedCommandGenerator(
|
|
|
65
64
|
) -> None:
|
|
66
65
|
super().__init__(config)
|
|
67
66
|
self.config = {**self.get_default_config(), **config}
|
|
68
|
-
self.config[LLM_CONFIG_KEY] = resolve_model_client_config(
|
|
69
|
-
self.config.get(LLM_CONFIG_KEY), LLMBasedCommandGenerator.__name__
|
|
70
|
-
)
|
|
71
67
|
self._model_storage = model_storage
|
|
72
68
|
self._resource = resource
|
|
73
69
|
self.flow_retrieval: Optional[FlowRetrieval]
|
|
@@ -77,9 +73,17 @@ class LLMBasedCommandGenerator(
|
|
|
77
73
|
self.config[FLOW_RETRIEVAL_KEY], model_storage, resource
|
|
78
74
|
)
|
|
79
75
|
structlogger.info("llm_based_command_generator.flow_retrieval.enabled")
|
|
80
|
-
self.config[FLOW_RETRIEVAL_KEY] = self.flow_retrieval.config
|
|
81
76
|
else:
|
|
82
77
|
self.flow_retrieval = None
|
|
78
|
+
structlogger.warn(
|
|
79
|
+
"llm_based_command_generator.flow_retrieval.disabled",
|
|
80
|
+
event_info=(
|
|
81
|
+
"Disabling flow retrieval can cause issues when there are a "
|
|
82
|
+
"large number of flows to be included in the prompt. For more"
|
|
83
|
+
"information see:\n"
|
|
84
|
+
"https://rasa.com/docs/rasa-pro/concepts/dialogue-understanding#how-the-llmcommandgenerator-works"
|
|
85
|
+
),
|
|
86
|
+
)
|
|
83
87
|
|
|
84
88
|
### Abstract methods
|
|
85
89
|
@staticmethod
|
|
@@ -167,32 +171,18 @@ class LLMBasedCommandGenerator(
|
|
|
167
171
|
"""Train the llm based command generator. Stores all flows into a vector
|
|
168
172
|
store.
|
|
169
173
|
"""
|
|
170
|
-
|
|
174
|
+
# Validate llm configuration
|
|
175
|
+
llm_client = try_instantiate_llm_client(
|
|
171
176
|
self.config.get(LLM_CONFIG_KEY),
|
|
172
177
|
DEFAULT_LLM_CONFIG,
|
|
173
178
|
"llm_based_command_generator.train",
|
|
174
179
|
LLMBasedCommandGenerator.__name__,
|
|
175
180
|
)
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
structlogger.warn(
|
|
182
|
-
"llm_based_command_generator.flow_retrieval.disabled",
|
|
183
|
-
event_info=(
|
|
184
|
-
f"You have {len(flows.user_flows)} user flows but flow "
|
|
185
|
-
f"retrieval is disabled. "
|
|
186
|
-
f"It is recommended to enable flow retrieval if the "
|
|
187
|
-
f"total number of user flows exceed "
|
|
188
|
-
f"{FLOW_RETRIEVAL_FLOW_THRESHOLD}. "
|
|
189
|
-
f"Keeping it disabled can result in deterioration of "
|
|
190
|
-
f"command generator's functional "
|
|
191
|
-
f"performance and higher costs because of increased "
|
|
192
|
-
f"number of tokens in the prompt. For more"
|
|
193
|
-
"information see:\n"
|
|
194
|
-
"https://rasa.com/docs/rasa-pro/concepts/dialogue-understanding#how-the-llmcommandgenerator-works"
|
|
195
|
-
),
|
|
181
|
+
if os.getenv(LLM_API_HEALTH_CHECK_ENV_VAR, "true").lower() == "true":
|
|
182
|
+
llm_api_health_check(
|
|
183
|
+
llm_client,
|
|
184
|
+
"llm_based_command_generator.train",
|
|
185
|
+
LLMBasedCommandGenerator.__name__,
|
|
196
186
|
)
|
|
197
187
|
|
|
198
188
|
# flow retrieval is populated with only user-defined flows
|
|
@@ -202,11 +192,10 @@ class LLMBasedCommandGenerator(
|
|
|
202
192
|
except Exception as e:
|
|
203
193
|
structlogger.error(
|
|
204
194
|
"llm_based_command_generator.train.failed",
|
|
205
|
-
event_info="Flow retrieval store
|
|
195
|
+
event_info=("Flow retrieval store isinaccessible."),
|
|
206
196
|
error=e,
|
|
207
197
|
)
|
|
208
198
|
raise
|
|
209
|
-
|
|
210
199
|
self.persist()
|
|
211
200
|
return self._resource
|
|
212
201
|
|
|
@@ -244,10 +233,7 @@ class LLMBasedCommandGenerator(
|
|
|
244
233
|
|
|
245
234
|
@classmethod
|
|
246
235
|
def load_flow_retrival(
|
|
247
|
-
cls,
|
|
248
|
-
config: Dict[str, Any],
|
|
249
|
-
model_storage: ModelStorage,
|
|
250
|
-
resource: Resource,
|
|
236
|
+
cls, config: Dict[Text, Any], model_storage: ModelStorage, resource: Resource
|
|
251
237
|
) -> Optional[FlowRetrieval]:
|
|
252
238
|
"""Load the FlowRetrieval component if it is enabled in the configuration."""
|
|
253
239
|
enable_flow_retrieval = config.get(FLOW_RETRIEVAL_KEY, {}).get(
|