unstructured-ingest 0.5.2__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of unstructured-ingest might be problematic. Click here for more details.

@@ -19,24 +19,31 @@ from unstructured_ingest.v2.processes.connectors.sharepoint import (
19
19
  )
20
20
 
21
21
 
22
+ def sharepoint_config():
23
+ class SharepointTestConfig:
24
+ def __init__(self):
25
+ self.client_id = os.environ["SHAREPOINT_CLIENT_ID"]
26
+ self.client_cred = os.environ["SHAREPOINT_CRED"]
27
+ self.user_pname = os.environ["MS_USER_PNAME"]
28
+ self.tenant = os.environ["MS_TENANT_ID"]
29
+
30
+ return SharepointTestConfig()
31
+
32
+
22
33
  @pytest.mark.asyncio
23
34
  @pytest.mark.tags(CONNECTOR_TYPE, SOURCE_TAG, BLOB_STORAGE_TAG)
24
35
  @requires_env("SHAREPOINT_CLIENT_ID", "SHAREPOINT_CRED", "MS_TENANT_ID", "MS_USER_PNAME")
25
36
  async def test_sharepoint_source(temp_dir):
26
- # Retrieve environment variables
27
37
  site = "https://unstructuredio.sharepoint.com/sites/utic-platform-test-source"
28
- client_id = os.environ["SHAREPOINT_CLIENT_ID"]
29
- client_cred = os.environ["SHAREPOINT_CRED"]
30
- user_pname = os.environ["MS_USER_PNAME"]
31
- tenant = os.environ["MS_TENANT_ID"]
38
+ config = sharepoint_config()
32
39
 
33
40
  # Create connection and indexer configurations
34
- access_config = SharepointAccessConfig(client_cred=client_cred)
41
+ access_config = SharepointAccessConfig(client_cred=config.client_cred)
35
42
  connection_config = SharepointConnectionConfig(
36
- client_id=client_id,
43
+ client_id=config.client_id,
37
44
  site=site,
38
- tenant=tenant,
39
- user_pname=user_pname,
45
+ tenant=config.tenant,
46
+ user_pname=config.user_pname,
40
47
  access_config=access_config,
41
48
  )
42
49
  index_config = SharepointIndexerConfig(recursive=True)
@@ -58,7 +65,151 @@ async def test_sharepoint_source(temp_dir):
58
65
  indexer=indexer,
59
66
  downloader=downloader,
60
67
  configs=SourceValidationConfigs(
61
- test_id="sharepoint",
68
+ test_id="sharepoint1",
69
+ expected_num_files=4,
70
+ validate_downloaded_files=True,
71
+ exclude_fields_extend=[
72
+ "metadata.date_created",
73
+ "metadata.date_modified",
74
+ "additional_metadata.LastModified",
75
+ "additional_metadata.@microsoft.graph.downloadUrl",
76
+ ],
77
+ ),
78
+ )
79
+
80
+
81
+ @pytest.mark.asyncio
82
+ @pytest.mark.tags(CONNECTOR_TYPE, SOURCE_TAG, BLOB_STORAGE_TAG)
83
+ @requires_env("SHAREPOINT_CLIENT_ID", "SHAREPOINT_CRED", "MS_TENANT_ID", "MS_USER_PNAME")
84
+ async def test_sharepoint_source_with_path(temp_dir):
85
+ site = "https://unstructuredio.sharepoint.com/sites/utic-platform-test-source"
86
+ config = sharepoint_config()
87
+
88
+ # Create connection and indexer configurations
89
+ access_config = SharepointAccessConfig(client_cred=config.client_cred)
90
+ connection_config = SharepointConnectionConfig(
91
+ client_id=config.client_id,
92
+ site=site,
93
+ tenant=config.tenant,
94
+ user_pname=config.user_pname,
95
+ access_config=access_config,
96
+ )
97
+ index_config = SharepointIndexerConfig(recursive=True, path="Folder1")
98
+
99
+ download_config = SharepointDownloaderConfig(download_dir=temp_dir)
100
+
101
+ # Instantiate indexer and downloader
102
+ indexer = SharepointIndexer(
103
+ connection_config=connection_config,
104
+ index_config=index_config,
105
+ )
106
+ downloader = SharepointDownloader(
107
+ connection_config=connection_config,
108
+ download_config=download_config,
109
+ )
110
+
111
+ # Run the source connector validation
112
+ await source_connector_validation(
113
+ indexer=indexer,
114
+ downloader=downloader,
115
+ configs=SourceValidationConfigs(
116
+ test_id="sharepoint2",
117
+ expected_num_files=2,
118
+ validate_downloaded_files=True,
119
+ exclude_fields_extend=[
120
+ "metadata.date_created",
121
+ "metadata.date_modified",
122
+ "additional_metadata.LastModified",
123
+ "additional_metadata.@microsoft.graph.downloadUrl",
124
+ ],
125
+ ),
126
+ )
127
+
128
+
129
+ @pytest.mark.asyncio
130
+ @pytest.mark.tags(CONNECTOR_TYPE, SOURCE_TAG, BLOB_STORAGE_TAG)
131
+ @requires_env("SHAREPOINT_CLIENT_ID", "SHAREPOINT_CRED", "MS_TENANT_ID", "MS_USER_PNAME")
132
+ async def test_sharepoint_root_with_path(temp_dir):
133
+ site = "https://unstructuredio.sharepoint.com/"
134
+ config = sharepoint_config()
135
+
136
+ # Create connection and indexer configurations
137
+ access_config = SharepointAccessConfig(client_cred=config.client_cred)
138
+ connection_config = SharepointConnectionConfig(
139
+ client_id=config.client_id,
140
+ site=site,
141
+ tenant=config.tenant,
142
+ user_pname=config.user_pname,
143
+ access_config=access_config,
144
+ )
145
+ index_config = SharepointIndexerConfig(recursive=True, path="e2e-test-folder")
146
+
147
+ download_config = SharepointDownloaderConfig(download_dir=temp_dir)
148
+
149
+ # Instantiate indexer and downloader
150
+ indexer = SharepointIndexer(
151
+ connection_config=connection_config,
152
+ index_config=index_config,
153
+ )
154
+ downloader = SharepointDownloader(
155
+ connection_config=connection_config,
156
+ download_config=download_config,
157
+ )
158
+
159
+ # Run the source connector validation
160
+ await source_connector_validation(
161
+ indexer=indexer,
162
+ downloader=downloader,
163
+ configs=SourceValidationConfigs(
164
+ test_id="sharepoint3",
165
+ expected_num_files=1,
166
+ validate_downloaded_files=True,
167
+ exclude_fields_extend=[
168
+ "metadata.date_created",
169
+ "metadata.date_modified",
170
+ "additional_metadata.LastModified",
171
+ "additional_metadata.@microsoft.graph.downloadUrl",
172
+ ],
173
+ ),
174
+ )
175
+
176
+
177
+ @pytest.mark.asyncio
178
+ @pytest.mark.tags(CONNECTOR_TYPE, SOURCE_TAG, BLOB_STORAGE_TAG)
179
+ @requires_env("SHAREPOINT_CLIENT_ID", "SHAREPOINT_CRED", "MS_TENANT_ID", "MS_USER_PNAME")
180
+ async def test_sharepoint_shared_documents(temp_dir):
181
+ site = "https://unstructuredio.sharepoint.com/sites/utic-platform-test-source"
182
+ config = sharepoint_config()
183
+
184
+ # Create connection and indexer configurations
185
+ access_config = SharepointAccessConfig(client_cred=config.client_cred)
186
+ connection_config = SharepointConnectionConfig(
187
+ client_id=config.client_id,
188
+ site=site,
189
+ tenant=config.tenant,
190
+ user_pname=config.user_pname,
191
+ access_config=access_config,
192
+ )
193
+ index_config = SharepointIndexerConfig(recursive=True, path="Shared Documents")
194
+
195
+ download_config = SharepointDownloaderConfig(download_dir=temp_dir)
196
+
197
+ # Instantiate indexer and downloader
198
+ indexer = SharepointIndexer(
199
+ connection_config=connection_config,
200
+ index_config=index_config,
201
+ )
202
+ downloader = SharepointDownloader(
203
+ connection_config=connection_config,
204
+ download_config=download_config,
205
+ )
206
+
207
+ # Run the source connector validation
208
+ await source_connector_validation(
209
+ indexer=indexer,
210
+ downloader=downloader,
211
+ configs=SourceValidationConfigs(
212
+ test_id="sharepoint4",
62
213
  expected_num_files=4,
63
214
  validate_downloaded_files=True,
64
215
  exclude_fields_extend=[
@@ -15,7 +15,7 @@ def generate_embedder_config_params() -> dict:
15
15
  "region_name": fake.city(),
16
16
  }
17
17
  if random.random() < 0.5:
18
- params["embed_model_name"] = fake.word()
18
+ params["embedder_model_name"] = fake.word()
19
19
  return params
20
20
 
21
21
 
@@ -16,7 +16,7 @@ fake = faker.Faker()
16
16
  def generate_embedder_config_params() -> dict:
17
17
  params = {}
18
18
  if random.random() < 0.5:
19
- params["embed_model_name"] = fake.word() if random.random() < 0.5 else None
19
+ params["embedder_model_name"] = fake.word() if random.random() < 0.5 else None
20
20
  params["embedder_model_kwargs"] = (
21
21
  generate_random_dictionary(key_type=str, value_type=Any)
22
22
  if random.random() < 0.5
@@ -1 +1 @@
1
- __version__ = "0.5.2" # pragma: no cover
1
+ __version__ = "0.5.4" # pragma: no cover
@@ -44,7 +44,13 @@ class AzureOpenAIEmbeddingConfig(OpenAIEmbeddingConfig):
44
44
  class AzureOpenAIEmbeddingEncoder(OpenAIEmbeddingEncoder):
45
45
  config: AzureOpenAIEmbeddingConfig
46
46
 
47
+ def get_client(self) -> "AzureOpenAI":
48
+ return self.config.get_client()
49
+
47
50
 
48
51
  @dataclass
49
52
  class AsyncAzureOpenAIEmbeddingEncoder(AsyncOpenAIEmbeddingEncoder):
50
53
  config: AzureOpenAIEmbeddingConfig
54
+
55
+ def get_client(self) -> "AsyncAzureOpenAI":
56
+ return self.config.get_async_client()
@@ -15,7 +15,13 @@ from unstructured_ingest.embed.interfaces import (
15
15
  )
16
16
  from unstructured_ingest.logger import logger
17
17
  from unstructured_ingest.utils.dep_check import requires_dependencies
18
- from unstructured_ingest.v2.errors import ProviderError, RateLimitError, UserAuthError, UserError
18
+ from unstructured_ingest.v2.errors import (
19
+ ProviderError,
20
+ RateLimitError,
21
+ UserAuthError,
22
+ UserError,
23
+ is_internal_error,
24
+ )
19
25
 
20
26
  if TYPE_CHECKING:
21
27
  from botocore.client import BaseClient
@@ -51,9 +57,11 @@ class BedrockEmbeddingConfig(EmbeddingConfig):
51
57
  aws_access_key_id: SecretStr
52
58
  aws_secret_access_key: SecretStr
53
59
  region_name: str = "us-west-2"
54
- embed_model_name: str = Field(default="amazon.titan-embed-text-v1", alias="model_name")
60
+ embedder_model_name: str = Field(default="amazon.titan-embed-text-v1", alias="model_name")
55
61
 
56
62
  def wrap_error(self, e: Exception) -> Exception:
63
+ if is_internal_error(e=e):
64
+ return e
57
65
  from botocore.exceptions import ClientError
58
66
 
59
67
  if isinstance(e, ClientError):
@@ -122,7 +130,7 @@ class BedrockEmbeddingEncoder(BaseEmbeddingEncoder):
122
130
 
123
131
  def embed_query(self, query: str) -> list[float]:
124
132
  """Call out to Bedrock embedding endpoint."""
125
- provider = self.config.embed_model_name.split(".")[0]
133
+ provider = self.config.embedder_model_name.split(".")[0]
126
134
  body = conform_query(query=query, provider=provider)
127
135
 
128
136
  bedrock_client = self.config.get_client()
@@ -130,7 +138,7 @@ class BedrockEmbeddingEncoder(BaseEmbeddingEncoder):
130
138
  try:
131
139
  response = bedrock_client.invoke_model(
132
140
  body=json.dumps(body),
133
- modelId=self.config.embed_model_name,
141
+ modelId=self.config.embedder_model_name,
134
142
  accept="application/json",
135
143
  contentType="application/json",
136
144
  )
@@ -148,6 +156,8 @@ class BedrockEmbeddingEncoder(BaseEmbeddingEncoder):
148
156
  def embed_documents(self, elements: list[dict]) -> list[dict]:
149
157
  elements = elements.copy()
150
158
  elements_with_text = [e for e in elements if e.get("text")]
159
+ if not elements_with_text:
160
+ return elements
151
161
  embeddings = [self.embed_query(query=e["text"]) for e in elements_with_text]
152
162
  for element, embedding in zip(elements_with_text, embeddings):
153
163
  element[EMBEDDINGS_KEY] = embedding
@@ -163,7 +173,7 @@ class AsyncBedrockEmbeddingEncoder(AsyncBaseEmbeddingEncoder):
163
173
 
164
174
  async def embed_query(self, query: str) -> list[float]:
165
175
  """Call out to Bedrock embedding endpoint."""
166
- provider = self.config.embed_model_name.split(".")[0]
176
+ provider = self.config.embedder_model_name.split(".")[0]
167
177
  body = conform_query(query=query, provider=provider)
168
178
  try:
169
179
  async with self.config.get_async_client() as bedrock_client:
@@ -171,7 +181,7 @@ class AsyncBedrockEmbeddingEncoder(AsyncBaseEmbeddingEncoder):
171
181
  try:
172
182
  response = await bedrock_client.invoke_model(
173
183
  body=json.dumps(body),
174
- modelId=self.config.embed_model_name,
184
+ modelId=self.config.embedder_model_name,
175
185
  accept="application/json",
176
186
  contentType="application/json",
177
187
  )
@@ -47,7 +47,7 @@ class HuggingFaceEmbeddingConfig(EmbeddingConfig):
47
47
  class HuggingFaceEmbeddingEncoder(BaseEmbeddingEncoder):
48
48
  config: HuggingFaceEmbeddingConfig
49
49
 
50
- def embed_query(self, query: str) -> list[float]:
50
+ def _embed_query(self, query: str) -> list[float]:
51
51
  return self._embed_documents(texts=[query])[0]
52
52
 
53
53
  def _embed_documents(self, texts: list[str]) -> list[list[float]]:
@@ -58,6 +58,8 @@ class HuggingFaceEmbeddingEncoder(BaseEmbeddingEncoder):
58
58
  def embed_documents(self, elements: list[dict]) -> list[dict]:
59
59
  elements = elements.copy()
60
60
  elements_with_text = [e for e in elements if e.get("text")]
61
+ if not elements_with_text:
62
+ return elements
61
63
  embeddings = self._embed_documents([e["text"] for e in elements_with_text])
62
64
  for element, embedding in zip(elements_with_text, embeddings):
63
65
  element[EMBEDDINGS_KEY] = embedding
@@ -1,11 +1,12 @@
1
- import asyncio
2
- from abc import ABC, abstractmethod
1
+ from abc import ABC
3
2
  from dataclasses import dataclass
4
- from typing import Optional
3
+ from typing import Any, Optional
5
4
 
6
5
  import numpy as np
7
6
  from pydantic import BaseModel, Field
8
7
 
8
+ from unstructured_ingest.utils.data_prep import batch_generator
9
+
9
10
  EMBEDDINGS_KEY = "embeddings"
10
11
 
11
12
 
@@ -50,21 +51,37 @@ class BaseEmbeddingEncoder(BaseEncoder, ABC):
50
51
  exemplary_embedding = self.get_exemplary_embedding()
51
52
  return np.isclose(np.linalg.norm(exemplary_embedding), 1.0)
52
53
 
53
- @abstractmethod
54
- def embed_documents(self, elements: list[dict]) -> list[dict]:
55
- pass
54
+ def get_client(self):
55
+ raise NotImplementedError
56
56
 
57
- @abstractmethod
58
- def embed_query(self, query: str) -> list[float]:
59
- pass
57
+ def embed_batch(self, client: Any, batch: list[str]) -> list[list[float]]:
58
+ raise NotImplementedError
60
59
 
61
- def _embed_documents(self, elements: list[str]) -> list[list[float]]:
62
- results = []
63
- for text in elements:
64
- response = self.embed_query(query=text)
65
- results.append(response)
60
+ def embed_documents(self, elements: list[dict]) -> list[dict]:
61
+ client = self.get_client()
62
+ elements = elements.copy()
63
+ elements_with_text = [e for e in elements if e.get("text")]
64
+ texts = [e["text"] for e in elements_with_text]
65
+ embeddings = []
66
+ try:
67
+ for batch in batch_generator(texts, batch_size=self.config.batch_size or len(texts)):
68
+ embeddings = self.embed_batch(client=client, batch=batch)
69
+ embeddings.extend(embeddings)
70
+ except Exception as e:
71
+ raise self.wrap_error(e=e)
72
+ for element, embedding in zip(elements_with_text, embeddings):
73
+ element[EMBEDDINGS_KEY] = embedding
74
+ return elements
75
+
76
+ def _embed_query(self, query: str) -> list[float]:
77
+ client = self.get_client()
78
+ return self.embed_batch(client=client, batch=[query])[0]
66
79
 
67
- return results
80
+ def embed_query(self, query: str) -> list[float]:
81
+ try:
82
+ return self._embed_query(query=query)
83
+ except Exception as e:
84
+ raise self.wrap_error(e=e)
68
85
 
69
86
 
70
87
  @dataclass
@@ -88,14 +105,35 @@ class AsyncBaseEmbeddingEncoder(BaseEncoder, ABC):
88
105
  exemplary_embedding = await self.get_exemplary_embedding()
89
106
  return np.isclose(np.linalg.norm(exemplary_embedding), 1.0)
90
107
 
91
- @abstractmethod
108
+ def get_client(self):
109
+ raise NotImplementedError
110
+
111
+ async def embed_batch(self, client: Any, batch: list[str]) -> list[list[float]]:
112
+ raise NotImplementedError
113
+
92
114
  async def embed_documents(self, elements: list[dict]) -> list[dict]:
93
- pass
115
+ client = self.get_client()
116
+ elements = elements.copy()
117
+ elements_with_text = [e for e in elements if e.get("text")]
118
+ texts = [e["text"] for e in elements_with_text]
119
+ embeddings = []
120
+ try:
121
+ for batch in batch_generator(texts, batch_size=self.config.batch_size or len(texts)):
122
+ embeddings = await self.embed_batch(client=client, batch=batch)
123
+ embeddings.extend(embeddings)
124
+ except Exception as e:
125
+ raise self.wrap_error(e=e)
126
+ for element, embedding in zip(elements_with_text, embeddings):
127
+ element[EMBEDDINGS_KEY] = embedding
128
+ return elements
129
+
130
+ async def _embed_query(self, query: str) -> list[float]:
131
+ client = self.get_client()
132
+ embeddings = await self.embed_batch(client=client, batch=[query])
133
+ return embeddings[0]
94
134
 
95
- @abstractmethod
96
135
  async def embed_query(self, query: str) -> list[float]:
97
- pass
98
-
99
- async def _embed_documents(self, elements: list[str]) -> list[list[float]]:
100
- results = await asyncio.gather(*[self.embed_query(query=text) for text in elements])
101
- return results
136
+ try:
137
+ return await self._embed_query(query=query)
138
+ except Exception as e:
139
+ raise self.wrap_error(e=e)
@@ -1,4 +1,3 @@
1
- import asyncio
2
1
  import os
3
2
  from dataclasses import dataclass
4
3
  from typing import TYPE_CHECKING
@@ -6,12 +5,10 @@ from typing import TYPE_CHECKING
6
5
  from pydantic import Field, SecretStr
7
6
 
8
7
  from unstructured_ingest.embed.interfaces import (
9
- EMBEDDINGS_KEY,
10
8
  AsyncBaseEmbeddingEncoder,
11
9
  BaseEmbeddingEncoder,
12
10
  EmbeddingConfig,
13
11
  )
14
- from unstructured_ingest.utils.data_prep import batch_generator
15
12
  from unstructured_ingest.utils.dep_check import requires_dependencies
16
13
 
17
14
  USER_AGENT = "@mixedbread-ai/unstructured"
@@ -85,7 +82,7 @@ class MixedbreadAIEmbeddingEncoder(BaseEmbeddingEncoder):
85
82
 
86
83
  def get_exemplary_embedding(self) -> list[float]:
87
84
  """Get an exemplary embedding to determine dimensions and unit vector status."""
88
- return self._embed(["Q"])[0]
85
+ return self.embed_query(query="Q")
89
86
 
90
87
  @requires_dependencies(
91
88
  ["mixedbread_ai"],
@@ -100,59 +97,19 @@ class MixedbreadAIEmbeddingEncoder(BaseEmbeddingEncoder):
100
97
  additional_headers={"User-Agent": USER_AGENT},
101
98
  )
102
99
 
103
- def _embed(self, texts: list[str]) -> list[list[float]]:
104
- """
105
- Embed a list of texts using the Mixedbread AI API.
106
-
107
- Args:
108
- texts (list[str]): List of texts to embed.
109
-
110
- Returns:
111
- list[list[float]]: List of embeddings.
112
- """
113
-
114
- responses = []
115
- client = self.config.get_client()
116
- for batch in batch_generator(texts, batch_size=self.config.batch_size or len(texts)):
117
- response = client.embeddings(
118
- model=self.config.embedder_model_name,
119
- normalized=True,
120
- encoding_format=ENCODING_FORMAT,
121
- truncation_strategy=TRUNCATION_STRATEGY,
122
- request_options=self.get_request_options(),
123
- input=batch,
124
- )
125
- responses.append(response)
126
- return [item.embedding for response in responses for item in response.data]
127
-
128
- def embed_documents(self, elements: list[dict]) -> list[dict]:
129
- """
130
- Embed a list of document elements.
131
-
132
- Args:
133
- elements (list[Element]): List of document elements.
134
-
135
- Returns:
136
- list[Element]: Elements with embeddings.
137
- """
138
- elements = elements.copy()
139
- elements_with_text = [e for e in elements if e.get("text")]
140
- embeddings = self._embed([e["text"] for e in elements_with_text])
141
- for element, embedding in zip(elements_with_text, embeddings):
142
- element[EMBEDDINGS_KEY] = embedding
143
- return elements
144
-
145
- def embed_query(self, query: str) -> list[float]:
146
- """
147
- Embed a query string.
148
-
149
- Args:
150
- query (str): Query string to embed.
151
-
152
- Returns:
153
- list[float]: Embedding of the query.
154
- """
155
- return self._embed([query])[0]
100
+ def get_client(self) -> "MixedbreadAI":
101
+ return self.config.get_client()
102
+
103
+ def embed_batch(self, client: "MixedbreadAI", batch: list[str]) -> list[list[float]]:
104
+ response = client.embeddings(
105
+ model=self.config.embedder_model_name,
106
+ normalized=True,
107
+ encoding_format=ENCODING_FORMAT,
108
+ truncation_strategy=TRUNCATION_STRATEGY,
109
+ request_options=self.get_request_options(),
110
+ input=batch,
111
+ )
112
+ return [datum.embedding for datum in response.data]
156
113
 
157
114
 
158
115
  @dataclass
@@ -162,8 +119,7 @@ class AsyncMixedbreadAIEmbeddingEncoder(AsyncBaseEmbeddingEncoder):
162
119
 
163
120
  async def get_exemplary_embedding(self) -> list[float]:
164
121
  """Get an exemplary embedding to determine dimensions and unit vector status."""
165
- embedding = await self._embed(["Q"])
166
- return embedding[0]
122
+ return await self.embed_query(query="Q")
167
123
 
168
124
  @requires_dependencies(
169
125
  ["mixedbread_ai"],
@@ -178,58 +134,16 @@ class AsyncMixedbreadAIEmbeddingEncoder(AsyncBaseEmbeddingEncoder):
178
134
  additional_headers={"User-Agent": USER_AGENT},
179
135
  )
180
136
 
181
- async def _embed(self, texts: list[str]) -> list[list[float]]:
182
- """
183
- Embed a list of texts using the Mixedbread AI API.
184
-
185
- Args:
186
- texts (list[str]): List of texts to embed.
187
-
188
- Returns:
189
- list[list[float]]: List of embeddings.
190
- """
191
- client = self.config.get_async_client()
192
- tasks = []
193
- for batch in batch_generator(texts, batch_size=self.config.batch_size or len(texts)):
194
- tasks.append(
195
- client.embeddings(
196
- model=self.config.embedder_model_name,
197
- normalized=True,
198
- encoding_format=ENCODING_FORMAT,
199
- truncation_strategy=TRUNCATION_STRATEGY,
200
- request_options=self.get_request_options(),
201
- input=batch,
202
- )
203
- )
204
- responses = await asyncio.gather(*tasks)
205
- return [item.embedding for response in responses for item in response.data]
206
-
207
- async def embed_documents(self, elements: list[dict]) -> list[dict]:
208
- """
209
- Embed a list of document elements.
210
-
211
- Args:
212
- elements (list[Element]): List of document elements.
213
-
214
- Returns:
215
- list[Element]: Elements with embeddings.
216
- """
217
- elements = elements.copy()
218
- elements_with_text = [e for e in elements if e.get("text")]
219
- embeddings = await self._embed([e["text"] for e in elements_with_text])
220
- for element, embedding in zip(elements_with_text, embeddings):
221
- element[EMBEDDINGS_KEY] = embedding
222
- return elements
223
-
224
- async def embed_query(self, query: str) -> list[float]:
225
- """
226
- Embed a query string.
227
-
228
- Args:
229
- query (str): Query string to embed.
230
-
231
- Returns:
232
- list[float]: Embedding of the query.
233
- """
234
- embedding = await self._embed([query])
235
- return embedding[0]
137
+ def get_client(self) -> "AsyncMixedbreadAI":
138
+ return self.config.get_async_client()
139
+
140
+ async def embed_batch(self, client: "AsyncMixedbreadAI", batch: list[str]) -> list[list[float]]:
141
+ response = await client.embeddings(
142
+ model=self.config.embedder_model_name,
143
+ normalized=True,
144
+ encoding_format=ENCODING_FORMAT,
145
+ truncation_strategy=TRUNCATION_STRATEGY,
146
+ request_options=self.get_request_options(),
147
+ input=batch,
148
+ )
149
+ return [datum.embedding for datum in response.data]