unstructured-ingest 0.0.19__py3-none-any.whl → 0.0.22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of unstructured-ingest might be problematic. Click here for more details.

Files changed (47) hide show
  1. unstructured_ingest/__version__.py +1 -1
  2. unstructured_ingest/cli/cmds/astradb.py +2 -2
  3. unstructured_ingest/connector/astradb.py +54 -24
  4. unstructured_ingest/embed/bedrock.py +56 -19
  5. unstructured_ingest/embed/huggingface.py +22 -22
  6. unstructured_ingest/embed/interfaces.py +11 -4
  7. unstructured_ingest/embed/mixedbreadai.py +17 -17
  8. unstructured_ingest/embed/octoai.py +7 -7
  9. unstructured_ingest/embed/openai.py +15 -20
  10. unstructured_ingest/embed/vertexai.py +25 -17
  11. unstructured_ingest/embed/voyageai.py +22 -17
  12. unstructured_ingest/v2/cli/base/cmd.py +1 -1
  13. unstructured_ingest/v2/interfaces/connector.py +1 -1
  14. unstructured_ingest/v2/pipeline/pipeline.py +3 -1
  15. unstructured_ingest/v2/pipeline/steps/chunk.py +1 -1
  16. unstructured_ingest/v2/pipeline/steps/download.py +6 -2
  17. unstructured_ingest/v2/pipeline/steps/embed.py +1 -1
  18. unstructured_ingest/v2/pipeline/steps/filter.py +1 -1
  19. unstructured_ingest/v2/pipeline/steps/index.py +4 -2
  20. unstructured_ingest/v2/pipeline/steps/partition.py +1 -1
  21. unstructured_ingest/v2/pipeline/steps/stage.py +3 -1
  22. unstructured_ingest/v2/pipeline/steps/uncompress.py +1 -1
  23. unstructured_ingest/v2/pipeline/steps/upload.py +6 -2
  24. unstructured_ingest/v2/processes/chunker.py +8 -29
  25. unstructured_ingest/v2/processes/connectors/airtable.py +1 -1
  26. unstructured_ingest/v2/processes/connectors/astradb.py +26 -19
  27. unstructured_ingest/v2/processes/connectors/databricks_volumes.py +11 -8
  28. unstructured_ingest/v2/processes/connectors/elasticsearch.py +2 -2
  29. unstructured_ingest/v2/processes/connectors/fsspec/azure.py +31 -5
  30. unstructured_ingest/v2/processes/connectors/fsspec/box.py +31 -2
  31. unstructured_ingest/v2/processes/connectors/fsspec/dropbox.py +36 -8
  32. unstructured_ingest/v2/processes/connectors/fsspec/fsspec.py +25 -77
  33. unstructured_ingest/v2/processes/connectors/fsspec/gcs.py +30 -1
  34. unstructured_ingest/v2/processes/connectors/fsspec/s3.py +15 -18
  35. unstructured_ingest/v2/processes/connectors/fsspec/sftp.py +22 -1
  36. unstructured_ingest/v2/processes/connectors/milvus.py +2 -2
  37. unstructured_ingest/v2/processes/connectors/opensearch.py +2 -2
  38. unstructured_ingest/v2/processes/partitioner.py +9 -55
  39. unstructured_ingest/v2/unstructured_api.py +87 -0
  40. unstructured_ingest/v2/utils.py +1 -1
  41. unstructured_ingest-0.0.22.dist-info/METADATA +186 -0
  42. {unstructured_ingest-0.0.19.dist-info → unstructured_ingest-0.0.22.dist-info}/RECORD +46 -45
  43. {unstructured_ingest-0.0.19.dist-info → unstructured_ingest-0.0.22.dist-info}/WHEEL +1 -1
  44. unstructured_ingest-0.0.19.dist-info/METADATA +0 -639
  45. {unstructured_ingest-0.0.19.dist-info → unstructured_ingest-0.0.22.dist-info}/LICENSE.md +0 -0
  46. {unstructured_ingest-0.0.19.dist-info → unstructured_ingest-0.0.22.dist-info}/entry_points.txt +0 -0
  47. {unstructured_ingest-0.0.19.dist-info → unstructured_ingest-0.0.22.dist-info}/top_level.txt +0 -0
@@ -1 +1 @@
1
- __version__ = "0.0.19" # pragma: no cover
1
+ __version__ = "0.0.22" # pragma: no cover
@@ -37,11 +37,11 @@ class AstraDBCliConfig(SimpleAstraDBConfig, CliConfig):
37
37
  "numbers, and underscores.",
38
38
  ),
39
39
  click.Option(
40
- ["--namespace"],
40
+ ["--keyspace"],
41
41
  required=False,
42
42
  default=None,
43
43
  type=str,
44
- help="The Astra DB connection namespace.",
44
+ help="The Astra DB connection keyspace.",
45
45
  ),
46
46
  ]
47
47
  return options
@@ -24,7 +24,8 @@ from unstructured_ingest.utils.data_prep import batch_generator, flatten_dict
24
24
  from unstructured_ingest.utils.dep_check import requires_dependencies
25
25
 
26
26
  if t.TYPE_CHECKING:
27
- from astrapy.db import AstraDB, AstraDBCollection
27
+ from astrapy import Collection as AstraDBCollection
28
+ from astrapy import Database as AstraDB
28
29
 
29
30
  NON_INDEXED_FIELDS = ["metadata._node_content", "content"]
30
31
 
@@ -39,6 +40,7 @@ class AstraDBAccessConfig(AccessConfig):
39
40
  class SimpleAstraDBConfig(BaseConnectorConfig):
40
41
  access_config: AstraDBAccessConfig
41
42
  collection_name: str
43
+ keyspace: t.Optional[str] = None
42
44
  namespace: t.Optional[str] = None
43
45
 
44
46
 
@@ -98,22 +100,30 @@ class AstraDBSourceConnector(SourceConnectorCleanupMixin, BaseSourceConnector):
98
100
  @requires_dependencies(["astrapy"], extras="astradb")
99
101
  def astra_db_collection(self) -> "AstraDBCollection":
100
102
  if self._astra_db_collection is None:
101
- from astrapy.db import AstraDB
103
+ from astrapy import DataAPIClient as AstraDBClient
102
104
 
103
- # Build the Astra DB object.
105
+ # Choose keyspace or deprecated namespace
106
+ keyspace_param = self.connector_config.keyspace or self.connector_config.namespace
107
+
108
+ # Create a client object to interact with the Astra DB
104
109
  # caller_name/version for Astra DB tracking
105
- self._astra_db = AstraDB(
106
- api_endpoint=self.connector_config.access_config.api_endpoint,
107
- token=self.connector_config.access_config.token,
108
- namespace=self.connector_config.namespace,
110
+ my_client = AstraDBClient(
109
111
  caller_name=integration_name,
110
112
  caller_version=integration_version,
111
113
  )
112
114
 
113
- # Create and connect to the collection
114
- self._astra_db_collection = self._astra_db.collection(
115
- collection_name=self.connector_config.collection_name,
115
+ # Get the database object
116
+ self._astra_db = my_client.get_database(
117
+ api_endpoint=self.connector_config.access_config.api_endpoint,
118
+ token=self.connector_config.access_config.token,
119
+ keyspace=keyspace_param,
116
120
  )
121
+
122
+ # Create and connect to the newly created collection
123
+ self._astra_db_collection = self._astra_db.get_collection(
124
+ name=self.connector_config.collection_name,
125
+ )
126
+
117
127
  return self._astra_db_collection # type: ignore
118
128
 
119
129
  @requires_dependencies(["astrapy"], extras="astradb")
@@ -132,8 +142,14 @@ class AstraDBSourceConnector(SourceConnectorCleanupMixin, BaseSourceConnector):
132
142
  @requires_dependencies(["astrapy"], extras="astradb")
133
143
  def get_ingest_docs(self): # type: ignore
134
144
  # Perform the find operation
135
- astra_db_docs = list(self.astra_db_collection.paginated_find())
145
+ astra_db_docs_cursor = self.astra_db_collection.find({})
136
146
 
147
+ # Iterate over the cursor
148
+ astra_db_docs = []
149
+ for result in astra_db_docs_cursor:
150
+ astra_db_docs.append(result)
151
+
152
+ # Create a list of AstraDBIngestDoc objects
137
153
  doc_list = []
138
154
  for record in astra_db_docs:
139
155
  doc = AstraDBIngestDoc(
@@ -182,30 +198,41 @@ class AstraDBDestinationConnector(BaseDestinationConnector):
182
198
  @requires_dependencies(["astrapy"], extras="astradb")
183
199
  def astra_db_collection(self) -> "AstraDBCollection":
184
200
  if self._astra_db_collection is None:
185
- from astrapy.db import AstraDB
201
+ from astrapy import DataAPIClient as AstraDBClient
202
+ from astrapy.exceptions import CollectionAlreadyExistsException
203
+
204
+ # Choose keyspace or deprecated namespace
205
+ keyspace_param = self.connector_config.keyspace or self.connector_config.namespace
186
206
 
187
207
  collection_name = self.connector_config.collection_name
188
208
  embedding_dimension = self.write_config.embedding_dimension
189
-
190
- # If the user has requested an indexing policy, pass it to the Astra DB
191
209
  requested_indexing_policy = self.write_config.requested_indexing_policy
192
- options = {"indexing": requested_indexing_policy} if requested_indexing_policy else None
193
210
 
211
+ # Create a client object to interact with the Astra DB
194
212
  # caller_name/version for Astra DB tracking
195
- self._astra_db = AstraDB(
196
- api_endpoint=self.connector_config.access_config.api_endpoint,
197
- token=self.connector_config.access_config.token,
198
- namespace=self.connector_config.namespace,
213
+ my_client = AstraDBClient(
199
214
  caller_name=integration_name,
200
215
  caller_version=integration_version,
201
216
  )
202
217
 
203
- # Create and connect to the newly created collection
204
- self._astra_db_collection = self._astra_db.create_collection(
205
- collection_name=collection_name,
206
- dimension=embedding_dimension,
207
- options=options,
218
+ # Get the database object
219
+ self._astra_db = my_client.get_database(
220
+ api_endpoint=self.connector_config.access_config.api_endpoint,
221
+ token=self.connector_config.access_config.token,
222
+ keyspace=keyspace_param,
208
223
  )
224
+
225
+ # Create and connect to the newly created collection
226
+ try:
227
+ self._astra_db_collection = self._astra_db.create_collection(
228
+ name=collection_name,
229
+ dimension=embedding_dimension,
230
+ indexing=requested_indexing_policy,
231
+ )
232
+ except CollectionAlreadyExistsException as e:
233
+ logger.info(f"{e}", exc_info=True)
234
+ self._astra_db_collection = self._astra_db.get_collection(name=collection_name)
235
+
209
236
  return self._astra_db_collection
210
237
 
211
238
  @requires_dependencies(["astrapy"], extras="astradb")
@@ -224,6 +251,9 @@ class AstraDBDestinationConnector(BaseDestinationConnector):
224
251
  def write_dict(self, *args, elements_dict: t.List[t.Dict[str, t.Any]], **kwargs) -> None:
225
252
  logger.info(f"inserting / updating {len(elements_dict)} documents to Astra DB.")
226
253
 
254
+ if self._astra_db_collection is None:
255
+ raise DestinationConnectionError("Astra DB collection not available for insertion.")
256
+
227
257
  astra_db_batch_size = self.write_config.batch_size
228
258
 
229
259
  for batch in batch_generator(elements_dict, astra_db_batch_size):
@@ -1,38 +1,43 @@
1
+ import json
2
+ import os
1
3
  from dataclasses import dataclass
2
- from typing import TYPE_CHECKING, List
4
+ from typing import TYPE_CHECKING
3
5
 
4
6
  import numpy as np
5
- from pydantic import SecretStr
7
+ from pydantic import Field, SecretStr
6
8
 
7
9
  from unstructured_ingest.embed.interfaces import BaseEmbeddingEncoder, EmbeddingConfig
8
10
  from unstructured_ingest.utils.dep_check import requires_dependencies
9
11
 
10
12
  if TYPE_CHECKING:
11
- from langchain_community.embeddings import BedrockEmbeddings
13
+ from botocore.client import BaseClient
14
+
15
+ class BedrockClient(BaseClient):
16
+ def invoke_model(self, body: str, modelId: str, trace: str) -> dict:
17
+ pass
12
18
 
13
19
 
14
20
  class BedrockEmbeddingConfig(EmbeddingConfig):
15
21
  aws_access_key_id: SecretStr
16
22
  aws_secret_access_key: SecretStr
17
23
  region_name: str = "us-west-2"
24
+ embed_model_name: str = Field(default="amazon.titan-embed-text-v1", alias="model_name")
18
25
 
19
26
  @requires_dependencies(
20
- ["boto3", "numpy", "langchain_community"],
27
+ ["boto3", "numpy", "botocore"],
21
28
  extras="bedrock",
22
29
  )
23
- def get_client(self) -> "BedrockEmbeddings":
30
+ def get_client(self) -> "BedrockClient":
24
31
  # delay import only when needed
25
32
  import boto3
26
- from langchain_community.embeddings import BedrockEmbeddings
27
33
 
28
- bedrock_runtime = boto3.client(
34
+ bedrock_client = boto3.client(
29
35
  service_name="bedrock-runtime",
30
36
  aws_access_key_id=self.aws_access_key_id.get_secret_value(),
31
37
  aws_secret_access_key=self.aws_secret_access_key.get_secret_value(),
32
38
  region_name=self.region_name,
33
39
  )
34
40
 
35
- bedrock_client = BedrockEmbeddings(client=bedrock_runtime)
36
41
  return bedrock_client
37
42
 
38
43
 
@@ -40,28 +45,60 @@ class BedrockEmbeddingConfig(EmbeddingConfig):
40
45
  class BedrockEmbeddingEncoder(BaseEmbeddingEncoder):
41
46
  config: BedrockEmbeddingConfig
42
47
 
43
- def get_exemplary_embedding(self) -> List[float]:
48
+ def get_exemplary_embedding(self) -> list[float]:
44
49
  return self.embed_query(query="Q")
45
50
 
46
- def num_of_dimensions(self):
51
+ def num_of_dimensions(self) -> tuple[int, ...]:
47
52
  exemplary_embedding = self.get_exemplary_embedding()
48
53
  return np.shape(exemplary_embedding)
49
54
 
50
- def is_unit_vector(self):
55
+ def is_unit_vector(self) -> bool:
51
56
  exemplary_embedding = self.get_exemplary_embedding()
52
57
  return np.isclose(np.linalg.norm(exemplary_embedding), 1.0)
53
58
 
54
- def embed_query(self, query):
55
- bedrock_client = self.config.get_client()
56
- return np.array(bedrock_client.embed_query(query))
57
-
58
- def embed_documents(self, elements: List[dict]) -> List[dict]:
59
- bedrock_client = self.config.get_client()
60
- embeddings = bedrock_client.embed_documents([e.get("text", "") for e in elements])
59
+ def embed_query(self, query: str) -> list[float]:
60
+ """Call out to Bedrock embedding endpoint."""
61
+ # replace newlines, which can negatively affect performance.
62
+ text = query.replace(os.linesep, " ")
63
+
64
+ # format input body for provider
65
+ provider = self.config.embed_model_name.split(".")[0]
66
+ input_body = {}
67
+ if provider == "cohere":
68
+ if "input_type" not in input_body:
69
+ input_body["input_type"] = "search_document"
70
+ input_body["texts"] = [text]
71
+ else:
72
+ # includes common provider == "amazon"
73
+ input_body["inputText"] = text
74
+ body = json.dumps(input_body)
75
+
76
+ try:
77
+ bedrock_client = self.config.get_client()
78
+ # invoke bedrock API
79
+ response = bedrock_client.invoke_model(
80
+ body=body,
81
+ modelId=self.config.embed_model_name,
82
+ accept="application/json",
83
+ contentType="application/json",
84
+ )
85
+
86
+ # format output based on provider
87
+ response_body = json.loads(response.get("body").read())
88
+ if provider == "cohere":
89
+ return response_body.get("embeddings")[0]
90
+ else:
91
+ # includes common provider == "amazon"
92
+ return response_body.get("embedding")
93
+ except Exception as e:
94
+ raise ValueError(f"Error raised by inference endpoint: {e}")
95
+
96
+ def embed_documents(self, elements: list[dict]) -> list[dict]:
97
+ embeddings = [self.embed_query(query=e.get("text", "")) for e in elements]
61
98
  elements_with_embeddings = self._add_embeddings_to_elements(elements, embeddings)
62
99
  return elements_with_embeddings
63
100
 
64
- def _add_embeddings_to_elements(self, elements, embeddings) -> List[dict]:
101
+ def _add_embeddings_to_elements(self, elements, embeddings) -> list[dict]:
65
102
  assert len(elements) == len(embeddings)
66
103
  elements_w_embedding = []
67
104
  for i, element in enumerate(elements):
@@ -1,5 +1,5 @@
1
1
  from dataclasses import dataclass
2
- from typing import TYPE_CHECKING, List, Optional
2
+ from typing import TYPE_CHECKING, Optional
3
3
 
4
4
  import numpy as np
5
5
  from pydantic import Field
@@ -8,7 +8,7 @@ from unstructured_ingest.embed.interfaces import BaseEmbeddingEncoder, Embedding
8
8
  from unstructured_ingest.utils.dep_check import requires_dependencies
9
9
 
10
10
  if TYPE_CHECKING:
11
- from langchain_huggingface.embeddings import HuggingFaceEmbeddings
11
+ from sentence_transformers import SentenceTransformer
12
12
 
13
13
 
14
14
  class HuggingFaceEmbeddingConfig(EmbeddingConfig):
@@ -19,51 +19,51 @@ class HuggingFaceEmbeddingConfig(EmbeddingConfig):
19
19
  default_factory=lambda: {"device": "cpu"}, alias="model_kwargs"
20
20
  )
21
21
  encode_kwargs: Optional[dict] = Field(default_factory=lambda: {"normalize_embeddings": False})
22
- cache_folder: Optional[dict] = Field(default=None)
22
+ cache_folder: Optional[str] = Field(default=None)
23
23
 
24
24
  @requires_dependencies(
25
- ["langchain_huggingface"],
25
+ ["sentence_transformers"],
26
26
  extras="embed-huggingface",
27
27
  )
28
- def get_client(self) -> "HuggingFaceEmbeddings":
29
- """Creates a langchain Huggingface python client to embed elements."""
30
- from langchain_huggingface.embeddings import HuggingFaceEmbeddings
31
-
32
- client = HuggingFaceEmbeddings(
33
- model_name=self.embedder_model_name,
34
- model_kwargs=self.embedder_model_kwargs,
35
- encode_kwargs=self.encode_kwargs,
28
+ def get_client(self) -> "SentenceTransformer":
29
+ from sentence_transformers import SentenceTransformer
30
+
31
+ return SentenceTransformer(
32
+ model_name_or_path=self.embedder_model_name,
36
33
  cache_folder=self.cache_folder,
34
+ **self.embedder_model_kwargs,
37
35
  )
38
- return client
39
36
 
40
37
 
41
38
  @dataclass
42
39
  class HuggingFaceEmbeddingEncoder(BaseEmbeddingEncoder):
43
40
  config: HuggingFaceEmbeddingConfig
44
41
 
45
- def get_exemplary_embedding(self) -> List[float]:
42
+ def get_exemplary_embedding(self) -> list[float]:
46
43
  return self.embed_query(query="Q")
47
44
 
48
- def num_of_dimensions(self):
45
+ def num_of_dimensions(self) -> tuple[int, ...]:
49
46
  exemplary_embedding = self.get_exemplary_embedding()
50
47
  return np.shape(exemplary_embedding)
51
48
 
52
- def is_unit_vector(self):
49
+ def is_unit_vector(self) -> bool:
53
50
  exemplary_embedding = self.get_exemplary_embedding()
54
51
  return np.isclose(np.linalg.norm(exemplary_embedding), 1.0)
55
52
 
56
- def embed_query(self, query):
57
- client = self.config.get_client()
58
- return client.embed_query(str(query))
53
+ def embed_query(self, query: str) -> list[float]:
54
+ return self._embed_documents(texts=[query])[0]
59
55
 
60
- def embed_documents(self, elements: List[dict]) -> List[dict]:
56
+ def _embed_documents(self, texts: list[str]) -> list[list[float]]:
61
57
  client = self.config.get_client()
62
- embeddings = client.embed_documents([e.get("text", "") for e in elements])
58
+ embeddings = client.encode(texts, **self.config.encode_kwargs)
59
+ return embeddings.tolist()
60
+
61
+ def embed_documents(self, elements: list[dict]) -> list[dict]:
62
+ embeddings = self._embed_documents([e.get("text", "") for e in elements])
63
63
  elements_with_embeddings = self._add_embeddings_to_elements(elements, embeddings)
64
64
  return elements_with_embeddings
65
65
 
66
- def _add_embeddings_to_elements(self, elements: list[dict], embeddings: list) -> List[dict]:
66
+ def _add_embeddings_to_elements(self, elements: list[dict], embeddings: list) -> list[dict]:
67
67
  assert len(elements) == len(embeddings)
68
68
  elements_w_embedding = []
69
69
 
@@ -1,6 +1,5 @@
1
1
  from abc import ABC, abstractmethod
2
2
  from dataclasses import dataclass
3
- from typing import List, Tuple
4
3
 
5
4
  from pydantic import BaseModel
6
5
 
@@ -19,7 +18,7 @@ class BaseEmbeddingEncoder(ABC):
19
18
 
20
19
  @property
21
20
  @abstractmethod
22
- def num_of_dimensions(self) -> Tuple[int]:
21
+ def num_of_dimensions(self) -> tuple[int, ...]:
23
22
  """Number of dimensions for the embedding vector."""
24
23
 
25
24
  @property
@@ -28,9 +27,17 @@ class BaseEmbeddingEncoder(ABC):
28
27
  """Denotes if the embedding vector is a unit vector."""
29
28
 
30
29
  @abstractmethod
31
- def embed_documents(self, elements: List[dict]) -> List[dict]:
30
+ def embed_documents(self, elements: list[dict]) -> list[dict]:
32
31
  pass
33
32
 
34
33
  @abstractmethod
35
- def embed_query(self, query: str) -> List[float]:
34
+ def embed_query(self, query: str) -> list[float]:
36
35
  pass
36
+
37
+ def _embed_documents(self, elements: list[str]) -> list[list[float]]:
38
+ results = []
39
+ for text in elements:
40
+ response = self.embed_query(query=text)
41
+ results.append(response)
42
+
43
+ return results
@@ -1,6 +1,6 @@
1
1
  import os
2
2
  from dataclasses import dataclass, field
3
- from typing import TYPE_CHECKING, List, Optional
3
+ from typing import TYPE_CHECKING, Optional
4
4
 
5
5
  import numpy as np
6
6
  from pydantic import Field, SecretStr
@@ -67,10 +67,10 @@ class MixedbreadAIEmbeddingEncoder(BaseEmbeddingEncoder):
67
67
 
68
68
  config: MixedbreadAIEmbeddingConfig
69
69
 
70
- _exemplary_embedding: Optional[List[float]] = field(init=False, default=None)
70
+ _exemplary_embedding: Optional[list[float]] = field(init=False, default=None)
71
71
  _request_options: Optional["RequestOptions"] = field(init=False, default=None)
72
72
 
73
- def get_exemplary_embedding(self) -> List[float]:
73
+ def get_exemplary_embedding(self) -> list[float]:
74
74
  """Get an exemplary embedding to determine dimensions and unit vector status."""
75
75
  return self._embed(["Q"])[0]
76
76
 
@@ -91,7 +91,7 @@ class MixedbreadAIEmbeddingEncoder(BaseEmbeddingEncoder):
91
91
  )
92
92
 
93
93
  @property
94
- def num_of_dimensions(self):
94
+ def num_of_dimensions(self) -> tuple[int, ...]:
95
95
  """Get the number of dimensions for the embeddings."""
96
96
  exemplary_embedding = self.get_exemplary_embedding()
97
97
  return np.shape(exemplary_embedding)
@@ -102,15 +102,15 @@ class MixedbreadAIEmbeddingEncoder(BaseEmbeddingEncoder):
102
102
  exemplary_embedding = self.get_exemplary_embedding()
103
103
  return np.isclose(np.linalg.norm(exemplary_embedding), 1.0)
104
104
 
105
- def _embed(self, texts: List[str]) -> List[List[float]]:
105
+ def _embed(self, texts: list[str]) -> list[list[float]]:
106
106
  """
107
107
  Embed a list of texts using the Mixedbread AI API.
108
108
 
109
109
  Args:
110
- texts (List[str]): List of texts to embed.
110
+ texts (list[str]): List of texts to embed.
111
111
 
112
112
  Returns:
113
- List[List[float]]: List of embeddings.
113
+ list[list[float]]: List of embeddings.
114
114
  """
115
115
  batch_size = BATCH_SIZE
116
116
  batch_itr = range(0, len(texts), batch_size)
@@ -132,17 +132,17 @@ class MixedbreadAIEmbeddingEncoder(BaseEmbeddingEncoder):
132
132
 
133
133
  @staticmethod
134
134
  def _add_embeddings_to_elements(
135
- elements: List[dict], embeddings: List[List[float]]
136
- ) -> List[dict]:
135
+ elements: list[dict], embeddings: list[list[float]]
136
+ ) -> list[dict]:
137
137
  """
138
138
  Add embeddings to elements.
139
139
 
140
140
  Args:
141
- elements (List[Element]): List of elements.
142
- embeddings (List[List[float]]): List of embeddings.
141
+ elements (list[Element]): List of elements.
142
+ embeddings (list[list[float]]): List of embeddings.
143
143
 
144
144
  Returns:
145
- List[Element]: Elements with embeddings added.
145
+ list[Element]: Elements with embeddings added.
146
146
  """
147
147
  assert len(elements) == len(embeddings)
148
148
  elements_w_embedding = []
@@ -151,20 +151,20 @@ class MixedbreadAIEmbeddingEncoder(BaseEmbeddingEncoder):
151
151
  elements_w_embedding.append(element)
152
152
  return elements
153
153
 
154
- def embed_documents(self, elements: List[dict]) -> List[dict]:
154
+ def embed_documents(self, elements: list[dict]) -> list[dict]:
155
155
  """
156
156
  Embed a list of document elements.
157
157
 
158
158
  Args:
159
- elements (List[Element]): List of document elements.
159
+ elements (list[Element]): List of document elements.
160
160
 
161
161
  Returns:
162
- List[Element]: Elements with embeddings.
162
+ list[Element]: Elements with embeddings.
163
163
  """
164
164
  embeddings = self._embed([e.get("text", "") for e in elements])
165
165
  return self._add_embeddings_to_elements(elements, embeddings)
166
166
 
167
- def embed_query(self, query: str) -> List[float]:
167
+ def embed_query(self, query: str) -> list[float]:
168
168
  """
169
169
  Embed a query string.
170
170
 
@@ -172,6 +172,6 @@ class MixedbreadAIEmbeddingEncoder(BaseEmbeddingEncoder):
172
172
  query (str): Query string to embed.
173
173
 
174
174
  Returns:
175
- List[float]: Embedding of the query.
175
+ list[float]: Embedding of the query.
176
176
  """
177
177
  return self._embed([query])[0]
@@ -1,5 +1,5 @@
1
1
  from dataclasses import dataclass, field
2
- from typing import TYPE_CHECKING, List, Optional
2
+ from typing import TYPE_CHECKING, Optional
3
3
 
4
4
  import numpy as np
5
5
  from pydantic import Field, SecretStr
@@ -31,16 +31,16 @@ class OctoAiEmbeddingConfig(EmbeddingConfig):
31
31
  class OctoAIEmbeddingEncoder(BaseEmbeddingEncoder):
32
32
  config: OctoAiEmbeddingConfig
33
33
  # Uses the OpenAI SDK
34
- _exemplary_embedding: Optional[List[float]] = field(init=False, default=None)
34
+ _exemplary_embedding: Optional[list[float]] = field(init=False, default=None)
35
35
 
36
- def get_exemplary_embedding(self) -> List[float]:
36
+ def get_exemplary_embedding(self) -> list[float]:
37
37
  return self.embed_query("Q")
38
38
 
39
- def num_of_dimensions(self):
39
+ def num_of_dimensions(self) -> tuple[int, ...]:
40
40
  exemplary_embedding = self.get_exemplary_embedding()
41
41
  return np.shape(exemplary_embedding)
42
42
 
43
- def is_unit_vector(self):
43
+ def is_unit_vector(self) -> bool:
44
44
  exemplary_embedding = self.get_exemplary_embedding()
45
45
  return np.isclose(np.linalg.norm(exemplary_embedding), 1.0)
46
46
 
@@ -49,12 +49,12 @@ class OctoAIEmbeddingEncoder(BaseEmbeddingEncoder):
49
49
  response = client.embeddings.create(input=query, model=self.config.embedder_model_name)
50
50
  return response.data[0].embedding
51
51
 
52
- def embed_documents(self, elements: List[dict]) -> List[dict]:
52
+ def embed_documents(self, elements: list[dict]) -> list[dict]:
53
53
  embeddings = [self.embed_query(e.get("text", "")) for e in elements]
54
54
  elements_with_embeddings = self._add_embeddings_to_elements(elements, embeddings)
55
55
  return elements_with_embeddings
56
56
 
57
- def _add_embeddings_to_elements(self, elements, embeddings) -> List[dict]:
57
+ def _add_embeddings_to_elements(self, elements, embeddings) -> list[dict]:
58
58
  assert len(elements) == len(embeddings)
59
59
  elements_w_embedding = []
60
60
  for i, element in enumerate(elements):
@@ -1,5 +1,5 @@
1
1
  from dataclasses import dataclass
2
- from typing import TYPE_CHECKING, List
2
+ from typing import TYPE_CHECKING
3
3
 
4
4
  import numpy as np
5
5
  from pydantic import Field, SecretStr
@@ -8,51 +8,46 @@ from unstructured_ingest.embed.interfaces import BaseEmbeddingEncoder, Embedding
8
8
  from unstructured_ingest.utils.dep_check import requires_dependencies
9
9
 
10
10
  if TYPE_CHECKING:
11
- from langchain_openai.embeddings import OpenAIEmbeddings
11
+ from openai import OpenAI
12
12
 
13
13
 
14
14
  class OpenAIEmbeddingConfig(EmbeddingConfig):
15
15
  api_key: SecretStr
16
16
  embedder_model_name: str = Field(default="text-embedding-ada-002", alias="model_name")
17
17
 
18
- @requires_dependencies(["langchain_openai"], extras="openai")
19
- def get_client(self) -> "OpenAIEmbeddings":
20
- """Creates a langchain OpenAI python client to embed elements."""
21
- from langchain_openai import OpenAIEmbeddings
18
+ @requires_dependencies(["openai"], extras="openai")
19
+ def get_client(self) -> "OpenAI":
20
+ from openai import OpenAI
22
21
 
23
- openai_client = OpenAIEmbeddings(
24
- openai_api_key=self.api_key.get_secret_value(),
25
- model=self.embedder_model_name, # type:ignore
26
- )
27
- return openai_client
22
+ return OpenAI(api_key=self.api_key.get_secret_value())
28
23
 
29
24
 
30
25
  @dataclass
31
26
  class OpenAIEmbeddingEncoder(BaseEmbeddingEncoder):
32
27
  config: OpenAIEmbeddingConfig
33
28
 
34
- def get_exemplary_embedding(self) -> List[float]:
29
+ def get_exemplary_embedding(self) -> list[float]:
35
30
  return self.embed_query(query="Q")
36
31
 
37
- def num_of_dimensions(self):
32
+ def num_of_dimensions(self) -> tuple[int, ...]:
38
33
  exemplary_embedding = self.get_exemplary_embedding()
39
34
  return np.shape(exemplary_embedding)
40
35
 
41
- def is_unit_vector(self):
36
+ def is_unit_vector(self) -> bool:
42
37
  exemplary_embedding = self.get_exemplary_embedding()
43
38
  return np.isclose(np.linalg.norm(exemplary_embedding), 1.0)
44
39
 
45
- def embed_query(self, query):
40
+ def embed_query(self, query: str) -> list[float]:
46
41
  client = self.config.get_client()
47
- return client.embed_query(str(query))
42
+ response = client.embeddings.create(input=query, model=self.config.embedder_model_name)
43
+ return response.data[0].embedding
48
44
 
49
- def embed_documents(self, elements: List[dict]) -> List[dict]:
50
- client = self.config.get_client()
51
- embeddings = client.embed_documents([e.get("text", "") for e in elements])
45
+ def embed_documents(self, elements: list[dict]) -> list[dict]:
46
+ embeddings = self._embed_documents([e.get("text", "") for e in elements])
52
47
  elements_with_embeddings = self._add_embeddings_to_elements(elements, embeddings)
53
48
  return elements_with_embeddings
54
49
 
55
- def _add_embeddings_to_elements(self, elements, embeddings) -> List[dict]:
50
+ def _add_embeddings_to_elements(self, elements, embeddings) -> list[dict]:
56
51
  assert len(elements) == len(embeddings)
57
52
  elements_w_embedding = []
58
53
  for i, element in enumerate(elements):