vectordb-bench 1.0.4__py3-none-any.whl → 1.0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vectordb_bench/backend/cases.py +45 -1
- vectordb_bench/backend/clients/__init__.py +32 -0
- vectordb_bench/backend/clients/milvus/cli.py +4 -9
- vectordb_bench/backend/clients/oss_opensearch/cli.py +155 -0
- vectordb_bench/backend/clients/oss_opensearch/config.py +157 -0
- vectordb_bench/backend/clients/oss_opensearch/oss_opensearch.py +582 -0
- vectordb_bench/backend/clients/oss_opensearch/run.py +166 -0
- vectordb_bench/backend/clients/s3_vectors/config.py +41 -0
- vectordb_bench/backend/clients/s3_vectors/s3_vectors.py +171 -0
- vectordb_bench/backend/clients/tidb/cli.py +0 -4
- vectordb_bench/backend/clients/tidb/config.py +22 -2
- vectordb_bench/backend/dataset.py +70 -0
- vectordb_bench/backend/filter.py +17 -0
- vectordb_bench/backend/runner/mp_runner.py +4 -0
- vectordb_bench/backend/runner/read_write_runner.py +10 -9
- vectordb_bench/backend/runner/serial_runner.py +23 -7
- vectordb_bench/backend/task_runner.py +5 -4
- vectordb_bench/cli/vectordbbench.py +2 -0
- vectordb_bench/fig/custom_case_run_test.png +0 -0
- vectordb_bench/fig/custom_dataset.png +0 -0
- vectordb_bench/fig/homepage/bar-chart.png +0 -0
- vectordb_bench/fig/homepage/concurrent.png +0 -0
- vectordb_bench/fig/homepage/custom.png +0 -0
- vectordb_bench/fig/homepage/label_filter.png +0 -0
- vectordb_bench/fig/homepage/qp$.png +0 -0
- vectordb_bench/fig/homepage/run_test.png +0 -0
- vectordb_bench/fig/homepage/streaming.png +0 -0
- vectordb_bench/fig/homepage/table.png +0 -0
- vectordb_bench/fig/run_test_select_case.png +0 -0
- vectordb_bench/fig/run_test_select_db.png +0 -0
- vectordb_bench/fig/run_test_submit.png +0 -0
- vectordb_bench/frontend/components/check_results/filters.py +1 -4
- vectordb_bench/frontend/components/check_results/nav.py +2 -1
- vectordb_bench/frontend/components/concurrent/charts.py +5 -0
- vectordb_bench/frontend/components/int_filter/charts.py +60 -0
- vectordb_bench/frontend/components/streaming/data.py +7 -0
- vectordb_bench/frontend/components/welcome/welcomePrams.py +42 -4
- vectordb_bench/frontend/config/dbCaseConfigs.py +60 -13
- vectordb_bench/frontend/config/styles.py +3 -0
- vectordb_bench/frontend/pages/concurrent.py +1 -1
- vectordb_bench/frontend/pages/custom.py +1 -1
- vectordb_bench/frontend/pages/int_filter.py +56 -0
- vectordb_bench/frontend/pages/streaming.py +16 -3
- vectordb_bench/metric.py +7 -0
- vectordb_bench/models.py +36 -4
- vectordb_bench/results/S3Vectors/result_20250722_standard_s3vectors.json +2509 -0
- {vectordb_bench-1.0.4.dist-info → vectordb_bench-1.0.5.dist-info}/METADATA +1 -1
- {vectordb_bench-1.0.4.dist-info → vectordb_bench-1.0.5.dist-info}/RECORD +52 -30
- {vectordb_bench-1.0.4.dist-info → vectordb_bench-1.0.5.dist-info}/WHEEL +0 -0
- {vectordb_bench-1.0.4.dist-info → vectordb_bench-1.0.5.dist-info}/entry_points.txt +0 -0
- {vectordb_bench-1.0.4.dist-info → vectordb_bench-1.0.5.dist-info}/licenses/LICENSE +0 -0
- {vectordb_bench-1.0.4.dist-info → vectordb_bench-1.0.5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,166 @@
|
|
1
|
+
import logging
|
2
|
+
import random
|
3
|
+
import time
|
4
|
+
|
5
|
+
from opensearchpy import OpenSearch
|
6
|
+
|
7
|
+
log = logging.getLogger(__name__)
|
8
|
+
|
9
|
+
_HOST = "xxxxxx.us-west-2.es.amazonaws.com"
|
10
|
+
_PORT = 443
|
11
|
+
_AUTH = ("admin", "xxxxxx") # For testing only. Don't store credentials in code.
|
12
|
+
|
13
|
+
_INDEX_NAME = "my-dsl-index"
|
14
|
+
_BATCH = 100
|
15
|
+
_ROWS = 100
|
16
|
+
_DIM = 128
|
17
|
+
_TOPK = 10
|
18
|
+
|
19
|
+
|
20
|
+
def create_client():
|
21
|
+
return OpenSearch(
|
22
|
+
hosts=[{"host": _HOST, "port": _PORT}],
|
23
|
+
http_compress=True, # enables gzip compression for request bodies
|
24
|
+
http_auth=_AUTH,
|
25
|
+
use_ssl=True,
|
26
|
+
verify_certs=True,
|
27
|
+
ssl_assert_hostname=False,
|
28
|
+
ssl_show_warn=False,
|
29
|
+
)
|
30
|
+
|
31
|
+
|
32
|
+
def create_index(client: OpenSearch, index_name: str):
|
33
|
+
settings = {
|
34
|
+
"index": {
|
35
|
+
"knn": True,
|
36
|
+
"number_of_shards": 1,
|
37
|
+
"refresh_interval": "5s",
|
38
|
+
},
|
39
|
+
}
|
40
|
+
mappings = {
|
41
|
+
"properties": {
|
42
|
+
"embedding": {
|
43
|
+
"type": "knn_vector",
|
44
|
+
"dimension": _DIM,
|
45
|
+
"method": {
|
46
|
+
"engine": "faiss",
|
47
|
+
"name": "hnsw",
|
48
|
+
"space_type": "l2",
|
49
|
+
"parameters": {
|
50
|
+
"ef_construction": 256,
|
51
|
+
"m": 16,
|
52
|
+
},
|
53
|
+
},
|
54
|
+
},
|
55
|
+
},
|
56
|
+
}
|
57
|
+
|
58
|
+
response = client.indices.create(
|
59
|
+
index=index_name,
|
60
|
+
body={"settings": settings, "mappings": mappings},
|
61
|
+
)
|
62
|
+
log.info("\nCreating index:")
|
63
|
+
log.info(response)
|
64
|
+
|
65
|
+
|
66
|
+
def delete_index(client: OpenSearch, index_name: str):
|
67
|
+
response = client.indices.delete(index=index_name)
|
68
|
+
log.info("\nDeleting index:")
|
69
|
+
log.info(response)
|
70
|
+
|
71
|
+
|
72
|
+
def bulk_insert(client: OpenSearch, index_name: str):
|
73
|
+
# Perform bulk operations
|
74
|
+
ids = list(range(_ROWS))
|
75
|
+
vec = [[random.random() for _ in range(_DIM)] for _ in range(_ROWS)]
|
76
|
+
|
77
|
+
docs = []
|
78
|
+
for i in range(0, _ROWS, _BATCH):
|
79
|
+
docs.clear()
|
80
|
+
for j in range(_BATCH):
|
81
|
+
docs.append({"index": {"_index": index_name, "_id": ids[i + j]}})
|
82
|
+
docs.append({"embedding": vec[i + j]})
|
83
|
+
response = client.bulk(docs)
|
84
|
+
log.info(f"Adding documents: {len(response['items'])}, {response['errors']}")
|
85
|
+
response = client.indices.stats(index_name)
|
86
|
+
log.info(
|
87
|
+
f'Total document count in index: { response["_all"]["primaries"]["indexing"]["index_total"] }',
|
88
|
+
)
|
89
|
+
|
90
|
+
|
91
|
+
def search(client: OpenSearch, index_name: str):
|
92
|
+
# Search for the document.
|
93
|
+
search_body = {
|
94
|
+
"size": _TOPK,
|
95
|
+
"query": {
|
96
|
+
"knn": {
|
97
|
+
"embedding": {
|
98
|
+
"vector": [random.random() for _ in range(_DIM)],
|
99
|
+
"k": _TOPK,
|
100
|
+
},
|
101
|
+
},
|
102
|
+
},
|
103
|
+
}
|
104
|
+
while True:
|
105
|
+
response = client.search(index=index_name, body=search_body)
|
106
|
+
log.info(f'\nSearch took: {response["took"]}')
|
107
|
+
log.info(f'\nSearch shards: {response["_shards"]}')
|
108
|
+
log.info(f'\nSearch hits total: {response["hits"]["total"]}')
|
109
|
+
result = response["hits"]["hits"]
|
110
|
+
if len(result) != 0:
|
111
|
+
log.info("\nSearch results:")
|
112
|
+
for hit in response["hits"]["hits"]:
|
113
|
+
log.info(hit["_id"], hit["_score"])
|
114
|
+
break
|
115
|
+
log.info("\nSearch not ready, sleep 1s")
|
116
|
+
time.sleep(1)
|
117
|
+
|
118
|
+
|
119
|
+
SECONDS_WAITING_FOR_FORCE_MERGE_API_CALL_SEC = 30
|
120
|
+
WAITINT_FOR_REFRESH_SEC = 30
|
121
|
+
|
122
|
+
|
123
|
+
def optimize_index(client: OpenSearch, index_name: str):
|
124
|
+
log.info(f"Starting force merge for index {index_name}")
|
125
|
+
force_merge_endpoint = f"/{index_name}/_forcemerge?max_num_segments=1&wait_for_completion=false"
|
126
|
+
force_merge_task_id = client.transport.perform_request("POST", force_merge_endpoint)["task"]
|
127
|
+
while True:
|
128
|
+
time.sleep(SECONDS_WAITING_FOR_FORCE_MERGE_API_CALL_SEC)
|
129
|
+
task_status = client.tasks.get(task_id=force_merge_task_id)
|
130
|
+
if task_status["completed"]:
|
131
|
+
break
|
132
|
+
log.info(f"Completed force merge for index {index_name}")
|
133
|
+
|
134
|
+
|
135
|
+
def refresh_index(client: OpenSearch, index_name: str):
|
136
|
+
log.info(f"Starting refresh for index {index_name}")
|
137
|
+
while True:
|
138
|
+
try:
|
139
|
+
log.info("Starting the Refresh Index..")
|
140
|
+
client.indices.refresh(index=index_name)
|
141
|
+
break
|
142
|
+
except Exception as e:
|
143
|
+
log.info(
|
144
|
+
f"Refresh errored out. Sleeping for {WAITINT_FOR_REFRESH_SEC} sec and then Retrying : {e}",
|
145
|
+
)
|
146
|
+
time.sleep(WAITINT_FOR_REFRESH_SEC)
|
147
|
+
continue
|
148
|
+
log.info(f"Completed refresh for index {index_name}")
|
149
|
+
|
150
|
+
|
151
|
+
def main():
|
152
|
+
client = create_client()
|
153
|
+
try:
|
154
|
+
create_index(client, _INDEX_NAME)
|
155
|
+
bulk_insert(client, _INDEX_NAME)
|
156
|
+
optimize_index(client, _INDEX_NAME)
|
157
|
+
refresh_index(client, _INDEX_NAME)
|
158
|
+
search(client, _INDEX_NAME)
|
159
|
+
delete_index(client, _INDEX_NAME)
|
160
|
+
except Exception as e:
|
161
|
+
log.info(e)
|
162
|
+
delete_index(client, _INDEX_NAME)
|
163
|
+
|
164
|
+
|
165
|
+
if __name__ == "__main__":
|
166
|
+
main()
|
@@ -0,0 +1,41 @@
|
|
1
|
+
from pydantic import BaseModel, SecretStr
|
2
|
+
|
3
|
+
from ..api import DBCaseConfig, DBConfig, MetricType
|
4
|
+
|
5
|
+
|
6
|
+
class S3VectorsConfig(DBConfig):
|
7
|
+
region_name: str = "us-west-2"
|
8
|
+
access_key_id: SecretStr
|
9
|
+
secret_access_key: SecretStr
|
10
|
+
bucket_name: str
|
11
|
+
index_name: str = "vdbbench-index"
|
12
|
+
|
13
|
+
def to_dict(self) -> dict:
|
14
|
+
return {
|
15
|
+
"region_name": self.region_name,
|
16
|
+
"access_key_id": self.access_key_id.get_secret_value() if self.access_key_id else "",
|
17
|
+
"secret_access_key": self.secret_access_key.get_secret_value() if self.secret_access_key else "",
|
18
|
+
"bucket_name": self.bucket_name,
|
19
|
+
"index_name": self.index_name,
|
20
|
+
}
|
21
|
+
|
22
|
+
|
23
|
+
class S3VectorsIndexConfig(DBCaseConfig, BaseModel):
|
24
|
+
"""Base config for s3-vectors"""
|
25
|
+
|
26
|
+
metric_type: MetricType | None = None
|
27
|
+
data_type: str = "float32"
|
28
|
+
|
29
|
+
def parse_metric(self) -> str:
|
30
|
+
if self.metric_type == MetricType.COSINE:
|
31
|
+
return "cosine"
|
32
|
+
if self.metric_type == MetricType.L2:
|
33
|
+
return "euclidean"
|
34
|
+
msg = f"Unsupported metric type: {self.metric_type}"
|
35
|
+
raise ValueError(msg)
|
36
|
+
|
37
|
+
def index_param(self) -> dict:
|
38
|
+
return {}
|
39
|
+
|
40
|
+
def search_param(self) -> dict:
|
41
|
+
return {}
|
@@ -0,0 +1,171 @@
|
|
1
|
+
"""Wrapper around the Milvus vector database over VectorDB"""
|
2
|
+
|
3
|
+
import logging
|
4
|
+
from collections.abc import Iterable
|
5
|
+
from contextlib import contextmanager
|
6
|
+
|
7
|
+
import boto3
|
8
|
+
|
9
|
+
from vectordb_bench.backend.filter import Filter, FilterOp
|
10
|
+
|
11
|
+
from ..api import VectorDB
|
12
|
+
from .config import S3VectorsIndexConfig
|
13
|
+
|
14
|
+
log = logging.getLogger(__name__)
|
15
|
+
|
16
|
+
|
17
|
+
class S3Vectors(VectorDB):
|
18
|
+
supported_filter_types: list[FilterOp] = [
|
19
|
+
FilterOp.NonFilter,
|
20
|
+
FilterOp.NumGE,
|
21
|
+
FilterOp.StrEqual,
|
22
|
+
]
|
23
|
+
|
24
|
+
def __init__(
|
25
|
+
self,
|
26
|
+
dim: int,
|
27
|
+
db_config: dict,
|
28
|
+
db_case_config: S3VectorsIndexConfig,
|
29
|
+
drop_old: bool = False,
|
30
|
+
with_scalar_labels: bool = False,
|
31
|
+
**kwargs,
|
32
|
+
):
|
33
|
+
"""Initialize wrapper around the s3-vectors client."""
|
34
|
+
self.db_config = db_config
|
35
|
+
self.case_config = db_case_config
|
36
|
+
self.with_scalar_labels = with_scalar_labels
|
37
|
+
|
38
|
+
self.batch_size = 500
|
39
|
+
|
40
|
+
self._scalar_id_field = "id"
|
41
|
+
self._scalar_label_field = "label"
|
42
|
+
self._vector_field = "vector"
|
43
|
+
|
44
|
+
self.region_name = self.db_config.get("region_name")
|
45
|
+
self.access_key_id = self.db_config.get("access_key_id")
|
46
|
+
self.secret_access_key = self.db_config.get("secret_access_key")
|
47
|
+
self.bucket_name = self.db_config.get("bucket_name")
|
48
|
+
self.index_name = self.db_config.get("index_name")
|
49
|
+
|
50
|
+
client = boto3.client(
|
51
|
+
service_name="s3vectors",
|
52
|
+
region_name=self.region_name,
|
53
|
+
aws_access_key_id=self.access_key_id,
|
54
|
+
aws_secret_access_key=self.secret_access_key,
|
55
|
+
)
|
56
|
+
|
57
|
+
if drop_old:
|
58
|
+
# delete old index if exists
|
59
|
+
response = client.list_indexes(vectorBucketName=self.bucket_name)
|
60
|
+
index_names = [index["indexName"] for index in response["indexes"]]
|
61
|
+
if self.index_name in index_names:
|
62
|
+
log.info(f"drop old index: {self.index_name}")
|
63
|
+
client.delete_index(vectorBucketName=self.bucket_name, indexName=self.index_name)
|
64
|
+
|
65
|
+
# create the index
|
66
|
+
client.create_index(
|
67
|
+
vectorBucketName=self.bucket_name,
|
68
|
+
indexName=self.index_name,
|
69
|
+
dataType=self.case_config.data_type,
|
70
|
+
dimension=dim,
|
71
|
+
distanceMetric=self.case_config.parse_metric(),
|
72
|
+
)
|
73
|
+
|
74
|
+
client.close()
|
75
|
+
|
76
|
+
@contextmanager
|
77
|
+
def init(self):
|
78
|
+
"""
|
79
|
+
Examples:
|
80
|
+
>>> with self.init():
|
81
|
+
>>> self.insert_embeddings()
|
82
|
+
>>> self.search_embedding()
|
83
|
+
"""
|
84
|
+
self.client = boto3.client(
|
85
|
+
service_name="s3vectors",
|
86
|
+
region_name=self.region_name,
|
87
|
+
aws_access_key_id=self.access_key_id,
|
88
|
+
aws_secret_access_key=self.secret_access_key,
|
89
|
+
)
|
90
|
+
|
91
|
+
yield
|
92
|
+
self.client.close()
|
93
|
+
|
94
|
+
def optimize(self, **kwargs):
|
95
|
+
return
|
96
|
+
|
97
|
+
def need_normalize_cosine(self) -> bool:
|
98
|
+
"""Wheather this database need to normalize dataset to support COSINE"""
|
99
|
+
return False
|
100
|
+
|
101
|
+
def insert_embeddings(
|
102
|
+
self,
|
103
|
+
embeddings: Iterable[list[float]],
|
104
|
+
metadata: list[int],
|
105
|
+
labels_data: list[str] | None = None,
|
106
|
+
**kwargs,
|
107
|
+
) -> tuple[int, Exception]:
|
108
|
+
"""Insert embeddings into s3-vectors. should call self.init() first"""
|
109
|
+
# use the first insert_embeddings to init collection
|
110
|
+
assert self.client is not None
|
111
|
+
assert len(embeddings) == len(metadata)
|
112
|
+
insert_count = 0
|
113
|
+
try:
|
114
|
+
for batch_start_offset in range(0, len(embeddings), self.batch_size):
|
115
|
+
batch_end_offset = min(batch_start_offset + self.batch_size, len(embeddings))
|
116
|
+
insert_data = [
|
117
|
+
{
|
118
|
+
"key": str(metadata[i]),
|
119
|
+
"data": {self.case_config.data_type: embeddings[i]},
|
120
|
+
"metadata": (
|
121
|
+
{self._scalar_label_field: labels_data[i], self._scalar_id_field: metadata[i]}
|
122
|
+
if self.with_scalar_labels
|
123
|
+
else {self._scalar_id_field: metadata[i]}
|
124
|
+
),
|
125
|
+
}
|
126
|
+
for i in range(batch_start_offset, batch_end_offset)
|
127
|
+
]
|
128
|
+
self.client.put_vectors(
|
129
|
+
vectorBucketName=self.bucket_name,
|
130
|
+
indexName=self.index_name,
|
131
|
+
vectors=insert_data,
|
132
|
+
)
|
133
|
+
insert_count += len(insert_data)
|
134
|
+
except Exception as e:
|
135
|
+
log.info(f"Failed to insert data: {e}")
|
136
|
+
return insert_count, e
|
137
|
+
return insert_count, None
|
138
|
+
|
139
|
+
def prepare_filter(self, filters: Filter):
|
140
|
+
if filters.type == FilterOp.NonFilter:
|
141
|
+
self.filter = None
|
142
|
+
elif filters.type == FilterOp.NumGE:
|
143
|
+
self.filter = {self._scalar_id_field: {"$gte": filters.int_value}}
|
144
|
+
elif filters.type == FilterOp.StrEqual:
|
145
|
+
self.filter = {self._scalar_label_field: filters.label_value}
|
146
|
+
else:
|
147
|
+
msg = f"Not support Filter for S3Vectors - {filters}"
|
148
|
+
raise ValueError(msg)
|
149
|
+
|
150
|
+
def search_embedding(
|
151
|
+
self,
|
152
|
+
query: list[float],
|
153
|
+
k: int = 100,
|
154
|
+
timeout: int | None = None,
|
155
|
+
) -> list[int]:
|
156
|
+
"""Perform a search on a query embedding and return results."""
|
157
|
+
assert self.client is not None
|
158
|
+
|
159
|
+
# Perform the search.
|
160
|
+
res = self.client.query_vectors(
|
161
|
+
vectorBucketName=self.bucket_name,
|
162
|
+
indexName=self.index_name,
|
163
|
+
queryVector={"float32": query},
|
164
|
+
topK=k,
|
165
|
+
filter=self.filter,
|
166
|
+
returnDistance=False,
|
167
|
+
returnMetadata=False,
|
168
|
+
)
|
169
|
+
|
170
|
+
# Organize results.
|
171
|
+
return [int(result["key"]) for result in res["vectors"]]
|
@@ -17,7 +17,6 @@ class TiDBTypedDict(CommonTypedDict):
|
|
17
17
|
help="Username",
|
18
18
|
default="root",
|
19
19
|
show_default=True,
|
20
|
-
required=True,
|
21
20
|
),
|
22
21
|
]
|
23
22
|
password: Annotated[
|
@@ -37,7 +36,6 @@ class TiDBTypedDict(CommonTypedDict):
|
|
37
36
|
type=str,
|
38
37
|
default="127.0.0.1",
|
39
38
|
show_default=True,
|
40
|
-
required=True,
|
41
39
|
help="Db host",
|
42
40
|
),
|
43
41
|
]
|
@@ -48,7 +46,6 @@ class TiDBTypedDict(CommonTypedDict):
|
|
48
46
|
type=int,
|
49
47
|
default=4000,
|
50
48
|
show_default=True,
|
51
|
-
required=True,
|
52
49
|
help="Db Port",
|
53
50
|
),
|
54
51
|
]
|
@@ -59,7 +56,6 @@ class TiDBTypedDict(CommonTypedDict):
|
|
59
56
|
type=str,
|
60
57
|
default="test",
|
61
58
|
show_default=True,
|
62
|
-
required=True,
|
63
59
|
help="Db name",
|
64
60
|
),
|
65
61
|
]
|
@@ -1,8 +1,20 @@
|
|
1
|
-
from
|
1
|
+
from typing import TypedDict
|
2
|
+
|
3
|
+
from pydantic import BaseModel, SecretStr, validator
|
2
4
|
|
3
5
|
from ..api import DBCaseConfig, DBConfig, MetricType
|
4
6
|
|
5
7
|
|
8
|
+
class TiDBConfigDict(TypedDict):
|
9
|
+
host: str
|
10
|
+
port: int
|
11
|
+
user: str
|
12
|
+
password: str
|
13
|
+
database: str
|
14
|
+
ssl_verify_cert: bool
|
15
|
+
ssl_verify_identity: bool
|
16
|
+
|
17
|
+
|
6
18
|
class TiDBConfig(DBConfig):
|
7
19
|
user_name: str = "root"
|
8
20
|
password: SecretStr
|
@@ -11,7 +23,7 @@ class TiDBConfig(DBConfig):
|
|
11
23
|
db_name: str = "test"
|
12
24
|
ssl: bool = False
|
13
25
|
|
14
|
-
def to_dict(self) ->
|
26
|
+
def to_dict(self) -> TiDBConfigDict:
|
15
27
|
pwd_str = self.password.get_secret_value()
|
16
28
|
return {
|
17
29
|
"host": self.host,
|
@@ -23,6 +35,14 @@ class TiDBConfig(DBConfig):
|
|
23
35
|
"ssl_verify_identity": self.ssl,
|
24
36
|
}
|
25
37
|
|
38
|
+
@validator("*")
|
39
|
+
def not_empty_field(cls, v: any, field: any):
|
40
|
+
if field.name in ["password", "db_label"]:
|
41
|
+
return v
|
42
|
+
if isinstance(v, str | SecretStr) and len(v) == 0:
|
43
|
+
raise ValueError("Empty string!")
|
44
|
+
return v
|
45
|
+
|
26
46
|
|
27
47
|
class TiDBIndexConfig(BaseModel, DBCaseConfig):
|
28
48
|
metric_type: MetricType | None = None
|
@@ -48,6 +48,7 @@ class BaseDataset(BaseModel):
|
|
48
48
|
scalar_labels_file_separated: bool = True
|
49
49
|
scalar_labels_file: str = "scalar_labels.parquet"
|
50
50
|
scalar_label_percentages: list[float] = []
|
51
|
+
scalar_int_rates: list[float] = []
|
51
52
|
train_id_field: str = "id"
|
52
53
|
train_vector_field: str = "emb"
|
53
54
|
test_file: str = "test.parquet"
|
@@ -164,6 +165,29 @@ class Cohere(BaseDataset):
|
|
164
165
|
}
|
165
166
|
with_scalar_labels: bool = True
|
166
167
|
scalar_label_percentages: list[float] = [0.001, 0.002, 0.005, 0.01, 0.02, 0.05, 0.1, 0.2, 0.5]
|
168
|
+
scalar_int_rates: list[float] = [
|
169
|
+
0.001,
|
170
|
+
0.002,
|
171
|
+
0.005,
|
172
|
+
0.01,
|
173
|
+
0.02,
|
174
|
+
0.05,
|
175
|
+
0.1,
|
176
|
+
0.2,
|
177
|
+
0.3,
|
178
|
+
0.4,
|
179
|
+
0.5,
|
180
|
+
0.6,
|
181
|
+
0.7,
|
182
|
+
0.8,
|
183
|
+
0.9,
|
184
|
+
0.95,
|
185
|
+
0.98,
|
186
|
+
0.99,
|
187
|
+
0.995,
|
188
|
+
0.998,
|
189
|
+
0.999,
|
190
|
+
]
|
167
191
|
|
168
192
|
|
169
193
|
class Bioasq(BaseDataset):
|
@@ -178,6 +202,29 @@ class Bioasq(BaseDataset):
|
|
178
202
|
}
|
179
203
|
with_scalar_labels: bool = True
|
180
204
|
scalar_label_percentages: list[float] = [0.001, 0.002, 0.005, 0.01, 0.02, 0.05, 0.1, 0.2, 0.5]
|
205
|
+
scalar_int_rates: list[float] = [
|
206
|
+
0.001,
|
207
|
+
0.002,
|
208
|
+
0.005,
|
209
|
+
0.01,
|
210
|
+
0.02,
|
211
|
+
0.05,
|
212
|
+
0.1,
|
213
|
+
0.2,
|
214
|
+
0.3,
|
215
|
+
0.4,
|
216
|
+
0.5,
|
217
|
+
0.6,
|
218
|
+
0.7,
|
219
|
+
0.8,
|
220
|
+
0.9,
|
221
|
+
0.95,
|
222
|
+
0.98,
|
223
|
+
0.99,
|
224
|
+
0.995,
|
225
|
+
0.998,
|
226
|
+
0.999,
|
227
|
+
]
|
181
228
|
|
182
229
|
|
183
230
|
class Glove(BaseDataset):
|
@@ -217,6 +264,29 @@ class OpenAI(BaseDataset):
|
|
217
264
|
}
|
218
265
|
with_scalar_labels: bool = True
|
219
266
|
scalar_label_percentages: list[float] = [0.001, 0.002, 0.005, 0.01, 0.02, 0.05, 0.1, 0.2, 0.5]
|
267
|
+
scalar_int_rates: list[float] = [
|
268
|
+
0.001,
|
269
|
+
0.002,
|
270
|
+
0.005,
|
271
|
+
0.01,
|
272
|
+
0.02,
|
273
|
+
0.05,
|
274
|
+
0.1,
|
275
|
+
0.2,
|
276
|
+
0.3,
|
277
|
+
0.4,
|
278
|
+
0.5,
|
279
|
+
0.6,
|
280
|
+
0.7,
|
281
|
+
0.8,
|
282
|
+
0.9,
|
283
|
+
0.95,
|
284
|
+
0.98,
|
285
|
+
0.99,
|
286
|
+
0.995,
|
287
|
+
0.998,
|
288
|
+
0.999,
|
289
|
+
]
|
220
290
|
|
221
291
|
|
222
292
|
class DatasetManager(BaseModel):
|
vectordb_bench/backend/filter.py
CHANGED
@@ -51,6 +51,23 @@ class IntFilter(Filter):
|
|
51
51
|
raise RuntimeError(msg)
|
52
52
|
|
53
53
|
|
54
|
+
class NewIntFilter(Filter):
|
55
|
+
type: FilterOp = FilterOp.NumGE
|
56
|
+
int_field: str = "id"
|
57
|
+
int_value: int
|
58
|
+
|
59
|
+
@property
|
60
|
+
def int_rate(self) -> str:
|
61
|
+
r = self.filter_rate * 100
|
62
|
+
if 1 <= r <= 99:
|
63
|
+
return f"int_{int(r)}p"
|
64
|
+
return f"int_{r:.1f}p"
|
65
|
+
|
66
|
+
@property
|
67
|
+
def groundtruth_file(self) -> str:
|
68
|
+
return f"neighbors_{self.int_rate}.parquet"
|
69
|
+
|
70
|
+
|
54
71
|
class LabelFilter(Filter):
|
55
72
|
"""
|
56
73
|
filter expr: label_field == label_value, like `color == "red"`
|
@@ -103,6 +103,7 @@ class MultiProcessingSearchRunner:
|
|
103
103
|
conc_num_list = []
|
104
104
|
conc_qps_list = []
|
105
105
|
conc_latency_p99_list = []
|
106
|
+
conc_latency_p95_list = []
|
106
107
|
conc_latency_avg_list = []
|
107
108
|
try:
|
108
109
|
for conc in self.concurrencies:
|
@@ -125,6 +126,7 @@ class MultiProcessingSearchRunner:
|
|
125
126
|
all_count = sum([r.result()[0] for r in future_iter])
|
126
127
|
latencies = sum([r.result()[2] for r in future_iter], start=[])
|
127
128
|
latency_p99 = np.percentile(latencies, 99)
|
129
|
+
latency_p95 = np.percentile(latencies, 95)
|
128
130
|
latency_avg = np.mean(latencies)
|
129
131
|
cost = time.perf_counter() - start
|
130
132
|
|
@@ -132,6 +134,7 @@ class MultiProcessingSearchRunner:
|
|
132
134
|
conc_num_list.append(conc)
|
133
135
|
conc_qps_list.append(qps)
|
134
136
|
conc_latency_p99_list.append(latency_p99)
|
137
|
+
conc_latency_p95_list.append(latency_p95)
|
135
138
|
conc_latency_avg_list.append(latency_avg)
|
136
139
|
log.info(f"End search in concurrency {conc}: dur={cost}s, total_count={all_count}, qps={qps}")
|
137
140
|
|
@@ -156,6 +159,7 @@ class MultiProcessingSearchRunner:
|
|
156
159
|
conc_num_list,
|
157
160
|
conc_qps_list,
|
158
161
|
conc_latency_p99_list,
|
162
|
+
conc_latency_p95_list,
|
159
163
|
conc_latency_avg_list,
|
160
164
|
)
|
161
165
|
|
@@ -98,10 +98,10 @@ class ReadWriteRunner(MultiProcessingSearchRunner, RatedMultiThreadingInsertRunn
|
|
98
98
|
log.info("Search after write - Serial search start")
|
99
99
|
test_time = round(time.perf_counter(), 4)
|
100
100
|
res, ssearch_dur = self.serial_search_runner.run()
|
101
|
-
recall, ndcg, p99_latency = res
|
101
|
+
recall, ndcg, p99_latency, p95_latency = res
|
102
102
|
log.info(
|
103
103
|
f"Search after write - Serial search - recall={recall}, ndcg={ndcg}, "
|
104
|
-
f"p99={p99_latency}, dur={ssearch_dur:.4f}",
|
104
|
+
f"p99={p99_latency}, p95={p95_latency}, dur={ssearch_dur:.4f}",
|
105
105
|
)
|
106
106
|
log.info(
|
107
107
|
f"Search after wirte - Conc search start, dur for each conc={self.read_dur_after_write}",
|
@@ -109,7 +109,7 @@ class ReadWriteRunner(MultiProcessingSearchRunner, RatedMultiThreadingInsertRunn
|
|
109
109
|
max_qps, conc_failed_rate = self.run_by_dur(self.read_dur_after_write)
|
110
110
|
log.info(f"Search after wirte - Conc search finished, max_qps={max_qps}")
|
111
111
|
|
112
|
-
return [(perc, test_time, max_qps, recall, ndcg, p99_latency, conc_failed_rate)]
|
112
|
+
return [(perc, test_time, max_qps, recall, ndcg, p99_latency, p95_latency, conc_failed_rate)]
|
113
113
|
|
114
114
|
def run_read_write(self) -> Metric:
|
115
115
|
"""
|
@@ -157,7 +157,8 @@ class ReadWriteRunner(MultiProcessingSearchRunner, RatedMultiThreadingInsertRunn
|
|
157
157
|
m.st_recall_list = [d[3] for d in r]
|
158
158
|
m.st_ndcg_list = [d[4] for d in r]
|
159
159
|
m.st_serial_latency_p99_list = [d[5] for d in r]
|
160
|
-
m.
|
160
|
+
m.st_serial_latency_p95_list = [d[6] for d in r]
|
161
|
+
m.st_conc_failed_rate_list = [d[7] for d in r]
|
161
162
|
|
162
163
|
except Exception as e:
|
163
164
|
log.warning(f"Read and write error: {e}")
|
@@ -201,7 +202,7 @@ class ReadWriteRunner(MultiProcessingSearchRunner, RatedMultiThreadingInsertRunn
|
|
201
202
|
"""
|
202
203
|
result, start_batch = [], 0
|
203
204
|
total_batch = math.ceil(self.data_volume / self.insert_rate)
|
204
|
-
recall, ndcg, p99_latency = None, None, None
|
205
|
+
recall, ndcg, p99_latency, p95_latency = None, None, None, None
|
205
206
|
|
206
207
|
def wait_next_target(start: int, target_batch: int) -> bool:
|
207
208
|
"""Return False when receive True or None"""
|
@@ -224,15 +225,15 @@ class ReadWriteRunner(MultiProcessingSearchRunner, RatedMultiThreadingInsertRunn
|
|
224
225
|
|
225
226
|
log.info(f"Insert {perc}% done, total batch={total_batch}")
|
226
227
|
test_time = round(time.perf_counter(), 4)
|
227
|
-
max_qps, recall, ndcg, p99_latency, conc_failed_rate = 0, 0, 0, 0, 0
|
228
|
+
max_qps, recall, ndcg, p99_latency, p95_latency, conc_failed_rate = 0, 0, 0, 0, 0, 0
|
228
229
|
try:
|
229
230
|
log.info(f"[{target_batch}/{total_batch}] Serial search - {perc}% start")
|
230
231
|
res, ssearch_dur = self.serial_search_runner.run()
|
231
232
|
ssearch_dur = round(ssearch_dur, 4)
|
232
|
-
recall, ndcg, p99_latency = res
|
233
|
+
recall, ndcg, p99_latency, p95_latency = res
|
233
234
|
log.info(
|
234
235
|
f"[{target_batch}/{total_batch}] Serial search - {perc}% done, "
|
235
|
-
f"recall={recall}, ndcg={ndcg}, p99={p99_latency}, dur={ssearch_dur}"
|
236
|
+
f"recall={recall}, ndcg={ndcg}, p99={p99_latency}, p95={p95_latency}, dur={ssearch_dur}"
|
236
237
|
)
|
237
238
|
|
238
239
|
each_conc_search_dur = self.get_each_conc_search_dur(
|
@@ -250,7 +251,7 @@ class ReadWriteRunner(MultiProcessingSearchRunner, RatedMultiThreadingInsertRunn
|
|
250
251
|
log.warning(f"Skip concurrent tests, each_conc_search_dur={each_conc_search_dur} less than 10s.")
|
251
252
|
except Exception as e:
|
252
253
|
log.warning(f"Streaming Search Failed at stage={stage}. Exception: {e}")
|
253
|
-
result.append((perc, test_time, max_qps, recall, ndcg, p99_latency, conc_failed_rate))
|
254
|
+
result.append((perc, test_time, max_qps, recall, ndcg, p99_latency, p95_latency, conc_failed_rate))
|
254
255
|
start_batch = target_batch
|
255
256
|
|
256
257
|
# Drain the queue
|