vectordb-bench 0.0.2__py3-none-any.whl → 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vectordb_bench/__init__.py +14 -3
- vectordb_bench/backend/cases.py +34 -13
- vectordb_bench/backend/clients/__init__.py +6 -1
- vectordb_bench/backend/clients/api.py +12 -8
- vectordb_bench/backend/clients/elastic_cloud/elastic_cloud.py +4 -2
- vectordb_bench/backend/clients/milvus/milvus.py +17 -10
- vectordb_bench/backend/clients/pgvector/config.py +49 -0
- vectordb_bench/backend/clients/pgvector/pgvector.py +171 -0
- vectordb_bench/backend/clients/pinecone/pinecone.py +4 -3
- vectordb_bench/backend/clients/qdrant_cloud/config.py +20 -2
- vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py +11 -11
- vectordb_bench/backend/clients/weaviate_cloud/weaviate_cloud.py +5 -5
- vectordb_bench/backend/clients/zilliz_cloud/zilliz_cloud.py +3 -1
- vectordb_bench/backend/dataset.py +99 -149
- vectordb_bench/backend/result_collector.py +2 -2
- vectordb_bench/backend/runner/mp_runner.py +29 -13
- vectordb_bench/backend/runner/serial_runner.py +69 -51
- vectordb_bench/backend/task_runner.py +43 -48
- vectordb_bench/frontend/components/get_results/saveAsImage.py +4 -2
- vectordb_bench/frontend/const/dbCaseConfigs.py +35 -4
- vectordb_bench/frontend/const/dbPrices.py +5 -33
- vectordb_bench/frontend/const/styles.py +9 -3
- vectordb_bench/metric.py +0 -1
- vectordb_bench/models.py +12 -8
- vectordb_bench/results/dbPrices.json +32 -0
- vectordb_bench/results/getLeaderboardData.py +52 -0
- vectordb_bench/results/leaderboard.json +1 -0
- vectordb_bench/results/{result_20230609_standard.json → result_20230705_standard.json} +670 -214
- {vectordb_bench-0.0.2.dist-info → vectordb_bench-0.0.3.dist-info}/METADATA +98 -13
- {vectordb_bench-0.0.2.dist-info → vectordb_bench-0.0.3.dist-info}/RECORD +34 -29
- {vectordb_bench-0.0.2.dist-info → vectordb_bench-0.0.3.dist-info}/LICENSE +0 -0
- {vectordb_bench-0.0.2.dist-info → vectordb_bench-0.0.3.dist-info}/WHEEL +0 -0
- {vectordb_bench-0.0.2.dist-info → vectordb_bench-0.0.3.dist-info}/entry_points.txt +0 -0
- {vectordb_bench-0.0.2.dist-info → vectordb_bench-0.0.3.dist-info}/top_level.txt +0 -0
vectordb_bench/__init__.py
CHANGED
@@ -18,12 +18,23 @@ class config:
|
|
18
18
|
USE_SHUFFLED_DATA = env.bool("USE_SHUFFLED_DATA", True)
|
19
19
|
|
20
20
|
RESULTS_LOCAL_DIR = pathlib.Path(__file__).parent.joinpath("results")
|
21
|
-
|
21
|
+
|
22
|
+
CAPACITY_TIMEOUT_IN_SECONDS = 24 * 3600 # 24h
|
23
|
+
LOAD_TIMEOUT_1M = 2.5 * 3600 # 2.5h
|
24
|
+
LOAD_TIMEOUT_10M = 25 * 3600 # 25h
|
25
|
+
LOAD_TIMEOUT_100M = 250 * 3600 # 10.41d
|
26
|
+
|
27
|
+
OPTIMIZE_TIMEOUT_1M = 15 * 60 # 15min
|
28
|
+
OPTIMIZE_TIMEOUT_10M = 2.5 * 3600 # 2.5h
|
29
|
+
OPTIMIZE_TIMEOUT_100M = 25 * 3600 # 1.04d
|
22
30
|
|
23
31
|
|
24
32
|
def display(self) -> str:
|
25
|
-
tmp = [
|
26
|
-
|
33
|
+
tmp = [
|
34
|
+
i for i in inspect.getmembers(self)
|
35
|
+
if not inspect.ismethod(i[1])
|
36
|
+
and not i[0].startswith('_')
|
37
|
+
and "TIMEOUT" not in i[0]
|
27
38
|
]
|
28
39
|
return tmp
|
29
40
|
|
vectordb_bench/backend/cases.py
CHANGED
@@ -2,8 +2,10 @@ import typing
|
|
2
2
|
import logging
|
3
3
|
from enum import Enum, auto
|
4
4
|
|
5
|
-
from
|
6
|
-
from
|
5
|
+
from vectordb_bench import config
|
6
|
+
from vectordb_bench.base import BaseModel
|
7
|
+
|
8
|
+
from .dataset import Dataset, DatasetManager
|
7
9
|
|
8
10
|
|
9
11
|
log = logging.getLogger(__name__)
|
@@ -44,7 +46,7 @@ class CaseType(Enum):
|
|
44
46
|
if c is not None:
|
45
47
|
return c().name
|
46
48
|
raise ValueError("Case unsupported")
|
47
|
-
|
49
|
+
|
48
50
|
@property
|
49
51
|
def case_description(self) -> str:
|
50
52
|
c = self.case_cls
|
@@ -73,7 +75,10 @@ class Case(BaseModel):
|
|
73
75
|
label: CaseLabel
|
74
76
|
name: str
|
75
77
|
description: str
|
76
|
-
dataset:
|
78
|
+
dataset: DatasetManager
|
79
|
+
|
80
|
+
load_timeout: float | int
|
81
|
+
optimize_timeout: float | int | None
|
77
82
|
|
78
83
|
filter_rate: float | None
|
79
84
|
|
@@ -92,6 +97,8 @@ class Case(BaseModel):
|
|
92
97
|
class CapacityCase(Case, BaseModel):
|
93
98
|
label: CaseLabel = CaseLabel.Load
|
94
99
|
filter_rate: float | None = None
|
100
|
+
load_timeout: float | int = config.CAPACITY_TIMEOUT_IN_SECONDS
|
101
|
+
optimize_timeout: float | int | None = None
|
95
102
|
|
96
103
|
|
97
104
|
class PerformanceCase(Case, BaseModel):
|
@@ -101,7 +108,7 @@ class PerformanceCase(Case, BaseModel):
|
|
101
108
|
|
102
109
|
class CapacityDim960(CapacityCase):
|
103
110
|
case_id: CaseType = CaseType.CapacityDim960
|
104
|
-
dataset:
|
111
|
+
dataset: DatasetManager = Dataset.GIST.manager(100_000)
|
105
112
|
name: str = "Capacity Test (960 Dim Repeated)"
|
106
113
|
description: str = """This case tests the vector database's loading capacity by repeatedly inserting large-dimension vectors (GIST 100K vectors, <b>960 dimensions</b>) until it is fully loaded.
|
107
114
|
Number of inserted vectors will be reported."""
|
@@ -109,7 +116,7 @@ Number of inserted vectors will be reported."""
|
|
109
116
|
|
110
117
|
class CapacityDim128(CapacityCase):
|
111
118
|
case_id: CaseType = CaseType.CapacityDim128
|
112
|
-
dataset:
|
119
|
+
dataset: DatasetManager = Dataset.SIFT.manager(500_000)
|
113
120
|
name: str = "Capacity Test (128 Dim Repeated)"
|
114
121
|
description: str = """This case tests the vector database's loading capacity by repeatedly inserting small-dimension vectors (SIFT 100K vectors, <b>128 dimensions</b>) until it is fully loaded.
|
115
122
|
Number of inserted vectors will be reported."""
|
@@ -117,64 +124,78 @@ Number of inserted vectors will be reported."""
|
|
117
124
|
|
118
125
|
class Performance10M(PerformanceCase):
|
119
126
|
case_id: CaseType = CaseType.Performance10M
|
120
|
-
dataset:
|
127
|
+
dataset: DatasetManager = Dataset.COHERE.manager(10_000_000)
|
121
128
|
name: str = "Search Performance Test (10M Dataset, 768 Dim)"
|
122
129
|
description: str = """This case tests the search performance of a vector database with a large dataset (<b>Cohere 10M vectors</b>, 768 dimensions) at varying parallel levels.
|
123
130
|
Results will show index building time, recall, and maximum QPS."""
|
131
|
+
load_timeout: float | int = config.LOAD_TIMEOUT_10M
|
132
|
+
optimize_timeout: float | int | None = config.OPTIMIZE_TIMEOUT_10M
|
124
133
|
|
125
134
|
|
126
135
|
class Performance1M(PerformanceCase):
|
127
136
|
case_id: CaseType = CaseType.Performance1M
|
128
|
-
dataset:
|
137
|
+
dataset: DatasetManager = Dataset.COHERE.manager(1_000_000)
|
129
138
|
name: str = "Search Performance Test (1M Dataset, 768 Dim)"
|
130
139
|
description: str = """This case tests the search performance of a vector database with a medium dataset (<b>Cohere 1M vectors</b>, 768 dimensions) at varying parallel levels.
|
131
140
|
Results will show index building time, recall, and maximum QPS."""
|
141
|
+
load_timeout: float | int = config.LOAD_TIMEOUT_1M
|
142
|
+
optimize_timeout: float | int | None = config.OPTIMIZE_TIMEOUT_1M
|
132
143
|
|
133
144
|
|
134
145
|
class Performance10M1P(PerformanceCase):
|
135
146
|
case_id: CaseType = CaseType.Performance10M1P
|
136
147
|
filter_rate: float | int | None = 0.01
|
137
|
-
dataset:
|
148
|
+
dataset: DatasetManager = Dataset.COHERE.manager(10_000_000)
|
138
149
|
name: str = "Filtering Search Performance Test (10M Dataset, 768 Dim, Filter 1%)"
|
139
150
|
description: str = """This case tests the search performance of a vector database with a large dataset (<b>Cohere 10M vectors</b>, 768 dimensions) under a low filtering rate (<b>1% vectors</b>), at varying parallel levels.
|
140
151
|
Results will show index building time, recall, and maximum QPS."""
|
152
|
+
load_timeout: float | int = config.LOAD_TIMEOUT_10M
|
153
|
+
optimize_timeout: float | int | None = config.OPTIMIZE_TIMEOUT_10M
|
141
154
|
|
142
155
|
|
143
156
|
class Performance1M1P(PerformanceCase):
|
144
157
|
case_id: CaseType = CaseType.Performance1M1P
|
145
158
|
filter_rate: float | int | None = 0.01
|
146
|
-
dataset:
|
159
|
+
dataset: DatasetManager = Dataset.COHERE.manager(1_000_000)
|
147
160
|
name: str = "Filtering Search Performance Test (1M Dataset, 768 Dim, Filter 1%)"
|
148
161
|
description: str = """This case tests the search performance of a vector database with a medium dataset (<b>Cohere 1M vectors</b>, 768 dimensions) under a low filtering rate (<b>1% vectors</b>), at varying parallel levels.
|
149
162
|
Results will show index building time, recall, and maximum QPS."""
|
163
|
+
load_timeout: float | int = config.LOAD_TIMEOUT_1M
|
164
|
+
optimize_timeout: float | int | None = config.OPTIMIZE_TIMEOUT_1M
|
150
165
|
|
151
166
|
|
152
167
|
class Performance10M99P(PerformanceCase):
|
153
168
|
case_id: CaseType = CaseType.Performance10M99P
|
154
169
|
filter_rate: float | int | None = 0.99
|
155
|
-
dataset:
|
170
|
+
dataset: DatasetManager = Dataset.COHERE.manager(10_000_000)
|
156
171
|
name: str = "Filtering Search Performance Test (10M Dataset, 768 Dim, Filter 99%)"
|
157
172
|
description: str = """This case tests the search performance of a vector database with a large dataset (<b>Cohere 10M vectors</b>, 768 dimensions) under a high filtering rate (<b>99% vectors</b>), at varying parallel levels.
|
158
173
|
Results will show index building time, recall, and maximum QPS."""
|
174
|
+
load_timeout: float | int = config.LOAD_TIMEOUT_10M
|
175
|
+
optimize_timeout: float | int | None = config.OPTIMIZE_TIMEOUT_10M
|
159
176
|
|
160
177
|
|
161
178
|
class Performance1M99P(PerformanceCase):
|
162
179
|
case_id: CaseType = CaseType.Performance1M99P
|
163
180
|
filter_rate: float | int | None = 0.99
|
164
|
-
dataset:
|
181
|
+
dataset: DatasetManager = Dataset.COHERE.manager(1_000_000)
|
165
182
|
name: str = "Filtering Search Performance Test (1M Dataset, 768 Dim, Filter 99%)"
|
166
183
|
description: str = """This case tests the search performance of a vector database with a medium dataset (<b>Cohere 1M vectors</b>, 768 dimensions) under a high filtering rate (<b>99% vectors</b>), at varying parallel levels.
|
167
184
|
Results will show index building time, recall, and maximum QPS."""
|
185
|
+
load_timeout: float | int = config.LOAD_TIMEOUT_1M
|
186
|
+
optimize_timeout: float | int | None = config.OPTIMIZE_TIMEOUT_1M
|
168
187
|
|
169
188
|
|
170
189
|
|
171
190
|
class Performance100M(PerformanceCase):
|
172
191
|
case_id: CaseType = CaseType.Performance100M
|
173
192
|
filter_rate: float | int | None = None
|
174
|
-
dataset:
|
193
|
+
dataset: DatasetManager = Dataset.LAION.manager(100_000_000)
|
175
194
|
name: str = "Search Performance Test (100M Dataset, 768 Dim)"
|
176
195
|
description: str = """This case tests the search performance of a vector database with a large 100M dataset (<b>LAION 100M vectors</b>, 768 dimensions), at varying parallel levels.
|
177
196
|
Results will show index building time, recall, and maximum QPS."""
|
197
|
+
load_timeout: float | int = config.LOAD_TIMEOUT_100M
|
198
|
+
optimize_timeout: float | int | None = config.OPTIMIZE_TIMEOUT_100M
|
178
199
|
|
179
200
|
|
180
201
|
type2case = {
|
@@ -15,7 +15,7 @@ from .pinecone.pinecone import Pinecone
|
|
15
15
|
from .weaviate_cloud.weaviate_cloud import WeaviateCloud
|
16
16
|
from .qdrant_cloud.qdrant_cloud import QdrantCloud
|
17
17
|
from .zilliz_cloud.zilliz_cloud import ZillizCloud
|
18
|
-
|
18
|
+
from .pgvector.pgvector import PgVector
|
19
19
|
|
20
20
|
class DB(Enum):
|
21
21
|
"""Database types
|
@@ -35,6 +35,7 @@ class DB(Enum):
|
|
35
35
|
ElasticCloud = "ElasticCloud"
|
36
36
|
QdrantCloud = "QdrantCloud"
|
37
37
|
WeaviateCloud = "WeaviateCloud"
|
38
|
+
PgVector = "PgVector"
|
38
39
|
|
39
40
|
|
40
41
|
@property
|
@@ -49,8 +50,12 @@ db2client = {
|
|
49
50
|
DB.ElasticCloud: ElasticCloud,
|
50
51
|
DB.QdrantCloud: QdrantCloud,
|
51
52
|
DB.Pinecone: Pinecone,
|
53
|
+
DB.PgVector: PgVector
|
52
54
|
}
|
53
55
|
|
56
|
+
for db in DB:
|
57
|
+
assert issubclass(db.init_cls, VectorDB)
|
58
|
+
|
54
59
|
|
55
60
|
__all__ = [
|
56
61
|
"DB", "VectorDB", "DBConfig", "DBCaseConfig", "IndexType", "MetricType", "EmptyDBCaseConfig",
|
@@ -73,7 +73,7 @@ class VectorDB(ABC):
|
|
73
73
|
|
74
74
|
In each process, the benchmark cases ensure VectorDB.init() calls before any other methods operations
|
75
75
|
|
76
|
-
insert_embeddings, search_embedding, and,
|
76
|
+
insert_embeddings, search_embedding, and, optimize will be timed for each call.
|
77
77
|
|
78
78
|
Examples:
|
79
79
|
>>> milvus = Milvus()
|
@@ -90,9 +90,12 @@ class VectorDB(ABC):
|
|
90
90
|
db_case_config: DBCaseConfig | None,
|
91
91
|
collection_name: str,
|
92
92
|
drop_old: bool = False,
|
93
|
-
**kwargs
|
93
|
+
**kwargs,
|
94
94
|
) -> None:
|
95
|
-
"""Initialize wrapper around the vector database client
|
95
|
+
"""Initialize wrapper around the vector database client.
|
96
|
+
|
97
|
+
Please drop the existing collection if drop_old is True. And create collection
|
98
|
+
if collection not in the Vector Database
|
96
99
|
|
97
100
|
Args:
|
98
101
|
dim(int): the dimension of the dataset
|
@@ -130,7 +133,7 @@ class VectorDB(ABC):
|
|
130
133
|
self,
|
131
134
|
embeddings: list[list[float]],
|
132
135
|
metadata: list[int],
|
133
|
-
kwargs
|
136
|
+
**kwargs,
|
134
137
|
) -> (int, Exception):
|
135
138
|
"""Insert the embeddings to the vector database. The default number of embeddings for
|
136
139
|
each insert_embeddings is 5000.
|
@@ -138,7 +141,7 @@ class VectorDB(ABC):
|
|
138
141
|
Args:
|
139
142
|
embeddings(list[list[float]]): list of embedding to add to the vector database.
|
140
143
|
metadatas(list[int]): metadata associated with the embeddings, for filtering.
|
141
|
-
kwargs(Any): vector database specific parameters.
|
144
|
+
**kwargs(Any): vector database specific parameters.
|
142
145
|
|
143
146
|
Returns:
|
144
147
|
int: inserted data count
|
@@ -166,13 +169,14 @@ class VectorDB(ABC):
|
|
166
169
|
|
167
170
|
# TODO: remove
|
168
171
|
@abstractmethod
|
169
|
-
def
|
170
|
-
"""
|
172
|
+
def optimize(self):
|
173
|
+
"""optimize will be called between insertion and search in performance cases.
|
171
174
|
|
172
175
|
Should be blocked until the vectorDB is ready to be tested on
|
173
176
|
heavy performance cases.
|
174
177
|
|
175
|
-
Time(insert the dataset) + Time(
|
178
|
+
Time(insert the dataset) + Time(optimize) will be recorded as "load_duration" metric
|
179
|
+
Optimize's execution time is limited, the limited time is based on cases.
|
176
180
|
"""
|
177
181
|
raise NotImplementedError
|
178
182
|
|
@@ -21,6 +21,7 @@ class ElasticCloud(VectorDB):
|
|
21
21
|
id_col_name: str = "id",
|
22
22
|
vector_col_name: str = "vector",
|
23
23
|
drop_old: bool = False,
|
24
|
+
**kwargs,
|
24
25
|
):
|
25
26
|
self.dim = dim
|
26
27
|
self.db_config = db_config
|
@@ -83,6 +84,7 @@ class ElasticCloud(VectorDB):
|
|
83
84
|
self,
|
84
85
|
embeddings: Iterable[list[float]],
|
85
86
|
metadata: list[int],
|
87
|
+
**kwargs,
|
86
88
|
) -> (int, Exception):
|
87
89
|
"""Insert the embeddings to the elasticsearch."""
|
88
90
|
assert self.client is not None, "should self.init() first"
|
@@ -143,8 +145,8 @@ class ElasticCloud(VectorDB):
|
|
143
145
|
log.warning(f"Failed to search: {self.indice} error: {str(e)}")
|
144
146
|
raise e from None
|
145
147
|
|
146
|
-
def
|
147
|
-
"""
|
148
|
+
def optimize(self):
|
149
|
+
"""optimize will be called between insertion and search in performance cases."""
|
148
150
|
pass
|
149
151
|
|
150
152
|
def ready_to_load(self):
|
@@ -2,7 +2,7 @@
|
|
2
2
|
|
3
3
|
import logging
|
4
4
|
from contextlib import contextmanager
|
5
|
-
from typing import
|
5
|
+
from typing import Iterable, Type
|
6
6
|
|
7
7
|
from pymilvus import Collection, utility
|
8
8
|
from pymilvus import CollectionSchema, DataType, FieldSchema, MilvusException
|
@@ -24,6 +24,7 @@ class Milvus(VectorDB):
|
|
24
24
|
collection_name: str = "VectorDBBenchCollection",
|
25
25
|
drop_old: bool = False,
|
26
26
|
name: str = "Milvus",
|
27
|
+
**kwargs,
|
27
28
|
):
|
28
29
|
"""Initialize wrapper around the milvus vector database."""
|
29
30
|
self.name = name
|
@@ -53,7 +54,7 @@ class Milvus(VectorDB):
|
|
53
54
|
log.info(f"{self.name} create collection: {self.collection_name}")
|
54
55
|
|
55
56
|
# Create the collection
|
56
|
-
|
57
|
+
Collection(
|
57
58
|
name=self.collection_name,
|
58
59
|
schema=CollectionSchema(fields),
|
59
60
|
consistency_level="Session",
|
@@ -107,6 +108,14 @@ class Milvus(VectorDB):
|
|
107
108
|
|
108
109
|
def _optimize(self):
|
109
110
|
log.info(f"{self.name} optimizing before search")
|
111
|
+
try:
|
112
|
+
self.col.load()
|
113
|
+
except Exception as e:
|
114
|
+
log.warning(f"{self.name} optimize error: {e}")
|
115
|
+
raise e from None
|
116
|
+
|
117
|
+
def _post_insert(self):
|
118
|
+
log.info(f"{self.name} post insert before optimize")
|
110
119
|
try:
|
111
120
|
self.col.flush()
|
112
121
|
self.col.compact()
|
@@ -119,10 +128,6 @@ class Milvus(VectorDB):
|
|
119
128
|
index_name=self._index_name,
|
120
129
|
)
|
121
130
|
utility.wait_for_index_building_complete(self.collection_name)
|
122
|
-
self.col.load()
|
123
|
-
# self.col.load(_refresh=True)
|
124
|
-
# utility.wait_for_loading_complete(self.collection_name)
|
125
|
-
# import time; time.sleep(10)
|
126
131
|
except Exception as e:
|
127
132
|
log.warning(f"{self.name} optimize error: {e}")
|
128
133
|
raise e from None
|
@@ -132,7 +137,7 @@ class Milvus(VectorDB):
|
|
132
137
|
self._pre_load(self.col)
|
133
138
|
pass
|
134
139
|
|
135
|
-
def
|
140
|
+
def optimize(self):
|
136
141
|
assert self.col, "Please call self.init() before"
|
137
142
|
self._optimize()
|
138
143
|
|
@@ -140,7 +145,7 @@ class Milvus(VectorDB):
|
|
140
145
|
self,
|
141
146
|
embeddings: Iterable[list[float]],
|
142
147
|
metadata: list[int],
|
143
|
-
**kwargs
|
148
|
+
**kwargs,
|
144
149
|
) -> (int, Exception):
|
145
150
|
"""Insert embeddings into Milvus. should call self.init() first"""
|
146
151
|
# use the first insert_embeddings to init collection
|
@@ -155,10 +160,12 @@ class Milvus(VectorDB):
|
|
155
160
|
metadata[batch_start_offset : batch_end_offset],
|
156
161
|
embeddings[batch_start_offset : batch_end_offset],
|
157
162
|
]
|
158
|
-
res = self.col.insert(insert_data
|
163
|
+
res = self.col.insert(insert_data)
|
159
164
|
insert_count += len(res.primary_keys)
|
165
|
+
if kwargs.get("last_batch"):
|
166
|
+
self._post_insert()
|
160
167
|
except MilvusException as e:
|
161
|
-
log.
|
168
|
+
log.info(f"Failed to insert data: {e}")
|
162
169
|
return (insert_count, e)
|
163
170
|
return (insert_count, None)
|
164
171
|
|
@@ -0,0 +1,49 @@
|
|
1
|
+
from pydantic import BaseModel, SecretStr
|
2
|
+
from ..api import DBConfig, DBCaseConfig, MetricType
|
3
|
+
|
4
|
+
POSTGRE_URL_PLACEHOLDER = "postgresql://%s:%s@%s/%s"
|
5
|
+
|
6
|
+
class PgVectorConfig(DBConfig):
|
7
|
+
user_name: SecretStr = "postgres"
|
8
|
+
password: SecretStr
|
9
|
+
url: SecretStr
|
10
|
+
db_name: str
|
11
|
+
|
12
|
+
def to_dict(self) -> dict:
|
13
|
+
user_str = self.user_name.get_secret_value()
|
14
|
+
pwd_str = self.password.get_secret_value()
|
15
|
+
url_str = self.url.get_secret_value()
|
16
|
+
return {
|
17
|
+
"url" : POSTGRE_URL_PLACEHOLDER%(user_str, pwd_str, url_str, self.db_name)
|
18
|
+
}
|
19
|
+
|
20
|
+
class PgVectorIndexConfig(BaseModel, DBCaseConfig):
|
21
|
+
metric_type: MetricType | None = None
|
22
|
+
lists: int | None = 1000
|
23
|
+
probes: int | None = 10
|
24
|
+
|
25
|
+
def parse_metric(self) -> str:
|
26
|
+
if self.metric_type == MetricType.L2:
|
27
|
+
return "vector_l2_ops"
|
28
|
+
elif self.metric_type == MetricType.IP:
|
29
|
+
return "vector_ip_ops"
|
30
|
+
return "vector_cosine_ops"
|
31
|
+
|
32
|
+
def parse_metric_fun_str(self) -> str:
|
33
|
+
if self.metric_type == MetricType.L2:
|
34
|
+
return "l2_distance"
|
35
|
+
elif self.metric_type == MetricType.IP:
|
36
|
+
return "max_inner_product"
|
37
|
+
return "cosine_distance"
|
38
|
+
|
39
|
+
def index_param(self) -> dict:
|
40
|
+
return {
|
41
|
+
"lists" : self.lists,
|
42
|
+
"metric" : self.parse_metric()
|
43
|
+
}
|
44
|
+
|
45
|
+
def search_param(self) -> dict:
|
46
|
+
return {
|
47
|
+
"probes" : self.probes,
|
48
|
+
"metric_fun" : self.parse_metric_fun_str()
|
49
|
+
}
|
@@ -0,0 +1,171 @@
|
|
1
|
+
"""Wrapper around the Pgvector vector database over VectorDB"""
|
2
|
+
|
3
|
+
import logging
|
4
|
+
import time
|
5
|
+
from contextlib import contextmanager
|
6
|
+
from typing import Any, Type
|
7
|
+
from functools import wraps
|
8
|
+
|
9
|
+
from ..api import VectorDB, DBConfig, DBCaseConfig, IndexType
|
10
|
+
from pgvector.sqlalchemy import Vector
|
11
|
+
from .config import PgVectorConfig, PgVectorIndexConfig
|
12
|
+
from sqlalchemy import (
|
13
|
+
MetaData,
|
14
|
+
create_engine,
|
15
|
+
insert,
|
16
|
+
select,
|
17
|
+
Index,
|
18
|
+
Table,
|
19
|
+
text,
|
20
|
+
Column,
|
21
|
+
Float,
|
22
|
+
Integer
|
23
|
+
)
|
24
|
+
from sqlalchemy.orm import (
|
25
|
+
declarative_base,
|
26
|
+
mapped_column,
|
27
|
+
Session
|
28
|
+
)
|
29
|
+
|
30
|
+
log = logging.getLogger(__name__)
|
31
|
+
|
32
|
+
class PgVector(VectorDB):
|
33
|
+
""" Use SQLAlchemy instructions"""
|
34
|
+
def __init__(
|
35
|
+
self,
|
36
|
+
dim: int,
|
37
|
+
db_config: dict,
|
38
|
+
db_case_config: DBCaseConfig,
|
39
|
+
collection_name: str = "PgVectorCollection",
|
40
|
+
drop_old: bool = False,
|
41
|
+
**kwargs,
|
42
|
+
):
|
43
|
+
self.db_config = db_config
|
44
|
+
self.case_config = db_case_config
|
45
|
+
self.table_name = collection_name
|
46
|
+
self.dim = dim
|
47
|
+
|
48
|
+
self._index_name = "pqvector_index"
|
49
|
+
self._primary_field = "id"
|
50
|
+
self._vector_field = "embedding"
|
51
|
+
|
52
|
+
# construct basic units
|
53
|
+
pg_engine = create_engine(**self.db_config)
|
54
|
+
Base = declarative_base()
|
55
|
+
pq_metadata = Base.metadata
|
56
|
+
pq_metadata.reflect(pg_engine)
|
57
|
+
|
58
|
+
# create vector extension
|
59
|
+
with pg_engine.connect() as conn:
|
60
|
+
conn.execute(text('CREATE EXTENSION IF NOT EXISTS vector'))
|
61
|
+
conn.commit()
|
62
|
+
|
63
|
+
self.pg_table = self._get_table_schema(pq_metadata)
|
64
|
+
if drop_old and self.table_name in pq_metadata.tables:
|
65
|
+
log.info(f"Pgvector client drop table : {self.table_name}")
|
66
|
+
# self.pg_table.drop(pg_engine, checkfirst=True)
|
67
|
+
pq_metadata.drop_all(pg_engine)
|
68
|
+
self._create_table(dim, pg_engine)
|
69
|
+
|
70
|
+
|
71
|
+
@classmethod
|
72
|
+
def config_cls(cls) -> Type[DBConfig]:
|
73
|
+
return PgVectorConfig
|
74
|
+
|
75
|
+
@classmethod
|
76
|
+
def case_config_cls(cls, index_type: IndexType | None = None) -> Type[DBCaseConfig]:
|
77
|
+
return PgVectorIndexConfig
|
78
|
+
|
79
|
+
@contextmanager
|
80
|
+
def init(self) -> None:
|
81
|
+
"""
|
82
|
+
Examples:
|
83
|
+
>>> with self.init():
|
84
|
+
>>> self.insert_embeddings()
|
85
|
+
>>> self.search_embedding()
|
86
|
+
"""
|
87
|
+
self.pg_engine = create_engine(**self.db_config)
|
88
|
+
|
89
|
+
Base = declarative_base()
|
90
|
+
pq_metadata = Base.metadata
|
91
|
+
pq_metadata.reflect(self.pg_engine)
|
92
|
+
self.pg_session = Session(self.pg_engine)
|
93
|
+
self.pg_table = self._get_table_schema(pq_metadata)
|
94
|
+
yield
|
95
|
+
self.pg_session = None
|
96
|
+
self.pg_engine = None
|
97
|
+
del (self.pg_session)
|
98
|
+
del (self.pg_engine)
|
99
|
+
|
100
|
+
def ready_to_load(self):
|
101
|
+
pass
|
102
|
+
|
103
|
+
def optimize(self):
|
104
|
+
pass
|
105
|
+
|
106
|
+
def ready_to_search(self):
|
107
|
+
pass
|
108
|
+
|
109
|
+
def _get_table_schema(self, pq_metadata):
|
110
|
+
return Table(
|
111
|
+
self.table_name,
|
112
|
+
pq_metadata,
|
113
|
+
Column(self._primary_field, Integer, primary_key=True),
|
114
|
+
Column(self._vector_field, Vector(self.dim)),
|
115
|
+
extend_existing=True
|
116
|
+
)
|
117
|
+
|
118
|
+
def _create_index(self, pg_engine):
|
119
|
+
index_param = self.case_config.index_param()
|
120
|
+
index = Index(self._index_name, self.pg_table.c.embedding,
|
121
|
+
postgresql_using='ivfflat',
|
122
|
+
postgresql_with={'lists': index_param["lists"]},
|
123
|
+
postgresql_ops={'embedding': index_param["metric"]}
|
124
|
+
)
|
125
|
+
index.drop(pg_engine, checkfirst = True)
|
126
|
+
index.create(pg_engine)
|
127
|
+
|
128
|
+
def _create_table(self, dim, pg_engine : int):
|
129
|
+
try:
|
130
|
+
# create table
|
131
|
+
self.pg_table.create(bind = pg_engine, checkfirst = True)
|
132
|
+
# create vec index
|
133
|
+
self._create_index(pg_engine)
|
134
|
+
except Exception as e:
|
135
|
+
log.warning(f"Failed to create pgvector table: {self.table_name} error: {e}")
|
136
|
+
raise e from None
|
137
|
+
|
138
|
+
def insert_embeddings(
|
139
|
+
self,
|
140
|
+
embeddings: list[list[float]],
|
141
|
+
metadata: list[int],
|
142
|
+
**kwargs: Any,
|
143
|
+
) -> (int, Exception):
|
144
|
+
try:
|
145
|
+
items = [dict(id = metadata[i], embedding=embeddings[i]) for i in range(len(metadata))]
|
146
|
+
self.pg_session.execute(insert(self.pg_table), items)
|
147
|
+
self.pg_session.commit()
|
148
|
+
return len(metadata), None
|
149
|
+
except Exception as e:
|
150
|
+
log.warning(f"Failed to insert data into pgvector table ({self.table_name}), error: {e}")
|
151
|
+
return 0, e
|
152
|
+
|
153
|
+
def search_embedding(
|
154
|
+
self,
|
155
|
+
query: list[float],
|
156
|
+
k: int = 100,
|
157
|
+
filters: dict | None = None,
|
158
|
+
timeout: int | None = None,
|
159
|
+
) -> list[int]:
|
160
|
+
assert self.pg_table is not None
|
161
|
+
search_param =self.case_config.search_param()
|
162
|
+
with self.pg_engine.connect() as conn:
|
163
|
+
conn.execute(text(f'SET ivfflat.probes = {search_param["probes"]}'))
|
164
|
+
conn.commit()
|
165
|
+
op_fun = getattr(self.pg_table.c.embedding, search_param["metric_fun"])
|
166
|
+
if filters:
|
167
|
+
res = self.pg_session.scalars(select(self.pg_table).order_by(op_fun(query)).filter(self.pg_table.c.id > filters.get('id')).limit(k))
|
168
|
+
else:
|
169
|
+
res = self.pg_session.scalars(select(self.pg_table).order_by(op_fun(query)).limit(k))
|
170
|
+
return list(res)
|
171
|
+
|
@@ -2,7 +2,7 @@
|
|
2
2
|
|
3
3
|
import logging
|
4
4
|
from contextlib import contextmanager
|
5
|
-
from typing import
|
5
|
+
from typing import Type
|
6
6
|
|
7
7
|
from ..api import VectorDB, DBConfig, DBCaseConfig, EmptyDBCaseConfig, IndexType
|
8
8
|
from .config import PineconeConfig
|
@@ -20,6 +20,7 @@ class Pinecone(VectorDB):
|
|
20
20
|
db_config: dict,
|
21
21
|
db_case_config: DBCaseConfig,
|
22
22
|
drop_old: bool = False,
|
23
|
+
**kwargs,
|
23
24
|
):
|
24
25
|
"""Initialize wrapper around the milvus vector database."""
|
25
26
|
self.index_name = db_config["index_name"]
|
@@ -69,13 +70,14 @@ class Pinecone(VectorDB):
|
|
69
70
|
def ready_to_load(self):
|
70
71
|
pass
|
71
72
|
|
72
|
-
def
|
73
|
+
def optimize(self):
|
73
74
|
pass
|
74
75
|
|
75
76
|
def insert_embeddings(
|
76
77
|
self,
|
77
78
|
embeddings: list[list[float]],
|
78
79
|
metadata: list[int],
|
80
|
+
**kwargs,
|
79
81
|
) -> (int, Exception):
|
80
82
|
assert len(embeddings) == len(metadata)
|
81
83
|
insert_count = 0
|
@@ -99,7 +101,6 @@ class Pinecone(VectorDB):
|
|
99
101
|
k: int = 100,
|
100
102
|
filters: dict | None = None,
|
101
103
|
timeout: int | None = None,
|
102
|
-
**kwargs: Any,
|
103
104
|
) -> list[tuple[int, float]]:
|
104
105
|
if filters is None:
|
105
106
|
pinecone_filters = {}
|
@@ -1,6 +1,7 @@
|
|
1
|
-
from pydantic import SecretStr
|
1
|
+
from pydantic import BaseModel, SecretStr
|
2
2
|
|
3
|
-
from ..api import DBConfig
|
3
|
+
from ..api import DBConfig, DBCaseConfig, MetricType
|
4
|
+
from qdrant_client.models import Distance
|
4
5
|
|
5
6
|
|
6
7
|
class QdrantConfig(DBConfig):
|
@@ -13,3 +14,20 @@ class QdrantConfig(DBConfig):
|
|
13
14
|
"api_key": self.api_key.get_secret_value(),
|
14
15
|
"prefer_grpc": True,
|
15
16
|
}
|
17
|
+
|
18
|
+
class QdrantIndexConfig(BaseModel, DBCaseConfig):
|
19
|
+
metric_type: MetricType | None = None
|
20
|
+
|
21
|
+
def parse_metric(self) -> str:
|
22
|
+
if self.metric_type == MetricType.L2:
|
23
|
+
return Distance.EUCLID
|
24
|
+
elif self.metric_type == MetricType.IP:
|
25
|
+
return Distance.DOT
|
26
|
+
return Distance.COSINE
|
27
|
+
|
28
|
+
def index_param(self) -> dict:
|
29
|
+
params = {"distance": self.parse_metric()}
|
30
|
+
return params
|
31
|
+
|
32
|
+
def search_param(self) -> dict:
|
33
|
+
return {}
|