vectordb-bench 0.0.18__py3-none-any.whl → 0.0.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (105) hide show
  1. vectordb_bench/__init__.py +49 -24
  2. vectordb_bench/__main__.py +4 -3
  3. vectordb_bench/backend/assembler.py +12 -13
  4. vectordb_bench/backend/cases.py +56 -46
  5. vectordb_bench/backend/clients/__init__.py +101 -14
  6. vectordb_bench/backend/clients/aliyun_elasticsearch/aliyun_elasticsearch.py +26 -0
  7. vectordb_bench/backend/clients/aliyun_elasticsearch/config.py +18 -0
  8. vectordb_bench/backend/clients/aliyun_opensearch/aliyun_opensearch.py +345 -0
  9. vectordb_bench/backend/clients/aliyun_opensearch/config.py +47 -0
  10. vectordb_bench/backend/clients/alloydb/alloydb.py +58 -80
  11. vectordb_bench/backend/clients/alloydb/cli.py +52 -35
  12. vectordb_bench/backend/clients/alloydb/config.py +30 -30
  13. vectordb_bench/backend/clients/api.py +8 -9
  14. vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py +46 -47
  15. vectordb_bench/backend/clients/aws_opensearch/cli.py +4 -7
  16. vectordb_bench/backend/clients/aws_opensearch/config.py +13 -9
  17. vectordb_bench/backend/clients/aws_opensearch/run.py +69 -59
  18. vectordb_bench/backend/clients/chroma/chroma.py +38 -36
  19. vectordb_bench/backend/clients/chroma/config.py +4 -2
  20. vectordb_bench/backend/clients/elastic_cloud/config.py +5 -5
  21. vectordb_bench/backend/clients/elastic_cloud/elastic_cloud.py +23 -22
  22. vectordb_bench/backend/clients/memorydb/cli.py +8 -8
  23. vectordb_bench/backend/clients/memorydb/config.py +2 -2
  24. vectordb_bench/backend/clients/memorydb/memorydb.py +65 -53
  25. vectordb_bench/backend/clients/milvus/cli.py +62 -80
  26. vectordb_bench/backend/clients/milvus/config.py +31 -7
  27. vectordb_bench/backend/clients/milvus/milvus.py +23 -26
  28. vectordb_bench/backend/clients/pgdiskann/cli.py +29 -22
  29. vectordb_bench/backend/clients/pgdiskann/config.py +29 -26
  30. vectordb_bench/backend/clients/pgdiskann/pgdiskann.py +55 -73
  31. vectordb_bench/backend/clients/pgvecto_rs/cli.py +9 -11
  32. vectordb_bench/backend/clients/pgvecto_rs/config.py +8 -14
  33. vectordb_bench/backend/clients/pgvecto_rs/pgvecto_rs.py +33 -34
  34. vectordb_bench/backend/clients/pgvector/cli.py +40 -31
  35. vectordb_bench/backend/clients/pgvector/config.py +63 -73
  36. vectordb_bench/backend/clients/pgvector/pgvector.py +97 -98
  37. vectordb_bench/backend/clients/pgvectorscale/cli.py +38 -24
  38. vectordb_bench/backend/clients/pgvectorscale/config.py +14 -15
  39. vectordb_bench/backend/clients/pgvectorscale/pgvectorscale.py +38 -43
  40. vectordb_bench/backend/clients/pinecone/config.py +1 -0
  41. vectordb_bench/backend/clients/pinecone/pinecone.py +14 -21
  42. vectordb_bench/backend/clients/qdrant_cloud/config.py +11 -10
  43. vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py +40 -31
  44. vectordb_bench/backend/clients/redis/cli.py +6 -12
  45. vectordb_bench/backend/clients/redis/config.py +7 -5
  46. vectordb_bench/backend/clients/redis/redis.py +94 -58
  47. vectordb_bench/backend/clients/test/cli.py +1 -2
  48. vectordb_bench/backend/clients/test/config.py +2 -2
  49. vectordb_bench/backend/clients/test/test.py +4 -5
  50. vectordb_bench/backend/clients/weaviate_cloud/cli.py +3 -4
  51. vectordb_bench/backend/clients/weaviate_cloud/config.py +2 -2
  52. vectordb_bench/backend/clients/weaviate_cloud/weaviate_cloud.py +36 -22
  53. vectordb_bench/backend/clients/zilliz_cloud/cli.py +14 -11
  54. vectordb_bench/backend/clients/zilliz_cloud/config.py +2 -4
  55. vectordb_bench/backend/clients/zilliz_cloud/zilliz_cloud.py +1 -1
  56. vectordb_bench/backend/data_source.py +30 -18
  57. vectordb_bench/backend/dataset.py +47 -27
  58. vectordb_bench/backend/result_collector.py +2 -3
  59. vectordb_bench/backend/runner/__init__.py +4 -6
  60. vectordb_bench/backend/runner/mp_runner.py +85 -34
  61. vectordb_bench/backend/runner/rate_runner.py +51 -23
  62. vectordb_bench/backend/runner/read_write_runner.py +140 -46
  63. vectordb_bench/backend/runner/serial_runner.py +99 -50
  64. vectordb_bench/backend/runner/util.py +4 -19
  65. vectordb_bench/backend/task_runner.py +95 -74
  66. vectordb_bench/backend/utils.py +17 -9
  67. vectordb_bench/base.py +0 -1
  68. vectordb_bench/cli/cli.py +65 -60
  69. vectordb_bench/cli/vectordbbench.py +6 -7
  70. vectordb_bench/frontend/components/check_results/charts.py +8 -19
  71. vectordb_bench/frontend/components/check_results/data.py +4 -16
  72. vectordb_bench/frontend/components/check_results/filters.py +8 -16
  73. vectordb_bench/frontend/components/check_results/nav.py +4 -4
  74. vectordb_bench/frontend/components/check_results/priceTable.py +1 -3
  75. vectordb_bench/frontend/components/check_results/stPageConfig.py +2 -1
  76. vectordb_bench/frontend/components/concurrent/charts.py +12 -12
  77. vectordb_bench/frontend/components/custom/displayCustomCase.py +17 -11
  78. vectordb_bench/frontend/components/custom/displaypPrams.py +4 -2
  79. vectordb_bench/frontend/components/custom/getCustomConfig.py +1 -2
  80. vectordb_bench/frontend/components/custom/initStyle.py +1 -1
  81. vectordb_bench/frontend/components/get_results/saveAsImage.py +2 -0
  82. vectordb_bench/frontend/components/run_test/caseSelector.py +3 -9
  83. vectordb_bench/frontend/components/run_test/dbConfigSetting.py +1 -4
  84. vectordb_bench/frontend/components/run_test/dbSelector.py +1 -1
  85. vectordb_bench/frontend/components/run_test/generateTasks.py +8 -8
  86. vectordb_bench/frontend/components/run_test/submitTask.py +14 -18
  87. vectordb_bench/frontend/components/tables/data.py +3 -6
  88. vectordb_bench/frontend/config/dbCaseConfigs.py +108 -83
  89. vectordb_bench/frontend/pages/concurrent.py +3 -5
  90. vectordb_bench/frontend/pages/custom.py +30 -9
  91. vectordb_bench/frontend/pages/quries_per_dollar.py +3 -3
  92. vectordb_bench/frontend/pages/run_test.py +3 -7
  93. vectordb_bench/frontend/utils.py +1 -1
  94. vectordb_bench/frontend/vdb_benchmark.py +4 -6
  95. vectordb_bench/interface.py +56 -26
  96. vectordb_bench/log_util.py +59 -64
  97. vectordb_bench/metric.py +10 -11
  98. vectordb_bench/models.py +26 -43
  99. {vectordb_bench-0.0.18.dist-info → vectordb_bench-0.0.20.dist-info}/METADATA +34 -42
  100. vectordb_bench-0.0.20.dist-info/RECORD +135 -0
  101. {vectordb_bench-0.0.18.dist-info → vectordb_bench-0.0.20.dist-info}/WHEEL +1 -1
  102. vectordb_bench-0.0.18.dist-info/RECORD +0 -131
  103. {vectordb_bench-0.0.18.dist-info → vectordb_bench-0.0.20.dist-info}/LICENSE +0 -0
  104. {vectordb_bench-0.0.18.dist-info → vectordb_bench-0.0.20.dist-info}/entry_points.txt +0 -0
  105. {vectordb_bench-0.0.18.dist-info → vectordb_bench-0.0.20.dist-info}/top_level.txt +0 -0
@@ -1,35 +1,36 @@
1
+ import concurrent
1
2
  import logging
3
+ import multiprocessing as mp
2
4
  import time
3
5
  from concurrent.futures import ThreadPoolExecutor
4
- import multiprocessing as mp
5
-
6
6
 
7
+ from vectordb_bench import config
7
8
  from vectordb_bench.backend.clients import api
8
9
  from vectordb_bench.backend.dataset import DataSetIterator
9
10
  from vectordb_bench.backend.utils import time_it
10
- from vectordb_bench import config
11
11
 
12
- from .util import get_data, is_futures_completed, get_future_exceptions
12
+ from .util import get_data
13
+
13
14
  log = logging.getLogger(__name__)
14
15
 
15
16
 
16
17
  class RatedMultiThreadingInsertRunner:
17
18
  def __init__(
18
19
  self,
19
- rate: int, # numRows per second
20
+ rate: int, # numRows per second
20
21
  db: api.VectorDB,
21
22
  dataset_iter: DataSetIterator,
22
23
  normalize: bool = False,
23
24
  timeout: float | None = None,
24
25
  ):
25
- self.timeout = timeout if isinstance(timeout, (int, float)) else None
26
+ self.timeout = timeout if isinstance(timeout, int | float) else None
26
27
  self.dataset = dataset_iter
27
28
  self.db = db
28
29
  self.normalize = normalize
29
30
  self.insert_rate = rate
30
31
  self.batch_rate = rate // config.NUM_PER_BATCH
31
32
 
32
- def send_insert_task(self, db, emb: list[list[float]], metadata: list[str]):
33
+ def send_insert_task(self, db: api.VectorDB, emb: list[list[float]], metadata: list[str]):
33
34
  db.insert_embeddings(emb, metadata)
34
35
 
35
36
  @time_it
@@ -42,7 +43,9 @@ class RatedMultiThreadingInsertRunner:
42
43
  rate = self.batch_rate
43
44
  for data in self.dataset:
44
45
  emb, metadata = get_data(data, self.normalize)
45
- executing_futures.append(executor.submit(self.send_insert_task, self.db, emb, metadata))
46
+ executing_futures.append(
47
+ executor.submit(self.send_insert_task, self.db, emb, metadata),
48
+ )
46
49
  rate -= 1
47
50
 
48
51
  if rate == 0:
@@ -54,26 +57,51 @@ class RatedMultiThreadingInsertRunner:
54
57
  start_time = time.perf_counter()
55
58
  finished, elapsed_time = submit_by_rate()
56
59
  if finished is True:
57
- q.put(None, block=True)
60
+ q.put(True, block=True)
58
61
  log.info(f"End of dataset, left unfinished={len(executing_futures)}")
59
- return
62
+ break
60
63
 
61
- q.put(True, block=False)
64
+ q.put(False, block=False)
62
65
  wait_interval = 1 - elapsed_time if elapsed_time < 1 else 0.001
63
66
 
64
- e, completed = is_futures_completed(executing_futures, wait_interval)
65
- if completed is True:
66
- ex = get_future_exceptions(executing_futures)
67
- if ex is not None:
68
- log.warn(f"task error, terminating, err={ex}")
69
- q.put(None)
70
- executor.shutdown(wait=True, cancel_futures=True)
71
- raise ex
67
+ try:
68
+ done, not_done = concurrent.futures.wait(
69
+ executing_futures,
70
+ timeout=wait_interval,
71
+ return_when=concurrent.futures.FIRST_EXCEPTION,
72
+ )
73
+
74
+ if len(not_done) > 0:
75
+ log.warning(
76
+ f"Failed to finish all tasks in 1s, [{len(not_done)}/{len(executing_futures)}] ",
77
+ f"tasks are not done, waited={wait_interval:.2f}, trying to wait in the next round",
78
+ )
79
+ executing_futures = list(not_done)
72
80
  else:
73
- log.debug(f"Finished {len(executing_futures)} insert-{config.NUM_PER_BATCH} task in 1s, wait_interval={wait_interval:.2f}")
74
- executing_futures = []
75
- else:
76
- log.warning(f"Failed to finish tasks in 1s, {e}, waited={wait_interval:.2f}, try to check the next round")
81
+ log.debug(
82
+ f"Finished {len(executing_futures)} insert-{config.NUM_PER_BATCH} ",
83
+ f"task in 1s, wait_interval={wait_interval:.2f}",
84
+ )
85
+ executing_futures = []
86
+ except Exception as e:
87
+ log.warning(f"task error, terminating, err={e}")
88
+ q.put(None, block=True)
89
+ executor.shutdown(wait=True, cancel_futures=True)
90
+ raise e from e
91
+
77
92
  dur = time.perf_counter() - start_time
78
93
  if dur < 1:
79
94
  time.sleep(1 - dur)
95
+
96
+ # wait for all tasks in executing_futures to complete
97
+ if len(executing_futures) > 0:
98
+ try:
99
+ done, _ = concurrent.futures.wait(
100
+ executing_futures,
101
+ return_when=concurrent.futures.FIRST_EXCEPTION,
102
+ )
103
+ except Exception as e:
104
+ log.warning(f"task error, terminating, err={e}")
105
+ q.put(None, block=True)
106
+ executor.shutdown(wait=True, cancel_futures=True)
107
+ raise e from e
@@ -1,16 +1,18 @@
1
+ import concurrent
1
2
  import logging
2
- from typing import Iterable
3
+ import math
3
4
  import multiprocessing as mp
4
- import concurrent
5
+ from collections.abc import Iterable
6
+
5
7
  import numpy as np
6
- import math
7
8
 
8
- from .mp_runner import MultiProcessingSearchRunner
9
- from .serial_runner import SerialSearchRunner
10
- from .rate_runner import RatedMultiThreadingInsertRunner
11
9
  from vectordb_bench.backend.clients import api
12
10
  from vectordb_bench.backend.dataset import DatasetManager
13
11
 
12
+ from .mp_runner import MultiProcessingSearchRunner
13
+ from .rate_runner import RatedMultiThreadingInsertRunner
14
+ from .serial_runner import SerialSearchRunner
15
+
14
16
  log = logging.getLogger(__name__)
15
17
 
16
18
 
@@ -24,19 +26,28 @@ class ReadWriteRunner(MultiProcessingSearchRunner, RatedMultiThreadingInsertRunn
24
26
  k: int = 100,
25
27
  filters: dict | None = None,
26
28
  concurrencies: Iterable[int] = (1, 15, 50),
27
- search_stage: Iterable[float] = (0.5, 0.6, 0.7, 0.8, 0.9, 1.0), # search in any insert portion, 0.0 means search from the start
28
- read_dur_after_write: int = 300, # seconds, search duration when insertion is done
29
+ search_stage: Iterable[float] = (
30
+ 0.5,
31
+ 0.6,
32
+ 0.7,
33
+ 0.8,
34
+ 0.9,
35
+ ), # search from insert portion, 0.0 means search from the start
36
+ read_dur_after_write: int = 300, # seconds, search duration when insertion is done
29
37
  timeout: float | None = None,
30
38
  ):
31
39
  self.insert_rate = insert_rate
32
40
  self.data_volume = dataset.data.size
33
41
 
34
42
  for stage in search_stage:
35
- assert 0.0 <= stage <= 1.0, "each search stage should be in [0.0, 1.0]"
43
+ assert 0.0 <= stage < 1.0, "each search stage should be in [0.0, 1.0)"
36
44
  self.search_stage = sorted(search_stage)
37
45
  self.read_dur_after_write = read_dur_after_write
38
46
 
39
- log.info(f"Init runner, concurencys={concurrencies}, search_stage={search_stage}, stage_search_dur={read_dur_after_write}")
47
+ log.info(
48
+ f"Init runner, concurencys={concurrencies}, search_stage={search_stage}, ",
49
+ f"stage_search_dur={read_dur_after_write}",
50
+ )
40
51
 
41
52
  test_emb = np.stack(dataset.test_data["emb"])
42
53
  if normalize:
@@ -65,48 +76,131 @@ class ReadWriteRunner(MultiProcessingSearchRunner, RatedMultiThreadingInsertRunn
65
76
  k=k,
66
77
  )
67
78
 
79
+ def run_optimize(self):
80
+ """Optimize needs to run in differenct process for pymilvus schema recursion problem"""
81
+ with self.db.init():
82
+ log.info("Search after write - Optimize start")
83
+ self.db.optimize()
84
+ log.info("Search after write - Optimize finished")
85
+
86
+ def run_search(self):
87
+ log.info("Search after write - Serial search start")
88
+ res, ssearch_dur = self.serial_search_runner.run()
89
+ recall, ndcg, p99_latency = res
90
+ log.info(
91
+ f"Search after write - Serial search - recall={recall}, ndcg={ndcg}, p99={p99_latency}, ",
92
+ f"dur={ssearch_dur:.4f}",
93
+ )
94
+ log.info(
95
+ f"Search after wirte - Conc search start, dur for each conc={self.read_dur_after_write}",
96
+ )
97
+ max_qps = self.run_by_dur(self.read_dur_after_write)
98
+ log.info(f"Search after wirte - Conc search finished, max_qps={max_qps}")
99
+
100
+ return (max_qps, recall, ndcg, p99_latency)
101
+
68
102
  def run_read_write(self):
69
- futures = []
70
103
  with mp.Manager() as m:
71
104
  q = m.Queue()
72
- with concurrent.futures.ProcessPoolExecutor(mp_context=mp.get_context("spawn"), max_workers=2) as executor:
73
- futures.append(executor.submit(self.run_with_rate, q))
74
- futures.append(executor.submit(self.run_search_by_sig, q))
75
-
76
- for future in concurrent.futures.as_completed(futures):
77
- res = future.result()
78
- log.info(f"Result = {res}")
79
-
105
+ with concurrent.futures.ProcessPoolExecutor(
106
+ mp_context=mp.get_context("spawn"),
107
+ max_workers=2,
108
+ ) as executor:
109
+ read_write_futures = []
110
+ read_write_futures.append(executor.submit(self.run_with_rate, q))
111
+ read_write_futures.append(executor.submit(self.run_search_by_sig, q))
112
+
113
+ try:
114
+ for f in concurrent.futures.as_completed(read_write_futures):
115
+ res = f.result()
116
+ log.info(f"Result = {res}")
117
+
118
+ # Wait for read_write_futures finishing and do optimize and search
119
+ op_future = executor.submit(self.run_optimize)
120
+ op_future.result()
121
+
122
+ search_future = executor.submit(self.run_search)
123
+ last_res = search_future.result()
124
+
125
+ log.info(f"Max QPS after optimze and search: {last_res}")
126
+ except Exception as e:
127
+ log.warning(f"Read and write error: {e}")
128
+ executor.shutdown(wait=True, cancel_futures=True)
129
+ raise e from e
80
130
  log.info("Concurrent read write all done")
81
131
 
82
-
83
- def run_search_by_sig(self, q):
84
- res = []
132
+ def run_search_by_sig(self, q: mp.Queue):
133
+ """
134
+ Args:
135
+ q: multiprocessing queue
136
+ (None) means abnormal exit
137
+ (False) means updating progress
138
+ (True) means normal exit
139
+ """
140
+ result, start_batch = [], 0
85
141
  total_batch = math.ceil(self.data_volume / self.insert_rate)
86
- batch = 0
87
- recall = 'x'
142
+ recall, ndcg, p99_latency = None, None, None
143
+
144
+ def wait_next_target(start: int, target_batch: int) -> bool:
145
+ """Return False when receive True or None"""
146
+ while start < target_batch:
147
+ sig = q.get(block=True)
148
+
149
+ if sig is None or sig is True:
150
+ return False
151
+ start += 1
152
+ return True
88
153
 
89
154
  for idx, stage in enumerate(self.search_stage):
90
155
  target_batch = int(total_batch * stage)
91
- while q.get(block=True):
92
- batch += 1
93
- if batch >= target_batch:
94
- perc = int(stage * 100)
95
- log.info(f"Insert {perc}% done, total batch={total_batch}")
96
- log.info(f"[{batch}/{total_batch}] Serial search - {perc}% start")
97
- recall, ndcg, p99 =self.serial_search_runner.run()
98
-
99
- if idx < len(self.search_stage) - 1:
100
- stage_search_dur = (self.data_volume * (self.search_stage[idx + 1] - stage) // self.insert_rate) // len(self.concurrencies)
101
- if stage_search_dur < 30:
102
- log.warning(f"Search duration too short, please reduce concurrency count or insert rate, or increase dataset volume: dur={stage_search_dur}, concurrencies={len(self.concurrencies)}, insert_rate={self.insert_rate}")
103
- log.info(f"[{batch}/{total_batch}] Conc search - {perc}% start, dur for each conc={stage_search_dur}s")
104
- else:
105
- last_search_dur = self.data_volume * (1.0 - stage) // self.insert_rate
106
- stage_search_dur = last_search_dur + self.read_dur_after_write
107
- log.info(f"[{batch}/{total_batch}] Last conc search - {perc}% start, [read_until_write|read_after_write|total] =[{last_search_dur}s|{self.read_dur_after_write}s|{stage_search_dur}s]")
108
-
109
- max_qps = self.run_by_dur(stage_search_dur)
110
- res.append((perc, max_qps, recall))
111
- break
112
- return res
156
+ perc = int(stage * 100)
157
+
158
+ got = wait_next_target(start_batch, target_batch)
159
+ if got is False:
160
+ log.warning(
161
+ f"Abnormal exit, target_batch={target_batch}, start_batch={start_batch}",
162
+ )
163
+ return None
164
+
165
+ log.info(f"Insert {perc}% done, total batch={total_batch}")
166
+ log.info(f"[{target_batch}/{total_batch}] Serial search - {perc}% start")
167
+ res, ssearch_dur = self.serial_search_runner.run()
168
+ recall, ndcg, p99_latency = res
169
+ log.info(
170
+ f"[{target_batch}/{total_batch}] Serial search - {perc}% done, recall={recall}, ",
171
+ f"ndcg={ndcg}, p99={p99_latency}, dur={ssearch_dur:.4f}",
172
+ )
173
+
174
+ # Search duration for non-last search stage is carefully calculated.
175
+ # If duration for each concurrency is less than 30s, runner will raise error.
176
+ if idx < len(self.search_stage) - 1:
177
+ total_dur_between_stages = self.data_volume * (self.search_stage[idx + 1] - stage) // self.insert_rate
178
+ csearch_dur = total_dur_between_stages - ssearch_dur
179
+
180
+ # Try to leave room for init process executors
181
+ csearch_dur = csearch_dur - 30 if csearch_dur > 60 else csearch_dur
182
+
183
+ each_conc_search_dur = csearch_dur / len(self.concurrencies)
184
+ if each_conc_search_dur < 30:
185
+ warning_msg = (
186
+ f"Results might be inaccurate, duration[{csearch_dur:.4f}] left for conc-search is too short, ",
187
+ f"total available dur={total_dur_between_stages}, serial_search_cost={ssearch_dur}.",
188
+ )
189
+ log.warning(warning_msg)
190
+
191
+ # The last stage
192
+ else:
193
+ each_conc_search_dur = 60
194
+
195
+ log.info(
196
+ f"[{target_batch}/{total_batch}] Concurrent search - {perc}% start, dur={each_conc_search_dur:.4f}",
197
+ )
198
+ max_qps = self.run_by_dur(each_conc_search_dur)
199
+ result.append((perc, max_qps, recall, ndcg, p99_latency))
200
+
201
+ start_batch = target_batch
202
+
203
+ # Drain the queue
204
+ while q.empty() is False:
205
+ q.get(block=True)
206
+ return result
@@ -1,20 +1,21 @@
1
- import time
2
- import logging
3
- import traceback
4
1
  import concurrent
5
- import multiprocessing as mp
2
+ import logging
6
3
  import math
7
- import psutil
4
+ import multiprocessing as mp
5
+ import time
6
+ import traceback
8
7
 
9
8
  import numpy as np
10
9
  import pandas as pd
10
+ import psutil
11
11
 
12
- from ..clients import api
12
+ from vectordb_bench.backend.dataset import DatasetManager
13
+
14
+ from ... import config
13
15
  from ...metric import calc_ndcg, calc_recall, get_ideal_dcg
14
16
  from ...models import LoadTimeoutError, PerformanceTimeoutError
15
17
  from .. import utils
16
- from ... import config
17
- from vectordb_bench.backend.dataset import DatasetManager
18
+ from ..clients import api
18
19
 
19
20
  NUM_PER_BATCH = config.NUM_PER_BATCH
20
21
  LOAD_MAX_TRY_COUNT = 10
@@ -22,9 +23,16 @@ WAITTING_TIME = 60
22
23
 
23
24
  log = logging.getLogger(__name__)
24
25
 
26
+
25
27
  class SerialInsertRunner:
26
- def __init__(self, db: api.VectorDB, dataset: DatasetManager, normalize: bool, timeout: float | None = None):
27
- self.timeout = timeout if isinstance(timeout, (int, float)) else None
28
+ def __init__(
29
+ self,
30
+ db: api.VectorDB,
31
+ dataset: DatasetManager,
32
+ normalize: bool,
33
+ timeout: float | None = None,
34
+ ):
35
+ self.timeout = timeout if isinstance(timeout, int | float) else None
28
36
  self.dataset = dataset
29
37
  self.db = db
30
38
  self.normalize = normalize
@@ -32,18 +40,20 @@ class SerialInsertRunner:
32
40
  def task(self) -> int:
33
41
  count = 0
34
42
  with self.db.init():
35
- log.info(f"({mp.current_process().name:16}) Start inserting embeddings in batch {config.NUM_PER_BATCH}")
43
+ log.info(
44
+ f"({mp.current_process().name:16}) Start inserting embeddings in batch {config.NUM_PER_BATCH}",
45
+ )
36
46
  start = time.perf_counter()
37
47
  for data_df in self.dataset:
38
- all_metadata = data_df['id'].tolist()
48
+ all_metadata = data_df["id"].tolist()
39
49
 
40
- emb_np = np.stack(data_df['emb'])
50
+ emb_np = np.stack(data_df["emb"])
41
51
  if self.normalize:
42
52
  log.debug("normalize the 100k train data")
43
53
  all_embeddings = (emb_np / np.linalg.norm(emb_np, axis=1)[:, np.newaxis]).tolist()
44
54
  else:
45
55
  all_embeddings = emb_np.tolist()
46
- del(emb_np)
56
+ del emb_np
47
57
  log.debug(f"batch dataset size: {len(all_embeddings)}, {len(all_metadata)}")
48
58
 
49
59
  insert_count, error = self.db.insert_embeddings(
@@ -56,30 +66,41 @@ class SerialInsertRunner:
56
66
  assert insert_count == len(all_metadata)
57
67
  count += insert_count
58
68
  if count % 100_000 == 0:
59
- log.info(f"({mp.current_process().name:16}) Loaded {count} embeddings into VectorDB")
69
+ log.info(
70
+ f"({mp.current_process().name:16}) Loaded {count} embeddings into VectorDB",
71
+ )
60
72
 
61
- log.info(f"({mp.current_process().name:16}) Finish loading all dataset into VectorDB, dur={time.perf_counter()-start}")
73
+ log.info(
74
+ f"({mp.current_process().name:16}) Finish loading all dataset into VectorDB, ",
75
+ f"dur={time.perf_counter()-start}",
76
+ )
62
77
  return count
63
78
 
64
- def endless_insert_data(self, all_embeddings, all_metadata, left_id: int = 0) -> int:
79
+ def endless_insert_data(self, all_embeddings: list, all_metadata: list, left_id: int = 0) -> int:
65
80
  with self.db.init():
66
81
  # unique id for endlessness insertion
67
- all_metadata = [i+left_id for i in all_metadata]
82
+ all_metadata = [i + left_id for i in all_metadata]
68
83
 
69
- NUM_BATCHES = math.ceil(len(all_embeddings)/NUM_PER_BATCH)
70
- log.info(f"({mp.current_process().name:16}) Start inserting {len(all_embeddings)} embeddings in batch {NUM_PER_BATCH}")
84
+ num_batches = math.ceil(len(all_embeddings) / NUM_PER_BATCH)
85
+ log.info(
86
+ f"({mp.current_process().name:16}) Start inserting {len(all_embeddings)} ",
87
+ f"embeddings in batch {NUM_PER_BATCH}",
88
+ )
71
89
  count = 0
72
- for batch_id in range(NUM_BATCHES):
90
+ for batch_id in range(num_batches):
73
91
  retry_count = 0
74
92
  already_insert_count = 0
75
- metadata = all_metadata[batch_id*NUM_PER_BATCH : (batch_id+1)*NUM_PER_BATCH]
76
- embeddings = all_embeddings[batch_id*NUM_PER_BATCH : (batch_id+1)*NUM_PER_BATCH]
93
+ metadata = all_metadata[batch_id * NUM_PER_BATCH : (batch_id + 1) * NUM_PER_BATCH]
94
+ embeddings = all_embeddings[batch_id * NUM_PER_BATCH : (batch_id + 1) * NUM_PER_BATCH]
77
95
 
78
- log.debug(f"({mp.current_process().name:16}) batch [{batch_id:3}/{NUM_BATCHES}], Start inserting {len(metadata)} embeddings")
96
+ log.debug(
97
+ f"({mp.current_process().name:16}) batch [{batch_id:3}/{num_batches}], ",
98
+ f"Start inserting {len(metadata)} embeddings",
99
+ )
79
100
  while retry_count < LOAD_MAX_TRY_COUNT:
80
101
  insert_count, error = self.db.insert_embeddings(
81
- embeddings=embeddings[already_insert_count :],
82
- metadata=metadata[already_insert_count :],
102
+ embeddings=embeddings[already_insert_count:],
103
+ metadata=metadata[already_insert_count:],
83
104
  )
84
105
  already_insert_count += insert_count
85
106
  if error is not None:
@@ -91,17 +112,26 @@ class SerialInsertRunner:
91
112
  raise error
92
113
  else:
93
114
  break
94
- log.debug(f"({mp.current_process().name:16}) batch [{batch_id:3}/{NUM_BATCHES}], Finish inserting {len(metadata)} embeddings")
115
+ log.debug(
116
+ f"({mp.current_process().name:16}) batch [{batch_id:3}/{num_batches}], ",
117
+ f"Finish inserting {len(metadata)} embeddings",
118
+ )
95
119
 
96
120
  assert already_insert_count == len(metadata)
97
121
  count += already_insert_count
98
- log.info(f"({mp.current_process().name:16}) Finish inserting {len(all_embeddings)} embeddings in batch {NUM_PER_BATCH}")
122
+ log.info(
123
+ f"({mp.current_process().name:16}) Finish inserting {len(all_embeddings)} embeddings in ",
124
+ f"batch {NUM_PER_BATCH}",
125
+ )
99
126
  return count
100
127
 
101
128
  @utils.time_it
102
129
  def _insert_all_batches(self) -> int:
103
130
  """Performance case only"""
104
- with concurrent.futures.ProcessPoolExecutor(mp_context=mp.get_context('spawn'), max_workers=1) as executor:
131
+ with concurrent.futures.ProcessPoolExecutor(
132
+ mp_context=mp.get_context("spawn"),
133
+ max_workers=1,
134
+ ) as executor:
105
135
  future = executor.submit(self.task)
106
136
  try:
107
137
  count = future.result(timeout=self.timeout)
@@ -121,8 +151,11 @@ class SerialInsertRunner:
121
151
  """run forever util DB raises exception or crash"""
122
152
  # datasets for load tests are quite small, can fit into memory
123
153
  # only 1 file
124
- data_df = [data_df for data_df in self.dataset][0]
125
- all_embeddings, all_metadata = np.stack(data_df["emb"]).tolist(), data_df['id'].tolist()
154
+ data_df = next(iter(self.dataset))
155
+ all_embeddings, all_metadata = (
156
+ np.stack(data_df["emb"]).tolist(),
157
+ data_df["id"].tolist(),
158
+ )
126
159
 
127
160
  start_time = time.perf_counter()
128
161
  max_load_count, times = 0, 0
@@ -130,18 +163,26 @@ class SerialInsertRunner:
130
163
  with self.db.init():
131
164
  self.db.ready_to_load()
132
165
  while time.perf_counter() - start_time < self.timeout:
133
- count = self.endless_insert_data(all_embeddings, all_metadata, left_id=max_load_count)
166
+ count = self.endless_insert_data(
167
+ all_embeddings,
168
+ all_metadata,
169
+ left_id=max_load_count,
170
+ )
134
171
  max_load_count += count
135
172
  times += 1
136
- log.info(f"Loaded {times} entire dataset, current max load counts={utils.numerize(max_load_count)}, {max_load_count}")
173
+ log.info(
174
+ f"Loaded {times} entire dataset, current max load counts={utils.numerize(max_load_count)}, ",
175
+ f"{max_load_count}",
176
+ )
137
177
  except Exception as e:
138
- log.info(f"Capacity case load reach limit, insertion counts={utils.numerize(max_load_count)}, {max_load_count}, err={e}")
178
+ log.info(
179
+ f"Capacity case load reach limit, insertion counts={utils.numerize(max_load_count)}, ",
180
+ f"{max_load_count}, err={e}",
181
+ )
139
182
  traceback.print_exc()
140
183
  return max_load_count
141
184
  else:
142
- msg = f"capacity case load timeout in {self.timeout}s"
143
- log.info(msg)
144
- raise LoadTimeoutError(msg)
185
+ raise LoadTimeoutError(self.timeout)
145
186
 
146
187
  def run(self) -> int:
147
188
  count, dur = self._insert_all_batches()
@@ -167,8 +208,10 @@ class SerialSearchRunner:
167
208
  self.test_data = test_data
168
209
  self.ground_truth = ground_truth
169
210
 
170
- def search(self, args: tuple[list, pd.DataFrame]):
171
- log.info(f"{mp.current_process().name:14} start search the entire test_data to get recall and latency")
211
+ def search(self, args: tuple[list, pd.DataFrame]) -> tuple[float, float, float]:
212
+ log.info(
213
+ f"{mp.current_process().name:14} start search the entire test_data to get recall and latency",
214
+ )
172
215
  with self.db.init():
173
216
  test_data, ground_truth = args
174
217
  ideal_dcg = get_ideal_dcg(self.k)
@@ -193,13 +236,15 @@ class SerialSearchRunner:
193
236
 
194
237
  latencies.append(time.perf_counter() - s)
195
238
 
196
- gt = ground_truth['neighbors_id'][idx]
197
- recalls.append(calc_recall(self.k, gt[:self.k], results))
198
- ndcgs.append(calc_ndcg(gt[:self.k], results, ideal_dcg))
199
-
239
+ gt = ground_truth["neighbors_id"][idx]
240
+ recalls.append(calc_recall(self.k, gt[: self.k], results))
241
+ ndcgs.append(calc_ndcg(gt[: self.k], results, ideal_dcg))
200
242
 
201
243
  if len(latencies) % 100 == 0:
202
- log.debug(f"({mp.current_process().name:14}) search_count={len(latencies):3}, latest_latency={latencies[-1]}, latest recall={recalls[-1]}")
244
+ log.debug(
245
+ f"({mp.current_process().name:14}) search_count={len(latencies):3}, ",
246
+ f"latest_latency={latencies[-1]}, latest recall={recalls[-1]}",
247
+ )
203
248
 
204
249
  avg_latency = round(np.mean(latencies), 4)
205
250
  avg_recall = round(np.mean(recalls), 4)
@@ -213,16 +258,20 @@ class SerialSearchRunner:
213
258
  f"avg_recall={avg_recall}, "
214
259
  f"avg_ndcg={avg_ndcg},"
215
260
  f"avg_latency={avg_latency}, "
216
- f"p99={p99}"
217
- )
261
+ f"p99={p99}",
262
+ )
218
263
  return (avg_recall, avg_ndcg, p99)
219
264
 
220
-
221
265
  def _run_in_subprocess(self) -> tuple[float, float]:
222
266
  with concurrent.futures.ProcessPoolExecutor(max_workers=1) as executor:
223
267
  future = executor.submit(self.search, (self.test_data, self.ground_truth))
224
- result = future.result()
225
- return result
268
+ return future.result()
269
+
270
+ @utils.time_it
271
+ def run(self) -> tuple[float, float, float]:
272
+ """
273
+ Returns:
274
+ tuple[tuple[float, float, float], float]: (avg_recall, avg_ndcg, p99_latency), cost
226
275
 
227
- def run(self) -> tuple[float, float]:
276
+ """
228
277
  return self._run_in_subprocess()
@@ -1,32 +1,17 @@
1
1
  import logging
2
- import concurrent
3
- from typing import Iterable
4
2
 
5
- from pandas import DataFrame
6
3
  import numpy as np
4
+ from pandas import DataFrame
7
5
 
8
6
  log = logging.getLogger(__name__)
9
7
 
8
+
10
9
  def get_data(data_df: DataFrame, normalize: bool) -> tuple[list[list[float]], list[str]]:
11
- all_metadata = data_df['id'].tolist()
12
- emb_np = np.stack(data_df['emb'])
10
+ all_metadata = data_df["id"].tolist()
11
+ emb_np = np.stack(data_df["emb"])
13
12
  if normalize:
14
13
  log.debug("normalize the 100k train data")
15
14
  all_embeddings = (emb_np / np.linalg.norm(emb_np, axis=1)[:, np.newaxis]).tolist()
16
15
  else:
17
16
  all_embeddings = emb_np.tolist()
18
17
  return all_embeddings, all_metadata
19
-
20
- def is_futures_completed(futures: Iterable[concurrent.futures.Future], interval) -> (Exception, bool):
21
- try:
22
- list(concurrent.futures.as_completed(futures, timeout=interval))
23
- except TimeoutError as e:
24
- return e, False
25
- return None, True
26
-
27
-
28
- def get_future_exceptions(futures: Iterable[concurrent.futures.Future]) -> BaseException | None:
29
- for f in futures:
30
- if f.exception() is not None:
31
- return f.exception()
32
- return