MindsDB 25.9.3rc1__py3-none-any.whl → 25.10.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of MindsDB might be problematic. Click here for more details.
- mindsdb/__about__.py +1 -1
- mindsdb/__main__.py +1 -9
- mindsdb/api/a2a/__init__.py +1 -1
- mindsdb/api/a2a/agent.py +9 -1
- mindsdb/api/a2a/common/server/server.py +4 -0
- mindsdb/api/a2a/common/server/task_manager.py +8 -1
- mindsdb/api/a2a/common/types.py +66 -0
- mindsdb/api/a2a/task_manager.py +50 -0
- mindsdb/api/common/middleware.py +1 -1
- mindsdb/api/executor/command_executor.py +49 -36
- mindsdb/api/executor/datahub/datanodes/information_schema_datanode.py +7 -13
- mindsdb/api/executor/datahub/datanodes/integration_datanode.py +2 -2
- mindsdb/api/executor/datahub/datanodes/system_tables.py +2 -1
- mindsdb/api/executor/planner/query_prepare.py +2 -20
- mindsdb/api/executor/utilities/sql.py +5 -4
- mindsdb/api/http/initialize.py +76 -60
- mindsdb/api/http/namespaces/agents.py +0 -3
- mindsdb/api/http/namespaces/chatbots.py +0 -5
- mindsdb/api/http/namespaces/file.py +2 -0
- mindsdb/api/http/namespaces/handlers.py +10 -5
- mindsdb/api/http/namespaces/knowledge_bases.py +20 -0
- mindsdb/api/http/namespaces/sql.py +2 -2
- mindsdb/api/http/start.py +2 -2
- mindsdb/api/mysql/mysql_proxy/utilities/dump.py +8 -2
- mindsdb/integrations/handlers/byom_handler/byom_handler.py +2 -10
- mindsdb/integrations/handlers/databricks_handler/databricks_handler.py +98 -46
- mindsdb/integrations/handlers/druid_handler/druid_handler.py +32 -40
- mindsdb/integrations/handlers/gitlab_handler/gitlab_handler.py +5 -2
- mindsdb/integrations/handlers/mssql_handler/mssql_handler.py +438 -100
- mindsdb/integrations/handlers/mssql_handler/requirements_odbc.txt +3 -0
- mindsdb/integrations/handlers/mysql_handler/mysql_handler.py +235 -3
- mindsdb/integrations/handlers/oracle_handler/__init__.py +2 -0
- mindsdb/integrations/handlers/oracle_handler/connection_args.py +7 -1
- mindsdb/integrations/handlers/oracle_handler/oracle_handler.py +321 -16
- mindsdb/integrations/handlers/oracle_handler/requirements.txt +1 -1
- mindsdb/integrations/handlers/postgres_handler/postgres_handler.py +2 -2
- mindsdb/integrations/handlers/zendesk_handler/zendesk_tables.py +144 -111
- mindsdb/integrations/libs/response.py +2 -2
- mindsdb/integrations/utilities/handlers/auth_utilities/snowflake/__init__.py +1 -0
- mindsdb/integrations/utilities/handlers/auth_utilities/snowflake/snowflake_jwt_gen.py +151 -0
- mindsdb/integrations/utilities/rag/rerankers/base_reranker.py +24 -21
- mindsdb/interfaces/agents/agents_controller.py +0 -2
- mindsdb/interfaces/data_catalog/data_catalog_loader.py +6 -7
- mindsdb/interfaces/data_catalog/data_catalog_reader.py +15 -4
- mindsdb/interfaces/database/data_handlers_cache.py +190 -0
- mindsdb/interfaces/database/database.py +3 -3
- mindsdb/interfaces/database/integrations.py +1 -121
- mindsdb/interfaces/database/projects.py +2 -6
- mindsdb/interfaces/database/views.py +1 -4
- mindsdb/interfaces/jobs/jobs_controller.py +0 -4
- mindsdb/interfaces/jobs/scheduler.py +0 -1
- mindsdb/interfaces/knowledge_base/controller.py +197 -108
- mindsdb/interfaces/knowledge_base/evaluate.py +36 -41
- mindsdb/interfaces/knowledge_base/executor.py +11 -0
- mindsdb/interfaces/knowledge_base/llm_client.py +51 -17
- mindsdb/interfaces/model/model_controller.py +4 -4
- mindsdb/interfaces/skills/custom/text2sql/mindsdb_sql_toolkit.py +4 -10
- mindsdb/interfaces/skills/skills_controller.py +1 -4
- mindsdb/interfaces/storage/db.py +16 -6
- mindsdb/interfaces/triggers/triggers_controller.py +1 -3
- mindsdb/utilities/config.py +19 -2
- mindsdb/utilities/exception.py +2 -2
- mindsdb/utilities/json_encoder.py +24 -10
- mindsdb/utilities/render/sqlalchemy_render.py +15 -14
- mindsdb/utilities/starters.py +0 -10
- {mindsdb-25.9.3rc1.dist-info → mindsdb-25.10.0rc1.dist-info}/METADATA +276 -264
- {mindsdb-25.9.3rc1.dist-info → mindsdb-25.10.0rc1.dist-info}/RECORD +70 -84
- mindsdb/api/postgres/__init__.py +0 -0
- mindsdb/api/postgres/postgres_proxy/__init__.py +0 -0
- mindsdb/api/postgres/postgres_proxy/executor/__init__.py +0 -1
- mindsdb/api/postgres/postgres_proxy/executor/executor.py +0 -182
- mindsdb/api/postgres/postgres_proxy/postgres_packets/__init__.py +0 -0
- mindsdb/api/postgres/postgres_proxy/postgres_packets/errors.py +0 -322
- mindsdb/api/postgres/postgres_proxy/postgres_packets/postgres_fields.py +0 -34
- mindsdb/api/postgres/postgres_proxy/postgres_packets/postgres_message.py +0 -31
- mindsdb/api/postgres/postgres_proxy/postgres_packets/postgres_message_formats.py +0 -1265
- mindsdb/api/postgres/postgres_proxy/postgres_packets/postgres_message_identifiers.py +0 -31
- mindsdb/api/postgres/postgres_proxy/postgres_packets/postgres_packets.py +0 -265
- mindsdb/api/postgres/postgres_proxy/postgres_proxy.py +0 -477
- mindsdb/api/postgres/postgres_proxy/utilities/__init__.py +0 -10
- mindsdb/api/postgres/start.py +0 -11
- mindsdb/integrations/handlers/mssql_handler/tests/__init__.py +0 -0
- mindsdb/integrations/handlers/mssql_handler/tests/test_mssql_handler.py +0 -169
- mindsdb/integrations/handlers/oracle_handler/tests/__init__.py +0 -0
- mindsdb/integrations/handlers/oracle_handler/tests/test_oracle_handler.py +0 -32
- {mindsdb-25.9.3rc1.dist-info → mindsdb-25.10.0rc1.dist-info}/WHEEL +0 -0
- {mindsdb-25.9.3rc1.dist-info → mindsdb-25.10.0rc1.dist-info}/licenses/LICENSE +0 -0
- {mindsdb-25.9.3rc1.dist-info → mindsdb-25.10.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -10,7 +10,6 @@ from pydantic import BaseModel, ValidationError
|
|
|
10
10
|
from sqlalchemy.orm.attributes import flag_modified
|
|
11
11
|
|
|
12
12
|
from mindsdb_sql_parser.ast import BinaryOperation, Constant, Identifier, Select, Update, Delete, Star
|
|
13
|
-
from mindsdb_sql_parser.ast.mindsdb import CreatePredictor
|
|
14
13
|
from mindsdb_sql_parser import parse_sql
|
|
15
14
|
|
|
16
15
|
from mindsdb.integrations.libs.keyword_search_base import KeywordSearchBase
|
|
@@ -23,6 +22,7 @@ from mindsdb.integrations.libs.vectordatabase_handler import (
|
|
|
23
22
|
VectorStoreHandler,
|
|
24
23
|
)
|
|
25
24
|
from mindsdb.integrations.utilities.handler_utils import get_api_key
|
|
25
|
+
from mindsdb.integrations.utilities.handlers.auth_utilities.snowflake import get_validated_jwt
|
|
26
26
|
|
|
27
27
|
from mindsdb.interfaces.agents.constants import DEFAULT_EMBEDDINGS_MODEL_CLASS, MAX_INSERT_BATCH_SIZE
|
|
28
28
|
from mindsdb.interfaces.agents.langchain_agent import create_chat_model, get_llm_provider
|
|
@@ -42,6 +42,7 @@ from mindsdb.api.executor.command_executor import ExecuteCommands
|
|
|
42
42
|
from mindsdb.api.executor.utilities.sql import query_df
|
|
43
43
|
from mindsdb.utilities import log
|
|
44
44
|
from mindsdb.integrations.utilities.rag.rerankers.base_reranker import BaseLLMReranker
|
|
45
|
+
from mindsdb.interfaces.knowledge_base.llm_client import LLMClient
|
|
45
46
|
|
|
46
47
|
logger = log.getLogger(__name__)
|
|
47
48
|
|
|
@@ -72,6 +73,10 @@ def get_model_params(model_params: dict, default_config_key: str):
|
|
|
72
73
|
if not isinstance(model_params, dict):
|
|
73
74
|
raise ValueError("Model parameters must be passed as a JSON object")
|
|
74
75
|
|
|
76
|
+
# if provider mismatches - don't use default values
|
|
77
|
+
if "provider" in model_params and model_params["provider"] != combined_model_params.get("provider"):
|
|
78
|
+
return model_params
|
|
79
|
+
|
|
75
80
|
combined_model_params.update(model_params)
|
|
76
81
|
|
|
77
82
|
combined_model_params.pop("use_default_llm", None)
|
|
@@ -142,6 +147,28 @@ def to_json(obj):
|
|
|
142
147
|
return obj
|
|
143
148
|
|
|
144
149
|
|
|
150
|
+
def rotate_provider_api_key(params):
|
|
151
|
+
"""
|
|
152
|
+
Check api key for specific providers. At the moment it checks and updated jwt token of snowflake provider
|
|
153
|
+
:param params: input params, can be modified by this function
|
|
154
|
+
:return: a new api key if it is refreshed
|
|
155
|
+
"""
|
|
156
|
+
provider = params.get("provider").lower()
|
|
157
|
+
|
|
158
|
+
if provider == "snowflake":
|
|
159
|
+
api_key = params.get("api_key")
|
|
160
|
+
api_key2 = get_validated_jwt(
|
|
161
|
+
api_key,
|
|
162
|
+
account=params.get("snowflake_account_id"),
|
|
163
|
+
user=params.get("user"),
|
|
164
|
+
private_key=params.get("private_key"),
|
|
165
|
+
)
|
|
166
|
+
if api_key2 != api_key:
|
|
167
|
+
# update keys
|
|
168
|
+
params["api_key"] = api_key2
|
|
169
|
+
return api_key2
|
|
170
|
+
|
|
171
|
+
|
|
145
172
|
class KnowledgeBaseTable:
|
|
146
173
|
"""
|
|
147
174
|
Knowledge base table interface
|
|
@@ -194,6 +221,22 @@ class KnowledgeBaseTable:
|
|
|
194
221
|
executor = KnowledgeBaseQueryExecutor(self)
|
|
195
222
|
df = executor.run(query)
|
|
196
223
|
|
|
224
|
+
# copy metadata to columns
|
|
225
|
+
if "metadata" in df.columns:
|
|
226
|
+
meta_columns = self._get_allowed_metadata_columns()
|
|
227
|
+
if meta_columns:
|
|
228
|
+
meta_data = pd.json_normalize(df["metadata"])
|
|
229
|
+
# exclude absent columns and used colunns
|
|
230
|
+
df_columns = list(df.columns)
|
|
231
|
+
meta_columns = list(set(meta_columns).intersection(meta_data.columns).difference(df_columns))
|
|
232
|
+
|
|
233
|
+
# add columns
|
|
234
|
+
df = df.join(meta_data[meta_columns])
|
|
235
|
+
|
|
236
|
+
# put metadata in the end
|
|
237
|
+
df_columns.remove("metadata")
|
|
238
|
+
df = df[df_columns + meta_columns + ["metadata"]]
|
|
239
|
+
|
|
197
240
|
if (
|
|
198
241
|
query_copy.group_by is not None
|
|
199
242
|
or query_copy.order_by is not None
|
|
@@ -384,7 +427,7 @@ class KnowledgeBaseTable:
|
|
|
384
427
|
# if relevance filtering method is strictly GREATER THAN we filter the df
|
|
385
428
|
if gt_filtering:
|
|
386
429
|
relevance_scores = TableField.RELEVANCE.value
|
|
387
|
-
df = df[relevance_scores > relevance_threshold]
|
|
430
|
+
df = df[df[relevance_scores] > relevance_threshold]
|
|
388
431
|
|
|
389
432
|
return df
|
|
390
433
|
|
|
@@ -402,6 +445,7 @@ class KnowledgeBaseTable:
|
|
|
402
445
|
return [col.lower() for col in columns]
|
|
403
446
|
|
|
404
447
|
def score_documents(self, query_text, documents, reranking_model_params):
|
|
448
|
+
rotate_provider_api_key(reranking_model_params)
|
|
405
449
|
reranker = get_reranking_model_from_params(reranking_model_params)
|
|
406
450
|
return reranker.get_scores(query_text, documents)
|
|
407
451
|
|
|
@@ -412,6 +456,15 @@ class KnowledgeBaseTable:
|
|
|
412
456
|
if reranking_model_params and query_text and len(df) > 0 and not disable_reranking:
|
|
413
457
|
# Use reranker for relevance score
|
|
414
458
|
|
|
459
|
+
new_api_key = rotate_provider_api_key(reranking_model_params)
|
|
460
|
+
if new_api_key:
|
|
461
|
+
# update key
|
|
462
|
+
if "reranking_model" not in self._kb.params:
|
|
463
|
+
self._kb.params["reranking_model"] = {}
|
|
464
|
+
self._kb.params["reranking_model"]["api_key"] = new_api_key
|
|
465
|
+
flag_modified(self._kb, "params")
|
|
466
|
+
db.session.commit()
|
|
467
|
+
|
|
415
468
|
# Apply custom filtering threshold if provided
|
|
416
469
|
if relevance_threshold is not None:
|
|
417
470
|
reranking_model_params["filtering_threshold"] = relevance_threshold
|
|
@@ -864,10 +917,12 @@ class KnowledgeBaseTable:
|
|
|
864
917
|
model_id = self._kb.embedding_model_id
|
|
865
918
|
|
|
866
919
|
if model_id is None:
|
|
867
|
-
# call litellm handler
|
|
868
920
|
messages = list(df[TableField.CONTENT.value])
|
|
869
921
|
embedding_params = get_model_params(self._kb.params.get("embedding_model", {}), "default_embedding_model")
|
|
870
|
-
|
|
922
|
+
|
|
923
|
+
llm_client = LLMClient(embedding_params, session=self.session)
|
|
924
|
+
results = llm_client.embeddings(messages)
|
|
925
|
+
|
|
871
926
|
results = [[val] for val in results]
|
|
872
927
|
return pd.DataFrame(results, columns=[TableField.EMBEDDINGS.value])
|
|
873
928
|
|
|
@@ -1053,6 +1108,26 @@ class KnowledgeBaseController:
|
|
|
1053
1108
|
def __init__(self, session) -> None:
|
|
1054
1109
|
self.session = session
|
|
1055
1110
|
|
|
1111
|
+
def _check_kb_input_params(self, params):
|
|
1112
|
+
# check names and types KB params
|
|
1113
|
+
try:
|
|
1114
|
+
KnowledgeBaseInputParams.model_validate(params)
|
|
1115
|
+
except ValidationError as e:
|
|
1116
|
+
problems = []
|
|
1117
|
+
for error in e.errors():
|
|
1118
|
+
parameter = ".".join([str(i) for i in error["loc"]])
|
|
1119
|
+
param_type = error["type"]
|
|
1120
|
+
if param_type == "extra_forbidden":
|
|
1121
|
+
msg = f"Parameter '{parameter}' is not allowed"
|
|
1122
|
+
else:
|
|
1123
|
+
msg = f"Error in '{parameter}' (type: {param_type}): {error['msg']}. Input: {repr(error['input'])}"
|
|
1124
|
+
problems.append(msg)
|
|
1125
|
+
|
|
1126
|
+
msg = "\n".join(problems)
|
|
1127
|
+
if len(problems) > 1:
|
|
1128
|
+
msg = "\n" + msg
|
|
1129
|
+
raise ValueError(f"Problem with knowledge base parameters: {msg}") from e
|
|
1130
|
+
|
|
1056
1131
|
def add(
|
|
1057
1132
|
self,
|
|
1058
1133
|
name: str,
|
|
@@ -1070,36 +1145,18 @@ class KnowledgeBaseController:
|
|
|
1070
1145
|
:param is_sparse: Whether to use sparse vectors for embeddings
|
|
1071
1146
|
:param vector_size: Optional size specification for vectors, required when is_sparse=True
|
|
1072
1147
|
"""
|
|
1073
|
-
if not name.islower():
|
|
1074
|
-
raise ValueError(f"The name must be in lower case: {name}")
|
|
1075
1148
|
|
|
1076
1149
|
# fill variables
|
|
1077
1150
|
params = variables_controller.fill_parameters(params)
|
|
1078
1151
|
|
|
1079
|
-
try:
|
|
1080
|
-
KnowledgeBaseInputParams.model_validate(params)
|
|
1081
|
-
except ValidationError as e:
|
|
1082
|
-
problems = []
|
|
1083
|
-
for error in e.errors():
|
|
1084
|
-
parameter = ".".join([str(i) for i in error["loc"]])
|
|
1085
|
-
param_type = error["type"]
|
|
1086
|
-
if param_type == "extra_forbidden":
|
|
1087
|
-
msg = f"Parameter '{parameter}' is not allowed"
|
|
1088
|
-
else:
|
|
1089
|
-
msg = f"Error in '{parameter}' (type: {param_type}): {error['msg']}. Input: {repr(error['input'])}"
|
|
1090
|
-
problems.append(msg)
|
|
1091
|
-
|
|
1092
|
-
msg = "\n".join(problems)
|
|
1093
|
-
if len(problems) > 1:
|
|
1094
|
-
msg = "\n" + msg
|
|
1095
|
-
raise ValueError(f"Problem with knowledge base parameters: {msg}") from e
|
|
1096
|
-
|
|
1097
1152
|
# Validate preprocessing config first if provided
|
|
1098
1153
|
if preprocessing_config is not None:
|
|
1099
1154
|
PreprocessingConfig(**preprocessing_config) # Validate before storing
|
|
1100
1155
|
params = params or {}
|
|
1101
1156
|
params["preprocessing"] = preprocessing_config
|
|
1102
1157
|
|
|
1158
|
+
self._check_kb_input_params(params)
|
|
1159
|
+
|
|
1103
1160
|
# Check if vector_size is provided when using sparse vectors
|
|
1104
1161
|
is_sparse = params.get("is_sparse")
|
|
1105
1162
|
vector_size = params.get("vector_size")
|
|
@@ -1110,8 +1167,6 @@ class KnowledgeBaseController:
|
|
|
1110
1167
|
project = self.session.database_controller.get_project(project_name)
|
|
1111
1168
|
project_id = project.id
|
|
1112
1169
|
|
|
1113
|
-
# not difference between cases in sql
|
|
1114
|
-
name = name.lower()
|
|
1115
1170
|
# check if knowledge base already exists
|
|
1116
1171
|
kb = self.get(name, project_id)
|
|
1117
1172
|
if kb is not None:
|
|
@@ -1123,42 +1178,25 @@ class KnowledgeBaseController:
|
|
|
1123
1178
|
params["embedding_model"] = embedding_params
|
|
1124
1179
|
|
|
1125
1180
|
# if model_name is None: # Legacy
|
|
1126
|
-
|
|
1181
|
+
self._check_embedding_model(
|
|
1127
1182
|
project.name,
|
|
1128
1183
|
params=embedding_params,
|
|
1129
1184
|
kb_name=name,
|
|
1130
1185
|
)
|
|
1131
|
-
if model_name is not None:
|
|
1132
|
-
params["created_embedding_model"] = model_name
|
|
1133
|
-
|
|
1134
|
-
embedding_model_id = None
|
|
1135
|
-
if model_name is not None:
|
|
1136
|
-
model = self.session.model_controller.get_model(name=model_name, project_name=project.name)
|
|
1137
|
-
model_record = db.Predictor.query.get(model["id"])
|
|
1138
|
-
embedding_model_id = model_record.id
|
|
1139
|
-
|
|
1140
|
-
if model_record.learn_args.get("using", {}).get("sparse"):
|
|
1141
|
-
is_sparse = True
|
|
1142
1186
|
|
|
1143
1187
|
# if params.get("reranking_model", {}) is bool and False we evaluate it to empty dictionary
|
|
1144
1188
|
reranking_model_params = params.get("reranking_model", {})
|
|
1145
1189
|
|
|
1146
1190
|
if isinstance(reranking_model_params, bool) and not reranking_model_params:
|
|
1147
1191
|
params["reranking_model"] = {}
|
|
1148
|
-
# if params.get("reranking_model", {}) is string and false in any case we evaluate it to empty dictionary
|
|
1149
|
-
if isinstance(reranking_model_params, str) and reranking_model_params.lower() == "false":
|
|
1150
|
-
params["reranking_model"] = {}
|
|
1151
1192
|
|
|
1152
1193
|
reranking_model_params = get_model_params(reranking_model_params, "default_reranking_model")
|
|
1153
1194
|
params["reranking_model"] = reranking_model_params
|
|
1154
1195
|
if reranking_model_params:
|
|
1155
1196
|
# Get reranking model from params.
|
|
1156
1197
|
# This is called here to check validaity of the parameters.
|
|
1157
|
-
|
|
1158
|
-
|
|
1159
|
-
reranker.get_scores("test", ["test"])
|
|
1160
|
-
except (ValueError, RuntimeError) as e:
|
|
1161
|
-
raise RuntimeError(f"Problem with reranker config: {e}") from e
|
|
1198
|
+
rotate_provider_api_key(reranking_model_params)
|
|
1199
|
+
self._test_reranking(reranking_model_params)
|
|
1162
1200
|
|
|
1163
1201
|
# search for the vector database table
|
|
1164
1202
|
if storage is None:
|
|
@@ -1211,13 +1249,115 @@ class KnowledgeBaseController:
|
|
|
1211
1249
|
project_id=project_id,
|
|
1212
1250
|
vector_database_id=vector_database_id,
|
|
1213
1251
|
vector_database_table=vector_table_name,
|
|
1214
|
-
embedding_model_id=
|
|
1252
|
+
embedding_model_id=None,
|
|
1215
1253
|
params=params,
|
|
1216
1254
|
)
|
|
1217
1255
|
db.session.add(kb)
|
|
1218
1256
|
db.session.commit()
|
|
1219
1257
|
return kb
|
|
1220
1258
|
|
|
1259
|
+
def update(
|
|
1260
|
+
self,
|
|
1261
|
+
name: str,
|
|
1262
|
+
project_name: str,
|
|
1263
|
+
params: dict,
|
|
1264
|
+
preprocessing_config: Optional[dict] = None,
|
|
1265
|
+
) -> db.KnowledgeBase:
|
|
1266
|
+
"""
|
|
1267
|
+
Update the knowledge base
|
|
1268
|
+
:param name: The name of the knowledge base
|
|
1269
|
+
:param project_name: Current project name
|
|
1270
|
+
:param params: The parameters to update
|
|
1271
|
+
:param preprocessing_config: Optional preprocessing configuration to validate and store
|
|
1272
|
+
"""
|
|
1273
|
+
|
|
1274
|
+
# fill variables
|
|
1275
|
+
params = variables_controller.fill_parameters(params)
|
|
1276
|
+
|
|
1277
|
+
# Validate preprocessing config first if provided
|
|
1278
|
+
if preprocessing_config is not None:
|
|
1279
|
+
PreprocessingConfig(**preprocessing_config) # Validate before storing
|
|
1280
|
+
params = params or {}
|
|
1281
|
+
params["preprocessing"] = preprocessing_config
|
|
1282
|
+
|
|
1283
|
+
self._check_kb_input_params(params)
|
|
1284
|
+
|
|
1285
|
+
# get project id
|
|
1286
|
+
project = self.session.database_controller.get_project(project_name)
|
|
1287
|
+
project_id = project.id
|
|
1288
|
+
|
|
1289
|
+
# get existed KB
|
|
1290
|
+
kb = self.get(name.lower(), project_id)
|
|
1291
|
+
if kb is None:
|
|
1292
|
+
raise EntityNotExistsError("Knowledge base doesn't exists", name)
|
|
1293
|
+
|
|
1294
|
+
if "embedding_model" in params:
|
|
1295
|
+
new_config = params["embedding_model"]
|
|
1296
|
+
# update embedding
|
|
1297
|
+
embed_params = kb.params.get("embedding_model", {})
|
|
1298
|
+
if not embed_params:
|
|
1299
|
+
# maybe old version of KB
|
|
1300
|
+
raise ValueError("No embedding config to update")
|
|
1301
|
+
|
|
1302
|
+
# some parameters are not allowed to update
|
|
1303
|
+
for key in ("provider", "model_name"):
|
|
1304
|
+
if key in new_config and new_config[key] != embed_params.get(key):
|
|
1305
|
+
raise ValueError(f"You can't update '{key}' setting")
|
|
1306
|
+
|
|
1307
|
+
embed_params.update(new_config)
|
|
1308
|
+
|
|
1309
|
+
self._check_embedding_model(
|
|
1310
|
+
project.name,
|
|
1311
|
+
params=embed_params,
|
|
1312
|
+
kb_name=name,
|
|
1313
|
+
)
|
|
1314
|
+
kb.params["embedding_model"] = embed_params
|
|
1315
|
+
|
|
1316
|
+
if "reranking_model" in params:
|
|
1317
|
+
new_config = params["reranking_model"]
|
|
1318
|
+
# update embedding
|
|
1319
|
+
rerank_params = kb.params.get("reranking_model", {})
|
|
1320
|
+
|
|
1321
|
+
if new_config is False:
|
|
1322
|
+
# disable reranking
|
|
1323
|
+
rerank_params = {}
|
|
1324
|
+
elif "provider" in new_config and new_config["provider"] != rerank_params.get("provider"):
|
|
1325
|
+
# use new config (and include default config)
|
|
1326
|
+
rerank_params = get_model_params(new_config, "default_reranking_model")
|
|
1327
|
+
else:
|
|
1328
|
+
# update current config
|
|
1329
|
+
rerank_params.update(new_config)
|
|
1330
|
+
|
|
1331
|
+
if rerank_params:
|
|
1332
|
+
self._test_reranking(rerank_params)
|
|
1333
|
+
|
|
1334
|
+
kb.params["reranking_model"] = rerank_params
|
|
1335
|
+
|
|
1336
|
+
# update other keys
|
|
1337
|
+
for key in ["id_column", "metadata_columns", "content_columns", "preprocessing"]:
|
|
1338
|
+
if key in params:
|
|
1339
|
+
kb.params[key] = params[key]
|
|
1340
|
+
|
|
1341
|
+
flag_modified(kb, "params")
|
|
1342
|
+
db.session.commit()
|
|
1343
|
+
|
|
1344
|
+
return self.get(name.lower(), project_id)
|
|
1345
|
+
|
|
1346
|
+
def _test_reranking(self, params):
|
|
1347
|
+
try:
|
|
1348
|
+
reranker = get_reranking_model_from_params(params)
|
|
1349
|
+
reranker.get_scores("test", ["test"])
|
|
1350
|
+
except (ValueError, RuntimeError) as e:
|
|
1351
|
+
if params["provider"] in ("azure_openai", "openai") and params.get("method") != "no-logprobs":
|
|
1352
|
+
# check with no-logprobs
|
|
1353
|
+
params["method"] = "no-logprobs"
|
|
1354
|
+
self._test_reranking(params)
|
|
1355
|
+
logger.warning(
|
|
1356
|
+
f"logprobs is not supported for this model: {params.get('model_name')}. using no-logprobs mode"
|
|
1357
|
+
)
|
|
1358
|
+
else:
|
|
1359
|
+
raise RuntimeError(f"Problem with reranker config: {e}") from e
|
|
1360
|
+
|
|
1221
1361
|
def _create_persistent_pgvector(self, params=None):
|
|
1222
1362
|
"""Create default vector database for knowledge base, if not specified"""
|
|
1223
1363
|
vector_store_name = "kb_pgvector_store"
|
|
@@ -1244,11 +1384,11 @@ class KnowledgeBaseController:
|
|
|
1244
1384
|
self.session.integration_controller.add(vector_store_name, engine, connection_args)
|
|
1245
1385
|
return vector_store_name
|
|
1246
1386
|
|
|
1247
|
-
def
|
|
1248
|
-
"""
|
|
1249
|
-
model_name = f"kb_embedding_{kb_name}"
|
|
1387
|
+
def _check_embedding_model(self, project_name, params: dict = None, kb_name=""):
|
|
1388
|
+
"""check embedding model for knowledge base"""
|
|
1250
1389
|
|
|
1251
|
-
#
|
|
1390
|
+
# if mindsdb model from old KB exists - drop it
|
|
1391
|
+
model_name = f"kb_embedding_{kb_name}"
|
|
1252
1392
|
try:
|
|
1253
1393
|
model = self.session.model_controller.get_model(model_name, project_name=project_name)
|
|
1254
1394
|
if model is not None:
|
|
@@ -1260,63 +1400,18 @@ class KnowledgeBaseController:
|
|
|
1260
1400
|
raise ValueError("'provider' parameter is required for embedding model")
|
|
1261
1401
|
|
|
1262
1402
|
# check available providers
|
|
1263
|
-
avail_providers = ("openai", "azure_openai", "bedrock", "gemini", "google")
|
|
1403
|
+
avail_providers = ("openai", "azure_openai", "bedrock", "gemini", "google", "ollama")
|
|
1264
1404
|
if params["provider"] not in avail_providers:
|
|
1265
1405
|
raise ValueError(
|
|
1266
1406
|
f"Wrong embedding provider: {params['provider']}. Available providers: {', '.join(avail_providers)}"
|
|
1267
1407
|
)
|
|
1268
1408
|
|
|
1269
|
-
|
|
1270
|
-
# try use litellm
|
|
1271
|
-
try:
|
|
1272
|
-
KnowledgeBaseTable.call_litellm_embedding(self.session, params, ["test"])
|
|
1273
|
-
except Exception as e:
|
|
1274
|
-
raise RuntimeError(f"Problem with embedding model config: {e}") from e
|
|
1275
|
-
return
|
|
1276
|
-
|
|
1277
|
-
params = copy.deepcopy(params)
|
|
1278
|
-
if "provider" in params:
|
|
1279
|
-
engine = params.pop("provider").lower()
|
|
1280
|
-
|
|
1281
|
-
api_key = get_api_key(engine, params, strict=False)
|
|
1282
|
-
if api_key is None:
|
|
1283
|
-
if "api_key" in params:
|
|
1284
|
-
params.pop("api_key")
|
|
1285
|
-
else:
|
|
1286
|
-
raise ValueError("'api_key' parameter is required for embedding model")
|
|
1287
|
-
|
|
1288
|
-
if engine == "azure_openai":
|
|
1289
|
-
engine = "openai"
|
|
1290
|
-
params["provider"] = "azure"
|
|
1291
|
-
|
|
1292
|
-
if engine == "openai":
|
|
1293
|
-
if "question_column" not in params:
|
|
1294
|
-
params["question_column"] = "content"
|
|
1295
|
-
if api_key:
|
|
1296
|
-
params[f"{engine}_api_key"] = api_key
|
|
1297
|
-
if "api_key" in params:
|
|
1298
|
-
params.pop("api_key")
|
|
1299
|
-
if "base_url" in params:
|
|
1300
|
-
params["api_base"] = params.pop("base_url")
|
|
1301
|
-
|
|
1302
|
-
params["engine"] = engine
|
|
1303
|
-
params["join_learn_process"] = True
|
|
1304
|
-
params["mode"] = "embedding"
|
|
1305
|
-
|
|
1306
|
-
# Include API key if provided.
|
|
1307
|
-
statement = CreatePredictor(
|
|
1308
|
-
name=Identifier(parts=[project_name, model_name]),
|
|
1309
|
-
using=params,
|
|
1310
|
-
targets=[Identifier(parts=[TableField.EMBEDDINGS.value])],
|
|
1311
|
-
)
|
|
1409
|
+
llm_client = LLMClient(params, session=self.session)
|
|
1312
1410
|
|
|
1313
|
-
|
|
1314
|
-
|
|
1315
|
-
|
|
1316
|
-
|
|
1317
|
-
if record["STATUS"] == "error":
|
|
1318
|
-
raise ValueError("Embedding model error:" + record["ERROR"])
|
|
1319
|
-
return model_name
|
|
1411
|
+
try:
|
|
1412
|
+
llm_client.embeddings(["test"])
|
|
1413
|
+
except Exception as e:
|
|
1414
|
+
raise RuntimeError(f"Problem with embedding model config: {e}") from e
|
|
1320
1415
|
|
|
1321
1416
|
def delete(self, name: str, project_name: int, if_exists: bool = False) -> None:
|
|
1322
1417
|
"""
|
|
@@ -1422,12 +1517,6 @@ class KnowledgeBaseController:
|
|
|
1422
1517
|
kb_table = self.get_table(table_name, project_id)
|
|
1423
1518
|
kb_table.create_index()
|
|
1424
1519
|
|
|
1425
|
-
def update(self, name: str, project_id: int, **kwargs) -> db.KnowledgeBase:
|
|
1426
|
-
"""
|
|
1427
|
-
Update a knowledge base record
|
|
1428
|
-
"""
|
|
1429
|
-
raise NotImplementedError()
|
|
1430
|
-
|
|
1431
1520
|
def evaluate(self, table_name: str, project_name: str, params: dict = None) -> pd.DataFrame:
|
|
1432
1521
|
"""
|
|
1433
1522
|
Run evaluate and/or create test data for evaluation
|
|
@@ -2,6 +2,7 @@ import json
|
|
|
2
2
|
import math
|
|
3
3
|
import re
|
|
4
4
|
import time
|
|
5
|
+
import copy
|
|
5
6
|
from typing import List
|
|
6
7
|
|
|
7
8
|
import pandas as pd
|
|
@@ -10,6 +11,7 @@ import datetime as dt
|
|
|
10
11
|
from mindsdb.api.executor.sql_query.result_set import ResultSet
|
|
11
12
|
from mindsdb_sql_parser import Identifier, Select, Constant, Star, parse_sql, BinaryOperation
|
|
12
13
|
from mindsdb.utilities import log
|
|
14
|
+
from mindsdb.utilities.config import config
|
|
13
15
|
|
|
14
16
|
from mindsdb.interfaces.knowledge_base.llm_client import LLMClient
|
|
15
17
|
|
|
@@ -105,7 +107,12 @@ class EvaluateBase:
|
|
|
105
107
|
if llm_params is None:
|
|
106
108
|
llm_params = self.kb._kb.params.get("reranking_model")
|
|
107
109
|
|
|
108
|
-
|
|
110
|
+
params = copy.deepcopy(config.get("default_llm", {}))
|
|
111
|
+
|
|
112
|
+
if llm_params:
|
|
113
|
+
params.update(llm_params)
|
|
114
|
+
|
|
115
|
+
self.llm_client = LLMClient(params)
|
|
109
116
|
|
|
110
117
|
def generate_test_data(self, gen_params: dict) -> pd.DataFrame:
|
|
111
118
|
# Extract source data (from users query or from KB itself) and call `generate` to get test data
|
|
@@ -241,6 +248,26 @@ class EvaluateBase:
|
|
|
241
248
|
|
|
242
249
|
return cls(session, kb_table).run_evaluate(params)
|
|
243
250
|
|
|
251
|
+
def generate_question_answer(self, text: str) -> (str, str):
|
|
252
|
+
messages = [
|
|
253
|
+
{"role": "system", "content": GENERATE_QA_SYSTEM_PROMPT},
|
|
254
|
+
{"role": "user", "content": f"\n\nText:\n{text}\n\n"},
|
|
255
|
+
]
|
|
256
|
+
answer = self.llm_client.completion(messages, json_output=True)[0]
|
|
257
|
+
|
|
258
|
+
# Sanitize the response by removing markdown code block formatting like ```json
|
|
259
|
+
sanitized_answer = sanitize_json_response(answer)
|
|
260
|
+
|
|
261
|
+
try:
|
|
262
|
+
output = json.loads(sanitized_answer)
|
|
263
|
+
except json.JSONDecodeError:
|
|
264
|
+
raise ValueError(f"Could not parse response from LLM: {answer}")
|
|
265
|
+
|
|
266
|
+
if "query" not in output or "reference_answer" not in output:
|
|
267
|
+
raise ValueError("Cant find question/answer in LLM response")
|
|
268
|
+
|
|
269
|
+
return output.get("query"), output.get("reference_answer")
|
|
270
|
+
|
|
244
271
|
|
|
245
272
|
class EvaluateRerank(EvaluateBase):
|
|
246
273
|
"""
|
|
@@ -268,28 +295,12 @@ class EvaluateRerank(EvaluateBase):
|
|
|
268
295
|
df["id"] = df.index
|
|
269
296
|
return df
|
|
270
297
|
|
|
271
|
-
def generate_question_answer(self, text: str) -> (str, str):
|
|
272
|
-
messages = [
|
|
273
|
-
{"role": "system", "content": GENERATE_QA_SYSTEM_PROMPT},
|
|
274
|
-
{"role": "user", "content": f"\n\nText:\n{text}\n\n"},
|
|
275
|
-
]
|
|
276
|
-
answer = self.llm_client.completion(messages, json_output=True)
|
|
277
|
-
|
|
278
|
-
# Sanitize the response by removing markdown code block formatting like ```json
|
|
279
|
-
sanitized_answer = sanitize_json_response(answer)
|
|
280
|
-
|
|
281
|
-
try:
|
|
282
|
-
output = json.loads(sanitized_answer)
|
|
283
|
-
except json.JSONDecodeError:
|
|
284
|
-
raise ValueError(f"Could not parse response from LLM: {answer}")
|
|
285
|
-
|
|
286
|
-
if "query" not in output or "reference_answer" not in output:
|
|
287
|
-
raise ValueError("Cant find question/answer in LLM response")
|
|
288
|
-
|
|
289
|
-
return output.get("query"), output.get("reference_answer")
|
|
290
|
-
|
|
291
298
|
def evaluate(self, test_data: pd.DataFrame) -> pd.DataFrame:
|
|
292
299
|
json_to_log_list = []
|
|
300
|
+
if {"question", "answer"} - set(test_data.columns):
|
|
301
|
+
raise KeyError(
|
|
302
|
+
f'Test data must contain "question" and "answer" columns. Columns in the provided test data: {list(test_data.columns)}'
|
|
303
|
+
)
|
|
293
304
|
questions = test_data.to_dict("records")
|
|
294
305
|
|
|
295
306
|
for i, item in enumerate(questions):
|
|
@@ -483,28 +494,12 @@ class EvaluateDocID(EvaluateBase):
|
|
|
483
494
|
df = pd.DataFrame(qa_data)
|
|
484
495
|
return df
|
|
485
496
|
|
|
486
|
-
def generate_question_answer(self, text: str) -> (str, str):
|
|
487
|
-
messages = [
|
|
488
|
-
{"role": "system", "content": GENERATE_QA_SYSTEM_PROMPT},
|
|
489
|
-
{"role": "user", "content": f"\n\nText:\n{text}\n\n"},
|
|
490
|
-
]
|
|
491
|
-
answer = self.llm_client.completion(messages, json_output=True)
|
|
492
|
-
|
|
493
|
-
# Sanitize the response by removing markdown code block formatting like ```json
|
|
494
|
-
sanitized_answer = sanitize_json_response(answer)
|
|
495
|
-
|
|
496
|
-
try:
|
|
497
|
-
output = json.loads(sanitized_answer)
|
|
498
|
-
except json.JSONDecodeError:
|
|
499
|
-
raise ValueError(f"Could not parse response from LLM: {answer}")
|
|
500
|
-
|
|
501
|
-
if "query" not in output or "reference_answer" not in output:
|
|
502
|
-
raise ValueError("Cant find question/answer in LLM response")
|
|
503
|
-
|
|
504
|
-
return output.get("query"), output.get("reference_answer")
|
|
505
|
-
|
|
506
497
|
def evaluate(self, test_data: pd.DataFrame) -> pd.DataFrame:
|
|
507
498
|
stats = []
|
|
499
|
+
if {"question", "doc_id"} - set(test_data.columns):
|
|
500
|
+
raise KeyError(
|
|
501
|
+
f'Test data must contain "question" and "doc_id" columns. Columns in the provided test data: {list(test_data.columns)}'
|
|
502
|
+
)
|
|
508
503
|
questions = test_data.to_dict("records")
|
|
509
504
|
|
|
510
505
|
for i, item in enumerate(questions):
|
|
@@ -43,7 +43,18 @@ class KnowledgeBaseQueryExecutor:
|
|
|
43
43
|
if isinstance(node, BinaryOperation):
|
|
44
44
|
if isinstance(node.args[0], Identifier):
|
|
45
45
|
parts = node.args[0].parts
|
|
46
|
+
|
|
47
|
+
# map chunk_content to content
|
|
48
|
+
if parts[0].lower() == "chunk_content":
|
|
49
|
+
parts[0] = self.content_column
|
|
50
|
+
|
|
46
51
|
if len(parts) == 1 and parts[0].lower() == self.content_column:
|
|
52
|
+
if "LIKE" in node.op.upper():
|
|
53
|
+
# remove '%'
|
|
54
|
+
arg = node.args[1]
|
|
55
|
+
if isinstance(arg, Constant) and isinstance(arg.value, str):
|
|
56
|
+
arg.value = arg.value.strip(" %")
|
|
57
|
+
|
|
47
58
|
return True
|
|
48
59
|
return False
|
|
49
60
|
|