rasa-pro 3.12.6.dev2__py3-none-any.whl → 3.13.0.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- rasa/__init__.py +0 -6
- rasa/cli/scaffold.py +1 -1
- rasa/core/actions/action.py +38 -34
- rasa/core/actions/action_run_slot_rejections.py +1 -1
- rasa/core/channels/studio_chat.py +16 -43
- rasa/core/channels/voice_ready/audiocodes.py +46 -17
- rasa/core/information_retrieval/faiss.py +68 -7
- rasa/core/information_retrieval/information_retrieval.py +40 -2
- rasa/core/information_retrieval/milvus.py +7 -2
- rasa/core/information_retrieval/qdrant.py +7 -2
- rasa/core/nlg/contextual_response_rephraser.py +11 -27
- rasa/core/nlg/generator.py +5 -21
- rasa/core/nlg/response.py +6 -43
- rasa/core/nlg/summarize.py +1 -15
- rasa/core/nlg/translate.py +0 -8
- rasa/core/policies/enterprise_search_policy.py +64 -316
- rasa/core/policies/flows/flow_executor.py +3 -38
- rasa/core/policies/intentless_policy.py +4 -17
- rasa/core/policies/policy.py +0 -2
- rasa/core/processor.py +27 -6
- rasa/core/utils.py +53 -0
- rasa/dialogue_understanding/coexistence/llm_based_router.py +4 -18
- rasa/dialogue_understanding/commands/cancel_flow_command.py +4 -59
- rasa/dialogue_understanding/commands/knowledge_answer_command.py +2 -2
- rasa/dialogue_understanding/commands/start_flow_command.py +0 -41
- rasa/dialogue_understanding/generator/command_generator.py +67 -0
- rasa/dialogue_understanding/generator/command_parser.py +1 -1
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +7 -23
- rasa/dialogue_understanding/generator/llm_command_generator.py +1 -3
- rasa/dialogue_understanding/generator/prompt_templates/command_prompt_template.jinja2 +1 -1
- rasa/dialogue_understanding/generator/prompt_templates/command_prompt_v2_claude_3_5_sonnet_20240620_template.jinja2 +1 -1
- rasa/dialogue_understanding/generator/prompt_templates/command_prompt_v2_gpt_4o_2024_11_20_template.jinja2 +24 -2
- rasa/dialogue_understanding/generator/single_step/compact_llm_command_generator.py +8 -12
- rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +0 -61
- rasa/dialogue_understanding/processor/command_processor.py +7 -65
- rasa/dialogue_understanding/stack/utils.py +0 -38
- rasa/dialogue_understanding_test/command_metric_calculation.py +7 -40
- rasa/dialogue_understanding_test/command_metrics.py +38 -0
- rasa/dialogue_understanding_test/du_test_case.py +58 -25
- rasa/dialogue_understanding_test/du_test_result.py +228 -132
- rasa/dialogue_understanding_test/du_test_runner.py +10 -1
- rasa/dialogue_understanding_test/io.py +48 -16
- rasa/document_retrieval/__init__.py +0 -0
- rasa/document_retrieval/constants.py +32 -0
- rasa/document_retrieval/document_post_processor.py +351 -0
- rasa/document_retrieval/document_post_processor_prompt_template.jinja2 +0 -0
- rasa/document_retrieval/document_retriever.py +333 -0
- rasa/document_retrieval/knowledge_base_connectors/__init__.py +0 -0
- rasa/document_retrieval/knowledge_base_connectors/api_connector.py +39 -0
- rasa/document_retrieval/knowledge_base_connectors/knowledge_base_connector.py +34 -0
- rasa/document_retrieval/knowledge_base_connectors/vector_store_connector.py +226 -0
- rasa/document_retrieval/query_rewriter.py +234 -0
- rasa/document_retrieval/query_rewriter_prompt_template.jinja2 +8 -0
- rasa/engine/recipes/default_components.py +2 -0
- rasa/hooks.py +0 -55
- rasa/model_manager/model_api.py +1 -1
- rasa/model_manager/socket_bridge.py +0 -7
- rasa/shared/constants.py +0 -5
- rasa/shared/core/constants.py +0 -8
- rasa/shared/core/domain.py +12 -3
- rasa/shared/core/flows/flow.py +0 -17
- rasa/shared/core/flows/flows_yaml_schema.json +3 -38
- rasa/shared/core/flows/steps/collect.py +5 -18
- rasa/shared/core/flows/utils.py +1 -16
- rasa/shared/core/slot_mappings.py +11 -5
- rasa/shared/core/slots.py +1 -1
- rasa/shared/core/trackers.py +4 -10
- rasa/shared/nlu/constants.py +0 -1
- rasa/shared/providers/constants.py +0 -9
- rasa/shared/providers/llm/_base_litellm_client.py +4 -14
- rasa/shared/providers/llm/default_litellm_llm_client.py +2 -2
- rasa/shared/providers/llm/litellm_router_llm_client.py +7 -17
- rasa/shared/providers/llm/llm_client.py +15 -24
- rasa/shared/providers/llm/self_hosted_llm_client.py +2 -10
- rasa/shared/utils/common.py +11 -1
- rasa/shared/utils/health_check/health_check.py +1 -7
- rasa/shared/utils/llm.py +1 -1
- rasa/tracing/instrumentation/attribute_extractors.py +50 -17
- rasa/tracing/instrumentation/instrumentation.py +12 -12
- rasa/tracing/instrumentation/intentless_policy_instrumentation.py +1 -2
- rasa/utils/licensing.py +0 -15
- rasa/validator.py +1 -123
- rasa/version.py +1 -1
- {rasa_pro-3.12.6.dev2.dist-info → rasa_pro-3.13.0.dev2.dist-info}/METADATA +2 -3
- {rasa_pro-3.12.6.dev2.dist-info → rasa_pro-3.13.0.dev2.dist-info}/RECORD +88 -80
- rasa/core/actions/action_handle_digressions.py +0 -164
- rasa/dialogue_understanding/commands/handle_digressions_command.py +0 -144
- rasa/dialogue_understanding/patterns/handle_digressions.py +0 -81
- rasa/monkey_patches.py +0 -91
- {rasa_pro-3.12.6.dev2.dist-info → rasa_pro-3.13.0.dev2.dist-info}/NOTICE +0 -0
- {rasa_pro-3.12.6.dev2.dist-info → rasa_pro-3.13.0.dev2.dist-info}/WHEEL +0 -0
- {rasa_pro-3.12.6.dev2.dist-info → rasa_pro-3.13.0.dev2.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,351 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import importlib.resources
|
|
3
|
+
from enum import Enum
|
|
4
|
+
from functools import lru_cache
|
|
5
|
+
from typing import Any, Dict, Optional, Text
|
|
6
|
+
|
|
7
|
+
import structlog
|
|
8
|
+
from jinja2 import Template
|
|
9
|
+
|
|
10
|
+
import rasa.shared.utils.io
|
|
11
|
+
from rasa.core.information_retrieval import SearchResult, SearchResultList
|
|
12
|
+
from rasa.dialogue_understanding.utils import add_prompt_to_message_parse_data
|
|
13
|
+
from rasa.engine.storage.resource import Resource
|
|
14
|
+
from rasa.engine.storage.storage import ModelStorage
|
|
15
|
+
from rasa.shared.constants import (
|
|
16
|
+
LLM_CONFIG_KEY,
|
|
17
|
+
MODEL_CONFIG_KEY,
|
|
18
|
+
OPENAI_PROVIDER,
|
|
19
|
+
PROMPT_TEMPLATE_CONFIG_KEY,
|
|
20
|
+
PROVIDER_CONFIG_KEY,
|
|
21
|
+
TEXT,
|
|
22
|
+
TIMEOUT_CONFIG_KEY,
|
|
23
|
+
)
|
|
24
|
+
from rasa.shared.core.trackers import DialogueStateTracker
|
|
25
|
+
from rasa.shared.exceptions import FileIOException, ProviderClientAPIException
|
|
26
|
+
from rasa.shared.nlu.training_data.message import Message
|
|
27
|
+
from rasa.shared.providers.llm.llm_client import LLMClient
|
|
28
|
+
from rasa.shared.providers.llm.llm_response import LLMResponse
|
|
29
|
+
from rasa.shared.utils.health_check.health_check import perform_llm_health_check
|
|
30
|
+
from rasa.shared.utils.health_check.llm_health_check_mixin import LLMHealthCheckMixin
|
|
31
|
+
from rasa.shared.utils.llm import (
|
|
32
|
+
DEFAULT_OPENAI_GENERATE_MODEL_NAME,
|
|
33
|
+
DEFAULT_OPENAI_MAX_GENERATED_TOKENS,
|
|
34
|
+
get_prompt_template,
|
|
35
|
+
llm_factory,
|
|
36
|
+
resolve_model_client_config,
|
|
37
|
+
tracker_as_readable_transcript,
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
TYPE_CONFIG_KEY = "type"
|
|
41
|
+
EMBEDDING_MODEL_KEY = "embedding_model_name"
|
|
42
|
+
|
|
43
|
+
DEFAULT_EMBEDDING_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
|
|
44
|
+
DOCUMENT_POST_PROCESSOR_PROMPT_FILE_NAME = (
|
|
45
|
+
"document_post_processor_prompt_template.jina2"
|
|
46
|
+
)
|
|
47
|
+
DEFAULT_LLM_CONFIG = {
|
|
48
|
+
PROVIDER_CONFIG_KEY: OPENAI_PROVIDER,
|
|
49
|
+
MODEL_CONFIG_KEY: DEFAULT_OPENAI_GENERATE_MODEL_NAME,
|
|
50
|
+
"temperature": 0.3,
|
|
51
|
+
"max_tokens": DEFAULT_OPENAI_MAX_GENERATED_TOKENS,
|
|
52
|
+
TIMEOUT_CONFIG_KEY: 5,
|
|
53
|
+
}
|
|
54
|
+
DEFAULT_DOCUMENT_POST_PROCESSOR_PROMPT_TEMPLATE = importlib.resources.read_text(
|
|
55
|
+
"rasa.document_retrieval",
|
|
56
|
+
"document_post_processor_prompt_template.jinja2",
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
structlogger = structlog.get_logger()
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class PostProcessingType(Enum):
|
|
63
|
+
PLAIN = "PLAIN"
|
|
64
|
+
AGGREGATED_SUMMARY = "AGGREGATED_SUMMARY"
|
|
65
|
+
INDIVIDUAL_SUMMARIES = "INDIVIDUAL_SUMMARIES"
|
|
66
|
+
BINARY_LLM = "BINARY_LLM"
|
|
67
|
+
BINARY_EMBEDDING_MODEL = "BINARY_EMBEDDING_MODEL"
|
|
68
|
+
FINAL_ANSWER = "FINAL_ANSWER"
|
|
69
|
+
|
|
70
|
+
def __str__(self) -> str:
|
|
71
|
+
return self.value
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class DocumentPostProcessor(LLMHealthCheckMixin):
|
|
75
|
+
@classmethod
|
|
76
|
+
def get_default_config(cls) -> Dict[str, Any]:
|
|
77
|
+
"""The default config for the document post processor."""
|
|
78
|
+
return {
|
|
79
|
+
TYPE_CONFIG_KEY: PostProcessingType.PLAIN,
|
|
80
|
+
LLM_CONFIG_KEY: DEFAULT_LLM_CONFIG,
|
|
81
|
+
PROMPT_TEMPLATE_CONFIG_KEY: DEFAULT_DOCUMENT_POST_PROCESSOR_PROMPT_TEMPLATE,
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
def __init__(
|
|
85
|
+
self,
|
|
86
|
+
config: Dict[str, Any],
|
|
87
|
+
model_storage: ModelStorage,
|
|
88
|
+
resource: Resource,
|
|
89
|
+
prompt_template: Optional[str] = None,
|
|
90
|
+
):
|
|
91
|
+
self.config = {**self.get_default_config(), **config}
|
|
92
|
+
self.config[LLM_CONFIG_KEY] = resolve_model_client_config(
|
|
93
|
+
self.config.get(LLM_CONFIG_KEY), DocumentPostProcessor.__name__
|
|
94
|
+
)
|
|
95
|
+
self.prompt_template = prompt_template or get_prompt_template(
|
|
96
|
+
config.get(PROMPT_TEMPLATE_CONFIG_KEY),
|
|
97
|
+
DEFAULT_DOCUMENT_POST_PROCESSOR_PROMPT_TEMPLATE,
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
self._model_storage = model_storage
|
|
101
|
+
self._resource = resource
|
|
102
|
+
|
|
103
|
+
@classmethod
|
|
104
|
+
def load(
|
|
105
|
+
cls,
|
|
106
|
+
config: Dict[Text, Any],
|
|
107
|
+
model_storage: ModelStorage,
|
|
108
|
+
resource: Resource,
|
|
109
|
+
**kwargs: Any,
|
|
110
|
+
) -> "DocumentPostProcessor":
|
|
111
|
+
"""Load document post processor."""
|
|
112
|
+
llm_config = resolve_model_client_config(config.get(LLM_CONFIG_KEY, {}))
|
|
113
|
+
perform_llm_health_check(
|
|
114
|
+
llm_config,
|
|
115
|
+
DEFAULT_LLM_CONFIG,
|
|
116
|
+
"document_post_processor.load",
|
|
117
|
+
DocumentPostProcessor.__name__,
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
# load prompt template
|
|
121
|
+
prompt_template = None
|
|
122
|
+
try:
|
|
123
|
+
with model_storage.read_from(resource) as path:
|
|
124
|
+
prompt_template = rasa.shared.utils.io.read_file(
|
|
125
|
+
path / DOCUMENT_POST_PROCESSOR_PROMPT_FILE_NAME
|
|
126
|
+
)
|
|
127
|
+
except (FileNotFoundError, FileIOException) as e:
|
|
128
|
+
structlogger.warning(
|
|
129
|
+
"document_post_processor.load_prompt_template.failed",
|
|
130
|
+
error=e,
|
|
131
|
+
resource=resource.name,
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
return DocumentPostProcessor(config, model_storage, resource, prompt_template)
|
|
135
|
+
|
|
136
|
+
def persist(self) -> None:
|
|
137
|
+
with self._model_storage.write_to(self._resource) as path:
|
|
138
|
+
rasa.shared.utils.io.write_text_file(
|
|
139
|
+
self.prompt_template, path / DOCUMENT_POST_PROCESSOR_PROMPT_FILE_NAME
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
async def process_documents(
|
|
143
|
+
self,
|
|
144
|
+
message: Message,
|
|
145
|
+
search_query: str,
|
|
146
|
+
documents: SearchResultList,
|
|
147
|
+
tracker: DialogueStateTracker,
|
|
148
|
+
) -> SearchResultList:
|
|
149
|
+
processing_type = self.config.get(TYPE_CONFIG_KEY)
|
|
150
|
+
|
|
151
|
+
llm = llm_factory(self.config.get(LLM_CONFIG_KEY), DEFAULT_LLM_CONFIG)
|
|
152
|
+
|
|
153
|
+
if processing_type == PostProcessingType.AGGREGATED_SUMMARY.value:
|
|
154
|
+
return await self._create_aggregated_summary(documents, llm)
|
|
155
|
+
|
|
156
|
+
elif processing_type == PostProcessingType.INDIVIDUAL_SUMMARIES.value:
|
|
157
|
+
return await self._create_individual_summaries(documents, llm)
|
|
158
|
+
|
|
159
|
+
elif processing_type == PostProcessingType.BINARY_LLM.value:
|
|
160
|
+
return await self._check_documents_relevance_to_user_query(
|
|
161
|
+
message, search_query, documents, llm, tracker
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
elif processing_type == PostProcessingType.BINARY_EMBEDDING_MODEL.value:
|
|
165
|
+
return (
|
|
166
|
+
await self._check_documents_relevance_to_user_query_using_modern_bert(
|
|
167
|
+
search_query,
|
|
168
|
+
documents,
|
|
169
|
+
)
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
elif processing_type == PostProcessingType.PLAIN.value:
|
|
173
|
+
return documents
|
|
174
|
+
|
|
175
|
+
elif processing_type == PostProcessingType.FINAL_ANSWER.value:
|
|
176
|
+
return await self._generate_final_answer(message, documents, llm, tracker)
|
|
177
|
+
|
|
178
|
+
else:
|
|
179
|
+
raise ValueError(f"Invalid postprocessing type: {processing_type}")
|
|
180
|
+
|
|
181
|
+
@lru_cache
|
|
182
|
+
def compile_template(self, template: str) -> Template:
|
|
183
|
+
"""Compile the prompt template.
|
|
184
|
+
|
|
185
|
+
Compiling the template is an expensive operation,
|
|
186
|
+
so we cache the result.
|
|
187
|
+
"""
|
|
188
|
+
return Template(template)
|
|
189
|
+
|
|
190
|
+
def render_prompt(self, data: Dict) -> str:
|
|
191
|
+
# TODO: This should probably be fixed, as the default prompt template is empty
|
|
192
|
+
# If there are default templates for summarization they should be created,
|
|
193
|
+
# and ideally be initialized based on the processing type.
|
|
194
|
+
prompt_template = get_prompt_template(
|
|
195
|
+
self.config.get(PROMPT_TEMPLATE_CONFIG_KEY),
|
|
196
|
+
DEFAULT_DOCUMENT_POST_PROCESSOR_PROMPT_TEMPLATE,
|
|
197
|
+
)
|
|
198
|
+
return self.compile_template(prompt_template).render(**data)
|
|
199
|
+
|
|
200
|
+
async def _invoke_llm(self, prompt: str, llm: LLMClient) -> Optional[LLMResponse]:
|
|
201
|
+
try:
|
|
202
|
+
return await llm.acompletion(prompt)
|
|
203
|
+
except Exception as e:
|
|
204
|
+
# unfortunately, langchain does not wrap LLM exceptions which means
|
|
205
|
+
# we have to catch all exceptions here
|
|
206
|
+
structlogger.error("document_post_processor.llm.error", error=e)
|
|
207
|
+
raise ProviderClientAPIException(
|
|
208
|
+
message="LLM call exception", original_exception=e
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
async def _create_aggregated_summary(
|
|
212
|
+
self, documents: SearchResultList, llm: LLMClient
|
|
213
|
+
) -> SearchResultList:
|
|
214
|
+
prompt = self.render_prompt(
|
|
215
|
+
{"retrieval_results": [doc.text for doc in documents.results]}
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
llm_response = await self._invoke_llm(prompt, llm)
|
|
219
|
+
aggregated_summary = LLMResponse.ensure_llm_response(llm_response)
|
|
220
|
+
|
|
221
|
+
aggregated_result = SearchResult(
|
|
222
|
+
text=aggregated_summary.choices[0], metadata={}
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
return SearchResultList(results=[aggregated_result], metadata={})
|
|
226
|
+
|
|
227
|
+
async def _create_individual_summaries(
|
|
228
|
+
self, documents: SearchResultList, llm: LLMClient
|
|
229
|
+
) -> SearchResultList:
|
|
230
|
+
tasks = []
|
|
231
|
+
|
|
232
|
+
for doc in documents.results:
|
|
233
|
+
prompt_template = self.render_prompt({"retrieval_results": doc.text})
|
|
234
|
+
prompt = prompt_template.format(doc.text, llm)
|
|
235
|
+
tasks.append(asyncio.create_task(self._invoke_llm(prompt, llm)))
|
|
236
|
+
|
|
237
|
+
llm_responses = await asyncio.gather(*tasks)
|
|
238
|
+
summarized_contents = [
|
|
239
|
+
LLMResponse.ensure_llm_response(summary) for summary in llm_responses
|
|
240
|
+
]
|
|
241
|
+
|
|
242
|
+
results = [
|
|
243
|
+
SearchResult(text=summary.choices[0], metadata={})
|
|
244
|
+
for summary in summarized_contents
|
|
245
|
+
]
|
|
246
|
+
return SearchResultList(results=results, metadata={})
|
|
247
|
+
|
|
248
|
+
async def _check_documents_relevance_to_user_query(
|
|
249
|
+
self,
|
|
250
|
+
message: Message,
|
|
251
|
+
search_query: str,
|
|
252
|
+
documents: SearchResultList,
|
|
253
|
+
llm: LLMClient,
|
|
254
|
+
tracker: DialogueStateTracker,
|
|
255
|
+
) -> SearchResultList:
|
|
256
|
+
# If no documents were retrieved from the vector store, the
|
|
257
|
+
# documents seem to be irrelevant. Respond with "NO".
|
|
258
|
+
if not documents.results:
|
|
259
|
+
return SearchResultList(
|
|
260
|
+
results=[
|
|
261
|
+
SearchResult(
|
|
262
|
+
text="NO",
|
|
263
|
+
metadata={},
|
|
264
|
+
)
|
|
265
|
+
],
|
|
266
|
+
metadata={},
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
prompt_data = {
|
|
270
|
+
"search_query": search_query,
|
|
271
|
+
"relevant_documents": documents,
|
|
272
|
+
"conversation": tracker_as_readable_transcript(tracker, max_turns=10),
|
|
273
|
+
}
|
|
274
|
+
|
|
275
|
+
prompt = self.render_prompt(prompt_data)
|
|
276
|
+
|
|
277
|
+
llm_response = await self._invoke_llm(prompt, llm)
|
|
278
|
+
documents_relevance = LLMResponse.ensure_llm_response(llm_response)
|
|
279
|
+
|
|
280
|
+
aggregated_result = SearchResult(
|
|
281
|
+
text=documents_relevance.choices[0],
|
|
282
|
+
metadata={},
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
add_prompt_to_message_parse_data(
|
|
286
|
+
message=message,
|
|
287
|
+
component_name=self.__class__.__name__,
|
|
288
|
+
prompt_name="document_post_processor",
|
|
289
|
+
user_prompt=prompt,
|
|
290
|
+
llm_response=llm_response,
|
|
291
|
+
)
|
|
292
|
+
structlogger.debug(
|
|
293
|
+
"document_post_processor._check_documents_relevance_to_user_query",
|
|
294
|
+
prompt=prompt,
|
|
295
|
+
documents=[d.text for d in documents.results],
|
|
296
|
+
llm_response=llm_response,
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
return SearchResultList(results=[aggregated_result], metadata={})
|
|
300
|
+
|
|
301
|
+
async def _check_documents_relevance_to_user_query_using_modern_bert(
|
|
302
|
+
self,
|
|
303
|
+
search_query: str,
|
|
304
|
+
documents: SearchResultList,
|
|
305
|
+
threshold: float = 0.5,
|
|
306
|
+
) -> SearchResultList:
|
|
307
|
+
import torch
|
|
308
|
+
from sentence_transformers import SentenceTransformer
|
|
309
|
+
|
|
310
|
+
self.model = SentenceTransformer(
|
|
311
|
+
self.config.get(EMBEDDING_MODEL_KEY, DEFAULT_EMBEDDING_MODEL_NAME),
|
|
312
|
+
trust_remote_code=True,
|
|
313
|
+
)
|
|
314
|
+
|
|
315
|
+
query_embeddings = self.model.encode(["search_query: " + search_query])
|
|
316
|
+
doc_embeddings = self.model.encode(
|
|
317
|
+
["search_document: " + doc.text for doc in documents.results]
|
|
318
|
+
)
|
|
319
|
+
|
|
320
|
+
similarities = self.model.similarity(query_embeddings, doc_embeddings)
|
|
321
|
+
|
|
322
|
+
is_any_doc_relevant = torch.any(similarities > threshold).item()
|
|
323
|
+
|
|
324
|
+
return SearchResultList(
|
|
325
|
+
results=[
|
|
326
|
+
SearchResult(text="YES" if is_any_doc_relevant else "NO", metadata={})
|
|
327
|
+
],
|
|
328
|
+
metadata={},
|
|
329
|
+
)
|
|
330
|
+
|
|
331
|
+
async def _generate_final_answer(
|
|
332
|
+
self,
|
|
333
|
+
message: Message,
|
|
334
|
+
documents: SearchResultList,
|
|
335
|
+
llm: LLMClient,
|
|
336
|
+
tracker: DialogueStateTracker,
|
|
337
|
+
) -> SearchResultList:
|
|
338
|
+
input = {
|
|
339
|
+
"current_conversation": tracker_as_readable_transcript(tracker),
|
|
340
|
+
"relevant_documents": documents,
|
|
341
|
+
"user_message": message.get(TEXT),
|
|
342
|
+
}
|
|
343
|
+
prompt = self.render_prompt(input)
|
|
344
|
+
response = await self._invoke_llm(prompt, llm)
|
|
345
|
+
response_text = response.choices[0] if response else ""
|
|
346
|
+
search_result = SearchResult(text=response_text, metadata={})
|
|
347
|
+
results = SearchResultList(
|
|
348
|
+
results=[search_result],
|
|
349
|
+
metadata={},
|
|
350
|
+
)
|
|
351
|
+
return results
|
|
File without changes
|
|
@@ -0,0 +1,333 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import datetime
|
|
4
|
+
import time
|
|
5
|
+
import uuid
|
|
6
|
+
from enum import Enum
|
|
7
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Text
|
|
8
|
+
|
|
9
|
+
import structlog
|
|
10
|
+
|
|
11
|
+
from rasa.core.information_retrieval.faiss import FAISS_Store
|
|
12
|
+
from rasa.dialogue_understanding.utils import add_prompt_to_message_parse_data
|
|
13
|
+
from rasa.document_retrieval.constants import (
|
|
14
|
+
CONNECTOR_CONFIG_KEY,
|
|
15
|
+
DEFAULT_K,
|
|
16
|
+
DEFAULT_THRESHOLD,
|
|
17
|
+
K_CONFIG_KEY,
|
|
18
|
+
POST_PROCESSED_DOCUMENTS_KEY,
|
|
19
|
+
POST_PROCESSING_CONFIG_KEY,
|
|
20
|
+
QUERY_REWRITING_CONFIG_KEY,
|
|
21
|
+
RETRIEVED_DOCUMENTS_KEY,
|
|
22
|
+
SEARCH_QUERY_KEY,
|
|
23
|
+
SOURCE_PROPERTY,
|
|
24
|
+
THRESHOLD_CONFIG_KEY,
|
|
25
|
+
USE_LLM_PROPERTY,
|
|
26
|
+
VECTOR_STORE_CONFIG_KEY,
|
|
27
|
+
VECTOR_STORE_TYPE_CONFIG_KEY,
|
|
28
|
+
)
|
|
29
|
+
from rasa.document_retrieval.document_post_processor import DocumentPostProcessor
|
|
30
|
+
from rasa.document_retrieval.knowledge_base_connectors.api_connector import APIConnector
|
|
31
|
+
from rasa.document_retrieval.knowledge_base_connectors.knowledge_base_connector import (
|
|
32
|
+
KnowledgeBaseConnector,
|
|
33
|
+
)
|
|
34
|
+
from rasa.document_retrieval.knowledge_base_connectors.vector_store_connector import (
|
|
35
|
+
DEFAULT_EMBEDDINGS_CONFIG,
|
|
36
|
+
VectorStoreConnector,
|
|
37
|
+
VectorStoreType,
|
|
38
|
+
)
|
|
39
|
+
from rasa.document_retrieval.query_rewriter import QueryRewriter
|
|
40
|
+
from rasa.engine.graph import ExecutionContext, GraphComponent
|
|
41
|
+
from rasa.engine.recipes.default_recipe import DefaultV1Recipe
|
|
42
|
+
from rasa.engine.storage.resource import Resource
|
|
43
|
+
from rasa.engine.storage.storage import ModelStorage
|
|
44
|
+
from rasa.shared.constants import (
|
|
45
|
+
EMBEDDINGS_CONFIG_KEY,
|
|
46
|
+
)
|
|
47
|
+
from rasa.shared.core.trackers import DialogueStateTracker
|
|
48
|
+
from rasa.shared.nlu.training_data.message import Message
|
|
49
|
+
from rasa.shared.nlu.training_data.training_data import TrainingData
|
|
50
|
+
from rasa.shared.providers.embedding._langchain_embedding_client_adapter import (
|
|
51
|
+
_LangchainEmbeddingClientAdapter,
|
|
52
|
+
)
|
|
53
|
+
from rasa.shared.providers.llm.llm_response import LLMResponse
|
|
54
|
+
from rasa.shared.utils.llm import (
|
|
55
|
+
embedder_factory,
|
|
56
|
+
resolve_model_client_config,
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
if TYPE_CHECKING:
|
|
60
|
+
from langchain.schema.embeddings import Embeddings
|
|
61
|
+
|
|
62
|
+
structlogger = structlog.get_logger()
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class ConnectorType(Enum):
|
|
66
|
+
API = "API"
|
|
67
|
+
VECTOR_STORE = "VECTOR_STORE"
|
|
68
|
+
|
|
69
|
+
def __str__(self) -> str:
|
|
70
|
+
return self.value
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
@DefaultV1Recipe.register(
|
|
74
|
+
[
|
|
75
|
+
DefaultV1Recipe.ComponentType.COEXISTENCE_ROUTER,
|
|
76
|
+
],
|
|
77
|
+
is_trainable=True,
|
|
78
|
+
)
|
|
79
|
+
class DocumentRetriever(GraphComponent):
|
|
80
|
+
@staticmethod
|
|
81
|
+
def get_default_config() -> Dict[str, Any]:
|
|
82
|
+
"""The component's default config (see parent class for full docstring)."""
|
|
83
|
+
return {
|
|
84
|
+
THRESHOLD_CONFIG_KEY: DEFAULT_THRESHOLD,
|
|
85
|
+
K_CONFIG_KEY: DEFAULT_K,
|
|
86
|
+
CONNECTOR_CONFIG_KEY: ConnectorType.VECTOR_STORE.value,
|
|
87
|
+
EMBEDDINGS_CONFIG_KEY: DEFAULT_EMBEDDINGS_CONFIG,
|
|
88
|
+
VECTOR_STORE_CONFIG_KEY: {
|
|
89
|
+
VECTOR_STORE_TYPE_CONFIG_KEY: VectorStoreType.FAISS.value,
|
|
90
|
+
},
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
def __init__(
|
|
94
|
+
self,
|
|
95
|
+
config: Dict[str, Any],
|
|
96
|
+
model_storage: ModelStorage,
|
|
97
|
+
resource: Resource,
|
|
98
|
+
query_rewriter: Optional[QueryRewriter] = None,
|
|
99
|
+
document_post_processor: Optional[DocumentPostProcessor] = None,
|
|
100
|
+
knowledge_base_connector: Optional[KnowledgeBaseConnector] = None,
|
|
101
|
+
) -> None:
|
|
102
|
+
self.config = {**self.get_default_config(), **config}
|
|
103
|
+
self.config[EMBEDDINGS_CONFIG_KEY] = resolve_model_client_config(
|
|
104
|
+
self.config.get(EMBEDDINGS_CONFIG_KEY), DocumentRetriever.__name__
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
self._model_storage = model_storage
|
|
108
|
+
self._resource = resource
|
|
109
|
+
|
|
110
|
+
# Disable query rewriting and post processing if they are not set
|
|
111
|
+
query_rewriting_config = config.get(
|
|
112
|
+
QUERY_REWRITING_CONFIG_KEY, {"type": "PLAIN"}
|
|
113
|
+
)
|
|
114
|
+
post_processing_config = config.get(
|
|
115
|
+
POST_PROCESSING_CONFIG_KEY, {"type": "PLAIN"}
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
self.query_rewriter = query_rewriter or QueryRewriter(
|
|
119
|
+
query_rewriting_config, model_storage, resource
|
|
120
|
+
)
|
|
121
|
+
self.document_post_processor = document_post_processor or DocumentPostProcessor(
|
|
122
|
+
post_processing_config, model_storage, resource
|
|
123
|
+
)
|
|
124
|
+
self.knowledge_base_connector = (
|
|
125
|
+
knowledge_base_connector or self.initialize_knowledge_base_connector()
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
self.use_llm = self.config.get(USE_LLM_PROPERTY, False)
|
|
129
|
+
|
|
130
|
+
def persist(self) -> None:
|
|
131
|
+
"""Persist this component to disk for future loading."""
|
|
132
|
+
self.query_rewriter.persist()
|
|
133
|
+
self.document_post_processor.persist()
|
|
134
|
+
|
|
135
|
+
@classmethod
|
|
136
|
+
def _create_plain_embedder(cls, config: Dict[Text, Any]) -> "Embeddings":
|
|
137
|
+
"""Creates an embedder based on the given configuration.
|
|
138
|
+
|
|
139
|
+
Returns:
|
|
140
|
+
The embedder.
|
|
141
|
+
"""
|
|
142
|
+
# Copy the config so original config is not modified
|
|
143
|
+
config = config.copy()
|
|
144
|
+
# Resolve config and instantiate the embedding client
|
|
145
|
+
config[EMBEDDINGS_CONFIG_KEY] = resolve_model_client_config(
|
|
146
|
+
config.get(EMBEDDINGS_CONFIG_KEY), DocumentRetriever.__name__
|
|
147
|
+
)
|
|
148
|
+
client = embedder_factory(
|
|
149
|
+
config.get(EMBEDDINGS_CONFIG_KEY), DEFAULT_EMBEDDINGS_CONFIG
|
|
150
|
+
)
|
|
151
|
+
# Wrap the embedding client in the adapter
|
|
152
|
+
return _LangchainEmbeddingClientAdapter(client)
|
|
153
|
+
|
|
154
|
+
def train(self, training_data: TrainingData) -> Resource:
|
|
155
|
+
"""Train the document retriever on a data set."""
|
|
156
|
+
store_type = self.config.get(VECTOR_STORE_CONFIG_KEY, {}).get(
|
|
157
|
+
VECTOR_STORE_TYPE_CONFIG_KEY
|
|
158
|
+
)
|
|
159
|
+
if store_type == VectorStoreType.FAISS.value:
|
|
160
|
+
structlogger.info("document_retriever.train.faiss")
|
|
161
|
+
embeddings = self._create_plain_embedder(self.config)
|
|
162
|
+
with self._model_storage.write_to(self._resource) as path:
|
|
163
|
+
self.vector_store = FAISS_Store(
|
|
164
|
+
docs_folder=self.config.get(VECTOR_STORE_CONFIG_KEY, {}).get(
|
|
165
|
+
SOURCE_PROPERTY
|
|
166
|
+
),
|
|
167
|
+
embeddings=embeddings,
|
|
168
|
+
index_path=path,
|
|
169
|
+
create_index=True,
|
|
170
|
+
use_llm=self.use_llm,
|
|
171
|
+
)
|
|
172
|
+
self.persist()
|
|
173
|
+
return self._resource
|
|
174
|
+
|
|
175
|
+
@classmethod
|
|
176
|
+
def load(
|
|
177
|
+
cls,
|
|
178
|
+
config: Dict[str, Any],
|
|
179
|
+
model_storage: ModelStorage,
|
|
180
|
+
resource: Resource,
|
|
181
|
+
execution_context: ExecutionContext,
|
|
182
|
+
**kwargs: Any,
|
|
183
|
+
) -> "DocumentRetriever":
|
|
184
|
+
"""Loads trained component (see parent class for full docstring)."""
|
|
185
|
+
# Load query rewriter and document post processor
|
|
186
|
+
|
|
187
|
+
# Disable query rewriting and post processing if they are not set
|
|
188
|
+
query_rewriting_config = config.get(
|
|
189
|
+
QUERY_REWRITING_CONFIG_KEY, {"type": "PLAIN"}
|
|
190
|
+
)
|
|
191
|
+
post_processing_config = config.get(
|
|
192
|
+
POST_PROCESSING_CONFIG_KEY, {"type": "PLAIN"}
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
query_rewriter = QueryRewriter.load(
|
|
196
|
+
query_rewriting_config, model_storage, resource
|
|
197
|
+
)
|
|
198
|
+
document_post_processor = DocumentPostProcessor.load(
|
|
199
|
+
post_processing_config, model_storage, resource
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
connector_type = config.get(CONNECTOR_CONFIG_KEY)
|
|
203
|
+
knowledge_base_connector: KnowledgeBaseConnector
|
|
204
|
+
|
|
205
|
+
if connector_type == ConnectorType.VECTOR_STORE.value:
|
|
206
|
+
knowledge_base_connector = VectorStoreConnector.load(
|
|
207
|
+
config, model_storage, resource
|
|
208
|
+
)
|
|
209
|
+
elif connector_type == ConnectorType.API.value:
|
|
210
|
+
knowledge_base_connector = APIConnector.load(
|
|
211
|
+
config, model_storage, resource
|
|
212
|
+
)
|
|
213
|
+
else:
|
|
214
|
+
raise ValueError(f"Invalid knowledge base connector: {connector_type}")
|
|
215
|
+
|
|
216
|
+
return cls(
|
|
217
|
+
config,
|
|
218
|
+
model_storage,
|
|
219
|
+
resource,
|
|
220
|
+
query_rewriter,
|
|
221
|
+
document_post_processor,
|
|
222
|
+
knowledge_base_connector,
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
@classmethod
|
|
226
|
+
def create(
|
|
227
|
+
cls,
|
|
228
|
+
config: Dict[str, Any],
|
|
229
|
+
model_storage: ModelStorage,
|
|
230
|
+
resource: Resource,
|
|
231
|
+
execution_context: ExecutionContext,
|
|
232
|
+
) -> DocumentRetriever:
|
|
233
|
+
"""Creates component (see parent class for full docstring)."""
|
|
234
|
+
return cls(config, model_storage, resource)
|
|
235
|
+
|
|
236
|
+
def initialize_knowledge_base_connector(self) -> KnowledgeBaseConnector:
|
|
237
|
+
connector_type = self.config.get(CONNECTOR_CONFIG_KEY)
|
|
238
|
+
|
|
239
|
+
if connector_type == ConnectorType.VECTOR_STORE.value:
|
|
240
|
+
return VectorStoreConnector(
|
|
241
|
+
self.config,
|
|
242
|
+
self._model_storage,
|
|
243
|
+
self._resource,
|
|
244
|
+
)
|
|
245
|
+
elif connector_type == ConnectorType.API.value:
|
|
246
|
+
return APIConnector(self.config)
|
|
247
|
+
else:
|
|
248
|
+
raise ValueError(f"Invalid knowledge base connector: {type}")
|
|
249
|
+
|
|
250
|
+
async def process(
|
|
251
|
+
self,
|
|
252
|
+
messages: List[Message],
|
|
253
|
+
tracker: Optional[DialogueStateTracker] = None,
|
|
254
|
+
) -> List[Message]:
|
|
255
|
+
"""Process a list of messages."""
|
|
256
|
+
self.knowledge_base_connector.connect_or_raise()
|
|
257
|
+
|
|
258
|
+
for message in messages:
|
|
259
|
+
start = time.time()
|
|
260
|
+
|
|
261
|
+
# Prepare search query
|
|
262
|
+
search_query = await self.query_rewriter.prepare_search_query(
|
|
263
|
+
message, tracker
|
|
264
|
+
)
|
|
265
|
+
message.set(
|
|
266
|
+
SEARCH_QUERY_KEY,
|
|
267
|
+
search_query,
|
|
268
|
+
add_to_output=True,
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
# Retrieve documents
|
|
272
|
+
search_result = await self.knowledge_base_connector.retrieve_documents(
|
|
273
|
+
search_query,
|
|
274
|
+
self.config[K_CONFIG_KEY] or DEFAULT_K,
|
|
275
|
+
self.config[THRESHOLD_CONFIG_KEY] or DEFAULT_THRESHOLD,
|
|
276
|
+
tracker,
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
if search_result is None:
|
|
280
|
+
message.set(
|
|
281
|
+
RETRIEVED_DOCUMENTS_KEY,
|
|
282
|
+
[],
|
|
283
|
+
add_to_output=True,
|
|
284
|
+
)
|
|
285
|
+
message.set(
|
|
286
|
+
POST_PROCESSED_DOCUMENTS_KEY,
|
|
287
|
+
[],
|
|
288
|
+
add_to_output=True,
|
|
289
|
+
)
|
|
290
|
+
continue
|
|
291
|
+
|
|
292
|
+
message.set(
|
|
293
|
+
RETRIEVED_DOCUMENTS_KEY,
|
|
294
|
+
search_result.to_dict(),
|
|
295
|
+
add_to_output=True,
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
# Post process documents
|
|
299
|
+
final_search_result = await self.document_post_processor.process_documents(
|
|
300
|
+
message, search_query, search_result, tracker
|
|
301
|
+
)
|
|
302
|
+
message.set(
|
|
303
|
+
POST_PROCESSED_DOCUMENTS_KEY,
|
|
304
|
+
final_search_result.to_dict(),
|
|
305
|
+
add_to_output=True,
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
structlogger.debug(
|
|
309
|
+
"document_retriever.process",
|
|
310
|
+
search_query=search_query,
|
|
311
|
+
search_result=search_result.to_dict(),
|
|
312
|
+
final_search_result=final_search_result.to_dict(),
|
|
313
|
+
)
|
|
314
|
+
|
|
315
|
+
end = time.time()
|
|
316
|
+
add_prompt_to_message_parse_data(
|
|
317
|
+
message,
|
|
318
|
+
DocumentRetriever.__name__,
|
|
319
|
+
"document_retriever_process",
|
|
320
|
+
user_prompt="Dummy prompt for document retriever process.",
|
|
321
|
+
llm_response=LLMResponse(
|
|
322
|
+
id=str(uuid.uuid4()),
|
|
323
|
+
choices=[
|
|
324
|
+
f"search_query: {search_query}\n"
|
|
325
|
+
f"retrieved_documents: {search_result.to_dict()}\n"
|
|
326
|
+
f"post_processed_documents: {final_search_result.to_dict()}",
|
|
327
|
+
],
|
|
328
|
+
created=int(datetime.datetime.now().timestamp()),
|
|
329
|
+
latency=end - start,
|
|
330
|
+
),
|
|
331
|
+
)
|
|
332
|
+
|
|
333
|
+
return messages
|
|
File without changes
|