MindsDB 25.6.4.0__py3-none-any.whl → 25.7.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of MindsDB might be problematic. Click here for more details.

Files changed (61) hide show
  1. mindsdb/__about__.py +1 -1
  2. mindsdb/__main__.py +53 -94
  3. mindsdb/api/a2a/agent.py +30 -206
  4. mindsdb/api/a2a/common/server/server.py +26 -27
  5. mindsdb/api/a2a/task_manager.py +93 -227
  6. mindsdb/api/a2a/utils.py +21 -0
  7. mindsdb/api/executor/command_executor.py +8 -6
  8. mindsdb/api/executor/datahub/datanodes/information_schema_datanode.py +1 -1
  9. mindsdb/api/executor/datahub/datanodes/integration_datanode.py +9 -11
  10. mindsdb/api/executor/datahub/datanodes/system_tables.py +1 -1
  11. mindsdb/api/executor/planner/query_prepare.py +68 -87
  12. mindsdb/api/executor/sql_query/steps/fetch_dataframe.py +6 -1
  13. mindsdb/api/executor/sql_query/steps/union_step.py +11 -9
  14. mindsdb/api/executor/utilities/sql.py +97 -21
  15. mindsdb/api/http/namespaces/agents.py +126 -201
  16. mindsdb/api/http/namespaces/config.py +12 -1
  17. mindsdb/api/http/namespaces/file.py +49 -24
  18. mindsdb/api/mcp/start.py +45 -31
  19. mindsdb/integrations/handlers/chromadb_handler/chromadb_handler.py +45 -52
  20. mindsdb/integrations/handlers/huggingface_handler/__init__.py +17 -12
  21. mindsdb/integrations/handlers/huggingface_handler/finetune.py +223 -223
  22. mindsdb/integrations/handlers/huggingface_handler/huggingface_handler.py +383 -383
  23. mindsdb/integrations/handlers/huggingface_handler/requirements.txt +7 -6
  24. mindsdb/integrations/handlers/huggingface_handler/requirements_cpu.txt +7 -6
  25. mindsdb/integrations/handlers/huggingface_handler/settings.py +25 -25
  26. mindsdb/integrations/handlers/litellm_handler/litellm_handler.py +22 -15
  27. mindsdb/integrations/handlers/pgvector_handler/pgvector_handler.py +244 -141
  28. mindsdb/integrations/handlers/postgres_handler/postgres_handler.py +1 -1
  29. mindsdb/integrations/handlers/salesforce_handler/salesforce_handler.py +3 -2
  30. mindsdb/integrations/handlers/salesforce_handler/salesforce_tables.py +1 -1
  31. mindsdb/integrations/handlers/statsforecast_handler/requirements.txt +1 -0
  32. mindsdb/integrations/handlers/statsforecast_handler/requirements_extra.txt +1 -0
  33. mindsdb/integrations/libs/keyword_search_base.py +41 -0
  34. mindsdb/integrations/libs/vectordatabase_handler.py +114 -84
  35. mindsdb/integrations/utilities/rag/rerankers/base_reranker.py +36 -42
  36. mindsdb/integrations/utilities/sql_utils.py +11 -0
  37. mindsdb/interfaces/agents/agents_controller.py +29 -9
  38. mindsdb/interfaces/agents/langchain_agent.py +7 -5
  39. mindsdb/interfaces/agents/mcp_client_agent.py +4 -4
  40. mindsdb/interfaces/agents/mindsdb_database_agent.py +10 -43
  41. mindsdb/interfaces/data_catalog/data_catalog_reader.py +3 -1
  42. mindsdb/interfaces/database/projects.py +1 -3
  43. mindsdb/interfaces/functions/controller.py +54 -64
  44. mindsdb/interfaces/functions/to_markdown.py +47 -14
  45. mindsdb/interfaces/knowledge_base/controller.py +228 -110
  46. mindsdb/interfaces/knowledge_base/evaluate.py +18 -6
  47. mindsdb/interfaces/knowledge_base/executor.py +346 -0
  48. mindsdb/interfaces/knowledge_base/llm_client.py +5 -6
  49. mindsdb/interfaces/knowledge_base/preprocessing/document_preprocessor.py +20 -45
  50. mindsdb/interfaces/knowledge_base/preprocessing/models.py +36 -69
  51. mindsdb/interfaces/skills/custom/text2sql/mindsdb_kb_tools.py +2 -0
  52. mindsdb/interfaces/skills/sql_agent.py +181 -130
  53. mindsdb/interfaces/storage/db.py +9 -7
  54. mindsdb/utilities/config.py +58 -40
  55. mindsdb/utilities/exception.py +58 -7
  56. mindsdb/utilities/security.py +54 -11
  57. {mindsdb-25.6.4.0.dist-info → mindsdb-25.7.2.0.dist-info}/METADATA +245 -259
  58. {mindsdb-25.6.4.0.dist-info → mindsdb-25.7.2.0.dist-info}/RECORD +61 -58
  59. {mindsdb-25.6.4.0.dist-info → mindsdb-25.7.2.0.dist-info}/WHEEL +0 -0
  60. {mindsdb-25.6.4.0.dist-info → mindsdb-25.7.2.0.dist-info}/licenses/LICENSE +0 -0
  61. {mindsdb-25.6.4.0.dist-info → mindsdb-25.7.2.0.dist-info}/top_level.txt +0 -0
@@ -476,7 +476,7 @@ class PostgresHandler(MetaDatabaseHandler):
476
476
  config = self._make_connection_args()
477
477
  config["autocommit"] = True
478
478
 
479
- conn = psycopg.connect(connect_timeout=10, **config)
479
+ conn = psycopg.connect(**config)
480
480
 
481
481
  # create db trigger
482
482
  trigger_name = f"mdb_notify_{table_name}"
@@ -271,10 +271,11 @@ class SalesforceHandler(MetaAPIHandler):
271
271
 
272
272
  # Retrieve the metadata for all Salesforce resources.
273
273
  main_metadata = connection.sobjects.describe()
274
-
275
274
  if table_names:
276
275
  # Filter the metadata for the specified tables.
277
- main_metadata = [resource for resource in main_metadata["sobjects"] if resource["name"] in table_names]
276
+ main_metadata = [
277
+ resource for resource in main_metadata["sobjects"] if resource["name"].lower() in table_names
278
+ ]
278
279
  else:
279
280
  main_metadata = main_metadata["sobjects"]
280
281
 
@@ -165,7 +165,7 @@ def create_table_class(resource_name: Text) -> MetaAPIResource:
165
165
  client = self.handler.connect()
166
166
 
167
167
  resource_metadata = next(
168
- (resource for resource in main_metadata if resource["name"] == resource_name),
168
+ (resource for resource in main_metadata if resource["name"].lower() == resource_name),
169
169
  )
170
170
 
171
171
  # Get row count if Id column is aggregatable.
@@ -1 +1,2 @@
1
1
  statsforecast==1.6.0
2
+ scipy==1.15.3
@@ -1 +1,2 @@
1
1
  statsforecast==1.6.0
2
+ scipy==1.15.3
@@ -0,0 +1,41 @@
1
+ from mindsdb_sql_parser.ast import Select
2
+ from typing import List
3
+ import pandas as pd
4
+
5
+ from mindsdb.integrations.utilities.sql_utils import FilterCondition, KeywordSearchArgs
6
+
7
+
8
+ class KeywordSearchBase:
9
+ """
10
+ Base class for keyword search integrations.
11
+ This class provides a common interface for keyword search functionality.
12
+ """
13
+
14
+ def __init__(self, *args, **kwargs):
15
+ pass
16
+
17
+ def dispatch_keyword_select(
18
+ self, query: Select, conditions: List[FilterCondition] = None, keyword_search_args: KeywordSearchArgs = None
19
+ ):
20
+ """Dispatches a keyword search select query to the appropriate method."""
21
+ raise NotImplementedError()
22
+
23
+ def keyword_select(
24
+ self,
25
+ table_name: str,
26
+ columns: List[str] = None,
27
+ conditions: List[FilterCondition] = None,
28
+ offset: int = None,
29
+ limit: int = None,
30
+ ) -> pd.DataFrame:
31
+ """Select data from table
32
+
33
+ Args:
34
+ table_name (str): table name
35
+ columns (List[str]): columns to select
36
+ conditions (List[FilterCondition]): conditions to select
37
+
38
+ Returns:
39
+ HandlerResponse
40
+ """
41
+ raise NotImplementedError()
@@ -2,6 +2,7 @@ import ast
2
2
  import hashlib
3
3
  from enum import Enum
4
4
  from typing import Dict, List, Optional
5
+ import datetime as dt
5
6
 
6
7
  import pandas as pd
7
8
  from mindsdb_sql_parser.ast import (
@@ -20,7 +21,7 @@ from mindsdb_sql_parser.ast.base import ASTNode
20
21
 
21
22
  from mindsdb.integrations.libs.response import RESPONSE_TYPE, HandlerResponse
22
23
  from mindsdb.utilities import log
23
- from mindsdb.integrations.utilities.sql_utils import FilterCondition, FilterOperator
24
+ from mindsdb.integrations.utilities.sql_utils import FilterCondition, FilterOperator, KeywordSearchArgs
24
25
 
25
26
  from mindsdb.integrations.utilities.query_traversal import query_traversal
26
27
  from .base import BaseHandler
@@ -28,6 +29,9 @@ from .base import BaseHandler
28
29
  LOG = log.getLogger(__name__)
29
30
 
30
31
 
32
+ class VectorHandlerException(Exception): ...
33
+
34
+
31
35
  class TableField(Enum):
32
36
  """
33
37
  Enum for table fields.
@@ -43,9 +47,9 @@ class TableField(Enum):
43
47
 
44
48
 
45
49
  class DistanceFunction(Enum):
46
- SQUARED_EUCLIDEAN_DISTANCE = '<->',
47
- NEGATIVE_DOT_PRODUCT = '<#>',
48
- COSINE_DISTANCE = '<=>'
50
+ SQUARED_EUCLIDEAN_DISTANCE = ("<->",)
51
+ NEGATIVE_DOT_PRODUCT = ("<#>",)
52
+ COSINE_DISTANCE = "<=>"
49
53
 
50
54
 
51
55
  class VectorStoreHandler(BaseHandler):
@@ -118,9 +122,7 @@ class VectorStoreHandler(BaseHandler):
118
122
  right_hand = [item.value for item in node.args[1].items]
119
123
  else:
120
124
  raise Exception(f"Unsupported right hand side: {node.args[1]}")
121
- conditions.append(
122
- FilterCondition(column=left_hand, op=op, value=right_hand)
123
- )
125
+ conditions.append(FilterCondition(column=left_hand, op=op, value=right_hand))
124
126
 
125
127
  query_traversal(where_statement, _extract_comparison_conditions)
126
128
 
@@ -129,15 +131,23 @@ class VectorStoreHandler(BaseHandler):
129
131
 
130
132
  return conditions
131
133
 
132
- def _convert_metadata_filters(self, conditions):
134
+ def _convert_metadata_filters(self, conditions, allowed_metadata_columns=None):
133
135
  if conditions is None:
134
136
  return
135
137
  # try to treat conditions that are not in TableField as metadata conditions
136
138
  for condition in conditions:
137
- if not self._is_condition_allowed(condition):
138
- condition.column = (
139
- TableField.METADATA.value + "." + condition.column
140
- )
139
+ if self._is_metadata_condition(condition):
140
+ # check restriction
141
+ if allowed_metadata_columns is not None:
142
+ # system columns are underscored, skip them
143
+ if condition.column.lower() not in allowed_metadata_columns and not condition.column.startswith(
144
+ "_"
145
+ ):
146
+ raise ValueError(f"Column is not found: {condition.column}")
147
+
148
+ # convert if required
149
+ if not condition.column.startswith(TableField.METADATA.value):
150
+ condition.column = TableField.METADATA.value + "." + condition.column
141
151
 
142
152
  def _is_columns_allowed(self, columns: List[str]) -> bool:
143
153
  """
@@ -146,16 +156,11 @@ class VectorStoreHandler(BaseHandler):
146
156
  allowed_columns = set([col["name"] for col in self.SCHEMA])
147
157
  return set(columns).issubset(allowed_columns)
148
158
 
149
- def _is_condition_allowed(self, condition: FilterCondition) -> bool:
159
+ def _is_metadata_condition(self, condition: FilterCondition) -> bool:
150
160
  allowed_field_values = set([field.value for field in TableField])
151
161
  if condition.column in allowed_field_values:
152
- return True
153
- else:
154
- # check if column is a metadata column
155
- if condition.column.startswith(TableField.METADATA.value):
156
- return True
157
- else:
158
- return False
162
+ return False
163
+ return True
159
164
 
160
165
  def _dispatch_create_table(self, query: CreateTable):
161
166
  """
@@ -184,17 +189,12 @@ class VectorStoreHandler(BaseHandler):
184
189
  columns = [column.name for column in query.columns]
185
190
 
186
191
  if not self._is_columns_allowed(columns):
187
- raise Exception(
188
- f"Columns {columns} not allowed."
189
- f"Allowed columns are {[col['name'] for col in self.SCHEMA]}"
190
- )
192
+ raise Exception(f"Columns {columns} not allowed.Allowed columns are {[col['name'] for col in self.SCHEMA]}")
191
193
 
192
194
  # get content column if it is present
193
195
  if TableField.CONTENT.value in columns:
194
196
  content_col_index = columns.index("content")
195
- content = [
196
- self._value_or_self(row[content_col_index]) for row in query.values
197
- ]
197
+ content = [self._value_or_self(row[content_col_index]) for row in query.values]
198
198
  else:
199
199
  content = None
200
200
 
@@ -209,19 +209,13 @@ class VectorStoreHandler(BaseHandler):
209
209
  # get embeddings column if it is present
210
210
  if TableField.EMBEDDINGS.value in columns:
211
211
  embeddings_col_index = columns.index("embeddings")
212
- embeddings = [
213
- ast.literal_eval(self._value_or_self(row[embeddings_col_index]))
214
- for row in query.values
215
- ]
212
+ embeddings = [ast.literal_eval(self._value_or_self(row[embeddings_col_index])) for row in query.values]
216
213
  else:
217
214
  raise Exception("Embeddings column is required!")
218
215
 
219
216
  if TableField.METADATA.value in columns:
220
217
  metadata_col_index = columns.index("metadata")
221
- metadata = [
222
- ast.literal_eval(self._value_or_self(row[metadata_col_index]))
223
- for row in query.values
224
- ]
218
+ metadata = [ast.literal_eval(self._value_or_self(row[metadata_col_index])) for row in query.values]
225
219
  else:
226
220
  metadata = None
227
221
 
@@ -277,6 +271,15 @@ class VectorStoreHandler(BaseHandler):
277
271
 
278
272
  return self.do_upsert(table_name, df)
279
273
 
274
+ def set_metadata_cur_time(self, df, col_name):
275
+ metadata_col = TableField.METADATA.value
276
+ cur_date = dt.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
277
+
278
+ def set_time(meta):
279
+ meta[col_name] = cur_date
280
+
281
+ df[metadata_col].apply(set_time)
282
+
280
283
  def do_upsert(self, table_name, df):
281
284
  """Upsert data into table, handling document updates and deletions.
282
285
 
@@ -289,6 +292,7 @@ class VectorStoreHandler(BaseHandler):
289
292
  2. Updated documents: Delete old chunks and insert new ones
290
293
  """
291
294
  id_col = TableField.ID.value
295
+ metadata_col = TableField.METADATA.value
292
296
  content_col = TableField.CONTENT.value
293
297
 
294
298
  def gen_hash(v):
@@ -309,37 +313,48 @@ class VectorStoreHandler(BaseHandler):
309
313
  # id is string TODO is it ok?
310
314
  df[id_col] = df[id_col].apply(str)
311
315
 
312
- if hasattr(self, 'upsert'):
316
+ # set updated_at
317
+ self.set_metadata_cur_time(df, "_updated_at")
318
+
319
+ if hasattr(self, "upsert"):
313
320
  self.upsert(table_name, df)
314
321
  return
315
322
 
316
323
  # find existing ids
317
- res = self.select(
324
+ df_existed = self.select(
318
325
  table_name,
319
- columns=[id_col],
320
- conditions=[
321
- FilterCondition(column=id_col, op=FilterOperator.IN, value=list(df[id_col]))
322
- ]
326
+ columns=[id_col, metadata_col],
327
+ conditions=[FilterCondition(column=id_col, op=FilterOperator.IN, value=list(df[id_col]))],
323
328
  )
324
- existed_ids = list(res[id_col])
329
+ existed_ids = list(df_existed[id_col])
325
330
 
326
331
  # update existed
327
332
  df_update = df[df[id_col].isin(existed_ids)]
328
333
  df_insert = df[~df[id_col].isin(existed_ids)]
329
334
 
330
335
  if not df_update.empty:
336
+ # get values of existed `created_at` and return them to metadata
337
+ created_dates = {row[id_col]: row[metadata_col].get("_created_at") for _, row in df_existed.iterrows()}
338
+
339
+ def keep_created_at(row):
340
+ val = created_dates.get(row[id_col])
341
+ if val:
342
+ row[metadata_col]["_created_at"] = val
343
+ return row
344
+
345
+ df_update.apply(keep_created_at, axis=1)
346
+
331
347
  try:
332
348
  self.update(table_name, df_update, [id_col])
333
349
  except NotImplementedError:
334
350
  # not implemented? do it with delete and insert
335
- conditions = [FilterCondition(
336
- column=id_col,
337
- op=FilterOperator.IN,
338
- value=list(df[id_col])
339
- )]
351
+ conditions = [FilterCondition(column=id_col, op=FilterOperator.IN, value=list(df[id_col]))]
340
352
  self.delete(table_name, conditions)
341
353
  self.insert(table_name, df_update)
342
354
  if not df_insert.empty:
355
+ # set created_at
356
+ self.set_metadata_cur_time(df_insert, "_created_at")
357
+
343
358
  self.insert(table_name, df_insert)
344
359
 
345
360
  def dispatch_delete(self, query: Delete, conditions: List[FilterCondition] = None):
@@ -356,42 +371,66 @@ class VectorStoreHandler(BaseHandler):
356
371
  # dispatch delete
357
372
  return self.delete(table_name, conditions=conditions)
358
373
 
359
- def dispatch_select(self, query: Select, conditions: List[FilterCondition] = None):
374
+ def dispatch_select(
375
+ self,
376
+ query: Select,
377
+ conditions: Optional[List[FilterCondition]] = None,
378
+ allowed_metadata_columns: List[str] = None,
379
+ keyword_search_args: Optional[KeywordSearchArgs] = None,
380
+ ):
360
381
  """
361
- Dispatch select query to the appropriate method.
382
+ Dispatches a select query to the appropriate method, handling both
383
+ standard selections and keyword searches based on the provided arguments.
362
384
  """
363
- # parse key arguments
385
+ # 1. Parse common query arguments
364
386
  table_name = query.from_table.parts[-1]
365
- # if targets are star, select all columns
387
+
388
+ # If targets are a star (*), select all schema columns
366
389
  if isinstance(query.targets[0], Star):
367
390
  columns = [col["name"] for col in self.SCHEMA]
368
391
  else:
369
392
  columns = [col.parts[-1] for col in query.targets]
370
393
 
394
+ # 2. Validate columns
371
395
  if not self._is_columns_allowed(columns):
372
- raise Exception(
373
- f"Columns {columns} not allowed."
374
- f"Allowed columns are {[col['name'] for col in self.SCHEMA]}"
375
- )
396
+ allowed_cols = [col["name"] for col in self.SCHEMA]
397
+ raise Exception(f"Columns {columns} not allowed. Allowed columns are {allowed_cols}")
376
398
 
377
- # check if columns are allowed
399
+ # 3. Extract and process conditions
378
400
  if conditions is None:
379
401
  where_statement = query.where
380
402
  conditions = self.extract_conditions(where_statement)
381
- self._convert_metadata_filters(conditions)
403
+ self._convert_metadata_filters(conditions, allowed_metadata_columns=allowed_metadata_columns)
382
404
 
383
- # get offset and limit
405
+ # 4. Get offset and limit
384
406
  offset = query.offset.value if query.offset is not None else None
385
407
  limit = query.limit.value if query.limit is not None else None
386
408
 
387
- # dispatch select
388
- return self.select(
389
- table_name,
390
- columns=columns,
391
- conditions=conditions,
392
- offset=offset,
393
- limit=limit,
394
- )
409
+ # 5. Conditionally dispatch to the correct select method
410
+ if keyword_search_args:
411
+ # It's a keyword search
412
+ return self.keyword_select(
413
+ table_name,
414
+ columns=columns,
415
+ conditions=conditions,
416
+ offset=offset,
417
+ limit=limit,
418
+ keyword_search_args=keyword_search_args,
419
+ )
420
+ else:
421
+ # It's a standard select
422
+ try:
423
+ return self.select(
424
+ table_name,
425
+ columns=columns,
426
+ conditions=conditions,
427
+ offset=offset,
428
+ limit=limit,
429
+ )
430
+
431
+ except Exception as e:
432
+ handler_engine = self.__class__.name
433
+ raise VectorHandlerException(f"Error in {handler_engine} database: {e}")
395
434
 
396
435
  def _dispatch(self, query: ASTNode) -> HandlerResponse:
397
436
  """
@@ -408,10 +447,7 @@ class VectorStoreHandler(BaseHandler):
408
447
  if type(query) in dispatch_router:
409
448
  resp = dispatch_router[type(query)](query)
410
449
  if resp is not None:
411
- return HandlerResponse(
412
- resp_type=RESPONSE_TYPE.TABLE,
413
- data_frame=resp
414
- )
450
+ return HandlerResponse(resp_type=RESPONSE_TYPE.TABLE, data_frame=resp)
415
451
  else:
416
452
  return HandlerResponse(resp_type=RESPONSE_TYPE.OK)
417
453
 
@@ -455,9 +491,7 @@ class VectorStoreHandler(BaseHandler):
455
491
  """
456
492
  raise NotImplementedError()
457
493
 
458
- def insert(
459
- self, table_name: str, data: pd.DataFrame
460
- ) -> HandlerResponse:
494
+ def insert(self, table_name: str, data: pd.DataFrame) -> HandlerResponse:
461
495
  """Insert data into table
462
496
 
463
497
  Args:
@@ -470,9 +504,7 @@ class VectorStoreHandler(BaseHandler):
470
504
  """
471
505
  raise NotImplementedError()
472
506
 
473
- def update(
474
- self, table_name: str, data: pd.DataFrame, key_columns: List[str] = None
475
- ):
507
+ def update(self, table_name: str, data: pd.DataFrame, key_columns: List[str] = None):
476
508
  """Update data in table
477
509
 
478
510
  Args:
@@ -485,9 +517,7 @@ class VectorStoreHandler(BaseHandler):
485
517
  """
486
518
  raise NotImplementedError()
487
519
 
488
- def delete(
489
- self, table_name: str, conditions: List[FilterCondition] = None
490
- ) -> HandlerResponse:
520
+ def delete(self, table_name: str, conditions: List[FilterCondition] = None) -> HandlerResponse:
491
521
  """Delete data from table
492
522
 
493
523
  Args:
@@ -535,9 +565,9 @@ class VectorStoreHandler(BaseHandler):
535
565
  query: str = None,
536
566
  metadata: Dict[str, str] = None,
537
567
  distance_function=DistanceFunction.COSINE_DISTANCE,
538
- **kwargs
568
+ **kwargs,
539
569
  ) -> pd.DataFrame:
540
- '''
570
+ """
541
571
  Executes a hybrid search, combining semantic search and one or both of keyword/metadata search.
542
572
 
543
573
  For insight on the query construction, see: https://docs.pgvecto.rs/use-case/hybrid-search.html#advanced-search-merge-the-results-of-full-text-search-and-vector-search.
@@ -551,11 +581,11 @@ class VectorStoreHandler(BaseHandler):
551
581
 
552
582
  Returns:
553
583
  df(pd.DataFrame): Hybrid search result, sorted by hybrid search rank
554
- '''
555
- raise NotImplementedError(f'Hybrid search not supported for VectorStoreHandler {self.name}')
584
+ """
585
+ raise NotImplementedError(f"Hybrid search not supported for VectorStoreHandler {self.name}")
556
586
 
557
587
  def create_index(self, *args, **kwargs):
558
588
  """
559
589
  Create an index on the specified table.
560
590
  """
561
- raise NotImplementedError(f'create_index not supported for VectorStoreHandler {self.name}')
591
+ raise NotImplementedError(f"create_index not supported for VectorStoreHandler {self.name}")
@@ -33,7 +33,7 @@ class BaseLLMReranker(BaseModel, ABC):
33
33
  client: Optional[AsyncOpenAI | BaseMLEngine] = None
34
34
  _semaphore: Optional[asyncio.Semaphore] = None
35
35
  max_concurrent_requests: int = 20
36
- max_retries: int = 3
36
+ max_retries: int = 2
37
37
  retry_delay: float = 1.0
38
38
  request_timeout: float = 20.0 # Timeout for API requests
39
39
  early_stop: bool = True # Whether to enable early stopping
@@ -100,7 +100,7 @@ class BaseLLMReranker(BaseModel, ABC):
100
100
  if self.api_key is not None:
101
101
  kwargs["api_key"] = self.api_key
102
102
 
103
- return await self.client.acompletion(model=f"{self.provider}/{self.model}", messages=messages, args=kwargs)
103
+ return await self.client.acompletion(self.provider, model=self.model, messages=messages, args=kwargs)
104
104
 
105
105
  async def _rank(self, query_document_pairs: List[Tuple[str, str]], rerank_callback=None) -> List[Tuple[str, float]]:
106
106
  ranked_results = []
@@ -109,47 +109,41 @@ class BaseLLMReranker(BaseModel, ABC):
109
109
  batch_size = min(self.max_concurrent_requests * 2, len(query_document_pairs))
110
110
  for i in range(0, len(query_document_pairs), batch_size):
111
111
  batch = query_document_pairs[i : i + batch_size]
112
- try:
113
- results = await asyncio.gather(
114
- *[
115
- self._backoff_wrapper(query=query, document=document, rerank_callback=rerank_callback)
116
- for (query, document) in batch
117
- ],
118
- return_exceptions=True,
119
- )
120
112
 
121
- for idx, result in enumerate(results):
122
- if isinstance(result, Exception):
123
- log.error(f"Error processing document {i + idx}: {str(result)}")
124
- ranked_results.append((batch[idx][1], 0.0))
125
- continue
126
-
127
- score = result["relevance_score"]
128
-
129
- ranked_results.append((batch[idx][1], score))
130
-
131
- # Check if we should stop early
132
- try:
133
- high_scoring_docs = [r for r in ranked_results if r[1] >= self.filtering_threshold]
134
- can_stop_early = (
135
- self.early_stop # Early stopping is enabled
136
- and self.num_docs_to_keep # We have a target number of docs
137
- and len(high_scoring_docs) >= self.num_docs_to_keep # Found enough good docs
138
- and score >= self.early_stop_threshold # Current doc is good enough
139
- )
140
-
141
- if can_stop_early:
142
- log.info(
143
- f"Early stopping after finding {self.num_docs_to_keep} documents with high confidence"
144
- )
145
- return ranked_results
146
- except Exception as e:
147
- # Don't let early stopping errors stop the whole process
148
- log.warning(f"Error in early stopping check: {str(e)}")
149
-
150
- except Exception as e:
151
- log.error(f"Batch processing error: {str(e)}")
152
- continue
113
+ results = await asyncio.gather(
114
+ *[
115
+ self._backoff_wrapper(query=query, document=document, rerank_callback=rerank_callback)
116
+ for (query, document) in batch
117
+ ],
118
+ return_exceptions=True,
119
+ )
120
+
121
+ for idx, result in enumerate(results):
122
+ if isinstance(result, Exception):
123
+ log.error(f"Error processing document {i + idx}: {str(result)}")
124
+ raise RuntimeError(f"Error during reranking: {result}")
125
+
126
+ score = result["relevance_score"]
127
+
128
+ ranked_results.append((batch[idx][1], score))
129
+
130
+ # Check if we should stop early
131
+ try:
132
+ high_scoring_docs = [r for r in ranked_results if r[1] >= self.filtering_threshold]
133
+ can_stop_early = (
134
+ self.early_stop # Early stopping is enabled
135
+ and self.num_docs_to_keep # We have a target number of docs
136
+ and len(high_scoring_docs) >= self.num_docs_to_keep # Found enough good docs
137
+ and score >= self.early_stop_threshold # Current doc is good enough
138
+ )
139
+
140
+ if can_stop_early:
141
+ log.info(f"Early stopping after finding {self.num_docs_to_keep} documents with high confidence")
142
+ return ranked_results
143
+ except Exception as e:
144
+ # Don't let early stopping errors stop the whole process
145
+ log.warning(f"Error in early stopping check: {str(e)}")
146
+
153
147
  return ranked_results
154
148
 
155
149
  async def _backoff_wrapper(self, query: str, document: str, rerank_callback=None) -> Any:
@@ -60,6 +60,17 @@ class FilterCondition:
60
60
  """
61
61
 
62
62
 
63
+ class KeywordSearchArgs:
64
+ def __init__(self, column: str, query: str):
65
+ """
66
+ Args:
67
+ column: The column to search in.
68
+ query: The search query string.
69
+ """
70
+ self.column = column
71
+ self.query = query
72
+
73
+
63
74
  class SortColumn:
64
75
  def __init__(self, column: str, ascending: bool = True):
65
76
  self.column = column