vectordb-bench 0.0.23__py3-none-any.whl → 0.0.25__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vectordb_bench/backend/clients/__init__.py +33 -1
- vectordb_bench/backend/clients/api.py +1 -1
- vectordb_bench/backend/clients/chroma/chroma.py +2 -2
- vectordb_bench/backend/clients/clickhouse/cli.py +66 -0
- vectordb_bench/backend/clients/clickhouse/clickhouse.py +156 -0
- vectordb_bench/backend/clients/clickhouse/config.py +60 -0
- vectordb_bench/backend/clients/elastic_cloud/elastic_cloud.py +1 -1
- vectordb_bench/backend/clients/mariadb/cli.py +60 -45
- vectordb_bench/backend/clients/mariadb/config.py +11 -9
- vectordb_bench/backend/clients/mariadb/mariadb.py +52 -58
- vectordb_bench/backend/clients/milvus/cli.py +1 -19
- vectordb_bench/backend/clients/milvus/config.py +0 -1
- vectordb_bench/backend/clients/milvus/milvus.py +1 -1
- vectordb_bench/backend/clients/pgvector/cli.py +1 -2
- vectordb_bench/backend/clients/pinecone/pinecone.py +1 -1
- vectordb_bench/backend/clients/qdrant_cloud/config.py +1 -9
- vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py +1 -1
- vectordb_bench/backend/clients/tidb/config.py +6 -9
- vectordb_bench/backend/clients/tidb/tidb.py +17 -18
- vectordb_bench/backend/clients/vespa/cli.py +47 -0
- vectordb_bench/backend/clients/vespa/config.py +51 -0
- vectordb_bench/backend/clients/vespa/util.py +15 -0
- vectordb_bench/backend/clients/vespa/vespa.py +249 -0
- vectordb_bench/backend/clients/weaviate_cloud/weaviate_cloud.py +1 -1
- vectordb_bench/cli/cli.py +21 -17
- vectordb_bench/cli/vectordbbench.py +5 -1
- vectordb_bench/frontend/config/dbCaseConfigs.py +58 -7
- vectordb_bench/frontend/config/styles.py +2 -0
- vectordb_bench/models.py +5 -6
- {vectordb_bench-0.0.23.dist-info → vectordb_bench-0.0.25.dist-info}/METADATA +11 -3
- {vectordb_bench-0.0.23.dist-info → vectordb_bench-0.0.25.dist-info}/RECORD +35 -28
- {vectordb_bench-0.0.23.dist-info → vectordb_bench-0.0.25.dist-info}/WHEEL +1 -1
- {vectordb_bench-0.0.23.dist-info → vectordb_bench-0.0.25.dist-info}/entry_points.txt +0 -0
- {vectordb_bench-0.0.23.dist-info → vectordb_bench-0.0.25.dist-info/licenses}/LICENSE +0 -0
- {vectordb_bench-0.0.23.dist-info → vectordb_bench-0.0.25.dist-info}/top_level.txt +0 -0
@@ -1,27 +1,25 @@
|
|
1
|
-
from ..api import VectorDB
|
2
|
-
|
3
1
|
import logging
|
4
2
|
from contextlib import contextmanager
|
5
|
-
from typing import Any, Optional, Tuple
|
6
|
-
from ..api import VectorDB
|
7
|
-
from .config import MariaDBConfigDict, MariaDBIndexConfig
|
8
|
-
import numpy as np
|
9
3
|
|
10
4
|
import mariadb
|
5
|
+
import numpy as np
|
6
|
+
|
7
|
+
from ..api import VectorDB
|
8
|
+
from .config import MariaDBConfigDict, MariaDBIndexConfig
|
11
9
|
|
12
10
|
log = logging.getLogger(__name__)
|
13
11
|
|
12
|
+
|
14
13
|
class MariaDB(VectorDB):
|
15
14
|
def __init__(
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
15
|
+
self,
|
16
|
+
dim: int,
|
17
|
+
db_config: MariaDBConfigDict,
|
18
|
+
db_case_config: MariaDBIndexConfig,
|
19
|
+
collection_name: str = "vec_collection",
|
20
|
+
drop_old: bool = False,
|
21
|
+
**kwargs,
|
22
|
+
):
|
25
23
|
self.name = "MariaDB"
|
26
24
|
self.db_config = db_config
|
27
25
|
self.case_config = db_case_config
|
@@ -31,7 +29,7 @@ class MariaDB(VectorDB):
|
|
31
29
|
|
32
30
|
# construct basic units
|
33
31
|
self.conn, self.cursor = self._create_connection(**self.db_config)
|
34
|
-
|
32
|
+
|
35
33
|
if drop_old:
|
36
34
|
self._drop_db()
|
37
35
|
self._create_db_table(dim)
|
@@ -41,9 +39,8 @@ class MariaDB(VectorDB):
|
|
41
39
|
self.cursor = None
|
42
40
|
self.conn = None
|
43
41
|
|
44
|
-
|
45
42
|
@staticmethod
|
46
|
-
def _create_connection(**kwargs) ->
|
43
|
+
def _create_connection(**kwargs) -> tuple[mariadb.Connection, mariadb.Cursor]:
|
47
44
|
conn = mariadb.connect(**kwargs)
|
48
45
|
cursor = conn.cursor()
|
49
46
|
|
@@ -52,7 +49,6 @@ class MariaDB(VectorDB):
|
|
52
49
|
|
53
50
|
return conn, cursor
|
54
51
|
|
55
|
-
|
56
52
|
def _drop_db(self):
|
57
53
|
assert self.conn is not None, "Connection is not initialized"
|
58
54
|
assert self.cursor is not None, "Cursor is not initialized"
|
@@ -77,24 +73,23 @@ class MariaDB(VectorDB):
|
|
77
73
|
log.info(f"{self.name} client create table : {self.table_name}")
|
78
74
|
self.cursor.execute(f"USE {self.db_name}")
|
79
75
|
|
80
|
-
self.cursor.execute(
|
76
|
+
self.cursor.execute(
|
77
|
+
f"""
|
81
78
|
CREATE TABLE {self.table_name} (
|
82
79
|
id INT PRIMARY KEY,
|
83
80
|
v VECTOR({self.dim}) NOT NULL
|
84
81
|
) ENGINE={index_param["storage_engine"]}
|
85
|
-
"""
|
82
|
+
"""
|
83
|
+
)
|
86
84
|
self.cursor.execute("COMMIT")
|
87
85
|
|
88
86
|
except Exception as e:
|
89
|
-
log.warning(
|
90
|
-
f"Failed to create table: {self.table_name} error: {e}"
|
91
|
-
)
|
87
|
+
log.warning(f"Failed to create table: {self.table_name} error: {e}")
|
92
88
|
raise e from None
|
93
89
|
|
94
|
-
|
95
90
|
@contextmanager
|
96
|
-
def init(self)
|
97
|
-
"""
|
91
|
+
def init(self):
|
92
|
+
"""create and destory connections to database.
|
98
93
|
|
99
94
|
Examples:
|
100
95
|
>>> with self.init():
|
@@ -109,15 +104,21 @@ class MariaDB(VectorDB):
|
|
109
104
|
self.cursor.execute("SET GLOBAL max_allowed_packet = 1073741824")
|
110
105
|
|
111
106
|
if index_param["index_type"] == "HNSW":
|
112
|
-
if index_param["max_cache_size"]
|
113
|
-
self.cursor.execute(f"SET GLOBAL mhnsw_max_cache_size = {index_param[
|
114
|
-
if search_param["ef_search"]
|
115
|
-
self.cursor.execute(f"SET mhnsw_ef_search = {search_param[
|
107
|
+
if index_param["max_cache_size"] is not None:
|
108
|
+
self.cursor.execute(f"SET GLOBAL mhnsw_max_cache_size = {index_param['max_cache_size']}")
|
109
|
+
if search_param["ef_search"] is not None:
|
110
|
+
self.cursor.execute(f"SET mhnsw_ef_search = {search_param['ef_search']}")
|
116
111
|
self.cursor.execute("COMMIT")
|
117
112
|
|
118
|
-
self.insert_sql = f"INSERT INTO {self.db_name}.{self.table_name} (id, v) VALUES (%s, %s)"
|
119
|
-
self.select_sql =
|
120
|
-
|
113
|
+
self.insert_sql = f"INSERT INTO {self.db_name}.{self.table_name} (id, v) VALUES (%s, %s)" # noqa: S608
|
114
|
+
self.select_sql = (
|
115
|
+
f"SELECT id FROM {self.db_name}.{self.table_name}" # noqa: S608
|
116
|
+
f"ORDER by vec_distance_{search_param['metric_type']}(v, %s) LIMIT %d"
|
117
|
+
)
|
118
|
+
self.select_sql_with_filter = (
|
119
|
+
f"SELECT id FROM {self.db_name}.{self.table_name} WHERE id >= %d " # noqa: S608
|
120
|
+
f"ORDER by vec_distance_{search_param['metric_type']}(v, %s) LIMIT %d"
|
121
|
+
)
|
121
122
|
|
122
123
|
try:
|
123
124
|
yield
|
@@ -126,7 +127,6 @@ class MariaDB(VectorDB):
|
|
126
127
|
self.conn.close()
|
127
128
|
self.cursor = None
|
128
129
|
self.conn = None
|
129
|
-
|
130
130
|
|
131
131
|
def ready_to_load(self) -> bool:
|
132
132
|
pass
|
@@ -139,33 +139,31 @@ class MariaDB(VectorDB):
|
|
139
139
|
|
140
140
|
try:
|
141
141
|
index_options = f"DISTANCE={index_param['metric_type']}"
|
142
|
-
if index_param["index_type"] == "HNSW" and index_param["M"]
|
142
|
+
if index_param["index_type"] == "HNSW" and index_param["M"] is not None:
|
143
143
|
index_options += f" M={index_param['M']}"
|
144
144
|
|
145
|
-
self.cursor.execute(
|
145
|
+
self.cursor.execute(
|
146
|
+
f"""
|
146
147
|
ALTER TABLE {self.db_name}.{self.table_name}
|
147
148
|
ADD VECTOR KEY v(v) {index_options}
|
148
|
-
"""
|
149
|
+
"""
|
150
|
+
)
|
149
151
|
self.cursor.execute("COMMIT")
|
150
152
|
|
151
153
|
except Exception as e:
|
152
|
-
log.warning(
|
153
|
-
f"Failed to create index: {self.table_name} error: {e}"
|
154
|
-
)
|
154
|
+
log.warning(f"Failed to create index: {self.table_name} error: {e}")
|
155
155
|
raise e from None
|
156
156
|
|
157
|
-
pass
|
158
|
-
|
159
157
|
@staticmethod
|
160
|
-
def vector_to_hex(v):
|
161
|
-
return np.array(v,
|
158
|
+
def vector_to_hex(v): # noqa: ANN001
|
159
|
+
return np.array(v, "float32").tobytes()
|
162
160
|
|
163
161
|
def insert_embeddings(
|
164
162
|
self,
|
165
163
|
embeddings: list[list[float]],
|
166
164
|
metadata: list[int],
|
167
|
-
**kwargs
|
168
|
-
) ->
|
165
|
+
**kwargs,
|
166
|
+
) -> tuple[int, Exception]:
|
169
167
|
"""Insert embeddings into the database.
|
170
168
|
Should call self.init() first.
|
171
169
|
"""
|
@@ -178,7 +176,7 @@ class MariaDB(VectorDB):
|
|
178
176
|
|
179
177
|
batch_data = []
|
180
178
|
for i, row in enumerate(metadata_arr):
|
181
|
-
batch_data.append((int(row), self.vector_to_hex(embeddings_arr[i])))
|
179
|
+
batch_data.append((int(row), self.vector_to_hex(embeddings_arr[i])))
|
182
180
|
|
183
181
|
self.cursor.executemany(self.insert_sql, batch_data)
|
184
182
|
self.cursor.execute("COMMIT")
|
@@ -186,11 +184,8 @@ class MariaDB(VectorDB):
|
|
186
184
|
|
187
185
|
return len(metadata), None
|
188
186
|
except Exception as e:
|
189
|
-
log.warning(
|
190
|
-
f"Failed to insert data into Vector table ({self.table_name}), error: {e}"
|
191
|
-
)
|
187
|
+
log.warning(f"Failed to insert data into Vector table ({self.table_name}), error: {e}")
|
192
188
|
return 0, e
|
193
|
-
|
194
189
|
|
195
190
|
def search_embedding(
|
196
191
|
self,
|
@@ -198,17 +193,16 @@ class MariaDB(VectorDB):
|
|
198
193
|
k: int = 100,
|
199
194
|
filters: dict | None = None,
|
200
195
|
timeout: int | None = None,
|
201
|
-
**kwargs
|
202
|
-
) ->
|
196
|
+
**kwargs,
|
197
|
+
) -> list[int]:
|
203
198
|
assert self.conn is not None, "Connection is not initialized"
|
204
199
|
assert self.cursor is not None, "Cursor is not initialized"
|
205
200
|
|
206
|
-
search_param = self.case_config.search_param()
|
201
|
+
search_param = self.case_config.search_param() # noqa: F841
|
207
202
|
|
208
203
|
if filters:
|
209
|
-
self.cursor.execute(self.select_sql_with_filter, (filters.get(
|
204
|
+
self.cursor.execute(self.select_sql_with_filter, (filters.get("id"), self.vector_to_hex(query), k))
|
210
205
|
else:
|
211
206
|
self.cursor.execute(self.select_sql, (self.vector_to_hex(query), k))
|
212
207
|
|
213
|
-
return [id for id, in self.cursor.fetchall()]
|
214
|
-
|
208
|
+
return [id for (id,) in self.cursor.fetchall()] # noqa: A001
|
@@ -194,25 +194,6 @@ def MilvusGPUIVFFlat(**parameters: Unpack[MilvusGPUIVFTypedDict]):
|
|
194
194
|
**parameters,
|
195
195
|
)
|
196
196
|
|
197
|
-
@cli.command()
|
198
|
-
@click_parameter_decorators_from_typed_dict(MilvusGPUBruteForceTypedDict)
|
199
|
-
def MilvusGPUBruteForce(**parameters: Unpack[MilvusGPUBruteForceTypedDict]):
|
200
|
-
from .config import GPUBruteForceConfig, MilvusConfig
|
201
|
-
|
202
|
-
run(
|
203
|
-
db=DBTYPE,
|
204
|
-
db_config=MilvusConfig(
|
205
|
-
db_label=parameters["db_label"],
|
206
|
-
uri=SecretStr(parameters["uri"]),
|
207
|
-
user=parameters["user_name"],
|
208
|
-
password=SecretStr(parameters["password"]),
|
209
|
-
),
|
210
|
-
db_case_config=GPUBruteForceConfig(
|
211
|
-
metric_type=parameters["metric_type"],
|
212
|
-
limit=parameters["limit"], # top-k for search
|
213
|
-
),
|
214
|
-
**parameters,
|
215
|
-
)
|
216
197
|
|
217
198
|
class MilvusGPUBruteForceTypedDict(CommonTypedDict, MilvusTypedDict):
|
218
199
|
metric_type: Annotated[
|
@@ -224,6 +205,7 @@ class MilvusGPUBruteForceTypedDict(CommonTypedDict, MilvusTypedDict):
|
|
224
205
|
click.option("--limit", type=int, required=True, help="Top-k limit for search"),
|
225
206
|
]
|
226
207
|
|
208
|
+
|
227
209
|
@cli.command()
|
228
210
|
@click_parameter_decorators_from_typed_dict(MilvusGPUBruteForceTypedDict)
|
229
211
|
def MilvusGPUBruteForce(**parameters: Unpack[MilvusGPUBruteForceTypedDict]):
|
@@ -155,7 +155,7 @@ class Milvus(VectorDB):
|
|
155
155
|
embeddings: Iterable[list[float]],
|
156
156
|
metadata: list[int],
|
157
157
|
**kwargs,
|
158
|
-
) ->
|
158
|
+
) -> tuple[int, Exception]:
|
159
159
|
"""Insert embeddings into Milvus. should call self.init() first"""
|
160
160
|
# use the first insert_embeddings to init collection
|
161
161
|
assert self.col is not None
|
@@ -18,8 +18,7 @@ from ....cli.cli import (
|
|
18
18
|
)
|
19
19
|
|
20
20
|
|
21
|
-
#
|
22
|
-
def set_default_quantized_fetch_limit(ctx: any, param: any, value: any):
|
21
|
+
def set_default_quantized_fetch_limit(ctx: any, param: any, value: any): # noqa: ARG001
|
23
22
|
if ctx.params.get("reranking") and value is None:
|
24
23
|
# ef_search is the default value for quantized_fetch_limit as it's bound by ef_search.
|
25
24
|
# 100 is default value for quantized_fetch_limit for IVFFlat.
|
@@ -1,4 +1,4 @@
|
|
1
|
-
from pydantic import BaseModel, SecretStr
|
1
|
+
from pydantic import BaseModel, SecretStr
|
2
2
|
|
3
3
|
from ..api import DBCaseConfig, DBConfig, MetricType
|
4
4
|
|
@@ -20,14 +20,6 @@ class QdrantConfig(DBConfig):
|
|
20
20
|
"url": self.url.get_secret_value(),
|
21
21
|
}
|
22
22
|
|
23
|
-
@validator("*")
|
24
|
-
def not_empty_field(cls, v: any, field: any):
|
25
|
-
if field.name in ["api_key", "db_label"]:
|
26
|
-
return v
|
27
|
-
if isinstance(v, str | SecretStr) and len(v) == 0:
|
28
|
-
raise ValueError("Empty string!")
|
29
|
-
return v
|
30
|
-
|
31
23
|
|
32
24
|
class QdrantIndexConfig(BaseModel, DBCaseConfig):
|
33
25
|
metric_type: MetricType | None = None
|
@@ -111,7 +111,7 @@ class QdrantCloud(VectorDB):
|
|
111
111
|
embeddings: list[list[float]],
|
112
112
|
metadata: list[int],
|
113
113
|
**kwargs,
|
114
|
-
) ->
|
114
|
+
) -> tuple[int, Exception]:
|
115
115
|
"""Insert embeddings into Milvus. should call self.init() first"""
|
116
116
|
assert self.qdrant_client is not None
|
117
117
|
try:
|
@@ -1,5 +1,6 @@
|
|
1
|
-
from pydantic import
|
2
|
-
|
1
|
+
from pydantic import BaseModel, SecretStr
|
2
|
+
|
3
|
+
from ..api import DBCaseConfig, DBConfig, MetricType
|
3
4
|
|
4
5
|
|
5
6
|
class TiDBConfig(DBConfig):
|
@@ -10,10 +11,6 @@ class TiDBConfig(DBConfig):
|
|
10
11
|
db_name: str = "test"
|
11
12
|
ssl: bool = False
|
12
13
|
|
13
|
-
@validator("*")
|
14
|
-
def not_empty_field(cls, v: any, field: any):
|
15
|
-
return v
|
16
|
-
|
17
14
|
def to_dict(self) -> dict:
|
18
15
|
pwd_str = self.password.get_secret_value()
|
19
16
|
return {
|
@@ -33,10 +30,10 @@ class TiDBIndexConfig(BaseModel, DBCaseConfig):
|
|
33
30
|
def get_metric_fn(self) -> str:
|
34
31
|
if self.metric_type == MetricType.L2:
|
35
32
|
return "vec_l2_distance"
|
36
|
-
|
33
|
+
if self.metric_type == MetricType.COSINE:
|
37
34
|
return "vec_cosine_distance"
|
38
|
-
|
39
|
-
|
35
|
+
msg = f"Unsupported metric type: {self.metric_type}"
|
36
|
+
raise ValueError(msg)
|
40
37
|
|
41
38
|
def index_param(self) -> dict:
|
42
39
|
return {
|
@@ -3,7 +3,7 @@ import io
|
|
3
3
|
import logging
|
4
4
|
import time
|
5
5
|
from contextlib import contextmanager
|
6
|
-
from typing import Any
|
6
|
+
from typing import Any
|
7
7
|
|
8
8
|
import pymysql
|
9
9
|
|
@@ -62,7 +62,7 @@ class TiDB(VectorDB):
|
|
62
62
|
conn.commit()
|
63
63
|
except Exception as e:
|
64
64
|
log.warning("Failed to drop table: %s error: %s", self.table_name, e)
|
65
|
-
raise
|
65
|
+
raise
|
66
66
|
|
67
67
|
def _create_table(self):
|
68
68
|
try:
|
@@ -80,7 +80,7 @@ class TiDB(VectorDB):
|
|
80
80
|
conn.commit()
|
81
81
|
except Exception as e:
|
82
82
|
log.warning("Failed to create table: %s error: %s", self.table_name, e)
|
83
|
-
raise
|
83
|
+
raise
|
84
84
|
|
85
85
|
def ready_to_load(self) -> bool:
|
86
86
|
pass
|
@@ -122,25 +122,25 @@ class TiDB(VectorDB):
|
|
122
122
|
f"""
|
123
123
|
SELECT PROGRESS FROM information_schema.tiflash_replica
|
124
124
|
WHERE TABLE_SCHEMA = "{database}" AND TABLE_NAME = "{self.table_name}"
|
125
|
-
"""
|
125
|
+
""" # noqa: S608
|
126
126
|
)
|
127
127
|
result = cursor.fetchone()
|
128
128
|
return result[0]
|
129
129
|
except Exception as e:
|
130
130
|
log.warning("Failed to check TiFlash replica progress: %s", e)
|
131
|
-
raise
|
131
|
+
raise
|
132
132
|
|
133
133
|
def _optimize_wait_tiflash_catch_up(self):
|
134
134
|
try:
|
135
135
|
with self._get_connection() as (conn, cursor):
|
136
136
|
cursor.execute('SET @@TIDB_ISOLATION_READ_ENGINES="tidb,tiflash"')
|
137
137
|
conn.commit()
|
138
|
-
cursor.execute(f"SELECT COUNT(*) FROM {self.table_name}")
|
138
|
+
cursor.execute(f"SELECT COUNT(*) FROM {self.table_name}") # noqa: S608
|
139
139
|
result = cursor.fetchone()
|
140
140
|
return result[0]
|
141
141
|
except Exception as e:
|
142
142
|
log.warning("Failed to wait TiFlash to catch up: %s", e)
|
143
|
-
raise
|
143
|
+
raise
|
144
144
|
|
145
145
|
def _optimize_compact_tiflash(self):
|
146
146
|
try:
|
@@ -149,7 +149,7 @@ class TiDB(VectorDB):
|
|
149
149
|
conn.commit()
|
150
150
|
except Exception as e:
|
151
151
|
log.warning("Failed to compact table: %s", e)
|
152
|
-
raise
|
152
|
+
raise
|
153
153
|
|
154
154
|
def _optimize_get_tiflash_index_pending_rows(self):
|
155
155
|
try:
|
@@ -160,13 +160,13 @@ class TiDB(VectorDB):
|
|
160
160
|
SELECT SUM(ROWS_STABLE_NOT_INDEXED)
|
161
161
|
FROM information_schema.tiflash_indexes
|
162
162
|
WHERE TIDB_DATABASE = "{database}" AND TIDB_TABLE = "{self.table_name}"
|
163
|
-
"""
|
163
|
+
""" # noqa: S608
|
164
164
|
)
|
165
165
|
result = cursor.fetchone()
|
166
166
|
return result[0]
|
167
167
|
except Exception as e:
|
168
168
|
log.warning("Failed to read TiFlash index pending rows: %s", e)
|
169
|
-
raise
|
169
|
+
raise
|
170
170
|
|
171
171
|
def _insert_embeddings_serial(
|
172
172
|
self,
|
@@ -178,29 +178,28 @@ class TiDB(VectorDB):
|
|
178
178
|
try:
|
179
179
|
with self._get_connection() as (conn, cursor):
|
180
180
|
buf = io.StringIO()
|
181
|
-
buf.write(f"INSERT INTO {self.table_name} (id, embedding) VALUES ")
|
181
|
+
buf.write(f"INSERT INTO {self.table_name} (id, embedding) VALUES ") # noqa: S608
|
182
182
|
for i in range(offset, offset + size):
|
183
183
|
if i > offset:
|
184
184
|
buf.write(",")
|
185
|
-
buf.write(f'({metadata[i]}, "{
|
185
|
+
buf.write(f'({metadata[i]}, "{embeddings[i]!s}")')
|
186
186
|
cursor.execute(buf.getvalue())
|
187
187
|
conn.commit()
|
188
188
|
except Exception as e:
|
189
189
|
log.warning("Failed to insert data into table: %s", e)
|
190
|
-
raise
|
190
|
+
raise
|
191
191
|
|
192
192
|
def insert_embeddings(
|
193
193
|
self,
|
194
194
|
embeddings: list[list[float]],
|
195
195
|
metadata: list[int],
|
196
196
|
**kwargs: Any,
|
197
|
-
) ->
|
197
|
+
) -> tuple[int, Exception]:
|
198
198
|
workers = 10
|
199
199
|
# Avoid exceeding MAX_ALLOWED_PACKET (default=64MB)
|
200
200
|
max_batch_size = 64 * 1024 * 1024 // 24 // self.dim
|
201
201
|
batch_size = len(embeddings) // workers
|
202
|
-
|
203
|
-
batch_size = max_batch_size
|
202
|
+
batch_size = min(batch_size, max_batch_size)
|
204
203
|
with concurrent.futures.ThreadPoolExecutor(max_workers=workers) as executor:
|
205
204
|
futures = []
|
206
205
|
for i in range(0, len(embeddings), batch_size):
|
@@ -227,8 +226,8 @@ class TiDB(VectorDB):
|
|
227
226
|
self.cursor.execute(
|
228
227
|
f"""
|
229
228
|
SELECT id FROM {self.table_name}
|
230
|
-
ORDER BY {self.search_fn}(embedding, "{
|
231
|
-
"""
|
229
|
+
ORDER BY {self.search_fn}(embedding, "{query!s}") LIMIT {k};
|
230
|
+
""" # noqa: S608
|
232
231
|
)
|
233
232
|
result = self.cursor.fetchall()
|
234
233
|
return [int(i[0]) for i in result]
|
@@ -0,0 +1,47 @@
|
|
1
|
+
from typing import Annotated, Unpack
|
2
|
+
|
3
|
+
import click
|
4
|
+
from pydantic import SecretStr
|
5
|
+
|
6
|
+
from vectordb_bench.backend.clients import DB
|
7
|
+
from vectordb_bench.cli.cli import (
|
8
|
+
CommonTypedDict,
|
9
|
+
HNSWFlavor1,
|
10
|
+
cli,
|
11
|
+
click_parameter_decorators_from_typed_dict,
|
12
|
+
run,
|
13
|
+
)
|
14
|
+
|
15
|
+
|
16
|
+
class VespaTypedDict(CommonTypedDict, HNSWFlavor1):
|
17
|
+
uri: Annotated[
|
18
|
+
str,
|
19
|
+
click.option("--uri", "-u", type=str, help="uri connection string", default="http://127.0.0.1"),
|
20
|
+
]
|
21
|
+
port: Annotated[
|
22
|
+
int,
|
23
|
+
click.option("--port", "-p", type=int, help="connection port", default=8080),
|
24
|
+
]
|
25
|
+
quantization: Annotated[
|
26
|
+
str, click.option("--quantization", type=click.Choice(["none", "binary"], case_sensitive=False), default="none")
|
27
|
+
]
|
28
|
+
|
29
|
+
|
30
|
+
@cli.command()
|
31
|
+
@click_parameter_decorators_from_typed_dict(VespaTypedDict)
|
32
|
+
def Vespa(**params: Unpack[VespaTypedDict]):
|
33
|
+
from .config import VespaConfig, VespaHNSWConfig
|
34
|
+
|
35
|
+
case_params = {
|
36
|
+
"quantization_type": params["quantization"],
|
37
|
+
"M": params["m"],
|
38
|
+
"efConstruction": params["ef_construction"],
|
39
|
+
"ef": params["ef_search"],
|
40
|
+
}
|
41
|
+
|
42
|
+
run(
|
43
|
+
db=DB.Vespa,
|
44
|
+
db_config=VespaConfig(url=SecretStr(params["uri"]), port=params["port"]),
|
45
|
+
db_case_config=VespaHNSWConfig(**{k: v for k, v in case_params.items() if v}),
|
46
|
+
**params,
|
47
|
+
)
|
@@ -0,0 +1,51 @@
|
|
1
|
+
from typing import Literal, TypeAlias
|
2
|
+
|
3
|
+
from pydantic import BaseModel, SecretStr
|
4
|
+
|
5
|
+
from ..api import DBCaseConfig, DBConfig, MetricType
|
6
|
+
|
7
|
+
VespaMetric: TypeAlias = Literal["euclidean", "angular", "dotproduct", "prenormalized-angular", "hamming", "geodegrees"]
|
8
|
+
|
9
|
+
VespaQuantizationType: TypeAlias = Literal["none", "binary"]
|
10
|
+
|
11
|
+
|
12
|
+
class VespaConfig(DBConfig):
|
13
|
+
url: SecretStr = "http://127.0.0.1"
|
14
|
+
port: int = 8080
|
15
|
+
|
16
|
+
def to_dict(self):
|
17
|
+
return {
|
18
|
+
"url": self.url.get_secret_value(),
|
19
|
+
"port": self.port,
|
20
|
+
}
|
21
|
+
|
22
|
+
|
23
|
+
class VespaHNSWConfig(BaseModel, DBCaseConfig):
|
24
|
+
metric_type: MetricType = MetricType.COSINE
|
25
|
+
quantization_type: VespaQuantizationType = "none"
|
26
|
+
M: int = 16
|
27
|
+
efConstruction: int = 200
|
28
|
+
ef: int = 100
|
29
|
+
|
30
|
+
def index_param(self) -> dict:
|
31
|
+
return {
|
32
|
+
"distance_metric": self.parse_metric(self.metric_type),
|
33
|
+
"max_links_per_node": self.M,
|
34
|
+
"neighbors_to_explore_at_insert": self.efConstruction,
|
35
|
+
}
|
36
|
+
|
37
|
+
def search_param(self) -> dict:
|
38
|
+
return {}
|
39
|
+
|
40
|
+
def parse_metric(self, metric_type: MetricType) -> VespaMetric:
|
41
|
+
match metric_type:
|
42
|
+
case MetricType.COSINE:
|
43
|
+
return "angular"
|
44
|
+
case MetricType.L2:
|
45
|
+
return "euclidean"
|
46
|
+
case MetricType.DP | MetricType.IP:
|
47
|
+
return "dotproduct"
|
48
|
+
case MetricType.HAMMING:
|
49
|
+
return "hamming"
|
50
|
+
case _:
|
51
|
+
raise NotImplementedError
|
@@ -0,0 +1,15 @@
|
|
1
|
+
"""Utility functions for supporting binary quantization
|
2
|
+
|
3
|
+
From https://docs.vespa.ai/en/binarizing-vectors.html#appendix-conversion-to-int8
|
4
|
+
"""
|
5
|
+
|
6
|
+
import numpy as np
|
7
|
+
|
8
|
+
|
9
|
+
def binarize_tensor(tensor: list[float]) -> list[int]:
|
10
|
+
"""
|
11
|
+
Binarize a floating-point list by thresholding at zero
|
12
|
+
and packing the bits into bytes.
|
13
|
+
"""
|
14
|
+
tensor = np.array(tensor)
|
15
|
+
return np.packbits(np.where(tensor > 0, 1, 0), axis=0).astype(np.int8).tolist()
|