vectordb-bench 0.0.19__py3-none-any.whl → 0.0.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (105) hide show
  1. vectordb_bench/__init__.py +49 -24
  2. vectordb_bench/__main__.py +4 -3
  3. vectordb_bench/backend/assembler.py +12 -13
  4. vectordb_bench/backend/cases.py +55 -45
  5. vectordb_bench/backend/clients/__init__.py +75 -14
  6. vectordb_bench/backend/clients/aliyun_elasticsearch/aliyun_elasticsearch.py +1 -2
  7. vectordb_bench/backend/clients/aliyun_elasticsearch/config.py +3 -4
  8. vectordb_bench/backend/clients/aliyun_opensearch/aliyun_opensearch.py +111 -70
  9. vectordb_bench/backend/clients/aliyun_opensearch/config.py +6 -7
  10. vectordb_bench/backend/clients/alloydb/alloydb.py +58 -80
  11. vectordb_bench/backend/clients/alloydb/cli.py +51 -34
  12. vectordb_bench/backend/clients/alloydb/config.py +30 -30
  13. vectordb_bench/backend/clients/api.py +5 -9
  14. vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py +46 -47
  15. vectordb_bench/backend/clients/aws_opensearch/cli.py +4 -7
  16. vectordb_bench/backend/clients/aws_opensearch/config.py +13 -9
  17. vectordb_bench/backend/clients/aws_opensearch/run.py +69 -59
  18. vectordb_bench/backend/clients/chroma/chroma.py +38 -36
  19. vectordb_bench/backend/clients/chroma/config.py +4 -2
  20. vectordb_bench/backend/clients/elastic_cloud/config.py +5 -5
  21. vectordb_bench/backend/clients/elastic_cloud/elastic_cloud.py +23 -22
  22. vectordb_bench/backend/clients/memorydb/cli.py +8 -8
  23. vectordb_bench/backend/clients/memorydb/config.py +2 -2
  24. vectordb_bench/backend/clients/memorydb/memorydb.py +65 -53
  25. vectordb_bench/backend/clients/milvus/cli.py +41 -83
  26. vectordb_bench/backend/clients/milvus/config.py +18 -8
  27. vectordb_bench/backend/clients/milvus/milvus.py +18 -19
  28. vectordb_bench/backend/clients/pgdiskann/cli.py +29 -22
  29. vectordb_bench/backend/clients/pgdiskann/config.py +29 -26
  30. vectordb_bench/backend/clients/pgdiskann/pgdiskann.py +55 -73
  31. vectordb_bench/backend/clients/pgvecto_rs/cli.py +9 -11
  32. vectordb_bench/backend/clients/pgvecto_rs/config.py +8 -14
  33. vectordb_bench/backend/clients/pgvecto_rs/pgvecto_rs.py +33 -34
  34. vectordb_bench/backend/clients/pgvector/cli.py +40 -31
  35. vectordb_bench/backend/clients/pgvector/config.py +63 -73
  36. vectordb_bench/backend/clients/pgvector/pgvector.py +97 -98
  37. vectordb_bench/backend/clients/pgvectorscale/cli.py +38 -24
  38. vectordb_bench/backend/clients/pgvectorscale/config.py +14 -15
  39. vectordb_bench/backend/clients/pgvectorscale/pgvectorscale.py +38 -43
  40. vectordb_bench/backend/clients/pinecone/config.py +1 -0
  41. vectordb_bench/backend/clients/pinecone/pinecone.py +14 -21
  42. vectordb_bench/backend/clients/qdrant_cloud/config.py +11 -10
  43. vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py +40 -31
  44. vectordb_bench/backend/clients/redis/cli.py +6 -12
  45. vectordb_bench/backend/clients/redis/config.py +7 -5
  46. vectordb_bench/backend/clients/redis/redis.py +94 -58
  47. vectordb_bench/backend/clients/test/cli.py +1 -2
  48. vectordb_bench/backend/clients/test/config.py +2 -2
  49. vectordb_bench/backend/clients/test/test.py +4 -5
  50. vectordb_bench/backend/clients/weaviate_cloud/cli.py +3 -4
  51. vectordb_bench/backend/clients/weaviate_cloud/config.py +2 -2
  52. vectordb_bench/backend/clients/weaviate_cloud/weaviate_cloud.py +36 -22
  53. vectordb_bench/backend/clients/zilliz_cloud/cli.py +14 -11
  54. vectordb_bench/backend/clients/zilliz_cloud/config.py +2 -4
  55. vectordb_bench/backend/clients/zilliz_cloud/zilliz_cloud.py +1 -1
  56. vectordb_bench/backend/data_source.py +30 -18
  57. vectordb_bench/backend/dataset.py +47 -27
  58. vectordb_bench/backend/result_collector.py +2 -3
  59. vectordb_bench/backend/runner/__init__.py +4 -6
  60. vectordb_bench/backend/runner/mp_runner.py +85 -34
  61. vectordb_bench/backend/runner/rate_runner.py +30 -19
  62. vectordb_bench/backend/runner/read_write_runner.py +51 -23
  63. vectordb_bench/backend/runner/serial_runner.py +91 -48
  64. vectordb_bench/backend/runner/util.py +4 -3
  65. vectordb_bench/backend/task_runner.py +92 -72
  66. vectordb_bench/backend/utils.py +17 -10
  67. vectordb_bench/base.py +0 -1
  68. vectordb_bench/cli/cli.py +65 -60
  69. vectordb_bench/cli/vectordbbench.py +6 -7
  70. vectordb_bench/frontend/components/check_results/charts.py +8 -19
  71. vectordb_bench/frontend/components/check_results/data.py +4 -16
  72. vectordb_bench/frontend/components/check_results/filters.py +8 -16
  73. vectordb_bench/frontend/components/check_results/nav.py +4 -4
  74. vectordb_bench/frontend/components/check_results/priceTable.py +1 -3
  75. vectordb_bench/frontend/components/check_results/stPageConfig.py +2 -1
  76. vectordb_bench/frontend/components/concurrent/charts.py +12 -12
  77. vectordb_bench/frontend/components/custom/displayCustomCase.py +17 -11
  78. vectordb_bench/frontend/components/custom/displaypPrams.py +4 -2
  79. vectordb_bench/frontend/components/custom/getCustomConfig.py +1 -2
  80. vectordb_bench/frontend/components/custom/initStyle.py +1 -1
  81. vectordb_bench/frontend/components/get_results/saveAsImage.py +2 -0
  82. vectordb_bench/frontend/components/run_test/caseSelector.py +3 -9
  83. vectordb_bench/frontend/components/run_test/dbConfigSetting.py +1 -4
  84. vectordb_bench/frontend/components/run_test/dbSelector.py +1 -1
  85. vectordb_bench/frontend/components/run_test/generateTasks.py +8 -8
  86. vectordb_bench/frontend/components/run_test/submitTask.py +14 -18
  87. vectordb_bench/frontend/components/tables/data.py +3 -6
  88. vectordb_bench/frontend/config/dbCaseConfigs.py +51 -84
  89. vectordb_bench/frontend/pages/concurrent.py +3 -5
  90. vectordb_bench/frontend/pages/custom.py +30 -9
  91. vectordb_bench/frontend/pages/quries_per_dollar.py +3 -3
  92. vectordb_bench/frontend/pages/run_test.py +3 -7
  93. vectordb_bench/frontend/utils.py +1 -1
  94. vectordb_bench/frontend/vdb_benchmark.py +4 -6
  95. vectordb_bench/interface.py +56 -26
  96. vectordb_bench/log_util.py +59 -64
  97. vectordb_bench/metric.py +10 -11
  98. vectordb_bench/models.py +26 -43
  99. {vectordb_bench-0.0.19.dist-info → vectordb_bench-0.0.20.dist-info}/METADATA +22 -15
  100. vectordb_bench-0.0.20.dist-info/RECORD +135 -0
  101. {vectordb_bench-0.0.19.dist-info → vectordb_bench-0.0.20.dist-info}/WHEEL +1 -1
  102. vectordb_bench-0.0.19.dist-info/RECORD +0 -135
  103. {vectordb_bench-0.0.19.dist-info → vectordb_bench-0.0.20.dist-info}/LICENSE +0 -0
  104. {vectordb_bench-0.0.19.dist-info → vectordb_bench-0.0.20.dist-info}/entry_points.txt +0 -0
  105. {vectordb_bench-0.0.19.dist-info → vectordb_bench-0.0.20.dist-info}/top_level.txt +0 -0
@@ -1,36 +1,36 @@
1
+ import concurrent
1
2
  import logging
3
+ import multiprocessing as mp
2
4
  import time
3
- import concurrent
4
5
  from concurrent.futures import ThreadPoolExecutor
5
- import multiprocessing as mp
6
-
7
6
 
7
+ from vectordb_bench import config
8
8
  from vectordb_bench.backend.clients import api
9
9
  from vectordb_bench.backend.dataset import DataSetIterator
10
10
  from vectordb_bench.backend.utils import time_it
11
- from vectordb_bench import config
12
11
 
13
12
  from .util import get_data
13
+
14
14
  log = logging.getLogger(__name__)
15
15
 
16
16
 
17
17
  class RatedMultiThreadingInsertRunner:
18
18
  def __init__(
19
19
  self,
20
- rate: int, # numRows per second
20
+ rate: int, # numRows per second
21
21
  db: api.VectorDB,
22
22
  dataset_iter: DataSetIterator,
23
23
  normalize: bool = False,
24
24
  timeout: float | None = None,
25
25
  ):
26
- self.timeout = timeout if isinstance(timeout, (int, float)) else None
26
+ self.timeout = timeout if isinstance(timeout, int | float) else None
27
27
  self.dataset = dataset_iter
28
28
  self.db = db
29
29
  self.normalize = normalize
30
30
  self.insert_rate = rate
31
31
  self.batch_rate = rate // config.NUM_PER_BATCH
32
32
 
33
- def send_insert_task(self, db, emb: list[list[float]], metadata: list[str]):
33
+ def send_insert_task(self, db: api.VectorDB, emb: list[list[float]], metadata: list[str]):
34
34
  db.insert_embeddings(emb, metadata)
35
35
 
36
36
  @time_it
@@ -43,7 +43,9 @@ class RatedMultiThreadingInsertRunner:
43
43
  rate = self.batch_rate
44
44
  for data in self.dataset:
45
45
  emb, metadata = get_data(data, self.normalize)
46
- executing_futures.append(executor.submit(self.send_insert_task, self.db, emb, metadata))
46
+ executing_futures.append(
47
+ executor.submit(self.send_insert_task, self.db, emb, metadata),
48
+ )
47
49
  rate -= 1
48
50
 
49
51
  if rate == 0:
@@ -66,19 +68,26 @@ class RatedMultiThreadingInsertRunner:
66
68
  done, not_done = concurrent.futures.wait(
67
69
  executing_futures,
68
70
  timeout=wait_interval,
69
- return_when=concurrent.futures.FIRST_EXCEPTION)
71
+ return_when=concurrent.futures.FIRST_EXCEPTION,
72
+ )
70
73
 
71
74
  if len(not_done) > 0:
72
- log.warning(f"Failed to finish all tasks in 1s, [{len(not_done)}/{len(executing_futures)}] tasks are not done, waited={wait_interval:.2f}, trying to wait in the next round")
75
+ log.warning(
76
+ f"Failed to finish all tasks in 1s, [{len(not_done)}/{len(executing_futures)}] ",
77
+ f"tasks are not done, waited={wait_interval:.2f}, trying to wait in the next round",
78
+ )
73
79
  executing_futures = list(not_done)
74
80
  else:
75
- log.debug(f"Finished {len(executing_futures)} insert-{config.NUM_PER_BATCH} task in 1s, wait_interval={wait_interval:.2f}")
81
+ log.debug(
82
+ f"Finished {len(executing_futures)} insert-{config.NUM_PER_BATCH} ",
83
+ f"task in 1s, wait_interval={wait_interval:.2f}",
84
+ )
76
85
  executing_futures = []
77
86
  except Exception as e:
78
- log.warn(f"task error, terminating, err={e}")
79
- q.put(None, block=True)
80
- executor.shutdown(wait=True, cancel_futures=True)
81
- raise e
87
+ log.warning(f"task error, terminating, err={e}")
88
+ q.put(None, block=True)
89
+ executor.shutdown(wait=True, cancel_futures=True)
90
+ raise e from e
82
91
 
83
92
  dur = time.perf_counter() - start_time
84
93
  if dur < 1:
@@ -87,10 +96,12 @@ class RatedMultiThreadingInsertRunner:
87
96
  # wait for all tasks in executing_futures to complete
88
97
  if len(executing_futures) > 0:
89
98
  try:
90
- done, _ = concurrent.futures.wait(executing_futures,
91
- return_when=concurrent.futures.FIRST_EXCEPTION)
99
+ done, _ = concurrent.futures.wait(
100
+ executing_futures,
101
+ return_when=concurrent.futures.FIRST_EXCEPTION,
102
+ )
92
103
  except Exception as e:
93
- log.warn(f"task error, terminating, err={e}")
104
+ log.warning(f"task error, terminating, err={e}")
94
105
  q.put(None, block=True)
95
106
  executor.shutdown(wait=True, cancel_futures=True)
96
- raise e
107
+ raise e from e
@@ -1,16 +1,18 @@
1
+ import concurrent
1
2
  import logging
2
- from typing import Iterable
3
+ import math
3
4
  import multiprocessing as mp
4
- import concurrent
5
+ from collections.abc import Iterable
6
+
5
7
  import numpy as np
6
- import math
7
8
 
8
- from .mp_runner import MultiProcessingSearchRunner
9
- from .serial_runner import SerialSearchRunner
10
- from .rate_runner import RatedMultiThreadingInsertRunner
11
9
  from vectordb_bench.backend.clients import api
12
10
  from vectordb_bench.backend.dataset import DatasetManager
13
11
 
12
+ from .mp_runner import MultiProcessingSearchRunner
13
+ from .rate_runner import RatedMultiThreadingInsertRunner
14
+ from .serial_runner import SerialSearchRunner
15
+
14
16
  log = logging.getLogger(__name__)
15
17
 
16
18
 
@@ -24,8 +26,14 @@ class ReadWriteRunner(MultiProcessingSearchRunner, RatedMultiThreadingInsertRunn
24
26
  k: int = 100,
25
27
  filters: dict | None = None,
26
28
  concurrencies: Iterable[int] = (1, 15, 50),
27
- search_stage: Iterable[float] = (0.5, 0.6, 0.7, 0.8, 0.9), # search from insert portion, 0.0 means search from the start
28
- read_dur_after_write: int = 300, # seconds, search duration when insertion is done
29
+ search_stage: Iterable[float] = (
30
+ 0.5,
31
+ 0.6,
32
+ 0.7,
33
+ 0.8,
34
+ 0.9,
35
+ ), # search from insert portion, 0.0 means search from the start
36
+ read_dur_after_write: int = 300, # seconds, search duration when insertion is done
29
37
  timeout: float | None = None,
30
38
  ):
31
39
  self.insert_rate = insert_rate
@@ -36,7 +44,10 @@ class ReadWriteRunner(MultiProcessingSearchRunner, RatedMultiThreadingInsertRunn
36
44
  self.search_stage = sorted(search_stage)
37
45
  self.read_dur_after_write = read_dur_after_write
38
46
 
39
- log.info(f"Init runner, concurencys={concurrencies}, search_stage={search_stage}, stage_search_dur={read_dur_after_write}")
47
+ log.info(
48
+ f"Init runner, concurencys={concurrencies}, search_stage={search_stage}, ",
49
+ f"stage_search_dur={read_dur_after_write}",
50
+ )
40
51
 
41
52
  test_emb = np.stack(dataset.test_data["emb"])
42
53
  if normalize:
@@ -76,8 +87,13 @@ class ReadWriteRunner(MultiProcessingSearchRunner, RatedMultiThreadingInsertRunn
76
87
  log.info("Search after write - Serial search start")
77
88
  res, ssearch_dur = self.serial_search_runner.run()
78
89
  recall, ndcg, p99_latency = res
79
- log.info(f"Search after write - Serial search - recall={recall}, ndcg={ndcg}, p99={p99_latency}, dur={ssearch_dur:.4f}")
80
- log.info(f"Search after wirte - Conc search start, dur for each conc={self.read_dur_after_write}")
90
+ log.info(
91
+ f"Search after write - Serial search - recall={recall}, ndcg={ndcg}, p99={p99_latency}, ",
92
+ f"dur={ssearch_dur:.4f}",
93
+ )
94
+ log.info(
95
+ f"Search after wirte - Conc search start, dur for each conc={self.read_dur_after_write}",
96
+ )
81
97
  max_qps = self.run_by_dur(self.read_dur_after_write)
82
98
  log.info(f"Search after wirte - Conc search finished, max_qps={max_qps}")
83
99
 
@@ -86,7 +102,10 @@ class ReadWriteRunner(MultiProcessingSearchRunner, RatedMultiThreadingInsertRunn
86
102
  def run_read_write(self):
87
103
  with mp.Manager() as m:
88
104
  q = m.Queue()
89
- with concurrent.futures.ProcessPoolExecutor(mp_context=mp.get_context("spawn"), max_workers=2) as executor:
105
+ with concurrent.futures.ProcessPoolExecutor(
106
+ mp_context=mp.get_context("spawn"),
107
+ max_workers=2,
108
+ ) as executor:
90
109
  read_write_futures = []
91
110
  read_write_futures.append(executor.submit(self.run_with_rate, q))
92
111
  read_write_futures.append(executor.submit(self.run_search_by_sig, q))
@@ -107,10 +126,10 @@ class ReadWriteRunner(MultiProcessingSearchRunner, RatedMultiThreadingInsertRunn
107
126
  except Exception as e:
108
127
  log.warning(f"Read and write error: {e}")
109
128
  executor.shutdown(wait=True, cancel_futures=True)
110
- raise e
129
+ raise e from e
111
130
  log.info("Concurrent read write all done")
112
131
 
113
- def run_search_by_sig(self, q):
132
+ def run_search_by_sig(self, q: mp.Queue):
114
133
  """
115
134
  Args:
116
135
  q: multiprocessing queue
@@ -122,15 +141,14 @@ class ReadWriteRunner(MultiProcessingSearchRunner, RatedMultiThreadingInsertRunn
122
141
  total_batch = math.ceil(self.data_volume / self.insert_rate)
123
142
  recall, ndcg, p99_latency = None, None, None
124
143
 
125
- def wait_next_target(start, target_batch) -> bool:
144
+ def wait_next_target(start: int, target_batch: int) -> bool:
126
145
  """Return False when receive True or None"""
127
146
  while start < target_batch:
128
147
  sig = q.get(block=True)
129
148
 
130
149
  if sig is None or sig is True:
131
150
  return False
132
- else:
133
- start += 1
151
+ start += 1
134
152
  return True
135
153
 
136
154
  for idx, stage in enumerate(self.search_stage):
@@ -139,19 +157,24 @@ class ReadWriteRunner(MultiProcessingSearchRunner, RatedMultiThreadingInsertRunn
139
157
 
140
158
  got = wait_next_target(start_batch, target_batch)
141
159
  if got is False:
142
- log.warning(f"Abnormal exit, target_batch={target_batch}, start_batch={start_batch}")
143
- return
160
+ log.warning(
161
+ f"Abnormal exit, target_batch={target_batch}, start_batch={start_batch}",
162
+ )
163
+ return None
144
164
 
145
165
  log.info(f"Insert {perc}% done, total batch={total_batch}")
146
166
  log.info(f"[{target_batch}/{total_batch}] Serial search - {perc}% start")
147
167
  res, ssearch_dur = self.serial_search_runner.run()
148
168
  recall, ndcg, p99_latency = res
149
- log.info(f"[{target_batch}/{total_batch}] Serial search - {perc}% done, recall={recall}, ndcg={ndcg}, p99={p99_latency}, dur={ssearch_dur:.4f}")
169
+ log.info(
170
+ f"[{target_batch}/{total_batch}] Serial search - {perc}% done, recall={recall}, ",
171
+ f"ndcg={ndcg}, p99={p99_latency}, dur={ssearch_dur:.4f}",
172
+ )
150
173
 
151
174
  # Search duration for non-last search stage is carefully calculated.
152
175
  # If duration for each concurrency is less than 30s, runner will raise error.
153
176
  if idx < len(self.search_stage) - 1:
154
- total_dur_between_stages = self.data_volume * (self.search_stage[idx + 1] - stage) // self.insert_rate
177
+ total_dur_between_stages = self.data_volume * (self.search_stage[idx + 1] - stage) // self.insert_rate
155
178
  csearch_dur = total_dur_between_stages - ssearch_dur
156
179
 
157
180
  # Try to leave room for init process executors
@@ -159,14 +182,19 @@ class ReadWriteRunner(MultiProcessingSearchRunner, RatedMultiThreadingInsertRunn
159
182
 
160
183
  each_conc_search_dur = csearch_dur / len(self.concurrencies)
161
184
  if each_conc_search_dur < 30:
162
- warning_msg = f"Results might be inaccurate, duration[{csearch_dur:.4f}] left for conc-search is too short, total available dur={total_dur_between_stages}, serial_search_cost={ssearch_dur}."
185
+ warning_msg = (
186
+ f"Results might be inaccurate, duration[{csearch_dur:.4f}] left for conc-search is too short, ",
187
+ f"total available dur={total_dur_between_stages}, serial_search_cost={ssearch_dur}.",
188
+ )
163
189
  log.warning(warning_msg)
164
190
 
165
191
  # The last stage
166
192
  else:
167
193
  each_conc_search_dur = 60
168
194
 
169
- log.info(f"[{target_batch}/{total_batch}] Concurrent search - {perc}% start, dur={each_conc_search_dur:.4f}")
195
+ log.info(
196
+ f"[{target_batch}/{total_batch}] Concurrent search - {perc}% start, dur={each_conc_search_dur:.4f}",
197
+ )
170
198
  max_qps = self.run_by_dur(each_conc_search_dur)
171
199
  result.append((perc, max_qps, recall, ndcg, p99_latency))
172
200
 
@@ -1,20 +1,21 @@
1
- import time
2
- import logging
3
- import traceback
4
1
  import concurrent
5
- import multiprocessing as mp
2
+ import logging
6
3
  import math
7
- import psutil
4
+ import multiprocessing as mp
5
+ import time
6
+ import traceback
8
7
 
9
8
  import numpy as np
10
9
  import pandas as pd
10
+ import psutil
11
11
 
12
- from ..clients import api
12
+ from vectordb_bench.backend.dataset import DatasetManager
13
+
14
+ from ... import config
13
15
  from ...metric import calc_ndcg, calc_recall, get_ideal_dcg
14
16
  from ...models import LoadTimeoutError, PerformanceTimeoutError
15
17
  from .. import utils
16
- from ... import config
17
- from vectordb_bench.backend.dataset import DatasetManager
18
+ from ..clients import api
18
19
 
19
20
  NUM_PER_BATCH = config.NUM_PER_BATCH
20
21
  LOAD_MAX_TRY_COUNT = 10
@@ -22,9 +23,16 @@ WAITTING_TIME = 60
22
23
 
23
24
  log = logging.getLogger(__name__)
24
25
 
26
+
25
27
  class SerialInsertRunner:
26
- def __init__(self, db: api.VectorDB, dataset: DatasetManager, normalize: bool, timeout: float | None = None):
27
- self.timeout = timeout if isinstance(timeout, (int, float)) else None
28
+ def __init__(
29
+ self,
30
+ db: api.VectorDB,
31
+ dataset: DatasetManager,
32
+ normalize: bool,
33
+ timeout: float | None = None,
34
+ ):
35
+ self.timeout = timeout if isinstance(timeout, int | float) else None
28
36
  self.dataset = dataset
29
37
  self.db = db
30
38
  self.normalize = normalize
@@ -32,18 +40,20 @@ class SerialInsertRunner:
32
40
  def task(self) -> int:
33
41
  count = 0
34
42
  with self.db.init():
35
- log.info(f"({mp.current_process().name:16}) Start inserting embeddings in batch {config.NUM_PER_BATCH}")
43
+ log.info(
44
+ f"({mp.current_process().name:16}) Start inserting embeddings in batch {config.NUM_PER_BATCH}",
45
+ )
36
46
  start = time.perf_counter()
37
47
  for data_df in self.dataset:
38
- all_metadata = data_df['id'].tolist()
48
+ all_metadata = data_df["id"].tolist()
39
49
 
40
- emb_np = np.stack(data_df['emb'])
50
+ emb_np = np.stack(data_df["emb"])
41
51
  if self.normalize:
42
52
  log.debug("normalize the 100k train data")
43
53
  all_embeddings = (emb_np / np.linalg.norm(emb_np, axis=1)[:, np.newaxis]).tolist()
44
54
  else:
45
55
  all_embeddings = emb_np.tolist()
46
- del(emb_np)
56
+ del emb_np
47
57
  log.debug(f"batch dataset size: {len(all_embeddings)}, {len(all_metadata)}")
48
58
 
49
59
  insert_count, error = self.db.insert_embeddings(
@@ -56,30 +66,41 @@ class SerialInsertRunner:
56
66
  assert insert_count == len(all_metadata)
57
67
  count += insert_count
58
68
  if count % 100_000 == 0:
59
- log.info(f"({mp.current_process().name:16}) Loaded {count} embeddings into VectorDB")
69
+ log.info(
70
+ f"({mp.current_process().name:16}) Loaded {count} embeddings into VectorDB",
71
+ )
60
72
 
61
- log.info(f"({mp.current_process().name:16}) Finish loading all dataset into VectorDB, dur={time.perf_counter()-start}")
73
+ log.info(
74
+ f"({mp.current_process().name:16}) Finish loading all dataset into VectorDB, ",
75
+ f"dur={time.perf_counter()-start}",
76
+ )
62
77
  return count
63
78
 
64
- def endless_insert_data(self, all_embeddings, all_metadata, left_id: int = 0) -> int:
79
+ def endless_insert_data(self, all_embeddings: list, all_metadata: list, left_id: int = 0) -> int:
65
80
  with self.db.init():
66
81
  # unique id for endlessness insertion
67
- all_metadata = [i+left_id for i in all_metadata]
82
+ all_metadata = [i + left_id for i in all_metadata]
68
83
 
69
- NUM_BATCHES = math.ceil(len(all_embeddings)/NUM_PER_BATCH)
70
- log.info(f"({mp.current_process().name:16}) Start inserting {len(all_embeddings)} embeddings in batch {NUM_PER_BATCH}")
84
+ num_batches = math.ceil(len(all_embeddings) / NUM_PER_BATCH)
85
+ log.info(
86
+ f"({mp.current_process().name:16}) Start inserting {len(all_embeddings)} ",
87
+ f"embeddings in batch {NUM_PER_BATCH}",
88
+ )
71
89
  count = 0
72
- for batch_id in range(NUM_BATCHES):
90
+ for batch_id in range(num_batches):
73
91
  retry_count = 0
74
92
  already_insert_count = 0
75
- metadata = all_metadata[batch_id*NUM_PER_BATCH : (batch_id+1)*NUM_PER_BATCH]
76
- embeddings = all_embeddings[batch_id*NUM_PER_BATCH : (batch_id+1)*NUM_PER_BATCH]
93
+ metadata = all_metadata[batch_id * NUM_PER_BATCH : (batch_id + 1) * NUM_PER_BATCH]
94
+ embeddings = all_embeddings[batch_id * NUM_PER_BATCH : (batch_id + 1) * NUM_PER_BATCH]
77
95
 
78
- log.debug(f"({mp.current_process().name:16}) batch [{batch_id:3}/{NUM_BATCHES}], Start inserting {len(metadata)} embeddings")
96
+ log.debug(
97
+ f"({mp.current_process().name:16}) batch [{batch_id:3}/{num_batches}], ",
98
+ f"Start inserting {len(metadata)} embeddings",
99
+ )
79
100
  while retry_count < LOAD_MAX_TRY_COUNT:
80
101
  insert_count, error = self.db.insert_embeddings(
81
- embeddings=embeddings[already_insert_count :],
82
- metadata=metadata[already_insert_count :],
102
+ embeddings=embeddings[already_insert_count:],
103
+ metadata=metadata[already_insert_count:],
83
104
  )
84
105
  already_insert_count += insert_count
85
106
  if error is not None:
@@ -91,17 +112,26 @@ class SerialInsertRunner:
91
112
  raise error
92
113
  else:
93
114
  break
94
- log.debug(f"({mp.current_process().name:16}) batch [{batch_id:3}/{NUM_BATCHES}], Finish inserting {len(metadata)} embeddings")
115
+ log.debug(
116
+ f"({mp.current_process().name:16}) batch [{batch_id:3}/{num_batches}], ",
117
+ f"Finish inserting {len(metadata)} embeddings",
118
+ )
95
119
 
96
120
  assert already_insert_count == len(metadata)
97
121
  count += already_insert_count
98
- log.info(f"({mp.current_process().name:16}) Finish inserting {len(all_embeddings)} embeddings in batch {NUM_PER_BATCH}")
122
+ log.info(
123
+ f"({mp.current_process().name:16}) Finish inserting {len(all_embeddings)} embeddings in ",
124
+ f"batch {NUM_PER_BATCH}",
125
+ )
99
126
  return count
100
127
 
101
128
  @utils.time_it
102
129
  def _insert_all_batches(self) -> int:
103
130
  """Performance case only"""
104
- with concurrent.futures.ProcessPoolExecutor(mp_context=mp.get_context('spawn'), max_workers=1) as executor:
131
+ with concurrent.futures.ProcessPoolExecutor(
132
+ mp_context=mp.get_context("spawn"),
133
+ max_workers=1,
134
+ ) as executor:
105
135
  future = executor.submit(self.task)
106
136
  try:
107
137
  count = future.result(timeout=self.timeout)
@@ -121,8 +151,11 @@ class SerialInsertRunner:
121
151
  """run forever util DB raises exception or crash"""
122
152
  # datasets for load tests are quite small, can fit into memory
123
153
  # only 1 file
124
- data_df = [data_df for data_df in self.dataset][0]
125
- all_embeddings, all_metadata = np.stack(data_df["emb"]).tolist(), data_df['id'].tolist()
154
+ data_df = next(iter(self.dataset))
155
+ all_embeddings, all_metadata = (
156
+ np.stack(data_df["emb"]).tolist(),
157
+ data_df["id"].tolist(),
158
+ )
126
159
 
127
160
  start_time = time.perf_counter()
128
161
  max_load_count, times = 0, 0
@@ -130,18 +163,26 @@ class SerialInsertRunner:
130
163
  with self.db.init():
131
164
  self.db.ready_to_load()
132
165
  while time.perf_counter() - start_time < self.timeout:
133
- count = self.endless_insert_data(all_embeddings, all_metadata, left_id=max_load_count)
166
+ count = self.endless_insert_data(
167
+ all_embeddings,
168
+ all_metadata,
169
+ left_id=max_load_count,
170
+ )
134
171
  max_load_count += count
135
172
  times += 1
136
- log.info(f"Loaded {times} entire dataset, current max load counts={utils.numerize(max_load_count)}, {max_load_count}")
173
+ log.info(
174
+ f"Loaded {times} entire dataset, current max load counts={utils.numerize(max_load_count)}, ",
175
+ f"{max_load_count}",
176
+ )
137
177
  except Exception as e:
138
- log.info(f"Capacity case load reach limit, insertion counts={utils.numerize(max_load_count)}, {max_load_count}, err={e}")
178
+ log.info(
179
+ f"Capacity case load reach limit, insertion counts={utils.numerize(max_load_count)}, ",
180
+ f"{max_load_count}, err={e}",
181
+ )
139
182
  traceback.print_exc()
140
183
  return max_load_count
141
184
  else:
142
- msg = f"capacity case load timeout in {self.timeout}s"
143
- log.info(msg)
144
- raise LoadTimeoutError(msg)
185
+ raise LoadTimeoutError(self.timeout)
145
186
 
146
187
  def run(self) -> int:
147
188
  count, dur = self._insert_all_batches()
@@ -168,7 +209,9 @@ class SerialSearchRunner:
168
209
  self.ground_truth = ground_truth
169
210
 
170
211
  def search(self, args: tuple[list, pd.DataFrame]) -> tuple[float, float, float]:
171
- log.info(f"{mp.current_process().name:14} start search the entire test_data to get recall and latency")
212
+ log.info(
213
+ f"{mp.current_process().name:14} start search the entire test_data to get recall and latency",
214
+ )
172
215
  with self.db.init():
173
216
  test_data, ground_truth = args
174
217
  ideal_dcg = get_ideal_dcg(self.k)
@@ -193,13 +236,15 @@ class SerialSearchRunner:
193
236
 
194
237
  latencies.append(time.perf_counter() - s)
195
238
 
196
- gt = ground_truth['neighbors_id'][idx]
197
- recalls.append(calc_recall(self.k, gt[:self.k], results))
198
- ndcgs.append(calc_ndcg(gt[:self.k], results, ideal_dcg))
199
-
239
+ gt = ground_truth["neighbors_id"][idx]
240
+ recalls.append(calc_recall(self.k, gt[: self.k], results))
241
+ ndcgs.append(calc_ndcg(gt[: self.k], results, ideal_dcg))
200
242
 
201
243
  if len(latencies) % 100 == 0:
202
- log.debug(f"({mp.current_process().name:14}) search_count={len(latencies):3}, latest_latency={latencies[-1]}, latest recall={recalls[-1]}")
244
+ log.debug(
245
+ f"({mp.current_process().name:14}) search_count={len(latencies):3}, ",
246
+ f"latest_latency={latencies[-1]}, latest recall={recalls[-1]}",
247
+ )
203
248
 
204
249
  avg_latency = round(np.mean(latencies), 4)
205
250
  avg_recall = round(np.mean(recalls), 4)
@@ -213,16 +258,14 @@ class SerialSearchRunner:
213
258
  f"avg_recall={avg_recall}, "
214
259
  f"avg_ndcg={avg_ndcg},"
215
260
  f"avg_latency={avg_latency}, "
216
- f"p99={p99}"
217
- )
261
+ f"p99={p99}",
262
+ )
218
263
  return (avg_recall, avg_ndcg, p99)
219
264
 
220
-
221
265
  def _run_in_subprocess(self) -> tuple[float, float]:
222
266
  with concurrent.futures.ProcessPoolExecutor(max_workers=1) as executor:
223
267
  future = executor.submit(self.search, (self.test_data, self.ground_truth))
224
- result = future.result()
225
- return result
268
+ return future.result()
226
269
 
227
270
  @utils.time_it
228
271
  def run(self) -> tuple[float, float, float]:
@@ -1,13 +1,14 @@
1
1
  import logging
2
2
 
3
- from pandas import DataFrame
4
3
  import numpy as np
4
+ from pandas import DataFrame
5
5
 
6
6
  log = logging.getLogger(__name__)
7
7
 
8
+
8
9
  def get_data(data_df: DataFrame, normalize: bool) -> tuple[list[list[float]], list[str]]:
9
- all_metadata = data_df['id'].tolist()
10
- emb_np = np.stack(data_df['emb'])
10
+ all_metadata = data_df["id"].tolist()
11
+ emb_np = np.stack(data_df["emb"])
11
12
  if normalize:
12
13
  log.debug("normalize the 100k train data")
13
14
  all_embeddings = (emb_np / np.linalg.norm(emb_np, axis=1)[:, np.newaxis]).tolist()