vectordb-bench 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vectordb_bench/__init__.py +14 -3
- vectordb_bench/backend/assembler.py +2 -2
- vectordb_bench/backend/cases.py +146 -57
- vectordb_bench/backend/clients/__init__.py +6 -1
- vectordb_bench/backend/clients/api.py +23 -11
- vectordb_bench/backend/clients/elastic_cloud/config.py +5 -5
- vectordb_bench/backend/clients/elastic_cloud/elastic_cloud.py +11 -9
- vectordb_bench/backend/clients/milvus/config.py +2 -3
- vectordb_bench/backend/clients/milvus/milvus.py +32 -19
- vectordb_bench/backend/clients/pgvector/config.py +49 -0
- vectordb_bench/backend/clients/pgvector/pgvector.py +171 -0
- vectordb_bench/backend/clients/pinecone/config.py +3 -3
- vectordb_bench/backend/clients/pinecone/pinecone.py +19 -13
- vectordb_bench/backend/clients/qdrant_cloud/config.py +23 -6
- vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py +12 -13
- vectordb_bench/backend/clients/weaviate_cloud/config.py +3 -3
- vectordb_bench/backend/clients/weaviate_cloud/weaviate_cloud.py +9 -8
- vectordb_bench/backend/clients/zilliz_cloud/config.py +5 -4
- vectordb_bench/backend/clients/zilliz_cloud/zilliz_cloud.py +3 -1
- vectordb_bench/backend/dataset.py +100 -162
- vectordb_bench/backend/result_collector.py +2 -2
- vectordb_bench/backend/runner/mp_runner.py +29 -13
- vectordb_bench/backend/runner/serial_runner.py +98 -36
- vectordb_bench/backend/task_runner.py +43 -48
- vectordb_bench/frontend/components/check_results/charts.py +10 -21
- vectordb_bench/frontend/components/check_results/data.py +31 -15
- vectordb_bench/frontend/components/check_results/expanderStyle.py +37 -0
- vectordb_bench/frontend/components/check_results/filters.py +61 -33
- vectordb_bench/frontend/components/check_results/footer.py +8 -0
- vectordb_bench/frontend/components/check_results/headerIcon.py +8 -4
- vectordb_bench/frontend/components/check_results/nav.py +7 -6
- vectordb_bench/frontend/components/check_results/priceTable.py +3 -2
- vectordb_bench/frontend/components/check_results/stPageConfig.py +18 -0
- vectordb_bench/frontend/components/get_results/saveAsImage.py +50 -0
- vectordb_bench/frontend/components/run_test/autoRefresh.py +1 -1
- vectordb_bench/frontend/components/run_test/caseSelector.py +19 -16
- vectordb_bench/frontend/components/run_test/dbConfigSetting.py +20 -7
- vectordb_bench/frontend/components/run_test/dbSelector.py +5 -5
- vectordb_bench/frontend/components/run_test/hideSidebar.py +4 -6
- vectordb_bench/frontend/components/run_test/submitTask.py +16 -10
- vectordb_bench/frontend/const/dbCaseConfigs.py +291 -0
- vectordb_bench/frontend/const/dbPrices.py +6 -0
- vectordb_bench/frontend/const/styles.py +58 -0
- vectordb_bench/frontend/pages/{qps_with_price.py → quries_per_dollar.py} +24 -17
- vectordb_bench/frontend/pages/run_test.py +17 -11
- vectordb_bench/frontend/vdb_benchmark.py +19 -12
- vectordb_bench/metric.py +19 -10
- vectordb_bench/models.py +14 -40
- vectordb_bench/results/dbPrices.json +32 -0
- vectordb_bench/results/getLeaderboardData.py +52 -0
- vectordb_bench/results/leaderboard.json +1 -0
- vectordb_bench/results/{result_20230609_standard.json → result_20230705_standard.json} +1910 -897
- {vectordb_bench-0.0.1.dist-info → vectordb_bench-0.0.3.dist-info}/METADATA +107 -27
- vectordb_bench-0.0.3.dist-info/RECORD +67 -0
- vectordb_bench/frontend/const.py +0 -391
- vectordb_bench-0.0.1.dist-info/RECORD +0 -56
- {vectordb_bench-0.0.1.dist-info → vectordb_bench-0.0.3.dist-info}/LICENSE +0 -0
- {vectordb_bench-0.0.1.dist-info → vectordb_bench-0.0.3.dist-info}/WHEEL +0 -0
- {vectordb_bench-0.0.1.dist-info → vectordb_bench-0.0.3.dist-info}/entry_points.txt +0 -0
- {vectordb_bench-0.0.1.dist-info → vectordb_bench-0.0.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,49 @@
|
|
1
|
+
from pydantic import BaseModel, SecretStr
|
2
|
+
from ..api import DBConfig, DBCaseConfig, MetricType
|
3
|
+
|
4
|
+
POSTGRE_URL_PLACEHOLDER = "postgresql://%s:%s@%s/%s"
|
5
|
+
|
6
|
+
class PgVectorConfig(DBConfig):
|
7
|
+
user_name: SecretStr = "postgres"
|
8
|
+
password: SecretStr
|
9
|
+
url: SecretStr
|
10
|
+
db_name: str
|
11
|
+
|
12
|
+
def to_dict(self) -> dict:
|
13
|
+
user_str = self.user_name.get_secret_value()
|
14
|
+
pwd_str = self.password.get_secret_value()
|
15
|
+
url_str = self.url.get_secret_value()
|
16
|
+
return {
|
17
|
+
"url" : POSTGRE_URL_PLACEHOLDER%(user_str, pwd_str, url_str, self.db_name)
|
18
|
+
}
|
19
|
+
|
20
|
+
class PgVectorIndexConfig(BaseModel, DBCaseConfig):
|
21
|
+
metric_type: MetricType | None = None
|
22
|
+
lists: int | None = 1000
|
23
|
+
probes: int | None = 10
|
24
|
+
|
25
|
+
def parse_metric(self) -> str:
|
26
|
+
if self.metric_type == MetricType.L2:
|
27
|
+
return "vector_l2_ops"
|
28
|
+
elif self.metric_type == MetricType.IP:
|
29
|
+
return "vector_ip_ops"
|
30
|
+
return "vector_cosine_ops"
|
31
|
+
|
32
|
+
def parse_metric_fun_str(self) -> str:
|
33
|
+
if self.metric_type == MetricType.L2:
|
34
|
+
return "l2_distance"
|
35
|
+
elif self.metric_type == MetricType.IP:
|
36
|
+
return "max_inner_product"
|
37
|
+
return "cosine_distance"
|
38
|
+
|
39
|
+
def index_param(self) -> dict:
|
40
|
+
return {
|
41
|
+
"lists" : self.lists,
|
42
|
+
"metric" : self.parse_metric()
|
43
|
+
}
|
44
|
+
|
45
|
+
def search_param(self) -> dict:
|
46
|
+
return {
|
47
|
+
"probes" : self.probes,
|
48
|
+
"metric_fun" : self.parse_metric_fun_str()
|
49
|
+
}
|
@@ -0,0 +1,171 @@
|
|
1
|
+
"""Wrapper around the Pgvector vector database over VectorDB"""
|
2
|
+
|
3
|
+
import logging
|
4
|
+
import time
|
5
|
+
from contextlib import contextmanager
|
6
|
+
from typing import Any, Type
|
7
|
+
from functools import wraps
|
8
|
+
|
9
|
+
from ..api import VectorDB, DBConfig, DBCaseConfig, IndexType
|
10
|
+
from pgvector.sqlalchemy import Vector
|
11
|
+
from .config import PgVectorConfig, PgVectorIndexConfig
|
12
|
+
from sqlalchemy import (
|
13
|
+
MetaData,
|
14
|
+
create_engine,
|
15
|
+
insert,
|
16
|
+
select,
|
17
|
+
Index,
|
18
|
+
Table,
|
19
|
+
text,
|
20
|
+
Column,
|
21
|
+
Float,
|
22
|
+
Integer
|
23
|
+
)
|
24
|
+
from sqlalchemy.orm import (
|
25
|
+
declarative_base,
|
26
|
+
mapped_column,
|
27
|
+
Session
|
28
|
+
)
|
29
|
+
|
30
|
+
log = logging.getLogger(__name__)
|
31
|
+
|
32
|
+
class PgVector(VectorDB):
|
33
|
+
""" Use SQLAlchemy instructions"""
|
34
|
+
def __init__(
|
35
|
+
self,
|
36
|
+
dim: int,
|
37
|
+
db_config: dict,
|
38
|
+
db_case_config: DBCaseConfig,
|
39
|
+
collection_name: str = "PgVectorCollection",
|
40
|
+
drop_old: bool = False,
|
41
|
+
**kwargs,
|
42
|
+
):
|
43
|
+
self.db_config = db_config
|
44
|
+
self.case_config = db_case_config
|
45
|
+
self.table_name = collection_name
|
46
|
+
self.dim = dim
|
47
|
+
|
48
|
+
self._index_name = "pqvector_index"
|
49
|
+
self._primary_field = "id"
|
50
|
+
self._vector_field = "embedding"
|
51
|
+
|
52
|
+
# construct basic units
|
53
|
+
pg_engine = create_engine(**self.db_config)
|
54
|
+
Base = declarative_base()
|
55
|
+
pq_metadata = Base.metadata
|
56
|
+
pq_metadata.reflect(pg_engine)
|
57
|
+
|
58
|
+
# create vector extension
|
59
|
+
with pg_engine.connect() as conn:
|
60
|
+
conn.execute(text('CREATE EXTENSION IF NOT EXISTS vector'))
|
61
|
+
conn.commit()
|
62
|
+
|
63
|
+
self.pg_table = self._get_table_schema(pq_metadata)
|
64
|
+
if drop_old and self.table_name in pq_metadata.tables:
|
65
|
+
log.info(f"Pgvector client drop table : {self.table_name}")
|
66
|
+
# self.pg_table.drop(pg_engine, checkfirst=True)
|
67
|
+
pq_metadata.drop_all(pg_engine)
|
68
|
+
self._create_table(dim, pg_engine)
|
69
|
+
|
70
|
+
|
71
|
+
@classmethod
|
72
|
+
def config_cls(cls) -> Type[DBConfig]:
|
73
|
+
return PgVectorConfig
|
74
|
+
|
75
|
+
@classmethod
|
76
|
+
def case_config_cls(cls, index_type: IndexType | None = None) -> Type[DBCaseConfig]:
|
77
|
+
return PgVectorIndexConfig
|
78
|
+
|
79
|
+
@contextmanager
|
80
|
+
def init(self) -> None:
|
81
|
+
"""
|
82
|
+
Examples:
|
83
|
+
>>> with self.init():
|
84
|
+
>>> self.insert_embeddings()
|
85
|
+
>>> self.search_embedding()
|
86
|
+
"""
|
87
|
+
self.pg_engine = create_engine(**self.db_config)
|
88
|
+
|
89
|
+
Base = declarative_base()
|
90
|
+
pq_metadata = Base.metadata
|
91
|
+
pq_metadata.reflect(self.pg_engine)
|
92
|
+
self.pg_session = Session(self.pg_engine)
|
93
|
+
self.pg_table = self._get_table_schema(pq_metadata)
|
94
|
+
yield
|
95
|
+
self.pg_session = None
|
96
|
+
self.pg_engine = None
|
97
|
+
del (self.pg_session)
|
98
|
+
del (self.pg_engine)
|
99
|
+
|
100
|
+
def ready_to_load(self):
|
101
|
+
pass
|
102
|
+
|
103
|
+
def optimize(self):
|
104
|
+
pass
|
105
|
+
|
106
|
+
def ready_to_search(self):
|
107
|
+
pass
|
108
|
+
|
109
|
+
def _get_table_schema(self, pq_metadata):
|
110
|
+
return Table(
|
111
|
+
self.table_name,
|
112
|
+
pq_metadata,
|
113
|
+
Column(self._primary_field, Integer, primary_key=True),
|
114
|
+
Column(self._vector_field, Vector(self.dim)),
|
115
|
+
extend_existing=True
|
116
|
+
)
|
117
|
+
|
118
|
+
def _create_index(self, pg_engine):
|
119
|
+
index_param = self.case_config.index_param()
|
120
|
+
index = Index(self._index_name, self.pg_table.c.embedding,
|
121
|
+
postgresql_using='ivfflat',
|
122
|
+
postgresql_with={'lists': index_param["lists"]},
|
123
|
+
postgresql_ops={'embedding': index_param["metric"]}
|
124
|
+
)
|
125
|
+
index.drop(pg_engine, checkfirst = True)
|
126
|
+
index.create(pg_engine)
|
127
|
+
|
128
|
+
def _create_table(self, dim, pg_engine : int):
|
129
|
+
try:
|
130
|
+
# create table
|
131
|
+
self.pg_table.create(bind = pg_engine, checkfirst = True)
|
132
|
+
# create vec index
|
133
|
+
self._create_index(pg_engine)
|
134
|
+
except Exception as e:
|
135
|
+
log.warning(f"Failed to create pgvector table: {self.table_name} error: {e}")
|
136
|
+
raise e from None
|
137
|
+
|
138
|
+
def insert_embeddings(
|
139
|
+
self,
|
140
|
+
embeddings: list[list[float]],
|
141
|
+
metadata: list[int],
|
142
|
+
**kwargs: Any,
|
143
|
+
) -> (int, Exception):
|
144
|
+
try:
|
145
|
+
items = [dict(id = metadata[i], embedding=embeddings[i]) for i in range(len(metadata))]
|
146
|
+
self.pg_session.execute(insert(self.pg_table), items)
|
147
|
+
self.pg_session.commit()
|
148
|
+
return len(metadata), None
|
149
|
+
except Exception as e:
|
150
|
+
log.warning(f"Failed to insert data into pgvector table ({self.table_name}), error: {e}")
|
151
|
+
return 0, e
|
152
|
+
|
153
|
+
def search_embedding(
|
154
|
+
self,
|
155
|
+
query: list[float],
|
156
|
+
k: int = 100,
|
157
|
+
filters: dict | None = None,
|
158
|
+
timeout: int | None = None,
|
159
|
+
) -> list[int]:
|
160
|
+
assert self.pg_table is not None
|
161
|
+
search_param =self.case_config.search_param()
|
162
|
+
with self.pg_engine.connect() as conn:
|
163
|
+
conn.execute(text(f'SET ivfflat.probes = {search_param["probes"]}'))
|
164
|
+
conn.commit()
|
165
|
+
op_fun = getattr(self.pg_table.c.embedding, search_param["metric_fun"])
|
166
|
+
if filters:
|
167
|
+
res = self.pg_session.scalars(select(self.pg_table).order_by(op_fun(query)).filter(self.pg_table.c.id > filters.get('id')).limit(k))
|
168
|
+
else:
|
169
|
+
res = self.pg_session.scalars(select(self.pg_table).order_by(op_fun(query)).limit(k))
|
170
|
+
return list(res)
|
171
|
+
|
@@ -2,9 +2,9 @@ from pydantic import BaseModel, SecretStr
|
|
2
2
|
from ..api import DBConfig
|
3
3
|
|
4
4
|
|
5
|
-
class PineconeConfig(DBConfig
|
6
|
-
api_key: SecretStr
|
7
|
-
environment: SecretStr
|
5
|
+
class PineconeConfig(DBConfig):
|
6
|
+
api_key: SecretStr
|
7
|
+
environment: SecretStr
|
8
8
|
index_name: str
|
9
9
|
|
10
10
|
def to_dict(self) -> dict:
|
@@ -2,7 +2,7 @@
|
|
2
2
|
|
3
3
|
import logging
|
4
4
|
from contextlib import contextmanager
|
5
|
-
from typing import
|
5
|
+
from typing import Type
|
6
6
|
|
7
7
|
from ..api import VectorDB, DBConfig, DBCaseConfig, EmptyDBCaseConfig, IndexType
|
8
8
|
from .config import PineconeConfig
|
@@ -20,6 +20,7 @@ class Pinecone(VectorDB):
|
|
20
20
|
db_config: dict,
|
21
21
|
db_case_config: DBCaseConfig,
|
22
22
|
drop_old: bool = False,
|
23
|
+
**kwargs,
|
23
24
|
):
|
24
25
|
"""Initialize wrapper around the milvus vector database."""
|
25
26
|
self.index_name = db_config["index_name"]
|
@@ -69,24 +70,30 @@ class Pinecone(VectorDB):
|
|
69
70
|
def ready_to_load(self):
|
70
71
|
pass
|
71
72
|
|
72
|
-
def
|
73
|
+
def optimize(self):
|
73
74
|
pass
|
74
75
|
|
75
76
|
def insert_embeddings(
|
76
77
|
self,
|
77
78
|
embeddings: list[list[float]],
|
78
79
|
metadata: list[int],
|
79
|
-
|
80
|
+
**kwargs,
|
81
|
+
) -> (int, Exception):
|
80
82
|
assert len(embeddings) == len(metadata)
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
83
|
+
insert_count = 0
|
84
|
+
try:
|
85
|
+
for batch_start_offset in range(0, len(embeddings), self.batch_size):
|
86
|
+
batch_end_offset = min(batch_start_offset + self.batch_size, len(embeddings))
|
87
|
+
insert_datas = []
|
88
|
+
for i in range(batch_start_offset, batch_end_offset):
|
89
|
+
insert_data = (str(metadata[i]), embeddings[i], {
|
90
|
+
self._metadata_key: metadata[i]})
|
91
|
+
insert_datas.append(insert_data)
|
92
|
+
self.index.upsert(insert_datas)
|
93
|
+
insert_count += batch_end_offset - batch_start_offset
|
94
|
+
except Exception as e:
|
95
|
+
return (insert_count, e)
|
96
|
+
return (len(embeddings), None)
|
90
97
|
|
91
98
|
def search_embedding(
|
92
99
|
self,
|
@@ -94,7 +101,6 @@ class Pinecone(VectorDB):
|
|
94
101
|
k: int = 100,
|
95
102
|
filters: dict | None = None,
|
96
103
|
timeout: int | None = None,
|
97
|
-
**kwargs: Any,
|
98
104
|
) -> list[tuple[int, float]]:
|
99
105
|
if filters is None:
|
100
106
|
pinecone_filters = {}
|
@@ -1,16 +1,33 @@
|
|
1
1
|
from pydantic import BaseModel, SecretStr
|
2
2
|
|
3
|
-
from ..api import DBConfig
|
3
|
+
from ..api import DBConfig, DBCaseConfig, MetricType
|
4
|
+
from qdrant_client.models import Distance
|
4
5
|
|
5
6
|
|
6
|
-
class QdrantConfig(DBConfig
|
7
|
-
url: SecretStr
|
8
|
-
api_key: SecretStr
|
9
|
-
prefer_grpc: bool = True
|
7
|
+
class QdrantConfig(DBConfig):
|
8
|
+
url: SecretStr
|
9
|
+
api_key: SecretStr
|
10
10
|
|
11
11
|
def to_dict(self) -> dict:
|
12
12
|
return {
|
13
13
|
"url": self.url.get_secret_value(),
|
14
14
|
"api_key": self.api_key.get_secret_value(),
|
15
|
-
"prefer_grpc":
|
15
|
+
"prefer_grpc": True,
|
16
16
|
}
|
17
|
+
|
18
|
+
class QdrantIndexConfig(BaseModel, DBCaseConfig):
|
19
|
+
metric_type: MetricType | None = None
|
20
|
+
|
21
|
+
def parse_metric(self) -> str:
|
22
|
+
if self.metric_type == MetricType.L2:
|
23
|
+
return Distance.EUCLID
|
24
|
+
elif self.metric_type == MetricType.IP:
|
25
|
+
return Distance.DOT
|
26
|
+
return Distance.COSINE
|
27
|
+
|
28
|
+
def index_param(self) -> dict:
|
29
|
+
params = {"distance": self.parse_metric()}
|
30
|
+
return params
|
31
|
+
|
32
|
+
def search_param(self) -> dict:
|
33
|
+
return {}
|
@@ -3,13 +3,12 @@
|
|
3
3
|
import logging
|
4
4
|
import time
|
5
5
|
from contextlib import contextmanager
|
6
|
-
from typing import
|
6
|
+
from typing import Type
|
7
7
|
|
8
|
-
from ..api import VectorDB, DBConfig, DBCaseConfig,
|
9
|
-
from .config import QdrantConfig
|
8
|
+
from ..api import VectorDB, DBConfig, DBCaseConfig, IndexType
|
9
|
+
from .config import QdrantConfig, QdrantIndexConfig
|
10
10
|
from qdrant_client.http.models import (
|
11
11
|
CollectionStatus,
|
12
|
-
Distance,
|
13
12
|
VectorParams,
|
14
13
|
PayloadSchemaType,
|
15
14
|
Batch,
|
@@ -32,6 +31,7 @@ class QdrantCloud(VectorDB):
|
|
32
31
|
db_case_config: DBCaseConfig,
|
33
32
|
collection_name: str = "QdrantCloudCollection",
|
34
33
|
drop_old: bool = False,
|
34
|
+
**kwargs,
|
35
35
|
):
|
36
36
|
"""Initialize wrapper around the QdrantCloud vector database."""
|
37
37
|
self.db_config = db_config
|
@@ -55,7 +55,7 @@ class QdrantCloud(VectorDB):
|
|
55
55
|
|
56
56
|
@classmethod
|
57
57
|
def case_config_cls(cls, index_type: IndexType | None = None) -> Type[DBCaseConfig]:
|
58
|
-
return
|
58
|
+
return QdrantIndexConfig
|
59
59
|
|
60
60
|
@contextmanager
|
61
61
|
def init(self) -> None:
|
@@ -74,7 +74,7 @@ class QdrantCloud(VectorDB):
|
|
74
74
|
pass
|
75
75
|
|
76
76
|
|
77
|
-
def
|
77
|
+
def optimize(self):
|
78
78
|
assert self.qdrant_client, "Please call self.init() before"
|
79
79
|
# wait for vectors to be fully indexed
|
80
80
|
SECONDS_WAITING_FOR_INDEXING_API_CALL = 5
|
@@ -97,7 +97,7 @@ class QdrantCloud(VectorDB):
|
|
97
97
|
try:
|
98
98
|
qdrant_client.create_collection(
|
99
99
|
collection_name=self.collection_name,
|
100
|
-
vectors_config=VectorParams(size=dim, distance=
|
100
|
+
vectors_config=VectorParams(size=dim, distance=self.case_config.index_param()["distance"])
|
101
101
|
)
|
102
102
|
|
103
103
|
qdrant_client.create_payload_index(
|
@@ -116,8 +116,8 @@ class QdrantCloud(VectorDB):
|
|
116
116
|
self,
|
117
117
|
embeddings: list[list[float]],
|
118
118
|
metadata: list[int],
|
119
|
-
**kwargs
|
120
|
-
) ->
|
119
|
+
**kwargs,
|
120
|
+
) -> (int, Exception):
|
121
121
|
"""Insert embeddings into Milvus. should call self.init() first"""
|
122
122
|
assert self.qdrant_client is not None
|
123
123
|
try:
|
@@ -127,11 +127,11 @@ class QdrantCloud(VectorDB):
|
|
127
127
|
wait=True,
|
128
128
|
points=Batch(ids=metadata, payloads=[{self._primary_field: v} for v in metadata], vectors=embeddings)
|
129
129
|
)
|
130
|
-
|
131
|
-
return len(metadata)
|
132
130
|
except Exception as e:
|
133
131
|
log.info(f"Failed to insert data, {e}")
|
134
|
-
|
132
|
+
return 0, e
|
133
|
+
else:
|
134
|
+
return len(metadata), None
|
135
135
|
|
136
136
|
def search_embedding(
|
137
137
|
self,
|
@@ -139,7 +139,6 @@ class QdrantCloud(VectorDB):
|
|
139
139
|
k: int = 100,
|
140
140
|
filters: dict | None = None,
|
141
141
|
timeout: int | None = None,
|
142
|
-
**kwargs: Any,
|
143
142
|
) -> list[int]:
|
144
143
|
"""Perform a search on a query embedding and return results with score.
|
145
144
|
Should call self.init() first.
|
@@ -4,9 +4,9 @@ import weaviate
|
|
4
4
|
from ..api import DBConfig, DBCaseConfig, MetricType
|
5
5
|
|
6
6
|
|
7
|
-
class WeaviateConfig(DBConfig
|
8
|
-
url: SecretStr
|
9
|
-
api_key: SecretStr
|
7
|
+
class WeaviateConfig(DBConfig):
|
8
|
+
url: SecretStr
|
9
|
+
api_key: SecretStr
|
10
10
|
|
11
11
|
def to_dict(self) -> dict:
|
12
12
|
return {
|
@@ -1,7 +1,7 @@
|
|
1
1
|
"""Wrapper around the Weaviate vector database over VectorDB"""
|
2
2
|
|
3
3
|
import logging
|
4
|
-
from typing import
|
4
|
+
from typing import Iterable, Type
|
5
5
|
from contextlib import contextmanager
|
6
6
|
|
7
7
|
from weaviate.exceptions import WeaviateBaseError
|
@@ -21,6 +21,7 @@ class WeaviateCloud(VectorDB):
|
|
21
21
|
db_case_config: DBCaseConfig,
|
22
22
|
collection_name: str = "VectorDBBenchCollection",
|
23
23
|
drop_old: bool = False,
|
24
|
+
**kwargs,
|
24
25
|
):
|
25
26
|
"""Initialize wrapper around the weaviate vector database."""
|
26
27
|
self.db_config = db_config
|
@@ -70,7 +71,7 @@ class WeaviateCloud(VectorDB):
|
|
70
71
|
"""Should call insert first, do nothing"""
|
71
72
|
pass
|
72
73
|
|
73
|
-
def
|
74
|
+
def optimize(self):
|
74
75
|
assert self.client.schema.exists(self.collection_name)
|
75
76
|
self.client.schema.update_config(self.collection_name, {"vectorIndexConfig": self.case_config.search_param() } )
|
76
77
|
|
@@ -98,11 +99,11 @@ class WeaviateCloud(VectorDB):
|
|
98
99
|
self,
|
99
100
|
embeddings: Iterable[list[float]],
|
100
101
|
metadata: list[int],
|
101
|
-
**kwargs
|
102
|
-
) -> int:
|
102
|
+
**kwargs,
|
103
|
+
) -> (int, Exception):
|
103
104
|
"""Insert embeddings into Weaviate"""
|
104
105
|
assert self.client.schema.exists(self.collection_name)
|
105
|
-
|
106
|
+
insert_count = 0
|
106
107
|
try:
|
107
108
|
with self.client.batch as batch:
|
108
109
|
batch.batch_size = len(metadata)
|
@@ -114,10 +115,11 @@ class WeaviateCloud(VectorDB):
|
|
114
115
|
class_name=self.collection_name,
|
115
116
|
vector=embeddings[i]
|
116
117
|
))
|
117
|
-
|
118
|
+
insert_count += 1
|
119
|
+
return (len(res), None)
|
118
120
|
except WeaviateBaseError as e:
|
119
121
|
log.warning(f"Failed to insert data, error: {str(e)}")
|
120
|
-
|
122
|
+
return (insert_count, e)
|
121
123
|
|
122
124
|
def search_embedding(
|
123
125
|
self,
|
@@ -125,7 +127,6 @@ class WeaviateCloud(VectorDB):
|
|
125
127
|
k: int = 100,
|
126
128
|
filters: dict | None = None,
|
127
129
|
timeout: int | None = None,
|
128
|
-
**kwargs: Any,
|
129
130
|
) -> list[int]:
|
130
131
|
"""Perform a search on a query embedding and return results with distance.
|
131
132
|
Should call self.init() first.
|
@@ -1,12 +1,13 @@
|
|
1
|
-
from pydantic import
|
1
|
+
from pydantic import SecretStr
|
2
|
+
|
2
3
|
from ..api import DBCaseConfig, DBConfig
|
3
4
|
from ..milvus.config import MilvusIndexConfig, IndexType
|
4
5
|
|
5
6
|
|
6
|
-
class ZillizCloudConfig(DBConfig
|
7
|
-
uri: SecretStr
|
7
|
+
class ZillizCloudConfig(DBConfig):
|
8
|
+
uri: SecretStr
|
8
9
|
user: str
|
9
|
-
password: SecretStr
|
10
|
+
password: SecretStr
|
10
11
|
|
11
12
|
def to_dict(self) -> dict:
|
12
13
|
return {
|
@@ -14,7 +14,8 @@ class ZillizCloud(Milvus):
|
|
14
14
|
db_case_config: DBCaseConfig,
|
15
15
|
collection_name: str = "ZillizCloudVectorDBBench",
|
16
16
|
drop_old: bool = False,
|
17
|
-
name: str = "ZillizCloud"
|
17
|
+
name: str = "ZillizCloud",
|
18
|
+
**kwargs,
|
18
19
|
):
|
19
20
|
super().__init__(
|
20
21
|
dim=dim,
|
@@ -23,6 +24,7 @@ class ZillizCloud(Milvus):
|
|
23
24
|
collection_name=collection_name,
|
24
25
|
drop_old=drop_old,
|
25
26
|
name=name,
|
27
|
+
**kwargs,
|
26
28
|
)
|
27
29
|
|
28
30
|
@classmethod
|