MindsDB 25.5.4.2__py3-none-any.whl → 25.6.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of MindsDB might be problematic. Click here for more details.
- mindsdb/__about__.py +1 -1
- mindsdb/api/a2a/agent.py +50 -26
- mindsdb/api/a2a/common/server/server.py +32 -26
- mindsdb/api/a2a/task_manager.py +68 -6
- mindsdb/api/executor/command_executor.py +69 -14
- mindsdb/api/executor/datahub/datanodes/integration_datanode.py +49 -65
- mindsdb/api/executor/datahub/datanodes/mindsdb_tables.py +91 -84
- mindsdb/api/executor/datahub/datanodes/project_datanode.py +29 -48
- mindsdb/api/executor/datahub/datanodes/system_tables.py +35 -61
- mindsdb/api/executor/planner/plan_join.py +67 -77
- mindsdb/api/executor/planner/query_planner.py +176 -155
- mindsdb/api/executor/planner/steps.py +37 -12
- mindsdb/api/executor/sql_query/result_set.py +45 -64
- mindsdb/api/executor/sql_query/steps/fetch_dataframe.py +14 -18
- mindsdb/api/executor/sql_query/steps/fetch_dataframe_partition.py +17 -18
- mindsdb/api/executor/sql_query/steps/insert_step.py +13 -33
- mindsdb/api/executor/sql_query/steps/subselect_step.py +43 -35
- mindsdb/api/executor/utilities/sql.py +42 -48
- mindsdb/api/http/namespaces/config.py +1 -1
- mindsdb/api/http/namespaces/file.py +14 -23
- mindsdb/api/http/namespaces/knowledge_bases.py +132 -154
- mindsdb/api/mysql/mysql_proxy/data_types/mysql_datum.py +12 -28
- mindsdb/api/mysql/mysql_proxy/data_types/mysql_packets/binary_resultset_row_package.py +59 -50
- mindsdb/api/mysql/mysql_proxy/data_types/mysql_packets/resultset_row_package.py +9 -8
- mindsdb/api/mysql/mysql_proxy/libs/constants/mysql.py +449 -461
- mindsdb/api/mysql/mysql_proxy/utilities/dump.py +87 -36
- mindsdb/integrations/handlers/bigquery_handler/bigquery_handler.py +219 -28
- mindsdb/integrations/handlers/file_handler/file_handler.py +15 -9
- mindsdb/integrations/handlers/file_handler/tests/test_file_handler.py +43 -24
- mindsdb/integrations/handlers/litellm_handler/litellm_handler.py +10 -3
- mindsdb/integrations/handlers/llama_index_handler/requirements.txt +1 -1
- mindsdb/integrations/handlers/mysql_handler/mysql_handler.py +29 -33
- mindsdb/integrations/handlers/openai_handler/openai_handler.py +277 -356
- mindsdb/integrations/handlers/oracle_handler/oracle_handler.py +74 -51
- mindsdb/integrations/handlers/postgres_handler/postgres_handler.py +305 -98
- mindsdb/integrations/handlers/salesforce_handler/salesforce_handler.py +145 -40
- mindsdb/integrations/handlers/salesforce_handler/salesforce_tables.py +136 -6
- mindsdb/integrations/handlers/snowflake_handler/snowflake_handler.py +352 -83
- mindsdb/integrations/libs/api_handler.py +279 -57
- mindsdb/integrations/libs/base.py +185 -30
- mindsdb/integrations/utilities/files/file_reader.py +99 -73
- mindsdb/integrations/utilities/handler_utils.py +23 -8
- mindsdb/integrations/utilities/sql_utils.py +35 -40
- mindsdb/interfaces/agents/agents_controller.py +226 -196
- mindsdb/interfaces/agents/constants.py +8 -1
- mindsdb/interfaces/agents/langchain_agent.py +42 -11
- mindsdb/interfaces/agents/mcp_client_agent.py +29 -21
- mindsdb/interfaces/agents/mindsdb_database_agent.py +23 -18
- mindsdb/interfaces/data_catalog/__init__.py +0 -0
- mindsdb/interfaces/data_catalog/base_data_catalog.py +54 -0
- mindsdb/interfaces/data_catalog/data_catalog_loader.py +375 -0
- mindsdb/interfaces/data_catalog/data_catalog_reader.py +38 -0
- mindsdb/interfaces/database/database.py +81 -57
- mindsdb/interfaces/database/integrations.py +222 -234
- mindsdb/interfaces/database/log.py +72 -104
- mindsdb/interfaces/database/projects.py +156 -193
- mindsdb/interfaces/file/file_controller.py +21 -65
- mindsdb/interfaces/knowledge_base/controller.py +66 -25
- mindsdb/interfaces/knowledge_base/evaluate.py +516 -0
- mindsdb/interfaces/knowledge_base/llm_client.py +75 -0
- mindsdb/interfaces/skills/custom/text2sql/mindsdb_kb_tools.py +83 -43
- mindsdb/interfaces/skills/skills_controller.py +31 -36
- mindsdb/interfaces/skills/sql_agent.py +113 -86
- mindsdb/interfaces/storage/db.py +242 -82
- mindsdb/migrations/versions/2025-05-28_a44643042fe8_added_data_catalog_tables.py +118 -0
- mindsdb/migrations/versions/2025-06-09_608e376c19a7_updated_data_catalog_data_types.py +58 -0
- mindsdb/utilities/config.py +13 -2
- mindsdb/utilities/log.py +35 -26
- mindsdb/utilities/ml_task_queue/task.py +19 -22
- mindsdb/utilities/render/sqlalchemy_render.py +129 -181
- mindsdb/utilities/starters.py +40 -0
- {mindsdb-25.5.4.2.dist-info → mindsdb-25.6.3.0.dist-info}/METADATA +257 -257
- {mindsdb-25.5.4.2.dist-info → mindsdb-25.6.3.0.dist-info}/RECORD +76 -68
- {mindsdb-25.5.4.2.dist-info → mindsdb-25.6.3.0.dist-info}/WHEEL +0 -0
- {mindsdb-25.5.4.2.dist-info → mindsdb-25.6.3.0.dist-info}/licenses/LICENSE +0 -0
- {mindsdb-25.5.4.2.dist-info → mindsdb-25.6.3.0.dist-info}/top_level.txt +0 -0
|
@@ -14,6 +14,8 @@ from mindsdb.utilities.context import context as ctx
|
|
|
14
14
|
from mindsdb.integrations.utilities.query_traversal import query_traversal
|
|
15
15
|
from mindsdb.integrations.libs.response import INF_SCHEMA_COLUMNS_NAMES
|
|
16
16
|
from mindsdb.api.mysql.mysql_proxy.libs.constants.mysql import MYSQL_DATA_TYPE
|
|
17
|
+
from mindsdb.utilities.config import config
|
|
18
|
+
from mindsdb.interfaces.data_catalog.data_catalog_reader import DataCatalogReader
|
|
17
19
|
|
|
18
20
|
logger = log.getLogger(__name__)
|
|
19
21
|
|
|
@@ -28,7 +30,7 @@ def list_to_csv_str(array: List[List[Any]]) -> str:
|
|
|
28
30
|
str: The array formatted as a CSV string using Excel dialect
|
|
29
31
|
"""
|
|
30
32
|
output = StringIO()
|
|
31
|
-
writer = csv.writer(output, dialect=
|
|
33
|
+
writer = csv.writer(output, dialect="excel")
|
|
32
34
|
str_array = [[str(item) for item in row] for row in array]
|
|
33
35
|
writer.writerows(str_array)
|
|
34
36
|
return output.getvalue()
|
|
@@ -55,23 +57,23 @@ def split_table_name(table_name: str) -> List[str]:
|
|
|
55
57
|
'input': '`aaa`.`bbb.ccc`', 'output': ['aaa', 'bbb.ccc']
|
|
56
58
|
"""
|
|
57
59
|
result = []
|
|
58
|
-
current =
|
|
60
|
+
current = ""
|
|
59
61
|
in_backticks = False
|
|
60
62
|
|
|
61
63
|
i = 0
|
|
62
64
|
while i < len(table_name):
|
|
63
|
-
if table_name[i] ==
|
|
65
|
+
if table_name[i] == "`":
|
|
64
66
|
in_backticks = not in_backticks
|
|
65
|
-
elif table_name[i] ==
|
|
67
|
+
elif table_name[i] == "." and not in_backticks:
|
|
66
68
|
if current:
|
|
67
|
-
result.append(current.strip(
|
|
68
|
-
current =
|
|
69
|
+
result.append(current.strip("`"))
|
|
70
|
+
current = ""
|
|
69
71
|
else:
|
|
70
72
|
current += table_name[i]
|
|
71
73
|
i += 1
|
|
72
74
|
|
|
73
75
|
if current:
|
|
74
|
-
result.append(current.strip(
|
|
76
|
+
result.append(current.strip("`"))
|
|
75
77
|
|
|
76
78
|
# ensure we split the table name
|
|
77
79
|
result = [r.split(".") for r in result][0]
|
|
@@ -89,13 +91,13 @@ class SQLAgent:
|
|
|
89
91
|
command_executor,
|
|
90
92
|
databases: List[str],
|
|
91
93
|
databases_struct: dict,
|
|
92
|
-
knowledge_base_database: str =
|
|
94
|
+
knowledge_base_database: str = "mindsdb",
|
|
93
95
|
include_tables: Optional[List[str]] = None,
|
|
94
96
|
ignore_tables: Optional[List[str]] = None,
|
|
95
97
|
include_knowledge_bases: Optional[List[str]] = None,
|
|
96
98
|
ignore_knowledge_bases: Optional[List[str]] = None,
|
|
97
99
|
sample_rows_in_table_info: int = 3,
|
|
98
|
-
cache: Optional[dict] = None
|
|
100
|
+
cache: Optional[dict] = None,
|
|
99
101
|
):
|
|
100
102
|
"""
|
|
101
103
|
Initialize SQLAgent.
|
|
@@ -133,12 +135,13 @@ class SQLAgent:
|
|
|
133
135
|
self._cache = cache
|
|
134
136
|
|
|
135
137
|
from mindsdb.interfaces.skills.skill_tool import SkillToolController
|
|
138
|
+
|
|
136
139
|
# Initialize the skill tool controller from MindsDB
|
|
137
140
|
self.skill_tool = SkillToolController()
|
|
138
141
|
|
|
139
142
|
def _call_engine(self, query: str, database=None):
|
|
140
143
|
# switch database
|
|
141
|
-
ast_query = parse_sql(query.strip(
|
|
144
|
+
ast_query = parse_sql(query.strip("`"))
|
|
142
145
|
self._check_permissions(ast_query)
|
|
143
146
|
|
|
144
147
|
if database is None:
|
|
@@ -148,10 +151,7 @@ class SQLAgent:
|
|
|
148
151
|
# for now, we will just use the first one
|
|
149
152
|
database = self._databases[0] if self._databases else "mindsdb"
|
|
150
153
|
|
|
151
|
-
ret = self._command_executor.execute_command(
|
|
152
|
-
ast_query,
|
|
153
|
-
database_name=database
|
|
154
|
-
)
|
|
154
|
+
ret = self._command_executor.execute_command(ast_query, database_name=database)
|
|
155
155
|
return ret
|
|
156
156
|
|
|
157
157
|
def _check_permissions(self, ast_query):
|
|
@@ -170,7 +170,7 @@ class SQLAgent:
|
|
|
170
170
|
|
|
171
171
|
def _check_f(node, is_table=None, **kwargs):
|
|
172
172
|
if is_table and isinstance(node, Identifier):
|
|
173
|
-
table_name =
|
|
173
|
+
table_name = ".".join(node.parts)
|
|
174
174
|
|
|
175
175
|
# Get the list of available knowledge bases
|
|
176
176
|
kb_names = self.get_usable_knowledge_base_names()
|
|
@@ -182,16 +182,21 @@ class SQLAgent:
|
|
|
182
182
|
if is_kb and self._knowledge_bases_to_include:
|
|
183
183
|
kb_parts = [split_table_name(x) for x in self._knowledge_bases_to_include]
|
|
184
184
|
if node.parts not in kb_parts:
|
|
185
|
-
raise ValueError(
|
|
185
|
+
raise ValueError(
|
|
186
|
+
f"Knowledge base {table_name} not found. Available knowledge bases: {', '.join(self._knowledge_bases_to_include)}"
|
|
187
|
+
)
|
|
186
188
|
# Regular table check
|
|
187
189
|
elif not is_kb and self._tables_to_include and node.parts not in tables_parts:
|
|
188
|
-
raise ValueError(
|
|
190
|
+
raise ValueError(
|
|
191
|
+
f"Table {table_name} not found. Available tables: {', '.join(self._tables_to_include)}"
|
|
192
|
+
)
|
|
189
193
|
# Check if it's a restricted knowledge base
|
|
190
194
|
elif is_kb and table_name in self._knowledge_bases_to_ignore:
|
|
191
195
|
raise ValueError(f"Knowledge base {table_name} is not allowed.")
|
|
192
196
|
# Check if it's a restricted table
|
|
193
197
|
elif not is_kb and table_name in self._tables_to_ignore:
|
|
194
198
|
raise ValueError(f"Table {table_name} is not allowed.")
|
|
199
|
+
|
|
195
200
|
query_traversal(ast_query, _check_f)
|
|
196
201
|
|
|
197
202
|
def get_usable_table_names(self) -> Iterable[str]:
|
|
@@ -200,7 +205,7 @@ class SQLAgent:
|
|
|
200
205
|
Returns:
|
|
201
206
|
Iterable[str]: list with table names
|
|
202
207
|
"""
|
|
203
|
-
cache_key = f
|
|
208
|
+
cache_key = f"{ctx.company_id}_{','.join(self._databases)}_tables"
|
|
204
209
|
|
|
205
210
|
# first check cache and return if found
|
|
206
211
|
if self._cache:
|
|
@@ -218,13 +223,13 @@ class SQLAgent:
|
|
|
218
223
|
|
|
219
224
|
schemas_names = list(self._mindsdb_db_struct[db_name].keys())
|
|
220
225
|
if len(schemas_names) > 1 and None in schemas_names:
|
|
221
|
-
raise Exception(
|
|
226
|
+
raise Exception("default schema and named schemas can not be used in same filter")
|
|
222
227
|
|
|
223
228
|
if None in schemas_names:
|
|
224
229
|
# get tables only from default schema
|
|
225
230
|
response = handler.get_tables()
|
|
226
231
|
tables_in_default_schema = list(response.data_frame.table_name)
|
|
227
|
-
schema_tables_restrictions = self._mindsdb_db_struct[db_name][None]
|
|
232
|
+
schema_tables_restrictions = self._mindsdb_db_struct[db_name][None] # None - is default schema
|
|
228
233
|
if schema_tables_restrictions is None:
|
|
229
234
|
for table_name in tables_in_default_schema:
|
|
230
235
|
result_tables.append([db_name, table_name])
|
|
@@ -233,27 +238,27 @@ class SQLAgent:
|
|
|
233
238
|
if table_name in tables_in_default_schema:
|
|
234
239
|
result_tables.append([db_name, table_name])
|
|
235
240
|
else:
|
|
236
|
-
if
|
|
241
|
+
if "all" in inspect.signature(handler.get_tables).parameters:
|
|
237
242
|
response = handler.get_tables(all=True)
|
|
238
243
|
else:
|
|
239
244
|
response = handler.get_tables()
|
|
240
245
|
response_schema_names = list(response.data_frame.table_schema.unique())
|
|
241
246
|
schemas_intersection = set(schemas_names) & set(response_schema_names)
|
|
242
247
|
if len(schemas_intersection) == 0:
|
|
243
|
-
raise Exception(
|
|
248
|
+
raise Exception("There are no allowed schemas in ds")
|
|
244
249
|
|
|
245
250
|
for schema_name in schemas_intersection:
|
|
246
|
-
schema_sub_df = response.data_frame[response.data_frame[
|
|
251
|
+
schema_sub_df = response.data_frame[response.data_frame["table_schema"] == schema_name]
|
|
247
252
|
if self._mindsdb_db_struct[db_name][schema_name] is None:
|
|
248
253
|
# all tables from schema allowed
|
|
249
254
|
for row in schema_sub_df:
|
|
250
|
-
result_tables.append([db_name, schema_name, row[
|
|
255
|
+
result_tables.append([db_name, schema_name, row["table_name"]])
|
|
251
256
|
else:
|
|
252
257
|
for table_name in self._mindsdb_db_struct[db_name][schema_name]:
|
|
253
|
-
if table_name in schema_sub_df[
|
|
258
|
+
if table_name in schema_sub_df["table_name"].values:
|
|
254
259
|
result_tables.append([db_name, schema_name, table_name])
|
|
255
260
|
|
|
256
|
-
result_tables = [
|
|
261
|
+
result_tables = [".".join(x) for x in result_tables]
|
|
257
262
|
if self._cache:
|
|
258
263
|
self._cache.set(cache_key, set(result_tables))
|
|
259
264
|
return result_tables
|
|
@@ -264,14 +269,14 @@ class SQLAgent:
|
|
|
264
269
|
Returns:
|
|
265
270
|
Iterable[str]: list with knowledge base names
|
|
266
271
|
"""
|
|
267
|
-
cache_key = f
|
|
272
|
+
cache_key = f"{ctx.company_id}_{self.knowledge_base_database}_knowledge_bases"
|
|
268
273
|
|
|
269
274
|
# todo we need to fix the cache, file cache can potentially store out of data information
|
|
270
275
|
# # first check cache and return if found
|
|
271
276
|
# if self._cache:
|
|
272
|
-
#
|
|
277
|
+
# cached_kbs = self._cache.get(cache_key)
|
|
273
278
|
# if cached_kbs:
|
|
274
|
-
#
|
|
279
|
+
# return cached_kbs
|
|
275
280
|
|
|
276
281
|
if self._knowledge_bases_to_include:
|
|
277
282
|
return self._knowledge_bases_to_include
|
|
@@ -289,13 +294,15 @@ class SQLAgent:
|
|
|
289
294
|
try:
|
|
290
295
|
# Get knowledge bases from the project database
|
|
291
296
|
kb_controller = self._command_executor.session.kb_controller
|
|
292
|
-
kb_names = [kb[
|
|
297
|
+
kb_names = [kb["name"] for kb in kb_controller.list()]
|
|
293
298
|
|
|
294
299
|
# Filter knowledge bases based on include list
|
|
295
300
|
if self._knowledge_bases_to_include:
|
|
296
301
|
kb_names = [kb_name for kb_name in kb_names if kb_name in self._knowledge_bases_to_include]
|
|
297
302
|
if not kb_names:
|
|
298
|
-
logger.warning(
|
|
303
|
+
logger.warning(
|
|
304
|
+
f"No knowledge bases found in the include list: {self._knowledge_bases_to_include}"
|
|
305
|
+
)
|
|
299
306
|
return []
|
|
300
307
|
|
|
301
308
|
return kb_names
|
|
@@ -317,7 +324,7 @@ class SQLAgent:
|
|
|
317
324
|
# Filter knowledge bases based on ignore list
|
|
318
325
|
kb_names = []
|
|
319
326
|
for row in result:
|
|
320
|
-
kb_name = row[
|
|
327
|
+
kb_name = row["name"]
|
|
321
328
|
if kb_name not in self._knowledge_bases_to_ignore:
|
|
322
329
|
kb_names.append(kb_name)
|
|
323
330
|
|
|
@@ -368,7 +375,7 @@ class SQLAgent:
|
|
|
368
375
|
return tables
|
|
369
376
|
|
|
370
377
|
def get_knowledge_base_info(self, kb_names: Optional[List[str]] = None) -> str:
|
|
371
|
-
"""
|
|
378
|
+
"""Get information about specified knowledge bases.
|
|
372
379
|
Follows best practices as specified in: Rajkumar et al, 2022 (https://arxiv.org/abs/2204.00498)
|
|
373
380
|
If `sample_rows_in_table_info`, the specified number of sample rows will be
|
|
374
381
|
appended to each table description. This can increase performance as demonstrated in the paper.
|
|
@@ -388,38 +395,60 @@ class SQLAgent:
|
|
|
388
395
|
return "\n\n".join(kbs_info)
|
|
389
396
|
|
|
390
397
|
def get_table_info(self, table_names: Optional[List[str]] = None) -> str:
|
|
391
|
-
"""
|
|
398
|
+
"""Get information about specified tables.
|
|
392
399
|
Follows best practices as specified in: Rajkumar et al, 2022 (https://arxiv.org/abs/2204.00498)
|
|
393
400
|
If `sample_rows_in_table_info`, the specified number of sample rows will be
|
|
394
401
|
appended to each table description. This can increase performance as demonstrated in the paper.
|
|
395
402
|
"""
|
|
403
|
+
if config.get("data_catalog", {}).get("enabled", False):
|
|
404
|
+
database_table_map = {}
|
|
405
|
+
for name in table_names or self.get_usable_table_names():
|
|
406
|
+
name = name.replace("`", "")
|
|
396
407
|
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
408
|
+
parts = name.split(".", 1)
|
|
409
|
+
# TODO: Will there be situations where parts has more than 2 elements? Like a schema?
|
|
410
|
+
# This is unlikely given that we default to a single schema per database.
|
|
411
|
+
if len(parts) == 1:
|
|
412
|
+
raise ValueError(f"Invalid table name: {name}. Expected format is 'database.table'.")
|
|
401
413
|
|
|
402
|
-
|
|
403
|
-
if len(split) > 1:
|
|
404
|
-
all_tables.append(Identifier(parts=[split[0], split[1]]))
|
|
405
|
-
else:
|
|
406
|
-
all_tables.append(Identifier(name))
|
|
414
|
+
database_table_map[parts[0]] = database_table_map.get(parts[0], []) + [parts[1]]
|
|
407
415
|
|
|
408
|
-
|
|
409
|
-
|
|
416
|
+
data_catalog_str = ""
|
|
417
|
+
for database_name, table_names in database_table_map.items():
|
|
418
|
+
data_catalog_reader = DataCatalogReader(database_name=database_name, table_names=table_names)
|
|
410
419
|
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
table_info = self._cache.get(key) if self._cache else None
|
|
415
|
-
if True or table_info is None:
|
|
416
|
-
table_info = self._get_single_table_info(table)
|
|
417
|
-
if self._cache:
|
|
418
|
-
self._cache.set(key, table_info)
|
|
420
|
+
data_catalog_str += data_catalog_reader.read_metadata_as_string()
|
|
421
|
+
|
|
422
|
+
return data_catalog_str
|
|
419
423
|
|
|
420
|
-
|
|
424
|
+
else:
|
|
425
|
+
# TODO: Improve old logic without data catalog
|
|
426
|
+
all_tables = []
|
|
427
|
+
for name in self.get_usable_table_names():
|
|
428
|
+
# remove backticks
|
|
429
|
+
name = name.replace("`", "")
|
|
430
|
+
|
|
431
|
+
split = name.split(".")
|
|
432
|
+
if len(split) > 1:
|
|
433
|
+
all_tables.append(Identifier(parts=[split[0], split[1]]))
|
|
434
|
+
else:
|
|
435
|
+
all_tables.append(Identifier(name))
|
|
421
436
|
|
|
422
|
-
|
|
437
|
+
if table_names is not None:
|
|
438
|
+
all_tables = self._resolve_table_names(table_names, all_tables)
|
|
439
|
+
|
|
440
|
+
tables_info = []
|
|
441
|
+
for table in all_tables:
|
|
442
|
+
key = f"{ctx.company_id}_{table}_info"
|
|
443
|
+
table_info = self._cache.get(key) if self._cache else None
|
|
444
|
+
if True or table_info is None:
|
|
445
|
+
table_info = self._get_single_table_info(table)
|
|
446
|
+
if self._cache:
|
|
447
|
+
self._cache.set(key, table_info)
|
|
448
|
+
|
|
449
|
+
tables_info.append(table_info)
|
|
450
|
+
|
|
451
|
+
return "\n\n".join(tables_info)
|
|
423
452
|
|
|
424
453
|
def get_kb_sample_rows(self, kb_name: str) -> str:
|
|
425
454
|
"""Get sample rows from a knowledge base.
|
|
@@ -430,7 +459,7 @@ class SQLAgent:
|
|
|
430
459
|
Returns:
|
|
431
460
|
str: A string containing the sample rows from the knowledge base.
|
|
432
461
|
"""
|
|
433
|
-
logger.info(f
|
|
462
|
+
logger.info(f"_get_sample_rows: knowledge base={kb_name}")
|
|
434
463
|
command = f"select * from {kb_name} limit 10;"
|
|
435
464
|
try:
|
|
436
465
|
ret = self._call_engine(command)
|
|
@@ -438,13 +467,12 @@ class SQLAgent:
|
|
|
438
467
|
|
|
439
468
|
def truncate_value(val):
|
|
440
469
|
str_val = str(val)
|
|
441
|
-
return str_val if len(str_val) < 100 else (str_val[:100] +
|
|
470
|
+
return str_val if len(str_val) < 100 else (str_val[:100] + "...")
|
|
442
471
|
|
|
443
|
-
sample_rows = list(
|
|
444
|
-
map(lambda row: [truncate_value(value) for value in row], sample_rows))
|
|
472
|
+
sample_rows = list(map(lambda row: [truncate_value(value) for value in row], sample_rows))
|
|
445
473
|
sample_rows_str = "\n" + f"{kb_name}:" + list_to_csv_str(sample_rows)
|
|
446
474
|
except Exception as e:
|
|
447
|
-
logger.info(f
|
|
475
|
+
logger.info(f"_get_sample_rows error: {e}")
|
|
448
476
|
sample_rows_str = "\n" + "\t [error] Couldn't retrieve sample rows!"
|
|
449
477
|
|
|
450
478
|
return sample_rows_str
|
|
@@ -471,11 +499,9 @@ class SQLAgent:
|
|
|
471
499
|
|
|
472
500
|
fields = df[INF_SCHEMA_COLUMNS_NAMES.COLUMN_NAME].to_list()
|
|
473
501
|
dtypes = [
|
|
474
|
-
mysql_data_type.value if isinstance(mysql_data_type, MYSQL_DATA_TYPE) else (data_type or
|
|
475
|
-
for mysql_data_type, data_type
|
|
476
|
-
|
|
477
|
-
df[INF_SCHEMA_COLUMNS_NAMES.MYSQL_DATA_TYPE],
|
|
478
|
-
df[INF_SCHEMA_COLUMNS_NAMES.DATA_TYPE]
|
|
502
|
+
mysql_data_type.value if isinstance(mysql_data_type, MYSQL_DATA_TYPE) else (data_type or "UNKNOWN")
|
|
503
|
+
for mysql_data_type, data_type in zip(
|
|
504
|
+
df[INF_SCHEMA_COLUMNS_NAMES.MYSQL_DATA_TYPE], df[INF_SCHEMA_COLUMNS_NAMES.DATA_TYPE]
|
|
479
505
|
)
|
|
480
506
|
]
|
|
481
507
|
except Exception as e:
|
|
@@ -492,16 +518,18 @@ class SQLAgent:
|
|
|
492
518
|
logger.warning(f"Could not get sample rows for {table_str}: {e}")
|
|
493
519
|
sample_rows_info = "\n\t [error] Couldn't retrieve sample rows!"
|
|
494
520
|
|
|
495
|
-
info = f
|
|
521
|
+
info = f"Table named `{table_str}`:\n"
|
|
496
522
|
info += f"\nSample with first {self._sample_rows_in_table_info} rows from table {table_str} in CSV format (dialect is 'excel'):\n"
|
|
497
523
|
info += sample_rows_info + "\n"
|
|
498
|
-
info +=
|
|
499
|
-
|
|
500
|
-
|
|
524
|
+
info += (
|
|
525
|
+
"\nColumn data types: "
|
|
526
|
+
+ ",\t".join([f"\n`{field}` : `{dtype}`" for field, dtype in zip(fields, dtypes)])
|
|
527
|
+
+ "\n"
|
|
528
|
+
)
|
|
501
529
|
return info
|
|
502
530
|
|
|
503
531
|
def _get_sample_rows(self, table: str, fields: List[str]) -> str:
|
|
504
|
-
logger.info(f
|
|
532
|
+
logger.info(f"_get_sample_rows: table={table} fields={fields}")
|
|
505
533
|
command = f"select {', '.join(fields)} from {table} limit {self._sample_rows_in_table_info};"
|
|
506
534
|
try:
|
|
507
535
|
ret = self._call_engine(command)
|
|
@@ -509,20 +537,19 @@ class SQLAgent:
|
|
|
509
537
|
|
|
510
538
|
def truncate_value(val):
|
|
511
539
|
str_val = str(val)
|
|
512
|
-
return str_val if len(str_val) < 100 else (str_val[:100] +
|
|
540
|
+
return str_val if len(str_val) < 100 else (str_val[:100] + "...")
|
|
513
541
|
|
|
514
|
-
sample_rows = list(
|
|
515
|
-
map(lambda row: [truncate_value(value) for value in row], sample_rows))
|
|
542
|
+
sample_rows = list(map(lambda row: [truncate_value(value) for value in row], sample_rows))
|
|
516
543
|
sample_rows_str = "\n" + list_to_csv_str([fields] + sample_rows)
|
|
517
544
|
except Exception as e:
|
|
518
|
-
logger.info(f
|
|
545
|
+
logger.info(f"_get_sample_rows error: {e}")
|
|
519
546
|
sample_rows_str = "\n" + "\t [error] Couldn't retrieve sample rows!"
|
|
520
547
|
|
|
521
548
|
return sample_rows_str
|
|
522
549
|
|
|
523
550
|
def _clean_query(self, query: str) -> str:
|
|
524
551
|
# Sometimes LLM can input markdown into query tools.
|
|
525
|
-
cmd = re.sub(r
|
|
552
|
+
cmd = re.sub(r"```(sql)?", "", query)
|
|
526
553
|
return cmd
|
|
527
554
|
|
|
528
555
|
def query(self, command: str, fetch: str = "all") -> str:
|
|
@@ -534,16 +561,16 @@ class SQLAgent:
|
|
|
534
561
|
def _repr_result(ret):
|
|
535
562
|
limit_rows = 30
|
|
536
563
|
|
|
537
|
-
columns_str =
|
|
538
|
-
res = f
|
|
564
|
+
columns_str = ", ".join([repr(col.name) for col in ret.columns])
|
|
565
|
+
res = f"Output columns: {columns_str}\n"
|
|
539
566
|
|
|
540
567
|
data = ret.to_lists()
|
|
541
568
|
if len(data) > limit_rows:
|
|
542
569
|
df = pd.DataFrame(data, columns=[col.name for col in ret.columns])
|
|
543
570
|
|
|
544
|
-
res += f
|
|
545
|
-
res += str(df.describe(include=
|
|
546
|
-
res += f
|
|
571
|
+
res += f"Result has {len(data)} rows. Description of data:\n"
|
|
572
|
+
res += str(df.describe(include="all")) + "\n\n"
|
|
573
|
+
res += f"First {limit_rows} rows:\n"
|
|
547
574
|
|
|
548
575
|
else:
|
|
549
576
|
res += "Result in CSV format (dialect is 'excel'):\n"
|
|
@@ -562,20 +589,20 @@ class SQLAgent:
|
|
|
562
589
|
|
|
563
590
|
def get_table_info_safe(self, table_names: Optional[List[str]] = None) -> str:
|
|
564
591
|
try:
|
|
565
|
-
logger.info(f
|
|
592
|
+
logger.info(f"get_table_info_safe: {table_names}")
|
|
566
593
|
return self.get_table_info(table_names)
|
|
567
594
|
except Exception as e:
|
|
568
|
-
logger.info(f
|
|
595
|
+
logger.info(f"get_table_info_safe error: {e}")
|
|
569
596
|
return f"Error: {e}"
|
|
570
597
|
|
|
571
598
|
def query_safe(self, command: str, fetch: str = "all") -> str:
|
|
572
599
|
try:
|
|
573
|
-
logger.info(f
|
|
600
|
+
logger.info(f"query_safe (fetch={fetch}): {command}")
|
|
574
601
|
return self.query(command, fetch)
|
|
575
602
|
except Exception as e:
|
|
576
603
|
logger.error(f"Error in query_safe: {str(e)}\n{traceback.format_exc()}")
|
|
577
|
-
logger.info(f
|
|
604
|
+
logger.info(f"query_safe error: {e}")
|
|
578
605
|
msg = f"Error: {e}"
|
|
579
|
-
if
|
|
580
|
-
msg +=
|
|
606
|
+
if "does not exist" in msg and " relation " in msg:
|
|
607
|
+
msg += "\nAvailable tables: " + ", ".join(self.get_usable_table_names())
|
|
581
608
|
return msg
|