vectordb-bench 0.0.22__py3-none-any.whl → 0.0.24__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. vectordb_bench/backend/clients/__init__.py +65 -1
  2. vectordb_bench/backend/clients/api.py +2 -1
  3. vectordb_bench/backend/clients/chroma/chroma.py +2 -2
  4. vectordb_bench/backend/clients/clickhouse/cli.py +66 -0
  5. vectordb_bench/backend/clients/clickhouse/clickhouse.py +156 -0
  6. vectordb_bench/backend/clients/clickhouse/config.py +60 -0
  7. vectordb_bench/backend/clients/elastic_cloud/elastic_cloud.py +1 -1
  8. vectordb_bench/backend/clients/mariadb/cli.py +122 -0
  9. vectordb_bench/backend/clients/mariadb/config.py +73 -0
  10. vectordb_bench/backend/clients/mariadb/mariadb.py +208 -0
  11. vectordb_bench/backend/clients/milvus/cli.py +32 -0
  12. vectordb_bench/backend/clients/milvus/config.py +32 -0
  13. vectordb_bench/backend/clients/milvus/milvus.py +1 -1
  14. vectordb_bench/backend/clients/pgvector/cli.py +14 -3
  15. vectordb_bench/backend/clients/pgvector/config.py +22 -5
  16. vectordb_bench/backend/clients/pgvector/pgvector.py +62 -19
  17. vectordb_bench/backend/clients/pinecone/pinecone.py +1 -1
  18. vectordb_bench/backend/clients/qdrant_cloud/config.py +1 -9
  19. vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py +1 -1
  20. vectordb_bench/backend/clients/tidb/cli.py +98 -0
  21. vectordb_bench/backend/clients/tidb/config.py +46 -0
  22. vectordb_bench/backend/clients/tidb/tidb.py +233 -0
  23. vectordb_bench/backend/clients/vespa/cli.py +47 -0
  24. vectordb_bench/backend/clients/vespa/config.py +51 -0
  25. vectordb_bench/backend/clients/vespa/util.py +15 -0
  26. vectordb_bench/backend/clients/vespa/vespa.py +249 -0
  27. vectordb_bench/backend/clients/weaviate_cloud/weaviate_cloud.py +1 -1
  28. vectordb_bench/cli/cli.py +20 -17
  29. vectordb_bench/cli/vectordbbench.py +8 -0
  30. vectordb_bench/frontend/config/dbCaseConfigs.py +147 -0
  31. vectordb_bench/frontend/config/styles.py +4 -0
  32. vectordb_bench/models.py +8 -6
  33. {vectordb_bench-0.0.22.dist-info → vectordb_bench-0.0.24.dist-info}/METADATA +22 -3
  34. {vectordb_bench-0.0.22.dist-info → vectordb_bench-0.0.24.dist-info}/RECORD +38 -25
  35. {vectordb_bench-0.0.22.dist-info → vectordb_bench-0.0.24.dist-info}/WHEEL +1 -1
  36. {vectordb_bench-0.0.22.dist-info → vectordb_bench-0.0.24.dist-info}/entry_points.txt +0 -0
  37. {vectordb_bench-0.0.22.dist-info → vectordb_bench-0.0.24.dist-info/licenses}/LICENSE +0 -0
  38. {vectordb_bench-0.0.22.dist-info → vectordb_bench-0.0.24.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,98 @@
1
+ from typing import Annotated, Unpack
2
+
3
+ import click
4
+ from pydantic import SecretStr
5
+
6
+ from vectordb_bench.backend.clients import DB
7
+
8
+ from ....cli.cli import CommonTypedDict, cli, click_parameter_decorators_from_typed_dict, run
9
+
10
+
11
+ class TiDBTypedDict(CommonTypedDict):
12
+ user_name: Annotated[
13
+ str,
14
+ click.option(
15
+ "--username",
16
+ type=str,
17
+ help="Username",
18
+ default="root",
19
+ show_default=True,
20
+ required=True,
21
+ ),
22
+ ]
23
+ password: Annotated[
24
+ str,
25
+ click.option(
26
+ "--password",
27
+ type=str,
28
+ default="",
29
+ show_default=True,
30
+ help="Password",
31
+ ),
32
+ ]
33
+ host: Annotated[
34
+ str,
35
+ click.option(
36
+ "--host",
37
+ type=str,
38
+ default="127.0.0.1",
39
+ show_default=True,
40
+ required=True,
41
+ help="Db host",
42
+ ),
43
+ ]
44
+ port: Annotated[
45
+ int,
46
+ click.option(
47
+ "--port",
48
+ type=int,
49
+ default=4000,
50
+ show_default=True,
51
+ required=True,
52
+ help="Db Port",
53
+ ),
54
+ ]
55
+ db_name: Annotated[
56
+ str,
57
+ click.option(
58
+ "--db-name",
59
+ type=str,
60
+ default="test",
61
+ show_default=True,
62
+ required=True,
63
+ help="Db name",
64
+ ),
65
+ ]
66
+ ssl: Annotated[
67
+ bool,
68
+ click.option(
69
+ "--ssl/--no-ssl",
70
+ default=False,
71
+ show_default=True,
72
+ is_flag=True,
73
+ help="Enable or disable SSL, for TiDB Serverless SSL must be enabled",
74
+ ),
75
+ ]
76
+
77
+
78
+ @cli.command()
79
+ @click_parameter_decorators_from_typed_dict(TiDBTypedDict)
80
+ def TiDB(
81
+ **parameters: Unpack[TiDBTypedDict],
82
+ ):
83
+ from .config import TiDBConfig, TiDBIndexConfig
84
+
85
+ run(
86
+ db=DB.TiDB,
87
+ db_config=TiDBConfig(
88
+ db_label=parameters["db_label"],
89
+ user_name=parameters["username"],
90
+ password=SecretStr(parameters["password"]),
91
+ host=parameters["host"],
92
+ port=parameters["port"],
93
+ db_name=parameters["db_name"],
94
+ ssl=parameters["ssl"],
95
+ ),
96
+ db_case_config=TiDBIndexConfig(),
97
+ **parameters,
98
+ )
@@ -0,0 +1,46 @@
1
+ from pydantic import BaseModel, SecretStr
2
+
3
+ from ..api import DBCaseConfig, DBConfig, MetricType
4
+
5
+
6
+ class TiDBConfig(DBConfig):
7
+ user_name: str = "root"
8
+ password: SecretStr
9
+ host: str = "127.0.0.1"
10
+ port: int = 4000
11
+ db_name: str = "test"
12
+ ssl: bool = False
13
+
14
+ def to_dict(self) -> dict:
15
+ pwd_str = self.password.get_secret_value()
16
+ return {
17
+ "host": self.host,
18
+ "port": self.port,
19
+ "user": self.user_name,
20
+ "password": pwd_str,
21
+ "database": self.db_name,
22
+ "ssl_verify_cert": self.ssl,
23
+ "ssl_verify_identity": self.ssl,
24
+ }
25
+
26
+
27
+ class TiDBIndexConfig(BaseModel, DBCaseConfig):
28
+ metric_type: MetricType | None = None
29
+
30
+ def get_metric_fn(self) -> str:
31
+ if self.metric_type == MetricType.L2:
32
+ return "vec_l2_distance"
33
+ if self.metric_type == MetricType.COSINE:
34
+ return "vec_cosine_distance"
35
+ msg = f"Unsupported metric type: {self.metric_type}"
36
+ raise ValueError(msg)
37
+
38
+ def index_param(self) -> dict:
39
+ return {
40
+ "metric_fn": self.get_metric_fn(),
41
+ }
42
+
43
+ def search_param(self) -> dict:
44
+ return {
45
+ "metric_fn": self.get_metric_fn(),
46
+ }
@@ -0,0 +1,233 @@
1
+ import concurrent.futures
2
+ import io
3
+ import logging
4
+ import time
5
+ from contextlib import contextmanager
6
+ from typing import Any
7
+
8
+ import pymysql
9
+
10
+ from ..api import VectorDB
11
+ from .config import TiDBIndexConfig
12
+
13
+ log = logging.getLogger(__name__)
14
+
15
+
16
+ class TiDB(VectorDB):
17
+ def __init__(
18
+ self,
19
+ dim: int,
20
+ db_config: dict,
21
+ db_case_config: TiDBIndexConfig,
22
+ collection_name: str = "vector_bench_test",
23
+ drop_old: bool = False,
24
+ **kwargs,
25
+ ):
26
+ self.name = "TiDB"
27
+ self.db_config = db_config
28
+ self.case_config = db_case_config
29
+ self.table_name = collection_name
30
+ self.dim = dim
31
+ self.conn = None # To be inited by init()
32
+ self.cursor = None # To be inited by init()
33
+
34
+ self.search_fn = db_case_config.search_param()["metric_fn"]
35
+
36
+ if drop_old:
37
+ self._drop_table()
38
+ self._create_table()
39
+
40
+ @contextmanager
41
+ def init(self):
42
+ with self._get_connection() as (conn, cursor):
43
+ self.conn = conn
44
+ self.cursor = cursor
45
+ try:
46
+ yield
47
+ finally:
48
+ self.conn = None
49
+ self.cursor = None
50
+
51
+ @contextmanager
52
+ def _get_connection(self):
53
+ with pymysql.connect(**self.db_config) as conn:
54
+ conn.autocommit = False
55
+ with conn.cursor() as cursor:
56
+ yield conn, cursor
57
+
58
+ def _drop_table(self):
59
+ try:
60
+ with self._get_connection() as (conn, cursor):
61
+ cursor.execute(f"DROP TABLE IF EXISTS {self.table_name}")
62
+ conn.commit()
63
+ except Exception as e:
64
+ log.warning("Failed to drop table: %s error: %s", self.table_name, e)
65
+ raise
66
+
67
+ def _create_table(self):
68
+ try:
69
+ index_param = self.case_config.index_param()
70
+ with self._get_connection() as (conn, cursor):
71
+ cursor.execute(
72
+ f"""
73
+ CREATE TABLE {self.table_name} (
74
+ id BIGINT PRIMARY KEY,
75
+ embedding VECTOR({self.dim}) NOT NULL,
76
+ VECTOR INDEX (({index_param["metric_fn"]}(embedding)))
77
+ );
78
+ """
79
+ )
80
+ conn.commit()
81
+ except Exception as e:
82
+ log.warning("Failed to create table: %s error: %s", self.table_name, e)
83
+ raise
84
+
85
+ def ready_to_load(self) -> bool:
86
+ pass
87
+
88
+ def optimize(self, data_size: int | None = None) -> None:
89
+ while True:
90
+ progress = self._optimize_check_tiflash_replica_progress()
91
+ if progress != 1:
92
+ log.info("Data replication not ready, progress: %d", progress)
93
+ time.sleep(2)
94
+ else:
95
+ break
96
+
97
+ log.info("Waiting TiFlash to catch up...")
98
+ self._optimize_wait_tiflash_catch_up()
99
+
100
+ log.info("Start compacting TiFlash replica...")
101
+ self._optimize_compact_tiflash()
102
+
103
+ log.info("Waiting index build to finish...")
104
+ log_reduce_seq = 0
105
+ while True:
106
+ pending_rows = self._optimize_get_tiflash_index_pending_rows()
107
+ if pending_rows > 0:
108
+ if log_reduce_seq % 15 == 0:
109
+ log.info("Index not fully built, pending rows: %d", pending_rows)
110
+ log_reduce_seq += 1
111
+ time.sleep(2)
112
+ else:
113
+ break
114
+
115
+ log.info("Index build finished successfully.")
116
+
117
+ def _optimize_check_tiflash_replica_progress(self):
118
+ try:
119
+ database = self.db_config["database"]
120
+ with self._get_connection() as (_, cursor):
121
+ cursor.execute(
122
+ f"""
123
+ SELECT PROGRESS FROM information_schema.tiflash_replica
124
+ WHERE TABLE_SCHEMA = "{database}" AND TABLE_NAME = "{self.table_name}"
125
+ """ # noqa: S608
126
+ )
127
+ result = cursor.fetchone()
128
+ return result[0]
129
+ except Exception as e:
130
+ log.warning("Failed to check TiFlash replica progress: %s", e)
131
+ raise
132
+
133
+ def _optimize_wait_tiflash_catch_up(self):
134
+ try:
135
+ with self._get_connection() as (conn, cursor):
136
+ cursor.execute('SET @@TIDB_ISOLATION_READ_ENGINES="tidb,tiflash"')
137
+ conn.commit()
138
+ cursor.execute(f"SELECT COUNT(*) FROM {self.table_name}") # noqa: S608
139
+ result = cursor.fetchone()
140
+ return result[0]
141
+ except Exception as e:
142
+ log.warning("Failed to wait TiFlash to catch up: %s", e)
143
+ raise
144
+
145
+ def _optimize_compact_tiflash(self):
146
+ try:
147
+ with self._get_connection() as (conn, cursor):
148
+ cursor.execute(f"ALTER TABLE {self.table_name} COMPACT")
149
+ conn.commit()
150
+ except Exception as e:
151
+ log.warning("Failed to compact table: %s", e)
152
+ raise
153
+
154
+ def _optimize_get_tiflash_index_pending_rows(self):
155
+ try:
156
+ database = self.db_config["database"]
157
+ with self._get_connection() as (_, cursor):
158
+ cursor.execute(
159
+ f"""
160
+ SELECT SUM(ROWS_STABLE_NOT_INDEXED)
161
+ FROM information_schema.tiflash_indexes
162
+ WHERE TIDB_DATABASE = "{database}" AND TIDB_TABLE = "{self.table_name}"
163
+ """ # noqa: S608
164
+ )
165
+ result = cursor.fetchone()
166
+ return result[0]
167
+ except Exception as e:
168
+ log.warning("Failed to read TiFlash index pending rows: %s", e)
169
+ raise
170
+
171
+ def _insert_embeddings_serial(
172
+ self,
173
+ embeddings: list[list[float]],
174
+ metadata: list[int],
175
+ offset: int,
176
+ size: int,
177
+ ) -> Exception:
178
+ try:
179
+ with self._get_connection() as (conn, cursor):
180
+ buf = io.StringIO()
181
+ buf.write(f"INSERT INTO {self.table_name} (id, embedding) VALUES ") # noqa: S608
182
+ for i in range(offset, offset + size):
183
+ if i > offset:
184
+ buf.write(",")
185
+ buf.write(f'({metadata[i]}, "{embeddings[i]!s}")')
186
+ cursor.execute(buf.getvalue())
187
+ conn.commit()
188
+ except Exception as e:
189
+ log.warning("Failed to insert data into table: %s", e)
190
+ raise
191
+
192
+ def insert_embeddings(
193
+ self,
194
+ embeddings: list[list[float]],
195
+ metadata: list[int],
196
+ **kwargs: Any,
197
+ ) -> tuple[int, Exception]:
198
+ workers = 10
199
+ # Avoid exceeding MAX_ALLOWED_PACKET (default=64MB)
200
+ max_batch_size = 64 * 1024 * 1024 // 24 // self.dim
201
+ batch_size = len(embeddings) // workers
202
+ batch_size = min(batch_size, max_batch_size)
203
+ with concurrent.futures.ThreadPoolExecutor(max_workers=workers) as executor:
204
+ futures = []
205
+ for i in range(0, len(embeddings), batch_size):
206
+ offset = i
207
+ size = min(batch_size, len(embeddings) - i)
208
+ future = executor.submit(self._insert_embeddings_serial, embeddings, metadata, offset, size)
209
+ futures.append(future)
210
+ done, pending = concurrent.futures.wait(futures, return_when=concurrent.futures.FIRST_EXCEPTION)
211
+ executor.shutdown(wait=False)
212
+ for future in done:
213
+ future.result()
214
+ for future in pending:
215
+ future.cancel()
216
+ return len(metadata), None
217
+
218
+ def search_embedding(
219
+ self,
220
+ query: list[float],
221
+ k: int = 100,
222
+ filters: dict | None = None,
223
+ timeout: int | None = None,
224
+ **kwargs: Any,
225
+ ) -> list[int]:
226
+ self.cursor.execute(
227
+ f"""
228
+ SELECT id FROM {self.table_name}
229
+ ORDER BY {self.search_fn}(embedding, "{query!s}") LIMIT {k};
230
+ """ # noqa: S608
231
+ )
232
+ result = self.cursor.fetchall()
233
+ return [int(i[0]) for i in result]
@@ -0,0 +1,47 @@
1
+ from typing import Annotated, Unpack
2
+
3
+ import click
4
+ from pydantic import SecretStr
5
+
6
+ from vectordb_bench.backend.clients import DB
7
+ from vectordb_bench.cli.cli import (
8
+ CommonTypedDict,
9
+ HNSWFlavor1,
10
+ cli,
11
+ click_parameter_decorators_from_typed_dict,
12
+ run,
13
+ )
14
+
15
+
16
+ class VespaTypedDict(CommonTypedDict, HNSWFlavor1):
17
+ uri: Annotated[
18
+ str,
19
+ click.option("--uri", "-u", type=str, help="uri connection string", default="http://127.0.0.1"),
20
+ ]
21
+ port: Annotated[
22
+ int,
23
+ click.option("--port", "-p", type=int, help="connection port", default=8080),
24
+ ]
25
+ quantization: Annotated[
26
+ str, click.option("--quantization", type=click.Choice(["none", "binary"], case_sensitive=False), default="none")
27
+ ]
28
+
29
+
30
+ @cli.command()
31
+ @click_parameter_decorators_from_typed_dict(VespaTypedDict)
32
+ def Vespa(**params: Unpack[VespaTypedDict]):
33
+ from .config import VespaConfig, VespaHNSWConfig
34
+
35
+ case_params = {
36
+ "quantization_type": params["quantization"],
37
+ "M": params["m"],
38
+ "efConstruction": params["ef_construction"],
39
+ "ef": params["ef_search"],
40
+ }
41
+
42
+ run(
43
+ db=DB.Vespa,
44
+ db_config=VespaConfig(url=SecretStr(params["uri"]), port=params["port"]),
45
+ db_case_config=VespaHNSWConfig(**{k: v for k, v in case_params.items() if v}),
46
+ **params,
47
+ )
@@ -0,0 +1,51 @@
1
+ from typing import Literal, TypeAlias
2
+
3
+ from pydantic import BaseModel, SecretStr
4
+
5
+ from ..api import DBCaseConfig, DBConfig, MetricType
6
+
7
+ VespaMetric: TypeAlias = Literal["euclidean", "angular", "dotproduct", "prenormalized-angular", "hamming", "geodegrees"]
8
+
9
+ VespaQuantizationType: TypeAlias = Literal["none", "binary"]
10
+
11
+
12
+ class VespaConfig(DBConfig):
13
+ url: SecretStr = "http://127.0.0.1"
14
+ port: int = 8080
15
+
16
+ def to_dict(self):
17
+ return {
18
+ "url": self.url.get_secret_value(),
19
+ "port": self.port,
20
+ }
21
+
22
+
23
+ class VespaHNSWConfig(BaseModel, DBCaseConfig):
24
+ metric_type: MetricType = MetricType.COSINE
25
+ quantization_type: VespaQuantizationType = "none"
26
+ M: int = 16
27
+ efConstruction: int = 200
28
+ ef: int = 100
29
+
30
+ def index_param(self) -> dict:
31
+ return {
32
+ "distance_metric": self.parse_metric(self.metric_type),
33
+ "max_links_per_node": self.M,
34
+ "neighbors_to_explore_at_insert": self.efConstruction,
35
+ }
36
+
37
+ def search_param(self) -> dict:
38
+ return {}
39
+
40
+ def parse_metric(self, metric_type: MetricType) -> VespaMetric:
41
+ match metric_type:
42
+ case MetricType.COSINE:
43
+ return "angular"
44
+ case MetricType.L2:
45
+ return "euclidean"
46
+ case MetricType.DP | MetricType.IP:
47
+ return "dotproduct"
48
+ case MetricType.HAMMING:
49
+ return "hamming"
50
+ case _:
51
+ raise NotImplementedError
@@ -0,0 +1,15 @@
1
+ """Utility functions for supporting binary quantization
2
+
3
+ From https://docs.vespa.ai/en/binarizing-vectors.html#appendix-conversion-to-int8
4
+ """
5
+
6
+ import numpy as np
7
+
8
+
9
+ def binarize_tensor(tensor: list[float]) -> list[int]:
10
+ """
11
+ Binarize a floating-point list by thresholding at zero
12
+ and packing the bits into bytes.
13
+ """
14
+ tensor = np.array(tensor)
15
+ return np.packbits(np.where(tensor > 0, 1, 0), axis=0).astype(np.int8).tolist()