vectordb-bench 0.0.14__py3-none-any.whl → 0.0.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -37,23 +37,24 @@ class config:
37
37
  K_DEFAULT = 100 # default return top k nearest neighbors during search
38
38
  CUSTOM_CONFIG_DIR = pathlib.Path(__file__).parent.joinpath("custom/custom_case.json")
39
39
 
40
- CAPACITY_TIMEOUT_IN_SECONDS = 24 * 3600 # 24h
41
- LOAD_TIMEOUT_DEFAULT = 2.5 * 3600 # 2.5h
42
- LOAD_TIMEOUT_768D_1M = 2.5 * 3600 # 2.5h
43
- LOAD_TIMEOUT_768D_10M = 25 * 3600 # 25h
44
- LOAD_TIMEOUT_768D_100M = 250 * 3600 # 10.41d
40
+ CAPACITY_TIMEOUT_IN_SECONDS = 24 * 3600 # 24h
41
+ LOAD_TIMEOUT_DEFAULT = 24 * 3600 # 24h
42
+ LOAD_TIMEOUT_768D_1M = 24 * 3600 # 24h
43
+ LOAD_TIMEOUT_768D_10M = 240 * 3600 # 10d
44
+ LOAD_TIMEOUT_768D_100M = 2400 * 3600 # 100d
45
45
 
46
- LOAD_TIMEOUT_1536D_500K = 2.5 * 3600 # 2.5h
47
- LOAD_TIMEOUT_1536D_5M = 25 * 3600 # 25h
46
+ LOAD_TIMEOUT_1536D_500K = 24 * 3600 # 24h
47
+ LOAD_TIMEOUT_1536D_5M = 240 * 3600 # 10d
48
48
 
49
- OPTIMIZE_TIMEOUT_DEFAULT = 30 * 60 # 30min
50
- OPTIMIZE_TIMEOUT_768D_1M = 30 * 60 # 30min
51
- OPTIMIZE_TIMEOUT_768D_10M = 5 * 3600 # 5h
52
- OPTIMIZE_TIMEOUT_768D_100M = 50 * 3600 # 50h
49
+ OPTIMIZE_TIMEOUT_DEFAULT = 24 * 3600 # 24h
50
+ OPTIMIZE_TIMEOUT_768D_1M = 24 * 3600 # 24h
51
+ OPTIMIZE_TIMEOUT_768D_10M = 240 * 3600 # 10d
52
+ OPTIMIZE_TIMEOUT_768D_100M = 2400 * 3600 # 100d
53
53
 
54
54
 
55
- OPTIMIZE_TIMEOUT_1536D_500K = 15 * 60 # 15min
56
- OPTIMIZE_TIMEOUT_1536D_5M = 2.5 * 3600 # 2.5h
55
+ OPTIMIZE_TIMEOUT_1536D_500K = 24 * 3600 # 24h
56
+ OPTIMIZE_TIMEOUT_1536D_5M = 240 * 3600 # 10d
57
+
57
58
  def display(self) -> str:
58
59
  tmp = [
59
60
  i for i in inspect.getmembers(self)
@@ -31,6 +31,7 @@ class DB(Enum):
31
31
  PgVector = "PgVector"
32
32
  PgVectoRS = "PgVectoRS"
33
33
  PgVectorScale = "PgVectorScale"
34
+ PgDiskANN = "PgDiskANN"
34
35
  Redis = "Redis"
35
36
  MemoryDB = "MemoryDB"
36
37
  Chroma = "Chroma"
@@ -77,6 +78,10 @@ class DB(Enum):
77
78
  from .pgvectorscale.pgvectorscale import PgVectorScale
78
79
  return PgVectorScale
79
80
 
81
+ if self == DB.PgDiskANN:
82
+ from .pgdiskann.pgdiskann import PgDiskANN
83
+ return PgDiskANN
84
+
80
85
  if self == DB.Redis:
81
86
  from .redis.redis import Redis
82
87
  return Redis
@@ -132,6 +137,10 @@ class DB(Enum):
132
137
  from .pgvectorscale.config import PgVectorScaleConfig
133
138
  return PgVectorScaleConfig
134
139
 
140
+ if self == DB.PgDiskANN:
141
+ from .pgdiskann.config import PgDiskANNConfig
142
+ return PgDiskANNConfig
143
+
135
144
  if self == DB.Redis:
136
145
  from .redis.config import RedisConfig
137
146
  return RedisConfig
@@ -185,6 +194,10 @@ class DB(Enum):
185
194
  from .pgvectorscale.config import _pgvectorscale_case_config
186
195
  return _pgvectorscale_case_config.get(index_type)
187
196
 
197
+ if self == DB.PgDiskANN:
198
+ from .pgdiskann.config import _pgdiskann_case_config
199
+ return _pgdiskann_case_config.get(index_type)
200
+
188
201
  # DB.Pinecone, DB.Chroma, DB.Redis
189
202
  return EmptyDBCaseConfig
190
203
 
@@ -10,6 +10,8 @@ class MetricType(str, Enum):
10
10
  L2 = "L2"
11
11
  COSINE = "COSINE"
12
12
  IP = "IP"
13
+ HAMMING = "HAMMING"
14
+ JACCARD = "JACCARD"
13
15
 
14
16
 
15
17
  class IndexType(str, Enum):
@@ -0,0 +1,99 @@
1
+ import click
2
+ import os
3
+ from pydantic import SecretStr
4
+
5
+ from ....cli.cli import (
6
+ CommonTypedDict,
7
+ cli,
8
+ click_parameter_decorators_from_typed_dict,
9
+ run,
10
+ )
11
+ from typing import Annotated, Optional, Unpack
12
+ from vectordb_bench.backend.clients import DB
13
+
14
+
15
+ class PgDiskAnnTypedDict(CommonTypedDict):
16
+ user_name: Annotated[
17
+ str, click.option("--user-name", type=str, help="Db username", required=True)
18
+ ]
19
+ password: Annotated[
20
+ str,
21
+ click.option("--password",
22
+ type=str,
23
+ help="Postgres database password",
24
+ default=lambda: os.environ.get("POSTGRES_PASSWORD", ""),
25
+ show_default="$POSTGRES_PASSWORD",
26
+ ),
27
+ ]
28
+
29
+ host: Annotated[
30
+ str, click.option("--host", type=str, help="Db host", required=True)
31
+ ]
32
+ db_name: Annotated[
33
+ str, click.option("--db-name", type=str, help="Db name", required=True)
34
+ ]
35
+ max_neighbors: Annotated[
36
+ int,
37
+ click.option(
38
+ "--max-neighbors", type=int, help="PgDiskAnn max neighbors",
39
+ ),
40
+ ]
41
+ l_value_ib: Annotated[
42
+ int,
43
+ click.option(
44
+ "--l-value-ib", type=int, help="PgDiskAnn l_value_ib",
45
+ ),
46
+ ]
47
+ l_value_is: Annotated[
48
+ float,
49
+ click.option(
50
+ "--l-value-is", type=float, help="PgDiskAnn l_value_is",
51
+ ),
52
+ ]
53
+ maintenance_work_mem: Annotated[
54
+ Optional[str],
55
+ click.option(
56
+ "--maintenance-work-mem",
57
+ type=str,
58
+ help="Sets the maximum memory to be used for maintenance operations (index creation). "
59
+ "Can be entered as string with unit like '64GB' or as an integer number of KB."
60
+ "This will set the parameters: max_parallel_maintenance_workers,"
61
+ " max_parallel_workers & table(parallel_workers)",
62
+ required=False,
63
+ ),
64
+ ]
65
+ max_parallel_workers: Annotated[
66
+ Optional[int],
67
+ click.option(
68
+ "--max-parallel-workers",
69
+ type=int,
70
+ help="Sets the maximum number of parallel processes per maintenance operation (index creation)",
71
+ required=False,
72
+ ),
73
+ ]
74
+
75
+ @cli.command()
76
+ @click_parameter_decorators_from_typed_dict(PgDiskAnnTypedDict)
77
+ def PgDiskAnn(
78
+ **parameters: Unpack[PgDiskAnnTypedDict],
79
+ ):
80
+ from .config import PgDiskANNConfig, PgDiskANNImplConfig
81
+
82
+ run(
83
+ db=DB.PgDiskANN,
84
+ db_config=PgDiskANNConfig(
85
+ db_label=parameters["db_label"],
86
+ user_name=SecretStr(parameters["user_name"]),
87
+ password=SecretStr(parameters["password"]),
88
+ host=parameters["host"],
89
+ db_name=parameters["db_name"],
90
+ ),
91
+ db_case_config=PgDiskANNImplConfig(
92
+ max_neighbors=parameters["max_neighbors"],
93
+ l_value_ib=parameters["l_value_ib"],
94
+ l_value_is=parameters["l_value_is"],
95
+ max_parallel_workers=parameters["max_parallel_workers"],
96
+ maintenance_work_mem=parameters["maintenance_work_mem"],
97
+ ),
98
+ **parameters,
99
+ )
@@ -0,0 +1,145 @@
1
+ from abc import abstractmethod
2
+ from typing import Any, Mapping, Optional, Sequence, TypedDict
3
+ from pydantic import BaseModel, SecretStr
4
+ from typing_extensions import LiteralString
5
+ from ..api import DBCaseConfig, DBConfig, IndexType, MetricType
6
+
7
+ POSTGRE_URL_PLACEHOLDER = "postgresql://%s:%s@%s/%s"
8
+
9
+
10
+ class PgDiskANNConfigDict(TypedDict):
11
+ """These keys will be directly used as kwargs in psycopg connection string,
12
+ so the names must match exactly psycopg API"""
13
+
14
+ user: str
15
+ password: str
16
+ host: str
17
+ port: int
18
+ dbname: str
19
+
20
+
21
+ class PgDiskANNConfig(DBConfig):
22
+ user_name: SecretStr = SecretStr("postgres")
23
+ password: SecretStr
24
+ host: str = "localhost"
25
+ port: int = 5432
26
+ db_name: str
27
+
28
+ def to_dict(self) -> PgDiskANNConfigDict:
29
+ user_str = self.user_name.get_secret_value()
30
+ pwd_str = self.password.get_secret_value()
31
+ return {
32
+ "host": self.host,
33
+ "port": self.port,
34
+ "dbname": self.db_name,
35
+ "user": user_str,
36
+ "password": pwd_str,
37
+ }
38
+
39
+
40
+ class PgDiskANNIndexConfig(BaseModel, DBCaseConfig):
41
+ metric_type: MetricType | None = None
42
+ create_index_before_load: bool = False
43
+ create_index_after_load: bool = True
44
+ maintenance_work_mem: Optional[str]
45
+ max_parallel_workers: Optional[int]
46
+
47
+ def parse_metric(self) -> str:
48
+ if self.metric_type == MetricType.L2:
49
+ return "vector_l2_ops"
50
+ elif self.metric_type == MetricType.IP:
51
+ return "vector_ip_ops"
52
+ return "vector_cosine_ops"
53
+
54
+ def parse_metric_fun_op(self) -> LiteralString:
55
+ if self.metric_type == MetricType.L2:
56
+ return "<->"
57
+ elif self.metric_type == MetricType.IP:
58
+ return "<#>"
59
+ return "<=>"
60
+
61
+ def parse_metric_fun_str(self) -> str:
62
+ if self.metric_type == MetricType.L2:
63
+ return "l2_distance"
64
+ elif self.metric_type == MetricType.IP:
65
+ return "max_inner_product"
66
+ return "cosine_distance"
67
+
68
+ @abstractmethod
69
+ def index_param(self) -> dict:
70
+ ...
71
+
72
+ @abstractmethod
73
+ def search_param(self) -> dict:
74
+ ...
75
+
76
+ @abstractmethod
77
+ def session_param(self) -> dict:
78
+ ...
79
+
80
+ @staticmethod
81
+ def _optionally_build_with_options(with_options: Mapping[str, Any]) -> Sequence[dict[str, Any]]:
82
+ """Walk through mappings, creating a List of {key1 = value} pairs. That will be used to build a where clause"""
83
+ options = []
84
+ for option_name, value in with_options.items():
85
+ if value is not None:
86
+ options.append(
87
+ {
88
+ "option_name": option_name,
89
+ "val": str(value),
90
+ }
91
+ )
92
+ return options
93
+
94
+ @staticmethod
95
+ def _optionally_build_set_options(
96
+ set_mapping: Mapping[str, Any]
97
+ ) -> Sequence[dict[str, Any]]:
98
+ """Walk through options, creating 'SET 'key1 = "value1";' list"""
99
+ session_options = []
100
+ for setting_name, value in set_mapping.items():
101
+ if value:
102
+ session_options.append(
103
+ {"parameter": {
104
+ "setting_name": setting_name,
105
+ "val": str(value),
106
+ },
107
+ }
108
+ )
109
+ return session_options
110
+
111
+
112
+ class PgDiskANNImplConfig(PgDiskANNIndexConfig):
113
+ index: IndexType = IndexType.DISKANN
114
+ max_neighbors: int | None
115
+ l_value_ib: int | None
116
+ l_value_is: float | None
117
+ maintenance_work_mem: Optional[str] = None
118
+ max_parallel_workers: Optional[int] = None
119
+
120
+ def index_param(self) -> dict:
121
+ return {
122
+ "metric": self.parse_metric(),
123
+ "index_type": self.index.value,
124
+ "options": {
125
+ "max_neighbors": self.max_neighbors,
126
+ "l_value_ib": self.l_value_ib,
127
+ },
128
+ "maintenance_work_mem": self.maintenance_work_mem,
129
+ "max_parallel_workers": self.max_parallel_workers,
130
+ }
131
+
132
+ def search_param(self) -> dict:
133
+ return {
134
+ "metric": self.parse_metric(),
135
+ "metric_fun_op": self.parse_metric_fun_op(),
136
+ }
137
+
138
+ def session_param(self) -> dict:
139
+ return {
140
+ "diskann.l_value_is": self.l_value_is,
141
+ }
142
+
143
+ _pgdiskann_case_config = {
144
+ IndexType.DISKANN: PgDiskANNImplConfig,
145
+ }
@@ -0,0 +1,350 @@
1
+ """Wrapper around the pg_diskann vector database over VectorDB"""
2
+
3
+ import logging
4
+ import pprint
5
+ from contextlib import contextmanager
6
+ from typing import Any, Generator, Optional, Tuple
7
+
8
+ import numpy as np
9
+ import psycopg
10
+ from pgvector.psycopg import register_vector
11
+ from psycopg import Connection, Cursor, sql
12
+
13
+ from ..api import VectorDB
14
+ from .config import PgDiskANNConfigDict, PgDiskANNIndexConfig
15
+
16
+ log = logging.getLogger(__name__)
17
+
18
+
19
+ class PgDiskANN(VectorDB):
20
+ """Use psycopg instructions"""
21
+
22
+ conn: psycopg.Connection[Any] | None = None
23
+ coursor: psycopg.Cursor[Any] | None = None
24
+
25
+ _filtered_search: sql.Composed
26
+ _unfiltered_search: sql.Composed
27
+
28
+ def __init__(
29
+ self,
30
+ dim: int,
31
+ db_config: PgDiskANNConfigDict,
32
+ db_case_config: PgDiskANNIndexConfig,
33
+ collection_name: str = "pg_diskann_collection",
34
+ drop_old: bool = False,
35
+ **kwargs,
36
+ ):
37
+ self.name = "PgDiskANN"
38
+ self.db_config = db_config
39
+ self.case_config = db_case_config
40
+ self.table_name = collection_name
41
+ self.dim = dim
42
+
43
+ self._index_name = "pgdiskann_index"
44
+ self._primary_field = "id"
45
+ self._vector_field = "embedding"
46
+
47
+ self.conn, self.cursor = self._create_connection(**self.db_config)
48
+
49
+ log.info(f"{self.name} config values: {self.db_config}\n{self.case_config}")
50
+ if not any(
51
+ (
52
+ self.case_config.create_index_before_load,
53
+ self.case_config.create_index_after_load,
54
+ )
55
+ ):
56
+ err = f"{self.name} config must create an index using create_index_before_load or create_index_after_load"
57
+ log.error(err)
58
+ raise RuntimeError(
59
+ f"{err}\n{pprint.pformat(self.db_config)}\n{pprint.pformat(self.case_config)}"
60
+ )
61
+
62
+ if drop_old:
63
+ self._drop_index()
64
+ self._drop_table()
65
+ self._create_table(dim)
66
+ if self.case_config.create_index_before_load:
67
+ self._create_index()
68
+
69
+ self.cursor.close()
70
+ self.conn.close()
71
+ self.cursor = None
72
+ self.conn = None
73
+
74
+ @staticmethod
75
+ def _create_connection(**kwargs) -> Tuple[Connection, Cursor]:
76
+ conn = psycopg.connect(**kwargs)
77
+ conn.cursor().execute("CREATE EXTENSION IF NOT EXISTS pg_diskann CASCADE")
78
+ conn.commit()
79
+ register_vector(conn)
80
+ conn.autocommit = False
81
+ cursor = conn.cursor()
82
+
83
+ assert conn is not None, "Connection is not initialized"
84
+ assert cursor is not None, "Cursor is not initialized"
85
+
86
+ return conn, cursor
87
+
88
+ @contextmanager
89
+ def init(self) -> Generator[None, None, None]:
90
+ self.conn, self.cursor = self._create_connection(**self.db_config)
91
+
92
+ # index configuration may have commands defined that we should set during each client session
93
+ session_options: dict[str, Any] = self.case_config.session_param()
94
+
95
+ if len(session_options) > 0:
96
+ for setting_name, setting_val in session_options.items():
97
+ command = sql.SQL("SET {setting_name} " + "= {setting_val};").format(
98
+ setting_name=sql.Identifier(setting_name),
99
+ setting_val=sql.Identifier(str(setting_val)),
100
+ )
101
+ log.debug(command.as_string(self.cursor))
102
+ self.cursor.execute(command)
103
+ self.conn.commit()
104
+
105
+ self._filtered_search = sql.Composed(
106
+ [
107
+ sql.SQL(
108
+ "SELECT id FROM public.{table_name} WHERE id >= %s ORDER BY embedding "
109
+ ).format(table_name=sql.Identifier(self.table_name)),
110
+ sql.SQL(self.case_config.search_param()["metric_fun_op"]),
111
+ sql.SQL(" %s::vector LIMIT %s::int"),
112
+ ]
113
+ )
114
+
115
+ self._unfiltered_search = sql.Composed(
116
+ [
117
+ sql.SQL("SELECT id FROM public.{} ORDER BY embedding ").format(
118
+ sql.Identifier(self.table_name)
119
+ ),
120
+ sql.SQL(self.case_config.search_param()["metric_fun_op"]),
121
+ sql.SQL(" %s::vector LIMIT %s::int"),
122
+ ]
123
+ )
124
+
125
+ try:
126
+ yield
127
+ finally:
128
+ self.cursor.close()
129
+ self.conn.close()
130
+ self.cursor = None
131
+ self.conn = None
132
+
133
+ def _drop_table(self):
134
+ assert self.conn is not None, "Connection is not initialized"
135
+ assert self.cursor is not None, "Cursor is not initialized"
136
+ log.info(f"{self.name} client drop table : {self.table_name}")
137
+
138
+ self.cursor.execute(
139
+ sql.SQL("DROP TABLE IF EXISTS public.{table_name}").format(
140
+ table_name=sql.Identifier(self.table_name)
141
+ )
142
+ )
143
+ self.conn.commit()
144
+
145
+ def ready_to_load(self):
146
+ pass
147
+
148
+ def optimize(self):
149
+ self._post_insert()
150
+
151
+ def _post_insert(self):
152
+ log.info(f"{self.name} post insert before optimize")
153
+ if self.case_config.create_index_after_load:
154
+ self._drop_index()
155
+ self._create_index()
156
+
157
+ def _drop_index(self):
158
+ assert self.conn is not None, "Connection is not initialized"
159
+ assert self.cursor is not None, "Cursor is not initialized"
160
+ log.info(f"{self.name} client drop index : {self._index_name}")
161
+
162
+ drop_index_sql = sql.SQL("DROP INDEX IF EXISTS {index_name}").format(
163
+ index_name=sql.Identifier(self._index_name)
164
+ )
165
+ log.debug(drop_index_sql.as_string(self.cursor))
166
+ self.cursor.execute(drop_index_sql)
167
+ self.conn.commit()
168
+
169
+ def _set_parallel_index_build_param(self):
170
+ assert self.conn is not None, "Connection is not initialized"
171
+ assert self.cursor is not None, "Cursor is not initialized"
172
+
173
+ index_param = self.case_config.index_param()
174
+
175
+ if index_param["maintenance_work_mem"] is not None:
176
+ self.cursor.execute(
177
+ sql.SQL("SET maintenance_work_mem TO {};").format(
178
+ index_param["maintenance_work_mem"]
179
+ )
180
+ )
181
+ self.cursor.execute(
182
+ sql.SQL("ALTER USER {} SET maintenance_work_mem TO {};").format(
183
+ sql.Identifier(self.db_config["user"]),
184
+ index_param["maintenance_work_mem"],
185
+ )
186
+ )
187
+ self.conn.commit()
188
+
189
+ if index_param["max_parallel_workers"] is not None:
190
+ self.cursor.execute(
191
+ sql.SQL("SET max_parallel_maintenance_workers TO '{}';").format(
192
+ index_param["max_parallel_workers"]
193
+ )
194
+ )
195
+ self.cursor.execute(
196
+ sql.SQL(
197
+ "ALTER USER {} SET max_parallel_maintenance_workers TO '{}';"
198
+ ).format(
199
+ sql.Identifier(self.db_config["user"]),
200
+ index_param["max_parallel_workers"],
201
+ )
202
+ )
203
+ self.cursor.execute(
204
+ sql.SQL("SET max_parallel_workers TO '{}';").format(
205
+ index_param["max_parallel_workers"]
206
+ )
207
+ )
208
+ self.cursor.execute(
209
+ sql.SQL(
210
+ "ALTER USER {} SET max_parallel_workers TO '{}';"
211
+ ).format(
212
+ sql.Identifier(self.db_config["user"]),
213
+ index_param["max_parallel_workers"],
214
+ )
215
+ )
216
+ self.cursor.execute(
217
+ sql.SQL(
218
+ "ALTER TABLE {} SET (parallel_workers = {});"
219
+ ).format(
220
+ sql.Identifier(self.table_name),
221
+ index_param["max_parallel_workers"],
222
+ )
223
+ )
224
+ self.conn.commit()
225
+
226
+ results = self.cursor.execute(
227
+ sql.SQL("SHOW max_parallel_maintenance_workers;")
228
+ ).fetchall()
229
+ results.extend(
230
+ self.cursor.execute(sql.SQL("SHOW max_parallel_workers;")).fetchall()
231
+ )
232
+ results.extend(
233
+ self.cursor.execute(sql.SQL("SHOW maintenance_work_mem;")).fetchall()
234
+ )
235
+ log.info(f"{self.name} parallel index creation parameters: {results}")
236
+ def _create_index(self):
237
+ assert self.conn is not None, "Connection is not initialized"
238
+ assert self.cursor is not None, "Cursor is not initialized"
239
+ log.info(f"{self.name} client create index : {self._index_name}")
240
+
241
+ index_param: dict[str, Any] = self.case_config.index_param()
242
+ self._set_parallel_index_build_param()
243
+
244
+ options = []
245
+ for option_name, option_val in index_param["options"].items():
246
+ if option_val is not None:
247
+ options.append(
248
+ sql.SQL("{option_name} = {val}").format(
249
+ option_name=sql.Identifier(option_name),
250
+ val=sql.Identifier(str(option_val)),
251
+ )
252
+ )
253
+
254
+ if any(options):
255
+ with_clause = sql.SQL("WITH ({});").format(sql.SQL(", ").join(options))
256
+ else:
257
+ with_clause = sql.Composed(())
258
+
259
+ index_create_sql = sql.SQL(
260
+ """
261
+ CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name}
262
+ USING {index_type} (embedding {embedding_metric})
263
+ """
264
+ ).format(
265
+ index_name=sql.Identifier(self._index_name),
266
+ table_name=sql.Identifier(self.table_name),
267
+ index_type=sql.Identifier(index_param["index_type"].lower()),
268
+ embedding_metric=sql.Identifier(index_param["metric"]),
269
+ )
270
+ index_create_sql_with_with_clause = (
271
+ index_create_sql + with_clause
272
+ ).join(" ")
273
+ log.debug(index_create_sql_with_with_clause.as_string(self.cursor))
274
+ self.cursor.execute(index_create_sql_with_with_clause)
275
+ self.conn.commit()
276
+
277
+ def _create_table(self, dim: int):
278
+ assert self.conn is not None, "Connection is not initialized"
279
+ assert self.cursor is not None, "Cursor is not initialized"
280
+
281
+ try:
282
+ log.info(f"{self.name} client create table : {self.table_name}")
283
+
284
+ self.cursor.execute(
285
+ sql.SQL(
286
+ "CREATE TABLE IF NOT EXISTS public.{table_name} (id BIGINT PRIMARY KEY, embedding vector({dim}));"
287
+ ).format(table_name=sql.Identifier(self.table_name), dim=dim)
288
+ )
289
+ self.conn.commit()
290
+ except Exception as e:
291
+ log.warning(
292
+ f"Failed to create pgdiskann table: {self.table_name} error: {e}"
293
+ )
294
+ raise e from None
295
+
296
+ def insert_embeddings(
297
+ self,
298
+ embeddings: list[list[float]],
299
+ metadata: list[int],
300
+ **kwargs: Any,
301
+ ) -> Tuple[int, Optional[Exception]]:
302
+ assert self.conn is not None, "Connection is not initialized"
303
+ assert self.cursor is not None, "Cursor is not initialized"
304
+
305
+ try:
306
+ metadata_arr = np.array(metadata)
307
+ embeddings_arr = np.array(embeddings)
308
+
309
+ with self.cursor.copy(
310
+ sql.SQL("COPY public.{table_name} FROM STDIN (FORMAT BINARY)").format(
311
+ table_name=sql.Identifier(self.table_name)
312
+ )
313
+ ) as copy:
314
+ copy.set_types(["bigint", "vector"])
315
+ for i, row in enumerate(metadata_arr):
316
+ copy.write_row((row, embeddings_arr[i]))
317
+ self.conn.commit()
318
+
319
+ if kwargs.get("last_batch"):
320
+ self._post_insert()
321
+
322
+ return len(metadata), None
323
+ except Exception as e:
324
+ log.warning(
325
+ f"Failed to insert data into table ({self.table_name}), error: {e}"
326
+ )
327
+ return 0, e
328
+
329
+ def search_embedding(
330
+ self,
331
+ query: list[float],
332
+ k: int = 100,
333
+ filters: dict | None = None,
334
+ timeout: int | None = None,
335
+ ) -> list[int]:
336
+ assert self.conn is not None, "Connection is not initialized"
337
+ assert self.cursor is not None, "Cursor is not initialized"
338
+
339
+ q = np.asarray(query)
340
+ if filters:
341
+ gt = filters.get("id")
342
+ result = self.cursor.execute(
343
+ self._filtered_search, (gt, q, k), prepare=True, binary=True
344
+ )
345
+ else:
346
+ result = self.cursor.execute(
347
+ self._unfiltered_search, (q, k), prepare=True, binary=True
348
+ )
349
+
350
+ return [int(i[0]) for i in result.fetchall()]