MindsDB 25.4.1.0__py3-none-any.whl → 25.4.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of MindsDB might be problematic. Click here for more details.
- mindsdb/__about__.py +1 -1
- mindsdb/api/executor/command_executor.py +91 -61
- mindsdb/api/executor/data_types/answer.py +9 -12
- mindsdb/api/executor/datahub/classes/response.py +11 -0
- mindsdb/api/executor/datahub/datanodes/datanode.py +4 -4
- mindsdb/api/executor/datahub/datanodes/information_schema_datanode.py +10 -11
- mindsdb/api/executor/datahub/datanodes/integration_datanode.py +22 -16
- mindsdb/api/executor/datahub/datanodes/mindsdb_tables.py +43 -1
- mindsdb/api/executor/datahub/datanodes/project_datanode.py +20 -20
- mindsdb/api/executor/planner/plan_join.py +2 -2
- mindsdb/api/executor/planner/query_plan.py +1 -0
- mindsdb/api/executor/planner/query_planner.py +86 -14
- mindsdb/api/executor/planner/steps.py +11 -2
- mindsdb/api/executor/sql_query/result_set.py +10 -7
- mindsdb/api/executor/sql_query/sql_query.py +69 -84
- mindsdb/api/executor/sql_query/steps/__init__.py +1 -0
- mindsdb/api/executor/sql_query/steps/delete_step.py +2 -3
- mindsdb/api/executor/sql_query/steps/fetch_dataframe.py +5 -3
- mindsdb/api/executor/sql_query/steps/fetch_dataframe_partition.py +288 -0
- mindsdb/api/executor/sql_query/steps/insert_step.py +2 -2
- mindsdb/api/executor/sql_query/steps/prepare_steps.py +2 -2
- mindsdb/api/executor/sql_query/steps/subselect_step.py +20 -8
- mindsdb/api/executor/sql_query/steps/update_step.py +4 -6
- mindsdb/api/http/namespaces/sql.py +4 -1
- mindsdb/api/mysql/mysql_proxy/data_types/mysql_packets/ok_packet.py +1 -1
- mindsdb/api/mysql/mysql_proxy/executor/mysql_executor.py +4 -27
- mindsdb/api/mysql/mysql_proxy/libs/constants/mysql.py +1 -0
- mindsdb/api/mysql/mysql_proxy/mysql_proxy.py +38 -37
- mindsdb/integrations/handlers/chromadb_handler/chromadb_handler.py +23 -13
- mindsdb/integrations/handlers/langchain_embedding_handler/langchain_embedding_handler.py +17 -16
- mindsdb/integrations/handlers/langchain_handler/langchain_handler.py +1 -0
- mindsdb/integrations/handlers/mssql_handler/mssql_handler.py +1 -1
- mindsdb/integrations/handlers/mysql_handler/mysql_handler.py +3 -2
- mindsdb/integrations/handlers/oracle_handler/oracle_handler.py +4 -4
- mindsdb/integrations/handlers/pgvector_handler/pgvector_handler.py +26 -16
- mindsdb/integrations/handlers/postgres_handler/postgres_handler.py +36 -7
- mindsdb/integrations/handlers/redshift_handler/redshift_handler.py +1 -1
- mindsdb/integrations/handlers/snowflake_handler/snowflake_handler.py +18 -11
- mindsdb/integrations/libs/llm/config.py +11 -1
- mindsdb/integrations/libs/llm/utils.py +12 -0
- mindsdb/integrations/libs/ml_handler_process/learn_process.py +1 -2
- mindsdb/integrations/libs/response.py +9 -4
- mindsdb/integrations/libs/vectordatabase_handler.py +17 -5
- mindsdb/integrations/utilities/rag/rerankers/reranker_compressor.py +8 -98
- mindsdb/interfaces/agents/constants.py +12 -1
- mindsdb/interfaces/agents/langchain_agent.py +6 -0
- mindsdb/interfaces/database/log.py +8 -9
- mindsdb/interfaces/database/projects.py +1 -5
- mindsdb/interfaces/functions/controller.py +59 -17
- mindsdb/interfaces/functions/to_markdown.py +194 -0
- mindsdb/interfaces/jobs/jobs_controller.py +3 -3
- mindsdb/interfaces/knowledge_base/controller.py +223 -97
- mindsdb/interfaces/knowledge_base/preprocessing/document_preprocessor.py +3 -14
- mindsdb/interfaces/query_context/context_controller.py +224 -1
- mindsdb/interfaces/storage/db.py +23 -0
- mindsdb/migrations/versions/2025-03-21_fda503400e43_queries.py +45 -0
- mindsdb/utilities/context_executor.py +1 -1
- mindsdb/utilities/partitioning.py +35 -20
- {mindsdb-25.4.1.0.dist-info → mindsdb-25.4.2.1.dist-info}/METADATA +227 -224
- {mindsdb-25.4.1.0.dist-info → mindsdb-25.4.2.1.dist-info}/RECORD +63 -59
- {mindsdb-25.4.1.0.dist-info → mindsdb-25.4.2.1.dist-info}/WHEEL +0 -0
- {mindsdb-25.4.1.0.dist-info → mindsdb-25.4.2.1.dist-info}/licenses/LICENSE +0 -0
- {mindsdb-25.4.1.0.dist-info → mindsdb-25.4.2.1.dist-info}/top_level.txt +0 -0
|
@@ -27,6 +27,8 @@ from mindsdb.integrations.libs.vectordatabase_handler import (
|
|
|
27
27
|
)
|
|
28
28
|
from mindsdb.integrations.utilities.rag.rag_pipeline_builder import RAG
|
|
29
29
|
from mindsdb.integrations.utilities.rag.config_loader import load_rag_config
|
|
30
|
+
from mindsdb.integrations.utilities.handler_utils import get_api_key
|
|
31
|
+
from mindsdb.integrations.handlers.langchain_embedding_handler.langchain_embedding_handler import construct_model_from_args, row_to_document
|
|
30
32
|
|
|
31
33
|
from mindsdb.interfaces.agents.constants import DEFAULT_EMBEDDINGS_MODEL_CLASS
|
|
32
34
|
from mindsdb.interfaces.agents.langchain_agent import create_chat_model, get_llm_provider
|
|
@@ -35,6 +37,8 @@ from mindsdb.interfaces.knowledge_base.preprocessing.models import Preprocessing
|
|
|
35
37
|
from mindsdb.interfaces.knowledge_base.preprocessing.document_preprocessor import PreprocessorFactory
|
|
36
38
|
from mindsdb.interfaces.model.functions import PredictorRecordNotFound
|
|
37
39
|
from mindsdb.utilities.exception import EntityExistsError, EntityNotExistsError
|
|
40
|
+
from mindsdb.integrations.utilities.sql_utils import FilterCondition, FilterOperator
|
|
41
|
+
from mindsdb.utilities.context import context as ctx
|
|
38
42
|
|
|
39
43
|
from mindsdb.api.executor.command_executor import ExecuteCommands
|
|
40
44
|
from mindsdb.utilities import log
|
|
@@ -49,6 +53,42 @@ KB_TO_VECTORDB_COLUMNS = {
|
|
|
49
53
|
}
|
|
50
54
|
|
|
51
55
|
|
|
56
|
+
def get_embedding_model_from_params(embedding_model_params: dict):
|
|
57
|
+
"""
|
|
58
|
+
Create embedding model from parameters.
|
|
59
|
+
"""
|
|
60
|
+
params_copy = copy.deepcopy(embedding_model_params)
|
|
61
|
+
provider = params_copy.pop('provider', None).lower()
|
|
62
|
+
api_key = get_api_key(provider, params_copy, strict=False) or params_copy.get('api_key')
|
|
63
|
+
# Underscores are replaced because the provider name ultimately gets mapped to a class name.
|
|
64
|
+
# This is mostly to support Azure OpenAI (azure_openai); the mapped class name is 'AzureOpenAIEmbeddings'.
|
|
65
|
+
params_copy['class'] = provider.replace('_', '')
|
|
66
|
+
if provider == 'azure_openai':
|
|
67
|
+
# Azure OpenAI expects the api_key to be passed as 'openai_api_key'.
|
|
68
|
+
params_copy['openai_api_key'] = api_key
|
|
69
|
+
else:
|
|
70
|
+
params_copy[f"{provider}_api_key"] = api_key
|
|
71
|
+
params_copy.pop('api_key', None)
|
|
72
|
+
params_copy['model'] = params_copy.pop('model_name', None)
|
|
73
|
+
|
|
74
|
+
return construct_model_from_args(params_copy)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def get_reranking_model_from_params(reranking_model_params: dict):
|
|
78
|
+
"""
|
|
79
|
+
Create reranking model from parameters.
|
|
80
|
+
"""
|
|
81
|
+
params_copy = copy.deepcopy(reranking_model_params)
|
|
82
|
+
provider = params_copy.pop('provider', "openai").lower()
|
|
83
|
+
if provider != 'openai':
|
|
84
|
+
raise ValueError("Only OpenAI provider is supported for the reranking model.")
|
|
85
|
+
params_copy[f"{provider}_api_key"] = get_api_key(provider, params_copy, strict=False) or params_copy.get('api_key')
|
|
86
|
+
params_copy.pop('api_key', None)
|
|
87
|
+
params_copy['model'] = params_copy.pop('model_name', None)
|
|
88
|
+
|
|
89
|
+
return LLMReranker(**params_copy)
|
|
90
|
+
|
|
91
|
+
|
|
52
92
|
class KnowledgeBaseTable:
|
|
53
93
|
"""
|
|
54
94
|
Knowledge base table interface
|
|
@@ -85,88 +125,125 @@ class KnowledgeBaseTable:
|
|
|
85
125
|
"""
|
|
86
126
|
logger.debug(f"Processing select query: {query}")
|
|
87
127
|
|
|
88
|
-
#
|
|
89
|
-
|
|
90
|
-
|
|
128
|
+
# Extract the content query text for potential reranking
|
|
129
|
+
|
|
130
|
+
db_handler = self.get_vector_db()
|
|
91
131
|
|
|
132
|
+
logger.debug("Replaced content with embeddings in where clause")
|
|
92
133
|
# set table name
|
|
93
134
|
query.from_table = Identifier(parts=[self._kb.vector_database_table])
|
|
94
135
|
logger.debug(f"Set table name to: {self._kb.vector_database_table}")
|
|
95
136
|
|
|
96
|
-
|
|
97
|
-
targets = []
|
|
137
|
+
requested_kb_columns = []
|
|
98
138
|
for target in query.targets:
|
|
99
139
|
if isinstance(target, Star):
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
140
|
+
requested_kb_columns = None
|
|
141
|
+
break
|
|
142
|
+
else:
|
|
143
|
+
requested_kb_columns.append(target.parts[-1].lower())
|
|
144
|
+
|
|
145
|
+
query.targets = [
|
|
146
|
+
Identifier(TableField.ID.value),
|
|
147
|
+
Identifier(TableField.CONTENT.value),
|
|
148
|
+
Identifier(TableField.METADATA.value),
|
|
149
|
+
Identifier(TableField.DISTANCE.value),
|
|
150
|
+
]
|
|
109
151
|
|
|
110
152
|
# Get response from vector db
|
|
111
|
-
db_handler = self.get_vector_db()
|
|
112
153
|
logger.debug(f"Using vector db handler: {type(db_handler)}")
|
|
113
154
|
|
|
114
|
-
conditions
|
|
155
|
+
# extract values from conditions and prepare for vectordb
|
|
156
|
+
conditions = []
|
|
157
|
+
query_text = None
|
|
158
|
+
reranking_threshold = None
|
|
159
|
+
query_conditions = db_handler.extract_conditions(query.where)
|
|
160
|
+
if query_conditions is not None:
|
|
161
|
+
for item in query_conditions:
|
|
162
|
+
if item.column == "reranking_threshold" and item.op.value == "=":
|
|
163
|
+
try:
|
|
164
|
+
reranking_threshold = float(item.value)
|
|
165
|
+
# Validate range: must be between 0 and 1
|
|
166
|
+
if not (0 <= reranking_threshold <= 1):
|
|
167
|
+
raise ValueError(f"reranking_threshold must be between 0 and 1, got: {reranking_threshold}")
|
|
168
|
+
logger.debug(f"Found reranking_threshold in query: {reranking_threshold}")
|
|
169
|
+
except (ValueError, TypeError) as e:
|
|
170
|
+
error_msg = f"Invalid reranking_threshold value: {item.value}. {str(e)}"
|
|
171
|
+
logger.error(error_msg)
|
|
172
|
+
raise ValueError(error_msg)
|
|
173
|
+
elif item.column == TableField.CONTENT.value:
|
|
174
|
+
query_text = item.value
|
|
175
|
+
|
|
176
|
+
# replace content with embeddings
|
|
177
|
+
conditions.append(FilterCondition(
|
|
178
|
+
column=TableField.EMBEDDINGS.value,
|
|
179
|
+
value=self._content_to_embeddings(item.value),
|
|
180
|
+
op=FilterOperator.EQUAL,
|
|
181
|
+
))
|
|
182
|
+
else:
|
|
183
|
+
conditions.append(item)
|
|
184
|
+
|
|
185
|
+
logger.debug(f"Extracted query text: {query_text}")
|
|
186
|
+
|
|
115
187
|
self.addapt_conditions_columns(conditions)
|
|
116
188
|
df = db_handler.dispatch_select(query, conditions)
|
|
189
|
+
df = self.addapt_result_columns(df)
|
|
117
190
|
|
|
118
|
-
|
|
191
|
+
logger.debug(f"Query returned {len(df)} rows")
|
|
192
|
+
logger.debug(f"Columns in response: {df.columns.tolist()}")
|
|
193
|
+
# Check if we have a rerank_model configured in KB params
|
|
119
194
|
|
|
120
|
-
|
|
121
|
-
logger.debug(f"Columns in response: {df.columns.tolist()}")
|
|
122
|
-
# Log a sample of IDs to help diagnose issues
|
|
123
|
-
if not df.empty:
|
|
124
|
-
logger.debug(f"Sample of IDs in response: {df['id'].head().tolist()}")
|
|
125
|
-
else:
|
|
126
|
-
logger.warning("Query returned no data")
|
|
195
|
+
df = self.add_relevance(df, query_text, reranking_threshold)
|
|
127
196
|
|
|
128
|
-
|
|
129
|
-
if
|
|
197
|
+
# filter by targets
|
|
198
|
+
if requested_kb_columns is not None:
|
|
199
|
+
df = df[requested_kb_columns]
|
|
200
|
+
return df
|
|
201
|
+
|
|
202
|
+
def add_relevance(self, df, query_text, reranking_threshold=None):
|
|
203
|
+
relevance_column = TableField.RELEVANCE.value
|
|
204
|
+
|
|
205
|
+
reranking_model_params = self._kb.params.get("reranking_model")
|
|
206
|
+
if reranking_model_params and query_text and len(df) > 0:
|
|
207
|
+
# Use reranker for relevance score
|
|
130
208
|
try:
|
|
131
|
-
logger.info(f"Using
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
nonlocal query_text
|
|
142
|
-
is_binary_op = isinstance(node, BinaryOperation)
|
|
143
|
-
is_identifier = isinstance(node.args[0], Identifier)
|
|
144
|
-
is_content = node.args[0].parts[-1].lower() == 'content'
|
|
145
|
-
is_constant = isinstance(node.args[1], Constant)
|
|
146
|
-
if is_binary_op and is_identifier and is_content and is_constant:
|
|
147
|
-
query_text = node.args[1].value
|
|
148
|
-
query_traversal(query.where, extract_content)
|
|
149
|
-
logger.debug(f"Extracted query text: {query_text}")
|
|
150
|
-
# Get scores from reranker
|
|
209
|
+
logger.info(f"Using knowledge reranking model from params: {reranking_model_params}")
|
|
210
|
+
# Apply custom filtering threshold if provided
|
|
211
|
+
if reranking_threshold is not None:
|
|
212
|
+
reranking_model_params["filtering_threshold"] = reranking_threshold
|
|
213
|
+
logger.info(f"Using custom filtering threshold: {reranking_threshold}")
|
|
214
|
+
|
|
215
|
+
reranker = get_reranking_model_from_params(reranking_model_params)
|
|
216
|
+
# Get documents to rerank
|
|
217
|
+
documents = df['chunk_content'].tolist()
|
|
218
|
+
# Use the get_scores method with disable_events=True
|
|
151
219
|
scores = reranker.get_scores(query_text, documents)
|
|
152
|
-
# Add scores as
|
|
220
|
+
# Add scores as the relevance column
|
|
221
|
+
df[relevance_column] = scores
|
|
222
|
+
|
|
223
|
+
# Filter by threshold
|
|
153
224
|
scores_array = np.array(scores)
|
|
154
|
-
# Add temporary column for sorting
|
|
155
|
-
df['_relevance_score'] = scores
|
|
156
|
-
# Filter by score threshold using numpy array for element-wise comparison
|
|
157
225
|
df = df[scores_array > reranker.filtering_threshold]
|
|
158
|
-
|
|
159
|
-
df = df.sort_values(by='_relevance_score', ascending=False)
|
|
160
|
-
# Remove temporary column
|
|
161
|
-
# df = df.drop(columns=['_relevance_score'])
|
|
162
|
-
# Apply original limit if it exists
|
|
163
|
-
if query.limit and len(df) > query.limit.value:
|
|
164
|
-
df = df.iloc[:query.limit.value]
|
|
165
|
-
logger.debug(f"Applied reranking with model {rerank_model}")
|
|
226
|
+
logger.debug(f"Applied reranking with params: {reranking_model_params}")
|
|
166
227
|
except Exception as e:
|
|
167
228
|
logger.error(f"Error during reranking: {str(e)}")
|
|
229
|
+
# Fallback to distance-based relevance
|
|
230
|
+
if 'distance' in df.columns:
|
|
231
|
+
df[relevance_column] = 1 / (1 + df['distance'])
|
|
232
|
+
else:
|
|
233
|
+
logger.info("No distance or reranker available")
|
|
168
234
|
|
|
169
|
-
|
|
235
|
+
elif 'distance' in df.columns:
|
|
236
|
+
# Calculate relevance from distance
|
|
237
|
+
logger.info("Calculating relevance from vector distance")
|
|
238
|
+
df[relevance_column] = 1 / (1 + df['distance'])
|
|
239
|
+
if reranking_threshold is not None:
|
|
240
|
+
df = df[df[relevance_column] > reranking_threshold]
|
|
241
|
+
|
|
242
|
+
else:
|
|
243
|
+
df[relevance_column] = None
|
|
244
|
+
df['distance'] = None
|
|
245
|
+
# Sort by relevance
|
|
246
|
+
df = df.sort_values(by=relevance_column, ascending=False)
|
|
170
247
|
return df
|
|
171
248
|
|
|
172
249
|
def addapt_conditions_columns(self, conditions):
|
|
@@ -186,7 +263,9 @@ class KnowledgeBaseTable:
|
|
|
186
263
|
|
|
187
264
|
columns = list(df.columns)
|
|
188
265
|
# update id, get from metadata
|
|
189
|
-
df[TableField.ID.value] = df[TableField.METADATA.value].apply(
|
|
266
|
+
df[TableField.ID.value] = df[TableField.METADATA.value].apply(
|
|
267
|
+
lambda m: None if m is None else m.get('original_row_id')
|
|
268
|
+
)
|
|
190
269
|
|
|
191
270
|
# id on first place
|
|
192
271
|
return df[[TableField.ID.value] + columns]
|
|
@@ -276,7 +355,9 @@ class KnowledgeBaseTable:
|
|
|
276
355
|
|
|
277
356
|
# send to vectordb
|
|
278
357
|
db_handler = self.get_vector_db()
|
|
279
|
-
db_handler.
|
|
358
|
+
conditions = db_handler.extract_conditions(query.where)
|
|
359
|
+
self.addapt_conditions_columns(conditions)
|
|
360
|
+
db_handler.dispatch_update(query, conditions)
|
|
280
361
|
|
|
281
362
|
def delete_query(self, query: Delete):
|
|
282
363
|
"""
|
|
@@ -332,6 +413,16 @@ class KnowledgeBaseTable:
|
|
|
332
413
|
if df.empty:
|
|
333
414
|
return
|
|
334
415
|
|
|
416
|
+
try:
|
|
417
|
+
run_query_id = ctx.run_query_id
|
|
418
|
+
# Link current KB to running query (where KB is used to insert data)
|
|
419
|
+
if run_query_id is not None:
|
|
420
|
+
self._kb.query_id = run_query_id
|
|
421
|
+
db.session.commit()
|
|
422
|
+
|
|
423
|
+
except AttributeError:
|
|
424
|
+
...
|
|
425
|
+
|
|
335
426
|
# First adapt column names to identify content and metadata columns
|
|
336
427
|
adapted_df = self._adapt_column_names(df)
|
|
337
428
|
content_columns = self._kb.params.get('content_columns', [TableField.CONTENT.value])
|
|
@@ -536,36 +627,48 @@ class KnowledgeBaseTable:
|
|
|
536
627
|
if df.empty:
|
|
537
628
|
return pd.DataFrame([], columns=[TableField.EMBEDDINGS.value])
|
|
538
629
|
|
|
630
|
+
# keep only content
|
|
631
|
+
df = df[[TableField.CONTENT.value]]
|
|
632
|
+
|
|
539
633
|
model_id = self._kb.embedding_model_id
|
|
540
|
-
|
|
541
|
-
|
|
634
|
+
if model_id:
|
|
635
|
+
# get the input columns
|
|
636
|
+
model_rec = db.session.query(db.Predictor).filter_by(id=model_id).first()
|
|
542
637
|
|
|
543
|
-
|
|
544
|
-
|
|
638
|
+
assert model_rec is not None, f"Model not found: {model_id}"
|
|
639
|
+
model_project = db.session.query(db.Project).filter_by(id=model_rec.project_id).first()
|
|
545
640
|
|
|
546
|
-
|
|
641
|
+
project_datanode = self.session.datahub.get(model_project.name)
|
|
547
642
|
|
|
548
|
-
|
|
549
|
-
|
|
643
|
+
model_using = model_rec.learn_args.get('using', {})
|
|
644
|
+
input_col = model_using.get('question_column')
|
|
645
|
+
if input_col is None:
|
|
646
|
+
input_col = model_using.get('input_column')
|
|
550
647
|
|
|
551
|
-
|
|
552
|
-
|
|
553
|
-
if input_col is None:
|
|
554
|
-
input_col = model_using.get('input_column')
|
|
648
|
+
if input_col is not None and input_col != TableField.CONTENT.value:
|
|
649
|
+
df = df.rename(columns={TableField.CONTENT.value: input_col})
|
|
555
650
|
|
|
556
|
-
|
|
557
|
-
|
|
651
|
+
df_out = project_datanode.predict(
|
|
652
|
+
model_name=model_rec.name,
|
|
653
|
+
df=df,
|
|
654
|
+
params=self.model_params
|
|
655
|
+
)
|
|
558
656
|
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
|
|
657
|
+
target = model_rec.to_predict[0]
|
|
658
|
+
if target != TableField.EMBEDDINGS.value:
|
|
659
|
+
# adapt output for vectordb
|
|
660
|
+
df_out = df_out.rename(columns={target: TableField.EMBEDDINGS.value})
|
|
661
|
+
|
|
662
|
+
elif self._kb.params.get('embedding_model'):
|
|
663
|
+
embedding_model = get_embedding_model_from_params(self._kb.params.get('embedding_model'))
|
|
664
|
+
|
|
665
|
+
df_texts = df.apply(row_to_document, axis=1)
|
|
666
|
+
embeddings = embedding_model.embed_documents(df_texts.tolist())
|
|
667
|
+
df_out = df.copy().assign(**{TableField.EMBEDDINGS.value: embeddings})
|
|
668
|
+
|
|
669
|
+
else:
|
|
670
|
+
raise ValueError("No embedding model found for the knowledge base.")
|
|
564
671
|
|
|
565
|
-
target = model_rec.to_predict[0]
|
|
566
|
-
if target != TableField.EMBEDDINGS.value:
|
|
567
|
-
# adapt output for vectordb
|
|
568
|
-
df_out = df_out.rename(columns={target: TableField.EMBEDDINGS.value})
|
|
569
672
|
df_out = df_out[[TableField.EMBEDDINGS.value]]
|
|
570
673
|
|
|
571
674
|
return df_out
|
|
@@ -599,9 +702,11 @@ class KnowledgeBaseTable:
|
|
|
599
702
|
# Extract embedding model args from knowledge base table
|
|
600
703
|
embedding_args = self._kb.embedding_model.learn_args.get('using', {})
|
|
601
704
|
# Construct the embedding model directly
|
|
602
|
-
from mindsdb.integrations.handlers.langchain_embedding_handler.langchain_embedding_handler import construct_model_from_args
|
|
603
705
|
embeddings_model = construct_model_from_args(embedding_args)
|
|
604
706
|
logger.debug(f"Using knowledge base embedding model with args: {embedding_args}")
|
|
707
|
+
elif self._kb.params.get('embedding_model'):
|
|
708
|
+
embeddings_model = get_embedding_model_from_params(self._kb.params['embedding_model'])
|
|
709
|
+
logger.debug(f"Using knowledge base embedding model from params: {self._kb.params['embedding_model']}")
|
|
605
710
|
else:
|
|
606
711
|
embeddings_model = DEFAULT_EMBEDDINGS_MODEL_CLASS()
|
|
607
712
|
logger.debug("Using default embedding model as knowledge base has no embedding model")
|
|
@@ -747,26 +852,46 @@ class KnowledgeBaseController:
|
|
|
747
852
|
return kb
|
|
748
853
|
raise EntityExistsError("Knowledge base already exists", name)
|
|
749
854
|
|
|
750
|
-
|
|
751
|
-
|
|
752
|
-
|
|
753
|
-
|
|
754
|
-
else:
|
|
755
|
-
# get embedding model from input
|
|
855
|
+
embedding_model_params = params.get('embedding_model', None)
|
|
856
|
+
reranking_model_params = params.get('reranking_model', None)
|
|
857
|
+
|
|
858
|
+
if embedding_model:
|
|
756
859
|
model_name = embedding_model.parts[-1]
|
|
757
860
|
|
|
861
|
+
elif embedding_model_params:
|
|
862
|
+
# Get embedding model from params.
|
|
863
|
+
# This is called here to check validaity of the parameters.
|
|
864
|
+
get_embedding_model_from_params(
|
|
865
|
+
embedding_model_params
|
|
866
|
+
)
|
|
867
|
+
|
|
868
|
+
else:
|
|
869
|
+
model_name = self._get_default_embedding_model(
|
|
870
|
+
project.name,
|
|
871
|
+
params=params
|
|
872
|
+
)
|
|
873
|
+
params['default_embedding_model'] = model_name
|
|
874
|
+
|
|
875
|
+
model_project = None
|
|
758
876
|
if embedding_model is not None and len(embedding_model.parts) > 1:
|
|
759
877
|
# model project is set
|
|
760
878
|
model_project = self.session.database_controller.get_project(embedding_model.parts[-2])
|
|
761
|
-
|
|
879
|
+
elif not embedding_model_params:
|
|
762
880
|
model_project = project
|
|
763
881
|
|
|
764
|
-
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
|
|
769
|
-
|
|
882
|
+
embedding_model_id = None
|
|
883
|
+
if model_project:
|
|
884
|
+
model = self.session.model_controller.get_model(
|
|
885
|
+
name=model_name,
|
|
886
|
+
project_name=model_project.name
|
|
887
|
+
)
|
|
888
|
+
model_record = db.Predictor.query.get(model['id'])
|
|
889
|
+
embedding_model_id = model_record.id
|
|
890
|
+
|
|
891
|
+
if reranking_model_params:
|
|
892
|
+
# Get reranking model from params.
|
|
893
|
+
# This is called here to check validaity of the parameters.
|
|
894
|
+
get_reranking_model_from_params(reranking_model_params)
|
|
770
895
|
|
|
771
896
|
# search for the vector database table
|
|
772
897
|
if storage is None:
|
|
@@ -988,6 +1113,7 @@ class KnowledgeBaseController:
|
|
|
988
1113
|
'embedding_model': embedding_model.name if embedding_model is not None else None,
|
|
989
1114
|
'vector_database': None if vector_database is None else vector_database.name,
|
|
990
1115
|
'vector_database_table': record.vector_database_table,
|
|
1116
|
+
'query_id': record.query_id,
|
|
991
1117
|
'params': record.params
|
|
992
1118
|
})
|
|
993
1119
|
|
|
@@ -92,9 +92,7 @@ class DocumentPreprocessor:
|
|
|
92
92
|
|
|
93
93
|
def _generate_chunk_id(
|
|
94
94
|
self,
|
|
95
|
-
content: str,
|
|
96
95
|
chunk_index: Optional[int] = None,
|
|
97
|
-
content_column: str = None,
|
|
98
96
|
provided_id: str = None,
|
|
99
97
|
) -> str:
|
|
100
98
|
"""Generate deterministic ID for a chunk"""
|
|
@@ -262,15 +260,8 @@ Please give a short succinct context to situate this chunk within the overall do
|
|
|
262
260
|
if doc.metadata:
|
|
263
261
|
metadata.update(doc.metadata)
|
|
264
262
|
|
|
265
|
-
# Pass through doc.id and content_column
|
|
266
|
-
content_column = (
|
|
267
|
-
doc.metadata.get("content_column") if doc.metadata else None
|
|
268
|
-
)
|
|
269
263
|
chunk_id = self._generate_chunk_id(
|
|
270
|
-
|
|
271
|
-
chunk_index,
|
|
272
|
-
content_column=content_column,
|
|
273
|
-
provided_id=doc.id,
|
|
264
|
+
chunk_index=chunk_index, provided_id=doc.id
|
|
274
265
|
)
|
|
275
266
|
processed_chunks.append(
|
|
276
267
|
ProcessedChunk(
|
|
@@ -335,7 +326,7 @@ class TextChunkingPreprocessor(DocumentPreprocessor):
|
|
|
335
326
|
|
|
336
327
|
# Pass through doc.id and content_column
|
|
337
328
|
id = self._generate_chunk_id(
|
|
338
|
-
|
|
329
|
+
chunk_index=0, provided_id=doc.id
|
|
339
330
|
)
|
|
340
331
|
processed_chunks.append(
|
|
341
332
|
ProcessedChunk(
|
|
@@ -358,9 +349,7 @@ class TextChunkingPreprocessor(DocumentPreprocessor):
|
|
|
358
349
|
|
|
359
350
|
# Pass through doc.id and content_column
|
|
360
351
|
chunk_id = self._generate_chunk_id(
|
|
361
|
-
|
|
362
|
-
i,
|
|
363
|
-
content_column=content_column,
|
|
352
|
+
chunk_index=i,
|
|
364
353
|
provided_id=doc.id,
|
|
365
354
|
)
|
|
366
355
|
processed_chunks.append(
|