MindsDB 25.4.1.0__py3-none-any.whl → 25.4.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of MindsDB might be problematic. Click here for more details.

Files changed (63) hide show
  1. mindsdb/__about__.py +1 -1
  2. mindsdb/api/executor/command_executor.py +91 -61
  3. mindsdb/api/executor/data_types/answer.py +9 -12
  4. mindsdb/api/executor/datahub/classes/response.py +11 -0
  5. mindsdb/api/executor/datahub/datanodes/datanode.py +4 -4
  6. mindsdb/api/executor/datahub/datanodes/information_schema_datanode.py +10 -11
  7. mindsdb/api/executor/datahub/datanodes/integration_datanode.py +22 -16
  8. mindsdb/api/executor/datahub/datanodes/mindsdb_tables.py +43 -1
  9. mindsdb/api/executor/datahub/datanodes/project_datanode.py +20 -20
  10. mindsdb/api/executor/planner/plan_join.py +2 -2
  11. mindsdb/api/executor/planner/query_plan.py +1 -0
  12. mindsdb/api/executor/planner/query_planner.py +86 -14
  13. mindsdb/api/executor/planner/steps.py +11 -2
  14. mindsdb/api/executor/sql_query/result_set.py +10 -7
  15. mindsdb/api/executor/sql_query/sql_query.py +69 -84
  16. mindsdb/api/executor/sql_query/steps/__init__.py +1 -0
  17. mindsdb/api/executor/sql_query/steps/delete_step.py +2 -3
  18. mindsdb/api/executor/sql_query/steps/fetch_dataframe.py +5 -3
  19. mindsdb/api/executor/sql_query/steps/fetch_dataframe_partition.py +288 -0
  20. mindsdb/api/executor/sql_query/steps/insert_step.py +2 -2
  21. mindsdb/api/executor/sql_query/steps/prepare_steps.py +2 -2
  22. mindsdb/api/executor/sql_query/steps/subselect_step.py +20 -8
  23. mindsdb/api/executor/sql_query/steps/update_step.py +4 -6
  24. mindsdb/api/http/namespaces/sql.py +4 -1
  25. mindsdb/api/mysql/mysql_proxy/data_types/mysql_packets/ok_packet.py +1 -1
  26. mindsdb/api/mysql/mysql_proxy/executor/mysql_executor.py +4 -27
  27. mindsdb/api/mysql/mysql_proxy/libs/constants/mysql.py +1 -0
  28. mindsdb/api/mysql/mysql_proxy/mysql_proxy.py +38 -37
  29. mindsdb/integrations/handlers/chromadb_handler/chromadb_handler.py +23 -13
  30. mindsdb/integrations/handlers/langchain_embedding_handler/langchain_embedding_handler.py +17 -16
  31. mindsdb/integrations/handlers/langchain_handler/langchain_handler.py +1 -0
  32. mindsdb/integrations/handlers/mssql_handler/mssql_handler.py +1 -1
  33. mindsdb/integrations/handlers/mysql_handler/mysql_handler.py +3 -2
  34. mindsdb/integrations/handlers/oracle_handler/oracle_handler.py +4 -4
  35. mindsdb/integrations/handlers/pgvector_handler/pgvector_handler.py +26 -16
  36. mindsdb/integrations/handlers/postgres_handler/postgres_handler.py +36 -7
  37. mindsdb/integrations/handlers/redshift_handler/redshift_handler.py +1 -1
  38. mindsdb/integrations/handlers/snowflake_handler/snowflake_handler.py +18 -11
  39. mindsdb/integrations/libs/llm/config.py +11 -1
  40. mindsdb/integrations/libs/llm/utils.py +12 -0
  41. mindsdb/integrations/libs/ml_handler_process/learn_process.py +1 -2
  42. mindsdb/integrations/libs/response.py +9 -4
  43. mindsdb/integrations/libs/vectordatabase_handler.py +17 -5
  44. mindsdb/integrations/utilities/rag/rerankers/reranker_compressor.py +8 -98
  45. mindsdb/interfaces/agents/constants.py +12 -1
  46. mindsdb/interfaces/agents/langchain_agent.py +6 -0
  47. mindsdb/interfaces/database/log.py +8 -9
  48. mindsdb/interfaces/database/projects.py +1 -5
  49. mindsdb/interfaces/functions/controller.py +59 -17
  50. mindsdb/interfaces/functions/to_markdown.py +194 -0
  51. mindsdb/interfaces/jobs/jobs_controller.py +3 -3
  52. mindsdb/interfaces/knowledge_base/controller.py +223 -97
  53. mindsdb/interfaces/knowledge_base/preprocessing/document_preprocessor.py +3 -14
  54. mindsdb/interfaces/query_context/context_controller.py +224 -1
  55. mindsdb/interfaces/storage/db.py +23 -0
  56. mindsdb/migrations/versions/2025-03-21_fda503400e43_queries.py +45 -0
  57. mindsdb/utilities/context_executor.py +1 -1
  58. mindsdb/utilities/partitioning.py +35 -20
  59. {mindsdb-25.4.1.0.dist-info → mindsdb-25.4.2.1.dist-info}/METADATA +227 -224
  60. {mindsdb-25.4.1.0.dist-info → mindsdb-25.4.2.1.dist-info}/RECORD +63 -59
  61. {mindsdb-25.4.1.0.dist-info → mindsdb-25.4.2.1.dist-info}/WHEEL +0 -0
  62. {mindsdb-25.4.1.0.dist-info → mindsdb-25.4.2.1.dist-info}/licenses/LICENSE +0 -0
  63. {mindsdb-25.4.1.0.dist-info → mindsdb-25.4.2.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,288 @@
1
+ import pandas as pd
2
+ import threading
3
+ import queue
4
+ from typing import List
5
+
6
+ from mindsdb_sql_parser import ASTNode
7
+ from mindsdb.api.executor.planner.steps import FetchDataframeStepPartition
8
+ from mindsdb.integrations.utilities.query_traversal import query_traversal
9
+
10
+ from mindsdb.interfaces.query_context.context_controller import RunningQuery
11
+ from mindsdb.api.executor.sql_query.result_set import ResultSet
12
+ from mindsdb.utilities import log
13
+ from mindsdb.utilities.config import Config
14
+ from mindsdb.utilities.context import Context, context as ctx
15
+ from mindsdb.utilities.partitioning import get_max_thread_count, split_data_frame
16
+ from mindsdb.api.executor.sql_query.steps.fetch_dataframe import get_table_alias, get_fill_param_fnc
17
+
18
+ from .base import BaseStepCall
19
+
20
+
21
+ logger = log.getLogger(__name__)
22
+
23
+
24
+ class FetchDataframePartitionCall(BaseStepCall):
25
+ """
26
+ Alternative to FetchDataframeCall but fetch data by batches wrapping user's query to:
27
+
28
+ select * from ({user query})
29
+ where {track_column} > {previous value}
30
+ order by track_column
31
+ limit size {batch_size} `
32
+
33
+ """
34
+
35
+ bind = FetchDataframeStepPartition
36
+
37
+ def call(self, step: FetchDataframeStepPartition) -> ResultSet:
38
+ """
39
+ Parameters:
40
+ - batch_size - count of rows to fetch from database per iteration, optional default 1000
41
+ - threads - run partitioning in threads, bool or int, optinal, if set:
42
+ - int value: use this as count of threads
43
+ - true: table threads, autodetect count of thread
44
+ - false: disable threads even if ml task queue is enabled
45
+ - track_column - column used for creating partitions
46
+ - query will be sorted by this column and select will be limited by batch_size
47
+ - error (default raise)
48
+ - when `error='skip'`, errors in partition will be skipped and execution will be continued
49
+ """
50
+
51
+ self.dn = self.session.datahub.get(step.integration)
52
+ query = step.query
53
+
54
+ # fill params
55
+ fill_params = get_fill_param_fnc(self.steps_data)
56
+ query_traversal(query, fill_params)
57
+
58
+ # get query record
59
+ run_query = self.sql_query.run_query
60
+ if run_query is None:
61
+ raise RuntimeError('Error with partitioning of the query')
62
+ run_query.set_params(step.params)
63
+
64
+ self.table_alias = get_table_alias(step.query.from_table, self.context.get('database'))
65
+ self.current_step_num = step.step_num
66
+ self.substeps = step.steps
67
+
68
+ config = Config()
69
+
70
+ # ml task queue enabled?
71
+ use_threads, thread_count = False, None
72
+ if config['ml_task_queue']['type'] == 'redis':
73
+ use_threads = True
74
+
75
+ # use threads?
76
+ if 'threads' in step.params:
77
+ threads = step.params['threads']
78
+ if isinstance(threads, int):
79
+ thread_count = threads
80
+ use_threads = True
81
+ if threads is True:
82
+ use_threads = True
83
+ if threads is False:
84
+ # disable even with ml task queue
85
+ use_threads = False
86
+
87
+ on_error = step.params.get('error', 'raise')
88
+ if use_threads:
89
+ return self.fetch_threads(run_query, query, thread_count=thread_count, on_error=on_error)
90
+ else:
91
+ return self.fetch_iterate(run_query, query, on_error=on_error)
92
+
93
+ def fetch_iterate(self, run_query: RunningQuery, query: ASTNode, on_error: str = None) -> ResultSet:
94
+ """
95
+ Process batches one by one in circle
96
+ """
97
+
98
+ results = []
99
+ while True:
100
+
101
+ # fetch batch
102
+ query2 = run_query.get_partition_query(self.current_step_num, query)
103
+ response = self.dn.query(
104
+ query=query2,
105
+ session=self.session
106
+ )
107
+ df = response.data_frame
108
+
109
+ if df is None or len(df) == 0:
110
+ break
111
+
112
+ # executing of sub steps can modify dataframe columns, lets memorise max tracking value
113
+ max_track_value = run_query.get_max_track_value(df)
114
+ try:
115
+ sub_data = self.exec_sub_steps(df)
116
+ results.append(sub_data)
117
+ except Exception as e:
118
+ if on_error == 'skip':
119
+ logger.error(e)
120
+ else:
121
+ raise e
122
+
123
+ run_query.set_progress(df, max_track_value)
124
+
125
+ return self.concat_results(results)
126
+
127
+ def concat_results(self, results: List[ResultSet]) -> ResultSet:
128
+ """
129
+ Concatenate list of result sets to single result set
130
+ """
131
+ df_list = []
132
+ for res in results:
133
+ df, col_names = res.to_df_cols()
134
+ if len(df) > 0:
135
+ df_list.append(df)
136
+
137
+ data = ResultSet()
138
+ if len(df_list) > 0:
139
+ data.from_df_cols(pd.concat(df_list), col_names)
140
+
141
+ return data
142
+
143
+ def exec_sub_steps(self, df: pd.DataFrame) -> ResultSet:
144
+ """
145
+ FetchDataframeStepPartition has substeps defined
146
+ Every batch of data have to be used to execute these substeps
147
+ - batch of data is put as result of FetchDataframeStepPartition
148
+ - substep are executed using result of previos step (like it is all fetched data is available)
149
+ - the final result is returned and used outside to concatenate with results of other's batches
150
+ """
151
+
152
+ input_data = ResultSet()
153
+
154
+ input_data.from_df(
155
+ df,
156
+ table_name=self.table_alias[1],
157
+ table_alias=self.table_alias[2],
158
+ database=self.table_alias[0]
159
+ )
160
+
161
+ # execute with modified previous results
162
+ steps_data2 = self.steps_data.copy()
163
+ steps_data2[self.current_step_num] = input_data
164
+
165
+ sub_data = None
166
+ for substep in self.substeps:
167
+ sub_data = self.sql_query.execute_step(substep, steps_data=steps_data2)
168
+ steps_data2[substep.step_num] = sub_data
169
+ return sub_data
170
+
171
+ def fetch_threads(self, run_query: RunningQuery, query: ASTNode,
172
+ thread_count: int = None, on_error: str = None) -> ResultSet:
173
+ """
174
+ Process batches in threads
175
+ - spawn required count of threads
176
+ - create in/out queue to communicate with threads
177
+ - send task to threads and receive results
178
+ """
179
+
180
+ # create communication queues
181
+ queue_in = queue.Queue()
182
+ queue_out = queue.Queue()
183
+ self.stop_event = threading.Event()
184
+
185
+ if thread_count is None:
186
+ thread_count = get_max_thread_count()
187
+
188
+ # 3 tasks per worker during 1 batch
189
+ partition_size = int(run_query.batch_size / thread_count / 3)
190
+ # min partition size
191
+ if partition_size < 10:
192
+ partition_size = 10
193
+
194
+ # create N workers pool
195
+ workers = []
196
+ results = []
197
+
198
+ try:
199
+ for i in range(thread_count):
200
+ worker = threading.Thread(target=self._worker, daemon=True, args=(ctx.dump(), queue_in,
201
+ queue_out, self.stop_event))
202
+ worker.start()
203
+ workers.append(worker)
204
+
205
+ while True:
206
+ # fetch batch
207
+ query2 = run_query.get_partition_query(self.current_step_num, query)
208
+ response = self.dn.query(
209
+ query=query2,
210
+ session=self.session
211
+ )
212
+ df = response.data_frame
213
+
214
+ if df is None or len(df) == 0:
215
+ # TODO detect circles: data handler ignores condition and output is repeated
216
+
217
+ # exit & stop workers
218
+ break
219
+
220
+ max_track_value = run_query.get_max_track_value(df)
221
+
222
+ # split into chunks and send to workers
223
+ sent_chunks = 0
224
+ for df2 in split_data_frame(df, partition_size):
225
+ queue_in.put([sent_chunks, df2])
226
+ sent_chunks += 1
227
+
228
+ batch_results = []
229
+ for i in range(sent_chunks):
230
+ res = queue_out.get()
231
+ if 'error' in res:
232
+ if on_error == 'skip':
233
+ logger.error(res['error'])
234
+ else:
235
+ raise RuntimeError(res['error'])
236
+
237
+ if res['data']:
238
+ batch_results.append(res)
239
+
240
+ # sort results
241
+ batch_results.sort(key=lambda x: x['num'])
242
+
243
+ results.append(self.concat_results(
244
+ [item['data'] for item in batch_results]
245
+ ))
246
+
247
+ # TODO
248
+ # 1. get next batch without updating track_value:
249
+ # it allows to keep queue_in filled with data between fetching batches
250
+ run_query.set_progress(df, max_track_value)
251
+ finally:
252
+ self.close_workers(workers)
253
+
254
+ return self.concat_results(results)
255
+
256
+ def close_workers(self, workers: List[threading.Thread]):
257
+ """
258
+ Sent signal to workers to stop
259
+ """
260
+
261
+ self.stop_event.set()
262
+ for worker in workers:
263
+ if worker.is_alive():
264
+ worker.join()
265
+
266
+ def _worker(self, context: Context, queue_in: queue.Queue, queue_out: queue.Queue, stop_event: threading.Event):
267
+ """
268
+ Worker function. Execute incoming tasks unless stop_event is set
269
+ """
270
+ ctx.load(context)
271
+ while True:
272
+ if stop_event.is_set():
273
+ break
274
+
275
+ try:
276
+ chunk_num, df = queue_in.get(timeout=1)
277
+ if df is None:
278
+ continue
279
+
280
+ sub_data = self.exec_sub_steps(df)
281
+
282
+ queue_out.put({'data': sub_data, 'num': chunk_num})
283
+ except queue.Empty:
284
+ continue
285
+
286
+ except Exception as e:
287
+ queue_out.put({'error': str(e)})
288
+ stop_event.set()
@@ -91,13 +91,13 @@ class InsertToTableCall(BaseStepCall):
91
91
  else:
92
92
  col_names.add(col.alias)
93
93
 
94
- dn.create_table(
94
+ response = dn.create_table(
95
95
  table_name=table_name,
96
96
  result_set=data,
97
97
  is_replace=is_replace,
98
98
  is_create=is_create
99
99
  )
100
- return ResultSet()
100
+ return ResultSet(affected_rows=response.affected_rows)
101
101
 
102
102
 
103
103
  class SaveToTableCall(InsertToTableCall):
@@ -47,10 +47,10 @@ class GetTableColumnsCall(BaseStepCall):
47
47
  dn = self.session.datahub.get(step.namespace)
48
48
  ds_query = Select(from_table=Identifier(table), targets=[Star()], limit=Constant(0))
49
49
 
50
- data, columns_info = dn.query(ds_query, session=self.session)
50
+ response = dn.query(ds_query, session=self.session)
51
51
 
52
52
  data = ResultSet()
53
- for column in columns_info:
53
+ for column in response.columns:
54
54
  data.add_column(Column(
55
55
  name=column['name'],
56
56
  type=column.get('type'),
@@ -3,13 +3,7 @@ from collections import defaultdict
3
3
  import pandas as pd
4
4
 
5
5
  from mindsdb_sql_parser.ast import (
6
- Identifier,
7
- Select,
8
- Star,
9
- Constant,
10
- Parameter,
11
- Function,
12
- Variable
6
+ Identifier, Select, Star, Constant, Parameter, Function, Variable, BinaryOperation
13
7
  )
14
8
 
15
9
  from mindsdb.api.mysql.mysql_proxy.libs.constants.mysql import SERVER_VARIABLES
@@ -87,7 +81,7 @@ class QueryStepCall(BaseStepCall):
87
81
 
88
82
  bind = QueryStep
89
83
 
90
- def call(self, step):
84
+ def call(self, step: QueryStep):
91
85
  query = step.query
92
86
 
93
87
  if step.from_table is not None:
@@ -190,6 +184,24 @@ class QueryStepCall(BaseStepCall):
190
184
  fill_params = get_fill_param_fnc(self.steps_data)
191
185
  query_traversal(query, fill_params)
192
186
 
187
+ if not step.strict_where:
188
+ # remove conditions with not-existed columns.
189
+ # these conditions can be already used as input to model or knowledge base
190
+ # but can be absent in their output
191
+
192
+ def remove_not_used_conditions(node, **kwargs):
193
+ # find last in where
194
+ if isinstance(node, BinaryOperation):
195
+ for arg in node.args:
196
+ if isinstance(arg, Identifier) and len(arg.parts) > 1:
197
+ key = tuple(arg.parts[-2:])
198
+ if key not in col_idx:
199
+ # exclude
200
+ node.args = [Constant(0), Constant(0)]
201
+ node.op = '='
202
+
203
+ query_traversal(query.where, remove_not_used_conditions)
204
+
193
205
  query_traversal(query, check_fields)
194
206
  query.where = query_context_controller.remove_lasts(query.where)
195
207
 
@@ -18,8 +18,6 @@ class UpdateToTableCall(BaseStepCall):
18
18
  bind = UpdateToTable
19
19
 
20
20
  def call(self, step):
21
- data = ResultSet()
22
-
23
21
  if len(step.table.parts) > 1:
24
22
  integration_name = step.table.parts[0]
25
23
  table_name_parts = step.table.parts[1:]
@@ -85,8 +83,8 @@ class UpdateToTableCall(BaseStepCall):
85
83
 
86
84
  if result_step is None:
87
85
  # run as is
88
- dn.query(query=update_query, session=self.session)
89
- return data
86
+ response = dn.query(query=update_query, session=self.session)
87
+ return ResultSet(affected_rows=response.affected_rows)
90
88
  result_data = self.steps_data[result_step.result.step_num]
91
89
 
92
90
  # link nodes with parameters for fast replacing with values
@@ -125,5 +123,5 @@ class UpdateToTableCall(BaseStepCall):
125
123
  for param_name, param in params_map_index:
126
124
  param.value = row[param_name]
127
125
 
128
- dn.query(query=update_query, session=self.session)
129
- return data
126
+ response = dn.query(query=update_query, session=self.session)
127
+ return ResultSet(affected_rows=response.affected_rows)
@@ -59,7 +59,10 @@ class Query(Resource):
59
59
  result = mysql_proxy.process_query(query)
60
60
 
61
61
  if result.type == SQL_RESPONSE_TYPE.OK:
62
- query_response = {"type": SQL_RESPONSE_TYPE.OK}
62
+ query_response = {
63
+ "type": SQL_RESPONSE_TYPE.OK,
64
+ "affected_rows": result.affected_rows
65
+ }
63
66
  elif result.type == SQL_RESPONSE_TYPE.TABLE:
64
67
  data = result.data.to_lists(json_types=True)
65
68
  query_response = {
@@ -40,7 +40,7 @@ class OkPacket(Packet):
40
40
  def setup(self):
41
41
  eof = self._kwargs.get('eof', False)
42
42
  self.ok_header = Datum('int<1>', 0xFE if eof is True else 0)
43
- self.affected_rows = Datum('int<lenenc>', self._kwargs.get('affected_rows', 0))
43
+ self.affected_rows = Datum('int<lenenc>', self._kwargs.get('affected_rows') or 0)
44
44
  self.last_insert_id = Datum('int<lenenc>', 0)
45
45
  status = self._kwargs.get('status', 0x0002)
46
46
  self.server_status = Datum('int<2>', status)
@@ -4,6 +4,7 @@ from mindsdb.api.executor.planner import utils as planner_utils
4
4
  import mindsdb.utilities.profiler as profiler
5
5
  from mindsdb.api.executor.sql_query.result_set import Column
6
6
  from mindsdb.api.executor.sql_query import SQLQuery
7
+ from mindsdb.api.executor.data_types.answer import ExecuteAnswer
7
8
  from mindsdb.api.executor.command_executor import ExecuteCommands
8
9
  from mindsdb.api.mysql.mysql_proxy.utilities import ErSqlSyntaxError
9
10
  from mindsdb.utilities import log
@@ -12,37 +13,20 @@ logger = log.getLogger(__name__)
12
13
 
13
14
 
14
15
  class Executor:
15
- """This class stores initial and intermediate params
16
- between different steps of query execution. And it is also
17
- creates a separate instance of ExecuteCommands to execute the current
18
- query step.
19
-
20
- IMPORTANT: A public API of this class is a contract.
21
- And there are at least 2 classes strongly depend on it:
22
- ExecuctorClient
23
- ExecutorService.
24
- These classes do the same work as Executor when
25
- MindsDB works in 'modularity' mode.
26
- Thus please make sure that IF you change the API,
27
- you must update the API of these two classes as well!"""
28
-
29
16
  def __init__(self, session, sqlserver):
30
17
  self.session = session
31
18
  self.sqlserver = sqlserver
32
19
 
33
20
  self.query = None
34
21
 
35
- # returned values
36
- # all this attributes needs to be added in
37
- # self.json() method
38
22
  self.columns = []
39
23
  self.params = []
40
24
  self.data = None
41
- self.state_track = None
42
25
  self.server_status = None
43
26
  self.is_executed = False
44
27
  self.error_message = None
45
28
  self.error_code = None
29
+ self.executor_answer: ExecuteAnswer = None
46
30
 
47
31
  self.sql = ""
48
32
  self.sql_lower = ""
@@ -126,14 +110,7 @@ class Executor:
126
110
  if self.is_executed:
127
111
  return
128
112
 
129
- ret = self.command_executor.execute_command(self.query)
130
- self.error_code = ret.error_code
131
- self.error_message = ret.error_message
113
+ executor_answer: ExecuteAnswer = self.command_executor.execute_command(self.query)
114
+ self.executor_answer = executor_answer
132
115
 
133
116
  self.is_executed = True
134
-
135
- if ret.data is not None:
136
- self.data = ret.data
137
- self.columns = ret.data.columns
138
-
139
- self.state_track = ret.state_track
@@ -94,6 +94,7 @@ class COMMANDS(object):
94
94
  COM_STMT_PREPARE = int('0x16', 0)
95
95
  COM_STMT_EXECUTE = int('0x17', 0)
96
96
  COM_STMT_FETCH = int('0x1c', 0)
97
+ COM_STMT_RESET = int('0x1a', 0)
97
98
  COM_STMT_CLOSE = int('0x19', 0)
98
99
  COM_FIELD_LIST = int('0x04', 0) # deprecated
99
100
 
@@ -21,7 +21,8 @@ import sys
21
21
  import tempfile
22
22
  import traceback
23
23
  from functools import partial
24
- from typing import Dict, List
24
+ from typing import Dict, List, Optional
25
+ from dataclasses import dataclass
25
26
 
26
27
  from numpy import dtype as np_dtype
27
28
  from pandas.api import types as pd_types
@@ -71,6 +72,7 @@ from mindsdb.api.mysql.mysql_proxy.libs.constants.mysql import (
71
72
  TYPES,
72
73
  getConstName,
73
74
  )
75
+ from mindsdb.api.executor.data_types.answer import ExecuteAnswer
74
76
  from mindsdb.api.executor.data_types.response_type import RESPONSE_TYPE
75
77
  from mindsdb.api.mysql.mysql_proxy.utilities import (
76
78
  ErWrongCharset,
@@ -93,24 +95,16 @@ def empty_fn():
93
95
  pass
94
96
 
95
97
 
98
+ @dataclass
96
99
  class SQLAnswer:
97
- def __init__(
98
- self,
99
- resp_type: RESPONSE_TYPE,
100
- columns: List[Dict] = None,
101
- data: List[Dict] = None,
102
- status: int = None,
103
- state_track: List[List] = None,
104
- error_code: int = None,
105
- error_message: str = None,
106
- ):
107
- self.resp_type = resp_type
108
- self.columns = columns
109
- self.data = data
110
- self.status = status
111
- self.state_track = state_track
112
- self.error_code = error_code
113
- self.error_message = error_message
100
+ resp_type: RESPONSE_TYPE = RESPONSE_TYPE.OK
101
+ columns: Optional[List[Dict]] = None
102
+ data: Optional[List[Dict]] = None # resultSet ?
103
+ status: Optional[int] = None
104
+ state_track: Optional[List[List]] = None
105
+ error_code: Optional[int] = None
106
+ error_message: Optional[str] = None
107
+ affected_rows: Optional[int] = None
114
108
 
115
109
  @property
116
110
  def type(self):
@@ -333,7 +327,7 @@ class MysqlProxy(SocketServer.BaseRequestHandler):
333
327
  packages.append(self.last_packet())
334
328
  self.send_package_group(packages)
335
329
  elif answer.type == RESPONSE_TYPE.OK:
336
- self.packet(OkPacket, state_track=answer.state_track).send()
330
+ self.packet(OkPacket, state_track=answer.state_track, affected_rows=answer.affected_rows).send()
337
331
  elif answer.type == RESPONSE_TYPE.ERROR:
338
332
  self.packet(
339
333
  ErrPacket, err_code=answer.error_code, msg=answer.error_message
@@ -546,21 +540,23 @@ class MysqlProxy(SocketServer.BaseRequestHandler):
546
540
  @profiler.profile()
547
541
  def process_query(self, sql):
548
542
  executor = Executor(session=self.session, sqlserver=self)
549
-
550
543
  executor.query_execute(sql)
544
+ executor_answer = executor.executor_answer
551
545
 
552
- if executor.data is None:
546
+ if executor_answer.data is None:
553
547
  resp = SQLAnswer(
554
548
  resp_type=RESPONSE_TYPE.OK,
555
- state_track=executor.state_track,
549
+ state_track=executor_answer.state_track,
550
+ affected_rows=executor_answer.affected_rows
556
551
  )
557
552
  else:
558
553
  resp = SQLAnswer(
559
554
  resp_type=RESPONSE_TYPE.TABLE,
560
- state_track=executor.state_track,
561
- columns=self.to_mysql_columns(executor.columns),
562
- data=executor.data,
555
+ state_track=executor_answer.state_track,
556
+ columns=self.to_mysql_columns(executor_answer.data.columns),
557
+ data=executor_answer.data,
563
558
  status=executor.server_status,
559
+ affected_rows=executor_answer.affected_rows
564
560
  )
565
561
 
566
562
  # Increment the counter and include metadata in attributes
@@ -604,18 +600,20 @@ class MysqlProxy(SocketServer.BaseRequestHandler):
604
600
 
605
601
  def answer_stmt_execute(self, stmt_id, parameters):
606
602
  prepared_stmt = self.session.prepared_stmts[stmt_id]
607
- executor = prepared_stmt["statement"]
603
+ executor: Executor = prepared_stmt["statement"]
608
604
 
609
605
  executor.stmt_execute(parameters)
610
606
 
611
- if executor.data is None:
607
+ executor_answer: ExecuteAnswer = executor.executor_answer
608
+
609
+ if executor_answer.data is None:
612
610
  resp = SQLAnswer(
613
- resp_type=RESPONSE_TYPE.OK, state_track=executor.state_track
611
+ resp_type=RESPONSE_TYPE.OK, state_track=executor_answer.state_track
614
612
  )
615
613
  return self.send_query_answer(resp)
616
614
 
617
615
  # TODO prepared_stmt['type'] == 'lock' is not used but it works
618
- columns_def = self.to_mysql_columns(executor.columns)
616
+ columns_def = self.to_mysql_columns(executor_answer.data.columns)
619
617
  packages = [self.packet(ColumnCountPacket, count=len(columns_def))]
620
618
 
621
619
  packages.extend(self._get_column_defenition_packets(columns_def))
@@ -624,14 +622,14 @@ class MysqlProxy(SocketServer.BaseRequestHandler):
624
622
  packages.append(self.packet(EofPacket, status=0x0062))
625
623
 
626
624
  # send all
627
- for row in executor.data.to_lists():
625
+ for row in executor_answer.data.to_lists():
628
626
  packages.append(
629
627
  self.packet(BinaryResultsetRowPacket, data=row, columns=columns_def)
630
628
  )
631
629
 
632
630
  server_status = executor.server_status or 0x0002
633
631
  packages.append(self.last_packet(status=server_status))
634
- prepared_stmt["fetched"] += len(executor.data)
632
+ prepared_stmt["fetched"] += len(executor_answer.data)
635
633
 
636
634
  return self.send_package_group(packages)
637
635
 
@@ -639,23 +637,24 @@ class MysqlProxy(SocketServer.BaseRequestHandler):
639
637
  prepared_stmt = self.session.prepared_stmts[stmt_id]
640
638
  executor = prepared_stmt["statement"]
641
639
  fetched = prepared_stmt["fetched"]
640
+ executor_answer: ExecuteAnswer = executor.executor_answer
642
641
 
643
- if executor.data is None:
642
+ if executor_answer.data is None:
644
643
  resp = SQLAnswer(
645
- resp_type=RESPONSE_TYPE.OK, state_track=executor.state_track
644
+ resp_type=RESPONSE_TYPE.OK, state_track=executor_answer.state_track
646
645
  )
647
646
  return self.send_query_answer(resp)
648
647
 
649
648
  packages = []
650
- columns = self.to_mysql_columns(executor.columns)
651
- for row in executor.data[fetched:limit].to_lists():
649
+ columns = self.to_mysql_columns(executor_answer.data.columns)
650
+ for row in executor_answer.data[fetched:limit].to_lists():
652
651
  packages.append(
653
652
  self.packet(BinaryResultsetRowPacket, data=row, columns=columns)
654
653
  )
655
654
 
656
- prepared_stmt["fetched"] += len(executor.data[fetched:limit])
655
+ prepared_stmt["fetched"] += len(executor_answer.data[fetched:limit])
657
656
 
658
- if len(executor.data) <= limit + fetched:
657
+ if len(executor_answer.data) <= limit + fetched:
659
658
  status = sum(
660
659
  [
661
660
  SERVER_STATUS.SERVER_STATUS_AUTOCOMMIT,
@@ -772,6 +771,8 @@ class MysqlProxy(SocketServer.BaseRequestHandler):
772
771
  elif p.type.value == COMMANDS.COM_FIELD_LIST:
773
772
  # this command is deprecated, but console client still use it.
774
773
  response = SQLAnswer(RESPONSE_TYPE.OK)
774
+ elif p.type.value == COMMANDS.COM_STMT_RESET:
775
+ response = SQLAnswer(RESPONSE_TYPE.OK)
775
776
  else:
776
777
  logger.warning("Command has no specific handler, return OK msg")
777
778
  logger.debug(str(p))