vectordb-bench 0.0.17__py3-none-any.whl → 0.0.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. vectordb_bench/backend/cases.py +1 -1
  2. vectordb_bench/backend/clients/__init__.py +39 -0
  3. vectordb_bench/backend/clients/aliyun_elasticsearch/aliyun_elasticsearch.py +27 -0
  4. vectordb_bench/backend/clients/aliyun_elasticsearch/config.py +19 -0
  5. vectordb_bench/backend/clients/aliyun_opensearch/aliyun_opensearch.py +304 -0
  6. vectordb_bench/backend/clients/aliyun_opensearch/config.py +48 -0
  7. vectordb_bench/backend/clients/alloydb/alloydb.py +372 -0
  8. vectordb_bench/backend/clients/alloydb/cli.py +147 -0
  9. vectordb_bench/backend/clients/alloydb/config.py +168 -0
  10. vectordb_bench/backend/clients/api.py +5 -0
  11. vectordb_bench/backend/clients/milvus/cli.py +25 -1
  12. vectordb_bench/backend/clients/milvus/config.py +16 -2
  13. vectordb_bench/backend/clients/milvus/milvus.py +4 -6
  14. vectordb_bench/backend/runner/rate_runner.py +32 -15
  15. vectordb_bench/backend/runner/read_write_runner.py +102 -36
  16. vectordb_bench/backend/runner/serial_runner.py +8 -2
  17. vectordb_bench/backend/runner/util.py +0 -16
  18. vectordb_bench/backend/task_runner.py +4 -3
  19. vectordb_bench/backend/utils.py +1 -0
  20. vectordb_bench/cli/vectordbbench.py +2 -0
  21. vectordb_bench/frontend/config/dbCaseConfigs.py +224 -0
  22. vectordb_bench/models.py +9 -0
  23. {vectordb_bench-0.0.17.dist-info → vectordb_bench-0.0.19.dist-info}/METADATA +13 -23
  24. {vectordb_bench-0.0.17.dist-info → vectordb_bench-0.0.19.dist-info}/RECORD +28 -21
  25. {vectordb_bench-0.0.17.dist-info → vectordb_bench-0.0.19.dist-info}/LICENSE +0 -0
  26. {vectordb_bench-0.0.17.dist-info → vectordb_bench-0.0.19.dist-info}/WHEEL +0 -0
  27. {vectordb_bench-0.0.17.dist-info → vectordb_bench-0.0.19.dist-info}/entry_points.txt +0 -0
  28. {vectordb_bench-0.0.17.dist-info → vectordb_bench-0.0.19.dist-info}/top_level.txt +0 -0
@@ -10,6 +10,7 @@ class MetricType(str, Enum):
10
10
  L2 = "L2"
11
11
  COSINE = "COSINE"
12
12
  IP = "IP"
13
+ DP = "DP"
13
14
  HAMMING = "HAMMING"
14
15
  JACCARD = "JACCARD"
15
16
 
@@ -27,6 +28,7 @@ class IndexType(str, Enum):
27
28
  GPU_IVF_FLAT = "GPU_IVF_FLAT"
28
29
  GPU_IVF_PQ = "GPU_IVF_PQ"
29
30
  GPU_CAGRA = "GPU_CAGRA"
31
+ SCANN = "scann"
30
32
 
31
33
 
32
34
  class DBConfig(ABC, BaseModel):
@@ -202,6 +204,9 @@ class VectorDB(ABC):
202
204
  """
203
205
  raise NotImplementedError
204
206
 
207
+ def optimize_with_size(self, data_size: int):
208
+ self.optimize()
209
+
205
210
  # TODO: remove
206
211
  @abstractmethod
207
212
  def ready_to_load(self):
@@ -1,4 +1,4 @@
1
- from typing import Annotated, TypedDict, Unpack
1
+ from typing import Annotated, TypedDict, Unpack, Optional
2
2
 
3
3
  import click
4
4
  from pydantic import SecretStr
@@ -21,6 +21,12 @@ class MilvusTypedDict(TypedDict):
21
21
  uri: Annotated[
22
22
  str, click.option("--uri", type=str, help="uri connection string", required=True)
23
23
  ]
24
+ user_name: Annotated[
25
+ Optional[str], click.option("--user-name", type=str, help="Db username", required=False)
26
+ ]
27
+ password: Annotated[
28
+ Optional[str], click.option("--password", type=str, help="Db password", required=False)
29
+ ]
24
30
 
25
31
 
26
32
  class MilvusAutoIndexTypedDict(CommonTypedDict, MilvusTypedDict):
@@ -37,6 +43,8 @@ def MilvusAutoIndex(**parameters: Unpack[MilvusAutoIndexTypedDict]):
37
43
  db_config=MilvusConfig(
38
44
  db_label=parameters["db_label"],
39
45
  uri=SecretStr(parameters["uri"]),
46
+ user=parameters["user_name"],
47
+ password=SecretStr(parameters["password"]),
40
48
  ),
41
49
  db_case_config=AutoIndexConfig(),
42
50
  **parameters,
@@ -53,6 +61,8 @@ def MilvusFlat(**parameters: Unpack[MilvusAutoIndexTypedDict]):
53
61
  db_config=MilvusConfig(
54
62
  db_label=parameters["db_label"],
55
63
  uri=SecretStr(parameters["uri"]),
64
+ user=parameters["user_name"],
65
+ password=SecretStr(parameters["password"]),
56
66
  ),
57
67
  db_case_config=FLATConfig(),
58
68
  **parameters,
@@ -73,6 +83,8 @@ def MilvusHNSW(**parameters: Unpack[MilvusHNSWTypedDict]):
73
83
  db_config=MilvusConfig(
74
84
  db_label=parameters["db_label"],
75
85
  uri=SecretStr(parameters["uri"]),
86
+ user=parameters["user_name"],
87
+ password=SecretStr(parameters["password"]) if parameters["password"] else None,
76
88
  ),
77
89
  db_case_config=HNSWConfig(
78
90
  M=parameters["m"],
@@ -97,6 +109,8 @@ def MilvusIVFFlat(**parameters: Unpack[MilvusIVFFlatTypedDict]):
97
109
  db_config=MilvusConfig(
98
110
  db_label=parameters["db_label"],
99
111
  uri=SecretStr(parameters["uri"]),
112
+ user=parameters["user_name"],
113
+ password=SecretStr(parameters["password"]),
100
114
  ),
101
115
  db_case_config=IVFFlatConfig(
102
116
  nlist=parameters["nlist"],
@@ -116,6 +130,8 @@ def MilvusIVFSQ8(**parameters: Unpack[MilvusIVFFlatTypedDict]):
116
130
  db_config=MilvusConfig(
117
131
  db_label=parameters["db_label"],
118
132
  uri=SecretStr(parameters["uri"]),
133
+ user=parameters["user_name"],
134
+ password=SecretStr(parameters["password"]),
119
135
  ),
120
136
  db_case_config=IVFSQ8Config(
121
137
  nlist=parameters["nlist"],
@@ -143,6 +159,8 @@ def MilvusDISKANN(**parameters: Unpack[MilvusDISKANNTypedDict]):
143
159
  db_config=MilvusConfig(
144
160
  db_label=parameters["db_label"],
145
161
  uri=SecretStr(parameters["uri"]),
162
+ user=parameters["user_name"],
163
+ password=SecretStr(parameters["password"]),
146
164
  ),
147
165
  db_case_config=DISKANNConfig(
148
166
  search_list=parameters["search_list"],
@@ -174,6 +192,8 @@ def MilvusGPUIVFFlat(**parameters: Unpack[MilvusGPUIVFTypedDict]):
174
192
  db_config=MilvusConfig(
175
193
  db_label=parameters["db_label"],
176
194
  uri=SecretStr(parameters["uri"]),
195
+ user=parameters["user_name"],
196
+ password=SecretStr(parameters["password"]),
177
197
  ),
178
198
  db_case_config=GPUIVFFlatConfig(
179
199
  nlist=parameters["nlist"],
@@ -208,6 +228,8 @@ def MilvusGPUIVFPQ(**parameters: Unpack[MilvusGPUIVFPQTypedDict]):
208
228
  db_config=MilvusConfig(
209
229
  db_label=parameters["db_label"],
210
230
  uri=SecretStr(parameters["uri"]),
231
+ user=parameters["user_name"],
232
+ password=SecretStr(parameters["password"]),
211
233
  ),
212
234
  db_case_config=GPUIVFPQConfig(
213
235
  nlist=parameters["nlist"],
@@ -274,6 +296,8 @@ def MilvusGPUCAGRA(**parameters: Unpack[MilvusGPUCAGRATypedDict]):
274
296
  db_config=MilvusConfig(
275
297
  db_label=parameters["db_label"],
276
298
  uri=SecretStr(parameters["uri"]),
299
+ user=parameters["user_name"],
300
+ password=SecretStr(parameters["password"]),
277
301
  ),
278
302
  db_case_config=GPUCAGRAConfig(
279
303
  intermediate_graph_degree=parameters["intermediate_graph_degree"],
@@ -1,12 +1,26 @@
1
- from pydantic import BaseModel, SecretStr
1
+ from pydantic import BaseModel, SecretStr, validator
2
2
  from ..api import DBConfig, DBCaseConfig, MetricType, IndexType
3
3
 
4
4
 
5
5
  class MilvusConfig(DBConfig):
6
6
  uri: SecretStr = "http://localhost:19530"
7
+ user: str | None = None
8
+ password: SecretStr | None = None
7
9
 
8
10
  def to_dict(self) -> dict:
9
- return {"uri": self.uri.get_secret_value()}
11
+ return {
12
+ "uri": self.uri.get_secret_value(),
13
+ "user": self.user if self.user else None,
14
+ "password": self.password.get_secret_value() if self.password else None,
15
+ }
16
+
17
+ @validator("*")
18
+ def not_empty_field(cls, v, field):
19
+ if field.name in cls.common_short_configs() or field.name in cls.common_long_configs() or field.name in ["user", "password"]:
20
+ return v
21
+ if isinstance(v, (str, SecretStr)) and len(v) == 0:
22
+ raise ValueError("Empty string!")
23
+ return v
10
24
 
11
25
 
12
26
  class MilvusIndexConfig(BaseModel):
@@ -8,7 +8,7 @@ from typing import Iterable
8
8
  from pymilvus import Collection, utility
9
9
  from pymilvus import CollectionSchema, DataType, FieldSchema, MilvusException
10
10
 
11
- from ..api import VectorDB, IndexType
11
+ from ..api import VectorDB
12
12
  from .config import MilvusIndexConfig
13
13
 
14
14
 
@@ -66,8 +66,7 @@ class Milvus(VectorDB):
66
66
  self.case_config.index_param(),
67
67
  index_name=self._index_name,
68
68
  )
69
- if kwargs.get("pre_load") is True:
70
- self._pre_load(col)
69
+ col.load()
71
70
 
72
71
  connections.disconnect("default")
73
72
 
@@ -90,8 +89,8 @@ class Milvus(VectorDB):
90
89
  connections.disconnect("default")
91
90
 
92
91
  def _optimize(self):
93
- self._post_insert()
94
92
  log.info(f"{self.name} optimizing before search")
93
+ self._post_insert()
95
94
  try:
96
95
  self.col.load(refresh=True)
97
96
  except Exception as e:
@@ -99,7 +98,6 @@ class Milvus(VectorDB):
99
98
  raise e from None
100
99
 
101
100
  def _post_insert(self):
102
- log.info(f"{self.name} post insert before optimize")
103
101
  try:
104
102
  self.col.flush()
105
103
  # wait for index done and load refresh
@@ -130,7 +128,7 @@ class Milvus(VectorDB):
130
128
  log.warning(f"{self.name} compact error: {e}")
131
129
  if hasattr(e, 'code'):
132
130
  if e.code().name == 'PERMISSION_DENIED':
133
- log.warning(f"Skip compact due to permission denied.")
131
+ log.warning("Skip compact due to permission denied.")
134
132
  pass
135
133
  else:
136
134
  raise e
@@ -1,5 +1,6 @@
1
1
  import logging
2
2
  import time
3
+ import concurrent
3
4
  from concurrent.futures import ThreadPoolExecutor
4
5
  import multiprocessing as mp
5
6
 
@@ -9,7 +10,7 @@ from vectordb_bench.backend.dataset import DataSetIterator
9
10
  from vectordb_bench.backend.utils import time_it
10
11
  from vectordb_bench import config
11
12
 
12
- from .util import get_data, is_futures_completed, get_future_exceptions
13
+ from .util import get_data
13
14
  log = logging.getLogger(__name__)
14
15
 
15
16
 
@@ -54,26 +55,42 @@ class RatedMultiThreadingInsertRunner:
54
55
  start_time = time.perf_counter()
55
56
  finished, elapsed_time = submit_by_rate()
56
57
  if finished is True:
57
- q.put(None, block=True)
58
+ q.put(True, block=True)
58
59
  log.info(f"End of dataset, left unfinished={len(executing_futures)}")
59
- return
60
+ break
60
61
 
61
- q.put(True, block=False)
62
+ q.put(False, block=False)
62
63
  wait_interval = 1 - elapsed_time if elapsed_time < 1 else 0.001
63
64
 
64
- e, completed = is_futures_completed(executing_futures, wait_interval)
65
- if completed is True:
66
- ex = get_future_exceptions(executing_futures)
67
- if ex is not None:
68
- log.warn(f"task error, terminating, err={ex}")
69
- q.put(None)
70
- executor.shutdown(wait=True, cancel_futures=True)
71
- raise ex
65
+ try:
66
+ done, not_done = concurrent.futures.wait(
67
+ executing_futures,
68
+ timeout=wait_interval,
69
+ return_when=concurrent.futures.FIRST_EXCEPTION)
70
+
71
+ if len(not_done) > 0:
72
+ log.warning(f"Failed to finish all tasks in 1s, [{len(not_done)}/{len(executing_futures)}] tasks are not done, waited={wait_interval:.2f}, trying to wait in the next round")
73
+ executing_futures = list(not_done)
72
74
  else:
73
75
  log.debug(f"Finished {len(executing_futures)} insert-{config.NUM_PER_BATCH} task in 1s, wait_interval={wait_interval:.2f}")
74
- executing_futures = []
75
- else:
76
- log.warning(f"Failed to finish tasks in 1s, {e}, waited={wait_interval:.2f}, try to check the next round")
76
+ executing_futures = []
77
+ except Exception as e:
78
+ log.warn(f"task error, terminating, err={e}")
79
+ q.put(None, block=True)
80
+ executor.shutdown(wait=True, cancel_futures=True)
81
+ raise e
82
+
77
83
  dur = time.perf_counter() - start_time
78
84
  if dur < 1:
79
85
  time.sleep(1 - dur)
86
+
87
+ # wait for all tasks in executing_futures to complete
88
+ if len(executing_futures) > 0:
89
+ try:
90
+ done, _ = concurrent.futures.wait(executing_futures,
91
+ return_when=concurrent.futures.FIRST_EXCEPTION)
92
+ except Exception as e:
93
+ log.warn(f"task error, terminating, err={e}")
94
+ q.put(None, block=True)
95
+ executor.shutdown(wait=True, cancel_futures=True)
96
+ raise e
@@ -24,7 +24,7 @@ class ReadWriteRunner(MultiProcessingSearchRunner, RatedMultiThreadingInsertRunn
24
24
  k: int = 100,
25
25
  filters: dict | None = None,
26
26
  concurrencies: Iterable[int] = (1, 15, 50),
27
- search_stage: Iterable[float] = (0.5, 0.6, 0.7, 0.8, 0.9, 1.0), # search in any insert portion, 0.0 means search from the start
27
+ search_stage: Iterable[float] = (0.5, 0.6, 0.7, 0.8, 0.9), # search from insert portion, 0.0 means search from the start
28
28
  read_dur_after_write: int = 300, # seconds, search duration when insertion is done
29
29
  timeout: float | None = None,
30
30
  ):
@@ -32,7 +32,7 @@ class ReadWriteRunner(MultiProcessingSearchRunner, RatedMultiThreadingInsertRunn
32
32
  self.data_volume = dataset.data.size
33
33
 
34
34
  for stage in search_stage:
35
- assert 0.0 <= stage <= 1.0, "each search stage should be in [0.0, 1.0]"
35
+ assert 0.0 <= stage < 1.0, "each search stage should be in [0.0, 1.0)"
36
36
  self.search_stage = sorted(search_stage)
37
37
  self.read_dur_after_write = read_dur_after_write
38
38
 
@@ -65,48 +65,114 @@ class ReadWriteRunner(MultiProcessingSearchRunner, RatedMultiThreadingInsertRunn
65
65
  k=k,
66
66
  )
67
67
 
68
+ def run_optimize(self):
69
+ """Optimize needs to run in differenct process for pymilvus schema recursion problem"""
70
+ with self.db.init():
71
+ log.info("Search after write - Optimize start")
72
+ self.db.optimize()
73
+ log.info("Search after write - Optimize finished")
74
+
75
+ def run_search(self):
76
+ log.info("Search after write - Serial search start")
77
+ res, ssearch_dur = self.serial_search_runner.run()
78
+ recall, ndcg, p99_latency = res
79
+ log.info(f"Search after write - Serial search - recall={recall}, ndcg={ndcg}, p99={p99_latency}, dur={ssearch_dur:.4f}")
80
+ log.info(f"Search after wirte - Conc search start, dur for each conc={self.read_dur_after_write}")
81
+ max_qps = self.run_by_dur(self.read_dur_after_write)
82
+ log.info(f"Search after wirte - Conc search finished, max_qps={max_qps}")
83
+
84
+ return (max_qps, recall, ndcg, p99_latency)
85
+
68
86
  def run_read_write(self):
69
- futures = []
70
87
  with mp.Manager() as m:
71
88
  q = m.Queue()
72
89
  with concurrent.futures.ProcessPoolExecutor(mp_context=mp.get_context("spawn"), max_workers=2) as executor:
73
- futures.append(executor.submit(self.run_with_rate, q))
74
- futures.append(executor.submit(self.run_search_by_sig, q))
75
-
76
- for future in concurrent.futures.as_completed(futures):
77
- res = future.result()
78
- log.info(f"Result = {res}")
79
-
90
+ read_write_futures = []
91
+ read_write_futures.append(executor.submit(self.run_with_rate, q))
92
+ read_write_futures.append(executor.submit(self.run_search_by_sig, q))
93
+
94
+ try:
95
+ for f in concurrent.futures.as_completed(read_write_futures):
96
+ res = f.result()
97
+ log.info(f"Result = {res}")
98
+
99
+ # Wait for read_write_futures finishing and do optimize and search
100
+ op_future = executor.submit(self.run_optimize)
101
+ op_future.result()
102
+
103
+ search_future = executor.submit(self.run_search)
104
+ last_res = search_future.result()
105
+
106
+ log.info(f"Max QPS after optimze and search: {last_res}")
107
+ except Exception as e:
108
+ log.warning(f"Read and write error: {e}")
109
+ executor.shutdown(wait=True, cancel_futures=True)
110
+ raise e
80
111
  log.info("Concurrent read write all done")
81
112
 
82
-
83
113
  def run_search_by_sig(self, q):
84
- res = []
114
+ """
115
+ Args:
116
+ q: multiprocessing queue
117
+ (None) means abnormal exit
118
+ (False) means updating progress
119
+ (True) means normal exit
120
+ """
121
+ result, start_batch = [], 0
85
122
  total_batch = math.ceil(self.data_volume / self.insert_rate)
86
- batch = 0
87
- recall = 'x'
123
+ recall, ndcg, p99_latency = None, None, None
124
+
125
+ def wait_next_target(start, target_batch) -> bool:
126
+ """Return False when receive True or None"""
127
+ while start < target_batch:
128
+ sig = q.get(block=True)
129
+
130
+ if sig is None or sig is True:
131
+ return False
132
+ else:
133
+ start += 1
134
+ return True
88
135
 
89
136
  for idx, stage in enumerate(self.search_stage):
90
137
  target_batch = int(total_batch * stage)
91
- while q.get(block=True):
92
- batch += 1
93
- if batch >= target_batch:
94
- perc = int(stage * 100)
95
- log.info(f"Insert {perc}% done, total batch={total_batch}")
96
- log.info(f"[{batch}/{total_batch}] Serial search - {perc}% start")
97
- recall, ndcg, p99 =self.serial_search_runner.run()
98
-
99
- if idx < len(self.search_stage) - 1:
100
- stage_search_dur = (self.data_volume * (self.search_stage[idx + 1] - stage) // self.insert_rate) // len(self.concurrencies)
101
- if stage_search_dur < 30:
102
- log.warning(f"Search duration too short, please reduce concurrency count or insert rate, or increase dataset volume: dur={stage_search_dur}, concurrencies={len(self.concurrencies)}, insert_rate={self.insert_rate}")
103
- log.info(f"[{batch}/{total_batch}] Conc search - {perc}% start, dur for each conc={stage_search_dur}s")
104
- else:
105
- last_search_dur = self.data_volume * (1.0 - stage) // self.insert_rate
106
- stage_search_dur = last_search_dur + self.read_dur_after_write
107
- log.info(f"[{batch}/{total_batch}] Last conc search - {perc}% start, [read_until_write|read_after_write|total] =[{last_search_dur}s|{self.read_dur_after_write}s|{stage_search_dur}s]")
108
-
109
- max_qps = self.run_by_dur(stage_search_dur)
110
- res.append((perc, max_qps, recall))
111
- break
112
- return res
138
+ perc = int(stage * 100)
139
+
140
+ got = wait_next_target(start_batch, target_batch)
141
+ if got is False:
142
+ log.warning(f"Abnormal exit, target_batch={target_batch}, start_batch={start_batch}")
143
+ return
144
+
145
+ log.info(f"Insert {perc}% done, total batch={total_batch}")
146
+ log.info(f"[{target_batch}/{total_batch}] Serial search - {perc}% start")
147
+ res, ssearch_dur = self.serial_search_runner.run()
148
+ recall, ndcg, p99_latency = res
149
+ log.info(f"[{target_batch}/{total_batch}] Serial search - {perc}% done, recall={recall}, ndcg={ndcg}, p99={p99_latency}, dur={ssearch_dur:.4f}")
150
+
151
+ # Search duration for non-last search stage is carefully calculated.
152
+ # If duration for each concurrency is less than 30s, runner will raise error.
153
+ if idx < len(self.search_stage) - 1:
154
+ total_dur_between_stages = self.data_volume * (self.search_stage[idx + 1] - stage) // self.insert_rate
155
+ csearch_dur = total_dur_between_stages - ssearch_dur
156
+
157
+ # Try to leave room for init process executors
158
+ csearch_dur = csearch_dur - 30 if csearch_dur > 60 else csearch_dur
159
+
160
+ each_conc_search_dur = csearch_dur / len(self.concurrencies)
161
+ if each_conc_search_dur < 30:
162
+ warning_msg = f"Results might be inaccurate, duration[{csearch_dur:.4f}] left for conc-search is too short, total available dur={total_dur_between_stages}, serial_search_cost={ssearch_dur}."
163
+ log.warning(warning_msg)
164
+
165
+ # The last stage
166
+ else:
167
+ each_conc_search_dur = 60
168
+
169
+ log.info(f"[{target_batch}/{total_batch}] Concurrent search - {perc}% start, dur={each_conc_search_dur:.4f}")
170
+ max_qps = self.run_by_dur(each_conc_search_dur)
171
+ result.append((perc, max_qps, recall, ndcg, p99_latency))
172
+
173
+ start_batch = target_batch
174
+
175
+ # Drain the queue
176
+ while q.empty() is False:
177
+ q.get(block=True)
178
+ return result
@@ -167,7 +167,7 @@ class SerialSearchRunner:
167
167
  self.test_data = test_data
168
168
  self.ground_truth = ground_truth
169
169
 
170
- def search(self, args: tuple[list, pd.DataFrame]):
170
+ def search(self, args: tuple[list, pd.DataFrame]) -> tuple[float, float, float]:
171
171
  log.info(f"{mp.current_process().name:14} start search the entire test_data to get recall and latency")
172
172
  with self.db.init():
173
173
  test_data, ground_truth = args
@@ -224,5 +224,11 @@ class SerialSearchRunner:
224
224
  result = future.result()
225
225
  return result
226
226
 
227
- def run(self) -> tuple[float, float]:
227
+ @utils.time_it
228
+ def run(self) -> tuple[float, float, float]:
229
+ """
230
+ Returns:
231
+ tuple[tuple[float, float, float], float]: (avg_recall, avg_ndcg, p99_latency), cost
232
+
233
+ """
228
234
  return self._run_in_subprocess()
@@ -1,6 +1,4 @@
1
1
  import logging
2
- import concurrent
3
- from typing import Iterable
4
2
 
5
3
  from pandas import DataFrame
6
4
  import numpy as np
@@ -16,17 +14,3 @@ def get_data(data_df: DataFrame, normalize: bool) -> tuple[list[list[float]], li
16
14
  else:
17
15
  all_embeddings = emb_np.tolist()
18
16
  return all_embeddings, all_metadata
19
-
20
- def is_futures_completed(futures: Iterable[concurrent.futures.Future], interval) -> (Exception, bool):
21
- try:
22
- list(concurrent.futures.as_completed(futures, timeout=interval))
23
- except TimeoutError as e:
24
- return e, False
25
- return None, True
26
-
27
-
28
- def get_future_exceptions(futures: Iterable[concurrent.futures.Future]) -> BaseException | None:
29
- for f in futures:
30
- if f.exception() is not None:
31
- return f.exception()
32
- return
@@ -206,7 +206,7 @@ class CaseRunner(BaseModel):
206
206
  finally:
207
207
  runner = None
208
208
 
209
- def _serial_search(self) -> tuple[float, float]:
209
+ def _serial_search(self) -> tuple[float, float, float]:
210
210
  """Performance serial tests, search the entire test data once,
211
211
  calculate the recall, serial_latency_p99
212
212
 
@@ -214,7 +214,8 @@ class CaseRunner(BaseModel):
214
214
  tuple[float, float]: recall, serial_latency_p99
215
215
  """
216
216
  try:
217
- return self.serial_search_runner.run()
217
+ results, _ = self.serial_search_runner.run()
218
+ return results
218
219
  except Exception as e:
219
220
  log.warning(f"search error: {str(e)}, {e}")
220
221
  self.stop()
@@ -238,7 +239,7 @@ class CaseRunner(BaseModel):
238
239
  @utils.time_it
239
240
  def _task(self) -> None:
240
241
  with self.db.init():
241
- self.db.optimize()
242
+ self.db.optimize_with_size(data_size=self.ca.dataset.data.size)
242
243
 
243
244
  def _optimize(self) -> float:
244
245
  with concurrent.futures.ProcessPoolExecutor(max_workers=1) as executor:
@@ -35,6 +35,7 @@ def numerize(n) -> str:
35
35
 
36
36
 
37
37
  def time_it(func):
38
+ """ returns result and elapsed time"""
38
39
  @wraps(func)
39
40
  def inner(*args, **kwargs):
40
41
  pref = time.perf_counter()
@@ -9,6 +9,7 @@ from ..backend.clients.weaviate_cloud.cli import Weaviate
9
9
  from ..backend.clients.zilliz_cloud.cli import ZillizAutoIndex
10
10
  from ..backend.clients.milvus.cli import MilvusAutoIndex
11
11
  from ..backend.clients.aws_opensearch.cli import AWSOpenSearch
12
+ from ..backend.clients.alloydb.cli import AlloyDBScaNN
12
13
 
13
14
  from .cli import cli
14
15
 
@@ -24,6 +25,7 @@ cli.add_command(MilvusAutoIndex)
24
25
  cli.add_command(AWSOpenSearch)
25
26
  cli.add_command(PgVectorScaleDiskAnn)
26
27
  cli.add_command(PgDiskAnn)
28
+ cli.add_command(AlloyDBScaNN)
27
29
 
28
30
 
29
31
  if __name__ == "__main__":