rasa-pro 3.11.0__py3-none-any.whl → 3.11.0a2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- README.md +396 -17
- rasa/__main__.py +15 -31
- rasa/api.py +1 -5
- rasa/cli/arguments/default_arguments.py +2 -1
- rasa/cli/arguments/shell.py +1 -5
- rasa/cli/arguments/train.py +0 -14
- rasa/cli/e2e_test.py +1 -1
- rasa/cli/evaluate.py +8 -8
- rasa/cli/inspect.py +5 -7
- rasa/cli/interactive.py +0 -1
- rasa/cli/llm_fine_tuning.py +1 -1
- rasa/cli/project_templates/calm/config.yml +7 -5
- rasa/cli/project_templates/calm/endpoints.yml +2 -15
- rasa/cli/project_templates/tutorial/config.yml +5 -8
- rasa/cli/project_templates/tutorial/data/flows.yml +1 -1
- rasa/cli/project_templates/tutorial/data/patterns.yml +0 -5
- rasa/cli/project_templates/tutorial/domain.yml +0 -14
- rasa/cli/project_templates/tutorial/endpoints.yml +0 -5
- rasa/cli/run.py +1 -1
- rasa/cli/scaffold.py +2 -4
- rasa/cli/studio/studio.py +8 -18
- rasa/cli/studio/upload.py +15 -0
- rasa/cli/train.py +0 -3
- rasa/cli/utils.py +1 -6
- rasa/cli/x.py +8 -8
- rasa/constants.py +1 -3
- rasa/core/actions/action.py +33 -75
- rasa/core/actions/e2e_stub_custom_action_executor.py +1 -5
- rasa/core/actions/http_custom_action_executor.py +0 -4
- rasa/core/channels/channel.py +0 -20
- rasa/core/channels/development_inspector.py +2 -8
- rasa/core/channels/inspector/dist/assets/{arc-bc141fb2.js → arc-6852c607.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{c4Diagram-d0fbc5ce-be2db283.js → c4Diagram-d0fbc5ce-acc952b2.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{classDiagram-936ed81e-55366915.js → classDiagram-936ed81e-848a7597.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{classDiagram-v2-c3cb15f1-bb529518.js → classDiagram-v2-c3cb15f1-a73d3e68.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{createText-62fc7601-b0ec81d6.js → createText-62fc7601-e5ee049d.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{edges-f2ad444c-6166330c.js → edges-f2ad444c-771e517e.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{erDiagram-9d236eb7-5ccc6a8e.js → erDiagram-9d236eb7-aa347178.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{flowDb-1972c806-fca3bfe4.js → flowDb-1972c806-651fc57d.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{flowDiagram-7ea5b25a-4739080f.js → flowDiagram-7ea5b25a-ca67804f.js} +1 -1
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-587d82d8.js +1 -0
- rasa/core/channels/inspector/dist/assets/{flowchart-elk-definition-abe16c3d-7c1b0e0f.js → flowchart-elk-definition-abe16c3d-2dbc568d.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{ganttDiagram-9b5ea136-772fd050.js → ganttDiagram-9b5ea136-25a65bd8.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{gitGraphDiagram-99d0ae7c-8eae1dc9.js → gitGraphDiagram-99d0ae7c-fdc7378d.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{index-2c4b9a3b-f55afcdf.js → index-2c4b9a3b-6f1fd606.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{index-e7cef9de.js → index-efdd30c1.js} +68 -68
- rasa/core/channels/inspector/dist/assets/{infoDiagram-736b4530-124d4a14.js → infoDiagram-736b4530-cb1a041a.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{journeyDiagram-df861f2b-7c4fae44.js → journeyDiagram-df861f2b-14609879.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{layout-b9885fb6.js → layout-2490f52b.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{line-7c59abb6.js → line-40186f1f.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{linear-4776f780.js → linear-08814e93.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{mindmap-definition-beec6740-2332c46c.js → mindmap-definition-beec6740-1a534584.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{pieDiagram-dbbf0591-8fb39303.js → pieDiagram-dbbf0591-72397b61.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{quadrantDiagram-4d7f4fd6-3c7180a2.js → quadrantDiagram-4d7f4fd6-3bb0b6a3.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{requirementDiagram-6fc4c22a-e910bcb8.js → requirementDiagram-6fc4c22a-57334f61.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{sankeyDiagram-8f13d901-ead16c89.js → sankeyDiagram-8f13d901-111e1297.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{sequenceDiagram-b655622a-29a02a19.js → sequenceDiagram-b655622a-10bcfe62.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{stateDiagram-59f0c015-042b3137.js → stateDiagram-59f0c015-acaf7513.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{stateDiagram-v2-2b26beab-2178c0f3.js → stateDiagram-v2-2b26beab-3ec2a235.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-080da4f6-23ffa4fc.js → styles-080da4f6-62730289.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-3dcbcfbf-94f59763.js → styles-3dcbcfbf-5284ee76.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-9c745c82-78a6bebc.js → styles-9c745c82-642435e3.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{svgDrawCommon-4835440b-eae2a6f6.js → svgDrawCommon-4835440b-b250a350.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{timeline-definition-5b62e21b-5c968d92.js → timeline-definition-5b62e21b-c2b147ed.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{xychartDiagram-2b33534f-fd3db0d5.js → xychartDiagram-2b33534f-f92cfea9.js} +1 -1
- rasa/core/channels/inspector/dist/index.html +1 -1
- rasa/core/channels/inspector/src/App.tsx +1 -1
- rasa/core/channels/inspector/src/helpers/audiostream.ts +16 -77
- rasa/core/channels/socketio.py +2 -7
- rasa/core/channels/telegram.py +1 -1
- rasa/core/channels/twilio.py +1 -1
- rasa/core/channels/voice_ready/audiocodes.py +4 -15
- rasa/core/channels/voice_ready/jambonz.py +4 -15
- rasa/core/channels/voice_ready/twilio_voice.py +21 -6
- rasa/core/channels/voice_ready/utils.py +5 -6
- rasa/core/channels/voice_stream/asr/asr_engine.py +1 -19
- rasa/core/channels/voice_stream/asr/asr_event.py +0 -5
- rasa/core/channels/voice_stream/asr/deepgram.py +15 -28
- rasa/core/channels/voice_stream/audio_bytes.py +0 -1
- rasa/core/channels/voice_stream/browser_audio.py +9 -32
- rasa/core/channels/voice_stream/tts/azure.py +3 -9
- rasa/core/channels/voice_stream/tts/cartesia.py +8 -12
- rasa/core/channels/voice_stream/tts/tts_engine.py +1 -11
- rasa/core/channels/voice_stream/twilio_media_streams.py +19 -28
- rasa/core/channels/voice_stream/util.py +4 -4
- rasa/core/channels/voice_stream/voice_channel.py +42 -222
- rasa/core/featurizers/single_state_featurizer.py +1 -22
- rasa/core/featurizers/tracker_featurizers.py +18 -115
- rasa/core/information_retrieval/qdrant.py +0 -1
- rasa/core/nlg/contextual_response_rephraser.py +25 -44
- rasa/core/persistor.py +34 -191
- rasa/core/policies/enterprise_search_policy.py +60 -119
- rasa/core/policies/flows/flow_executor.py +4 -7
- rasa/core/policies/intentless_policy.py +22 -82
- rasa/core/policies/ted_policy.py +33 -58
- rasa/core/policies/unexpected_intent_policy.py +7 -15
- rasa/core/processor.py +5 -32
- rasa/core/training/interactive.py +35 -34
- rasa/core/utils.py +22 -58
- rasa/dialogue_understanding/coexistence/llm_based_router.py +12 -39
- rasa/dialogue_understanding/commands/__init__.py +0 -4
- rasa/dialogue_understanding/commands/change_flow_command.py +0 -6
- rasa/dialogue_understanding/commands/utils.py +0 -5
- rasa/dialogue_understanding/generator/constants.py +0 -2
- rasa/dialogue_understanding/generator/flow_retrieval.py +4 -49
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +23 -37
- rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +10 -57
- rasa/dialogue_understanding/generator/nlu_command_adapter.py +1 -19
- rasa/dialogue_understanding/generator/single_step/command_prompt_template.jinja2 +0 -3
- rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +10 -90
- rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +0 -53
- rasa/dialogue_understanding/processor/command_processor.py +1 -21
- rasa/e2e_test/assertions.py +16 -133
- rasa/e2e_test/assertions_schema.yml +0 -23
- rasa/e2e_test/e2e_test_case.py +6 -85
- rasa/e2e_test/e2e_test_runner.py +4 -6
- rasa/e2e_test/utils/io.py +1 -3
- rasa/engine/loader.py +0 -12
- rasa/engine/validation.py +11 -541
- rasa/keys +1 -0
- rasa/llm_fine_tuning/notebooks/unsloth_finetuning.ipynb +407 -0
- rasa/model_training.py +7 -29
- rasa/nlu/classifiers/diet_classifier.py +25 -38
- rasa/nlu/classifiers/logistic_regression_classifier.py +9 -22
- rasa/nlu/classifiers/sklearn_intent_classifier.py +16 -37
- rasa/nlu/extractors/crf_entity_extractor.py +50 -93
- rasa/nlu/featurizers/sparse_featurizer/count_vectors_featurizer.py +16 -45
- rasa/nlu/featurizers/sparse_featurizer/lexical_syntactic_featurizer.py +17 -52
- rasa/nlu/featurizers/sparse_featurizer/regex_featurizer.py +3 -5
- rasa/nlu/tokenizers/whitespace_tokenizer.py +14 -3
- rasa/server.py +1 -3
- rasa/shared/constants.py +0 -61
- rasa/shared/core/constants.py +0 -9
- rasa/shared/core/domain.py +5 -8
- rasa/shared/core/flows/flow.py +0 -5
- rasa/shared/core/flows/flows_list.py +1 -5
- rasa/shared/core/flows/flows_yaml_schema.json +0 -10
- rasa/shared/core/flows/validation.py +0 -96
- rasa/shared/core/flows/yaml_flows_io.py +4 -13
- rasa/shared/core/slots.py +0 -5
- rasa/shared/importers/importer.py +2 -19
- rasa/shared/importers/rasa.py +1 -5
- rasa/shared/nlu/training_data/features.py +2 -120
- rasa/shared/nlu/training_data/formats/rasa_yaml.py +3 -18
- rasa/shared/providers/_configs/azure_openai_client_config.py +3 -5
- rasa/shared/providers/_configs/openai_client_config.py +1 -1
- rasa/shared/providers/_configs/self_hosted_llm_client_config.py +0 -1
- rasa/shared/providers/_configs/utils.py +0 -16
- rasa/shared/providers/embedding/_base_litellm_embedding_client.py +29 -18
- rasa/shared/providers/embedding/azure_openai_embedding_client.py +21 -54
- rasa/shared/providers/embedding/default_litellm_embedding_client.py +0 -24
- rasa/shared/providers/llm/_base_litellm_client.py +31 -63
- rasa/shared/providers/llm/azure_openai_llm_client.py +29 -50
- rasa/shared/providers/llm/default_litellm_llm_client.py +0 -24
- rasa/shared/providers/llm/self_hosted_llm_client.py +29 -17
- rasa/shared/providers/mappings.py +0 -19
- rasa/shared/utils/common.py +2 -37
- rasa/shared/utils/io.py +6 -28
- rasa/shared/utils/llm.py +46 -353
- rasa/shared/utils/yaml.py +82 -181
- rasa/studio/auth.py +5 -3
- rasa/studio/config.py +4 -13
- rasa/studio/constants.py +0 -1
- rasa/studio/data_handler.py +4 -13
- rasa/studio/upload.py +80 -175
- rasa/telemetry.py +17 -94
- rasa/tracing/config.py +1 -3
- rasa/tracing/instrumentation/attribute_extractors.py +17 -94
- rasa/tracing/instrumentation/instrumentation.py +0 -121
- rasa/utils/common.py +0 -5
- rasa/utils/endpoints.py +1 -27
- rasa/utils/io.py +81 -7
- rasa/utils/log_utils.py +2 -9
- rasa/utils/tensorflow/model_data.py +193 -2
- rasa/validator.py +4 -110
- rasa/version.py +1 -1
- rasa_pro-3.11.0a2.dist-info/METADATA +576 -0
- {rasa_pro-3.11.0.dist-info → rasa_pro-3.11.0a2.dist-info}/RECORD +181 -213
- rasa/core/actions/action_repeat_bot_messages.py +0 -89
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-736177bf.js +0 -1
- rasa/core/channels/voice_stream/asr/azure.py +0 -129
- rasa/core/channels/voice_stream/call_state.py +0 -23
- rasa/dialogue_understanding/commands/repeat_bot_messages_command.py +0 -60
- rasa/dialogue_understanding/commands/user_silence_command.py +0 -59
- rasa/dialogue_understanding/patterns/repeat.py +0 -37
- rasa/dialogue_understanding/patterns/user_silence.py +0 -37
- rasa/model_manager/__init__.py +0 -0
- rasa/model_manager/config.py +0 -40
- rasa/model_manager/model_api.py +0 -559
- rasa/model_manager/runner_service.py +0 -286
- rasa/model_manager/socket_bridge.py +0 -146
- rasa/model_manager/studio_jwt_auth.py +0 -86
- rasa/model_manager/trainer_service.py +0 -325
- rasa/model_manager/utils.py +0 -87
- rasa/model_manager/warm_rasa_process.py +0 -187
- rasa/model_service.py +0 -112
- rasa/shared/core/flows/utils.py +0 -39
- rasa/shared/providers/_configs/litellm_router_client_config.py +0 -220
- rasa/shared/providers/_configs/model_group_config.py +0 -167
- rasa/shared/providers/_configs/rasa_llm_client_config.py +0 -73
- rasa/shared/providers/_utils.py +0 -79
- rasa/shared/providers/embedding/litellm_router_embedding_client.py +0 -135
- rasa/shared/providers/llm/litellm_router_llm_client.py +0 -182
- rasa/shared/providers/llm/rasa_llm_client.py +0 -112
- rasa/shared/providers/router/__init__.py +0 -0
- rasa/shared/providers/router/_base_litellm_router_client.py +0 -183
- rasa/shared/providers/router/router_client.py +0 -73
- rasa/shared/utils/health_check/__init__.py +0 -0
- rasa/shared/utils/health_check/embeddings_health_check_mixin.py +0 -31
- rasa/shared/utils/health_check/health_check.py +0 -258
- rasa/shared/utils/health_check/llm_health_check_mixin.py +0 -31
- rasa/utils/sanic_error_handler.py +0 -32
- rasa/utils/tensorflow/feature_array.py +0 -366
- rasa_pro-3.11.0.dist-info/METADATA +0 -198
- {rasa_pro-3.11.0.dist-info → rasa_pro-3.11.0a2.dist-info}/NOTICE +0 -0
- {rasa_pro-3.11.0.dist-info → rasa_pro-3.11.0a2.dist-info}/WHEEL +0 -0
- {rasa_pro-3.11.0.dist-info → rasa_pro-3.11.0a2.dist-info}/entry_points.txt +0 -0
|
@@ -1,42 +1,45 @@
|
|
|
1
1
|
import importlib.resources
|
|
2
2
|
import json
|
|
3
|
+
import os
|
|
3
4
|
import re
|
|
4
5
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Text
|
|
6
|
+
|
|
5
7
|
import dotenv
|
|
6
8
|
import structlog
|
|
7
9
|
from jinja2 import Template
|
|
8
10
|
from pydantic import ValidationError
|
|
9
11
|
|
|
12
|
+
from rasa.shared.providers.embedding._langchain_embedding_client_adapter import (
|
|
13
|
+
_LangchainEmbeddingClientAdapter,
|
|
14
|
+
)
|
|
15
|
+
from rasa.shared.providers.llm.llm_client import LLMClient
|
|
16
|
+
|
|
10
17
|
import rasa.shared.utils.io
|
|
18
|
+
from rasa.telemetry import (
|
|
19
|
+
track_enterprise_search_policy_predict,
|
|
20
|
+
track_enterprise_search_policy_train_completed,
|
|
21
|
+
track_enterprise_search_policy_train_started,
|
|
22
|
+
)
|
|
23
|
+
from rasa.shared.exceptions import RasaException
|
|
11
24
|
from rasa.core.constants import (
|
|
12
25
|
POLICY_MAX_HISTORY,
|
|
13
26
|
POLICY_PRIORITY,
|
|
14
27
|
SEARCH_POLICY_PRIORITY,
|
|
15
28
|
UTTER_SOURCE_METADATA_KEY,
|
|
16
29
|
)
|
|
17
|
-
from rasa.core.information_retrieval import (
|
|
18
|
-
InformationRetrieval,
|
|
19
|
-
SearchResult,
|
|
20
|
-
InformationRetrievalException,
|
|
21
|
-
create_from_endpoint_config,
|
|
22
|
-
)
|
|
23
|
-
from rasa.core.information_retrieval.faiss import FAISS_Store
|
|
24
30
|
from rasa.core.policies.policy import Policy, PolicyPrediction
|
|
25
31
|
from rasa.core.utils import AvailableEndpoints
|
|
26
|
-
from rasa.dialogue_understanding.
|
|
27
|
-
|
|
32
|
+
from rasa.dialogue_understanding.patterns.internal_error import (
|
|
33
|
+
InternalErrorPatternFlowStackFrame,
|
|
28
34
|
)
|
|
29
35
|
from rasa.dialogue_understanding.patterns.cannot_handle import (
|
|
30
36
|
CannotHandlePatternFlowStackFrame,
|
|
31
37
|
)
|
|
32
|
-
from rasa.dialogue_understanding.
|
|
33
|
-
InternalErrorPatternFlowStackFrame,
|
|
34
|
-
)
|
|
38
|
+
from rasa.dialogue_understanding.stack.frames import PatternFlowStackFrame
|
|
35
39
|
from rasa.dialogue_understanding.stack.frames import (
|
|
36
40
|
DialogueStackFrame,
|
|
37
41
|
SearchStackFrame,
|
|
38
42
|
)
|
|
39
|
-
from rasa.dialogue_understanding.stack.frames import PatternFlowStackFrame
|
|
40
43
|
from rasa.engine.graph import ExecutionContext
|
|
41
44
|
from rasa.engine.recipes.default_recipe import DefaultV1Recipe
|
|
42
45
|
from rasa.engine.storage.resource import Resource
|
|
@@ -45,13 +48,14 @@ from rasa.graph_components.providers.forms_provider import Forms
|
|
|
45
48
|
from rasa.graph_components.providers.responses_provider import Responses
|
|
46
49
|
from rasa.shared.constants import (
|
|
47
50
|
EMBEDDINGS_CONFIG_KEY,
|
|
51
|
+
LLM_API_HEALTH_CHECK_ENV_VAR,
|
|
52
|
+
LLM_CONFIG_KEY,
|
|
48
53
|
MODEL_CONFIG_KEY,
|
|
54
|
+
MODEL_NAME_CONFIG_KEY,
|
|
49
55
|
PROMPT_CONFIG_KEY,
|
|
50
56
|
PROVIDER_CONFIG_KEY,
|
|
51
57
|
OPENAI_PROVIDER,
|
|
52
58
|
TIMEOUT_CONFIG_KEY,
|
|
53
|
-
MODEL_NAME_CONFIG_KEY,
|
|
54
|
-
MODEL_GROUP_ID_CONFIG_KEY,
|
|
55
59
|
)
|
|
56
60
|
from rasa.shared.core.constants import (
|
|
57
61
|
ACTION_CANCEL_FLOW,
|
|
@@ -62,32 +66,26 @@ from rasa.shared.core.domain import Domain
|
|
|
62
66
|
from rasa.shared.core.events import Event, UserUttered, BotUttered
|
|
63
67
|
from rasa.shared.core.generator import TrackerWithCachedStates
|
|
64
68
|
from rasa.shared.core.trackers import DialogueStateTracker, EventVerbosity
|
|
65
|
-
from rasa.shared.exceptions import RasaException, FileIOException
|
|
66
69
|
from rasa.shared.nlu.training_data.training_data import TrainingData
|
|
67
|
-
from rasa.shared.providers.embedding._langchain_embedding_client_adapter import (
|
|
68
|
-
_LangchainEmbeddingClientAdapter,
|
|
69
|
-
)
|
|
70
|
-
from rasa.shared.providers.llm.llm_client import LLMClient
|
|
71
70
|
from rasa.shared.utils.cli import print_error_and_exit
|
|
72
|
-
from rasa.shared.utils.health_check.embeddings_health_check_mixin import (
|
|
73
|
-
EmbeddingsHealthCheckMixin,
|
|
74
|
-
)
|
|
75
|
-
from rasa.shared.utils.health_check.llm_health_check_mixin import LLMHealthCheckMixin
|
|
76
71
|
from rasa.shared.utils.io import deep_container_fingerprint
|
|
77
72
|
from rasa.shared.utils.llm import (
|
|
78
73
|
DEFAULT_OPENAI_CHAT_MODEL_NAME,
|
|
79
74
|
DEFAULT_OPENAI_EMBEDDING_MODEL_NAME,
|
|
80
75
|
embedder_factory,
|
|
81
76
|
get_prompt_template,
|
|
77
|
+
llm_api_health_check,
|
|
82
78
|
llm_factory,
|
|
83
79
|
sanitize_message_for_prompt,
|
|
84
80
|
tracker_as_readable_transcript,
|
|
85
|
-
|
|
81
|
+
try_instantiate_llm_client,
|
|
86
82
|
)
|
|
87
|
-
from rasa.
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
83
|
+
from rasa.core.information_retrieval.faiss import FAISS_Store
|
|
84
|
+
from rasa.core.information_retrieval import (
|
|
85
|
+
InformationRetrieval,
|
|
86
|
+
SearchResult,
|
|
87
|
+
InformationRetrievalException,
|
|
88
|
+
create_from_endpoint_config,
|
|
91
89
|
)
|
|
92
90
|
|
|
93
91
|
if TYPE_CHECKING:
|
|
@@ -132,7 +130,6 @@ DEFAULT_EMBEDDINGS_CONFIG = {
|
|
|
132
130
|
}
|
|
133
131
|
|
|
134
132
|
ENTERPRISE_SEARCH_PROMPT_FILE_NAME = "enterprise_search_policy_prompt.jinja2"
|
|
135
|
-
ENTERPRISE_SEARCH_CONFIG_FILE_NAME = "config.json"
|
|
136
133
|
|
|
137
134
|
SEARCH_RESULTS_METADATA_KEY = "search_results"
|
|
138
135
|
SEARCH_QUERY_METADATA_KEY = "search_query"
|
|
@@ -157,7 +154,7 @@ class VectorStoreConfigurationError(RasaException):
|
|
|
157
154
|
@DefaultV1Recipe.register(
|
|
158
155
|
DefaultV1Recipe.ComponentType.POLICY_WITH_END_TO_END_SUPPORT, is_trainable=True
|
|
159
156
|
)
|
|
160
|
-
class EnterpriseSearchPolicy(
|
|
157
|
+
class EnterpriseSearchPolicy(Policy):
|
|
161
158
|
"""Policy which uses a vector store and LLMs to respond to user messages.
|
|
162
159
|
|
|
163
160
|
The policy uses a vector store and LLMs to respond to user messages. The
|
|
@@ -203,35 +200,24 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
|
|
|
203
200
|
"""Constructs a new Policy object."""
|
|
204
201
|
super().__init__(config, model_storage, resource, execution_context, featurizer)
|
|
205
202
|
|
|
206
|
-
# Resolve LLM config
|
|
207
|
-
self.config[LLM_CONFIG_KEY] = resolve_model_client_config(
|
|
208
|
-
self.config.get(LLM_CONFIG_KEY), EnterpriseSearchPolicy.__name__
|
|
209
|
-
)
|
|
210
|
-
# Resolve embeddings config
|
|
211
|
-
self.config[EMBEDDINGS_CONFIG_KEY] = resolve_model_client_config(
|
|
212
|
-
self.config.get(EMBEDDINGS_CONFIG_KEY), EnterpriseSearchPolicy.__name__
|
|
213
|
-
)
|
|
214
|
-
|
|
215
203
|
# Vector store object and configuration
|
|
216
204
|
self.vector_store = vector_store
|
|
217
|
-
self.vector_store_config =
|
|
205
|
+
self.vector_store_config = config.get(
|
|
218
206
|
VECTOR_STORE_PROPERTY, DEFAULT_VECTOR_STORE
|
|
219
207
|
)
|
|
220
|
-
|
|
221
208
|
# Embeddings configuration for encoding the search query
|
|
222
|
-
self.embeddings_config = (
|
|
223
|
-
|
|
209
|
+
self.embeddings_config = self.config.get(
|
|
210
|
+
EMBEDDINGS_CONFIG_KEY, DEFAULT_EMBEDDINGS_CONFIG
|
|
224
211
|
)
|
|
225
|
-
|
|
226
|
-
# LLM Configuration for response generation
|
|
227
|
-
self.llm_config = self.config[LLM_CONFIG_KEY] or DEFAULT_LLM_CONFIG
|
|
228
|
-
|
|
229
212
|
# Maximum number of turns to include in the prompt
|
|
230
213
|
self.max_history = self.config.get(POLICY_MAX_HISTORY)
|
|
231
214
|
|
|
232
215
|
# Maximum number of messages to include in the search query
|
|
233
216
|
self.max_messages_in_query = self.config.get(MAX_MESSAGES_IN_QUERY_KEY, 2)
|
|
234
217
|
|
|
218
|
+
# LLM Configuration for response generation
|
|
219
|
+
self.llm_config = self.config.get(LLM_CONFIG_KEY, DEFAULT_LLM_CONFIG)
|
|
220
|
+
|
|
235
221
|
# boolean to enable/disable tracing of prompt tokens
|
|
236
222
|
self.trace_prompt_tokens = self.config.get(TRACE_TOKENS_PROPERTY, False)
|
|
237
223
|
|
|
@@ -260,16 +246,9 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
|
|
|
260
246
|
Returns:
|
|
261
247
|
The embedder.
|
|
262
248
|
"""
|
|
263
|
-
# Copy the config so original config is not modified
|
|
264
|
-
config = config.copy()
|
|
265
|
-
# Resolve config and instantiate the embedding client
|
|
266
|
-
config[EMBEDDINGS_CONFIG_KEY] = resolve_model_client_config(
|
|
267
|
-
config.get(EMBEDDINGS_CONFIG_KEY), EnterpriseSearchPolicy.__name__
|
|
268
|
-
)
|
|
269
249
|
client = embedder_factory(
|
|
270
250
|
config.get(EMBEDDINGS_CONFIG_KEY), DEFAULT_EMBEDDINGS_CONFIG
|
|
271
251
|
)
|
|
272
|
-
# Wrap the embedding client in the adapter
|
|
273
252
|
return _LangchainEmbeddingClientAdapter(client)
|
|
274
253
|
|
|
275
254
|
def train( # type: ignore[override]
|
|
@@ -296,9 +275,6 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
|
|
|
296
275
|
A policy must return its resource locator so that potential children nodes
|
|
297
276
|
can load the policy from the resource.
|
|
298
277
|
"""
|
|
299
|
-
# Perform health checks for both LLM and embeddings client configs
|
|
300
|
-
self._perform_health_checks(self.config, "enterprise_search_policy.train")
|
|
301
|
-
|
|
302
278
|
store_type = self.vector_store_config.get(VECTOR_STORE_TYPE_PROPERTY)
|
|
303
279
|
|
|
304
280
|
# telemetry call to track training start
|
|
@@ -318,6 +294,20 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
|
|
|
318
294
|
f"required environment variables. Error: {e}"
|
|
319
295
|
)
|
|
320
296
|
|
|
297
|
+
# validate llm configuration
|
|
298
|
+
llm_client = try_instantiate_llm_client(
|
|
299
|
+
self.config.get(LLM_CONFIG_KEY),
|
|
300
|
+
DEFAULT_LLM_CONFIG,
|
|
301
|
+
"enterprise_search_policy.train",
|
|
302
|
+
EnterpriseSearchPolicy.__name__,
|
|
303
|
+
)
|
|
304
|
+
if os.getenv(LLM_API_HEALTH_CHECK_ENV_VAR, "true").lower() == "true":
|
|
305
|
+
llm_api_health_check(
|
|
306
|
+
llm_client,
|
|
307
|
+
"enterprise_search_policy.train",
|
|
308
|
+
EnterpriseSearchPolicy.__name__,
|
|
309
|
+
)
|
|
310
|
+
|
|
321
311
|
if store_type == DEFAULT_VECTOR_STORE_TYPE:
|
|
322
312
|
logger.info("enterprise_search_policy.train.faiss")
|
|
323
313
|
with self._model_storage.write_to(self._resource) as path:
|
|
@@ -336,13 +326,9 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
|
|
|
336
326
|
embeddings_type=self.embeddings_config.get(PROVIDER_CONFIG_KEY),
|
|
337
327
|
embeddings_model=self.embeddings_config.get(MODEL_CONFIG_KEY)
|
|
338
328
|
or self.embeddings_config.get(MODEL_NAME_CONFIG_KEY),
|
|
339
|
-
embeddings_model_group_id=self.embeddings_config.get(
|
|
340
|
-
MODEL_GROUP_ID_CONFIG_KEY
|
|
341
|
-
),
|
|
342
329
|
llm_type=self.llm_config.get(PROVIDER_CONFIG_KEY),
|
|
343
330
|
llm_model=self.llm_config.get(MODEL_CONFIG_KEY)
|
|
344
331
|
or self.llm_config.get(MODEL_NAME_CONFIG_KEY),
|
|
345
|
-
llm_model_group_id=self.llm_config.get(MODEL_GROUP_ID_CONFIG_KEY),
|
|
346
332
|
citation_enabled=self.citation_enabled,
|
|
347
333
|
)
|
|
348
334
|
self.persist()
|
|
@@ -354,9 +340,6 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
|
|
|
354
340
|
rasa.shared.utils.io.write_text_file(
|
|
355
341
|
self.prompt_template, path / ENTERPRISE_SEARCH_PROMPT_FILE_NAME
|
|
356
342
|
)
|
|
357
|
-
rasa.shared.utils.io.dump_obj_as_json_to_file(
|
|
358
|
-
path / ENTERPRISE_SEARCH_CONFIG_FILE_NAME, self.config
|
|
359
|
-
)
|
|
360
343
|
|
|
361
344
|
def _prepare_slots_for_template(
|
|
362
345
|
self, tracker: DialogueStateTracker
|
|
@@ -537,13 +520,9 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
|
|
|
537
520
|
embeddings_type=self.embeddings_config.get(PROVIDER_CONFIG_KEY),
|
|
538
521
|
embeddings_model=self.embeddings_config.get(MODEL_CONFIG_KEY)
|
|
539
522
|
or self.embeddings_config.get(MODEL_NAME_CONFIG_KEY),
|
|
540
|
-
embeddings_model_group_id=self.embeddings_config.get(
|
|
541
|
-
MODEL_GROUP_ID_CONFIG_KEY
|
|
542
|
-
),
|
|
543
523
|
llm_type=self.llm_config.get(PROVIDER_CONFIG_KEY),
|
|
544
524
|
llm_model=self.llm_config.get(MODEL_CONFIG_KEY)
|
|
545
525
|
or self.llm_config.get(MODEL_NAME_CONFIG_KEY),
|
|
546
|
-
llm_model_group_id=self.llm_config.get(MODEL_GROUP_ID_CONFIG_KEY),
|
|
547
526
|
citation_enabled=self.citation_enabled,
|
|
548
527
|
)
|
|
549
528
|
return self._create_prediction(
|
|
@@ -692,27 +671,12 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
|
|
|
692
671
|
**kwargs: Any,
|
|
693
672
|
) -> "EnterpriseSearchPolicy":
|
|
694
673
|
"""Loads a trained policy (see parent class for full docstring)."""
|
|
695
|
-
|
|
696
|
-
# Perform health checks for both LLM and embeddings client configs
|
|
697
|
-
cls._perform_health_checks(config, "enterprise_search_policy.load")
|
|
698
|
-
|
|
699
674
|
prompt_template = None
|
|
700
|
-
try:
|
|
701
|
-
with model_storage.read_from(resource) as path:
|
|
702
|
-
prompt_template = rasa.shared.utils.io.read_file(
|
|
703
|
-
path / ENTERPRISE_SEARCH_PROMPT_FILE_NAME
|
|
704
|
-
)
|
|
705
|
-
except (FileNotFoundError, FileIOException) as e:
|
|
706
|
-
logger.warning(
|
|
707
|
-
"enterprise_search_policy.load.failed", error=e, resource=resource.name
|
|
708
|
-
)
|
|
709
|
-
|
|
710
675
|
store_type = config.get(VECTOR_STORE_PROPERTY, {}).get(
|
|
711
676
|
VECTOR_STORE_TYPE_PROPERTY
|
|
712
677
|
)
|
|
713
678
|
|
|
714
679
|
embeddings = cls._create_plain_embedder(config)
|
|
715
|
-
|
|
716
680
|
logger.info("enterprise_search_policy.load", config=config)
|
|
717
681
|
if store_type == DEFAULT_VECTOR_STORE_TYPE:
|
|
718
682
|
# if a vector store is not specified,
|
|
@@ -730,6 +694,16 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
|
|
|
730
694
|
config_type=store_type,
|
|
731
695
|
embeddings=embeddings,
|
|
732
696
|
) # type: ignore
|
|
697
|
+
try:
|
|
698
|
+
with model_storage.read_from(resource) as path:
|
|
699
|
+
prompt_template = rasa.shared.utils.io.read_file(
|
|
700
|
+
path / ENTERPRISE_SEARCH_PROMPT_FILE_NAME
|
|
701
|
+
)
|
|
702
|
+
|
|
703
|
+
except (FileNotFoundError, FileNotFoundError) as e:
|
|
704
|
+
logger.warning(
|
|
705
|
+
"enterprise_search_policy.load.failed", error=e, resource=resource.name
|
|
706
|
+
)
|
|
733
707
|
|
|
734
708
|
return cls(
|
|
735
709
|
config,
|
|
@@ -771,23 +745,14 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
|
|
|
771
745
|
|
|
772
746
|
@classmethod
|
|
773
747
|
def fingerprint_addon(cls, config: Dict[str, Any]) -> Optional[str]:
|
|
774
|
-
"""Add a fingerprint of
|
|
748
|
+
"""Add a fingerprint of the knowledge base and prompt template for the graph."""
|
|
775
749
|
local_knowledge_data = cls._get_local_knowledge_data(config)
|
|
776
750
|
|
|
777
751
|
prompt_template = get_prompt_template(
|
|
778
752
|
config.get(PROMPT_CONFIG_KEY),
|
|
779
753
|
DEFAULT_ENTERPRISE_SEARCH_PROMPT_TEMPLATE,
|
|
780
754
|
)
|
|
781
|
-
|
|
782
|
-
llm_config = resolve_model_client_config(
|
|
783
|
-
config.get(LLM_CONFIG_KEY), EnterpriseSearchPolicy.__name__
|
|
784
|
-
)
|
|
785
|
-
embedding_config = resolve_model_client_config(
|
|
786
|
-
config.get(EMBEDDINGS_CONFIG_KEY), EnterpriseSearchPolicy.__name__
|
|
787
|
-
)
|
|
788
|
-
return deep_container_fingerprint(
|
|
789
|
-
[prompt_template, local_knowledge_data, llm_config, embedding_config]
|
|
790
|
-
)
|
|
755
|
+
return deep_container_fingerprint([prompt_template, local_knowledge_data])
|
|
791
756
|
|
|
792
757
|
@staticmethod
|
|
793
758
|
def post_process_citations(llm_answer: str) -> str:
|
|
@@ -879,27 +844,3 @@ class EnterpriseSearchPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Po
|
|
|
879
844
|
joined_sources = "\n".join(new_sources)
|
|
880
845
|
|
|
881
846
|
return joined_answer + joined_sources
|
|
882
|
-
|
|
883
|
-
@classmethod
|
|
884
|
-
def _perform_health_checks(
|
|
885
|
-
cls, config: Dict[Text, Any], log_source_method: str
|
|
886
|
-
) -> None:
|
|
887
|
-
# Perform health check of the LLM client config
|
|
888
|
-
llm_config = resolve_model_client_config(config.get(LLM_CONFIG_KEY, {}))
|
|
889
|
-
cls.perform_llm_health_check(
|
|
890
|
-
llm_config,
|
|
891
|
-
DEFAULT_LLM_CONFIG,
|
|
892
|
-
log_source_method,
|
|
893
|
-
EnterpriseSearchPolicy.__name__,
|
|
894
|
-
)
|
|
895
|
-
|
|
896
|
-
# Perform health check of the embeddings client config
|
|
897
|
-
embeddings_config = resolve_model_client_config(
|
|
898
|
-
config.get(EMBEDDINGS_CONFIG_KEY, {})
|
|
899
|
-
)
|
|
900
|
-
cls.perform_embeddings_health_check(
|
|
901
|
-
embeddings_config,
|
|
902
|
-
DEFAULT_EMBEDDINGS_CONFIG,
|
|
903
|
-
log_source_method,
|
|
904
|
-
EnterpriseSearchPolicy.__name__,
|
|
905
|
-
)
|
|
@@ -330,27 +330,24 @@ def reset_scoped_slots(
|
|
|
330
330
|
events: List[Event] = []
|
|
331
331
|
|
|
332
332
|
not_resettable_slot_names = set()
|
|
333
|
-
flow_persistable_slots = current_flow.persisted_slots
|
|
334
333
|
|
|
335
334
|
for step in current_flow.steps_with_calls_resolved:
|
|
336
335
|
if isinstance(step, CollectInformationFlowStep):
|
|
337
336
|
# reset all slots scoped to the flow
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
_reset_slot(slot_name, tracker)
|
|
337
|
+
if step.reset_after_flow_ends:
|
|
338
|
+
_reset_slot(step.collect, tracker)
|
|
341
339
|
else:
|
|
342
|
-
not_resettable_slot_names.add(
|
|
340
|
+
not_resettable_slot_names.add(step.collect)
|
|
343
341
|
|
|
344
342
|
# slots set by the set slots step should be reset after the flow ends
|
|
345
343
|
# unless they are also used in a collect step where `reset_after_flow_ends`
|
|
346
|
-
# is set to `False`
|
|
344
|
+
# is set to `False`
|
|
347
345
|
resettable_set_slots = [
|
|
348
346
|
slot["key"]
|
|
349
347
|
for step in current_flow.steps_with_calls_resolved
|
|
350
348
|
if isinstance(step, SetSlotsFlowStep)
|
|
351
349
|
for slot in step.slots
|
|
352
350
|
if slot["key"] not in not_resettable_slot_names
|
|
353
|
-
and slot["key"] not in flow_persistable_slots
|
|
354
351
|
]
|
|
355
352
|
|
|
356
353
|
for name in resettable_set_slots:
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import importlib.resources
|
|
2
2
|
import math
|
|
3
|
+
import os
|
|
3
4
|
from dataclasses import dataclass, field
|
|
4
5
|
from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING, Text, Tuple
|
|
5
6
|
|
|
@@ -18,7 +19,6 @@ from rasa.core.constants import (
|
|
|
18
19
|
UTTER_SOURCE_METADATA_KEY,
|
|
19
20
|
)
|
|
20
21
|
from rasa.core.policies.policy import Policy, PolicyPrediction, SupportedData
|
|
21
|
-
from rasa.dialogue_understanding.patterns.chitchat import FLOW_PATTERN_CHITCHAT
|
|
22
22
|
from rasa.dialogue_understanding.stack.frames import (
|
|
23
23
|
ChitChatStackFrame,
|
|
24
24
|
DialogueStackFrame,
|
|
@@ -32,6 +32,7 @@ from rasa.graph_components.providers.responses_provider import Responses
|
|
|
32
32
|
from rasa.shared.constants import (
|
|
33
33
|
REQUIRED_SLOTS_KEY,
|
|
34
34
|
EMBEDDINGS_CONFIG_KEY,
|
|
35
|
+
LLM_API_HEALTH_CHECK_ENV_VAR,
|
|
35
36
|
LLM_CONFIG_KEY,
|
|
36
37
|
MODEL_CONFIG_KEY,
|
|
37
38
|
MODEL_NAME_CONFIG_KEY,
|
|
@@ -39,10 +40,8 @@ from rasa.shared.constants import (
|
|
|
39
40
|
PROVIDER_CONFIG_KEY,
|
|
40
41
|
OPENAI_PROVIDER,
|
|
41
42
|
TIMEOUT_CONFIG_KEY,
|
|
42
|
-
MODEL_GROUP_ID_CONFIG_KEY,
|
|
43
43
|
)
|
|
44
44
|
from rasa.shared.core.constants import ACTION_LISTEN_NAME
|
|
45
|
-
from rasa.shared.core.constants import ACTION_TRIGGER_CHITCHAT
|
|
46
45
|
from rasa.shared.core.domain import KEY_RESPONSES_TEXT, Domain
|
|
47
46
|
from rasa.shared.core.events import (
|
|
48
47
|
ActionExecuted,
|
|
@@ -60,10 +59,6 @@ from rasa.shared.providers.embedding._langchain_embedding_client_adapter import
|
|
|
60
59
|
_LangchainEmbeddingClientAdapter,
|
|
61
60
|
)
|
|
62
61
|
from rasa.shared.providers.llm.llm_client import LLMClient
|
|
63
|
-
from rasa.shared.utils.health_check.embeddings_health_check_mixin import (
|
|
64
|
-
EmbeddingsHealthCheckMixin,
|
|
65
|
-
)
|
|
66
|
-
from rasa.shared.utils.health_check.llm_health_check_mixin import LLMHealthCheckMixin
|
|
67
62
|
from rasa.shared.utils.io import deep_container_fingerprint
|
|
68
63
|
from rasa.shared.utils.llm import (
|
|
69
64
|
AI,
|
|
@@ -74,12 +69,12 @@ from rasa.shared.utils.llm import (
|
|
|
74
69
|
combine_custom_and_default_config,
|
|
75
70
|
embedder_factory,
|
|
76
71
|
get_prompt_template,
|
|
72
|
+
llm_api_health_check,
|
|
77
73
|
llm_factory,
|
|
78
74
|
sanitize_message_for_prompt,
|
|
79
75
|
tracker_as_readable_transcript,
|
|
80
|
-
|
|
76
|
+
try_instantiate_llm_client,
|
|
81
77
|
)
|
|
82
|
-
from rasa.utils.log_utils import log_llm
|
|
83
78
|
from rasa.utils.ml_utils import (
|
|
84
79
|
extract_ai_response_examples,
|
|
85
80
|
extract_participant_messages_from_transcript,
|
|
@@ -88,6 +83,9 @@ from rasa.utils.ml_utils import (
|
|
|
88
83
|
persist_faiss_vector_store,
|
|
89
84
|
response_for_template,
|
|
90
85
|
)
|
|
86
|
+
from rasa.dialogue_understanding.patterns.chitchat import FLOW_PATTERN_CHITCHAT
|
|
87
|
+
from rasa.shared.core.constants import ACTION_TRIGGER_CHITCHAT
|
|
88
|
+
from rasa.utils.log_utils import log_llm
|
|
91
89
|
|
|
92
90
|
if TYPE_CHECKING:
|
|
93
91
|
from rasa.core.featurizers.tracker_featurizers import TrackerFeaturizer
|
|
@@ -127,7 +125,6 @@ DEFAULT_INTENTLESS_PROMPT_TEMPLATE = importlib.resources.open_text(
|
|
|
127
125
|
).name
|
|
128
126
|
|
|
129
127
|
INTENTLESS_PROMPT_TEMPLATE_FILE_NAME = "intentless_policy_prompt.jinja2"
|
|
130
|
-
INTENTLESS_CONFIG_FILE_NAME = "config.json"
|
|
131
128
|
|
|
132
129
|
|
|
133
130
|
class RasaMLPolicyTrainingException(RasaCoreException):
|
|
@@ -377,7 +374,7 @@ def conversation_as_prompt(conversation: Conversation) -> str:
|
|
|
377
374
|
@DefaultV1Recipe.register(
|
|
378
375
|
DefaultV1Recipe.ComponentType.POLICY_WITH_END_TO_END_SUPPORT, is_trainable=True
|
|
379
376
|
)
|
|
380
|
-
class IntentlessPolicy(
|
|
377
|
+
class IntentlessPolicy(Policy):
|
|
381
378
|
"""Policy which uses a language model to generate the next action.
|
|
382
379
|
|
|
383
380
|
The policy uses the OpenAI API to generate the next action based on the
|
|
@@ -434,16 +431,6 @@ class IntentlessPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Policy):
|
|
|
434
431
|
"""Constructs a new Policy object."""
|
|
435
432
|
super().__init__(config, model_storage, resource, execution_context, featurizer)
|
|
436
433
|
|
|
437
|
-
# Resolve LLM config
|
|
438
|
-
self.config[LLM_CONFIG_KEY] = resolve_model_client_config(
|
|
439
|
-
self.config.get(LLM_CONFIG_KEY), IntentlessPolicy.__name__
|
|
440
|
-
)
|
|
441
|
-
|
|
442
|
-
# Resolve embeddings config
|
|
443
|
-
self.config[EMBEDDINGS_CONFIG_KEY] = resolve_model_client_config(
|
|
444
|
-
self.config.get(EMBEDDINGS_CONFIG_KEY), IntentlessPolicy.__name__
|
|
445
|
-
)
|
|
446
|
-
|
|
447
434
|
self.nlu_abstention_threshold: float = self.config[NLU_ABSTENTION_THRESHOLD]
|
|
448
435
|
self.response_index = responses_docsearch
|
|
449
436
|
self.conversation_samples_index = samples_docsearch
|
|
@@ -460,16 +447,9 @@ class IntentlessPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Policy):
|
|
|
460
447
|
Returns:
|
|
461
448
|
The embedder.
|
|
462
449
|
"""
|
|
463
|
-
# Copy the config so original config is not modified
|
|
464
|
-
config.copy()
|
|
465
|
-
# Resolve config and instantiate the embedding client
|
|
466
|
-
config[EMBEDDINGS_CONFIG_KEY] = resolve_model_client_config(
|
|
467
|
-
config.get(EMBEDDINGS_CONFIG_KEY), IntentlessPolicy.__name__
|
|
468
|
-
)
|
|
469
450
|
client = embedder_factory(
|
|
470
451
|
config.get(EMBEDDINGS_CONFIG_KEY), DEFAULT_EMBEDDINGS_CONFIG
|
|
471
452
|
)
|
|
472
|
-
# Wrap the embedding client in the adapter
|
|
473
453
|
return _LangchainEmbeddingClientAdapter(client)
|
|
474
454
|
|
|
475
455
|
def embeddings_property(self, prop: str) -> Optional[str]:
|
|
@@ -510,8 +490,16 @@ class IntentlessPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Policy):
|
|
|
510
490
|
A policy must return its resource locator so that potential children nodes
|
|
511
491
|
can load the policy from the resource.
|
|
512
492
|
"""
|
|
513
|
-
|
|
514
|
-
|
|
493
|
+
llm_client = try_instantiate_llm_client(
|
|
494
|
+
self.config.get(LLM_CONFIG_KEY),
|
|
495
|
+
DEFAULT_LLM_CONFIG,
|
|
496
|
+
"intentless_policy.train",
|
|
497
|
+
IntentlessPolicy.__name__,
|
|
498
|
+
)
|
|
499
|
+
if os.getenv(LLM_API_HEALTH_CHECK_ENV_VAR, "true").lower() == "true":
|
|
500
|
+
llm_api_health_check(
|
|
501
|
+
llm_client, "intentless_policy.train", IntentlessPolicy.__name__
|
|
502
|
+
)
|
|
515
503
|
|
|
516
504
|
responses = filter_responses(responses, forms, flows or FlowsList([]))
|
|
517
505
|
telemetry.track_intentless_policy_train()
|
|
@@ -558,13 +546,9 @@ class IntentlessPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Policy):
|
|
|
558
546
|
embeddings_type=self.embeddings_property(PROVIDER_CONFIG_KEY),
|
|
559
547
|
embeddings_model=self.embeddings_property(MODEL_CONFIG_KEY)
|
|
560
548
|
or self.embeddings_property(MODEL_NAME_CONFIG_KEY),
|
|
561
|
-
embeddings_model_group_id=self.embeddings_property(
|
|
562
|
-
MODEL_GROUP_ID_CONFIG_KEY
|
|
563
|
-
),
|
|
564
549
|
llm_type=self.llm_property(PROVIDER_CONFIG_KEY),
|
|
565
550
|
llm_model=self.llm_property(MODEL_CONFIG_KEY)
|
|
566
551
|
or self.llm_property(MODEL_NAME_CONFIG_KEY),
|
|
567
|
-
llm_model_group_id=self.llm_property(MODEL_GROUP_ID_CONFIG_KEY),
|
|
568
552
|
)
|
|
569
553
|
|
|
570
554
|
self.persist()
|
|
@@ -580,9 +564,6 @@ class IntentlessPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Policy):
|
|
|
580
564
|
rasa.shared.utils.io.write_text_file(
|
|
581
565
|
self.prompt_template, path / INTENTLESS_PROMPT_TEMPLATE_FILE_NAME
|
|
582
566
|
)
|
|
583
|
-
rasa.shared.utils.io.dump_obj_as_json_to_file(
|
|
584
|
-
path / INTENTLESS_CONFIG_FILE_NAME, self.config
|
|
585
|
-
)
|
|
586
567
|
|
|
587
568
|
async def predict_action_probabilities(
|
|
588
569
|
self,
|
|
@@ -644,13 +625,9 @@ class IntentlessPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Policy):
|
|
|
644
625
|
embeddings_type=self.embeddings_property(PROVIDER_CONFIG_KEY),
|
|
645
626
|
embeddings_model=self.embeddings_property(MODEL_CONFIG_KEY)
|
|
646
627
|
or self.embeddings_property(MODEL_NAME_CONFIG_KEY),
|
|
647
|
-
embeddings_model_group_id=self.embeddings_property(
|
|
648
|
-
MODEL_GROUP_ID_CONFIG_KEY
|
|
649
|
-
),
|
|
650
628
|
llm_type=self.llm_property(PROVIDER_CONFIG_KEY),
|
|
651
629
|
llm_model=self.llm_property(MODEL_CONFIG_KEY)
|
|
652
630
|
or self.llm_property(MODEL_NAME_CONFIG_KEY),
|
|
653
|
-
llm_model_group_id=self.llm_property(MODEL_GROUP_ID_CONFIG_KEY),
|
|
654
631
|
score=score,
|
|
655
632
|
)
|
|
656
633
|
|
|
@@ -674,7 +651,7 @@ class IntentlessPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Policy):
|
|
|
674
651
|
history: str,
|
|
675
652
|
) -> Optional[str]:
|
|
676
653
|
"""Make the llm call to generate an answer."""
|
|
677
|
-
llm = llm_factory(self.config
|
|
654
|
+
llm = llm_factory(self.config[LLM_CONFIG_KEY], DEFAULT_LLM_CONFIG)
|
|
678
655
|
inputs = {
|
|
679
656
|
"conversations": conversation_samples,
|
|
680
657
|
"responses": response_examples,
|
|
@@ -948,10 +925,6 @@ class IntentlessPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Policy):
|
|
|
948
925
|
**kwargs: Any,
|
|
949
926
|
) -> "IntentlessPolicy":
|
|
950
927
|
"""Loads a trained policy (see parent class for full docstring)."""
|
|
951
|
-
|
|
952
|
-
# Perform health checks of both LLM and embeddings client configs
|
|
953
|
-
cls._perform_health_checks(config, "intentless_policy.load")
|
|
954
|
-
|
|
955
928
|
responses_docsearch = None
|
|
956
929
|
samples_docsearch = None
|
|
957
930
|
prompt_template = None
|
|
@@ -972,6 +945,7 @@ class IntentlessPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Policy):
|
|
|
972
945
|
prompt_template = rasa.shared.utils.io.read_file(
|
|
973
946
|
path / INTENTLESS_PROMPT_TEMPLATE_FILE_NAME
|
|
974
947
|
)
|
|
948
|
+
|
|
975
949
|
except (ValueError, FileNotFoundError, FileIOException) as e:
|
|
976
950
|
structlogger.warning(
|
|
977
951
|
"intentless_policy.load.failed", error=e, resource_name=resource.name
|
|
@@ -989,43 +963,9 @@ class IntentlessPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Policy):
|
|
|
989
963
|
|
|
990
964
|
@classmethod
|
|
991
965
|
def fingerprint_addon(cls, config: Dict[str, Any]) -> Optional[str]:
|
|
992
|
-
"""Add a fingerprint of
|
|
966
|
+
"""Add a fingerprint of the knowledge base for the graph."""
|
|
993
967
|
prompt_template = get_prompt_template(
|
|
994
968
|
config.get(PROMPT_CONFIG_KEY),
|
|
995
969
|
DEFAULT_INTENTLESS_PROMPT_TEMPLATE,
|
|
996
970
|
)
|
|
997
|
-
|
|
998
|
-
llm_config = resolve_model_client_config(
|
|
999
|
-
config.get(LLM_CONFIG_KEY), IntentlessPolicy.__name__
|
|
1000
|
-
)
|
|
1001
|
-
embedding_config = resolve_model_client_config(
|
|
1002
|
-
config.get(EMBEDDINGS_CONFIG_KEY), IntentlessPolicy.__name__
|
|
1003
|
-
)
|
|
1004
|
-
|
|
1005
|
-
return deep_container_fingerprint(
|
|
1006
|
-
[prompt_template, llm_config, embedding_config]
|
|
1007
|
-
)
|
|
1008
|
-
|
|
1009
|
-
@classmethod
|
|
1010
|
-
def _perform_health_checks(
|
|
1011
|
-
cls, config: Dict[Text, Any], log_source_method: str
|
|
1012
|
-
) -> None:
|
|
1013
|
-
# Perform health check of the LLM client config
|
|
1014
|
-
llm_config = resolve_model_client_config(config.get(LLM_CONFIG_KEY, {}))
|
|
1015
|
-
cls.perform_llm_health_check(
|
|
1016
|
-
llm_config,
|
|
1017
|
-
DEFAULT_LLM_CONFIG,
|
|
1018
|
-
log_source_method,
|
|
1019
|
-
IntentlessPolicy.__name__,
|
|
1020
|
-
)
|
|
1021
|
-
|
|
1022
|
-
# Perform health check of the embeddings client config
|
|
1023
|
-
embeddings_config = resolve_model_client_config(
|
|
1024
|
-
config.get(EMBEDDINGS_CONFIG_KEY, {})
|
|
1025
|
-
)
|
|
1026
|
-
cls.perform_embeddings_health_check(
|
|
1027
|
-
embeddings_config,
|
|
1028
|
-
DEFAULT_EMBEDDINGS_CONFIG,
|
|
1029
|
-
log_source_method,
|
|
1030
|
-
IntentlessPolicy.__name__,
|
|
1031
|
-
)
|
|
971
|
+
return deep_container_fingerprint(prompt_template)
|