unstructured-ingest 0.3.11__py3-none-any.whl → 0.3.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of unstructured-ingest might be problematic. Click here for more details.
- test/integration/connectors/test_milvus.py +13 -0
- test/integration/connectors/test_onedrive.py +6 -0
- test/integration/connectors/test_redis.py +119 -0
- test/integration/connectors/test_vectara.py +270 -0
- test/integration/embedders/test_bedrock.py +28 -0
- test/integration/embedders/test_octoai.py +14 -0
- test/integration/embedders/test_openai.py +13 -0
- test/integration/embedders/test_togetherai.py +10 -0
- test/integration/partitioners/test_partitioner.py +2 -2
- test/unit/embed/test_octoai.py +8 -1
- unstructured_ingest/__version__.py +1 -1
- unstructured_ingest/embed/bedrock.py +39 -11
- unstructured_ingest/embed/interfaces.py +5 -0
- unstructured_ingest/embed/octoai.py +44 -3
- unstructured_ingest/embed/openai.py +37 -1
- unstructured_ingest/embed/togetherai.py +28 -1
- unstructured_ingest/embed/voyageai.py +33 -1
- unstructured_ingest/v2/errors.py +18 -0
- unstructured_ingest/v2/processes/connectors/__init__.py +7 -0
- unstructured_ingest/v2/processes/connectors/chroma.py +0 -1
- unstructured_ingest/v2/processes/connectors/kafka/cloud.py +5 -2
- unstructured_ingest/v2/processes/connectors/kafka/kafka.py +14 -3
- unstructured_ingest/v2/processes/connectors/milvus.py +15 -6
- unstructured_ingest/v2/processes/connectors/neo4j.py +2 -0
- unstructured_ingest/v2/processes/connectors/onedrive.py +79 -25
- unstructured_ingest/v2/processes/connectors/qdrant/qdrant.py +0 -1
- unstructured_ingest/v2/processes/connectors/redisdb.py +182 -0
- unstructured_ingest/v2/processes/connectors/vectara.py +350 -0
- unstructured_ingest/v2/unstructured_api.py +25 -2
- {unstructured_ingest-0.3.11.dist-info → unstructured_ingest-0.3.12.dist-info}/METADATA +23 -19
- {unstructured_ingest-0.3.11.dist-info → unstructured_ingest-0.3.12.dist-info}/RECORD +35 -31
- test/integration/connectors/test_kafka.py +0 -304
- {unstructured_ingest-0.3.11.dist-info → unstructured_ingest-0.3.12.dist-info}/LICENSE.md +0 -0
- {unstructured_ingest-0.3.11.dist-info → unstructured_ingest-0.3.12.dist-info}/WHEEL +0 -0
- {unstructured_ingest-0.3.11.dist-info → unstructured_ingest-0.3.12.dist-info}/entry_points.txt +0 -0
- {unstructured_ingest-0.3.11.dist-info → unstructured_ingest-0.3.12.dist-info}/top_level.txt +0 -0
|
@@ -4,7 +4,15 @@ from typing import TYPE_CHECKING
|
|
|
4
4
|
from pydantic import Field, SecretStr
|
|
5
5
|
|
|
6
6
|
from unstructured_ingest.embed.interfaces import BaseEmbeddingEncoder, EmbeddingConfig
|
|
7
|
+
from unstructured_ingest.logger import logger
|
|
7
8
|
from unstructured_ingest.utils.dep_check import requires_dependencies
|
|
9
|
+
from unstructured_ingest.v2.errors import (
|
|
10
|
+
ProviderError,
|
|
11
|
+
QuotaError,
|
|
12
|
+
RateLimitError,
|
|
13
|
+
UserAuthError,
|
|
14
|
+
UserError,
|
|
15
|
+
)
|
|
8
16
|
|
|
9
17
|
if TYPE_CHECKING:
|
|
10
18
|
from openai import OpenAI
|
|
@@ -30,12 +38,45 @@ class OctoAiEmbeddingConfig(EmbeddingConfig):
|
|
|
30
38
|
class OctoAIEmbeddingEncoder(BaseEmbeddingEncoder):
|
|
31
39
|
config: OctoAiEmbeddingConfig
|
|
32
40
|
|
|
41
|
+
def wrap_error(self, e: Exception) -> Exception:
|
|
42
|
+
# https://platform.openai.com/docs/guides/error-codes/api-errors
|
|
43
|
+
from openai import APIStatusError
|
|
44
|
+
|
|
45
|
+
if not isinstance(e, APIStatusError):
|
|
46
|
+
logger.error(f"unhandled exception from openai: {e}", exc_info=True)
|
|
47
|
+
raise e
|
|
48
|
+
error_code = e.code
|
|
49
|
+
if 400 <= e.status_code < 500:
|
|
50
|
+
# user error
|
|
51
|
+
if e.status_code == 401:
|
|
52
|
+
return UserAuthError(e.message)
|
|
53
|
+
if e.status_code == 429:
|
|
54
|
+
# 429 indicates rate limit exceeded and quote exceeded
|
|
55
|
+
if error_code == "insufficient_quota":
|
|
56
|
+
return QuotaError(e.message)
|
|
57
|
+
else:
|
|
58
|
+
return RateLimitError(e.message)
|
|
59
|
+
return UserError(e.message)
|
|
60
|
+
if e.status_code >= 500:
|
|
61
|
+
return ProviderError(e.message)
|
|
62
|
+
logger.error(f"unhandled exception from openai: {e}", exc_info=True)
|
|
63
|
+
return e
|
|
64
|
+
|
|
33
65
|
def embed_query(self, query: str):
|
|
34
|
-
|
|
35
|
-
|
|
66
|
+
try:
|
|
67
|
+
client = self.config.get_client()
|
|
68
|
+
response = client.embeddings.create(input=query, model=self.config.embedder_model_name)
|
|
69
|
+
except Exception as e:
|
|
70
|
+
raise self.wrap_error(e=e)
|
|
36
71
|
return response.data[0].embedding
|
|
37
72
|
|
|
38
73
|
def embed_documents(self, elements: list[dict]) -> list[dict]:
|
|
39
|
-
|
|
74
|
+
texts = [e.get("text", "") for e in elements]
|
|
75
|
+
try:
|
|
76
|
+
client = self.config.get_client()
|
|
77
|
+
response = client.embeddings.create(input=texts, model=self.config.embedder_model_name)
|
|
78
|
+
except Exception as e:
|
|
79
|
+
raise self.wrap_error(e=e)
|
|
80
|
+
embeddings = [data.embedding for data in response.data]
|
|
40
81
|
elements_with_embeddings = self._add_embeddings_to_elements(elements, embeddings)
|
|
41
82
|
return elements_with_embeddings
|
|
@@ -4,7 +4,15 @@ from typing import TYPE_CHECKING
|
|
|
4
4
|
from pydantic import Field, SecretStr
|
|
5
5
|
|
|
6
6
|
from unstructured_ingest.embed.interfaces import BaseEmbeddingEncoder, EmbeddingConfig
|
|
7
|
+
from unstructured_ingest.logger import logger
|
|
7
8
|
from unstructured_ingest.utils.dep_check import requires_dependencies
|
|
9
|
+
from unstructured_ingest.v2.errors import (
|
|
10
|
+
ProviderError,
|
|
11
|
+
QuotaError,
|
|
12
|
+
RateLimitError,
|
|
13
|
+
UserAuthError,
|
|
14
|
+
UserError,
|
|
15
|
+
)
|
|
8
16
|
|
|
9
17
|
if TYPE_CHECKING:
|
|
10
18
|
from openai import OpenAI
|
|
@@ -25,9 +33,37 @@ class OpenAIEmbeddingConfig(EmbeddingConfig):
|
|
|
25
33
|
class OpenAIEmbeddingEncoder(BaseEmbeddingEncoder):
|
|
26
34
|
config: OpenAIEmbeddingConfig
|
|
27
35
|
|
|
36
|
+
def wrap_error(self, e: Exception) -> Exception:
|
|
37
|
+
# https://platform.openai.com/docs/guides/error-codes/api-errors
|
|
38
|
+
from openai import APIStatusError
|
|
39
|
+
|
|
40
|
+
if not isinstance(e, APIStatusError):
|
|
41
|
+
logger.error(f"unhandled exception from openai: {e}", exc_info=True)
|
|
42
|
+
raise e
|
|
43
|
+
error_code = e.code
|
|
44
|
+
if 400 <= e.status_code < 500:
|
|
45
|
+
# user error
|
|
46
|
+
if e.status_code == 401:
|
|
47
|
+
return UserAuthError(e.message)
|
|
48
|
+
if e.status_code == 429:
|
|
49
|
+
# 429 indicates rate limit exceeded and quote exceeded
|
|
50
|
+
if error_code == "insufficient_quota":
|
|
51
|
+
return QuotaError(e.message)
|
|
52
|
+
else:
|
|
53
|
+
return RateLimitError(e.message)
|
|
54
|
+
return UserError(e.message)
|
|
55
|
+
if e.status_code >= 500:
|
|
56
|
+
return ProviderError(e.message)
|
|
57
|
+
logger.error(f"unhandled exception from openai: {e}", exc_info=True)
|
|
58
|
+
return e
|
|
59
|
+
|
|
28
60
|
def embed_query(self, query: str) -> list[float]:
|
|
61
|
+
|
|
29
62
|
client = self.config.get_client()
|
|
30
|
-
|
|
63
|
+
try:
|
|
64
|
+
response = client.embeddings.create(input=query, model=self.config.embedder_model_name)
|
|
65
|
+
except Exception as e:
|
|
66
|
+
raise self.wrap_error(e=e)
|
|
31
67
|
return response.data[0].embedding
|
|
32
68
|
|
|
33
69
|
def embed_documents(self, elements: list[dict]) -> list[dict]:
|
|
@@ -4,7 +4,15 @@ from typing import TYPE_CHECKING
|
|
|
4
4
|
from pydantic import Field, SecretStr
|
|
5
5
|
|
|
6
6
|
from unstructured_ingest.embed.interfaces import BaseEmbeddingEncoder, EmbeddingConfig
|
|
7
|
+
from unstructured_ingest.logger import logger
|
|
7
8
|
from unstructured_ingest.utils.dep_check import requires_dependencies
|
|
9
|
+
from unstructured_ingest.v2.errors import (
|
|
10
|
+
RateLimitError as CustomRateLimitError,
|
|
11
|
+
)
|
|
12
|
+
from unstructured_ingest.v2.errors import (
|
|
13
|
+
UserAuthError,
|
|
14
|
+
UserError,
|
|
15
|
+
)
|
|
8
16
|
|
|
9
17
|
if TYPE_CHECKING:
|
|
10
18
|
from together import Together
|
|
@@ -27,6 +35,20 @@ class TogetherAIEmbeddingConfig(EmbeddingConfig):
|
|
|
27
35
|
class TogetherAIEmbeddingEncoder(BaseEmbeddingEncoder):
|
|
28
36
|
config: TogetherAIEmbeddingConfig
|
|
29
37
|
|
|
38
|
+
def wrap_error(self, e: Exception) -> Exception:
|
|
39
|
+
# https://docs.together.ai/docs/error-codes
|
|
40
|
+
from together.error import AuthenticationError, RateLimitError, TogetherException
|
|
41
|
+
|
|
42
|
+
if not isinstance(e, TogetherException):
|
|
43
|
+
logger.error(f"unhandled exception from openai: {e}", exc_info=True)
|
|
44
|
+
return e
|
|
45
|
+
message = e.args[0]
|
|
46
|
+
if isinstance(e, AuthenticationError):
|
|
47
|
+
return UserAuthError(message)
|
|
48
|
+
if isinstance(e, RateLimitError):
|
|
49
|
+
return CustomRateLimitError(message)
|
|
50
|
+
return UserError(message)
|
|
51
|
+
|
|
30
52
|
def embed_query(self, query: str) -> list[float]:
|
|
31
53
|
return self._embed_documents(elements=[query])[0]
|
|
32
54
|
|
|
@@ -36,5 +58,10 @@ class TogetherAIEmbeddingEncoder(BaseEmbeddingEncoder):
|
|
|
36
58
|
|
|
37
59
|
def _embed_documents(self, elements: list[str]) -> list[list[float]]:
|
|
38
60
|
client = self.config.get_client()
|
|
39
|
-
|
|
61
|
+
try:
|
|
62
|
+
outputs = client.embeddings.create(
|
|
63
|
+
model=self.config.embedder_model_name, input=elements
|
|
64
|
+
)
|
|
65
|
+
except Exception as e:
|
|
66
|
+
raise self.wrap_error(e=e)
|
|
40
67
|
return [outputs.data[i].embedding for i in range(len(elements))]
|
|
@@ -4,7 +4,16 @@ from typing import TYPE_CHECKING, Optional
|
|
|
4
4
|
from pydantic import Field, SecretStr
|
|
5
5
|
|
|
6
6
|
from unstructured_ingest.embed.interfaces import BaseEmbeddingEncoder, EmbeddingConfig
|
|
7
|
+
from unstructured_ingest.logger import logger
|
|
7
8
|
from unstructured_ingest.utils.dep_check import requires_dependencies
|
|
9
|
+
from unstructured_ingest.v2.errors import (
|
|
10
|
+
ProviderError,
|
|
11
|
+
UserAuthError,
|
|
12
|
+
UserError,
|
|
13
|
+
)
|
|
14
|
+
from unstructured_ingest.v2.errors import (
|
|
15
|
+
RateLimitError as CustomRateLimitError,
|
|
16
|
+
)
|
|
8
17
|
|
|
9
18
|
if TYPE_CHECKING:
|
|
10
19
|
from voyageai import Client as VoyageAIClient
|
|
@@ -38,9 +47,32 @@ class VoyageAIEmbeddingConfig(EmbeddingConfig):
|
|
|
38
47
|
class VoyageAIEmbeddingEncoder(BaseEmbeddingEncoder):
|
|
39
48
|
config: VoyageAIEmbeddingConfig
|
|
40
49
|
|
|
50
|
+
def wrap_error(self, e: Exception) -> Exception:
|
|
51
|
+
# https://docs.voyageai.com/docs/error-codes
|
|
52
|
+
from voyageai.error import AuthenticationError, RateLimitError, VoyageError
|
|
53
|
+
|
|
54
|
+
if not isinstance(e, VoyageError):
|
|
55
|
+
logger.error(f"unhandled exception from openai: {e}", exc_info=True)
|
|
56
|
+
raise e
|
|
57
|
+
http_code = e.http_status
|
|
58
|
+
message = e.user_message
|
|
59
|
+
if isinstance(e, AuthenticationError):
|
|
60
|
+
return UserAuthError(message)
|
|
61
|
+
if isinstance(e, RateLimitError):
|
|
62
|
+
return CustomRateLimitError(message)
|
|
63
|
+
if 400 <= http_code < 500:
|
|
64
|
+
return UserError(message)
|
|
65
|
+
if http_code >= 500:
|
|
66
|
+
return ProviderError(message)
|
|
67
|
+
logger.error(f"unhandled exception from openai: {e}", exc_info=True)
|
|
68
|
+
return e
|
|
69
|
+
|
|
41
70
|
def _embed_documents(self, elements: list[str]) -> list[list[float]]:
|
|
42
71
|
client: VoyageAIClient = self.config.get_client()
|
|
43
|
-
|
|
72
|
+
try:
|
|
73
|
+
response = client.embed(texts=elements, model=self.config.embedder_model_name)
|
|
74
|
+
except Exception as e:
|
|
75
|
+
self.wrap_error(e=e)
|
|
44
76
|
return response.embeddings
|
|
45
77
|
|
|
46
78
|
def embed_documents(self, elements: list[dict]) -> list[dict]:
|
|
@@ -48,12 +48,16 @@ from .outlook import CONNECTOR_TYPE as OUTLOOK_CONNECTOR_TYPE
|
|
|
48
48
|
from .outlook import outlook_source_entry
|
|
49
49
|
from .pinecone import CONNECTOR_TYPE as PINECONE_CONNECTOR_TYPE
|
|
50
50
|
from .pinecone import pinecone_destination_entry
|
|
51
|
+
from .redisdb import CONNECTOR_TYPE as REDIS_CONNECTOR_TYPE
|
|
52
|
+
from .redisdb import redis_destination_entry
|
|
51
53
|
from .salesforce import CONNECTOR_TYPE as SALESFORCE_CONNECTOR_TYPE
|
|
52
54
|
from .salesforce import salesforce_source_entry
|
|
53
55
|
from .sharepoint import CONNECTOR_TYPE as SHAREPOINT_CONNECTOR_TYPE
|
|
54
56
|
from .sharepoint import sharepoint_source_entry
|
|
55
57
|
from .slack import CONNECTOR_TYPE as SLACK_CONNECTOR_TYPE
|
|
56
58
|
from .slack import slack_source_entry
|
|
59
|
+
from .vectara import CONNECTOR_TYPE as VECTARA_CONNECTOR_TYPE
|
|
60
|
+
from .vectara import vectara_destination_entry
|
|
57
61
|
|
|
58
62
|
add_source_entry(source_type=ASTRA_DB_CONNECTOR_TYPE, entry=astra_db_source_entry)
|
|
59
63
|
add_destination_entry(destination_type=ASTRA_DB_CONNECTOR_TYPE, entry=astra_db_destination_entry)
|
|
@@ -101,4 +105,7 @@ add_source_entry(source_type=GITLAB_CONNECTOR_TYPE, entry=gitlab_source_entry)
|
|
|
101
105
|
|
|
102
106
|
add_source_entry(source_type=SLACK_CONNECTOR_TYPE, entry=slack_source_entry)
|
|
103
107
|
|
|
108
|
+
add_destination_entry(destination_type=VECTARA_CONNECTOR_TYPE, entry=vectara_destination_entry)
|
|
104
109
|
add_source_entry(source_type=CONFLUENCE_CONNECTOR_TYPE, entry=confluence_source_entry)
|
|
110
|
+
|
|
111
|
+
add_destination_entry(destination_type=REDIS_CONNECTOR_TYPE, entry=redis_destination_entry)
|
|
@@ -4,6 +4,7 @@ from typing import TYPE_CHECKING
|
|
|
4
4
|
|
|
5
5
|
from pydantic import Field, Secret, SecretStr
|
|
6
6
|
|
|
7
|
+
from unstructured_ingest.v2.logger import logger
|
|
7
8
|
from unstructured_ingest.v2.processes.connector_registry import (
|
|
8
9
|
DestinationRegistryEntry,
|
|
9
10
|
SourceRegistryEntry,
|
|
@@ -50,6 +51,7 @@ class CloudKafkaConnectionConfig(KafkaConnectionConfig):
|
|
|
50
51
|
"sasl.password": access_config.secret.get_secret_value(),
|
|
51
52
|
"sasl.mechanism": "PLAIN",
|
|
52
53
|
"security.protocol": "SASL_SSL",
|
|
54
|
+
"logger": logger,
|
|
53
55
|
}
|
|
54
56
|
|
|
55
57
|
return conf
|
|
@@ -61,10 +63,11 @@ class CloudKafkaConnectionConfig(KafkaConnectionConfig):
|
|
|
61
63
|
|
|
62
64
|
conf = {
|
|
63
65
|
"bootstrap.servers": f"{bootstrap}:{port}",
|
|
64
|
-
"sasl.username": access_config.kafka_api_key,
|
|
65
|
-
"sasl.password": access_config.secret,
|
|
66
|
+
"sasl.username": access_config.kafka_api_key.get_secret_value(),
|
|
67
|
+
"sasl.password": access_config.secret.get_secret_value(),
|
|
66
68
|
"sasl.mechanism": "PLAIN",
|
|
67
69
|
"security.protocol": "SASL_SSL",
|
|
70
|
+
"logger": logger,
|
|
68
71
|
}
|
|
69
72
|
|
|
70
73
|
return conf
|
|
@@ -170,7 +170,7 @@ class KafkaIndexer(Indexer, ABC):
|
|
|
170
170
|
]
|
|
171
171
|
if self.index_config.topic not in current_topics:
|
|
172
172
|
raise SourceConnectionError(
|
|
173
|
-
"expected topic {} not detected in cluster: {}".format(
|
|
173
|
+
"expected topic '{}' not detected in cluster: '{}'".format(
|
|
174
174
|
self.index_config.topic, ", ".join(current_topics)
|
|
175
175
|
)
|
|
176
176
|
)
|
|
@@ -232,6 +232,13 @@ class KafkaUploader(Uploader, ABC):
|
|
|
232
232
|
topic for topic in cluster_meta.topics if topic != "__consumer_offsets"
|
|
233
233
|
]
|
|
234
234
|
logger.info(f"successfully checked available topics: {current_topics}")
|
|
235
|
+
if self.upload_config.topic not in current_topics:
|
|
236
|
+
raise DestinationConnectionError(
|
|
237
|
+
"expected topic '{}' not detected in cluster: '{}'".format(
|
|
238
|
+
self.upload_config.topic, ", ".join(current_topics)
|
|
239
|
+
)
|
|
240
|
+
)
|
|
241
|
+
|
|
235
242
|
except Exception as e:
|
|
236
243
|
logger.error(f"failed to validate connection: {e}", exc_info=True)
|
|
237
244
|
raise DestinationConnectionError(f"failed to validate connection: {e}")
|
|
@@ -243,8 +250,10 @@ class KafkaUploader(Uploader, ABC):
|
|
|
243
250
|
failed_producer = False
|
|
244
251
|
|
|
245
252
|
def acked(err, msg):
|
|
253
|
+
nonlocal failed_producer
|
|
246
254
|
if err is not None:
|
|
247
|
-
|
|
255
|
+
failed_producer = True
|
|
256
|
+
logger.error("Failed to deliver kafka message: %s: %s" % (str(msg), str(err)))
|
|
248
257
|
|
|
249
258
|
for element in elements:
|
|
250
259
|
producer.produce(
|
|
@@ -253,7 +262,9 @@ class KafkaUploader(Uploader, ABC):
|
|
|
253
262
|
callback=acked,
|
|
254
263
|
)
|
|
255
264
|
|
|
256
|
-
producer
|
|
265
|
+
while producer_len := len(producer):
|
|
266
|
+
logger.debug(f"another iteration of kafka producer flush. Queue length: {producer_len}")
|
|
267
|
+
producer.flush(timeout=self.upload_config.timeout)
|
|
257
268
|
if failed_producer:
|
|
258
269
|
raise KafkaException("failed to produce all messages in batch")
|
|
259
270
|
|
|
@@ -156,11 +156,18 @@ class MilvusUploader(Uploader):
|
|
|
156
156
|
|
|
157
157
|
@DestinationConnectionError.wrap
|
|
158
158
|
def precheck(self):
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
)
|
|
159
|
+
from pymilvus import MilvusException
|
|
160
|
+
|
|
161
|
+
try:
|
|
162
|
+
with self.get_client() as client:
|
|
163
|
+
if not client.has_collection(self.upload_config.collection_name):
|
|
164
|
+
raise DestinationConnectionError(
|
|
165
|
+
f"Collection '{self.upload_config.collection_name}' does not exist"
|
|
166
|
+
)
|
|
167
|
+
except MilvusException as milvus_exception:
|
|
168
|
+
raise DestinationConnectionError(
|
|
169
|
+
f"failed to precheck Milvus: {str(milvus_exception.message)}"
|
|
170
|
+
) from milvus_exception
|
|
164
171
|
|
|
165
172
|
@contextmanager
|
|
166
173
|
def get_client(self) -> Generator["MilvusClient", None, None]:
|
|
@@ -197,7 +204,9 @@ class MilvusUploader(Uploader):
|
|
|
197
204
|
try:
|
|
198
205
|
res = client.insert(collection_name=self.upload_config.collection_name, data=data)
|
|
199
206
|
except MilvusException as milvus_exception:
|
|
200
|
-
raise WriteError(
|
|
207
|
+
raise WriteError(
|
|
208
|
+
f"failed to upload records to Milvus: {str(milvus_exception.message)}"
|
|
209
|
+
) from milvus_exception
|
|
201
210
|
if "err_count" in res and isinstance(res["err_count"], int) and res["err_count"] > 0:
|
|
202
211
|
err_count = res["err_count"]
|
|
203
212
|
raise WriteError(f"failed to upload {err_count} docs")
|
|
@@ -378,6 +378,8 @@ class Neo4jUploader(Uploader):
|
|
|
378
378
|
|
|
379
379
|
neo4j_destination_entry = DestinationRegistryEntry(
|
|
380
380
|
connection_config=Neo4jConnectionConfig,
|
|
381
|
+
upload_stager=Neo4jUploadStager,
|
|
382
|
+
upload_stager_config=Neo4jUploadStagerConfig,
|
|
381
383
|
uploader=Neo4jUploader,
|
|
382
384
|
uploader_config=Neo4jUploaderConfig,
|
|
383
385
|
)
|
|
@@ -1,10 +1,11 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import asyncio
|
|
3
4
|
import json
|
|
4
5
|
from dataclasses import dataclass
|
|
5
6
|
from pathlib import Path
|
|
6
7
|
from time import time
|
|
7
|
-
from typing import TYPE_CHECKING, Any, Generator, Optional
|
|
8
|
+
from typing import TYPE_CHECKING, Any, AsyncIterator, Generator, Iterator, Optional, TypeVar
|
|
8
9
|
|
|
9
10
|
from dateutil import parser
|
|
10
11
|
from pydantic import Field, Secret
|
|
@@ -100,6 +101,27 @@ class OnedriveIndexerConfig(IndexerConfig):
|
|
|
100
101
|
recursive: bool = False
|
|
101
102
|
|
|
102
103
|
|
|
104
|
+
T = TypeVar("T")
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def async_iterable_to_sync_iterable(iterator: AsyncIterator[T]) -> Iterator[T]:
|
|
108
|
+
# This version works on Python 3.9 by manually handling the async iteration.
|
|
109
|
+
loop = asyncio.new_event_loop()
|
|
110
|
+
asyncio.set_event_loop(loop)
|
|
111
|
+
try:
|
|
112
|
+
while True:
|
|
113
|
+
try:
|
|
114
|
+
# Instead of anext(iterator), we directly call __anext__().
|
|
115
|
+
# __anext__ returns a coroutine that we must run until complete.
|
|
116
|
+
future = iterator.__anext__()
|
|
117
|
+
result = loop.run_until_complete(future)
|
|
118
|
+
yield result
|
|
119
|
+
except StopAsyncIteration:
|
|
120
|
+
break
|
|
121
|
+
finally:
|
|
122
|
+
loop.close()
|
|
123
|
+
|
|
124
|
+
|
|
103
125
|
@dataclass
|
|
104
126
|
class OnedriveIndexer(Indexer):
|
|
105
127
|
connection_config: OnedriveConnectionConfig
|
|
@@ -116,17 +138,21 @@ class OnedriveIndexer(Indexer):
|
|
|
116
138
|
logger.error(f"failed to validate connection: {e}", exc_info=True)
|
|
117
139
|
raise SourceConnectionError(f"failed to validate connection: {e}")
|
|
118
140
|
|
|
119
|
-
def
|
|
141
|
+
def list_objects_sync(self, folder: DriveItem, recursive: bool) -> list["DriveItem"]:
|
|
120
142
|
drive_items = folder.children.get().execute_query()
|
|
121
143
|
files = [d for d in drive_items if d.is_file]
|
|
122
144
|
if not recursive:
|
|
123
145
|
return files
|
|
146
|
+
|
|
124
147
|
folders = [d for d in drive_items if d.is_folder]
|
|
125
148
|
for f in folders:
|
|
126
|
-
files.extend(self.
|
|
149
|
+
files.extend(self.list_objects_sync(f, recursive))
|
|
127
150
|
return files
|
|
128
151
|
|
|
129
|
-
def
|
|
152
|
+
async def list_objects(self, folder: "DriveItem", recursive: bool) -> list["DriveItem"]:
|
|
153
|
+
return await asyncio.to_thread(self.list_objects_sync, folder, recursive)
|
|
154
|
+
|
|
155
|
+
def get_root_sync(self, client: "GraphClient") -> "DriveItem":
|
|
130
156
|
root = client.users[self.connection_config.user_pname].drive.get().execute_query().root
|
|
131
157
|
if fpath := self.index_config.path:
|
|
132
158
|
root = root.get_by_path(fpath).get().execute_query()
|
|
@@ -134,7 +160,10 @@ class OnedriveIndexer(Indexer):
|
|
|
134
160
|
raise ValueError(f"Unable to find directory, given: {fpath}")
|
|
135
161
|
return root
|
|
136
162
|
|
|
137
|
-
def
|
|
163
|
+
async def get_root(self, client: "GraphClient") -> "DriveItem":
|
|
164
|
+
return await asyncio.to_thread(self.get_root_sync, client)
|
|
165
|
+
|
|
166
|
+
def get_properties_sync(self, drive_item: "DriveItem") -> dict:
|
|
138
167
|
properties = drive_item.properties
|
|
139
168
|
filtered_properties = {}
|
|
140
169
|
for k, v in properties.items():
|
|
@@ -145,7 +174,10 @@ class OnedriveIndexer(Indexer):
|
|
|
145
174
|
pass
|
|
146
175
|
return filtered_properties
|
|
147
176
|
|
|
148
|
-
def
|
|
177
|
+
async def get_properties(self, drive_item: "DriveItem") -> dict:
|
|
178
|
+
return await asyncio.to_thread(self.get_properties_sync, drive_item)
|
|
179
|
+
|
|
180
|
+
def drive_item_to_file_data_sync(self, drive_item: "DriveItem") -> FileData:
|
|
149
181
|
file_path = drive_item.parent_reference.path.split(":")[-1]
|
|
150
182
|
file_path = file_path[1:] if file_path and file_path[0] == "/" else file_path
|
|
151
183
|
filename = drive_item.name
|
|
@@ -176,17 +208,34 @@ class OnedriveIndexer(Indexer):
|
|
|
176
208
|
"server_relative_path": server_path,
|
|
177
209
|
},
|
|
178
210
|
),
|
|
179
|
-
additional_metadata=self.
|
|
211
|
+
additional_metadata=self.get_properties_sync(drive_item=drive_item),
|
|
180
212
|
)
|
|
181
213
|
|
|
182
|
-
def
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
214
|
+
async def drive_item_to_file_data(self, drive_item: "DriveItem") -> FileData:
|
|
215
|
+
# Offload the file data creation if it's not guaranteed async
|
|
216
|
+
return await asyncio.to_thread(self.drive_item_to_file_data_sync, drive_item)
|
|
217
|
+
|
|
218
|
+
async def _run_async(self, **kwargs: Any) -> AsyncIterator[FileData]:
|
|
219
|
+
token_resp = await asyncio.to_thread(self.connection_config.get_token)
|
|
220
|
+
if "error" in token_resp:
|
|
221
|
+
raise SourceConnectionError(
|
|
222
|
+
f"[{CONNECTOR_TYPE}]: {token_resp['error']} ({token_resp.get('error_description')})"
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
client = await asyncio.to_thread(self.connection_config.get_client)
|
|
226
|
+
root = await self.get_root(client=client)
|
|
227
|
+
drive_items = await self.list_objects(folder=root, recursive=self.index_config.recursive)
|
|
228
|
+
|
|
186
229
|
for drive_item in drive_items:
|
|
187
|
-
file_data = self.drive_item_to_file_data(drive_item=drive_item)
|
|
230
|
+
file_data = await self.drive_item_to_file_data(drive_item=drive_item)
|
|
188
231
|
yield file_data
|
|
189
232
|
|
|
233
|
+
def run(self, **kwargs: Any) -> Generator[FileData, None, None]:
|
|
234
|
+
# Convert the async generator to a sync generator without loading all data into memory
|
|
235
|
+
async_gen = self._run_async(**kwargs)
|
|
236
|
+
for item in async_iterable_to_sync_iterable(async_gen):
|
|
237
|
+
yield item
|
|
238
|
+
|
|
190
239
|
|
|
191
240
|
class OnedriveDownloaderConfig(DownloaderConfig):
|
|
192
241
|
pass
|
|
@@ -220,19 +269,24 @@ class OnedriveDownloader(Downloader):
|
|
|
220
269
|
|
|
221
270
|
@SourceConnectionError.wrap
|
|
222
271
|
def run(self, file_data: FileData, **kwargs: Any) -> DownloadResponse:
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
272
|
+
try:
|
|
273
|
+
file = self._fetch_file(file_data=file_data)
|
|
274
|
+
fsize = file.get_property("size", 0)
|
|
275
|
+
download_path = self.get_download_path(file_data=file_data)
|
|
276
|
+
download_path.parent.mkdir(parents=True, exist_ok=True)
|
|
277
|
+
logger.info(f"downloading {file_data.source_identifiers.fullpath} to {download_path}")
|
|
278
|
+
if fsize > MAX_MB_SIZE:
|
|
279
|
+
logger.info(f"downloading file with size: {fsize} bytes in chunks")
|
|
280
|
+
with download_path.open(mode="wb") as f:
|
|
281
|
+
file.download_session(f, chunk_size=1024 * 1024 * 100).execute_query()
|
|
282
|
+
else:
|
|
283
|
+
with download_path.open(mode="wb") as f:
|
|
284
|
+
file.download(f).execute_query()
|
|
285
|
+
return self.generate_download_response(file_data=file_data, download_path=download_path)
|
|
286
|
+
except Exception as e:
|
|
287
|
+
logger.error(f"[{CONNECTOR_TYPE}] Exception during downloading: {e}", exc_info=True)
|
|
288
|
+
# Re-raise to see full stack trace locally
|
|
289
|
+
raise
|
|
236
290
|
|
|
237
291
|
|
|
238
292
|
class OnedriveUploaderConfig(UploaderConfig):
|