unstructured-ingest 0.4.0__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of unstructured-ingest might be problematic. Click here for more details.
- test/integration/connectors/utils/validation/equality.py +2 -1
- test/unit/v2/connectors/sql/test_sql.py +4 -2
- unstructured_ingest/__version__.py +1 -1
- unstructured_ingest/utils/data_prep.py +11 -3
- unstructured_ingest/utils/html.py +109 -0
- unstructured_ingest/utils/ndjson.py +52 -0
- unstructured_ingest/v2/interfaces/upload_stager.py +3 -13
- unstructured_ingest/v2/pipeline/steps/chunk.py +3 -4
- unstructured_ingest/v2/pipeline/steps/embed.py +3 -4
- unstructured_ingest/v2/pipeline/steps/partition.py +3 -4
- unstructured_ingest/v2/processes/connectors/confluence.py +95 -25
- unstructured_ingest/v2/processes/connectors/duckdb/base.py +2 -2
- unstructured_ingest/v2/processes/connectors/fsspec/azure.py +8 -8
- unstructured_ingest/v2/processes/connectors/fsspec/box.py +7 -7
- unstructured_ingest/v2/processes/connectors/fsspec/dropbox.py +9 -9
- unstructured_ingest/v2/processes/connectors/fsspec/fsspec.py +41 -9
- unstructured_ingest/v2/processes/connectors/fsspec/gcs.py +7 -7
- unstructured_ingest/v2/processes/connectors/fsspec/s3.py +8 -8
- unstructured_ingest/v2/processes/connectors/fsspec/sftp.py +5 -5
- unstructured_ingest/v2/processes/connectors/sql/__init__.py +4 -0
- unstructured_ingest/v2/processes/connectors/sql/singlestore.py +2 -1
- unstructured_ingest/v2/processes/connectors/sql/sql.py +12 -8
- unstructured_ingest/v2/processes/connectors/sql/sqlite.py +2 -1
- unstructured_ingest/v2/processes/connectors/sql/vastdb.py +270 -0
- {unstructured_ingest-0.4.0.dist-info → unstructured_ingest-0.4.1.dist-info}/METADATA +25 -22
- {unstructured_ingest-0.4.0.dist-info → unstructured_ingest-0.4.1.dist-info}/RECORD +30 -27
- {unstructured_ingest-0.4.0.dist-info → unstructured_ingest-0.4.1.dist-info}/LICENSE.md +0 -0
- {unstructured_ingest-0.4.0.dist-info → unstructured_ingest-0.4.1.dist-info}/WHEEL +0 -0
- {unstructured_ingest-0.4.0.dist-info → unstructured_ingest-0.4.1.dist-info}/entry_points.txt +0 -0
- {unstructured_ingest-0.4.0.dist-info → unstructured_ingest-0.4.1.dist-info}/top_level.txt +0 -0
|
@@ -119,7 +119,7 @@ class FsspecIndexer(Indexer):
|
|
|
119
119
|
logger.error(f"failed to validate connection: {e}", exc_info=True)
|
|
120
120
|
raise self.wrap_error(e=e)
|
|
121
121
|
|
|
122
|
-
def
|
|
122
|
+
def get_file_info(self) -> list[dict[str, Any]]:
|
|
123
123
|
if not self.index_config.recursive:
|
|
124
124
|
# fs.ls does not walk directories
|
|
125
125
|
# directories that are listed in cloud storage can cause problems
|
|
@@ -156,24 +156,56 @@ class FsspecIndexer(Indexer):
|
|
|
156
156
|
|
|
157
157
|
return random.sample(files, n)
|
|
158
158
|
|
|
159
|
-
def get_metadata(self,
|
|
159
|
+
def get_metadata(self, file_info: dict) -> FileDataSourceMetadata:
|
|
160
160
|
raise NotImplementedError()
|
|
161
161
|
|
|
162
|
-
def get_path(self,
|
|
163
|
-
return
|
|
162
|
+
def get_path(self, file_info: dict) -> str:
|
|
163
|
+
return file_info["name"]
|
|
164
164
|
|
|
165
165
|
def sterilize_info(self, file_data: dict) -> dict:
|
|
166
166
|
return sterilize_dict(data=file_data)
|
|
167
167
|
|
|
168
|
+
def create_init_file_data(self, remote_filepath: Optional[str] = None) -> FileData:
|
|
169
|
+
# Create initial file data that requires no network calls and is constructed purely
|
|
170
|
+
# with information that exists in the config
|
|
171
|
+
remote_filepath = remote_filepath or self.index_config.remote_url
|
|
172
|
+
path_without_protocol = remote_filepath.split("://")[1]
|
|
173
|
+
rel_path = remote_filepath.replace(path_without_protocol, "").lstrip("/")
|
|
174
|
+
return FileData(
|
|
175
|
+
identifier=str(uuid5(NAMESPACE_DNS, remote_filepath)),
|
|
176
|
+
connector_type=self.connector_type,
|
|
177
|
+
display_name=remote_filepath,
|
|
178
|
+
source_identifiers=SourceIdentifiers(
|
|
179
|
+
filename=Path(remote_filepath).name,
|
|
180
|
+
rel_path=rel_path or None,
|
|
181
|
+
fullpath=remote_filepath,
|
|
182
|
+
),
|
|
183
|
+
metadata=FileDataSourceMetadata(url=remote_filepath),
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
def hydrate_file_data(self, init_file_data: FileData):
|
|
187
|
+
# Get file info
|
|
188
|
+
with self.connection_config.get_client(protocol=self.index_config.protocol) as client:
|
|
189
|
+
files = client.ls(self.index_config.path_without_protocol, detail=True)
|
|
190
|
+
filtered_files = [
|
|
191
|
+
file for file in files if file.get("size") > 0 and file.get("type") == "file"
|
|
192
|
+
]
|
|
193
|
+
if not filtered_files:
|
|
194
|
+
raise ValueError(f"{init_file_data} did not reference any valid file")
|
|
195
|
+
if len(filtered_files) > 1:
|
|
196
|
+
raise ValueError(f"{init_file_data} referenced more than one file")
|
|
197
|
+
file_info = filtered_files[0]
|
|
198
|
+
init_file_data.additional_metadata = self.get_metadata(file_info=file_info)
|
|
199
|
+
|
|
168
200
|
def run(self, **kwargs: Any) -> Generator[FileData, None, None]:
|
|
169
|
-
files = self.
|
|
170
|
-
for
|
|
171
|
-
file_path = self.get_path(
|
|
201
|
+
files = self.get_file_info()
|
|
202
|
+
for file_info in files:
|
|
203
|
+
file_path = self.get_path(file_info=file_info)
|
|
172
204
|
# Note: we remove any remaining leading slashes (Box introduces these)
|
|
173
205
|
# to get a valid relative path
|
|
174
206
|
rel_path = file_path.replace(self.index_config.path_without_protocol, "").lstrip("/")
|
|
175
207
|
|
|
176
|
-
additional_metadata = self.sterilize_info(file_data=
|
|
208
|
+
additional_metadata = self.sterilize_info(file_data=file_info)
|
|
177
209
|
additional_metadata["original_file_path"] = file_path
|
|
178
210
|
yield FileData(
|
|
179
211
|
identifier=str(uuid5(NAMESPACE_DNS, file_path)),
|
|
@@ -183,7 +215,7 @@ class FsspecIndexer(Indexer):
|
|
|
183
215
|
rel_path=rel_path or None,
|
|
184
216
|
fullpath=file_path,
|
|
185
217
|
),
|
|
186
|
-
metadata=self.get_metadata(
|
|
218
|
+
metadata=self.get_metadata(file_info=file_info),
|
|
187
219
|
additional_metadata=additional_metadata,
|
|
188
220
|
display_name=file_path,
|
|
189
221
|
)
|
|
@@ -131,22 +131,22 @@ class GcsIndexer(FsspecIndexer):
|
|
|
131
131
|
index_config: GcsIndexerConfig
|
|
132
132
|
connector_type: str = CONNECTOR_TYPE
|
|
133
133
|
|
|
134
|
-
def get_metadata(self,
|
|
135
|
-
path =
|
|
134
|
+
def get_metadata(self, file_info: dict) -> FileDataSourceMetadata:
|
|
135
|
+
path = file_info["name"]
|
|
136
136
|
date_created = None
|
|
137
137
|
date_modified = None
|
|
138
|
-
if modified_at_str :=
|
|
138
|
+
if modified_at_str := file_info.get("updated"):
|
|
139
139
|
date_modified = str(parser.parse(modified_at_str).timestamp())
|
|
140
|
-
if created_at_str :=
|
|
140
|
+
if created_at_str := file_info.get("timeCreated"):
|
|
141
141
|
date_created = str(parser.parse(created_at_str).timestamp())
|
|
142
142
|
|
|
143
|
-
file_size =
|
|
143
|
+
file_size = file_info.get("size") if "size" in file_info else None
|
|
144
144
|
|
|
145
|
-
version =
|
|
145
|
+
version = file_info.get("etag")
|
|
146
146
|
record_locator = {
|
|
147
147
|
"protocol": self.index_config.protocol,
|
|
148
148
|
"remote_file_path": self.index_config.remote_url,
|
|
149
|
-
"file_id":
|
|
149
|
+
"file_id": file_info.get("id"),
|
|
150
150
|
}
|
|
151
151
|
return FileDataSourceMetadata(
|
|
152
152
|
date_created=date_created,
|
|
@@ -110,22 +110,22 @@ class S3Indexer(FsspecIndexer):
|
|
|
110
110
|
def wrap_error(self, e: Exception) -> Exception:
|
|
111
111
|
return self.connection_config.wrap_error(e=e)
|
|
112
112
|
|
|
113
|
-
def get_path(self,
|
|
114
|
-
return
|
|
113
|
+
def get_path(self, file_info: dict) -> str:
|
|
114
|
+
return file_info["Key"]
|
|
115
115
|
|
|
116
|
-
def get_metadata(self,
|
|
117
|
-
path =
|
|
116
|
+
def get_metadata(self, file_info: dict) -> FileDataSourceMetadata:
|
|
117
|
+
path = file_info["Key"]
|
|
118
118
|
date_created = None
|
|
119
119
|
date_modified = None
|
|
120
|
-
modified =
|
|
120
|
+
modified = file_info.get("LastModified")
|
|
121
121
|
if modified:
|
|
122
122
|
date_created = str(modified.timestamp())
|
|
123
123
|
date_modified = str(modified.timestamp())
|
|
124
124
|
|
|
125
|
-
file_size =
|
|
126
|
-
file_size = file_size or
|
|
125
|
+
file_size = file_info.get("size") if "size" in file_info else None
|
|
126
|
+
file_size = file_size or file_info.get("Size")
|
|
127
127
|
|
|
128
|
-
version =
|
|
128
|
+
version = file_info.get("ETag").rstrip('"').lstrip('"') if "ETag" in file_info else None
|
|
129
129
|
metadata: dict[str, str] = {}
|
|
130
130
|
with contextlib.suppress(AttributeError):
|
|
131
131
|
with self.connection_config.get_client(protocol=self.index_config.protocol) as client:
|
|
@@ -107,12 +107,12 @@ class SftpIndexer(FsspecIndexer):
|
|
|
107
107
|
file.identifier = new_identifier
|
|
108
108
|
yield file
|
|
109
109
|
|
|
110
|
-
def get_metadata(self,
|
|
111
|
-
path =
|
|
112
|
-
date_created = str(
|
|
113
|
-
date_modified = str(
|
|
110
|
+
def get_metadata(self, file_info: dict) -> FileDataSourceMetadata:
|
|
111
|
+
path = file_info["name"]
|
|
112
|
+
date_created = str(file_info.get("time").timestamp()) if "time" in file_info else None
|
|
113
|
+
date_modified = str(file_info.get("mtime").timestamp()) if "mtime" in file_info else None
|
|
114
114
|
|
|
115
|
-
file_size =
|
|
115
|
+
file_size = file_info.get("size") if "size" in file_info else None
|
|
116
116
|
|
|
117
117
|
record_locator = {
|
|
118
118
|
"protocol": self.index_config.protocol,
|
|
@@ -15,11 +15,14 @@ from .snowflake import CONNECTOR_TYPE as SNOWFLAKE_CONNECTOR_TYPE
|
|
|
15
15
|
from .snowflake import snowflake_destination_entry, snowflake_source_entry
|
|
16
16
|
from .sqlite import CONNECTOR_TYPE as SQLITE_CONNECTOR_TYPE
|
|
17
17
|
from .sqlite import sqlite_destination_entry, sqlite_source_entry
|
|
18
|
+
from .vastdb import CONNECTOR_TYPE as VASTDB_CONNECTOR_TYPE
|
|
19
|
+
from .vastdb import vastdb_destination_entry, vastdb_source_entry
|
|
18
20
|
|
|
19
21
|
add_source_entry(source_type=SQLITE_CONNECTOR_TYPE, entry=sqlite_source_entry)
|
|
20
22
|
add_source_entry(source_type=POSTGRES_CONNECTOR_TYPE, entry=postgres_source_entry)
|
|
21
23
|
add_source_entry(source_type=SNOWFLAKE_CONNECTOR_TYPE, entry=snowflake_source_entry)
|
|
22
24
|
add_source_entry(source_type=SINGLESTORE_CONNECTOR_TYPE, entry=singlestore_source_entry)
|
|
25
|
+
add_source_entry(source_type=VASTDB_CONNECTOR_TYPE, entry=vastdb_source_entry)
|
|
23
26
|
|
|
24
27
|
add_destination_entry(destination_type=SQLITE_CONNECTOR_TYPE, entry=sqlite_destination_entry)
|
|
25
28
|
add_destination_entry(destination_type=POSTGRES_CONNECTOR_TYPE, entry=postgres_destination_entry)
|
|
@@ -31,3 +34,4 @@ add_destination_entry(
|
|
|
31
34
|
destination_type=DATABRICKS_DELTA_TABLES_CONNECTOR_TYPE,
|
|
32
35
|
entry=databricks_delta_tables_destination_entry,
|
|
33
36
|
)
|
|
37
|
+
add_destination_entry(destination_type=VASTDB_CONNECTOR_TYPE, entry=vastdb_destination_entry)
|
|
@@ -3,6 +3,7 @@ from contextlib import contextmanager
|
|
|
3
3
|
from dataclasses import dataclass, field
|
|
4
4
|
from typing import TYPE_CHECKING, Any, Generator, Optional
|
|
5
5
|
|
|
6
|
+
import pandas as pd
|
|
6
7
|
from pydantic import Field, Secret
|
|
7
8
|
|
|
8
9
|
from unstructured_ingest.v2.logger import logger
|
|
@@ -139,7 +140,7 @@ class SingleStoreUploader(SQLUploader):
|
|
|
139
140
|
if isinstance(value, (list, dict)):
|
|
140
141
|
value = json.dumps(value)
|
|
141
142
|
if column_name in _DATE_COLUMNS:
|
|
142
|
-
if value is None:
|
|
143
|
+
if value is None or pd.isna(value):
|
|
143
144
|
parsed.append(None)
|
|
144
145
|
else:
|
|
145
146
|
parsed.append(parse_date_string(value))
|
|
@@ -14,7 +14,7 @@ from dateutil import parser
|
|
|
14
14
|
from pydantic import BaseModel, Field, Secret
|
|
15
15
|
|
|
16
16
|
from unstructured_ingest.error import DestinationConnectionError, SourceConnectionError
|
|
17
|
-
from unstructured_ingest.utils.data_prep import get_data, get_data_df, split_dataframe
|
|
17
|
+
from unstructured_ingest.utils.data_prep import get_data, get_data_df, split_dataframe, write_data
|
|
18
18
|
from unstructured_ingest.v2.constants import RECORD_ID_LABEL
|
|
19
19
|
from unstructured_ingest.v2.interfaces import (
|
|
20
20
|
AccessConfig,
|
|
@@ -314,7 +314,7 @@ class SQLUploadStager(UploadStager):
|
|
|
314
314
|
output_filename = f"{Path(output_filename).stem}{output_filename_suffix}"
|
|
315
315
|
output_path = self.get_output_path(output_filename=output_filename, output_dir=output_dir)
|
|
316
316
|
|
|
317
|
-
|
|
317
|
+
write_data(path=output_path, data=df.to_dict(orient="records"))
|
|
318
318
|
return output_path
|
|
319
319
|
|
|
320
320
|
|
|
@@ -332,6 +332,7 @@ class SQLUploader(Uploader):
|
|
|
332
332
|
upload_config: SQLUploaderConfig
|
|
333
333
|
connection_config: SQLConnectionConfig
|
|
334
334
|
values_delimiter: str = "?"
|
|
335
|
+
_columns: list[str] = field(init=False, default=None)
|
|
335
336
|
|
|
336
337
|
def precheck(self) -> None:
|
|
337
338
|
try:
|
|
@@ -354,7 +355,7 @@ class SQLUploader(Uploader):
|
|
|
354
355
|
parsed = []
|
|
355
356
|
for column_name, value in zip(columns, row):
|
|
356
357
|
if column_name in _DATE_COLUMNS:
|
|
357
|
-
if value is None:
|
|
358
|
+
if value is None or pd.isna(value): # pandas is nan
|
|
358
359
|
parsed.append(None)
|
|
359
360
|
else:
|
|
360
361
|
parsed.append(parse_date_string(value))
|
|
@@ -364,8 +365,9 @@ class SQLUploader(Uploader):
|
|
|
364
365
|
return output
|
|
365
366
|
|
|
366
367
|
def _fit_to_schema(self, df: pd.DataFrame) -> pd.DataFrame:
|
|
368
|
+
table_columns = self.get_table_columns()
|
|
367
369
|
columns = set(df.columns)
|
|
368
|
-
schema_fields = set(
|
|
370
|
+
schema_fields = set(table_columns)
|
|
369
371
|
columns_to_drop = columns - schema_fields
|
|
370
372
|
missing_columns = schema_fields - columns
|
|
371
373
|
|
|
@@ -395,8 +397,8 @@ class SQLUploader(Uploader):
|
|
|
395
397
|
f"record id column "
|
|
396
398
|
f"{self.upload_config.record_id_key}, skipping delete"
|
|
397
399
|
)
|
|
400
|
+
df = self._fit_to_schema(df=df)
|
|
398
401
|
df.replace({np.nan: None}, inplace=True)
|
|
399
|
-
self._fit_to_schema(df=df)
|
|
400
402
|
|
|
401
403
|
columns = list(df.columns)
|
|
402
404
|
stmt = "INSERT INTO {table_name} ({columns}) VALUES({values})".format(
|
|
@@ -424,9 +426,11 @@ class SQLUploader(Uploader):
|
|
|
424
426
|
cursor.executemany(stmt, values)
|
|
425
427
|
|
|
426
428
|
def get_table_columns(self) -> list[str]:
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
429
|
+
if self._columns is None:
|
|
430
|
+
with self.get_cursor() as cursor:
|
|
431
|
+
cursor.execute(f"SELECT * from {self.upload_config.table_name} LIMIT 1")
|
|
432
|
+
self._columns = [desc[0] for desc in cursor.description]
|
|
433
|
+
return self._columns
|
|
430
434
|
|
|
431
435
|
def can_delete(self) -> bool:
|
|
432
436
|
return self.upload_config.record_id_key in self.get_table_columns()
|
|
@@ -4,6 +4,7 @@ from dataclasses import dataclass, field
|
|
|
4
4
|
from pathlib import Path
|
|
5
5
|
from typing import TYPE_CHECKING, Any, Generator
|
|
6
6
|
|
|
7
|
+
import pandas as pd
|
|
7
8
|
from pydantic import Field, Secret, model_validator
|
|
8
9
|
|
|
9
10
|
from unstructured_ingest.v2.logger import logger
|
|
@@ -141,7 +142,7 @@ class SQLiteUploader(SQLUploader):
|
|
|
141
142
|
if isinstance(value, (list, dict)):
|
|
142
143
|
value = json.dumps(value)
|
|
143
144
|
if column_name in _DATE_COLUMNS:
|
|
144
|
-
if value is None:
|
|
145
|
+
if value is None or pd.isna(value):
|
|
145
146
|
parsed.append(None)
|
|
146
147
|
else:
|
|
147
148
|
parsed.append(parse_date_string(value))
|
|
@@ -0,0 +1,270 @@
|
|
|
1
|
+
from contextlib import contextmanager
|
|
2
|
+
from dataclasses import dataclass, field
|
|
3
|
+
from typing import TYPE_CHECKING, Any, Optional
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import pandas as pd
|
|
7
|
+
from pydantic import Field, Secret
|
|
8
|
+
|
|
9
|
+
from unstructured_ingest.error import DestinationConnectionError
|
|
10
|
+
from unstructured_ingest.utils.data_prep import split_dataframe
|
|
11
|
+
from unstructured_ingest.utils.dep_check import requires_dependencies
|
|
12
|
+
from unstructured_ingest.v2.constants import RECORD_ID_LABEL
|
|
13
|
+
from unstructured_ingest.v2.interfaces import (
|
|
14
|
+
FileData,
|
|
15
|
+
)
|
|
16
|
+
from unstructured_ingest.v2.logger import logger
|
|
17
|
+
from unstructured_ingest.v2.processes.connector_registry import (
|
|
18
|
+
DestinationRegistryEntry,
|
|
19
|
+
SourceRegistryEntry,
|
|
20
|
+
)
|
|
21
|
+
from unstructured_ingest.v2.processes.connectors.sql.sql import (
|
|
22
|
+
_COLUMNS,
|
|
23
|
+
SQLAccessConfig,
|
|
24
|
+
SqlBatchFileData,
|
|
25
|
+
SQLConnectionConfig,
|
|
26
|
+
SQLDownloader,
|
|
27
|
+
SQLDownloaderConfig,
|
|
28
|
+
SQLIndexer,
|
|
29
|
+
SQLIndexerConfig,
|
|
30
|
+
SQLUploader,
|
|
31
|
+
SQLUploaderConfig,
|
|
32
|
+
SQLUploadStager,
|
|
33
|
+
SQLUploadStagerConfig,
|
|
34
|
+
)
|
|
35
|
+
from unstructured_ingest.v2.utils import get_enhanced_element_id
|
|
36
|
+
|
|
37
|
+
if TYPE_CHECKING:
|
|
38
|
+
from vastdb import connect as VastdbConnect
|
|
39
|
+
from vastdb import transaction as VastdbTransaction
|
|
40
|
+
from vastdb.table import Table as VastdbTable
|
|
41
|
+
|
|
42
|
+
CONNECTOR_TYPE = "vastdb"
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class VastdbAccessConfig(SQLAccessConfig):
|
|
46
|
+
endpoint: Optional[str] = Field(default=None, description="DB endpoint")
|
|
47
|
+
access_key_id: Optional[str] = Field(default=None, description="access key id")
|
|
48
|
+
access_key_secret: Optional[str] = Field(default=None, description="access key secret")
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class VastdbConnectionConfig(SQLConnectionConfig):
|
|
52
|
+
access_config: Secret[VastdbAccessConfig] = Field(
|
|
53
|
+
default=VastdbAccessConfig(), validate_default=True
|
|
54
|
+
)
|
|
55
|
+
vastdb_bucket: str
|
|
56
|
+
vastdb_schema: str
|
|
57
|
+
connector_type: str = Field(default=CONNECTOR_TYPE, init=False)
|
|
58
|
+
|
|
59
|
+
@requires_dependencies(["vastdb"], extras="vastdb")
|
|
60
|
+
@contextmanager
|
|
61
|
+
def get_connection(self) -> "VastdbConnect":
|
|
62
|
+
from vastdb import connect
|
|
63
|
+
|
|
64
|
+
access_config = self.access_config.get_secret_value()
|
|
65
|
+
connection = connect(
|
|
66
|
+
endpoint=access_config.endpoint,
|
|
67
|
+
access=access_config.access_key_id,
|
|
68
|
+
secret=access_config.access_key_secret,
|
|
69
|
+
)
|
|
70
|
+
yield connection
|
|
71
|
+
|
|
72
|
+
@contextmanager
|
|
73
|
+
def get_cursor(self) -> "VastdbTransaction":
|
|
74
|
+
with self.get_connection() as connection:
|
|
75
|
+
with connection.transaction() as transaction:
|
|
76
|
+
yield transaction
|
|
77
|
+
|
|
78
|
+
@contextmanager
|
|
79
|
+
def get_table(self, table_name: str) -> "VastdbTable":
|
|
80
|
+
with self.get_cursor() as cursor:
|
|
81
|
+
bucket = cursor.bucket(self.vastdb_bucket)
|
|
82
|
+
schema = bucket.schema(self.vastdb_schema)
|
|
83
|
+
table = schema.table(table_name)
|
|
84
|
+
yield table
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
class VastdbIndexerConfig(SQLIndexerConfig):
|
|
88
|
+
pass
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
@dataclass
|
|
92
|
+
class VastdbIndexer(SQLIndexer):
|
|
93
|
+
connection_config: VastdbConnectionConfig
|
|
94
|
+
index_config: VastdbIndexerConfig
|
|
95
|
+
connector_type: str = CONNECTOR_TYPE
|
|
96
|
+
|
|
97
|
+
def _get_doc_ids(self) -> list[str]:
|
|
98
|
+
with self.connection_config.get_table(self.index_config.table_name) as table:
|
|
99
|
+
reader = table.select(columns=[self.index_config.id_column])
|
|
100
|
+
results = reader.read_all() # Build a PyArrow Table from the RecordBatchReader
|
|
101
|
+
ids = sorted([result[self.index_config.id_column] for result in results.to_pylist()])
|
|
102
|
+
return ids
|
|
103
|
+
|
|
104
|
+
def precheck(self) -> None:
|
|
105
|
+
try:
|
|
106
|
+
with self.connection_config.get_table(self.index_config.table_name) as table:
|
|
107
|
+
table.select()
|
|
108
|
+
except Exception as e:
|
|
109
|
+
logger.error(f"failed to validate connection: {e}", exc_info=True)
|
|
110
|
+
raise DestinationConnectionError(f"failed to validate connection: {e}")
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
class VastdbDownloaderConfig(SQLDownloaderConfig):
|
|
114
|
+
pass
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
@dataclass
|
|
118
|
+
class VastdbDownloader(SQLDownloader):
|
|
119
|
+
connection_config: VastdbConnectionConfig
|
|
120
|
+
download_config: VastdbDownloaderConfig
|
|
121
|
+
connector_type: str = CONNECTOR_TYPE
|
|
122
|
+
|
|
123
|
+
@requires_dependencies(["ibis"], extras="vastdb")
|
|
124
|
+
def query_db(self, file_data: SqlBatchFileData) -> tuple[list[tuple], list[str]]:
|
|
125
|
+
from ibis import _ # imports the Ibis deferred expression
|
|
126
|
+
|
|
127
|
+
table_name = file_data.additional_metadata.table_name
|
|
128
|
+
id_column = file_data.additional_metadata.id_column
|
|
129
|
+
ids = tuple([item.identifier for item in file_data.batch_items])
|
|
130
|
+
|
|
131
|
+
with self.connection_config.get_table(table_name) as table:
|
|
132
|
+
|
|
133
|
+
predicate = _[id_column].isin(ids)
|
|
134
|
+
|
|
135
|
+
if self.download_config.fields:
|
|
136
|
+
# Vastdb requires the id column to be included in the fields
|
|
137
|
+
fields = self.download_config.fields + [id_column]
|
|
138
|
+
# dict.fromkeys to remove duplicates and keep order
|
|
139
|
+
reader = table.select(columns=list(dict.fromkeys(fields)), predicate=predicate)
|
|
140
|
+
else:
|
|
141
|
+
reader = table.select(predicate=predicate)
|
|
142
|
+
results = reader.read_all()
|
|
143
|
+
df = results.to_pandas()
|
|
144
|
+
return [tuple(r) for r in df.to_numpy()], results.column_names
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
class VastdbUploadStagerConfig(SQLUploadStagerConfig):
|
|
148
|
+
rename_columns_map: Optional[dict] = Field(
|
|
149
|
+
default=None,
|
|
150
|
+
description="Map of column names to rename, ex: {'old_name': 'new_name'}",
|
|
151
|
+
)
|
|
152
|
+
additional_columns: Optional[list[str]] = Field(
|
|
153
|
+
default_factory=list, description="Additional columns to include in the upload"
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
class VastdbUploadStager(SQLUploadStager):
|
|
158
|
+
upload_stager_config: VastdbUploadStagerConfig
|
|
159
|
+
|
|
160
|
+
def conform_dict(self, element_dict: dict, file_data: FileData) -> dict:
|
|
161
|
+
data = element_dict.copy()
|
|
162
|
+
metadata: dict[str, Any] = data.pop("metadata", {})
|
|
163
|
+
data_source = metadata.pop("data_source", {})
|
|
164
|
+
coordinates = metadata.pop("coordinates", {})
|
|
165
|
+
|
|
166
|
+
data.update(metadata)
|
|
167
|
+
data.update(data_source)
|
|
168
|
+
data.update(coordinates)
|
|
169
|
+
|
|
170
|
+
data["id"] = get_enhanced_element_id(element_dict=data, file_data=file_data)
|
|
171
|
+
|
|
172
|
+
# remove extraneous, not supported columns
|
|
173
|
+
# but also allow for additional columns
|
|
174
|
+
approved_columns = set(_COLUMNS).union(self.upload_stager_config.additional_columns)
|
|
175
|
+
element = {k: v for k, v in data.items() if k in approved_columns}
|
|
176
|
+
element[RECORD_ID_LABEL] = file_data.identifier
|
|
177
|
+
return element
|
|
178
|
+
|
|
179
|
+
def conform_dataframe(self, df: pd.DataFrame) -> pd.DataFrame:
|
|
180
|
+
df = super().conform_dataframe(df=df)
|
|
181
|
+
if self.upload_stager_config.rename_columns_map:
|
|
182
|
+
df.rename(columns=self.upload_stager_config.rename_columns_map, inplace=True)
|
|
183
|
+
return df
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
class VastdbUploaderConfig(SQLUploaderConfig):
|
|
187
|
+
pass
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
@dataclass
|
|
191
|
+
class VastdbUploader(SQLUploader):
|
|
192
|
+
upload_config: VastdbUploaderConfig = field(default_factory=VastdbUploaderConfig)
|
|
193
|
+
connection_config: VastdbConnectionConfig
|
|
194
|
+
connector_type: str = CONNECTOR_TYPE
|
|
195
|
+
|
|
196
|
+
def precheck(self) -> None:
|
|
197
|
+
try:
|
|
198
|
+
with self.connection_config.get_table(self.upload_config.table_name) as table:
|
|
199
|
+
table.select()
|
|
200
|
+
except Exception as e:
|
|
201
|
+
logger.error(f"failed to validate connection: {e}", exc_info=True)
|
|
202
|
+
raise DestinationConnectionError(f"failed to validate connection: {e}")
|
|
203
|
+
|
|
204
|
+
@requires_dependencies(["pyarrow"], extras="vastdb")
|
|
205
|
+
def upload_dataframe(self, df: pd.DataFrame, file_data: FileData) -> None:
|
|
206
|
+
import pyarrow as pa
|
|
207
|
+
|
|
208
|
+
if self.can_delete():
|
|
209
|
+
self.delete_by_record_id(file_data=file_data)
|
|
210
|
+
else:
|
|
211
|
+
logger.warning(
|
|
212
|
+
f"table doesn't contain expected "
|
|
213
|
+
f"record id column "
|
|
214
|
+
f"{self.upload_config.record_id_key}, skipping delete"
|
|
215
|
+
)
|
|
216
|
+
df.replace({np.nan: None}, inplace=True)
|
|
217
|
+
df = self._fit_to_schema(df=df)
|
|
218
|
+
|
|
219
|
+
logger.info(
|
|
220
|
+
f"writing a total of {len(df)} elements via"
|
|
221
|
+
f" document batches to destination"
|
|
222
|
+
f" table named {self.upload_config.table_name}"
|
|
223
|
+
f" with batch size {self.upload_config.batch_size}"
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
for rows in split_dataframe(df=df, chunk_size=self.upload_config.batch_size):
|
|
227
|
+
|
|
228
|
+
with self.connection_config.get_table(self.upload_config.table_name) as table:
|
|
229
|
+
pa_table = pa.Table.from_pandas(rows)
|
|
230
|
+
table.insert(pa_table)
|
|
231
|
+
|
|
232
|
+
def get_table_columns(self) -> list[str]:
|
|
233
|
+
if self._columns is None:
|
|
234
|
+
with self.connection_config.get_table(self.upload_config.table_name) as table:
|
|
235
|
+
self._columns = table.columns().names
|
|
236
|
+
return self._columns
|
|
237
|
+
|
|
238
|
+
@requires_dependencies(["ibis"], extras="vastdb")
|
|
239
|
+
def delete_by_record_id(self, file_data: FileData) -> None:
|
|
240
|
+
from ibis import _ # imports the Ibis deferred expression
|
|
241
|
+
|
|
242
|
+
logger.debug(
|
|
243
|
+
f"deleting any content with data "
|
|
244
|
+
f"{self.upload_config.record_id_key}={file_data.identifier} "
|
|
245
|
+
f"from table {self.upload_config.table_name}"
|
|
246
|
+
)
|
|
247
|
+
predicate = _[self.upload_config.record_id_key].isin([file_data.identifier])
|
|
248
|
+
with self.connection_config.get_table(self.upload_config.table_name) as table:
|
|
249
|
+
# Get the internal row id
|
|
250
|
+
rows_to_delete = table.select(
|
|
251
|
+
columns=[], predicate=predicate, internal_row_id=True
|
|
252
|
+
).read_all()
|
|
253
|
+
table.delete(rows_to_delete)
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
vastdb_source_entry = SourceRegistryEntry(
|
|
257
|
+
connection_config=VastdbConnectionConfig,
|
|
258
|
+
indexer_config=VastdbIndexerConfig,
|
|
259
|
+
indexer=VastdbIndexer,
|
|
260
|
+
downloader_config=VastdbDownloaderConfig,
|
|
261
|
+
downloader=VastdbDownloader,
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
vastdb_destination_entry = DestinationRegistryEntry(
|
|
265
|
+
connection_config=VastdbConnectionConfig,
|
|
266
|
+
uploader=VastdbUploader,
|
|
267
|
+
uploader_config=VastdbUploaderConfig,
|
|
268
|
+
upload_stager=VastdbUploadStager,
|
|
269
|
+
upload_stager_config=VastdbUploadStagerConfig,
|
|
270
|
+
)
|