unstructured-ingest 0.0.19__py3-none-any.whl → 0.0.22__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of unstructured-ingest might be problematic. Click here for more details.
- unstructured_ingest/__version__.py +1 -1
- unstructured_ingest/cli/cmds/astradb.py +2 -2
- unstructured_ingest/connector/astradb.py +54 -24
- unstructured_ingest/embed/bedrock.py +56 -19
- unstructured_ingest/embed/huggingface.py +22 -22
- unstructured_ingest/embed/interfaces.py +11 -4
- unstructured_ingest/embed/mixedbreadai.py +17 -17
- unstructured_ingest/embed/octoai.py +7 -7
- unstructured_ingest/embed/openai.py +15 -20
- unstructured_ingest/embed/vertexai.py +25 -17
- unstructured_ingest/embed/voyageai.py +22 -17
- unstructured_ingest/v2/cli/base/cmd.py +1 -1
- unstructured_ingest/v2/interfaces/connector.py +1 -1
- unstructured_ingest/v2/pipeline/pipeline.py +3 -1
- unstructured_ingest/v2/pipeline/steps/chunk.py +1 -1
- unstructured_ingest/v2/pipeline/steps/download.py +6 -2
- unstructured_ingest/v2/pipeline/steps/embed.py +1 -1
- unstructured_ingest/v2/pipeline/steps/filter.py +1 -1
- unstructured_ingest/v2/pipeline/steps/index.py +4 -2
- unstructured_ingest/v2/pipeline/steps/partition.py +1 -1
- unstructured_ingest/v2/pipeline/steps/stage.py +3 -1
- unstructured_ingest/v2/pipeline/steps/uncompress.py +1 -1
- unstructured_ingest/v2/pipeline/steps/upload.py +6 -2
- unstructured_ingest/v2/processes/chunker.py +8 -29
- unstructured_ingest/v2/processes/connectors/airtable.py +1 -1
- unstructured_ingest/v2/processes/connectors/astradb.py +26 -19
- unstructured_ingest/v2/processes/connectors/databricks_volumes.py +11 -8
- unstructured_ingest/v2/processes/connectors/elasticsearch.py +2 -2
- unstructured_ingest/v2/processes/connectors/fsspec/azure.py +31 -5
- unstructured_ingest/v2/processes/connectors/fsspec/box.py +31 -2
- unstructured_ingest/v2/processes/connectors/fsspec/dropbox.py +36 -8
- unstructured_ingest/v2/processes/connectors/fsspec/fsspec.py +25 -77
- unstructured_ingest/v2/processes/connectors/fsspec/gcs.py +30 -1
- unstructured_ingest/v2/processes/connectors/fsspec/s3.py +15 -18
- unstructured_ingest/v2/processes/connectors/fsspec/sftp.py +22 -1
- unstructured_ingest/v2/processes/connectors/milvus.py +2 -2
- unstructured_ingest/v2/processes/connectors/opensearch.py +2 -2
- unstructured_ingest/v2/processes/partitioner.py +9 -55
- unstructured_ingest/v2/unstructured_api.py +87 -0
- unstructured_ingest/v2/utils.py +1 -1
- unstructured_ingest-0.0.22.dist-info/METADATA +186 -0
- {unstructured_ingest-0.0.19.dist-info → unstructured_ingest-0.0.22.dist-info}/RECORD +46 -45
- {unstructured_ingest-0.0.19.dist-info → unstructured_ingest-0.0.22.dist-info}/WHEEL +1 -1
- unstructured_ingest-0.0.19.dist-info/METADATA +0 -639
- {unstructured_ingest-0.0.19.dist-info → unstructured_ingest-0.0.22.dist-info}/LICENSE.md +0 -0
- {unstructured_ingest-0.0.19.dist-info → unstructured_ingest-0.0.22.dist-info}/entry_points.txt +0 -0
- {unstructured_ingest-0.0.19.dist-info → unstructured_ingest-0.0.22.dist-info}/top_level.txt +0 -0
|
@@ -1 +1 @@
|
|
|
1
|
-
__version__ = "0.0.
|
|
1
|
+
__version__ = "0.0.22" # pragma: no cover
|
|
@@ -37,11 +37,11 @@ class AstraDBCliConfig(SimpleAstraDBConfig, CliConfig):
|
|
|
37
37
|
"numbers, and underscores.",
|
|
38
38
|
),
|
|
39
39
|
click.Option(
|
|
40
|
-
["--
|
|
40
|
+
["--keyspace"],
|
|
41
41
|
required=False,
|
|
42
42
|
default=None,
|
|
43
43
|
type=str,
|
|
44
|
-
help="The Astra DB connection
|
|
44
|
+
help="The Astra DB connection keyspace.",
|
|
45
45
|
),
|
|
46
46
|
]
|
|
47
47
|
return options
|
|
@@ -24,7 +24,8 @@ from unstructured_ingest.utils.data_prep import batch_generator, flatten_dict
|
|
|
24
24
|
from unstructured_ingest.utils.dep_check import requires_dependencies
|
|
25
25
|
|
|
26
26
|
if t.TYPE_CHECKING:
|
|
27
|
-
from astrapy
|
|
27
|
+
from astrapy import Collection as AstraDBCollection
|
|
28
|
+
from astrapy import Database as AstraDB
|
|
28
29
|
|
|
29
30
|
NON_INDEXED_FIELDS = ["metadata._node_content", "content"]
|
|
30
31
|
|
|
@@ -39,6 +40,7 @@ class AstraDBAccessConfig(AccessConfig):
|
|
|
39
40
|
class SimpleAstraDBConfig(BaseConnectorConfig):
|
|
40
41
|
access_config: AstraDBAccessConfig
|
|
41
42
|
collection_name: str
|
|
43
|
+
keyspace: t.Optional[str] = None
|
|
42
44
|
namespace: t.Optional[str] = None
|
|
43
45
|
|
|
44
46
|
|
|
@@ -98,22 +100,30 @@ class AstraDBSourceConnector(SourceConnectorCleanupMixin, BaseSourceConnector):
|
|
|
98
100
|
@requires_dependencies(["astrapy"], extras="astradb")
|
|
99
101
|
def astra_db_collection(self) -> "AstraDBCollection":
|
|
100
102
|
if self._astra_db_collection is None:
|
|
101
|
-
from astrapy
|
|
103
|
+
from astrapy import DataAPIClient as AstraDBClient
|
|
102
104
|
|
|
103
|
-
#
|
|
105
|
+
# Choose keyspace or deprecated namespace
|
|
106
|
+
keyspace_param = self.connector_config.keyspace or self.connector_config.namespace
|
|
107
|
+
|
|
108
|
+
# Create a client object to interact with the Astra DB
|
|
104
109
|
# caller_name/version for Astra DB tracking
|
|
105
|
-
|
|
106
|
-
api_endpoint=self.connector_config.access_config.api_endpoint,
|
|
107
|
-
token=self.connector_config.access_config.token,
|
|
108
|
-
namespace=self.connector_config.namespace,
|
|
110
|
+
my_client = AstraDBClient(
|
|
109
111
|
caller_name=integration_name,
|
|
110
112
|
caller_version=integration_version,
|
|
111
113
|
)
|
|
112
114
|
|
|
113
|
-
#
|
|
114
|
-
self.
|
|
115
|
-
|
|
115
|
+
# Get the database object
|
|
116
|
+
self._astra_db = my_client.get_database(
|
|
117
|
+
api_endpoint=self.connector_config.access_config.api_endpoint,
|
|
118
|
+
token=self.connector_config.access_config.token,
|
|
119
|
+
keyspace=keyspace_param,
|
|
116
120
|
)
|
|
121
|
+
|
|
122
|
+
# Create and connect to the newly created collection
|
|
123
|
+
self._astra_db_collection = self._astra_db.get_collection(
|
|
124
|
+
name=self.connector_config.collection_name,
|
|
125
|
+
)
|
|
126
|
+
|
|
117
127
|
return self._astra_db_collection # type: ignore
|
|
118
128
|
|
|
119
129
|
@requires_dependencies(["astrapy"], extras="astradb")
|
|
@@ -132,8 +142,14 @@ class AstraDBSourceConnector(SourceConnectorCleanupMixin, BaseSourceConnector):
|
|
|
132
142
|
@requires_dependencies(["astrapy"], extras="astradb")
|
|
133
143
|
def get_ingest_docs(self): # type: ignore
|
|
134
144
|
# Perform the find operation
|
|
135
|
-
|
|
145
|
+
astra_db_docs_cursor = self.astra_db_collection.find({})
|
|
136
146
|
|
|
147
|
+
# Iterate over the cursor
|
|
148
|
+
astra_db_docs = []
|
|
149
|
+
for result in astra_db_docs_cursor:
|
|
150
|
+
astra_db_docs.append(result)
|
|
151
|
+
|
|
152
|
+
# Create a list of AstraDBIngestDoc objects
|
|
137
153
|
doc_list = []
|
|
138
154
|
for record in astra_db_docs:
|
|
139
155
|
doc = AstraDBIngestDoc(
|
|
@@ -182,30 +198,41 @@ class AstraDBDestinationConnector(BaseDestinationConnector):
|
|
|
182
198
|
@requires_dependencies(["astrapy"], extras="astradb")
|
|
183
199
|
def astra_db_collection(self) -> "AstraDBCollection":
|
|
184
200
|
if self._astra_db_collection is None:
|
|
185
|
-
from astrapy
|
|
201
|
+
from astrapy import DataAPIClient as AstraDBClient
|
|
202
|
+
from astrapy.exceptions import CollectionAlreadyExistsException
|
|
203
|
+
|
|
204
|
+
# Choose keyspace or deprecated namespace
|
|
205
|
+
keyspace_param = self.connector_config.keyspace or self.connector_config.namespace
|
|
186
206
|
|
|
187
207
|
collection_name = self.connector_config.collection_name
|
|
188
208
|
embedding_dimension = self.write_config.embedding_dimension
|
|
189
|
-
|
|
190
|
-
# If the user has requested an indexing policy, pass it to the Astra DB
|
|
191
209
|
requested_indexing_policy = self.write_config.requested_indexing_policy
|
|
192
|
-
options = {"indexing": requested_indexing_policy} if requested_indexing_policy else None
|
|
193
210
|
|
|
211
|
+
# Create a client object to interact with the Astra DB
|
|
194
212
|
# caller_name/version for Astra DB tracking
|
|
195
|
-
|
|
196
|
-
api_endpoint=self.connector_config.access_config.api_endpoint,
|
|
197
|
-
token=self.connector_config.access_config.token,
|
|
198
|
-
namespace=self.connector_config.namespace,
|
|
213
|
+
my_client = AstraDBClient(
|
|
199
214
|
caller_name=integration_name,
|
|
200
215
|
caller_version=integration_version,
|
|
201
216
|
)
|
|
202
217
|
|
|
203
|
-
#
|
|
204
|
-
self.
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
218
|
+
# Get the database object
|
|
219
|
+
self._astra_db = my_client.get_database(
|
|
220
|
+
api_endpoint=self.connector_config.access_config.api_endpoint,
|
|
221
|
+
token=self.connector_config.access_config.token,
|
|
222
|
+
keyspace=keyspace_param,
|
|
208
223
|
)
|
|
224
|
+
|
|
225
|
+
# Create and connect to the newly created collection
|
|
226
|
+
try:
|
|
227
|
+
self._astra_db_collection = self._astra_db.create_collection(
|
|
228
|
+
name=collection_name,
|
|
229
|
+
dimension=embedding_dimension,
|
|
230
|
+
indexing=requested_indexing_policy,
|
|
231
|
+
)
|
|
232
|
+
except CollectionAlreadyExistsException as e:
|
|
233
|
+
logger.info(f"{e}", exc_info=True)
|
|
234
|
+
self._astra_db_collection = self._astra_db.get_collection(name=collection_name)
|
|
235
|
+
|
|
209
236
|
return self._astra_db_collection
|
|
210
237
|
|
|
211
238
|
@requires_dependencies(["astrapy"], extras="astradb")
|
|
@@ -224,6 +251,9 @@ class AstraDBDestinationConnector(BaseDestinationConnector):
|
|
|
224
251
|
def write_dict(self, *args, elements_dict: t.List[t.Dict[str, t.Any]], **kwargs) -> None:
|
|
225
252
|
logger.info(f"inserting / updating {len(elements_dict)} documents to Astra DB.")
|
|
226
253
|
|
|
254
|
+
if self._astra_db_collection is None:
|
|
255
|
+
raise DestinationConnectionError("Astra DB collection not available for insertion.")
|
|
256
|
+
|
|
227
257
|
astra_db_batch_size = self.write_config.batch_size
|
|
228
258
|
|
|
229
259
|
for batch in batch_generator(elements_dict, astra_db_batch_size):
|
|
@@ -1,38 +1,43 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import os
|
|
1
3
|
from dataclasses import dataclass
|
|
2
|
-
from typing import TYPE_CHECKING
|
|
4
|
+
from typing import TYPE_CHECKING
|
|
3
5
|
|
|
4
6
|
import numpy as np
|
|
5
|
-
from pydantic import SecretStr
|
|
7
|
+
from pydantic import Field, SecretStr
|
|
6
8
|
|
|
7
9
|
from unstructured_ingest.embed.interfaces import BaseEmbeddingEncoder, EmbeddingConfig
|
|
8
10
|
from unstructured_ingest.utils.dep_check import requires_dependencies
|
|
9
11
|
|
|
10
12
|
if TYPE_CHECKING:
|
|
11
|
-
from
|
|
13
|
+
from botocore.client import BaseClient
|
|
14
|
+
|
|
15
|
+
class BedrockClient(BaseClient):
|
|
16
|
+
def invoke_model(self, body: str, modelId: str, trace: str) -> dict:
|
|
17
|
+
pass
|
|
12
18
|
|
|
13
19
|
|
|
14
20
|
class BedrockEmbeddingConfig(EmbeddingConfig):
|
|
15
21
|
aws_access_key_id: SecretStr
|
|
16
22
|
aws_secret_access_key: SecretStr
|
|
17
23
|
region_name: str = "us-west-2"
|
|
24
|
+
embed_model_name: str = Field(default="amazon.titan-embed-text-v1", alias="model_name")
|
|
18
25
|
|
|
19
26
|
@requires_dependencies(
|
|
20
|
-
["boto3", "numpy", "
|
|
27
|
+
["boto3", "numpy", "botocore"],
|
|
21
28
|
extras="bedrock",
|
|
22
29
|
)
|
|
23
|
-
def get_client(self) -> "
|
|
30
|
+
def get_client(self) -> "BedrockClient":
|
|
24
31
|
# delay import only when needed
|
|
25
32
|
import boto3
|
|
26
|
-
from langchain_community.embeddings import BedrockEmbeddings
|
|
27
33
|
|
|
28
|
-
|
|
34
|
+
bedrock_client = boto3.client(
|
|
29
35
|
service_name="bedrock-runtime",
|
|
30
36
|
aws_access_key_id=self.aws_access_key_id.get_secret_value(),
|
|
31
37
|
aws_secret_access_key=self.aws_secret_access_key.get_secret_value(),
|
|
32
38
|
region_name=self.region_name,
|
|
33
39
|
)
|
|
34
40
|
|
|
35
|
-
bedrock_client = BedrockEmbeddings(client=bedrock_runtime)
|
|
36
41
|
return bedrock_client
|
|
37
42
|
|
|
38
43
|
|
|
@@ -40,28 +45,60 @@ class BedrockEmbeddingConfig(EmbeddingConfig):
|
|
|
40
45
|
class BedrockEmbeddingEncoder(BaseEmbeddingEncoder):
|
|
41
46
|
config: BedrockEmbeddingConfig
|
|
42
47
|
|
|
43
|
-
def get_exemplary_embedding(self) ->
|
|
48
|
+
def get_exemplary_embedding(self) -> list[float]:
|
|
44
49
|
return self.embed_query(query="Q")
|
|
45
50
|
|
|
46
|
-
def num_of_dimensions(self):
|
|
51
|
+
def num_of_dimensions(self) -> tuple[int, ...]:
|
|
47
52
|
exemplary_embedding = self.get_exemplary_embedding()
|
|
48
53
|
return np.shape(exemplary_embedding)
|
|
49
54
|
|
|
50
|
-
def is_unit_vector(self):
|
|
55
|
+
def is_unit_vector(self) -> bool:
|
|
51
56
|
exemplary_embedding = self.get_exemplary_embedding()
|
|
52
57
|
return np.isclose(np.linalg.norm(exemplary_embedding), 1.0)
|
|
53
58
|
|
|
54
|
-
def embed_query(self, query):
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
59
|
+
def embed_query(self, query: str) -> list[float]:
|
|
60
|
+
"""Call out to Bedrock embedding endpoint."""
|
|
61
|
+
# replace newlines, which can negatively affect performance.
|
|
62
|
+
text = query.replace(os.linesep, " ")
|
|
63
|
+
|
|
64
|
+
# format input body for provider
|
|
65
|
+
provider = self.config.embed_model_name.split(".")[0]
|
|
66
|
+
input_body = {}
|
|
67
|
+
if provider == "cohere":
|
|
68
|
+
if "input_type" not in input_body:
|
|
69
|
+
input_body["input_type"] = "search_document"
|
|
70
|
+
input_body["texts"] = [text]
|
|
71
|
+
else:
|
|
72
|
+
# includes common provider == "amazon"
|
|
73
|
+
input_body["inputText"] = text
|
|
74
|
+
body = json.dumps(input_body)
|
|
75
|
+
|
|
76
|
+
try:
|
|
77
|
+
bedrock_client = self.config.get_client()
|
|
78
|
+
# invoke bedrock API
|
|
79
|
+
response = bedrock_client.invoke_model(
|
|
80
|
+
body=body,
|
|
81
|
+
modelId=self.config.embed_model_name,
|
|
82
|
+
accept="application/json",
|
|
83
|
+
contentType="application/json",
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
# format output based on provider
|
|
87
|
+
response_body = json.loads(response.get("body").read())
|
|
88
|
+
if provider == "cohere":
|
|
89
|
+
return response_body.get("embeddings")[0]
|
|
90
|
+
else:
|
|
91
|
+
# includes common provider == "amazon"
|
|
92
|
+
return response_body.get("embedding")
|
|
93
|
+
except Exception as e:
|
|
94
|
+
raise ValueError(f"Error raised by inference endpoint: {e}")
|
|
95
|
+
|
|
96
|
+
def embed_documents(self, elements: list[dict]) -> list[dict]:
|
|
97
|
+
embeddings = [self.embed_query(query=e.get("text", "")) for e in elements]
|
|
61
98
|
elements_with_embeddings = self._add_embeddings_to_elements(elements, embeddings)
|
|
62
99
|
return elements_with_embeddings
|
|
63
100
|
|
|
64
|
-
def _add_embeddings_to_elements(self, elements, embeddings) ->
|
|
101
|
+
def _add_embeddings_to_elements(self, elements, embeddings) -> list[dict]:
|
|
65
102
|
assert len(elements) == len(embeddings)
|
|
66
103
|
elements_w_embedding = []
|
|
67
104
|
for i, element in enumerate(elements):
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
from dataclasses import dataclass
|
|
2
|
-
from typing import TYPE_CHECKING,
|
|
2
|
+
from typing import TYPE_CHECKING, Optional
|
|
3
3
|
|
|
4
4
|
import numpy as np
|
|
5
5
|
from pydantic import Field
|
|
@@ -8,7 +8,7 @@ from unstructured_ingest.embed.interfaces import BaseEmbeddingEncoder, Embedding
|
|
|
8
8
|
from unstructured_ingest.utils.dep_check import requires_dependencies
|
|
9
9
|
|
|
10
10
|
if TYPE_CHECKING:
|
|
11
|
-
from
|
|
11
|
+
from sentence_transformers import SentenceTransformer
|
|
12
12
|
|
|
13
13
|
|
|
14
14
|
class HuggingFaceEmbeddingConfig(EmbeddingConfig):
|
|
@@ -19,51 +19,51 @@ class HuggingFaceEmbeddingConfig(EmbeddingConfig):
|
|
|
19
19
|
default_factory=lambda: {"device": "cpu"}, alias="model_kwargs"
|
|
20
20
|
)
|
|
21
21
|
encode_kwargs: Optional[dict] = Field(default_factory=lambda: {"normalize_embeddings": False})
|
|
22
|
-
cache_folder: Optional[
|
|
22
|
+
cache_folder: Optional[str] = Field(default=None)
|
|
23
23
|
|
|
24
24
|
@requires_dependencies(
|
|
25
|
-
["
|
|
25
|
+
["sentence_transformers"],
|
|
26
26
|
extras="embed-huggingface",
|
|
27
27
|
)
|
|
28
|
-
def get_client(self) -> "
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
model_name=self.embedder_model_name,
|
|
34
|
-
model_kwargs=self.embedder_model_kwargs,
|
|
35
|
-
encode_kwargs=self.encode_kwargs,
|
|
28
|
+
def get_client(self) -> "SentenceTransformer":
|
|
29
|
+
from sentence_transformers import SentenceTransformer
|
|
30
|
+
|
|
31
|
+
return SentenceTransformer(
|
|
32
|
+
model_name_or_path=self.embedder_model_name,
|
|
36
33
|
cache_folder=self.cache_folder,
|
|
34
|
+
**self.embedder_model_kwargs,
|
|
37
35
|
)
|
|
38
|
-
return client
|
|
39
36
|
|
|
40
37
|
|
|
41
38
|
@dataclass
|
|
42
39
|
class HuggingFaceEmbeddingEncoder(BaseEmbeddingEncoder):
|
|
43
40
|
config: HuggingFaceEmbeddingConfig
|
|
44
41
|
|
|
45
|
-
def get_exemplary_embedding(self) ->
|
|
42
|
+
def get_exemplary_embedding(self) -> list[float]:
|
|
46
43
|
return self.embed_query(query="Q")
|
|
47
44
|
|
|
48
|
-
def num_of_dimensions(self):
|
|
45
|
+
def num_of_dimensions(self) -> tuple[int, ...]:
|
|
49
46
|
exemplary_embedding = self.get_exemplary_embedding()
|
|
50
47
|
return np.shape(exemplary_embedding)
|
|
51
48
|
|
|
52
|
-
def is_unit_vector(self):
|
|
49
|
+
def is_unit_vector(self) -> bool:
|
|
53
50
|
exemplary_embedding = self.get_exemplary_embedding()
|
|
54
51
|
return np.isclose(np.linalg.norm(exemplary_embedding), 1.0)
|
|
55
52
|
|
|
56
|
-
def embed_query(self, query):
|
|
57
|
-
|
|
58
|
-
return client.embed_query(str(query))
|
|
53
|
+
def embed_query(self, query: str) -> list[float]:
|
|
54
|
+
return self._embed_documents(texts=[query])[0]
|
|
59
55
|
|
|
60
|
-
def
|
|
56
|
+
def _embed_documents(self, texts: list[str]) -> list[list[float]]:
|
|
61
57
|
client = self.config.get_client()
|
|
62
|
-
embeddings = client.
|
|
58
|
+
embeddings = client.encode(texts, **self.config.encode_kwargs)
|
|
59
|
+
return embeddings.tolist()
|
|
60
|
+
|
|
61
|
+
def embed_documents(self, elements: list[dict]) -> list[dict]:
|
|
62
|
+
embeddings = self._embed_documents([e.get("text", "") for e in elements])
|
|
63
63
|
elements_with_embeddings = self._add_embeddings_to_elements(elements, embeddings)
|
|
64
64
|
return elements_with_embeddings
|
|
65
65
|
|
|
66
|
-
def _add_embeddings_to_elements(self, elements: list[dict], embeddings: list) ->
|
|
66
|
+
def _add_embeddings_to_elements(self, elements: list[dict], embeddings: list) -> list[dict]:
|
|
67
67
|
assert len(elements) == len(embeddings)
|
|
68
68
|
elements_w_embedding = []
|
|
69
69
|
|
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
from abc import ABC, abstractmethod
|
|
2
2
|
from dataclasses import dataclass
|
|
3
|
-
from typing import List, Tuple
|
|
4
3
|
|
|
5
4
|
from pydantic import BaseModel
|
|
6
5
|
|
|
@@ -19,7 +18,7 @@ class BaseEmbeddingEncoder(ABC):
|
|
|
19
18
|
|
|
20
19
|
@property
|
|
21
20
|
@abstractmethod
|
|
22
|
-
def num_of_dimensions(self) ->
|
|
21
|
+
def num_of_dimensions(self) -> tuple[int, ...]:
|
|
23
22
|
"""Number of dimensions for the embedding vector."""
|
|
24
23
|
|
|
25
24
|
@property
|
|
@@ -28,9 +27,17 @@ class BaseEmbeddingEncoder(ABC):
|
|
|
28
27
|
"""Denotes if the embedding vector is a unit vector."""
|
|
29
28
|
|
|
30
29
|
@abstractmethod
|
|
31
|
-
def embed_documents(self, elements:
|
|
30
|
+
def embed_documents(self, elements: list[dict]) -> list[dict]:
|
|
32
31
|
pass
|
|
33
32
|
|
|
34
33
|
@abstractmethod
|
|
35
|
-
def embed_query(self, query: str) ->
|
|
34
|
+
def embed_query(self, query: str) -> list[float]:
|
|
36
35
|
pass
|
|
36
|
+
|
|
37
|
+
def _embed_documents(self, elements: list[str]) -> list[list[float]]:
|
|
38
|
+
results = []
|
|
39
|
+
for text in elements:
|
|
40
|
+
response = self.embed_query(query=text)
|
|
41
|
+
results.append(response)
|
|
42
|
+
|
|
43
|
+
return results
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import os
|
|
2
2
|
from dataclasses import dataclass, field
|
|
3
|
-
from typing import TYPE_CHECKING,
|
|
3
|
+
from typing import TYPE_CHECKING, Optional
|
|
4
4
|
|
|
5
5
|
import numpy as np
|
|
6
6
|
from pydantic import Field, SecretStr
|
|
@@ -67,10 +67,10 @@ class MixedbreadAIEmbeddingEncoder(BaseEmbeddingEncoder):
|
|
|
67
67
|
|
|
68
68
|
config: MixedbreadAIEmbeddingConfig
|
|
69
69
|
|
|
70
|
-
_exemplary_embedding: Optional[
|
|
70
|
+
_exemplary_embedding: Optional[list[float]] = field(init=False, default=None)
|
|
71
71
|
_request_options: Optional["RequestOptions"] = field(init=False, default=None)
|
|
72
72
|
|
|
73
|
-
def get_exemplary_embedding(self) ->
|
|
73
|
+
def get_exemplary_embedding(self) -> list[float]:
|
|
74
74
|
"""Get an exemplary embedding to determine dimensions and unit vector status."""
|
|
75
75
|
return self._embed(["Q"])[0]
|
|
76
76
|
|
|
@@ -91,7 +91,7 @@ class MixedbreadAIEmbeddingEncoder(BaseEmbeddingEncoder):
|
|
|
91
91
|
)
|
|
92
92
|
|
|
93
93
|
@property
|
|
94
|
-
def num_of_dimensions(self):
|
|
94
|
+
def num_of_dimensions(self) -> tuple[int, ...]:
|
|
95
95
|
"""Get the number of dimensions for the embeddings."""
|
|
96
96
|
exemplary_embedding = self.get_exemplary_embedding()
|
|
97
97
|
return np.shape(exemplary_embedding)
|
|
@@ -102,15 +102,15 @@ class MixedbreadAIEmbeddingEncoder(BaseEmbeddingEncoder):
|
|
|
102
102
|
exemplary_embedding = self.get_exemplary_embedding()
|
|
103
103
|
return np.isclose(np.linalg.norm(exemplary_embedding), 1.0)
|
|
104
104
|
|
|
105
|
-
def _embed(self, texts:
|
|
105
|
+
def _embed(self, texts: list[str]) -> list[list[float]]:
|
|
106
106
|
"""
|
|
107
107
|
Embed a list of texts using the Mixedbread AI API.
|
|
108
108
|
|
|
109
109
|
Args:
|
|
110
|
-
texts (
|
|
110
|
+
texts (list[str]): List of texts to embed.
|
|
111
111
|
|
|
112
112
|
Returns:
|
|
113
|
-
|
|
113
|
+
list[list[float]]: List of embeddings.
|
|
114
114
|
"""
|
|
115
115
|
batch_size = BATCH_SIZE
|
|
116
116
|
batch_itr = range(0, len(texts), batch_size)
|
|
@@ -132,17 +132,17 @@ class MixedbreadAIEmbeddingEncoder(BaseEmbeddingEncoder):
|
|
|
132
132
|
|
|
133
133
|
@staticmethod
|
|
134
134
|
def _add_embeddings_to_elements(
|
|
135
|
-
elements:
|
|
136
|
-
) ->
|
|
135
|
+
elements: list[dict], embeddings: list[list[float]]
|
|
136
|
+
) -> list[dict]:
|
|
137
137
|
"""
|
|
138
138
|
Add embeddings to elements.
|
|
139
139
|
|
|
140
140
|
Args:
|
|
141
|
-
elements (
|
|
142
|
-
embeddings (
|
|
141
|
+
elements (list[Element]): List of elements.
|
|
142
|
+
embeddings (list[list[float]]): List of embeddings.
|
|
143
143
|
|
|
144
144
|
Returns:
|
|
145
|
-
|
|
145
|
+
list[Element]: Elements with embeddings added.
|
|
146
146
|
"""
|
|
147
147
|
assert len(elements) == len(embeddings)
|
|
148
148
|
elements_w_embedding = []
|
|
@@ -151,20 +151,20 @@ class MixedbreadAIEmbeddingEncoder(BaseEmbeddingEncoder):
|
|
|
151
151
|
elements_w_embedding.append(element)
|
|
152
152
|
return elements
|
|
153
153
|
|
|
154
|
-
def embed_documents(self, elements:
|
|
154
|
+
def embed_documents(self, elements: list[dict]) -> list[dict]:
|
|
155
155
|
"""
|
|
156
156
|
Embed a list of document elements.
|
|
157
157
|
|
|
158
158
|
Args:
|
|
159
|
-
elements (
|
|
159
|
+
elements (list[Element]): List of document elements.
|
|
160
160
|
|
|
161
161
|
Returns:
|
|
162
|
-
|
|
162
|
+
list[Element]: Elements with embeddings.
|
|
163
163
|
"""
|
|
164
164
|
embeddings = self._embed([e.get("text", "") for e in elements])
|
|
165
165
|
return self._add_embeddings_to_elements(elements, embeddings)
|
|
166
166
|
|
|
167
|
-
def embed_query(self, query: str) ->
|
|
167
|
+
def embed_query(self, query: str) -> list[float]:
|
|
168
168
|
"""
|
|
169
169
|
Embed a query string.
|
|
170
170
|
|
|
@@ -172,6 +172,6 @@ class MixedbreadAIEmbeddingEncoder(BaseEmbeddingEncoder):
|
|
|
172
172
|
query (str): Query string to embed.
|
|
173
173
|
|
|
174
174
|
Returns:
|
|
175
|
-
|
|
175
|
+
list[float]: Embedding of the query.
|
|
176
176
|
"""
|
|
177
177
|
return self._embed([query])[0]
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
from dataclasses import dataclass, field
|
|
2
|
-
from typing import TYPE_CHECKING,
|
|
2
|
+
from typing import TYPE_CHECKING, Optional
|
|
3
3
|
|
|
4
4
|
import numpy as np
|
|
5
5
|
from pydantic import Field, SecretStr
|
|
@@ -31,16 +31,16 @@ class OctoAiEmbeddingConfig(EmbeddingConfig):
|
|
|
31
31
|
class OctoAIEmbeddingEncoder(BaseEmbeddingEncoder):
|
|
32
32
|
config: OctoAiEmbeddingConfig
|
|
33
33
|
# Uses the OpenAI SDK
|
|
34
|
-
_exemplary_embedding: Optional[
|
|
34
|
+
_exemplary_embedding: Optional[list[float]] = field(init=False, default=None)
|
|
35
35
|
|
|
36
|
-
def get_exemplary_embedding(self) ->
|
|
36
|
+
def get_exemplary_embedding(self) -> list[float]:
|
|
37
37
|
return self.embed_query("Q")
|
|
38
38
|
|
|
39
|
-
def num_of_dimensions(self):
|
|
39
|
+
def num_of_dimensions(self) -> tuple[int, ...]:
|
|
40
40
|
exemplary_embedding = self.get_exemplary_embedding()
|
|
41
41
|
return np.shape(exemplary_embedding)
|
|
42
42
|
|
|
43
|
-
def is_unit_vector(self):
|
|
43
|
+
def is_unit_vector(self) -> bool:
|
|
44
44
|
exemplary_embedding = self.get_exemplary_embedding()
|
|
45
45
|
return np.isclose(np.linalg.norm(exemplary_embedding), 1.0)
|
|
46
46
|
|
|
@@ -49,12 +49,12 @@ class OctoAIEmbeddingEncoder(BaseEmbeddingEncoder):
|
|
|
49
49
|
response = client.embeddings.create(input=query, model=self.config.embedder_model_name)
|
|
50
50
|
return response.data[0].embedding
|
|
51
51
|
|
|
52
|
-
def embed_documents(self, elements:
|
|
52
|
+
def embed_documents(self, elements: list[dict]) -> list[dict]:
|
|
53
53
|
embeddings = [self.embed_query(e.get("text", "")) for e in elements]
|
|
54
54
|
elements_with_embeddings = self._add_embeddings_to_elements(elements, embeddings)
|
|
55
55
|
return elements_with_embeddings
|
|
56
56
|
|
|
57
|
-
def _add_embeddings_to_elements(self, elements, embeddings) ->
|
|
57
|
+
def _add_embeddings_to_elements(self, elements, embeddings) -> list[dict]:
|
|
58
58
|
assert len(elements) == len(embeddings)
|
|
59
59
|
elements_w_embedding = []
|
|
60
60
|
for i, element in enumerate(elements):
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
from dataclasses import dataclass
|
|
2
|
-
from typing import TYPE_CHECKING
|
|
2
|
+
from typing import TYPE_CHECKING
|
|
3
3
|
|
|
4
4
|
import numpy as np
|
|
5
5
|
from pydantic import Field, SecretStr
|
|
@@ -8,51 +8,46 @@ from unstructured_ingest.embed.interfaces import BaseEmbeddingEncoder, Embedding
|
|
|
8
8
|
from unstructured_ingest.utils.dep_check import requires_dependencies
|
|
9
9
|
|
|
10
10
|
if TYPE_CHECKING:
|
|
11
|
-
from
|
|
11
|
+
from openai import OpenAI
|
|
12
12
|
|
|
13
13
|
|
|
14
14
|
class OpenAIEmbeddingConfig(EmbeddingConfig):
|
|
15
15
|
api_key: SecretStr
|
|
16
16
|
embedder_model_name: str = Field(default="text-embedding-ada-002", alias="model_name")
|
|
17
17
|
|
|
18
|
-
@requires_dependencies(["
|
|
19
|
-
def get_client(self) -> "
|
|
20
|
-
|
|
21
|
-
from langchain_openai import OpenAIEmbeddings
|
|
18
|
+
@requires_dependencies(["openai"], extras="openai")
|
|
19
|
+
def get_client(self) -> "OpenAI":
|
|
20
|
+
from openai import OpenAI
|
|
22
21
|
|
|
23
|
-
|
|
24
|
-
openai_api_key=self.api_key.get_secret_value(),
|
|
25
|
-
model=self.embedder_model_name, # type:ignore
|
|
26
|
-
)
|
|
27
|
-
return openai_client
|
|
22
|
+
return OpenAI(api_key=self.api_key.get_secret_value())
|
|
28
23
|
|
|
29
24
|
|
|
30
25
|
@dataclass
|
|
31
26
|
class OpenAIEmbeddingEncoder(BaseEmbeddingEncoder):
|
|
32
27
|
config: OpenAIEmbeddingConfig
|
|
33
28
|
|
|
34
|
-
def get_exemplary_embedding(self) ->
|
|
29
|
+
def get_exemplary_embedding(self) -> list[float]:
|
|
35
30
|
return self.embed_query(query="Q")
|
|
36
31
|
|
|
37
|
-
def num_of_dimensions(self):
|
|
32
|
+
def num_of_dimensions(self) -> tuple[int, ...]:
|
|
38
33
|
exemplary_embedding = self.get_exemplary_embedding()
|
|
39
34
|
return np.shape(exemplary_embedding)
|
|
40
35
|
|
|
41
|
-
def is_unit_vector(self):
|
|
36
|
+
def is_unit_vector(self) -> bool:
|
|
42
37
|
exemplary_embedding = self.get_exemplary_embedding()
|
|
43
38
|
return np.isclose(np.linalg.norm(exemplary_embedding), 1.0)
|
|
44
39
|
|
|
45
|
-
def embed_query(self, query):
|
|
40
|
+
def embed_query(self, query: str) -> list[float]:
|
|
46
41
|
client = self.config.get_client()
|
|
47
|
-
|
|
42
|
+
response = client.embeddings.create(input=query, model=self.config.embedder_model_name)
|
|
43
|
+
return response.data[0].embedding
|
|
48
44
|
|
|
49
|
-
def embed_documents(self, elements:
|
|
50
|
-
|
|
51
|
-
embeddings = client.embed_documents([e.get("text", "") for e in elements])
|
|
45
|
+
def embed_documents(self, elements: list[dict]) -> list[dict]:
|
|
46
|
+
embeddings = self._embed_documents([e.get("text", "") for e in elements])
|
|
52
47
|
elements_with_embeddings = self._add_embeddings_to_elements(elements, embeddings)
|
|
53
48
|
return elements_with_embeddings
|
|
54
49
|
|
|
55
|
-
def _add_embeddings_to_elements(self, elements, embeddings) ->
|
|
50
|
+
def _add_embeddings_to_elements(self, elements, embeddings) -> list[dict]:
|
|
56
51
|
assert len(elements) == len(embeddings)
|
|
57
52
|
elements_w_embedding = []
|
|
58
53
|
for i, element in enumerate(elements):
|