vectordb-bench 0.0.10__py3-none-any.whl → 0.0.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. vectordb_bench/__init__.py +18 -5
  2. vectordb_bench/backend/cases.py +32 -12
  3. vectordb_bench/backend/clients/__init__.py +1 -0
  4. vectordb_bench/backend/clients/api.py +1 -1
  5. vectordb_bench/backend/clients/milvus/cli.py +291 -0
  6. vectordb_bench/backend/clients/milvus/milvus.py +13 -6
  7. vectordb_bench/backend/clients/pgvector/cli.py +116 -0
  8. vectordb_bench/backend/clients/pgvector/config.py +1 -1
  9. vectordb_bench/backend/clients/pgvector/pgvector.py +7 -4
  10. vectordb_bench/backend/clients/redis/cli.py +74 -0
  11. vectordb_bench/backend/clients/test/cli.py +25 -0
  12. vectordb_bench/backend/clients/test/config.py +18 -0
  13. vectordb_bench/backend/clients/test/test.py +62 -0
  14. vectordb_bench/backend/clients/weaviate_cloud/cli.py +41 -0
  15. vectordb_bench/backend/clients/zilliz_cloud/cli.py +55 -0
  16. vectordb_bench/backend/runner/mp_runner.py +14 -3
  17. vectordb_bench/backend/runner/serial_runner.py +7 -3
  18. vectordb_bench/backend/task_runner.py +76 -26
  19. vectordb_bench/cli/__init__.py +0 -0
  20. vectordb_bench/cli/cli.py +362 -0
  21. vectordb_bench/cli/vectordbbench.py +20 -0
  22. vectordb_bench/config-files/sample_config.yml +17 -0
  23. vectordb_bench/frontend/components/check_results/data.py +11 -8
  24. vectordb_bench/frontend/components/concurrent/charts.py +82 -0
  25. vectordb_bench/frontend/components/run_test/dbSelector.py +7 -1
  26. vectordb_bench/frontend/components/run_test/submitTask.py +12 -4
  27. vectordb_bench/frontend/components/tables/data.py +44 -0
  28. vectordb_bench/frontend/const/dbCaseConfigs.py +2 -1
  29. vectordb_bench/frontend/pages/concurrent.py +72 -0
  30. vectordb_bench/frontend/pages/tables.py +24 -0
  31. vectordb_bench/interface.py +21 -25
  32. vectordb_bench/metric.py +23 -1
  33. vectordb_bench/models.py +45 -5
  34. {vectordb_bench-0.0.10.dist-info → vectordb_bench-0.0.11.dist-info}/METADATA +193 -2
  35. {vectordb_bench-0.0.10.dist-info → vectordb_bench-0.0.11.dist-info}/RECORD +39 -23
  36. {vectordb_bench-0.0.10.dist-info → vectordb_bench-0.0.11.dist-info}/WHEEL +1 -1
  37. {vectordb_bench-0.0.10.dist-info → vectordb_bench-0.0.11.dist-info}/entry_points.txt +1 -0
  38. {vectordb_bench-0.0.10.dist-info → vectordb_bench-0.0.11.dist-info}/LICENSE +0 -0
  39. {vectordb_bench-0.0.10.dist-info → vectordb_bench-0.0.11.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,74 @@
1
+ from typing import Annotated, TypedDict, Unpack
2
+
3
+ import click
4
+ from pydantic import SecretStr
5
+
6
+ from ....cli.cli import (
7
+ CommonTypedDict,
8
+ HNSWFlavor2,
9
+ cli,
10
+ click_parameter_decorators_from_typed_dict,
11
+ run,
12
+ )
13
+ from .. import DB
14
+
15
+
16
+ class RedisTypedDict(TypedDict):
17
+ host: Annotated[
18
+ str, click.option("--host", type=str, help="Db host", required=True)
19
+ ]
20
+ password: Annotated[str, click.option("--password", type=str, help="Db password")]
21
+ port: Annotated[int, click.option("--port", type=int, default=6379, help="Db Port")]
22
+ ssl: Annotated[
23
+ bool,
24
+ click.option(
25
+ "--ssl/--no-ssl",
26
+ is_flag=True,
27
+ show_default=True,
28
+ default=True,
29
+ help="Enable or disable SSL for Redis",
30
+ ),
31
+ ]
32
+ ssl_ca_certs: Annotated[
33
+ str,
34
+ click.option(
35
+ "--ssl-ca-certs",
36
+ show_default=True,
37
+ help="Path to certificate authority file to use for SSL",
38
+ ),
39
+ ]
40
+ cmd: Annotated[
41
+ bool,
42
+ click.option(
43
+ "--cmd",
44
+ is_flag=True,
45
+ show_default=True,
46
+ default=False,
47
+ help="Cluster Mode Disabled (CMD) for Redis doesn't use Cluster conn",
48
+ ),
49
+ ]
50
+
51
+
52
+ class RedisHNSWTypedDict(CommonTypedDict, RedisTypedDict, HNSWFlavor2):
53
+ ...
54
+
55
+
56
+ @cli.command()
57
+ @click_parameter_decorators_from_typed_dict(RedisHNSWTypedDict)
58
+ def Redis(**parameters: Unpack[RedisHNSWTypedDict]):
59
+ from .config import RedisConfig
60
+ run(
61
+ db=DB.Redis,
62
+ db_config=RedisConfig(
63
+ db_label=parameters["db_label"],
64
+ password=SecretStr(parameters["password"])
65
+ if parameters["password"]
66
+ else None,
67
+ host=SecretStr(parameters["host"]),
68
+ port=parameters["port"],
69
+ ssl=parameters["ssl"],
70
+ ssl_ca_certs=parameters["ssl_ca_certs"],
71
+ cmd=parameters["cmd"],
72
+ ),
73
+ **parameters,
74
+ )
@@ -0,0 +1,25 @@
1
+ from typing import Unpack
2
+
3
+ from ....cli.cli import (
4
+ CommonTypedDict,
5
+ cli,
6
+ click_parameter_decorators_from_typed_dict,
7
+ run,
8
+ )
9
+ from .. import DB
10
+ from ..test.config import TestConfig, TestIndexConfig
11
+
12
+
13
+ class TestTypedDict(CommonTypedDict):
14
+ ...
15
+
16
+
17
+ @cli.command()
18
+ @click_parameter_decorators_from_typed_dict(TestTypedDict)
19
+ def Test(**parameters: Unpack[TestTypedDict]):
20
+ run(
21
+ db=DB.NewClient,
22
+ db_config=TestConfig(db_label=parameters["db_label"]),
23
+ db_case_config=TestIndexConfig(),
24
+ **parameters,
25
+ )
@@ -0,0 +1,18 @@
1
+ from pydantic import BaseModel, SecretStr
2
+
3
+ from ..api import DBCaseConfig, DBConfig, IndexType, MetricType
4
+
5
+
6
+ class TestConfig(DBConfig):
7
+ def to_dict(self) -> dict:
8
+ return {"db_label": self.db_label}
9
+
10
+
11
+ class TestIndexConfig(BaseModel, DBCaseConfig):
12
+ metric_type: MetricType | None = None
13
+
14
+ def index_param(self) -> dict:
15
+ return {}
16
+
17
+ def search_param(self) -> dict:
18
+ return {}
@@ -0,0 +1,62 @@
1
+ import logging
2
+ from contextlib import contextmanager
3
+ from typing import Any, Generator, Optional, Tuple
4
+
5
+ from ..api import DBCaseConfig, VectorDB
6
+
7
+ log = logging.getLogger(__name__)
8
+
9
+
10
+ class Test(VectorDB):
11
+ def __init__(
12
+ self,
13
+ dim: int,
14
+ db_config: dict,
15
+ db_case_config: DBCaseConfig,
16
+ drop_old: bool = False,
17
+ **kwargs,
18
+ ):
19
+ self.db_config = db_config
20
+ self.case_config = db_case_config
21
+
22
+ log.info("Starting Test DB")
23
+
24
+ @contextmanager
25
+ def init(self) -> Generator[None, None, None]:
26
+ """create and destroy connections to database.
27
+
28
+ Examples:
29
+ >>> with self.init():
30
+ >>> self.insert_embeddings()
31
+ """
32
+
33
+ yield
34
+
35
+ def ready_to_load(self) -> bool:
36
+ return True
37
+
38
+ def optimize(self) -> None:
39
+ pass
40
+
41
+ def insert_embeddings(
42
+ self,
43
+ embeddings: list[list[float]],
44
+ metadata: list[int],
45
+ **kwargs: Any,
46
+ ) -> Tuple[int, Optional[Exception]]:
47
+ """Insert embeddings into the database.
48
+ Should call self.init() first.
49
+ """
50
+ raise RuntimeError("Not implemented")
51
+ return len(metadata), None
52
+
53
+ def search_embedding(
54
+ self,
55
+ query: list[float],
56
+ k: int = 100,
57
+ filters: dict | None = None,
58
+ timeout: int | None = None,
59
+ **kwargs: Any,
60
+ ) -> list[int]:
61
+ raise NotImplementedError
62
+ return [i for i in range(k)]
@@ -0,0 +1,41 @@
1
+ from typing import Annotated, Unpack
2
+
3
+ import click
4
+ from pydantic import SecretStr
5
+
6
+ from ....cli.cli import (
7
+ CommonTypedDict,
8
+ cli,
9
+ click_parameter_decorators_from_typed_dict,
10
+ run,
11
+ )
12
+ from .. import DB
13
+
14
+
15
+ class WeaviateTypedDict(CommonTypedDict):
16
+ api_key: Annotated[
17
+ str, click.option("--api-key", type=str, help="Weaviate api key", required=True)
18
+ ]
19
+ url: Annotated[
20
+ str,
21
+ click.option("--url", type=str, help="Weaviate url", required=True),
22
+ ]
23
+
24
+
25
+ @cli.command()
26
+ @click_parameter_decorators_from_typed_dict(WeaviateTypedDict)
27
+ def Weaviate(**parameters: Unpack[WeaviateTypedDict]):
28
+ from .config import WeaviateConfig, WeaviateIndexConfig
29
+
30
+ run(
31
+ db=DB.WeaviateCloud,
32
+ db_config=WeaviateConfig(
33
+ db_label=parameters["db_label"],
34
+ api_key=SecretStr(parameters["api_key"]),
35
+ url=SecretStr(parameters["url"]),
36
+ ),
37
+ db_case_config=WeaviateIndexConfig(
38
+ ef=256, efConstruction=256, maxConnections=16
39
+ ),
40
+ **parameters,
41
+ )
@@ -0,0 +1,55 @@
1
+ from typing import Annotated, Unpack
2
+
3
+ import click
4
+ import os
5
+ from pydantic import SecretStr
6
+
7
+ from vectordb_bench.cli.cli import (
8
+ CommonTypedDict,
9
+ cli,
10
+ click_parameter_decorators_from_typed_dict,
11
+ run,
12
+ )
13
+ from vectordb_bench.backend.clients import DB
14
+
15
+
16
+ class ZillizTypedDict(CommonTypedDict):
17
+ uri: Annotated[
18
+ str, click.option("--uri", type=str, help="uri connection string", required=True)
19
+ ]
20
+ user_name: Annotated[
21
+ str, click.option("--user-name", type=str, help="Db username", required=True)
22
+ ]
23
+ password: Annotated[
24
+ str,
25
+ click.option("--password",
26
+ type=str,
27
+ help="Zilliz password",
28
+ default=lambda: os.environ.get("ZILLIZ_PASSWORD", ""),
29
+ show_default="$ZILLIZ_PASSWORD",
30
+ ),
31
+ ]
32
+ level: Annotated[
33
+ str,
34
+ click.option("--level", type=str, help="Zilliz index level", required=False),
35
+ ]
36
+
37
+
38
+ @cli.command()
39
+ @click_parameter_decorators_from_typed_dict(ZillizTypedDict)
40
+ def ZillizAutoIndex(**parameters: Unpack[ZillizTypedDict]):
41
+ from .config import ZillizCloudConfig, AutoIndexConfig
42
+
43
+ run(
44
+ db=DB.ZillizCloud,
45
+ db_config=ZillizCloudConfig(
46
+ db_label=parameters["db_label"],
47
+ uri=SecretStr(parameters["uri"]),
48
+ user=parameters["user_name"],
49
+ password=SecretStr(parameters["password"]),
50
+ ),
51
+ db_case_config=AutoIndexConfig(
52
+ params={parameters["level"]},
53
+ ),
54
+ **parameters,
55
+ )
@@ -4,6 +4,7 @@ import concurrent
4
4
  import multiprocessing as mp
5
5
  import logging
6
6
  from typing import Iterable
7
+ import numpy as np
7
8
  from ..clients import api
8
9
  from ... import config
9
10
 
@@ -49,6 +50,7 @@ class MultiProcessingSearchRunner:
49
50
 
50
51
  start_time = time.perf_counter()
51
52
  count = 0
53
+ latencies = []
52
54
  while time.perf_counter() < start_time + self.duration:
53
55
  s = time.perf_counter()
54
56
  try:
@@ -61,7 +63,8 @@ class MultiProcessingSearchRunner:
61
63
  log.warning(f"VectorDB search_embedding error: {e}")
62
64
  traceback.print_exc(chain=True)
63
65
  raise e from None
64
-
66
+
67
+ latencies.append(time.perf_counter() - s)
65
68
  count += 1
66
69
  # loop through the test data
67
70
  idx = idx + 1 if idx < num - 1 else 0
@@ -75,7 +78,7 @@ class MultiProcessingSearchRunner:
75
78
  f"actual_dur={total_dur}s, count={count}, qps in this process: {round(count / total_dur, 4):3}"
76
79
  )
77
80
 
78
- return (count, total_dur)
81
+ return (count, total_dur, latencies)
79
82
 
80
83
  @staticmethod
81
84
  def get_mp_context():
@@ -85,6 +88,9 @@ class MultiProcessingSearchRunner:
85
88
 
86
89
  def _run_all_concurrencies_mem_efficient(self) -> float:
87
90
  max_qps = 0
91
+ conc_num_list = []
92
+ conc_qps_list = []
93
+ conc_latency_p99_list = []
88
94
  try:
89
95
  for conc in self.concurrencies:
90
96
  with mp.Manager() as m:
@@ -103,9 +109,14 @@ class MultiProcessingSearchRunner:
103
109
 
104
110
  start = time.perf_counter()
105
111
  all_count = sum([r.result()[0] for r in future_iter])
112
+ latencies = sum([r.result()[2] for r in future_iter], start=[])
113
+ latency_p99 = np.percentile(latencies, 0.99)
106
114
  cost = time.perf_counter() - start
107
115
 
108
116
  qps = round(all_count / cost, 4)
117
+ conc_num_list.append(conc)
118
+ conc_qps_list.append(qps)
119
+ conc_latency_p99_list.append(latency_p99)
109
120
  log.info(f"End search in concurrency {conc}: dur={cost}s, total_count={all_count}, qps={qps}")
110
121
 
111
122
  if qps > max_qps:
@@ -122,7 +133,7 @@ class MultiProcessingSearchRunner:
122
133
  finally:
123
134
  self.stop()
124
135
 
125
- return max_qps
136
+ return max_qps, conc_num_list, conc_qps_list, conc_latency_p99_list
126
137
 
127
138
  def run(self) -> float:
128
139
  """
@@ -10,7 +10,7 @@ import numpy as np
10
10
  import pandas as pd
11
11
 
12
12
  from ..clients import api
13
- from ...metric import calc_recall
13
+ from ...metric import calc_ndcg, calc_recall, get_ideal_dcg
14
14
  from ...models import LoadTimeoutError, PerformanceTimeoutError
15
15
  from .. import utils
16
16
  from ... import config
@@ -171,11 +171,12 @@ class SerialSearchRunner:
171
171
  log.info(f"{mp.current_process().name:14} start search the entire test_data to get recall and latency")
172
172
  with self.db.init():
173
173
  test_data, ground_truth = args
174
+ ideal_dcg = get_ideal_dcg(self.k)
174
175
 
175
176
  log.debug(f"test dataset size: {len(test_data)}")
176
177
  log.debug(f"ground truth size: {ground_truth.columns}, shape: {ground_truth.shape}")
177
178
 
178
- latencies, recalls = [], []
179
+ latencies, recalls, ndcgs = [], [], []
179
180
  for idx, emb in enumerate(test_data):
180
181
  s = time.perf_counter()
181
182
  try:
@@ -194,6 +195,7 @@ class SerialSearchRunner:
194
195
 
195
196
  gt = ground_truth['neighbors_id'][idx]
196
197
  recalls.append(calc_recall(self.k, gt[:self.k], results))
198
+ ndcgs.append(calc_ndcg(gt[:self.k], results, ideal_dcg))
197
199
 
198
200
 
199
201
  if len(latencies) % 100 == 0:
@@ -201,6 +203,7 @@ class SerialSearchRunner:
201
203
 
202
204
  avg_latency = round(np.mean(latencies), 4)
203
205
  avg_recall = round(np.mean(recalls), 4)
206
+ avg_ndcg = round(np.mean(ndcgs), 4)
204
207
  cost = round(np.sum(latencies), 4)
205
208
  p99 = round(np.percentile(latencies, 99), 4)
206
209
  log.info(
@@ -208,10 +211,11 @@ class SerialSearchRunner:
208
211
  f"cost={cost}s, "
209
212
  f"queries={len(latencies)}, "
210
213
  f"avg_recall={avg_recall}, "
214
+ f"avg_ndcg={avg_ndcg},"
211
215
  f"avg_latency={avg_latency}, "
212
216
  f"p99={p99}"
213
217
  )
214
- return (avg_recall, p99)
218
+ return (avg_recall, avg_ndcg, p99)
215
219
 
216
220
 
217
221
  def _run_in_subprocess(self) -> tuple[float, float]:
@@ -8,7 +8,7 @@ from enum import Enum, auto
8
8
  from . import utils
9
9
  from .cases import Case, CaseLabel
10
10
  from ..base import BaseModel
11
- from ..models import TaskConfig, PerformanceTimeoutError
11
+ from ..models import TaskConfig, PerformanceTimeoutError, TaskStage
12
12
 
13
13
  from .clients import (
14
14
  api,
@@ -29,7 +29,7 @@ class RunningStatus(Enum):
29
29
 
30
30
 
31
31
  class CaseRunner(BaseModel):
32
- """ DataSet, filter_rate, db_class with db config
32
+ """DataSet, filter_rate, db_class with db config
33
33
 
34
34
  Fields:
35
35
  run_id(str): run_id of this case runner,
@@ -49,8 +49,9 @@ class CaseRunner(BaseModel):
49
49
 
50
50
  db: api.VectorDB | None = None
51
51
  test_emb: list[list[float]] | None = None
52
- search_runner: MultiProcessingSearchRunner | None = None
53
52
  serial_search_runner: SerialSearchRunner | None = None
53
+ search_runner: MultiProcessingSearchRunner | None = None
54
+ final_search_runner: MultiProcessingSearchRunner | None = None
54
55
 
55
56
  def __eq__(self, obj):
56
57
  if isinstance(obj, CaseRunner):
@@ -58,7 +59,7 @@ class CaseRunner(BaseModel):
58
59
  self.config.db == obj.config.db and \
59
60
  self.config.db_case_config == obj.config.db_case_config and \
60
61
  self.ca.dataset == obj.ca.dataset
61
- return False
62
+ return False
62
63
 
63
64
  def display(self) -> dict:
64
65
  c_dict = self.ca.dict(include={'label':True, 'filters': True,'dataset':{'data': {'name': True, 'size': True, 'dim': True, 'metric_type': True, 'label': True}} })
@@ -79,20 +80,25 @@ class CaseRunner(BaseModel):
79
80
  db_config=self.config.db_config.to_dict(),
80
81
  db_case_config=self.config.db_case_config,
81
82
  drop_old=drop_old,
82
- )
83
+ ) # type:ignore
84
+
83
85
 
84
86
  def _pre_run(self, drop_old: bool = True):
85
87
  try:
86
88
  self.init_db(drop_old)
87
89
  self.ca.dataset.prepare(self.dataset_source, filters=self.ca.filter_rate)
88
90
  except ModuleNotFoundError as e:
89
- log.warning(f"pre run case error: please install client for db: {self.config.db}, error={e}")
91
+ log.warning(
92
+ f"pre run case error: please install client for db: {self.config.db}, error={e}"
93
+ )
90
94
  raise e from None
91
95
  except Exception as e:
92
96
  log.warning(f"pre run case error: {e}")
93
97
  raise e from None
94
98
 
95
99
  def run(self, drop_old: bool = True) -> Metric:
100
+ log.info("Starting run")
101
+
96
102
  self._pre_run(drop_old)
97
103
 
98
104
  if self.ca.label == CaseLabel.Load:
@@ -105,31 +111,35 @@ class CaseRunner(BaseModel):
105
111
  raise ValueError(msg)
106
112
 
107
113
  def _run_capacity_case(self) -> Metric:
108
- """ run capacity cases
114
+ """run capacity cases
109
115
 
110
116
  Returns:
111
117
  Metric: the max load count
112
118
  """
119
+ assert self.db is not None
113
120
  log.info("Start capacity case")
114
121
  try:
115
- runner = SerialInsertRunner(self.db, self.ca.dataset, self.normalize, self.ca.load_timeout)
122
+ runner = SerialInsertRunner(
123
+ self.db, self.ca.dataset, self.normalize, self.ca.load_timeout
124
+ )
116
125
  count = runner.run_endlessness()
117
126
  except Exception as e:
118
127
  log.warning(f"Failed to run capacity case, reason = {e}")
119
128
  raise e from None
120
129
  else:
121
- log.info(f"Capacity case loading dataset reaches VectorDB's limit: max capacity = {count}")
130
+ log.info(
131
+ f"Capacity case loading dataset reaches VectorDB's limit: max capacity = {count}"
132
+ )
122
133
  return Metric(max_load_count=count)
123
134
 
124
135
  def _run_perf_case(self, drop_old: bool = True) -> Metric:
125
- """ run performance cases
136
+ """run performance cases
126
137
 
127
138
  Returns:
128
139
  Metric: load_duration, recall, serial_latency_p99, and, qps
129
140
  """
130
- try:
131
- m = Metric()
132
- if drop_old:
141
+ '''
142
+ if drop_old:
133
143
  _, load_dur = self._load_train_data()
134
144
  build_dur = self._optimize()
135
145
  m.load_duration = round(load_dur+build_dur, 4)
@@ -140,8 +150,43 @@ class CaseRunner(BaseModel):
140
150
  )
141
151
 
142
152
  self._init_search_runner()
143
- m.qps = self._conc_search()
153
+
154
+ m.qps, m.conc_num_list, m.conc_qps_list, m.conc_latency_p99_list = self._conc_search()
144
155
  m.recall, m.serial_latency_p99 = self._serial_search()
156
+ '''
157
+
158
+ log.info("Start performance case")
159
+ try:
160
+ m = Metric()
161
+ if drop_old:
162
+ if TaskStage.LOAD in self.config.stages:
163
+ # self._load_train_data()
164
+ _, load_dur = self._load_train_data()
165
+ build_dur = self._optimize()
166
+ m.load_duration = round(load_dur + build_dur, 4)
167
+ log.info(
168
+ f"Finish loading the entire dataset into VectorDB,"
169
+ f" insert_duration={load_dur}, optimize_duration={build_dur}"
170
+ f" load_duration(insert + optimize) = {m.load_duration}"
171
+ )
172
+ else:
173
+ log.info("Data loading skipped")
174
+ if (
175
+ TaskStage.SEARCH_SERIAL in self.config.stages
176
+ or TaskStage.SEARCH_CONCURRENT in self.config.stages
177
+ ):
178
+ self._init_search_runner()
179
+ if TaskStage.SEARCH_SERIAL in self.config.stages:
180
+ search_results = self._serial_search()
181
+ '''
182
+ m.recall = search_results.recall
183
+ m.serial_latencies = search_results.serial_latencies
184
+ '''
185
+ m.recall, m.ndcg, m.serial_latency_p99 = search_results
186
+ if TaskStage.SEARCH_CONCURRENT in self.config.stages:
187
+ search_results = self._conc_search()
188
+ m.qps, m.conc_num_list, m.conc_qps_list, m.conc_latency_p99_list = search_results
189
+
145
190
  except Exception as e:
146
191
  log.warning(f"Failed to run performance case, reason = {e}")
147
192
  traceback.print_exc()
@@ -217,18 +262,23 @@ class CaseRunner(BaseModel):
217
262
 
218
263
  gt_df = self.ca.dataset.gt_data
219
264
 
220
- self.serial_search_runner = SerialSearchRunner(
221
- db=self.db,
222
- test_data=self.test_emb,
223
- ground_truth=gt_df,
224
- filters=self.ca.filters,
225
- )
226
-
227
- self.search_runner = MultiProcessingSearchRunner(
228
- db=self.db,
229
- test_data=self.test_emb,
230
- filters=self.ca.filters,
231
- )
265
+ if TaskStage.SEARCH_SERIAL in self.config.stages:
266
+ self.serial_search_runner = SerialSearchRunner(
267
+ db=self.db,
268
+ test_data=self.test_emb,
269
+ ground_truth=gt_df,
270
+ filters=self.ca.filters,
271
+ k=self.config.case_config.k,
272
+ )
273
+ if TaskStage.SEARCH_CONCURRENT in self.config.stages:
274
+ self.search_runner = MultiProcessingSearchRunner(
275
+ db=self.db,
276
+ test_data=self.test_emb,
277
+ filters=self.ca.filters,
278
+ concurrencies=self.config.case_config.concurrency_search_config.num_concurrency,
279
+ duration=self.config.case_config.concurrency_search_config.concurrency_duration,
280
+ k=self.config.case_config.k,
281
+ )
232
282
 
233
283
  def stop(self):
234
284
  if self.search_runner:
File without changes