rasa-pro 3.11.0a4.dev3__py3-none-any.whl → 3.11.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- rasa/__main__.py +22 -12
- rasa/api.py +1 -1
- rasa/cli/arguments/default_arguments.py +1 -2
- rasa/cli/arguments/shell.py +5 -1
- rasa/cli/e2e_test.py +1 -1
- rasa/cli/evaluate.py +8 -8
- rasa/cli/inspect.py +4 -4
- rasa/cli/llm_fine_tuning.py +1 -1
- rasa/cli/project_templates/calm/config.yml +5 -7
- rasa/cli/project_templates/calm/endpoints.yml +8 -0
- rasa/cli/project_templates/tutorial/config.yml +8 -5
- rasa/cli/project_templates/tutorial/data/flows.yml +1 -1
- rasa/cli/project_templates/tutorial/data/patterns.yml +5 -0
- rasa/cli/project_templates/tutorial/domain.yml +14 -0
- rasa/cli/project_templates/tutorial/endpoints.yml +7 -7
- rasa/cli/run.py +1 -1
- rasa/cli/scaffold.py +4 -2
- rasa/cli/utils.py +5 -0
- rasa/cli/x.py +8 -8
- rasa/constants.py +1 -1
- rasa/core/channels/channel.py +3 -0
- rasa/core/channels/inspector/dist/assets/{arc-6852c607.js → arc-bc141fb2.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{c4Diagram-d0fbc5ce-acc952b2.js → c4Diagram-d0fbc5ce-be2db283.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{classDiagram-936ed81e-848a7597.js → classDiagram-936ed81e-55366915.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{classDiagram-v2-c3cb15f1-a73d3e68.js → classDiagram-v2-c3cb15f1-bb529518.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{createText-62fc7601-e5ee049d.js → createText-62fc7601-b0ec81d6.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{edges-f2ad444c-771e517e.js → edges-f2ad444c-6166330c.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{erDiagram-9d236eb7-aa347178.js → erDiagram-9d236eb7-5ccc6a8e.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{flowDb-1972c806-651fc57d.js → flowDb-1972c806-fca3bfe4.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{flowDiagram-7ea5b25a-ca67804f.js → flowDiagram-7ea5b25a-4739080f.js} +1 -1
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-736177bf.js +1 -0
- rasa/core/channels/inspector/dist/assets/{flowchart-elk-definition-abe16c3d-2dbc568d.js → flowchart-elk-definition-abe16c3d-7c1b0e0f.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{ganttDiagram-9b5ea136-25a65bd8.js → ganttDiagram-9b5ea136-772fd050.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{gitGraphDiagram-99d0ae7c-fdc7378d.js → gitGraphDiagram-99d0ae7c-8eae1dc9.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{index-2c4b9a3b-6f1fd606.js → index-2c4b9a3b-f55afcdf.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{index-efdd30c1.js → index-e7cef9de.js} +68 -68
- rasa/core/channels/inspector/dist/assets/{infoDiagram-736b4530-cb1a041a.js → infoDiagram-736b4530-124d4a14.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{journeyDiagram-df861f2b-14609879.js → journeyDiagram-df861f2b-7c4fae44.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{layout-2490f52b.js → layout-b9885fb6.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{line-40186f1f.js → line-7c59abb6.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{linear-08814e93.js → linear-4776f780.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{mindmap-definition-beec6740-1a534584.js → mindmap-definition-beec6740-2332c46c.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{pieDiagram-dbbf0591-72397b61.js → pieDiagram-dbbf0591-8fb39303.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{quadrantDiagram-4d7f4fd6-3bb0b6a3.js → quadrantDiagram-4d7f4fd6-3c7180a2.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{requirementDiagram-6fc4c22a-57334f61.js → requirementDiagram-6fc4c22a-e910bcb8.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{sankeyDiagram-8f13d901-111e1297.js → sankeyDiagram-8f13d901-ead16c89.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{sequenceDiagram-b655622a-10bcfe62.js → sequenceDiagram-b655622a-29a02a19.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{stateDiagram-59f0c015-acaf7513.js → stateDiagram-59f0c015-042b3137.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{stateDiagram-v2-2b26beab-3ec2a235.js → stateDiagram-v2-2b26beab-2178c0f3.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-080da4f6-62730289.js → styles-080da4f6-23ffa4fc.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-3dcbcfbf-5284ee76.js → styles-3dcbcfbf-94f59763.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-9c745c82-642435e3.js → styles-9c745c82-78a6bebc.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{svgDrawCommon-4835440b-b250a350.js → svgDrawCommon-4835440b-eae2a6f6.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{timeline-definition-5b62e21b-c2b147ed.js → timeline-definition-5b62e21b-5c968d92.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{xychartDiagram-2b33534f-f92cfea9.js → xychartDiagram-2b33534f-fd3db0d5.js} +1 -1
- rasa/core/channels/inspector/dist/index.html +1 -1
- rasa/core/channels/inspector/src/App.tsx +1 -1
- rasa/core/channels/inspector/src/helpers/audiostream.ts +77 -16
- rasa/core/channels/socketio.py +2 -1
- rasa/core/channels/telegram.py +1 -1
- rasa/core/channels/twilio.py +1 -1
- rasa/core/channels/voice_ready/jambonz.py +2 -2
- rasa/core/channels/voice_stream/asr/asr_event.py +5 -0
- rasa/core/channels/voice_stream/asr/azure.py +122 -0
- rasa/core/channels/voice_stream/asr/deepgram.py +16 -6
- rasa/core/channels/voice_stream/audio_bytes.py +1 -0
- rasa/core/channels/voice_stream/browser_audio.py +31 -8
- rasa/core/channels/voice_stream/call_state.py +23 -0
- rasa/core/channels/voice_stream/tts/azure.py +6 -2
- rasa/core/channels/voice_stream/tts/cartesia.py +10 -6
- rasa/core/channels/voice_stream/tts/tts_engine.py +1 -0
- rasa/core/channels/voice_stream/twilio_media_streams.py +27 -18
- rasa/core/channels/voice_stream/util.py +4 -4
- rasa/core/channels/voice_stream/voice_channel.py +177 -39
- rasa/core/featurizers/single_state_featurizer.py +22 -1
- rasa/core/featurizers/tracker_featurizers.py +115 -18
- rasa/core/nlg/contextual_response_rephraser.py +16 -22
- rasa/core/persistor.py +86 -39
- rasa/core/policies/enterprise_search_policy.py +159 -60
- rasa/core/policies/flows/flow_executor.py +7 -4
- rasa/core/policies/intentless_policy.py +120 -22
- rasa/core/policies/ted_policy.py +58 -33
- rasa/core/policies/unexpected_intent_policy.py +15 -7
- rasa/core/processor.py +25 -0
- rasa/core/training/interactive.py +34 -35
- rasa/core/utils.py +8 -3
- rasa/dialogue_understanding/coexistence/llm_based_router.py +58 -16
- rasa/dialogue_understanding/commands/change_flow_command.py +6 -0
- rasa/dialogue_understanding/commands/user_silence_command.py +59 -0
- rasa/dialogue_understanding/commands/utils.py +5 -0
- rasa/dialogue_understanding/generator/constants.py +4 -0
- rasa/dialogue_understanding/generator/flow_retrieval.py +65 -3
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +68 -26
- rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +57 -8
- rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +64 -7
- rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +39 -0
- rasa/dialogue_understanding/patterns/user_silence.py +37 -0
- rasa/e2e_test/e2e_test_runner.py +4 -2
- rasa/e2e_test/utils/io.py +1 -1
- rasa/engine/validation.py +297 -7
- rasa/model_manager/config.py +15 -3
- rasa/model_manager/model_api.py +15 -7
- rasa/model_manager/runner_service.py +8 -6
- rasa/model_manager/socket_bridge.py +6 -3
- rasa/model_manager/trainer_service.py +7 -5
- rasa/model_manager/utils.py +28 -7
- rasa/model_service.py +6 -2
- rasa/model_training.py +2 -0
- rasa/nlu/classifiers/diet_classifier.py +38 -25
- rasa/nlu/classifiers/logistic_regression_classifier.py +22 -9
- rasa/nlu/classifiers/sklearn_intent_classifier.py +37 -16
- rasa/nlu/extractors/crf_entity_extractor.py +93 -50
- rasa/nlu/featurizers/sparse_featurizer/count_vectors_featurizer.py +45 -16
- rasa/nlu/featurizers/sparse_featurizer/lexical_syntactic_featurizer.py +52 -17
- rasa/nlu/featurizers/sparse_featurizer/regex_featurizer.py +5 -3
- rasa/shared/constants.py +36 -3
- rasa/shared/core/constants.py +7 -0
- rasa/shared/core/domain.py +26 -0
- rasa/shared/core/flows/flow.py +5 -0
- rasa/shared/core/flows/flows_yaml_schema.json +10 -0
- rasa/shared/core/flows/utils.py +39 -0
- rasa/shared/core/flows/validation.py +96 -0
- rasa/shared/core/slots.py +5 -0
- rasa/shared/nlu/training_data/features.py +120 -2
- rasa/shared/providers/_configs/azure_openai_client_config.py +5 -3
- rasa/shared/providers/_configs/litellm_router_client_config.py +200 -0
- rasa/shared/providers/_configs/model_group_config.py +167 -0
- rasa/shared/providers/_configs/openai_client_config.py +1 -1
- rasa/shared/providers/_configs/rasa_llm_client_config.py +73 -0
- rasa/shared/providers/_configs/self_hosted_llm_client_config.py +1 -0
- rasa/shared/providers/_configs/utils.py +16 -0
- rasa/shared/providers/embedding/_base_litellm_embedding_client.py +12 -15
- rasa/shared/providers/embedding/azure_openai_embedding_client.py +54 -21
- rasa/shared/providers/embedding/litellm_router_embedding_client.py +135 -0
- rasa/shared/providers/llm/_base_litellm_client.py +31 -30
- rasa/shared/providers/llm/azure_openai_llm_client.py +50 -29
- rasa/shared/providers/llm/litellm_router_llm_client.py +127 -0
- rasa/shared/providers/llm/rasa_llm_client.py +112 -0
- rasa/shared/providers/llm/self_hosted_llm_client.py +1 -1
- rasa/shared/providers/mappings.py +19 -0
- rasa/shared/providers/router/__init__.py +0 -0
- rasa/shared/providers/router/_base_litellm_router_client.py +149 -0
- rasa/shared/providers/router/router_client.py +73 -0
- rasa/shared/utils/common.py +8 -0
- rasa/shared/utils/health_check.py +533 -0
- rasa/shared/utils/io.py +28 -6
- rasa/shared/utils/llm.py +350 -46
- rasa/shared/utils/yaml.py +11 -13
- rasa/studio/upload.py +64 -20
- rasa/telemetry.py +80 -17
- rasa/tracing/instrumentation/attribute_extractors.py +74 -17
- rasa/utils/io.py +0 -66
- rasa/utils/log_utils.py +9 -2
- rasa/utils/tensorflow/feature_array.py +366 -0
- rasa/utils/tensorflow/model_data.py +2 -193
- rasa/validator.py +70 -0
- rasa/version.py +1 -1
- {rasa_pro-3.11.0a4.dev3.dist-info → rasa_pro-3.11.0rc1.dist-info}/METADATA +10 -10
- {rasa_pro-3.11.0a4.dev3.dist-info → rasa_pro-3.11.0rc1.dist-info}/RECORD +162 -146
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-587d82d8.js +0 -1
- {rasa_pro-3.11.0a4.dev3.dist-info → rasa_pro-3.11.0rc1.dist-info}/NOTICE +0 -0
- {rasa_pro-3.11.0a4.dev3.dist-info → rasa_pro-3.11.0rc1.dist-info}/WHEEL +0 -0
- {rasa_pro-3.11.0a4.dev3.dist-info → rasa_pro-3.11.0rc1.dist-info}/entry_points.txt +0 -0
|
@@ -1,45 +1,44 @@
|
|
|
1
1
|
import importlib.resources
|
|
2
2
|
import json
|
|
3
|
-
import os
|
|
4
3
|
import re
|
|
5
|
-
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Text
|
|
6
|
-
|
|
4
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Text, Tuple
|
|
7
5
|
import dotenv
|
|
8
6
|
import structlog
|
|
9
7
|
from jinja2 import Template
|
|
10
8
|
from pydantic import ValidationError
|
|
11
9
|
|
|
12
|
-
from rasa.shared.providers.embedding._langchain_embedding_client_adapter import (
|
|
13
|
-
_LangchainEmbeddingClientAdapter,
|
|
14
|
-
)
|
|
15
|
-
from rasa.shared.providers.llm.llm_client import LLMClient
|
|
16
|
-
|
|
17
10
|
import rasa.shared.utils.io
|
|
18
|
-
from rasa.telemetry import (
|
|
19
|
-
track_enterprise_search_policy_predict,
|
|
20
|
-
track_enterprise_search_policy_train_completed,
|
|
21
|
-
track_enterprise_search_policy_train_started,
|
|
22
|
-
)
|
|
23
|
-
from rasa.shared.exceptions import RasaException
|
|
24
11
|
from rasa.core.constants import (
|
|
25
12
|
POLICY_MAX_HISTORY,
|
|
26
13
|
POLICY_PRIORITY,
|
|
27
14
|
SEARCH_POLICY_PRIORITY,
|
|
28
15
|
UTTER_SOURCE_METADATA_KEY,
|
|
29
16
|
)
|
|
17
|
+
from rasa.core.information_retrieval import (
|
|
18
|
+
InformationRetrieval,
|
|
19
|
+
SearchResult,
|
|
20
|
+
InformationRetrievalException,
|
|
21
|
+
create_from_endpoint_config,
|
|
22
|
+
)
|
|
23
|
+
from rasa.core.information_retrieval.faiss import FAISS_Store
|
|
30
24
|
from rasa.core.policies.policy import Policy, PolicyPrediction
|
|
31
25
|
from rasa.core.utils import AvailableEndpoints
|
|
32
|
-
from rasa.dialogue_understanding.
|
|
33
|
-
|
|
26
|
+
from rasa.dialogue_understanding.generator.constants import (
|
|
27
|
+
LLM_CONFIG_KEY,
|
|
28
|
+
TRAINED_MODEL_NAME_CONFIG_KEY,
|
|
29
|
+
TRAINED_EMBEDDINGS_CONFIG_KEY,
|
|
34
30
|
)
|
|
35
31
|
from rasa.dialogue_understanding.patterns.cannot_handle import (
|
|
36
32
|
CannotHandlePatternFlowStackFrame,
|
|
37
33
|
)
|
|
38
|
-
from rasa.dialogue_understanding.
|
|
34
|
+
from rasa.dialogue_understanding.patterns.internal_error import (
|
|
35
|
+
InternalErrorPatternFlowStackFrame,
|
|
36
|
+
)
|
|
39
37
|
from rasa.dialogue_understanding.stack.frames import (
|
|
40
38
|
DialogueStackFrame,
|
|
41
39
|
SearchStackFrame,
|
|
42
40
|
)
|
|
41
|
+
from rasa.dialogue_understanding.stack.frames import PatternFlowStackFrame
|
|
43
42
|
from rasa.engine.graph import ExecutionContext
|
|
44
43
|
from rasa.engine.recipes.default_recipe import DefaultV1Recipe
|
|
45
44
|
from rasa.engine.storage.resource import Resource
|
|
@@ -48,14 +47,13 @@ from rasa.graph_components.providers.forms_provider import Forms
|
|
|
48
47
|
from rasa.graph_components.providers.responses_provider import Responses
|
|
49
48
|
from rasa.shared.constants import (
|
|
50
49
|
EMBEDDINGS_CONFIG_KEY,
|
|
51
|
-
LLM_API_HEALTH_CHECK_ENV_VAR,
|
|
52
|
-
LLM_CONFIG_KEY,
|
|
53
50
|
MODEL_CONFIG_KEY,
|
|
54
|
-
MODEL_NAME_CONFIG_KEY,
|
|
55
51
|
PROMPT_CONFIG_KEY,
|
|
56
52
|
PROVIDER_CONFIG_KEY,
|
|
57
53
|
OPENAI_PROVIDER,
|
|
58
54
|
TIMEOUT_CONFIG_KEY,
|
|
55
|
+
MODEL_NAME_CONFIG_KEY,
|
|
56
|
+
MODEL_GROUP_CONFIG_KEY,
|
|
59
57
|
)
|
|
60
58
|
from rasa.shared.core.constants import (
|
|
61
59
|
ACTION_CANCEL_FLOW,
|
|
@@ -66,7 +64,12 @@ from rasa.shared.core.domain import Domain
|
|
|
66
64
|
from rasa.shared.core.events import Event, UserUttered, BotUttered
|
|
67
65
|
from rasa.shared.core.generator import TrackerWithCachedStates
|
|
68
66
|
from rasa.shared.core.trackers import DialogueStateTracker, EventVerbosity
|
|
67
|
+
from rasa.shared.exceptions import RasaException, FileIOException
|
|
69
68
|
from rasa.shared.nlu.training_data.training_data import TrainingData
|
|
69
|
+
from rasa.shared.providers.embedding._langchain_embedding_client_adapter import (
|
|
70
|
+
_LangchainEmbeddingClientAdapter,
|
|
71
|
+
)
|
|
72
|
+
from rasa.shared.providers.llm.llm_client import LLMClient
|
|
70
73
|
from rasa.shared.utils.cli import print_error_and_exit
|
|
71
74
|
from rasa.shared.utils.io import deep_container_fingerprint
|
|
72
75
|
from rasa.shared.utils.llm import (
|
|
@@ -74,18 +77,21 @@ from rasa.shared.utils.llm import (
|
|
|
74
77
|
DEFAULT_OPENAI_EMBEDDING_MODEL_NAME,
|
|
75
78
|
embedder_factory,
|
|
76
79
|
get_prompt_template,
|
|
77
|
-
llm_api_health_check,
|
|
78
80
|
llm_factory,
|
|
79
81
|
sanitize_message_for_prompt,
|
|
80
82
|
tracker_as_readable_transcript,
|
|
81
|
-
|
|
83
|
+
resolve_model_client_config,
|
|
82
84
|
)
|
|
83
|
-
from rasa.
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
85
|
+
from rasa.shared.utils.health_check import (
|
|
86
|
+
perform_training_time_llm_health_check,
|
|
87
|
+
perform_training_time_embeddings_health_check,
|
|
88
|
+
perform_inference_time_llm_health_check,
|
|
89
|
+
perform_inference_time_embeddings_health_check,
|
|
90
|
+
)
|
|
91
|
+
from rasa.telemetry import (
|
|
92
|
+
track_enterprise_search_policy_predict,
|
|
93
|
+
track_enterprise_search_policy_train_completed,
|
|
94
|
+
track_enterprise_search_policy_train_started,
|
|
89
95
|
)
|
|
90
96
|
|
|
91
97
|
if TYPE_CHECKING:
|
|
@@ -130,6 +136,7 @@ DEFAULT_EMBEDDINGS_CONFIG = {
|
|
|
130
136
|
}
|
|
131
137
|
|
|
132
138
|
ENTERPRISE_SEARCH_PROMPT_FILE_NAME = "enterprise_search_policy_prompt.jinja2"
|
|
139
|
+
ENTERPRISE_SEARCH_CONFIG_FILE_NAME = "config.json"
|
|
133
140
|
|
|
134
141
|
SEARCH_RESULTS_METADATA_KEY = "search_results"
|
|
135
142
|
SEARCH_QUERY_METADATA_KEY = "search_query"
|
|
@@ -200,24 +207,35 @@ class EnterpriseSearchPolicy(Policy):
|
|
|
200
207
|
"""Constructs a new Policy object."""
|
|
201
208
|
super().__init__(config, model_storage, resource, execution_context, featurizer)
|
|
202
209
|
|
|
210
|
+
# Resolve LLM config
|
|
211
|
+
self.config[LLM_CONFIG_KEY] = resolve_model_client_config(
|
|
212
|
+
self.config.get(LLM_CONFIG_KEY), EnterpriseSearchPolicy.__name__
|
|
213
|
+
)
|
|
214
|
+
# Resolve embeddings config
|
|
215
|
+
self.config[EMBEDDINGS_CONFIG_KEY] = resolve_model_client_config(
|
|
216
|
+
self.config.get(EMBEDDINGS_CONFIG_KEY), EnterpriseSearchPolicy.__name__
|
|
217
|
+
)
|
|
218
|
+
|
|
203
219
|
# Vector store object and configuration
|
|
204
220
|
self.vector_store = vector_store
|
|
205
|
-
self.vector_store_config = config.get(
|
|
221
|
+
self.vector_store_config = self.config.get(
|
|
206
222
|
VECTOR_STORE_PROPERTY, DEFAULT_VECTOR_STORE
|
|
207
223
|
)
|
|
224
|
+
|
|
208
225
|
# Embeddings configuration for encoding the search query
|
|
209
|
-
self.embeddings_config =
|
|
210
|
-
EMBEDDINGS_CONFIG_KEY
|
|
226
|
+
self.embeddings_config = (
|
|
227
|
+
self.config[EMBEDDINGS_CONFIG_KEY] or DEFAULT_EMBEDDINGS_CONFIG
|
|
211
228
|
)
|
|
229
|
+
|
|
230
|
+
# LLM Configuration for response generation
|
|
231
|
+
self.llm_config = self.config[LLM_CONFIG_KEY] or DEFAULT_LLM_CONFIG
|
|
232
|
+
|
|
212
233
|
# Maximum number of turns to include in the prompt
|
|
213
234
|
self.max_history = self.config.get(POLICY_MAX_HISTORY)
|
|
214
235
|
|
|
215
236
|
# Maximum number of messages to include in the search query
|
|
216
237
|
self.max_messages_in_query = self.config.get(MAX_MESSAGES_IN_QUERY_KEY, 2)
|
|
217
238
|
|
|
218
|
-
# LLM Configuration for response generation
|
|
219
|
-
self.llm_config = self.config.get(LLM_CONFIG_KEY, DEFAULT_LLM_CONFIG)
|
|
220
|
-
|
|
221
239
|
# boolean to enable/disable tracing of prompt tokens
|
|
222
240
|
self.trace_prompt_tokens = self.config.get(TRACE_TOKENS_PROPERTY, False)
|
|
223
241
|
|
|
@@ -246,9 +264,16 @@ class EnterpriseSearchPolicy(Policy):
|
|
|
246
264
|
Returns:
|
|
247
265
|
The embedder.
|
|
248
266
|
"""
|
|
267
|
+
# Copy the config so original config is not modified
|
|
268
|
+
config = config.copy()
|
|
269
|
+
# Resolve config and instantiate the embedding client
|
|
270
|
+
config[EMBEDDINGS_CONFIG_KEY] = resolve_model_client_config(
|
|
271
|
+
config.get(EMBEDDINGS_CONFIG_KEY), EnterpriseSearchPolicy.__name__
|
|
272
|
+
)
|
|
249
273
|
client = embedder_factory(
|
|
250
274
|
config.get(EMBEDDINGS_CONFIG_KEY), DEFAULT_EMBEDDINGS_CONFIG
|
|
251
275
|
)
|
|
276
|
+
# Wrap the embedding client in the adapter
|
|
252
277
|
return _LangchainEmbeddingClientAdapter(client)
|
|
253
278
|
|
|
254
279
|
def train( # type: ignore[override]
|
|
@@ -294,19 +319,10 @@ class EnterpriseSearchPolicy(Policy):
|
|
|
294
319
|
f"required environment variables. Error: {e}"
|
|
295
320
|
)
|
|
296
321
|
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
self.config
|
|
300
|
-
|
|
301
|
-
"enterprise_search_policy.train",
|
|
302
|
-
EnterpriseSearchPolicy.__name__,
|
|
303
|
-
)
|
|
304
|
-
if os.getenv(LLM_API_HEALTH_CHECK_ENV_VAR, "true").lower() == "true":
|
|
305
|
-
llm_api_health_check(
|
|
306
|
-
llm_client,
|
|
307
|
-
"enterprise_search_policy.train",
|
|
308
|
-
EnterpriseSearchPolicy.__name__,
|
|
309
|
-
)
|
|
322
|
+
(
|
|
323
|
+
self.config[TRAINED_MODEL_NAME_CONFIG_KEY],
|
|
324
|
+
self.config[TRAINED_EMBEDDINGS_CONFIG_KEY],
|
|
325
|
+
) = self._perform_training_time_health_checks()
|
|
310
326
|
|
|
311
327
|
if store_type == DEFAULT_VECTOR_STORE_TYPE:
|
|
312
328
|
logger.info("enterprise_search_policy.train.faiss")
|
|
@@ -326,9 +342,13 @@ class EnterpriseSearchPolicy(Policy):
|
|
|
326
342
|
embeddings_type=self.embeddings_config.get(PROVIDER_CONFIG_KEY),
|
|
327
343
|
embeddings_model=self.embeddings_config.get(MODEL_CONFIG_KEY)
|
|
328
344
|
or self.embeddings_config.get(MODEL_NAME_CONFIG_KEY),
|
|
345
|
+
embeddings_model_group_id=self.embeddings_config.get(
|
|
346
|
+
MODEL_GROUP_CONFIG_KEY
|
|
347
|
+
),
|
|
329
348
|
llm_type=self.llm_config.get(PROVIDER_CONFIG_KEY),
|
|
330
349
|
llm_model=self.llm_config.get(MODEL_CONFIG_KEY)
|
|
331
350
|
or self.llm_config.get(MODEL_NAME_CONFIG_KEY),
|
|
351
|
+
llm_model_group_id=self.llm_config.get(MODEL_GROUP_CONFIG_KEY),
|
|
332
352
|
citation_enabled=self.citation_enabled,
|
|
333
353
|
)
|
|
334
354
|
self.persist()
|
|
@@ -340,6 +360,9 @@ class EnterpriseSearchPolicy(Policy):
|
|
|
340
360
|
rasa.shared.utils.io.write_text_file(
|
|
341
361
|
self.prompt_template, path / ENTERPRISE_SEARCH_PROMPT_FILE_NAME
|
|
342
362
|
)
|
|
363
|
+
rasa.shared.utils.io.dump_obj_as_json_to_file(
|
|
364
|
+
path / ENTERPRISE_SEARCH_CONFIG_FILE_NAME, self.config
|
|
365
|
+
)
|
|
343
366
|
|
|
344
367
|
def _prepare_slots_for_template(
|
|
345
368
|
self, tracker: DialogueStateTracker
|
|
@@ -520,9 +543,13 @@ class EnterpriseSearchPolicy(Policy):
|
|
|
520
543
|
embeddings_type=self.embeddings_config.get(PROVIDER_CONFIG_KEY),
|
|
521
544
|
embeddings_model=self.embeddings_config.get(MODEL_CONFIG_KEY)
|
|
522
545
|
or self.embeddings_config.get(MODEL_NAME_CONFIG_KEY),
|
|
546
|
+
embeddings_model_group_id=self.embeddings_config.get(
|
|
547
|
+
MODEL_GROUP_CONFIG_KEY
|
|
548
|
+
),
|
|
523
549
|
llm_type=self.llm_config.get(PROVIDER_CONFIG_KEY),
|
|
524
550
|
llm_model=self.llm_config.get(MODEL_CONFIG_KEY)
|
|
525
551
|
or self.llm_config.get(MODEL_NAME_CONFIG_KEY),
|
|
552
|
+
llm_model_group_id=self.llm_config.get(MODEL_GROUP_CONFIG_KEY),
|
|
526
553
|
citation_enabled=self.citation_enabled,
|
|
527
554
|
)
|
|
528
555
|
return self._create_prediction(
|
|
@@ -672,11 +699,26 @@ class EnterpriseSearchPolicy(Policy):
|
|
|
672
699
|
) -> "EnterpriseSearchPolicy":
|
|
673
700
|
"""Loads a trained policy (see parent class for full docstring)."""
|
|
674
701
|
prompt_template = None
|
|
702
|
+
persisted_config = None
|
|
703
|
+
try:
|
|
704
|
+
with model_storage.read_from(resource) as path:
|
|
705
|
+
prompt_template = rasa.shared.utils.io.read_file(
|
|
706
|
+
path / ENTERPRISE_SEARCH_PROMPT_FILE_NAME
|
|
707
|
+
)
|
|
708
|
+
persisted_config = rasa.shared.utils.io.read_json_file(
|
|
709
|
+
path / ENTERPRISE_SEARCH_CONFIG_FILE_NAME
|
|
710
|
+
)
|
|
711
|
+
except (FileNotFoundError, FileIOException) as e:
|
|
712
|
+
logger.warning(
|
|
713
|
+
"enterprise_search_policy.load.failed", error=e, resource=resource.name
|
|
714
|
+
)
|
|
715
|
+
|
|
675
716
|
store_type = config.get(VECTOR_STORE_PROPERTY, {}).get(
|
|
676
717
|
VECTOR_STORE_TYPE_PROPERTY
|
|
677
718
|
)
|
|
678
719
|
|
|
679
720
|
embeddings = cls._create_plain_embedder(config)
|
|
721
|
+
|
|
680
722
|
logger.info("enterprise_search_policy.load", config=config)
|
|
681
723
|
if store_type == DEFAULT_VECTOR_STORE_TYPE:
|
|
682
724
|
# if a vector store is not specified,
|
|
@@ -694,18 +736,8 @@ class EnterpriseSearchPolicy(Policy):
|
|
|
694
736
|
config_type=store_type,
|
|
695
737
|
embeddings=embeddings,
|
|
696
738
|
) # type: ignore
|
|
697
|
-
try:
|
|
698
|
-
with model_storage.read_from(resource) as path:
|
|
699
|
-
prompt_template = rasa.shared.utils.io.read_file(
|
|
700
|
-
path / ENTERPRISE_SEARCH_PROMPT_FILE_NAME
|
|
701
|
-
)
|
|
702
739
|
|
|
703
|
-
|
|
704
|
-
logger.warning(
|
|
705
|
-
"enterprise_search_policy.load.failed", error=e, resource=resource.name
|
|
706
|
-
)
|
|
707
|
-
|
|
708
|
-
return cls(
|
|
740
|
+
policy = cls(
|
|
709
741
|
config,
|
|
710
742
|
model_storage,
|
|
711
743
|
resource,
|
|
@@ -714,6 +746,14 @@ class EnterpriseSearchPolicy(Policy):
|
|
|
714
746
|
prompt_template=prompt_template,
|
|
715
747
|
)
|
|
716
748
|
|
|
749
|
+
cls._perform_inference_time_health_checks(
|
|
750
|
+
persisted_config,
|
|
751
|
+
policy.config.get(LLM_CONFIG_KEY),
|
|
752
|
+
policy.config.get(EMBEDDINGS_CONFIG_KEY),
|
|
753
|
+
)
|
|
754
|
+
|
|
755
|
+
return policy
|
|
756
|
+
|
|
717
757
|
@classmethod
|
|
718
758
|
def _get_local_knowledge_data(cls, config: Dict[str, Any]) -> Optional[List[str]]:
|
|
719
759
|
"""This is required only for local knowledge base types.
|
|
@@ -745,14 +785,23 @@ class EnterpriseSearchPolicy(Policy):
|
|
|
745
785
|
|
|
746
786
|
@classmethod
|
|
747
787
|
def fingerprint_addon(cls, config: Dict[str, Any]) -> Optional[str]:
|
|
748
|
-
"""Add a fingerprint of
|
|
788
|
+
"""Add a fingerprint of enterprise search policy for the graph."""
|
|
749
789
|
local_knowledge_data = cls._get_local_knowledge_data(config)
|
|
750
790
|
|
|
751
791
|
prompt_template = get_prompt_template(
|
|
752
792
|
config.get(PROMPT_CONFIG_KEY),
|
|
753
793
|
DEFAULT_ENTERPRISE_SEARCH_PROMPT_TEMPLATE,
|
|
754
794
|
)
|
|
755
|
-
|
|
795
|
+
|
|
796
|
+
llm_config = resolve_model_client_config(
|
|
797
|
+
config.get(LLM_CONFIG_KEY), EnterpriseSearchPolicy.__name__
|
|
798
|
+
)
|
|
799
|
+
embedding_config = resolve_model_client_config(
|
|
800
|
+
config.get(EMBEDDINGS_CONFIG_KEY), EnterpriseSearchPolicy.__name__
|
|
801
|
+
)
|
|
802
|
+
return deep_container_fingerprint(
|
|
803
|
+
[prompt_template, local_knowledge_data, llm_config, embedding_config]
|
|
804
|
+
)
|
|
756
805
|
|
|
757
806
|
@staticmethod
|
|
758
807
|
def post_process_citations(llm_answer: str) -> str:
|
|
@@ -844,3 +893,53 @@ class EnterpriseSearchPolicy(Policy):
|
|
|
844
893
|
joined_sources = "\n".join(new_sources)
|
|
845
894
|
|
|
846
895
|
return joined_answer + joined_sources
|
|
896
|
+
|
|
897
|
+
def _perform_training_time_health_checks(
|
|
898
|
+
self,
|
|
899
|
+
) -> Tuple[Optional[str], Optional[str]]:
|
|
900
|
+
train_model_name = perform_training_time_llm_health_check(
|
|
901
|
+
self.config.get(LLM_CONFIG_KEY),
|
|
902
|
+
DEFAULT_LLM_CONFIG,
|
|
903
|
+
"enterprise_search_policy.train",
|
|
904
|
+
EnterpriseSearchPolicy.__name__,
|
|
905
|
+
)
|
|
906
|
+
train_embedding_name = perform_training_time_embeddings_health_check(
|
|
907
|
+
self.config.get(EMBEDDINGS_CONFIG_KEY),
|
|
908
|
+
DEFAULT_EMBEDDINGS_CONFIG,
|
|
909
|
+
"enterprise_search_policy.train",
|
|
910
|
+
EnterpriseSearchPolicy.__name__,
|
|
911
|
+
)
|
|
912
|
+
return train_model_name, train_embedding_name
|
|
913
|
+
|
|
914
|
+
@classmethod
|
|
915
|
+
def _perform_inference_time_health_checks(
|
|
916
|
+
cls,
|
|
917
|
+
persisted_config: Optional[Dict[str, Any]],
|
|
918
|
+
resolved_llm_config: Optional[Dict[str, Any]],
|
|
919
|
+
resolved_embeddings_config: Optional[Dict[str, Any]],
|
|
920
|
+
) -> None:
|
|
921
|
+
train_model_name = (
|
|
922
|
+
persisted_config.get(TRAINED_MODEL_NAME_CONFIG_KEY, None)
|
|
923
|
+
if persisted_config
|
|
924
|
+
else None
|
|
925
|
+
)
|
|
926
|
+
perform_inference_time_llm_health_check(
|
|
927
|
+
resolved_llm_config,
|
|
928
|
+
DEFAULT_LLM_CONFIG,
|
|
929
|
+
train_model_name,
|
|
930
|
+
"enterprise_search_policy.load",
|
|
931
|
+
EnterpriseSearchPolicy.__name__,
|
|
932
|
+
)
|
|
933
|
+
|
|
934
|
+
train_embeddings_name = (
|
|
935
|
+
persisted_config.get(TRAINED_EMBEDDINGS_CONFIG_KEY, None)
|
|
936
|
+
if persisted_config
|
|
937
|
+
else None
|
|
938
|
+
)
|
|
939
|
+
perform_inference_time_embeddings_health_check(
|
|
940
|
+
resolved_embeddings_config,
|
|
941
|
+
DEFAULT_EMBEDDINGS_CONFIG,
|
|
942
|
+
train_embeddings_name,
|
|
943
|
+
"enterprise_search_policy.load",
|
|
944
|
+
EnterpriseSearchPolicy.__name__,
|
|
945
|
+
)
|
|
@@ -330,24 +330,27 @@ def reset_scoped_slots(
|
|
|
330
330
|
events: List[Event] = []
|
|
331
331
|
|
|
332
332
|
not_resettable_slot_names = set()
|
|
333
|
+
flow_persistable_slots = current_flow.persisted_slots
|
|
333
334
|
|
|
334
335
|
for step in current_flow.steps_with_calls_resolved:
|
|
335
336
|
if isinstance(step, CollectInformationFlowStep):
|
|
336
337
|
# reset all slots scoped to the flow
|
|
337
|
-
|
|
338
|
-
|
|
338
|
+
slot_name = step.collect
|
|
339
|
+
if step.reset_after_flow_ends and slot_name not in flow_persistable_slots:
|
|
340
|
+
_reset_slot(slot_name, tracker)
|
|
339
341
|
else:
|
|
340
|
-
not_resettable_slot_names.add(
|
|
342
|
+
not_resettable_slot_names.add(slot_name)
|
|
341
343
|
|
|
342
344
|
# slots set by the set slots step should be reset after the flow ends
|
|
343
345
|
# unless they are also used in a collect step where `reset_after_flow_ends`
|
|
344
|
-
# is set to `False`
|
|
346
|
+
# is set to `False` or set in the `persisted_slots` list.
|
|
345
347
|
resettable_set_slots = [
|
|
346
348
|
slot["key"]
|
|
347
349
|
for step in current_flow.steps_with_calls_resolved
|
|
348
350
|
if isinstance(step, SetSlotsFlowStep)
|
|
349
351
|
for slot in step.slots
|
|
350
352
|
if slot["key"] not in not_resettable_slot_names
|
|
353
|
+
and slot["key"] not in flow_persistable_slots
|
|
351
354
|
]
|
|
352
355
|
|
|
353
356
|
for name in resettable_set_slots:
|