vectordb-bench 0.0.11__py3-none-any.whl → 0.0.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vectordb_bench/__init__.py +1 -0
- vectordb_bench/backend/assembler.py +1 -1
- vectordb_bench/backend/cases.py +64 -18
- vectordb_bench/backend/clients/__init__.py +35 -0
- vectordb_bench/backend/clients/api.py +21 -1
- vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py +159 -0
- vectordb_bench/backend/clients/aws_opensearch/cli.py +44 -0
- vectordb_bench/backend/clients/aws_opensearch/config.py +58 -0
- vectordb_bench/backend/clients/aws_opensearch/run.py +125 -0
- vectordb_bench/backend/clients/memorydb/cli.py +88 -0
- vectordb_bench/backend/clients/memorydb/config.py +54 -0
- vectordb_bench/backend/clients/memorydb/memorydb.py +254 -0
- vectordb_bench/backend/clients/pgvecto_rs/cli.py +154 -0
- vectordb_bench/backend/clients/pgvecto_rs/config.py +108 -73
- vectordb_bench/backend/clients/pgvecto_rs/pgvecto_rs.py +159 -59
- vectordb_bench/backend/clients/pgvectorscale/config.py +111 -0
- vectordb_bench/backend/clients/pgvectorscale/pgvectorscale.py +272 -0
- vectordb_bench/backend/dataset.py +27 -5
- vectordb_bench/cli/vectordbbench.py +7 -0
- vectordb_bench/custom/custom_case.json +18 -0
- vectordb_bench/frontend/components/check_results/charts.py +6 -6
- vectordb_bench/frontend/components/check_results/data.py +18 -11
- vectordb_bench/frontend/components/check_results/expanderStyle.py +1 -1
- vectordb_bench/frontend/components/check_results/filters.py +20 -13
- vectordb_bench/frontend/components/check_results/headerIcon.py +1 -1
- vectordb_bench/frontend/components/check_results/priceTable.py +1 -1
- vectordb_bench/frontend/components/check_results/stPageConfig.py +1 -1
- vectordb_bench/frontend/components/concurrent/charts.py +26 -29
- vectordb_bench/frontend/components/custom/displayCustomCase.py +31 -0
- vectordb_bench/frontend/components/custom/displaypPrams.py +11 -0
- vectordb_bench/frontend/components/custom/getCustomConfig.py +40 -0
- vectordb_bench/frontend/components/custom/initStyle.py +15 -0
- vectordb_bench/frontend/components/run_test/autoRefresh.py +1 -1
- vectordb_bench/frontend/components/run_test/caseSelector.py +50 -28
- vectordb_bench/frontend/components/run_test/dbConfigSetting.py +37 -19
- vectordb_bench/frontend/components/run_test/dbSelector.py +2 -14
- vectordb_bench/frontend/components/run_test/generateTasks.py +3 -5
- vectordb_bench/frontend/components/run_test/initStyle.py +16 -0
- vectordb_bench/frontend/components/run_test/submitTask.py +1 -1
- vectordb_bench/frontend/{const → config}/dbCaseConfigs.py +311 -40
- vectordb_bench/frontend/{const → config}/styles.py +2 -0
- vectordb_bench/frontend/pages/concurrent.py +11 -18
- vectordb_bench/frontend/pages/custom.py +64 -0
- vectordb_bench/frontend/pages/quries_per_dollar.py +5 -5
- vectordb_bench/frontend/pages/run_test.py +4 -0
- vectordb_bench/frontend/pages/tables.py +2 -2
- vectordb_bench/frontend/utils.py +17 -1
- vectordb_bench/frontend/vdb_benchmark.py +3 -3
- vectordb_bench/models.py +26 -10
- vectordb_bench/results/getLeaderboardData.py +1 -1
- {vectordb_bench-0.0.11.dist-info → vectordb_bench-0.0.13.dist-info}/METADATA +46 -15
- {vectordb_bench-0.0.11.dist-info → vectordb_bench-0.0.13.dist-info}/RECORD +57 -40
- {vectordb_bench-0.0.11.dist-info → vectordb_bench-0.0.13.dist-info}/WHEEL +1 -1
- /vectordb_bench/frontend/{const → config}/dbPrices.py +0 -0
- {vectordb_bench-0.0.11.dist-info → vectordb_bench-0.0.13.dist-info}/LICENSE +0 -0
- {vectordb_bench-0.0.11.dist-info → vectordb_bench-0.0.13.dist-info}/entry_points.txt +0 -0
- {vectordb_bench-0.0.11.dist-info → vectordb_bench-0.0.13.dist-info}/top_level.txt +0 -0
vectordb_bench/__init__.py
CHANGED
@@ -35,6 +35,7 @@ class config:
|
|
35
35
|
|
36
36
|
|
37
37
|
K_DEFAULT = 100 # default return top k nearest neighbors during search
|
38
|
+
CUSTOM_CONFIG_DIR = pathlib.Path(__file__).parent.joinpath("custom/custom_case.json")
|
38
39
|
|
39
40
|
CAPACITY_TIMEOUT_IN_SECONDS = 24 * 3600 # 24h
|
40
41
|
LOAD_TIMEOUT_DEFAULT = 2.5 * 3600 # 2.5h
|
@@ -14,7 +14,7 @@ class Assembler:
|
|
14
14
|
def assemble(cls, run_id , task: TaskConfig, source: DatasetSource) -> CaseRunner:
|
15
15
|
c_cls = task.case_config.case_id.case_cls
|
16
16
|
|
17
|
-
c = c_cls()
|
17
|
+
c = c_cls(task.case_config.custom_case)
|
18
18
|
if type(task.db_case_config) != EmptyDBCaseConfig:
|
19
19
|
task.db_case_config.metric_type = c.dataset.data.metric_type
|
20
20
|
|
vectordb_bench/backend/cases.py
CHANGED
@@ -4,9 +4,13 @@ from enum import Enum, auto
|
|
4
4
|
from typing import Type
|
5
5
|
|
6
6
|
from vectordb_bench import config
|
7
|
+
from vectordb_bench.backend.clients.api import MetricType
|
7
8
|
from vectordb_bench.base import BaseModel
|
9
|
+
from vectordb_bench.frontend.components.custom.getCustomConfig import (
|
10
|
+
CustomDatasetConfig,
|
11
|
+
)
|
8
12
|
|
9
|
-
from .dataset import Dataset, DatasetManager
|
13
|
+
from .dataset import CustomDataset, Dataset, DatasetManager
|
10
14
|
|
11
15
|
|
12
16
|
log = logging.getLogger(__name__)
|
@@ -44,25 +48,24 @@ class CaseType(Enum):
|
|
44
48
|
Performance1536D50K = 50
|
45
49
|
|
46
50
|
Custom = 100
|
51
|
+
PerformanceCustomDataset = 101
|
47
52
|
|
48
|
-
@property
|
49
53
|
def case_cls(self, custom_configs: dict | None = None) -> Type["Case"]:
|
50
|
-
if
|
51
|
-
|
52
|
-
|
54
|
+
if custom_configs is None:
|
55
|
+
return type2case.get(self)()
|
56
|
+
else:
|
57
|
+
return type2case.get(self)(**custom_configs)
|
53
58
|
|
54
|
-
|
55
|
-
|
56
|
-
c = self.case_cls
|
59
|
+
def case_name(self, custom_configs: dict | None = None) -> str:
|
60
|
+
c = self.case_cls(custom_configs)
|
57
61
|
if c is not None:
|
58
|
-
return c
|
62
|
+
return c.name
|
59
63
|
raise ValueError("Case unsupported")
|
60
64
|
|
61
|
-
|
62
|
-
|
63
|
-
c = self.case_cls
|
65
|
+
def case_description(self, custom_configs: dict | None = None) -> str:
|
66
|
+
c = self.case_cls(custom_configs)
|
64
67
|
if c is not None:
|
65
|
-
return c
|
68
|
+
return c.description
|
66
69
|
raise ValueError("Case unsupported")
|
67
70
|
|
68
71
|
|
@@ -289,26 +292,69 @@ Results will show index building time, recall, and maximum QPS."""
|
|
289
292
|
optimize_timeout: float | int | None = 15 * 60
|
290
293
|
|
291
294
|
|
295
|
+
def metric_type_map(s: str) -> MetricType:
|
296
|
+
if s.lower() == "cosine":
|
297
|
+
return MetricType.COSINE
|
298
|
+
if s.lower() == "l2" or s.lower() == "euclidean":
|
299
|
+
return MetricType.L2
|
300
|
+
if s.lower() == "ip":
|
301
|
+
return MetricType.IP
|
302
|
+
err_msg = f"Not support metric_type: {s}"
|
303
|
+
log.error(err_msg)
|
304
|
+
raise RuntimeError(err_msg)
|
305
|
+
|
306
|
+
|
307
|
+
class PerformanceCustomDataset(PerformanceCase):
|
308
|
+
case_id: CaseType = CaseType.PerformanceCustomDataset
|
309
|
+
name: str = "Performance With Custom Dataset"
|
310
|
+
description: str = ""
|
311
|
+
dataset: DatasetManager
|
312
|
+
|
313
|
+
def __init__(
|
314
|
+
self,
|
315
|
+
name,
|
316
|
+
description,
|
317
|
+
load_timeout,
|
318
|
+
optimize_timeout,
|
319
|
+
dataset_config,
|
320
|
+
**kwargs,
|
321
|
+
):
|
322
|
+
dataset_config = CustomDatasetConfig(**dataset_config)
|
323
|
+
dataset = CustomDataset(
|
324
|
+
name=dataset_config.name,
|
325
|
+
size=dataset_config.size,
|
326
|
+
dim=dataset_config.dim,
|
327
|
+
metric_type=metric_type_map(dataset_config.metric_type),
|
328
|
+
use_shuffled=dataset_config.use_shuffled,
|
329
|
+
with_gt=dataset_config.with_gt,
|
330
|
+
dir=dataset_config.dir,
|
331
|
+
file_num=dataset_config.file_count,
|
332
|
+
)
|
333
|
+
super().__init__(
|
334
|
+
name=name,
|
335
|
+
description=description,
|
336
|
+
load_timeout=load_timeout,
|
337
|
+
optimize_timeout=optimize_timeout,
|
338
|
+
dataset=DatasetManager(data=dataset),
|
339
|
+
)
|
340
|
+
|
341
|
+
|
292
342
|
type2case = {
|
293
343
|
CaseType.CapacityDim960: CapacityDim960,
|
294
344
|
CaseType.CapacityDim128: CapacityDim128,
|
295
|
-
|
296
345
|
CaseType.Performance768D100M: Performance768D100M,
|
297
346
|
CaseType.Performance768D10M: Performance768D10M,
|
298
347
|
CaseType.Performance768D1M: Performance768D1M,
|
299
|
-
|
300
348
|
CaseType.Performance768D10M1P: Performance768D10M1P,
|
301
349
|
CaseType.Performance768D1M1P: Performance768D1M1P,
|
302
350
|
CaseType.Performance768D10M99P: Performance768D10M99P,
|
303
351
|
CaseType.Performance768D1M99P: Performance768D1M99P,
|
304
|
-
|
305
352
|
CaseType.Performance1536D500K: Performance1536D500K,
|
306
353
|
CaseType.Performance1536D5M: Performance1536D5M,
|
307
|
-
|
308
354
|
CaseType.Performance1536D500K1P: Performance1536D500K1P,
|
309
355
|
CaseType.Performance1536D5M1P: Performance1536D5M1P,
|
310
|
-
|
311
356
|
CaseType.Performance1536D500K99P: Performance1536D500K99P,
|
312
357
|
CaseType.Performance1536D5M99P: Performance1536D5M99P,
|
313
358
|
CaseType.Performance1536D50K: Performance1536D50K,
|
359
|
+
CaseType.PerformanceCustomDataset: PerformanceCustomDataset,
|
314
360
|
}
|
@@ -30,8 +30,11 @@ class DB(Enum):
|
|
30
30
|
WeaviateCloud = "WeaviateCloud"
|
31
31
|
PgVector = "PgVector"
|
32
32
|
PgVectoRS = "PgVectoRS"
|
33
|
+
PgVectorScale = "PgVectorScale"
|
33
34
|
Redis = "Redis"
|
35
|
+
MemoryDB = "MemoryDB"
|
34
36
|
Chroma = "Chroma"
|
37
|
+
AWSOpenSearch = "OpenSearch"
|
35
38
|
Test = "test"
|
36
39
|
|
37
40
|
|
@@ -69,15 +72,27 @@ class DB(Enum):
|
|
69
72
|
if self == DB.PgVectoRS:
|
70
73
|
from .pgvecto_rs.pgvecto_rs import PgVectoRS
|
71
74
|
return PgVectoRS
|
75
|
+
|
76
|
+
if self == DB.PgVectorScale:
|
77
|
+
from .pgvectorscale.pgvectorscale import PgVectorScale
|
78
|
+
return PgVectorScale
|
72
79
|
|
73
80
|
if self == DB.Redis:
|
74
81
|
from .redis.redis import Redis
|
75
82
|
return Redis
|
83
|
+
|
84
|
+
if self == DB.MemoryDB:
|
85
|
+
from .memorydb.memorydb import MemoryDB
|
86
|
+
return MemoryDB
|
76
87
|
|
77
88
|
if self == DB.Chroma:
|
78
89
|
from .chroma.chroma import ChromaClient
|
79
90
|
return ChromaClient
|
80
91
|
|
92
|
+
if self == DB.AWSOpenSearch:
|
93
|
+
from .aws_opensearch.aws_opensearch import AWSOpenSearch
|
94
|
+
return AWSOpenSearch
|
95
|
+
|
81
96
|
@property
|
82
97
|
def config_cls(self) -> Type[DBConfig]:
|
83
98
|
"""Import while in use"""
|
@@ -113,14 +128,26 @@ class DB(Enum):
|
|
113
128
|
from .pgvecto_rs.config import PgVectoRSConfig
|
114
129
|
return PgVectoRSConfig
|
115
130
|
|
131
|
+
if self == DB.PgVectorScale:
|
132
|
+
from .pgvectorscale.config import PgVectorScaleConfig
|
133
|
+
return PgVectorScaleConfig
|
134
|
+
|
116
135
|
if self == DB.Redis:
|
117
136
|
from .redis.config import RedisConfig
|
118
137
|
return RedisConfig
|
138
|
+
|
139
|
+
if self == DB.MemoryDB:
|
140
|
+
from .memorydb.config import MemoryDBConfig
|
141
|
+
return MemoryDBConfig
|
119
142
|
|
120
143
|
if self == DB.Chroma:
|
121
144
|
from .chroma.config import ChromaConfig
|
122
145
|
return ChromaConfig
|
123
146
|
|
147
|
+
if self == DB.AWSOpenSearch:
|
148
|
+
from .aws_opensearch.config import AWSOpenSearchConfig
|
149
|
+
return AWSOpenSearchConfig
|
150
|
+
|
124
151
|
def case_config_cls(self, index_type: IndexType | None = None) -> Type[DBCaseConfig]:
|
125
152
|
if self == DB.Milvus:
|
126
153
|
from .milvus.config import _milvus_case_config
|
@@ -150,6 +177,14 @@ class DB(Enum):
|
|
150
177
|
from .pgvecto_rs.config import _pgvecto_rs_case_config
|
151
178
|
return _pgvecto_rs_case_config.get(index_type)
|
152
179
|
|
180
|
+
if self == DB.AWSOpenSearch:
|
181
|
+
from .aws_opensearch.config import AWSOpenSearchIndexConfig
|
182
|
+
return AWSOpenSearchIndexConfig
|
183
|
+
|
184
|
+
if self == DB.PgVectorScale:
|
185
|
+
from .pgvectorscale.config import _pgvectorscale_case_config
|
186
|
+
return _pgvectorscale_case_config.get(index_type)
|
187
|
+
|
153
188
|
# DB.Pinecone, DB.Chroma, DB.Redis
|
154
189
|
return EmptyDBCaseConfig
|
155
190
|
|
@@ -15,6 +15,7 @@ class MetricType(str, Enum):
|
|
15
15
|
class IndexType(str, Enum):
|
16
16
|
HNSW = "HNSW"
|
17
17
|
DISKANN = "DISKANN"
|
18
|
+
STREAMING_DISKANN = "DISKANN"
|
18
19
|
IVFFlat = "IVF_FLAT"
|
19
20
|
IVFSQ8 = "IVF_SQ8"
|
20
21
|
Flat = "FLAT"
|
@@ -38,6 +39,22 @@ class DBConfig(ABC, BaseModel):
|
|
38
39
|
"""
|
39
40
|
|
40
41
|
db_label: str = ""
|
42
|
+
version: str = ""
|
43
|
+
note: str = ""
|
44
|
+
|
45
|
+
@staticmethod
|
46
|
+
def common_short_configs() -> list[str]:
|
47
|
+
"""
|
48
|
+
short input, such as `db_label`, `version`
|
49
|
+
"""
|
50
|
+
return ["version", "db_label"]
|
51
|
+
|
52
|
+
@staticmethod
|
53
|
+
def common_long_configs() -> list[str]:
|
54
|
+
"""
|
55
|
+
long input, such as `note`
|
56
|
+
"""
|
57
|
+
return ["note"]
|
41
58
|
|
42
59
|
@abstractmethod
|
43
60
|
def to_dict(self) -> dict:
|
@@ -45,7 +62,10 @@ class DBConfig(ABC, BaseModel):
|
|
45
62
|
|
46
63
|
@validator("*")
|
47
64
|
def not_empty_field(cls, v, field):
|
48
|
-
if
|
65
|
+
if (
|
66
|
+
field.name in cls.common_short_configs()
|
67
|
+
or field.name in cls.common_long_configs()
|
68
|
+
):
|
49
69
|
return v
|
50
70
|
if not v and isinstance(v, (str, SecretStr)):
|
51
71
|
raise ValueError("Empty string!")
|
@@ -0,0 +1,159 @@
|
|
1
|
+
import logging
|
2
|
+
from contextlib import contextmanager
|
3
|
+
import time
|
4
|
+
from typing import Iterable, Type
|
5
|
+
from ..api import VectorDB, DBCaseConfig, DBConfig, IndexType
|
6
|
+
from .config import AWSOpenSearchConfig, AWSOpenSearchIndexConfig
|
7
|
+
from opensearchpy import OpenSearch
|
8
|
+
from opensearchpy.helpers import bulk
|
9
|
+
|
10
|
+
log = logging.getLogger(__name__)
|
11
|
+
|
12
|
+
|
13
|
+
class AWSOpenSearch(VectorDB):
|
14
|
+
def __init__(
|
15
|
+
self,
|
16
|
+
dim: int,
|
17
|
+
db_config: dict,
|
18
|
+
db_case_config: AWSOpenSearchIndexConfig,
|
19
|
+
index_name: str = "vdb_bench_index", # must be lowercase
|
20
|
+
id_col_name: str = "id",
|
21
|
+
vector_col_name: str = "embedding",
|
22
|
+
drop_old: bool = False,
|
23
|
+
**kwargs,
|
24
|
+
):
|
25
|
+
self.dim = dim
|
26
|
+
self.db_config = db_config
|
27
|
+
self.case_config = db_case_config
|
28
|
+
self.index_name = index_name
|
29
|
+
self.id_col_name = id_col_name
|
30
|
+
self.category_col_names = [
|
31
|
+
f"scalar-{categoryCount}" for categoryCount in [2, 5, 10, 100, 1000]
|
32
|
+
]
|
33
|
+
self.vector_col_name = vector_col_name
|
34
|
+
|
35
|
+
log.info(f"AWS_OpenSearch client config: {self.db_config}")
|
36
|
+
client = OpenSearch(**self.db_config)
|
37
|
+
if drop_old:
|
38
|
+
log.info(f"AWS_OpenSearch client drop old index: {self.index_name}")
|
39
|
+
is_existed = client.indices.exists(index=self.index_name)
|
40
|
+
if is_existed:
|
41
|
+
client.indices.delete(index=self.index_name)
|
42
|
+
self._create_index(client)
|
43
|
+
|
44
|
+
@classmethod
|
45
|
+
def config_cls(cls) -> AWSOpenSearchConfig:
|
46
|
+
return AWSOpenSearchConfig
|
47
|
+
|
48
|
+
@classmethod
|
49
|
+
def case_config_cls(
|
50
|
+
cls, index_type: IndexType | None = None
|
51
|
+
) -> AWSOpenSearchIndexConfig:
|
52
|
+
return AWSOpenSearchIndexConfig
|
53
|
+
|
54
|
+
def _create_index(self, client: OpenSearch):
|
55
|
+
settings = {
|
56
|
+
"index": {
|
57
|
+
"knn": True,
|
58
|
+
# "number_of_shards": 5,
|
59
|
+
# "refresh_interval": "600s",
|
60
|
+
}
|
61
|
+
}
|
62
|
+
mappings = {
|
63
|
+
"properties": {
|
64
|
+
self.id_col_name: {"type": "integer"},
|
65
|
+
**{
|
66
|
+
categoryCol: {"type": "keyword"}
|
67
|
+
for categoryCol in self.category_col_names
|
68
|
+
},
|
69
|
+
self.vector_col_name: {
|
70
|
+
"type": "knn_vector",
|
71
|
+
"dimension": self.dim,
|
72
|
+
"method": self.case_config.index_param(),
|
73
|
+
},
|
74
|
+
}
|
75
|
+
}
|
76
|
+
try:
|
77
|
+
client.indices.create(
|
78
|
+
index=self.index_name, body=dict(settings=settings, mappings=mappings)
|
79
|
+
)
|
80
|
+
except Exception as e:
|
81
|
+
log.warning(f"Failed to create index: {self.index_name} error: {str(e)}")
|
82
|
+
raise e from None
|
83
|
+
|
84
|
+
@contextmanager
|
85
|
+
def init(self) -> None:
|
86
|
+
"""connect to elasticsearch"""
|
87
|
+
self.client = OpenSearch(**self.db_config)
|
88
|
+
|
89
|
+
yield
|
90
|
+
# self.client.transport.close()
|
91
|
+
self.client = None
|
92
|
+
del self.client
|
93
|
+
|
94
|
+
def insert_embeddings(
|
95
|
+
self,
|
96
|
+
embeddings: Iterable[list[float]],
|
97
|
+
metadata: list[int],
|
98
|
+
**kwargs,
|
99
|
+
) -> tuple[int, Exception]:
|
100
|
+
"""Insert the embeddings to the elasticsearch."""
|
101
|
+
assert self.client is not None, "should self.init() first"
|
102
|
+
|
103
|
+
insert_data = []
|
104
|
+
for i in range(len(embeddings)):
|
105
|
+
insert_data.append({"index": {"_index": self.index_name, "_id": metadata[i]}})
|
106
|
+
insert_data.append({self.vector_col_name: embeddings[i]})
|
107
|
+
try:
|
108
|
+
resp = self.client.bulk(insert_data)
|
109
|
+
log.info(f"AWS_OpenSearch adding documents: {len(resp['items'])}")
|
110
|
+
resp = self.client.indices.stats(self.index_name)
|
111
|
+
log.info(f"Total document count in index: {resp['_all']['primaries']['indexing']['index_total']}")
|
112
|
+
return (len(embeddings), None)
|
113
|
+
except Exception as e:
|
114
|
+
log.warning(f"Failed to insert data: {self.index_name} error: {str(e)}")
|
115
|
+
time.sleep(10)
|
116
|
+
return self.insert_embeddings(embeddings, metadata)
|
117
|
+
|
118
|
+
def search_embedding(
|
119
|
+
self,
|
120
|
+
query: list[float],
|
121
|
+
k: int = 100,
|
122
|
+
filters: dict | None = None,
|
123
|
+
) -> list[int]:
|
124
|
+
"""Get k most similar embeddings to query vector.
|
125
|
+
|
126
|
+
Args:
|
127
|
+
query(list[float]): query embedding to look up documents similar to.
|
128
|
+
k(int): Number of most similar embeddings to return. Defaults to 100.
|
129
|
+
filters(dict, optional): filtering expression to filter the data while searching.
|
130
|
+
|
131
|
+
Returns:
|
132
|
+
list[tuple[int, float]]: list of k most similar embeddings in (id, score) tuple to the query embedding.
|
133
|
+
"""
|
134
|
+
assert self.client is not None, "should self.init() first"
|
135
|
+
|
136
|
+
body = {
|
137
|
+
"size": k,
|
138
|
+
"query": {"knn": {self.vector_col_name: {"vector": query, "k": k}}},
|
139
|
+
}
|
140
|
+
try:
|
141
|
+
resp = self.client.search(index=self.index_name, body=body)
|
142
|
+
log.info(f'Search took: {resp["took"]}')
|
143
|
+
log.info(f'Search shards: {resp["_shards"]}')
|
144
|
+
log.info(f'Search hits total: {resp["hits"]["total"]}')
|
145
|
+
result = [int(d["_id"]) for d in resp["hits"]["hits"]]
|
146
|
+
# log.info(f'success! length={len(res)}')
|
147
|
+
|
148
|
+
return result
|
149
|
+
except Exception as e:
|
150
|
+
log.warning(f"Failed to search: {self.index_name} error: {str(e)}")
|
151
|
+
raise e from None
|
152
|
+
|
153
|
+
def optimize(self):
|
154
|
+
"""optimize will be called between insertion and search in performance cases."""
|
155
|
+
pass
|
156
|
+
|
157
|
+
def ready_to_load(self):
|
158
|
+
"""ready_to_load will be called before load in load cases."""
|
159
|
+
pass
|
@@ -0,0 +1,44 @@
|
|
1
|
+
from typing import Annotated, TypedDict, Unpack
|
2
|
+
|
3
|
+
import click
|
4
|
+
from pydantic import SecretStr
|
5
|
+
|
6
|
+
from ....cli.cli import (
|
7
|
+
CommonTypedDict,
|
8
|
+
HNSWFlavor2,
|
9
|
+
cli,
|
10
|
+
click_parameter_decorators_from_typed_dict,
|
11
|
+
run,
|
12
|
+
)
|
13
|
+
from .. import DB
|
14
|
+
|
15
|
+
|
16
|
+
class AWSOpenSearchTypedDict(TypedDict):
|
17
|
+
host: Annotated[
|
18
|
+
str, click.option("--host", type=str, help="Db host", required=True)
|
19
|
+
]
|
20
|
+
port: Annotated[int, click.option("--port", type=int, default=443, help="Db Port")]
|
21
|
+
user: Annotated[str, click.option("--user", type=str, default="admin", help="Db User")]
|
22
|
+
password: Annotated[str, click.option("--password", type=str, help="Db password")]
|
23
|
+
|
24
|
+
|
25
|
+
class AWSOpenSearchHNSWTypedDict(CommonTypedDict, AWSOpenSearchTypedDict, HNSWFlavor2):
|
26
|
+
...
|
27
|
+
|
28
|
+
|
29
|
+
@cli.command()
|
30
|
+
@click_parameter_decorators_from_typed_dict(AWSOpenSearchHNSWTypedDict)
|
31
|
+
def AWSOpenSearch(**parameters: Unpack[AWSOpenSearchHNSWTypedDict]):
|
32
|
+
from .config import AWSOpenSearchConfig, AWSOpenSearchIndexConfig
|
33
|
+
run(
|
34
|
+
db=DB.AWSOpenSearch,
|
35
|
+
db_config=AWSOpenSearchConfig(
|
36
|
+
host=parameters["host"],
|
37
|
+
port=parameters["port"],
|
38
|
+
user=parameters["user"],
|
39
|
+
password=SecretStr(parameters["password"]),
|
40
|
+
),
|
41
|
+
db_case_config=AWSOpenSearchIndexConfig(
|
42
|
+
),
|
43
|
+
**parameters,
|
44
|
+
)
|
@@ -0,0 +1,58 @@
|
|
1
|
+
from enum import Enum
|
2
|
+
from pydantic import SecretStr, BaseModel
|
3
|
+
|
4
|
+
from ..api import DBConfig, DBCaseConfig, MetricType, IndexType
|
5
|
+
|
6
|
+
|
7
|
+
class AWSOpenSearchConfig(DBConfig, BaseModel):
|
8
|
+
host: str = ""
|
9
|
+
port: int = 443
|
10
|
+
user: str = ""
|
11
|
+
password: SecretStr = ""
|
12
|
+
|
13
|
+
def to_dict(self) -> dict:
|
14
|
+
return {
|
15
|
+
"hosts": [{'host': self.host, 'port': self.port}],
|
16
|
+
"http_auth": (self.user, self.password.get_secret_value()),
|
17
|
+
"use_ssl": True,
|
18
|
+
"http_compress": True,
|
19
|
+
"verify_certs": True,
|
20
|
+
"ssl_assert_hostname": False,
|
21
|
+
"ssl_show_warn": False,
|
22
|
+
"timeout": 600,
|
23
|
+
}
|
24
|
+
|
25
|
+
|
26
|
+
class AWSOS_Engine(Enum):
|
27
|
+
nmslib = "nmslib"
|
28
|
+
faiss = "faiss"
|
29
|
+
lucene = "Lucene"
|
30
|
+
|
31
|
+
|
32
|
+
class AWSOpenSearchIndexConfig(BaseModel, DBCaseConfig):
|
33
|
+
metric_type: MetricType = MetricType.L2
|
34
|
+
engine: AWSOS_Engine = AWSOS_Engine.nmslib
|
35
|
+
efConstruction: int = 360
|
36
|
+
M: int = 30
|
37
|
+
|
38
|
+
def parse_metric(self) -> str:
|
39
|
+
if self.metric_type == MetricType.IP:
|
40
|
+
return "innerproduct" # only support faiss / nmslib, not for Lucene.
|
41
|
+
elif self.metric_type == MetricType.COSINE:
|
42
|
+
return "cosinesimil"
|
43
|
+
return "l2"
|
44
|
+
|
45
|
+
def index_param(self) -> dict:
|
46
|
+
params = {
|
47
|
+
"name": "hnsw",
|
48
|
+
"space_type": self.parse_metric(),
|
49
|
+
"engine": self.engine.value,
|
50
|
+
"parameters": {
|
51
|
+
"ef_construction": self.efConstruction,
|
52
|
+
"m": self.M
|
53
|
+
}
|
54
|
+
}
|
55
|
+
return params
|
56
|
+
|
57
|
+
def search_param(self) -> dict:
|
58
|
+
return {}
|
@@ -0,0 +1,125 @@
|
|
1
|
+
import time, random
|
2
|
+
from opensearchpy import OpenSearch
|
3
|
+
from opensearch_dsl import Search, Document, Text, Keyword
|
4
|
+
|
5
|
+
_HOST = 'xxxxxx.us-west-2.es.amazonaws.com'
|
6
|
+
_PORT = 443
|
7
|
+
_AUTH = ('admin', 'xxxxxx') # For testing only. Don't store credentials in code.
|
8
|
+
|
9
|
+
_INDEX_NAME = 'my-dsl-index'
|
10
|
+
_BATCH = 100
|
11
|
+
_ROWS = 100
|
12
|
+
_DIM = 128
|
13
|
+
_TOPK = 10
|
14
|
+
|
15
|
+
|
16
|
+
def create_client():
|
17
|
+
client = OpenSearch(
|
18
|
+
hosts=[{'host': _HOST, 'port': _PORT}],
|
19
|
+
http_compress=True, # enables gzip compression for request bodies
|
20
|
+
http_auth=_AUTH,
|
21
|
+
use_ssl=True,
|
22
|
+
verify_certs=True,
|
23
|
+
ssl_assert_hostname=False,
|
24
|
+
ssl_show_warn=False,
|
25
|
+
)
|
26
|
+
return client
|
27
|
+
|
28
|
+
|
29
|
+
def create_index(client, index_name):
|
30
|
+
settings = {
|
31
|
+
"index": {
|
32
|
+
"knn": True,
|
33
|
+
"number_of_shards": 1,
|
34
|
+
"refresh_interval": "5s",
|
35
|
+
}
|
36
|
+
}
|
37
|
+
mappings = {
|
38
|
+
"properties": {
|
39
|
+
"embedding": {
|
40
|
+
"type": "knn_vector",
|
41
|
+
"dimension": _DIM,
|
42
|
+
"method": {
|
43
|
+
"engine": "nmslib",
|
44
|
+
"name": "hnsw",
|
45
|
+
"space_type": "l2",
|
46
|
+
"parameters": {
|
47
|
+
"ef_construction": 128,
|
48
|
+
"m": 24,
|
49
|
+
}
|
50
|
+
}
|
51
|
+
}
|
52
|
+
}
|
53
|
+
}
|
54
|
+
|
55
|
+
response = client.indices.create(index=index_name, body=dict(settings=settings, mappings=mappings))
|
56
|
+
print('\nCreating index:')
|
57
|
+
print(response)
|
58
|
+
|
59
|
+
|
60
|
+
def delete_index(client, index_name):
|
61
|
+
response = client.indices.delete(index=index_name)
|
62
|
+
print('\nDeleting index:')
|
63
|
+
print(response)
|
64
|
+
|
65
|
+
|
66
|
+
def bulk_insert(client, index_name):
|
67
|
+
# Perform bulk operations
|
68
|
+
ids = [i for i in range(_ROWS)]
|
69
|
+
vec = [[random.random() for _ in range(_DIM)] for _ in range(_ROWS)]
|
70
|
+
|
71
|
+
docs = []
|
72
|
+
for i in range(0, _ROWS, _BATCH):
|
73
|
+
docs.clear()
|
74
|
+
for j in range(0, _BATCH):
|
75
|
+
docs.append({"index": {"_index": index_name, "_id": ids[i+j]}})
|
76
|
+
docs.append({"embedding": vec[i+j]})
|
77
|
+
response = client.bulk(docs)
|
78
|
+
print('\nAdding documents:', len(response['items']), response['errors'])
|
79
|
+
response = client.indices.stats(index_name)
|
80
|
+
print('\nTotal document count in index:', response['_all']['primaries']['indexing']['index_total'])
|
81
|
+
|
82
|
+
|
83
|
+
def search(client, index_name):
|
84
|
+
# Search for the document.
|
85
|
+
search_body = {
|
86
|
+
"size": _TOPK,
|
87
|
+
"query": {
|
88
|
+
"knn": {
|
89
|
+
"embedding": {
|
90
|
+
"vector": [random.random() for _ in range(_DIM)],
|
91
|
+
"k": _TOPK,
|
92
|
+
}
|
93
|
+
}
|
94
|
+
}
|
95
|
+
}
|
96
|
+
while True:
|
97
|
+
response = client.search(index=index_name, body=search_body)
|
98
|
+
print(f'\nSearch took: {response["took"]}')
|
99
|
+
print(f'\nSearch shards: {response["_shards"]}')
|
100
|
+
print(f'\nSearch hits total: {response["hits"]["total"]}')
|
101
|
+
result = response["hits"]["hits"]
|
102
|
+
if len(result) != 0:
|
103
|
+
print('\nSearch results:')
|
104
|
+
for hit in response["hits"]["hits"]:
|
105
|
+
print(hit["_id"], hit["_score"])
|
106
|
+
break
|
107
|
+
else:
|
108
|
+
print('\nSearch not ready, sleep 1s')
|
109
|
+
time.sleep(1)
|
110
|
+
|
111
|
+
|
112
|
+
def main():
|
113
|
+
client = create_client()
|
114
|
+
try:
|
115
|
+
create_index(client, _INDEX_NAME)
|
116
|
+
bulk_insert(client, _INDEX_NAME)
|
117
|
+
search(client, _INDEX_NAME)
|
118
|
+
delete_index(client, _INDEX_NAME)
|
119
|
+
except Exception as e:
|
120
|
+
print(e)
|
121
|
+
delete_index(client, _INDEX_NAME)
|
122
|
+
|
123
|
+
|
124
|
+
if __name__ == '__main__':
|
125
|
+
main()
|