langchain-postgres 0.0.6__py3-none-any.whl → 0.0.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,36 +1,50 @@
1
+ # pylint: disable=too-many-lines
1
2
  from __future__ import annotations
2
3
 
4
+ import contextlib
3
5
  import enum
4
6
  import logging
5
7
  import uuid
6
8
  from typing import (
7
9
  Any,
10
+ AsyncGenerator,
8
11
  Callable,
9
12
  Dict,
13
+ Generator,
10
14
  Iterable,
11
15
  List,
12
16
  Optional,
17
+ Sequence,
13
18
  Tuple,
14
19
  Type,
15
20
  Union,
16
21
  )
22
+ from typing import (
23
+ cast as typing_cast,
24
+ )
17
25
 
18
26
  import numpy as np
19
27
  import sqlalchemy
20
- from sqlalchemy import SQLColumnExpression, cast, delete, func
21
- from sqlalchemy.dialects.postgresql import JSON, JSONB, JSONPATH, UUID, insert
22
- from sqlalchemy.orm import Session, relationship, sessionmaker
23
-
24
- try:
25
- from sqlalchemy.orm import declarative_base
26
- except ImportError:
27
- from sqlalchemy.ext.declarative import declarative_base
28
-
29
28
  from langchain_core.documents import Document
30
29
  from langchain_core.embeddings import Embeddings
31
- from langchain_core.runnables.config import run_in_executor
32
30
  from langchain_core.utils import get_from_dict_or_env
33
31
  from langchain_core.vectorstores import VectorStore
32
+ from sqlalchemy import SQLColumnExpression, cast, create_engine, delete, func, select
33
+ from sqlalchemy.dialects.postgresql import JSON, JSONB, JSONPATH, UUID, insert
34
+ from sqlalchemy.engine import Connection, Engine
35
+ from sqlalchemy.ext.asyncio import (
36
+ AsyncEngine,
37
+ AsyncSession,
38
+ async_sessionmaker,
39
+ create_async_engine,
40
+ )
41
+ from sqlalchemy.orm import (
42
+ Session,
43
+ declarative_base,
44
+ relationship,
45
+ scoped_session,
46
+ sessionmaker,
47
+ )
34
48
 
35
49
  from langchain_postgres._utils import maximal_marginal_relevance
36
50
 
@@ -74,7 +88,7 @@ TEXT_OPERATORS = {
74
88
  "$ilike",
75
89
  }
76
90
 
77
- LOGICAL_OPERATORS = {"$and", "$or"}
91
+ LOGICAL_OPERATORS = {"$and", "$or", "$not"}
78
92
 
79
93
  SUPPORTED_OPERATORS = (
80
94
  set(COMPARISONS_TO_NATIVE)
@@ -112,7 +126,27 @@ def _get_embedding_collection_store(vector_dimension: Optional[int] = None) -> A
112
126
  def get_by_name(
113
127
  cls, session: Session, name: str
114
128
  ) -> Optional["CollectionStore"]:
115
- return session.query(cls).filter(cls.name == name).first() # type: ignore
129
+ return (
130
+ session.query(cls)
131
+ .filter(typing_cast(sqlalchemy.Column, cls.name) == name)
132
+ .first()
133
+ )
134
+
135
+ @classmethod
136
+ async def aget_by_name(
137
+ cls, session: AsyncSession, name: str
138
+ ) -> Optional["CollectionStore"]:
139
+ return (
140
+ (
141
+ await session.execute(
142
+ select(CollectionStore).where(
143
+ typing_cast(sqlalchemy.Column, cls.name) == name
144
+ )
145
+ )
146
+ )
147
+ .scalars()
148
+ .first()
149
+ )
116
150
 
117
151
  @classmethod
118
152
  def get_or_create(
@@ -136,6 +170,28 @@ def _get_embedding_collection_store(vector_dimension: Optional[int] = None) -> A
136
170
  created = True
137
171
  return collection, created
138
172
 
173
+ @classmethod
174
+ async def aget_or_create(
175
+ cls,
176
+ session: AsyncSession,
177
+ name: str,
178
+ cmetadata: Optional[dict] = None,
179
+ ) -> Tuple["CollectionStore", bool]:
180
+ """
181
+ Get or create a collection.
182
+ Returns [Collection, bool] where the bool is True if the collection was created.
183
+ """ # noqa: E501
184
+ created = False
185
+ collection = await cls.aget_by_name(session, name)
186
+ if collection:
187
+ return collection, created
188
+
189
+ collection = cls(name=name, cmetadata=cmetadata)
190
+ session.add(collection)
191
+ await session.commit()
192
+ created = True
193
+ return collection, created
194
+
139
195
  class EmbeddingStore(Base):
140
196
  """Embedding store."""
141
197
 
@@ -177,7 +233,16 @@ def _results_to_docs(docs_and_scores: Any) -> List[Document]:
177
233
  return [doc for doc, _ in docs_and_scores]
178
234
 
179
235
 
180
- Connection = Union[sqlalchemy.engine.Engine, str]
236
+ def _create_vector_extension(conn: Connection) -> None:
237
+ statement = sqlalchemy.text(
238
+ "SELECT pg_advisory_xact_lock(1573678846307946496);"
239
+ "CREATE EXTENSION IF NOT EXISTS vector;"
240
+ )
241
+ conn.execute(statement)
242
+ conn.commit()
243
+
244
+
245
+ DBConnection = Union[sqlalchemy.engine.Engine, str]
181
246
 
182
247
 
183
248
  class PGVector(VectorStore):
@@ -215,6 +280,7 @@ class PGVector(VectorStore):
215
280
  connection=connection_string,
216
281
  collection_name=collection_name,
217
282
  use_jsonb=True,
283
+ async_mode=False,
218
284
  )
219
285
 
220
286
 
@@ -232,13 +298,52 @@ class PGVector(VectorStore):
232
298
  You will need to recreate the tables if you are using an existing database.
233
299
  * A Connection object has to be provided explicitly. Connections will not be
234
300
  picked up automatically based on env variables.
301
+ * langchain_postgres now accept async connections. If you want to use the async
302
+ version, you need to set `async_mode=True` when initializing the store or
303
+ use an async engine.
304
+
305
+ Supported filter operators:
306
+
307
+ * $eq: Equality operator
308
+ * $ne: Not equal operator
309
+ * $lt: Less than operator
310
+ * $lte: Less than or equal operator
311
+ * $gt: Greater than operator
312
+ * $gte: Greater than or equal operator
313
+ * $in: In operator
314
+ * $nin: Not in operator
315
+ * $between: Between operator
316
+ * $exists: Exists operator
317
+ * $like: Like operator
318
+ * $ilike: Case insensitive like operator
319
+ * $and: Logical AND operator
320
+ * $or: Logical OR operator
321
+ * $not: Logical NOT operator
322
+
323
+ Example:
324
+
325
+ .. code-block:: python
326
+
327
+ vectorstore.similarity_search('kitty', k=10, filter={
328
+ 'id': {'$in': [1, 5, 2, 9]}
329
+ })
330
+ #%% md
331
+
332
+ If you provide a dict with multiple fields, but no operators,
333
+ the top level will be interpreted as a logical **AND** filter
334
+
335
+ vectorstore.similarity_search('ducks', k=10, filter={
336
+ 'id': {'$in': [1, 5, 2, 9]},
337
+ 'location': {'$in': ["pond", "market"]}
338
+ })
339
+
235
340
  """
236
341
 
237
342
  def __init__(
238
343
  self,
239
344
  embeddings: Embeddings,
240
345
  *,
241
- connection: Optional[Connection] = None,
346
+ connection: Union[None, DBConnection, Engine, AsyncEngine, str] = None,
242
347
  embedding_length: Optional[int] = None,
243
348
  collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
244
349
  collection_metadata: Optional[dict] = None,
@@ -249,11 +354,13 @@ class PGVector(VectorStore):
249
354
  engine_args: Optional[dict[str, Any]] = None,
250
355
  use_jsonb: bool = True,
251
356
  create_extension: bool = True,
357
+ async_mode: bool = False,
252
358
  ) -> None:
253
359
  """Initialize the PGVector store.
360
+ For an async version, use `PGVector.acreate()` instead.
254
361
 
255
362
  Args:
256
- connection: Postgres connection string.
363
+ connection: Postgres connection string or (async)engine.
257
364
  embeddings: Any embedding function implementing
258
365
  `langchain.embeddings.base.Embeddings` interface.
259
366
  embedding_length: The length of the embedding vector. (default: None)
@@ -277,6 +384,7 @@ class PGVector(VectorStore):
277
384
  doesn't exist. disabling creation is useful when using ReadOnly
278
385
  Databases.
279
386
  """
387
+ self.async_mode = async_mode
280
388
  self.embedding_function = embeddings
281
389
  self._embedding_length = embedding_length
282
390
  self.collection_name = collection_name
@@ -285,20 +393,33 @@ class PGVector(VectorStore):
285
393
  self.pre_delete_collection = pre_delete_collection
286
394
  self.logger = logger or logging.getLogger(__name__)
287
395
  self.override_relevance_score_fn = relevance_score_fn
396
+ self._engine: Optional[Engine] = None
397
+ self._async_engine: Optional[AsyncEngine] = None
398
+ self._async_init = False
288
399
 
289
400
  if isinstance(connection, str):
290
- self._engine = sqlalchemy.create_engine(
291
- url=connection, **(engine_args or {})
292
- )
293
- elif isinstance(connection, sqlalchemy.engine.Engine):
401
+ if async_mode:
402
+ self._async_engine = create_async_engine(
403
+ connection, **(engine_args or {})
404
+ )
405
+ else:
406
+ self._engine = create_engine(url=connection, **(engine_args or {}))
407
+ elif isinstance(connection, Engine):
408
+ self.async_mode = False
294
409
  self._engine = connection
410
+ elif isinstance(connection, AsyncEngine):
411
+ self.async_mode = True
412
+ self._async_engine = connection
295
413
  else:
296
414
  raise ValueError(
297
415
  "connection should be a connection string or an instance of "
298
- "sqlalchemy.engine.Engine"
416
+ "sqlalchemy.engine.Engine or sqlalchemy.ext.asyncio.engine.AsyncEngine"
299
417
  )
300
-
301
- self._session_maker = sessionmaker(bind=self._engine)
418
+ self.session_maker: Union[scoped_session, async_sessionmaker]
419
+ if self.async_mode:
420
+ self.session_maker = async_sessionmaker(bind=self._async_engine)
421
+ else:
422
+ self.session_maker = scoped_session(sessionmaker(bind=self._engine))
302
423
 
303
424
  self.use_jsonb = use_jsonb
304
425
  self.create_extension = create_extension
@@ -306,7 +427,8 @@ class PGVector(VectorStore):
306
427
  if not use_jsonb:
307
428
  # Replace with a deprecation warning.
308
429
  raise NotImplementedError("use_jsonb=False is no longer supported.")
309
- self.__post_init__()
430
+ if not self.async_mode:
431
+ self.__post_init__()
310
432
 
311
433
  def __post_init__(
312
434
  self,
@@ -323,52 +445,99 @@ class PGVector(VectorStore):
323
445
  self.create_tables_if_not_exists()
324
446
  self.create_collection()
325
447
 
326
- def __del__(self) -> None:
327
- if isinstance(self._engine, sqlalchemy.engine.Connection):
328
- self._engine.close()
448
+ async def __apost_init__(
449
+ self,
450
+ ) -> None:
451
+ """Async initialize the store (use lazy approach)."""
452
+ if self._async_init: # Warning: possible race condition
453
+ return
454
+ self._async_init = True
455
+
456
+ EmbeddingStore, CollectionStore = _get_embedding_collection_store(
457
+ self._embedding_length
458
+ )
459
+ self.CollectionStore = CollectionStore
460
+ self.EmbeddingStore = EmbeddingStore
461
+ if self.create_extension:
462
+ await self.acreate_vector_extension()
463
+
464
+ await self.acreate_tables_if_not_exists()
465
+ await self.acreate_collection()
329
466
 
330
467
  @property
331
468
  def embeddings(self) -> Embeddings:
332
469
  return self.embedding_function
333
470
 
334
471
  def create_vector_extension(self) -> None:
472
+ assert self._engine, "engine not found"
335
473
  try:
336
- with self._session_maker() as session: # type: ignore[arg-type]
337
- # The advisor lock fixes issue arising from concurrent
338
- # creation of the vector extension.
339
- # https://github.com/langchain-ai/langchain/issues/12933
340
- # For more information see:
341
- # https://www.postgresql.org/docs/16/explicit-locking.html#ADVISORY-LOCKS
342
- statement = sqlalchemy.text(
343
- "BEGIN;"
344
- "SELECT pg_advisory_xact_lock(1573678846307946496);"
345
- "CREATE EXTENSION IF NOT EXISTS vector;"
346
- "COMMIT;"
347
- )
348
- session.execute(statement)
349
- session.commit()
474
+ with self._engine.connect() as conn:
475
+ _create_vector_extension(conn)
350
476
  except Exception as e:
351
477
  raise Exception(f"Failed to create vector extension: {e}") from e
352
478
 
479
+ async def acreate_vector_extension(self) -> None:
480
+ assert self._async_engine, "_async_engine not found"
481
+
482
+ async with self._async_engine.begin() as conn:
483
+ await conn.run_sync(_create_vector_extension)
484
+
353
485
  def create_tables_if_not_exists(self) -> None:
354
- with self._session_maker() as session:
486
+ with self._make_sync_session() as session:
355
487
  Base.metadata.create_all(session.get_bind())
488
+ session.commit()
489
+
490
+ async def acreate_tables_if_not_exists(self) -> None:
491
+ assert self._async_engine, "This method must be called with async_mode"
492
+ async with self._async_engine.begin() as conn:
493
+ await conn.run_sync(Base.metadata.create_all)
356
494
 
357
495
  def drop_tables(self) -> None:
358
- with self._session_maker() as session:
496
+ with self._make_sync_session() as session:
359
497
  Base.metadata.drop_all(session.get_bind())
498
+ session.commit()
499
+
500
+ async def adrop_tables(self) -> None:
501
+ assert self._async_engine, "This method must be called with async_mode"
502
+ await self.__apost_init__() # Lazy async init
503
+ async with self._async_engine.begin() as conn:
504
+ await conn.run_sync(Base.metadata.drop_all)
360
505
 
361
506
  def create_collection(self) -> None:
362
507
  if self.pre_delete_collection:
363
508
  self.delete_collection()
364
- with self._session_maker() as session: # type: ignore[arg-type]
509
+ with self._make_sync_session() as session:
365
510
  self.CollectionStore.get_or_create(
366
511
  session, self.collection_name, cmetadata=self.collection_metadata
367
512
  )
513
+ session.commit()
514
+
515
+ async def acreate_collection(self) -> None:
516
+ await self.__apost_init__() # Lazy async init
517
+ async with self._make_async_session() as session:
518
+ if self.pre_delete_collection:
519
+ await self._adelete_collection(session)
520
+ await self.CollectionStore.aget_or_create(
521
+ session, self.collection_name, cmetadata=self.collection_metadata
522
+ )
523
+ await session.commit()
524
+
525
+ def _delete_collection(self, session: Session) -> None:
526
+ collection = self.get_collection(session)
527
+ if not collection:
528
+ self.logger.warning("Collection not found")
529
+ return
530
+ session.delete(collection)
531
+
532
+ async def _adelete_collection(self, session: AsyncSession) -> None:
533
+ collection = await self.aget_collection(session)
534
+ if not collection:
535
+ self.logger.warning("Collection not found")
536
+ return
537
+ await session.delete(collection)
368
538
 
369
539
  def delete_collection(self) -> None:
370
- self.logger.debug("Trying to delete collection")
371
- with self._session_maker() as session: # type: ignore[arg-type]
540
+ with self._make_sync_session() as session:
372
541
  collection = self.get_collection(session)
373
542
  if not collection:
374
543
  self.logger.warning("Collection not found")
@@ -376,6 +545,16 @@ class PGVector(VectorStore):
376
545
  session.delete(collection)
377
546
  session.commit()
378
547
 
548
+ async def adelete_collection(self) -> None:
549
+ await self.__apost_init__() # Lazy async init
550
+ async with self._make_async_session() as session:
551
+ collection = await self.aget_collection(session)
552
+ if not collection:
553
+ self.logger.warning("Collection not found")
554
+ return
555
+ await session.delete(collection)
556
+ await session.commit()
557
+
379
558
  def delete(
380
559
  self,
381
560
  ids: Optional[List[str]] = None,
@@ -388,7 +567,7 @@ class PGVector(VectorStore):
388
567
  ids: List of ids to delete.
389
568
  collection_only: Only delete ids in the collection.
390
569
  """
391
- with self._session_maker() as session:
570
+ with self._make_sync_session() as session:
392
571
  if ids is not None:
393
572
  self.logger.debug(
394
573
  "Trying to delete vectors by ids (represented by the model "
@@ -411,9 +590,51 @@ class PGVector(VectorStore):
411
590
  session.execute(stmt)
412
591
  session.commit()
413
592
 
593
+ async def adelete(
594
+ self,
595
+ ids: Optional[List[str]] = None,
596
+ collection_only: bool = False,
597
+ **kwargs: Any,
598
+ ) -> None:
599
+ """Async delete vectors by ids or uuids.
600
+
601
+ Args:
602
+ ids: List of ids to delete.
603
+ collection_only: Only delete ids in the collection.
604
+ """
605
+ await self.__apost_init__() # Lazy async init
606
+ async with self._make_async_session() as session:
607
+ if ids is not None:
608
+ self.logger.debug(
609
+ "Trying to delete vectors by ids (represented by the model "
610
+ "using the custom ids field)"
611
+ )
612
+
613
+ stmt = delete(self.EmbeddingStore)
614
+
615
+ if collection_only:
616
+ collection = await self.aget_collection(session)
617
+ if not collection:
618
+ self.logger.warning("Collection not found")
619
+ return
620
+
621
+ stmt = stmt.where(
622
+ self.EmbeddingStore.collection_id == collection.uuid
623
+ )
624
+
625
+ stmt = stmt.where(self.EmbeddingStore.id.in_(ids))
626
+ await session.execute(stmt)
627
+ await session.commit()
628
+
414
629
  def get_collection(self, session: Session) -> Any:
630
+ assert not self._async_engine, "This method must be called without async_mode"
415
631
  return self.CollectionStore.get_by_name(session, self.collection_name)
416
632
 
633
+ async def aget_collection(self, session: AsyncSession) -> Any:
634
+ assert self._async_engine, "This method must be called with async_mode"
635
+ await self.__apost_init__() # Lazy async init
636
+ return await self.CollectionStore.aget_by_name(session, self.collection_name)
637
+
417
638
  @classmethod
418
639
  def __from(
419
640
  cls,
@@ -452,6 +673,45 @@ class PGVector(VectorStore):
452
673
 
453
674
  return store
454
675
 
676
+ @classmethod
677
+ async def __afrom(
678
+ cls,
679
+ texts: List[str],
680
+ embeddings: List[List[float]],
681
+ embedding: Embeddings,
682
+ metadatas: Optional[List[dict]] = None,
683
+ ids: Optional[List[str]] = None,
684
+ collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
685
+ distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
686
+ connection: Optional[str] = None,
687
+ pre_delete_collection: bool = False,
688
+ *,
689
+ use_jsonb: bool = True,
690
+ **kwargs: Any,
691
+ ) -> PGVector:
692
+ if ids is None:
693
+ ids = [str(uuid.uuid1()) for _ in texts]
694
+
695
+ if not metadatas:
696
+ metadatas = [{} for _ in texts]
697
+
698
+ store = cls(
699
+ connection=connection,
700
+ collection_name=collection_name,
701
+ embeddings=embedding,
702
+ distance_strategy=distance_strategy,
703
+ pre_delete_collection=pre_delete_collection,
704
+ use_jsonb=use_jsonb,
705
+ async_mode=True,
706
+ **kwargs,
707
+ )
708
+
709
+ await store.aadd_embeddings(
710
+ texts=texts, embeddings=embeddings, metadatas=metadatas, ids=ids, **kwargs
711
+ )
712
+
713
+ return store
714
+
455
715
  def add_embeddings(
456
716
  self,
457
717
  texts: Iterable[str],
@@ -466,15 +726,18 @@ class PGVector(VectorStore):
466
726
  texts: Iterable of strings to add to the vectorstore.
467
727
  embeddings: List of list of embedding vectors.
468
728
  metadatas: List of metadatas associated with the texts.
729
+ ids: Optional list of ids for the documents.
730
+ If not provided, will generate a new id for each document.
469
731
  kwargs: vectorstore specific parameters
470
732
  """
733
+ assert not self._async_engine, "This method must be called with sync_mode"
471
734
  if ids is None:
472
735
  ids = [str(uuid.uuid4()) for _ in texts]
473
736
 
474
737
  if not metadatas:
475
738
  metadatas = [{} for _ in texts]
476
739
 
477
- with self._session_maker() as session: # type: ignore[arg-type]
740
+ with self._make_sync_session() as session: # type: ignore[arg-type]
478
741
  collection = self.get_collection(session)
479
742
  if not collection:
480
743
  raise ValueError("Collection not found")
@@ -505,6 +768,62 @@ class PGVector(VectorStore):
505
768
 
506
769
  return ids
507
770
 
771
+ async def aadd_embeddings(
772
+ self,
773
+ texts: Iterable[str],
774
+ embeddings: List[List[float]],
775
+ metadatas: Optional[List[dict]] = None,
776
+ ids: Optional[List[str]] = None,
777
+ **kwargs: Any,
778
+ ) -> List[str]:
779
+ """Async add embeddings to the vectorstore.
780
+
781
+ Args:
782
+ texts: Iterable of strings to add to the vectorstore.
783
+ embeddings: List of list of embedding vectors.
784
+ metadatas: List of metadatas associated with the texts.
785
+ ids: Optional list of ids for the texts.
786
+ If not provided, will generate a new id for each text.
787
+ kwargs: vectorstore specific parameters
788
+ """
789
+ await self.__apost_init__() # Lazy async init
790
+ if ids is None:
791
+ ids = [str(uuid.uuid1()) for _ in texts]
792
+
793
+ if not metadatas:
794
+ metadatas = [{} for _ in texts]
795
+
796
+ async with self._make_async_session() as session: # type: ignore[arg-type]
797
+ collection = await self.aget_collection(session)
798
+ if not collection:
799
+ raise ValueError("Collection not found")
800
+ data = [
801
+ {
802
+ "id": id,
803
+ "collection_id": collection.uuid,
804
+ "embedding": embedding,
805
+ "document": text,
806
+ "cmetadata": metadata or {},
807
+ }
808
+ for text, metadata, embedding, id in zip(
809
+ texts, metadatas, embeddings, ids
810
+ )
811
+ ]
812
+ stmt = insert(self.EmbeddingStore).values(data)
813
+ on_conflict_stmt = stmt.on_conflict_do_update(
814
+ index_elements=["id"],
815
+ # Conflict detection based on these columns
816
+ set_={
817
+ "embedding": stmt.excluded.embedding,
818
+ "document": stmt.excluded.document,
819
+ "cmetadata": stmt.excluded.cmetadata,
820
+ },
821
+ )
822
+ await session.execute(on_conflict_stmt)
823
+ await session.commit()
824
+
825
+ return ids
826
+
508
827
  def add_texts(
509
828
  self,
510
829
  texts: Iterable[str],
@@ -517,16 +836,44 @@ class PGVector(VectorStore):
517
836
  Args:
518
837
  texts: Iterable of strings to add to the vectorstore.
519
838
  metadatas: Optional list of metadatas associated with the texts.
839
+ ids: Optional list of ids for the texts.
840
+ If not provided, will generate a new id for each text.
520
841
  kwargs: vectorstore specific parameters
521
842
 
522
843
  Returns:
523
844
  List of ids from adding the texts into the vectorstore.
524
845
  """
846
+ assert not self._async_engine, "This method must be called without async_mode"
525
847
  embeddings = self.embedding_function.embed_documents(list(texts))
526
848
  return self.add_embeddings(
527
849
  texts=texts, embeddings=embeddings, metadatas=metadatas, ids=ids, **kwargs
528
850
  )
529
851
 
852
+ async def aadd_texts(
853
+ self,
854
+ texts: Iterable[str],
855
+ metadatas: Optional[List[dict]] = None,
856
+ ids: Optional[List[str]] = None,
857
+ **kwargs: Any,
858
+ ) -> List[str]:
859
+ """Run more texts through the embeddings and add to the vectorstore.
860
+
861
+ Args:
862
+ texts: Iterable of strings to add to the vectorstore.
863
+ metadatas: Optional list of metadatas associated with the texts.
864
+ ids: Optional list of ids for the texts.
865
+ If not provided, will generate a new id for each text.
866
+ kwargs: vectorstore specific parameters
867
+
868
+ Returns:
869
+ List of ids from adding the texts into the vectorstore.
870
+ """
871
+ await self.__apost_init__() # Lazy async init
872
+ embeddings = await self.embedding_function.aembed_documents(list(texts))
873
+ return await self.aadd_embeddings(
874
+ texts=texts, embeddings=embeddings, metadatas=metadatas, ids=ids, **kwargs
875
+ )
876
+
530
877
  def similarity_search(
531
878
  self,
532
879
  query: str,
@@ -544,6 +891,7 @@ class PGVector(VectorStore):
544
891
  Returns:
545
892
  List of Documents most similar to the query.
546
893
  """
894
+ assert not self._async_engine, "This method must be called without async_mode"
547
895
  embedding = self.embedding_function.embed_query(text=query)
548
896
  return self.similarity_search_by_vector(
549
897
  embedding=embedding,
@@ -551,6 +899,31 @@ class PGVector(VectorStore):
551
899
  filter=filter,
552
900
  )
553
901
 
902
+ async def asimilarity_search(
903
+ self,
904
+ query: str,
905
+ k: int = 4,
906
+ filter: Optional[dict] = None,
907
+ **kwargs: Any,
908
+ ) -> List[Document]:
909
+ """Run similarity search with PGVector with distance.
910
+
911
+ Args:
912
+ query (str): Query text to search for.
913
+ k (int): Number of results to return. Defaults to 4.
914
+ filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
915
+
916
+ Returns:
917
+ List of Documents most similar to the query.
918
+ """
919
+ await self.__apost_init__() # Lazy async init
920
+ embedding = self.embedding_function.embed_query(text=query)
921
+ return await self.asimilarity_search_by_vector(
922
+ embedding=embedding,
923
+ k=k,
924
+ filter=filter,
925
+ )
926
+
554
927
  def similarity_search_with_score(
555
928
  self,
556
929
  query: str,
@@ -567,12 +940,36 @@ class PGVector(VectorStore):
567
940
  Returns:
568
941
  List of Documents most similar to the query and score for each.
569
942
  """
943
+ assert not self._async_engine, "This method must be called without async_mode"
570
944
  embedding = self.embedding_function.embed_query(query)
571
945
  docs = self.similarity_search_with_score_by_vector(
572
946
  embedding=embedding, k=k, filter=filter
573
947
  )
574
948
  return docs
575
949
 
950
+ async def asimilarity_search_with_score(
951
+ self,
952
+ query: str,
953
+ k: int = 4,
954
+ filter: Optional[dict] = None,
955
+ ) -> List[Tuple[Document, float]]:
956
+ """Return docs most similar to query.
957
+
958
+ Args:
959
+ query: Text to look up documents similar to.
960
+ k: Number of Documents to return. Defaults to 4.
961
+ filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
962
+
963
+ Returns:
964
+ List of Documents most similar to the query and score for each.
965
+ """
966
+ await self.__apost_init__() # Lazy async init
967
+ embedding = self.embedding_function.embed_query(query)
968
+ docs = await self.asimilarity_search_with_score_by_vector(
969
+ embedding=embedding, k=k, filter=filter
970
+ )
971
+ return docs
972
+
576
973
  @property
577
974
  def distance_strategy(self) -> Any:
578
975
  if self._distance_strategy == DistanceStrategy.EUCLIDEAN:
@@ -593,10 +990,25 @@ class PGVector(VectorStore):
593
990
  k: int = 4,
594
991
  filter: Optional[dict] = None,
595
992
  ) -> List[Tuple[Document, float]]:
993
+ assert not self._async_engine, "This method must be called without async_mode"
596
994
  results = self.__query_collection(embedding=embedding, k=k, filter=filter)
597
995
 
598
996
  return self._results_to_docs_and_scores(results)
599
997
 
998
+ async def asimilarity_search_with_score_by_vector(
999
+ self,
1000
+ embedding: List[float],
1001
+ k: int = 4,
1002
+ filter: Optional[dict] = None,
1003
+ ) -> List[Tuple[Document, float]]:
1004
+ await self.__apost_init__() # Lazy async init
1005
+ async with self._make_async_session() as session: # type: ignore[arg-type]
1006
+ results = await self.__aquery_collection(
1007
+ session=session, embedding=embedding, k=k, filter=filter
1008
+ )
1009
+
1010
+ return self._results_to_docs_and_scores(results)
1011
+
600
1012
  def _results_to_docs_and_scores(self, results: Any) -> List[Tuple[Document, float]]:
601
1013
  """Return docs and scores from results."""
602
1014
  docs = [
@@ -698,6 +1110,11 @@ class PGVector(VectorStore):
698
1110
  f"Unsupported type: {type(val)} for value: {val}"
699
1111
  )
700
1112
 
1113
+ if isinstance(val, bool): # b/c bool is an instance of int
1114
+ raise NotImplementedError(
1115
+ f"Unsupported type: {type(val)} for value: {val}"
1116
+ )
1117
+
701
1118
  queried_field = self.EmbeddingStore.cmetadata[field].astext
702
1119
 
703
1120
  if operator in {"$in"}:
@@ -832,26 +1249,25 @@ class PGVector(VectorStore):
832
1249
  """
833
1250
  if isinstance(filters, dict):
834
1251
  if len(filters) == 1:
835
- # The only operators allowed at the top level are $AND and $OR
1252
+ # The only operators allowed at the top level are $AND, $OR, and $NOT
836
1253
  # First check if an operator or a field
837
1254
  key, value = list(filters.items())[0]
838
1255
  if key.startswith("$"):
839
1256
  # Then it's an operator
840
- if key.lower() not in ["$and", "$or"]:
1257
+ if key.lower() not in ["$and", "$or", "$not"]:
841
1258
  raise ValueError(
842
- f"Invalid filter condition. Expected $and or $or "
1259
+ f"Invalid filter condition. Expected $and, $or or $not "
843
1260
  f"but got: {key}"
844
1261
  )
845
1262
  else:
846
1263
  # Then it's a field
847
1264
  return self._handle_field_filter(key, filters[key])
848
1265
 
849
- # Here we handle the $and and $or operators
850
- if not isinstance(value, list):
851
- raise ValueError(
852
- f"Expected a list, but got {type(value)} for value: {value}"
853
- )
854
1266
  if key.lower() == "$and":
1267
+ if not isinstance(value, list):
1268
+ raise ValueError(
1269
+ f"Expected a list, but got {type(value)} for value: {value}"
1270
+ )
855
1271
  and_ = [self._create_filter_clause(el) for el in value]
856
1272
  if len(and_) > 1:
857
1273
  return sqlalchemy.and_(*and_)
@@ -863,6 +1279,10 @@ class PGVector(VectorStore):
863
1279
  "but got an empty dictionary"
864
1280
  )
865
1281
  elif key.lower() == "$or":
1282
+ if not isinstance(value, list):
1283
+ raise ValueError(
1284
+ f"Expected a list, but got {type(value)} for value: {value}"
1285
+ )
866
1286
  or_ = [self._create_filter_clause(el) for el in value]
867
1287
  if len(or_) > 1:
868
1288
  return sqlalchemy.or_(*or_)
@@ -873,9 +1293,29 @@ class PGVector(VectorStore):
873
1293
  "Invalid filter condition. Expected a dictionary "
874
1294
  "but got an empty dictionary"
875
1295
  )
1296
+ elif key.lower() == "$not":
1297
+ if isinstance(value, list):
1298
+ not_conditions = [
1299
+ self._create_filter_clause(item) for item in value
1300
+ ]
1301
+ not_ = sqlalchemy.and_(
1302
+ *[
1303
+ sqlalchemy.not_(condition)
1304
+ for condition in not_conditions
1305
+ ]
1306
+ )
1307
+ return not_
1308
+ elif isinstance(value, dict):
1309
+ not_ = self._create_filter_clause(value)
1310
+ return sqlalchemy.not_(not_)
1311
+ else:
1312
+ raise ValueError(
1313
+ f"Invalid filter condition. Expected a dictionary "
1314
+ f"or a list but got: {type(value)}"
1315
+ )
876
1316
  else:
877
1317
  raise ValueError(
878
- f"Invalid filter condition. Expected $and or $or "
1318
+ f"Invalid filter condition. Expected $and, $or or $not "
879
1319
  f"but got: {key}"
880
1320
  )
881
1321
  elif len(filters) > 1:
@@ -908,9 +1348,9 @@ class PGVector(VectorStore):
908
1348
  embedding: List[float],
909
1349
  k: int = 4,
910
1350
  filter: Optional[Dict[str, str]] = None,
911
- ) -> List[Any]:
1351
+ ) -> Sequence[Any]:
912
1352
  """Query the collection."""
913
- with self._session_maker() as session: # type: ignore[arg-type]
1353
+ with self._make_sync_session() as session: # type: ignore[arg-type]
914
1354
  collection = self.get_collection(session)
915
1355
  if not collection:
916
1356
  raise ValueError("Collection not found")
@@ -931,7 +1371,7 @@ class PGVector(VectorStore):
931
1371
  results: List[Any] = (
932
1372
  session.query(
933
1373
  self.EmbeddingStore,
934
- self.distance_strategy(embedding).label("distance"), # type: ignore
1374
+ self.distance_strategy(embedding).label("distance"),
935
1375
  )
936
1376
  .filter(*filter_by)
937
1377
  .order_by(sqlalchemy.asc("distance"))
@@ -945,6 +1385,50 @@ class PGVector(VectorStore):
945
1385
 
946
1386
  return results
947
1387
 
1388
+ async def __aquery_collection(
1389
+ self,
1390
+ session: AsyncSession,
1391
+ embedding: List[float],
1392
+ k: int = 4,
1393
+ filter: Optional[Dict[str, str]] = None,
1394
+ ) -> Sequence[Any]:
1395
+ """Query the collection."""
1396
+ async with self._make_async_session() as session: # type: ignore[arg-type]
1397
+ collection = await self.aget_collection(session)
1398
+ if not collection:
1399
+ raise ValueError("Collection not found")
1400
+
1401
+ filter_by = [self.EmbeddingStore.collection_id == collection.uuid]
1402
+ if filter:
1403
+ if self.use_jsonb:
1404
+ filter_clauses = self._create_filter_clause(filter)
1405
+ if filter_clauses is not None:
1406
+ filter_by.append(filter_clauses)
1407
+ else:
1408
+ # Old way of doing things
1409
+ filter_clauses = self._create_filter_clause_json_deprecated(filter)
1410
+ filter_by.extend(filter_clauses)
1411
+
1412
+ _type = self.EmbeddingStore
1413
+
1414
+ stmt = (
1415
+ select(
1416
+ self.EmbeddingStore,
1417
+ self.distance_strategy(embedding).label("distance"),
1418
+ )
1419
+ .filter(*filter_by)
1420
+ .order_by(sqlalchemy.asc("distance"))
1421
+ .join(
1422
+ self.CollectionStore,
1423
+ self.EmbeddingStore.collection_id == self.CollectionStore.uuid,
1424
+ )
1425
+ .limit(k)
1426
+ )
1427
+
1428
+ results: Sequence[Any] = (await session.execute(stmt)).all()
1429
+
1430
+ return results
1431
+
948
1432
  def similarity_search_by_vector(
949
1433
  self,
950
1434
  embedding: List[float],
@@ -962,11 +1446,36 @@ class PGVector(VectorStore):
962
1446
  Returns:
963
1447
  List of Documents most similar to the query vector.
964
1448
  """
1449
+ assert not self._async_engine, "This method must be called without async_mode"
965
1450
  docs_and_scores = self.similarity_search_with_score_by_vector(
966
1451
  embedding=embedding, k=k, filter=filter
967
1452
  )
968
1453
  return _results_to_docs(docs_and_scores)
969
1454
 
1455
+ async def asimilarity_search_by_vector(
1456
+ self,
1457
+ embedding: List[float],
1458
+ k: int = 4,
1459
+ filter: Optional[dict] = None,
1460
+ **kwargs: Any,
1461
+ ) -> List[Document]:
1462
+ """Return docs most similar to embedding vector.
1463
+
1464
+ Args:
1465
+ embedding: Embedding to look up documents similar to.
1466
+ k: Number of Documents to return. Defaults to 4.
1467
+ filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
1468
+
1469
+ Returns:
1470
+ List of Documents most similar to the query vector.
1471
+ """
1472
+ assert self._async_engine, "This method must be called with async_mode"
1473
+ await self.__apost_init__() # Lazy async init
1474
+ docs_and_scores = await self.asimilarity_search_with_score_by_vector(
1475
+ embedding=embedding, k=k, filter=filter
1476
+ )
1477
+ return _results_to_docs(docs_and_scores)
1478
+
970
1479
  @classmethod
971
1480
  def from_texts(
972
1481
  cls: Type[PGVector],
@@ -997,6 +1506,35 @@ class PGVector(VectorStore):
997
1506
  **kwargs,
998
1507
  )
999
1508
 
1509
+ @classmethod
1510
+ async def afrom_texts(
1511
+ cls: Type[PGVector],
1512
+ texts: List[str],
1513
+ embedding: Embeddings,
1514
+ metadatas: Optional[List[dict]] = None,
1515
+ collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
1516
+ distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
1517
+ ids: Optional[List[str]] = None,
1518
+ pre_delete_collection: bool = False,
1519
+ *,
1520
+ use_jsonb: bool = True,
1521
+ **kwargs: Any,
1522
+ ) -> PGVector:
1523
+ """Return VectorStore initialized from documents and embeddings."""
1524
+ embeddings = embedding.embed_documents(list(texts))
1525
+ return await cls.__afrom(
1526
+ texts,
1527
+ embeddings,
1528
+ embedding,
1529
+ metadatas=metadatas,
1530
+ ids=ids,
1531
+ collection_name=collection_name,
1532
+ distance_strategy=distance_strategy,
1533
+ pre_delete_collection=pre_delete_collection,
1534
+ use_jsonb=use_jsonb,
1535
+ **kwargs,
1536
+ )
1537
+
1000
1538
  @classmethod
1001
1539
  def from_embeddings(
1002
1540
  cls,
@@ -1019,6 +1557,7 @@ class PGVector(VectorStore):
1019
1557
  collection_name: Name of the collection.
1020
1558
  distance_strategy: Distance strategy to use.
1021
1559
  ids: Optional list of ids for the documents.
1560
+ If not provided, will generate a new id for each document.
1022
1561
  pre_delete_collection: If True, will delete the collection if it exists.
1023
1562
  **Attention**: This will delete all the documents in the existing
1024
1563
  collection.
@@ -1053,6 +1592,51 @@ class PGVector(VectorStore):
1053
1592
  **kwargs,
1054
1593
  )
1055
1594
 
1595
+ @classmethod
1596
+ async def afrom_embeddings(
1597
+ cls,
1598
+ text_embeddings: List[Tuple[str, List[float]]],
1599
+ embedding: Embeddings,
1600
+ metadatas: Optional[List[dict]] = None,
1601
+ collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
1602
+ distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
1603
+ ids: Optional[List[str]] = None,
1604
+ pre_delete_collection: bool = False,
1605
+ **kwargs: Any,
1606
+ ) -> PGVector:
1607
+ """Construct PGVector wrapper from raw documents and pre-
1608
+ generated embeddings.
1609
+
1610
+ Return VectorStore initialized from documents and embeddings.
1611
+ Postgres connection string is required
1612
+ "Either pass it as a parameter
1613
+ or set the PGVECTOR_CONNECTION_STRING environment variable.
1614
+
1615
+ Example:
1616
+ .. code-block:: python
1617
+
1618
+ from langchain_community.vectorstores import PGVector
1619
+ from langchain_community.embeddings import OpenAIEmbeddings
1620
+ embeddings = OpenAIEmbeddings()
1621
+ text_embeddings = embeddings.embed_documents(texts)
1622
+ text_embedding_pairs = list(zip(texts, text_embeddings))
1623
+ faiss = PGVector.from_embeddings(text_embedding_pairs, embeddings)
1624
+ """
1625
+ texts = [t[0] for t in text_embeddings]
1626
+ embeddings = [t[1] for t in text_embeddings]
1627
+
1628
+ return await cls.__afrom(
1629
+ texts,
1630
+ embeddings,
1631
+ embedding,
1632
+ metadatas=metadatas,
1633
+ ids=ids,
1634
+ collection_name=collection_name,
1635
+ distance_strategy=distance_strategy,
1636
+ pre_delete_collection=pre_delete_collection,
1637
+ **kwargs,
1638
+ )
1639
+
1056
1640
  @classmethod
1057
1641
  def from_existing_index(
1058
1642
  cls: Type[PGVector],
@@ -1061,7 +1645,7 @@ class PGVector(VectorStore):
1061
1645
  collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
1062
1646
  distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
1063
1647
  pre_delete_collection: bool = False,
1064
- connection: Optional[Connection] = None,
1648
+ connection: Optional[DBConnection] = None,
1065
1649
  **kwargs: Any,
1066
1650
  ) -> PGVector:
1067
1651
  """
@@ -1080,11 +1664,39 @@ class PGVector(VectorStore):
1080
1664
 
1081
1665
  return store
1082
1666
 
1667
+ @classmethod
1668
+ async def afrom_existing_index(
1669
+ cls: Type[PGVector],
1670
+ embedding: Embeddings,
1671
+ *,
1672
+ collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
1673
+ distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
1674
+ pre_delete_collection: bool = False,
1675
+ connection: Optional[DBConnection] = None,
1676
+ **kwargs: Any,
1677
+ ) -> PGVector:
1678
+ """
1679
+ Get instance of an existing PGVector store.This method will
1680
+ return the instance of the store without inserting any new
1681
+ embeddings
1682
+ """
1683
+ store = PGVector(
1684
+ connection=connection,
1685
+ collection_name=collection_name,
1686
+ embeddings=embedding,
1687
+ distance_strategy=distance_strategy,
1688
+ pre_delete_collection=pre_delete_collection,
1689
+ async_mode=True,
1690
+ **kwargs,
1691
+ )
1692
+
1693
+ return store
1694
+
1083
1695
  @classmethod
1084
1696
  def get_connection_string(cls, kwargs: Dict[str, Any]) -> str:
1085
1697
  connection_string: str = get_from_dict_or_env(
1086
1698
  data=kwargs,
1087
- key="connection_string",
1699
+ key="connection",
1088
1700
  env_key="PGVECTOR_CONNECTION_STRING",
1089
1701
  )
1090
1702
 
@@ -1103,7 +1715,7 @@ class PGVector(VectorStore):
1103
1715
  documents: List[Document],
1104
1716
  embedding: Embeddings,
1105
1717
  *,
1106
- connection: Optional[Connection] = None,
1718
+ connection: Optional[DBConnection] = None,
1107
1719
  collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
1108
1720
  distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
1109
1721
  ids: Optional[List[str]] = None,
@@ -1129,6 +1741,44 @@ class PGVector(VectorStore):
1129
1741
  **kwargs,
1130
1742
  )
1131
1743
 
1744
+ @classmethod
1745
+ async def afrom_documents(
1746
+ cls: Type[PGVector],
1747
+ documents: List[Document],
1748
+ embedding: Embeddings,
1749
+ collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
1750
+ distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
1751
+ ids: Optional[List[str]] = None,
1752
+ pre_delete_collection: bool = False,
1753
+ *,
1754
+ use_jsonb: bool = True,
1755
+ **kwargs: Any,
1756
+ ) -> PGVector:
1757
+ """
1758
+ Return VectorStore initialized from documents and embeddings.
1759
+ Postgres connection string is required
1760
+ "Either pass it as a parameter
1761
+ or set the PGVECTOR_CONNECTION_STRING environment variable.
1762
+ """
1763
+
1764
+ texts = [d.page_content for d in documents]
1765
+ metadatas = [d.metadata for d in documents]
1766
+ connection_string = cls.get_connection_string(kwargs)
1767
+
1768
+ kwargs["connection"] = connection_string
1769
+
1770
+ return await cls.afrom_texts(
1771
+ texts=texts,
1772
+ pre_delete_collection=pre_delete_collection,
1773
+ embedding=embedding,
1774
+ distance_strategy=distance_strategy,
1775
+ metadatas=metadatas,
1776
+ ids=ids,
1777
+ collection_name=collection_name,
1778
+ use_jsonb=use_jsonb,
1779
+ **kwargs,
1780
+ )
1781
+
1132
1782
  @classmethod
1133
1783
  def connection_string_from_db_params(
1134
1784
  cls,
@@ -1201,6 +1851,7 @@ class PGVector(VectorStore):
1201
1851
  List[Tuple[Document, float]]: List of Documents selected by maximal marginal
1202
1852
  relevance to the query and score for each.
1203
1853
  """
1854
+ assert not self._async_engine, "This method must be called without async_mode"
1204
1855
  results = self.__query_collection(embedding=embedding, k=fetch_k, filter=filter)
1205
1856
 
1206
1857
  embedding_list = [result.EmbeddingStore.embedding for result in results]
@@ -1216,6 +1867,55 @@ class PGVector(VectorStore):
1216
1867
 
1217
1868
  return [r for i, r in enumerate(candidates) if i in mmr_selected]
1218
1869
 
1870
+ async def amax_marginal_relevance_search_with_score_by_vector(
1871
+ self,
1872
+ embedding: List[float],
1873
+ k: int = 4,
1874
+ fetch_k: int = 20,
1875
+ lambda_mult: float = 0.5,
1876
+ filter: Optional[Dict[str, str]] = None,
1877
+ **kwargs: Any,
1878
+ ) -> List[Tuple[Document, float]]:
1879
+ """Return docs selected using the maximal marginal relevance with score
1880
+ to embedding vector.
1881
+
1882
+ Maximal marginal relevance optimizes for similarity to query AND diversity
1883
+ among selected documents.
1884
+
1885
+ Args:
1886
+ embedding: Embedding to look up documents similar to.
1887
+ k (int): Number of Documents to return. Defaults to 4.
1888
+ fetch_k (int): Number of Documents to fetch to pass to MMR algorithm.
1889
+ Defaults to 20.
1890
+ lambda_mult (float): Number between 0 and 1 that determines the degree
1891
+ of diversity among the results with 0 corresponding
1892
+ to maximum diversity and 1 to minimum diversity.
1893
+ Defaults to 0.5.
1894
+ filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
1895
+
1896
+ Returns:
1897
+ List[Tuple[Document, float]]: List of Documents selected by maximal marginal
1898
+ relevance to the query and score for each.
1899
+ """
1900
+ await self.__apost_init__() # Lazy async init
1901
+ async with self._make_async_session() as session:
1902
+ results = await self.__aquery_collection(
1903
+ session=session, embedding=embedding, k=fetch_k, filter=filter
1904
+ )
1905
+
1906
+ embedding_list = [result.EmbeddingStore.embedding for result in results]
1907
+
1908
+ mmr_selected = maximal_marginal_relevance(
1909
+ np.array(embedding, dtype=np.float32),
1910
+ embedding_list,
1911
+ k=k,
1912
+ lambda_mult=lambda_mult,
1913
+ )
1914
+
1915
+ candidates = self._results_to_docs_and_scores(results)
1916
+
1917
+ return [r for i, r in enumerate(candidates) if i in mmr_selected]
1918
+
1219
1919
  def max_marginal_relevance_search(
1220
1920
  self,
1221
1921
  query: str,
@@ -1254,6 +1954,45 @@ class PGVector(VectorStore):
1254
1954
  **kwargs,
1255
1955
  )
1256
1956
 
1957
+ async def amax_marginal_relevance_search(
1958
+ self,
1959
+ query: str,
1960
+ k: int = 4,
1961
+ fetch_k: int = 20,
1962
+ lambda_mult: float = 0.5,
1963
+ filter: Optional[Dict[str, str]] = None,
1964
+ **kwargs: Any,
1965
+ ) -> List[Document]:
1966
+ """Return docs selected using the maximal marginal relevance.
1967
+
1968
+ Maximal marginal relevance optimizes for similarity to query AND diversity
1969
+ among selected documents.
1970
+
1971
+ Args:
1972
+ query (str): Text to look up documents similar to.
1973
+ k (int): Number of Documents to return. Defaults to 4.
1974
+ fetch_k (int): Number of Documents to fetch to pass to MMR algorithm.
1975
+ Defaults to 20.
1976
+ lambda_mult (float): Number between 0 and 1 that determines the degree
1977
+ of diversity among the results with 0 corresponding
1978
+ to maximum diversity and 1 to minimum diversity.
1979
+ Defaults to 0.5.
1980
+ filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
1981
+
1982
+ Returns:
1983
+ List[Document]: List of Documents selected by maximal marginal relevance.
1984
+ """
1985
+ await self.__apost_init__() # Lazy async init
1986
+ embedding = self.embedding_function.embed_query(query)
1987
+ return await self.amax_marginal_relevance_search_by_vector(
1988
+ embedding,
1989
+ k=k,
1990
+ fetch_k=fetch_k,
1991
+ lambda_mult=lambda_mult,
1992
+ filter=filter,
1993
+ **kwargs,
1994
+ )
1995
+
1257
1996
  def max_marginal_relevance_search_with_score(
1258
1997
  self,
1259
1998
  query: str,
@@ -1294,6 +2033,47 @@ class PGVector(VectorStore):
1294
2033
  )
1295
2034
  return docs
1296
2035
 
2036
+ async def amax_marginal_relevance_search_with_score(
2037
+ self,
2038
+ query: str,
2039
+ k: int = 4,
2040
+ fetch_k: int = 20,
2041
+ lambda_mult: float = 0.5,
2042
+ filter: Optional[dict] = None,
2043
+ **kwargs: Any,
2044
+ ) -> List[Tuple[Document, float]]:
2045
+ """Return docs selected using the maximal marginal relevance with score.
2046
+
2047
+ Maximal marginal relevance optimizes for similarity to query AND diversity
2048
+ among selected documents.
2049
+
2050
+ Args:
2051
+ query (str): Text to look up documents similar to.
2052
+ k (int): Number of Documents to return. Defaults to 4.
2053
+ fetch_k (int): Number of Documents to fetch to pass to MMR algorithm.
2054
+ Defaults to 20.
2055
+ lambda_mult (float): Number between 0 and 1 that determines the degree
2056
+ of diversity among the results with 0 corresponding
2057
+ to maximum diversity and 1 to minimum diversity.
2058
+ Defaults to 0.5.
2059
+ filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
2060
+
2061
+ Returns:
2062
+ List[Tuple[Document, float]]: List of Documents selected by maximal marginal
2063
+ relevance to the query and score for each.
2064
+ """
2065
+ await self.__apost_init__() # Lazy async init
2066
+ embedding = self.embedding_function.embed_query(query)
2067
+ docs = await self.amax_marginal_relevance_search_with_score_by_vector(
2068
+ embedding=embedding,
2069
+ k=k,
2070
+ fetch_k=fetch_k,
2071
+ lambda_mult=lambda_mult,
2072
+ filter=filter,
2073
+ **kwargs,
2074
+ )
2075
+ return docs
2076
+
1297
2077
  def max_marginal_relevance_search_by_vector(
1298
2078
  self,
1299
2079
  embedding: List[float],
@@ -1343,18 +2123,58 @@ class PGVector(VectorStore):
1343
2123
  filter: Optional[Dict[str, str]] = None,
1344
2124
  **kwargs: Any,
1345
2125
  ) -> List[Document]:
1346
- """Return docs selected using the maximal marginal relevance."""
1347
-
1348
- # This is a temporary workaround to make the similarity search
1349
- # asynchronous. The proper solution is to make the similarity search
1350
- # asynchronous in the vector store implementations.
1351
- return await run_in_executor(
1352
- None,
1353
- self.max_marginal_relevance_search_by_vector,
1354
- embedding,
1355
- k=k,
1356
- fetch_k=fetch_k,
1357
- lambda_mult=lambda_mult,
1358
- filter=filter,
1359
- **kwargs,
2126
+ """Return docs selected using the maximal marginal relevance
2127
+ to embedding vector.
2128
+
2129
+ Maximal marginal relevance optimizes for similarity to query AND diversity
2130
+ among selected documents.
2131
+
2132
+ Args:
2133
+ embedding (str): Text to look up documents similar to.
2134
+ k (int): Number of Documents to return. Defaults to 4.
2135
+ fetch_k (int): Number of Documents to fetch to pass to MMR algorithm.
2136
+ Defaults to 20.
2137
+ lambda_mult (float): Number between 0 and 1 that determines the degree
2138
+ of diversity among the results with 0 corresponding
2139
+ to maximum diversity and 1 to minimum diversity.
2140
+ Defaults to 0.5.
2141
+ filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
2142
+
2143
+ Returns:
2144
+ List[Document]: List of Documents selected by maximal marginal relevance.
2145
+ """
2146
+ await self.__apost_init__() # Lazy async init
2147
+ docs_and_scores = (
2148
+ await self.amax_marginal_relevance_search_with_score_by_vector(
2149
+ embedding,
2150
+ k=k,
2151
+ fetch_k=fetch_k,
2152
+ lambda_mult=lambda_mult,
2153
+ filter=filter,
2154
+ **kwargs,
2155
+ )
1360
2156
  )
2157
+
2158
+ return _results_to_docs(docs_and_scores)
2159
+
2160
+ @contextlib.contextmanager
2161
+ def _make_sync_session(self) -> Generator[Session, None, None]:
2162
+ """Make an async session."""
2163
+ if self.async_mode:
2164
+ raise ValueError(
2165
+ "Attempting to use a sync method in when async mode is turned on. "
2166
+ "Please use the corresponding async method instead."
2167
+ )
2168
+ with self.session_maker() as session:
2169
+ yield typing_cast(Session, session)
2170
+
2171
+ @contextlib.asynccontextmanager
2172
+ async def _make_async_session(self) -> AsyncGenerator[AsyncSession, None]:
2173
+ """Make an async session."""
2174
+ if not self.async_mode:
2175
+ raise ValueError(
2176
+ "Attempting to use an async method in when sync mode is turned on. "
2177
+ "Please use the corresponding async method instead."
2178
+ )
2179
+ async with self.session_maker() as session:
2180
+ yield typing_cast(AsyncSession, session)