MindsDB 25.5.4.2__py3-none-any.whl → 25.6.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of MindsDB might be problematic. Click here for more details.

Files changed (69) hide show
  1. mindsdb/__about__.py +1 -1
  2. mindsdb/api/a2a/agent.py +28 -25
  3. mindsdb/api/a2a/common/server/server.py +32 -26
  4. mindsdb/api/executor/command_executor.py +69 -14
  5. mindsdb/api/executor/datahub/datanodes/integration_datanode.py +49 -65
  6. mindsdb/api/executor/datahub/datanodes/project_datanode.py +29 -48
  7. mindsdb/api/executor/datahub/datanodes/system_tables.py +35 -61
  8. mindsdb/api/executor/planner/plan_join.py +67 -77
  9. mindsdb/api/executor/planner/query_planner.py +176 -155
  10. mindsdb/api/executor/planner/steps.py +37 -12
  11. mindsdb/api/executor/sql_query/result_set.py +45 -64
  12. mindsdb/api/executor/sql_query/steps/fetch_dataframe.py +14 -18
  13. mindsdb/api/executor/sql_query/steps/fetch_dataframe_partition.py +17 -18
  14. mindsdb/api/executor/sql_query/steps/insert_step.py +13 -33
  15. mindsdb/api/executor/sql_query/steps/subselect_step.py +43 -35
  16. mindsdb/api/executor/utilities/sql.py +42 -48
  17. mindsdb/api/http/namespaces/config.py +1 -1
  18. mindsdb/api/http/namespaces/file.py +14 -23
  19. mindsdb/api/mysql/mysql_proxy/data_types/mysql_datum.py +12 -28
  20. mindsdb/api/mysql/mysql_proxy/data_types/mysql_packets/binary_resultset_row_package.py +59 -50
  21. mindsdb/api/mysql/mysql_proxy/data_types/mysql_packets/resultset_row_package.py +9 -8
  22. mindsdb/api/mysql/mysql_proxy/libs/constants/mysql.py +449 -461
  23. mindsdb/api/mysql/mysql_proxy/utilities/dump.py +87 -36
  24. mindsdb/integrations/handlers/file_handler/file_handler.py +15 -9
  25. mindsdb/integrations/handlers/file_handler/tests/test_file_handler.py +43 -24
  26. mindsdb/integrations/handlers/litellm_handler/litellm_handler.py +10 -3
  27. mindsdb/integrations/handlers/mysql_handler/mysql_handler.py +26 -33
  28. mindsdb/integrations/handlers/oracle_handler/oracle_handler.py +74 -51
  29. mindsdb/integrations/handlers/postgres_handler/postgres_handler.py +305 -98
  30. mindsdb/integrations/handlers/salesforce_handler/salesforce_handler.py +53 -34
  31. mindsdb/integrations/handlers/salesforce_handler/salesforce_tables.py +136 -6
  32. mindsdb/integrations/handlers/snowflake_handler/snowflake_handler.py +334 -83
  33. mindsdb/integrations/libs/api_handler.py +261 -57
  34. mindsdb/integrations/libs/base.py +100 -29
  35. mindsdb/integrations/utilities/files/file_reader.py +99 -73
  36. mindsdb/integrations/utilities/handler_utils.py +23 -8
  37. mindsdb/integrations/utilities/sql_utils.py +35 -40
  38. mindsdb/interfaces/agents/agents_controller.py +196 -192
  39. mindsdb/interfaces/agents/constants.py +7 -1
  40. mindsdb/interfaces/agents/langchain_agent.py +42 -11
  41. mindsdb/interfaces/agents/mcp_client_agent.py +29 -21
  42. mindsdb/interfaces/data_catalog/__init__.py +0 -0
  43. mindsdb/interfaces/data_catalog/base_data_catalog.py +54 -0
  44. mindsdb/interfaces/data_catalog/data_catalog_loader.py +359 -0
  45. mindsdb/interfaces/data_catalog/data_catalog_reader.py +34 -0
  46. mindsdb/interfaces/database/database.py +81 -57
  47. mindsdb/interfaces/database/integrations.py +220 -234
  48. mindsdb/interfaces/database/log.py +72 -104
  49. mindsdb/interfaces/database/projects.py +156 -193
  50. mindsdb/interfaces/file/file_controller.py +21 -65
  51. mindsdb/interfaces/knowledge_base/controller.py +63 -10
  52. mindsdb/interfaces/knowledge_base/evaluate.py +519 -0
  53. mindsdb/interfaces/knowledge_base/llm_client.py +75 -0
  54. mindsdb/interfaces/skills/custom/text2sql/mindsdb_kb_tools.py +83 -43
  55. mindsdb/interfaces/skills/skills_controller.py +54 -36
  56. mindsdb/interfaces/skills/sql_agent.py +109 -86
  57. mindsdb/interfaces/storage/db.py +223 -79
  58. mindsdb/migrations/versions/2025-05-28_a44643042fe8_added_data_catalog_tables.py +118 -0
  59. mindsdb/migrations/versions/2025-06-09_608e376c19a7_updated_data_catalog_data_types.py +58 -0
  60. mindsdb/utilities/config.py +9 -2
  61. mindsdb/utilities/log.py +35 -26
  62. mindsdb/utilities/ml_task_queue/task.py +19 -22
  63. mindsdb/utilities/render/sqlalchemy_render.py +129 -181
  64. mindsdb/utilities/starters.py +40 -0
  65. {mindsdb-25.5.4.2.dist-info → mindsdb-25.6.2.0.dist-info}/METADATA +253 -253
  66. {mindsdb-25.5.4.2.dist-info → mindsdb-25.6.2.0.dist-info}/RECORD +69 -61
  67. {mindsdb-25.5.4.2.dist-info → mindsdb-25.6.2.0.dist-info}/WHEEL +0 -0
  68. {mindsdb-25.5.4.2.dist-info → mindsdb-25.6.2.0.dist-info}/licenses/LICENSE +0 -0
  69. {mindsdb-25.5.4.2.dist-info → mindsdb-25.6.2.0.dist-info}/top_level.txt +0 -0
@@ -14,6 +14,8 @@ from mindsdb.utilities.context import context as ctx
14
14
  from mindsdb.integrations.utilities.query_traversal import query_traversal
15
15
  from mindsdb.integrations.libs.response import INF_SCHEMA_COLUMNS_NAMES
16
16
  from mindsdb.api.mysql.mysql_proxy.libs.constants.mysql import MYSQL_DATA_TYPE
17
+ from mindsdb.utilities.config import config
18
+ from mindsdb.interfaces.data_catalog.data_catalog_reader import DataCatalogReader
17
19
 
18
20
  logger = log.getLogger(__name__)
19
21
 
@@ -28,7 +30,7 @@ def list_to_csv_str(array: List[List[Any]]) -> str:
28
30
  str: The array formatted as a CSV string using Excel dialect
29
31
  """
30
32
  output = StringIO()
31
- writer = csv.writer(output, dialect='excel')
33
+ writer = csv.writer(output, dialect="excel")
32
34
  str_array = [[str(item) for item in row] for row in array]
33
35
  writer.writerows(str_array)
34
36
  return output.getvalue()
@@ -55,23 +57,23 @@ def split_table_name(table_name: str) -> List[str]:
55
57
  'input': '`aaa`.`bbb.ccc`', 'output': ['aaa', 'bbb.ccc']
56
58
  """
57
59
  result = []
58
- current = ''
60
+ current = ""
59
61
  in_backticks = False
60
62
 
61
63
  i = 0
62
64
  while i < len(table_name):
63
- if table_name[i] == '`':
65
+ if table_name[i] == "`":
64
66
  in_backticks = not in_backticks
65
- elif table_name[i] == '.' and not in_backticks:
67
+ elif table_name[i] == "." and not in_backticks:
66
68
  if current:
67
- result.append(current.strip('`'))
68
- current = ''
69
+ result.append(current.strip("`"))
70
+ current = ""
69
71
  else:
70
72
  current += table_name[i]
71
73
  i += 1
72
74
 
73
75
  if current:
74
- result.append(current.strip('`'))
76
+ result.append(current.strip("`"))
75
77
 
76
78
  # ensure we split the table name
77
79
  result = [r.split(".") for r in result][0]
@@ -89,13 +91,13 @@ class SQLAgent:
89
91
  command_executor,
90
92
  databases: List[str],
91
93
  databases_struct: dict,
92
- knowledge_base_database: str = 'mindsdb',
94
+ knowledge_base_database: str = "mindsdb",
93
95
  include_tables: Optional[List[str]] = None,
94
96
  ignore_tables: Optional[List[str]] = None,
95
97
  include_knowledge_bases: Optional[List[str]] = None,
96
98
  ignore_knowledge_bases: Optional[List[str]] = None,
97
99
  sample_rows_in_table_info: int = 3,
98
- cache: Optional[dict] = None
100
+ cache: Optional[dict] = None,
99
101
  ):
100
102
  """
101
103
  Initialize SQLAgent.
@@ -133,12 +135,13 @@ class SQLAgent:
133
135
  self._cache = cache
134
136
 
135
137
  from mindsdb.interfaces.skills.skill_tool import SkillToolController
138
+
136
139
  # Initialize the skill tool controller from MindsDB
137
140
  self.skill_tool = SkillToolController()
138
141
 
139
142
  def _call_engine(self, query: str, database=None):
140
143
  # switch database
141
- ast_query = parse_sql(query.strip('`'))
144
+ ast_query = parse_sql(query.strip("`"))
142
145
  self._check_permissions(ast_query)
143
146
 
144
147
  if database is None:
@@ -148,10 +151,7 @@ class SQLAgent:
148
151
  # for now, we will just use the first one
149
152
  database = self._databases[0] if self._databases else "mindsdb"
150
153
 
151
- ret = self._command_executor.execute_command(
152
- ast_query,
153
- database_name=database
154
- )
154
+ ret = self._command_executor.execute_command(ast_query, database_name=database)
155
155
  return ret
156
156
 
157
157
  def _check_permissions(self, ast_query):
@@ -170,7 +170,7 @@ class SQLAgent:
170
170
 
171
171
  def _check_f(node, is_table=None, **kwargs):
172
172
  if is_table and isinstance(node, Identifier):
173
- table_name = '.'.join(node.parts)
173
+ table_name = ".".join(node.parts)
174
174
 
175
175
  # Get the list of available knowledge bases
176
176
  kb_names = self.get_usable_knowledge_base_names()
@@ -182,16 +182,21 @@ class SQLAgent:
182
182
  if is_kb and self._knowledge_bases_to_include:
183
183
  kb_parts = [split_table_name(x) for x in self._knowledge_bases_to_include]
184
184
  if node.parts not in kb_parts:
185
- raise ValueError(f"Knowledge base {table_name} not found. Available knowledge bases: {', '.join(self._knowledge_bases_to_include)}")
185
+ raise ValueError(
186
+ f"Knowledge base {table_name} not found. Available knowledge bases: {', '.join(self._knowledge_bases_to_include)}"
187
+ )
186
188
  # Regular table check
187
189
  elif not is_kb and self._tables_to_include and node.parts not in tables_parts:
188
- raise ValueError(f"Table {table_name} not found. Available tables: {', '.join(self._tables_to_include)}")
190
+ raise ValueError(
191
+ f"Table {table_name} not found. Available tables: {', '.join(self._tables_to_include)}"
192
+ )
189
193
  # Check if it's a restricted knowledge base
190
194
  elif is_kb and table_name in self._knowledge_bases_to_ignore:
191
195
  raise ValueError(f"Knowledge base {table_name} is not allowed.")
192
196
  # Check if it's a restricted table
193
197
  elif not is_kb and table_name in self._tables_to_ignore:
194
198
  raise ValueError(f"Table {table_name} is not allowed.")
199
+
195
200
  query_traversal(ast_query, _check_f)
196
201
 
197
202
  def get_usable_table_names(self) -> Iterable[str]:
@@ -200,7 +205,7 @@ class SQLAgent:
200
205
  Returns:
201
206
  Iterable[str]: list with table names
202
207
  """
203
- cache_key = f'{ctx.company_id}_{",".join(self._databases)}_tables'
208
+ cache_key = f"{ctx.company_id}_{','.join(self._databases)}_tables"
204
209
 
205
210
  # first check cache and return if found
206
211
  if self._cache:
@@ -218,13 +223,13 @@ class SQLAgent:
218
223
 
219
224
  schemas_names = list(self._mindsdb_db_struct[db_name].keys())
220
225
  if len(schemas_names) > 1 and None in schemas_names:
221
- raise Exception('default schema and named schemas can not be used in same filter')
226
+ raise Exception("default schema and named schemas can not be used in same filter")
222
227
 
223
228
  if None in schemas_names:
224
229
  # get tables only from default schema
225
230
  response = handler.get_tables()
226
231
  tables_in_default_schema = list(response.data_frame.table_name)
227
- schema_tables_restrictions = self._mindsdb_db_struct[db_name][None] # None - is default schema
232
+ schema_tables_restrictions = self._mindsdb_db_struct[db_name][None] # None - is default schema
228
233
  if schema_tables_restrictions is None:
229
234
  for table_name in tables_in_default_schema:
230
235
  result_tables.append([db_name, table_name])
@@ -233,27 +238,27 @@ class SQLAgent:
233
238
  if table_name in tables_in_default_schema:
234
239
  result_tables.append([db_name, table_name])
235
240
  else:
236
- if 'all' in inspect.signature(handler.get_tables).parameters:
241
+ if "all" in inspect.signature(handler.get_tables).parameters:
237
242
  response = handler.get_tables(all=True)
238
243
  else:
239
244
  response = handler.get_tables()
240
245
  response_schema_names = list(response.data_frame.table_schema.unique())
241
246
  schemas_intersection = set(schemas_names) & set(response_schema_names)
242
247
  if len(schemas_intersection) == 0:
243
- raise Exception('There are no allowed schemas in ds')
248
+ raise Exception("There are no allowed schemas in ds")
244
249
 
245
250
  for schema_name in schemas_intersection:
246
- schema_sub_df = response.data_frame[response.data_frame['table_schema'] == schema_name]
251
+ schema_sub_df = response.data_frame[response.data_frame["table_schema"] == schema_name]
247
252
  if self._mindsdb_db_struct[db_name][schema_name] is None:
248
253
  # all tables from schema allowed
249
254
  for row in schema_sub_df:
250
- result_tables.append([db_name, schema_name, row['table_name']])
255
+ result_tables.append([db_name, schema_name, row["table_name"]])
251
256
  else:
252
257
  for table_name in self._mindsdb_db_struct[db_name][schema_name]:
253
- if table_name in schema_sub_df['table_name'].values:
258
+ if table_name in schema_sub_df["table_name"].values:
254
259
  result_tables.append([db_name, schema_name, table_name])
255
260
 
256
- result_tables = ['.'.join(x) for x in result_tables]
261
+ result_tables = [".".join(x) for x in result_tables]
257
262
  if self._cache:
258
263
  self._cache.set(cache_key, set(result_tables))
259
264
  return result_tables
@@ -264,14 +269,14 @@ class SQLAgent:
264
269
  Returns:
265
270
  Iterable[str]: list with knowledge base names
266
271
  """
267
- cache_key = f'{ctx.company_id}_{self.knowledge_base_database}_knowledge_bases'
272
+ cache_key = f"{ctx.company_id}_{self.knowledge_base_database}_knowledge_bases"
268
273
 
269
274
  # todo we need to fix the cache, file cache can potentially store out of data information
270
275
  # # first check cache and return if found
271
276
  # if self._cache:
272
- # cached_kbs = self._cache.get(cache_key)
277
+ # cached_kbs = self._cache.get(cache_key)
273
278
  # if cached_kbs:
274
- # return cached_kbs
279
+ # return cached_kbs
275
280
 
276
281
  if self._knowledge_bases_to_include:
277
282
  return self._knowledge_bases_to_include
@@ -289,13 +294,15 @@ class SQLAgent:
289
294
  try:
290
295
  # Get knowledge bases from the project database
291
296
  kb_controller = self._command_executor.session.kb_controller
292
- kb_names = [kb['name'] for kb in kb_controller.list()]
297
+ kb_names = [kb["name"] for kb in kb_controller.list()]
293
298
 
294
299
  # Filter knowledge bases based on include list
295
300
  if self._knowledge_bases_to_include:
296
301
  kb_names = [kb_name for kb_name in kb_names if kb_name in self._knowledge_bases_to_include]
297
302
  if not kb_names:
298
- logger.warning(f"No knowledge bases found in the include list: {self._knowledge_bases_to_include}")
303
+ logger.warning(
304
+ f"No knowledge bases found in the include list: {self._knowledge_bases_to_include}"
305
+ )
299
306
  return []
300
307
 
301
308
  return kb_names
@@ -317,7 +324,7 @@ class SQLAgent:
317
324
  # Filter knowledge bases based on ignore list
318
325
  kb_names = []
319
326
  for row in result:
320
- kb_name = row['name']
327
+ kb_name = row["name"]
321
328
  if kb_name not in self._knowledge_bases_to_ignore:
322
329
  kb_names.append(kb_name)
323
330
 
@@ -368,7 +375,7 @@ class SQLAgent:
368
375
  return tables
369
376
 
370
377
  def get_knowledge_base_info(self, kb_names: Optional[List[str]] = None) -> str:
371
- """ Get information about specified knowledge bases.
378
+ """Get information about specified knowledge bases.
372
379
  Follows best practices as specified in: Rajkumar et al, 2022 (https://arxiv.org/abs/2204.00498)
373
380
  If `sample_rows_in_table_info`, the specified number of sample rows will be
374
381
  appended to each table description. This can increase performance as demonstrated in the paper.
@@ -388,38 +395,56 @@ class SQLAgent:
388
395
  return "\n\n".join(kbs_info)
389
396
 
390
397
  def get_table_info(self, table_names: Optional[List[str]] = None) -> str:
391
- """ Get information about specified tables.
398
+ """Get information about specified tables.
392
399
  Follows best practices as specified in: Rajkumar et al, 2022 (https://arxiv.org/abs/2204.00498)
393
400
  If `sample_rows_in_table_info`, the specified number of sample rows will be
394
401
  appended to each table description. This can increase performance as demonstrated in the paper.
395
402
  """
403
+ if config.get("data_catalog", {}).get("enabled", False):
404
+ database_table_map = {}
405
+ for name in self.get_usable_table_names():
406
+ name = name.replace("`", "")
396
407
 
397
- all_tables = []
398
- for name in self.get_usable_table_names():
399
- # remove backticks
400
- name = name.replace("`", "")
408
+ # TODO: Can there be situations where the database name is returned from the above method?
409
+ parts = name.split(".", 1)
410
+ database_table_map[parts[0]] = database_table_map.get(parts[0], []) + [parts[1]]
401
411
 
402
- split = name.split(".")
403
- if len(split) > 1:
404
- all_tables.append(Identifier(parts=[split[0], split[1]]))
405
- else:
406
- all_tables.append(Identifier(name))
412
+ data_catalog_str = ""
413
+ for database_name, table_names in database_table_map.items():
414
+ data_catalog_reader = DataCatalogReader(database_name=database_name, table_names=table_names)
407
415
 
408
- # if table_names is not None:
409
- # all_tables = self._resolve_table_names(table_names, all_tables)
416
+ data_catalog_str += data_catalog_reader.read_metadata_as_string()
410
417
 
411
- tables_info = []
412
- for table in all_tables:
413
- key = f"{ctx.company_id}_{table}_info"
414
- table_info = self._cache.get(key) if self._cache else None
415
- if True or table_info is None:
416
- table_info = self._get_single_table_info(table)
417
- if self._cache:
418
- self._cache.set(key, table_info)
418
+ return data_catalog_str
419
419
 
420
- tables_info.append(table_info)
420
+ else:
421
+ # TODO: Improve old logic without data catalog
422
+ all_tables = []
423
+ for name in self.get_usable_table_names():
424
+ # remove backticks
425
+ name = name.replace("`", "")
426
+
427
+ split = name.split(".")
428
+ if len(split) > 1:
429
+ all_tables.append(Identifier(parts=[split[0], split[1]]))
430
+ else:
431
+ all_tables.append(Identifier(name))
432
+
433
+ # if table_names is not None:
434
+ # all_tables = self._resolve_table_names(table_names, all_tables)
435
+
436
+ tables_info = []
437
+ for table in all_tables:
438
+ key = f"{ctx.company_id}_{table}_info"
439
+ table_info = self._cache.get(key) if self._cache else None
440
+ if True or table_info is None:
441
+ table_info = self._get_single_table_info(table)
442
+ if self._cache:
443
+ self._cache.set(key, table_info)
444
+
445
+ tables_info.append(table_info)
421
446
 
422
- return "\n\n".join(tables_info)
447
+ return "\n\n".join(tables_info)
423
448
 
424
449
  def get_kb_sample_rows(self, kb_name: str) -> str:
425
450
  """Get sample rows from a knowledge base.
@@ -430,7 +455,7 @@ class SQLAgent:
430
455
  Returns:
431
456
  str: A string containing the sample rows from the knowledge base.
432
457
  """
433
- logger.info(f'_get_sample_rows: knowledge base={kb_name}')
458
+ logger.info(f"_get_sample_rows: knowledge base={kb_name}")
434
459
  command = f"select * from {kb_name} limit 10;"
435
460
  try:
436
461
  ret = self._call_engine(command)
@@ -438,13 +463,12 @@ class SQLAgent:
438
463
 
439
464
  def truncate_value(val):
440
465
  str_val = str(val)
441
- return str_val if len(str_val) < 100 else (str_val[:100] + '...')
466
+ return str_val if len(str_val) < 100 else (str_val[:100] + "...")
442
467
 
443
- sample_rows = list(
444
- map(lambda row: [truncate_value(value) for value in row], sample_rows))
468
+ sample_rows = list(map(lambda row: [truncate_value(value) for value in row], sample_rows))
445
469
  sample_rows_str = "\n" + f"{kb_name}:" + list_to_csv_str(sample_rows)
446
470
  except Exception as e:
447
- logger.info(f'_get_sample_rows error: {e}')
471
+ logger.info(f"_get_sample_rows error: {e}")
448
472
  sample_rows_str = "\n" + "\t [error] Couldn't retrieve sample rows!"
449
473
 
450
474
  return sample_rows_str
@@ -471,11 +495,9 @@ class SQLAgent:
471
495
 
472
496
  fields = df[INF_SCHEMA_COLUMNS_NAMES.COLUMN_NAME].to_list()
473
497
  dtypes = [
474
- mysql_data_type.value if isinstance(mysql_data_type, MYSQL_DATA_TYPE) else (data_type or 'UNKNOWN')
475
- for mysql_data_type, data_type
476
- in zip(
477
- df[INF_SCHEMA_COLUMNS_NAMES.MYSQL_DATA_TYPE],
478
- df[INF_SCHEMA_COLUMNS_NAMES.DATA_TYPE]
498
+ mysql_data_type.value if isinstance(mysql_data_type, MYSQL_DATA_TYPE) else (data_type or "UNKNOWN")
499
+ for mysql_data_type, data_type in zip(
500
+ df[INF_SCHEMA_COLUMNS_NAMES.MYSQL_DATA_TYPE], df[INF_SCHEMA_COLUMNS_NAMES.DATA_TYPE]
479
501
  )
480
502
  ]
481
503
  except Exception as e:
@@ -492,16 +514,18 @@ class SQLAgent:
492
514
  logger.warning(f"Could not get sample rows for {table_str}: {e}")
493
515
  sample_rows_info = "\n\t [error] Couldn't retrieve sample rows!"
494
516
 
495
- info = f'Table named `{table_str}`:\n'
517
+ info = f"Table named `{table_str}`:\n"
496
518
  info += f"\nSample with first {self._sample_rows_in_table_info} rows from table {table_str} in CSV format (dialect is 'excel'):\n"
497
519
  info += sample_rows_info + "\n"
498
- info += '\nColumn data types: ' + ",\t".join(
499
- [f'\n`{field}` : `{dtype}`' for field, dtype in zip(fields, dtypes)]
500
- ) + '\n'
520
+ info += (
521
+ "\nColumn data types: "
522
+ + ",\t".join([f"\n`{field}` : `{dtype}`" for field, dtype in zip(fields, dtypes)])
523
+ + "\n"
524
+ )
501
525
  return info
502
526
 
503
527
  def _get_sample_rows(self, table: str, fields: List[str]) -> str:
504
- logger.info(f'_get_sample_rows: table={table} fields={fields}')
528
+ logger.info(f"_get_sample_rows: table={table} fields={fields}")
505
529
  command = f"select {', '.join(fields)} from {table} limit {self._sample_rows_in_table_info};"
506
530
  try:
507
531
  ret = self._call_engine(command)
@@ -509,20 +533,19 @@ class SQLAgent:
509
533
 
510
534
  def truncate_value(val):
511
535
  str_val = str(val)
512
- return str_val if len(str_val) < 100 else (str_val[:100] + '...')
536
+ return str_val if len(str_val) < 100 else (str_val[:100] + "...")
513
537
 
514
- sample_rows = list(
515
- map(lambda row: [truncate_value(value) for value in row], sample_rows))
538
+ sample_rows = list(map(lambda row: [truncate_value(value) for value in row], sample_rows))
516
539
  sample_rows_str = "\n" + list_to_csv_str([fields] + sample_rows)
517
540
  except Exception as e:
518
- logger.info(f'_get_sample_rows error: {e}')
541
+ logger.info(f"_get_sample_rows error: {e}")
519
542
  sample_rows_str = "\n" + "\t [error] Couldn't retrieve sample rows!"
520
543
 
521
544
  return sample_rows_str
522
545
 
523
546
  def _clean_query(self, query: str) -> str:
524
547
  # Sometimes LLM can input markdown into query tools.
525
- cmd = re.sub(r'```(sql)?', '', query)
548
+ cmd = re.sub(r"```(sql)?", "", query)
526
549
  return cmd
527
550
 
528
551
  def query(self, command: str, fetch: str = "all") -> str:
@@ -534,16 +557,16 @@ class SQLAgent:
534
557
  def _repr_result(ret):
535
558
  limit_rows = 30
536
559
 
537
- columns_str = ', '.join([repr(col.name) for col in ret.columns])
538
- res = f'Output columns: {columns_str}\n'
560
+ columns_str = ", ".join([repr(col.name) for col in ret.columns])
561
+ res = f"Output columns: {columns_str}\n"
539
562
 
540
563
  data = ret.to_lists()
541
564
  if len(data) > limit_rows:
542
565
  df = pd.DataFrame(data, columns=[col.name for col in ret.columns])
543
566
 
544
- res += f'Result has {len(data)} rows. Description of data:\n'
545
- res += str(df.describe(include='all')) + '\n\n'
546
- res += f'First {limit_rows} rows:\n'
567
+ res += f"Result has {len(data)} rows. Description of data:\n"
568
+ res += str(df.describe(include="all")) + "\n\n"
569
+ res += f"First {limit_rows} rows:\n"
547
570
 
548
571
  else:
549
572
  res += "Result in CSV format (dialect is 'excel'):\n"
@@ -562,20 +585,20 @@ class SQLAgent:
562
585
 
563
586
  def get_table_info_safe(self, table_names: Optional[List[str]] = None) -> str:
564
587
  try:
565
- logger.info(f'get_table_info_safe: {table_names}')
588
+ logger.info(f"get_table_info_safe: {table_names}")
566
589
  return self.get_table_info(table_names)
567
590
  except Exception as e:
568
- logger.info(f'get_table_info_safe error: {e}')
591
+ logger.info(f"get_table_info_safe error: {e}")
569
592
  return f"Error: {e}"
570
593
 
571
594
  def query_safe(self, command: str, fetch: str = "all") -> str:
572
595
  try:
573
- logger.info(f'query_safe (fetch={fetch}): {command}')
596
+ logger.info(f"query_safe (fetch={fetch}): {command}")
574
597
  return self.query(command, fetch)
575
598
  except Exception as e:
576
599
  logger.error(f"Error in query_safe: {str(e)}\n{traceback.format_exc()}")
577
- logger.info(f'query_safe error: {e}')
600
+ logger.info(f"query_safe error: {e}")
578
601
  msg = f"Error: {e}"
579
- if 'does not exist' in msg and ' relation ' in msg:
580
- msg += '\nAvailable tables: ' + ', '.join(self.get_usable_table_names())
602
+ if "does not exist" in msg and " relation " in msg:
603
+ msg += "\nAvailable tables: " + ", ".join(self.get_usable_table_names())
581
604
  return msg