vectordb-bench 0.0.12__py3-none-any.whl → 0.0.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (23) hide show
  1. vectordb_bench/backend/clients/__init__.py +22 -0
  2. vectordb_bench/backend/clients/api.py +21 -1
  3. vectordb_bench/backend/clients/memorydb/cli.py +88 -0
  4. vectordb_bench/backend/clients/memorydb/config.py +54 -0
  5. vectordb_bench/backend/clients/memorydb/memorydb.py +254 -0
  6. vectordb_bench/backend/clients/pgvecto_rs/cli.py +154 -0
  7. vectordb_bench/backend/clients/pgvecto_rs/config.py +108 -73
  8. vectordb_bench/backend/clients/pgvecto_rs/pgvecto_rs.py +159 -59
  9. vectordb_bench/backend/clients/pgvectorscale/config.py +111 -0
  10. vectordb_bench/backend/clients/pgvectorscale/pgvectorscale.py +272 -0
  11. vectordb_bench/cli/vectordbbench.py +5 -0
  12. vectordb_bench/frontend/components/check_results/data.py +13 -6
  13. vectordb_bench/frontend/components/run_test/caseSelector.py +10 -0
  14. vectordb_bench/frontend/components/run_test/dbConfigSetting.py +37 -15
  15. vectordb_bench/frontend/components/run_test/initStyle.py +3 -1
  16. vectordb_bench/frontend/config/dbCaseConfigs.py +173 -9
  17. vectordb_bench/models.py +18 -6
  18. {vectordb_bench-0.0.12.dist-info → vectordb_bench-0.0.13.dist-info}/METADATA +11 -3
  19. {vectordb_bench-0.0.12.dist-info → vectordb_bench-0.0.13.dist-info}/RECORD +23 -17
  20. {vectordb_bench-0.0.12.dist-info → vectordb_bench-0.0.13.dist-info}/WHEEL +1 -1
  21. {vectordb_bench-0.0.12.dist-info → vectordb_bench-0.0.13.dist-info}/LICENSE +0 -0
  22. {vectordb_bench-0.0.12.dist-info → vectordb_bench-0.0.13.dist-info}/entry_points.txt +0 -0
  23. {vectordb_bench-0.0.12.dist-info → vectordb_bench-0.0.13.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,272 @@
1
+ """Wrapper around the Pgvectorscale vector database over VectorDB"""
2
+
3
+ import logging
4
+ import pprint
5
+ from contextlib import contextmanager
6
+ from typing import Any, Generator, Optional, Tuple
7
+
8
+ import numpy as np
9
+ import psycopg
10
+ from pgvector.psycopg import register_vector
11
+ from psycopg import Connection, Cursor, sql
12
+
13
+ from ..api import VectorDB
14
+ from .config import PgVectorScaleConfigDict, PgVectorScaleIndexConfig
15
+
16
+ log = logging.getLogger(__name__)
17
+
18
+
19
+ class PgVectorScale(VectorDB):
20
+ """Use psycopg instructions"""
21
+
22
+ conn: psycopg.Connection[Any] | None = None
23
+ coursor: psycopg.Cursor[Any] | None = None
24
+
25
+ def __init__(
26
+ self,
27
+ dim: int,
28
+ db_config: PgVectorScaleConfigDict,
29
+ db_case_config: PgVectorScaleIndexConfig,
30
+ collection_name: str = "pg_vectorscale_collection",
31
+ drop_old: bool = False,
32
+ **kwargs,
33
+ ):
34
+ self.name = "PgVectorScale"
35
+ self.db_config = db_config
36
+ self.case_config = db_case_config
37
+ self.table_name = collection_name
38
+ self.dim = dim
39
+
40
+ self._index_name = "pgvectorscale_index"
41
+ self._primary_field = "id"
42
+ self._vector_field = "embedding"
43
+
44
+ self.conn, self.cursor = self._create_connection(**self.db_config)
45
+
46
+ log.info(f"{self.name} config values: {self.db_config}\n{self.case_config}")
47
+ if not any(
48
+ (
49
+ self.case_config.create_index_before_load,
50
+ self.case_config.create_index_after_load,
51
+ )
52
+ ):
53
+ err = f"{self.name} config must create an index using create_index_before_load or create_index_after_load"
54
+ log.error(err)
55
+ raise RuntimeError(
56
+ f"{err}\n{pprint.pformat(self.db_config)}\n{pprint.pformat(self.case_config)}"
57
+ )
58
+
59
+ if drop_old:
60
+ self._drop_index()
61
+ self._drop_table()
62
+ self._create_table(dim)
63
+ if self.case_config.create_index_before_load:
64
+ self._create_index()
65
+
66
+ self.cursor.close()
67
+ self.conn.close()
68
+ self.cursor = None
69
+ self.conn = None
70
+
71
+ @staticmethod
72
+ def _create_connection(**kwargs) -> Tuple[Connection, Cursor]:
73
+ conn = psycopg.connect(**kwargs)
74
+ conn.cursor().execute("CREATE EXTENSION IF NOT EXISTS vectorscale CASCADE")
75
+ conn.commit()
76
+ register_vector(conn)
77
+ conn.autocommit = False
78
+ cursor = conn.cursor()
79
+
80
+ assert conn is not None, "Connection is not initialized"
81
+ assert cursor is not None, "Cursor is not initialized"
82
+
83
+ return conn, cursor
84
+
85
+ @contextmanager
86
+ def init(self) -> Generator[None, None, None]:
87
+ self.conn, self.cursor = self._create_connection(**self.db_config)
88
+
89
+ # index configuration may have commands defined that we should set during each client session
90
+ session_options: dict[str, Any] = self.case_config.session_param()
91
+
92
+ if len(session_options) > 0:
93
+ for setting_name, setting_val in session_options.items():
94
+ command = sql.SQL("SET {setting_name} " + "= {setting_val};").format(
95
+ setting_name=sql.Identifier(setting_name),
96
+ setting_val=sql.Identifier(str(setting_val)),
97
+ )
98
+ log.debug(command.as_string(self.cursor))
99
+ self.cursor.execute(command)
100
+ self.conn.commit()
101
+
102
+ self._unfiltered_search = sql.Composed(
103
+ [
104
+ sql.SQL("SELECT id FROM public.{} ORDER BY embedding ").format(
105
+ sql.Identifier(self.table_name)
106
+ ),
107
+ sql.SQL(self.case_config.search_param()["metric_fun_op"]),
108
+ sql.SQL(" %s::vector LIMIT %s::int"),
109
+ ]
110
+ )
111
+
112
+ try:
113
+ yield
114
+ finally:
115
+ self.cursor.close()
116
+ self.conn.close()
117
+ self.cursor = None
118
+ self.conn = None
119
+
120
+ def _drop_table(self):
121
+ assert self.conn is not None, "Connection is not initialized"
122
+ assert self.cursor is not None, "Cursor is not initialized"
123
+ log.info(f"{self.name} client drop table : {self.table_name}")
124
+
125
+ self.cursor.execute(
126
+ sql.SQL("DROP TABLE IF EXISTS public.{table_name}").format(
127
+ table_name=sql.Identifier(self.table_name)
128
+ )
129
+ )
130
+ self.conn.commit()
131
+
132
+ def ready_to_load(self):
133
+ pass
134
+
135
+ def optimize(self):
136
+ self._post_insert()
137
+
138
+ def _post_insert(self):
139
+ log.info(f"{self.name} post insert before optimize")
140
+ if self.case_config.create_index_after_load:
141
+ self._drop_index()
142
+ self._create_index()
143
+
144
+ def _drop_index(self):
145
+ assert self.conn is not None, "Connection is not initialized"
146
+ assert self.cursor is not None, "Cursor is not initialized"
147
+ log.info(f"{self.name} client drop index : {self._index_name}")
148
+
149
+ drop_index_sql = sql.SQL("DROP INDEX IF EXISTS {index_name}").format(
150
+ index_name=sql.Identifier(self._index_name)
151
+ )
152
+ log.debug(drop_index_sql.as_string(self.cursor))
153
+ self.cursor.execute(drop_index_sql)
154
+ self.conn.commit()
155
+
156
+ def _create_index(self):
157
+ assert self.conn is not None, "Connection is not initialized"
158
+ assert self.cursor is not None, "Cursor is not initialized"
159
+ log.info(f"{self.name} client create index : {self._index_name}")
160
+
161
+ index_param: dict[str, Any] = self.case_config.index_param()
162
+
163
+ options = []
164
+ for option_name, option_val in index_param["options"].items():
165
+ if option_val is not None:
166
+ options.append(
167
+ sql.SQL("{option_name} = {val}").format(
168
+ option_name=sql.Identifier(option_name),
169
+ val=sql.Identifier(str(option_val)),
170
+ )
171
+ )
172
+
173
+ num_bits_per_dimension = "2" if self.dim < 900 else "1"
174
+ options.append(
175
+ sql.SQL("{option_name} = {val}").format(
176
+ option_name=sql.Identifier("num_bits_per_dimension"),
177
+ val=sql.Identifier(num_bits_per_dimension),
178
+ )
179
+ )
180
+
181
+ if any(options):
182
+ with_clause = sql.SQL("WITH ({});").format(sql.SQL(", ").join(options))
183
+ else:
184
+ with_clause = sql.Composed(())
185
+
186
+ index_create_sql = sql.SQL(
187
+ """
188
+ CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name}
189
+ USING {index_type} (embedding {embedding_metric})
190
+ """
191
+ ).format(
192
+ index_name=sql.Identifier(self._index_name),
193
+ table_name=sql.Identifier(self.table_name),
194
+ index_type=sql.Identifier(index_param["index_type"].lower()),
195
+ embedding_metric=sql.Identifier(index_param["metric"]),
196
+ )
197
+ index_create_sql_with_with_clause = (
198
+ index_create_sql + with_clause
199
+ ).join(" ")
200
+ log.debug(index_create_sql_with_with_clause.as_string(self.cursor))
201
+ self.cursor.execute(index_create_sql_with_with_clause)
202
+ self.conn.commit()
203
+
204
+ def _create_table(self, dim: int):
205
+ assert self.conn is not None, "Connection is not initialized"
206
+ assert self.cursor is not None, "Cursor is not initialized"
207
+
208
+ try:
209
+ log.info(f"{self.name} client create table : {self.table_name}")
210
+
211
+ self.cursor.execute(
212
+ sql.SQL(
213
+ "CREATE TABLE IF NOT EXISTS public.{table_name} (id BIGINT PRIMARY KEY, embedding vector({dim}));"
214
+ ).format(table_name=sql.Identifier(self.table_name), dim=dim)
215
+ )
216
+ self.conn.commit()
217
+ except Exception as e:
218
+ log.warning(
219
+ f"Failed to create pgvectorscale table: {self.table_name} error: {e}"
220
+ )
221
+ raise e from None
222
+
223
+ def insert_embeddings(
224
+ self,
225
+ embeddings: list[list[float]],
226
+ metadata: list[int],
227
+ **kwargs: Any,
228
+ ) -> Tuple[int, Optional[Exception]]:
229
+ assert self.conn is not None, "Connection is not initialized"
230
+ assert self.cursor is not None, "Cursor is not initialized"
231
+
232
+ try:
233
+ metadata_arr = np.array(metadata)
234
+ embeddings_arr = np.array(embeddings)
235
+
236
+ with self.cursor.copy(
237
+ sql.SQL("COPY public.{table_name} FROM STDIN (FORMAT BINARY)").format(
238
+ table_name=sql.Identifier(self.table_name)
239
+ )
240
+ ) as copy:
241
+ copy.set_types(["bigint", "vector"])
242
+ for i, row in enumerate(metadata_arr):
243
+ copy.write_row((row, embeddings_arr[i]))
244
+ self.conn.commit()
245
+
246
+ if kwargs.get("last_batch"):
247
+ self._post_insert()
248
+
249
+ return len(metadata), None
250
+ except Exception as e:
251
+ log.warning(
252
+ f"Failed to insert data into pgvector table ({self.table_name}), error: {e}"
253
+ )
254
+ return 0, e
255
+
256
+ def search_embedding(
257
+ self,
258
+ query: list[float],
259
+ k: int = 100,
260
+ filters: dict | None = None,
261
+ timeout: int | None = None,
262
+ ) -> list[int]:
263
+ assert self.conn is not None, "Connection is not initialized"
264
+ assert self.cursor is not None, "Cursor is not initialized"
265
+
266
+ q = np.asarray(query)
267
+ # TODO add filters support
268
+ result = self.cursor.execute(
269
+ self._unfiltered_search, (q, k), prepare=True, binary=True
270
+ )
271
+
272
+ return [int(i[0]) for i in result.fetchall()]
@@ -1,5 +1,7 @@
1
1
  from ..backend.clients.pgvector.cli import PgVectorHNSW
2
+ from ..backend.clients.pgvecto_rs.cli import PgVectoRSHNSW, PgVectoRSIVFFlat
2
3
  from ..backend.clients.redis.cli import Redis
4
+ from ..backend.clients.memorydb.cli import MemoryDB
3
5
  from ..backend.clients.test.cli import Test
4
6
  from ..backend.clients.weaviate_cloud.cli import Weaviate
5
7
  from ..backend.clients.zilliz_cloud.cli import ZillizAutoIndex
@@ -10,7 +12,10 @@ from ..backend.clients.aws_opensearch.cli import AWSOpenSearch
10
12
  from .cli import cli
11
13
 
12
14
  cli.add_command(PgVectorHNSW)
15
+ cli.add_command(PgVectoRSHNSW)
16
+ cli.add_command(PgVectoRSIVFFlat)
13
17
  cli.add_command(Redis)
18
+ cli.add_command(MemoryDB)
14
19
  cli.add_command(Weaviate)
15
20
  cli.add_command(Test)
16
21
  cli.add_command(ZillizAutoIndex)
@@ -24,7 +24,10 @@ def getFilterTasks(
24
24
  task
25
25
  for task in tasks
26
26
  if task.task_config.db_name in dbNames
27
- and task.task_config.case_config.case_id.case_cls(task.task_config.case_config.custom_case).name in caseNames
27
+ and task.task_config.case_config.case_id.case_cls(
28
+ task.task_config.case_config.custom_case
29
+ ).name
30
+ in caseNames
28
31
  ]
29
32
  return filterTasks
30
33
 
@@ -35,17 +38,20 @@ def mergeTasks(tasks: list[CaseResult]):
35
38
  db_name = task.task_config.db_name
36
39
  db = task.task_config.db.value
37
40
  db_label = task.task_config.db_config.db_label or ""
38
- case = task.task_config.case_config.case_id.case_cls(task.task_config.case_config.custom_case)
41
+ version = task.task_config.db_config.version or ""
42
+ case = task.task_config.case_config.case_id.case_cls(
43
+ task.task_config.case_config.custom_case
44
+ )
39
45
  dbCaseMetricsMap[db_name][case.name] = {
40
46
  "db": db,
41
47
  "db_label": db_label,
48
+ "version": version,
42
49
  "metrics": mergeMetrics(
43
50
  dbCaseMetricsMap[db_name][case.name].get("metrics", {}),
44
51
  asdict(task.metrics),
45
52
  ),
46
53
  "label": getBetterLabel(
47
- dbCaseMetricsMap[db_name][case.name].get(
48
- "label", ResultLabel.FAILED),
54
+ dbCaseMetricsMap[db_name][case.name].get("label", ResultLabel.FAILED),
49
55
  task.label,
50
56
  ),
51
57
  }
@@ -57,6 +63,7 @@ def mergeTasks(tasks: list[CaseResult]):
57
63
  metrics = metricInfo["metrics"]
58
64
  db = metricInfo["db"]
59
65
  db_label = metricInfo["db_label"]
66
+ version = metricInfo["version"]
60
67
  label = metricInfo["label"]
61
68
  if label == ResultLabel.NORMAL:
62
69
  mergedTasks.append(
@@ -64,6 +71,7 @@ def mergeTasks(tasks: list[CaseResult]):
64
71
  "db_name": db_name,
65
72
  "db": db,
66
73
  "db_label": db_label,
74
+ "version": version,
67
75
  "case_name": case_name,
68
76
  "metricsSet": set(metrics.keys()),
69
77
  **metrics,
@@ -79,8 +87,7 @@ def mergeMetrics(metrics_1: dict, metrics_2: dict) -> dict:
79
87
  metrics = {**metrics_1}
80
88
  for key, value in metrics_2.items():
81
89
  metrics[key] = (
82
- getBetterMetric(
83
- key, value, metrics[key]) if key in metrics else value
90
+ getBetterMetric(key, value, metrics[key]) if key in metrics else value
84
91
  )
85
92
 
86
93
  return metrics
@@ -100,6 +100,16 @@ def caseConfigSetting(st, dbToCaseClusterConfigs, uiCaseItem: UICaseItem, active
100
100
  value=config.inputConfig["value"],
101
101
  help=config.inputHelp,
102
102
  )
103
+ elif config.inputType == InputType.Float:
104
+ caseConfig[config.label] = column.number_input(
105
+ config.displayLabel if config.displayLabel else config.label.value,
106
+ step=config.inputConfig.get("step", 0.1),
107
+ min_value=config.inputConfig["min"],
108
+ max_value=config.inputConfig["max"],
109
+ key=key,
110
+ value=config.inputConfig["value"],
111
+ help=config.inputHelp,
112
+ )
103
113
  k += 1
104
114
  if k == 0:
105
115
  columns[1].write("Auto")
@@ -1,9 +1,10 @@
1
1
  from pydantic import ValidationError
2
- from vectordb_bench.frontend.config.styles import *
2
+ from vectordb_bench.backend.clients import DB
3
+ from vectordb_bench.frontend.config.styles import DB_CONFIG_SETTING_COLUMNS
3
4
  from vectordb_bench.frontend.utils import inputIsPassword
4
5
 
5
6
 
6
- def dbConfigSettings(st, activedDbList):
7
+ def dbConfigSettings(st, activedDbList: list[DB]):
7
8
  expander = st.expander("Configurations for the selected databases", True)
8
9
 
9
10
  dbConfigs = {}
@@ -27,7 +28,7 @@ def dbConfigSettings(st, activedDbList):
27
28
  return dbConfigs, isAllValid
28
29
 
29
30
 
30
- def dbConfigSettingItem(st, activeDb):
31
+ def dbConfigSettingItem(st, activeDb: DB):
31
32
  st.markdown(
32
33
  f"<div style='font-weight: 600; font-size: 20px; margin-top: 16px;'>{activeDb.value}</div>",
33
34
  unsafe_allow_html=True,
@@ -36,20 +37,41 @@ def dbConfigSettingItem(st, activeDb):
36
37
 
37
38
  dbConfigClass = activeDb.config_cls
38
39
  properties = dbConfigClass.schema().get("properties")
39
- propertiesItems = list(properties.items())
40
- moveDBLabelToLast(propertiesItems)
41
40
  dbConfig = {}
42
- for j, property in enumerate(propertiesItems):
43
- column = columns[j % DB_CONFIG_SETTING_COLUMNS]
44
- key, value = property
41
+ idx = 0
42
+
43
+ # db config (unique)
44
+ for key, property in properties.items():
45
+ if (
46
+ key not in dbConfigClass.common_short_configs()
47
+ and key not in dbConfigClass.common_long_configs()
48
+ ):
49
+ column = columns[idx % DB_CONFIG_SETTING_COLUMNS]
50
+ idx += 1
51
+ dbConfig[key] = column.text_input(
52
+ key,
53
+ key="%s-%s" % (activeDb.name, key),
54
+ value=property.get("default", ""),
55
+ type="password" if inputIsPassword(key) else "default",
56
+ )
57
+ # db config (common short labels)
58
+ for key in dbConfigClass.common_short_configs():
59
+ column = columns[idx % DB_CONFIG_SETTING_COLUMNS]
60
+ idx += 1
45
61
  dbConfig[key] = column.text_input(
46
62
  key,
47
- key="%s-%s" % (activeDb, key),
48
- value=value.get("default", ""),
49
- type="password" if inputIsPassword(key) else "default",
63
+ key="%s-%s" % (activeDb.name, key),
64
+ value="",
65
+ type="default",
66
+ placeholder="optional, for labeling results",
50
67
  )
51
- return dbConfig
52
-
53
68
 
54
- def moveDBLabelToLast(propertiesItems):
55
- propertiesItems.sort(key=lambda x: 1 if x[0] == "db_label" else 0)
69
+ # db config (common long text_input)
70
+ for key in dbConfigClass.common_long_configs():
71
+ dbConfig[key] = st.text_area(
72
+ key,
73
+ key="%s-%s" % (activeDb.name, key),
74
+ value="",
75
+ placeholder="optional",
76
+ )
77
+ return dbConfig
@@ -9,6 +9,8 @@ def initStyle(st):
9
9
  div[data-testid='stHorizontalBlock'] {gap: 8px;}
10
10
  /* check box */
11
11
  .stCheckbox p { color: #000; font-size: 18px; font-weight: 600; }
12
+ /* db selector - db_name should not wrap */
13
+ div[data-testid="stVerticalBlockBorderWrapper"] div[data-testid="stCheckbox"] div[data-testid="stWidgetLabel"] p { white-space: nowrap; }
12
14
  </style>""",
13
15
  unsafe_allow_html=True,
14
- )
16
+ )