vectordb-bench 0.0.19__py3-none-any.whl → 0.0.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vectordb_bench/__init__.py +49 -24
- vectordb_bench/__main__.py +4 -3
- vectordb_bench/backend/assembler.py +12 -13
- vectordb_bench/backend/cases.py +55 -45
- vectordb_bench/backend/clients/__init__.py +85 -14
- vectordb_bench/backend/clients/aliyun_elasticsearch/aliyun_elasticsearch.py +1 -2
- vectordb_bench/backend/clients/aliyun_elasticsearch/config.py +3 -4
- vectordb_bench/backend/clients/aliyun_opensearch/aliyun_opensearch.py +112 -77
- vectordb_bench/backend/clients/aliyun_opensearch/config.py +6 -7
- vectordb_bench/backend/clients/alloydb/alloydb.py +59 -84
- vectordb_bench/backend/clients/alloydb/cli.py +51 -34
- vectordb_bench/backend/clients/alloydb/config.py +30 -30
- vectordb_bench/backend/clients/api.py +13 -24
- vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py +50 -54
- vectordb_bench/backend/clients/aws_opensearch/cli.py +4 -7
- vectordb_bench/backend/clients/aws_opensearch/config.py +13 -9
- vectordb_bench/backend/clients/aws_opensearch/run.py +69 -59
- vectordb_bench/backend/clients/chroma/chroma.py +39 -40
- vectordb_bench/backend/clients/chroma/config.py +4 -2
- vectordb_bench/backend/clients/elastic_cloud/config.py +5 -5
- vectordb_bench/backend/clients/elastic_cloud/elastic_cloud.py +24 -26
- vectordb_bench/backend/clients/memorydb/cli.py +8 -8
- vectordb_bench/backend/clients/memorydb/config.py +2 -2
- vectordb_bench/backend/clients/memorydb/memorydb.py +67 -58
- vectordb_bench/backend/clients/milvus/cli.py +41 -83
- vectordb_bench/backend/clients/milvus/config.py +18 -8
- vectordb_bench/backend/clients/milvus/milvus.py +19 -39
- vectordb_bench/backend/clients/pgdiskann/cli.py +29 -22
- vectordb_bench/backend/clients/pgdiskann/config.py +29 -26
- vectordb_bench/backend/clients/pgdiskann/pgdiskann.py +56 -77
- vectordb_bench/backend/clients/pgvecto_rs/cli.py +9 -11
- vectordb_bench/backend/clients/pgvecto_rs/config.py +8 -14
- vectordb_bench/backend/clients/pgvecto_rs/pgvecto_rs.py +34 -43
- vectordb_bench/backend/clients/pgvector/cli.py +40 -31
- vectordb_bench/backend/clients/pgvector/config.py +63 -73
- vectordb_bench/backend/clients/pgvector/pgvector.py +98 -104
- vectordb_bench/backend/clients/pgvectorscale/cli.py +38 -24
- vectordb_bench/backend/clients/pgvectorscale/config.py +14 -15
- vectordb_bench/backend/clients/pgvectorscale/pgvectorscale.py +39 -49
- vectordb_bench/backend/clients/pinecone/config.py +1 -0
- vectordb_bench/backend/clients/pinecone/pinecone.py +15 -25
- vectordb_bench/backend/clients/qdrant_cloud/config.py +11 -10
- vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py +41 -35
- vectordb_bench/backend/clients/redis/cli.py +6 -12
- vectordb_bench/backend/clients/redis/config.py +7 -5
- vectordb_bench/backend/clients/redis/redis.py +95 -62
- vectordb_bench/backend/clients/test/cli.py +2 -3
- vectordb_bench/backend/clients/test/config.py +2 -2
- vectordb_bench/backend/clients/test/test.py +5 -9
- vectordb_bench/backend/clients/weaviate_cloud/cli.py +3 -4
- vectordb_bench/backend/clients/weaviate_cloud/config.py +2 -2
- vectordb_bench/backend/clients/weaviate_cloud/weaviate_cloud.py +37 -26
- vectordb_bench/backend/clients/zilliz_cloud/cli.py +14 -11
- vectordb_bench/backend/clients/zilliz_cloud/config.py +2 -4
- vectordb_bench/backend/clients/zilliz_cloud/zilliz_cloud.py +1 -1
- vectordb_bench/backend/data_source.py +18 -14
- vectordb_bench/backend/dataset.py +47 -27
- vectordb_bench/backend/result_collector.py +2 -3
- vectordb_bench/backend/runner/__init__.py +4 -6
- vectordb_bench/backend/runner/mp_runner.py +56 -23
- vectordb_bench/backend/runner/rate_runner.py +30 -19
- vectordb_bench/backend/runner/read_write_runner.py +46 -22
- vectordb_bench/backend/runner/serial_runner.py +81 -46
- vectordb_bench/backend/runner/util.py +4 -3
- vectordb_bench/backend/task_runner.py +92 -92
- vectordb_bench/backend/utils.py +17 -10
- vectordb_bench/base.py +0 -1
- vectordb_bench/cli/cli.py +65 -60
- vectordb_bench/cli/vectordbbench.py +6 -7
- vectordb_bench/frontend/components/check_results/charts.py +8 -19
- vectordb_bench/frontend/components/check_results/data.py +4 -16
- vectordb_bench/frontend/components/check_results/filters.py +8 -16
- vectordb_bench/frontend/components/check_results/nav.py +4 -4
- vectordb_bench/frontend/components/check_results/priceTable.py +1 -3
- vectordb_bench/frontend/components/check_results/stPageConfig.py +2 -1
- vectordb_bench/frontend/components/concurrent/charts.py +12 -12
- vectordb_bench/frontend/components/custom/displayCustomCase.py +17 -11
- vectordb_bench/frontend/components/custom/displaypPrams.py +4 -2
- vectordb_bench/frontend/components/custom/getCustomConfig.py +1 -2
- vectordb_bench/frontend/components/custom/initStyle.py +1 -1
- vectordb_bench/frontend/components/get_results/saveAsImage.py +2 -0
- vectordb_bench/frontend/components/run_test/caseSelector.py +3 -9
- vectordb_bench/frontend/components/run_test/dbConfigSetting.py +1 -4
- vectordb_bench/frontend/components/run_test/dbSelector.py +1 -1
- vectordb_bench/frontend/components/run_test/generateTasks.py +8 -8
- vectordb_bench/frontend/components/run_test/submitTask.py +14 -18
- vectordb_bench/frontend/components/tables/data.py +3 -6
- vectordb_bench/frontend/config/dbCaseConfigs.py +51 -84
- vectordb_bench/frontend/pages/concurrent.py +3 -5
- vectordb_bench/frontend/pages/custom.py +30 -9
- vectordb_bench/frontend/pages/quries_per_dollar.py +3 -3
- vectordb_bench/frontend/pages/run_test.py +3 -7
- vectordb_bench/frontend/utils.py +1 -1
- vectordb_bench/frontend/vdb_benchmark.py +4 -6
- vectordb_bench/interface.py +45 -24
- vectordb_bench/log_util.py +59 -64
- vectordb_bench/metric.py +10 -11
- vectordb_bench/models.py +26 -43
- {vectordb_bench-0.0.19.dist-info → vectordb_bench-0.0.21.dist-info}/METADATA +22 -15
- vectordb_bench-0.0.21.dist-info/RECORD +135 -0
- {vectordb_bench-0.0.19.dist-info → vectordb_bench-0.0.21.dist-info}/WHEEL +1 -1
- vectordb_bench-0.0.19.dist-info/RECORD +0 -135
- {vectordb_bench-0.0.19.dist-info → vectordb_bench-0.0.21.dist-info}/LICENSE +0 -0
- {vectordb_bench-0.0.19.dist-info → vectordb_bench-0.0.21.dist-info}/entry_points.txt +0 -0
- {vectordb_bench-0.0.19.dist-info → vectordb_bench-0.0.21.dist-info}/top_level.txt +0 -0
@@ -1,55 +1,55 @@
|
|
1
|
-
import
|
2
|
-
import logging
|
1
|
+
import logging
|
3
2
|
from contextlib import contextmanager
|
4
3
|
from typing import Any
|
5
|
-
|
4
|
+
|
5
|
+
import chromadb
|
6
|
+
|
7
|
+
from ..api import DBCaseConfig, VectorDB
|
6
8
|
|
7
9
|
log = logging.getLogger(__name__)
|
10
|
+
|
11
|
+
|
8
12
|
class ChromaClient(VectorDB):
|
9
|
-
"""Chroma client for VectorDB.
|
13
|
+
"""Chroma client for VectorDB.
|
10
14
|
To set up Chroma in docker, see https://docs.trychroma.com/usage-guide
|
11
15
|
or the instructions in tests/test_chroma.py
|
12
16
|
|
13
17
|
To change to running in process, modify the HttpClient() in __init__() and init().
|
14
|
-
"""
|
18
|
+
"""
|
15
19
|
|
16
20
|
def __init__(
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
):
|
25
|
-
|
21
|
+
self,
|
22
|
+
dim: int,
|
23
|
+
db_config: dict,
|
24
|
+
db_case_config: DBCaseConfig,
|
25
|
+
drop_old: bool = False,
|
26
|
+
**kwargs,
|
27
|
+
):
|
26
28
|
self.db_config = db_config
|
27
29
|
self.case_config = db_case_config
|
28
|
-
self.collection_name =
|
30
|
+
self.collection_name = "example2"
|
29
31
|
|
30
|
-
client = chromadb.HttpClient(host=self.db_config["host"],
|
31
|
-
port=self.db_config["port"])
|
32
|
+
client = chromadb.HttpClient(host=self.db_config["host"], port=self.db_config["port"])
|
32
33
|
assert client.heartbeat() is not None
|
33
34
|
if drop_old:
|
34
35
|
try:
|
35
|
-
client.reset()
|
36
|
-
except:
|
36
|
+
client.reset() # Reset the database
|
37
|
+
except Exception:
|
37
38
|
drop_old = False
|
38
39
|
log.info(f"Chroma client drop_old collection: {self.collection_name}")
|
39
40
|
|
40
41
|
@contextmanager
|
41
42
|
def init(self) -> None:
|
42
|
-
"""
|
43
|
+
"""create and destory connections to database.
|
43
44
|
|
44
45
|
Examples:
|
45
46
|
>>> with self.init():
|
46
47
|
>>> self.insert_embeddings()
|
47
48
|
"""
|
48
|
-
#create connection
|
49
|
-
self.client = chromadb.HttpClient(host=self.db_config["host"],
|
50
|
-
|
51
|
-
|
52
|
-
self.collection = self.client.get_or_create_collection('example2')
|
49
|
+
# create connection
|
50
|
+
self.client = chromadb.HttpClient(host=self.db_config["host"], port=self.db_config["port"])
|
51
|
+
|
52
|
+
self.collection = self.client.get_or_create_collection("example2")
|
53
53
|
yield
|
54
54
|
self.client = None
|
55
55
|
self.collection = None
|
@@ -57,10 +57,7 @@ class ChromaClient(VectorDB):
|
|
57
57
|
def ready_to_search(self) -> bool:
|
58
58
|
pass
|
59
59
|
|
60
|
-
def
|
61
|
-
pass
|
62
|
-
|
63
|
-
def optimize(self) -> None:
|
60
|
+
def optimize(self, data_size: int | None = None):
|
64
61
|
pass
|
65
62
|
|
66
63
|
def insert_embeddings(
|
@@ -79,12 +76,12 @@ class ChromaClient(VectorDB):
|
|
79
76
|
Returns:
|
80
77
|
(int, Exception): number of embeddings inserted and exception if any
|
81
78
|
"""
|
82
|
-
ids=[str(i) for i in metadata]
|
83
|
-
metadata = [{"id": int(i)} for i in metadata]
|
79
|
+
ids = [str(i) for i in metadata]
|
80
|
+
metadata = [{"id": int(i)} for i in metadata]
|
84
81
|
if len(embeddings) > 0:
|
85
82
|
self.collection.add(embeddings=embeddings, ids=ids, metadatas=metadata)
|
86
83
|
return len(embeddings), None
|
87
|
-
|
84
|
+
|
88
85
|
def search_embedding(
|
89
86
|
self,
|
90
87
|
query: list[float],
|
@@ -100,17 +97,19 @@ class ChromaClient(VectorDB):
|
|
100
97
|
kwargs: other arguments
|
101
98
|
|
102
99
|
Returns:
|
103
|
-
Dict {ids: list[list[int]],
|
104
|
-
embedding: list[list[float]]
|
100
|
+
Dict {ids: list[list[int]],
|
101
|
+
embedding: list[list[float]]
|
105
102
|
distance: list[list[float]]}
|
106
103
|
"""
|
107
104
|
if filters:
|
108
105
|
# assumes benchmark test filters of format: {'metadata': '>=10000', 'id': 10000}
|
109
106
|
id_value = filters.get("id")
|
110
|
-
results = self.collection.query(
|
111
|
-
|
112
|
-
|
113
|
-
|
107
|
+
results = self.collection.query(
|
108
|
+
query_embeddings=query,
|
109
|
+
n_results=k,
|
110
|
+
where={"id": {"$gt": id_value}},
|
111
|
+
)
|
112
|
+
# return list of id's in results
|
113
|
+
return [int(i) for i in results.get("ids")[0]]
|
114
114
|
results = self.collection.query(query_embeddings=query, n_results=k)
|
115
|
-
return [int(i) for i in results.get(
|
116
|
-
|
115
|
+
return [int(i) for i in results.get("ids")[0]]
|
@@ -1,14 +1,16 @@
|
|
1
1
|
from pydantic import SecretStr
|
2
|
+
|
2
3
|
from ..api import DBConfig
|
3
4
|
|
5
|
+
|
4
6
|
class ChromaConfig(DBConfig):
|
5
7
|
password: SecretStr
|
6
8
|
host: SecretStr
|
7
|
-
port: int
|
9
|
+
port: int
|
8
10
|
|
9
11
|
def to_dict(self) -> dict:
|
10
12
|
return {
|
11
13
|
"host": self.host.get_secret_value(),
|
12
14
|
"port": self.port,
|
13
15
|
"password": self.password.get_secret_value(),
|
14
|
-
}
|
16
|
+
}
|
@@ -1,7 +1,8 @@
|
|
1
1
|
from enum import Enum
|
2
|
-
from pydantic import SecretStr, BaseModel
|
3
2
|
|
4
|
-
from
|
3
|
+
from pydantic import BaseModel, SecretStr
|
4
|
+
|
5
|
+
from ..api import DBCaseConfig, DBConfig, IndexType, MetricType
|
5
6
|
|
6
7
|
|
7
8
|
class ElasticCloudConfig(DBConfig, BaseModel):
|
@@ -32,12 +33,12 @@ class ElasticCloudIndexConfig(BaseModel, DBCaseConfig):
|
|
32
33
|
def parse_metric(self) -> str:
|
33
34
|
if self.metric_type == MetricType.L2:
|
34
35
|
return "l2_norm"
|
35
|
-
|
36
|
+
if self.metric_type == MetricType.IP:
|
36
37
|
return "dot_product"
|
37
38
|
return "cosine"
|
38
39
|
|
39
40
|
def index_param(self) -> dict:
|
40
|
-
|
41
|
+
return {
|
41
42
|
"type": "dense_vector",
|
42
43
|
"index": True,
|
43
44
|
"element_type": self.element_type.value,
|
@@ -48,7 +49,6 @@ class ElasticCloudIndexConfig(BaseModel, DBCaseConfig):
|
|
48
49
|
"ef_construction": self.efConstruction,
|
49
50
|
},
|
50
51
|
}
|
51
|
-
return params
|
52
52
|
|
53
53
|
def search_param(self) -> dict:
|
54
54
|
return {
|
@@ -1,17 +1,22 @@
|
|
1
1
|
import logging
|
2
2
|
import time
|
3
|
+
from collections.abc import Iterable
|
3
4
|
from contextlib import contextmanager
|
4
|
-
|
5
|
-
from ..api import VectorDB
|
6
|
-
from .config import ElasticCloudIndexConfig
|
5
|
+
|
7
6
|
from elasticsearch.helpers import bulk
|
8
7
|
|
8
|
+
from ..api import VectorDB
|
9
|
+
from .config import ElasticCloudIndexConfig
|
9
10
|
|
10
11
|
for logger in ("elasticsearch", "elastic_transport"):
|
11
12
|
logging.getLogger(logger).setLevel(logging.WARNING)
|
12
13
|
|
13
14
|
log = logging.getLogger(__name__)
|
14
15
|
|
16
|
+
|
17
|
+
SECONDS_WAITING_FOR_FORCE_MERGE_API_CALL_SEC = 30
|
18
|
+
|
19
|
+
|
15
20
|
class ElasticCloud(VectorDB):
|
16
21
|
def __init__(
|
17
22
|
self,
|
@@ -46,14 +51,14 @@ class ElasticCloud(VectorDB):
|
|
46
51
|
def init(self) -> None:
|
47
52
|
"""connect to elasticsearch"""
|
48
53
|
from elasticsearch import Elasticsearch
|
54
|
+
|
49
55
|
self.client = Elasticsearch(**self.db_config, request_timeout=180)
|
50
56
|
|
51
57
|
yield
|
52
|
-
# self.client.transport.close()
|
53
58
|
self.client = None
|
54
|
-
del
|
59
|
+
del self.client
|
55
60
|
|
56
|
-
def _create_indice(self, client) -> None:
|
61
|
+
def _create_indice(self, client: any) -> None:
|
57
62
|
mappings = {
|
58
63
|
"_source": {"excludes": [self.vector_col_name]},
|
59
64
|
"properties": {
|
@@ -62,13 +67,13 @@ class ElasticCloud(VectorDB):
|
|
62
67
|
"dims": self.dim,
|
63
68
|
**self.case_config.index_param(),
|
64
69
|
},
|
65
|
-
}
|
70
|
+
},
|
66
71
|
}
|
67
72
|
|
68
73
|
try:
|
69
74
|
client.indices.create(index=self.indice, mappings=mappings)
|
70
75
|
except Exception as e:
|
71
|
-
log.warning(f"Failed to create indice: {self.indice} error: {
|
76
|
+
log.warning(f"Failed to create indice: {self.indice} error: {e!s}")
|
72
77
|
raise e from None
|
73
78
|
|
74
79
|
def insert_embeddings(
|
@@ -94,7 +99,7 @@ class ElasticCloud(VectorDB):
|
|
94
99
|
bulk_insert_res = bulk(self.client, insert_data)
|
95
100
|
return (bulk_insert_res[0], None)
|
96
101
|
except Exception as e:
|
97
|
-
log.warning(f"Failed to insert data: {self.indice} error: {
|
102
|
+
log.warning(f"Failed to insert data: {self.indice} error: {e!s}")
|
98
103
|
return (0, e)
|
99
104
|
|
100
105
|
def search_embedding(
|
@@ -114,16 +119,12 @@ class ElasticCloud(VectorDB):
|
|
114
119
|
list[tuple[int, float]]: list of k most similar embeddings in (id, score) tuple to the query embedding.
|
115
120
|
"""
|
116
121
|
assert self.client is not None, "should self.init() first"
|
117
|
-
# is_existed_res = self.client.indices.exists(index=self.indice)
|
118
|
-
# assert is_existed_res.raw == True, "should self.init() first"
|
119
122
|
|
120
123
|
knn = {
|
121
124
|
"field": self.vector_col_name,
|
122
125
|
"k": k,
|
123
126
|
"num_candidates": self.case_config.num_candidates,
|
124
|
-
"filter": [{"range": {self.id_col_name: {"gt": filters["id"]}}}]
|
125
|
-
if filters
|
126
|
-
else [],
|
127
|
+
"filter": [{"range": {self.id_col_name: {"gt": filters["id"]}}}] if filters else [],
|
127
128
|
"query_vector": query,
|
128
129
|
}
|
129
130
|
size = k
|
@@ -137,26 +138,23 @@ class ElasticCloud(VectorDB):
|
|
137
138
|
stored_fields="_none_",
|
138
139
|
filter_path=[f"hits.hits.fields.{self.id_col_name}"],
|
139
140
|
)
|
140
|
-
|
141
|
-
|
142
|
-
return res
|
141
|
+
return [h["fields"][self.id_col_name][0] for h in res["hits"]["hits"]]
|
143
142
|
except Exception as e:
|
144
|
-
log.warning(f"Failed to search: {self.indice} error: {
|
143
|
+
log.warning(f"Failed to search: {self.indice} error: {e!s}")
|
145
144
|
raise e from None
|
146
145
|
|
147
|
-
def optimize(self):
|
146
|
+
def optimize(self, data_size: int | None = None):
|
148
147
|
"""optimize will be called between insertion and search in performance cases."""
|
149
148
|
assert self.client is not None, "should self.init() first"
|
150
149
|
self.client.indices.refresh(index=self.indice)
|
151
|
-
force_merge_task_id = self.client.indices.forcemerge(
|
150
|
+
force_merge_task_id = self.client.indices.forcemerge(
|
151
|
+
index=self.indice,
|
152
|
+
max_num_segments=1,
|
153
|
+
wait_for_completion=False,
|
154
|
+
)["task"]
|
152
155
|
log.info(f"Elasticsearch force merge task id: {force_merge_task_id}")
|
153
|
-
SECONDS_WAITING_FOR_FORCE_MERGE_API_CALL_SEC = 30
|
154
156
|
while True:
|
155
157
|
time.sleep(SECONDS_WAITING_FOR_FORCE_MERGE_API_CALL_SEC)
|
156
158
|
task_status = self.client.tasks.get(task_id=force_merge_task_id)
|
157
|
-
if task_status[
|
159
|
+
if task_status["completed"]:
|
158
160
|
return
|
159
|
-
|
160
|
-
def ready_to_load(self):
|
161
|
-
"""ready_to_load will be called before load in load cases."""
|
162
|
-
pass
|
@@ -14,9 +14,7 @@ from .. import DB
|
|
14
14
|
|
15
15
|
|
16
16
|
class MemoryDBTypedDict(TypedDict):
|
17
|
-
host: Annotated[
|
18
|
-
str, click.option("--host", type=str, help="Db host", required=True)
|
19
|
-
]
|
17
|
+
host: Annotated[str, click.option("--host", type=str, help="Db host", required=True)]
|
20
18
|
password: Annotated[str, click.option("--password", type=str, help="Db password")]
|
21
19
|
port: Annotated[int, click.option("--port", type=int, default=6379, help="Db Port")]
|
22
20
|
ssl: Annotated[
|
@@ -44,7 +42,10 @@ class MemoryDBTypedDict(TypedDict):
|
|
44
42
|
is_flag=True,
|
45
43
|
show_default=True,
|
46
44
|
default=False,
|
47
|
-
help=
|
45
|
+
help=(
|
46
|
+
"Cluster Mode Disabled (CMD), use this flag when testing locally on a single node instance."
|
47
|
+
" In production, MemoryDB only supports cluster mode (CME)"
|
48
|
+
),
|
48
49
|
),
|
49
50
|
]
|
50
51
|
insert_batch_size: Annotated[
|
@@ -58,8 +59,7 @@ class MemoryDBTypedDict(TypedDict):
|
|
58
59
|
]
|
59
60
|
|
60
61
|
|
61
|
-
class MemoryDBHNSWTypedDict(CommonTypedDict, MemoryDBTypedDict, HNSWFlavor2):
|
62
|
-
...
|
62
|
+
class MemoryDBHNSWTypedDict(CommonTypedDict, MemoryDBTypedDict, HNSWFlavor2): ...
|
63
63
|
|
64
64
|
|
65
65
|
@cli.command()
|
@@ -82,7 +82,7 @@ def MemoryDB(**parameters: Unpack[MemoryDBHNSWTypedDict]):
|
|
82
82
|
M=parameters["m"],
|
83
83
|
ef_construction=parameters["ef_construction"],
|
84
84
|
ef_runtime=parameters["ef_runtime"],
|
85
|
-
insert_batch_size=parameters["insert_batch_size"]
|
85
|
+
insert_batch_size=parameters["insert_batch_size"],
|
86
86
|
),
|
87
87
|
**parameters,
|
88
|
-
)
|
88
|
+
)
|
@@ -29,7 +29,7 @@ class MemoryDBIndexConfig(BaseModel, DBCaseConfig):
|
|
29
29
|
def parse_metric(self) -> str:
|
30
30
|
if self.metric_type == MetricType.L2:
|
31
31
|
return "l2"
|
32
|
-
|
32
|
+
if self.metric_type == MetricType.IP:
|
33
33
|
return "ip"
|
34
34
|
return "cosine"
|
35
35
|
|
@@ -51,4 +51,4 @@ class MemoryDBHNSWConfig(MemoryDBIndexConfig):
|
|
51
51
|
def search_param(self) -> dict:
|
52
52
|
return {
|
53
53
|
"ef_runtime": self.ef_runtime,
|
54
|
-
}
|
54
|
+
}
|
@@ -1,30 +1,33 @@
|
|
1
|
-
import logging
|
1
|
+
import logging
|
2
|
+
import time
|
3
|
+
from collections.abc import Generator
|
2
4
|
from contextlib import contextmanager
|
3
|
-
from typing import Any
|
4
|
-
|
5
|
-
|
5
|
+
from typing import Any
|
6
|
+
|
7
|
+
import numpy as np
|
6
8
|
import redis
|
7
9
|
from redis import Redis
|
8
10
|
from redis.cluster import RedisCluster
|
9
|
-
from redis.commands.search.field import TagField, VectorField
|
10
|
-
from redis.commands.search.indexDefinition import IndexDefinition
|
11
|
+
from redis.commands.search.field import NumericField, TagField, VectorField
|
12
|
+
from redis.commands.search.indexDefinition import IndexDefinition
|
11
13
|
from redis.commands.search.query import Query
|
12
|
-
import numpy as np
|
13
14
|
|
15
|
+
from ..api import IndexType, VectorDB
|
16
|
+
from .config import MemoryDBIndexConfig
|
14
17
|
|
15
18
|
log = logging.getLogger(__name__)
|
16
|
-
INDEX_NAME = "index"
|
19
|
+
INDEX_NAME = "index" # Vector Index Name
|
20
|
+
|
17
21
|
|
18
22
|
class MemoryDB(VectorDB):
|
19
23
|
def __init__(
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
24
|
+
self,
|
25
|
+
dim: int,
|
26
|
+
db_config: dict,
|
27
|
+
db_case_config: MemoryDBIndexConfig,
|
28
|
+
drop_old: bool = False,
|
29
|
+
**kwargs,
|
30
|
+
):
|
28
31
|
self.db_config = db_config
|
29
32
|
self.case_config = db_case_config
|
30
33
|
self.collection_name = INDEX_NAME
|
@@ -44,10 +47,10 @@ class MemoryDB(VectorDB):
|
|
44
47
|
info = conn.ft(INDEX_NAME).info()
|
45
48
|
log.info(f"Index info: {info}")
|
46
49
|
except redis.exceptions.ResponseError as e:
|
47
|
-
log.
|
50
|
+
log.warning(e)
|
48
51
|
drop_old = False
|
49
52
|
log.info(f"MemoryDB client drop_old collection: {self.collection_name}")
|
50
|
-
|
53
|
+
|
51
54
|
log.info("Executing FLUSHALL")
|
52
55
|
conn.flushall()
|
53
56
|
|
@@ -59,7 +62,7 @@ class MemoryDB(VectorDB):
|
|
59
62
|
self.wait_until(self.wait_for_empty_db, 3, "", rc)
|
60
63
|
log.debug(f"Flushall done in the host: {host}")
|
61
64
|
rc.close()
|
62
|
-
|
65
|
+
|
63
66
|
self.make_index(dim, conn)
|
64
67
|
conn.close()
|
65
68
|
conn = None
|
@@ -69,7 +72,7 @@ class MemoryDB(VectorDB):
|
|
69
72
|
# check to see if index exists
|
70
73
|
conn.ft(INDEX_NAME).info()
|
71
74
|
except Exception as e:
|
72
|
-
log.
|
75
|
+
log.warning(f"Error getting info for index '{INDEX_NAME}': {e}")
|
73
76
|
index_param = self.case_config.index_param()
|
74
77
|
search_param = self.case_config.search_param()
|
75
78
|
vector_parameters = { # Vector Index Type: FLAT or HNSW
|
@@ -85,17 +88,19 @@ class MemoryDB(VectorDB):
|
|
85
88
|
vector_parameters["EF_RUNTIME"] = search_param["ef_runtime"]
|
86
89
|
|
87
90
|
schema = (
|
88
|
-
TagField("id"),
|
89
|
-
NumericField("metadata"),
|
90
|
-
VectorField(
|
91
|
-
"
|
91
|
+
TagField("id"),
|
92
|
+
NumericField("metadata"),
|
93
|
+
VectorField(
|
94
|
+
"vector", # Vector Field Name
|
95
|
+
"HNSW",
|
96
|
+
vector_parameters,
|
92
97
|
),
|
93
98
|
)
|
94
99
|
|
95
100
|
definition = IndexDefinition(index_type=IndexType.HASH)
|
96
101
|
rs = conn.ft(INDEX_NAME)
|
97
102
|
rs.create_index(schema, definition=definition)
|
98
|
-
|
103
|
+
|
99
104
|
def get_client(self, **kwargs):
|
100
105
|
"""
|
101
106
|
Gets either cluster connection or normal connection based on `cmd` flag.
|
@@ -143,7 +148,7 @@ class MemoryDB(VectorDB):
|
|
143
148
|
|
144
149
|
@contextmanager
|
145
150
|
def init(self) -> Generator[None, None, None]:
|
146
|
-
"""
|
151
|
+
"""create and destory connections to database.
|
147
152
|
|
148
153
|
Examples:
|
149
154
|
>>> with self.init():
|
@@ -152,17 +157,14 @@ class MemoryDB(VectorDB):
|
|
152
157
|
self.conn = self.get_client()
|
153
158
|
search_param = self.case_config.search_param()
|
154
159
|
if search_param["ef_runtime"]:
|
155
|
-
self.ef_runtime_str = f
|
160
|
+
self.ef_runtime_str = f"EF_RUNTIME {search_param['ef_runtime']}"
|
156
161
|
else:
|
157
162
|
self.ef_runtime_str = ""
|
158
163
|
yield
|
159
164
|
self.conn.close()
|
160
165
|
self.conn = None
|
161
166
|
|
162
|
-
def
|
163
|
-
pass
|
164
|
-
|
165
|
-
def optimize(self) -> None:
|
167
|
+
def optimize(self, data_size: int | None = None):
|
166
168
|
self._post_insert()
|
167
169
|
|
168
170
|
def insert_embeddings(
|
@@ -170,7 +172,7 @@ class MemoryDB(VectorDB):
|
|
170
172
|
embeddings: list[list[float]],
|
171
173
|
metadata: list[int],
|
172
174
|
**kwargs: Any,
|
173
|
-
) ->
|
175
|
+
) -> tuple[int, Exception | None]:
|
174
176
|
"""Insert embeddings into the database.
|
175
177
|
Should call self.init() first.
|
176
178
|
"""
|
@@ -178,12 +180,15 @@ class MemoryDB(VectorDB):
|
|
178
180
|
try:
|
179
181
|
with self.conn.pipeline(transaction=False) as pipe:
|
180
182
|
for i, embedding in enumerate(embeddings):
|
181
|
-
|
182
|
-
pipe.hset(
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
183
|
+
ndarr_emb = np.array(embedding).astype(np.float32)
|
184
|
+
pipe.hset(
|
185
|
+
metadata[i],
|
186
|
+
mapping={
|
187
|
+
"id": str(metadata[i]),
|
188
|
+
"metadata": metadata[i],
|
189
|
+
"vector": ndarr_emb.tobytes(),
|
190
|
+
},
|
191
|
+
)
|
187
192
|
# Execute the pipe so we don't keep too much in memory at once
|
188
193
|
if (i + 1) % self.insert_batch_size == 0:
|
189
194
|
pipe.execute()
|
@@ -192,9 +197,9 @@ class MemoryDB(VectorDB):
|
|
192
197
|
result_len = i + 1
|
193
198
|
except Exception as e:
|
194
199
|
return 0, e
|
195
|
-
|
200
|
+
|
196
201
|
return result_len, None
|
197
|
-
|
202
|
+
|
198
203
|
def _post_insert(self):
|
199
204
|
"""Wait for indexing to finish"""
|
200
205
|
client = self.get_client(primary=True)
|
@@ -208,21 +213,17 @@ class MemoryDB(VectorDB):
|
|
208
213
|
self.wait_until(*args)
|
209
214
|
log.debug(f"Background indexing completed in the host: {host_name}")
|
210
215
|
rc.close()
|
211
|
-
|
212
|
-
def wait_until(
|
213
|
-
self, condition, interval=5, message="Operation took too long", *args
|
214
|
-
):
|
216
|
+
|
217
|
+
def wait_until(self, condition: any, interval: int = 5, message: str = "Operation took too long", *args):
|
215
218
|
while not condition(*args):
|
216
219
|
time.sleep(interval)
|
217
|
-
|
220
|
+
|
218
221
|
def wait_for_no_activity(self, client: redis.RedisCluster | redis.Redis):
|
219
|
-
return (
|
220
|
-
|
221
|
-
)
|
222
|
-
|
222
|
+
return client.info("search")["search_background_indexing_status"] == "NO_ACTIVITY"
|
223
|
+
|
223
224
|
def wait_for_empty_db(self, client: redis.RedisCluster | redis.Redis):
|
224
225
|
return client.execute_command("DBSIZE") == 0
|
225
|
-
|
226
|
+
|
226
227
|
def search_embedding(
|
227
228
|
self,
|
228
229
|
query: list[float],
|
@@ -230,13 +231,13 @@ class MemoryDB(VectorDB):
|
|
230
231
|
filters: dict | None = None,
|
231
232
|
timeout: int | None = None,
|
232
233
|
**kwargs: Any,
|
233
|
-
) ->
|
234
|
+
) -> list[int]:
|
234
235
|
assert self.conn is not None
|
235
|
-
|
236
|
+
|
236
237
|
query_vector = np.array(query).astype(np.float32).tobytes()
|
237
238
|
query_obj = Query(f"*=>[KNN {k} @vector $vec]").return_fields("id").paging(0, k)
|
238
239
|
query_params = {"vec": query_vector}
|
239
|
-
|
240
|
+
|
240
241
|
if filters:
|
241
242
|
# benchmark test filters of format: {'metadata': '>=10000', 'id': 10000}
|
242
243
|
# gets exact match for id, and range for metadata if they exist in filters
|
@@ -244,11 +245,19 @@ class MemoryDB(VectorDB):
|
|
244
245
|
# Removing '>=' from the id_value: '>=10000'
|
245
246
|
metadata_value = filters.get("metadata")[2:]
|
246
247
|
if id_value and metadata_value:
|
247
|
-
query_obj =
|
248
|
+
query_obj = (
|
249
|
+
Query(
|
250
|
+
f"(@metadata:[{metadata_value} +inf] @id:{ {id_value} })=>[KNN {k} @vector $vec]",
|
251
|
+
)
|
252
|
+
.return_fields("id")
|
253
|
+
.paging(0, k)
|
254
|
+
)
|
248
255
|
elif id_value:
|
249
|
-
#gets exact match for id
|
256
|
+
# gets exact match for id
|
250
257
|
query_obj = Query(f"@id:{ {id_value} }=>[KNN {k} @vector $vec]").return_fields("id").paging(0, k)
|
251
|
-
else:
|
252
|
-
query_obj =
|
258
|
+
else: # metadata only case, greater than or equal to metadata value
|
259
|
+
query_obj = (
|
260
|
+
Query(f"@metadata:[{metadata_value} +inf]=>[KNN {k} @vector $vec]").return_fields("id").paging(0, k)
|
261
|
+
)
|
253
262
|
res = self.conn.ft(INDEX_NAME).search(query_obj, query_params)
|
254
|
-
return [int(doc["id"]) for doc in res.docs]
|
263
|
+
return [int(doc["id"]) for doc in res.docs]
|