langchain-postgres 0.0.13__py3-none-any.whl → 0.0.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,348 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ from dataclasses import dataclass
5
+ from threading import Thread
6
+ from typing import Any, Awaitable, Optional, TypedDict, TypeVar, Union
7
+
8
+ from sqlalchemy import text
9
+ from sqlalchemy.engine import URL
10
+ from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
11
+
12
+ T = TypeVar("T")
13
+
14
+
15
+ class ColumnDict(TypedDict):
16
+ name: str
17
+ data_type: str
18
+ nullable: bool
19
+
20
+
21
+ @dataclass
22
+ class Column:
23
+ name: str
24
+ data_type: str
25
+ nullable: bool = True
26
+
27
+ def __post_init__(self) -> None:
28
+ """Check if initialization parameters are valid.
29
+
30
+ Raises:
31
+ ValueError: If Column name is not string.
32
+ ValueError: If data_type is not type string.
33
+ """
34
+
35
+ if not isinstance(self.name, str):
36
+ raise ValueError("Column name must be type string")
37
+ if not isinstance(self.data_type, str):
38
+ raise ValueError("Column data_type must be type string")
39
+
40
+
41
+ class PGEngine:
42
+ """A class for managing connections to a Postgres database."""
43
+
44
+ _default_loop: Optional[asyncio.AbstractEventLoop] = None
45
+ _default_thread: Optional[Thread] = None
46
+ __create_key = object()
47
+
48
+ def __init__(
49
+ self,
50
+ key: object,
51
+ pool: AsyncEngine,
52
+ loop: Optional[asyncio.AbstractEventLoop],
53
+ thread: Optional[Thread],
54
+ ) -> None:
55
+ """PGEngine constructor.
56
+
57
+ Args:
58
+ key (object): Prevent direct constructor usage.
59
+ pool (AsyncEngine): Async engine connection pool.
60
+ loop (Optional[asyncio.AbstractEventLoop]): Async event loop used to create the engine.
61
+ thread (Optional[Thread]): Thread used to create the engine async.
62
+
63
+ Raises:
64
+ Exception: If the constructor is called directly by the user.
65
+ """
66
+
67
+ if key != PGEngine.__create_key:
68
+ raise Exception(
69
+ "Only create class through 'from_connection_string' or 'from_engine' methods!"
70
+ )
71
+ self._pool = pool
72
+ self._loop = loop
73
+ self._thread = thread
74
+
75
+ @classmethod
76
+ def from_engine(
77
+ cls: type[PGEngine],
78
+ engine: AsyncEngine,
79
+ loop: Optional[asyncio.AbstractEventLoop] = None,
80
+ ) -> PGEngine:
81
+ """Create an PGEngine instance from an AsyncEngine."""
82
+ return cls(cls.__create_key, engine, loop, None)
83
+
84
+ @classmethod
85
+ def from_connection_string(
86
+ cls,
87
+ url: str | URL,
88
+ **kwargs: Any,
89
+ ) -> PGEngine:
90
+ """Create an PGEngine instance from arguments
91
+
92
+ Args:
93
+ url (Optional[str]): the URL used to connect to a database. Use url or set other arguments.
94
+
95
+ Raises:
96
+ ValueError: If not all database url arguments are specified
97
+
98
+ Returns:
99
+ PGEngine
100
+ """
101
+ # Running a loop in a background thread allows us to support
102
+ # async methods from non-async environments
103
+ if cls._default_loop is None:
104
+ cls._default_loop = asyncio.new_event_loop()
105
+ cls._default_thread = Thread(
106
+ target=cls._default_loop.run_forever, daemon=True
107
+ )
108
+ cls._default_thread.start()
109
+
110
+ engine = create_async_engine(url, **kwargs)
111
+ return cls(cls.__create_key, engine, cls._default_loop, cls._default_thread)
112
+
113
+ async def _run_as_async(self, coro: Awaitable[T]) -> T:
114
+ """Run an async coroutine asynchronously"""
115
+ # If a loop has not been provided, attempt to run in current thread
116
+ if not self._loop:
117
+ return await coro
118
+ # Otherwise, run in the background thread
119
+ return await asyncio.wrap_future(
120
+ asyncio.run_coroutine_threadsafe(coro, self._loop)
121
+ )
122
+
123
+ def _run_as_sync(self, coro: Awaitable[T]) -> T:
124
+ """Run an async coroutine synchronously"""
125
+ if not self._loop:
126
+ raise Exception(
127
+ "Engine was initialized without a background loop and cannot call sync methods."
128
+ )
129
+ return asyncio.run_coroutine_threadsafe(coro, self._loop).result()
130
+
131
+ async def close(self) -> None:
132
+ """Dispose of connection pool"""
133
+ await self._run_as_async(self._pool.dispose())
134
+
135
+ def _escape_postgres_identifier(self, name: str) -> str:
136
+ return name.replace('"', '""')
137
+
138
+ def _validate_column_dict(self, col: ColumnDict) -> None:
139
+ if not isinstance(col.get("name"), str):
140
+ raise TypeError("The 'name' field must be a string.")
141
+ if not isinstance(col.get("data_type"), str):
142
+ raise TypeError("The 'data_type' field must be a string.")
143
+ if not isinstance(col.get("nullable"), bool):
144
+ raise TypeError("The 'nullable' field must be a boolean.")
145
+
146
+ async def _ainit_vectorstore_table(
147
+ self,
148
+ table_name: str,
149
+ vector_size: int,
150
+ *,
151
+ schema_name: str = "public",
152
+ content_column: str = "content",
153
+ embedding_column: str = "embedding",
154
+ metadata_columns: Optional[list[Union[Column, ColumnDict]]] = None,
155
+ metadata_json_column: str = "langchain_metadata",
156
+ id_column: Union[str, Column, ColumnDict] = "langchain_id",
157
+ overwrite_existing: bool = False,
158
+ store_metadata: bool = True,
159
+ ) -> None:
160
+ """
161
+ Create a table for saving of vectors to be used with PGVectorStore.
162
+
163
+ Args:
164
+ table_name (str): The database table name.
165
+ vector_size (int): Vector size for the embedding model to be used.
166
+ schema_name (str): The schema name.
167
+ Default: "public".
168
+ content_column (str): Name of the column to store document content.
169
+ Default: "page_content".
170
+ embedding_column (str) : Name of the column to store vector embeddings.
171
+ Default: "embedding".
172
+ metadata_columns (Optional[list[Union[Column, ColumnDict]]]): A list of Columns to create for custom
173
+ metadata. Default: None. Optional.
174
+ metadata_json_column (str): The column to store extra metadata in JSON format.
175
+ Default: "langchain_metadata". Optional.
176
+ id_column (Union[str, Column, ColumnDict]) : Column to store ids.
177
+ Default: "langchain_id" column name with data type UUID. Optional.
178
+ overwrite_existing (bool): Whether to drop existing table. Default: False.
179
+ store_metadata (bool): Whether to store metadata in the table.
180
+ Default: True.
181
+
182
+ Raises:
183
+ :class:`DuplicateTableError <asyncpg.exceptions.DuplicateTableError>`: if table already exists.
184
+ :class:`UndefinedObjectError <asyncpg.exceptions.UndefinedObjectError>`: if the data type of the id column is not a postgreSQL data type.
185
+ """
186
+
187
+ schema_name = self._escape_postgres_identifier(schema_name)
188
+ table_name = self._escape_postgres_identifier(table_name)
189
+ content_column = self._escape_postgres_identifier(content_column)
190
+ embedding_column = self._escape_postgres_identifier(embedding_column)
191
+ if metadata_columns is None:
192
+ metadata_columns = []
193
+ else:
194
+ for col in metadata_columns:
195
+ if isinstance(col, Column):
196
+ col.name = self._escape_postgres_identifier(col.name)
197
+ elif isinstance(col, dict):
198
+ self._validate_column_dict(col)
199
+ col["name"] = self._escape_postgres_identifier(col["name"])
200
+ if isinstance(id_column, str):
201
+ id_column = self._escape_postgres_identifier(id_column)
202
+ elif isinstance(id_column, Column):
203
+ id_column.name = self._escape_postgres_identifier(id_column.name)
204
+ else:
205
+ self._validate_column_dict(id_column)
206
+ id_column["name"] = self._escape_postgres_identifier(id_column["name"])
207
+
208
+ async with self._pool.connect() as conn:
209
+ await conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector"))
210
+ await conn.commit()
211
+
212
+ if overwrite_existing:
213
+ async with self._pool.connect() as conn:
214
+ await conn.execute(
215
+ text(f'DROP TABLE IF EXISTS "{schema_name}"."{table_name}"')
216
+ )
217
+ await conn.commit()
218
+
219
+ if isinstance(id_column, str):
220
+ id_data_type = "UUID"
221
+ id_column_name = id_column
222
+ elif isinstance(id_column, Column):
223
+ id_data_type = id_column.data_type
224
+ id_column_name = id_column.name
225
+ else:
226
+ id_data_type = id_column["data_type"]
227
+ id_column_name = id_column["name"]
228
+
229
+ query = f"""CREATE TABLE "{schema_name}"."{table_name}"(
230
+ "{id_column_name}" {id_data_type} PRIMARY KEY,
231
+ "{content_column}" TEXT NOT NULL,
232
+ "{embedding_column}" vector({vector_size}) NOT NULL"""
233
+ for column in metadata_columns:
234
+ if isinstance(column, Column):
235
+ nullable = "NOT NULL" if not column.nullable else ""
236
+ query += f',\n"{column.name}" {column.data_type} {nullable}'
237
+ elif isinstance(column, dict):
238
+ nullable = "NOT NULL" if not column["nullable"] else ""
239
+ query += f',\n"{column["name"]}" {column["data_type"]} {nullable}'
240
+ if store_metadata:
241
+ query += f""",\n"{metadata_json_column}" JSON"""
242
+ query += "\n);"
243
+
244
+ async with self._pool.connect() as conn:
245
+ await conn.execute(text(query))
246
+ await conn.commit()
247
+
248
+ async def ainit_vectorstore_table(
249
+ self,
250
+ table_name: str,
251
+ vector_size: int,
252
+ *,
253
+ schema_name: str = "public",
254
+ content_column: str = "content",
255
+ embedding_column: str = "embedding",
256
+ metadata_columns: Optional[list[Union[Column, ColumnDict]]] = None,
257
+ metadata_json_column: str = "langchain_metadata",
258
+ id_column: Union[str, Column, ColumnDict] = "langchain_id",
259
+ overwrite_existing: bool = False,
260
+ store_metadata: bool = True,
261
+ ) -> None:
262
+ """
263
+ Create a table for saving of vectors to be used with PGVectorStore.
264
+
265
+ Args:
266
+ table_name (str): The database table name.
267
+ vector_size (int): Vector size for the embedding model to be used.
268
+ schema_name (str): The schema name.
269
+ Default: "public".
270
+ content_column (str): Name of the column to store document content.
271
+ Default: "page_content".
272
+ embedding_column (str) : Name of the column to store vector embeddings.
273
+ Default: "embedding".
274
+ metadata_columns (Optional[list[Union[Column, ColumnDict]]]): A list of Columns to create for custom
275
+ metadata. Default: None. Optional.
276
+ metadata_json_column (str): The column to store extra metadata in JSON format.
277
+ Default: "langchain_metadata". Optional.
278
+ id_column (Union[str, Column, ColumnDict]) : Column to store ids.
279
+ Default: "langchain_id" column name with data type UUID. Optional.
280
+ overwrite_existing (bool): Whether to drop existing table. Default: False.
281
+ store_metadata (bool): Whether to store metadata in the table.
282
+ Default: True.
283
+ """
284
+ await self._run_as_async(
285
+ self._ainit_vectorstore_table(
286
+ table_name,
287
+ vector_size,
288
+ schema_name=schema_name,
289
+ content_column=content_column,
290
+ embedding_column=embedding_column,
291
+ metadata_columns=metadata_columns,
292
+ metadata_json_column=metadata_json_column,
293
+ id_column=id_column,
294
+ overwrite_existing=overwrite_existing,
295
+ store_metadata=store_metadata,
296
+ )
297
+ )
298
+
299
+ def init_vectorstore_table(
300
+ self,
301
+ table_name: str,
302
+ vector_size: int,
303
+ *,
304
+ schema_name: str = "public",
305
+ content_column: str = "content",
306
+ embedding_column: str = "embedding",
307
+ metadata_columns: Optional[list[Union[Column, ColumnDict]]] = None,
308
+ metadata_json_column: str = "langchain_metadata",
309
+ id_column: Union[str, Column, ColumnDict] = "langchain_id",
310
+ overwrite_existing: bool = False,
311
+ store_metadata: bool = True,
312
+ ) -> None:
313
+ """
314
+ Create a table for saving of vectors to be used with PGVectorStore.
315
+
316
+ Args:
317
+ table_name (str): The database table name.
318
+ vector_size (int): Vector size for the embedding model to be used.
319
+ schema_name (str): The schema name.
320
+ Default: "public".
321
+ content_column (str): Name of the column to store document content.
322
+ Default: "page_content".
323
+ embedding_column (str) : Name of the column to store vector embeddings.
324
+ Default: "embedding".
325
+ metadata_columns (Optional[list[Union[Column, ColumnDict]]]): A list of Columns to create for custom
326
+ metadata. Default: None. Optional.
327
+ metadata_json_column (str): The column to store extra metadata in JSON format.
328
+ Default: "langchain_metadata". Optional.
329
+ id_column (Union[str, Column, ColumnDict]) : Column to store ids.
330
+ Default: "langchain_id" column name with data type UUID. Optional.
331
+ overwrite_existing (bool): Whether to drop existing table. Default: False.
332
+ store_metadata (bool): Whether to store metadata in the table.
333
+ Default: True.
334
+ """
335
+ self._run_as_sync(
336
+ self._ainit_vectorstore_table(
337
+ table_name,
338
+ vector_size,
339
+ schema_name=schema_name,
340
+ content_column=content_column,
341
+ embedding_column=embedding_column,
342
+ metadata_columns=metadata_columns,
343
+ metadata_json_column=metadata_json_column,
344
+ id_column=id_column,
345
+ overwrite_existing=overwrite_existing,
346
+ store_metadata=store_metadata,
347
+ )
348
+ )
@@ -0,0 +1,155 @@
1
+ """Index class to add vector indexes on the PGVectorStore.
2
+
3
+ Learn more about vector indexes at https://github.com/pgvector/pgvector?tab=readme-ov-file#indexing
4
+ """
5
+
6
+ import enum
7
+ import re
8
+ import warnings
9
+ from abc import ABC, abstractmethod
10
+ from dataclasses import dataclass, field
11
+ from typing import Optional
12
+
13
+
14
+ @dataclass
15
+ class StrategyMixin:
16
+ operator: str
17
+ search_function: str
18
+ index_function: str
19
+
20
+
21
+ class DistanceStrategy(StrategyMixin, enum.Enum):
22
+ """Enumerator of the Distance strategies."""
23
+
24
+ EUCLIDEAN = "<->", "l2_distance", "vector_l2_ops"
25
+ COSINE_DISTANCE = "<=>", "cosine_distance", "vector_cosine_ops"
26
+ INNER_PRODUCT = "<#>", "inner_product", "vector_ip_ops"
27
+
28
+
29
+ DEFAULT_DISTANCE_STRATEGY: DistanceStrategy = DistanceStrategy.COSINE_DISTANCE
30
+ DEFAULT_INDEX_NAME_SUFFIX: str = "langchainvectorindex"
31
+
32
+
33
+ def validate_identifier(identifier: str) -> None:
34
+ if re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", identifier) is None:
35
+ raise ValueError(
36
+ f"Invalid identifier: {identifier}. Identifiers must start with a letter or underscore, and subsequent characters can be letters, digits, or underscores."
37
+ )
38
+
39
+
40
+ @dataclass
41
+ class BaseIndex(ABC):
42
+ """
43
+ Abstract base class for defining vector indexes.
44
+
45
+ Attributes:
46
+ name (Optional[str]): A human-readable name for the index. Defaults to None.
47
+ index_type (str): A string identifying the type of index. Defaults to "base".
48
+ distance_strategy (DistanceStrategy): The strategy used to calculate distances
49
+ between vectors in the index. Defaults to DistanceStrategy.COSINE_DISTANCE.
50
+ partial_indexes (Optional[list[str]]): A list of names of partial indexes. Defaults to None.
51
+ extension_name (Optional[str]): The name of the extension to be created for the index, if any. Defaults to None.
52
+ """
53
+
54
+ name: Optional[str] = None
55
+ index_type: str = "base"
56
+ distance_strategy: DistanceStrategy = field(
57
+ default_factory=lambda: DistanceStrategy.COSINE_DISTANCE
58
+ )
59
+ partial_indexes: Optional[list[str]] = None
60
+ extension_name: Optional[str] = None
61
+
62
+ @abstractmethod
63
+ def index_options(self) -> str:
64
+ """Set index query options for vector store initialization."""
65
+ raise NotImplementedError(
66
+ "index_options method must be implemented by subclass"
67
+ )
68
+
69
+ def get_index_function(self) -> str:
70
+ return self.distance_strategy.index_function
71
+
72
+ def __post_init__(self) -> None:
73
+ """Check if initialization parameters are valid.
74
+
75
+ Raises:
76
+ ValueError: extension_name is a valid postgreSQL identifier
77
+ """
78
+
79
+ if self.extension_name:
80
+ validate_identifier(self.extension_name)
81
+ if self.index_type:
82
+ validate_identifier(self.index_type)
83
+
84
+
85
+ @dataclass
86
+ class ExactNearestNeighbor(BaseIndex):
87
+ index_type: str = "exactnearestneighbor"
88
+
89
+
90
+ @dataclass
91
+ class QueryOptions(ABC):
92
+ @abstractmethod
93
+ def to_parameter(self) -> list[str]:
94
+ """Convert index attributes to list of configurations."""
95
+ raise NotImplementedError("to_parameter method must be implemented by subclass")
96
+
97
+ @abstractmethod
98
+ def to_string(self) -> str:
99
+ """Convert index attributes to string."""
100
+ raise NotImplementedError("to_string method must be implemented by subclass")
101
+
102
+
103
+ @dataclass
104
+ class HNSWIndex(BaseIndex):
105
+ index_type: str = "hnsw"
106
+ m: int = 16
107
+ ef_construction: int = 64
108
+
109
+ def index_options(self) -> str:
110
+ """Set index query options for vector store initialization."""
111
+ return f"(m = {self.m}, ef_construction = {self.ef_construction})"
112
+
113
+
114
+ @dataclass
115
+ class HNSWQueryOptions(QueryOptions):
116
+ ef_search: int = 40
117
+
118
+ def to_parameter(self) -> list[str]:
119
+ """Convert index attributes to list of configurations."""
120
+ return [f"hnsw.ef_search = {self.ef_search}"]
121
+
122
+ def to_string(self) -> str:
123
+ """Convert index attributes to string."""
124
+ warnings.warn(
125
+ "to_string is deprecated, use to_parameter instead.",
126
+ DeprecationWarning,
127
+ )
128
+ return f"hnsw.ef_search = {self.ef_search}"
129
+
130
+
131
+ @dataclass
132
+ class IVFFlatIndex(BaseIndex):
133
+ index_type: str = "ivfflat"
134
+ lists: int = 100
135
+
136
+ def index_options(self) -> str:
137
+ """Set index query options for vector store initialization."""
138
+ return f"(lists = {self.lists})"
139
+
140
+
141
+ @dataclass
142
+ class IVFFlatQueryOptions(QueryOptions):
143
+ probes: int = 1
144
+
145
+ def to_parameter(self) -> list[str]:
146
+ """Convert index attributes to list of configurations."""
147
+ return [f"ivfflat.probes = {self.probes}"]
148
+
149
+ def to_string(self) -> str:
150
+ """Convert index attributes to string."""
151
+ warnings.warn(
152
+ "to_string is deprecated, use to_parameter instead.",
153
+ DeprecationWarning,
154
+ )
155
+ return f"ivfflat.probes = {self.probes}"