vectordb-bench 0.0.19__py3-none-any.whl → 0.0.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (105) hide show
  1. vectordb_bench/__init__.py +49 -24
  2. vectordb_bench/__main__.py +4 -3
  3. vectordb_bench/backend/assembler.py +12 -13
  4. vectordb_bench/backend/cases.py +55 -45
  5. vectordb_bench/backend/clients/__init__.py +75 -14
  6. vectordb_bench/backend/clients/aliyun_elasticsearch/aliyun_elasticsearch.py +1 -2
  7. vectordb_bench/backend/clients/aliyun_elasticsearch/config.py +3 -4
  8. vectordb_bench/backend/clients/aliyun_opensearch/aliyun_opensearch.py +111 -70
  9. vectordb_bench/backend/clients/aliyun_opensearch/config.py +6 -7
  10. vectordb_bench/backend/clients/alloydb/alloydb.py +58 -80
  11. vectordb_bench/backend/clients/alloydb/cli.py +51 -34
  12. vectordb_bench/backend/clients/alloydb/config.py +30 -30
  13. vectordb_bench/backend/clients/api.py +5 -9
  14. vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py +46 -47
  15. vectordb_bench/backend/clients/aws_opensearch/cli.py +4 -7
  16. vectordb_bench/backend/clients/aws_opensearch/config.py +13 -9
  17. vectordb_bench/backend/clients/aws_opensearch/run.py +69 -59
  18. vectordb_bench/backend/clients/chroma/chroma.py +38 -36
  19. vectordb_bench/backend/clients/chroma/config.py +4 -2
  20. vectordb_bench/backend/clients/elastic_cloud/config.py +5 -5
  21. vectordb_bench/backend/clients/elastic_cloud/elastic_cloud.py +23 -22
  22. vectordb_bench/backend/clients/memorydb/cli.py +8 -8
  23. vectordb_bench/backend/clients/memorydb/config.py +2 -2
  24. vectordb_bench/backend/clients/memorydb/memorydb.py +65 -53
  25. vectordb_bench/backend/clients/milvus/cli.py +41 -83
  26. vectordb_bench/backend/clients/milvus/config.py +18 -8
  27. vectordb_bench/backend/clients/milvus/milvus.py +18 -19
  28. vectordb_bench/backend/clients/pgdiskann/cli.py +29 -22
  29. vectordb_bench/backend/clients/pgdiskann/config.py +29 -26
  30. vectordb_bench/backend/clients/pgdiskann/pgdiskann.py +55 -73
  31. vectordb_bench/backend/clients/pgvecto_rs/cli.py +9 -11
  32. vectordb_bench/backend/clients/pgvecto_rs/config.py +8 -14
  33. vectordb_bench/backend/clients/pgvecto_rs/pgvecto_rs.py +33 -34
  34. vectordb_bench/backend/clients/pgvector/cli.py +40 -31
  35. vectordb_bench/backend/clients/pgvector/config.py +63 -73
  36. vectordb_bench/backend/clients/pgvector/pgvector.py +97 -98
  37. vectordb_bench/backend/clients/pgvectorscale/cli.py +38 -24
  38. vectordb_bench/backend/clients/pgvectorscale/config.py +14 -15
  39. vectordb_bench/backend/clients/pgvectorscale/pgvectorscale.py +38 -43
  40. vectordb_bench/backend/clients/pinecone/config.py +1 -0
  41. vectordb_bench/backend/clients/pinecone/pinecone.py +14 -21
  42. vectordb_bench/backend/clients/qdrant_cloud/config.py +11 -10
  43. vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py +40 -31
  44. vectordb_bench/backend/clients/redis/cli.py +6 -12
  45. vectordb_bench/backend/clients/redis/config.py +7 -5
  46. vectordb_bench/backend/clients/redis/redis.py +94 -58
  47. vectordb_bench/backend/clients/test/cli.py +1 -2
  48. vectordb_bench/backend/clients/test/config.py +2 -2
  49. vectordb_bench/backend/clients/test/test.py +4 -5
  50. vectordb_bench/backend/clients/weaviate_cloud/cli.py +3 -4
  51. vectordb_bench/backend/clients/weaviate_cloud/config.py +2 -2
  52. vectordb_bench/backend/clients/weaviate_cloud/weaviate_cloud.py +36 -22
  53. vectordb_bench/backend/clients/zilliz_cloud/cli.py +14 -11
  54. vectordb_bench/backend/clients/zilliz_cloud/config.py +2 -4
  55. vectordb_bench/backend/clients/zilliz_cloud/zilliz_cloud.py +1 -1
  56. vectordb_bench/backend/data_source.py +30 -18
  57. vectordb_bench/backend/dataset.py +47 -27
  58. vectordb_bench/backend/result_collector.py +2 -3
  59. vectordb_bench/backend/runner/__init__.py +4 -6
  60. vectordb_bench/backend/runner/mp_runner.py +85 -34
  61. vectordb_bench/backend/runner/rate_runner.py +30 -19
  62. vectordb_bench/backend/runner/read_write_runner.py +51 -23
  63. vectordb_bench/backend/runner/serial_runner.py +91 -48
  64. vectordb_bench/backend/runner/util.py +4 -3
  65. vectordb_bench/backend/task_runner.py +92 -72
  66. vectordb_bench/backend/utils.py +17 -10
  67. vectordb_bench/base.py +0 -1
  68. vectordb_bench/cli/cli.py +65 -60
  69. vectordb_bench/cli/vectordbbench.py +6 -7
  70. vectordb_bench/frontend/components/check_results/charts.py +8 -19
  71. vectordb_bench/frontend/components/check_results/data.py +4 -16
  72. vectordb_bench/frontend/components/check_results/filters.py +8 -16
  73. vectordb_bench/frontend/components/check_results/nav.py +4 -4
  74. vectordb_bench/frontend/components/check_results/priceTable.py +1 -3
  75. vectordb_bench/frontend/components/check_results/stPageConfig.py +2 -1
  76. vectordb_bench/frontend/components/concurrent/charts.py +12 -12
  77. vectordb_bench/frontend/components/custom/displayCustomCase.py +17 -11
  78. vectordb_bench/frontend/components/custom/displaypPrams.py +4 -2
  79. vectordb_bench/frontend/components/custom/getCustomConfig.py +1 -2
  80. vectordb_bench/frontend/components/custom/initStyle.py +1 -1
  81. vectordb_bench/frontend/components/get_results/saveAsImage.py +2 -0
  82. vectordb_bench/frontend/components/run_test/caseSelector.py +3 -9
  83. vectordb_bench/frontend/components/run_test/dbConfigSetting.py +1 -4
  84. vectordb_bench/frontend/components/run_test/dbSelector.py +1 -1
  85. vectordb_bench/frontend/components/run_test/generateTasks.py +8 -8
  86. vectordb_bench/frontend/components/run_test/submitTask.py +14 -18
  87. vectordb_bench/frontend/components/tables/data.py +3 -6
  88. vectordb_bench/frontend/config/dbCaseConfigs.py +51 -84
  89. vectordb_bench/frontend/pages/concurrent.py +3 -5
  90. vectordb_bench/frontend/pages/custom.py +30 -9
  91. vectordb_bench/frontend/pages/quries_per_dollar.py +3 -3
  92. vectordb_bench/frontend/pages/run_test.py +3 -7
  93. vectordb_bench/frontend/utils.py +1 -1
  94. vectordb_bench/frontend/vdb_benchmark.py +4 -6
  95. vectordb_bench/interface.py +56 -26
  96. vectordb_bench/log_util.py +59 -64
  97. vectordb_bench/metric.py +10 -11
  98. vectordb_bench/models.py +26 -43
  99. {vectordb_bench-0.0.19.dist-info → vectordb_bench-0.0.20.dist-info}/METADATA +22 -15
  100. vectordb_bench-0.0.20.dist-info/RECORD +135 -0
  101. {vectordb_bench-0.0.19.dist-info → vectordb_bench-0.0.20.dist-info}/WHEEL +1 -1
  102. vectordb_bench-0.0.19.dist-info/RECORD +0 -135
  103. {vectordb_bench-0.0.19.dist-info → vectordb_bench-0.0.20.dist-info}/LICENSE +0 -0
  104. {vectordb_bench-0.0.19.dist-info → vectordb_bench-0.0.20.dist-info}/entry_points.txt +0 -0
  105. {vectordb_bench-0.0.19.dist-info → vectordb_bench-0.0.20.dist-info}/top_level.txt +0 -0
@@ -1,9 +1,9 @@
1
1
  """Wrapper around the Pgvectorscale vector database over VectorDB"""
2
2
 
3
3
  import logging
4
- import pprint
4
+ from collections.abc import Generator
5
5
  from contextlib import contextmanager
6
- from typing import Any, Generator, Optional, Tuple
6
+ from typing import Any
7
7
 
8
8
  import numpy as np
9
9
  import psycopg
@@ -44,20 +44,21 @@ class PgVectorScale(VectorDB):
44
44
  self._primary_field = "id"
45
45
  self._vector_field = "embedding"
46
46
 
47
- self.conn, self.cursor = self._create_connection(**self.db_config)
47
+ self.conn, self.cursor = self._create_connection(**self.db_config)
48
48
 
49
49
  log.info(f"{self.name} config values: {self.db_config}\n{self.case_config}")
50
50
  if not any(
51
51
  (
52
52
  self.case_config.create_index_before_load,
53
53
  self.case_config.create_index_after_load,
54
- )
54
+ ),
55
55
  ):
56
- err = f"{self.name} config must create an index using create_index_before_load or create_index_after_load"
57
- log.error(err)
58
- raise RuntimeError(
59
- f"{err}\n{pprint.pformat(self.db_config)}\n{pprint.pformat(self.case_config)}"
56
+ msg = (
57
+ f"{self.name} config must create an index using create_index_before_load or create_index_after_load"
58
+ f"{self.name} config values: {self.db_config}\n{self.case_config}"
60
59
  )
60
+ log.error(msg)
61
+ raise RuntimeError(msg)
61
62
 
62
63
  if drop_old:
63
64
  self._drop_index()
@@ -72,7 +73,7 @@ class PgVectorScale(VectorDB):
72
73
  self.conn = None
73
74
 
74
75
  @staticmethod
75
- def _create_connection(**kwargs) -> Tuple[Connection, Cursor]:
76
+ def _create_connection(**kwargs) -> tuple[Connection, Cursor]:
76
77
  conn = psycopg.connect(**kwargs)
77
78
  conn.cursor().execute("CREATE EXTENSION IF NOT EXISTS vectorscale CASCADE")
78
79
  conn.commit()
@@ -101,25 +102,25 @@ class PgVectorScale(VectorDB):
101
102
  log.debug(command.as_string(self.cursor))
102
103
  self.cursor.execute(command)
103
104
  self.conn.commit()
104
-
105
+
105
106
  self._filtered_search = sql.Composed(
106
107
  [
107
108
  sql.SQL("SELECT id FROM public.{} WHERE id >= %s ORDER BY embedding ").format(
108
109
  sql.Identifier(self.table_name),
109
110
  ),
110
111
  sql.SQL(self.case_config.search_param()["metric_fun_op"]),
111
- sql.SQL(" %s::vector LIMIT %s::int")
112
- ]
112
+ sql.SQL(" %s::vector LIMIT %s::int"),
113
+ ],
113
114
  )
114
-
115
+
115
116
  self._unfiltered_search = sql.Composed(
116
117
  [
117
118
  sql.SQL("SELECT id FROM public.{} ORDER BY embedding ").format(
118
- sql.Identifier(self.table_name)
119
+ sql.Identifier(self.table_name),
119
120
  ),
120
121
  sql.SQL(self.case_config.search_param()["metric_fun_op"]),
121
122
  sql.SQL(" %s::vector LIMIT %s::int"),
122
- ]
123
+ ],
123
124
  )
124
125
 
125
126
  try:
@@ -137,8 +138,8 @@ class PgVectorScale(VectorDB):
137
138
 
138
139
  self.cursor.execute(
139
140
  sql.SQL("DROP TABLE IF EXISTS public.{table_name}").format(
140
- table_name=sql.Identifier(self.table_name)
141
- )
141
+ table_name=sql.Identifier(self.table_name),
142
+ ),
142
143
  )
143
144
  self.conn.commit()
144
145
 
@@ -160,7 +161,7 @@ class PgVectorScale(VectorDB):
160
161
  log.info(f"{self.name} client drop index : {self._index_name}")
161
162
 
162
163
  drop_index_sql = sql.SQL("DROP INDEX IF EXISTS {index_name}").format(
163
- index_name=sql.Identifier(self._index_name)
164
+ index_name=sql.Identifier(self._index_name),
164
165
  )
165
166
  log.debug(drop_index_sql.as_string(self.cursor))
166
167
  self.cursor.execute(drop_index_sql)
@@ -180,36 +181,31 @@ class PgVectorScale(VectorDB):
180
181
  sql.SQL("{option_name} = {val}").format(
181
182
  option_name=sql.Identifier(option_name),
182
183
  val=sql.Identifier(str(option_val)),
183
- )
184
+ ),
184
185
  )
185
-
186
+
186
187
  num_bits_per_dimension = "2" if self.dim < 900 else "1"
187
188
  options.append(
188
189
  sql.SQL("{option_name} = {val}").format(
189
190
  option_name=sql.Identifier("num_bits_per_dimension"),
190
191
  val=sql.Identifier(num_bits_per_dimension),
191
- )
192
+ ),
192
193
  )
193
194
 
194
- if any(options):
195
- with_clause = sql.SQL("WITH ({});").format(sql.SQL(", ").join(options))
196
- else:
197
- with_clause = sql.Composed(())
195
+ with_clause = sql.SQL("WITH ({});").format(sql.SQL(", ").join(options)) if any(options) else sql.Composed(())
198
196
 
199
197
  index_create_sql = sql.SQL(
200
198
  """
201
- CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name}
199
+ CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name}
202
200
  USING {index_type} (embedding {embedding_metric})
203
- """
201
+ """,
204
202
  ).format(
205
203
  index_name=sql.Identifier(self._index_name),
206
204
  table_name=sql.Identifier(self.table_name),
207
205
  index_type=sql.Identifier(index_param["index_type"].lower()),
208
206
  embedding_metric=sql.Identifier(index_param["metric"]),
209
207
  )
210
- index_create_sql_with_with_clause = (
211
- index_create_sql + with_clause
212
- ).join(" ")
208
+ index_create_sql_with_with_clause = (index_create_sql + with_clause).join(" ")
213
209
  log.debug(index_create_sql_with_with_clause.as_string(self.cursor))
214
210
  self.cursor.execute(index_create_sql_with_with_clause)
215
211
  self.conn.commit()
@@ -223,14 +219,12 @@ class PgVectorScale(VectorDB):
223
219
 
224
220
  self.cursor.execute(
225
221
  sql.SQL(
226
- "CREATE TABLE IF NOT EXISTS public.{table_name} (id BIGINT PRIMARY KEY, embedding vector({dim}));"
227
- ).format(table_name=sql.Identifier(self.table_name), dim=dim)
222
+ "CREATE TABLE IF NOT EXISTS public.{table_name} (id BIGINT PRIMARY KEY, embedding vector({dim}));",
223
+ ).format(table_name=sql.Identifier(self.table_name), dim=dim),
228
224
  )
229
225
  self.conn.commit()
230
226
  except Exception as e:
231
- log.warning(
232
- f"Failed to create pgvectorscale table: {self.table_name} error: {e}"
233
- )
227
+ log.warning(f"Failed to create pgvectorscale table: {self.table_name} error: {e}")
234
228
  raise e from None
235
229
 
236
230
  def insert_embeddings(
@@ -238,7 +232,7 @@ class PgVectorScale(VectorDB):
238
232
  embeddings: list[list[float]],
239
233
  metadata: list[int],
240
234
  **kwargs: Any,
241
- ) -> Tuple[int, Optional[Exception]]:
235
+ ) -> tuple[int, Exception | None]:
242
236
  assert self.conn is not None, "Connection is not initialized"
243
237
  assert self.cursor is not None, "Cursor is not initialized"
244
238
 
@@ -248,8 +242,8 @@ class PgVectorScale(VectorDB):
248
242
 
249
243
  with self.cursor.copy(
250
244
  sql.SQL("COPY public.{table_name} FROM STDIN (FORMAT BINARY)").format(
251
- table_name=sql.Identifier(self.table_name)
252
- )
245
+ table_name=sql.Identifier(self.table_name),
246
+ ),
253
247
  ) as copy:
254
248
  copy.set_types(["bigint", "vector"])
255
249
  for i, row in enumerate(metadata_arr):
@@ -262,7 +256,7 @@ class PgVectorScale(VectorDB):
262
256
  return len(metadata), None
263
257
  except Exception as e:
264
258
  log.warning(
265
- f"Failed to insert data into pgvector table ({self.table_name}), error: {e}"
259
+ f"Failed to insert data into pgvector table ({self.table_name}), error: {e}",
266
260
  )
267
261
  return 0, e
268
262
 
@@ -280,11 +274,12 @@ class PgVectorScale(VectorDB):
280
274
  if filters:
281
275
  gt = filters.get("id")
282
276
  result = self.cursor.execute(
283
- self._filtered_search, (gt, q, k), prepare=True, binary=True
277
+ self._filtered_search,
278
+ (gt, q, k),
279
+ prepare=True,
280
+ binary=True,
284
281
  )
285
282
  else:
286
- result = self.cursor.execute(
287
- self._unfiltered_search, (q, k), prepare=True, binary=True
288
- )
283
+ result = self.cursor.execute(self._unfiltered_search, (q, k), prepare=True, binary=True)
289
284
 
290
285
  return [int(i[0]) for i in result.fetchall()]
@@ -1,4 +1,5 @@
1
1
  from pydantic import SecretStr
2
+
2
3
  from ..api import DBConfig
3
4
 
4
5
 
@@ -2,11 +2,11 @@
2
2
 
3
3
  import logging
4
4
  from contextlib import contextmanager
5
- from typing import Type
5
+
6
6
  import pinecone
7
- from ..api import VectorDB, DBConfig, DBCaseConfig, EmptyDBCaseConfig, IndexType
8
- from .config import PineconeConfig
9
7
 
8
+ from ..api import DBCaseConfig, DBConfig, EmptyDBCaseConfig, IndexType, VectorDB
9
+ from .config import PineconeConfig
10
10
 
11
11
  log = logging.getLogger(__name__)
12
12
 
@@ -17,7 +17,7 @@ PINECONE_MAX_SIZE_PER_BATCH = 2 * 1024 * 1024 # 2MB
17
17
  class Pinecone(VectorDB):
18
18
  def __init__(
19
19
  self,
20
- dim,
20
+ dim: int,
21
21
  db_config: dict,
22
22
  db_case_config: DBCaseConfig,
23
23
  drop_old: bool = False,
@@ -27,7 +27,7 @@ class Pinecone(VectorDB):
27
27
  self.index_name = db_config.get("index_name", "")
28
28
  self.api_key = db_config.get("api_key", "")
29
29
  self.batch_size = int(
30
- min(PINECONE_MAX_SIZE_PER_BATCH / (dim * 5), PINECONE_MAX_NUM_PER_BATCH)
30
+ min(PINECONE_MAX_SIZE_PER_BATCH / (dim * 5), PINECONE_MAX_NUM_PER_BATCH),
31
31
  )
32
32
 
33
33
  pc = pinecone.Pinecone(api_key=self.api_key)
@@ -37,9 +37,8 @@ class Pinecone(VectorDB):
37
37
  index_stats = index.describe_index_stats()
38
38
  index_dim = index_stats["dimension"]
39
39
  if index_dim != dim:
40
- raise ValueError(
41
- f"Pinecone index {self.index_name} dimension mismatch, expected {index_dim} got {dim}"
42
- )
40
+ msg = f"Pinecone index {self.index_name} dimension mismatch, expected {index_dim} got {dim}"
41
+ raise ValueError(msg)
43
42
  for namespace in index_stats["namespaces"]:
44
43
  log.info(f"Pinecone index delete namespace: {namespace}")
45
44
  index.delete(delete_all=True, namespace=namespace)
@@ -47,11 +46,11 @@ class Pinecone(VectorDB):
47
46
  self._metadata_key = "meta"
48
47
 
49
48
  @classmethod
50
- def config_cls(cls) -> Type[DBConfig]:
49
+ def config_cls(cls) -> type[DBConfig]:
51
50
  return PineconeConfig
52
51
 
53
52
  @classmethod
54
- def case_config_cls(cls, index_type: IndexType | None = None) -> Type[DBCaseConfig]:
53
+ def case_config_cls(cls, index_type: IndexType | None = None) -> type[DBCaseConfig]:
55
54
  return EmptyDBCaseConfig
56
55
 
57
56
  @contextmanager
@@ -76,9 +75,7 @@ class Pinecone(VectorDB):
76
75
  insert_count = 0
77
76
  try:
78
77
  for batch_start_offset in range(0, len(embeddings), self.batch_size):
79
- batch_end_offset = min(
80
- batch_start_offset + self.batch_size, len(embeddings)
81
- )
78
+ batch_end_offset = min(batch_start_offset + self.batch_size, len(embeddings))
82
79
  insert_datas = []
83
80
  for i in range(batch_start_offset, batch_end_offset):
84
81
  insert_data = (
@@ -100,10 +97,7 @@ class Pinecone(VectorDB):
100
97
  filters: dict | None = None,
101
98
  timeout: int | None = None,
102
99
  ) -> list[int]:
103
- if filters is None:
104
- pinecone_filters = {}
105
- else:
106
- pinecone_filters = {self._metadata_key: {"$gte": filters["id"]}}
100
+ pinecone_filters = {} if filters is None else {self._metadata_key: {"$gte": filters["id"]}}
107
101
  try:
108
102
  res = self.index.query(
109
103
  top_k=k,
@@ -111,7 +105,6 @@ class Pinecone(VectorDB):
111
105
  filter=pinecone_filters,
112
106
  )["matches"]
113
107
  except Exception as e:
114
- print(f"Error querying index: {e}")
115
- raise e
116
- id_res = [int(one_res["id"]) for one_res in res]
117
- return id_res
108
+ log.warning(f"Error querying index: {e}")
109
+ raise e from e
110
+ return [int(one_res["id"]) for one_res in res]
@@ -1,7 +1,7 @@
1
- from pydantic import BaseModel, SecretStr
1
+ from pydantic import BaseModel, SecretStr, validator
2
+
3
+ from ..api import DBCaseConfig, DBConfig, MetricType
2
4
 
3
- from ..api import DBConfig, DBCaseConfig, MetricType
4
- from pydantic import validator
5
5
 
6
6
  # Allowing `api_key` to be left empty, to ensure compatibility with the open-source Qdrant.
7
7
  class QdrantConfig(DBConfig):
@@ -16,17 +16,19 @@ class QdrantConfig(DBConfig):
16
16
  "api_key": self.api_key.get_secret_value(),
17
17
  "prefer_grpc": True,
18
18
  }
19
- else:
20
- return {"url": self.url.get_secret_value(),}
21
-
19
+ return {
20
+ "url": self.url.get_secret_value(),
21
+ }
22
+
22
23
  @validator("*")
23
- def not_empty_field(cls, v, field):
24
+ def not_empty_field(cls, v: any, field: any):
24
25
  if field.name in ["api_key", "db_label"]:
25
26
  return v
26
- if isinstance(v, (str, SecretStr)) and len(v) == 0:
27
+ if isinstance(v, str | SecretStr) and len(v) == 0:
27
28
  raise ValueError("Empty string!")
28
29
  return v
29
30
 
31
+
30
32
  class QdrantIndexConfig(BaseModel, DBCaseConfig):
31
33
  metric_type: MetricType | None = None
32
34
 
@@ -40,8 +42,7 @@ class QdrantIndexConfig(BaseModel, DBCaseConfig):
40
42
  return "Cosine"
41
43
 
42
44
  def index_param(self) -> dict:
43
- params = {"distance": self.parse_metric()}
44
- return params
45
+ return {"distance": self.parse_metric()}
45
46
 
46
47
  def search_param(self) -> dict:
47
48
  return {}
@@ -4,23 +4,26 @@ import logging
4
4
  import time
5
5
  from contextlib import contextmanager
6
6
 
7
- from ..api import VectorDB, DBCaseConfig
7
+ from qdrant_client import QdrantClient
8
8
  from qdrant_client.http.models import (
9
- CollectionStatus,
10
- VectorParams,
11
- PayloadSchemaType,
12
9
  Batch,
13
- Filter,
10
+ CollectionStatus,
14
11
  FieldCondition,
12
+ Filter,
13
+ PayloadSchemaType,
15
14
  Range,
15
+ VectorParams,
16
16
  )
17
17
 
18
- from qdrant_client import QdrantClient
19
-
18
+ from ..api import DBCaseConfig, VectorDB
20
19
 
21
20
  log = logging.getLogger(__name__)
22
21
 
23
22
 
23
+ SECONDS_WAITING_FOR_INDEXING_API_CALL = 5
24
+ QDRANT_BATCH_SIZE = 500
25
+
26
+
24
27
  class QdrantCloud(VectorDB):
25
28
  def __init__(
26
29
  self,
@@ -57,16 +60,14 @@ class QdrantCloud(VectorDB):
57
60
  self.qdrant_client = QdrantClient(**self.db_config)
58
61
  yield
59
62
  self.qdrant_client = None
60
- del(self.qdrant_client)
63
+ del self.qdrant_client
61
64
 
62
65
  def ready_to_load(self):
63
66
  pass
64
67
 
65
-
66
68
  def optimize(self):
67
69
  assert self.qdrant_client, "Please call self.init() before"
68
70
  # wait for vectors to be fully indexed
69
- SECONDS_WAITING_FOR_INDEXING_API_CALL = 5
70
71
  try:
71
72
  while True:
72
73
  info = self.qdrant_client.get_collection(self.collection_name)
@@ -74,19 +75,26 @@ class QdrantCloud(VectorDB):
74
75
  if info.status != CollectionStatus.GREEN:
75
76
  continue
76
77
  if info.status == CollectionStatus.GREEN:
77
- log.info(f"Stored vectors: {info.vectors_count}, Indexed vectors: {info.indexed_vectors_count}, Collection status: {info.indexed_vectors_count}")
78
+ msg = (
79
+ f"Stored vectors: {info.vectors_count}, Indexed vectors: {info.indexed_vectors_count}, ",
80
+ f"Collection status: {info.indexed_vectors_count}",
81
+ )
82
+ log.info(msg)
78
83
  return
79
84
  except Exception as e:
80
85
  log.warning(f"QdrantCloud ready to search error: {e}")
81
86
  raise e from None
82
87
 
83
- def _create_collection(self, dim, qdrant_client: int):
88
+ def _create_collection(self, dim: int, qdrant_client: QdrantClient):
84
89
  log.info(f"Create collection: {self.collection_name}")
85
90
 
86
91
  try:
87
92
  qdrant_client.create_collection(
88
93
  collection_name=self.collection_name,
89
- vectors_config=VectorParams(size=dim, distance=self.case_config.index_param()["distance"])
94
+ vectors_config=VectorParams(
95
+ size=dim,
96
+ distance=self.case_config.index_param()["distance"],
97
+ ),
90
98
  )
91
99
 
92
100
  qdrant_client.create_payload_index(
@@ -109,13 +117,12 @@ class QdrantCloud(VectorDB):
109
117
  ) -> (int, Exception):
110
118
  """Insert embeddings into Milvus. should call self.init() first"""
111
119
  assert self.qdrant_client is not None
112
- QDRANT_BATCH_SIZE = 500
113
120
  try:
114
121
  # TODO: counts
115
122
  for offset in range(0, len(embeddings), QDRANT_BATCH_SIZE):
116
- vectors = embeddings[offset: offset + QDRANT_BATCH_SIZE]
117
- ids = metadata[offset: offset + QDRANT_BATCH_SIZE]
118
- payloads=[{self._primary_field: v} for v in ids]
123
+ vectors = embeddings[offset : offset + QDRANT_BATCH_SIZE]
124
+ ids = metadata[offset : offset + QDRANT_BATCH_SIZE]
125
+ payloads = [{self._primary_field: v} for v in ids]
119
126
  _ = self.qdrant_client.upsert(
120
127
  collection_name=self.collection_name,
121
128
  wait=True,
@@ -142,21 +149,23 @@ class QdrantCloud(VectorDB):
142
149
  f = None
143
150
  if filters:
144
151
  f = Filter(
145
- must=[FieldCondition(
146
- key = self._primary_field,
147
- range = Range(
148
- gt=filters.get('id'),
152
+ must=[
153
+ FieldCondition(
154
+ key=self._primary_field,
155
+ range=Range(
156
+ gt=filters.get("id"),
157
+ ),
149
158
  ),
150
- )]
159
+ ],
151
160
  )
152
161
 
153
- res = self.qdrant_client.search(
154
- collection_name=self.collection_name,
155
- query_vector=query,
156
- limit=k,
157
- query_filter=f,
158
- # with_payload=True,
159
- ),
162
+ res = (
163
+ self.qdrant_client.search(
164
+ collection_name=self.collection_name,
165
+ query_vector=query,
166
+ limit=k,
167
+ query_filter=f,
168
+ ),
169
+ )
160
170
 
161
- ret = [result.id for result in res[0]]
162
- return ret
171
+ return [result.id for result in res[0]]
@@ -3,9 +3,6 @@ from typing import Annotated, TypedDict, Unpack
3
3
  import click
4
4
  from pydantic import SecretStr
5
5
 
6
- from .config import RedisHNSWConfig
7
-
8
-
9
6
  from ....cli.cli import (
10
7
  CommonTypedDict,
11
8
  HNSWFlavor2,
@@ -14,12 +11,11 @@ from ....cli.cli import (
14
11
  run,
15
12
  )
16
13
  from .. import DB
14
+ from .config import RedisHNSWConfig
17
15
 
18
16
 
19
17
  class RedisTypedDict(TypedDict):
20
- host: Annotated[
21
- str, click.option("--host", type=str, help="Db host", required=True)
22
- ]
18
+ host: Annotated[str, click.option("--host", type=str, help="Db host", required=True)]
23
19
  password: Annotated[str, click.option("--password", type=str, help="Db password")]
24
20
  port: Annotated[int, click.option("--port", type=int, default=6379, help="Db Port")]
25
21
  ssl: Annotated[
@@ -52,27 +48,25 @@ class RedisTypedDict(TypedDict):
52
48
  ]
53
49
 
54
50
 
55
- class RedisHNSWTypedDict(CommonTypedDict, RedisTypedDict, HNSWFlavor2):
56
- ...
51
+ class RedisHNSWTypedDict(CommonTypedDict, RedisTypedDict, HNSWFlavor2): ...
57
52
 
58
53
 
59
54
  @cli.command()
60
55
  @click_parameter_decorators_from_typed_dict(RedisHNSWTypedDict)
61
56
  def Redis(**parameters: Unpack[RedisHNSWTypedDict]):
62
57
  from .config import RedisConfig
58
+
63
59
  run(
64
60
  db=DB.Redis,
65
61
  db_config=RedisConfig(
66
62
  db_label=parameters["db_label"],
67
- password=SecretStr(parameters["password"])
68
- if parameters["password"]
69
- else None,
63
+ password=SecretStr(parameters["password"]) if parameters["password"] else None,
70
64
  host=SecretStr(parameters["host"]),
71
65
  port=parameters["port"],
72
66
  ssl=parameters["ssl"],
73
67
  ssl_ca_certs=parameters["ssl_ca_certs"],
74
68
  cmd=parameters["cmd"],
75
- ),
69
+ ),
76
70
  db_case_config=RedisHNSWConfig(
77
71
  M=parameters["m"],
78
72
  efConstruction=parameters["ef_construction"],
@@ -1,10 +1,12 @@
1
- from pydantic import SecretStr, BaseModel
2
- from ..api import DBConfig, DBCaseConfig, MetricType, IndexType
1
+ from pydantic import BaseModel, SecretStr
2
+
3
+ from ..api import DBCaseConfig, DBConfig, IndexType, MetricType
4
+
3
5
 
4
6
  class RedisConfig(DBConfig):
5
7
  password: SecretStr | None = None
6
8
  host: SecretStr
7
- port: int | None = None
9
+ port: int | None = None
8
10
 
9
11
  def to_dict(self) -> dict:
10
12
  return {
@@ -12,7 +14,6 @@ class RedisConfig(DBConfig):
12
14
  "port": self.port,
13
15
  "password": self.password.get_secret_value() if self.password is not None else None,
14
16
  }
15
-
16
17
 
17
18
 
18
19
  class RedisIndexConfig(BaseModel):
@@ -24,7 +25,8 @@ class RedisIndexConfig(BaseModel):
24
25
  if not self.metric_type:
25
26
  return ""
26
27
  return self.metric_type.value
27
-
28
+
29
+
28
30
  class RedisHNSWConfig(RedisIndexConfig, DBCaseConfig):
29
31
  M: int
30
32
  efConstruction: int