unstructured-ingest 0.5.2__py3-none-any.whl → 0.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of unstructured-ingest might be problematic. Click here for more details.

@@ -15,7 +15,7 @@ def generate_embedder_config_params() -> dict:
15
15
  "region_name": fake.city(),
16
16
  }
17
17
  if random.random() < 0.5:
18
- params["embed_model_name"] = fake.word()
18
+ params["embedder_model_name"] = fake.word()
19
19
  return params
20
20
 
21
21
 
@@ -16,7 +16,7 @@ fake = faker.Faker()
16
16
  def generate_embedder_config_params() -> dict:
17
17
  params = {}
18
18
  if random.random() < 0.5:
19
- params["embed_model_name"] = fake.word() if random.random() < 0.5 else None
19
+ params["embedder_model_name"] = fake.word() if random.random() < 0.5 else None
20
20
  params["embedder_model_kwargs"] = (
21
21
  generate_random_dictionary(key_type=str, value_type=Any)
22
22
  if random.random() < 0.5
@@ -1 +1 @@
1
- __version__ = "0.5.2" # pragma: no cover
1
+ __version__ = "0.5.3" # pragma: no cover
@@ -44,7 +44,13 @@ class AzureOpenAIEmbeddingConfig(OpenAIEmbeddingConfig):
44
44
  class AzureOpenAIEmbeddingEncoder(OpenAIEmbeddingEncoder):
45
45
  config: AzureOpenAIEmbeddingConfig
46
46
 
47
+ def get_client(self) -> "AzureOpenAI":
48
+ return self.config.get_client()
49
+
47
50
 
48
51
  @dataclass
49
52
  class AsyncAzureOpenAIEmbeddingEncoder(AsyncOpenAIEmbeddingEncoder):
50
53
  config: AzureOpenAIEmbeddingConfig
54
+
55
+ def get_client(self) -> "AsyncAzureOpenAI":
56
+ return self.config.get_async_client()
@@ -15,7 +15,13 @@ from unstructured_ingest.embed.interfaces import (
15
15
  )
16
16
  from unstructured_ingest.logger import logger
17
17
  from unstructured_ingest.utils.dep_check import requires_dependencies
18
- from unstructured_ingest.v2.errors import ProviderError, RateLimitError, UserAuthError, UserError
18
+ from unstructured_ingest.v2.errors import (
19
+ ProviderError,
20
+ RateLimitError,
21
+ UserAuthError,
22
+ UserError,
23
+ is_internal_error,
24
+ )
19
25
 
20
26
  if TYPE_CHECKING:
21
27
  from botocore.client import BaseClient
@@ -51,9 +57,11 @@ class BedrockEmbeddingConfig(EmbeddingConfig):
51
57
  aws_access_key_id: SecretStr
52
58
  aws_secret_access_key: SecretStr
53
59
  region_name: str = "us-west-2"
54
- embed_model_name: str = Field(default="amazon.titan-embed-text-v1", alias="model_name")
60
+ embedder_model_name: str = Field(default="amazon.titan-embed-text-v1", alias="model_name")
55
61
 
56
62
  def wrap_error(self, e: Exception) -> Exception:
63
+ if is_internal_error(e=e):
64
+ return e
57
65
  from botocore.exceptions import ClientError
58
66
 
59
67
  if isinstance(e, ClientError):
@@ -122,7 +130,7 @@ class BedrockEmbeddingEncoder(BaseEmbeddingEncoder):
122
130
 
123
131
  def embed_query(self, query: str) -> list[float]:
124
132
  """Call out to Bedrock embedding endpoint."""
125
- provider = self.config.embed_model_name.split(".")[0]
133
+ provider = self.config.embedder_model_name.split(".")[0]
126
134
  body = conform_query(query=query, provider=provider)
127
135
 
128
136
  bedrock_client = self.config.get_client()
@@ -130,7 +138,7 @@ class BedrockEmbeddingEncoder(BaseEmbeddingEncoder):
130
138
  try:
131
139
  response = bedrock_client.invoke_model(
132
140
  body=json.dumps(body),
133
- modelId=self.config.embed_model_name,
141
+ modelId=self.config.embedder_model_name,
134
142
  accept="application/json",
135
143
  contentType="application/json",
136
144
  )
@@ -148,6 +156,8 @@ class BedrockEmbeddingEncoder(BaseEmbeddingEncoder):
148
156
  def embed_documents(self, elements: list[dict]) -> list[dict]:
149
157
  elements = elements.copy()
150
158
  elements_with_text = [e for e in elements if e.get("text")]
159
+ if not elements_with_text:
160
+ return elements
151
161
  embeddings = [self.embed_query(query=e["text"]) for e in elements_with_text]
152
162
  for element, embedding in zip(elements_with_text, embeddings):
153
163
  element[EMBEDDINGS_KEY] = embedding
@@ -163,7 +173,7 @@ class AsyncBedrockEmbeddingEncoder(AsyncBaseEmbeddingEncoder):
163
173
 
164
174
  async def embed_query(self, query: str) -> list[float]:
165
175
  """Call out to Bedrock embedding endpoint."""
166
- provider = self.config.embed_model_name.split(".")[0]
176
+ provider = self.config.embedder_model_name.split(".")[0]
167
177
  body = conform_query(query=query, provider=provider)
168
178
  try:
169
179
  async with self.config.get_async_client() as bedrock_client:
@@ -171,7 +181,7 @@ class AsyncBedrockEmbeddingEncoder(AsyncBaseEmbeddingEncoder):
171
181
  try:
172
182
  response = await bedrock_client.invoke_model(
173
183
  body=json.dumps(body),
174
- modelId=self.config.embed_model_name,
184
+ modelId=self.config.embedder_model_name,
175
185
  accept="application/json",
176
186
  contentType="application/json",
177
187
  )
@@ -47,7 +47,7 @@ class HuggingFaceEmbeddingConfig(EmbeddingConfig):
47
47
  class HuggingFaceEmbeddingEncoder(BaseEmbeddingEncoder):
48
48
  config: HuggingFaceEmbeddingConfig
49
49
 
50
- def embed_query(self, query: str) -> list[float]:
50
+ def _embed_query(self, query: str) -> list[float]:
51
51
  return self._embed_documents(texts=[query])[0]
52
52
 
53
53
  def _embed_documents(self, texts: list[str]) -> list[list[float]]:
@@ -58,6 +58,8 @@ class HuggingFaceEmbeddingEncoder(BaseEmbeddingEncoder):
58
58
  def embed_documents(self, elements: list[dict]) -> list[dict]:
59
59
  elements = elements.copy()
60
60
  elements_with_text = [e for e in elements if e.get("text")]
61
+ if not elements_with_text:
62
+ return elements
61
63
  embeddings = self._embed_documents([e["text"] for e in elements_with_text])
62
64
  for element, embedding in zip(elements_with_text, embeddings):
63
65
  element[EMBEDDINGS_KEY] = embedding
@@ -1,11 +1,12 @@
1
- import asyncio
2
- from abc import ABC, abstractmethod
1
+ from abc import ABC
3
2
  from dataclasses import dataclass
4
- from typing import Optional
3
+ from typing import Any, Optional
5
4
 
6
5
  import numpy as np
7
6
  from pydantic import BaseModel, Field
8
7
 
8
+ from unstructured_ingest.utils.data_prep import batch_generator
9
+
9
10
  EMBEDDINGS_KEY = "embeddings"
10
11
 
11
12
 
@@ -50,21 +51,37 @@ class BaseEmbeddingEncoder(BaseEncoder, ABC):
50
51
  exemplary_embedding = self.get_exemplary_embedding()
51
52
  return np.isclose(np.linalg.norm(exemplary_embedding), 1.0)
52
53
 
53
- @abstractmethod
54
- def embed_documents(self, elements: list[dict]) -> list[dict]:
55
- pass
54
+ def get_client(self):
55
+ raise NotImplementedError
56
56
 
57
- @abstractmethod
58
- def embed_query(self, query: str) -> list[float]:
59
- pass
57
+ def embed_batch(self, client: Any, batch: list[str]) -> list[list[float]]:
58
+ raise NotImplementedError
60
59
 
61
- def _embed_documents(self, elements: list[str]) -> list[list[float]]:
62
- results = []
63
- for text in elements:
64
- response = self.embed_query(query=text)
65
- results.append(response)
60
+ def embed_documents(self, elements: list[dict]) -> list[dict]:
61
+ client = self.get_client()
62
+ elements = elements.copy()
63
+ elements_with_text = [e for e in elements if e.get("text")]
64
+ texts = [e["text"] for e in elements_with_text]
65
+ embeddings = []
66
+ try:
67
+ for batch in batch_generator(texts, batch_size=self.config.batch_size or len(texts)):
68
+ embeddings = self.embed_batch(client=client, batch=batch)
69
+ embeddings.extend(embeddings)
70
+ except Exception as e:
71
+ raise self.wrap_error(e=e)
72
+ for element, embedding in zip(elements_with_text, embeddings):
73
+ element[EMBEDDINGS_KEY] = embedding
74
+ return elements
75
+
76
+ def _embed_query(self, query: str) -> list[float]:
77
+ client = self.get_client()
78
+ return self.embed_batch(client=client, batch=[query])[0]
66
79
 
67
- return results
80
+ def embed_query(self, query: str) -> list[float]:
81
+ try:
82
+ return self._embed_query(query=query)
83
+ except Exception as e:
84
+ raise self.wrap_error(e=e)
68
85
 
69
86
 
70
87
  @dataclass
@@ -88,14 +105,35 @@ class AsyncBaseEmbeddingEncoder(BaseEncoder, ABC):
88
105
  exemplary_embedding = await self.get_exemplary_embedding()
89
106
  return np.isclose(np.linalg.norm(exemplary_embedding), 1.0)
90
107
 
91
- @abstractmethod
108
+ def get_client(self):
109
+ raise NotImplementedError
110
+
111
+ async def embed_batch(self, client: Any, batch: list[str]) -> list[list[float]]:
112
+ raise NotImplementedError
113
+
92
114
  async def embed_documents(self, elements: list[dict]) -> list[dict]:
93
- pass
115
+ client = self.get_client()
116
+ elements = elements.copy()
117
+ elements_with_text = [e for e in elements if e.get("text")]
118
+ texts = [e["text"] for e in elements_with_text]
119
+ embeddings = []
120
+ try:
121
+ for batch in batch_generator(texts, batch_size=self.config.batch_size or len(texts)):
122
+ embeddings = await self.embed_batch(client=client, batch=batch)
123
+ embeddings.extend(embeddings)
124
+ except Exception as e:
125
+ raise self.wrap_error(e=e)
126
+ for element, embedding in zip(elements_with_text, embeddings):
127
+ element[EMBEDDINGS_KEY] = embedding
128
+ return elements
129
+
130
+ async def _embed_query(self, query: str) -> list[float]:
131
+ client = self.get_client()
132
+ embeddings = await self.embed_batch(client=client, batch=[query])
133
+ return embeddings[0]
94
134
 
95
- @abstractmethod
96
135
  async def embed_query(self, query: str) -> list[float]:
97
- pass
98
-
99
- async def _embed_documents(self, elements: list[str]) -> list[list[float]]:
100
- results = await asyncio.gather(*[self.embed_query(query=text) for text in elements])
101
- return results
136
+ try:
137
+ return await self._embed_query(query=query)
138
+ except Exception as e:
139
+ raise self.wrap_error(e=e)
@@ -1,4 +1,3 @@
1
- import asyncio
2
1
  import os
3
2
  from dataclasses import dataclass
4
3
  from typing import TYPE_CHECKING
@@ -6,12 +5,10 @@ from typing import TYPE_CHECKING
6
5
  from pydantic import Field, SecretStr
7
6
 
8
7
  from unstructured_ingest.embed.interfaces import (
9
- EMBEDDINGS_KEY,
10
8
  AsyncBaseEmbeddingEncoder,
11
9
  BaseEmbeddingEncoder,
12
10
  EmbeddingConfig,
13
11
  )
14
- from unstructured_ingest.utils.data_prep import batch_generator
15
12
  from unstructured_ingest.utils.dep_check import requires_dependencies
16
13
 
17
14
  USER_AGENT = "@mixedbread-ai/unstructured"
@@ -85,7 +82,7 @@ class MixedbreadAIEmbeddingEncoder(BaseEmbeddingEncoder):
85
82
 
86
83
  def get_exemplary_embedding(self) -> list[float]:
87
84
  """Get an exemplary embedding to determine dimensions and unit vector status."""
88
- return self._embed(["Q"])[0]
85
+ return self.embed_query(query="Q")
89
86
 
90
87
  @requires_dependencies(
91
88
  ["mixedbread_ai"],
@@ -100,59 +97,19 @@ class MixedbreadAIEmbeddingEncoder(BaseEmbeddingEncoder):
100
97
  additional_headers={"User-Agent": USER_AGENT},
101
98
  )
102
99
 
103
- def _embed(self, texts: list[str]) -> list[list[float]]:
104
- """
105
- Embed a list of texts using the Mixedbread AI API.
106
-
107
- Args:
108
- texts (list[str]): List of texts to embed.
109
-
110
- Returns:
111
- list[list[float]]: List of embeddings.
112
- """
113
-
114
- responses = []
115
- client = self.config.get_client()
116
- for batch in batch_generator(texts, batch_size=self.config.batch_size or len(texts)):
117
- response = client.embeddings(
118
- model=self.config.embedder_model_name,
119
- normalized=True,
120
- encoding_format=ENCODING_FORMAT,
121
- truncation_strategy=TRUNCATION_STRATEGY,
122
- request_options=self.get_request_options(),
123
- input=batch,
124
- )
125
- responses.append(response)
126
- return [item.embedding for response in responses for item in response.data]
127
-
128
- def embed_documents(self, elements: list[dict]) -> list[dict]:
129
- """
130
- Embed a list of document elements.
131
-
132
- Args:
133
- elements (list[Element]): List of document elements.
134
-
135
- Returns:
136
- list[Element]: Elements with embeddings.
137
- """
138
- elements = elements.copy()
139
- elements_with_text = [e for e in elements if e.get("text")]
140
- embeddings = self._embed([e["text"] for e in elements_with_text])
141
- for element, embedding in zip(elements_with_text, embeddings):
142
- element[EMBEDDINGS_KEY] = embedding
143
- return elements
144
-
145
- def embed_query(self, query: str) -> list[float]:
146
- """
147
- Embed a query string.
148
-
149
- Args:
150
- query (str): Query string to embed.
151
-
152
- Returns:
153
- list[float]: Embedding of the query.
154
- """
155
- return self._embed([query])[0]
100
+ def get_client(self) -> "MixedbreadAI":
101
+ return self.config.get_client()
102
+
103
+ def embed_batch(self, client: "MixedbreadAI", batch: list[str]) -> list[list[float]]:
104
+ response = client.embeddings(
105
+ model=self.config.embedder_model_name,
106
+ normalized=True,
107
+ encoding_format=ENCODING_FORMAT,
108
+ truncation_strategy=TRUNCATION_STRATEGY,
109
+ request_options=self.get_request_options(),
110
+ input=batch,
111
+ )
112
+ return [datum.embedding for datum in response.data]
156
113
 
157
114
 
158
115
  @dataclass
@@ -162,8 +119,7 @@ class AsyncMixedbreadAIEmbeddingEncoder(AsyncBaseEmbeddingEncoder):
162
119
 
163
120
  async def get_exemplary_embedding(self) -> list[float]:
164
121
  """Get an exemplary embedding to determine dimensions and unit vector status."""
165
- embedding = await self._embed(["Q"])
166
- return embedding[0]
122
+ return await self.embed_query(query="Q")
167
123
 
168
124
  @requires_dependencies(
169
125
  ["mixedbread_ai"],
@@ -178,58 +134,16 @@ class AsyncMixedbreadAIEmbeddingEncoder(AsyncBaseEmbeddingEncoder):
178
134
  additional_headers={"User-Agent": USER_AGENT},
179
135
  )
180
136
 
181
- async def _embed(self, texts: list[str]) -> list[list[float]]:
182
- """
183
- Embed a list of texts using the Mixedbread AI API.
184
-
185
- Args:
186
- texts (list[str]): List of texts to embed.
187
-
188
- Returns:
189
- list[list[float]]: List of embeddings.
190
- """
191
- client = self.config.get_async_client()
192
- tasks = []
193
- for batch in batch_generator(texts, batch_size=self.config.batch_size or len(texts)):
194
- tasks.append(
195
- client.embeddings(
196
- model=self.config.embedder_model_name,
197
- normalized=True,
198
- encoding_format=ENCODING_FORMAT,
199
- truncation_strategy=TRUNCATION_STRATEGY,
200
- request_options=self.get_request_options(),
201
- input=batch,
202
- )
203
- )
204
- responses = await asyncio.gather(*tasks)
205
- return [item.embedding for response in responses for item in response.data]
206
-
207
- async def embed_documents(self, elements: list[dict]) -> list[dict]:
208
- """
209
- Embed a list of document elements.
210
-
211
- Args:
212
- elements (list[Element]): List of document elements.
213
-
214
- Returns:
215
- list[Element]: Elements with embeddings.
216
- """
217
- elements = elements.copy()
218
- elements_with_text = [e for e in elements if e.get("text")]
219
- embeddings = await self._embed([e["text"] for e in elements_with_text])
220
- for element, embedding in zip(elements_with_text, embeddings):
221
- element[EMBEDDINGS_KEY] = embedding
222
- return elements
223
-
224
- async def embed_query(self, query: str) -> list[float]:
225
- """
226
- Embed a query string.
227
-
228
- Args:
229
- query (str): Query string to embed.
230
-
231
- Returns:
232
- list[float]: Embedding of the query.
233
- """
234
- embedding = await self._embed([query])
235
- return embedding[0]
137
+ def get_client(self) -> "AsyncMixedbreadAI":
138
+ return self.config.get_async_client()
139
+
140
+ async def embed_batch(self, client: "AsyncMixedbreadAI", batch: list[str]) -> list[list[float]]:
141
+ response = await client.embeddings(
142
+ model=self.config.embedder_model_name,
143
+ normalized=True,
144
+ encoding_format=ENCODING_FORMAT,
145
+ truncation_strategy=TRUNCATION_STRATEGY,
146
+ request_options=self.get_request_options(),
147
+ input=batch,
148
+ )
149
+ return [datum.embedding for datum in response.data]
@@ -4,13 +4,11 @@ from typing import TYPE_CHECKING
4
4
  from pydantic import Field, SecretStr
5
5
 
6
6
  from unstructured_ingest.embed.interfaces import (
7
- EMBEDDINGS_KEY,
8
7
  AsyncBaseEmbeddingEncoder,
9
8
  BaseEmbeddingEncoder,
10
9
  EmbeddingConfig,
11
10
  )
12
11
  from unstructured_ingest.logger import logger
13
- from unstructured_ingest.utils.data_prep import batch_generator
14
12
  from unstructured_ingest.utils.dep_check import requires_dependencies
15
13
  from unstructured_ingest.v2.errors import (
16
14
  ProviderError,
@@ -18,6 +16,7 @@ from unstructured_ingest.v2.errors import (
18
16
  RateLimitError,
19
17
  UserAuthError,
20
18
  UserError,
19
+ is_internal_error,
21
20
  )
22
21
 
23
22
  if TYPE_CHECKING:
@@ -30,6 +29,8 @@ class OctoAiEmbeddingConfig(EmbeddingConfig):
30
29
  base_url: str = Field(default="https://text.octoai.run/v1")
31
30
 
32
31
  def wrap_error(self, e: Exception) -> Exception:
32
+ if is_internal_error(e=e):
33
+ return e
33
34
  # https://platform.openai.com/docs/guides/error-codes/api-errors
34
35
  from openai import APIStatusError
35
36
 
@@ -81,31 +82,17 @@ class OctoAIEmbeddingEncoder(BaseEmbeddingEncoder):
81
82
  def wrap_error(self, e: Exception) -> Exception:
82
83
  return self.config.wrap_error(e=e)
83
84
 
84
- def embed_query(self, query: str):
85
- try:
86
- client = self.config.get_client()
87
- response = client.embeddings.create(input=query, model=self.config.embedder_model_name)
88
- except Exception as e:
89
- raise self.wrap_error(e=e)
85
+ def _embed_query(self, query: str):
86
+ client = self.get_client()
87
+ response = client.embeddings.create(input=query, model=self.config.embedder_model_name)
90
88
  return response.data[0].embedding
91
89
 
92
- def embed_documents(self, elements: list[dict]) -> list[dict]:
93
- elements = elements.copy()
94
- elements_with_text = [e for e in elements if e.get("text")]
95
- texts = [e["text"] for e in elements_with_text]
96
- embeddings = []
97
- client = self.config.get_client()
98
- try:
99
- for batch in batch_generator(texts, batch_size=self.config.batch_size or len(texts)):
100
- response = client.embeddings.create(
101
- input=batch, model=self.config.embedder_model_name
102
- )
103
- embeddings.extend([data.embedding for data in response.data])
104
- except Exception as e:
105
- raise self.wrap_error(e=e)
106
- for element, embedding in zip(elements_with_text, embeddings):
107
- element[EMBEDDINGS_KEY] = embedding
108
- return elements
90
+ def get_client(self) -> "OpenAI":
91
+ return self.config.get_client()
92
+
93
+ def embed_batch(self, client: "OpenAI", batch: list[str]) -> list[list[float]]:
94
+ response = client.embeddings.create(input=batch, model=self.config.embedder_model_name)
95
+ return [data.embedding for data in response.data]
109
96
 
110
97
 
111
98
  @dataclass
@@ -115,30 +102,11 @@ class AsyncOctoAIEmbeddingEncoder(AsyncBaseEmbeddingEncoder):
115
102
  def wrap_error(self, e: Exception) -> Exception:
116
103
  return self.config.wrap_error(e=e)
117
104
 
118
- async def embed_query(self, query: str):
119
- client = self.config.get_async_client()
120
- try:
121
- response = await client.embeddings.create(
122
- input=query, model=self.config.embedder_model_name
123
- )
124
- except Exception as e:
125
- raise self.wrap_error(e=e)
126
- return response.data[0].embedding
105
+ def get_client(self) -> "AsyncOpenAI":
106
+ return self.config.get_async_client()
127
107
 
128
- async def embed_documents(self, elements: list[dict]) -> list[dict]:
129
- elements = elements.copy()
130
- elements_with_text = [e for e in elements if e.get("text")]
131
- texts = [e["text"] for e in elements_with_text]
132
- client = self.config.get_async_client()
133
- embeddings = []
134
- try:
135
- for batch in batch_generator(texts, batch_size=self.config.batch_size or len(texts)):
136
- response = await client.embeddings.create(
137
- input=batch, model=self.config.embedder_model_name
138
- )
139
- embeddings.extend([data.embedding for data in response.data])
140
- except Exception as e:
141
- raise self.wrap_error(e=e)
142
- for element, embedding in zip(elements_with_text, embeddings):
143
- element[EMBEDDINGS_KEY] = embedding
144
- return elements
108
+ async def embed_batch(self, client: "AsyncOpenAI", batch: list[str]) -> list[list[float]]:
109
+ response = await client.embeddings.create(
110
+ input=batch, model=self.config.embedder_model_name
111
+ )
112
+ return [data.embedding for data in response.data]
@@ -4,13 +4,11 @@ from typing import TYPE_CHECKING
4
4
  from pydantic import Field, SecretStr
5
5
 
6
6
  from unstructured_ingest.embed.interfaces import (
7
- EMBEDDINGS_KEY,
8
7
  AsyncBaseEmbeddingEncoder,
9
8
  BaseEmbeddingEncoder,
10
9
  EmbeddingConfig,
11
10
  )
12
11
  from unstructured_ingest.logger import logger
13
- from unstructured_ingest.utils.data_prep import batch_generator
14
12
  from unstructured_ingest.utils.dep_check import requires_dependencies
15
13
  from unstructured_ingest.v2.errors import (
16
14
  ProviderError,
@@ -18,6 +16,7 @@ from unstructured_ingest.v2.errors import (
18
16
  RateLimitError,
19
17
  UserAuthError,
20
18
  UserError,
19
+ is_internal_error,
21
20
  )
22
21
 
23
22
  if TYPE_CHECKING:
@@ -29,6 +28,8 @@ class OpenAIEmbeddingConfig(EmbeddingConfig):
29
28
  embedder_model_name: str = Field(default="text-embedding-ada-002", alias="model_name")
30
29
 
31
30
  def wrap_error(self, e: Exception) -> Exception:
31
+ if is_internal_error(e=e):
32
+ return e
32
33
  # https://platform.openai.com/docs/guides/error-codes/api-errors
33
34
  from openai import APIStatusError
34
35
 
@@ -72,32 +73,12 @@ class OpenAIEmbeddingEncoder(BaseEmbeddingEncoder):
72
73
  def wrap_error(self, e: Exception) -> Exception:
73
74
  return self.config.wrap_error(e=e)
74
75
 
75
- def embed_query(self, query: str) -> list[float]:
76
-
77
- client = self.config.get_client()
78
- try:
79
- response = client.embeddings.create(input=query, model=self.config.embedder_model_name)
80
- except Exception as e:
81
- raise self.wrap_error(e=e)
82
- return response.data[0].embedding
83
-
84
- def embed_documents(self, elements: list[dict]) -> list[dict]:
85
- client = self.config.get_client()
86
- elements = elements.copy()
87
- elements_with_text = [e for e in elements if e.get("text")]
88
- texts = [e["text"] for e in elements_with_text]
89
- embeddings = []
90
- try:
91
- for batch in batch_generator(texts, batch_size=self.config.batch_size or len(texts)):
92
- response = client.embeddings.create(
93
- input=batch, model=self.config.embedder_model_name
94
- )
95
- embeddings.extend([data.embedding for data in response.data])
96
- except Exception as e:
97
- raise self.wrap_error(e=e)
98
- for element, embedding in zip(elements_with_text, embeddings):
99
- element[EMBEDDINGS_KEY] = embedding
100
- return elements
76
+ def get_client(self) -> "OpenAI":
77
+ return self.config.get_client()
78
+
79
+ def embed_batch(self, client: "OpenAI", batch: list[str]) -> list[list[float]]:
80
+ response = client.embeddings.create(input=batch, model=self.config.embedder_model_name)
81
+ return [data.embedding for data in response.data]
101
82
 
102
83
 
103
84
  @dataclass
@@ -107,30 +88,11 @@ class AsyncOpenAIEmbeddingEncoder(AsyncBaseEmbeddingEncoder):
107
88
  def wrap_error(self, e: Exception) -> Exception:
108
89
  return self.config.wrap_error(e=e)
109
90
 
110
- async def embed_query(self, query: str) -> list[float]:
111
- client = self.config.get_async_client()
112
- try:
113
- response = await client.embeddings.create(
114
- input=query, model=self.config.embedder_model_name
115
- )
116
- except Exception as e:
117
- raise self.wrap_error(e=e)
118
- return response.data[0].embedding
119
-
120
- async def embed_documents(self, elements: list[dict]) -> list[dict]:
121
- client = self.config.get_async_client()
122
- elements = elements.copy()
123
- elements_with_text = [e for e in elements if e.get("text")]
124
- texts = [e["text"] for e in elements_with_text]
125
- embeddings = []
126
- try:
127
- for batch in batch_generator(texts, batch_size=self.config.batch_size or len(texts)):
128
- response = await client.embeddings.create(
129
- input=batch, model=self.config.embedder_model_name
130
- )
131
- embeddings.extend([data.embedding for data in response.data])
132
- except Exception as e:
133
- raise self.wrap_error(e=e)
134
- for element, embedding in zip(elements_with_text, embeddings):
135
- element[EMBEDDINGS_KEY] = embedding
136
- return elements
91
+ def get_client(self) -> "AsyncOpenAI":
92
+ return self.config.get_async_client()
93
+
94
+ async def embed_batch(self, client: "AsyncOpenAI", batch: list[str]) -> list[list[float]]:
95
+ response = await client.embeddings.create(
96
+ input=batch, model=self.config.embedder_model_name
97
+ )
98
+ return [data.embedding for data in response.data]