unstructured-ingest 0.4.0__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of unstructured-ingest might be problematic. Click here for more details.

Files changed (30) hide show
  1. test/integration/connectors/utils/validation/equality.py +2 -1
  2. test/unit/v2/connectors/sql/test_sql.py +4 -2
  3. unstructured_ingest/__version__.py +1 -1
  4. unstructured_ingest/utils/data_prep.py +11 -3
  5. unstructured_ingest/utils/html.py +109 -0
  6. unstructured_ingest/utils/ndjson.py +52 -0
  7. unstructured_ingest/v2/interfaces/upload_stager.py +3 -13
  8. unstructured_ingest/v2/pipeline/steps/chunk.py +3 -4
  9. unstructured_ingest/v2/pipeline/steps/embed.py +3 -4
  10. unstructured_ingest/v2/pipeline/steps/partition.py +3 -4
  11. unstructured_ingest/v2/processes/connectors/confluence.py +95 -25
  12. unstructured_ingest/v2/processes/connectors/duckdb/base.py +2 -2
  13. unstructured_ingest/v2/processes/connectors/fsspec/azure.py +8 -8
  14. unstructured_ingest/v2/processes/connectors/fsspec/box.py +7 -7
  15. unstructured_ingest/v2/processes/connectors/fsspec/dropbox.py +9 -9
  16. unstructured_ingest/v2/processes/connectors/fsspec/fsspec.py +41 -9
  17. unstructured_ingest/v2/processes/connectors/fsspec/gcs.py +7 -7
  18. unstructured_ingest/v2/processes/connectors/fsspec/s3.py +8 -8
  19. unstructured_ingest/v2/processes/connectors/fsspec/sftp.py +5 -5
  20. unstructured_ingest/v2/processes/connectors/sql/__init__.py +4 -0
  21. unstructured_ingest/v2/processes/connectors/sql/singlestore.py +2 -1
  22. unstructured_ingest/v2/processes/connectors/sql/sql.py +12 -8
  23. unstructured_ingest/v2/processes/connectors/sql/sqlite.py +2 -1
  24. unstructured_ingest/v2/processes/connectors/sql/vastdb.py +270 -0
  25. {unstructured_ingest-0.4.0.dist-info → unstructured_ingest-0.4.1.dist-info}/METADATA +25 -22
  26. {unstructured_ingest-0.4.0.dist-info → unstructured_ingest-0.4.1.dist-info}/RECORD +30 -27
  27. {unstructured_ingest-0.4.0.dist-info → unstructured_ingest-0.4.1.dist-info}/LICENSE.md +0 -0
  28. {unstructured_ingest-0.4.0.dist-info → unstructured_ingest-0.4.1.dist-info}/WHEEL +0 -0
  29. {unstructured_ingest-0.4.0.dist-info → unstructured_ingest-0.4.1.dist-info}/entry_points.txt +0 -0
  30. {unstructured_ingest-0.4.0.dist-info → unstructured_ingest-0.4.1.dist-info}/top_level.txt +0 -0
@@ -119,7 +119,7 @@ class FsspecIndexer(Indexer):
119
119
  logger.error(f"failed to validate connection: {e}", exc_info=True)
120
120
  raise self.wrap_error(e=e)
121
121
 
122
- def get_file_data(self) -> list[dict[str, Any]]:
122
+ def get_file_info(self) -> list[dict[str, Any]]:
123
123
  if not self.index_config.recursive:
124
124
  # fs.ls does not walk directories
125
125
  # directories that are listed in cloud storage can cause problems
@@ -156,24 +156,56 @@ class FsspecIndexer(Indexer):
156
156
 
157
157
  return random.sample(files, n)
158
158
 
159
- def get_metadata(self, file_data: dict) -> FileDataSourceMetadata:
159
+ def get_metadata(self, file_info: dict) -> FileDataSourceMetadata:
160
160
  raise NotImplementedError()
161
161
 
162
- def get_path(self, file_data: dict) -> str:
163
- return file_data["name"]
162
+ def get_path(self, file_info: dict) -> str:
163
+ return file_info["name"]
164
164
 
165
165
  def sterilize_info(self, file_data: dict) -> dict:
166
166
  return sterilize_dict(data=file_data)
167
167
 
168
+ def create_init_file_data(self, remote_filepath: Optional[str] = None) -> FileData:
169
+ # Create initial file data that requires no network calls and is constructed purely
170
+ # with information that exists in the config
171
+ remote_filepath = remote_filepath or self.index_config.remote_url
172
+ path_without_protocol = remote_filepath.split("://")[1]
173
+ rel_path = remote_filepath.replace(path_without_protocol, "").lstrip("/")
174
+ return FileData(
175
+ identifier=str(uuid5(NAMESPACE_DNS, remote_filepath)),
176
+ connector_type=self.connector_type,
177
+ display_name=remote_filepath,
178
+ source_identifiers=SourceIdentifiers(
179
+ filename=Path(remote_filepath).name,
180
+ rel_path=rel_path or None,
181
+ fullpath=remote_filepath,
182
+ ),
183
+ metadata=FileDataSourceMetadata(url=remote_filepath),
184
+ )
185
+
186
+ def hydrate_file_data(self, init_file_data: FileData):
187
+ # Get file info
188
+ with self.connection_config.get_client(protocol=self.index_config.protocol) as client:
189
+ files = client.ls(self.index_config.path_without_protocol, detail=True)
190
+ filtered_files = [
191
+ file for file in files if file.get("size") > 0 and file.get("type") == "file"
192
+ ]
193
+ if not filtered_files:
194
+ raise ValueError(f"{init_file_data} did not reference any valid file")
195
+ if len(filtered_files) > 1:
196
+ raise ValueError(f"{init_file_data} referenced more than one file")
197
+ file_info = filtered_files[0]
198
+ init_file_data.additional_metadata = self.get_metadata(file_info=file_info)
199
+
168
200
  def run(self, **kwargs: Any) -> Generator[FileData, None, None]:
169
- files = self.get_file_data()
170
- for file_data in files:
171
- file_path = self.get_path(file_data=file_data)
201
+ files = self.get_file_info()
202
+ for file_info in files:
203
+ file_path = self.get_path(file_info=file_info)
172
204
  # Note: we remove any remaining leading slashes (Box introduces these)
173
205
  # to get a valid relative path
174
206
  rel_path = file_path.replace(self.index_config.path_without_protocol, "").lstrip("/")
175
207
 
176
- additional_metadata = self.sterilize_info(file_data=file_data)
208
+ additional_metadata = self.sterilize_info(file_data=file_info)
177
209
  additional_metadata["original_file_path"] = file_path
178
210
  yield FileData(
179
211
  identifier=str(uuid5(NAMESPACE_DNS, file_path)),
@@ -183,7 +215,7 @@ class FsspecIndexer(Indexer):
183
215
  rel_path=rel_path or None,
184
216
  fullpath=file_path,
185
217
  ),
186
- metadata=self.get_metadata(file_data=file_data),
218
+ metadata=self.get_metadata(file_info=file_info),
187
219
  additional_metadata=additional_metadata,
188
220
  display_name=file_path,
189
221
  )
@@ -131,22 +131,22 @@ class GcsIndexer(FsspecIndexer):
131
131
  index_config: GcsIndexerConfig
132
132
  connector_type: str = CONNECTOR_TYPE
133
133
 
134
- def get_metadata(self, file_data: dict) -> FileDataSourceMetadata:
135
- path = file_data["name"]
134
+ def get_metadata(self, file_info: dict) -> FileDataSourceMetadata:
135
+ path = file_info["name"]
136
136
  date_created = None
137
137
  date_modified = None
138
- if modified_at_str := file_data.get("updated"):
138
+ if modified_at_str := file_info.get("updated"):
139
139
  date_modified = str(parser.parse(modified_at_str).timestamp())
140
- if created_at_str := file_data.get("timeCreated"):
140
+ if created_at_str := file_info.get("timeCreated"):
141
141
  date_created = str(parser.parse(created_at_str).timestamp())
142
142
 
143
- file_size = file_data.get("size") if "size" in file_data else None
143
+ file_size = file_info.get("size") if "size" in file_info else None
144
144
 
145
- version = file_data.get("etag")
145
+ version = file_info.get("etag")
146
146
  record_locator = {
147
147
  "protocol": self.index_config.protocol,
148
148
  "remote_file_path": self.index_config.remote_url,
149
- "file_id": file_data.get("id"),
149
+ "file_id": file_info.get("id"),
150
150
  }
151
151
  return FileDataSourceMetadata(
152
152
  date_created=date_created,
@@ -110,22 +110,22 @@ class S3Indexer(FsspecIndexer):
110
110
  def wrap_error(self, e: Exception) -> Exception:
111
111
  return self.connection_config.wrap_error(e=e)
112
112
 
113
- def get_path(self, file_data: dict) -> str:
114
- return file_data["Key"]
113
+ def get_path(self, file_info: dict) -> str:
114
+ return file_info["Key"]
115
115
 
116
- def get_metadata(self, file_data: dict) -> FileDataSourceMetadata:
117
- path = file_data["Key"]
116
+ def get_metadata(self, file_info: dict) -> FileDataSourceMetadata:
117
+ path = file_info["Key"]
118
118
  date_created = None
119
119
  date_modified = None
120
- modified = file_data.get("LastModified")
120
+ modified = file_info.get("LastModified")
121
121
  if modified:
122
122
  date_created = str(modified.timestamp())
123
123
  date_modified = str(modified.timestamp())
124
124
 
125
- file_size = file_data.get("size") if "size" in file_data else None
126
- file_size = file_size or file_data.get("Size")
125
+ file_size = file_info.get("size") if "size" in file_info else None
126
+ file_size = file_size or file_info.get("Size")
127
127
 
128
- version = file_data.get("ETag").rstrip('"').lstrip('"') if "ETag" in file_data else None
128
+ version = file_info.get("ETag").rstrip('"').lstrip('"') if "ETag" in file_info else None
129
129
  metadata: dict[str, str] = {}
130
130
  with contextlib.suppress(AttributeError):
131
131
  with self.connection_config.get_client(protocol=self.index_config.protocol) as client:
@@ -107,12 +107,12 @@ class SftpIndexer(FsspecIndexer):
107
107
  file.identifier = new_identifier
108
108
  yield file
109
109
 
110
- def get_metadata(self, file_data: dict) -> FileDataSourceMetadata:
111
- path = file_data["name"]
112
- date_created = str(file_data.get("time").timestamp()) if "time" in file_data else None
113
- date_modified = str(file_data.get("mtime").timestamp()) if "mtime" in file_data else None
110
+ def get_metadata(self, file_info: dict) -> FileDataSourceMetadata:
111
+ path = file_info["name"]
112
+ date_created = str(file_info.get("time").timestamp()) if "time" in file_info else None
113
+ date_modified = str(file_info.get("mtime").timestamp()) if "mtime" in file_info else None
114
114
 
115
- file_size = file_data.get("size") if "size" in file_data else None
115
+ file_size = file_info.get("size") if "size" in file_info else None
116
116
 
117
117
  record_locator = {
118
118
  "protocol": self.index_config.protocol,
@@ -15,11 +15,14 @@ from .snowflake import CONNECTOR_TYPE as SNOWFLAKE_CONNECTOR_TYPE
15
15
  from .snowflake import snowflake_destination_entry, snowflake_source_entry
16
16
  from .sqlite import CONNECTOR_TYPE as SQLITE_CONNECTOR_TYPE
17
17
  from .sqlite import sqlite_destination_entry, sqlite_source_entry
18
+ from .vastdb import CONNECTOR_TYPE as VASTDB_CONNECTOR_TYPE
19
+ from .vastdb import vastdb_destination_entry, vastdb_source_entry
18
20
 
19
21
  add_source_entry(source_type=SQLITE_CONNECTOR_TYPE, entry=sqlite_source_entry)
20
22
  add_source_entry(source_type=POSTGRES_CONNECTOR_TYPE, entry=postgres_source_entry)
21
23
  add_source_entry(source_type=SNOWFLAKE_CONNECTOR_TYPE, entry=snowflake_source_entry)
22
24
  add_source_entry(source_type=SINGLESTORE_CONNECTOR_TYPE, entry=singlestore_source_entry)
25
+ add_source_entry(source_type=VASTDB_CONNECTOR_TYPE, entry=vastdb_source_entry)
23
26
 
24
27
  add_destination_entry(destination_type=SQLITE_CONNECTOR_TYPE, entry=sqlite_destination_entry)
25
28
  add_destination_entry(destination_type=POSTGRES_CONNECTOR_TYPE, entry=postgres_destination_entry)
@@ -31,3 +34,4 @@ add_destination_entry(
31
34
  destination_type=DATABRICKS_DELTA_TABLES_CONNECTOR_TYPE,
32
35
  entry=databricks_delta_tables_destination_entry,
33
36
  )
37
+ add_destination_entry(destination_type=VASTDB_CONNECTOR_TYPE, entry=vastdb_destination_entry)
@@ -3,6 +3,7 @@ from contextlib import contextmanager
3
3
  from dataclasses import dataclass, field
4
4
  from typing import TYPE_CHECKING, Any, Generator, Optional
5
5
 
6
+ import pandas as pd
6
7
  from pydantic import Field, Secret
7
8
 
8
9
  from unstructured_ingest.v2.logger import logger
@@ -139,7 +140,7 @@ class SingleStoreUploader(SQLUploader):
139
140
  if isinstance(value, (list, dict)):
140
141
  value = json.dumps(value)
141
142
  if column_name in _DATE_COLUMNS:
142
- if value is None:
143
+ if value is None or pd.isna(value):
143
144
  parsed.append(None)
144
145
  else:
145
146
  parsed.append(parse_date_string(value))
@@ -14,7 +14,7 @@ from dateutil import parser
14
14
  from pydantic import BaseModel, Field, Secret
15
15
 
16
16
  from unstructured_ingest.error import DestinationConnectionError, SourceConnectionError
17
- from unstructured_ingest.utils.data_prep import get_data, get_data_df, split_dataframe
17
+ from unstructured_ingest.utils.data_prep import get_data, get_data_df, split_dataframe, write_data
18
18
  from unstructured_ingest.v2.constants import RECORD_ID_LABEL
19
19
  from unstructured_ingest.v2.interfaces import (
20
20
  AccessConfig,
@@ -314,7 +314,7 @@ class SQLUploadStager(UploadStager):
314
314
  output_filename = f"{Path(output_filename).stem}{output_filename_suffix}"
315
315
  output_path = self.get_output_path(output_filename=output_filename, output_dir=output_dir)
316
316
 
317
- self.write_output(output_path=output_path, data=df.to_dict(orient="records"))
317
+ write_data(path=output_path, data=df.to_dict(orient="records"))
318
318
  return output_path
319
319
 
320
320
 
@@ -332,6 +332,7 @@ class SQLUploader(Uploader):
332
332
  upload_config: SQLUploaderConfig
333
333
  connection_config: SQLConnectionConfig
334
334
  values_delimiter: str = "?"
335
+ _columns: list[str] = field(init=False, default=None)
335
336
 
336
337
  def precheck(self) -> None:
337
338
  try:
@@ -354,7 +355,7 @@ class SQLUploader(Uploader):
354
355
  parsed = []
355
356
  for column_name, value in zip(columns, row):
356
357
  if column_name in _DATE_COLUMNS:
357
- if value is None:
358
+ if value is None or pd.isna(value): # pandas is nan
358
359
  parsed.append(None)
359
360
  else:
360
361
  parsed.append(parse_date_string(value))
@@ -364,8 +365,9 @@ class SQLUploader(Uploader):
364
365
  return output
365
366
 
366
367
  def _fit_to_schema(self, df: pd.DataFrame) -> pd.DataFrame:
368
+ table_columns = self.get_table_columns()
367
369
  columns = set(df.columns)
368
- schema_fields = set(columns)
370
+ schema_fields = set(table_columns)
369
371
  columns_to_drop = columns - schema_fields
370
372
  missing_columns = schema_fields - columns
371
373
 
@@ -395,8 +397,8 @@ class SQLUploader(Uploader):
395
397
  f"record id column "
396
398
  f"{self.upload_config.record_id_key}, skipping delete"
397
399
  )
400
+ df = self._fit_to_schema(df=df)
398
401
  df.replace({np.nan: None}, inplace=True)
399
- self._fit_to_schema(df=df)
400
402
 
401
403
  columns = list(df.columns)
402
404
  stmt = "INSERT INTO {table_name} ({columns}) VALUES({values})".format(
@@ -424,9 +426,11 @@ class SQLUploader(Uploader):
424
426
  cursor.executemany(stmt, values)
425
427
 
426
428
  def get_table_columns(self) -> list[str]:
427
- with self.get_cursor() as cursor:
428
- cursor.execute(f"SELECT * from {self.upload_config.table_name}")
429
- return [desc[0] for desc in cursor.description]
429
+ if self._columns is None:
430
+ with self.get_cursor() as cursor:
431
+ cursor.execute(f"SELECT * from {self.upload_config.table_name} LIMIT 1")
432
+ self._columns = [desc[0] for desc in cursor.description]
433
+ return self._columns
430
434
 
431
435
  def can_delete(self) -> bool:
432
436
  return self.upload_config.record_id_key in self.get_table_columns()
@@ -4,6 +4,7 @@ from dataclasses import dataclass, field
4
4
  from pathlib import Path
5
5
  from typing import TYPE_CHECKING, Any, Generator
6
6
 
7
+ import pandas as pd
7
8
  from pydantic import Field, Secret, model_validator
8
9
 
9
10
  from unstructured_ingest.v2.logger import logger
@@ -141,7 +142,7 @@ class SQLiteUploader(SQLUploader):
141
142
  if isinstance(value, (list, dict)):
142
143
  value = json.dumps(value)
143
144
  if column_name in _DATE_COLUMNS:
144
- if value is None:
145
+ if value is None or pd.isna(value):
145
146
  parsed.append(None)
146
147
  else:
147
148
  parsed.append(parse_date_string(value))
@@ -0,0 +1,270 @@
1
+ from contextlib import contextmanager
2
+ from dataclasses import dataclass, field
3
+ from typing import TYPE_CHECKING, Any, Optional
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ from pydantic import Field, Secret
8
+
9
+ from unstructured_ingest.error import DestinationConnectionError
10
+ from unstructured_ingest.utils.data_prep import split_dataframe
11
+ from unstructured_ingest.utils.dep_check import requires_dependencies
12
+ from unstructured_ingest.v2.constants import RECORD_ID_LABEL
13
+ from unstructured_ingest.v2.interfaces import (
14
+ FileData,
15
+ )
16
+ from unstructured_ingest.v2.logger import logger
17
+ from unstructured_ingest.v2.processes.connector_registry import (
18
+ DestinationRegistryEntry,
19
+ SourceRegistryEntry,
20
+ )
21
+ from unstructured_ingest.v2.processes.connectors.sql.sql import (
22
+ _COLUMNS,
23
+ SQLAccessConfig,
24
+ SqlBatchFileData,
25
+ SQLConnectionConfig,
26
+ SQLDownloader,
27
+ SQLDownloaderConfig,
28
+ SQLIndexer,
29
+ SQLIndexerConfig,
30
+ SQLUploader,
31
+ SQLUploaderConfig,
32
+ SQLUploadStager,
33
+ SQLUploadStagerConfig,
34
+ )
35
+ from unstructured_ingest.v2.utils import get_enhanced_element_id
36
+
37
+ if TYPE_CHECKING:
38
+ from vastdb import connect as VastdbConnect
39
+ from vastdb import transaction as VastdbTransaction
40
+ from vastdb.table import Table as VastdbTable
41
+
42
+ CONNECTOR_TYPE = "vastdb"
43
+
44
+
45
+ class VastdbAccessConfig(SQLAccessConfig):
46
+ endpoint: Optional[str] = Field(default=None, description="DB endpoint")
47
+ access_key_id: Optional[str] = Field(default=None, description="access key id")
48
+ access_key_secret: Optional[str] = Field(default=None, description="access key secret")
49
+
50
+
51
+ class VastdbConnectionConfig(SQLConnectionConfig):
52
+ access_config: Secret[VastdbAccessConfig] = Field(
53
+ default=VastdbAccessConfig(), validate_default=True
54
+ )
55
+ vastdb_bucket: str
56
+ vastdb_schema: str
57
+ connector_type: str = Field(default=CONNECTOR_TYPE, init=False)
58
+
59
+ @requires_dependencies(["vastdb"], extras="vastdb")
60
+ @contextmanager
61
+ def get_connection(self) -> "VastdbConnect":
62
+ from vastdb import connect
63
+
64
+ access_config = self.access_config.get_secret_value()
65
+ connection = connect(
66
+ endpoint=access_config.endpoint,
67
+ access=access_config.access_key_id,
68
+ secret=access_config.access_key_secret,
69
+ )
70
+ yield connection
71
+
72
+ @contextmanager
73
+ def get_cursor(self) -> "VastdbTransaction":
74
+ with self.get_connection() as connection:
75
+ with connection.transaction() as transaction:
76
+ yield transaction
77
+
78
+ @contextmanager
79
+ def get_table(self, table_name: str) -> "VastdbTable":
80
+ with self.get_cursor() as cursor:
81
+ bucket = cursor.bucket(self.vastdb_bucket)
82
+ schema = bucket.schema(self.vastdb_schema)
83
+ table = schema.table(table_name)
84
+ yield table
85
+
86
+
87
+ class VastdbIndexerConfig(SQLIndexerConfig):
88
+ pass
89
+
90
+
91
+ @dataclass
92
+ class VastdbIndexer(SQLIndexer):
93
+ connection_config: VastdbConnectionConfig
94
+ index_config: VastdbIndexerConfig
95
+ connector_type: str = CONNECTOR_TYPE
96
+
97
+ def _get_doc_ids(self) -> list[str]:
98
+ with self.connection_config.get_table(self.index_config.table_name) as table:
99
+ reader = table.select(columns=[self.index_config.id_column])
100
+ results = reader.read_all() # Build a PyArrow Table from the RecordBatchReader
101
+ ids = sorted([result[self.index_config.id_column] for result in results.to_pylist()])
102
+ return ids
103
+
104
+ def precheck(self) -> None:
105
+ try:
106
+ with self.connection_config.get_table(self.index_config.table_name) as table:
107
+ table.select()
108
+ except Exception as e:
109
+ logger.error(f"failed to validate connection: {e}", exc_info=True)
110
+ raise DestinationConnectionError(f"failed to validate connection: {e}")
111
+
112
+
113
+ class VastdbDownloaderConfig(SQLDownloaderConfig):
114
+ pass
115
+
116
+
117
+ @dataclass
118
+ class VastdbDownloader(SQLDownloader):
119
+ connection_config: VastdbConnectionConfig
120
+ download_config: VastdbDownloaderConfig
121
+ connector_type: str = CONNECTOR_TYPE
122
+
123
+ @requires_dependencies(["ibis"], extras="vastdb")
124
+ def query_db(self, file_data: SqlBatchFileData) -> tuple[list[tuple], list[str]]:
125
+ from ibis import _ # imports the Ibis deferred expression
126
+
127
+ table_name = file_data.additional_metadata.table_name
128
+ id_column = file_data.additional_metadata.id_column
129
+ ids = tuple([item.identifier for item in file_data.batch_items])
130
+
131
+ with self.connection_config.get_table(table_name) as table:
132
+
133
+ predicate = _[id_column].isin(ids)
134
+
135
+ if self.download_config.fields:
136
+ # Vastdb requires the id column to be included in the fields
137
+ fields = self.download_config.fields + [id_column]
138
+ # dict.fromkeys to remove duplicates and keep order
139
+ reader = table.select(columns=list(dict.fromkeys(fields)), predicate=predicate)
140
+ else:
141
+ reader = table.select(predicate=predicate)
142
+ results = reader.read_all()
143
+ df = results.to_pandas()
144
+ return [tuple(r) for r in df.to_numpy()], results.column_names
145
+
146
+
147
+ class VastdbUploadStagerConfig(SQLUploadStagerConfig):
148
+ rename_columns_map: Optional[dict] = Field(
149
+ default=None,
150
+ description="Map of column names to rename, ex: {'old_name': 'new_name'}",
151
+ )
152
+ additional_columns: Optional[list[str]] = Field(
153
+ default_factory=list, description="Additional columns to include in the upload"
154
+ )
155
+
156
+
157
+ class VastdbUploadStager(SQLUploadStager):
158
+ upload_stager_config: VastdbUploadStagerConfig
159
+
160
+ def conform_dict(self, element_dict: dict, file_data: FileData) -> dict:
161
+ data = element_dict.copy()
162
+ metadata: dict[str, Any] = data.pop("metadata", {})
163
+ data_source = metadata.pop("data_source", {})
164
+ coordinates = metadata.pop("coordinates", {})
165
+
166
+ data.update(metadata)
167
+ data.update(data_source)
168
+ data.update(coordinates)
169
+
170
+ data["id"] = get_enhanced_element_id(element_dict=data, file_data=file_data)
171
+
172
+ # remove extraneous, not supported columns
173
+ # but also allow for additional columns
174
+ approved_columns = set(_COLUMNS).union(self.upload_stager_config.additional_columns)
175
+ element = {k: v for k, v in data.items() if k in approved_columns}
176
+ element[RECORD_ID_LABEL] = file_data.identifier
177
+ return element
178
+
179
+ def conform_dataframe(self, df: pd.DataFrame) -> pd.DataFrame:
180
+ df = super().conform_dataframe(df=df)
181
+ if self.upload_stager_config.rename_columns_map:
182
+ df.rename(columns=self.upload_stager_config.rename_columns_map, inplace=True)
183
+ return df
184
+
185
+
186
+ class VastdbUploaderConfig(SQLUploaderConfig):
187
+ pass
188
+
189
+
190
+ @dataclass
191
+ class VastdbUploader(SQLUploader):
192
+ upload_config: VastdbUploaderConfig = field(default_factory=VastdbUploaderConfig)
193
+ connection_config: VastdbConnectionConfig
194
+ connector_type: str = CONNECTOR_TYPE
195
+
196
+ def precheck(self) -> None:
197
+ try:
198
+ with self.connection_config.get_table(self.upload_config.table_name) as table:
199
+ table.select()
200
+ except Exception as e:
201
+ logger.error(f"failed to validate connection: {e}", exc_info=True)
202
+ raise DestinationConnectionError(f"failed to validate connection: {e}")
203
+
204
+ @requires_dependencies(["pyarrow"], extras="vastdb")
205
+ def upload_dataframe(self, df: pd.DataFrame, file_data: FileData) -> None:
206
+ import pyarrow as pa
207
+
208
+ if self.can_delete():
209
+ self.delete_by_record_id(file_data=file_data)
210
+ else:
211
+ logger.warning(
212
+ f"table doesn't contain expected "
213
+ f"record id column "
214
+ f"{self.upload_config.record_id_key}, skipping delete"
215
+ )
216
+ df.replace({np.nan: None}, inplace=True)
217
+ df = self._fit_to_schema(df=df)
218
+
219
+ logger.info(
220
+ f"writing a total of {len(df)} elements via"
221
+ f" document batches to destination"
222
+ f" table named {self.upload_config.table_name}"
223
+ f" with batch size {self.upload_config.batch_size}"
224
+ )
225
+
226
+ for rows in split_dataframe(df=df, chunk_size=self.upload_config.batch_size):
227
+
228
+ with self.connection_config.get_table(self.upload_config.table_name) as table:
229
+ pa_table = pa.Table.from_pandas(rows)
230
+ table.insert(pa_table)
231
+
232
+ def get_table_columns(self) -> list[str]:
233
+ if self._columns is None:
234
+ with self.connection_config.get_table(self.upload_config.table_name) as table:
235
+ self._columns = table.columns().names
236
+ return self._columns
237
+
238
+ @requires_dependencies(["ibis"], extras="vastdb")
239
+ def delete_by_record_id(self, file_data: FileData) -> None:
240
+ from ibis import _ # imports the Ibis deferred expression
241
+
242
+ logger.debug(
243
+ f"deleting any content with data "
244
+ f"{self.upload_config.record_id_key}={file_data.identifier} "
245
+ f"from table {self.upload_config.table_name}"
246
+ )
247
+ predicate = _[self.upload_config.record_id_key].isin([file_data.identifier])
248
+ with self.connection_config.get_table(self.upload_config.table_name) as table:
249
+ # Get the internal row id
250
+ rows_to_delete = table.select(
251
+ columns=[], predicate=predicate, internal_row_id=True
252
+ ).read_all()
253
+ table.delete(rows_to_delete)
254
+
255
+
256
+ vastdb_source_entry = SourceRegistryEntry(
257
+ connection_config=VastdbConnectionConfig,
258
+ indexer_config=VastdbIndexerConfig,
259
+ indexer=VastdbIndexer,
260
+ downloader_config=VastdbDownloaderConfig,
261
+ downloader=VastdbDownloader,
262
+ )
263
+
264
+ vastdb_destination_entry = DestinationRegistryEntry(
265
+ connection_config=VastdbConnectionConfig,
266
+ uploader=VastdbUploader,
267
+ uploader_config=VastdbUploaderConfig,
268
+ upload_stager=VastdbUploadStager,
269
+ upload_stager_config=VastdbUploadStagerConfig,
270
+ )