rasa-pro 3.10.16__py3-none-any.whl → 3.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- rasa/__main__.py +31 -15
- rasa/api.py +12 -2
- rasa/cli/arguments/default_arguments.py +24 -4
- rasa/cli/arguments/run.py +15 -0
- rasa/cli/arguments/shell.py +5 -1
- rasa/cli/arguments/train.py +17 -9
- rasa/cli/evaluate.py +7 -7
- rasa/cli/inspect.py +19 -7
- rasa/cli/interactive.py +1 -0
- rasa/cli/llm_fine_tuning.py +11 -14
- rasa/cli/project_templates/calm/config.yml +5 -7
- rasa/cli/project_templates/calm/endpoints.yml +15 -2
- rasa/cli/project_templates/tutorial/config.yml +8 -5
- rasa/cli/project_templates/tutorial/data/flows.yml +1 -1
- rasa/cli/project_templates/tutorial/data/patterns.yml +5 -0
- rasa/cli/project_templates/tutorial/domain.yml +14 -0
- rasa/cli/project_templates/tutorial/endpoints.yml +5 -0
- rasa/cli/run.py +7 -0
- rasa/cli/scaffold.py +4 -2
- rasa/cli/studio/upload.py +0 -15
- rasa/cli/train.py +14 -53
- rasa/cli/utils.py +14 -11
- rasa/cli/x.py +7 -7
- rasa/constants.py +3 -1
- rasa/core/actions/action.py +77 -33
- rasa/core/actions/action_hangup.py +29 -0
- rasa/core/actions/action_repeat_bot_messages.py +89 -0
- rasa/core/actions/e2e_stub_custom_action_executor.py +5 -1
- rasa/core/actions/http_custom_action_executor.py +4 -0
- rasa/core/agent.py +2 -2
- rasa/core/brokers/kafka.py +3 -1
- rasa/core/brokers/pika.py +3 -1
- rasa/core/channels/__init__.py +10 -6
- rasa/core/channels/channel.py +41 -4
- rasa/core/channels/development_inspector.py +150 -46
- rasa/core/channels/inspector/README.md +1 -1
- rasa/core/channels/inspector/dist/assets/{arc-b6e548fe.js → arc-bc141fb2.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{c4Diagram-d0fbc5ce-fa03ac9e.js → c4Diagram-d0fbc5ce-be2db283.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{classDiagram-936ed81e-ee67392a.js → classDiagram-936ed81e-55366915.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{classDiagram-v2-c3cb15f1-9b283fae.js → classDiagram-v2-c3cb15f1-bb529518.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{createText-62fc7601-8b6fcc2a.js → createText-62fc7601-b0ec81d6.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{edges-f2ad444c-22e77f4f.js → edges-f2ad444c-6166330c.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{erDiagram-9d236eb7-60ffc87f.js → erDiagram-9d236eb7-5ccc6a8e.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{flowDb-1972c806-9dd802e4.js → flowDb-1972c806-fca3bfe4.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{flowDiagram-7ea5b25a-5fa1912f.js → flowDiagram-7ea5b25a-4739080f.js} +1 -1
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-736177bf.js +1 -0
- rasa/core/channels/inspector/dist/assets/{flowchart-elk-definition-abe16c3d-622a1fd2.js → flowchart-elk-definition-abe16c3d-7c1b0e0f.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{ganttDiagram-9b5ea136-e285a63a.js → ganttDiagram-9b5ea136-772fd050.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{gitGraphDiagram-99d0ae7c-f237bdca.js → gitGraphDiagram-99d0ae7c-8eae1dc9.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{index-2c4b9a3b-4b03d70e.js → index-2c4b9a3b-f55afcdf.js} +1 -1
- rasa/core/channels/inspector/dist/assets/index-e7cef9de.js +1317 -0
- rasa/core/channels/inspector/dist/assets/{infoDiagram-736b4530-72a0fa5f.js → infoDiagram-736b4530-124d4a14.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{journeyDiagram-df861f2b-82218c41.js → journeyDiagram-df861f2b-7c4fae44.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{layout-78cff630.js → layout-b9885fb6.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{line-5038b469.js → line-7c59abb6.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{linear-c4fc4098.js → linear-4776f780.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{mindmap-definition-beec6740-c33c8ea6.js → mindmap-definition-beec6740-2332c46c.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{pieDiagram-dbbf0591-a8d03059.js → pieDiagram-dbbf0591-8fb39303.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{quadrantDiagram-4d7f4fd6-6a0e56b2.js → quadrantDiagram-4d7f4fd6-3c7180a2.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{requirementDiagram-6fc4c22a-2dc7c7bd.js → requirementDiagram-6fc4c22a-e910bcb8.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{sankeyDiagram-8f13d901-2360fe39.js → sankeyDiagram-8f13d901-ead16c89.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{sequenceDiagram-b655622a-41b9f9ad.js → sequenceDiagram-b655622a-29a02a19.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{stateDiagram-59f0c015-0aad326f.js → stateDiagram-59f0c015-042b3137.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{stateDiagram-v2-2b26beab-9847d984.js → stateDiagram-v2-2b26beab-2178c0f3.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-080da4f6-564d890e.js → styles-080da4f6-23ffa4fc.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-3dcbcfbf-38957613.js → styles-3dcbcfbf-94f59763.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-9c745c82-f0fc6921.js → styles-9c745c82-78a6bebc.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{svgDrawCommon-4835440b-ef3c5a77.js → svgDrawCommon-4835440b-eae2a6f6.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{timeline-definition-5b62e21b-bf3e91c1.js → timeline-definition-5b62e21b-5c968d92.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{xychartDiagram-2b33534f-4d4026c0.js → xychartDiagram-2b33534f-fd3db0d5.js} +1 -1
- rasa/core/channels/inspector/dist/index.html +18 -17
- rasa/core/channels/inspector/index.html +17 -16
- rasa/core/channels/inspector/package.json +5 -1
- rasa/core/channels/inspector/src/App.tsx +118 -68
- rasa/core/channels/inspector/src/components/Chat.tsx +95 -0
- rasa/core/channels/inspector/src/components/DiagramFlow.tsx +11 -10
- rasa/core/channels/inspector/src/components/DialogueStack.tsx +10 -25
- rasa/core/channels/inspector/src/components/LoadingSpinner.tsx +6 -3
- rasa/core/channels/inspector/src/helpers/audiostream.ts +165 -0
- rasa/core/channels/inspector/src/helpers/formatters.test.ts +10 -0
- rasa/core/channels/inspector/src/helpers/formatters.ts +107 -41
- rasa/core/channels/inspector/src/helpers/utils.ts +92 -7
- rasa/core/channels/inspector/src/types.ts +21 -1
- rasa/core/channels/inspector/yarn.lock +94 -1
- rasa/core/channels/rest.py +51 -46
- rasa/core/channels/socketio.py +28 -1
- rasa/core/channels/telegram.py +1 -1
- rasa/core/channels/twilio.py +1 -1
- rasa/core/channels/{audiocodes.py → voice_ready/audiocodes.py} +122 -69
- rasa/core/channels/{voice_aware → voice_ready}/jambonz.py +26 -8
- rasa/core/channels/{voice_aware → voice_ready}/jambonz_protocol.py +57 -5
- rasa/core/channels/{twilio_voice.py → voice_ready/twilio_voice.py} +64 -28
- rasa/core/channels/voice_ready/utils.py +37 -0
- rasa/core/channels/voice_stream/asr/__init__.py +0 -0
- rasa/core/channels/voice_stream/asr/asr_engine.py +89 -0
- rasa/core/channels/voice_stream/asr/asr_event.py +18 -0
- rasa/core/channels/voice_stream/asr/azure.py +129 -0
- rasa/core/channels/voice_stream/asr/deepgram.py +90 -0
- rasa/core/channels/voice_stream/audio_bytes.py +8 -0
- rasa/core/channels/voice_stream/browser_audio.py +107 -0
- rasa/core/channels/voice_stream/call_state.py +23 -0
- rasa/core/channels/voice_stream/tts/__init__.py +0 -0
- rasa/core/channels/voice_stream/tts/azure.py +106 -0
- rasa/core/channels/voice_stream/tts/cartesia.py +118 -0
- rasa/core/channels/voice_stream/tts/tts_cache.py +27 -0
- rasa/core/channels/voice_stream/tts/tts_engine.py +58 -0
- rasa/core/channels/voice_stream/twilio_media_streams.py +173 -0
- rasa/core/channels/voice_stream/util.py +57 -0
- rasa/core/channels/voice_stream/voice_channel.py +427 -0
- rasa/core/information_retrieval/qdrant.py +1 -0
- rasa/core/nlg/contextual_response_rephraser.py +45 -17
- rasa/{nlu → core}/persistor.py +203 -68
- rasa/core/policies/enterprise_search_policy.py +119 -63
- rasa/core/policies/flows/flow_executor.py +15 -22
- rasa/core/policies/intentless_policy.py +83 -28
- rasa/core/processor.py +25 -0
- rasa/core/run.py +12 -2
- rasa/core/secrets_manager/constants.py +4 -0
- rasa/core/secrets_manager/factory.py +8 -0
- rasa/core/secrets_manager/vault.py +11 -1
- rasa/core/training/interactive.py +33 -34
- rasa/core/utils.py +47 -21
- rasa/dialogue_understanding/coexistence/llm_based_router.py +41 -14
- rasa/dialogue_understanding/commands/__init__.py +6 -0
- rasa/dialogue_understanding/commands/repeat_bot_messages_command.py +60 -0
- rasa/dialogue_understanding/commands/session_end_command.py +61 -0
- rasa/dialogue_understanding/commands/user_silence_command.py +59 -0
- rasa/dialogue_understanding/commands/utils.py +5 -0
- rasa/dialogue_understanding/generator/constants.py +2 -0
- rasa/dialogue_understanding/generator/flow_retrieval.py +47 -9
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +38 -15
- rasa/dialogue_understanding/generator/llm_command_generator.py +1 -1
- rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +35 -13
- rasa/dialogue_understanding/generator/single_step/command_prompt_template.jinja2 +3 -0
- rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +60 -13
- rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +53 -0
- rasa/dialogue_understanding/patterns/repeat.py +37 -0
- rasa/dialogue_understanding/patterns/user_silence.py +37 -0
- rasa/dialogue_understanding/processor/command_processor.py +21 -1
- rasa/e2e_test/aggregate_test_stats_calculator.py +1 -11
- rasa/e2e_test/assertions.py +136 -61
- rasa/e2e_test/assertions_schema.yml +23 -0
- rasa/e2e_test/e2e_test_case.py +85 -6
- rasa/e2e_test/e2e_test_runner.py +2 -3
- rasa/e2e_test/utils/e2e_yaml_utils.py +1 -1
- rasa/engine/graph.py +3 -10
- rasa/engine/loader.py +12 -0
- rasa/engine/recipes/config_files/default_config.yml +0 -3
- rasa/engine/recipes/default_recipe.py +0 -1
- rasa/engine/recipes/graph_recipe.py +0 -1
- rasa/engine/runner/dask.py +2 -2
- rasa/engine/storage/local_model_storage.py +12 -42
- rasa/engine/storage/storage.py +1 -5
- rasa/engine/validation.py +527 -74
- rasa/model_manager/__init__.py +0 -0
- rasa/model_manager/config.py +40 -0
- rasa/model_manager/model_api.py +559 -0
- rasa/model_manager/runner_service.py +286 -0
- rasa/model_manager/socket_bridge.py +146 -0
- rasa/model_manager/studio_jwt_auth.py +86 -0
- rasa/model_manager/trainer_service.py +325 -0
- rasa/model_manager/utils.py +87 -0
- rasa/model_manager/warm_rasa_process.py +187 -0
- rasa/model_service.py +112 -0
- rasa/model_training.py +42 -23
- rasa/nlu/tokenizers/whitespace_tokenizer.py +3 -14
- rasa/server.py +4 -2
- rasa/shared/constants.py +60 -8
- rasa/shared/core/constants.py +13 -0
- rasa/shared/core/domain.py +107 -50
- rasa/shared/core/events.py +29 -0
- rasa/shared/core/flows/flow.py +5 -0
- rasa/shared/core/flows/flows_list.py +19 -6
- rasa/shared/core/flows/flows_yaml_schema.json +10 -0
- rasa/shared/core/flows/utils.py +39 -0
- rasa/shared/core/flows/validation.py +121 -0
- rasa/shared/core/flows/yaml_flows_io.py +15 -27
- rasa/shared/core/slots.py +5 -0
- rasa/shared/importers/importer.py +59 -41
- rasa/shared/importers/multi_project.py +23 -11
- rasa/shared/importers/rasa.py +12 -3
- rasa/shared/importers/remote_importer.py +196 -0
- rasa/shared/importers/utils.py +3 -1
- rasa/shared/nlu/training_data/formats/rasa_yaml.py +18 -3
- rasa/shared/nlu/training_data/training_data.py +18 -19
- rasa/shared/providers/_configs/litellm_router_client_config.py +220 -0
- rasa/shared/providers/_configs/model_group_config.py +167 -0
- rasa/shared/providers/_configs/openai_client_config.py +1 -1
- rasa/shared/providers/_configs/rasa_llm_client_config.py +73 -0
- rasa/shared/providers/_configs/self_hosted_llm_client_config.py +1 -0
- rasa/shared/providers/_configs/utils.py +16 -0
- rasa/shared/providers/_utils.py +79 -0
- rasa/shared/providers/embedding/_base_litellm_embedding_client.py +13 -29
- rasa/shared/providers/embedding/azure_openai_embedding_client.py +54 -21
- rasa/shared/providers/embedding/default_litellm_embedding_client.py +24 -0
- rasa/shared/providers/embedding/litellm_router_embedding_client.py +135 -0
- rasa/shared/providers/llm/_base_litellm_client.py +34 -22
- rasa/shared/providers/llm/azure_openai_llm_client.py +50 -29
- rasa/shared/providers/llm/default_litellm_llm_client.py +24 -0
- rasa/shared/providers/llm/litellm_router_llm_client.py +182 -0
- rasa/shared/providers/llm/rasa_llm_client.py +112 -0
- rasa/shared/providers/llm/self_hosted_llm_client.py +5 -29
- rasa/shared/providers/mappings.py +19 -0
- rasa/shared/providers/router/__init__.py +0 -0
- rasa/shared/providers/router/_base_litellm_router_client.py +183 -0
- rasa/shared/providers/router/router_client.py +73 -0
- rasa/shared/utils/common.py +40 -24
- rasa/shared/utils/health_check/__init__.py +0 -0
- rasa/shared/utils/health_check/embeddings_health_check_mixin.py +31 -0
- rasa/shared/utils/health_check/health_check.py +258 -0
- rasa/shared/utils/health_check/llm_health_check_mixin.py +31 -0
- rasa/shared/utils/io.py +27 -6
- rasa/shared/utils/llm.py +354 -44
- rasa/shared/utils/schemas/events.py +2 -0
- rasa/shared/utils/schemas/model_config.yml +0 -10
- rasa/shared/utils/yaml.py +181 -38
- rasa/studio/data_handler.py +3 -1
- rasa/studio/upload.py +160 -74
- rasa/telemetry.py +94 -17
- rasa/tracing/config.py +3 -1
- rasa/tracing/instrumentation/attribute_extractors.py +95 -18
- rasa/tracing/instrumentation/instrumentation.py +121 -0
- rasa/utils/common.py +5 -0
- rasa/utils/endpoints.py +27 -1
- rasa/utils/io.py +8 -16
- rasa/utils/log_utils.py +9 -2
- rasa/utils/sanic_error_handler.py +32 -0
- rasa/validator.py +110 -16
- rasa/version.py +1 -1
- {rasa_pro-3.10.16.dist-info → rasa_pro-3.11.0.dist-info}/METADATA +16 -14
- {rasa_pro-3.10.16.dist-info → rasa_pro-3.11.0.dist-info}/RECORD +236 -185
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-1844e5a5.js +0 -1
- rasa/core/channels/inspector/dist/assets/index-a5d3e69d.js +0 -1040
- rasa/core/channels/voice_aware/utils.py +0 -20
- rasa/llm_fine_tuning/notebooks/unsloth_finetuning.ipynb +0 -407
- /rasa/core/channels/{voice_aware → voice_ready}/__init__.py +0 -0
- /rasa/core/channels/{voice_native → voice_stream}/__init__.py +0 -0
- {rasa_pro-3.10.16.dist-info → rasa_pro-3.11.0.dist-info}/NOTICE +0 -0
- {rasa_pro-3.10.16.dist-info → rasa_pro-3.11.0.dist-info}/WHEEL +0 -0
- {rasa_pro-3.10.16.dist-info → rasa_pro-3.11.0.dist-info}/entry_points.txt +0 -0
|
@@ -2,43 +2,41 @@ import importlib.resources
|
|
|
2
2
|
import json
|
|
3
3
|
import re
|
|
4
4
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Text
|
|
5
|
-
|
|
6
5
|
import dotenv
|
|
7
6
|
import structlog
|
|
8
7
|
from jinja2 import Template
|
|
9
8
|
from pydantic import ValidationError
|
|
10
9
|
|
|
11
|
-
from rasa.shared.providers.embedding._langchain_embedding_client_adapter import (
|
|
12
|
-
_LangchainEmbeddingClientAdapter,
|
|
13
|
-
)
|
|
14
|
-
from rasa.shared.providers.llm.llm_client import LLMClient
|
|
15
|
-
|
|
16
10
|
import rasa.shared.utils.io
|
|
17
|
-
from rasa.telemetry import (
|
|
18
|
-
track_enterprise_search_policy_predict,
|
|
19
|
-
track_enterprise_search_policy_train_completed,
|
|
20
|
-
track_enterprise_search_policy_train_started,
|
|
21
|
-
)
|
|
22
|
-
from rasa.shared.exceptions import RasaException
|
|
23
11
|
from rasa.core.constants import (
|
|
24
12
|
POLICY_MAX_HISTORY,
|
|
25
13
|
POLICY_PRIORITY,
|
|
26
14
|
SEARCH_POLICY_PRIORITY,
|
|
27
15
|
UTTER_SOURCE_METADATA_KEY,
|
|
28
16
|
)
|
|
17
|
+
from rasa.core.information_retrieval import (
|
|
18
|
+
InformationRetrieval,
|
|
19
|
+
SearchResult,
|
|
20
|
+
InformationRetrievalException,
|
|
21
|
+
create_from_endpoint_config,
|
|
22
|
+
)
|
|
23
|
+
from rasa.core.information_retrieval.faiss import FAISS_Store
|
|
29
24
|
from rasa.core.policies.policy import Policy, PolicyPrediction
|
|
30
25
|
from rasa.core.utils import AvailableEndpoints
|
|
31
|
-
from rasa.dialogue_understanding.
|
|
32
|
-
|
|
26
|
+
from rasa.dialogue_understanding.generator.constants import (
|
|
27
|
+
LLM_CONFIG_KEY,
|
|
33
28
|
)
|
|
34
29
|
from rasa.dialogue_understanding.patterns.cannot_handle import (
|
|
35
30
|
CannotHandlePatternFlowStackFrame,
|
|
36
31
|
)
|
|
37
|
-
from rasa.dialogue_understanding.
|
|
32
|
+
from rasa.dialogue_understanding.patterns.internal_error import (
|
|
33
|
+
InternalErrorPatternFlowStackFrame,
|
|
34
|
+
)
|
|
38
35
|
from rasa.dialogue_understanding.stack.frames import (
|
|
39
36
|
DialogueStackFrame,
|
|
40
37
|
SearchStackFrame,
|
|
41
38
|
)
|
|
39
|
+
from rasa.dialogue_understanding.stack.frames import PatternFlowStackFrame
|
|
42
40
|
from rasa.engine.graph import ExecutionContext
|
|
43
41
|
from rasa.engine.recipes.default_recipe import DefaultV1Recipe
|
|
44
42
|
from rasa.engine.storage.resource import Resource
|
|
@@ -47,13 +45,13 @@ from rasa.graph_components.providers.forms_provider import Forms
|
|
|
47
45
|
from rasa.graph_components.providers.responses_provider import Responses
|
|
48
46
|
from rasa.shared.constants import (
|
|
49
47
|
EMBEDDINGS_CONFIG_KEY,
|
|
50
|
-
LLM_CONFIG_KEY,
|
|
51
48
|
MODEL_CONFIG_KEY,
|
|
52
|
-
MODEL_NAME_CONFIG_KEY,
|
|
53
49
|
PROMPT_CONFIG_KEY,
|
|
54
50
|
PROVIDER_CONFIG_KEY,
|
|
55
51
|
OPENAI_PROVIDER,
|
|
56
52
|
TIMEOUT_CONFIG_KEY,
|
|
53
|
+
MODEL_NAME_CONFIG_KEY,
|
|
54
|
+
MODEL_GROUP_ID_CONFIG_KEY,
|
|
57
55
|
)
|
|
58
56
|
from rasa.shared.core.constants import (
|
|
59
57
|
ACTION_CANCEL_FLOW,
|
|
@@ -64,8 +62,17 @@ from rasa.shared.core.domain import Domain
|
|
|
64
62
|
from rasa.shared.core.events import Event, UserUttered, BotUttered
|
|
65
63
|
from rasa.shared.core.generator import TrackerWithCachedStates
|
|
66
64
|
from rasa.shared.core.trackers import DialogueStateTracker, EventVerbosity
|
|
65
|
+
from rasa.shared.exceptions import RasaException, FileIOException
|
|
67
66
|
from rasa.shared.nlu.training_data.training_data import TrainingData
|
|
67
|
+
from rasa.shared.providers.embedding._langchain_embedding_client_adapter import (
|
|
68
|
+
_LangchainEmbeddingClientAdapter,
|
|
69
|
+
)
|
|
70
|
+
from rasa.shared.providers.llm.llm_client import LLMClient
|
|
68
71
|
from rasa.shared.utils.cli import print_error_and_exit
|
|
72
|
+
from rasa.shared.utils.health_check.embeddings_health_check_mixin import (
|
|
73
|
+
EmbeddingsHealthCheckMixin,
|
|
74
|
+
)
|
|
75
|
+
from rasa.shared.utils.health_check.llm_health_check_mixin import LLMHealthCheckMixin
|
|
69
76
|
from rasa.shared.utils.io import deep_container_fingerprint
|
|
70
77
|
from rasa.shared.utils.llm import (
|
|
71
78
|
DEFAULT_OPENAI_CHAT_MODEL_NAME,
|
|
@@ -75,15 +82,12 @@ from rasa.shared.utils.llm import (
|
|
|
75
82
|
llm_factory,
|
|
76
83
|
sanitize_message_for_prompt,
|
|
77
84
|
tracker_as_readable_transcript,
|
|
78
|
-
|
|
79
|
-
try_instantiate_embedder,
|
|
85
|
+
resolve_model_client_config,
|
|
80
86
|
)
|
|
81
|
-
from rasa.
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
InformationRetrievalException,
|
|
86
|
-
create_from_endpoint_config,
|
|
87
|
+
from rasa.telemetry import (
|
|
88
|
+
track_enterprise_search_policy_predict,
|
|
89
|
+
track_enterprise_search_policy_train_completed,
|
|
90
|
+
track_enterprise_search_policy_train_started,
|
|
87
91
|
)
|
|
88
92
|
|
|
89
93
|
if TYPE_CHECKING:
|
|
@@ -128,6 +132,7 @@ DEFAULT_EMBEDDINGS_CONFIG = {
|
|
|
128
132
|
}
|
|
129
133
|
|
|
130
134
|
ENTERPRISE_SEARCH_PROMPT_FILE_NAME = "enterprise_search_policy_prompt.jinja2"
|
|
135
|
+
ENTERPRISE_SEARCH_CONFIG_FILE_NAME = "config.json"
|
|
131
136
|
|
|
132
137
|
SEARCH_RESULTS_METADATA_KEY = "search_results"
|
|
133
138
|
SEARCH_QUERY_METADATA_KEY = "search_query"
|
|
@@ -152,7 +157,7 @@ class VectorStoreConfigurationError(RasaException):
|
|
|
152
157
|
@DefaultV1Recipe.register(
|
|
153
158
|
DefaultV1Recipe.ComponentType.POLICY_WITH_END_TO_END_SUPPORT, is_trainable=True
|
|
154
159
|
)
|
|
155
|
-
class EnterpriseSearchPolicy(Policy):
|
|
160
|
+
class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Policy):
|
|
156
161
|
"""Policy which uses a vector store and LLMs to respond to user messages.
|
|
157
162
|
|
|
158
163
|
The policy uses a vector store and LLMs to respond to user messages. The
|
|
@@ -198,24 +203,35 @@ class EnterpriseSearchPolicy(Policy):
|
|
|
198
203
|
"""Constructs a new Policy object."""
|
|
199
204
|
super().__init__(config, model_storage, resource, execution_context, featurizer)
|
|
200
205
|
|
|
206
|
+
# Resolve LLM config
|
|
207
|
+
self.config[LLM_CONFIG_KEY] = resolve_model_client_config(
|
|
208
|
+
self.config.get(LLM_CONFIG_KEY), EnterpriseSearchPolicy.__name__
|
|
209
|
+
)
|
|
210
|
+
# Resolve embeddings config
|
|
211
|
+
self.config[EMBEDDINGS_CONFIG_KEY] = resolve_model_client_config(
|
|
212
|
+
self.config.get(EMBEDDINGS_CONFIG_KEY), EnterpriseSearchPolicy.__name__
|
|
213
|
+
)
|
|
214
|
+
|
|
201
215
|
# Vector store object and configuration
|
|
202
216
|
self.vector_store = vector_store
|
|
203
|
-
self.vector_store_config = config.get(
|
|
217
|
+
self.vector_store_config = self.config.get(
|
|
204
218
|
VECTOR_STORE_PROPERTY, DEFAULT_VECTOR_STORE
|
|
205
219
|
)
|
|
220
|
+
|
|
206
221
|
# Embeddings configuration for encoding the search query
|
|
207
|
-
self.embeddings_config =
|
|
208
|
-
EMBEDDINGS_CONFIG_KEY
|
|
222
|
+
self.embeddings_config = (
|
|
223
|
+
self.config[EMBEDDINGS_CONFIG_KEY] or DEFAULT_EMBEDDINGS_CONFIG
|
|
209
224
|
)
|
|
225
|
+
|
|
226
|
+
# LLM Configuration for response generation
|
|
227
|
+
self.llm_config = self.config[LLM_CONFIG_KEY] or DEFAULT_LLM_CONFIG
|
|
228
|
+
|
|
210
229
|
# Maximum number of turns to include in the prompt
|
|
211
230
|
self.max_history = self.config.get(POLICY_MAX_HISTORY)
|
|
212
231
|
|
|
213
232
|
# Maximum number of messages to include in the search query
|
|
214
233
|
self.max_messages_in_query = self.config.get(MAX_MESSAGES_IN_QUERY_KEY, 2)
|
|
215
234
|
|
|
216
|
-
# LLM Configuration for response generation
|
|
217
|
-
self.llm_config = self.config.get(LLM_CONFIG_KEY, DEFAULT_LLM_CONFIG)
|
|
218
|
-
|
|
219
235
|
# boolean to enable/disable tracing of prompt tokens
|
|
220
236
|
self.trace_prompt_tokens = self.config.get(TRACE_TOKENS_PROPERTY, False)
|
|
221
237
|
|
|
@@ -244,9 +260,16 @@ class EnterpriseSearchPolicy(Policy):
|
|
|
244
260
|
Returns:
|
|
245
261
|
The embedder.
|
|
246
262
|
"""
|
|
263
|
+
# Copy the config so original config is not modified
|
|
264
|
+
config = config.copy()
|
|
265
|
+
# Resolve config and instantiate the embedding client
|
|
266
|
+
config[EMBEDDINGS_CONFIG_KEY] = resolve_model_client_config(
|
|
267
|
+
config.get(EMBEDDINGS_CONFIG_KEY), EnterpriseSearchPolicy.__name__
|
|
268
|
+
)
|
|
247
269
|
client = embedder_factory(
|
|
248
270
|
config.get(EMBEDDINGS_CONFIG_KEY), DEFAULT_EMBEDDINGS_CONFIG
|
|
249
271
|
)
|
|
272
|
+
# Wrap the embedding client in the adapter
|
|
250
273
|
return _LangchainEmbeddingClientAdapter(client)
|
|
251
274
|
|
|
252
275
|
def train( # type: ignore[override]
|
|
@@ -273,6 +296,9 @@ class EnterpriseSearchPolicy(Policy):
|
|
|
273
296
|
A policy must return its resource locator so that potential children nodes
|
|
274
297
|
can load the policy from the resource.
|
|
275
298
|
"""
|
|
299
|
+
# Perform health checks for both LLM and embeddings client configs
|
|
300
|
+
self._perform_health_checks(self.config, "enterprise_search_policy.train")
|
|
301
|
+
|
|
276
302
|
store_type = self.vector_store_config.get(VECTOR_STORE_TYPE_PROPERTY)
|
|
277
303
|
|
|
278
304
|
# telemetry call to track training start
|
|
@@ -292,14 +318,6 @@ class EnterpriseSearchPolicy(Policy):
|
|
|
292
318
|
f"required environment variables. Error: {e}"
|
|
293
319
|
)
|
|
294
320
|
|
|
295
|
-
# validate llm configuration
|
|
296
|
-
try_instantiate_llm_client(
|
|
297
|
-
self.config.get(LLM_CONFIG_KEY),
|
|
298
|
-
DEFAULT_LLM_CONFIG,
|
|
299
|
-
"enterprise_search_policy.train",
|
|
300
|
-
"EnterpriseSearchPolicy",
|
|
301
|
-
)
|
|
302
|
-
|
|
303
321
|
if store_type == DEFAULT_VECTOR_STORE_TYPE:
|
|
304
322
|
logger.info("enterprise_search_policy.train.faiss")
|
|
305
323
|
with self._model_storage.write_to(self._resource) as path:
|
|
@@ -318,9 +336,13 @@ class EnterpriseSearchPolicy(Policy):
|
|
|
318
336
|
embeddings_type=self.embeddings_config.get(PROVIDER_CONFIG_KEY),
|
|
319
337
|
embeddings_model=self.embeddings_config.get(MODEL_CONFIG_KEY)
|
|
320
338
|
or self.embeddings_config.get(MODEL_NAME_CONFIG_KEY),
|
|
339
|
+
embeddings_model_group_id=self.embeddings_config.get(
|
|
340
|
+
MODEL_GROUP_ID_CONFIG_KEY
|
|
341
|
+
),
|
|
321
342
|
llm_type=self.llm_config.get(PROVIDER_CONFIG_KEY),
|
|
322
343
|
llm_model=self.llm_config.get(MODEL_CONFIG_KEY)
|
|
323
344
|
or self.llm_config.get(MODEL_NAME_CONFIG_KEY),
|
|
345
|
+
llm_model_group_id=self.llm_config.get(MODEL_GROUP_ID_CONFIG_KEY),
|
|
324
346
|
citation_enabled=self.citation_enabled,
|
|
325
347
|
)
|
|
326
348
|
self.persist()
|
|
@@ -332,6 +354,9 @@ class EnterpriseSearchPolicy(Policy):
|
|
|
332
354
|
rasa.shared.utils.io.write_text_file(
|
|
333
355
|
self.prompt_template, path / ENTERPRISE_SEARCH_PROMPT_FILE_NAME
|
|
334
356
|
)
|
|
357
|
+
rasa.shared.utils.io.dump_obj_as_json_to_file(
|
|
358
|
+
path / ENTERPRISE_SEARCH_CONFIG_FILE_NAME, self.config
|
|
359
|
+
)
|
|
335
360
|
|
|
336
361
|
def _prepare_slots_for_template(
|
|
337
362
|
self, tracker: DialogueStateTracker
|
|
@@ -512,9 +537,13 @@ class EnterpriseSearchPolicy(Policy):
|
|
|
512
537
|
embeddings_type=self.embeddings_config.get(PROVIDER_CONFIG_KEY),
|
|
513
538
|
embeddings_model=self.embeddings_config.get(MODEL_CONFIG_KEY)
|
|
514
539
|
or self.embeddings_config.get(MODEL_NAME_CONFIG_KEY),
|
|
540
|
+
embeddings_model_group_id=self.embeddings_config.get(
|
|
541
|
+
MODEL_GROUP_ID_CONFIG_KEY
|
|
542
|
+
),
|
|
515
543
|
llm_type=self.llm_config.get(PROVIDER_CONFIG_KEY),
|
|
516
544
|
llm_model=self.llm_config.get(MODEL_CONFIG_KEY)
|
|
517
545
|
or self.llm_config.get(MODEL_NAME_CONFIG_KEY),
|
|
546
|
+
llm_model_group_id=self.llm_config.get(MODEL_GROUP_ID_CONFIG_KEY),
|
|
518
547
|
citation_enabled=self.citation_enabled,
|
|
519
548
|
)
|
|
520
549
|
return self._create_prediction(
|
|
@@ -662,25 +691,28 @@ class EnterpriseSearchPolicy(Policy):
|
|
|
662
691
|
execution_context: ExecutionContext,
|
|
663
692
|
**kwargs: Any,
|
|
664
693
|
) -> "EnterpriseSearchPolicy":
|
|
665
|
-
try_instantiate_llm_client(
|
|
666
|
-
config.get(LLM_CONFIG_KEY),
|
|
667
|
-
DEFAULT_LLM_CONFIG,
|
|
668
|
-
"enterprise_search_policy.load",
|
|
669
|
-
EnterpriseSearchPolicy.__name__,
|
|
670
|
-
)
|
|
671
|
-
try_instantiate_embedder(
|
|
672
|
-
config.get(EMBEDDINGS_CONFIG_KEY),
|
|
673
|
-
DEFAULT_EMBEDDINGS_CONFIG,
|
|
674
|
-
"enterprise_search_policy.load",
|
|
675
|
-
EnterpriseSearchPolicy.__name__,
|
|
676
|
-
)
|
|
677
694
|
"""Loads a trained policy (see parent class for full docstring)."""
|
|
695
|
+
|
|
696
|
+
# Perform health checks for both LLM and embeddings client configs
|
|
697
|
+
cls._perform_health_checks(config, "enterprise_search_policy.load")
|
|
698
|
+
|
|
678
699
|
prompt_template = None
|
|
700
|
+
try:
|
|
701
|
+
with model_storage.read_from(resource) as path:
|
|
702
|
+
prompt_template = rasa.shared.utils.io.read_file(
|
|
703
|
+
path / ENTERPRISE_SEARCH_PROMPT_FILE_NAME
|
|
704
|
+
)
|
|
705
|
+
except (FileNotFoundError, FileIOException) as e:
|
|
706
|
+
logger.warning(
|
|
707
|
+
"enterprise_search_policy.load.failed", error=e, resource=resource.name
|
|
708
|
+
)
|
|
709
|
+
|
|
679
710
|
store_type = config.get(VECTOR_STORE_PROPERTY, {}).get(
|
|
680
711
|
VECTOR_STORE_TYPE_PROPERTY
|
|
681
712
|
)
|
|
682
713
|
|
|
683
714
|
embeddings = cls._create_plain_embedder(config)
|
|
715
|
+
|
|
684
716
|
logger.info("enterprise_search_policy.load", config=config)
|
|
685
717
|
if store_type == DEFAULT_VECTOR_STORE_TYPE:
|
|
686
718
|
# if a vector store is not specified,
|
|
@@ -698,16 +730,7 @@ class EnterpriseSearchPolicy(Policy):
|
|
|
698
730
|
config_type=store_type,
|
|
699
731
|
embeddings=embeddings,
|
|
700
732
|
) # type: ignore
|
|
701
|
-
try:
|
|
702
|
-
with model_storage.read_from(resource) as path:
|
|
703
|
-
prompt_template = rasa.shared.utils.io.read_file(
|
|
704
|
-
path / ENTERPRISE_SEARCH_PROMPT_FILE_NAME
|
|
705
|
-
)
|
|
706
733
|
|
|
707
|
-
except (FileNotFoundError, FileNotFoundError) as e:
|
|
708
|
-
logger.warning(
|
|
709
|
-
"enterprise_search_policy.load.failed", error=e, resource=resource.name
|
|
710
|
-
)
|
|
711
734
|
return cls(
|
|
712
735
|
config,
|
|
713
736
|
model_storage,
|
|
@@ -748,14 +771,23 @@ class EnterpriseSearchPolicy(Policy):
|
|
|
748
771
|
|
|
749
772
|
@classmethod
|
|
750
773
|
def fingerprint_addon(cls, config: Dict[str, Any]) -> Optional[str]:
|
|
751
|
-
"""Add a fingerprint of
|
|
774
|
+
"""Add a fingerprint of enterprise search policy for the graph."""
|
|
752
775
|
local_knowledge_data = cls._get_local_knowledge_data(config)
|
|
753
776
|
|
|
754
777
|
prompt_template = get_prompt_template(
|
|
755
778
|
config.get(PROMPT_CONFIG_KEY),
|
|
756
779
|
DEFAULT_ENTERPRISE_SEARCH_PROMPT_TEMPLATE,
|
|
757
780
|
)
|
|
758
|
-
|
|
781
|
+
|
|
782
|
+
llm_config = resolve_model_client_config(
|
|
783
|
+
config.get(LLM_CONFIG_KEY), EnterpriseSearchPolicy.__name__
|
|
784
|
+
)
|
|
785
|
+
embedding_config = resolve_model_client_config(
|
|
786
|
+
config.get(EMBEDDINGS_CONFIG_KEY), EnterpriseSearchPolicy.__name__
|
|
787
|
+
)
|
|
788
|
+
return deep_container_fingerprint(
|
|
789
|
+
[prompt_template, local_knowledge_data, llm_config, embedding_config]
|
|
790
|
+
)
|
|
759
791
|
|
|
760
792
|
@staticmethod
|
|
761
793
|
def post_process_citations(llm_answer: str) -> str:
|
|
@@ -847,3 +879,27 @@ class EnterpriseSearchPolicy(Policy):
|
|
|
847
879
|
joined_sources = "\n".join(new_sources)
|
|
848
880
|
|
|
849
881
|
return joined_answer + joined_sources
|
|
882
|
+
|
|
883
|
+
@classmethod
|
|
884
|
+
def _perform_health_checks(
|
|
885
|
+
cls, config: Dict[Text, Any], log_source_method: str
|
|
886
|
+
) -> None:
|
|
887
|
+
# Perform health check of the LLM client config
|
|
888
|
+
llm_config = resolve_model_client_config(config.get(LLM_CONFIG_KEY, {}))
|
|
889
|
+
cls.perform_llm_health_check(
|
|
890
|
+
llm_config,
|
|
891
|
+
DEFAULT_LLM_CONFIG,
|
|
892
|
+
log_source_method,
|
|
893
|
+
EnterpriseSearchPolicy.__name__,
|
|
894
|
+
)
|
|
895
|
+
|
|
896
|
+
# Perform health check of the embeddings client config
|
|
897
|
+
embeddings_config = resolve_model_client_config(
|
|
898
|
+
config.get(EMBEDDINGS_CONFIG_KEY, {})
|
|
899
|
+
)
|
|
900
|
+
cls.perform_embeddings_health_check(
|
|
901
|
+
embeddings_config,
|
|
902
|
+
DEFAULT_EMBEDDINGS_CONFIG,
|
|
903
|
+
log_source_method,
|
|
904
|
+
EnterpriseSearchPolicy.__name__,
|
|
905
|
+
)
|
|
@@ -330,24 +330,27 @@ def reset_scoped_slots(
|
|
|
330
330
|
events: List[Event] = []
|
|
331
331
|
|
|
332
332
|
not_resettable_slot_names = set()
|
|
333
|
+
flow_persistable_slots = current_flow.persisted_slots
|
|
333
334
|
|
|
334
335
|
for step in current_flow.steps_with_calls_resolved:
|
|
335
336
|
if isinstance(step, CollectInformationFlowStep):
|
|
336
337
|
# reset all slots scoped to the flow
|
|
337
|
-
|
|
338
|
-
|
|
338
|
+
slot_name = step.collect
|
|
339
|
+
if step.reset_after_flow_ends and slot_name not in flow_persistable_slots:
|
|
340
|
+
_reset_slot(slot_name, tracker)
|
|
339
341
|
else:
|
|
340
|
-
not_resettable_slot_names.add(
|
|
342
|
+
not_resettable_slot_names.add(slot_name)
|
|
341
343
|
|
|
342
344
|
# slots set by the set slots step should be reset after the flow ends
|
|
343
345
|
# unless they are also used in a collect step where `reset_after_flow_ends`
|
|
344
|
-
# is set to `False`
|
|
346
|
+
# is set to `False` or set in the `persisted_slots` list.
|
|
345
347
|
resettable_set_slots = [
|
|
346
348
|
slot["key"]
|
|
347
349
|
for step in current_flow.steps_with_calls_resolved
|
|
348
350
|
if isinstance(step, SetSlotsFlowStep)
|
|
349
351
|
for slot in step.slots
|
|
350
352
|
if slot["key"] not in not_resettable_slot_names
|
|
353
|
+
and slot["key"] not in flow_persistable_slots
|
|
351
354
|
]
|
|
352
355
|
|
|
353
356
|
for name in resettable_set_slots:
|
|
@@ -484,8 +487,7 @@ def validate_collect_step(
|
|
|
484
487
|
step: CollectInformationFlowStep,
|
|
485
488
|
stack: DialogueStack,
|
|
486
489
|
available_actions: List[str],
|
|
487
|
-
slots: Dict[
|
|
488
|
-
flow_name: str,
|
|
490
|
+
slots: Dict[Text, Slot],
|
|
489
491
|
) -> bool:
|
|
490
492
|
"""Validate that a collect step can be executed.
|
|
491
493
|
|
|
@@ -508,12 +510,12 @@ def validate_collect_step(
|
|
|
508
510
|
slot_name=step.collect,
|
|
509
511
|
)
|
|
510
512
|
|
|
511
|
-
cancel_flow_and_push_internal_error(stack
|
|
513
|
+
cancel_flow_and_push_internal_error(stack)
|
|
512
514
|
|
|
513
515
|
return False
|
|
514
516
|
|
|
515
517
|
|
|
516
|
-
def cancel_flow_and_push_internal_error(stack: DialogueStack
|
|
518
|
+
def cancel_flow_and_push_internal_error(stack: DialogueStack) -> None:
|
|
517
519
|
"""Cancel the top user flow and push the internal error pattern."""
|
|
518
520
|
top_frame = stack.top()
|
|
519
521
|
|
|
@@ -525,7 +527,7 @@ def cancel_flow_and_push_internal_error(stack: DialogueStack, flow_name: str) ->
|
|
|
525
527
|
canceled_frames = CancelFlowCommand.select_canceled_frames(stack)
|
|
526
528
|
stack.push(
|
|
527
529
|
CancelPatternFlowStackFrame(
|
|
528
|
-
canceled_name=
|
|
530
|
+
canceled_name=top_frame.flow_id,
|
|
529
531
|
canceled_frames=canceled_frames,
|
|
530
532
|
)
|
|
531
533
|
)
|
|
@@ -537,7 +539,6 @@ def validate_custom_slot_mappings(
|
|
|
537
539
|
stack: DialogueStack,
|
|
538
540
|
tracker: DialogueStateTracker,
|
|
539
541
|
available_actions: List[str],
|
|
540
|
-
flow_name: str,
|
|
541
542
|
) -> bool:
|
|
542
543
|
"""Validate a slot with custom mappings.
|
|
543
544
|
|
|
@@ -558,7 +559,7 @@ def validate_custom_slot_mappings(
|
|
|
558
559
|
action=step.collect_action,
|
|
559
560
|
collect=step.collect,
|
|
560
561
|
)
|
|
561
|
-
cancel_flow_and_push_internal_error(stack
|
|
562
|
+
cancel_flow_and_push_internal_error(stack)
|
|
562
563
|
return False
|
|
563
564
|
|
|
564
565
|
return True
|
|
@@ -598,12 +599,7 @@ def run_step(
|
|
|
598
599
|
|
|
599
600
|
if isinstance(step, CollectInformationFlowStep):
|
|
600
601
|
return _run_collect_information_step(
|
|
601
|
-
available_actions,
|
|
602
|
-
initial_events,
|
|
603
|
-
stack,
|
|
604
|
-
step,
|
|
605
|
-
tracker,
|
|
606
|
-
flow.readable_name(),
|
|
602
|
+
available_actions, initial_events, stack, step, tracker
|
|
607
603
|
)
|
|
608
604
|
|
|
609
605
|
elif isinstance(step, ActionFlowStep):
|
|
@@ -723,18 +719,15 @@ def _run_collect_information_step(
|
|
|
723
719
|
stack: DialogueStack,
|
|
724
720
|
step: CollectInformationFlowStep,
|
|
725
721
|
tracker: DialogueStateTracker,
|
|
726
|
-
flow_name: str,
|
|
727
722
|
) -> FlowStepResult:
|
|
728
|
-
is_step_valid = validate_collect_step(
|
|
729
|
-
step, stack, available_actions, tracker.slots, flow_name
|
|
730
|
-
)
|
|
723
|
+
is_step_valid = validate_collect_step(step, stack, available_actions, tracker.slots)
|
|
731
724
|
|
|
732
725
|
if not is_step_valid:
|
|
733
726
|
# if we return any other FlowStepResult, the assistant will stay silent
|
|
734
727
|
# instead of triggering the internal error pattern
|
|
735
728
|
return ContinueFlowWithNextStep(events=initial_events)
|
|
736
729
|
is_mapping_valid = validate_custom_slot_mappings(
|
|
737
|
-
step, stack, tracker, available_actions
|
|
730
|
+
step, stack, tracker, available_actions
|
|
738
731
|
)
|
|
739
732
|
|
|
740
733
|
if not is_mapping_valid:
|