vectordb-bench 0.0.29__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vectordb_bench/__init__.py +14 -27
- vectordb_bench/backend/assembler.py +19 -6
- vectordb_bench/backend/cases.py +186 -23
- vectordb_bench/backend/clients/__init__.py +32 -0
- vectordb_bench/backend/clients/api.py +22 -1
- vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py +249 -43
- vectordb_bench/backend/clients/aws_opensearch/cli.py +51 -21
- vectordb_bench/backend/clients/aws_opensearch/config.py +58 -16
- vectordb_bench/backend/clients/chroma/chroma.py +6 -2
- vectordb_bench/backend/clients/elastic_cloud/config.py +19 -1
- vectordb_bench/backend/clients/elastic_cloud/elastic_cloud.py +133 -45
- vectordb_bench/backend/clients/lancedb/cli.py +62 -8
- vectordb_bench/backend/clients/lancedb/config.py +14 -1
- vectordb_bench/backend/clients/lancedb/lancedb.py +21 -9
- vectordb_bench/backend/clients/memorydb/memorydb.py +2 -2
- vectordb_bench/backend/clients/milvus/cli.py +30 -9
- vectordb_bench/backend/clients/milvus/config.py +3 -0
- vectordb_bench/backend/clients/milvus/milvus.py +81 -23
- vectordb_bench/backend/clients/oceanbase/cli.py +100 -0
- vectordb_bench/backend/clients/oceanbase/config.py +125 -0
- vectordb_bench/backend/clients/oceanbase/oceanbase.py +215 -0
- vectordb_bench/backend/clients/pinecone/pinecone.py +39 -25
- vectordb_bench/backend/clients/qdrant_cloud/config.py +59 -3
- vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py +100 -33
- vectordb_bench/backend/clients/qdrant_local/cli.py +60 -0
- vectordb_bench/backend/clients/qdrant_local/config.py +47 -0
- vectordb_bench/backend/clients/qdrant_local/qdrant_local.py +232 -0
- vectordb_bench/backend/clients/weaviate_cloud/cli.py +29 -3
- vectordb_bench/backend/clients/weaviate_cloud/config.py +2 -0
- vectordb_bench/backend/clients/weaviate_cloud/weaviate_cloud.py +5 -0
- vectordb_bench/backend/dataset.py +143 -27
- vectordb_bench/backend/filter.py +76 -0
- vectordb_bench/backend/runner/__init__.py +3 -3
- vectordb_bench/backend/runner/mp_runner.py +52 -39
- vectordb_bench/backend/runner/rate_runner.py +68 -52
- vectordb_bench/backend/runner/read_write_runner.py +125 -68
- vectordb_bench/backend/runner/serial_runner.py +56 -23
- vectordb_bench/backend/task_runner.py +48 -20
- vectordb_bench/cli/batch_cli.py +121 -0
- vectordb_bench/cli/cli.py +59 -1
- vectordb_bench/cli/vectordbbench.py +7 -0
- vectordb_bench/config-files/batch_sample_config.yml +17 -0
- vectordb_bench/frontend/components/check_results/data.py +16 -11
- vectordb_bench/frontend/components/check_results/filters.py +53 -25
- vectordb_bench/frontend/components/check_results/headerIcon.py +16 -13
- vectordb_bench/frontend/components/check_results/nav.py +20 -0
- vectordb_bench/frontend/components/custom/displayCustomCase.py +43 -8
- vectordb_bench/frontend/components/custom/displaypPrams.py +10 -5
- vectordb_bench/frontend/components/custom/getCustomConfig.py +10 -0
- vectordb_bench/frontend/components/label_filter/charts.py +60 -0
- vectordb_bench/frontend/components/run_test/caseSelector.py +48 -52
- vectordb_bench/frontend/components/run_test/dbSelector.py +9 -5
- vectordb_bench/frontend/components/run_test/inputWidget.py +48 -0
- vectordb_bench/frontend/components/run_test/submitTask.py +3 -1
- vectordb_bench/frontend/components/streaming/charts.py +253 -0
- vectordb_bench/frontend/components/streaming/data.py +62 -0
- vectordb_bench/frontend/components/tables/data.py +1 -1
- vectordb_bench/frontend/components/welcome/explainPrams.py +66 -0
- vectordb_bench/frontend/components/welcome/pagestyle.py +106 -0
- vectordb_bench/frontend/components/welcome/welcomePrams.py +147 -0
- vectordb_bench/frontend/config/dbCaseConfigs.py +420 -41
- vectordb_bench/frontend/config/styles.py +32 -2
- vectordb_bench/frontend/pages/concurrent.py +5 -1
- vectordb_bench/frontend/pages/custom.py +4 -0
- vectordb_bench/frontend/pages/label_filter.py +56 -0
- vectordb_bench/frontend/pages/quries_per_dollar.py +5 -1
- vectordb_bench/frontend/pages/results.py +60 -0
- vectordb_bench/frontend/pages/run_test.py +3 -3
- vectordb_bench/frontend/pages/streaming.py +135 -0
- vectordb_bench/frontend/pages/tables.py +4 -0
- vectordb_bench/frontend/vdb_benchmark.py +16 -41
- vectordb_bench/interface.py +6 -2
- vectordb_bench/metric.py +15 -1
- vectordb_bench/models.py +38 -11
- vectordb_bench/results/ElasticCloud/result_20250318_standard_elasticcloud.json +5890 -0
- vectordb_bench/results/Milvus/result_20250509_standard_milvus.json +6138 -0
- vectordb_bench/results/OpenSearch/result_20250224_standard_opensearch.json +7319 -0
- vectordb_bench/results/Pinecone/result_20250124_standard_pinecone.json +2365 -0
- vectordb_bench/results/QdrantCloud/result_20250602_standard_qdrantcloud.json +3556 -0
- vectordb_bench/results/ZillizCloud/result_20250613_standard_zillizcloud.json +6290 -0
- vectordb_bench/results/dbPrices.json +12 -4
- {vectordb_bench-0.0.29.dist-info → vectordb_bench-1.0.0.dist-info}/METADATA +131 -32
- {vectordb_bench-0.0.29.dist-info → vectordb_bench-1.0.0.dist-info}/RECORD +87 -65
- {vectordb_bench-0.0.29.dist-info → vectordb_bench-1.0.0.dist-info}/WHEEL +1 -1
- vectordb_bench/results/ZillizCloud/result_20230727_standard_zillizcloud.json +0 -791
- vectordb_bench/results/ZillizCloud/result_20230808_standard_zillizcloud.json +0 -679
- vectordb_bench/results/ZillizCloud/result_20240105_standard_202401_zillizcloud.json +0 -1352
- {vectordb_bench-0.0.29.dist-info → vectordb_bench-1.0.0.dist-info}/entry_points.txt +0 -0
- {vectordb_bench-0.0.29.dist-info → vectordb_bench-1.0.0.dist-info}/licenses/LICENSE +0 -0
- {vectordb_bench-0.0.29.dist-info → vectordb_bench-1.0.0.dist-info}/top_level.txt +0 -0
@@ -10,17 +10,21 @@ log = logging.getLogger(__name__)
|
|
10
10
|
|
11
11
|
class AWSOpenSearchConfig(DBConfig, BaseModel):
|
12
12
|
host: str = ""
|
13
|
-
port: int =
|
13
|
+
port: int = 80
|
14
14
|
user: str = ""
|
15
15
|
password: SecretStr = ""
|
16
16
|
|
17
17
|
def to_dict(self) -> dict:
|
18
|
+
use_ssl = self.port == 443
|
19
|
+
http_auth = (
|
20
|
+
(self.user, self.password.get_secret_value()) if len(self.user) != 0 and len(self.password) != 0 else ()
|
21
|
+
)
|
18
22
|
return {
|
19
23
|
"hosts": [{"host": self.host, "port": self.port}],
|
20
|
-
"http_auth":
|
21
|
-
"use_ssl":
|
24
|
+
"http_auth": http_auth,
|
25
|
+
"use_ssl": use_ssl,
|
22
26
|
"http_compress": True,
|
23
|
-
"verify_certs":
|
27
|
+
"verify_certs": use_ssl,
|
24
28
|
"ssl_assert_hostname": False,
|
25
29
|
"ssl_show_warn": False,
|
26
30
|
"timeout": 600,
|
@@ -28,16 +32,22 @@ class AWSOpenSearchConfig(DBConfig, BaseModel):
|
|
28
32
|
|
29
33
|
|
30
34
|
class AWSOS_Engine(Enum):
|
31
|
-
nmslib = "nmslib"
|
32
35
|
faiss = "faiss"
|
33
|
-
lucene = "
|
36
|
+
lucene = "lucene"
|
37
|
+
|
38
|
+
|
39
|
+
class AWSOSQuantization(Enum):
|
40
|
+
fp32 = "fp32"
|
41
|
+
fp16 = "fp16"
|
34
42
|
|
35
43
|
|
36
44
|
class AWSOpenSearchIndexConfig(BaseModel, DBCaseConfig):
|
37
45
|
metric_type: MetricType = MetricType.L2
|
38
46
|
engine: AWSOS_Engine = AWSOS_Engine.faiss
|
39
47
|
efConstruction: int = 256
|
40
|
-
efSearch: int =
|
48
|
+
efSearch: int = 100
|
49
|
+
engine_name: str | None = None
|
50
|
+
metric_type_name: str | None = None
|
41
51
|
M: int = 16
|
42
52
|
index_thread_qty: int | None = 4
|
43
53
|
number_of_shards: int | None = 1
|
@@ -46,33 +56,65 @@ class AWSOpenSearchIndexConfig(BaseModel, DBCaseConfig):
|
|
46
56
|
refresh_interval: str | None = "60s"
|
47
57
|
force_merge_enabled: bool | None = True
|
48
58
|
flush_threshold_size: str | None = "5120mb"
|
49
|
-
|
50
|
-
index_thread_qty_during_force_merge: int
|
59
|
+
index_thread_qty_during_force_merge: int = 8
|
51
60
|
cb_threshold: str | None = "50%"
|
61
|
+
number_of_indexing_clients: int | None = 1
|
62
|
+
use_routing: bool = False # for label-filter cases
|
63
|
+
oversample_factor: float = 1.0
|
64
|
+
quantization_type: AWSOSQuantization = AWSOSQuantization.fp32
|
65
|
+
|
66
|
+
def __eq__(self, obj: any):
|
67
|
+
return (
|
68
|
+
self.engine == obj.engine
|
69
|
+
and self.M == obj.M
|
70
|
+
and self.efConstruction == obj.efConstruction
|
71
|
+
and self.number_of_shards == obj.number_of_shards
|
72
|
+
and self.number_of_replicas == obj.number_of_replicas
|
73
|
+
and self.number_of_segments == obj.number_of_segments
|
74
|
+
and self.use_routing == obj.use_routing
|
75
|
+
and self.quantization_type == obj.quantization_type
|
76
|
+
)
|
52
77
|
|
53
78
|
def parse_metric(self) -> str:
|
79
|
+
log.info(f"User specified metric_type: {self.metric_type_name}")
|
80
|
+
self.metric_type = MetricType[self.metric_type_name.upper()]
|
54
81
|
if self.metric_type == MetricType.IP:
|
55
82
|
return "innerproduct"
|
56
83
|
if self.metric_type == MetricType.COSINE:
|
57
|
-
if self.engine == AWSOS_Engine.faiss:
|
58
|
-
log.info(
|
59
|
-
"Using innerproduct because faiss doesn't support cosine as metric type for Opensearch",
|
60
|
-
)
|
61
|
-
return "innerproduct"
|
62
84
|
return "cosinesimil"
|
85
|
+
if self.metric_type == MetricType.L2:
|
86
|
+
log.info("Using l2 as specified by user")
|
87
|
+
return "l2"
|
63
88
|
return "l2"
|
64
89
|
|
90
|
+
@property
|
91
|
+
def use_quant(self) -> bool:
|
92
|
+
return self.quantization_type is not AWSOSQuantization.fp32
|
93
|
+
|
65
94
|
def index_param(self) -> dict:
|
95
|
+
log.info(f"Using engine: {self.engine} for index creation")
|
96
|
+
log.info(f"Using metric_type: {self.metric_type_name} for index creation")
|
97
|
+
log.info(f"Resulting space_type: {self.parse_metric()} for index creation")
|
98
|
+
|
99
|
+
parameters = {"ef_construction": self.efConstruction, "m": self.M}
|
100
|
+
|
101
|
+
if self.engine == AWSOS_Engine.faiss and self.faiss_use_fp16:
|
102
|
+
parameters["encoder"] = {"name": "sq", "parameters": {"type": "fp16"}}
|
103
|
+
|
66
104
|
return {
|
67
105
|
"name": "hnsw",
|
68
|
-
"space_type": self.parse_metric(),
|
69
106
|
"engine": self.engine.value,
|
70
107
|
"parameters": {
|
71
108
|
"ef_construction": self.efConstruction,
|
72
109
|
"m": self.M,
|
73
110
|
"ef_search": self.efSearch,
|
111
|
+
**(
|
112
|
+
{"encoder": {"name": "sq", "parameters": {"type": self.quantization_type.fp16.value}}}
|
113
|
+
if self.use_quant
|
114
|
+
else {}
|
115
|
+
),
|
74
116
|
},
|
75
117
|
}
|
76
118
|
|
77
119
|
def search_param(self) -> dict:
|
78
|
-
return {}
|
120
|
+
return {"ef_search": self.efSearch}
|
@@ -78,8 +78,12 @@ class ChromaClient(VectorDB):
|
|
78
78
|
"""
|
79
79
|
ids = [str(i) for i in metadata]
|
80
80
|
metadata = [{"id": int(i)} for i in metadata]
|
81
|
-
|
82
|
-
|
81
|
+
try:
|
82
|
+
if len(embeddings) > 0:
|
83
|
+
self.collection.add(embeddings=embeddings, ids=ids, metadatas=metadata)
|
84
|
+
except Exception as e:
|
85
|
+
log.warning(f"Failed to insert data: error: {e!s}")
|
86
|
+
return 0, e
|
83
87
|
return len(embeddings), None
|
84
88
|
|
85
89
|
def search_embedding(
|
@@ -23,13 +23,31 @@ class ESElementType(str, Enum):
|
|
23
23
|
|
24
24
|
class ElasticCloudIndexConfig(BaseModel, DBCaseConfig):
|
25
25
|
element_type: ESElementType = ESElementType.float
|
26
|
-
index: IndexType = IndexType.ES_HNSW
|
26
|
+
index: IndexType = IndexType.ES_HNSW
|
27
|
+
number_of_shards: int = 1
|
28
|
+
number_of_replicas: int = 0
|
29
|
+
refresh_interval: str = "30s"
|
30
|
+
merge_max_thread_count: int = 8
|
31
|
+
use_rescore: bool = False
|
32
|
+
oversample_ratio: float = 2.0
|
33
|
+
use_routing: bool = False
|
34
|
+
use_force_merge: bool = True
|
27
35
|
|
28
36
|
metric_type: MetricType | None = None
|
29
37
|
efConstruction: int | None = None
|
30
38
|
M: int | None = None
|
31
39
|
num_candidates: int | None = None
|
32
40
|
|
41
|
+
def __eq__(self, obj: any):
|
42
|
+
return (
|
43
|
+
self.index == obj.index
|
44
|
+
and self.number_of_shards == obj.number_of_shards
|
45
|
+
and self.number_of_replicas == obj.number_of_replicas
|
46
|
+
and self.use_routing == obj.use_routing
|
47
|
+
and self.efConstruction == obj.efConstruction
|
48
|
+
and self.M == obj.M
|
49
|
+
)
|
50
|
+
|
33
51
|
def parse_metric(self) -> str:
|
34
52
|
if self.metric_type == MetricType.L2:
|
35
53
|
return "l2_norm"
|
@@ -5,6 +5,8 @@ from contextlib import contextmanager
|
|
5
5
|
|
6
6
|
from elasticsearch.helpers import bulk
|
7
7
|
|
8
|
+
from vectordb_bench.backend.filter import Filter, FilterOp
|
9
|
+
|
8
10
|
from ..api import VectorDB
|
9
11
|
from .config import ElasticCloudIndexConfig
|
10
12
|
|
@@ -18,6 +20,12 @@ SECONDS_WAITING_FOR_FORCE_MERGE_API_CALL_SEC = 30
|
|
18
20
|
|
19
21
|
|
20
22
|
class ElasticCloud(VectorDB):
|
23
|
+
supported_filter_types: list[FilterOp] = [
|
24
|
+
FilterOp.NonFilter,
|
25
|
+
FilterOp.NumGE,
|
26
|
+
FilterOp.StrEqual,
|
27
|
+
]
|
28
|
+
|
21
29
|
def __init__(
|
22
30
|
self,
|
23
31
|
dim: int,
|
@@ -25,8 +33,10 @@ class ElasticCloud(VectorDB):
|
|
25
33
|
db_case_config: ElasticCloudIndexConfig,
|
26
34
|
indice: str = "vdb_bench_indice", # must be lowercase
|
27
35
|
id_col_name: str = "id",
|
36
|
+
label_col_name: str = "label",
|
28
37
|
vector_col_name: str = "vector",
|
29
38
|
drop_old: bool = False,
|
39
|
+
with_scalar_labels: bool = False,
|
30
40
|
**kwargs,
|
31
41
|
):
|
32
42
|
self.dim = dim
|
@@ -34,7 +44,9 @@ class ElasticCloud(VectorDB):
|
|
34
44
|
self.case_config = db_case_config
|
35
45
|
self.indice = indice
|
36
46
|
self.id_col_name = id_col_name
|
47
|
+
self.label_col_name = label_col_name
|
37
48
|
self.vector_col_name = vector_col_name
|
49
|
+
self.with_scalar_labels = with_scalar_labels
|
38
50
|
|
39
51
|
from elasticsearch import Elasticsearch
|
40
52
|
|
@@ -69,9 +81,17 @@ class ElasticCloud(VectorDB):
|
|
69
81
|
},
|
70
82
|
},
|
71
83
|
}
|
84
|
+
settings = {
|
85
|
+
"index": {
|
86
|
+
"number_of_shards": self.case_config.number_of_shards,
|
87
|
+
"number_of_replicas": self.case_config.number_of_replicas,
|
88
|
+
"refresh_interval": self.case_config.refresh_interval,
|
89
|
+
"merge.scheduler.max_thread_count": self.case_config.merge_max_thread_count,
|
90
|
+
}
|
91
|
+
}
|
72
92
|
|
73
93
|
try:
|
74
|
-
client.indices.create(index=self.indice, mappings=mappings)
|
94
|
+
client.indices.create(index=self.indice, mappings=mappings, settings=settings)
|
75
95
|
except Exception as e:
|
76
96
|
log.warning(f"Failed to create indice: {self.indice} error: {e!s}")
|
77
97
|
raise e from None
|
@@ -80,21 +100,48 @@ class ElasticCloud(VectorDB):
|
|
80
100
|
self,
|
81
101
|
embeddings: Iterable[list[float]],
|
82
102
|
metadata: list[int],
|
103
|
+
labels_data: list[str] | None = None,
|
83
104
|
**kwargs,
|
84
105
|
) -> tuple[int, Exception]:
|
85
106
|
"""Insert the embeddings to the elasticsearch."""
|
86
107
|
assert self.client is not None, "should self.init() first"
|
87
108
|
|
88
|
-
insert_data =
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
109
|
+
insert_data = (
|
110
|
+
[
|
111
|
+
(
|
112
|
+
{
|
113
|
+
"_index": self.indice,
|
114
|
+
"_source": {
|
115
|
+
self.id_col_name: metadata[i],
|
116
|
+
self.label_col_name: labels_data[i],
|
117
|
+
self.vector_col_name: embeddings[i],
|
118
|
+
},
|
119
|
+
"_routing": labels_data[i],
|
120
|
+
}
|
121
|
+
if self.case_config.use_routing
|
122
|
+
else {
|
123
|
+
"_index": self.indice,
|
124
|
+
"_source": {
|
125
|
+
self.id_col_name: metadata[i],
|
126
|
+
self.label_col_name: labels_data[i],
|
127
|
+
self.vector_col_name: embeddings[i],
|
128
|
+
},
|
129
|
+
}
|
130
|
+
)
|
131
|
+
for i in range(len(embeddings))
|
132
|
+
]
|
133
|
+
if self.with_scalar_labels
|
134
|
+
else [
|
135
|
+
{
|
136
|
+
"_index": self.indice,
|
137
|
+
"_source": {
|
138
|
+
self.id_col_name: metadata[i],
|
139
|
+
self.vector_col_name: embeddings[i],
|
140
|
+
},
|
141
|
+
}
|
142
|
+
for i in range(len(embeddings))
|
143
|
+
]
|
144
|
+
)
|
98
145
|
try:
|
99
146
|
bulk_insert_res = bulk(self.client, insert_data)
|
100
147
|
return (bulk_insert_res[0], None)
|
@@ -102,59 +149,100 @@ class ElasticCloud(VectorDB):
|
|
102
149
|
log.warning(f"Failed to insert data: {self.indice} error: {e!s}")
|
103
150
|
return (0, e)
|
104
151
|
|
152
|
+
def prepare_filter(self, filters: Filter):
|
153
|
+
self.routing_key = None
|
154
|
+
if filters.type == FilterOp.NonFilter:
|
155
|
+
self.filter = []
|
156
|
+
elif filters.type == FilterOp.NumGE:
|
157
|
+
self.filter = {"range": {self.id_col_name: {"gt": filters.int_value}}}
|
158
|
+
elif filters.type == FilterOp.StrEqual:
|
159
|
+
self.filter = {"term": {self.label_col_name: filters.label_value}}
|
160
|
+
if self.case_config.use_routing:
|
161
|
+
self.routing_key = filters.label_value
|
162
|
+
else:
|
163
|
+
msg = f"Not support Filter for Milvus - {filters}"
|
164
|
+
raise ValueError(msg)
|
165
|
+
|
105
166
|
def search_embedding(
|
106
167
|
self,
|
107
168
|
query: list[float],
|
108
169
|
k: int = 100,
|
109
|
-
|
170
|
+
**kwargs,
|
110
171
|
) -> list[int]:
|
111
172
|
"""Get k most similar embeddings to query vector.
|
112
173
|
|
113
174
|
Args:
|
114
175
|
query(list[float]): query embedding to look up documents similar to.
|
115
176
|
k(int): Number of most similar embeddings to return. Defaults to 100.
|
116
|
-
filters(dict, optional): filtering expression to filter the data while searching.
|
117
177
|
|
118
178
|
Returns:
|
119
179
|
list[tuple[int, float]]: list of k most similar embeddings in (id, score) tuple to the query embedding.
|
120
180
|
"""
|
121
181
|
assert self.client is not None, "should self.init() first"
|
122
182
|
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
183
|
+
if self.case_config.use_rescore:
|
184
|
+
oversample_k = int(k * self.case_config.oversample_ratio)
|
185
|
+
oversample_num_candidates = int(self.case_config.num_candidates * self.case_config.oversample_ratio)
|
186
|
+
knn = {
|
187
|
+
"field": self.vector_col_name,
|
188
|
+
"k": oversample_k,
|
189
|
+
"num_candidates": oversample_num_candidates,
|
190
|
+
"filter": self.filter,
|
191
|
+
"query_vector": query,
|
192
|
+
}
|
193
|
+
rescore = {
|
194
|
+
"window_size": oversample_k,
|
195
|
+
"query": {
|
196
|
+
"rescore_query": {
|
197
|
+
"script_score": {
|
198
|
+
"query": {"match_all": {}},
|
199
|
+
"script": {
|
200
|
+
"source": f"cosineSimilarity(params.queryVector, '{self.vector_col_name}')",
|
201
|
+
"params": {"queryVector": query},
|
202
|
+
},
|
203
|
+
}
|
204
|
+
},
|
205
|
+
"query_weight": 0,
|
206
|
+
"rescore_query_weight": 1,
|
207
|
+
},
|
208
|
+
}
|
209
|
+
else:
|
210
|
+
knn = {
|
211
|
+
"field": self.vector_col_name,
|
212
|
+
"k": k,
|
213
|
+
"num_candidates": self.case_config.num_candidates,
|
214
|
+
"filter": self.filter,
|
215
|
+
"query_vector": query,
|
216
|
+
}
|
217
|
+
rescore = None
|
130
218
|
size = k
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
raise e from None
|
219
|
+
|
220
|
+
res = self.client.search(
|
221
|
+
index=self.indice,
|
222
|
+
knn=knn,
|
223
|
+
routing=self.routing_key,
|
224
|
+
rescore=rescore,
|
225
|
+
size=size,
|
226
|
+
_source=False,
|
227
|
+
docvalue_fields=[self.id_col_name],
|
228
|
+
stored_fields="_none_",
|
229
|
+
filter_path=[f"hits.hits.fields.{self.id_col_name}"],
|
230
|
+
)
|
231
|
+
return [h["fields"][self.id_col_name][0] for h in res["hits"]["hits"]]
|
145
232
|
|
146
233
|
def optimize(self, data_size: int | None = None):
|
147
234
|
"""optimize will be called between insertion and search in performance cases."""
|
148
235
|
assert self.client is not None, "should self.init() first"
|
149
236
|
self.client.indices.refresh(index=self.indice)
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
237
|
+
if self.case_config.use_force_merge:
|
238
|
+
force_merge_task_id = self.client.indices.forcemerge(
|
239
|
+
index=self.indice,
|
240
|
+
max_num_segments=1,
|
241
|
+
wait_for_completion=False,
|
242
|
+
)["task"]
|
243
|
+
log.info(f"Elasticsearch force merge task id: {force_merge_task_id}")
|
244
|
+
while True:
|
245
|
+
time.sleep(SECONDS_WAITING_FOR_FORCE_MERGE_API_CALL_SEC)
|
246
|
+
task_status = self.client.tasks.get(task_id=force_merge_task_id)
|
247
|
+
if task_status["completed"]:
|
248
|
+
return
|
@@ -58,10 +58,46 @@ def LanceDBAutoIndex(**parameters: Unpack[LanceDBTypedDict]):
|
|
58
58
|
)
|
59
59
|
|
60
60
|
|
61
|
+
class LanceDBIVFPQTypedDict(CommonTypedDict, LanceDBTypedDict):
|
62
|
+
num_partitions: Annotated[
|
63
|
+
int,
|
64
|
+
click.option(
|
65
|
+
"--num-partitions",
|
66
|
+
type=int,
|
67
|
+
default=0,
|
68
|
+
help="Number of partitions for IVFPQ index, unset = use LanceDB default",
|
69
|
+
),
|
70
|
+
]
|
71
|
+
num_sub_vectors: Annotated[
|
72
|
+
int,
|
73
|
+
click.option(
|
74
|
+
"--num-sub-vectors",
|
75
|
+
type=int,
|
76
|
+
default=0,
|
77
|
+
help="Number of sub-vectors for IVFPQ index, unset = use LanceDB default",
|
78
|
+
),
|
79
|
+
]
|
80
|
+
nbits: Annotated[
|
81
|
+
int,
|
82
|
+
click.option(
|
83
|
+
"--nbits",
|
84
|
+
type=int,
|
85
|
+
default=8,
|
86
|
+
help="Number of bits for IVFPQ index (must be 4 or 8), unset = use LanceDB default",
|
87
|
+
),
|
88
|
+
]
|
89
|
+
nprobes: Annotated[
|
90
|
+
int,
|
91
|
+
click.option(
|
92
|
+
"--nprobes", type=int, default=0, help="Number of probes for IVFPQ search, unset = use LanceDB default"
|
93
|
+
),
|
94
|
+
]
|
95
|
+
|
96
|
+
|
61
97
|
@cli.command()
|
62
|
-
@click_parameter_decorators_from_typed_dict(
|
63
|
-
def LanceDBIVFPQ(**parameters: Unpack[
|
64
|
-
from .config import LanceDBConfig,
|
98
|
+
@click_parameter_decorators_from_typed_dict(LanceDBIVFPQTypedDict)
|
99
|
+
def LanceDBIVFPQ(**parameters: Unpack[LanceDBIVFPQTypedDict]):
|
100
|
+
from .config import LanceDBConfig, LanceDBIndexConfig
|
65
101
|
|
66
102
|
run(
|
67
103
|
db=DB.LanceDB,
|
@@ -70,15 +106,29 @@ def LanceDBIVFPQ(**parameters: Unpack[LanceDBTypedDict]):
|
|
70
106
|
uri=parameters["uri"],
|
71
107
|
token=SecretStr(parameters["token"]) if parameters.get("token") else None,
|
72
108
|
),
|
73
|
-
db_case_config=
|
109
|
+
db_case_config=LanceDBIndexConfig(
|
110
|
+
index=IndexType.IVFPQ,
|
111
|
+
num_partitions=parameters["num_partitions"],
|
112
|
+
num_sub_vectors=parameters["num_sub_vectors"],
|
113
|
+
nbits=parameters["nbits"],
|
114
|
+
nprobes=parameters["nprobes"],
|
115
|
+
),
|
74
116
|
**parameters,
|
75
117
|
)
|
76
118
|
|
77
119
|
|
120
|
+
class LanceDBHNSWTypedDict(CommonTypedDict, LanceDBTypedDict):
|
121
|
+
m: Annotated[int, click.option("--m", type=int, default=0, help="HNSW parameter m")]
|
122
|
+
ef_construction: Annotated[
|
123
|
+
int, click.option("--ef-construction", type=int, default=0, help="HNSW parameter ef_construction")
|
124
|
+
]
|
125
|
+
ef: Annotated[int, click.option("--ef", type=int, default=0, help="HNSW search parameter ef")]
|
126
|
+
|
127
|
+
|
78
128
|
@cli.command()
|
79
|
-
@click_parameter_decorators_from_typed_dict(
|
80
|
-
def LanceDBHNSW(**parameters: Unpack[
|
81
|
-
from .config import LanceDBConfig,
|
129
|
+
@click_parameter_decorators_from_typed_dict(LanceDBHNSWTypedDict)
|
130
|
+
def LanceDBHNSW(**parameters: Unpack[LanceDBHNSWTypedDict]):
|
131
|
+
from .config import LanceDBConfig, LanceDBHNSWIndexConfig
|
82
132
|
|
83
133
|
run(
|
84
134
|
db=DB.LanceDB,
|
@@ -87,6 +137,10 @@ def LanceDBHNSW(**parameters: Unpack[LanceDBTypedDict]):
|
|
87
137
|
uri=parameters["uri"],
|
88
138
|
token=SecretStr(parameters["token"]) if parameters.get("token") else None,
|
89
139
|
),
|
90
|
-
db_case_config=
|
140
|
+
db_case_config=LanceDBHNSWIndexConfig(
|
141
|
+
m=parameters["m"],
|
142
|
+
ef_construction=parameters["ef_construction"],
|
143
|
+
ef=parameters["ef"],
|
144
|
+
),
|
91
145
|
**parameters,
|
92
146
|
)
|
@@ -25,6 +25,7 @@ class LanceDBIndexConfig(BaseModel, DBCaseConfig):
|
|
25
25
|
nbits: int = 8 # Must be 4 or 8
|
26
26
|
sample_rate: int = 256
|
27
27
|
max_iterations: int = 50
|
28
|
+
nprobes: int = 0
|
28
29
|
|
29
30
|
def index_param(self) -> dict:
|
30
31
|
if self.index not in [
|
@@ -52,7 +53,11 @@ class LanceDBIndexConfig(BaseModel, DBCaseConfig):
|
|
52
53
|
return params
|
53
54
|
|
54
55
|
def search_param(self) -> dict:
|
55
|
-
|
56
|
+
params = {}
|
57
|
+
if self.nprobes > 0:
|
58
|
+
params["nprobes"] = self.nprobes
|
59
|
+
|
60
|
+
return params
|
56
61
|
|
57
62
|
def parse_metric(self) -> str:
|
58
63
|
if self.metric_type in [MetricType.L2, MetricType.COSINE]:
|
@@ -81,6 +86,7 @@ class LanceDBHNSWIndexConfig(LanceDBIndexConfig):
|
|
81
86
|
index: IndexType = IndexType.HNSW
|
82
87
|
m: int = 0
|
83
88
|
ef_construction: int = 0
|
89
|
+
ef: int = 0
|
84
90
|
|
85
91
|
def index_param(self) -> dict:
|
86
92
|
params = LanceDBIndexConfig.index_param(self)
|
@@ -94,6 +100,13 @@ class LanceDBHNSWIndexConfig(LanceDBIndexConfig):
|
|
94
100
|
|
95
101
|
return params
|
96
102
|
|
103
|
+
def search_param(self) -> dict:
|
104
|
+
params = {}
|
105
|
+
if self.ef != 0:
|
106
|
+
params = {"ef": self.ef}
|
107
|
+
|
108
|
+
return params
|
109
|
+
|
97
110
|
|
98
111
|
_lancedb_case_config = {
|
99
112
|
IndexType.IVFPQ: LanceDBIndexConfig,
|
@@ -32,6 +32,10 @@ class LanceDB(VectorDB):
|
|
32
32
|
self.table_name = collection_name
|
33
33
|
self.dim = dim
|
34
34
|
self.uri = db_config["uri"]
|
35
|
+
# avoid the search_param being called every time during the search process
|
36
|
+
self.search_config = db_case_config.search_param()
|
37
|
+
|
38
|
+
log.info(f"Search config: {self.search_config}")
|
35
39
|
|
36
40
|
db = lancedb.connect(self.uri)
|
37
41
|
|
@@ -45,7 +49,7 @@ class LanceDB(VectorDB):
|
|
45
49
|
db.open_table(self.table_name)
|
46
50
|
except Exception:
|
47
51
|
schema = pa.schema(
|
48
|
-
[pa.field("id", pa.int64()), pa.field("vector", pa.list_(pa.
|
52
|
+
[pa.field("id", pa.int64()), pa.field("vector", pa.list_(pa.float32(), list_size=self.dim))]
|
49
53
|
)
|
50
54
|
db.create_table(self.table_name, schema=schema, mode="overwrite")
|
51
55
|
|
@@ -77,20 +81,28 @@ class LanceDB(VectorDB):
|
|
77
81
|
filters: dict | None = None,
|
78
82
|
) -> list[int]:
|
79
83
|
if filters:
|
80
|
-
results = (
|
81
|
-
|
82
|
-
.
|
83
|
-
|
84
|
-
.
|
85
|
-
|
86
|
-
|
84
|
+
results = self.table.search(query).select(["id"]).where(f"id >= {filters['id']}", prefilter=True).limit(k)
|
85
|
+
if self.case_config.index == IndexType.IVFPQ and "nprobes" in self.search_config:
|
86
|
+
results = results.nprobes(self.search_config["nprobes"]).to_list()
|
87
|
+
elif self.case_config.index == IndexType.HNSW and "ef" in self.search_config:
|
88
|
+
results = results.ef(self.search_config["ef"]).to_list()
|
89
|
+
else:
|
90
|
+
results = results.to_list()
|
87
91
|
else:
|
88
|
-
results = self.table.search(query).select(["id"]).limit(k)
|
92
|
+
results = self.table.search(query).select(["id"]).limit(k)
|
93
|
+
if self.case_config.index == IndexType.IVFPQ and "nprobes" in self.search_config:
|
94
|
+
results = results.nprobes(self.search_config["nprobes"]).to_list()
|
95
|
+
elif self.case_config.index == IndexType.HNSW and "ef" in self.search_config:
|
96
|
+
results = results.ef(self.search_config["ef"]).to_list()
|
97
|
+
else:
|
98
|
+
results = results.to_list()
|
99
|
+
|
89
100
|
return [int(result["id"]) for result in results]
|
90
101
|
|
91
102
|
def optimize(self, data_size: int | None = None):
|
92
103
|
if self.table and hasattr(self, "case_config") and self.case_config.index != IndexType.NONE:
|
93
104
|
log.info(f"Creating index for LanceDB table ({self.table_name})")
|
105
|
+
log.info(f"Index parameters: {self.case_config.index_param()}")
|
94
106
|
self.table.create_index(**self.case_config.index_param())
|
95
107
|
# Better recall with IVF_PQ (though still bad) but breaks HNSW: https://github.com/lancedb/lancedb/issues/2369
|
96
108
|
if self.case_config.index in (IndexType.IVFPQ, IndexType.AUTOINDEX):
|
@@ -9,10 +9,10 @@ import redis
|
|
9
9
|
from redis import Redis
|
10
10
|
from redis.cluster import RedisCluster
|
11
11
|
from redis.commands.search.field import NumericField, TagField, VectorField
|
12
|
-
from redis.commands.search.indexDefinition import IndexDefinition
|
12
|
+
from redis.commands.search.indexDefinition import IndexDefinition, IndexType
|
13
13
|
from redis.commands.search.query import Query
|
14
14
|
|
15
|
-
from ..api import
|
15
|
+
from ..api import VectorDB
|
16
16
|
from .config import MemoryDBIndexConfig
|
17
17
|
|
18
18
|
log = logging.getLogger(__name__)
|