vectordb-bench 0.0.18__py3-none-any.whl → 0.0.20__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vectordb_bench/__init__.py +49 -24
- vectordb_bench/__main__.py +4 -3
- vectordb_bench/backend/assembler.py +12 -13
- vectordb_bench/backend/cases.py +56 -46
- vectordb_bench/backend/clients/__init__.py +101 -14
- vectordb_bench/backend/clients/aliyun_elasticsearch/aliyun_elasticsearch.py +26 -0
- vectordb_bench/backend/clients/aliyun_elasticsearch/config.py +18 -0
- vectordb_bench/backend/clients/aliyun_opensearch/aliyun_opensearch.py +345 -0
- vectordb_bench/backend/clients/aliyun_opensearch/config.py +47 -0
- vectordb_bench/backend/clients/alloydb/alloydb.py +58 -80
- vectordb_bench/backend/clients/alloydb/cli.py +52 -35
- vectordb_bench/backend/clients/alloydb/config.py +30 -30
- vectordb_bench/backend/clients/api.py +8 -9
- vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py +46 -47
- vectordb_bench/backend/clients/aws_opensearch/cli.py +4 -7
- vectordb_bench/backend/clients/aws_opensearch/config.py +13 -9
- vectordb_bench/backend/clients/aws_opensearch/run.py +69 -59
- vectordb_bench/backend/clients/chroma/chroma.py +38 -36
- vectordb_bench/backend/clients/chroma/config.py +4 -2
- vectordb_bench/backend/clients/elastic_cloud/config.py +5 -5
- vectordb_bench/backend/clients/elastic_cloud/elastic_cloud.py +23 -22
- vectordb_bench/backend/clients/memorydb/cli.py +8 -8
- vectordb_bench/backend/clients/memorydb/config.py +2 -2
- vectordb_bench/backend/clients/memorydb/memorydb.py +65 -53
- vectordb_bench/backend/clients/milvus/cli.py +62 -80
- vectordb_bench/backend/clients/milvus/config.py +31 -7
- vectordb_bench/backend/clients/milvus/milvus.py +23 -26
- vectordb_bench/backend/clients/pgdiskann/cli.py +29 -22
- vectordb_bench/backend/clients/pgdiskann/config.py +29 -26
- vectordb_bench/backend/clients/pgdiskann/pgdiskann.py +55 -73
- vectordb_bench/backend/clients/pgvecto_rs/cli.py +9 -11
- vectordb_bench/backend/clients/pgvecto_rs/config.py +8 -14
- vectordb_bench/backend/clients/pgvecto_rs/pgvecto_rs.py +33 -34
- vectordb_bench/backend/clients/pgvector/cli.py +40 -31
- vectordb_bench/backend/clients/pgvector/config.py +63 -73
- vectordb_bench/backend/clients/pgvector/pgvector.py +97 -98
- vectordb_bench/backend/clients/pgvectorscale/cli.py +38 -24
- vectordb_bench/backend/clients/pgvectorscale/config.py +14 -15
- vectordb_bench/backend/clients/pgvectorscale/pgvectorscale.py +38 -43
- vectordb_bench/backend/clients/pinecone/config.py +1 -0
- vectordb_bench/backend/clients/pinecone/pinecone.py +14 -21
- vectordb_bench/backend/clients/qdrant_cloud/config.py +11 -10
- vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py +40 -31
- vectordb_bench/backend/clients/redis/cli.py +6 -12
- vectordb_bench/backend/clients/redis/config.py +7 -5
- vectordb_bench/backend/clients/redis/redis.py +94 -58
- vectordb_bench/backend/clients/test/cli.py +1 -2
- vectordb_bench/backend/clients/test/config.py +2 -2
- vectordb_bench/backend/clients/test/test.py +4 -5
- vectordb_bench/backend/clients/weaviate_cloud/cli.py +3 -4
- vectordb_bench/backend/clients/weaviate_cloud/config.py +2 -2
- vectordb_bench/backend/clients/weaviate_cloud/weaviate_cloud.py +36 -22
- vectordb_bench/backend/clients/zilliz_cloud/cli.py +14 -11
- vectordb_bench/backend/clients/zilliz_cloud/config.py +2 -4
- vectordb_bench/backend/clients/zilliz_cloud/zilliz_cloud.py +1 -1
- vectordb_bench/backend/data_source.py +30 -18
- vectordb_bench/backend/dataset.py +47 -27
- vectordb_bench/backend/result_collector.py +2 -3
- vectordb_bench/backend/runner/__init__.py +4 -6
- vectordb_bench/backend/runner/mp_runner.py +85 -34
- vectordb_bench/backend/runner/rate_runner.py +51 -23
- vectordb_bench/backend/runner/read_write_runner.py +140 -46
- vectordb_bench/backend/runner/serial_runner.py +99 -50
- vectordb_bench/backend/runner/util.py +4 -19
- vectordb_bench/backend/task_runner.py +95 -74
- vectordb_bench/backend/utils.py +17 -9
- vectordb_bench/base.py +0 -1
- vectordb_bench/cli/cli.py +65 -60
- vectordb_bench/cli/vectordbbench.py +6 -7
- vectordb_bench/frontend/components/check_results/charts.py +8 -19
- vectordb_bench/frontend/components/check_results/data.py +4 -16
- vectordb_bench/frontend/components/check_results/filters.py +8 -16
- vectordb_bench/frontend/components/check_results/nav.py +4 -4
- vectordb_bench/frontend/components/check_results/priceTable.py +1 -3
- vectordb_bench/frontend/components/check_results/stPageConfig.py +2 -1
- vectordb_bench/frontend/components/concurrent/charts.py +12 -12
- vectordb_bench/frontend/components/custom/displayCustomCase.py +17 -11
- vectordb_bench/frontend/components/custom/displaypPrams.py +4 -2
- vectordb_bench/frontend/components/custom/getCustomConfig.py +1 -2
- vectordb_bench/frontend/components/custom/initStyle.py +1 -1
- vectordb_bench/frontend/components/get_results/saveAsImage.py +2 -0
- vectordb_bench/frontend/components/run_test/caseSelector.py +3 -9
- vectordb_bench/frontend/components/run_test/dbConfigSetting.py +1 -4
- vectordb_bench/frontend/components/run_test/dbSelector.py +1 -1
- vectordb_bench/frontend/components/run_test/generateTasks.py +8 -8
- vectordb_bench/frontend/components/run_test/submitTask.py +14 -18
- vectordb_bench/frontend/components/tables/data.py +3 -6
- vectordb_bench/frontend/config/dbCaseConfigs.py +108 -83
- vectordb_bench/frontend/pages/concurrent.py +3 -5
- vectordb_bench/frontend/pages/custom.py +30 -9
- vectordb_bench/frontend/pages/quries_per_dollar.py +3 -3
- vectordb_bench/frontend/pages/run_test.py +3 -7
- vectordb_bench/frontend/utils.py +1 -1
- vectordb_bench/frontend/vdb_benchmark.py +4 -6
- vectordb_bench/interface.py +56 -26
- vectordb_bench/log_util.py +59 -64
- vectordb_bench/metric.py +10 -11
- vectordb_bench/models.py +26 -43
- {vectordb_bench-0.0.18.dist-info → vectordb_bench-0.0.20.dist-info}/METADATA +34 -42
- vectordb_bench-0.0.20.dist-info/RECORD +135 -0
- {vectordb_bench-0.0.18.dist-info → vectordb_bench-0.0.20.dist-info}/WHEEL +1 -1
- vectordb_bench-0.0.18.dist-info/RECORD +0 -131
- {vectordb_bench-0.0.18.dist-info → vectordb_bench-0.0.20.dist-info}/LICENSE +0 -0
- {vectordb_bench-0.0.18.dist-info → vectordb_bench-0.0.20.dist-info}/entry_points.txt +0 -0
- {vectordb_bench-0.0.18.dist-info → vectordb_bench-0.0.20.dist-info}/top_level.txt +0 -0
@@ -1,9 +1,8 @@
|
|
1
1
|
from abc import ABC, abstractmethod
|
2
|
-
from enum import Enum
|
3
|
-
from typing import Any, Type
|
4
2
|
from contextlib import contextmanager
|
3
|
+
from enum import Enum
|
5
4
|
|
6
|
-
from pydantic import BaseModel,
|
5
|
+
from pydantic import BaseModel, SecretStr, validator
|
7
6
|
|
8
7
|
|
9
8
|
class MetricType(str, Enum):
|
@@ -65,13 +64,10 @@ class DBConfig(ABC, BaseModel):
|
|
65
64
|
raise NotImplementedError
|
66
65
|
|
67
66
|
@validator("*")
|
68
|
-
def not_empty_field(cls, v, field):
|
69
|
-
if (
|
70
|
-
field.name in cls.common_short_configs()
|
71
|
-
or field.name in cls.common_long_configs()
|
72
|
-
):
|
67
|
+
def not_empty_field(cls, v: any, field: any):
|
68
|
+
if field.name in cls.common_short_configs() or field.name in cls.common_long_configs():
|
73
69
|
return v
|
74
|
-
if not v and isinstance(v,
|
70
|
+
if not v and isinstance(v, str | SecretStr):
|
75
71
|
raise ValueError("Empty string!")
|
76
72
|
return v
|
77
73
|
|
@@ -204,6 +200,9 @@ class VectorDB(ABC):
|
|
204
200
|
"""
|
205
201
|
raise NotImplementedError
|
206
202
|
|
203
|
+
def optimize_with_size(self, data_size: int):
|
204
|
+
self.optimize()
|
205
|
+
|
207
206
|
# TODO: remove
|
208
207
|
@abstractmethod
|
209
208
|
def ready_to_load(self):
|
@@ -1,14 +1,18 @@
|
|
1
1
|
import logging
|
2
|
-
from contextlib import contextmanager
|
3
2
|
import time
|
4
|
-
from
|
5
|
-
from
|
6
|
-
|
3
|
+
from collections.abc import Iterable
|
4
|
+
from contextlib import contextmanager
|
5
|
+
|
7
6
|
from opensearchpy import OpenSearch
|
8
|
-
|
7
|
+
|
8
|
+
from ..api import IndexType, VectorDB
|
9
|
+
from .config import AWSOpenSearchConfig, AWSOpenSearchIndexConfig, AWSOS_Engine
|
9
10
|
|
10
11
|
log = logging.getLogger(__name__)
|
11
12
|
|
13
|
+
WAITING_FOR_REFRESH_SEC = 30
|
14
|
+
WAITING_FOR_FORCE_MERGE_SEC = 30
|
15
|
+
|
12
16
|
|
13
17
|
class AWSOpenSearch(VectorDB):
|
14
18
|
def __init__(
|
@@ -17,7 +21,7 @@ class AWSOpenSearch(VectorDB):
|
|
17
21
|
db_config: dict,
|
18
22
|
db_case_config: AWSOpenSearchIndexConfig,
|
19
23
|
index_name: str = "vdb_bench_index", # must be lowercase
|
20
|
-
id_col_name: str = "
|
24
|
+
id_col_name: str = "_id",
|
21
25
|
vector_col_name: str = "embedding",
|
22
26
|
drop_old: bool = False,
|
23
27
|
**kwargs,
|
@@ -27,9 +31,7 @@ class AWSOpenSearch(VectorDB):
|
|
27
31
|
self.case_config = db_case_config
|
28
32
|
self.index_name = index_name
|
29
33
|
self.id_col_name = id_col_name
|
30
|
-
self.category_col_names = [
|
31
|
-
f"scalar-{categoryCount}" for categoryCount in [2, 5, 10, 100, 1000]
|
32
|
-
]
|
34
|
+
self.category_col_names = [f"scalar-{categoryCount}" for categoryCount in [2, 5, 10, 100, 1000]]
|
33
35
|
self.vector_col_name = vector_col_name
|
34
36
|
|
35
37
|
log.info(f"AWS_OpenSearch client config: {self.db_config}")
|
@@ -46,39 +48,32 @@ class AWSOpenSearch(VectorDB):
|
|
46
48
|
return AWSOpenSearchConfig
|
47
49
|
|
48
50
|
@classmethod
|
49
|
-
def case_config_cls(
|
50
|
-
cls, index_type: IndexType | None = None
|
51
|
-
) -> AWSOpenSearchIndexConfig:
|
51
|
+
def case_config_cls(cls, index_type: IndexType | None = None) -> AWSOpenSearchIndexConfig:
|
52
52
|
return AWSOpenSearchIndexConfig
|
53
53
|
|
54
54
|
def _create_index(self, client: OpenSearch):
|
55
55
|
settings = {
|
56
56
|
"index": {
|
57
57
|
"knn": True,
|
58
|
-
|
59
|
-
# "refresh_interval": "600s",
|
60
|
-
}
|
58
|
+
},
|
61
59
|
}
|
62
60
|
mappings = {
|
63
61
|
"properties": {
|
64
|
-
|
65
|
-
**{
|
66
|
-
categoryCol: {"type": "keyword"}
|
67
|
-
for categoryCol in self.category_col_names
|
68
|
-
},
|
62
|
+
**{categoryCol: {"type": "keyword"} for categoryCol in self.category_col_names},
|
69
63
|
self.vector_col_name: {
|
70
64
|
"type": "knn_vector",
|
71
65
|
"dimension": self.dim,
|
72
66
|
"method": self.case_config.index_param(),
|
73
67
|
},
|
74
|
-
}
|
68
|
+
},
|
75
69
|
}
|
76
70
|
try:
|
77
71
|
client.indices.create(
|
78
|
-
index=self.index_name,
|
72
|
+
index=self.index_name,
|
73
|
+
body={"settings": settings, "mappings": mappings},
|
79
74
|
)
|
80
75
|
except Exception as e:
|
81
|
-
log.warning(f"Failed to create index: {self.index_name} error: {
|
76
|
+
log.warning(f"Failed to create index: {self.index_name} error: {e!s}")
|
82
77
|
raise e from None
|
83
78
|
|
84
79
|
@contextmanager
|
@@ -87,7 +82,6 @@ class AWSOpenSearch(VectorDB):
|
|
87
82
|
self.client = OpenSearch(**self.db_config)
|
88
83
|
|
89
84
|
yield
|
90
|
-
# self.client.transport.close()
|
91
85
|
self.client = None
|
92
86
|
del self.client
|
93
87
|
|
@@ -102,16 +96,20 @@ class AWSOpenSearch(VectorDB):
|
|
102
96
|
|
103
97
|
insert_data = []
|
104
98
|
for i in range(len(embeddings)):
|
105
|
-
insert_data.append(
|
99
|
+
insert_data.append(
|
100
|
+
{"index": {"_index": self.index_name, self.id_col_name: metadata[i]}},
|
101
|
+
)
|
106
102
|
insert_data.append({self.vector_col_name: embeddings[i]})
|
107
103
|
try:
|
108
104
|
resp = self.client.bulk(insert_data)
|
109
105
|
log.info(f"AWS_OpenSearch adding documents: {len(resp['items'])}")
|
110
106
|
resp = self.client.indices.stats(self.index_name)
|
111
|
-
log.info(
|
107
|
+
log.info(
|
108
|
+
f"Total document count in index: {resp['_all']['primaries']['indexing']['index_total']}",
|
109
|
+
)
|
112
110
|
return (len(embeddings), None)
|
113
111
|
except Exception as e:
|
114
|
-
log.warning(f"Failed to insert data: {self.index_name} error: {
|
112
|
+
log.warning(f"Failed to insert data: {self.index_name} error: {e!s}")
|
115
113
|
time.sleep(10)
|
116
114
|
return self.insert_embeddings(embeddings, metadata)
|
117
115
|
|
@@ -136,20 +134,23 @@ class AWSOpenSearch(VectorDB):
|
|
136
134
|
body = {
|
137
135
|
"size": k,
|
138
136
|
"query": {"knn": {self.vector_col_name: {"vector": query, "k": k}}},
|
139
|
-
**({"filter": {"range": {self.id_col_name: {"gt": filters["id"]}}}} if filters else {})
|
137
|
+
**({"filter": {"range": {self.id_col_name: {"gt": filters["id"]}}}} if filters else {}),
|
140
138
|
}
|
141
139
|
try:
|
142
|
-
resp = self.client.search(
|
140
|
+
resp = self.client.search(
|
141
|
+
index=self.index_name,
|
142
|
+
body=body,
|
143
|
+
size=k,
|
144
|
+
_source=False,
|
145
|
+
docvalue_fields=[self.id_col_name],
|
146
|
+
stored_fields="_none_",
|
147
|
+
)
|
143
148
|
log.info(f'Search took: {resp["took"]}')
|
144
149
|
log.info(f'Search shards: {resp["_shards"]}')
|
145
150
|
log.info(f'Search hits total: {resp["hits"]["total"]}')
|
146
|
-
|
147
|
-
#result = [int(d["_id"]) for d in resp["hits"]["hits"]]
|
148
|
-
# log.info(f'success! length={len(res)}')
|
149
|
-
|
150
|
-
return result
|
151
|
+
return [int(h["fields"][self.id_col_name][0]) for h in resp["hits"]["hits"]]
|
151
152
|
except Exception as e:
|
152
|
-
log.warning(f"Failed to search: {self.index_name} error: {
|
153
|
+
log.warning(f"Failed to search: {self.index_name} error: {e!s}")
|
153
154
|
raise e from None
|
154
155
|
|
155
156
|
def optimize(self):
|
@@ -164,37 +165,35 @@ class AWSOpenSearch(VectorDB):
|
|
164
165
|
|
165
166
|
def _refresh_index(self):
|
166
167
|
log.debug(f"Starting refresh for index {self.index_name}")
|
167
|
-
SECONDS_WAITING_FOR_REFRESH_API_CALL_SEC = 30
|
168
168
|
while True:
|
169
169
|
try:
|
170
|
-
log.info(
|
170
|
+
log.info("Starting the Refresh Index..")
|
171
171
|
self.client.indices.refresh(index=self.index_name)
|
172
172
|
break
|
173
173
|
except Exception as e:
|
174
174
|
log.info(
|
175
|
-
f"Refresh errored out. Sleeping for {
|
176
|
-
|
175
|
+
f"Refresh errored out. Sleeping for {WAITING_FOR_REFRESH_SEC} sec and then Retrying : {e}",
|
176
|
+
)
|
177
|
+
time.sleep(WAITING_FOR_REFRESH_SEC)
|
177
178
|
continue
|
178
179
|
log.debug(f"Completed refresh for index {self.index_name}")
|
179
180
|
|
180
181
|
def _do_force_merge(self):
|
181
182
|
log.debug(f"Starting force merge for index {self.index_name}")
|
182
|
-
force_merge_endpoint = f
|
183
|
-
force_merge_task_id = self.client.transport.perform_request(
|
184
|
-
SECONDS_WAITING_FOR_FORCE_MERGE_API_CALL_SEC = 30
|
183
|
+
force_merge_endpoint = f"/{self.index_name}/_forcemerge?max_num_segments=1&wait_for_completion=false"
|
184
|
+
force_merge_task_id = self.client.transport.perform_request("POST", force_merge_endpoint)["task"]
|
185
185
|
while True:
|
186
|
-
time.sleep(
|
186
|
+
time.sleep(WAITING_FOR_FORCE_MERGE_SEC)
|
187
187
|
task_status = self.client.tasks.get(task_id=force_merge_task_id)
|
188
|
-
if task_status[
|
188
|
+
if task_status["completed"]:
|
189
189
|
break
|
190
190
|
log.debug(f"Completed force merge for index {self.index_name}")
|
191
191
|
|
192
192
|
def _load_graphs_to_memory(self):
|
193
193
|
if self.case_config.engine != AWSOS_Engine.lucene:
|
194
194
|
log.info("Calling warmup API to load graphs into memory")
|
195
|
-
warmup_endpoint = f
|
196
|
-
self.client.transport.perform_request(
|
195
|
+
warmup_endpoint = f"/_plugins/_knn/warmup/{self.index_name}"
|
196
|
+
self.client.transport.perform_request("GET", warmup_endpoint)
|
197
197
|
|
198
198
|
def ready_to_load(self):
|
199
199
|
"""ready_to_load will be called before load in load cases."""
|
200
|
-
pass
|
@@ -14,22 +14,20 @@ from .. import DB
|
|
14
14
|
|
15
15
|
|
16
16
|
class AWSOpenSearchTypedDict(TypedDict):
|
17
|
-
host: Annotated[
|
18
|
-
str, click.option("--host", type=str, help="Db host", required=True)
|
19
|
-
]
|
17
|
+
host: Annotated[str, click.option("--host", type=str, help="Db host", required=True)]
|
20
18
|
port: Annotated[int, click.option("--port", type=int, default=443, help="Db Port")]
|
21
19
|
user: Annotated[str, click.option("--user", type=str, default="admin", help="Db User")]
|
22
20
|
password: Annotated[str, click.option("--password", type=str, help="Db password")]
|
23
21
|
|
24
22
|
|
25
|
-
class AWSOpenSearchHNSWTypedDict(CommonTypedDict, AWSOpenSearchTypedDict, HNSWFlavor2):
|
26
|
-
...
|
23
|
+
class AWSOpenSearchHNSWTypedDict(CommonTypedDict, AWSOpenSearchTypedDict, HNSWFlavor2): ...
|
27
24
|
|
28
25
|
|
29
26
|
@cli.command()
|
30
27
|
@click_parameter_decorators_from_typed_dict(AWSOpenSearchHNSWTypedDict)
|
31
28
|
def AWSOpenSearch(**parameters: Unpack[AWSOpenSearchHNSWTypedDict]):
|
32
29
|
from .config import AWSOpenSearchConfig, AWSOpenSearchIndexConfig
|
30
|
+
|
33
31
|
run(
|
34
32
|
db=DB.AWSOpenSearch,
|
35
33
|
db_config=AWSOpenSearchConfig(
|
@@ -38,7 +36,6 @@ def AWSOpenSearch(**parameters: Unpack[AWSOpenSearchHNSWTypedDict]):
|
|
38
36
|
user=parameters["user"],
|
39
37
|
password=SecretStr(parameters["password"]),
|
40
38
|
),
|
41
|
-
db_case_config=AWSOpenSearchIndexConfig(
|
42
|
-
),
|
39
|
+
db_case_config=AWSOpenSearchIndexConfig(),
|
43
40
|
**parameters,
|
44
41
|
)
|
@@ -1,10 +1,13 @@
|
|
1
1
|
import logging
|
2
2
|
from enum import Enum
|
3
|
-
from pydantic import SecretStr, BaseModel
|
4
3
|
|
5
|
-
from
|
4
|
+
from pydantic import BaseModel, SecretStr
|
5
|
+
|
6
|
+
from ..api import DBCaseConfig, DBConfig, MetricType
|
6
7
|
|
7
8
|
log = logging.getLogger(__name__)
|
9
|
+
|
10
|
+
|
8
11
|
class AWSOpenSearchConfig(DBConfig, BaseModel):
|
9
12
|
host: str = ""
|
10
13
|
port: int = 443
|
@@ -13,7 +16,7 @@ class AWSOpenSearchConfig(DBConfig, BaseModel):
|
|
13
16
|
|
14
17
|
def to_dict(self) -> dict:
|
15
18
|
return {
|
16
|
-
"hosts": [{
|
19
|
+
"hosts": [{"host": self.host, "port": self.port}],
|
17
20
|
"http_auth": (self.user, self.password.get_secret_value()),
|
18
21
|
"use_ssl": True,
|
19
22
|
"http_compress": True,
|
@@ -40,25 +43,26 @@ class AWSOpenSearchIndexConfig(BaseModel, DBCaseConfig):
|
|
40
43
|
def parse_metric(self) -> str:
|
41
44
|
if self.metric_type == MetricType.IP:
|
42
45
|
return "innerproduct"
|
43
|
-
|
46
|
+
if self.metric_type == MetricType.COSINE:
|
44
47
|
if self.engine == AWSOS_Engine.faiss:
|
45
|
-
log.info(
|
48
|
+
log.info(
|
49
|
+
"Using innerproduct because faiss doesn't support cosine as metric type for Opensearch",
|
50
|
+
)
|
46
51
|
return "innerproduct"
|
47
52
|
return "cosinesimil"
|
48
53
|
return "l2"
|
49
54
|
|
50
55
|
def index_param(self) -> dict:
|
51
|
-
|
56
|
+
return {
|
52
57
|
"name": "hnsw",
|
53
58
|
"space_type": self.parse_metric(),
|
54
59
|
"engine": self.engine.value,
|
55
60
|
"parameters": {
|
56
61
|
"ef_construction": self.efConstruction,
|
57
62
|
"m": self.M,
|
58
|
-
"ef_search": self.efSearch
|
59
|
-
}
|
63
|
+
"ef_search": self.efSearch,
|
64
|
+
},
|
60
65
|
}
|
61
|
-
return params
|
62
66
|
|
63
67
|
def search_param(self) -> dict:
|
64
68
|
return {}
|
@@ -1,12 +1,16 @@
|
|
1
|
-
import
|
1
|
+
import logging
|
2
|
+
import random
|
3
|
+
import time
|
4
|
+
|
2
5
|
from opensearchpy import OpenSearch
|
3
|
-
from opensearch_dsl import Search, Document, Text, Keyword
|
4
6
|
|
5
|
-
|
7
|
+
log = logging.getLogger(__name__)
|
8
|
+
|
9
|
+
_HOST = "xxxxxx.us-west-2.es.amazonaws.com"
|
6
10
|
_PORT = 443
|
7
|
-
_AUTH = (
|
11
|
+
_AUTH = ("admin", "xxxxxx") # For testing only. Don't store credentials in code.
|
8
12
|
|
9
|
-
_INDEX_NAME =
|
13
|
+
_INDEX_NAME = "my-dsl-index"
|
10
14
|
_BATCH = 100
|
11
15
|
_ROWS = 100
|
12
16
|
_DIM = 128
|
@@ -14,25 +18,24 @@ _TOPK = 10
|
|
14
18
|
|
15
19
|
|
16
20
|
def create_client():
|
17
|
-
|
18
|
-
hosts=[{
|
19
|
-
http_compress=True,
|
21
|
+
return OpenSearch(
|
22
|
+
hosts=[{"host": _HOST, "port": _PORT}],
|
23
|
+
http_compress=True, # enables gzip compression for request bodies
|
20
24
|
http_auth=_AUTH,
|
21
25
|
use_ssl=True,
|
22
26
|
verify_certs=True,
|
23
27
|
ssl_assert_hostname=False,
|
24
28
|
ssl_show_warn=False,
|
25
29
|
)
|
26
|
-
return client
|
27
30
|
|
28
31
|
|
29
|
-
def create_index(client, index_name):
|
32
|
+
def create_index(client: OpenSearch, index_name: str):
|
30
33
|
settings = {
|
31
34
|
"index": {
|
32
35
|
"knn": True,
|
33
36
|
"number_of_shards": 1,
|
34
37
|
"refresh_interval": "5s",
|
35
|
-
}
|
38
|
+
},
|
36
39
|
}
|
37
40
|
mappings = {
|
38
41
|
"properties": {
|
@@ -46,41 +49,46 @@ def create_index(client, index_name):
|
|
46
49
|
"parameters": {
|
47
50
|
"ef_construction": 256,
|
48
51
|
"m": 16,
|
49
|
-
}
|
50
|
-
}
|
51
|
-
}
|
52
|
-
}
|
52
|
+
},
|
53
|
+
},
|
54
|
+
},
|
55
|
+
},
|
53
56
|
}
|
54
57
|
|
55
|
-
response = client.indices.create(
|
56
|
-
|
57
|
-
|
58
|
+
response = client.indices.create(
|
59
|
+
index=index_name,
|
60
|
+
body={"settings": settings, "mappings": mappings},
|
61
|
+
)
|
62
|
+
log.info("\nCreating index:")
|
63
|
+
log.info(response)
|
58
64
|
|
59
65
|
|
60
|
-
def delete_index(client, index_name):
|
66
|
+
def delete_index(client: OpenSearch, index_name: str):
|
61
67
|
response = client.indices.delete(index=index_name)
|
62
|
-
|
63
|
-
|
68
|
+
log.info("\nDeleting index:")
|
69
|
+
log.info(response)
|
64
70
|
|
65
71
|
|
66
|
-
def bulk_insert(client, index_name):
|
72
|
+
def bulk_insert(client: OpenSearch, index_name: str):
|
67
73
|
# Perform bulk operations
|
68
|
-
ids =
|
74
|
+
ids = list(range(_ROWS))
|
69
75
|
vec = [[random.random() for _ in range(_DIM)] for _ in range(_ROWS)]
|
70
76
|
|
71
77
|
docs = []
|
72
78
|
for i in range(0, _ROWS, _BATCH):
|
73
79
|
docs.clear()
|
74
|
-
for j in range(
|
75
|
-
docs.append({"index": {"_index": index_name, "_id": ids[i+j]}})
|
76
|
-
docs.append({"embedding": vec[i+j]})
|
80
|
+
for j in range(_BATCH):
|
81
|
+
docs.append({"index": {"_index": index_name, "_id": ids[i + j]}})
|
82
|
+
docs.append({"embedding": vec[i + j]})
|
77
83
|
response = client.bulk(docs)
|
78
|
-
|
84
|
+
log.info(f"Adding documents: {len(response['items'])}, {response['errors']}")
|
79
85
|
response = client.indices.stats(index_name)
|
80
|
-
|
86
|
+
log.info(
|
87
|
+
f'Total document count in index: { response["_all"]["primaries"]["indexing"]["index_total"] }',
|
88
|
+
)
|
81
89
|
|
82
90
|
|
83
|
-
def search(client, index_name):
|
91
|
+
def search(client: OpenSearch, index_name: str):
|
84
92
|
# Search for the document.
|
85
93
|
search_body = {
|
86
94
|
"size": _TOPK,
|
@@ -89,53 +97,55 @@ def search(client, index_name):
|
|
89
97
|
"embedding": {
|
90
98
|
"vector": [random.random() for _ in range(_DIM)],
|
91
99
|
"k": _TOPK,
|
92
|
-
}
|
93
|
-
}
|
94
|
-
}
|
100
|
+
},
|
101
|
+
},
|
102
|
+
},
|
95
103
|
}
|
96
104
|
while True:
|
97
105
|
response = client.search(index=index_name, body=search_body)
|
98
|
-
|
99
|
-
|
100
|
-
|
106
|
+
log.info(f'\nSearch took: {response["took"]}')
|
107
|
+
log.info(f'\nSearch shards: {response["_shards"]}')
|
108
|
+
log.info(f'\nSearch hits total: {response["hits"]["total"]}')
|
101
109
|
result = response["hits"]["hits"]
|
102
110
|
if len(result) != 0:
|
103
|
-
|
111
|
+
log.info("\nSearch results:")
|
104
112
|
for hit in response["hits"]["hits"]:
|
105
|
-
|
113
|
+
log.info(hit["_id"], hit["_score"])
|
106
114
|
break
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
115
|
+
log.info("\nSearch not ready, sleep 1s")
|
116
|
+
time.sleep(1)
|
117
|
+
|
118
|
+
|
119
|
+
SECONDS_WAITING_FOR_FORCE_MERGE_API_CALL_SEC = 30
|
120
|
+
WAITINT_FOR_REFRESH_SEC = 30
|
121
|
+
|
122
|
+
|
123
|
+
def optimize_index(client: OpenSearch, index_name: str):
|
124
|
+
log.info(f"Starting force merge for index {index_name}")
|
125
|
+
force_merge_endpoint = f"/{index_name}/_forcemerge?max_num_segments=1&wait_for_completion=false"
|
126
|
+
force_merge_task_id = client.transport.perform_request("POST", force_merge_endpoint)["task"]
|
116
127
|
while True:
|
117
128
|
time.sleep(SECONDS_WAITING_FOR_FORCE_MERGE_API_CALL_SEC)
|
118
129
|
task_status = client.tasks.get(task_id=force_merge_task_id)
|
119
|
-
if task_status[
|
130
|
+
if task_status["completed"]:
|
120
131
|
break
|
121
|
-
|
132
|
+
log.info(f"Completed force merge for index {index_name}")
|
122
133
|
|
123
134
|
|
124
|
-
def refresh_index(client, index_name):
|
125
|
-
|
126
|
-
SECONDS_WAITING_FOR_REFRESH_API_CALL_SEC = 30
|
135
|
+
def refresh_index(client: OpenSearch, index_name: str):
|
136
|
+
log.info(f"Starting refresh for index {index_name}")
|
127
137
|
while True:
|
128
138
|
try:
|
129
|
-
|
139
|
+
log.info("Starting the Refresh Index..")
|
130
140
|
client.indices.refresh(index=index_name)
|
131
141
|
break
|
132
142
|
except Exception as e:
|
133
|
-
|
134
|
-
f"Refresh errored out. Sleeping for {
|
135
|
-
|
143
|
+
log.info(
|
144
|
+
f"Refresh errored out. Sleeping for {WAITINT_FOR_REFRESH_SEC} sec and then Retrying : {e}",
|
145
|
+
)
|
146
|
+
time.sleep(WAITINT_FOR_REFRESH_SEC)
|
136
147
|
continue
|
137
|
-
|
138
|
-
|
148
|
+
log.info(f"Completed refresh for index {index_name}")
|
139
149
|
|
140
150
|
|
141
151
|
def main():
|
@@ -148,9 +158,9 @@ def main():
|
|
148
158
|
search(client, _INDEX_NAME)
|
149
159
|
delete_index(client, _INDEX_NAME)
|
150
160
|
except Exception as e:
|
151
|
-
|
161
|
+
log.info(e)
|
152
162
|
delete_index(client, _INDEX_NAME)
|
153
163
|
|
154
164
|
|
155
|
-
if __name__ ==
|
165
|
+
if __name__ == "__main__":
|
156
166
|
main()
|
@@ -1,55 +1,55 @@
|
|
1
|
-
import
|
2
|
-
import logging
|
1
|
+
import logging
|
3
2
|
from contextlib import contextmanager
|
4
3
|
from typing import Any
|
5
|
-
|
4
|
+
|
5
|
+
import chromadb
|
6
|
+
|
7
|
+
from ..api import DBCaseConfig, VectorDB
|
6
8
|
|
7
9
|
log = logging.getLogger(__name__)
|
10
|
+
|
11
|
+
|
8
12
|
class ChromaClient(VectorDB):
|
9
|
-
"""Chroma client for VectorDB.
|
13
|
+
"""Chroma client for VectorDB.
|
10
14
|
To set up Chroma in docker, see https://docs.trychroma.com/usage-guide
|
11
15
|
or the instructions in tests/test_chroma.py
|
12
16
|
|
13
17
|
To change to running in process, modify the HttpClient() in __init__() and init().
|
14
|
-
"""
|
18
|
+
"""
|
15
19
|
|
16
20
|
def __init__(
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
):
|
25
|
-
|
21
|
+
self,
|
22
|
+
dim: int,
|
23
|
+
db_config: dict,
|
24
|
+
db_case_config: DBCaseConfig,
|
25
|
+
drop_old: bool = False,
|
26
|
+
**kwargs,
|
27
|
+
):
|
26
28
|
self.db_config = db_config
|
27
29
|
self.case_config = db_case_config
|
28
|
-
self.collection_name =
|
30
|
+
self.collection_name = "example2"
|
29
31
|
|
30
|
-
client = chromadb.HttpClient(host=self.db_config["host"],
|
31
|
-
port=self.db_config["port"])
|
32
|
+
client = chromadb.HttpClient(host=self.db_config["host"], port=self.db_config["port"])
|
32
33
|
assert client.heartbeat() is not None
|
33
34
|
if drop_old:
|
34
35
|
try:
|
35
|
-
client.reset()
|
36
|
-
except:
|
36
|
+
client.reset() # Reset the database
|
37
|
+
except Exception:
|
37
38
|
drop_old = False
|
38
39
|
log.info(f"Chroma client drop_old collection: {self.collection_name}")
|
39
40
|
|
40
41
|
@contextmanager
|
41
42
|
def init(self) -> None:
|
42
|
-
"""
|
43
|
+
"""create and destory connections to database.
|
43
44
|
|
44
45
|
Examples:
|
45
46
|
>>> with self.init():
|
46
47
|
>>> self.insert_embeddings()
|
47
48
|
"""
|
48
|
-
#create connection
|
49
|
-
self.client = chromadb.HttpClient(host=self.db_config["host"],
|
50
|
-
|
51
|
-
|
52
|
-
self.collection = self.client.get_or_create_collection('example2')
|
49
|
+
# create connection
|
50
|
+
self.client = chromadb.HttpClient(host=self.db_config["host"], port=self.db_config["port"])
|
51
|
+
|
52
|
+
self.collection = self.client.get_or_create_collection("example2")
|
53
53
|
yield
|
54
54
|
self.client = None
|
55
55
|
self.collection = None
|
@@ -79,12 +79,12 @@ class ChromaClient(VectorDB):
|
|
79
79
|
Returns:
|
80
80
|
(int, Exception): number of embeddings inserted and exception if any
|
81
81
|
"""
|
82
|
-
ids=[str(i) for i in metadata]
|
83
|
-
metadata = [{"id": int(i)} for i in metadata]
|
82
|
+
ids = [str(i) for i in metadata]
|
83
|
+
metadata = [{"id": int(i)} for i in metadata]
|
84
84
|
if len(embeddings) > 0:
|
85
85
|
self.collection.add(embeddings=embeddings, ids=ids, metadatas=metadata)
|
86
86
|
return len(embeddings), None
|
87
|
-
|
87
|
+
|
88
88
|
def search_embedding(
|
89
89
|
self,
|
90
90
|
query: list[float],
|
@@ -100,17 +100,19 @@ class ChromaClient(VectorDB):
|
|
100
100
|
kwargs: other arguments
|
101
101
|
|
102
102
|
Returns:
|
103
|
-
Dict {ids: list[list[int]],
|
104
|
-
embedding: list[list[float]]
|
103
|
+
Dict {ids: list[list[int]],
|
104
|
+
embedding: list[list[float]]
|
105
105
|
distance: list[list[float]]}
|
106
106
|
"""
|
107
107
|
if filters:
|
108
108
|
# assumes benchmark test filters of format: {'metadata': '>=10000', 'id': 10000}
|
109
109
|
id_value = filters.get("id")
|
110
|
-
results = self.collection.query(
|
111
|
-
|
112
|
-
|
113
|
-
|
110
|
+
results = self.collection.query(
|
111
|
+
query_embeddings=query,
|
112
|
+
n_results=k,
|
113
|
+
where={"id": {"$gt": id_value}},
|
114
|
+
)
|
115
|
+
# return list of id's in results
|
116
|
+
return [int(i) for i in results.get("ids")[0]]
|
114
117
|
results = self.collection.query(query_embeddings=query, n_results=k)
|
115
|
-
return [int(i) for i in results.get(
|
116
|
-
|
118
|
+
return [int(i) for i in results.get("ids")[0]]
|