vectordb-bench 0.0.12__py3-none-any.whl → 0.0.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vectordb_bench/backend/clients/__init__.py +22 -0
- vectordb_bench/backend/clients/api.py +21 -1
- vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py +47 -6
- vectordb_bench/backend/clients/aws_opensearch/config.py +12 -6
- vectordb_bench/backend/clients/aws_opensearch/run.py +34 -3
- vectordb_bench/backend/clients/memorydb/cli.py +88 -0
- vectordb_bench/backend/clients/memorydb/config.py +54 -0
- vectordb_bench/backend/clients/memorydb/memorydb.py +254 -0
- vectordb_bench/backend/clients/pgvecto_rs/cli.py +154 -0
- vectordb_bench/backend/clients/pgvecto_rs/config.py +108 -73
- vectordb_bench/backend/clients/pgvecto_rs/pgvecto_rs.py +159 -59
- vectordb_bench/backend/clients/pgvector/cli.py +17 -2
- vectordb_bench/backend/clients/pgvector/config.py +20 -5
- vectordb_bench/backend/clients/pgvector/pgvector.py +95 -25
- vectordb_bench/backend/clients/pgvectorscale/cli.py +108 -0
- vectordb_bench/backend/clients/pgvectorscale/config.py +111 -0
- vectordb_bench/backend/clients/pgvectorscale/pgvectorscale.py +290 -0
- vectordb_bench/backend/clients/pinecone/config.py +0 -2
- vectordb_bench/backend/clients/pinecone/pinecone.py +34 -36
- vectordb_bench/backend/clients/redis/cli.py +8 -0
- vectordb_bench/backend/clients/redis/config.py +37 -6
- vectordb_bench/backend/runner/mp_runner.py +2 -1
- vectordb_bench/cli/cli.py +137 -0
- vectordb_bench/cli/vectordbbench.py +7 -1
- vectordb_bench/frontend/components/check_results/charts.py +9 -6
- vectordb_bench/frontend/components/check_results/data.py +13 -6
- vectordb_bench/frontend/components/concurrent/charts.py +3 -6
- vectordb_bench/frontend/components/run_test/caseSelector.py +10 -0
- vectordb_bench/frontend/components/run_test/dbConfigSetting.py +37 -15
- vectordb_bench/frontend/components/run_test/initStyle.py +3 -1
- vectordb_bench/frontend/config/dbCaseConfigs.py +230 -9
- vectordb_bench/frontend/pages/quries_per_dollar.py +13 -5
- vectordb_bench/frontend/vdb_benchmark.py +11 -3
- vectordb_bench/models.py +25 -9
- vectordb_bench/results/Milvus/result_20230727_standard_milvus.json +53 -1
- vectordb_bench/results/Milvus/result_20230808_standard_milvus.json +48 -0
- vectordb_bench/results/ZillizCloud/result_20230727_standard_zillizcloud.json +29 -1
- vectordb_bench/results/ZillizCloud/result_20230808_standard_zillizcloud.json +24 -0
- vectordb_bench/results/ZillizCloud/result_20240105_standard_202401_zillizcloud.json +98 -49
- vectordb_bench/results/getLeaderboardData.py +17 -7
- vectordb_bench/results/leaderboard.json +1 -1
- {vectordb_bench-0.0.12.dist-info → vectordb_bench-0.0.14.dist-info}/METADATA +64 -31
- {vectordb_bench-0.0.12.dist-info → vectordb_bench-0.0.14.dist-info}/RECORD +47 -40
- {vectordb_bench-0.0.12.dist-info → vectordb_bench-0.0.14.dist-info}/WHEEL +1 -1
- {vectordb_bench-0.0.12.dist-info → vectordb_bench-0.0.14.dist-info}/LICENSE +0 -0
- {vectordb_bench-0.0.12.dist-info → vectordb_bench-0.0.14.dist-info}/entry_points.txt +0 -0
- {vectordb_bench-0.0.12.dist-info → vectordb_bench-0.0.14.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,108 @@
|
|
1
|
+
import click
|
2
|
+
import os
|
3
|
+
from pydantic import SecretStr
|
4
|
+
|
5
|
+
from ....cli.cli import (
|
6
|
+
CommonTypedDict,
|
7
|
+
cli,
|
8
|
+
click_parameter_decorators_from_typed_dict,
|
9
|
+
run,
|
10
|
+
)
|
11
|
+
from typing import Annotated, Unpack
|
12
|
+
from vectordb_bench.backend.clients import DB
|
13
|
+
|
14
|
+
|
15
|
+
class PgVectorScaleTypedDict(CommonTypedDict):
|
16
|
+
user_name: Annotated[
|
17
|
+
str, click.option("--user-name", type=str, help="Db username", required=True)
|
18
|
+
]
|
19
|
+
password: Annotated[
|
20
|
+
str,
|
21
|
+
click.option("--password",
|
22
|
+
type=str,
|
23
|
+
help="Postgres database password",
|
24
|
+
default=lambda: os.environ.get("POSTGRES_PASSWORD", ""),
|
25
|
+
show_default="$POSTGRES_PASSWORD",
|
26
|
+
),
|
27
|
+
]
|
28
|
+
|
29
|
+
host: Annotated[
|
30
|
+
str, click.option("--host", type=str, help="Db host", required=True)
|
31
|
+
]
|
32
|
+
db_name: Annotated[
|
33
|
+
str, click.option("--db-name", type=str, help="Db name", required=True)
|
34
|
+
]
|
35
|
+
|
36
|
+
|
37
|
+
class PgVectorScaleDiskAnnTypedDict(PgVectorScaleTypedDict):
|
38
|
+
storage_layout: Annotated[
|
39
|
+
str,
|
40
|
+
click.option(
|
41
|
+
"--storage-layout", type=str, help="Streaming DiskANN storage layout",
|
42
|
+
),
|
43
|
+
]
|
44
|
+
num_neighbors: Annotated[
|
45
|
+
int,
|
46
|
+
click.option(
|
47
|
+
"--num-neighbors", type=int, help="Streaming DiskANN num neighbors",
|
48
|
+
),
|
49
|
+
]
|
50
|
+
search_list_size: Annotated[
|
51
|
+
int,
|
52
|
+
click.option(
|
53
|
+
"--search-list-size", type=int, help="Streaming DiskANN search list size",
|
54
|
+
),
|
55
|
+
]
|
56
|
+
max_alpha: Annotated[
|
57
|
+
float,
|
58
|
+
click.option(
|
59
|
+
"--max-alpha", type=float, help="Streaming DiskANN max alpha",
|
60
|
+
),
|
61
|
+
]
|
62
|
+
num_dimensions: Annotated[
|
63
|
+
int,
|
64
|
+
click.option(
|
65
|
+
"--num-dimensions", type=int, help="Streaming DiskANN num dimensions",
|
66
|
+
),
|
67
|
+
]
|
68
|
+
query_search_list_size: Annotated[
|
69
|
+
int,
|
70
|
+
click.option(
|
71
|
+
"--query-search-list-size", type=int, help="Streaming DiskANN query search list size",
|
72
|
+
),
|
73
|
+
]
|
74
|
+
query_rescore: Annotated[
|
75
|
+
int,
|
76
|
+
click.option(
|
77
|
+
"--query-rescore", type=int, help="Streaming DiskANN query rescore",
|
78
|
+
),
|
79
|
+
]
|
80
|
+
|
81
|
+
|
82
|
+
@cli.command()
|
83
|
+
@click_parameter_decorators_from_typed_dict(PgVectorScaleDiskAnnTypedDict)
|
84
|
+
def PgVectorScaleDiskAnn(
|
85
|
+
**parameters: Unpack[PgVectorScaleDiskAnnTypedDict],
|
86
|
+
):
|
87
|
+
from .config import PgVectorScaleConfig, PgVectorScaleStreamingDiskANNConfig
|
88
|
+
|
89
|
+
run(
|
90
|
+
db=DB.PgVectorScale,
|
91
|
+
db_config=PgVectorScaleConfig(
|
92
|
+
db_label=parameters["db_label"],
|
93
|
+
user_name=SecretStr(parameters["user_name"]),
|
94
|
+
password=SecretStr(parameters["password"]),
|
95
|
+
host=parameters["host"],
|
96
|
+
db_name=parameters["db_name"],
|
97
|
+
),
|
98
|
+
db_case_config=PgVectorScaleStreamingDiskANNConfig(
|
99
|
+
storage_layout=parameters["storage_layout"],
|
100
|
+
num_neighbors=parameters["num_neighbors"],
|
101
|
+
search_list_size=parameters["search_list_size"],
|
102
|
+
max_alpha=parameters["max_alpha"],
|
103
|
+
num_dimensions=parameters["num_dimensions"],
|
104
|
+
query_search_list_size=parameters["query_search_list_size"],
|
105
|
+
query_rescore=parameters["query_rescore"],
|
106
|
+
),
|
107
|
+
**parameters,
|
108
|
+
)
|
@@ -0,0 +1,111 @@
|
|
1
|
+
from abc import abstractmethod
|
2
|
+
from typing import TypedDict
|
3
|
+
from pydantic import BaseModel, SecretStr
|
4
|
+
from typing_extensions import LiteralString
|
5
|
+
from ..api import DBCaseConfig, DBConfig, IndexType, MetricType
|
6
|
+
|
7
|
+
POSTGRE_URL_PLACEHOLDER = "postgresql://%s:%s@%s/%s"
|
8
|
+
|
9
|
+
|
10
|
+
class PgVectorScaleConfigDict(TypedDict):
|
11
|
+
"""These keys will be directly used as kwargs in psycopg connection string,
|
12
|
+
so the names must match exactly psycopg API"""
|
13
|
+
|
14
|
+
user: str
|
15
|
+
password: str
|
16
|
+
host: str
|
17
|
+
port: int
|
18
|
+
dbname: str
|
19
|
+
|
20
|
+
|
21
|
+
class PgVectorScaleConfig(DBConfig):
|
22
|
+
user_name: SecretStr = SecretStr("postgres")
|
23
|
+
password: SecretStr
|
24
|
+
host: str = "localhost"
|
25
|
+
port: int = 5432
|
26
|
+
db_name: str
|
27
|
+
|
28
|
+
def to_dict(self) -> PgVectorScaleConfigDict:
|
29
|
+
user_str = self.user_name.get_secret_value()
|
30
|
+
pwd_str = self.password.get_secret_value()
|
31
|
+
return {
|
32
|
+
"host": self.host,
|
33
|
+
"port": self.port,
|
34
|
+
"dbname": self.db_name,
|
35
|
+
"user": user_str,
|
36
|
+
"password": pwd_str,
|
37
|
+
}
|
38
|
+
|
39
|
+
|
40
|
+
class PgVectorScaleIndexConfig(BaseModel, DBCaseConfig):
|
41
|
+
metric_type: MetricType | None = None
|
42
|
+
create_index_before_load: bool = False
|
43
|
+
create_index_after_load: bool = True
|
44
|
+
|
45
|
+
def parse_metric(self) -> str:
|
46
|
+
if self.metric_type == MetricType.COSINE:
|
47
|
+
return "vector_cosine_ops"
|
48
|
+
return ""
|
49
|
+
|
50
|
+
def parse_metric_fun_op(self) -> LiteralString:
|
51
|
+
if self.metric_type == MetricType.COSINE:
|
52
|
+
return "<=>"
|
53
|
+
return ""
|
54
|
+
|
55
|
+
def parse_metric_fun_str(self) -> str:
|
56
|
+
if self.metric_type == MetricType.COSINE:
|
57
|
+
return "cosine_distance"
|
58
|
+
return ""
|
59
|
+
|
60
|
+
@abstractmethod
|
61
|
+
def index_param(self) -> dict:
|
62
|
+
...
|
63
|
+
|
64
|
+
@abstractmethod
|
65
|
+
def search_param(self) -> dict:
|
66
|
+
...
|
67
|
+
|
68
|
+
@abstractmethod
|
69
|
+
def session_param(self) -> dict:
|
70
|
+
...
|
71
|
+
|
72
|
+
|
73
|
+
class PgVectorScaleStreamingDiskANNConfig(PgVectorScaleIndexConfig):
|
74
|
+
index: IndexType = IndexType.STREAMING_DISKANN
|
75
|
+
storage_layout: str | None
|
76
|
+
num_neighbors: int | None
|
77
|
+
search_list_size: int | None
|
78
|
+
max_alpha: float | None
|
79
|
+
num_dimensions: int | None
|
80
|
+
num_bits_per_dimension: int | None
|
81
|
+
query_search_list_size: int | None
|
82
|
+
query_rescore: int | None
|
83
|
+
|
84
|
+
def index_param(self) -> dict:
|
85
|
+
return {
|
86
|
+
"metric": self.parse_metric(),
|
87
|
+
"index_type": self.index.value,
|
88
|
+
"options": {
|
89
|
+
"storage_layout": self.storage_layout,
|
90
|
+
"num_neighbors": self.num_neighbors,
|
91
|
+
"search_list_size": self.search_list_size,
|
92
|
+
"max_alpha": self.max_alpha,
|
93
|
+
"num_dimensions": self.num_dimensions,
|
94
|
+
},
|
95
|
+
}
|
96
|
+
|
97
|
+
def search_param(self) -> dict:
|
98
|
+
return {
|
99
|
+
"metric": self.parse_metric(),
|
100
|
+
"metric_fun_op": self.parse_metric_fun_op(),
|
101
|
+
}
|
102
|
+
|
103
|
+
def session_param(self) -> dict:
|
104
|
+
return {
|
105
|
+
"diskann.query_search_list_size": self.query_search_list_size,
|
106
|
+
"diskann.query_rescore": self.query_rescore,
|
107
|
+
}
|
108
|
+
|
109
|
+
_pgvectorscale_case_config = {
|
110
|
+
IndexType.STREAMING_DISKANN: PgVectorScaleStreamingDiskANNConfig,
|
111
|
+
}
|
@@ -0,0 +1,290 @@
|
|
1
|
+
"""Wrapper around the Pgvectorscale vector database over VectorDB"""
|
2
|
+
|
3
|
+
import logging
|
4
|
+
import pprint
|
5
|
+
from contextlib import contextmanager
|
6
|
+
from typing import Any, Generator, Optional, Tuple
|
7
|
+
|
8
|
+
import numpy as np
|
9
|
+
import psycopg
|
10
|
+
from pgvector.psycopg import register_vector
|
11
|
+
from psycopg import Connection, Cursor, sql
|
12
|
+
|
13
|
+
from ..api import VectorDB
|
14
|
+
from .config import PgVectorScaleConfigDict, PgVectorScaleIndexConfig
|
15
|
+
|
16
|
+
log = logging.getLogger(__name__)
|
17
|
+
|
18
|
+
|
19
|
+
class PgVectorScale(VectorDB):
|
20
|
+
"""Use psycopg instructions"""
|
21
|
+
|
22
|
+
conn: psycopg.Connection[Any] | None = None
|
23
|
+
coursor: psycopg.Cursor[Any] | None = None
|
24
|
+
|
25
|
+
_unfiltered_search: sql.Composed
|
26
|
+
_filtered_search: sql.Composed
|
27
|
+
|
28
|
+
def __init__(
|
29
|
+
self,
|
30
|
+
dim: int,
|
31
|
+
db_config: PgVectorScaleConfigDict,
|
32
|
+
db_case_config: PgVectorScaleIndexConfig,
|
33
|
+
collection_name: str = "pg_vectorscale_collection",
|
34
|
+
drop_old: bool = False,
|
35
|
+
**kwargs,
|
36
|
+
):
|
37
|
+
self.name = "PgVectorScale"
|
38
|
+
self.db_config = db_config
|
39
|
+
self.case_config = db_case_config
|
40
|
+
self.table_name = collection_name
|
41
|
+
self.dim = dim
|
42
|
+
|
43
|
+
self._index_name = "pgvectorscale_index"
|
44
|
+
self._primary_field = "id"
|
45
|
+
self._vector_field = "embedding"
|
46
|
+
|
47
|
+
self.conn, self.cursor = self._create_connection(**self.db_config)
|
48
|
+
|
49
|
+
log.info(f"{self.name} config values: {self.db_config}\n{self.case_config}")
|
50
|
+
if not any(
|
51
|
+
(
|
52
|
+
self.case_config.create_index_before_load,
|
53
|
+
self.case_config.create_index_after_load,
|
54
|
+
)
|
55
|
+
):
|
56
|
+
err = f"{self.name} config must create an index using create_index_before_load or create_index_after_load"
|
57
|
+
log.error(err)
|
58
|
+
raise RuntimeError(
|
59
|
+
f"{err}\n{pprint.pformat(self.db_config)}\n{pprint.pformat(self.case_config)}"
|
60
|
+
)
|
61
|
+
|
62
|
+
if drop_old:
|
63
|
+
self._drop_index()
|
64
|
+
self._drop_table()
|
65
|
+
self._create_table(dim)
|
66
|
+
if self.case_config.create_index_before_load:
|
67
|
+
self._create_index()
|
68
|
+
|
69
|
+
self.cursor.close()
|
70
|
+
self.conn.close()
|
71
|
+
self.cursor = None
|
72
|
+
self.conn = None
|
73
|
+
|
74
|
+
@staticmethod
|
75
|
+
def _create_connection(**kwargs) -> Tuple[Connection, Cursor]:
|
76
|
+
conn = psycopg.connect(**kwargs)
|
77
|
+
conn.cursor().execute("CREATE EXTENSION IF NOT EXISTS vectorscale CASCADE")
|
78
|
+
conn.commit()
|
79
|
+
register_vector(conn)
|
80
|
+
conn.autocommit = False
|
81
|
+
cursor = conn.cursor()
|
82
|
+
|
83
|
+
assert conn is not None, "Connection is not initialized"
|
84
|
+
assert cursor is not None, "Cursor is not initialized"
|
85
|
+
|
86
|
+
return conn, cursor
|
87
|
+
|
88
|
+
@contextmanager
|
89
|
+
def init(self) -> Generator[None, None, None]:
|
90
|
+
self.conn, self.cursor = self._create_connection(**self.db_config)
|
91
|
+
|
92
|
+
# index configuration may have commands defined that we should set during each client session
|
93
|
+
session_options: dict[str, Any] = self.case_config.session_param()
|
94
|
+
|
95
|
+
if len(session_options) > 0:
|
96
|
+
for setting_name, setting_val in session_options.items():
|
97
|
+
command = sql.SQL("SET {setting_name} " + "= {setting_val};").format(
|
98
|
+
setting_name=sql.Identifier(setting_name),
|
99
|
+
setting_val=sql.Identifier(str(setting_val)),
|
100
|
+
)
|
101
|
+
log.debug(command.as_string(self.cursor))
|
102
|
+
self.cursor.execute(command)
|
103
|
+
self.conn.commit()
|
104
|
+
|
105
|
+
self._filtered_search = sql.Composed(
|
106
|
+
[
|
107
|
+
sql.SQL("SELECT id FROM public.{} WHERE id >= %s ORDER BY embedding ").format(
|
108
|
+
sql.Identifier(self.table_name),
|
109
|
+
),
|
110
|
+
sql.SQL(self.case_config.search_param()["metric_fun_op"]),
|
111
|
+
sql.SQL(" %s::vector LIMIT %s::int")
|
112
|
+
]
|
113
|
+
)
|
114
|
+
|
115
|
+
self._unfiltered_search = sql.Composed(
|
116
|
+
[
|
117
|
+
sql.SQL("SELECT id FROM public.{} ORDER BY embedding ").format(
|
118
|
+
sql.Identifier(self.table_name)
|
119
|
+
),
|
120
|
+
sql.SQL(self.case_config.search_param()["metric_fun_op"]),
|
121
|
+
sql.SQL(" %s::vector LIMIT %s::int"),
|
122
|
+
]
|
123
|
+
)
|
124
|
+
|
125
|
+
try:
|
126
|
+
yield
|
127
|
+
finally:
|
128
|
+
self.cursor.close()
|
129
|
+
self.conn.close()
|
130
|
+
self.cursor = None
|
131
|
+
self.conn = None
|
132
|
+
|
133
|
+
def _drop_table(self):
|
134
|
+
assert self.conn is not None, "Connection is not initialized"
|
135
|
+
assert self.cursor is not None, "Cursor is not initialized"
|
136
|
+
log.info(f"{self.name} client drop table : {self.table_name}")
|
137
|
+
|
138
|
+
self.cursor.execute(
|
139
|
+
sql.SQL("DROP TABLE IF EXISTS public.{table_name}").format(
|
140
|
+
table_name=sql.Identifier(self.table_name)
|
141
|
+
)
|
142
|
+
)
|
143
|
+
self.conn.commit()
|
144
|
+
|
145
|
+
def ready_to_load(self):
|
146
|
+
pass
|
147
|
+
|
148
|
+
def optimize(self):
|
149
|
+
self._post_insert()
|
150
|
+
|
151
|
+
def _post_insert(self):
|
152
|
+
log.info(f"{self.name} post insert before optimize")
|
153
|
+
if self.case_config.create_index_after_load:
|
154
|
+
self._drop_index()
|
155
|
+
self._create_index()
|
156
|
+
|
157
|
+
def _drop_index(self):
|
158
|
+
assert self.conn is not None, "Connection is not initialized"
|
159
|
+
assert self.cursor is not None, "Cursor is not initialized"
|
160
|
+
log.info(f"{self.name} client drop index : {self._index_name}")
|
161
|
+
|
162
|
+
drop_index_sql = sql.SQL("DROP INDEX IF EXISTS {index_name}").format(
|
163
|
+
index_name=sql.Identifier(self._index_name)
|
164
|
+
)
|
165
|
+
log.debug(drop_index_sql.as_string(self.cursor))
|
166
|
+
self.cursor.execute(drop_index_sql)
|
167
|
+
self.conn.commit()
|
168
|
+
|
169
|
+
def _create_index(self):
|
170
|
+
assert self.conn is not None, "Connection is not initialized"
|
171
|
+
assert self.cursor is not None, "Cursor is not initialized"
|
172
|
+
log.info(f"{self.name} client create index : {self._index_name}")
|
173
|
+
|
174
|
+
index_param: dict[str, Any] = self.case_config.index_param()
|
175
|
+
|
176
|
+
options = []
|
177
|
+
for option_name, option_val in index_param["options"].items():
|
178
|
+
if option_val is not None:
|
179
|
+
options.append(
|
180
|
+
sql.SQL("{option_name} = {val}").format(
|
181
|
+
option_name=sql.Identifier(option_name),
|
182
|
+
val=sql.Identifier(str(option_val)),
|
183
|
+
)
|
184
|
+
)
|
185
|
+
|
186
|
+
num_bits_per_dimension = "2" if self.dim < 900 else "1"
|
187
|
+
options.append(
|
188
|
+
sql.SQL("{option_name} = {val}").format(
|
189
|
+
option_name=sql.Identifier("num_bits_per_dimension"),
|
190
|
+
val=sql.Identifier(num_bits_per_dimension),
|
191
|
+
)
|
192
|
+
)
|
193
|
+
|
194
|
+
if any(options):
|
195
|
+
with_clause = sql.SQL("WITH ({});").format(sql.SQL(", ").join(options))
|
196
|
+
else:
|
197
|
+
with_clause = sql.Composed(())
|
198
|
+
|
199
|
+
index_create_sql = sql.SQL(
|
200
|
+
"""
|
201
|
+
CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name}
|
202
|
+
USING {index_type} (embedding {embedding_metric})
|
203
|
+
"""
|
204
|
+
).format(
|
205
|
+
index_name=sql.Identifier(self._index_name),
|
206
|
+
table_name=sql.Identifier(self.table_name),
|
207
|
+
index_type=sql.Identifier(index_param["index_type"].lower()),
|
208
|
+
embedding_metric=sql.Identifier(index_param["metric"]),
|
209
|
+
)
|
210
|
+
index_create_sql_with_with_clause = (
|
211
|
+
index_create_sql + with_clause
|
212
|
+
).join(" ")
|
213
|
+
log.debug(index_create_sql_with_with_clause.as_string(self.cursor))
|
214
|
+
self.cursor.execute(index_create_sql_with_with_clause)
|
215
|
+
self.conn.commit()
|
216
|
+
|
217
|
+
def _create_table(self, dim: int):
|
218
|
+
assert self.conn is not None, "Connection is not initialized"
|
219
|
+
assert self.cursor is not None, "Cursor is not initialized"
|
220
|
+
|
221
|
+
try:
|
222
|
+
log.info(f"{self.name} client create table : {self.table_name}")
|
223
|
+
|
224
|
+
self.cursor.execute(
|
225
|
+
sql.SQL(
|
226
|
+
"CREATE TABLE IF NOT EXISTS public.{table_name} (id BIGINT PRIMARY KEY, embedding vector({dim}));"
|
227
|
+
).format(table_name=sql.Identifier(self.table_name), dim=dim)
|
228
|
+
)
|
229
|
+
self.conn.commit()
|
230
|
+
except Exception as e:
|
231
|
+
log.warning(
|
232
|
+
f"Failed to create pgvectorscale table: {self.table_name} error: {e}"
|
233
|
+
)
|
234
|
+
raise e from None
|
235
|
+
|
236
|
+
def insert_embeddings(
|
237
|
+
self,
|
238
|
+
embeddings: list[list[float]],
|
239
|
+
metadata: list[int],
|
240
|
+
**kwargs: Any,
|
241
|
+
) -> Tuple[int, Optional[Exception]]:
|
242
|
+
assert self.conn is not None, "Connection is not initialized"
|
243
|
+
assert self.cursor is not None, "Cursor is not initialized"
|
244
|
+
|
245
|
+
try:
|
246
|
+
metadata_arr = np.array(metadata)
|
247
|
+
embeddings_arr = np.array(embeddings)
|
248
|
+
|
249
|
+
with self.cursor.copy(
|
250
|
+
sql.SQL("COPY public.{table_name} FROM STDIN (FORMAT BINARY)").format(
|
251
|
+
table_name=sql.Identifier(self.table_name)
|
252
|
+
)
|
253
|
+
) as copy:
|
254
|
+
copy.set_types(["bigint", "vector"])
|
255
|
+
for i, row in enumerate(metadata_arr):
|
256
|
+
copy.write_row((row, embeddings_arr[i]))
|
257
|
+
self.conn.commit()
|
258
|
+
|
259
|
+
if kwargs.get("last_batch"):
|
260
|
+
self._post_insert()
|
261
|
+
|
262
|
+
return len(metadata), None
|
263
|
+
except Exception as e:
|
264
|
+
log.warning(
|
265
|
+
f"Failed to insert data into pgvector table ({self.table_name}), error: {e}"
|
266
|
+
)
|
267
|
+
return 0, e
|
268
|
+
|
269
|
+
def search_embedding(
|
270
|
+
self,
|
271
|
+
query: list[float],
|
272
|
+
k: int = 100,
|
273
|
+
filters: dict | None = None,
|
274
|
+
timeout: int | None = None,
|
275
|
+
) -> list[int]:
|
276
|
+
assert self.conn is not None, "Connection is not initialized"
|
277
|
+
assert self.cursor is not None, "Cursor is not initialized"
|
278
|
+
|
279
|
+
q = np.asarray(query)
|
280
|
+
if filters:
|
281
|
+
gt = filters.get("id")
|
282
|
+
result = self.cursor.execute(
|
283
|
+
self._filtered_search, (gt, q, k), prepare=True, binary=True
|
284
|
+
)
|
285
|
+
else:
|
286
|
+
result = self.cursor.execute(
|
287
|
+
self._unfiltered_search, (q, k), prepare=True, binary=True
|
288
|
+
)
|
289
|
+
|
290
|
+
return [int(i[0]) for i in result.fetchall()]
|
@@ -4,12 +4,10 @@ from ..api import DBConfig
|
|
4
4
|
|
5
5
|
class PineconeConfig(DBConfig):
|
6
6
|
api_key: SecretStr
|
7
|
-
environment: SecretStr
|
8
7
|
index_name: str
|
9
8
|
|
10
9
|
def to_dict(self) -> dict:
|
11
10
|
return {
|
12
11
|
"api_key": self.api_key.get_secret_value(),
|
13
|
-
"environment": self.environment.get_secret_value(),
|
14
12
|
"index_name": self.index_name,
|
15
13
|
}
|
@@ -3,7 +3,7 @@
|
|
3
3
|
import logging
|
4
4
|
from contextlib import contextmanager
|
5
5
|
from typing import Type
|
6
|
-
|
6
|
+
import pinecone
|
7
7
|
from ..api import VectorDB, DBConfig, DBCaseConfig, EmptyDBCaseConfig, IndexType
|
8
8
|
from .config import PineconeConfig
|
9
9
|
|
@@ -11,7 +11,8 @@ from .config import PineconeConfig
|
|
11
11
|
log = logging.getLogger(__name__)
|
12
12
|
|
13
13
|
PINECONE_MAX_NUM_PER_BATCH = 1000
|
14
|
-
PINECONE_MAX_SIZE_PER_BATCH = 2 * 1024 * 1024
|
14
|
+
PINECONE_MAX_SIZE_PER_BATCH = 2 * 1024 * 1024 # 2MB
|
15
|
+
|
15
16
|
|
16
17
|
class Pinecone(VectorDB):
|
17
18
|
def __init__(
|
@@ -23,30 +24,25 @@ class Pinecone(VectorDB):
|
|
23
24
|
**kwargs,
|
24
25
|
):
|
25
26
|
"""Initialize wrapper around the milvus vector database."""
|
26
|
-
self.index_name = db_config
|
27
|
-
self.api_key = db_config
|
28
|
-
self.
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
27
|
+
self.index_name = db_config.get("index_name", "")
|
28
|
+
self.api_key = db_config.get("api_key", "")
|
29
|
+
self.batch_size = int(
|
30
|
+
min(PINECONE_MAX_SIZE_PER_BATCH / (dim * 5), PINECONE_MAX_NUM_PER_BATCH)
|
31
|
+
)
|
32
|
+
|
33
|
+
pc = pinecone.Pinecone(api_key=self.api_key)
|
34
|
+
index = pc.Index(self.index_name)
|
35
|
+
|
35
36
|
if drop_old:
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
index_dim = index.describe_index_stats()["dimension"]
|
40
|
-
if (index_dim != dim):
|
41
|
-
raise ValueError(
|
42
|
-
f"Pinecone index {self.index_name} dimension mismatch, expected {index_dim} got {dim}")
|
43
|
-
log.info(
|
44
|
-
f"Pinecone client delete old index: {self.index_name}")
|
45
|
-
index.delete(delete_all=True)
|
46
|
-
index.close()
|
47
|
-
else:
|
37
|
+
index_stats = index.describe_index_stats()
|
38
|
+
index_dim = index_stats["dimension"]
|
39
|
+
if index_dim != dim:
|
48
40
|
raise ValueError(
|
49
|
-
f"Pinecone index {self.index_name}
|
41
|
+
f"Pinecone index {self.index_name} dimension mismatch, expected {index_dim} got {dim}"
|
42
|
+
)
|
43
|
+
for namespace in index_stats["namespaces"]:
|
44
|
+
log.info(f"Pinecone index delete namespace: {namespace}")
|
45
|
+
index.delete(delete_all=True, namespace=namespace)
|
50
46
|
|
51
47
|
self._metadata_key = "meta"
|
52
48
|
|
@@ -59,13 +55,10 @@ class Pinecone(VectorDB):
|
|
59
55
|
return EmptyDBCaseConfig
|
60
56
|
|
61
57
|
@contextmanager
|
62
|
-
def init(self)
|
63
|
-
|
64
|
-
|
65
|
-
api_key=self.api_key, environment=self.environment)
|
66
|
-
self.index = pinecone.Index(self.index_name)
|
58
|
+
def init(self):
|
59
|
+
pc = pinecone.Pinecone(api_key=self.api_key)
|
60
|
+
self.index = pc.Index(self.index_name)
|
67
61
|
yield
|
68
|
-
self.index.close()
|
69
62
|
|
70
63
|
def ready_to_load(self):
|
71
64
|
pass
|
@@ -83,11 +76,16 @@ class Pinecone(VectorDB):
|
|
83
76
|
insert_count = 0
|
84
77
|
try:
|
85
78
|
for batch_start_offset in range(0, len(embeddings), self.batch_size):
|
86
|
-
batch_end_offset = min(
|
79
|
+
batch_end_offset = min(
|
80
|
+
batch_start_offset + self.batch_size, len(embeddings)
|
81
|
+
)
|
87
82
|
insert_datas = []
|
88
83
|
for i in range(batch_start_offset, batch_end_offset):
|
89
|
-
insert_data = (
|
90
|
-
|
84
|
+
insert_data = (
|
85
|
+
str(metadata[i]),
|
86
|
+
embeddings[i],
|
87
|
+
{self._metadata_key: metadata[i]},
|
88
|
+
)
|
91
89
|
insert_datas.append(insert_data)
|
92
90
|
self.index.upsert(insert_datas)
|
93
91
|
insert_count += batch_end_offset - batch_start_offset
|
@@ -101,7 +99,7 @@ class Pinecone(VectorDB):
|
|
101
99
|
k: int = 100,
|
102
100
|
filters: dict | None = None,
|
103
101
|
timeout: int | None = None,
|
104
|
-
) -> list[
|
102
|
+
) -> list[int]:
|
105
103
|
if filters is None:
|
106
104
|
pinecone_filters = {}
|
107
105
|
else:
|
@@ -111,9 +109,9 @@ class Pinecone(VectorDB):
|
|
111
109
|
top_k=k,
|
112
110
|
vector=query,
|
113
111
|
filter=pinecone_filters,
|
114
|
-
)[
|
112
|
+
)["matches"]
|
115
113
|
except Exception as e:
|
116
114
|
print(f"Error querying index: {e}")
|
117
115
|
raise e
|
118
|
-
id_res = [int(one_res[
|
116
|
+
id_res = [int(one_res["id"]) for one_res in res]
|
119
117
|
return id_res
|
@@ -3,6 +3,9 @@ from typing import Annotated, TypedDict, Unpack
|
|
3
3
|
import click
|
4
4
|
from pydantic import SecretStr
|
5
5
|
|
6
|
+
from .config import RedisHNSWConfig
|
7
|
+
|
8
|
+
|
6
9
|
from ....cli.cli import (
|
7
10
|
CommonTypedDict,
|
8
11
|
HNSWFlavor2,
|
@@ -69,6 +72,11 @@ def Redis(**parameters: Unpack[RedisHNSWTypedDict]):
|
|
69
72
|
ssl=parameters["ssl"],
|
70
73
|
ssl_ca_certs=parameters["ssl_ca_certs"],
|
71
74
|
cmd=parameters["cmd"],
|
75
|
+
),
|
76
|
+
db_case_config=RedisHNSWConfig(
|
77
|
+
M=parameters["m"],
|
78
|
+
efConstruction=parameters["ef_construction"],
|
79
|
+
ef=parameters["ef_runtime"],
|
72
80
|
),
|
73
81
|
**parameters,
|
74
82
|
)
|