vectordb-bench 0.0.22__py3-none-any.whl → 0.0.23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -94,7 +94,7 @@ class PgVector(VectorDB):
94
94
  reranking = self.case_config.search_param()["reranking"]
95
95
  column_name = (
96
96
  sql.SQL("binary_quantize({0})").format(sql.Identifier("embedding"))
97
- if index_param["quantization_type"] == "bit"
97
+ if index_param["quantization_type"] == "bit" and index_param["table_quantization_type"] != "bit"
98
98
  else sql.SQL("embedding")
99
99
  )
100
100
  search_vector = (
@@ -104,7 +104,8 @@ class PgVector(VectorDB):
104
104
  )
105
105
 
106
106
  # The following sections assume that the quantization_type value matches the quantization function name
107
- if index_param["quantization_type"] is not None:
107
+ if index_param["quantization_type"] != index_param["table_quantization_type"]:
108
+ # Reranking makes sense only if table quantization is not "bit"
108
109
  if index_param["quantization_type"] == "bit" and reranking:
109
110
  # Embeddings needs to be passed to binary_quantize function if quantization_type is bit
110
111
  search_query = sql.Composed(
@@ -113,7 +114,7 @@ class PgVector(VectorDB):
113
114
  """
114
115
  SELECT i.id
115
116
  FROM (
116
- SELECT id, embedding {reranking_metric_fun_op} %s::vector AS distance
117
+ SELECT id, embedding {reranking_metric_fun_op} %s::{table_quantization_type} AS distance
117
118
  FROM public.{table_name} {where_clause}
118
119
  ORDER BY {column_name}::{quantization_type}({dim})
119
120
  """,
@@ -123,6 +124,8 @@ class PgVector(VectorDB):
123
124
  reranking_metric_fun_op=sql.SQL(
124
125
  self.case_config.search_param()["reranking_metric_fun_op"],
125
126
  ),
127
+ search_vector=search_vector,
128
+ table_quantization_type=sql.SQL(index_param["table_quantization_type"]),
126
129
  quantization_type=sql.SQL(index_param["quantization_type"]),
127
130
  dim=sql.Literal(self.dim),
128
131
  where_clause=sql.SQL("WHERE id >= %s") if filtered else sql.SQL(""),
@@ -130,7 +133,7 @@ class PgVector(VectorDB):
130
133
  sql.SQL(self.case_config.search_param()["metric_fun_op"]),
131
134
  sql.SQL(
132
135
  """
133
- {search_vector}
136
+ {search_vector}::{quantization_type}({dim})
134
137
  LIMIT {quantized_fetch_limit}
135
138
  ) i
136
139
  ORDER BY i.distance
@@ -138,6 +141,8 @@ class PgVector(VectorDB):
138
141
  """,
139
142
  ).format(
140
143
  search_vector=search_vector,
144
+ quantization_type=sql.SQL(index_param["quantization_type"]),
145
+ dim=sql.Literal(self.dim),
141
146
  quantized_fetch_limit=sql.Literal(
142
147
  self.case_config.search_param()["quantized_fetch_limit"],
143
148
  ),
@@ -160,10 +165,12 @@ class PgVector(VectorDB):
160
165
  where_clause=sql.SQL("WHERE id >= %s") if filtered else sql.SQL(""),
161
166
  ),
162
167
  sql.SQL(self.case_config.search_param()["metric_fun_op"]),
163
- sql.SQL(" {search_vector} LIMIT %s::int").format(
168
+ sql.SQL(" {search_vector}::{quantization_type}({dim}) LIMIT %s::int").format(
164
169
  search_vector=search_vector,
170
+ quantization_type=sql.SQL(index_param["quantization_type"]),
171
+ dim=sql.Literal(self.dim),
165
172
  ),
166
- ],
173
+ ]
167
174
  )
168
175
  else:
169
176
  search_query = sql.Composed(
@@ -175,8 +182,12 @@ class PgVector(VectorDB):
175
182
  where_clause=sql.SQL("WHERE id >= %s") if filtered else sql.SQL(""),
176
183
  ),
177
184
  sql.SQL(self.case_config.search_param()["metric_fun_op"]),
178
- sql.SQL(" %s::vector LIMIT %s::int"),
179
- ],
185
+ sql.SQL(" {search_vector}::{quantization_type}({dim}) LIMIT %s::int").format(
186
+ search_vector=search_vector,
187
+ quantization_type=sql.SQL(index_param["quantization_type"]),
188
+ dim=sql.Literal(self.dim),
189
+ ),
190
+ ]
180
191
  )
181
192
 
182
193
  return search_query
@@ -323,7 +334,7 @@ class PgVector(VectorDB):
323
334
  )
324
335
  with_clause = sql.SQL("WITH ({});").format(sql.SQL(", ").join(options)) if any(options) else sql.Composed(())
325
336
 
326
- if index_param["quantization_type"] is not None:
337
+ if index_param["quantization_type"] != index_param["table_quantization_type"]:
327
338
  index_create_sql = sql.SQL(
328
339
  """
329
340
  CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name}
@@ -365,14 +376,23 @@ class PgVector(VectorDB):
365
376
  assert self.conn is not None, "Connection is not initialized"
366
377
  assert self.cursor is not None, "Cursor is not initialized"
367
378
 
379
+ index_param = self.case_config.index_param()
380
+
368
381
  try:
369
382
  log.info(f"{self.name} client create table : {self.table_name}")
370
383
 
371
384
  # create table
372
385
  self.cursor.execute(
373
386
  sql.SQL(
374
- "CREATE TABLE IF NOT EXISTS public.{table_name} (id BIGINT PRIMARY KEY, embedding vector({dim}));",
375
- ).format(table_name=sql.Identifier(self.table_name), dim=dim),
387
+ """
388
+ CREATE TABLE IF NOT EXISTS public.{table_name}
389
+ (id BIGINT PRIMARY KEY, embedding {table_quantization_type}({dim}));
390
+ """
391
+ ).format(
392
+ table_name=sql.Identifier(self.table_name),
393
+ table_quantization_type=sql.SQL(index_param["table_quantization_type"]),
394
+ dim=dim,
395
+ )
376
396
  )
377
397
  self.cursor.execute(
378
398
  sql.SQL(
@@ -393,18 +413,41 @@ class PgVector(VectorDB):
393
413
  assert self.conn is not None, "Connection is not initialized"
394
414
  assert self.cursor is not None, "Cursor is not initialized"
395
415
 
416
+ index_param = self.case_config.index_param()
417
+
396
418
  try:
397
419
  metadata_arr = np.array(metadata)
398
420
  embeddings_arr = np.array(embeddings)
399
421
 
400
- with self.cursor.copy(
401
- sql.SQL("COPY public.{table_name} FROM STDIN (FORMAT BINARY)").format(
402
- table_name=sql.Identifier(self.table_name),
403
- ),
404
- ) as copy:
405
- copy.set_types(["bigint", "vector"])
406
- for i, row in enumerate(metadata_arr):
407
- copy.write_row((row, embeddings_arr[i]))
422
+ if index_param["table_quantization_type"] == "bit":
423
+ with self.cursor.copy(
424
+ sql.SQL("COPY public.{table_name} FROM STDIN (FORMAT TEXT)").format(
425
+ table_name=sql.Identifier(self.table_name)
426
+ )
427
+ ) as copy:
428
+ # Same logic as pgvector binary_quantize
429
+ for i, row in enumerate(metadata_arr):
430
+ embeddings_bit = ""
431
+ for embedding in embeddings_arr[i]:
432
+ if embedding > 0:
433
+ embeddings_bit += "1"
434
+ else:
435
+ embeddings_bit += "0"
436
+ copy.write_row((str(row), embeddings_bit))
437
+ else:
438
+ with self.cursor.copy(
439
+ sql.SQL("COPY public.{table_name} FROM STDIN (FORMAT BINARY)").format(
440
+ table_name=sql.Identifier(self.table_name)
441
+ )
442
+ ) as copy:
443
+ if index_param["table_quantization_type"] == "halfvec":
444
+ copy.set_types(["bigint", "halfvec"])
445
+ for i, row in enumerate(metadata_arr):
446
+ copy.write_row((row, np.float16(embeddings_arr[i])))
447
+ else:
448
+ copy.set_types(["bigint", "vector"])
449
+ for i, row in enumerate(metadata_arr):
450
+ copy.write_row((row, embeddings_arr[i]))
408
451
  self.conn.commit()
409
452
 
410
453
  if kwargs.get("last_batch"):
@@ -0,0 +1,98 @@
1
+ from typing import Annotated, Unpack
2
+
3
+ import click
4
+ from pydantic import SecretStr
5
+
6
+ from vectordb_bench.backend.clients import DB
7
+
8
+ from ....cli.cli import CommonTypedDict, cli, click_parameter_decorators_from_typed_dict, run
9
+
10
+
11
+ class TiDBTypedDict(CommonTypedDict):
12
+ user_name: Annotated[
13
+ str,
14
+ click.option(
15
+ "--username",
16
+ type=str,
17
+ help="Username",
18
+ default="root",
19
+ show_default=True,
20
+ required=True,
21
+ ),
22
+ ]
23
+ password: Annotated[
24
+ str,
25
+ click.option(
26
+ "--password",
27
+ type=str,
28
+ default="",
29
+ show_default=True,
30
+ help="Password",
31
+ ),
32
+ ]
33
+ host: Annotated[
34
+ str,
35
+ click.option(
36
+ "--host",
37
+ type=str,
38
+ default="127.0.0.1",
39
+ show_default=True,
40
+ required=True,
41
+ help="Db host",
42
+ ),
43
+ ]
44
+ port: Annotated[
45
+ int,
46
+ click.option(
47
+ "--port",
48
+ type=int,
49
+ default=4000,
50
+ show_default=True,
51
+ required=True,
52
+ help="Db Port",
53
+ ),
54
+ ]
55
+ db_name: Annotated[
56
+ str,
57
+ click.option(
58
+ "--db-name",
59
+ type=str,
60
+ default="test",
61
+ show_default=True,
62
+ required=True,
63
+ help="Db name",
64
+ ),
65
+ ]
66
+ ssl: Annotated[
67
+ bool,
68
+ click.option(
69
+ "--ssl/--no-ssl",
70
+ default=False,
71
+ show_default=True,
72
+ is_flag=True,
73
+ help="Enable or disable SSL, for TiDB Serverless SSL must be enabled",
74
+ ),
75
+ ]
76
+
77
+
78
+ @cli.command()
79
+ @click_parameter_decorators_from_typed_dict(TiDBTypedDict)
80
+ def TiDB(
81
+ **parameters: Unpack[TiDBTypedDict],
82
+ ):
83
+ from .config import TiDBConfig, TiDBIndexConfig
84
+
85
+ run(
86
+ db=DB.TiDB,
87
+ db_config=TiDBConfig(
88
+ db_label=parameters["db_label"],
89
+ user_name=parameters["username"],
90
+ password=SecretStr(parameters["password"]),
91
+ host=parameters["host"],
92
+ port=parameters["port"],
93
+ db_name=parameters["db_name"],
94
+ ssl=parameters["ssl"],
95
+ ),
96
+ db_case_config=TiDBIndexConfig(),
97
+ **parameters,
98
+ )
@@ -0,0 +1,49 @@
1
+ from pydantic import SecretStr, BaseModel, validator
2
+ from ..api import DBConfig, DBCaseConfig, MetricType
3
+
4
+
5
+ class TiDBConfig(DBConfig):
6
+ user_name: str = "root"
7
+ password: SecretStr
8
+ host: str = "127.0.0.1"
9
+ port: int = 4000
10
+ db_name: str = "test"
11
+ ssl: bool = False
12
+
13
+ @validator("*")
14
+ def not_empty_field(cls, v: any, field: any):
15
+ return v
16
+
17
+ def to_dict(self) -> dict:
18
+ pwd_str = self.password.get_secret_value()
19
+ return {
20
+ "host": self.host,
21
+ "port": self.port,
22
+ "user": self.user_name,
23
+ "password": pwd_str,
24
+ "database": self.db_name,
25
+ "ssl_verify_cert": self.ssl,
26
+ "ssl_verify_identity": self.ssl,
27
+ }
28
+
29
+
30
+ class TiDBIndexConfig(BaseModel, DBCaseConfig):
31
+ metric_type: MetricType | None = None
32
+
33
+ def get_metric_fn(self) -> str:
34
+ if self.metric_type == MetricType.L2:
35
+ return "vec_l2_distance"
36
+ elif self.metric_type == MetricType.COSINE:
37
+ return "vec_cosine_distance"
38
+ else:
39
+ raise ValueError(f"Unsupported metric type: {self.metric_type}")
40
+
41
+ def index_param(self) -> dict:
42
+ return {
43
+ "metric_fn": self.get_metric_fn(),
44
+ }
45
+
46
+ def search_param(self) -> dict:
47
+ return {
48
+ "metric_fn": self.get_metric_fn(),
49
+ }
@@ -0,0 +1,234 @@
1
+ import concurrent.futures
2
+ import io
3
+ import logging
4
+ import time
5
+ from contextlib import contextmanager
6
+ from typing import Any, Optional, Tuple
7
+
8
+ import pymysql
9
+
10
+ from ..api import VectorDB
11
+ from .config import TiDBIndexConfig
12
+
13
+ log = logging.getLogger(__name__)
14
+
15
+
16
+ class TiDB(VectorDB):
17
+ def __init__(
18
+ self,
19
+ dim: int,
20
+ db_config: dict,
21
+ db_case_config: TiDBIndexConfig,
22
+ collection_name: str = "vector_bench_test",
23
+ drop_old: bool = False,
24
+ **kwargs,
25
+ ):
26
+ self.name = "TiDB"
27
+ self.db_config = db_config
28
+ self.case_config = db_case_config
29
+ self.table_name = collection_name
30
+ self.dim = dim
31
+ self.conn = None # To be inited by init()
32
+ self.cursor = None # To be inited by init()
33
+
34
+ self.search_fn = db_case_config.search_param()["metric_fn"]
35
+
36
+ if drop_old:
37
+ self._drop_table()
38
+ self._create_table()
39
+
40
+ @contextmanager
41
+ def init(self):
42
+ with self._get_connection() as (conn, cursor):
43
+ self.conn = conn
44
+ self.cursor = cursor
45
+ try:
46
+ yield
47
+ finally:
48
+ self.conn = None
49
+ self.cursor = None
50
+
51
+ @contextmanager
52
+ def _get_connection(self):
53
+ with pymysql.connect(**self.db_config) as conn:
54
+ conn.autocommit = False
55
+ with conn.cursor() as cursor:
56
+ yield conn, cursor
57
+
58
+ def _drop_table(self):
59
+ try:
60
+ with self._get_connection() as (conn, cursor):
61
+ cursor.execute(f"DROP TABLE IF EXISTS {self.table_name}")
62
+ conn.commit()
63
+ except Exception as e:
64
+ log.warning("Failed to drop table: %s error: %s", self.table_name, e)
65
+ raise e
66
+
67
+ def _create_table(self):
68
+ try:
69
+ index_param = self.case_config.index_param()
70
+ with self._get_connection() as (conn, cursor):
71
+ cursor.execute(
72
+ f"""
73
+ CREATE TABLE {self.table_name} (
74
+ id BIGINT PRIMARY KEY,
75
+ embedding VECTOR({self.dim}) NOT NULL,
76
+ VECTOR INDEX (({index_param["metric_fn"]}(embedding)))
77
+ );
78
+ """
79
+ )
80
+ conn.commit()
81
+ except Exception as e:
82
+ log.warning("Failed to create table: %s error: %s", self.table_name, e)
83
+ raise e
84
+
85
+ def ready_to_load(self) -> bool:
86
+ pass
87
+
88
+ def optimize(self, data_size: int | None = None) -> None:
89
+ while True:
90
+ progress = self._optimize_check_tiflash_replica_progress()
91
+ if progress != 1:
92
+ log.info("Data replication not ready, progress: %d", progress)
93
+ time.sleep(2)
94
+ else:
95
+ break
96
+
97
+ log.info("Waiting TiFlash to catch up...")
98
+ self._optimize_wait_tiflash_catch_up()
99
+
100
+ log.info("Start compacting TiFlash replica...")
101
+ self._optimize_compact_tiflash()
102
+
103
+ log.info("Waiting index build to finish...")
104
+ log_reduce_seq = 0
105
+ while True:
106
+ pending_rows = self._optimize_get_tiflash_index_pending_rows()
107
+ if pending_rows > 0:
108
+ if log_reduce_seq % 15 == 0:
109
+ log.info("Index not fully built, pending rows: %d", pending_rows)
110
+ log_reduce_seq += 1
111
+ time.sleep(2)
112
+ else:
113
+ break
114
+
115
+ log.info("Index build finished successfully.")
116
+
117
+ def _optimize_check_tiflash_replica_progress(self):
118
+ try:
119
+ database = self.db_config["database"]
120
+ with self._get_connection() as (_, cursor):
121
+ cursor.execute(
122
+ f"""
123
+ SELECT PROGRESS FROM information_schema.tiflash_replica
124
+ WHERE TABLE_SCHEMA = "{database}" AND TABLE_NAME = "{self.table_name}"
125
+ """
126
+ )
127
+ result = cursor.fetchone()
128
+ return result[0]
129
+ except Exception as e:
130
+ log.warning("Failed to check TiFlash replica progress: %s", e)
131
+ raise e
132
+
133
+ def _optimize_wait_tiflash_catch_up(self):
134
+ try:
135
+ with self._get_connection() as (conn, cursor):
136
+ cursor.execute('SET @@TIDB_ISOLATION_READ_ENGINES="tidb,tiflash"')
137
+ conn.commit()
138
+ cursor.execute(f"SELECT COUNT(*) FROM {self.table_name}")
139
+ result = cursor.fetchone()
140
+ return result[0]
141
+ except Exception as e:
142
+ log.warning("Failed to wait TiFlash to catch up: %s", e)
143
+ raise e
144
+
145
+ def _optimize_compact_tiflash(self):
146
+ try:
147
+ with self._get_connection() as (conn, cursor):
148
+ cursor.execute(f"ALTER TABLE {self.table_name} COMPACT")
149
+ conn.commit()
150
+ except Exception as e:
151
+ log.warning("Failed to compact table: %s", e)
152
+ raise e
153
+
154
+ def _optimize_get_tiflash_index_pending_rows(self):
155
+ try:
156
+ database = self.db_config["database"]
157
+ with self._get_connection() as (_, cursor):
158
+ cursor.execute(
159
+ f"""
160
+ SELECT SUM(ROWS_STABLE_NOT_INDEXED)
161
+ FROM information_schema.tiflash_indexes
162
+ WHERE TIDB_DATABASE = "{database}" AND TIDB_TABLE = "{self.table_name}"
163
+ """
164
+ )
165
+ result = cursor.fetchone()
166
+ return result[0]
167
+ except Exception as e:
168
+ log.warning("Failed to read TiFlash index pending rows: %s", e)
169
+ raise e
170
+
171
+ def _insert_embeddings_serial(
172
+ self,
173
+ embeddings: list[list[float]],
174
+ metadata: list[int],
175
+ offset: int,
176
+ size: int,
177
+ ) -> Exception:
178
+ try:
179
+ with self._get_connection() as (conn, cursor):
180
+ buf = io.StringIO()
181
+ buf.write(f"INSERT INTO {self.table_name} (id, embedding) VALUES ")
182
+ for i in range(offset, offset + size):
183
+ if i > offset:
184
+ buf.write(",")
185
+ buf.write(f'({metadata[i]}, "{str(embeddings[i])}")')
186
+ cursor.execute(buf.getvalue())
187
+ conn.commit()
188
+ except Exception as e:
189
+ log.warning("Failed to insert data into table: %s", e)
190
+ raise e
191
+
192
+ def insert_embeddings(
193
+ self,
194
+ embeddings: list[list[float]],
195
+ metadata: list[int],
196
+ **kwargs: Any,
197
+ ) -> Tuple[int, Optional[Exception]]:
198
+ workers = 10
199
+ # Avoid exceeding MAX_ALLOWED_PACKET (default=64MB)
200
+ max_batch_size = 64 * 1024 * 1024 // 24 // self.dim
201
+ batch_size = len(embeddings) // workers
202
+ if batch_size > max_batch_size:
203
+ batch_size = max_batch_size
204
+ with concurrent.futures.ThreadPoolExecutor(max_workers=workers) as executor:
205
+ futures = []
206
+ for i in range(0, len(embeddings), batch_size):
207
+ offset = i
208
+ size = min(batch_size, len(embeddings) - i)
209
+ future = executor.submit(self._insert_embeddings_serial, embeddings, metadata, offset, size)
210
+ futures.append(future)
211
+ done, pending = concurrent.futures.wait(futures, return_when=concurrent.futures.FIRST_EXCEPTION)
212
+ executor.shutdown(wait=False)
213
+ for future in done:
214
+ future.result()
215
+ for future in pending:
216
+ future.cancel()
217
+ return len(metadata), None
218
+
219
+ def search_embedding(
220
+ self,
221
+ query: list[float],
222
+ k: int = 100,
223
+ filters: dict | None = None,
224
+ timeout: int | None = None,
225
+ **kwargs: Any,
226
+ ) -> list[int]:
227
+ self.cursor.execute(
228
+ f"""
229
+ SELECT id FROM {self.table_name}
230
+ ORDER BY {self.search_fn}(embedding, "{str(query)}") LIMIT {k};
231
+ """
232
+ )
233
+ result = self.cursor.fetchall()
234
+ return [int(i[0]) for i in result]
@@ -1,5 +1,6 @@
1
1
  from ..backend.clients.alloydb.cli import AlloyDBScaNN
2
2
  from ..backend.clients.aws_opensearch.cli import AWSOpenSearch
3
+ from ..backend.clients.mariadb.cli import MariaDBHNSW
3
4
  from ..backend.clients.memorydb.cli import MemoryDB
4
5
  from ..backend.clients.milvus.cli import MilvusAutoIndex
5
6
  from ..backend.clients.pgdiskann.cli import PgDiskAnn
@@ -10,6 +11,7 @@ from ..backend.clients.redis.cli import Redis
10
11
  from ..backend.clients.test.cli import Test
11
12
  from ..backend.clients.weaviate_cloud.cli import Weaviate
12
13
  from ..backend.clients.zilliz_cloud.cli import ZillizAutoIndex
14
+ from ..backend.clients.tidb.cli import TiDB
13
15
  from .cli import cli
14
16
 
15
17
  cli.add_command(PgVectorHNSW)
@@ -25,6 +27,8 @@ cli.add_command(AWSOpenSearch)
25
27
  cli.add_command(PgVectorScaleDiskAnn)
26
28
  cli.add_command(PgDiskAnn)
27
29
  cli.add_command(AlloyDBScaNN)
30
+ cli.add_command(MariaDBHNSW)
31
+ cli.add_command(TiDB)
28
32
 
29
33
 
30
34
  if __name__ == "__main__":