vectordb-bench 0.0.29__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. vectordb_bench/__init__.py +14 -27
  2. vectordb_bench/backend/assembler.py +19 -6
  3. vectordb_bench/backend/cases.py +186 -23
  4. vectordb_bench/backend/clients/__init__.py +32 -0
  5. vectordb_bench/backend/clients/api.py +22 -1
  6. vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py +249 -43
  7. vectordb_bench/backend/clients/aws_opensearch/cli.py +51 -21
  8. vectordb_bench/backend/clients/aws_opensearch/config.py +58 -16
  9. vectordb_bench/backend/clients/chroma/chroma.py +6 -2
  10. vectordb_bench/backend/clients/elastic_cloud/config.py +19 -1
  11. vectordb_bench/backend/clients/elastic_cloud/elastic_cloud.py +133 -45
  12. vectordb_bench/backend/clients/lancedb/cli.py +62 -8
  13. vectordb_bench/backend/clients/lancedb/config.py +14 -1
  14. vectordb_bench/backend/clients/lancedb/lancedb.py +21 -9
  15. vectordb_bench/backend/clients/memorydb/memorydb.py +2 -2
  16. vectordb_bench/backend/clients/milvus/cli.py +30 -9
  17. vectordb_bench/backend/clients/milvus/config.py +3 -0
  18. vectordb_bench/backend/clients/milvus/milvus.py +81 -23
  19. vectordb_bench/backend/clients/oceanbase/cli.py +100 -0
  20. vectordb_bench/backend/clients/oceanbase/config.py +125 -0
  21. vectordb_bench/backend/clients/oceanbase/oceanbase.py +215 -0
  22. vectordb_bench/backend/clients/pinecone/pinecone.py +39 -25
  23. vectordb_bench/backend/clients/qdrant_cloud/config.py +59 -3
  24. vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py +100 -33
  25. vectordb_bench/backend/clients/qdrant_local/cli.py +60 -0
  26. vectordb_bench/backend/clients/qdrant_local/config.py +47 -0
  27. vectordb_bench/backend/clients/qdrant_local/qdrant_local.py +232 -0
  28. vectordb_bench/backend/clients/weaviate_cloud/cli.py +29 -3
  29. vectordb_bench/backend/clients/weaviate_cloud/config.py +2 -0
  30. vectordb_bench/backend/clients/weaviate_cloud/weaviate_cloud.py +5 -0
  31. vectordb_bench/backend/dataset.py +143 -27
  32. vectordb_bench/backend/filter.py +76 -0
  33. vectordb_bench/backend/runner/__init__.py +3 -3
  34. vectordb_bench/backend/runner/mp_runner.py +52 -39
  35. vectordb_bench/backend/runner/rate_runner.py +68 -52
  36. vectordb_bench/backend/runner/read_write_runner.py +125 -68
  37. vectordb_bench/backend/runner/serial_runner.py +56 -23
  38. vectordb_bench/backend/task_runner.py +48 -20
  39. vectordb_bench/cli/batch_cli.py +121 -0
  40. vectordb_bench/cli/cli.py +59 -1
  41. vectordb_bench/cli/vectordbbench.py +7 -0
  42. vectordb_bench/config-files/batch_sample_config.yml +17 -0
  43. vectordb_bench/frontend/components/check_results/data.py +16 -11
  44. vectordb_bench/frontend/components/check_results/filters.py +53 -25
  45. vectordb_bench/frontend/components/check_results/headerIcon.py +16 -13
  46. vectordb_bench/frontend/components/check_results/nav.py +20 -0
  47. vectordb_bench/frontend/components/custom/displayCustomCase.py +43 -8
  48. vectordb_bench/frontend/components/custom/displaypPrams.py +10 -5
  49. vectordb_bench/frontend/components/custom/getCustomConfig.py +10 -0
  50. vectordb_bench/frontend/components/label_filter/charts.py +60 -0
  51. vectordb_bench/frontend/components/run_test/caseSelector.py +48 -52
  52. vectordb_bench/frontend/components/run_test/dbSelector.py +9 -5
  53. vectordb_bench/frontend/components/run_test/inputWidget.py +48 -0
  54. vectordb_bench/frontend/components/run_test/submitTask.py +3 -1
  55. vectordb_bench/frontend/components/streaming/charts.py +253 -0
  56. vectordb_bench/frontend/components/streaming/data.py +62 -0
  57. vectordb_bench/frontend/components/tables/data.py +1 -1
  58. vectordb_bench/frontend/components/welcome/explainPrams.py +66 -0
  59. vectordb_bench/frontend/components/welcome/pagestyle.py +106 -0
  60. vectordb_bench/frontend/components/welcome/welcomePrams.py +147 -0
  61. vectordb_bench/frontend/config/dbCaseConfigs.py +420 -41
  62. vectordb_bench/frontend/config/styles.py +32 -2
  63. vectordb_bench/frontend/pages/concurrent.py +5 -1
  64. vectordb_bench/frontend/pages/custom.py +4 -0
  65. vectordb_bench/frontend/pages/label_filter.py +56 -0
  66. vectordb_bench/frontend/pages/quries_per_dollar.py +5 -1
  67. vectordb_bench/frontend/pages/results.py +60 -0
  68. vectordb_bench/frontend/pages/run_test.py +3 -3
  69. vectordb_bench/frontend/pages/streaming.py +135 -0
  70. vectordb_bench/frontend/pages/tables.py +4 -0
  71. vectordb_bench/frontend/vdb_benchmark.py +16 -41
  72. vectordb_bench/interface.py +6 -2
  73. vectordb_bench/metric.py +15 -1
  74. vectordb_bench/models.py +38 -11
  75. vectordb_bench/results/ElasticCloud/result_20250318_standard_elasticcloud.json +5890 -0
  76. vectordb_bench/results/Milvus/result_20250509_standard_milvus.json +6138 -0
  77. vectordb_bench/results/OpenSearch/result_20250224_standard_opensearch.json +7319 -0
  78. vectordb_bench/results/Pinecone/result_20250124_standard_pinecone.json +2365 -0
  79. vectordb_bench/results/QdrantCloud/result_20250602_standard_qdrantcloud.json +3556 -0
  80. vectordb_bench/results/ZillizCloud/result_20250613_standard_zillizcloud.json +6290 -0
  81. vectordb_bench/results/dbPrices.json +12 -4
  82. {vectordb_bench-0.0.29.dist-info → vectordb_bench-1.0.0.dist-info}/METADATA +131 -32
  83. {vectordb_bench-0.0.29.dist-info → vectordb_bench-1.0.0.dist-info}/RECORD +87 -65
  84. {vectordb_bench-0.0.29.dist-info → vectordb_bench-1.0.0.dist-info}/WHEEL +1 -1
  85. vectordb_bench/results/ZillizCloud/result_20230727_standard_zillizcloud.json +0 -791
  86. vectordb_bench/results/ZillizCloud/result_20230808_standard_zillizcloud.json +0 -679
  87. vectordb_bench/results/ZillizCloud/result_20240105_standard_202401_zillizcloud.json +0 -1352
  88. {vectordb_bench-0.0.29.dist-info → vectordb_bench-1.0.0.dist-info}/entry_points.txt +0 -0
  89. {vectordb_bench-0.0.29.dist-info → vectordb_bench-1.0.0.dist-info}/licenses/LICENSE +0 -0
  90. {vectordb_bench-0.0.29.dist-info → vectordb_bench-1.0.0.dist-info}/top_level.txt +0 -0
@@ -30,78 +30,94 @@ class RatedMultiThreadingInsertRunner:
30
30
  self.insert_rate = rate
31
31
  self.batch_rate = rate // config.NUM_PER_BATCH
32
32
 
33
- def send_insert_task(self, db: api.VectorDB, emb: list[list[float]], metadata: list[str]):
34
- db.insert_embeddings(emb, metadata)
33
+ self.executing_futures = []
34
+ self.sig_idx = 0
35
+
36
+ def send_insert_task(self, db: api.VectorDB, emb: list[list[float]], metadata: list[str], retry_idx: int = 0):
37
+ _, error = db.insert_embeddings(emb, metadata)
38
+ if error is not None:
39
+ log.warning(f"Insert Failed, try_idx={retry_idx}, Exception: {error}")
40
+ retry_idx += 1
41
+ if retry_idx <= config.MAX_INSERT_RETRY:
42
+ time.sleep(retry_idx)
43
+ self.send_insert_task(db, emb=emb, metadata=metadata, retry_idx=retry_idx)
44
+ else:
45
+ msg = f"Insert failed and retried more than {config.MAX_INSERT_RETRY} times"
46
+ raise RuntimeError(msg) from None
35
47
 
36
48
  @time_it
37
49
  def run_with_rate(self, q: mp.Queue):
38
50
  with ThreadPoolExecutor(max_workers=mp.cpu_count()) as executor:
39
- executing_futures = []
40
51
 
41
52
  @time_it
42
53
  def submit_by_rate() -> bool:
43
54
  rate = self.batch_rate
44
55
  for data in self.dataset:
45
56
  emb, metadata = get_data(data, self.normalize)
46
- executing_futures.append(
47
- executor.submit(self.send_insert_task, self.db, emb, metadata),
48
- )
57
+ self.executing_futures.append(executor.submit(self.send_insert_task, self.db, emb, metadata))
49
58
  rate -= 1
50
59
 
51
60
  if rate == 0:
52
61
  return False
53
62
  return rate == self.batch_rate
54
63
 
64
+ def check_and_send_signal(wait_interval: float, finished: bool = False):
65
+ try:
66
+ done, not_done = concurrent.futures.wait(
67
+ self.executing_futures,
68
+ timeout=wait_interval,
69
+ return_when=concurrent.futures.FIRST_EXCEPTION,
70
+ )
71
+ _ = [fut.result() for fut in done]
72
+ if len(not_done) > 0:
73
+ self.executing_futures = list(not_done)
74
+ else:
75
+ self.executing_futures = []
76
+
77
+ self.sig_idx += len(done)
78
+ while self.sig_idx >= self.batch_rate:
79
+ self.sig_idx -= self.batch_rate
80
+ if self.sig_idx < self.batch_rate and len(not_done) == 0 and finished:
81
+ q.put(True, block=True)
82
+ else:
83
+ q.put(False, block=False)
84
+
85
+ except Exception as e:
86
+ log.warning(f"task error, terminating, err={e}")
87
+ q.put(None, block=True)
88
+ executor.shutdown(wait=True, cancel_futures=True)
89
+ raise e from None
90
+
91
+ time_per_batch = config.TIME_PER_BATCH
55
92
  with self.db.init():
93
+ start_time = time.perf_counter()
94
+ round_idx = 0
95
+
56
96
  while True:
57
- start_time = time.perf_counter()
58
- finished, elapsed_time = submit_by_rate()
59
- if finished is True:
60
- q.put(True, block=True)
61
- log.info(f"End of dataset, left unfinished={len(executing_futures)}")
62
- break
63
-
64
- q.put(False, block=False)
65
- wait_interval = 1 - elapsed_time if elapsed_time < 1 else 0.001
66
-
67
- try:
68
- done, not_done = concurrent.futures.wait(
69
- executing_futures,
70
- timeout=wait_interval,
71
- return_when=concurrent.futures.FIRST_EXCEPTION,
72
- )
73
-
74
- if len(not_done) > 0:
75
- log.warning(
76
- f"Failed to finish all tasks in 1s, [{len(not_done)}/{len(executing_futures)}] "
77
- f"tasks are not done, waited={wait_interval:.2f}, trying to wait in the next round"
97
+ if len(self.executing_futures) > 200:
98
+ log.warning("Skip data insertion this round. There are 200+ unfinished insertion tasks.")
99
+ else:
100
+ finished, elapsed_time = submit_by_rate()
101
+ if finished is True:
102
+ log.info(
103
+ f"End of dataset, left unfinished={len(self.executing_futures)}, num_round={round_idx}"
78
104
  )
79
- executing_futures = list(not_done)
80
- else:
81
- log.debug(
82
- f"Finished {len(executing_futures)} insert-{config.NUM_PER_BATCH} "
83
- f"task in 1s, wait_interval={wait_interval:.2f}"
105
+ break
106
+ if elapsed_time >= 1.5:
107
+ log.warning(
108
+ f"Submit insert tasks took {elapsed_time}s, expected 1s, "
109
+ f"indicating potential resource limitations on the client machine.",
84
110
  )
85
- executing_futures = []
86
- except Exception as e:
87
- log.warning(f"task error, terminating, err={e}")
88
- q.put(None, block=True)
89
- executor.shutdown(wait=True, cancel_futures=True)
90
- raise e from e
91
111
 
92
- dur = time.perf_counter() - start_time
93
- if dur < 1:
94
- time.sleep(1 - dur)
112
+ check_and_send_signal(wait_interval=0.001, finished=False)
113
+ dur = time.perf_counter() - start_time - round_idx * time_per_batch
114
+ if dur < time_per_batch:
115
+ time.sleep(time_per_batch - dur)
116
+ round_idx += 1
95
117
 
96
118
  # wait for all tasks in executing_futures to complete
97
- if len(executing_futures) > 0:
98
- try:
99
- done, _ = concurrent.futures.wait(
100
- executing_futures,
101
- return_when=concurrent.futures.FIRST_EXCEPTION,
102
- )
103
- except Exception as e:
104
- log.warning(f"task error, terminating, err={e}")
105
- q.put(None, block=True)
106
- executor.shutdown(wait=True, cancel_futures=True)
107
- raise e from e
119
+ while len(self.executing_futures) > 0:
120
+ check_and_send_signal(wait_interval=1, finished=True)
121
+ round_idx += 1
122
+
123
+ log.info(f"Finish all streaming insertion, num_round={round_idx}")
@@ -1,13 +1,18 @@
1
1
  import concurrent
2
+ import concurrent.futures
2
3
  import logging
3
4
  import math
4
5
  import multiprocessing as mp
6
+ import time
5
7
  from collections.abc import Iterable
6
8
 
7
9
  import numpy as np
8
10
 
9
11
  from vectordb_bench.backend.clients import api
10
12
  from vectordb_bench.backend.dataset import DatasetManager
13
+ from vectordb_bench.backend.filter import Filter, non_filter
14
+ from vectordb_bench.backend.utils import time_it
15
+ from vectordb_bench.metric import Metric
11
16
 
12
17
  from .mp_runner import MultiProcessingSearchRunner
13
18
  from .rate_runner import RatedMultiThreadingInsertRunner
@@ -24,35 +29,39 @@ class ReadWriteRunner(MultiProcessingSearchRunner, RatedMultiThreadingInsertRunn
24
29
  insert_rate: int = 1000,
25
30
  normalize: bool = False,
26
31
  k: int = 100,
27
- filters: dict | None = None,
32
+ filters: Filter = non_filter,
28
33
  concurrencies: Iterable[int] = (1, 15, 50),
29
- search_stage: Iterable[float] = (
34
+ search_stages: Iterable[float] = (
30
35
  0.5,
31
36
  0.6,
32
37
  0.7,
33
38
  0.8,
34
39
  0.9,
35
40
  ), # search from insert portion, 0.0 means search from the start
41
+ optimize_after_write: bool = True,
36
42
  read_dur_after_write: int = 300, # seconds, search duration when insertion is done
37
43
  timeout: float | None = None,
38
44
  ):
39
45
  self.insert_rate = insert_rate
40
46
  self.data_volume = dataset.data.size
41
47
 
42
- for stage in search_stage:
48
+ for stage in search_stages:
43
49
  assert 0.0 <= stage < 1.0, "each search stage should be in [0.0, 1.0)"
44
- self.search_stage = sorted(search_stage)
50
+ self.search_stages = sorted(search_stages)
51
+ self.optimize_after_write = optimize_after_write
45
52
  self.read_dur_after_write = read_dur_after_write
46
53
 
47
54
  log.info(
48
- f"Init runner, concurencys={concurrencies}, search_stage={search_stage}, "
49
- f"stage_search_dur={read_dur_after_write}"
55
+ f"Init runner, concurencys={concurrencies}, search_stages={self.search_stages}, "
56
+ f"stage_search_dur={read_dur_after_write}",
50
57
  )
51
58
 
52
- test_emb = np.stack(dataset.test_data["emb"])
53
59
  if normalize:
60
+ test_emb = np.array(dataset.test_data)
54
61
  test_emb = test_emb / np.linalg.norm(test_emb, axis=1)[:, np.newaxis]
55
- test_emb = test_emb.tolist()
62
+ test_emb = test_emb.tolist()
63
+ else:
64
+ test_emb = dataset.test_data
56
65
 
57
66
  MultiProcessingSearchRunner.__init__(
58
67
  self,
@@ -74,8 +83,10 @@ class ReadWriteRunner(MultiProcessingSearchRunner, RatedMultiThreadingInsertRunn
74
83
  test_data=test_emb,
75
84
  ground_truth=dataset.gt_data,
76
85
  k=k,
86
+ filters=filters,
77
87
  )
78
88
 
89
+ @time_it
79
90
  def run_optimize(self):
80
91
  """Optimize needs to run in differenct process for pymilvus schema recursion problem"""
81
92
  with self.db.init():
@@ -83,49 +94,102 @@ class ReadWriteRunner(MultiProcessingSearchRunner, RatedMultiThreadingInsertRunn
83
94
  self.db.optimize(data_size=self.data_volume)
84
95
  log.info("Search after write - Optimize finished")
85
96
 
86
- def run_search(self):
97
+ def run_search(self, perc: int):
87
98
  log.info("Search after write - Serial search start")
99
+ test_time = round(time.perf_counter(), 4)
88
100
  res, ssearch_dur = self.serial_search_runner.run()
89
101
  recall, ndcg, p99_latency = res
90
102
  log.info(
91
- f"Search after write - Serial search - recall={recall}, ndcg={ndcg}, p99={p99_latency}, "
92
- f"dur={ssearch_dur:.4f}",
103
+ f"Search after write - Serial search - recall={recall}, ndcg={ndcg}, "
104
+ f"p99={p99_latency}, dur={ssearch_dur:.4f}",
105
+ )
106
+ log.info(
107
+ f"Search after wirte - Conc search start, dur for each conc={self.read_dur_after_write}",
93
108
  )
94
- log.info(f"Search after wirte - Conc search start, dur for each conc={self.read_dur_after_write}")
95
- max_qps = self.run_by_dur(self.read_dur_after_write)
109
+ max_qps, conc_failed_rate = self.run_by_dur(self.read_dur_after_write)
96
110
  log.info(f"Search after wirte - Conc search finished, max_qps={max_qps}")
97
111
 
98
- return (max_qps, recall, ndcg, p99_latency)
112
+ return [(perc, test_time, max_qps, recall, ndcg, p99_latency, conc_failed_rate)]
99
113
 
100
- def run_read_write(self):
101
- with mp.Manager() as m:
102
- q = m.Queue()
103
- with concurrent.futures.ProcessPoolExecutor(
104
- mp_context=mp.get_context("spawn"),
105
- max_workers=2,
106
- ) as executor:
107
- read_write_futures = []
108
- read_write_futures.append(executor.submit(self.run_with_rate, q))
109
- read_write_futures.append(executor.submit(self.run_search_by_sig, q))
114
+ def run_read_write(self) -> Metric:
115
+ """
116
+ Test search performance with a fixed insert rate.
117
+ - Insert requests are sent to VectorDB at a fixed rate within a dedicated insert process pool.
118
+ - if the database cannot promptly process these requests, the process pool will accumulate insert tasks.
119
+ - Search Tests are categorized into three types:
120
+ - streaming_search: Initiates a new search test upon receiving a signal that the inserted data has
121
+ reached the search_stage.
122
+ - streaming_end_search: initiates a new search test after all data has been inserted.
123
+ - optimized_search (optional): After the streaming_end_search, optimizes and initiates a search test.
124
+ """
125
+ m = Metric()
126
+ with mp.Manager() as mp_manager:
127
+ q = mp_manager.Queue()
128
+ with concurrent.futures.ProcessPoolExecutor(mp_context=mp.get_context("spawn"), max_workers=2) as executor:
129
+ insert_future = executor.submit(self.run_with_rate, q)
130
+ streaming_search_future = executor.submit(self.run_search_by_sig, q)
110
131
 
111
132
  try:
112
- for f in concurrent.futures.as_completed(read_write_futures):
113
- res = f.result()
114
- log.info(f"Result = {res}")
133
+ start_time = time.perf_counter()
134
+ _, m.insert_duration = insert_future.result()
135
+ streaming_search_res = streaming_search_future.result()
136
+ if streaming_search_res is None:
137
+ streaming_search_res = []
138
+
139
+ streaming_end_search_future = executor.submit(self.run_search, 100)
140
+ streaming_end_search_res = streaming_end_search_future.result()
115
141
 
116
142
  # Wait for read_write_futures finishing and do optimize and search
117
- op_future = executor.submit(self.run_optimize)
118
- op_future.result()
143
+ if self.optimize_after_write:
144
+ op_future = executor.submit(self.run_optimize)
145
+ _, m.optimize_duration = op_future.result()
146
+ log.info(f"Optimize cost {m.optimize_duration}s")
147
+ optimized_search_future = executor.submit(self.run_search, 110)
148
+ optimized_search_res = optimized_search_future.result()
149
+ else:
150
+ log.info("Skip optimization and search")
151
+ optimized_search_res = []
119
152
 
120
- search_future = executor.submit(self.run_search)
121
- last_res = search_future.result()
153
+ r = [*streaming_search_res, *streaming_end_search_res, *optimized_search_res]
154
+ m.st_search_stage_list = [d[0] for d in r]
155
+ m.st_search_time_list = [round(d[1] - start_time, 4) for d in r]
156
+ m.st_max_qps_list_list = [d[2] for d in r]
157
+ m.st_recall_list = [d[3] for d in r]
158
+ m.st_ndcg_list = [d[4] for d in r]
159
+ m.st_serial_latency_p99_list = [d[5] for d in r]
160
+ m.st_conc_failed_rate_list = [d[6] for d in r]
122
161
 
123
- log.info(f"Max QPS after optimze and search: {last_res}")
124
162
  except Exception as e:
125
163
  log.warning(f"Read and write error: {e}")
126
164
  executor.shutdown(wait=True, cancel_futures=True)
127
- raise e from e
128
- log.info("Concurrent read write all done")
165
+ # raise e
166
+ m.st_ideal_insert_duration = math.ceil(self.data_volume / self.insert_rate)
167
+ log.info(f"Concurrent read write all done, results: {m}")
168
+ return m
169
+
170
+ def get_each_conc_search_dur(self, ssearch_dur: float, cur_stage: float, next_stage: float) -> float:
171
+ # Search duration for non-last search stage is carefully calculated.
172
+ # If duration for each concurrency is less than 30s, runner will raise error.
173
+ total_dur_between_stages = self.data_volume * (next_stage - cur_stage) // self.insert_rate
174
+ csearch_dur = total_dur_between_stages - ssearch_dur
175
+
176
+ # Try to leave room for init process executors
177
+ if csearch_dur > 60:
178
+ csearch_dur -= 30
179
+ elif csearch_dur > 30:
180
+ csearch_dur -= 15
181
+ else:
182
+ csearch_dur /= 2
183
+
184
+ each_conc_search_dur = round(csearch_dur / len(self.concurrencies), 4)
185
+ if each_conc_search_dur < 30:
186
+ warning_msg = (
187
+ f"Results might be inaccurate, duration[{csearch_dur:.4f}] left for conc-search is too short, "
188
+ f"total available dur={total_dur_between_stages}, serial_search_cost={ssearch_dur}, "
189
+ f"each_conc_search_dur={each_conc_search_dur}."
190
+ )
191
+ log.warning(warning_msg)
192
+ return each_conc_search_dur
129
193
 
130
194
  def run_search_by_sig(self, q: mp.Queue):
131
195
  """
@@ -149,7 +213,7 @@ class ReadWriteRunner(MultiProcessingSearchRunner, RatedMultiThreadingInsertRunn
149
213
  start += 1
150
214
  return True
151
215
 
152
- for idx, stage in enumerate(self.search_stage):
216
+ for idx, stage in enumerate(self.search_stages):
153
217
  target_batch = int(total_batch * stage)
154
218
  perc = int(stage * 100)
155
219
 
@@ -159,41 +223,34 @@ class ReadWriteRunner(MultiProcessingSearchRunner, RatedMultiThreadingInsertRunn
159
223
  return None
160
224
 
161
225
  log.info(f"Insert {perc}% done, total batch={total_batch}")
162
- log.info(f"[{target_batch}/{total_batch}] Serial search - {perc}% start")
163
- res, ssearch_dur = self.serial_search_runner.run()
164
- recall, ndcg, p99_latency = res
165
- log.info(
166
- f"[{target_batch}/{total_batch}] Serial search - {perc}% done, recall={recall}, "
167
- f"ndcg={ndcg}, p99={p99_latency}, dur={ssearch_dur:.4f}"
168
- )
169
-
170
- # Search duration for non-last search stage is carefully calculated.
171
- # If duration for each concurrency is less than 30s, runner will raise error.
172
- if idx < len(self.search_stage) - 1:
173
- total_dur_between_stages = self.data_volume * (self.search_stage[idx + 1] - stage) // self.insert_rate
174
- csearch_dur = total_dur_between_stages - ssearch_dur
175
-
176
- # Try to leave room for init process executors
177
- csearch_dur = csearch_dur - 30 if csearch_dur > 60 else csearch_dur
226
+ test_time = round(time.perf_counter(), 4)
227
+ max_qps, recall, ndcg, p99_latency, conc_failed_rate = 0, 0, 0, 0, 0
228
+ try:
229
+ log.info(f"[{target_batch}/{total_batch}] Serial search - {perc}% start")
230
+ res, ssearch_dur = self.serial_search_runner.run()
231
+ ssearch_dur = round(ssearch_dur, 4)
232
+ recall, ndcg, p99_latency = res
233
+ log.info(
234
+ f"[{target_batch}/{total_batch}] Serial search - {perc}% done, "
235
+ f"recall={recall}, ndcg={ndcg}, p99={p99_latency}, dur={ssearch_dur}"
236
+ )
178
237
 
179
- each_conc_search_dur = csearch_dur / len(self.concurrencies)
180
- if each_conc_search_dur < 30:
181
- warning_msg = (
182
- f"Results might be inaccurate, duration[{csearch_dur:.4f}] left for conc-search is too short, "
183
- f"total available dur={total_dur_between_stages}, serial_search_cost={ssearch_dur}."
238
+ each_conc_search_dur = self.get_each_conc_search_dur(
239
+ ssearch_dur,
240
+ cur_stage=stage,
241
+ next_stage=self.search_stages[idx + 1] if idx < len(self.search_stages) - 1 else 1.0,
242
+ )
243
+ if each_conc_search_dur > 10:
244
+ log.info(
245
+ f"[{target_batch}/{total_batch}] Concurrent search - {perc}% start, "
246
+ f"dur={each_conc_search_dur:.4f}"
184
247
  )
185
- log.warning(warning_msg)
186
-
187
- # The last stage
188
- else:
189
- each_conc_search_dur = 60
190
-
191
- log.info(
192
- f"[{target_batch}/{total_batch}] Concurrent search - {perc}% start, dur={each_conc_search_dur:.4f}"
193
- )
194
- max_qps = self.run_by_dur(each_conc_search_dur)
195
- result.append((perc, max_qps, recall, ndcg, p99_latency))
196
-
248
+ max_qps, conc_failed_rate = self.run_by_dur(each_conc_search_dur)
249
+ else:
250
+ log.warning(f"Skip concurrent tests, each_conc_search_dur={each_conc_search_dur} less than 10s.")
251
+ except Exception as e:
252
+ log.warning(f"Streaming Search Failed at stage={stage}. Exception: {e}")
253
+ result.append((perc, test_time, max_qps, recall, ndcg, p99_latency, conc_failed_rate))
197
254
  start_batch = target_batch
198
255
 
199
256
  # Drain the queue
@@ -6,10 +6,10 @@ import time
6
6
  import traceback
7
7
 
8
8
  import numpy as np
9
- import pandas as pd
10
9
  import psutil
11
10
 
12
11
  from vectordb_bench.backend.dataset import DatasetManager
12
+ from vectordb_bench.backend.filter import Filter, FilterOp, non_filter
13
13
 
14
14
  from ... import config
15
15
  from ...metric import calc_ndcg, calc_recall, get_ideal_dcg
@@ -18,8 +18,7 @@ from .. import utils
18
18
  from ..clients import api
19
19
 
20
20
  NUM_PER_BATCH = config.NUM_PER_BATCH
21
- LOAD_MAX_TRY_COUNT = 10
22
- WAITTING_TIME = 60
21
+ LOAD_MAX_TRY_COUNT = config.LOAD_MAX_TRY_COUNT
23
22
 
24
23
  log = logging.getLogger(__name__)
25
24
 
@@ -30,12 +29,26 @@ class SerialInsertRunner:
30
29
  db: api.VectorDB,
31
30
  dataset: DatasetManager,
32
31
  normalize: bool,
32
+ filters: Filter = non_filter,
33
33
  timeout: float | None = None,
34
34
  ):
35
35
  self.timeout = timeout if isinstance(timeout, int | float) else None
36
36
  self.dataset = dataset
37
37
  self.db = db
38
38
  self.normalize = normalize
39
+ self.filters = filters
40
+
41
+ def retry_insert(self, db: api.VectorDB, retry_idx: int = 0, **kwargs):
42
+ _, error = db.insert_embeddings(**kwargs)
43
+ if error is not None:
44
+ log.warning(f"Insert Failed, try_idx={retry_idx}, Exception: {error}")
45
+ retry_idx += 1
46
+ if retry_idx <= config.MAX_INSERT_RETRY:
47
+ time.sleep(retry_idx)
48
+ self.retry_insert(db, retry_idx=retry_idx, **kwargs)
49
+ else:
50
+ msg = f"Insert failed and retried more than {config.MAX_INSERT_RETRY} times"
51
+ raise RuntimeError(msg) from None
39
52
 
40
53
  def task(self) -> int:
41
54
  count = 0
@@ -43,9 +56,9 @@ class SerialInsertRunner:
43
56
  log.info(f"({mp.current_process().name:16}) Start inserting embeddings in batch {config.NUM_PER_BATCH}")
44
57
  start = time.perf_counter()
45
58
  for data_df in self.dataset:
46
- all_metadata = data_df["id"].tolist()
59
+ all_metadata = data_df[self.dataset.data.train_id_field].tolist()
47
60
 
48
- emb_np = np.stack(data_df["emb"])
61
+ emb_np = np.stack(data_df[self.dataset.data.train_vector_field])
49
62
  if self.normalize:
50
63
  log.debug("normalize the 100k train data")
51
64
  all_embeddings = (emb_np / np.linalg.norm(emb_np, axis=1)[:, np.newaxis]).tolist()
@@ -54,12 +67,25 @@ class SerialInsertRunner:
54
67
  del emb_np
55
68
  log.debug(f"batch dataset size: {len(all_embeddings)}, {len(all_metadata)}")
56
69
 
70
+ labels_data = None
71
+ if self.filters.type == FilterOp.StrEqual:
72
+ if self.dataset.data.scalar_labels_file_separated:
73
+ labels_data = self.dataset.scalar_labels[self.filters.label_field][all_metadata].to_list()
74
+ else:
75
+ labels_data = data_df[self.filters.label_field].tolist()
76
+
57
77
  insert_count, error = self.db.insert_embeddings(
58
78
  embeddings=all_embeddings,
59
79
  metadata=all_metadata,
80
+ labels_data=labels_data,
60
81
  )
61
82
  if error is not None:
62
- raise error
83
+ self.retry_insert(
84
+ self.db,
85
+ embeddings=all_embeddings,
86
+ metadata=all_metadata,
87
+ labels_data=labels_data,
88
+ )
63
89
 
64
90
  assert insert_count == len(all_metadata)
65
91
  count += insert_count
@@ -101,7 +127,7 @@ class SerialInsertRunner:
101
127
  already_insert_count += insert_count
102
128
  if error is not None:
103
129
  retry_count += 1
104
- time.sleep(WAITTING_TIME)
130
+ time.sleep(10)
105
131
 
106
132
  log.info(f"Failed to insert data, try {retry_count} time")
107
133
  if retry_count >= LOAD_MAX_TRY_COUNT:
@@ -149,8 +175,8 @@ class SerialInsertRunner:
149
175
  # only 1 file
150
176
  data_df = next(iter(self.dataset))
151
177
  all_embeddings, all_metadata = (
152
- np.stack(data_df["emb"]).tolist(),
153
- data_df["id"].tolist(),
178
+ np.stack(data_df[self.dataset.data.train_vector_field]).tolist(),
179
+ data_df[self.dataset.data.train_id_field].tolist(),
154
180
  )
155
181
 
156
182
  start_time = time.perf_counter()
@@ -188,9 +214,9 @@ class SerialSearchRunner:
188
214
  self,
189
215
  db: api.VectorDB,
190
216
  test_data: list[list[float]],
191
- ground_truth: pd.DataFrame,
217
+ ground_truth: list[list[int]],
192
218
  k: int = 100,
193
- filters: dict | None = None,
219
+ filters: Filter = non_filter,
194
220
  ):
195
221
  self.db = db
196
222
  self.k = k
@@ -202,35 +228,42 @@ class SerialSearchRunner:
202
228
  self.test_data = test_data
203
229
  self.ground_truth = ground_truth
204
230
 
205
- def search(self, args: tuple[list, pd.DataFrame]) -> tuple[float, float, float]:
231
+ def _get_db_search_res(self, emb: list[float], retry_idx: int = 0) -> list[int]:
232
+ try:
233
+ results = self.db.search_embedding(emb, self.k)
234
+ except Exception as e:
235
+ log.warning(f"Serial search failed, retry_idx={retry_idx}, Exception: {e}")
236
+ if retry_idx < config.MAX_SEARCH_RETRY:
237
+ return self._get_db_search_res(emb=emb, retry_idx=retry_idx + 1)
238
+
239
+ msg = f"Serial search failed and retried more than {config.MAX_SEARCH_RETRY} times"
240
+ raise RuntimeError(msg) from e
241
+
242
+ return results
243
+
244
+ def search(self, args: tuple[list, list[list[int]]]) -> tuple[float, float, float]:
206
245
  log.info(f"{mp.current_process().name:14} start search the entire test_data to get recall and latency")
207
246
  with self.db.init():
247
+ self.db.prepare_filter(self.filters)
208
248
  test_data, ground_truth = args
209
249
  ideal_dcg = get_ideal_dcg(self.k)
210
250
 
211
251
  log.debug(f"test dataset size: {len(test_data)}")
212
- if ground_truth is not None:
213
- log.debug(f"ground truth size: {ground_truth.columns}, shape: {ground_truth.shape}")
252
+ log.debug(f"ground truth size: {len(ground_truth)}")
214
253
 
215
254
  latencies, recalls, ndcgs = [], [], []
216
255
  for idx, emb in enumerate(test_data):
217
256
  s = time.perf_counter()
218
257
  try:
219
- results = self.db.search_embedding(
220
- emb,
221
- self.k,
222
- self.filters,
223
- )
224
-
258
+ results = self._get_db_search_res(emb)
225
259
  except Exception as e:
226
260
  log.warning(f"VectorDB search_embedding error: {e}")
227
- traceback.print_exc(chain=True)
228
261
  raise e from None
229
262
 
230
263
  latencies.append(time.perf_counter() - s)
231
264
 
232
265
  if ground_truth is not None:
233
- gt = ground_truth["neighbors_id"][idx]
266
+ gt = ground_truth[idx]
234
267
  recalls.append(calc_recall(self.k, gt[: self.k], results))
235
268
  ndcgs.append(calc_ndcg(gt[: self.k], results, ideal_dcg))
236
269
  else:
@@ -253,7 +286,7 @@ class SerialSearchRunner:
253
286
  f"cost={cost}s, "
254
287
  f"queries={len(latencies)}, "
255
288
  f"avg_recall={avg_recall}, "
256
- f"avg_ndcg={avg_ndcg},"
289
+ f"avg_ndcg={avg_ndcg}, "
257
290
  f"avg_latency={avg_latency}, "
258
291
  f"p99={p99}"
259
292
  )