MindsDB 25.7.1.0__py3-none-any.whl → 25.7.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of MindsDB might be problematic. Click here for more details.

Files changed (38) hide show
  1. mindsdb/__about__.py +1 -1
  2. mindsdb/__main__.py +54 -95
  3. mindsdb/api/a2a/agent.py +30 -206
  4. mindsdb/api/a2a/common/server/server.py +26 -27
  5. mindsdb/api/a2a/task_manager.py +93 -227
  6. mindsdb/api/a2a/utils.py +21 -0
  7. mindsdb/api/executor/command_executor.py +7 -2
  8. mindsdb/api/executor/datahub/datanodes/integration_datanode.py +5 -1
  9. mindsdb/api/executor/utilities/sql.py +97 -21
  10. mindsdb/api/http/namespaces/agents.py +127 -202
  11. mindsdb/api/http/namespaces/config.py +12 -1
  12. mindsdb/integrations/handlers/litellm_handler/litellm_handler.py +11 -1
  13. mindsdb/integrations/handlers/llama_index_handler/requirements.txt +1 -1
  14. mindsdb/integrations/handlers/pgvector_handler/pgvector_handler.py +94 -1
  15. mindsdb/integrations/handlers/s3_handler/s3_handler.py +72 -70
  16. mindsdb/integrations/handlers/salesforce_handler/salesforce_handler.py +4 -3
  17. mindsdb/integrations/handlers/salesforce_handler/salesforce_tables.py +12 -3
  18. mindsdb/integrations/handlers/slack_handler/slack_tables.py +141 -161
  19. mindsdb/integrations/handlers/youtube_handler/youtube_tables.py +183 -55
  20. mindsdb/integrations/libs/keyword_search_base.py +41 -0
  21. mindsdb/integrations/libs/vectordatabase_handler.py +35 -14
  22. mindsdb/integrations/utilities/sql_utils.py +11 -0
  23. mindsdb/interfaces/agents/agents_controller.py +2 -2
  24. mindsdb/interfaces/data_catalog/data_catalog_loader.py +18 -4
  25. mindsdb/interfaces/database/projects.py +1 -3
  26. mindsdb/interfaces/functions/controller.py +54 -64
  27. mindsdb/interfaces/functions/to_markdown.py +47 -14
  28. mindsdb/interfaces/knowledge_base/controller.py +134 -35
  29. mindsdb/interfaces/knowledge_base/evaluate.py +53 -10
  30. mindsdb/interfaces/knowledge_base/llm_client.py +3 -3
  31. mindsdb/interfaces/knowledge_base/preprocessing/document_preprocessor.py +21 -13
  32. mindsdb/utilities/config.py +46 -39
  33. mindsdb/utilities/exception.py +11 -0
  34. {mindsdb-25.7.1.0.dist-info → mindsdb-25.7.3.0.dist-info}/METADATA +236 -236
  35. {mindsdb-25.7.1.0.dist-info → mindsdb-25.7.3.0.dist-info}/RECORD +38 -36
  36. {mindsdb-25.7.1.0.dist-info → mindsdb-25.7.3.0.dist-info}/WHEEL +0 -0
  37. {mindsdb-25.7.1.0.dist-info → mindsdb-25.7.3.0.dist-info}/licenses/LICENSE +0 -0
  38. {mindsdb-25.7.1.0.dist-info → mindsdb-25.7.3.0.dist-info}/top_level.txt +0 -0
@@ -7,15 +7,15 @@ from mindsdb.utilities.config import config
7
7
 
8
8
 
9
9
  def python_to_duckdb_type(py_type):
10
- if py_type == 'int':
10
+ if py_type == "int":
11
11
  return BIGINT
12
- elif py_type == 'float':
12
+ elif py_type == "float":
13
13
  return DOUBLE
14
- elif py_type == 'str':
14
+ elif py_type == "str":
15
15
  return VARCHAR
16
- elif py_type == 'bool':
16
+ elif py_type == "bool":
17
17
  return BOOLEAN
18
- elif py_type == 'bytes':
18
+ elif py_type == "bytes":
19
19
  return BLOB
20
20
  else:
21
21
  # Unknown
@@ -53,8 +53,8 @@ class BYOMFunctionsController:
53
53
  # first run
54
54
  self.byom_engines = []
55
55
  for name, info in self.session.integration_controller.get_all().items():
56
- if info['type'] == 'ml' and info['engine'] == 'byom':
57
- if info['connection_data'].get('mode') == 'custom_function':
56
+ if info["type"] == "ml" and info["engine"] == "byom":
57
+ if info["connection_data"].get("mode") == "custom_function":
58
58
  self.byom_engines.append(name)
59
59
  return self.byom_engines
60
60
 
@@ -63,7 +63,7 @@ class BYOMFunctionsController:
63
63
  ml_handler = self.session.integration_controller.get_ml_handler(engine)
64
64
 
65
65
  storage = HandlerStorage(ml_handler.integration_id)
66
- methods = storage.json_get('methods')
66
+ methods = storage.json_get("methods")
67
67
  self.byom_methods[engine] = methods
68
68
  self.byom_handlers[engine] = ml_handler
69
69
 
@@ -81,7 +81,7 @@ class BYOMFunctionsController:
81
81
  # do nothing
82
82
  return
83
83
 
84
- new_name = f'{node.namespace}_{fnc_name}'
84
+ new_name = f"{node.namespace}_{fnc_name}"
85
85
  node.op = new_name
86
86
 
87
87
  if new_name in self.callbacks:
@@ -91,16 +91,13 @@ class BYOMFunctionsController:
91
91
  def callback(*args):
92
92
  return self.method_call(engine, fnc_name, args)
93
93
 
94
- input_types = [
95
- param['type']
96
- for param in methods[fnc_name]['input_params']
97
- ]
94
+ input_types = [param["type"] for param in methods[fnc_name]["input_params"]]
98
95
 
99
96
  meta = {
100
- 'name': new_name,
101
- 'callback': callback,
102
- 'input_types': input_types,
103
- 'output_type': methods[fnc_name]['output_type']
97
+ "name": new_name,
98
+ "callback": callback,
99
+ "input_types": input_types,
100
+ "output_type": methods[fnc_name]["output_type"],
104
101
  }
105
102
 
106
103
  self.callbacks[new_name] = meta
@@ -114,7 +111,6 @@ class BYOMFunctionsController:
114
111
 
115
112
 
116
113
  class FunctionController(BYOMFunctionsController):
117
-
118
114
  def __init__(self, *args, **kwargs):
119
115
  super().__init__(*args, **kwargs)
120
116
 
@@ -124,10 +120,10 @@ class FunctionController(BYOMFunctionsController):
124
120
  return meta
125
121
 
126
122
  # builtin functions
127
- if node.op.lower() == 'llm':
123
+ if node.op.lower() == "llm":
128
124
  return self.llm_call_function(node)
129
125
 
130
- elif node.op.lower() == 'to_markdown':
126
+ elif node.op.lower() == "to_markdown":
131
127
  return self.to_markdown_call_function(node)
132
128
 
133
129
  def llm_call_function(self, node):
@@ -141,70 +137,74 @@ class FunctionController(BYOMFunctionsController):
141
137
  try:
142
138
  from langchain_core.messages import HumanMessage
143
139
  from mindsdb.interfaces.agents.langchain_agent import create_chat_model
140
+
144
141
  llm = create_chat_model(chat_model_params)
145
142
  except Exception as e:
146
- raise RuntimeError(f'Unable to use LLM function, check ENV variables: {e}')
143
+ raise RuntimeError(f"Unable to use LLM function, check ENV variables: {e}")
147
144
 
148
145
  def callback(question):
149
146
  resp = llm([HumanMessage(question)])
150
147
  return resp.content
151
148
 
152
- meta = {
153
- 'name': name,
154
- 'callback': callback,
155
- 'input_types': ['str'],
156
- 'output_type': 'str'
157
- }
149
+ meta = {"name": name, "callback": callback, "input_types": ["str"], "output_type": "str"}
158
150
  self.callbacks[name] = meta
159
151
  return meta
160
152
 
161
153
  def to_markdown_call_function(self, node):
162
154
  # load on-demand because lib is heavy
163
155
  from mindsdb.interfaces.functions.to_markdown import ToMarkdown
156
+
164
157
  name = node.op.lower()
165
158
 
166
159
  if name in self.callbacks:
167
160
  return self.callbacks[name]
168
161
 
169
- def callback(file_path_or_url):
170
- chat_model_params = self._parse_chat_model_params('TO_MARKDOWN_FUNCTION_')
171
-
162
+ def prepare_chat_model_params(chat_model_params: dict) -> dict:
163
+ """
164
+ Parepares the chat model parameters for the ToMarkdown function.
165
+ """
172
166
  params_copy = copy.deepcopy(chat_model_params)
173
- params_copy['model'] = params_copy.pop('model_name')
174
- params_copy.pop('api_keys')
175
- params_copy.pop('provider')
167
+ params_copy["model"] = params_copy.pop("model_name")
168
+
169
+ # Set the base_url for the Google provider.
170
+ if params_copy["provider"] == "google" and "base_url" not in params_copy:
171
+ params_copy["base_url"] = "https://generativelanguage.googleapis.com/v1beta/"
172
+
173
+ params_copy.pop("api_keys")
174
+ params_copy.pop("provider")
175
+
176
+ return params_copy
177
+
178
+ def callback(file_path_or_url):
179
+ chat_model_params = self._parse_chat_model_params("TO_MARKDOWN_FUNCTION_")
180
+ chat_model_params = prepare_chat_model_params(chat_model_params)
176
181
 
177
182
  to_markdown = ToMarkdown()
178
- return to_markdown.call(file_path_or_url, **params_copy)
183
+ return to_markdown.call(file_path_or_url, **chat_model_params)
179
184
 
180
- meta = {
181
- 'name': name,
182
- 'callback': callback,
183
- 'input_types': ['str'],
184
- 'output_type': 'str'
185
- }
185
+ meta = {"name": name, "callback": callback, "input_types": ["str"], "output_type": "str"}
186
186
  self.callbacks[name] = meta
187
187
  return meta
188
188
 
189
- def _parse_chat_model_params(self, param_prefix: str = 'LLM_FUNCTION_'):
189
+ def _parse_chat_model_params(self, param_prefix: str = "LLM_FUNCTION_"):
190
190
  """
191
191
  Parses the environment variables for chat model parameters.
192
192
  """
193
193
  chat_model_params = config.get("default_llm") or {}
194
194
  for k, v in os.environ.items():
195
195
  if k.startswith(param_prefix):
196
- param_name = k[len(param_prefix):]
197
- if param_name == 'MODEL':
198
- chat_model_params['model_name'] = v
196
+ param_name = k[len(param_prefix) :]
197
+ if param_name == "MODEL":
198
+ chat_model_params["model_name"] = v
199
199
  else:
200
200
  chat_model_params[param_name.lower()] = v
201
201
 
202
- if 'provider' not in chat_model_params:
203
- chat_model_params['provider'] = 'openai'
202
+ if "provider" not in chat_model_params:
203
+ chat_model_params["provider"] = "openai"
204
204
 
205
- if 'api_key' in chat_model_params:
205
+ if "api_key" in chat_model_params:
206
206
  # move to api_keys dict
207
- chat_model_params["api_keys"] = {chat_model_params['provider']: chat_model_params['api_key']}
207
+ chat_model_params["api_keys"] = {chat_model_params["provider"]: chat_model_params["api_key"]}
208
208
 
209
209
  return chat_model_params
210
210
 
@@ -215,33 +215,23 @@ class DuckDBFunctions:
215
215
  self.functions = {}
216
216
 
217
217
  def check_function(self, node):
218
-
219
218
  meta = self.controller.check_function(node)
220
219
  if meta is None:
221
220
  return
222
221
 
223
- name = meta['name']
222
+ name = meta["name"]
224
223
 
225
224
  if name in self.functions:
226
225
  return
227
226
 
228
- input_types = [
229
- python_to_duckdb_type(param)
230
- for param in meta['input_types']
231
- ]
227
+ input_types = [python_to_duckdb_type(param) for param in meta["input_types"]]
232
228
 
233
229
  self.functions[name] = {
234
- 'callback': function_maker(len(input_types), meta['callback']),
235
- 'input': input_types,
236
- 'output': python_to_duckdb_type(meta['output_type'])
230
+ "callback": function_maker(len(input_types), meta["callback"]),
231
+ "input": input_types,
232
+ "output": python_to_duckdb_type(meta["output_type"]),
237
233
  }
238
234
 
239
235
  def register(self, connection):
240
236
  for name, info in self.functions.items():
241
- connection.create_function(
242
- name,
243
- info['callback'],
244
- info['input'],
245
- info['output'],
246
- null_handling="special"
247
- )
237
+ connection.create_function(name, info["callback"], info["input"], info["output"], null_handling="special")
@@ -2,6 +2,7 @@ from io import BytesIO
2
2
  import os
3
3
  from typing import Union
4
4
  from urllib.parse import urlparse
5
+ import xml.etree.ElementTree as ET
5
6
 
6
7
  from aipdf import ocr
7
8
  import mimetypes
@@ -12,6 +13,7 @@ class ToMarkdown:
12
13
  """
13
14
  Extracts the content of documents of various formats in markdown format.
14
15
  """
16
+
15
17
  def __init__(self):
16
18
  """
17
19
  Initializes the ToMarkdown class.
@@ -24,24 +26,28 @@ class ToMarkdown:
24
26
  file_extension = self._get_file_extension(file_path_or_url)
25
27
  file_content = self._get_file_content(file_path_or_url)
26
28
 
27
- if file_extension == '.pdf':
29
+ if file_extension == ".pdf":
28
30
  return self._pdf_to_markdown(file_content, **kwargs)
31
+
32
+ elif file_extension in (".xml", ".nessus"):
33
+ return self._xml_to_markdown(file_content, **kwargs)
34
+
29
35
  else:
30
36
  raise ValueError(f"Unsupported file type: {file_extension}.")
31
37
 
32
- def _get_file_content(self, file_path_or_url: str) -> str:
38
+ def _get_file_content(self, file_path_or_url: str) -> BytesIO:
33
39
  """
34
40
  Retrieves the content of a file.
35
41
  """
36
42
  parsed_url = urlparse(file_path_or_url)
37
- if parsed_url.scheme in ('http', 'https'):
43
+ if parsed_url.scheme in ("http", "https"):
38
44
  response = requests.get(file_path_or_url)
39
45
  if response.status_code == 200:
40
- return response
46
+ return BytesIO(response.content)
41
47
  else:
42
- raise RuntimeError(f'Unable to retrieve file from URL: {file_path_or_url}')
48
+ raise RuntimeError(f"Unable to retrieve file from URL: {file_path_or_url}")
43
49
  else:
44
- with open(file_path_or_url, 'rb') as file:
50
+ with open(file_path_or_url, "rb") as file:
45
51
  return BytesIO(file.read())
46
52
 
47
53
  def _get_file_extension(self, file_path_or_url: str) -> str:
@@ -49,13 +55,13 @@ class ToMarkdown:
49
55
  Retrieves the file extension from a file path or URL.
50
56
  """
51
57
  parsed_url = urlparse(file_path_or_url)
52
- if parsed_url.scheme in ('http', 'https'):
58
+ if parsed_url.scheme in ("http", "https"):
53
59
  try:
54
60
  # Make a HEAD request to get headers without downloading the file.
55
61
  response = requests.head(file_path_or_url, allow_redirects=True)
56
- content_type = response.headers.get('Content-Type', '')
62
+ content_type = response.headers.get("Content-Type", "")
57
63
  if content_type:
58
- ext = mimetypes.guess_extension(content_type.split(';')[0].strip())
64
+ ext = mimetypes.guess_extension(content_type.split(";")[0].strip())
59
65
  if ext:
60
66
  return ext
61
67
 
@@ -64,16 +70,43 @@ class ToMarkdown:
64
70
  if ext:
65
71
  return ext
66
72
  except requests.RequestException:
67
- raise RuntimeError(f'Unable to retrieve file extension from URL: {file_path_or_url}')
73
+ raise RuntimeError(f"Unable to retrieve file extension from URL: {file_path_or_url}")
68
74
  else:
69
75
  return os.path.splitext(file_path_or_url)[1]
70
76
 
71
- def _pdf_to_markdown(self, file_content: Union[requests.Response, bytes], **kwargs) -> str:
77
+ def _pdf_to_markdown(self, file_content: Union[requests.Response, BytesIO], **kwargs) -> str:
72
78
  """
73
79
  Converts a PDF file to markdown.
74
80
  """
75
- if isinstance(file_content, requests.Response):
76
- file_content = BytesIO(file_content.content)
77
-
78
81
  markdown_pages = ocr(file_content, **kwargs)
79
82
  return "\n\n---\n\n".join(markdown_pages)
83
+
84
+ def _xml_to_markdown(self, file_content: Union[requests.Response, BytesIO], **kwargs) -> str:
85
+ """
86
+ Converts an XML (or Nessus) file to markdown.
87
+ """
88
+
89
+ def parse_element(element: ET.Element, depth: int = 0) -> str:
90
+ """
91
+ Recursively parses an XML element and converts it to markdown.
92
+ """
93
+ markdown = []
94
+ heading = "#" * (depth + 1)
95
+
96
+ markdown.append(f"{heading} {element.tag}")
97
+
98
+ for key, val in element.attrib.items():
99
+ markdown.append(f"- **{key}**: {val}")
100
+
101
+ text = (element.text or "").strip()
102
+ if text:
103
+ markdown.append(f"\n{text}\n")
104
+
105
+ for child in element:
106
+ markdown.append(parse_element(child, depth + 1))
107
+
108
+ return "\n".join(markdown)
109
+
110
+ root = ET.fromstring(file_content.read().decode("utf-8"))
111
+ markdown_content = parse_element(root)
112
+ return markdown_content
@@ -1,17 +1,19 @@
1
1
  import os
2
2
  import copy
3
- from typing import Dict, List, Optional
3
+ from typing import Dict, List, Optional, Any, Text
4
4
  import json
5
5
  import decimal
6
6
 
7
7
  import pandas as pd
8
8
  import numpy as np
9
+ from pydantic import BaseModel, ValidationError
9
10
  from sqlalchemy.orm.attributes import flag_modified
10
11
 
11
12
  from mindsdb_sql_parser.ast import BinaryOperation, Constant, Identifier, Select, Update, Delete, Star
12
13
  from mindsdb_sql_parser.ast.mindsdb import CreatePredictor
13
14
  from mindsdb_sql_parser import parse_sql
14
15
 
16
+ from mindsdb.integrations.libs.keyword_search_base import KeywordSearchBase
15
17
  from mindsdb.integrations.utilities.query_traversal import query_traversal
16
18
 
17
19
  import mindsdb.interfaces.storage.db as db
@@ -37,7 +39,7 @@ from mindsdb.interfaces.knowledge_base.evaluate import EvaluateBase
37
39
  from mindsdb.interfaces.knowledge_base.executor import KnowledgeBaseQueryExecutor
38
40
  from mindsdb.interfaces.model.functions import PredictorRecordNotFound
39
41
  from mindsdb.utilities.exception import EntityExistsError, EntityNotExistsError
40
- from mindsdb.integrations.utilities.sql_utils import FilterCondition, FilterOperator
42
+ from mindsdb.integrations.utilities.sql_utils import FilterCondition, FilterOperator, KeywordSearchArgs
41
43
  from mindsdb.utilities.config import config
42
44
  from mindsdb.utilities.context import context as ctx
43
45
 
@@ -49,6 +51,20 @@ from mindsdb.integrations.utilities.rag.rerankers.base_reranker import BaseLLMRe
49
51
  logger = log.getLogger(__name__)
50
52
 
51
53
 
54
+ class KnowledgeBaseInputParams(BaseModel):
55
+ metadata_columns: List[str] | None = None
56
+ content_columns: List[str] | None = None
57
+ id_column: str | None = None
58
+ kb_no_upsert: bool = False
59
+ embedding_model: Dict[Text, Any] | None = None
60
+ is_sparse: bool = False
61
+ vector_size: int | None = None
62
+ reranking_model: Dict[Text, Any] | None = None
63
+
64
+ class Config:
65
+ extra = "forbid"
66
+
67
+
52
68
  def get_model_params(model_params: dict, default_config_key: str):
53
69
  """
54
70
  Get model parameters by combining default config with user provided parameters.
@@ -101,7 +117,10 @@ def get_reranking_model_from_params(reranking_model_params: dict):
101
117
 
102
118
  if "api_key" not in params_copy:
103
119
  params_copy["api_key"] = get_api_key(provider, params_copy, strict=False)
104
- params_copy["model"] = params_copy.pop("model_name", None)
120
+
121
+ if "model_name" not in params_copy:
122
+ raise ValueError("'model_name' must be provided for reranking model")
123
+ params_copy["model"] = params_copy.pop("model_name")
105
124
 
106
125
  return BaseLLMReranker(**params_copy)
107
126
 
@@ -179,17 +198,20 @@ class KnowledgeBaseTable:
179
198
  df = executor.run(query)
180
199
 
181
200
  if (
182
- query.group_by is not None
183
- or query.order_by is not None
184
- or query.having is not None
185
- or query.distinct is True
186
- or len(query.targets) != 1
187
- or not isinstance(query.targets[0], Star)
201
+ query_copy.group_by is not None
202
+ or query_copy.order_by is not None
203
+ or query_copy.having is not None
204
+ or query_copy.distinct is True
205
+ or len(query_copy.targets) != 1
206
+ or not isinstance(query_copy.targets[0], Star)
188
207
  ):
189
208
  query_copy.where = None
190
209
  if "metadata" in df.columns:
191
210
  df["metadata"] = df["metadata"].apply(to_json)
192
211
 
212
+ if query_copy.from_table is None:
213
+ query_copy.from_table = Identifier(parts=[self._kb.name])
214
+
193
215
  df = query_df(df, query_copy, session=self.session)
194
216
 
195
217
  return df
@@ -218,8 +240,12 @@ class KnowledgeBaseTable:
218
240
 
219
241
  # extract values from conditions and prepare for vectordb
220
242
  conditions = []
243
+ keyword_search_conditions = []
244
+ keyword_search_cols_and_values = []
221
245
  query_text = None
222
246
  relevance_threshold = None
247
+ reranking_enabled_flag = True
248
+ hybrid_search_enabled_flag = False
223
249
  query_conditions = db_handler.extract_conditions(query.where)
224
250
  if query_conditions is not None:
225
251
  for item in query_conditions:
@@ -235,9 +261,17 @@ class KnowledgeBaseTable:
235
261
  logger.error(error_msg)
236
262
  raise ValueError(error_msg)
237
263
  elif item.column == "reranking":
264
+ reranking_enabled_flag = item.value
265
+ # cast to boolean
266
+ if isinstance(reranking_enabled_flag, str):
267
+ reranking_enabled_flag = reranking_enabled_flag.lower() not in ("false")
268
+ elif item.column == "hybrid_search":
269
+ hybrid_search_enabled_flag = item.value
270
+ # cast to boolean
271
+ if isinstance(hybrid_search_enabled_flag, str):
272
+ hybrid_search_enabled_flag = hybrid_search_enabled_flag.lower() not in ("false")
238
273
  if item.value is False or (isinstance(item.value, str) and item.value.lower() == "false"):
239
274
  disable_reranking = True
240
-
241
275
  elif item.column == "relevance" and item.op.value != FilterOperator.GREATER_THAN_OR_EQUAL.value:
242
276
  raise ValueError(
243
277
  f"Invalid operator for relevance: {item.op.value}. Only GREATER_THAN_OR_EQUAL is allowed."
@@ -253,8 +287,16 @@ class KnowledgeBaseTable:
253
287
  op=FilterOperator.EQUAL,
254
288
  )
255
289
  )
290
+ keyword_search_cols_and_values.append((TableField.CONTENT.value, item.value))
256
291
  else:
257
292
  conditions.append(item)
293
+ keyword_search_conditions.append(item) # keyword search conditions do not use embeddings
294
+
295
+ if len(keyword_search_cols_and_values) > 1:
296
+ raise ValueError(
297
+ "Multiple content columns found in query conditions. "
298
+ "Only one content column is allowed for keyword search."
299
+ )
258
300
 
259
301
  logger.debug(f"Extracted query text: {query_text}")
260
302
 
@@ -272,9 +314,42 @@ class KnowledgeBaseTable:
272
314
  allowed_metadata_columns = self._get_allowed_metadata_columns()
273
315
  df = db_handler.dispatch_select(query, conditions, allowed_metadata_columns=allowed_metadata_columns)
274
316
  df = self.addapt_result_columns(df)
275
-
276
317
  logger.debug(f"Query returned {len(df)} rows")
277
318
  logger.debug(f"Columns in response: {df.columns.tolist()}")
319
+
320
+ if hybrid_search_enabled_flag and not isinstance(db_handler, KeywordSearchBase):
321
+ raise ValueError(f"Hybrid search is enabled but the db_handler {type(db_handler)} does not support it. ")
322
+ # check if db_handler inherits from KeywordSearchBase
323
+ if hybrid_search_enabled_flag and isinstance(db_handler, KeywordSearchBase):
324
+ # If query_text is present, use it for keyword search
325
+ logger.debug(f"Performing keyword search with query text: {query_text}")
326
+ keyword_search_args = KeywordSearchArgs(query=query_text, column=TableField.CONTENT.value)
327
+ keyword_query_obj = copy.deepcopy(query)
328
+
329
+ keyword_query_obj.targets = [
330
+ Identifier(TableField.ID.value),
331
+ Identifier(TableField.CONTENT.value),
332
+ Identifier(TableField.METADATA.value),
333
+ ]
334
+
335
+ df_keyword_select = db_handler.dispatch_select(
336
+ keyword_query_obj, keyword_search_conditions, keyword_search_args=keyword_search_args
337
+ )
338
+ df_keyword_select = self.addapt_result_columns(df_keyword_select)
339
+ logger.debug(f"Keyword search returned {len(df_keyword_select)} rows")
340
+ logger.debug(f"Columns in keyword search response: {df_keyword_select.columns.tolist()}")
341
+ # ensure df and df_keyword_select have exactly the same columns
342
+ if not df_keyword_select.empty:
343
+ if set(df.columns) != set(df_keyword_select.columns):
344
+ raise ValueError(
345
+ f"Keyword search returned different columns: {df_keyword_select.columns} "
346
+ f"than expected: {df.columns}"
347
+ )
348
+ df = pd.concat([df, df_keyword_select], ignore_index=True)
349
+ # if chunk_id column exists remove duplicates based on chunk_id
350
+ if "chunk_id" in df.columns:
351
+ df = df.drop_duplicates(subset=["chunk_id"])
352
+
278
353
  # Check if we have a rerank_model configured in KB params
279
354
  df = self.add_relevance(df, query_text, relevance_threshold, disable_reranking)
280
355
 
@@ -736,8 +811,7 @@ class KnowledgeBaseTable:
736
811
  if model_id is None:
737
812
  # call litellm handler
738
813
  messages = list(df[TableField.CONTENT.value])
739
- embedding_params = copy.deepcopy(config.get("default_embedding_model", {}))
740
- embedding_params.update(self._kb.params["embedding_model"])
814
+ embedding_params = get_model_params(self._kb.params.get("embedding_model", {}), "default_embedding_model")
741
815
  results = self.call_litellm_embedding(self.session, embedding_params, messages)
742
816
  results = [[val] for val in results]
743
817
  return pd.DataFrame(results, columns=[TableField.EMBEDDINGS.value])
@@ -783,6 +857,9 @@ class KnowledgeBaseTable:
783
857
  def call_litellm_embedding(session, model_params, messages):
784
858
  args = copy.deepcopy(model_params)
785
859
 
860
+ if "model_name" not in args:
861
+ raise ValueError("'model_name' must be provided for embedding model")
862
+
786
863
  llm_model = args.pop("model_name")
787
864
  engine = args.pop("provider")
788
865
 
@@ -936,6 +1013,24 @@ class KnowledgeBaseController:
936
1013
  # fill variables
937
1014
  params = variables_controller.fill_parameters(params)
938
1015
 
1016
+ try:
1017
+ KnowledgeBaseInputParams.model_validate(params)
1018
+ except ValidationError as e:
1019
+ problems = []
1020
+ for error in e.errors():
1021
+ parameter = ".".join([str(i) for i in error["loc"]])
1022
+ param_type = error["type"]
1023
+ if param_type == "extra_forbidden":
1024
+ msg = f"Parameter '{parameter}' is not allowed"
1025
+ else:
1026
+ msg = f"Error in '{parameter}' (type: {param_type}): {error['msg']}. Input: {repr(error['input'])}"
1027
+ problems.append(msg)
1028
+
1029
+ msg = "\n".join(problems)
1030
+ if len(problems) > 1:
1031
+ msg = "\n" + msg
1032
+ raise ValueError(f"Problem with knowledge base parameters: {msg}")
1033
+
939
1034
  # Validate preprocessing config first if provided
940
1035
  if preprocessing_config is not None:
941
1036
  PreprocessingConfig(**preprocessing_config) # Validate before storing
@@ -961,24 +1056,6 @@ class KnowledgeBaseController:
961
1056
  return kb
962
1057
  raise EntityExistsError("Knowledge base already exists", name)
963
1058
 
964
- embedding_params = copy.deepcopy(config.get("default_embedding_model", {}))
965
-
966
- # Legacy
967
- # model_name = None
968
- # model_project = project
969
- # if embedding_model:
970
- # model_name = embedding_model.parts[-1]
971
- # if len(embedding_model.parts) > 1:
972
- # model_project = self.session.database_controller.get_project(embedding_model.parts[-2])
973
-
974
- # elif "embedding_model" in params:
975
- # if isinstance(params["embedding_model"], str):
976
- # # it is model name
977
- # model_name = params["embedding_model"]
978
- # else:
979
- # # it is params for model
980
- # embedding_params.update(params["embedding_model"])
981
-
982
1059
  embedding_params = get_model_params(params.get("embedding_model", {}), "default_embedding_model")
983
1060
 
984
1061
  # if model_name is None: # Legacy
@@ -1009,7 +1086,11 @@ class KnowledgeBaseController:
1009
1086
  if reranking_model_params:
1010
1087
  # Get reranking model from params.
1011
1088
  # This is called here to check validaity of the parameters.
1012
- get_reranking_model_from_params(reranking_model_params)
1089
+ try:
1090
+ reranker = get_reranking_model_from_params(reranking_model_params)
1091
+ reranker.get_scores("test", ["test"])
1092
+ except (ValueError, RuntimeError) as e:
1093
+ raise RuntimeError(f"Problem with reranker config: {e}")
1013
1094
 
1014
1095
  # search for the vector database table
1015
1096
  if storage is None:
@@ -1102,15 +1183,33 @@ class KnowledgeBaseController:
1102
1183
  except PredictorRecordNotFound:
1103
1184
  pass
1104
1185
 
1105
- if params.get("provider", None) not in ("openai", "azure_openai"):
1186
+ if "provider" not in params:
1187
+ raise ValueError("'provider' parameter is required for embedding model")
1188
+
1189
+ # check available providers
1190
+ avail_providers = ("openai", "azure_openai", "bedrock", "gemini", "google")
1191
+ if params["provider"] not in avail_providers:
1192
+ raise ValueError(
1193
+ f"Wrong embedding provider: {params['provider']}. Available providers: {', '.join(avail_providers)}"
1194
+ )
1195
+
1196
+ if params["provider"] not in ("openai", "azure_openai"):
1106
1197
  # try use litellm
1107
- KnowledgeBaseTable.call_litellm_embedding(self.session, params, ["test"])
1198
+ try:
1199
+ KnowledgeBaseTable.call_litellm_embedding(self.session, params, ["test"])
1200
+ except Exception as e:
1201
+ raise RuntimeError(f"Problem with embedding model config: {e}")
1108
1202
  return
1109
1203
 
1110
1204
  if "provider" in params:
1111
1205
  engine = params.pop("provider").lower()
1112
1206
 
1113
- api_key = get_api_key(engine, params, strict=False) or params.pop("api_key")
1207
+ api_key = get_api_key(engine, params, strict=False)
1208
+ if api_key is None:
1209
+ if "api_key" in params:
1210
+ params.pop("api_key")
1211
+ else:
1212
+ raise ValueError("'api_key' parameter is required for embedding model")
1114
1213
 
1115
1214
  if engine == "azure_openai":
1116
1215
  engine = "openai"