vectordb-bench 0.0.2__py3-none-any.whl → 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. vectordb_bench/__init__.py +14 -3
  2. vectordb_bench/backend/cases.py +34 -13
  3. vectordb_bench/backend/clients/__init__.py +6 -1
  4. vectordb_bench/backend/clients/api.py +12 -8
  5. vectordb_bench/backend/clients/elastic_cloud/elastic_cloud.py +4 -2
  6. vectordb_bench/backend/clients/milvus/milvus.py +17 -10
  7. vectordb_bench/backend/clients/pgvector/config.py +49 -0
  8. vectordb_bench/backend/clients/pgvector/pgvector.py +171 -0
  9. vectordb_bench/backend/clients/pinecone/pinecone.py +4 -3
  10. vectordb_bench/backend/clients/qdrant_cloud/config.py +20 -2
  11. vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py +11 -11
  12. vectordb_bench/backend/clients/weaviate_cloud/weaviate_cloud.py +5 -5
  13. vectordb_bench/backend/clients/zilliz_cloud/zilliz_cloud.py +3 -1
  14. vectordb_bench/backend/dataset.py +99 -149
  15. vectordb_bench/backend/result_collector.py +2 -2
  16. vectordb_bench/backend/runner/mp_runner.py +29 -13
  17. vectordb_bench/backend/runner/serial_runner.py +69 -51
  18. vectordb_bench/backend/task_runner.py +43 -48
  19. vectordb_bench/frontend/components/get_results/saveAsImage.py +4 -2
  20. vectordb_bench/frontend/const/dbCaseConfigs.py +35 -4
  21. vectordb_bench/frontend/const/dbPrices.py +5 -33
  22. vectordb_bench/frontend/const/styles.py +9 -3
  23. vectordb_bench/metric.py +0 -1
  24. vectordb_bench/models.py +12 -8
  25. vectordb_bench/results/dbPrices.json +32 -0
  26. vectordb_bench/results/getLeaderboardData.py +52 -0
  27. vectordb_bench/results/leaderboard.json +1 -0
  28. vectordb_bench/results/{result_20230609_standard.json → result_20230705_standard.json} +670 -214
  29. {vectordb_bench-0.0.2.dist-info → vectordb_bench-0.0.3.dist-info}/METADATA +98 -13
  30. {vectordb_bench-0.0.2.dist-info → vectordb_bench-0.0.3.dist-info}/RECORD +34 -29
  31. {vectordb_bench-0.0.2.dist-info → vectordb_bench-0.0.3.dist-info}/LICENSE +0 -0
  32. {vectordb_bench-0.0.2.dist-info → vectordb_bench-0.0.3.dist-info}/WHEEL +0 -0
  33. {vectordb_bench-0.0.2.dist-info → vectordb_bench-0.0.3.dist-info}/entry_points.txt +0 -0
  34. {vectordb_bench-0.0.2.dist-info → vectordb_bench-0.0.3.dist-info}/top_level.txt +0 -0
@@ -18,12 +18,23 @@ class config:
18
18
  USE_SHUFFLED_DATA = env.bool("USE_SHUFFLED_DATA", True)
19
19
 
20
20
  RESULTS_LOCAL_DIR = pathlib.Path(__file__).parent.joinpath("results")
21
- CASE_TIMEOUT_IN_SECOND = 24 * 60 * 60
21
+
22
+ CAPACITY_TIMEOUT_IN_SECONDS = 24 * 3600 # 24h
23
+ LOAD_TIMEOUT_1M = 2.5 * 3600 # 2.5h
24
+ LOAD_TIMEOUT_10M = 25 * 3600 # 25h
25
+ LOAD_TIMEOUT_100M = 250 * 3600 # 10.41d
26
+
27
+ OPTIMIZE_TIMEOUT_1M = 15 * 60 # 15min
28
+ OPTIMIZE_TIMEOUT_10M = 2.5 * 3600 # 2.5h
29
+ OPTIMIZE_TIMEOUT_100M = 25 * 3600 # 1.04d
22
30
 
23
31
 
24
32
  def display(self) -> str:
25
- tmp = [i for i in inspect.getmembers(self)
26
- if not inspect.ismethod(i[1]) and not i[0].startswith('_') \
33
+ tmp = [
34
+ i for i in inspect.getmembers(self)
35
+ if not inspect.ismethod(i[1])
36
+ and not i[0].startswith('_')
37
+ and "TIMEOUT" not in i[0]
27
38
  ]
28
39
  return tmp
29
40
 
@@ -2,8 +2,10 @@ import typing
2
2
  import logging
3
3
  from enum import Enum, auto
4
4
 
5
- from . import dataset as ds
6
- from ..base import BaseModel
5
+ from vectordb_bench import config
6
+ from vectordb_bench.base import BaseModel
7
+
8
+ from .dataset import Dataset, DatasetManager
7
9
 
8
10
 
9
11
  log = logging.getLogger(__name__)
@@ -44,7 +46,7 @@ class CaseType(Enum):
44
46
  if c is not None:
45
47
  return c().name
46
48
  raise ValueError("Case unsupported")
47
-
49
+
48
50
  @property
49
51
  def case_description(self) -> str:
50
52
  c = self.case_cls
@@ -73,7 +75,10 @@ class Case(BaseModel):
73
75
  label: CaseLabel
74
76
  name: str
75
77
  description: str
76
- dataset: ds.DataSet
78
+ dataset: DatasetManager
79
+
80
+ load_timeout: float | int
81
+ optimize_timeout: float | int | None
77
82
 
78
83
  filter_rate: float | None
79
84
 
@@ -92,6 +97,8 @@ class Case(BaseModel):
92
97
  class CapacityCase(Case, BaseModel):
93
98
  label: CaseLabel = CaseLabel.Load
94
99
  filter_rate: float | None = None
100
+ load_timeout: float | int = config.CAPACITY_TIMEOUT_IN_SECONDS
101
+ optimize_timeout: float | int | None = None
95
102
 
96
103
 
97
104
  class PerformanceCase(Case, BaseModel):
@@ -101,7 +108,7 @@ class PerformanceCase(Case, BaseModel):
101
108
 
102
109
  class CapacityDim960(CapacityCase):
103
110
  case_id: CaseType = CaseType.CapacityDim960
104
- dataset: ds.DataSet = ds.get(ds.Name.GIST, ds.Label.SMALL)
111
+ dataset: DatasetManager = Dataset.GIST.manager(100_000)
105
112
  name: str = "Capacity Test (960 Dim Repeated)"
106
113
  description: str = """This case tests the vector database's loading capacity by repeatedly inserting large-dimension vectors (GIST 100K vectors, <b>960 dimensions</b>) until it is fully loaded.
107
114
  Number of inserted vectors will be reported."""
@@ -109,7 +116,7 @@ Number of inserted vectors will be reported."""
109
116
 
110
117
  class CapacityDim128(CapacityCase):
111
118
  case_id: CaseType = CaseType.CapacityDim128
112
- dataset: ds.DataSet = ds.get(ds.Name.SIFT, ds.Label.SMALL)
119
+ dataset: DatasetManager = Dataset.SIFT.manager(500_000)
113
120
  name: str = "Capacity Test (128 Dim Repeated)"
114
121
  description: str = """This case tests the vector database's loading capacity by repeatedly inserting small-dimension vectors (SIFT 100K vectors, <b>128 dimensions</b>) until it is fully loaded.
115
122
  Number of inserted vectors will be reported."""
@@ -117,64 +124,78 @@ Number of inserted vectors will be reported."""
117
124
 
118
125
  class Performance10M(PerformanceCase):
119
126
  case_id: CaseType = CaseType.Performance10M
120
- dataset: ds.DataSet = ds.get(ds.Name.Cohere, ds.Label.LARGE)
127
+ dataset: DatasetManager = Dataset.COHERE.manager(10_000_000)
121
128
  name: str = "Search Performance Test (10M Dataset, 768 Dim)"
122
129
  description: str = """This case tests the search performance of a vector database with a large dataset (<b>Cohere 10M vectors</b>, 768 dimensions) at varying parallel levels.
123
130
  Results will show index building time, recall, and maximum QPS."""
131
+ load_timeout: float | int = config.LOAD_TIMEOUT_10M
132
+ optimize_timeout: float | int | None = config.OPTIMIZE_TIMEOUT_10M
124
133
 
125
134
 
126
135
  class Performance1M(PerformanceCase):
127
136
  case_id: CaseType = CaseType.Performance1M
128
- dataset: ds.DataSet = ds.get(ds.Name.Cohere, ds.Label.MEDIUM)
137
+ dataset: DatasetManager = Dataset.COHERE.manager(1_000_000)
129
138
  name: str = "Search Performance Test (1M Dataset, 768 Dim)"
130
139
  description: str = """This case tests the search performance of a vector database with a medium dataset (<b>Cohere 1M vectors</b>, 768 dimensions) at varying parallel levels.
131
140
  Results will show index building time, recall, and maximum QPS."""
141
+ load_timeout: float | int = config.LOAD_TIMEOUT_1M
142
+ optimize_timeout: float | int | None = config.OPTIMIZE_TIMEOUT_1M
132
143
 
133
144
 
134
145
  class Performance10M1P(PerformanceCase):
135
146
  case_id: CaseType = CaseType.Performance10M1P
136
147
  filter_rate: float | int | None = 0.01
137
- dataset: ds.DataSet = ds.get(ds.Name.Cohere, ds.Label.LARGE)
148
+ dataset: DatasetManager = Dataset.COHERE.manager(10_000_000)
138
149
  name: str = "Filtering Search Performance Test (10M Dataset, 768 Dim, Filter 1%)"
139
150
  description: str = """This case tests the search performance of a vector database with a large dataset (<b>Cohere 10M vectors</b>, 768 dimensions) under a low filtering rate (<b>1% vectors</b>), at varying parallel levels.
140
151
  Results will show index building time, recall, and maximum QPS."""
152
+ load_timeout: float | int = config.LOAD_TIMEOUT_10M
153
+ optimize_timeout: float | int | None = config.OPTIMIZE_TIMEOUT_10M
141
154
 
142
155
 
143
156
  class Performance1M1P(PerformanceCase):
144
157
  case_id: CaseType = CaseType.Performance1M1P
145
158
  filter_rate: float | int | None = 0.01
146
- dataset: ds.DataSet = ds.get(ds.Name.Cohere, ds.Label.MEDIUM)
159
+ dataset: DatasetManager = Dataset.COHERE.manager(1_000_000)
147
160
  name: str = "Filtering Search Performance Test (1M Dataset, 768 Dim, Filter 1%)"
148
161
  description: str = """This case tests the search performance of a vector database with a medium dataset (<b>Cohere 1M vectors</b>, 768 dimensions) under a low filtering rate (<b>1% vectors</b>), at varying parallel levels.
149
162
  Results will show index building time, recall, and maximum QPS."""
163
+ load_timeout: float | int = config.LOAD_TIMEOUT_1M
164
+ optimize_timeout: float | int | None = config.OPTIMIZE_TIMEOUT_1M
150
165
 
151
166
 
152
167
  class Performance10M99P(PerformanceCase):
153
168
  case_id: CaseType = CaseType.Performance10M99P
154
169
  filter_rate: float | int | None = 0.99
155
- dataset: ds.DataSet = ds.get(ds.Name.Cohere, ds.Label.LARGE)
170
+ dataset: DatasetManager = Dataset.COHERE.manager(10_000_000)
156
171
  name: str = "Filtering Search Performance Test (10M Dataset, 768 Dim, Filter 99%)"
157
172
  description: str = """This case tests the search performance of a vector database with a large dataset (<b>Cohere 10M vectors</b>, 768 dimensions) under a high filtering rate (<b>99% vectors</b>), at varying parallel levels.
158
173
  Results will show index building time, recall, and maximum QPS."""
174
+ load_timeout: float | int = config.LOAD_TIMEOUT_10M
175
+ optimize_timeout: float | int | None = config.OPTIMIZE_TIMEOUT_10M
159
176
 
160
177
 
161
178
  class Performance1M99P(PerformanceCase):
162
179
  case_id: CaseType = CaseType.Performance1M99P
163
180
  filter_rate: float | int | None = 0.99
164
- dataset: ds.DataSet = ds.get(ds.Name.Cohere, ds.Label.MEDIUM)
181
+ dataset: DatasetManager = Dataset.COHERE.manager(1_000_000)
165
182
  name: str = "Filtering Search Performance Test (1M Dataset, 768 Dim, Filter 99%)"
166
183
  description: str = """This case tests the search performance of a vector database with a medium dataset (<b>Cohere 1M vectors</b>, 768 dimensions) under a high filtering rate (<b>99% vectors</b>), at varying parallel levels.
167
184
  Results will show index building time, recall, and maximum QPS."""
185
+ load_timeout: float | int = config.LOAD_TIMEOUT_1M
186
+ optimize_timeout: float | int | None = config.OPTIMIZE_TIMEOUT_1M
168
187
 
169
188
 
170
189
 
171
190
  class Performance100M(PerformanceCase):
172
191
  case_id: CaseType = CaseType.Performance100M
173
192
  filter_rate: float | int | None = None
174
- dataset: ds.DataSet = ds.get(ds.Name.LAION, ds.Label.LARGE)
193
+ dataset: DatasetManager = Dataset.LAION.manager(100_000_000)
175
194
  name: str = "Search Performance Test (100M Dataset, 768 Dim)"
176
195
  description: str = """This case tests the search performance of a vector database with a large 100M dataset (<b>LAION 100M vectors</b>, 768 dimensions), at varying parallel levels.
177
196
  Results will show index building time, recall, and maximum QPS."""
197
+ load_timeout: float | int = config.LOAD_TIMEOUT_100M
198
+ optimize_timeout: float | int | None = config.OPTIMIZE_TIMEOUT_100M
178
199
 
179
200
 
180
201
  type2case = {
@@ -15,7 +15,7 @@ from .pinecone.pinecone import Pinecone
15
15
  from .weaviate_cloud.weaviate_cloud import WeaviateCloud
16
16
  from .qdrant_cloud.qdrant_cloud import QdrantCloud
17
17
  from .zilliz_cloud.zilliz_cloud import ZillizCloud
18
-
18
+ from .pgvector.pgvector import PgVector
19
19
 
20
20
  class DB(Enum):
21
21
  """Database types
@@ -35,6 +35,7 @@ class DB(Enum):
35
35
  ElasticCloud = "ElasticCloud"
36
36
  QdrantCloud = "QdrantCloud"
37
37
  WeaviateCloud = "WeaviateCloud"
38
+ PgVector = "PgVector"
38
39
 
39
40
 
40
41
  @property
@@ -49,8 +50,12 @@ db2client = {
49
50
  DB.ElasticCloud: ElasticCloud,
50
51
  DB.QdrantCloud: QdrantCloud,
51
52
  DB.Pinecone: Pinecone,
53
+ DB.PgVector: PgVector
52
54
  }
53
55
 
56
+ for db in DB:
57
+ assert issubclass(db.init_cls, VectorDB)
58
+
54
59
 
55
60
  __all__ = [
56
61
  "DB", "VectorDB", "DBConfig", "DBCaseConfig", "IndexType", "MetricType", "EmptyDBCaseConfig",
@@ -73,7 +73,7 @@ class VectorDB(ABC):
73
73
 
74
74
  In each process, the benchmark cases ensure VectorDB.init() calls before any other methods operations
75
75
 
76
- insert_embeddings, search_embedding, and, ready_to_search will be timed for each call.
76
+ insert_embeddings, search_embedding, and, optimize will be timed for each call.
77
77
 
78
78
  Examples:
79
79
  >>> milvus = Milvus()
@@ -90,9 +90,12 @@ class VectorDB(ABC):
90
90
  db_case_config: DBCaseConfig | None,
91
91
  collection_name: str,
92
92
  drop_old: bool = False,
93
- **kwargs
93
+ **kwargs,
94
94
  ) -> None:
95
- """Initialize wrapper around the vector database client
95
+ """Initialize wrapper around the vector database client.
96
+
97
+ Please drop the existing collection if drop_old is True. And create collection
98
+ if collection not in the Vector Database
96
99
 
97
100
  Args:
98
101
  dim(int): the dimension of the dataset
@@ -130,7 +133,7 @@ class VectorDB(ABC):
130
133
  self,
131
134
  embeddings: list[list[float]],
132
135
  metadata: list[int],
133
- kwargs: Any,
136
+ **kwargs,
134
137
  ) -> (int, Exception):
135
138
  """Insert the embeddings to the vector database. The default number of embeddings for
136
139
  each insert_embeddings is 5000.
@@ -138,7 +141,7 @@ class VectorDB(ABC):
138
141
  Args:
139
142
  embeddings(list[list[float]]): list of embedding to add to the vector database.
140
143
  metadatas(list[int]): metadata associated with the embeddings, for filtering.
141
- kwargs(Any): vector database specific parameters.
144
+ **kwargs(Any): vector database specific parameters.
142
145
 
143
146
  Returns:
144
147
  int: inserted data count
@@ -166,13 +169,14 @@ class VectorDB(ABC):
166
169
 
167
170
  # TODO: remove
168
171
  @abstractmethod
169
- def ready_to_search(self):
170
- """ready_to_search will be called between insertion and search in performance cases.
172
+ def optimize(self):
173
+ """optimize will be called between insertion and search in performance cases.
171
174
 
172
175
  Should be blocked until the vectorDB is ready to be tested on
173
176
  heavy performance cases.
174
177
 
175
- Time(insert the dataset) + Time(ready_to_search) will be recorded as "load_duration" metric
178
+ Time(insert the dataset) + Time(optimize) will be recorded as "load_duration" metric
179
+ Optimize's execution time is limited, the limited time is based on cases.
176
180
  """
177
181
  raise NotImplementedError
178
182
 
@@ -21,6 +21,7 @@ class ElasticCloud(VectorDB):
21
21
  id_col_name: str = "id",
22
22
  vector_col_name: str = "vector",
23
23
  drop_old: bool = False,
24
+ **kwargs,
24
25
  ):
25
26
  self.dim = dim
26
27
  self.db_config = db_config
@@ -83,6 +84,7 @@ class ElasticCloud(VectorDB):
83
84
  self,
84
85
  embeddings: Iterable[list[float]],
85
86
  metadata: list[int],
87
+ **kwargs,
86
88
  ) -> (int, Exception):
87
89
  """Insert the embeddings to the elasticsearch."""
88
90
  assert self.client is not None, "should self.init() first"
@@ -143,8 +145,8 @@ class ElasticCloud(VectorDB):
143
145
  log.warning(f"Failed to search: {self.indice} error: {str(e)}")
144
146
  raise e from None
145
147
 
146
- def ready_to_search(self):
147
- """ready_to_search will be called between insertion and search in performance cases."""
148
+ def optimize(self):
149
+ """optimize will be called between insertion and search in performance cases."""
148
150
  pass
149
151
 
150
152
  def ready_to_load(self):
@@ -2,7 +2,7 @@
2
2
 
3
3
  import logging
4
4
  from contextlib import contextmanager
5
- from typing import Any, Iterable, Type
5
+ from typing import Iterable, Type
6
6
 
7
7
  from pymilvus import Collection, utility
8
8
  from pymilvus import CollectionSchema, DataType, FieldSchema, MilvusException
@@ -24,6 +24,7 @@ class Milvus(VectorDB):
24
24
  collection_name: str = "VectorDBBenchCollection",
25
25
  drop_old: bool = False,
26
26
  name: str = "Milvus",
27
+ **kwargs,
27
28
  ):
28
29
  """Initialize wrapper around the milvus vector database."""
29
30
  self.name = name
@@ -53,7 +54,7 @@ class Milvus(VectorDB):
53
54
  log.info(f"{self.name} create collection: {self.collection_name}")
54
55
 
55
56
  # Create the collection
56
- coll = Collection(
57
+ Collection(
57
58
  name=self.collection_name,
58
59
  schema=CollectionSchema(fields),
59
60
  consistency_level="Session",
@@ -107,6 +108,14 @@ class Milvus(VectorDB):
107
108
 
108
109
  def _optimize(self):
109
110
  log.info(f"{self.name} optimizing before search")
111
+ try:
112
+ self.col.load()
113
+ except Exception as e:
114
+ log.warning(f"{self.name} optimize error: {e}")
115
+ raise e from None
116
+
117
+ def _post_insert(self):
118
+ log.info(f"{self.name} post insert before optimize")
110
119
  try:
111
120
  self.col.flush()
112
121
  self.col.compact()
@@ -119,10 +128,6 @@ class Milvus(VectorDB):
119
128
  index_name=self._index_name,
120
129
  )
121
130
  utility.wait_for_index_building_complete(self.collection_name)
122
- self.col.load()
123
- # self.col.load(_refresh=True)
124
- # utility.wait_for_loading_complete(self.collection_name)
125
- # import time; time.sleep(10)
126
131
  except Exception as e:
127
132
  log.warning(f"{self.name} optimize error: {e}")
128
133
  raise e from None
@@ -132,7 +137,7 @@ class Milvus(VectorDB):
132
137
  self._pre_load(self.col)
133
138
  pass
134
139
 
135
- def ready_to_search(self):
140
+ def optimize(self):
136
141
  assert self.col, "Please call self.init() before"
137
142
  self._optimize()
138
143
 
@@ -140,7 +145,7 @@ class Milvus(VectorDB):
140
145
  self,
141
146
  embeddings: Iterable[list[float]],
142
147
  metadata: list[int],
143
- **kwargs: Any,
148
+ **kwargs,
144
149
  ) -> (int, Exception):
145
150
  """Insert embeddings into Milvus. should call self.init() first"""
146
151
  # use the first insert_embeddings to init collection
@@ -155,10 +160,12 @@ class Milvus(VectorDB):
155
160
  metadata[batch_start_offset : batch_end_offset],
156
161
  embeddings[batch_start_offset : batch_end_offset],
157
162
  ]
158
- res = self.col.insert(insert_data, **kwargs)
163
+ res = self.col.insert(insert_data)
159
164
  insert_count += len(res.primary_keys)
165
+ if kwargs.get("last_batch"):
166
+ self._post_insert()
160
167
  except MilvusException as e:
161
- log.warning("Failed to insert data")
168
+ log.info(f"Failed to insert data: {e}")
162
169
  return (insert_count, e)
163
170
  return (insert_count, None)
164
171
 
@@ -0,0 +1,49 @@
1
+ from pydantic import BaseModel, SecretStr
2
+ from ..api import DBConfig, DBCaseConfig, MetricType
3
+
4
+ POSTGRE_URL_PLACEHOLDER = "postgresql://%s:%s@%s/%s"
5
+
6
+ class PgVectorConfig(DBConfig):
7
+ user_name: SecretStr = "postgres"
8
+ password: SecretStr
9
+ url: SecretStr
10
+ db_name: str
11
+
12
+ def to_dict(self) -> dict:
13
+ user_str = self.user_name.get_secret_value()
14
+ pwd_str = self.password.get_secret_value()
15
+ url_str = self.url.get_secret_value()
16
+ return {
17
+ "url" : POSTGRE_URL_PLACEHOLDER%(user_str, pwd_str, url_str, self.db_name)
18
+ }
19
+
20
+ class PgVectorIndexConfig(BaseModel, DBCaseConfig):
21
+ metric_type: MetricType | None = None
22
+ lists: int | None = 1000
23
+ probes: int | None = 10
24
+
25
+ def parse_metric(self) -> str:
26
+ if self.metric_type == MetricType.L2:
27
+ return "vector_l2_ops"
28
+ elif self.metric_type == MetricType.IP:
29
+ return "vector_ip_ops"
30
+ return "vector_cosine_ops"
31
+
32
+ def parse_metric_fun_str(self) -> str:
33
+ if self.metric_type == MetricType.L2:
34
+ return "l2_distance"
35
+ elif self.metric_type == MetricType.IP:
36
+ return "max_inner_product"
37
+ return "cosine_distance"
38
+
39
+ def index_param(self) -> dict:
40
+ return {
41
+ "lists" : self.lists,
42
+ "metric" : self.parse_metric()
43
+ }
44
+
45
+ def search_param(self) -> dict:
46
+ return {
47
+ "probes" : self.probes,
48
+ "metric_fun" : self.parse_metric_fun_str()
49
+ }
@@ -0,0 +1,171 @@
1
+ """Wrapper around the Pgvector vector database over VectorDB"""
2
+
3
+ import logging
4
+ import time
5
+ from contextlib import contextmanager
6
+ from typing import Any, Type
7
+ from functools import wraps
8
+
9
+ from ..api import VectorDB, DBConfig, DBCaseConfig, IndexType
10
+ from pgvector.sqlalchemy import Vector
11
+ from .config import PgVectorConfig, PgVectorIndexConfig
12
+ from sqlalchemy import (
13
+ MetaData,
14
+ create_engine,
15
+ insert,
16
+ select,
17
+ Index,
18
+ Table,
19
+ text,
20
+ Column,
21
+ Float,
22
+ Integer
23
+ )
24
+ from sqlalchemy.orm import (
25
+ declarative_base,
26
+ mapped_column,
27
+ Session
28
+ )
29
+
30
+ log = logging.getLogger(__name__)
31
+
32
+ class PgVector(VectorDB):
33
+ """ Use SQLAlchemy instructions"""
34
+ def __init__(
35
+ self,
36
+ dim: int,
37
+ db_config: dict,
38
+ db_case_config: DBCaseConfig,
39
+ collection_name: str = "PgVectorCollection",
40
+ drop_old: bool = False,
41
+ **kwargs,
42
+ ):
43
+ self.db_config = db_config
44
+ self.case_config = db_case_config
45
+ self.table_name = collection_name
46
+ self.dim = dim
47
+
48
+ self._index_name = "pqvector_index"
49
+ self._primary_field = "id"
50
+ self._vector_field = "embedding"
51
+
52
+ # construct basic units
53
+ pg_engine = create_engine(**self.db_config)
54
+ Base = declarative_base()
55
+ pq_metadata = Base.metadata
56
+ pq_metadata.reflect(pg_engine)
57
+
58
+ # create vector extension
59
+ with pg_engine.connect() as conn:
60
+ conn.execute(text('CREATE EXTENSION IF NOT EXISTS vector'))
61
+ conn.commit()
62
+
63
+ self.pg_table = self._get_table_schema(pq_metadata)
64
+ if drop_old and self.table_name in pq_metadata.tables:
65
+ log.info(f"Pgvector client drop table : {self.table_name}")
66
+ # self.pg_table.drop(pg_engine, checkfirst=True)
67
+ pq_metadata.drop_all(pg_engine)
68
+ self._create_table(dim, pg_engine)
69
+
70
+
71
+ @classmethod
72
+ def config_cls(cls) -> Type[DBConfig]:
73
+ return PgVectorConfig
74
+
75
+ @classmethod
76
+ def case_config_cls(cls, index_type: IndexType | None = None) -> Type[DBCaseConfig]:
77
+ return PgVectorIndexConfig
78
+
79
+ @contextmanager
80
+ def init(self) -> None:
81
+ """
82
+ Examples:
83
+ >>> with self.init():
84
+ >>> self.insert_embeddings()
85
+ >>> self.search_embedding()
86
+ """
87
+ self.pg_engine = create_engine(**self.db_config)
88
+
89
+ Base = declarative_base()
90
+ pq_metadata = Base.metadata
91
+ pq_metadata.reflect(self.pg_engine)
92
+ self.pg_session = Session(self.pg_engine)
93
+ self.pg_table = self._get_table_schema(pq_metadata)
94
+ yield
95
+ self.pg_session = None
96
+ self.pg_engine = None
97
+ del (self.pg_session)
98
+ del (self.pg_engine)
99
+
100
+ def ready_to_load(self):
101
+ pass
102
+
103
+ def optimize(self):
104
+ pass
105
+
106
+ def ready_to_search(self):
107
+ pass
108
+
109
+ def _get_table_schema(self, pq_metadata):
110
+ return Table(
111
+ self.table_name,
112
+ pq_metadata,
113
+ Column(self._primary_field, Integer, primary_key=True),
114
+ Column(self._vector_field, Vector(self.dim)),
115
+ extend_existing=True
116
+ )
117
+
118
+ def _create_index(self, pg_engine):
119
+ index_param = self.case_config.index_param()
120
+ index = Index(self._index_name, self.pg_table.c.embedding,
121
+ postgresql_using='ivfflat',
122
+ postgresql_with={'lists': index_param["lists"]},
123
+ postgresql_ops={'embedding': index_param["metric"]}
124
+ )
125
+ index.drop(pg_engine, checkfirst = True)
126
+ index.create(pg_engine)
127
+
128
+ def _create_table(self, dim, pg_engine : int):
129
+ try:
130
+ # create table
131
+ self.pg_table.create(bind = pg_engine, checkfirst = True)
132
+ # create vec index
133
+ self._create_index(pg_engine)
134
+ except Exception as e:
135
+ log.warning(f"Failed to create pgvector table: {self.table_name} error: {e}")
136
+ raise e from None
137
+
138
+ def insert_embeddings(
139
+ self,
140
+ embeddings: list[list[float]],
141
+ metadata: list[int],
142
+ **kwargs: Any,
143
+ ) -> (int, Exception):
144
+ try:
145
+ items = [dict(id = metadata[i], embedding=embeddings[i]) for i in range(len(metadata))]
146
+ self.pg_session.execute(insert(self.pg_table), items)
147
+ self.pg_session.commit()
148
+ return len(metadata), None
149
+ except Exception as e:
150
+ log.warning(f"Failed to insert data into pgvector table ({self.table_name}), error: {e}")
151
+ return 0, e
152
+
153
+ def search_embedding(
154
+ self,
155
+ query: list[float],
156
+ k: int = 100,
157
+ filters: dict | None = None,
158
+ timeout: int | None = None,
159
+ ) -> list[int]:
160
+ assert self.pg_table is not None
161
+ search_param =self.case_config.search_param()
162
+ with self.pg_engine.connect() as conn:
163
+ conn.execute(text(f'SET ivfflat.probes = {search_param["probes"]}'))
164
+ conn.commit()
165
+ op_fun = getattr(self.pg_table.c.embedding, search_param["metric_fun"])
166
+ if filters:
167
+ res = self.pg_session.scalars(select(self.pg_table).order_by(op_fun(query)).filter(self.pg_table.c.id > filters.get('id')).limit(k))
168
+ else:
169
+ res = self.pg_session.scalars(select(self.pg_table).order_by(op_fun(query)).limit(k))
170
+ return list(res)
171
+
@@ -2,7 +2,7 @@
2
2
 
3
3
  import logging
4
4
  from contextlib import contextmanager
5
- from typing import Any, Type
5
+ from typing import Type
6
6
 
7
7
  from ..api import VectorDB, DBConfig, DBCaseConfig, EmptyDBCaseConfig, IndexType
8
8
  from .config import PineconeConfig
@@ -20,6 +20,7 @@ class Pinecone(VectorDB):
20
20
  db_config: dict,
21
21
  db_case_config: DBCaseConfig,
22
22
  drop_old: bool = False,
23
+ **kwargs,
23
24
  ):
24
25
  """Initialize wrapper around the milvus vector database."""
25
26
  self.index_name = db_config["index_name"]
@@ -69,13 +70,14 @@ class Pinecone(VectorDB):
69
70
  def ready_to_load(self):
70
71
  pass
71
72
 
72
- def ready_to_search(self):
73
+ def optimize(self):
73
74
  pass
74
75
 
75
76
  def insert_embeddings(
76
77
  self,
77
78
  embeddings: list[list[float]],
78
79
  metadata: list[int],
80
+ **kwargs,
79
81
  ) -> (int, Exception):
80
82
  assert len(embeddings) == len(metadata)
81
83
  insert_count = 0
@@ -99,7 +101,6 @@ class Pinecone(VectorDB):
99
101
  k: int = 100,
100
102
  filters: dict | None = None,
101
103
  timeout: int | None = None,
102
- **kwargs: Any,
103
104
  ) -> list[tuple[int, float]]:
104
105
  if filters is None:
105
106
  pinecone_filters = {}
@@ -1,6 +1,7 @@
1
- from pydantic import SecretStr
1
+ from pydantic import BaseModel, SecretStr
2
2
 
3
- from ..api import DBConfig
3
+ from ..api import DBConfig, DBCaseConfig, MetricType
4
+ from qdrant_client.models import Distance
4
5
 
5
6
 
6
7
  class QdrantConfig(DBConfig):
@@ -13,3 +14,20 @@ class QdrantConfig(DBConfig):
13
14
  "api_key": self.api_key.get_secret_value(),
14
15
  "prefer_grpc": True,
15
16
  }
17
+
18
+ class QdrantIndexConfig(BaseModel, DBCaseConfig):
19
+ metric_type: MetricType | None = None
20
+
21
+ def parse_metric(self) -> str:
22
+ if self.metric_type == MetricType.L2:
23
+ return Distance.EUCLID
24
+ elif self.metric_type == MetricType.IP:
25
+ return Distance.DOT
26
+ return Distance.COSINE
27
+
28
+ def index_param(self) -> dict:
29
+ params = {"distance": self.parse_metric()}
30
+ return params
31
+
32
+ def search_param(self) -> dict:
33
+ return {}