langchain-postgres 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1348 @@
1
+ from __future__ import annotations
2
+
3
+ import enum
4
+ import logging
5
+ import uuid
6
+ from typing import (
7
+ Any,
8
+ Callable,
9
+ Dict,
10
+ Iterable,
11
+ List,
12
+ Optional,
13
+ Tuple,
14
+ Type,
15
+ Union,
16
+ )
17
+
18
+ import numpy as np
19
+ import sqlalchemy
20
+ from sqlalchemy import SQLColumnExpression, cast, delete, func
21
+ from sqlalchemy.dialects.postgresql import JSON, JSONB, JSONPATH, UUID, insert
22
+ from sqlalchemy.orm import Session, relationship, sessionmaker
23
+
24
+ try:
25
+ from sqlalchemy.orm import declarative_base
26
+ except ImportError:
27
+ from sqlalchemy.ext.declarative import declarative_base
28
+
29
+ from langchain_core.documents import Document
30
+ from langchain_core.embeddings import Embeddings
31
+ from langchain_core.runnables.config import run_in_executor
32
+ from langchain_core.utils import get_from_dict_or_env
33
+ from langchain_core.vectorstores import VectorStore
34
+
35
+ from langchain_postgres._utils import maximal_marginal_relevance
36
+
37
+
38
+ class DistanceStrategy(str, enum.Enum):
39
+ """Enumerator of the Distance strategies."""
40
+
41
+ EUCLIDEAN = "l2"
42
+ COSINE = "cosine"
43
+ MAX_INNER_PRODUCT = "inner"
44
+
45
+
46
+ DEFAULT_DISTANCE_STRATEGY = DistanceStrategy.COSINE
47
+
48
+ Base = declarative_base() # type: Any
49
+
50
+
51
+ _LANGCHAIN_DEFAULT_COLLECTION_NAME = "langchain"
52
+
53
+
54
+ _classes: Any = None
55
+
56
+ COMPARISONS_TO_NATIVE = {
57
+ "$eq": "==",
58
+ "$ne": "!=",
59
+ "$lt": "<",
60
+ "$lte": "<=",
61
+ "$gt": ">",
62
+ "$gte": ">=",
63
+ }
64
+
65
+ SPECIAL_CASED_OPERATORS = {
66
+ "$in",
67
+ "$nin",
68
+ "$between",
69
+ }
70
+
71
+ TEXT_OPERATORS = {
72
+ "$like",
73
+ "$ilike",
74
+ }
75
+
76
+ LOGICAL_OPERATORS = {"$and", "$or"}
77
+
78
+ SUPPORTED_OPERATORS = (
79
+ set(COMPARISONS_TO_NATIVE)
80
+ .union(TEXT_OPERATORS)
81
+ .union(LOGICAL_OPERATORS)
82
+ .union(SPECIAL_CASED_OPERATORS)
83
+ )
84
+
85
+
86
+ def _get_embedding_collection_store(vector_dimension: Optional[int] = None) -> Any:
87
+ global _classes
88
+ if _classes is not None:
89
+ return _classes
90
+
91
+ from pgvector.sqlalchemy import Vector # type: ignore
92
+
93
+ class CollectionStore(Base):
94
+ """Collection store."""
95
+
96
+ __tablename__ = "langchain_pg_collection"
97
+
98
+ uuid = sqlalchemy.Column(
99
+ UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
100
+ )
101
+ name = sqlalchemy.Column(sqlalchemy.String, nullable=False, unique=True)
102
+ cmetadata = sqlalchemy.Column(JSON)
103
+
104
+ embeddings = relationship(
105
+ "EmbeddingStore",
106
+ back_populates="collection",
107
+ passive_deletes=True,
108
+ )
109
+
110
+ @classmethod
111
+ def get_by_name(
112
+ cls, session: Session, name: str
113
+ ) -> Optional["CollectionStore"]:
114
+ return session.query(cls).filter(cls.name == name).first() # type: ignore
115
+
116
+ @classmethod
117
+ def get_or_create(
118
+ cls,
119
+ session: Session,
120
+ name: str,
121
+ cmetadata: Optional[dict] = None,
122
+ ) -> Tuple["CollectionStore", bool]:
123
+ """Get or create a collection.
124
+ Returns:
125
+ Where the bool is True if the collection was created.
126
+ """ # noqa: E501
127
+ created = False
128
+ collection = cls.get_by_name(session, name)
129
+ if collection:
130
+ return collection, created
131
+
132
+ collection = cls(name=name, cmetadata=cmetadata)
133
+ session.add(collection)
134
+ session.commit()
135
+ created = True
136
+ return collection, created
137
+
138
+ class EmbeddingStore(Base):
139
+ """Embedding store."""
140
+
141
+ __tablename__ = "langchain_pg_embedding"
142
+
143
+ id = sqlalchemy.Column(
144
+ sqlalchemy.String, nullable=True, primary_key=True, index=True, unique=True
145
+ )
146
+
147
+ collection_id = sqlalchemy.Column(
148
+ UUID(as_uuid=True),
149
+ sqlalchemy.ForeignKey(
150
+ f"{CollectionStore.__tablename__}.uuid",
151
+ ondelete="CASCADE",
152
+ ),
153
+ )
154
+ collection = relationship(CollectionStore, back_populates="embeddings")
155
+
156
+ embedding: Vector = sqlalchemy.Column(Vector(vector_dimension))
157
+ document = sqlalchemy.Column(sqlalchemy.String, nullable=True)
158
+ cmetadata = sqlalchemy.Column(JSONB, nullable=True)
159
+
160
+ __table_args__ = (
161
+ sqlalchemy.Index(
162
+ "ix_cmetadata_gin",
163
+ "cmetadata",
164
+ postgresql_using="gin",
165
+ postgresql_ops={"cmetadata": "jsonb_path_ops"},
166
+ ),
167
+ )
168
+
169
+ _classes = (EmbeddingStore, CollectionStore)
170
+
171
+ return _classes
172
+
173
+
174
+ def _results_to_docs(docs_and_scores: Any) -> List[Document]:
175
+ """Return docs from docs and scores."""
176
+ return [doc for doc, _ in docs_and_scores]
177
+
178
+
179
+ Connection = Union[sqlalchemy.engine.Engine, str]
180
+
181
+
182
+ class PGVector(VectorStore):
183
+ """Vectorstore implementation using Postgres as the backend.
184
+
185
+ Currently, there is no mechanism for supporting data migration.
186
+
187
+ So breaking changes in the vectorstore schema will require the user to recreate
188
+ the tables and re-add the documents.
189
+
190
+ If this is a concern, please use a different vectorstore. If
191
+ not, this implementation should be fine for your use case.
192
+
193
+ To use this vectorstore you need to have the `vector` extension installed.
194
+ The `vector` extension is a Postgres extension that provides vector
195
+ similarity search capabilities.
196
+
197
+ ```sh
198
+ docker run --name pgvector-container -e POSTGRES_PASSWORD=...
199
+ -d pgvector/pgvector:pg16
200
+ ```
201
+
202
+ Example:
203
+ .. code-block:: python
204
+
205
+ from langchain_postgres.vectorstores import PGVector
206
+ from langchain_openai.embeddings import OpenAIEmbeddings
207
+
208
+ connection_string = "postgresql+psycopg://..."
209
+ collection_name = "state_of_the_union_test"
210
+ embeddings = OpenAIEmbeddings()
211
+ vectorstore = PGVector.from_documents(
212
+ embedding=embeddings,
213
+ documents=docs,
214
+ connection=connection_string,
215
+ collection_name=collection_name,
216
+ use_jsonb=True,
217
+ )
218
+
219
+
220
+ This code has been ported over from langchain_community with minimal changes
221
+ to allow users to easily transition from langchain_community to langchain_postgres.
222
+
223
+ Some changes had to be made to address issues with the community implementation:
224
+ * langchain_postgres now works with psycopg3. Please update your
225
+ connection strings from `postgresql+psycopg2://...` to
226
+ `postgresql+psycopg://langchain:langchain@...`
227
+ (yes, the driver name is `psycopg` not `psycopg3`)
228
+ * The schema of the embedding store and collection have been changed to make
229
+ add_documents work correctly with user specified ids, specifically
230
+ when overwriting existing documents.
231
+ You will need to recreate the tables if you are using an existing database.
232
+ * A Connection object has to be provided explicitly. Connections will not be
233
+ picked up automatically based on env variables.
234
+ """
235
+
236
+ def __init__(
237
+ self,
238
+ embeddings: Embeddings,
239
+ *,
240
+ connection: Optional[Connection] = None,
241
+ embedding_length: Optional[int] = None,
242
+ collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
243
+ collection_metadata: Optional[dict] = None,
244
+ distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
245
+ pre_delete_collection: bool = False,
246
+ logger: Optional[logging.Logger] = None,
247
+ relevance_score_fn: Optional[Callable[[float], float]] = None,
248
+ engine_args: Optional[dict[str, Any]] = None,
249
+ use_jsonb: bool = True,
250
+ create_extension: bool = True,
251
+ ) -> None:
252
+ """Initialize the PGVector store.
253
+
254
+ Args:
255
+ connection: Postgres connection string.
256
+ embeddings: Any embedding function implementing
257
+ `langchain.embeddings.base.Embeddings` interface.
258
+ embedding_length: The length of the embedding vector. (default: None)
259
+ NOTE: This is not mandatory. Defining it will prevent vectors of
260
+ any other size to be added to the embeddings table but, without it,
261
+ the embeddings can't be indexed.
262
+ collection_name: The name of the collection to use. (default: langchain)
263
+ NOTE: This is not the name of the table, but the name of the collection.
264
+ The tables will be created when initializing the store (if not exists)
265
+ So, make sure the user has the right permissions to create tables.
266
+ distance_strategy: The distance strategy to use. (default: COSINE)
267
+ pre_delete_collection: If True, will delete the collection if it exists.
268
+ (default: False). Useful for testing.
269
+ engine_args: SQLAlchemy's create engine arguments.
270
+ use_jsonb: Use JSONB instead of JSON for metadata. (default: True)
271
+ Strongly discouraged from using JSON as it's not as efficient
272
+ for querying.
273
+ It's provided here for backwards compatibility with older versions,
274
+ and will be removed in the future.
275
+ create_extension: If True, will create the vector extension if it
276
+ doesn't exist. disabling creation is useful when using ReadOnly
277
+ Databases.
278
+ """
279
+ self.embedding_function = embeddings
280
+ self._embedding_length = embedding_length
281
+ self.collection_name = collection_name
282
+ self.collection_metadata = collection_metadata
283
+ self._distance_strategy = distance_strategy
284
+ self.pre_delete_collection = pre_delete_collection
285
+ self.logger = logger or logging.getLogger(__name__)
286
+ self.override_relevance_score_fn = relevance_score_fn
287
+
288
+ if isinstance(connection, str):
289
+ self._engine = sqlalchemy.create_engine(
290
+ url=connection, **(engine_args or {})
291
+ )
292
+ elif isinstance(connection, sqlalchemy.engine.Engine):
293
+ self._engine = connection
294
+ else:
295
+ raise ValueError(
296
+ "connection should be a connection string or an instance of "
297
+ "sqlalchemy.engine.Engine"
298
+ )
299
+
300
+ self._session_maker = sessionmaker(bind=self._engine)
301
+
302
+ self.use_jsonb = use_jsonb
303
+ self.create_extension = create_extension
304
+
305
+ if not use_jsonb:
306
+ # Replace with a deprecation warning.
307
+ raise NotImplementedError("use_jsonb=False is no longer supported.")
308
+ self.__post_init__()
309
+
310
+ def __post_init__(
311
+ self,
312
+ ) -> None:
313
+ """Initialize the store."""
314
+ if self.create_extension:
315
+ self.create_vector_extension()
316
+
317
+ EmbeddingStore, CollectionStore = _get_embedding_collection_store(
318
+ self._embedding_length
319
+ )
320
+ self.CollectionStore = CollectionStore
321
+ self.EmbeddingStore = EmbeddingStore
322
+ self.create_tables_if_not_exists()
323
+ self.create_collection()
324
+
325
+ def __del__(self) -> None:
326
+ if isinstance(self._engine, sqlalchemy.engine.Connection):
327
+ self._engine.close()
328
+
329
+ @property
330
+ def embeddings(self) -> Embeddings:
331
+ return self.embedding_function
332
+
333
+ def create_vector_extension(self) -> None:
334
+ try:
335
+ with self._session_maker() as session: # type: ignore[arg-type]
336
+ # The advisor lock fixes issue arising from concurrent
337
+ # creation of the vector extension.
338
+ # https://github.com/langchain-ai/langchain/issues/12933
339
+ # For more information see:
340
+ # https://www.postgresql.org/docs/16/explicit-locking.html#ADVISORY-LOCKS
341
+ statement = sqlalchemy.text(
342
+ "BEGIN;"
343
+ "SELECT pg_advisory_xact_lock(1573678846307946496);"
344
+ "CREATE EXTENSION IF NOT EXISTS vector;"
345
+ "COMMIT;"
346
+ )
347
+ session.execute(statement)
348
+ session.commit()
349
+ except Exception as e:
350
+ raise Exception(f"Failed to create vector extension: {e}") from e
351
+
352
+ def create_tables_if_not_exists(self) -> None:
353
+ with self._session_maker() as session:
354
+ Base.metadata.create_all(session.get_bind())
355
+
356
+ def drop_tables(self) -> None:
357
+ with self._session_maker() as session:
358
+ Base.metadata.drop_all(session.get_bind())
359
+
360
+ def create_collection(self) -> None:
361
+ if self.pre_delete_collection:
362
+ self.delete_collection()
363
+ with self._session_maker() as session: # type: ignore[arg-type]
364
+ self.CollectionStore.get_or_create(
365
+ session, self.collection_name, cmetadata=self.collection_metadata
366
+ )
367
+
368
+ def delete_collection(self) -> None:
369
+ self.logger.debug("Trying to delete collection")
370
+ with self._session_maker() as session: # type: ignore[arg-type]
371
+ collection = self.get_collection(session)
372
+ if not collection:
373
+ self.logger.warning("Collection not found")
374
+ return
375
+ session.delete(collection)
376
+ session.commit()
377
+
378
+ def delete(
379
+ self,
380
+ ids: Optional[List[str]] = None,
381
+ collection_only: bool = False,
382
+ **kwargs: Any,
383
+ ) -> None:
384
+ """Delete vectors by ids or uuids.
385
+
386
+ Args:
387
+ ids: List of ids to delete.
388
+ collection_only: Only delete ids in the collection.
389
+ """
390
+ with self._session_maker() as session:
391
+ if ids is not None:
392
+ self.logger.debug(
393
+ "Trying to delete vectors by ids (represented by the model "
394
+ "using the custom ids field)"
395
+ )
396
+
397
+ stmt = delete(self.EmbeddingStore)
398
+
399
+ if collection_only:
400
+ collection = self.get_collection(session)
401
+ if not collection:
402
+ self.logger.warning("Collection not found")
403
+ return
404
+
405
+ stmt = stmt.where(
406
+ self.EmbeddingStore.collection_id == collection.uuid
407
+ )
408
+
409
+ stmt = stmt.where(self.EmbeddingStore.id.in_(ids))
410
+ session.execute(stmt)
411
+ session.commit()
412
+
413
+ def get_collection(self, session: Session) -> Any:
414
+ return self.CollectionStore.get_by_name(session, self.collection_name)
415
+
416
+ @classmethod
417
+ def __from(
418
+ cls,
419
+ texts: List[str],
420
+ embeddings: List[List[float]],
421
+ embedding: Embeddings,
422
+ metadatas: Optional[List[dict]] = None,
423
+ ids: Optional[List[str]] = None,
424
+ collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
425
+ distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
426
+ connection: Optional[str] = None,
427
+ pre_delete_collection: bool = False,
428
+ *,
429
+ use_jsonb: bool = True,
430
+ **kwargs: Any,
431
+ ) -> PGVector:
432
+ if ids is None:
433
+ ids = [str(uuid.uuid1()) for _ in texts]
434
+
435
+ if not metadatas:
436
+ metadatas = [{} for _ in texts]
437
+
438
+ store = cls(
439
+ connection=connection,
440
+ collection_name=collection_name,
441
+ embeddings=embedding,
442
+ distance_strategy=distance_strategy,
443
+ pre_delete_collection=pre_delete_collection,
444
+ use_jsonb=use_jsonb,
445
+ **kwargs,
446
+ )
447
+
448
+ store.add_embeddings(
449
+ texts=texts, embeddings=embeddings, metadatas=metadatas, ids=ids, **kwargs
450
+ )
451
+
452
+ return store
453
+
454
+ def add_embeddings(
455
+ self,
456
+ texts: Iterable[str],
457
+ embeddings: List[List[float]],
458
+ metadatas: Optional[List[dict]] = None,
459
+ ids: Optional[List[str]] = None,
460
+ **kwargs: Any,
461
+ ) -> List[str]:
462
+ """Add embeddings to the vectorstore.
463
+
464
+ Args:
465
+ texts: Iterable of strings to add to the vectorstore.
466
+ embeddings: List of list of embedding vectors.
467
+ metadatas: List of metadatas associated with the texts.
468
+ kwargs: vectorstore specific parameters
469
+ """
470
+ if ids is None:
471
+ ids = [str(uuid.uuid1()) for _ in texts]
472
+
473
+ if not metadatas:
474
+ metadatas = [{} for _ in texts]
475
+
476
+ with self._session_maker() as session: # type: ignore[arg-type]
477
+ collection = self.get_collection(session)
478
+ if not collection:
479
+ raise ValueError("Collection not found")
480
+ data = [
481
+ {
482
+ "id": id,
483
+ "collection_id": collection.uuid,
484
+ "embedding": embedding,
485
+ "document": text,
486
+ "cmetadata": metadata or {},
487
+ }
488
+ for text, metadata, embedding, id in zip(
489
+ texts, metadatas, embeddings, ids
490
+ )
491
+ ]
492
+ stmt = insert(self.EmbeddingStore).values(data)
493
+ on_conflict_stmt = stmt.on_conflict_do_update(
494
+ index_elements=["id"],
495
+ # Conflict detection based on these columns
496
+ set_={
497
+ "embedding": stmt.excluded.embedding,
498
+ "document": stmt.excluded.document,
499
+ "cmetadata": stmt.excluded.cmetadata,
500
+ },
501
+ )
502
+ session.execute(on_conflict_stmt)
503
+ session.commit()
504
+
505
+ return ids
506
+
507
+ def add_texts(
508
+ self,
509
+ texts: Iterable[str],
510
+ metadatas: Optional[List[dict]] = None,
511
+ ids: Optional[List[str]] = None,
512
+ **kwargs: Any,
513
+ ) -> List[str]:
514
+ """Run more texts through the embeddings and add to the vectorstore.
515
+
516
+ Args:
517
+ texts: Iterable of strings to add to the vectorstore.
518
+ metadatas: Optional list of metadatas associated with the texts.
519
+ kwargs: vectorstore specific parameters
520
+
521
+ Returns:
522
+ List of ids from adding the texts into the vectorstore.
523
+ """
524
+ embeddings = self.embedding_function.embed_documents(list(texts))
525
+ return self.add_embeddings(
526
+ texts=texts, embeddings=embeddings, metadatas=metadatas, ids=ids, **kwargs
527
+ )
528
+
529
+ def similarity_search(
530
+ self,
531
+ query: str,
532
+ k: int = 4,
533
+ filter: Optional[dict] = None,
534
+ **kwargs: Any,
535
+ ) -> List[Document]:
536
+ """Run similarity search with PGVector with distance.
537
+
538
+ Args:
539
+ query (str): Query text to search for.
540
+ k (int): Number of results to return. Defaults to 4.
541
+ filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
542
+
543
+ Returns:
544
+ List of Documents most similar to the query.
545
+ """
546
+ embedding = self.embedding_function.embed_query(text=query)
547
+ return self.similarity_search_by_vector(
548
+ embedding=embedding,
549
+ k=k,
550
+ filter=filter,
551
+ )
552
+
553
+ def similarity_search_with_score(
554
+ self,
555
+ query: str,
556
+ k: int = 4,
557
+ filter: Optional[dict] = None,
558
+ ) -> List[Tuple[Document, float]]:
559
+ """Return docs most similar to query.
560
+
561
+ Args:
562
+ query: Text to look up documents similar to.
563
+ k: Number of Documents to return. Defaults to 4.
564
+ filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
565
+
566
+ Returns:
567
+ List of Documents most similar to the query and score for each.
568
+ """
569
+ embedding = self.embedding_function.embed_query(query)
570
+ docs = self.similarity_search_with_score_by_vector(
571
+ embedding=embedding, k=k, filter=filter
572
+ )
573
+ return docs
574
+
575
+ @property
576
+ def distance_strategy(self) -> Any:
577
+ if self._distance_strategy == DistanceStrategy.EUCLIDEAN:
578
+ return self.EmbeddingStore.embedding.l2_distance
579
+ elif self._distance_strategy == DistanceStrategy.COSINE:
580
+ return self.EmbeddingStore.embedding.cosine_distance
581
+ elif self._distance_strategy == DistanceStrategy.MAX_INNER_PRODUCT:
582
+ return self.EmbeddingStore.embedding.max_inner_product
583
+ else:
584
+ raise ValueError(
585
+ f"Got unexpected value for distance: {self._distance_strategy}. "
586
+ f"Should be one of {', '.join([ds.value for ds in DistanceStrategy])}."
587
+ )
588
+
589
+ def similarity_search_with_score_by_vector(
590
+ self,
591
+ embedding: List[float],
592
+ k: int = 4,
593
+ filter: Optional[dict] = None,
594
+ ) -> List[Tuple[Document, float]]:
595
+ results = self.__query_collection(embedding=embedding, k=k, filter=filter)
596
+
597
+ return self._results_to_docs_and_scores(results)
598
+
599
+ def _results_to_docs_and_scores(self, results: Any) -> List[Tuple[Document, float]]:
600
+ """Return docs and scores from results."""
601
+ docs = [
602
+ (
603
+ Document(
604
+ page_content=result.EmbeddingStore.document,
605
+ metadata=result.EmbeddingStore.cmetadata,
606
+ ),
607
+ result.distance if self.embedding_function is not None else None,
608
+ )
609
+ for result in results
610
+ ]
611
+ return docs
612
+
613
+ def _handle_field_filter(
614
+ self,
615
+ field: str,
616
+ value: Any,
617
+ ) -> SQLColumnExpression:
618
+ """Create a filter for a specific field.
619
+
620
+ Args:
621
+ field: name of field
622
+ value: value to filter
623
+ If provided as is then this will be an equality filter
624
+ If provided as a dictionary then this will be a filter, the key
625
+ will be the operator and the value will be the value to filter by
626
+
627
+ Returns:
628
+ sqlalchemy expression
629
+ """
630
+ if not isinstance(field, str):
631
+ raise ValueError(
632
+ f"field should be a string but got: {type(field)} with value: {field}"
633
+ )
634
+
635
+ if field.startswith("$"):
636
+ raise ValueError(
637
+ f"Invalid filter condition. Expected a field but got an operator: "
638
+ f"{field}"
639
+ )
640
+
641
+ # Allow [a-zA-Z0-9_], disallow $ for now until we support escape characters
642
+ if not field.isidentifier():
643
+ raise ValueError(
644
+ f"Invalid field name: {field}. Expected a valid identifier."
645
+ )
646
+
647
+ if isinstance(value, dict):
648
+ # This is a filter specification
649
+ if len(value) != 1:
650
+ raise ValueError(
651
+ "Invalid filter condition. Expected a value which "
652
+ "is a dictionary with a single key that corresponds to an operator "
653
+ f"but got a dictionary with {len(value)} keys. The first few "
654
+ f"keys are: {list(value.keys())[:3]}"
655
+ )
656
+ operator, filter_value = list(value.items())[0]
657
+ # Verify that that operator is an operator
658
+ if operator not in SUPPORTED_OPERATORS:
659
+ raise ValueError(
660
+ f"Invalid operator: {operator}. "
661
+ f"Expected one of {SUPPORTED_OPERATORS}"
662
+ )
663
+ else: # Then we assume an equality operator
664
+ operator = "$eq"
665
+ filter_value = value
666
+
667
+ if operator in COMPARISONS_TO_NATIVE:
668
+ # Then we implement an equality filter
669
+ # native is trusted input
670
+ native = COMPARISONS_TO_NATIVE[operator]
671
+ return func.jsonb_path_match(
672
+ self.EmbeddingStore.cmetadata,
673
+ cast(f"$.{field} {native} $value", JSONPATH),
674
+ cast({"value": filter_value}, JSONB),
675
+ )
676
+ elif operator == "$between":
677
+ # Use AND with two comparisons
678
+ low, high = filter_value
679
+
680
+ lower_bound = func.jsonb_path_match(
681
+ self.EmbeddingStore.cmetadata,
682
+ cast(f"$.{field} >= $value", JSONPATH),
683
+ cast({"value": low}, JSONB),
684
+ )
685
+ upper_bound = func.jsonb_path_match(
686
+ self.EmbeddingStore.cmetadata,
687
+ cast(f"$.{field} <= $value", JSONPATH),
688
+ cast({"value": high}, JSONB),
689
+ )
690
+ return sqlalchemy.and_(lower_bound, upper_bound)
691
+ elif operator in {"$in", "$nin", "$like", "$ilike"}:
692
+ # We'll do force coercion to text
693
+ if operator in {"$in", "$nin"}:
694
+ for val in filter_value:
695
+ if not isinstance(val, (str, int, float)):
696
+ raise NotImplementedError(
697
+ f"Unsupported type: {type(val)} for value: {val}"
698
+ )
699
+
700
+ queried_field = self.EmbeddingStore.cmetadata[field].astext
701
+
702
+ if operator in {"$in"}:
703
+ return queried_field.in_([str(val) for val in filter_value])
704
+ elif operator in {"$nin"}:
705
+ return queried_field.nin_([str(val) for val in filter_value])
706
+ elif operator in {"$like"}:
707
+ return queried_field.like(filter_value)
708
+ elif operator in {"$ilike"}:
709
+ return queried_field.ilike(filter_value)
710
+ else:
711
+ raise NotImplementedError()
712
+ else:
713
+ raise NotImplementedError()
714
+
715
+ def _create_filter_clause_deprecated(self, key, value): # type: ignore[no-untyped-def]
716
+ """Deprecated functionality.
717
+
718
+ This is for backwards compatibility with the JSON based schema for metadata.
719
+ It uses incorrect operator syntax (operators are not prefixed with $).
720
+
721
+ This implementation is not efficient, and has bugs associated with
722
+ the way that it handles numeric filter clauses.
723
+ """
724
+ IN, NIN, BETWEEN, GT, LT, NE = "in", "nin", "between", "gt", "lt", "ne"
725
+ EQ, LIKE, CONTAINS, OR, AND = "eq", "like", "contains", "or", "and"
726
+
727
+ value_case_insensitive = {k.lower(): v for k, v in value.items()}
728
+ if IN in map(str.lower, value):
729
+ filter_by_metadata = self.EmbeddingStore.cmetadata[key].astext.in_(
730
+ value_case_insensitive[IN]
731
+ )
732
+ elif NIN in map(str.lower, value):
733
+ filter_by_metadata = self.EmbeddingStore.cmetadata[key].astext.not_in(
734
+ value_case_insensitive[NIN]
735
+ )
736
+ elif BETWEEN in map(str.lower, value):
737
+ filter_by_metadata = self.EmbeddingStore.cmetadata[key].astext.between(
738
+ str(value_case_insensitive[BETWEEN][0]),
739
+ str(value_case_insensitive[BETWEEN][1]),
740
+ )
741
+ elif GT in map(str.lower, value):
742
+ filter_by_metadata = self.EmbeddingStore.cmetadata[key].astext > str(
743
+ value_case_insensitive[GT]
744
+ )
745
+ elif LT in map(str.lower, value):
746
+ filter_by_metadata = self.EmbeddingStore.cmetadata[key].astext < str(
747
+ value_case_insensitive[LT]
748
+ )
749
+ elif NE in map(str.lower, value):
750
+ filter_by_metadata = self.EmbeddingStore.cmetadata[key].astext != str(
751
+ value_case_insensitive[NE]
752
+ )
753
+ elif EQ in map(str.lower, value):
754
+ filter_by_metadata = self.EmbeddingStore.cmetadata[key].astext == str(
755
+ value_case_insensitive[EQ]
756
+ )
757
+ elif LIKE in map(str.lower, value):
758
+ filter_by_metadata = self.EmbeddingStore.cmetadata[key].astext.like(
759
+ value_case_insensitive[LIKE]
760
+ )
761
+ elif CONTAINS in map(str.lower, value):
762
+ filter_by_metadata = self.EmbeddingStore.cmetadata[key].astext.contains(
763
+ value_case_insensitive[CONTAINS]
764
+ )
765
+ elif OR in map(str.lower, value):
766
+ or_clauses = [
767
+ self._create_filter_clause(key, sub_value)
768
+ for sub_value in value_case_insensitive[OR]
769
+ ]
770
+ filter_by_metadata = sqlalchemy.or_(*or_clauses)
771
+ elif AND in map(str.lower, value):
772
+ and_clauses = [
773
+ self._create_filter_clause(key, sub_value)
774
+ for sub_value in value_case_insensitive[AND]
775
+ ]
776
+ filter_by_metadata = sqlalchemy.and_(*and_clauses)
777
+
778
+ else:
779
+ filter_by_metadata = None
780
+
781
+ return filter_by_metadata
782
+
783
+ def _create_filter_clause_json_deprecated(
784
+ self, filter: Any
785
+ ) -> List[SQLColumnExpression]:
786
+ """Convert filters from IR to SQL clauses.
787
+
788
+ **DEPRECATED** This functionality will be deprecated in the future.
789
+
790
+ It implements translation of filters for a schema that uses JSON
791
+ for metadata rather than the JSONB field which is more efficient
792
+ for querying.
793
+ """
794
+ filter_clauses = []
795
+ for key, value in filter.items():
796
+ if isinstance(value, dict):
797
+ filter_by_metadata = self._create_filter_clause_deprecated(key, value)
798
+
799
+ if filter_by_metadata is not None:
800
+ filter_clauses.append(filter_by_metadata)
801
+ else:
802
+ filter_by_metadata = self.EmbeddingStore.cmetadata[key].astext == str(
803
+ value
804
+ )
805
+ filter_clauses.append(filter_by_metadata)
806
+ return filter_clauses
807
+
808
+ def _create_filter_clause(self, filters: Any) -> Any:
809
+ """Convert LangChain IR filter representation to matching SQLAlchemy clauses.
810
+
811
+ At the top level, we still don't know if we're working with a field
812
+ or an operator for the keys. After we've determined that we can
813
+ call the appropriate logic to handle filter creation.
814
+
815
+ Args:
816
+ filters: Dictionary of filters to apply to the query.
817
+
818
+ Returns:
819
+ SQLAlchemy clause to apply to the query.
820
+ """
821
+ if isinstance(filters, dict):
822
+ if len(filters) == 1:
823
+ # The only operators allowed at the top level are $AND and $OR
824
+ # First check if an operator or a field
825
+ key, value = list(filters.items())[0]
826
+ if key.startswith("$"):
827
+ # Then it's an operator
828
+ if key.lower() not in ["$and", "$or"]:
829
+ raise ValueError(
830
+ f"Invalid filter condition. Expected $and or $or "
831
+ f"but got: {key}"
832
+ )
833
+ else:
834
+ # Then it's a field
835
+ return self._handle_field_filter(key, filters[key])
836
+
837
+ # Here we handle the $and and $or operators
838
+ if not isinstance(value, list):
839
+ raise ValueError(
840
+ f"Expected a list, but got {type(value)} for value: {value}"
841
+ )
842
+ if key.lower() == "$and":
843
+ and_ = [self._create_filter_clause(el) for el in value]
844
+ if len(and_) > 1:
845
+ return sqlalchemy.and_(*and_)
846
+ elif len(and_) == 1:
847
+ return and_[0]
848
+ else:
849
+ raise ValueError(
850
+ "Invalid filter condition. Expected a dictionary "
851
+ "but got an empty dictionary"
852
+ )
853
+ elif key.lower() == "$or":
854
+ or_ = [self._create_filter_clause(el) for el in value]
855
+ if len(or_) > 1:
856
+ return sqlalchemy.or_(*or_)
857
+ elif len(or_) == 1:
858
+ return or_[0]
859
+ else:
860
+ raise ValueError(
861
+ "Invalid filter condition. Expected a dictionary "
862
+ "but got an empty dictionary"
863
+ )
864
+ else:
865
+ raise ValueError(
866
+ f"Invalid filter condition. Expected $and or $or "
867
+ f"but got: {key}"
868
+ )
869
+ elif len(filters) > 1:
870
+ # Then all keys have to be fields (they cannot be operators)
871
+ for key in filters.keys():
872
+ if key.startswith("$"):
873
+ raise ValueError(
874
+ f"Invalid filter condition. Expected a field but got: {key}"
875
+ )
876
+ # These should all be fields and combined using an $and operator
877
+ and_ = [self._handle_field_filter(k, v) for k, v in filters.items()]
878
+ if len(and_) > 1:
879
+ return sqlalchemy.and_(*and_)
880
+ elif len(and_) == 1:
881
+ return and_[0]
882
+ else:
883
+ raise ValueError(
884
+ "Invalid filter condition. Expected a dictionary "
885
+ "but got an empty dictionary"
886
+ )
887
+ else:
888
+ raise ValueError("Got an empty dictionary for filters.")
889
+ else:
890
+ raise ValueError(
891
+ f"Invalid type: Expected a dictionary but got type: {type(filters)}"
892
+ )
893
+
894
+ def __query_collection(
895
+ self,
896
+ embedding: List[float],
897
+ k: int = 4,
898
+ filter: Optional[Dict[str, str]] = None,
899
+ ) -> List[Any]:
900
+ """Query the collection."""
901
+ with self._session_maker() as session: # type: ignore[arg-type]
902
+ collection = self.get_collection(session)
903
+ if not collection:
904
+ raise ValueError("Collection not found")
905
+
906
+ filter_by = [self.EmbeddingStore.collection_id == collection.uuid]
907
+ if filter:
908
+ if self.use_jsonb:
909
+ filter_clauses = self._create_filter_clause(filter)
910
+ if filter_clauses is not None:
911
+ filter_by.append(filter_clauses)
912
+ else:
913
+ # Old way of doing things
914
+ filter_clauses = self._create_filter_clause_json_deprecated(filter)
915
+ filter_by.extend(filter_clauses)
916
+
917
+ _type = self.EmbeddingStore
918
+
919
+ results: List[Any] = (
920
+ session.query(
921
+ self.EmbeddingStore,
922
+ self.distance_strategy(embedding).label("distance"), # type: ignore
923
+ )
924
+ .filter(*filter_by)
925
+ .order_by(sqlalchemy.asc("distance"))
926
+ .join(
927
+ self.CollectionStore,
928
+ self.EmbeddingStore.collection_id == self.CollectionStore.uuid,
929
+ )
930
+ .limit(k)
931
+ .all()
932
+ )
933
+
934
+ return results
935
+
936
+ def similarity_search_by_vector(
937
+ self,
938
+ embedding: List[float],
939
+ k: int = 4,
940
+ filter: Optional[dict] = None,
941
+ **kwargs: Any,
942
+ ) -> List[Document]:
943
+ """Return docs most similar to embedding vector.
944
+
945
+ Args:
946
+ embedding: Embedding to look up documents similar to.
947
+ k: Number of Documents to return. Defaults to 4.
948
+ filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
949
+
950
+ Returns:
951
+ List of Documents most similar to the query vector.
952
+ """
953
+ docs_and_scores = self.similarity_search_with_score_by_vector(
954
+ embedding=embedding, k=k, filter=filter
955
+ )
956
+ return _results_to_docs(docs_and_scores)
957
+
958
+ @classmethod
959
+ def from_texts(
960
+ cls: Type[PGVector],
961
+ texts: List[str],
962
+ embedding: Embeddings,
963
+ metadatas: Optional[List[dict]] = None,
964
+ *,
965
+ collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
966
+ distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
967
+ ids: Optional[List[str]] = None,
968
+ pre_delete_collection: bool = False,
969
+ use_jsonb: bool = True,
970
+ **kwargs: Any,
971
+ ) -> PGVector:
972
+ """Return VectorStore initialized from documents and embeddings."""
973
+ embeddings = embedding.embed_documents(list(texts))
974
+
975
+ return cls.__from(
976
+ texts,
977
+ embeddings,
978
+ embedding,
979
+ metadatas=metadatas,
980
+ ids=ids,
981
+ collection_name=collection_name,
982
+ distance_strategy=distance_strategy,
983
+ pre_delete_collection=pre_delete_collection,
984
+ use_jsonb=use_jsonb,
985
+ **kwargs,
986
+ )
987
+
988
+ @classmethod
989
+ def from_embeddings(
990
+ cls,
991
+ text_embeddings: List[Tuple[str, List[float]]],
992
+ embedding: Embeddings,
993
+ *,
994
+ metadatas: Optional[List[dict]] = None,
995
+ collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
996
+ distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
997
+ ids: Optional[List[str]] = None,
998
+ pre_delete_collection: bool = False,
999
+ **kwargs: Any,
1000
+ ) -> PGVector:
1001
+ """Construct PGVector wrapper from raw documents and embeddings.
1002
+
1003
+ Args:
1004
+ text_embeddings: List of tuples of text and embeddings.
1005
+ embedding: Embeddings object.
1006
+ metadatas: Optional list of metadatas associated with the texts.
1007
+ collection_name: Name of the collection.
1008
+ distance_strategy: Distance strategy to use.
1009
+ ids: Optional list of ids for the documents.
1010
+ pre_delete_collection: If True, will delete the collection if it exists.
1011
+ **Attention**: This will delete all the documents in the existing
1012
+ collection.
1013
+ kwargs: Additional arguments.
1014
+
1015
+ Returns:
1016
+ PGVector: PGVector instance.
1017
+
1018
+ Example:
1019
+ .. code-block:: python
1020
+
1021
+ from langchain_postgres.vectorstores import PGVector
1022
+ from langchain_openai.embeddings import OpenAIEmbeddings
1023
+
1024
+ embeddings = OpenAIEmbeddings()
1025
+ text_embeddings = embeddings.embed_documents(texts)
1026
+ text_embedding_pairs = list(zip(texts, text_embeddings))
1027
+ vectorstore = PGVector.from_embeddings(text_embedding_pairs, embeddings)
1028
+ """
1029
+ texts = [t[0] for t in text_embeddings]
1030
+ embeddings = [t[1] for t in text_embeddings]
1031
+
1032
+ return cls.__from(
1033
+ texts,
1034
+ embeddings,
1035
+ embedding,
1036
+ metadatas=metadatas,
1037
+ ids=ids,
1038
+ collection_name=collection_name,
1039
+ distance_strategy=distance_strategy,
1040
+ pre_delete_collection=pre_delete_collection,
1041
+ **kwargs,
1042
+ )
1043
+
1044
+ @classmethod
1045
+ def from_existing_index(
1046
+ cls: Type[PGVector],
1047
+ embedding: Embeddings,
1048
+ *,
1049
+ collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
1050
+ distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
1051
+ pre_delete_collection: bool = False,
1052
+ connection: Optional[Connection] = None,
1053
+ **kwargs: Any,
1054
+ ) -> PGVector:
1055
+ """
1056
+ Get instance of an existing PGVector store.This method will
1057
+ return the instance of the store without inserting any new
1058
+ embeddings
1059
+ """
1060
+ store = cls(
1061
+ connection=connection,
1062
+ collection_name=collection_name,
1063
+ embeddings=embedding,
1064
+ distance_strategy=distance_strategy,
1065
+ pre_delete_collection=pre_delete_collection,
1066
+ **kwargs,
1067
+ )
1068
+
1069
+ return store
1070
+
1071
+ @classmethod
1072
+ def get_connection_string(cls, kwargs: Dict[str, Any]) -> str:
1073
+ connection_string: str = get_from_dict_or_env(
1074
+ data=kwargs,
1075
+ key="connection_string",
1076
+ env_key="PGVECTOR_CONNECTION_STRING",
1077
+ )
1078
+
1079
+ if not connection_string:
1080
+ raise ValueError(
1081
+ "Postgres connection string is required"
1082
+ "Either pass it as a parameter"
1083
+ "or set the PGVECTOR_CONNECTION_STRING environment variable."
1084
+ )
1085
+
1086
+ return connection_string
1087
+
1088
+ @classmethod
1089
+ def from_documents(
1090
+ cls: Type[PGVector],
1091
+ documents: List[Document],
1092
+ embedding: Embeddings,
1093
+ *,
1094
+ connection: Optional[Connection] = None,
1095
+ collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
1096
+ distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
1097
+ ids: Optional[List[str]] = None,
1098
+ pre_delete_collection: bool = False,
1099
+ use_jsonb: bool = True,
1100
+ **kwargs: Any,
1101
+ ) -> PGVector:
1102
+ """Return VectorStore initialized from documents and embeddings."""
1103
+
1104
+ texts = [d.page_content for d in documents]
1105
+ metadatas = [d.metadata for d in documents]
1106
+
1107
+ return cls.from_texts(
1108
+ texts=texts,
1109
+ pre_delete_collection=pre_delete_collection,
1110
+ embedding=embedding,
1111
+ distance_strategy=distance_strategy,
1112
+ metadatas=metadatas,
1113
+ connection=connection,
1114
+ ids=ids,
1115
+ collection_name=collection_name,
1116
+ use_jsonb=use_jsonb,
1117
+ **kwargs,
1118
+ )
1119
+
1120
+ @classmethod
1121
+ def connection_string_from_db_params(
1122
+ cls,
1123
+ driver: str,
1124
+ host: str,
1125
+ port: int,
1126
+ database: str,
1127
+ user: str,
1128
+ password: str,
1129
+ ) -> str:
1130
+ """Return connection string from database parameters."""
1131
+ if driver != "psycopg":
1132
+ raise NotImplementedError("Only psycopg3 driver is supported")
1133
+ return f"postgresql+{driver}://{user}:{password}@{host}:{port}/{database}"
1134
+
1135
+ def _select_relevance_score_fn(self) -> Callable[[float], float]:
1136
+ """
1137
+ The 'correct' relevance function
1138
+ may differ depending on a few things, including:
1139
+ - the distance / similarity metric used by the VectorStore
1140
+ - the scale of your embeddings (OpenAI's are unit normed. Many others are not!)
1141
+ - embedding dimensionality
1142
+ - etc.
1143
+ """
1144
+ if self.override_relevance_score_fn is not None:
1145
+ return self.override_relevance_score_fn
1146
+
1147
+ # Default strategy is to rely on distance strategy provided
1148
+ # in vectorstore constructor
1149
+ if self._distance_strategy == DistanceStrategy.COSINE:
1150
+ return self._cosine_relevance_score_fn
1151
+ elif self._distance_strategy == DistanceStrategy.EUCLIDEAN:
1152
+ return self._euclidean_relevance_score_fn
1153
+ elif self._distance_strategy == DistanceStrategy.MAX_INNER_PRODUCT:
1154
+ return self._max_inner_product_relevance_score_fn
1155
+ else:
1156
+ raise ValueError(
1157
+ "No supported normalization function"
1158
+ f" for distance_strategy of {self._distance_strategy}."
1159
+ "Consider providing relevance_score_fn to PGVector constructor."
1160
+ )
1161
+
1162
+ def max_marginal_relevance_search_with_score_by_vector(
1163
+ self,
1164
+ embedding: List[float],
1165
+ k: int = 4,
1166
+ fetch_k: int = 20,
1167
+ lambda_mult: float = 0.5,
1168
+ filter: Optional[Dict[str, str]] = None,
1169
+ **kwargs: Any,
1170
+ ) -> List[Tuple[Document, float]]:
1171
+ """Return docs selected using the maximal marginal relevance with score
1172
+ to embedding vector.
1173
+
1174
+ Maximal marginal relevance optimizes for similarity to query AND diversity
1175
+ among selected documents.
1176
+
1177
+ Args:
1178
+ embedding: Embedding to look up documents similar to.
1179
+ k (int): Number of Documents to return. Defaults to 4.
1180
+ fetch_k (int): Number of Documents to fetch to pass to MMR algorithm.
1181
+ Defaults to 20.
1182
+ lambda_mult (float): Number between 0 and 1 that determines the degree
1183
+ of diversity among the results with 0 corresponding
1184
+ to maximum diversity and 1 to minimum diversity.
1185
+ Defaults to 0.5.
1186
+ filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
1187
+
1188
+ Returns:
1189
+ List[Tuple[Document, float]]: List of Documents selected by maximal marginal
1190
+ relevance to the query and score for each.
1191
+ """
1192
+ results = self.__query_collection(embedding=embedding, k=fetch_k, filter=filter)
1193
+
1194
+ embedding_list = [result.EmbeddingStore.embedding for result in results]
1195
+
1196
+ mmr_selected = maximal_marginal_relevance(
1197
+ np.array(embedding, dtype=np.float32),
1198
+ embedding_list,
1199
+ k=k,
1200
+ lambda_mult=lambda_mult,
1201
+ )
1202
+
1203
+ candidates = self._results_to_docs_and_scores(results)
1204
+
1205
+ return [r for i, r in enumerate(candidates) if i in mmr_selected]
1206
+
1207
+ def max_marginal_relevance_search(
1208
+ self,
1209
+ query: str,
1210
+ k: int = 4,
1211
+ fetch_k: int = 20,
1212
+ lambda_mult: float = 0.5,
1213
+ filter: Optional[Dict[str, str]] = None,
1214
+ **kwargs: Any,
1215
+ ) -> List[Document]:
1216
+ """Return docs selected using the maximal marginal relevance.
1217
+
1218
+ Maximal marginal relevance optimizes for similarity to query AND diversity
1219
+ among selected documents.
1220
+
1221
+ Args:
1222
+ query (str): Text to look up documents similar to.
1223
+ k (int): Number of Documents to return. Defaults to 4.
1224
+ fetch_k (int): Number of Documents to fetch to pass to MMR algorithm.
1225
+ Defaults to 20.
1226
+ lambda_mult (float): Number between 0 and 1 that determines the degree
1227
+ of diversity among the results with 0 corresponding
1228
+ to maximum diversity and 1 to minimum diversity.
1229
+ Defaults to 0.5.
1230
+ filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
1231
+
1232
+ Returns:
1233
+ List[Document]: List of Documents selected by maximal marginal relevance.
1234
+ """
1235
+ embedding = self.embedding_function.embed_query(query)
1236
+ return self.max_marginal_relevance_search_by_vector(
1237
+ embedding,
1238
+ k=k,
1239
+ fetch_k=fetch_k,
1240
+ lambda_mult=lambda_mult,
1241
+ filter=filter,
1242
+ **kwargs,
1243
+ )
1244
+
1245
+ def max_marginal_relevance_search_with_score(
1246
+ self,
1247
+ query: str,
1248
+ k: int = 4,
1249
+ fetch_k: int = 20,
1250
+ lambda_mult: float = 0.5,
1251
+ filter: Optional[dict] = None,
1252
+ **kwargs: Any,
1253
+ ) -> List[Tuple[Document, float]]:
1254
+ """Return docs selected using the maximal marginal relevance with score.
1255
+
1256
+ Maximal marginal relevance optimizes for similarity to query AND diversity
1257
+ among selected documents.
1258
+
1259
+ Args:
1260
+ query (str): Text to look up documents similar to.
1261
+ k (int): Number of Documents to return. Defaults to 4.
1262
+ fetch_k (int): Number of Documents to fetch to pass to MMR algorithm.
1263
+ Defaults to 20.
1264
+ lambda_mult (float): Number between 0 and 1 that determines the degree
1265
+ of diversity among the results with 0 corresponding
1266
+ to maximum diversity and 1 to minimum diversity.
1267
+ Defaults to 0.5.
1268
+ filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
1269
+
1270
+ Returns:
1271
+ List[Tuple[Document, float]]: List of Documents selected by maximal marginal
1272
+ relevance to the query and score for each.
1273
+ """
1274
+ embedding = self.embedding_function.embed_query(query)
1275
+ docs = self.max_marginal_relevance_search_with_score_by_vector(
1276
+ embedding=embedding,
1277
+ k=k,
1278
+ fetch_k=fetch_k,
1279
+ lambda_mult=lambda_mult,
1280
+ filter=filter,
1281
+ **kwargs,
1282
+ )
1283
+ return docs
1284
+
1285
+ def max_marginal_relevance_search_by_vector(
1286
+ self,
1287
+ embedding: List[float],
1288
+ k: int = 4,
1289
+ fetch_k: int = 20,
1290
+ lambda_mult: float = 0.5,
1291
+ filter: Optional[Dict[str, str]] = None,
1292
+ **kwargs: Any,
1293
+ ) -> List[Document]:
1294
+ """Return docs selected using the maximal marginal relevance
1295
+ to embedding vector.
1296
+
1297
+ Maximal marginal relevance optimizes for similarity to query AND diversity
1298
+ among selected documents.
1299
+
1300
+ Args:
1301
+ embedding (str): Text to look up documents similar to.
1302
+ k (int): Number of Documents to return. Defaults to 4.
1303
+ fetch_k (int): Number of Documents to fetch to pass to MMR algorithm.
1304
+ Defaults to 20.
1305
+ lambda_mult (float): Number between 0 and 1 that determines the degree
1306
+ of diversity among the results with 0 corresponding
1307
+ to maximum diversity and 1 to minimum diversity.
1308
+ Defaults to 0.5.
1309
+ filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
1310
+
1311
+ Returns:
1312
+ List[Document]: List of Documents selected by maximal marginal relevance.
1313
+ """
1314
+ docs_and_scores = self.max_marginal_relevance_search_with_score_by_vector(
1315
+ embedding,
1316
+ k=k,
1317
+ fetch_k=fetch_k,
1318
+ lambda_mult=lambda_mult,
1319
+ filter=filter,
1320
+ **kwargs,
1321
+ )
1322
+
1323
+ return _results_to_docs(docs_and_scores)
1324
+
1325
+ async def amax_marginal_relevance_search_by_vector(
1326
+ self,
1327
+ embedding: List[float],
1328
+ k: int = 4,
1329
+ fetch_k: int = 20,
1330
+ lambda_mult: float = 0.5,
1331
+ filter: Optional[Dict[str, str]] = None,
1332
+ **kwargs: Any,
1333
+ ) -> List[Document]:
1334
+ """Return docs selected using the maximal marginal relevance."""
1335
+
1336
+ # This is a temporary workaround to make the similarity search
1337
+ # asynchronous. The proper solution is to make the similarity search
1338
+ # asynchronous in the vector store implementations.
1339
+ return await run_in_executor(
1340
+ None,
1341
+ self.max_marginal_relevance_search_by_vector,
1342
+ embedding,
1343
+ k=k,
1344
+ fetch_k=fetch_k,
1345
+ lambda_mult=lambda_mult,
1346
+ filter=filter,
1347
+ **kwargs,
1348
+ )