MindsDB 25.4.3.2__py3-none-any.whl → 25.4.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of MindsDB might be problematic. Click here for more details.
- mindsdb/__about__.py +1 -1
- mindsdb/__main__.py +18 -4
- mindsdb/api/executor/command_executor.py +12 -2
- mindsdb/api/executor/data_types/response_type.py +1 -0
- mindsdb/api/executor/datahub/classes/tables_row.py +3 -10
- mindsdb/api/executor/datahub/datanodes/datanode.py +7 -2
- mindsdb/api/executor/datahub/datanodes/information_schema_datanode.py +44 -10
- mindsdb/api/executor/datahub/datanodes/integration_datanode.py +57 -38
- mindsdb/api/executor/datahub/datanodes/mindsdb_tables.py +2 -1
- mindsdb/api/executor/datahub/datanodes/project_datanode.py +39 -7
- mindsdb/api/executor/datahub/datanodes/system_tables.py +116 -109
- mindsdb/api/executor/planner/query_plan.py +1 -0
- mindsdb/api/executor/planner/query_planner.py +15 -1
- mindsdb/api/executor/planner/steps.py +8 -2
- mindsdb/api/executor/sql_query/sql_query.py +24 -8
- mindsdb/api/executor/sql_query/steps/apply_predictor_step.py +25 -8
- mindsdb/api/executor/sql_query/steps/fetch_dataframe_partition.py +4 -2
- mindsdb/api/executor/sql_query/steps/insert_step.py +2 -1
- mindsdb/api/executor/sql_query/steps/prepare_steps.py +2 -3
- mindsdb/api/http/namespaces/config.py +19 -11
- mindsdb/api/litellm/start.py +82 -0
- mindsdb/api/mysql/mysql_proxy/libs/constants/mysql.py +133 -0
- mindsdb/integrations/handlers/chromadb_handler/chromadb_handler.py +7 -2
- mindsdb/integrations/handlers/chromadb_handler/settings.py +1 -0
- mindsdb/integrations/handlers/mssql_handler/mssql_handler.py +13 -4
- mindsdb/integrations/handlers/mysql_handler/mysql_handler.py +14 -5
- mindsdb/integrations/handlers/openai_handler/helpers.py +3 -5
- mindsdb/integrations/handlers/openai_handler/openai_handler.py +20 -8
- mindsdb/integrations/handlers/oracle_handler/oracle_handler.py +14 -4
- mindsdb/integrations/handlers/pgvector_handler/pgvector_handler.py +34 -19
- mindsdb/integrations/handlers/postgres_handler/postgres_handler.py +21 -18
- mindsdb/integrations/handlers/snowflake_handler/snowflake_handler.py +14 -4
- mindsdb/integrations/handlers/togetherai_handler/__about__.py +9 -0
- mindsdb/integrations/handlers/togetherai_handler/__init__.py +20 -0
- mindsdb/integrations/handlers/togetherai_handler/creation_args.py +14 -0
- mindsdb/integrations/handlers/togetherai_handler/icon.svg +15 -0
- mindsdb/integrations/handlers/togetherai_handler/model_using_args.py +5 -0
- mindsdb/integrations/handlers/togetherai_handler/requirements.txt +2 -0
- mindsdb/integrations/handlers/togetherai_handler/settings.py +33 -0
- mindsdb/integrations/handlers/togetherai_handler/togetherai_handler.py +234 -0
- mindsdb/integrations/handlers/web_handler/urlcrawl_helpers.py +1 -1
- mindsdb/integrations/libs/response.py +80 -32
- mindsdb/integrations/utilities/handler_utils.py +4 -0
- mindsdb/integrations/utilities/rag/rerankers/base_reranker.py +360 -0
- mindsdb/integrations/utilities/rag/rerankers/reranker_compressor.py +8 -153
- mindsdb/interfaces/agents/litellm_server.py +345 -0
- mindsdb/interfaces/agents/mcp_client_agent.py +252 -0
- mindsdb/interfaces/agents/run_mcp_agent.py +205 -0
- mindsdb/interfaces/functions/controller.py +3 -2
- mindsdb/interfaces/knowledge_base/controller.py +106 -82
- mindsdb/interfaces/query_context/context_controller.py +55 -15
- mindsdb/interfaces/query_context/query_task.py +19 -0
- mindsdb/interfaces/skills/skill_tool.py +7 -1
- mindsdb/interfaces/skills/sql_agent.py +8 -3
- mindsdb/interfaces/storage/db.py +2 -2
- mindsdb/interfaces/tasks/task_monitor.py +5 -1
- mindsdb/interfaces/tasks/task_thread.py +6 -0
- mindsdb/migrations/versions/2025-04-22_53502b6d63bf_query_database.py +27 -0
- mindsdb/utilities/config.py +20 -2
- mindsdb/utilities/context.py +1 -0
- mindsdb/utilities/starters.py +7 -0
- {mindsdb-25.4.3.2.dist-info → mindsdb-25.4.5.0.dist-info}/METADATA +226 -221
- {mindsdb-25.4.3.2.dist-info → mindsdb-25.4.5.0.dist-info}/RECORD +67 -53
- {mindsdb-25.4.3.2.dist-info → mindsdb-25.4.5.0.dist-info}/WHEEL +1 -1
- mindsdb/integrations/handlers/snowflake_handler/tests/test_snowflake_handler.py +0 -230
- /mindsdb/{integrations/handlers/snowflake_handler/tests → api/litellm}/__init__.py +0 -0
- {mindsdb-25.4.3.2.dist-info → mindsdb-25.4.5.0.dist-info}/licenses/LICENSE +0 -0
- {mindsdb-25.4.3.2.dist-info → mindsdb-25.4.5.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,360 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import logging
|
|
5
|
+
import math
|
|
6
|
+
import os
|
|
7
|
+
import random
|
|
8
|
+
from abc import ABC
|
|
9
|
+
from typing import Any, List, Optional, Tuple
|
|
10
|
+
|
|
11
|
+
from openai import AsyncOpenAI, AsyncAzureOpenAI
|
|
12
|
+
from pydantic import field_validator
|
|
13
|
+
from pydantic import BaseModel
|
|
14
|
+
|
|
15
|
+
from mindsdb.integrations.utilities.rag.settings import DEFAULT_RERANKING_MODEL, DEFAULT_LLM_ENDPOINT
|
|
16
|
+
|
|
17
|
+
log = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class BaseLLMReranker(BaseModel, ABC):
|
|
21
|
+
|
|
22
|
+
filtering_threshold: float = 0.0 # Default threshold for filtering
|
|
23
|
+
provider: str = 'openai'
|
|
24
|
+
model: str = DEFAULT_RERANKING_MODEL # Model to use for reranking
|
|
25
|
+
temperature: float = 0.0 # Temperature for the model
|
|
26
|
+
api_key: Optional[str] = None
|
|
27
|
+
base_url: Optional[str] = None
|
|
28
|
+
api_version: Optional[str] = None
|
|
29
|
+
num_docs_to_keep: Optional[int] = None # How many of the top documents to keep after reranking & compressing.
|
|
30
|
+
method: str = "multi-class" # Scoring method: 'multi-class' or 'binary'
|
|
31
|
+
_api_key_var: str = "OPENAI_API_KEY"
|
|
32
|
+
client: Optional[AsyncOpenAI] = None
|
|
33
|
+
_semaphore: Optional[asyncio.Semaphore] = None
|
|
34
|
+
max_concurrent_requests: int = 20
|
|
35
|
+
max_retries: int = 3
|
|
36
|
+
retry_delay: float = 1.0
|
|
37
|
+
request_timeout: float = 20.0 # Timeout for API requests
|
|
38
|
+
early_stop: bool = True # Whether to enable early stopping
|
|
39
|
+
early_stop_threshold: float = 0.8 # Confidence threshold for early stopping
|
|
40
|
+
|
|
41
|
+
class Config:
|
|
42
|
+
arbitrary_types_allowed = True
|
|
43
|
+
|
|
44
|
+
@field_validator('provider')
|
|
45
|
+
@classmethod
|
|
46
|
+
def validate_provider(cls, v: str) -> str:
|
|
47
|
+
allowed = {'openai', 'azure_openai'}
|
|
48
|
+
v_lower = v.lower()
|
|
49
|
+
if v_lower not in allowed:
|
|
50
|
+
raise ValueError(f"Unsupported provider: {v}.")
|
|
51
|
+
return v_lower
|
|
52
|
+
|
|
53
|
+
def __init__(self, **kwargs):
|
|
54
|
+
super().__init__(**kwargs)
|
|
55
|
+
self._semaphore = asyncio.Semaphore(self.max_concurrent_requests)
|
|
56
|
+
|
|
57
|
+
async def _init_client(self):
|
|
58
|
+
if self.client is None:
|
|
59
|
+
|
|
60
|
+
if self.provider == "azure_openai":
|
|
61
|
+
|
|
62
|
+
azure_api_key = self.api_key or os.getenv("AZURE_OPENAI_API_KEY")
|
|
63
|
+
azure_api_endpoint = self.base_url or os.environ.get("AZURE_OPENAI_ENDPOINT")
|
|
64
|
+
azure_api_version = self.api_version or os.environ.get("AZURE_OPENAI_API_VERSION")
|
|
65
|
+
self.client = AsyncAzureOpenAI(api_key=azure_api_key,
|
|
66
|
+
azure_endpoint=azure_api_endpoint,
|
|
67
|
+
api_version=azure_api_version,
|
|
68
|
+
timeout=self.request_timeout,
|
|
69
|
+
max_retries=2)
|
|
70
|
+
elif self.provider == "openai":
|
|
71
|
+
api_key_var: str = "OPENAI_API_KEY"
|
|
72
|
+
openai_api_key = self.api_key or os.getenv(api_key_var)
|
|
73
|
+
if not openai_api_key:
|
|
74
|
+
raise ValueError(f"OpenAI API key not found in environment variable {api_key_var}")
|
|
75
|
+
|
|
76
|
+
base_url = self.base_url or DEFAULT_LLM_ENDPOINT
|
|
77
|
+
self.client = AsyncOpenAI(api_key=openai_api_key, base_url=base_url, timeout=self.request_timeout, max_retries=2)
|
|
78
|
+
|
|
79
|
+
async def search_relevancy(self, query: str, document: str, rerank_callback=None) -> Any:
|
|
80
|
+
await self._init_client()
|
|
81
|
+
|
|
82
|
+
async with self._semaphore:
|
|
83
|
+
for attempt in range(self.max_retries):
|
|
84
|
+
try:
|
|
85
|
+
response = await self.client.chat.completions.create(
|
|
86
|
+
model=self.model,
|
|
87
|
+
messages=[
|
|
88
|
+
{"role": "system", "content": "Rate the relevance of the document to the query. Respond with 'yes' or 'no'."},
|
|
89
|
+
{"role": "user", "content": f"Query: {query}\nDocument: {document}\nIs this document relevant?"}
|
|
90
|
+
],
|
|
91
|
+
temperature=self.temperature,
|
|
92
|
+
n=1,
|
|
93
|
+
logprobs=True,
|
|
94
|
+
max_tokens=1
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
# Extract response and logprobs
|
|
98
|
+
answer = response.choices[0].message.content
|
|
99
|
+
logprob = response.choices[0].logprobs.content[0].logprob
|
|
100
|
+
rerank_data = {
|
|
101
|
+
"document": document,
|
|
102
|
+
"answer": answer,
|
|
103
|
+
"logprob": logprob
|
|
104
|
+
}
|
|
105
|
+
|
|
106
|
+
# Stream reranking update.
|
|
107
|
+
if rerank_callback is not None:
|
|
108
|
+
rerank_callback(rerank_data)
|
|
109
|
+
|
|
110
|
+
return rerank_data
|
|
111
|
+
|
|
112
|
+
except Exception as e:
|
|
113
|
+
if attempt == self.max_retries - 1:
|
|
114
|
+
log.error(f"Failed after {self.max_retries} attempts: {str(e)}")
|
|
115
|
+
raise
|
|
116
|
+
# Exponential backoff with jitter
|
|
117
|
+
retry_delay = self.retry_delay * (2 ** attempt) + random.uniform(0, 0.1)
|
|
118
|
+
await asyncio.sleep(retry_delay)
|
|
119
|
+
|
|
120
|
+
async def _rank(self, query_document_pairs: List[Tuple[str, str]], rerank_callback=None) -> List[Tuple[str, float]]:
|
|
121
|
+
ranked_results = []
|
|
122
|
+
|
|
123
|
+
# Process in larger batches for better throughput
|
|
124
|
+
batch_size = min(self.max_concurrent_requests * 2, len(query_document_pairs))
|
|
125
|
+
for i in range(0, len(query_document_pairs), batch_size):
|
|
126
|
+
batch = query_document_pairs[i:i + batch_size]
|
|
127
|
+
try:
|
|
128
|
+
results = await asyncio.gather(
|
|
129
|
+
*[self.search_relevancy(query=query, document=document, rerank_callback=rerank_callback) for (query, document) in batch],
|
|
130
|
+
return_exceptions=True
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
for idx, result in enumerate(results):
|
|
134
|
+
if isinstance(result, Exception):
|
|
135
|
+
log.error(f"Error processing document {i+idx}: {str(result)}")
|
|
136
|
+
ranked_results.append((batch[idx][1], 0.0))
|
|
137
|
+
continue
|
|
138
|
+
|
|
139
|
+
answer = result["answer"]
|
|
140
|
+
logprob = result["logprob"]
|
|
141
|
+
prob = math.exp(logprob)
|
|
142
|
+
|
|
143
|
+
# Convert answer to score using the model's confidence
|
|
144
|
+
if answer.lower().strip() == "yes":
|
|
145
|
+
score = prob # If yes, use the model's confidence
|
|
146
|
+
elif answer.lower().strip() == "no":
|
|
147
|
+
score = 1 - prob # If no, invert the confidence
|
|
148
|
+
else:
|
|
149
|
+
score = 0.5 * prob # For unclear answers, reduce confidence
|
|
150
|
+
|
|
151
|
+
ranked_results.append((batch[idx][1], score))
|
|
152
|
+
|
|
153
|
+
# Check if we should stop early
|
|
154
|
+
try:
|
|
155
|
+
high_scoring_docs = [r for r in ranked_results if r[1] >= self.filtering_threshold]
|
|
156
|
+
can_stop_early = (
|
|
157
|
+
self.early_stop # Early stopping is enabled
|
|
158
|
+
and self.num_docs_to_keep # We have a target number of docs
|
|
159
|
+
and len(high_scoring_docs) >= self.num_docs_to_keep # Found enough good docs
|
|
160
|
+
and score >= self.early_stop_threshold # Current doc is good enough
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
if can_stop_early:
|
|
164
|
+
log.info(f"Early stopping after finding {self.num_docs_to_keep} documents with high confidence")
|
|
165
|
+
return ranked_results
|
|
166
|
+
except Exception as e:
|
|
167
|
+
# Don't let early stopping errors stop the whole process
|
|
168
|
+
log.warning(f"Error in early stopping check: {str(e)}")
|
|
169
|
+
|
|
170
|
+
except Exception as e:
|
|
171
|
+
log.error(f"Batch processing error: {str(e)}")
|
|
172
|
+
continue
|
|
173
|
+
return ranked_results
|
|
174
|
+
|
|
175
|
+
async def search_relevancy_score(self, query: str, document: str) -> Any:
|
|
176
|
+
await self._init_client()
|
|
177
|
+
|
|
178
|
+
async with self._semaphore:
|
|
179
|
+
for attempt in range(self.max_retries):
|
|
180
|
+
try:
|
|
181
|
+
response = await self.client.chat.completions.create(
|
|
182
|
+
model=self.model,
|
|
183
|
+
messages=[
|
|
184
|
+
{"role": "system", "content": """
|
|
185
|
+
You are an intelligent assistant that evaluates how relevant a given document chunk is to a user's search query.
|
|
186
|
+
Your task is to analyze the similarity between the search query and the document chunk, and return **only the class label** that best represents the relevance:
|
|
187
|
+
|
|
188
|
+
- "class_1": Not relevant (score between 0.0 and 0.25)
|
|
189
|
+
- "class_2": Slightly relevant (score between 0.25 and 0.5)
|
|
190
|
+
- "class_3": Moderately relevant (score between 0.5 and 0.75)
|
|
191
|
+
- "class_4": Highly relevant (score between 0.75 and 1.0)
|
|
192
|
+
|
|
193
|
+
Respond with only one of: "class_1", "class_2", "class_3", or "class_4".
|
|
194
|
+
|
|
195
|
+
Examples:
|
|
196
|
+
|
|
197
|
+
Search query: "How to reset a router to factory settings?"
|
|
198
|
+
Document chunk: "Computers often come with customizable parental control settings."
|
|
199
|
+
Score: class_1
|
|
200
|
+
|
|
201
|
+
Search query: "Symptoms of vitamin D deficiency"
|
|
202
|
+
Document chunk: "Vitamin D deficiency has been linked to fatigue, bone pain, and muscle weakness."
|
|
203
|
+
Score: class_4
|
|
204
|
+
|
|
205
|
+
Search query: "Best practices for onboarding remote employees"
|
|
206
|
+
Document chunk: "An employee handbook can be useful for new hires, outlining company policies and benefits."
|
|
207
|
+
Score: class_2
|
|
208
|
+
|
|
209
|
+
Search query: "Benefits of mindfulness meditation"
|
|
210
|
+
Document chunk: "Practicing mindfulness has shown to reduce stress and improve focus in multiple studies."
|
|
211
|
+
Score: class_3
|
|
212
|
+
|
|
213
|
+
Search query: "What is Kubernetes used for?"
|
|
214
|
+
Document chunk: "Kubernetes is an open-source system for automating deployment, scaling, and management of containerized applications."
|
|
215
|
+
Score: class_4
|
|
216
|
+
|
|
217
|
+
Search query: "How to bake sourdough bread at home"
|
|
218
|
+
Document chunk: "The French Revolution began in 1789 and radically transformed society."
|
|
219
|
+
Score: class_1
|
|
220
|
+
|
|
221
|
+
Search query: "Machine learning algorithms for image classification"
|
|
222
|
+
Document chunk: "Convolutional Neural Networks (CNNs) are particularly effective in image classification tasks."
|
|
223
|
+
Score: class_4
|
|
224
|
+
|
|
225
|
+
Search query: "How to improve focus while working remotely"
|
|
226
|
+
Document chunk: "Creating a dedicated workspace and setting a consistent schedule can significantly improve focus during remote work."
|
|
227
|
+
Score: class_4
|
|
228
|
+
|
|
229
|
+
Search query: "Carbon emissions from electric vehicles vs gas cars"
|
|
230
|
+
Document chunk: "Electric vehicles produce zero emissions while driving, but battery production has environmental impacts."
|
|
231
|
+
Score: class_3
|
|
232
|
+
|
|
233
|
+
Search query: "Time zones in the United States"
|
|
234
|
+
Document chunk: "The U.S. is divided into six primary time zones: Eastern, Central, Mountain, Pacific, Alaska, and Hawaii-Aleutian."
|
|
235
|
+
Score: class_4
|
|
236
|
+
"""},
|
|
237
|
+
|
|
238
|
+
{"role": "user", "content": f"""
|
|
239
|
+
Now evaluate the following pair:
|
|
240
|
+
|
|
241
|
+
Search query: {query}
|
|
242
|
+
Document chunk: {document}
|
|
243
|
+
|
|
244
|
+
Which class best represents the relevance?
|
|
245
|
+
"""}
|
|
246
|
+
],
|
|
247
|
+
temperature=self.temperature,
|
|
248
|
+
n=1,
|
|
249
|
+
logprobs=True,
|
|
250
|
+
top_logprobs=4,
|
|
251
|
+
max_tokens=3
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
# Extract response and logprobs
|
|
255
|
+
class_label = response.choices[0].message.content.strip()
|
|
256
|
+
token_logprobs = response.choices[0].logprobs.content
|
|
257
|
+
# Reconstruct the prediction and extract the top logprobs from the final token (e.g., "1")
|
|
258
|
+
final_token_logprob = token_logprobs[-1]
|
|
259
|
+
top_logprobs = final_token_logprob.top_logprobs
|
|
260
|
+
# Create a map of 'class_1' -> probability, using token combinations
|
|
261
|
+
class_probs = {}
|
|
262
|
+
for top_token in top_logprobs:
|
|
263
|
+
full_label = f"class_{top_token.token}"
|
|
264
|
+
prob = math.exp(top_token.logprob)
|
|
265
|
+
class_probs[full_label] = prob
|
|
266
|
+
# Optional: normalize in case some are missing
|
|
267
|
+
total_prob = sum(class_probs.values())
|
|
268
|
+
class_probs = {k: v / total_prob for k, v in class_probs.items()}
|
|
269
|
+
# Assign weights to classes
|
|
270
|
+
class_weights = {
|
|
271
|
+
"class_1": 0.25,
|
|
272
|
+
"class_2": 0.5,
|
|
273
|
+
"class_3": 0.75,
|
|
274
|
+
"class_4": 1.0
|
|
275
|
+
}
|
|
276
|
+
# Compute the final smooth score
|
|
277
|
+
relevance_score = sum(class_weights.get(class_label, 0) * prob for class_label, prob in class_probs.items())
|
|
278
|
+
rerank_data = {
|
|
279
|
+
"document": document,
|
|
280
|
+
"answer": class_label,
|
|
281
|
+
"relevance_score": relevance_score
|
|
282
|
+
}
|
|
283
|
+
return rerank_data
|
|
284
|
+
|
|
285
|
+
except Exception as e:
|
|
286
|
+
if attempt == self.max_retries - 1:
|
|
287
|
+
log.error(f"Failed after {self.max_retries} attempts: {str(e)}")
|
|
288
|
+
raise
|
|
289
|
+
# Exponential backoff with jitter
|
|
290
|
+
retry_delay = self.retry_delay * (2 ** attempt) + random.uniform(0, 0.1)
|
|
291
|
+
await asyncio.sleep(retry_delay)
|
|
292
|
+
|
|
293
|
+
async def _rank_score(self, query_document_pairs: List[Tuple[str, str]]) -> List[Tuple[str, float]]:
|
|
294
|
+
ranked_results = []
|
|
295
|
+
|
|
296
|
+
# Process in larger batches for better throughput
|
|
297
|
+
batch_size = min(self.max_concurrent_requests * 2, len(query_document_pairs))
|
|
298
|
+
for i in range(0, len(query_document_pairs), batch_size):
|
|
299
|
+
batch = query_document_pairs[i:i + batch_size]
|
|
300
|
+
try:
|
|
301
|
+
results = await asyncio.gather(
|
|
302
|
+
*[self.search_relevancy_score(query=query, document=document) for (query, document) in batch],
|
|
303
|
+
return_exceptions=True
|
|
304
|
+
)
|
|
305
|
+
|
|
306
|
+
for idx, result in enumerate(results):
|
|
307
|
+
if isinstance(result, Exception):
|
|
308
|
+
log.error(f"Error processing document {i+idx}: {str(result)}")
|
|
309
|
+
ranked_results.append((batch[idx][1], 0.0))
|
|
310
|
+
continue
|
|
311
|
+
|
|
312
|
+
score = result["relevance_score"]
|
|
313
|
+
if score is not None:
|
|
314
|
+
if score > 1.0:
|
|
315
|
+
score = 1.0
|
|
316
|
+
elif score < 0.0:
|
|
317
|
+
score = 0.0
|
|
318
|
+
|
|
319
|
+
ranked_results.append((batch[idx][1], score))
|
|
320
|
+
# Check if we should stop early
|
|
321
|
+
try:
|
|
322
|
+
high_scoring_docs = [r for r in ranked_results if r[1] >= self.filtering_threshold]
|
|
323
|
+
can_stop_early = (
|
|
324
|
+
self.early_stop # Early stopping is enabled
|
|
325
|
+
and self.num_docs_to_keep # We have a target number of docs
|
|
326
|
+
and len(high_scoring_docs) >= self.num_docs_to_keep # Found enough good docs
|
|
327
|
+
and score >= self.early_stop_threshold # Current doc is good enough
|
|
328
|
+
)
|
|
329
|
+
|
|
330
|
+
if can_stop_early:
|
|
331
|
+
log.info(f"Early stopping after finding {self.num_docs_to_keep} documents with high confidence")
|
|
332
|
+
return ranked_results
|
|
333
|
+
except Exception as e:
|
|
334
|
+
# Don't let early stopping errors stop the whole process
|
|
335
|
+
log.warning(f"Error in early stopping check: {str(e)}")
|
|
336
|
+
|
|
337
|
+
except Exception as e:
|
|
338
|
+
log.error(f"Batch processing error: {str(e)}")
|
|
339
|
+
continue
|
|
340
|
+
|
|
341
|
+
return ranked_results
|
|
342
|
+
|
|
343
|
+
def get_scores(self, query: str, documents: list[str]):
|
|
344
|
+
query_document_pairs = [(query, doc) for doc in documents]
|
|
345
|
+
# Create event loop and run async code
|
|
346
|
+
import asyncio
|
|
347
|
+
try:
|
|
348
|
+
loop = asyncio.get_running_loop()
|
|
349
|
+
except RuntimeError:
|
|
350
|
+
# If no running loop exists, create a new one
|
|
351
|
+
loop = asyncio.new_event_loop()
|
|
352
|
+
asyncio.set_event_loop(loop)
|
|
353
|
+
|
|
354
|
+
if self.method == "multi-class": # default 'multi-class' method
|
|
355
|
+
documents_and_scores = loop.run_until_complete(self._rank_score(query_document_pairs))
|
|
356
|
+
else:
|
|
357
|
+
documents_and_scores = loop.run_until_complete(self._rank(query_document_pairs))
|
|
358
|
+
|
|
359
|
+
scores = [score for _, score in documents_and_scores]
|
|
360
|
+
return scores
|
|
@@ -2,153 +2,22 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
import asyncio
|
|
4
4
|
import logging
|
|
5
|
-
import
|
|
6
|
-
import os
|
|
7
|
-
import random
|
|
8
|
-
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
|
5
|
+
from typing import Any, Dict, Optional, Sequence
|
|
9
6
|
|
|
10
7
|
from langchain.retrievers.document_compressors.base import BaseDocumentCompressor
|
|
11
8
|
from langchain_core.callbacks import Callbacks, dispatch_custom_event
|
|
12
9
|
from langchain_core.documents import Document
|
|
13
|
-
from openai import AsyncOpenAI
|
|
14
10
|
|
|
15
|
-
from mindsdb.integrations.utilities.rag.
|
|
11
|
+
from mindsdb.integrations.utilities.rag.rerankers.base_reranker import BaseLLMReranker
|
|
16
12
|
|
|
17
13
|
log = logging.getLogger(__name__)
|
|
18
14
|
|
|
19
15
|
|
|
20
|
-
class LLMReranker(BaseDocumentCompressor):
|
|
21
|
-
filtering_threshold: float = 0.0 # Default threshold for filtering
|
|
22
|
-
model: str = DEFAULT_RERANKING_MODEL # Model to use for reranking
|
|
23
|
-
temperature: float = 0.0 # Temperature for the model
|
|
24
|
-
openai_api_key: Optional[str] = None
|
|
16
|
+
class LLMReranker(BaseDocumentCompressor, BaseLLMReranker):
|
|
25
17
|
remove_irrelevant: bool = True # New flag to control removal of irrelevant documents
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
client: Optional[AsyncOpenAI] = None
|
|
30
|
-
_semaphore: Optional[asyncio.Semaphore] = None
|
|
31
|
-
max_concurrent_requests: int = 20
|
|
32
|
-
max_retries: int = 3
|
|
33
|
-
retry_delay: float = 1.0
|
|
34
|
-
request_timeout: float = 20.0 # Timeout for API requests
|
|
35
|
-
early_stop: bool = True # Whether to enable early stopping
|
|
36
|
-
early_stop_threshold: float = 0.8 # Confidence threshold for early stopping
|
|
37
|
-
|
|
38
|
-
class Config:
|
|
39
|
-
arbitrary_types_allowed = True
|
|
40
|
-
|
|
41
|
-
def __init__(self, **kwargs):
|
|
42
|
-
super().__init__(**kwargs)
|
|
43
|
-
self._semaphore = asyncio.Semaphore(self.max_concurrent_requests)
|
|
44
|
-
|
|
45
|
-
async def _init_client(self):
|
|
46
|
-
if self.client is None:
|
|
47
|
-
openai_api_key = self.openai_api_key or os.getenv(self._api_key_var)
|
|
48
|
-
if not openai_api_key:
|
|
49
|
-
raise ValueError(f"OpenAI API key not found in environment variable {self._api_key_var}")
|
|
50
|
-
self.client = AsyncOpenAI(
|
|
51
|
-
api_key=openai_api_key,
|
|
52
|
-
base_url=self.base_url,
|
|
53
|
-
timeout=self.request_timeout,
|
|
54
|
-
max_retries=2 # Client-level retries
|
|
55
|
-
)
|
|
56
|
-
|
|
57
|
-
async def search_relevancy(self, query: str, document: str, custom_event: bool = True) -> Any:
|
|
58
|
-
await self._init_client()
|
|
59
|
-
|
|
60
|
-
async with self._semaphore:
|
|
61
|
-
for attempt in range(self.max_retries):
|
|
62
|
-
try:
|
|
63
|
-
response = await self.client.chat.completions.create(
|
|
64
|
-
model=self.model,
|
|
65
|
-
messages=[
|
|
66
|
-
{"role": "system", "content": "Rate the relevance of the document to the query. Respond with 'yes' or 'no'."},
|
|
67
|
-
{"role": "user", "content": f"Query: {query}\nDocument: {document}\nIs this document relevant?"}
|
|
68
|
-
],
|
|
69
|
-
temperature=self.temperature,
|
|
70
|
-
n=1,
|
|
71
|
-
logprobs=True,
|
|
72
|
-
max_tokens=1
|
|
73
|
-
)
|
|
74
|
-
|
|
75
|
-
# Extract response and logprobs
|
|
76
|
-
answer = response.choices[0].message.content
|
|
77
|
-
logprob = response.choices[0].logprobs.content[0].logprob
|
|
78
|
-
rerank_data = {
|
|
79
|
-
"document": document,
|
|
80
|
-
"answer": answer,
|
|
81
|
-
"logprob": logprob
|
|
82
|
-
}
|
|
83
|
-
|
|
84
|
-
# Stream reranking update.
|
|
85
|
-
if custom_event:
|
|
86
|
-
dispatch_custom_event("rerank", rerank_data)
|
|
87
|
-
return rerank_data
|
|
88
|
-
|
|
89
|
-
except Exception as e:
|
|
90
|
-
if attempt == self.max_retries - 1:
|
|
91
|
-
log.error(f"Failed after {self.max_retries} attempts: {str(e)}")
|
|
92
|
-
raise
|
|
93
|
-
# Exponential backoff with jitter
|
|
94
|
-
retry_delay = self.retry_delay * (2 ** attempt) + random.uniform(0, 0.1)
|
|
95
|
-
await asyncio.sleep(retry_delay)
|
|
96
|
-
|
|
97
|
-
async def _rank(self, query_document_pairs: List[Tuple[str, str]], custom_event: bool = True) -> List[Tuple[str, float]]:
|
|
98
|
-
ranked_results = []
|
|
99
|
-
|
|
100
|
-
# Process in larger batches for better throughput
|
|
101
|
-
batch_size = min(self.max_concurrent_requests * 2, len(query_document_pairs))
|
|
102
|
-
for i in range(0, len(query_document_pairs), batch_size):
|
|
103
|
-
batch = query_document_pairs[i:i + batch_size]
|
|
104
|
-
try:
|
|
105
|
-
results = await asyncio.gather(
|
|
106
|
-
*[self.search_relevancy(query=query, document=document, custom_event=custom_event) for (query, document) in batch],
|
|
107
|
-
return_exceptions=True
|
|
108
|
-
)
|
|
109
|
-
|
|
110
|
-
for idx, result in enumerate(results):
|
|
111
|
-
if isinstance(result, Exception):
|
|
112
|
-
log.error(f"Error processing document {i+idx}: {str(result)}")
|
|
113
|
-
ranked_results.append((batch[idx][1], 0.0))
|
|
114
|
-
continue
|
|
115
|
-
|
|
116
|
-
answer = result["answer"]
|
|
117
|
-
logprob = result["logprob"]
|
|
118
|
-
prob = math.exp(logprob)
|
|
119
|
-
|
|
120
|
-
# Convert answer to score using the model's confidence
|
|
121
|
-
if answer.lower().strip() == "yes":
|
|
122
|
-
score = prob # If yes, use the model's confidence
|
|
123
|
-
elif answer.lower().strip() == "no":
|
|
124
|
-
score = 1 - prob # If no, invert the confidence
|
|
125
|
-
else:
|
|
126
|
-
score = 0.5 * prob # For unclear answers, reduce confidence
|
|
127
|
-
|
|
128
|
-
ranked_results.append((batch[idx][1], score))
|
|
129
|
-
|
|
130
|
-
# Check if we should stop early
|
|
131
|
-
try:
|
|
132
|
-
high_scoring_docs = [r for r in ranked_results if r[1] >= self.filtering_threshold]
|
|
133
|
-
can_stop_early = (
|
|
134
|
-
self.early_stop # Early stopping is enabled
|
|
135
|
-
and self.num_docs_to_keep # We have a target number of docs
|
|
136
|
-
and len(high_scoring_docs) >= self.num_docs_to_keep # Found enough good docs
|
|
137
|
-
and score >= self.early_stop_threshold # Current doc is good enough
|
|
138
|
-
)
|
|
139
|
-
|
|
140
|
-
if can_stop_early:
|
|
141
|
-
log.info(f"Early stopping after finding {self.num_docs_to_keep} documents with high confidence")
|
|
142
|
-
return ranked_results
|
|
143
|
-
except Exception as e:
|
|
144
|
-
# Don't let early stopping errors stop the whole process
|
|
145
|
-
log.warning(f"Error in early stopping check: {str(e)}")
|
|
146
|
-
|
|
147
|
-
except Exception as e:
|
|
148
|
-
log.error(f"Batch processing error: {str(e)}")
|
|
149
|
-
continue
|
|
150
|
-
|
|
151
|
-
return ranked_results
|
|
18
|
+
|
|
19
|
+
def _dispatch_rerank_event(self, data):
|
|
20
|
+
dispatch_custom_event("rerank", data)
|
|
152
21
|
|
|
153
22
|
async def acompress_documents(
|
|
154
23
|
self,
|
|
@@ -177,7 +46,7 @@ class LLMReranker(BaseDocumentCompressor):
|
|
|
177
46
|
await callbacks.on_text("Starting document reranking...")
|
|
178
47
|
|
|
179
48
|
# Get ranked results
|
|
180
|
-
ranked_results = await self._rank(query_document_pairs)
|
|
49
|
+
ranked_results = await self._rank(query_document_pairs, rerank_callback=self._dispatch_rerank_event)
|
|
181
50
|
|
|
182
51
|
# Sort by score in descending order
|
|
183
52
|
ranked_results.sort(key=lambda x: x[1], reverse=True)
|
|
@@ -226,19 +95,5 @@ class LLMReranker(BaseDocumentCompressor):
|
|
|
226
95
|
"model": self.model,
|
|
227
96
|
"temperature": self.temperature,
|
|
228
97
|
"remove_irrelevant": self.remove_irrelevant,
|
|
98
|
+
"method": self.method,
|
|
229
99
|
}
|
|
230
|
-
|
|
231
|
-
def get_scores(self, query: str, documents: list[str], custom_event: bool = False):
|
|
232
|
-
query_document_pairs = [(query, doc) for doc in documents]
|
|
233
|
-
# Create event loop and run async code
|
|
234
|
-
import asyncio
|
|
235
|
-
try:
|
|
236
|
-
loop = asyncio.get_running_loop()
|
|
237
|
-
except RuntimeError:
|
|
238
|
-
# If no running loop exists, create a new one
|
|
239
|
-
loop = asyncio.new_event_loop()
|
|
240
|
-
asyncio.set_event_loop(loop)
|
|
241
|
-
|
|
242
|
-
documents_and_scores = loop.run_until_complete(self._rank(query_document_pairs, custom_event=custom_event))
|
|
243
|
-
scores = [score for _, score in documents_and_scores]
|
|
244
|
-
return scores
|