rasa-pro 3.11.0__py3-none-any.whl → 3.11.0a2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- README.md +396 -17
- rasa/__main__.py +15 -31
- rasa/api.py +1 -5
- rasa/cli/arguments/default_arguments.py +2 -1
- rasa/cli/arguments/shell.py +1 -5
- rasa/cli/arguments/train.py +0 -14
- rasa/cli/e2e_test.py +1 -1
- rasa/cli/evaluate.py +8 -8
- rasa/cli/inspect.py +5 -7
- rasa/cli/interactive.py +0 -1
- rasa/cli/llm_fine_tuning.py +1 -1
- rasa/cli/project_templates/calm/config.yml +7 -5
- rasa/cli/project_templates/calm/endpoints.yml +2 -15
- rasa/cli/project_templates/tutorial/config.yml +5 -8
- rasa/cli/project_templates/tutorial/data/flows.yml +1 -1
- rasa/cli/project_templates/tutorial/data/patterns.yml +0 -5
- rasa/cli/project_templates/tutorial/domain.yml +0 -14
- rasa/cli/project_templates/tutorial/endpoints.yml +0 -5
- rasa/cli/run.py +1 -1
- rasa/cli/scaffold.py +2 -4
- rasa/cli/studio/studio.py +8 -18
- rasa/cli/studio/upload.py +15 -0
- rasa/cli/train.py +0 -3
- rasa/cli/utils.py +1 -6
- rasa/cli/x.py +8 -8
- rasa/constants.py +1 -3
- rasa/core/actions/action.py +33 -75
- rasa/core/actions/e2e_stub_custom_action_executor.py +1 -5
- rasa/core/actions/http_custom_action_executor.py +0 -4
- rasa/core/channels/channel.py +0 -20
- rasa/core/channels/development_inspector.py +2 -8
- rasa/core/channels/inspector/dist/assets/{arc-bc141fb2.js → arc-6852c607.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{c4Diagram-d0fbc5ce-be2db283.js → c4Diagram-d0fbc5ce-acc952b2.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{classDiagram-936ed81e-55366915.js → classDiagram-936ed81e-848a7597.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{classDiagram-v2-c3cb15f1-bb529518.js → classDiagram-v2-c3cb15f1-a73d3e68.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{createText-62fc7601-b0ec81d6.js → createText-62fc7601-e5ee049d.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{edges-f2ad444c-6166330c.js → edges-f2ad444c-771e517e.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{erDiagram-9d236eb7-5ccc6a8e.js → erDiagram-9d236eb7-aa347178.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{flowDb-1972c806-fca3bfe4.js → flowDb-1972c806-651fc57d.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{flowDiagram-7ea5b25a-4739080f.js → flowDiagram-7ea5b25a-ca67804f.js} +1 -1
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-587d82d8.js +1 -0
- rasa/core/channels/inspector/dist/assets/{flowchart-elk-definition-abe16c3d-7c1b0e0f.js → flowchart-elk-definition-abe16c3d-2dbc568d.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{ganttDiagram-9b5ea136-772fd050.js → ganttDiagram-9b5ea136-25a65bd8.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{gitGraphDiagram-99d0ae7c-8eae1dc9.js → gitGraphDiagram-99d0ae7c-fdc7378d.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{index-2c4b9a3b-f55afcdf.js → index-2c4b9a3b-6f1fd606.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{index-e7cef9de.js → index-efdd30c1.js} +68 -68
- rasa/core/channels/inspector/dist/assets/{infoDiagram-736b4530-124d4a14.js → infoDiagram-736b4530-cb1a041a.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{journeyDiagram-df861f2b-7c4fae44.js → journeyDiagram-df861f2b-14609879.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{layout-b9885fb6.js → layout-2490f52b.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{line-7c59abb6.js → line-40186f1f.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{linear-4776f780.js → linear-08814e93.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{mindmap-definition-beec6740-2332c46c.js → mindmap-definition-beec6740-1a534584.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{pieDiagram-dbbf0591-8fb39303.js → pieDiagram-dbbf0591-72397b61.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{quadrantDiagram-4d7f4fd6-3c7180a2.js → quadrantDiagram-4d7f4fd6-3bb0b6a3.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{requirementDiagram-6fc4c22a-e910bcb8.js → requirementDiagram-6fc4c22a-57334f61.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{sankeyDiagram-8f13d901-ead16c89.js → sankeyDiagram-8f13d901-111e1297.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{sequenceDiagram-b655622a-29a02a19.js → sequenceDiagram-b655622a-10bcfe62.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{stateDiagram-59f0c015-042b3137.js → stateDiagram-59f0c015-acaf7513.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{stateDiagram-v2-2b26beab-2178c0f3.js → stateDiagram-v2-2b26beab-3ec2a235.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-080da4f6-23ffa4fc.js → styles-080da4f6-62730289.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-3dcbcfbf-94f59763.js → styles-3dcbcfbf-5284ee76.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-9c745c82-78a6bebc.js → styles-9c745c82-642435e3.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{svgDrawCommon-4835440b-eae2a6f6.js → svgDrawCommon-4835440b-b250a350.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{timeline-definition-5b62e21b-5c968d92.js → timeline-definition-5b62e21b-c2b147ed.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{xychartDiagram-2b33534f-fd3db0d5.js → xychartDiagram-2b33534f-f92cfea9.js} +1 -1
- rasa/core/channels/inspector/dist/index.html +1 -1
- rasa/core/channels/inspector/src/App.tsx +1 -1
- rasa/core/channels/inspector/src/helpers/audiostream.ts +16 -77
- rasa/core/channels/socketio.py +2 -7
- rasa/core/channels/telegram.py +1 -1
- rasa/core/channels/twilio.py +1 -1
- rasa/core/channels/voice_ready/audiocodes.py +4 -15
- rasa/core/channels/voice_ready/jambonz.py +4 -15
- rasa/core/channels/voice_ready/twilio_voice.py +21 -6
- rasa/core/channels/voice_ready/utils.py +5 -6
- rasa/core/channels/voice_stream/asr/asr_engine.py +1 -19
- rasa/core/channels/voice_stream/asr/asr_event.py +0 -5
- rasa/core/channels/voice_stream/asr/deepgram.py +15 -28
- rasa/core/channels/voice_stream/audio_bytes.py +0 -1
- rasa/core/channels/voice_stream/browser_audio.py +9 -32
- rasa/core/channels/voice_stream/tts/azure.py +3 -9
- rasa/core/channels/voice_stream/tts/cartesia.py +8 -12
- rasa/core/channels/voice_stream/tts/tts_engine.py +1 -11
- rasa/core/channels/voice_stream/twilio_media_streams.py +19 -28
- rasa/core/channels/voice_stream/util.py +4 -4
- rasa/core/channels/voice_stream/voice_channel.py +42 -222
- rasa/core/featurizers/single_state_featurizer.py +1 -22
- rasa/core/featurizers/tracker_featurizers.py +18 -115
- rasa/core/information_retrieval/qdrant.py +0 -1
- rasa/core/nlg/contextual_response_rephraser.py +25 -44
- rasa/core/persistor.py +34 -191
- rasa/core/policies/enterprise_search_policy.py +60 -119
- rasa/core/policies/flows/flow_executor.py +4 -7
- rasa/core/policies/intentless_policy.py +22 -82
- rasa/core/policies/ted_policy.py +33 -58
- rasa/core/policies/unexpected_intent_policy.py +7 -15
- rasa/core/processor.py +5 -32
- rasa/core/training/interactive.py +35 -34
- rasa/core/utils.py +22 -58
- rasa/dialogue_understanding/coexistence/llm_based_router.py +12 -39
- rasa/dialogue_understanding/commands/__init__.py +0 -4
- rasa/dialogue_understanding/commands/change_flow_command.py +0 -6
- rasa/dialogue_understanding/commands/utils.py +0 -5
- rasa/dialogue_understanding/generator/constants.py +0 -2
- rasa/dialogue_understanding/generator/flow_retrieval.py +4 -49
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +23 -37
- rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +10 -57
- rasa/dialogue_understanding/generator/nlu_command_adapter.py +1 -19
- rasa/dialogue_understanding/generator/single_step/command_prompt_template.jinja2 +0 -3
- rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +10 -90
- rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +0 -53
- rasa/dialogue_understanding/processor/command_processor.py +1 -21
- rasa/e2e_test/assertions.py +16 -133
- rasa/e2e_test/assertions_schema.yml +0 -23
- rasa/e2e_test/e2e_test_case.py +6 -85
- rasa/e2e_test/e2e_test_runner.py +4 -6
- rasa/e2e_test/utils/io.py +1 -3
- rasa/engine/loader.py +0 -12
- rasa/engine/validation.py +11 -541
- rasa/keys +1 -0
- rasa/llm_fine_tuning/notebooks/unsloth_finetuning.ipynb +407 -0
- rasa/model_training.py +7 -29
- rasa/nlu/classifiers/diet_classifier.py +25 -38
- rasa/nlu/classifiers/logistic_regression_classifier.py +9 -22
- rasa/nlu/classifiers/sklearn_intent_classifier.py +16 -37
- rasa/nlu/extractors/crf_entity_extractor.py +50 -93
- rasa/nlu/featurizers/sparse_featurizer/count_vectors_featurizer.py +16 -45
- rasa/nlu/featurizers/sparse_featurizer/lexical_syntactic_featurizer.py +17 -52
- rasa/nlu/featurizers/sparse_featurizer/regex_featurizer.py +3 -5
- rasa/nlu/tokenizers/whitespace_tokenizer.py +14 -3
- rasa/server.py +1 -3
- rasa/shared/constants.py +0 -61
- rasa/shared/core/constants.py +0 -9
- rasa/shared/core/domain.py +5 -8
- rasa/shared/core/flows/flow.py +0 -5
- rasa/shared/core/flows/flows_list.py +1 -5
- rasa/shared/core/flows/flows_yaml_schema.json +0 -10
- rasa/shared/core/flows/validation.py +0 -96
- rasa/shared/core/flows/yaml_flows_io.py +4 -13
- rasa/shared/core/slots.py +0 -5
- rasa/shared/importers/importer.py +2 -19
- rasa/shared/importers/rasa.py +1 -5
- rasa/shared/nlu/training_data/features.py +2 -120
- rasa/shared/nlu/training_data/formats/rasa_yaml.py +3 -18
- rasa/shared/providers/_configs/azure_openai_client_config.py +3 -5
- rasa/shared/providers/_configs/openai_client_config.py +1 -1
- rasa/shared/providers/_configs/self_hosted_llm_client_config.py +0 -1
- rasa/shared/providers/_configs/utils.py +0 -16
- rasa/shared/providers/embedding/_base_litellm_embedding_client.py +29 -18
- rasa/shared/providers/embedding/azure_openai_embedding_client.py +21 -54
- rasa/shared/providers/embedding/default_litellm_embedding_client.py +0 -24
- rasa/shared/providers/llm/_base_litellm_client.py +31 -63
- rasa/shared/providers/llm/azure_openai_llm_client.py +29 -50
- rasa/shared/providers/llm/default_litellm_llm_client.py +0 -24
- rasa/shared/providers/llm/self_hosted_llm_client.py +29 -17
- rasa/shared/providers/mappings.py +0 -19
- rasa/shared/utils/common.py +2 -37
- rasa/shared/utils/io.py +6 -28
- rasa/shared/utils/llm.py +46 -353
- rasa/shared/utils/yaml.py +82 -181
- rasa/studio/auth.py +5 -3
- rasa/studio/config.py +4 -13
- rasa/studio/constants.py +0 -1
- rasa/studio/data_handler.py +4 -13
- rasa/studio/upload.py +80 -175
- rasa/telemetry.py +17 -94
- rasa/tracing/config.py +1 -3
- rasa/tracing/instrumentation/attribute_extractors.py +17 -94
- rasa/tracing/instrumentation/instrumentation.py +0 -121
- rasa/utils/common.py +0 -5
- rasa/utils/endpoints.py +1 -27
- rasa/utils/io.py +81 -7
- rasa/utils/log_utils.py +2 -9
- rasa/utils/tensorflow/model_data.py +193 -2
- rasa/validator.py +4 -110
- rasa/version.py +1 -1
- rasa_pro-3.11.0a2.dist-info/METADATA +576 -0
- {rasa_pro-3.11.0.dist-info → rasa_pro-3.11.0a2.dist-info}/RECORD +181 -213
- rasa/core/actions/action_repeat_bot_messages.py +0 -89
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-736177bf.js +0 -1
- rasa/core/channels/voice_stream/asr/azure.py +0 -129
- rasa/core/channels/voice_stream/call_state.py +0 -23
- rasa/dialogue_understanding/commands/repeat_bot_messages_command.py +0 -60
- rasa/dialogue_understanding/commands/user_silence_command.py +0 -59
- rasa/dialogue_understanding/patterns/repeat.py +0 -37
- rasa/dialogue_understanding/patterns/user_silence.py +0 -37
- rasa/model_manager/__init__.py +0 -0
- rasa/model_manager/config.py +0 -40
- rasa/model_manager/model_api.py +0 -559
- rasa/model_manager/runner_service.py +0 -286
- rasa/model_manager/socket_bridge.py +0 -146
- rasa/model_manager/studio_jwt_auth.py +0 -86
- rasa/model_manager/trainer_service.py +0 -325
- rasa/model_manager/utils.py +0 -87
- rasa/model_manager/warm_rasa_process.py +0 -187
- rasa/model_service.py +0 -112
- rasa/shared/core/flows/utils.py +0 -39
- rasa/shared/providers/_configs/litellm_router_client_config.py +0 -220
- rasa/shared/providers/_configs/model_group_config.py +0 -167
- rasa/shared/providers/_configs/rasa_llm_client_config.py +0 -73
- rasa/shared/providers/_utils.py +0 -79
- rasa/shared/providers/embedding/litellm_router_embedding_client.py +0 -135
- rasa/shared/providers/llm/litellm_router_llm_client.py +0 -182
- rasa/shared/providers/llm/rasa_llm_client.py +0 -112
- rasa/shared/providers/router/__init__.py +0 -0
- rasa/shared/providers/router/_base_litellm_router_client.py +0 -183
- rasa/shared/providers/router/router_client.py +0 -73
- rasa/shared/utils/health_check/__init__.py +0 -0
- rasa/shared/utils/health_check/embeddings_health_check_mixin.py +0 -31
- rasa/shared/utils/health_check/health_check.py +0 -258
- rasa/shared/utils/health_check/llm_health_check_mixin.py +0 -31
- rasa/utils/sanic_error_handler.py +0 -32
- rasa/utils/tensorflow/feature_array.py +0 -366
- rasa_pro-3.11.0.dist-info/METADATA +0 -198
- {rasa_pro-3.11.0.dist-info → rasa_pro-3.11.0a2.dist-info}/NOTICE +0 -0
- {rasa_pro-3.11.0.dist-info → rasa_pro-3.11.0a2.dist-info}/WHEEL +0 -0
- {rasa_pro-3.11.0.dist-info → rasa_pro-3.11.0a2.dist-info}/entry_points.txt +0 -0
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import importlib
|
|
4
|
+
import os
|
|
4
5
|
from typing import Any, Dict, List, Optional
|
|
5
6
|
|
|
6
7
|
import structlog
|
|
@@ -15,14 +16,13 @@ from rasa.dialogue_understanding.coexistence.constants import (
|
|
|
15
16
|
)
|
|
16
17
|
from rasa.dialogue_understanding.commands import Command, SetSlotCommand
|
|
17
18
|
from rasa.dialogue_understanding.commands.noop_command import NoopCommand
|
|
18
|
-
from rasa.dialogue_understanding.generator.constants import
|
|
19
|
-
LLM_CONFIG_KEY,
|
|
20
|
-
)
|
|
19
|
+
from rasa.dialogue_understanding.generator.constants import LLM_CONFIG_KEY
|
|
21
20
|
from rasa.engine.graph import ExecutionContext, GraphComponent
|
|
22
21
|
from rasa.engine.recipes.default_recipe import DefaultV1Recipe
|
|
23
22
|
from rasa.engine.storage.resource import Resource
|
|
24
23
|
from rasa.engine.storage.storage import ModelStorage
|
|
25
24
|
from rasa.shared.constants import (
|
|
25
|
+
LLM_API_HEALTH_CHECK_ENV_VAR,
|
|
26
26
|
ROUTE_TO_CALM_SLOT,
|
|
27
27
|
PROMPT_CONFIG_KEY,
|
|
28
28
|
PROVIDER_CONFIG_KEY,
|
|
@@ -35,13 +35,12 @@ from rasa.shared.exceptions import InvalidConfigException, FileIOException
|
|
|
35
35
|
from rasa.shared.nlu.constants import COMMANDS, TEXT
|
|
36
36
|
from rasa.shared.nlu.training_data.message import Message
|
|
37
37
|
from rasa.shared.nlu.training_data.training_data import TrainingData
|
|
38
|
-
from rasa.shared.utils.health_check.llm_health_check_mixin import LLMHealthCheckMixin
|
|
39
|
-
from rasa.shared.utils.io import deep_container_fingerprint
|
|
40
38
|
from rasa.shared.utils.llm import (
|
|
41
39
|
DEFAULT_OPENAI_CHAT_MODEL_NAME,
|
|
42
40
|
get_prompt_template,
|
|
41
|
+
llm_api_health_check,
|
|
43
42
|
llm_factory,
|
|
44
|
-
|
|
43
|
+
try_instantiate_llm_client,
|
|
45
44
|
)
|
|
46
45
|
from rasa.utils.log_utils import log_llm
|
|
47
46
|
|
|
@@ -49,7 +48,6 @@ LLM_BASED_ROUTER_PROMPT_FILE_NAME = "llm_based_router_prompt.jinja2"
|
|
|
49
48
|
DEFAULT_COMMAND_PROMPT_TEMPLATE = importlib.resources.read_text(
|
|
50
49
|
"rasa.dialogue_understanding.coexistence", "router_template.jinja2"
|
|
51
50
|
)
|
|
52
|
-
LLM_BASED_ROUTER_CONFIG_FILE_NAME = "config.json"
|
|
53
51
|
|
|
54
52
|
# Token ids for gpt 3.5 and gpt 4 corresponding to space + capitalized Letter
|
|
55
53
|
A_TO_C_TOKEN_IDS_CHATGPT = [
|
|
@@ -76,7 +74,7 @@ structlogger = structlog.get_logger()
|
|
|
76
74
|
],
|
|
77
75
|
is_trainable=True,
|
|
78
76
|
)
|
|
79
|
-
class LLMBasedRouter(
|
|
77
|
+
class LLMBasedRouter(GraphComponent):
|
|
80
78
|
@staticmethod
|
|
81
79
|
def get_default_config() -> Dict[str, Any]:
|
|
82
80
|
"""The component's default config (see parent class for full docstring)."""
|
|
@@ -98,9 +96,6 @@ class LLMBasedRouter(LLMHealthCheckMixin, GraphComponent):
|
|
|
98
96
|
prompt_template: Optional[str] = None,
|
|
99
97
|
) -> None:
|
|
100
98
|
self.config = {**self.get_default_config(), **config}
|
|
101
|
-
self.config[LLM_CONFIG_KEY] = resolve_model_client_config(
|
|
102
|
-
self.config.get(LLM_CONFIG_KEY), LLMBasedRouter.__name__
|
|
103
|
-
)
|
|
104
99
|
|
|
105
100
|
self.prompt_template = (
|
|
106
101
|
prompt_template
|
|
@@ -134,18 +129,20 @@ class LLMBasedRouter(LLMHealthCheckMixin, GraphComponent):
|
|
|
134
129
|
rasa.shared.utils.io.write_text_file(
|
|
135
130
|
self.prompt_template, path / LLM_BASED_ROUTER_PROMPT_FILE_NAME
|
|
136
131
|
)
|
|
137
|
-
rasa.shared.utils.io.dump_obj_as_json_to_file(
|
|
138
|
-
path / LLM_BASED_ROUTER_CONFIG_FILE_NAME, self.config
|
|
139
|
-
)
|
|
140
132
|
|
|
141
133
|
def train(self, training_data: TrainingData) -> Resource:
|
|
142
134
|
"""Train the intent classifier on a data set."""
|
|
143
|
-
|
|
135
|
+
# Validate llm configuration
|
|
136
|
+
llm_client = try_instantiate_llm_client(
|
|
144
137
|
self.config.get(LLM_CONFIG_KEY),
|
|
145
138
|
DEFAULT_LLM_CONFIG,
|
|
146
139
|
"llm_based_router.train",
|
|
147
140
|
LLMBasedRouter.__name__,
|
|
148
141
|
)
|
|
142
|
+
if os.getenv(LLM_API_HEALTH_CHECK_ENV_VAR, "true").lower() == "true":
|
|
143
|
+
llm_api_health_check(
|
|
144
|
+
llm_client, "llm_based_router.train", LLMBasedRouter.__name__
|
|
145
|
+
)
|
|
149
146
|
|
|
150
147
|
self.persist()
|
|
151
148
|
return self._resource
|
|
@@ -160,16 +157,6 @@ class LLMBasedRouter(LLMHealthCheckMixin, GraphComponent):
|
|
|
160
157
|
**kwargs: Any,
|
|
161
158
|
) -> "LLMBasedRouter":
|
|
162
159
|
"""Loads trained component (see parent class for full docstring)."""
|
|
163
|
-
|
|
164
|
-
# Perform health check on the resolved LLM client config
|
|
165
|
-
llm_config = resolve_model_client_config(config.get(LLM_CONFIG_KEY, {}))
|
|
166
|
-
cls.perform_llm_health_check(
|
|
167
|
-
llm_config,
|
|
168
|
-
DEFAULT_LLM_CONFIG,
|
|
169
|
-
"llm_based_router.load",
|
|
170
|
-
LLMBasedRouter.__name__,
|
|
171
|
-
)
|
|
172
|
-
|
|
173
160
|
prompt_template = None
|
|
174
161
|
try:
|
|
175
162
|
with model_storage.read_from(resource) as path:
|
|
@@ -311,17 +298,3 @@ class LLMBasedRouter(LLMHealthCheckMixin, GraphComponent):
|
|
|
311
298
|
# we have to catch all exceptions here
|
|
312
299
|
structlogger.error("llm_based_router.llm.error", error=e)
|
|
313
300
|
return None
|
|
314
|
-
|
|
315
|
-
@classmethod
|
|
316
|
-
def fingerprint_addon(cls, config: Dict[str, Any]) -> Optional[str]:
|
|
317
|
-
"""Add a fingerprint of llm based router for the graph."""
|
|
318
|
-
prompt_template = get_prompt_template(
|
|
319
|
-
config.get(PROMPT_CONFIG_KEY),
|
|
320
|
-
DEFAULT_COMMAND_PROMPT_TEMPLATE,
|
|
321
|
-
)
|
|
322
|
-
|
|
323
|
-
llm_config = resolve_model_client_config(
|
|
324
|
-
config.get(LLM_CONFIG_KEY), LLMBasedRouter.__name__
|
|
325
|
-
)
|
|
326
|
-
|
|
327
|
-
return deep_container_fingerprint([prompt_template, llm_config])
|
|
@@ -33,9 +33,6 @@ from rasa.dialogue_understanding.commands.session_start_command import (
|
|
|
33
33
|
SessionStartCommand,
|
|
34
34
|
)
|
|
35
35
|
from rasa.dialogue_understanding.commands.session_end_command import SessionEndCommand
|
|
36
|
-
from rasa.dialogue_understanding.commands.repeat_bot_messages_command import (
|
|
37
|
-
RepeatBotMessagesCommand,
|
|
38
|
-
)
|
|
39
36
|
|
|
40
37
|
__all__ = [
|
|
41
38
|
"Command",
|
|
@@ -56,6 +53,5 @@ __all__ = [
|
|
|
56
53
|
"ChangeFlowCommand",
|
|
57
54
|
"SessionStartCommand",
|
|
58
55
|
"SessionEndCommand",
|
|
59
|
-
"RepeatBotMessagesCommand",
|
|
60
56
|
"RestartCommand",
|
|
61
57
|
]
|
|
@@ -36,9 +36,3 @@ class ChangeFlowCommand(Command):
|
|
|
36
36
|
# the change flow command is not actually pushing anything to the tracker,
|
|
37
37
|
# but it is predicted by the MultiStepLLMCommandGenerator and used internally
|
|
38
38
|
return []
|
|
39
|
-
|
|
40
|
-
def __eq__(self, other: Any) -> bool:
|
|
41
|
-
return isinstance(other, ChangeFlowCommand)
|
|
42
|
-
|
|
43
|
-
def __hash__(self) -> int:
|
|
44
|
-
return hash(self.command())
|
|
@@ -11,7 +11,6 @@ from rasa.dialogue_understanding.commands import (
|
|
|
11
11
|
SkipQuestionCommand,
|
|
12
12
|
RestartCommand,
|
|
13
13
|
)
|
|
14
|
-
from rasa.dialogue_understanding.commands.user_silence_command import UserSilenceCommand
|
|
15
14
|
from rasa.dialogue_understanding.patterns.cancel import CancelPatternFlowStackFrame
|
|
16
15
|
from rasa.dialogue_understanding.patterns.cannot_handle import (
|
|
17
16
|
CannotHandlePatternFlowStackFrame,
|
|
@@ -28,13 +27,9 @@ from rasa.dialogue_understanding.patterns.session_start import (
|
|
|
28
27
|
from rasa.dialogue_understanding.patterns.skip_question import (
|
|
29
28
|
SkipQuestionPatternFlowStackFrame,
|
|
30
29
|
)
|
|
31
|
-
from rasa.dialogue_understanding.patterns.user_silence import (
|
|
32
|
-
UserSilencePatternFlowStackFrame,
|
|
33
|
-
)
|
|
34
30
|
|
|
35
31
|
triggerable_pattern_to_command_class: Dict[str, Type[Command]] = {
|
|
36
32
|
SessionStartPatternFlowStackFrame.flow_id: SessionStartCommand,
|
|
37
|
-
UserSilencePatternFlowStackFrame.flow_id: UserSilenceCommand,
|
|
38
33
|
CancelPatternFlowStackFrame.flow_id: CancelFlowCommand,
|
|
39
34
|
ChitchatPatternFlowStackFrame.flow_id: ChitChatAnswerCommand,
|
|
40
35
|
HumanHandoffPatternFlowStackFrame.flow_id: HumanHandoffCommand,
|
|
@@ -1,4 +1,5 @@
|
|
|
1
|
-
"""
|
|
1
|
+
"""
|
|
2
|
+
The module is primarily centered around the `FlowRetrieval` class which handles the
|
|
2
3
|
initialization, configuration validation, vector store management, and flow retrieval
|
|
3
4
|
logic. It integrates components for managing embeddings, vector stores, and
|
|
4
5
|
flow-specific templates, facilitating semantic search functionalities.
|
|
@@ -26,10 +27,8 @@ from langchain.docstore.document import Document
|
|
|
26
27
|
from langchain.schema.embeddings import Embeddings
|
|
27
28
|
from langchain_community.vectorstores.faiss import FAISS
|
|
28
29
|
from langchain_community.vectorstores.utils import DistanceStrategy
|
|
29
|
-
|
|
30
30
|
from rasa.engine.storage.resource import Resource
|
|
31
31
|
from rasa.engine.storage.storage import ModelStorage
|
|
32
|
-
import rasa.shared.utils.io
|
|
33
32
|
from rasa.shared.constants import (
|
|
34
33
|
EMBEDDINGS_CONFIG_KEY,
|
|
35
34
|
PROVIDER_CONFIG_KEY,
|
|
@@ -38,15 +37,12 @@ from rasa.shared.constants import (
|
|
|
38
37
|
from rasa.shared.core.domain import Domain
|
|
39
38
|
from rasa.shared.core.flows import FlowsList
|
|
40
39
|
from rasa.shared.core.trackers import DialogueStateTracker
|
|
41
|
-
from rasa.shared.exceptions import ProviderClientAPIException
|
|
42
40
|
from rasa.shared.nlu.constants import TEXT, FLOWS_FROM_SEMANTIC_SEARCH
|
|
43
41
|
from rasa.shared.nlu.training_data.message import Message
|
|
42
|
+
from rasa.shared.exceptions import ProviderClientAPIException
|
|
44
43
|
from rasa.shared.providers.embedding._langchain_embedding_client_adapter import (
|
|
45
44
|
_LangchainEmbeddingClientAdapter,
|
|
46
45
|
)
|
|
47
|
-
from rasa.shared.utils.health_check.embeddings_health_check_mixin import (
|
|
48
|
-
EmbeddingsHealthCheckMixin,
|
|
49
|
-
)
|
|
50
46
|
from rasa.shared.utils.llm import (
|
|
51
47
|
tracker_as_readable_transcript,
|
|
52
48
|
embedder_factory,
|
|
@@ -54,15 +50,12 @@ from rasa.shared.utils.llm import (
|
|
|
54
50
|
USER,
|
|
55
51
|
get_prompt_template,
|
|
56
52
|
allowed_values_for_slot,
|
|
57
|
-
resolve_model_client_config,
|
|
58
53
|
)
|
|
59
54
|
|
|
60
55
|
DEFAULT_FLOW_DOCUMENT_TEMPLATE = importlib.resources.read_text(
|
|
61
56
|
"rasa.dialogue_understanding.generator", "flow_document_template.jinja2"
|
|
62
57
|
)
|
|
63
58
|
|
|
64
|
-
FLOW_RETRIEVAL_CONFIG_FILE_NAME = "flow_retrieval_config.json"
|
|
65
|
-
|
|
66
59
|
DEFAULT_EMBEDDINGS_CONFIG = {
|
|
67
60
|
PROVIDER_CONFIG_KEY: OPENAI_PROVIDER,
|
|
68
61
|
"model": DEFAULT_OPENAI_EMBEDDING_MODEL_NAME,
|
|
@@ -80,7 +73,7 @@ DEFAULT_SHOULD_EMBED_SLOTS = True
|
|
|
80
73
|
structlogger = structlog.get_logger()
|
|
81
74
|
|
|
82
75
|
|
|
83
|
-
class FlowRetrieval
|
|
76
|
+
class FlowRetrieval:
|
|
84
77
|
@classmethod
|
|
85
78
|
def get_default_config(cls) -> Dict[str, Any]:
|
|
86
79
|
"""The default config for the flow retrieval."""
|
|
@@ -99,9 +92,6 @@ class FlowRetrieval(EmbeddingsHealthCheckMixin):
|
|
|
99
92
|
):
|
|
100
93
|
config = {**self.get_default_config(), **config}
|
|
101
94
|
self.config = self.validate_config(config)
|
|
102
|
-
self.config[EMBEDDINGS_CONFIG_KEY] = resolve_model_client_config(
|
|
103
|
-
self.config.get(EMBEDDINGS_CONFIG_KEY), FlowRetrieval.__name__
|
|
104
|
-
)
|
|
105
95
|
self.vector_store: Optional[FAISS] = None
|
|
106
96
|
self.flow_document_template = get_prompt_template(
|
|
107
97
|
None, DEFAULT_FLOW_DOCUMENT_TEMPLATE
|
|
@@ -150,18 +140,6 @@ class FlowRetrieval(EmbeddingsHealthCheckMixin):
|
|
|
150
140
|
**kwargs: Any,
|
|
151
141
|
) -> "FlowRetrieval":
|
|
152
142
|
"""Load flow retrieval with previously populated FAISS vector store."""
|
|
153
|
-
|
|
154
|
-
# Perform health check on resolved embedding client config
|
|
155
|
-
embeddings_config = resolve_model_client_config(
|
|
156
|
-
config.get(EMBEDDINGS_CONFIG_KEY, {})
|
|
157
|
-
)
|
|
158
|
-
cls.perform_embeddings_health_check(
|
|
159
|
-
embeddings_config,
|
|
160
|
-
DEFAULT_EMBEDDINGS_CONFIG,
|
|
161
|
-
"flow_retrieval.load",
|
|
162
|
-
FlowRetrieval.__name__,
|
|
163
|
-
)
|
|
164
|
-
|
|
165
143
|
# initialize base flow retrieval
|
|
166
144
|
flow_retrieval = FlowRetrieval(config, model_storage, resource)
|
|
167
145
|
# load vector store
|
|
@@ -169,7 +147,6 @@ class FlowRetrieval(EmbeddingsHealthCheckMixin):
|
|
|
169
147
|
flow_retrieval.config, model_storage, resource
|
|
170
148
|
)
|
|
171
149
|
flow_retrieval.vector_store = vector_store
|
|
172
|
-
|
|
173
150
|
return flow_retrieval
|
|
174
151
|
|
|
175
152
|
@classmethod
|
|
@@ -201,21 +178,13 @@ class FlowRetrieval(EmbeddingsHealthCheckMixin):
|
|
|
201
178
|
Returns:
|
|
202
179
|
The embedder.
|
|
203
180
|
"""
|
|
204
|
-
# Copy the config so original config is not modified
|
|
205
|
-
config = config.copy()
|
|
206
|
-
# Resolve config and instantiate the embedding client
|
|
207
|
-
config[EMBEDDINGS_CONFIG_KEY] = resolve_model_client_config(
|
|
208
|
-
config.get(EMBEDDINGS_CONFIG_KEY), FlowRetrieval.__name__
|
|
209
|
-
)
|
|
210
181
|
client = embedder_factory(
|
|
211
182
|
config.get(EMBEDDINGS_CONFIG_KEY), DEFAULT_EMBEDDINGS_CONFIG
|
|
212
183
|
)
|
|
213
|
-
# Wrap the embedding client in the adapter
|
|
214
184
|
return _LangchainEmbeddingClientAdapter(client)
|
|
215
185
|
|
|
216
186
|
def persist(self) -> None:
|
|
217
187
|
self._persist_vector_store()
|
|
218
|
-
self._persist_config()
|
|
219
188
|
|
|
220
189
|
def _persist_vector_store(self) -> None:
|
|
221
190
|
"""Persists the FAISS vector store."""
|
|
@@ -228,12 +197,6 @@ class FlowRetrieval(EmbeddingsHealthCheckMixin):
|
|
|
228
197
|
event_info="Vector store is None, not persisted.",
|
|
229
198
|
)
|
|
230
199
|
|
|
231
|
-
def _persist_config(self) -> None:
|
|
232
|
-
with self._model_storage.write_to(self._resource) as path:
|
|
233
|
-
rasa.shared.utils.io.dump_obj_as_json_to_file(
|
|
234
|
-
path / FLOW_RETRIEVAL_CONFIG_FILE_NAME, self.config
|
|
235
|
-
)
|
|
236
|
-
|
|
237
200
|
def populate(self, flows: FlowsList, domain: Domain) -> None:
|
|
238
201
|
"""Populates the vector store with embeddings generated from
|
|
239
202
|
documents based on the flow descriptions, and flow slots
|
|
@@ -243,14 +206,6 @@ class FlowRetrieval(EmbeddingsHealthCheckMixin):
|
|
|
243
206
|
flows: List of flows to populate the vector store with.
|
|
244
207
|
domain: The domain containing relevant slot information.
|
|
245
208
|
"""
|
|
246
|
-
# Perform health check before populating the vector store with flows
|
|
247
|
-
self.perform_embeddings_health_check(
|
|
248
|
-
self.config.get(EMBEDDINGS_CONFIG_KEY),
|
|
249
|
-
DEFAULT_EMBEDDINGS_CONFIG,
|
|
250
|
-
"flow_retrieval.train",
|
|
251
|
-
FlowRetrieval.__name__,
|
|
252
|
-
)
|
|
253
|
-
|
|
254
209
|
flows_to_embedd = flows.exclude_link_only_flows()
|
|
255
210
|
embeddings = self._create_embedder(self.config)
|
|
256
211
|
documents = self._generate_flow_documents(flows_to_embedd, domain)
|
|
@@ -2,6 +2,7 @@ from abc import ABC, abstractmethod
|
|
|
2
2
|
from functools import lru_cache
|
|
3
3
|
from typing import Dict, Any, List, Optional, Tuple, Union, Text
|
|
4
4
|
|
|
5
|
+
import os
|
|
5
6
|
import structlog
|
|
6
7
|
from jinja2 import Template
|
|
7
8
|
|
|
@@ -16,13 +17,13 @@ from rasa.dialogue_understanding.generator.constants import (
|
|
|
16
17
|
LLM_CONFIG_KEY,
|
|
17
18
|
FLOW_RETRIEVAL_KEY,
|
|
18
19
|
FLOW_RETRIEVAL_ACTIVE_KEY,
|
|
19
|
-
FLOW_RETRIEVAL_FLOW_THRESHOLD,
|
|
20
20
|
)
|
|
21
21
|
from rasa.dialogue_understanding.generator.flow_retrieval import FlowRetrieval
|
|
22
22
|
from rasa.engine.graph import GraphComponent, ExecutionContext
|
|
23
23
|
from rasa.engine.recipes.default_recipe import DefaultV1Recipe
|
|
24
24
|
from rasa.engine.storage.resource import Resource
|
|
25
25
|
from rasa.engine.storage.storage import ModelStorage
|
|
26
|
+
from rasa.shared.constants import LLM_API_HEALTH_CHECK_ENV_VAR
|
|
26
27
|
from rasa.shared.core.domain import Domain
|
|
27
28
|
from rasa.shared.core.flows import FlowStep, Flow, FlowsList
|
|
28
29
|
from rasa.shared.core.flows.steps.collect import CollectInformationFlowStep
|
|
@@ -32,11 +33,11 @@ from rasa.shared.exceptions import ProviderClientAPIException
|
|
|
32
33
|
from rasa.shared.nlu.constants import FLOWS_IN_PROMPT
|
|
33
34
|
from rasa.shared.nlu.training_data.message import Message
|
|
34
35
|
from rasa.shared.nlu.training_data.training_data import TrainingData
|
|
35
|
-
from rasa.shared.utils.health_check.llm_health_check_mixin import LLMHealthCheckMixin
|
|
36
36
|
from rasa.shared.utils.llm import (
|
|
37
37
|
allowed_values_for_slot,
|
|
38
|
+
llm_api_health_check,
|
|
38
39
|
llm_factory,
|
|
39
|
-
|
|
40
|
+
try_instantiate_llm_client,
|
|
40
41
|
)
|
|
41
42
|
from rasa.utils.log_utils import log_llm
|
|
42
43
|
|
|
@@ -49,9 +50,7 @@ structlogger = structlog.get_logger()
|
|
|
49
50
|
],
|
|
50
51
|
is_trainable=True,
|
|
51
52
|
)
|
|
52
|
-
class LLMBasedCommandGenerator(
|
|
53
|
-
LLMHealthCheckMixin, GraphComponent, CommandGenerator, ABC
|
|
54
|
-
):
|
|
53
|
+
class LLMBasedCommandGenerator(GraphComponent, CommandGenerator, ABC):
|
|
55
54
|
"""An abstract class defining interface and common functionality
|
|
56
55
|
of an LLM-based command generators.
|
|
57
56
|
"""
|
|
@@ -65,9 +64,6 @@ class LLMBasedCommandGenerator(
|
|
|
65
64
|
) -> None:
|
|
66
65
|
super().__init__(config)
|
|
67
66
|
self.config = {**self.get_default_config(), **config}
|
|
68
|
-
self.config[LLM_CONFIG_KEY] = resolve_model_client_config(
|
|
69
|
-
self.config.get(LLM_CONFIG_KEY), LLMBasedCommandGenerator.__name__
|
|
70
|
-
)
|
|
71
67
|
self._model_storage = model_storage
|
|
72
68
|
self._resource = resource
|
|
73
69
|
self.flow_retrieval: Optional[FlowRetrieval]
|
|
@@ -77,9 +73,17 @@ class LLMBasedCommandGenerator(
|
|
|
77
73
|
self.config[FLOW_RETRIEVAL_KEY], model_storage, resource
|
|
78
74
|
)
|
|
79
75
|
structlogger.info("llm_based_command_generator.flow_retrieval.enabled")
|
|
80
|
-
self.config[FLOW_RETRIEVAL_KEY] = self.flow_retrieval.config
|
|
81
76
|
else:
|
|
82
77
|
self.flow_retrieval = None
|
|
78
|
+
structlogger.warn(
|
|
79
|
+
"llm_based_command_generator.flow_retrieval.disabled",
|
|
80
|
+
event_info=(
|
|
81
|
+
"Disabling flow retrieval can cause issues when there are a "
|
|
82
|
+
"large number of flows to be included in the prompt. For more"
|
|
83
|
+
"information see:\n"
|
|
84
|
+
"https://rasa.com/docs/rasa-pro/concepts/dialogue-understanding#how-the-llmcommandgenerator-works"
|
|
85
|
+
),
|
|
86
|
+
)
|
|
83
87
|
|
|
84
88
|
### Abstract methods
|
|
85
89
|
@staticmethod
|
|
@@ -167,32 +171,18 @@ class LLMBasedCommandGenerator(
|
|
|
167
171
|
"""Train the llm based command generator. Stores all flows into a vector
|
|
168
172
|
store.
|
|
169
173
|
"""
|
|
170
|
-
|
|
174
|
+
# Validate llm configuration
|
|
175
|
+
llm_client = try_instantiate_llm_client(
|
|
171
176
|
self.config.get(LLM_CONFIG_KEY),
|
|
172
177
|
DEFAULT_LLM_CONFIG,
|
|
173
178
|
"llm_based_command_generator.train",
|
|
174
179
|
LLMBasedCommandGenerator.__name__,
|
|
175
180
|
)
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
structlogger.warn(
|
|
182
|
-
"llm_based_command_generator.flow_retrieval.disabled",
|
|
183
|
-
event_info=(
|
|
184
|
-
f"You have {len(flows.user_flows)} user flows but flow "
|
|
185
|
-
f"retrieval is disabled. "
|
|
186
|
-
f"It is recommended to enable flow retrieval if the "
|
|
187
|
-
f"total number of user flows exceed "
|
|
188
|
-
f"{FLOW_RETRIEVAL_FLOW_THRESHOLD}. "
|
|
189
|
-
f"Keeping it disabled can result in deterioration of "
|
|
190
|
-
f"command generator's functional "
|
|
191
|
-
f"performance and higher costs because of increased "
|
|
192
|
-
f"number of tokens in the prompt. For more"
|
|
193
|
-
"information see:\n"
|
|
194
|
-
"https://rasa.com/docs/rasa-pro/concepts/dialogue-understanding#how-the-llmcommandgenerator-works"
|
|
195
|
-
),
|
|
181
|
+
if os.getenv(LLM_API_HEALTH_CHECK_ENV_VAR, "true").lower() == "true":
|
|
182
|
+
llm_api_health_check(
|
|
183
|
+
llm_client,
|
|
184
|
+
"llm_based_command_generator.train",
|
|
185
|
+
LLMBasedCommandGenerator.__name__,
|
|
196
186
|
)
|
|
197
187
|
|
|
198
188
|
# flow retrieval is populated with only user-defined flows
|
|
@@ -202,11 +192,10 @@ class LLMBasedCommandGenerator(
|
|
|
202
192
|
except Exception as e:
|
|
203
193
|
structlogger.error(
|
|
204
194
|
"llm_based_command_generator.train.failed",
|
|
205
|
-
event_info="Flow retrieval store
|
|
195
|
+
event_info=("Flow retrieval store isinaccessible."),
|
|
206
196
|
error=e,
|
|
207
197
|
)
|
|
208
198
|
raise
|
|
209
|
-
|
|
210
199
|
self.persist()
|
|
211
200
|
return self._resource
|
|
212
201
|
|
|
@@ -244,10 +233,7 @@ class LLMBasedCommandGenerator(
|
|
|
244
233
|
|
|
245
234
|
@classmethod
|
|
246
235
|
def load_flow_retrival(
|
|
247
|
-
cls,
|
|
248
|
-
config: Dict[str, Any],
|
|
249
|
-
model_storage: ModelStorage,
|
|
250
|
-
resource: Resource,
|
|
236
|
+
cls, config: Dict[Text, Any], model_storage: ModelStorage, resource: Resource
|
|
251
237
|
) -> Optional[FlowRetrieval]:
|
|
252
238
|
"""Load the FlowRetrieval component if it is enabled in the configuration."""
|
|
253
239
|
enable_flow_retrieval = config.get(FLOW_RETRIEVAL_KEY, {}).get(
|
|
@@ -24,7 +24,6 @@ from rasa.dialogue_understanding.generator.constants import (
|
|
|
24
24
|
LLM_CONFIG_KEY,
|
|
25
25
|
USER_INPUT_CONFIG_KEY,
|
|
26
26
|
FLOW_RETRIEVAL_KEY,
|
|
27
|
-
DEFAULT_LLM_CONFIG,
|
|
28
27
|
)
|
|
29
28
|
from rasa.dialogue_understanding.generator.flow_retrieval import FlowRetrieval
|
|
30
29
|
from rasa.dialogue_understanding.generator.llm_based_command_generator import (
|
|
@@ -40,10 +39,7 @@ from rasa.engine.graph import ExecutionContext
|
|
|
40
39
|
from rasa.engine.recipes.default_recipe import DefaultV1Recipe
|
|
41
40
|
from rasa.engine.storage.resource import Resource
|
|
42
41
|
from rasa.engine.storage.storage import ModelStorage
|
|
43
|
-
from rasa.shared.constants import
|
|
44
|
-
RASA_PATTERN_CANNOT_HANDLE_NOT_SUPPORTED,
|
|
45
|
-
EMBEDDINGS_CONFIG_KEY,
|
|
46
|
-
)
|
|
42
|
+
from rasa.shared.constants import RASA_PATTERN_CANNOT_HANDLE_NOT_SUPPORTED
|
|
47
43
|
from rasa.shared.constants import ROUTE_TO_CALM_SLOT
|
|
48
44
|
from rasa.shared.core.flows import FlowStep, Flow, FlowsList
|
|
49
45
|
from rasa.shared.core.flows.steps.collect import CollectInformationFlowStep
|
|
@@ -57,7 +53,6 @@ from rasa.shared.utils.llm import (
|
|
|
57
53
|
tracker_as_readable_transcript,
|
|
58
54
|
sanitize_message_for_prompt,
|
|
59
55
|
allowed_values_for_slot,
|
|
60
|
-
resolve_model_client_config,
|
|
61
56
|
)
|
|
62
57
|
|
|
63
58
|
# multistep template keys
|
|
@@ -75,7 +70,6 @@ DEFAULT_HANDLE_FLOWS_TEMPLATE = importlib.resources.read_text(
|
|
|
75
70
|
DEFAULT_FILL_SLOTS_TEMPLATE = importlib.resources.read_text(
|
|
76
71
|
"rasa.dialogue_understanding.generator.multi_step", "fill_slots_prompt.jinja2"
|
|
77
72
|
).strip()
|
|
78
|
-
MULTI_STEP_LLM_COMMAND_GENERATOR_CONFIG_FILE = "config.json"
|
|
79
73
|
|
|
80
74
|
# dictionary of template names and associated file names and default values
|
|
81
75
|
PROMPT_TEMPLATES = {
|
|
@@ -144,18 +138,7 @@ class MultiStepLLMCommandGenerator(LLMBasedCommandGenerator):
|
|
|
144
138
|
**kwargs: Any,
|
|
145
139
|
) -> "MultiStepLLMCommandGenerator":
|
|
146
140
|
"""Loads trained component (see parent class for full docstring)."""
|
|
147
|
-
|
|
148
|
-
# Perform health check of the LLM client config
|
|
149
|
-
llm_config = resolve_model_client_config(config.get(LLM_CONFIG_KEY, {}))
|
|
150
|
-
cls.perform_llm_health_check(
|
|
151
|
-
llm_config,
|
|
152
|
-
DEFAULT_LLM_CONFIG,
|
|
153
|
-
"multi_step_llm_command_generator.load",
|
|
154
|
-
MultiStepLLMCommandGenerator.__name__,
|
|
155
|
-
)
|
|
156
|
-
|
|
157
141
|
prompts = cls._load_prompt_templates(model_storage, resource)
|
|
158
|
-
|
|
159
142
|
# init base command generator
|
|
160
143
|
command_generator = cls(config, model_storage, resource, prompts)
|
|
161
144
|
# load flow retrieval if enabled
|
|
@@ -163,13 +146,13 @@ class MultiStepLLMCommandGenerator(LLMBasedCommandGenerator):
|
|
|
163
146
|
command_generator.flow_retrieval = cls.load_flow_retrival(
|
|
164
147
|
command_generator.config, model_storage, resource
|
|
165
148
|
)
|
|
166
|
-
|
|
167
149
|
return command_generator
|
|
168
150
|
|
|
169
151
|
def persist(self) -> None:
|
|
170
152
|
"""Persist this component to disk for future loading."""
|
|
153
|
+
# persist prompt template
|
|
171
154
|
self._persist_prompt_templates()
|
|
172
|
-
|
|
155
|
+
# persist flow retrieval
|
|
173
156
|
if self.flow_retrieval is not None:
|
|
174
157
|
self.flow_retrieval.persist()
|
|
175
158
|
|
|
@@ -246,9 +229,9 @@ class MultiStepLLMCommandGenerator(LLMBasedCommandGenerator):
|
|
|
246
229
|
commands: List[Command] = []
|
|
247
230
|
|
|
248
231
|
slot_set_re = re.compile(
|
|
249
|
-
r"""SetSlot\(
|
|
232
|
+
r"""SetSlot\((\"?[a-zA-Z_][a-zA-Z0-9_-]*?\"?), ?(.*)\)"""
|
|
250
233
|
)
|
|
251
|
-
start_flow_re = re.compile(r"StartFlow\(
|
|
234
|
+
start_flow_re = re.compile(r"StartFlow\(([a-zA-Z0-9_-]+?)\)")
|
|
252
235
|
change_flow_re = re.compile(r"ChangeFlow\(\)")
|
|
253
236
|
cancel_flow_re = re.compile(r"CancelFlow\(\)")
|
|
254
237
|
chitchat_re = re.compile(r"ChitChat\(\)")
|
|
@@ -297,19 +280,9 @@ class MultiStepLLMCommandGenerator(LLMBasedCommandGenerator):
|
|
|
297
280
|
commands.append(HumanHandoffCommand())
|
|
298
281
|
elif match := clarify_re.search(action):
|
|
299
282
|
options = sorted([opt.strip() for opt in match.group(1).split(",")])
|
|
300
|
-
# Remove surrounding quotes if present
|
|
301
|
-
cleaned_options = []
|
|
302
|
-
for flow in options:
|
|
303
|
-
if (flow.startswith('"') and flow.endswith('"')) or (
|
|
304
|
-
flow.startswith("'") and flow.endswith("'")
|
|
305
|
-
):
|
|
306
|
-
cleaned_options.append(flow[1:-1])
|
|
307
|
-
else:
|
|
308
|
-
cleaned_options.append(flow)
|
|
309
|
-
# check if flow is valid
|
|
310
283
|
valid_options = [
|
|
311
284
|
flow
|
|
312
|
-
for flow in
|
|
285
|
+
for flow in options
|
|
313
286
|
if flow in flows.user_flow_ids
|
|
314
287
|
and flow not in user_flows_on_the_stack(tracker.stack)
|
|
315
288
|
]
|
|
@@ -320,13 +293,6 @@ class MultiStepLLMCommandGenerator(LLMBasedCommandGenerator):
|
|
|
320
293
|
elif change_flow_re.search(action):
|
|
321
294
|
commands.append(ChangeFlowCommand())
|
|
322
295
|
|
|
323
|
-
if not commands:
|
|
324
|
-
structlogger.debug(
|
|
325
|
-
"multi_step_llm_command_generator.parse_commands",
|
|
326
|
-
message="No commands were parsed from the LLM actions.",
|
|
327
|
-
actions=actions,
|
|
328
|
-
)
|
|
329
|
-
|
|
330
296
|
return commands
|
|
331
297
|
|
|
332
298
|
### Helper methods
|
|
@@ -402,13 +368,6 @@ class MultiStepLLMCommandGenerator(LLMBasedCommandGenerator):
|
|
|
402
368
|
file_path = path / file_name
|
|
403
369
|
rasa.shared.utils.io.write_text_file(template, file_path)
|
|
404
370
|
|
|
405
|
-
def _persist_config(self) -> None:
|
|
406
|
-
"""Persist config as a source of truth for resolved clients."""
|
|
407
|
-
with self._model_storage.write_to(self._resource) as path:
|
|
408
|
-
rasa.shared.utils.io.dump_obj_as_json_to_file(
|
|
409
|
-
path / MULTI_STEP_LLM_COMMAND_GENERATOR_CONFIG_FILE, self.config
|
|
410
|
-
)
|
|
411
|
-
|
|
412
371
|
async def _predict_commands_with_multi_step(
|
|
413
372
|
self,
|
|
414
373
|
message: Message,
|
|
@@ -802,17 +761,11 @@ class MultiStepLLMCommandGenerator(LLMBasedCommandGenerator):
|
|
|
802
761
|
.get(FILE_PATH_KEY),
|
|
803
762
|
DEFAULT_FILL_SLOTS_TEMPLATE,
|
|
804
763
|
)
|
|
805
|
-
|
|
806
|
-
llm_config = resolve_model_client_config(
|
|
807
|
-
config.get(LLM_CONFIG_KEY), MultiStepLLMCommandGenerator.__name__
|
|
808
|
-
)
|
|
809
|
-
embedding_config = resolve_model_client_config(
|
|
810
|
-
config.get(FLOW_RETRIEVAL_KEY, {}).get(EMBEDDINGS_CONFIG_KEY),
|
|
811
|
-
FlowRetrieval.__name__,
|
|
812
|
-
)
|
|
813
|
-
|
|
814
764
|
return deep_container_fingerprint(
|
|
815
|
-
[
|
|
765
|
+
[
|
|
766
|
+
handle_flows_template,
|
|
767
|
+
fill_slots_template,
|
|
768
|
+
]
|
|
816
769
|
)
|
|
817
770
|
|
|
818
771
|
@staticmethod
|
|
@@ -19,7 +19,6 @@ from rasa.engine.storage.storage import ModelStorage
|
|
|
19
19
|
from rasa.shared.constants import ROUTE_TO_CALM_SLOT
|
|
20
20
|
from rasa.shared.core.domain import Domain
|
|
21
21
|
from rasa.shared.core.flows.flows_list import FlowsList
|
|
22
|
-
from rasa.shared.core.flows.steps import CollectInformationFlowStep
|
|
23
22
|
from rasa.shared.core.slot_mappings import (
|
|
24
23
|
SlotFillingManager,
|
|
25
24
|
extract_slot_value,
|
|
@@ -218,24 +217,7 @@ def _issue_set_slot_commands(
|
|
|
218
217
|
commands: List[Command] = []
|
|
219
218
|
domain = domain if domain else Domain.empty()
|
|
220
219
|
slot_filling_manager = SlotFillingManager(domain, tracker, message)
|
|
221
|
-
|
|
222
|
-
# only use slots that don't have ask_before_filling set to True
|
|
223
|
-
available_slot_names = flows.available_slot_names(ask_before_filling=False)
|
|
224
|
-
|
|
225
|
-
# check if the current step is a CollectInformationFlowStep
|
|
226
|
-
# in case it has ask_before_filling set to True, we need to add the
|
|
227
|
-
# slot to the available_slot_names
|
|
228
|
-
if tracker.active_flow:
|
|
229
|
-
flow = flows.flow_by_id(tracker.active_flow)
|
|
230
|
-
step_id = tracker.current_step_id
|
|
231
|
-
if flow is not None:
|
|
232
|
-
current_step = flow.step_by_id(step_id)
|
|
233
|
-
if (
|
|
234
|
-
current_step
|
|
235
|
-
and isinstance(current_step, CollectInformationFlowStep)
|
|
236
|
-
and current_step.ask_before_filling
|
|
237
|
-
):
|
|
238
|
-
available_slot_names.add(current_step.collect)
|
|
220
|
+
available_slot_names = flows.available_slot_names()
|
|
239
221
|
|
|
240
222
|
for _, slot in tracker.slots.items():
|
|
241
223
|
# if a slot is not collected in available flows,
|
|
@@ -41,9 +41,6 @@ Based on this information generate a list of actions you want to take. Your job
|
|
|
41
41
|
* Responding to knowledge-oriented user messages, described by "SearchAndReply()"
|
|
42
42
|
* Responding to a casual, non-task-oriented user message, described by "ChitChat()".
|
|
43
43
|
* Handing off to a human, in case the user seems frustrated or explicitly asks to speak to one, described by "HumanHandoff()".
|
|
44
|
-
{% if is_repeat_command_enabled %}
|
|
45
|
-
* Repeat the last bot messages, described by "RepeatLastBotMessages()". This is useful when the user asks to repeat the last bot messages.
|
|
46
|
-
{% endif %}
|
|
47
44
|
|
|
48
45
|
===
|
|
49
46
|
Write out the actions you want to take, one per line, in the order they should take place.
|