vectordb-bench 0.0.18__py3-none-any.whl → 0.0.20__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vectordb_bench/__init__.py +49 -24
- vectordb_bench/__main__.py +4 -3
- vectordb_bench/backend/assembler.py +12 -13
- vectordb_bench/backend/cases.py +56 -46
- vectordb_bench/backend/clients/__init__.py +101 -14
- vectordb_bench/backend/clients/aliyun_elasticsearch/aliyun_elasticsearch.py +26 -0
- vectordb_bench/backend/clients/aliyun_elasticsearch/config.py +18 -0
- vectordb_bench/backend/clients/aliyun_opensearch/aliyun_opensearch.py +345 -0
- vectordb_bench/backend/clients/aliyun_opensearch/config.py +47 -0
- vectordb_bench/backend/clients/alloydb/alloydb.py +58 -80
- vectordb_bench/backend/clients/alloydb/cli.py +52 -35
- vectordb_bench/backend/clients/alloydb/config.py +30 -30
- vectordb_bench/backend/clients/api.py +8 -9
- vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py +46 -47
- vectordb_bench/backend/clients/aws_opensearch/cli.py +4 -7
- vectordb_bench/backend/clients/aws_opensearch/config.py +13 -9
- vectordb_bench/backend/clients/aws_opensearch/run.py +69 -59
- vectordb_bench/backend/clients/chroma/chroma.py +38 -36
- vectordb_bench/backend/clients/chroma/config.py +4 -2
- vectordb_bench/backend/clients/elastic_cloud/config.py +5 -5
- vectordb_bench/backend/clients/elastic_cloud/elastic_cloud.py +23 -22
- vectordb_bench/backend/clients/memorydb/cli.py +8 -8
- vectordb_bench/backend/clients/memorydb/config.py +2 -2
- vectordb_bench/backend/clients/memorydb/memorydb.py +65 -53
- vectordb_bench/backend/clients/milvus/cli.py +62 -80
- vectordb_bench/backend/clients/milvus/config.py +31 -7
- vectordb_bench/backend/clients/milvus/milvus.py +23 -26
- vectordb_bench/backend/clients/pgdiskann/cli.py +29 -22
- vectordb_bench/backend/clients/pgdiskann/config.py +29 -26
- vectordb_bench/backend/clients/pgdiskann/pgdiskann.py +55 -73
- vectordb_bench/backend/clients/pgvecto_rs/cli.py +9 -11
- vectordb_bench/backend/clients/pgvecto_rs/config.py +8 -14
- vectordb_bench/backend/clients/pgvecto_rs/pgvecto_rs.py +33 -34
- vectordb_bench/backend/clients/pgvector/cli.py +40 -31
- vectordb_bench/backend/clients/pgvector/config.py +63 -73
- vectordb_bench/backend/clients/pgvector/pgvector.py +97 -98
- vectordb_bench/backend/clients/pgvectorscale/cli.py +38 -24
- vectordb_bench/backend/clients/pgvectorscale/config.py +14 -15
- vectordb_bench/backend/clients/pgvectorscale/pgvectorscale.py +38 -43
- vectordb_bench/backend/clients/pinecone/config.py +1 -0
- vectordb_bench/backend/clients/pinecone/pinecone.py +14 -21
- vectordb_bench/backend/clients/qdrant_cloud/config.py +11 -10
- vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py +40 -31
- vectordb_bench/backend/clients/redis/cli.py +6 -12
- vectordb_bench/backend/clients/redis/config.py +7 -5
- vectordb_bench/backend/clients/redis/redis.py +94 -58
- vectordb_bench/backend/clients/test/cli.py +1 -2
- vectordb_bench/backend/clients/test/config.py +2 -2
- vectordb_bench/backend/clients/test/test.py +4 -5
- vectordb_bench/backend/clients/weaviate_cloud/cli.py +3 -4
- vectordb_bench/backend/clients/weaviate_cloud/config.py +2 -2
- vectordb_bench/backend/clients/weaviate_cloud/weaviate_cloud.py +36 -22
- vectordb_bench/backend/clients/zilliz_cloud/cli.py +14 -11
- vectordb_bench/backend/clients/zilliz_cloud/config.py +2 -4
- vectordb_bench/backend/clients/zilliz_cloud/zilliz_cloud.py +1 -1
- vectordb_bench/backend/data_source.py +30 -18
- vectordb_bench/backend/dataset.py +47 -27
- vectordb_bench/backend/result_collector.py +2 -3
- vectordb_bench/backend/runner/__init__.py +4 -6
- vectordb_bench/backend/runner/mp_runner.py +85 -34
- vectordb_bench/backend/runner/rate_runner.py +51 -23
- vectordb_bench/backend/runner/read_write_runner.py +140 -46
- vectordb_bench/backend/runner/serial_runner.py +99 -50
- vectordb_bench/backend/runner/util.py +4 -19
- vectordb_bench/backend/task_runner.py +95 -74
- vectordb_bench/backend/utils.py +17 -9
- vectordb_bench/base.py +0 -1
- vectordb_bench/cli/cli.py +65 -60
- vectordb_bench/cli/vectordbbench.py +6 -7
- vectordb_bench/frontend/components/check_results/charts.py +8 -19
- vectordb_bench/frontend/components/check_results/data.py +4 -16
- vectordb_bench/frontend/components/check_results/filters.py +8 -16
- vectordb_bench/frontend/components/check_results/nav.py +4 -4
- vectordb_bench/frontend/components/check_results/priceTable.py +1 -3
- vectordb_bench/frontend/components/check_results/stPageConfig.py +2 -1
- vectordb_bench/frontend/components/concurrent/charts.py +12 -12
- vectordb_bench/frontend/components/custom/displayCustomCase.py +17 -11
- vectordb_bench/frontend/components/custom/displaypPrams.py +4 -2
- vectordb_bench/frontend/components/custom/getCustomConfig.py +1 -2
- vectordb_bench/frontend/components/custom/initStyle.py +1 -1
- vectordb_bench/frontend/components/get_results/saveAsImage.py +2 -0
- vectordb_bench/frontend/components/run_test/caseSelector.py +3 -9
- vectordb_bench/frontend/components/run_test/dbConfigSetting.py +1 -4
- vectordb_bench/frontend/components/run_test/dbSelector.py +1 -1
- vectordb_bench/frontend/components/run_test/generateTasks.py +8 -8
- vectordb_bench/frontend/components/run_test/submitTask.py +14 -18
- vectordb_bench/frontend/components/tables/data.py +3 -6
- vectordb_bench/frontend/config/dbCaseConfigs.py +108 -83
- vectordb_bench/frontend/pages/concurrent.py +3 -5
- vectordb_bench/frontend/pages/custom.py +30 -9
- vectordb_bench/frontend/pages/quries_per_dollar.py +3 -3
- vectordb_bench/frontend/pages/run_test.py +3 -7
- vectordb_bench/frontend/utils.py +1 -1
- vectordb_bench/frontend/vdb_benchmark.py +4 -6
- vectordb_bench/interface.py +56 -26
- vectordb_bench/log_util.py +59 -64
- vectordb_bench/metric.py +10 -11
- vectordb_bench/models.py +26 -43
- {vectordb_bench-0.0.18.dist-info → vectordb_bench-0.0.20.dist-info}/METADATA +34 -42
- vectordb_bench-0.0.20.dist-info/RECORD +135 -0
- {vectordb_bench-0.0.18.dist-info → vectordb_bench-0.0.20.dist-info}/WHEEL +1 -1
- vectordb_bench-0.0.18.dist-info/RECORD +0 -131
- {vectordb_bench-0.0.18.dist-info → vectordb_bench-0.0.20.dist-info}/LICENSE +0 -0
- {vectordb_bench-0.0.18.dist-info → vectordb_bench-0.0.20.dist-info}/entry_points.txt +0 -0
- {vectordb_bench-0.0.18.dist-info → vectordb_bench-0.0.20.dist-info}/top_level.txt +0 -0
@@ -1,14 +1,16 @@
|
|
1
1
|
from pydantic import SecretStr
|
2
|
+
|
2
3
|
from ..api import DBConfig
|
3
4
|
|
5
|
+
|
4
6
|
class ChromaConfig(DBConfig):
|
5
7
|
password: SecretStr
|
6
8
|
host: SecretStr
|
7
|
-
port: int
|
9
|
+
port: int
|
8
10
|
|
9
11
|
def to_dict(self) -> dict:
|
10
12
|
return {
|
11
13
|
"host": self.host.get_secret_value(),
|
12
14
|
"port": self.port,
|
13
15
|
"password": self.password.get_secret_value(),
|
14
|
-
}
|
16
|
+
}
|
@@ -1,7 +1,8 @@
|
|
1
1
|
from enum import Enum
|
2
|
-
from pydantic import SecretStr, BaseModel
|
3
2
|
|
4
|
-
from
|
3
|
+
from pydantic import BaseModel, SecretStr
|
4
|
+
|
5
|
+
from ..api import DBCaseConfig, DBConfig, IndexType, MetricType
|
5
6
|
|
6
7
|
|
7
8
|
class ElasticCloudConfig(DBConfig, BaseModel):
|
@@ -32,12 +33,12 @@ class ElasticCloudIndexConfig(BaseModel, DBCaseConfig):
|
|
32
33
|
def parse_metric(self) -> str:
|
33
34
|
if self.metric_type == MetricType.L2:
|
34
35
|
return "l2_norm"
|
35
|
-
|
36
|
+
if self.metric_type == MetricType.IP:
|
36
37
|
return "dot_product"
|
37
38
|
return "cosine"
|
38
39
|
|
39
40
|
def index_param(self) -> dict:
|
40
|
-
|
41
|
+
return {
|
41
42
|
"type": "dense_vector",
|
42
43
|
"index": True,
|
43
44
|
"element_type": self.element_type.value,
|
@@ -48,7 +49,6 @@ class ElasticCloudIndexConfig(BaseModel, DBCaseConfig):
|
|
48
49
|
"ef_construction": self.efConstruction,
|
49
50
|
},
|
50
51
|
}
|
51
|
-
return params
|
52
52
|
|
53
53
|
def search_param(self) -> dict:
|
54
54
|
return {
|
@@ -1,17 +1,22 @@
|
|
1
1
|
import logging
|
2
2
|
import time
|
3
|
+
from collections.abc import Iterable
|
3
4
|
from contextlib import contextmanager
|
4
|
-
|
5
|
-
from ..api import VectorDB
|
6
|
-
from .config import ElasticCloudIndexConfig
|
5
|
+
|
7
6
|
from elasticsearch.helpers import bulk
|
8
7
|
|
8
|
+
from ..api import VectorDB
|
9
|
+
from .config import ElasticCloudIndexConfig
|
9
10
|
|
10
11
|
for logger in ("elasticsearch", "elastic_transport"):
|
11
12
|
logging.getLogger(logger).setLevel(logging.WARNING)
|
12
13
|
|
13
14
|
log = logging.getLogger(__name__)
|
14
15
|
|
16
|
+
|
17
|
+
SECONDS_WAITING_FOR_FORCE_MERGE_API_CALL_SEC = 30
|
18
|
+
|
19
|
+
|
15
20
|
class ElasticCloud(VectorDB):
|
16
21
|
def __init__(
|
17
22
|
self,
|
@@ -46,14 +51,14 @@ class ElasticCloud(VectorDB):
|
|
46
51
|
def init(self) -> None:
|
47
52
|
"""connect to elasticsearch"""
|
48
53
|
from elasticsearch import Elasticsearch
|
54
|
+
|
49
55
|
self.client = Elasticsearch(**self.db_config, request_timeout=180)
|
50
56
|
|
51
57
|
yield
|
52
|
-
# self.client.transport.close()
|
53
58
|
self.client = None
|
54
|
-
del
|
59
|
+
del self.client
|
55
60
|
|
56
|
-
def _create_indice(self, client) -> None:
|
61
|
+
def _create_indice(self, client: any) -> None:
|
57
62
|
mappings = {
|
58
63
|
"_source": {"excludes": [self.vector_col_name]},
|
59
64
|
"properties": {
|
@@ -62,13 +67,13 @@ class ElasticCloud(VectorDB):
|
|
62
67
|
"dims": self.dim,
|
63
68
|
**self.case_config.index_param(),
|
64
69
|
},
|
65
|
-
}
|
70
|
+
},
|
66
71
|
}
|
67
72
|
|
68
73
|
try:
|
69
74
|
client.indices.create(index=self.indice, mappings=mappings)
|
70
75
|
except Exception as e:
|
71
|
-
log.warning(f"Failed to create indice: {self.indice} error: {
|
76
|
+
log.warning(f"Failed to create indice: {self.indice} error: {e!s}")
|
72
77
|
raise e from None
|
73
78
|
|
74
79
|
def insert_embeddings(
|
@@ -94,7 +99,7 @@ class ElasticCloud(VectorDB):
|
|
94
99
|
bulk_insert_res = bulk(self.client, insert_data)
|
95
100
|
return (bulk_insert_res[0], None)
|
96
101
|
except Exception as e:
|
97
|
-
log.warning(f"Failed to insert data: {self.indice} error: {
|
102
|
+
log.warning(f"Failed to insert data: {self.indice} error: {e!s}")
|
98
103
|
return (0, e)
|
99
104
|
|
100
105
|
def search_embedding(
|
@@ -114,16 +119,12 @@ class ElasticCloud(VectorDB):
|
|
114
119
|
list[tuple[int, float]]: list of k most similar embeddings in (id, score) tuple to the query embedding.
|
115
120
|
"""
|
116
121
|
assert self.client is not None, "should self.init() first"
|
117
|
-
# is_existed_res = self.client.indices.exists(index=self.indice)
|
118
|
-
# assert is_existed_res.raw == True, "should self.init() first"
|
119
122
|
|
120
123
|
knn = {
|
121
124
|
"field": self.vector_col_name,
|
122
125
|
"k": k,
|
123
126
|
"num_candidates": self.case_config.num_candidates,
|
124
|
-
"filter": [{"range": {self.id_col_name: {"gt": filters["id"]}}}]
|
125
|
-
if filters
|
126
|
-
else [],
|
127
|
+
"filter": [{"range": {self.id_col_name: {"gt": filters["id"]}}}] if filters else [],
|
127
128
|
"query_vector": query,
|
128
129
|
}
|
129
130
|
size = k
|
@@ -137,26 +138,26 @@ class ElasticCloud(VectorDB):
|
|
137
138
|
stored_fields="_none_",
|
138
139
|
filter_path=[f"hits.hits.fields.{self.id_col_name}"],
|
139
140
|
)
|
140
|
-
|
141
|
-
|
142
|
-
return res
|
141
|
+
return [h["fields"][self.id_col_name][0] for h in res["hits"]["hits"]]
|
143
142
|
except Exception as e:
|
144
|
-
log.warning(f"Failed to search: {self.indice} error: {
|
143
|
+
log.warning(f"Failed to search: {self.indice} error: {e!s}")
|
145
144
|
raise e from None
|
146
145
|
|
147
146
|
def optimize(self):
|
148
147
|
"""optimize will be called between insertion and search in performance cases."""
|
149
148
|
assert self.client is not None, "should self.init() first"
|
150
149
|
self.client.indices.refresh(index=self.indice)
|
151
|
-
force_merge_task_id = self.client.indices.forcemerge(
|
150
|
+
force_merge_task_id = self.client.indices.forcemerge(
|
151
|
+
index=self.indice,
|
152
|
+
max_num_segments=1,
|
153
|
+
wait_for_completion=False,
|
154
|
+
)["task"]
|
152
155
|
log.info(f"Elasticsearch force merge task id: {force_merge_task_id}")
|
153
|
-
SECONDS_WAITING_FOR_FORCE_MERGE_API_CALL_SEC = 30
|
154
156
|
while True:
|
155
157
|
time.sleep(SECONDS_WAITING_FOR_FORCE_MERGE_API_CALL_SEC)
|
156
158
|
task_status = self.client.tasks.get(task_id=force_merge_task_id)
|
157
|
-
if task_status[
|
159
|
+
if task_status["completed"]:
|
158
160
|
return
|
159
161
|
|
160
162
|
def ready_to_load(self):
|
161
163
|
"""ready_to_load will be called before load in load cases."""
|
162
|
-
pass
|
@@ -14,9 +14,7 @@ from .. import DB
|
|
14
14
|
|
15
15
|
|
16
16
|
class MemoryDBTypedDict(TypedDict):
|
17
|
-
host: Annotated[
|
18
|
-
str, click.option("--host", type=str, help="Db host", required=True)
|
19
|
-
]
|
17
|
+
host: Annotated[str, click.option("--host", type=str, help="Db host", required=True)]
|
20
18
|
password: Annotated[str, click.option("--password", type=str, help="Db password")]
|
21
19
|
port: Annotated[int, click.option("--port", type=int, default=6379, help="Db Port")]
|
22
20
|
ssl: Annotated[
|
@@ -44,7 +42,10 @@ class MemoryDBTypedDict(TypedDict):
|
|
44
42
|
is_flag=True,
|
45
43
|
show_default=True,
|
46
44
|
default=False,
|
47
|
-
help=
|
45
|
+
help=(
|
46
|
+
"Cluster Mode Disabled (CMD), use this flag when testing locally on a single node instance.",
|
47
|
+
" In production, MemoryDB only supports cluster mode (CME)",
|
48
|
+
),
|
48
49
|
),
|
49
50
|
]
|
50
51
|
insert_batch_size: Annotated[
|
@@ -58,8 +59,7 @@ class MemoryDBTypedDict(TypedDict):
|
|
58
59
|
]
|
59
60
|
|
60
61
|
|
61
|
-
class MemoryDBHNSWTypedDict(CommonTypedDict, MemoryDBTypedDict, HNSWFlavor2):
|
62
|
-
...
|
62
|
+
class MemoryDBHNSWTypedDict(CommonTypedDict, MemoryDBTypedDict, HNSWFlavor2): ...
|
63
63
|
|
64
64
|
|
65
65
|
@cli.command()
|
@@ -82,7 +82,7 @@ def MemoryDB(**parameters: Unpack[MemoryDBHNSWTypedDict]):
|
|
82
82
|
M=parameters["m"],
|
83
83
|
ef_construction=parameters["ef_construction"],
|
84
84
|
ef_runtime=parameters["ef_runtime"],
|
85
|
-
insert_batch_size=parameters["insert_batch_size"]
|
85
|
+
insert_batch_size=parameters["insert_batch_size"],
|
86
86
|
),
|
87
87
|
**parameters,
|
88
|
-
)
|
88
|
+
)
|
@@ -29,7 +29,7 @@ class MemoryDBIndexConfig(BaseModel, DBCaseConfig):
|
|
29
29
|
def parse_metric(self) -> str:
|
30
30
|
if self.metric_type == MetricType.L2:
|
31
31
|
return "l2"
|
32
|
-
|
32
|
+
if self.metric_type == MetricType.IP:
|
33
33
|
return "ip"
|
34
34
|
return "cosine"
|
35
35
|
|
@@ -51,4 +51,4 @@ class MemoryDBHNSWConfig(MemoryDBIndexConfig):
|
|
51
51
|
def search_param(self) -> dict:
|
52
52
|
return {
|
53
53
|
"ef_runtime": self.ef_runtime,
|
54
|
-
}
|
54
|
+
}
|
@@ -1,30 +1,33 @@
|
|
1
|
-
import logging
|
1
|
+
import logging
|
2
|
+
import time
|
3
|
+
from collections.abc import Generator
|
2
4
|
from contextlib import contextmanager
|
3
|
-
from typing import Any
|
4
|
-
|
5
|
-
|
5
|
+
from typing import Any
|
6
|
+
|
7
|
+
import numpy as np
|
6
8
|
import redis
|
7
9
|
from redis import Redis
|
8
10
|
from redis.cluster import RedisCluster
|
9
|
-
from redis.commands.search.field import TagField, VectorField
|
10
|
-
from redis.commands.search.indexDefinition import IndexDefinition
|
11
|
+
from redis.commands.search.field import NumericField, TagField, VectorField
|
12
|
+
from redis.commands.search.indexDefinition import IndexDefinition
|
11
13
|
from redis.commands.search.query import Query
|
12
|
-
import numpy as np
|
13
14
|
|
15
|
+
from ..api import IndexType, VectorDB
|
16
|
+
from .config import MemoryDBIndexConfig
|
14
17
|
|
15
18
|
log = logging.getLogger(__name__)
|
16
|
-
INDEX_NAME = "index"
|
19
|
+
INDEX_NAME = "index" # Vector Index Name
|
20
|
+
|
17
21
|
|
18
22
|
class MemoryDB(VectorDB):
|
19
23
|
def __init__(
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
24
|
+
self,
|
25
|
+
dim: int,
|
26
|
+
db_config: dict,
|
27
|
+
db_case_config: MemoryDBIndexConfig,
|
28
|
+
drop_old: bool = False,
|
29
|
+
**kwargs,
|
30
|
+
):
|
28
31
|
self.db_config = db_config
|
29
32
|
self.case_config = db_case_config
|
30
33
|
self.collection_name = INDEX_NAME
|
@@ -44,10 +47,10 @@ class MemoryDB(VectorDB):
|
|
44
47
|
info = conn.ft(INDEX_NAME).info()
|
45
48
|
log.info(f"Index info: {info}")
|
46
49
|
except redis.exceptions.ResponseError as e:
|
47
|
-
log.
|
50
|
+
log.warning(e)
|
48
51
|
drop_old = False
|
49
52
|
log.info(f"MemoryDB client drop_old collection: {self.collection_name}")
|
50
|
-
|
53
|
+
|
51
54
|
log.info("Executing FLUSHALL")
|
52
55
|
conn.flushall()
|
53
56
|
|
@@ -59,7 +62,7 @@ class MemoryDB(VectorDB):
|
|
59
62
|
self.wait_until(self.wait_for_empty_db, 3, "", rc)
|
60
63
|
log.debug(f"Flushall done in the host: {host}")
|
61
64
|
rc.close()
|
62
|
-
|
65
|
+
|
63
66
|
self.make_index(dim, conn)
|
64
67
|
conn.close()
|
65
68
|
conn = None
|
@@ -69,7 +72,7 @@ class MemoryDB(VectorDB):
|
|
69
72
|
# check to see if index exists
|
70
73
|
conn.ft(INDEX_NAME).info()
|
71
74
|
except Exception as e:
|
72
|
-
log.
|
75
|
+
log.warning(f"Error getting info for index '{INDEX_NAME}': {e}")
|
73
76
|
index_param = self.case_config.index_param()
|
74
77
|
search_param = self.case_config.search_param()
|
75
78
|
vector_parameters = { # Vector Index Type: FLAT or HNSW
|
@@ -85,17 +88,19 @@ class MemoryDB(VectorDB):
|
|
85
88
|
vector_parameters["EF_RUNTIME"] = search_param["ef_runtime"]
|
86
89
|
|
87
90
|
schema = (
|
88
|
-
TagField("id"),
|
89
|
-
NumericField("metadata"),
|
90
|
-
VectorField(
|
91
|
-
"
|
91
|
+
TagField("id"),
|
92
|
+
NumericField("metadata"),
|
93
|
+
VectorField(
|
94
|
+
"vector", # Vector Field Name
|
95
|
+
"HNSW",
|
96
|
+
vector_parameters,
|
92
97
|
),
|
93
98
|
)
|
94
99
|
|
95
100
|
definition = IndexDefinition(index_type=IndexType.HASH)
|
96
101
|
rs = conn.ft(INDEX_NAME)
|
97
102
|
rs.create_index(schema, definition=definition)
|
98
|
-
|
103
|
+
|
99
104
|
def get_client(self, **kwargs):
|
100
105
|
"""
|
101
106
|
Gets either cluster connection or normal connection based on `cmd` flag.
|
@@ -143,7 +148,7 @@ class MemoryDB(VectorDB):
|
|
143
148
|
|
144
149
|
@contextmanager
|
145
150
|
def init(self) -> Generator[None, None, None]:
|
146
|
-
"""
|
151
|
+
"""create and destory connections to database.
|
147
152
|
|
148
153
|
Examples:
|
149
154
|
>>> with self.init():
|
@@ -170,7 +175,7 @@ class MemoryDB(VectorDB):
|
|
170
175
|
embeddings: list[list[float]],
|
171
176
|
metadata: list[int],
|
172
177
|
**kwargs: Any,
|
173
|
-
) ->
|
178
|
+
) -> tuple[int, Exception | None]:
|
174
179
|
"""Insert embeddings into the database.
|
175
180
|
Should call self.init() first.
|
176
181
|
"""
|
@@ -178,12 +183,15 @@ class MemoryDB(VectorDB):
|
|
178
183
|
try:
|
179
184
|
with self.conn.pipeline(transaction=False) as pipe:
|
180
185
|
for i, embedding in enumerate(embeddings):
|
181
|
-
|
182
|
-
pipe.hset(
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
186
|
+
ndarr_emb = np.array(embedding).astype(np.float32)
|
187
|
+
pipe.hset(
|
188
|
+
metadata[i],
|
189
|
+
mapping={
|
190
|
+
"id": str(metadata[i]),
|
191
|
+
"metadata": metadata[i],
|
192
|
+
"vector": ndarr_emb.tobytes(),
|
193
|
+
},
|
194
|
+
)
|
187
195
|
# Execute the pipe so we don't keep too much in memory at once
|
188
196
|
if (i + 1) % self.insert_batch_size == 0:
|
189
197
|
pipe.execute()
|
@@ -192,9 +200,9 @@ class MemoryDB(VectorDB):
|
|
192
200
|
result_len = i + 1
|
193
201
|
except Exception as e:
|
194
202
|
return 0, e
|
195
|
-
|
203
|
+
|
196
204
|
return result_len, None
|
197
|
-
|
205
|
+
|
198
206
|
def _post_insert(self):
|
199
207
|
"""Wait for indexing to finish"""
|
200
208
|
client = self.get_client(primary=True)
|
@@ -208,21 +216,17 @@ class MemoryDB(VectorDB):
|
|
208
216
|
self.wait_until(*args)
|
209
217
|
log.debug(f"Background indexing completed in the host: {host_name}")
|
210
218
|
rc.close()
|
211
|
-
|
212
|
-
def wait_until(
|
213
|
-
self, condition, interval=5, message="Operation took too long", *args
|
214
|
-
):
|
219
|
+
|
220
|
+
def wait_until(self, condition: any, interval: int = 5, message: str = "Operation took too long", *args):
|
215
221
|
while not condition(*args):
|
216
222
|
time.sleep(interval)
|
217
|
-
|
223
|
+
|
218
224
|
def wait_for_no_activity(self, client: redis.RedisCluster | redis.Redis):
|
219
|
-
return (
|
220
|
-
|
221
|
-
)
|
222
|
-
|
225
|
+
return client.info("search")["search_background_indexing_status"] == "NO_ACTIVITY"
|
226
|
+
|
223
227
|
def wait_for_empty_db(self, client: redis.RedisCluster | redis.Redis):
|
224
228
|
return client.execute_command("DBSIZE") == 0
|
225
|
-
|
229
|
+
|
226
230
|
def search_embedding(
|
227
231
|
self,
|
228
232
|
query: list[float],
|
@@ -230,13 +234,13 @@ class MemoryDB(VectorDB):
|
|
230
234
|
filters: dict | None = None,
|
231
235
|
timeout: int | None = None,
|
232
236
|
**kwargs: Any,
|
233
|
-
) ->
|
237
|
+
) -> list[int]:
|
234
238
|
assert self.conn is not None
|
235
|
-
|
239
|
+
|
236
240
|
query_vector = np.array(query).astype(np.float32).tobytes()
|
237
241
|
query_obj = Query(f"*=>[KNN {k} @vector $vec]").return_fields("id").paging(0, k)
|
238
242
|
query_params = {"vec": query_vector}
|
239
|
-
|
243
|
+
|
240
244
|
if filters:
|
241
245
|
# benchmark test filters of format: {'metadata': '>=10000', 'id': 10000}
|
242
246
|
# gets exact match for id, and range for metadata if they exist in filters
|
@@ -244,11 +248,19 @@ class MemoryDB(VectorDB):
|
|
244
248
|
# Removing '>=' from the id_value: '>=10000'
|
245
249
|
metadata_value = filters.get("metadata")[2:]
|
246
250
|
if id_value and metadata_value:
|
247
|
-
query_obj =
|
251
|
+
query_obj = (
|
252
|
+
Query(
|
253
|
+
f"(@metadata:[{metadata_value} +inf] @id:{ {id_value} })=>[KNN {k} @vector $vec]",
|
254
|
+
)
|
255
|
+
.return_fields("id")
|
256
|
+
.paging(0, k)
|
257
|
+
)
|
248
258
|
elif id_value:
|
249
|
-
#gets exact match for id
|
259
|
+
# gets exact match for id
|
250
260
|
query_obj = Query(f"@id:{ {id_value} }=>[KNN {k} @vector $vec]").return_fields("id").paging(0, k)
|
251
|
-
else:
|
252
|
-
query_obj =
|
261
|
+
else: # metadata only case, greater than or equal to metadata value
|
262
|
+
query_obj = (
|
263
|
+
Query(f"@metadata:[{metadata_value} +inf]=>[KNN {k} @vector $vec]").return_fields("id").paging(0, k)
|
264
|
+
)
|
253
265
|
res = self.conn.ft(INDEX_NAME).search(query_obj, query_params)
|
254
|
-
return [int(doc["id"]) for doc in res.docs]
|
266
|
+
return [int(doc["id"]) for doc in res.docs]
|