MindsDB 25.7.1.0__py3-none-any.whl → 25.7.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of MindsDB might be problematic. Click here for more details.

@@ -1,17 +1,19 @@
1
1
  import os
2
2
  import copy
3
- from typing import Dict, List, Optional
3
+ from typing import Dict, List, Optional, Any, Text
4
4
  import json
5
5
  import decimal
6
6
 
7
7
  import pandas as pd
8
8
  import numpy as np
9
+ from pydantic import BaseModel, ValidationError
9
10
  from sqlalchemy.orm.attributes import flag_modified
10
11
 
11
12
  from mindsdb_sql_parser.ast import BinaryOperation, Constant, Identifier, Select, Update, Delete, Star
12
13
  from mindsdb_sql_parser.ast.mindsdb import CreatePredictor
13
14
  from mindsdb_sql_parser import parse_sql
14
15
 
16
+ from mindsdb.integrations.libs.keyword_search_base import KeywordSearchBase
15
17
  from mindsdb.integrations.utilities.query_traversal import query_traversal
16
18
 
17
19
  import mindsdb.interfaces.storage.db as db
@@ -37,7 +39,7 @@ from mindsdb.interfaces.knowledge_base.evaluate import EvaluateBase
37
39
  from mindsdb.interfaces.knowledge_base.executor import KnowledgeBaseQueryExecutor
38
40
  from mindsdb.interfaces.model.functions import PredictorRecordNotFound
39
41
  from mindsdb.utilities.exception import EntityExistsError, EntityNotExistsError
40
- from mindsdb.integrations.utilities.sql_utils import FilterCondition, FilterOperator
42
+ from mindsdb.integrations.utilities.sql_utils import FilterCondition, FilterOperator, KeywordSearchArgs
41
43
  from mindsdb.utilities.config import config
42
44
  from mindsdb.utilities.context import context as ctx
43
45
 
@@ -49,6 +51,20 @@ from mindsdb.integrations.utilities.rag.rerankers.base_reranker import BaseLLMRe
49
51
  logger = log.getLogger(__name__)
50
52
 
51
53
 
54
+ class KnowledgeBaseInputParams(BaseModel):
55
+ metadata_columns: List[str] | None = None
56
+ content_columns: List[str] | None = None
57
+ id_column: str | None = None
58
+ kb_no_upsert: bool = False
59
+ embedding_model: Dict[Text, Any] | None = None
60
+ is_sparse: bool = False
61
+ vector_size: int | None = None
62
+ reranking_model: Dict[Text, Any] | None = None
63
+
64
+ class Config:
65
+ extra = "forbid"
66
+
67
+
52
68
  def get_model_params(model_params: dict, default_config_key: str):
53
69
  """
54
70
  Get model parameters by combining default config with user provided parameters.
@@ -101,7 +117,10 @@ def get_reranking_model_from_params(reranking_model_params: dict):
101
117
 
102
118
  if "api_key" not in params_copy:
103
119
  params_copy["api_key"] = get_api_key(provider, params_copy, strict=False)
104
- params_copy["model"] = params_copy.pop("model_name", None)
120
+
121
+ if "model_name" not in params_copy:
122
+ raise ValueError("'model_name' must be provided for reranking model")
123
+ params_copy["model"] = params_copy.pop("model_name")
105
124
 
106
125
  return BaseLLMReranker(**params_copy)
107
126
 
@@ -179,17 +198,20 @@ class KnowledgeBaseTable:
179
198
  df = executor.run(query)
180
199
 
181
200
  if (
182
- query.group_by is not None
183
- or query.order_by is not None
184
- or query.having is not None
185
- or query.distinct is True
186
- or len(query.targets) != 1
187
- or not isinstance(query.targets[0], Star)
201
+ query_copy.group_by is not None
202
+ or query_copy.order_by is not None
203
+ or query_copy.having is not None
204
+ or query_copy.distinct is True
205
+ or len(query_copy.targets) != 1
206
+ or not isinstance(query_copy.targets[0], Star)
188
207
  ):
189
208
  query_copy.where = None
190
209
  if "metadata" in df.columns:
191
210
  df["metadata"] = df["metadata"].apply(to_json)
192
211
 
212
+ if query_copy.from_table is None:
213
+ query_copy.from_table = Identifier(parts=[self._kb.name])
214
+
193
215
  df = query_df(df, query_copy, session=self.session)
194
216
 
195
217
  return df
@@ -218,8 +240,12 @@ class KnowledgeBaseTable:
218
240
 
219
241
  # extract values from conditions and prepare for vectordb
220
242
  conditions = []
243
+ keyword_search_conditions = []
244
+ keyword_search_cols_and_values = []
221
245
  query_text = None
222
246
  relevance_threshold = None
247
+ reranking_enabled_flag = True
248
+ hybrid_search_enabled_flag = False
223
249
  query_conditions = db_handler.extract_conditions(query.where)
224
250
  if query_conditions is not None:
225
251
  for item in query_conditions:
@@ -235,9 +261,17 @@ class KnowledgeBaseTable:
235
261
  logger.error(error_msg)
236
262
  raise ValueError(error_msg)
237
263
  elif item.column == "reranking":
264
+ reranking_enabled_flag = item.value
265
+ # cast to boolean
266
+ if isinstance(reranking_enabled_flag, str):
267
+ reranking_enabled_flag = reranking_enabled_flag.lower() not in ("false")
268
+ elif item.column == "hybrid_search":
269
+ hybrid_search_enabled_flag = item.value
270
+ # cast to boolean
271
+ if isinstance(hybrid_search_enabled_flag, str):
272
+ hybrid_search_enabled_flag = hybrid_search_enabled_flag.lower() not in ("false")
238
273
  if item.value is False or (isinstance(item.value, str) and item.value.lower() == "false"):
239
274
  disable_reranking = True
240
-
241
275
  elif item.column == "relevance" and item.op.value != FilterOperator.GREATER_THAN_OR_EQUAL.value:
242
276
  raise ValueError(
243
277
  f"Invalid operator for relevance: {item.op.value}. Only GREATER_THAN_OR_EQUAL is allowed."
@@ -253,8 +287,16 @@ class KnowledgeBaseTable:
253
287
  op=FilterOperator.EQUAL,
254
288
  )
255
289
  )
290
+ keyword_search_cols_and_values.append((TableField.CONTENT.value, item.value))
256
291
  else:
257
292
  conditions.append(item)
293
+ keyword_search_conditions.append(item) # keyword search conditions do not use embeddings
294
+
295
+ if len(keyword_search_cols_and_values) > 1:
296
+ raise ValueError(
297
+ "Multiple content columns found in query conditions. "
298
+ "Only one content column is allowed for keyword search."
299
+ )
258
300
 
259
301
  logger.debug(f"Extracted query text: {query_text}")
260
302
 
@@ -272,9 +314,42 @@ class KnowledgeBaseTable:
272
314
  allowed_metadata_columns = self._get_allowed_metadata_columns()
273
315
  df = db_handler.dispatch_select(query, conditions, allowed_metadata_columns=allowed_metadata_columns)
274
316
  df = self.addapt_result_columns(df)
275
-
276
317
  logger.debug(f"Query returned {len(df)} rows")
277
318
  logger.debug(f"Columns in response: {df.columns.tolist()}")
319
+
320
+ if hybrid_search_enabled_flag and not isinstance(db_handler, KeywordSearchBase):
321
+ raise ValueError(f"Hybrid search is enabled but the db_handler {type(db_handler)} does not support it. ")
322
+ # check if db_handler inherits from KeywordSearchBase
323
+ if hybrid_search_enabled_flag and isinstance(db_handler, KeywordSearchBase):
324
+ # If query_text is present, use it for keyword search
325
+ logger.debug(f"Performing keyword search with query text: {query_text}")
326
+ keyword_search_args = KeywordSearchArgs(query=query_text, column=TableField.CONTENT.value)
327
+ keyword_query_obj = copy.deepcopy(query)
328
+
329
+ keyword_query_obj.targets = [
330
+ Identifier(TableField.ID.value),
331
+ Identifier(TableField.CONTENT.value),
332
+ Identifier(TableField.METADATA.value),
333
+ ]
334
+
335
+ df_keyword_select = db_handler.dispatch_select(
336
+ keyword_query_obj, keyword_search_conditions, keyword_search_args=keyword_search_args
337
+ )
338
+ df_keyword_select = self.addapt_result_columns(df_keyword_select)
339
+ logger.debug(f"Keyword search returned {len(df_keyword_select)} rows")
340
+ logger.debug(f"Columns in keyword search response: {df_keyword_select.columns.tolist()}")
341
+ # ensure df and df_keyword_select have exactly the same columns
342
+ if not df_keyword_select.empty:
343
+ if set(df.columns) != set(df_keyword_select.columns):
344
+ raise ValueError(
345
+ f"Keyword search returned different columns: {df_keyword_select.columns} "
346
+ f"than expected: {df.columns}"
347
+ )
348
+ df = pd.concat([df, df_keyword_select], ignore_index=True)
349
+ # if chunk_id column exists remove duplicates based on chunk_id
350
+ if "chunk_id" in df.columns:
351
+ df = df.drop_duplicates(subset=["chunk_id"])
352
+
278
353
  # Check if we have a rerank_model configured in KB params
279
354
  df = self.add_relevance(df, query_text, relevance_threshold, disable_reranking)
280
355
 
@@ -736,8 +811,7 @@ class KnowledgeBaseTable:
736
811
  if model_id is None:
737
812
  # call litellm handler
738
813
  messages = list(df[TableField.CONTENT.value])
739
- embedding_params = copy.deepcopy(config.get("default_embedding_model", {}))
740
- embedding_params.update(self._kb.params["embedding_model"])
814
+ embedding_params = get_model_params(self._kb.params.get("embedding_model", {}), "default_embedding_model")
741
815
  results = self.call_litellm_embedding(self.session, embedding_params, messages)
742
816
  results = [[val] for val in results]
743
817
  return pd.DataFrame(results, columns=[TableField.EMBEDDINGS.value])
@@ -783,6 +857,9 @@ class KnowledgeBaseTable:
783
857
  def call_litellm_embedding(session, model_params, messages):
784
858
  args = copy.deepcopy(model_params)
785
859
 
860
+ if "model_name" not in args:
861
+ raise ValueError("'model_name' must be provided for embedding model")
862
+
786
863
  llm_model = args.pop("model_name")
787
864
  engine = args.pop("provider")
788
865
 
@@ -936,6 +1013,24 @@ class KnowledgeBaseController:
936
1013
  # fill variables
937
1014
  params = variables_controller.fill_parameters(params)
938
1015
 
1016
+ try:
1017
+ KnowledgeBaseInputParams.model_validate(params)
1018
+ except ValidationError as e:
1019
+ problems = []
1020
+ for error in e.errors():
1021
+ parameter = ".".join([str(i) for i in error["loc"]])
1022
+ param_type = error["type"]
1023
+ if param_type == "extra_forbidden":
1024
+ msg = f"Parameter '{parameter}' is not allowed"
1025
+ else:
1026
+ msg = f"Error in '{parameter}' (type: {param_type}): {error['msg']}. Input: {repr(error['input'])}"
1027
+ problems.append(msg)
1028
+
1029
+ msg = "\n".join(problems)
1030
+ if len(problems) > 1:
1031
+ msg = "\n" + msg
1032
+ raise ValueError(f"Problem with knowledge base parameters: {msg}")
1033
+
939
1034
  # Validate preprocessing config first if provided
940
1035
  if preprocessing_config is not None:
941
1036
  PreprocessingConfig(**preprocessing_config) # Validate before storing
@@ -961,24 +1056,6 @@ class KnowledgeBaseController:
961
1056
  return kb
962
1057
  raise EntityExistsError("Knowledge base already exists", name)
963
1058
 
964
- embedding_params = copy.deepcopy(config.get("default_embedding_model", {}))
965
-
966
- # Legacy
967
- # model_name = None
968
- # model_project = project
969
- # if embedding_model:
970
- # model_name = embedding_model.parts[-1]
971
- # if len(embedding_model.parts) > 1:
972
- # model_project = self.session.database_controller.get_project(embedding_model.parts[-2])
973
-
974
- # elif "embedding_model" in params:
975
- # if isinstance(params["embedding_model"], str):
976
- # # it is model name
977
- # model_name = params["embedding_model"]
978
- # else:
979
- # # it is params for model
980
- # embedding_params.update(params["embedding_model"])
981
-
982
1059
  embedding_params = get_model_params(params.get("embedding_model", {}), "default_embedding_model")
983
1060
 
984
1061
  # if model_name is None: # Legacy
@@ -1009,7 +1086,11 @@ class KnowledgeBaseController:
1009
1086
  if reranking_model_params:
1010
1087
  # Get reranking model from params.
1011
1088
  # This is called here to check validaity of the parameters.
1012
- get_reranking_model_from_params(reranking_model_params)
1089
+ try:
1090
+ reranker = get_reranking_model_from_params(reranking_model_params)
1091
+ reranker.get_scores("test", ["test"])
1092
+ except (ValueError, RuntimeError) as e:
1093
+ raise RuntimeError(f"Problem with reranker config: {e}")
1013
1094
 
1014
1095
  # search for the vector database table
1015
1096
  if storage is None:
@@ -1102,15 +1183,26 @@ class KnowledgeBaseController:
1102
1183
  except PredictorRecordNotFound:
1103
1184
  pass
1104
1185
 
1105
- if params.get("provider", None) not in ("openai", "azure_openai"):
1186
+ if "provider" not in params:
1187
+ raise ValueError("'provider' parameter is required for embedding model")
1188
+
1189
+ if params["provider"] not in ("openai", "azure_openai"):
1106
1190
  # try use litellm
1107
- KnowledgeBaseTable.call_litellm_embedding(self.session, params, ["test"])
1191
+ try:
1192
+ KnowledgeBaseTable.call_litellm_embedding(self.session, params, ["test"])
1193
+ except Exception as e:
1194
+ raise RuntimeError(f"Problem with embedding model config: {e}")
1108
1195
  return
1109
1196
 
1110
1197
  if "provider" in params:
1111
1198
  engine = params.pop("provider").lower()
1112
1199
 
1113
- api_key = get_api_key(engine, params, strict=False) or params.pop("api_key")
1200
+ api_key = get_api_key(engine, params, strict=False)
1201
+ if api_key is None:
1202
+ if "api_key" in params:
1203
+ params.pop("api_key")
1204
+ else:
1205
+ raise ValueError("'api_key' parameter is required for embedding model")
1114
1206
 
1115
1207
  if engine == "azure_openai":
1116
1208
  engine = "openai"
@@ -90,7 +90,7 @@ class EvaluateBase:
90
90
  df = response.data_frame
91
91
 
92
92
  if "content" not in df.columns:
93
- raise ValueError("`content` column isn't found in source data")
93
+ raise ValueError(f"`content` column isn't found in provided sql: {gen_params['from_sql']}")
94
94
 
95
95
  df.rename(columns={"content": "chunk_content"}, inplace=True)
96
96
  else:
@@ -186,7 +186,7 @@ class EvaluateBase:
186
186
  to_table = params["save_to"]
187
187
  if isinstance(to_table, str):
188
188
  to_table = Identifier(to_table)
189
- self.save_to_table(to_table, scores)
189
+ self.save_to_table(to_table, scores.copy())
190
190
 
191
191
  return scores
192
192
 
@@ -28,6 +28,13 @@ def _merge_configs(original_config: dict, override_config: dict) -> dict:
28
28
  return original_config
29
29
 
30
30
 
31
+ def _overwrite_configs(original_config: dict, override_config: dict) -> dict:
32
+ """Overwrite original config with override config."""
33
+ for key in list(override_config.keys()):
34
+ original_config[key] = override_config[key]
35
+ return original_config
36
+
37
+
31
38
  def create_data_dir(path: Path) -> None:
32
39
  """Create a directory and checks that it is writable.
33
40
 
@@ -196,6 +203,15 @@ class Config:
196
203
  "host": "0.0.0.0", # API server binds to all interfaces by default
197
204
  "port": "8000",
198
205
  },
206
+ "a2a": {
207
+ "host": api_host,
208
+ "port": 47338,
209
+ "mindsdb_host": "localhost",
210
+ "mindsdb_port": 47334,
211
+ "agent_name": "my_agent",
212
+ "project_name": "mindsdb",
213
+ "enabled": False,
214
+ },
199
215
  },
200
216
  "cache": {"type": "local"},
201
217
  "ml_task_queue": {"type": "local"},
@@ -209,15 +225,6 @@ class Config:
209
225
  "default_llm": {},
210
226
  "default_embedding_model": {},
211
227
  "default_reranking_model": {},
212
- "a2a": {
213
- "host": "localhost",
214
- "port": 47338,
215
- "mindsdb_host": "localhost",
216
- "mindsdb_port": 47334,
217
- "agent_name": "my_agent",
218
- "project_name": "mindsdb",
219
- "enabled": False,
220
- },
221
228
  "data_catalog": {
222
229
  "enabled": False,
223
230
  },
@@ -243,12 +250,11 @@ class Config:
243
250
  """Collect config values from env vars to self._env_config"""
244
251
  self._env_config = {
245
252
  "logging": {"handlers": {"console": {}, "file": {}}},
246
- "api": {"http": {"server": {}}},
253
+ "api": {"http": {"server": {}}, "a2a": {}},
247
254
  "auth": {},
248
255
  "paths": {},
249
256
  "permanent_storage": {},
250
257
  "ml_task_queue": {},
251
- "a2a": {},
252
258
  }
253
259
 
254
260
  # region storage root path
@@ -390,7 +396,7 @@ class Config:
390
396
  )
391
397
 
392
398
  if a2a_config:
393
- self._env_config["a2a"] = a2a_config
399
+ self._env_config["api"]["a2a"] = a2a_config
394
400
  # endregion
395
401
 
396
402
  def fetch_auto_config(self) -> bool:
@@ -457,47 +463,36 @@ class Config:
457
463
  _merge_configs(new_config, self._env_config)
458
464
 
459
465
  # Apply command-line arguments for A2A
460
- cmd_args_config = {}
466
+ a2a_config = {}
461
467
 
462
468
  # Check for A2A command-line arguments
463
469
  if hasattr(self.cmd_args, "a2a_host") and self.cmd_args.a2a_host is not None:
464
- if "a2a" not in cmd_args_config:
465
- cmd_args_config["a2a"] = {}
466
- cmd_args_config["a2a"]["host"] = self.cmd_args.a2a_host
470
+ a2a_config["host"] = self.cmd_args.a2a_host
467
471
 
468
472
  if hasattr(self.cmd_args, "a2a_port") and self.cmd_args.a2a_port is not None:
469
- if "a2a" not in cmd_args_config:
470
- cmd_args_config["a2a"] = {}
471
- cmd_args_config["a2a"]["port"] = self.cmd_args.a2a_port
473
+ a2a_config["port"] = self.cmd_args.a2a_port
472
474
 
473
475
  if hasattr(self.cmd_args, "mindsdb_host") and self.cmd_args.mindsdb_host is not None:
474
- if "a2a" not in cmd_args_config:
475
- cmd_args_config["a2a"] = {}
476
- cmd_args_config["a2a"]["mindsdb_host"] = self.cmd_args.mindsdb_host
476
+ a2a_config["mindsdb_host"] = self.cmd_args.mindsdb_host
477
477
 
478
478
  if hasattr(self.cmd_args, "mindsdb_port") and self.cmd_args.mindsdb_port is not None:
479
- if "a2a" not in cmd_args_config:
480
- cmd_args_config["a2a"] = {}
481
- cmd_args_config["a2a"]["mindsdb_port"] = self.cmd_args.mindsdb_port
479
+ a2a_config["mindsdb_port"] = self.cmd_args.mindsdb_port
482
480
 
483
481
  if hasattr(self.cmd_args, "agent_name") and self.cmd_args.agent_name is not None:
484
- if "a2a" not in cmd_args_config:
485
- cmd_args_config["a2a"] = {}
486
- cmd_args_config["a2a"]["agent_name"] = self.cmd_args.agent_name
482
+ a2a_config["agent_name"] = self.cmd_args.agent_name
487
483
 
488
484
  if hasattr(self.cmd_args, "project_name") and self.cmd_args.project_name is not None:
489
- if "a2a" not in cmd_args_config:
490
- cmd_args_config["a2a"] = {}
491
- cmd_args_config["a2a"]["project_name"] = self.cmd_args.project_name
485
+ a2a_config["project_name"] = self.cmd_args.project_name
492
486
 
493
487
  # Merge command-line args config with highest priority
494
- if cmd_args_config:
495
- _merge_configs(new_config, cmd_args_config)
488
+ if a2a_config:
489
+ _merge_configs(new_config, {"api": {"a2a": a2a_config}})
496
490
 
497
491
  # Ensure A2A port is never 0, which would prevent the A2A API from starting
498
- if "a2a" in new_config and isinstance(new_config["a2a"], dict):
499
- if "port" in new_config["a2a"] and (new_config["a2a"]["port"] == 0 or new_config["a2a"]["port"] is None):
500
- new_config["a2a"]["port"] = 47338 # Use the default port value
492
+ a2a_config = new_config["api"].get("a2a")
493
+ if a2a_config is not None and isinstance(a2a_config, dict):
494
+ if "port" in a2a_config and (a2a_config["port"] == 0 or a2a_config["port"] is None):
495
+ a2a_config["port"] = 47338 # Use the default port value
501
496
 
502
497
  # region create dirs
503
498
  for key, value in new_config["paths"].items():
@@ -522,11 +517,23 @@ class Config:
522
517
  self.ensure_auto_config_is_relevant()
523
518
  return self._config
524
519
 
525
- def update(self, data: dict) -> None:
526
- """Update calues in `auto` config"""
520
+ def update(self, data: dict, overwrite: bool = False) -> None:
521
+ """
522
+ Update values in `auto` config.
523
+ Args:
524
+ data (dict): data to update in `auto` config.
525
+ overwrite (bool): if True, overwrite existing keys, otherwise merge them.
526
+ - False (default): Merge recursively. Existing nested dictionaries are preserved
527
+ and only the specified keys in `data` are updated.
528
+ - True: Overwrite completely. Existing keys are replaced entirely with values
529
+ from `data`, discarding any nested structure not present in `data`.
530
+ """
527
531
  self.ensure_auto_config_is_relevant()
528
532
 
529
- _merge_configs(self._auto_config, data)
533
+ if overwrite:
534
+ _overwrite_configs(self._auto_config, data)
535
+ else:
536
+ _merge_configs(self._auto_config, data)
530
537
 
531
538
  self.auto_config_path.write_text(json.dumps(self._auto_config, indent=4))
532
539
 
@@ -40,6 +40,7 @@ def format_db_error_message(
40
40
  db_type: str | None = None,
41
41
  db_error_msg: str | None = None,
42
42
  failed_query: str | None = None,
43
+ is_external: bool = True,
43
44
  ) -> str:
44
45
  """Format the error message for the database query.
45
46
 
@@ -48,11 +49,21 @@ def format_db_error_message(
48
49
  db_type (str | None): The type of the database.
49
50
  db_error_msg (str | None): The error message.
50
51
  failed_query (str | None): The failed query.
52
+ is_external (bool): True if error appeared in external database, False if in internal duckdb
51
53
 
52
54
  Returns:
53
55
  str: The formatted error message.
54
56
  """
55
57
  error_message = "Failed to execute external database query during query processing."
58
+ if is_external:
59
+ error_message = (
60
+ "An error occurred while executing a derived query on the external "
61
+ "database during processing of your original SQL query."
62
+ )
63
+ else:
64
+ error_message = (
65
+ "An error occurred while processing an internally generated query derived from your original SQL statement."
66
+ )
56
67
  if db_name is not None or db_type is not None:
57
68
  error_message += "\n\nDatabase Details:"
58
69
  if db_name is not None: