MindsDB 25.5.4.1__py3-none-any.whl → 25.6.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of MindsDB might be problematic. Click here for more details.

Files changed (70) hide show
  1. mindsdb/__about__.py +1 -1
  2. mindsdb/api/a2a/agent.py +28 -25
  3. mindsdb/api/a2a/common/server/server.py +32 -26
  4. mindsdb/api/a2a/run_a2a.py +1 -1
  5. mindsdb/api/executor/command_executor.py +69 -14
  6. mindsdb/api/executor/datahub/datanodes/integration_datanode.py +49 -65
  7. mindsdb/api/executor/datahub/datanodes/project_datanode.py +29 -48
  8. mindsdb/api/executor/datahub/datanodes/system_tables.py +35 -61
  9. mindsdb/api/executor/planner/plan_join.py +67 -77
  10. mindsdb/api/executor/planner/query_planner.py +176 -155
  11. mindsdb/api/executor/planner/steps.py +37 -12
  12. mindsdb/api/executor/sql_query/result_set.py +45 -64
  13. mindsdb/api/executor/sql_query/steps/fetch_dataframe.py +14 -18
  14. mindsdb/api/executor/sql_query/steps/fetch_dataframe_partition.py +17 -18
  15. mindsdb/api/executor/sql_query/steps/insert_step.py +13 -33
  16. mindsdb/api/executor/sql_query/steps/subselect_step.py +43 -35
  17. mindsdb/api/executor/utilities/sql.py +42 -48
  18. mindsdb/api/http/namespaces/config.py +1 -1
  19. mindsdb/api/http/namespaces/file.py +14 -23
  20. mindsdb/api/mysql/mysql_proxy/data_types/mysql_datum.py +12 -28
  21. mindsdb/api/mysql/mysql_proxy/data_types/mysql_packets/binary_resultset_row_package.py +59 -50
  22. mindsdb/api/mysql/mysql_proxy/data_types/mysql_packets/resultset_row_package.py +9 -8
  23. mindsdb/api/mysql/mysql_proxy/libs/constants/mysql.py +449 -461
  24. mindsdb/api/mysql/mysql_proxy/utilities/dump.py +87 -36
  25. mindsdb/integrations/handlers/file_handler/file_handler.py +15 -9
  26. mindsdb/integrations/handlers/file_handler/tests/test_file_handler.py +43 -24
  27. mindsdb/integrations/handlers/litellm_handler/litellm_handler.py +10 -3
  28. mindsdb/integrations/handlers/mysql_handler/mysql_handler.py +26 -33
  29. mindsdb/integrations/handlers/oracle_handler/oracle_handler.py +74 -51
  30. mindsdb/integrations/handlers/postgres_handler/postgres_handler.py +305 -98
  31. mindsdb/integrations/handlers/salesforce_handler/salesforce_handler.py +53 -34
  32. mindsdb/integrations/handlers/salesforce_handler/salesforce_tables.py +136 -6
  33. mindsdb/integrations/handlers/snowflake_handler/snowflake_handler.py +334 -83
  34. mindsdb/integrations/libs/api_handler.py +261 -57
  35. mindsdb/integrations/libs/base.py +100 -29
  36. mindsdb/integrations/utilities/files/file_reader.py +99 -73
  37. mindsdb/integrations/utilities/handler_utils.py +23 -8
  38. mindsdb/integrations/utilities/sql_utils.py +35 -40
  39. mindsdb/interfaces/agents/agents_controller.py +196 -192
  40. mindsdb/interfaces/agents/constants.py +7 -1
  41. mindsdb/interfaces/agents/langchain_agent.py +42 -11
  42. mindsdb/interfaces/agents/mcp_client_agent.py +29 -21
  43. mindsdb/interfaces/data_catalog/__init__.py +0 -0
  44. mindsdb/interfaces/data_catalog/base_data_catalog.py +54 -0
  45. mindsdb/interfaces/data_catalog/data_catalog_loader.py +359 -0
  46. mindsdb/interfaces/data_catalog/data_catalog_reader.py +34 -0
  47. mindsdb/interfaces/database/database.py +81 -57
  48. mindsdb/interfaces/database/integrations.py +220 -234
  49. mindsdb/interfaces/database/log.py +72 -104
  50. mindsdb/interfaces/database/projects.py +156 -193
  51. mindsdb/interfaces/file/file_controller.py +21 -65
  52. mindsdb/interfaces/knowledge_base/controller.py +63 -10
  53. mindsdb/interfaces/knowledge_base/evaluate.py +519 -0
  54. mindsdb/interfaces/knowledge_base/llm_client.py +75 -0
  55. mindsdb/interfaces/skills/custom/text2sql/mindsdb_kb_tools.py +83 -43
  56. mindsdb/interfaces/skills/skills_controller.py +54 -36
  57. mindsdb/interfaces/skills/sql_agent.py +109 -86
  58. mindsdb/interfaces/storage/db.py +223 -79
  59. mindsdb/migrations/versions/2025-05-28_a44643042fe8_added_data_catalog_tables.py +118 -0
  60. mindsdb/migrations/versions/2025-06-09_608e376c19a7_updated_data_catalog_data_types.py +58 -0
  61. mindsdb/utilities/config.py +9 -2
  62. mindsdb/utilities/log.py +35 -26
  63. mindsdb/utilities/ml_task_queue/task.py +19 -22
  64. mindsdb/utilities/render/sqlalchemy_render.py +129 -181
  65. mindsdb/utilities/starters.py +49 -1
  66. {mindsdb-25.5.4.1.dist-info → mindsdb-25.6.2.0.dist-info}/METADATA +268 -268
  67. {mindsdb-25.5.4.1.dist-info → mindsdb-25.6.2.0.dist-info}/RECORD +70 -62
  68. {mindsdb-25.5.4.1.dist-info → mindsdb-25.6.2.0.dist-info}/WHEEL +0 -0
  69. {mindsdb-25.5.4.1.dist-info → mindsdb-25.6.2.0.dist-info}/licenses/LICENSE +0 -0
  70. {mindsdb-25.5.4.1.dist-info → mindsdb-25.6.2.0.dist-info}/top_level.txt +0 -0
@@ -1,10 +1,11 @@
1
+ from dataclasses import dataclass, astuple
1
2
  import traceback
2
3
  import json
3
4
  import csv
4
5
  from io import BytesIO, StringIO, IOBase
5
6
  from pathlib import Path
6
7
  import codecs
7
- from typing import List
8
+ from typing import List, Generator
8
9
 
9
10
  import filetype
10
11
  import pandas as pd
@@ -18,8 +19,27 @@ DEFAULT_CHUNK_SIZE = 500
18
19
  DEFAULT_CHUNK_OVERLAP = 250
19
20
 
20
21
 
21
- class FileDetectError(Exception):
22
- ...
22
+ class FileProcessingError(Exception): ...
23
+
24
+
25
+ @dataclass(frozen=True, slots=True)
26
+ class _SINGLE_PAGE_FORMAT:
27
+ CSV: str = "csv"
28
+ JSON: str = "json"
29
+ TXT: str = "txt"
30
+ PDF: str = "pdf"
31
+ PARQUET: str = "parquet"
32
+
33
+
34
+ SINGLE_PAGE_FORMAT = _SINGLE_PAGE_FORMAT()
35
+
36
+
37
+ @dataclass(frozen=True, slots=True)
38
+ class _MULTI_PAGE_FORMAT:
39
+ XLSX: str = "xlsx"
40
+
41
+
42
+ MULTI_PAGE_FORMAT = _MULTI_PAGE_FORMAT()
23
43
 
24
44
 
25
45
  def decode(file_obj: IOBase) -> StringIO:
@@ -56,21 +76,20 @@ def decode(file_obj: IOBase) -> StringIO:
56
76
  data_str = StringIO(byte_str.decode(encoding, errors))
57
77
  except Exception as e:
58
78
  logger.error(traceback.format_exc())
59
- raise FileDetectError("Could not load into string") from e
79
+ raise FileProcessingError("Could not load into string") from e
60
80
 
61
81
  return data_str
62
82
 
63
83
 
64
84
  class FormatDetector:
65
-
66
- supported_formats = ['parquet', 'csv', 'xlsx', 'pdf', 'json', 'txt']
67
- multipage_formats = ['xlsx']
85
+ supported_formats = astuple(SINGLE_PAGE_FORMAT) + astuple(MULTI_PAGE_FORMAT)
86
+ multipage_formats = astuple(MULTI_PAGE_FORMAT)
68
87
 
69
88
  def __init__(
70
89
  self,
71
- path: str = None,
72
- name: str = None,
73
- file: IOBase = None
90
+ path: str | None = None,
91
+ name: str | None = None,
92
+ file: IOBase | None = None,
74
93
  ):
75
94
  """
76
95
  File format detector
@@ -81,16 +100,16 @@ class FormatDetector:
81
100
  :param file: file descriptor (via open(...), of BytesIO(...))
82
101
  """
83
102
  if path is not None:
84
- file = open(path, 'rb')
103
+ file = open(path, "rb")
85
104
 
86
105
  elif file is not None:
87
106
  if name is None:
88
- if hasattr(file, 'name'):
107
+ if hasattr(file, "name"):
89
108
  path = file.name
90
109
  else:
91
- path = 'file'
110
+ path = "file"
92
111
  else:
93
- raise FileDetectError('Wrong arguments: path or file is required')
112
+ raise FileProcessingError("Wrong arguments: path or file is required")
94
113
 
95
114
  if name is None:
96
115
  name = Path(path).name
@@ -108,14 +127,14 @@ class FormatDetector:
108
127
  format = self.get_format_by_name()
109
128
  if format is not None:
110
129
  if format not in self.supported_formats:
111
- raise FileDetectError(f'Not supported format: {format}')
130
+ raise FileProcessingError(f"Not supported format: {format}")
112
131
 
113
132
  if format is None and self.file_obj is not None:
114
133
  format = self.get_format_by_content()
115
134
  self.file_obj.seek(0)
116
135
 
117
136
  if format is None:
118
- raise FileDetectError(f'Unable to detect format: {self.name}')
137
+ raise FileProcessingError(f"Unable to detect format: {self.name}")
119
138
 
120
139
  self.format = format
121
140
  return format
@@ -124,33 +143,32 @@ class FormatDetector:
124
143
  extension = Path(self.name).suffix.strip(".").lower()
125
144
  if extension == "tsv":
126
145
  extension = "csv"
127
- self.parameters['delimiter'] = '\t'
146
+ self.parameters["delimiter"] = "\t"
128
147
 
129
148
  return extension or None
130
149
 
131
150
  def get_format_by_content(self):
132
151
  if self.is_parquet(self.file_obj):
133
- return "parquet"
152
+ return SINGLE_PAGE_FORMAT.PARQUET
134
153
 
135
154
  file_type = filetype.guess(self.file_obj)
136
155
  if file_type is not None:
137
-
138
156
  if file_type.mime in {
139
157
  "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
140
158
  "application/vnd.ms-excel",
141
159
  }:
142
- return 'xlsx'
160
+ return MULTI_PAGE_FORMAT.XLSX
143
161
 
144
- if file_type.mime == 'application/pdf':
145
- return "pdf"
162
+ if file_type.mime == "application/pdf":
163
+ return SINGLE_PAGE_FORMAT.PDF
146
164
 
147
165
  file_obj = decode(self.file_obj)
148
166
 
149
167
  if self.is_json(file_obj):
150
- return "json"
168
+ return SINGLE_PAGE_FORMAT.JSON
151
169
 
152
170
  if self.is_csv(file_obj):
153
- return "csv"
171
+ return SINGLE_PAGE_FORMAT.CSV
154
172
 
155
173
  @staticmethod
156
174
  def is_json(data_obj: StringIO) -> bool:
@@ -198,35 +216,53 @@ class FormatDetector:
198
216
  return False
199
217
 
200
218
 
201
- class FileReader(FormatDetector):
219
+ def format_column_names(df: pd.DataFrame):
220
+ df.columns = [column.strip(" \t") for column in df.columns]
221
+ if len(df.columns) != len(set(df.columns)) or any(len(column_name) == 0 for column_name in df.columns):
222
+ raise FileProcessingError("Each column should have a unique and non-empty name.")
223
+
202
224
 
225
+ class FileReader(FormatDetector):
203
226
  def _get_fnc(self):
204
227
  format = self.get_format()
205
- func = getattr(self, f'read_{format}', None)
228
+ func = getattr(self, f"read_{format}", None)
206
229
  if func is None:
207
- raise FileDetectError(f'Unsupported format: {format}')
208
- return func
230
+ raise FileProcessingError(f"Unsupported format: {format}")
231
+
232
+ if format in astuple(MULTI_PAGE_FORMAT):
233
+
234
+ def format_multipage(*args, **kwargs):
235
+ for page_number, df in func(*args, **kwargs):
236
+ format_column_names(df)
237
+ yield page_number, df
238
+
239
+ return format_multipage
240
+
241
+ def format_singlepage(*args, **kwargs) -> pd.DataFrame:
242
+ """Check that the columns have unique not-empty names"""
243
+ df = func(*args, **kwargs)
244
+ format_column_names(df)
245
+ return df
246
+
247
+ return format_singlepage
209
248
 
210
249
  def get_pages(self, **kwargs) -> List[str]:
211
250
  """
212
- Get list of tables in file
251
+ Get list of tables in file
213
252
  """
214
253
  format = self.get_format()
215
254
  if format not in self.multipage_formats:
216
255
  # only one table
217
- return ['main']
256
+ return ["main"]
218
257
 
219
258
  func = self._get_fnc()
220
259
  self.file_obj.seek(0)
221
260
 
222
- return [
223
- name for name, _ in
224
- func(self.file_obj, only_names=True, **kwargs)
225
- ]
261
+ return [name for name, _ in func(self.file_obj, only_names=True, **kwargs)]
226
262
 
227
- def get_contents(self, **kwargs):
263
+ def get_contents(self, **kwargs) -> dict[str, pd.DataFrame]:
228
264
  """
229
- Get all info(pages with content) from file as dict: {tablename, content}
265
+ Get all info(pages with content) from file as dict: {tablename, content}
230
266
  """
231
267
  func = self._get_fnc()
232
268
  self.file_obj.seek(0)
@@ -234,17 +270,13 @@ class FileReader(FormatDetector):
234
270
  format = self.get_format()
235
271
  if format not in self.multipage_formats:
236
272
  # only one table
237
- return {'main': func(self.file_obj, name=self.name, **kwargs)}
273
+ return {"main": func(self.file_obj, name=self.name, **kwargs)}
238
274
 
239
- return {
240
- name: df
241
- for name, df in
242
- func(self.file_obj, **kwargs)
243
- }
275
+ return {name: df for name, df in func(self.file_obj, **kwargs)}
244
276
 
245
- def get_page_content(self, page_name: str = None, **kwargs) -> pd.DataFrame:
277
+ def get_page_content(self, page_name: str | None = None, **kwargs) -> pd.DataFrame:
246
278
  """
247
- Get content of a single table
279
+ Get content of a single table
248
280
  """
249
281
  func = self._get_fnc()
250
282
  self.file_obj.seek(0)
@@ -258,7 +290,7 @@ class FileReader(FormatDetector):
258
290
  return df
259
291
 
260
292
  @staticmethod
261
- def _get_csv_dialect(buffer, delimiter=None) -> csv.Dialect:
293
+ def _get_csv_dialect(buffer, delimiter: str | None = None) -> csv.Dialect | None:
262
294
  sample = buffer.readline() # trying to get dialect from header
263
295
  buffer.seek(0)
264
296
  try:
@@ -270,42 +302,35 @@ class FileReader(FormatDetector):
270
302
  else:
271
303
  accepted_csv_delimiters = [",", "\t", ";"]
272
304
  try:
273
- dialect = csv.Sniffer().sniff(
274
- sample, delimiters=accepted_csv_delimiters
275
- )
276
- dialect.doublequote = (
277
- True # assume that all csvs have " as string escape
278
- )
305
+ dialect = csv.Sniffer().sniff(sample, delimiters=accepted_csv_delimiters)
306
+ dialect.doublequote = True # assume that all csvs have " as string escape
279
307
  except Exception:
280
308
  dialect = csv.reader(sample).dialect
281
309
  if dialect.delimiter not in accepted_csv_delimiters:
282
- raise Exception(
283
- f"CSV delimeter '{dialect.delimiter}' is not supported"
284
- )
310
+ raise FileProcessingError(f"CSV delimeter '{dialect.delimiter}' is not supported")
285
311
 
286
312
  except csv.Error:
287
313
  dialect = None
288
314
  return dialect
289
315
 
290
316
  @classmethod
291
- def read_csv(cls, file_obj: BytesIO, delimiter=None, **kwargs):
317
+ def read_csv(cls, file_obj: BytesIO, delimiter: str | None = None, **kwargs) -> pd.DataFrame:
292
318
  file_obj = decode(file_obj)
293
319
  dialect = cls._get_csv_dialect(file_obj, delimiter=delimiter)
294
-
295
320
  return pd.read_csv(file_obj, sep=dialect.delimiter, index_col=False)
296
321
 
297
322
  @staticmethod
298
- def read_txt(file_obj: BytesIO, name=None, **kwargs):
323
+ def read_txt(file_obj: BytesIO, name: str | None = None, **kwargs) -> pd.DataFrame:
299
324
  # the lib is heavy, so import it only when needed
300
325
  from langchain_text_splitters import RecursiveCharacterTextSplitter
326
+
301
327
  file_obj = decode(file_obj)
302
328
 
303
329
  try:
304
330
  from langchain_core.documents import Document
305
331
  except ImportError:
306
- raise ImportError(
307
- "To import TXT document please install 'langchain-community':\n"
308
- " pip install langchain-community"
332
+ raise FileProcessingError(
333
+ "To import TXT document please install 'langchain-community':\n pip install langchain-community"
309
334
  )
310
335
  text = file_obj.read()
311
336
 
@@ -317,15 +342,10 @@ class FileReader(FormatDetector):
317
342
  )
318
343
 
319
344
  docs = text_splitter.split_documents(documents)
320
- return pd.DataFrame(
321
- [
322
- {"content": doc.page_content, "metadata": doc.metadata}
323
- for doc in docs
324
- ]
325
- )
345
+ return pd.DataFrame([{"content": doc.page_content, "metadata": doc.metadata} for doc in docs])
326
346
 
327
347
  @staticmethod
328
- def read_pdf(file_obj: BytesIO, name=None, **kwargs):
348
+ def read_pdf(file_obj: BytesIO, name: str | None = None, **kwargs) -> pd.DataFrame:
329
349
  # the libs are heavy, so import it only when needed
330
350
  import fitz # pymupdf
331
351
  from langchain_text_splitters import RecursiveCharacterTextSplitter
@@ -340,30 +360,36 @@ class FileReader(FormatDetector):
340
360
  split_text = text_splitter.split_text(text)
341
361
 
342
362
  return pd.DataFrame(
343
- {"content": split_text, "metadata": [{"file_format": "pdf", "source_file": name}] * len(split_text)}
363
+ {
364
+ "content": split_text,
365
+ "metadata": [{"file_format": "pdf", "source_file": name}] * len(split_text),
366
+ }
344
367
  )
345
368
 
346
369
  @staticmethod
347
- def read_json(file_obj: BytesIO, **kwargs):
370
+ def read_json(file_obj: BytesIO, **kwargs) -> pd.DataFrame:
348
371
  file_obj = decode(file_obj)
349
372
  file_obj.seek(0)
350
373
  json_doc = json.loads(file_obj.read())
351
374
  return pd.json_normalize(json_doc, max_level=0)
352
375
 
353
376
  @staticmethod
354
- def read_parquet(file_obj: BytesIO, **kwargs):
377
+ def read_parquet(file_obj: BytesIO, **kwargs) -> pd.DataFrame:
355
378
  return pd.read_parquet(file_obj)
356
379
 
357
380
  @staticmethod
358
- def read_xlsx(file_obj: BytesIO, page_name=None, only_names=False, **kwargs):
381
+ def read_xlsx(
382
+ file_obj: BytesIO,
383
+ page_name: str | None = None,
384
+ only_names: bool = False,
385
+ **kwargs,
386
+ ) -> Generator[tuple[str, pd.DataFrame | None], None, None]:
359
387
  with pd.ExcelFile(file_obj) as xls:
360
-
361
388
  if page_name is not None:
362
389
  # return specific page
363
390
  yield page_name, pd.read_excel(xls, sheet_name=page_name)
364
391
 
365
392
  for page_name in xls.sheet_names:
366
-
367
393
  if only_names:
368
394
  # extract only pages names
369
395
  df = None
@@ -39,22 +39,36 @@ def get_api_key(
39
39
  if "using" in create_args and f"{api_name.lower()}_api_key" in create_args["using"]:
40
40
  return create_args["using"][f"{api_name.lower()}_api_key"]
41
41
 
42
+ # 1.5 - Check for generic api_key in using
43
+ if "using" in create_args and "api_key" in create_args["using"]:
44
+ return create_args["using"]["api_key"]
45
+
42
46
  # 2
43
47
  if f"{api_name.lower()}_api_key" in create_args:
44
48
  return create_args[f"{api_name.lower()}_api_key"]
45
49
 
46
- # 2.5 - Check in params dictionary if it exists (for agents)
50
+ # 2.5 - Check for generic api_key
51
+ if "api_key" in create_args:
52
+ return create_args["api_key"]
53
+
54
+ # 3 - Check in params dictionary if it exists (for agents)
47
55
  if "params" in create_args and create_args["params"] is not None:
48
56
  if f"{api_name.lower()}_api_key" in create_args["params"]:
49
57
  return create_args["params"][f"{api_name.lower()}_api_key"]
58
+ # 3.5 - Check for generic api_key in params
59
+ if "api_key" in create_args["params"]:
60
+ return create_args["params"]["api_key"]
50
61
 
51
- # 3
62
+ # 4
52
63
  if engine_storage is not None:
53
64
  connection_args = engine_storage.get_connection_args()
54
65
  if f"{api_name.lower()}_api_key" in connection_args:
55
66
  return connection_args[f"{api_name.lower()}_api_key"]
67
+ # 4.5 - Check for generic api_key in connection_args
68
+ if "api_key" in connection_args:
69
+ return connection_args["api_key"]
56
70
 
57
- # 4
71
+ # 5
58
72
  api_key = os.getenv(f"{api_name.lower()}_api_key")
59
73
  if api_key is not None:
60
74
  return api_key
@@ -62,15 +76,15 @@ def get_api_key(
62
76
  if api_key is not None:
63
77
  return api_key
64
78
 
65
- # 5
79
+ # 6
66
80
  config = Config()
67
81
  api_cfg = config.get(api_name, {})
68
82
  if f"{api_name.lower()}_api_key" in api_cfg:
69
83
  return api_cfg[f"{api_name.lower()}_api_key"]
70
84
 
71
- # 6
72
- if 'api_keys' in create_args and api_name in create_args['api_keys']:
73
- return create_args['api_keys'][api_name]
85
+ # 7
86
+ if "api_keys" in create_args and api_name in create_args["api_keys"]:
87
+ return create_args["api_keys"][api_name]
74
88
 
75
89
  if strict:
76
90
  provider_upper = api_name.upper()
@@ -79,8 +93,9 @@ def get_api_key(
79
93
  error_message = (
80
94
  f"API key for {api_name} not found. Please provide it using one of the following methods:\n"
81
95
  f"1. Set the {api_key_env_var} environment variable\n"
82
- f"2. Provide it as '{api_key_arg}' parameter when creating an agent using the CREATE AGENT syntax\n"
96
+ f"2. Provide it as '{api_key_arg}' parameter or 'api_key' parameter when creating an agent using the CREATE AGENT syntax\n"
83
97
  f" Example: CREATE AGENT my_agent USING model='gpt-4', provider='{api_name}', {api_key_arg}='your-api-key';\n"
98
+ f" Or: CREATE AGENT my_agent USING model='gpt-4', provider='{api_name}', api_key='your-api-key';\n"
84
99
  )
85
100
  raise Exception(error_message)
86
101
  return None
@@ -46,11 +46,7 @@ class FilterCondition:
46
46
 
47
47
  def __eq__(self, __value: object) -> bool:
48
48
  if isinstance(__value, FilterCondition):
49
- return (
50
- self.column == __value.column
51
- and self.op == __value.op
52
- and self.value == __value.value
53
- )
49
+ return self.column == __value.column and self.op == __value.op and self.value == __value.value
54
50
  else:
55
51
  return False
56
52
 
@@ -75,7 +71,7 @@ def make_sql_session():
75
71
  from mindsdb.api.executor.controllers.session_controller import SessionController
76
72
 
77
73
  sql_session = SessionController()
78
- sql_session.database = config.get('default_project')
74
+ sql_session.database = config.get("default_project")
79
75
  return sql_session
80
76
 
81
77
 
@@ -84,44 +80,50 @@ def conditions_to_filter(binary_op: ASTNode):
84
80
 
85
81
  filters = {}
86
82
  for op, arg1, arg2 in conditions:
87
- if op != '=':
83
+ if op != "=":
88
84
  raise NotImplementedError
89
85
  filters[arg1] = arg2
90
86
  return filters
91
87
 
92
88
 
93
- def extract_comparison_conditions(binary_op: ASTNode):
94
- '''Extracts all simple comparison conditions that must be true from an AST node.
89
+ def extract_comparison_conditions(binary_op: ASTNode, ignore_functions=False):
90
+ """Extracts all simple comparison conditions that must be true from an AST node.
95
91
  Does NOT support 'or' conditions.
96
- '''
92
+ """
97
93
  conditions = []
98
94
 
99
95
  def _extract_comparison_conditions(node: ASTNode, **kwargs):
100
96
  if isinstance(node, ast.BinaryOperation):
101
97
  op = node.op.lower()
102
- if op == 'and':
98
+ if op == "and":
103
99
  # Want to separate individual conditions, not include 'and' as its own condition.
104
100
  return
105
- elif not isinstance(node.args[0], ast.Identifier):
101
+
102
+ arg1, arg2 = node.args
103
+ if ignore_functions and isinstance(arg1, ast.Function):
104
+ # handle lower/upper
105
+ if arg1.op.lower() in ("lower", "upper"):
106
+ if isinstance(arg1.args[0], ast.Identifier):
107
+ arg1 = arg1.args[0]
108
+
109
+ if not isinstance(arg1, ast.Identifier):
106
110
  # Only support [identifier] =/</>/>=/<=/etc [constant] comparisons.
107
- raise NotImplementedError(f'Not implemented arg1: {node.args[0]}')
111
+ raise NotImplementedError(f"Not implemented arg1: {arg1}")
108
112
 
109
- if isinstance(node.args[1], ast.Constant):
110
- value = node.args[1].value
111
- elif isinstance(node.args[1], ast.Tuple):
112
- value = [i.value for i in node.args[1].items]
113
+ if isinstance(arg2, ast.Constant):
114
+ value = arg2.value
115
+ elif isinstance(arg2, ast.Tuple):
116
+ value = [i.value for i in arg2.items]
113
117
  else:
114
- raise NotImplementedError(f'Not implemented arg2: {node.args[1]}')
118
+ raise NotImplementedError(f"Not implemented arg2: {arg2}")
115
119
 
116
- conditions.append([op, node.args[0].parts[-1], value])
120
+ conditions.append([op, arg1.parts[-1], value])
117
121
  if isinstance(node, ast.BetweenOperation):
118
122
  var, up, down = node.args
119
123
  if not (
120
- isinstance(var, ast.Identifier)
121
- and isinstance(up, ast.Constant)
122
- and isinstance(down, ast.Constant)
124
+ isinstance(var, ast.Identifier) and isinstance(up, ast.Constant) and isinstance(down, ast.Constant)
123
125
  ):
124
- raise NotImplementedError(f'Not implemented: {node}')
126
+ raise NotImplementedError(f"Not implemented: {node}")
125
127
 
126
128
  op = node.op.lower()
127
129
  conditions.append([op, var.parts[-1], (up.value, down.value)])
@@ -131,16 +133,13 @@ def extract_comparison_conditions(binary_op: ASTNode):
131
133
 
132
134
 
133
135
  def project_dataframe(df, targets, table_columns):
134
- '''
135
- case-insensitive projection
136
- 'select A' and 'select a' return different column case but with the same content
137
- '''
136
+ """
137
+ case-insensitive projection
138
+ 'select A' and 'select a' return different column case but with the same content
139
+ """
138
140
 
139
141
  columns = []
140
- df_cols_idx = {
141
- col.lower(): col
142
- for col in df.columns
143
- }
142
+ df_cols_idx = {col.lower(): col for col in df.columns}
144
143
  df_col_rename = {}
145
144
 
146
145
  for target in targets:
@@ -156,10 +155,7 @@ def project_dataframe(df, targets, table_columns):
156
155
  col = target.parts[-1]
157
156
  col_df = df_cols_idx.get(col.lower())
158
157
  if col_df is not None:
159
- if (
160
- hasattr(target, 'alias')
161
- and isinstance(target.alias, ast.Identifier)
162
- ):
158
+ if hasattr(target, "alias") and isinstance(target.alias, ast.Identifier):
163
159
  df_col_rename[col_df] = target.alias.parts[0]
164
160
  else:
165
161
  df_col_rename[col_df] = col
@@ -184,14 +180,13 @@ def project_dataframe(df, targets, table_columns):
184
180
 
185
181
 
186
182
  def filter_dataframe(df: pd.DataFrame, conditions: list):
187
-
188
183
  # convert list of conditions to ast.
189
184
  # assumes that list was got from extract_comparison_conditions
190
185
  where_query = None
191
186
  for op, arg1, arg2 in conditions:
192
187
  op = op.lower()
193
188
 
194
- if op == 'between':
189
+ if op == "between":
195
190
  item = ast.BetweenOperation(args=[ast.Identifier(arg1), ast.Constant(arg2[0]), ast.Constant(arg2[1])])
196
191
  else:
197
192
  if isinstance(arg2, (tuple, list)):
@@ -201,9 +196,9 @@ def filter_dataframe(df: pd.DataFrame, conditions: list):
201
196
  if where_query is None:
202
197
  where_query = item
203
198
  else:
204
- where_query = ast.BinaryOperation(op='and', args=[where_query, item])
199
+ where_query = ast.BinaryOperation(op="and", args=[where_query, item])
205
200
 
206
- query = ast.Select(targets=[ast.Star()], from_table=ast.Identifier('df'), where=where_query)
201
+ query = ast.Select(targets=[ast.Star()], from_table=ast.Identifier("df"), where=where_query)
207
202
 
208
203
  return query_df(df, query)
209
204
 
@@ -220,7 +215,7 @@ def sort_dataframe(df, order_by: list):
220
215
  continue
221
216
 
222
217
  cols.append(col)
223
- ascending.append(False if order.direction.lower() == 'desc' else True)
218
+ ascending.append(False if order.direction.lower() == "desc" else True)
224
219
  if len(cols) > 0:
225
220
  df = df.sort_values(by=cols, ascending=ascending)
226
221
  return df