langchain-postgres 0.0.12__py3-none-any.whl → 0.0.14rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,351 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ from dataclasses import dataclass
5
+ from threading import Thread
6
+ from typing import TYPE_CHECKING, Any, Awaitable, Optional, TypeVar, TypedDict, Union
7
+
8
+ from sqlalchemy import text
9
+ from sqlalchemy.engine import URL
10
+ from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
11
+
12
+ if TYPE_CHECKING:
13
+ import asyncpg # type: ignore
14
+
15
+ T = TypeVar("T")
16
+
17
+
18
+ class ColumnDict(TypedDict):
19
+ name: str
20
+ data_type: str
21
+ nullable: bool
22
+
23
+
24
+ @dataclass
25
+ class Column:
26
+ name: str
27
+ data_type: str
28
+ nullable: bool = True
29
+
30
+ def __post_init__(self) -> None:
31
+ """Check if initialization parameters are valid.
32
+
33
+ Raises:
34
+ ValueError: If Column name is not string.
35
+ ValueError: If data_type is not type string.
36
+ """
37
+
38
+ if not isinstance(self.name, str):
39
+ raise ValueError("Column name must be type string")
40
+ if not isinstance(self.data_type, str):
41
+ raise ValueError("Column data_type must be type string")
42
+
43
+
44
+ class PGEngine:
45
+ """A class for managing connections to a Postgres database."""
46
+
47
+ _default_loop: Optional[asyncio.AbstractEventLoop] = None
48
+ _default_thread: Optional[Thread] = None
49
+ __create_key = object()
50
+
51
+ def __init__(
52
+ self,
53
+ key: object,
54
+ pool: AsyncEngine,
55
+ loop: Optional[asyncio.AbstractEventLoop],
56
+ thread: Optional[Thread],
57
+ ) -> None:
58
+ """PGEngine constructor.
59
+
60
+ Args:
61
+ key (object): Prevent direct constructor usage.
62
+ pool (AsyncEngine): Async engine connection pool.
63
+ loop (Optional[asyncio.AbstractEventLoop]): Async event loop used to create the engine.
64
+ thread (Optional[Thread]): Thread used to create the engine async.
65
+
66
+ Raises:
67
+ Exception: If the constructor is called directly by the user.
68
+ """
69
+
70
+ if key != PGEngine.__create_key:
71
+ raise Exception(
72
+ "Only create class through 'from_connection_string' or 'from_engine' methods!"
73
+ )
74
+ self._pool = pool
75
+ self._loop = loop
76
+ self._thread = thread
77
+
78
+ @classmethod
79
+ def from_engine(
80
+ cls: type[PGEngine],
81
+ engine: AsyncEngine,
82
+ loop: Optional[asyncio.AbstractEventLoop] = None,
83
+ ) -> PGEngine:
84
+ """Create an PGEngine instance from an AsyncEngine."""
85
+ return cls(cls.__create_key, engine, loop, None)
86
+
87
+ @classmethod
88
+ def from_connection_string(
89
+ cls,
90
+ url: str | URL,
91
+ **kwargs: Any,
92
+ ) -> PGEngine:
93
+ """Create an PGEngine instance from arguments
94
+
95
+ Args:
96
+ url (Optional[str]): the URL used to connect to a database. Use url or set other arguments.
97
+
98
+ Raises:
99
+ ValueError: If not all database url arguments are specified
100
+
101
+ Returns:
102
+ PGEngine
103
+ """
104
+ # Running a loop in a background thread allows us to support
105
+ # async methods from non-async environments
106
+ if cls._default_loop is None:
107
+ cls._default_loop = asyncio.new_event_loop()
108
+ cls._default_thread = Thread(
109
+ target=cls._default_loop.run_forever, daemon=True
110
+ )
111
+ cls._default_thread.start()
112
+
113
+ engine = create_async_engine(url, **kwargs)
114
+ return cls(cls.__create_key, engine, cls._default_loop, cls._default_thread)
115
+
116
+ async def _run_as_async(self, coro: Awaitable[T]) -> T:
117
+ """Run an async coroutine asynchronously"""
118
+ # If a loop has not been provided, attempt to run in current thread
119
+ if not self._loop:
120
+ return await coro
121
+ # Otherwise, run in the background thread
122
+ return await asyncio.wrap_future(
123
+ asyncio.run_coroutine_threadsafe(coro, self._loop)
124
+ )
125
+
126
+ def _run_as_sync(self, coro: Awaitable[T]) -> T:
127
+ """Run an async coroutine synchronously"""
128
+ if not self._loop:
129
+ raise Exception(
130
+ "Engine was initialized without a background loop and cannot call sync methods."
131
+ )
132
+ return asyncio.run_coroutine_threadsafe(coro, self._loop).result()
133
+
134
+ async def close(self) -> None:
135
+ """Dispose of connection pool"""
136
+ await self._run_as_async(self._pool.dispose())
137
+
138
+ def _escape_postgres_identifier(self, name: str) -> str:
139
+ return name.replace('"', '""')
140
+
141
+ def _validate_column_dict(self, col: ColumnDict) -> None:
142
+ if not isinstance(col.get("name"), str):
143
+ raise TypeError("The 'name' field must be a string.")
144
+ if not isinstance(col.get("data_type"), str):
145
+ raise TypeError("The 'data_type' field must be a string.")
146
+ if not isinstance(col.get("nullable"), bool):
147
+ raise TypeError("The 'nullable' field must be a boolean.")
148
+
149
+ async def _ainit_vectorstore_table(
150
+ self,
151
+ table_name: str,
152
+ vector_size: int,
153
+ *,
154
+ schema_name: str = "public",
155
+ content_column: str = "content",
156
+ embedding_column: str = "embedding",
157
+ metadata_columns: Optional[list[Union[Column, ColumnDict]]] = None,
158
+ metadata_json_column: str = "langchain_metadata",
159
+ id_column: Union[str, Column, ColumnDict] = "langchain_id",
160
+ overwrite_existing: bool = False,
161
+ store_metadata: bool = True,
162
+ ) -> None:
163
+ """
164
+ Create a table for saving of vectors to be used with PGVectorStore.
165
+
166
+ Args:
167
+ table_name (str): The database table name.
168
+ vector_size (int): Vector size for the embedding model to be used.
169
+ schema_name (str): The schema name.
170
+ Default: "public".
171
+ content_column (str): Name of the column to store document content.
172
+ Default: "page_content".
173
+ embedding_column (str) : Name of the column to store vector embeddings.
174
+ Default: "embedding".
175
+ metadata_columns (Optional[list[Union[Column, ColumnDict]]]): A list of Columns to create for custom
176
+ metadata. Default: None. Optional.
177
+ metadata_json_column (str): The column to store extra metadata in JSON format.
178
+ Default: "langchain_metadata". Optional.
179
+ id_column (Union[str, Column, ColumnDict]) : Column to store ids.
180
+ Default: "langchain_id" column name with data type UUID. Optional.
181
+ overwrite_existing (bool): Whether to drop existing table. Default: False.
182
+ store_metadata (bool): Whether to store metadata in the table.
183
+ Default: True.
184
+
185
+ Raises:
186
+ :class:`DuplicateTableError <asyncpg.exceptions.DuplicateTableError>`: if table already exists.
187
+ :class:`UndefinedObjectError <asyncpg.exceptions.UndefinedObjectError>`: if the data type of the id column is not a postgreSQL data type.
188
+ """
189
+
190
+ schema_name = self._escape_postgres_identifier(schema_name)
191
+ table_name = self._escape_postgres_identifier(table_name)
192
+ content_column = self._escape_postgres_identifier(content_column)
193
+ embedding_column = self._escape_postgres_identifier(embedding_column)
194
+ if metadata_columns is None:
195
+ metadata_columns = []
196
+ else:
197
+ for col in metadata_columns:
198
+ if isinstance(col, Column):
199
+ col.name = self._escape_postgres_identifier(col.name)
200
+ elif isinstance(col, dict):
201
+ self._validate_column_dict(col)
202
+ col["name"] = self._escape_postgres_identifier(col["name"])
203
+ if isinstance(id_column, str):
204
+ id_column = self._escape_postgres_identifier(id_column)
205
+ elif isinstance(id_column, Column):
206
+ id_column.name = self._escape_postgres_identifier(id_column.name)
207
+ else:
208
+ self._validate_column_dict(id_column)
209
+ id_column["name"] = self._escape_postgres_identifier(id_column["name"])
210
+
211
+ async with self._pool.connect() as conn:
212
+ await conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector"))
213
+ await conn.commit()
214
+
215
+ if overwrite_existing:
216
+ async with self._pool.connect() as conn:
217
+ await conn.execute(
218
+ text(f'DROP TABLE IF EXISTS "{schema_name}"."{table_name}"')
219
+ )
220
+ await conn.commit()
221
+
222
+ if isinstance(id_column, str):
223
+ id_data_type = "UUID"
224
+ id_column_name = id_column
225
+ elif isinstance(id_column, Column):
226
+ id_data_type = id_column.data_type
227
+ id_column_name = id_column.name
228
+ else:
229
+ id_data_type = id_column["data_type"]
230
+ id_column_name = id_column["name"]
231
+
232
+ query = f"""CREATE TABLE "{schema_name}"."{table_name}"(
233
+ "{id_column_name}" {id_data_type} PRIMARY KEY,
234
+ "{content_column}" TEXT NOT NULL,
235
+ "{embedding_column}" vector({vector_size}) NOT NULL"""
236
+ for column in metadata_columns:
237
+ if isinstance(column, Column):
238
+ nullable = "NOT NULL" if not column.nullable else ""
239
+ query += f',\n"{column.name}" {column.data_type} {nullable}'
240
+ elif isinstance(column, dict):
241
+ nullable = "NOT NULL" if not column["nullable"] else ""
242
+ query += f',\n"{column["name"]}" {column["data_type"]} {nullable}'
243
+ if store_metadata:
244
+ query += f""",\n"{metadata_json_column}" JSON"""
245
+ query += "\n);"
246
+
247
+ async with self._pool.connect() as conn:
248
+ await conn.execute(text(query))
249
+ await conn.commit()
250
+
251
+ async def ainit_vectorstore_table(
252
+ self,
253
+ table_name: str,
254
+ vector_size: int,
255
+ *,
256
+ schema_name: str = "public",
257
+ content_column: str = "content",
258
+ embedding_column: str = "embedding",
259
+ metadata_columns: Optional[list[Union[Column, ColumnDict]]] = None,
260
+ metadata_json_column: str = "langchain_metadata",
261
+ id_column: Union[str, Column, ColumnDict] = "langchain_id",
262
+ overwrite_existing: bool = False,
263
+ store_metadata: bool = True,
264
+ ) -> None:
265
+ """
266
+ Create a table for saving of vectors to be used with PGVectorStore.
267
+
268
+ Args:
269
+ table_name (str): The database table name.
270
+ vector_size (int): Vector size for the embedding model to be used.
271
+ schema_name (str): The schema name.
272
+ Default: "public".
273
+ content_column (str): Name of the column to store document content.
274
+ Default: "page_content".
275
+ embedding_column (str) : Name of the column to store vector embeddings.
276
+ Default: "embedding".
277
+ metadata_columns (Optional[list[Union[Column, ColumnDict]]]): A list of Columns to create for custom
278
+ metadata. Default: None. Optional.
279
+ metadata_json_column (str): The column to store extra metadata in JSON format.
280
+ Default: "langchain_metadata". Optional.
281
+ id_column (Union[str, Column, ColumnDict]) : Column to store ids.
282
+ Default: "langchain_id" column name with data type UUID. Optional.
283
+ overwrite_existing (bool): Whether to drop existing table. Default: False.
284
+ store_metadata (bool): Whether to store metadata in the table.
285
+ Default: True.
286
+ """
287
+ await self._run_as_async(
288
+ self._ainit_vectorstore_table(
289
+ table_name,
290
+ vector_size,
291
+ schema_name=schema_name,
292
+ content_column=content_column,
293
+ embedding_column=embedding_column,
294
+ metadata_columns=metadata_columns,
295
+ metadata_json_column=metadata_json_column,
296
+ id_column=id_column,
297
+ overwrite_existing=overwrite_existing,
298
+ store_metadata=store_metadata,
299
+ )
300
+ )
301
+
302
+ def init_vectorstore_table(
303
+ self,
304
+ table_name: str,
305
+ vector_size: int,
306
+ *,
307
+ schema_name: str = "public",
308
+ content_column: str = "content",
309
+ embedding_column: str = "embedding",
310
+ metadata_columns: Optional[list[Union[Column, ColumnDict]]] = None,
311
+ metadata_json_column: str = "langchain_metadata",
312
+ id_column: Union[str, Column, ColumnDict] = "langchain_id",
313
+ overwrite_existing: bool = False,
314
+ store_metadata: bool = True,
315
+ ) -> None:
316
+ """
317
+ Create a table for saving of vectors to be used with PGVectorStore.
318
+
319
+ Args:
320
+ table_name (str): The database table name.
321
+ vector_size (int): Vector size for the embedding model to be used.
322
+ schema_name (str): The schema name.
323
+ Default: "public".
324
+ content_column (str): Name of the column to store document content.
325
+ Default: "page_content".
326
+ embedding_column (str) : Name of the column to store vector embeddings.
327
+ Default: "embedding".
328
+ metadata_columns (Optional[list[Union[Column, ColumnDict]]]): A list of Columns to create for custom
329
+ metadata. Default: None. Optional.
330
+ metadata_json_column (str): The column to store extra metadata in JSON format.
331
+ Default: "langchain_metadata". Optional.
332
+ id_column (Union[str, Column, ColumnDict]) : Column to store ids.
333
+ Default: "langchain_id" column name with data type UUID. Optional.
334
+ overwrite_existing (bool): Whether to drop existing table. Default: False.
335
+ store_metadata (bool): Whether to store metadata in the table.
336
+ Default: True.
337
+ """
338
+ self._run_as_sync(
339
+ self._ainit_vectorstore_table(
340
+ table_name,
341
+ vector_size,
342
+ schema_name=schema_name,
343
+ content_column=content_column,
344
+ embedding_column=embedding_column,
345
+ metadata_columns=metadata_columns,
346
+ metadata_json_column=metadata_json_column,
347
+ id_column=id_column,
348
+ overwrite_existing=overwrite_existing,
349
+ store_metadata=store_metadata,
350
+ )
351
+ )
@@ -0,0 +1,155 @@
1
+ """Index class to add vector indexes on the PGVectorStore.
2
+
3
+ Learn more about vector indexes at https://github.com/pgvector/pgvector?tab=readme-ov-file#indexing
4
+ """
5
+
6
+ import enum
7
+ import re
8
+ import warnings
9
+ from abc import ABC, abstractmethod
10
+ from dataclasses import dataclass, field
11
+ from typing import Optional
12
+
13
+
14
+ @dataclass
15
+ class StrategyMixin:
16
+ operator: str
17
+ search_function: str
18
+ index_function: str
19
+
20
+
21
+ class DistanceStrategy(StrategyMixin, enum.Enum):
22
+ """Enumerator of the Distance strategies."""
23
+
24
+ EUCLIDEAN = "<->", "l2_distance", "vector_l2_ops"
25
+ COSINE_DISTANCE = "<=>", "cosine_distance", "vector_cosine_ops"
26
+ INNER_PRODUCT = "<#>", "inner_product", "vector_ip_ops"
27
+
28
+
29
+ DEFAULT_DISTANCE_STRATEGY: DistanceStrategy = DistanceStrategy.COSINE_DISTANCE
30
+ DEFAULT_INDEX_NAME_SUFFIX: str = "langchainvectorindex"
31
+
32
+
33
+ def validate_identifier(identifier: str) -> None:
34
+ if re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", identifier) is None:
35
+ raise ValueError(
36
+ f"Invalid identifier: {identifier}. Identifiers must start with a letter or underscore, and subsequent characters can be letters, digits, or underscores."
37
+ )
38
+
39
+
40
+ @dataclass
41
+ class BaseIndex(ABC):
42
+ """
43
+ Abstract base class for defining vector indexes.
44
+
45
+ Attributes:
46
+ name (Optional[str]): A human-readable name for the index. Defaults to None.
47
+ index_type (str): A string identifying the type of index. Defaults to "base".
48
+ distance_strategy (DistanceStrategy): The strategy used to calculate distances
49
+ between vectors in the index. Defaults to DistanceStrategy.COSINE_DISTANCE.
50
+ partial_indexes (Optional[list[str]]): A list of names of partial indexes. Defaults to None.
51
+ extension_name (Optional[str]): The name of the extension to be created for the index, if any. Defaults to None.
52
+ """
53
+
54
+ name: Optional[str] = None
55
+ index_type: str = "base"
56
+ distance_strategy: DistanceStrategy = field(
57
+ default_factory=lambda: DistanceStrategy.COSINE_DISTANCE
58
+ )
59
+ partial_indexes: Optional[list[str]] = None
60
+ extension_name: Optional[str] = None
61
+
62
+ @abstractmethod
63
+ def index_options(self) -> str:
64
+ """Set index query options for vector store initialization."""
65
+ raise NotImplementedError(
66
+ "index_options method must be implemented by subclass"
67
+ )
68
+
69
+ def get_index_function(self) -> str:
70
+ return self.distance_strategy.index_function
71
+
72
+ def __post_init__(self) -> None:
73
+ """Check if initialization parameters are valid.
74
+
75
+ Raises:
76
+ ValueError: extension_name is a valid postgreSQL identifier
77
+ """
78
+
79
+ if self.extension_name:
80
+ validate_identifier(self.extension_name)
81
+ if self.index_type:
82
+ validate_identifier(self.index_type)
83
+
84
+
85
+ @dataclass
86
+ class ExactNearestNeighbor(BaseIndex):
87
+ index_type: str = "exactnearestneighbor"
88
+
89
+
90
+ @dataclass
91
+ class QueryOptions(ABC):
92
+ @abstractmethod
93
+ def to_parameter(self) -> list[str]:
94
+ """Convert index attributes to list of configurations."""
95
+ raise NotImplementedError("to_parameter method must be implemented by subclass")
96
+
97
+ @abstractmethod
98
+ def to_string(self) -> str:
99
+ """Convert index attributes to string."""
100
+ raise NotImplementedError("to_string method must be implemented by subclass")
101
+
102
+
103
+ @dataclass
104
+ class HNSWIndex(BaseIndex):
105
+ index_type: str = "hnsw"
106
+ m: int = 16
107
+ ef_construction: int = 64
108
+
109
+ def index_options(self) -> str:
110
+ """Set index query options for vector store initialization."""
111
+ return f"(m = {self.m}, ef_construction = {self.ef_construction})"
112
+
113
+
114
+ @dataclass
115
+ class HNSWQueryOptions(QueryOptions):
116
+ ef_search: int = 40
117
+
118
+ def to_parameter(self) -> list[str]:
119
+ """Convert index attributes to list of configurations."""
120
+ return [f"hnsw.ef_search = {self.ef_search}"]
121
+
122
+ def to_string(self) -> str:
123
+ """Convert index attributes to string."""
124
+ warnings.warn(
125
+ "to_string is deprecated, use to_parameter instead.",
126
+ DeprecationWarning,
127
+ )
128
+ return f"hnsw.ef_search = {self.ef_search}"
129
+
130
+
131
+ @dataclass
132
+ class IVFFlatIndex(BaseIndex):
133
+ index_type: str = "ivfflat"
134
+ lists: int = 100
135
+
136
+ def index_options(self) -> str:
137
+ """Set index query options for vector store initialization."""
138
+ return f"(lists = {self.lists})"
139
+
140
+
141
+ @dataclass
142
+ class IVFFlatQueryOptions(QueryOptions):
143
+ probes: int = 1
144
+
145
+ def to_parameter(self) -> list[str]:
146
+ """Convert index attributes to list of configurations."""
147
+ return [f"ivfflat.probes = {self.probes}"]
148
+
149
+ def to_string(self) -> str:
150
+ """Convert index attributes to string."""
151
+ warnings.warn(
152
+ "to_string is deprecated, use to_parameter instead.",
153
+ DeprecationWarning,
154
+ )
155
+ return f"ivfflat.probes = {self.probes}"