vectordb-bench 0.0.1__1-py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. vectordb_bench/__init__.py +30 -0
  2. vectordb_bench/__main__.py +39 -0
  3. vectordb_bench/backend/__init__.py +0 -0
  4. vectordb_bench/backend/assembler.py +57 -0
  5. vectordb_bench/backend/cases.py +124 -0
  6. vectordb_bench/backend/clients/__init__.py +57 -0
  7. vectordb_bench/backend/clients/api.py +179 -0
  8. vectordb_bench/backend/clients/elastic_cloud/config.py +56 -0
  9. vectordb_bench/backend/clients/elastic_cloud/elastic_cloud.py +152 -0
  10. vectordb_bench/backend/clients/milvus/config.py +123 -0
  11. vectordb_bench/backend/clients/milvus/milvus.py +182 -0
  12. vectordb_bench/backend/clients/pinecone/config.py +15 -0
  13. vectordb_bench/backend/clients/pinecone/pinecone.py +113 -0
  14. vectordb_bench/backend/clients/qdrant_cloud/config.py +16 -0
  15. vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py +169 -0
  16. vectordb_bench/backend/clients/weaviate_cloud/config.py +45 -0
  17. vectordb_bench/backend/clients/weaviate_cloud/weaviate_cloud.py +151 -0
  18. vectordb_bench/backend/clients/zilliz_cloud/config.py +34 -0
  19. vectordb_bench/backend/clients/zilliz_cloud/zilliz_cloud.py +35 -0
  20. vectordb_bench/backend/dataset.py +393 -0
  21. vectordb_bench/backend/result_collector.py +15 -0
  22. vectordb_bench/backend/runner/__init__.py +12 -0
  23. vectordb_bench/backend/runner/mp_runner.py +124 -0
  24. vectordb_bench/backend/runner/serial_runner.py +164 -0
  25. vectordb_bench/backend/task_runner.py +290 -0
  26. vectordb_bench/backend/utils.py +85 -0
  27. vectordb_bench/base.py +6 -0
  28. vectordb_bench/frontend/components/check_results/charts.py +175 -0
  29. vectordb_bench/frontend/components/check_results/data.py +86 -0
  30. vectordb_bench/frontend/components/check_results/filters.py +97 -0
  31. vectordb_bench/frontend/components/check_results/headerIcon.py +18 -0
  32. vectordb_bench/frontend/components/check_results/nav.py +21 -0
  33. vectordb_bench/frontend/components/check_results/priceTable.py +48 -0
  34. vectordb_bench/frontend/components/run_test/autoRefresh.py +10 -0
  35. vectordb_bench/frontend/components/run_test/caseSelector.py +87 -0
  36. vectordb_bench/frontend/components/run_test/dbConfigSetting.py +47 -0
  37. vectordb_bench/frontend/components/run_test/dbSelector.py +36 -0
  38. vectordb_bench/frontend/components/run_test/generateTasks.py +21 -0
  39. vectordb_bench/frontend/components/run_test/hideSidebar.py +10 -0
  40. vectordb_bench/frontend/components/run_test/submitTask.py +69 -0
  41. vectordb_bench/frontend/const.py +391 -0
  42. vectordb_bench/frontend/pages/qps_with_price.py +60 -0
  43. vectordb_bench/frontend/pages/run_test.py +59 -0
  44. vectordb_bench/frontend/utils.py +6 -0
  45. vectordb_bench/frontend/vdb_benchmark.py +42 -0
  46. vectordb_bench/interface.py +239 -0
  47. vectordb_bench/log_util.py +103 -0
  48. vectordb_bench/metric.py +53 -0
  49. vectordb_bench/models.py +234 -0
  50. vectordb_bench/results/result_20230609_standard.json +3228 -0
  51. vectordb_bench-0.0.1.dist-info/LICENSE +21 -0
  52. vectordb_bench-0.0.1.dist-info/METADATA +226 -0
  53. vectordb_bench-0.0.1.dist-info/RECORD +56 -0
  54. vectordb_bench-0.0.1.dist-info/WHEEL +5 -0
  55. vectordb_bench-0.0.1.dist-info/entry_points.txt +2 -0
  56. vectordb_bench-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,124 @@
1
+ import time
2
+ import traceback
3
+ import concurrent
4
+ import multiprocessing as mp
5
+ import logging
6
+ from typing import Iterable
7
+ import numpy as np
8
+ from ..clients import api
9
+ from .. import utils
10
+ from ... import config
11
+
12
+
13
+ NUM_PER_BATCH = config.NUM_PER_BATCH
14
+ log = logging.getLogger(__name__)
15
+
16
+
17
+ class MultiProcessingSearchRunner:
18
+ """ multiprocessing search runner
19
+
20
+ Args:
21
+ k(int): search topk, default to 100
22
+ concurrency(Iterable): concurrencies, default [1, 5, 10, 15, 20, 25, 30, 35]
23
+ duration(int): duration for each concurency, default to 30s
24
+ """
25
+ def __init__(
26
+ self,
27
+ db: api.VectorDB,
28
+ test_data: np.ndarray,
29
+ k: int = 100,
30
+ filters: dict | None = None,
31
+ concurrencies: Iterable[int] = (1, 5, 10, 15, 20, 25, 30, 35),
32
+ duration: int = 30,
33
+ ):
34
+ self.db = db
35
+ self.k = k
36
+ self.filters = filters
37
+ self.concurrencies = concurrencies
38
+ self.duration = duration
39
+
40
+ self.test_data = utils.SharedNumpyArray(test_data)
41
+ log.debug(f"test dataset columns: {len(test_data)}")
42
+
43
+ def search(self, test_np: utils.SharedNumpyArray) -> tuple[int, float]:
44
+ with self.db.init():
45
+ test_data = test_np.read().tolist()
46
+ num, idx = len(test_data), 0
47
+
48
+ start_time = time.perf_counter()
49
+ count = 0
50
+ while time.perf_counter() < start_time + self.duration:
51
+ s = time.perf_counter()
52
+ try:
53
+ self.db.search_embedding(
54
+ test_data[idx],
55
+ self.k,
56
+ self.filters,
57
+ )
58
+ except Exception as e:
59
+ log.warning(f"VectorDB search_embedding error: {e}")
60
+ traceback.print_exc(chain=True)
61
+ raise e from None
62
+
63
+ count += 1
64
+ # loop through the test data
65
+ idx = idx + 1 if idx < num - 1 else 0
66
+
67
+ if count % 500 == 0:
68
+ log.debug(f"({mp.current_process().name:16}) search_count: {count}, latest_latency={time.perf_counter()-s}")
69
+
70
+ total_dur = round(time.perf_counter() - start_time, 4)
71
+ log.info(
72
+ f"{mp.current_process().name:16} search {self.duration}s: "
73
+ f"actual_dur={total_dur}s, count={count}, qps in this process: {round(count / total_dur, 4):3}"
74
+ )
75
+
76
+ return (count, total_dur)
77
+
78
+ @staticmethod
79
+ def get_mp_context():
80
+ mp_start_method = "forkserver" if "forkserver" in mp.get_all_start_methods() else "spawn"
81
+ log.debug(f"MultiProcessingSearchRunner get multiprocessing start method: {mp_start_method}")
82
+ return mp.get_context(mp_start_method)
83
+
84
+ def _run_all_concurrencies_mem_efficient(self) -> float:
85
+ max_qps = 0
86
+ try:
87
+ for conc in self.concurrencies:
88
+ with concurrent.futures.ProcessPoolExecutor(mp_context=self.get_mp_context(), max_workers=conc) as executor:
89
+ start = time.perf_counter()
90
+ log.info(f"start search {self.duration}s in concurrency {conc}, filters: {self.filters}")
91
+ future_iter = executor.map(self.search, [self.test_data for i in range(conc)])
92
+ all_count = sum([r[0] for r in future_iter])
93
+
94
+ cost = time.perf_counter() - start
95
+ qps = round(all_count / cost, 4)
96
+ log.info(f"end search in concurrency {conc}: dur={cost}s, total_count={all_count}, qps={qps}")
97
+
98
+ if qps > max_qps:
99
+ max_qps = qps
100
+ log.info(f"update largest qps with concurrency {conc}: current max_qps={max_qps}")
101
+ except Exception as e:
102
+ log.warning(f"fail to search all concurrencies: {self.concurrencies}, max_qps before failure={max_qps}, reason={e}")
103
+ traceback.print_exc()
104
+
105
+ # No results available, raise exception
106
+ if max_qps == 0.0:
107
+ raise e from None
108
+
109
+ finally:
110
+ self.stop()
111
+
112
+ return max_qps
113
+
114
+ def run(self) -> float:
115
+ """
116
+ Returns:
117
+ float: largest qps
118
+ """
119
+ return self._run_all_concurrencies_mem_efficient()
120
+
121
+ def stop(self) -> None:
122
+ if self.test_data:
123
+ self.test_data.unlink()
124
+ self.test_data = None
@@ -0,0 +1,164 @@
1
+ import time
2
+ import logging
3
+ import traceback
4
+ import concurrent
5
+ import multiprocessing as mp
6
+ import math
7
+ import numpy as np
8
+ import pandas as pd
9
+
10
+ from ..clients import api
11
+ from ...metric import calc_recall
12
+ from ...models import LoadTimeoutError
13
+ from .. import utils
14
+ from ... import config
15
+
16
+ NUM_PER_BATCH = config.NUM_PER_BATCH
17
+ LOAD_TIMEOUT = 24 * 60 * 60
18
+
19
+ log = logging.getLogger(__name__)
20
+
21
+
22
+ class SerialInsertRunner:
23
+ def __init__(self, db: api.VectorDB, train_emb: list[list[float]], train_id: list[int]):
24
+ log.debug(f"Dataset shape: {len(train_emb)}")
25
+ self.db = db
26
+ self.shared_emb = train_emb
27
+ self.train_id = train_id
28
+
29
+ self.seq_batches = math.ceil(len(train_emb)/NUM_PER_BATCH)
30
+
31
+ def insert_data(self, left_id: int = 0) -> int:
32
+ with self.db.init():
33
+ all_embeddings = self.shared_emb
34
+
35
+ # unique id for endlessness insertion
36
+ all_metadata = [i+left_id for i in self.train_id]
37
+
38
+ num_conc_batches = math.ceil(len(all_embeddings)/NUM_PER_BATCH)
39
+ log.info(f"({mp.current_process().name:16}) Start inserting {len(all_embeddings)} embeddings in batch {NUM_PER_BATCH}")
40
+ count = 0
41
+ for batch_id in range(self.seq_batches):
42
+ metadata = all_metadata[batch_id*NUM_PER_BATCH: (batch_id+1)*NUM_PER_BATCH]
43
+ embeddings = all_embeddings[batch_id*NUM_PER_BATCH: (batch_id+1)*NUM_PER_BATCH]
44
+
45
+ log.debug(f"({mp.current_process().name:16}) batch [{batch_id:3}/{num_conc_batches}], Start inserting {len(metadata)} embeddings")
46
+ insert_count = self.db.insert_embeddings(
47
+ embeddings=embeddings,
48
+ metadata=metadata,
49
+ )
50
+ log.debug(f"({mp.current_process().name:16}) batch [{batch_id:3}/{num_conc_batches}], Finish inserting {len(metadata)} embeddings")
51
+
52
+ assert insert_count == len(metadata)
53
+ count += insert_count
54
+ log.info(f"({mp.current_process().name:16}) Finish inserting {len(all_embeddings)} embeddings in batch {NUM_PER_BATCH}")
55
+ return count
56
+
57
+ @utils.time_it
58
+ def _insert_all_batches(self) -> int:
59
+ """Performance case only"""
60
+ with concurrent.futures.ProcessPoolExecutor(mp_context=mp.get_context('spawn'), max_workers=1) as executor:
61
+ future = executor.submit(self.insert_data)
62
+ count = future.result()
63
+ return count
64
+
65
+ def run_endlessness(self) -> int:
66
+ """run forever util DB raises exception or crash"""
67
+ start_time = time.perf_counter()
68
+ max_load_count, times = 0, 0
69
+ try:
70
+ with self.db.init():
71
+ self.db.ready_to_load()
72
+ while time.perf_counter() - start_time < config.CASE_TIMEOUT_IN_SECOND:
73
+ count = self.insert_data(left_id=max_load_count)
74
+ max_load_count += count
75
+ times += 1
76
+ log.info(f"Loaded {times} entire dataset, current max load counts={utils.numerize(max_load_count)}, {max_load_count}")
77
+ raise LoadTimeoutError("capacity case load timeout and stop")
78
+ except LoadTimeoutError as e:
79
+ log.info("load timetout, stop the load case")
80
+ raise e from None
81
+ except Exception as e:
82
+ log.info(f"Capacity case load reach limit, insertion counts={utils.numerize(max_load_count)}, {max_load_count}, err={e}")
83
+ traceback.print_exc()
84
+ return max_load_count
85
+
86
+ def run(self) -> int:
87
+ count, dur = self._insert_all_batches()
88
+ return count
89
+
90
+
91
+ class SerialSearchRunner:
92
+ def __init__(
93
+ self,
94
+ db: api.VectorDB,
95
+ test_data: list[list[float]],
96
+ ground_truth: pd.DataFrame,
97
+ k: int = 100,
98
+ filters: dict | None = None,
99
+ ):
100
+ self.db = db
101
+ self.k = k
102
+ self.filters = filters
103
+
104
+ if isinstance(test_data[0], np.ndarray):
105
+ self.test_data = [query.tolist() for query in test_data]
106
+ else:
107
+ self.test_data = test_data
108
+ self.ground_truth = ground_truth
109
+
110
+ def search(self, args: tuple[list, pd.DataFrame]):
111
+ log.info(f"{mp.current_process().name:14} start search the entire test_data to get recall and latency")
112
+ with self.db.init():
113
+ test_data, ground_truth = args
114
+
115
+ log.debug(f"test dataset size: {len(test_data)}")
116
+ log.info(f"ground truth size: {ground_truth.columns}, shape: {ground_truth.shape}")
117
+
118
+ latencies, recalls = [], []
119
+ for idx, emb in enumerate(test_data):
120
+ s = time.perf_counter()
121
+ try:
122
+ results = self.db.search_embedding(
123
+ emb,
124
+ self.k,
125
+ self.filters,
126
+ )
127
+
128
+ except Exception as e:
129
+ log.warning(f"VectorDB search_embedding error: {e}")
130
+ traceback.print_exc(chain=True)
131
+ raise e from None
132
+
133
+ latencies.append(time.perf_counter() - s)
134
+
135
+ gt = ground_truth['neighbors_id'][idx]
136
+ recalls.append(calc_recall(self.k, gt[:self.k], results))
137
+
138
+
139
+ if len(latencies) % 100 == 0:
140
+ log.debug(f"({mp.current_process().name:14}) search_count={len(latencies):3}, latest_latency={latencies[-1]}, latest recall={recalls[-1]}")
141
+
142
+ avg_latency = round(np.mean(latencies), 4)
143
+ avg_recall = round(np.mean(recalls), 4)
144
+ cost = round(np.sum(latencies), 4)
145
+ p99 = round(np.percentile(latencies, 99), 4)
146
+ log.info(
147
+ f"{mp.current_process().name:14} search entire test_data: "
148
+ f"cost={cost}s, "
149
+ f"queries={len(latencies)}, "
150
+ f"avg_recall={avg_recall}, "
151
+ f"avg_latency={avg_latency}, "
152
+ f"p99={p99}"
153
+ )
154
+ return (avg_recall, p99)
155
+
156
+
157
+ def _run_in_subprocess(self) -> tuple[float, float]:
158
+ with concurrent.futures.ProcessPoolExecutor(max_workers=1) as executor:
159
+ future = executor.submit(self.search, (self.test_data, self.ground_truth))
160
+ result = future.result()
161
+ return result
162
+
163
+ def run(self) -> tuple[float, float]:
164
+ return self._run_in_subprocess()
@@ -0,0 +1,290 @@
1
+ import logging
2
+ import traceback
3
+ import concurrent
4
+ import numpy as np
5
+ from enum import Enum, auto
6
+
7
+ from . import utils
8
+ from .cases import Case, CaseLabel
9
+ from ..base import BaseModel
10
+ from ..models import TaskConfig
11
+
12
+ from .clients import (
13
+ api,
14
+ ZillizCloud,
15
+ Milvus,
16
+ MetricType
17
+ )
18
+ from ..metric import Metric
19
+ from .runner import MultiProcessingSearchRunner
20
+ from .runner import SerialSearchRunner, SerialInsertRunner
21
+
22
+
23
+ log = logging.getLogger(__name__)
24
+
25
+
26
+ class RunningStatus(Enum):
27
+ PENDING = auto()
28
+ FINISHED = auto()
29
+
30
+
31
+ class CaseRunner(BaseModel):
32
+ """ DataSet, filter_rate, db_class with db config
33
+
34
+ Fields:
35
+ run_id(str): run_id of this case runner,
36
+ indicating which task does this case belong to.
37
+ config(TaskConfig): task configs of this case runner.
38
+ ca(Case): case for this case runner.
39
+ status(RunningStatus): RunningStatus of this case runner.
40
+
41
+ db(api.VectorDB): The vector database for this case runner.
42
+ """
43
+
44
+ run_id: str
45
+ config: TaskConfig
46
+ ca: Case
47
+ status: RunningStatus
48
+
49
+ db: api.VectorDB | None = None
50
+ test_emb: np.ndarray | None = None
51
+ search_runner: MultiProcessingSearchRunner | None = None
52
+ serial_search_runner: SerialSearchRunner | None = None
53
+
54
+ def __eq__(self, obj):
55
+ if isinstance(obj, CaseRunner):
56
+ return self.ca.label == CaseLabel.Performance and \
57
+ self.config.db == obj.config.db and \
58
+ self.config.db_case_config == obj.config.db_case_config and \
59
+ self.ca.dataset == obj.ca.dataset
60
+ return False
61
+
62
+ def display(self) -> dict:
63
+ c_dict = self.ca.dict(include={'label':True, 'filters': True,'dataset':{'data': True} })
64
+ c_dict['db'] = self.config.db_name
65
+ return c_dict
66
+
67
+ @property
68
+ def normalize(self) -> bool:
69
+ assert self.db
70
+ return isinstance(self.db, (Milvus, ZillizCloud)) and \
71
+ self.ca.dataset.data.metric_type == MetricType.COSINE
72
+
73
+ def init_db(self, drop_old: bool = True) -> None:
74
+ db_cls = self.config.db.init_cls
75
+
76
+ self.db = db_cls(
77
+ dim=self.ca.dataset.data.dim,
78
+ db_config=self.config.db_config.to_dict(),
79
+ db_case_config=self.config.db_case_config,
80
+ drop_old=drop_old,
81
+ )
82
+
83
+ def _pre_run(self, drop_old: bool = True):
84
+ try:
85
+ self.ca.dataset.prepare()
86
+ self.init_db(drop_old)
87
+ except Exception as e:
88
+ log.warning(f"pre run case error: {e}")
89
+ raise e from None
90
+
91
+ def run(self, drop_old: bool = True) -> Metric:
92
+ self._pre_run(drop_old)
93
+
94
+ if self.ca.label == CaseLabel.Load:
95
+ return self._run_load_case()
96
+ elif self.ca.label == CaseLabel.Performance:
97
+ return self._run_perf_case(drop_old)
98
+ else:
99
+ log.warning(f"unknown case type: {self.ca.label}")
100
+ raise ValueError(f"Unknown case type: {self.ca.label}")
101
+
102
+
103
+ def _run_load_case(self) -> Metric:
104
+ """ run load cases
105
+
106
+ Returns:
107
+ Metric: the max load count
108
+ """
109
+ log.info("Start capacity case")
110
+ # datasets for load tests are quite small, can fit into memory
111
+ # only 1 file
112
+ data_df = [data_df for data_df in self.ca.dataset][0]
113
+
114
+ all_embeddings, all_metadata = np.stack(data_df["emb"]).tolist(), data_df['id'].tolist()
115
+ runner = SerialInsertRunner(self.db, all_embeddings, all_metadata)
116
+ try:
117
+ count = runner.run_endlessness()
118
+ log.info(f"load reach limit: insertion counts={count}")
119
+ return Metric(max_load_count=count)
120
+ except Exception as e:
121
+ log.warning(f"run capacity case error: {e}")
122
+ raise e from None
123
+ log.info("End capacity case")
124
+
125
+
126
+ def _run_perf_case(self, drop_old: bool = True) -> Metric:
127
+ try:
128
+ m = Metric()
129
+ if drop_old:
130
+ _, load_dur = self._load_train_data()
131
+ build_dur = self._optimize()
132
+ m.load_duration = round(load_dur+build_dur, 4)
133
+
134
+ self._init_search_runner()
135
+ m.recall, m.serial_latency_p99 = self._serial_search()
136
+ m.qps = self._conc_search()
137
+
138
+ log.info(f"got results: {m}")
139
+ return m
140
+ except Exception as e:
141
+ log.warning(f"performance case run error: {e}")
142
+ traceback.print_exc()
143
+ raise e
144
+
145
+ @utils.time_it
146
+ def _load_train_data(self):
147
+ """Insert train data and get the insert_duration"""
148
+ for data_df in self.ca.dataset:
149
+ try:
150
+ all_metadata = data_df['id'].tolist()
151
+
152
+ emb_np = np.stack(data_df['emb'])
153
+ if self.normalize:
154
+ log.debug("normalize the 100k train data")
155
+ all_embeddings = emb_np / np.linalg.norm(emb_np, axis=1)[:, np.newaxis].tolist()
156
+ else:
157
+ all_embeddings = emb_np.tolist()
158
+
159
+ del(emb_np)
160
+ log.debug(f"normalized size: {len(all_embeddings)}, {len(all_metadata)}")
161
+
162
+ runner = SerialInsertRunner(self.db, all_embeddings, all_metadata)
163
+ runner.run()
164
+ except Exception as e:
165
+ raise e from None
166
+ finally:
167
+ runner = None
168
+
169
+
170
+ def _serial_search(self) -> tuple[float, float]:
171
+ """Performance serial tests, search the entire test data once,
172
+ calculate the recall, serial_latency_p99
173
+
174
+ Returns:
175
+ tuple[float, float]: recall, serial_latency_p99
176
+ """
177
+ try:
178
+ return self.serial_search_runner.run()
179
+ except Exception as e:
180
+ log.warning(f"search error: {str(e)}, {e}")
181
+ self.stop()
182
+ raise e from None
183
+
184
+ def _conc_search(self):
185
+ """Performance concurrency tests, search the test data endlessness
186
+ for 30s in several concurrencies
187
+
188
+ Returns:
189
+ float: the largest qps in all concurrencies
190
+ """
191
+ try:
192
+ return self.search_runner.run()
193
+ except Exception as e:
194
+ log.warning(f"search error: {str(e)}, {e}")
195
+ raise e from None
196
+ finally:
197
+ self.stop()
198
+
199
+ @utils.time_it
200
+ def _task(self) -> None:
201
+ """"""
202
+ with self.db.init():
203
+ self.db.ready_to_search()
204
+
205
+ def _optimize(self) -> float:
206
+ with concurrent.futures.ProcessPoolExecutor(max_workers=1) as executor:
207
+ future = executor.submit(self._task)
208
+ try:
209
+ return future.result()[1]
210
+ except Exception as e:
211
+ log.warning(f"VectorDB ready_to_search error: {e}")
212
+ raise e from None
213
+
214
+ def _init_search_runner(self):
215
+ test_emb = np.stack(self.ca.dataset.test_data["emb"])
216
+ if self.normalize:
217
+ test_emb = test_emb / np.linalg.norm(test_emb, axis=1)[:, np.newaxis]
218
+ self.test_emb = test_emb
219
+
220
+ gt_df = self.ca.dataset.get_ground_truth(self.ca.filter_rate)
221
+
222
+ self.serial_search_runner = SerialSearchRunner(
223
+ db=self.db,
224
+ test_data=self.test_emb.tolist(),
225
+ ground_truth=gt_df,
226
+ filters=self.ca.filters,
227
+ )
228
+
229
+ self.search_runner = MultiProcessingSearchRunner(
230
+ db=self.db,
231
+ test_data=self.test_emb,
232
+ filters=self.ca.filters,
233
+ )
234
+
235
+ def stop(self):
236
+ if self.search_runner:
237
+ self.search_runner.stop()
238
+
239
+
240
+ class TaskRunner(BaseModel):
241
+ run_id: str
242
+ task_label: str
243
+ case_runners: list[CaseRunner]
244
+
245
+ def num_cases(self) -> int:
246
+ return len(self.case_runners)
247
+
248
+ def num_finished(self) -> int:
249
+ return self._get_num_by_status(RunningStatus.FINISHED)
250
+
251
+ def set_finished(self, idx: int) -> None:
252
+ self.case_runners[idx].status = RunningStatus.FINISHED
253
+
254
+ def _get_num_by_status(self, status: RunningStatus) -> int:
255
+ return sum([1 for c in self.case_runners if c.status == status])
256
+
257
+ def display(self) -> None:
258
+ DATA_FORMAT = (" %-14s | %-12s %-20s %7s | %-10s")
259
+ TITLE_FORMAT = (" %-14s | %-12s %-20s %7s | %-10s") % (
260
+ "DB", "CaseType", "Dataset", "Filter", "task_label")
261
+
262
+ fmt = [TITLE_FORMAT]
263
+ fmt.append(DATA_FORMAT%(
264
+ "-"*11,
265
+ "-"*12,
266
+ "-"*20,
267
+ "-"*7,
268
+ "-"*7
269
+ ))
270
+
271
+ for f in self.case_runners:
272
+ if f.ca.filter_rate != 0.0:
273
+ filters = f.ca.filter_rate
274
+ elif f.ca.filter_size != 0:
275
+ filters = f.ca.filter_size
276
+ else:
277
+ filters = "None"
278
+
279
+ ds_str = f"{f.ca.dataset.data.name}-{f.ca.dataset.data.label}-{utils.numerize(f.ca.dataset.data.size)}"
280
+ fmt.append(DATA_FORMAT%(
281
+ f.config.db_name,
282
+ f.ca.label.name,
283
+ ds_str,
284
+ filters,
285
+ self.task_label,
286
+ ))
287
+
288
+ tmp_logger = logging.getLogger("no_color")
289
+ for f in fmt:
290
+ tmp_logger.info(f)
@@ -0,0 +1,85 @@
1
+ import time
2
+ from functools import wraps
3
+ from multiprocessing.shared_memory import SharedMemory
4
+
5
+ import numpy as np
6
+
7
+
8
+ def numerize(n) -> str:
9
+ """display positive number n for readability
10
+
11
+ Examples:
12
+ >>> numerize(1_000)
13
+ '1K'
14
+ >>> numerize(1_000_000_000)
15
+ '1B'
16
+ """
17
+ sufix2upbound = {
18
+ "EMPTY": 1e3,
19
+ "K": 1e6,
20
+ "M": 1e9,
21
+ "B": 1e12,
22
+ "END": float('inf'),
23
+ }
24
+
25
+ display_n, sufix = n, ""
26
+ for s, base in sufix2upbound.items():
27
+ # number >= 1000B will alway have sufix 'B'
28
+ if s == "END":
29
+ display_n = int(n/1e9)
30
+ sufix = "B"
31
+ break
32
+
33
+ if n < base:
34
+ sufix = "" if s == "EMPTY" else s
35
+ display_n = int(n/(base/1e3))
36
+ break
37
+ return f"{display_n}{sufix}"
38
+
39
+
40
+ def time_it(func):
41
+ @wraps(func)
42
+ def inner(*args, **kwargs):
43
+ pref = time.perf_counter()
44
+ result = func(*args, **kwargs)
45
+ delta = time.perf_counter() - pref
46
+ return result, delta
47
+ return inner
48
+
49
+
50
+ class SharedNumpyArray:
51
+ ''' Wraps a numpy array so that it can be shared quickly among processes,
52
+ avoiding unnecessary copying and (de)serializing.
53
+ '''
54
+ def __init__(self, array: np.ndarray):
55
+ '''
56
+ Creates the shared memory and copies the array therein
57
+ '''
58
+ # create the shared memory location of the same size of the array
59
+ self._shared = SharedMemory(create=True, size=array.nbytes)
60
+
61
+ # save data type and shape, necessary to read the data correctly
62
+ self._dtype, self._shape = array.dtype, array.shape
63
+
64
+ # create a new numpy array that uses the shared memory we created.
65
+ # at first, it is filled with zeros
66
+ res = np.ndarray(
67
+ self._shape, dtype=self._dtype, buffer=self._shared.buf
68
+ )
69
+
70
+ # copy data from the array to the shared memory. numpy will
71
+ # take care of copying everything in the correct format
72
+ res[:] = array[:]
73
+
74
+ def read(self) -> np.ndarray:
75
+ '''Reads the array from the shared memory without unnecessary copying. '''
76
+ # simply create an array of the correct shape and type,
77
+ # using the shared memory location we created earlier
78
+ return np.ndarray(self._shape, self._dtype, buffer=self._shared.buf)
79
+
80
+ def unlink(self) -> None:
81
+ ''' Releases the allocated memory. Call when finished using the data,
82
+ or when the data was copied somewhere else.
83
+ '''
84
+ self._shared.close()
85
+ self._shared.unlink()
vectordb_bench/base.py ADDED
@@ -0,0 +1,6 @@
1
+ from pydantic import BaseModel as PydanticBaseModel
2
+
3
+
4
+ class BaseModel(PydanticBaseModel, arbitrary_types_allowed=True):
5
+ pass
6
+