langchain-postgres 0.0.13__py3-none-any.whl → 0.0.14rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langchain_postgres/__init__.py +6 -0
- langchain_postgres/chat_message_histories.py +7 -1
- langchain_postgres/utils/pgvector_migrator.py +321 -0
- langchain_postgres/v2/__init__.py +0 -0
- langchain_postgres/v2/async_vectorstore.py +1268 -0
- langchain_postgres/v2/engine.py +351 -0
- langchain_postgres/v2/indexes.py +155 -0
- langchain_postgres/v2/vectorstores.py +842 -0
- langchain_postgres/vectorstores.py +11 -4
- langchain_postgres-0.0.14rc1.dist-info/METADATA +170 -0
- langchain_postgres-0.0.14rc1.dist-info/RECORD +16 -0
- langchain_postgres-0.0.13.dist-info/METADATA +0 -109
- langchain_postgres-0.0.13.dist-info/RECORD +0 -10
- {langchain_postgres-0.0.13.dist-info → langchain_postgres-0.0.14rc1.dist-info}/LICENSE +0 -0
- {langchain_postgres-0.0.13.dist-info → langchain_postgres-0.0.14rc1.dist-info}/WHEEL +0 -0
@@ -0,0 +1,1268 @@
|
|
1
|
+
# TODO: Remove below import when minimum supported Python version is 3.10
|
2
|
+
from __future__ import annotations
|
3
|
+
|
4
|
+
import copy
|
5
|
+
import json
|
6
|
+
import uuid
|
7
|
+
from typing import Any, Callable, Iterable, Optional, Sequence
|
8
|
+
|
9
|
+
import numpy as np
|
10
|
+
from langchain_core.documents import Document
|
11
|
+
from langchain_core.embeddings import Embeddings
|
12
|
+
from langchain_core.vectorstores import VectorStore, utils
|
13
|
+
from sqlalchemy import RowMapping, text
|
14
|
+
from sqlalchemy.ext.asyncio import AsyncEngine
|
15
|
+
|
16
|
+
from .engine import PGEngine
|
17
|
+
from .indexes import (
|
18
|
+
DEFAULT_DISTANCE_STRATEGY,
|
19
|
+
DEFAULT_INDEX_NAME_SUFFIX,
|
20
|
+
BaseIndex,
|
21
|
+
DistanceStrategy,
|
22
|
+
ExactNearestNeighbor,
|
23
|
+
QueryOptions,
|
24
|
+
)
|
25
|
+
|
26
|
+
COMPARISONS_TO_NATIVE = {
|
27
|
+
"$eq": "=",
|
28
|
+
"$ne": "!=",
|
29
|
+
"$lt": "<",
|
30
|
+
"$lte": "<=",
|
31
|
+
"$gt": ">",
|
32
|
+
"$gte": ">=",
|
33
|
+
}
|
34
|
+
|
35
|
+
SPECIAL_CASED_OPERATORS = {
|
36
|
+
"$in",
|
37
|
+
"$nin",
|
38
|
+
"$between",
|
39
|
+
"$exists",
|
40
|
+
}
|
41
|
+
|
42
|
+
TEXT_OPERATORS = {
|
43
|
+
"$like",
|
44
|
+
"$ilike",
|
45
|
+
}
|
46
|
+
|
47
|
+
LOGICAL_OPERATORS = {"$and", "$or", "$not"}
|
48
|
+
|
49
|
+
SUPPORTED_OPERATORS = (
|
50
|
+
set(COMPARISONS_TO_NATIVE)
|
51
|
+
.union(TEXT_OPERATORS)
|
52
|
+
.union(LOGICAL_OPERATORS)
|
53
|
+
.union(SPECIAL_CASED_OPERATORS)
|
54
|
+
)
|
55
|
+
|
56
|
+
|
57
|
+
class AsyncPGVectorStore(VectorStore):
|
58
|
+
"""Postgres Vector Store class"""
|
59
|
+
|
60
|
+
__create_key = object()
|
61
|
+
|
62
|
+
def __init__(
|
63
|
+
self,
|
64
|
+
key: object,
|
65
|
+
engine: AsyncEngine,
|
66
|
+
embedding_service: Embeddings,
|
67
|
+
table_name: str,
|
68
|
+
*,
|
69
|
+
schema_name: str = "public",
|
70
|
+
content_column: str = "content",
|
71
|
+
embedding_column: str = "embedding",
|
72
|
+
metadata_columns: Optional[list[str]] = None,
|
73
|
+
id_column: str = "langchain_id",
|
74
|
+
metadata_json_column: Optional[str] = "langchain_metadata",
|
75
|
+
distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
|
76
|
+
k: int = 4,
|
77
|
+
fetch_k: int = 20,
|
78
|
+
lambda_mult: float = 0.5,
|
79
|
+
index_query_options: Optional[QueryOptions] = None,
|
80
|
+
):
|
81
|
+
"""AsyncPGVectorStore constructor.
|
82
|
+
Args:
|
83
|
+
key (object): Prevent direct constructor usage.
|
84
|
+
engine (PGEngine): Connection pool engine for managing connections to postgres database.
|
85
|
+
embedding_service (Embeddings): Text embedding model to use.
|
86
|
+
table_name (str): Name of the existing table or the table to be created.
|
87
|
+
schema_name (str, optional): Name of the database schema. Defaults to "public".
|
88
|
+
content_column (str): Column that represent a Document's page_content. Defaults to "content".
|
89
|
+
embedding_column (str): Column for embedding vectors. The embedding is generated from the document value. Defaults to "embedding".
|
90
|
+
metadata_columns (list[str]): Column(s) that represent a document's metadata.
|
91
|
+
id_column (str): Column that represents the Document's id. Defaults to "langchain_id".
|
92
|
+
metadata_json_column (str): Column to store metadata as JSON. Defaults to "langchain_metadata".
|
93
|
+
distance_strategy (DistanceStrategy): Distance strategy to use for vector similarity search. Defaults to COSINE_DISTANCE.
|
94
|
+
k (int): Number of Documents to return from search. Defaults to 4.
|
95
|
+
fetch_k (int): Number of Documents to fetch to pass to MMR algorithm.
|
96
|
+
lambda_mult (float): Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5.
|
97
|
+
index_query_options (QueryOptions): Index query option.
|
98
|
+
|
99
|
+
|
100
|
+
Raises:
|
101
|
+
Exception: If called directly by user.
|
102
|
+
"""
|
103
|
+
if key != AsyncPGVectorStore.__create_key:
|
104
|
+
raise Exception(
|
105
|
+
"Only create class through 'create' or 'create_sync' methods!"
|
106
|
+
)
|
107
|
+
|
108
|
+
self.engine = engine
|
109
|
+
self.embedding_service = embedding_service
|
110
|
+
self.table_name = table_name
|
111
|
+
self.schema_name = schema_name
|
112
|
+
self.content_column = content_column
|
113
|
+
self.embedding_column = embedding_column
|
114
|
+
self.metadata_columns = metadata_columns if metadata_columns is not None else []
|
115
|
+
self.id_column = id_column
|
116
|
+
self.metadata_json_column = metadata_json_column
|
117
|
+
self.distance_strategy = distance_strategy
|
118
|
+
self.k = k
|
119
|
+
self.fetch_k = fetch_k
|
120
|
+
self.lambda_mult = lambda_mult
|
121
|
+
self.index_query_options = index_query_options
|
122
|
+
|
123
|
+
@classmethod
|
124
|
+
async def create(
|
125
|
+
cls: type[AsyncPGVectorStore],
|
126
|
+
engine: PGEngine,
|
127
|
+
embedding_service: Embeddings,
|
128
|
+
table_name: str,
|
129
|
+
*,
|
130
|
+
schema_name: str = "public",
|
131
|
+
content_column: str = "content",
|
132
|
+
embedding_column: str = "embedding",
|
133
|
+
metadata_columns: Optional[list[str]] = None,
|
134
|
+
ignore_metadata_columns: Optional[list[str]] = None,
|
135
|
+
id_column: str = "langchain_id",
|
136
|
+
metadata_json_column: Optional[str] = "langchain_metadata",
|
137
|
+
distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
|
138
|
+
k: int = 4,
|
139
|
+
fetch_k: int = 20,
|
140
|
+
lambda_mult: float = 0.5,
|
141
|
+
index_query_options: Optional[QueryOptions] = None,
|
142
|
+
) -> AsyncPGVectorStore:
|
143
|
+
"""Create an AsyncPGVectorStore instance.
|
144
|
+
|
145
|
+
Args:
|
146
|
+
engine (PGEngine): Connection pool engine for managing connections to postgres database.
|
147
|
+
embedding_service (Embeddings): Text embedding model to use.
|
148
|
+
table_name (str): Name of an existing table.
|
149
|
+
schema_name (str, optional): Name of the database schema. Defaults to "public".
|
150
|
+
content_column (str): Column that represent a Document's page_content. Defaults to "content".
|
151
|
+
embedding_column (str): Column for embedding vectors. The embedding is generated from the document value. Defaults to "embedding".
|
152
|
+
metadata_columns (list[str]): Column(s) that represent a document's metadata.
|
153
|
+
ignore_metadata_columns (list[str]): Column(s) to ignore in pre-existing tables for a document's metadata. Can not be used with metadata_columns. Defaults to None.
|
154
|
+
id_column (str): Column that represents the Document's id. Defaults to "langchain_id".
|
155
|
+
metadata_json_column (str): Column to store metadata as JSON. Defaults to "langchain_metadata".
|
156
|
+
distance_strategy (DistanceStrategy): Distance strategy to use for vector similarity search. Defaults to COSINE_DISTANCE.
|
157
|
+
k (int): Number of Documents to return from search. Defaults to 4.
|
158
|
+
fetch_k (int): Number of Documents to fetch to pass to MMR algorithm.
|
159
|
+
lambda_mult (float): Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5.
|
160
|
+
index_query_options (QueryOptions): Index query option.
|
161
|
+
|
162
|
+
Returns:
|
163
|
+
AsyncPGVectorStore
|
164
|
+
"""
|
165
|
+
|
166
|
+
if metadata_columns is None:
|
167
|
+
metadata_columns = []
|
168
|
+
|
169
|
+
if metadata_columns and ignore_metadata_columns:
|
170
|
+
raise ValueError(
|
171
|
+
"Can not use both metadata_columns and ignore_metadata_columns."
|
172
|
+
)
|
173
|
+
# Get field type information
|
174
|
+
stmt = "SELECT column_name, data_type FROM information_schema.columns WHERE table_name = :table_name AND table_schema = :schema_name"
|
175
|
+
async with engine._pool.connect() as conn:
|
176
|
+
result = await conn.execute(
|
177
|
+
text(stmt),
|
178
|
+
{"table_name": table_name, "schema_name": schema_name},
|
179
|
+
)
|
180
|
+
result_map = result.mappings()
|
181
|
+
results = result_map.fetchall()
|
182
|
+
columns = {}
|
183
|
+
for field in results:
|
184
|
+
columns[field["column_name"]] = field["data_type"]
|
185
|
+
|
186
|
+
# Check columns
|
187
|
+
if id_column not in columns:
|
188
|
+
raise ValueError(f"Id column, {id_column}, does not exist.")
|
189
|
+
if content_column not in columns:
|
190
|
+
raise ValueError(f"Content column, {content_column}, does not exist.")
|
191
|
+
content_type = columns[content_column]
|
192
|
+
if content_type != "text" and "char" not in content_type:
|
193
|
+
raise ValueError(
|
194
|
+
f"Content column, {content_column}, is type, {content_type}. It must be a type of character string."
|
195
|
+
)
|
196
|
+
if embedding_column not in columns:
|
197
|
+
raise ValueError(f"Embedding column, {embedding_column}, does not exist.")
|
198
|
+
if columns[embedding_column] != "USER-DEFINED":
|
199
|
+
raise ValueError(
|
200
|
+
f"Embedding column, {embedding_column}, is not type Vector."
|
201
|
+
)
|
202
|
+
|
203
|
+
metadata_json_column = (
|
204
|
+
None if metadata_json_column not in columns else metadata_json_column
|
205
|
+
)
|
206
|
+
|
207
|
+
# If using metadata_columns check to make sure column exists
|
208
|
+
for column in metadata_columns:
|
209
|
+
if column not in columns:
|
210
|
+
raise ValueError(f"Metadata column, {column}, does not exist.")
|
211
|
+
|
212
|
+
# If using ignore_metadata_columns, filter out known columns and set known metadata columns
|
213
|
+
all_columns = columns
|
214
|
+
if ignore_metadata_columns:
|
215
|
+
for column in ignore_metadata_columns:
|
216
|
+
del all_columns[column]
|
217
|
+
|
218
|
+
del all_columns[id_column]
|
219
|
+
del all_columns[content_column]
|
220
|
+
del all_columns[embedding_column]
|
221
|
+
metadata_columns = [k for k in all_columns.keys()]
|
222
|
+
|
223
|
+
return cls(
|
224
|
+
cls.__create_key,
|
225
|
+
engine._pool,
|
226
|
+
embedding_service,
|
227
|
+
table_name,
|
228
|
+
schema_name=schema_name,
|
229
|
+
content_column=content_column,
|
230
|
+
embedding_column=embedding_column,
|
231
|
+
metadata_columns=metadata_columns,
|
232
|
+
id_column=id_column,
|
233
|
+
metadata_json_column=metadata_json_column,
|
234
|
+
distance_strategy=distance_strategy,
|
235
|
+
k=k,
|
236
|
+
fetch_k=fetch_k,
|
237
|
+
lambda_mult=lambda_mult,
|
238
|
+
index_query_options=index_query_options,
|
239
|
+
)
|
240
|
+
|
241
|
+
@property
|
242
|
+
def embeddings(self) -> Embeddings:
|
243
|
+
return self.embedding_service
|
244
|
+
|
245
|
+
async def aadd_embeddings(
|
246
|
+
self,
|
247
|
+
texts: Iterable[str],
|
248
|
+
embeddings: list[list[float]],
|
249
|
+
metadatas: Optional[list[dict]] = None,
|
250
|
+
ids: Optional[list] = None,
|
251
|
+
**kwargs: Any,
|
252
|
+
) -> list[str]:
|
253
|
+
"""Add data along with embeddings to the table.
|
254
|
+
|
255
|
+
Raises:
|
256
|
+
:class:`InvalidTextRepresentationError <asyncpg.exceptions.InvalidTextRepresentationError>`: if the `ids` data type does not match that of the `id_column`.
|
257
|
+
"""
|
258
|
+
if not ids:
|
259
|
+
ids = [str(uuid.uuid4()) for _ in texts]
|
260
|
+
else:
|
261
|
+
# This is done to fill in any missing ids
|
262
|
+
ids = [id if id is not None else str(uuid.uuid4()) for id in ids]
|
263
|
+
if not metadatas:
|
264
|
+
metadatas = [{} for _ in texts]
|
265
|
+
|
266
|
+
# Check for inline embedding capability
|
267
|
+
inline_embed_func = getattr(self.embedding_service, "embed_query_inline", None)
|
268
|
+
can_inline_embed = callable(inline_embed_func)
|
269
|
+
# Insert embeddings
|
270
|
+
for id, content, embedding, metadata in zip(ids, texts, embeddings, metadatas):
|
271
|
+
metadata_col_names = (
|
272
|
+
", " + ", ".join(f'"{col}"' for col in self.metadata_columns)
|
273
|
+
if len(self.metadata_columns) > 0
|
274
|
+
else ""
|
275
|
+
)
|
276
|
+
insert_stmt = f'INSERT INTO "{self.schema_name}"."{self.table_name}"("{self.id_column}", "{self.content_column}", "{self.embedding_column}"{metadata_col_names}'
|
277
|
+
values = {
|
278
|
+
"id": id,
|
279
|
+
"content": content,
|
280
|
+
"embedding": str([float(dimension) for dimension in embedding]),
|
281
|
+
}
|
282
|
+
values_stmt = "VALUES (:id, :content, :embedding"
|
283
|
+
|
284
|
+
if not embedding and can_inline_embed:
|
285
|
+
values_stmt = f"VALUES (:id, :content, {self.embedding_service.embed_query_inline(content)}" # type: ignore
|
286
|
+
|
287
|
+
# Add metadata
|
288
|
+
extra = copy.deepcopy(metadata)
|
289
|
+
for metadata_column in self.metadata_columns:
|
290
|
+
if metadata_column in metadata:
|
291
|
+
values_stmt += f", :{metadata_column}"
|
292
|
+
values[metadata_column] = metadata[metadata_column]
|
293
|
+
del extra[metadata_column]
|
294
|
+
else:
|
295
|
+
values_stmt += ",null"
|
296
|
+
|
297
|
+
# Add JSON column and/or close statement
|
298
|
+
insert_stmt += (
|
299
|
+
f""", "{self.metadata_json_column}")"""
|
300
|
+
if self.metadata_json_column
|
301
|
+
else ")"
|
302
|
+
)
|
303
|
+
if self.metadata_json_column:
|
304
|
+
values_stmt += ", :extra)"
|
305
|
+
values["extra"] = json.dumps(extra)
|
306
|
+
else:
|
307
|
+
values_stmt += ")"
|
308
|
+
|
309
|
+
upsert_stmt = f' ON CONFLICT ("{self.id_column}") DO UPDATE SET "{self.content_column}" = EXCLUDED."{self.content_column}", "{self.embedding_column}" = EXCLUDED."{self.embedding_column}"'
|
310
|
+
|
311
|
+
if self.metadata_json_column:
|
312
|
+
upsert_stmt += f', "{self.metadata_json_column}" = EXCLUDED."{self.metadata_json_column}"'
|
313
|
+
|
314
|
+
for column in self.metadata_columns:
|
315
|
+
upsert_stmt += f', "{column}" = EXCLUDED."{column}"'
|
316
|
+
|
317
|
+
upsert_stmt += ";"
|
318
|
+
|
319
|
+
query = insert_stmt + values_stmt + upsert_stmt
|
320
|
+
async with self.engine.connect() as conn:
|
321
|
+
await conn.execute(text(query), values)
|
322
|
+
await conn.commit()
|
323
|
+
|
324
|
+
return ids
|
325
|
+
|
326
|
+
async def aadd_texts(
|
327
|
+
self,
|
328
|
+
texts: Iterable[str],
|
329
|
+
metadatas: Optional[list[dict]] = None,
|
330
|
+
ids: Optional[list] = None,
|
331
|
+
**kwargs: Any,
|
332
|
+
) -> list[str]:
|
333
|
+
"""Embed texts and add to the table.
|
334
|
+
|
335
|
+
Raises:
|
336
|
+
:class:`InvalidTextRepresentationError <asyncpg.exceptions.InvalidTextRepresentationError>`: if the `ids` data type does not match that of the `id_column`.
|
337
|
+
"""
|
338
|
+
# Check for inline embedding query
|
339
|
+
inline_embed_func = getattr(self.embedding_service, "embed_query_inline", None)
|
340
|
+
if callable(inline_embed_func):
|
341
|
+
embeddings: list[list[float]] = [[] for _ in list(texts)]
|
342
|
+
else:
|
343
|
+
embeddings = await self.embedding_service.aembed_documents(list(texts))
|
344
|
+
|
345
|
+
ids = await self.aadd_embeddings(
|
346
|
+
texts, embeddings, metadatas=metadatas, ids=ids, **kwargs
|
347
|
+
)
|
348
|
+
return ids
|
349
|
+
|
350
|
+
async def aadd_documents(
|
351
|
+
self,
|
352
|
+
documents: list[Document],
|
353
|
+
ids: Optional[list] = None,
|
354
|
+
**kwargs: Any,
|
355
|
+
) -> list[str]:
|
356
|
+
"""Embed documents and add to the table.
|
357
|
+
|
358
|
+
Raises:
|
359
|
+
:class:`InvalidTextRepresentationError <asyncpg.exceptions.InvalidTextRepresentationError>`: if the `ids` data type does not match that of the `id_column`.
|
360
|
+
"""
|
361
|
+
texts = [doc.page_content for doc in documents]
|
362
|
+
metadatas = [doc.metadata for doc in documents]
|
363
|
+
if not ids:
|
364
|
+
ids = [doc.id for doc in documents]
|
365
|
+
ids = await self.aadd_texts(texts, metadatas=metadatas, ids=ids, **kwargs)
|
366
|
+
return ids
|
367
|
+
|
368
|
+
async def adelete(
|
369
|
+
self,
|
370
|
+
ids: Optional[list] = None,
|
371
|
+
**kwargs: Any,
|
372
|
+
) -> Optional[bool]:
|
373
|
+
"""Delete records from the table.
|
374
|
+
|
375
|
+
Raises:
|
376
|
+
:class:`InvalidTextRepresentationError <asyncpg.exceptions.InvalidTextRepresentationError>`: if the `ids` data type does not match that of the `id_column`.
|
377
|
+
"""
|
378
|
+
if not ids:
|
379
|
+
return False
|
380
|
+
|
381
|
+
placeholders = ", ".join(f":id_{i}" for i in range(len(ids)))
|
382
|
+
param_dict = {f"id_{i}": id for i, id in enumerate(ids)}
|
383
|
+
query = f'DELETE FROM "{self.schema_name}"."{self.table_name}" WHERE {self.id_column} in ({placeholders})'
|
384
|
+
async with self.engine.connect() as conn:
|
385
|
+
await conn.execute(text(query), param_dict)
|
386
|
+
await conn.commit()
|
387
|
+
return True
|
388
|
+
|
389
|
+
@classmethod
|
390
|
+
async def afrom_texts( # type: ignore[override]
|
391
|
+
cls: type[AsyncPGVectorStore],
|
392
|
+
texts: list[str],
|
393
|
+
embedding: Embeddings,
|
394
|
+
engine: PGEngine,
|
395
|
+
table_name: str,
|
396
|
+
*,
|
397
|
+
schema_name: str = "public",
|
398
|
+
metadatas: Optional[list[dict]] = None,
|
399
|
+
ids: Optional[list] = None,
|
400
|
+
content_column: str = "content",
|
401
|
+
embedding_column: str = "embedding",
|
402
|
+
metadata_columns: Optional[list[str]] = None,
|
403
|
+
ignore_metadata_columns: Optional[list[str]] = None,
|
404
|
+
id_column: str = "langchain_id",
|
405
|
+
metadata_json_column: str = "langchain_metadata",
|
406
|
+
distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
|
407
|
+
k: int = 4,
|
408
|
+
fetch_k: int = 20,
|
409
|
+
lambda_mult: float = 0.5,
|
410
|
+
index_query_options: Optional[QueryOptions] = None,
|
411
|
+
**kwargs: Any,
|
412
|
+
) -> AsyncPGVectorStore:
|
413
|
+
"""Create an AsyncPGVectorStore instance from texts.
|
414
|
+
|
415
|
+
Args:
|
416
|
+
texts (list[str]): Texts to add to the vector store.
|
417
|
+
embedding (Embeddings): Text embedding model to use.
|
418
|
+
engine (PGEngine): Connection pool engine for managing connections to postgres database.
|
419
|
+
table_name (str): Name of an existing table.
|
420
|
+
metadatas (Optional[list[dict]]): List of metadatas to add to table records.
|
421
|
+
ids: (Optional[list[str]]): List of IDs to add to table records.
|
422
|
+
content_column (str): Column that represent a Document's page_content. Defaults to "content".
|
423
|
+
embedding_column (str): Column for embedding vectors. The embedding is generated from the document value. Defaults to "embedding".
|
424
|
+
metadata_columns (list[str]): Column(s) that represent a document's metadata.
|
425
|
+
ignore_metadata_columns (list[str]): Column(s) to ignore in pre-existing tables for a document's metadata. Can not be used with metadata_columns. Defaults to None.
|
426
|
+
id_column (str): Column that represents the Document's id. Defaults to "langchain_id".
|
427
|
+
metadata_json_column (str): Column to store metadata as JSON. Defaults to "langchain_metadata".
|
428
|
+
distance_strategy (DistanceStrategy): Distance strategy to use for vector similarity search. Defaults to COSINE_DISTANCE.
|
429
|
+
k (int): Number of Documents to return from search. Defaults to 4.
|
430
|
+
fetch_k (int): Number of Documents to fetch to pass to MMR algorithm.
|
431
|
+
lambda_mult (float): Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5.
|
432
|
+
index_query_options (QueryOptions): Index query option.
|
433
|
+
|
434
|
+
Raises:
|
435
|
+
:class:`InvalidTextRepresentationError <asyncpg.exceptions.InvalidTextRepresentationError>`: if the `ids` data type does not match that of the `id_column`.
|
436
|
+
|
437
|
+
Returns:
|
438
|
+
AsyncPGVectorStore
|
439
|
+
"""
|
440
|
+
vs = await cls.create(
|
441
|
+
engine,
|
442
|
+
embedding,
|
443
|
+
table_name,
|
444
|
+
schema_name=schema_name,
|
445
|
+
content_column=content_column,
|
446
|
+
embedding_column=embedding_column,
|
447
|
+
metadata_columns=metadata_columns,
|
448
|
+
ignore_metadata_columns=ignore_metadata_columns,
|
449
|
+
id_column=id_column,
|
450
|
+
metadata_json_column=metadata_json_column,
|
451
|
+
distance_strategy=distance_strategy,
|
452
|
+
k=k,
|
453
|
+
fetch_k=fetch_k,
|
454
|
+
lambda_mult=lambda_mult,
|
455
|
+
index_query_options=index_query_options,
|
456
|
+
)
|
457
|
+
await vs.aadd_texts(texts, metadatas=metadatas, ids=ids, **kwargs)
|
458
|
+
return vs
|
459
|
+
|
460
|
+
@classmethod
|
461
|
+
async def afrom_documents( # type: ignore[override]
|
462
|
+
cls: type[AsyncPGVectorStore],
|
463
|
+
documents: list[Document],
|
464
|
+
embedding: Embeddings,
|
465
|
+
engine: PGEngine,
|
466
|
+
table_name: str,
|
467
|
+
*,
|
468
|
+
schema_name: str = "public",
|
469
|
+
ids: Optional[list] = None,
|
470
|
+
content_column: str = "content",
|
471
|
+
embedding_column: str = "embedding",
|
472
|
+
metadata_columns: Optional[list[str]] = None,
|
473
|
+
ignore_metadata_columns: Optional[list[str]] = None,
|
474
|
+
id_column: str = "langchain_id",
|
475
|
+
metadata_json_column: str = "langchain_metadata",
|
476
|
+
distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
|
477
|
+
k: int = 4,
|
478
|
+
fetch_k: int = 20,
|
479
|
+
lambda_mult: float = 0.5,
|
480
|
+
index_query_options: Optional[QueryOptions] = None,
|
481
|
+
**kwargs: Any,
|
482
|
+
) -> AsyncPGVectorStore:
|
483
|
+
"""Create an AsyncPGVectorStore instance from documents.
|
484
|
+
|
485
|
+
Args:
|
486
|
+
documents (list[Document]): Documents to add to the vector store.
|
487
|
+
embedding (Embeddings): Text embedding model to use.
|
488
|
+
engine (PGEngine): Connection pool engine for managing connections to postgres database.
|
489
|
+
table_name (str): Name of an existing table.
|
490
|
+
metadatas (Optional[list[dict]]): List of metadatas to add to table records.
|
491
|
+
ids: (Optional[list[str]]): List of IDs to add to table records.
|
492
|
+
content_column (str): Column that represent a Document's page_content. Defaults to "content".
|
493
|
+
embedding_column (str): Column for embedding vectors. The embedding is generated from the document value. Defaults to "embedding".
|
494
|
+
metadata_columns (list[str]): Column(s) that represent a document's metadata.
|
495
|
+
ignore_metadata_columns (list[str]): Column(s) to ignore in pre-existing tables for a document's metadata. Can not be used with metadata_columns. Defaults to None.
|
496
|
+
id_column (str): Column that represents the Document's id. Defaults to "langchain_id".
|
497
|
+
metadata_json_column (str): Column to store metadata as JSON. Defaults to "langchain_metadata".
|
498
|
+
distance_strategy (DistanceStrategy): Distance strategy to use for vector similarity search. Defaults to COSINE_DISTANCE.
|
499
|
+
k (int): Number of Documents to return from search. Defaults to 4.
|
500
|
+
fetch_k (int): Number of Documents to fetch to pass to MMR algorithm.
|
501
|
+
lambda_mult (float): Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5.
|
502
|
+
index_query_options (QueryOptions): Index query option.
|
503
|
+
|
504
|
+
Raises:
|
505
|
+
:class:`InvalidTextRepresentationError <asyncpg.exceptions.InvalidTextRepresentationError>`: if the `ids` data type does not match that of the `id_column`.
|
506
|
+
|
507
|
+
Returns:
|
508
|
+
AsyncPGVectorStore
|
509
|
+
"""
|
510
|
+
|
511
|
+
vs = await cls.create(
|
512
|
+
engine,
|
513
|
+
embedding,
|
514
|
+
table_name,
|
515
|
+
schema_name=schema_name,
|
516
|
+
content_column=content_column,
|
517
|
+
embedding_column=embedding_column,
|
518
|
+
metadata_columns=metadata_columns,
|
519
|
+
ignore_metadata_columns=ignore_metadata_columns,
|
520
|
+
id_column=id_column,
|
521
|
+
metadata_json_column=metadata_json_column,
|
522
|
+
distance_strategy=distance_strategy,
|
523
|
+
k=k,
|
524
|
+
fetch_k=fetch_k,
|
525
|
+
lambda_mult=lambda_mult,
|
526
|
+
index_query_options=index_query_options,
|
527
|
+
)
|
528
|
+
texts = [doc.page_content for doc in documents]
|
529
|
+
metadatas = [doc.metadata for doc in documents]
|
530
|
+
await vs.aadd_texts(texts, metadatas=metadatas, ids=ids, **kwargs)
|
531
|
+
return vs
|
532
|
+
|
533
|
+
async def __query_collection(
|
534
|
+
self,
|
535
|
+
embedding: list[float],
|
536
|
+
*,
|
537
|
+
k: Optional[int] = None,
|
538
|
+
filter: Optional[dict] = None,
|
539
|
+
**kwargs: Any,
|
540
|
+
) -> Sequence[RowMapping]:
|
541
|
+
"""Perform similarity search query on database."""
|
542
|
+
k = k if k else self.k
|
543
|
+
operator = self.distance_strategy.operator
|
544
|
+
search_function = self.distance_strategy.search_function
|
545
|
+
|
546
|
+
columns = self.metadata_columns + [
|
547
|
+
self.id_column,
|
548
|
+
self.content_column,
|
549
|
+
self.embedding_column,
|
550
|
+
]
|
551
|
+
if self.metadata_json_column:
|
552
|
+
columns.append(self.metadata_json_column)
|
553
|
+
|
554
|
+
column_names = ", ".join(f'"{col}"' for col in columns)
|
555
|
+
|
556
|
+
safe_filter = None
|
557
|
+
filter_dict = None
|
558
|
+
if filter and isinstance(filter, dict):
|
559
|
+
safe_filter, filter_dict = self._create_filter_clause(filter)
|
560
|
+
param_filter = f"WHERE {safe_filter}" if safe_filter else ""
|
561
|
+
inline_embed_func = getattr(self.embedding_service, "embed_query_inline", None)
|
562
|
+
if not embedding and callable(inline_embed_func) and "query" in kwargs:
|
563
|
+
query_embedding = self.embedding_service.embed_query_inline(kwargs["query"]) # type: ignore
|
564
|
+
else:
|
565
|
+
query_embedding = f"{[float(dimension) for dimension in embedding]}"
|
566
|
+
stmt = f"""SELECT {column_names}, {search_function}("{self.embedding_column}", :query_embedding) as distance
|
567
|
+
FROM "{self.schema_name}"."{self.table_name}" {param_filter} ORDER BY "{self.embedding_column}" {operator} :query_embedding LIMIT :k;
|
568
|
+
"""
|
569
|
+
param_dict = {"query_embedding": query_embedding, "k": k}
|
570
|
+
if filter_dict:
|
571
|
+
param_dict.update(filter_dict)
|
572
|
+
if self.index_query_options:
|
573
|
+
async with self.engine.connect() as conn:
|
574
|
+
# Set each query option individually
|
575
|
+
for query_option in self.index_query_options.to_parameter():
|
576
|
+
query_options_stmt = f"SET LOCAL {query_option};"
|
577
|
+
await conn.execute(text(query_options_stmt))
|
578
|
+
result = await conn.execute(text(stmt), param_dict)
|
579
|
+
result_map = result.mappings()
|
580
|
+
results = result_map.fetchall()
|
581
|
+
else:
|
582
|
+
async with self.engine.connect() as conn:
|
583
|
+
result = await conn.execute(text(stmt), param_dict)
|
584
|
+
result_map = result.mappings()
|
585
|
+
results = result_map.fetchall()
|
586
|
+
return results
|
587
|
+
|
588
|
+
async def asimilarity_search(
|
589
|
+
self,
|
590
|
+
query: str,
|
591
|
+
k: Optional[int] = None,
|
592
|
+
filter: Optional[dict] = None,
|
593
|
+
**kwargs: Any,
|
594
|
+
) -> list[Document]:
|
595
|
+
"""Return docs selected by similarity search on query."""
|
596
|
+
inline_embed_func = getattr(self.embedding_service, "embed_query_inline", None)
|
597
|
+
embedding = (
|
598
|
+
[]
|
599
|
+
if callable(inline_embed_func)
|
600
|
+
else await self.embedding_service.aembed_query(text=query)
|
601
|
+
)
|
602
|
+
kwargs["query"] = query
|
603
|
+
|
604
|
+
return await self.asimilarity_search_by_vector(
|
605
|
+
embedding=embedding, k=k, filter=filter, **kwargs
|
606
|
+
)
|
607
|
+
|
608
|
+
def _select_relevance_score_fn(self) -> Callable[[float], float]:
|
609
|
+
"""Select a relevance function based on distance strategy."""
|
610
|
+
# Calculate distance strategy provided in
|
611
|
+
# vectorstore constructor
|
612
|
+
if self.distance_strategy == DistanceStrategy.COSINE_DISTANCE:
|
613
|
+
return self._cosine_relevance_score_fn
|
614
|
+
if self.distance_strategy == DistanceStrategy.INNER_PRODUCT:
|
615
|
+
return self._max_inner_product_relevance_score_fn
|
616
|
+
elif self.distance_strategy == DistanceStrategy.EUCLIDEAN:
|
617
|
+
return self._euclidean_relevance_score_fn
|
618
|
+
|
619
|
+
async def asimilarity_search_with_score(
|
620
|
+
self,
|
621
|
+
query: str,
|
622
|
+
k: Optional[int] = None,
|
623
|
+
filter: Optional[dict] = None,
|
624
|
+
**kwargs: Any,
|
625
|
+
) -> list[tuple[Document, float]]:
|
626
|
+
"""Return docs and distance scores selected by similarity search on query."""
|
627
|
+
inline_embed_func = getattr(self.embedding_service, "embed_query_inline", None)
|
628
|
+
embedding = (
|
629
|
+
[]
|
630
|
+
if callable(inline_embed_func)
|
631
|
+
else await self.embedding_service.aembed_query(text=query)
|
632
|
+
)
|
633
|
+
kwargs["query"] = query
|
634
|
+
|
635
|
+
docs = await self.asimilarity_search_with_score_by_vector(
|
636
|
+
embedding=embedding, k=k, filter=filter, **kwargs
|
637
|
+
)
|
638
|
+
return docs
|
639
|
+
|
640
|
+
async def asimilarity_search_by_vector(
|
641
|
+
self,
|
642
|
+
embedding: list[float],
|
643
|
+
k: Optional[int] = None,
|
644
|
+
filter: Optional[dict] = None,
|
645
|
+
**kwargs: Any,
|
646
|
+
) -> list[Document]:
|
647
|
+
"""Return docs selected by vector similarity search."""
|
648
|
+
docs_and_scores = await self.asimilarity_search_with_score_by_vector(
|
649
|
+
embedding=embedding, k=k, filter=filter, **kwargs
|
650
|
+
)
|
651
|
+
|
652
|
+
return [doc for doc, _ in docs_and_scores]
|
653
|
+
|
654
|
+
async def asimilarity_search_with_score_by_vector(
|
655
|
+
self,
|
656
|
+
embedding: list[float],
|
657
|
+
k: Optional[int] = None,
|
658
|
+
filter: Optional[dict] = None,
|
659
|
+
**kwargs: Any,
|
660
|
+
) -> list[tuple[Document, float]]:
|
661
|
+
"""Return docs and distance scores selected by vector similarity search."""
|
662
|
+
results = await self.__query_collection(
|
663
|
+
embedding=embedding, k=k, filter=filter, **kwargs
|
664
|
+
)
|
665
|
+
|
666
|
+
documents_with_scores = []
|
667
|
+
for row in results:
|
668
|
+
metadata = (
|
669
|
+
row[self.metadata_json_column]
|
670
|
+
if self.metadata_json_column and row[self.metadata_json_column]
|
671
|
+
else {}
|
672
|
+
)
|
673
|
+
for col in self.metadata_columns:
|
674
|
+
metadata[col] = row[col]
|
675
|
+
documents_with_scores.append(
|
676
|
+
(
|
677
|
+
Document(
|
678
|
+
page_content=row[self.content_column],
|
679
|
+
metadata=metadata,
|
680
|
+
id=str(row[self.id_column]),
|
681
|
+
),
|
682
|
+
row["distance"],
|
683
|
+
)
|
684
|
+
)
|
685
|
+
|
686
|
+
return documents_with_scores
|
687
|
+
|
688
|
+
async def amax_marginal_relevance_search(
|
689
|
+
self,
|
690
|
+
query: str,
|
691
|
+
k: Optional[int] = None,
|
692
|
+
fetch_k: Optional[int] = None,
|
693
|
+
lambda_mult: Optional[float] = None,
|
694
|
+
filter: Optional[dict] = None,
|
695
|
+
**kwargs: Any,
|
696
|
+
) -> list[Document]:
|
697
|
+
"""Return docs selected using the maximal marginal relevance."""
|
698
|
+
embedding = await self.embedding_service.aembed_query(text=query)
|
699
|
+
|
700
|
+
return await self.amax_marginal_relevance_search_by_vector(
|
701
|
+
embedding=embedding,
|
702
|
+
k=k,
|
703
|
+
fetch_k=fetch_k,
|
704
|
+
lambda_mult=lambda_mult,
|
705
|
+
filter=filter,
|
706
|
+
**kwargs,
|
707
|
+
)
|
708
|
+
|
709
|
+
async def amax_marginal_relevance_search_by_vector(
|
710
|
+
self,
|
711
|
+
embedding: list[float],
|
712
|
+
k: Optional[int] = None,
|
713
|
+
fetch_k: Optional[int] = None,
|
714
|
+
lambda_mult: Optional[float] = None,
|
715
|
+
filter: Optional[dict] = None,
|
716
|
+
**kwargs: Any,
|
717
|
+
) -> list[Document]:
|
718
|
+
"""Return docs selected using the maximal marginal relevance."""
|
719
|
+
docs_and_scores = (
|
720
|
+
await self.amax_marginal_relevance_search_with_score_by_vector(
|
721
|
+
embedding,
|
722
|
+
k=k,
|
723
|
+
fetch_k=fetch_k,
|
724
|
+
lambda_mult=lambda_mult,
|
725
|
+
filter=filter,
|
726
|
+
**kwargs,
|
727
|
+
)
|
728
|
+
)
|
729
|
+
|
730
|
+
return [result[0] for result in docs_and_scores]
|
731
|
+
|
732
|
+
async def amax_marginal_relevance_search_with_score_by_vector(
|
733
|
+
self,
|
734
|
+
embedding: list[float],
|
735
|
+
k: Optional[int] = None,
|
736
|
+
fetch_k: Optional[int] = None,
|
737
|
+
lambda_mult: Optional[float] = None,
|
738
|
+
filter: Optional[dict] = None,
|
739
|
+
**kwargs: Any,
|
740
|
+
) -> list[tuple[Document, float]]:
|
741
|
+
"""Return docs and distance scores selected using the maximal marginal relevance."""
|
742
|
+
results = await self.__query_collection(
|
743
|
+
embedding=embedding, k=fetch_k, filter=filter, **kwargs
|
744
|
+
)
|
745
|
+
|
746
|
+
k = k if k else self.k
|
747
|
+
fetch_k = fetch_k if fetch_k else self.fetch_k
|
748
|
+
lambda_mult = lambda_mult if lambda_mult else self.lambda_mult
|
749
|
+
embedding_list = [json.loads(row[self.embedding_column]) for row in results]
|
750
|
+
mmr_selected = utils.maximal_marginal_relevance(
|
751
|
+
np.array(embedding, dtype=np.float32),
|
752
|
+
embedding_list,
|
753
|
+
k=k,
|
754
|
+
lambda_mult=lambda_mult,
|
755
|
+
)
|
756
|
+
|
757
|
+
documents_with_scores = []
|
758
|
+
for row in results:
|
759
|
+
metadata = (
|
760
|
+
row[self.metadata_json_column]
|
761
|
+
if self.metadata_json_column and row[self.metadata_json_column]
|
762
|
+
else {}
|
763
|
+
)
|
764
|
+
for col in self.metadata_columns:
|
765
|
+
metadata[col] = row[col]
|
766
|
+
documents_with_scores.append(
|
767
|
+
(
|
768
|
+
Document(
|
769
|
+
page_content=row[self.content_column],
|
770
|
+
metadata=metadata,
|
771
|
+
id=str(row[self.id_column]),
|
772
|
+
),
|
773
|
+
row["distance"],
|
774
|
+
)
|
775
|
+
)
|
776
|
+
|
777
|
+
return [r for i, r in enumerate(documents_with_scores) if i in mmr_selected]
|
778
|
+
|
779
|
+
async def aapply_vector_index(
|
780
|
+
self,
|
781
|
+
index: BaseIndex,
|
782
|
+
name: Optional[str] = None,
|
783
|
+
*,
|
784
|
+
concurrently: bool = False,
|
785
|
+
) -> None:
|
786
|
+
"""Create index in the vector store table."""
|
787
|
+
if isinstance(index, ExactNearestNeighbor):
|
788
|
+
await self.adrop_vector_index()
|
789
|
+
return
|
790
|
+
|
791
|
+
# if extension name is mentioned, create the extension
|
792
|
+
if index.extension_name:
|
793
|
+
async with self.engine.connect() as conn:
|
794
|
+
await conn.execute(
|
795
|
+
text(f"CREATE EXTENSION IF NOT EXISTS {index.extension_name}")
|
796
|
+
)
|
797
|
+
await conn.commit()
|
798
|
+
function = index.get_index_function()
|
799
|
+
|
800
|
+
filter = f"WHERE ({index.partial_indexes})" if index.partial_indexes else ""
|
801
|
+
params = "WITH " + index.index_options()
|
802
|
+
if name is None:
|
803
|
+
if index.name == None:
|
804
|
+
index.name = self.table_name + DEFAULT_INDEX_NAME_SUFFIX
|
805
|
+
name = index.name
|
806
|
+
stmt = f'CREATE INDEX {"CONCURRENTLY" if concurrently else ""} "{name}" ON "{self.schema_name}"."{self.table_name}" USING {index.index_type} ({self.embedding_column} {function}) {params} {filter};'
|
807
|
+
if concurrently:
|
808
|
+
async with self.engine.connect() as conn:
|
809
|
+
autocommit_conn = await conn.execution_options(
|
810
|
+
isolation_level="AUTOCOMMIT"
|
811
|
+
)
|
812
|
+
await autocommit_conn.execute(text(stmt))
|
813
|
+
else:
|
814
|
+
async with self.engine.connect() as conn:
|
815
|
+
await conn.execute(text(stmt))
|
816
|
+
await conn.commit()
|
817
|
+
|
818
|
+
async def areindex(self, index_name: Optional[str] = None) -> None:
|
819
|
+
"""Re-index the vector store table."""
|
820
|
+
index_name = index_name or self.table_name + DEFAULT_INDEX_NAME_SUFFIX
|
821
|
+
query = f'REINDEX INDEX "{index_name}";'
|
822
|
+
async with self.engine.connect() as conn:
|
823
|
+
await conn.execute(text(query))
|
824
|
+
await conn.commit()
|
825
|
+
|
826
|
+
async def adrop_vector_index(
|
827
|
+
self,
|
828
|
+
index_name: Optional[str] = None,
|
829
|
+
) -> None:
|
830
|
+
"""Drop the vector index."""
|
831
|
+
index_name = index_name or self.table_name + DEFAULT_INDEX_NAME_SUFFIX
|
832
|
+
query = f'DROP INDEX IF EXISTS "{index_name}";'
|
833
|
+
async with self.engine.connect() as conn:
|
834
|
+
await conn.execute(text(query))
|
835
|
+
await conn.commit()
|
836
|
+
|
837
|
+
async def is_valid_index(
|
838
|
+
self,
|
839
|
+
index_name: Optional[str] = None,
|
840
|
+
) -> bool:
|
841
|
+
"""Check if index exists in the table."""
|
842
|
+
index_name = index_name or self.table_name + DEFAULT_INDEX_NAME_SUFFIX
|
843
|
+
query = """
|
844
|
+
SELECT tablename, indexname
|
845
|
+
FROM pg_indexes
|
846
|
+
WHERE tablename = :table_name AND schemaname = :schema_name AND indexname = :index_name;
|
847
|
+
"""
|
848
|
+
param_dict = {
|
849
|
+
"table_name": self.table_name,
|
850
|
+
"schema_name": self.schema_name,
|
851
|
+
"index_name": index_name,
|
852
|
+
}
|
853
|
+
async with self.engine.connect() as conn:
|
854
|
+
result = await conn.execute(text(query), param_dict)
|
855
|
+
result_map = result.mappings()
|
856
|
+
results = result_map.fetchall()
|
857
|
+
return bool(len(results) == 1)
|
858
|
+
|
859
|
+
async def aget_by_ids(self, ids: Sequence[str]) -> list[Document]:
|
860
|
+
"""Get documents by ids."""
|
861
|
+
|
862
|
+
columns = self.metadata_columns + [
|
863
|
+
self.id_column,
|
864
|
+
self.content_column,
|
865
|
+
]
|
866
|
+
if self.metadata_json_column:
|
867
|
+
columns.append(self.metadata_json_column)
|
868
|
+
|
869
|
+
column_names = ", ".join(f'"{col}"' for col in columns)
|
870
|
+
|
871
|
+
placeholders = ", ".join(f":id_{i}" for i in range(len(ids)))
|
872
|
+
param_dict = {f"id_{i}": id for i, id in enumerate(ids)}
|
873
|
+
|
874
|
+
query = f'SELECT {column_names} FROM "{self.schema_name}"."{self.table_name}" WHERE "{self.id_column}" IN ({placeholders});'
|
875
|
+
|
876
|
+
async with self.engine.connect() as conn:
|
877
|
+
result = await conn.execute(text(query), param_dict)
|
878
|
+
result_map = result.mappings()
|
879
|
+
results = result_map.fetchall()
|
880
|
+
|
881
|
+
documents = []
|
882
|
+
for row in results:
|
883
|
+
metadata = (
|
884
|
+
row[self.metadata_json_column]
|
885
|
+
if self.metadata_json_column and row[self.metadata_json_column]
|
886
|
+
else {}
|
887
|
+
)
|
888
|
+
for col in self.metadata_columns:
|
889
|
+
metadata[col] = row[col]
|
890
|
+
documents.append(
|
891
|
+
(
|
892
|
+
Document(
|
893
|
+
page_content=row[self.content_column],
|
894
|
+
metadata=metadata,
|
895
|
+
id=str(row[self.id_column]),
|
896
|
+
)
|
897
|
+
)
|
898
|
+
)
|
899
|
+
|
900
|
+
return documents
|
901
|
+
|
902
|
+
def _handle_field_filter(
|
903
|
+
self,
|
904
|
+
*,
|
905
|
+
field: str,
|
906
|
+
value: Any,
|
907
|
+
) -> tuple[str, dict]:
|
908
|
+
"""Create a filter for a specific field.
|
909
|
+
|
910
|
+
Args:
|
911
|
+
field: name of field
|
912
|
+
value: value to filter
|
913
|
+
If provided as is then this will be an equality filter
|
914
|
+
If provided as a dictionary then this will be a filter, the key
|
915
|
+
will be the operator and the value will be the value to filter by
|
916
|
+
|
917
|
+
Returns:
|
918
|
+
sql where query as a string
|
919
|
+
"""
|
920
|
+
if not isinstance(field, str):
|
921
|
+
raise ValueError(
|
922
|
+
f"field should be a string but got: {type(field)} with value: {field}"
|
923
|
+
)
|
924
|
+
|
925
|
+
if field.startswith("$"):
|
926
|
+
raise ValueError(
|
927
|
+
f"Invalid filter condition. Expected a field but got an operator: "
|
928
|
+
f"{field}"
|
929
|
+
)
|
930
|
+
|
931
|
+
# Allow [a-zA-Z0-9_], disallow $ for now until we support escape characters
|
932
|
+
if not field.isidentifier():
|
933
|
+
raise ValueError(
|
934
|
+
f"Invalid field name: {field}. Expected a valid identifier."
|
935
|
+
)
|
936
|
+
|
937
|
+
if isinstance(value, dict):
|
938
|
+
# This is a filter specification
|
939
|
+
if len(value) != 1:
|
940
|
+
raise ValueError(
|
941
|
+
"Invalid filter condition. Expected a value which "
|
942
|
+
"is a dictionary with a single key that corresponds to an operator "
|
943
|
+
f"but got a dictionary with {len(value)} keys. The first few "
|
944
|
+
f"keys are: {list(value.keys())[:3]}"
|
945
|
+
)
|
946
|
+
operator, filter_value = list(value.items())[0]
|
947
|
+
# Verify that that operator is an operator
|
948
|
+
if operator not in SUPPORTED_OPERATORS:
|
949
|
+
raise ValueError(
|
950
|
+
f"Invalid operator: {operator}. "
|
951
|
+
f"Expected one of {SUPPORTED_OPERATORS}"
|
952
|
+
)
|
953
|
+
else: # Then we assume an equality operator
|
954
|
+
operator = "$eq"
|
955
|
+
filter_value = value
|
956
|
+
|
957
|
+
if operator in COMPARISONS_TO_NATIVE:
|
958
|
+
# Then we implement an equality filter
|
959
|
+
# native is trusted input
|
960
|
+
native = COMPARISONS_TO_NATIVE[operator]
|
961
|
+
id = str(uuid.uuid4()).split("-")[0]
|
962
|
+
return f"{field} {native} :{field}_{id}", {f"{field}_{id}": filter_value}
|
963
|
+
elif operator == "$between":
|
964
|
+
# Use AND with two comparisons
|
965
|
+
low, high = filter_value
|
966
|
+
|
967
|
+
return f"({field} BETWEEN :{field}_low AND :{field}_high)", {
|
968
|
+
f"{field}_low": low,
|
969
|
+
f"{field}_high": high,
|
970
|
+
}
|
971
|
+
elif operator in {"$in", "$nin", "$like", "$ilike"}:
|
972
|
+
# We'll do force coercion to text
|
973
|
+
if operator in {"$in", "$nin"}:
|
974
|
+
for val in filter_value:
|
975
|
+
if not isinstance(val, (str, int, float)):
|
976
|
+
raise NotImplementedError(
|
977
|
+
f"Unsupported type: {type(val)} for value: {val}"
|
978
|
+
)
|
979
|
+
|
980
|
+
if isinstance(val, bool): # b/c bool is an instance of int
|
981
|
+
raise NotImplementedError(
|
982
|
+
f"Unsupported type: {type(val)} for value: {val}"
|
983
|
+
)
|
984
|
+
|
985
|
+
if operator in {"$in"}:
|
986
|
+
return f"{field} = ANY(:{field}_in)", {f"{field}_in": filter_value}
|
987
|
+
elif operator in {"$nin"}:
|
988
|
+
return f"{field} <> ALL (:{field}_nin)", {f"{field}_nin": filter_value}
|
989
|
+
elif operator in {"$like"}:
|
990
|
+
return f"({field} LIKE :{field}_like)", {f"{field}_like": filter_value}
|
991
|
+
elif operator in {"$ilike"}:
|
992
|
+
return f"({field} ILIKE :{field}_ilike)", {
|
993
|
+
f"{field}_ilike": filter_value
|
994
|
+
}
|
995
|
+
else:
|
996
|
+
raise NotImplementedError()
|
997
|
+
elif operator == "$exists":
|
998
|
+
if not isinstance(filter_value, bool):
|
999
|
+
raise ValueError(
|
1000
|
+
"Expected a boolean value for $exists "
|
1001
|
+
f"operator, but got: {filter_value}"
|
1002
|
+
)
|
1003
|
+
else:
|
1004
|
+
if filter_value:
|
1005
|
+
return f"({field} IS NOT NULL)", {}
|
1006
|
+
else:
|
1007
|
+
return f"({field} IS NULL)", {}
|
1008
|
+
else:
|
1009
|
+
raise NotImplementedError()
|
1010
|
+
|
1011
|
+
def _create_filter_clause(self, filters: Any) -> tuple[str, dict]:
|
1012
|
+
"""Create LangChain filter representation to matching SQL where clauses
|
1013
|
+
|
1014
|
+
Args:
|
1015
|
+
filters: Dictionary of filters to apply to the query.
|
1016
|
+
|
1017
|
+
Returns:
|
1018
|
+
String containing the sql where query.
|
1019
|
+
"""
|
1020
|
+
|
1021
|
+
if not isinstance(filters, dict):
|
1022
|
+
raise ValueError(
|
1023
|
+
f"Invalid type: Expected a dictionary but got type: {type(filters)}"
|
1024
|
+
)
|
1025
|
+
if len(filters) == 1:
|
1026
|
+
# The only operators allowed at the top level are $AND, $OR, and $NOT
|
1027
|
+
# First check if an operator or a field
|
1028
|
+
key, value = list(filters.items())[0]
|
1029
|
+
if key.startswith("$"):
|
1030
|
+
# Then it's an operator
|
1031
|
+
if key.lower() not in ["$and", "$or", "$not"]:
|
1032
|
+
raise ValueError(
|
1033
|
+
f"Invalid filter condition. Expected $and, $or or $not "
|
1034
|
+
f"but got: {key}"
|
1035
|
+
)
|
1036
|
+
else:
|
1037
|
+
# Then it's a field
|
1038
|
+
return self._handle_field_filter(field=key, value=filters[key])
|
1039
|
+
|
1040
|
+
if key.lower() == "$and" or key.lower() == "$or":
|
1041
|
+
if not isinstance(value, list):
|
1042
|
+
raise ValueError(
|
1043
|
+
f"Expected a list, but got {type(value)} for value: {value}"
|
1044
|
+
)
|
1045
|
+
op = key[1:].upper() # Extract the operator
|
1046
|
+
filter_clause = [self._create_filter_clause(el) for el in value]
|
1047
|
+
if len(filter_clause) > 1:
|
1048
|
+
all_clauses = [clause[0] for clause in filter_clause]
|
1049
|
+
params = {}
|
1050
|
+
for clause in filter_clause:
|
1051
|
+
params.update(clause[1])
|
1052
|
+
return f"({f' {op} '.join(all_clauses)})", params
|
1053
|
+
elif len(filter_clause) == 1:
|
1054
|
+
return filter_clause[0]
|
1055
|
+
else:
|
1056
|
+
raise ValueError(
|
1057
|
+
"Invalid filter condition. Expected a dictionary "
|
1058
|
+
"but got an empty dictionary"
|
1059
|
+
)
|
1060
|
+
elif key.lower() == "$not":
|
1061
|
+
if isinstance(value, list):
|
1062
|
+
not_conditions = [
|
1063
|
+
self._create_filter_clause(item) for item in value
|
1064
|
+
]
|
1065
|
+
all_clauses = [clause[0] for clause in not_conditions]
|
1066
|
+
params = {}
|
1067
|
+
for clause in not_conditions:
|
1068
|
+
params.update(clause[1])
|
1069
|
+
not_stmts = [f"NOT {condition}" for condition in all_clauses]
|
1070
|
+
return f"({' AND '.join(not_stmts)})", params
|
1071
|
+
elif isinstance(value, dict):
|
1072
|
+
not_, params = self._create_filter_clause(value)
|
1073
|
+
return f"(NOT {not_})", params
|
1074
|
+
else:
|
1075
|
+
raise ValueError(
|
1076
|
+
f"Invalid filter condition. Expected a dictionary "
|
1077
|
+
f"or a list but got: {type(value)}"
|
1078
|
+
)
|
1079
|
+
else:
|
1080
|
+
raise ValueError(
|
1081
|
+
f"Invalid filter condition. Expected $and, $or or $not "
|
1082
|
+
f"but got: {key}"
|
1083
|
+
)
|
1084
|
+
elif len(filters) > 1:
|
1085
|
+
# Then all keys have to be fields (they cannot be operators)
|
1086
|
+
for key in filters.keys():
|
1087
|
+
if key.startswith("$"):
|
1088
|
+
raise ValueError(
|
1089
|
+
f"Invalid filter condition. Expected a field but got: {key}"
|
1090
|
+
)
|
1091
|
+
# These should all be fields and combined using an $and operator
|
1092
|
+
and_ = [
|
1093
|
+
self._handle_field_filter(field=k, value=v) for k, v in filters.items()
|
1094
|
+
]
|
1095
|
+
if len(and_) > 1:
|
1096
|
+
all_clauses = [clause[0] for clause in and_]
|
1097
|
+
params = {}
|
1098
|
+
for clause in and_:
|
1099
|
+
params.update(clause[1])
|
1100
|
+
return f"({' AND '.join(all_clauses)})", params
|
1101
|
+
elif len(and_) == 1:
|
1102
|
+
return and_[0]
|
1103
|
+
else:
|
1104
|
+
raise ValueError(
|
1105
|
+
"Invalid filter condition. Expected a dictionary "
|
1106
|
+
"but got an empty dictionary"
|
1107
|
+
)
|
1108
|
+
else:
|
1109
|
+
return "", {}
|
1110
|
+
|
1111
|
+
def get_by_ids(self, ids: Sequence[str]) -> list[Document]:
|
1112
|
+
raise NotImplementedError(
|
1113
|
+
"Sync methods are not implemented for AsyncPGVectorStore. Use PGVectorStore interface instead."
|
1114
|
+
)
|
1115
|
+
|
1116
|
+
def add_texts(
|
1117
|
+
self,
|
1118
|
+
texts: Iterable[str],
|
1119
|
+
metadatas: Optional[list[dict]] = None,
|
1120
|
+
ids: Optional[list] = None,
|
1121
|
+
**kwargs: Any,
|
1122
|
+
) -> list[str]:
|
1123
|
+
raise NotImplementedError(
|
1124
|
+
"Sync methods are not implemented for AsyncPGVectorStore. Use PGVectorStore interface instead."
|
1125
|
+
)
|
1126
|
+
|
1127
|
+
def add_documents(
|
1128
|
+
self,
|
1129
|
+
documents: list[Document],
|
1130
|
+
ids: Optional[list] = None,
|
1131
|
+
**kwargs: Any,
|
1132
|
+
) -> list[str]:
|
1133
|
+
raise NotImplementedError(
|
1134
|
+
"Sync methods are not implemented for AsyncPGVectorStore. Use PGVectorStore interface instead."
|
1135
|
+
)
|
1136
|
+
|
1137
|
+
def delete(
|
1138
|
+
self,
|
1139
|
+
ids: Optional[list] = None,
|
1140
|
+
**kwargs: Any,
|
1141
|
+
) -> Optional[bool]:
|
1142
|
+
raise NotImplementedError(
|
1143
|
+
"Sync methods are not implemented for AsyncPGVectorStore. Use PGVectorStore interface instead."
|
1144
|
+
)
|
1145
|
+
|
1146
|
+
@classmethod
|
1147
|
+
def from_texts( # type: ignore[override]
|
1148
|
+
cls: type[AsyncPGVectorStore],
|
1149
|
+
texts: list[str],
|
1150
|
+
embedding: Embeddings,
|
1151
|
+
engine: PGEngine,
|
1152
|
+
table_name: str,
|
1153
|
+
metadatas: Optional[list[dict]] = None,
|
1154
|
+
ids: Optional[list] = None,
|
1155
|
+
content_column: str = "content",
|
1156
|
+
embedding_column: str = "embedding",
|
1157
|
+
metadata_columns: Optional[list[str]] = None,
|
1158
|
+
ignore_metadata_columns: Optional[list[str]] = None,
|
1159
|
+
id_column: str = "langchain_id",
|
1160
|
+
metadata_json_column: str = "langchain_metadata",
|
1161
|
+
**kwargs: Any,
|
1162
|
+
) -> AsyncPGVectorStore:
|
1163
|
+
raise NotImplementedError(
|
1164
|
+
"Sync methods are not implemented for AsyncPGVectorStore. Use PGVectorStore interface instead."
|
1165
|
+
)
|
1166
|
+
|
1167
|
+
@classmethod
|
1168
|
+
def from_documents( # type: ignore[override]
|
1169
|
+
cls: type[AsyncPGVectorStore],
|
1170
|
+
documents: list[Document],
|
1171
|
+
embedding: Embeddings,
|
1172
|
+
engine: PGEngine,
|
1173
|
+
table_name: str,
|
1174
|
+
ids: Optional[list] = None,
|
1175
|
+
content_column: str = "content",
|
1176
|
+
embedding_column: str = "embedding",
|
1177
|
+
metadata_columns: Optional[list[str]] = None,
|
1178
|
+
ignore_metadata_columns: Optional[list[str]] = None,
|
1179
|
+
id_column: str = "langchain_id",
|
1180
|
+
metadata_json_column: str = "langchain_metadata",
|
1181
|
+
**kwargs: Any,
|
1182
|
+
) -> AsyncPGVectorStore:
|
1183
|
+
raise NotImplementedError(
|
1184
|
+
"Sync methods are not implemented for AsyncPGVectorStore. Use PGVectorStore interface instead."
|
1185
|
+
)
|
1186
|
+
|
1187
|
+
def similarity_search(
|
1188
|
+
self,
|
1189
|
+
query: str,
|
1190
|
+
k: Optional[int] = None,
|
1191
|
+
filter: Optional[dict] = None,
|
1192
|
+
**kwargs: Any,
|
1193
|
+
) -> list[Document]:
|
1194
|
+
raise NotImplementedError(
|
1195
|
+
"Sync methods are not implemented for AsyncPGVectorStore. Use PGVectorStore interface instead."
|
1196
|
+
)
|
1197
|
+
|
1198
|
+
def similarity_search_with_score(
|
1199
|
+
self,
|
1200
|
+
query: str,
|
1201
|
+
k: Optional[int] = None,
|
1202
|
+
filter: Optional[dict] = None,
|
1203
|
+
**kwargs: Any,
|
1204
|
+
) -> list[tuple[Document, float]]:
|
1205
|
+
raise NotImplementedError(
|
1206
|
+
"Sync methods are not implemented for AsyncPGVectorStore. Use PGVectorStore interface instead."
|
1207
|
+
)
|
1208
|
+
|
1209
|
+
def similarity_search_by_vector(
|
1210
|
+
self,
|
1211
|
+
embedding: list[float],
|
1212
|
+
k: Optional[int] = None,
|
1213
|
+
filter: Optional[dict] = None,
|
1214
|
+
**kwargs: Any,
|
1215
|
+
) -> list[Document]:
|
1216
|
+
raise NotImplementedError(
|
1217
|
+
"Sync methods are not implemented for AsyncPGVectorStore. Use PGVectorStore interface instead."
|
1218
|
+
)
|
1219
|
+
|
1220
|
+
def similarity_search_with_score_by_vector(
|
1221
|
+
self,
|
1222
|
+
embedding: list[float],
|
1223
|
+
k: Optional[int] = None,
|
1224
|
+
filter: Optional[dict] = None,
|
1225
|
+
**kwargs: Any,
|
1226
|
+
) -> list[tuple[Document, float]]:
|
1227
|
+
raise NotImplementedError(
|
1228
|
+
"Sync methods are not implemented for AsyncPGVectorStore. Use PGVectorStore interface instead."
|
1229
|
+
)
|
1230
|
+
|
1231
|
+
def max_marginal_relevance_search(
|
1232
|
+
self,
|
1233
|
+
query: str,
|
1234
|
+
k: Optional[int] = None,
|
1235
|
+
fetch_k: Optional[int] = None,
|
1236
|
+
lambda_mult: Optional[float] = None,
|
1237
|
+
filter: Optional[dict] = None,
|
1238
|
+
**kwargs: Any,
|
1239
|
+
) -> list[Document]:
|
1240
|
+
raise NotImplementedError(
|
1241
|
+
"Sync methods are not implemented for AsyncPGVectorStore. Use PGVectorStore interface instead."
|
1242
|
+
)
|
1243
|
+
|
1244
|
+
def max_marginal_relevance_search_by_vector(
|
1245
|
+
self,
|
1246
|
+
embedding: list[float],
|
1247
|
+
k: Optional[int] = None,
|
1248
|
+
fetch_k: Optional[int] = None,
|
1249
|
+
lambda_mult: Optional[float] = None,
|
1250
|
+
filter: Optional[dict] = None,
|
1251
|
+
**kwargs: Any,
|
1252
|
+
) -> list[Document]:
|
1253
|
+
raise NotImplementedError(
|
1254
|
+
"Sync methods are not implemented for AsyncPGVectorStore. Use PGVectorStore interface instead."
|
1255
|
+
)
|
1256
|
+
|
1257
|
+
def max_marginal_relevance_search_with_score_by_vector(
|
1258
|
+
self,
|
1259
|
+
embedding: list[float],
|
1260
|
+
k: Optional[int] = None,
|
1261
|
+
fetch_k: Optional[int] = None,
|
1262
|
+
lambda_mult: Optional[float] = None,
|
1263
|
+
filter: Optional[dict] = None,
|
1264
|
+
**kwargs: Any,
|
1265
|
+
) -> list[tuple[Document, float]]:
|
1266
|
+
raise NotImplementedError(
|
1267
|
+
"Sync methods are not implemented for AsyncPGVectorStore. Use PGVectorStore interface instead."
|
1268
|
+
)
|