MindsDB 25.7.1.0__py3-none-any.whl → 25.7.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of MindsDB might be problematic. Click here for more details.
- mindsdb/__about__.py +1 -1
- mindsdb/__main__.py +54 -95
- mindsdb/api/a2a/agent.py +30 -206
- mindsdb/api/a2a/common/server/server.py +26 -27
- mindsdb/api/a2a/task_manager.py +93 -227
- mindsdb/api/a2a/utils.py +21 -0
- mindsdb/api/executor/command_executor.py +7 -2
- mindsdb/api/executor/datahub/datanodes/integration_datanode.py +5 -1
- mindsdb/api/executor/utilities/sql.py +97 -21
- mindsdb/api/http/namespaces/agents.py +127 -202
- mindsdb/api/http/namespaces/config.py +12 -1
- mindsdb/integrations/handlers/litellm_handler/litellm_handler.py +11 -1
- mindsdb/integrations/handlers/llama_index_handler/requirements.txt +1 -1
- mindsdb/integrations/handlers/pgvector_handler/pgvector_handler.py +94 -1
- mindsdb/integrations/handlers/s3_handler/s3_handler.py +72 -70
- mindsdb/integrations/handlers/salesforce_handler/salesforce_handler.py +4 -3
- mindsdb/integrations/handlers/salesforce_handler/salesforce_tables.py +12 -3
- mindsdb/integrations/handlers/slack_handler/slack_tables.py +141 -161
- mindsdb/integrations/handlers/youtube_handler/youtube_tables.py +183 -55
- mindsdb/integrations/libs/keyword_search_base.py +41 -0
- mindsdb/integrations/libs/vectordatabase_handler.py +35 -14
- mindsdb/integrations/utilities/sql_utils.py +11 -0
- mindsdb/interfaces/agents/agents_controller.py +2 -2
- mindsdb/interfaces/data_catalog/data_catalog_loader.py +18 -4
- mindsdb/interfaces/database/projects.py +1 -3
- mindsdb/interfaces/functions/controller.py +54 -64
- mindsdb/interfaces/functions/to_markdown.py +47 -14
- mindsdb/interfaces/knowledge_base/controller.py +134 -35
- mindsdb/interfaces/knowledge_base/evaluate.py +53 -10
- mindsdb/interfaces/knowledge_base/llm_client.py +3 -3
- mindsdb/interfaces/knowledge_base/preprocessing/document_preprocessor.py +21 -13
- mindsdb/utilities/config.py +46 -39
- mindsdb/utilities/exception.py +11 -0
- {mindsdb-25.7.1.0.dist-info → mindsdb-25.7.3.0.dist-info}/METADATA +236 -236
- {mindsdb-25.7.1.0.dist-info → mindsdb-25.7.3.0.dist-info}/RECORD +38 -36
- {mindsdb-25.7.1.0.dist-info → mindsdb-25.7.3.0.dist-info}/WHEEL +0 -0
- {mindsdb-25.7.1.0.dist-info → mindsdb-25.7.3.0.dist-info}/licenses/LICENSE +0 -0
- {mindsdb-25.7.1.0.dist-info → mindsdb-25.7.3.0.dist-info}/top_level.txt +0 -0
|
@@ -7,15 +7,15 @@ from mindsdb.utilities.config import config
|
|
|
7
7
|
|
|
8
8
|
|
|
9
9
|
def python_to_duckdb_type(py_type):
|
|
10
|
-
if py_type ==
|
|
10
|
+
if py_type == "int":
|
|
11
11
|
return BIGINT
|
|
12
|
-
elif py_type ==
|
|
12
|
+
elif py_type == "float":
|
|
13
13
|
return DOUBLE
|
|
14
|
-
elif py_type ==
|
|
14
|
+
elif py_type == "str":
|
|
15
15
|
return VARCHAR
|
|
16
|
-
elif py_type ==
|
|
16
|
+
elif py_type == "bool":
|
|
17
17
|
return BOOLEAN
|
|
18
|
-
elif py_type ==
|
|
18
|
+
elif py_type == "bytes":
|
|
19
19
|
return BLOB
|
|
20
20
|
else:
|
|
21
21
|
# Unknown
|
|
@@ -53,8 +53,8 @@ class BYOMFunctionsController:
|
|
|
53
53
|
# first run
|
|
54
54
|
self.byom_engines = []
|
|
55
55
|
for name, info in self.session.integration_controller.get_all().items():
|
|
56
|
-
if info[
|
|
57
|
-
if info[
|
|
56
|
+
if info["type"] == "ml" and info["engine"] == "byom":
|
|
57
|
+
if info["connection_data"].get("mode") == "custom_function":
|
|
58
58
|
self.byom_engines.append(name)
|
|
59
59
|
return self.byom_engines
|
|
60
60
|
|
|
@@ -63,7 +63,7 @@ class BYOMFunctionsController:
|
|
|
63
63
|
ml_handler = self.session.integration_controller.get_ml_handler(engine)
|
|
64
64
|
|
|
65
65
|
storage = HandlerStorage(ml_handler.integration_id)
|
|
66
|
-
methods = storage.json_get(
|
|
66
|
+
methods = storage.json_get("methods")
|
|
67
67
|
self.byom_methods[engine] = methods
|
|
68
68
|
self.byom_handlers[engine] = ml_handler
|
|
69
69
|
|
|
@@ -81,7 +81,7 @@ class BYOMFunctionsController:
|
|
|
81
81
|
# do nothing
|
|
82
82
|
return
|
|
83
83
|
|
|
84
|
-
new_name = f
|
|
84
|
+
new_name = f"{node.namespace}_{fnc_name}"
|
|
85
85
|
node.op = new_name
|
|
86
86
|
|
|
87
87
|
if new_name in self.callbacks:
|
|
@@ -91,16 +91,13 @@ class BYOMFunctionsController:
|
|
|
91
91
|
def callback(*args):
|
|
92
92
|
return self.method_call(engine, fnc_name, args)
|
|
93
93
|
|
|
94
|
-
input_types = [
|
|
95
|
-
param['type']
|
|
96
|
-
for param in methods[fnc_name]['input_params']
|
|
97
|
-
]
|
|
94
|
+
input_types = [param["type"] for param in methods[fnc_name]["input_params"]]
|
|
98
95
|
|
|
99
96
|
meta = {
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
97
|
+
"name": new_name,
|
|
98
|
+
"callback": callback,
|
|
99
|
+
"input_types": input_types,
|
|
100
|
+
"output_type": methods[fnc_name]["output_type"],
|
|
104
101
|
}
|
|
105
102
|
|
|
106
103
|
self.callbacks[new_name] = meta
|
|
@@ -114,7 +111,6 @@ class BYOMFunctionsController:
|
|
|
114
111
|
|
|
115
112
|
|
|
116
113
|
class FunctionController(BYOMFunctionsController):
|
|
117
|
-
|
|
118
114
|
def __init__(self, *args, **kwargs):
|
|
119
115
|
super().__init__(*args, **kwargs)
|
|
120
116
|
|
|
@@ -124,10 +120,10 @@ class FunctionController(BYOMFunctionsController):
|
|
|
124
120
|
return meta
|
|
125
121
|
|
|
126
122
|
# builtin functions
|
|
127
|
-
if node.op.lower() ==
|
|
123
|
+
if node.op.lower() == "llm":
|
|
128
124
|
return self.llm_call_function(node)
|
|
129
125
|
|
|
130
|
-
elif node.op.lower() ==
|
|
126
|
+
elif node.op.lower() == "to_markdown":
|
|
131
127
|
return self.to_markdown_call_function(node)
|
|
132
128
|
|
|
133
129
|
def llm_call_function(self, node):
|
|
@@ -141,70 +137,74 @@ class FunctionController(BYOMFunctionsController):
|
|
|
141
137
|
try:
|
|
142
138
|
from langchain_core.messages import HumanMessage
|
|
143
139
|
from mindsdb.interfaces.agents.langchain_agent import create_chat_model
|
|
140
|
+
|
|
144
141
|
llm = create_chat_model(chat_model_params)
|
|
145
142
|
except Exception as e:
|
|
146
|
-
raise RuntimeError(f
|
|
143
|
+
raise RuntimeError(f"Unable to use LLM function, check ENV variables: {e}")
|
|
147
144
|
|
|
148
145
|
def callback(question):
|
|
149
146
|
resp = llm([HumanMessage(question)])
|
|
150
147
|
return resp.content
|
|
151
148
|
|
|
152
|
-
meta = {
|
|
153
|
-
'name': name,
|
|
154
|
-
'callback': callback,
|
|
155
|
-
'input_types': ['str'],
|
|
156
|
-
'output_type': 'str'
|
|
157
|
-
}
|
|
149
|
+
meta = {"name": name, "callback": callback, "input_types": ["str"], "output_type": "str"}
|
|
158
150
|
self.callbacks[name] = meta
|
|
159
151
|
return meta
|
|
160
152
|
|
|
161
153
|
def to_markdown_call_function(self, node):
|
|
162
154
|
# load on-demand because lib is heavy
|
|
163
155
|
from mindsdb.interfaces.functions.to_markdown import ToMarkdown
|
|
156
|
+
|
|
164
157
|
name = node.op.lower()
|
|
165
158
|
|
|
166
159
|
if name in self.callbacks:
|
|
167
160
|
return self.callbacks[name]
|
|
168
161
|
|
|
169
|
-
def
|
|
170
|
-
|
|
171
|
-
|
|
162
|
+
def prepare_chat_model_params(chat_model_params: dict) -> dict:
|
|
163
|
+
"""
|
|
164
|
+
Parepares the chat model parameters for the ToMarkdown function.
|
|
165
|
+
"""
|
|
172
166
|
params_copy = copy.deepcopy(chat_model_params)
|
|
173
|
-
params_copy[
|
|
174
|
-
|
|
175
|
-
|
|
167
|
+
params_copy["model"] = params_copy.pop("model_name")
|
|
168
|
+
|
|
169
|
+
# Set the base_url for the Google provider.
|
|
170
|
+
if params_copy["provider"] == "google" and "base_url" not in params_copy:
|
|
171
|
+
params_copy["base_url"] = "https://generativelanguage.googleapis.com/v1beta/"
|
|
172
|
+
|
|
173
|
+
params_copy.pop("api_keys")
|
|
174
|
+
params_copy.pop("provider")
|
|
175
|
+
|
|
176
|
+
return params_copy
|
|
177
|
+
|
|
178
|
+
def callback(file_path_or_url):
|
|
179
|
+
chat_model_params = self._parse_chat_model_params("TO_MARKDOWN_FUNCTION_")
|
|
180
|
+
chat_model_params = prepare_chat_model_params(chat_model_params)
|
|
176
181
|
|
|
177
182
|
to_markdown = ToMarkdown()
|
|
178
|
-
return to_markdown.call(file_path_or_url, **
|
|
183
|
+
return to_markdown.call(file_path_or_url, **chat_model_params)
|
|
179
184
|
|
|
180
|
-
meta = {
|
|
181
|
-
'name': name,
|
|
182
|
-
'callback': callback,
|
|
183
|
-
'input_types': ['str'],
|
|
184
|
-
'output_type': 'str'
|
|
185
|
-
}
|
|
185
|
+
meta = {"name": name, "callback": callback, "input_types": ["str"], "output_type": "str"}
|
|
186
186
|
self.callbacks[name] = meta
|
|
187
187
|
return meta
|
|
188
188
|
|
|
189
|
-
def _parse_chat_model_params(self, param_prefix: str =
|
|
189
|
+
def _parse_chat_model_params(self, param_prefix: str = "LLM_FUNCTION_"):
|
|
190
190
|
"""
|
|
191
191
|
Parses the environment variables for chat model parameters.
|
|
192
192
|
"""
|
|
193
193
|
chat_model_params = config.get("default_llm") or {}
|
|
194
194
|
for k, v in os.environ.items():
|
|
195
195
|
if k.startswith(param_prefix):
|
|
196
|
-
param_name = k[len(param_prefix):]
|
|
197
|
-
if param_name ==
|
|
198
|
-
chat_model_params[
|
|
196
|
+
param_name = k[len(param_prefix) :]
|
|
197
|
+
if param_name == "MODEL":
|
|
198
|
+
chat_model_params["model_name"] = v
|
|
199
199
|
else:
|
|
200
200
|
chat_model_params[param_name.lower()] = v
|
|
201
201
|
|
|
202
|
-
if
|
|
203
|
-
chat_model_params[
|
|
202
|
+
if "provider" not in chat_model_params:
|
|
203
|
+
chat_model_params["provider"] = "openai"
|
|
204
204
|
|
|
205
|
-
if
|
|
205
|
+
if "api_key" in chat_model_params:
|
|
206
206
|
# move to api_keys dict
|
|
207
|
-
chat_model_params["api_keys"] = {chat_model_params[
|
|
207
|
+
chat_model_params["api_keys"] = {chat_model_params["provider"]: chat_model_params["api_key"]}
|
|
208
208
|
|
|
209
209
|
return chat_model_params
|
|
210
210
|
|
|
@@ -215,33 +215,23 @@ class DuckDBFunctions:
|
|
|
215
215
|
self.functions = {}
|
|
216
216
|
|
|
217
217
|
def check_function(self, node):
|
|
218
|
-
|
|
219
218
|
meta = self.controller.check_function(node)
|
|
220
219
|
if meta is None:
|
|
221
220
|
return
|
|
222
221
|
|
|
223
|
-
name = meta[
|
|
222
|
+
name = meta["name"]
|
|
224
223
|
|
|
225
224
|
if name in self.functions:
|
|
226
225
|
return
|
|
227
226
|
|
|
228
|
-
input_types = [
|
|
229
|
-
python_to_duckdb_type(param)
|
|
230
|
-
for param in meta['input_types']
|
|
231
|
-
]
|
|
227
|
+
input_types = [python_to_duckdb_type(param) for param in meta["input_types"]]
|
|
232
228
|
|
|
233
229
|
self.functions[name] = {
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
230
|
+
"callback": function_maker(len(input_types), meta["callback"]),
|
|
231
|
+
"input": input_types,
|
|
232
|
+
"output": python_to_duckdb_type(meta["output_type"]),
|
|
237
233
|
}
|
|
238
234
|
|
|
239
235
|
def register(self, connection):
|
|
240
236
|
for name, info in self.functions.items():
|
|
241
|
-
connection.create_function(
|
|
242
|
-
name,
|
|
243
|
-
info['callback'],
|
|
244
|
-
info['input'],
|
|
245
|
-
info['output'],
|
|
246
|
-
null_handling="special"
|
|
247
|
-
)
|
|
237
|
+
connection.create_function(name, info["callback"], info["input"], info["output"], null_handling="special")
|
|
@@ -2,6 +2,7 @@ from io import BytesIO
|
|
|
2
2
|
import os
|
|
3
3
|
from typing import Union
|
|
4
4
|
from urllib.parse import urlparse
|
|
5
|
+
import xml.etree.ElementTree as ET
|
|
5
6
|
|
|
6
7
|
from aipdf import ocr
|
|
7
8
|
import mimetypes
|
|
@@ -12,6 +13,7 @@ class ToMarkdown:
|
|
|
12
13
|
"""
|
|
13
14
|
Extracts the content of documents of various formats in markdown format.
|
|
14
15
|
"""
|
|
16
|
+
|
|
15
17
|
def __init__(self):
|
|
16
18
|
"""
|
|
17
19
|
Initializes the ToMarkdown class.
|
|
@@ -24,24 +26,28 @@ class ToMarkdown:
|
|
|
24
26
|
file_extension = self._get_file_extension(file_path_or_url)
|
|
25
27
|
file_content = self._get_file_content(file_path_or_url)
|
|
26
28
|
|
|
27
|
-
if file_extension ==
|
|
29
|
+
if file_extension == ".pdf":
|
|
28
30
|
return self._pdf_to_markdown(file_content, **kwargs)
|
|
31
|
+
|
|
32
|
+
elif file_extension in (".xml", ".nessus"):
|
|
33
|
+
return self._xml_to_markdown(file_content, **kwargs)
|
|
34
|
+
|
|
29
35
|
else:
|
|
30
36
|
raise ValueError(f"Unsupported file type: {file_extension}.")
|
|
31
37
|
|
|
32
|
-
def _get_file_content(self, file_path_or_url: str) ->
|
|
38
|
+
def _get_file_content(self, file_path_or_url: str) -> BytesIO:
|
|
33
39
|
"""
|
|
34
40
|
Retrieves the content of a file.
|
|
35
41
|
"""
|
|
36
42
|
parsed_url = urlparse(file_path_or_url)
|
|
37
|
-
if parsed_url.scheme in (
|
|
43
|
+
if parsed_url.scheme in ("http", "https"):
|
|
38
44
|
response = requests.get(file_path_or_url)
|
|
39
45
|
if response.status_code == 200:
|
|
40
|
-
return response
|
|
46
|
+
return BytesIO(response.content)
|
|
41
47
|
else:
|
|
42
|
-
raise RuntimeError(f
|
|
48
|
+
raise RuntimeError(f"Unable to retrieve file from URL: {file_path_or_url}")
|
|
43
49
|
else:
|
|
44
|
-
with open(file_path_or_url,
|
|
50
|
+
with open(file_path_or_url, "rb") as file:
|
|
45
51
|
return BytesIO(file.read())
|
|
46
52
|
|
|
47
53
|
def _get_file_extension(self, file_path_or_url: str) -> str:
|
|
@@ -49,13 +55,13 @@ class ToMarkdown:
|
|
|
49
55
|
Retrieves the file extension from a file path or URL.
|
|
50
56
|
"""
|
|
51
57
|
parsed_url = urlparse(file_path_or_url)
|
|
52
|
-
if parsed_url.scheme in (
|
|
58
|
+
if parsed_url.scheme in ("http", "https"):
|
|
53
59
|
try:
|
|
54
60
|
# Make a HEAD request to get headers without downloading the file.
|
|
55
61
|
response = requests.head(file_path_or_url, allow_redirects=True)
|
|
56
|
-
content_type = response.headers.get(
|
|
62
|
+
content_type = response.headers.get("Content-Type", "")
|
|
57
63
|
if content_type:
|
|
58
|
-
ext = mimetypes.guess_extension(content_type.split(
|
|
64
|
+
ext = mimetypes.guess_extension(content_type.split(";")[0].strip())
|
|
59
65
|
if ext:
|
|
60
66
|
return ext
|
|
61
67
|
|
|
@@ -64,16 +70,43 @@ class ToMarkdown:
|
|
|
64
70
|
if ext:
|
|
65
71
|
return ext
|
|
66
72
|
except requests.RequestException:
|
|
67
|
-
raise RuntimeError(f
|
|
73
|
+
raise RuntimeError(f"Unable to retrieve file extension from URL: {file_path_or_url}")
|
|
68
74
|
else:
|
|
69
75
|
return os.path.splitext(file_path_or_url)[1]
|
|
70
76
|
|
|
71
|
-
def _pdf_to_markdown(self, file_content: Union[requests.Response,
|
|
77
|
+
def _pdf_to_markdown(self, file_content: Union[requests.Response, BytesIO], **kwargs) -> str:
|
|
72
78
|
"""
|
|
73
79
|
Converts a PDF file to markdown.
|
|
74
80
|
"""
|
|
75
|
-
if isinstance(file_content, requests.Response):
|
|
76
|
-
file_content = BytesIO(file_content.content)
|
|
77
|
-
|
|
78
81
|
markdown_pages = ocr(file_content, **kwargs)
|
|
79
82
|
return "\n\n---\n\n".join(markdown_pages)
|
|
83
|
+
|
|
84
|
+
def _xml_to_markdown(self, file_content: Union[requests.Response, BytesIO], **kwargs) -> str:
|
|
85
|
+
"""
|
|
86
|
+
Converts an XML (or Nessus) file to markdown.
|
|
87
|
+
"""
|
|
88
|
+
|
|
89
|
+
def parse_element(element: ET.Element, depth: int = 0) -> str:
|
|
90
|
+
"""
|
|
91
|
+
Recursively parses an XML element and converts it to markdown.
|
|
92
|
+
"""
|
|
93
|
+
markdown = []
|
|
94
|
+
heading = "#" * (depth + 1)
|
|
95
|
+
|
|
96
|
+
markdown.append(f"{heading} {element.tag}")
|
|
97
|
+
|
|
98
|
+
for key, val in element.attrib.items():
|
|
99
|
+
markdown.append(f"- **{key}**: {val}")
|
|
100
|
+
|
|
101
|
+
text = (element.text or "").strip()
|
|
102
|
+
if text:
|
|
103
|
+
markdown.append(f"\n{text}\n")
|
|
104
|
+
|
|
105
|
+
for child in element:
|
|
106
|
+
markdown.append(parse_element(child, depth + 1))
|
|
107
|
+
|
|
108
|
+
return "\n".join(markdown)
|
|
109
|
+
|
|
110
|
+
root = ET.fromstring(file_content.read().decode("utf-8"))
|
|
111
|
+
markdown_content = parse_element(root)
|
|
112
|
+
return markdown_content
|
|
@@ -1,17 +1,19 @@
|
|
|
1
1
|
import os
|
|
2
2
|
import copy
|
|
3
|
-
from typing import Dict, List, Optional
|
|
3
|
+
from typing import Dict, List, Optional, Any, Text
|
|
4
4
|
import json
|
|
5
5
|
import decimal
|
|
6
6
|
|
|
7
7
|
import pandas as pd
|
|
8
8
|
import numpy as np
|
|
9
|
+
from pydantic import BaseModel, ValidationError
|
|
9
10
|
from sqlalchemy.orm.attributes import flag_modified
|
|
10
11
|
|
|
11
12
|
from mindsdb_sql_parser.ast import BinaryOperation, Constant, Identifier, Select, Update, Delete, Star
|
|
12
13
|
from mindsdb_sql_parser.ast.mindsdb import CreatePredictor
|
|
13
14
|
from mindsdb_sql_parser import parse_sql
|
|
14
15
|
|
|
16
|
+
from mindsdb.integrations.libs.keyword_search_base import KeywordSearchBase
|
|
15
17
|
from mindsdb.integrations.utilities.query_traversal import query_traversal
|
|
16
18
|
|
|
17
19
|
import mindsdb.interfaces.storage.db as db
|
|
@@ -37,7 +39,7 @@ from mindsdb.interfaces.knowledge_base.evaluate import EvaluateBase
|
|
|
37
39
|
from mindsdb.interfaces.knowledge_base.executor import KnowledgeBaseQueryExecutor
|
|
38
40
|
from mindsdb.interfaces.model.functions import PredictorRecordNotFound
|
|
39
41
|
from mindsdb.utilities.exception import EntityExistsError, EntityNotExistsError
|
|
40
|
-
from mindsdb.integrations.utilities.sql_utils import FilterCondition, FilterOperator
|
|
42
|
+
from mindsdb.integrations.utilities.sql_utils import FilterCondition, FilterOperator, KeywordSearchArgs
|
|
41
43
|
from mindsdb.utilities.config import config
|
|
42
44
|
from mindsdb.utilities.context import context as ctx
|
|
43
45
|
|
|
@@ -49,6 +51,20 @@ from mindsdb.integrations.utilities.rag.rerankers.base_reranker import BaseLLMRe
|
|
|
49
51
|
logger = log.getLogger(__name__)
|
|
50
52
|
|
|
51
53
|
|
|
54
|
+
class KnowledgeBaseInputParams(BaseModel):
|
|
55
|
+
metadata_columns: List[str] | None = None
|
|
56
|
+
content_columns: List[str] | None = None
|
|
57
|
+
id_column: str | None = None
|
|
58
|
+
kb_no_upsert: bool = False
|
|
59
|
+
embedding_model: Dict[Text, Any] | None = None
|
|
60
|
+
is_sparse: bool = False
|
|
61
|
+
vector_size: int | None = None
|
|
62
|
+
reranking_model: Dict[Text, Any] | None = None
|
|
63
|
+
|
|
64
|
+
class Config:
|
|
65
|
+
extra = "forbid"
|
|
66
|
+
|
|
67
|
+
|
|
52
68
|
def get_model_params(model_params: dict, default_config_key: str):
|
|
53
69
|
"""
|
|
54
70
|
Get model parameters by combining default config with user provided parameters.
|
|
@@ -101,7 +117,10 @@ def get_reranking_model_from_params(reranking_model_params: dict):
|
|
|
101
117
|
|
|
102
118
|
if "api_key" not in params_copy:
|
|
103
119
|
params_copy["api_key"] = get_api_key(provider, params_copy, strict=False)
|
|
104
|
-
|
|
120
|
+
|
|
121
|
+
if "model_name" not in params_copy:
|
|
122
|
+
raise ValueError("'model_name' must be provided for reranking model")
|
|
123
|
+
params_copy["model"] = params_copy.pop("model_name")
|
|
105
124
|
|
|
106
125
|
return BaseLLMReranker(**params_copy)
|
|
107
126
|
|
|
@@ -179,17 +198,20 @@ class KnowledgeBaseTable:
|
|
|
179
198
|
df = executor.run(query)
|
|
180
199
|
|
|
181
200
|
if (
|
|
182
|
-
|
|
183
|
-
or
|
|
184
|
-
or
|
|
185
|
-
or
|
|
186
|
-
or len(
|
|
187
|
-
or not isinstance(
|
|
201
|
+
query_copy.group_by is not None
|
|
202
|
+
or query_copy.order_by is not None
|
|
203
|
+
or query_copy.having is not None
|
|
204
|
+
or query_copy.distinct is True
|
|
205
|
+
or len(query_copy.targets) != 1
|
|
206
|
+
or not isinstance(query_copy.targets[0], Star)
|
|
188
207
|
):
|
|
189
208
|
query_copy.where = None
|
|
190
209
|
if "metadata" in df.columns:
|
|
191
210
|
df["metadata"] = df["metadata"].apply(to_json)
|
|
192
211
|
|
|
212
|
+
if query_copy.from_table is None:
|
|
213
|
+
query_copy.from_table = Identifier(parts=[self._kb.name])
|
|
214
|
+
|
|
193
215
|
df = query_df(df, query_copy, session=self.session)
|
|
194
216
|
|
|
195
217
|
return df
|
|
@@ -218,8 +240,12 @@ class KnowledgeBaseTable:
|
|
|
218
240
|
|
|
219
241
|
# extract values from conditions and prepare for vectordb
|
|
220
242
|
conditions = []
|
|
243
|
+
keyword_search_conditions = []
|
|
244
|
+
keyword_search_cols_and_values = []
|
|
221
245
|
query_text = None
|
|
222
246
|
relevance_threshold = None
|
|
247
|
+
reranking_enabled_flag = True
|
|
248
|
+
hybrid_search_enabled_flag = False
|
|
223
249
|
query_conditions = db_handler.extract_conditions(query.where)
|
|
224
250
|
if query_conditions is not None:
|
|
225
251
|
for item in query_conditions:
|
|
@@ -235,9 +261,17 @@ class KnowledgeBaseTable:
|
|
|
235
261
|
logger.error(error_msg)
|
|
236
262
|
raise ValueError(error_msg)
|
|
237
263
|
elif item.column == "reranking":
|
|
264
|
+
reranking_enabled_flag = item.value
|
|
265
|
+
# cast to boolean
|
|
266
|
+
if isinstance(reranking_enabled_flag, str):
|
|
267
|
+
reranking_enabled_flag = reranking_enabled_flag.lower() not in ("false")
|
|
268
|
+
elif item.column == "hybrid_search":
|
|
269
|
+
hybrid_search_enabled_flag = item.value
|
|
270
|
+
# cast to boolean
|
|
271
|
+
if isinstance(hybrid_search_enabled_flag, str):
|
|
272
|
+
hybrid_search_enabled_flag = hybrid_search_enabled_flag.lower() not in ("false")
|
|
238
273
|
if item.value is False or (isinstance(item.value, str) and item.value.lower() == "false"):
|
|
239
274
|
disable_reranking = True
|
|
240
|
-
|
|
241
275
|
elif item.column == "relevance" and item.op.value != FilterOperator.GREATER_THAN_OR_EQUAL.value:
|
|
242
276
|
raise ValueError(
|
|
243
277
|
f"Invalid operator for relevance: {item.op.value}. Only GREATER_THAN_OR_EQUAL is allowed."
|
|
@@ -253,8 +287,16 @@ class KnowledgeBaseTable:
|
|
|
253
287
|
op=FilterOperator.EQUAL,
|
|
254
288
|
)
|
|
255
289
|
)
|
|
290
|
+
keyword_search_cols_and_values.append((TableField.CONTENT.value, item.value))
|
|
256
291
|
else:
|
|
257
292
|
conditions.append(item)
|
|
293
|
+
keyword_search_conditions.append(item) # keyword search conditions do not use embeddings
|
|
294
|
+
|
|
295
|
+
if len(keyword_search_cols_and_values) > 1:
|
|
296
|
+
raise ValueError(
|
|
297
|
+
"Multiple content columns found in query conditions. "
|
|
298
|
+
"Only one content column is allowed for keyword search."
|
|
299
|
+
)
|
|
258
300
|
|
|
259
301
|
logger.debug(f"Extracted query text: {query_text}")
|
|
260
302
|
|
|
@@ -272,9 +314,42 @@ class KnowledgeBaseTable:
|
|
|
272
314
|
allowed_metadata_columns = self._get_allowed_metadata_columns()
|
|
273
315
|
df = db_handler.dispatch_select(query, conditions, allowed_metadata_columns=allowed_metadata_columns)
|
|
274
316
|
df = self.addapt_result_columns(df)
|
|
275
|
-
|
|
276
317
|
logger.debug(f"Query returned {len(df)} rows")
|
|
277
318
|
logger.debug(f"Columns in response: {df.columns.tolist()}")
|
|
319
|
+
|
|
320
|
+
if hybrid_search_enabled_flag and not isinstance(db_handler, KeywordSearchBase):
|
|
321
|
+
raise ValueError(f"Hybrid search is enabled but the db_handler {type(db_handler)} does not support it. ")
|
|
322
|
+
# check if db_handler inherits from KeywordSearchBase
|
|
323
|
+
if hybrid_search_enabled_flag and isinstance(db_handler, KeywordSearchBase):
|
|
324
|
+
# If query_text is present, use it for keyword search
|
|
325
|
+
logger.debug(f"Performing keyword search with query text: {query_text}")
|
|
326
|
+
keyword_search_args = KeywordSearchArgs(query=query_text, column=TableField.CONTENT.value)
|
|
327
|
+
keyword_query_obj = copy.deepcopy(query)
|
|
328
|
+
|
|
329
|
+
keyword_query_obj.targets = [
|
|
330
|
+
Identifier(TableField.ID.value),
|
|
331
|
+
Identifier(TableField.CONTENT.value),
|
|
332
|
+
Identifier(TableField.METADATA.value),
|
|
333
|
+
]
|
|
334
|
+
|
|
335
|
+
df_keyword_select = db_handler.dispatch_select(
|
|
336
|
+
keyword_query_obj, keyword_search_conditions, keyword_search_args=keyword_search_args
|
|
337
|
+
)
|
|
338
|
+
df_keyword_select = self.addapt_result_columns(df_keyword_select)
|
|
339
|
+
logger.debug(f"Keyword search returned {len(df_keyword_select)} rows")
|
|
340
|
+
logger.debug(f"Columns in keyword search response: {df_keyword_select.columns.tolist()}")
|
|
341
|
+
# ensure df and df_keyword_select have exactly the same columns
|
|
342
|
+
if not df_keyword_select.empty:
|
|
343
|
+
if set(df.columns) != set(df_keyword_select.columns):
|
|
344
|
+
raise ValueError(
|
|
345
|
+
f"Keyword search returned different columns: {df_keyword_select.columns} "
|
|
346
|
+
f"than expected: {df.columns}"
|
|
347
|
+
)
|
|
348
|
+
df = pd.concat([df, df_keyword_select], ignore_index=True)
|
|
349
|
+
# if chunk_id column exists remove duplicates based on chunk_id
|
|
350
|
+
if "chunk_id" in df.columns:
|
|
351
|
+
df = df.drop_duplicates(subset=["chunk_id"])
|
|
352
|
+
|
|
278
353
|
# Check if we have a rerank_model configured in KB params
|
|
279
354
|
df = self.add_relevance(df, query_text, relevance_threshold, disable_reranking)
|
|
280
355
|
|
|
@@ -736,8 +811,7 @@ class KnowledgeBaseTable:
|
|
|
736
811
|
if model_id is None:
|
|
737
812
|
# call litellm handler
|
|
738
813
|
messages = list(df[TableField.CONTENT.value])
|
|
739
|
-
embedding_params =
|
|
740
|
-
embedding_params.update(self._kb.params["embedding_model"])
|
|
814
|
+
embedding_params = get_model_params(self._kb.params.get("embedding_model", {}), "default_embedding_model")
|
|
741
815
|
results = self.call_litellm_embedding(self.session, embedding_params, messages)
|
|
742
816
|
results = [[val] for val in results]
|
|
743
817
|
return pd.DataFrame(results, columns=[TableField.EMBEDDINGS.value])
|
|
@@ -783,6 +857,9 @@ class KnowledgeBaseTable:
|
|
|
783
857
|
def call_litellm_embedding(session, model_params, messages):
|
|
784
858
|
args = copy.deepcopy(model_params)
|
|
785
859
|
|
|
860
|
+
if "model_name" not in args:
|
|
861
|
+
raise ValueError("'model_name' must be provided for embedding model")
|
|
862
|
+
|
|
786
863
|
llm_model = args.pop("model_name")
|
|
787
864
|
engine = args.pop("provider")
|
|
788
865
|
|
|
@@ -936,6 +1013,24 @@ class KnowledgeBaseController:
|
|
|
936
1013
|
# fill variables
|
|
937
1014
|
params = variables_controller.fill_parameters(params)
|
|
938
1015
|
|
|
1016
|
+
try:
|
|
1017
|
+
KnowledgeBaseInputParams.model_validate(params)
|
|
1018
|
+
except ValidationError as e:
|
|
1019
|
+
problems = []
|
|
1020
|
+
for error in e.errors():
|
|
1021
|
+
parameter = ".".join([str(i) for i in error["loc"]])
|
|
1022
|
+
param_type = error["type"]
|
|
1023
|
+
if param_type == "extra_forbidden":
|
|
1024
|
+
msg = f"Parameter '{parameter}' is not allowed"
|
|
1025
|
+
else:
|
|
1026
|
+
msg = f"Error in '{parameter}' (type: {param_type}): {error['msg']}. Input: {repr(error['input'])}"
|
|
1027
|
+
problems.append(msg)
|
|
1028
|
+
|
|
1029
|
+
msg = "\n".join(problems)
|
|
1030
|
+
if len(problems) > 1:
|
|
1031
|
+
msg = "\n" + msg
|
|
1032
|
+
raise ValueError(f"Problem with knowledge base parameters: {msg}")
|
|
1033
|
+
|
|
939
1034
|
# Validate preprocessing config first if provided
|
|
940
1035
|
if preprocessing_config is not None:
|
|
941
1036
|
PreprocessingConfig(**preprocessing_config) # Validate before storing
|
|
@@ -961,24 +1056,6 @@ class KnowledgeBaseController:
|
|
|
961
1056
|
return kb
|
|
962
1057
|
raise EntityExistsError("Knowledge base already exists", name)
|
|
963
1058
|
|
|
964
|
-
embedding_params = copy.deepcopy(config.get("default_embedding_model", {}))
|
|
965
|
-
|
|
966
|
-
# Legacy
|
|
967
|
-
# model_name = None
|
|
968
|
-
# model_project = project
|
|
969
|
-
# if embedding_model:
|
|
970
|
-
# model_name = embedding_model.parts[-1]
|
|
971
|
-
# if len(embedding_model.parts) > 1:
|
|
972
|
-
# model_project = self.session.database_controller.get_project(embedding_model.parts[-2])
|
|
973
|
-
|
|
974
|
-
# elif "embedding_model" in params:
|
|
975
|
-
# if isinstance(params["embedding_model"], str):
|
|
976
|
-
# # it is model name
|
|
977
|
-
# model_name = params["embedding_model"]
|
|
978
|
-
# else:
|
|
979
|
-
# # it is params for model
|
|
980
|
-
# embedding_params.update(params["embedding_model"])
|
|
981
|
-
|
|
982
1059
|
embedding_params = get_model_params(params.get("embedding_model", {}), "default_embedding_model")
|
|
983
1060
|
|
|
984
1061
|
# if model_name is None: # Legacy
|
|
@@ -1009,7 +1086,11 @@ class KnowledgeBaseController:
|
|
|
1009
1086
|
if reranking_model_params:
|
|
1010
1087
|
# Get reranking model from params.
|
|
1011
1088
|
# This is called here to check validaity of the parameters.
|
|
1012
|
-
|
|
1089
|
+
try:
|
|
1090
|
+
reranker = get_reranking_model_from_params(reranking_model_params)
|
|
1091
|
+
reranker.get_scores("test", ["test"])
|
|
1092
|
+
except (ValueError, RuntimeError) as e:
|
|
1093
|
+
raise RuntimeError(f"Problem with reranker config: {e}")
|
|
1013
1094
|
|
|
1014
1095
|
# search for the vector database table
|
|
1015
1096
|
if storage is None:
|
|
@@ -1102,15 +1183,33 @@ class KnowledgeBaseController:
|
|
|
1102
1183
|
except PredictorRecordNotFound:
|
|
1103
1184
|
pass
|
|
1104
1185
|
|
|
1105
|
-
if
|
|
1186
|
+
if "provider" not in params:
|
|
1187
|
+
raise ValueError("'provider' parameter is required for embedding model")
|
|
1188
|
+
|
|
1189
|
+
# check available providers
|
|
1190
|
+
avail_providers = ("openai", "azure_openai", "bedrock", "gemini", "google")
|
|
1191
|
+
if params["provider"] not in avail_providers:
|
|
1192
|
+
raise ValueError(
|
|
1193
|
+
f"Wrong embedding provider: {params['provider']}. Available providers: {', '.join(avail_providers)}"
|
|
1194
|
+
)
|
|
1195
|
+
|
|
1196
|
+
if params["provider"] not in ("openai", "azure_openai"):
|
|
1106
1197
|
# try use litellm
|
|
1107
|
-
|
|
1198
|
+
try:
|
|
1199
|
+
KnowledgeBaseTable.call_litellm_embedding(self.session, params, ["test"])
|
|
1200
|
+
except Exception as e:
|
|
1201
|
+
raise RuntimeError(f"Problem with embedding model config: {e}")
|
|
1108
1202
|
return
|
|
1109
1203
|
|
|
1110
1204
|
if "provider" in params:
|
|
1111
1205
|
engine = params.pop("provider").lower()
|
|
1112
1206
|
|
|
1113
|
-
api_key = get_api_key(engine, params, strict=False)
|
|
1207
|
+
api_key = get_api_key(engine, params, strict=False)
|
|
1208
|
+
if api_key is None:
|
|
1209
|
+
if "api_key" in params:
|
|
1210
|
+
params.pop("api_key")
|
|
1211
|
+
else:
|
|
1212
|
+
raise ValueError("'api_key' parameter is required for embedding model")
|
|
1114
1213
|
|
|
1115
1214
|
if engine == "azure_openai":
|
|
1116
1215
|
engine = "openai"
|