unstructured-ingest 0.1.1__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of unstructured-ingest might be problematic. Click here for more details.

Files changed (39) hide show
  1. test/integration/connectors/conftest.py +13 -0
  2. test/integration/connectors/databricks_tests/test_volumes_native.py +8 -4
  3. test/integration/connectors/sql/test_postgres.py +6 -10
  4. test/integration/connectors/sql/test_singlestore.py +156 -0
  5. test/integration/connectors/sql/test_snowflake.py +205 -0
  6. test/integration/connectors/sql/test_sqlite.py +6 -10
  7. test/integration/connectors/test_delta_table.py +138 -0
  8. test/integration/connectors/test_s3.py +1 -1
  9. test/integration/connectors/utils/docker.py +78 -0
  10. test/integration/connectors/utils/docker_compose.py +23 -8
  11. test/integration/connectors/utils/validation.py +93 -2
  12. unstructured_ingest/__version__.py +1 -1
  13. unstructured_ingest/v2/cli/utils/click.py +32 -1
  14. unstructured_ingest/v2/cli/utils/model_conversion.py +10 -3
  15. unstructured_ingest/v2/interfaces/file_data.py +1 -0
  16. unstructured_ingest/v2/interfaces/indexer.py +4 -1
  17. unstructured_ingest/v2/pipeline/pipeline.py +10 -2
  18. unstructured_ingest/v2/pipeline/steps/index.py +18 -1
  19. unstructured_ingest/v2/processes/connectors/__init__.py +13 -6
  20. unstructured_ingest/v2/processes/connectors/astradb.py +278 -55
  21. unstructured_ingest/v2/processes/connectors/databricks/volumes.py +3 -1
  22. unstructured_ingest/v2/processes/connectors/delta_table.py +185 -0
  23. unstructured_ingest/v2/processes/connectors/fsspec/fsspec.py +1 -0
  24. unstructured_ingest/v2/processes/connectors/slack.py +248 -0
  25. unstructured_ingest/v2/processes/connectors/sql/__init__.py +15 -2
  26. unstructured_ingest/v2/processes/connectors/sql/postgres.py +33 -56
  27. unstructured_ingest/v2/processes/connectors/sql/singlestore.py +168 -0
  28. unstructured_ingest/v2/processes/connectors/sql/snowflake.py +162 -0
  29. unstructured_ingest/v2/processes/connectors/sql/sql.py +51 -12
  30. unstructured_ingest/v2/processes/connectors/sql/sqlite.py +31 -32
  31. unstructured_ingest/v2/unstructured_api.py +1 -1
  32. {unstructured_ingest-0.1.1.dist-info → unstructured_ingest-0.2.1.dist-info}/METADATA +19 -17
  33. {unstructured_ingest-0.1.1.dist-info → unstructured_ingest-0.2.1.dist-info}/RECORD +37 -31
  34. unstructured_ingest/v2/processes/connectors/databricks_volumes.py +0 -250
  35. unstructured_ingest/v2/processes/connectors/singlestore.py +0 -156
  36. {unstructured_ingest-0.1.1.dist-info → unstructured_ingest-0.2.1.dist-info}/LICENSE.md +0 -0
  37. {unstructured_ingest-0.1.1.dist-info → unstructured_ingest-0.2.1.dist-info}/WHEEL +0 -0
  38. {unstructured_ingest-0.1.1.dist-info → unstructured_ingest-0.2.1.dist-info}/entry_points.txt +0 -0
  39. {unstructured_ingest-0.1.1.dist-info → unstructured_ingest-0.2.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,248 @@
1
+ import hashlib
2
+ import time
3
+ import xml.etree.ElementTree as ET
4
+ from dataclasses import dataclass, field
5
+ from datetime import datetime
6
+ from pathlib import Path
7
+ from typing import TYPE_CHECKING, Any, Generator, Optional
8
+
9
+ from pydantic import Field, Secret
10
+
11
+ from unstructured_ingest.error import SourceConnectionError
12
+ from unstructured_ingest.logger import logger
13
+ from unstructured_ingest.utils.dep_check import requires_dependencies
14
+ from unstructured_ingest.v2.interfaces import (
15
+ AccessConfig,
16
+ ConnectionConfig,
17
+ Downloader,
18
+ DownloaderConfig,
19
+ Indexer,
20
+ IndexerConfig,
21
+ download_responses,
22
+ )
23
+ from unstructured_ingest.v2.interfaces.file_data import (
24
+ FileData,
25
+ FileDataSourceMetadata,
26
+ SourceIdentifiers,
27
+ )
28
+ from unstructured_ingest.v2.processes.connector_registry import SourceRegistryEntry
29
+
30
+ if TYPE_CHECKING:
31
+ from slack_sdk import WebClient
32
+ from slack_sdk.web.async_client import AsyncWebClient
33
+
34
+ # NOTE: Pagination limit set to the upper end of the recommended range
35
+ # https://api.slack.com/apis/pagination#facts
36
+ PAGINATION_LIMIT = 200
37
+
38
+ CONNECTOR_TYPE = "slack"
39
+
40
+
41
+ class SlackAccessConfig(AccessConfig):
42
+ token: str = Field(
43
+ description="Bot token used to access Slack API, must have channels:history scope for the"
44
+ " bot user."
45
+ )
46
+
47
+
48
+ class SlackConnectionConfig(ConnectionConfig):
49
+ access_config: Secret[SlackAccessConfig]
50
+
51
+ @requires_dependencies(["slack_sdk"], extras="slack")
52
+ @SourceConnectionError.wrap
53
+ def get_client(self) -> "WebClient":
54
+ from slack_sdk import WebClient
55
+
56
+ return WebClient(token=self.access_config.get_secret_value().token)
57
+
58
+ @requires_dependencies(["slack_sdk"], extras="slack")
59
+ @SourceConnectionError.wrap
60
+ def get_async_client(self) -> "AsyncWebClient":
61
+ from slack_sdk.web.async_client import AsyncWebClient
62
+
63
+ return AsyncWebClient(token=self.access_config.get_secret_value().token)
64
+
65
+
66
+ class SlackIndexerConfig(IndexerConfig):
67
+ channels: list[str] = Field(
68
+ description="Comma-delimited list of Slack channel IDs to pull messages from, can be"
69
+ " both public or private channels."
70
+ )
71
+ start_date: Optional[datetime] = Field(
72
+ default=None,
73
+ description="Start date/time in formats YYYY-MM-DD[T]HH:MM[:SS[.ffffff]][Z or [±]HH[:]MM]"
74
+ " or YYYY-MM-DD",
75
+ )
76
+ end_date: Optional[datetime] = Field(
77
+ default=None,
78
+ description="End date/time in formats YYYY-MM-DD[T]HH:MM[:SS[.ffffff]][Z or [±]HH[:]MM]"
79
+ " or YYYY-MM-DD",
80
+ )
81
+
82
+
83
+ @dataclass
84
+ class SlackIndexer(Indexer):
85
+ index_config: SlackIndexerConfig
86
+ connection_config: SlackConnectionConfig
87
+ connector_type: str = CONNECTOR_TYPE
88
+
89
+ def run(self, **kwargs: Any) -> Generator[FileData, None, None]:
90
+ client = self.connection_config.get_client()
91
+ for channel in self.index_config.channels:
92
+ messages = []
93
+ oldest = (
94
+ str(self.index_config.start_date.timestamp())
95
+ if self.index_config.start_date is not None
96
+ else None
97
+ )
98
+ latest = (
99
+ str(self.index_config.end_date.timestamp())
100
+ if self.index_config.end_date is not None
101
+ else None
102
+ )
103
+ for conversation_history in client.conversations_history(
104
+ channel=channel,
105
+ oldest=oldest,
106
+ latest=latest,
107
+ limit=PAGINATION_LIMIT,
108
+ ):
109
+ messages = conversation_history.get("messages", [])
110
+ if messages:
111
+ yield self._messages_to_file_data(messages, channel)
112
+
113
+ def _messages_to_file_data(
114
+ self,
115
+ messages: list[dict],
116
+ channel: str,
117
+ ) -> FileData:
118
+ ts_oldest = min((message["ts"] for message in messages), key=lambda m: float(m))
119
+ ts_newest = max((message["ts"] for message in messages), key=lambda m: float(m))
120
+
121
+ identifier_base = f"{channel}-{ts_oldest}-{ts_newest}"
122
+ identifier = hashlib.sha256(identifier_base.encode("utf-8")).hexdigest()
123
+ filename = identifier[:16]
124
+
125
+ return FileData(
126
+ identifier=identifier,
127
+ connector_type=CONNECTOR_TYPE,
128
+ source_identifiers=SourceIdentifiers(
129
+ filename=f"{filename}.xml", fullpath=f"{filename}.xml"
130
+ ),
131
+ metadata=FileDataSourceMetadata(
132
+ date_created=ts_oldest,
133
+ date_modified=ts_newest,
134
+ date_processed=str(time.time()),
135
+ record_locator={
136
+ "channel": channel,
137
+ "oldest": ts_oldest,
138
+ "latest": ts_newest,
139
+ },
140
+ ),
141
+ )
142
+
143
+ @SourceConnectionError.wrap
144
+ def precheck(self) -> None:
145
+ client = self.connection_config.get_client()
146
+ for channel in self.index_config.channels:
147
+ # NOTE: Querying conversations history guarantees that the bot is in the channel
148
+ client.conversations_history(channel=channel, limit=1)
149
+
150
+
151
+ class SlackDownloaderConfig(DownloaderConfig):
152
+ pass
153
+
154
+
155
+ @dataclass
156
+ class SlackDownloader(Downloader):
157
+ connector_type: str = CONNECTOR_TYPE
158
+ connection_config: SlackConnectionConfig
159
+ download_config: SlackDownloaderConfig = field(default_factory=SlackDownloaderConfig)
160
+
161
+ def run(self, file_data, **kwargs):
162
+ raise NotImplementedError
163
+
164
+ async def run_async(self, file_data: FileData, **kwargs) -> download_responses:
165
+ # NOTE: Indexer should provide source identifiers required to generate the download path
166
+ download_path = self.get_download_path(file_data)
167
+ if download_path is None:
168
+ logger.error(
169
+ "Generated download path is None, source_identifiers might be missing"
170
+ "from FileData."
171
+ )
172
+ raise ValueError("Generated invalid download path.")
173
+
174
+ await self._download_conversation(file_data, download_path)
175
+ return self.generate_download_response(file_data, download_path)
176
+
177
+ def is_async(self):
178
+ return True
179
+
180
+ async def _download_conversation(self, file_data: FileData, download_path: Path) -> None:
181
+ # NOTE: Indexer should supply the record locator in metadata
182
+ if (
183
+ file_data.metadata.record_locator is None
184
+ or "channel" not in file_data.metadata.record_locator
185
+ or "oldest" not in file_data.metadata.record_locator
186
+ or "latest" not in file_data.metadata.record_locator
187
+ ):
188
+ logger.error(
189
+ f"Invalid record locator in metadata: {file_data.metadata.record_locator}."
190
+ "Keys 'channel', 'oldest' and 'latest' must be present."
191
+ )
192
+ raise ValueError("Invalid record locator.")
193
+
194
+ client = self.connection_config.get_async_client()
195
+ messages = []
196
+ async for conversation_history in await client.conversations_history(
197
+ channel=file_data.metadata.record_locator["channel"],
198
+ oldest=file_data.metadata.record_locator["oldest"],
199
+ latest=file_data.metadata.record_locator["latest"],
200
+ limit=PAGINATION_LIMIT,
201
+ # NOTE: In order to get the exact same range of messages as indexer, it provides
202
+ # timestamps of oldest and newest messages, inclusive=True is necessary to include them
203
+ inclusive=True,
204
+ ):
205
+ messages += conversation_history.get("messages", [])
206
+
207
+ conversation = []
208
+ for message in messages:
209
+ thread_messages = []
210
+ async for conversations_replies in await client.conversations_replies(
211
+ channel=file_data.metadata.record_locator["channel"],
212
+ ts=message["ts"],
213
+ limit=PAGINATION_LIMIT,
214
+ ):
215
+ thread_messages += conversations_replies.get("messages", [])
216
+
217
+ # NOTE: Replies contains the whole thread, including the message references by the `ts`
218
+ # parameter even if it's the only message (there were no replies).
219
+ # Reference: https://api.slack.com/methods/conversations.replies#markdown
220
+ conversation.append(thread_messages)
221
+
222
+ conversation_xml = self._conversation_to_xml(conversation)
223
+ download_path.parent.mkdir(exist_ok=True, parents=True)
224
+ conversation_xml.write(download_path, encoding="utf-8", xml_declaration=True)
225
+
226
+ def _conversation_to_xml(self, conversation: list[list[dict]]) -> ET.ElementTree:
227
+ root = ET.Element("messages")
228
+
229
+ for thread in conversation:
230
+ message, *replies = thread
231
+ message_elem = ET.SubElement(root, "message")
232
+ text_elem = ET.SubElement(message_elem, "text")
233
+ text_elem.text = message.get("text")
234
+
235
+ for reply in replies:
236
+ reply_msg = reply.get("text", "")
237
+ text_elem.text = "".join([str(text_elem.text), " <reply> ", reply_msg])
238
+
239
+ return ET.ElementTree(root)
240
+
241
+
242
+ slack_source_entry = SourceRegistryEntry(
243
+ indexer=SlackIndexer,
244
+ indexer_config=SlackIndexerConfig,
245
+ downloader=SlackDownloader,
246
+ downloader_config=DownloaderConfig,
247
+ connection_config=SlackConnectionConfig,
248
+ )
@@ -2,12 +2,25 @@ from __future__ import annotations
2
2
 
3
3
  from unstructured_ingest.v2.processes.connector_registry import (
4
4
  add_destination_entry,
5
+ add_source_entry,
5
6
  )
6
7
 
7
8
  from .postgres import CONNECTOR_TYPE as POSTGRES_CONNECTOR_TYPE
8
- from .postgres import postgres_destination_entry
9
+ from .postgres import postgres_destination_entry, postgres_source_entry
10
+ from .singlestore import CONNECTOR_TYPE as SINGLESTORE_CONNECTOR_TYPE
11
+ from .singlestore import singlestore_destination_entry
12
+ from .snowflake import CONNECTOR_TYPE as SNOWFLAKE_CONNECTOR_TYPE
13
+ from .snowflake import snowflake_destination_entry, snowflake_source_entry
9
14
  from .sqlite import CONNECTOR_TYPE as SQLITE_CONNECTOR_TYPE
10
- from .sqlite import sqlite_destination_entry
15
+ from .sqlite import sqlite_destination_entry, sqlite_source_entry
16
+
17
+ add_source_entry(source_type=SQLITE_CONNECTOR_TYPE, entry=sqlite_source_entry)
18
+ add_source_entry(source_type=POSTGRES_CONNECTOR_TYPE, entry=postgres_source_entry)
19
+ add_source_entry(source_type=SNOWFLAKE_CONNECTOR_TYPE, entry=snowflake_source_entry)
11
20
 
12
21
  add_destination_entry(destination_type=SQLITE_CONNECTOR_TYPE, entry=sqlite_destination_entry)
13
22
  add_destination_entry(destination_type=POSTGRES_CONNECTOR_TYPE, entry=postgres_destination_entry)
23
+ add_destination_entry(destination_type=SNOWFLAKE_CONNECTOR_TYPE, entry=snowflake_destination_entry)
24
+ add_destination_entry(
25
+ destination_type=SINGLESTORE_CONNECTOR_TYPE, entry=singlestore_destination_entry
26
+ )
@@ -1,17 +1,17 @@
1
+ from contextlib import contextmanager
1
2
  from dataclasses import dataclass, field
2
- from pathlib import Path
3
- from typing import TYPE_CHECKING, Any, Optional
3
+ from typing import TYPE_CHECKING, Generator, Optional
4
4
 
5
- import numpy as np
6
- import pandas as pd
7
5
  from pydantic import Field, Secret
8
6
 
9
7
  from unstructured_ingest.utils.dep_check import requires_dependencies
10
8
  from unstructured_ingest.v2.interfaces import FileData
11
9
  from unstructured_ingest.v2.logger import logger
12
- from unstructured_ingest.v2.processes.connector_registry import DestinationRegistryEntry
10
+ from unstructured_ingest.v2.processes.connector_registry import (
11
+ DestinationRegistryEntry,
12
+ SourceRegistryEntry,
13
+ )
13
14
  from unstructured_ingest.v2.processes.connectors.sql.sql import (
14
- _DATE_COLUMNS,
15
15
  SQLAccessConfig,
16
16
  SQLConnectionConfig,
17
17
  SQLDownloader,
@@ -22,11 +22,11 @@ from unstructured_ingest.v2.processes.connectors.sql.sql import (
22
22
  SQLUploaderConfig,
23
23
  SQLUploadStager,
24
24
  SQLUploadStagerConfig,
25
- parse_date_string,
26
25
  )
27
26
 
28
27
  if TYPE_CHECKING:
29
28
  from psycopg2.extensions import connection as PostgresConnection
29
+ from psycopg2.extensions import cursor as PostgresCursor
30
30
 
31
31
  CONNECTOR_TYPE = "postgres"
32
32
 
@@ -48,18 +48,33 @@ class PostgresConnectionConfig(SQLConnectionConfig):
48
48
  port: Optional[int] = Field(default=5432, description="DB host connection port")
49
49
  connector_type: str = Field(default=CONNECTOR_TYPE, init=False)
50
50
 
51
+ @contextmanager
51
52
  @requires_dependencies(["psycopg2"], extras="postgres")
52
- def get_connection(self) -> "PostgresConnection":
53
+ def get_connection(self) -> Generator["PostgresConnection", None, None]:
53
54
  from psycopg2 import connect
54
55
 
55
56
  access_config = self.access_config.get_secret_value()
56
- return connect(
57
+ connection = connect(
57
58
  user=self.username,
58
59
  password=access_config.password,
59
60
  dbname=self.database,
60
61
  host=self.host,
61
62
  port=self.port,
62
63
  )
64
+ try:
65
+ yield connection
66
+ finally:
67
+ connection.commit()
68
+ connection.close()
69
+
70
+ @contextmanager
71
+ def get_cursor(self) -> Generator["PostgresCursor", None, None]:
72
+ with self.get_connection() as connection:
73
+ cursor = connection.cursor()
74
+ try:
75
+ yield cursor
76
+ finally:
77
+ cursor.close()
63
78
 
64
79
 
65
80
  class PostgresIndexerConfig(SQLIndexerConfig):
@@ -72,16 +87,6 @@ class PostgresIndexer(SQLIndexer):
72
87
  index_config: PostgresIndexerConfig
73
88
  connector_type: str = CONNECTOR_TYPE
74
89
 
75
- def _get_doc_ids(self) -> list[str]:
76
- connection = self.connection_config.get_connection()
77
- with connection.cursor() as cursor:
78
- cursor.execute(
79
- f"SELECT {self.index_config.id_column} FROM {self.index_config.table_name}"
80
- )
81
- results = cursor.fetchall()
82
- ids = [result[0] for result in results]
83
- return ids
84
-
85
90
 
86
91
  class PostgresDownloaderConfig(SQLDownloaderConfig):
87
92
  pass
@@ -97,8 +102,7 @@ class PostgresDownloader(SQLDownloader):
97
102
  table_name = file_data.additional_metadata["table_name"]
98
103
  id_column = file_data.additional_metadata["id_column"]
99
104
  ids = file_data.additional_metadata["ids"]
100
- connection = self.connection_config.get_connection()
101
- with connection.cursor() as cursor:
105
+ with self.connection_config.get_cursor() as cursor:
102
106
  fields = ",".join(self.download_config.fields) if self.download_config.fields else "*"
103
107
  query = "SELECT {fields} FROM {table_name} WHERE {id_column} in ({ids})".format(
104
108
  fields=fields,
@@ -130,43 +134,16 @@ class PostgresUploader(SQLUploader):
130
134
  upload_config: PostgresUploaderConfig = field(default_factory=PostgresUploaderConfig)
131
135
  connection_config: PostgresConnectionConfig
132
136
  connector_type: str = CONNECTOR_TYPE
137
+ values_delimiter: str = "%s"
133
138
 
134
- def prepare_data(
135
- self, columns: list[str], data: tuple[tuple[Any, ...], ...]
136
- ) -> list[tuple[Any, ...]]:
137
- output = []
138
- for row in data:
139
- parsed = []
140
- for column_name, value in zip(columns, row):
141
- if column_name in _DATE_COLUMNS:
142
- if value is None:
143
- parsed.append(None)
144
- else:
145
- parsed.append(parse_date_string(value))
146
- else:
147
- parsed.append(value)
148
- output.append(tuple(parsed))
149
- return output
150
-
151
- def upload_contents(self, path: Path) -> None:
152
- df = pd.read_json(path, orient="records", lines=True)
153
- logger.debug(f"uploading {len(df)} entries to {self.connection_config.database} ")
154
- df.replace({np.nan: None}, inplace=True)
155
-
156
- columns = tuple(df.columns)
157
- stmt = f"INSERT INTO {self.upload_config.table_name} ({','.join(columns)}) \
158
- VALUES({','.join(['%s' for x in columns])})" # noqa E501
159
-
160
- for rows in pd.read_json(
161
- path, orient="records", lines=True, chunksize=self.upload_config.batch_size
162
- ):
163
- with self.connection_config.get_connection() as conn:
164
- values = self.prepare_data(columns, tuple(rows.itertuples(index=False, name=None)))
165
- with conn.cursor() as cur:
166
- cur.executemany(stmt, values)
167
-
168
- conn.commit()
169
139
 
140
+ postgres_source_entry = SourceRegistryEntry(
141
+ connection_config=PostgresConnectionConfig,
142
+ indexer_config=PostgresIndexerConfig,
143
+ indexer=PostgresIndexer,
144
+ downloader_config=PostgresDownloaderConfig,
145
+ downloader=PostgresDownloader,
146
+ )
170
147
 
171
148
  postgres_destination_entry = DestinationRegistryEntry(
172
149
  connection_config=PostgresConnectionConfig,
@@ -0,0 +1,168 @@
1
+ import json
2
+ from contextlib import contextmanager
3
+ from dataclasses import dataclass, field
4
+ from typing import TYPE_CHECKING, Any, Generator, Optional
5
+
6
+ from pydantic import Field, Secret
7
+
8
+ from unstructured_ingest.v2.interfaces import FileData
9
+ from unstructured_ingest.v2.logger import logger
10
+ from unstructured_ingest.v2.processes.connector_registry import (
11
+ DestinationRegistryEntry,
12
+ SourceRegistryEntry,
13
+ )
14
+ from unstructured_ingest.v2.processes.connectors.sql.sql import (
15
+ _DATE_COLUMNS,
16
+ SQLAccessConfig,
17
+ SQLConnectionConfig,
18
+ SQLDownloader,
19
+ SQLDownloaderConfig,
20
+ SQLIndexer,
21
+ SQLIndexerConfig,
22
+ SQLUploader,
23
+ SQLUploaderConfig,
24
+ SQLUploadStager,
25
+ SQLUploadStagerConfig,
26
+ parse_date_string,
27
+ )
28
+
29
+ if TYPE_CHECKING:
30
+ from singlestoredb.connection import Connection as SingleStoreConnection
31
+ from singlestoredb.connection import Cursor as SingleStoreCursor
32
+
33
+ CONNECTOR_TYPE = "singlestore"
34
+
35
+
36
+ class SingleStoreAccessConfig(SQLAccessConfig):
37
+ password: Optional[str] = Field(default=None, description="SingleStore password")
38
+
39
+
40
+ class SingleStoreConnectionConfig(SQLConnectionConfig):
41
+ access_config: Secret[SingleStoreAccessConfig]
42
+ host: Optional[str] = Field(default=None, description="SingleStore host")
43
+ port: Optional[int] = Field(default=None, description="SingleStore port")
44
+ user: Optional[str] = Field(default=None, description="SingleStore user")
45
+ database: Optional[str] = Field(default=None, description="SingleStore database")
46
+
47
+ @contextmanager
48
+ def get_connection(self) -> Generator["SingleStoreConnection", None, None]:
49
+ import singlestoredb as s2
50
+
51
+ connection = s2.connect(
52
+ host=self.host,
53
+ port=self.port,
54
+ database=self.database,
55
+ user=self.user,
56
+ password=self.access_config.get_secret_value().password,
57
+ )
58
+ try:
59
+ yield connection
60
+ finally:
61
+ connection.commit()
62
+ connection.close()
63
+
64
+ @contextmanager
65
+ def get_cursor(self) -> Generator["SingleStoreCursor", None, None]:
66
+ with self.get_connection() as connection:
67
+ with connection.cursor() as cursor:
68
+ try:
69
+ yield cursor
70
+ finally:
71
+ cursor.close()
72
+
73
+
74
+ class SingleStoreIndexerConfig(SQLIndexerConfig):
75
+ pass
76
+
77
+
78
+ @dataclass
79
+ class SingleStoreIndexer(SQLIndexer):
80
+ connection_config: SingleStoreConnectionConfig
81
+ index_config: SingleStoreIndexerConfig
82
+ connector_type: str = CONNECTOR_TYPE
83
+
84
+
85
+ class SingleStoreDownloaderConfig(SQLDownloaderConfig):
86
+ pass
87
+
88
+
89
+ @dataclass
90
+ class SingleStoreDownloader(SQLDownloader):
91
+ connection_config: SingleStoreConnectionConfig
92
+ download_config: SingleStoreDownloaderConfig
93
+ connector_type: str = CONNECTOR_TYPE
94
+
95
+ def query_db(self, file_data: FileData) -> tuple[list[tuple], list[str]]:
96
+ table_name = file_data.additional_metadata["table_name"]
97
+ id_column = file_data.additional_metadata["id_column"]
98
+ ids = file_data.additional_metadata["ids"]
99
+ with self.connection_config.get_connection() as sqlite_connection:
100
+ cursor = sqlite_connection.cursor()
101
+ fields = ",".join(self.download_config.fields) if self.download_config.fields else "*"
102
+ query = "SELECT {fields} FROM {table_name} WHERE {id_column} in ({ids})".format(
103
+ fields=fields,
104
+ table_name=table_name,
105
+ id_column=id_column,
106
+ ids=",".join([str(i) for i in ids]),
107
+ )
108
+ logger.debug(f"running query: {query}")
109
+ cursor.execute(query)
110
+ rows = cursor.fetchall()
111
+ columns = [col[0] for col in cursor.description]
112
+ return rows, columns
113
+
114
+
115
+ class SingleStoreUploadStagerConfig(SQLUploadStagerConfig):
116
+ pass
117
+
118
+
119
+ class SingleStoreUploadStager(SQLUploadStager):
120
+ upload_stager_config: SingleStoreUploadStagerConfig
121
+
122
+
123
+ class SingleStoreUploaderConfig(SQLUploaderConfig):
124
+ pass
125
+
126
+
127
+ @dataclass
128
+ class SingleStoreUploader(SQLUploader):
129
+ upload_config: SingleStoreUploaderConfig = field(default_factory=SingleStoreUploaderConfig)
130
+ connection_config: SingleStoreConnectionConfig
131
+ values_delimiter: str = "%s"
132
+ connector_type: str = CONNECTOR_TYPE
133
+
134
+ def prepare_data(
135
+ self, columns: list[str], data: tuple[tuple[Any, ...], ...]
136
+ ) -> list[tuple[Any, ...]]:
137
+ output = []
138
+ for row in data:
139
+ parsed = []
140
+ for column_name, value in zip(columns, row):
141
+ if isinstance(value, (list, dict)):
142
+ value = json.dumps(value)
143
+ if column_name in _DATE_COLUMNS:
144
+ if value is None:
145
+ parsed.append(None)
146
+ else:
147
+ parsed.append(parse_date_string(value))
148
+ else:
149
+ parsed.append(value)
150
+ output.append(tuple(parsed))
151
+ return output
152
+
153
+
154
+ singlestore_source_entry = SourceRegistryEntry(
155
+ connection_config=SingleStoreConnectionConfig,
156
+ indexer_config=SingleStoreIndexerConfig,
157
+ indexer=SQLIndexer,
158
+ downloader_config=SingleStoreDownloaderConfig,
159
+ downloader=SingleStoreDownloader,
160
+ )
161
+
162
+ singlestore_destination_entry = DestinationRegistryEntry(
163
+ connection_config=SingleStoreConnectionConfig,
164
+ uploader=SingleStoreUploader,
165
+ uploader_config=SingleStoreUploaderConfig,
166
+ upload_stager=SingleStoreUploadStager,
167
+ upload_stager_config=SingleStoreUploadStagerConfig,
168
+ )