MindsDB 25.5.4.1__py3-none-any.whl → 25.6.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of MindsDB might be problematic. Click here for more details.
- mindsdb/__about__.py +1 -1
- mindsdb/api/a2a/agent.py +28 -25
- mindsdb/api/a2a/common/server/server.py +32 -26
- mindsdb/api/a2a/run_a2a.py +1 -1
- mindsdb/api/executor/command_executor.py +69 -14
- mindsdb/api/executor/datahub/datanodes/integration_datanode.py +49 -65
- mindsdb/api/executor/datahub/datanodes/project_datanode.py +29 -48
- mindsdb/api/executor/datahub/datanodes/system_tables.py +35 -61
- mindsdb/api/executor/planner/plan_join.py +67 -77
- mindsdb/api/executor/planner/query_planner.py +176 -155
- mindsdb/api/executor/planner/steps.py +37 -12
- mindsdb/api/executor/sql_query/result_set.py +45 -64
- mindsdb/api/executor/sql_query/steps/fetch_dataframe.py +14 -18
- mindsdb/api/executor/sql_query/steps/fetch_dataframe_partition.py +17 -18
- mindsdb/api/executor/sql_query/steps/insert_step.py +13 -33
- mindsdb/api/executor/sql_query/steps/subselect_step.py +43 -35
- mindsdb/api/executor/utilities/sql.py +42 -48
- mindsdb/api/http/namespaces/config.py +1 -1
- mindsdb/api/http/namespaces/file.py +14 -23
- mindsdb/api/mysql/mysql_proxy/data_types/mysql_datum.py +12 -28
- mindsdb/api/mysql/mysql_proxy/data_types/mysql_packets/binary_resultset_row_package.py +59 -50
- mindsdb/api/mysql/mysql_proxy/data_types/mysql_packets/resultset_row_package.py +9 -8
- mindsdb/api/mysql/mysql_proxy/libs/constants/mysql.py +449 -461
- mindsdb/api/mysql/mysql_proxy/utilities/dump.py +87 -36
- mindsdb/integrations/handlers/file_handler/file_handler.py +15 -9
- mindsdb/integrations/handlers/file_handler/tests/test_file_handler.py +43 -24
- mindsdb/integrations/handlers/litellm_handler/litellm_handler.py +10 -3
- mindsdb/integrations/handlers/mysql_handler/mysql_handler.py +26 -33
- mindsdb/integrations/handlers/oracle_handler/oracle_handler.py +74 -51
- mindsdb/integrations/handlers/postgres_handler/postgres_handler.py +305 -98
- mindsdb/integrations/handlers/salesforce_handler/salesforce_handler.py +53 -34
- mindsdb/integrations/handlers/salesforce_handler/salesforce_tables.py +136 -6
- mindsdb/integrations/handlers/snowflake_handler/snowflake_handler.py +334 -83
- mindsdb/integrations/libs/api_handler.py +261 -57
- mindsdb/integrations/libs/base.py +100 -29
- mindsdb/integrations/utilities/files/file_reader.py +99 -73
- mindsdb/integrations/utilities/handler_utils.py +23 -8
- mindsdb/integrations/utilities/sql_utils.py +35 -40
- mindsdb/interfaces/agents/agents_controller.py +196 -192
- mindsdb/interfaces/agents/constants.py +7 -1
- mindsdb/interfaces/agents/langchain_agent.py +42 -11
- mindsdb/interfaces/agents/mcp_client_agent.py +29 -21
- mindsdb/interfaces/data_catalog/__init__.py +0 -0
- mindsdb/interfaces/data_catalog/base_data_catalog.py +54 -0
- mindsdb/interfaces/data_catalog/data_catalog_loader.py +359 -0
- mindsdb/interfaces/data_catalog/data_catalog_reader.py +34 -0
- mindsdb/interfaces/database/database.py +81 -57
- mindsdb/interfaces/database/integrations.py +220 -234
- mindsdb/interfaces/database/log.py +72 -104
- mindsdb/interfaces/database/projects.py +156 -193
- mindsdb/interfaces/file/file_controller.py +21 -65
- mindsdb/interfaces/knowledge_base/controller.py +63 -10
- mindsdb/interfaces/knowledge_base/evaluate.py +519 -0
- mindsdb/interfaces/knowledge_base/llm_client.py +75 -0
- mindsdb/interfaces/skills/custom/text2sql/mindsdb_kb_tools.py +83 -43
- mindsdb/interfaces/skills/skills_controller.py +54 -36
- mindsdb/interfaces/skills/sql_agent.py +109 -86
- mindsdb/interfaces/storage/db.py +223 -79
- mindsdb/migrations/versions/2025-05-28_a44643042fe8_added_data_catalog_tables.py +118 -0
- mindsdb/migrations/versions/2025-06-09_608e376c19a7_updated_data_catalog_data_types.py +58 -0
- mindsdb/utilities/config.py +9 -2
- mindsdb/utilities/log.py +35 -26
- mindsdb/utilities/ml_task_queue/task.py +19 -22
- mindsdb/utilities/render/sqlalchemy_render.py +129 -181
- mindsdb/utilities/starters.py +49 -1
- {mindsdb-25.5.4.1.dist-info → mindsdb-25.6.2.0.dist-info}/METADATA +268 -268
- {mindsdb-25.5.4.1.dist-info → mindsdb-25.6.2.0.dist-info}/RECORD +70 -62
- {mindsdb-25.5.4.1.dist-info → mindsdb-25.6.2.0.dist-info}/WHEEL +0 -0
- {mindsdb-25.5.4.1.dist-info → mindsdb-25.6.2.0.dist-info}/licenses/LICENSE +0 -0
- {mindsdb-25.5.4.1.dist-info → mindsdb-25.6.2.0.dist-info}/top_level.txt +0 -0
|
@@ -14,6 +14,8 @@ from mindsdb.utilities.context import context as ctx
|
|
|
14
14
|
from mindsdb.integrations.utilities.query_traversal import query_traversal
|
|
15
15
|
from mindsdb.integrations.libs.response import INF_SCHEMA_COLUMNS_NAMES
|
|
16
16
|
from mindsdb.api.mysql.mysql_proxy.libs.constants.mysql import MYSQL_DATA_TYPE
|
|
17
|
+
from mindsdb.utilities.config import config
|
|
18
|
+
from mindsdb.interfaces.data_catalog.data_catalog_reader import DataCatalogReader
|
|
17
19
|
|
|
18
20
|
logger = log.getLogger(__name__)
|
|
19
21
|
|
|
@@ -28,7 +30,7 @@ def list_to_csv_str(array: List[List[Any]]) -> str:
|
|
|
28
30
|
str: The array formatted as a CSV string using Excel dialect
|
|
29
31
|
"""
|
|
30
32
|
output = StringIO()
|
|
31
|
-
writer = csv.writer(output, dialect=
|
|
33
|
+
writer = csv.writer(output, dialect="excel")
|
|
32
34
|
str_array = [[str(item) for item in row] for row in array]
|
|
33
35
|
writer.writerows(str_array)
|
|
34
36
|
return output.getvalue()
|
|
@@ -55,23 +57,23 @@ def split_table_name(table_name: str) -> List[str]:
|
|
|
55
57
|
'input': '`aaa`.`bbb.ccc`', 'output': ['aaa', 'bbb.ccc']
|
|
56
58
|
"""
|
|
57
59
|
result = []
|
|
58
|
-
current =
|
|
60
|
+
current = ""
|
|
59
61
|
in_backticks = False
|
|
60
62
|
|
|
61
63
|
i = 0
|
|
62
64
|
while i < len(table_name):
|
|
63
|
-
if table_name[i] ==
|
|
65
|
+
if table_name[i] == "`":
|
|
64
66
|
in_backticks = not in_backticks
|
|
65
|
-
elif table_name[i] ==
|
|
67
|
+
elif table_name[i] == "." and not in_backticks:
|
|
66
68
|
if current:
|
|
67
|
-
result.append(current.strip(
|
|
68
|
-
current =
|
|
69
|
+
result.append(current.strip("`"))
|
|
70
|
+
current = ""
|
|
69
71
|
else:
|
|
70
72
|
current += table_name[i]
|
|
71
73
|
i += 1
|
|
72
74
|
|
|
73
75
|
if current:
|
|
74
|
-
result.append(current.strip(
|
|
76
|
+
result.append(current.strip("`"))
|
|
75
77
|
|
|
76
78
|
# ensure we split the table name
|
|
77
79
|
result = [r.split(".") for r in result][0]
|
|
@@ -89,13 +91,13 @@ class SQLAgent:
|
|
|
89
91
|
command_executor,
|
|
90
92
|
databases: List[str],
|
|
91
93
|
databases_struct: dict,
|
|
92
|
-
knowledge_base_database: str =
|
|
94
|
+
knowledge_base_database: str = "mindsdb",
|
|
93
95
|
include_tables: Optional[List[str]] = None,
|
|
94
96
|
ignore_tables: Optional[List[str]] = None,
|
|
95
97
|
include_knowledge_bases: Optional[List[str]] = None,
|
|
96
98
|
ignore_knowledge_bases: Optional[List[str]] = None,
|
|
97
99
|
sample_rows_in_table_info: int = 3,
|
|
98
|
-
cache: Optional[dict] = None
|
|
100
|
+
cache: Optional[dict] = None,
|
|
99
101
|
):
|
|
100
102
|
"""
|
|
101
103
|
Initialize SQLAgent.
|
|
@@ -133,12 +135,13 @@ class SQLAgent:
|
|
|
133
135
|
self._cache = cache
|
|
134
136
|
|
|
135
137
|
from mindsdb.interfaces.skills.skill_tool import SkillToolController
|
|
138
|
+
|
|
136
139
|
# Initialize the skill tool controller from MindsDB
|
|
137
140
|
self.skill_tool = SkillToolController()
|
|
138
141
|
|
|
139
142
|
def _call_engine(self, query: str, database=None):
|
|
140
143
|
# switch database
|
|
141
|
-
ast_query = parse_sql(query.strip(
|
|
144
|
+
ast_query = parse_sql(query.strip("`"))
|
|
142
145
|
self._check_permissions(ast_query)
|
|
143
146
|
|
|
144
147
|
if database is None:
|
|
@@ -148,10 +151,7 @@ class SQLAgent:
|
|
|
148
151
|
# for now, we will just use the first one
|
|
149
152
|
database = self._databases[0] if self._databases else "mindsdb"
|
|
150
153
|
|
|
151
|
-
ret = self._command_executor.execute_command(
|
|
152
|
-
ast_query,
|
|
153
|
-
database_name=database
|
|
154
|
-
)
|
|
154
|
+
ret = self._command_executor.execute_command(ast_query, database_name=database)
|
|
155
155
|
return ret
|
|
156
156
|
|
|
157
157
|
def _check_permissions(self, ast_query):
|
|
@@ -170,7 +170,7 @@ class SQLAgent:
|
|
|
170
170
|
|
|
171
171
|
def _check_f(node, is_table=None, **kwargs):
|
|
172
172
|
if is_table and isinstance(node, Identifier):
|
|
173
|
-
table_name =
|
|
173
|
+
table_name = ".".join(node.parts)
|
|
174
174
|
|
|
175
175
|
# Get the list of available knowledge bases
|
|
176
176
|
kb_names = self.get_usable_knowledge_base_names()
|
|
@@ -182,16 +182,21 @@ class SQLAgent:
|
|
|
182
182
|
if is_kb and self._knowledge_bases_to_include:
|
|
183
183
|
kb_parts = [split_table_name(x) for x in self._knowledge_bases_to_include]
|
|
184
184
|
if node.parts not in kb_parts:
|
|
185
|
-
raise ValueError(
|
|
185
|
+
raise ValueError(
|
|
186
|
+
f"Knowledge base {table_name} not found. Available knowledge bases: {', '.join(self._knowledge_bases_to_include)}"
|
|
187
|
+
)
|
|
186
188
|
# Regular table check
|
|
187
189
|
elif not is_kb and self._tables_to_include and node.parts not in tables_parts:
|
|
188
|
-
raise ValueError(
|
|
190
|
+
raise ValueError(
|
|
191
|
+
f"Table {table_name} not found. Available tables: {', '.join(self._tables_to_include)}"
|
|
192
|
+
)
|
|
189
193
|
# Check if it's a restricted knowledge base
|
|
190
194
|
elif is_kb and table_name in self._knowledge_bases_to_ignore:
|
|
191
195
|
raise ValueError(f"Knowledge base {table_name} is not allowed.")
|
|
192
196
|
# Check if it's a restricted table
|
|
193
197
|
elif not is_kb and table_name in self._tables_to_ignore:
|
|
194
198
|
raise ValueError(f"Table {table_name} is not allowed.")
|
|
199
|
+
|
|
195
200
|
query_traversal(ast_query, _check_f)
|
|
196
201
|
|
|
197
202
|
def get_usable_table_names(self) -> Iterable[str]:
|
|
@@ -200,7 +205,7 @@ class SQLAgent:
|
|
|
200
205
|
Returns:
|
|
201
206
|
Iterable[str]: list with table names
|
|
202
207
|
"""
|
|
203
|
-
cache_key = f
|
|
208
|
+
cache_key = f"{ctx.company_id}_{','.join(self._databases)}_tables"
|
|
204
209
|
|
|
205
210
|
# first check cache and return if found
|
|
206
211
|
if self._cache:
|
|
@@ -218,13 +223,13 @@ class SQLAgent:
|
|
|
218
223
|
|
|
219
224
|
schemas_names = list(self._mindsdb_db_struct[db_name].keys())
|
|
220
225
|
if len(schemas_names) > 1 and None in schemas_names:
|
|
221
|
-
raise Exception(
|
|
226
|
+
raise Exception("default schema and named schemas can not be used in same filter")
|
|
222
227
|
|
|
223
228
|
if None in schemas_names:
|
|
224
229
|
# get tables only from default schema
|
|
225
230
|
response = handler.get_tables()
|
|
226
231
|
tables_in_default_schema = list(response.data_frame.table_name)
|
|
227
|
-
schema_tables_restrictions = self._mindsdb_db_struct[db_name][None]
|
|
232
|
+
schema_tables_restrictions = self._mindsdb_db_struct[db_name][None] # None - is default schema
|
|
228
233
|
if schema_tables_restrictions is None:
|
|
229
234
|
for table_name in tables_in_default_schema:
|
|
230
235
|
result_tables.append([db_name, table_name])
|
|
@@ -233,27 +238,27 @@ class SQLAgent:
|
|
|
233
238
|
if table_name in tables_in_default_schema:
|
|
234
239
|
result_tables.append([db_name, table_name])
|
|
235
240
|
else:
|
|
236
|
-
if
|
|
241
|
+
if "all" in inspect.signature(handler.get_tables).parameters:
|
|
237
242
|
response = handler.get_tables(all=True)
|
|
238
243
|
else:
|
|
239
244
|
response = handler.get_tables()
|
|
240
245
|
response_schema_names = list(response.data_frame.table_schema.unique())
|
|
241
246
|
schemas_intersection = set(schemas_names) & set(response_schema_names)
|
|
242
247
|
if len(schemas_intersection) == 0:
|
|
243
|
-
raise Exception(
|
|
248
|
+
raise Exception("There are no allowed schemas in ds")
|
|
244
249
|
|
|
245
250
|
for schema_name in schemas_intersection:
|
|
246
|
-
schema_sub_df = response.data_frame[response.data_frame[
|
|
251
|
+
schema_sub_df = response.data_frame[response.data_frame["table_schema"] == schema_name]
|
|
247
252
|
if self._mindsdb_db_struct[db_name][schema_name] is None:
|
|
248
253
|
# all tables from schema allowed
|
|
249
254
|
for row in schema_sub_df:
|
|
250
|
-
result_tables.append([db_name, schema_name, row[
|
|
255
|
+
result_tables.append([db_name, schema_name, row["table_name"]])
|
|
251
256
|
else:
|
|
252
257
|
for table_name in self._mindsdb_db_struct[db_name][schema_name]:
|
|
253
|
-
if table_name in schema_sub_df[
|
|
258
|
+
if table_name in schema_sub_df["table_name"].values:
|
|
254
259
|
result_tables.append([db_name, schema_name, table_name])
|
|
255
260
|
|
|
256
|
-
result_tables = [
|
|
261
|
+
result_tables = [".".join(x) for x in result_tables]
|
|
257
262
|
if self._cache:
|
|
258
263
|
self._cache.set(cache_key, set(result_tables))
|
|
259
264
|
return result_tables
|
|
@@ -264,14 +269,14 @@ class SQLAgent:
|
|
|
264
269
|
Returns:
|
|
265
270
|
Iterable[str]: list with knowledge base names
|
|
266
271
|
"""
|
|
267
|
-
cache_key = f
|
|
272
|
+
cache_key = f"{ctx.company_id}_{self.knowledge_base_database}_knowledge_bases"
|
|
268
273
|
|
|
269
274
|
# todo we need to fix the cache, file cache can potentially store out of data information
|
|
270
275
|
# # first check cache and return if found
|
|
271
276
|
# if self._cache:
|
|
272
|
-
#
|
|
277
|
+
# cached_kbs = self._cache.get(cache_key)
|
|
273
278
|
# if cached_kbs:
|
|
274
|
-
#
|
|
279
|
+
# return cached_kbs
|
|
275
280
|
|
|
276
281
|
if self._knowledge_bases_to_include:
|
|
277
282
|
return self._knowledge_bases_to_include
|
|
@@ -289,13 +294,15 @@ class SQLAgent:
|
|
|
289
294
|
try:
|
|
290
295
|
# Get knowledge bases from the project database
|
|
291
296
|
kb_controller = self._command_executor.session.kb_controller
|
|
292
|
-
kb_names = [kb[
|
|
297
|
+
kb_names = [kb["name"] for kb in kb_controller.list()]
|
|
293
298
|
|
|
294
299
|
# Filter knowledge bases based on include list
|
|
295
300
|
if self._knowledge_bases_to_include:
|
|
296
301
|
kb_names = [kb_name for kb_name in kb_names if kb_name in self._knowledge_bases_to_include]
|
|
297
302
|
if not kb_names:
|
|
298
|
-
logger.warning(
|
|
303
|
+
logger.warning(
|
|
304
|
+
f"No knowledge bases found in the include list: {self._knowledge_bases_to_include}"
|
|
305
|
+
)
|
|
299
306
|
return []
|
|
300
307
|
|
|
301
308
|
return kb_names
|
|
@@ -317,7 +324,7 @@ class SQLAgent:
|
|
|
317
324
|
# Filter knowledge bases based on ignore list
|
|
318
325
|
kb_names = []
|
|
319
326
|
for row in result:
|
|
320
|
-
kb_name = row[
|
|
327
|
+
kb_name = row["name"]
|
|
321
328
|
if kb_name not in self._knowledge_bases_to_ignore:
|
|
322
329
|
kb_names.append(kb_name)
|
|
323
330
|
|
|
@@ -368,7 +375,7 @@ class SQLAgent:
|
|
|
368
375
|
return tables
|
|
369
376
|
|
|
370
377
|
def get_knowledge_base_info(self, kb_names: Optional[List[str]] = None) -> str:
|
|
371
|
-
"""
|
|
378
|
+
"""Get information about specified knowledge bases.
|
|
372
379
|
Follows best practices as specified in: Rajkumar et al, 2022 (https://arxiv.org/abs/2204.00498)
|
|
373
380
|
If `sample_rows_in_table_info`, the specified number of sample rows will be
|
|
374
381
|
appended to each table description. This can increase performance as demonstrated in the paper.
|
|
@@ -388,38 +395,56 @@ class SQLAgent:
|
|
|
388
395
|
return "\n\n".join(kbs_info)
|
|
389
396
|
|
|
390
397
|
def get_table_info(self, table_names: Optional[List[str]] = None) -> str:
|
|
391
|
-
"""
|
|
398
|
+
"""Get information about specified tables.
|
|
392
399
|
Follows best practices as specified in: Rajkumar et al, 2022 (https://arxiv.org/abs/2204.00498)
|
|
393
400
|
If `sample_rows_in_table_info`, the specified number of sample rows will be
|
|
394
401
|
appended to each table description. This can increase performance as demonstrated in the paper.
|
|
395
402
|
"""
|
|
403
|
+
if config.get("data_catalog", {}).get("enabled", False):
|
|
404
|
+
database_table_map = {}
|
|
405
|
+
for name in self.get_usable_table_names():
|
|
406
|
+
name = name.replace("`", "")
|
|
396
407
|
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
name = name.replace("`", "")
|
|
408
|
+
# TODO: Can there be situations where the database name is returned from the above method?
|
|
409
|
+
parts = name.split(".", 1)
|
|
410
|
+
database_table_map[parts[0]] = database_table_map.get(parts[0], []) + [parts[1]]
|
|
401
411
|
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
else:
|
|
406
|
-
all_tables.append(Identifier(name))
|
|
412
|
+
data_catalog_str = ""
|
|
413
|
+
for database_name, table_names in database_table_map.items():
|
|
414
|
+
data_catalog_reader = DataCatalogReader(database_name=database_name, table_names=table_names)
|
|
407
415
|
|
|
408
|
-
|
|
409
|
-
# all_tables = self._resolve_table_names(table_names, all_tables)
|
|
416
|
+
data_catalog_str += data_catalog_reader.read_metadata_as_string()
|
|
410
417
|
|
|
411
|
-
|
|
412
|
-
for table in all_tables:
|
|
413
|
-
key = f"{ctx.company_id}_{table}_info"
|
|
414
|
-
table_info = self._cache.get(key) if self._cache else None
|
|
415
|
-
if True or table_info is None:
|
|
416
|
-
table_info = self._get_single_table_info(table)
|
|
417
|
-
if self._cache:
|
|
418
|
-
self._cache.set(key, table_info)
|
|
418
|
+
return data_catalog_str
|
|
419
419
|
|
|
420
|
-
|
|
420
|
+
else:
|
|
421
|
+
# TODO: Improve old logic without data catalog
|
|
422
|
+
all_tables = []
|
|
423
|
+
for name in self.get_usable_table_names():
|
|
424
|
+
# remove backticks
|
|
425
|
+
name = name.replace("`", "")
|
|
426
|
+
|
|
427
|
+
split = name.split(".")
|
|
428
|
+
if len(split) > 1:
|
|
429
|
+
all_tables.append(Identifier(parts=[split[0], split[1]]))
|
|
430
|
+
else:
|
|
431
|
+
all_tables.append(Identifier(name))
|
|
432
|
+
|
|
433
|
+
# if table_names is not None:
|
|
434
|
+
# all_tables = self._resolve_table_names(table_names, all_tables)
|
|
435
|
+
|
|
436
|
+
tables_info = []
|
|
437
|
+
for table in all_tables:
|
|
438
|
+
key = f"{ctx.company_id}_{table}_info"
|
|
439
|
+
table_info = self._cache.get(key) if self._cache else None
|
|
440
|
+
if True or table_info is None:
|
|
441
|
+
table_info = self._get_single_table_info(table)
|
|
442
|
+
if self._cache:
|
|
443
|
+
self._cache.set(key, table_info)
|
|
444
|
+
|
|
445
|
+
tables_info.append(table_info)
|
|
421
446
|
|
|
422
|
-
|
|
447
|
+
return "\n\n".join(tables_info)
|
|
423
448
|
|
|
424
449
|
def get_kb_sample_rows(self, kb_name: str) -> str:
|
|
425
450
|
"""Get sample rows from a knowledge base.
|
|
@@ -430,7 +455,7 @@ class SQLAgent:
|
|
|
430
455
|
Returns:
|
|
431
456
|
str: A string containing the sample rows from the knowledge base.
|
|
432
457
|
"""
|
|
433
|
-
logger.info(f
|
|
458
|
+
logger.info(f"_get_sample_rows: knowledge base={kb_name}")
|
|
434
459
|
command = f"select * from {kb_name} limit 10;"
|
|
435
460
|
try:
|
|
436
461
|
ret = self._call_engine(command)
|
|
@@ -438,13 +463,12 @@ class SQLAgent:
|
|
|
438
463
|
|
|
439
464
|
def truncate_value(val):
|
|
440
465
|
str_val = str(val)
|
|
441
|
-
return str_val if len(str_val) < 100 else (str_val[:100] +
|
|
466
|
+
return str_val if len(str_val) < 100 else (str_val[:100] + "...")
|
|
442
467
|
|
|
443
|
-
sample_rows = list(
|
|
444
|
-
map(lambda row: [truncate_value(value) for value in row], sample_rows))
|
|
468
|
+
sample_rows = list(map(lambda row: [truncate_value(value) for value in row], sample_rows))
|
|
445
469
|
sample_rows_str = "\n" + f"{kb_name}:" + list_to_csv_str(sample_rows)
|
|
446
470
|
except Exception as e:
|
|
447
|
-
logger.info(f
|
|
471
|
+
logger.info(f"_get_sample_rows error: {e}")
|
|
448
472
|
sample_rows_str = "\n" + "\t [error] Couldn't retrieve sample rows!"
|
|
449
473
|
|
|
450
474
|
return sample_rows_str
|
|
@@ -471,11 +495,9 @@ class SQLAgent:
|
|
|
471
495
|
|
|
472
496
|
fields = df[INF_SCHEMA_COLUMNS_NAMES.COLUMN_NAME].to_list()
|
|
473
497
|
dtypes = [
|
|
474
|
-
mysql_data_type.value if isinstance(mysql_data_type, MYSQL_DATA_TYPE) else (data_type or
|
|
475
|
-
for mysql_data_type, data_type
|
|
476
|
-
|
|
477
|
-
df[INF_SCHEMA_COLUMNS_NAMES.MYSQL_DATA_TYPE],
|
|
478
|
-
df[INF_SCHEMA_COLUMNS_NAMES.DATA_TYPE]
|
|
498
|
+
mysql_data_type.value if isinstance(mysql_data_type, MYSQL_DATA_TYPE) else (data_type or "UNKNOWN")
|
|
499
|
+
for mysql_data_type, data_type in zip(
|
|
500
|
+
df[INF_SCHEMA_COLUMNS_NAMES.MYSQL_DATA_TYPE], df[INF_SCHEMA_COLUMNS_NAMES.DATA_TYPE]
|
|
479
501
|
)
|
|
480
502
|
]
|
|
481
503
|
except Exception as e:
|
|
@@ -492,16 +514,18 @@ class SQLAgent:
|
|
|
492
514
|
logger.warning(f"Could not get sample rows for {table_str}: {e}")
|
|
493
515
|
sample_rows_info = "\n\t [error] Couldn't retrieve sample rows!"
|
|
494
516
|
|
|
495
|
-
info = f
|
|
517
|
+
info = f"Table named `{table_str}`:\n"
|
|
496
518
|
info += f"\nSample with first {self._sample_rows_in_table_info} rows from table {table_str} in CSV format (dialect is 'excel'):\n"
|
|
497
519
|
info += sample_rows_info + "\n"
|
|
498
|
-
info +=
|
|
499
|
-
|
|
500
|
-
|
|
520
|
+
info += (
|
|
521
|
+
"\nColumn data types: "
|
|
522
|
+
+ ",\t".join([f"\n`{field}` : `{dtype}`" for field, dtype in zip(fields, dtypes)])
|
|
523
|
+
+ "\n"
|
|
524
|
+
)
|
|
501
525
|
return info
|
|
502
526
|
|
|
503
527
|
def _get_sample_rows(self, table: str, fields: List[str]) -> str:
|
|
504
|
-
logger.info(f
|
|
528
|
+
logger.info(f"_get_sample_rows: table={table} fields={fields}")
|
|
505
529
|
command = f"select {', '.join(fields)} from {table} limit {self._sample_rows_in_table_info};"
|
|
506
530
|
try:
|
|
507
531
|
ret = self._call_engine(command)
|
|
@@ -509,20 +533,19 @@ class SQLAgent:
|
|
|
509
533
|
|
|
510
534
|
def truncate_value(val):
|
|
511
535
|
str_val = str(val)
|
|
512
|
-
return str_val if len(str_val) < 100 else (str_val[:100] +
|
|
536
|
+
return str_val if len(str_val) < 100 else (str_val[:100] + "...")
|
|
513
537
|
|
|
514
|
-
sample_rows = list(
|
|
515
|
-
map(lambda row: [truncate_value(value) for value in row], sample_rows))
|
|
538
|
+
sample_rows = list(map(lambda row: [truncate_value(value) for value in row], sample_rows))
|
|
516
539
|
sample_rows_str = "\n" + list_to_csv_str([fields] + sample_rows)
|
|
517
540
|
except Exception as e:
|
|
518
|
-
logger.info(f
|
|
541
|
+
logger.info(f"_get_sample_rows error: {e}")
|
|
519
542
|
sample_rows_str = "\n" + "\t [error] Couldn't retrieve sample rows!"
|
|
520
543
|
|
|
521
544
|
return sample_rows_str
|
|
522
545
|
|
|
523
546
|
def _clean_query(self, query: str) -> str:
|
|
524
547
|
# Sometimes LLM can input markdown into query tools.
|
|
525
|
-
cmd = re.sub(r
|
|
548
|
+
cmd = re.sub(r"```(sql)?", "", query)
|
|
526
549
|
return cmd
|
|
527
550
|
|
|
528
551
|
def query(self, command: str, fetch: str = "all") -> str:
|
|
@@ -534,16 +557,16 @@ class SQLAgent:
|
|
|
534
557
|
def _repr_result(ret):
|
|
535
558
|
limit_rows = 30
|
|
536
559
|
|
|
537
|
-
columns_str =
|
|
538
|
-
res = f
|
|
560
|
+
columns_str = ", ".join([repr(col.name) for col in ret.columns])
|
|
561
|
+
res = f"Output columns: {columns_str}\n"
|
|
539
562
|
|
|
540
563
|
data = ret.to_lists()
|
|
541
564
|
if len(data) > limit_rows:
|
|
542
565
|
df = pd.DataFrame(data, columns=[col.name for col in ret.columns])
|
|
543
566
|
|
|
544
|
-
res += f
|
|
545
|
-
res += str(df.describe(include=
|
|
546
|
-
res += f
|
|
567
|
+
res += f"Result has {len(data)} rows. Description of data:\n"
|
|
568
|
+
res += str(df.describe(include="all")) + "\n\n"
|
|
569
|
+
res += f"First {limit_rows} rows:\n"
|
|
547
570
|
|
|
548
571
|
else:
|
|
549
572
|
res += "Result in CSV format (dialect is 'excel'):\n"
|
|
@@ -562,20 +585,20 @@ class SQLAgent:
|
|
|
562
585
|
|
|
563
586
|
def get_table_info_safe(self, table_names: Optional[List[str]] = None) -> str:
|
|
564
587
|
try:
|
|
565
|
-
logger.info(f
|
|
588
|
+
logger.info(f"get_table_info_safe: {table_names}")
|
|
566
589
|
return self.get_table_info(table_names)
|
|
567
590
|
except Exception as e:
|
|
568
|
-
logger.info(f
|
|
591
|
+
logger.info(f"get_table_info_safe error: {e}")
|
|
569
592
|
return f"Error: {e}"
|
|
570
593
|
|
|
571
594
|
def query_safe(self, command: str, fetch: str = "all") -> str:
|
|
572
595
|
try:
|
|
573
|
-
logger.info(f
|
|
596
|
+
logger.info(f"query_safe (fetch={fetch}): {command}")
|
|
574
597
|
return self.query(command, fetch)
|
|
575
598
|
except Exception as e:
|
|
576
599
|
logger.error(f"Error in query_safe: {str(e)}\n{traceback.format_exc()}")
|
|
577
|
-
logger.info(f
|
|
600
|
+
logger.info(f"query_safe error: {e}")
|
|
578
601
|
msg = f"Error: {e}"
|
|
579
|
-
if
|
|
580
|
-
msg +=
|
|
602
|
+
if "does not exist" in msg and " relation " in msg:
|
|
603
|
+
msg += "\nAvailable tables: " + ", ".join(self.get_usable_table_names())
|
|
581
604
|
return msg
|