vectordb-bench 0.0.12__py3-none-any.whl → 0.0.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. vectordb_bench/backend/clients/__init__.py +22 -0
  2. vectordb_bench/backend/clients/api.py +21 -1
  3. vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py +47 -6
  4. vectordb_bench/backend/clients/aws_opensearch/config.py +12 -6
  5. vectordb_bench/backend/clients/aws_opensearch/run.py +34 -3
  6. vectordb_bench/backend/clients/memorydb/cli.py +88 -0
  7. vectordb_bench/backend/clients/memorydb/config.py +54 -0
  8. vectordb_bench/backend/clients/memorydb/memorydb.py +254 -0
  9. vectordb_bench/backend/clients/pgvecto_rs/cli.py +154 -0
  10. vectordb_bench/backend/clients/pgvecto_rs/config.py +108 -73
  11. vectordb_bench/backend/clients/pgvecto_rs/pgvecto_rs.py +159 -59
  12. vectordb_bench/backend/clients/pgvector/cli.py +17 -2
  13. vectordb_bench/backend/clients/pgvector/config.py +20 -5
  14. vectordb_bench/backend/clients/pgvector/pgvector.py +95 -25
  15. vectordb_bench/backend/clients/pgvectorscale/cli.py +108 -0
  16. vectordb_bench/backend/clients/pgvectorscale/config.py +111 -0
  17. vectordb_bench/backend/clients/pgvectorscale/pgvectorscale.py +290 -0
  18. vectordb_bench/backend/clients/pinecone/config.py +0 -2
  19. vectordb_bench/backend/clients/pinecone/pinecone.py +34 -36
  20. vectordb_bench/backend/clients/redis/cli.py +8 -0
  21. vectordb_bench/backend/clients/redis/config.py +37 -6
  22. vectordb_bench/backend/runner/mp_runner.py +2 -1
  23. vectordb_bench/cli/cli.py +137 -0
  24. vectordb_bench/cli/vectordbbench.py +7 -1
  25. vectordb_bench/frontend/components/check_results/charts.py +9 -6
  26. vectordb_bench/frontend/components/check_results/data.py +13 -6
  27. vectordb_bench/frontend/components/concurrent/charts.py +3 -6
  28. vectordb_bench/frontend/components/run_test/caseSelector.py +10 -0
  29. vectordb_bench/frontend/components/run_test/dbConfigSetting.py +37 -15
  30. vectordb_bench/frontend/components/run_test/initStyle.py +3 -1
  31. vectordb_bench/frontend/config/dbCaseConfigs.py +230 -9
  32. vectordb_bench/frontend/pages/quries_per_dollar.py +13 -5
  33. vectordb_bench/frontend/vdb_benchmark.py +11 -3
  34. vectordb_bench/models.py +25 -9
  35. vectordb_bench/results/Milvus/result_20230727_standard_milvus.json +53 -1
  36. vectordb_bench/results/Milvus/result_20230808_standard_milvus.json +48 -0
  37. vectordb_bench/results/ZillizCloud/result_20230727_standard_zillizcloud.json +29 -1
  38. vectordb_bench/results/ZillizCloud/result_20230808_standard_zillizcloud.json +24 -0
  39. vectordb_bench/results/ZillizCloud/result_20240105_standard_202401_zillizcloud.json +98 -49
  40. vectordb_bench/results/getLeaderboardData.py +17 -7
  41. vectordb_bench/results/leaderboard.json +1 -1
  42. {vectordb_bench-0.0.12.dist-info → vectordb_bench-0.0.14.dist-info}/METADATA +64 -31
  43. {vectordb_bench-0.0.12.dist-info → vectordb_bench-0.0.14.dist-info}/RECORD +47 -40
  44. {vectordb_bench-0.0.12.dist-info → vectordb_bench-0.0.14.dist-info}/WHEEL +1 -1
  45. {vectordb_bench-0.0.12.dist-info → vectordb_bench-0.0.14.dist-info}/LICENSE +0 -0
  46. {vectordb_bench-0.0.12.dist-info → vectordb_bench-0.0.14.dist-info}/entry_points.txt +0 -0
  47. {vectordb_bench-0.0.12.dist-info → vectordb_bench-0.0.14.dist-info}/top_level.txt +0 -0
@@ -1,73 +1,138 @@
1
1
  """Wrapper around the Pgvecto.rs vector database over VectorDB"""
2
2
 
3
- import io
4
3
  import logging
4
+ import pprint
5
5
  from contextlib import contextmanager
6
- from typing import Any
7
- import pandas as pd
8
- import psycopg2
9
- import psycopg2.extras
6
+ from typing import Any, Generator, Optional, Tuple
10
7
 
11
- from ..api import VectorDB, DBCaseConfig
8
+ import numpy as np
9
+ import psycopg
10
+ from psycopg import Connection, Cursor, sql
11
+ from pgvecto_rs.psycopg import register_vector
12
+
13
+ from ..api import VectorDB
14
+ from .config import PgVectoRSConfig, PgVectoRSIndexConfig
12
15
 
13
16
  log = logging.getLogger(__name__)
14
17
 
18
+
15
19
  class PgVectoRS(VectorDB):
16
- """Use SQLAlchemy instructions"""
20
+ """Use psycopg instructions"""
21
+
22
+ conn: psycopg.Connection[Any] | None = None
23
+ cursor: psycopg.Cursor[Any] | None = None
24
+ _unfiltered_search: sql.Composed
25
+ _filtered_search: sql.Composed
17
26
 
18
27
  def __init__(
19
28
  self,
20
29
  dim: int,
21
- db_config: dict,
22
- db_case_config: DBCaseConfig,
23
- collection_name: str = "PgVectorCollection",
30
+ db_config: PgVectoRSConfig,
31
+ db_case_config: PgVectoRSIndexConfig,
32
+ collection_name: str = "PgVectoRSCollection",
24
33
  drop_old: bool = False,
25
34
  **kwargs,
26
35
  ):
36
+
37
+ self.name = "PgVectorRS"
27
38
  self.db_config = db_config
28
39
  self.case_config = db_case_config
29
40
  self.table_name = collection_name
30
41
  self.dim = dim
31
42
 
32
- self._index_name = "pqvector_index"
43
+ self._index_name = "pgvectors_index"
33
44
  self._primary_field = "id"
34
45
  self._vector_field = "embedding"
35
46
 
36
47
  # construct basic units
37
- self.conn = psycopg2.connect(**self.db_config)
38
- self.conn.autocommit = False
39
- self.cursor = self.conn.cursor()
48
+ self.conn, self.cursor = self._create_connection(**self.db_config)
40
49
 
41
- # create vector extension
42
- self.cursor.execute("CREATE EXTENSION IF NOT EXISTS vectors")
43
- self.conn.commit()
50
+ log.info(f"{self.name} config values: {self.db_config}\n{self.case_config}")
51
+ if not any(
52
+ (
53
+ self.case_config.create_index_before_load,
54
+ self.case_config.create_index_after_load,
55
+ )
56
+ ):
57
+ err = f"{self.name} config must create an index using create_index_before_load or create_index_after_load"
58
+ log.error(err)
59
+ raise RuntimeError(
60
+ f"{err}\n{pprint.pformat(self.db_config)}\n{pprint.pformat(self.case_config)}"
61
+ )
44
62
 
45
63
  if drop_old:
46
64
  log.info(f"Pgvecto.rs client drop table : {self.table_name}")
47
65
  self._drop_index()
48
66
  self._drop_table()
49
67
  self._create_table(dim)
50
- self._create_index()
68
+ if self.case_config.create_index_before_load:
69
+ self._create_index()
51
70
 
52
71
  self.cursor.close()
53
72
  self.conn.close()
54
73
  self.cursor = None
55
74
  self.conn = None
56
75
 
76
+ @staticmethod
77
+ def _create_connection(**kwargs) -> Tuple[Connection, Cursor]:
78
+ conn = psycopg.connect(**kwargs)
79
+
80
+ # create vector extension
81
+ conn.execute("CREATE EXTENSION IF NOT EXISTS vectors")
82
+ conn.commit()
83
+ register_vector(conn)
84
+
85
+ conn.autocommit = False
86
+ cursor = conn.cursor()
87
+
88
+ assert conn is not None, "Connection is not initialized"
89
+ assert cursor is not None, "Cursor is not initialized"
90
+
91
+ return conn, cursor
92
+
57
93
  @contextmanager
58
- def init(self) -> None:
94
+ def init(self) -> Generator[None, None, None]:
59
95
  """
60
96
  Examples:
61
97
  >>> with self.init():
62
98
  >>> self.insert_embeddings()
63
99
  >>> self.search_embedding()
64
100
  """
65
- self.conn = psycopg2.connect(**self.db_config)
66
- self.conn.autocommit = False
67
- self.cursor = self.conn.cursor()
68
- self.cursor.execute('SET search_path = "$user", public, vectors')
101
+
102
+ self.conn, self.cursor = self._create_connection(**self.db_config)
103
+
104
+ # index configuration may have commands defined that we should set during each client session
105
+ session_options = self.case_config.session_param()
106
+
107
+ for key, val in session_options.items():
108
+ command = sql.SQL("SET {setting_name} " + "= {val};").format(
109
+ setting_name=sql.Identifier(key),
110
+ val=val,
111
+ )
112
+ log.debug(command.as_string(self.cursor))
113
+ self.cursor.execute(command)
69
114
  self.conn.commit()
70
115
 
116
+ self._filtered_search = sql.Composed(
117
+ [
118
+ sql.SQL(
119
+ "SELECT id FROM public.{table_name} WHERE id >= %s ORDER BY embedding "
120
+ ).format(table_name=sql.Identifier(self.table_name)),
121
+ sql.SQL(self.case_config.search_param()["metric_fun_op"]),
122
+ sql.SQL(" %s::vector LIMIT %s::int"),
123
+ ]
124
+ )
125
+
126
+ self._unfiltered_search = sql.Composed(
127
+ [
128
+ sql.SQL(
129
+ "SELECT id FROM public.{table_name} ORDER BY embedding "
130
+ ).format(table_name=sql.Identifier(self.table_name)),
131
+ sql.SQL(self.case_config.search_param()["metric_fun_op"]),
132
+ sql.SQL(" %s::vector LIMIT %s::int"),
133
+ ]
134
+ )
135
+
71
136
  try:
72
137
  yield
73
138
  finally:
@@ -79,42 +144,65 @@ class PgVectoRS(VectorDB):
79
144
  def _drop_table(self):
80
145
  assert self.conn is not None, "Connection is not initialized"
81
146
  assert self.cursor is not None, "Cursor is not initialized"
147
+ log.info(f"{self.name} client drop table : {self.table_name}")
82
148
 
83
- self.cursor.execute(f'DROP TABLE IF EXISTS public."{self.table_name}"')
149
+ self.cursor.execute(
150
+ sql.SQL("DROP TABLE IF EXISTS public.{table_name}").format(
151
+ table_name=sql.Identifier(self.table_name)
152
+ )
153
+ )
84
154
  self.conn.commit()
85
155
 
86
156
  def ready_to_load(self):
87
157
  pass
88
158
 
89
159
  def optimize(self):
90
- pass
160
+ self._post_insert()
91
161
 
92
- def ready_to_search(self):
93
- pass
162
+ def _post_insert(self):
163
+ log.info(f"{self.name} post insert before optimize")
164
+ if self.case_config.create_index_after_load:
165
+ self._drop_index()
166
+ self._create_index()
94
167
 
95
168
  def _drop_index(self):
96
169
  assert self.conn is not None, "Connection is not initialized"
97
170
  assert self.cursor is not None, "Cursor is not initialized"
171
+ log.info(f"{self.name} client drop index : {self._index_name}")
98
172
 
99
- self.cursor.execute(f'DROP INDEX IF EXISTS "{self._index_name}"')
173
+ drop_index_sql = sql.SQL("DROP INDEX IF EXISTS {index_name}").format(
174
+ index_name=sql.Identifier(self._index_name)
175
+ )
176
+ log.debug(drop_index_sql.as_string(self.cursor))
177
+ self.cursor.execute(drop_index_sql)
100
178
  self.conn.commit()
101
179
 
102
180
  def _create_index(self):
103
181
  assert self.conn is not None, "Connection is not initialized"
104
182
  assert self.cursor is not None, "Cursor is not initialized"
183
+ log.info(f"{self.name} client create index : {self._index_name}")
105
184
 
106
185
  index_param = self.case_config.index_param()
107
186
 
187
+ index_create_sql = sql.SQL(
188
+ """
189
+ CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name}
190
+ USING vectors (embedding {embedding_metric}) WITH (options = {index_options})
191
+ """
192
+ ).format(
193
+ index_name=sql.Identifier(self._index_name),
194
+ table_name=sql.Identifier(self.table_name),
195
+ embedding_metric=sql.Identifier(index_param["metric"]),
196
+ index_options=index_param["options"],
197
+ )
108
198
  try:
109
- # create table
110
- self.cursor.execute(
111
- f'CREATE INDEX IF NOT EXISTS {self._index_name} ON public."{self.table_name}" \
112
- USING vectors (embedding {index_param["metric"]}) WITH (options = $${index_param["options"]}$$);'
113
- )
199
+ log.debug(index_create_sql.as_string(self.cursor))
200
+ self.cursor.execute(index_create_sql)
114
201
  self.conn.commit()
115
202
  except Exception as e:
116
203
  log.warning(
117
- f"Failed to create pgvecto.rs table: {self.table_name} error: {e}"
204
+ f"Failed to create pgvecto.rs index {self._index_name} \
205
+ at table {self.table_name} error: {e}"
118
206
  )
119
207
  raise e from None
120
208
 
@@ -122,12 +210,18 @@ class PgVectoRS(VectorDB):
122
210
  assert self.conn is not None, "Connection is not initialized"
123
211
  assert self.cursor is not None, "Cursor is not initialized"
124
212
 
213
+ table_create_sql = sql.SQL(
214
+ """
215
+ CREATE TABLE IF NOT EXISTS public.{table_name}
216
+ (id BIGINT PRIMARY KEY, embedding vector({dim}))
217
+ """
218
+ ).format(
219
+ table_name=sql.Identifier(self.table_name),
220
+ dim=dim,
221
+ )
125
222
  try:
126
223
  # create table
127
- self.cursor.execute(
128
- f'CREATE TABLE IF NOT EXISTS public."{self.table_name}" \
129
- (id Integer PRIMARY KEY, embedding vector({dim}));'
130
- )
224
+ self.cursor.execute(table_create_sql)
131
225
  self.conn.commit()
132
226
  except Exception as e:
133
227
  log.warning(
@@ -140,7 +234,7 @@ class PgVectoRS(VectorDB):
140
234
  embeddings: list[list[float]],
141
235
  metadata: list[int],
142
236
  **kwargs: Any,
143
- ) -> (int, Exception):
237
+ ) -> Tuple[int, Optional[Exception]]:
144
238
  assert self.conn is not None, "Connection is not initialized"
145
239
  assert self.cursor is not None, "Cursor is not initialized"
146
240
 
@@ -148,19 +242,27 @@ class PgVectoRS(VectorDB):
148
242
  assert self.cursor is not None, "Cursor is not initialized"
149
243
 
150
244
  try:
151
- items = {
152
- "id": metadata,
153
- "embedding": embeddings
154
- }
155
- df = pd.DataFrame(items)
156
- csv_buffer = io.StringIO()
157
- df.to_csv(csv_buffer, index=False, header=False)
158
- csv_buffer.seek(0)
159
- self.cursor.copy_expert(f"COPY public.\"{self.table_name}\" FROM STDIN WITH (FORMAT CSV)", csv_buffer)
245
+ metadata_arr = np.array(metadata)
246
+ embeddings_arr = np.array(embeddings)
247
+
248
+ with self.cursor.copy(
249
+ sql.SQL("COPY public.{table_name} FROM STDIN (FORMAT BINARY)").format(
250
+ table_name=sql.Identifier(self.table_name)
251
+ )
252
+ ) as copy:
253
+ copy.set_types(["bigint", "vector"])
254
+ for i, row in enumerate(metadata_arr):
255
+ copy.write_row((row, embeddings_arr[i]))
160
256
  self.conn.commit()
257
+
258
+ if kwargs.get("last_batch"):
259
+ self._post_insert()
260
+
161
261
  return len(metadata), None
162
262
  except Exception as e:
163
- log.warning(f"Failed to insert data into pgvecto.rs table ({self.table_name}), error: {e}")
263
+ log.warning(
264
+ f"Failed to insert data into pgvecto.rs table ({self.table_name}), error: {e}"
265
+ )
164
266
  return 0, e
165
267
 
166
268
  def search_embedding(
@@ -173,20 +275,18 @@ class PgVectoRS(VectorDB):
173
275
  assert self.conn is not None, "Connection is not initialized"
174
276
  assert self.cursor is not None, "Cursor is not initialized"
175
277
 
176
- search_param = self.case_config.search_param()
278
+ q = np.asarray(query)
177
279
 
178
280
  if filters:
281
+ log.debug(self._filtered_search.as_string(self.cursor))
179
282
  gt = filters.get("id")
180
- self.cursor.execute(
181
- f"SELECT id FROM (SELECT * FROM public.\"{self.table_name}\" ORDER BY embedding \
182
- {search_param['metrics_op']} '{query}' LIMIT {k}) AS X WHERE id > {gt} ;"
283
+ result = self.cursor.execute(
284
+ self._filtered_search, (gt, q, k), prepare=True, binary=True
183
285
  )
184
286
  else:
185
- self.cursor.execute(
186
- f"SELECT id FROM public.\"{self.table_name}\" ORDER BY embedding \
187
- {search_param['metrics_op']} '{query}' LIMIT {k};"
287
+ log.debug(self._unfiltered_search.as_string(self.cursor))
288
+ result = self.cursor.execute(
289
+ self._unfiltered_search, (q, k), prepare=True, binary=True
188
290
  )
189
- self.conn.commit()
190
- result = self.cursor.fetchall()
191
291
 
192
- return [int(i[0]) for i in result]
292
+ return [int(i[0]) for i in result.fetchall()]
@@ -10,6 +10,7 @@ from ....cli.cli import (
10
10
  IVFFlatTypedDict,
11
11
  cli,
12
12
  click_parameter_decorators_from_typed_dict,
13
+ get_custom_case_config,
13
14
  run,
14
15
  )
15
16
  from vectordb_bench.backend.clients import DB
@@ -56,7 +57,15 @@ class PgVectorTypedDict(CommonTypedDict):
56
57
  required=False,
57
58
  ),
58
59
  ]
59
-
60
+ quantization_type: Annotated[
61
+ Optional[str],
62
+ click.option(
63
+ "--quantization-type",
64
+ type=click.Choice(["none", "halfvec"]),
65
+ help="quantization type for vectors",
66
+ required=False,
67
+ ),
68
+ ]
60
69
 
61
70
  class PgVectorIVFFlatTypedDict(PgVectorTypedDict, IVFFlatTypedDict):
62
71
  ...
@@ -69,6 +78,7 @@ def PgVectorIVFFlat(
69
78
  ):
70
79
  from .config import PgVectorConfig, PgVectorIVFFlatConfig
71
80
 
81
+ parameters["custom_case"] = get_custom_case_config(parameters)
72
82
  run(
73
83
  db=DB.PgVector,
74
84
  db_config=PgVectorConfig(
@@ -79,7 +89,10 @@ def PgVectorIVFFlat(
79
89
  db_name=parameters["db_name"],
80
90
  ),
81
91
  db_case_config=PgVectorIVFFlatConfig(
82
- metric_type=None, lists=parameters["lists"], probes=parameters["probes"]
92
+ metric_type=None,
93
+ lists=parameters["lists"],
94
+ probes=parameters["probes"],
95
+ quantization_type=parameters["quantization_type"],
83
96
  ),
84
97
  **parameters,
85
98
  )
@@ -96,6 +109,7 @@ def PgVectorHNSW(
96
109
  ):
97
110
  from .config import PgVectorConfig, PgVectorHNSWConfig
98
111
 
112
+ parameters["custom_case"] = get_custom_case_config(parameters)
99
113
  run(
100
114
  db=DB.PgVector,
101
115
  db_config=PgVectorConfig(
@@ -111,6 +125,7 @@ def PgVectorHNSW(
111
125
  ef_search=parameters["ef_search"],
112
126
  maintenance_work_mem=parameters["maintenance_work_mem"],
113
127
  max_parallel_workers=parameters["max_parallel_workers"],
128
+ quantization_type=parameters["quantization_type"],
114
129
  ),
115
130
  **parameters,
116
131
  )
@@ -59,11 +59,18 @@ class PgVectorIndexConfig(BaseModel, DBCaseConfig):
59
59
  create_index_after_load: bool = True
60
60
 
61
61
  def parse_metric(self) -> str:
62
- if self.metric_type == MetricType.L2:
63
- return "vector_l2_ops"
64
- elif self.metric_type == MetricType.IP:
65
- return "vector_ip_ops"
66
- return "vector_cosine_ops"
62
+ if self.quantization_type == "halfvec":
63
+ if self.metric_type == MetricType.L2:
64
+ return "halfvec_l2_ops"
65
+ elif self.metric_type == MetricType.IP:
66
+ return "halfvec_ip_ops"
67
+ return "halfvec_cosine_ops"
68
+ else:
69
+ if self.metric_type == MetricType.L2:
70
+ return "vector_l2_ops"
71
+ elif self.metric_type == MetricType.IP:
72
+ return "vector_ip_ops"
73
+ return "vector_cosine_ops"
67
74
 
68
75
  def parse_metric_fun_op(self) -> LiteralString:
69
76
  if self.metric_type == MetricType.L2:
@@ -143,9 +150,12 @@ class PgVectorIVFFlatConfig(PgVectorIndexConfig):
143
150
  index: IndexType = IndexType.ES_IVFFlat
144
151
  maintenance_work_mem: Optional[str] = None
145
152
  max_parallel_workers: Optional[int] = None
153
+ quantization_type: Optional[str] = None
146
154
 
147
155
  def index_param(self) -> PgVectorIndexParam:
148
156
  index_parameters = {"lists": self.lists}
157
+ if self.quantization_type == "none":
158
+ self.quantization_type = None
149
159
  return {
150
160
  "metric": self.parse_metric(),
151
161
  "index_type": self.index.value,
@@ -154,6 +164,7 @@ class PgVectorIVFFlatConfig(PgVectorIndexConfig):
154
164
  ),
155
165
  "maintenance_work_mem": self.maintenance_work_mem,
156
166
  "max_parallel_workers": self.max_parallel_workers,
167
+ "quantization_type": self.quantization_type,
157
168
  }
158
169
 
159
170
  def search_param(self) -> PgVectorSearchParam:
@@ -183,9 +194,12 @@ class PgVectorHNSWConfig(PgVectorIndexConfig):
183
194
  index: IndexType = IndexType.ES_HNSW
184
195
  maintenance_work_mem: Optional[str] = None
185
196
  max_parallel_workers: Optional[int] = None
197
+ quantization_type: Optional[str] = None
186
198
 
187
199
  def index_param(self) -> PgVectorIndexParam:
188
200
  index_parameters = {"m": self.m, "ef_construction": self.ef_construction}
201
+ if self.quantization_type == "none":
202
+ self.quantization_type = None
189
203
  return {
190
204
  "metric": self.parse_metric(),
191
205
  "index_type": self.index.value,
@@ -194,6 +208,7 @@ class PgVectorHNSWConfig(PgVectorIndexConfig):
194
208
  ),
195
209
  "maintenance_work_mem": self.maintenance_work_mem,
196
210
  "max_parallel_workers": self.max_parallel_workers,
211
+ "quantization_type": self.quantization_type,
197
212
  }
198
213
 
199
214
  def search_param(self) -> PgVectorSearchParam:
@@ -22,7 +22,7 @@ class PgVector(VectorDB):
22
22
  conn: psycopg.Connection[Any] | None = None
23
23
  cursor: psycopg.Cursor[Any] | None = None
24
24
 
25
- # TODO add filters support
25
+ _filtered_search: sql.Composed
26
26
  _unfiltered_search: sql.Composed
27
27
 
28
28
  def __init__(
@@ -112,15 +112,63 @@ class PgVector(VectorDB):
112
112
  self.cursor.execute(command)
113
113
  self.conn.commit()
114
114
 
115
- self._unfiltered_search = sql.Composed(
116
- [
117
- sql.SQL("SELECT id FROM public.{} ORDER BY embedding ").format(
118
- sql.Identifier(self.table_name)
119
- ),
120
- sql.SQL(self.case_config.search_param()["metric_fun_op"]),
121
- sql.SQL(" %s::vector LIMIT %s::int"),
122
- ]
123
- )
115
+ index_param = self.case_config.index_param()
116
+ # The following sections assume that the quantization_type value matches the quantization function name
117
+ if index_param["quantization_type"] != None:
118
+ self._filtered_search = sql.Composed(
119
+ [
120
+ sql.SQL(
121
+ "SELECT id FROM public.{table_name} WHERE id >= %s ORDER BY embedding::{quantization_type}({dim}) "
122
+ ).format(
123
+ table_name=sql.Identifier(self.table_name),
124
+ quantization_type=sql.SQL(index_param["quantization_type"]),
125
+ dim=sql.Literal(self.dim),
126
+ ),
127
+ sql.SQL(self.case_config.search_param()["metric_fun_op"]),
128
+ sql.SQL(" %s::{quantization_type}({dim}) LIMIT %s::int").format(
129
+ quantization_type=sql.SQL(index_param["quantization_type"]),
130
+ dim=sql.Literal(self.dim),
131
+ ),
132
+ ]
133
+ )
134
+ else:
135
+ self._filtered_search = sql.Composed(
136
+ [
137
+ sql.SQL(
138
+ "SELECT id FROM public.{table_name} WHERE id >= %s ORDER BY embedding "
139
+ ).format(table_name=sql.Identifier(self.table_name)),
140
+ sql.SQL(self.case_config.search_param()["metric_fun_op"]),
141
+ sql.SQL(" %s::vector LIMIT %s::int"),
142
+ ]
143
+ )
144
+
145
+ if index_param["quantization_type"] != None:
146
+ self._unfiltered_search = sql.Composed(
147
+ [
148
+ sql.SQL(
149
+ "SELECT id FROM public.{table_name} ORDER BY embedding::{quantization_type}({dim}) "
150
+ ).format(
151
+ table_name=sql.Identifier(self.table_name),
152
+ quantization_type=sql.SQL(index_param["quantization_type"]),
153
+ dim=sql.Literal(self.dim),
154
+ ),
155
+ sql.SQL(self.case_config.search_param()["metric_fun_op"]),
156
+ sql.SQL(" %s::{quantization_type}({dim}) LIMIT %s::int").format(
157
+ quantization_type=sql.SQL(index_param["quantization_type"]),
158
+ dim=sql.Literal(self.dim),
159
+ ),
160
+ ]
161
+ )
162
+ else:
163
+ self._unfiltered_search = sql.Composed(
164
+ [
165
+ sql.SQL("SELECT id FROM public.{} ORDER BY embedding ").format(
166
+ sql.Identifier(self.table_name)
167
+ ),
168
+ sql.SQL(self.case_config.search_param()["metric_fun_op"]),
169
+ sql.SQL(" %s::vector LIMIT %s::int"),
170
+ ]
171
+ )
124
172
 
125
173
  try:
126
174
  yield
@@ -255,17 +303,34 @@ class PgVector(VectorDB):
255
303
  else:
256
304
  with_clause = sql.Composed(())
257
305
 
258
- index_create_sql = sql.SQL(
259
- """
260
- CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name}
261
- USING {index_type} (embedding {embedding_metric})
262
- """
263
- ).format(
264
- index_name=sql.Identifier(self._index_name),
265
- table_name=sql.Identifier(self.table_name),
266
- index_type=sql.Identifier(index_param["index_type"]),
267
- embedding_metric=sql.Identifier(index_param["metric"]),
268
- )
306
+ if index_param["quantization_type"] != None:
307
+ index_create_sql = sql.SQL(
308
+ """
309
+ CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name}
310
+ USING {index_type} ((embedding::{quantization_type}({dim})) {embedding_metric})
311
+ """
312
+ ).format(
313
+ index_name=sql.Identifier(self._index_name),
314
+ table_name=sql.Identifier(self.table_name),
315
+ index_type=sql.Identifier(index_param["index_type"]),
316
+ # This assumes that the quantization_type value matches the quantization function name
317
+ quantization_type=sql.SQL(index_param["quantization_type"]),
318
+ dim=self.dim,
319
+ embedding_metric=sql.Identifier(index_param["metric"]),
320
+ )
321
+ else:
322
+ index_create_sql = sql.SQL(
323
+ """
324
+ CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name}
325
+ USING {index_type} (embedding {embedding_metric})
326
+ """
327
+ ).format(
328
+ index_name=sql.Identifier(self._index_name),
329
+ table_name=sql.Identifier(self.table_name),
330
+ index_type=sql.Identifier(index_param["index_type"]),
331
+ embedding_metric=sql.Identifier(index_param["metric"]),
332
+ )
333
+
269
334
  index_create_sql_with_with_clause = (
270
335
  index_create_sql + with_clause
271
336
  ).join(" ")
@@ -342,9 +407,14 @@ class PgVector(VectorDB):
342
407
  assert self.cursor is not None, "Cursor is not initialized"
343
408
 
344
409
  q = np.asarray(query)
345
- # TODO add filters support
346
- result = self.cursor.execute(
347
- self._unfiltered_search, (q, k), prepare=True, binary=True
348
- )
410
+ if filters:
411
+ gt = filters.get("id")
412
+ result = self.cursor.execute(
413
+ self._filtered_search, (gt, q, k), prepare=True, binary=True
414
+ )
415
+ else:
416
+ result = self.cursor.execute(
417
+ self._unfiltered_search, (q, k), prepare=True, binary=True
418
+ )
349
419
 
350
420
  return [int(i[0]) for i in result.fetchall()]