vectordb-bench 0.0.29__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. vectordb_bench/__init__.py +14 -27
  2. vectordb_bench/backend/assembler.py +19 -6
  3. vectordb_bench/backend/cases.py +186 -23
  4. vectordb_bench/backend/clients/__init__.py +32 -0
  5. vectordb_bench/backend/clients/api.py +22 -1
  6. vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py +249 -43
  7. vectordb_bench/backend/clients/aws_opensearch/cli.py +51 -21
  8. vectordb_bench/backend/clients/aws_opensearch/config.py +58 -16
  9. vectordb_bench/backend/clients/chroma/chroma.py +6 -2
  10. vectordb_bench/backend/clients/elastic_cloud/config.py +19 -1
  11. vectordb_bench/backend/clients/elastic_cloud/elastic_cloud.py +133 -45
  12. vectordb_bench/backend/clients/lancedb/cli.py +62 -8
  13. vectordb_bench/backend/clients/lancedb/config.py +14 -1
  14. vectordb_bench/backend/clients/lancedb/lancedb.py +21 -9
  15. vectordb_bench/backend/clients/memorydb/memorydb.py +2 -2
  16. vectordb_bench/backend/clients/milvus/cli.py +30 -9
  17. vectordb_bench/backend/clients/milvus/config.py +3 -0
  18. vectordb_bench/backend/clients/milvus/milvus.py +81 -23
  19. vectordb_bench/backend/clients/oceanbase/cli.py +100 -0
  20. vectordb_bench/backend/clients/oceanbase/config.py +125 -0
  21. vectordb_bench/backend/clients/oceanbase/oceanbase.py +215 -0
  22. vectordb_bench/backend/clients/pinecone/pinecone.py +39 -25
  23. vectordb_bench/backend/clients/qdrant_cloud/config.py +59 -3
  24. vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py +100 -33
  25. vectordb_bench/backend/clients/qdrant_local/cli.py +60 -0
  26. vectordb_bench/backend/clients/qdrant_local/config.py +47 -0
  27. vectordb_bench/backend/clients/qdrant_local/qdrant_local.py +232 -0
  28. vectordb_bench/backend/clients/weaviate_cloud/cli.py +29 -3
  29. vectordb_bench/backend/clients/weaviate_cloud/config.py +2 -0
  30. vectordb_bench/backend/clients/weaviate_cloud/weaviate_cloud.py +5 -0
  31. vectordb_bench/backend/dataset.py +143 -27
  32. vectordb_bench/backend/filter.py +76 -0
  33. vectordb_bench/backend/runner/__init__.py +3 -3
  34. vectordb_bench/backend/runner/mp_runner.py +52 -39
  35. vectordb_bench/backend/runner/rate_runner.py +68 -52
  36. vectordb_bench/backend/runner/read_write_runner.py +125 -68
  37. vectordb_bench/backend/runner/serial_runner.py +56 -23
  38. vectordb_bench/backend/task_runner.py +48 -20
  39. vectordb_bench/cli/batch_cli.py +121 -0
  40. vectordb_bench/cli/cli.py +59 -1
  41. vectordb_bench/cli/vectordbbench.py +7 -0
  42. vectordb_bench/config-files/batch_sample_config.yml +17 -0
  43. vectordb_bench/frontend/components/check_results/data.py +16 -11
  44. vectordb_bench/frontend/components/check_results/filters.py +53 -25
  45. vectordb_bench/frontend/components/check_results/headerIcon.py +16 -13
  46. vectordb_bench/frontend/components/check_results/nav.py +20 -0
  47. vectordb_bench/frontend/components/custom/displayCustomCase.py +43 -8
  48. vectordb_bench/frontend/components/custom/displaypPrams.py +10 -5
  49. vectordb_bench/frontend/components/custom/getCustomConfig.py +10 -0
  50. vectordb_bench/frontend/components/label_filter/charts.py +60 -0
  51. vectordb_bench/frontend/components/run_test/caseSelector.py +48 -52
  52. vectordb_bench/frontend/components/run_test/dbSelector.py +9 -5
  53. vectordb_bench/frontend/components/run_test/inputWidget.py +48 -0
  54. vectordb_bench/frontend/components/run_test/submitTask.py +3 -1
  55. vectordb_bench/frontend/components/streaming/charts.py +253 -0
  56. vectordb_bench/frontend/components/streaming/data.py +62 -0
  57. vectordb_bench/frontend/components/tables/data.py +1 -1
  58. vectordb_bench/frontend/components/welcome/explainPrams.py +66 -0
  59. vectordb_bench/frontend/components/welcome/pagestyle.py +106 -0
  60. vectordb_bench/frontend/components/welcome/welcomePrams.py +147 -0
  61. vectordb_bench/frontend/config/dbCaseConfigs.py +420 -41
  62. vectordb_bench/frontend/config/styles.py +32 -2
  63. vectordb_bench/frontend/pages/concurrent.py +5 -1
  64. vectordb_bench/frontend/pages/custom.py +4 -0
  65. vectordb_bench/frontend/pages/label_filter.py +56 -0
  66. vectordb_bench/frontend/pages/quries_per_dollar.py +5 -1
  67. vectordb_bench/frontend/pages/results.py +60 -0
  68. vectordb_bench/frontend/pages/run_test.py +3 -3
  69. vectordb_bench/frontend/pages/streaming.py +135 -0
  70. vectordb_bench/frontend/pages/tables.py +4 -0
  71. vectordb_bench/frontend/vdb_benchmark.py +16 -41
  72. vectordb_bench/interface.py +6 -2
  73. vectordb_bench/metric.py +15 -1
  74. vectordb_bench/models.py +38 -11
  75. vectordb_bench/results/ElasticCloud/result_20250318_standard_elasticcloud.json +5890 -0
  76. vectordb_bench/results/Milvus/result_20250509_standard_milvus.json +6138 -0
  77. vectordb_bench/results/OpenSearch/result_20250224_standard_opensearch.json +7319 -0
  78. vectordb_bench/results/Pinecone/result_20250124_standard_pinecone.json +2365 -0
  79. vectordb_bench/results/QdrantCloud/result_20250602_standard_qdrantcloud.json +3556 -0
  80. vectordb_bench/results/ZillizCloud/result_20250613_standard_zillizcloud.json +6290 -0
  81. vectordb_bench/results/dbPrices.json +12 -4
  82. {vectordb_bench-0.0.29.dist-info → vectordb_bench-1.0.0.dist-info}/METADATA +131 -32
  83. {vectordb_bench-0.0.29.dist-info → vectordb_bench-1.0.0.dist-info}/RECORD +87 -65
  84. {vectordb_bench-0.0.29.dist-info → vectordb_bench-1.0.0.dist-info}/WHEEL +1 -1
  85. vectordb_bench/results/ZillizCloud/result_20230727_standard_zillizcloud.json +0 -791
  86. vectordb_bench/results/ZillizCloud/result_20230808_standard_zillizcloud.json +0 -679
  87. vectordb_bench/results/ZillizCloud/result_20240105_standard_202401_zillizcloud.json +0 -1352
  88. {vectordb_bench-0.0.29.dist-info → vectordb_bench-1.0.0.dist-info}/entry_points.txt +0 -0
  89. {vectordb_bench-0.0.29.dist-info → vectordb_bench-1.0.0.dist-info}/licenses/LICENSE +0 -0
  90. {vectordb_bench-0.0.29.dist-info → vectordb_bench-1.0.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,215 @@
1
+ import logging
2
+ import struct
3
+ import time
4
+ from collections.abc import Generator
5
+ from contextlib import contextmanager
6
+ from typing import Any
7
+
8
+ import mysql.connector as mysql
9
+
10
+ from ..api import IndexType, VectorDB
11
+ from .config import OceanBaseConfigDict, OceanBaseHNSWConfig
12
+
13
+ log = logging.getLogger(__name__)
14
+
15
+ OCEANBASE_DEFAULT_LOAD_BATCH_SIZE = 256
16
+
17
+
18
+ class OceanBase(VectorDB):
19
+ def __init__(
20
+ self,
21
+ dim: int,
22
+ db_config: OceanBaseConfigDict,
23
+ db_case_config: OceanBaseHNSWConfig,
24
+ collection_name: str = "items",
25
+ drop_old: bool = False,
26
+ **kwargs,
27
+ ):
28
+ self.name = "OceanBase"
29
+ self.dim = dim
30
+ self.db_config = db_config
31
+ self.db_case_config = db_case_config
32
+ self.table_name = collection_name
33
+ self.load_batch_size = OCEANBASE_DEFAULT_LOAD_BATCH_SIZE
34
+ self._index_name = "vidx"
35
+ self._primary_field = "id"
36
+ self._vector_field = "embedding"
37
+
38
+ log.info(
39
+ f"{self.name} initialized with config:\nDatabase: {self.db_config}\nCase Config: {self.db_case_config}"
40
+ )
41
+
42
+ self._conn = None
43
+ self._cursor = None
44
+
45
+ try:
46
+ self._connect()
47
+ if drop_old:
48
+ self._drop_table()
49
+ self._create_table()
50
+ finally:
51
+ self._disconnect()
52
+
53
+ def _connect(self):
54
+ try:
55
+ self._conn = mysql.connect(
56
+ host=self.db_config["host"],
57
+ user=self.db_config["user"],
58
+ port=self.db_config["port"],
59
+ password=self.db_config["password"],
60
+ database=self.db_config["database"],
61
+ )
62
+ self._cursor = self._conn.cursor()
63
+ except mysql.Error:
64
+ log.exception("Failed to connect to the database")
65
+ raise
66
+
67
+ def _disconnect(self):
68
+ if self._cursor:
69
+ self._cursor.close()
70
+ self._cursor = None
71
+ if self._conn:
72
+ self._conn.close()
73
+ self._conn = None
74
+
75
+ @contextmanager
76
+ def init(self) -> Generator[None, None, None]:
77
+ try:
78
+ self._connect()
79
+ self._cursor.execute("SET autocommit=1")
80
+
81
+ if self.db_case_config.index in {IndexType.HNSW, IndexType.HNSW_SQ, IndexType.HNSW_BQ}:
82
+ self._cursor.execute(
83
+ f"SET ob_hnsw_ef_search={(self.db_case_config.search_param())['params']['ef_search']}"
84
+ )
85
+ else:
86
+ self._cursor.execute(
87
+ f"SET ob_ivf_nprobes={(self.db_case_config.search_param())['params']['ivf_nprobes']}"
88
+ )
89
+ yield
90
+ finally:
91
+ self._disconnect()
92
+
93
+ def _drop_table(self):
94
+ if not self._cursor:
95
+ raise ValueError("Cursor is not initialized")
96
+
97
+ log.info(f"Dropping table {self.table_name}")
98
+ self._cursor.execute(f"DROP TABLE IF EXISTS {self.table_name}")
99
+
100
+ def _create_table(self):
101
+ if not self._cursor:
102
+ raise ValueError("Cursor is not initialized")
103
+
104
+ log.info(f"Creating table {self.table_name}")
105
+ create_table_query = f"""
106
+ CREATE TABLE {self.table_name} (
107
+ id INT PRIMARY KEY,
108
+ embedding VECTOR({self.dim})
109
+ );
110
+ """
111
+ self._cursor.execute(create_table_query)
112
+
113
+ def optimize(self, data_size: int):
114
+ index_params = self.db_case_config.index_param()
115
+ index_args = ", ".join(f"{k}={v}" for k, v in index_params["params"].items())
116
+ index_query = (
117
+ f"CREATE /*+ PARALLEL(18) */ VECTOR INDEX idx1 "
118
+ f"ON {self.table_name}(embedding) "
119
+ f"WITH (distance={self.db_case_config.parse_metric()}, "
120
+ f"type={index_params['index_type']}, lib={index_params['lib']}, {index_args}"
121
+ )
122
+
123
+ if self.db_case_config.index in {IndexType.HNSW, IndexType.HNSW_SQ, IndexType.HNSW_BQ}:
124
+ index_query += ", extra_info_max_size=32"
125
+
126
+ index_query += ")"
127
+
128
+ log.info("Create index query: %s", index_query)
129
+
130
+ try:
131
+ log.info("Creating index...")
132
+ start_time = time.time()
133
+ self._cursor.execute(index_query)
134
+ log.info(f"Index created in {time.time() - start_time:.2f} seconds")
135
+
136
+ log.info("Performing major freeze...")
137
+ self._cursor.execute("ALTER SYSTEM MAJOR FREEZE;")
138
+ time.sleep(10)
139
+ self._wait_for_major_compaction()
140
+
141
+ log.info("Gathering schema statistics...")
142
+ self._cursor.execute("CALL dbms_stats.gather_schema_stats('test', degree => 96);")
143
+ except mysql.Error:
144
+ log.exception("Failed to optimize index")
145
+ raise
146
+
147
+ def need_normalize_cosine(self) -> bool:
148
+ if self.db_case_config.index == IndexType.HNSW_BQ:
149
+ log.info("current HNSW_BQ only supports L2, cosine dataset need normalize.")
150
+ return True
151
+
152
+ return False
153
+
154
+ def _wait_for_major_compaction(self):
155
+ while True:
156
+ self._cursor.execute(
157
+ "SELECT IF(COUNT(*) = COUNT(STATUS = 'IDLE' OR NULL), 'TRUE', 'FALSE') "
158
+ "AS all_status_idle FROM oceanbase.DBA_OB_ZONE_MAJOR_COMPACTION;"
159
+ )
160
+ all_status_idle = self._cursor.fetchone()[0]
161
+ if all_status_idle == "TRUE":
162
+ break
163
+ time.sleep(10)
164
+
165
+ def insert_embeddings(
166
+ self,
167
+ embeddings: list[list[float]],
168
+ metadata: list[int],
169
+ **kwargs: Any,
170
+ ) -> tuple[int, Exception | None]:
171
+ if not self._cursor:
172
+ raise ValueError("Cursor is not initialized")
173
+
174
+ insert_count = 0
175
+ try:
176
+ for batch_start in range(0, len(embeddings), self.load_batch_size):
177
+ batch_end = min(batch_start + self.load_batch_size, len(embeddings))
178
+ batch = [(metadata[i], embeddings[i]) for i in range(batch_start, batch_end)]
179
+ values = ", ".join(f"({item_id}, '[{','.join(map(str, embedding))}]')" for item_id, embedding in batch)
180
+ self._cursor.execute(
181
+ f"INSERT /*+ ENABLE_PARALLEL_DML PARALLEL(32) */ INTO {self.table_name} VALUES {values}" # noqa: S608
182
+ )
183
+ insert_count += len(batch)
184
+ except mysql.Error:
185
+ log.exception("Failed to insert embeddings")
186
+ raise
187
+
188
+ return insert_count, None
189
+
190
+ def search_embedding(
191
+ self,
192
+ query: list[float],
193
+ k: int = 100,
194
+ filters: dict[str, Any] | None = None,
195
+ timeout: int | None = None,
196
+ ) -> list[int]:
197
+ if not self._cursor:
198
+ raise ValueError("Cursor is not initialized")
199
+
200
+ packed = struct.pack(f"<{len(query)}f", *query)
201
+ hex_vec = packed.hex()
202
+ filter_clause = f"WHERE id >= {filters['id']}" if filters else ""
203
+ query_str = (
204
+ f"SELECT id FROM {self.table_name} " # noqa: S608
205
+ f"{filter_clause} ORDER BY "
206
+ f"{self.db_case_config.parse_metric_func_str()}(embedding, X'{hex_vec}') "
207
+ f"APPROXIMATE LIMIT {k}"
208
+ )
209
+
210
+ try:
211
+ self._cursor.execute(query_str)
212
+ return [row[0] for row in self._cursor.fetchall()]
213
+ except mysql.Error:
214
+ log.exception("Failed to execute search query")
215
+ raise
@@ -5,8 +5,9 @@ from contextlib import contextmanager
5
5
 
6
6
  import pinecone
7
7
 
8
- from ..api import DBCaseConfig, DBConfig, EmptyDBCaseConfig, IndexType, VectorDB
9
- from .config import PineconeConfig
8
+ from vectordb_bench.backend.filter import Filter, FilterOp
9
+
10
+ from ..api import DBCaseConfig, VectorDB
10
11
 
11
12
  log = logging.getLogger(__name__)
12
13
 
@@ -15,12 +16,19 @@ PINECONE_MAX_SIZE_PER_BATCH = 2 * 1024 * 1024 # 2MB
15
16
 
16
17
 
17
18
  class Pinecone(VectorDB):
19
+ supported_filter_types: list[FilterOp] = [
20
+ FilterOp.NonFilter,
21
+ FilterOp.NumGE,
22
+ FilterOp.StrEqual,
23
+ ]
24
+
18
25
  def __init__(
19
26
  self,
20
27
  dim: int,
21
28
  db_config: dict,
22
29
  db_case_config: DBCaseConfig,
23
30
  drop_old: bool = False,
31
+ with_scalar_labels: bool = False,
24
32
  **kwargs,
25
33
  ):
26
34
  """Initialize wrapper around the milvus vector database."""
@@ -33,6 +41,7 @@ class Pinecone(VectorDB):
33
41
  pc = pinecone.Pinecone(api_key=self.api_key)
34
42
  index = pc.Index(self.index_name)
35
43
 
44
+ self.with_scalar_labels = with_scalar_labels
36
45
  if drop_old:
37
46
  index_stats = index.describe_index_stats()
38
47
  index_dim = index_stats["dimension"]
@@ -43,15 +52,8 @@ class Pinecone(VectorDB):
43
52
  log.info(f"Pinecone index delete namespace: {namespace}")
44
53
  index.delete(delete_all=True, namespace=namespace)
45
54
 
46
- self._metadata_key = "meta"
47
-
48
- @classmethod
49
- def config_cls(cls) -> type[DBConfig]:
50
- return PineconeConfig
51
-
52
- @classmethod
53
- def case_config_cls(cls, index_type: IndexType | None = None) -> type[DBCaseConfig]:
54
- return EmptyDBCaseConfig
55
+ self._scalar_id_field = "meta"
56
+ self._scalar_label_field = "label"
55
57
 
56
58
  @contextmanager
57
59
  def init(self):
@@ -66,6 +68,7 @@ class Pinecone(VectorDB):
66
68
  self,
67
69
  embeddings: list[list[float]],
68
70
  metadata: list[int],
71
+ labels_data: list[str] | None = None,
69
72
  **kwargs,
70
73
  ) -> tuple[int, Exception]:
71
74
  assert len(embeddings) == len(metadata)
@@ -75,33 +78,44 @@ class Pinecone(VectorDB):
75
78
  batch_end_offset = min(batch_start_offset + self.batch_size, len(embeddings))
76
79
  insert_datas = []
77
80
  for i in range(batch_start_offset, batch_end_offset):
81
+ metadata_dict = {self._scalar_id_field: metadata[i]}
82
+ if self.with_scalar_labels:
83
+ metadata_dict[self._scalar_label_field] = labels_data[i]
78
84
  insert_data = (
79
85
  str(metadata[i]),
80
86
  embeddings[i],
81
- {self._metadata_key: metadata[i]},
87
+ metadata_dict,
82
88
  )
83
89
  insert_datas.append(insert_data)
84
90
  self.index.upsert(insert_datas)
85
91
  insert_count += batch_end_offset - batch_start_offset
86
92
  except Exception as e:
87
- return (insert_count, e)
88
- return (len(embeddings), None)
93
+ return insert_count, e
94
+ return len(embeddings), None
89
95
 
90
96
  def search_embedding(
91
97
  self,
92
98
  query: list[float],
93
99
  k: int = 100,
94
- filters: dict | None = None,
95
100
  timeout: int | None = None,
96
101
  ) -> list[int]:
97
- pinecone_filters = {} if filters is None else {self._metadata_key: {"$gte": filters["id"]}}
98
- try:
99
- res = self.index.query(
100
- top_k=k,
101
- vector=query,
102
- filter=pinecone_filters,
103
- )["matches"]
104
- except Exception as e:
105
- log.warning(f"Error querying index: {e}")
106
- raise e from e
102
+ pinecone_filters = self.expr
103
+ res = self.index.query(
104
+ top_k=k,
105
+ vector=query,
106
+ filter=pinecone_filters,
107
+ )["matches"]
107
108
  return [int(one_res["id"]) for one_res in res]
109
+
110
+ def prepare_filter(self, filters: Filter):
111
+ if filters.type == FilterOp.NonFilter:
112
+ self.expr = None
113
+ elif filters.type == FilterOp.NumGE:
114
+ self.expr = {self._scalar_id_field: {"$gte": filters.int_value}}
115
+ elif filters.type == FilterOp.StrEqual:
116
+ # both "in" and "==" are supported
117
+ # for example, self.expr = {self._scalar_label_field: {"$in": [filters.label_value]}}
118
+ self.expr = {self._scalar_label_field: {"$eq": filters.label_value}}
119
+ else:
120
+ msg = f"Not support Filter for Pinecone - {filters}"
121
+ raise ValueError(msg)
@@ -1,7 +1,12 @@
1
- from pydantic import BaseModel, SecretStr
1
+ from typing import TypeVar
2
+
3
+ from pydantic import BaseModel, SecretStr, validator
2
4
 
3
5
  from ..api import DBCaseConfig, DBConfig, MetricType
4
6
 
7
+ # define type "SearchParams"
8
+ SearchParams = TypeVar("SearchParams")
9
+
5
10
 
6
11
  # Allowing `api_key` to be left empty, to ensure compatibility with the open-source Qdrant.
7
12
  class QdrantConfig(DBConfig):
@@ -20,9 +25,43 @@ class QdrantConfig(DBConfig):
20
25
  "url": self.url.get_secret_value(),
21
26
  }
22
27
 
28
+ @validator("*")
29
+ def not_empty_field(cls, v: any, field: any):
30
+ if field.name in ["api_key"]:
31
+ return v
32
+ return super().not_empty_field(v, field)
33
+
23
34
 
24
35
  class QdrantIndexConfig(BaseModel, DBCaseConfig):
25
36
  metric_type: MetricType | None = None
37
+ m: int = 16
38
+ payload_m: int = 16 # only for label_filter cases
39
+ create_payload_int_index: bool = False
40
+ create_payload_keyword_index: bool = False
41
+ is_tenant: bool = False
42
+ use_scalar_quant: bool = False
43
+ sq_quantile: float = 0.99
44
+ default_segment_number: int = 0
45
+
46
+ use_rescore: bool = False
47
+ oversampling: float = 1.0
48
+ indexed_only: bool = False
49
+ hnsw_ef: int | None = 100
50
+ exact: bool = False
51
+
52
+ with_payload: bool = False
53
+
54
+ def __eq__(self, obj: any):
55
+ return (
56
+ self.m == obj.m
57
+ and self.payload_m == obj.payload_m
58
+ and self.create_payload_int_index == obj.create_payload_int_index
59
+ and self.create_payload_keyword_index == obj.create_payload_keyword_index
60
+ and self.is_tenant == obj.is_tenant
61
+ and self.use_scalar_quant == obj.use_scalar_quant
62
+ and self.sq_quantile == obj.sq_quantile
63
+ and self.default_segment_number == obj.default_segment_number
64
+ )
26
65
 
27
66
  def parse_metric(self) -> str:
28
67
  if self.metric_type == MetricType.L2:
@@ -36,5 +75,22 @@ class QdrantIndexConfig(BaseModel, DBCaseConfig):
36
75
  def index_param(self) -> dict:
37
76
  return {"distance": self.parse_metric()}
38
77
 
39
- def search_param(self) -> dict:
40
- return {}
78
+ def search_param(self) -> SearchParams:
79
+ # Import while in use
80
+ from qdrant_client.http.models import QuantizationSearchParams, SearchParams
81
+
82
+ quantization = (
83
+ QuantizationSearchParams(
84
+ ignore=False,
85
+ rescore=True,
86
+ oversampling=self.oversampling,
87
+ )
88
+ if self.use_rescore
89
+ else None
90
+ )
91
+ return SearchParams(
92
+ hnsw_ef=self.hnsw_ef,
93
+ exact=self.exact,
94
+ indexed_only=self.indexed_only,
95
+ quantization=quantization,
96
+ )
@@ -9,13 +9,24 @@ from qdrant_client.http.models import (
9
9
  Batch,
10
10
  CollectionStatus,
11
11
  FieldCondition,
12
- Filter,
12
+ HnswConfigDiff,
13
+ KeywordIndexParams,
14
+ OptimizersConfigDiff,
13
15
  PayloadSchemaType,
14
16
  Range,
17
+ ScalarQuantization,
18
+ ScalarQuantizationConfig,
19
+ ScalarType,
15
20
  VectorParams,
16
21
  )
22
+ from qdrant_client.http.models import (
23
+ Filter as QdrantFilter,
24
+ )
17
25
 
18
- from ..api import DBCaseConfig, VectorDB
26
+ from vectordb_bench.backend.clients.qdrant_cloud.config import QdrantIndexConfig
27
+ from vectordb_bench.backend.filter import Filter, FilterOp
28
+
29
+ from ..api import VectorDB
19
30
 
20
31
  log = logging.getLogger(__name__)
21
32
 
@@ -25,24 +36,33 @@ QDRANT_BATCH_SIZE = 500
25
36
 
26
37
 
27
38
  class QdrantCloud(VectorDB):
39
+ supported_filter_types: list[FilterOp] = [
40
+ FilterOp.NonFilter,
41
+ FilterOp.NumGE,
42
+ FilterOp.StrEqual,
43
+ ]
44
+
28
45
  def __init__(
29
46
  self,
30
47
  dim: int,
31
48
  db_config: dict,
32
- db_case_config: DBCaseConfig,
49
+ db_case_config: QdrantIndexConfig,
33
50
  collection_name: str = "QdrantCloudCollection",
34
51
  drop_old: bool = False,
52
+ with_scalar_labels: bool = False,
35
53
  **kwargs,
36
54
  ):
37
55
  """Initialize wrapper around the QdrantCloud vector database."""
38
56
  self.db_config = db_config
39
- self.case_config = db_case_config
57
+ self.db_case_config = db_case_config
40
58
  self.collection_name = collection_name
41
59
 
42
60
  self._primary_field = "pk"
61
+ self._scalar_label_field = "label"
43
62
  self._vector_field = "vector"
44
63
 
45
64
  tmp_client = QdrantClient(**self.db_config)
65
+ self.with_scalar_labels = with_scalar_labels
46
66
  if drop_old:
47
67
  log.info(f"QdrantCloud client drop_old collection: {self.collection_name}")
48
68
  tmp_client.delete_collection(self.collection_name)
@@ -50,7 +70,7 @@ class QdrantCloud(VectorDB):
50
70
  tmp_client = None
51
71
 
52
72
  @contextmanager
53
- def init(self) -> None:
73
+ def init(self):
54
74
  """
55
75
  Examples:
56
76
  >>> with self.init():
@@ -74,7 +94,7 @@ class QdrantCloud(VectorDB):
74
94
  if info.status == CollectionStatus.GREEN:
75
95
  msg = (
76
96
  f"Stored vectors: {info.vectors_count}, Indexed vectors: {info.indexed_vectors_count}, "
77
- f"Collection status: {info.indexed_vectors_count}"
97
+ f"Collection status: {info.status}, Segment counts: {info.segments_count}"
78
98
  )
79
99
  log.info(msg)
80
100
  return
@@ -86,19 +106,48 @@ class QdrantCloud(VectorDB):
86
106
  log.info(f"Create collection: {self.collection_name}")
87
107
 
88
108
  try:
109
+ # whether to use quant (SQ8)
110
+ quantization_config = None
111
+ if self.db_case_config.use_scalar_quant:
112
+ quantization_config = ScalarQuantization(
113
+ scalar=ScalarQuantizationConfig(
114
+ type=ScalarType.INT8,
115
+ quantile=self.db_case_config.sq_quantile,
116
+ always_ram=True,
117
+ )
118
+ )
119
+
120
+ # create collection
89
121
  qdrant_client.create_collection(
90
122
  collection_name=self.collection_name,
91
123
  vectors_config=VectorParams(
92
124
  size=dim,
93
- distance=self.case_config.index_param()["distance"],
125
+ distance=self.db_case_config.parse_metric(),
126
+ ),
127
+ hnsw_config=HnswConfigDiff(m=self.db_case_config.m, payload_m=self.db_case_config.payload_m),
128
+ optimizers_config=OptimizersConfigDiff(
129
+ default_segment_number=self.db_case_config.default_segment_number
94
130
  ),
131
+ quantization_config=quantization_config,
95
132
  )
96
133
 
97
- qdrant_client.create_payload_index(
98
- collection_name=self.collection_name,
99
- field_name=self._primary_field,
100
- field_schema=PayloadSchemaType.INTEGER,
101
- )
134
+ # create payload_index for int-field
135
+ if self.db_case_config.create_payload_int_index:
136
+ qdrant_client.create_payload_index(
137
+ collection_name=self.collection_name,
138
+ field_name=self._primary_field,
139
+ field_schema=PayloadSchemaType.INTEGER,
140
+ )
141
+
142
+ # create payload_index for str-field
143
+ if self.with_scalar_labels and self.db_case_config.create_payload_keyword_index:
144
+ qdrant_client.create_payload_index(
145
+ collection_name=self.collection_name,
146
+ field_name=self._scalar_label_field,
147
+ field_schema=KeywordIndexParams(
148
+ type=PayloadSchemaType.KEYWORD, is_tenant=self.db_case_config.is_tenant
149
+ ),
150
+ )
102
151
 
103
152
  except Exception as e:
104
153
  if "already exists!" in str(e):
@@ -110,16 +159,22 @@ class QdrantCloud(VectorDB):
110
159
  self,
111
160
  embeddings: list[list[float]],
112
161
  metadata: list[int],
162
+ labels_data: list[str] | None = None,
113
163
  **kwargs,
114
164
  ) -> tuple[int, Exception]:
115
165
  """Insert embeddings into Milvus. should call self.init() first"""
116
166
  assert self.qdrant_client is not None
117
167
  try:
118
- # TODO: counts
119
168
  for offset in range(0, len(embeddings), QDRANT_BATCH_SIZE):
120
169
  vectors = embeddings[offset : offset + QDRANT_BATCH_SIZE]
121
170
  ids = metadata[offset : offset + QDRANT_BATCH_SIZE]
122
- payloads = [{self._primary_field: v} for v in ids]
171
+ if self.with_scalar_labels:
172
+ labels = labels_data[offset : offset + QDRANT_BATCH_SIZE]
173
+ payloads = [
174
+ {self._primary_field: pk, self._scalar_label_field: labels[i]} for i, pk in enumerate(ids)
175
+ ]
176
+ else:
177
+ payloads = [{self._primary_field: pk} for i, pk in enumerate(ids)]
123
178
  _ = self.qdrant_client.upsert(
124
179
  collection_name=self.collection_name,
125
180
  wait=True,
@@ -135,34 +190,46 @@ class QdrantCloud(VectorDB):
135
190
  self,
136
191
  query: list[float],
137
192
  k: int = 100,
138
- filters: dict | None = None,
139
193
  timeout: int | None = None,
194
+ **kwargs,
140
195
  ) -> list[int]:
141
196
  """Perform a search on a query embedding and return results with score.
142
197
  Should call self.init() first.
143
198
  """
144
199
  assert self.qdrant_client is not None
145
200
 
146
- f = None
147
- if filters:
148
- f = Filter(
201
+ res = self.qdrant_client.search(
202
+ collection_name=self.collection_name,
203
+ query_vector=query,
204
+ limit=k,
205
+ query_filter=self.query_filter,
206
+ search_params=self.db_case_config.search_param(),
207
+ with_payload=self.db_case_config.with_payload,
208
+ )
209
+
210
+ return [r.id for r in res]
211
+
212
+ def prepare_filter(self, filters: Filter):
213
+ if filters.type == FilterOp.NonFilter:
214
+ self.query_filter = None
215
+ elif filters.type == FilterOp.NumGE:
216
+ self.query_filter = QdrantFilter(
149
217
  must=[
150
218
  FieldCondition(
151
219
  key=self._primary_field,
152
- range=Range(
153
- gt=filters.get("id"),
154
- ),
220
+ range=Range(gte=filters.int_value),
155
221
  ),
156
- ],
222
+ ]
157
223
  )
158
-
159
- res = (
160
- self.qdrant_client.search(
161
- collection_name=self.collection_name,
162
- query_vector=query,
163
- limit=k,
164
- query_filter=f,
165
- ),
166
- )
167
-
168
- return [result.id for result in res[0]]
224
+ elif filters.type == FilterOp.StrEqual:
225
+ self.query_filter = QdrantFilter(
226
+ must=[
227
+ FieldCondition(
228
+ key=self._scalar_label_field,
229
+ match={"value": filters.label_value},
230
+ ),
231
+ ]
232
+ )
233
+ else:
234
+ msg = f"Not support Filter for Qdrant - {filters}"
235
+ raise ValueError(msg)