MindsDB 25.1.2.1__py3-none-any.whl → 25.1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of MindsDB might be problematic. Click here for more details.

Files changed (77) hide show
  1. {MindsDB-25.1.2.1.dist-info → MindsDB-25.1.4.0.dist-info}/METADATA +244 -242
  2. {MindsDB-25.1.2.1.dist-info → MindsDB-25.1.4.0.dist-info}/RECORD +76 -67
  3. mindsdb/__about__.py +1 -1
  4. mindsdb/__main__.py +5 -3
  5. mindsdb/api/executor/__init__.py +0 -1
  6. mindsdb/api/executor/command_executor.py +2 -1
  7. mindsdb/api/executor/data_types/answer.py +1 -1
  8. mindsdb/api/executor/datahub/datanodes/integration_datanode.py +7 -2
  9. mindsdb/api/executor/datahub/datanodes/project_datanode.py +8 -1
  10. mindsdb/api/executor/sql_query/__init__.py +1 -0
  11. mindsdb/api/executor/sql_query/result_set.py +36 -21
  12. mindsdb/api/executor/sql_query/steps/apply_predictor_step.py +1 -1
  13. mindsdb/api/executor/sql_query/steps/join_step.py +4 -4
  14. mindsdb/api/executor/sql_query/steps/map_reduce_step.py +6 -39
  15. mindsdb/api/executor/utilities/sql.py +2 -10
  16. mindsdb/api/http/namespaces/knowledge_bases.py +3 -3
  17. mindsdb/api/http/namespaces/sql.py +3 -1
  18. mindsdb/api/mysql/mysql_proxy/executor/mysql_executor.py +2 -1
  19. mindsdb/api/mysql/mysql_proxy/mysql_proxy.py +7 -0
  20. mindsdb/api/postgres/postgres_proxy/executor/executor.py +2 -1
  21. mindsdb/integrations/handlers/chromadb_handler/chromadb_handler.py +2 -2
  22. mindsdb/integrations/handlers/chromadb_handler/requirements.txt +1 -1
  23. mindsdb/integrations/handlers/file_handler/file_handler.py +1 -1
  24. mindsdb/integrations/handlers/file_handler/tests/test_file_handler.py +17 -1
  25. mindsdb/integrations/handlers/jira_handler/jira_handler.py +15 -1
  26. mindsdb/integrations/handlers/jira_handler/jira_table.py +52 -31
  27. mindsdb/integrations/handlers/langchain_embedding_handler/fastapi_embeddings.py +82 -0
  28. mindsdb/integrations/handlers/langchain_embedding_handler/langchain_embedding_handler.py +8 -1
  29. mindsdb/integrations/handlers/langchain_handler/requirements.txt +1 -1
  30. mindsdb/integrations/handlers/pgvector_handler/pgvector_handler.py +48 -16
  31. mindsdb/integrations/handlers/pinecone_handler/pinecone_handler.py +123 -72
  32. mindsdb/integrations/handlers/pinecone_handler/requirements.txt +1 -1
  33. mindsdb/integrations/handlers/postgres_handler/postgres_handler.py +12 -6
  34. mindsdb/integrations/handlers/slack_handler/slack_handler.py +13 -2
  35. mindsdb/integrations/handlers/slack_handler/slack_tables.py +21 -1
  36. mindsdb/integrations/libs/ml_handler_process/learn_process.py +1 -1
  37. mindsdb/integrations/utilities/rag/loaders/vector_store_loader/pgvector.py +76 -27
  38. mindsdb/integrations/utilities/rag/loaders/vector_store_loader/vector_store_loader.py +18 -1
  39. mindsdb/integrations/utilities/rag/pipelines/rag.py +73 -18
  40. mindsdb/integrations/utilities/rag/rerankers/reranker_compressor.py +166 -108
  41. mindsdb/integrations/utilities/rag/retrievers/sql_retriever.py +36 -14
  42. mindsdb/integrations/utilities/rag/settings.py +8 -2
  43. mindsdb/integrations/utilities/sql_utils.py +1 -1
  44. mindsdb/interfaces/agents/agents_controller.py +3 -5
  45. mindsdb/interfaces/agents/langchain_agent.py +112 -150
  46. mindsdb/interfaces/agents/langfuse_callback_handler.py +0 -37
  47. mindsdb/interfaces/agents/mindsdb_database_agent.py +15 -13
  48. mindsdb/interfaces/chatbot/chatbot_controller.py +7 -11
  49. mindsdb/interfaces/chatbot/chatbot_task.py +16 -5
  50. mindsdb/interfaces/chatbot/memory.py +58 -13
  51. mindsdb/interfaces/database/projects.py +17 -15
  52. mindsdb/interfaces/database/views.py +12 -25
  53. mindsdb/interfaces/knowledge_base/controller.py +39 -15
  54. mindsdb/interfaces/model/functions.py +15 -4
  55. mindsdb/interfaces/model/model_controller.py +4 -7
  56. mindsdb/interfaces/skills/custom/text2sql/mindsdb_sql_toolkit.py +47 -38
  57. mindsdb/interfaces/skills/retrieval_tool.py +10 -3
  58. mindsdb/interfaces/skills/skill_tool.py +97 -53
  59. mindsdb/interfaces/skills/sql_agent.py +77 -36
  60. mindsdb/interfaces/storage/db.py +1 -1
  61. mindsdb/migrations/versions/2025-01-15_c06c35f7e8e1_project_company.py +88 -0
  62. mindsdb/utilities/cache.py +7 -4
  63. mindsdb/utilities/context.py +11 -1
  64. mindsdb/utilities/langfuse.py +264 -0
  65. mindsdb/utilities/log.py +20 -2
  66. mindsdb/utilities/otel/__init__.py +206 -0
  67. mindsdb/utilities/otel/logger.py +25 -0
  68. mindsdb/utilities/otel/meter.py +19 -0
  69. mindsdb/utilities/otel/metric_handlers/__init__.py +25 -0
  70. mindsdb/utilities/otel/tracer.py +16 -0
  71. mindsdb/utilities/partitioning.py +52 -0
  72. mindsdb/utilities/render/sqlalchemy_render.py +7 -1
  73. mindsdb/utilities/utils.py +34 -0
  74. mindsdb/utilities/otel.py +0 -72
  75. {MindsDB-25.1.2.1.dist-info → MindsDB-25.1.4.0.dist-info}/LICENSE +0 -0
  76. {MindsDB-25.1.2.1.dist-info → MindsDB-25.1.4.0.dist-info}/WHEEL +0 -0
  77. {MindsDB-25.1.2.1.dist-info → MindsDB-25.1.4.0.dist-info}/top_level.txt +0 -0
@@ -1,9 +1,9 @@
1
- from typing import Any, List, Optional, Dict
1
+ from typing import Any, List, Union, Optional, Dict
2
2
 
3
3
  from langchain_community.vectorstores import PGVector
4
4
  from langchain_community.vectorstores.pgvector import Base
5
5
 
6
- from pgvector.sqlalchemy import Vector
6
+ from pgvector.sqlalchemy import SPARSEVEC, Vector
7
7
  import sqlalchemy as sa
8
8
  from sqlalchemy.dialects.postgresql import JSON
9
9
 
@@ -15,9 +15,17 @@ _generated_sa_tables = {}
15
15
 
16
16
  class PGVectorMDB(PGVector):
17
17
  """
18
- langchain_community.vectorstores.PGVector adapted for mindsdb vector store table structure
18
+ langchain_community.vectorstores.PGVector adapted for mindsdb vector store table structure
19
19
  """
20
20
 
21
+ def __init__(self, *args, is_sparse: bool = False, vector_size: Optional[int] = None, **kwargs):
22
+ # todo get is_sparse and vector_size from kb vector table
23
+ self.is_sparse = is_sparse
24
+ if is_sparse and vector_size is None:
25
+ raise ValueError("vector_size is required when is_sparse=True")
26
+ self.vector_size = vector_size
27
+ super().__init__(*args, **kwargs)
28
+
21
29
  def __post_init__(
22
30
  self,
23
31
  ) -> None:
@@ -32,53 +40,94 @@ class PGVectorMDB(PGVector):
32
40
  __tablename__ = collection_name
33
41
 
34
42
  id = sa.Column(sa.Integer, primary_key=True)
35
- embedding: Vector = sa.Column('embeddings', Vector())
36
- document = sa.Column('content', sa.String, nullable=True)
37
- cmetadata = sa.Column('metadata', JSON, nullable=True)
43
+ embedding = sa.Column(
44
+ "embeddings",
45
+ SPARSEVEC() if self.is_sparse else Vector() if self.vector_size is None else
46
+ SPARSEVEC(self.vector_size) if self.is_sparse else Vector(self.vector_size)
47
+ )
48
+ document = sa.Column("content", sa.String, nullable=True)
49
+ cmetadata = sa.Column("metadata", JSON, nullable=True)
38
50
 
39
51
  _generated_sa_tables[collection_name] = EmbeddingStore
40
52
 
41
53
  self.EmbeddingStore = _generated_sa_tables[collection_name]
42
54
 
43
55
  def __query_collection(
44
- self,
45
- embedding: List[float],
46
- k: int = 4,
47
- filter: Optional[Dict[str, str]] = None,
56
+ self,
57
+ embedding: Union[List[float], Dict[int, float], str],
58
+ k: int = 4,
59
+ filter: Optional[Dict[str, str]] = None,
48
60
  ) -> List[Any]:
49
61
  """Query the collection."""
50
62
  with Session(self._bind) as session:
51
-
52
- results: List[Any] = (
53
- session.query(
54
- self.EmbeddingStore,
55
- self.distance_strategy(embedding).label("distance"),
56
- )
57
- .order_by(sa.asc("distance"))
58
- .limit(k)
59
- .all()
63
+ if self.is_sparse:
64
+ # Sparse vectors: expect string in format "{key:value,...}/size" or dictionary
65
+ if isinstance(embedding, dict):
66
+ from pgvector.utils import SparseVector
67
+ embedding = SparseVector(embedding, self.vector_size)
68
+ embedding_str = embedding.to_text()
69
+ elif isinstance(embedding, str):
70
+ # Use string as is - it should already be in the correct format
71
+ embedding_str = embedding
72
+ # Use inner product for sparse vectors
73
+ distance_op = "<#>"
74
+ # For inner product, larger values are better matches
75
+ order_direction = "ASC"
76
+ else:
77
+ # Dense vectors: expect string in JSON array format or list of floats
78
+ if isinstance(embedding, list):
79
+ embedding_str = f"[{','.join(str(x) for x in embedding)}]"
80
+ elif isinstance(embedding, str):
81
+ embedding_str = embedding
82
+ # Use cosine similarity for dense vectors
83
+ distance_op = "<=>"
84
+ # For cosine similarity, smaller values are better matches
85
+ order_direction = "ASC"
86
+
87
+ # Use SQL directly for vector comparison
88
+ query = sa.text(
89
+ f"""
90
+ SELECT t.*, t.embeddings {distance_op} '{embedding_str}' as distance
91
+ FROM {self.collection_name} t
92
+ ORDER BY distance {order_direction}
93
+ LIMIT {k}
94
+ """
60
95
  )
61
- for rec, _ in results:
62
- if not bool(rec.cmetadata):
63
- rec.cmetadata = {0: 0}
96
+ results = session.execute(query).all()
97
+
98
+ # Convert results to the expected format
99
+ formatted_results = []
100
+ for rec in results:
101
+ metadata = rec.metadata if bool(rec.metadata) else {0: 0}
102
+ embedding_store = self.EmbeddingStore()
103
+ embedding_store.document = rec.content
104
+ embedding_store.cmetadata = metadata
105
+ result = type(
106
+ 'Result', (), {
107
+ 'EmbeddingStore': embedding_store,
108
+ 'distance': rec.distance
109
+ }
110
+ )
111
+ formatted_results.append(result)
64
112
 
65
- return results
113
+ return formatted_results
66
114
 
67
115
  # aliases for different langchain versions
68
116
  def _PGVector__query_collection(self, *args, **kwargs):
117
+
69
118
  return self.__query_collection(*args, **kwargs)
70
119
 
71
120
  def _query_collection(self, *args, **kwargs):
72
121
  return self.__query_collection(*args, **kwargs)
73
122
 
74
123
  def create_collection(self):
75
- raise RuntimeError('Forbidden')
124
+ raise RuntimeError("Forbidden")
76
125
 
77
126
  def delete_collection(self):
78
- raise RuntimeError('Forbidden')
127
+ raise RuntimeError("Forbidden")
79
128
 
80
129
  def delete(self, *args, **kwargs):
81
- raise RuntimeError('Forbidden')
130
+ raise RuntimeError("Forbidden")
82
131
 
83
132
  def add_embeddings(self, *args, **kwargs):
84
- raise RuntimeError('Forbidden')
133
+ raise RuntimeError("Forbidden")
@@ -7,6 +7,7 @@ from pydantic import BaseModel
7
7
 
8
8
  from mindsdb.integrations.utilities.rag.settings import VectorStoreType, VectorStoreConfig
9
9
  from mindsdb.integrations.utilities.rag.loaders.vector_store_loader.MDBVectorStore import MDBVectorStore
10
+ from mindsdb.integrations.utilities.rag.loaders.vector_store_loader.pgvector import PGVectorMDB
10
11
  from mindsdb.utilities import log
11
12
 
12
13
 
@@ -28,6 +29,20 @@ class VectorStoreLoader(BaseModel):
28
29
  Loads the vector store based on the provided config and embeddings model
29
30
  :return:
30
31
  """
32
+ if self.config.is_sparse is not None and self.config.vector_size is not None and self.config.kb_table is not None:
33
+ # Only use PGVector store for sparse vectors.
34
+ db_handler = self.config.kb_table.get_vector_db()
35
+ db_args = db_handler.connection_args
36
+ # Assume we are always using PGVector & psycopg2.
37
+ connection_str = f"postgresql+psycopg2://{db_args.get('user')}:{db_args.get('password')}@{db_args.get('host')}:{db_args.get('port')}/{db_args.get('dbname', db_args.get('database'))}"
38
+
39
+ return PGVectorMDB(
40
+ connection_string=connection_str,
41
+ collection_name=self.config.kb_table._kb.vector_database_table,
42
+ embedding_function=self.embedding_model,
43
+ is_sparse=self.config.is_sparse,
44
+ vector_size=self.config.vector_size
45
+ )
31
46
  return MDBVectorStore(kb_table=self.config.kb_table)
32
47
 
33
48
 
@@ -56,5 +71,7 @@ class VectorStoreFactory:
56
71
  return PGVectorMDB(
57
72
  connection_string=settings.connection_string,
58
73
  collection_name=settings.collection_name,
59
- embedding_function=embedding_model
74
+ embedding_function=embedding_model,
75
+ is_sparse=settings.is_sparse,
76
+ vector_size=settings.vector_size
60
77
  )
@@ -1,8 +1,9 @@
1
1
  from copy import copy
2
- from typing import Optional, Any
2
+ from typing import Optional, Any, List
3
3
 
4
4
  from langchain_core.output_parsers import StrOutputParser
5
5
  from langchain.retrievers import ContextualCompressionRetriever
6
+ from langchain_core.documents import Document
6
7
 
7
8
  from langchain_core.prompts import ChatPromptTemplate
8
9
  from langchain_core.runnables import RunnableParallel, RunnablePassthrough, RunnableSerializable
@@ -28,6 +29,23 @@ from mindsdb.interfaces.agents.langchain_agent import create_chat_model
28
29
  class LangChainRAGPipeline:
29
30
  """
30
31
  Builds a RAG pipeline using langchain LCEL components
32
+
33
+ Args:
34
+ retriever_runnable: Base retriever component
35
+ prompt_template: Template for generating responses
36
+ llm: Language model for generating responses
37
+ reranker (bool): Whether to use reranking (default: False)
38
+ reranker_config (RerankerConfig): Configuration for the reranker, including:
39
+ - model: Model to use for reranking
40
+ - filtering_threshold: Minimum score to keep a document
41
+ - num_docs_to_keep: Maximum number of documents to keep
42
+ - max_concurrent_requests: Maximum concurrent API requests
43
+ - max_retries: Number of retry attempts for failed requests
44
+ - retry_delay: Delay between retries
45
+ - early_stop (bool): Whether to enable early stopping
46
+ - early_stop_threshold: Confidence threshold for early stopping
47
+ vector_store_config (VectorStoreConfig): Vector store configuration
48
+ summarization_config (SummarizationConfig): Summarization configuration
31
49
  """
32
50
 
33
51
  def __init__(
@@ -40,19 +58,15 @@ class LangChainRAGPipeline:
40
58
  vector_store_config: Optional[VectorStoreConfig] = None,
41
59
  summarization_config: Optional[SummarizationConfig] = None
42
60
  ):
43
-
44
61
  self.retriever_runnable = retriever_runnable
45
62
  self.prompt_template = prompt_template
46
63
  self.llm = llm
47
64
  if reranker:
48
65
  if reranker_config is None:
49
66
  reranker_config = RerankerConfig()
50
- self.reranker = LLMReranker(
51
- model=reranker_config.model,
52
- base_url=reranker_config.base_url,
53
- filtering_threshold=reranker_config.filtering_threshold,
54
- num_docs_to_keep=reranker_config.num_docs_to_keep
55
- )
67
+ # Convert config to dict and initialize reranker
68
+ reranker_kwargs = reranker_config.model_dump(exclude_none=True)
69
+ self.reranker = LLMReranker(**reranker_kwargs)
56
70
  else:
57
71
  self.reranker = None
58
72
  self.summarizer = None
@@ -102,17 +116,45 @@ class LangChainRAGPipeline:
102
116
  raise ValueError("One of the required components (llm) is None")
103
117
 
104
118
  if self.reranker:
105
- reranker = self.reranker
106
- retriever = copy(self.retriever_runnable)
107
- self.retriever_runnable = ContextualCompressionRetriever(
108
- base_compressor=reranker, base_retriever=retriever
119
+ # Create a custom retriever that handles async operations properly
120
+ class AsyncRerankerRetriever(ContextualCompressionRetriever):
121
+ """Async-aware retriever that properly handles concurrent reranking operations."""
122
+
123
+ def __init__(self, base_retriever, reranker):
124
+ super().__init__(
125
+ base_compressor=reranker,
126
+ base_retriever=base_retriever
127
+ )
128
+
129
+ async def ainvoke(self, query: str) -> List[Document]:
130
+ """Async retrieval with proper concurrency handling."""
131
+ # Get initial documents
132
+ if hasattr(self.base_retriever, 'ainvoke'):
133
+ docs = await self.base_retriever.ainvoke(query)
134
+ else:
135
+ docs = await RunnablePassthrough(self.base_retriever.get_relevant_documents)(query)
136
+
137
+ # Rerank documents
138
+ if docs:
139
+ docs = await self.base_compressor.acompress_documents(docs, query)
140
+ return docs
141
+
142
+ def get_relevant_documents(self, query: str) -> List[Document]:
143
+ """Sync wrapper for async retrieval."""
144
+ import asyncio
145
+ return asyncio.run(self.ainvoke(query))
146
+
147
+ # Use our custom async-aware retriever
148
+ self.retriever_runnable = AsyncRerankerRetriever(
149
+ base_retriever=copy(self.retriever_runnable),
150
+ reranker=self.reranker
109
151
  )
110
152
 
111
153
  rag_chain_from_docs = (
112
- RunnablePassthrough.assign(context=(lambda x: format_docs(x["context"]))) # noqa: E126, E122
113
- | prompt
114
- | self.llm
115
- | StrOutputParser()
154
+ RunnablePassthrough.assign(context=(lambda x: format_docs(x["context"])))
155
+ | prompt
156
+ | self.llm
157
+ | StrOutputParser()
116
158
  )
117
159
 
118
160
  retrieval_chain = RunnableParallel(
@@ -125,6 +167,16 @@ class LangChainRAGPipeline:
125
167
  rag_chain_with_source = retrieval_chain.assign(answer=rag_chain_from_docs)
126
168
  return rag_chain_with_source
127
169
 
170
+ async def ainvoke(self, input_dict: dict) -> dict:
171
+ """Async invocation of the RAG pipeline."""
172
+ chain = self.with_returned_sources()
173
+ return await chain.ainvoke(input_dict)
174
+
175
+ def invoke(self, input_dict: dict) -> dict:
176
+ """Sync invocation of the RAG pipeline."""
177
+ import asyncio
178
+ return asyncio.run(self.ainvoke(input_dict))
179
+
128
180
  @classmethod
129
181
  def _apply_search_kwargs(cls, retriever: Any, search_kwargs: Optional[SearchKwargs] = None, search_type: Optional[SearchType] = None) -> Any:
130
182
  """Apply search kwargs and search type to the retriever if they exist"""
@@ -235,6 +287,10 @@ class LangChainRAGPipeline:
235
287
  )
236
288
  vector_store_retriever = vector_store_operator.vector_store.as_retriever()
237
289
  vector_store_retriever = cls._apply_search_kwargs(vector_store_retriever, config.search_kwargs, config.search_type)
290
+ distance_function = DistanceFunction.SQUARED_EUCLIDEAN_DISTANCE
291
+ if config.vector_store_config.is_sparse and config.vector_store_config.vector_size is not None:
292
+ # Use negative dot product for sparse retrieval.
293
+ distance_function = DistanceFunction.NEGATIVE_DOT_PRODUCT
238
294
  retriever = SQLRetriever(
239
295
  fallback_retriever=vector_store_retriever,
240
296
  vector_store_handler=knowledge_base_table.get_vector_db(),
@@ -248,8 +304,7 @@ class LangChainRAGPipeline:
248
304
  query_checker_template=retriever_config.query_checker_template,
249
305
  embeddings_table=knowledge_base_table._kb.vector_database_table,
250
306
  source_table=retriever_config.source_table,
251
- # Currently only similarity search is supported.
252
- distance_function=DistanceFunction.SQUARED_EUCLIDEAN_DISTANCE,
307
+ distance_function=distance_function,
253
308
  search_kwargs=config.search_kwargs,
254
309
  llm=sql_llm
255
310
  )
@@ -4,16 +4,15 @@ import asyncio
4
4
  import logging
5
5
  import math
6
6
  import os
7
+ import random
7
8
  from typing import Any, Dict, List, Optional, Sequence, Tuple
8
- from uuid import uuid4
9
+
9
10
  from langchain.retrievers.document_compressors.base import BaseDocumentCompressor
10
11
  from langchain_core.callbacks import Callbacks
11
-
12
- from mindsdb.integrations.utilities.rag.settings import DEFAULT_RERANKING_MODEL, DEFAULT_LLM_ENDPOINT
13
12
  from langchain_core.documents import Document
14
- from langchain_core.messages import HumanMessage, SystemMessage
15
- from langchain_openai import ChatOpenAI
13
+ from openai import AsyncOpenAI
16
14
 
15
+ from mindsdb.integrations.utilities.rag.settings import DEFAULT_RERANKING_MODEL, DEFAULT_LLM_ENDPOINT
17
16
 
18
17
  log = logging.getLogger(__name__)
19
18
 
@@ -23,128 +22,187 @@ class LLMReranker(BaseDocumentCompressor):
23
22
  model: str = DEFAULT_RERANKING_MODEL # Model to use for reranking
24
23
  temperature: float = 0.0 # Temperature for the model
25
24
  openai_api_key: Optional[str] = None
26
- remove_irrelevant: bool = True # New flag to control removal of irrelevant documents,
25
+ remove_irrelevant: bool = True # New flag to control removal of irrelevant documents
27
26
  base_url: str = DEFAULT_LLM_ENDPOINT
28
27
  num_docs_to_keep: Optional[int] = None # How many of the top documents to keep after reranking & compressing.
29
-
30
28
  _api_key_var: str = "OPENAI_API_KEY"
31
- client: Optional[Any] = None
29
+ client: Optional[AsyncOpenAI] = None
30
+ _semaphore: Optional[asyncio.Semaphore] = None
31
+ max_concurrent_requests: int = 20
32
+ max_retries: int = 3
33
+ retry_delay: float = 1.0
34
+ request_timeout: float = 20.0 # Timeout for API requests
35
+ early_stop: bool = True # Whether to enable early stopping
36
+ early_stop_threshold: float = 0.8 # Confidence threshold for early stopping
32
37
 
33
38
  class Config:
34
39
  arbitrary_types_allowed = True
35
40
 
36
- async def search_relevancy(self, query: str, document: str) -> Any:
37
- openai_api_key = self.openai_api_key or os.getenv(self._api_key_var)
38
-
39
- # Initialize the ChatOpenAI client
40
- client = ChatOpenAI(openai_api_base=self.base_url, api_key=openai_api_key, model=self.model, temperature=0,
41
- logprobs=True)
42
-
43
- # Create the message history for the conversation
44
- message_history = [
45
- SystemMessage(
46
- content="""Your task is to classify whether the document is relevant to the search query provided below. Answer just "YES" or "NO"."""),
47
- HumanMessage(content=f"""Document: ```{document}```; Search query: ```{query}```""")
48
- ]
49
-
50
- # Generate the response using LangChain's chat model
51
- response = await client.agenerate(
52
- messages=[message_history],
53
- max_tokens=1
54
- )
41
+ def __init__(self, **kwargs):
42
+ super().__init__(**kwargs)
43
+ self._semaphore = asyncio.Semaphore(self.max_concurrent_requests)
44
+
45
+ async def _init_client(self):
46
+ if self.client is None:
47
+ openai_api_key = self.openai_api_key or os.getenv(self._api_key_var)
48
+ if not openai_api_key:
49
+ raise ValueError(f"OpenAI API key not found in environment variable {self._api_key_var}")
50
+ self.client = AsyncOpenAI(
51
+ api_key=openai_api_key,
52
+ base_url=self.base_url,
53
+ timeout=self.request_timeout,
54
+ max_retries=2 # Client-level retries
55
+ )
55
56
 
56
- # Return the response from the model
57
- return response.generations[0]
57
+ async def search_relevancy(self, query: str, document: str) -> Any:
58
+ await self._init_client()
59
+
60
+ async with self._semaphore:
61
+ for attempt in range(self.max_retries):
62
+ try:
63
+ response = await self.client.chat.completions.create(
64
+ model=self.model,
65
+ messages=[
66
+ {"role": "system", "content": "Rate the relevance of the document to the query. Respond with 'yes' or 'no'."},
67
+ {"role": "user", "content": f"Query: {query}\nDocument: {document}\nIs this document relevant?"}
68
+ ],
69
+ temperature=self.temperature,
70
+ n=1,
71
+ logprobs=True,
72
+ max_tokens=1
73
+ )
74
+
75
+ # Extract response and logprobs
76
+ answer = response.choices[0].message.content
77
+ logprob = response.choices[0].logprobs.content[0].logprob
78
+
79
+ return {"answer": answer, "logprob": logprob}
80
+
81
+ except Exception as e:
82
+ if attempt == self.max_retries - 1:
83
+ log.error(f"Failed after {self.max_retries} attempts: {str(e)}")
84
+ raise
85
+ # Exponential backoff with jitter
86
+ retry_delay = self.retry_delay * (2 ** attempt) + random.uniform(0, 0.1)
87
+ await asyncio.sleep(retry_delay)
58
88
 
59
89
  async def _rank(self, query_document_pairs: List[Tuple[str, str]]) -> List[Tuple[str, float]]:
60
- # Gather results asynchronously for all query-document pairs
61
- results = await asyncio.gather(
62
- *[self.search_relevancy(query=query, document=document) for (query, document) in query_document_pairs]
63
- )
64
-
65
90
  ranked_results = []
66
91
 
67
- for idx, result in enumerate(results):
68
- # Extract the log probability (assuming logprobs are provided in LangChain response)
69
- msg = result[0].message
70
- logprob = msg.response_metadata['logprobs']['content'][0]['logprob']
71
- prob = math.exp(logprob)
72
- answer = result[0].message.content # The model's "YES" or "NO" response
73
-
74
- # Calculate the score based on the model's response
75
- if answer == "YES":
76
- score = prob
77
- elif answer.lower().strip().startswith("y"):
78
- score = prob
79
- elif answer == "NO":
80
- score = 1 - prob
81
- elif answer.lower().strip().startswith("n"):
82
- score = 1 - prob
83
- else:
84
- score = 0.0 # Default if something unexpected happens
85
-
86
- # Append the document and score to the result
87
- ranked_results.append((query_document_pairs[idx][1], score)) # (document, score)
92
+ # Process in larger batches for better throughput
93
+ batch_size = min(self.max_concurrent_requests * 2, len(query_document_pairs))
94
+ for i in range(0, len(query_document_pairs), batch_size):
95
+ batch = query_document_pairs[i:i + batch_size]
96
+ try:
97
+ results = await asyncio.gather(
98
+ *[self.search_relevancy(query=query, document=document) for (query, document) in batch],
99
+ return_exceptions=True
100
+ )
101
+
102
+ for idx, result in enumerate(results):
103
+ if isinstance(result, Exception):
104
+ log.error(f"Error processing document {i+idx}: {str(result)}")
105
+ ranked_results.append((batch[idx][1], 0.0))
106
+ continue
107
+
108
+ answer = result["answer"]
109
+ logprob = result["logprob"]
110
+ prob = math.exp(logprob)
111
+
112
+ # Convert answer to score using the model's confidence
113
+ if answer.lower().strip() == "yes":
114
+ score = prob # If yes, use the model's confidence
115
+ elif answer.lower().strip() == "no":
116
+ score = 1 - prob # If no, invert the confidence
117
+ else:
118
+ score = 0.5 * prob # For unclear answers, reduce confidence
119
+
120
+ ranked_results.append((batch[idx][1], score))
121
+
122
+ # Check if we should stop early
123
+ high_scoring_docs = [r for r in ranked_results if r[1] >= self.filtering_threshold]
124
+ can_stop_early = (
125
+ self.early_stop # Early stopping is enabled
126
+ and self.num_docs_to_keep # We have a target number of docs
127
+ and len(high_scoring_docs) >= self.num_docs_to_keep # Found enough good docs
128
+ and score >= self.early_stop_threshold # Current doc is good enough
129
+ )
130
+
131
+ if can_stop_early:
132
+ log.info(f"Early stopping after finding {self.num_docs_to_keep} documents with high confidence")
133
+ return ranked_results
134
+
135
+ except Exception as e:
136
+ log.error(f"Batch processing error: {str(e)}")
137
+ continue
88
138
 
89
139
  return ranked_results
90
140
 
91
- def compress_documents(
92
- self,
93
- documents: Sequence[Document],
94
- query: str,
95
- callbacks: Optional[Callbacks] = None,
141
+ async def acompress_documents(
142
+ self,
143
+ documents: Sequence[Document],
144
+ query: str,
145
+ callbacks: Optional[Callbacks] = None,
96
146
  ) -> Sequence[Document]:
97
- """Compress documents using OpenAI's rerank capability with individual document assessment."""
98
- log.info(f"Compressing documents. Initial count: {len(documents)}")
99
- if len(documents) == 0:
100
- log.warning("No documents to compress. Returning empty list.")
147
+ """Async compress documents using reranking with proper error handling."""
148
+ if callbacks:
149
+ await callbacks.on_retriever_start({"query": query}, "Reranking documents")
150
+
151
+ log.info(f"Async compressing documents. Initial count: {len(documents)}")
152
+ if not documents:
153
+ if callbacks:
154
+ await callbacks.on_retriever_end({"documents": []})
101
155
  return []
102
156
 
103
- doc_contents = [doc.page_content for doc in documents]
104
- query_documents_pairs = [(query, doc) for doc in doc_contents]
105
-
106
- # Create event loop and run async code
107
- import asyncio
108
157
  try:
109
- rankings = asyncio.get_event_loop().run_until_complete(self._rank(query_documents_pairs))
110
- except RuntimeError:
111
- # If no event loop is available, create a new one
112
- loop = asyncio.new_event_loop()
113
- asyncio.set_event_loop(loop)
114
- rankings = loop.run_until_complete(self._rank(query_documents_pairs))
115
-
116
- compressed = []
117
- for ind, ranking in enumerate(rankings):
118
- doc = documents[ind]
119
- document_text, score = ranking
120
- doc.metadata["relevance_score"] = score
121
- doc.metadata["is_relevant"] = score > self.filtering_threshold
122
- # Add the document to the compressed list if it is relevant or if we are not removing irrelevant documents
123
- if not self.remove_irrelevant:
124
- compressed.append(doc)
125
- elif doc.metadata["is_relevant"]:
126
- compressed.append(doc)
127
-
128
- log.info(f"Compression complete. {len(compressed)} documents returned")
129
- if not compressed:
130
- log.warning("No documents found after compression")
131
-
132
- if self.num_docs_to_keep is not None:
133
- # Sort by relevance score with highest first.
134
- compressed.sort(
135
- key=lambda d: d.metadata.get('relevance_score', 0) if d.metadata else 0,
136
- reverse=True
137
- )
138
- compressed = compressed[:self.num_docs_to_keep]
139
-
140
- # Handle retrieval callbacks to account for reranked & compressed docs.
141
- callbacks = callbacks if callbacks else []
142
- run_id = uuid4().hex
143
- if not isinstance(callbacks, list):
144
- callbacks = callbacks.handlers
145
- for callback in callbacks:
146
- callback.on_retriever_end(compressed, run_id=run_id)
147
- return compressed
158
+ # Prepare query-document pairs
159
+ query_document_pairs = [(query, doc.page_content) for doc in documents]
160
+
161
+ if callbacks:
162
+ await callbacks.on_text("Starting document reranking...")
163
+
164
+ # Get ranked results
165
+ ranked_results = await self._rank(query_document_pairs)
166
+
167
+ # Sort by score in descending order
168
+ ranked_results.sort(key=lambda x: x[1], reverse=True)
169
+
170
+ # Filter based on threshold and num_docs_to_keep
171
+ filtered_docs = []
172
+ for doc, score in ranked_results:
173
+ if score >= self.filtering_threshold:
174
+ matching_doc = next(d for d in documents if d.page_content == doc)
175
+ matching_doc.metadata = {**(matching_doc.metadata or {}), "relevance_score": score}
176
+ filtered_docs.append(matching_doc)
177
+
178
+ if callbacks:
179
+ await callbacks.on_text(f"Document scored {score:.2f}")
180
+
181
+ if self.num_docs_to_keep and len(filtered_docs) >= self.num_docs_to_keep:
182
+ break
183
+
184
+ log.info(f"Async compression complete. Final count: {len(filtered_docs)}")
185
+
186
+ if callbacks:
187
+ await callbacks.on_retriever_end({"documents": filtered_docs})
188
+
189
+ return filtered_docs
190
+
191
+ except Exception as e:
192
+ error_msg = f"Error during async document compression: {str(e)}"
193
+ log.error(error_msg)
194
+ if callbacks:
195
+ await callbacks.on_retriever_error(error_msg)
196
+ return documents # Return original documents on error
197
+
198
+ def compress_documents(
199
+ self,
200
+ documents: Sequence[Document],
201
+ query: str,
202
+ callbacks: Optional[Callbacks] = None,
203
+ ) -> Sequence[Document]:
204
+ """Sync wrapper for async compression."""
205
+ return asyncio.run(self.acompress_documents(documents, query, callbacks))
148
206
 
149
207
  @property
150
208
  def _identifying_params(self) -> Dict[str, Any]: