vectordb-bench 0.0.19__py3-none-any.whl → 0.0.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vectordb_bench/__init__.py +49 -24
- vectordb_bench/__main__.py +4 -3
- vectordb_bench/backend/assembler.py +12 -13
- vectordb_bench/backend/cases.py +55 -45
- vectordb_bench/backend/clients/__init__.py +85 -14
- vectordb_bench/backend/clients/aliyun_elasticsearch/aliyun_elasticsearch.py +1 -2
- vectordb_bench/backend/clients/aliyun_elasticsearch/config.py +3 -4
- vectordb_bench/backend/clients/aliyun_opensearch/aliyun_opensearch.py +112 -77
- vectordb_bench/backend/clients/aliyun_opensearch/config.py +6 -7
- vectordb_bench/backend/clients/alloydb/alloydb.py +59 -84
- vectordb_bench/backend/clients/alloydb/cli.py +51 -34
- vectordb_bench/backend/clients/alloydb/config.py +30 -30
- vectordb_bench/backend/clients/api.py +13 -24
- vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py +50 -54
- vectordb_bench/backend/clients/aws_opensearch/cli.py +4 -7
- vectordb_bench/backend/clients/aws_opensearch/config.py +13 -9
- vectordb_bench/backend/clients/aws_opensearch/run.py +69 -59
- vectordb_bench/backend/clients/chroma/chroma.py +39 -40
- vectordb_bench/backend/clients/chroma/config.py +4 -2
- vectordb_bench/backend/clients/elastic_cloud/config.py +5 -5
- vectordb_bench/backend/clients/elastic_cloud/elastic_cloud.py +24 -26
- vectordb_bench/backend/clients/memorydb/cli.py +8 -8
- vectordb_bench/backend/clients/memorydb/config.py +2 -2
- vectordb_bench/backend/clients/memorydb/memorydb.py +67 -58
- vectordb_bench/backend/clients/milvus/cli.py +41 -83
- vectordb_bench/backend/clients/milvus/config.py +18 -8
- vectordb_bench/backend/clients/milvus/milvus.py +19 -39
- vectordb_bench/backend/clients/pgdiskann/cli.py +29 -22
- vectordb_bench/backend/clients/pgdiskann/config.py +29 -26
- vectordb_bench/backend/clients/pgdiskann/pgdiskann.py +56 -77
- vectordb_bench/backend/clients/pgvecto_rs/cli.py +9 -11
- vectordb_bench/backend/clients/pgvecto_rs/config.py +8 -14
- vectordb_bench/backend/clients/pgvecto_rs/pgvecto_rs.py +34 -43
- vectordb_bench/backend/clients/pgvector/cli.py +40 -31
- vectordb_bench/backend/clients/pgvector/config.py +63 -73
- vectordb_bench/backend/clients/pgvector/pgvector.py +98 -104
- vectordb_bench/backend/clients/pgvectorscale/cli.py +38 -24
- vectordb_bench/backend/clients/pgvectorscale/config.py +14 -15
- vectordb_bench/backend/clients/pgvectorscale/pgvectorscale.py +39 -49
- vectordb_bench/backend/clients/pinecone/config.py +1 -0
- vectordb_bench/backend/clients/pinecone/pinecone.py +15 -25
- vectordb_bench/backend/clients/qdrant_cloud/config.py +11 -10
- vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py +41 -35
- vectordb_bench/backend/clients/redis/cli.py +6 -12
- vectordb_bench/backend/clients/redis/config.py +7 -5
- vectordb_bench/backend/clients/redis/redis.py +95 -62
- vectordb_bench/backend/clients/test/cli.py +2 -3
- vectordb_bench/backend/clients/test/config.py +2 -2
- vectordb_bench/backend/clients/test/test.py +5 -9
- vectordb_bench/backend/clients/weaviate_cloud/cli.py +3 -4
- vectordb_bench/backend/clients/weaviate_cloud/config.py +2 -2
- vectordb_bench/backend/clients/weaviate_cloud/weaviate_cloud.py +37 -26
- vectordb_bench/backend/clients/zilliz_cloud/cli.py +14 -11
- vectordb_bench/backend/clients/zilliz_cloud/config.py +2 -4
- vectordb_bench/backend/clients/zilliz_cloud/zilliz_cloud.py +1 -1
- vectordb_bench/backend/data_source.py +18 -14
- vectordb_bench/backend/dataset.py +47 -27
- vectordb_bench/backend/result_collector.py +2 -3
- vectordb_bench/backend/runner/__init__.py +4 -6
- vectordb_bench/backend/runner/mp_runner.py +56 -23
- vectordb_bench/backend/runner/rate_runner.py +30 -19
- vectordb_bench/backend/runner/read_write_runner.py +46 -22
- vectordb_bench/backend/runner/serial_runner.py +81 -46
- vectordb_bench/backend/runner/util.py +4 -3
- vectordb_bench/backend/task_runner.py +92 -92
- vectordb_bench/backend/utils.py +17 -10
- vectordb_bench/base.py +0 -1
- vectordb_bench/cli/cli.py +65 -60
- vectordb_bench/cli/vectordbbench.py +6 -7
- vectordb_bench/frontend/components/check_results/charts.py +8 -19
- vectordb_bench/frontend/components/check_results/data.py +4 -16
- vectordb_bench/frontend/components/check_results/filters.py +8 -16
- vectordb_bench/frontend/components/check_results/nav.py +4 -4
- vectordb_bench/frontend/components/check_results/priceTable.py +1 -3
- vectordb_bench/frontend/components/check_results/stPageConfig.py +2 -1
- vectordb_bench/frontend/components/concurrent/charts.py +12 -12
- vectordb_bench/frontend/components/custom/displayCustomCase.py +17 -11
- vectordb_bench/frontend/components/custom/displaypPrams.py +4 -2
- vectordb_bench/frontend/components/custom/getCustomConfig.py +1 -2
- vectordb_bench/frontend/components/custom/initStyle.py +1 -1
- vectordb_bench/frontend/components/get_results/saveAsImage.py +2 -0
- vectordb_bench/frontend/components/run_test/caseSelector.py +3 -9
- vectordb_bench/frontend/components/run_test/dbConfigSetting.py +1 -4
- vectordb_bench/frontend/components/run_test/dbSelector.py +1 -1
- vectordb_bench/frontend/components/run_test/generateTasks.py +8 -8
- vectordb_bench/frontend/components/run_test/submitTask.py +14 -18
- vectordb_bench/frontend/components/tables/data.py +3 -6
- vectordb_bench/frontend/config/dbCaseConfigs.py +51 -84
- vectordb_bench/frontend/pages/concurrent.py +3 -5
- vectordb_bench/frontend/pages/custom.py +30 -9
- vectordb_bench/frontend/pages/quries_per_dollar.py +3 -3
- vectordb_bench/frontend/pages/run_test.py +3 -7
- vectordb_bench/frontend/utils.py +1 -1
- vectordb_bench/frontend/vdb_benchmark.py +4 -6
- vectordb_bench/interface.py +45 -24
- vectordb_bench/log_util.py +59 -64
- vectordb_bench/metric.py +10 -11
- vectordb_bench/models.py +26 -43
- {vectordb_bench-0.0.19.dist-info → vectordb_bench-0.0.21.dist-info}/METADATA +22 -15
- vectordb_bench-0.0.21.dist-info/RECORD +135 -0
- {vectordb_bench-0.0.19.dist-info → vectordb_bench-0.0.21.dist-info}/WHEEL +1 -1
- vectordb_bench-0.0.19.dist-info/RECORD +0 -135
- {vectordb_bench-0.0.19.dist-info → vectordb_bench-0.0.21.dist-info}/LICENSE +0 -0
- {vectordb_bench-0.0.19.dist-info → vectordb_bench-0.0.21.dist-info}/entry_points.txt +0 -0
- {vectordb_bench-0.0.19.dist-info → vectordb_bench-0.0.21.dist-info}/top_level.txt +0 -0
@@ -1,9 +1,8 @@
|
|
1
1
|
from abc import ABC, abstractmethod
|
2
|
-
from enum import Enum
|
3
|
-
from typing import Any, Type
|
4
2
|
from contextlib import contextmanager
|
3
|
+
from enum import Enum
|
5
4
|
|
6
|
-
from pydantic import BaseModel,
|
5
|
+
from pydantic import BaseModel, SecretStr, validator
|
7
6
|
|
8
7
|
|
9
8
|
class MetricType(str, Enum):
|
@@ -65,13 +64,10 @@ class DBConfig(ABC, BaseModel):
|
|
65
64
|
raise NotImplementedError
|
66
65
|
|
67
66
|
@validator("*")
|
68
|
-
def not_empty_field(cls, v, field):
|
69
|
-
if (
|
70
|
-
field.name in cls.common_short_configs()
|
71
|
-
or field.name in cls.common_long_configs()
|
72
|
-
):
|
67
|
+
def not_empty_field(cls, v: any, field: any):
|
68
|
+
if field.name in cls.common_short_configs() or field.name in cls.common_long_configs():
|
73
69
|
return v
|
74
|
-
if not v and isinstance(v,
|
70
|
+
if not v and isinstance(v, str | SecretStr):
|
75
71
|
raise ValueError("Empty string!")
|
76
72
|
return v
|
77
73
|
|
@@ -141,6 +137,13 @@ class VectorDB(ABC):
|
|
141
137
|
@contextmanager
|
142
138
|
def init(self) -> None:
|
143
139
|
"""create and destory connections to database.
|
140
|
+
Why contextmanager:
|
141
|
+
|
142
|
+
In multiprocessing search tasks, vectordbbench might init
|
143
|
+
totally hundreds of thousands of connections with DB server.
|
144
|
+
|
145
|
+
Too many connections may drain local FDs or server connection resources.
|
146
|
+
If the DB client doesn't have `close()` method, just set the object to None.
|
144
147
|
|
145
148
|
Examples:
|
146
149
|
>>> with self.init():
|
@@ -191,9 +194,8 @@ class VectorDB(ABC):
|
|
191
194
|
"""
|
192
195
|
raise NotImplementedError
|
193
196
|
|
194
|
-
# TODO: remove
|
195
197
|
@abstractmethod
|
196
|
-
def optimize(self):
|
198
|
+
def optimize(self, data_size: int | None = None):
|
197
199
|
"""optimize will be called between insertion and search in performance cases.
|
198
200
|
|
199
201
|
Should be blocked until the vectorDB is ready to be tested on
|
@@ -203,16 +205,3 @@ class VectorDB(ABC):
|
|
203
205
|
Optimize's execution time is limited, the limited time is based on cases.
|
204
206
|
"""
|
205
207
|
raise NotImplementedError
|
206
|
-
|
207
|
-
def optimize_with_size(self, data_size: int):
|
208
|
-
self.optimize()
|
209
|
-
|
210
|
-
# TODO: remove
|
211
|
-
@abstractmethod
|
212
|
-
def ready_to_load(self):
|
213
|
-
"""ready_to_load will be called before load in load cases.
|
214
|
-
|
215
|
-
Should be blocked until the vectorDB is ready to be tested on
|
216
|
-
heavy load cases.
|
217
|
-
"""
|
218
|
-
raise NotImplementedError
|
@@ -1,14 +1,18 @@
|
|
1
1
|
import logging
|
2
|
-
from contextlib import contextmanager
|
3
2
|
import time
|
4
|
-
from
|
5
|
-
from
|
6
|
-
|
3
|
+
from collections.abc import Iterable
|
4
|
+
from contextlib import contextmanager
|
5
|
+
|
7
6
|
from opensearchpy import OpenSearch
|
8
|
-
|
7
|
+
|
8
|
+
from ..api import IndexType, VectorDB
|
9
|
+
from .config import AWSOpenSearchConfig, AWSOpenSearchIndexConfig, AWSOS_Engine
|
9
10
|
|
10
11
|
log = logging.getLogger(__name__)
|
11
12
|
|
13
|
+
WAITING_FOR_REFRESH_SEC = 30
|
14
|
+
WAITING_FOR_FORCE_MERGE_SEC = 30
|
15
|
+
|
12
16
|
|
13
17
|
class AWSOpenSearch(VectorDB):
|
14
18
|
def __init__(
|
@@ -17,7 +21,7 @@ class AWSOpenSearch(VectorDB):
|
|
17
21
|
db_config: dict,
|
18
22
|
db_case_config: AWSOpenSearchIndexConfig,
|
19
23
|
index_name: str = "vdb_bench_index", # must be lowercase
|
20
|
-
id_col_name: str = "
|
24
|
+
id_col_name: str = "_id",
|
21
25
|
vector_col_name: str = "embedding",
|
22
26
|
drop_old: bool = False,
|
23
27
|
**kwargs,
|
@@ -27,9 +31,7 @@ class AWSOpenSearch(VectorDB):
|
|
27
31
|
self.case_config = db_case_config
|
28
32
|
self.index_name = index_name
|
29
33
|
self.id_col_name = id_col_name
|
30
|
-
self.category_col_names = [
|
31
|
-
f"scalar-{categoryCount}" for categoryCount in [2, 5, 10, 100, 1000]
|
32
|
-
]
|
34
|
+
self.category_col_names = [f"scalar-{categoryCount}" for categoryCount in [2, 5, 10, 100, 1000]]
|
33
35
|
self.vector_col_name = vector_col_name
|
34
36
|
|
35
37
|
log.info(f"AWS_OpenSearch client config: {self.db_config}")
|
@@ -46,39 +48,32 @@ class AWSOpenSearch(VectorDB):
|
|
46
48
|
return AWSOpenSearchConfig
|
47
49
|
|
48
50
|
@classmethod
|
49
|
-
def case_config_cls(
|
50
|
-
cls, index_type: IndexType | None = None
|
51
|
-
) -> AWSOpenSearchIndexConfig:
|
51
|
+
def case_config_cls(cls, index_type: IndexType | None = None) -> AWSOpenSearchIndexConfig:
|
52
52
|
return AWSOpenSearchIndexConfig
|
53
53
|
|
54
54
|
def _create_index(self, client: OpenSearch):
|
55
55
|
settings = {
|
56
56
|
"index": {
|
57
57
|
"knn": True,
|
58
|
-
|
59
|
-
# "refresh_interval": "600s",
|
60
|
-
}
|
58
|
+
},
|
61
59
|
}
|
62
60
|
mappings = {
|
63
61
|
"properties": {
|
64
|
-
|
65
|
-
**{
|
66
|
-
categoryCol: {"type": "keyword"}
|
67
|
-
for categoryCol in self.category_col_names
|
68
|
-
},
|
62
|
+
**{categoryCol: {"type": "keyword"} for categoryCol in self.category_col_names},
|
69
63
|
self.vector_col_name: {
|
70
64
|
"type": "knn_vector",
|
71
65
|
"dimension": self.dim,
|
72
66
|
"method": self.case_config.index_param(),
|
73
67
|
},
|
74
|
-
}
|
68
|
+
},
|
75
69
|
}
|
76
70
|
try:
|
77
71
|
client.indices.create(
|
78
|
-
index=self.index_name,
|
72
|
+
index=self.index_name,
|
73
|
+
body={"settings": settings, "mappings": mappings},
|
79
74
|
)
|
80
75
|
except Exception as e:
|
81
|
-
log.warning(f"Failed to create index: {self.index_name} error: {
|
76
|
+
log.warning(f"Failed to create index: {self.index_name} error: {e!s}")
|
82
77
|
raise e from None
|
83
78
|
|
84
79
|
@contextmanager
|
@@ -87,7 +82,6 @@ class AWSOpenSearch(VectorDB):
|
|
87
82
|
self.client = OpenSearch(**self.db_config)
|
88
83
|
|
89
84
|
yield
|
90
|
-
# self.client.transport.close()
|
91
85
|
self.client = None
|
92
86
|
del self.client
|
93
87
|
|
@@ -102,16 +96,20 @@ class AWSOpenSearch(VectorDB):
|
|
102
96
|
|
103
97
|
insert_data = []
|
104
98
|
for i in range(len(embeddings)):
|
105
|
-
insert_data.append(
|
99
|
+
insert_data.append(
|
100
|
+
{"index": {"_index": self.index_name, self.id_col_name: metadata[i]}},
|
101
|
+
)
|
106
102
|
insert_data.append({self.vector_col_name: embeddings[i]})
|
107
103
|
try:
|
108
104
|
resp = self.client.bulk(insert_data)
|
109
105
|
log.info(f"AWS_OpenSearch adding documents: {len(resp['items'])}")
|
110
106
|
resp = self.client.indices.stats(self.index_name)
|
111
|
-
log.info(
|
107
|
+
log.info(
|
108
|
+
f"Total document count in index: {resp['_all']['primaries']['indexing']['index_total']}",
|
109
|
+
)
|
112
110
|
return (len(embeddings), None)
|
113
111
|
except Exception as e:
|
114
|
-
log.warning(f"Failed to insert data: {self.index_name} error: {
|
112
|
+
log.warning(f"Failed to insert data: {self.index_name} error: {e!s}")
|
115
113
|
time.sleep(10)
|
116
114
|
return self.insert_embeddings(embeddings, metadata)
|
117
115
|
|
@@ -136,23 +134,26 @@ class AWSOpenSearch(VectorDB):
|
|
136
134
|
body = {
|
137
135
|
"size": k,
|
138
136
|
"query": {"knn": {self.vector_col_name: {"vector": query, "k": k}}},
|
139
|
-
**({"filter": {"range": {self.id_col_name: {"gt": filters["id"]}}}} if filters else {})
|
137
|
+
**({"filter": {"range": {self.id_col_name: {"gt": filters["id"]}}}} if filters else {}),
|
140
138
|
}
|
141
139
|
try:
|
142
|
-
resp = self.client.search(
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
140
|
+
resp = self.client.search(
|
141
|
+
index=self.index_name,
|
142
|
+
body=body,
|
143
|
+
size=k,
|
144
|
+
_source=False,
|
145
|
+
docvalue_fields=[self.id_col_name],
|
146
|
+
stored_fields="_none_",
|
147
|
+
)
|
148
|
+
log.info(f"Search took: {resp['took']}")
|
149
|
+
log.info(f"Search shards: {resp['_shards']}")
|
150
|
+
log.info(f"Search hits total: {resp['hits']['total']}")
|
151
|
+
return [int(h["fields"][self.id_col_name][0]) for h in resp["hits"]["hits"]]
|
151
152
|
except Exception as e:
|
152
|
-
log.warning(f"Failed to search: {self.index_name} error: {
|
153
|
+
log.warning(f"Failed to search: {self.index_name} error: {e!s}")
|
153
154
|
raise e from None
|
154
155
|
|
155
|
-
def optimize(self):
|
156
|
+
def optimize(self, data_size: int | None = None):
|
156
157
|
"""optimize will be called between insertion and search in performance cases."""
|
157
158
|
# Call refresh first to ensure that all segments are created
|
158
159
|
self._refresh_index()
|
@@ -164,37 +165,32 @@ class AWSOpenSearch(VectorDB):
|
|
164
165
|
|
165
166
|
def _refresh_index(self):
|
166
167
|
log.debug(f"Starting refresh for index {self.index_name}")
|
167
|
-
SECONDS_WAITING_FOR_REFRESH_API_CALL_SEC = 30
|
168
168
|
while True:
|
169
169
|
try:
|
170
|
-
log.info(
|
170
|
+
log.info("Starting the Refresh Index..")
|
171
171
|
self.client.indices.refresh(index=self.index_name)
|
172
172
|
break
|
173
173
|
except Exception as e:
|
174
174
|
log.info(
|
175
|
-
f"Refresh errored out. Sleeping for {
|
176
|
-
|
175
|
+
f"Refresh errored out. Sleeping for {WAITING_FOR_REFRESH_SEC} sec and then Retrying : {e}",
|
176
|
+
)
|
177
|
+
time.sleep(WAITING_FOR_REFRESH_SEC)
|
177
178
|
continue
|
178
179
|
log.debug(f"Completed refresh for index {self.index_name}")
|
179
180
|
|
180
181
|
def _do_force_merge(self):
|
181
182
|
log.debug(f"Starting force merge for index {self.index_name}")
|
182
|
-
force_merge_endpoint = f
|
183
|
-
force_merge_task_id = self.client.transport.perform_request(
|
184
|
-
SECONDS_WAITING_FOR_FORCE_MERGE_API_CALL_SEC = 30
|
183
|
+
force_merge_endpoint = f"/{self.index_name}/_forcemerge?max_num_segments=1&wait_for_completion=false"
|
184
|
+
force_merge_task_id = self.client.transport.perform_request("POST", force_merge_endpoint)["task"]
|
185
185
|
while True:
|
186
|
-
time.sleep(
|
186
|
+
time.sleep(WAITING_FOR_FORCE_MERGE_SEC)
|
187
187
|
task_status = self.client.tasks.get(task_id=force_merge_task_id)
|
188
|
-
if task_status[
|
188
|
+
if task_status["completed"]:
|
189
189
|
break
|
190
190
|
log.debug(f"Completed force merge for index {self.index_name}")
|
191
191
|
|
192
192
|
def _load_graphs_to_memory(self):
|
193
193
|
if self.case_config.engine != AWSOS_Engine.lucene:
|
194
194
|
log.info("Calling warmup API to load graphs into memory")
|
195
|
-
warmup_endpoint = f
|
196
|
-
self.client.transport.perform_request(
|
197
|
-
|
198
|
-
def ready_to_load(self):
|
199
|
-
"""ready_to_load will be called before load in load cases."""
|
200
|
-
pass
|
195
|
+
warmup_endpoint = f"/_plugins/_knn/warmup/{self.index_name}"
|
196
|
+
self.client.transport.perform_request("GET", warmup_endpoint)
|
@@ -14,22 +14,20 @@ from .. import DB
|
|
14
14
|
|
15
15
|
|
16
16
|
class AWSOpenSearchTypedDict(TypedDict):
|
17
|
-
host: Annotated[
|
18
|
-
str, click.option("--host", type=str, help="Db host", required=True)
|
19
|
-
]
|
17
|
+
host: Annotated[str, click.option("--host", type=str, help="Db host", required=True)]
|
20
18
|
port: Annotated[int, click.option("--port", type=int, default=443, help="Db Port")]
|
21
19
|
user: Annotated[str, click.option("--user", type=str, default="admin", help="Db User")]
|
22
20
|
password: Annotated[str, click.option("--password", type=str, help="Db password")]
|
23
21
|
|
24
22
|
|
25
|
-
class AWSOpenSearchHNSWTypedDict(CommonTypedDict, AWSOpenSearchTypedDict, HNSWFlavor2):
|
26
|
-
...
|
23
|
+
class AWSOpenSearchHNSWTypedDict(CommonTypedDict, AWSOpenSearchTypedDict, HNSWFlavor2): ...
|
27
24
|
|
28
25
|
|
29
26
|
@cli.command()
|
30
27
|
@click_parameter_decorators_from_typed_dict(AWSOpenSearchHNSWTypedDict)
|
31
28
|
def AWSOpenSearch(**parameters: Unpack[AWSOpenSearchHNSWTypedDict]):
|
32
29
|
from .config import AWSOpenSearchConfig, AWSOpenSearchIndexConfig
|
30
|
+
|
33
31
|
run(
|
34
32
|
db=DB.AWSOpenSearch,
|
35
33
|
db_config=AWSOpenSearchConfig(
|
@@ -38,7 +36,6 @@ def AWSOpenSearch(**parameters: Unpack[AWSOpenSearchHNSWTypedDict]):
|
|
38
36
|
user=parameters["user"],
|
39
37
|
password=SecretStr(parameters["password"]),
|
40
38
|
),
|
41
|
-
db_case_config=AWSOpenSearchIndexConfig(
|
42
|
-
),
|
39
|
+
db_case_config=AWSOpenSearchIndexConfig(),
|
43
40
|
**parameters,
|
44
41
|
)
|
@@ -1,10 +1,13 @@
|
|
1
1
|
import logging
|
2
2
|
from enum import Enum
|
3
|
-
from pydantic import SecretStr, BaseModel
|
4
3
|
|
5
|
-
from
|
4
|
+
from pydantic import BaseModel, SecretStr
|
5
|
+
|
6
|
+
from ..api import DBCaseConfig, DBConfig, MetricType
|
6
7
|
|
7
8
|
log = logging.getLogger(__name__)
|
9
|
+
|
10
|
+
|
8
11
|
class AWSOpenSearchConfig(DBConfig, BaseModel):
|
9
12
|
host: str = ""
|
10
13
|
port: int = 443
|
@@ -13,7 +16,7 @@ class AWSOpenSearchConfig(DBConfig, BaseModel):
|
|
13
16
|
|
14
17
|
def to_dict(self) -> dict:
|
15
18
|
return {
|
16
|
-
"hosts": [{
|
19
|
+
"hosts": [{"host": self.host, "port": self.port}],
|
17
20
|
"http_auth": (self.user, self.password.get_secret_value()),
|
18
21
|
"use_ssl": True,
|
19
22
|
"http_compress": True,
|
@@ -40,25 +43,26 @@ class AWSOpenSearchIndexConfig(BaseModel, DBCaseConfig):
|
|
40
43
|
def parse_metric(self) -> str:
|
41
44
|
if self.metric_type == MetricType.IP:
|
42
45
|
return "innerproduct"
|
43
|
-
|
46
|
+
if self.metric_type == MetricType.COSINE:
|
44
47
|
if self.engine == AWSOS_Engine.faiss:
|
45
|
-
log.info(
|
48
|
+
log.info(
|
49
|
+
"Using innerproduct because faiss doesn't support cosine as metric type for Opensearch",
|
50
|
+
)
|
46
51
|
return "innerproduct"
|
47
52
|
return "cosinesimil"
|
48
53
|
return "l2"
|
49
54
|
|
50
55
|
def index_param(self) -> dict:
|
51
|
-
|
56
|
+
return {
|
52
57
|
"name": "hnsw",
|
53
58
|
"space_type": self.parse_metric(),
|
54
59
|
"engine": self.engine.value,
|
55
60
|
"parameters": {
|
56
61
|
"ef_construction": self.efConstruction,
|
57
62
|
"m": self.M,
|
58
|
-
"ef_search": self.efSearch
|
59
|
-
}
|
63
|
+
"ef_search": self.efSearch,
|
64
|
+
},
|
60
65
|
}
|
61
|
-
return params
|
62
66
|
|
63
67
|
def search_param(self) -> dict:
|
64
68
|
return {}
|
@@ -1,12 +1,16 @@
|
|
1
|
-
import
|
1
|
+
import logging
|
2
|
+
import random
|
3
|
+
import time
|
4
|
+
|
2
5
|
from opensearchpy import OpenSearch
|
3
|
-
from opensearch_dsl import Search, Document, Text, Keyword
|
4
6
|
|
5
|
-
|
7
|
+
log = logging.getLogger(__name__)
|
8
|
+
|
9
|
+
_HOST = "xxxxxx.us-west-2.es.amazonaws.com"
|
6
10
|
_PORT = 443
|
7
|
-
_AUTH = (
|
11
|
+
_AUTH = ("admin", "xxxxxx") # For testing only. Don't store credentials in code.
|
8
12
|
|
9
|
-
_INDEX_NAME =
|
13
|
+
_INDEX_NAME = "my-dsl-index"
|
10
14
|
_BATCH = 100
|
11
15
|
_ROWS = 100
|
12
16
|
_DIM = 128
|
@@ -14,25 +18,24 @@ _TOPK = 10
|
|
14
18
|
|
15
19
|
|
16
20
|
def create_client():
|
17
|
-
|
18
|
-
hosts=[{
|
19
|
-
http_compress=True,
|
21
|
+
return OpenSearch(
|
22
|
+
hosts=[{"host": _HOST, "port": _PORT}],
|
23
|
+
http_compress=True, # enables gzip compression for request bodies
|
20
24
|
http_auth=_AUTH,
|
21
25
|
use_ssl=True,
|
22
26
|
verify_certs=True,
|
23
27
|
ssl_assert_hostname=False,
|
24
28
|
ssl_show_warn=False,
|
25
29
|
)
|
26
|
-
return client
|
27
30
|
|
28
31
|
|
29
|
-
def create_index(client, index_name):
|
32
|
+
def create_index(client: OpenSearch, index_name: str):
|
30
33
|
settings = {
|
31
34
|
"index": {
|
32
35
|
"knn": True,
|
33
36
|
"number_of_shards": 1,
|
34
37
|
"refresh_interval": "5s",
|
35
|
-
}
|
38
|
+
},
|
36
39
|
}
|
37
40
|
mappings = {
|
38
41
|
"properties": {
|
@@ -46,41 +49,46 @@ def create_index(client, index_name):
|
|
46
49
|
"parameters": {
|
47
50
|
"ef_construction": 256,
|
48
51
|
"m": 16,
|
49
|
-
}
|
50
|
-
}
|
51
|
-
}
|
52
|
-
}
|
52
|
+
},
|
53
|
+
},
|
54
|
+
},
|
55
|
+
},
|
53
56
|
}
|
54
57
|
|
55
|
-
response = client.indices.create(
|
56
|
-
|
57
|
-
|
58
|
+
response = client.indices.create(
|
59
|
+
index=index_name,
|
60
|
+
body={"settings": settings, "mappings": mappings},
|
61
|
+
)
|
62
|
+
log.info("\nCreating index:")
|
63
|
+
log.info(response)
|
58
64
|
|
59
65
|
|
60
|
-
def delete_index(client, index_name):
|
66
|
+
def delete_index(client: OpenSearch, index_name: str):
|
61
67
|
response = client.indices.delete(index=index_name)
|
62
|
-
|
63
|
-
|
68
|
+
log.info("\nDeleting index:")
|
69
|
+
log.info(response)
|
64
70
|
|
65
71
|
|
66
|
-
def bulk_insert(client, index_name):
|
72
|
+
def bulk_insert(client: OpenSearch, index_name: str):
|
67
73
|
# Perform bulk operations
|
68
|
-
ids =
|
74
|
+
ids = list(range(_ROWS))
|
69
75
|
vec = [[random.random() for _ in range(_DIM)] for _ in range(_ROWS)]
|
70
76
|
|
71
77
|
docs = []
|
72
78
|
for i in range(0, _ROWS, _BATCH):
|
73
79
|
docs.clear()
|
74
|
-
for j in range(
|
75
|
-
docs.append({"index": {"_index": index_name, "_id": ids[i+j]}})
|
76
|
-
docs.append({"embedding": vec[i+j]})
|
80
|
+
for j in range(_BATCH):
|
81
|
+
docs.append({"index": {"_index": index_name, "_id": ids[i + j]}})
|
82
|
+
docs.append({"embedding": vec[i + j]})
|
77
83
|
response = client.bulk(docs)
|
78
|
-
|
84
|
+
log.info(f"Adding documents: {len(response['items'])}, {response['errors']}")
|
79
85
|
response = client.indices.stats(index_name)
|
80
|
-
|
86
|
+
log.info(
|
87
|
+
f'Total document count in index: { response["_all"]["primaries"]["indexing"]["index_total"] }',
|
88
|
+
)
|
81
89
|
|
82
90
|
|
83
|
-
def search(client, index_name):
|
91
|
+
def search(client: OpenSearch, index_name: str):
|
84
92
|
# Search for the document.
|
85
93
|
search_body = {
|
86
94
|
"size": _TOPK,
|
@@ -89,53 +97,55 @@ def search(client, index_name):
|
|
89
97
|
"embedding": {
|
90
98
|
"vector": [random.random() for _ in range(_DIM)],
|
91
99
|
"k": _TOPK,
|
92
|
-
}
|
93
|
-
}
|
94
|
-
}
|
100
|
+
},
|
101
|
+
},
|
102
|
+
},
|
95
103
|
}
|
96
104
|
while True:
|
97
105
|
response = client.search(index=index_name, body=search_body)
|
98
|
-
|
99
|
-
|
100
|
-
|
106
|
+
log.info(f'\nSearch took: {response["took"]}')
|
107
|
+
log.info(f'\nSearch shards: {response["_shards"]}')
|
108
|
+
log.info(f'\nSearch hits total: {response["hits"]["total"]}')
|
101
109
|
result = response["hits"]["hits"]
|
102
110
|
if len(result) != 0:
|
103
|
-
|
111
|
+
log.info("\nSearch results:")
|
104
112
|
for hit in response["hits"]["hits"]:
|
105
|
-
|
113
|
+
log.info(hit["_id"], hit["_score"])
|
106
114
|
break
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
115
|
+
log.info("\nSearch not ready, sleep 1s")
|
116
|
+
time.sleep(1)
|
117
|
+
|
118
|
+
|
119
|
+
SECONDS_WAITING_FOR_FORCE_MERGE_API_CALL_SEC = 30
|
120
|
+
WAITINT_FOR_REFRESH_SEC = 30
|
121
|
+
|
122
|
+
|
123
|
+
def optimize_index(client: OpenSearch, index_name: str):
|
124
|
+
log.info(f"Starting force merge for index {index_name}")
|
125
|
+
force_merge_endpoint = f"/{index_name}/_forcemerge?max_num_segments=1&wait_for_completion=false"
|
126
|
+
force_merge_task_id = client.transport.perform_request("POST", force_merge_endpoint)["task"]
|
116
127
|
while True:
|
117
128
|
time.sleep(SECONDS_WAITING_FOR_FORCE_MERGE_API_CALL_SEC)
|
118
129
|
task_status = client.tasks.get(task_id=force_merge_task_id)
|
119
|
-
if task_status[
|
130
|
+
if task_status["completed"]:
|
120
131
|
break
|
121
|
-
|
132
|
+
log.info(f"Completed force merge for index {index_name}")
|
122
133
|
|
123
134
|
|
124
|
-
def refresh_index(client, index_name):
|
125
|
-
|
126
|
-
SECONDS_WAITING_FOR_REFRESH_API_CALL_SEC = 30
|
135
|
+
def refresh_index(client: OpenSearch, index_name: str):
|
136
|
+
log.info(f"Starting refresh for index {index_name}")
|
127
137
|
while True:
|
128
138
|
try:
|
129
|
-
|
139
|
+
log.info("Starting the Refresh Index..")
|
130
140
|
client.indices.refresh(index=index_name)
|
131
141
|
break
|
132
142
|
except Exception as e:
|
133
|
-
|
134
|
-
f"Refresh errored out. Sleeping for {
|
135
|
-
|
143
|
+
log.info(
|
144
|
+
f"Refresh errored out. Sleeping for {WAITINT_FOR_REFRESH_SEC} sec and then Retrying : {e}",
|
145
|
+
)
|
146
|
+
time.sleep(WAITINT_FOR_REFRESH_SEC)
|
136
147
|
continue
|
137
|
-
|
138
|
-
|
148
|
+
log.info(f"Completed refresh for index {index_name}")
|
139
149
|
|
140
150
|
|
141
151
|
def main():
|
@@ -148,9 +158,9 @@ def main():
|
|
148
158
|
search(client, _INDEX_NAME)
|
149
159
|
delete_index(client, _INDEX_NAME)
|
150
160
|
except Exception as e:
|
151
|
-
|
161
|
+
log.info(e)
|
152
162
|
delete_index(client, _INDEX_NAME)
|
153
163
|
|
154
164
|
|
155
|
-
if __name__ ==
|
165
|
+
if __name__ == "__main__":
|
156
166
|
main()
|