vectordb-bench 1.0.4__py3-none-any.whl → 1.0.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vectordb_bench/__init__.py +1 -0
- vectordb_bench/backend/cases.py +45 -1
- vectordb_bench/backend/clients/__init__.py +47 -0
- vectordb_bench/backend/clients/api.py +2 -0
- vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py +104 -40
- vectordb_bench/backend/clients/aws_opensearch/cli.py +52 -15
- vectordb_bench/backend/clients/aws_opensearch/config.py +27 -7
- vectordb_bench/backend/clients/hologres/cli.py +50 -0
- vectordb_bench/backend/clients/hologres/config.py +121 -0
- vectordb_bench/backend/clients/hologres/hologres.py +365 -0
- vectordb_bench/backend/clients/lancedb/lancedb.py +1 -0
- vectordb_bench/backend/clients/milvus/cli.py +29 -9
- vectordb_bench/backend/clients/milvus/config.py +2 -0
- vectordb_bench/backend/clients/milvus/milvus.py +1 -1
- vectordb_bench/backend/clients/oceanbase/cli.py +1 -0
- vectordb_bench/backend/clients/oceanbase/config.py +3 -1
- vectordb_bench/backend/clients/oceanbase/oceanbase.py +20 -4
- vectordb_bench/backend/clients/oss_opensearch/cli.py +155 -0
- vectordb_bench/backend/clients/oss_opensearch/config.py +157 -0
- vectordb_bench/backend/clients/oss_opensearch/oss_opensearch.py +582 -0
- vectordb_bench/backend/clients/oss_opensearch/run.py +166 -0
- vectordb_bench/backend/clients/pgdiskann/cli.py +45 -0
- vectordb_bench/backend/clients/pgdiskann/config.py +16 -0
- vectordb_bench/backend/clients/pgdiskann/pgdiskann.py +94 -26
- vectordb_bench/backend/clients/s3_vectors/config.py +41 -0
- vectordb_bench/backend/clients/s3_vectors/s3_vectors.py +171 -0
- vectordb_bench/backend/clients/tidb/cli.py +0 -4
- vectordb_bench/backend/clients/tidb/config.py +22 -2
- vectordb_bench/backend/clients/zilliz_cloud/cli.py +14 -1
- vectordb_bench/backend/clients/zilliz_cloud/config.py +4 -1
- vectordb_bench/backend/dataset.py +70 -0
- vectordb_bench/backend/filter.py +17 -0
- vectordb_bench/backend/runner/mp_runner.py +4 -0
- vectordb_bench/backend/runner/rate_runner.py +23 -11
- vectordb_bench/backend/runner/read_write_runner.py +10 -9
- vectordb_bench/backend/runner/serial_runner.py +23 -7
- vectordb_bench/backend/task_runner.py +5 -4
- vectordb_bench/cli/cli.py +36 -0
- vectordb_bench/cli/vectordbbench.py +4 -0
- vectordb_bench/fig/custom_case_run_test.png +0 -0
- vectordb_bench/fig/custom_dataset.png +0 -0
- vectordb_bench/fig/homepage/bar-chart.png +0 -0
- vectordb_bench/fig/homepage/concurrent.png +0 -0
- vectordb_bench/fig/homepage/custom.png +0 -0
- vectordb_bench/fig/homepage/label_filter.png +0 -0
- vectordb_bench/fig/homepage/qp$.png +0 -0
- vectordb_bench/fig/homepage/run_test.png +0 -0
- vectordb_bench/fig/homepage/streaming.png +0 -0
- vectordb_bench/fig/homepage/table.png +0 -0
- vectordb_bench/fig/run_test_select_case.png +0 -0
- vectordb_bench/fig/run_test_select_db.png +0 -0
- vectordb_bench/fig/run_test_submit.png +0 -0
- vectordb_bench/frontend/components/check_results/filters.py +1 -4
- vectordb_bench/frontend/components/check_results/nav.py +2 -1
- vectordb_bench/frontend/components/concurrent/charts.py +5 -0
- vectordb_bench/frontend/components/int_filter/charts.py +60 -0
- vectordb_bench/frontend/components/streaming/data.py +7 -0
- vectordb_bench/frontend/components/welcome/welcomePrams.py +42 -4
- vectordb_bench/frontend/config/dbCaseConfigs.py +142 -16
- vectordb_bench/frontend/config/styles.py +4 -0
- vectordb_bench/frontend/pages/concurrent.py +1 -1
- vectordb_bench/frontend/pages/custom.py +1 -1
- vectordb_bench/frontend/pages/int_filter.py +56 -0
- vectordb_bench/frontend/pages/streaming.py +16 -3
- vectordb_bench/interface.py +5 -1
- vectordb_bench/metric.py +7 -0
- vectordb_bench/models.py +39 -4
- vectordb_bench/results/S3Vectors/result_20250722_standard_s3vectors.json +2509 -0
- vectordb_bench/results/getLeaderboardDataV2.py +23 -2
- vectordb_bench/results/leaderboard_v2.json +200 -0
- vectordb_bench/results/leaderboard_v2_streaming.json +128 -0
- {vectordb_bench-1.0.4.dist-info → vectordb_bench-1.0.7.dist-info}/METADATA +40 -8
- {vectordb_bench-1.0.4.dist-info → vectordb_bench-1.0.7.dist-info}/RECORD +77 -51
- {vectordb_bench-1.0.4.dist-info → vectordb_bench-1.0.7.dist-info}/WHEEL +0 -0
- {vectordb_bench-1.0.4.dist-info → vectordb_bench-1.0.7.dist-info}/entry_points.txt +0 -0
- {vectordb_bench-1.0.4.dist-info → vectordb_bench-1.0.7.dist-info}/licenses/LICENSE +0 -0
- {vectordb_bench-1.0.4.dist-info → vectordb_bench-1.0.7.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,166 @@
|
|
1
|
+
import logging
|
2
|
+
import random
|
3
|
+
import time
|
4
|
+
|
5
|
+
from opensearchpy import OpenSearch
|
6
|
+
|
7
|
+
log = logging.getLogger(__name__)
|
8
|
+
|
9
|
+
_HOST = "xxxxxx.us-west-2.es.amazonaws.com"
|
10
|
+
_PORT = 443
|
11
|
+
_AUTH = ("admin", "xxxxxx") # For testing only. Don't store credentials in code.
|
12
|
+
|
13
|
+
_INDEX_NAME = "my-dsl-index"
|
14
|
+
_BATCH = 100
|
15
|
+
_ROWS = 100
|
16
|
+
_DIM = 128
|
17
|
+
_TOPK = 10
|
18
|
+
|
19
|
+
|
20
|
+
def create_client():
|
21
|
+
return OpenSearch(
|
22
|
+
hosts=[{"host": _HOST, "port": _PORT}],
|
23
|
+
http_compress=True, # enables gzip compression for request bodies
|
24
|
+
http_auth=_AUTH,
|
25
|
+
use_ssl=True,
|
26
|
+
verify_certs=True,
|
27
|
+
ssl_assert_hostname=False,
|
28
|
+
ssl_show_warn=False,
|
29
|
+
)
|
30
|
+
|
31
|
+
|
32
|
+
def create_index(client: OpenSearch, index_name: str):
|
33
|
+
settings = {
|
34
|
+
"index": {
|
35
|
+
"knn": True,
|
36
|
+
"number_of_shards": 1,
|
37
|
+
"refresh_interval": "5s",
|
38
|
+
},
|
39
|
+
}
|
40
|
+
mappings = {
|
41
|
+
"properties": {
|
42
|
+
"embedding": {
|
43
|
+
"type": "knn_vector",
|
44
|
+
"dimension": _DIM,
|
45
|
+
"method": {
|
46
|
+
"engine": "faiss",
|
47
|
+
"name": "hnsw",
|
48
|
+
"space_type": "l2",
|
49
|
+
"parameters": {
|
50
|
+
"ef_construction": 256,
|
51
|
+
"m": 16,
|
52
|
+
},
|
53
|
+
},
|
54
|
+
},
|
55
|
+
},
|
56
|
+
}
|
57
|
+
|
58
|
+
response = client.indices.create(
|
59
|
+
index=index_name,
|
60
|
+
body={"settings": settings, "mappings": mappings},
|
61
|
+
)
|
62
|
+
log.info("\nCreating index:")
|
63
|
+
log.info(response)
|
64
|
+
|
65
|
+
|
66
|
+
def delete_index(client: OpenSearch, index_name: str):
|
67
|
+
response = client.indices.delete(index=index_name)
|
68
|
+
log.info("\nDeleting index:")
|
69
|
+
log.info(response)
|
70
|
+
|
71
|
+
|
72
|
+
def bulk_insert(client: OpenSearch, index_name: str):
|
73
|
+
# Perform bulk operations
|
74
|
+
ids = list(range(_ROWS))
|
75
|
+
vec = [[random.random() for _ in range(_DIM)] for _ in range(_ROWS)]
|
76
|
+
|
77
|
+
docs = []
|
78
|
+
for i in range(0, _ROWS, _BATCH):
|
79
|
+
docs.clear()
|
80
|
+
for j in range(_BATCH):
|
81
|
+
docs.append({"index": {"_index": index_name, "_id": ids[i + j]}})
|
82
|
+
docs.append({"embedding": vec[i + j]})
|
83
|
+
response = client.bulk(docs)
|
84
|
+
log.info(f"Adding documents: {len(response['items'])}, {response['errors']}")
|
85
|
+
response = client.indices.stats(index_name)
|
86
|
+
log.info(
|
87
|
+
f'Total document count in index: { response["_all"]["primaries"]["indexing"]["index_total"] }',
|
88
|
+
)
|
89
|
+
|
90
|
+
|
91
|
+
def search(client: OpenSearch, index_name: str):
|
92
|
+
# Search for the document.
|
93
|
+
search_body = {
|
94
|
+
"size": _TOPK,
|
95
|
+
"query": {
|
96
|
+
"knn": {
|
97
|
+
"embedding": {
|
98
|
+
"vector": [random.random() for _ in range(_DIM)],
|
99
|
+
"k": _TOPK,
|
100
|
+
},
|
101
|
+
},
|
102
|
+
},
|
103
|
+
}
|
104
|
+
while True:
|
105
|
+
response = client.search(index=index_name, body=search_body)
|
106
|
+
log.info(f'\nSearch took: {response["took"]}')
|
107
|
+
log.info(f'\nSearch shards: {response["_shards"]}')
|
108
|
+
log.info(f'\nSearch hits total: {response["hits"]["total"]}')
|
109
|
+
result = response["hits"]["hits"]
|
110
|
+
if len(result) != 0:
|
111
|
+
log.info("\nSearch results:")
|
112
|
+
for hit in response["hits"]["hits"]:
|
113
|
+
log.info(hit["_id"], hit["_score"])
|
114
|
+
break
|
115
|
+
log.info("\nSearch not ready, sleep 1s")
|
116
|
+
time.sleep(1)
|
117
|
+
|
118
|
+
|
119
|
+
SECONDS_WAITING_FOR_FORCE_MERGE_API_CALL_SEC = 30
|
120
|
+
WAITINT_FOR_REFRESH_SEC = 30
|
121
|
+
|
122
|
+
|
123
|
+
def optimize_index(client: OpenSearch, index_name: str):
|
124
|
+
log.info(f"Starting force merge for index {index_name}")
|
125
|
+
force_merge_endpoint = f"/{index_name}/_forcemerge?max_num_segments=1&wait_for_completion=false"
|
126
|
+
force_merge_task_id = client.transport.perform_request("POST", force_merge_endpoint)["task"]
|
127
|
+
while True:
|
128
|
+
time.sleep(SECONDS_WAITING_FOR_FORCE_MERGE_API_CALL_SEC)
|
129
|
+
task_status = client.tasks.get(task_id=force_merge_task_id)
|
130
|
+
if task_status["completed"]:
|
131
|
+
break
|
132
|
+
log.info(f"Completed force merge for index {index_name}")
|
133
|
+
|
134
|
+
|
135
|
+
def refresh_index(client: OpenSearch, index_name: str):
|
136
|
+
log.info(f"Starting refresh for index {index_name}")
|
137
|
+
while True:
|
138
|
+
try:
|
139
|
+
log.info("Starting the Refresh Index..")
|
140
|
+
client.indices.refresh(index=index_name)
|
141
|
+
break
|
142
|
+
except Exception as e:
|
143
|
+
log.info(
|
144
|
+
f"Refresh errored out. Sleeping for {WAITINT_FOR_REFRESH_SEC} sec and then Retrying : {e}",
|
145
|
+
)
|
146
|
+
time.sleep(WAITINT_FOR_REFRESH_SEC)
|
147
|
+
continue
|
148
|
+
log.info(f"Completed refresh for index {index_name}")
|
149
|
+
|
150
|
+
|
151
|
+
def main():
|
152
|
+
client = create_client()
|
153
|
+
try:
|
154
|
+
create_index(client, _INDEX_NAME)
|
155
|
+
bulk_insert(client, _INDEX_NAME)
|
156
|
+
optimize_index(client, _INDEX_NAME)
|
157
|
+
refresh_index(client, _INDEX_NAME)
|
158
|
+
search(client, _INDEX_NAME)
|
159
|
+
delete_index(client, _INDEX_NAME)
|
160
|
+
except Exception as e:
|
161
|
+
log.info(e)
|
162
|
+
delete_index(client, _INDEX_NAME)
|
163
|
+
|
164
|
+
|
165
|
+
if __name__ == "__main__":
|
166
|
+
main()
|
@@ -5,6 +5,7 @@ import click
|
|
5
5
|
from pydantic import SecretStr
|
6
6
|
|
7
7
|
from vectordb_bench.backend.clients import DB
|
8
|
+
from vectordb_bench.backend.clients.api import MetricType
|
8
9
|
|
9
10
|
from ....cli.cli import (
|
10
11
|
CommonTypedDict,
|
@@ -48,6 +49,15 @@ class PgDiskAnnTypedDict(CommonTypedDict):
|
|
48
49
|
help="PgDiskAnn l_value_ib",
|
49
50
|
),
|
50
51
|
]
|
52
|
+
pq_param_num_chunks: Annotated[
|
53
|
+
int,
|
54
|
+
click.option(
|
55
|
+
"--pq-param-num-chunks",
|
56
|
+
type=int,
|
57
|
+
help="PgDiskAnn pq_param_num_chunks",
|
58
|
+
required=False,
|
59
|
+
),
|
60
|
+
]
|
51
61
|
l_value_is: Annotated[
|
52
62
|
float,
|
53
63
|
click.option(
|
@@ -56,6 +66,37 @@ class PgDiskAnnTypedDict(CommonTypedDict):
|
|
56
66
|
help="PgDiskAnn l_value_is",
|
57
67
|
),
|
58
68
|
]
|
69
|
+
reranking: Annotated[
|
70
|
+
bool | None,
|
71
|
+
click.option(
|
72
|
+
"--reranking/--skip-reranking",
|
73
|
+
type=bool,
|
74
|
+
help="Enable reranking for PQ search",
|
75
|
+
default=False,
|
76
|
+
),
|
77
|
+
]
|
78
|
+
reranking_metric: Annotated[
|
79
|
+
str | None,
|
80
|
+
click.option(
|
81
|
+
"--reranking-metric",
|
82
|
+
type=click.Choice(
|
83
|
+
[metric.value for metric in MetricType if metric.value not in ["HAMMING", "JACCARD", "DP"]],
|
84
|
+
),
|
85
|
+
help="Distance metric for reranking",
|
86
|
+
default="COSINE",
|
87
|
+
show_default=True,
|
88
|
+
required=False,
|
89
|
+
),
|
90
|
+
]
|
91
|
+
quantized_fetch_limit: Annotated[
|
92
|
+
int | None,
|
93
|
+
click.option(
|
94
|
+
"--quantized-fetch-limit",
|
95
|
+
type=int,
|
96
|
+
help="Limit of inner query in case of reranking",
|
97
|
+
required=False,
|
98
|
+
),
|
99
|
+
]
|
59
100
|
maintenance_work_mem: Annotated[
|
60
101
|
str | None,
|
61
102
|
click.option(
|
@@ -98,7 +139,11 @@ def PgDiskAnn(
|
|
98
139
|
db_case_config=PgDiskANNImplConfig(
|
99
140
|
max_neighbors=parameters["max_neighbors"],
|
100
141
|
l_value_ib=parameters["l_value_ib"],
|
142
|
+
pq_param_num_chunks=parameters["pq_param_num_chunks"],
|
101
143
|
l_value_is=parameters["l_value_is"],
|
144
|
+
reranking=parameters["reranking"],
|
145
|
+
reranking_metric=parameters["reranking_metric"],
|
146
|
+
quantized_fetch_limit=parameters["quantized_fetch_limit"],
|
102
147
|
max_parallel_workers=parameters["max_parallel_workers"],
|
103
148
|
maintenance_work_mem=parameters["maintenance_work_mem"],
|
104
149
|
),
|
@@ -60,6 +60,13 @@ class PgDiskANNIndexConfig(BaseModel, DBCaseConfig):
|
|
60
60
|
return "<#>"
|
61
61
|
return "<=>"
|
62
62
|
|
63
|
+
def parse_reranking_metric_fun_op(self) -> LiteralString:
|
64
|
+
if self.reranking_metric == MetricType.L2:
|
65
|
+
return "<->"
|
66
|
+
if self.reranking_metric == MetricType.IP:
|
67
|
+
return "<#>"
|
68
|
+
return "<=>"
|
69
|
+
|
63
70
|
def parse_metric_fun_str(self) -> str:
|
64
71
|
if self.metric_type == MetricType.L2:
|
65
72
|
return "l2_distance"
|
@@ -115,7 +122,11 @@ class PgDiskANNImplConfig(PgDiskANNIndexConfig):
|
|
115
122
|
index: IndexType = IndexType.DISKANN
|
116
123
|
max_neighbors: int | None
|
117
124
|
l_value_ib: int | None
|
125
|
+
pq_param_num_chunks: int | None
|
118
126
|
l_value_is: float | None
|
127
|
+
reranking: bool | None = None
|
128
|
+
reranking_metric: str | None = None
|
129
|
+
quantized_fetch_limit: int | None = None
|
119
130
|
maintenance_work_mem: str | None = None
|
120
131
|
max_parallel_workers: int | None = None
|
121
132
|
|
@@ -126,6 +137,8 @@ class PgDiskANNImplConfig(PgDiskANNIndexConfig):
|
|
126
137
|
"options": {
|
127
138
|
"max_neighbors": self.max_neighbors,
|
128
139
|
"l_value_ib": self.l_value_ib,
|
140
|
+
"pq_param_num_chunks": self.pq_param_num_chunks,
|
141
|
+
"product_quantized": str(self.reranking),
|
129
142
|
},
|
130
143
|
"maintenance_work_mem": self.maintenance_work_mem,
|
131
144
|
"max_parallel_workers": self.max_parallel_workers,
|
@@ -135,6 +148,9 @@ class PgDiskANNImplConfig(PgDiskANNIndexConfig):
|
|
135
148
|
return {
|
136
149
|
"metric": self.parse_metric(),
|
137
150
|
"metric_fun_op": self.parse_metric_fun_op(),
|
151
|
+
"reranking": self.reranking,
|
152
|
+
"reranking_metric_fun_op": self.parse_reranking_metric_fun_op(),
|
153
|
+
"quantized_fetch_limit": self.quantized_fetch_limit,
|
138
154
|
}
|
139
155
|
|
140
156
|
def session_param(self) -> dict:
|
@@ -90,38 +90,83 @@ class PgDiskANN(VectorDB):
|
|
90
90
|
def init(self) -> Generator[None, None, None]:
|
91
91
|
self.conn, self.cursor = self._create_connection(**self.db_config)
|
92
92
|
|
93
|
-
# index configuration may have commands defined that we should set during each client session
|
94
93
|
session_options: dict[str, Any] = self.case_config.session_param()
|
95
94
|
|
96
95
|
if len(session_options) > 0:
|
97
96
|
for setting_name, setting_val in session_options.items():
|
98
|
-
command = sql.SQL("SET {setting_name}
|
99
|
-
setting_name=sql.Identifier(setting_name),
|
100
|
-
setting_val=sql.Identifier(str(setting_val)),
|
97
|
+
command = sql.SQL("SET {setting_name} = {setting_val};").format(
|
98
|
+
setting_name=sql.Identifier(setting_name), setting_val=sql.Literal(setting_val)
|
101
99
|
)
|
102
100
|
log.debug(command.as_string(self.cursor))
|
103
101
|
self.cursor.execute(command)
|
104
102
|
self.conn.commit()
|
105
103
|
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
104
|
+
search_params = self.case_config.search_param()
|
105
|
+
|
106
|
+
if search_params.get("reranking"):
|
107
|
+
# Reranking-enabled queries
|
108
|
+
self._filtered_search = sql.SQL(
|
109
|
+
"""
|
110
|
+
SELECT i.id
|
111
|
+
FROM (
|
112
|
+
SELECT id, embedding
|
113
|
+
FROM public.{table_name}
|
114
|
+
WHERE id >= %s
|
115
|
+
ORDER BY embedding {metric_fun_op} %s::vector
|
116
|
+
LIMIT {quantized_fetch_limit}::int
|
117
|
+
) i
|
118
|
+
ORDER BY i.embedding {reranking_metric_fun_op} %s::vector
|
119
|
+
LIMIT %s::int
|
120
|
+
"""
|
121
|
+
).format(
|
122
|
+
table_name=sql.Identifier(self.table_name),
|
123
|
+
metric_fun_op=sql.SQL(search_params["metric_fun_op"]),
|
124
|
+
reranking_metric_fun_op=sql.SQL(search_params["reranking_metric_fun_op"]),
|
125
|
+
quantized_fetch_limit=sql.Literal(search_params["quantized_fetch_limit"]),
|
126
|
+
)
|
115
127
|
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
128
|
+
self._unfiltered_search = sql.SQL(
|
129
|
+
"""
|
130
|
+
SELECT i.id
|
131
|
+
FROM (
|
132
|
+
SELECT id, embedding
|
133
|
+
FROM public.{table_name}
|
134
|
+
ORDER BY embedding {metric_fun_op} %s::vector
|
135
|
+
LIMIT {quantized_fetch_limit}::int
|
136
|
+
) i
|
137
|
+
ORDER BY i.embedding {reranking_metric_fun_op} %s::vector
|
138
|
+
LIMIT %s::int
|
139
|
+
"""
|
140
|
+
).format(
|
141
|
+
table_name=sql.Identifier(self.table_name),
|
142
|
+
metric_fun_op=sql.SQL(search_params["metric_fun_op"]),
|
143
|
+
reranking_metric_fun_op=sql.SQL(search_params["reranking_metric_fun_op"]),
|
144
|
+
quantized_fetch_limit=sql.Literal(search_params["quantized_fetch_limit"]),
|
145
|
+
)
|
146
|
+
|
147
|
+
else:
|
148
|
+
self._filtered_search = sql.Composed(
|
149
|
+
[
|
150
|
+
sql.SQL(
|
151
|
+
"SELECT id FROM public.{table_name} WHERE id >= %s ORDER BY embedding ",
|
152
|
+
).format(table_name=sql.Identifier(self.table_name)),
|
153
|
+
sql.SQL(search_params["metric_fun_op"]),
|
154
|
+
sql.SQL(" %s::vector LIMIT %s::int"),
|
155
|
+
]
|
156
|
+
)
|
157
|
+
|
158
|
+
self._unfiltered_search = sql.Composed(
|
159
|
+
[
|
160
|
+
sql.SQL("SELECT id FROM public.{table_name} ORDER BY embedding ").format(
|
161
|
+
table_name=sql.Identifier(self.table_name)
|
162
|
+
),
|
163
|
+
sql.SQL(search_params["metric_fun_op"]),
|
164
|
+
sql.SQL(" %s::vector LIMIT %s::int"),
|
165
|
+
]
|
166
|
+
)
|
167
|
+
|
168
|
+
log.debug(f"Unfiltered search query={self._unfiltered_search.as_string(self.conn)}")
|
169
|
+
log.debug(f"Filtered search query={self._filtered_search.as_string(self.conn)}")
|
125
170
|
|
126
171
|
try:
|
127
172
|
yield
|
@@ -234,7 +279,7 @@ class PgDiskANN(VectorDB):
|
|
234
279
|
options.append(
|
235
280
|
sql.SQL("{option_name} = {val}").format(
|
236
281
|
option_name=sql.Identifier(option_name),
|
237
|
-
val=sql.
|
282
|
+
val=sql.Literal(option_val),
|
238
283
|
),
|
239
284
|
)
|
240
285
|
|
@@ -314,16 +359,39 @@ class PgDiskANN(VectorDB):
|
|
314
359
|
assert self.conn is not None, "Connection is not initialized"
|
315
360
|
assert self.cursor is not None, "Cursor is not initialized"
|
316
361
|
|
362
|
+
search_params = self.case_config.search_param()
|
363
|
+
is_reranking = search_params.get("reranking", False)
|
364
|
+
|
317
365
|
q = np.asarray(query)
|
318
366
|
if filters:
|
319
367
|
gt = filters.get("id")
|
368
|
+
if is_reranking:
|
369
|
+
result = self.cursor.execute(
|
370
|
+
self._filtered_search,
|
371
|
+
(gt, q, q, k),
|
372
|
+
prepare=True,
|
373
|
+
binary=True,
|
374
|
+
)
|
375
|
+
else:
|
376
|
+
result = self.cursor.execute(
|
377
|
+
self._filtered_search,
|
378
|
+
(gt, q, k),
|
379
|
+
prepare=True,
|
380
|
+
binary=True,
|
381
|
+
)
|
382
|
+
elif is_reranking:
|
320
383
|
result = self.cursor.execute(
|
321
|
-
self.
|
322
|
-
(
|
384
|
+
self._unfiltered_search,
|
385
|
+
(q, q, k),
|
323
386
|
prepare=True,
|
324
387
|
binary=True,
|
325
388
|
)
|
326
389
|
else:
|
327
|
-
result = self.cursor.execute(
|
390
|
+
result = self.cursor.execute(
|
391
|
+
self._unfiltered_search,
|
392
|
+
(q, k),
|
393
|
+
prepare=True,
|
394
|
+
binary=True,
|
395
|
+
)
|
328
396
|
|
329
397
|
return [int(i[0]) for i in result.fetchall()]
|
@@ -0,0 +1,41 @@
|
|
1
|
+
from pydantic import BaseModel, SecretStr
|
2
|
+
|
3
|
+
from ..api import DBCaseConfig, DBConfig, MetricType
|
4
|
+
|
5
|
+
|
6
|
+
class S3VectorsConfig(DBConfig):
|
7
|
+
region_name: str = "us-west-2"
|
8
|
+
access_key_id: SecretStr
|
9
|
+
secret_access_key: SecretStr
|
10
|
+
bucket_name: str
|
11
|
+
index_name: str = "vdbbench-index"
|
12
|
+
|
13
|
+
def to_dict(self) -> dict:
|
14
|
+
return {
|
15
|
+
"region_name": self.region_name,
|
16
|
+
"access_key_id": self.access_key_id.get_secret_value() if self.access_key_id else "",
|
17
|
+
"secret_access_key": self.secret_access_key.get_secret_value() if self.secret_access_key else "",
|
18
|
+
"bucket_name": self.bucket_name,
|
19
|
+
"index_name": self.index_name,
|
20
|
+
}
|
21
|
+
|
22
|
+
|
23
|
+
class S3VectorsIndexConfig(DBCaseConfig, BaseModel):
|
24
|
+
"""Base config for s3-vectors"""
|
25
|
+
|
26
|
+
metric_type: MetricType | None = None
|
27
|
+
data_type: str = "float32"
|
28
|
+
|
29
|
+
def parse_metric(self) -> str:
|
30
|
+
if self.metric_type == MetricType.COSINE:
|
31
|
+
return "cosine"
|
32
|
+
if self.metric_type == MetricType.L2:
|
33
|
+
return "euclidean"
|
34
|
+
msg = f"Unsupported metric type: {self.metric_type}"
|
35
|
+
raise ValueError(msg)
|
36
|
+
|
37
|
+
def index_param(self) -> dict:
|
38
|
+
return {}
|
39
|
+
|
40
|
+
def search_param(self) -> dict:
|
41
|
+
return {}
|
@@ -0,0 +1,171 @@
|
|
1
|
+
"""Wrapper around the Milvus vector database over VectorDB"""
|
2
|
+
|
3
|
+
import logging
|
4
|
+
from collections.abc import Iterable
|
5
|
+
from contextlib import contextmanager
|
6
|
+
|
7
|
+
import boto3
|
8
|
+
|
9
|
+
from vectordb_bench.backend.filter import Filter, FilterOp
|
10
|
+
|
11
|
+
from ..api import VectorDB
|
12
|
+
from .config import S3VectorsIndexConfig
|
13
|
+
|
14
|
+
log = logging.getLogger(__name__)
|
15
|
+
|
16
|
+
|
17
|
+
class S3Vectors(VectorDB):
|
18
|
+
supported_filter_types: list[FilterOp] = [
|
19
|
+
FilterOp.NonFilter,
|
20
|
+
FilterOp.NumGE,
|
21
|
+
FilterOp.StrEqual,
|
22
|
+
]
|
23
|
+
|
24
|
+
def __init__(
|
25
|
+
self,
|
26
|
+
dim: int,
|
27
|
+
db_config: dict,
|
28
|
+
db_case_config: S3VectorsIndexConfig,
|
29
|
+
drop_old: bool = False,
|
30
|
+
with_scalar_labels: bool = False,
|
31
|
+
**kwargs,
|
32
|
+
):
|
33
|
+
"""Initialize wrapper around the s3-vectors client."""
|
34
|
+
self.db_config = db_config
|
35
|
+
self.case_config = db_case_config
|
36
|
+
self.with_scalar_labels = with_scalar_labels
|
37
|
+
|
38
|
+
self.batch_size = 500
|
39
|
+
|
40
|
+
self._scalar_id_field = "id"
|
41
|
+
self._scalar_label_field = "label"
|
42
|
+
self._vector_field = "vector"
|
43
|
+
|
44
|
+
self.region_name = self.db_config.get("region_name")
|
45
|
+
self.access_key_id = self.db_config.get("access_key_id")
|
46
|
+
self.secret_access_key = self.db_config.get("secret_access_key")
|
47
|
+
self.bucket_name = self.db_config.get("bucket_name")
|
48
|
+
self.index_name = self.db_config.get("index_name")
|
49
|
+
|
50
|
+
client = boto3.client(
|
51
|
+
service_name="s3vectors",
|
52
|
+
region_name=self.region_name,
|
53
|
+
aws_access_key_id=self.access_key_id,
|
54
|
+
aws_secret_access_key=self.secret_access_key,
|
55
|
+
)
|
56
|
+
|
57
|
+
if drop_old:
|
58
|
+
# delete old index if exists
|
59
|
+
response = client.list_indexes(vectorBucketName=self.bucket_name)
|
60
|
+
index_names = [index["indexName"] for index in response["indexes"]]
|
61
|
+
if self.index_name in index_names:
|
62
|
+
log.info(f"drop old index: {self.index_name}")
|
63
|
+
client.delete_index(vectorBucketName=self.bucket_name, indexName=self.index_name)
|
64
|
+
|
65
|
+
# create the index
|
66
|
+
client.create_index(
|
67
|
+
vectorBucketName=self.bucket_name,
|
68
|
+
indexName=self.index_name,
|
69
|
+
dataType=self.case_config.data_type,
|
70
|
+
dimension=dim,
|
71
|
+
distanceMetric=self.case_config.parse_metric(),
|
72
|
+
)
|
73
|
+
|
74
|
+
client.close()
|
75
|
+
|
76
|
+
@contextmanager
|
77
|
+
def init(self):
|
78
|
+
"""
|
79
|
+
Examples:
|
80
|
+
>>> with self.init():
|
81
|
+
>>> self.insert_embeddings()
|
82
|
+
>>> self.search_embedding()
|
83
|
+
"""
|
84
|
+
self.client = boto3.client(
|
85
|
+
service_name="s3vectors",
|
86
|
+
region_name=self.region_name,
|
87
|
+
aws_access_key_id=self.access_key_id,
|
88
|
+
aws_secret_access_key=self.secret_access_key,
|
89
|
+
)
|
90
|
+
|
91
|
+
yield
|
92
|
+
self.client.close()
|
93
|
+
|
94
|
+
def optimize(self, **kwargs):
|
95
|
+
return
|
96
|
+
|
97
|
+
def need_normalize_cosine(self) -> bool:
|
98
|
+
"""Wheather this database need to normalize dataset to support COSINE"""
|
99
|
+
return False
|
100
|
+
|
101
|
+
def insert_embeddings(
|
102
|
+
self,
|
103
|
+
embeddings: Iterable[list[float]],
|
104
|
+
metadata: list[int],
|
105
|
+
labels_data: list[str] | None = None,
|
106
|
+
**kwargs,
|
107
|
+
) -> tuple[int, Exception]:
|
108
|
+
"""Insert embeddings into s3-vectors. should call self.init() first"""
|
109
|
+
# use the first insert_embeddings to init collection
|
110
|
+
assert self.client is not None
|
111
|
+
assert len(embeddings) == len(metadata)
|
112
|
+
insert_count = 0
|
113
|
+
try:
|
114
|
+
for batch_start_offset in range(0, len(embeddings), self.batch_size):
|
115
|
+
batch_end_offset = min(batch_start_offset + self.batch_size, len(embeddings))
|
116
|
+
insert_data = [
|
117
|
+
{
|
118
|
+
"key": str(metadata[i]),
|
119
|
+
"data": {self.case_config.data_type: embeddings[i]},
|
120
|
+
"metadata": (
|
121
|
+
{self._scalar_label_field: labels_data[i], self._scalar_id_field: metadata[i]}
|
122
|
+
if self.with_scalar_labels
|
123
|
+
else {self._scalar_id_field: metadata[i]}
|
124
|
+
),
|
125
|
+
}
|
126
|
+
for i in range(batch_start_offset, batch_end_offset)
|
127
|
+
]
|
128
|
+
self.client.put_vectors(
|
129
|
+
vectorBucketName=self.bucket_name,
|
130
|
+
indexName=self.index_name,
|
131
|
+
vectors=insert_data,
|
132
|
+
)
|
133
|
+
insert_count += len(insert_data)
|
134
|
+
except Exception as e:
|
135
|
+
log.info(f"Failed to insert data: {e}")
|
136
|
+
return insert_count, e
|
137
|
+
return insert_count, None
|
138
|
+
|
139
|
+
def prepare_filter(self, filters: Filter):
|
140
|
+
if filters.type == FilterOp.NonFilter:
|
141
|
+
self.filter = None
|
142
|
+
elif filters.type == FilterOp.NumGE:
|
143
|
+
self.filter = {self._scalar_id_field: {"$gte": filters.int_value}}
|
144
|
+
elif filters.type == FilterOp.StrEqual:
|
145
|
+
self.filter = {self._scalar_label_field: filters.label_value}
|
146
|
+
else:
|
147
|
+
msg = f"Not support Filter for S3Vectors - {filters}"
|
148
|
+
raise ValueError(msg)
|
149
|
+
|
150
|
+
def search_embedding(
|
151
|
+
self,
|
152
|
+
query: list[float],
|
153
|
+
k: int = 100,
|
154
|
+
timeout: int | None = None,
|
155
|
+
) -> list[int]:
|
156
|
+
"""Perform a search on a query embedding and return results."""
|
157
|
+
assert self.client is not None
|
158
|
+
|
159
|
+
# Perform the search.
|
160
|
+
res = self.client.query_vectors(
|
161
|
+
vectorBucketName=self.bucket_name,
|
162
|
+
indexName=self.index_name,
|
163
|
+
queryVector={"float32": query},
|
164
|
+
topK=k,
|
165
|
+
filter=self.filter,
|
166
|
+
returnDistance=False,
|
167
|
+
returnMetadata=False,
|
168
|
+
)
|
169
|
+
|
170
|
+
# Organize results.
|
171
|
+
return [int(result["key"]) for result in res["vectors"]]
|
@@ -17,7 +17,6 @@ class TiDBTypedDict(CommonTypedDict):
|
|
17
17
|
help="Username",
|
18
18
|
default="root",
|
19
19
|
show_default=True,
|
20
|
-
required=True,
|
21
20
|
),
|
22
21
|
]
|
23
22
|
password: Annotated[
|
@@ -37,7 +36,6 @@ class TiDBTypedDict(CommonTypedDict):
|
|
37
36
|
type=str,
|
38
37
|
default="127.0.0.1",
|
39
38
|
show_default=True,
|
40
|
-
required=True,
|
41
39
|
help="Db host",
|
42
40
|
),
|
43
41
|
]
|
@@ -48,7 +46,6 @@ class TiDBTypedDict(CommonTypedDict):
|
|
48
46
|
type=int,
|
49
47
|
default=4000,
|
50
48
|
show_default=True,
|
51
|
-
required=True,
|
52
49
|
help="Db Port",
|
53
50
|
),
|
54
51
|
]
|
@@ -59,7 +56,6 @@ class TiDBTypedDict(CommonTypedDict):
|
|
59
56
|
type=str,
|
60
57
|
default="test",
|
61
58
|
show_default=True,
|
62
|
-
required=True,
|
63
59
|
help="Db name",
|
64
60
|
),
|
65
61
|
]
|