MindsDB 25.5.4.2__py3-none-any.whl → 25.6.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of MindsDB might be problematic. Click here for more details.
- mindsdb/__about__.py +1 -1
- mindsdb/api/a2a/agent.py +50 -26
- mindsdb/api/a2a/common/server/server.py +32 -26
- mindsdb/api/a2a/task_manager.py +68 -6
- mindsdb/api/executor/command_executor.py +69 -14
- mindsdb/api/executor/datahub/datanodes/integration_datanode.py +49 -65
- mindsdb/api/executor/datahub/datanodes/mindsdb_tables.py +91 -84
- mindsdb/api/executor/datahub/datanodes/project_datanode.py +29 -48
- mindsdb/api/executor/datahub/datanodes/system_tables.py +35 -61
- mindsdb/api/executor/planner/plan_join.py +67 -77
- mindsdb/api/executor/planner/query_planner.py +176 -155
- mindsdb/api/executor/planner/steps.py +37 -12
- mindsdb/api/executor/sql_query/result_set.py +45 -64
- mindsdb/api/executor/sql_query/steps/fetch_dataframe.py +14 -18
- mindsdb/api/executor/sql_query/steps/fetch_dataframe_partition.py +17 -18
- mindsdb/api/executor/sql_query/steps/insert_step.py +13 -33
- mindsdb/api/executor/sql_query/steps/subselect_step.py +43 -35
- mindsdb/api/executor/utilities/sql.py +42 -48
- mindsdb/api/http/namespaces/config.py +1 -1
- mindsdb/api/http/namespaces/file.py +14 -23
- mindsdb/api/http/namespaces/knowledge_bases.py +132 -154
- mindsdb/api/mysql/mysql_proxy/data_types/mysql_datum.py +12 -28
- mindsdb/api/mysql/mysql_proxy/data_types/mysql_packets/binary_resultset_row_package.py +59 -50
- mindsdb/api/mysql/mysql_proxy/data_types/mysql_packets/resultset_row_package.py +9 -8
- mindsdb/api/mysql/mysql_proxy/libs/constants/mysql.py +449 -461
- mindsdb/api/mysql/mysql_proxy/utilities/dump.py +87 -36
- mindsdb/integrations/handlers/bigquery_handler/bigquery_handler.py +219 -28
- mindsdb/integrations/handlers/file_handler/file_handler.py +15 -9
- mindsdb/integrations/handlers/file_handler/tests/test_file_handler.py +43 -24
- mindsdb/integrations/handlers/litellm_handler/litellm_handler.py +10 -3
- mindsdb/integrations/handlers/llama_index_handler/requirements.txt +1 -1
- mindsdb/integrations/handlers/mysql_handler/mysql_handler.py +29 -33
- mindsdb/integrations/handlers/openai_handler/openai_handler.py +277 -356
- mindsdb/integrations/handlers/oracle_handler/oracle_handler.py +74 -51
- mindsdb/integrations/handlers/postgres_handler/postgres_handler.py +305 -98
- mindsdb/integrations/handlers/salesforce_handler/salesforce_handler.py +145 -40
- mindsdb/integrations/handlers/salesforce_handler/salesforce_tables.py +136 -6
- mindsdb/integrations/handlers/snowflake_handler/snowflake_handler.py +352 -83
- mindsdb/integrations/libs/api_handler.py +279 -57
- mindsdb/integrations/libs/base.py +185 -30
- mindsdb/integrations/utilities/files/file_reader.py +99 -73
- mindsdb/integrations/utilities/handler_utils.py +23 -8
- mindsdb/integrations/utilities/sql_utils.py +35 -40
- mindsdb/interfaces/agents/agents_controller.py +226 -196
- mindsdb/interfaces/agents/constants.py +8 -1
- mindsdb/interfaces/agents/langchain_agent.py +42 -11
- mindsdb/interfaces/agents/mcp_client_agent.py +29 -21
- mindsdb/interfaces/agents/mindsdb_database_agent.py +23 -18
- mindsdb/interfaces/data_catalog/__init__.py +0 -0
- mindsdb/interfaces/data_catalog/base_data_catalog.py +54 -0
- mindsdb/interfaces/data_catalog/data_catalog_loader.py +375 -0
- mindsdb/interfaces/data_catalog/data_catalog_reader.py +38 -0
- mindsdb/interfaces/database/database.py +81 -57
- mindsdb/interfaces/database/integrations.py +222 -234
- mindsdb/interfaces/database/log.py +72 -104
- mindsdb/interfaces/database/projects.py +156 -193
- mindsdb/interfaces/file/file_controller.py +21 -65
- mindsdb/interfaces/knowledge_base/controller.py +66 -25
- mindsdb/interfaces/knowledge_base/evaluate.py +516 -0
- mindsdb/interfaces/knowledge_base/llm_client.py +75 -0
- mindsdb/interfaces/skills/custom/text2sql/mindsdb_kb_tools.py +83 -43
- mindsdb/interfaces/skills/skills_controller.py +31 -36
- mindsdb/interfaces/skills/sql_agent.py +113 -86
- mindsdb/interfaces/storage/db.py +242 -82
- mindsdb/migrations/versions/2025-05-28_a44643042fe8_added_data_catalog_tables.py +118 -0
- mindsdb/migrations/versions/2025-06-09_608e376c19a7_updated_data_catalog_data_types.py +58 -0
- mindsdb/utilities/config.py +13 -2
- mindsdb/utilities/log.py +35 -26
- mindsdb/utilities/ml_task_queue/task.py +19 -22
- mindsdb/utilities/render/sqlalchemy_render.py +129 -181
- mindsdb/utilities/starters.py +40 -0
- {mindsdb-25.5.4.2.dist-info → mindsdb-25.6.3.0.dist-info}/METADATA +257 -257
- {mindsdb-25.5.4.2.dist-info → mindsdb-25.6.3.0.dist-info}/RECORD +76 -68
- {mindsdb-25.5.4.2.dist-info → mindsdb-25.6.3.0.dist-info}/WHEEL +0 -0
- {mindsdb-25.5.4.2.dist-info → mindsdb-25.6.3.0.dist-info}/licenses/LICENSE +0 -0
- {mindsdb-25.5.4.2.dist-info → mindsdb-25.6.3.0.dist-info}/top_level.txt +0 -0
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import ast
|
|
2
|
+
import concurrent.futures
|
|
2
3
|
import inspect
|
|
3
4
|
import textwrap
|
|
4
5
|
from _ast import AnnAssign, AugAssign
|
|
@@ -8,20 +9,20 @@ import pandas as pd
|
|
|
8
9
|
from mindsdb_sql_parser.ast.base import ASTNode
|
|
9
10
|
from mindsdb.utilities import log
|
|
10
11
|
|
|
11
|
-
from mindsdb.integrations.libs.response import HandlerResponse, HandlerStatusResponse
|
|
12
|
+
from mindsdb.integrations.libs.response import HandlerResponse, HandlerStatusResponse, RESPONSE_TYPE
|
|
12
13
|
|
|
13
14
|
logger = log.getLogger(__name__)
|
|
14
15
|
|
|
15
16
|
|
|
16
17
|
class BaseHandler:
|
|
17
|
-
"""
|
|
18
|
+
"""Base class for database handlers
|
|
18
19
|
|
|
19
20
|
Base class for handlers that associate a source of information with the
|
|
20
21
|
broader MindsDB ecosystem via SQL commands.
|
|
21
22
|
"""
|
|
22
23
|
|
|
23
24
|
def __init__(self, name: str):
|
|
24
|
-
"""
|
|
25
|
+
"""constructor
|
|
25
26
|
Args:
|
|
26
27
|
name (str): the handler name
|
|
27
28
|
"""
|
|
@@ -29,7 +30,7 @@ class BaseHandler:
|
|
|
29
30
|
self.name = name
|
|
30
31
|
|
|
31
32
|
def connect(self):
|
|
32
|
-
"""
|
|
33
|
+
"""Set up any connections required by the handler
|
|
33
34
|
|
|
34
35
|
Should return connection
|
|
35
36
|
|
|
@@ -37,7 +38,7 @@ class BaseHandler:
|
|
|
37
38
|
raise NotImplementedError()
|
|
38
39
|
|
|
39
40
|
def disconnect(self):
|
|
40
|
-
"""
|
|
41
|
+
"""Close any existing connections
|
|
41
42
|
|
|
42
43
|
Should switch self.is_connected.
|
|
43
44
|
"""
|
|
@@ -45,7 +46,7 @@ class BaseHandler:
|
|
|
45
46
|
return
|
|
46
47
|
|
|
47
48
|
def check_connection(self) -> HandlerStatusResponse:
|
|
48
|
-
"""
|
|
49
|
+
"""Check connection to the handler
|
|
49
50
|
|
|
50
51
|
Returns:
|
|
51
52
|
HandlerStatusResponse
|
|
@@ -77,7 +78,7 @@ class BaseHandler:
|
|
|
77
78
|
raise NotImplementedError()
|
|
78
79
|
|
|
79
80
|
def get_tables(self) -> HandlerResponse:
|
|
80
|
-
"""
|
|
81
|
+
"""Return list of entities
|
|
81
82
|
|
|
82
83
|
Return list of entities that will be accesible as tables.
|
|
83
84
|
|
|
@@ -89,7 +90,7 @@ class BaseHandler:
|
|
|
89
90
|
raise NotImplementedError()
|
|
90
91
|
|
|
91
92
|
def get_columns(self, table_name: str) -> HandlerResponse:
|
|
92
|
-
"""
|
|
93
|
+
"""Returns a list of entity columns
|
|
93
94
|
|
|
94
95
|
Args:
|
|
95
96
|
table_name (str): name of one of tables returned by self.get_tables()
|
|
@@ -113,6 +114,174 @@ class DatabaseHandler(BaseHandler):
|
|
|
113
114
|
super().__init__(name)
|
|
114
115
|
|
|
115
116
|
|
|
117
|
+
class MetaDatabaseHandler(DatabaseHandler):
|
|
118
|
+
"""
|
|
119
|
+
Base class for handlers associated to data storage systems (e.g. databases, data warehouses, streaming services, etc.)
|
|
120
|
+
|
|
121
|
+
This class is used when the handler is also needed to store information in the data catalog.
|
|
122
|
+
This information is typically avaiable in the information schema or system tables of the database.
|
|
123
|
+
"""
|
|
124
|
+
|
|
125
|
+
def __init__(self, name: str):
|
|
126
|
+
super().__init__(name)
|
|
127
|
+
|
|
128
|
+
def meta_get_tables(self, table_names: Optional[List[str]]) -> HandlerResponse:
|
|
129
|
+
"""
|
|
130
|
+
Returns metadata information about the tables to be stored in the data catalog.
|
|
131
|
+
|
|
132
|
+
Returns:
|
|
133
|
+
HandlerResponse: The response should consist of the following columns:
|
|
134
|
+
- TABLE_NAME (str): Name of the table.
|
|
135
|
+
- TABLE_TYPE (str): Type of the table, e.g. 'BASE TABLE', 'VIEW', etc. (optional).
|
|
136
|
+
- TABLE_SCHEMA (str): Schema of the table (optional).
|
|
137
|
+
- TABLE_DESCRIPTION (str): Description of the table (optional).
|
|
138
|
+
- ROW_COUNT (int): Estimated number of rows in the table (optional).
|
|
139
|
+
"""
|
|
140
|
+
raise NotImplementedError()
|
|
141
|
+
|
|
142
|
+
def meta_get_columns(self, table_names: Optional[List[str]]) -> HandlerResponse:
|
|
143
|
+
"""
|
|
144
|
+
Returns metadata information about the columns in the tables to be stored in the data catalog.
|
|
145
|
+
|
|
146
|
+
Returns:
|
|
147
|
+
HandlerResponse: The response should consist of the following columns:
|
|
148
|
+
- TABLE_NAME (str): Name of the table.
|
|
149
|
+
- COLUMN_NAME (str): Name of the column.
|
|
150
|
+
- DATA_TYPE (str): Data type of the column, e.g. 'VARCHAR', 'INT', etc.
|
|
151
|
+
- COLUMN_DESCRIPTION (str): Description of the column (optional).
|
|
152
|
+
- IS_NULLABLE (bool): Whether the column can contain NULL values (optional).
|
|
153
|
+
- COLUMN_DEFAULT (str): Default value of the column (optional).
|
|
154
|
+
"""
|
|
155
|
+
raise NotImplementedError()
|
|
156
|
+
|
|
157
|
+
def meta_get_column_statistics(self, table_names: Optional[List[str]]) -> HandlerResponse:
|
|
158
|
+
"""
|
|
159
|
+
Returns metadata statisical information about the columns in the tables to be stored in the data catalog.
|
|
160
|
+
Either this method should be overridden in the handler or `meta_get_column_statistics_for_table` should be implemented.
|
|
161
|
+
|
|
162
|
+
Returns:
|
|
163
|
+
HandlerResponse: The response should consist of the following columns:
|
|
164
|
+
- TABLE_NAME (str): Name of the table.
|
|
165
|
+
- COLUMN_NAME (str): Name of the column.
|
|
166
|
+
- MOST_COMMON_VALUES (List[str]): Most common values in the column (optional).
|
|
167
|
+
- MOST_COMMON_FREQUENCIES (List[str]): Frequencies of the most common values in the column (optional).
|
|
168
|
+
- NULL_PERCENTAGE: Percentage of NULL values in the column (optional).
|
|
169
|
+
- MINIMUM_VALUE (str): Minimum value in the column (optional).
|
|
170
|
+
- MAXIMUM_VALUE (str): Maximum value in the column (optional).
|
|
171
|
+
- DISTINCT_VALUES_COUNT (int): Count of distinct values in the column (optional).
|
|
172
|
+
"""
|
|
173
|
+
method = getattr(self, "meta_get_column_statistics_for_table")
|
|
174
|
+
if method.__func__ is not MetaDatabaseHandler.meta_get_column_statistics_for_table:
|
|
175
|
+
meta_columns = self.meta_get_columns(table_names)
|
|
176
|
+
grouped_columns = (
|
|
177
|
+
meta_columns.data_frame.groupby("table_name")
|
|
178
|
+
.agg(
|
|
179
|
+
{
|
|
180
|
+
"column_name": list,
|
|
181
|
+
}
|
|
182
|
+
)
|
|
183
|
+
.reset_index()
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
executor = concurrent.futures.ThreadPoolExecutor(max_workers=5)
|
|
187
|
+
futures = []
|
|
188
|
+
|
|
189
|
+
results = []
|
|
190
|
+
with executor:
|
|
191
|
+
for _, row in grouped_columns.iterrows():
|
|
192
|
+
table_name = row["table_name"]
|
|
193
|
+
columns = row["column_name"]
|
|
194
|
+
futures.append(executor.submit(self.meta_get_column_statistics_for_table, table_name, columns))
|
|
195
|
+
|
|
196
|
+
for future in concurrent.futures.as_completed(futures):
|
|
197
|
+
try:
|
|
198
|
+
result = future.result(timeout=120)
|
|
199
|
+
if result.resp_type == RESPONSE_TYPE.TABLE:
|
|
200
|
+
results.append(result.data_frame)
|
|
201
|
+
else:
|
|
202
|
+
logger.error(
|
|
203
|
+
f"Error retrieving column statistics for table {table_name}: {result.error_message}"
|
|
204
|
+
)
|
|
205
|
+
except Exception as e:
|
|
206
|
+
logger.error(f"Exception occurred while retrieving column statistics for table {table_name}: {e}")
|
|
207
|
+
|
|
208
|
+
if not results:
|
|
209
|
+
logger.warning("No column statistics could be retrieved for the specified tables.")
|
|
210
|
+
return HandlerResponse(RESPONSE_TYPE.ERROR, error_message="No column statistics could be retrieved.")
|
|
211
|
+
return HandlerResponse(
|
|
212
|
+
RESPONSE_TYPE.TABLE, pd.concat(results, ignore_index=True) if results else pd.DataFrame()
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
else:
|
|
216
|
+
raise NotImplementedError()
|
|
217
|
+
|
|
218
|
+
def meta_get_column_statistics_for_table(
|
|
219
|
+
self, table_name: str, column_names: Optional[List[str]] = None
|
|
220
|
+
) -> HandlerResponse:
|
|
221
|
+
"""
|
|
222
|
+
Returns metadata statistical information about the columns in a specific table to be stored in the data catalog.
|
|
223
|
+
Either this method should be implemented in the handler or `meta_get_column_statistics` should be overridden.
|
|
224
|
+
|
|
225
|
+
Args:
|
|
226
|
+
table_name (str): Name of the table.
|
|
227
|
+
column_names (Optional[List[str]]): List of column names to retrieve statistics for. If None, statistics for all columns will be returned.
|
|
228
|
+
|
|
229
|
+
Returns:
|
|
230
|
+
HandlerResponse: The response should consist of the following columns:
|
|
231
|
+
- TABLE_NAME (str): Name of the table.
|
|
232
|
+
- COLUMN_NAME (str): Name of the column.
|
|
233
|
+
- MOST_COMMON_VALUES (List[str]): Most common values in the column (optional).
|
|
234
|
+
- MOST_COMMON_FREQUENCIES (List[str]): Frequencies of the most common values in the column (optional).
|
|
235
|
+
- NULL_PERCENTAGE: Percentage of NULL values in the column (optional).
|
|
236
|
+
- MINIMUM_VALUE (str): Minimum value in the column (optional).
|
|
237
|
+
- MAXIMUM_VALUE (str): Maximum value in the column (optional).
|
|
238
|
+
- DISTINCT_VALUES_COUNT (int): Count of distinct values in the column (optional).
|
|
239
|
+
"""
|
|
240
|
+
pass
|
|
241
|
+
|
|
242
|
+
def meta_get_primary_keys(self, table_names: Optional[List[str]]) -> HandlerResponse:
|
|
243
|
+
"""
|
|
244
|
+
Returns metadata information about the primary keys in the tables to be stored in the data catalog.
|
|
245
|
+
|
|
246
|
+
Returns:
|
|
247
|
+
HandlerResponse: The response should consist of the following columns:
|
|
248
|
+
- TABLE_NAME (str): Name of the table.
|
|
249
|
+
- COLUMN_NAME (str): Name of the column that is part of the primary key.
|
|
250
|
+
- ORDINAL_POSITION (int): Position of the column in the primary key (optional).
|
|
251
|
+
- CONSTRAINT_NAME (str): Name of the primary key constraint (optional).
|
|
252
|
+
"""
|
|
253
|
+
raise NotImplementedError()
|
|
254
|
+
|
|
255
|
+
def meta_get_foreign_keys(self, table_names: Optional[List[str]]) -> HandlerResponse:
|
|
256
|
+
"""
|
|
257
|
+
Returns metadata information about the foreign keys in the tables to be stored in the data catalog.
|
|
258
|
+
|
|
259
|
+
Returns:
|
|
260
|
+
HandlerResponse: The response should consist of the following columns:
|
|
261
|
+
- PARENT_TABLE_NAME (str): Name of the parent table.
|
|
262
|
+
- PARENT_COLUMN_NAME (str): Name of the parent column that is part of the foreign key.
|
|
263
|
+
- CHILD_TABLE_NAME (str): Name of the child table.
|
|
264
|
+
- CHILD_COLUMN_NAME (str): Name of the child column that is part of the foreign key.
|
|
265
|
+
- CONSTRAINT_NAME (str): Name of the foreign key constraint (optional).
|
|
266
|
+
"""
|
|
267
|
+
raise NotImplementedError()
|
|
268
|
+
|
|
269
|
+
def meta_get_handler_info(self, **kwargs) -> str:
|
|
270
|
+
"""
|
|
271
|
+
Retrieves information about the design and implementation of the database handler.
|
|
272
|
+
This should include, but not be limited to, the following:
|
|
273
|
+
- The type of SQL queries and operations that the handler supports.
|
|
274
|
+
- etc.
|
|
275
|
+
|
|
276
|
+
Args:
|
|
277
|
+
kwargs: Additional keyword arguments that may be used in generating the handler information.
|
|
278
|
+
|
|
279
|
+
Returns:
|
|
280
|
+
str: A string containing information about the database handler's design and implementation.
|
|
281
|
+
"""
|
|
282
|
+
pass
|
|
283
|
+
|
|
284
|
+
|
|
116
285
|
class ArgProbeMixin:
|
|
117
286
|
"""
|
|
118
287
|
A mixin class that provides probing of arguments that
|
|
@@ -154,26 +323,16 @@ class ArgProbeMixin:
|
|
|
154
323
|
self.visit(node.value)
|
|
155
324
|
|
|
156
325
|
def visit_Subscript(self, node):
|
|
157
|
-
if (
|
|
158
|
-
isinstance(node.value, ast.
|
|
159
|
-
and node.value.id in self.var_names_to_track
|
|
160
|
-
):
|
|
161
|
-
if isinstance(node.slice, ast.Index) and isinstance(
|
|
162
|
-
node.slice.value, ast.Str
|
|
163
|
-
):
|
|
326
|
+
if isinstance(node.value, ast.Name) and node.value.id in self.var_names_to_track:
|
|
327
|
+
if isinstance(node.slice, ast.Index) and isinstance(node.slice.value, ast.Str):
|
|
164
328
|
self.arg_keys.append({"name": node.slice.value.s, "required": True})
|
|
165
329
|
self.generic_visit(node)
|
|
166
330
|
|
|
167
331
|
def visit_Call(self, node):
|
|
168
332
|
if isinstance(node.func, ast.Attribute) and node.func.attr == "get":
|
|
169
|
-
if (
|
|
170
|
-
isinstance(node.func.value, ast.Name)
|
|
171
|
-
and node.func.value.id in self.var_names_to_track
|
|
172
|
-
):
|
|
333
|
+
if isinstance(node.func.value, ast.Name) and node.func.value.id in self.var_names_to_track:
|
|
173
334
|
if isinstance(node.args[0], ast.Str):
|
|
174
|
-
self.arg_keys.append(
|
|
175
|
-
{"name": node.args[0].s, "required": False}
|
|
176
|
-
)
|
|
335
|
+
self.arg_keys.append({"name": node.args[0].s, "required": False})
|
|
177
336
|
self.generic_visit(node)
|
|
178
337
|
|
|
179
338
|
@classmethod
|
|
@@ -197,9 +356,7 @@ class ArgProbeMixin:
|
|
|
197
356
|
try:
|
|
198
357
|
source_code = self.get_source_code(method_name)
|
|
199
358
|
except Exception as e:
|
|
200
|
-
logger.error(
|
|
201
|
-
f"Failed to get source code of method {method_name} in {self.__class__.__name__}. Reason: {e}"
|
|
202
|
-
)
|
|
359
|
+
logger.error(f"Failed to get source code of method {method_name} in {self.__class__.__name__}. Reason: {e}")
|
|
203
360
|
return []
|
|
204
361
|
|
|
205
362
|
# parse the source code
|
|
@@ -238,9 +395,7 @@ class ArgProbeMixin:
|
|
|
238
395
|
"""
|
|
239
396
|
method = getattr(self, method_name)
|
|
240
397
|
if method is None:
|
|
241
|
-
raise Exception(
|
|
242
|
-
f"Method {method_name} does not exist in {self.__class__.__name__}"
|
|
243
|
-
)
|
|
398
|
+
raise Exception(f"Method {method_name} does not exist in {self.__class__.__name__}")
|
|
244
399
|
source_code = inspect.getsource(method)
|
|
245
400
|
return source_code
|
|
246
401
|
|
|
@@ -288,8 +443,8 @@ class BaseMLEngine(ArgProbeMixin):
|
|
|
288
443
|
self.engine_storage = engine_storage
|
|
289
444
|
self.generative = False # if True, the target column name does not have to be specified at creation time
|
|
290
445
|
|
|
291
|
-
if kwargs.get(
|
|
292
|
-
self.base_model_storage = kwargs[
|
|
446
|
+
if kwargs.get("base_model_storage"):
|
|
447
|
+
self.base_model_storage = kwargs["base_model_storage"] # available when updating a model
|
|
293
448
|
else:
|
|
294
449
|
self.base_model_storage = None
|
|
295
450
|
|
|
@@ -1,10 +1,11 @@
|
|
|
1
|
+
from dataclasses import dataclass, astuple
|
|
1
2
|
import traceback
|
|
2
3
|
import json
|
|
3
4
|
import csv
|
|
4
5
|
from io import BytesIO, StringIO, IOBase
|
|
5
6
|
from pathlib import Path
|
|
6
7
|
import codecs
|
|
7
|
-
from typing import List
|
|
8
|
+
from typing import List, Generator
|
|
8
9
|
|
|
9
10
|
import filetype
|
|
10
11
|
import pandas as pd
|
|
@@ -18,8 +19,27 @@ DEFAULT_CHUNK_SIZE = 500
|
|
|
18
19
|
DEFAULT_CHUNK_OVERLAP = 250
|
|
19
20
|
|
|
20
21
|
|
|
21
|
-
class
|
|
22
|
-
|
|
22
|
+
class FileProcessingError(Exception): ...
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass(frozen=True, slots=True)
|
|
26
|
+
class _SINGLE_PAGE_FORMAT:
|
|
27
|
+
CSV: str = "csv"
|
|
28
|
+
JSON: str = "json"
|
|
29
|
+
TXT: str = "txt"
|
|
30
|
+
PDF: str = "pdf"
|
|
31
|
+
PARQUET: str = "parquet"
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
SINGLE_PAGE_FORMAT = _SINGLE_PAGE_FORMAT()
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@dataclass(frozen=True, slots=True)
|
|
38
|
+
class _MULTI_PAGE_FORMAT:
|
|
39
|
+
XLSX: str = "xlsx"
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
MULTI_PAGE_FORMAT = _MULTI_PAGE_FORMAT()
|
|
23
43
|
|
|
24
44
|
|
|
25
45
|
def decode(file_obj: IOBase) -> StringIO:
|
|
@@ -56,21 +76,20 @@ def decode(file_obj: IOBase) -> StringIO:
|
|
|
56
76
|
data_str = StringIO(byte_str.decode(encoding, errors))
|
|
57
77
|
except Exception as e:
|
|
58
78
|
logger.error(traceback.format_exc())
|
|
59
|
-
raise
|
|
79
|
+
raise FileProcessingError("Could not load into string") from e
|
|
60
80
|
|
|
61
81
|
return data_str
|
|
62
82
|
|
|
63
83
|
|
|
64
84
|
class FormatDetector:
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
multipage_formats = ['xlsx']
|
|
85
|
+
supported_formats = astuple(SINGLE_PAGE_FORMAT) + astuple(MULTI_PAGE_FORMAT)
|
|
86
|
+
multipage_formats = astuple(MULTI_PAGE_FORMAT)
|
|
68
87
|
|
|
69
88
|
def __init__(
|
|
70
89
|
self,
|
|
71
|
-
path: str = None,
|
|
72
|
-
name: str = None,
|
|
73
|
-
file: IOBase = None
|
|
90
|
+
path: str | None = None,
|
|
91
|
+
name: str | None = None,
|
|
92
|
+
file: IOBase | None = None,
|
|
74
93
|
):
|
|
75
94
|
"""
|
|
76
95
|
File format detector
|
|
@@ -81,16 +100,16 @@ class FormatDetector:
|
|
|
81
100
|
:param file: file descriptor (via open(...), of BytesIO(...))
|
|
82
101
|
"""
|
|
83
102
|
if path is not None:
|
|
84
|
-
file = open(path,
|
|
103
|
+
file = open(path, "rb")
|
|
85
104
|
|
|
86
105
|
elif file is not None:
|
|
87
106
|
if name is None:
|
|
88
|
-
if hasattr(file,
|
|
107
|
+
if hasattr(file, "name"):
|
|
89
108
|
path = file.name
|
|
90
109
|
else:
|
|
91
|
-
path =
|
|
110
|
+
path = "file"
|
|
92
111
|
else:
|
|
93
|
-
raise
|
|
112
|
+
raise FileProcessingError("Wrong arguments: path or file is required")
|
|
94
113
|
|
|
95
114
|
if name is None:
|
|
96
115
|
name = Path(path).name
|
|
@@ -108,14 +127,14 @@ class FormatDetector:
|
|
|
108
127
|
format = self.get_format_by_name()
|
|
109
128
|
if format is not None:
|
|
110
129
|
if format not in self.supported_formats:
|
|
111
|
-
raise
|
|
130
|
+
raise FileProcessingError(f"Not supported format: {format}")
|
|
112
131
|
|
|
113
132
|
if format is None and self.file_obj is not None:
|
|
114
133
|
format = self.get_format_by_content()
|
|
115
134
|
self.file_obj.seek(0)
|
|
116
135
|
|
|
117
136
|
if format is None:
|
|
118
|
-
raise
|
|
137
|
+
raise FileProcessingError(f"Unable to detect format: {self.name}")
|
|
119
138
|
|
|
120
139
|
self.format = format
|
|
121
140
|
return format
|
|
@@ -124,33 +143,32 @@ class FormatDetector:
|
|
|
124
143
|
extension = Path(self.name).suffix.strip(".").lower()
|
|
125
144
|
if extension == "tsv":
|
|
126
145
|
extension = "csv"
|
|
127
|
-
self.parameters[
|
|
146
|
+
self.parameters["delimiter"] = "\t"
|
|
128
147
|
|
|
129
148
|
return extension or None
|
|
130
149
|
|
|
131
150
|
def get_format_by_content(self):
|
|
132
151
|
if self.is_parquet(self.file_obj):
|
|
133
|
-
return
|
|
152
|
+
return SINGLE_PAGE_FORMAT.PARQUET
|
|
134
153
|
|
|
135
154
|
file_type = filetype.guess(self.file_obj)
|
|
136
155
|
if file_type is not None:
|
|
137
|
-
|
|
138
156
|
if file_type.mime in {
|
|
139
157
|
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
|
140
158
|
"application/vnd.ms-excel",
|
|
141
159
|
}:
|
|
142
|
-
return
|
|
160
|
+
return MULTI_PAGE_FORMAT.XLSX
|
|
143
161
|
|
|
144
|
-
if file_type.mime ==
|
|
145
|
-
return
|
|
162
|
+
if file_type.mime == "application/pdf":
|
|
163
|
+
return SINGLE_PAGE_FORMAT.PDF
|
|
146
164
|
|
|
147
165
|
file_obj = decode(self.file_obj)
|
|
148
166
|
|
|
149
167
|
if self.is_json(file_obj):
|
|
150
|
-
return
|
|
168
|
+
return SINGLE_PAGE_FORMAT.JSON
|
|
151
169
|
|
|
152
170
|
if self.is_csv(file_obj):
|
|
153
|
-
return
|
|
171
|
+
return SINGLE_PAGE_FORMAT.CSV
|
|
154
172
|
|
|
155
173
|
@staticmethod
|
|
156
174
|
def is_json(data_obj: StringIO) -> bool:
|
|
@@ -198,35 +216,53 @@ class FormatDetector:
|
|
|
198
216
|
return False
|
|
199
217
|
|
|
200
218
|
|
|
201
|
-
|
|
219
|
+
def format_column_names(df: pd.DataFrame):
|
|
220
|
+
df.columns = [column.strip(" \t") for column in df.columns]
|
|
221
|
+
if len(df.columns) != len(set(df.columns)) or any(len(column_name) == 0 for column_name in df.columns):
|
|
222
|
+
raise FileProcessingError("Each column should have a unique and non-empty name.")
|
|
223
|
+
|
|
202
224
|
|
|
225
|
+
class FileReader(FormatDetector):
|
|
203
226
|
def _get_fnc(self):
|
|
204
227
|
format = self.get_format()
|
|
205
|
-
func = getattr(self, f
|
|
228
|
+
func = getattr(self, f"read_{format}", None)
|
|
206
229
|
if func is None:
|
|
207
|
-
raise
|
|
208
|
-
|
|
230
|
+
raise FileProcessingError(f"Unsupported format: {format}")
|
|
231
|
+
|
|
232
|
+
if format in astuple(MULTI_PAGE_FORMAT):
|
|
233
|
+
|
|
234
|
+
def format_multipage(*args, **kwargs):
|
|
235
|
+
for page_number, df in func(*args, **kwargs):
|
|
236
|
+
format_column_names(df)
|
|
237
|
+
yield page_number, df
|
|
238
|
+
|
|
239
|
+
return format_multipage
|
|
240
|
+
|
|
241
|
+
def format_singlepage(*args, **kwargs) -> pd.DataFrame:
|
|
242
|
+
"""Check that the columns have unique not-empty names"""
|
|
243
|
+
df = func(*args, **kwargs)
|
|
244
|
+
format_column_names(df)
|
|
245
|
+
return df
|
|
246
|
+
|
|
247
|
+
return format_singlepage
|
|
209
248
|
|
|
210
249
|
def get_pages(self, **kwargs) -> List[str]:
|
|
211
250
|
"""
|
|
212
|
-
|
|
251
|
+
Get list of tables in file
|
|
213
252
|
"""
|
|
214
253
|
format = self.get_format()
|
|
215
254
|
if format not in self.multipage_formats:
|
|
216
255
|
# only one table
|
|
217
|
-
return [
|
|
256
|
+
return ["main"]
|
|
218
257
|
|
|
219
258
|
func = self._get_fnc()
|
|
220
259
|
self.file_obj.seek(0)
|
|
221
260
|
|
|
222
|
-
return [
|
|
223
|
-
name for name, _ in
|
|
224
|
-
func(self.file_obj, only_names=True, **kwargs)
|
|
225
|
-
]
|
|
261
|
+
return [name for name, _ in func(self.file_obj, only_names=True, **kwargs)]
|
|
226
262
|
|
|
227
|
-
def get_contents(self, **kwargs):
|
|
263
|
+
def get_contents(self, **kwargs) -> dict[str, pd.DataFrame]:
|
|
228
264
|
"""
|
|
229
|
-
|
|
265
|
+
Get all info(pages with content) from file as dict: {tablename, content}
|
|
230
266
|
"""
|
|
231
267
|
func = self._get_fnc()
|
|
232
268
|
self.file_obj.seek(0)
|
|
@@ -234,17 +270,13 @@ class FileReader(FormatDetector):
|
|
|
234
270
|
format = self.get_format()
|
|
235
271
|
if format not in self.multipage_formats:
|
|
236
272
|
# only one table
|
|
237
|
-
return {
|
|
273
|
+
return {"main": func(self.file_obj, name=self.name, **kwargs)}
|
|
238
274
|
|
|
239
|
-
return {
|
|
240
|
-
name: df
|
|
241
|
-
for name, df in
|
|
242
|
-
func(self.file_obj, **kwargs)
|
|
243
|
-
}
|
|
275
|
+
return {name: df for name, df in func(self.file_obj, **kwargs)}
|
|
244
276
|
|
|
245
|
-
def get_page_content(self, page_name: str = None, **kwargs) -> pd.DataFrame:
|
|
277
|
+
def get_page_content(self, page_name: str | None = None, **kwargs) -> pd.DataFrame:
|
|
246
278
|
"""
|
|
247
|
-
|
|
279
|
+
Get content of a single table
|
|
248
280
|
"""
|
|
249
281
|
func = self._get_fnc()
|
|
250
282
|
self.file_obj.seek(0)
|
|
@@ -258,7 +290,7 @@ class FileReader(FormatDetector):
|
|
|
258
290
|
return df
|
|
259
291
|
|
|
260
292
|
@staticmethod
|
|
261
|
-
def _get_csv_dialect(buffer, delimiter=None) -> csv.Dialect:
|
|
293
|
+
def _get_csv_dialect(buffer, delimiter: str | None = None) -> csv.Dialect | None:
|
|
262
294
|
sample = buffer.readline() # trying to get dialect from header
|
|
263
295
|
buffer.seek(0)
|
|
264
296
|
try:
|
|
@@ -270,42 +302,35 @@ class FileReader(FormatDetector):
|
|
|
270
302
|
else:
|
|
271
303
|
accepted_csv_delimiters = [",", "\t", ";"]
|
|
272
304
|
try:
|
|
273
|
-
dialect = csv.Sniffer().sniff(
|
|
274
|
-
|
|
275
|
-
)
|
|
276
|
-
dialect.doublequote = (
|
|
277
|
-
True # assume that all csvs have " as string escape
|
|
278
|
-
)
|
|
305
|
+
dialect = csv.Sniffer().sniff(sample, delimiters=accepted_csv_delimiters)
|
|
306
|
+
dialect.doublequote = True # assume that all csvs have " as string escape
|
|
279
307
|
except Exception:
|
|
280
308
|
dialect = csv.reader(sample).dialect
|
|
281
309
|
if dialect.delimiter not in accepted_csv_delimiters:
|
|
282
|
-
raise
|
|
283
|
-
f"CSV delimeter '{dialect.delimiter}' is not supported"
|
|
284
|
-
)
|
|
310
|
+
raise FileProcessingError(f"CSV delimeter '{dialect.delimiter}' is not supported")
|
|
285
311
|
|
|
286
312
|
except csv.Error:
|
|
287
313
|
dialect = None
|
|
288
314
|
return dialect
|
|
289
315
|
|
|
290
316
|
@classmethod
|
|
291
|
-
def read_csv(cls, file_obj: BytesIO, delimiter=None, **kwargs):
|
|
317
|
+
def read_csv(cls, file_obj: BytesIO, delimiter: str | None = None, **kwargs) -> pd.DataFrame:
|
|
292
318
|
file_obj = decode(file_obj)
|
|
293
319
|
dialect = cls._get_csv_dialect(file_obj, delimiter=delimiter)
|
|
294
|
-
|
|
295
320
|
return pd.read_csv(file_obj, sep=dialect.delimiter, index_col=False)
|
|
296
321
|
|
|
297
322
|
@staticmethod
|
|
298
|
-
def read_txt(file_obj: BytesIO, name=None, **kwargs):
|
|
323
|
+
def read_txt(file_obj: BytesIO, name: str | None = None, **kwargs) -> pd.DataFrame:
|
|
299
324
|
# the lib is heavy, so import it only when needed
|
|
300
325
|
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
|
326
|
+
|
|
301
327
|
file_obj = decode(file_obj)
|
|
302
328
|
|
|
303
329
|
try:
|
|
304
330
|
from langchain_core.documents import Document
|
|
305
331
|
except ImportError:
|
|
306
|
-
raise
|
|
307
|
-
"To import TXT document please install 'langchain-community':\n"
|
|
308
|
-
" pip install langchain-community"
|
|
332
|
+
raise FileProcessingError(
|
|
333
|
+
"To import TXT document please install 'langchain-community':\n pip install langchain-community"
|
|
309
334
|
)
|
|
310
335
|
text = file_obj.read()
|
|
311
336
|
|
|
@@ -317,15 +342,10 @@ class FileReader(FormatDetector):
|
|
|
317
342
|
)
|
|
318
343
|
|
|
319
344
|
docs = text_splitter.split_documents(documents)
|
|
320
|
-
return pd.DataFrame(
|
|
321
|
-
[
|
|
322
|
-
{"content": doc.page_content, "metadata": doc.metadata}
|
|
323
|
-
for doc in docs
|
|
324
|
-
]
|
|
325
|
-
)
|
|
345
|
+
return pd.DataFrame([{"content": doc.page_content, "metadata": doc.metadata} for doc in docs])
|
|
326
346
|
|
|
327
347
|
@staticmethod
|
|
328
|
-
def read_pdf(file_obj: BytesIO, name=None, **kwargs):
|
|
348
|
+
def read_pdf(file_obj: BytesIO, name: str | None = None, **kwargs) -> pd.DataFrame:
|
|
329
349
|
# the libs are heavy, so import it only when needed
|
|
330
350
|
import fitz # pymupdf
|
|
331
351
|
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
|
@@ -340,30 +360,36 @@ class FileReader(FormatDetector):
|
|
|
340
360
|
split_text = text_splitter.split_text(text)
|
|
341
361
|
|
|
342
362
|
return pd.DataFrame(
|
|
343
|
-
{
|
|
363
|
+
{
|
|
364
|
+
"content": split_text,
|
|
365
|
+
"metadata": [{"file_format": "pdf", "source_file": name}] * len(split_text),
|
|
366
|
+
}
|
|
344
367
|
)
|
|
345
368
|
|
|
346
369
|
@staticmethod
|
|
347
|
-
def read_json(file_obj: BytesIO, **kwargs):
|
|
370
|
+
def read_json(file_obj: BytesIO, **kwargs) -> pd.DataFrame:
|
|
348
371
|
file_obj = decode(file_obj)
|
|
349
372
|
file_obj.seek(0)
|
|
350
373
|
json_doc = json.loads(file_obj.read())
|
|
351
374
|
return pd.json_normalize(json_doc, max_level=0)
|
|
352
375
|
|
|
353
376
|
@staticmethod
|
|
354
|
-
def read_parquet(file_obj: BytesIO, **kwargs):
|
|
377
|
+
def read_parquet(file_obj: BytesIO, **kwargs) -> pd.DataFrame:
|
|
355
378
|
return pd.read_parquet(file_obj)
|
|
356
379
|
|
|
357
380
|
@staticmethod
|
|
358
|
-
def read_xlsx(
|
|
381
|
+
def read_xlsx(
|
|
382
|
+
file_obj: BytesIO,
|
|
383
|
+
page_name: str | None = None,
|
|
384
|
+
only_names: bool = False,
|
|
385
|
+
**kwargs,
|
|
386
|
+
) -> Generator[tuple[str, pd.DataFrame | None], None, None]:
|
|
359
387
|
with pd.ExcelFile(file_obj) as xls:
|
|
360
|
-
|
|
361
388
|
if page_name is not None:
|
|
362
389
|
# return specific page
|
|
363
390
|
yield page_name, pd.read_excel(xls, sheet_name=page_name)
|
|
364
391
|
|
|
365
392
|
for page_name in xls.sheet_names:
|
|
366
|
-
|
|
367
393
|
if only_names:
|
|
368
394
|
# extract only pages names
|
|
369
395
|
df = None
|