vectordb-bench 0.0.12__py3-none-any.whl → 0.0.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. vectordb_bench/backend/clients/__init__.py +22 -0
  2. vectordb_bench/backend/clients/api.py +21 -1
  3. vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py +47 -6
  4. vectordb_bench/backend/clients/aws_opensearch/config.py +12 -6
  5. vectordb_bench/backend/clients/aws_opensearch/run.py +34 -3
  6. vectordb_bench/backend/clients/memorydb/cli.py +88 -0
  7. vectordb_bench/backend/clients/memorydb/config.py +54 -0
  8. vectordb_bench/backend/clients/memorydb/memorydb.py +254 -0
  9. vectordb_bench/backend/clients/pgvecto_rs/cli.py +154 -0
  10. vectordb_bench/backend/clients/pgvecto_rs/config.py +108 -73
  11. vectordb_bench/backend/clients/pgvecto_rs/pgvecto_rs.py +159 -59
  12. vectordb_bench/backend/clients/pgvector/cli.py +17 -2
  13. vectordb_bench/backend/clients/pgvector/config.py +20 -5
  14. vectordb_bench/backend/clients/pgvector/pgvector.py +95 -25
  15. vectordb_bench/backend/clients/pgvectorscale/cli.py +108 -0
  16. vectordb_bench/backend/clients/pgvectorscale/config.py +111 -0
  17. vectordb_bench/backend/clients/pgvectorscale/pgvectorscale.py +290 -0
  18. vectordb_bench/backend/clients/pinecone/config.py +0 -2
  19. vectordb_bench/backend/clients/pinecone/pinecone.py +34 -36
  20. vectordb_bench/backend/clients/redis/cli.py +8 -0
  21. vectordb_bench/backend/clients/redis/config.py +37 -6
  22. vectordb_bench/backend/runner/mp_runner.py +2 -1
  23. vectordb_bench/cli/cli.py +137 -0
  24. vectordb_bench/cli/vectordbbench.py +7 -1
  25. vectordb_bench/frontend/components/check_results/charts.py +9 -6
  26. vectordb_bench/frontend/components/check_results/data.py +13 -6
  27. vectordb_bench/frontend/components/concurrent/charts.py +3 -6
  28. vectordb_bench/frontend/components/run_test/caseSelector.py +10 -0
  29. vectordb_bench/frontend/components/run_test/dbConfigSetting.py +37 -15
  30. vectordb_bench/frontend/components/run_test/initStyle.py +3 -1
  31. vectordb_bench/frontend/config/dbCaseConfigs.py +230 -9
  32. vectordb_bench/frontend/pages/quries_per_dollar.py +13 -5
  33. vectordb_bench/frontend/vdb_benchmark.py +11 -3
  34. vectordb_bench/models.py +25 -9
  35. vectordb_bench/results/Milvus/result_20230727_standard_milvus.json +53 -1
  36. vectordb_bench/results/Milvus/result_20230808_standard_milvus.json +48 -0
  37. vectordb_bench/results/ZillizCloud/result_20230727_standard_zillizcloud.json +29 -1
  38. vectordb_bench/results/ZillizCloud/result_20230808_standard_zillizcloud.json +24 -0
  39. vectordb_bench/results/ZillizCloud/result_20240105_standard_202401_zillizcloud.json +98 -49
  40. vectordb_bench/results/getLeaderboardData.py +17 -7
  41. vectordb_bench/results/leaderboard.json +1 -1
  42. {vectordb_bench-0.0.12.dist-info → vectordb_bench-0.0.14.dist-info}/METADATA +64 -31
  43. {vectordb_bench-0.0.12.dist-info → vectordb_bench-0.0.14.dist-info}/RECORD +47 -40
  44. {vectordb_bench-0.0.12.dist-info → vectordb_bench-0.0.14.dist-info}/WHEEL +1 -1
  45. {vectordb_bench-0.0.12.dist-info → vectordb_bench-0.0.14.dist-info}/LICENSE +0 -0
  46. {vectordb_bench-0.0.12.dist-info → vectordb_bench-0.0.14.dist-info}/entry_points.txt +0 -0
  47. {vectordb_bench-0.0.12.dist-info → vectordb_bench-0.0.14.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,108 @@
1
+ import click
2
+ import os
3
+ from pydantic import SecretStr
4
+
5
+ from ....cli.cli import (
6
+ CommonTypedDict,
7
+ cli,
8
+ click_parameter_decorators_from_typed_dict,
9
+ run,
10
+ )
11
+ from typing import Annotated, Unpack
12
+ from vectordb_bench.backend.clients import DB
13
+
14
+
15
+ class PgVectorScaleTypedDict(CommonTypedDict):
16
+ user_name: Annotated[
17
+ str, click.option("--user-name", type=str, help="Db username", required=True)
18
+ ]
19
+ password: Annotated[
20
+ str,
21
+ click.option("--password",
22
+ type=str,
23
+ help="Postgres database password",
24
+ default=lambda: os.environ.get("POSTGRES_PASSWORD", ""),
25
+ show_default="$POSTGRES_PASSWORD",
26
+ ),
27
+ ]
28
+
29
+ host: Annotated[
30
+ str, click.option("--host", type=str, help="Db host", required=True)
31
+ ]
32
+ db_name: Annotated[
33
+ str, click.option("--db-name", type=str, help="Db name", required=True)
34
+ ]
35
+
36
+
37
+ class PgVectorScaleDiskAnnTypedDict(PgVectorScaleTypedDict):
38
+ storage_layout: Annotated[
39
+ str,
40
+ click.option(
41
+ "--storage-layout", type=str, help="Streaming DiskANN storage layout",
42
+ ),
43
+ ]
44
+ num_neighbors: Annotated[
45
+ int,
46
+ click.option(
47
+ "--num-neighbors", type=int, help="Streaming DiskANN num neighbors",
48
+ ),
49
+ ]
50
+ search_list_size: Annotated[
51
+ int,
52
+ click.option(
53
+ "--search-list-size", type=int, help="Streaming DiskANN search list size",
54
+ ),
55
+ ]
56
+ max_alpha: Annotated[
57
+ float,
58
+ click.option(
59
+ "--max-alpha", type=float, help="Streaming DiskANN max alpha",
60
+ ),
61
+ ]
62
+ num_dimensions: Annotated[
63
+ int,
64
+ click.option(
65
+ "--num-dimensions", type=int, help="Streaming DiskANN num dimensions",
66
+ ),
67
+ ]
68
+ query_search_list_size: Annotated[
69
+ int,
70
+ click.option(
71
+ "--query-search-list-size", type=int, help="Streaming DiskANN query search list size",
72
+ ),
73
+ ]
74
+ query_rescore: Annotated[
75
+ int,
76
+ click.option(
77
+ "--query-rescore", type=int, help="Streaming DiskANN query rescore",
78
+ ),
79
+ ]
80
+
81
+
82
+ @cli.command()
83
+ @click_parameter_decorators_from_typed_dict(PgVectorScaleDiskAnnTypedDict)
84
+ def PgVectorScaleDiskAnn(
85
+ **parameters: Unpack[PgVectorScaleDiskAnnTypedDict],
86
+ ):
87
+ from .config import PgVectorScaleConfig, PgVectorScaleStreamingDiskANNConfig
88
+
89
+ run(
90
+ db=DB.PgVectorScale,
91
+ db_config=PgVectorScaleConfig(
92
+ db_label=parameters["db_label"],
93
+ user_name=SecretStr(parameters["user_name"]),
94
+ password=SecretStr(parameters["password"]),
95
+ host=parameters["host"],
96
+ db_name=parameters["db_name"],
97
+ ),
98
+ db_case_config=PgVectorScaleStreamingDiskANNConfig(
99
+ storage_layout=parameters["storage_layout"],
100
+ num_neighbors=parameters["num_neighbors"],
101
+ search_list_size=parameters["search_list_size"],
102
+ max_alpha=parameters["max_alpha"],
103
+ num_dimensions=parameters["num_dimensions"],
104
+ query_search_list_size=parameters["query_search_list_size"],
105
+ query_rescore=parameters["query_rescore"],
106
+ ),
107
+ **parameters,
108
+ )
@@ -0,0 +1,111 @@
1
+ from abc import abstractmethod
2
+ from typing import TypedDict
3
+ from pydantic import BaseModel, SecretStr
4
+ from typing_extensions import LiteralString
5
+ from ..api import DBCaseConfig, DBConfig, IndexType, MetricType
6
+
7
+ POSTGRE_URL_PLACEHOLDER = "postgresql://%s:%s@%s/%s"
8
+
9
+
10
+ class PgVectorScaleConfigDict(TypedDict):
11
+ """These keys will be directly used as kwargs in psycopg connection string,
12
+ so the names must match exactly psycopg API"""
13
+
14
+ user: str
15
+ password: str
16
+ host: str
17
+ port: int
18
+ dbname: str
19
+
20
+
21
+ class PgVectorScaleConfig(DBConfig):
22
+ user_name: SecretStr = SecretStr("postgres")
23
+ password: SecretStr
24
+ host: str = "localhost"
25
+ port: int = 5432
26
+ db_name: str
27
+
28
+ def to_dict(self) -> PgVectorScaleConfigDict:
29
+ user_str = self.user_name.get_secret_value()
30
+ pwd_str = self.password.get_secret_value()
31
+ return {
32
+ "host": self.host,
33
+ "port": self.port,
34
+ "dbname": self.db_name,
35
+ "user": user_str,
36
+ "password": pwd_str,
37
+ }
38
+
39
+
40
+ class PgVectorScaleIndexConfig(BaseModel, DBCaseConfig):
41
+ metric_type: MetricType | None = None
42
+ create_index_before_load: bool = False
43
+ create_index_after_load: bool = True
44
+
45
+ def parse_metric(self) -> str:
46
+ if self.metric_type == MetricType.COSINE:
47
+ return "vector_cosine_ops"
48
+ return ""
49
+
50
+ def parse_metric_fun_op(self) -> LiteralString:
51
+ if self.metric_type == MetricType.COSINE:
52
+ return "<=>"
53
+ return ""
54
+
55
+ def parse_metric_fun_str(self) -> str:
56
+ if self.metric_type == MetricType.COSINE:
57
+ return "cosine_distance"
58
+ return ""
59
+
60
+ @abstractmethod
61
+ def index_param(self) -> dict:
62
+ ...
63
+
64
+ @abstractmethod
65
+ def search_param(self) -> dict:
66
+ ...
67
+
68
+ @abstractmethod
69
+ def session_param(self) -> dict:
70
+ ...
71
+
72
+
73
+ class PgVectorScaleStreamingDiskANNConfig(PgVectorScaleIndexConfig):
74
+ index: IndexType = IndexType.STREAMING_DISKANN
75
+ storage_layout: str | None
76
+ num_neighbors: int | None
77
+ search_list_size: int | None
78
+ max_alpha: float | None
79
+ num_dimensions: int | None
80
+ num_bits_per_dimension: int | None
81
+ query_search_list_size: int | None
82
+ query_rescore: int | None
83
+
84
+ def index_param(self) -> dict:
85
+ return {
86
+ "metric": self.parse_metric(),
87
+ "index_type": self.index.value,
88
+ "options": {
89
+ "storage_layout": self.storage_layout,
90
+ "num_neighbors": self.num_neighbors,
91
+ "search_list_size": self.search_list_size,
92
+ "max_alpha": self.max_alpha,
93
+ "num_dimensions": self.num_dimensions,
94
+ },
95
+ }
96
+
97
+ def search_param(self) -> dict:
98
+ return {
99
+ "metric": self.parse_metric(),
100
+ "metric_fun_op": self.parse_metric_fun_op(),
101
+ }
102
+
103
+ def session_param(self) -> dict:
104
+ return {
105
+ "diskann.query_search_list_size": self.query_search_list_size,
106
+ "diskann.query_rescore": self.query_rescore,
107
+ }
108
+
109
+ _pgvectorscale_case_config = {
110
+ IndexType.STREAMING_DISKANN: PgVectorScaleStreamingDiskANNConfig,
111
+ }
@@ -0,0 +1,290 @@
1
+ """Wrapper around the Pgvectorscale vector database over VectorDB"""
2
+
3
+ import logging
4
+ import pprint
5
+ from contextlib import contextmanager
6
+ from typing import Any, Generator, Optional, Tuple
7
+
8
+ import numpy as np
9
+ import psycopg
10
+ from pgvector.psycopg import register_vector
11
+ from psycopg import Connection, Cursor, sql
12
+
13
+ from ..api import VectorDB
14
+ from .config import PgVectorScaleConfigDict, PgVectorScaleIndexConfig
15
+
16
+ log = logging.getLogger(__name__)
17
+
18
+
19
+ class PgVectorScale(VectorDB):
20
+ """Use psycopg instructions"""
21
+
22
+ conn: psycopg.Connection[Any] | None = None
23
+ coursor: psycopg.Cursor[Any] | None = None
24
+
25
+ _unfiltered_search: sql.Composed
26
+ _filtered_search: sql.Composed
27
+
28
+ def __init__(
29
+ self,
30
+ dim: int,
31
+ db_config: PgVectorScaleConfigDict,
32
+ db_case_config: PgVectorScaleIndexConfig,
33
+ collection_name: str = "pg_vectorscale_collection",
34
+ drop_old: bool = False,
35
+ **kwargs,
36
+ ):
37
+ self.name = "PgVectorScale"
38
+ self.db_config = db_config
39
+ self.case_config = db_case_config
40
+ self.table_name = collection_name
41
+ self.dim = dim
42
+
43
+ self._index_name = "pgvectorscale_index"
44
+ self._primary_field = "id"
45
+ self._vector_field = "embedding"
46
+
47
+ self.conn, self.cursor = self._create_connection(**self.db_config)
48
+
49
+ log.info(f"{self.name} config values: {self.db_config}\n{self.case_config}")
50
+ if not any(
51
+ (
52
+ self.case_config.create_index_before_load,
53
+ self.case_config.create_index_after_load,
54
+ )
55
+ ):
56
+ err = f"{self.name} config must create an index using create_index_before_load or create_index_after_load"
57
+ log.error(err)
58
+ raise RuntimeError(
59
+ f"{err}\n{pprint.pformat(self.db_config)}\n{pprint.pformat(self.case_config)}"
60
+ )
61
+
62
+ if drop_old:
63
+ self._drop_index()
64
+ self._drop_table()
65
+ self._create_table(dim)
66
+ if self.case_config.create_index_before_load:
67
+ self._create_index()
68
+
69
+ self.cursor.close()
70
+ self.conn.close()
71
+ self.cursor = None
72
+ self.conn = None
73
+
74
+ @staticmethod
75
+ def _create_connection(**kwargs) -> Tuple[Connection, Cursor]:
76
+ conn = psycopg.connect(**kwargs)
77
+ conn.cursor().execute("CREATE EXTENSION IF NOT EXISTS vectorscale CASCADE")
78
+ conn.commit()
79
+ register_vector(conn)
80
+ conn.autocommit = False
81
+ cursor = conn.cursor()
82
+
83
+ assert conn is not None, "Connection is not initialized"
84
+ assert cursor is not None, "Cursor is not initialized"
85
+
86
+ return conn, cursor
87
+
88
+ @contextmanager
89
+ def init(self) -> Generator[None, None, None]:
90
+ self.conn, self.cursor = self._create_connection(**self.db_config)
91
+
92
+ # index configuration may have commands defined that we should set during each client session
93
+ session_options: dict[str, Any] = self.case_config.session_param()
94
+
95
+ if len(session_options) > 0:
96
+ for setting_name, setting_val in session_options.items():
97
+ command = sql.SQL("SET {setting_name} " + "= {setting_val};").format(
98
+ setting_name=sql.Identifier(setting_name),
99
+ setting_val=sql.Identifier(str(setting_val)),
100
+ )
101
+ log.debug(command.as_string(self.cursor))
102
+ self.cursor.execute(command)
103
+ self.conn.commit()
104
+
105
+ self._filtered_search = sql.Composed(
106
+ [
107
+ sql.SQL("SELECT id FROM public.{} WHERE id >= %s ORDER BY embedding ").format(
108
+ sql.Identifier(self.table_name),
109
+ ),
110
+ sql.SQL(self.case_config.search_param()["metric_fun_op"]),
111
+ sql.SQL(" %s::vector LIMIT %s::int")
112
+ ]
113
+ )
114
+
115
+ self._unfiltered_search = sql.Composed(
116
+ [
117
+ sql.SQL("SELECT id FROM public.{} ORDER BY embedding ").format(
118
+ sql.Identifier(self.table_name)
119
+ ),
120
+ sql.SQL(self.case_config.search_param()["metric_fun_op"]),
121
+ sql.SQL(" %s::vector LIMIT %s::int"),
122
+ ]
123
+ )
124
+
125
+ try:
126
+ yield
127
+ finally:
128
+ self.cursor.close()
129
+ self.conn.close()
130
+ self.cursor = None
131
+ self.conn = None
132
+
133
+ def _drop_table(self):
134
+ assert self.conn is not None, "Connection is not initialized"
135
+ assert self.cursor is not None, "Cursor is not initialized"
136
+ log.info(f"{self.name} client drop table : {self.table_name}")
137
+
138
+ self.cursor.execute(
139
+ sql.SQL("DROP TABLE IF EXISTS public.{table_name}").format(
140
+ table_name=sql.Identifier(self.table_name)
141
+ )
142
+ )
143
+ self.conn.commit()
144
+
145
+ def ready_to_load(self):
146
+ pass
147
+
148
+ def optimize(self):
149
+ self._post_insert()
150
+
151
+ def _post_insert(self):
152
+ log.info(f"{self.name} post insert before optimize")
153
+ if self.case_config.create_index_after_load:
154
+ self._drop_index()
155
+ self._create_index()
156
+
157
+ def _drop_index(self):
158
+ assert self.conn is not None, "Connection is not initialized"
159
+ assert self.cursor is not None, "Cursor is not initialized"
160
+ log.info(f"{self.name} client drop index : {self._index_name}")
161
+
162
+ drop_index_sql = sql.SQL("DROP INDEX IF EXISTS {index_name}").format(
163
+ index_name=sql.Identifier(self._index_name)
164
+ )
165
+ log.debug(drop_index_sql.as_string(self.cursor))
166
+ self.cursor.execute(drop_index_sql)
167
+ self.conn.commit()
168
+
169
+ def _create_index(self):
170
+ assert self.conn is not None, "Connection is not initialized"
171
+ assert self.cursor is not None, "Cursor is not initialized"
172
+ log.info(f"{self.name} client create index : {self._index_name}")
173
+
174
+ index_param: dict[str, Any] = self.case_config.index_param()
175
+
176
+ options = []
177
+ for option_name, option_val in index_param["options"].items():
178
+ if option_val is not None:
179
+ options.append(
180
+ sql.SQL("{option_name} = {val}").format(
181
+ option_name=sql.Identifier(option_name),
182
+ val=sql.Identifier(str(option_val)),
183
+ )
184
+ )
185
+
186
+ num_bits_per_dimension = "2" if self.dim < 900 else "1"
187
+ options.append(
188
+ sql.SQL("{option_name} = {val}").format(
189
+ option_name=sql.Identifier("num_bits_per_dimension"),
190
+ val=sql.Identifier(num_bits_per_dimension),
191
+ )
192
+ )
193
+
194
+ if any(options):
195
+ with_clause = sql.SQL("WITH ({});").format(sql.SQL(", ").join(options))
196
+ else:
197
+ with_clause = sql.Composed(())
198
+
199
+ index_create_sql = sql.SQL(
200
+ """
201
+ CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name}
202
+ USING {index_type} (embedding {embedding_metric})
203
+ """
204
+ ).format(
205
+ index_name=sql.Identifier(self._index_name),
206
+ table_name=sql.Identifier(self.table_name),
207
+ index_type=sql.Identifier(index_param["index_type"].lower()),
208
+ embedding_metric=sql.Identifier(index_param["metric"]),
209
+ )
210
+ index_create_sql_with_with_clause = (
211
+ index_create_sql + with_clause
212
+ ).join(" ")
213
+ log.debug(index_create_sql_with_with_clause.as_string(self.cursor))
214
+ self.cursor.execute(index_create_sql_with_with_clause)
215
+ self.conn.commit()
216
+
217
+ def _create_table(self, dim: int):
218
+ assert self.conn is not None, "Connection is not initialized"
219
+ assert self.cursor is not None, "Cursor is not initialized"
220
+
221
+ try:
222
+ log.info(f"{self.name} client create table : {self.table_name}")
223
+
224
+ self.cursor.execute(
225
+ sql.SQL(
226
+ "CREATE TABLE IF NOT EXISTS public.{table_name} (id BIGINT PRIMARY KEY, embedding vector({dim}));"
227
+ ).format(table_name=sql.Identifier(self.table_name), dim=dim)
228
+ )
229
+ self.conn.commit()
230
+ except Exception as e:
231
+ log.warning(
232
+ f"Failed to create pgvectorscale table: {self.table_name} error: {e}"
233
+ )
234
+ raise e from None
235
+
236
+ def insert_embeddings(
237
+ self,
238
+ embeddings: list[list[float]],
239
+ metadata: list[int],
240
+ **kwargs: Any,
241
+ ) -> Tuple[int, Optional[Exception]]:
242
+ assert self.conn is not None, "Connection is not initialized"
243
+ assert self.cursor is not None, "Cursor is not initialized"
244
+
245
+ try:
246
+ metadata_arr = np.array(metadata)
247
+ embeddings_arr = np.array(embeddings)
248
+
249
+ with self.cursor.copy(
250
+ sql.SQL("COPY public.{table_name} FROM STDIN (FORMAT BINARY)").format(
251
+ table_name=sql.Identifier(self.table_name)
252
+ )
253
+ ) as copy:
254
+ copy.set_types(["bigint", "vector"])
255
+ for i, row in enumerate(metadata_arr):
256
+ copy.write_row((row, embeddings_arr[i]))
257
+ self.conn.commit()
258
+
259
+ if kwargs.get("last_batch"):
260
+ self._post_insert()
261
+
262
+ return len(metadata), None
263
+ except Exception as e:
264
+ log.warning(
265
+ f"Failed to insert data into pgvector table ({self.table_name}), error: {e}"
266
+ )
267
+ return 0, e
268
+
269
+ def search_embedding(
270
+ self,
271
+ query: list[float],
272
+ k: int = 100,
273
+ filters: dict | None = None,
274
+ timeout: int | None = None,
275
+ ) -> list[int]:
276
+ assert self.conn is not None, "Connection is not initialized"
277
+ assert self.cursor is not None, "Cursor is not initialized"
278
+
279
+ q = np.asarray(query)
280
+ if filters:
281
+ gt = filters.get("id")
282
+ result = self.cursor.execute(
283
+ self._filtered_search, (gt, q, k), prepare=True, binary=True
284
+ )
285
+ else:
286
+ result = self.cursor.execute(
287
+ self._unfiltered_search, (q, k), prepare=True, binary=True
288
+ )
289
+
290
+ return [int(i[0]) for i in result.fetchall()]
@@ -4,12 +4,10 @@ from ..api import DBConfig
4
4
 
5
5
  class PineconeConfig(DBConfig):
6
6
  api_key: SecretStr
7
- environment: SecretStr
8
7
  index_name: str
9
8
 
10
9
  def to_dict(self) -> dict:
11
10
  return {
12
11
  "api_key": self.api_key.get_secret_value(),
13
- "environment": self.environment.get_secret_value(),
14
12
  "index_name": self.index_name,
15
13
  }
@@ -3,7 +3,7 @@
3
3
  import logging
4
4
  from contextlib import contextmanager
5
5
  from typing import Type
6
-
6
+ import pinecone
7
7
  from ..api import VectorDB, DBConfig, DBCaseConfig, EmptyDBCaseConfig, IndexType
8
8
  from .config import PineconeConfig
9
9
 
@@ -11,7 +11,8 @@ from .config import PineconeConfig
11
11
  log = logging.getLogger(__name__)
12
12
 
13
13
  PINECONE_MAX_NUM_PER_BATCH = 1000
14
- PINECONE_MAX_SIZE_PER_BATCH = 2 * 1024 * 1024 # 2MB
14
+ PINECONE_MAX_SIZE_PER_BATCH = 2 * 1024 * 1024 # 2MB
15
+
15
16
 
16
17
  class Pinecone(VectorDB):
17
18
  def __init__(
@@ -23,30 +24,25 @@ class Pinecone(VectorDB):
23
24
  **kwargs,
24
25
  ):
25
26
  """Initialize wrapper around the milvus vector database."""
26
- self.index_name = db_config["index_name"]
27
- self.api_key = db_config["api_key"]
28
- self.environment = db_config["environment"]
29
- self.batch_size = int(min(PINECONE_MAX_SIZE_PER_BATCH / (dim * 5), PINECONE_MAX_NUM_PER_BATCH))
30
- # Pincone will make connections with server while import
31
- # so place the import here.
32
- import pinecone
33
- pinecone.init(
34
- api_key=self.api_key, environment=self.environment)
27
+ self.index_name = db_config.get("index_name", "")
28
+ self.api_key = db_config.get("api_key", "")
29
+ self.batch_size = int(
30
+ min(PINECONE_MAX_SIZE_PER_BATCH / (dim * 5), PINECONE_MAX_NUM_PER_BATCH)
31
+ )
32
+
33
+ pc = pinecone.Pinecone(api_key=self.api_key)
34
+ index = pc.Index(self.index_name)
35
+
35
36
  if drop_old:
36
- list_indexes = pinecone.list_indexes()
37
- if self.index_name in list_indexes:
38
- index = pinecone.Index(self.index_name)
39
- index_dim = index.describe_index_stats()["dimension"]
40
- if (index_dim != dim):
41
- raise ValueError(
42
- f"Pinecone index {self.index_name} dimension mismatch, expected {index_dim} got {dim}")
43
- log.info(
44
- f"Pinecone client delete old index: {self.index_name}")
45
- index.delete(delete_all=True)
46
- index.close()
47
- else:
37
+ index_stats = index.describe_index_stats()
38
+ index_dim = index_stats["dimension"]
39
+ if index_dim != dim:
48
40
  raise ValueError(
49
- f"Pinecone index {self.index_name} does not exist")
41
+ f"Pinecone index {self.index_name} dimension mismatch, expected {index_dim} got {dim}"
42
+ )
43
+ for namespace in index_stats["namespaces"]:
44
+ log.info(f"Pinecone index delete namespace: {namespace}")
45
+ index.delete(delete_all=True, namespace=namespace)
50
46
 
51
47
  self._metadata_key = "meta"
52
48
 
@@ -59,13 +55,10 @@ class Pinecone(VectorDB):
59
55
  return EmptyDBCaseConfig
60
56
 
61
57
  @contextmanager
62
- def init(self) -> None:
63
- import pinecone
64
- pinecone.init(
65
- api_key=self.api_key, environment=self.environment)
66
- self.index = pinecone.Index(self.index_name)
58
+ def init(self):
59
+ pc = pinecone.Pinecone(api_key=self.api_key)
60
+ self.index = pc.Index(self.index_name)
67
61
  yield
68
- self.index.close()
69
62
 
70
63
  def ready_to_load(self):
71
64
  pass
@@ -83,11 +76,16 @@ class Pinecone(VectorDB):
83
76
  insert_count = 0
84
77
  try:
85
78
  for batch_start_offset in range(0, len(embeddings), self.batch_size):
86
- batch_end_offset = min(batch_start_offset + self.batch_size, len(embeddings))
79
+ batch_end_offset = min(
80
+ batch_start_offset + self.batch_size, len(embeddings)
81
+ )
87
82
  insert_datas = []
88
83
  for i in range(batch_start_offset, batch_end_offset):
89
- insert_data = (str(metadata[i]), embeddings[i], {
90
- self._metadata_key: metadata[i]})
84
+ insert_data = (
85
+ str(metadata[i]),
86
+ embeddings[i],
87
+ {self._metadata_key: metadata[i]},
88
+ )
91
89
  insert_datas.append(insert_data)
92
90
  self.index.upsert(insert_datas)
93
91
  insert_count += batch_end_offset - batch_start_offset
@@ -101,7 +99,7 @@ class Pinecone(VectorDB):
101
99
  k: int = 100,
102
100
  filters: dict | None = None,
103
101
  timeout: int | None = None,
104
- ) -> list[tuple[int, float]]:
102
+ ) -> list[int]:
105
103
  if filters is None:
106
104
  pinecone_filters = {}
107
105
  else:
@@ -111,9 +109,9 @@ class Pinecone(VectorDB):
111
109
  top_k=k,
112
110
  vector=query,
113
111
  filter=pinecone_filters,
114
- )['matches']
112
+ )["matches"]
115
113
  except Exception as e:
116
114
  print(f"Error querying index: {e}")
117
115
  raise e
118
- id_res = [int(one_res['id']) for one_res in res]
116
+ id_res = [int(one_res["id"]) for one_res in res]
119
117
  return id_res
@@ -3,6 +3,9 @@ from typing import Annotated, TypedDict, Unpack
3
3
  import click
4
4
  from pydantic import SecretStr
5
5
 
6
+ from .config import RedisHNSWConfig
7
+
8
+
6
9
  from ....cli.cli import (
7
10
  CommonTypedDict,
8
11
  HNSWFlavor2,
@@ -69,6 +72,11 @@ def Redis(**parameters: Unpack[RedisHNSWTypedDict]):
69
72
  ssl=parameters["ssl"],
70
73
  ssl_ca_certs=parameters["ssl_ca_certs"],
71
74
  cmd=parameters["cmd"],
75
+ ),
76
+ db_case_config=RedisHNSWConfig(
77
+ M=parameters["m"],
78
+ efConstruction=parameters["ef_construction"],
79
+ ef=parameters["ef_runtime"],
72
80
  ),
73
81
  **parameters,
74
82
  )