vectordb-bench 0.0.12__py3-none-any.whl → 0.0.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (23) hide show
  1. vectordb_bench/backend/clients/__init__.py +22 -0
  2. vectordb_bench/backend/clients/api.py +21 -1
  3. vectordb_bench/backend/clients/memorydb/cli.py +88 -0
  4. vectordb_bench/backend/clients/memorydb/config.py +54 -0
  5. vectordb_bench/backend/clients/memorydb/memorydb.py +254 -0
  6. vectordb_bench/backend/clients/pgvecto_rs/cli.py +154 -0
  7. vectordb_bench/backend/clients/pgvecto_rs/config.py +108 -73
  8. vectordb_bench/backend/clients/pgvecto_rs/pgvecto_rs.py +159 -59
  9. vectordb_bench/backend/clients/pgvectorscale/config.py +111 -0
  10. vectordb_bench/backend/clients/pgvectorscale/pgvectorscale.py +272 -0
  11. vectordb_bench/cli/vectordbbench.py +5 -0
  12. vectordb_bench/frontend/components/check_results/data.py +13 -6
  13. vectordb_bench/frontend/components/run_test/caseSelector.py +10 -0
  14. vectordb_bench/frontend/components/run_test/dbConfigSetting.py +37 -15
  15. vectordb_bench/frontend/components/run_test/initStyle.py +3 -1
  16. vectordb_bench/frontend/config/dbCaseConfigs.py +173 -9
  17. vectordb_bench/models.py +18 -6
  18. {vectordb_bench-0.0.12.dist-info → vectordb_bench-0.0.13.dist-info}/METADATA +11 -3
  19. {vectordb_bench-0.0.12.dist-info → vectordb_bench-0.0.13.dist-info}/RECORD +23 -17
  20. {vectordb_bench-0.0.12.dist-info → vectordb_bench-0.0.13.dist-info}/WHEEL +1 -1
  21. {vectordb_bench-0.0.12.dist-info → vectordb_bench-0.0.13.dist-info}/LICENSE +0 -0
  22. {vectordb_bench-0.0.12.dist-info → vectordb_bench-0.0.13.dist-info}/entry_points.txt +0 -0
  23. {vectordb_bench-0.0.12.dist-info → vectordb_bench-0.0.13.dist-info}/top_level.txt +0 -0
@@ -1,30 +1,53 @@
1
- from typing import Literal
1
+ from abc import abstractmethod
2
+ from typing import TypedDict
3
+
2
4
  from pydantic import BaseModel, SecretStr
3
- from ..api import DBConfig, DBCaseConfig, MetricType, IndexType
5
+ from pgvecto_rs.types import IndexOption, Ivf, Hnsw, Flat, Quantization
6
+ from pgvecto_rs.types.index import QuantizationType, QuantizationRatio
7
+
8
+ from ..api import DBConfig, DBCaseConfig, IndexType, MetricType
4
9
 
5
10
  POSTGRE_URL_PLACEHOLDER = "postgresql://%s:%s@%s/%s"
6
11
 
7
12
 
13
+ class PgVectorRSConfigDict(TypedDict):
14
+ """These keys will be directly used as kwargs in psycopg connection string,
15
+ so the names must match exactly psycopg API"""
16
+
17
+ user: str
18
+ password: str
19
+ host: str
20
+ port: int
21
+ dbname: str
22
+
23
+
8
24
  class PgVectoRSConfig(DBConfig):
9
- user_name: SecretStr = "postgres"
25
+ user_name: str = "postgres"
10
26
  password: SecretStr
11
27
  host: str = "localhost"
12
28
  port: int = 5432
13
29
  db_name: str
14
30
 
15
31
  def to_dict(self) -> dict:
16
- user_str = self.user_name.get_secret_value()
32
+ user_str = self.user_name
17
33
  pwd_str = self.password.get_secret_value()
18
34
  return {
19
35
  "host": self.host,
20
36
  "port": self.port,
21
37
  "dbname": self.db_name,
22
38
  "user": user_str,
23
- "password": pwd_str
39
+ "password": pwd_str,
24
40
  }
25
41
 
42
+
26
43
  class PgVectoRSIndexConfig(BaseModel, DBCaseConfig):
27
44
  metric_type: MetricType | None = None
45
+ create_index_before_load: bool = False
46
+ create_index_after_load: bool = True
47
+
48
+ max_parallel_workers: int | None = None
49
+ quantization_type: QuantizationType | None = None
50
+ quantization_ratio: QuantizationRatio | None = None
28
51
 
29
52
  def parse_metric(self) -> str:
30
53
  if self.metric_type == MetricType.L2:
@@ -40,88 +63,100 @@ class PgVectoRSIndexConfig(BaseModel, DBCaseConfig):
40
63
  return "<#>"
41
64
  return "<=>"
42
65
 
43
- class PgVectoRSQuantConfig(PgVectoRSIndexConfig):
44
- quantizationType: Literal["trivial", "scalar", "product"]
45
- quantizationRatio: None | Literal["x4", "x8", "x16", "x32", "x64"]
46
-
47
- def parse_quantization(self) -> str:
48
- if self.quantizationType == "trivial":
49
- return "quantization = { trivial = { } }"
50
- elif self.quantizationType == "scalar":
51
- return "quantization = { scalar = { } }"
52
- else:
53
- return f'quantization = {{ product = {{ ratio = "{self.quantizationRatio}" }} }}'
54
-
66
+ def search_param(self) -> dict:
67
+ return {
68
+ "metric_fun_op": self.parse_metric_fun_op(),
69
+ }
55
70
 
56
- class HNSWConfig(PgVectoRSQuantConfig):
57
- M: int
58
- efConstruction: int
59
- index: IndexType = IndexType.HNSW
71
+ @abstractmethod
72
+ def index_param(self) -> dict[str, str]: ...
60
73
 
61
- def index_param(self) -> dict:
62
- options = f"""
63
- [indexing.hnsw]
64
- m = {self.M}
65
- ef_construction = {self.efConstruction}
66
- {self.parse_quantization()}
67
- """
68
- return {"options": options, "metric": self.parse_metric()}
74
+ @abstractmethod
75
+ def session_param(self) -> dict[str, str | int]: ...
69
76
 
70
- def search_param(self) -> dict:
71
- return {"metrics_op": self.parse_metric_fun_op()}
72
77
 
78
+ class PgVectoRSHNSWConfig(PgVectoRSIndexConfig):
79
+ index: IndexType = IndexType.HNSW
80
+ m: int | None = None
81
+ ef_search: int | None
82
+ ef_construction: int | None = None
73
83
 
74
- class IVFFlatConfig(PgVectoRSQuantConfig):
75
- nlist: int
76
- nprobe: int | None = None
84
+ def index_param(self) -> dict[str, str]:
85
+ if self.quantization_type is None:
86
+ quantization = None
87
+ else:
88
+ quantization = Quantization(
89
+ typ=self.quantization_type, ratio=self.quantization_ratio
90
+ )
91
+
92
+ option = IndexOption(
93
+ index=Hnsw(
94
+ m=self.m,
95
+ ef_construction=self.ef_construction,
96
+ quantization=quantization,
97
+ ),
98
+ threads=self.max_parallel_workers,
99
+ )
100
+ return {"options": option.dumps(), "metric": self.parse_metric()}
101
+
102
+ def session_param(self) -> dict[str, str | int]:
103
+ session_parameters = {}
104
+ if self.ef_search is not None:
105
+ session_parameters["vectors.hnsw_ef_search"] = str(self.ef_search)
106
+ return session_parameters
107
+
108
+
109
+ class PgVectoRSIVFFlatConfig(PgVectoRSIndexConfig):
77
110
  index: IndexType = IndexType.IVFFlat
111
+ probes: int | None
112
+ lists: int | None
113
+
114
+ def index_param(self) -> dict[str, str]:
115
+ if self.quantization_type is None:
116
+ quantization = None
117
+ else:
118
+ quantization = Quantization(
119
+ typ=self.quantization_type, ratio=self.quantization_ratio
120
+ )
78
121
 
79
- def index_param(self) -> dict:
80
- options = f"""
81
- [indexing.ivf]
82
- nlist = {self.nlist}
83
- nsample = {self.nprobe if self.nprobe else 10}
84
- {self.parse_quantization()}
85
- """
86
- return {"options": options, "metric": self.parse_metric()}
122
+ option = IndexOption(
123
+ index=Ivf(nlist=self.lists, quantization=quantization),
124
+ threads=self.max_parallel_workers,
125
+ )
126
+ return {"options": option.dumps(), "metric": self.parse_metric()}
87
127
 
88
- def search_param(self) -> dict:
89
- return {"metrics_op": self.parse_metric_fun_op()}
90
-
91
- class IVFFlatSQ8Config(PgVectoRSIndexConfig):
92
- nlist: int
93
- nprobe: int | None = None
94
- index: IndexType = IndexType.IVFSQ8
95
-
96
- def index_param(self) -> dict:
97
- options = f"""
98
- [indexing.ivf]
99
- nlist = {self.nlist}
100
- nsample = {self.nprobe if self.nprobe else 10}
101
- quantization = {{ scalar = {{ }} }}
102
- """
103
- return {"options": options, "metric": self.parse_metric()}
128
+ def session_param(self) -> dict[str, str | int]:
129
+ session_parameters = {}
130
+ if self.probes is not None:
131
+ session_parameters["vectors.ivf_nprobe"] = str(self.probes)
132
+ return session_parameters
104
133
 
105
- def search_param(self) -> dict:
106
- return {"metrics_op": self.parse_metric_fun_op()}
107
134
 
108
- class FLATConfig(PgVectoRSQuantConfig):
135
+ class PgVectoRSFLATConfig(PgVectoRSIndexConfig):
109
136
  index: IndexType = IndexType.Flat
110
137
 
111
- def index_param(self) -> dict:
112
- options = f"""
113
- [indexing.flat]
114
- {self.parse_quantization()}
115
- """
116
- return {"options": options, "metric": self.parse_metric()}
138
+ def index_param(self) -> dict[str, str]:
139
+ if self.quantization_type is None:
140
+ quantization = None
141
+ else:
142
+ quantization = Quantization(
143
+ typ=self.quantization_type, ratio=self.quantization_ratio
144
+ )
117
145
 
118
- def search_param(self) -> dict:
119
- return {"metrics_op": self.parse_metric_fun_op()}
146
+ option = IndexOption(
147
+ index=Flat(
148
+ quantization=quantization,
149
+ ),
150
+ threads=self.max_parallel_workers,
151
+ )
152
+ return {"options": option.dumps(), "metric": self.parse_metric()}
153
+
154
+ def session_param(self) -> dict[str, str | int]:
155
+ return {}
120
156
 
121
157
 
122
158
  _pgvecto_rs_case_config = {
123
- IndexType.HNSW: HNSWConfig,
124
- IndexType.IVFFlat: IVFFlatConfig,
125
- IndexType.IVFSQ8: IVFFlatSQ8Config,
126
- IndexType.Flat: FLATConfig,
159
+ IndexType.HNSW: PgVectoRSHNSWConfig,
160
+ IndexType.IVFFlat: PgVectoRSIVFFlatConfig,
161
+ IndexType.Flat: PgVectoRSFLATConfig,
127
162
  }
@@ -1,73 +1,138 @@
1
1
  """Wrapper around the Pgvecto.rs vector database over VectorDB"""
2
2
 
3
- import io
4
3
  import logging
4
+ import pprint
5
5
  from contextlib import contextmanager
6
- from typing import Any
7
- import pandas as pd
8
- import psycopg2
9
- import psycopg2.extras
6
+ from typing import Any, Generator, Optional, Tuple
10
7
 
11
- from ..api import VectorDB, DBCaseConfig
8
+ import numpy as np
9
+ import psycopg
10
+ from psycopg import Connection, Cursor, sql
11
+ from pgvecto_rs.psycopg import register_vector
12
+
13
+ from ..api import VectorDB
14
+ from .config import PgVectoRSConfig, PgVectoRSIndexConfig
12
15
 
13
16
  log = logging.getLogger(__name__)
14
17
 
18
+
15
19
  class PgVectoRS(VectorDB):
16
- """Use SQLAlchemy instructions"""
20
+ """Use psycopg instructions"""
21
+
22
+ conn: psycopg.Connection[Any] | None = None
23
+ cursor: psycopg.Cursor[Any] | None = None
24
+ _unfiltered_search: sql.Composed
25
+ _filtered_search: sql.Composed
17
26
 
18
27
  def __init__(
19
28
  self,
20
29
  dim: int,
21
- db_config: dict,
22
- db_case_config: DBCaseConfig,
23
- collection_name: str = "PgVectorCollection",
30
+ db_config: PgVectoRSConfig,
31
+ db_case_config: PgVectoRSIndexConfig,
32
+ collection_name: str = "PgVectoRSCollection",
24
33
  drop_old: bool = False,
25
34
  **kwargs,
26
35
  ):
36
+
37
+ self.name = "PgVectorRS"
27
38
  self.db_config = db_config
28
39
  self.case_config = db_case_config
29
40
  self.table_name = collection_name
30
41
  self.dim = dim
31
42
 
32
- self._index_name = "pqvector_index"
43
+ self._index_name = "pgvectors_index"
33
44
  self._primary_field = "id"
34
45
  self._vector_field = "embedding"
35
46
 
36
47
  # construct basic units
37
- self.conn = psycopg2.connect(**self.db_config)
38
- self.conn.autocommit = False
39
- self.cursor = self.conn.cursor()
48
+ self.conn, self.cursor = self._create_connection(**self.db_config)
40
49
 
41
- # create vector extension
42
- self.cursor.execute("CREATE EXTENSION IF NOT EXISTS vectors")
43
- self.conn.commit()
50
+ log.info(f"{self.name} config values: {self.db_config}\n{self.case_config}")
51
+ if not any(
52
+ (
53
+ self.case_config.create_index_before_load,
54
+ self.case_config.create_index_after_load,
55
+ )
56
+ ):
57
+ err = f"{self.name} config must create an index using create_index_before_load or create_index_after_load"
58
+ log.error(err)
59
+ raise RuntimeError(
60
+ f"{err}\n{pprint.pformat(self.db_config)}\n{pprint.pformat(self.case_config)}"
61
+ )
44
62
 
45
63
  if drop_old:
46
64
  log.info(f"Pgvecto.rs client drop table : {self.table_name}")
47
65
  self._drop_index()
48
66
  self._drop_table()
49
67
  self._create_table(dim)
50
- self._create_index()
68
+ if self.case_config.create_index_before_load:
69
+ self._create_index()
51
70
 
52
71
  self.cursor.close()
53
72
  self.conn.close()
54
73
  self.cursor = None
55
74
  self.conn = None
56
75
 
76
+ @staticmethod
77
+ def _create_connection(**kwargs) -> Tuple[Connection, Cursor]:
78
+ conn = psycopg.connect(**kwargs)
79
+
80
+ # create vector extension
81
+ conn.execute("CREATE EXTENSION IF NOT EXISTS vectors")
82
+ conn.commit()
83
+ register_vector(conn)
84
+
85
+ conn.autocommit = False
86
+ cursor = conn.cursor()
87
+
88
+ assert conn is not None, "Connection is not initialized"
89
+ assert cursor is not None, "Cursor is not initialized"
90
+
91
+ return conn, cursor
92
+
57
93
  @contextmanager
58
- def init(self) -> None:
94
+ def init(self) -> Generator[None, None, None]:
59
95
  """
60
96
  Examples:
61
97
  >>> with self.init():
62
98
  >>> self.insert_embeddings()
63
99
  >>> self.search_embedding()
64
100
  """
65
- self.conn = psycopg2.connect(**self.db_config)
66
- self.conn.autocommit = False
67
- self.cursor = self.conn.cursor()
68
- self.cursor.execute('SET search_path = "$user", public, vectors')
101
+
102
+ self.conn, self.cursor = self._create_connection(**self.db_config)
103
+
104
+ # index configuration may have commands defined that we should set during each client session
105
+ session_options = self.case_config.session_param()
106
+
107
+ for key, val in session_options.items():
108
+ command = sql.SQL("SET {setting_name} " + "= {val};").format(
109
+ setting_name=sql.Identifier(key),
110
+ val=val,
111
+ )
112
+ log.debug(command.as_string(self.cursor))
113
+ self.cursor.execute(command)
69
114
  self.conn.commit()
70
115
 
116
+ self._filtered_search = sql.Composed(
117
+ [
118
+ sql.SQL(
119
+ "SELECT id FROM public.{table_name} WHERE id >= %s ORDER BY embedding "
120
+ ).format(table_name=sql.Identifier(self.table_name)),
121
+ sql.SQL(self.case_config.search_param()["metric_fun_op"]),
122
+ sql.SQL(" %s::vector LIMIT %s::int"),
123
+ ]
124
+ )
125
+
126
+ self._unfiltered_search = sql.Composed(
127
+ [
128
+ sql.SQL(
129
+ "SELECT id FROM public.{table_name} ORDER BY embedding "
130
+ ).format(table_name=sql.Identifier(self.table_name)),
131
+ sql.SQL(self.case_config.search_param()["metric_fun_op"]),
132
+ sql.SQL(" %s::vector LIMIT %s::int"),
133
+ ]
134
+ )
135
+
71
136
  try:
72
137
  yield
73
138
  finally:
@@ -79,42 +144,65 @@ class PgVectoRS(VectorDB):
79
144
  def _drop_table(self):
80
145
  assert self.conn is not None, "Connection is not initialized"
81
146
  assert self.cursor is not None, "Cursor is not initialized"
147
+ log.info(f"{self.name} client drop table : {self.table_name}")
82
148
 
83
- self.cursor.execute(f'DROP TABLE IF EXISTS public."{self.table_name}"')
149
+ self.cursor.execute(
150
+ sql.SQL("DROP TABLE IF EXISTS public.{table_name}").format(
151
+ table_name=sql.Identifier(self.table_name)
152
+ )
153
+ )
84
154
  self.conn.commit()
85
155
 
86
156
  def ready_to_load(self):
87
157
  pass
88
158
 
89
159
  def optimize(self):
90
- pass
160
+ self._post_insert()
91
161
 
92
- def ready_to_search(self):
93
- pass
162
+ def _post_insert(self):
163
+ log.info(f"{self.name} post insert before optimize")
164
+ if self.case_config.create_index_after_load:
165
+ self._drop_index()
166
+ self._create_index()
94
167
 
95
168
  def _drop_index(self):
96
169
  assert self.conn is not None, "Connection is not initialized"
97
170
  assert self.cursor is not None, "Cursor is not initialized"
171
+ log.info(f"{self.name} client drop index : {self._index_name}")
98
172
 
99
- self.cursor.execute(f'DROP INDEX IF EXISTS "{self._index_name}"')
173
+ drop_index_sql = sql.SQL("DROP INDEX IF EXISTS {index_name}").format(
174
+ index_name=sql.Identifier(self._index_name)
175
+ )
176
+ log.debug(drop_index_sql.as_string(self.cursor))
177
+ self.cursor.execute(drop_index_sql)
100
178
  self.conn.commit()
101
179
 
102
180
  def _create_index(self):
103
181
  assert self.conn is not None, "Connection is not initialized"
104
182
  assert self.cursor is not None, "Cursor is not initialized"
183
+ log.info(f"{self.name} client create index : {self._index_name}")
105
184
 
106
185
  index_param = self.case_config.index_param()
107
186
 
187
+ index_create_sql = sql.SQL(
188
+ """
189
+ CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name}
190
+ USING vectors (embedding {embedding_metric}) WITH (options = {index_options})
191
+ """
192
+ ).format(
193
+ index_name=sql.Identifier(self._index_name),
194
+ table_name=sql.Identifier(self.table_name),
195
+ embedding_metric=sql.Identifier(index_param["metric"]),
196
+ index_options=index_param["options"],
197
+ )
108
198
  try:
109
- # create table
110
- self.cursor.execute(
111
- f'CREATE INDEX IF NOT EXISTS {self._index_name} ON public."{self.table_name}" \
112
- USING vectors (embedding {index_param["metric"]}) WITH (options = $${index_param["options"]}$$);'
113
- )
199
+ log.debug(index_create_sql.as_string(self.cursor))
200
+ self.cursor.execute(index_create_sql)
114
201
  self.conn.commit()
115
202
  except Exception as e:
116
203
  log.warning(
117
- f"Failed to create pgvecto.rs table: {self.table_name} error: {e}"
204
+ f"Failed to create pgvecto.rs index {self._index_name} \
205
+ at table {self.table_name} error: {e}"
118
206
  )
119
207
  raise e from None
120
208
 
@@ -122,12 +210,18 @@ class PgVectoRS(VectorDB):
122
210
  assert self.conn is not None, "Connection is not initialized"
123
211
  assert self.cursor is not None, "Cursor is not initialized"
124
212
 
213
+ table_create_sql = sql.SQL(
214
+ """
215
+ CREATE TABLE IF NOT EXISTS public.{table_name}
216
+ (id BIGINT PRIMARY KEY, embedding vector({dim}))
217
+ """
218
+ ).format(
219
+ table_name=sql.Identifier(self.table_name),
220
+ dim=dim,
221
+ )
125
222
  try:
126
223
  # create table
127
- self.cursor.execute(
128
- f'CREATE TABLE IF NOT EXISTS public."{self.table_name}" \
129
- (id Integer PRIMARY KEY, embedding vector({dim}));'
130
- )
224
+ self.cursor.execute(table_create_sql)
131
225
  self.conn.commit()
132
226
  except Exception as e:
133
227
  log.warning(
@@ -140,7 +234,7 @@ class PgVectoRS(VectorDB):
140
234
  embeddings: list[list[float]],
141
235
  metadata: list[int],
142
236
  **kwargs: Any,
143
- ) -> (int, Exception):
237
+ ) -> Tuple[int, Optional[Exception]]:
144
238
  assert self.conn is not None, "Connection is not initialized"
145
239
  assert self.cursor is not None, "Cursor is not initialized"
146
240
 
@@ -148,19 +242,27 @@ class PgVectoRS(VectorDB):
148
242
  assert self.cursor is not None, "Cursor is not initialized"
149
243
 
150
244
  try:
151
- items = {
152
- "id": metadata,
153
- "embedding": embeddings
154
- }
155
- df = pd.DataFrame(items)
156
- csv_buffer = io.StringIO()
157
- df.to_csv(csv_buffer, index=False, header=False)
158
- csv_buffer.seek(0)
159
- self.cursor.copy_expert(f"COPY public.\"{self.table_name}\" FROM STDIN WITH (FORMAT CSV)", csv_buffer)
245
+ metadata_arr = np.array(metadata)
246
+ embeddings_arr = np.array(embeddings)
247
+
248
+ with self.cursor.copy(
249
+ sql.SQL("COPY public.{table_name} FROM STDIN (FORMAT BINARY)").format(
250
+ table_name=sql.Identifier(self.table_name)
251
+ )
252
+ ) as copy:
253
+ copy.set_types(["bigint", "vector"])
254
+ for i, row in enumerate(metadata_arr):
255
+ copy.write_row((row, embeddings_arr[i]))
160
256
  self.conn.commit()
257
+
258
+ if kwargs.get("last_batch"):
259
+ self._post_insert()
260
+
161
261
  return len(metadata), None
162
262
  except Exception as e:
163
- log.warning(f"Failed to insert data into pgvecto.rs table ({self.table_name}), error: {e}")
263
+ log.warning(
264
+ f"Failed to insert data into pgvecto.rs table ({self.table_name}), error: {e}"
265
+ )
164
266
  return 0, e
165
267
 
166
268
  def search_embedding(
@@ -173,20 +275,18 @@ class PgVectoRS(VectorDB):
173
275
  assert self.conn is not None, "Connection is not initialized"
174
276
  assert self.cursor is not None, "Cursor is not initialized"
175
277
 
176
- search_param = self.case_config.search_param()
278
+ q = np.asarray(query)
177
279
 
178
280
  if filters:
281
+ log.debug(self._filtered_search.as_string(self.cursor))
179
282
  gt = filters.get("id")
180
- self.cursor.execute(
181
- f"SELECT id FROM (SELECT * FROM public.\"{self.table_name}\" ORDER BY embedding \
182
- {search_param['metrics_op']} '{query}' LIMIT {k}) AS X WHERE id > {gt} ;"
283
+ result = self.cursor.execute(
284
+ self._filtered_search, (gt, q, k), prepare=True, binary=True
183
285
  )
184
286
  else:
185
- self.cursor.execute(
186
- f"SELECT id FROM public.\"{self.table_name}\" ORDER BY embedding \
187
- {search_param['metrics_op']} '{query}' LIMIT {k};"
287
+ log.debug(self._unfiltered_search.as_string(self.cursor))
288
+ result = self.cursor.execute(
289
+ self._unfiltered_search, (q, k), prepare=True, binary=True
188
290
  )
189
- self.conn.commit()
190
- result = self.cursor.fetchall()
191
291
 
192
- return [int(i[0]) for i in result]
292
+ return [int(i[0]) for i in result.fetchall()]
@@ -0,0 +1,111 @@
1
+ from abc import abstractmethod
2
+ from typing import TypedDict
3
+ from pydantic import BaseModel, SecretStr
4
+ from typing_extensions import LiteralString
5
+ from ..api import DBCaseConfig, DBConfig, IndexType, MetricType
6
+
7
+ POSTGRE_URL_PLACEHOLDER = "postgresql://%s:%s@%s/%s"
8
+
9
+
10
+ class PgVectorScaleConfigDict(TypedDict):
11
+ """These keys will be directly used as kwargs in psycopg connection string,
12
+ so the names must match exactly psycopg API"""
13
+
14
+ user: str
15
+ password: str
16
+ host: str
17
+ port: int
18
+ dbname: str
19
+
20
+
21
+ class PgVectorScaleConfig(DBConfig):
22
+ user_name: SecretStr = SecretStr("postgres")
23
+ password: SecretStr
24
+ host: str = "localhost"
25
+ port: int = 5432
26
+ db_name: str
27
+
28
+ def to_dict(self) -> PgVectorScaleConfigDict:
29
+ user_str = self.user_name.get_secret_value()
30
+ pwd_str = self.password.get_secret_value()
31
+ return {
32
+ "host": self.host,
33
+ "port": self.port,
34
+ "dbname": self.db_name,
35
+ "user": user_str,
36
+ "password": pwd_str,
37
+ }
38
+
39
+
40
+ class PgVectorScaleIndexConfig(BaseModel, DBCaseConfig):
41
+ metric_type: MetricType | None = None
42
+ create_index_before_load: bool = False
43
+ create_index_after_load: bool = True
44
+
45
+ def parse_metric(self) -> str:
46
+ if self.metric_type == MetricType.COSINE:
47
+ return "vector_cosine_ops"
48
+ return ""
49
+
50
+ def parse_metric_fun_op(self) -> LiteralString:
51
+ if self.metric_type == MetricType.COSINE:
52
+ return "<=>"
53
+ return ""
54
+
55
+ def parse_metric_fun_str(self) -> str:
56
+ if self.metric_type == MetricType.COSINE:
57
+ return "cosine_distance"
58
+ return ""
59
+
60
+ @abstractmethod
61
+ def index_param(self) -> dict:
62
+ ...
63
+
64
+ @abstractmethod
65
+ def search_param(self) -> dict:
66
+ ...
67
+
68
+ @abstractmethod
69
+ def session_param(self) -> dict:
70
+ ...
71
+
72
+
73
+ class PgVectorScaleStreamingDiskANNConfig(PgVectorScaleIndexConfig):
74
+ index: IndexType = IndexType.STREAMING_DISKANN
75
+ storage_layout: str | None
76
+ num_neighbors: int | None
77
+ search_list_size: int | None
78
+ max_alpha: float | None
79
+ num_dimensions: int | None
80
+ num_bits_per_dimension: int | None
81
+ query_search_list_size: int | None
82
+ query_rescore: int | None
83
+
84
+ def index_param(self) -> dict:
85
+ return {
86
+ "metric": self.parse_metric(),
87
+ "index_type": self.index.value,
88
+ "options": {
89
+ "storage_layout": self.storage_layout,
90
+ "num_neighbors": self.num_neighbors,
91
+ "search_list_size": self.search_list_size,
92
+ "max_alpha": self.max_alpha,
93
+ "num_dimensions": self.num_dimensions,
94
+ },
95
+ }
96
+
97
+ def search_param(self) -> dict:
98
+ return {
99
+ "metric": self.parse_metric(),
100
+ "metric_fun_op": self.parse_metric_fun_op(),
101
+ }
102
+
103
+ def session_param(self) -> dict:
104
+ return {
105
+ "diskann.query_search_list_size": self.query_search_list_size,
106
+ "diskann.query_rescore": self.query_rescore,
107
+ }
108
+
109
+ _pgvectorscale_case_config = {
110
+ IndexType.STREAMING_DISKANN: PgVectorScaleStreamingDiskANNConfig,
111
+ }