rasa-pro 3.11.0a4.dev3__py3-none-any.whl → 3.11.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- rasa/__main__.py +22 -12
- rasa/api.py +1 -1
- rasa/cli/arguments/default_arguments.py +1 -2
- rasa/cli/arguments/shell.py +5 -1
- rasa/cli/e2e_test.py +1 -1
- rasa/cli/evaluate.py +8 -8
- rasa/cli/inspect.py +6 -4
- rasa/cli/llm_fine_tuning.py +1 -1
- rasa/cli/project_templates/calm/config.yml +5 -7
- rasa/cli/project_templates/calm/endpoints.yml +8 -0
- rasa/cli/project_templates/tutorial/config.yml +8 -5
- rasa/cli/project_templates/tutorial/data/flows.yml +1 -1
- rasa/cli/project_templates/tutorial/data/patterns.yml +5 -0
- rasa/cli/project_templates/tutorial/domain.yml +14 -0
- rasa/cli/project_templates/tutorial/endpoints.yml +7 -7
- rasa/cli/run.py +1 -1
- rasa/cli/scaffold.py +4 -2
- rasa/cli/studio/studio.py +18 -8
- rasa/cli/utils.py +5 -0
- rasa/cli/x.py +8 -8
- rasa/constants.py +1 -1
- rasa/core/actions/action_repeat_bot_messages.py +17 -0
- rasa/core/channels/channel.py +20 -0
- rasa/core/channels/inspector/dist/assets/{arc-6852c607.js → arc-bc141fb2.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{c4Diagram-d0fbc5ce-acc952b2.js → c4Diagram-d0fbc5ce-be2db283.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{classDiagram-936ed81e-848a7597.js → classDiagram-936ed81e-55366915.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{classDiagram-v2-c3cb15f1-a73d3e68.js → classDiagram-v2-c3cb15f1-bb529518.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{createText-62fc7601-e5ee049d.js → createText-62fc7601-b0ec81d6.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{edges-f2ad444c-771e517e.js → edges-f2ad444c-6166330c.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{erDiagram-9d236eb7-aa347178.js → erDiagram-9d236eb7-5ccc6a8e.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{flowDb-1972c806-651fc57d.js → flowDb-1972c806-fca3bfe4.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{flowDiagram-7ea5b25a-ca67804f.js → flowDiagram-7ea5b25a-4739080f.js} +1 -1
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-736177bf.js +1 -0
- rasa/core/channels/inspector/dist/assets/{flowchart-elk-definition-abe16c3d-2dbc568d.js → flowchart-elk-definition-abe16c3d-7c1b0e0f.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{ganttDiagram-9b5ea136-25a65bd8.js → ganttDiagram-9b5ea136-772fd050.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{gitGraphDiagram-99d0ae7c-fdc7378d.js → gitGraphDiagram-99d0ae7c-8eae1dc9.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{index-2c4b9a3b-6f1fd606.js → index-2c4b9a3b-f55afcdf.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{index-efdd30c1.js → index-e7cef9de.js} +68 -68
- rasa/core/channels/inspector/dist/assets/{infoDiagram-736b4530-cb1a041a.js → infoDiagram-736b4530-124d4a14.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{journeyDiagram-df861f2b-14609879.js → journeyDiagram-df861f2b-7c4fae44.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{layout-2490f52b.js → layout-b9885fb6.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{line-40186f1f.js → line-7c59abb6.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{linear-08814e93.js → linear-4776f780.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{mindmap-definition-beec6740-1a534584.js → mindmap-definition-beec6740-2332c46c.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{pieDiagram-dbbf0591-72397b61.js → pieDiagram-dbbf0591-8fb39303.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{quadrantDiagram-4d7f4fd6-3bb0b6a3.js → quadrantDiagram-4d7f4fd6-3c7180a2.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{requirementDiagram-6fc4c22a-57334f61.js → requirementDiagram-6fc4c22a-e910bcb8.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{sankeyDiagram-8f13d901-111e1297.js → sankeyDiagram-8f13d901-ead16c89.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{sequenceDiagram-b655622a-10bcfe62.js → sequenceDiagram-b655622a-29a02a19.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{stateDiagram-59f0c015-acaf7513.js → stateDiagram-59f0c015-042b3137.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{stateDiagram-v2-2b26beab-3ec2a235.js → stateDiagram-v2-2b26beab-2178c0f3.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-080da4f6-62730289.js → styles-080da4f6-23ffa4fc.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-3dcbcfbf-5284ee76.js → styles-3dcbcfbf-94f59763.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-9c745c82-642435e3.js → styles-9c745c82-78a6bebc.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{svgDrawCommon-4835440b-b250a350.js → svgDrawCommon-4835440b-eae2a6f6.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{timeline-definition-5b62e21b-c2b147ed.js → timeline-definition-5b62e21b-5c968d92.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{xychartDiagram-2b33534f-f92cfea9.js → xychartDiagram-2b33534f-fd3db0d5.js} +1 -1
- rasa/core/channels/inspector/dist/index.html +1 -1
- rasa/core/channels/inspector/src/App.tsx +1 -1
- rasa/core/channels/inspector/src/helpers/audiostream.ts +77 -16
- rasa/core/channels/socketio.py +2 -1
- rasa/core/channels/telegram.py +1 -1
- rasa/core/channels/twilio.py +1 -1
- rasa/core/channels/voice_ready/audiocodes.py +12 -0
- rasa/core/channels/voice_ready/jambonz.py +15 -4
- rasa/core/channels/voice_ready/twilio_voice.py +6 -21
- rasa/core/channels/voice_stream/asr/asr_event.py +5 -0
- rasa/core/channels/voice_stream/asr/azure.py +122 -0
- rasa/core/channels/voice_stream/asr/deepgram.py +16 -6
- rasa/core/channels/voice_stream/audio_bytes.py +1 -0
- rasa/core/channels/voice_stream/browser_audio.py +31 -8
- rasa/core/channels/voice_stream/call_state.py +23 -0
- rasa/core/channels/voice_stream/tts/azure.py +6 -2
- rasa/core/channels/voice_stream/tts/cartesia.py +10 -6
- rasa/core/channels/voice_stream/tts/tts_engine.py +1 -0
- rasa/core/channels/voice_stream/twilio_media_streams.py +27 -18
- rasa/core/channels/voice_stream/util.py +4 -4
- rasa/core/channels/voice_stream/voice_channel.py +189 -39
- rasa/core/featurizers/single_state_featurizer.py +22 -1
- rasa/core/featurizers/tracker_featurizers.py +115 -18
- rasa/core/nlg/contextual_response_rephraser.py +32 -30
- rasa/core/persistor.py +86 -39
- rasa/core/policies/enterprise_search_policy.py +119 -60
- rasa/core/policies/flows/flow_executor.py +7 -4
- rasa/core/policies/intentless_policy.py +78 -22
- rasa/core/policies/ted_policy.py +58 -33
- rasa/core/policies/unexpected_intent_policy.py +15 -7
- rasa/core/processor.py +25 -0
- rasa/core/training/interactive.py +34 -35
- rasa/core/utils.py +8 -3
- rasa/dialogue_understanding/coexistence/llm_based_router.py +39 -12
- rasa/dialogue_understanding/commands/change_flow_command.py +6 -0
- rasa/dialogue_understanding/commands/user_silence_command.py +59 -0
- rasa/dialogue_understanding/commands/utils.py +5 -0
- rasa/dialogue_understanding/generator/constants.py +2 -0
- rasa/dialogue_understanding/generator/flow_retrieval.py +49 -4
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +37 -23
- rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +57 -10
- rasa/dialogue_understanding/generator/nlu_command_adapter.py +19 -1
- rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +71 -11
- rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +39 -0
- rasa/dialogue_understanding/patterns/user_silence.py +37 -0
- rasa/dialogue_understanding/processor/command_processor.py +21 -1
- rasa/e2e_test/e2e_test_case.py +85 -6
- rasa/e2e_test/e2e_test_runner.py +4 -2
- rasa/e2e_test/utils/io.py +1 -1
- rasa/engine/validation.py +316 -10
- rasa/model_manager/config.py +15 -3
- rasa/model_manager/model_api.py +15 -7
- rasa/model_manager/runner_service.py +8 -6
- rasa/model_manager/socket_bridge.py +6 -3
- rasa/model_manager/trainer_service.py +7 -5
- rasa/model_manager/utils.py +28 -7
- rasa/model_service.py +9 -2
- rasa/model_training.py +2 -0
- rasa/nlu/classifiers/diet_classifier.py +38 -25
- rasa/nlu/classifiers/logistic_regression_classifier.py +22 -9
- rasa/nlu/classifiers/sklearn_intent_classifier.py +37 -16
- rasa/nlu/extractors/crf_entity_extractor.py +93 -50
- rasa/nlu/featurizers/sparse_featurizer/count_vectors_featurizer.py +45 -16
- rasa/nlu/featurizers/sparse_featurizer/lexical_syntactic_featurizer.py +52 -17
- rasa/nlu/featurizers/sparse_featurizer/regex_featurizer.py +5 -3
- rasa/nlu/tokenizers/whitespace_tokenizer.py +3 -14
- rasa/server.py +3 -1
- rasa/shared/constants.py +36 -3
- rasa/shared/core/constants.py +7 -0
- rasa/shared/core/domain.py +26 -0
- rasa/shared/core/flows/flow.py +5 -0
- rasa/shared/core/flows/flows_list.py +5 -1
- rasa/shared/core/flows/flows_yaml_schema.json +10 -0
- rasa/shared/core/flows/utils.py +39 -0
- rasa/shared/core/flows/validation.py +96 -0
- rasa/shared/core/slots.py +5 -0
- rasa/shared/nlu/training_data/features.py +120 -2
- rasa/shared/providers/_configs/azure_openai_client_config.py +5 -3
- rasa/shared/providers/_configs/litellm_router_client_config.py +200 -0
- rasa/shared/providers/_configs/model_group_config.py +167 -0
- rasa/shared/providers/_configs/openai_client_config.py +1 -1
- rasa/shared/providers/_configs/rasa_llm_client_config.py +73 -0
- rasa/shared/providers/_configs/self_hosted_llm_client_config.py +1 -0
- rasa/shared/providers/_configs/utils.py +16 -0
- rasa/shared/providers/embedding/_base_litellm_embedding_client.py +18 -29
- rasa/shared/providers/embedding/azure_openai_embedding_client.py +54 -21
- rasa/shared/providers/embedding/litellm_router_embedding_client.py +135 -0
- rasa/shared/providers/llm/_base_litellm_client.py +37 -31
- rasa/shared/providers/llm/azure_openai_llm_client.py +50 -29
- rasa/shared/providers/llm/litellm_router_llm_client.py +127 -0
- rasa/shared/providers/llm/rasa_llm_client.py +112 -0
- rasa/shared/providers/llm/self_hosted_llm_client.py +1 -1
- rasa/shared/providers/mappings.py +19 -0
- rasa/shared/providers/router/__init__.py +0 -0
- rasa/shared/providers/router/_base_litellm_router_client.py +149 -0
- rasa/shared/providers/router/router_client.py +73 -0
- rasa/shared/utils/common.py +8 -0
- rasa/shared/utils/health_check/__init__.py +0 -0
- rasa/shared/utils/health_check/embeddings_health_check_mixin.py +31 -0
- rasa/shared/utils/health_check/health_check.py +256 -0
- rasa/shared/utils/health_check/llm_health_check_mixin.py +31 -0
- rasa/shared/utils/io.py +28 -6
- rasa/shared/utils/llm.py +353 -46
- rasa/shared/utils/yaml.py +111 -73
- rasa/studio/auth.py +3 -5
- rasa/studio/config.py +13 -4
- rasa/studio/constants.py +1 -0
- rasa/studio/data_handler.py +10 -3
- rasa/studio/upload.py +81 -26
- rasa/telemetry.py +92 -17
- rasa/tracing/config.py +2 -0
- rasa/tracing/instrumentation/attribute_extractors.py +94 -17
- rasa/tracing/instrumentation/instrumentation.py +121 -0
- rasa/utils/common.py +5 -0
- rasa/utils/io.py +7 -81
- rasa/utils/log_utils.py +9 -2
- rasa/utils/sanic_error_handler.py +32 -0
- rasa/utils/tensorflow/feature_array.py +366 -0
- rasa/utils/tensorflow/model_data.py +2 -193
- rasa/validator.py +70 -0
- rasa/version.py +1 -1
- {rasa_pro-3.11.0a4.dev3.dist-info → rasa_pro-3.11.0rc2.dist-info}/METADATA +11 -10
- {rasa_pro-3.11.0a4.dev3.dist-info → rasa_pro-3.11.0rc2.dist-info}/RECORD +183 -163
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-587d82d8.js +0 -1
- {rasa_pro-3.11.0a4.dev3.dist-info → rasa_pro-3.11.0rc2.dist-info}/NOTICE +0 -0
- {rasa_pro-3.11.0a4.dev3.dist-info → rasa_pro-3.11.0rc2.dist-info}/WHEEL +0 -0
- {rasa_pro-3.11.0a4.dev3.dist-info → rasa_pro-3.11.0rc2.dist-info}/entry_points.txt +0 -0
|
@@ -1,5 +1,4 @@
|
|
|
1
|
-
"""
|
|
2
|
-
The module is primarily centered around the `FlowRetrieval` class which handles the
|
|
1
|
+
"""The module is primarily centered around the `FlowRetrieval` class which handles the
|
|
3
2
|
initialization, configuration validation, vector store management, and flow retrieval
|
|
4
3
|
logic. It integrates components for managing embeddings, vector stores, and
|
|
5
4
|
flow-specific templates, facilitating semantic search functionalities.
|
|
@@ -27,8 +26,10 @@ from langchain.docstore.document import Document
|
|
|
27
26
|
from langchain.schema.embeddings import Embeddings
|
|
28
27
|
from langchain_community.vectorstores.faiss import FAISS
|
|
29
28
|
from langchain_community.vectorstores.utils import DistanceStrategy
|
|
29
|
+
|
|
30
30
|
from rasa.engine.storage.resource import Resource
|
|
31
31
|
from rasa.engine.storage.storage import ModelStorage
|
|
32
|
+
import rasa.shared.utils.io
|
|
32
33
|
from rasa.shared.constants import (
|
|
33
34
|
EMBEDDINGS_CONFIG_KEY,
|
|
34
35
|
PROVIDER_CONFIG_KEY,
|
|
@@ -37,12 +38,15 @@ from rasa.shared.constants import (
|
|
|
37
38
|
from rasa.shared.core.domain import Domain
|
|
38
39
|
from rasa.shared.core.flows import FlowsList
|
|
39
40
|
from rasa.shared.core.trackers import DialogueStateTracker
|
|
41
|
+
from rasa.shared.exceptions import ProviderClientAPIException
|
|
40
42
|
from rasa.shared.nlu.constants import TEXT, FLOWS_FROM_SEMANTIC_SEARCH
|
|
41
43
|
from rasa.shared.nlu.training_data.message import Message
|
|
42
|
-
from rasa.shared.exceptions import ProviderClientAPIException
|
|
43
44
|
from rasa.shared.providers.embedding._langchain_embedding_client_adapter import (
|
|
44
45
|
_LangchainEmbeddingClientAdapter,
|
|
45
46
|
)
|
|
47
|
+
from rasa.shared.utils.health_check.embeddings_health_check_mixin import (
|
|
48
|
+
EmbeddingsHealthCheckMixin,
|
|
49
|
+
)
|
|
46
50
|
from rasa.shared.utils.llm import (
|
|
47
51
|
tracker_as_readable_transcript,
|
|
48
52
|
embedder_factory,
|
|
@@ -50,12 +54,15 @@ from rasa.shared.utils.llm import (
|
|
|
50
54
|
USER,
|
|
51
55
|
get_prompt_template,
|
|
52
56
|
allowed_values_for_slot,
|
|
57
|
+
resolve_model_client_config,
|
|
53
58
|
)
|
|
54
59
|
|
|
55
60
|
DEFAULT_FLOW_DOCUMENT_TEMPLATE = importlib.resources.read_text(
|
|
56
61
|
"rasa.dialogue_understanding.generator", "flow_document_template.jinja2"
|
|
57
62
|
)
|
|
58
63
|
|
|
64
|
+
FLOW_RETRIEVAL_CONFIG_FILE_NAME = "flow_retrieval_config.json"
|
|
65
|
+
|
|
59
66
|
DEFAULT_EMBEDDINGS_CONFIG = {
|
|
60
67
|
PROVIDER_CONFIG_KEY: OPENAI_PROVIDER,
|
|
61
68
|
"model": DEFAULT_OPENAI_EMBEDDING_MODEL_NAME,
|
|
@@ -73,7 +80,7 @@ DEFAULT_SHOULD_EMBED_SLOTS = True
|
|
|
73
80
|
structlogger = structlog.get_logger()
|
|
74
81
|
|
|
75
82
|
|
|
76
|
-
class FlowRetrieval:
|
|
83
|
+
class FlowRetrieval(EmbeddingsHealthCheckMixin):
|
|
77
84
|
@classmethod
|
|
78
85
|
def get_default_config(cls) -> Dict[str, Any]:
|
|
79
86
|
"""The default config for the flow retrieval."""
|
|
@@ -92,6 +99,9 @@ class FlowRetrieval:
|
|
|
92
99
|
):
|
|
93
100
|
config = {**self.get_default_config(), **config}
|
|
94
101
|
self.config = self.validate_config(config)
|
|
102
|
+
self.config[EMBEDDINGS_CONFIG_KEY] = resolve_model_client_config(
|
|
103
|
+
self.config.get(EMBEDDINGS_CONFIG_KEY), FlowRetrieval.__name__
|
|
104
|
+
)
|
|
95
105
|
self.vector_store: Optional[FAISS] = None
|
|
96
106
|
self.flow_document_template = get_prompt_template(
|
|
97
107
|
None, DEFAULT_FLOW_DOCUMENT_TEMPLATE
|
|
@@ -140,6 +150,18 @@ class FlowRetrieval:
|
|
|
140
150
|
**kwargs: Any,
|
|
141
151
|
) -> "FlowRetrieval":
|
|
142
152
|
"""Load flow retrieval with previously populated FAISS vector store."""
|
|
153
|
+
|
|
154
|
+
# Perform health check on resolved embedding client config
|
|
155
|
+
embeddings_config = resolve_model_client_config(
|
|
156
|
+
config.get(EMBEDDINGS_CONFIG_KEY, {})
|
|
157
|
+
)
|
|
158
|
+
cls.perform_embeddings_health_check(
|
|
159
|
+
embeddings_config,
|
|
160
|
+
DEFAULT_EMBEDDINGS_CONFIG,
|
|
161
|
+
"flow_retrieval.load",
|
|
162
|
+
FlowRetrieval.__name__,
|
|
163
|
+
)
|
|
164
|
+
|
|
143
165
|
# initialize base flow retrieval
|
|
144
166
|
flow_retrieval = FlowRetrieval(config, model_storage, resource)
|
|
145
167
|
# load vector store
|
|
@@ -147,6 +169,7 @@ class FlowRetrieval:
|
|
|
147
169
|
flow_retrieval.config, model_storage, resource
|
|
148
170
|
)
|
|
149
171
|
flow_retrieval.vector_store = vector_store
|
|
172
|
+
|
|
150
173
|
return flow_retrieval
|
|
151
174
|
|
|
152
175
|
@classmethod
|
|
@@ -178,13 +201,21 @@ class FlowRetrieval:
|
|
|
178
201
|
Returns:
|
|
179
202
|
The embedder.
|
|
180
203
|
"""
|
|
204
|
+
# Copy the config so original config is not modified
|
|
205
|
+
config = config.copy()
|
|
206
|
+
# Resolve config and instantiate the embedding client
|
|
207
|
+
config[EMBEDDINGS_CONFIG_KEY] = resolve_model_client_config(
|
|
208
|
+
config.get(EMBEDDINGS_CONFIG_KEY), FlowRetrieval.__name__
|
|
209
|
+
)
|
|
181
210
|
client = embedder_factory(
|
|
182
211
|
config.get(EMBEDDINGS_CONFIG_KEY), DEFAULT_EMBEDDINGS_CONFIG
|
|
183
212
|
)
|
|
213
|
+
# Wrap the embedding client in the adapter
|
|
184
214
|
return _LangchainEmbeddingClientAdapter(client)
|
|
185
215
|
|
|
186
216
|
def persist(self) -> None:
|
|
187
217
|
self._persist_vector_store()
|
|
218
|
+
self._persist_config()
|
|
188
219
|
|
|
189
220
|
def _persist_vector_store(self) -> None:
|
|
190
221
|
"""Persists the FAISS vector store."""
|
|
@@ -197,6 +228,12 @@ class FlowRetrieval:
|
|
|
197
228
|
event_info="Vector store is None, not persisted.",
|
|
198
229
|
)
|
|
199
230
|
|
|
231
|
+
def _persist_config(self) -> None:
|
|
232
|
+
with self._model_storage.write_to(self._resource) as path:
|
|
233
|
+
rasa.shared.utils.io.dump_obj_as_json_to_file(
|
|
234
|
+
path / FLOW_RETRIEVAL_CONFIG_FILE_NAME, self.config
|
|
235
|
+
)
|
|
236
|
+
|
|
200
237
|
def populate(self, flows: FlowsList, domain: Domain) -> None:
|
|
201
238
|
"""Populates the vector store with embeddings generated from
|
|
202
239
|
documents based on the flow descriptions, and flow slots
|
|
@@ -206,6 +243,14 @@ class FlowRetrieval:
|
|
|
206
243
|
flows: List of flows to populate the vector store with.
|
|
207
244
|
domain: The domain containing relevant slot information.
|
|
208
245
|
"""
|
|
246
|
+
# Perform health check before populating the vector store with flows
|
|
247
|
+
self.perform_embeddings_health_check(
|
|
248
|
+
self.config.get(EMBEDDINGS_CONFIG_KEY),
|
|
249
|
+
DEFAULT_EMBEDDINGS_CONFIG,
|
|
250
|
+
"flow_retrieval.train",
|
|
251
|
+
FlowRetrieval.__name__,
|
|
252
|
+
)
|
|
253
|
+
|
|
209
254
|
flows_to_embedd = flows.exclude_link_only_flows()
|
|
210
255
|
embeddings = self._create_embedder(self.config)
|
|
211
256
|
documents = self._generate_flow_documents(flows_to_embedd, domain)
|
|
@@ -2,7 +2,6 @@ from abc import ABC, abstractmethod
|
|
|
2
2
|
from functools import lru_cache
|
|
3
3
|
from typing import Dict, Any, List, Optional, Tuple, Union, Text
|
|
4
4
|
|
|
5
|
-
import os
|
|
6
5
|
import structlog
|
|
7
6
|
from jinja2 import Template
|
|
8
7
|
|
|
@@ -17,13 +16,13 @@ from rasa.dialogue_understanding.generator.constants import (
|
|
|
17
16
|
LLM_CONFIG_KEY,
|
|
18
17
|
FLOW_RETRIEVAL_KEY,
|
|
19
18
|
FLOW_RETRIEVAL_ACTIVE_KEY,
|
|
19
|
+
FLOW_RETRIEVAL_FLOW_THRESHOLD,
|
|
20
20
|
)
|
|
21
21
|
from rasa.dialogue_understanding.generator.flow_retrieval import FlowRetrieval
|
|
22
22
|
from rasa.engine.graph import GraphComponent, ExecutionContext
|
|
23
23
|
from rasa.engine.recipes.default_recipe import DefaultV1Recipe
|
|
24
24
|
from rasa.engine.storage.resource import Resource
|
|
25
25
|
from rasa.engine.storage.storage import ModelStorage
|
|
26
|
-
from rasa.shared.constants import LLM_API_HEALTH_CHECK_ENV_VAR
|
|
27
26
|
from rasa.shared.core.domain import Domain
|
|
28
27
|
from rasa.shared.core.flows import FlowStep, Flow, FlowsList
|
|
29
28
|
from rasa.shared.core.flows.steps.collect import CollectInformationFlowStep
|
|
@@ -33,11 +32,11 @@ from rasa.shared.exceptions import ProviderClientAPIException
|
|
|
33
32
|
from rasa.shared.nlu.constants import FLOWS_IN_PROMPT
|
|
34
33
|
from rasa.shared.nlu.training_data.message import Message
|
|
35
34
|
from rasa.shared.nlu.training_data.training_data import TrainingData
|
|
35
|
+
from rasa.shared.utils.health_check.llm_health_check_mixin import LLMHealthCheckMixin
|
|
36
36
|
from rasa.shared.utils.llm import (
|
|
37
37
|
allowed_values_for_slot,
|
|
38
|
-
llm_api_health_check,
|
|
39
38
|
llm_factory,
|
|
40
|
-
|
|
39
|
+
resolve_model_client_config,
|
|
41
40
|
)
|
|
42
41
|
from rasa.utils.log_utils import log_llm
|
|
43
42
|
|
|
@@ -50,7 +49,9 @@ structlogger = structlog.get_logger()
|
|
|
50
49
|
],
|
|
51
50
|
is_trainable=True,
|
|
52
51
|
)
|
|
53
|
-
class LLMBasedCommandGenerator(
|
|
52
|
+
class LLMBasedCommandGenerator(
|
|
53
|
+
LLMHealthCheckMixin, GraphComponent, CommandGenerator, ABC
|
|
54
|
+
):
|
|
54
55
|
"""An abstract class defining interface and common functionality
|
|
55
56
|
of an LLM-based command generators.
|
|
56
57
|
"""
|
|
@@ -64,6 +65,9 @@ class LLMBasedCommandGenerator(GraphComponent, CommandGenerator, ABC):
|
|
|
64
65
|
) -> None:
|
|
65
66
|
super().__init__(config)
|
|
66
67
|
self.config = {**self.get_default_config(), **config}
|
|
68
|
+
self.config[LLM_CONFIG_KEY] = resolve_model_client_config(
|
|
69
|
+
self.config.get(LLM_CONFIG_KEY), LLMBasedCommandGenerator.__name__
|
|
70
|
+
)
|
|
67
71
|
self._model_storage = model_storage
|
|
68
72
|
self._resource = resource
|
|
69
73
|
self.flow_retrieval: Optional[FlowRetrieval]
|
|
@@ -73,17 +77,9 @@ class LLMBasedCommandGenerator(GraphComponent, CommandGenerator, ABC):
|
|
|
73
77
|
self.config[FLOW_RETRIEVAL_KEY], model_storage, resource
|
|
74
78
|
)
|
|
75
79
|
structlogger.info("llm_based_command_generator.flow_retrieval.enabled")
|
|
80
|
+
self.config[FLOW_RETRIEVAL_KEY] = self.flow_retrieval.config
|
|
76
81
|
else:
|
|
77
82
|
self.flow_retrieval = None
|
|
78
|
-
structlogger.warn(
|
|
79
|
-
"llm_based_command_generator.flow_retrieval.disabled",
|
|
80
|
-
event_info=(
|
|
81
|
-
"Disabling flow retrieval can cause issues when there are a "
|
|
82
|
-
"large number of flows to be included in the prompt. For more"
|
|
83
|
-
"information see:\n"
|
|
84
|
-
"https://rasa.com/docs/rasa-pro/concepts/dialogue-understanding#how-the-llmcommandgenerator-works"
|
|
85
|
-
),
|
|
86
|
-
)
|
|
87
83
|
|
|
88
84
|
### Abstract methods
|
|
89
85
|
@staticmethod
|
|
@@ -171,18 +167,32 @@ class LLMBasedCommandGenerator(GraphComponent, CommandGenerator, ABC):
|
|
|
171
167
|
"""Train the llm based command generator. Stores all flows into a vector
|
|
172
168
|
store.
|
|
173
169
|
"""
|
|
174
|
-
|
|
175
|
-
llm_client = try_instantiate_llm_client(
|
|
170
|
+
self.perform_llm_health_check(
|
|
176
171
|
self.config.get(LLM_CONFIG_KEY),
|
|
177
172
|
DEFAULT_LLM_CONFIG,
|
|
178
173
|
"llm_based_command_generator.train",
|
|
179
174
|
LLMBasedCommandGenerator.__name__,
|
|
180
175
|
)
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
176
|
+
|
|
177
|
+
if (
|
|
178
|
+
self.flow_retrieval is None
|
|
179
|
+
and len(flows.user_flows) > FLOW_RETRIEVAL_FLOW_THRESHOLD
|
|
180
|
+
):
|
|
181
|
+
structlogger.warn(
|
|
182
|
+
"llm_based_command_generator.flow_retrieval.disabled",
|
|
183
|
+
event_info=(
|
|
184
|
+
f"You have {len(flows.user_flows)} user flows but flow "
|
|
185
|
+
f"retrieval is disabled. "
|
|
186
|
+
f"It is recommended to enable flow retrieval if the "
|
|
187
|
+
f"total number of user flows exceed "
|
|
188
|
+
f"{FLOW_RETRIEVAL_FLOW_THRESHOLD}. "
|
|
189
|
+
f"Keeping it disabled can result in deterioration of "
|
|
190
|
+
f"command generator's functional "
|
|
191
|
+
f"performance and higher costs because of increased "
|
|
192
|
+
f"number of tokens in the prompt. For more"
|
|
193
|
+
"information see:\n"
|
|
194
|
+
"https://rasa.com/docs/rasa-pro/concepts/dialogue-understanding#how-the-llmcommandgenerator-works"
|
|
195
|
+
),
|
|
186
196
|
)
|
|
187
197
|
|
|
188
198
|
# flow retrieval is populated with only user-defined flows
|
|
@@ -192,10 +202,11 @@ class LLMBasedCommandGenerator(GraphComponent, CommandGenerator, ABC):
|
|
|
192
202
|
except Exception as e:
|
|
193
203
|
structlogger.error(
|
|
194
204
|
"llm_based_command_generator.train.failed",
|
|
195
|
-
event_info=
|
|
205
|
+
event_info="Flow retrieval store is inaccessible.",
|
|
196
206
|
error=e,
|
|
197
207
|
)
|
|
198
208
|
raise
|
|
209
|
+
|
|
199
210
|
self.persist()
|
|
200
211
|
return self._resource
|
|
201
212
|
|
|
@@ -233,7 +244,10 @@ class LLMBasedCommandGenerator(GraphComponent, CommandGenerator, ABC):
|
|
|
233
244
|
|
|
234
245
|
@classmethod
|
|
235
246
|
def load_flow_retrival(
|
|
236
|
-
cls,
|
|
247
|
+
cls,
|
|
248
|
+
config: Dict[str, Any],
|
|
249
|
+
model_storage: ModelStorage,
|
|
250
|
+
resource: Resource,
|
|
237
251
|
) -> Optional[FlowRetrieval]:
|
|
238
252
|
"""Load the FlowRetrieval component if it is enabled in the configuration."""
|
|
239
253
|
enable_flow_retrieval = config.get(FLOW_RETRIEVAL_KEY, {}).get(
|
|
@@ -24,6 +24,7 @@ from rasa.dialogue_understanding.generator.constants import (
|
|
|
24
24
|
LLM_CONFIG_KEY,
|
|
25
25
|
USER_INPUT_CONFIG_KEY,
|
|
26
26
|
FLOW_RETRIEVAL_KEY,
|
|
27
|
+
DEFAULT_LLM_CONFIG,
|
|
27
28
|
)
|
|
28
29
|
from rasa.dialogue_understanding.generator.flow_retrieval import FlowRetrieval
|
|
29
30
|
from rasa.dialogue_understanding.generator.llm_based_command_generator import (
|
|
@@ -39,7 +40,10 @@ from rasa.engine.graph import ExecutionContext
|
|
|
39
40
|
from rasa.engine.recipes.default_recipe import DefaultV1Recipe
|
|
40
41
|
from rasa.engine.storage.resource import Resource
|
|
41
42
|
from rasa.engine.storage.storage import ModelStorage
|
|
42
|
-
from rasa.shared.constants import
|
|
43
|
+
from rasa.shared.constants import (
|
|
44
|
+
RASA_PATTERN_CANNOT_HANDLE_NOT_SUPPORTED,
|
|
45
|
+
EMBEDDINGS_CONFIG_KEY,
|
|
46
|
+
)
|
|
43
47
|
from rasa.shared.constants import ROUTE_TO_CALM_SLOT
|
|
44
48
|
from rasa.shared.core.flows import FlowStep, Flow, FlowsList
|
|
45
49
|
from rasa.shared.core.flows.steps.collect import CollectInformationFlowStep
|
|
@@ -53,6 +57,7 @@ from rasa.shared.utils.llm import (
|
|
|
53
57
|
tracker_as_readable_transcript,
|
|
54
58
|
sanitize_message_for_prompt,
|
|
55
59
|
allowed_values_for_slot,
|
|
60
|
+
resolve_model_client_config,
|
|
56
61
|
)
|
|
57
62
|
|
|
58
63
|
# multistep template keys
|
|
@@ -70,6 +75,7 @@ DEFAULT_HANDLE_FLOWS_TEMPLATE = importlib.resources.read_text(
|
|
|
70
75
|
DEFAULT_FILL_SLOTS_TEMPLATE = importlib.resources.read_text(
|
|
71
76
|
"rasa.dialogue_understanding.generator.multi_step", "fill_slots_prompt.jinja2"
|
|
72
77
|
).strip()
|
|
78
|
+
MULTI_STEP_LLM_COMMAND_GENERATOR_CONFIG_FILE = "config.json"
|
|
73
79
|
|
|
74
80
|
# dictionary of template names and associated file names and default values
|
|
75
81
|
PROMPT_TEMPLATES = {
|
|
@@ -138,7 +144,18 @@ class MultiStepLLMCommandGenerator(LLMBasedCommandGenerator):
|
|
|
138
144
|
**kwargs: Any,
|
|
139
145
|
) -> "MultiStepLLMCommandGenerator":
|
|
140
146
|
"""Loads trained component (see parent class for full docstring)."""
|
|
147
|
+
|
|
148
|
+
# Perform health check of the LLM client config
|
|
149
|
+
llm_config = resolve_model_client_config(config.get(LLM_CONFIG_KEY, {}))
|
|
150
|
+
cls.perform_llm_health_check(
|
|
151
|
+
llm_config,
|
|
152
|
+
DEFAULT_LLM_CONFIG,
|
|
153
|
+
"multi_step_llm_command_generator.load",
|
|
154
|
+
MultiStepLLMCommandGenerator.__name__,
|
|
155
|
+
)
|
|
156
|
+
|
|
141
157
|
prompts = cls._load_prompt_templates(model_storage, resource)
|
|
158
|
+
|
|
142
159
|
# init base command generator
|
|
143
160
|
command_generator = cls(config, model_storage, resource, prompts)
|
|
144
161
|
# load flow retrieval if enabled
|
|
@@ -146,13 +163,13 @@ class MultiStepLLMCommandGenerator(LLMBasedCommandGenerator):
|
|
|
146
163
|
command_generator.flow_retrieval = cls.load_flow_retrival(
|
|
147
164
|
command_generator.config, model_storage, resource
|
|
148
165
|
)
|
|
166
|
+
|
|
149
167
|
return command_generator
|
|
150
168
|
|
|
151
169
|
def persist(self) -> None:
|
|
152
170
|
"""Persist this component to disk for future loading."""
|
|
153
|
-
# persist prompt template
|
|
154
171
|
self._persist_prompt_templates()
|
|
155
|
-
|
|
172
|
+
self._persist_config()
|
|
156
173
|
if self.flow_retrieval is not None:
|
|
157
174
|
self.flow_retrieval.persist()
|
|
158
175
|
|
|
@@ -229,9 +246,9 @@ class MultiStepLLMCommandGenerator(LLMBasedCommandGenerator):
|
|
|
229
246
|
commands: List[Command] = []
|
|
230
247
|
|
|
231
248
|
slot_set_re = re.compile(
|
|
232
|
-
r"""SetSlot\(
|
|
249
|
+
r"""SetSlot\(['"]?([a-zA-Z_][a-zA-Z0-9_-]*)['"]?, ?['"]?(.*)['"]?\)"""
|
|
233
250
|
)
|
|
234
|
-
start_flow_re = re.compile(r"StartFlow\(([a-zA-Z0-9_-]
|
|
251
|
+
start_flow_re = re.compile(r"StartFlow\(['\"]?([a-zA-Z0-9_-]+)['\"]?\)")
|
|
235
252
|
change_flow_re = re.compile(r"ChangeFlow\(\)")
|
|
236
253
|
cancel_flow_re = re.compile(r"CancelFlow\(\)")
|
|
237
254
|
chitchat_re = re.compile(r"ChitChat\(\)")
|
|
@@ -280,9 +297,19 @@ class MultiStepLLMCommandGenerator(LLMBasedCommandGenerator):
|
|
|
280
297
|
commands.append(HumanHandoffCommand())
|
|
281
298
|
elif match := clarify_re.search(action):
|
|
282
299
|
options = sorted([opt.strip() for opt in match.group(1).split(",")])
|
|
300
|
+
# Remove surrounding quotes if present
|
|
301
|
+
cleaned_options = []
|
|
302
|
+
for flow in options:
|
|
303
|
+
if (flow.startswith('"') and flow.endswith('"')) or (
|
|
304
|
+
flow.startswith("'") and flow.endswith("'")
|
|
305
|
+
):
|
|
306
|
+
cleaned_options.append(flow[1:-1])
|
|
307
|
+
else:
|
|
308
|
+
cleaned_options.append(flow)
|
|
309
|
+
# check if flow is valid
|
|
283
310
|
valid_options = [
|
|
284
311
|
flow
|
|
285
|
-
for flow in
|
|
312
|
+
for flow in cleaned_options
|
|
286
313
|
if flow in flows.user_flow_ids
|
|
287
314
|
and flow not in user_flows_on_the_stack(tracker.stack)
|
|
288
315
|
]
|
|
@@ -293,6 +320,13 @@ class MultiStepLLMCommandGenerator(LLMBasedCommandGenerator):
|
|
|
293
320
|
elif change_flow_re.search(action):
|
|
294
321
|
commands.append(ChangeFlowCommand())
|
|
295
322
|
|
|
323
|
+
if not commands:
|
|
324
|
+
structlogger.debug(
|
|
325
|
+
"multi_step_llm_command_generator.parse_commands",
|
|
326
|
+
message="No commands were parsed from the LLM actions.",
|
|
327
|
+
actions=actions,
|
|
328
|
+
)
|
|
329
|
+
|
|
296
330
|
return commands
|
|
297
331
|
|
|
298
332
|
### Helper methods
|
|
@@ -368,6 +402,13 @@ class MultiStepLLMCommandGenerator(LLMBasedCommandGenerator):
|
|
|
368
402
|
file_path = path / file_name
|
|
369
403
|
rasa.shared.utils.io.write_text_file(template, file_path)
|
|
370
404
|
|
|
405
|
+
def _persist_config(self) -> None:
|
|
406
|
+
"""Persist config as a source of truth for resolved clients."""
|
|
407
|
+
with self._model_storage.write_to(self._resource) as path:
|
|
408
|
+
rasa.shared.utils.io.dump_obj_as_json_to_file(
|
|
409
|
+
path / MULTI_STEP_LLM_COMMAND_GENERATOR_CONFIG_FILE, self.config
|
|
410
|
+
)
|
|
411
|
+
|
|
371
412
|
async def _predict_commands_with_multi_step(
|
|
372
413
|
self,
|
|
373
414
|
message: Message,
|
|
@@ -761,11 +802,17 @@ class MultiStepLLMCommandGenerator(LLMBasedCommandGenerator):
|
|
|
761
802
|
.get(FILE_PATH_KEY),
|
|
762
803
|
DEFAULT_FILL_SLOTS_TEMPLATE,
|
|
763
804
|
)
|
|
805
|
+
|
|
806
|
+
llm_config = resolve_model_client_config(
|
|
807
|
+
config.get(LLM_CONFIG_KEY), MultiStepLLMCommandGenerator.__name__
|
|
808
|
+
)
|
|
809
|
+
embedding_config = resolve_model_client_config(
|
|
810
|
+
config.get(FLOW_RETRIEVAL_KEY, {}).get(EMBEDDINGS_CONFIG_KEY),
|
|
811
|
+
FlowRetrieval.__name__,
|
|
812
|
+
)
|
|
813
|
+
|
|
764
814
|
return deep_container_fingerprint(
|
|
765
|
-
[
|
|
766
|
-
handle_flows_template,
|
|
767
|
-
fill_slots_template,
|
|
768
|
-
]
|
|
815
|
+
[handle_flows_template, fill_slots_template, llm_config, embedding_config]
|
|
769
816
|
)
|
|
770
817
|
|
|
771
818
|
@staticmethod
|
|
@@ -19,6 +19,7 @@ from rasa.engine.storage.storage import ModelStorage
|
|
|
19
19
|
from rasa.shared.constants import ROUTE_TO_CALM_SLOT
|
|
20
20
|
from rasa.shared.core.domain import Domain
|
|
21
21
|
from rasa.shared.core.flows.flows_list import FlowsList
|
|
22
|
+
from rasa.shared.core.flows.steps import CollectInformationFlowStep
|
|
22
23
|
from rasa.shared.core.slot_mappings import (
|
|
23
24
|
SlotFillingManager,
|
|
24
25
|
extract_slot_value,
|
|
@@ -217,7 +218,24 @@ def _issue_set_slot_commands(
|
|
|
217
218
|
commands: List[Command] = []
|
|
218
219
|
domain = domain if domain else Domain.empty()
|
|
219
220
|
slot_filling_manager = SlotFillingManager(domain, tracker, message)
|
|
220
|
-
|
|
221
|
+
|
|
222
|
+
# only use slots that don't have ask_before_filling set to True
|
|
223
|
+
available_slot_names = flows.available_slot_names(ask_before_filling=False)
|
|
224
|
+
|
|
225
|
+
# check if the current step is a CollectInformationFlowStep
|
|
226
|
+
# in case it has ask_before_filling set to True, we need to add the
|
|
227
|
+
# slot to the available_slot_names
|
|
228
|
+
if tracker.active_flow:
|
|
229
|
+
flow = flows.flow_by_id(tracker.active_flow)
|
|
230
|
+
step_id = tracker.current_step_id
|
|
231
|
+
if flow is not None:
|
|
232
|
+
current_step = flow.step_by_id(step_id)
|
|
233
|
+
if (
|
|
234
|
+
current_step
|
|
235
|
+
and isinstance(current_step, CollectInformationFlowStep)
|
|
236
|
+
and current_step.ask_before_filling
|
|
237
|
+
):
|
|
238
|
+
available_slot_names.add(current_step.collect)
|
|
221
239
|
|
|
222
240
|
for _, slot in tracker.slots.items():
|
|
223
241
|
# if a slot is not collected in available flows,
|
|
@@ -22,6 +22,7 @@ from rasa.dialogue_understanding.generator.constants import (
|
|
|
22
22
|
LLM_CONFIG_KEY,
|
|
23
23
|
USER_INPUT_CONFIG_KEY,
|
|
24
24
|
FLOW_RETRIEVAL_KEY,
|
|
25
|
+
DEFAULT_LLM_CONFIG,
|
|
25
26
|
)
|
|
26
27
|
from rasa.dialogue_understanding.generator.flow_retrieval import (
|
|
27
28
|
FlowRetrieval,
|
|
@@ -38,6 +39,7 @@ from rasa.shared.constants import (
|
|
|
38
39
|
ROUTE_TO_CALM_SLOT,
|
|
39
40
|
PROMPT_CONFIG_KEY,
|
|
40
41
|
PROMPT_TEMPLATE_CONFIG_KEY,
|
|
42
|
+
EMBEDDINGS_CONFIG_KEY,
|
|
41
43
|
)
|
|
42
44
|
from rasa.shared.core.flows import FlowsList
|
|
43
45
|
from rasa.shared.core.trackers import DialogueStateTracker
|
|
@@ -49,9 +51,10 @@ from rasa.shared.utils.llm import (
|
|
|
49
51
|
get_prompt_template,
|
|
50
52
|
tracker_as_readable_transcript,
|
|
51
53
|
sanitize_message_for_prompt,
|
|
54
|
+
resolve_model_client_config,
|
|
52
55
|
)
|
|
53
|
-
from rasa.utils.log_utils import log_llm
|
|
54
56
|
from rasa.utils.beta import ensure_beta_feature_is_enabled, BetaNotEnabledException
|
|
57
|
+
from rasa.utils.log_utils import log_llm
|
|
55
58
|
|
|
56
59
|
COMMAND_PROMPT_FILE_NAME = "command_prompt.jinja2"
|
|
57
60
|
|
|
@@ -59,6 +62,7 @@ DEFAULT_COMMAND_PROMPT_TEMPLATE = importlib.resources.read_text(
|
|
|
59
62
|
"rasa.dialogue_understanding.generator.single_step",
|
|
60
63
|
"command_prompt_template.jinja2",
|
|
61
64
|
)
|
|
65
|
+
SINGLE_STEP_LLM_COMMAND_GENERATOR_CONFIG_FILE = "config.json"
|
|
62
66
|
|
|
63
67
|
structlogger = structlog.get_logger()
|
|
64
68
|
|
|
@@ -132,10 +136,21 @@ class SingleStepLLMCommandGenerator(LLMBasedCommandGenerator):
|
|
|
132
136
|
**kwargs: Any,
|
|
133
137
|
) -> "SingleStepLLMCommandGenerator":
|
|
134
138
|
"""Loads trained component (see parent class for full docstring)."""
|
|
139
|
+
|
|
140
|
+
# Perform health check of the LLM API endpoint
|
|
141
|
+
llm_config = resolve_model_client_config(config.get(LLM_CONFIG_KEY, {}))
|
|
142
|
+
cls.perform_llm_health_check(
|
|
143
|
+
llm_config,
|
|
144
|
+
DEFAULT_LLM_CONFIG,
|
|
145
|
+
"single_step_llm_command_generator.load",
|
|
146
|
+
SingleStepLLMCommandGenerator.__name__,
|
|
147
|
+
)
|
|
148
|
+
|
|
135
149
|
# load prompt template from the model storage.
|
|
136
150
|
prompt_template = cls.load_prompt_template_from_model_storage(
|
|
137
151
|
model_storage, resource, COMMAND_PROMPT_FILE_NAME
|
|
138
152
|
)
|
|
153
|
+
|
|
139
154
|
# init base command generator
|
|
140
155
|
command_generator = cls(config, model_storage, resource, prompt_template)
|
|
141
156
|
# load flow retrieval if enabled
|
|
@@ -143,18 +158,29 @@ class SingleStepLLMCommandGenerator(LLMBasedCommandGenerator):
|
|
|
143
158
|
command_generator.flow_retrieval = cls.load_flow_retrival(
|
|
144
159
|
command_generator.config, model_storage, resource
|
|
145
160
|
)
|
|
161
|
+
|
|
146
162
|
return command_generator
|
|
147
163
|
|
|
148
164
|
def persist(self) -> None:
|
|
149
165
|
"""Persist this component to disk for future loading."""
|
|
150
|
-
|
|
166
|
+
self._persist_prompt_template()
|
|
167
|
+
self._persist_config()
|
|
168
|
+
if self.flow_retrieval is not None:
|
|
169
|
+
self.flow_retrieval.persist()
|
|
170
|
+
|
|
171
|
+
def _persist_prompt_template(self) -> None:
|
|
172
|
+
"""Persist prompt template for future loading."""
|
|
151
173
|
with self._model_storage.write_to(self._resource) as path:
|
|
152
174
|
rasa.shared.utils.io.write_text_file(
|
|
153
175
|
self.prompt_template, path / COMMAND_PROMPT_FILE_NAME
|
|
154
176
|
)
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
177
|
+
|
|
178
|
+
def _persist_config(self) -> None:
|
|
179
|
+
"""Persist config as a source of truth for resolved clients."""
|
|
180
|
+
with self._model_storage.write_to(self._resource) as path:
|
|
181
|
+
rasa.shared.utils.io.dump_obj_as_json_to_file(
|
|
182
|
+
path / SINGLE_STEP_LLM_COMMAND_GENERATOR_CONFIG_FILE, self.config
|
|
183
|
+
)
|
|
158
184
|
|
|
159
185
|
async def predict_commands(
|
|
160
186
|
self,
|
|
@@ -187,6 +213,12 @@ class SingleStepLLMCommandGenerator(LLMBasedCommandGenerator):
|
|
|
187
213
|
|
|
188
214
|
if not commands:
|
|
189
215
|
# no commands are parsed or there's an invalid command
|
|
216
|
+
structlogger.warning(
|
|
217
|
+
"single_step_llm_command_generator.predict_commands",
|
|
218
|
+
message="No commands were predicted as the LLM response could "
|
|
219
|
+
"not be parsed or the LLM responded with an invalid command."
|
|
220
|
+
"Returning a CannotHandleCommand instead.",
|
|
221
|
+
)
|
|
190
222
|
commands = [CannotHandleCommand()]
|
|
191
223
|
|
|
192
224
|
if tracker.has_coexistence_routing_slot:
|
|
@@ -287,14 +319,16 @@ class SingleStepLLMCommandGenerator(LLMBasedCommandGenerator):
|
|
|
287
319
|
|
|
288
320
|
commands: List[Command] = []
|
|
289
321
|
|
|
290
|
-
slot_set_re = re.compile(
|
|
291
|
-
|
|
322
|
+
slot_set_re = re.compile(
|
|
323
|
+
r"""SetSlot\(['"]?([a-zA-Z_][a-zA-Z0-9_-]*)['"]?, ?['"]?(.*)['"]?\)"""
|
|
324
|
+
)
|
|
325
|
+
start_flow_re = re.compile(r"StartFlow\(['\"]?([a-zA-Z0-9_-]+)['\"]?\)")
|
|
292
326
|
cancel_flow_re = re.compile(r"CancelFlow\(\)")
|
|
293
327
|
chitchat_re = re.compile(r"ChitChat\(\)")
|
|
294
328
|
skip_question_re = re.compile(r"SkipQuestion\(\)")
|
|
295
329
|
knowledge_re = re.compile(r"SearchAndReply\(\)")
|
|
296
330
|
humand_handoff_re = re.compile(r"HumanHandoff\(\)")
|
|
297
|
-
clarify_re = re.compile(r"Clarify\(([a-zA-Z0-9_, ]+)\)")
|
|
331
|
+
clarify_re = re.compile(r"Clarify\(([\"\'a-zA-Z0-9_, ]+)\)")
|
|
298
332
|
repeat_re = re.compile(r"RepeatLastBotMessages\(\)")
|
|
299
333
|
|
|
300
334
|
for action in actions.strip().splitlines():
|
|
@@ -326,19 +360,36 @@ class SingleStepLLMCommandGenerator(LLMBasedCommandGenerator):
|
|
|
326
360
|
commands.append(RepeatBotMessagesCommand())
|
|
327
361
|
elif match := clarify_re.search(action):
|
|
328
362
|
options = sorted([opt.strip() for opt in match.group(1).split(",")])
|
|
363
|
+
# Remove surrounding quotes if present
|
|
364
|
+
cleaned_options = []
|
|
365
|
+
for flow in options:
|
|
366
|
+
if (flow.startswith('"') and flow.endswith('"')) or (
|
|
367
|
+
flow.startswith("'") and flow.endswith("'")
|
|
368
|
+
):
|
|
369
|
+
cleaned_options.append(flow[1:-1])
|
|
370
|
+
else:
|
|
371
|
+
cleaned_options.append(flow)
|
|
372
|
+
# check if flow is valid
|
|
329
373
|
valid_options = [
|
|
330
|
-
flow for flow in
|
|
374
|
+
flow for flow in cleaned_options if flow in flows.user_flow_ids
|
|
331
375
|
]
|
|
332
376
|
if len(set(valid_options)) == 1:
|
|
333
377
|
commands.extend(cls.start_flow_by_name(valid_options[0], flows))
|
|
334
378
|
elif len(valid_options) > 1:
|
|
335
379
|
commands.append(ClarifyCommand(valid_options))
|
|
336
380
|
|
|
381
|
+
if not commands:
|
|
382
|
+
structlogger.debug(
|
|
383
|
+
"single_step_llm_command_generator.parse_commands",
|
|
384
|
+
message="No commands were parsed from the LLM actions.",
|
|
385
|
+
actions=actions,
|
|
386
|
+
)
|
|
387
|
+
|
|
337
388
|
return commands
|
|
338
389
|
|
|
339
390
|
@classmethod
|
|
340
391
|
def fingerprint_addon(cls: Any, config: Dict[str, Any]) -> Optional[str]:
|
|
341
|
-
"""Add a fingerprint
|
|
392
|
+
"""Add a fingerprint for the graph."""
|
|
342
393
|
config_prompt = (
|
|
343
394
|
config.get(PROMPT_CONFIG_KEY)
|
|
344
395
|
or config.get(PROMPT_TEMPLATE_CONFIG_KEY)
|
|
@@ -348,7 +399,16 @@ class SingleStepLLMCommandGenerator(LLMBasedCommandGenerator):
|
|
|
348
399
|
config_prompt,
|
|
349
400
|
DEFAULT_COMMAND_PROMPT_TEMPLATE,
|
|
350
401
|
)
|
|
351
|
-
|
|
402
|
+
llm_config = resolve_model_client_config(
|
|
403
|
+
config.get(LLM_CONFIG_KEY), SingleStepLLMCommandGenerator.__name__
|
|
404
|
+
)
|
|
405
|
+
embedding_config = resolve_model_client_config(
|
|
406
|
+
config.get(FLOW_RETRIEVAL_KEY, {}).get(EMBEDDINGS_CONFIG_KEY),
|
|
407
|
+
FlowRetrieval.__name__,
|
|
408
|
+
)
|
|
409
|
+
return deep_container_fingerprint(
|
|
410
|
+
[prompt_template, llm_config, embedding_config]
|
|
411
|
+
)
|
|
352
412
|
|
|
353
413
|
### Helper methods
|
|
354
414
|
def render_template(
|