rasa-pro 3.11.0rc1__py3-none-any.whl → 3.11.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- rasa/cli/inspect.py +2 -0
- rasa/cli/studio/studio.py +18 -8
- rasa/core/actions/action_repeat_bot_messages.py +17 -0
- rasa/core/channels/channel.py +17 -0
- rasa/core/channels/voice_ready/audiocodes.py +12 -0
- rasa/core/channels/voice_ready/jambonz.py +13 -2
- rasa/core/channels/voice_ready/twilio_voice.py +6 -21
- rasa/core/channels/voice_stream/voice_channel.py +13 -1
- rasa/core/nlg/contextual_response_rephraser.py +18 -10
- rasa/core/policies/enterprise_search_policy.py +27 -67
- rasa/core/policies/intentless_policy.py +25 -67
- rasa/dialogue_understanding/coexistence/llm_based_router.py +18 -33
- rasa/dialogue_understanding/generator/constants.py +0 -2
- rasa/dialogue_understanding/generator/flow_retrieval.py +33 -50
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +12 -40
- rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +18 -20
- rasa/dialogue_understanding/generator/nlu_command_adapter.py +19 -1
- rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +24 -21
- rasa/dialogue_understanding/processor/command_processor.py +21 -1
- rasa/e2e_test/e2e_test_case.py +85 -6
- rasa/engine/validation.py +57 -41
- rasa/model_service.py +3 -0
- rasa/nlu/tokenizers/whitespace_tokenizer.py +3 -14
- rasa/server.py +3 -1
- rasa/shared/core/flows/flows_list.py +5 -1
- rasa/shared/providers/embedding/_base_litellm_embedding_client.py +6 -14
- rasa/shared/providers/llm/_base_litellm_client.py +6 -1
- rasa/shared/utils/health_check/__init__.py +0 -0
- rasa/shared/utils/health_check/embeddings_health_check_mixin.py +31 -0
- rasa/shared/utils/health_check/health_check.py +256 -0
- rasa/shared/utils/health_check/llm_health_check_mixin.py +31 -0
- rasa/shared/utils/llm.py +5 -2
- rasa/shared/utils/yaml.py +102 -62
- rasa/studio/auth.py +3 -5
- rasa/studio/config.py +13 -4
- rasa/studio/constants.py +1 -0
- rasa/studio/data_handler.py +10 -3
- rasa/studio/upload.py +21 -10
- rasa/telemetry.py +12 -0
- rasa/tracing/config.py +2 -0
- rasa/tracing/instrumentation/attribute_extractors.py +20 -0
- rasa/tracing/instrumentation/instrumentation.py +121 -0
- rasa/utils/common.py +5 -0
- rasa/utils/io.py +8 -16
- rasa/utils/sanic_error_handler.py +32 -0
- rasa/version.py +1 -1
- {rasa_pro-3.11.0rc1.dist-info → rasa_pro-3.11.0rc2.dist-info}/METADATA +3 -2
- {rasa_pro-3.11.0rc1.dist-info → rasa_pro-3.11.0rc2.dist-info}/RECORD +51 -47
- rasa/shared/utils/health_check.py +0 -533
- {rasa_pro-3.11.0rc1.dist-info → rasa_pro-3.11.0rc2.dist-info}/NOTICE +0 -0
- {rasa_pro-3.11.0rc1.dist-info → rasa_pro-3.11.0rc2.dist-info}/WHEEL +0 -0
- {rasa_pro-3.11.0rc1.dist-info → rasa_pro-3.11.0rc2.dist-info}/entry_points.txt +0 -0
|
@@ -18,10 +18,6 @@ from rasa.core.constants import (
|
|
|
18
18
|
UTTER_SOURCE_METADATA_KEY,
|
|
19
19
|
)
|
|
20
20
|
from rasa.core.policies.policy import Policy, PolicyPrediction, SupportedData
|
|
21
|
-
from rasa.dialogue_understanding.generator.constants import (
|
|
22
|
-
TRAINED_MODEL_NAME_CONFIG_KEY,
|
|
23
|
-
TRAINED_EMBEDDINGS_CONFIG_KEY,
|
|
24
|
-
)
|
|
25
21
|
from rasa.dialogue_understanding.patterns.chitchat import FLOW_PATTERN_CHITCHAT
|
|
26
22
|
from rasa.dialogue_understanding.stack.frames import (
|
|
27
23
|
ChitChatStackFrame,
|
|
@@ -64,6 +60,10 @@ from rasa.shared.providers.embedding._langchain_embedding_client_adapter import
|
|
|
64
60
|
_LangchainEmbeddingClientAdapter,
|
|
65
61
|
)
|
|
66
62
|
from rasa.shared.providers.llm.llm_client import LLMClient
|
|
63
|
+
from rasa.shared.utils.health_check.embeddings_health_check_mixin import (
|
|
64
|
+
EmbeddingsHealthCheckMixin,
|
|
65
|
+
)
|
|
66
|
+
from rasa.shared.utils.health_check.llm_health_check_mixin import LLMHealthCheckMixin
|
|
67
67
|
from rasa.shared.utils.io import deep_container_fingerprint
|
|
68
68
|
from rasa.shared.utils.llm import (
|
|
69
69
|
AI,
|
|
@@ -79,12 +79,6 @@ from rasa.shared.utils.llm import (
|
|
|
79
79
|
tracker_as_readable_transcript,
|
|
80
80
|
resolve_model_client_config,
|
|
81
81
|
)
|
|
82
|
-
from rasa.shared.utils.health_check import (
|
|
83
|
-
perform_training_time_llm_health_check,
|
|
84
|
-
perform_training_time_embeddings_health_check,
|
|
85
|
-
perform_inference_time_llm_health_check,
|
|
86
|
-
perform_inference_time_embeddings_health_check,
|
|
87
|
-
)
|
|
88
82
|
from rasa.utils.log_utils import log_llm
|
|
89
83
|
from rasa.utils.ml_utils import (
|
|
90
84
|
extract_ai_response_examples,
|
|
@@ -383,7 +377,7 @@ def conversation_as_prompt(conversation: Conversation) -> str:
|
|
|
383
377
|
@DefaultV1Recipe.register(
|
|
384
378
|
DefaultV1Recipe.ComponentType.POLICY_WITH_END_TO_END_SUPPORT, is_trainable=True
|
|
385
379
|
)
|
|
386
|
-
class IntentlessPolicy(Policy):
|
|
380
|
+
class IntentlessPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Policy):
|
|
387
381
|
"""Policy which uses a language model to generate the next action.
|
|
388
382
|
|
|
389
383
|
The policy uses the OpenAI API to generate the next action based on the
|
|
@@ -516,10 +510,8 @@ class IntentlessPolicy(Policy):
|
|
|
516
510
|
A policy must return its resource locator so that potential children nodes
|
|
517
511
|
can load the policy from the resource.
|
|
518
512
|
"""
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
self.config[TRAINED_EMBEDDINGS_CONFIG_KEY],
|
|
522
|
-
) = self._perform_training_time_health_checks()
|
|
513
|
+
# Perform health checks of both LLM and embeddings client configs
|
|
514
|
+
self._perform_health_checks(self.config, "intentless_policy.train")
|
|
523
515
|
|
|
524
516
|
responses = filter_responses(responses, forms, flows or FlowsList([]))
|
|
525
517
|
telemetry.track_intentless_policy_train()
|
|
@@ -952,10 +944,13 @@ class IntentlessPolicy(Policy):
|
|
|
952
944
|
**kwargs: Any,
|
|
953
945
|
) -> "IntentlessPolicy":
|
|
954
946
|
"""Loads a trained policy (see parent class for full docstring)."""
|
|
947
|
+
|
|
948
|
+
# Perform health checks of both LLM and embeddings client configs
|
|
949
|
+
cls._perform_health_checks(config, "intentless_policy.load")
|
|
950
|
+
|
|
955
951
|
responses_docsearch = None
|
|
956
952
|
samples_docsearch = None
|
|
957
953
|
prompt_template = None
|
|
958
|
-
persisted_config = None
|
|
959
954
|
try:
|
|
960
955
|
with model_storage.read_from(resource) as path:
|
|
961
956
|
responses_docsearch = load_faiss_vector_store(
|
|
@@ -973,15 +968,12 @@ class IntentlessPolicy(Policy):
|
|
|
973
968
|
prompt_template = rasa.shared.utils.io.read_file(
|
|
974
969
|
path / INTENTLESS_PROMPT_TEMPLATE_FILE_NAME
|
|
975
970
|
)
|
|
976
|
-
persisted_config = rasa.shared.utils.io.read_json_file(
|
|
977
|
-
path / INTENTLESS_CONFIG_FILE_NAME
|
|
978
|
-
)
|
|
979
971
|
except (ValueError, FileNotFoundError, FileIOException) as e:
|
|
980
972
|
structlogger.warning(
|
|
981
973
|
"intentless_policy.load.failed", error=e, resource_name=resource.name
|
|
982
974
|
)
|
|
983
975
|
|
|
984
|
-
|
|
976
|
+
return cls(
|
|
985
977
|
config,
|
|
986
978
|
model_storage,
|
|
987
979
|
resource,
|
|
@@ -991,14 +983,6 @@ class IntentlessPolicy(Policy):
|
|
|
991
983
|
prompt_template=prompt_template,
|
|
992
984
|
)
|
|
993
985
|
|
|
994
|
-
cls._perform_inference_time_health_checks(
|
|
995
|
-
persisted_config,
|
|
996
|
-
policy.config.get(LLM_CONFIG_KEY),
|
|
997
|
-
policy.config.get(EMBEDDINGS_CONFIG_KEY),
|
|
998
|
-
)
|
|
999
|
-
|
|
1000
|
-
return policy
|
|
1001
|
-
|
|
1002
986
|
@classmethod
|
|
1003
987
|
def fingerprint_addon(cls, config: Dict[str, Any]) -> Optional[str]:
|
|
1004
988
|
"""Add a fingerprint of intentless policy for the graph."""
|
|
@@ -1018,52 +1002,26 @@ class IntentlessPolicy(Policy):
|
|
|
1018
1002
|
[prompt_template, llm_config, embedding_config]
|
|
1019
1003
|
)
|
|
1020
1004
|
|
|
1021
|
-
def _perform_training_time_health_checks(
|
|
1022
|
-
self,
|
|
1023
|
-
) -> Tuple[Optional[str], Optional[str]]:
|
|
1024
|
-
train_model_name = perform_training_time_llm_health_check(
|
|
1025
|
-
self.config.get(LLM_CONFIG_KEY),
|
|
1026
|
-
DEFAULT_LLM_CONFIG,
|
|
1027
|
-
"intentless_policy.train",
|
|
1028
|
-
IntentlessPolicy.__name__,
|
|
1029
|
-
)
|
|
1030
|
-
train_embedding_name = perform_training_time_embeddings_health_check(
|
|
1031
|
-
self.config.get(EMBEDDINGS_CONFIG_KEY),
|
|
1032
|
-
DEFAULT_EMBEDDINGS_CONFIG,
|
|
1033
|
-
"intentless_policy.train",
|
|
1034
|
-
IntentlessPolicy.__name__,
|
|
1035
|
-
)
|
|
1036
|
-
return train_model_name, train_embedding_name
|
|
1037
|
-
|
|
1038
1005
|
@classmethod
|
|
1039
|
-
def
|
|
1040
|
-
cls,
|
|
1041
|
-
persisted_config: Optional[Dict[str, Any]],
|
|
1042
|
-
resolved_llm_config: Optional[Dict[str, Any]],
|
|
1043
|
-
resolved_embeddings_config: Optional[Dict[str, Any]],
|
|
1006
|
+
def _perform_health_checks(
|
|
1007
|
+
cls, config: Dict[Text, Any], log_source_method: str
|
|
1044
1008
|
) -> None:
|
|
1045
|
-
|
|
1046
|
-
|
|
1047
|
-
|
|
1048
|
-
|
|
1049
|
-
)
|
|
1050
|
-
perform_inference_time_llm_health_check(
|
|
1051
|
-
resolved_llm_config,
|
|
1009
|
+
# Perform health check of the LLM client config
|
|
1010
|
+
llm_config = resolve_model_client_config(config.get(LLM_CONFIG_KEY, {}))
|
|
1011
|
+
cls.perform_llm_health_check(
|
|
1012
|
+
llm_config,
|
|
1052
1013
|
DEFAULT_LLM_CONFIG,
|
|
1053
|
-
|
|
1054
|
-
"intentless_policy.load",
|
|
1014
|
+
log_source_method,
|
|
1055
1015
|
IntentlessPolicy.__name__,
|
|
1056
1016
|
)
|
|
1057
1017
|
|
|
1058
|
-
|
|
1059
|
-
|
|
1060
|
-
|
|
1061
|
-
else None
|
|
1018
|
+
# Perform health check of the embeddings client config
|
|
1019
|
+
embeddings_config = resolve_model_client_config(
|
|
1020
|
+
config.get(EMBEDDINGS_CONFIG_KEY, {})
|
|
1062
1021
|
)
|
|
1063
|
-
|
|
1064
|
-
|
|
1022
|
+
cls.perform_embeddings_health_check(
|
|
1023
|
+
embeddings_config,
|
|
1065
1024
|
DEFAULT_EMBEDDINGS_CONFIG,
|
|
1066
|
-
|
|
1067
|
-
"intentless_policy.load",
|
|
1025
|
+
log_source_method,
|
|
1068
1026
|
IntentlessPolicy.__name__,
|
|
1069
1027
|
)
|
|
@@ -17,7 +17,6 @@ from rasa.dialogue_understanding.commands import Command, SetSlotCommand
|
|
|
17
17
|
from rasa.dialogue_understanding.commands.noop_command import NoopCommand
|
|
18
18
|
from rasa.dialogue_understanding.generator.constants import (
|
|
19
19
|
LLM_CONFIG_KEY,
|
|
20
|
-
TRAINED_MODEL_NAME_CONFIG_KEY,
|
|
21
20
|
)
|
|
22
21
|
from rasa.engine.graph import ExecutionContext, GraphComponent
|
|
23
22
|
from rasa.engine.recipes.default_recipe import DefaultV1Recipe
|
|
@@ -36,6 +35,7 @@ from rasa.shared.exceptions import InvalidConfigException, FileIOException
|
|
|
36
35
|
from rasa.shared.nlu.constants import COMMANDS, TEXT
|
|
37
36
|
from rasa.shared.nlu.training_data.message import Message
|
|
38
37
|
from rasa.shared.nlu.training_data.training_data import TrainingData
|
|
38
|
+
from rasa.shared.utils.health_check.llm_health_check_mixin import LLMHealthCheckMixin
|
|
39
39
|
from rasa.shared.utils.io import deep_container_fingerprint
|
|
40
40
|
from rasa.shared.utils.llm import (
|
|
41
41
|
DEFAULT_OPENAI_CHAT_MODEL_NAME,
|
|
@@ -43,10 +43,6 @@ from rasa.shared.utils.llm import (
|
|
|
43
43
|
llm_factory,
|
|
44
44
|
resolve_model_client_config,
|
|
45
45
|
)
|
|
46
|
-
from rasa.shared.utils.health_check import (
|
|
47
|
-
perform_training_time_llm_health_check,
|
|
48
|
-
perform_inference_time_llm_health_check,
|
|
49
|
-
)
|
|
50
46
|
from rasa.utils.log_utils import log_llm
|
|
51
47
|
|
|
52
48
|
LLM_BASED_ROUTER_PROMPT_FILE_NAME = "llm_based_router_prompt.jinja2"
|
|
@@ -80,7 +76,7 @@ structlogger = structlog.get_logger()
|
|
|
80
76
|
],
|
|
81
77
|
is_trainable=True,
|
|
82
78
|
)
|
|
83
|
-
class LLMBasedRouter(GraphComponent):
|
|
79
|
+
class LLMBasedRouter(LLMHealthCheckMixin, GraphComponent):
|
|
84
80
|
@staticmethod
|
|
85
81
|
def get_default_config() -> Dict[str, Any]:
|
|
86
82
|
"""The component's default config (see parent class for full docstring)."""
|
|
@@ -144,13 +140,11 @@ class LLMBasedRouter(GraphComponent):
|
|
|
144
140
|
|
|
145
141
|
def train(self, training_data: TrainingData) -> Resource:
|
|
146
142
|
"""Train the intent classifier on a data set."""
|
|
147
|
-
self.
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
LLMBasedRouter.__name__,
|
|
153
|
-
)
|
|
143
|
+
self.perform_llm_health_check(
|
|
144
|
+
self.config.get(LLM_CONFIG_KEY),
|
|
145
|
+
DEFAULT_LLM_CONFIG,
|
|
146
|
+
"llm_based_router.train",
|
|
147
|
+
LLMBasedRouter.__name__,
|
|
154
148
|
)
|
|
155
149
|
|
|
156
150
|
self.persist()
|
|
@@ -166,37 +160,28 @@ class LLMBasedRouter(GraphComponent):
|
|
|
166
160
|
**kwargs: Any,
|
|
167
161
|
) -> "LLMBasedRouter":
|
|
168
162
|
"""Loads trained component (see parent class for full docstring)."""
|
|
163
|
+
|
|
164
|
+
# Perform health check on the resolved LLM client config
|
|
165
|
+
llm_config = resolve_model_client_config(config.get(LLM_CONFIG_KEY, {}))
|
|
166
|
+
cls.perform_llm_health_check(
|
|
167
|
+
llm_config,
|
|
168
|
+
DEFAULT_LLM_CONFIG,
|
|
169
|
+
"llm_based_router.load",
|
|
170
|
+
LLMBasedRouter.__name__,
|
|
171
|
+
)
|
|
172
|
+
|
|
169
173
|
prompt_template = None
|
|
170
|
-
persisted_config = None
|
|
171
174
|
try:
|
|
172
175
|
with model_storage.read_from(resource) as path:
|
|
173
176
|
prompt_template = rasa.shared.utils.io.read_file(
|
|
174
177
|
path / LLM_BASED_ROUTER_PROMPT_FILE_NAME
|
|
175
178
|
)
|
|
176
|
-
persisted_config = rasa.shared.utils.io.read_json_file(
|
|
177
|
-
path / LLM_BASED_ROUTER_CONFIG_FILE_NAME
|
|
178
|
-
)
|
|
179
179
|
except (FileNotFoundError, FileIOException) as e:
|
|
180
180
|
structlogger.warning(
|
|
181
181
|
"llm_based_router.load.failed", error=e, resource=resource.name
|
|
182
182
|
)
|
|
183
183
|
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
train_model_name = (
|
|
187
|
-
persisted_config.get(TRAINED_MODEL_NAME_CONFIG_KEY, None)
|
|
188
|
-
if persisted_config
|
|
189
|
-
else None
|
|
190
|
-
)
|
|
191
|
-
perform_inference_time_llm_health_check(
|
|
192
|
-
router.config.get(LLM_CONFIG_KEY),
|
|
193
|
-
DEFAULT_LLM_CONFIG,
|
|
194
|
-
train_model_name,
|
|
195
|
-
"llm_based_router.load",
|
|
196
|
-
LLMBasedRouter.__name__,
|
|
197
|
-
)
|
|
198
|
-
|
|
199
|
-
return router
|
|
184
|
+
return cls(config, model_storage, resource, prompt_template=prompt_template)
|
|
200
185
|
|
|
201
186
|
@classmethod
|
|
202
187
|
def create(
|
|
@@ -18,8 +18,6 @@ DEFAULT_LLM_CONFIG = {
|
|
|
18
18
|
}
|
|
19
19
|
|
|
20
20
|
LLM_CONFIG_KEY = "llm"
|
|
21
|
-
TRAINED_MODEL_NAME_CONFIG_KEY = "trained_llm_model_name"
|
|
22
|
-
TRAINED_EMBEDDINGS_CONFIG_KEY = "trained_embeddings_model_name"
|
|
23
21
|
USER_INPUT_CONFIG_KEY = "user_input"
|
|
24
22
|
|
|
25
23
|
FLOW_RETRIEVAL_KEY = "flow_retrieval"
|
|
@@ -27,12 +27,9 @@ from langchain.schema.embeddings import Embeddings
|
|
|
27
27
|
from langchain_community.vectorstores.faiss import FAISS
|
|
28
28
|
from langchain_community.vectorstores.utils import DistanceStrategy
|
|
29
29
|
|
|
30
|
-
from rasa.dialogue_understanding.generator.constants import (
|
|
31
|
-
TRAINED_EMBEDDINGS_CONFIG_KEY,
|
|
32
|
-
)
|
|
33
30
|
from rasa.engine.storage.resource import Resource
|
|
34
31
|
from rasa.engine.storage.storage import ModelStorage
|
|
35
|
-
|
|
32
|
+
import rasa.shared.utils.io
|
|
36
33
|
from rasa.shared.constants import (
|
|
37
34
|
EMBEDDINGS_CONFIG_KEY,
|
|
38
35
|
PROVIDER_CONFIG_KEY,
|
|
@@ -41,12 +38,15 @@ from rasa.shared.constants import (
|
|
|
41
38
|
from rasa.shared.core.domain import Domain
|
|
42
39
|
from rasa.shared.core.flows import FlowsList
|
|
43
40
|
from rasa.shared.core.trackers import DialogueStateTracker
|
|
44
|
-
from rasa.shared.exceptions import ProviderClientAPIException
|
|
41
|
+
from rasa.shared.exceptions import ProviderClientAPIException
|
|
45
42
|
from rasa.shared.nlu.constants import TEXT, FLOWS_FROM_SEMANTIC_SEARCH
|
|
46
43
|
from rasa.shared.nlu.training_data.message import Message
|
|
47
44
|
from rasa.shared.providers.embedding._langchain_embedding_client_adapter import (
|
|
48
45
|
_LangchainEmbeddingClientAdapter,
|
|
49
46
|
)
|
|
47
|
+
from rasa.shared.utils.health_check.embeddings_health_check_mixin import (
|
|
48
|
+
EmbeddingsHealthCheckMixin,
|
|
49
|
+
)
|
|
50
50
|
from rasa.shared.utils.llm import (
|
|
51
51
|
tracker_as_readable_transcript,
|
|
52
52
|
embedder_factory,
|
|
@@ -56,11 +56,6 @@ from rasa.shared.utils.llm import (
|
|
|
56
56
|
allowed_values_for_slot,
|
|
57
57
|
resolve_model_client_config,
|
|
58
58
|
)
|
|
59
|
-
from rasa.shared.utils.health_check import (
|
|
60
|
-
perform_training_time_embeddings_health_check,
|
|
61
|
-
perform_inference_time_embeddings_health_check,
|
|
62
|
-
)
|
|
63
|
-
from rasa.shared.utils.io import dump_obj_as_json_to_file, read_json_file
|
|
64
59
|
|
|
65
60
|
DEFAULT_FLOW_DOCUMENT_TEMPLATE = importlib.resources.read_text(
|
|
66
61
|
"rasa.dialogue_understanding.generator", "flow_document_template.jinja2"
|
|
@@ -85,7 +80,7 @@ DEFAULT_SHOULD_EMBED_SLOTS = True
|
|
|
85
80
|
structlogger = structlog.get_logger()
|
|
86
81
|
|
|
87
82
|
|
|
88
|
-
class FlowRetrieval:
|
|
83
|
+
class FlowRetrieval(EmbeddingsHealthCheckMixin):
|
|
89
84
|
@classmethod
|
|
90
85
|
def get_default_config(cls) -> Dict[str, Any]:
|
|
91
86
|
"""The default config for the flow retrieval."""
|
|
@@ -94,7 +89,6 @@ class FlowRetrieval:
|
|
|
94
89
|
MAX_FLOWS_FROM_SEMANTIC_SEARCH_KEY: DEFAULT_MAX_FLOWS_FROM_SEMANTIC_SEARCH,
|
|
95
90
|
TURNS_TO_EMBED_KEY: DEFAULT_TURNS_TO_EMBED,
|
|
96
91
|
SHOULD_EMBED_SLOTS_KEY: DEFAULT_SHOULD_EMBED_SLOTS,
|
|
97
|
-
TRAINED_EMBEDDINGS_CONFIG_KEY: None,
|
|
98
92
|
}
|
|
99
93
|
|
|
100
94
|
def __init__(
|
|
@@ -147,16 +141,6 @@ class FlowRetrieval:
|
|
|
147
141
|
|
|
148
142
|
return config
|
|
149
143
|
|
|
150
|
-
def train(self) -> None:
|
|
151
|
-
self.config[TRAINED_EMBEDDINGS_CONFIG_KEY] = (
|
|
152
|
-
perform_training_time_embeddings_health_check(
|
|
153
|
-
self.config.get(EMBEDDINGS_CONFIG_KEY),
|
|
154
|
-
DEFAULT_EMBEDDINGS_CONFIG,
|
|
155
|
-
"flow_retrieval.train",
|
|
156
|
-
FlowRetrieval.__name__,
|
|
157
|
-
)
|
|
158
|
-
)
|
|
159
|
-
|
|
160
144
|
@classmethod
|
|
161
145
|
def load(
|
|
162
146
|
cls,
|
|
@@ -166,6 +150,18 @@ class FlowRetrieval:
|
|
|
166
150
|
**kwargs: Any,
|
|
167
151
|
) -> "FlowRetrieval":
|
|
168
152
|
"""Load flow retrieval with previously populated FAISS vector store."""
|
|
153
|
+
|
|
154
|
+
# Perform health check on resolved embedding client config
|
|
155
|
+
embeddings_config = resolve_model_client_config(
|
|
156
|
+
config.get(EMBEDDINGS_CONFIG_KEY, {})
|
|
157
|
+
)
|
|
158
|
+
cls.perform_embeddings_health_check(
|
|
159
|
+
embeddings_config,
|
|
160
|
+
DEFAULT_EMBEDDINGS_CONFIG,
|
|
161
|
+
"flow_retrieval.load",
|
|
162
|
+
FlowRetrieval.__name__,
|
|
163
|
+
)
|
|
164
|
+
|
|
169
165
|
# initialize base flow retrieval
|
|
170
166
|
flow_retrieval = FlowRetrieval(config, model_storage, resource)
|
|
171
167
|
# load vector store
|
|
@@ -174,30 +170,6 @@ class FlowRetrieval:
|
|
|
174
170
|
)
|
|
175
171
|
flow_retrieval.vector_store = vector_store
|
|
176
172
|
|
|
177
|
-
persisted_config = None
|
|
178
|
-
try:
|
|
179
|
-
with model_storage.read_from(resource) as path:
|
|
180
|
-
persisted_config = read_json_file(
|
|
181
|
-
path / FLOW_RETRIEVAL_CONFIG_FILE_NAME
|
|
182
|
-
)
|
|
183
|
-
except (FileNotFoundError, FileIOException) as e:
|
|
184
|
-
structlogger.warning(
|
|
185
|
-
"flow_retrieval.load.failed", error=e, resource=resource.name
|
|
186
|
-
)
|
|
187
|
-
|
|
188
|
-
train_embeddings_name = (
|
|
189
|
-
persisted_config.get(TRAINED_EMBEDDINGS_CONFIG_KEY, None)
|
|
190
|
-
if persisted_config
|
|
191
|
-
else None
|
|
192
|
-
)
|
|
193
|
-
perform_inference_time_embeddings_health_check(
|
|
194
|
-
flow_retrieval.config.get(EMBEDDINGS_CONFIG_KEY),
|
|
195
|
-
DEFAULT_EMBEDDINGS_CONFIG,
|
|
196
|
-
train_embeddings_name,
|
|
197
|
-
"flow_retrieval.load",
|
|
198
|
-
FlowRetrieval.__name__,
|
|
199
|
-
)
|
|
200
|
-
|
|
201
173
|
return flow_retrieval
|
|
202
174
|
|
|
203
175
|
@classmethod
|
|
@@ -243,10 +215,7 @@ class FlowRetrieval:
|
|
|
243
215
|
|
|
244
216
|
def persist(self) -> None:
|
|
245
217
|
self._persist_vector_store()
|
|
246
|
-
|
|
247
|
-
dump_obj_as_json_to_file(
|
|
248
|
-
path / FLOW_RETRIEVAL_CONFIG_FILE_NAME, self.config
|
|
249
|
-
)
|
|
218
|
+
self._persist_config()
|
|
250
219
|
|
|
251
220
|
def _persist_vector_store(self) -> None:
|
|
252
221
|
"""Persists the FAISS vector store."""
|
|
@@ -259,6 +228,12 @@ class FlowRetrieval:
|
|
|
259
228
|
event_info="Vector store is None, not persisted.",
|
|
260
229
|
)
|
|
261
230
|
|
|
231
|
+
def _persist_config(self) -> None:
|
|
232
|
+
with self._model_storage.write_to(self._resource) as path:
|
|
233
|
+
rasa.shared.utils.io.dump_obj_as_json_to_file(
|
|
234
|
+
path / FLOW_RETRIEVAL_CONFIG_FILE_NAME, self.config
|
|
235
|
+
)
|
|
236
|
+
|
|
262
237
|
def populate(self, flows: FlowsList, domain: Domain) -> None:
|
|
263
238
|
"""Populates the vector store with embeddings generated from
|
|
264
239
|
documents based on the flow descriptions, and flow slots
|
|
@@ -268,6 +243,14 @@ class FlowRetrieval:
|
|
|
268
243
|
flows: List of flows to populate the vector store with.
|
|
269
244
|
domain: The domain containing relevant slot information.
|
|
270
245
|
"""
|
|
246
|
+
# Perform health check before populating the vector store with flows
|
|
247
|
+
self.perform_embeddings_health_check(
|
|
248
|
+
self.config.get(EMBEDDINGS_CONFIG_KEY),
|
|
249
|
+
DEFAULT_EMBEDDINGS_CONFIG,
|
|
250
|
+
"flow_retrieval.train",
|
|
251
|
+
FlowRetrieval.__name__,
|
|
252
|
+
)
|
|
253
|
+
|
|
271
254
|
flows_to_embedd = flows.exclude_link_only_flows()
|
|
272
255
|
embeddings = self._create_embedder(self.config)
|
|
273
256
|
documents = self._generate_flow_documents(flows_to_embedd, domain)
|
|
@@ -17,7 +17,6 @@ from rasa.dialogue_understanding.generator.constants import (
|
|
|
17
17
|
FLOW_RETRIEVAL_KEY,
|
|
18
18
|
FLOW_RETRIEVAL_ACTIVE_KEY,
|
|
19
19
|
FLOW_RETRIEVAL_FLOW_THRESHOLD,
|
|
20
|
-
TRAINED_MODEL_NAME_CONFIG_KEY,
|
|
21
20
|
)
|
|
22
21
|
from rasa.dialogue_understanding.generator.flow_retrieval import FlowRetrieval
|
|
23
22
|
from rasa.engine.graph import GraphComponent, ExecutionContext
|
|
@@ -33,27 +32,26 @@ from rasa.shared.exceptions import ProviderClientAPIException
|
|
|
33
32
|
from rasa.shared.nlu.constants import FLOWS_IN_PROMPT
|
|
34
33
|
from rasa.shared.nlu.training_data.message import Message
|
|
35
34
|
from rasa.shared.nlu.training_data.training_data import TrainingData
|
|
35
|
+
from rasa.shared.utils.health_check.llm_health_check_mixin import LLMHealthCheckMixin
|
|
36
36
|
from rasa.shared.utils.llm import (
|
|
37
37
|
allowed_values_for_slot,
|
|
38
38
|
llm_factory,
|
|
39
39
|
resolve_model_client_config,
|
|
40
40
|
)
|
|
41
|
-
from rasa.shared.utils.health_check import perform_training_time_llm_health_check
|
|
42
41
|
from rasa.utils.log_utils import log_llm
|
|
43
42
|
|
|
44
43
|
structlogger = structlog.get_logger()
|
|
45
44
|
|
|
46
45
|
|
|
47
|
-
LLM_BASED_COMMAND_GENERATOR_CONFIG_FILE = "config.json"
|
|
48
|
-
|
|
49
|
-
|
|
50
46
|
@DefaultV1Recipe.register(
|
|
51
47
|
[
|
|
52
48
|
DefaultV1Recipe.ComponentType.COMMAND_GENERATOR,
|
|
53
49
|
],
|
|
54
50
|
is_trainable=True,
|
|
55
51
|
)
|
|
56
|
-
class LLMBasedCommandGenerator(
|
|
52
|
+
class LLMBasedCommandGenerator(
|
|
53
|
+
LLMHealthCheckMixin, GraphComponent, CommandGenerator, ABC
|
|
54
|
+
):
|
|
57
55
|
"""An abstract class defining interface and common functionality
|
|
58
56
|
of an LLM-based command generators.
|
|
59
57
|
"""
|
|
@@ -106,11 +104,7 @@ class LLMBasedCommandGenerator(GraphComponent, CommandGenerator, ABC):
|
|
|
106
104
|
@abstractmethod
|
|
107
105
|
def persist(self) -> None:
|
|
108
106
|
"""Persist the component to disk for future loading."""
|
|
109
|
-
|
|
110
|
-
with self._model_storage.write_to(self._resource) as path:
|
|
111
|
-
rasa.shared.utils.io.dump_obj_as_json_to_file(
|
|
112
|
-
path / LLM_BASED_COMMAND_GENERATOR_CONFIG_FILE, self.config
|
|
113
|
-
)
|
|
107
|
+
pass
|
|
114
108
|
|
|
115
109
|
@abstractmethod
|
|
116
110
|
async def predict_commands(
|
|
@@ -173,13 +167,11 @@ class LLMBasedCommandGenerator(GraphComponent, CommandGenerator, ABC):
|
|
|
173
167
|
"""Train the llm based command generator. Stores all flows into a vector
|
|
174
168
|
store.
|
|
175
169
|
"""
|
|
176
|
-
self.
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
LLMBasedCommandGenerator.__name__,
|
|
182
|
-
)
|
|
170
|
+
self.perform_llm_health_check(
|
|
171
|
+
self.config.get(LLM_CONFIG_KEY),
|
|
172
|
+
DEFAULT_LLM_CONFIG,
|
|
173
|
+
"llm_based_command_generator.train",
|
|
174
|
+
LLMBasedCommandGenerator.__name__,
|
|
183
175
|
)
|
|
184
176
|
|
|
185
177
|
if (
|
|
@@ -210,12 +202,11 @@ class LLMBasedCommandGenerator(GraphComponent, CommandGenerator, ABC):
|
|
|
210
202
|
except Exception as e:
|
|
211
203
|
structlogger.error(
|
|
212
204
|
"llm_based_command_generator.train.failed",
|
|
213
|
-
event_info="Flow retrieval store
|
|
205
|
+
event_info="Flow retrieval store is inaccessible.",
|
|
214
206
|
error=e,
|
|
215
207
|
)
|
|
216
208
|
raise
|
|
217
|
-
|
|
218
|
-
self.flow_retrieval.train()
|
|
209
|
+
|
|
219
210
|
self.persist()
|
|
220
211
|
return self._resource
|
|
221
212
|
|
|
@@ -251,25 +242,6 @@ class LLMBasedCommandGenerator(GraphComponent, CommandGenerator, ABC):
|
|
|
251
242
|
)
|
|
252
243
|
return None
|
|
253
244
|
|
|
254
|
-
@classmethod
|
|
255
|
-
def load_config_from_model_storage(
|
|
256
|
-
cls,
|
|
257
|
-
model_storage: ModelStorage,
|
|
258
|
-
resource: Resource,
|
|
259
|
-
) -> Optional[Text]:
|
|
260
|
-
try:
|
|
261
|
-
with model_storage.read_from(resource) as path:
|
|
262
|
-
return rasa.shared.utils.io.read_json_file(
|
|
263
|
-
path / LLM_BASED_COMMAND_GENERATOR_CONFIG_FILE
|
|
264
|
-
)
|
|
265
|
-
except (FileNotFoundError, FileIOException) as e:
|
|
266
|
-
structlogger.warning(
|
|
267
|
-
"llm_based_command_generator.load_config.failed",
|
|
268
|
-
error=e,
|
|
269
|
-
resource=resource.name,
|
|
270
|
-
)
|
|
271
|
-
return None
|
|
272
|
-
|
|
273
245
|
@classmethod
|
|
274
246
|
def load_flow_retrival(
|
|
275
247
|
cls,
|
|
@@ -24,7 +24,6 @@ from rasa.dialogue_understanding.generator.constants import (
|
|
|
24
24
|
LLM_CONFIG_KEY,
|
|
25
25
|
USER_INPUT_CONFIG_KEY,
|
|
26
26
|
FLOW_RETRIEVAL_KEY,
|
|
27
|
-
TRAINED_MODEL_NAME_CONFIG_KEY,
|
|
28
27
|
DEFAULT_LLM_CONFIG,
|
|
29
28
|
)
|
|
30
29
|
from rasa.dialogue_understanding.generator.flow_retrieval import FlowRetrieval
|
|
@@ -60,7 +59,6 @@ from rasa.shared.utils.llm import (
|
|
|
60
59
|
allowed_values_for_slot,
|
|
61
60
|
resolve_model_client_config,
|
|
62
61
|
)
|
|
63
|
-
from rasa.shared.utils.health_check import perform_inference_time_llm_health_check
|
|
64
62
|
|
|
65
63
|
# multistep template keys
|
|
66
64
|
HANDLE_FLOWS_KEY = "handle_flows"
|
|
@@ -77,6 +75,7 @@ DEFAULT_HANDLE_FLOWS_TEMPLATE = importlib.resources.read_text(
|
|
|
77
75
|
DEFAULT_FILL_SLOTS_TEMPLATE = importlib.resources.read_text(
|
|
78
76
|
"rasa.dialogue_understanding.generator.multi_step", "fill_slots_prompt.jinja2"
|
|
79
77
|
).strip()
|
|
78
|
+
MULTI_STEP_LLM_COMMAND_GENERATOR_CONFIG_FILE = "config.json"
|
|
80
79
|
|
|
81
80
|
# dictionary of template names and associated file names and default values
|
|
82
81
|
PROMPT_TEMPLATES = {
|
|
@@ -145,15 +144,18 @@ class MultiStepLLMCommandGenerator(LLMBasedCommandGenerator):
|
|
|
145
144
|
**kwargs: Any,
|
|
146
145
|
) -> "MultiStepLLMCommandGenerator":
|
|
147
146
|
"""Loads trained component (see parent class for full docstring)."""
|
|
148
|
-
prompts = cls._load_prompt_templates(model_storage, resource)
|
|
149
147
|
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
148
|
+
# Perform health check of the LLM client config
|
|
149
|
+
llm_config = resolve_model_client_config(config.get(LLM_CONFIG_KEY, {}))
|
|
150
|
+
cls.perform_llm_health_check(
|
|
151
|
+
llm_config,
|
|
152
|
+
DEFAULT_LLM_CONFIG,
|
|
153
|
+
"multi_step_llm_command_generator.load",
|
|
154
|
+
MultiStepLLMCommandGenerator.__name__,
|
|
155
155
|
)
|
|
156
156
|
|
|
157
|
+
prompts = cls._load_prompt_templates(model_storage, resource)
|
|
158
|
+
|
|
157
159
|
# init base command generator
|
|
158
160
|
command_generator = cls(config, model_storage, resource, prompts)
|
|
159
161
|
# load flow retrieval if enabled
|
|
@@ -162,23 +164,12 @@ class MultiStepLLMCommandGenerator(LLMBasedCommandGenerator):
|
|
|
162
164
|
command_generator.config, model_storage, resource
|
|
163
165
|
)
|
|
164
166
|
|
|
165
|
-
perform_inference_time_llm_health_check(
|
|
166
|
-
command_generator.config.get(LLM_CONFIG_KEY),
|
|
167
|
-
DEFAULT_LLM_CONFIG,
|
|
168
|
-
train_model_name,
|
|
169
|
-
"multi_step_llm_command_generator.load",
|
|
170
|
-
MultiStepLLMCommandGenerator.__name__,
|
|
171
|
-
)
|
|
172
|
-
|
|
173
167
|
return command_generator
|
|
174
168
|
|
|
175
169
|
def persist(self) -> None:
|
|
176
170
|
"""Persist this component to disk for future loading."""
|
|
177
|
-
super().persist()
|
|
178
|
-
|
|
179
|
-
# persist prompt template
|
|
180
171
|
self._persist_prompt_templates()
|
|
181
|
-
|
|
172
|
+
self._persist_config()
|
|
182
173
|
if self.flow_retrieval is not None:
|
|
183
174
|
self.flow_retrieval.persist()
|
|
184
175
|
|
|
@@ -411,6 +402,13 @@ class MultiStepLLMCommandGenerator(LLMBasedCommandGenerator):
|
|
|
411
402
|
file_path = path / file_name
|
|
412
403
|
rasa.shared.utils.io.write_text_file(template, file_path)
|
|
413
404
|
|
|
405
|
+
def _persist_config(self) -> None:
|
|
406
|
+
"""Persist config as a source of truth for resolved clients."""
|
|
407
|
+
with self._model_storage.write_to(self._resource) as path:
|
|
408
|
+
rasa.shared.utils.io.dump_obj_as_json_to_file(
|
|
409
|
+
path / MULTI_STEP_LLM_COMMAND_GENERATOR_CONFIG_FILE, self.config
|
|
410
|
+
)
|
|
411
|
+
|
|
414
412
|
async def _predict_commands_with_multi_step(
|
|
415
413
|
self,
|
|
416
414
|
message: Message,
|
|
@@ -19,6 +19,7 @@ from rasa.engine.storage.storage import ModelStorage
|
|
|
19
19
|
from rasa.shared.constants import ROUTE_TO_CALM_SLOT
|
|
20
20
|
from rasa.shared.core.domain import Domain
|
|
21
21
|
from rasa.shared.core.flows.flows_list import FlowsList
|
|
22
|
+
from rasa.shared.core.flows.steps import CollectInformationFlowStep
|
|
22
23
|
from rasa.shared.core.slot_mappings import (
|
|
23
24
|
SlotFillingManager,
|
|
24
25
|
extract_slot_value,
|
|
@@ -217,7 +218,24 @@ def _issue_set_slot_commands(
|
|
|
217
218
|
commands: List[Command] = []
|
|
218
219
|
domain = domain if domain else Domain.empty()
|
|
219
220
|
slot_filling_manager = SlotFillingManager(domain, tracker, message)
|
|
220
|
-
|
|
221
|
+
|
|
222
|
+
# only use slots that don't have ask_before_filling set to True
|
|
223
|
+
available_slot_names = flows.available_slot_names(ask_before_filling=False)
|
|
224
|
+
|
|
225
|
+
# check if the current step is a CollectInformationFlowStep
|
|
226
|
+
# in case it has ask_before_filling set to True, we need to add the
|
|
227
|
+
# slot to the available_slot_names
|
|
228
|
+
if tracker.active_flow:
|
|
229
|
+
flow = flows.flow_by_id(tracker.active_flow)
|
|
230
|
+
step_id = tracker.current_step_id
|
|
231
|
+
if flow is not None:
|
|
232
|
+
current_step = flow.step_by_id(step_id)
|
|
233
|
+
if (
|
|
234
|
+
current_step
|
|
235
|
+
and isinstance(current_step, CollectInformationFlowStep)
|
|
236
|
+
and current_step.ask_before_filling
|
|
237
|
+
):
|
|
238
|
+
available_slot_names.add(current_step.collect)
|
|
221
239
|
|
|
222
240
|
for _, slot in tracker.slots.items():
|
|
223
241
|
# if a slot is not collected in available flows,
|