vectordb-bench 0.0.22__py3-none-any.whl → 0.0.23__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vectordb_bench/backend/clients/__init__.py +32 -0
- vectordb_bench/backend/clients/api.py +1 -0
- vectordb_bench/backend/clients/mariadb/cli.py +107 -0
- vectordb_bench/backend/clients/mariadb/config.py +71 -0
- vectordb_bench/backend/clients/mariadb/mariadb.py +214 -0
- vectordb_bench/backend/clients/milvus/cli.py +50 -0
- vectordb_bench/backend/clients/milvus/config.py +33 -0
- vectordb_bench/backend/clients/pgvector/cli.py +13 -1
- vectordb_bench/backend/clients/pgvector/config.py +22 -5
- vectordb_bench/backend/clients/pgvector/pgvector.py +62 -19
- vectordb_bench/backend/clients/tidb/cli.py +98 -0
- vectordb_bench/backend/clients/tidb/config.py +49 -0
- vectordb_bench/backend/clients/tidb/tidb.py +234 -0
- vectordb_bench/cli/vectordbbench.py +4 -0
- vectordb_bench/frontend/config/dbCaseConfigs.py +96 -0
- vectordb_bench/frontend/config/styles.py +2 -0
- vectordb_bench/models.py +3 -0
- {vectordb_bench-0.0.22.dist-info → vectordb_bench-0.0.23.dist-info}/METADATA +13 -2
- {vectordb_bench-0.0.22.dist-info → vectordb_bench-0.0.23.dist-info}/RECORD +23 -17
- {vectordb_bench-0.0.22.dist-info → vectordb_bench-0.0.23.dist-info}/WHEEL +1 -1
- {vectordb_bench-0.0.22.dist-info → vectordb_bench-0.0.23.dist-info}/LICENSE +0 -0
- {vectordb_bench-0.0.22.dist-info → vectordb_bench-0.0.23.dist-info}/entry_points.txt +0 -0
- {vectordb_bench-0.0.22.dist-info → vectordb_bench-0.0.23.dist-info}/top_level.txt +0 -0
@@ -94,7 +94,7 @@ class PgVector(VectorDB):
|
|
94
94
|
reranking = self.case_config.search_param()["reranking"]
|
95
95
|
column_name = (
|
96
96
|
sql.SQL("binary_quantize({0})").format(sql.Identifier("embedding"))
|
97
|
-
if index_param["quantization_type"] == "bit"
|
97
|
+
if index_param["quantization_type"] == "bit" and index_param["table_quantization_type"] != "bit"
|
98
98
|
else sql.SQL("embedding")
|
99
99
|
)
|
100
100
|
search_vector = (
|
@@ -104,7 +104,8 @@ class PgVector(VectorDB):
|
|
104
104
|
)
|
105
105
|
|
106
106
|
# The following sections assume that the quantization_type value matches the quantization function name
|
107
|
-
if index_param["quantization_type"]
|
107
|
+
if index_param["quantization_type"] != index_param["table_quantization_type"]:
|
108
|
+
# Reranking makes sense only if table quantization is not "bit"
|
108
109
|
if index_param["quantization_type"] == "bit" and reranking:
|
109
110
|
# Embeddings needs to be passed to binary_quantize function if quantization_type is bit
|
110
111
|
search_query = sql.Composed(
|
@@ -113,7 +114,7 @@ class PgVector(VectorDB):
|
|
113
114
|
"""
|
114
115
|
SELECT i.id
|
115
116
|
FROM (
|
116
|
-
SELECT id, embedding {reranking_metric_fun_op} %s::
|
117
|
+
SELECT id, embedding {reranking_metric_fun_op} %s::{table_quantization_type} AS distance
|
117
118
|
FROM public.{table_name} {where_clause}
|
118
119
|
ORDER BY {column_name}::{quantization_type}({dim})
|
119
120
|
""",
|
@@ -123,6 +124,8 @@ class PgVector(VectorDB):
|
|
123
124
|
reranking_metric_fun_op=sql.SQL(
|
124
125
|
self.case_config.search_param()["reranking_metric_fun_op"],
|
125
126
|
),
|
127
|
+
search_vector=search_vector,
|
128
|
+
table_quantization_type=sql.SQL(index_param["table_quantization_type"]),
|
126
129
|
quantization_type=sql.SQL(index_param["quantization_type"]),
|
127
130
|
dim=sql.Literal(self.dim),
|
128
131
|
where_clause=sql.SQL("WHERE id >= %s") if filtered else sql.SQL(""),
|
@@ -130,7 +133,7 @@ class PgVector(VectorDB):
|
|
130
133
|
sql.SQL(self.case_config.search_param()["metric_fun_op"]),
|
131
134
|
sql.SQL(
|
132
135
|
"""
|
133
|
-
{search_vector}
|
136
|
+
{search_vector}::{quantization_type}({dim})
|
134
137
|
LIMIT {quantized_fetch_limit}
|
135
138
|
) i
|
136
139
|
ORDER BY i.distance
|
@@ -138,6 +141,8 @@ class PgVector(VectorDB):
|
|
138
141
|
""",
|
139
142
|
).format(
|
140
143
|
search_vector=search_vector,
|
144
|
+
quantization_type=sql.SQL(index_param["quantization_type"]),
|
145
|
+
dim=sql.Literal(self.dim),
|
141
146
|
quantized_fetch_limit=sql.Literal(
|
142
147
|
self.case_config.search_param()["quantized_fetch_limit"],
|
143
148
|
),
|
@@ -160,10 +165,12 @@ class PgVector(VectorDB):
|
|
160
165
|
where_clause=sql.SQL("WHERE id >= %s") if filtered else sql.SQL(""),
|
161
166
|
),
|
162
167
|
sql.SQL(self.case_config.search_param()["metric_fun_op"]),
|
163
|
-
sql.SQL(" {search_vector} LIMIT %s::int").format(
|
168
|
+
sql.SQL(" {search_vector}::{quantization_type}({dim}) LIMIT %s::int").format(
|
164
169
|
search_vector=search_vector,
|
170
|
+
quantization_type=sql.SQL(index_param["quantization_type"]),
|
171
|
+
dim=sql.Literal(self.dim),
|
165
172
|
),
|
166
|
-
]
|
173
|
+
]
|
167
174
|
)
|
168
175
|
else:
|
169
176
|
search_query = sql.Composed(
|
@@ -175,8 +182,12 @@ class PgVector(VectorDB):
|
|
175
182
|
where_clause=sql.SQL("WHERE id >= %s") if filtered else sql.SQL(""),
|
176
183
|
),
|
177
184
|
sql.SQL(self.case_config.search_param()["metric_fun_op"]),
|
178
|
-
sql.SQL("
|
179
|
-
|
185
|
+
sql.SQL(" {search_vector}::{quantization_type}({dim}) LIMIT %s::int").format(
|
186
|
+
search_vector=search_vector,
|
187
|
+
quantization_type=sql.SQL(index_param["quantization_type"]),
|
188
|
+
dim=sql.Literal(self.dim),
|
189
|
+
),
|
190
|
+
]
|
180
191
|
)
|
181
192
|
|
182
193
|
return search_query
|
@@ -323,7 +334,7 @@ class PgVector(VectorDB):
|
|
323
334
|
)
|
324
335
|
with_clause = sql.SQL("WITH ({});").format(sql.SQL(", ").join(options)) if any(options) else sql.Composed(())
|
325
336
|
|
326
|
-
if index_param["quantization_type"]
|
337
|
+
if index_param["quantization_type"] != index_param["table_quantization_type"]:
|
327
338
|
index_create_sql = sql.SQL(
|
328
339
|
"""
|
329
340
|
CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name}
|
@@ -365,14 +376,23 @@ class PgVector(VectorDB):
|
|
365
376
|
assert self.conn is not None, "Connection is not initialized"
|
366
377
|
assert self.cursor is not None, "Cursor is not initialized"
|
367
378
|
|
379
|
+
index_param = self.case_config.index_param()
|
380
|
+
|
368
381
|
try:
|
369
382
|
log.info(f"{self.name} client create table : {self.table_name}")
|
370
383
|
|
371
384
|
# create table
|
372
385
|
self.cursor.execute(
|
373
386
|
sql.SQL(
|
374
|
-
"
|
375
|
-
|
387
|
+
"""
|
388
|
+
CREATE TABLE IF NOT EXISTS public.{table_name}
|
389
|
+
(id BIGINT PRIMARY KEY, embedding {table_quantization_type}({dim}));
|
390
|
+
"""
|
391
|
+
).format(
|
392
|
+
table_name=sql.Identifier(self.table_name),
|
393
|
+
table_quantization_type=sql.SQL(index_param["table_quantization_type"]),
|
394
|
+
dim=dim,
|
395
|
+
)
|
376
396
|
)
|
377
397
|
self.cursor.execute(
|
378
398
|
sql.SQL(
|
@@ -393,18 +413,41 @@ class PgVector(VectorDB):
|
|
393
413
|
assert self.conn is not None, "Connection is not initialized"
|
394
414
|
assert self.cursor is not None, "Cursor is not initialized"
|
395
415
|
|
416
|
+
index_param = self.case_config.index_param()
|
417
|
+
|
396
418
|
try:
|
397
419
|
metadata_arr = np.array(metadata)
|
398
420
|
embeddings_arr = np.array(embeddings)
|
399
421
|
|
400
|
-
|
401
|
-
|
402
|
-
|
403
|
-
|
404
|
-
|
405
|
-
copy
|
406
|
-
|
407
|
-
|
422
|
+
if index_param["table_quantization_type"] == "bit":
|
423
|
+
with self.cursor.copy(
|
424
|
+
sql.SQL("COPY public.{table_name} FROM STDIN (FORMAT TEXT)").format(
|
425
|
+
table_name=sql.Identifier(self.table_name)
|
426
|
+
)
|
427
|
+
) as copy:
|
428
|
+
# Same logic as pgvector binary_quantize
|
429
|
+
for i, row in enumerate(metadata_arr):
|
430
|
+
embeddings_bit = ""
|
431
|
+
for embedding in embeddings_arr[i]:
|
432
|
+
if embedding > 0:
|
433
|
+
embeddings_bit += "1"
|
434
|
+
else:
|
435
|
+
embeddings_bit += "0"
|
436
|
+
copy.write_row((str(row), embeddings_bit))
|
437
|
+
else:
|
438
|
+
with self.cursor.copy(
|
439
|
+
sql.SQL("COPY public.{table_name} FROM STDIN (FORMAT BINARY)").format(
|
440
|
+
table_name=sql.Identifier(self.table_name)
|
441
|
+
)
|
442
|
+
) as copy:
|
443
|
+
if index_param["table_quantization_type"] == "halfvec":
|
444
|
+
copy.set_types(["bigint", "halfvec"])
|
445
|
+
for i, row in enumerate(metadata_arr):
|
446
|
+
copy.write_row((row, np.float16(embeddings_arr[i])))
|
447
|
+
else:
|
448
|
+
copy.set_types(["bigint", "vector"])
|
449
|
+
for i, row in enumerate(metadata_arr):
|
450
|
+
copy.write_row((row, embeddings_arr[i]))
|
408
451
|
self.conn.commit()
|
409
452
|
|
410
453
|
if kwargs.get("last_batch"):
|
@@ -0,0 +1,98 @@
|
|
1
|
+
from typing import Annotated, Unpack
|
2
|
+
|
3
|
+
import click
|
4
|
+
from pydantic import SecretStr
|
5
|
+
|
6
|
+
from vectordb_bench.backend.clients import DB
|
7
|
+
|
8
|
+
from ....cli.cli import CommonTypedDict, cli, click_parameter_decorators_from_typed_dict, run
|
9
|
+
|
10
|
+
|
11
|
+
class TiDBTypedDict(CommonTypedDict):
|
12
|
+
user_name: Annotated[
|
13
|
+
str,
|
14
|
+
click.option(
|
15
|
+
"--username",
|
16
|
+
type=str,
|
17
|
+
help="Username",
|
18
|
+
default="root",
|
19
|
+
show_default=True,
|
20
|
+
required=True,
|
21
|
+
),
|
22
|
+
]
|
23
|
+
password: Annotated[
|
24
|
+
str,
|
25
|
+
click.option(
|
26
|
+
"--password",
|
27
|
+
type=str,
|
28
|
+
default="",
|
29
|
+
show_default=True,
|
30
|
+
help="Password",
|
31
|
+
),
|
32
|
+
]
|
33
|
+
host: Annotated[
|
34
|
+
str,
|
35
|
+
click.option(
|
36
|
+
"--host",
|
37
|
+
type=str,
|
38
|
+
default="127.0.0.1",
|
39
|
+
show_default=True,
|
40
|
+
required=True,
|
41
|
+
help="Db host",
|
42
|
+
),
|
43
|
+
]
|
44
|
+
port: Annotated[
|
45
|
+
int,
|
46
|
+
click.option(
|
47
|
+
"--port",
|
48
|
+
type=int,
|
49
|
+
default=4000,
|
50
|
+
show_default=True,
|
51
|
+
required=True,
|
52
|
+
help="Db Port",
|
53
|
+
),
|
54
|
+
]
|
55
|
+
db_name: Annotated[
|
56
|
+
str,
|
57
|
+
click.option(
|
58
|
+
"--db-name",
|
59
|
+
type=str,
|
60
|
+
default="test",
|
61
|
+
show_default=True,
|
62
|
+
required=True,
|
63
|
+
help="Db name",
|
64
|
+
),
|
65
|
+
]
|
66
|
+
ssl: Annotated[
|
67
|
+
bool,
|
68
|
+
click.option(
|
69
|
+
"--ssl/--no-ssl",
|
70
|
+
default=False,
|
71
|
+
show_default=True,
|
72
|
+
is_flag=True,
|
73
|
+
help="Enable or disable SSL, for TiDB Serverless SSL must be enabled",
|
74
|
+
),
|
75
|
+
]
|
76
|
+
|
77
|
+
|
78
|
+
@cli.command()
|
79
|
+
@click_parameter_decorators_from_typed_dict(TiDBTypedDict)
|
80
|
+
def TiDB(
|
81
|
+
**parameters: Unpack[TiDBTypedDict],
|
82
|
+
):
|
83
|
+
from .config import TiDBConfig, TiDBIndexConfig
|
84
|
+
|
85
|
+
run(
|
86
|
+
db=DB.TiDB,
|
87
|
+
db_config=TiDBConfig(
|
88
|
+
db_label=parameters["db_label"],
|
89
|
+
user_name=parameters["username"],
|
90
|
+
password=SecretStr(parameters["password"]),
|
91
|
+
host=parameters["host"],
|
92
|
+
port=parameters["port"],
|
93
|
+
db_name=parameters["db_name"],
|
94
|
+
ssl=parameters["ssl"],
|
95
|
+
),
|
96
|
+
db_case_config=TiDBIndexConfig(),
|
97
|
+
**parameters,
|
98
|
+
)
|
@@ -0,0 +1,49 @@
|
|
1
|
+
from pydantic import SecretStr, BaseModel, validator
|
2
|
+
from ..api import DBConfig, DBCaseConfig, MetricType
|
3
|
+
|
4
|
+
|
5
|
+
class TiDBConfig(DBConfig):
|
6
|
+
user_name: str = "root"
|
7
|
+
password: SecretStr
|
8
|
+
host: str = "127.0.0.1"
|
9
|
+
port: int = 4000
|
10
|
+
db_name: str = "test"
|
11
|
+
ssl: bool = False
|
12
|
+
|
13
|
+
@validator("*")
|
14
|
+
def not_empty_field(cls, v: any, field: any):
|
15
|
+
return v
|
16
|
+
|
17
|
+
def to_dict(self) -> dict:
|
18
|
+
pwd_str = self.password.get_secret_value()
|
19
|
+
return {
|
20
|
+
"host": self.host,
|
21
|
+
"port": self.port,
|
22
|
+
"user": self.user_name,
|
23
|
+
"password": pwd_str,
|
24
|
+
"database": self.db_name,
|
25
|
+
"ssl_verify_cert": self.ssl,
|
26
|
+
"ssl_verify_identity": self.ssl,
|
27
|
+
}
|
28
|
+
|
29
|
+
|
30
|
+
class TiDBIndexConfig(BaseModel, DBCaseConfig):
|
31
|
+
metric_type: MetricType | None = None
|
32
|
+
|
33
|
+
def get_metric_fn(self) -> str:
|
34
|
+
if self.metric_type == MetricType.L2:
|
35
|
+
return "vec_l2_distance"
|
36
|
+
elif self.metric_type == MetricType.COSINE:
|
37
|
+
return "vec_cosine_distance"
|
38
|
+
else:
|
39
|
+
raise ValueError(f"Unsupported metric type: {self.metric_type}")
|
40
|
+
|
41
|
+
def index_param(self) -> dict:
|
42
|
+
return {
|
43
|
+
"metric_fn": self.get_metric_fn(),
|
44
|
+
}
|
45
|
+
|
46
|
+
def search_param(self) -> dict:
|
47
|
+
return {
|
48
|
+
"metric_fn": self.get_metric_fn(),
|
49
|
+
}
|
@@ -0,0 +1,234 @@
|
|
1
|
+
import concurrent.futures
|
2
|
+
import io
|
3
|
+
import logging
|
4
|
+
import time
|
5
|
+
from contextlib import contextmanager
|
6
|
+
from typing import Any, Optional, Tuple
|
7
|
+
|
8
|
+
import pymysql
|
9
|
+
|
10
|
+
from ..api import VectorDB
|
11
|
+
from .config import TiDBIndexConfig
|
12
|
+
|
13
|
+
log = logging.getLogger(__name__)
|
14
|
+
|
15
|
+
|
16
|
+
class TiDB(VectorDB):
|
17
|
+
def __init__(
|
18
|
+
self,
|
19
|
+
dim: int,
|
20
|
+
db_config: dict,
|
21
|
+
db_case_config: TiDBIndexConfig,
|
22
|
+
collection_name: str = "vector_bench_test",
|
23
|
+
drop_old: bool = False,
|
24
|
+
**kwargs,
|
25
|
+
):
|
26
|
+
self.name = "TiDB"
|
27
|
+
self.db_config = db_config
|
28
|
+
self.case_config = db_case_config
|
29
|
+
self.table_name = collection_name
|
30
|
+
self.dim = dim
|
31
|
+
self.conn = None # To be inited by init()
|
32
|
+
self.cursor = None # To be inited by init()
|
33
|
+
|
34
|
+
self.search_fn = db_case_config.search_param()["metric_fn"]
|
35
|
+
|
36
|
+
if drop_old:
|
37
|
+
self._drop_table()
|
38
|
+
self._create_table()
|
39
|
+
|
40
|
+
@contextmanager
|
41
|
+
def init(self):
|
42
|
+
with self._get_connection() as (conn, cursor):
|
43
|
+
self.conn = conn
|
44
|
+
self.cursor = cursor
|
45
|
+
try:
|
46
|
+
yield
|
47
|
+
finally:
|
48
|
+
self.conn = None
|
49
|
+
self.cursor = None
|
50
|
+
|
51
|
+
@contextmanager
|
52
|
+
def _get_connection(self):
|
53
|
+
with pymysql.connect(**self.db_config) as conn:
|
54
|
+
conn.autocommit = False
|
55
|
+
with conn.cursor() as cursor:
|
56
|
+
yield conn, cursor
|
57
|
+
|
58
|
+
def _drop_table(self):
|
59
|
+
try:
|
60
|
+
with self._get_connection() as (conn, cursor):
|
61
|
+
cursor.execute(f"DROP TABLE IF EXISTS {self.table_name}")
|
62
|
+
conn.commit()
|
63
|
+
except Exception as e:
|
64
|
+
log.warning("Failed to drop table: %s error: %s", self.table_name, e)
|
65
|
+
raise e
|
66
|
+
|
67
|
+
def _create_table(self):
|
68
|
+
try:
|
69
|
+
index_param = self.case_config.index_param()
|
70
|
+
with self._get_connection() as (conn, cursor):
|
71
|
+
cursor.execute(
|
72
|
+
f"""
|
73
|
+
CREATE TABLE {self.table_name} (
|
74
|
+
id BIGINT PRIMARY KEY,
|
75
|
+
embedding VECTOR({self.dim}) NOT NULL,
|
76
|
+
VECTOR INDEX (({index_param["metric_fn"]}(embedding)))
|
77
|
+
);
|
78
|
+
"""
|
79
|
+
)
|
80
|
+
conn.commit()
|
81
|
+
except Exception as e:
|
82
|
+
log.warning("Failed to create table: %s error: %s", self.table_name, e)
|
83
|
+
raise e
|
84
|
+
|
85
|
+
def ready_to_load(self) -> bool:
|
86
|
+
pass
|
87
|
+
|
88
|
+
def optimize(self, data_size: int | None = None) -> None:
|
89
|
+
while True:
|
90
|
+
progress = self._optimize_check_tiflash_replica_progress()
|
91
|
+
if progress != 1:
|
92
|
+
log.info("Data replication not ready, progress: %d", progress)
|
93
|
+
time.sleep(2)
|
94
|
+
else:
|
95
|
+
break
|
96
|
+
|
97
|
+
log.info("Waiting TiFlash to catch up...")
|
98
|
+
self._optimize_wait_tiflash_catch_up()
|
99
|
+
|
100
|
+
log.info("Start compacting TiFlash replica...")
|
101
|
+
self._optimize_compact_tiflash()
|
102
|
+
|
103
|
+
log.info("Waiting index build to finish...")
|
104
|
+
log_reduce_seq = 0
|
105
|
+
while True:
|
106
|
+
pending_rows = self._optimize_get_tiflash_index_pending_rows()
|
107
|
+
if pending_rows > 0:
|
108
|
+
if log_reduce_seq % 15 == 0:
|
109
|
+
log.info("Index not fully built, pending rows: %d", pending_rows)
|
110
|
+
log_reduce_seq += 1
|
111
|
+
time.sleep(2)
|
112
|
+
else:
|
113
|
+
break
|
114
|
+
|
115
|
+
log.info("Index build finished successfully.")
|
116
|
+
|
117
|
+
def _optimize_check_tiflash_replica_progress(self):
|
118
|
+
try:
|
119
|
+
database = self.db_config["database"]
|
120
|
+
with self._get_connection() as (_, cursor):
|
121
|
+
cursor.execute(
|
122
|
+
f"""
|
123
|
+
SELECT PROGRESS FROM information_schema.tiflash_replica
|
124
|
+
WHERE TABLE_SCHEMA = "{database}" AND TABLE_NAME = "{self.table_name}"
|
125
|
+
"""
|
126
|
+
)
|
127
|
+
result = cursor.fetchone()
|
128
|
+
return result[0]
|
129
|
+
except Exception as e:
|
130
|
+
log.warning("Failed to check TiFlash replica progress: %s", e)
|
131
|
+
raise e
|
132
|
+
|
133
|
+
def _optimize_wait_tiflash_catch_up(self):
|
134
|
+
try:
|
135
|
+
with self._get_connection() as (conn, cursor):
|
136
|
+
cursor.execute('SET @@TIDB_ISOLATION_READ_ENGINES="tidb,tiflash"')
|
137
|
+
conn.commit()
|
138
|
+
cursor.execute(f"SELECT COUNT(*) FROM {self.table_name}")
|
139
|
+
result = cursor.fetchone()
|
140
|
+
return result[0]
|
141
|
+
except Exception as e:
|
142
|
+
log.warning("Failed to wait TiFlash to catch up: %s", e)
|
143
|
+
raise e
|
144
|
+
|
145
|
+
def _optimize_compact_tiflash(self):
|
146
|
+
try:
|
147
|
+
with self._get_connection() as (conn, cursor):
|
148
|
+
cursor.execute(f"ALTER TABLE {self.table_name} COMPACT")
|
149
|
+
conn.commit()
|
150
|
+
except Exception as e:
|
151
|
+
log.warning("Failed to compact table: %s", e)
|
152
|
+
raise e
|
153
|
+
|
154
|
+
def _optimize_get_tiflash_index_pending_rows(self):
|
155
|
+
try:
|
156
|
+
database = self.db_config["database"]
|
157
|
+
with self._get_connection() as (_, cursor):
|
158
|
+
cursor.execute(
|
159
|
+
f"""
|
160
|
+
SELECT SUM(ROWS_STABLE_NOT_INDEXED)
|
161
|
+
FROM information_schema.tiflash_indexes
|
162
|
+
WHERE TIDB_DATABASE = "{database}" AND TIDB_TABLE = "{self.table_name}"
|
163
|
+
"""
|
164
|
+
)
|
165
|
+
result = cursor.fetchone()
|
166
|
+
return result[0]
|
167
|
+
except Exception as e:
|
168
|
+
log.warning("Failed to read TiFlash index pending rows: %s", e)
|
169
|
+
raise e
|
170
|
+
|
171
|
+
def _insert_embeddings_serial(
|
172
|
+
self,
|
173
|
+
embeddings: list[list[float]],
|
174
|
+
metadata: list[int],
|
175
|
+
offset: int,
|
176
|
+
size: int,
|
177
|
+
) -> Exception:
|
178
|
+
try:
|
179
|
+
with self._get_connection() as (conn, cursor):
|
180
|
+
buf = io.StringIO()
|
181
|
+
buf.write(f"INSERT INTO {self.table_name} (id, embedding) VALUES ")
|
182
|
+
for i in range(offset, offset + size):
|
183
|
+
if i > offset:
|
184
|
+
buf.write(",")
|
185
|
+
buf.write(f'({metadata[i]}, "{str(embeddings[i])}")')
|
186
|
+
cursor.execute(buf.getvalue())
|
187
|
+
conn.commit()
|
188
|
+
except Exception as e:
|
189
|
+
log.warning("Failed to insert data into table: %s", e)
|
190
|
+
raise e
|
191
|
+
|
192
|
+
def insert_embeddings(
|
193
|
+
self,
|
194
|
+
embeddings: list[list[float]],
|
195
|
+
metadata: list[int],
|
196
|
+
**kwargs: Any,
|
197
|
+
) -> Tuple[int, Optional[Exception]]:
|
198
|
+
workers = 10
|
199
|
+
# Avoid exceeding MAX_ALLOWED_PACKET (default=64MB)
|
200
|
+
max_batch_size = 64 * 1024 * 1024 // 24 // self.dim
|
201
|
+
batch_size = len(embeddings) // workers
|
202
|
+
if batch_size > max_batch_size:
|
203
|
+
batch_size = max_batch_size
|
204
|
+
with concurrent.futures.ThreadPoolExecutor(max_workers=workers) as executor:
|
205
|
+
futures = []
|
206
|
+
for i in range(0, len(embeddings), batch_size):
|
207
|
+
offset = i
|
208
|
+
size = min(batch_size, len(embeddings) - i)
|
209
|
+
future = executor.submit(self._insert_embeddings_serial, embeddings, metadata, offset, size)
|
210
|
+
futures.append(future)
|
211
|
+
done, pending = concurrent.futures.wait(futures, return_when=concurrent.futures.FIRST_EXCEPTION)
|
212
|
+
executor.shutdown(wait=False)
|
213
|
+
for future in done:
|
214
|
+
future.result()
|
215
|
+
for future in pending:
|
216
|
+
future.cancel()
|
217
|
+
return len(metadata), None
|
218
|
+
|
219
|
+
def search_embedding(
|
220
|
+
self,
|
221
|
+
query: list[float],
|
222
|
+
k: int = 100,
|
223
|
+
filters: dict | None = None,
|
224
|
+
timeout: int | None = None,
|
225
|
+
**kwargs: Any,
|
226
|
+
) -> list[int]:
|
227
|
+
self.cursor.execute(
|
228
|
+
f"""
|
229
|
+
SELECT id FROM {self.table_name}
|
230
|
+
ORDER BY {self.search_fn}(embedding, "{str(query)}") LIMIT {k};
|
231
|
+
"""
|
232
|
+
)
|
233
|
+
result = self.cursor.fetchall()
|
234
|
+
return [int(i[0]) for i in result]
|
@@ -1,5 +1,6 @@
|
|
1
1
|
from ..backend.clients.alloydb.cli import AlloyDBScaNN
|
2
2
|
from ..backend.clients.aws_opensearch.cli import AWSOpenSearch
|
3
|
+
from ..backend.clients.mariadb.cli import MariaDBHNSW
|
3
4
|
from ..backend.clients.memorydb.cli import MemoryDB
|
4
5
|
from ..backend.clients.milvus.cli import MilvusAutoIndex
|
5
6
|
from ..backend.clients.pgdiskann.cli import PgDiskAnn
|
@@ -10,6 +11,7 @@ from ..backend.clients.redis.cli import Redis
|
|
10
11
|
from ..backend.clients.test.cli import Test
|
11
12
|
from ..backend.clients.weaviate_cloud.cli import Weaviate
|
12
13
|
from ..backend.clients.zilliz_cloud.cli import ZillizAutoIndex
|
14
|
+
from ..backend.clients.tidb.cli import TiDB
|
13
15
|
from .cli import cli
|
14
16
|
|
15
17
|
cli.add_command(PgVectorHNSW)
|
@@ -25,6 +27,8 @@ cli.add_command(AWSOpenSearch)
|
|
25
27
|
cli.add_command(PgVectorScaleDiskAnn)
|
26
28
|
cli.add_command(PgDiskAnn)
|
27
29
|
cli.add_command(AlloyDBScaNN)
|
30
|
+
cli.add_command(MariaDBHNSW)
|
31
|
+
cli.add_command(TiDB)
|
28
32
|
|
29
33
|
|
30
34
|
if __name__ == "__main__":
|