rasa-pro 3.11.0rc1__py3-none-any.whl → 3.11.0rc3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- rasa/cli/inspect.py +2 -0
- rasa/cli/studio/studio.py +18 -8
- rasa/core/actions/action_repeat_bot_messages.py +17 -0
- rasa/core/channels/channel.py +17 -0
- rasa/core/channels/development_inspector.py +4 -1
- rasa/core/channels/voice_ready/audiocodes.py +15 -4
- rasa/core/channels/voice_ready/jambonz.py +13 -2
- rasa/core/channels/voice_ready/twilio_voice.py +6 -21
- rasa/core/channels/voice_stream/asr/asr_event.py +1 -1
- rasa/core/channels/voice_stream/asr/azure.py +5 -7
- rasa/core/channels/voice_stream/asr/deepgram.py +13 -11
- rasa/core/channels/voice_stream/voice_channel.py +61 -19
- rasa/core/nlg/contextual_response_rephraser.py +20 -12
- rasa/core/policies/enterprise_search_policy.py +32 -72
- rasa/core/policies/intentless_policy.py +34 -72
- rasa/dialogue_understanding/coexistence/llm_based_router.py +18 -33
- rasa/dialogue_understanding/generator/constants.py +0 -2
- rasa/dialogue_understanding/generator/flow_retrieval.py +33 -50
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +12 -40
- rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +18 -20
- rasa/dialogue_understanding/generator/nlu_command_adapter.py +19 -1
- rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +26 -22
- rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +9 -0
- rasa/dialogue_understanding/processor/command_processor.py +21 -1
- rasa/e2e_test/e2e_test_case.py +85 -6
- rasa/engine/validation.py +88 -60
- rasa/model_service.py +3 -0
- rasa/nlu/tokenizers/whitespace_tokenizer.py +3 -14
- rasa/server.py +3 -1
- rasa/shared/constants.py +5 -5
- rasa/shared/core/constants.py +1 -1
- rasa/shared/core/domain.py +0 -26
- rasa/shared/core/flows/flows_list.py +5 -1
- rasa/shared/providers/_configs/litellm_router_client_config.py +29 -9
- rasa/shared/providers/embedding/_base_litellm_embedding_client.py +6 -14
- rasa/shared/providers/embedding/litellm_router_embedding_client.py +1 -1
- rasa/shared/providers/llm/_base_litellm_client.py +32 -1
- rasa/shared/providers/llm/litellm_router_llm_client.py +56 -1
- rasa/shared/providers/llm/self_hosted_llm_client.py +4 -28
- rasa/shared/providers/router/_base_litellm_router_client.py +35 -1
- rasa/shared/utils/common.py +1 -1
- rasa/shared/utils/health_check/__init__.py +0 -0
- rasa/shared/utils/health_check/embeddings_health_check_mixin.py +31 -0
- rasa/shared/utils/health_check/health_check.py +256 -0
- rasa/shared/utils/health_check/llm_health_check_mixin.py +31 -0
- rasa/shared/utils/llm.py +5 -2
- rasa/shared/utils/yaml.py +102 -62
- rasa/studio/auth.py +3 -5
- rasa/studio/config.py +13 -4
- rasa/studio/constants.py +1 -0
- rasa/studio/data_handler.py +10 -3
- rasa/studio/upload.py +21 -10
- rasa/telemetry.py +15 -1
- rasa/tracing/config.py +3 -1
- rasa/tracing/instrumentation/attribute_extractors.py +20 -0
- rasa/tracing/instrumentation/instrumentation.py +121 -0
- rasa/utils/common.py +5 -0
- rasa/utils/io.py +8 -16
- rasa/utils/sanic_error_handler.py +32 -0
- rasa/version.py +1 -1
- {rasa_pro-3.11.0rc1.dist-info → rasa_pro-3.11.0rc3.dist-info}/METADATA +3 -2
- {rasa_pro-3.11.0rc1.dist-info → rasa_pro-3.11.0rc3.dist-info}/RECORD +65 -61
- rasa/shared/utils/health_check.py +0 -533
- {rasa_pro-3.11.0rc1.dist-info → rasa_pro-3.11.0rc3.dist-info}/NOTICE +0 -0
- {rasa_pro-3.11.0rc1.dist-info → rasa_pro-3.11.0rc3.dist-info}/WHEEL +0 -0
- {rasa_pro-3.11.0rc1.dist-info → rasa_pro-3.11.0rc3.dist-info}/entry_points.txt +0 -0
|
@@ -2,7 +2,6 @@ from typing import Any, Dict, Optional, Text
|
|
|
2
2
|
|
|
3
3
|
import structlog
|
|
4
4
|
from jinja2 import Template
|
|
5
|
-
|
|
6
5
|
from rasa import telemetry
|
|
7
6
|
from rasa.core.nlg.response import TemplatedNaturalLanguageGenerator
|
|
8
7
|
from rasa.core.nlg.summarize import summarize_conversation
|
|
@@ -14,11 +13,12 @@ from rasa.shared.constants import (
|
|
|
14
13
|
PROVIDER_CONFIG_KEY,
|
|
15
14
|
OPENAI_PROVIDER,
|
|
16
15
|
TIMEOUT_CONFIG_KEY,
|
|
17
|
-
|
|
16
|
+
MODEL_GROUP_ID_CONFIG_KEY,
|
|
18
17
|
)
|
|
19
18
|
from rasa.shared.core.domain import KEY_RESPONSES_TEXT, Domain
|
|
20
19
|
from rasa.shared.core.events import BotUttered, UserUttered
|
|
21
20
|
from rasa.shared.core.trackers import DialogueStateTracker
|
|
21
|
+
from rasa.shared.utils.health_check.llm_health_check_mixin import LLMHealthCheckMixin
|
|
22
22
|
from rasa.shared.utils.llm import (
|
|
23
23
|
DEFAULT_OPENAI_GENERATE_MODEL_NAME,
|
|
24
24
|
DEFAULT_OPENAI_MAX_GENERATED_TOKENS,
|
|
@@ -28,7 +28,6 @@ from rasa.shared.utils.llm import (
|
|
|
28
28
|
llm_factory,
|
|
29
29
|
resolve_model_client_config,
|
|
30
30
|
)
|
|
31
|
-
from rasa.shared.utils.health_check import perform_training_time_llm_health_check
|
|
32
31
|
from rasa.shared.utils.llm import (
|
|
33
32
|
tracker_as_readable_transcript,
|
|
34
33
|
)
|
|
@@ -44,6 +43,8 @@ RESPONSE_REPHRASING_TEMPLATE_KEY = "rephrase_prompt"
|
|
|
44
43
|
RESPONSE_SUMMARISE_CONVERSATION_KEY = "summarize_conversation"
|
|
45
44
|
|
|
46
45
|
DEFAULT_REPHRASE_ALL = False
|
|
46
|
+
DEFAULT_SUMMARIZE_HISTORY = True
|
|
47
|
+
DEFAULT_MAX_HISTORICAL_TURNS = 5
|
|
47
48
|
|
|
48
49
|
DEFAULT_LLM_CONFIG = {
|
|
49
50
|
PROVIDER_CONFIG_KEY: OPENAI_PROVIDER,
|
|
@@ -68,7 +69,9 @@ Suggested AI Response: {{suggested_response}}
|
|
|
68
69
|
Rephrased AI Response:"""
|
|
69
70
|
|
|
70
71
|
|
|
71
|
-
class ContextualResponseRephraser(
|
|
72
|
+
class ContextualResponseRephraser(
|
|
73
|
+
LLMHealthCheckMixin, TemplatedNaturalLanguageGenerator
|
|
74
|
+
):
|
|
72
75
|
"""Generates responses based on modified templates.
|
|
73
76
|
|
|
74
77
|
The templates are filled with the entities and slots that are available in the
|
|
@@ -102,13 +105,19 @@ class ContextualResponseRephraser(TemplatedNaturalLanguageGenerator):
|
|
|
102
105
|
self.trace_prompt_tokens = self.nlg_endpoint.kwargs.get(
|
|
103
106
|
"trace_prompt_tokens", False
|
|
104
107
|
)
|
|
108
|
+
self.summarize_history = self.nlg_endpoint.kwargs.get(
|
|
109
|
+
"summarize_history", DEFAULT_SUMMARIZE_HISTORY
|
|
110
|
+
)
|
|
111
|
+
self.max_historical_turns = self.nlg_endpoint.kwargs.get(
|
|
112
|
+
"max_historical_turns", DEFAULT_MAX_HISTORICAL_TURNS
|
|
113
|
+
)
|
|
105
114
|
|
|
106
115
|
self.llm_config = resolve_model_client_config(
|
|
107
116
|
self.nlg_endpoint.kwargs.get(LLM_CONFIG_KEY),
|
|
108
117
|
ContextualResponseRephraser.__name__,
|
|
109
118
|
)
|
|
110
119
|
|
|
111
|
-
|
|
120
|
+
self.perform_llm_health_check(
|
|
112
121
|
self.llm_config,
|
|
113
122
|
DEFAULT_LLM_CONFIG,
|
|
114
123
|
"contextual_response_rephraser.init",
|
|
@@ -213,18 +222,17 @@ class ContextualResponseRephraser(TemplatedNaturalLanguageGenerator):
|
|
|
213
222
|
prompt_template_text = self._template_for_response_rephrasing(response)
|
|
214
223
|
|
|
215
224
|
# Retrieve inputs for the dynamic prompt
|
|
216
|
-
transcript = tracker_as_readable_transcript(tracker, max_turns=5)
|
|
217
225
|
latest_message = self._last_message_if_human(tracker)
|
|
218
226
|
current_input = f"{USER}: {latest_message}" if latest_message else ""
|
|
219
227
|
|
|
220
228
|
# Only summarise conversation history if flagged
|
|
221
|
-
|
|
222
|
-
RESPONSE_SUMMARISE_CONVERSATION_KEY, False
|
|
223
|
-
)
|
|
224
|
-
if summarize_conversation_flag:
|
|
229
|
+
if self.summarize_history:
|
|
225
230
|
history = await self._create_history(tracker)
|
|
226
231
|
else:
|
|
227
|
-
history
|
|
232
|
+
# make sure the transcript/history contains the last user utterance
|
|
233
|
+
max_turns = max(self.max_historical_turns, 1)
|
|
234
|
+
history = tracker_as_readable_transcript(tracker, max_turns=max_turns)
|
|
235
|
+
# the history already contains the current input
|
|
228
236
|
current_input = ""
|
|
229
237
|
|
|
230
238
|
prompt = Template(prompt_template_text).render(
|
|
@@ -245,7 +253,7 @@ class ContextualResponseRephraser(TemplatedNaturalLanguageGenerator):
|
|
|
245
253
|
llm_type=self.llm_property(PROVIDER_CONFIG_KEY),
|
|
246
254
|
llm_model=self.llm_property(MODEL_CONFIG_KEY)
|
|
247
255
|
or self.llm_property(MODEL_NAME_CONFIG_KEY),
|
|
248
|
-
llm_model_group_id=self.llm_property(
|
|
256
|
+
llm_model_group_id=self.llm_property(MODEL_GROUP_ID_CONFIG_KEY),
|
|
249
257
|
)
|
|
250
258
|
if not (updated_text := await self._generate_llm_response(prompt)):
|
|
251
259
|
# If the LLM fails to generate a response, we
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import importlib.resources
|
|
2
2
|
import json
|
|
3
3
|
import re
|
|
4
|
-
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Text
|
|
4
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Text
|
|
5
5
|
import dotenv
|
|
6
6
|
import structlog
|
|
7
7
|
from jinja2 import Template
|
|
@@ -25,8 +25,6 @@ from rasa.core.policies.policy import Policy, PolicyPrediction
|
|
|
25
25
|
from rasa.core.utils import AvailableEndpoints
|
|
26
26
|
from rasa.dialogue_understanding.generator.constants import (
|
|
27
27
|
LLM_CONFIG_KEY,
|
|
28
|
-
TRAINED_MODEL_NAME_CONFIG_KEY,
|
|
29
|
-
TRAINED_EMBEDDINGS_CONFIG_KEY,
|
|
30
28
|
)
|
|
31
29
|
from rasa.dialogue_understanding.patterns.cannot_handle import (
|
|
32
30
|
CannotHandlePatternFlowStackFrame,
|
|
@@ -53,7 +51,7 @@ from rasa.shared.constants import (
|
|
|
53
51
|
OPENAI_PROVIDER,
|
|
54
52
|
TIMEOUT_CONFIG_KEY,
|
|
55
53
|
MODEL_NAME_CONFIG_KEY,
|
|
56
|
-
|
|
54
|
+
MODEL_GROUP_ID_CONFIG_KEY,
|
|
57
55
|
)
|
|
58
56
|
from rasa.shared.core.constants import (
|
|
59
57
|
ACTION_CANCEL_FLOW,
|
|
@@ -71,6 +69,10 @@ from rasa.shared.providers.embedding._langchain_embedding_client_adapter import
|
|
|
71
69
|
)
|
|
72
70
|
from rasa.shared.providers.llm.llm_client import LLMClient
|
|
73
71
|
from rasa.shared.utils.cli import print_error_and_exit
|
|
72
|
+
from rasa.shared.utils.health_check.embeddings_health_check_mixin import (
|
|
73
|
+
EmbeddingsHealthCheckMixin,
|
|
74
|
+
)
|
|
75
|
+
from rasa.shared.utils.health_check.llm_health_check_mixin import LLMHealthCheckMixin
|
|
74
76
|
from rasa.shared.utils.io import deep_container_fingerprint
|
|
75
77
|
from rasa.shared.utils.llm import (
|
|
76
78
|
DEFAULT_OPENAI_CHAT_MODEL_NAME,
|
|
@@ -82,12 +84,6 @@ from rasa.shared.utils.llm import (
|
|
|
82
84
|
tracker_as_readable_transcript,
|
|
83
85
|
resolve_model_client_config,
|
|
84
86
|
)
|
|
85
|
-
from rasa.shared.utils.health_check import (
|
|
86
|
-
perform_training_time_llm_health_check,
|
|
87
|
-
perform_training_time_embeddings_health_check,
|
|
88
|
-
perform_inference_time_llm_health_check,
|
|
89
|
-
perform_inference_time_embeddings_health_check,
|
|
90
|
-
)
|
|
91
87
|
from rasa.telemetry import (
|
|
92
88
|
track_enterprise_search_policy_predict,
|
|
93
89
|
track_enterprise_search_policy_train_completed,
|
|
@@ -161,7 +157,7 @@ class VectorStoreConfigurationError(RasaException):
|
|
|
161
157
|
@DefaultV1Recipe.register(
|
|
162
158
|
DefaultV1Recipe.ComponentType.POLICY_WITH_END_TO_END_SUPPORT, is_trainable=True
|
|
163
159
|
)
|
|
164
|
-
class EnterpriseSearchPolicy(Policy):
|
|
160
|
+
class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Policy):
|
|
165
161
|
"""Policy which uses a vector store and LLMs to respond to user messages.
|
|
166
162
|
|
|
167
163
|
The policy uses a vector store and LLMs to respond to user messages. The
|
|
@@ -300,6 +296,9 @@ class EnterpriseSearchPolicy(Policy):
|
|
|
300
296
|
A policy must return its resource locator so that potential children nodes
|
|
301
297
|
can load the policy from the resource.
|
|
302
298
|
"""
|
|
299
|
+
# Perform health checks for both LLM and embeddings client configs
|
|
300
|
+
self._perform_health_checks(self.config, "enterprise_search_policy.train")
|
|
301
|
+
|
|
303
302
|
store_type = self.vector_store_config.get(VECTOR_STORE_TYPE_PROPERTY)
|
|
304
303
|
|
|
305
304
|
# telemetry call to track training start
|
|
@@ -319,11 +318,6 @@ class EnterpriseSearchPolicy(Policy):
|
|
|
319
318
|
f"required environment variables. Error: {e}"
|
|
320
319
|
)
|
|
321
320
|
|
|
322
|
-
(
|
|
323
|
-
self.config[TRAINED_MODEL_NAME_CONFIG_KEY],
|
|
324
|
-
self.config[TRAINED_EMBEDDINGS_CONFIG_KEY],
|
|
325
|
-
) = self._perform_training_time_health_checks()
|
|
326
|
-
|
|
327
321
|
if store_type == DEFAULT_VECTOR_STORE_TYPE:
|
|
328
322
|
logger.info("enterprise_search_policy.train.faiss")
|
|
329
323
|
with self._model_storage.write_to(self._resource) as path:
|
|
@@ -343,12 +337,12 @@ class EnterpriseSearchPolicy(Policy):
|
|
|
343
337
|
embeddings_model=self.embeddings_config.get(MODEL_CONFIG_KEY)
|
|
344
338
|
or self.embeddings_config.get(MODEL_NAME_CONFIG_KEY),
|
|
345
339
|
embeddings_model_group_id=self.embeddings_config.get(
|
|
346
|
-
|
|
340
|
+
MODEL_GROUP_ID_CONFIG_KEY
|
|
347
341
|
),
|
|
348
342
|
llm_type=self.llm_config.get(PROVIDER_CONFIG_KEY),
|
|
349
343
|
llm_model=self.llm_config.get(MODEL_CONFIG_KEY)
|
|
350
344
|
or self.llm_config.get(MODEL_NAME_CONFIG_KEY),
|
|
351
|
-
llm_model_group_id=self.llm_config.get(
|
|
345
|
+
llm_model_group_id=self.llm_config.get(MODEL_GROUP_ID_CONFIG_KEY),
|
|
352
346
|
citation_enabled=self.citation_enabled,
|
|
353
347
|
)
|
|
354
348
|
self.persist()
|
|
@@ -544,12 +538,12 @@ class EnterpriseSearchPolicy(Policy):
|
|
|
544
538
|
embeddings_model=self.embeddings_config.get(MODEL_CONFIG_KEY)
|
|
545
539
|
or self.embeddings_config.get(MODEL_NAME_CONFIG_KEY),
|
|
546
540
|
embeddings_model_group_id=self.embeddings_config.get(
|
|
547
|
-
|
|
541
|
+
MODEL_GROUP_ID_CONFIG_KEY
|
|
548
542
|
),
|
|
549
543
|
llm_type=self.llm_config.get(PROVIDER_CONFIG_KEY),
|
|
550
544
|
llm_model=self.llm_config.get(MODEL_CONFIG_KEY)
|
|
551
545
|
or self.llm_config.get(MODEL_NAME_CONFIG_KEY),
|
|
552
|
-
llm_model_group_id=self.llm_config.get(
|
|
546
|
+
llm_model_group_id=self.llm_config.get(MODEL_GROUP_ID_CONFIG_KEY),
|
|
553
547
|
citation_enabled=self.citation_enabled,
|
|
554
548
|
)
|
|
555
549
|
return self._create_prediction(
|
|
@@ -698,16 +692,16 @@ class EnterpriseSearchPolicy(Policy):
|
|
|
698
692
|
**kwargs: Any,
|
|
699
693
|
) -> "EnterpriseSearchPolicy":
|
|
700
694
|
"""Loads a trained policy (see parent class for full docstring)."""
|
|
695
|
+
|
|
696
|
+
# Perform health checks for both LLM and embeddings client configs
|
|
697
|
+
cls._perform_health_checks(config, "enterprise_search_policy.load")
|
|
698
|
+
|
|
701
699
|
prompt_template = None
|
|
702
|
-
persisted_config = None
|
|
703
700
|
try:
|
|
704
701
|
with model_storage.read_from(resource) as path:
|
|
705
702
|
prompt_template = rasa.shared.utils.io.read_file(
|
|
706
703
|
path / ENTERPRISE_SEARCH_PROMPT_FILE_NAME
|
|
707
704
|
)
|
|
708
|
-
persisted_config = rasa.shared.utils.io.read_json_file(
|
|
709
|
-
path / ENTERPRISE_SEARCH_CONFIG_FILE_NAME
|
|
710
|
-
)
|
|
711
705
|
except (FileNotFoundError, FileIOException) as e:
|
|
712
706
|
logger.warning(
|
|
713
707
|
"enterprise_search_policy.load.failed", error=e, resource=resource.name
|
|
@@ -737,7 +731,7 @@ class EnterpriseSearchPolicy(Policy):
|
|
|
737
731
|
embeddings=embeddings,
|
|
738
732
|
) # type: ignore
|
|
739
733
|
|
|
740
|
-
|
|
734
|
+
return cls(
|
|
741
735
|
config,
|
|
742
736
|
model_storage,
|
|
743
737
|
resource,
|
|
@@ -746,14 +740,6 @@ class EnterpriseSearchPolicy(Policy):
|
|
|
746
740
|
prompt_template=prompt_template,
|
|
747
741
|
)
|
|
748
742
|
|
|
749
|
-
cls._perform_inference_time_health_checks(
|
|
750
|
-
persisted_config,
|
|
751
|
-
policy.config.get(LLM_CONFIG_KEY),
|
|
752
|
-
policy.config.get(EMBEDDINGS_CONFIG_KEY),
|
|
753
|
-
)
|
|
754
|
-
|
|
755
|
-
return policy
|
|
756
|
-
|
|
757
743
|
@classmethod
|
|
758
744
|
def _get_local_knowledge_data(cls, config: Dict[str, Any]) -> Optional[List[str]]:
|
|
759
745
|
"""This is required only for local knowledge base types.
|
|
@@ -894,52 +880,26 @@ class EnterpriseSearchPolicy(Policy):
|
|
|
894
880
|
|
|
895
881
|
return joined_answer + joined_sources
|
|
896
882
|
|
|
897
|
-
def _perform_training_time_health_checks(
|
|
898
|
-
self,
|
|
899
|
-
) -> Tuple[Optional[str], Optional[str]]:
|
|
900
|
-
train_model_name = perform_training_time_llm_health_check(
|
|
901
|
-
self.config.get(LLM_CONFIG_KEY),
|
|
902
|
-
DEFAULT_LLM_CONFIG,
|
|
903
|
-
"enterprise_search_policy.train",
|
|
904
|
-
EnterpriseSearchPolicy.__name__,
|
|
905
|
-
)
|
|
906
|
-
train_embedding_name = perform_training_time_embeddings_health_check(
|
|
907
|
-
self.config.get(EMBEDDINGS_CONFIG_KEY),
|
|
908
|
-
DEFAULT_EMBEDDINGS_CONFIG,
|
|
909
|
-
"enterprise_search_policy.train",
|
|
910
|
-
EnterpriseSearchPolicy.__name__,
|
|
911
|
-
)
|
|
912
|
-
return train_model_name, train_embedding_name
|
|
913
|
-
|
|
914
883
|
@classmethod
|
|
915
|
-
def
|
|
916
|
-
cls,
|
|
917
|
-
persisted_config: Optional[Dict[str, Any]],
|
|
918
|
-
resolved_llm_config: Optional[Dict[str, Any]],
|
|
919
|
-
resolved_embeddings_config: Optional[Dict[str, Any]],
|
|
884
|
+
def _perform_health_checks(
|
|
885
|
+
cls, config: Dict[Text, Any], log_source_method: str
|
|
920
886
|
) -> None:
|
|
921
|
-
|
|
922
|
-
|
|
923
|
-
|
|
924
|
-
|
|
925
|
-
)
|
|
926
|
-
perform_inference_time_llm_health_check(
|
|
927
|
-
resolved_llm_config,
|
|
887
|
+
# Perform health check of the LLM client config
|
|
888
|
+
llm_config = resolve_model_client_config(config.get(LLM_CONFIG_KEY, {}))
|
|
889
|
+
cls.perform_llm_health_check(
|
|
890
|
+
llm_config,
|
|
928
891
|
DEFAULT_LLM_CONFIG,
|
|
929
|
-
|
|
930
|
-
"enterprise_search_policy.load",
|
|
892
|
+
log_source_method,
|
|
931
893
|
EnterpriseSearchPolicy.__name__,
|
|
932
894
|
)
|
|
933
895
|
|
|
934
|
-
|
|
935
|
-
|
|
936
|
-
|
|
937
|
-
else None
|
|
896
|
+
# Perform health check of the embeddings client config
|
|
897
|
+
embeddings_config = resolve_model_client_config(
|
|
898
|
+
config.get(EMBEDDINGS_CONFIG_KEY, {})
|
|
938
899
|
)
|
|
939
|
-
|
|
940
|
-
|
|
900
|
+
cls.perform_embeddings_health_check(
|
|
901
|
+
embeddings_config,
|
|
941
902
|
DEFAULT_EMBEDDINGS_CONFIG,
|
|
942
|
-
|
|
943
|
-
"enterprise_search_policy.load",
|
|
903
|
+
log_source_method,
|
|
944
904
|
EnterpriseSearchPolicy.__name__,
|
|
945
905
|
)
|
|
@@ -18,10 +18,6 @@ from rasa.core.constants import (
|
|
|
18
18
|
UTTER_SOURCE_METADATA_KEY,
|
|
19
19
|
)
|
|
20
20
|
from rasa.core.policies.policy import Policy, PolicyPrediction, SupportedData
|
|
21
|
-
from rasa.dialogue_understanding.generator.constants import (
|
|
22
|
-
TRAINED_MODEL_NAME_CONFIG_KEY,
|
|
23
|
-
TRAINED_EMBEDDINGS_CONFIG_KEY,
|
|
24
|
-
)
|
|
25
21
|
from rasa.dialogue_understanding.patterns.chitchat import FLOW_PATTERN_CHITCHAT
|
|
26
22
|
from rasa.dialogue_understanding.stack.frames import (
|
|
27
23
|
ChitChatStackFrame,
|
|
@@ -43,7 +39,7 @@ from rasa.shared.constants import (
|
|
|
43
39
|
PROVIDER_CONFIG_KEY,
|
|
44
40
|
OPENAI_PROVIDER,
|
|
45
41
|
TIMEOUT_CONFIG_KEY,
|
|
46
|
-
|
|
42
|
+
MODEL_GROUP_ID_CONFIG_KEY,
|
|
47
43
|
)
|
|
48
44
|
from rasa.shared.core.constants import ACTION_LISTEN_NAME
|
|
49
45
|
from rasa.shared.core.constants import ACTION_TRIGGER_CHITCHAT
|
|
@@ -64,6 +60,10 @@ from rasa.shared.providers.embedding._langchain_embedding_client_adapter import
|
|
|
64
60
|
_LangchainEmbeddingClientAdapter,
|
|
65
61
|
)
|
|
66
62
|
from rasa.shared.providers.llm.llm_client import LLMClient
|
|
63
|
+
from rasa.shared.utils.health_check.embeddings_health_check_mixin import (
|
|
64
|
+
EmbeddingsHealthCheckMixin,
|
|
65
|
+
)
|
|
66
|
+
from rasa.shared.utils.health_check.llm_health_check_mixin import LLMHealthCheckMixin
|
|
67
67
|
from rasa.shared.utils.io import deep_container_fingerprint
|
|
68
68
|
from rasa.shared.utils.llm import (
|
|
69
69
|
AI,
|
|
@@ -79,12 +79,6 @@ from rasa.shared.utils.llm import (
|
|
|
79
79
|
tracker_as_readable_transcript,
|
|
80
80
|
resolve_model_client_config,
|
|
81
81
|
)
|
|
82
|
-
from rasa.shared.utils.health_check import (
|
|
83
|
-
perform_training_time_llm_health_check,
|
|
84
|
-
perform_training_time_embeddings_health_check,
|
|
85
|
-
perform_inference_time_llm_health_check,
|
|
86
|
-
perform_inference_time_embeddings_health_check,
|
|
87
|
-
)
|
|
88
82
|
from rasa.utils.log_utils import log_llm
|
|
89
83
|
from rasa.utils.ml_utils import (
|
|
90
84
|
extract_ai_response_examples,
|
|
@@ -383,7 +377,7 @@ def conversation_as_prompt(conversation: Conversation) -> str:
|
|
|
383
377
|
@DefaultV1Recipe.register(
|
|
384
378
|
DefaultV1Recipe.ComponentType.POLICY_WITH_END_TO_END_SUPPORT, is_trainable=True
|
|
385
379
|
)
|
|
386
|
-
class IntentlessPolicy(Policy):
|
|
380
|
+
class IntentlessPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Policy):
|
|
387
381
|
"""Policy which uses a language model to generate the next action.
|
|
388
382
|
|
|
389
383
|
The policy uses the OpenAI API to generate the next action based on the
|
|
@@ -516,10 +510,8 @@ class IntentlessPolicy(Policy):
|
|
|
516
510
|
A policy must return its resource locator so that potential children nodes
|
|
517
511
|
can load the policy from the resource.
|
|
518
512
|
"""
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
self.config[TRAINED_EMBEDDINGS_CONFIG_KEY],
|
|
522
|
-
) = self._perform_training_time_health_checks()
|
|
513
|
+
# Perform health checks of both LLM and embeddings client configs
|
|
514
|
+
self._perform_health_checks(self.config, "intentless_policy.train")
|
|
523
515
|
|
|
524
516
|
responses = filter_responses(responses, forms, flows or FlowsList([]))
|
|
525
517
|
telemetry.track_intentless_policy_train()
|
|
@@ -566,11 +558,13 @@ class IntentlessPolicy(Policy):
|
|
|
566
558
|
embeddings_type=self.embeddings_property(PROVIDER_CONFIG_KEY),
|
|
567
559
|
embeddings_model=self.embeddings_property(MODEL_CONFIG_KEY)
|
|
568
560
|
or self.embeddings_property(MODEL_NAME_CONFIG_KEY),
|
|
569
|
-
embeddings_model_group_id=self.embeddings_property(
|
|
561
|
+
embeddings_model_group_id=self.embeddings_property(
|
|
562
|
+
MODEL_GROUP_ID_CONFIG_KEY
|
|
563
|
+
),
|
|
570
564
|
llm_type=self.llm_property(PROVIDER_CONFIG_KEY),
|
|
571
565
|
llm_model=self.llm_property(MODEL_CONFIG_KEY)
|
|
572
566
|
or self.llm_property(MODEL_NAME_CONFIG_KEY),
|
|
573
|
-
llm_model_group_id=self.llm_property(
|
|
567
|
+
llm_model_group_id=self.llm_property(MODEL_GROUP_ID_CONFIG_KEY),
|
|
574
568
|
)
|
|
575
569
|
|
|
576
570
|
self.persist()
|
|
@@ -650,11 +644,13 @@ class IntentlessPolicy(Policy):
|
|
|
650
644
|
embeddings_type=self.embeddings_property(PROVIDER_CONFIG_KEY),
|
|
651
645
|
embeddings_model=self.embeddings_property(MODEL_CONFIG_KEY)
|
|
652
646
|
or self.embeddings_property(MODEL_NAME_CONFIG_KEY),
|
|
653
|
-
embeddings_model_group_id=self.embeddings_property(
|
|
647
|
+
embeddings_model_group_id=self.embeddings_property(
|
|
648
|
+
MODEL_GROUP_ID_CONFIG_KEY
|
|
649
|
+
),
|
|
654
650
|
llm_type=self.llm_property(PROVIDER_CONFIG_KEY),
|
|
655
651
|
llm_model=self.llm_property(MODEL_CONFIG_KEY)
|
|
656
652
|
or self.llm_property(MODEL_NAME_CONFIG_KEY),
|
|
657
|
-
llm_model_group_id=self.llm_property(
|
|
653
|
+
llm_model_group_id=self.llm_property(MODEL_GROUP_ID_CONFIG_KEY),
|
|
658
654
|
score=score,
|
|
659
655
|
)
|
|
660
656
|
|
|
@@ -952,10 +948,13 @@ class IntentlessPolicy(Policy):
|
|
|
952
948
|
**kwargs: Any,
|
|
953
949
|
) -> "IntentlessPolicy":
|
|
954
950
|
"""Loads a trained policy (see parent class for full docstring)."""
|
|
951
|
+
|
|
952
|
+
# Perform health checks of both LLM and embeddings client configs
|
|
953
|
+
cls._perform_health_checks(config, "intentless_policy.load")
|
|
954
|
+
|
|
955
955
|
responses_docsearch = None
|
|
956
956
|
samples_docsearch = None
|
|
957
957
|
prompt_template = None
|
|
958
|
-
persisted_config = None
|
|
959
958
|
try:
|
|
960
959
|
with model_storage.read_from(resource) as path:
|
|
961
960
|
responses_docsearch = load_faiss_vector_store(
|
|
@@ -973,15 +972,12 @@ class IntentlessPolicy(Policy):
|
|
|
973
972
|
prompt_template = rasa.shared.utils.io.read_file(
|
|
974
973
|
path / INTENTLESS_PROMPT_TEMPLATE_FILE_NAME
|
|
975
974
|
)
|
|
976
|
-
persisted_config = rasa.shared.utils.io.read_json_file(
|
|
977
|
-
path / INTENTLESS_CONFIG_FILE_NAME
|
|
978
|
-
)
|
|
979
975
|
except (ValueError, FileNotFoundError, FileIOException) as e:
|
|
980
976
|
structlogger.warning(
|
|
981
977
|
"intentless_policy.load.failed", error=e, resource_name=resource.name
|
|
982
978
|
)
|
|
983
979
|
|
|
984
|
-
|
|
980
|
+
return cls(
|
|
985
981
|
config,
|
|
986
982
|
model_storage,
|
|
987
983
|
resource,
|
|
@@ -991,14 +987,6 @@ class IntentlessPolicy(Policy):
|
|
|
991
987
|
prompt_template=prompt_template,
|
|
992
988
|
)
|
|
993
989
|
|
|
994
|
-
cls._perform_inference_time_health_checks(
|
|
995
|
-
persisted_config,
|
|
996
|
-
policy.config.get(LLM_CONFIG_KEY),
|
|
997
|
-
policy.config.get(EMBEDDINGS_CONFIG_KEY),
|
|
998
|
-
)
|
|
999
|
-
|
|
1000
|
-
return policy
|
|
1001
|
-
|
|
1002
990
|
@classmethod
|
|
1003
991
|
def fingerprint_addon(cls, config: Dict[str, Any]) -> Optional[str]:
|
|
1004
992
|
"""Add a fingerprint of intentless policy for the graph."""
|
|
@@ -1018,52 +1006,26 @@ class IntentlessPolicy(Policy):
|
|
|
1018
1006
|
[prompt_template, llm_config, embedding_config]
|
|
1019
1007
|
)
|
|
1020
1008
|
|
|
1021
|
-
def _perform_training_time_health_checks(
|
|
1022
|
-
self,
|
|
1023
|
-
) -> Tuple[Optional[str], Optional[str]]:
|
|
1024
|
-
train_model_name = perform_training_time_llm_health_check(
|
|
1025
|
-
self.config.get(LLM_CONFIG_KEY),
|
|
1026
|
-
DEFAULT_LLM_CONFIG,
|
|
1027
|
-
"intentless_policy.train",
|
|
1028
|
-
IntentlessPolicy.__name__,
|
|
1029
|
-
)
|
|
1030
|
-
train_embedding_name = perform_training_time_embeddings_health_check(
|
|
1031
|
-
self.config.get(EMBEDDINGS_CONFIG_KEY),
|
|
1032
|
-
DEFAULT_EMBEDDINGS_CONFIG,
|
|
1033
|
-
"intentless_policy.train",
|
|
1034
|
-
IntentlessPolicy.__name__,
|
|
1035
|
-
)
|
|
1036
|
-
return train_model_name, train_embedding_name
|
|
1037
|
-
|
|
1038
1009
|
@classmethod
|
|
1039
|
-
def
|
|
1040
|
-
cls,
|
|
1041
|
-
persisted_config: Optional[Dict[str, Any]],
|
|
1042
|
-
resolved_llm_config: Optional[Dict[str, Any]],
|
|
1043
|
-
resolved_embeddings_config: Optional[Dict[str, Any]],
|
|
1010
|
+
def _perform_health_checks(
|
|
1011
|
+
cls, config: Dict[Text, Any], log_source_method: str
|
|
1044
1012
|
) -> None:
|
|
1045
|
-
|
|
1046
|
-
|
|
1047
|
-
|
|
1048
|
-
|
|
1049
|
-
)
|
|
1050
|
-
perform_inference_time_llm_health_check(
|
|
1051
|
-
resolved_llm_config,
|
|
1013
|
+
# Perform health check of the LLM client config
|
|
1014
|
+
llm_config = resolve_model_client_config(config.get(LLM_CONFIG_KEY, {}))
|
|
1015
|
+
cls.perform_llm_health_check(
|
|
1016
|
+
llm_config,
|
|
1052
1017
|
DEFAULT_LLM_CONFIG,
|
|
1053
|
-
|
|
1054
|
-
"intentless_policy.load",
|
|
1018
|
+
log_source_method,
|
|
1055
1019
|
IntentlessPolicy.__name__,
|
|
1056
1020
|
)
|
|
1057
1021
|
|
|
1058
|
-
|
|
1059
|
-
|
|
1060
|
-
|
|
1061
|
-
else None
|
|
1022
|
+
# Perform health check of the embeddings client config
|
|
1023
|
+
embeddings_config = resolve_model_client_config(
|
|
1024
|
+
config.get(EMBEDDINGS_CONFIG_KEY, {})
|
|
1062
1025
|
)
|
|
1063
|
-
|
|
1064
|
-
|
|
1026
|
+
cls.perform_embeddings_health_check(
|
|
1027
|
+
embeddings_config,
|
|
1065
1028
|
DEFAULT_EMBEDDINGS_CONFIG,
|
|
1066
|
-
|
|
1067
|
-
"intentless_policy.load",
|
|
1029
|
+
log_source_method,
|
|
1068
1030
|
IntentlessPolicy.__name__,
|
|
1069
1031
|
)
|
|
@@ -17,7 +17,6 @@ from rasa.dialogue_understanding.commands import Command, SetSlotCommand
|
|
|
17
17
|
from rasa.dialogue_understanding.commands.noop_command import NoopCommand
|
|
18
18
|
from rasa.dialogue_understanding.generator.constants import (
|
|
19
19
|
LLM_CONFIG_KEY,
|
|
20
|
-
TRAINED_MODEL_NAME_CONFIG_KEY,
|
|
21
20
|
)
|
|
22
21
|
from rasa.engine.graph import ExecutionContext, GraphComponent
|
|
23
22
|
from rasa.engine.recipes.default_recipe import DefaultV1Recipe
|
|
@@ -36,6 +35,7 @@ from rasa.shared.exceptions import InvalidConfigException, FileIOException
|
|
|
36
35
|
from rasa.shared.nlu.constants import COMMANDS, TEXT
|
|
37
36
|
from rasa.shared.nlu.training_data.message import Message
|
|
38
37
|
from rasa.shared.nlu.training_data.training_data import TrainingData
|
|
38
|
+
from rasa.shared.utils.health_check.llm_health_check_mixin import LLMHealthCheckMixin
|
|
39
39
|
from rasa.shared.utils.io import deep_container_fingerprint
|
|
40
40
|
from rasa.shared.utils.llm import (
|
|
41
41
|
DEFAULT_OPENAI_CHAT_MODEL_NAME,
|
|
@@ -43,10 +43,6 @@ from rasa.shared.utils.llm import (
|
|
|
43
43
|
llm_factory,
|
|
44
44
|
resolve_model_client_config,
|
|
45
45
|
)
|
|
46
|
-
from rasa.shared.utils.health_check import (
|
|
47
|
-
perform_training_time_llm_health_check,
|
|
48
|
-
perform_inference_time_llm_health_check,
|
|
49
|
-
)
|
|
50
46
|
from rasa.utils.log_utils import log_llm
|
|
51
47
|
|
|
52
48
|
LLM_BASED_ROUTER_PROMPT_FILE_NAME = "llm_based_router_prompt.jinja2"
|
|
@@ -80,7 +76,7 @@ structlogger = structlog.get_logger()
|
|
|
80
76
|
],
|
|
81
77
|
is_trainable=True,
|
|
82
78
|
)
|
|
83
|
-
class LLMBasedRouter(GraphComponent):
|
|
79
|
+
class LLMBasedRouter(LLMHealthCheckMixin, GraphComponent):
|
|
84
80
|
@staticmethod
|
|
85
81
|
def get_default_config() -> Dict[str, Any]:
|
|
86
82
|
"""The component's default config (see parent class for full docstring)."""
|
|
@@ -144,13 +140,11 @@ class LLMBasedRouter(GraphComponent):
|
|
|
144
140
|
|
|
145
141
|
def train(self, training_data: TrainingData) -> Resource:
|
|
146
142
|
"""Train the intent classifier on a data set."""
|
|
147
|
-
self.
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
LLMBasedRouter.__name__,
|
|
153
|
-
)
|
|
143
|
+
self.perform_llm_health_check(
|
|
144
|
+
self.config.get(LLM_CONFIG_KEY),
|
|
145
|
+
DEFAULT_LLM_CONFIG,
|
|
146
|
+
"llm_based_router.train",
|
|
147
|
+
LLMBasedRouter.__name__,
|
|
154
148
|
)
|
|
155
149
|
|
|
156
150
|
self.persist()
|
|
@@ -166,37 +160,28 @@ class LLMBasedRouter(GraphComponent):
|
|
|
166
160
|
**kwargs: Any,
|
|
167
161
|
) -> "LLMBasedRouter":
|
|
168
162
|
"""Loads trained component (see parent class for full docstring)."""
|
|
163
|
+
|
|
164
|
+
# Perform health check on the resolved LLM client config
|
|
165
|
+
llm_config = resolve_model_client_config(config.get(LLM_CONFIG_KEY, {}))
|
|
166
|
+
cls.perform_llm_health_check(
|
|
167
|
+
llm_config,
|
|
168
|
+
DEFAULT_LLM_CONFIG,
|
|
169
|
+
"llm_based_router.load",
|
|
170
|
+
LLMBasedRouter.__name__,
|
|
171
|
+
)
|
|
172
|
+
|
|
169
173
|
prompt_template = None
|
|
170
|
-
persisted_config = None
|
|
171
174
|
try:
|
|
172
175
|
with model_storage.read_from(resource) as path:
|
|
173
176
|
prompt_template = rasa.shared.utils.io.read_file(
|
|
174
177
|
path / LLM_BASED_ROUTER_PROMPT_FILE_NAME
|
|
175
178
|
)
|
|
176
|
-
persisted_config = rasa.shared.utils.io.read_json_file(
|
|
177
|
-
path / LLM_BASED_ROUTER_CONFIG_FILE_NAME
|
|
178
|
-
)
|
|
179
179
|
except (FileNotFoundError, FileIOException) as e:
|
|
180
180
|
structlogger.warning(
|
|
181
181
|
"llm_based_router.load.failed", error=e, resource=resource.name
|
|
182
182
|
)
|
|
183
183
|
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
train_model_name = (
|
|
187
|
-
persisted_config.get(TRAINED_MODEL_NAME_CONFIG_KEY, None)
|
|
188
|
-
if persisted_config
|
|
189
|
-
else None
|
|
190
|
-
)
|
|
191
|
-
perform_inference_time_llm_health_check(
|
|
192
|
-
router.config.get(LLM_CONFIG_KEY),
|
|
193
|
-
DEFAULT_LLM_CONFIG,
|
|
194
|
-
train_model_name,
|
|
195
|
-
"llm_based_router.load",
|
|
196
|
-
LLMBasedRouter.__name__,
|
|
197
|
-
)
|
|
198
|
-
|
|
199
|
-
return router
|
|
184
|
+
return cls(config, model_storage, resource, prompt_template=prompt_template)
|
|
200
185
|
|
|
201
186
|
@classmethod
|
|
202
187
|
def create(
|
|
@@ -18,8 +18,6 @@ DEFAULT_LLM_CONFIG = {
|
|
|
18
18
|
}
|
|
19
19
|
|
|
20
20
|
LLM_CONFIG_KEY = "llm"
|
|
21
|
-
TRAINED_MODEL_NAME_CONFIG_KEY = "trained_llm_model_name"
|
|
22
|
-
TRAINED_EMBEDDINGS_CONFIG_KEY = "trained_embeddings_model_name"
|
|
23
21
|
USER_INPUT_CONFIG_KEY = "user_input"
|
|
24
22
|
|
|
25
23
|
FLOW_RETRIEVAL_KEY = "flow_retrieval"
|