vectordb-bench 0.0.19__py3-none-any.whl → 0.0.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (105) hide show
  1. vectordb_bench/__init__.py +49 -24
  2. vectordb_bench/__main__.py +4 -3
  3. vectordb_bench/backend/assembler.py +12 -13
  4. vectordb_bench/backend/cases.py +55 -45
  5. vectordb_bench/backend/clients/__init__.py +85 -14
  6. vectordb_bench/backend/clients/aliyun_elasticsearch/aliyun_elasticsearch.py +1 -2
  7. vectordb_bench/backend/clients/aliyun_elasticsearch/config.py +3 -4
  8. vectordb_bench/backend/clients/aliyun_opensearch/aliyun_opensearch.py +112 -77
  9. vectordb_bench/backend/clients/aliyun_opensearch/config.py +6 -7
  10. vectordb_bench/backend/clients/alloydb/alloydb.py +59 -84
  11. vectordb_bench/backend/clients/alloydb/cli.py +51 -34
  12. vectordb_bench/backend/clients/alloydb/config.py +30 -30
  13. vectordb_bench/backend/clients/api.py +13 -24
  14. vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py +50 -54
  15. vectordb_bench/backend/clients/aws_opensearch/cli.py +4 -7
  16. vectordb_bench/backend/clients/aws_opensearch/config.py +13 -9
  17. vectordb_bench/backend/clients/aws_opensearch/run.py +69 -59
  18. vectordb_bench/backend/clients/chroma/chroma.py +39 -40
  19. vectordb_bench/backend/clients/chroma/config.py +4 -2
  20. vectordb_bench/backend/clients/elastic_cloud/config.py +5 -5
  21. vectordb_bench/backend/clients/elastic_cloud/elastic_cloud.py +24 -26
  22. vectordb_bench/backend/clients/memorydb/cli.py +8 -8
  23. vectordb_bench/backend/clients/memorydb/config.py +2 -2
  24. vectordb_bench/backend/clients/memorydb/memorydb.py +67 -58
  25. vectordb_bench/backend/clients/milvus/cli.py +41 -83
  26. vectordb_bench/backend/clients/milvus/config.py +18 -8
  27. vectordb_bench/backend/clients/milvus/milvus.py +19 -39
  28. vectordb_bench/backend/clients/pgdiskann/cli.py +29 -22
  29. vectordb_bench/backend/clients/pgdiskann/config.py +29 -26
  30. vectordb_bench/backend/clients/pgdiskann/pgdiskann.py +56 -77
  31. vectordb_bench/backend/clients/pgvecto_rs/cli.py +9 -11
  32. vectordb_bench/backend/clients/pgvecto_rs/config.py +8 -14
  33. vectordb_bench/backend/clients/pgvecto_rs/pgvecto_rs.py +34 -43
  34. vectordb_bench/backend/clients/pgvector/cli.py +40 -31
  35. vectordb_bench/backend/clients/pgvector/config.py +63 -73
  36. vectordb_bench/backend/clients/pgvector/pgvector.py +98 -104
  37. vectordb_bench/backend/clients/pgvectorscale/cli.py +38 -24
  38. vectordb_bench/backend/clients/pgvectorscale/config.py +14 -15
  39. vectordb_bench/backend/clients/pgvectorscale/pgvectorscale.py +39 -49
  40. vectordb_bench/backend/clients/pinecone/config.py +1 -0
  41. vectordb_bench/backend/clients/pinecone/pinecone.py +15 -25
  42. vectordb_bench/backend/clients/qdrant_cloud/config.py +11 -10
  43. vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py +41 -35
  44. vectordb_bench/backend/clients/redis/cli.py +6 -12
  45. vectordb_bench/backend/clients/redis/config.py +7 -5
  46. vectordb_bench/backend/clients/redis/redis.py +95 -62
  47. vectordb_bench/backend/clients/test/cli.py +2 -3
  48. vectordb_bench/backend/clients/test/config.py +2 -2
  49. vectordb_bench/backend/clients/test/test.py +5 -9
  50. vectordb_bench/backend/clients/weaviate_cloud/cli.py +3 -4
  51. vectordb_bench/backend/clients/weaviate_cloud/config.py +2 -2
  52. vectordb_bench/backend/clients/weaviate_cloud/weaviate_cloud.py +37 -26
  53. vectordb_bench/backend/clients/zilliz_cloud/cli.py +14 -11
  54. vectordb_bench/backend/clients/zilliz_cloud/config.py +2 -4
  55. vectordb_bench/backend/clients/zilliz_cloud/zilliz_cloud.py +1 -1
  56. vectordb_bench/backend/data_source.py +18 -14
  57. vectordb_bench/backend/dataset.py +47 -27
  58. vectordb_bench/backend/result_collector.py +2 -3
  59. vectordb_bench/backend/runner/__init__.py +4 -6
  60. vectordb_bench/backend/runner/mp_runner.py +56 -23
  61. vectordb_bench/backend/runner/rate_runner.py +30 -19
  62. vectordb_bench/backend/runner/read_write_runner.py +46 -22
  63. vectordb_bench/backend/runner/serial_runner.py +81 -46
  64. vectordb_bench/backend/runner/util.py +4 -3
  65. vectordb_bench/backend/task_runner.py +92 -92
  66. vectordb_bench/backend/utils.py +17 -10
  67. vectordb_bench/base.py +0 -1
  68. vectordb_bench/cli/cli.py +65 -60
  69. vectordb_bench/cli/vectordbbench.py +6 -7
  70. vectordb_bench/frontend/components/check_results/charts.py +8 -19
  71. vectordb_bench/frontend/components/check_results/data.py +4 -16
  72. vectordb_bench/frontend/components/check_results/filters.py +8 -16
  73. vectordb_bench/frontend/components/check_results/nav.py +4 -4
  74. vectordb_bench/frontend/components/check_results/priceTable.py +1 -3
  75. vectordb_bench/frontend/components/check_results/stPageConfig.py +2 -1
  76. vectordb_bench/frontend/components/concurrent/charts.py +12 -12
  77. vectordb_bench/frontend/components/custom/displayCustomCase.py +17 -11
  78. vectordb_bench/frontend/components/custom/displaypPrams.py +4 -2
  79. vectordb_bench/frontend/components/custom/getCustomConfig.py +1 -2
  80. vectordb_bench/frontend/components/custom/initStyle.py +1 -1
  81. vectordb_bench/frontend/components/get_results/saveAsImage.py +2 -0
  82. vectordb_bench/frontend/components/run_test/caseSelector.py +3 -9
  83. vectordb_bench/frontend/components/run_test/dbConfigSetting.py +1 -4
  84. vectordb_bench/frontend/components/run_test/dbSelector.py +1 -1
  85. vectordb_bench/frontend/components/run_test/generateTasks.py +8 -8
  86. vectordb_bench/frontend/components/run_test/submitTask.py +14 -18
  87. vectordb_bench/frontend/components/tables/data.py +3 -6
  88. vectordb_bench/frontend/config/dbCaseConfigs.py +51 -84
  89. vectordb_bench/frontend/pages/concurrent.py +3 -5
  90. vectordb_bench/frontend/pages/custom.py +30 -9
  91. vectordb_bench/frontend/pages/quries_per_dollar.py +3 -3
  92. vectordb_bench/frontend/pages/run_test.py +3 -7
  93. vectordb_bench/frontend/utils.py +1 -1
  94. vectordb_bench/frontend/vdb_benchmark.py +4 -6
  95. vectordb_bench/interface.py +45 -24
  96. vectordb_bench/log_util.py +59 -64
  97. vectordb_bench/metric.py +10 -11
  98. vectordb_bench/models.py +26 -43
  99. {vectordb_bench-0.0.19.dist-info → vectordb_bench-0.0.21.dist-info}/METADATA +22 -15
  100. vectordb_bench-0.0.21.dist-info/RECORD +135 -0
  101. {vectordb_bench-0.0.19.dist-info → vectordb_bench-0.0.21.dist-info}/WHEEL +1 -1
  102. vectordb_bench-0.0.19.dist-info/RECORD +0 -135
  103. {vectordb_bench-0.0.19.dist-info → vectordb_bench-0.0.21.dist-info}/LICENSE +0 -0
  104. {vectordb_bench-0.0.19.dist-info → vectordb_bench-0.0.21.dist-info}/entry_points.txt +0 -0
  105. {vectordb_bench-0.0.19.dist-info → vectordb_bench-0.0.21.dist-info}/top_level.txt +0 -0
@@ -1,9 +1,9 @@
1
1
  """Wrapper around the Pgvectorscale vector database over VectorDB"""
2
2
 
3
3
  import logging
4
- import pprint
4
+ from collections.abc import Generator
5
5
  from contextlib import contextmanager
6
- from typing import Any, Generator, Optional, Tuple
6
+ from typing import Any
7
7
 
8
8
  import numpy as np
9
9
  import psycopg
@@ -44,20 +44,21 @@ class PgVectorScale(VectorDB):
44
44
  self._primary_field = "id"
45
45
  self._vector_field = "embedding"
46
46
 
47
- self.conn, self.cursor = self._create_connection(**self.db_config)
47
+ self.conn, self.cursor = self._create_connection(**self.db_config)
48
48
 
49
49
  log.info(f"{self.name} config values: {self.db_config}\n{self.case_config}")
50
50
  if not any(
51
51
  (
52
52
  self.case_config.create_index_before_load,
53
53
  self.case_config.create_index_after_load,
54
- )
54
+ ),
55
55
  ):
56
- err = f"{self.name} config must create an index using create_index_before_load or create_index_after_load"
57
- log.error(err)
58
- raise RuntimeError(
59
- f"{err}\n{pprint.pformat(self.db_config)}\n{pprint.pformat(self.case_config)}"
56
+ msg = (
57
+ f"{self.name} config must create an index using create_index_before_load or create_index_after_load"
58
+ f"{self.name} config values: {self.db_config}\n{self.case_config}"
60
59
  )
60
+ log.error(msg)
61
+ raise RuntimeError(msg)
61
62
 
62
63
  if drop_old:
63
64
  self._drop_index()
@@ -72,7 +73,7 @@ class PgVectorScale(VectorDB):
72
73
  self.conn = None
73
74
 
74
75
  @staticmethod
75
- def _create_connection(**kwargs) -> Tuple[Connection, Cursor]:
76
+ def _create_connection(**kwargs) -> tuple[Connection, Cursor]:
76
77
  conn = psycopg.connect(**kwargs)
77
78
  conn.cursor().execute("CREATE EXTENSION IF NOT EXISTS vectorscale CASCADE")
78
79
  conn.commit()
@@ -101,25 +102,25 @@ class PgVectorScale(VectorDB):
101
102
  log.debug(command.as_string(self.cursor))
102
103
  self.cursor.execute(command)
103
104
  self.conn.commit()
104
-
105
+
105
106
  self._filtered_search = sql.Composed(
106
107
  [
107
108
  sql.SQL("SELECT id FROM public.{} WHERE id >= %s ORDER BY embedding ").format(
108
109
  sql.Identifier(self.table_name),
109
110
  ),
110
111
  sql.SQL(self.case_config.search_param()["metric_fun_op"]),
111
- sql.SQL(" %s::vector LIMIT %s::int")
112
- ]
112
+ sql.SQL(" %s::vector LIMIT %s::int"),
113
+ ],
113
114
  )
114
-
115
+
115
116
  self._unfiltered_search = sql.Composed(
116
117
  [
117
118
  sql.SQL("SELECT id FROM public.{} ORDER BY embedding ").format(
118
- sql.Identifier(self.table_name)
119
+ sql.Identifier(self.table_name),
119
120
  ),
120
121
  sql.SQL(self.case_config.search_param()["metric_fun_op"]),
121
122
  sql.SQL(" %s::vector LIMIT %s::int"),
122
- ]
123
+ ],
123
124
  )
124
125
 
125
126
  try:
@@ -137,15 +138,12 @@ class PgVectorScale(VectorDB):
137
138
 
138
139
  self.cursor.execute(
139
140
  sql.SQL("DROP TABLE IF EXISTS public.{table_name}").format(
140
- table_name=sql.Identifier(self.table_name)
141
- )
141
+ table_name=sql.Identifier(self.table_name),
142
+ ),
142
143
  )
143
144
  self.conn.commit()
144
145
 
145
- def ready_to_load(self):
146
- pass
147
-
148
- def optimize(self):
146
+ def optimize(self, data_size: int | None = None):
149
147
  self._post_insert()
150
148
 
151
149
  def _post_insert(self):
@@ -160,7 +158,7 @@ class PgVectorScale(VectorDB):
160
158
  log.info(f"{self.name} client drop index : {self._index_name}")
161
159
 
162
160
  drop_index_sql = sql.SQL("DROP INDEX IF EXISTS {index_name}").format(
163
- index_name=sql.Identifier(self._index_name)
161
+ index_name=sql.Identifier(self._index_name),
164
162
  )
165
163
  log.debug(drop_index_sql.as_string(self.cursor))
166
164
  self.cursor.execute(drop_index_sql)
@@ -180,36 +178,31 @@ class PgVectorScale(VectorDB):
180
178
  sql.SQL("{option_name} = {val}").format(
181
179
  option_name=sql.Identifier(option_name),
182
180
  val=sql.Identifier(str(option_val)),
183
- )
181
+ ),
184
182
  )
185
-
183
+
186
184
  num_bits_per_dimension = "2" if self.dim < 900 else "1"
187
185
  options.append(
188
186
  sql.SQL("{option_name} = {val}").format(
189
187
  option_name=sql.Identifier("num_bits_per_dimension"),
190
188
  val=sql.Identifier(num_bits_per_dimension),
191
- )
189
+ ),
192
190
  )
193
191
 
194
- if any(options):
195
- with_clause = sql.SQL("WITH ({});").format(sql.SQL(", ").join(options))
196
- else:
197
- with_clause = sql.Composed(())
192
+ with_clause = sql.SQL("WITH ({});").format(sql.SQL(", ").join(options)) if any(options) else sql.Composed(())
198
193
 
199
194
  index_create_sql = sql.SQL(
200
195
  """
201
- CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name}
196
+ CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name}
202
197
  USING {index_type} (embedding {embedding_metric})
203
- """
198
+ """,
204
199
  ).format(
205
200
  index_name=sql.Identifier(self._index_name),
206
201
  table_name=sql.Identifier(self.table_name),
207
202
  index_type=sql.Identifier(index_param["index_type"].lower()),
208
203
  embedding_metric=sql.Identifier(index_param["metric"]),
209
204
  )
210
- index_create_sql_with_with_clause = (
211
- index_create_sql + with_clause
212
- ).join(" ")
205
+ index_create_sql_with_with_clause = (index_create_sql + with_clause).join(" ")
213
206
  log.debug(index_create_sql_with_with_clause.as_string(self.cursor))
214
207
  self.cursor.execute(index_create_sql_with_with_clause)
215
208
  self.conn.commit()
@@ -223,14 +216,12 @@ class PgVectorScale(VectorDB):
223
216
 
224
217
  self.cursor.execute(
225
218
  sql.SQL(
226
- "CREATE TABLE IF NOT EXISTS public.{table_name} (id BIGINT PRIMARY KEY, embedding vector({dim}));"
227
- ).format(table_name=sql.Identifier(self.table_name), dim=dim)
219
+ "CREATE TABLE IF NOT EXISTS public.{table_name} (id BIGINT PRIMARY KEY, embedding vector({dim}));",
220
+ ).format(table_name=sql.Identifier(self.table_name), dim=dim),
228
221
  )
229
222
  self.conn.commit()
230
223
  except Exception as e:
231
- log.warning(
232
- f"Failed to create pgvectorscale table: {self.table_name} error: {e}"
233
- )
224
+ log.warning(f"Failed to create pgvectorscale table: {self.table_name} error: {e}")
234
225
  raise e from None
235
226
 
236
227
  def insert_embeddings(
@@ -238,7 +229,7 @@ class PgVectorScale(VectorDB):
238
229
  embeddings: list[list[float]],
239
230
  metadata: list[int],
240
231
  **kwargs: Any,
241
- ) -> Tuple[int, Optional[Exception]]:
232
+ ) -> tuple[int, Exception | None]:
242
233
  assert self.conn is not None, "Connection is not initialized"
243
234
  assert self.cursor is not None, "Cursor is not initialized"
244
235
 
@@ -248,8 +239,8 @@ class PgVectorScale(VectorDB):
248
239
 
249
240
  with self.cursor.copy(
250
241
  sql.SQL("COPY public.{table_name} FROM STDIN (FORMAT BINARY)").format(
251
- table_name=sql.Identifier(self.table_name)
252
- )
242
+ table_name=sql.Identifier(self.table_name),
243
+ ),
253
244
  ) as copy:
254
245
  copy.set_types(["bigint", "vector"])
255
246
  for i, row in enumerate(metadata_arr):
@@ -261,9 +252,7 @@ class PgVectorScale(VectorDB):
261
252
 
262
253
  return len(metadata), None
263
254
  except Exception as e:
264
- log.warning(
265
- f"Failed to insert data into pgvector table ({self.table_name}), error: {e}"
266
- )
255
+ log.warning(f"Failed to insert data into pgvector table ({self.table_name}), error: {e}")
267
256
  return 0, e
268
257
 
269
258
  def search_embedding(
@@ -280,11 +269,12 @@ class PgVectorScale(VectorDB):
280
269
  if filters:
281
270
  gt = filters.get("id")
282
271
  result = self.cursor.execute(
283
- self._filtered_search, (gt, q, k), prepare=True, binary=True
272
+ self._filtered_search,
273
+ (gt, q, k),
274
+ prepare=True,
275
+ binary=True,
284
276
  )
285
277
  else:
286
- result = self.cursor.execute(
287
- self._unfiltered_search, (q, k), prepare=True, binary=True
288
- )
278
+ result = self.cursor.execute(self._unfiltered_search, (q, k), prepare=True, binary=True)
289
279
 
290
280
  return [int(i[0]) for i in result.fetchall()]
@@ -1,4 +1,5 @@
1
1
  from pydantic import SecretStr
2
+
2
3
  from ..api import DBConfig
3
4
 
4
5
 
@@ -2,11 +2,11 @@
2
2
 
3
3
  import logging
4
4
  from contextlib import contextmanager
5
- from typing import Type
5
+
6
6
  import pinecone
7
- from ..api import VectorDB, DBConfig, DBCaseConfig, EmptyDBCaseConfig, IndexType
8
- from .config import PineconeConfig
9
7
 
8
+ from ..api import DBCaseConfig, DBConfig, EmptyDBCaseConfig, IndexType, VectorDB
9
+ from .config import PineconeConfig
10
10
 
11
11
  log = logging.getLogger(__name__)
12
12
 
@@ -17,7 +17,7 @@ PINECONE_MAX_SIZE_PER_BATCH = 2 * 1024 * 1024 # 2MB
17
17
  class Pinecone(VectorDB):
18
18
  def __init__(
19
19
  self,
20
- dim,
20
+ dim: int,
21
21
  db_config: dict,
22
22
  db_case_config: DBCaseConfig,
23
23
  drop_old: bool = False,
@@ -27,7 +27,7 @@ class Pinecone(VectorDB):
27
27
  self.index_name = db_config.get("index_name", "")
28
28
  self.api_key = db_config.get("api_key", "")
29
29
  self.batch_size = int(
30
- min(PINECONE_MAX_SIZE_PER_BATCH / (dim * 5), PINECONE_MAX_NUM_PER_BATCH)
30
+ min(PINECONE_MAX_SIZE_PER_BATCH / (dim * 5), PINECONE_MAX_NUM_PER_BATCH),
31
31
  )
32
32
 
33
33
  pc = pinecone.Pinecone(api_key=self.api_key)
@@ -37,9 +37,8 @@ class Pinecone(VectorDB):
37
37
  index_stats = index.describe_index_stats()
38
38
  index_dim = index_stats["dimension"]
39
39
  if index_dim != dim:
40
- raise ValueError(
41
- f"Pinecone index {self.index_name} dimension mismatch, expected {index_dim} got {dim}"
42
- )
40
+ msg = f"Pinecone index {self.index_name} dimension mismatch, expected {index_dim} got {dim}"
41
+ raise ValueError(msg)
43
42
  for namespace in index_stats["namespaces"]:
44
43
  log.info(f"Pinecone index delete namespace: {namespace}")
45
44
  index.delete(delete_all=True, namespace=namespace)
@@ -47,11 +46,11 @@ class Pinecone(VectorDB):
47
46
  self._metadata_key = "meta"
48
47
 
49
48
  @classmethod
50
- def config_cls(cls) -> Type[DBConfig]:
49
+ def config_cls(cls) -> type[DBConfig]:
51
50
  return PineconeConfig
52
51
 
53
52
  @classmethod
54
- def case_config_cls(cls, index_type: IndexType | None = None) -> Type[DBCaseConfig]:
53
+ def case_config_cls(cls, index_type: IndexType | None = None) -> type[DBCaseConfig]:
55
54
  return EmptyDBCaseConfig
56
55
 
57
56
  @contextmanager
@@ -60,10 +59,7 @@ class Pinecone(VectorDB):
60
59
  self.index = pc.Index(self.index_name)
61
60
  yield
62
61
 
63
- def ready_to_load(self):
64
- pass
65
-
66
- def optimize(self):
62
+ def optimize(self, data_size: int | None = None):
67
63
  pass
68
64
 
69
65
  def insert_embeddings(
@@ -76,9 +72,7 @@ class Pinecone(VectorDB):
76
72
  insert_count = 0
77
73
  try:
78
74
  for batch_start_offset in range(0, len(embeddings), self.batch_size):
79
- batch_end_offset = min(
80
- batch_start_offset + self.batch_size, len(embeddings)
81
- )
75
+ batch_end_offset = min(batch_start_offset + self.batch_size, len(embeddings))
82
76
  insert_datas = []
83
77
  for i in range(batch_start_offset, batch_end_offset):
84
78
  insert_data = (
@@ -100,10 +94,7 @@ class Pinecone(VectorDB):
100
94
  filters: dict | None = None,
101
95
  timeout: int | None = None,
102
96
  ) -> list[int]:
103
- if filters is None:
104
- pinecone_filters = {}
105
- else:
106
- pinecone_filters = {self._metadata_key: {"$gte": filters["id"]}}
97
+ pinecone_filters = {} if filters is None else {self._metadata_key: {"$gte": filters["id"]}}
107
98
  try:
108
99
  res = self.index.query(
109
100
  top_k=k,
@@ -111,7 +102,6 @@ class Pinecone(VectorDB):
111
102
  filter=pinecone_filters,
112
103
  )["matches"]
113
104
  except Exception as e:
114
- print(f"Error querying index: {e}")
115
- raise e
116
- id_res = [int(one_res["id"]) for one_res in res]
117
- return id_res
105
+ log.warning(f"Error querying index: {e}")
106
+ raise e from e
107
+ return [int(one_res["id"]) for one_res in res]
@@ -1,7 +1,7 @@
1
- from pydantic import BaseModel, SecretStr
1
+ from pydantic import BaseModel, SecretStr, validator
2
+
3
+ from ..api import DBCaseConfig, DBConfig, MetricType
2
4
 
3
- from ..api import DBConfig, DBCaseConfig, MetricType
4
- from pydantic import validator
5
5
 
6
6
  # Allowing `api_key` to be left empty, to ensure compatibility with the open-source Qdrant.
7
7
  class QdrantConfig(DBConfig):
@@ -16,17 +16,19 @@ class QdrantConfig(DBConfig):
16
16
  "api_key": self.api_key.get_secret_value(),
17
17
  "prefer_grpc": True,
18
18
  }
19
- else:
20
- return {"url": self.url.get_secret_value(),}
21
-
19
+ return {
20
+ "url": self.url.get_secret_value(),
21
+ }
22
+
22
23
  @validator("*")
23
- def not_empty_field(cls, v, field):
24
+ def not_empty_field(cls, v: any, field: any):
24
25
  if field.name in ["api_key", "db_label"]:
25
26
  return v
26
- if isinstance(v, (str, SecretStr)) and len(v) == 0:
27
+ if isinstance(v, str | SecretStr) and len(v) == 0:
27
28
  raise ValueError("Empty string!")
28
29
  return v
29
30
 
31
+
30
32
  class QdrantIndexConfig(BaseModel, DBCaseConfig):
31
33
  metric_type: MetricType | None = None
32
34
 
@@ -40,8 +42,7 @@ class QdrantIndexConfig(BaseModel, DBCaseConfig):
40
42
  return "Cosine"
41
43
 
42
44
  def index_param(self) -> dict:
43
- params = {"distance": self.parse_metric()}
44
- return params
45
+ return {"distance": self.parse_metric()}
45
46
 
46
47
  def search_param(self) -> dict:
47
48
  return {}
@@ -4,23 +4,26 @@ import logging
4
4
  import time
5
5
  from contextlib import contextmanager
6
6
 
7
- from ..api import VectorDB, DBCaseConfig
7
+ from qdrant_client import QdrantClient
8
8
  from qdrant_client.http.models import (
9
- CollectionStatus,
10
- VectorParams,
11
- PayloadSchemaType,
12
9
  Batch,
13
- Filter,
10
+ CollectionStatus,
14
11
  FieldCondition,
12
+ Filter,
13
+ PayloadSchemaType,
15
14
  Range,
15
+ VectorParams,
16
16
  )
17
17
 
18
- from qdrant_client import QdrantClient
19
-
18
+ from ..api import DBCaseConfig, VectorDB
20
19
 
21
20
  log = logging.getLogger(__name__)
22
21
 
23
22
 
23
+ SECONDS_WAITING_FOR_INDEXING_API_CALL = 5
24
+ QDRANT_BATCH_SIZE = 500
25
+
26
+
24
27
  class QdrantCloud(VectorDB):
25
28
  def __init__(
26
29
  self,
@@ -57,16 +60,11 @@ class QdrantCloud(VectorDB):
57
60
  self.qdrant_client = QdrantClient(**self.db_config)
58
61
  yield
59
62
  self.qdrant_client = None
60
- del(self.qdrant_client)
61
-
62
- def ready_to_load(self):
63
- pass
63
+ del self.qdrant_client
64
64
 
65
-
66
- def optimize(self):
65
+ def optimize(self, data_size: int | None = None):
67
66
  assert self.qdrant_client, "Please call self.init() before"
68
67
  # wait for vectors to be fully indexed
69
- SECONDS_WAITING_FOR_INDEXING_API_CALL = 5
70
68
  try:
71
69
  while True:
72
70
  info = self.qdrant_client.get_collection(self.collection_name)
@@ -74,19 +72,26 @@ class QdrantCloud(VectorDB):
74
72
  if info.status != CollectionStatus.GREEN:
75
73
  continue
76
74
  if info.status == CollectionStatus.GREEN:
77
- log.info(f"Stored vectors: {info.vectors_count}, Indexed vectors: {info.indexed_vectors_count}, Collection status: {info.indexed_vectors_count}")
75
+ msg = (
76
+ f"Stored vectors: {info.vectors_count}, Indexed vectors: {info.indexed_vectors_count}, "
77
+ f"Collection status: {info.indexed_vectors_count}"
78
+ )
79
+ log.info(msg)
78
80
  return
79
81
  except Exception as e:
80
82
  log.warning(f"QdrantCloud ready to search error: {e}")
81
83
  raise e from None
82
84
 
83
- def _create_collection(self, dim, qdrant_client: int):
85
+ def _create_collection(self, dim: int, qdrant_client: QdrantClient):
84
86
  log.info(f"Create collection: {self.collection_name}")
85
87
 
86
88
  try:
87
89
  qdrant_client.create_collection(
88
90
  collection_name=self.collection_name,
89
- vectors_config=VectorParams(size=dim, distance=self.case_config.index_param()["distance"])
91
+ vectors_config=VectorParams(
92
+ size=dim,
93
+ distance=self.case_config.index_param()["distance"],
94
+ ),
90
95
  )
91
96
 
92
97
  qdrant_client.create_payload_index(
@@ -109,13 +114,12 @@ class QdrantCloud(VectorDB):
109
114
  ) -> (int, Exception):
110
115
  """Insert embeddings into Milvus. should call self.init() first"""
111
116
  assert self.qdrant_client is not None
112
- QDRANT_BATCH_SIZE = 500
113
117
  try:
114
118
  # TODO: counts
115
119
  for offset in range(0, len(embeddings), QDRANT_BATCH_SIZE):
116
- vectors = embeddings[offset: offset + QDRANT_BATCH_SIZE]
117
- ids = metadata[offset: offset + QDRANT_BATCH_SIZE]
118
- payloads=[{self._primary_field: v} for v in ids]
120
+ vectors = embeddings[offset : offset + QDRANT_BATCH_SIZE]
121
+ ids = metadata[offset : offset + QDRANT_BATCH_SIZE]
122
+ payloads = [{self._primary_field: v} for v in ids]
119
123
  _ = self.qdrant_client.upsert(
120
124
  collection_name=self.collection_name,
121
125
  wait=True,
@@ -142,21 +146,23 @@ class QdrantCloud(VectorDB):
142
146
  f = None
143
147
  if filters:
144
148
  f = Filter(
145
- must=[FieldCondition(
146
- key = self._primary_field,
147
- range = Range(
148
- gt=filters.get('id'),
149
+ must=[
150
+ FieldCondition(
151
+ key=self._primary_field,
152
+ range=Range(
153
+ gt=filters.get("id"),
154
+ ),
149
155
  ),
150
- )]
156
+ ],
151
157
  )
152
158
 
153
- res = self.qdrant_client.search(
154
- collection_name=self.collection_name,
155
- query_vector=query,
156
- limit=k,
157
- query_filter=f,
158
- # with_payload=True,
159
- ),
159
+ res = (
160
+ self.qdrant_client.search(
161
+ collection_name=self.collection_name,
162
+ query_vector=query,
163
+ limit=k,
164
+ query_filter=f,
165
+ ),
166
+ )
160
167
 
161
- ret = [result.id for result in res[0]]
162
- return ret
168
+ return [result.id for result in res[0]]
@@ -3,9 +3,6 @@ from typing import Annotated, TypedDict, Unpack
3
3
  import click
4
4
  from pydantic import SecretStr
5
5
 
6
- from .config import RedisHNSWConfig
7
-
8
-
9
6
  from ....cli.cli import (
10
7
  CommonTypedDict,
11
8
  HNSWFlavor2,
@@ -14,12 +11,11 @@ from ....cli.cli import (
14
11
  run,
15
12
  )
16
13
  from .. import DB
14
+ from .config import RedisHNSWConfig
17
15
 
18
16
 
19
17
  class RedisTypedDict(TypedDict):
20
- host: Annotated[
21
- str, click.option("--host", type=str, help="Db host", required=True)
22
- ]
18
+ host: Annotated[str, click.option("--host", type=str, help="Db host", required=True)]
23
19
  password: Annotated[str, click.option("--password", type=str, help="Db password")]
24
20
  port: Annotated[int, click.option("--port", type=int, default=6379, help="Db Port")]
25
21
  ssl: Annotated[
@@ -52,27 +48,25 @@ class RedisTypedDict(TypedDict):
52
48
  ]
53
49
 
54
50
 
55
- class RedisHNSWTypedDict(CommonTypedDict, RedisTypedDict, HNSWFlavor2):
56
- ...
51
+ class RedisHNSWTypedDict(CommonTypedDict, RedisTypedDict, HNSWFlavor2): ...
57
52
 
58
53
 
59
54
  @cli.command()
60
55
  @click_parameter_decorators_from_typed_dict(RedisHNSWTypedDict)
61
56
  def Redis(**parameters: Unpack[RedisHNSWTypedDict]):
62
57
  from .config import RedisConfig
58
+
63
59
  run(
64
60
  db=DB.Redis,
65
61
  db_config=RedisConfig(
66
62
  db_label=parameters["db_label"],
67
- password=SecretStr(parameters["password"])
68
- if parameters["password"]
69
- else None,
63
+ password=SecretStr(parameters["password"]) if parameters["password"] else None,
70
64
  host=SecretStr(parameters["host"]),
71
65
  port=parameters["port"],
72
66
  ssl=parameters["ssl"],
73
67
  ssl_ca_certs=parameters["ssl_ca_certs"],
74
68
  cmd=parameters["cmd"],
75
- ),
69
+ ),
76
70
  db_case_config=RedisHNSWConfig(
77
71
  M=parameters["m"],
78
72
  efConstruction=parameters["ef_construction"],
@@ -1,10 +1,12 @@
1
- from pydantic import SecretStr, BaseModel
2
- from ..api import DBConfig, DBCaseConfig, MetricType, IndexType
1
+ from pydantic import BaseModel, SecretStr
2
+
3
+ from ..api import DBCaseConfig, DBConfig, IndexType, MetricType
4
+
3
5
 
4
6
  class RedisConfig(DBConfig):
5
7
  password: SecretStr | None = None
6
8
  host: SecretStr
7
- port: int | None = None
9
+ port: int | None = None
8
10
 
9
11
  def to_dict(self) -> dict:
10
12
  return {
@@ -12,7 +14,6 @@ class RedisConfig(DBConfig):
12
14
  "port": self.port,
13
15
  "password": self.password.get_secret_value() if self.password is not None else None,
14
16
  }
15
-
16
17
 
17
18
 
18
19
  class RedisIndexConfig(BaseModel):
@@ -24,7 +25,8 @@ class RedisIndexConfig(BaseModel):
24
25
  if not self.metric_type:
25
26
  return ""
26
27
  return self.metric_type.value
27
-
28
+
29
+
28
30
  class RedisHNSWConfig(RedisIndexConfig, DBCaseConfig):
29
31
  M: int
30
32
  efConstruction: int