vectordb-bench 0.0.30__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. vectordb_bench/__init__.py +14 -27
  2. vectordb_bench/backend/assembler.py +19 -6
  3. vectordb_bench/backend/cases.py +186 -23
  4. vectordb_bench/backend/clients/__init__.py +16 -0
  5. vectordb_bench/backend/clients/api.py +22 -1
  6. vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py +82 -41
  7. vectordb_bench/backend/clients/aws_opensearch/config.py +23 -4
  8. vectordb_bench/backend/clients/chroma/chroma.py +6 -2
  9. vectordb_bench/backend/clients/elastic_cloud/config.py +19 -1
  10. vectordb_bench/backend/clients/elastic_cloud/elastic_cloud.py +133 -45
  11. vectordb_bench/backend/clients/milvus/config.py +1 -0
  12. vectordb_bench/backend/clients/milvus/milvus.py +74 -22
  13. vectordb_bench/backend/clients/oceanbase/cli.py +100 -0
  14. vectordb_bench/backend/clients/oceanbase/config.py +125 -0
  15. vectordb_bench/backend/clients/oceanbase/oceanbase.py +215 -0
  16. vectordb_bench/backend/clients/pinecone/pinecone.py +39 -25
  17. vectordb_bench/backend/clients/qdrant_cloud/config.py +59 -3
  18. vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py +100 -33
  19. vectordb_bench/backend/dataset.py +143 -27
  20. vectordb_bench/backend/filter.py +76 -0
  21. vectordb_bench/backend/runner/__init__.py +3 -3
  22. vectordb_bench/backend/runner/mp_runner.py +52 -39
  23. vectordb_bench/backend/runner/rate_runner.py +68 -52
  24. vectordb_bench/backend/runner/read_write_runner.py +125 -68
  25. vectordb_bench/backend/runner/serial_runner.py +56 -23
  26. vectordb_bench/backend/task_runner.py +48 -20
  27. vectordb_bench/cli/cli.py +59 -1
  28. vectordb_bench/cli/vectordbbench.py +3 -0
  29. vectordb_bench/frontend/components/check_results/data.py +16 -11
  30. vectordb_bench/frontend/components/check_results/filters.py +53 -25
  31. vectordb_bench/frontend/components/check_results/headerIcon.py +16 -13
  32. vectordb_bench/frontend/components/check_results/nav.py +20 -0
  33. vectordb_bench/frontend/components/custom/displayCustomCase.py +43 -8
  34. vectordb_bench/frontend/components/custom/displaypPrams.py +10 -5
  35. vectordb_bench/frontend/components/custom/getCustomConfig.py +10 -0
  36. vectordb_bench/frontend/components/label_filter/charts.py +60 -0
  37. vectordb_bench/frontend/components/run_test/caseSelector.py +48 -52
  38. vectordb_bench/frontend/components/run_test/dbSelector.py +9 -5
  39. vectordb_bench/frontend/components/run_test/inputWidget.py +48 -0
  40. vectordb_bench/frontend/components/run_test/submitTask.py +3 -1
  41. vectordb_bench/frontend/components/streaming/charts.py +253 -0
  42. vectordb_bench/frontend/components/streaming/data.py +62 -0
  43. vectordb_bench/frontend/components/tables/data.py +1 -1
  44. vectordb_bench/frontend/components/welcome/explainPrams.py +66 -0
  45. vectordb_bench/frontend/components/welcome/pagestyle.py +106 -0
  46. vectordb_bench/frontend/components/welcome/welcomePrams.py +147 -0
  47. vectordb_bench/frontend/config/dbCaseConfigs.py +307 -40
  48. vectordb_bench/frontend/config/styles.py +32 -2
  49. vectordb_bench/frontend/pages/concurrent.py +5 -1
  50. vectordb_bench/frontend/pages/custom.py +4 -0
  51. vectordb_bench/frontend/pages/label_filter.py +56 -0
  52. vectordb_bench/frontend/pages/quries_per_dollar.py +5 -1
  53. vectordb_bench/frontend/pages/results.py +60 -0
  54. vectordb_bench/frontend/pages/run_test.py +3 -3
  55. vectordb_bench/frontend/pages/streaming.py +135 -0
  56. vectordb_bench/frontend/pages/tables.py +4 -0
  57. vectordb_bench/frontend/vdb_benchmark.py +16 -41
  58. vectordb_bench/interface.py +6 -2
  59. vectordb_bench/metric.py +15 -1
  60. vectordb_bench/models.py +31 -11
  61. vectordb_bench/results/ElasticCloud/result_20250318_standard_elasticcloud.json +5890 -0
  62. vectordb_bench/results/Milvus/result_20250509_standard_milvus.json +6138 -0
  63. vectordb_bench/results/OpenSearch/result_20250224_standard_opensearch.json +7319 -0
  64. vectordb_bench/results/Pinecone/result_20250124_standard_pinecone.json +2365 -0
  65. vectordb_bench/results/QdrantCloud/result_20250602_standard_qdrantcloud.json +3556 -0
  66. vectordb_bench/results/ZillizCloud/result_20250613_standard_zillizcloud.json +6290 -0
  67. vectordb_bench/results/dbPrices.json +12 -4
  68. {vectordb_bench-0.0.30.dist-info → vectordb_bench-1.0.0.dist-info}/METADATA +85 -32
  69. {vectordb_bench-0.0.30.dist-info → vectordb_bench-1.0.0.dist-info}/RECORD +73 -56
  70. vectordb_bench/results/ZillizCloud/result_20230727_standard_zillizcloud.json +0 -791
  71. vectordb_bench/results/ZillizCloud/result_20230808_standard_zillizcloud.json +0 -679
  72. vectordb_bench/results/ZillizCloud/result_20240105_standard_202401_zillizcloud.json +0 -1352
  73. {vectordb_bench-0.0.30.dist-info → vectordb_bench-1.0.0.dist-info}/WHEEL +0 -0
  74. {vectordb_bench-0.0.30.dist-info → vectordb_bench-1.0.0.dist-info}/entry_points.txt +0 -0
  75. {vectordb_bench-0.0.30.dist-info → vectordb_bench-1.0.0.dist-info}/licenses/LICENSE +0 -0
  76. {vectordb_bench-0.0.30.dist-info → vectordb_bench-1.0.0.dist-info}/top_level.txt +0 -0
@@ -5,8 +5,10 @@ from contextlib import contextmanager
5
5
 
6
6
  from opensearchpy import OpenSearch
7
7
 
8
- from ..api import IndexType, VectorDB
9
- from .config import AWSOpenSearchConfig, AWSOpenSearchIndexConfig, AWSOS_Engine
8
+ from vectordb_bench.backend.filter import Filter, FilterOp
9
+
10
+ from ..api import VectorDB
11
+ from .config import AWSOpenSearchIndexConfig, AWSOS_Engine
10
12
 
11
13
  log = logging.getLogger(__name__)
12
14
 
@@ -16,6 +18,12 @@ SECONDS_WAITING_FOR_REPLICAS_TO_BE_ENABLED_SEC = 30
16
18
 
17
19
 
18
20
  class AWSOpenSearch(VectorDB):
21
+ supported_filter_types: list[FilterOp] = [
22
+ FilterOp.NonFilter,
23
+ FilterOp.NumGE,
24
+ FilterOp.StrEqual,
25
+ ]
26
+
19
27
  def __init__(
20
28
  self,
21
29
  dim: int,
@@ -23,8 +31,10 @@ class AWSOpenSearch(VectorDB):
23
31
  db_case_config: AWSOpenSearchIndexConfig,
24
32
  index_name: str = "vdb_bench_index", # must be lowercase
25
33
  id_col_name: str = "_id",
34
+ label_col_name: str = "label",
26
35
  vector_col_name: str = "embedding",
27
36
  drop_old: bool = False,
37
+ with_scalar_labels: bool = False,
28
38
  **kwargs,
29
39
  ):
30
40
  self.dim = dim
@@ -32,8 +42,9 @@ class AWSOpenSearch(VectorDB):
32
42
  self.case_config = db_case_config
33
43
  self.index_name = index_name
34
44
  self.id_col_name = id_col_name
35
- self.category_col_names = [f"scalar-{categoryCount}" for categoryCount in [2, 5, 10, 100, 1000]]
45
+ self.label_col_name = label_col_name
36
46
  self.vector_col_name = vector_col_name
47
+ self.with_scalar_labels = with_scalar_labels
37
48
 
38
49
  log.info(f"AWS_OpenSearch client config: {self.db_config}")
39
50
  log.info(f"AWS_OpenSearch db case config : {self.case_config}")
@@ -53,14 +64,6 @@ class AWSOpenSearch(VectorDB):
53
64
  self._update_ef_search_before_search(client)
54
65
  self._load_graphs_to_memory(client)
55
66
 
56
- @classmethod
57
- def config_cls(cls) -> AWSOpenSearchConfig:
58
- return AWSOpenSearchConfig
59
-
60
- @classmethod
61
- def case_config_cls(cls, index_type: IndexType | None = None) -> AWSOpenSearchIndexConfig:
62
- return AWSOpenSearchIndexConfig
63
-
64
67
  def _create_index(self, client: OpenSearch) -> None:
65
68
  ef_search_value = (
66
69
  self.case_config.ef_search if self.case_config.ef_search is not None else self.case_config.efSearch
@@ -93,7 +96,8 @@ class AWSOpenSearch(VectorDB):
93
96
  mappings = {
94
97
  "_source": {"excludes": [self.vector_col_name], "recovery_source_excludes": [self.vector_col_name]},
95
98
  "properties": {
96
- **{categoryCol: {"type": "keyword"} for categoryCol in self.category_col_names},
99
+ self.id_col_name: {"type": "integer", "store": True},
100
+ self.label_col_name: {"type": "keyword"},
97
101
  self.vector_col_name: {
98
102
  "type": "knn_vector",
99
103
  "dimension": self.dim,
@@ -125,6 +129,7 @@ class AWSOpenSearch(VectorDB):
125
129
  self,
126
130
  embeddings: Iterable[list[float]],
127
131
  metadata: list[int],
132
+ labels_data: list[str] | None = None,
128
133
  **kwargs,
129
134
  ) -> tuple[int, Exception]:
130
135
  """Insert the embeddings to the opensearch."""
@@ -135,34 +140,42 @@ class AWSOpenSearch(VectorDB):
135
140
 
136
141
  if num_clients <= 1:
137
142
  log.info("Using single client for data insertion")
138
- return self._insert_with_single_client(embeddings, metadata)
143
+ return self._insert_with_single_client(embeddings, metadata, labels_data)
139
144
  log.info(f"Using {num_clients} parallel clients for data insertion")
140
- return self._insert_with_multiple_clients(embeddings, metadata, num_clients)
145
+ return self._insert_with_multiple_clients(embeddings, metadata, num_clients, labels_data)
141
146
 
142
147
  def _insert_with_single_client(
143
- self, embeddings: Iterable[list[float]], metadata: list[int]
148
+ self,
149
+ embeddings: Iterable[list[float]],
150
+ metadata: list[int],
151
+ labels_data: list[str] | None = None,
144
152
  ) -> tuple[int, Exception]:
145
153
  insert_data = []
146
154
  for i in range(len(embeddings)):
147
- insert_data.append(
148
- {"index": {"_index": self.index_name, self.id_col_name: metadata[i]}},
149
- )
150
- insert_data.append({self.vector_col_name: embeddings[i]})
155
+ index_data = {"index": {"_index": self.index_name, self.id_col_name: metadata[i]}}
156
+ if self.with_scalar_labels and self.case_config.use_routing:
157
+ index_data["routing"] = labels_data[i]
158
+ insert_data.append(index_data)
159
+
160
+ other_data = {self.vector_col_name: embeddings[i]}
161
+ if self.with_scalar_labels:
162
+ other_data[self.label_col_name] = labels_data[i]
163
+ insert_data.append(other_data)
164
+
151
165
  try:
152
- resp = self.client.bulk(insert_data)
153
- log.info(f"AWS_OpenSearch adding documents: {len(resp['items'])}")
154
- resp = self.client.indices.stats(self.index_name)
155
- log.info(
156
- f"Total document count in index: {resp['_all']['primaries']['indexing']['index_total']}",
157
- )
158
- return (len(embeddings), None)
166
+ self.client.bulk(insert_data)
167
+ return len(embeddings), None
159
168
  except Exception as e:
160
169
  log.warning(f"Failed to insert data: {self.index_name} error: {e!s}")
161
170
  time.sleep(10)
162
171
  return self._insert_with_single_client(embeddings, metadata)
163
172
 
164
173
  def _insert_with_multiple_clients(
165
- self, embeddings: Iterable[list[float]], metadata: list[int], num_clients: int
174
+ self,
175
+ embeddings: Iterable[list[float]],
176
+ metadata: list[int],
177
+ num_clients: int,
178
+ labels_data: list[str] | None = None,
166
179
  ) -> tuple[int, Exception]:
167
180
  import concurrent.futures
168
181
  from concurrent.futures import ThreadPoolExecutor
@@ -173,7 +186,7 @@ class AWSOpenSearch(VectorDB):
173
186
 
174
187
  for i in range(0, len(embeddings_list), chunk_size):
175
188
  end = min(i + chunk_size, len(embeddings_list))
176
- chunks.append((embeddings_list[i:end], metadata[i:end]))
189
+ chunks.append((embeddings_list[i:end], metadata[i:end], labels_data[i:end]))
177
190
 
178
191
  clients = []
179
192
  for _ in range(min(num_clients, len(chunks))):
@@ -183,15 +196,20 @@ class AWSOpenSearch(VectorDB):
183
196
  log.info(f"AWS_OpenSearch using {len(clients)} parallel clients for data insertion")
184
197
 
185
198
  def insert_chunk(client_idx: int, chunk_idx: int):
186
- chunk_embeddings, chunk_metadata = chunks[chunk_idx]
199
+ chunk_embeddings, chunk_metadata, chunk_labels_data = chunks[chunk_idx]
187
200
  client = clients[client_idx]
188
201
 
189
202
  insert_data = []
190
203
  for i in range(len(chunk_embeddings)):
191
- insert_data.append(
192
- {"index": {"_index": self.index_name, self.id_col_name: chunk_metadata[i]}},
193
- )
194
- insert_data.append({self.vector_col_name: chunk_embeddings[i]})
204
+ index_data = {"index": {"_index": self.index_name, self.id_col_name: chunk_metadata[i]}}
205
+ if self.with_scalar_labels and self.case_config.use_routing:
206
+ index_data["routing"] = chunk_labels_data[i]
207
+ insert_data.append(index_data)
208
+
209
+ other_data = {self.vector_col_name: chunk_embeddings[i]}
210
+ if self.with_scalar_labels:
211
+ other_data[self.label_col_name] = chunk_labels_data[i]
212
+ insert_data.append(other_data)
195
213
 
196
214
  try:
197
215
  resp = client.bulk(insert_data)
@@ -266,17 +284,16 @@ class AWSOpenSearch(VectorDB):
266
284
  self,
267
285
  query: list[float],
268
286
  k: int = 100,
269
- filters: dict | None = None,
287
+ **kwargs,
270
288
  ) -> list[int]:
271
289
  """Get k most similar embeddings to query vector.
272
290
 
273
291
  Args:
274
292
  query(list[float]): query embedding to look up documents similar to.
275
293
  k(int): Number of most similar embeddings to return. Defaults to 100.
276
- filters(dict, optional): filtering expression to filter the data while searching.
277
294
 
278
295
  Returns:
279
- list[tuple[int, float]]: list of k most similar embeddings in (id, score) tuple to the query embedding.
296
+ list[int]: list of k most similar ids to the query embedding.
280
297
  """
281
298
  assert self.client is not None, "should self.init() first"
282
299
 
@@ -287,11 +304,16 @@ class AWSOpenSearch(VectorDB):
287
304
  self.vector_col_name: {
288
305
  "vector": query,
289
306
  "k": k,
290
- "method_parameters": {"ef_search": self.case_config.efSearch},
307
+ "method_parameters": self.case_config.search_param(),
308
+ **({"filter": self.filter} if self.filter else {}),
309
+ **(
310
+ {"rescore": {"oversample_factor": self.case_config.oversample_factor}}
311
+ if self.case_config.use_quant
312
+ else {}
313
+ ),
291
314
  }
292
315
  }
293
316
  },
294
- **({"filter": {"range": {self.id_col_name: {"gt": filters["id"]}}}} if filters else {}),
295
317
  }
296
318
 
297
319
  try:
@@ -303,15 +325,34 @@ class AWSOpenSearch(VectorDB):
303
325
  docvalue_fields=[self.id_col_name],
304
326
  stored_fields="_none_",
305
327
  preference="_only_local" if self.case_config.number_of_shards == 1 else None,
328
+ routing=self.routing_key,
306
329
  )
307
330
  log.debug(f"Search took: {resp['took']}")
308
331
  log.debug(f"Search shards: {resp['_shards']}")
309
332
  log.debug(f"Search hits total: {resp['hits']['total']}")
310
- return [int(h["fields"][self.id_col_name][0]) for h in resp["hits"]["hits"]]
333
+ try:
334
+ return [int(h["fields"][self.id_col_name][0]) for h in resp["hits"]["hits"]]
335
+ except Exception:
336
+ # empty results
337
+ return []
311
338
  except Exception as e:
312
339
  log.warning(f"Failed to search: {self.index_name} error: {e!s}")
313
340
  raise e from None
314
341
 
342
+ def prepare_filter(self, filters: Filter):
343
+ self.routing_key = None
344
+ if filters.type == FilterOp.NonFilter:
345
+ self.filter = None
346
+ elif filters.type == FilterOp.NumGE:
347
+ self.filter = {"range": {self.id_col_name: {"gt": filters.int_value}}}
348
+ elif filters.type == FilterOp.StrEqual:
349
+ self.filter = {"term": {self.label_col_name: filters.label_value}}
350
+ if self.case_config.use_routing:
351
+ self.routing_key = filters.label_value
352
+ else:
353
+ msg = f"Not support Filter for OpenSearch - {filters}"
354
+ raise ValueError(msg)
355
+
315
356
  def optimize(self, data_size: int | None = None):
316
357
  """optimize will be called between insertion and search in performance cases."""
317
358
  self._update_ef_search()
@@ -392,7 +433,7 @@ class AWSOpenSearch(VectorDB):
392
433
  )
393
434
  log.info(f"response of updating setting is: {output}")
394
435
 
395
- log.debug(f"Starting force merge for index {self.index_name}")
436
+ log.info(f"Starting force merge for index {self.index_name}")
396
437
  segments = self.case_config.number_of_segments
397
438
  force_merge_endpoint = f"/{self.index_name}/_forcemerge?max_num_segments={segments}&wait_for_completion=false"
398
439
  force_merge_task_id = self.client.transport.perform_request("POST", force_merge_endpoint)["task"]
@@ -401,7 +442,7 @@ class AWSOpenSearch(VectorDB):
401
442
  task_status = self.client.tasks.get(task_id=force_merge_task_id)
402
443
  if task_status["completed"]:
403
444
  break
404
- log.debug(f"Completed force merge for index {self.index_name}")
445
+ log.info(f"Completed force merge for index {self.index_name}")
405
446
 
406
447
  def _load_graphs_to_memory(self, client: OpenSearch):
407
448
  if self.case_config.engine != AWSOS_Engine.lucene:
@@ -45,7 +45,7 @@ class AWSOpenSearchIndexConfig(BaseModel, DBCaseConfig):
45
45
  metric_type: MetricType = MetricType.L2
46
46
  engine: AWSOS_Engine = AWSOS_Engine.faiss
47
47
  efConstruction: int = 256
48
- ef_search: int = 200
48
+ efSearch: int = 100
49
49
  engine_name: str | None = None
50
50
  metric_type_name: str | None = None
51
51
  M: int = 16
@@ -56,10 +56,25 @@ class AWSOpenSearchIndexConfig(BaseModel, DBCaseConfig):
56
56
  refresh_interval: str | None = "60s"
57
57
  force_merge_enabled: bool | None = True
58
58
  flush_threshold_size: str | None = "5120mb"
59
- index_thread_qty_during_force_merge: int
59
+ index_thread_qty_during_force_merge: int = 8
60
60
  cb_threshold: str | None = "50%"
61
+ number_of_indexing_clients: int | None = 1
62
+ use_routing: bool = False # for label-filter cases
63
+ oversample_factor: float = 1.0
61
64
  quantization_type: AWSOSQuantization = AWSOSQuantization.fp32
62
65
 
66
+ def __eq__(self, obj: any):
67
+ return (
68
+ self.engine == obj.engine
69
+ and self.M == obj.M
70
+ and self.efConstruction == obj.efConstruction
71
+ and self.number_of_shards == obj.number_of_shards
72
+ and self.number_of_replicas == obj.number_of_replicas
73
+ and self.number_of_segments == obj.number_of_segments
74
+ and self.use_routing == obj.use_routing
75
+ and self.quantization_type == obj.quantization_type
76
+ )
77
+
63
78
  def parse_metric(self) -> str:
64
79
  log.info(f"User specified metric_type: {self.metric_type_name}")
65
80
  self.metric_type = MetricType[self.metric_type_name.upper()]
@@ -72,6 +87,10 @@ class AWSOpenSearchIndexConfig(BaseModel, DBCaseConfig):
72
87
  return "l2"
73
88
  return "l2"
74
89
 
90
+ @property
91
+ def use_quant(self) -> bool:
92
+ return self.quantization_type is not AWSOSQuantization.fp32
93
+
75
94
  def index_param(self) -> dict:
76
95
  log.info(f"Using engine: {self.engine} for index creation")
77
96
  log.info(f"Using metric_type: {self.metric_type_name} for index creation")
@@ -91,11 +110,11 @@ class AWSOpenSearchIndexConfig(BaseModel, DBCaseConfig):
91
110
  "ef_search": self.efSearch,
92
111
  **(
93
112
  {"encoder": {"name": "sq", "parameters": {"type": self.quantization_type.fp16.value}}}
94
- if self.quantization_type is not AWSOSQuantization.fp32
113
+ if self.use_quant
95
114
  else {}
96
115
  ),
97
116
  },
98
117
  }
99
118
 
100
119
  def search_param(self) -> dict:
101
- return {}
120
+ return {"ef_search": self.efSearch}
@@ -78,8 +78,12 @@ class ChromaClient(VectorDB):
78
78
  """
79
79
  ids = [str(i) for i in metadata]
80
80
  metadata = [{"id": int(i)} for i in metadata]
81
- if len(embeddings) > 0:
82
- self.collection.add(embeddings=embeddings, ids=ids, metadatas=metadata)
81
+ try:
82
+ if len(embeddings) > 0:
83
+ self.collection.add(embeddings=embeddings, ids=ids, metadatas=metadata)
84
+ except Exception as e:
85
+ log.warning(f"Failed to insert data: error: {e!s}")
86
+ return 0, e
83
87
  return len(embeddings), None
84
88
 
85
89
  def search_embedding(
@@ -23,13 +23,31 @@ class ESElementType(str, Enum):
23
23
 
24
24
  class ElasticCloudIndexConfig(BaseModel, DBCaseConfig):
25
25
  element_type: ESElementType = ESElementType.float
26
- index: IndexType = IndexType.ES_HNSW # ES only support 'hnsw'
26
+ index: IndexType = IndexType.ES_HNSW
27
+ number_of_shards: int = 1
28
+ number_of_replicas: int = 0
29
+ refresh_interval: str = "30s"
30
+ merge_max_thread_count: int = 8
31
+ use_rescore: bool = False
32
+ oversample_ratio: float = 2.0
33
+ use_routing: bool = False
34
+ use_force_merge: bool = True
27
35
 
28
36
  metric_type: MetricType | None = None
29
37
  efConstruction: int | None = None
30
38
  M: int | None = None
31
39
  num_candidates: int | None = None
32
40
 
41
+ def __eq__(self, obj: any):
42
+ return (
43
+ self.index == obj.index
44
+ and self.number_of_shards == obj.number_of_shards
45
+ and self.number_of_replicas == obj.number_of_replicas
46
+ and self.use_routing == obj.use_routing
47
+ and self.efConstruction == obj.efConstruction
48
+ and self.M == obj.M
49
+ )
50
+
33
51
  def parse_metric(self) -> str:
34
52
  if self.metric_type == MetricType.L2:
35
53
  return "l2_norm"
@@ -5,6 +5,8 @@ from contextlib import contextmanager
5
5
 
6
6
  from elasticsearch.helpers import bulk
7
7
 
8
+ from vectordb_bench.backend.filter import Filter, FilterOp
9
+
8
10
  from ..api import VectorDB
9
11
  from .config import ElasticCloudIndexConfig
10
12
 
@@ -18,6 +20,12 @@ SECONDS_WAITING_FOR_FORCE_MERGE_API_CALL_SEC = 30
18
20
 
19
21
 
20
22
  class ElasticCloud(VectorDB):
23
+ supported_filter_types: list[FilterOp] = [
24
+ FilterOp.NonFilter,
25
+ FilterOp.NumGE,
26
+ FilterOp.StrEqual,
27
+ ]
28
+
21
29
  def __init__(
22
30
  self,
23
31
  dim: int,
@@ -25,8 +33,10 @@ class ElasticCloud(VectorDB):
25
33
  db_case_config: ElasticCloudIndexConfig,
26
34
  indice: str = "vdb_bench_indice", # must be lowercase
27
35
  id_col_name: str = "id",
36
+ label_col_name: str = "label",
28
37
  vector_col_name: str = "vector",
29
38
  drop_old: bool = False,
39
+ with_scalar_labels: bool = False,
30
40
  **kwargs,
31
41
  ):
32
42
  self.dim = dim
@@ -34,7 +44,9 @@ class ElasticCloud(VectorDB):
34
44
  self.case_config = db_case_config
35
45
  self.indice = indice
36
46
  self.id_col_name = id_col_name
47
+ self.label_col_name = label_col_name
37
48
  self.vector_col_name = vector_col_name
49
+ self.with_scalar_labels = with_scalar_labels
38
50
 
39
51
  from elasticsearch import Elasticsearch
40
52
 
@@ -69,9 +81,17 @@ class ElasticCloud(VectorDB):
69
81
  },
70
82
  },
71
83
  }
84
+ settings = {
85
+ "index": {
86
+ "number_of_shards": self.case_config.number_of_shards,
87
+ "number_of_replicas": self.case_config.number_of_replicas,
88
+ "refresh_interval": self.case_config.refresh_interval,
89
+ "merge.scheduler.max_thread_count": self.case_config.merge_max_thread_count,
90
+ }
91
+ }
72
92
 
73
93
  try:
74
- client.indices.create(index=self.indice, mappings=mappings)
94
+ client.indices.create(index=self.indice, mappings=mappings, settings=settings)
75
95
  except Exception as e:
76
96
  log.warning(f"Failed to create indice: {self.indice} error: {e!s}")
77
97
  raise e from None
@@ -80,21 +100,48 @@ class ElasticCloud(VectorDB):
80
100
  self,
81
101
  embeddings: Iterable[list[float]],
82
102
  metadata: list[int],
103
+ labels_data: list[str] | None = None,
83
104
  **kwargs,
84
105
  ) -> tuple[int, Exception]:
85
106
  """Insert the embeddings to the elasticsearch."""
86
107
  assert self.client is not None, "should self.init() first"
87
108
 
88
- insert_data = [
89
- {
90
- "_index": self.indice,
91
- "_source": {
92
- self.id_col_name: metadata[i],
93
- self.vector_col_name: embeddings[i],
94
- },
95
- }
96
- for i in range(len(embeddings))
97
- ]
109
+ insert_data = (
110
+ [
111
+ (
112
+ {
113
+ "_index": self.indice,
114
+ "_source": {
115
+ self.id_col_name: metadata[i],
116
+ self.label_col_name: labels_data[i],
117
+ self.vector_col_name: embeddings[i],
118
+ },
119
+ "_routing": labels_data[i],
120
+ }
121
+ if self.case_config.use_routing
122
+ else {
123
+ "_index": self.indice,
124
+ "_source": {
125
+ self.id_col_name: metadata[i],
126
+ self.label_col_name: labels_data[i],
127
+ self.vector_col_name: embeddings[i],
128
+ },
129
+ }
130
+ )
131
+ for i in range(len(embeddings))
132
+ ]
133
+ if self.with_scalar_labels
134
+ else [
135
+ {
136
+ "_index": self.indice,
137
+ "_source": {
138
+ self.id_col_name: metadata[i],
139
+ self.vector_col_name: embeddings[i],
140
+ },
141
+ }
142
+ for i in range(len(embeddings))
143
+ ]
144
+ )
98
145
  try:
99
146
  bulk_insert_res = bulk(self.client, insert_data)
100
147
  return (bulk_insert_res[0], None)
@@ -102,59 +149,100 @@ class ElasticCloud(VectorDB):
102
149
  log.warning(f"Failed to insert data: {self.indice} error: {e!s}")
103
150
  return (0, e)
104
151
 
152
+ def prepare_filter(self, filters: Filter):
153
+ self.routing_key = None
154
+ if filters.type == FilterOp.NonFilter:
155
+ self.filter = []
156
+ elif filters.type == FilterOp.NumGE:
157
+ self.filter = {"range": {self.id_col_name: {"gt": filters.int_value}}}
158
+ elif filters.type == FilterOp.StrEqual:
159
+ self.filter = {"term": {self.label_col_name: filters.label_value}}
160
+ if self.case_config.use_routing:
161
+ self.routing_key = filters.label_value
162
+ else:
163
+ msg = f"Not support Filter for Milvus - {filters}"
164
+ raise ValueError(msg)
165
+
105
166
  def search_embedding(
106
167
  self,
107
168
  query: list[float],
108
169
  k: int = 100,
109
- filters: dict | None = None,
170
+ **kwargs,
110
171
  ) -> list[int]:
111
172
  """Get k most similar embeddings to query vector.
112
173
 
113
174
  Args:
114
175
  query(list[float]): query embedding to look up documents similar to.
115
176
  k(int): Number of most similar embeddings to return. Defaults to 100.
116
- filters(dict, optional): filtering expression to filter the data while searching.
117
177
 
118
178
  Returns:
119
179
  list[tuple[int, float]]: list of k most similar embeddings in (id, score) tuple to the query embedding.
120
180
  """
121
181
  assert self.client is not None, "should self.init() first"
122
182
 
123
- knn = {
124
- "field": self.vector_col_name,
125
- "k": k,
126
- "num_candidates": self.case_config.num_candidates,
127
- "filter": [{"range": {self.id_col_name: {"gt": filters["id"]}}}] if filters else [],
128
- "query_vector": query,
129
- }
183
+ if self.case_config.use_rescore:
184
+ oversample_k = int(k * self.case_config.oversample_ratio)
185
+ oversample_num_candidates = int(self.case_config.num_candidates * self.case_config.oversample_ratio)
186
+ knn = {
187
+ "field": self.vector_col_name,
188
+ "k": oversample_k,
189
+ "num_candidates": oversample_num_candidates,
190
+ "filter": self.filter,
191
+ "query_vector": query,
192
+ }
193
+ rescore = {
194
+ "window_size": oversample_k,
195
+ "query": {
196
+ "rescore_query": {
197
+ "script_score": {
198
+ "query": {"match_all": {}},
199
+ "script": {
200
+ "source": f"cosineSimilarity(params.queryVector, '{self.vector_col_name}')",
201
+ "params": {"queryVector": query},
202
+ },
203
+ }
204
+ },
205
+ "query_weight": 0,
206
+ "rescore_query_weight": 1,
207
+ },
208
+ }
209
+ else:
210
+ knn = {
211
+ "field": self.vector_col_name,
212
+ "k": k,
213
+ "num_candidates": self.case_config.num_candidates,
214
+ "filter": self.filter,
215
+ "query_vector": query,
216
+ }
217
+ rescore = None
130
218
  size = k
131
- try:
132
- res = self.client.search(
133
- index=self.indice,
134
- knn=knn,
135
- size=size,
136
- _source=False,
137
- docvalue_fields=[self.id_col_name],
138
- stored_fields="_none_",
139
- filter_path=[f"hits.hits.fields.{self.id_col_name}"],
140
- )
141
- return [h["fields"][self.id_col_name][0] for h in res["hits"]["hits"]]
142
- except Exception as e:
143
- log.warning(f"Failed to search: {self.indice} error: {e!s}")
144
- raise e from None
219
+
220
+ res = self.client.search(
221
+ index=self.indice,
222
+ knn=knn,
223
+ routing=self.routing_key,
224
+ rescore=rescore,
225
+ size=size,
226
+ _source=False,
227
+ docvalue_fields=[self.id_col_name],
228
+ stored_fields="_none_",
229
+ filter_path=[f"hits.hits.fields.{self.id_col_name}"],
230
+ )
231
+ return [h["fields"][self.id_col_name][0] for h in res["hits"]["hits"]]
145
232
 
146
233
  def optimize(self, data_size: int | None = None):
147
234
  """optimize will be called between insertion and search in performance cases."""
148
235
  assert self.client is not None, "should self.init() first"
149
236
  self.client.indices.refresh(index=self.indice)
150
- force_merge_task_id = self.client.indices.forcemerge(
151
- index=self.indice,
152
- max_num_segments=1,
153
- wait_for_completion=False,
154
- )["task"]
155
- log.info(f"Elasticsearch force merge task id: {force_merge_task_id}")
156
- while True:
157
- time.sleep(SECONDS_WAITING_FOR_FORCE_MERGE_API_CALL_SEC)
158
- task_status = self.client.tasks.get(task_id=force_merge_task_id)
159
- if task_status["completed"]:
160
- return
237
+ if self.case_config.use_force_merge:
238
+ force_merge_task_id = self.client.indices.forcemerge(
239
+ index=self.indice,
240
+ max_num_segments=1,
241
+ wait_for_completion=False,
242
+ )["task"]
243
+ log.info(f"Elasticsearch force merge task id: {force_merge_task_id}")
244
+ while True:
245
+ time.sleep(SECONDS_WAITING_FOR_FORCE_MERGE_API_CALL_SEC)
246
+ task_status = self.client.tasks.get(task_id=force_merge_task_id)
247
+ if task_status["completed"]:
248
+ return
@@ -35,6 +35,7 @@ class MilvusIndexConfig(BaseModel):
35
35
 
36
36
  index: IndexType
37
37
  metric_type: MetricType | None = None
38
+ use_partition_key: bool = True # for label-filter
38
39
 
39
40
  @property
40
41
  def is_gpu_index(self) -> bool: