vectordb-bench 0.0.13__py3-none-any.whl → 0.0.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py +47 -6
  2. vectordb_bench/backend/clients/aws_opensearch/config.py +12 -6
  3. vectordb_bench/backend/clients/aws_opensearch/run.py +34 -3
  4. vectordb_bench/backend/clients/pgvector/cli.py +17 -2
  5. vectordb_bench/backend/clients/pgvector/config.py +20 -5
  6. vectordb_bench/backend/clients/pgvector/pgvector.py +95 -25
  7. vectordb_bench/backend/clients/pgvectorscale/cli.py +108 -0
  8. vectordb_bench/backend/clients/pgvectorscale/pgvectorscale.py +22 -4
  9. vectordb_bench/backend/clients/pinecone/config.py +0 -2
  10. vectordb_bench/backend/clients/pinecone/pinecone.py +34 -36
  11. vectordb_bench/backend/clients/redis/cli.py +8 -0
  12. vectordb_bench/backend/clients/redis/config.py +37 -6
  13. vectordb_bench/backend/runner/mp_runner.py +2 -1
  14. vectordb_bench/cli/cli.py +137 -0
  15. vectordb_bench/cli/vectordbbench.py +2 -1
  16. vectordb_bench/frontend/components/check_results/charts.py +9 -6
  17. vectordb_bench/frontend/components/concurrent/charts.py +3 -6
  18. vectordb_bench/frontend/config/dbCaseConfigs.py +57 -0
  19. vectordb_bench/frontend/pages/quries_per_dollar.py +13 -5
  20. vectordb_bench/frontend/vdb_benchmark.py +11 -3
  21. vectordb_bench/models.py +7 -3
  22. vectordb_bench/results/Milvus/result_20230727_standard_milvus.json +53 -1
  23. vectordb_bench/results/Milvus/result_20230808_standard_milvus.json +48 -0
  24. vectordb_bench/results/ZillizCloud/result_20230727_standard_zillizcloud.json +29 -1
  25. vectordb_bench/results/ZillizCloud/result_20230808_standard_zillizcloud.json +24 -0
  26. vectordb_bench/results/ZillizCloud/result_20240105_standard_202401_zillizcloud.json +98 -49
  27. vectordb_bench/results/getLeaderboardData.py +17 -7
  28. vectordb_bench/results/leaderboard.json +1 -1
  29. {vectordb_bench-0.0.13.dist-info → vectordb_bench-0.0.14.dist-info}/METADATA +60 -35
  30. {vectordb_bench-0.0.13.dist-info → vectordb_bench-0.0.14.dist-info}/RECORD +34 -33
  31. {vectordb_bench-0.0.13.dist-info → vectordb_bench-0.0.14.dist-info}/WHEEL +1 -1
  32. {vectordb_bench-0.0.13.dist-info → vectordb_bench-0.0.14.dist-info}/LICENSE +0 -0
  33. {vectordb_bench-0.0.13.dist-info → vectordb_bench-0.0.14.dist-info}/entry_points.txt +0 -0
  34. {vectordb_bench-0.0.13.dist-info → vectordb_bench-0.0.14.dist-info}/top_level.txt +0 -0
@@ -3,7 +3,7 @@ from contextlib import contextmanager
3
3
  import time
4
4
  from typing import Iterable, Type
5
5
  from ..api import VectorDB, DBCaseConfig, DBConfig, IndexType
6
- from .config import AWSOpenSearchConfig, AWSOpenSearchIndexConfig
6
+ from .config import AWSOpenSearchConfig, AWSOpenSearchIndexConfig, AWSOS_Engine
7
7
  from opensearchpy import OpenSearch
8
8
  from opensearchpy.helpers import bulk
9
9
 
@@ -83,7 +83,7 @@ class AWSOpenSearch(VectorDB):
83
83
 
84
84
  @contextmanager
85
85
  def init(self) -> None:
86
- """connect to elasticsearch"""
86
+ """connect to opensearch"""
87
87
  self.client = OpenSearch(**self.db_config)
88
88
 
89
89
  yield
@@ -97,7 +97,7 @@ class AWSOpenSearch(VectorDB):
97
97
  metadata: list[int],
98
98
  **kwargs,
99
99
  ) -> tuple[int, Exception]:
100
- """Insert the embeddings to the elasticsearch."""
100
+ """Insert the embeddings to the opensearch."""
101
101
  assert self.client is not None, "should self.init() first"
102
102
 
103
103
  insert_data = []
@@ -136,13 +136,15 @@ class AWSOpenSearch(VectorDB):
136
136
  body = {
137
137
  "size": k,
138
138
  "query": {"knn": {self.vector_col_name: {"vector": query, "k": k}}},
139
+ **({"filter": {"range": {self.id_col_name: {"gt": filters["id"]}}}} if filters else {})
139
140
  }
140
141
  try:
141
- resp = self.client.search(index=self.index_name, body=body)
142
+ resp = self.client.search(index=self.index_name, body=body,size=k,_source=False,docvalue_fields=[self.id_col_name],stored_fields="_none_",filter_path=[f"hits.hits.fields.{self.id_col_name}"],)
142
143
  log.info(f'Search took: {resp["took"]}')
143
144
  log.info(f'Search shards: {resp["_shards"]}')
144
145
  log.info(f'Search hits total: {resp["hits"]["total"]}')
145
- result = [int(d["_id"]) for d in resp["hits"]["hits"]]
146
+ result = [h["fields"][self.id_col_name][0] for h in resp["hits"]["hits"]]
147
+ #result = [int(d["_id"]) for d in resp["hits"]["hits"]]
146
148
  # log.info(f'success! length={len(res)}')
147
149
 
148
150
  return result
@@ -152,7 +154,46 @@ class AWSOpenSearch(VectorDB):
152
154
 
153
155
  def optimize(self):
154
156
  """optimize will be called between insertion and search in performance cases."""
155
- pass
157
+ # Call refresh first to ensure that all segments are created
158
+ self._refresh_index()
159
+ self._do_force_merge()
160
+ # Call refresh again to ensure that the index is ready after force merge.
161
+ self._refresh_index()
162
+ # ensure that all graphs are loaded in memory and ready for search
163
+ self._load_graphs_to_memory()
164
+
165
+ def _refresh_index(self):
166
+ log.debug(f"Starting refresh for index {self.index_name}")
167
+ SECONDS_WAITING_FOR_REFRESH_API_CALL_SEC = 30
168
+ while True:
169
+ try:
170
+ log.info(f"Starting the Refresh Index..")
171
+ self.client.indices.refresh(index=self.index_name)
172
+ break
173
+ except Exception as e:
174
+ log.info(
175
+ f"Refresh errored out. Sleeping for {SECONDS_WAITING_FOR_REFRESH_API_CALL_SEC} sec and then Retrying : {e}")
176
+ time.sleep(SECONDS_WAITING_FOR_REFRESH_API_CALL_SEC)
177
+ continue
178
+ log.debug(f"Completed refresh for index {self.index_name}")
179
+
180
+ def _do_force_merge(self):
181
+ log.debug(f"Starting force merge for index {self.index_name}")
182
+ force_merge_endpoint = f'/{self.index_name}/_forcemerge?max_num_segments=1&wait_for_completion=false'
183
+ force_merge_task_id = self.client.transport.perform_request('POST', force_merge_endpoint)['task']
184
+ SECONDS_WAITING_FOR_FORCE_MERGE_API_CALL_SEC = 30
185
+ while True:
186
+ time.sleep(SECONDS_WAITING_FOR_FORCE_MERGE_API_CALL_SEC)
187
+ task_status = self.client.tasks.get(task_id=force_merge_task_id)
188
+ if task_status['completed']:
189
+ break
190
+ log.debug(f"Completed force merge for index {self.index_name}")
191
+
192
+ def _load_graphs_to_memory(self):
193
+ if self.case_config.engine != AWSOS_Engine.lucene:
194
+ log.info("Calling warmup API to load graphs into memory")
195
+ warmup_endpoint = f'/_plugins/_knn/warmup/{self.index_name}'
196
+ self.client.transport.perform_request('GET', warmup_endpoint)
156
197
 
157
198
  def ready_to_load(self):
158
199
  """ready_to_load will be called before load in load cases."""
@@ -1,9 +1,10 @@
1
+ import logging
1
2
  from enum import Enum
2
3
  from pydantic import SecretStr, BaseModel
3
4
 
4
5
  from ..api import DBConfig, DBCaseConfig, MetricType, IndexType
5
6
 
6
-
7
+ log = logging.getLogger(__name__)
7
8
  class AWSOpenSearchConfig(DBConfig, BaseModel):
8
9
  host: str = ""
9
10
  port: int = 443
@@ -31,14 +32,18 @@ class AWSOS_Engine(Enum):
31
32
 
32
33
  class AWSOpenSearchIndexConfig(BaseModel, DBCaseConfig):
33
34
  metric_type: MetricType = MetricType.L2
34
- engine: AWSOS_Engine = AWSOS_Engine.nmslib
35
- efConstruction: int = 360
36
- M: int = 30
35
+ engine: AWSOS_Engine = AWSOS_Engine.faiss
36
+ efConstruction: int = 256
37
+ efSearch: int = 256
38
+ M: int = 16
37
39
 
38
40
  def parse_metric(self) -> str:
39
41
  if self.metric_type == MetricType.IP:
40
- return "innerproduct" # only support faiss / nmslib, not for Lucene.
42
+ return "innerproduct"
41
43
  elif self.metric_type == MetricType.COSINE:
44
+ if self.engine == AWSOS_Engine.faiss:
45
+ log.info(f"Using metric type as innerproduct because faiss doesn't support cosine as metric type for Opensearch")
46
+ return "innerproduct"
42
47
  return "cosinesimil"
43
48
  return "l2"
44
49
 
@@ -49,7 +54,8 @@ class AWSOpenSearchIndexConfig(BaseModel, DBCaseConfig):
49
54
  "engine": self.engine.value,
50
55
  "parameters": {
51
56
  "ef_construction": self.efConstruction,
52
- "m": self.M
57
+ "m": self.M,
58
+ "ef_search": self.efSearch
53
59
  }
54
60
  }
55
61
  return params
@@ -40,12 +40,12 @@ def create_index(client, index_name):
40
40
  "type": "knn_vector",
41
41
  "dimension": _DIM,
42
42
  "method": {
43
- "engine": "nmslib",
43
+ "engine": "faiss",
44
44
  "name": "hnsw",
45
45
  "space_type": "l2",
46
46
  "parameters": {
47
- "ef_construction": 128,
48
- "m": 24,
47
+ "ef_construction": 256,
48
+ "m": 16,
49
49
  }
50
50
  }
51
51
  }
@@ -108,12 +108,43 @@ def search(client, index_name):
108
108
  print('\nSearch not ready, sleep 1s')
109
109
  time.sleep(1)
110
110
 
111
+ def optimize_index(client, index_name):
112
+ print(f"Starting force merge for index {index_name}")
113
+ force_merge_endpoint = f'/{index_name}/_forcemerge?max_num_segments=1&wait_for_completion=false'
114
+ force_merge_task_id = client.transport.perform_request('POST', force_merge_endpoint)['task']
115
+ SECONDS_WAITING_FOR_FORCE_MERGE_API_CALL_SEC = 30
116
+ while True:
117
+ time.sleep(SECONDS_WAITING_FOR_FORCE_MERGE_API_CALL_SEC)
118
+ task_status = client.tasks.get(task_id=force_merge_task_id)
119
+ if task_status['completed']:
120
+ break
121
+ print(f"Completed force merge for index {index_name}")
122
+
123
+
124
+ def refresh_index(client, index_name):
125
+ print(f"Starting refresh for index {index_name}")
126
+ SECONDS_WAITING_FOR_REFRESH_API_CALL_SEC = 30
127
+ while True:
128
+ try:
129
+ print(f"Starting the Refresh Index..")
130
+ client.indices.refresh(index=index_name)
131
+ break
132
+ except Exception as e:
133
+ print(
134
+ f"Refresh errored out. Sleeping for {SECONDS_WAITING_FOR_REFRESH_API_CALL_SEC} sec and then Retrying : {e}")
135
+ time.sleep(SECONDS_WAITING_FOR_REFRESH_API_CALL_SEC)
136
+ continue
137
+ print(f"Completed refresh for index {index_name}")
138
+
139
+
111
140
 
112
141
  def main():
113
142
  client = create_client()
114
143
  try:
115
144
  create_index(client, _INDEX_NAME)
116
145
  bulk_insert(client, _INDEX_NAME)
146
+ optimize_index(client, _INDEX_NAME)
147
+ refresh_index(client, _INDEX_NAME)
117
148
  search(client, _INDEX_NAME)
118
149
  delete_index(client, _INDEX_NAME)
119
150
  except Exception as e:
@@ -10,6 +10,7 @@ from ....cli.cli import (
10
10
  IVFFlatTypedDict,
11
11
  cli,
12
12
  click_parameter_decorators_from_typed_dict,
13
+ get_custom_case_config,
13
14
  run,
14
15
  )
15
16
  from vectordb_bench.backend.clients import DB
@@ -56,7 +57,15 @@ class PgVectorTypedDict(CommonTypedDict):
56
57
  required=False,
57
58
  ),
58
59
  ]
59
-
60
+ quantization_type: Annotated[
61
+ Optional[str],
62
+ click.option(
63
+ "--quantization-type",
64
+ type=click.Choice(["none", "halfvec"]),
65
+ help="quantization type for vectors",
66
+ required=False,
67
+ ),
68
+ ]
60
69
 
61
70
  class PgVectorIVFFlatTypedDict(PgVectorTypedDict, IVFFlatTypedDict):
62
71
  ...
@@ -69,6 +78,7 @@ def PgVectorIVFFlat(
69
78
  ):
70
79
  from .config import PgVectorConfig, PgVectorIVFFlatConfig
71
80
 
81
+ parameters["custom_case"] = get_custom_case_config(parameters)
72
82
  run(
73
83
  db=DB.PgVector,
74
84
  db_config=PgVectorConfig(
@@ -79,7 +89,10 @@ def PgVectorIVFFlat(
79
89
  db_name=parameters["db_name"],
80
90
  ),
81
91
  db_case_config=PgVectorIVFFlatConfig(
82
- metric_type=None, lists=parameters["lists"], probes=parameters["probes"]
92
+ metric_type=None,
93
+ lists=parameters["lists"],
94
+ probes=parameters["probes"],
95
+ quantization_type=parameters["quantization_type"],
83
96
  ),
84
97
  **parameters,
85
98
  )
@@ -96,6 +109,7 @@ def PgVectorHNSW(
96
109
  ):
97
110
  from .config import PgVectorConfig, PgVectorHNSWConfig
98
111
 
112
+ parameters["custom_case"] = get_custom_case_config(parameters)
99
113
  run(
100
114
  db=DB.PgVector,
101
115
  db_config=PgVectorConfig(
@@ -111,6 +125,7 @@ def PgVectorHNSW(
111
125
  ef_search=parameters["ef_search"],
112
126
  maintenance_work_mem=parameters["maintenance_work_mem"],
113
127
  max_parallel_workers=parameters["max_parallel_workers"],
128
+ quantization_type=parameters["quantization_type"],
114
129
  ),
115
130
  **parameters,
116
131
  )
@@ -59,11 +59,18 @@ class PgVectorIndexConfig(BaseModel, DBCaseConfig):
59
59
  create_index_after_load: bool = True
60
60
 
61
61
  def parse_metric(self) -> str:
62
- if self.metric_type == MetricType.L2:
63
- return "vector_l2_ops"
64
- elif self.metric_type == MetricType.IP:
65
- return "vector_ip_ops"
66
- return "vector_cosine_ops"
62
+ if self.quantization_type == "halfvec":
63
+ if self.metric_type == MetricType.L2:
64
+ return "halfvec_l2_ops"
65
+ elif self.metric_type == MetricType.IP:
66
+ return "halfvec_ip_ops"
67
+ return "halfvec_cosine_ops"
68
+ else:
69
+ if self.metric_type == MetricType.L2:
70
+ return "vector_l2_ops"
71
+ elif self.metric_type == MetricType.IP:
72
+ return "vector_ip_ops"
73
+ return "vector_cosine_ops"
67
74
 
68
75
  def parse_metric_fun_op(self) -> LiteralString:
69
76
  if self.metric_type == MetricType.L2:
@@ -143,9 +150,12 @@ class PgVectorIVFFlatConfig(PgVectorIndexConfig):
143
150
  index: IndexType = IndexType.ES_IVFFlat
144
151
  maintenance_work_mem: Optional[str] = None
145
152
  max_parallel_workers: Optional[int] = None
153
+ quantization_type: Optional[str] = None
146
154
 
147
155
  def index_param(self) -> PgVectorIndexParam:
148
156
  index_parameters = {"lists": self.lists}
157
+ if self.quantization_type == "none":
158
+ self.quantization_type = None
149
159
  return {
150
160
  "metric": self.parse_metric(),
151
161
  "index_type": self.index.value,
@@ -154,6 +164,7 @@ class PgVectorIVFFlatConfig(PgVectorIndexConfig):
154
164
  ),
155
165
  "maintenance_work_mem": self.maintenance_work_mem,
156
166
  "max_parallel_workers": self.max_parallel_workers,
167
+ "quantization_type": self.quantization_type,
157
168
  }
158
169
 
159
170
  def search_param(self) -> PgVectorSearchParam:
@@ -183,9 +194,12 @@ class PgVectorHNSWConfig(PgVectorIndexConfig):
183
194
  index: IndexType = IndexType.ES_HNSW
184
195
  maintenance_work_mem: Optional[str] = None
185
196
  max_parallel_workers: Optional[int] = None
197
+ quantization_type: Optional[str] = None
186
198
 
187
199
  def index_param(self) -> PgVectorIndexParam:
188
200
  index_parameters = {"m": self.m, "ef_construction": self.ef_construction}
201
+ if self.quantization_type == "none":
202
+ self.quantization_type = None
189
203
  return {
190
204
  "metric": self.parse_metric(),
191
205
  "index_type": self.index.value,
@@ -194,6 +208,7 @@ class PgVectorHNSWConfig(PgVectorIndexConfig):
194
208
  ),
195
209
  "maintenance_work_mem": self.maintenance_work_mem,
196
210
  "max_parallel_workers": self.max_parallel_workers,
211
+ "quantization_type": self.quantization_type,
197
212
  }
198
213
 
199
214
  def search_param(self) -> PgVectorSearchParam:
@@ -22,7 +22,7 @@ class PgVector(VectorDB):
22
22
  conn: psycopg.Connection[Any] | None = None
23
23
  cursor: psycopg.Cursor[Any] | None = None
24
24
 
25
- # TODO add filters support
25
+ _filtered_search: sql.Composed
26
26
  _unfiltered_search: sql.Composed
27
27
 
28
28
  def __init__(
@@ -112,15 +112,63 @@ class PgVector(VectorDB):
112
112
  self.cursor.execute(command)
113
113
  self.conn.commit()
114
114
 
115
- self._unfiltered_search = sql.Composed(
116
- [
117
- sql.SQL("SELECT id FROM public.{} ORDER BY embedding ").format(
118
- sql.Identifier(self.table_name)
119
- ),
120
- sql.SQL(self.case_config.search_param()["metric_fun_op"]),
121
- sql.SQL(" %s::vector LIMIT %s::int"),
122
- ]
123
- )
115
+ index_param = self.case_config.index_param()
116
+ # The following sections assume that the quantization_type value matches the quantization function name
117
+ if index_param["quantization_type"] != None:
118
+ self._filtered_search = sql.Composed(
119
+ [
120
+ sql.SQL(
121
+ "SELECT id FROM public.{table_name} WHERE id >= %s ORDER BY embedding::{quantization_type}({dim}) "
122
+ ).format(
123
+ table_name=sql.Identifier(self.table_name),
124
+ quantization_type=sql.SQL(index_param["quantization_type"]),
125
+ dim=sql.Literal(self.dim),
126
+ ),
127
+ sql.SQL(self.case_config.search_param()["metric_fun_op"]),
128
+ sql.SQL(" %s::{quantization_type}({dim}) LIMIT %s::int").format(
129
+ quantization_type=sql.SQL(index_param["quantization_type"]),
130
+ dim=sql.Literal(self.dim),
131
+ ),
132
+ ]
133
+ )
134
+ else:
135
+ self._filtered_search = sql.Composed(
136
+ [
137
+ sql.SQL(
138
+ "SELECT id FROM public.{table_name} WHERE id >= %s ORDER BY embedding "
139
+ ).format(table_name=sql.Identifier(self.table_name)),
140
+ sql.SQL(self.case_config.search_param()["metric_fun_op"]),
141
+ sql.SQL(" %s::vector LIMIT %s::int"),
142
+ ]
143
+ )
144
+
145
+ if index_param["quantization_type"] != None:
146
+ self._unfiltered_search = sql.Composed(
147
+ [
148
+ sql.SQL(
149
+ "SELECT id FROM public.{table_name} ORDER BY embedding::{quantization_type}({dim}) "
150
+ ).format(
151
+ table_name=sql.Identifier(self.table_name),
152
+ quantization_type=sql.SQL(index_param["quantization_type"]),
153
+ dim=sql.Literal(self.dim),
154
+ ),
155
+ sql.SQL(self.case_config.search_param()["metric_fun_op"]),
156
+ sql.SQL(" %s::{quantization_type}({dim}) LIMIT %s::int").format(
157
+ quantization_type=sql.SQL(index_param["quantization_type"]),
158
+ dim=sql.Literal(self.dim),
159
+ ),
160
+ ]
161
+ )
162
+ else:
163
+ self._unfiltered_search = sql.Composed(
164
+ [
165
+ sql.SQL("SELECT id FROM public.{} ORDER BY embedding ").format(
166
+ sql.Identifier(self.table_name)
167
+ ),
168
+ sql.SQL(self.case_config.search_param()["metric_fun_op"]),
169
+ sql.SQL(" %s::vector LIMIT %s::int"),
170
+ ]
171
+ )
124
172
 
125
173
  try:
126
174
  yield
@@ -255,17 +303,34 @@ class PgVector(VectorDB):
255
303
  else:
256
304
  with_clause = sql.Composed(())
257
305
 
258
- index_create_sql = sql.SQL(
259
- """
260
- CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name}
261
- USING {index_type} (embedding {embedding_metric})
262
- """
263
- ).format(
264
- index_name=sql.Identifier(self._index_name),
265
- table_name=sql.Identifier(self.table_name),
266
- index_type=sql.Identifier(index_param["index_type"]),
267
- embedding_metric=sql.Identifier(index_param["metric"]),
268
- )
306
+ if index_param["quantization_type"] != None:
307
+ index_create_sql = sql.SQL(
308
+ """
309
+ CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name}
310
+ USING {index_type} ((embedding::{quantization_type}({dim})) {embedding_metric})
311
+ """
312
+ ).format(
313
+ index_name=sql.Identifier(self._index_name),
314
+ table_name=sql.Identifier(self.table_name),
315
+ index_type=sql.Identifier(index_param["index_type"]),
316
+ # This assumes that the quantization_type value matches the quantization function name
317
+ quantization_type=sql.SQL(index_param["quantization_type"]),
318
+ dim=self.dim,
319
+ embedding_metric=sql.Identifier(index_param["metric"]),
320
+ )
321
+ else:
322
+ index_create_sql = sql.SQL(
323
+ """
324
+ CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name}
325
+ USING {index_type} (embedding {embedding_metric})
326
+ """
327
+ ).format(
328
+ index_name=sql.Identifier(self._index_name),
329
+ table_name=sql.Identifier(self.table_name),
330
+ index_type=sql.Identifier(index_param["index_type"]),
331
+ embedding_metric=sql.Identifier(index_param["metric"]),
332
+ )
333
+
269
334
  index_create_sql_with_with_clause = (
270
335
  index_create_sql + with_clause
271
336
  ).join(" ")
@@ -342,9 +407,14 @@ class PgVector(VectorDB):
342
407
  assert self.cursor is not None, "Cursor is not initialized"
343
408
 
344
409
  q = np.asarray(query)
345
- # TODO add filters support
346
- result = self.cursor.execute(
347
- self._unfiltered_search, (q, k), prepare=True, binary=True
348
- )
410
+ if filters:
411
+ gt = filters.get("id")
412
+ result = self.cursor.execute(
413
+ self._filtered_search, (gt, q, k), prepare=True, binary=True
414
+ )
415
+ else:
416
+ result = self.cursor.execute(
417
+ self._unfiltered_search, (q, k), prepare=True, binary=True
418
+ )
349
419
 
350
420
  return [int(i[0]) for i in result.fetchall()]
@@ -0,0 +1,108 @@
1
+ import click
2
+ import os
3
+ from pydantic import SecretStr
4
+
5
+ from ....cli.cli import (
6
+ CommonTypedDict,
7
+ cli,
8
+ click_parameter_decorators_from_typed_dict,
9
+ run,
10
+ )
11
+ from typing import Annotated, Unpack
12
+ from vectordb_bench.backend.clients import DB
13
+
14
+
15
+ class PgVectorScaleTypedDict(CommonTypedDict):
16
+ user_name: Annotated[
17
+ str, click.option("--user-name", type=str, help="Db username", required=True)
18
+ ]
19
+ password: Annotated[
20
+ str,
21
+ click.option("--password",
22
+ type=str,
23
+ help="Postgres database password",
24
+ default=lambda: os.environ.get("POSTGRES_PASSWORD", ""),
25
+ show_default="$POSTGRES_PASSWORD",
26
+ ),
27
+ ]
28
+
29
+ host: Annotated[
30
+ str, click.option("--host", type=str, help="Db host", required=True)
31
+ ]
32
+ db_name: Annotated[
33
+ str, click.option("--db-name", type=str, help="Db name", required=True)
34
+ ]
35
+
36
+
37
+ class PgVectorScaleDiskAnnTypedDict(PgVectorScaleTypedDict):
38
+ storage_layout: Annotated[
39
+ str,
40
+ click.option(
41
+ "--storage-layout", type=str, help="Streaming DiskANN storage layout",
42
+ ),
43
+ ]
44
+ num_neighbors: Annotated[
45
+ int,
46
+ click.option(
47
+ "--num-neighbors", type=int, help="Streaming DiskANN num neighbors",
48
+ ),
49
+ ]
50
+ search_list_size: Annotated[
51
+ int,
52
+ click.option(
53
+ "--search-list-size", type=int, help="Streaming DiskANN search list size",
54
+ ),
55
+ ]
56
+ max_alpha: Annotated[
57
+ float,
58
+ click.option(
59
+ "--max-alpha", type=float, help="Streaming DiskANN max alpha",
60
+ ),
61
+ ]
62
+ num_dimensions: Annotated[
63
+ int,
64
+ click.option(
65
+ "--num-dimensions", type=int, help="Streaming DiskANN num dimensions",
66
+ ),
67
+ ]
68
+ query_search_list_size: Annotated[
69
+ int,
70
+ click.option(
71
+ "--query-search-list-size", type=int, help="Streaming DiskANN query search list size",
72
+ ),
73
+ ]
74
+ query_rescore: Annotated[
75
+ int,
76
+ click.option(
77
+ "--query-rescore", type=int, help="Streaming DiskANN query rescore",
78
+ ),
79
+ ]
80
+
81
+
82
+ @cli.command()
83
+ @click_parameter_decorators_from_typed_dict(PgVectorScaleDiskAnnTypedDict)
84
+ def PgVectorScaleDiskAnn(
85
+ **parameters: Unpack[PgVectorScaleDiskAnnTypedDict],
86
+ ):
87
+ from .config import PgVectorScaleConfig, PgVectorScaleStreamingDiskANNConfig
88
+
89
+ run(
90
+ db=DB.PgVectorScale,
91
+ db_config=PgVectorScaleConfig(
92
+ db_label=parameters["db_label"],
93
+ user_name=SecretStr(parameters["user_name"]),
94
+ password=SecretStr(parameters["password"]),
95
+ host=parameters["host"],
96
+ db_name=parameters["db_name"],
97
+ ),
98
+ db_case_config=PgVectorScaleStreamingDiskANNConfig(
99
+ storage_layout=parameters["storage_layout"],
100
+ num_neighbors=parameters["num_neighbors"],
101
+ search_list_size=parameters["search_list_size"],
102
+ max_alpha=parameters["max_alpha"],
103
+ num_dimensions=parameters["num_dimensions"],
104
+ query_search_list_size=parameters["query_search_list_size"],
105
+ query_rescore=parameters["query_rescore"],
106
+ ),
107
+ **parameters,
108
+ )
@@ -22,6 +22,9 @@ class PgVectorScale(VectorDB):
22
22
  conn: psycopg.Connection[Any] | None = None
23
23
  coursor: psycopg.Cursor[Any] | None = None
24
24
 
25
+ _unfiltered_search: sql.Composed
26
+ _filtered_search: sql.Composed
27
+
25
28
  def __init__(
26
29
  self,
27
30
  dim: int,
@@ -99,6 +102,16 @@ class PgVectorScale(VectorDB):
99
102
  self.cursor.execute(command)
100
103
  self.conn.commit()
101
104
 
105
+ self._filtered_search = sql.Composed(
106
+ [
107
+ sql.SQL("SELECT id FROM public.{} WHERE id >= %s ORDER BY embedding ").format(
108
+ sql.Identifier(self.table_name),
109
+ ),
110
+ sql.SQL(self.case_config.search_param()["metric_fun_op"]),
111
+ sql.SQL(" %s::vector LIMIT %s::int")
112
+ ]
113
+ )
114
+
102
115
  self._unfiltered_search = sql.Composed(
103
116
  [
104
117
  sql.SQL("SELECT id FROM public.{} ORDER BY embedding ").format(
@@ -264,9 +277,14 @@ class PgVectorScale(VectorDB):
264
277
  assert self.cursor is not None, "Cursor is not initialized"
265
278
 
266
279
  q = np.asarray(query)
267
- # TODO add filters support
268
- result = self.cursor.execute(
269
- self._unfiltered_search, (q, k), prepare=True, binary=True
270
- )
280
+ if filters:
281
+ gt = filters.get("id")
282
+ result = self.cursor.execute(
283
+ self._filtered_search, (gt, q, k), prepare=True, binary=True
284
+ )
285
+ else:
286
+ result = self.cursor.execute(
287
+ self._unfiltered_search, (q, k), prepare=True, binary=True
288
+ )
271
289
 
272
290
  return [int(i[0]) for i in result.fetchall()]
@@ -4,12 +4,10 @@ from ..api import DBConfig
4
4
 
5
5
  class PineconeConfig(DBConfig):
6
6
  api_key: SecretStr
7
- environment: SecretStr
8
7
  index_name: str
9
8
 
10
9
  def to_dict(self) -> dict:
11
10
  return {
12
11
  "api_key": self.api_key.get_secret_value(),
13
- "environment": self.environment.get_secret_value(),
14
12
  "index_name": self.index_name,
15
13
  }