MindsDB 25.1.2.1__py3-none-any.whl → 25.1.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of MindsDB might be problematic. Click here for more details.
- {MindsDB-25.1.2.1.dist-info → MindsDB-25.1.4.0.dist-info}/METADATA +244 -242
- {MindsDB-25.1.2.1.dist-info → MindsDB-25.1.4.0.dist-info}/RECORD +76 -67
- mindsdb/__about__.py +1 -1
- mindsdb/__main__.py +5 -3
- mindsdb/api/executor/__init__.py +0 -1
- mindsdb/api/executor/command_executor.py +2 -1
- mindsdb/api/executor/data_types/answer.py +1 -1
- mindsdb/api/executor/datahub/datanodes/integration_datanode.py +7 -2
- mindsdb/api/executor/datahub/datanodes/project_datanode.py +8 -1
- mindsdb/api/executor/sql_query/__init__.py +1 -0
- mindsdb/api/executor/sql_query/result_set.py +36 -21
- mindsdb/api/executor/sql_query/steps/apply_predictor_step.py +1 -1
- mindsdb/api/executor/sql_query/steps/join_step.py +4 -4
- mindsdb/api/executor/sql_query/steps/map_reduce_step.py +6 -39
- mindsdb/api/executor/utilities/sql.py +2 -10
- mindsdb/api/http/namespaces/knowledge_bases.py +3 -3
- mindsdb/api/http/namespaces/sql.py +3 -1
- mindsdb/api/mysql/mysql_proxy/executor/mysql_executor.py +2 -1
- mindsdb/api/mysql/mysql_proxy/mysql_proxy.py +7 -0
- mindsdb/api/postgres/postgres_proxy/executor/executor.py +2 -1
- mindsdb/integrations/handlers/chromadb_handler/chromadb_handler.py +2 -2
- mindsdb/integrations/handlers/chromadb_handler/requirements.txt +1 -1
- mindsdb/integrations/handlers/file_handler/file_handler.py +1 -1
- mindsdb/integrations/handlers/file_handler/tests/test_file_handler.py +17 -1
- mindsdb/integrations/handlers/jira_handler/jira_handler.py +15 -1
- mindsdb/integrations/handlers/jira_handler/jira_table.py +52 -31
- mindsdb/integrations/handlers/langchain_embedding_handler/fastapi_embeddings.py +82 -0
- mindsdb/integrations/handlers/langchain_embedding_handler/langchain_embedding_handler.py +8 -1
- mindsdb/integrations/handlers/langchain_handler/requirements.txt +1 -1
- mindsdb/integrations/handlers/pgvector_handler/pgvector_handler.py +48 -16
- mindsdb/integrations/handlers/pinecone_handler/pinecone_handler.py +123 -72
- mindsdb/integrations/handlers/pinecone_handler/requirements.txt +1 -1
- mindsdb/integrations/handlers/postgres_handler/postgres_handler.py +12 -6
- mindsdb/integrations/handlers/slack_handler/slack_handler.py +13 -2
- mindsdb/integrations/handlers/slack_handler/slack_tables.py +21 -1
- mindsdb/integrations/libs/ml_handler_process/learn_process.py +1 -1
- mindsdb/integrations/utilities/rag/loaders/vector_store_loader/pgvector.py +76 -27
- mindsdb/integrations/utilities/rag/loaders/vector_store_loader/vector_store_loader.py +18 -1
- mindsdb/integrations/utilities/rag/pipelines/rag.py +73 -18
- mindsdb/integrations/utilities/rag/rerankers/reranker_compressor.py +166 -108
- mindsdb/integrations/utilities/rag/retrievers/sql_retriever.py +36 -14
- mindsdb/integrations/utilities/rag/settings.py +8 -2
- mindsdb/integrations/utilities/sql_utils.py +1 -1
- mindsdb/interfaces/agents/agents_controller.py +3 -5
- mindsdb/interfaces/agents/langchain_agent.py +112 -150
- mindsdb/interfaces/agents/langfuse_callback_handler.py +0 -37
- mindsdb/interfaces/agents/mindsdb_database_agent.py +15 -13
- mindsdb/interfaces/chatbot/chatbot_controller.py +7 -11
- mindsdb/interfaces/chatbot/chatbot_task.py +16 -5
- mindsdb/interfaces/chatbot/memory.py +58 -13
- mindsdb/interfaces/database/projects.py +17 -15
- mindsdb/interfaces/database/views.py +12 -25
- mindsdb/interfaces/knowledge_base/controller.py +39 -15
- mindsdb/interfaces/model/functions.py +15 -4
- mindsdb/interfaces/model/model_controller.py +4 -7
- mindsdb/interfaces/skills/custom/text2sql/mindsdb_sql_toolkit.py +47 -38
- mindsdb/interfaces/skills/retrieval_tool.py +10 -3
- mindsdb/interfaces/skills/skill_tool.py +97 -53
- mindsdb/interfaces/skills/sql_agent.py +77 -36
- mindsdb/interfaces/storage/db.py +1 -1
- mindsdb/migrations/versions/2025-01-15_c06c35f7e8e1_project_company.py +88 -0
- mindsdb/utilities/cache.py +7 -4
- mindsdb/utilities/context.py +11 -1
- mindsdb/utilities/langfuse.py +264 -0
- mindsdb/utilities/log.py +20 -2
- mindsdb/utilities/otel/__init__.py +206 -0
- mindsdb/utilities/otel/logger.py +25 -0
- mindsdb/utilities/otel/meter.py +19 -0
- mindsdb/utilities/otel/metric_handlers/__init__.py +25 -0
- mindsdb/utilities/otel/tracer.py +16 -0
- mindsdb/utilities/partitioning.py +52 -0
- mindsdb/utilities/render/sqlalchemy_render.py +7 -1
- mindsdb/utilities/utils.py +34 -0
- mindsdb/utilities/otel.py +0 -72
- {MindsDB-25.1.2.1.dist-info → MindsDB-25.1.4.0.dist-info}/LICENSE +0 -0
- {MindsDB-25.1.2.1.dist-info → MindsDB-25.1.4.0.dist-info}/WHEEL +0 -0
- {MindsDB-25.1.2.1.dist-info → MindsDB-25.1.4.0.dist-info}/top_level.txt +0 -0
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
from typing import Any, List, Optional, Dict
|
|
1
|
+
from typing import Any, List, Union, Optional, Dict
|
|
2
2
|
|
|
3
3
|
from langchain_community.vectorstores import PGVector
|
|
4
4
|
from langchain_community.vectorstores.pgvector import Base
|
|
5
5
|
|
|
6
|
-
from pgvector.sqlalchemy import Vector
|
|
6
|
+
from pgvector.sqlalchemy import SPARSEVEC, Vector
|
|
7
7
|
import sqlalchemy as sa
|
|
8
8
|
from sqlalchemy.dialects.postgresql import JSON
|
|
9
9
|
|
|
@@ -15,9 +15,17 @@ _generated_sa_tables = {}
|
|
|
15
15
|
|
|
16
16
|
class PGVectorMDB(PGVector):
|
|
17
17
|
"""
|
|
18
|
-
|
|
18
|
+
langchain_community.vectorstores.PGVector adapted for mindsdb vector store table structure
|
|
19
19
|
"""
|
|
20
20
|
|
|
21
|
+
def __init__(self, *args, is_sparse: bool = False, vector_size: Optional[int] = None, **kwargs):
|
|
22
|
+
# todo get is_sparse and vector_size from kb vector table
|
|
23
|
+
self.is_sparse = is_sparse
|
|
24
|
+
if is_sparse and vector_size is None:
|
|
25
|
+
raise ValueError("vector_size is required when is_sparse=True")
|
|
26
|
+
self.vector_size = vector_size
|
|
27
|
+
super().__init__(*args, **kwargs)
|
|
28
|
+
|
|
21
29
|
def __post_init__(
|
|
22
30
|
self,
|
|
23
31
|
) -> None:
|
|
@@ -32,53 +40,94 @@ class PGVectorMDB(PGVector):
|
|
|
32
40
|
__tablename__ = collection_name
|
|
33
41
|
|
|
34
42
|
id = sa.Column(sa.Integer, primary_key=True)
|
|
35
|
-
embedding
|
|
36
|
-
|
|
37
|
-
|
|
43
|
+
embedding = sa.Column(
|
|
44
|
+
"embeddings",
|
|
45
|
+
SPARSEVEC() if self.is_sparse else Vector() if self.vector_size is None else
|
|
46
|
+
SPARSEVEC(self.vector_size) if self.is_sparse else Vector(self.vector_size)
|
|
47
|
+
)
|
|
48
|
+
document = sa.Column("content", sa.String, nullable=True)
|
|
49
|
+
cmetadata = sa.Column("metadata", JSON, nullable=True)
|
|
38
50
|
|
|
39
51
|
_generated_sa_tables[collection_name] = EmbeddingStore
|
|
40
52
|
|
|
41
53
|
self.EmbeddingStore = _generated_sa_tables[collection_name]
|
|
42
54
|
|
|
43
55
|
def __query_collection(
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
56
|
+
self,
|
|
57
|
+
embedding: Union[List[float], Dict[int, float], str],
|
|
58
|
+
k: int = 4,
|
|
59
|
+
filter: Optional[Dict[str, str]] = None,
|
|
48
60
|
) -> List[Any]:
|
|
49
61
|
"""Query the collection."""
|
|
50
62
|
with Session(self._bind) as session:
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
63
|
+
if self.is_sparse:
|
|
64
|
+
# Sparse vectors: expect string in format "{key:value,...}/size" or dictionary
|
|
65
|
+
if isinstance(embedding, dict):
|
|
66
|
+
from pgvector.utils import SparseVector
|
|
67
|
+
embedding = SparseVector(embedding, self.vector_size)
|
|
68
|
+
embedding_str = embedding.to_text()
|
|
69
|
+
elif isinstance(embedding, str):
|
|
70
|
+
# Use string as is - it should already be in the correct format
|
|
71
|
+
embedding_str = embedding
|
|
72
|
+
# Use inner product for sparse vectors
|
|
73
|
+
distance_op = "<#>"
|
|
74
|
+
# For inner product, larger values are better matches
|
|
75
|
+
order_direction = "ASC"
|
|
76
|
+
else:
|
|
77
|
+
# Dense vectors: expect string in JSON array format or list of floats
|
|
78
|
+
if isinstance(embedding, list):
|
|
79
|
+
embedding_str = f"[{','.join(str(x) for x in embedding)}]"
|
|
80
|
+
elif isinstance(embedding, str):
|
|
81
|
+
embedding_str = embedding
|
|
82
|
+
# Use cosine similarity for dense vectors
|
|
83
|
+
distance_op = "<=>"
|
|
84
|
+
# For cosine similarity, smaller values are better matches
|
|
85
|
+
order_direction = "ASC"
|
|
86
|
+
|
|
87
|
+
# Use SQL directly for vector comparison
|
|
88
|
+
query = sa.text(
|
|
89
|
+
f"""
|
|
90
|
+
SELECT t.*, t.embeddings {distance_op} '{embedding_str}' as distance
|
|
91
|
+
FROM {self.collection_name} t
|
|
92
|
+
ORDER BY distance {order_direction}
|
|
93
|
+
LIMIT {k}
|
|
94
|
+
"""
|
|
60
95
|
)
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
96
|
+
results = session.execute(query).all()
|
|
97
|
+
|
|
98
|
+
# Convert results to the expected format
|
|
99
|
+
formatted_results = []
|
|
100
|
+
for rec in results:
|
|
101
|
+
metadata = rec.metadata if bool(rec.metadata) else {0: 0}
|
|
102
|
+
embedding_store = self.EmbeddingStore()
|
|
103
|
+
embedding_store.document = rec.content
|
|
104
|
+
embedding_store.cmetadata = metadata
|
|
105
|
+
result = type(
|
|
106
|
+
'Result', (), {
|
|
107
|
+
'EmbeddingStore': embedding_store,
|
|
108
|
+
'distance': rec.distance
|
|
109
|
+
}
|
|
110
|
+
)
|
|
111
|
+
formatted_results.append(result)
|
|
64
112
|
|
|
65
|
-
|
|
113
|
+
return formatted_results
|
|
66
114
|
|
|
67
115
|
# aliases for different langchain versions
|
|
68
116
|
def _PGVector__query_collection(self, *args, **kwargs):
|
|
117
|
+
|
|
69
118
|
return self.__query_collection(*args, **kwargs)
|
|
70
119
|
|
|
71
120
|
def _query_collection(self, *args, **kwargs):
|
|
72
121
|
return self.__query_collection(*args, **kwargs)
|
|
73
122
|
|
|
74
123
|
def create_collection(self):
|
|
75
|
-
raise RuntimeError(
|
|
124
|
+
raise RuntimeError("Forbidden")
|
|
76
125
|
|
|
77
126
|
def delete_collection(self):
|
|
78
|
-
raise RuntimeError(
|
|
127
|
+
raise RuntimeError("Forbidden")
|
|
79
128
|
|
|
80
129
|
def delete(self, *args, **kwargs):
|
|
81
|
-
raise RuntimeError(
|
|
130
|
+
raise RuntimeError("Forbidden")
|
|
82
131
|
|
|
83
132
|
def add_embeddings(self, *args, **kwargs):
|
|
84
|
-
raise RuntimeError(
|
|
133
|
+
raise RuntimeError("Forbidden")
|
|
@@ -7,6 +7,7 @@ from pydantic import BaseModel
|
|
|
7
7
|
|
|
8
8
|
from mindsdb.integrations.utilities.rag.settings import VectorStoreType, VectorStoreConfig
|
|
9
9
|
from mindsdb.integrations.utilities.rag.loaders.vector_store_loader.MDBVectorStore import MDBVectorStore
|
|
10
|
+
from mindsdb.integrations.utilities.rag.loaders.vector_store_loader.pgvector import PGVectorMDB
|
|
10
11
|
from mindsdb.utilities import log
|
|
11
12
|
|
|
12
13
|
|
|
@@ -28,6 +29,20 @@ class VectorStoreLoader(BaseModel):
|
|
|
28
29
|
Loads the vector store based on the provided config and embeddings model
|
|
29
30
|
:return:
|
|
30
31
|
"""
|
|
32
|
+
if self.config.is_sparse is not None and self.config.vector_size is not None and self.config.kb_table is not None:
|
|
33
|
+
# Only use PGVector store for sparse vectors.
|
|
34
|
+
db_handler = self.config.kb_table.get_vector_db()
|
|
35
|
+
db_args = db_handler.connection_args
|
|
36
|
+
# Assume we are always using PGVector & psycopg2.
|
|
37
|
+
connection_str = f"postgresql+psycopg2://{db_args.get('user')}:{db_args.get('password')}@{db_args.get('host')}:{db_args.get('port')}/{db_args.get('dbname', db_args.get('database'))}"
|
|
38
|
+
|
|
39
|
+
return PGVectorMDB(
|
|
40
|
+
connection_string=connection_str,
|
|
41
|
+
collection_name=self.config.kb_table._kb.vector_database_table,
|
|
42
|
+
embedding_function=self.embedding_model,
|
|
43
|
+
is_sparse=self.config.is_sparse,
|
|
44
|
+
vector_size=self.config.vector_size
|
|
45
|
+
)
|
|
31
46
|
return MDBVectorStore(kb_table=self.config.kb_table)
|
|
32
47
|
|
|
33
48
|
|
|
@@ -56,5 +71,7 @@ class VectorStoreFactory:
|
|
|
56
71
|
return PGVectorMDB(
|
|
57
72
|
connection_string=settings.connection_string,
|
|
58
73
|
collection_name=settings.collection_name,
|
|
59
|
-
embedding_function=embedding_model
|
|
74
|
+
embedding_function=embedding_model,
|
|
75
|
+
is_sparse=settings.is_sparse,
|
|
76
|
+
vector_size=settings.vector_size
|
|
60
77
|
)
|
|
@@ -1,8 +1,9 @@
|
|
|
1
1
|
from copy import copy
|
|
2
|
-
from typing import Optional, Any
|
|
2
|
+
from typing import Optional, Any, List
|
|
3
3
|
|
|
4
4
|
from langchain_core.output_parsers import StrOutputParser
|
|
5
5
|
from langchain.retrievers import ContextualCompressionRetriever
|
|
6
|
+
from langchain_core.documents import Document
|
|
6
7
|
|
|
7
8
|
from langchain_core.prompts import ChatPromptTemplate
|
|
8
9
|
from langchain_core.runnables import RunnableParallel, RunnablePassthrough, RunnableSerializable
|
|
@@ -28,6 +29,23 @@ from mindsdb.interfaces.agents.langchain_agent import create_chat_model
|
|
|
28
29
|
class LangChainRAGPipeline:
|
|
29
30
|
"""
|
|
30
31
|
Builds a RAG pipeline using langchain LCEL components
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
retriever_runnable: Base retriever component
|
|
35
|
+
prompt_template: Template for generating responses
|
|
36
|
+
llm: Language model for generating responses
|
|
37
|
+
reranker (bool): Whether to use reranking (default: False)
|
|
38
|
+
reranker_config (RerankerConfig): Configuration for the reranker, including:
|
|
39
|
+
- model: Model to use for reranking
|
|
40
|
+
- filtering_threshold: Minimum score to keep a document
|
|
41
|
+
- num_docs_to_keep: Maximum number of documents to keep
|
|
42
|
+
- max_concurrent_requests: Maximum concurrent API requests
|
|
43
|
+
- max_retries: Number of retry attempts for failed requests
|
|
44
|
+
- retry_delay: Delay between retries
|
|
45
|
+
- early_stop (bool): Whether to enable early stopping
|
|
46
|
+
- early_stop_threshold: Confidence threshold for early stopping
|
|
47
|
+
vector_store_config (VectorStoreConfig): Vector store configuration
|
|
48
|
+
summarization_config (SummarizationConfig): Summarization configuration
|
|
31
49
|
"""
|
|
32
50
|
|
|
33
51
|
def __init__(
|
|
@@ -40,19 +58,15 @@ class LangChainRAGPipeline:
|
|
|
40
58
|
vector_store_config: Optional[VectorStoreConfig] = None,
|
|
41
59
|
summarization_config: Optional[SummarizationConfig] = None
|
|
42
60
|
):
|
|
43
|
-
|
|
44
61
|
self.retriever_runnable = retriever_runnable
|
|
45
62
|
self.prompt_template = prompt_template
|
|
46
63
|
self.llm = llm
|
|
47
64
|
if reranker:
|
|
48
65
|
if reranker_config is None:
|
|
49
66
|
reranker_config = RerankerConfig()
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
filtering_threshold=reranker_config.filtering_threshold,
|
|
54
|
-
num_docs_to_keep=reranker_config.num_docs_to_keep
|
|
55
|
-
)
|
|
67
|
+
# Convert config to dict and initialize reranker
|
|
68
|
+
reranker_kwargs = reranker_config.model_dump(exclude_none=True)
|
|
69
|
+
self.reranker = LLMReranker(**reranker_kwargs)
|
|
56
70
|
else:
|
|
57
71
|
self.reranker = None
|
|
58
72
|
self.summarizer = None
|
|
@@ -102,17 +116,45 @@ class LangChainRAGPipeline:
|
|
|
102
116
|
raise ValueError("One of the required components (llm) is None")
|
|
103
117
|
|
|
104
118
|
if self.reranker:
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
119
|
+
# Create a custom retriever that handles async operations properly
|
|
120
|
+
class AsyncRerankerRetriever(ContextualCompressionRetriever):
|
|
121
|
+
"""Async-aware retriever that properly handles concurrent reranking operations."""
|
|
122
|
+
|
|
123
|
+
def __init__(self, base_retriever, reranker):
|
|
124
|
+
super().__init__(
|
|
125
|
+
base_compressor=reranker,
|
|
126
|
+
base_retriever=base_retriever
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
async def ainvoke(self, query: str) -> List[Document]:
|
|
130
|
+
"""Async retrieval with proper concurrency handling."""
|
|
131
|
+
# Get initial documents
|
|
132
|
+
if hasattr(self.base_retriever, 'ainvoke'):
|
|
133
|
+
docs = await self.base_retriever.ainvoke(query)
|
|
134
|
+
else:
|
|
135
|
+
docs = await RunnablePassthrough(self.base_retriever.get_relevant_documents)(query)
|
|
136
|
+
|
|
137
|
+
# Rerank documents
|
|
138
|
+
if docs:
|
|
139
|
+
docs = await self.base_compressor.acompress_documents(docs, query)
|
|
140
|
+
return docs
|
|
141
|
+
|
|
142
|
+
def get_relevant_documents(self, query: str) -> List[Document]:
|
|
143
|
+
"""Sync wrapper for async retrieval."""
|
|
144
|
+
import asyncio
|
|
145
|
+
return asyncio.run(self.ainvoke(query))
|
|
146
|
+
|
|
147
|
+
# Use our custom async-aware retriever
|
|
148
|
+
self.retriever_runnable = AsyncRerankerRetriever(
|
|
149
|
+
base_retriever=copy(self.retriever_runnable),
|
|
150
|
+
reranker=self.reranker
|
|
109
151
|
)
|
|
110
152
|
|
|
111
153
|
rag_chain_from_docs = (
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
154
|
+
RunnablePassthrough.assign(context=(lambda x: format_docs(x["context"])))
|
|
155
|
+
| prompt
|
|
156
|
+
| self.llm
|
|
157
|
+
| StrOutputParser()
|
|
116
158
|
)
|
|
117
159
|
|
|
118
160
|
retrieval_chain = RunnableParallel(
|
|
@@ -125,6 +167,16 @@ class LangChainRAGPipeline:
|
|
|
125
167
|
rag_chain_with_source = retrieval_chain.assign(answer=rag_chain_from_docs)
|
|
126
168
|
return rag_chain_with_source
|
|
127
169
|
|
|
170
|
+
async def ainvoke(self, input_dict: dict) -> dict:
|
|
171
|
+
"""Async invocation of the RAG pipeline."""
|
|
172
|
+
chain = self.with_returned_sources()
|
|
173
|
+
return await chain.ainvoke(input_dict)
|
|
174
|
+
|
|
175
|
+
def invoke(self, input_dict: dict) -> dict:
|
|
176
|
+
"""Sync invocation of the RAG pipeline."""
|
|
177
|
+
import asyncio
|
|
178
|
+
return asyncio.run(self.ainvoke(input_dict))
|
|
179
|
+
|
|
128
180
|
@classmethod
|
|
129
181
|
def _apply_search_kwargs(cls, retriever: Any, search_kwargs: Optional[SearchKwargs] = None, search_type: Optional[SearchType] = None) -> Any:
|
|
130
182
|
"""Apply search kwargs and search type to the retriever if they exist"""
|
|
@@ -235,6 +287,10 @@ class LangChainRAGPipeline:
|
|
|
235
287
|
)
|
|
236
288
|
vector_store_retriever = vector_store_operator.vector_store.as_retriever()
|
|
237
289
|
vector_store_retriever = cls._apply_search_kwargs(vector_store_retriever, config.search_kwargs, config.search_type)
|
|
290
|
+
distance_function = DistanceFunction.SQUARED_EUCLIDEAN_DISTANCE
|
|
291
|
+
if config.vector_store_config.is_sparse and config.vector_store_config.vector_size is not None:
|
|
292
|
+
# Use negative dot product for sparse retrieval.
|
|
293
|
+
distance_function = DistanceFunction.NEGATIVE_DOT_PRODUCT
|
|
238
294
|
retriever = SQLRetriever(
|
|
239
295
|
fallback_retriever=vector_store_retriever,
|
|
240
296
|
vector_store_handler=knowledge_base_table.get_vector_db(),
|
|
@@ -248,8 +304,7 @@ class LangChainRAGPipeline:
|
|
|
248
304
|
query_checker_template=retriever_config.query_checker_template,
|
|
249
305
|
embeddings_table=knowledge_base_table._kb.vector_database_table,
|
|
250
306
|
source_table=retriever_config.source_table,
|
|
251
|
-
|
|
252
|
-
distance_function=DistanceFunction.SQUARED_EUCLIDEAN_DISTANCE,
|
|
307
|
+
distance_function=distance_function,
|
|
253
308
|
search_kwargs=config.search_kwargs,
|
|
254
309
|
llm=sql_llm
|
|
255
310
|
)
|
|
@@ -4,16 +4,15 @@ import asyncio
|
|
|
4
4
|
import logging
|
|
5
5
|
import math
|
|
6
6
|
import os
|
|
7
|
+
import random
|
|
7
8
|
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
|
8
|
-
|
|
9
|
+
|
|
9
10
|
from langchain.retrievers.document_compressors.base import BaseDocumentCompressor
|
|
10
11
|
from langchain_core.callbacks import Callbacks
|
|
11
|
-
|
|
12
|
-
from mindsdb.integrations.utilities.rag.settings import DEFAULT_RERANKING_MODEL, DEFAULT_LLM_ENDPOINT
|
|
13
12
|
from langchain_core.documents import Document
|
|
14
|
-
from
|
|
15
|
-
from langchain_openai import ChatOpenAI
|
|
13
|
+
from openai import AsyncOpenAI
|
|
16
14
|
|
|
15
|
+
from mindsdb.integrations.utilities.rag.settings import DEFAULT_RERANKING_MODEL, DEFAULT_LLM_ENDPOINT
|
|
17
16
|
|
|
18
17
|
log = logging.getLogger(__name__)
|
|
19
18
|
|
|
@@ -23,128 +22,187 @@ class LLMReranker(BaseDocumentCompressor):
|
|
|
23
22
|
model: str = DEFAULT_RERANKING_MODEL # Model to use for reranking
|
|
24
23
|
temperature: float = 0.0 # Temperature for the model
|
|
25
24
|
openai_api_key: Optional[str] = None
|
|
26
|
-
remove_irrelevant: bool = True # New flag to control removal of irrelevant documents
|
|
25
|
+
remove_irrelevant: bool = True # New flag to control removal of irrelevant documents
|
|
27
26
|
base_url: str = DEFAULT_LLM_ENDPOINT
|
|
28
27
|
num_docs_to_keep: Optional[int] = None # How many of the top documents to keep after reranking & compressing.
|
|
29
|
-
|
|
30
28
|
_api_key_var: str = "OPENAI_API_KEY"
|
|
31
|
-
client: Optional[
|
|
29
|
+
client: Optional[AsyncOpenAI] = None
|
|
30
|
+
_semaphore: Optional[asyncio.Semaphore] = None
|
|
31
|
+
max_concurrent_requests: int = 20
|
|
32
|
+
max_retries: int = 3
|
|
33
|
+
retry_delay: float = 1.0
|
|
34
|
+
request_timeout: float = 20.0 # Timeout for API requests
|
|
35
|
+
early_stop: bool = True # Whether to enable early stopping
|
|
36
|
+
early_stop_threshold: float = 0.8 # Confidence threshold for early stopping
|
|
32
37
|
|
|
33
38
|
class Config:
|
|
34
39
|
arbitrary_types_allowed = True
|
|
35
40
|
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
response = await client.agenerate(
|
|
52
|
-
messages=[message_history],
|
|
53
|
-
max_tokens=1
|
|
54
|
-
)
|
|
41
|
+
def __init__(self, **kwargs):
|
|
42
|
+
super().__init__(**kwargs)
|
|
43
|
+
self._semaphore = asyncio.Semaphore(self.max_concurrent_requests)
|
|
44
|
+
|
|
45
|
+
async def _init_client(self):
|
|
46
|
+
if self.client is None:
|
|
47
|
+
openai_api_key = self.openai_api_key or os.getenv(self._api_key_var)
|
|
48
|
+
if not openai_api_key:
|
|
49
|
+
raise ValueError(f"OpenAI API key not found in environment variable {self._api_key_var}")
|
|
50
|
+
self.client = AsyncOpenAI(
|
|
51
|
+
api_key=openai_api_key,
|
|
52
|
+
base_url=self.base_url,
|
|
53
|
+
timeout=self.request_timeout,
|
|
54
|
+
max_retries=2 # Client-level retries
|
|
55
|
+
)
|
|
55
56
|
|
|
56
|
-
|
|
57
|
-
|
|
57
|
+
async def search_relevancy(self, query: str, document: str) -> Any:
|
|
58
|
+
await self._init_client()
|
|
59
|
+
|
|
60
|
+
async with self._semaphore:
|
|
61
|
+
for attempt in range(self.max_retries):
|
|
62
|
+
try:
|
|
63
|
+
response = await self.client.chat.completions.create(
|
|
64
|
+
model=self.model,
|
|
65
|
+
messages=[
|
|
66
|
+
{"role": "system", "content": "Rate the relevance of the document to the query. Respond with 'yes' or 'no'."},
|
|
67
|
+
{"role": "user", "content": f"Query: {query}\nDocument: {document}\nIs this document relevant?"}
|
|
68
|
+
],
|
|
69
|
+
temperature=self.temperature,
|
|
70
|
+
n=1,
|
|
71
|
+
logprobs=True,
|
|
72
|
+
max_tokens=1
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
# Extract response and logprobs
|
|
76
|
+
answer = response.choices[0].message.content
|
|
77
|
+
logprob = response.choices[0].logprobs.content[0].logprob
|
|
78
|
+
|
|
79
|
+
return {"answer": answer, "logprob": logprob}
|
|
80
|
+
|
|
81
|
+
except Exception as e:
|
|
82
|
+
if attempt == self.max_retries - 1:
|
|
83
|
+
log.error(f"Failed after {self.max_retries} attempts: {str(e)}")
|
|
84
|
+
raise
|
|
85
|
+
# Exponential backoff with jitter
|
|
86
|
+
retry_delay = self.retry_delay * (2 ** attempt) + random.uniform(0, 0.1)
|
|
87
|
+
await asyncio.sleep(retry_delay)
|
|
58
88
|
|
|
59
89
|
async def _rank(self, query_document_pairs: List[Tuple[str, str]]) -> List[Tuple[str, float]]:
|
|
60
|
-
# Gather results asynchronously for all query-document pairs
|
|
61
|
-
results = await asyncio.gather(
|
|
62
|
-
*[self.search_relevancy(query=query, document=document) for (query, document) in query_document_pairs]
|
|
63
|
-
)
|
|
64
|
-
|
|
65
90
|
ranked_results = []
|
|
66
91
|
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
92
|
+
# Process in larger batches for better throughput
|
|
93
|
+
batch_size = min(self.max_concurrent_requests * 2, len(query_document_pairs))
|
|
94
|
+
for i in range(0, len(query_document_pairs), batch_size):
|
|
95
|
+
batch = query_document_pairs[i:i + batch_size]
|
|
96
|
+
try:
|
|
97
|
+
results = await asyncio.gather(
|
|
98
|
+
*[self.search_relevancy(query=query, document=document) for (query, document) in batch],
|
|
99
|
+
return_exceptions=True
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
for idx, result in enumerate(results):
|
|
103
|
+
if isinstance(result, Exception):
|
|
104
|
+
log.error(f"Error processing document {i+idx}: {str(result)}")
|
|
105
|
+
ranked_results.append((batch[idx][1], 0.0))
|
|
106
|
+
continue
|
|
107
|
+
|
|
108
|
+
answer = result["answer"]
|
|
109
|
+
logprob = result["logprob"]
|
|
110
|
+
prob = math.exp(logprob)
|
|
111
|
+
|
|
112
|
+
# Convert answer to score using the model's confidence
|
|
113
|
+
if answer.lower().strip() == "yes":
|
|
114
|
+
score = prob # If yes, use the model's confidence
|
|
115
|
+
elif answer.lower().strip() == "no":
|
|
116
|
+
score = 1 - prob # If no, invert the confidence
|
|
117
|
+
else:
|
|
118
|
+
score = 0.5 * prob # For unclear answers, reduce confidence
|
|
119
|
+
|
|
120
|
+
ranked_results.append((batch[idx][1], score))
|
|
121
|
+
|
|
122
|
+
# Check if we should stop early
|
|
123
|
+
high_scoring_docs = [r for r in ranked_results if r[1] >= self.filtering_threshold]
|
|
124
|
+
can_stop_early = (
|
|
125
|
+
self.early_stop # Early stopping is enabled
|
|
126
|
+
and self.num_docs_to_keep # We have a target number of docs
|
|
127
|
+
and len(high_scoring_docs) >= self.num_docs_to_keep # Found enough good docs
|
|
128
|
+
and score >= self.early_stop_threshold # Current doc is good enough
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
if can_stop_early:
|
|
132
|
+
log.info(f"Early stopping after finding {self.num_docs_to_keep} documents with high confidence")
|
|
133
|
+
return ranked_results
|
|
134
|
+
|
|
135
|
+
except Exception as e:
|
|
136
|
+
log.error(f"Batch processing error: {str(e)}")
|
|
137
|
+
continue
|
|
88
138
|
|
|
89
139
|
return ranked_results
|
|
90
140
|
|
|
91
|
-
def
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
141
|
+
async def acompress_documents(
|
|
142
|
+
self,
|
|
143
|
+
documents: Sequence[Document],
|
|
144
|
+
query: str,
|
|
145
|
+
callbacks: Optional[Callbacks] = None,
|
|
96
146
|
) -> Sequence[Document]:
|
|
97
|
-
"""
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
147
|
+
"""Async compress documents using reranking with proper error handling."""
|
|
148
|
+
if callbacks:
|
|
149
|
+
await callbacks.on_retriever_start({"query": query}, "Reranking documents")
|
|
150
|
+
|
|
151
|
+
log.info(f"Async compressing documents. Initial count: {len(documents)}")
|
|
152
|
+
if not documents:
|
|
153
|
+
if callbacks:
|
|
154
|
+
await callbacks.on_retriever_end({"documents": []})
|
|
101
155
|
return []
|
|
102
156
|
|
|
103
|
-
doc_contents = [doc.page_content for doc in documents]
|
|
104
|
-
query_documents_pairs = [(query, doc) for doc in doc_contents]
|
|
105
|
-
|
|
106
|
-
# Create event loop and run async code
|
|
107
|
-
import asyncio
|
|
108
157
|
try:
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
158
|
+
# Prepare query-document pairs
|
|
159
|
+
query_document_pairs = [(query, doc.page_content) for doc in documents]
|
|
160
|
+
|
|
161
|
+
if callbacks:
|
|
162
|
+
await callbacks.on_text("Starting document reranking...")
|
|
163
|
+
|
|
164
|
+
# Get ranked results
|
|
165
|
+
ranked_results = await self._rank(query_document_pairs)
|
|
166
|
+
|
|
167
|
+
# Sort by score in descending order
|
|
168
|
+
ranked_results.sort(key=lambda x: x[1], reverse=True)
|
|
169
|
+
|
|
170
|
+
# Filter based on threshold and num_docs_to_keep
|
|
171
|
+
filtered_docs = []
|
|
172
|
+
for doc, score in ranked_results:
|
|
173
|
+
if score >= self.filtering_threshold:
|
|
174
|
+
matching_doc = next(d for d in documents if d.page_content == doc)
|
|
175
|
+
matching_doc.metadata = {**(matching_doc.metadata or {}), "relevance_score": score}
|
|
176
|
+
filtered_docs.append(matching_doc)
|
|
177
|
+
|
|
178
|
+
if callbacks:
|
|
179
|
+
await callbacks.on_text(f"Document scored {score:.2f}")
|
|
180
|
+
|
|
181
|
+
if self.num_docs_to_keep and len(filtered_docs) >= self.num_docs_to_keep:
|
|
182
|
+
break
|
|
183
|
+
|
|
184
|
+
log.info(f"Async compression complete. Final count: {len(filtered_docs)}")
|
|
185
|
+
|
|
186
|
+
if callbacks:
|
|
187
|
+
await callbacks.on_retriever_end({"documents": filtered_docs})
|
|
188
|
+
|
|
189
|
+
return filtered_docs
|
|
190
|
+
|
|
191
|
+
except Exception as e:
|
|
192
|
+
error_msg = f"Error during async document compression: {str(e)}"
|
|
193
|
+
log.error(error_msg)
|
|
194
|
+
if callbacks:
|
|
195
|
+
await callbacks.on_retriever_error(error_msg)
|
|
196
|
+
return documents # Return original documents on error
|
|
197
|
+
|
|
198
|
+
def compress_documents(
|
|
199
|
+
self,
|
|
200
|
+
documents: Sequence[Document],
|
|
201
|
+
query: str,
|
|
202
|
+
callbacks: Optional[Callbacks] = None,
|
|
203
|
+
) -> Sequence[Document]:
|
|
204
|
+
"""Sync wrapper for async compression."""
|
|
205
|
+
return asyncio.run(self.acompress_documents(documents, query, callbacks))
|
|
148
206
|
|
|
149
207
|
@property
|
|
150
208
|
def _identifying_params(self) -> Dict[str, Any]:
|