langroid 0.39.5__py3-none-any.whl → 0.41.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,427 @@
1
+ import json
2
+ import logging
3
+ import os
4
+ import re
5
+ from dataclasses import dataclass
6
+ from typing import (
7
+ TYPE_CHECKING,
8
+ Any,
9
+ Dict,
10
+ List,
11
+ Literal,
12
+ Optional,
13
+ Sequence,
14
+ Tuple,
15
+ Union,
16
+ )
17
+
18
+ from dotenv import load_dotenv
19
+
20
+ from langroid import LangroidImportError
21
+ from langroid.mytypes import Document
22
+
23
+ # import dataclass
24
+ from langroid.pydantic_v1 import BaseModel
25
+ from langroid.utils.configuration import settings
26
+ from langroid.vector_store.base import VectorStore, VectorStoreConfig
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ has_pinecone: bool = True
32
+ try:
33
+ from pinecone import Pinecone, PineconeApiException, ServerlessSpec
34
+ except ImportError:
35
+
36
+ if not TYPE_CHECKING:
37
+
38
+ class ServerlessSpec(BaseModel):
39
+ """
40
+ Fallback Serverless specification configuration to avoid import errors.
41
+ """
42
+
43
+ cloud: str
44
+ region: str
45
+
46
+ PineconeApiException = Any # type: ignore
47
+ Pinecone = Any # type: ignore
48
+ has_pinecone = False
49
+
50
+
51
+ @dataclass(frozen=True)
52
+ class IndexMeta:
53
+ name: str
54
+ total_vector_count: int
55
+
56
+
57
+ class PineconeDBConfig(VectorStoreConfig):
58
+ cloud: bool = True
59
+ collection_name: str | None = "temp"
60
+ spec: ServerlessSpec = ServerlessSpec(cloud="aws", region="us-east-1")
61
+ deletion_protection: Literal["enabled", "disabled"] | None = None
62
+ metric: str = "cosine"
63
+ pagination_size: int = 100
64
+
65
+
66
+ class PineconeDB(VectorStore):
67
+ def __init__(self, config: PineconeDBConfig = PineconeDBConfig()):
68
+ super().__init__(config)
69
+ if not has_pinecone:
70
+ raise LangroidImportError("pinecone", "pinecone")
71
+ self.config: PineconeDBConfig = config
72
+ load_dotenv()
73
+ key = os.getenv("PINECONE_API_KEY")
74
+
75
+ if not key:
76
+ raise ValueError("PINECONE_API_KEY not set, could not instantiate client")
77
+ self.client = Pinecone(api_key=key)
78
+
79
+ if config.collection_name:
80
+ self.create_collection(
81
+ collection_name=config.collection_name,
82
+ replace=config.replace_collection,
83
+ )
84
+
85
+ def clear_empty_collections(self) -> int:
86
+ indexes = self._list_index_metas(empty=True)
87
+ n_deletes = 0
88
+ for index in indexes:
89
+ if index.total_vector_count == -1:
90
+ logger.warning(
91
+ f"Error fetching details for {index.name} when scanning indexes"
92
+ )
93
+ n_deletes += 1
94
+ self.delete_collection(collection_name=index.name)
95
+ return n_deletes
96
+
97
+ def clear_all_collections(self, really: bool = False, prefix: str = "") -> int:
98
+ """
99
+ Returns:
100
+ Number of Pinecone indexes that were deleted
101
+
102
+ Args:
103
+ really: Optional[bool] - whether to really delete all Pinecone collections
104
+ prefix: Optional[str] - string to match potential Pinecone
105
+ indexes for deletion
106
+ """
107
+ if not really:
108
+ logger.warning("Not deleting all collections, set really=True to confirm")
109
+ return 0
110
+ indexes = [
111
+ c for c in self._list_index_metas(empty=True) if c.name.startswith(prefix)
112
+ ]
113
+ if len(indexes) == 0:
114
+ logger.warning(f"No collections found with prefix {prefix}")
115
+ return 0
116
+ n_empty_deletes, n_non_empty_deletes = 0, 0
117
+ for index_desc in indexes:
118
+ self.delete_collection(collection_name=index_desc.name)
119
+ n_empty_deletes += index_desc.total_vector_count == 0
120
+ n_non_empty_deletes += index_desc.total_vector_count > 0
121
+ logger.warning(
122
+ f"""
123
+ Deleted {n_empty_deletes} empty indexes and
124
+ {n_non_empty_deletes} non-empty indexes
125
+ """
126
+ )
127
+ return n_empty_deletes + n_non_empty_deletes
128
+
129
+ def list_collections(self, empty: bool = False) -> List[str]:
130
+ """
131
+ Returns:
132
+ List of Pinecone indices that have at least one vector.
133
+
134
+ Args:
135
+ empty: Optional[bool] - whether to include empty collections
136
+ """
137
+ indexes = self.client.list_indexes()
138
+ res: List[str] = []
139
+ if empty:
140
+ res.extend(indexes.names())
141
+ return res
142
+
143
+ for index in indexes.names():
144
+ index_meta = self.client.Index(name=index)
145
+ if index_meta.describe_index_stats().get("total_vector_count", 0) > 0:
146
+ res.append(index)
147
+ return res
148
+
149
+ def _list_index_metas(self, empty: bool = False) -> List[IndexMeta]:
150
+ """
151
+ Returns:
152
+ List of objects describing Pinecone indices
153
+
154
+ Args:
155
+ empty: Optional[bool] - whether to include empty collections
156
+ """
157
+ indexes = self.client.list_indexes()
158
+ res = []
159
+ for index in indexes.names():
160
+ index_meta = self._fetch_index_meta(index)
161
+ if empty:
162
+ res.append(index_meta)
163
+ elif index_meta.total_vector_count > 0:
164
+ res.append(index_meta)
165
+ return res
166
+
167
+ def _fetch_index_meta(self, index_name: str) -> IndexMeta:
168
+ """
169
+ Returns:
170
+ A dataclass describing the input Index by name and vector count
171
+ to save a bit on index description calls
172
+
173
+ Args:
174
+ index_name: str - Name of the index in Pinecone
175
+ """
176
+ try:
177
+ index = self.client.Index(name=index_name)
178
+ stats = index.describe_index_stats()
179
+ return IndexMeta(
180
+ name=index_name, total_vector_count=stats.get("total_vector_count", 0)
181
+ )
182
+ except PineconeApiException as e:
183
+ logger.warning(f"Error fetching details for index {index_name}")
184
+ logger.warning(e)
185
+ return IndexMeta(name=index_name, total_vector_count=-1)
186
+
187
+ def create_collection(self, collection_name: str, replace: bool = False) -> None:
188
+ """
189
+ Create a collection with the given name, optionally replacing an existing
190
+ collection if `replace` is True.
191
+
192
+ Args:
193
+ collection_name: str - Configuration of the collection to create.
194
+ replace: Optional[Bool] - Whether to replace an existing collection
195
+ with the same name. Defaults to False.
196
+ """
197
+ pattern = re.compile(r"^[a-z0-9-]+$")
198
+ if not pattern.match(collection_name):
199
+ raise ValueError(
200
+ "Pinecone index names must be lowercase alphanumeric characters or '-'"
201
+ )
202
+ self.config.collection_name = collection_name
203
+ if collection_name in self.list_collections(empty=True):
204
+ index = self.client.Index(name=collection_name)
205
+ stats = index.describe_index_stats()
206
+ status = self.client.describe_index(name=collection_name)
207
+ if status["status"]["ready"] and stats["total_vector_count"] > 0:
208
+ logger.warning(f"Non-empty collection {collection_name} already exists")
209
+ if not replace:
210
+ logger.warning("Not replacing collection")
211
+ return
212
+ else:
213
+ logger.warning("Recreating fresh collection")
214
+ self.delete_collection(collection_name=collection_name)
215
+
216
+ payload = {
217
+ "name": collection_name,
218
+ "dimension": self.embedding_dim,
219
+ "spec": self.config.spec,
220
+ "metric": self.config.metric,
221
+ "timeout": self.config.timeout,
222
+ }
223
+
224
+ if self.config.deletion_protection:
225
+ payload["deletion_protection"] = self.config.deletion_protection
226
+
227
+ try:
228
+ self.client.create_index(**payload)
229
+ except PineconeApiException as e:
230
+ logger.error(e)
231
+
232
+ def delete_collection(self, collection_name: str) -> None:
233
+ logger.info(f"Attempting to delete {collection_name}")
234
+ try:
235
+ self.client.delete_index(name=collection_name)
236
+ except PineconeApiException as e:
237
+ logger.error(f"Failed to delete {collection_name}")
238
+ logger.error(e)
239
+
240
+ def add_documents(self, documents: Sequence[Document], namespace: str = "") -> None:
241
+ if self.config.collection_name is None:
242
+ raise ValueError("No collection name set, cannot ingest docs")
243
+
244
+ if len(documents) == 0:
245
+ logger.warning("Empty list of documents passed into add_documents")
246
+ return
247
+
248
+ super().maybe_add_ids(documents)
249
+ document_dicts = [doc.dict() for doc in documents]
250
+ document_ids = [doc.id() for doc in documents]
251
+ embedding_vectors = self.embedding_fn([doc.content for doc in documents])
252
+ vectors = [
253
+ {
254
+ "id": document_id,
255
+ "values": embedding_vector,
256
+ "metadata": {
257
+ **document_dict["metadata"],
258
+ **{
259
+ key: value
260
+ for key, value in document_dict.items()
261
+ if key != "metadata"
262
+ },
263
+ },
264
+ }
265
+ for document_dict, document_id, embedding_vector in zip(
266
+ document_dicts, document_ids, embedding_vectors
267
+ )
268
+ ]
269
+
270
+ if self.config.collection_name not in self.list_collections(empty=True):
271
+ self.create_collection(
272
+ collection_name=self.config.collection_name, replace=True
273
+ )
274
+
275
+ index = self.client.Index(name=self.config.collection_name)
276
+ batch_size = self.config.batch_size
277
+
278
+ for i in range(0, len(documents), batch_size):
279
+ try:
280
+ if namespace:
281
+ index.upsert(
282
+ vectors=vectors[i : i + batch_size], namespace=namespace
283
+ )
284
+ else:
285
+ index.upsert(vectors=vectors[i : i + batch_size])
286
+ except PineconeApiException as e:
287
+ logger.error(
288
+ f"Unable to add of docs between indices {i} and {batch_size}"
289
+ )
290
+ logger.error(e)
291
+
292
+ def get_all_documents(
293
+ self, prefix: str = "", namespace: str = ""
294
+ ) -> List[Document]:
295
+ """
296
+ Returns:
297
+ All documents for the collection currently defined in
298
+ the configuration object
299
+
300
+ Args:
301
+ prefix: str - document id prefix to search for
302
+ namespace: str - partition of vectors to search within the index
303
+ """
304
+ if self.config.collection_name is None:
305
+ raise ValueError("No collection name set, cannot retrieve docs")
306
+ docs = []
307
+
308
+ request_filters: Dict[str, Union[str, int]] = {
309
+ "limit": self.config.pagination_size
310
+ }
311
+ if prefix:
312
+ request_filters["prefix"] = prefix
313
+ if namespace:
314
+ request_filters["namespace"] = namespace
315
+
316
+ index = self.client.Index(name=self.config.collection_name)
317
+
318
+ while True:
319
+ response = index.list_paginated(**request_filters)
320
+ vectors = response.get("vectors", [])
321
+
322
+ if not vectors:
323
+ logger.warning("Received empty list while requesting for vector ids")
324
+ logger.warning("Halting fetch requests")
325
+ if settings.debug:
326
+ logger.debug(f"Request for failed fetch was: {request_filters}")
327
+ break
328
+
329
+ docs.extend(
330
+ self.get_documents_by_ids(
331
+ ids=[vector.get("id") for vector in vectors],
332
+ namespace=namespace if namespace else "",
333
+ )
334
+ )
335
+
336
+ pagination_token = response.get("pagination", {}).get("next", None)
337
+
338
+ if not pagination_token:
339
+ break
340
+
341
+ request_filters["pagination_token"] = pagination_token
342
+
343
+ return docs
344
+
345
+ def get_documents_by_ids(
346
+ self, ids: List[str], namespace: str = ""
347
+ ) -> List[Document]:
348
+ """
349
+ Returns:
350
+ Fetches document text embedded in Pinecone index metadata
351
+
352
+ Args:
353
+ ids: List[str] - vector data object ids to retrieve
354
+ namespace: str - partition of vectors to search within the index
355
+ """
356
+ if self.config.collection_name is None:
357
+ raise ValueError("No collection name set, cannot retrieve docs")
358
+ index = self.client.Index(name=self.config.collection_name)
359
+
360
+ if namespace:
361
+ records = index.fetch(ids=ids, namespace=namespace)
362
+ else:
363
+ records = index.fetch(ids=ids)
364
+
365
+ id_mapping = {key: value for key, value in records["vectors"].items()}
366
+ ordered_payloads = [id_mapping[_id] for _id in ids if _id in id_mapping]
367
+ return [
368
+ self.transform_pinecone_vector(payload.get("metadata", {}))
369
+ for payload in ordered_payloads
370
+ ]
371
+
372
+ def similar_texts_with_scores(
373
+ self,
374
+ text: str,
375
+ k: int = 1,
376
+ where: Optional[str] = None,
377
+ namespace: Optional[str] = None,
378
+ ) -> List[Tuple[Document, float]]:
379
+ if self.config.collection_name is None:
380
+ raise ValueError("No collection name set, cannot search")
381
+
382
+ if k < 1 or k > 9999:
383
+ raise ValueError(
384
+ f"TopK for Pinecone vector search must be 1 < k < 10000, k was {k}"
385
+ )
386
+
387
+ vector_search_request = {
388
+ "top_k": k,
389
+ "include_metadata": True,
390
+ "vector": self.embedding_fn([text])[0],
391
+ }
392
+ if where:
393
+ vector_search_request["filter"] = json.loads(where) if where else None
394
+ if namespace:
395
+ vector_search_request["namespace"] = namespace
396
+
397
+ index = self.client.Index(name=self.config.collection_name)
398
+ response = index.query(**vector_search_request)
399
+ doc_score_pairs = [
400
+ (
401
+ self.transform_pinecone_vector(match.get("metadata", {})),
402
+ match.get("score", 0),
403
+ )
404
+ for match in response.get("matches", [])
405
+ ]
406
+ if settings.debug:
407
+ max_score = max([pair[1] for pair in doc_score_pairs])
408
+ logger.info(f"Found {len(doc_score_pairs)} matches, max score: {max_score}")
409
+ self.show_if_debug(doc_score_pairs)
410
+ return doc_score_pairs
411
+
412
+ def transform_pinecone_vector(self, metadata_dict: Dict[str, Any]) -> Document:
413
+ """
414
+ Parses the metadata response from the Pinecone vector query and
415
+ formats it into a dictionary that can be parsed by the Document class
416
+ associated with the PineconeDBConfig class
417
+
418
+ Returns:
419
+ Well formed dictionary object to be transformed into a Document
420
+
421
+ Args:
422
+ metadata_dict: Dict - the metadata dictionary from the Pinecone
423
+ vector query match
424
+ """
425
+ return self.config.document_class(
426
+ **{**metadata_dict, "metadata": {**metadata_dict}}
427
+ )