unstructured-ingest 0.5.1__py3-none-any.whl → 0.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of unstructured-ingest might be problematic. Click here for more details.

Files changed (25) hide show
  1. test/integration/connectors/test_google_drive.py +141 -0
  2. test/unit/v2/embedders/test_bedrock.py +1 -1
  3. test/unit/v2/embedders/test_huggingface.py +1 -1
  4. unstructured_ingest/__version__.py +1 -1
  5. unstructured_ingest/embed/azure_openai.py +6 -0
  6. unstructured_ingest/embed/bedrock.py +29 -12
  7. unstructured_ingest/embed/huggingface.py +14 -5
  8. unstructured_ingest/embed/interfaces.py +63 -44
  9. unstructured_ingest/embed/mixedbreadai.py +28 -105
  10. unstructured_ingest/embed/octoai.py +19 -44
  11. unstructured_ingest/embed/openai.py +17 -48
  12. unstructured_ingest/embed/togetherai.py +16 -49
  13. unstructured_ingest/embed/vertexai.py +15 -39
  14. unstructured_ingest/embed/voyageai.py +16 -42
  15. unstructured_ingest/v2/errors.py +7 -0
  16. unstructured_ingest/v2/processes/connectors/google_drive.py +132 -3
  17. unstructured_ingest/v2/processes/connectors/neo4j.py +129 -43
  18. unstructured_ingest/v2/processes/connectors/sql/snowflake.py +53 -3
  19. unstructured_ingest/v2/processes/embedder.py +9 -7
  20. {unstructured_ingest-0.5.1.dist-info → unstructured_ingest-0.5.3.dist-info}/METADATA +99 -87
  21. {unstructured_ingest-0.5.1.dist-info → unstructured_ingest-0.5.3.dist-info}/RECORD +25 -25
  22. {unstructured_ingest-0.5.1.dist-info → unstructured_ingest-0.5.3.dist-info}/WHEEL +1 -1
  23. {unstructured_ingest-0.5.1.dist-info → unstructured_ingest-0.5.3.dist-info}/LICENSE.md +0 -0
  24. {unstructured_ingest-0.5.1.dist-info → unstructured_ingest-0.5.3.dist-info}/entry_points.txt +0 -0
  25. {unstructured_ingest-0.5.1.dist-info → unstructured_ingest-0.5.3.dist-info}/top_level.txt +0 -0
@@ -9,7 +9,6 @@ from unstructured_ingest.embed.interfaces import (
9
9
  EmbeddingConfig,
10
10
  )
11
11
  from unstructured_ingest.logger import logger
12
- from unstructured_ingest.utils.data_prep import batch_generator
13
12
  from unstructured_ingest.utils.dep_check import requires_dependencies
14
13
  from unstructured_ingest.v2.errors import (
15
14
  ProviderError,
@@ -17,6 +16,7 @@ from unstructured_ingest.v2.errors import (
17
16
  RateLimitError,
18
17
  UserAuthError,
19
18
  UserError,
19
+ is_internal_error,
20
20
  )
21
21
 
22
22
  if TYPE_CHECKING:
@@ -29,6 +29,8 @@ class OctoAiEmbeddingConfig(EmbeddingConfig):
29
29
  base_url: str = Field(default="https://text.octoai.run/v1")
30
30
 
31
31
  def wrap_error(self, e: Exception) -> Exception:
32
+ if is_internal_error(e=e):
33
+ return e
32
34
  # https://platform.openai.com/docs/guides/error-codes/api-errors
33
35
  from openai import APIStatusError
34
36
 
@@ -80,28 +82,17 @@ class OctoAIEmbeddingEncoder(BaseEmbeddingEncoder):
80
82
  def wrap_error(self, e: Exception) -> Exception:
81
83
  return self.config.wrap_error(e=e)
82
84
 
83
- def embed_query(self, query: str):
84
- try:
85
- client = self.config.get_client()
86
- response = client.embeddings.create(input=query, model=self.config.embedder_model_name)
87
- except Exception as e:
88
- raise self.wrap_error(e=e)
85
+ def _embed_query(self, query: str):
86
+ client = self.get_client()
87
+ response = client.embeddings.create(input=query, model=self.config.embedder_model_name)
89
88
  return response.data[0].embedding
90
89
 
91
- def embed_documents(self, elements: list[dict]) -> list[dict]:
92
- texts = [e.get("text", "") for e in elements]
93
- embeddings = []
94
- client = self.config.get_client()
95
- try:
96
- for batch in batch_generator(texts, batch_size=self.config.batch_size or len(texts)):
97
- response = client.embeddings.create(
98
- input=batch, model=self.config.embedder_model_name
99
- )
100
- embeddings.extend([data.embedding for data in response.data])
101
- except Exception as e:
102
- raise self.wrap_error(e=e)
103
- elements_with_embeddings = self._add_embeddings_to_elements(elements, embeddings)
104
- return elements_with_embeddings
90
+ def get_client(self) -> "OpenAI":
91
+ return self.config.get_client()
92
+
93
+ def embed_batch(self, client: "OpenAI", batch: list[str]) -> list[list[float]]:
94
+ response = client.embeddings.create(input=batch, model=self.config.embedder_model_name)
95
+ return [data.embedding for data in response.data]
105
96
 
106
97
 
107
98
  @dataclass
@@ -111,27 +102,11 @@ class AsyncOctoAIEmbeddingEncoder(AsyncBaseEmbeddingEncoder):
111
102
  def wrap_error(self, e: Exception) -> Exception:
112
103
  return self.config.wrap_error(e=e)
113
104
 
114
- async def embed_query(self, query: str):
115
- client = self.config.get_async_client()
116
- try:
117
- response = await client.embeddings.create(
118
- input=query, model=self.config.embedder_model_name
119
- )
120
- except Exception as e:
121
- raise self.wrap_error(e=e)
122
- return response.data[0].embedding
105
+ def get_client(self) -> "AsyncOpenAI":
106
+ return self.config.get_async_client()
123
107
 
124
- async def embed_documents(self, elements: list[dict]) -> list[dict]:
125
- texts = [e.get("text", "") for e in elements]
126
- client = self.config.get_async_client()
127
- embeddings = []
128
- try:
129
- for batch in batch_generator(texts, batch_size=self.config.batch_size or len(texts)):
130
- response = await client.embeddings.create(
131
- input=batch, model=self.config.embedder_model_name
132
- )
133
- embeddings.extend([data.embedding for data in response.data])
134
- except Exception as e:
135
- raise self.wrap_error(e=e)
136
- elements_with_embeddings = self._add_embeddings_to_elements(elements, embeddings)
137
- return elements_with_embeddings
108
+ async def embed_batch(self, client: "AsyncOpenAI", batch: list[str]) -> list[list[float]]:
109
+ response = await client.embeddings.create(
110
+ input=batch, model=self.config.embedder_model_name
111
+ )
112
+ return [data.embedding for data in response.data]
@@ -9,7 +9,6 @@ from unstructured_ingest.embed.interfaces import (
9
9
  EmbeddingConfig,
10
10
  )
11
11
  from unstructured_ingest.logger import logger
12
- from unstructured_ingest.utils.data_prep import batch_generator
13
12
  from unstructured_ingest.utils.dep_check import requires_dependencies
14
13
  from unstructured_ingest.v2.errors import (
15
14
  ProviderError,
@@ -17,6 +16,7 @@ from unstructured_ingest.v2.errors import (
17
16
  RateLimitError,
18
17
  UserAuthError,
19
18
  UserError,
19
+ is_internal_error,
20
20
  )
21
21
 
22
22
  if TYPE_CHECKING:
@@ -28,6 +28,8 @@ class OpenAIEmbeddingConfig(EmbeddingConfig):
28
28
  embedder_model_name: str = Field(default="text-embedding-ada-002", alias="model_name")
29
29
 
30
30
  def wrap_error(self, e: Exception) -> Exception:
31
+ if is_internal_error(e=e):
32
+ return e
31
33
  # https://platform.openai.com/docs/guides/error-codes/api-errors
32
34
  from openai import APIStatusError
33
35
 
@@ -71,29 +73,12 @@ class OpenAIEmbeddingEncoder(BaseEmbeddingEncoder):
71
73
  def wrap_error(self, e: Exception) -> Exception:
72
74
  return self.config.wrap_error(e=e)
73
75
 
74
- def embed_query(self, query: str) -> list[float]:
75
-
76
- client = self.config.get_client()
77
- try:
78
- response = client.embeddings.create(input=query, model=self.config.embedder_model_name)
79
- except Exception as e:
80
- raise self.wrap_error(e=e)
81
- return response.data[0].embedding
82
-
83
- def embed_documents(self, elements: list[dict]) -> list[dict]:
84
- client = self.config.get_client()
85
- texts = [e.get("text", "") for e in elements]
86
- embeddings = []
87
- try:
88
- for batch in batch_generator(texts, batch_size=self.config.batch_size or len(texts)):
89
- response = client.embeddings.create(
90
- input=batch, model=self.config.embedder_model_name
91
- )
92
- embeddings.extend([data.embedding for data in response.data])
93
- except Exception as e:
94
- raise self.wrap_error(e=e)
95
- elements_with_embeddings = self._add_embeddings_to_elements(elements, embeddings)
96
- return elements_with_embeddings
76
+ def get_client(self) -> "OpenAI":
77
+ return self.config.get_client()
78
+
79
+ def embed_batch(self, client: "OpenAI", batch: list[str]) -> list[list[float]]:
80
+ response = client.embeddings.create(input=batch, model=self.config.embedder_model_name)
81
+ return [data.embedding for data in response.data]
97
82
 
98
83
 
99
84
  @dataclass
@@ -103,27 +88,11 @@ class AsyncOpenAIEmbeddingEncoder(AsyncBaseEmbeddingEncoder):
103
88
  def wrap_error(self, e: Exception) -> Exception:
104
89
  return self.config.wrap_error(e=e)
105
90
 
106
- async def embed_query(self, query: str) -> list[float]:
107
- client = self.config.get_async_client()
108
- try:
109
- response = await client.embeddings.create(
110
- input=query, model=self.config.embedder_model_name
111
- )
112
- except Exception as e:
113
- raise self.wrap_error(e=e)
114
- return response.data[0].embedding
115
-
116
- async def embed_documents(self, elements: list[dict]) -> list[dict]:
117
- client = self.config.get_async_client()
118
- texts = [e.get("text", "") for e in elements]
119
- embeddings = []
120
- try:
121
- for batch in batch_generator(texts, batch_size=self.config.batch_size or len(texts)):
122
- response = await client.embeddings.create(
123
- input=batch, model=self.config.embedder_model_name
124
- )
125
- embeddings.extend([data.embedding for data in response.data])
126
- except Exception as e:
127
- raise self.wrap_error(e=e)
128
- elements_with_embeddings = self._add_embeddings_to_elements(elements, embeddings)
129
- return elements_with_embeddings
91
+ def get_client(self) -> "AsyncOpenAI":
92
+ return self.config.get_async_client()
93
+
94
+ async def embed_batch(self, client: "AsyncOpenAI", batch: list[str]) -> list[list[float]]:
95
+ response = await client.embeddings.create(
96
+ input=batch, model=self.config.embedder_model_name
97
+ )
98
+ return [data.embedding for data in response.data]
@@ -1,5 +1,5 @@
1
1
  from dataclasses import dataclass
2
- from typing import TYPE_CHECKING
2
+ from typing import TYPE_CHECKING, Any
3
3
 
4
4
  from pydantic import Field, SecretStr
5
5
 
@@ -9,15 +9,11 @@ from unstructured_ingest.embed.interfaces import (
9
9
  EmbeddingConfig,
10
10
  )
11
11
  from unstructured_ingest.logger import logger
12
- from unstructured_ingest.utils.data_prep import batch_generator
13
12
  from unstructured_ingest.utils.dep_check import requires_dependencies
14
13
  from unstructured_ingest.v2.errors import (
15
14
  RateLimitError as CustomRateLimitError,
16
15
  )
17
- from unstructured_ingest.v2.errors import (
18
- UserAuthError,
19
- UserError,
20
- )
16
+ from unstructured_ingest.v2.errors import UserAuthError, UserError, is_internal_error
21
17
 
22
18
  if TYPE_CHECKING:
23
19
  from together import AsyncTogether, Together
@@ -30,6 +26,8 @@ class TogetherAIEmbeddingConfig(EmbeddingConfig):
30
26
  )
31
27
 
32
28
  def wrap_error(self, e: Exception) -> Exception:
29
+ if is_internal_error(e=e):
30
+ return e
33
31
  # https://docs.together.ai/docs/error-codes
34
32
  from together.error import AuthenticationError, RateLimitError, TogetherException
35
33
 
@@ -63,27 +61,12 @@ class TogetherAIEmbeddingEncoder(BaseEmbeddingEncoder):
63
61
  def wrap_error(self, e: Exception) -> Exception:
64
62
  return self.config.wrap_error(e=e)
65
63
 
66
- def embed_query(self, query: str) -> list[float]:
67
- return self._embed_documents(elements=[query])[0]
68
-
69
- def embed_documents(self, elements: list[dict]) -> list[dict]:
70
- embeddings = self._embed_documents([e.get("text", "") for e in elements])
71
- return self._add_embeddings_to_elements(elements, embeddings)
72
-
73
- def _embed_documents(self, elements: list[str]) -> list[list[float]]:
74
- client = self.config.get_client()
75
- embeddings = []
76
- try:
77
- for batch in batch_generator(
78
- elements, batch_size=self.config.batch_size or len(elements)
79
- ):
80
- outputs = client.embeddings.create(
81
- model=self.config.embedder_model_name, input=batch
82
- )
83
- embeddings.extend([outputs.data[i].embedding for i in range(len(batch))])
84
- except Exception as e:
85
- raise self.wrap_error(e=e)
86
- return embeddings
64
+ def get_client(self) -> "Together":
65
+ return self.config.get_client()
66
+
67
+ def embed_batch(self, client: "Together", batch: list[str]) -> list[list[float]]:
68
+ outputs = client.embeddings.create(model=self.config.embedder_model_name, input=batch)
69
+ return [outputs.data[i].embedding for i in range(len(batch))]
87
70
 
88
71
 
89
72
  @dataclass
@@ -93,25 +76,9 @@ class AsyncTogetherAIEmbeddingEncoder(AsyncBaseEmbeddingEncoder):
93
76
  def wrap_error(self, e: Exception) -> Exception:
94
77
  return self.config.wrap_error(e=e)
95
78
 
96
- async def embed_query(self, query: str) -> list[float]:
97
- embedding = await self._embed_documents(elements=[query])
98
- return embedding[0]
99
-
100
- async def embed_documents(self, elements: list[dict]) -> list[dict]:
101
- embeddings = await self._embed_documents([e.get("text", "") for e in elements])
102
- return self._add_embeddings_to_elements(elements, embeddings)
103
-
104
- async def _embed_documents(self, elements: list[str]) -> list[list[float]]:
105
- client = self.config.get_async_client()
106
- embeddings = []
107
- try:
108
- for batch in batch_generator(
109
- elements, batch_size=self.config.batch_size or len(elements)
110
- ):
111
- outputs = await client.embeddings.create(
112
- model=self.config.embedder_model_name, input=batch
113
- )
114
- embeddings.extend([outputs.data[i].embedding for i in range(len(batch))])
115
- except Exception as e:
116
- raise self.wrap_error(e=e)
117
- return embeddings
79
+ def get_client(self) -> "AsyncTogether":
80
+ return self.config.get_async_client()
81
+
82
+ async def embed_batch(self, client: Any, batch: list[str]) -> list[list[float]]:
83
+ outputs = await client.embeddings.create(model=self.config.embedder_model_name, input=batch)
84
+ return [outputs.data[i].embedding for i in range(len(batch))]
@@ -13,9 +13,8 @@ from unstructured_ingest.embed.interfaces import (
13
13
  BaseEmbeddingEncoder,
14
14
  EmbeddingConfig,
15
15
  )
16
- from unstructured_ingest.utils.data_prep import batch_generator
17
16
  from unstructured_ingest.utils.dep_check import requires_dependencies
18
- from unstructured_ingest.v2.errors import UserAuthError
17
+ from unstructured_ingest.v2.errors import UserAuthError, is_internal_error
19
18
 
20
19
  if TYPE_CHECKING:
21
20
  from vertexai.language_models import TextEmbeddingModel
@@ -39,6 +38,8 @@ class VertexAIEmbeddingConfig(EmbeddingConfig):
39
38
  )
40
39
 
41
40
  def wrap_error(self, e: Exception) -> Exception:
41
+ if is_internal_error(e=e):
42
+ return e
42
43
  from google.auth.exceptions import GoogleAuthError
43
44
 
44
45
  if isinstance(e, GoogleAuthError):
@@ -71,31 +72,19 @@ class VertexAIEmbeddingEncoder(BaseEmbeddingEncoder):
71
72
  def wrap_error(self, e: Exception) -> Exception:
72
73
  return self.config.wrap_error(e=e)
73
74
 
74
- def embed_query(self, query):
75
- return self._embed_documents(elements=[query])[0]
76
-
77
- def embed_documents(self, elements: list[dict]) -> list[dict]:
78
- embeddings = self._embed_documents([e.get("text", "") for e in elements])
79
- elements_with_embeddings = self._add_embeddings_to_elements(elements, embeddings)
80
- return elements_with_embeddings
75
+ def get_client(self) -> "TextEmbeddingModel":
76
+ return self.config.get_client()
81
77
 
82
78
  @requires_dependencies(
83
79
  ["vertexai"],
84
80
  extras="embed-vertexai",
85
81
  )
86
- def _embed_documents(self, elements: list[str]) -> list[list[float]]:
82
+ def embed_batch(self, client: "TextEmbeddingModel", batch: list[str]) -> list[list[float]]:
87
83
  from vertexai.language_models import TextEmbeddingInput
88
84
 
89
- inputs = [TextEmbeddingInput(text=element) for element in elements]
90
- client = self.config.get_client()
91
- embeddings = []
92
- try:
93
- for batch in batch_generator(inputs, batch_size=self.config.batch_size or len(inputs)):
94
- response = client.get_embeddings(batch)
95
- embeddings.extend([e.values for e in response])
96
- except Exception as e:
97
- raise self.wrap_error(e=e)
98
- return embeddings
85
+ inputs = [TextEmbeddingInput(text=text) for text in batch]
86
+ response = client.get_embeddings(inputs)
87
+ return [e.values for e in response]
99
88
 
100
89
 
101
90
  @dataclass
@@ -105,29 +94,16 @@ class AsyncVertexAIEmbeddingEncoder(AsyncBaseEmbeddingEncoder):
105
94
  def wrap_error(self, e: Exception) -> Exception:
106
95
  return self.config.wrap_error(e=e)
107
96
 
108
- async def embed_query(self, query):
109
- embedding = await self._embed_documents(elements=[query])
110
- return embedding[0]
111
-
112
- async def embed_documents(self, elements: list[dict]) -> list[dict]:
113
- embeddings = await self._embed_documents([e.get("text", "") for e in elements])
114
- elements_with_embeddings = self._add_embeddings_to_elements(elements, embeddings)
115
- return elements_with_embeddings
97
+ def get_client(self) -> "TextEmbeddingModel":
98
+ return self.config.get_client()
116
99
 
117
100
  @requires_dependencies(
118
101
  ["vertexai"],
119
102
  extras="embed-vertexai",
120
103
  )
121
- async def _embed_documents(self, elements: list[str]) -> list[list[float]]:
104
+ async def embed_batch(self, client: Any, batch: list[str]) -> list[list[float]]:
122
105
  from vertexai.language_models import TextEmbeddingInput
123
106
 
124
- inputs = [TextEmbeddingInput(text=element) for element in elements]
125
- client = self.config.get_client()
126
- embeddings = []
127
- try:
128
- for batch in batch_generator(inputs, batch_size=self.config.batch_size or len(inputs)):
129
- response = await client.get_embeddings_async(batch)
130
- embeddings.extend([e.values for e in response])
131
- except Exception as e:
132
- raise self.wrap_error(e=e)
133
- return embeddings
107
+ inputs = [TextEmbeddingInput(text=text) for text in batch]
108
+ response = await client.get_embeddings_async(inputs)
109
+ return [e.values for e in response]
@@ -9,13 +9,8 @@ from unstructured_ingest.embed.interfaces import (
9
9
  EmbeddingConfig,
10
10
  )
11
11
  from unstructured_ingest.logger import logger
12
- from unstructured_ingest.utils.data_prep import batch_generator
13
12
  from unstructured_ingest.utils.dep_check import requires_dependencies
14
- from unstructured_ingest.v2.errors import (
15
- ProviderError,
16
- UserAuthError,
17
- UserError,
18
- )
13
+ from unstructured_ingest.v2.errors import ProviderError, UserAuthError, UserError, is_internal_error
19
14
  from unstructured_ingest.v2.errors import (
20
15
  RateLimitError as CustomRateLimitError,
21
16
  )
@@ -38,6 +33,8 @@ class VoyageAIEmbeddingConfig(EmbeddingConfig):
38
33
  timeout_in_seconds: Optional[int] = None
39
34
 
40
35
  def wrap_error(self, e: Exception) -> Exception:
36
+ if is_internal_error(e=e):
37
+ return e
41
38
  # https://docs.voyageai.com/docs/error-codes
42
39
  from voyageai.error import AuthenticationError, RateLimitError, VoyageError
43
40
 
@@ -95,23 +92,12 @@ class VoyageAIEmbeddingEncoder(BaseEmbeddingEncoder):
95
92
  def wrap_error(self, e: Exception) -> Exception:
96
93
  return self.config.wrap_error(e=e)
97
94
 
98
- def _embed_documents(self, elements: list[str]) -> list[list[float]]:
99
- client = self.config.get_client()
100
- embeddings = []
101
- try:
102
- for batch in batch_generator(elements, batch_size=self.config.batch_size):
103
- response = client.embed(texts=batch, model=self.config.embedder_model_name)
104
- embeddings.extend(response.embeddings)
105
- except Exception as e:
106
- raise self.wrap_error(e=e)
107
- return embeddings
108
-
109
- def embed_documents(self, elements: list[dict]) -> list[dict]:
110
- embeddings = self._embed_documents([e.get("text", "") for e in elements])
111
- return self._add_embeddings_to_elements(elements, embeddings)
95
+ def get_client(self) -> "VoyageAIClient":
96
+ return self.config.get_client()
112
97
 
113
- def embed_query(self, query: str) -> list[float]:
114
- return self._embed_documents(elements=[query])[0]
98
+ def embed_batch(self, client: "VoyageAIClient", batch: list[str]) -> list[list[float]]:
99
+ response = client.embed(texts=batch, model=self.config.embedder_model_name)
100
+ return response.embeddings
115
101
 
116
102
 
117
103
  @dataclass
@@ -121,23 +107,11 @@ class AsyncVoyageAIEmbeddingEncoder(AsyncBaseEmbeddingEncoder):
121
107
  def wrap_error(self, e: Exception) -> Exception:
122
108
  return self.config.wrap_error(e=e)
123
109
 
124
- async def _embed_documents(self, elements: list[str]) -> list[list[float]]:
125
- client = self.config.get_async_client()
126
- embeddings = []
127
- try:
128
- for batch in batch_generator(
129
- elements, batch_size=self.config.batch_size or len(elements)
130
- ):
131
- response = await client.embed(texts=batch, model=self.config.embedder_model_name)
132
- embeddings.extend(response.embeddings)
133
- except Exception as e:
134
- raise self.wrap_error(e=e)
135
- return embeddings
136
-
137
- async def embed_documents(self, elements: list[dict]) -> list[dict]:
138
- embeddings = await self._embed_documents([e.get("text", "") for e in elements])
139
- return self._add_embeddings_to_elements(elements, embeddings)
140
-
141
- async def embed_query(self, query: str) -> list[float]:
142
- embedding = await self._embed_documents(elements=[query])
143
- return embedding[0]
110
+ def get_client(self) -> "AsyncVoyageAIClient":
111
+ return self.config.get_async_client()
112
+
113
+ async def embed_batch(
114
+ self, client: "AsyncVoyageAIClient", batch: list[str]
115
+ ) -> list[list[float]]:
116
+ response = await client.embed(texts=batch, model=self.config.embedder_model_name)
117
+ return response.embeddings
@@ -16,3 +16,10 @@ class QuotaError(UserError):
16
16
 
17
17
  class ProviderError(Exception):
18
18
  pass
19
+
20
+
21
+ recognized_errors = [UserError, UserAuthError, RateLimitError, QuotaError, ProviderError]
22
+
23
+
24
+ def is_internal_error(e: Exception) -> bool:
25
+ return any(isinstance(e, recognized_error) for recognized_error in recognized_errors)
@@ -132,12 +132,141 @@ class GoogleDriveIndexer(Indexer):
132
132
  ]
133
133
  )
134
134
 
135
+ @staticmethod
136
+ def verify_drive_api_enabled(client) -> None:
137
+ from googleapiclient.errors import HttpError
138
+
139
+ """
140
+ Makes a lightweight API call to verify that the Drive API is enabled.
141
+ If the API is not enabled, an HttpError should be raised.
142
+ """
143
+ try:
144
+ # A very minimal call: list 1 file from the drive.
145
+ client.list(spaces="drive", pageSize=1, fields="files(id)").execute()
146
+ except HttpError as e:
147
+ error_content = e.content.decode() if hasattr(e, "content") else ""
148
+ lower_error = error_content.lower()
149
+ if "drive api" in lower_error and (
150
+ "not enabled" in lower_error or "not been used" in lower_error
151
+ ):
152
+ raise SourceConnectionError(
153
+ "Google Drive API is not enabled for your project. \
154
+ Please enable it in the Google Cloud Console."
155
+ )
156
+ else:
157
+ raise SourceConnectionError("Google drive API unreachable for an unknown reason!")
158
+
159
+ @staticmethod
160
+ def count_files_recursively(files_client, folder_id: str, extensions: list[str] = None) -> int:
161
+ """
162
+ Count non-folder files recursively under the given folder.
163
+ If `extensions` is provided, only count files
164
+ whose `fileExtension` matches one of the values.
165
+ """
166
+ count = 0
167
+ stack = [folder_id]
168
+ while stack:
169
+ current_folder = stack.pop()
170
+ # Always list all items under the current folder.
171
+ query = f"'{current_folder}' in parents"
172
+ page_token = None
173
+ while True:
174
+ response = files_client.list(
175
+ spaces="drive",
176
+ q=query,
177
+ fields="nextPageToken, files(id, mimeType, fileExtension)",
178
+ pageToken=page_token,
179
+ pageSize=1000,
180
+ ).execute()
181
+ for item in response.get("files", []):
182
+ if item.get("mimeType") == "application/vnd.google-apps.folder":
183
+ # Always traverse sub-folders regardless of extension filter.
184
+ stack.append(item["id"])
185
+ else:
186
+ if extensions:
187
+ # Use a case-insensitive comparison for the file extension.
188
+ file_ext = (item.get("fileExtension") or "").lower()
189
+ valid_exts = [e.lower() for e in extensions]
190
+ if file_ext in valid_exts:
191
+ count += 1
192
+ else:
193
+ count += 1
194
+ page_token = response.get("nextPageToken")
195
+ if not page_token:
196
+ break
197
+ return count
198
+
135
199
  def precheck(self) -> None:
200
+ """
201
+ Enhanced precheck that verifies not only connectivity
202
+ but also that the provided drive_id is valid and accessible.
203
+ """
136
204
  try:
137
- self.connection_config.get_client()
205
+ with self.connection_config.get_client() as client:
206
+ # First, verify that the Drive API is enabled.
207
+ self.verify_drive_api_enabled(client)
208
+
209
+ # Try to retrieve metadata for the drive id.
210
+ # This will catch errors such as an invalid drive id or insufficient permissions.
211
+ root_info = self.get_root_info(
212
+ files_client=client, object_id=self.connection_config.drive_id
213
+ )
214
+ logger.info(
215
+ f"Successfully retrieved drive root info: "
216
+ f"{root_info.get('name', 'Unnamed')} (ID: {root_info.get('id')})"
217
+ )
218
+
219
+ # If the target is a folder, perform file count check.
220
+ if self.is_dir(root_info):
221
+ if self.index_config.recursive:
222
+ file_count = self.count_files_recursively(
223
+ client,
224
+ self.connection_config.drive_id,
225
+ extensions=self.index_config.extensions,
226
+ )
227
+ if file_count == 0:
228
+ logger.warning(
229
+ "Empty folder: no files found recursively in the folder. \
230
+ Please verify that the folder contains files and \
231
+ that the service account has proper permissions."
232
+ )
233
+ # raise SourceConnectionError(
234
+ # "Empty folder: no files found recursively in the folder. "
235
+ # "Please verify that the folder contains files and \
236
+ # that the service account has proper permissions."
237
+ # )
238
+ else:
239
+ logger.info(f"Found {file_count} files recursively in the folder.")
240
+ else:
241
+ # Non-recursive: check for at least one immediate non-folder child.
242
+ response = client.list(
243
+ spaces="drive",
244
+ fields="files(id)",
245
+ pageSize=1,
246
+ q=f"'{self.connection_config.drive_id}' in parents",
247
+ ).execute()
248
+ if not response.get("files"):
249
+ logger.warning(
250
+ "Empty folder: no files found at the folder's root level. "
251
+ "Please verify that the folder contains files and \
252
+ that the service account has proper permissions."
253
+ )
254
+ # raise SourceConnectionError(
255
+ # "Empty folder: no files found at the folder's root level. "
256
+ # "Please verify that the folder contains files and \
257
+ # that the service account has proper permissions."
258
+ # )
259
+ else:
260
+ logger.info("Found files at the folder's root level.")
261
+ else:
262
+ # If the target is a file, precheck passes.
263
+ logger.info("Drive ID corresponds to a file. Precheck passed.")
264
+
138
265
  except Exception as e:
139
- logger.error(f"failed to validate connection: {e}", exc_info=True)
140
- raise SourceConnectionError(f"failed to validate connection: {e}")
266
+ logger.error(
267
+ "Failed to validate Google Drive connection during precheck", exc_info=True
268
+ )
269
+ raise SourceConnectionError(f"Precheck failed: {e}")
141
270
 
142
271
  @staticmethod
143
272
  def is_dir(record: dict) -> bool: