MindsDB 25.4.4.0__py3-none-any.whl → 25.4.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of MindsDB might be problematic. Click here for more details.

Files changed (37) hide show
  1. mindsdb/__about__.py +1 -1
  2. mindsdb/api/executor/command_executor.py +12 -2
  3. mindsdb/api/executor/datahub/datanodes/mindsdb_tables.py +2 -1
  4. mindsdb/api/executor/planner/query_plan.py +1 -0
  5. mindsdb/api/executor/planner/query_planner.py +5 -0
  6. mindsdb/api/executor/sql_query/sql_query.py +24 -8
  7. mindsdb/api/executor/sql_query/steps/apply_predictor_step.py +20 -3
  8. mindsdb/api/executor/sql_query/steps/fetch_dataframe_partition.py +3 -1
  9. mindsdb/api/http/namespaces/config.py +19 -11
  10. mindsdb/integrations/handlers/openai_handler/helpers.py +3 -5
  11. mindsdb/integrations/handlers/openai_handler/openai_handler.py +20 -8
  12. mindsdb/integrations/handlers/togetherai_handler/__about__.py +9 -0
  13. mindsdb/integrations/handlers/togetherai_handler/__init__.py +20 -0
  14. mindsdb/integrations/handlers/togetherai_handler/creation_args.py +14 -0
  15. mindsdb/integrations/handlers/togetherai_handler/icon.svg +15 -0
  16. mindsdb/integrations/handlers/togetherai_handler/model_using_args.py +5 -0
  17. mindsdb/integrations/handlers/togetherai_handler/requirements.txt +2 -0
  18. mindsdb/integrations/handlers/togetherai_handler/settings.py +33 -0
  19. mindsdb/integrations/handlers/togetherai_handler/togetherai_handler.py +234 -0
  20. mindsdb/integrations/utilities/handler_utils.py +4 -0
  21. mindsdb/integrations/utilities/rag/rerankers/base_reranker.py +360 -0
  22. mindsdb/integrations/utilities/rag/rerankers/reranker_compressor.py +6 -346
  23. mindsdb/interfaces/functions/controller.py +3 -2
  24. mindsdb/interfaces/knowledge_base/controller.py +89 -75
  25. mindsdb/interfaces/query_context/context_controller.py +55 -15
  26. mindsdb/interfaces/query_context/query_task.py +19 -0
  27. mindsdb/interfaces/storage/db.py +2 -2
  28. mindsdb/interfaces/tasks/task_monitor.py +5 -1
  29. mindsdb/interfaces/tasks/task_thread.py +6 -0
  30. mindsdb/migrations/versions/2025-04-22_53502b6d63bf_query_database.py +27 -0
  31. mindsdb/utilities/config.py +12 -1
  32. mindsdb/utilities/context.py +1 -0
  33. {mindsdb-25.4.4.0.dist-info → mindsdb-25.4.5.0.dist-info}/METADATA +229 -226
  34. {mindsdb-25.4.4.0.dist-info → mindsdb-25.4.5.0.dist-info}/RECORD +37 -26
  35. {mindsdb-25.4.4.0.dist-info → mindsdb-25.4.5.0.dist-info}/WHEEL +1 -1
  36. {mindsdb-25.4.4.0.dist-info → mindsdb-25.4.5.0.dist-info}/licenses/LICENSE +0 -0
  37. {mindsdb-25.4.4.0.dist-info → mindsdb-25.4.5.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,234 @@
1
+ import os
2
+ import textwrap
3
+ from typing import Optional, Dict, Any
4
+ import requests
5
+ import pandas as pd
6
+ from openai import OpenAI, AuthenticationError
7
+ from mindsdb.integrations.handlers.openai_handler import Handler as OpenAIHandler
8
+ from mindsdb.integrations.utilities.handler_utils import get_api_key
9
+ from mindsdb.integrations.handlers.togetherai_handler.settings import (
10
+ togetherai_handler_config,
11
+ )
12
+
13
+ from mindsdb.utilities import log
14
+
15
+ logger = log.getLogger(__name__)
16
+
17
+
18
+ class TogetherAIHandler(OpenAIHandler):
19
+ """
20
+ This handler handles connection to the TogetherAI.
21
+ """
22
+
23
+ name = "togetherai"
24
+
25
+ def __init__(self, *args, **kwargs):
26
+ super().__init__(*args, **kwargs)
27
+ self.generative = True
28
+ self.api_base = togetherai_handler_config.BASE_URL
29
+ self.default_model = togetherai_handler_config.DEFAULT_MODEL
30
+ self.default_embedding_model = togetherai_handler_config.DEFAULT_EMBEDDING_MODEL
31
+ self.default_mode = togetherai_handler_config.DEFAULT_MODE
32
+ self.supported_modes = togetherai_handler_config.SUPPORTED_MODES
33
+
34
+ @staticmethod
35
+ def _check_client_connection(client: OpenAI):
36
+ """
37
+ Check the TogetherAI engine client connection by listing models.
38
+
39
+ Args:
40
+ client (OpenAI): OpenAI client configured with the TogetherAI API credentials.
41
+
42
+ Raises:
43
+ Exception: If the client connection (API key) is invalid.
44
+
45
+ Returns:
46
+ None
47
+ """
48
+
49
+ try:
50
+ TogetherAIHandler._get_supported_models(client.api_key, client.base_url)
51
+
52
+ except Exception as e:
53
+ raise Exception(f"Something went wrong: {e}")
54
+
55
+ def create_engine(self, connection_args):
56
+ """
57
+ Validate the TogetherAI API credentials on engine creation.
58
+
59
+ Args:
60
+ connection_args (dict): Connection arguments.
61
+
62
+ Raises:
63
+ Exception: If the handler is not configured with valid API credentials.
64
+
65
+ Returns:
66
+ None
67
+ """
68
+
69
+ connection_args = {k.lower(): v for k, v in connection_args.items()}
70
+ api_key = connection_args.get("togetherai_api_key")
71
+ if api_key is not None:
72
+ api_base = connection_args.get("api_base") or os.environ.get(
73
+ "TOGETHERAI_API_BASE", togetherai_handler_config.BASE_URL
74
+ )
75
+ client = self._get_client(api_key=api_key, base_url=api_base)
76
+ TogetherAIHandler._check_client_connection(client)
77
+
78
+ @staticmethod
79
+ def create_validation(target, args=None, **kwargs):
80
+ """
81
+ Validate the TogetherAI API credentials on model creation.
82
+
83
+ Args:
84
+ target (str): Target column, not required for LLMs.
85
+ args (dict): Handler arguments.
86
+ kwargs (dict): Handler keyword arguments.
87
+
88
+ Raises:
89
+ Exception: If the handler is not configured with valid API credentials.
90
+
91
+ Returns:
92
+ None
93
+ """
94
+ if "using" not in args:
95
+ raise Exception(
96
+ "TogetherAI engine require a USING clause! Refer to its documentation for more details"
97
+ )
98
+ else:
99
+ args = args["using"]
100
+
101
+ if (
102
+ len(set(args.keys()) & {"question_column", "prompt_template", "prompt"})
103
+ == 0
104
+ ):
105
+ raise Exception(
106
+ "One of `question_column`, `prompt_template` or `prompt` is required for this engine."
107
+ )
108
+
109
+ keys_collection = [
110
+ ["prompt_template"],
111
+ ["question_column", "context_column"],
112
+ ["prompt", "user_column", "assistant_column"],
113
+ ]
114
+ for keys in keys_collection:
115
+ if keys[0] in args and any(
116
+ x[0] in args for x in keys_collection if x != keys
117
+ ):
118
+ raise Exception(
119
+ textwrap.dedent(
120
+ """\
121
+ Please provide one of
122
+ 1) a `prompt_template`
123
+ 2) a `question_column` and an optional `context_column`
124
+ 3) a `prompt`, `user_column` and `assistant_column`
125
+ """
126
+ )
127
+ )
128
+
129
+ engine_storage = kwargs["handler_storage"]
130
+ connection_args = engine_storage.get_connection_args()
131
+ api_key = get_api_key("togetherai", args, engine_storage=engine_storage)
132
+ api_base = connection_args.get("api_base") or os.environ.get(
133
+ "TOGETHERAI_API_BASE", togetherai_handler_config.BASE_URL
134
+ )
135
+ client = TogetherAIHandler._get_client(api_key=api_key, base_url=api_base)
136
+ TogetherAIHandler._check_client_connection(client)
137
+
138
+ def create(self, target, args: Dict = None, **kwargs: Any) -> None:
139
+ """
140
+ Create a model for TogetherAI engine.
141
+
142
+ Args:
143
+ target (str): Target column, not required for LLMs.
144
+ args (dict): Handler arguments.
145
+ kwargs (dict): Handler keyword arguments.
146
+
147
+ Raises:
148
+ Exception: If the handler is not configured with valid API credentials.
149
+
150
+ Returns:
151
+ None
152
+ """
153
+ args = args["using"]
154
+ args["target"] = target
155
+ try:
156
+ api_key = get_api_key(self.api_key_name, args, self.engine_storage)
157
+ connection_args = self.engine_storage.get_connection_args()
158
+ api_base = (
159
+ args.get("api_base")
160
+ or connection_args.get("api_base")
161
+ or os.environ.get("TOGETHERAI_API_BASE")
162
+ or self.api_base
163
+ )
164
+ available_models = self._get_supported_models(api_key, api_base)
165
+
166
+ if args.get("mode") is None:
167
+ args["mode"] = self.default_mode
168
+ elif args["mode"] not in self.supported_modes:
169
+ raise Exception(
170
+ f"Invalid operation mode. Please use one of {self.supported_modes}"
171
+ )
172
+
173
+ if not args.get("model_name"):
174
+ if args["mode"] == "embedding":
175
+ args["model_name"] = self.default_embedding_model
176
+ else:
177
+ args["model_name"] = self.default_model
178
+ elif args["model_name"] not in available_models:
179
+ raise Exception(
180
+ f"Invalid model name. Please use one of {available_models}"
181
+ )
182
+ finally:
183
+ self.model_storage.json_set("args", args)
184
+
185
+ def predict(self, df: pd.DataFrame, args: Optional[Dict] = None) -> pd.DataFrame:
186
+ """
187
+ Call the TogetherAI engine to predict the next token.
188
+
189
+ Args:
190
+ df (pd.DataFrame): Input data.
191
+ args (dict): Handler arguments.
192
+
193
+ Returns:
194
+ pd.DataFrame: Predicted data.
195
+ """
196
+
197
+ api_key = get_api_key("togetherai", args, engine_storage=self.engine_storage)
198
+ supported_models = self._get_supported_models(api_key, self.api_base)
199
+ self.chat_completion_models = supported_models
200
+ return super().predict(df, args)
201
+
202
+ @staticmethod
203
+ def _get_supported_models(api_key, base_url):
204
+ """
205
+ Get the list of supported models from the TogetherAI engine.
206
+
207
+ Args:
208
+ api_key (str): TogetherAI API key.
209
+ base_url (str): TogetherAI API base URL.
210
+
211
+ Returns:
212
+ list: List of supported models.
213
+ """
214
+
215
+ list_model_endpoint = f"{base_url}/models"
216
+ headers = {
217
+ "accept": "application/json",
218
+ "authorization": f"Bearer {api_key}",
219
+ }
220
+ response = requests.get(url=list_model_endpoint, headers=headers)
221
+
222
+ if response.status_code == 200:
223
+ model_list = response.json()
224
+ chat_completion_models = list(map(lambda model: model["id"], model_list))
225
+ return chat_completion_models
226
+ elif response.status_code == 401:
227
+ raise AuthenticationError(message="Invalid API key")
228
+ else:
229
+ raise Exception(f"Failed to get supported models: {response.text}")
230
+
231
+ def finetune(
232
+ self, df: Optional[pd.DataFrame] = None, args: Optional[Dict] = None
233
+ ) -> None:
234
+ raise NotImplementedError("Fine-tuning is not supported for TogetherAI engine")
@@ -63,6 +63,10 @@ def get_api_key(
63
63
  if f"{api_name.lower()}_api_key" in api_cfg:
64
64
  return api_cfg[f"{api_name.lower()}_api_key"]
65
65
 
66
+ # 6
67
+ if 'api_keys' in create_args and api_name in create_args['api_keys']:
68
+ return create_args['api_keys'][api_name]
69
+
66
70
  if strict:
67
71
  raise Exception(
68
72
  f"Missing API key '{api_name.lower()}_api_key'. Either re-create this ML_ENGINE specifying the '{api_name.lower()}_api_key' parameter, or re-create this model and pass the API key with `USING` syntax."
@@ -0,0 +1,360 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import logging
5
+ import math
6
+ import os
7
+ import random
8
+ from abc import ABC
9
+ from typing import Any, List, Optional, Tuple
10
+
11
+ from openai import AsyncOpenAI, AsyncAzureOpenAI
12
+ from pydantic import field_validator
13
+ from pydantic import BaseModel
14
+
15
+ from mindsdb.integrations.utilities.rag.settings import DEFAULT_RERANKING_MODEL, DEFAULT_LLM_ENDPOINT
16
+
17
+ log = logging.getLogger(__name__)
18
+
19
+
20
+ class BaseLLMReranker(BaseModel, ABC):
21
+
22
+ filtering_threshold: float = 0.0 # Default threshold for filtering
23
+ provider: str = 'openai'
24
+ model: str = DEFAULT_RERANKING_MODEL # Model to use for reranking
25
+ temperature: float = 0.0 # Temperature for the model
26
+ api_key: Optional[str] = None
27
+ base_url: Optional[str] = None
28
+ api_version: Optional[str] = None
29
+ num_docs_to_keep: Optional[int] = None # How many of the top documents to keep after reranking & compressing.
30
+ method: str = "multi-class" # Scoring method: 'multi-class' or 'binary'
31
+ _api_key_var: str = "OPENAI_API_KEY"
32
+ client: Optional[AsyncOpenAI] = None
33
+ _semaphore: Optional[asyncio.Semaphore] = None
34
+ max_concurrent_requests: int = 20
35
+ max_retries: int = 3
36
+ retry_delay: float = 1.0
37
+ request_timeout: float = 20.0 # Timeout for API requests
38
+ early_stop: bool = True # Whether to enable early stopping
39
+ early_stop_threshold: float = 0.8 # Confidence threshold for early stopping
40
+
41
+ class Config:
42
+ arbitrary_types_allowed = True
43
+
44
+ @field_validator('provider')
45
+ @classmethod
46
+ def validate_provider(cls, v: str) -> str:
47
+ allowed = {'openai', 'azure_openai'}
48
+ v_lower = v.lower()
49
+ if v_lower not in allowed:
50
+ raise ValueError(f"Unsupported provider: {v}.")
51
+ return v_lower
52
+
53
+ def __init__(self, **kwargs):
54
+ super().__init__(**kwargs)
55
+ self._semaphore = asyncio.Semaphore(self.max_concurrent_requests)
56
+
57
+ async def _init_client(self):
58
+ if self.client is None:
59
+
60
+ if self.provider == "azure_openai":
61
+
62
+ azure_api_key = self.api_key or os.getenv("AZURE_OPENAI_API_KEY")
63
+ azure_api_endpoint = self.base_url or os.environ.get("AZURE_OPENAI_ENDPOINT")
64
+ azure_api_version = self.api_version or os.environ.get("AZURE_OPENAI_API_VERSION")
65
+ self.client = AsyncAzureOpenAI(api_key=azure_api_key,
66
+ azure_endpoint=azure_api_endpoint,
67
+ api_version=azure_api_version,
68
+ timeout=self.request_timeout,
69
+ max_retries=2)
70
+ elif self.provider == "openai":
71
+ api_key_var: str = "OPENAI_API_KEY"
72
+ openai_api_key = self.api_key or os.getenv(api_key_var)
73
+ if not openai_api_key:
74
+ raise ValueError(f"OpenAI API key not found in environment variable {api_key_var}")
75
+
76
+ base_url = self.base_url or DEFAULT_LLM_ENDPOINT
77
+ self.client = AsyncOpenAI(api_key=openai_api_key, base_url=base_url, timeout=self.request_timeout, max_retries=2)
78
+
79
+ async def search_relevancy(self, query: str, document: str, rerank_callback=None) -> Any:
80
+ await self._init_client()
81
+
82
+ async with self._semaphore:
83
+ for attempt in range(self.max_retries):
84
+ try:
85
+ response = await self.client.chat.completions.create(
86
+ model=self.model,
87
+ messages=[
88
+ {"role": "system", "content": "Rate the relevance of the document to the query. Respond with 'yes' or 'no'."},
89
+ {"role": "user", "content": f"Query: {query}\nDocument: {document}\nIs this document relevant?"}
90
+ ],
91
+ temperature=self.temperature,
92
+ n=1,
93
+ logprobs=True,
94
+ max_tokens=1
95
+ )
96
+
97
+ # Extract response and logprobs
98
+ answer = response.choices[0].message.content
99
+ logprob = response.choices[0].logprobs.content[0].logprob
100
+ rerank_data = {
101
+ "document": document,
102
+ "answer": answer,
103
+ "logprob": logprob
104
+ }
105
+
106
+ # Stream reranking update.
107
+ if rerank_callback is not None:
108
+ rerank_callback(rerank_data)
109
+
110
+ return rerank_data
111
+
112
+ except Exception as e:
113
+ if attempt == self.max_retries - 1:
114
+ log.error(f"Failed after {self.max_retries} attempts: {str(e)}")
115
+ raise
116
+ # Exponential backoff with jitter
117
+ retry_delay = self.retry_delay * (2 ** attempt) + random.uniform(0, 0.1)
118
+ await asyncio.sleep(retry_delay)
119
+
120
+ async def _rank(self, query_document_pairs: List[Tuple[str, str]], rerank_callback=None) -> List[Tuple[str, float]]:
121
+ ranked_results = []
122
+
123
+ # Process in larger batches for better throughput
124
+ batch_size = min(self.max_concurrent_requests * 2, len(query_document_pairs))
125
+ for i in range(0, len(query_document_pairs), batch_size):
126
+ batch = query_document_pairs[i:i + batch_size]
127
+ try:
128
+ results = await asyncio.gather(
129
+ *[self.search_relevancy(query=query, document=document, rerank_callback=rerank_callback) for (query, document) in batch],
130
+ return_exceptions=True
131
+ )
132
+
133
+ for idx, result in enumerate(results):
134
+ if isinstance(result, Exception):
135
+ log.error(f"Error processing document {i+idx}: {str(result)}")
136
+ ranked_results.append((batch[idx][1], 0.0))
137
+ continue
138
+
139
+ answer = result["answer"]
140
+ logprob = result["logprob"]
141
+ prob = math.exp(logprob)
142
+
143
+ # Convert answer to score using the model's confidence
144
+ if answer.lower().strip() == "yes":
145
+ score = prob # If yes, use the model's confidence
146
+ elif answer.lower().strip() == "no":
147
+ score = 1 - prob # If no, invert the confidence
148
+ else:
149
+ score = 0.5 * prob # For unclear answers, reduce confidence
150
+
151
+ ranked_results.append((batch[idx][1], score))
152
+
153
+ # Check if we should stop early
154
+ try:
155
+ high_scoring_docs = [r for r in ranked_results if r[1] >= self.filtering_threshold]
156
+ can_stop_early = (
157
+ self.early_stop # Early stopping is enabled
158
+ and self.num_docs_to_keep # We have a target number of docs
159
+ and len(high_scoring_docs) >= self.num_docs_to_keep # Found enough good docs
160
+ and score >= self.early_stop_threshold # Current doc is good enough
161
+ )
162
+
163
+ if can_stop_early:
164
+ log.info(f"Early stopping after finding {self.num_docs_to_keep} documents with high confidence")
165
+ return ranked_results
166
+ except Exception as e:
167
+ # Don't let early stopping errors stop the whole process
168
+ log.warning(f"Error in early stopping check: {str(e)}")
169
+
170
+ except Exception as e:
171
+ log.error(f"Batch processing error: {str(e)}")
172
+ continue
173
+ return ranked_results
174
+
175
+ async def search_relevancy_score(self, query: str, document: str) -> Any:
176
+ await self._init_client()
177
+
178
+ async with self._semaphore:
179
+ for attempt in range(self.max_retries):
180
+ try:
181
+ response = await self.client.chat.completions.create(
182
+ model=self.model,
183
+ messages=[
184
+ {"role": "system", "content": """
185
+ You are an intelligent assistant that evaluates how relevant a given document chunk is to a user's search query.
186
+ Your task is to analyze the similarity between the search query and the document chunk, and return **only the class label** that best represents the relevance:
187
+
188
+ - "class_1": Not relevant (score between 0.0 and 0.25)
189
+ - "class_2": Slightly relevant (score between 0.25 and 0.5)
190
+ - "class_3": Moderately relevant (score between 0.5 and 0.75)
191
+ - "class_4": Highly relevant (score between 0.75 and 1.0)
192
+
193
+ Respond with only one of: "class_1", "class_2", "class_3", or "class_4".
194
+
195
+ Examples:
196
+
197
+ Search query: "How to reset a router to factory settings?"
198
+ Document chunk: "Computers often come with customizable parental control settings."
199
+ Score: class_1
200
+
201
+ Search query: "Symptoms of vitamin D deficiency"
202
+ Document chunk: "Vitamin D deficiency has been linked to fatigue, bone pain, and muscle weakness."
203
+ Score: class_4
204
+
205
+ Search query: "Best practices for onboarding remote employees"
206
+ Document chunk: "An employee handbook can be useful for new hires, outlining company policies and benefits."
207
+ Score: class_2
208
+
209
+ Search query: "Benefits of mindfulness meditation"
210
+ Document chunk: "Practicing mindfulness has shown to reduce stress and improve focus in multiple studies."
211
+ Score: class_3
212
+
213
+ Search query: "What is Kubernetes used for?"
214
+ Document chunk: "Kubernetes is an open-source system for automating deployment, scaling, and management of containerized applications."
215
+ Score: class_4
216
+
217
+ Search query: "How to bake sourdough bread at home"
218
+ Document chunk: "The French Revolution began in 1789 and radically transformed society."
219
+ Score: class_1
220
+
221
+ Search query: "Machine learning algorithms for image classification"
222
+ Document chunk: "Convolutional Neural Networks (CNNs) are particularly effective in image classification tasks."
223
+ Score: class_4
224
+
225
+ Search query: "How to improve focus while working remotely"
226
+ Document chunk: "Creating a dedicated workspace and setting a consistent schedule can significantly improve focus during remote work."
227
+ Score: class_4
228
+
229
+ Search query: "Carbon emissions from electric vehicles vs gas cars"
230
+ Document chunk: "Electric vehicles produce zero emissions while driving, but battery production has environmental impacts."
231
+ Score: class_3
232
+
233
+ Search query: "Time zones in the United States"
234
+ Document chunk: "The U.S. is divided into six primary time zones: Eastern, Central, Mountain, Pacific, Alaska, and Hawaii-Aleutian."
235
+ Score: class_4
236
+ """},
237
+
238
+ {"role": "user", "content": f"""
239
+ Now evaluate the following pair:
240
+
241
+ Search query: {query}
242
+ Document chunk: {document}
243
+
244
+ Which class best represents the relevance?
245
+ """}
246
+ ],
247
+ temperature=self.temperature,
248
+ n=1,
249
+ logprobs=True,
250
+ top_logprobs=4,
251
+ max_tokens=3
252
+ )
253
+
254
+ # Extract response and logprobs
255
+ class_label = response.choices[0].message.content.strip()
256
+ token_logprobs = response.choices[0].logprobs.content
257
+ # Reconstruct the prediction and extract the top logprobs from the final token (e.g., "1")
258
+ final_token_logprob = token_logprobs[-1]
259
+ top_logprobs = final_token_logprob.top_logprobs
260
+ # Create a map of 'class_1' -> probability, using token combinations
261
+ class_probs = {}
262
+ for top_token in top_logprobs:
263
+ full_label = f"class_{top_token.token}"
264
+ prob = math.exp(top_token.logprob)
265
+ class_probs[full_label] = prob
266
+ # Optional: normalize in case some are missing
267
+ total_prob = sum(class_probs.values())
268
+ class_probs = {k: v / total_prob for k, v in class_probs.items()}
269
+ # Assign weights to classes
270
+ class_weights = {
271
+ "class_1": 0.25,
272
+ "class_2": 0.5,
273
+ "class_3": 0.75,
274
+ "class_4": 1.0
275
+ }
276
+ # Compute the final smooth score
277
+ relevance_score = sum(class_weights.get(class_label, 0) * prob for class_label, prob in class_probs.items())
278
+ rerank_data = {
279
+ "document": document,
280
+ "answer": class_label,
281
+ "relevance_score": relevance_score
282
+ }
283
+ return rerank_data
284
+
285
+ except Exception as e:
286
+ if attempt == self.max_retries - 1:
287
+ log.error(f"Failed after {self.max_retries} attempts: {str(e)}")
288
+ raise
289
+ # Exponential backoff with jitter
290
+ retry_delay = self.retry_delay * (2 ** attempt) + random.uniform(0, 0.1)
291
+ await asyncio.sleep(retry_delay)
292
+
293
+ async def _rank_score(self, query_document_pairs: List[Tuple[str, str]]) -> List[Tuple[str, float]]:
294
+ ranked_results = []
295
+
296
+ # Process in larger batches for better throughput
297
+ batch_size = min(self.max_concurrent_requests * 2, len(query_document_pairs))
298
+ for i in range(0, len(query_document_pairs), batch_size):
299
+ batch = query_document_pairs[i:i + batch_size]
300
+ try:
301
+ results = await asyncio.gather(
302
+ *[self.search_relevancy_score(query=query, document=document) for (query, document) in batch],
303
+ return_exceptions=True
304
+ )
305
+
306
+ for idx, result in enumerate(results):
307
+ if isinstance(result, Exception):
308
+ log.error(f"Error processing document {i+idx}: {str(result)}")
309
+ ranked_results.append((batch[idx][1], 0.0))
310
+ continue
311
+
312
+ score = result["relevance_score"]
313
+ if score is not None:
314
+ if score > 1.0:
315
+ score = 1.0
316
+ elif score < 0.0:
317
+ score = 0.0
318
+
319
+ ranked_results.append((batch[idx][1], score))
320
+ # Check if we should stop early
321
+ try:
322
+ high_scoring_docs = [r for r in ranked_results if r[1] >= self.filtering_threshold]
323
+ can_stop_early = (
324
+ self.early_stop # Early stopping is enabled
325
+ and self.num_docs_to_keep # We have a target number of docs
326
+ and len(high_scoring_docs) >= self.num_docs_to_keep # Found enough good docs
327
+ and score >= self.early_stop_threshold # Current doc is good enough
328
+ )
329
+
330
+ if can_stop_early:
331
+ log.info(f"Early stopping after finding {self.num_docs_to_keep} documents with high confidence")
332
+ return ranked_results
333
+ except Exception as e:
334
+ # Don't let early stopping errors stop the whole process
335
+ log.warning(f"Error in early stopping check: {str(e)}")
336
+
337
+ except Exception as e:
338
+ log.error(f"Batch processing error: {str(e)}")
339
+ continue
340
+
341
+ return ranked_results
342
+
343
+ def get_scores(self, query: str, documents: list[str]):
344
+ query_document_pairs = [(query, doc) for doc in documents]
345
+ # Create event loop and run async code
346
+ import asyncio
347
+ try:
348
+ loop = asyncio.get_running_loop()
349
+ except RuntimeError:
350
+ # If no running loop exists, create a new one
351
+ loop = asyncio.new_event_loop()
352
+ asyncio.set_event_loop(loop)
353
+
354
+ if self.method == "multi-class": # default 'multi-class' method
355
+ documents_and_scores = loop.run_until_complete(self._rank_score(query_document_pairs))
356
+ else:
357
+ documents_and_scores = loop.run_until_complete(self._rank(query_document_pairs))
358
+
359
+ scores = [score for _, score in documents_and_scores]
360
+ return scores