vectordb-bench 0.0.11__py3-none-any.whl → 0.0.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vectordb_bench/__init__.py +1 -0
- vectordb_bench/backend/assembler.py +1 -1
- vectordb_bench/backend/cases.py +64 -18
- vectordb_bench/backend/clients/__init__.py +13 -0
- vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py +159 -0
- vectordb_bench/backend/clients/aws_opensearch/cli.py +44 -0
- vectordb_bench/backend/clients/aws_opensearch/config.py +58 -0
- vectordb_bench/backend/clients/aws_opensearch/run.py +125 -0
- vectordb_bench/backend/dataset.py +27 -5
- vectordb_bench/cli/vectordbbench.py +2 -0
- vectordb_bench/custom/custom_case.json +18 -0
- vectordb_bench/frontend/components/check_results/charts.py +6 -6
- vectordb_bench/frontend/components/check_results/data.py +12 -12
- vectordb_bench/frontend/components/check_results/expanderStyle.py +1 -1
- vectordb_bench/frontend/components/check_results/filters.py +20 -13
- vectordb_bench/frontend/components/check_results/headerIcon.py +1 -1
- vectordb_bench/frontend/components/check_results/priceTable.py +1 -1
- vectordb_bench/frontend/components/check_results/stPageConfig.py +1 -1
- vectordb_bench/frontend/components/concurrent/charts.py +26 -29
- vectordb_bench/frontend/components/custom/displayCustomCase.py +31 -0
- vectordb_bench/frontend/components/custom/displaypPrams.py +11 -0
- vectordb_bench/frontend/components/custom/getCustomConfig.py +40 -0
- vectordb_bench/frontend/components/custom/initStyle.py +15 -0
- vectordb_bench/frontend/components/run_test/autoRefresh.py +1 -1
- vectordb_bench/frontend/components/run_test/caseSelector.py +40 -28
- vectordb_bench/frontend/components/run_test/dbConfigSetting.py +1 -5
- vectordb_bench/frontend/components/run_test/dbSelector.py +2 -14
- vectordb_bench/frontend/components/run_test/generateTasks.py +3 -5
- vectordb_bench/frontend/components/run_test/initStyle.py +14 -0
- vectordb_bench/frontend/components/run_test/submitTask.py +1 -1
- vectordb_bench/frontend/{const → config}/dbCaseConfigs.py +138 -31
- vectordb_bench/frontend/{const → config}/styles.py +2 -0
- vectordb_bench/frontend/pages/concurrent.py +11 -18
- vectordb_bench/frontend/pages/custom.py +64 -0
- vectordb_bench/frontend/pages/quries_per_dollar.py +5 -5
- vectordb_bench/frontend/pages/run_test.py +4 -0
- vectordb_bench/frontend/pages/tables.py +2 -2
- vectordb_bench/frontend/utils.py +17 -1
- vectordb_bench/frontend/vdb_benchmark.py +3 -3
- vectordb_bench/models.py +8 -4
- vectordb_bench/results/getLeaderboardData.py +1 -1
- {vectordb_bench-0.0.11.dist-info → vectordb_bench-0.0.12.dist-info}/METADATA +36 -13
- {vectordb_bench-0.0.11.dist-info → vectordb_bench-0.0.12.dist-info}/RECORD +48 -37
- {vectordb_bench-0.0.11.dist-info → vectordb_bench-0.0.12.dist-info}/WHEEL +1 -1
- /vectordb_bench/frontend/{const → config}/dbPrices.py +0 -0
- {vectordb_bench-0.0.11.dist-info → vectordb_bench-0.0.12.dist-info}/LICENSE +0 -0
- {vectordb_bench-0.0.11.dist-info → vectordb_bench-0.0.12.dist-info}/entry_points.txt +0 -0
- {vectordb_bench-0.0.11.dist-info → vectordb_bench-0.0.12.dist-info}/top_level.txt +0 -0
vectordb_bench/__init__.py
CHANGED
@@ -35,6 +35,7 @@ class config:
|
|
35
35
|
|
36
36
|
|
37
37
|
K_DEFAULT = 100 # default return top k nearest neighbors during search
|
38
|
+
CUSTOM_CONFIG_DIR = pathlib.Path(__file__).parent.joinpath("custom/custom_case.json")
|
38
39
|
|
39
40
|
CAPACITY_TIMEOUT_IN_SECONDS = 24 * 3600 # 24h
|
40
41
|
LOAD_TIMEOUT_DEFAULT = 2.5 * 3600 # 2.5h
|
@@ -14,7 +14,7 @@ class Assembler:
|
|
14
14
|
def assemble(cls, run_id , task: TaskConfig, source: DatasetSource) -> CaseRunner:
|
15
15
|
c_cls = task.case_config.case_id.case_cls
|
16
16
|
|
17
|
-
c = c_cls()
|
17
|
+
c = c_cls(task.case_config.custom_case)
|
18
18
|
if type(task.db_case_config) != EmptyDBCaseConfig:
|
19
19
|
task.db_case_config.metric_type = c.dataset.data.metric_type
|
20
20
|
|
vectordb_bench/backend/cases.py
CHANGED
@@ -4,9 +4,13 @@ from enum import Enum, auto
|
|
4
4
|
from typing import Type
|
5
5
|
|
6
6
|
from vectordb_bench import config
|
7
|
+
from vectordb_bench.backend.clients.api import MetricType
|
7
8
|
from vectordb_bench.base import BaseModel
|
9
|
+
from vectordb_bench.frontend.components.custom.getCustomConfig import (
|
10
|
+
CustomDatasetConfig,
|
11
|
+
)
|
8
12
|
|
9
|
-
from .dataset import Dataset, DatasetManager
|
13
|
+
from .dataset import CustomDataset, Dataset, DatasetManager
|
10
14
|
|
11
15
|
|
12
16
|
log = logging.getLogger(__name__)
|
@@ -44,25 +48,24 @@ class CaseType(Enum):
|
|
44
48
|
Performance1536D50K = 50
|
45
49
|
|
46
50
|
Custom = 100
|
51
|
+
PerformanceCustomDataset = 101
|
47
52
|
|
48
|
-
@property
|
49
53
|
def case_cls(self, custom_configs: dict | None = None) -> Type["Case"]:
|
50
|
-
if
|
51
|
-
|
52
|
-
|
54
|
+
if custom_configs is None:
|
55
|
+
return type2case.get(self)()
|
56
|
+
else:
|
57
|
+
return type2case.get(self)(**custom_configs)
|
53
58
|
|
54
|
-
|
55
|
-
|
56
|
-
c = self.case_cls
|
59
|
+
def case_name(self, custom_configs: dict | None = None) -> str:
|
60
|
+
c = self.case_cls(custom_configs)
|
57
61
|
if c is not None:
|
58
|
-
return c
|
62
|
+
return c.name
|
59
63
|
raise ValueError("Case unsupported")
|
60
64
|
|
61
|
-
|
62
|
-
|
63
|
-
c = self.case_cls
|
65
|
+
def case_description(self, custom_configs: dict | None = None) -> str:
|
66
|
+
c = self.case_cls(custom_configs)
|
64
67
|
if c is not None:
|
65
|
-
return c
|
68
|
+
return c.description
|
66
69
|
raise ValueError("Case unsupported")
|
67
70
|
|
68
71
|
|
@@ -289,26 +292,69 @@ Results will show index building time, recall, and maximum QPS."""
|
|
289
292
|
optimize_timeout: float | int | None = 15 * 60
|
290
293
|
|
291
294
|
|
295
|
+
def metric_type_map(s: str) -> MetricType:
|
296
|
+
if s.lower() == "cosine":
|
297
|
+
return MetricType.COSINE
|
298
|
+
if s.lower() == "l2" or s.lower() == "euclidean":
|
299
|
+
return MetricType.L2
|
300
|
+
if s.lower() == "ip":
|
301
|
+
return MetricType.IP
|
302
|
+
err_msg = f"Not support metric_type: {s}"
|
303
|
+
log.error(err_msg)
|
304
|
+
raise RuntimeError(err_msg)
|
305
|
+
|
306
|
+
|
307
|
+
class PerformanceCustomDataset(PerformanceCase):
|
308
|
+
case_id: CaseType = CaseType.PerformanceCustomDataset
|
309
|
+
name: str = "Performance With Custom Dataset"
|
310
|
+
description: str = ""
|
311
|
+
dataset: DatasetManager
|
312
|
+
|
313
|
+
def __init__(
|
314
|
+
self,
|
315
|
+
name,
|
316
|
+
description,
|
317
|
+
load_timeout,
|
318
|
+
optimize_timeout,
|
319
|
+
dataset_config,
|
320
|
+
**kwargs,
|
321
|
+
):
|
322
|
+
dataset_config = CustomDatasetConfig(**dataset_config)
|
323
|
+
dataset = CustomDataset(
|
324
|
+
name=dataset_config.name,
|
325
|
+
size=dataset_config.size,
|
326
|
+
dim=dataset_config.dim,
|
327
|
+
metric_type=metric_type_map(dataset_config.metric_type),
|
328
|
+
use_shuffled=dataset_config.use_shuffled,
|
329
|
+
with_gt=dataset_config.with_gt,
|
330
|
+
dir=dataset_config.dir,
|
331
|
+
file_num=dataset_config.file_count,
|
332
|
+
)
|
333
|
+
super().__init__(
|
334
|
+
name=name,
|
335
|
+
description=description,
|
336
|
+
load_timeout=load_timeout,
|
337
|
+
optimize_timeout=optimize_timeout,
|
338
|
+
dataset=DatasetManager(data=dataset),
|
339
|
+
)
|
340
|
+
|
341
|
+
|
292
342
|
type2case = {
|
293
343
|
CaseType.CapacityDim960: CapacityDim960,
|
294
344
|
CaseType.CapacityDim128: CapacityDim128,
|
295
|
-
|
296
345
|
CaseType.Performance768D100M: Performance768D100M,
|
297
346
|
CaseType.Performance768D10M: Performance768D10M,
|
298
347
|
CaseType.Performance768D1M: Performance768D1M,
|
299
|
-
|
300
348
|
CaseType.Performance768D10M1P: Performance768D10M1P,
|
301
349
|
CaseType.Performance768D1M1P: Performance768D1M1P,
|
302
350
|
CaseType.Performance768D10M99P: Performance768D10M99P,
|
303
351
|
CaseType.Performance768D1M99P: Performance768D1M99P,
|
304
|
-
|
305
352
|
CaseType.Performance1536D500K: Performance1536D500K,
|
306
353
|
CaseType.Performance1536D5M: Performance1536D5M,
|
307
|
-
|
308
354
|
CaseType.Performance1536D500K1P: Performance1536D500K1P,
|
309
355
|
CaseType.Performance1536D5M1P: Performance1536D5M1P,
|
310
|
-
|
311
356
|
CaseType.Performance1536D500K99P: Performance1536D500K99P,
|
312
357
|
CaseType.Performance1536D5M99P: Performance1536D5M99P,
|
313
358
|
CaseType.Performance1536D50K: Performance1536D50K,
|
359
|
+
CaseType.PerformanceCustomDataset: PerformanceCustomDataset,
|
314
360
|
}
|
@@ -32,6 +32,7 @@ class DB(Enum):
|
|
32
32
|
PgVectoRS = "PgVectoRS"
|
33
33
|
Redis = "Redis"
|
34
34
|
Chroma = "Chroma"
|
35
|
+
AWSOpenSearch = "OpenSearch"
|
35
36
|
Test = "test"
|
36
37
|
|
37
38
|
|
@@ -78,6 +79,10 @@ class DB(Enum):
|
|
78
79
|
from .chroma.chroma import ChromaClient
|
79
80
|
return ChromaClient
|
80
81
|
|
82
|
+
if self == DB.AWSOpenSearch:
|
83
|
+
from .aws_opensearch.aws_opensearch import AWSOpenSearch
|
84
|
+
return AWSOpenSearch
|
85
|
+
|
81
86
|
@property
|
82
87
|
def config_cls(self) -> Type[DBConfig]:
|
83
88
|
"""Import while in use"""
|
@@ -121,6 +126,10 @@ class DB(Enum):
|
|
121
126
|
from .chroma.config import ChromaConfig
|
122
127
|
return ChromaConfig
|
123
128
|
|
129
|
+
if self == DB.AWSOpenSearch:
|
130
|
+
from .aws_opensearch.config import AWSOpenSearchConfig
|
131
|
+
return AWSOpenSearchConfig
|
132
|
+
|
124
133
|
def case_config_cls(self, index_type: IndexType | None = None) -> Type[DBCaseConfig]:
|
125
134
|
if self == DB.Milvus:
|
126
135
|
from .milvus.config import _milvus_case_config
|
@@ -150,6 +159,10 @@ class DB(Enum):
|
|
150
159
|
from .pgvecto_rs.config import _pgvecto_rs_case_config
|
151
160
|
return _pgvecto_rs_case_config.get(index_type)
|
152
161
|
|
162
|
+
if self == DB.AWSOpenSearch:
|
163
|
+
from .aws_opensearch.config import AWSOpenSearchIndexConfig
|
164
|
+
return AWSOpenSearchIndexConfig
|
165
|
+
|
153
166
|
# DB.Pinecone, DB.Chroma, DB.Redis
|
154
167
|
return EmptyDBCaseConfig
|
155
168
|
|
@@ -0,0 +1,159 @@
|
|
1
|
+
import logging
|
2
|
+
from contextlib import contextmanager
|
3
|
+
import time
|
4
|
+
from typing import Iterable, Type
|
5
|
+
from ..api import VectorDB, DBCaseConfig, DBConfig, IndexType
|
6
|
+
from .config import AWSOpenSearchConfig, AWSOpenSearchIndexConfig
|
7
|
+
from opensearchpy import OpenSearch
|
8
|
+
from opensearchpy.helpers import bulk
|
9
|
+
|
10
|
+
log = logging.getLogger(__name__)
|
11
|
+
|
12
|
+
|
13
|
+
class AWSOpenSearch(VectorDB):
|
14
|
+
def __init__(
|
15
|
+
self,
|
16
|
+
dim: int,
|
17
|
+
db_config: dict,
|
18
|
+
db_case_config: AWSOpenSearchIndexConfig,
|
19
|
+
index_name: str = "vdb_bench_index", # must be lowercase
|
20
|
+
id_col_name: str = "id",
|
21
|
+
vector_col_name: str = "embedding",
|
22
|
+
drop_old: bool = False,
|
23
|
+
**kwargs,
|
24
|
+
):
|
25
|
+
self.dim = dim
|
26
|
+
self.db_config = db_config
|
27
|
+
self.case_config = db_case_config
|
28
|
+
self.index_name = index_name
|
29
|
+
self.id_col_name = id_col_name
|
30
|
+
self.category_col_names = [
|
31
|
+
f"scalar-{categoryCount}" for categoryCount in [2, 5, 10, 100, 1000]
|
32
|
+
]
|
33
|
+
self.vector_col_name = vector_col_name
|
34
|
+
|
35
|
+
log.info(f"AWS_OpenSearch client config: {self.db_config}")
|
36
|
+
client = OpenSearch(**self.db_config)
|
37
|
+
if drop_old:
|
38
|
+
log.info(f"AWS_OpenSearch client drop old index: {self.index_name}")
|
39
|
+
is_existed = client.indices.exists(index=self.index_name)
|
40
|
+
if is_existed:
|
41
|
+
client.indices.delete(index=self.index_name)
|
42
|
+
self._create_index(client)
|
43
|
+
|
44
|
+
@classmethod
|
45
|
+
def config_cls(cls) -> AWSOpenSearchConfig:
|
46
|
+
return AWSOpenSearchConfig
|
47
|
+
|
48
|
+
@classmethod
|
49
|
+
def case_config_cls(
|
50
|
+
cls, index_type: IndexType | None = None
|
51
|
+
) -> AWSOpenSearchIndexConfig:
|
52
|
+
return AWSOpenSearchIndexConfig
|
53
|
+
|
54
|
+
def _create_index(self, client: OpenSearch):
|
55
|
+
settings = {
|
56
|
+
"index": {
|
57
|
+
"knn": True,
|
58
|
+
# "number_of_shards": 5,
|
59
|
+
# "refresh_interval": "600s",
|
60
|
+
}
|
61
|
+
}
|
62
|
+
mappings = {
|
63
|
+
"properties": {
|
64
|
+
self.id_col_name: {"type": "integer"},
|
65
|
+
**{
|
66
|
+
categoryCol: {"type": "keyword"}
|
67
|
+
for categoryCol in self.category_col_names
|
68
|
+
},
|
69
|
+
self.vector_col_name: {
|
70
|
+
"type": "knn_vector",
|
71
|
+
"dimension": self.dim,
|
72
|
+
"method": self.case_config.index_param(),
|
73
|
+
},
|
74
|
+
}
|
75
|
+
}
|
76
|
+
try:
|
77
|
+
client.indices.create(
|
78
|
+
index=self.index_name, body=dict(settings=settings, mappings=mappings)
|
79
|
+
)
|
80
|
+
except Exception as e:
|
81
|
+
log.warning(f"Failed to create index: {self.index_name} error: {str(e)}")
|
82
|
+
raise e from None
|
83
|
+
|
84
|
+
@contextmanager
|
85
|
+
def init(self) -> None:
|
86
|
+
"""connect to elasticsearch"""
|
87
|
+
self.client = OpenSearch(**self.db_config)
|
88
|
+
|
89
|
+
yield
|
90
|
+
# self.client.transport.close()
|
91
|
+
self.client = None
|
92
|
+
del self.client
|
93
|
+
|
94
|
+
def insert_embeddings(
|
95
|
+
self,
|
96
|
+
embeddings: Iterable[list[float]],
|
97
|
+
metadata: list[int],
|
98
|
+
**kwargs,
|
99
|
+
) -> tuple[int, Exception]:
|
100
|
+
"""Insert the embeddings to the elasticsearch."""
|
101
|
+
assert self.client is not None, "should self.init() first"
|
102
|
+
|
103
|
+
insert_data = []
|
104
|
+
for i in range(len(embeddings)):
|
105
|
+
insert_data.append({"index": {"_index": self.index_name, "_id": metadata[i]}})
|
106
|
+
insert_data.append({self.vector_col_name: embeddings[i]})
|
107
|
+
try:
|
108
|
+
resp = self.client.bulk(insert_data)
|
109
|
+
log.info(f"AWS_OpenSearch adding documents: {len(resp['items'])}")
|
110
|
+
resp = self.client.indices.stats(self.index_name)
|
111
|
+
log.info(f"Total document count in index: {resp['_all']['primaries']['indexing']['index_total']}")
|
112
|
+
return (len(embeddings), None)
|
113
|
+
except Exception as e:
|
114
|
+
log.warning(f"Failed to insert data: {self.index_name} error: {str(e)}")
|
115
|
+
time.sleep(10)
|
116
|
+
return self.insert_embeddings(embeddings, metadata)
|
117
|
+
|
118
|
+
def search_embedding(
|
119
|
+
self,
|
120
|
+
query: list[float],
|
121
|
+
k: int = 100,
|
122
|
+
filters: dict | None = None,
|
123
|
+
) -> list[int]:
|
124
|
+
"""Get k most similar embeddings to query vector.
|
125
|
+
|
126
|
+
Args:
|
127
|
+
query(list[float]): query embedding to look up documents similar to.
|
128
|
+
k(int): Number of most similar embeddings to return. Defaults to 100.
|
129
|
+
filters(dict, optional): filtering expression to filter the data while searching.
|
130
|
+
|
131
|
+
Returns:
|
132
|
+
list[tuple[int, float]]: list of k most similar embeddings in (id, score) tuple to the query embedding.
|
133
|
+
"""
|
134
|
+
assert self.client is not None, "should self.init() first"
|
135
|
+
|
136
|
+
body = {
|
137
|
+
"size": k,
|
138
|
+
"query": {"knn": {self.vector_col_name: {"vector": query, "k": k}}},
|
139
|
+
}
|
140
|
+
try:
|
141
|
+
resp = self.client.search(index=self.index_name, body=body)
|
142
|
+
log.info(f'Search took: {resp["took"]}')
|
143
|
+
log.info(f'Search shards: {resp["_shards"]}')
|
144
|
+
log.info(f'Search hits total: {resp["hits"]["total"]}')
|
145
|
+
result = [int(d["_id"]) for d in resp["hits"]["hits"]]
|
146
|
+
# log.info(f'success! length={len(res)}')
|
147
|
+
|
148
|
+
return result
|
149
|
+
except Exception as e:
|
150
|
+
log.warning(f"Failed to search: {self.index_name} error: {str(e)}")
|
151
|
+
raise e from None
|
152
|
+
|
153
|
+
def optimize(self):
|
154
|
+
"""optimize will be called between insertion and search in performance cases."""
|
155
|
+
pass
|
156
|
+
|
157
|
+
def ready_to_load(self):
|
158
|
+
"""ready_to_load will be called before load in load cases."""
|
159
|
+
pass
|
@@ -0,0 +1,44 @@
|
|
1
|
+
from typing import Annotated, TypedDict, Unpack
|
2
|
+
|
3
|
+
import click
|
4
|
+
from pydantic import SecretStr
|
5
|
+
|
6
|
+
from ....cli.cli import (
|
7
|
+
CommonTypedDict,
|
8
|
+
HNSWFlavor2,
|
9
|
+
cli,
|
10
|
+
click_parameter_decorators_from_typed_dict,
|
11
|
+
run,
|
12
|
+
)
|
13
|
+
from .. import DB
|
14
|
+
|
15
|
+
|
16
|
+
class AWSOpenSearchTypedDict(TypedDict):
|
17
|
+
host: Annotated[
|
18
|
+
str, click.option("--host", type=str, help="Db host", required=True)
|
19
|
+
]
|
20
|
+
port: Annotated[int, click.option("--port", type=int, default=443, help="Db Port")]
|
21
|
+
user: Annotated[str, click.option("--user", type=str, default="admin", help="Db User")]
|
22
|
+
password: Annotated[str, click.option("--password", type=str, help="Db password")]
|
23
|
+
|
24
|
+
|
25
|
+
class AWSOpenSearchHNSWTypedDict(CommonTypedDict, AWSOpenSearchTypedDict, HNSWFlavor2):
|
26
|
+
...
|
27
|
+
|
28
|
+
|
29
|
+
@cli.command()
|
30
|
+
@click_parameter_decorators_from_typed_dict(AWSOpenSearchHNSWTypedDict)
|
31
|
+
def AWSOpenSearch(**parameters: Unpack[AWSOpenSearchHNSWTypedDict]):
|
32
|
+
from .config import AWSOpenSearchConfig, AWSOpenSearchIndexConfig
|
33
|
+
run(
|
34
|
+
db=DB.AWSOpenSearch,
|
35
|
+
db_config=AWSOpenSearchConfig(
|
36
|
+
host=parameters["host"],
|
37
|
+
port=parameters["port"],
|
38
|
+
user=parameters["user"],
|
39
|
+
password=SecretStr(parameters["password"]),
|
40
|
+
),
|
41
|
+
db_case_config=AWSOpenSearchIndexConfig(
|
42
|
+
),
|
43
|
+
**parameters,
|
44
|
+
)
|
@@ -0,0 +1,58 @@
|
|
1
|
+
from enum import Enum
|
2
|
+
from pydantic import SecretStr, BaseModel
|
3
|
+
|
4
|
+
from ..api import DBConfig, DBCaseConfig, MetricType, IndexType
|
5
|
+
|
6
|
+
|
7
|
+
class AWSOpenSearchConfig(DBConfig, BaseModel):
|
8
|
+
host: str = ""
|
9
|
+
port: int = 443
|
10
|
+
user: str = ""
|
11
|
+
password: SecretStr = ""
|
12
|
+
|
13
|
+
def to_dict(self) -> dict:
|
14
|
+
return {
|
15
|
+
"hosts": [{'host': self.host, 'port': self.port}],
|
16
|
+
"http_auth": (self.user, self.password.get_secret_value()),
|
17
|
+
"use_ssl": True,
|
18
|
+
"http_compress": True,
|
19
|
+
"verify_certs": True,
|
20
|
+
"ssl_assert_hostname": False,
|
21
|
+
"ssl_show_warn": False,
|
22
|
+
"timeout": 600,
|
23
|
+
}
|
24
|
+
|
25
|
+
|
26
|
+
class AWSOS_Engine(Enum):
|
27
|
+
nmslib = "nmslib"
|
28
|
+
faiss = "faiss"
|
29
|
+
lucene = "Lucene"
|
30
|
+
|
31
|
+
|
32
|
+
class AWSOpenSearchIndexConfig(BaseModel, DBCaseConfig):
|
33
|
+
metric_type: MetricType = MetricType.L2
|
34
|
+
engine: AWSOS_Engine = AWSOS_Engine.nmslib
|
35
|
+
efConstruction: int = 360
|
36
|
+
M: int = 30
|
37
|
+
|
38
|
+
def parse_metric(self) -> str:
|
39
|
+
if self.metric_type == MetricType.IP:
|
40
|
+
return "innerproduct" # only support faiss / nmslib, not for Lucene.
|
41
|
+
elif self.metric_type == MetricType.COSINE:
|
42
|
+
return "cosinesimil"
|
43
|
+
return "l2"
|
44
|
+
|
45
|
+
def index_param(self) -> dict:
|
46
|
+
params = {
|
47
|
+
"name": "hnsw",
|
48
|
+
"space_type": self.parse_metric(),
|
49
|
+
"engine": self.engine.value,
|
50
|
+
"parameters": {
|
51
|
+
"ef_construction": self.efConstruction,
|
52
|
+
"m": self.M
|
53
|
+
}
|
54
|
+
}
|
55
|
+
return params
|
56
|
+
|
57
|
+
def search_param(self) -> dict:
|
58
|
+
return {}
|
@@ -0,0 +1,125 @@
|
|
1
|
+
import time, random
|
2
|
+
from opensearchpy import OpenSearch
|
3
|
+
from opensearch_dsl import Search, Document, Text, Keyword
|
4
|
+
|
5
|
+
_HOST = 'xxxxxx.us-west-2.es.amazonaws.com'
|
6
|
+
_PORT = 443
|
7
|
+
_AUTH = ('admin', 'xxxxxx') # For testing only. Don't store credentials in code.
|
8
|
+
|
9
|
+
_INDEX_NAME = 'my-dsl-index'
|
10
|
+
_BATCH = 100
|
11
|
+
_ROWS = 100
|
12
|
+
_DIM = 128
|
13
|
+
_TOPK = 10
|
14
|
+
|
15
|
+
|
16
|
+
def create_client():
|
17
|
+
client = OpenSearch(
|
18
|
+
hosts=[{'host': _HOST, 'port': _PORT}],
|
19
|
+
http_compress=True, # enables gzip compression for request bodies
|
20
|
+
http_auth=_AUTH,
|
21
|
+
use_ssl=True,
|
22
|
+
verify_certs=True,
|
23
|
+
ssl_assert_hostname=False,
|
24
|
+
ssl_show_warn=False,
|
25
|
+
)
|
26
|
+
return client
|
27
|
+
|
28
|
+
|
29
|
+
def create_index(client, index_name):
|
30
|
+
settings = {
|
31
|
+
"index": {
|
32
|
+
"knn": True,
|
33
|
+
"number_of_shards": 1,
|
34
|
+
"refresh_interval": "5s",
|
35
|
+
}
|
36
|
+
}
|
37
|
+
mappings = {
|
38
|
+
"properties": {
|
39
|
+
"embedding": {
|
40
|
+
"type": "knn_vector",
|
41
|
+
"dimension": _DIM,
|
42
|
+
"method": {
|
43
|
+
"engine": "nmslib",
|
44
|
+
"name": "hnsw",
|
45
|
+
"space_type": "l2",
|
46
|
+
"parameters": {
|
47
|
+
"ef_construction": 128,
|
48
|
+
"m": 24,
|
49
|
+
}
|
50
|
+
}
|
51
|
+
}
|
52
|
+
}
|
53
|
+
}
|
54
|
+
|
55
|
+
response = client.indices.create(index=index_name, body=dict(settings=settings, mappings=mappings))
|
56
|
+
print('\nCreating index:')
|
57
|
+
print(response)
|
58
|
+
|
59
|
+
|
60
|
+
def delete_index(client, index_name):
|
61
|
+
response = client.indices.delete(index=index_name)
|
62
|
+
print('\nDeleting index:')
|
63
|
+
print(response)
|
64
|
+
|
65
|
+
|
66
|
+
def bulk_insert(client, index_name):
|
67
|
+
# Perform bulk operations
|
68
|
+
ids = [i for i in range(_ROWS)]
|
69
|
+
vec = [[random.random() for _ in range(_DIM)] for _ in range(_ROWS)]
|
70
|
+
|
71
|
+
docs = []
|
72
|
+
for i in range(0, _ROWS, _BATCH):
|
73
|
+
docs.clear()
|
74
|
+
for j in range(0, _BATCH):
|
75
|
+
docs.append({"index": {"_index": index_name, "_id": ids[i+j]}})
|
76
|
+
docs.append({"embedding": vec[i+j]})
|
77
|
+
response = client.bulk(docs)
|
78
|
+
print('\nAdding documents:', len(response['items']), response['errors'])
|
79
|
+
response = client.indices.stats(index_name)
|
80
|
+
print('\nTotal document count in index:', response['_all']['primaries']['indexing']['index_total'])
|
81
|
+
|
82
|
+
|
83
|
+
def search(client, index_name):
|
84
|
+
# Search for the document.
|
85
|
+
search_body = {
|
86
|
+
"size": _TOPK,
|
87
|
+
"query": {
|
88
|
+
"knn": {
|
89
|
+
"embedding": {
|
90
|
+
"vector": [random.random() for _ in range(_DIM)],
|
91
|
+
"k": _TOPK,
|
92
|
+
}
|
93
|
+
}
|
94
|
+
}
|
95
|
+
}
|
96
|
+
while True:
|
97
|
+
response = client.search(index=index_name, body=search_body)
|
98
|
+
print(f'\nSearch took: {response["took"]}')
|
99
|
+
print(f'\nSearch shards: {response["_shards"]}')
|
100
|
+
print(f'\nSearch hits total: {response["hits"]["total"]}')
|
101
|
+
result = response["hits"]["hits"]
|
102
|
+
if len(result) != 0:
|
103
|
+
print('\nSearch results:')
|
104
|
+
for hit in response["hits"]["hits"]:
|
105
|
+
print(hit["_id"], hit["_score"])
|
106
|
+
break
|
107
|
+
else:
|
108
|
+
print('\nSearch not ready, sleep 1s')
|
109
|
+
time.sleep(1)
|
110
|
+
|
111
|
+
|
112
|
+
def main():
|
113
|
+
client = create_client()
|
114
|
+
try:
|
115
|
+
create_index(client, _INDEX_NAME)
|
116
|
+
bulk_insert(client, _INDEX_NAME)
|
117
|
+
search(client, _INDEX_NAME)
|
118
|
+
delete_index(client, _INDEX_NAME)
|
119
|
+
except Exception as e:
|
120
|
+
print(e)
|
121
|
+
delete_index(client, _INDEX_NAME)
|
122
|
+
|
123
|
+
|
124
|
+
if __name__ == '__main__':
|
125
|
+
main()
|
@@ -33,6 +33,7 @@ class BaseDataset(BaseModel):
|
|
33
33
|
use_shuffled: bool
|
34
34
|
with_gt: bool = False
|
35
35
|
_size_label: dict[int, SizeLabel] = PrivateAttr()
|
36
|
+
isCustom: bool = False
|
36
37
|
|
37
38
|
@validator("size")
|
38
39
|
def verify_size(cls, v):
|
@@ -52,7 +53,27 @@ class BaseDataset(BaseModel):
|
|
52
53
|
def file_count(self) -> int:
|
53
54
|
return self._size_label.get(self.size).file_count
|
54
55
|
|
56
|
+
class CustomDataset(BaseDataset):
|
57
|
+
dir: str
|
58
|
+
file_num: int
|
59
|
+
isCustom: bool = True
|
60
|
+
|
61
|
+
@validator("size")
|
62
|
+
def verify_size(cls, v):
|
63
|
+
return v
|
64
|
+
|
65
|
+
@property
|
66
|
+
def label(self) -> str:
|
67
|
+
return "Custom"
|
55
68
|
|
69
|
+
@property
|
70
|
+
def dir_name(self) -> str:
|
71
|
+
return self.dir
|
72
|
+
|
73
|
+
@property
|
74
|
+
def file_count(self) -> int:
|
75
|
+
return self.file_num
|
76
|
+
|
56
77
|
class LAION(BaseDataset):
|
57
78
|
name: str = "LAION"
|
58
79
|
dim: int = 768
|
@@ -186,11 +207,12 @@ class DatasetManager(BaseModel):
|
|
186
207
|
gt_file, test_file = utils.compose_gt_file(filters), "test.parquet"
|
187
208
|
all_files.extend([gt_file, test_file])
|
188
209
|
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
210
|
+
if not self.data.isCustom:
|
211
|
+
source.reader().read(
|
212
|
+
dataset=self.data.dir_name.lower(),
|
213
|
+
files=all_files,
|
214
|
+
local_ds_root=self.data_dir,
|
215
|
+
)
|
194
216
|
|
195
217
|
if gt_file is not None and test_file is not None:
|
196
218
|
self.test_data = self._read_file(test_file)
|
@@ -4,6 +4,7 @@ from ..backend.clients.test.cli import Test
|
|
4
4
|
from ..backend.clients.weaviate_cloud.cli import Weaviate
|
5
5
|
from ..backend.clients.zilliz_cloud.cli import ZillizAutoIndex
|
6
6
|
from ..backend.clients.milvus.cli import MilvusAutoIndex
|
7
|
+
from ..backend.clients.aws_opensearch.cli import AWSOpenSearch
|
7
8
|
|
8
9
|
|
9
10
|
from .cli import cli
|
@@ -14,6 +15,7 @@ cli.add_command(Weaviate)
|
|
14
15
|
cli.add_command(Test)
|
15
16
|
cli.add_command(ZillizAutoIndex)
|
16
17
|
cli.add_command(MilvusAutoIndex)
|
18
|
+
cli.add_command(AWSOpenSearch)
|
17
19
|
|
18
20
|
|
19
21
|
if __name__ == "__main__":
|
@@ -0,0 +1,18 @@
|
|
1
|
+
[
|
2
|
+
{
|
3
|
+
"name": "My Dataset (Performace Case)",
|
4
|
+
"description": "this is a customized dataset.",
|
5
|
+
"load_timeout": 36000,
|
6
|
+
"optimize_timeout": 36000,
|
7
|
+
"dataset_config": {
|
8
|
+
"name": "My Dataset",
|
9
|
+
"dir": "/my_dataset_path",
|
10
|
+
"size": 1000000,
|
11
|
+
"dim": 1024,
|
12
|
+
"metric_type": "L2",
|
13
|
+
"file_count": 1,
|
14
|
+
"use_shuffled": false,
|
15
|
+
"with_gt": true
|
16
|
+
}
|
17
|
+
}
|
18
|
+
]
|
@@ -1,19 +1,19 @@
|
|
1
1
|
from vectordb_bench.backend.cases import Case
|
2
2
|
from vectordb_bench.frontend.components.check_results.expanderStyle import initMainExpanderStyle
|
3
3
|
from vectordb_bench.metric import metricOrder, isLowerIsBetterMetric, metricUnitMap
|
4
|
-
from vectordb_bench.frontend.
|
4
|
+
from vectordb_bench.frontend.config.styles import *
|
5
5
|
from vectordb_bench.models import ResultLabel
|
6
6
|
import plotly.express as px
|
7
7
|
|
8
8
|
|
9
|
-
def drawCharts(st, allData, failedTasks,
|
9
|
+
def drawCharts(st, allData, failedTasks, caseNames: list[str]):
|
10
10
|
initMainExpanderStyle(st)
|
11
|
-
for
|
12
|
-
chartContainer = st.expander(
|
13
|
-
data = [data for data in allData if data["case_name"] ==
|
11
|
+
for caseName in caseNames:
|
12
|
+
chartContainer = st.expander(caseName, True)
|
13
|
+
data = [data for data in allData if data["case_name"] == caseName]
|
14
14
|
drawChart(data, chartContainer)
|
15
15
|
|
16
|
-
errorDBs = failedTasks[
|
16
|
+
errorDBs = failedTasks[caseName]
|
17
17
|
showFailedDBs(chartContainer, errorDBs)
|
18
18
|
|
19
19
|
|