rasa-pro 3.13.0.dev2__py3-none-any.whl → 3.13.0.dev5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- rasa/__main__.py +3 -1
- rasa/cli/inspect.py +8 -4
- rasa/cli/project_templates/default/config.yml +5 -32
- rasa/cli/project_templates/{calm → default}/e2e_tests/cancelations/user_cancels_during_a_correction.yml +1 -1
- rasa/cli/project_templates/{calm → default}/e2e_tests/cancelations/user_changes_mind_on_a_whim.yml +1 -1
- rasa/cli/project_templates/{calm → default}/e2e_tests/corrections/user_corrects_contact_handle.yml +1 -1
- rasa/cli/project_templates/{calm → default}/e2e_tests/corrections/user_corrects_contact_name.yml +1 -1
- rasa/cli/project_templates/{calm → default}/e2e_tests/happy_paths/user_adds_contact_to_their_list.yml +1 -1
- rasa/cli/project_templates/{calm → default}/e2e_tests/happy_paths/user_lists_contacts.yml +1 -1
- rasa/cli/project_templates/{calm → default}/e2e_tests/happy_paths/user_removes_contact.yml +1 -1
- rasa/cli/project_templates/{calm → default}/e2e_tests/happy_paths/user_removes_contact_from_list.yml +1 -1
- rasa/cli/project_templates/default/endpoints.yml +18 -2
- rasa/cli/run.py +10 -6
- rasa/cli/scaffold.py +3 -4
- rasa/cli/studio/download.py +1 -1
- rasa/cli/studio/upload.py +0 -6
- rasa/cli/utils.py +7 -0
- rasa/core/channels/channel.py +93 -0
- rasa/core/channels/inspector/dist/assets/{arc-c7691751.js → arc-9f75cc3b.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{blockDiagram-38ab4fdb-ab99dff7.js → blockDiagram-38ab4fdb-7f34db23.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{c4Diagram-3d4e48cf-08c35a6b.js → c4Diagram-3d4e48cf-948bab2c.js} +1 -1
- rasa/core/channels/inspector/dist/assets/channel-dfa68278.js +1 -0
- rasa/core/channels/inspector/dist/assets/{classDiagram-70f12bd4-9e9c71c9.js → classDiagram-70f12bd4-53b0dd0e.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{classDiagram-v2-f2320105-15e7e2bf.js → classDiagram-v2-f2320105-fdf789e7.js} +1 -1
- rasa/core/channels/inspector/dist/assets/clone-edb7f119.js +1 -0
- rasa/core/channels/inspector/dist/assets/{createText-2e5e7dd3-9c105cb1.js → createText-2e5e7dd3-87c4ece5.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{edges-e0da2a9e-77e89e48.js → edges-e0da2a9e-5a8b0749.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{erDiagram-9861fffd-7a011646.js → erDiagram-9861fffd-66da90e2.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{flowDb-956e92f1-b6f105ac.js → flowDb-956e92f1-10044f05.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{flowDiagram-66a62f08-ce4f18c2.js → flowDiagram-66a62f08-f338f66a.js} +1 -1
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-96b9c2cf-65e7c670.js +1 -0
- rasa/core/channels/inspector/dist/assets/{flowchart-elk-definition-4a651766-cb5f6da4.js → flowchart-elk-definition-4a651766-b13140aa.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{ganttDiagram-c361ad54-e4d19e28.js → ganttDiagram-c361ad54-f2b4a55a.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{gitGraphDiagram-72cf32ee-727b1c33.js → gitGraphDiagram-72cf32ee-dedc298d.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{graph-6e2ab9a7.js → graph-4ede11ff.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{index-3862675e-84ec700f.js → index-3862675e-65549d37.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{index-098a1a24.js → index-3a23e736.js} +142 -129
- rasa/core/channels/inspector/dist/assets/{infoDiagram-f8f76790-78dda442.js → infoDiagram-f8f76790-65439671.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{journeyDiagram-49397b02-f1cc6dd1.js → journeyDiagram-49397b02-56d03d98.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{layout-d98dcd0c.js → layout-dd48f7f4.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{line-838e3d82.js → line-1569ad2c.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{linear-eae72406.js → linear-48bf4935.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{mindmap-definition-fc14e90a-c96fd84b.js → mindmap-definition-fc14e90a-688504c1.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{pieDiagram-8a3498a8-c936d4e2.js → pieDiagram-8a3498a8-78b6d7e6.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{quadrantDiagram-120e2f19-b338eb8f.js → quadrantDiagram-120e2f19-048b84b3.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{requirementDiagram-deff3bca-c6b6c0d5.js → requirementDiagram-deff3bca-dd67f107.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{sankeyDiagram-04a897e0-b9372e19.js → sankeyDiagram-04a897e0-8128436e.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{sequenceDiagram-704730f1-479e0a3f.js → sequenceDiagram-704730f1-1a0d1461.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{stateDiagram-587899a1-fd26eebc.js → stateDiagram-587899a1-46d388ed.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{stateDiagram-v2-d93cdb3a-3233e0ae.js → stateDiagram-v2-d93cdb3a-ea42951a.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-6aaf32cf-1fdd392b.js → styles-6aaf32cf-7427ed0c.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-9a916d00-6d7bfa1b.js → styles-9a916d00-ff5e5a16.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-c10674c1-f86aab11.js → styles-c10674c1-7b3680cf.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{svgDrawCommon-08f97a94-e3e49d7a.js → svgDrawCommon-08f97a94-f860f2ad.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{timeline-definition-85554ec2-6fe08b4d.js → timeline-definition-85554ec2-2eebf0c8.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{xychartDiagram-e933f94c-c2e06fd6.js → xychartDiagram-e933f94c-5d7f4e96.js} +1 -1
- rasa/core/channels/inspector/dist/index.html +1 -1
- rasa/core/channels/inspector/src/App.tsx +3 -2
- rasa/core/channels/inspector/src/components/Chat.tsx +23 -2
- rasa/core/channels/inspector/src/components/DiagramFlow.tsx +2 -5
- rasa/core/channels/inspector/src/helpers/conversation.ts +16 -0
- rasa/core/channels/inspector/src/types.ts +1 -1
- rasa/core/channels/voice_ready/audiocodes.py +41 -15
- rasa/core/channels/voice_ready/jambonz.py +25 -5
- rasa/core/channels/voice_ready/jambonz_protocol.py +4 -0
- rasa/core/channels/voice_ready/twilio_voice.py +48 -1
- rasa/core/channels/voice_stream/tts/azure.py +11 -2
- rasa/core/channels/voice_stream/twilio_media_streams.py +101 -26
- rasa/core/channels/voice_stream/voice_channel.py +28 -2
- rasa/core/concurrent_lock_store.py +24 -10
- rasa/core/information_retrieval/faiss.py +7 -68
- rasa/core/information_retrieval/information_retrieval.py +2 -40
- rasa/core/information_retrieval/milvus.py +2 -7
- rasa/core/information_retrieval/qdrant.py +2 -7
- rasa/core/lock_store.py +151 -60
- rasa/core/nlg/contextual_response_rephraser.py +3 -0
- rasa/core/policies/enterprise_search_policy.py +310 -61
- rasa/core/policies/intentless_policy.py +3 -0
- rasa/dialogue_understanding/coexistence/llm_based_router.py +8 -0
- rasa/dialogue_understanding/commands/knowledge_answer_command.py +2 -2
- rasa/dialogue_understanding/generator/command_parser.py +1 -1
- rasa/dialogue_understanding/generator/flow_retrieval.py +1 -4
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +1 -2
- rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +13 -0
- rasa/dialogue_understanding/generator/prompt_templates/command_prompt_template.jinja2 +1 -1
- rasa/dialogue_understanding/generator/prompt_templates/command_prompt_v2_claude_3_5_sonnet_20240620_template.jinja2 +1 -1
- rasa/dialogue_understanding/generator/prompt_templates/command_prompt_v2_gpt_4o_2024_11_20_template.jinja2 +2 -24
- rasa/dialogue_understanding/generator/single_step/compact_llm_command_generator.py +22 -17
- rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +27 -12
- rasa/dialogue_understanding_test/du_test_case.py +16 -8
- rasa/dialogue_understanding_test/io.py +8 -13
- rasa/e2e_test/utils/validation.py +3 -3
- rasa/engine/recipes/default_components.py +0 -2
- rasa/llm_fine_tuning/paraphrasing/conversation_rephraser.py +3 -0
- rasa/plugin.py +0 -3
- rasa/shared/constants.py +1 -0
- rasa/shared/core/domain.py +165 -11
- rasa/shared/core/flows/flow.py +155 -131
- rasa/shared/core/flows/flow_step.py +19 -3
- rasa/shared/core/flows/flow_step_links.py +15 -0
- rasa/shared/core/flows/flow_step_sequence.py +6 -0
- rasa/shared/core/flows/nlu_trigger.py +13 -0
- rasa/shared/core/flows/steps/action.py +7 -4
- rasa/shared/core/flows/steps/call.py +11 -4
- rasa/shared/core/flows/steps/collect.py +27 -6
- rasa/shared/core/flows/steps/internal.py +6 -1
- rasa/shared/core/flows/steps/link.py +7 -4
- rasa/shared/core/flows/steps/no_operation.py +7 -4
- rasa/shared/core/flows/steps/set_slots.py +8 -4
- rasa/shared/core/flows/yaml_flows_io.py +106 -5
- rasa/shared/importers/importer.py +8 -0
- rasa/shared/providers/_utils.py +83 -0
- rasa/shared/providers/llm/_base_litellm_client.py +6 -3
- rasa/shared/providers/llm/azure_openai_llm_client.py +6 -68
- rasa/shared/providers/router/_base_litellm_router_client.py +53 -1
- rasa/shared/utils/common.py +42 -0
- rasa/shared/utils/constants.py +3 -0
- rasa/shared/utils/llm.py +70 -24
- rasa/studio/download/domains.py +49 -0
- rasa/studio/download/download.py +439 -0
- rasa/studio/download/flows.py +359 -0
- rasa/studio/results_logger.py +6 -1
- rasa/studio/upload.py +69 -5
- rasa/tracing/instrumentation/attribute_extractors.py +7 -10
- rasa/tracing/instrumentation/instrumentation.py +12 -12
- rasa/utils/common.py +36 -0
- rasa/utils/endpoints.py +22 -1
- rasa/utils/licensing.py +1 -1
- rasa/validator.py +1 -2
- rasa/version.py +1 -1
- {rasa_pro-3.13.0.dev2.dist-info → rasa_pro-3.13.0.dev5.dist-info}/METADATA +7 -7
- {rasa_pro-3.13.0.dev2.dist-info → rasa_pro-3.13.0.dev5.dist-info}/RECORD +149 -166
- rasa/cli/project_templates/calm/config.yml +0 -10
- rasa/cli/project_templates/calm/credentials.yml +0 -33
- rasa/cli/project_templates/calm/endpoints.yml +0 -58
- rasa/cli/project_templates/default/actions/actions.py +0 -27
- rasa/cli/project_templates/default/data/nlu.yml +0 -91
- rasa/cli/project_templates/default/data/rules.yml +0 -13
- rasa/cli/project_templates/default/data/stories.yml +0 -30
- rasa/cli/project_templates/default/domain.yml +0 -34
- rasa/cli/project_templates/default/tests/test_stories.yml +0 -91
- rasa/core/channels/inspector/dist/assets/channel-11268142.js +0 -1
- rasa/core/channels/inspector/dist/assets/clone-ff7f2ce7.js +0 -1
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-96b9c2cf-cba7ae20.js +0 -1
- rasa/document_retrieval/__init__.py +0 -0
- rasa/document_retrieval/constants.py +0 -32
- rasa/document_retrieval/document_post_processor.py +0 -351
- rasa/document_retrieval/document_post_processor_prompt_template.jinja2 +0 -0
- rasa/document_retrieval/document_retriever.py +0 -333
- rasa/document_retrieval/knowledge_base_connectors/__init__.py +0 -0
- rasa/document_retrieval/knowledge_base_connectors/api_connector.py +0 -39
- rasa/document_retrieval/knowledge_base_connectors/knowledge_base_connector.py +0 -34
- rasa/document_retrieval/knowledge_base_connectors/vector_store_connector.py +0 -226
- rasa/document_retrieval/query_rewriter.py +0 -234
- rasa/document_retrieval/query_rewriter_prompt_template.jinja2 +0 -8
- rasa/studio/download.py +0 -489
- /rasa/cli/project_templates/{calm → default}/actions/action_template.py +0 -0
- /rasa/cli/project_templates/{calm → default}/actions/add_contact.py +0 -0
- /rasa/cli/project_templates/{calm → default}/actions/db.py +0 -0
- /rasa/cli/project_templates/{calm → default}/actions/list_contacts.py +0 -0
- /rasa/cli/project_templates/{calm → default}/actions/remove_contact.py +0 -0
- /rasa/cli/project_templates/{calm → default}/data/flows/add_contact.yml +0 -0
- /rasa/cli/project_templates/{calm → default}/data/flows/list_contacts.yml +0 -0
- /rasa/cli/project_templates/{calm → default}/data/flows/remove_contact.yml +0 -0
- /rasa/cli/project_templates/{calm → default}/db/contacts.json +0 -0
- /rasa/cli/project_templates/{calm → default}/domain/add_contact.yml +0 -0
- /rasa/cli/project_templates/{calm → default}/domain/list_contacts.yml +0 -0
- /rasa/cli/project_templates/{calm → default}/domain/remove_contact.yml +0 -0
- /rasa/cli/project_templates/{calm → default}/domain/shared.yml +0 -0
- /rasa/{cli/project_templates/calm/actions → studio/download}/__init__.py +0 -0
- {rasa_pro-3.13.0.dev2.dist-info → rasa_pro-3.13.0.dev5.dist-info}/NOTICE +0 -0
- {rasa_pro-3.13.0.dev2.dist-info → rasa_pro-3.13.0.dev5.dist-info}/WHEEL +0 -0
- {rasa_pro-3.13.0.dev2.dist-info → rasa_pro-3.13.0.dev5.dist-info}/entry_points.txt +0 -0
|
@@ -1,39 +0,0 @@
|
|
|
1
|
-
from typing import Any, Dict, Optional
|
|
2
|
-
|
|
3
|
-
from rasa.core.information_retrieval import SearchResultList
|
|
4
|
-
from rasa.document_retrieval.knowledge_base_connectors.knowledge_base_connector import (
|
|
5
|
-
KnowledgeBaseConnector,
|
|
6
|
-
)
|
|
7
|
-
from rasa.engine.storage.resource import Resource
|
|
8
|
-
from rasa.engine.storage.storage import ModelStorage
|
|
9
|
-
from rasa.shared.core.trackers import DialogueStateTracker
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
class APIConnector(KnowledgeBaseConnector):
|
|
13
|
-
def __init__(self, config: Dict[str, Any]) -> None:
|
|
14
|
-
self.config = config
|
|
15
|
-
|
|
16
|
-
@classmethod
|
|
17
|
-
def load(
|
|
18
|
-
cls,
|
|
19
|
-
config: Dict[str, Any],
|
|
20
|
-
model_storage: ModelStorage,
|
|
21
|
-
resource: Resource,
|
|
22
|
-
**kwargs: Any,
|
|
23
|
-
) -> "APIConnector":
|
|
24
|
-
# TODO implement
|
|
25
|
-
return APIConnector(config)
|
|
26
|
-
|
|
27
|
-
async def retrieve_documents(
|
|
28
|
-
self,
|
|
29
|
-
search_query: str,
|
|
30
|
-
k: int,
|
|
31
|
-
threshold: float,
|
|
32
|
-
tracker: Optional[DialogueStateTracker],
|
|
33
|
-
) -> Optional[SearchResultList]:
|
|
34
|
-
# TODO implement
|
|
35
|
-
return SearchResultList(results=[], metadata={})
|
|
36
|
-
|
|
37
|
-
def connect_or_raise(self) -> None:
|
|
38
|
-
# TODO implement
|
|
39
|
-
return None
|
|
@@ -1,34 +0,0 @@
|
|
|
1
|
-
from abc import ABC, abstractmethod
|
|
2
|
-
from typing import Any, Dict, Optional
|
|
3
|
-
|
|
4
|
-
from rasa.core.information_retrieval import SearchResultList
|
|
5
|
-
from rasa.engine.storage.resource import Resource
|
|
6
|
-
from rasa.engine.storage.storage import ModelStorage
|
|
7
|
-
from rasa.shared.core.trackers import DialogueStateTracker
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
class KnowledgeBaseConnector(ABC):
|
|
11
|
-
@abstractmethod
|
|
12
|
-
def connect_or_raise(self) -> None:
|
|
13
|
-
pass
|
|
14
|
-
|
|
15
|
-
@abstractmethod
|
|
16
|
-
async def retrieve_documents(
|
|
17
|
-
self,
|
|
18
|
-
search_query: str,
|
|
19
|
-
k: int,
|
|
20
|
-
threshold: float,
|
|
21
|
-
tracker: Optional[DialogueStateTracker],
|
|
22
|
-
) -> Optional[SearchResultList]:
|
|
23
|
-
pass
|
|
24
|
-
|
|
25
|
-
@classmethod
|
|
26
|
-
@abstractmethod
|
|
27
|
-
def load(
|
|
28
|
-
cls,
|
|
29
|
-
config: Dict[str, Any],
|
|
30
|
-
model_storage: ModelStorage,
|
|
31
|
-
resource: Resource,
|
|
32
|
-
**kwargs: Any,
|
|
33
|
-
) -> "KnowledgeBaseConnector":
|
|
34
|
-
pass
|
|
@@ -1,226 +0,0 @@
|
|
|
1
|
-
import copy
|
|
2
|
-
from enum import Enum
|
|
3
|
-
from typing import TYPE_CHECKING, Any, Dict, Optional
|
|
4
|
-
|
|
5
|
-
import structlog
|
|
6
|
-
|
|
7
|
-
from rasa.core.information_retrieval import (
|
|
8
|
-
InformationRetrieval,
|
|
9
|
-
InformationRetrievalException,
|
|
10
|
-
SearchResultList,
|
|
11
|
-
create_from_endpoint_config,
|
|
12
|
-
)
|
|
13
|
-
from rasa.core.information_retrieval.faiss import FAISS_Store
|
|
14
|
-
from rasa.document_retrieval.constants import (
|
|
15
|
-
DEFAULT_EMBEDDINGS_CONFIG,
|
|
16
|
-
DEFAULT_VECTOR_STORE,
|
|
17
|
-
DEFAULT_VECTOR_STORE_TYPE,
|
|
18
|
-
VECTOR_STORE_CONFIG_KEY,
|
|
19
|
-
VECTOR_STORE_TYPE_CONFIG_KEY,
|
|
20
|
-
)
|
|
21
|
-
from rasa.document_retrieval.knowledge_base_connectors.knowledge_base_connector import (
|
|
22
|
-
KnowledgeBaseConnector,
|
|
23
|
-
)
|
|
24
|
-
from rasa.engine.storage.resource import Resource
|
|
25
|
-
from rasa.engine.storage.storage import ModelStorage
|
|
26
|
-
from rasa.shared.constants import EMBEDDINGS_CONFIG_KEY
|
|
27
|
-
from rasa.shared.core.trackers import DialogueStateTracker, EventVerbosity
|
|
28
|
-
from rasa.shared.exceptions import RasaException
|
|
29
|
-
from rasa.shared.providers.embedding._langchain_embedding_client_adapter import (
|
|
30
|
-
_LangchainEmbeddingClientAdapter,
|
|
31
|
-
)
|
|
32
|
-
from rasa.shared.utils.health_check.embeddings_health_check_mixin import (
|
|
33
|
-
EmbeddingsHealthCheckMixin,
|
|
34
|
-
)
|
|
35
|
-
from rasa.shared.utils.health_check.health_check import perform_embeddings_health_check
|
|
36
|
-
from rasa.shared.utils.llm import embedder_factory, resolve_model_client_config
|
|
37
|
-
|
|
38
|
-
if TYPE_CHECKING:
|
|
39
|
-
from langchain.schema.embeddings import Embeddings
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
structlogger = structlog.get_logger()
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
class VectorStoreConnectionError(RasaException):
|
|
46
|
-
"""Exception raised for errors in connecting to the vector store."""
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
class VectorStoreConfigurationError(RasaException):
|
|
50
|
-
"""Exception raised for errors in vector store configuration."""
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
class VectorStoreType(Enum):
|
|
54
|
-
FAISS = "FAISS"
|
|
55
|
-
QDRANT = "QDRANT"
|
|
56
|
-
MILVUS = "MILVUS"
|
|
57
|
-
|
|
58
|
-
def __str__(self) -> str:
|
|
59
|
-
return self.value
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
class VectorStoreConnector(KnowledgeBaseConnector, EmbeddingsHealthCheckMixin):
|
|
63
|
-
def __init__(
|
|
64
|
-
self,
|
|
65
|
-
config: Dict[str, Any],
|
|
66
|
-
model_storage: ModelStorage,
|
|
67
|
-
resource: Resource,
|
|
68
|
-
vector_store: Optional[InformationRetrieval] = None,
|
|
69
|
-
) -> None:
|
|
70
|
-
self.config = config
|
|
71
|
-
self.vector_store_type = config.get(VECTOR_STORE_CONFIG_KEY, {}).get(
|
|
72
|
-
VECTOR_STORE_TYPE_CONFIG_KEY
|
|
73
|
-
)
|
|
74
|
-
|
|
75
|
-
# Vector store object and configuration
|
|
76
|
-
self.vector_store = vector_store
|
|
77
|
-
self.vector_store_config = self.config.get(
|
|
78
|
-
VECTOR_STORE_CONFIG_KEY, DEFAULT_VECTOR_STORE
|
|
79
|
-
)
|
|
80
|
-
|
|
81
|
-
# Embeddings configuration for encoding the search query
|
|
82
|
-
self.embeddings_config = (
|
|
83
|
-
self.config[EMBEDDINGS_CONFIG_KEY] or DEFAULT_EMBEDDINGS_CONFIG
|
|
84
|
-
)
|
|
85
|
-
|
|
86
|
-
self._model_storage = model_storage
|
|
87
|
-
self._resource = resource
|
|
88
|
-
|
|
89
|
-
@classmethod
|
|
90
|
-
def _create_plain_embedder(cls, config: Dict[str, Any]) -> "Embeddings":
|
|
91
|
-
"""Creates an embedder based on the given configuration.
|
|
92
|
-
|
|
93
|
-
Returns:
|
|
94
|
-
The embedder.
|
|
95
|
-
"""
|
|
96
|
-
# Copy the config so original config is not modified
|
|
97
|
-
config = copy.deepcopy(config)
|
|
98
|
-
# Resolve config and instantiate the embedding client
|
|
99
|
-
config[EMBEDDINGS_CONFIG_KEY] = resolve_model_client_config(
|
|
100
|
-
config.get(EMBEDDINGS_CONFIG_KEY), VectorStoreConnector.__name__
|
|
101
|
-
)
|
|
102
|
-
client = embedder_factory(
|
|
103
|
-
config.get(EMBEDDINGS_CONFIG_KEY), DEFAULT_EMBEDDINGS_CONFIG
|
|
104
|
-
)
|
|
105
|
-
# Wrap the embedding client in the adapter
|
|
106
|
-
return _LangchainEmbeddingClientAdapter(client)
|
|
107
|
-
|
|
108
|
-
@classmethod
|
|
109
|
-
def load(
|
|
110
|
-
cls,
|
|
111
|
-
config: Dict[str, Any],
|
|
112
|
-
model_storage: ModelStorage,
|
|
113
|
-
resource: Resource,
|
|
114
|
-
**kwargs: Any,
|
|
115
|
-
) -> "VectorStoreConnector":
|
|
116
|
-
# Perform health check on the resolved embeddings client config
|
|
117
|
-
embedding_config = resolve_model_client_config(
|
|
118
|
-
config.get(EMBEDDINGS_CONFIG_KEY, {})
|
|
119
|
-
)
|
|
120
|
-
perform_embeddings_health_check(
|
|
121
|
-
embedding_config,
|
|
122
|
-
DEFAULT_EMBEDDINGS_CONFIG,
|
|
123
|
-
"vector_store_connector.load",
|
|
124
|
-
VectorStoreConnector.__name__,
|
|
125
|
-
)
|
|
126
|
-
|
|
127
|
-
store_type = config.get(VECTOR_STORE_CONFIG_KEY, {}).get(
|
|
128
|
-
VECTOR_STORE_TYPE_CONFIG_KEY
|
|
129
|
-
)
|
|
130
|
-
embeddings = cls._create_plain_embedder(config)
|
|
131
|
-
|
|
132
|
-
structlogger.info("vector_store_connector.load", config=config)
|
|
133
|
-
if store_type == VectorStoreType.FAISS.value:
|
|
134
|
-
# if a vector store is not specified,
|
|
135
|
-
# default to using FAISS with the index stored in the model
|
|
136
|
-
# TODO figure out a way to get path without context manager
|
|
137
|
-
with model_storage.read_from(resource) as path:
|
|
138
|
-
vector_store = FAISS_Store(
|
|
139
|
-
embeddings=embeddings,
|
|
140
|
-
index_path=path,
|
|
141
|
-
docs_folder=None,
|
|
142
|
-
create_index=False,
|
|
143
|
-
)
|
|
144
|
-
else:
|
|
145
|
-
vector_store = create_from_endpoint_config(
|
|
146
|
-
config_type=store_type,
|
|
147
|
-
embeddings=embeddings,
|
|
148
|
-
) # type: ignore
|
|
149
|
-
|
|
150
|
-
return cls(
|
|
151
|
-
config=config,
|
|
152
|
-
model_storage=model_storage,
|
|
153
|
-
resource=resource,
|
|
154
|
-
vector_store=vector_store,
|
|
155
|
-
)
|
|
156
|
-
|
|
157
|
-
def connect_or_raise(self) -> None:
|
|
158
|
-
"""Connects to the vector store or raises an exception.
|
|
159
|
-
|
|
160
|
-
Raise exceptions for the following cases:
|
|
161
|
-
- The configuration is not specified
|
|
162
|
-
- Unable to connect to the vector store
|
|
163
|
-
|
|
164
|
-
Args:
|
|
165
|
-
endpoints: Endpoints configuration.
|
|
166
|
-
"""
|
|
167
|
-
if self.vector_store_type == VectorStoreType.FAISS.value:
|
|
168
|
-
return
|
|
169
|
-
from rasa.core.utils import AvailableEndpoints
|
|
170
|
-
|
|
171
|
-
endpoints = AvailableEndpoints.get_instance()
|
|
172
|
-
|
|
173
|
-
config = endpoints.vector_store if endpoints else None
|
|
174
|
-
store_type = self.config.get(VECTOR_STORE_CONFIG_KEY, {}).get(
|
|
175
|
-
VECTOR_STORE_TYPE_CONFIG_KEY
|
|
176
|
-
)
|
|
177
|
-
|
|
178
|
-
if config is None and store_type != DEFAULT_VECTOR_STORE_TYPE:
|
|
179
|
-
structlogger.error("vector_store_connector._connect_or_raise.no_config")
|
|
180
|
-
raise VectorStoreConfigurationError(
|
|
181
|
-
"""No vector store specified. Please specify a vector
|
|
182
|
-
store in the endpoints configuration."""
|
|
183
|
-
)
|
|
184
|
-
try:
|
|
185
|
-
self.vector_store.connect(config) # type: ignore
|
|
186
|
-
except Exception as e:
|
|
187
|
-
structlogger.error(
|
|
188
|
-
"vector_store_connector._connect_or_raise.connect_error",
|
|
189
|
-
error=e,
|
|
190
|
-
config=config,
|
|
191
|
-
)
|
|
192
|
-
raise VectorStoreConnectionError(
|
|
193
|
-
f"Unable to connect to the vector store. Error: {e}"
|
|
194
|
-
)
|
|
195
|
-
|
|
196
|
-
async def retrieve_documents(
|
|
197
|
-
self,
|
|
198
|
-
search_query: str,
|
|
199
|
-
k: int,
|
|
200
|
-
threshold: float,
|
|
201
|
-
tracker: Optional[DialogueStateTracker],
|
|
202
|
-
) -> Optional[SearchResultList]:
|
|
203
|
-
if self.vector_store is None:
|
|
204
|
-
return None
|
|
205
|
-
|
|
206
|
-
try:
|
|
207
|
-
self.connect_or_raise()
|
|
208
|
-
except (VectorStoreConfigurationError, VectorStoreConnectionError) as e:
|
|
209
|
-
structlogger.error("vector_store_connector.connection_error", error=e)
|
|
210
|
-
return None
|
|
211
|
-
|
|
212
|
-
if tracker is not None:
|
|
213
|
-
tracker_state = tracker.current_state(EventVerbosity.AFTER_RESTART)
|
|
214
|
-
else:
|
|
215
|
-
tracker_state = {}
|
|
216
|
-
|
|
217
|
-
try:
|
|
218
|
-
return await self.vector_store.search(
|
|
219
|
-
query=search_query,
|
|
220
|
-
threshold=threshold,
|
|
221
|
-
tracker_state=tracker_state,
|
|
222
|
-
k=k,
|
|
223
|
-
)
|
|
224
|
-
except InformationRetrievalException as e:
|
|
225
|
-
structlogger.error("vector_store.search_error", error=e)
|
|
226
|
-
return None
|
|
@@ -1,234 +0,0 @@
|
|
|
1
|
-
import importlib.resources
|
|
2
|
-
from enum import Enum
|
|
3
|
-
from typing import Any, Dict, Optional
|
|
4
|
-
|
|
5
|
-
import structlog
|
|
6
|
-
from jinja2 import Template
|
|
7
|
-
|
|
8
|
-
import rasa.shared.utils.io
|
|
9
|
-
from rasa.engine.storage.resource import Resource
|
|
10
|
-
from rasa.engine.storage.storage import ModelStorage
|
|
11
|
-
from rasa.shared.constants import (
|
|
12
|
-
LLM_CONFIG_KEY,
|
|
13
|
-
MODEL_CONFIG_KEY,
|
|
14
|
-
OPENAI_PROVIDER,
|
|
15
|
-
PROMPT_TEMPLATE_CONFIG_KEY,
|
|
16
|
-
PROVIDER_CONFIG_KEY,
|
|
17
|
-
TEXT,
|
|
18
|
-
TIMEOUT_CONFIG_KEY,
|
|
19
|
-
)
|
|
20
|
-
from rasa.shared.core.trackers import DialogueStateTracker
|
|
21
|
-
from rasa.shared.exceptions import FileIOException, ProviderClientAPIException
|
|
22
|
-
from rasa.shared.nlu.training_data.message import Message
|
|
23
|
-
from rasa.shared.providers.llm.llm_client import LLMClient
|
|
24
|
-
from rasa.shared.providers.llm.llm_response import LLMResponse
|
|
25
|
-
from rasa.shared.utils.health_check.health_check import perform_llm_health_check
|
|
26
|
-
from rasa.shared.utils.health_check.llm_health_check_mixin import (
|
|
27
|
-
LLMHealthCheckMixin,
|
|
28
|
-
)
|
|
29
|
-
from rasa.shared.utils.llm import (
|
|
30
|
-
AI,
|
|
31
|
-
DEFAULT_OPENAI_GENERATE_MODEL_NAME,
|
|
32
|
-
DEFAULT_OPENAI_MAX_GENERATED_TOKENS,
|
|
33
|
-
USER,
|
|
34
|
-
get_prompt_template,
|
|
35
|
-
llm_factory,
|
|
36
|
-
resolve_model_client_config,
|
|
37
|
-
sanitize_message_for_prompt,
|
|
38
|
-
tracker_as_readable_transcript,
|
|
39
|
-
)
|
|
40
|
-
|
|
41
|
-
QUERY_REWRITER_PROMPT_FILE_NAME = "query_rewriter_prompt_template.jinja2"
|
|
42
|
-
MAX_TURNS = "max_turns"
|
|
43
|
-
DEFAULT_LLM_CONFIG = {
|
|
44
|
-
PROVIDER_CONFIG_KEY: OPENAI_PROVIDER,
|
|
45
|
-
MODEL_CONFIG_KEY: DEFAULT_OPENAI_GENERATE_MODEL_NAME,
|
|
46
|
-
"temperature": 0.3,
|
|
47
|
-
"max_tokens": DEFAULT_OPENAI_MAX_GENERATED_TOKENS,
|
|
48
|
-
TIMEOUT_CONFIG_KEY: 5,
|
|
49
|
-
}
|
|
50
|
-
DEFAULT_QUERY_REWRITER_PROMPT_TEMPLATE = importlib.resources.read_text(
|
|
51
|
-
"rasa.document_retrieval",
|
|
52
|
-
"query_rewriter_prompt_template.jinja2",
|
|
53
|
-
)
|
|
54
|
-
|
|
55
|
-
TYPE_CONFIG_KEY = "type"
|
|
56
|
-
|
|
57
|
-
structlogger = structlog.get_logger()
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
class QueryRewritingType(Enum):
|
|
61
|
-
PLAIN = "PLAIN"
|
|
62
|
-
CONCATENATED_TURNS = "CONCATENATED_TURNS"
|
|
63
|
-
REPHRASE = "REPHRASE"
|
|
64
|
-
KEYWORD_EXTRACTION = "KEYWORD_EXTRACTION"
|
|
65
|
-
|
|
66
|
-
def __str__(self) -> str:
|
|
67
|
-
return self.value
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
class QueryRewriter(LLMHealthCheckMixin):
|
|
71
|
-
@classmethod
|
|
72
|
-
def get_default_config(cls) -> Dict[str, Any]:
|
|
73
|
-
"""The default config for the query rewriter."""
|
|
74
|
-
return {
|
|
75
|
-
TYPE_CONFIG_KEY: QueryRewritingType.PLAIN,
|
|
76
|
-
MAX_TURNS: 0,
|
|
77
|
-
LLM_CONFIG_KEY: DEFAULT_LLM_CONFIG,
|
|
78
|
-
PROMPT_TEMPLATE_CONFIG_KEY: DEFAULT_QUERY_REWRITER_PROMPT_TEMPLATE,
|
|
79
|
-
}
|
|
80
|
-
|
|
81
|
-
def __init__(
|
|
82
|
-
self,
|
|
83
|
-
config: Dict[str, Any],
|
|
84
|
-
model_storage: ModelStorage,
|
|
85
|
-
resource: Resource,
|
|
86
|
-
prompt_template: Optional[str] = None,
|
|
87
|
-
):
|
|
88
|
-
self.config = {**self.get_default_config(), **config}
|
|
89
|
-
self.config[LLM_CONFIG_KEY] = resolve_model_client_config(
|
|
90
|
-
self.config.get(LLM_CONFIG_KEY), QueryRewriter.__name__
|
|
91
|
-
)
|
|
92
|
-
self.prompt_template = prompt_template or get_prompt_template(
|
|
93
|
-
config.get(PROMPT_TEMPLATE_CONFIG_KEY),
|
|
94
|
-
DEFAULT_QUERY_REWRITER_PROMPT_TEMPLATE,
|
|
95
|
-
)
|
|
96
|
-
|
|
97
|
-
self._model_storage = model_storage
|
|
98
|
-
self._resource = resource
|
|
99
|
-
|
|
100
|
-
@classmethod
|
|
101
|
-
def load(
|
|
102
|
-
cls,
|
|
103
|
-
config: Dict[str, Any],
|
|
104
|
-
model_storage: ModelStorage,
|
|
105
|
-
resource: Resource,
|
|
106
|
-
**kwargs: Any,
|
|
107
|
-
) -> "QueryRewriter":
|
|
108
|
-
"""Load query rewriter."""
|
|
109
|
-
llm_config = resolve_model_client_config(config.get(LLM_CONFIG_KEY, {}))
|
|
110
|
-
perform_llm_health_check(
|
|
111
|
-
llm_config,
|
|
112
|
-
DEFAULT_LLM_CONFIG,
|
|
113
|
-
"query_rewriter.load",
|
|
114
|
-
QueryRewriter.__name__,
|
|
115
|
-
)
|
|
116
|
-
|
|
117
|
-
# load prompt template
|
|
118
|
-
prompt_template = None
|
|
119
|
-
try:
|
|
120
|
-
with model_storage.read_from(resource) as path:
|
|
121
|
-
prompt_template = rasa.shared.utils.io.read_file(
|
|
122
|
-
path / QUERY_REWRITER_PROMPT_FILE_NAME
|
|
123
|
-
)
|
|
124
|
-
except (FileNotFoundError, FileIOException) as e:
|
|
125
|
-
structlogger.warning(
|
|
126
|
-
"query_rewriter.load_prompt_template.failed",
|
|
127
|
-
error=e,
|
|
128
|
-
resource=resource.name,
|
|
129
|
-
)
|
|
130
|
-
|
|
131
|
-
return QueryRewriter(config, model_storage, resource, prompt_template)
|
|
132
|
-
|
|
133
|
-
def persist(self) -> None:
|
|
134
|
-
with self._model_storage.write_to(self._resource) as path:
|
|
135
|
-
rasa.shared.utils.io.write_text_file(
|
|
136
|
-
self.prompt_template, path / QUERY_REWRITER_PROMPT_FILE_NAME
|
|
137
|
-
)
|
|
138
|
-
|
|
139
|
-
@staticmethod
|
|
140
|
-
def _concatenate_turns(
|
|
141
|
-
message: Message, tracker: DialogueStateTracker, max_turns: int
|
|
142
|
-
) -> str:
|
|
143
|
-
transcript = tracker_as_readable_transcript(tracker, max_turns=max_turns)
|
|
144
|
-
transcript += "\nUSER: " + message.get(TEXT)
|
|
145
|
-
return transcript
|
|
146
|
-
|
|
147
|
-
@staticmethod
|
|
148
|
-
async def _invoke_llm(prompt: str, llm: LLMClient) -> Optional[LLMResponse]:
|
|
149
|
-
try:
|
|
150
|
-
return await llm.acompletion(prompt)
|
|
151
|
-
except Exception as e:
|
|
152
|
-
# unfortunately, langchain does not wrap LLM exceptions which means
|
|
153
|
-
# we have to catch all exceptions here
|
|
154
|
-
structlogger.error("query_rewriter.llm.error", error=e)
|
|
155
|
-
raise ProviderClientAPIException(
|
|
156
|
-
message="LLM call exception", original_exception=e
|
|
157
|
-
)
|
|
158
|
-
|
|
159
|
-
async def _rephrase_message(
|
|
160
|
-
self, message: Message, tracker: DialogueStateTracker, max_turns: int = 5
|
|
161
|
-
) -> str:
|
|
162
|
-
llm = llm_factory(self.config.get(LLM_CONFIG_KEY), DEFAULT_LLM_CONFIG)
|
|
163
|
-
|
|
164
|
-
transcript = tracker_as_readable_transcript(
|
|
165
|
-
tracker, max_turns=max_turns, ai_prefix="ASSISTANT"
|
|
166
|
-
)
|
|
167
|
-
|
|
168
|
-
inputs = {
|
|
169
|
-
"conversation": transcript,
|
|
170
|
-
"user_message": message.get(TEXT),
|
|
171
|
-
}
|
|
172
|
-
|
|
173
|
-
prompt = Template(self.prompt_template).render(**inputs)
|
|
174
|
-
llm_response = await self._invoke_llm(prompt, llm)
|
|
175
|
-
llm_response = LLMResponse.ensure_llm_response(llm_response)
|
|
176
|
-
|
|
177
|
-
return llm_response.choices[0]
|
|
178
|
-
|
|
179
|
-
@staticmethod
|
|
180
|
-
def _keyword_extraction(
|
|
181
|
-
message: Message, tracker: DialogueStateTracker, max_turns: int = 5
|
|
182
|
-
) -> str:
|
|
183
|
-
import spacy
|
|
184
|
-
|
|
185
|
-
nlp = spacy.load("en_core_web_md")
|
|
186
|
-
|
|
187
|
-
transcript = tracker_as_readable_transcript(tracker, max_turns=max_turns)
|
|
188
|
-
transcript = transcript.replace(USER, "")
|
|
189
|
-
transcript = transcript.replace(AI, "")
|
|
190
|
-
|
|
191
|
-
doc = nlp(transcript)
|
|
192
|
-
|
|
193
|
-
keywords = set()
|
|
194
|
-
for token in doc:
|
|
195
|
-
# Extract nouns and proper nouns
|
|
196
|
-
if token.pos_ in ["NOUN", "PROPN"]:
|
|
197
|
-
keywords.add(token.lemma_)
|
|
198
|
-
|
|
199
|
-
for ent in doc.ents:
|
|
200
|
-
# Add named entities as keywords
|
|
201
|
-
keywords.add(ent.text)
|
|
202
|
-
|
|
203
|
-
# Remove stop words and punctuation
|
|
204
|
-
keywords = {
|
|
205
|
-
word
|
|
206
|
-
for word in keywords
|
|
207
|
-
if word.lower() not in nlp.Defaults.stop_words and word.isalpha()
|
|
208
|
-
}
|
|
209
|
-
|
|
210
|
-
if keywords:
|
|
211
|
-
return message.get(TEXT) + " " + " ".join(keywords)
|
|
212
|
-
else:
|
|
213
|
-
return message.get(TEXT)
|
|
214
|
-
|
|
215
|
-
async def prepare_search_query(
|
|
216
|
-
self, message: Message, tracker: DialogueStateTracker
|
|
217
|
-
) -> str:
|
|
218
|
-
query_rewriting_type = self.config[TYPE_CONFIG_KEY]
|
|
219
|
-
max_turns: int = self.config[MAX_TURNS]
|
|
220
|
-
|
|
221
|
-
query: str
|
|
222
|
-
|
|
223
|
-
if query_rewriting_type == QueryRewritingType.CONCATENATED_TURNS.value:
|
|
224
|
-
query = self._concatenate_turns(message, tracker, max_turns)
|
|
225
|
-
elif query_rewriting_type == QueryRewritingType.KEYWORD_EXTRACTION.value:
|
|
226
|
-
query = self._keyword_extraction(message, tracker, max_turns)
|
|
227
|
-
elif query_rewriting_type == QueryRewritingType.REPHRASE.value:
|
|
228
|
-
query = await self._rephrase_message(message, tracker, max_turns)
|
|
229
|
-
elif query_rewriting_type == QueryRewritingType.PLAIN.value:
|
|
230
|
-
query = message.get(TEXT)
|
|
231
|
-
else:
|
|
232
|
-
raise ValueError(f"Invalid query rewriting type: {query_rewriting_type}")
|
|
233
|
-
|
|
234
|
-
return sanitize_message_for_prompt(query)
|