vectordb-bench 0.0.1__1-py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vectordb_bench/__init__.py +30 -0
- vectordb_bench/__main__.py +39 -0
- vectordb_bench/backend/__init__.py +0 -0
- vectordb_bench/backend/assembler.py +57 -0
- vectordb_bench/backend/cases.py +124 -0
- vectordb_bench/backend/clients/__init__.py +57 -0
- vectordb_bench/backend/clients/api.py +179 -0
- vectordb_bench/backend/clients/elastic_cloud/config.py +56 -0
- vectordb_bench/backend/clients/elastic_cloud/elastic_cloud.py +152 -0
- vectordb_bench/backend/clients/milvus/config.py +123 -0
- vectordb_bench/backend/clients/milvus/milvus.py +182 -0
- vectordb_bench/backend/clients/pinecone/config.py +15 -0
- vectordb_bench/backend/clients/pinecone/pinecone.py +113 -0
- vectordb_bench/backend/clients/qdrant_cloud/config.py +16 -0
- vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py +169 -0
- vectordb_bench/backend/clients/weaviate_cloud/config.py +45 -0
- vectordb_bench/backend/clients/weaviate_cloud/weaviate_cloud.py +151 -0
- vectordb_bench/backend/clients/zilliz_cloud/config.py +34 -0
- vectordb_bench/backend/clients/zilliz_cloud/zilliz_cloud.py +35 -0
- vectordb_bench/backend/dataset.py +393 -0
- vectordb_bench/backend/result_collector.py +15 -0
- vectordb_bench/backend/runner/__init__.py +12 -0
- vectordb_bench/backend/runner/mp_runner.py +124 -0
- vectordb_bench/backend/runner/serial_runner.py +164 -0
- vectordb_bench/backend/task_runner.py +290 -0
- vectordb_bench/backend/utils.py +85 -0
- vectordb_bench/base.py +6 -0
- vectordb_bench/frontend/components/check_results/charts.py +175 -0
- vectordb_bench/frontend/components/check_results/data.py +86 -0
- vectordb_bench/frontend/components/check_results/filters.py +97 -0
- vectordb_bench/frontend/components/check_results/headerIcon.py +18 -0
- vectordb_bench/frontend/components/check_results/nav.py +21 -0
- vectordb_bench/frontend/components/check_results/priceTable.py +48 -0
- vectordb_bench/frontend/components/run_test/autoRefresh.py +10 -0
- vectordb_bench/frontend/components/run_test/caseSelector.py +87 -0
- vectordb_bench/frontend/components/run_test/dbConfigSetting.py +47 -0
- vectordb_bench/frontend/components/run_test/dbSelector.py +36 -0
- vectordb_bench/frontend/components/run_test/generateTasks.py +21 -0
- vectordb_bench/frontend/components/run_test/hideSidebar.py +10 -0
- vectordb_bench/frontend/components/run_test/submitTask.py +69 -0
- vectordb_bench/frontend/const.py +391 -0
- vectordb_bench/frontend/pages/qps_with_price.py +60 -0
- vectordb_bench/frontend/pages/run_test.py +59 -0
- vectordb_bench/frontend/utils.py +6 -0
- vectordb_bench/frontend/vdb_benchmark.py +42 -0
- vectordb_bench/interface.py +239 -0
- vectordb_bench/log_util.py +103 -0
- vectordb_bench/metric.py +53 -0
- vectordb_bench/models.py +234 -0
- vectordb_bench/results/result_20230609_standard.json +3228 -0
- vectordb_bench-0.0.1.dist-info/LICENSE +21 -0
- vectordb_bench-0.0.1.dist-info/METADATA +226 -0
- vectordb_bench-0.0.1.dist-info/RECORD +56 -0
- vectordb_bench-0.0.1.dist-info/WHEEL +5 -0
- vectordb_bench-0.0.1.dist-info/entry_points.txt +2 -0
- vectordb_bench-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,123 @@
|
|
1
|
+
from pydantic import BaseModel, SecretStr
|
2
|
+
from ..api import DBConfig, DBCaseConfig, MetricType, IndexType
|
3
|
+
|
4
|
+
|
5
|
+
class MilvusConfig(DBConfig, BaseModel):
|
6
|
+
uri: SecretStr | None = "http://localhost:19530"
|
7
|
+
|
8
|
+
def to_dict(self) -> dict:
|
9
|
+
return {"uri": self.uri.get_secret_value()}
|
10
|
+
|
11
|
+
|
12
|
+
|
13
|
+
class MilvusIndexConfig(BaseModel):
|
14
|
+
"""Base config for milvus"""
|
15
|
+
|
16
|
+
index: IndexType
|
17
|
+
metric_type: MetricType | None = None
|
18
|
+
|
19
|
+
def parse_metric(self) -> str:
|
20
|
+
if not self.metric_type:
|
21
|
+
return ""
|
22
|
+
|
23
|
+
if self.metric_type == MetricType.COSINE:
|
24
|
+
return MetricType.L2.value
|
25
|
+
return self.metric_type.value
|
26
|
+
|
27
|
+
|
28
|
+
class AutoIndexConfig(MilvusIndexConfig, DBCaseConfig):
|
29
|
+
index: IndexType = IndexType.AUTOINDEX
|
30
|
+
|
31
|
+
def index_param(self) -> dict:
|
32
|
+
return {
|
33
|
+
"metric_type": self.parse_metric(),
|
34
|
+
"index_type": self.index.value,
|
35
|
+
"params": {},
|
36
|
+
}
|
37
|
+
|
38
|
+
def search_param(self) -> dict:
|
39
|
+
return {
|
40
|
+
"metric_type": self.parse_metric(),
|
41
|
+
}
|
42
|
+
|
43
|
+
class HNSWConfig(MilvusIndexConfig, DBCaseConfig):
|
44
|
+
M: int
|
45
|
+
efConstruction: int
|
46
|
+
ef: int | None = None
|
47
|
+
index: IndexType = IndexType.HNSW
|
48
|
+
|
49
|
+
def index_param(self) -> dict:
|
50
|
+
return {
|
51
|
+
"metric_type": self.parse_metric(),
|
52
|
+
"index_type": self.index.value,
|
53
|
+
"params": {"M": self.M, "efConstruction": self.efConstruction},
|
54
|
+
}
|
55
|
+
|
56
|
+
def search_param(self) -> dict:
|
57
|
+
return {
|
58
|
+
"metric_type": self.parse_metric(),
|
59
|
+
"params": {"ef": self.ef},
|
60
|
+
}
|
61
|
+
|
62
|
+
|
63
|
+
class DISKANNConfig(MilvusIndexConfig, DBCaseConfig):
|
64
|
+
search_list: int | None = None
|
65
|
+
index: IndexType = IndexType.DISKANN
|
66
|
+
|
67
|
+
def index_param(self) -> dict:
|
68
|
+
return {
|
69
|
+
"metric_type": self.parse_metric(),
|
70
|
+
"index_type": self.index.value,
|
71
|
+
"params": {},
|
72
|
+
}
|
73
|
+
|
74
|
+
def search_param(self) -> dict:
|
75
|
+
return {
|
76
|
+
"metric_type": self.parse_metric(),
|
77
|
+
"params": {"search_list": self.search_list},
|
78
|
+
}
|
79
|
+
|
80
|
+
|
81
|
+
class IVFFlatConfig(MilvusIndexConfig, DBCaseConfig):
|
82
|
+
nlist: int
|
83
|
+
nprobe: int | None = None
|
84
|
+
index: IndexType = IndexType.IVFFlat
|
85
|
+
|
86
|
+
def index_param(self) -> dict:
|
87
|
+
return {
|
88
|
+
"metric_type": self.parse_metric(),
|
89
|
+
"index_type": self.index.value,
|
90
|
+
"params": {"nlist": self.nlist},
|
91
|
+
}
|
92
|
+
|
93
|
+
def search_param(self) -> dict:
|
94
|
+
return {
|
95
|
+
"metric_type": self.parse_metric(),
|
96
|
+
"params": {"nprobe": self.nprobe},
|
97
|
+
}
|
98
|
+
|
99
|
+
|
100
|
+
class FLATConfig(MilvusIndexConfig, DBCaseConfig):
|
101
|
+
index: IndexType = IndexType.Flat
|
102
|
+
|
103
|
+
def index_param(self) -> dict:
|
104
|
+
return {
|
105
|
+
"metric_type": self.parse_metric(),
|
106
|
+
"index_type": self.index.value,
|
107
|
+
"params": {},
|
108
|
+
}
|
109
|
+
|
110
|
+
def search_param(self) -> dict:
|
111
|
+
return {
|
112
|
+
"metric_type": self.parse_metric(),
|
113
|
+
"params": {},
|
114
|
+
}
|
115
|
+
|
116
|
+
_milvus_case_config = {
|
117
|
+
IndexType.AUTOINDEX: AutoIndexConfig,
|
118
|
+
IndexType.HNSW: HNSWConfig,
|
119
|
+
IndexType.DISKANN: DISKANNConfig,
|
120
|
+
IndexType.IVFFlat: IVFFlatConfig,
|
121
|
+
IndexType.Flat: FLATConfig,
|
122
|
+
}
|
123
|
+
|
@@ -0,0 +1,182 @@
|
|
1
|
+
"""Wrapper around the Milvus vector database over VectorDB"""
|
2
|
+
|
3
|
+
import logging
|
4
|
+
from contextlib import contextmanager
|
5
|
+
from typing import Any, Iterable, Type
|
6
|
+
|
7
|
+
from pymilvus import Collection, utility
|
8
|
+
from pymilvus import CollectionSchema, DataType, FieldSchema, MilvusException
|
9
|
+
|
10
|
+
from ..api import VectorDB, DBCaseConfig, DBConfig, IndexType
|
11
|
+
from .config import MilvusConfig, _milvus_case_config
|
12
|
+
|
13
|
+
|
14
|
+
log = logging.getLogger(__name__)
|
15
|
+
|
16
|
+
|
17
|
+
class Milvus(VectorDB):
|
18
|
+
def __init__(
|
19
|
+
self,
|
20
|
+
dim: int,
|
21
|
+
db_config: dict,
|
22
|
+
db_case_config: DBCaseConfig,
|
23
|
+
collection_name: str = "VectorDBBenchCollection",
|
24
|
+
drop_old: bool = False,
|
25
|
+
name: str = "Milvus",
|
26
|
+
):
|
27
|
+
"""Initialize wrapper around the milvus vector database."""
|
28
|
+
self.name = name
|
29
|
+
self.db_config = db_config
|
30
|
+
self.case_config = db_case_config
|
31
|
+
self.collection_name = collection_name
|
32
|
+
|
33
|
+
self._primary_field = "pk"
|
34
|
+
self._scalar_field = "id"
|
35
|
+
self._vector_field = "vector"
|
36
|
+
self._index_name = "vector_idx"
|
37
|
+
|
38
|
+
from pymilvus import connections
|
39
|
+
connections.connect(**self.db_config, timeout=30)
|
40
|
+
if drop_old and utility.has_collection(self.collection_name):
|
41
|
+
log.info(f"{self.name} client drop_old collection: {self.collection_name}")
|
42
|
+
utility.drop_collection(self.collection_name)
|
43
|
+
|
44
|
+
if not utility.has_collection(self.collection_name):
|
45
|
+
fields = [
|
46
|
+
FieldSchema(self._primary_field, DataType.INT64, is_primary=True),
|
47
|
+
FieldSchema(self._scalar_field, DataType.INT64),
|
48
|
+
FieldSchema(self._vector_field, DataType.FLOAT_VECTOR, dim=dim)
|
49
|
+
]
|
50
|
+
|
51
|
+
log.info(f"{self.name} create collection: {self.collection_name}")
|
52
|
+
|
53
|
+
# Create the collection
|
54
|
+
coll = Collection(
|
55
|
+
name=self.collection_name,
|
56
|
+
schema=CollectionSchema(fields),
|
57
|
+
consistency_level="Session",
|
58
|
+
)
|
59
|
+
|
60
|
+
# self._pre_load(coll)
|
61
|
+
|
62
|
+
connections.disconnect("default")
|
63
|
+
|
64
|
+
@classmethod
|
65
|
+
def config_cls(cls) -> Type[DBConfig]:
|
66
|
+
return MilvusConfig
|
67
|
+
|
68
|
+
@classmethod
|
69
|
+
def case_config_cls(cls, index_type: IndexType | None = None) -> Type[DBCaseConfig]:
|
70
|
+
return _milvus_case_config.get(index_type)
|
71
|
+
|
72
|
+
|
73
|
+
@contextmanager
|
74
|
+
def init(self) -> None:
|
75
|
+
"""
|
76
|
+
Examples:
|
77
|
+
>>> with self.init():
|
78
|
+
>>> self.insert_embeddings()
|
79
|
+
>>> self.search_embedding()
|
80
|
+
"""
|
81
|
+
from pymilvus import connections
|
82
|
+
self.col: Collection | None = None
|
83
|
+
|
84
|
+
connections.connect(**self.db_config, timeout=60)
|
85
|
+
# Grab the existing colection with connections
|
86
|
+
self.col = Collection(self.collection_name)
|
87
|
+
|
88
|
+
yield
|
89
|
+
connections.disconnect("default")
|
90
|
+
|
91
|
+
def _pre_load(self, coll: Collection):
|
92
|
+
if not coll.has_index(index_name=self._index_name):
|
93
|
+
log.info(f"{self.name} create index and load")
|
94
|
+
try:
|
95
|
+
coll.create_index(
|
96
|
+
self._vector_field,
|
97
|
+
self.case_config.index_param(),
|
98
|
+
index_name=self._index_name,
|
99
|
+
)
|
100
|
+
|
101
|
+
coll.load()
|
102
|
+
except Exception as e:
|
103
|
+
log.warning(f"{self.name} pre load error: {e}")
|
104
|
+
raise e from None
|
105
|
+
|
106
|
+
def _optimize(self):
|
107
|
+
log.info(f"{self.name} optimizing before search")
|
108
|
+
try:
|
109
|
+
self.col.flush()
|
110
|
+
self.col.compact()
|
111
|
+
self.col.wait_for_compaction_completed()
|
112
|
+
|
113
|
+
# wait for index done and load refresh
|
114
|
+
self.col.create_index(
|
115
|
+
self._vector_field,
|
116
|
+
self.case_config.index_param(),
|
117
|
+
index_name=self._index_name,
|
118
|
+
)
|
119
|
+
utility.wait_for_index_building_complete(self.collection_name)
|
120
|
+
self.col.load()
|
121
|
+
# self.col.load(_refresh=True)
|
122
|
+
# utility.wait_for_loading_complete(self.collection_name)
|
123
|
+
# import time; time.sleep(10)
|
124
|
+
except Exception as e:
|
125
|
+
log.warning(f"{self.name} optimize error: {e}")
|
126
|
+
raise e from None
|
127
|
+
|
128
|
+
def ready_to_load(self):
|
129
|
+
assert self.col, "Please call self.init() before"
|
130
|
+
self._pre_load(self.col)
|
131
|
+
pass
|
132
|
+
|
133
|
+
def ready_to_search(self):
|
134
|
+
assert self.col, "Please call self.init() before"
|
135
|
+
self._optimize()
|
136
|
+
|
137
|
+
def insert_embeddings(
|
138
|
+
self,
|
139
|
+
embeddings: Iterable[list[float]],
|
140
|
+
metadata: list[int],
|
141
|
+
**kwargs: Any,
|
142
|
+
) -> int:
|
143
|
+
"""Insert embeddings into Milvus. should call self.init() first"""
|
144
|
+
# use the first insert_embeddings to init collection
|
145
|
+
assert self.col is not None
|
146
|
+
insert_data = [
|
147
|
+
metadata,
|
148
|
+
metadata,
|
149
|
+
embeddings,
|
150
|
+
]
|
151
|
+
|
152
|
+
try:
|
153
|
+
res = self.col.insert(insert_data, **kwargs)
|
154
|
+
return len(res.primary_keys)
|
155
|
+
except MilvusException as e:
|
156
|
+
log.warning("Failed to insert data")
|
157
|
+
raise e from None
|
158
|
+
|
159
|
+
def search_embedding(
|
160
|
+
self,
|
161
|
+
query: list[float],
|
162
|
+
k: int = 100,
|
163
|
+
filters: dict | None = None,
|
164
|
+
timeout: int | None = None,
|
165
|
+
) -> list[int]:
|
166
|
+
"""Perform a search on a query embedding and return results."""
|
167
|
+
assert self.col is not None
|
168
|
+
|
169
|
+
expr = f"{self._scalar_field} {filters.get('metadata')}" if filters else ""
|
170
|
+
|
171
|
+
# Perform the search.
|
172
|
+
res = self.col.search(
|
173
|
+
data=[query],
|
174
|
+
anns_field=self._vector_field,
|
175
|
+
param=self.case_config.search_param(),
|
176
|
+
limit=k,
|
177
|
+
expr=expr,
|
178
|
+
)
|
179
|
+
|
180
|
+
# Organize results.
|
181
|
+
ret = [result.id for result in res[0]]
|
182
|
+
return ret
|
@@ -0,0 +1,15 @@
|
|
1
|
+
from pydantic import BaseModel, SecretStr
|
2
|
+
from ..api import DBConfig
|
3
|
+
|
4
|
+
|
5
|
+
class PineconeConfig(DBConfig, BaseModel):
|
6
|
+
api_key: SecretStr | None = None
|
7
|
+
environment: SecretStr | None = None
|
8
|
+
index_name: str
|
9
|
+
|
10
|
+
def to_dict(self) -> dict:
|
11
|
+
return {
|
12
|
+
"api_key": self.api_key.get_secret_value(),
|
13
|
+
"environment": self.environment.get_secret_value(),
|
14
|
+
"index_name": self.index_name,
|
15
|
+
}
|
@@ -0,0 +1,113 @@
|
|
1
|
+
"""Wrapper around the Pinecone vector database over VectorDB"""
|
2
|
+
|
3
|
+
import logging
|
4
|
+
from contextlib import contextmanager
|
5
|
+
from typing import Any, Type
|
6
|
+
|
7
|
+
from ..api import VectorDB, DBConfig, DBCaseConfig, EmptyDBCaseConfig, IndexType
|
8
|
+
from .config import PineconeConfig
|
9
|
+
|
10
|
+
|
11
|
+
log = logging.getLogger(__name__)
|
12
|
+
|
13
|
+
PINECONE_MAX_NUM_PER_BATCH = 1000
|
14
|
+
PINECONE_MAX_SIZE_PER_BATCH = 2 * 1024 * 1024 # 2MB
|
15
|
+
|
16
|
+
class Pinecone(VectorDB):
|
17
|
+
def __init__(
|
18
|
+
self,
|
19
|
+
dim,
|
20
|
+
db_config: dict,
|
21
|
+
db_case_config: DBCaseConfig,
|
22
|
+
drop_old: bool = False,
|
23
|
+
):
|
24
|
+
"""Initialize wrapper around the milvus vector database."""
|
25
|
+
self.index_name = db_config["index_name"]
|
26
|
+
self.api_key = db_config["api_key"]
|
27
|
+
self.environment = db_config["environment"]
|
28
|
+
self.batch_size = int(min(PINECONE_MAX_SIZE_PER_BATCH / (dim * 5), PINECONE_MAX_NUM_PER_BATCH))
|
29
|
+
# Pincone will make connections with server while import
|
30
|
+
# so place the import here.
|
31
|
+
import pinecone
|
32
|
+
pinecone.init(
|
33
|
+
api_key=self.api_key, environment=self.environment)
|
34
|
+
if drop_old:
|
35
|
+
list_indexes = pinecone.list_indexes()
|
36
|
+
if self.index_name in list_indexes:
|
37
|
+
index = pinecone.Index(self.index_name)
|
38
|
+
index_dim = index.describe_index_stats()["dimension"]
|
39
|
+
if (index_dim != dim):
|
40
|
+
raise ValueError(
|
41
|
+
f"Pinecone index {self.index_name} dimension mismatch, expected {index_dim} got {dim}")
|
42
|
+
log.info(
|
43
|
+
f"Pinecone client delete old index: {self.index_name}")
|
44
|
+
index.delete(delete_all=True)
|
45
|
+
index.close()
|
46
|
+
else:
|
47
|
+
raise ValueError(
|
48
|
+
f"Pinecone index {self.index_name} does not exist")
|
49
|
+
|
50
|
+
self._metadata_key = "meta"
|
51
|
+
|
52
|
+
@classmethod
|
53
|
+
def config_cls(cls) -> Type[DBConfig]:
|
54
|
+
return PineconeConfig
|
55
|
+
|
56
|
+
@classmethod
|
57
|
+
def case_config_cls(cls, index_type: IndexType | None = None) -> Type[DBCaseConfig]:
|
58
|
+
return EmptyDBCaseConfig
|
59
|
+
|
60
|
+
@contextmanager
|
61
|
+
def init(self) -> None:
|
62
|
+
import pinecone
|
63
|
+
pinecone.init(
|
64
|
+
api_key=self.api_key, environment=self.environment)
|
65
|
+
self.index = pinecone.Index(self.index_name)
|
66
|
+
yield
|
67
|
+
self.index.close()
|
68
|
+
|
69
|
+
def ready_to_load(self):
|
70
|
+
pass
|
71
|
+
|
72
|
+
def ready_to_search(self):
|
73
|
+
pass
|
74
|
+
|
75
|
+
def insert_embeddings(
|
76
|
+
self,
|
77
|
+
embeddings: list[list[float]],
|
78
|
+
metadata: list[int],
|
79
|
+
) -> list[str]:
|
80
|
+
assert len(embeddings) == len(metadata)
|
81
|
+
for batch_start_offset in range(0, len(embeddings), self.batch_size):
|
82
|
+
batch_end_offset = min(batch_start_offset + self.batch_size, len(embeddings))
|
83
|
+
insert_datas = []
|
84
|
+
for i in range(batch_start_offset, batch_end_offset):
|
85
|
+
insert_data = (str(metadata[i]), embeddings[i], {
|
86
|
+
self._metadata_key: metadata[i]})
|
87
|
+
insert_datas.append(insert_data)
|
88
|
+
self.index.upsert(insert_datas)
|
89
|
+
return len(embeddings)
|
90
|
+
|
91
|
+
def search_embedding(
|
92
|
+
self,
|
93
|
+
query: list[float],
|
94
|
+
k: int = 100,
|
95
|
+
filters: dict | None = None,
|
96
|
+
timeout: int | None = None,
|
97
|
+
**kwargs: Any,
|
98
|
+
) -> list[tuple[int, float]]:
|
99
|
+
if filters is None:
|
100
|
+
pinecone_filters = {}
|
101
|
+
else:
|
102
|
+
pinecone_filters = {self._metadata_key: {"$gte": filters["id"]}}
|
103
|
+
try:
|
104
|
+
res = self.index.query(
|
105
|
+
top_k=k,
|
106
|
+
vector=query,
|
107
|
+
filter=pinecone_filters,
|
108
|
+
)['matches']
|
109
|
+
except Exception as e:
|
110
|
+
print(f"Error querying index: {e}")
|
111
|
+
raise e
|
112
|
+
id_res = [int(one_res['id']) for one_res in res]
|
113
|
+
return id_res
|
@@ -0,0 +1,16 @@
|
|
1
|
+
from pydantic import BaseModel, SecretStr
|
2
|
+
|
3
|
+
from ..api import DBConfig
|
4
|
+
|
5
|
+
|
6
|
+
class QdrantConfig(DBConfig, BaseModel):
|
7
|
+
url: SecretStr | None = None
|
8
|
+
api_key: SecretStr | None = None
|
9
|
+
prefer_grpc: bool = True
|
10
|
+
|
11
|
+
def to_dict(self) -> dict:
|
12
|
+
return {
|
13
|
+
"url": self.url.get_secret_value(),
|
14
|
+
"api_key": self.api_key.get_secret_value(),
|
15
|
+
"prefer_grpc": self.prefer_grpc,
|
16
|
+
}
|
@@ -0,0 +1,169 @@
|
|
1
|
+
"""Wrapper around the QdrantCloud vector database over VectorDB"""
|
2
|
+
|
3
|
+
import logging
|
4
|
+
import time
|
5
|
+
from contextlib import contextmanager
|
6
|
+
from typing import Any, Type
|
7
|
+
|
8
|
+
from ..api import VectorDB, DBConfig, DBCaseConfig, EmptyDBCaseConfig, IndexType
|
9
|
+
from .config import QdrantConfig
|
10
|
+
from qdrant_client.http.models import (
|
11
|
+
CollectionStatus,
|
12
|
+
Distance,
|
13
|
+
VectorParams,
|
14
|
+
PayloadSchemaType,
|
15
|
+
Batch,
|
16
|
+
Filter,
|
17
|
+
FieldCondition,
|
18
|
+
Range,
|
19
|
+
)
|
20
|
+
|
21
|
+
from qdrant_client import QdrantClient
|
22
|
+
|
23
|
+
|
24
|
+
log = logging.getLogger(__name__)
|
25
|
+
|
26
|
+
|
27
|
+
class QdrantCloud(VectorDB):
|
28
|
+
def __init__(
|
29
|
+
self,
|
30
|
+
dim: int,
|
31
|
+
db_config: dict,
|
32
|
+
db_case_config: DBCaseConfig,
|
33
|
+
collection_name: str = "QdrantCloudCollection",
|
34
|
+
drop_old: bool = False,
|
35
|
+
):
|
36
|
+
"""Initialize wrapper around the QdrantCloud vector database."""
|
37
|
+
self.db_config = db_config
|
38
|
+
self.case_config = db_case_config
|
39
|
+
self.collection_name = collection_name
|
40
|
+
|
41
|
+
self._primary_field = "pk"
|
42
|
+
self._vector_field = "vector"
|
43
|
+
|
44
|
+
tmp_client = QdrantClient(**self.db_config)
|
45
|
+
if drop_old:
|
46
|
+
log.info(f"QdrantCloud client drop_old collection: {self.collection_name}")
|
47
|
+
tmp_client.delete_collection(self.collection_name)
|
48
|
+
|
49
|
+
self._create_collection(dim, tmp_client)
|
50
|
+
tmp_client = None
|
51
|
+
|
52
|
+
@classmethod
|
53
|
+
def config_cls(cls) -> Type[DBConfig]:
|
54
|
+
return QdrantConfig
|
55
|
+
|
56
|
+
@classmethod
|
57
|
+
def case_config_cls(cls, index_type: IndexType | None = None) -> Type[DBCaseConfig]:
|
58
|
+
return EmptyDBCaseConfig
|
59
|
+
|
60
|
+
@contextmanager
|
61
|
+
def init(self) -> None:
|
62
|
+
"""
|
63
|
+
Examples:
|
64
|
+
>>> with self.init():
|
65
|
+
>>> self.insert_embeddings()
|
66
|
+
>>> self.search_embedding()
|
67
|
+
"""
|
68
|
+
self.qdrant_client = QdrantClient(**self.db_config)
|
69
|
+
yield
|
70
|
+
self.qdrant_client = None
|
71
|
+
del(self.qdrant_client)
|
72
|
+
|
73
|
+
def ready_to_load(self):
|
74
|
+
pass
|
75
|
+
|
76
|
+
|
77
|
+
def ready_to_search(self):
|
78
|
+
assert self.qdrant_client, "Please call self.init() before"
|
79
|
+
# wait for vectors to be fully indexed
|
80
|
+
SECONDS_WAITING_FOR_INDEXING_API_CALL = 5
|
81
|
+
try:
|
82
|
+
while True:
|
83
|
+
info = self.qdrant_client.get_collection(self.collection_name)
|
84
|
+
time.sleep(SECONDS_WAITING_FOR_INDEXING_API_CALL)
|
85
|
+
if info.status != CollectionStatus.GREEN:
|
86
|
+
continue
|
87
|
+
if info.status == CollectionStatus.GREEN:
|
88
|
+
log.info(f"Stored vectors: {info.vectors_count}, Indexed vectors: {info.indexed_vectors_count}, Collection status: {info.indexed_vectors_count}")
|
89
|
+
return
|
90
|
+
except Exception as e:
|
91
|
+
log.warning(f"QdrantCloud ready to search error: {e}")
|
92
|
+
raise e from None
|
93
|
+
|
94
|
+
def _create_collection(self, dim, qdrant_client: int):
|
95
|
+
log.info(f"Create collection: {self.collection_name}")
|
96
|
+
|
97
|
+
try:
|
98
|
+
qdrant_client.create_collection(
|
99
|
+
collection_name=self.collection_name,
|
100
|
+
vectors_config=VectorParams(size=dim, distance=Distance.EUCLID)
|
101
|
+
)
|
102
|
+
|
103
|
+
qdrant_client.create_payload_index(
|
104
|
+
collection_name=self.collection_name,
|
105
|
+
field_name=self._primary_field,
|
106
|
+
field_schema=PayloadSchemaType.INTEGER,
|
107
|
+
)
|
108
|
+
|
109
|
+
except Exception as e:
|
110
|
+
if "already exists!" in str(e):
|
111
|
+
return
|
112
|
+
log.warning(f"Failed to create collection: {self.collection_name} error: {e}")
|
113
|
+
raise e from None
|
114
|
+
|
115
|
+
def insert_embeddings(
|
116
|
+
self,
|
117
|
+
embeddings: list[list[float]],
|
118
|
+
metadata: list[int],
|
119
|
+
**kwargs: Any,
|
120
|
+
) -> list[str]:
|
121
|
+
"""Insert embeddings into Milvus. should call self.init() first"""
|
122
|
+
assert self.qdrant_client is not None
|
123
|
+
try:
|
124
|
+
# TODO: counts
|
125
|
+
_ = self.qdrant_client.upsert(
|
126
|
+
collection_name=self.collection_name,
|
127
|
+
wait=True,
|
128
|
+
points=Batch(ids=metadata, payloads=[{self._primary_field: v} for v in metadata], vectors=embeddings)
|
129
|
+
)
|
130
|
+
|
131
|
+
return len(metadata)
|
132
|
+
except Exception as e:
|
133
|
+
log.info(f"Failed to insert data, {e}")
|
134
|
+
raise e from None
|
135
|
+
|
136
|
+
def search_embedding(
|
137
|
+
self,
|
138
|
+
query: list[float],
|
139
|
+
k: int = 100,
|
140
|
+
filters: dict | None = None,
|
141
|
+
timeout: int | None = None,
|
142
|
+
**kwargs: Any,
|
143
|
+
) -> list[int]:
|
144
|
+
"""Perform a search on a query embedding and return results with score.
|
145
|
+
Should call self.init() first.
|
146
|
+
"""
|
147
|
+
assert self.qdrant_client is not None
|
148
|
+
|
149
|
+
f = None
|
150
|
+
if filters:
|
151
|
+
f = Filter(
|
152
|
+
must=[FieldCondition(
|
153
|
+
key = self._primary_field,
|
154
|
+
range = Range(
|
155
|
+
gt=filters.get('id'),
|
156
|
+
),
|
157
|
+
)]
|
158
|
+
)
|
159
|
+
|
160
|
+
res = self.qdrant_client.search(
|
161
|
+
collection_name=self.collection_name,
|
162
|
+
query_vector=query,
|
163
|
+
limit=k,
|
164
|
+
query_filter=f,
|
165
|
+
# with_payload=True,
|
166
|
+
),
|
167
|
+
|
168
|
+
ret = [result.id for result in res[0]]
|
169
|
+
return ret
|
@@ -0,0 +1,45 @@
|
|
1
|
+
from pydantic import BaseModel, SecretStr
|
2
|
+
import weaviate
|
3
|
+
|
4
|
+
from ..api import DBConfig, DBCaseConfig, MetricType
|
5
|
+
|
6
|
+
|
7
|
+
class WeaviateConfig(DBConfig, BaseModel):
|
8
|
+
url: SecretStr | None = None
|
9
|
+
api_key: SecretStr | None = None
|
10
|
+
|
11
|
+
def to_dict(self) -> dict:
|
12
|
+
return {
|
13
|
+
"url": self.url.get_secret_value(),
|
14
|
+
"auth_client_secret": weaviate.AuthApiKey(api_key=self.api_key.get_secret_value()),
|
15
|
+
}
|
16
|
+
|
17
|
+
|
18
|
+
class WeaviateIndexConfig(BaseModel, DBCaseConfig):
|
19
|
+
metric_type: MetricType | None = None
|
20
|
+
ef: int | None = -1
|
21
|
+
efConstruction: int | None = None
|
22
|
+
maxConnections: int | None = None
|
23
|
+
|
24
|
+
def parse_metric(self) -> str:
|
25
|
+
if self.metric_type == MetricType.L2:
|
26
|
+
return "l2-squared"
|
27
|
+
elif self.metric_type == MetricType.IP:
|
28
|
+
return "dot"
|
29
|
+
return "cosine"
|
30
|
+
|
31
|
+
def index_param(self) -> dict:
|
32
|
+
if self.maxConnections is not None and self.efConstruction is not None:
|
33
|
+
params = {
|
34
|
+
"distance": self.parse_metric(),
|
35
|
+
"maxConnections": self.maxConnections,
|
36
|
+
"efConstruction": self.efConstruction,
|
37
|
+
}
|
38
|
+
else:
|
39
|
+
params = {"distance": self.parse_metric()}
|
40
|
+
return params
|
41
|
+
|
42
|
+
def search_param(self) -> dict:
|
43
|
+
return {
|
44
|
+
"ef": self.ef,
|
45
|
+
}
|