MindsDB 25.5.4.0__py3-none-any.whl → 25.5.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of MindsDB might be problematic. Click here for more details.
- mindsdb/__about__.py +8 -8
- mindsdb/api/a2a/__main__.py +38 -8
- mindsdb/api/a2a/run_a2a.py +10 -53
- mindsdb/api/a2a/task_manager.py +19 -53
- mindsdb/api/executor/command_executor.py +147 -291
- mindsdb/api/http/namespaces/config.py +61 -86
- mindsdb/integrations/handlers/byom_handler/requirements.txt +1 -2
- mindsdb/integrations/handlers/lancedb_handler/requirements.txt +0 -1
- mindsdb/integrations/handlers/litellm_handler/litellm_handler.py +37 -20
- mindsdb/integrations/libs/llm/config.py +13 -0
- mindsdb/integrations/libs/llm/utils.py +37 -65
- mindsdb/integrations/utilities/rag/rerankers/base_reranker.py +230 -227
- mindsdb/interfaces/agents/constants.py +17 -13
- mindsdb/interfaces/agents/langchain_agent.py +93 -94
- mindsdb/interfaces/knowledge_base/controller.py +230 -221
- mindsdb/utilities/config.py +43 -84
- {mindsdb-25.5.4.0.dist-info → mindsdb-25.5.4.1.dist-info}/METADATA +261 -259
- {mindsdb-25.5.4.0.dist-info → mindsdb-25.5.4.1.dist-info}/RECORD +21 -25
- mindsdb/api/a2a/a2a_client.py +0 -439
- mindsdb/api/a2a/common/client/__init__.py +0 -4
- mindsdb/api/a2a/common/client/card_resolver.py +0 -21
- mindsdb/api/a2a/common/client/client.py +0 -86
- {mindsdb-25.5.4.0.dist-info → mindsdb-25.5.4.1.dist-info}/WHEEL +0 -0
- {mindsdb-25.5.4.0.dist-info → mindsdb-25.5.4.1.dist-info}/licenses/LICENSE +0 -0
- {mindsdb-25.5.4.0.dist-info → mindsdb-25.5.4.1.dist-info}/top_level.txt +0 -0
|
@@ -1,26 +1,27 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import re
|
|
3
4
|
import asyncio
|
|
4
5
|
import logging
|
|
5
6
|
import math
|
|
6
7
|
import os
|
|
7
8
|
import random
|
|
8
9
|
from abc import ABC
|
|
10
|
+
from textwrap import dedent
|
|
9
11
|
from typing import Any, List, Optional, Tuple
|
|
10
12
|
|
|
11
13
|
from openai import AsyncOpenAI, AsyncAzureOpenAI
|
|
12
|
-
from pydantic import field_validator
|
|
13
14
|
from pydantic import BaseModel
|
|
14
15
|
|
|
15
16
|
from mindsdb.integrations.utilities.rag.settings import DEFAULT_RERANKING_MODEL, DEFAULT_LLM_ENDPOINT
|
|
17
|
+
from mindsdb.integrations.libs.base import BaseMLEngine
|
|
16
18
|
|
|
17
19
|
log = logging.getLogger(__name__)
|
|
18
20
|
|
|
19
21
|
|
|
20
22
|
class BaseLLMReranker(BaseModel, ABC):
|
|
21
|
-
|
|
22
23
|
filtering_threshold: float = 0.0 # Default threshold for filtering
|
|
23
|
-
provider: str =
|
|
24
|
+
provider: str = "openai"
|
|
24
25
|
model: str = DEFAULT_RERANKING_MODEL # Model to use for reranking
|
|
25
26
|
temperature: float = 0.0 # Temperature for the model
|
|
26
27
|
api_key: Optional[str] = None
|
|
@@ -29,7 +30,7 @@ class BaseLLMReranker(BaseModel, ABC):
|
|
|
29
30
|
num_docs_to_keep: Optional[int] = None # How many of the top documents to keep after reranking & compressing.
|
|
30
31
|
method: str = "multi-class" # Scoring method: 'multi-class' or 'binary'
|
|
31
32
|
_api_key_var: str = "OPENAI_API_KEY"
|
|
32
|
-
client: Optional[AsyncOpenAI] = None
|
|
33
|
+
client: Optional[AsyncOpenAI | BaseMLEngine] = None
|
|
33
34
|
_semaphore: Optional[asyncio.Semaphore] = None
|
|
34
35
|
max_concurrent_requests: int = 20
|
|
35
36
|
max_retries: int = 3
|
|
@@ -40,33 +41,26 @@ class BaseLLMReranker(BaseModel, ABC):
|
|
|
40
41
|
|
|
41
42
|
class Config:
|
|
42
43
|
arbitrary_types_allowed = True
|
|
43
|
-
|
|
44
|
-
@field_validator('provider')
|
|
45
|
-
@classmethod
|
|
46
|
-
def validate_provider(cls, v: str) -> str:
|
|
47
|
-
allowed = {'openai', 'azure_openai'}
|
|
48
|
-
v_lower = v.lower()
|
|
49
|
-
if v_lower not in allowed:
|
|
50
|
-
raise ValueError(f"Unsupported provider: {v}.")
|
|
51
|
-
return v_lower
|
|
44
|
+
extra = "allow"
|
|
52
45
|
|
|
53
46
|
def __init__(self, **kwargs):
|
|
54
47
|
super().__init__(**kwargs)
|
|
55
48
|
self._semaphore = asyncio.Semaphore(self.max_concurrent_requests)
|
|
49
|
+
self._init_client()
|
|
56
50
|
|
|
57
|
-
|
|
51
|
+
def _init_client(self):
|
|
58
52
|
if self.client is None:
|
|
59
|
-
|
|
60
53
|
if self.provider == "azure_openai":
|
|
61
|
-
|
|
62
54
|
azure_api_key = self.api_key or os.getenv("AZURE_OPENAI_API_KEY")
|
|
63
55
|
azure_api_endpoint = self.base_url or os.environ.get("AZURE_OPENAI_ENDPOINT")
|
|
64
56
|
azure_api_version = self.api_version or os.environ.get("AZURE_OPENAI_API_VERSION")
|
|
65
|
-
self.client = AsyncAzureOpenAI(
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
57
|
+
self.client = AsyncAzureOpenAI(
|
|
58
|
+
api_key=azure_api_key,
|
|
59
|
+
azure_endpoint=azure_api_endpoint,
|
|
60
|
+
api_version=azure_api_version,
|
|
61
|
+
timeout=self.request_timeout,
|
|
62
|
+
max_retries=2,
|
|
63
|
+
)
|
|
70
64
|
elif self.provider == "openai":
|
|
71
65
|
api_key_var: str = "OPENAI_API_KEY"
|
|
72
66
|
openai_api_key = self.api_key or os.getenv(api_key_var)
|
|
@@ -74,48 +68,39 @@ class BaseLLMReranker(BaseModel, ABC):
|
|
|
74
68
|
raise ValueError(f"OpenAI API key not found in environment variable {api_key_var}")
|
|
75
69
|
|
|
76
70
|
base_url = self.base_url or DEFAULT_LLM_ENDPOINT
|
|
77
|
-
self.client = AsyncOpenAI(
|
|
71
|
+
self.client = AsyncOpenAI(
|
|
72
|
+
api_key=openai_api_key, base_url=base_url, timeout=self.request_timeout, max_retries=2
|
|
73
|
+
)
|
|
78
74
|
|
|
79
|
-
|
|
80
|
-
|
|
75
|
+
else:
|
|
76
|
+
# try to use litellm
|
|
77
|
+
from mindsdb.api.executor.controllers.session_controller import SessionController
|
|
81
78
|
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
try:
|
|
85
|
-
response = await self.client.chat.completions.create(
|
|
86
|
-
model=self.model,
|
|
87
|
-
messages=[
|
|
88
|
-
{"role": "system", "content": "Rate the relevance of the document to the query. Respond with 'yes' or 'no'."},
|
|
89
|
-
{"role": "user", "content": f"Query: {query}\nDocument: {document}\nIs this document relevant?"}
|
|
90
|
-
],
|
|
91
|
-
temperature=self.temperature,
|
|
92
|
-
n=1,
|
|
93
|
-
logprobs=True,
|
|
94
|
-
max_tokens=1
|
|
95
|
-
)
|
|
96
|
-
|
|
97
|
-
# Extract response and logprobs
|
|
98
|
-
answer = response.choices[0].message.content
|
|
99
|
-
logprob = response.choices[0].logprobs.content[0].logprob
|
|
100
|
-
rerank_data = {
|
|
101
|
-
"document": document,
|
|
102
|
-
"answer": answer,
|
|
103
|
-
"logprob": logprob
|
|
104
|
-
}
|
|
105
|
-
|
|
106
|
-
# Stream reranking update.
|
|
107
|
-
if rerank_callback is not None:
|
|
108
|
-
rerank_callback(rerank_data)
|
|
79
|
+
session = SessionController()
|
|
80
|
+
module = session.integration_controller.get_handler_module("litellm")
|
|
109
81
|
|
|
110
|
-
|
|
82
|
+
if module is None or module.Handler is None:
|
|
83
|
+
raise ValueError(f'Unable to use "{self.provider}" provider. Litellm handler is not installed')
|
|
111
84
|
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
85
|
+
self.client = module.Handler
|
|
86
|
+
self.method = "no-logprobs"
|
|
87
|
+
|
|
88
|
+
async def _call_llm(self, messages):
|
|
89
|
+
if self.provider in ("azure_openai", "openai"):
|
|
90
|
+
return await self.client.chat.completions.create(
|
|
91
|
+
model=self.model,
|
|
92
|
+
messages=messages,
|
|
93
|
+
)
|
|
94
|
+
else:
|
|
95
|
+
kwargs = self.model_extra.copy()
|
|
96
|
+
|
|
97
|
+
if self.base_url is not None:
|
|
98
|
+
kwargs["api_base"] = self.base_url
|
|
99
|
+
|
|
100
|
+
if self.api_key is not None:
|
|
101
|
+
kwargs["api_key"] = self.api_key
|
|
102
|
+
|
|
103
|
+
return await self.client.acompletion(model=f"{self.provider}/{self.model}", messages=messages, args=kwargs)
|
|
119
104
|
|
|
120
105
|
async def _rank(self, query_document_pairs: List[Tuple[str, str]], rerank_callback=None) -> List[Tuple[str, float]]:
|
|
121
106
|
ranked_results = []
|
|
@@ -123,30 +108,23 @@ class BaseLLMReranker(BaseModel, ABC):
|
|
|
123
108
|
# Process in larger batches for better throughput
|
|
124
109
|
batch_size = min(self.max_concurrent_requests * 2, len(query_document_pairs))
|
|
125
110
|
for i in range(0, len(query_document_pairs), batch_size):
|
|
126
|
-
batch = query_document_pairs[i:i + batch_size]
|
|
111
|
+
batch = query_document_pairs[i : i + batch_size]
|
|
127
112
|
try:
|
|
128
113
|
results = await asyncio.gather(
|
|
129
|
-
*[
|
|
130
|
-
|
|
114
|
+
*[
|
|
115
|
+
self._backoff_wrapper(query=query, document=document, rerank_callback=rerank_callback)
|
|
116
|
+
for (query, document) in batch
|
|
117
|
+
],
|
|
118
|
+
return_exceptions=True,
|
|
131
119
|
)
|
|
132
120
|
|
|
133
121
|
for idx, result in enumerate(results):
|
|
134
122
|
if isinstance(result, Exception):
|
|
135
|
-
log.error(f"Error processing document {i+idx}: {str(result)}")
|
|
123
|
+
log.error(f"Error processing document {i + idx}: {str(result)}")
|
|
136
124
|
ranked_results.append((batch[idx][1], 0.0))
|
|
137
125
|
continue
|
|
138
126
|
|
|
139
|
-
|
|
140
|
-
logprob = result["logprob"]
|
|
141
|
-
prob = math.exp(logprob)
|
|
142
|
-
|
|
143
|
-
# Convert answer to score using the model's confidence
|
|
144
|
-
if answer.lower().strip() == "yes":
|
|
145
|
-
score = prob # If yes, use the model's confidence
|
|
146
|
-
elif answer.lower().strip() == "no":
|
|
147
|
-
score = 1 - prob # If no, invert the confidence
|
|
148
|
-
else:
|
|
149
|
-
score = 0.5 * prob # For unclear answers, reduce confidence
|
|
127
|
+
score = result["relevance_score"]
|
|
150
128
|
|
|
151
129
|
ranked_results.append((batch[idx][1], score))
|
|
152
130
|
|
|
@@ -161,7 +139,9 @@ class BaseLLMReranker(BaseModel, ABC):
|
|
|
161
139
|
)
|
|
162
140
|
|
|
163
141
|
if can_stop_early:
|
|
164
|
-
log.info(
|
|
142
|
+
log.info(
|
|
143
|
+
f"Early stopping after finding {self.num_docs_to_keep} documents with high confidence"
|
|
144
|
+
)
|
|
165
145
|
return ranked_results
|
|
166
146
|
except Exception as e:
|
|
167
147
|
# Don't let early stopping errors stop the whole process
|
|
@@ -172,114 +152,18 @@ class BaseLLMReranker(BaseModel, ABC):
|
|
|
172
152
|
continue
|
|
173
153
|
return ranked_results
|
|
174
154
|
|
|
175
|
-
async def
|
|
176
|
-
await self._init_client()
|
|
177
|
-
|
|
155
|
+
async def _backoff_wrapper(self, query: str, document: str, rerank_callback=None) -> Any:
|
|
178
156
|
async with self._semaphore:
|
|
179
157
|
for attempt in range(self.max_retries):
|
|
180
158
|
try:
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
- "class_2": Slightly relevant (score between 0.25 and 0.5)
|
|
190
|
-
- "class_3": Moderately relevant (score between 0.5 and 0.75)
|
|
191
|
-
- "class_4": Highly relevant (score between 0.75 and 1.0)
|
|
192
|
-
|
|
193
|
-
Respond with only one of: "class_1", "class_2", "class_3", or "class_4".
|
|
194
|
-
|
|
195
|
-
Examples:
|
|
196
|
-
|
|
197
|
-
Search query: "How to reset a router to factory settings?"
|
|
198
|
-
Document chunk: "Computers often come with customizable parental control settings."
|
|
199
|
-
Score: class_1
|
|
200
|
-
|
|
201
|
-
Search query: "Symptoms of vitamin D deficiency"
|
|
202
|
-
Document chunk: "Vitamin D deficiency has been linked to fatigue, bone pain, and muscle weakness."
|
|
203
|
-
Score: class_4
|
|
204
|
-
|
|
205
|
-
Search query: "Best practices for onboarding remote employees"
|
|
206
|
-
Document chunk: "An employee handbook can be useful for new hires, outlining company policies and benefits."
|
|
207
|
-
Score: class_2
|
|
208
|
-
|
|
209
|
-
Search query: "Benefits of mindfulness meditation"
|
|
210
|
-
Document chunk: "Practicing mindfulness has shown to reduce stress and improve focus in multiple studies."
|
|
211
|
-
Score: class_3
|
|
212
|
-
|
|
213
|
-
Search query: "What is Kubernetes used for?"
|
|
214
|
-
Document chunk: "Kubernetes is an open-source system for automating deployment, scaling, and management of containerized applications."
|
|
215
|
-
Score: class_4
|
|
216
|
-
|
|
217
|
-
Search query: "How to bake sourdough bread at home"
|
|
218
|
-
Document chunk: "The French Revolution began in 1789 and radically transformed society."
|
|
219
|
-
Score: class_1
|
|
220
|
-
|
|
221
|
-
Search query: "Machine learning algorithms for image classification"
|
|
222
|
-
Document chunk: "Convolutional Neural Networks (CNNs) are particularly effective in image classification tasks."
|
|
223
|
-
Score: class_4
|
|
224
|
-
|
|
225
|
-
Search query: "How to improve focus while working remotely"
|
|
226
|
-
Document chunk: "Creating a dedicated workspace and setting a consistent schedule can significantly improve focus during remote work."
|
|
227
|
-
Score: class_4
|
|
228
|
-
|
|
229
|
-
Search query: "Carbon emissions from electric vehicles vs gas cars"
|
|
230
|
-
Document chunk: "Electric vehicles produce zero emissions while driving, but battery production has environmental impacts."
|
|
231
|
-
Score: class_3
|
|
232
|
-
|
|
233
|
-
Search query: "Time zones in the United States"
|
|
234
|
-
Document chunk: "The U.S. is divided into six primary time zones: Eastern, Central, Mountain, Pacific, Alaska, and Hawaii-Aleutian."
|
|
235
|
-
Score: class_4
|
|
236
|
-
"""},
|
|
237
|
-
|
|
238
|
-
{"role": "user", "content": f"""
|
|
239
|
-
Now evaluate the following pair:
|
|
240
|
-
|
|
241
|
-
Search query: {query}
|
|
242
|
-
Document chunk: {document}
|
|
243
|
-
|
|
244
|
-
Which class best represents the relevance?
|
|
245
|
-
"""}
|
|
246
|
-
],
|
|
247
|
-
temperature=self.temperature,
|
|
248
|
-
n=1,
|
|
249
|
-
logprobs=True,
|
|
250
|
-
top_logprobs=4,
|
|
251
|
-
max_tokens=3
|
|
252
|
-
)
|
|
253
|
-
|
|
254
|
-
# Extract response and logprobs
|
|
255
|
-
class_label = response.choices[0].message.content.strip()
|
|
256
|
-
token_logprobs = response.choices[0].logprobs.content
|
|
257
|
-
# Reconstruct the prediction and extract the top logprobs from the final token (e.g., "1")
|
|
258
|
-
final_token_logprob = token_logprobs[-1]
|
|
259
|
-
top_logprobs = final_token_logprob.top_logprobs
|
|
260
|
-
# Create a map of 'class_1' -> probability, using token combinations
|
|
261
|
-
class_probs = {}
|
|
262
|
-
for top_token in top_logprobs:
|
|
263
|
-
full_label = f"class_{top_token.token}"
|
|
264
|
-
prob = math.exp(top_token.logprob)
|
|
265
|
-
class_probs[full_label] = prob
|
|
266
|
-
# Optional: normalize in case some are missing
|
|
267
|
-
total_prob = sum(class_probs.values())
|
|
268
|
-
class_probs = {k: v / total_prob for k, v in class_probs.items()}
|
|
269
|
-
# Assign weights to classes
|
|
270
|
-
class_weights = {
|
|
271
|
-
"class_1": 0.25,
|
|
272
|
-
"class_2": 0.5,
|
|
273
|
-
"class_3": 0.75,
|
|
274
|
-
"class_4": 1.0
|
|
275
|
-
}
|
|
276
|
-
# Compute the final smooth score
|
|
277
|
-
relevance_score = sum(class_weights.get(class_label, 0) * prob for class_label, prob in class_probs.items())
|
|
278
|
-
rerank_data = {
|
|
279
|
-
"document": document,
|
|
280
|
-
"answer": class_label,
|
|
281
|
-
"relevance_score": relevance_score
|
|
282
|
-
}
|
|
159
|
+
if self.method == "multi-class":
|
|
160
|
+
rerank_data = await self.search_relevancy_score(query, document)
|
|
161
|
+
elif self.method == "no-logprobs":
|
|
162
|
+
rerank_data = await self.search_relevancy_no_logprob(query, document)
|
|
163
|
+
else:
|
|
164
|
+
rerank_data = await self.search_relevancy(query, document)
|
|
165
|
+
if rerank_callback is not None:
|
|
166
|
+
rerank_callback(rerank_data)
|
|
283
167
|
return rerank_data
|
|
284
168
|
|
|
285
169
|
except Exception as e:
|
|
@@ -287,63 +171,185 @@ class BaseLLMReranker(BaseModel, ABC):
|
|
|
287
171
|
log.error(f"Failed after {self.max_retries} attempts: {str(e)}")
|
|
288
172
|
raise
|
|
289
173
|
# Exponential backoff with jitter
|
|
290
|
-
retry_delay = self.retry_delay * (2
|
|
174
|
+
retry_delay = self.retry_delay * (2**attempt) + random.uniform(0, 0.1)
|
|
291
175
|
await asyncio.sleep(retry_delay)
|
|
292
176
|
|
|
293
|
-
async def
|
|
294
|
-
|
|
177
|
+
async def search_relevancy(self, query: str, document: str) -> Any:
|
|
178
|
+
response = await self.client.chat.completions.create(
|
|
179
|
+
model=self.model,
|
|
180
|
+
messages=[
|
|
181
|
+
{
|
|
182
|
+
"role": "system",
|
|
183
|
+
"content": "Rate the relevance of the document to the query. Respond with 'yes' or 'no'.",
|
|
184
|
+
},
|
|
185
|
+
{"role": "user", "content": f"Query: {query}\nDocument: {document}\nIs this document relevant?"},
|
|
186
|
+
],
|
|
187
|
+
temperature=self.temperature,
|
|
188
|
+
n=1,
|
|
189
|
+
logprobs=True,
|
|
190
|
+
max_tokens=1,
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
# Extract response and logprobs
|
|
194
|
+
answer = response.choices[0].message.content
|
|
195
|
+
logprob = response.choices[0].logprobs.content[0].logprob
|
|
196
|
+
|
|
197
|
+
# Convert answer to score using the model's confidence
|
|
198
|
+
if answer.lower().strip() == "yes":
|
|
199
|
+
score = logprob # If yes, use the model's confidence
|
|
200
|
+
elif answer.lower().strip() == "no":
|
|
201
|
+
score = 1 - logprob # If no, invert the confidence
|
|
202
|
+
else:
|
|
203
|
+
score = 0.5 * logprob # For unclear answers, reduce confidence
|
|
295
204
|
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
try:
|
|
301
|
-
results = await asyncio.gather(
|
|
302
|
-
*[self.search_relevancy_score(query=query, document=document) for (query, document) in batch],
|
|
303
|
-
return_exceptions=True
|
|
304
|
-
)
|
|
205
|
+
rerank_data = {
|
|
206
|
+
"document": document,
|
|
207
|
+
"relevance_score": score,
|
|
208
|
+
}
|
|
305
209
|
|
|
306
|
-
|
|
307
|
-
if isinstance(result, Exception):
|
|
308
|
-
log.error(f"Error processing document {i+idx}: {str(result)}")
|
|
309
|
-
ranked_results.append((batch[idx][1], 0.0))
|
|
310
|
-
continue
|
|
210
|
+
return rerank_data
|
|
311
211
|
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
212
|
+
async def search_relevancy_no_logprob(self, query: str, document: str) -> Any:
|
|
213
|
+
prompt = dedent(
|
|
214
|
+
f"""
|
|
215
|
+
Score the relevance between search query and user message on scale between 0 and 100 per cents.
|
|
216
|
+
Consider semantic meaning, key concepts, and contextual relevance.
|
|
217
|
+
Return ONLY a numerical score between 0 and 100 per cents. No other text. Stop after sending a number
|
|
218
|
+
Search query: {query}
|
|
219
|
+
"""
|
|
220
|
+
)
|
|
318
221
|
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
high_scoring_docs = [r for r in ranked_results if r[1] >= self.filtering_threshold]
|
|
323
|
-
can_stop_early = (
|
|
324
|
-
self.early_stop # Early stopping is enabled
|
|
325
|
-
and self.num_docs_to_keep # We have a target number of docs
|
|
326
|
-
and len(high_scoring_docs) >= self.num_docs_to_keep # Found enough good docs
|
|
327
|
-
and score >= self.early_stop_threshold # Current doc is good enough
|
|
328
|
-
)
|
|
222
|
+
response = await self._call_llm(
|
|
223
|
+
messages=[{"role": "system", "content": prompt}, {"role": "user", "content": document}],
|
|
224
|
+
)
|
|
329
225
|
|
|
330
|
-
|
|
331
|
-
log.info(f"Early stopping after finding {self.num_docs_to_keep} documents with high confidence")
|
|
332
|
-
return ranked_results
|
|
333
|
-
except Exception as e:
|
|
334
|
-
# Don't let early stopping errors stop the whole process
|
|
335
|
-
log.warning(f"Error in early stopping check: {str(e)}")
|
|
226
|
+
answer = response.choices[0].message.content
|
|
336
227
|
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
228
|
+
try:
|
|
229
|
+
value = re.findall(r"[\d]+", answer)[0]
|
|
230
|
+
score = float(value) / 100
|
|
231
|
+
score = max(0.0, min(score, 1.0))
|
|
232
|
+
except (ValueError, IndexError):
|
|
233
|
+
score = 0.0
|
|
340
234
|
|
|
341
|
-
|
|
235
|
+
rerank_data = {
|
|
236
|
+
"document": document,
|
|
237
|
+
"relevance_score": score,
|
|
238
|
+
}
|
|
239
|
+
|
|
240
|
+
return rerank_data
|
|
241
|
+
|
|
242
|
+
async def search_relevancy_score(self, query: str, document: str) -> Any:
|
|
243
|
+
response = await self.client.chat.completions.create(
|
|
244
|
+
model=self.model,
|
|
245
|
+
messages=[
|
|
246
|
+
{
|
|
247
|
+
"role": "system",
|
|
248
|
+
"content": """
|
|
249
|
+
You are an intelligent assistant that evaluates how relevant a given document chunk is to a user's search query.
|
|
250
|
+
Your task is to analyze the similarity between the search query and the document chunk, and return **only the class label** that best represents the relevance:
|
|
251
|
+
|
|
252
|
+
- "class_1": Not relevant (score between 0.0 and 0.25)
|
|
253
|
+
- "class_2": Slightly relevant (score between 0.25 and 0.5)
|
|
254
|
+
- "class_3": Moderately relevant (score between 0.5 and 0.75)
|
|
255
|
+
- "class_4": Highly relevant (score between 0.75 and 1.0)
|
|
256
|
+
|
|
257
|
+
Respond with only one of: "class_1", "class_2", "class_3", or "class_4".
|
|
258
|
+
|
|
259
|
+
Examples:
|
|
260
|
+
|
|
261
|
+
Search query: "How to reset a router to factory settings?"
|
|
262
|
+
Document chunk: "Computers often come with customizable parental control settings."
|
|
263
|
+
Score: class_1
|
|
264
|
+
|
|
265
|
+
Search query: "Symptoms of vitamin D deficiency"
|
|
266
|
+
Document chunk: "Vitamin D deficiency has been linked to fatigue, bone pain, and muscle weakness."
|
|
267
|
+
Score: class_4
|
|
268
|
+
|
|
269
|
+
Search query: "Best practices for onboarding remote employees"
|
|
270
|
+
Document chunk: "An employee handbook can be useful for new hires, outlining company policies and benefits."
|
|
271
|
+
Score: class_2
|
|
272
|
+
|
|
273
|
+
Search query: "Benefits of mindfulness meditation"
|
|
274
|
+
Document chunk: "Practicing mindfulness has shown to reduce stress and improve focus in multiple studies."
|
|
275
|
+
Score: class_3
|
|
276
|
+
|
|
277
|
+
Search query: "What is Kubernetes used for?"
|
|
278
|
+
Document chunk: "Kubernetes is an open-source system for automating deployment, scaling, and management of containerized applications."
|
|
279
|
+
Score: class_4
|
|
280
|
+
|
|
281
|
+
Search query: "How to bake sourdough bread at home"
|
|
282
|
+
Document chunk: "The French Revolution began in 1789 and radically transformed society."
|
|
283
|
+
Score: class_1
|
|
284
|
+
|
|
285
|
+
Search query: "Machine learning algorithms for image classification"
|
|
286
|
+
Document chunk: "Convolutional Neural Networks (CNNs) are particularly effective in image classification tasks."
|
|
287
|
+
Score: class_4
|
|
288
|
+
|
|
289
|
+
Search query: "How to improve focus while working remotely"
|
|
290
|
+
Document chunk: "Creating a dedicated workspace and setting a consistent schedule can significantly improve focus during remote work."
|
|
291
|
+
Score: class_4
|
|
292
|
+
|
|
293
|
+
Search query: "Carbon emissions from electric vehicles vs gas cars"
|
|
294
|
+
Document chunk: "Electric vehicles produce zero emissions while driving, but battery production has environmental impacts."
|
|
295
|
+
Score: class_3
|
|
296
|
+
|
|
297
|
+
Search query: "Time zones in the United States"
|
|
298
|
+
Document chunk: "The U.S. is divided into six primary time zones: Eastern, Central, Mountain, Pacific, Alaska, and Hawaii-Aleutian."
|
|
299
|
+
Score: class_4
|
|
300
|
+
""",
|
|
301
|
+
},
|
|
302
|
+
{
|
|
303
|
+
"role": "user",
|
|
304
|
+
"content": f"""
|
|
305
|
+
Now evaluate the following pair:
|
|
306
|
+
|
|
307
|
+
Search query: {query}
|
|
308
|
+
Document chunk: {document}
|
|
309
|
+
|
|
310
|
+
Which class best represents the relevance?
|
|
311
|
+
""",
|
|
312
|
+
},
|
|
313
|
+
],
|
|
314
|
+
temperature=self.temperature,
|
|
315
|
+
n=1,
|
|
316
|
+
logprobs=True,
|
|
317
|
+
top_logprobs=4,
|
|
318
|
+
max_tokens=3,
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
# Extract response and logprobs
|
|
322
|
+
token_logprobs = response.choices[0].logprobs.content
|
|
323
|
+
# Reconstruct the prediction and extract the top logprobs from the final token (e.g., "1")
|
|
324
|
+
final_token_logprob = token_logprobs[-1]
|
|
325
|
+
top_logprobs = final_token_logprob.top_logprobs
|
|
326
|
+
# Create a map of 'class_1' -> probability, using token combinations
|
|
327
|
+
class_probs = {}
|
|
328
|
+
for top_token in top_logprobs:
|
|
329
|
+
full_label = f"class_{top_token.token}"
|
|
330
|
+
prob = math.exp(top_token.logprob)
|
|
331
|
+
class_probs[full_label] = prob
|
|
332
|
+
# Optional: normalize in case some are missing
|
|
333
|
+
total_prob = sum(class_probs.values())
|
|
334
|
+
class_probs = {k: v / total_prob for k, v in class_probs.items()}
|
|
335
|
+
# Assign weights to classes
|
|
336
|
+
class_weights = {"class_1": 0.25, "class_2": 0.5, "class_3": 0.75, "class_4": 1.0}
|
|
337
|
+
# Compute the final smooth score
|
|
338
|
+
score = sum(class_weights.get(class_label, 0) * prob for class_label, prob in class_probs.items())
|
|
339
|
+
if score is not None:
|
|
340
|
+
if score > 1.0:
|
|
341
|
+
score = 1.0
|
|
342
|
+
elif score < 0.0:
|
|
343
|
+
score = 0.0
|
|
344
|
+
|
|
345
|
+
rerank_data = {"document": document, "relevance_score": score}
|
|
346
|
+
return rerank_data
|
|
342
347
|
|
|
343
348
|
def get_scores(self, query: str, documents: list[str]):
|
|
344
349
|
query_document_pairs = [(query, doc) for doc in documents]
|
|
345
350
|
# Create event loop and run async code
|
|
346
351
|
import asyncio
|
|
352
|
+
|
|
347
353
|
try:
|
|
348
354
|
loop = asyncio.get_running_loop()
|
|
349
355
|
except RuntimeError:
|
|
@@ -351,10 +357,7 @@ class BaseLLMReranker(BaseModel, ABC):
|
|
|
351
357
|
loop = asyncio.new_event_loop()
|
|
352
358
|
asyncio.set_event_loop(loop)
|
|
353
359
|
|
|
354
|
-
|
|
355
|
-
documents_and_scores = loop.run_until_complete(self._rank_score(query_document_pairs))
|
|
356
|
-
else:
|
|
357
|
-
documents_and_scores = loop.run_until_complete(self._rank(query_document_pairs))
|
|
360
|
+
documents_and_scores = loop.run_until_complete(self._rank(query_document_pairs))
|
|
358
361
|
|
|
359
362
|
scores = [score for _, score in documents_and_scores]
|
|
360
363
|
return scores
|
|
@@ -8,16 +8,16 @@ from types import MappingProxyType
|
|
|
8
8
|
# the same as
|
|
9
9
|
# from mindsdb.integrations.handlers.openai_handler.constants import CHAT_MODELS
|
|
10
10
|
OPEN_AI_CHAT_MODELS = (
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
11
|
+
"gpt-3.5-turbo",
|
|
12
|
+
"gpt-3.5-turbo-16k",
|
|
13
|
+
"gpt-3.5-turbo-instruct",
|
|
14
|
+
"gpt-4",
|
|
15
|
+
"gpt-4-32k",
|
|
16
|
+
"gpt-4-1106-preview",
|
|
17
|
+
"gpt-4-0125-preview",
|
|
18
|
+
"gpt-4o",
|
|
19
|
+
"o3-mini",
|
|
20
|
+
"o1-mini",
|
|
21
21
|
)
|
|
22
22
|
|
|
23
23
|
SUPPORTED_PROVIDERS = {
|
|
@@ -28,7 +28,8 @@ SUPPORTED_PROVIDERS = {
|
|
|
28
28
|
"ollama",
|
|
29
29
|
"nvidia_nim",
|
|
30
30
|
"vllm",
|
|
31
|
-
"google"
|
|
31
|
+
"google",
|
|
32
|
+
"writer",
|
|
32
33
|
}
|
|
33
34
|
# Chat models
|
|
34
35
|
ANTHROPIC_CHAT_MODELS = (
|
|
@@ -175,6 +176,8 @@ GOOGLE_GEMINI_CHAT_MODELS = (
|
|
|
175
176
|
"gemini-1.5-pro",
|
|
176
177
|
)
|
|
177
178
|
|
|
179
|
+
WRITER_CHAT_MODELS = ("palmyra-x5", "palmyra-x4")
|
|
180
|
+
|
|
178
181
|
# Define a read-only dictionary mapping providers to their models
|
|
179
182
|
PROVIDER_TO_MODELS = MappingProxyType(
|
|
180
183
|
{
|
|
@@ -183,6 +186,7 @@ PROVIDER_TO_MODELS = MappingProxyType(
|
|
|
183
186
|
"openai": OPEN_AI_CHAT_MODELS,
|
|
184
187
|
"nvidia_nim": NVIDIA_NIM_CHAT_MODELS,
|
|
185
188
|
"google": GOOGLE_GEMINI_CHAT_MODELS,
|
|
189
|
+
"writer": WRITER_CHAT_MODELS,
|
|
186
190
|
}
|
|
187
191
|
)
|
|
188
192
|
|
|
@@ -200,8 +204,8 @@ DEFAULT_TEMPERATURE = 0.0
|
|
|
200
204
|
USER_COLUMN = "question"
|
|
201
205
|
DEFAULT_EMBEDDINGS_MODEL_PROVIDER = "openai"
|
|
202
206
|
DEFAULT_EMBEDDINGS_MODEL_CLASS = OpenAIEmbeddings
|
|
203
|
-
DEFAULT_TIKTOKEN_MODEL_NAME = os.getenv(
|
|
204
|
-
AGENT_CHUNK_POLLING_INTERVAL_SECONDS = os.getenv(
|
|
207
|
+
DEFAULT_TIKTOKEN_MODEL_NAME = os.getenv("DEFAULT_TIKTOKEN_MODEL_NAME", "gpt-4")
|
|
208
|
+
AGENT_CHUNK_POLLING_INTERVAL_SECONDS = os.getenv("AGENT_CHUNK_POLLING_INTERVAL_SECONDS", 1.0)
|
|
205
209
|
DEFAULT_TEXT2SQL_DATABASE = "mindsdb"
|
|
206
210
|
DEFAULT_AGENT_SYSTEM_PROMPT = """
|
|
207
211
|
You are an AI assistant powered by MindsDB. When answering questions, follow these guidelines:
|