MindsDB 25.5.4.2__py3-none-any.whl → 25.6.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of MindsDB might be problematic. Click here for more details.

Files changed (76) hide show
  1. mindsdb/__about__.py +1 -1
  2. mindsdb/api/a2a/agent.py +50 -26
  3. mindsdb/api/a2a/common/server/server.py +32 -26
  4. mindsdb/api/a2a/task_manager.py +68 -6
  5. mindsdb/api/executor/command_executor.py +69 -14
  6. mindsdb/api/executor/datahub/datanodes/integration_datanode.py +49 -65
  7. mindsdb/api/executor/datahub/datanodes/mindsdb_tables.py +91 -84
  8. mindsdb/api/executor/datahub/datanodes/project_datanode.py +29 -48
  9. mindsdb/api/executor/datahub/datanodes/system_tables.py +35 -61
  10. mindsdb/api/executor/planner/plan_join.py +67 -77
  11. mindsdb/api/executor/planner/query_planner.py +176 -155
  12. mindsdb/api/executor/planner/steps.py +37 -12
  13. mindsdb/api/executor/sql_query/result_set.py +45 -64
  14. mindsdb/api/executor/sql_query/steps/fetch_dataframe.py +14 -18
  15. mindsdb/api/executor/sql_query/steps/fetch_dataframe_partition.py +17 -18
  16. mindsdb/api/executor/sql_query/steps/insert_step.py +13 -33
  17. mindsdb/api/executor/sql_query/steps/subselect_step.py +43 -35
  18. mindsdb/api/executor/utilities/sql.py +42 -48
  19. mindsdb/api/http/namespaces/config.py +1 -1
  20. mindsdb/api/http/namespaces/file.py +14 -23
  21. mindsdb/api/http/namespaces/knowledge_bases.py +132 -154
  22. mindsdb/api/mysql/mysql_proxy/data_types/mysql_datum.py +12 -28
  23. mindsdb/api/mysql/mysql_proxy/data_types/mysql_packets/binary_resultset_row_package.py +59 -50
  24. mindsdb/api/mysql/mysql_proxy/data_types/mysql_packets/resultset_row_package.py +9 -8
  25. mindsdb/api/mysql/mysql_proxy/libs/constants/mysql.py +449 -461
  26. mindsdb/api/mysql/mysql_proxy/utilities/dump.py +87 -36
  27. mindsdb/integrations/handlers/bigquery_handler/bigquery_handler.py +219 -28
  28. mindsdb/integrations/handlers/file_handler/file_handler.py +15 -9
  29. mindsdb/integrations/handlers/file_handler/tests/test_file_handler.py +43 -24
  30. mindsdb/integrations/handlers/litellm_handler/litellm_handler.py +10 -3
  31. mindsdb/integrations/handlers/llama_index_handler/requirements.txt +1 -1
  32. mindsdb/integrations/handlers/mysql_handler/mysql_handler.py +29 -33
  33. mindsdb/integrations/handlers/openai_handler/openai_handler.py +277 -356
  34. mindsdb/integrations/handlers/oracle_handler/oracle_handler.py +74 -51
  35. mindsdb/integrations/handlers/postgres_handler/postgres_handler.py +305 -98
  36. mindsdb/integrations/handlers/salesforce_handler/salesforce_handler.py +145 -40
  37. mindsdb/integrations/handlers/salesforce_handler/salesforce_tables.py +136 -6
  38. mindsdb/integrations/handlers/snowflake_handler/snowflake_handler.py +352 -83
  39. mindsdb/integrations/libs/api_handler.py +279 -57
  40. mindsdb/integrations/libs/base.py +185 -30
  41. mindsdb/integrations/utilities/files/file_reader.py +99 -73
  42. mindsdb/integrations/utilities/handler_utils.py +23 -8
  43. mindsdb/integrations/utilities/sql_utils.py +35 -40
  44. mindsdb/interfaces/agents/agents_controller.py +226 -196
  45. mindsdb/interfaces/agents/constants.py +8 -1
  46. mindsdb/interfaces/agents/langchain_agent.py +42 -11
  47. mindsdb/interfaces/agents/mcp_client_agent.py +29 -21
  48. mindsdb/interfaces/agents/mindsdb_database_agent.py +23 -18
  49. mindsdb/interfaces/data_catalog/__init__.py +0 -0
  50. mindsdb/interfaces/data_catalog/base_data_catalog.py +54 -0
  51. mindsdb/interfaces/data_catalog/data_catalog_loader.py +375 -0
  52. mindsdb/interfaces/data_catalog/data_catalog_reader.py +38 -0
  53. mindsdb/interfaces/database/database.py +81 -57
  54. mindsdb/interfaces/database/integrations.py +222 -234
  55. mindsdb/interfaces/database/log.py +72 -104
  56. mindsdb/interfaces/database/projects.py +156 -193
  57. mindsdb/interfaces/file/file_controller.py +21 -65
  58. mindsdb/interfaces/knowledge_base/controller.py +66 -25
  59. mindsdb/interfaces/knowledge_base/evaluate.py +516 -0
  60. mindsdb/interfaces/knowledge_base/llm_client.py +75 -0
  61. mindsdb/interfaces/skills/custom/text2sql/mindsdb_kb_tools.py +83 -43
  62. mindsdb/interfaces/skills/skills_controller.py +31 -36
  63. mindsdb/interfaces/skills/sql_agent.py +113 -86
  64. mindsdb/interfaces/storage/db.py +242 -82
  65. mindsdb/migrations/versions/2025-05-28_a44643042fe8_added_data_catalog_tables.py +118 -0
  66. mindsdb/migrations/versions/2025-06-09_608e376c19a7_updated_data_catalog_data_types.py +58 -0
  67. mindsdb/utilities/config.py +13 -2
  68. mindsdb/utilities/log.py +35 -26
  69. mindsdb/utilities/ml_task_queue/task.py +19 -22
  70. mindsdb/utilities/render/sqlalchemy_render.py +129 -181
  71. mindsdb/utilities/starters.py +40 -0
  72. {mindsdb-25.5.4.2.dist-info → mindsdb-25.6.3.0.dist-info}/METADATA +257 -257
  73. {mindsdb-25.5.4.2.dist-info → mindsdb-25.6.3.0.dist-info}/RECORD +76 -68
  74. {mindsdb-25.5.4.2.dist-info → mindsdb-25.6.3.0.dist-info}/WHEEL +0 -0
  75. {mindsdb-25.5.4.2.dist-info → mindsdb-25.6.3.0.dist-info}/licenses/LICENSE +0 -0
  76. {mindsdb-25.5.4.2.dist-info → mindsdb-25.6.3.0.dist-info}/top_level.txt +0 -0
@@ -14,6 +14,8 @@ from mindsdb.utilities.context import context as ctx
14
14
  from mindsdb.integrations.utilities.query_traversal import query_traversal
15
15
  from mindsdb.integrations.libs.response import INF_SCHEMA_COLUMNS_NAMES
16
16
  from mindsdb.api.mysql.mysql_proxy.libs.constants.mysql import MYSQL_DATA_TYPE
17
+ from mindsdb.utilities.config import config
18
+ from mindsdb.interfaces.data_catalog.data_catalog_reader import DataCatalogReader
17
19
 
18
20
  logger = log.getLogger(__name__)
19
21
 
@@ -28,7 +30,7 @@ def list_to_csv_str(array: List[List[Any]]) -> str:
28
30
  str: The array formatted as a CSV string using Excel dialect
29
31
  """
30
32
  output = StringIO()
31
- writer = csv.writer(output, dialect='excel')
33
+ writer = csv.writer(output, dialect="excel")
32
34
  str_array = [[str(item) for item in row] for row in array]
33
35
  writer.writerows(str_array)
34
36
  return output.getvalue()
@@ -55,23 +57,23 @@ def split_table_name(table_name: str) -> List[str]:
55
57
  'input': '`aaa`.`bbb.ccc`', 'output': ['aaa', 'bbb.ccc']
56
58
  """
57
59
  result = []
58
- current = ''
60
+ current = ""
59
61
  in_backticks = False
60
62
 
61
63
  i = 0
62
64
  while i < len(table_name):
63
- if table_name[i] == '`':
65
+ if table_name[i] == "`":
64
66
  in_backticks = not in_backticks
65
- elif table_name[i] == '.' and not in_backticks:
67
+ elif table_name[i] == "." and not in_backticks:
66
68
  if current:
67
- result.append(current.strip('`'))
68
- current = ''
69
+ result.append(current.strip("`"))
70
+ current = ""
69
71
  else:
70
72
  current += table_name[i]
71
73
  i += 1
72
74
 
73
75
  if current:
74
- result.append(current.strip('`'))
76
+ result.append(current.strip("`"))
75
77
 
76
78
  # ensure we split the table name
77
79
  result = [r.split(".") for r in result][0]
@@ -89,13 +91,13 @@ class SQLAgent:
89
91
  command_executor,
90
92
  databases: List[str],
91
93
  databases_struct: dict,
92
- knowledge_base_database: str = 'mindsdb',
94
+ knowledge_base_database: str = "mindsdb",
93
95
  include_tables: Optional[List[str]] = None,
94
96
  ignore_tables: Optional[List[str]] = None,
95
97
  include_knowledge_bases: Optional[List[str]] = None,
96
98
  ignore_knowledge_bases: Optional[List[str]] = None,
97
99
  sample_rows_in_table_info: int = 3,
98
- cache: Optional[dict] = None
100
+ cache: Optional[dict] = None,
99
101
  ):
100
102
  """
101
103
  Initialize SQLAgent.
@@ -133,12 +135,13 @@ class SQLAgent:
133
135
  self._cache = cache
134
136
 
135
137
  from mindsdb.interfaces.skills.skill_tool import SkillToolController
138
+
136
139
  # Initialize the skill tool controller from MindsDB
137
140
  self.skill_tool = SkillToolController()
138
141
 
139
142
  def _call_engine(self, query: str, database=None):
140
143
  # switch database
141
- ast_query = parse_sql(query.strip('`'))
144
+ ast_query = parse_sql(query.strip("`"))
142
145
  self._check_permissions(ast_query)
143
146
 
144
147
  if database is None:
@@ -148,10 +151,7 @@ class SQLAgent:
148
151
  # for now, we will just use the first one
149
152
  database = self._databases[0] if self._databases else "mindsdb"
150
153
 
151
- ret = self._command_executor.execute_command(
152
- ast_query,
153
- database_name=database
154
- )
154
+ ret = self._command_executor.execute_command(ast_query, database_name=database)
155
155
  return ret
156
156
 
157
157
  def _check_permissions(self, ast_query):
@@ -170,7 +170,7 @@ class SQLAgent:
170
170
 
171
171
  def _check_f(node, is_table=None, **kwargs):
172
172
  if is_table and isinstance(node, Identifier):
173
- table_name = '.'.join(node.parts)
173
+ table_name = ".".join(node.parts)
174
174
 
175
175
  # Get the list of available knowledge bases
176
176
  kb_names = self.get_usable_knowledge_base_names()
@@ -182,16 +182,21 @@ class SQLAgent:
182
182
  if is_kb and self._knowledge_bases_to_include:
183
183
  kb_parts = [split_table_name(x) for x in self._knowledge_bases_to_include]
184
184
  if node.parts not in kb_parts:
185
- raise ValueError(f"Knowledge base {table_name} not found. Available knowledge bases: {', '.join(self._knowledge_bases_to_include)}")
185
+ raise ValueError(
186
+ f"Knowledge base {table_name} not found. Available knowledge bases: {', '.join(self._knowledge_bases_to_include)}"
187
+ )
186
188
  # Regular table check
187
189
  elif not is_kb and self._tables_to_include and node.parts not in tables_parts:
188
- raise ValueError(f"Table {table_name} not found. Available tables: {', '.join(self._tables_to_include)}")
190
+ raise ValueError(
191
+ f"Table {table_name} not found. Available tables: {', '.join(self._tables_to_include)}"
192
+ )
189
193
  # Check if it's a restricted knowledge base
190
194
  elif is_kb and table_name in self._knowledge_bases_to_ignore:
191
195
  raise ValueError(f"Knowledge base {table_name} is not allowed.")
192
196
  # Check if it's a restricted table
193
197
  elif not is_kb and table_name in self._tables_to_ignore:
194
198
  raise ValueError(f"Table {table_name} is not allowed.")
199
+
195
200
  query_traversal(ast_query, _check_f)
196
201
 
197
202
  def get_usable_table_names(self) -> Iterable[str]:
@@ -200,7 +205,7 @@ class SQLAgent:
200
205
  Returns:
201
206
  Iterable[str]: list with table names
202
207
  """
203
- cache_key = f'{ctx.company_id}_{",".join(self._databases)}_tables'
208
+ cache_key = f"{ctx.company_id}_{','.join(self._databases)}_tables"
204
209
 
205
210
  # first check cache and return if found
206
211
  if self._cache:
@@ -218,13 +223,13 @@ class SQLAgent:
218
223
 
219
224
  schemas_names = list(self._mindsdb_db_struct[db_name].keys())
220
225
  if len(schemas_names) > 1 and None in schemas_names:
221
- raise Exception('default schema and named schemas can not be used in same filter')
226
+ raise Exception("default schema and named schemas can not be used in same filter")
222
227
 
223
228
  if None in schemas_names:
224
229
  # get tables only from default schema
225
230
  response = handler.get_tables()
226
231
  tables_in_default_schema = list(response.data_frame.table_name)
227
- schema_tables_restrictions = self._mindsdb_db_struct[db_name][None] # None - is default schema
232
+ schema_tables_restrictions = self._mindsdb_db_struct[db_name][None] # None - is default schema
228
233
  if schema_tables_restrictions is None:
229
234
  for table_name in tables_in_default_schema:
230
235
  result_tables.append([db_name, table_name])
@@ -233,27 +238,27 @@ class SQLAgent:
233
238
  if table_name in tables_in_default_schema:
234
239
  result_tables.append([db_name, table_name])
235
240
  else:
236
- if 'all' in inspect.signature(handler.get_tables).parameters:
241
+ if "all" in inspect.signature(handler.get_tables).parameters:
237
242
  response = handler.get_tables(all=True)
238
243
  else:
239
244
  response = handler.get_tables()
240
245
  response_schema_names = list(response.data_frame.table_schema.unique())
241
246
  schemas_intersection = set(schemas_names) & set(response_schema_names)
242
247
  if len(schemas_intersection) == 0:
243
- raise Exception('There are no allowed schemas in ds')
248
+ raise Exception("There are no allowed schemas in ds")
244
249
 
245
250
  for schema_name in schemas_intersection:
246
- schema_sub_df = response.data_frame[response.data_frame['table_schema'] == schema_name]
251
+ schema_sub_df = response.data_frame[response.data_frame["table_schema"] == schema_name]
247
252
  if self._mindsdb_db_struct[db_name][schema_name] is None:
248
253
  # all tables from schema allowed
249
254
  for row in schema_sub_df:
250
- result_tables.append([db_name, schema_name, row['table_name']])
255
+ result_tables.append([db_name, schema_name, row["table_name"]])
251
256
  else:
252
257
  for table_name in self._mindsdb_db_struct[db_name][schema_name]:
253
- if table_name in schema_sub_df['table_name'].values:
258
+ if table_name in schema_sub_df["table_name"].values:
254
259
  result_tables.append([db_name, schema_name, table_name])
255
260
 
256
- result_tables = ['.'.join(x) for x in result_tables]
261
+ result_tables = [".".join(x) for x in result_tables]
257
262
  if self._cache:
258
263
  self._cache.set(cache_key, set(result_tables))
259
264
  return result_tables
@@ -264,14 +269,14 @@ class SQLAgent:
264
269
  Returns:
265
270
  Iterable[str]: list with knowledge base names
266
271
  """
267
- cache_key = f'{ctx.company_id}_{self.knowledge_base_database}_knowledge_bases'
272
+ cache_key = f"{ctx.company_id}_{self.knowledge_base_database}_knowledge_bases"
268
273
 
269
274
  # todo we need to fix the cache, file cache can potentially store out of data information
270
275
  # # first check cache and return if found
271
276
  # if self._cache:
272
- # cached_kbs = self._cache.get(cache_key)
277
+ # cached_kbs = self._cache.get(cache_key)
273
278
  # if cached_kbs:
274
- # return cached_kbs
279
+ # return cached_kbs
275
280
 
276
281
  if self._knowledge_bases_to_include:
277
282
  return self._knowledge_bases_to_include
@@ -289,13 +294,15 @@ class SQLAgent:
289
294
  try:
290
295
  # Get knowledge bases from the project database
291
296
  kb_controller = self._command_executor.session.kb_controller
292
- kb_names = [kb['name'] for kb in kb_controller.list()]
297
+ kb_names = [kb["name"] for kb in kb_controller.list()]
293
298
 
294
299
  # Filter knowledge bases based on include list
295
300
  if self._knowledge_bases_to_include:
296
301
  kb_names = [kb_name for kb_name in kb_names if kb_name in self._knowledge_bases_to_include]
297
302
  if not kb_names:
298
- logger.warning(f"No knowledge bases found in the include list: {self._knowledge_bases_to_include}")
303
+ logger.warning(
304
+ f"No knowledge bases found in the include list: {self._knowledge_bases_to_include}"
305
+ )
299
306
  return []
300
307
 
301
308
  return kb_names
@@ -317,7 +324,7 @@ class SQLAgent:
317
324
  # Filter knowledge bases based on ignore list
318
325
  kb_names = []
319
326
  for row in result:
320
- kb_name = row['name']
327
+ kb_name = row["name"]
321
328
  if kb_name not in self._knowledge_bases_to_ignore:
322
329
  kb_names.append(kb_name)
323
330
 
@@ -368,7 +375,7 @@ class SQLAgent:
368
375
  return tables
369
376
 
370
377
  def get_knowledge_base_info(self, kb_names: Optional[List[str]] = None) -> str:
371
- """ Get information about specified knowledge bases.
378
+ """Get information about specified knowledge bases.
372
379
  Follows best practices as specified in: Rajkumar et al, 2022 (https://arxiv.org/abs/2204.00498)
373
380
  If `sample_rows_in_table_info`, the specified number of sample rows will be
374
381
  appended to each table description. This can increase performance as demonstrated in the paper.
@@ -388,38 +395,60 @@ class SQLAgent:
388
395
  return "\n\n".join(kbs_info)
389
396
 
390
397
  def get_table_info(self, table_names: Optional[List[str]] = None) -> str:
391
- """ Get information about specified tables.
398
+ """Get information about specified tables.
392
399
  Follows best practices as specified in: Rajkumar et al, 2022 (https://arxiv.org/abs/2204.00498)
393
400
  If `sample_rows_in_table_info`, the specified number of sample rows will be
394
401
  appended to each table description. This can increase performance as demonstrated in the paper.
395
402
  """
403
+ if config.get("data_catalog", {}).get("enabled", False):
404
+ database_table_map = {}
405
+ for name in table_names or self.get_usable_table_names():
406
+ name = name.replace("`", "")
396
407
 
397
- all_tables = []
398
- for name in self.get_usable_table_names():
399
- # remove backticks
400
- name = name.replace("`", "")
408
+ parts = name.split(".", 1)
409
+ # TODO: Will there be situations where parts has more than 2 elements? Like a schema?
410
+ # This is unlikely given that we default to a single schema per database.
411
+ if len(parts) == 1:
412
+ raise ValueError(f"Invalid table name: {name}. Expected format is 'database.table'.")
401
413
 
402
- split = name.split(".")
403
- if len(split) > 1:
404
- all_tables.append(Identifier(parts=[split[0], split[1]]))
405
- else:
406
- all_tables.append(Identifier(name))
414
+ database_table_map[parts[0]] = database_table_map.get(parts[0], []) + [parts[1]]
407
415
 
408
- # if table_names is not None:
409
- # all_tables = self._resolve_table_names(table_names, all_tables)
416
+ data_catalog_str = ""
417
+ for database_name, table_names in database_table_map.items():
418
+ data_catalog_reader = DataCatalogReader(database_name=database_name, table_names=table_names)
410
419
 
411
- tables_info = []
412
- for table in all_tables:
413
- key = f"{ctx.company_id}_{table}_info"
414
- table_info = self._cache.get(key) if self._cache else None
415
- if True or table_info is None:
416
- table_info = self._get_single_table_info(table)
417
- if self._cache:
418
- self._cache.set(key, table_info)
420
+ data_catalog_str += data_catalog_reader.read_metadata_as_string()
421
+
422
+ return data_catalog_str
419
423
 
420
- tables_info.append(table_info)
424
+ else:
425
+ # TODO: Improve old logic without data catalog
426
+ all_tables = []
427
+ for name in self.get_usable_table_names():
428
+ # remove backticks
429
+ name = name.replace("`", "")
430
+
431
+ split = name.split(".")
432
+ if len(split) > 1:
433
+ all_tables.append(Identifier(parts=[split[0], split[1]]))
434
+ else:
435
+ all_tables.append(Identifier(name))
421
436
 
422
- return "\n\n".join(tables_info)
437
+ if table_names is not None:
438
+ all_tables = self._resolve_table_names(table_names, all_tables)
439
+
440
+ tables_info = []
441
+ for table in all_tables:
442
+ key = f"{ctx.company_id}_{table}_info"
443
+ table_info = self._cache.get(key) if self._cache else None
444
+ if True or table_info is None:
445
+ table_info = self._get_single_table_info(table)
446
+ if self._cache:
447
+ self._cache.set(key, table_info)
448
+
449
+ tables_info.append(table_info)
450
+
451
+ return "\n\n".join(tables_info)
423
452
 
424
453
  def get_kb_sample_rows(self, kb_name: str) -> str:
425
454
  """Get sample rows from a knowledge base.
@@ -430,7 +459,7 @@ class SQLAgent:
430
459
  Returns:
431
460
  str: A string containing the sample rows from the knowledge base.
432
461
  """
433
- logger.info(f'_get_sample_rows: knowledge base={kb_name}')
462
+ logger.info(f"_get_sample_rows: knowledge base={kb_name}")
434
463
  command = f"select * from {kb_name} limit 10;"
435
464
  try:
436
465
  ret = self._call_engine(command)
@@ -438,13 +467,12 @@ class SQLAgent:
438
467
 
439
468
  def truncate_value(val):
440
469
  str_val = str(val)
441
- return str_val if len(str_val) < 100 else (str_val[:100] + '...')
470
+ return str_val if len(str_val) < 100 else (str_val[:100] + "...")
442
471
 
443
- sample_rows = list(
444
- map(lambda row: [truncate_value(value) for value in row], sample_rows))
472
+ sample_rows = list(map(lambda row: [truncate_value(value) for value in row], sample_rows))
445
473
  sample_rows_str = "\n" + f"{kb_name}:" + list_to_csv_str(sample_rows)
446
474
  except Exception as e:
447
- logger.info(f'_get_sample_rows error: {e}')
475
+ logger.info(f"_get_sample_rows error: {e}")
448
476
  sample_rows_str = "\n" + "\t [error] Couldn't retrieve sample rows!"
449
477
 
450
478
  return sample_rows_str
@@ -471,11 +499,9 @@ class SQLAgent:
471
499
 
472
500
  fields = df[INF_SCHEMA_COLUMNS_NAMES.COLUMN_NAME].to_list()
473
501
  dtypes = [
474
- mysql_data_type.value if isinstance(mysql_data_type, MYSQL_DATA_TYPE) else (data_type or 'UNKNOWN')
475
- for mysql_data_type, data_type
476
- in zip(
477
- df[INF_SCHEMA_COLUMNS_NAMES.MYSQL_DATA_TYPE],
478
- df[INF_SCHEMA_COLUMNS_NAMES.DATA_TYPE]
502
+ mysql_data_type.value if isinstance(mysql_data_type, MYSQL_DATA_TYPE) else (data_type or "UNKNOWN")
503
+ for mysql_data_type, data_type in zip(
504
+ df[INF_SCHEMA_COLUMNS_NAMES.MYSQL_DATA_TYPE], df[INF_SCHEMA_COLUMNS_NAMES.DATA_TYPE]
479
505
  )
480
506
  ]
481
507
  except Exception as e:
@@ -492,16 +518,18 @@ class SQLAgent:
492
518
  logger.warning(f"Could not get sample rows for {table_str}: {e}")
493
519
  sample_rows_info = "\n\t [error] Couldn't retrieve sample rows!"
494
520
 
495
- info = f'Table named `{table_str}`:\n'
521
+ info = f"Table named `{table_str}`:\n"
496
522
  info += f"\nSample with first {self._sample_rows_in_table_info} rows from table {table_str} in CSV format (dialect is 'excel'):\n"
497
523
  info += sample_rows_info + "\n"
498
- info += '\nColumn data types: ' + ",\t".join(
499
- [f'\n`{field}` : `{dtype}`' for field, dtype in zip(fields, dtypes)]
500
- ) + '\n'
524
+ info += (
525
+ "\nColumn data types: "
526
+ + ",\t".join([f"\n`{field}` : `{dtype}`" for field, dtype in zip(fields, dtypes)])
527
+ + "\n"
528
+ )
501
529
  return info
502
530
 
503
531
  def _get_sample_rows(self, table: str, fields: List[str]) -> str:
504
- logger.info(f'_get_sample_rows: table={table} fields={fields}')
532
+ logger.info(f"_get_sample_rows: table={table} fields={fields}")
505
533
  command = f"select {', '.join(fields)} from {table} limit {self._sample_rows_in_table_info};"
506
534
  try:
507
535
  ret = self._call_engine(command)
@@ -509,20 +537,19 @@ class SQLAgent:
509
537
 
510
538
  def truncate_value(val):
511
539
  str_val = str(val)
512
- return str_val if len(str_val) < 100 else (str_val[:100] + '...')
540
+ return str_val if len(str_val) < 100 else (str_val[:100] + "...")
513
541
 
514
- sample_rows = list(
515
- map(lambda row: [truncate_value(value) for value in row], sample_rows))
542
+ sample_rows = list(map(lambda row: [truncate_value(value) for value in row], sample_rows))
516
543
  sample_rows_str = "\n" + list_to_csv_str([fields] + sample_rows)
517
544
  except Exception as e:
518
- logger.info(f'_get_sample_rows error: {e}')
545
+ logger.info(f"_get_sample_rows error: {e}")
519
546
  sample_rows_str = "\n" + "\t [error] Couldn't retrieve sample rows!"
520
547
 
521
548
  return sample_rows_str
522
549
 
523
550
  def _clean_query(self, query: str) -> str:
524
551
  # Sometimes LLM can input markdown into query tools.
525
- cmd = re.sub(r'```(sql)?', '', query)
552
+ cmd = re.sub(r"```(sql)?", "", query)
526
553
  return cmd
527
554
 
528
555
  def query(self, command: str, fetch: str = "all") -> str:
@@ -534,16 +561,16 @@ class SQLAgent:
534
561
  def _repr_result(ret):
535
562
  limit_rows = 30
536
563
 
537
- columns_str = ', '.join([repr(col.name) for col in ret.columns])
538
- res = f'Output columns: {columns_str}\n'
564
+ columns_str = ", ".join([repr(col.name) for col in ret.columns])
565
+ res = f"Output columns: {columns_str}\n"
539
566
 
540
567
  data = ret.to_lists()
541
568
  if len(data) > limit_rows:
542
569
  df = pd.DataFrame(data, columns=[col.name for col in ret.columns])
543
570
 
544
- res += f'Result has {len(data)} rows. Description of data:\n'
545
- res += str(df.describe(include='all')) + '\n\n'
546
- res += f'First {limit_rows} rows:\n'
571
+ res += f"Result has {len(data)} rows. Description of data:\n"
572
+ res += str(df.describe(include="all")) + "\n\n"
573
+ res += f"First {limit_rows} rows:\n"
547
574
 
548
575
  else:
549
576
  res += "Result in CSV format (dialect is 'excel'):\n"
@@ -562,20 +589,20 @@ class SQLAgent:
562
589
 
563
590
  def get_table_info_safe(self, table_names: Optional[List[str]] = None) -> str:
564
591
  try:
565
- logger.info(f'get_table_info_safe: {table_names}')
592
+ logger.info(f"get_table_info_safe: {table_names}")
566
593
  return self.get_table_info(table_names)
567
594
  except Exception as e:
568
- logger.info(f'get_table_info_safe error: {e}')
595
+ logger.info(f"get_table_info_safe error: {e}")
569
596
  return f"Error: {e}"
570
597
 
571
598
  def query_safe(self, command: str, fetch: str = "all") -> str:
572
599
  try:
573
- logger.info(f'query_safe (fetch={fetch}): {command}')
600
+ logger.info(f"query_safe (fetch={fetch}): {command}")
574
601
  return self.query(command, fetch)
575
602
  except Exception as e:
576
603
  logger.error(f"Error in query_safe: {str(e)}\n{traceback.format_exc()}")
577
- logger.info(f'query_safe error: {e}')
604
+ logger.info(f"query_safe error: {e}")
578
605
  msg = f"Error: {e}"
579
- if 'does not exist' in msg and ' relation ' in msg:
580
- msg += '\nAvailable tables: ' + ', '.join(self.get_usable_table_names())
606
+ if "does not exist" in msg and " relation " in msg:
607
+ msg += "\nAvailable tables: " + ", ".join(self.get_usable_table_names())
581
608
  return msg