rasa-pro 3.11.0a4.dev3__py3-none-any.whl → 3.11.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- rasa/__main__.py +22 -12
- rasa/api.py +1 -1
- rasa/cli/arguments/default_arguments.py +1 -2
- rasa/cli/arguments/shell.py +5 -1
- rasa/cli/e2e_test.py +1 -1
- rasa/cli/evaluate.py +8 -8
- rasa/cli/inspect.py +6 -4
- rasa/cli/llm_fine_tuning.py +1 -1
- rasa/cli/project_templates/calm/config.yml +5 -7
- rasa/cli/project_templates/calm/endpoints.yml +8 -0
- rasa/cli/project_templates/tutorial/config.yml +8 -5
- rasa/cli/project_templates/tutorial/data/flows.yml +1 -1
- rasa/cli/project_templates/tutorial/data/patterns.yml +5 -0
- rasa/cli/project_templates/tutorial/domain.yml +14 -0
- rasa/cli/project_templates/tutorial/endpoints.yml +7 -7
- rasa/cli/run.py +1 -1
- rasa/cli/scaffold.py +4 -2
- rasa/cli/studio/studio.py +18 -8
- rasa/cli/utils.py +5 -0
- rasa/cli/x.py +8 -8
- rasa/constants.py +1 -1
- rasa/core/actions/action_repeat_bot_messages.py +17 -0
- rasa/core/channels/channel.py +20 -0
- rasa/core/channels/inspector/dist/assets/{arc-6852c607.js → arc-bc141fb2.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{c4Diagram-d0fbc5ce-acc952b2.js → c4Diagram-d0fbc5ce-be2db283.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{classDiagram-936ed81e-848a7597.js → classDiagram-936ed81e-55366915.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{classDiagram-v2-c3cb15f1-a73d3e68.js → classDiagram-v2-c3cb15f1-bb529518.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{createText-62fc7601-e5ee049d.js → createText-62fc7601-b0ec81d6.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{edges-f2ad444c-771e517e.js → edges-f2ad444c-6166330c.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{erDiagram-9d236eb7-aa347178.js → erDiagram-9d236eb7-5ccc6a8e.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{flowDb-1972c806-651fc57d.js → flowDb-1972c806-fca3bfe4.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{flowDiagram-7ea5b25a-ca67804f.js → flowDiagram-7ea5b25a-4739080f.js} +1 -1
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-736177bf.js +1 -0
- rasa/core/channels/inspector/dist/assets/{flowchart-elk-definition-abe16c3d-2dbc568d.js → flowchart-elk-definition-abe16c3d-7c1b0e0f.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{ganttDiagram-9b5ea136-25a65bd8.js → ganttDiagram-9b5ea136-772fd050.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{gitGraphDiagram-99d0ae7c-fdc7378d.js → gitGraphDiagram-99d0ae7c-8eae1dc9.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{index-2c4b9a3b-6f1fd606.js → index-2c4b9a3b-f55afcdf.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{index-efdd30c1.js → index-e7cef9de.js} +68 -68
- rasa/core/channels/inspector/dist/assets/{infoDiagram-736b4530-cb1a041a.js → infoDiagram-736b4530-124d4a14.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{journeyDiagram-df861f2b-14609879.js → journeyDiagram-df861f2b-7c4fae44.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{layout-2490f52b.js → layout-b9885fb6.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{line-40186f1f.js → line-7c59abb6.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{linear-08814e93.js → linear-4776f780.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{mindmap-definition-beec6740-1a534584.js → mindmap-definition-beec6740-2332c46c.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{pieDiagram-dbbf0591-72397b61.js → pieDiagram-dbbf0591-8fb39303.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{quadrantDiagram-4d7f4fd6-3bb0b6a3.js → quadrantDiagram-4d7f4fd6-3c7180a2.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{requirementDiagram-6fc4c22a-57334f61.js → requirementDiagram-6fc4c22a-e910bcb8.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{sankeyDiagram-8f13d901-111e1297.js → sankeyDiagram-8f13d901-ead16c89.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{sequenceDiagram-b655622a-10bcfe62.js → sequenceDiagram-b655622a-29a02a19.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{stateDiagram-59f0c015-acaf7513.js → stateDiagram-59f0c015-042b3137.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{stateDiagram-v2-2b26beab-3ec2a235.js → stateDiagram-v2-2b26beab-2178c0f3.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-080da4f6-62730289.js → styles-080da4f6-23ffa4fc.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-3dcbcfbf-5284ee76.js → styles-3dcbcfbf-94f59763.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-9c745c82-642435e3.js → styles-9c745c82-78a6bebc.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{svgDrawCommon-4835440b-b250a350.js → svgDrawCommon-4835440b-eae2a6f6.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{timeline-definition-5b62e21b-c2b147ed.js → timeline-definition-5b62e21b-5c968d92.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{xychartDiagram-2b33534f-f92cfea9.js → xychartDiagram-2b33534f-fd3db0d5.js} +1 -1
- rasa/core/channels/inspector/dist/index.html +1 -1
- rasa/core/channels/inspector/src/App.tsx +1 -1
- rasa/core/channels/inspector/src/helpers/audiostream.ts +77 -16
- rasa/core/channels/socketio.py +2 -1
- rasa/core/channels/telegram.py +1 -1
- rasa/core/channels/twilio.py +1 -1
- rasa/core/channels/voice_ready/audiocodes.py +12 -0
- rasa/core/channels/voice_ready/jambonz.py +15 -4
- rasa/core/channels/voice_ready/twilio_voice.py +6 -21
- rasa/core/channels/voice_stream/asr/asr_event.py +5 -0
- rasa/core/channels/voice_stream/asr/azure.py +122 -0
- rasa/core/channels/voice_stream/asr/deepgram.py +16 -6
- rasa/core/channels/voice_stream/audio_bytes.py +1 -0
- rasa/core/channels/voice_stream/browser_audio.py +31 -8
- rasa/core/channels/voice_stream/call_state.py +23 -0
- rasa/core/channels/voice_stream/tts/azure.py +6 -2
- rasa/core/channels/voice_stream/tts/cartesia.py +10 -6
- rasa/core/channels/voice_stream/tts/tts_engine.py +1 -0
- rasa/core/channels/voice_stream/twilio_media_streams.py +27 -18
- rasa/core/channels/voice_stream/util.py +4 -4
- rasa/core/channels/voice_stream/voice_channel.py +189 -39
- rasa/core/featurizers/single_state_featurizer.py +22 -1
- rasa/core/featurizers/tracker_featurizers.py +115 -18
- rasa/core/nlg/contextual_response_rephraser.py +32 -30
- rasa/core/persistor.py +86 -39
- rasa/core/policies/enterprise_search_policy.py +119 -60
- rasa/core/policies/flows/flow_executor.py +7 -4
- rasa/core/policies/intentless_policy.py +78 -22
- rasa/core/policies/ted_policy.py +58 -33
- rasa/core/policies/unexpected_intent_policy.py +15 -7
- rasa/core/processor.py +25 -0
- rasa/core/training/interactive.py +34 -35
- rasa/core/utils.py +8 -3
- rasa/dialogue_understanding/coexistence/llm_based_router.py +39 -12
- rasa/dialogue_understanding/commands/change_flow_command.py +6 -0
- rasa/dialogue_understanding/commands/user_silence_command.py +59 -0
- rasa/dialogue_understanding/commands/utils.py +5 -0
- rasa/dialogue_understanding/generator/constants.py +2 -0
- rasa/dialogue_understanding/generator/flow_retrieval.py +49 -4
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +37 -23
- rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +57 -10
- rasa/dialogue_understanding/generator/nlu_command_adapter.py +19 -1
- rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +71 -11
- rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +39 -0
- rasa/dialogue_understanding/patterns/user_silence.py +37 -0
- rasa/dialogue_understanding/processor/command_processor.py +21 -1
- rasa/e2e_test/e2e_test_case.py +85 -6
- rasa/e2e_test/e2e_test_runner.py +4 -2
- rasa/e2e_test/utils/io.py +1 -1
- rasa/engine/validation.py +316 -10
- rasa/model_manager/config.py +15 -3
- rasa/model_manager/model_api.py +15 -7
- rasa/model_manager/runner_service.py +8 -6
- rasa/model_manager/socket_bridge.py +6 -3
- rasa/model_manager/trainer_service.py +7 -5
- rasa/model_manager/utils.py +28 -7
- rasa/model_service.py +9 -2
- rasa/model_training.py +2 -0
- rasa/nlu/classifiers/diet_classifier.py +38 -25
- rasa/nlu/classifiers/logistic_regression_classifier.py +22 -9
- rasa/nlu/classifiers/sklearn_intent_classifier.py +37 -16
- rasa/nlu/extractors/crf_entity_extractor.py +93 -50
- rasa/nlu/featurizers/sparse_featurizer/count_vectors_featurizer.py +45 -16
- rasa/nlu/featurizers/sparse_featurizer/lexical_syntactic_featurizer.py +52 -17
- rasa/nlu/featurizers/sparse_featurizer/regex_featurizer.py +5 -3
- rasa/nlu/tokenizers/whitespace_tokenizer.py +3 -14
- rasa/server.py +3 -1
- rasa/shared/constants.py +36 -3
- rasa/shared/core/constants.py +7 -0
- rasa/shared/core/domain.py +26 -0
- rasa/shared/core/flows/flow.py +5 -0
- rasa/shared/core/flows/flows_list.py +5 -1
- rasa/shared/core/flows/flows_yaml_schema.json +10 -0
- rasa/shared/core/flows/utils.py +39 -0
- rasa/shared/core/flows/validation.py +96 -0
- rasa/shared/core/slots.py +5 -0
- rasa/shared/nlu/training_data/features.py +120 -2
- rasa/shared/providers/_configs/azure_openai_client_config.py +5 -3
- rasa/shared/providers/_configs/litellm_router_client_config.py +200 -0
- rasa/shared/providers/_configs/model_group_config.py +167 -0
- rasa/shared/providers/_configs/openai_client_config.py +1 -1
- rasa/shared/providers/_configs/rasa_llm_client_config.py +73 -0
- rasa/shared/providers/_configs/self_hosted_llm_client_config.py +1 -0
- rasa/shared/providers/_configs/utils.py +16 -0
- rasa/shared/providers/embedding/_base_litellm_embedding_client.py +18 -29
- rasa/shared/providers/embedding/azure_openai_embedding_client.py +54 -21
- rasa/shared/providers/embedding/litellm_router_embedding_client.py +135 -0
- rasa/shared/providers/llm/_base_litellm_client.py +37 -31
- rasa/shared/providers/llm/azure_openai_llm_client.py +50 -29
- rasa/shared/providers/llm/litellm_router_llm_client.py +127 -0
- rasa/shared/providers/llm/rasa_llm_client.py +112 -0
- rasa/shared/providers/llm/self_hosted_llm_client.py +1 -1
- rasa/shared/providers/mappings.py +19 -0
- rasa/shared/providers/router/__init__.py +0 -0
- rasa/shared/providers/router/_base_litellm_router_client.py +149 -0
- rasa/shared/providers/router/router_client.py +73 -0
- rasa/shared/utils/common.py +8 -0
- rasa/shared/utils/health_check/__init__.py +0 -0
- rasa/shared/utils/health_check/embeddings_health_check_mixin.py +31 -0
- rasa/shared/utils/health_check/health_check.py +256 -0
- rasa/shared/utils/health_check/llm_health_check_mixin.py +31 -0
- rasa/shared/utils/io.py +28 -6
- rasa/shared/utils/llm.py +353 -46
- rasa/shared/utils/yaml.py +111 -73
- rasa/studio/auth.py +3 -5
- rasa/studio/config.py +13 -4
- rasa/studio/constants.py +1 -0
- rasa/studio/data_handler.py +10 -3
- rasa/studio/upload.py +81 -26
- rasa/telemetry.py +92 -17
- rasa/tracing/config.py +2 -0
- rasa/tracing/instrumentation/attribute_extractors.py +94 -17
- rasa/tracing/instrumentation/instrumentation.py +121 -0
- rasa/utils/common.py +5 -0
- rasa/utils/io.py +7 -81
- rasa/utils/log_utils.py +9 -2
- rasa/utils/sanic_error_handler.py +32 -0
- rasa/utils/tensorflow/feature_array.py +366 -0
- rasa/utils/tensorflow/model_data.py +2 -193
- rasa/validator.py +70 -0
- rasa/version.py +1 -1
- {rasa_pro-3.11.0a4.dev3.dist-info → rasa_pro-3.11.0rc2.dist-info}/METADATA +11 -10
- {rasa_pro-3.11.0a4.dev3.dist-info → rasa_pro-3.11.0rc2.dist-info}/RECORD +183 -163
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-587d82d8.js +0 -1
- {rasa_pro-3.11.0a4.dev3.dist-info → rasa_pro-3.11.0rc2.dist-info}/NOTICE +0 -0
- {rasa_pro-3.11.0a4.dev3.dist-info → rasa_pro-3.11.0rc2.dist-info}/WHEEL +0 -0
- {rasa_pro-3.11.0a4.dev3.dist-info → rasa_pro-3.11.0rc2.dist-info}/entry_points.txt +0 -0
|
@@ -1,45 +1,42 @@
|
|
|
1
1
|
import importlib.resources
|
|
2
2
|
import json
|
|
3
|
-
import os
|
|
4
3
|
import re
|
|
5
4
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Text
|
|
6
|
-
|
|
7
5
|
import dotenv
|
|
8
6
|
import structlog
|
|
9
7
|
from jinja2 import Template
|
|
10
8
|
from pydantic import ValidationError
|
|
11
9
|
|
|
12
|
-
from rasa.shared.providers.embedding._langchain_embedding_client_adapter import (
|
|
13
|
-
_LangchainEmbeddingClientAdapter,
|
|
14
|
-
)
|
|
15
|
-
from rasa.shared.providers.llm.llm_client import LLMClient
|
|
16
|
-
|
|
17
10
|
import rasa.shared.utils.io
|
|
18
|
-
from rasa.telemetry import (
|
|
19
|
-
track_enterprise_search_policy_predict,
|
|
20
|
-
track_enterprise_search_policy_train_completed,
|
|
21
|
-
track_enterprise_search_policy_train_started,
|
|
22
|
-
)
|
|
23
|
-
from rasa.shared.exceptions import RasaException
|
|
24
11
|
from rasa.core.constants import (
|
|
25
12
|
POLICY_MAX_HISTORY,
|
|
26
13
|
POLICY_PRIORITY,
|
|
27
14
|
SEARCH_POLICY_PRIORITY,
|
|
28
15
|
UTTER_SOURCE_METADATA_KEY,
|
|
29
16
|
)
|
|
17
|
+
from rasa.core.information_retrieval import (
|
|
18
|
+
InformationRetrieval,
|
|
19
|
+
SearchResult,
|
|
20
|
+
InformationRetrievalException,
|
|
21
|
+
create_from_endpoint_config,
|
|
22
|
+
)
|
|
23
|
+
from rasa.core.information_retrieval.faiss import FAISS_Store
|
|
30
24
|
from rasa.core.policies.policy import Policy, PolicyPrediction
|
|
31
25
|
from rasa.core.utils import AvailableEndpoints
|
|
32
|
-
from rasa.dialogue_understanding.
|
|
33
|
-
|
|
26
|
+
from rasa.dialogue_understanding.generator.constants import (
|
|
27
|
+
LLM_CONFIG_KEY,
|
|
34
28
|
)
|
|
35
29
|
from rasa.dialogue_understanding.patterns.cannot_handle import (
|
|
36
30
|
CannotHandlePatternFlowStackFrame,
|
|
37
31
|
)
|
|
38
|
-
from rasa.dialogue_understanding.
|
|
32
|
+
from rasa.dialogue_understanding.patterns.internal_error import (
|
|
33
|
+
InternalErrorPatternFlowStackFrame,
|
|
34
|
+
)
|
|
39
35
|
from rasa.dialogue_understanding.stack.frames import (
|
|
40
36
|
DialogueStackFrame,
|
|
41
37
|
SearchStackFrame,
|
|
42
38
|
)
|
|
39
|
+
from rasa.dialogue_understanding.stack.frames import PatternFlowStackFrame
|
|
43
40
|
from rasa.engine.graph import ExecutionContext
|
|
44
41
|
from rasa.engine.recipes.default_recipe import DefaultV1Recipe
|
|
45
42
|
from rasa.engine.storage.resource import Resource
|
|
@@ -48,14 +45,13 @@ from rasa.graph_components.providers.forms_provider import Forms
|
|
|
48
45
|
from rasa.graph_components.providers.responses_provider import Responses
|
|
49
46
|
from rasa.shared.constants import (
|
|
50
47
|
EMBEDDINGS_CONFIG_KEY,
|
|
51
|
-
LLM_API_HEALTH_CHECK_ENV_VAR,
|
|
52
|
-
LLM_CONFIG_KEY,
|
|
53
48
|
MODEL_CONFIG_KEY,
|
|
54
|
-
MODEL_NAME_CONFIG_KEY,
|
|
55
49
|
PROMPT_CONFIG_KEY,
|
|
56
50
|
PROVIDER_CONFIG_KEY,
|
|
57
51
|
OPENAI_PROVIDER,
|
|
58
52
|
TIMEOUT_CONFIG_KEY,
|
|
53
|
+
MODEL_NAME_CONFIG_KEY,
|
|
54
|
+
MODEL_GROUP_CONFIG_KEY,
|
|
59
55
|
)
|
|
60
56
|
from rasa.shared.core.constants import (
|
|
61
57
|
ACTION_CANCEL_FLOW,
|
|
@@ -66,26 +62,32 @@ from rasa.shared.core.domain import Domain
|
|
|
66
62
|
from rasa.shared.core.events import Event, UserUttered, BotUttered
|
|
67
63
|
from rasa.shared.core.generator import TrackerWithCachedStates
|
|
68
64
|
from rasa.shared.core.trackers import DialogueStateTracker, EventVerbosity
|
|
65
|
+
from rasa.shared.exceptions import RasaException, FileIOException
|
|
69
66
|
from rasa.shared.nlu.training_data.training_data import TrainingData
|
|
67
|
+
from rasa.shared.providers.embedding._langchain_embedding_client_adapter import (
|
|
68
|
+
_LangchainEmbeddingClientAdapter,
|
|
69
|
+
)
|
|
70
|
+
from rasa.shared.providers.llm.llm_client import LLMClient
|
|
70
71
|
from rasa.shared.utils.cli import print_error_and_exit
|
|
72
|
+
from rasa.shared.utils.health_check.embeddings_health_check_mixin import (
|
|
73
|
+
EmbeddingsHealthCheckMixin,
|
|
74
|
+
)
|
|
75
|
+
from rasa.shared.utils.health_check.llm_health_check_mixin import LLMHealthCheckMixin
|
|
71
76
|
from rasa.shared.utils.io import deep_container_fingerprint
|
|
72
77
|
from rasa.shared.utils.llm import (
|
|
73
78
|
DEFAULT_OPENAI_CHAT_MODEL_NAME,
|
|
74
79
|
DEFAULT_OPENAI_EMBEDDING_MODEL_NAME,
|
|
75
80
|
embedder_factory,
|
|
76
81
|
get_prompt_template,
|
|
77
|
-
llm_api_health_check,
|
|
78
82
|
llm_factory,
|
|
79
83
|
sanitize_message_for_prompt,
|
|
80
84
|
tracker_as_readable_transcript,
|
|
81
|
-
|
|
85
|
+
resolve_model_client_config,
|
|
82
86
|
)
|
|
83
|
-
from rasa.
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
InformationRetrievalException,
|
|
88
|
-
create_from_endpoint_config,
|
|
87
|
+
from rasa.telemetry import (
|
|
88
|
+
track_enterprise_search_policy_predict,
|
|
89
|
+
track_enterprise_search_policy_train_completed,
|
|
90
|
+
track_enterprise_search_policy_train_started,
|
|
89
91
|
)
|
|
90
92
|
|
|
91
93
|
if TYPE_CHECKING:
|
|
@@ -130,6 +132,7 @@ DEFAULT_EMBEDDINGS_CONFIG = {
|
|
|
130
132
|
}
|
|
131
133
|
|
|
132
134
|
ENTERPRISE_SEARCH_PROMPT_FILE_NAME = "enterprise_search_policy_prompt.jinja2"
|
|
135
|
+
ENTERPRISE_SEARCH_CONFIG_FILE_NAME = "config.json"
|
|
133
136
|
|
|
134
137
|
SEARCH_RESULTS_METADATA_KEY = "search_results"
|
|
135
138
|
SEARCH_QUERY_METADATA_KEY = "search_query"
|
|
@@ -154,7 +157,7 @@ class VectorStoreConfigurationError(RasaException):
|
|
|
154
157
|
@DefaultV1Recipe.register(
|
|
155
158
|
DefaultV1Recipe.ComponentType.POLICY_WITH_END_TO_END_SUPPORT, is_trainable=True
|
|
156
159
|
)
|
|
157
|
-
class EnterpriseSearchPolicy(Policy):
|
|
160
|
+
class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Policy):
|
|
158
161
|
"""Policy which uses a vector store and LLMs to respond to user messages.
|
|
159
162
|
|
|
160
163
|
The policy uses a vector store and LLMs to respond to user messages. The
|
|
@@ -200,24 +203,35 @@ class EnterpriseSearchPolicy(Policy):
|
|
|
200
203
|
"""Constructs a new Policy object."""
|
|
201
204
|
super().__init__(config, model_storage, resource, execution_context, featurizer)
|
|
202
205
|
|
|
206
|
+
# Resolve LLM config
|
|
207
|
+
self.config[LLM_CONFIG_KEY] = resolve_model_client_config(
|
|
208
|
+
self.config.get(LLM_CONFIG_KEY), EnterpriseSearchPolicy.__name__
|
|
209
|
+
)
|
|
210
|
+
# Resolve embeddings config
|
|
211
|
+
self.config[EMBEDDINGS_CONFIG_KEY] = resolve_model_client_config(
|
|
212
|
+
self.config.get(EMBEDDINGS_CONFIG_KEY), EnterpriseSearchPolicy.__name__
|
|
213
|
+
)
|
|
214
|
+
|
|
203
215
|
# Vector store object and configuration
|
|
204
216
|
self.vector_store = vector_store
|
|
205
|
-
self.vector_store_config = config.get(
|
|
217
|
+
self.vector_store_config = self.config.get(
|
|
206
218
|
VECTOR_STORE_PROPERTY, DEFAULT_VECTOR_STORE
|
|
207
219
|
)
|
|
220
|
+
|
|
208
221
|
# Embeddings configuration for encoding the search query
|
|
209
|
-
self.embeddings_config =
|
|
210
|
-
EMBEDDINGS_CONFIG_KEY
|
|
222
|
+
self.embeddings_config = (
|
|
223
|
+
self.config[EMBEDDINGS_CONFIG_KEY] or DEFAULT_EMBEDDINGS_CONFIG
|
|
211
224
|
)
|
|
225
|
+
|
|
226
|
+
# LLM Configuration for response generation
|
|
227
|
+
self.llm_config = self.config[LLM_CONFIG_KEY] or DEFAULT_LLM_CONFIG
|
|
228
|
+
|
|
212
229
|
# Maximum number of turns to include in the prompt
|
|
213
230
|
self.max_history = self.config.get(POLICY_MAX_HISTORY)
|
|
214
231
|
|
|
215
232
|
# Maximum number of messages to include in the search query
|
|
216
233
|
self.max_messages_in_query = self.config.get(MAX_MESSAGES_IN_QUERY_KEY, 2)
|
|
217
234
|
|
|
218
|
-
# LLM Configuration for response generation
|
|
219
|
-
self.llm_config = self.config.get(LLM_CONFIG_KEY, DEFAULT_LLM_CONFIG)
|
|
220
|
-
|
|
221
235
|
# boolean to enable/disable tracing of prompt tokens
|
|
222
236
|
self.trace_prompt_tokens = self.config.get(TRACE_TOKENS_PROPERTY, False)
|
|
223
237
|
|
|
@@ -246,9 +260,16 @@ class EnterpriseSearchPolicy(Policy):
|
|
|
246
260
|
Returns:
|
|
247
261
|
The embedder.
|
|
248
262
|
"""
|
|
263
|
+
# Copy the config so original config is not modified
|
|
264
|
+
config = config.copy()
|
|
265
|
+
# Resolve config and instantiate the embedding client
|
|
266
|
+
config[EMBEDDINGS_CONFIG_KEY] = resolve_model_client_config(
|
|
267
|
+
config.get(EMBEDDINGS_CONFIG_KEY), EnterpriseSearchPolicy.__name__
|
|
268
|
+
)
|
|
249
269
|
client = embedder_factory(
|
|
250
270
|
config.get(EMBEDDINGS_CONFIG_KEY), DEFAULT_EMBEDDINGS_CONFIG
|
|
251
271
|
)
|
|
272
|
+
# Wrap the embedding client in the adapter
|
|
252
273
|
return _LangchainEmbeddingClientAdapter(client)
|
|
253
274
|
|
|
254
275
|
def train( # type: ignore[override]
|
|
@@ -275,6 +296,9 @@ class EnterpriseSearchPolicy(Policy):
|
|
|
275
296
|
A policy must return its resource locator so that potential children nodes
|
|
276
297
|
can load the policy from the resource.
|
|
277
298
|
"""
|
|
299
|
+
# Perform health checks for both LLM and embeddings client configs
|
|
300
|
+
self._perform_health_checks(self.config, "enterprise_search_policy.train")
|
|
301
|
+
|
|
278
302
|
store_type = self.vector_store_config.get(VECTOR_STORE_TYPE_PROPERTY)
|
|
279
303
|
|
|
280
304
|
# telemetry call to track training start
|
|
@@ -294,20 +318,6 @@ class EnterpriseSearchPolicy(Policy):
|
|
|
294
318
|
f"required environment variables. Error: {e}"
|
|
295
319
|
)
|
|
296
320
|
|
|
297
|
-
# validate llm configuration
|
|
298
|
-
llm_client = try_instantiate_llm_client(
|
|
299
|
-
self.config.get(LLM_CONFIG_KEY),
|
|
300
|
-
DEFAULT_LLM_CONFIG,
|
|
301
|
-
"enterprise_search_policy.train",
|
|
302
|
-
EnterpriseSearchPolicy.__name__,
|
|
303
|
-
)
|
|
304
|
-
if os.getenv(LLM_API_HEALTH_CHECK_ENV_VAR, "true").lower() == "true":
|
|
305
|
-
llm_api_health_check(
|
|
306
|
-
llm_client,
|
|
307
|
-
"enterprise_search_policy.train",
|
|
308
|
-
EnterpriseSearchPolicy.__name__,
|
|
309
|
-
)
|
|
310
|
-
|
|
311
321
|
if store_type == DEFAULT_VECTOR_STORE_TYPE:
|
|
312
322
|
logger.info("enterprise_search_policy.train.faiss")
|
|
313
323
|
with self._model_storage.write_to(self._resource) as path:
|
|
@@ -326,9 +336,13 @@ class EnterpriseSearchPolicy(Policy):
|
|
|
326
336
|
embeddings_type=self.embeddings_config.get(PROVIDER_CONFIG_KEY),
|
|
327
337
|
embeddings_model=self.embeddings_config.get(MODEL_CONFIG_KEY)
|
|
328
338
|
or self.embeddings_config.get(MODEL_NAME_CONFIG_KEY),
|
|
339
|
+
embeddings_model_group_id=self.embeddings_config.get(
|
|
340
|
+
MODEL_GROUP_CONFIG_KEY
|
|
341
|
+
),
|
|
329
342
|
llm_type=self.llm_config.get(PROVIDER_CONFIG_KEY),
|
|
330
343
|
llm_model=self.llm_config.get(MODEL_CONFIG_KEY)
|
|
331
344
|
or self.llm_config.get(MODEL_NAME_CONFIG_KEY),
|
|
345
|
+
llm_model_group_id=self.llm_config.get(MODEL_GROUP_CONFIG_KEY),
|
|
332
346
|
citation_enabled=self.citation_enabled,
|
|
333
347
|
)
|
|
334
348
|
self.persist()
|
|
@@ -340,6 +354,9 @@ class EnterpriseSearchPolicy(Policy):
|
|
|
340
354
|
rasa.shared.utils.io.write_text_file(
|
|
341
355
|
self.prompt_template, path / ENTERPRISE_SEARCH_PROMPT_FILE_NAME
|
|
342
356
|
)
|
|
357
|
+
rasa.shared.utils.io.dump_obj_as_json_to_file(
|
|
358
|
+
path / ENTERPRISE_SEARCH_CONFIG_FILE_NAME, self.config
|
|
359
|
+
)
|
|
343
360
|
|
|
344
361
|
def _prepare_slots_for_template(
|
|
345
362
|
self, tracker: DialogueStateTracker
|
|
@@ -520,9 +537,13 @@ class EnterpriseSearchPolicy(Policy):
|
|
|
520
537
|
embeddings_type=self.embeddings_config.get(PROVIDER_CONFIG_KEY),
|
|
521
538
|
embeddings_model=self.embeddings_config.get(MODEL_CONFIG_KEY)
|
|
522
539
|
or self.embeddings_config.get(MODEL_NAME_CONFIG_KEY),
|
|
540
|
+
embeddings_model_group_id=self.embeddings_config.get(
|
|
541
|
+
MODEL_GROUP_CONFIG_KEY
|
|
542
|
+
),
|
|
523
543
|
llm_type=self.llm_config.get(PROVIDER_CONFIG_KEY),
|
|
524
544
|
llm_model=self.llm_config.get(MODEL_CONFIG_KEY)
|
|
525
545
|
or self.llm_config.get(MODEL_NAME_CONFIG_KEY),
|
|
546
|
+
llm_model_group_id=self.llm_config.get(MODEL_GROUP_CONFIG_KEY),
|
|
526
547
|
citation_enabled=self.citation_enabled,
|
|
527
548
|
)
|
|
528
549
|
return self._create_prediction(
|
|
@@ -671,12 +692,27 @@ class EnterpriseSearchPolicy(Policy):
|
|
|
671
692
|
**kwargs: Any,
|
|
672
693
|
) -> "EnterpriseSearchPolicy":
|
|
673
694
|
"""Loads a trained policy (see parent class for full docstring)."""
|
|
695
|
+
|
|
696
|
+
# Perform health checks for both LLM and embeddings client configs
|
|
697
|
+
cls._perform_health_checks(config, "enterprise_search_policy.load")
|
|
698
|
+
|
|
674
699
|
prompt_template = None
|
|
700
|
+
try:
|
|
701
|
+
with model_storage.read_from(resource) as path:
|
|
702
|
+
prompt_template = rasa.shared.utils.io.read_file(
|
|
703
|
+
path / ENTERPRISE_SEARCH_PROMPT_FILE_NAME
|
|
704
|
+
)
|
|
705
|
+
except (FileNotFoundError, FileIOException) as e:
|
|
706
|
+
logger.warning(
|
|
707
|
+
"enterprise_search_policy.load.failed", error=e, resource=resource.name
|
|
708
|
+
)
|
|
709
|
+
|
|
675
710
|
store_type = config.get(VECTOR_STORE_PROPERTY, {}).get(
|
|
676
711
|
VECTOR_STORE_TYPE_PROPERTY
|
|
677
712
|
)
|
|
678
713
|
|
|
679
714
|
embeddings = cls._create_plain_embedder(config)
|
|
715
|
+
|
|
680
716
|
logger.info("enterprise_search_policy.load", config=config)
|
|
681
717
|
if store_type == DEFAULT_VECTOR_STORE_TYPE:
|
|
682
718
|
# if a vector store is not specified,
|
|
@@ -694,16 +730,6 @@ class EnterpriseSearchPolicy(Policy):
|
|
|
694
730
|
config_type=store_type,
|
|
695
731
|
embeddings=embeddings,
|
|
696
732
|
) # type: ignore
|
|
697
|
-
try:
|
|
698
|
-
with model_storage.read_from(resource) as path:
|
|
699
|
-
prompt_template = rasa.shared.utils.io.read_file(
|
|
700
|
-
path / ENTERPRISE_SEARCH_PROMPT_FILE_NAME
|
|
701
|
-
)
|
|
702
|
-
|
|
703
|
-
except (FileNotFoundError, FileNotFoundError) as e:
|
|
704
|
-
logger.warning(
|
|
705
|
-
"enterprise_search_policy.load.failed", error=e, resource=resource.name
|
|
706
|
-
)
|
|
707
733
|
|
|
708
734
|
return cls(
|
|
709
735
|
config,
|
|
@@ -745,14 +771,23 @@ class EnterpriseSearchPolicy(Policy):
|
|
|
745
771
|
|
|
746
772
|
@classmethod
|
|
747
773
|
def fingerprint_addon(cls, config: Dict[str, Any]) -> Optional[str]:
|
|
748
|
-
"""Add a fingerprint of
|
|
774
|
+
"""Add a fingerprint of enterprise search policy for the graph."""
|
|
749
775
|
local_knowledge_data = cls._get_local_knowledge_data(config)
|
|
750
776
|
|
|
751
777
|
prompt_template = get_prompt_template(
|
|
752
778
|
config.get(PROMPT_CONFIG_KEY),
|
|
753
779
|
DEFAULT_ENTERPRISE_SEARCH_PROMPT_TEMPLATE,
|
|
754
780
|
)
|
|
755
|
-
|
|
781
|
+
|
|
782
|
+
llm_config = resolve_model_client_config(
|
|
783
|
+
config.get(LLM_CONFIG_KEY), EnterpriseSearchPolicy.__name__
|
|
784
|
+
)
|
|
785
|
+
embedding_config = resolve_model_client_config(
|
|
786
|
+
config.get(EMBEDDINGS_CONFIG_KEY), EnterpriseSearchPolicy.__name__
|
|
787
|
+
)
|
|
788
|
+
return deep_container_fingerprint(
|
|
789
|
+
[prompt_template, local_knowledge_data, llm_config, embedding_config]
|
|
790
|
+
)
|
|
756
791
|
|
|
757
792
|
@staticmethod
|
|
758
793
|
def post_process_citations(llm_answer: str) -> str:
|
|
@@ -844,3 +879,27 @@ class EnterpriseSearchPolicy(Policy):
|
|
|
844
879
|
joined_sources = "\n".join(new_sources)
|
|
845
880
|
|
|
846
881
|
return joined_answer + joined_sources
|
|
882
|
+
|
|
883
|
+
@classmethod
|
|
884
|
+
def _perform_health_checks(
|
|
885
|
+
cls, config: Dict[Text, Any], log_source_method: str
|
|
886
|
+
) -> None:
|
|
887
|
+
# Perform health check of the LLM client config
|
|
888
|
+
llm_config = resolve_model_client_config(config.get(LLM_CONFIG_KEY, {}))
|
|
889
|
+
cls.perform_llm_health_check(
|
|
890
|
+
llm_config,
|
|
891
|
+
DEFAULT_LLM_CONFIG,
|
|
892
|
+
log_source_method,
|
|
893
|
+
EnterpriseSearchPolicy.__name__,
|
|
894
|
+
)
|
|
895
|
+
|
|
896
|
+
# Perform health check of the embeddings client config
|
|
897
|
+
embeddings_config = resolve_model_client_config(
|
|
898
|
+
config.get(EMBEDDINGS_CONFIG_KEY, {})
|
|
899
|
+
)
|
|
900
|
+
cls.perform_embeddings_health_check(
|
|
901
|
+
embeddings_config,
|
|
902
|
+
DEFAULT_EMBEDDINGS_CONFIG,
|
|
903
|
+
log_source_method,
|
|
904
|
+
EnterpriseSearchPolicy.__name__,
|
|
905
|
+
)
|
|
@@ -330,24 +330,27 @@ def reset_scoped_slots(
|
|
|
330
330
|
events: List[Event] = []
|
|
331
331
|
|
|
332
332
|
not_resettable_slot_names = set()
|
|
333
|
+
flow_persistable_slots = current_flow.persisted_slots
|
|
333
334
|
|
|
334
335
|
for step in current_flow.steps_with_calls_resolved:
|
|
335
336
|
if isinstance(step, CollectInformationFlowStep):
|
|
336
337
|
# reset all slots scoped to the flow
|
|
337
|
-
|
|
338
|
-
|
|
338
|
+
slot_name = step.collect
|
|
339
|
+
if step.reset_after_flow_ends and slot_name not in flow_persistable_slots:
|
|
340
|
+
_reset_slot(slot_name, tracker)
|
|
339
341
|
else:
|
|
340
|
-
not_resettable_slot_names.add(
|
|
342
|
+
not_resettable_slot_names.add(slot_name)
|
|
341
343
|
|
|
342
344
|
# slots set by the set slots step should be reset after the flow ends
|
|
343
345
|
# unless they are also used in a collect step where `reset_after_flow_ends`
|
|
344
|
-
# is set to `False`
|
|
346
|
+
# is set to `False` or set in the `persisted_slots` list.
|
|
345
347
|
resettable_set_slots = [
|
|
346
348
|
slot["key"]
|
|
347
349
|
for step in current_flow.steps_with_calls_resolved
|
|
348
350
|
if isinstance(step, SetSlotsFlowStep)
|
|
349
351
|
for slot in step.slots
|
|
350
352
|
if slot["key"] not in not_resettable_slot_names
|
|
353
|
+
and slot["key"] not in flow_persistable_slots
|
|
351
354
|
]
|
|
352
355
|
|
|
353
356
|
for name in resettable_set_slots:
|
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
import importlib.resources
|
|
2
2
|
import math
|
|
3
|
-
import os
|
|
4
3
|
from dataclasses import dataclass, field
|
|
5
4
|
from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING, Text, Tuple
|
|
6
5
|
|
|
@@ -19,6 +18,7 @@ from rasa.core.constants import (
|
|
|
19
18
|
UTTER_SOURCE_METADATA_KEY,
|
|
20
19
|
)
|
|
21
20
|
from rasa.core.policies.policy import Policy, PolicyPrediction, SupportedData
|
|
21
|
+
from rasa.dialogue_understanding.patterns.chitchat import FLOW_PATTERN_CHITCHAT
|
|
22
22
|
from rasa.dialogue_understanding.stack.frames import (
|
|
23
23
|
ChitChatStackFrame,
|
|
24
24
|
DialogueStackFrame,
|
|
@@ -32,7 +32,6 @@ from rasa.graph_components.providers.responses_provider import Responses
|
|
|
32
32
|
from rasa.shared.constants import (
|
|
33
33
|
REQUIRED_SLOTS_KEY,
|
|
34
34
|
EMBEDDINGS_CONFIG_KEY,
|
|
35
|
-
LLM_API_HEALTH_CHECK_ENV_VAR,
|
|
36
35
|
LLM_CONFIG_KEY,
|
|
37
36
|
MODEL_CONFIG_KEY,
|
|
38
37
|
MODEL_NAME_CONFIG_KEY,
|
|
@@ -40,8 +39,10 @@ from rasa.shared.constants import (
|
|
|
40
39
|
PROVIDER_CONFIG_KEY,
|
|
41
40
|
OPENAI_PROVIDER,
|
|
42
41
|
TIMEOUT_CONFIG_KEY,
|
|
42
|
+
MODEL_GROUP_CONFIG_KEY,
|
|
43
43
|
)
|
|
44
44
|
from rasa.shared.core.constants import ACTION_LISTEN_NAME
|
|
45
|
+
from rasa.shared.core.constants import ACTION_TRIGGER_CHITCHAT
|
|
45
46
|
from rasa.shared.core.domain import KEY_RESPONSES_TEXT, Domain
|
|
46
47
|
from rasa.shared.core.events import (
|
|
47
48
|
ActionExecuted,
|
|
@@ -59,6 +60,10 @@ from rasa.shared.providers.embedding._langchain_embedding_client_adapter import
|
|
|
59
60
|
_LangchainEmbeddingClientAdapter,
|
|
60
61
|
)
|
|
61
62
|
from rasa.shared.providers.llm.llm_client import LLMClient
|
|
63
|
+
from rasa.shared.utils.health_check.embeddings_health_check_mixin import (
|
|
64
|
+
EmbeddingsHealthCheckMixin,
|
|
65
|
+
)
|
|
66
|
+
from rasa.shared.utils.health_check.llm_health_check_mixin import LLMHealthCheckMixin
|
|
62
67
|
from rasa.shared.utils.io import deep_container_fingerprint
|
|
63
68
|
from rasa.shared.utils.llm import (
|
|
64
69
|
AI,
|
|
@@ -69,12 +74,12 @@ from rasa.shared.utils.llm import (
|
|
|
69
74
|
combine_custom_and_default_config,
|
|
70
75
|
embedder_factory,
|
|
71
76
|
get_prompt_template,
|
|
72
|
-
llm_api_health_check,
|
|
73
77
|
llm_factory,
|
|
74
78
|
sanitize_message_for_prompt,
|
|
75
79
|
tracker_as_readable_transcript,
|
|
76
|
-
|
|
80
|
+
resolve_model_client_config,
|
|
77
81
|
)
|
|
82
|
+
from rasa.utils.log_utils import log_llm
|
|
78
83
|
from rasa.utils.ml_utils import (
|
|
79
84
|
extract_ai_response_examples,
|
|
80
85
|
extract_participant_messages_from_transcript,
|
|
@@ -83,9 +88,6 @@ from rasa.utils.ml_utils import (
|
|
|
83
88
|
persist_faiss_vector_store,
|
|
84
89
|
response_for_template,
|
|
85
90
|
)
|
|
86
|
-
from rasa.dialogue_understanding.patterns.chitchat import FLOW_PATTERN_CHITCHAT
|
|
87
|
-
from rasa.shared.core.constants import ACTION_TRIGGER_CHITCHAT
|
|
88
|
-
from rasa.utils.log_utils import log_llm
|
|
89
91
|
|
|
90
92
|
if TYPE_CHECKING:
|
|
91
93
|
from rasa.core.featurizers.tracker_featurizers import TrackerFeaturizer
|
|
@@ -125,6 +127,7 @@ DEFAULT_INTENTLESS_PROMPT_TEMPLATE = importlib.resources.open_text(
|
|
|
125
127
|
).name
|
|
126
128
|
|
|
127
129
|
INTENTLESS_PROMPT_TEMPLATE_FILE_NAME = "intentless_policy_prompt.jinja2"
|
|
130
|
+
INTENTLESS_CONFIG_FILE_NAME = "config.json"
|
|
128
131
|
|
|
129
132
|
|
|
130
133
|
class RasaMLPolicyTrainingException(RasaCoreException):
|
|
@@ -374,7 +377,7 @@ def conversation_as_prompt(conversation: Conversation) -> str:
|
|
|
374
377
|
@DefaultV1Recipe.register(
|
|
375
378
|
DefaultV1Recipe.ComponentType.POLICY_WITH_END_TO_END_SUPPORT, is_trainable=True
|
|
376
379
|
)
|
|
377
|
-
class IntentlessPolicy(Policy):
|
|
380
|
+
class IntentlessPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Policy):
|
|
378
381
|
"""Policy which uses a language model to generate the next action.
|
|
379
382
|
|
|
380
383
|
The policy uses the OpenAI API to generate the next action based on the
|
|
@@ -431,6 +434,16 @@ class IntentlessPolicy(Policy):
|
|
|
431
434
|
"""Constructs a new Policy object."""
|
|
432
435
|
super().__init__(config, model_storage, resource, execution_context, featurizer)
|
|
433
436
|
|
|
437
|
+
# Resolve LLM config
|
|
438
|
+
self.config[LLM_CONFIG_KEY] = resolve_model_client_config(
|
|
439
|
+
self.config.get(LLM_CONFIG_KEY), IntentlessPolicy.__name__
|
|
440
|
+
)
|
|
441
|
+
|
|
442
|
+
# Resolve embeddings config
|
|
443
|
+
self.config[EMBEDDINGS_CONFIG_KEY] = resolve_model_client_config(
|
|
444
|
+
self.config.get(EMBEDDINGS_CONFIG_KEY), IntentlessPolicy.__name__
|
|
445
|
+
)
|
|
446
|
+
|
|
434
447
|
self.nlu_abstention_threshold: float = self.config[NLU_ABSTENTION_THRESHOLD]
|
|
435
448
|
self.response_index = responses_docsearch
|
|
436
449
|
self.conversation_samples_index = samples_docsearch
|
|
@@ -447,9 +460,16 @@ class IntentlessPolicy(Policy):
|
|
|
447
460
|
Returns:
|
|
448
461
|
The embedder.
|
|
449
462
|
"""
|
|
463
|
+
# Copy the config so original config is not modified
|
|
464
|
+
config.copy()
|
|
465
|
+
# Resolve config and instantiate the embedding client
|
|
466
|
+
config[EMBEDDINGS_CONFIG_KEY] = resolve_model_client_config(
|
|
467
|
+
config.get(EMBEDDINGS_CONFIG_KEY), IntentlessPolicy.__name__
|
|
468
|
+
)
|
|
450
469
|
client = embedder_factory(
|
|
451
470
|
config.get(EMBEDDINGS_CONFIG_KEY), DEFAULT_EMBEDDINGS_CONFIG
|
|
452
471
|
)
|
|
472
|
+
# Wrap the embedding client in the adapter
|
|
453
473
|
return _LangchainEmbeddingClientAdapter(client)
|
|
454
474
|
|
|
455
475
|
def embeddings_property(self, prop: str) -> Optional[str]:
|
|
@@ -490,16 +510,8 @@ class IntentlessPolicy(Policy):
|
|
|
490
510
|
A policy must return its resource locator so that potential children nodes
|
|
491
511
|
can load the policy from the resource.
|
|
492
512
|
"""
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
DEFAULT_LLM_CONFIG,
|
|
496
|
-
"intentless_policy.train",
|
|
497
|
-
IntentlessPolicy.__name__,
|
|
498
|
-
)
|
|
499
|
-
if os.getenv(LLM_API_HEALTH_CHECK_ENV_VAR, "true").lower() == "true":
|
|
500
|
-
llm_api_health_check(
|
|
501
|
-
llm_client, "intentless_policy.train", IntentlessPolicy.__name__
|
|
502
|
-
)
|
|
513
|
+
# Perform health checks of both LLM and embeddings client configs
|
|
514
|
+
self._perform_health_checks(self.config, "intentless_policy.train")
|
|
503
515
|
|
|
504
516
|
responses = filter_responses(responses, forms, flows or FlowsList([]))
|
|
505
517
|
telemetry.track_intentless_policy_train()
|
|
@@ -546,9 +558,11 @@ class IntentlessPolicy(Policy):
|
|
|
546
558
|
embeddings_type=self.embeddings_property(PROVIDER_CONFIG_KEY),
|
|
547
559
|
embeddings_model=self.embeddings_property(MODEL_CONFIG_KEY)
|
|
548
560
|
or self.embeddings_property(MODEL_NAME_CONFIG_KEY),
|
|
561
|
+
embeddings_model_group_id=self.embeddings_property(MODEL_GROUP_CONFIG_KEY),
|
|
549
562
|
llm_type=self.llm_property(PROVIDER_CONFIG_KEY),
|
|
550
563
|
llm_model=self.llm_property(MODEL_CONFIG_KEY)
|
|
551
564
|
or self.llm_property(MODEL_NAME_CONFIG_KEY),
|
|
565
|
+
llm_model_group_id=self.llm_property(MODEL_GROUP_CONFIG_KEY),
|
|
552
566
|
)
|
|
553
567
|
|
|
554
568
|
self.persist()
|
|
@@ -564,6 +578,9 @@ class IntentlessPolicy(Policy):
|
|
|
564
578
|
rasa.shared.utils.io.write_text_file(
|
|
565
579
|
self.prompt_template, path / INTENTLESS_PROMPT_TEMPLATE_FILE_NAME
|
|
566
580
|
)
|
|
581
|
+
rasa.shared.utils.io.dump_obj_as_json_to_file(
|
|
582
|
+
path / INTENTLESS_CONFIG_FILE_NAME, self.config
|
|
583
|
+
)
|
|
567
584
|
|
|
568
585
|
async def predict_action_probabilities(
|
|
569
586
|
self,
|
|
@@ -625,9 +642,11 @@ class IntentlessPolicy(Policy):
|
|
|
625
642
|
embeddings_type=self.embeddings_property(PROVIDER_CONFIG_KEY),
|
|
626
643
|
embeddings_model=self.embeddings_property(MODEL_CONFIG_KEY)
|
|
627
644
|
or self.embeddings_property(MODEL_NAME_CONFIG_KEY),
|
|
645
|
+
embeddings_model_group_id=self.embeddings_property(MODEL_GROUP_CONFIG_KEY),
|
|
628
646
|
llm_type=self.llm_property(PROVIDER_CONFIG_KEY),
|
|
629
647
|
llm_model=self.llm_property(MODEL_CONFIG_KEY)
|
|
630
648
|
or self.llm_property(MODEL_NAME_CONFIG_KEY),
|
|
649
|
+
llm_model_group_id=self.llm_property(MODEL_GROUP_CONFIG_KEY),
|
|
631
650
|
score=score,
|
|
632
651
|
)
|
|
633
652
|
|
|
@@ -651,7 +670,7 @@ class IntentlessPolicy(Policy):
|
|
|
651
670
|
history: str,
|
|
652
671
|
) -> Optional[str]:
|
|
653
672
|
"""Make the llm call to generate an answer."""
|
|
654
|
-
llm = llm_factory(self.config
|
|
673
|
+
llm = llm_factory(self.config.get(LLM_CONFIG_KEY), DEFAULT_LLM_CONFIG)
|
|
655
674
|
inputs = {
|
|
656
675
|
"conversations": conversation_samples,
|
|
657
676
|
"responses": response_examples,
|
|
@@ -925,6 +944,10 @@ class IntentlessPolicy(Policy):
|
|
|
925
944
|
**kwargs: Any,
|
|
926
945
|
) -> "IntentlessPolicy":
|
|
927
946
|
"""Loads a trained policy (see parent class for full docstring)."""
|
|
947
|
+
|
|
948
|
+
# Perform health checks of both LLM and embeddings client configs
|
|
949
|
+
cls._perform_health_checks(config, "intentless_policy.load")
|
|
950
|
+
|
|
928
951
|
responses_docsearch = None
|
|
929
952
|
samples_docsearch = None
|
|
930
953
|
prompt_template = None
|
|
@@ -945,7 +968,6 @@ class IntentlessPolicy(Policy):
|
|
|
945
968
|
prompt_template = rasa.shared.utils.io.read_file(
|
|
946
969
|
path / INTENTLESS_PROMPT_TEMPLATE_FILE_NAME
|
|
947
970
|
)
|
|
948
|
-
|
|
949
971
|
except (ValueError, FileNotFoundError, FileIOException) as e:
|
|
950
972
|
structlogger.warning(
|
|
951
973
|
"intentless_policy.load.failed", error=e, resource_name=resource.name
|
|
@@ -963,9 +985,43 @@ class IntentlessPolicy(Policy):
|
|
|
963
985
|
|
|
964
986
|
@classmethod
|
|
965
987
|
def fingerprint_addon(cls, config: Dict[str, Any]) -> Optional[str]:
|
|
966
|
-
"""Add a fingerprint of
|
|
988
|
+
"""Add a fingerprint of intentless policy for the graph."""
|
|
967
989
|
prompt_template = get_prompt_template(
|
|
968
990
|
config.get(PROMPT_CONFIG_KEY),
|
|
969
991
|
DEFAULT_INTENTLESS_PROMPT_TEMPLATE,
|
|
970
992
|
)
|
|
971
|
-
|
|
993
|
+
|
|
994
|
+
llm_config = resolve_model_client_config(
|
|
995
|
+
config.get(LLM_CONFIG_KEY), IntentlessPolicy.__name__
|
|
996
|
+
)
|
|
997
|
+
embedding_config = resolve_model_client_config(
|
|
998
|
+
config.get(EMBEDDINGS_CONFIG_KEY), IntentlessPolicy.__name__
|
|
999
|
+
)
|
|
1000
|
+
|
|
1001
|
+
return deep_container_fingerprint(
|
|
1002
|
+
[prompt_template, llm_config, embedding_config]
|
|
1003
|
+
)
|
|
1004
|
+
|
|
1005
|
+
@classmethod
|
|
1006
|
+
def _perform_health_checks(
|
|
1007
|
+
cls, config: Dict[Text, Any], log_source_method: str
|
|
1008
|
+
) -> None:
|
|
1009
|
+
# Perform health check of the LLM client config
|
|
1010
|
+
llm_config = resolve_model_client_config(config.get(LLM_CONFIG_KEY, {}))
|
|
1011
|
+
cls.perform_llm_health_check(
|
|
1012
|
+
llm_config,
|
|
1013
|
+
DEFAULT_LLM_CONFIG,
|
|
1014
|
+
log_source_method,
|
|
1015
|
+
IntentlessPolicy.__name__,
|
|
1016
|
+
)
|
|
1017
|
+
|
|
1018
|
+
# Perform health check of the embeddings client config
|
|
1019
|
+
embeddings_config = resolve_model_client_config(
|
|
1020
|
+
config.get(EMBEDDINGS_CONFIG_KEY, {})
|
|
1021
|
+
)
|
|
1022
|
+
cls.perform_embeddings_health_check(
|
|
1023
|
+
embeddings_config,
|
|
1024
|
+
DEFAULT_EMBEDDINGS_CONFIG,
|
|
1025
|
+
log_source_method,
|
|
1026
|
+
IntentlessPolicy.__name__,
|
|
1027
|
+
)
|