vectordb-bench 0.0.13__py3-none-any.whl → 0.0.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. vectordb_bench/__init__.py +14 -13
  2. vectordb_bench/backend/clients/__init__.py +13 -0
  3. vectordb_bench/backend/clients/api.py +2 -0
  4. vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py +47 -6
  5. vectordb_bench/backend/clients/aws_opensearch/config.py +12 -6
  6. vectordb_bench/backend/clients/aws_opensearch/run.py +34 -3
  7. vectordb_bench/backend/clients/pgdiskann/cli.py +99 -0
  8. vectordb_bench/backend/clients/pgdiskann/config.py +145 -0
  9. vectordb_bench/backend/clients/pgdiskann/pgdiskann.py +350 -0
  10. vectordb_bench/backend/clients/pgvector/cli.py +62 -1
  11. vectordb_bench/backend/clients/pgvector/config.py +48 -10
  12. vectordb_bench/backend/clients/pgvector/pgvector.py +145 -26
  13. vectordb_bench/backend/clients/pgvectorscale/cli.py +108 -0
  14. vectordb_bench/backend/clients/pgvectorscale/pgvectorscale.py +22 -4
  15. vectordb_bench/backend/clients/pinecone/config.py +0 -2
  16. vectordb_bench/backend/clients/pinecone/pinecone.py +34 -36
  17. vectordb_bench/backend/clients/redis/cli.py +8 -0
  18. vectordb_bench/backend/clients/redis/config.py +37 -6
  19. vectordb_bench/backend/clients/weaviate_cloud/weaviate_cloud.py +1 -1
  20. vectordb_bench/backend/runner/mp_runner.py +2 -1
  21. vectordb_bench/cli/cli.py +137 -0
  22. vectordb_bench/cli/vectordbbench.py +4 -1
  23. vectordb_bench/frontend/components/check_results/charts.py +9 -6
  24. vectordb_bench/frontend/components/concurrent/charts.py +3 -6
  25. vectordb_bench/frontend/components/run_test/caseSelector.py +6 -0
  26. vectordb_bench/frontend/config/dbCaseConfigs.py +165 -1
  27. vectordb_bench/frontend/pages/quries_per_dollar.py +13 -5
  28. vectordb_bench/frontend/vdb_benchmark.py +11 -3
  29. vectordb_bench/models.py +13 -3
  30. vectordb_bench/results/Milvus/result_20230727_standard_milvus.json +53 -1
  31. vectordb_bench/results/Milvus/result_20230808_standard_milvus.json +48 -0
  32. vectordb_bench/results/ZillizCloud/result_20230727_standard_zillizcloud.json +29 -1
  33. vectordb_bench/results/ZillizCloud/result_20230808_standard_zillizcloud.json +24 -0
  34. vectordb_bench/results/ZillizCloud/result_20240105_standard_202401_zillizcloud.json +98 -49
  35. vectordb_bench/results/getLeaderboardData.py +17 -7
  36. vectordb_bench/results/leaderboard.json +1 -1
  37. {vectordb_bench-0.0.13.dist-info → vectordb_bench-0.0.15.dist-info}/METADATA +65 -35
  38. {vectordb_bench-0.0.13.dist-info → vectordb_bench-0.0.15.dist-info}/RECORD +42 -38
  39. {vectordb_bench-0.0.13.dist-info → vectordb_bench-0.0.15.dist-info}/WHEEL +1 -1
  40. {vectordb_bench-0.0.13.dist-info → vectordb_bench-0.0.15.dist-info}/LICENSE +0 -0
  41. {vectordb_bench-0.0.13.dist-info → vectordb_bench-0.0.15.dist-info}/entry_points.txt +0 -0
  42. {vectordb_bench-0.0.13.dist-info → vectordb_bench-0.0.15.dist-info}/top_level.txt +0 -0
@@ -11,7 +11,7 @@ from pgvector.psycopg import register_vector
11
11
  from psycopg import Connection, Cursor, sql
12
12
 
13
13
  from ..api import VectorDB
14
- from .config import PgVectorConfigDict, PgVectorIndexConfig
14
+ from .config import PgVectorConfigDict, PgVectorIndexConfig, PgVectorHNSWConfig
15
15
 
16
16
  log = logging.getLogger(__name__)
17
17
 
@@ -22,7 +22,7 @@ class PgVector(VectorDB):
22
22
  conn: psycopg.Connection[Any] | None = None
23
23
  cursor: psycopg.Cursor[Any] | None = None
24
24
 
25
- # TODO add filters support
25
+ _filtered_search: sql.Composed
26
26
  _unfiltered_search: sql.Composed
27
27
 
28
28
  def __init__(
@@ -87,6 +87,92 @@ class PgVector(VectorDB):
87
87
  assert cursor is not None, "Cursor is not initialized"
88
88
 
89
89
  return conn, cursor
90
+
91
+ def _generate_search_query(self, filtered: bool=False) -> sql.Composed:
92
+ index_param = self.case_config.index_param()
93
+ reranking = self.case_config.search_param()["reranking"]
94
+ column_name = (
95
+ sql.SQL("binary_quantize({0})").format(sql.Identifier("embedding"))
96
+ if index_param["quantization_type"] == "bit"
97
+ else sql.SQL("embedding")
98
+ )
99
+ search_vector = (
100
+ sql.SQL("binary_quantize({0})").format(sql.Placeholder())
101
+ if index_param["quantization_type"] == "bit"
102
+ else sql.Placeholder()
103
+ )
104
+
105
+ # The following sections assume that the quantization_type value matches the quantization function name
106
+ if index_param["quantization_type"] != None:
107
+ if index_param["quantization_type"] == "bit" and reranking:
108
+ # Embeddings needs to be passed to binary_quantize function if quantization_type is bit
109
+ search_query = sql.Composed(
110
+ [
111
+ sql.SQL(
112
+ """
113
+ SELECT i.id
114
+ FROM (
115
+ SELECT id, embedding {reranking_metric_fun_op} %s::vector AS distance
116
+ FROM public.{table_name} {where_clause}
117
+ ORDER BY {column_name}::{quantization_type}({dim})
118
+ """
119
+ ).format(
120
+ table_name=sql.Identifier(self.table_name),
121
+ column_name=column_name,
122
+ reranking_metric_fun_op=sql.SQL(self.case_config.search_param()["reranking_metric_fun_op"]),
123
+ quantization_type=sql.SQL(index_param["quantization_type"]),
124
+ dim=sql.Literal(self.dim),
125
+ where_clause=sql.SQL("WHERE id >= %s") if filtered else sql.SQL(""),
126
+ ),
127
+ sql.SQL(self.case_config.search_param()["metric_fun_op"]),
128
+ sql.SQL(
129
+ """
130
+ {search_vector}
131
+ LIMIT {quantized_fetch_limit}
132
+ ) i
133
+ ORDER BY i.distance
134
+ LIMIT %s::int
135
+ """
136
+ ).format(
137
+ search_vector=search_vector,
138
+ quantized_fetch_limit=sql.Literal(
139
+ self.case_config.search_param()["quantized_fetch_limit"]
140
+ ),
141
+ ),
142
+ ]
143
+ )
144
+ else:
145
+ search_query = sql.Composed(
146
+ [
147
+ sql.SQL(
148
+ "SELECT id FROM public.{table_name} {where_clause} ORDER BY {column_name}::{quantization_type}({dim}) "
149
+ ).format(
150
+ table_name=sql.Identifier(self.table_name),
151
+ column_name=column_name,
152
+ quantization_type=sql.SQL(index_param["quantization_type"]),
153
+ dim=sql.Literal(self.dim),
154
+ where_clause=sql.SQL("WHERE id >= %s") if filtered else sql.SQL(""),
155
+ ),
156
+ sql.SQL(self.case_config.search_param()["metric_fun_op"]),
157
+ sql.SQL(" {search_vector} LIMIT %s::int").format(search_vector=search_vector),
158
+ ]
159
+ )
160
+ else:
161
+ search_query = sql.Composed(
162
+ [
163
+ sql.SQL(
164
+ "SELECT id FROM public.{table_name} {where_clause} ORDER BY embedding "
165
+ ).format(
166
+ table_name=sql.Identifier(self.table_name),
167
+ where_clause=sql.SQL("WHERE id >= %s") if filtered else sql.SQL(""),
168
+ ),
169
+ sql.SQL(self.case_config.search_param()["metric_fun_op"]),
170
+ sql.SQL(" %s::vector LIMIT %s::int"),
171
+ ]
172
+ )
173
+
174
+ return search_query
175
+
90
176
 
91
177
  @contextmanager
92
178
  def init(self) -> Generator[None, None, None]:
@@ -112,15 +198,8 @@ class PgVector(VectorDB):
112
198
  self.cursor.execute(command)
113
199
  self.conn.commit()
114
200
 
115
- self._unfiltered_search = sql.Composed(
116
- [
117
- sql.SQL("SELECT id FROM public.{} ORDER BY embedding ").format(
118
- sql.Identifier(self.table_name)
119
- ),
120
- sql.SQL(self.case_config.search_param()["metric_fun_op"]),
121
- sql.SQL(" %s::vector LIMIT %s::int"),
122
- ]
123
- )
201
+ self._filtered_search = self._generate_search_query(filtered=True)
202
+ self._unfiltered_search = self._generate_search_query()
124
203
 
125
204
  try:
126
205
  yield
@@ -255,17 +334,39 @@ class PgVector(VectorDB):
255
334
  else:
256
335
  with_clause = sql.Composed(())
257
336
 
258
- index_create_sql = sql.SQL(
259
- """
260
- CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name}
261
- USING {index_type} (embedding {embedding_metric})
262
- """
263
- ).format(
264
- index_name=sql.Identifier(self._index_name),
265
- table_name=sql.Identifier(self.table_name),
266
- index_type=sql.Identifier(index_param["index_type"]),
267
- embedding_metric=sql.Identifier(index_param["metric"]),
268
- )
337
+ if index_param["quantization_type"] != None:
338
+ index_create_sql = sql.SQL(
339
+ """
340
+ CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name}
341
+ USING {index_type} (({column_name}::{quantization_type}({dim})) {embedding_metric})
342
+ """
343
+ ).format(
344
+ index_name=sql.Identifier(self._index_name),
345
+ table_name=sql.Identifier(self.table_name),
346
+ column_name=(
347
+ sql.SQL("binary_quantize({0})").format(sql.Identifier("embedding"))
348
+ if index_param["quantization_type"] == "bit"
349
+ else sql.Identifier("embedding")
350
+ ),
351
+ index_type=sql.Identifier(index_param["index_type"]),
352
+ # This assumes that the quantization_type value matches the quantization function name
353
+ quantization_type=sql.SQL(index_param["quantization_type"]),
354
+ dim=self.dim,
355
+ embedding_metric=sql.Identifier(index_param["metric"]),
356
+ )
357
+ else:
358
+ index_create_sql = sql.SQL(
359
+ """
360
+ CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name}
361
+ USING {index_type} (embedding {embedding_metric})
362
+ """
363
+ ).format(
364
+ index_name=sql.Identifier(self._index_name),
365
+ table_name=sql.Identifier(self.table_name),
366
+ index_type=sql.Identifier(index_param["index_type"]),
367
+ embedding_metric=sql.Identifier(index_param["metric"]),
368
+ )
369
+
269
370
  index_create_sql_with_with_clause = (
270
371
  index_create_sql + with_clause
271
372
  ).join(" ")
@@ -341,10 +442,28 @@ class PgVector(VectorDB):
341
442
  assert self.conn is not None, "Connection is not initialized"
342
443
  assert self.cursor is not None, "Cursor is not initialized"
343
444
 
445
+ index_param = self.case_config.index_param()
446
+ search_param = self.case_config.search_param()
344
447
  q = np.asarray(query)
345
- # TODO add filters support
346
- result = self.cursor.execute(
347
- self._unfiltered_search, (q, k), prepare=True, binary=True
348
- )
448
+ if filters:
449
+ gt = filters.get("id")
450
+ if index_param["quantization_type"] == "bit" and search_param["reranking"]:
451
+ result = self.cursor.execute(
452
+ self._filtered_search, (q, gt, q, k), prepare=True, binary=True
453
+ )
454
+ else:
455
+ result = self.cursor.execute(
456
+ self._filtered_search, (gt, q, k), prepare=True, binary=True
457
+ )
458
+
459
+ else:
460
+ if index_param["quantization_type"] == "bit" and search_param["reranking"]:
461
+ result = self.cursor.execute(
462
+ self._unfiltered_search, (q, q, k), prepare=True, binary=True
463
+ )
464
+ else:
465
+ result = self.cursor.execute(
466
+ self._unfiltered_search, (q, k), prepare=True, binary=True
467
+ )
349
468
 
350
469
  return [int(i[0]) for i in result.fetchall()]
@@ -0,0 +1,108 @@
1
+ import click
2
+ import os
3
+ from pydantic import SecretStr
4
+
5
+ from ....cli.cli import (
6
+ CommonTypedDict,
7
+ cli,
8
+ click_parameter_decorators_from_typed_dict,
9
+ run,
10
+ )
11
+ from typing import Annotated, Unpack
12
+ from vectordb_bench.backend.clients import DB
13
+
14
+
15
+ class PgVectorScaleTypedDict(CommonTypedDict):
16
+ user_name: Annotated[
17
+ str, click.option("--user-name", type=str, help="Db username", required=True)
18
+ ]
19
+ password: Annotated[
20
+ str,
21
+ click.option("--password",
22
+ type=str,
23
+ help="Postgres database password",
24
+ default=lambda: os.environ.get("POSTGRES_PASSWORD", ""),
25
+ show_default="$POSTGRES_PASSWORD",
26
+ ),
27
+ ]
28
+
29
+ host: Annotated[
30
+ str, click.option("--host", type=str, help="Db host", required=True)
31
+ ]
32
+ db_name: Annotated[
33
+ str, click.option("--db-name", type=str, help="Db name", required=True)
34
+ ]
35
+
36
+
37
+ class PgVectorScaleDiskAnnTypedDict(PgVectorScaleTypedDict):
38
+ storage_layout: Annotated[
39
+ str,
40
+ click.option(
41
+ "--storage-layout", type=str, help="Streaming DiskANN storage layout",
42
+ ),
43
+ ]
44
+ num_neighbors: Annotated[
45
+ int,
46
+ click.option(
47
+ "--num-neighbors", type=int, help="Streaming DiskANN num neighbors",
48
+ ),
49
+ ]
50
+ search_list_size: Annotated[
51
+ int,
52
+ click.option(
53
+ "--search-list-size", type=int, help="Streaming DiskANN search list size",
54
+ ),
55
+ ]
56
+ max_alpha: Annotated[
57
+ float,
58
+ click.option(
59
+ "--max-alpha", type=float, help="Streaming DiskANN max alpha",
60
+ ),
61
+ ]
62
+ num_dimensions: Annotated[
63
+ int,
64
+ click.option(
65
+ "--num-dimensions", type=int, help="Streaming DiskANN num dimensions",
66
+ ),
67
+ ]
68
+ query_search_list_size: Annotated[
69
+ int,
70
+ click.option(
71
+ "--query-search-list-size", type=int, help="Streaming DiskANN query search list size",
72
+ ),
73
+ ]
74
+ query_rescore: Annotated[
75
+ int,
76
+ click.option(
77
+ "--query-rescore", type=int, help="Streaming DiskANN query rescore",
78
+ ),
79
+ ]
80
+
81
+
82
+ @cli.command()
83
+ @click_parameter_decorators_from_typed_dict(PgVectorScaleDiskAnnTypedDict)
84
+ def PgVectorScaleDiskAnn(
85
+ **parameters: Unpack[PgVectorScaleDiskAnnTypedDict],
86
+ ):
87
+ from .config import PgVectorScaleConfig, PgVectorScaleStreamingDiskANNConfig
88
+
89
+ run(
90
+ db=DB.PgVectorScale,
91
+ db_config=PgVectorScaleConfig(
92
+ db_label=parameters["db_label"],
93
+ user_name=SecretStr(parameters["user_name"]),
94
+ password=SecretStr(parameters["password"]),
95
+ host=parameters["host"],
96
+ db_name=parameters["db_name"],
97
+ ),
98
+ db_case_config=PgVectorScaleStreamingDiskANNConfig(
99
+ storage_layout=parameters["storage_layout"],
100
+ num_neighbors=parameters["num_neighbors"],
101
+ search_list_size=parameters["search_list_size"],
102
+ max_alpha=parameters["max_alpha"],
103
+ num_dimensions=parameters["num_dimensions"],
104
+ query_search_list_size=parameters["query_search_list_size"],
105
+ query_rescore=parameters["query_rescore"],
106
+ ),
107
+ **parameters,
108
+ )
@@ -22,6 +22,9 @@ class PgVectorScale(VectorDB):
22
22
  conn: psycopg.Connection[Any] | None = None
23
23
  coursor: psycopg.Cursor[Any] | None = None
24
24
 
25
+ _unfiltered_search: sql.Composed
26
+ _filtered_search: sql.Composed
27
+
25
28
  def __init__(
26
29
  self,
27
30
  dim: int,
@@ -99,6 +102,16 @@ class PgVectorScale(VectorDB):
99
102
  self.cursor.execute(command)
100
103
  self.conn.commit()
101
104
 
105
+ self._filtered_search = sql.Composed(
106
+ [
107
+ sql.SQL("SELECT id FROM public.{} WHERE id >= %s ORDER BY embedding ").format(
108
+ sql.Identifier(self.table_name),
109
+ ),
110
+ sql.SQL(self.case_config.search_param()["metric_fun_op"]),
111
+ sql.SQL(" %s::vector LIMIT %s::int")
112
+ ]
113
+ )
114
+
102
115
  self._unfiltered_search = sql.Composed(
103
116
  [
104
117
  sql.SQL("SELECT id FROM public.{} ORDER BY embedding ").format(
@@ -264,9 +277,14 @@ class PgVectorScale(VectorDB):
264
277
  assert self.cursor is not None, "Cursor is not initialized"
265
278
 
266
279
  q = np.asarray(query)
267
- # TODO add filters support
268
- result = self.cursor.execute(
269
- self._unfiltered_search, (q, k), prepare=True, binary=True
270
- )
280
+ if filters:
281
+ gt = filters.get("id")
282
+ result = self.cursor.execute(
283
+ self._filtered_search, (gt, q, k), prepare=True, binary=True
284
+ )
285
+ else:
286
+ result = self.cursor.execute(
287
+ self._unfiltered_search, (q, k), prepare=True, binary=True
288
+ )
271
289
 
272
290
  return [int(i[0]) for i in result.fetchall()]
@@ -4,12 +4,10 @@ from ..api import DBConfig
4
4
 
5
5
  class PineconeConfig(DBConfig):
6
6
  api_key: SecretStr
7
- environment: SecretStr
8
7
  index_name: str
9
8
 
10
9
  def to_dict(self) -> dict:
11
10
  return {
12
11
  "api_key": self.api_key.get_secret_value(),
13
- "environment": self.environment.get_secret_value(),
14
12
  "index_name": self.index_name,
15
13
  }
@@ -3,7 +3,7 @@
3
3
  import logging
4
4
  from contextlib import contextmanager
5
5
  from typing import Type
6
-
6
+ import pinecone
7
7
  from ..api import VectorDB, DBConfig, DBCaseConfig, EmptyDBCaseConfig, IndexType
8
8
  from .config import PineconeConfig
9
9
 
@@ -11,7 +11,8 @@ from .config import PineconeConfig
11
11
  log = logging.getLogger(__name__)
12
12
 
13
13
  PINECONE_MAX_NUM_PER_BATCH = 1000
14
- PINECONE_MAX_SIZE_PER_BATCH = 2 * 1024 * 1024 # 2MB
14
+ PINECONE_MAX_SIZE_PER_BATCH = 2 * 1024 * 1024 # 2MB
15
+
15
16
 
16
17
  class Pinecone(VectorDB):
17
18
  def __init__(
@@ -23,30 +24,25 @@ class Pinecone(VectorDB):
23
24
  **kwargs,
24
25
  ):
25
26
  """Initialize wrapper around the milvus vector database."""
26
- self.index_name = db_config["index_name"]
27
- self.api_key = db_config["api_key"]
28
- self.environment = db_config["environment"]
29
- self.batch_size = int(min(PINECONE_MAX_SIZE_PER_BATCH / (dim * 5), PINECONE_MAX_NUM_PER_BATCH))
30
- # Pincone will make connections with server while import
31
- # so place the import here.
32
- import pinecone
33
- pinecone.init(
34
- api_key=self.api_key, environment=self.environment)
27
+ self.index_name = db_config.get("index_name", "")
28
+ self.api_key = db_config.get("api_key", "")
29
+ self.batch_size = int(
30
+ min(PINECONE_MAX_SIZE_PER_BATCH / (dim * 5), PINECONE_MAX_NUM_PER_BATCH)
31
+ )
32
+
33
+ pc = pinecone.Pinecone(api_key=self.api_key)
34
+ index = pc.Index(self.index_name)
35
+
35
36
  if drop_old:
36
- list_indexes = pinecone.list_indexes()
37
- if self.index_name in list_indexes:
38
- index = pinecone.Index(self.index_name)
39
- index_dim = index.describe_index_stats()["dimension"]
40
- if (index_dim != dim):
41
- raise ValueError(
42
- f"Pinecone index {self.index_name} dimension mismatch, expected {index_dim} got {dim}")
43
- log.info(
44
- f"Pinecone client delete old index: {self.index_name}")
45
- index.delete(delete_all=True)
46
- index.close()
47
- else:
37
+ index_stats = index.describe_index_stats()
38
+ index_dim = index_stats["dimension"]
39
+ if index_dim != dim:
48
40
  raise ValueError(
49
- f"Pinecone index {self.index_name} does not exist")
41
+ f"Pinecone index {self.index_name} dimension mismatch, expected {index_dim} got {dim}"
42
+ )
43
+ for namespace in index_stats["namespaces"]:
44
+ log.info(f"Pinecone index delete namespace: {namespace}")
45
+ index.delete(delete_all=True, namespace=namespace)
50
46
 
51
47
  self._metadata_key = "meta"
52
48
 
@@ -59,13 +55,10 @@ class Pinecone(VectorDB):
59
55
  return EmptyDBCaseConfig
60
56
 
61
57
  @contextmanager
62
- def init(self) -> None:
63
- import pinecone
64
- pinecone.init(
65
- api_key=self.api_key, environment=self.environment)
66
- self.index = pinecone.Index(self.index_name)
58
+ def init(self):
59
+ pc = pinecone.Pinecone(api_key=self.api_key)
60
+ self.index = pc.Index(self.index_name)
67
61
  yield
68
- self.index.close()
69
62
 
70
63
  def ready_to_load(self):
71
64
  pass
@@ -83,11 +76,16 @@ class Pinecone(VectorDB):
83
76
  insert_count = 0
84
77
  try:
85
78
  for batch_start_offset in range(0, len(embeddings), self.batch_size):
86
- batch_end_offset = min(batch_start_offset + self.batch_size, len(embeddings))
79
+ batch_end_offset = min(
80
+ batch_start_offset + self.batch_size, len(embeddings)
81
+ )
87
82
  insert_datas = []
88
83
  for i in range(batch_start_offset, batch_end_offset):
89
- insert_data = (str(metadata[i]), embeddings[i], {
90
- self._metadata_key: metadata[i]})
84
+ insert_data = (
85
+ str(metadata[i]),
86
+ embeddings[i],
87
+ {self._metadata_key: metadata[i]},
88
+ )
91
89
  insert_datas.append(insert_data)
92
90
  self.index.upsert(insert_datas)
93
91
  insert_count += batch_end_offset - batch_start_offset
@@ -101,7 +99,7 @@ class Pinecone(VectorDB):
101
99
  k: int = 100,
102
100
  filters: dict | None = None,
103
101
  timeout: int | None = None,
104
- ) -> list[tuple[int, float]]:
102
+ ) -> list[int]:
105
103
  if filters is None:
106
104
  pinecone_filters = {}
107
105
  else:
@@ -111,9 +109,9 @@ class Pinecone(VectorDB):
111
109
  top_k=k,
112
110
  vector=query,
113
111
  filter=pinecone_filters,
114
- )['matches']
112
+ )["matches"]
115
113
  except Exception as e:
116
114
  print(f"Error querying index: {e}")
117
115
  raise e
118
- id_res = [int(one_res['id']) for one_res in res]
116
+ id_res = [int(one_res["id"]) for one_res in res]
119
117
  return id_res
@@ -3,6 +3,9 @@ from typing import Annotated, TypedDict, Unpack
3
3
  import click
4
4
  from pydantic import SecretStr
5
5
 
6
+ from .config import RedisHNSWConfig
7
+
8
+
6
9
  from ....cli.cli import (
7
10
  CommonTypedDict,
8
11
  HNSWFlavor2,
@@ -69,6 +72,11 @@ def Redis(**parameters: Unpack[RedisHNSWTypedDict]):
69
72
  ssl=parameters["ssl"],
70
73
  ssl_ca_certs=parameters["ssl_ca_certs"],
71
74
  cmd=parameters["cmd"],
75
+ ),
76
+ db_case_config=RedisHNSWConfig(
77
+ M=parameters["m"],
78
+ efConstruction=parameters["ef_construction"],
79
+ ef=parameters["ef_runtime"],
72
80
  ),
73
81
  **parameters,
74
82
  )
@@ -1,14 +1,45 @@
1
- from pydantic import SecretStr
2
- from ..api import DBConfig
1
+ from pydantic import SecretStr, BaseModel
2
+ from ..api import DBConfig, DBCaseConfig, MetricType, IndexType
3
3
 
4
4
  class RedisConfig(DBConfig):
5
- password: SecretStr
5
+ password: SecretStr | None = None
6
6
  host: SecretStr
7
- port: int = None
7
+ port: int | None = None
8
8
 
9
9
  def to_dict(self) -> dict:
10
10
  return {
11
11
  "host": self.host.get_secret_value(),
12
12
  "port": self.port,
13
- "password": self.password.get_secret_value(),
14
- }
13
+ "password": self.password.get_secret_value() if self.password is not None else None,
14
+ }
15
+
16
+
17
+
18
+ class RedisIndexConfig(BaseModel):
19
+ """Base config for milvus"""
20
+
21
+ metric_type: MetricType | None = None
22
+
23
+ def parse_metric(self) -> str:
24
+ if not self.metric_type:
25
+ return ""
26
+ return self.metric_type.value
27
+
28
+ class RedisHNSWConfig(RedisIndexConfig, DBCaseConfig):
29
+ M: int
30
+ efConstruction: int
31
+ ef: int | None = None
32
+ index: IndexType = IndexType.HNSW
33
+
34
+ def index_param(self) -> dict:
35
+ return {
36
+ "metric_type": self.parse_metric(),
37
+ "index_type": self.index.value,
38
+ "params": {"M": self.M, "efConstruction": self.efConstruction},
39
+ }
40
+
41
+ def search_param(self) -> dict:
42
+ return {
43
+ "metric_type": self.parse_metric(),
44
+ "params": {"ef": self.ef},
45
+ }
@@ -23,7 +23,7 @@ class WeaviateCloud(VectorDB):
23
23
  **kwargs,
24
24
  ):
25
25
  """Initialize wrapper around the weaviate vector database."""
26
- db_config.update("auth_client_secret", weaviate.AuthApiKey(api_key=db_config.get("auth_client_secret")))
26
+ db_config.update({"auth_client_secret": weaviate.AuthApiKey(api_key=db_config.get("auth_client_secret"))})
27
27
  self.db_config = db_config
28
28
  self.case_config = db_case_config
29
29
  self.collection_name = collection_name
@@ -2,6 +2,7 @@ import time
2
2
  import traceback
3
3
  import concurrent
4
4
  import multiprocessing as mp
5
+ import random
5
6
  import logging
6
7
  from typing import Iterable
7
8
  import numpy as np
@@ -46,7 +47,7 @@ class MultiProcessingSearchRunner:
46
47
  cond.wait()
47
48
 
48
49
  with self.db.init():
49
- num, idx = len(test_data), 0
50
+ num, idx = len(test_data), random.randint(0, len(test_data) - 1)
50
51
 
51
52
  start_time = time.perf_counter()
52
53
  count = 0