MindsDB 25.1.2.1__py3-none-any.whl → 25.1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of MindsDB might be problematic. Click here for more details.
- {MindsDB-25.1.2.1.dist-info → MindsDB-25.1.5.0.dist-info}/METADATA +246 -255
- {MindsDB-25.1.2.1.dist-info → MindsDB-25.1.5.0.dist-info}/RECORD +94 -83
- mindsdb/__about__.py +1 -1
- mindsdb/__main__.py +5 -3
- mindsdb/api/executor/__init__.py +0 -1
- mindsdb/api/executor/command_executor.py +2 -1
- mindsdb/api/executor/data_types/answer.py +1 -1
- mindsdb/api/executor/datahub/datanodes/datanode.py +1 -1
- mindsdb/api/executor/datahub/datanodes/information_schema_datanode.py +1 -1
- mindsdb/api/executor/datahub/datanodes/integration_datanode.py +8 -3
- mindsdb/api/executor/datahub/datanodes/project_datanode.py +9 -26
- mindsdb/api/executor/sql_query/__init__.py +1 -0
- mindsdb/api/executor/sql_query/result_set.py +36 -21
- mindsdb/api/executor/sql_query/steps/apply_predictor_step.py +1 -1
- mindsdb/api/executor/sql_query/steps/join_step.py +4 -4
- mindsdb/api/executor/sql_query/steps/map_reduce_step.py +6 -39
- mindsdb/api/executor/utilities/sql.py +2 -10
- mindsdb/api/http/namespaces/agents.py +3 -1
- mindsdb/api/http/namespaces/knowledge_bases.py +3 -3
- mindsdb/api/http/namespaces/sql.py +3 -1
- mindsdb/api/mysql/mysql_proxy/executor/mysql_executor.py +2 -1
- mindsdb/api/mysql/mysql_proxy/mysql_proxy.py +7 -0
- mindsdb/api/postgres/postgres_proxy/executor/executor.py +2 -1
- mindsdb/integrations/handlers/chromadb_handler/chromadb_handler.py +2 -2
- mindsdb/integrations/handlers/chromadb_handler/requirements.txt +1 -1
- mindsdb/integrations/handlers/databricks_handler/requirements.txt +1 -1
- mindsdb/integrations/handlers/file_handler/file_handler.py +1 -1
- mindsdb/integrations/handlers/file_handler/requirements.txt +0 -4
- mindsdb/integrations/handlers/file_handler/tests/test_file_handler.py +17 -1
- mindsdb/integrations/handlers/jira_handler/jira_handler.py +15 -1
- mindsdb/integrations/handlers/jira_handler/jira_table.py +52 -31
- mindsdb/integrations/handlers/langchain_embedding_handler/fastapi_embeddings.py +82 -0
- mindsdb/integrations/handlers/langchain_embedding_handler/langchain_embedding_handler.py +8 -1
- mindsdb/integrations/handlers/langchain_handler/requirements.txt +1 -1
- mindsdb/integrations/handlers/ms_one_drive_handler/ms_one_drive_handler.py +1 -1
- mindsdb/integrations/handlers/ms_one_drive_handler/ms_one_drive_tables.py +8 -0
- mindsdb/integrations/handlers/pgvector_handler/pgvector_handler.py +50 -16
- mindsdb/integrations/handlers/pinecone_handler/pinecone_handler.py +123 -72
- mindsdb/integrations/handlers/pinecone_handler/requirements.txt +1 -1
- mindsdb/integrations/handlers/postgres_handler/postgres_handler.py +12 -6
- mindsdb/integrations/handlers/ray_serve_handler/ray_serve_handler.py +5 -3
- mindsdb/integrations/handlers/slack_handler/slack_handler.py +13 -2
- mindsdb/integrations/handlers/slack_handler/slack_tables.py +21 -1
- mindsdb/integrations/handlers/web_handler/requirements.txt +0 -1
- mindsdb/integrations/libs/ml_handler_process/learn_process.py +2 -2
- mindsdb/integrations/utilities/files/__init__.py +0 -0
- mindsdb/integrations/utilities/files/file_reader.py +258 -0
- mindsdb/integrations/utilities/handlers/api_utilities/microsoft/ms_graph_api_utilities.py +2 -1
- mindsdb/integrations/utilities/handlers/auth_utilities/microsoft/ms_graph_api_auth_utilities.py +8 -3
- mindsdb/integrations/utilities/rag/chains/map_reduce_summarizer_chain.py +5 -9
- mindsdb/integrations/utilities/rag/loaders/vector_store_loader/pgvector.py +76 -27
- mindsdb/integrations/utilities/rag/loaders/vector_store_loader/vector_store_loader.py +18 -1
- mindsdb/integrations/utilities/rag/pipelines/rag.py +74 -21
- mindsdb/integrations/utilities/rag/rerankers/reranker_compressor.py +166 -108
- mindsdb/integrations/utilities/rag/retrievers/sql_retriever.py +108 -78
- mindsdb/integrations/utilities/rag/settings.py +37 -16
- mindsdb/integrations/utilities/sql_utils.py +1 -1
- mindsdb/interfaces/agents/agents_controller.py +18 -8
- mindsdb/interfaces/agents/constants.py +1 -0
- mindsdb/interfaces/agents/langchain_agent.py +124 -157
- mindsdb/interfaces/agents/langfuse_callback_handler.py +4 -37
- mindsdb/interfaces/agents/mindsdb_database_agent.py +21 -13
- mindsdb/interfaces/chatbot/chatbot_controller.py +7 -11
- mindsdb/interfaces/chatbot/chatbot_task.py +16 -5
- mindsdb/interfaces/chatbot/memory.py +58 -13
- mindsdb/interfaces/database/integrations.py +5 -1
- mindsdb/interfaces/database/projects.py +55 -16
- mindsdb/interfaces/database/views.py +12 -25
- mindsdb/interfaces/knowledge_base/controller.py +39 -15
- mindsdb/interfaces/knowledge_base/preprocessing/document_loader.py +7 -26
- mindsdb/interfaces/model/functions.py +15 -4
- mindsdb/interfaces/model/model_controller.py +4 -7
- mindsdb/interfaces/skills/custom/text2sql/mindsdb_sql_toolkit.py +51 -40
- mindsdb/interfaces/skills/retrieval_tool.py +10 -3
- mindsdb/interfaces/skills/skill_tool.py +97 -54
- mindsdb/interfaces/skills/skills_controller.py +7 -3
- mindsdb/interfaces/skills/sql_agent.py +127 -41
- mindsdb/interfaces/storage/db.py +1 -1
- mindsdb/migrations/versions/2025-01-15_c06c35f7e8e1_project_company.py +88 -0
- mindsdb/utilities/cache.py +7 -4
- mindsdb/utilities/context.py +11 -1
- mindsdb/utilities/langfuse.py +279 -0
- mindsdb/utilities/log.py +20 -2
- mindsdb/utilities/otel/__init__.py +206 -0
- mindsdb/utilities/otel/logger.py +25 -0
- mindsdb/utilities/otel/meter.py +19 -0
- mindsdb/utilities/otel/metric_handlers/__init__.py +25 -0
- mindsdb/utilities/otel/tracer.py +16 -0
- mindsdb/utilities/partitioning.py +52 -0
- mindsdb/utilities/render/sqlalchemy_render.py +7 -1
- mindsdb/utilities/utils.py +34 -0
- mindsdb/utilities/otel.py +0 -72
- {MindsDB-25.1.2.1.dist-info → MindsDB-25.1.5.0.dist-info}/LICENSE +0 -0
- {MindsDB-25.1.2.1.dist-info → MindsDB-25.1.5.0.dist-info}/WHEEL +0 -0
- {MindsDB-25.1.2.1.dist-info → MindsDB-25.1.5.0.dist-info}/top_level.txt +0 -0
|
@@ -1,8 +1,9 @@
|
|
|
1
1
|
from copy import copy
|
|
2
|
-
from typing import Optional, Any
|
|
2
|
+
from typing import Optional, Any, List
|
|
3
3
|
|
|
4
4
|
from langchain_core.output_parsers import StrOutputParser
|
|
5
5
|
from langchain.retrievers import ContextualCompressionRetriever
|
|
6
|
+
from langchain_core.documents import Document
|
|
6
7
|
|
|
7
8
|
from langchain_core.prompts import ChatPromptTemplate
|
|
8
9
|
from langchain_core.runnables import RunnableParallel, RunnablePassthrough, RunnableSerializable
|
|
@@ -28,6 +29,23 @@ from mindsdb.interfaces.agents.langchain_agent import create_chat_model
|
|
|
28
29
|
class LangChainRAGPipeline:
|
|
29
30
|
"""
|
|
30
31
|
Builds a RAG pipeline using langchain LCEL components
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
retriever_runnable: Base retriever component
|
|
35
|
+
prompt_template: Template for generating responses
|
|
36
|
+
llm: Language model for generating responses
|
|
37
|
+
reranker (bool): Whether to use reranking (default: False)
|
|
38
|
+
reranker_config (RerankerConfig): Configuration for the reranker, including:
|
|
39
|
+
- model: Model to use for reranking
|
|
40
|
+
- filtering_threshold: Minimum score to keep a document
|
|
41
|
+
- num_docs_to_keep: Maximum number of documents to keep
|
|
42
|
+
- max_concurrent_requests: Maximum concurrent API requests
|
|
43
|
+
- max_retries: Number of retry attempts for failed requests
|
|
44
|
+
- retry_delay: Delay between retries
|
|
45
|
+
- early_stop (bool): Whether to enable early stopping
|
|
46
|
+
- early_stop_threshold: Confidence threshold for early stopping
|
|
47
|
+
vector_store_config (VectorStoreConfig): Vector store configuration
|
|
48
|
+
summarization_config (SummarizationConfig): Summarization configuration
|
|
31
49
|
"""
|
|
32
50
|
|
|
33
51
|
def __init__(
|
|
@@ -40,19 +58,15 @@ class LangChainRAGPipeline:
|
|
|
40
58
|
vector_store_config: Optional[VectorStoreConfig] = None,
|
|
41
59
|
summarization_config: Optional[SummarizationConfig] = None
|
|
42
60
|
):
|
|
43
|
-
|
|
44
61
|
self.retriever_runnable = retriever_runnable
|
|
45
62
|
self.prompt_template = prompt_template
|
|
46
63
|
self.llm = llm
|
|
47
64
|
if reranker:
|
|
48
65
|
if reranker_config is None:
|
|
49
66
|
reranker_config = RerankerConfig()
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
filtering_threshold=reranker_config.filtering_threshold,
|
|
54
|
-
num_docs_to_keep=reranker_config.num_docs_to_keep
|
|
55
|
-
)
|
|
67
|
+
# Convert config to dict and initialize reranker
|
|
68
|
+
reranker_kwargs = reranker_config.model_dump(exclude_none=True)
|
|
69
|
+
self.reranker = LLMReranker(**reranker_kwargs)
|
|
56
70
|
else:
|
|
57
71
|
self.reranker = None
|
|
58
72
|
self.summarizer = None
|
|
@@ -102,17 +116,45 @@ class LangChainRAGPipeline:
|
|
|
102
116
|
raise ValueError("One of the required components (llm) is None")
|
|
103
117
|
|
|
104
118
|
if self.reranker:
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
119
|
+
# Create a custom retriever that handles async operations properly
|
|
120
|
+
class AsyncRerankerRetriever(ContextualCompressionRetriever):
|
|
121
|
+
"""Async-aware retriever that properly handles concurrent reranking operations."""
|
|
122
|
+
|
|
123
|
+
def __init__(self, base_retriever, reranker):
|
|
124
|
+
super().__init__(
|
|
125
|
+
base_compressor=reranker,
|
|
126
|
+
base_retriever=base_retriever
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
async def ainvoke(self, query: str) -> List[Document]:
|
|
130
|
+
"""Async retrieval with proper concurrency handling."""
|
|
131
|
+
# Get initial documents
|
|
132
|
+
if hasattr(self.base_retriever, 'ainvoke'):
|
|
133
|
+
docs = await self.base_retriever.ainvoke(query)
|
|
134
|
+
else:
|
|
135
|
+
docs = await RunnablePassthrough(self.base_retriever.get_relevant_documents)(query)
|
|
136
|
+
|
|
137
|
+
# Rerank documents
|
|
138
|
+
if docs:
|
|
139
|
+
docs = await self.base_compressor.acompress_documents(docs, query)
|
|
140
|
+
return docs
|
|
141
|
+
|
|
142
|
+
def get_relevant_documents(self, query: str) -> List[Document]:
|
|
143
|
+
"""Sync wrapper for async retrieval."""
|
|
144
|
+
import asyncio
|
|
145
|
+
return asyncio.run(self.ainvoke(query))
|
|
146
|
+
|
|
147
|
+
# Use our custom async-aware retriever
|
|
148
|
+
self.retriever_runnable = AsyncRerankerRetriever(
|
|
149
|
+
base_retriever=copy(self.retriever_runnable),
|
|
150
|
+
reranker=self.reranker
|
|
109
151
|
)
|
|
110
152
|
|
|
111
153
|
rag_chain_from_docs = (
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
154
|
+
RunnablePassthrough.assign(context=(lambda x: format_docs(x["context"])))
|
|
155
|
+
| prompt
|
|
156
|
+
| self.llm
|
|
157
|
+
| StrOutputParser()
|
|
116
158
|
)
|
|
117
159
|
|
|
118
160
|
retrieval_chain = RunnableParallel(
|
|
@@ -125,6 +167,16 @@ class LangChainRAGPipeline:
|
|
|
125
167
|
rag_chain_with_source = retrieval_chain.assign(answer=rag_chain_from_docs)
|
|
126
168
|
return rag_chain_with_source
|
|
127
169
|
|
|
170
|
+
async def ainvoke(self, input_dict: dict) -> dict:
|
|
171
|
+
"""Async invocation of the RAG pipeline."""
|
|
172
|
+
chain = self.with_returned_sources()
|
|
173
|
+
return await chain.ainvoke(input_dict)
|
|
174
|
+
|
|
175
|
+
def invoke(self, input_dict: dict) -> dict:
|
|
176
|
+
"""Sync invocation of the RAG pipeline."""
|
|
177
|
+
import asyncio
|
|
178
|
+
return asyncio.run(self.ainvoke(input_dict))
|
|
179
|
+
|
|
128
180
|
@classmethod
|
|
129
181
|
def _apply_search_kwargs(cls, retriever: Any, search_kwargs: Optional[SearchKwargs] = None, search_type: Optional[SearchType] = None) -> Any:
|
|
130
182
|
"""Apply search kwargs and search type to the retriever if they exist"""
|
|
@@ -235,6 +287,10 @@ class LangChainRAGPipeline:
|
|
|
235
287
|
)
|
|
236
288
|
vector_store_retriever = vector_store_operator.vector_store.as_retriever()
|
|
237
289
|
vector_store_retriever = cls._apply_search_kwargs(vector_store_retriever, config.search_kwargs, config.search_type)
|
|
290
|
+
distance_function = DistanceFunction.SQUARED_EUCLIDEAN_DISTANCE
|
|
291
|
+
if config.vector_store_config.is_sparse and config.vector_store_config.vector_size is not None:
|
|
292
|
+
# Use negative dot product for sparse retrieval.
|
|
293
|
+
distance_function = DistanceFunction.NEGATIVE_DOT_PRODUCT
|
|
238
294
|
retriever = SQLRetriever(
|
|
239
295
|
fallback_retriever=vector_store_retriever,
|
|
240
296
|
vector_store_handler=knowledge_base_table.get_vector_db(),
|
|
@@ -242,14 +298,11 @@ class LangChainRAGPipeline:
|
|
|
242
298
|
examples=retriever_config.examples,
|
|
243
299
|
embeddings_model=embeddings,
|
|
244
300
|
rewrite_prompt_template=retriever_config.rewrite_prompt_template,
|
|
245
|
-
|
|
301
|
+
metadata_filters_prompt_template=retriever_config.metadata_filters_prompt_template,
|
|
246
302
|
num_retries=retriever_config.num_retries,
|
|
247
|
-
sql_prompt_template=retriever_config.sql_prompt_template,
|
|
248
|
-
query_checker_template=retriever_config.query_checker_template,
|
|
249
303
|
embeddings_table=knowledge_base_table._kb.vector_database_table,
|
|
250
304
|
source_table=retriever_config.source_table,
|
|
251
|
-
|
|
252
|
-
distance_function=DistanceFunction.SQUARED_EUCLIDEAN_DISTANCE,
|
|
305
|
+
distance_function=distance_function,
|
|
253
306
|
search_kwargs=config.search_kwargs,
|
|
254
307
|
llm=sql_llm
|
|
255
308
|
)
|
|
@@ -4,16 +4,15 @@ import asyncio
|
|
|
4
4
|
import logging
|
|
5
5
|
import math
|
|
6
6
|
import os
|
|
7
|
+
import random
|
|
7
8
|
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
|
8
|
-
|
|
9
|
+
|
|
9
10
|
from langchain.retrievers.document_compressors.base import BaseDocumentCompressor
|
|
10
11
|
from langchain_core.callbacks import Callbacks
|
|
11
|
-
|
|
12
|
-
from mindsdb.integrations.utilities.rag.settings import DEFAULT_RERANKING_MODEL, DEFAULT_LLM_ENDPOINT
|
|
13
12
|
from langchain_core.documents import Document
|
|
14
|
-
from
|
|
15
|
-
from langchain_openai import ChatOpenAI
|
|
13
|
+
from openai import AsyncOpenAI
|
|
16
14
|
|
|
15
|
+
from mindsdb.integrations.utilities.rag.settings import DEFAULT_RERANKING_MODEL, DEFAULT_LLM_ENDPOINT
|
|
17
16
|
|
|
18
17
|
log = logging.getLogger(__name__)
|
|
19
18
|
|
|
@@ -23,128 +22,187 @@ class LLMReranker(BaseDocumentCompressor):
|
|
|
23
22
|
model: str = DEFAULT_RERANKING_MODEL # Model to use for reranking
|
|
24
23
|
temperature: float = 0.0 # Temperature for the model
|
|
25
24
|
openai_api_key: Optional[str] = None
|
|
26
|
-
remove_irrelevant: bool = True # New flag to control removal of irrelevant documents
|
|
25
|
+
remove_irrelevant: bool = True # New flag to control removal of irrelevant documents
|
|
27
26
|
base_url: str = DEFAULT_LLM_ENDPOINT
|
|
28
27
|
num_docs_to_keep: Optional[int] = None # How many of the top documents to keep after reranking & compressing.
|
|
29
|
-
|
|
30
28
|
_api_key_var: str = "OPENAI_API_KEY"
|
|
31
|
-
client: Optional[
|
|
29
|
+
client: Optional[AsyncOpenAI] = None
|
|
30
|
+
_semaphore: Optional[asyncio.Semaphore] = None
|
|
31
|
+
max_concurrent_requests: int = 20
|
|
32
|
+
max_retries: int = 3
|
|
33
|
+
retry_delay: float = 1.0
|
|
34
|
+
request_timeout: float = 20.0 # Timeout for API requests
|
|
35
|
+
early_stop: bool = True # Whether to enable early stopping
|
|
36
|
+
early_stop_threshold: float = 0.8 # Confidence threshold for early stopping
|
|
32
37
|
|
|
33
38
|
class Config:
|
|
34
39
|
arbitrary_types_allowed = True
|
|
35
40
|
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
response = await client.agenerate(
|
|
52
|
-
messages=[message_history],
|
|
53
|
-
max_tokens=1
|
|
54
|
-
)
|
|
41
|
+
def __init__(self, **kwargs):
|
|
42
|
+
super().__init__(**kwargs)
|
|
43
|
+
self._semaphore = asyncio.Semaphore(self.max_concurrent_requests)
|
|
44
|
+
|
|
45
|
+
async def _init_client(self):
|
|
46
|
+
if self.client is None:
|
|
47
|
+
openai_api_key = self.openai_api_key or os.getenv(self._api_key_var)
|
|
48
|
+
if not openai_api_key:
|
|
49
|
+
raise ValueError(f"OpenAI API key not found in environment variable {self._api_key_var}")
|
|
50
|
+
self.client = AsyncOpenAI(
|
|
51
|
+
api_key=openai_api_key,
|
|
52
|
+
base_url=self.base_url,
|
|
53
|
+
timeout=self.request_timeout,
|
|
54
|
+
max_retries=2 # Client-level retries
|
|
55
|
+
)
|
|
55
56
|
|
|
56
|
-
|
|
57
|
-
|
|
57
|
+
async def search_relevancy(self, query: str, document: str) -> Any:
|
|
58
|
+
await self._init_client()
|
|
59
|
+
|
|
60
|
+
async with self._semaphore:
|
|
61
|
+
for attempt in range(self.max_retries):
|
|
62
|
+
try:
|
|
63
|
+
response = await self.client.chat.completions.create(
|
|
64
|
+
model=self.model,
|
|
65
|
+
messages=[
|
|
66
|
+
{"role": "system", "content": "Rate the relevance of the document to the query. Respond with 'yes' or 'no'."},
|
|
67
|
+
{"role": "user", "content": f"Query: {query}\nDocument: {document}\nIs this document relevant?"}
|
|
68
|
+
],
|
|
69
|
+
temperature=self.temperature,
|
|
70
|
+
n=1,
|
|
71
|
+
logprobs=True,
|
|
72
|
+
max_tokens=1
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
# Extract response and logprobs
|
|
76
|
+
answer = response.choices[0].message.content
|
|
77
|
+
logprob = response.choices[0].logprobs.content[0].logprob
|
|
78
|
+
|
|
79
|
+
return {"answer": answer, "logprob": logprob}
|
|
80
|
+
|
|
81
|
+
except Exception as e:
|
|
82
|
+
if attempt == self.max_retries - 1:
|
|
83
|
+
log.error(f"Failed after {self.max_retries} attempts: {str(e)}")
|
|
84
|
+
raise
|
|
85
|
+
# Exponential backoff with jitter
|
|
86
|
+
retry_delay = self.retry_delay * (2 ** attempt) + random.uniform(0, 0.1)
|
|
87
|
+
await asyncio.sleep(retry_delay)
|
|
58
88
|
|
|
59
89
|
async def _rank(self, query_document_pairs: List[Tuple[str, str]]) -> List[Tuple[str, float]]:
|
|
60
|
-
# Gather results asynchronously for all query-document pairs
|
|
61
|
-
results = await asyncio.gather(
|
|
62
|
-
*[self.search_relevancy(query=query, document=document) for (query, document) in query_document_pairs]
|
|
63
|
-
)
|
|
64
|
-
|
|
65
90
|
ranked_results = []
|
|
66
91
|
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
92
|
+
# Process in larger batches for better throughput
|
|
93
|
+
batch_size = min(self.max_concurrent_requests * 2, len(query_document_pairs))
|
|
94
|
+
for i in range(0, len(query_document_pairs), batch_size):
|
|
95
|
+
batch = query_document_pairs[i:i + batch_size]
|
|
96
|
+
try:
|
|
97
|
+
results = await asyncio.gather(
|
|
98
|
+
*[self.search_relevancy(query=query, document=document) for (query, document) in batch],
|
|
99
|
+
return_exceptions=True
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
for idx, result in enumerate(results):
|
|
103
|
+
if isinstance(result, Exception):
|
|
104
|
+
log.error(f"Error processing document {i+idx}: {str(result)}")
|
|
105
|
+
ranked_results.append((batch[idx][1], 0.0))
|
|
106
|
+
continue
|
|
107
|
+
|
|
108
|
+
answer = result["answer"]
|
|
109
|
+
logprob = result["logprob"]
|
|
110
|
+
prob = math.exp(logprob)
|
|
111
|
+
|
|
112
|
+
# Convert answer to score using the model's confidence
|
|
113
|
+
if answer.lower().strip() == "yes":
|
|
114
|
+
score = prob # If yes, use the model's confidence
|
|
115
|
+
elif answer.lower().strip() == "no":
|
|
116
|
+
score = 1 - prob # If no, invert the confidence
|
|
117
|
+
else:
|
|
118
|
+
score = 0.5 * prob # For unclear answers, reduce confidence
|
|
119
|
+
|
|
120
|
+
ranked_results.append((batch[idx][1], score))
|
|
121
|
+
|
|
122
|
+
# Check if we should stop early
|
|
123
|
+
high_scoring_docs = [r for r in ranked_results if r[1] >= self.filtering_threshold]
|
|
124
|
+
can_stop_early = (
|
|
125
|
+
self.early_stop # Early stopping is enabled
|
|
126
|
+
and self.num_docs_to_keep # We have a target number of docs
|
|
127
|
+
and len(high_scoring_docs) >= self.num_docs_to_keep # Found enough good docs
|
|
128
|
+
and score >= self.early_stop_threshold # Current doc is good enough
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
if can_stop_early:
|
|
132
|
+
log.info(f"Early stopping after finding {self.num_docs_to_keep} documents with high confidence")
|
|
133
|
+
return ranked_results
|
|
134
|
+
|
|
135
|
+
except Exception as e:
|
|
136
|
+
log.error(f"Batch processing error: {str(e)}")
|
|
137
|
+
continue
|
|
88
138
|
|
|
89
139
|
return ranked_results
|
|
90
140
|
|
|
91
|
-
def
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
141
|
+
async def acompress_documents(
|
|
142
|
+
self,
|
|
143
|
+
documents: Sequence[Document],
|
|
144
|
+
query: str,
|
|
145
|
+
callbacks: Optional[Callbacks] = None,
|
|
96
146
|
) -> Sequence[Document]:
|
|
97
|
-
"""
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
147
|
+
"""Async compress documents using reranking with proper error handling."""
|
|
148
|
+
if callbacks:
|
|
149
|
+
await callbacks.on_retriever_start({"query": query}, "Reranking documents")
|
|
150
|
+
|
|
151
|
+
log.info(f"Async compressing documents. Initial count: {len(documents)}")
|
|
152
|
+
if not documents:
|
|
153
|
+
if callbacks:
|
|
154
|
+
await callbacks.on_retriever_end({"documents": []})
|
|
101
155
|
return []
|
|
102
156
|
|
|
103
|
-
doc_contents = [doc.page_content for doc in documents]
|
|
104
|
-
query_documents_pairs = [(query, doc) for doc in doc_contents]
|
|
105
|
-
|
|
106
|
-
# Create event loop and run async code
|
|
107
|
-
import asyncio
|
|
108
157
|
try:
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
158
|
+
# Prepare query-document pairs
|
|
159
|
+
query_document_pairs = [(query, doc.page_content) for doc in documents]
|
|
160
|
+
|
|
161
|
+
if callbacks:
|
|
162
|
+
await callbacks.on_text("Starting document reranking...")
|
|
163
|
+
|
|
164
|
+
# Get ranked results
|
|
165
|
+
ranked_results = await self._rank(query_document_pairs)
|
|
166
|
+
|
|
167
|
+
# Sort by score in descending order
|
|
168
|
+
ranked_results.sort(key=lambda x: x[1], reverse=True)
|
|
169
|
+
|
|
170
|
+
# Filter based on threshold and num_docs_to_keep
|
|
171
|
+
filtered_docs = []
|
|
172
|
+
for doc, score in ranked_results:
|
|
173
|
+
if score >= self.filtering_threshold:
|
|
174
|
+
matching_doc = next(d for d in documents if d.page_content == doc)
|
|
175
|
+
matching_doc.metadata = {**(matching_doc.metadata or {}), "relevance_score": score}
|
|
176
|
+
filtered_docs.append(matching_doc)
|
|
177
|
+
|
|
178
|
+
if callbacks:
|
|
179
|
+
await callbacks.on_text(f"Document scored {score:.2f}")
|
|
180
|
+
|
|
181
|
+
if self.num_docs_to_keep and len(filtered_docs) >= self.num_docs_to_keep:
|
|
182
|
+
break
|
|
183
|
+
|
|
184
|
+
log.info(f"Async compression complete. Final count: {len(filtered_docs)}")
|
|
185
|
+
|
|
186
|
+
if callbacks:
|
|
187
|
+
await callbacks.on_retriever_end({"documents": filtered_docs})
|
|
188
|
+
|
|
189
|
+
return filtered_docs
|
|
190
|
+
|
|
191
|
+
except Exception as e:
|
|
192
|
+
error_msg = f"Error during async document compression: {str(e)}"
|
|
193
|
+
log.error(error_msg)
|
|
194
|
+
if callbacks:
|
|
195
|
+
await callbacks.on_retriever_error(error_msg)
|
|
196
|
+
return documents # Return original documents on error
|
|
197
|
+
|
|
198
|
+
def compress_documents(
|
|
199
|
+
self,
|
|
200
|
+
documents: Sequence[Document],
|
|
201
|
+
query: str,
|
|
202
|
+
callbacks: Optional[Callbacks] = None,
|
|
203
|
+
) -> Sequence[Document]:
|
|
204
|
+
"""Sync wrapper for async compression."""
|
|
205
|
+
return asyncio.run(self.acompress_documents(documents, query, callbacks))
|
|
148
206
|
|
|
149
207
|
@property
|
|
150
208
|
def _identifying_params(self) -> Dict[str, Any]:
|