MindsDB 25.3.4.2__py3-none-any.whl → 25.4.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of MindsDB might be problematic. Click here for more details.
- mindsdb/__about__.py +1 -1
- mindsdb/__main__.py +21 -4
- mindsdb/api/executor/command_executor.py +62 -61
- mindsdb/api/executor/data_types/answer.py +9 -12
- mindsdb/api/executor/datahub/classes/response.py +11 -0
- mindsdb/api/executor/datahub/datanodes/datanode.py +4 -4
- mindsdb/api/executor/datahub/datanodes/information_schema_datanode.py +7 -9
- mindsdb/api/executor/datahub/datanodes/integration_datanode.py +22 -16
- mindsdb/api/executor/datahub/datanodes/project_datanode.py +20 -20
- mindsdb/api/executor/planner/plan_join.py +1 -1
- mindsdb/api/executor/planner/steps.py +2 -1
- mindsdb/api/executor/sql_query/result_set.py +10 -7
- mindsdb/api/executor/sql_query/sql_query.py +36 -82
- mindsdb/api/executor/sql_query/steps/delete_step.py +2 -3
- mindsdb/api/executor/sql_query/steps/fetch_dataframe.py +5 -3
- mindsdb/api/executor/sql_query/steps/insert_step.py +2 -2
- mindsdb/api/executor/sql_query/steps/prepare_steps.py +2 -2
- mindsdb/api/executor/sql_query/steps/subselect_step.py +20 -8
- mindsdb/api/executor/sql_query/steps/update_step.py +4 -6
- mindsdb/api/http/namespaces/sql.py +4 -1
- mindsdb/api/mcp/__init__.py +0 -0
- mindsdb/api/mcp/start.py +152 -0
- mindsdb/api/mysql/mysql_proxy/data_types/mysql_packets/ok_packet.py +1 -1
- mindsdb/api/mysql/mysql_proxy/executor/mysql_executor.py +4 -27
- mindsdb/api/mysql/mysql_proxy/libs/constants/mysql.py +1 -0
- mindsdb/api/mysql/mysql_proxy/mysql_proxy.py +38 -37
- mindsdb/integrations/handlers/chromadb_handler/chromadb_handler.py +23 -13
- mindsdb/integrations/handlers/mssql_handler/mssql_handler.py +1 -1
- mindsdb/integrations/handlers/mysql_handler/mysql_handler.py +3 -2
- mindsdb/integrations/handlers/oracle_handler/oracle_handler.py +4 -4
- mindsdb/integrations/handlers/pgvector_handler/pgvector_handler.py +19 -5
- mindsdb/integrations/handlers/postgres_handler/postgres_handler.py +9 -4
- mindsdb/integrations/handlers/redshift_handler/redshift_handler.py +1 -1
- mindsdb/integrations/handlers/snowflake_handler/snowflake_handler.py +18 -11
- mindsdb/integrations/libs/ml_handler_process/learn_process.py +1 -2
- mindsdb/integrations/libs/response.py +9 -4
- mindsdb/integrations/libs/vectordatabase_handler.py +37 -25
- mindsdb/integrations/utilities/rag/rerankers/reranker_compressor.py +35 -15
- mindsdb/interfaces/database/log.py +8 -9
- mindsdb/interfaces/database/projects.py +16 -5
- mindsdb/interfaces/functions/controller.py +59 -17
- mindsdb/interfaces/functions/to_markdown.py +194 -0
- mindsdb/interfaces/jobs/jobs_controller.py +3 -3
- mindsdb/interfaces/knowledge_base/controller.py +143 -26
- mindsdb/interfaces/knowledge_base/preprocessing/document_preprocessor.py +3 -14
- mindsdb/interfaces/query_context/context_controller.py +3 -1
- mindsdb/utilities/config.py +8 -0
- mindsdb/utilities/starters.py +7 -0
- {mindsdb-25.3.4.2.dist-info → mindsdb-25.4.2.0.dist-info}/METADATA +233 -231
- {mindsdb-25.3.4.2.dist-info → mindsdb-25.4.2.0.dist-info}/RECORD +53 -49
- {mindsdb-25.3.4.2.dist-info → mindsdb-25.4.2.0.dist-info}/WHEEL +0 -0
- {mindsdb-25.3.4.2.dist-info → mindsdb-25.4.2.0.dist-info}/licenses/LICENSE +0 -0
- {mindsdb-25.3.4.2.dist-info → mindsdb-25.4.2.0.dist-info}/top_level.txt +0 -0
|
@@ -21,7 +21,8 @@ import sys
|
|
|
21
21
|
import tempfile
|
|
22
22
|
import traceback
|
|
23
23
|
from functools import partial
|
|
24
|
-
from typing import Dict, List
|
|
24
|
+
from typing import Dict, List, Optional
|
|
25
|
+
from dataclasses import dataclass
|
|
25
26
|
|
|
26
27
|
from numpy import dtype as np_dtype
|
|
27
28
|
from pandas.api import types as pd_types
|
|
@@ -71,6 +72,7 @@ from mindsdb.api.mysql.mysql_proxy.libs.constants.mysql import (
|
|
|
71
72
|
TYPES,
|
|
72
73
|
getConstName,
|
|
73
74
|
)
|
|
75
|
+
from mindsdb.api.executor.data_types.answer import ExecuteAnswer
|
|
74
76
|
from mindsdb.api.executor.data_types.response_type import RESPONSE_TYPE
|
|
75
77
|
from mindsdb.api.mysql.mysql_proxy.utilities import (
|
|
76
78
|
ErWrongCharset,
|
|
@@ -93,24 +95,16 @@ def empty_fn():
|
|
|
93
95
|
pass
|
|
94
96
|
|
|
95
97
|
|
|
98
|
+
@dataclass
|
|
96
99
|
class SQLAnswer:
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
error_message: str = None,
|
|
106
|
-
):
|
|
107
|
-
self.resp_type = resp_type
|
|
108
|
-
self.columns = columns
|
|
109
|
-
self.data = data
|
|
110
|
-
self.status = status
|
|
111
|
-
self.state_track = state_track
|
|
112
|
-
self.error_code = error_code
|
|
113
|
-
self.error_message = error_message
|
|
100
|
+
resp_type: RESPONSE_TYPE = RESPONSE_TYPE.OK
|
|
101
|
+
columns: Optional[List[Dict]] = None
|
|
102
|
+
data: Optional[List[Dict]] = None # resultSet ?
|
|
103
|
+
status: Optional[int] = None
|
|
104
|
+
state_track: Optional[List[List]] = None
|
|
105
|
+
error_code: Optional[int] = None
|
|
106
|
+
error_message: Optional[str] = None
|
|
107
|
+
affected_rows: Optional[int] = None
|
|
114
108
|
|
|
115
109
|
@property
|
|
116
110
|
def type(self):
|
|
@@ -333,7 +327,7 @@ class MysqlProxy(SocketServer.BaseRequestHandler):
|
|
|
333
327
|
packages.append(self.last_packet())
|
|
334
328
|
self.send_package_group(packages)
|
|
335
329
|
elif answer.type == RESPONSE_TYPE.OK:
|
|
336
|
-
self.packet(OkPacket, state_track=answer.state_track).send()
|
|
330
|
+
self.packet(OkPacket, state_track=answer.state_track, affected_rows=answer.affected_rows).send()
|
|
337
331
|
elif answer.type == RESPONSE_TYPE.ERROR:
|
|
338
332
|
self.packet(
|
|
339
333
|
ErrPacket, err_code=answer.error_code, msg=answer.error_message
|
|
@@ -546,21 +540,23 @@ class MysqlProxy(SocketServer.BaseRequestHandler):
|
|
|
546
540
|
@profiler.profile()
|
|
547
541
|
def process_query(self, sql):
|
|
548
542
|
executor = Executor(session=self.session, sqlserver=self)
|
|
549
|
-
|
|
550
543
|
executor.query_execute(sql)
|
|
544
|
+
executor_answer = executor.executor_answer
|
|
551
545
|
|
|
552
|
-
if
|
|
546
|
+
if executor_answer.data is None:
|
|
553
547
|
resp = SQLAnswer(
|
|
554
548
|
resp_type=RESPONSE_TYPE.OK,
|
|
555
|
-
state_track=
|
|
549
|
+
state_track=executor_answer.state_track,
|
|
550
|
+
affected_rows=executor_answer.affected_rows
|
|
556
551
|
)
|
|
557
552
|
else:
|
|
558
553
|
resp = SQLAnswer(
|
|
559
554
|
resp_type=RESPONSE_TYPE.TABLE,
|
|
560
|
-
state_track=
|
|
561
|
-
columns=self.to_mysql_columns(
|
|
562
|
-
data=
|
|
555
|
+
state_track=executor_answer.state_track,
|
|
556
|
+
columns=self.to_mysql_columns(executor_answer.data.columns),
|
|
557
|
+
data=executor_answer.data,
|
|
563
558
|
status=executor.server_status,
|
|
559
|
+
affected_rows=executor_answer.affected_rows
|
|
564
560
|
)
|
|
565
561
|
|
|
566
562
|
# Increment the counter and include metadata in attributes
|
|
@@ -604,18 +600,20 @@ class MysqlProxy(SocketServer.BaseRequestHandler):
|
|
|
604
600
|
|
|
605
601
|
def answer_stmt_execute(self, stmt_id, parameters):
|
|
606
602
|
prepared_stmt = self.session.prepared_stmts[stmt_id]
|
|
607
|
-
executor = prepared_stmt["statement"]
|
|
603
|
+
executor: Executor = prepared_stmt["statement"]
|
|
608
604
|
|
|
609
605
|
executor.stmt_execute(parameters)
|
|
610
606
|
|
|
611
|
-
|
|
607
|
+
executor_answer: ExecuteAnswer = executor.executor_answer
|
|
608
|
+
|
|
609
|
+
if executor_answer.data is None:
|
|
612
610
|
resp = SQLAnswer(
|
|
613
|
-
resp_type=RESPONSE_TYPE.OK, state_track=
|
|
611
|
+
resp_type=RESPONSE_TYPE.OK, state_track=executor_answer.state_track
|
|
614
612
|
)
|
|
615
613
|
return self.send_query_answer(resp)
|
|
616
614
|
|
|
617
615
|
# TODO prepared_stmt['type'] == 'lock' is not used but it works
|
|
618
|
-
columns_def = self.to_mysql_columns(
|
|
616
|
+
columns_def = self.to_mysql_columns(executor_answer.data.columns)
|
|
619
617
|
packages = [self.packet(ColumnCountPacket, count=len(columns_def))]
|
|
620
618
|
|
|
621
619
|
packages.extend(self._get_column_defenition_packets(columns_def))
|
|
@@ -624,14 +622,14 @@ class MysqlProxy(SocketServer.BaseRequestHandler):
|
|
|
624
622
|
packages.append(self.packet(EofPacket, status=0x0062))
|
|
625
623
|
|
|
626
624
|
# send all
|
|
627
|
-
for row in
|
|
625
|
+
for row in executor_answer.data.to_lists():
|
|
628
626
|
packages.append(
|
|
629
627
|
self.packet(BinaryResultsetRowPacket, data=row, columns=columns_def)
|
|
630
628
|
)
|
|
631
629
|
|
|
632
630
|
server_status = executor.server_status or 0x0002
|
|
633
631
|
packages.append(self.last_packet(status=server_status))
|
|
634
|
-
prepared_stmt["fetched"] += len(
|
|
632
|
+
prepared_stmt["fetched"] += len(executor_answer.data)
|
|
635
633
|
|
|
636
634
|
return self.send_package_group(packages)
|
|
637
635
|
|
|
@@ -639,23 +637,24 @@ class MysqlProxy(SocketServer.BaseRequestHandler):
|
|
|
639
637
|
prepared_stmt = self.session.prepared_stmts[stmt_id]
|
|
640
638
|
executor = prepared_stmt["statement"]
|
|
641
639
|
fetched = prepared_stmt["fetched"]
|
|
640
|
+
executor_answer: ExecuteAnswer = executor.executor_answer
|
|
642
641
|
|
|
643
|
-
if
|
|
642
|
+
if executor_answer.data is None:
|
|
644
643
|
resp = SQLAnswer(
|
|
645
|
-
resp_type=RESPONSE_TYPE.OK, state_track=
|
|
644
|
+
resp_type=RESPONSE_TYPE.OK, state_track=executor_answer.state_track
|
|
646
645
|
)
|
|
647
646
|
return self.send_query_answer(resp)
|
|
648
647
|
|
|
649
648
|
packages = []
|
|
650
|
-
columns = self.to_mysql_columns(
|
|
651
|
-
for row in
|
|
649
|
+
columns = self.to_mysql_columns(executor_answer.data.columns)
|
|
650
|
+
for row in executor_answer.data[fetched:limit].to_lists():
|
|
652
651
|
packages.append(
|
|
653
652
|
self.packet(BinaryResultsetRowPacket, data=row, columns=columns)
|
|
654
653
|
)
|
|
655
654
|
|
|
656
|
-
prepared_stmt["fetched"] += len(
|
|
655
|
+
prepared_stmt["fetched"] += len(executor_answer.data[fetched:limit])
|
|
657
656
|
|
|
658
|
-
if len(
|
|
657
|
+
if len(executor_answer.data) <= limit + fetched:
|
|
659
658
|
status = sum(
|
|
660
659
|
[
|
|
661
660
|
SERVER_STATUS.SERVER_STATUS_AUTOCOMMIT,
|
|
@@ -772,6 +771,8 @@ class MysqlProxy(SocketServer.BaseRequestHandler):
|
|
|
772
771
|
elif p.type.value == COMMANDS.COM_FIELD_LIST:
|
|
773
772
|
# this command is deprecated, but console client still use it.
|
|
774
773
|
response = SQLAnswer(RESPONSE_TYPE.OK)
|
|
774
|
+
elif p.type.value == COMMANDS.COM_STMT_RESET:
|
|
775
|
+
response = SQLAnswer(RESPONSE_TYPE.OK)
|
|
775
776
|
else:
|
|
776
777
|
logger.warning("Command has no specific handler, return OK msg")
|
|
777
778
|
logger.debug(str(p))
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import ast
|
|
2
2
|
import sys
|
|
3
|
+
import os
|
|
3
4
|
from typing import Dict, List, Optional, Union
|
|
4
5
|
import hashlib
|
|
5
6
|
|
|
@@ -67,6 +68,8 @@ class ChromaDBHandler(VectorStoreHandler):
|
|
|
67
68
|
"persist_directory": self.persist_directory,
|
|
68
69
|
}
|
|
69
70
|
|
|
71
|
+
self._use_handler_storage = False
|
|
72
|
+
|
|
70
73
|
self.connect()
|
|
71
74
|
|
|
72
75
|
def validate_connection_parameters(self, name, **kwargs):
|
|
@@ -79,11 +82,15 @@ class ChromaDBHandler(VectorStoreHandler):
|
|
|
79
82
|
|
|
80
83
|
config = ChromaHandlerConfig(**_config)
|
|
81
84
|
|
|
82
|
-
if config.persist_directory
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
85
|
+
if config.persist_directory:
|
|
86
|
+
if os.path.isabs(config.persist_directory):
|
|
87
|
+
self.persist_directory = config.persist_directory
|
|
88
|
+
elif not self.handler_storage.is_temporal:
|
|
89
|
+
# get full persistence directory from handler storage
|
|
90
|
+
self.persist_directory = self.handler_storage.folder_get(
|
|
91
|
+
config.persist_directory
|
|
92
|
+
)
|
|
93
|
+
self._use_handler_storage = True
|
|
87
94
|
|
|
88
95
|
return config
|
|
89
96
|
|
|
@@ -105,7 +112,7 @@ class ChromaDBHandler(VectorStoreHandler):
|
|
|
105
112
|
|
|
106
113
|
def _sync(self):
|
|
107
114
|
"""Sync the database to disk if using persistent storage"""
|
|
108
|
-
if self.persist_directory:
|
|
115
|
+
if self.persist_directory and self._use_handler_storage:
|
|
109
116
|
self.handler_storage.folder_sync(self.persist_directory)
|
|
110
117
|
|
|
111
118
|
def __del__(self):
|
|
@@ -162,6 +169,8 @@ class ChromaDBHandler(VectorStoreHandler):
|
|
|
162
169
|
FilterOperator.LESS_THAN_OR_EQUAL: "$lte",
|
|
163
170
|
FilterOperator.GREATER_THAN: "$gt",
|
|
164
171
|
FilterOperator.GREATER_THAN_OR_EQUAL: "$gte",
|
|
172
|
+
FilterOperator.IN: "$in",
|
|
173
|
+
FilterOperator.NOT_IN: "$nin",
|
|
165
174
|
}
|
|
166
175
|
|
|
167
176
|
if operator not in mapping:
|
|
@@ -308,7 +317,7 @@ class ChromaDBHandler(VectorStoreHandler):
|
|
|
308
317
|
}
|
|
309
318
|
|
|
310
319
|
if columns is not None:
|
|
311
|
-
payload = {column: payload[column] for column in columns}
|
|
320
|
+
payload = {column: payload[column] for column in columns if column != TableField.DISTANCE.value}
|
|
312
321
|
|
|
313
322
|
# always include distance
|
|
314
323
|
distance_filter = None
|
|
@@ -316,10 +325,11 @@ class ChromaDBHandler(VectorStoreHandler):
|
|
|
316
325
|
if distances is not None:
|
|
317
326
|
payload[distance_col] = distances
|
|
318
327
|
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
328
|
+
if conditions is not None:
|
|
329
|
+
for cond in conditions:
|
|
330
|
+
if cond.column == distance_col:
|
|
331
|
+
distance_filter = cond
|
|
332
|
+
break
|
|
323
333
|
|
|
324
334
|
df = pd.DataFrame(payload)
|
|
325
335
|
if distance_filter is not None:
|
|
@@ -413,8 +423,8 @@ class ChromaDBHandler(VectorStoreHandler):
|
|
|
413
423
|
collection.upsert(
|
|
414
424
|
ids=data_dict[TableField.ID.value],
|
|
415
425
|
documents=data_dict[TableField.CONTENT.value],
|
|
416
|
-
embeddings=data_dict.get(TableField.EMBEDDINGS.value),
|
|
417
|
-
metadatas=data_dict.get(TableField.METADATA.value)
|
|
426
|
+
embeddings=data_dict.get(TableField.EMBEDDINGS.value, None),
|
|
427
|
+
metadatas=data_dict.get(TableField.METADATA.value, None)
|
|
418
428
|
)
|
|
419
429
|
self._sync()
|
|
420
430
|
except Exception as e:
|
|
@@ -177,7 +177,7 @@ class SqlServerHandler(DatabaseHandler):
|
|
|
177
177
|
)
|
|
178
178
|
)
|
|
179
179
|
else:
|
|
180
|
-
response = Response(RESPONSE_TYPE.OK)
|
|
180
|
+
response = Response(RESPONSE_TYPE.OK, affected_rows=cur.rowcount)
|
|
181
181
|
connection.commit()
|
|
182
182
|
except Exception as e:
|
|
183
183
|
logger.error(f'Error running query: {query} on {self.database}, {e}!')
|
|
@@ -178,10 +178,11 @@ class MySQLHandler(DatabaseHandler):
|
|
|
178
178
|
pd.DataFrame(
|
|
179
179
|
result,
|
|
180
180
|
columns=[x[0] for x in cur.description]
|
|
181
|
-
)
|
|
181
|
+
),
|
|
182
|
+
affected_rows=cur.rowcount
|
|
182
183
|
)
|
|
183
184
|
else:
|
|
184
|
-
response = Response(RESPONSE_TYPE.OK)
|
|
185
|
+
response = Response(RESPONSE_TYPE.OK, affected_rows=cur.rowcount)
|
|
185
186
|
except mysql.connector.Error as e:
|
|
186
187
|
logger.error(f'Error running query: {query} on {self.connection_data["database"]}!')
|
|
187
188
|
response = Response(
|
|
@@ -205,8 +205,10 @@ class OracleHandler(DatabaseHandler):
|
|
|
205
205
|
with connection.cursor() as cur:
|
|
206
206
|
try:
|
|
207
207
|
cur.execute(query)
|
|
208
|
-
|
|
209
|
-
|
|
208
|
+
if cur.description is None:
|
|
209
|
+
response = Response(RESPONSE_TYPE.OK, affected_rows=cur.rowcount)
|
|
210
|
+
else:
|
|
211
|
+
result = cur.fetchall()
|
|
210
212
|
response = Response(
|
|
211
213
|
RESPONSE_TYPE.TABLE,
|
|
212
214
|
data_frame=pd.DataFrame(
|
|
@@ -214,8 +216,6 @@ class OracleHandler(DatabaseHandler):
|
|
|
214
216
|
columns=[row[0] for row in cur.description],
|
|
215
217
|
),
|
|
216
218
|
)
|
|
217
|
-
else:
|
|
218
|
-
response = Response(RESPONSE_TYPE.OK)
|
|
219
219
|
|
|
220
220
|
connection.commit()
|
|
221
221
|
except DatabaseError as database_error:
|
|
@@ -149,7 +149,7 @@ class PgVectorHandler(PostgresHandler, VectorStoreHandler):
|
|
|
149
149
|
for key, value in filter_conditions.items():
|
|
150
150
|
if key == "embeddings":
|
|
151
151
|
continue
|
|
152
|
-
if value['op'].lower()
|
|
152
|
+
if value['op'].lower() in ('in', 'not in'):
|
|
153
153
|
values = list(repr(i) for i in value['value'])
|
|
154
154
|
value['value'] = '({})'.format(', '.join(values))
|
|
155
155
|
else:
|
|
@@ -165,9 +165,9 @@ class PgVectorHandler(PostgresHandler, VectorStoreHandler):
|
|
|
165
165
|
|
|
166
166
|
@staticmethod
|
|
167
167
|
def _construct_full_after_from_clause(
|
|
168
|
+
where_clause: str,
|
|
168
169
|
offset_clause: str,
|
|
169
170
|
limit_clause: str,
|
|
170
|
-
where_clause: str,
|
|
171
171
|
) -> str:
|
|
172
172
|
|
|
173
173
|
return f"{where_clause} {offset_clause} {limit_clause}"
|
|
@@ -200,10 +200,20 @@ class PgVectorHandler(PostgresHandler, VectorStoreHandler):
|
|
|
200
200
|
where_clause, offset_clause, limit_clause
|
|
201
201
|
)
|
|
202
202
|
|
|
203
|
-
|
|
204
|
-
|
|
203
|
+
# Handle distance column specially since it's calculated, not stored
|
|
204
|
+
modified_columns = []
|
|
205
|
+
has_distance = False
|
|
206
|
+
if columns is not None:
|
|
207
|
+
for col in columns:
|
|
208
|
+
if col == TableField.DISTANCE.value:
|
|
209
|
+
has_distance = True
|
|
210
|
+
else:
|
|
211
|
+
modified_columns.append(col)
|
|
205
212
|
else:
|
|
206
|
-
|
|
213
|
+
modified_columns = ['id', 'content', 'embeddings', 'metadata']
|
|
214
|
+
has_distance = True
|
|
215
|
+
|
|
216
|
+
targets = ', '.join(modified_columns)
|
|
207
217
|
|
|
208
218
|
|
|
209
219
|
if filter_conditions:
|
|
@@ -227,6 +237,10 @@ class PgVectorHandler(PostgresHandler, VectorStoreHandler):
|
|
|
227
237
|
# Use cosine similarity for dense vectors
|
|
228
238
|
distance_op = "<=>"
|
|
229
239
|
|
|
240
|
+
# Calculate distance as part of the query if needed
|
|
241
|
+
if has_distance:
|
|
242
|
+
targets = f"{targets}, (embeddings {distance_op} '{search_vector}') as distance"
|
|
243
|
+
|
|
230
244
|
return f"SELECT {targets} FROM {table_name} ORDER BY embeddings {distance_op} '{search_vector}' ASC {after_from_clause}"
|
|
231
245
|
|
|
232
246
|
else:
|
|
@@ -228,7 +228,7 @@ class PostgresHandler(DatabaseHandler):
|
|
|
228
228
|
else:
|
|
229
229
|
cur.execute(query)
|
|
230
230
|
if cur.pgresult is None or ExecStatus(cur.pgresult.status) == ExecStatus.COMMAND_OK:
|
|
231
|
-
response = Response(RESPONSE_TYPE.OK)
|
|
231
|
+
response = Response(RESPONSE_TYPE.OK, affected_rows=cur.rowcount)
|
|
232
232
|
else:
|
|
233
233
|
result = cur.fetchall()
|
|
234
234
|
df = DataFrame(
|
|
@@ -238,7 +238,8 @@ class PostgresHandler(DatabaseHandler):
|
|
|
238
238
|
self._cast_dtypes(df, cur.description)
|
|
239
239
|
response = Response(
|
|
240
240
|
RESPONSE_TYPE.TABLE,
|
|
241
|
-
df
|
|
241
|
+
data_frame=df,
|
|
242
|
+
affected_rows=cur.rowcount
|
|
242
243
|
)
|
|
243
244
|
connection.commit()
|
|
244
245
|
except Exception as e:
|
|
@@ -255,15 +256,16 @@ class PostgresHandler(DatabaseHandler):
|
|
|
255
256
|
|
|
256
257
|
return response
|
|
257
258
|
|
|
258
|
-
def insert(self, table_name: str, df: pd.DataFrame):
|
|
259
|
+
def insert(self, table_name: str, df: pd.DataFrame) -> Response:
|
|
259
260
|
need_to_close = not self.is_connected
|
|
260
261
|
|
|
261
262
|
connection = self.connect()
|
|
262
263
|
|
|
263
264
|
columns = [f'"{c}"' for c in df.columns]
|
|
265
|
+
rowcount = None
|
|
264
266
|
with connection.cursor() as cur:
|
|
265
267
|
try:
|
|
266
|
-
with cur.copy(f'copy "{table_name}" ({",".join(columns)}) from STDIN
|
|
268
|
+
with cur.copy(f'copy "{table_name}" ({",".join(columns)}) from STDIN WITH CSV') as copy:
|
|
267
269
|
df.to_csv(copy, index=False, header=False)
|
|
268
270
|
|
|
269
271
|
connection.commit()
|
|
@@ -271,10 +273,13 @@ class PostgresHandler(DatabaseHandler):
|
|
|
271
273
|
logger.error(f'Error running insert to {table_name} on {self.database}, {e}!')
|
|
272
274
|
connection.rollback()
|
|
273
275
|
raise e
|
|
276
|
+
rowcount = cur.rowcount
|
|
274
277
|
|
|
275
278
|
if need_to_close:
|
|
276
279
|
self.disconnect()
|
|
277
280
|
|
|
281
|
+
return Response(RESPONSE_TYPE.OK, affected_rows=rowcount)
|
|
282
|
+
|
|
278
283
|
@profiler.profile()
|
|
279
284
|
def query(self, query: ASTNode) -> Response:
|
|
280
285
|
"""
|
|
@@ -52,7 +52,7 @@ class RedshiftHandler(PostgresHandler):
|
|
|
52
52
|
with connection.cursor() as cur:
|
|
53
53
|
try:
|
|
54
54
|
cur.executemany(query, df.values.tolist())
|
|
55
|
-
response = Response(RESPONSE_TYPE.OK)
|
|
55
|
+
response = Response(RESPONSE_TYPE.OK, affected_rows=cur.rowcount)
|
|
56
56
|
|
|
57
57
|
connection.commit()
|
|
58
58
|
except Exception as e:
|
|
@@ -230,18 +230,25 @@ class SnowflakeHandler(DatabaseHandler):
|
|
|
230
230
|
# Fallback for CREATE/DELETE/UPDATE. These commands returns table with single column,
|
|
231
231
|
# but it cannot be retrieved as pandas DataFrame.
|
|
232
232
|
result = cur.fetchall()
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
233
|
+
match result:
|
|
234
|
+
case (
|
|
235
|
+
[{'number of rows inserted': affected_rows}]
|
|
236
|
+
| [{'number of rows deleted': affected_rows}]
|
|
237
|
+
| [{'number of rows updated': affected_rows, 'number of multi-joined rows updated': _}]
|
|
238
|
+
):
|
|
239
|
+
response = Response(RESPONSE_TYPE.OK, affected_rows=affected_rows)
|
|
240
|
+
case list():
|
|
241
|
+
response = Response(
|
|
242
|
+
RESPONSE_TYPE.TABLE,
|
|
243
|
+
DataFrame(
|
|
244
|
+
result,
|
|
245
|
+
columns=[x[0] for x in cur.description]
|
|
246
|
+
)
|
|
239
247
|
)
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
response = Response(RESPONSE_TYPE.OK)
|
|
248
|
+
case _:
|
|
249
|
+
# Looks like SnowFlake always returns something in response, so this is suspicious
|
|
250
|
+
logger.warning('Snowflake did not return any data in response.')
|
|
251
|
+
response = Response(RESPONSE_TYPE.OK)
|
|
245
252
|
except Exception as e:
|
|
246
253
|
logger.error(f"Error running query: {query} on {self.connection_data.get('database')}, {e}!")
|
|
247
254
|
response = Response(
|
|
@@ -78,8 +78,7 @@ def learn_process(data_integration_ref: dict, problem_definition: dict, fetch_da
|
|
|
78
78
|
query_ast = parse_sql(fetch_data_query)
|
|
79
79
|
sqlquery = SQLQuery(query_ast, session=sql_session)
|
|
80
80
|
|
|
81
|
-
|
|
82
|
-
training_data_df = result['result']
|
|
81
|
+
training_data_df = sqlquery.fetched_data.to_df()
|
|
83
82
|
|
|
84
83
|
training_data_columns_count, training_data_rows_count = 0, 0
|
|
85
84
|
if training_data_df is not None:
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
from typing import Optional
|
|
1
2
|
from pandas import DataFrame
|
|
2
3
|
|
|
3
4
|
from mindsdb.utilities import log
|
|
@@ -8,13 +9,16 @@ from mindsdb_sql_parser.ast import ASTNode
|
|
|
8
9
|
logger = log.getLogger(__name__)
|
|
9
10
|
|
|
10
11
|
class HandlerResponse:
|
|
11
|
-
def __init__(self, resp_type: RESPONSE_TYPE, data_frame: DataFrame = None,
|
|
12
|
-
|
|
12
|
+
def __init__(self, resp_type: RESPONSE_TYPE, data_frame: DataFrame = None, query: ASTNode = 0, error_code: int = 0,
|
|
13
|
+
error_message: Optional[str] = None, affected_rows: Optional[int] = None) -> None:
|
|
13
14
|
self.resp_type = resp_type
|
|
14
15
|
self.query = query
|
|
15
16
|
self.data_frame = data_frame
|
|
16
17
|
self.error_code = error_code
|
|
17
18
|
self.error_message = error_message
|
|
19
|
+
self.affected_rows = affected_rows
|
|
20
|
+
if isinstance(self.affected_rows, int) is False or self.affected_rows < 0:
|
|
21
|
+
self.affected_rows = 0
|
|
18
22
|
|
|
19
23
|
@property
|
|
20
24
|
def type(self):
|
|
@@ -35,13 +39,14 @@ class HandlerResponse:
|
|
|
35
39
|
"error": self.error_message}
|
|
36
40
|
|
|
37
41
|
def __repr__(self):
|
|
38
|
-
return "%s: resp_type=%s, query=%s, data_frame=%s, err_code=%s, error=%s" % (
|
|
42
|
+
return "%s: resp_type=%s, query=%s, data_frame=%s, err_code=%s, error=%s, affected_rows=%s" % (
|
|
39
43
|
self.__class__.__name__,
|
|
40
44
|
self.resp_type,
|
|
41
45
|
self.query,
|
|
42
46
|
self.data_frame,
|
|
43
47
|
self.error_code,
|
|
44
|
-
self.error_message
|
|
48
|
+
self.error_message,
|
|
49
|
+
self.affected_rows
|
|
45
50
|
)
|
|
46
51
|
|
|
47
52
|
class HandlerStatusResponse:
|
|
@@ -20,7 +20,7 @@ from mindsdb_sql_parser.ast.base import ASTNode
|
|
|
20
20
|
|
|
21
21
|
from mindsdb.integrations.libs.response import RESPONSE_TYPE, HandlerResponse
|
|
22
22
|
from mindsdb.utilities import log
|
|
23
|
-
from mindsdb.integrations.utilities.sql_utils import
|
|
23
|
+
from mindsdb.integrations.utilities.sql_utils import FilterCondition, FilterOperator
|
|
24
24
|
|
|
25
25
|
from mindsdb.integrations.utilities.query_traversal import query_traversal
|
|
26
26
|
from .base import BaseHandler
|
|
@@ -39,6 +39,7 @@ class TableField(Enum):
|
|
|
39
39
|
METADATA = "metadata"
|
|
40
40
|
SEARCH_VECTOR = "search_vector"
|
|
41
41
|
DISTANCE = "distance"
|
|
42
|
+
RELEVANCE = "relevance"
|
|
42
43
|
|
|
43
44
|
|
|
44
45
|
class DistanceFunction(Enum):
|
|
@@ -69,6 +70,10 @@ class VectorStoreHandler(BaseHandler):
|
|
|
69
70
|
"name": TableField.METADATA.value,
|
|
70
71
|
"data_type": "json",
|
|
71
72
|
},
|
|
73
|
+
{
|
|
74
|
+
"name": TableField.DISTANCE.value,
|
|
75
|
+
"data_type": "float",
|
|
76
|
+
},
|
|
72
77
|
]
|
|
73
78
|
|
|
74
79
|
def validate_connection_parameters(self, name, **kwargs):
|
|
@@ -89,7 +94,7 @@ class VectorStoreHandler(BaseHandler):
|
|
|
89
94
|
else:
|
|
90
95
|
return value
|
|
91
96
|
|
|
92
|
-
def
|
|
97
|
+
def extract_conditions(self, where_statement) -> Optional[List[FilterCondition]]:
|
|
93
98
|
conditions = []
|
|
94
99
|
# parse conditions
|
|
95
100
|
if where_statement is not None:
|
|
@@ -110,13 +115,7 @@ class VectorStoreHandler(BaseHandler):
|
|
|
110
115
|
right_hand = node.args[1].value
|
|
111
116
|
elif isinstance(node.args[1], Tuple):
|
|
112
117
|
# Constant could be actually a list i.e. [1.2, 3.2]
|
|
113
|
-
right_hand = [
|
|
114
|
-
ast.literal_eval(item.value)
|
|
115
|
-
if isinstance(item, Constant)
|
|
116
|
-
and not isinstance(item.value, list)
|
|
117
|
-
else item.value
|
|
118
|
-
for item in node.args[1].items
|
|
119
|
-
]
|
|
118
|
+
right_hand = [item.value for item in node.args[1].items]
|
|
120
119
|
else:
|
|
121
120
|
raise Exception(f"Unsupported right hand side: {node.args[1]}")
|
|
122
121
|
conditions.append(
|
|
@@ -125,18 +124,21 @@ class VectorStoreHandler(BaseHandler):
|
|
|
125
124
|
|
|
126
125
|
query_traversal(where_statement, _extract_comparison_conditions)
|
|
127
126
|
|
|
128
|
-
# try to treat conditions that are not in TableField as metadata conditions
|
|
129
|
-
for condition in conditions:
|
|
130
|
-
if not self._is_condition_allowed(condition):
|
|
131
|
-
condition.column = (
|
|
132
|
-
TableField.METADATA.value + "." + condition.column
|
|
133
|
-
)
|
|
134
|
-
|
|
135
127
|
else:
|
|
136
128
|
conditions = None
|
|
137
129
|
|
|
138
130
|
return conditions
|
|
139
131
|
|
|
132
|
+
def _convert_metadata_filters(self, conditions):
|
|
133
|
+
if conditions is None:
|
|
134
|
+
return
|
|
135
|
+
# try to treat conditions that are not in TableField as metadata conditions
|
|
136
|
+
for condition in conditions:
|
|
137
|
+
if not self._is_condition_allowed(condition):
|
|
138
|
+
condition.column = (
|
|
139
|
+
TableField.METADATA.value + "." + condition.column
|
|
140
|
+
)
|
|
141
|
+
|
|
140
142
|
def _is_columns_allowed(self, columns: List[str]) -> bool:
|
|
141
143
|
"""
|
|
142
144
|
Check if columns are allowed.
|
|
@@ -234,7 +236,7 @@ class VectorStoreHandler(BaseHandler):
|
|
|
234
236
|
|
|
235
237
|
return self.do_upsert(table_name, pd.DataFrame(data))
|
|
236
238
|
|
|
237
|
-
def
|
|
239
|
+
def dispatch_update(self, query: Update, conditions: List[FilterCondition] = None):
|
|
238
240
|
"""
|
|
239
241
|
Dispatch update query to the appropriate method.
|
|
240
242
|
"""
|
|
@@ -253,8 +255,15 @@ class VectorStoreHandler(BaseHandler):
|
|
|
253
255
|
pass
|
|
254
256
|
row[k] = v
|
|
255
257
|
|
|
256
|
-
|
|
257
|
-
|
|
258
|
+
if conditions is None:
|
|
259
|
+
where_statement = query.where
|
|
260
|
+
conditions = self.extract_conditions(where_statement)
|
|
261
|
+
|
|
262
|
+
for condition in conditions:
|
|
263
|
+
if condition.op != FilterOperator.EQUAL:
|
|
264
|
+
raise NotImplementedError
|
|
265
|
+
|
|
266
|
+
row[condition.column] = condition.value
|
|
258
267
|
|
|
259
268
|
# checks
|
|
260
269
|
if TableField.EMBEDDINGS.value not in row:
|
|
@@ -325,14 +334,16 @@ class VectorStoreHandler(BaseHandler):
|
|
|
325
334
|
if not df_insert.empty:
|
|
326
335
|
self.insert(table_name, df_insert)
|
|
327
336
|
|
|
328
|
-
def dispatch_delete(self, query: Delete):
|
|
337
|
+
def dispatch_delete(self, query: Delete, conditions: List[FilterCondition] = None):
|
|
329
338
|
"""
|
|
330
339
|
Dispatch delete query to the appropriate method.
|
|
331
340
|
"""
|
|
332
341
|
# parse key arguments
|
|
333
342
|
table_name = query.table.parts[-1]
|
|
334
|
-
|
|
335
|
-
|
|
343
|
+
if conditions is None:
|
|
344
|
+
where_statement = query.where
|
|
345
|
+
conditions = self.extract_conditions(where_statement)
|
|
346
|
+
self._convert_metadata_filters(conditions)
|
|
336
347
|
|
|
337
348
|
# dispatch delete
|
|
338
349
|
return self.delete(table_name, conditions=conditions)
|
|
@@ -356,9 +367,10 @@ class VectorStoreHandler(BaseHandler):
|
|
|
356
367
|
)
|
|
357
368
|
|
|
358
369
|
# check if columns are allowed
|
|
359
|
-
where_statement = query.where
|
|
360
370
|
if conditions is None:
|
|
361
|
-
|
|
371
|
+
where_statement = query.where
|
|
372
|
+
conditions = self.extract_conditions(where_statement)
|
|
373
|
+
self._convert_metadata_filters(conditions)
|
|
362
374
|
|
|
363
375
|
# get offset and limit
|
|
364
376
|
offset = query.offset.value if query.offset is not None else None
|
|
@@ -381,7 +393,7 @@ class VectorStoreHandler(BaseHandler):
|
|
|
381
393
|
CreateTable: self._dispatch_create_table,
|
|
382
394
|
DropTables: self._dispatch_drop_table,
|
|
383
395
|
Insert: self._dispatch_insert,
|
|
384
|
-
Update: self.
|
|
396
|
+
Update: self.dispatch_update,
|
|
385
397
|
Delete: self.dispatch_delete,
|
|
386
398
|
Select: self.dispatch_select,
|
|
387
399
|
}
|