MindsDB 25.1.2.1__py3-none-any.whl → 25.1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of MindsDB might be problematic. Click here for more details.

Files changed (95) hide show
  1. {MindsDB-25.1.2.1.dist-info → MindsDB-25.1.5.0.dist-info}/METADATA +246 -255
  2. {MindsDB-25.1.2.1.dist-info → MindsDB-25.1.5.0.dist-info}/RECORD +94 -83
  3. mindsdb/__about__.py +1 -1
  4. mindsdb/__main__.py +5 -3
  5. mindsdb/api/executor/__init__.py +0 -1
  6. mindsdb/api/executor/command_executor.py +2 -1
  7. mindsdb/api/executor/data_types/answer.py +1 -1
  8. mindsdb/api/executor/datahub/datanodes/datanode.py +1 -1
  9. mindsdb/api/executor/datahub/datanodes/information_schema_datanode.py +1 -1
  10. mindsdb/api/executor/datahub/datanodes/integration_datanode.py +8 -3
  11. mindsdb/api/executor/datahub/datanodes/project_datanode.py +9 -26
  12. mindsdb/api/executor/sql_query/__init__.py +1 -0
  13. mindsdb/api/executor/sql_query/result_set.py +36 -21
  14. mindsdb/api/executor/sql_query/steps/apply_predictor_step.py +1 -1
  15. mindsdb/api/executor/sql_query/steps/join_step.py +4 -4
  16. mindsdb/api/executor/sql_query/steps/map_reduce_step.py +6 -39
  17. mindsdb/api/executor/utilities/sql.py +2 -10
  18. mindsdb/api/http/namespaces/agents.py +3 -1
  19. mindsdb/api/http/namespaces/knowledge_bases.py +3 -3
  20. mindsdb/api/http/namespaces/sql.py +3 -1
  21. mindsdb/api/mysql/mysql_proxy/executor/mysql_executor.py +2 -1
  22. mindsdb/api/mysql/mysql_proxy/mysql_proxy.py +7 -0
  23. mindsdb/api/postgres/postgres_proxy/executor/executor.py +2 -1
  24. mindsdb/integrations/handlers/chromadb_handler/chromadb_handler.py +2 -2
  25. mindsdb/integrations/handlers/chromadb_handler/requirements.txt +1 -1
  26. mindsdb/integrations/handlers/databricks_handler/requirements.txt +1 -1
  27. mindsdb/integrations/handlers/file_handler/file_handler.py +1 -1
  28. mindsdb/integrations/handlers/file_handler/requirements.txt +0 -4
  29. mindsdb/integrations/handlers/file_handler/tests/test_file_handler.py +17 -1
  30. mindsdb/integrations/handlers/jira_handler/jira_handler.py +15 -1
  31. mindsdb/integrations/handlers/jira_handler/jira_table.py +52 -31
  32. mindsdb/integrations/handlers/langchain_embedding_handler/fastapi_embeddings.py +82 -0
  33. mindsdb/integrations/handlers/langchain_embedding_handler/langchain_embedding_handler.py +8 -1
  34. mindsdb/integrations/handlers/langchain_handler/requirements.txt +1 -1
  35. mindsdb/integrations/handlers/ms_one_drive_handler/ms_one_drive_handler.py +1 -1
  36. mindsdb/integrations/handlers/ms_one_drive_handler/ms_one_drive_tables.py +8 -0
  37. mindsdb/integrations/handlers/pgvector_handler/pgvector_handler.py +50 -16
  38. mindsdb/integrations/handlers/pinecone_handler/pinecone_handler.py +123 -72
  39. mindsdb/integrations/handlers/pinecone_handler/requirements.txt +1 -1
  40. mindsdb/integrations/handlers/postgres_handler/postgres_handler.py +12 -6
  41. mindsdb/integrations/handlers/ray_serve_handler/ray_serve_handler.py +5 -3
  42. mindsdb/integrations/handlers/slack_handler/slack_handler.py +13 -2
  43. mindsdb/integrations/handlers/slack_handler/slack_tables.py +21 -1
  44. mindsdb/integrations/handlers/web_handler/requirements.txt +0 -1
  45. mindsdb/integrations/libs/ml_handler_process/learn_process.py +2 -2
  46. mindsdb/integrations/utilities/files/__init__.py +0 -0
  47. mindsdb/integrations/utilities/files/file_reader.py +258 -0
  48. mindsdb/integrations/utilities/handlers/api_utilities/microsoft/ms_graph_api_utilities.py +2 -1
  49. mindsdb/integrations/utilities/handlers/auth_utilities/microsoft/ms_graph_api_auth_utilities.py +8 -3
  50. mindsdb/integrations/utilities/rag/chains/map_reduce_summarizer_chain.py +5 -9
  51. mindsdb/integrations/utilities/rag/loaders/vector_store_loader/pgvector.py +76 -27
  52. mindsdb/integrations/utilities/rag/loaders/vector_store_loader/vector_store_loader.py +18 -1
  53. mindsdb/integrations/utilities/rag/pipelines/rag.py +74 -21
  54. mindsdb/integrations/utilities/rag/rerankers/reranker_compressor.py +166 -108
  55. mindsdb/integrations/utilities/rag/retrievers/sql_retriever.py +108 -78
  56. mindsdb/integrations/utilities/rag/settings.py +37 -16
  57. mindsdb/integrations/utilities/sql_utils.py +1 -1
  58. mindsdb/interfaces/agents/agents_controller.py +18 -8
  59. mindsdb/interfaces/agents/constants.py +1 -0
  60. mindsdb/interfaces/agents/langchain_agent.py +124 -157
  61. mindsdb/interfaces/agents/langfuse_callback_handler.py +4 -37
  62. mindsdb/interfaces/agents/mindsdb_database_agent.py +21 -13
  63. mindsdb/interfaces/chatbot/chatbot_controller.py +7 -11
  64. mindsdb/interfaces/chatbot/chatbot_task.py +16 -5
  65. mindsdb/interfaces/chatbot/memory.py +58 -13
  66. mindsdb/interfaces/database/integrations.py +5 -1
  67. mindsdb/interfaces/database/projects.py +55 -16
  68. mindsdb/interfaces/database/views.py +12 -25
  69. mindsdb/interfaces/knowledge_base/controller.py +39 -15
  70. mindsdb/interfaces/knowledge_base/preprocessing/document_loader.py +7 -26
  71. mindsdb/interfaces/model/functions.py +15 -4
  72. mindsdb/interfaces/model/model_controller.py +4 -7
  73. mindsdb/interfaces/skills/custom/text2sql/mindsdb_sql_toolkit.py +51 -40
  74. mindsdb/interfaces/skills/retrieval_tool.py +10 -3
  75. mindsdb/interfaces/skills/skill_tool.py +97 -54
  76. mindsdb/interfaces/skills/skills_controller.py +7 -3
  77. mindsdb/interfaces/skills/sql_agent.py +127 -41
  78. mindsdb/interfaces/storage/db.py +1 -1
  79. mindsdb/migrations/versions/2025-01-15_c06c35f7e8e1_project_company.py +88 -0
  80. mindsdb/utilities/cache.py +7 -4
  81. mindsdb/utilities/context.py +11 -1
  82. mindsdb/utilities/langfuse.py +279 -0
  83. mindsdb/utilities/log.py +20 -2
  84. mindsdb/utilities/otel/__init__.py +206 -0
  85. mindsdb/utilities/otel/logger.py +25 -0
  86. mindsdb/utilities/otel/meter.py +19 -0
  87. mindsdb/utilities/otel/metric_handlers/__init__.py +25 -0
  88. mindsdb/utilities/otel/tracer.py +16 -0
  89. mindsdb/utilities/partitioning.py +52 -0
  90. mindsdb/utilities/render/sqlalchemy_render.py +7 -1
  91. mindsdb/utilities/utils.py +34 -0
  92. mindsdb/utilities/otel.py +0 -72
  93. {MindsDB-25.1.2.1.dist-info → MindsDB-25.1.5.0.dist-info}/LICENSE +0 -0
  94. {MindsDB-25.1.2.1.dist-info → MindsDB-25.1.5.0.dist-info}/WHEEL +0 -0
  95. {MindsDB-25.1.2.1.dist-info → MindsDB-25.1.5.0.dist-info}/top_level.txt +0 -0
@@ -1,8 +1,9 @@
1
1
  from copy import copy
2
- from typing import Optional, Any
2
+ from typing import Optional, Any, List
3
3
 
4
4
  from langchain_core.output_parsers import StrOutputParser
5
5
  from langchain.retrievers import ContextualCompressionRetriever
6
+ from langchain_core.documents import Document
6
7
 
7
8
  from langchain_core.prompts import ChatPromptTemplate
8
9
  from langchain_core.runnables import RunnableParallel, RunnablePassthrough, RunnableSerializable
@@ -28,6 +29,23 @@ from mindsdb.interfaces.agents.langchain_agent import create_chat_model
28
29
  class LangChainRAGPipeline:
29
30
  """
30
31
  Builds a RAG pipeline using langchain LCEL components
32
+
33
+ Args:
34
+ retriever_runnable: Base retriever component
35
+ prompt_template: Template for generating responses
36
+ llm: Language model for generating responses
37
+ reranker (bool): Whether to use reranking (default: False)
38
+ reranker_config (RerankerConfig): Configuration for the reranker, including:
39
+ - model: Model to use for reranking
40
+ - filtering_threshold: Minimum score to keep a document
41
+ - num_docs_to_keep: Maximum number of documents to keep
42
+ - max_concurrent_requests: Maximum concurrent API requests
43
+ - max_retries: Number of retry attempts for failed requests
44
+ - retry_delay: Delay between retries
45
+ - early_stop (bool): Whether to enable early stopping
46
+ - early_stop_threshold: Confidence threshold for early stopping
47
+ vector_store_config (VectorStoreConfig): Vector store configuration
48
+ summarization_config (SummarizationConfig): Summarization configuration
31
49
  """
32
50
 
33
51
  def __init__(
@@ -40,19 +58,15 @@ class LangChainRAGPipeline:
40
58
  vector_store_config: Optional[VectorStoreConfig] = None,
41
59
  summarization_config: Optional[SummarizationConfig] = None
42
60
  ):
43
-
44
61
  self.retriever_runnable = retriever_runnable
45
62
  self.prompt_template = prompt_template
46
63
  self.llm = llm
47
64
  if reranker:
48
65
  if reranker_config is None:
49
66
  reranker_config = RerankerConfig()
50
- self.reranker = LLMReranker(
51
- model=reranker_config.model,
52
- base_url=reranker_config.base_url,
53
- filtering_threshold=reranker_config.filtering_threshold,
54
- num_docs_to_keep=reranker_config.num_docs_to_keep
55
- )
67
+ # Convert config to dict and initialize reranker
68
+ reranker_kwargs = reranker_config.model_dump(exclude_none=True)
69
+ self.reranker = LLMReranker(**reranker_kwargs)
56
70
  else:
57
71
  self.reranker = None
58
72
  self.summarizer = None
@@ -102,17 +116,45 @@ class LangChainRAGPipeline:
102
116
  raise ValueError("One of the required components (llm) is None")
103
117
 
104
118
  if self.reranker:
105
- reranker = self.reranker
106
- retriever = copy(self.retriever_runnable)
107
- self.retriever_runnable = ContextualCompressionRetriever(
108
- base_compressor=reranker, base_retriever=retriever
119
+ # Create a custom retriever that handles async operations properly
120
+ class AsyncRerankerRetriever(ContextualCompressionRetriever):
121
+ """Async-aware retriever that properly handles concurrent reranking operations."""
122
+
123
+ def __init__(self, base_retriever, reranker):
124
+ super().__init__(
125
+ base_compressor=reranker,
126
+ base_retriever=base_retriever
127
+ )
128
+
129
+ async def ainvoke(self, query: str) -> List[Document]:
130
+ """Async retrieval with proper concurrency handling."""
131
+ # Get initial documents
132
+ if hasattr(self.base_retriever, 'ainvoke'):
133
+ docs = await self.base_retriever.ainvoke(query)
134
+ else:
135
+ docs = await RunnablePassthrough(self.base_retriever.get_relevant_documents)(query)
136
+
137
+ # Rerank documents
138
+ if docs:
139
+ docs = await self.base_compressor.acompress_documents(docs, query)
140
+ return docs
141
+
142
+ def get_relevant_documents(self, query: str) -> List[Document]:
143
+ """Sync wrapper for async retrieval."""
144
+ import asyncio
145
+ return asyncio.run(self.ainvoke(query))
146
+
147
+ # Use our custom async-aware retriever
148
+ self.retriever_runnable = AsyncRerankerRetriever(
149
+ base_retriever=copy(self.retriever_runnable),
150
+ reranker=self.reranker
109
151
  )
110
152
 
111
153
  rag_chain_from_docs = (
112
- RunnablePassthrough.assign(context=(lambda x: format_docs(x["context"]))) # noqa: E126, E122
113
- | prompt
114
- | self.llm
115
- | StrOutputParser()
154
+ RunnablePassthrough.assign(context=(lambda x: format_docs(x["context"])))
155
+ | prompt
156
+ | self.llm
157
+ | StrOutputParser()
116
158
  )
117
159
 
118
160
  retrieval_chain = RunnableParallel(
@@ -125,6 +167,16 @@ class LangChainRAGPipeline:
125
167
  rag_chain_with_source = retrieval_chain.assign(answer=rag_chain_from_docs)
126
168
  return rag_chain_with_source
127
169
 
170
+ async def ainvoke(self, input_dict: dict) -> dict:
171
+ """Async invocation of the RAG pipeline."""
172
+ chain = self.with_returned_sources()
173
+ return await chain.ainvoke(input_dict)
174
+
175
+ def invoke(self, input_dict: dict) -> dict:
176
+ """Sync invocation of the RAG pipeline."""
177
+ import asyncio
178
+ return asyncio.run(self.ainvoke(input_dict))
179
+
128
180
  @classmethod
129
181
  def _apply_search_kwargs(cls, retriever: Any, search_kwargs: Optional[SearchKwargs] = None, search_type: Optional[SearchType] = None) -> Any:
130
182
  """Apply search kwargs and search type to the retriever if they exist"""
@@ -235,6 +287,10 @@ class LangChainRAGPipeline:
235
287
  )
236
288
  vector_store_retriever = vector_store_operator.vector_store.as_retriever()
237
289
  vector_store_retriever = cls._apply_search_kwargs(vector_store_retriever, config.search_kwargs, config.search_type)
290
+ distance_function = DistanceFunction.SQUARED_EUCLIDEAN_DISTANCE
291
+ if config.vector_store_config.is_sparse and config.vector_store_config.vector_size is not None:
292
+ # Use negative dot product for sparse retrieval.
293
+ distance_function = DistanceFunction.NEGATIVE_DOT_PRODUCT
238
294
  retriever = SQLRetriever(
239
295
  fallback_retriever=vector_store_retriever,
240
296
  vector_store_handler=knowledge_base_table.get_vector_db(),
@@ -242,14 +298,11 @@ class LangChainRAGPipeline:
242
298
  examples=retriever_config.examples,
243
299
  embeddings_model=embeddings,
244
300
  rewrite_prompt_template=retriever_config.rewrite_prompt_template,
245
- retry_prompt_template=retriever_config.query_retry_template,
301
+ metadata_filters_prompt_template=retriever_config.metadata_filters_prompt_template,
246
302
  num_retries=retriever_config.num_retries,
247
- sql_prompt_template=retriever_config.sql_prompt_template,
248
- query_checker_template=retriever_config.query_checker_template,
249
303
  embeddings_table=knowledge_base_table._kb.vector_database_table,
250
304
  source_table=retriever_config.source_table,
251
- # Currently only similarity search is supported.
252
- distance_function=DistanceFunction.SQUARED_EUCLIDEAN_DISTANCE,
305
+ distance_function=distance_function,
253
306
  search_kwargs=config.search_kwargs,
254
307
  llm=sql_llm
255
308
  )
@@ -4,16 +4,15 @@ import asyncio
4
4
  import logging
5
5
  import math
6
6
  import os
7
+ import random
7
8
  from typing import Any, Dict, List, Optional, Sequence, Tuple
8
- from uuid import uuid4
9
+
9
10
  from langchain.retrievers.document_compressors.base import BaseDocumentCompressor
10
11
  from langchain_core.callbacks import Callbacks
11
-
12
- from mindsdb.integrations.utilities.rag.settings import DEFAULT_RERANKING_MODEL, DEFAULT_LLM_ENDPOINT
13
12
  from langchain_core.documents import Document
14
- from langchain_core.messages import HumanMessage, SystemMessage
15
- from langchain_openai import ChatOpenAI
13
+ from openai import AsyncOpenAI
16
14
 
15
+ from mindsdb.integrations.utilities.rag.settings import DEFAULT_RERANKING_MODEL, DEFAULT_LLM_ENDPOINT
17
16
 
18
17
  log = logging.getLogger(__name__)
19
18
 
@@ -23,128 +22,187 @@ class LLMReranker(BaseDocumentCompressor):
23
22
  model: str = DEFAULT_RERANKING_MODEL # Model to use for reranking
24
23
  temperature: float = 0.0 # Temperature for the model
25
24
  openai_api_key: Optional[str] = None
26
- remove_irrelevant: bool = True # New flag to control removal of irrelevant documents,
25
+ remove_irrelevant: bool = True # New flag to control removal of irrelevant documents
27
26
  base_url: str = DEFAULT_LLM_ENDPOINT
28
27
  num_docs_to_keep: Optional[int] = None # How many of the top documents to keep after reranking & compressing.
29
-
30
28
  _api_key_var: str = "OPENAI_API_KEY"
31
- client: Optional[Any] = None
29
+ client: Optional[AsyncOpenAI] = None
30
+ _semaphore: Optional[asyncio.Semaphore] = None
31
+ max_concurrent_requests: int = 20
32
+ max_retries: int = 3
33
+ retry_delay: float = 1.0
34
+ request_timeout: float = 20.0 # Timeout for API requests
35
+ early_stop: bool = True # Whether to enable early stopping
36
+ early_stop_threshold: float = 0.8 # Confidence threshold for early stopping
32
37
 
33
38
  class Config:
34
39
  arbitrary_types_allowed = True
35
40
 
36
- async def search_relevancy(self, query: str, document: str) -> Any:
37
- openai_api_key = self.openai_api_key or os.getenv(self._api_key_var)
38
-
39
- # Initialize the ChatOpenAI client
40
- client = ChatOpenAI(openai_api_base=self.base_url, api_key=openai_api_key, model=self.model, temperature=0,
41
- logprobs=True)
42
-
43
- # Create the message history for the conversation
44
- message_history = [
45
- SystemMessage(
46
- content="""Your task is to classify whether the document is relevant to the search query provided below. Answer just "YES" or "NO"."""),
47
- HumanMessage(content=f"""Document: ```{document}```; Search query: ```{query}```""")
48
- ]
49
-
50
- # Generate the response using LangChain's chat model
51
- response = await client.agenerate(
52
- messages=[message_history],
53
- max_tokens=1
54
- )
41
+ def __init__(self, **kwargs):
42
+ super().__init__(**kwargs)
43
+ self._semaphore = asyncio.Semaphore(self.max_concurrent_requests)
44
+
45
+ async def _init_client(self):
46
+ if self.client is None:
47
+ openai_api_key = self.openai_api_key or os.getenv(self._api_key_var)
48
+ if not openai_api_key:
49
+ raise ValueError(f"OpenAI API key not found in environment variable {self._api_key_var}")
50
+ self.client = AsyncOpenAI(
51
+ api_key=openai_api_key,
52
+ base_url=self.base_url,
53
+ timeout=self.request_timeout,
54
+ max_retries=2 # Client-level retries
55
+ )
55
56
 
56
- # Return the response from the model
57
- return response.generations[0]
57
+ async def search_relevancy(self, query: str, document: str) -> Any:
58
+ await self._init_client()
59
+
60
+ async with self._semaphore:
61
+ for attempt in range(self.max_retries):
62
+ try:
63
+ response = await self.client.chat.completions.create(
64
+ model=self.model,
65
+ messages=[
66
+ {"role": "system", "content": "Rate the relevance of the document to the query. Respond with 'yes' or 'no'."},
67
+ {"role": "user", "content": f"Query: {query}\nDocument: {document}\nIs this document relevant?"}
68
+ ],
69
+ temperature=self.temperature,
70
+ n=1,
71
+ logprobs=True,
72
+ max_tokens=1
73
+ )
74
+
75
+ # Extract response and logprobs
76
+ answer = response.choices[0].message.content
77
+ logprob = response.choices[0].logprobs.content[0].logprob
78
+
79
+ return {"answer": answer, "logprob": logprob}
80
+
81
+ except Exception as e:
82
+ if attempt == self.max_retries - 1:
83
+ log.error(f"Failed after {self.max_retries} attempts: {str(e)}")
84
+ raise
85
+ # Exponential backoff with jitter
86
+ retry_delay = self.retry_delay * (2 ** attempt) + random.uniform(0, 0.1)
87
+ await asyncio.sleep(retry_delay)
58
88
 
59
89
  async def _rank(self, query_document_pairs: List[Tuple[str, str]]) -> List[Tuple[str, float]]:
60
- # Gather results asynchronously for all query-document pairs
61
- results = await asyncio.gather(
62
- *[self.search_relevancy(query=query, document=document) for (query, document) in query_document_pairs]
63
- )
64
-
65
90
  ranked_results = []
66
91
 
67
- for idx, result in enumerate(results):
68
- # Extract the log probability (assuming logprobs are provided in LangChain response)
69
- msg = result[0].message
70
- logprob = msg.response_metadata['logprobs']['content'][0]['logprob']
71
- prob = math.exp(logprob)
72
- answer = result[0].message.content # The model's "YES" or "NO" response
73
-
74
- # Calculate the score based on the model's response
75
- if answer == "YES":
76
- score = prob
77
- elif answer.lower().strip().startswith("y"):
78
- score = prob
79
- elif answer == "NO":
80
- score = 1 - prob
81
- elif answer.lower().strip().startswith("n"):
82
- score = 1 - prob
83
- else:
84
- score = 0.0 # Default if something unexpected happens
85
-
86
- # Append the document and score to the result
87
- ranked_results.append((query_document_pairs[idx][1], score)) # (document, score)
92
+ # Process in larger batches for better throughput
93
+ batch_size = min(self.max_concurrent_requests * 2, len(query_document_pairs))
94
+ for i in range(0, len(query_document_pairs), batch_size):
95
+ batch = query_document_pairs[i:i + batch_size]
96
+ try:
97
+ results = await asyncio.gather(
98
+ *[self.search_relevancy(query=query, document=document) for (query, document) in batch],
99
+ return_exceptions=True
100
+ )
101
+
102
+ for idx, result in enumerate(results):
103
+ if isinstance(result, Exception):
104
+ log.error(f"Error processing document {i+idx}: {str(result)}")
105
+ ranked_results.append((batch[idx][1], 0.0))
106
+ continue
107
+
108
+ answer = result["answer"]
109
+ logprob = result["logprob"]
110
+ prob = math.exp(logprob)
111
+
112
+ # Convert answer to score using the model's confidence
113
+ if answer.lower().strip() == "yes":
114
+ score = prob # If yes, use the model's confidence
115
+ elif answer.lower().strip() == "no":
116
+ score = 1 - prob # If no, invert the confidence
117
+ else:
118
+ score = 0.5 * prob # For unclear answers, reduce confidence
119
+
120
+ ranked_results.append((batch[idx][1], score))
121
+
122
+ # Check if we should stop early
123
+ high_scoring_docs = [r for r in ranked_results if r[1] >= self.filtering_threshold]
124
+ can_stop_early = (
125
+ self.early_stop # Early stopping is enabled
126
+ and self.num_docs_to_keep # We have a target number of docs
127
+ and len(high_scoring_docs) >= self.num_docs_to_keep # Found enough good docs
128
+ and score >= self.early_stop_threshold # Current doc is good enough
129
+ )
130
+
131
+ if can_stop_early:
132
+ log.info(f"Early stopping after finding {self.num_docs_to_keep} documents with high confidence")
133
+ return ranked_results
134
+
135
+ except Exception as e:
136
+ log.error(f"Batch processing error: {str(e)}")
137
+ continue
88
138
 
89
139
  return ranked_results
90
140
 
91
- def compress_documents(
92
- self,
93
- documents: Sequence[Document],
94
- query: str,
95
- callbacks: Optional[Callbacks] = None,
141
+ async def acompress_documents(
142
+ self,
143
+ documents: Sequence[Document],
144
+ query: str,
145
+ callbacks: Optional[Callbacks] = None,
96
146
  ) -> Sequence[Document]:
97
- """Compress documents using OpenAI's rerank capability with individual document assessment."""
98
- log.info(f"Compressing documents. Initial count: {len(documents)}")
99
- if len(documents) == 0:
100
- log.warning("No documents to compress. Returning empty list.")
147
+ """Async compress documents using reranking with proper error handling."""
148
+ if callbacks:
149
+ await callbacks.on_retriever_start({"query": query}, "Reranking documents")
150
+
151
+ log.info(f"Async compressing documents. Initial count: {len(documents)}")
152
+ if not documents:
153
+ if callbacks:
154
+ await callbacks.on_retriever_end({"documents": []})
101
155
  return []
102
156
 
103
- doc_contents = [doc.page_content for doc in documents]
104
- query_documents_pairs = [(query, doc) for doc in doc_contents]
105
-
106
- # Create event loop and run async code
107
- import asyncio
108
157
  try:
109
- rankings = asyncio.get_event_loop().run_until_complete(self._rank(query_documents_pairs))
110
- except RuntimeError:
111
- # If no event loop is available, create a new one
112
- loop = asyncio.new_event_loop()
113
- asyncio.set_event_loop(loop)
114
- rankings = loop.run_until_complete(self._rank(query_documents_pairs))
115
-
116
- compressed = []
117
- for ind, ranking in enumerate(rankings):
118
- doc = documents[ind]
119
- document_text, score = ranking
120
- doc.metadata["relevance_score"] = score
121
- doc.metadata["is_relevant"] = score > self.filtering_threshold
122
- # Add the document to the compressed list if it is relevant or if we are not removing irrelevant documents
123
- if not self.remove_irrelevant:
124
- compressed.append(doc)
125
- elif doc.metadata["is_relevant"]:
126
- compressed.append(doc)
127
-
128
- log.info(f"Compression complete. {len(compressed)} documents returned")
129
- if not compressed:
130
- log.warning("No documents found after compression")
131
-
132
- if self.num_docs_to_keep is not None:
133
- # Sort by relevance score with highest first.
134
- compressed.sort(
135
- key=lambda d: d.metadata.get('relevance_score', 0) if d.metadata else 0,
136
- reverse=True
137
- )
138
- compressed = compressed[:self.num_docs_to_keep]
139
-
140
- # Handle retrieval callbacks to account for reranked & compressed docs.
141
- callbacks = callbacks if callbacks else []
142
- run_id = uuid4().hex
143
- if not isinstance(callbacks, list):
144
- callbacks = callbacks.handlers
145
- for callback in callbacks:
146
- callback.on_retriever_end(compressed, run_id=run_id)
147
- return compressed
158
+ # Prepare query-document pairs
159
+ query_document_pairs = [(query, doc.page_content) for doc in documents]
160
+
161
+ if callbacks:
162
+ await callbacks.on_text("Starting document reranking...")
163
+
164
+ # Get ranked results
165
+ ranked_results = await self._rank(query_document_pairs)
166
+
167
+ # Sort by score in descending order
168
+ ranked_results.sort(key=lambda x: x[1], reverse=True)
169
+
170
+ # Filter based on threshold and num_docs_to_keep
171
+ filtered_docs = []
172
+ for doc, score in ranked_results:
173
+ if score >= self.filtering_threshold:
174
+ matching_doc = next(d for d in documents if d.page_content == doc)
175
+ matching_doc.metadata = {**(matching_doc.metadata or {}), "relevance_score": score}
176
+ filtered_docs.append(matching_doc)
177
+
178
+ if callbacks:
179
+ await callbacks.on_text(f"Document scored {score:.2f}")
180
+
181
+ if self.num_docs_to_keep and len(filtered_docs) >= self.num_docs_to_keep:
182
+ break
183
+
184
+ log.info(f"Async compression complete. Final count: {len(filtered_docs)}")
185
+
186
+ if callbacks:
187
+ await callbacks.on_retriever_end({"documents": filtered_docs})
188
+
189
+ return filtered_docs
190
+
191
+ except Exception as e:
192
+ error_msg = f"Error during async document compression: {str(e)}"
193
+ log.error(error_msg)
194
+ if callbacks:
195
+ await callbacks.on_retriever_error(error_msg)
196
+ return documents # Return original documents on error
197
+
198
+ def compress_documents(
199
+ self,
200
+ documents: Sequence[Document],
201
+ query: str,
202
+ callbacks: Optional[Callbacks] = None,
203
+ ) -> Sequence[Document]:
204
+ """Sync wrapper for async compression."""
205
+ return asyncio.run(self.acompress_documents(documents, query, callbacks))
148
206
 
149
207
  @property
150
208
  def _identifying_params(self) -> Dict[str, Any]: