vectordb-bench 0.0.13__py3-none-any.whl → 0.0.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vectordb_bench/__init__.py +14 -13
- vectordb_bench/backend/clients/__init__.py +13 -0
- vectordb_bench/backend/clients/api.py +2 -0
- vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py +47 -6
- vectordb_bench/backend/clients/aws_opensearch/config.py +12 -6
- vectordb_bench/backend/clients/aws_opensearch/run.py +34 -3
- vectordb_bench/backend/clients/pgdiskann/cli.py +99 -0
- vectordb_bench/backend/clients/pgdiskann/config.py +145 -0
- vectordb_bench/backend/clients/pgdiskann/pgdiskann.py +350 -0
- vectordb_bench/backend/clients/pgvector/cli.py +62 -1
- vectordb_bench/backend/clients/pgvector/config.py +48 -10
- vectordb_bench/backend/clients/pgvector/pgvector.py +145 -26
- vectordb_bench/backend/clients/pgvectorscale/cli.py +108 -0
- vectordb_bench/backend/clients/pgvectorscale/pgvectorscale.py +22 -4
- vectordb_bench/backend/clients/pinecone/config.py +0 -2
- vectordb_bench/backend/clients/pinecone/pinecone.py +34 -36
- vectordb_bench/backend/clients/redis/cli.py +8 -0
- vectordb_bench/backend/clients/redis/config.py +37 -6
- vectordb_bench/backend/clients/weaviate_cloud/weaviate_cloud.py +1 -1
- vectordb_bench/backend/runner/mp_runner.py +2 -1
- vectordb_bench/cli/cli.py +137 -0
- vectordb_bench/cli/vectordbbench.py +4 -1
- vectordb_bench/frontend/components/check_results/charts.py +9 -6
- vectordb_bench/frontend/components/concurrent/charts.py +3 -6
- vectordb_bench/frontend/components/run_test/caseSelector.py +6 -0
- vectordb_bench/frontend/config/dbCaseConfigs.py +165 -1
- vectordb_bench/frontend/pages/quries_per_dollar.py +13 -5
- vectordb_bench/frontend/vdb_benchmark.py +11 -3
- vectordb_bench/models.py +13 -3
- vectordb_bench/results/Milvus/result_20230727_standard_milvus.json +53 -1
- vectordb_bench/results/Milvus/result_20230808_standard_milvus.json +48 -0
- vectordb_bench/results/ZillizCloud/result_20230727_standard_zillizcloud.json +29 -1
- vectordb_bench/results/ZillizCloud/result_20230808_standard_zillizcloud.json +24 -0
- vectordb_bench/results/ZillizCloud/result_20240105_standard_202401_zillizcloud.json +98 -49
- vectordb_bench/results/getLeaderboardData.py +17 -7
- vectordb_bench/results/leaderboard.json +1 -1
- {vectordb_bench-0.0.13.dist-info → vectordb_bench-0.0.15.dist-info}/METADATA +65 -35
- {vectordb_bench-0.0.13.dist-info → vectordb_bench-0.0.15.dist-info}/RECORD +42 -38
- {vectordb_bench-0.0.13.dist-info → vectordb_bench-0.0.15.dist-info}/WHEEL +1 -1
- {vectordb_bench-0.0.13.dist-info → vectordb_bench-0.0.15.dist-info}/LICENSE +0 -0
- {vectordb_bench-0.0.13.dist-info → vectordb_bench-0.0.15.dist-info}/entry_points.txt +0 -0
- {vectordb_bench-0.0.13.dist-info → vectordb_bench-0.0.15.dist-info}/top_level.txt +0 -0
vectordb_bench/__init__.py
CHANGED
@@ -37,23 +37,24 @@ class config:
|
|
37
37
|
K_DEFAULT = 100 # default return top k nearest neighbors during search
|
38
38
|
CUSTOM_CONFIG_DIR = pathlib.Path(__file__).parent.joinpath("custom/custom_case.json")
|
39
39
|
|
40
|
-
CAPACITY_TIMEOUT_IN_SECONDS = 24 * 3600
|
41
|
-
LOAD_TIMEOUT_DEFAULT =
|
42
|
-
LOAD_TIMEOUT_768D_1M =
|
43
|
-
LOAD_TIMEOUT_768D_10M =
|
44
|
-
LOAD_TIMEOUT_768D_100M =
|
40
|
+
CAPACITY_TIMEOUT_IN_SECONDS = 24 * 3600 # 24h
|
41
|
+
LOAD_TIMEOUT_DEFAULT = 24 * 3600 # 24h
|
42
|
+
LOAD_TIMEOUT_768D_1M = 24 * 3600 # 24h
|
43
|
+
LOAD_TIMEOUT_768D_10M = 240 * 3600 # 10d
|
44
|
+
LOAD_TIMEOUT_768D_100M = 2400 * 3600 # 100d
|
45
45
|
|
46
|
-
LOAD_TIMEOUT_1536D_500K =
|
47
|
-
LOAD_TIMEOUT_1536D_5M =
|
46
|
+
LOAD_TIMEOUT_1536D_500K = 24 * 3600 # 24h
|
47
|
+
LOAD_TIMEOUT_1536D_5M = 240 * 3600 # 10d
|
48
48
|
|
49
|
-
OPTIMIZE_TIMEOUT_DEFAULT =
|
50
|
-
OPTIMIZE_TIMEOUT_768D_1M =
|
51
|
-
OPTIMIZE_TIMEOUT_768D_10M =
|
52
|
-
OPTIMIZE_TIMEOUT_768D_100M =
|
49
|
+
OPTIMIZE_TIMEOUT_DEFAULT = 24 * 3600 # 24h
|
50
|
+
OPTIMIZE_TIMEOUT_768D_1M = 24 * 3600 # 24h
|
51
|
+
OPTIMIZE_TIMEOUT_768D_10M = 240 * 3600 # 10d
|
52
|
+
OPTIMIZE_TIMEOUT_768D_100M = 2400 * 3600 # 100d
|
53
53
|
|
54
54
|
|
55
|
-
OPTIMIZE_TIMEOUT_1536D_500K =
|
56
|
-
OPTIMIZE_TIMEOUT_1536D_5M =
|
55
|
+
OPTIMIZE_TIMEOUT_1536D_500K = 24 * 3600 # 24h
|
56
|
+
OPTIMIZE_TIMEOUT_1536D_5M = 240 * 3600 # 10d
|
57
|
+
|
57
58
|
def display(self) -> str:
|
58
59
|
tmp = [
|
59
60
|
i for i in inspect.getmembers(self)
|
@@ -31,6 +31,7 @@ class DB(Enum):
|
|
31
31
|
PgVector = "PgVector"
|
32
32
|
PgVectoRS = "PgVectoRS"
|
33
33
|
PgVectorScale = "PgVectorScale"
|
34
|
+
PgDiskANN = "PgDiskANN"
|
34
35
|
Redis = "Redis"
|
35
36
|
MemoryDB = "MemoryDB"
|
36
37
|
Chroma = "Chroma"
|
@@ -77,6 +78,10 @@ class DB(Enum):
|
|
77
78
|
from .pgvectorscale.pgvectorscale import PgVectorScale
|
78
79
|
return PgVectorScale
|
79
80
|
|
81
|
+
if self == DB.PgDiskANN:
|
82
|
+
from .pgdiskann.pgdiskann import PgDiskANN
|
83
|
+
return PgDiskANN
|
84
|
+
|
80
85
|
if self == DB.Redis:
|
81
86
|
from .redis.redis import Redis
|
82
87
|
return Redis
|
@@ -132,6 +137,10 @@ class DB(Enum):
|
|
132
137
|
from .pgvectorscale.config import PgVectorScaleConfig
|
133
138
|
return PgVectorScaleConfig
|
134
139
|
|
140
|
+
if self == DB.PgDiskANN:
|
141
|
+
from .pgdiskann.config import PgDiskANNConfig
|
142
|
+
return PgDiskANNConfig
|
143
|
+
|
135
144
|
if self == DB.Redis:
|
136
145
|
from .redis.config import RedisConfig
|
137
146
|
return RedisConfig
|
@@ -185,6 +194,10 @@ class DB(Enum):
|
|
185
194
|
from .pgvectorscale.config import _pgvectorscale_case_config
|
186
195
|
return _pgvectorscale_case_config.get(index_type)
|
187
196
|
|
197
|
+
if self == DB.PgDiskANN:
|
198
|
+
from .pgdiskann.config import _pgdiskann_case_config
|
199
|
+
return _pgdiskann_case_config.get(index_type)
|
200
|
+
|
188
201
|
# DB.Pinecone, DB.Chroma, DB.Redis
|
189
202
|
return EmptyDBCaseConfig
|
190
203
|
|
@@ -3,7 +3,7 @@ from contextlib import contextmanager
|
|
3
3
|
import time
|
4
4
|
from typing import Iterable, Type
|
5
5
|
from ..api import VectorDB, DBCaseConfig, DBConfig, IndexType
|
6
|
-
from .config import AWSOpenSearchConfig, AWSOpenSearchIndexConfig
|
6
|
+
from .config import AWSOpenSearchConfig, AWSOpenSearchIndexConfig, AWSOS_Engine
|
7
7
|
from opensearchpy import OpenSearch
|
8
8
|
from opensearchpy.helpers import bulk
|
9
9
|
|
@@ -83,7 +83,7 @@ class AWSOpenSearch(VectorDB):
|
|
83
83
|
|
84
84
|
@contextmanager
|
85
85
|
def init(self) -> None:
|
86
|
-
"""connect to
|
86
|
+
"""connect to opensearch"""
|
87
87
|
self.client = OpenSearch(**self.db_config)
|
88
88
|
|
89
89
|
yield
|
@@ -97,7 +97,7 @@ class AWSOpenSearch(VectorDB):
|
|
97
97
|
metadata: list[int],
|
98
98
|
**kwargs,
|
99
99
|
) -> tuple[int, Exception]:
|
100
|
-
"""Insert the embeddings to the
|
100
|
+
"""Insert the embeddings to the opensearch."""
|
101
101
|
assert self.client is not None, "should self.init() first"
|
102
102
|
|
103
103
|
insert_data = []
|
@@ -136,13 +136,15 @@ class AWSOpenSearch(VectorDB):
|
|
136
136
|
body = {
|
137
137
|
"size": k,
|
138
138
|
"query": {"knn": {self.vector_col_name: {"vector": query, "k": k}}},
|
139
|
+
**({"filter": {"range": {self.id_col_name: {"gt": filters["id"]}}}} if filters else {})
|
139
140
|
}
|
140
141
|
try:
|
141
|
-
resp = self.client.search(index=self.index_name, body=body)
|
142
|
+
resp = self.client.search(index=self.index_name, body=body,size=k,_source=False,docvalue_fields=[self.id_col_name],stored_fields="_none_",filter_path=[f"hits.hits.fields.{self.id_col_name}"],)
|
142
143
|
log.info(f'Search took: {resp["took"]}')
|
143
144
|
log.info(f'Search shards: {resp["_shards"]}')
|
144
145
|
log.info(f'Search hits total: {resp["hits"]["total"]}')
|
145
|
-
result = [
|
146
|
+
result = [h["fields"][self.id_col_name][0] for h in resp["hits"]["hits"]]
|
147
|
+
#result = [int(d["_id"]) for d in resp["hits"]["hits"]]
|
146
148
|
# log.info(f'success! length={len(res)}')
|
147
149
|
|
148
150
|
return result
|
@@ -152,7 +154,46 @@ class AWSOpenSearch(VectorDB):
|
|
152
154
|
|
153
155
|
def optimize(self):
|
154
156
|
"""optimize will be called between insertion and search in performance cases."""
|
155
|
-
|
157
|
+
# Call refresh first to ensure that all segments are created
|
158
|
+
self._refresh_index()
|
159
|
+
self._do_force_merge()
|
160
|
+
# Call refresh again to ensure that the index is ready after force merge.
|
161
|
+
self._refresh_index()
|
162
|
+
# ensure that all graphs are loaded in memory and ready for search
|
163
|
+
self._load_graphs_to_memory()
|
164
|
+
|
165
|
+
def _refresh_index(self):
|
166
|
+
log.debug(f"Starting refresh for index {self.index_name}")
|
167
|
+
SECONDS_WAITING_FOR_REFRESH_API_CALL_SEC = 30
|
168
|
+
while True:
|
169
|
+
try:
|
170
|
+
log.info(f"Starting the Refresh Index..")
|
171
|
+
self.client.indices.refresh(index=self.index_name)
|
172
|
+
break
|
173
|
+
except Exception as e:
|
174
|
+
log.info(
|
175
|
+
f"Refresh errored out. Sleeping for {SECONDS_WAITING_FOR_REFRESH_API_CALL_SEC} sec and then Retrying : {e}")
|
176
|
+
time.sleep(SECONDS_WAITING_FOR_REFRESH_API_CALL_SEC)
|
177
|
+
continue
|
178
|
+
log.debug(f"Completed refresh for index {self.index_name}")
|
179
|
+
|
180
|
+
def _do_force_merge(self):
|
181
|
+
log.debug(f"Starting force merge for index {self.index_name}")
|
182
|
+
force_merge_endpoint = f'/{self.index_name}/_forcemerge?max_num_segments=1&wait_for_completion=false'
|
183
|
+
force_merge_task_id = self.client.transport.perform_request('POST', force_merge_endpoint)['task']
|
184
|
+
SECONDS_WAITING_FOR_FORCE_MERGE_API_CALL_SEC = 30
|
185
|
+
while True:
|
186
|
+
time.sleep(SECONDS_WAITING_FOR_FORCE_MERGE_API_CALL_SEC)
|
187
|
+
task_status = self.client.tasks.get(task_id=force_merge_task_id)
|
188
|
+
if task_status['completed']:
|
189
|
+
break
|
190
|
+
log.debug(f"Completed force merge for index {self.index_name}")
|
191
|
+
|
192
|
+
def _load_graphs_to_memory(self):
|
193
|
+
if self.case_config.engine != AWSOS_Engine.lucene:
|
194
|
+
log.info("Calling warmup API to load graphs into memory")
|
195
|
+
warmup_endpoint = f'/_plugins/_knn/warmup/{self.index_name}'
|
196
|
+
self.client.transport.perform_request('GET', warmup_endpoint)
|
156
197
|
|
157
198
|
def ready_to_load(self):
|
158
199
|
"""ready_to_load will be called before load in load cases."""
|
@@ -1,9 +1,10 @@
|
|
1
|
+
import logging
|
1
2
|
from enum import Enum
|
2
3
|
from pydantic import SecretStr, BaseModel
|
3
4
|
|
4
5
|
from ..api import DBConfig, DBCaseConfig, MetricType, IndexType
|
5
6
|
|
6
|
-
|
7
|
+
log = logging.getLogger(__name__)
|
7
8
|
class AWSOpenSearchConfig(DBConfig, BaseModel):
|
8
9
|
host: str = ""
|
9
10
|
port: int = 443
|
@@ -31,14 +32,18 @@ class AWSOS_Engine(Enum):
|
|
31
32
|
|
32
33
|
class AWSOpenSearchIndexConfig(BaseModel, DBCaseConfig):
|
33
34
|
metric_type: MetricType = MetricType.L2
|
34
|
-
engine: AWSOS_Engine = AWSOS_Engine.
|
35
|
-
efConstruction: int =
|
36
|
-
|
35
|
+
engine: AWSOS_Engine = AWSOS_Engine.faiss
|
36
|
+
efConstruction: int = 256
|
37
|
+
efSearch: int = 256
|
38
|
+
M: int = 16
|
37
39
|
|
38
40
|
def parse_metric(self) -> str:
|
39
41
|
if self.metric_type == MetricType.IP:
|
40
|
-
return "innerproduct"
|
42
|
+
return "innerproduct"
|
41
43
|
elif self.metric_type == MetricType.COSINE:
|
44
|
+
if self.engine == AWSOS_Engine.faiss:
|
45
|
+
log.info(f"Using metric type as innerproduct because faiss doesn't support cosine as metric type for Opensearch")
|
46
|
+
return "innerproduct"
|
42
47
|
return "cosinesimil"
|
43
48
|
return "l2"
|
44
49
|
|
@@ -49,7 +54,8 @@ class AWSOpenSearchIndexConfig(BaseModel, DBCaseConfig):
|
|
49
54
|
"engine": self.engine.value,
|
50
55
|
"parameters": {
|
51
56
|
"ef_construction": self.efConstruction,
|
52
|
-
"m": self.M
|
57
|
+
"m": self.M,
|
58
|
+
"ef_search": self.efSearch
|
53
59
|
}
|
54
60
|
}
|
55
61
|
return params
|
@@ -40,12 +40,12 @@ def create_index(client, index_name):
|
|
40
40
|
"type": "knn_vector",
|
41
41
|
"dimension": _DIM,
|
42
42
|
"method": {
|
43
|
-
"engine": "
|
43
|
+
"engine": "faiss",
|
44
44
|
"name": "hnsw",
|
45
45
|
"space_type": "l2",
|
46
46
|
"parameters": {
|
47
|
-
"ef_construction":
|
48
|
-
"m":
|
47
|
+
"ef_construction": 256,
|
48
|
+
"m": 16,
|
49
49
|
}
|
50
50
|
}
|
51
51
|
}
|
@@ -108,12 +108,43 @@ def search(client, index_name):
|
|
108
108
|
print('\nSearch not ready, sleep 1s')
|
109
109
|
time.sleep(1)
|
110
110
|
|
111
|
+
def optimize_index(client, index_name):
|
112
|
+
print(f"Starting force merge for index {index_name}")
|
113
|
+
force_merge_endpoint = f'/{index_name}/_forcemerge?max_num_segments=1&wait_for_completion=false'
|
114
|
+
force_merge_task_id = client.transport.perform_request('POST', force_merge_endpoint)['task']
|
115
|
+
SECONDS_WAITING_FOR_FORCE_MERGE_API_CALL_SEC = 30
|
116
|
+
while True:
|
117
|
+
time.sleep(SECONDS_WAITING_FOR_FORCE_MERGE_API_CALL_SEC)
|
118
|
+
task_status = client.tasks.get(task_id=force_merge_task_id)
|
119
|
+
if task_status['completed']:
|
120
|
+
break
|
121
|
+
print(f"Completed force merge for index {index_name}")
|
122
|
+
|
123
|
+
|
124
|
+
def refresh_index(client, index_name):
|
125
|
+
print(f"Starting refresh for index {index_name}")
|
126
|
+
SECONDS_WAITING_FOR_REFRESH_API_CALL_SEC = 30
|
127
|
+
while True:
|
128
|
+
try:
|
129
|
+
print(f"Starting the Refresh Index..")
|
130
|
+
client.indices.refresh(index=index_name)
|
131
|
+
break
|
132
|
+
except Exception as e:
|
133
|
+
print(
|
134
|
+
f"Refresh errored out. Sleeping for {SECONDS_WAITING_FOR_REFRESH_API_CALL_SEC} sec and then Retrying : {e}")
|
135
|
+
time.sleep(SECONDS_WAITING_FOR_REFRESH_API_CALL_SEC)
|
136
|
+
continue
|
137
|
+
print(f"Completed refresh for index {index_name}")
|
138
|
+
|
139
|
+
|
111
140
|
|
112
141
|
def main():
|
113
142
|
client = create_client()
|
114
143
|
try:
|
115
144
|
create_index(client, _INDEX_NAME)
|
116
145
|
bulk_insert(client, _INDEX_NAME)
|
146
|
+
optimize_index(client, _INDEX_NAME)
|
147
|
+
refresh_index(client, _INDEX_NAME)
|
117
148
|
search(client, _INDEX_NAME)
|
118
149
|
delete_index(client, _INDEX_NAME)
|
119
150
|
except Exception as e:
|
@@ -0,0 +1,99 @@
|
|
1
|
+
import click
|
2
|
+
import os
|
3
|
+
from pydantic import SecretStr
|
4
|
+
|
5
|
+
from ....cli.cli import (
|
6
|
+
CommonTypedDict,
|
7
|
+
cli,
|
8
|
+
click_parameter_decorators_from_typed_dict,
|
9
|
+
run,
|
10
|
+
)
|
11
|
+
from typing import Annotated, Optional, Unpack
|
12
|
+
from vectordb_bench.backend.clients import DB
|
13
|
+
|
14
|
+
|
15
|
+
class PgDiskAnnTypedDict(CommonTypedDict):
|
16
|
+
user_name: Annotated[
|
17
|
+
str, click.option("--user-name", type=str, help="Db username", required=True)
|
18
|
+
]
|
19
|
+
password: Annotated[
|
20
|
+
str,
|
21
|
+
click.option("--password",
|
22
|
+
type=str,
|
23
|
+
help="Postgres database password",
|
24
|
+
default=lambda: os.environ.get("POSTGRES_PASSWORD", ""),
|
25
|
+
show_default="$POSTGRES_PASSWORD",
|
26
|
+
),
|
27
|
+
]
|
28
|
+
|
29
|
+
host: Annotated[
|
30
|
+
str, click.option("--host", type=str, help="Db host", required=True)
|
31
|
+
]
|
32
|
+
db_name: Annotated[
|
33
|
+
str, click.option("--db-name", type=str, help="Db name", required=True)
|
34
|
+
]
|
35
|
+
max_neighbors: Annotated[
|
36
|
+
int,
|
37
|
+
click.option(
|
38
|
+
"--max-neighbors", type=int, help="PgDiskAnn max neighbors",
|
39
|
+
),
|
40
|
+
]
|
41
|
+
l_value_ib: Annotated[
|
42
|
+
int,
|
43
|
+
click.option(
|
44
|
+
"--l-value-ib", type=int, help="PgDiskAnn l_value_ib",
|
45
|
+
),
|
46
|
+
]
|
47
|
+
l_value_is: Annotated[
|
48
|
+
float,
|
49
|
+
click.option(
|
50
|
+
"--l-value-is", type=float, help="PgDiskAnn l_value_is",
|
51
|
+
),
|
52
|
+
]
|
53
|
+
maintenance_work_mem: Annotated[
|
54
|
+
Optional[str],
|
55
|
+
click.option(
|
56
|
+
"--maintenance-work-mem",
|
57
|
+
type=str,
|
58
|
+
help="Sets the maximum memory to be used for maintenance operations (index creation). "
|
59
|
+
"Can be entered as string with unit like '64GB' or as an integer number of KB."
|
60
|
+
"This will set the parameters: max_parallel_maintenance_workers,"
|
61
|
+
" max_parallel_workers & table(parallel_workers)",
|
62
|
+
required=False,
|
63
|
+
),
|
64
|
+
]
|
65
|
+
max_parallel_workers: Annotated[
|
66
|
+
Optional[int],
|
67
|
+
click.option(
|
68
|
+
"--max-parallel-workers",
|
69
|
+
type=int,
|
70
|
+
help="Sets the maximum number of parallel processes per maintenance operation (index creation)",
|
71
|
+
required=False,
|
72
|
+
),
|
73
|
+
]
|
74
|
+
|
75
|
+
@cli.command()
|
76
|
+
@click_parameter_decorators_from_typed_dict(PgDiskAnnTypedDict)
|
77
|
+
def PgDiskAnn(
|
78
|
+
**parameters: Unpack[PgDiskAnnTypedDict],
|
79
|
+
):
|
80
|
+
from .config import PgDiskANNConfig, PgDiskANNImplConfig
|
81
|
+
|
82
|
+
run(
|
83
|
+
db=DB.PgDiskANN,
|
84
|
+
db_config=PgDiskANNConfig(
|
85
|
+
db_label=parameters["db_label"],
|
86
|
+
user_name=SecretStr(parameters["user_name"]),
|
87
|
+
password=SecretStr(parameters["password"]),
|
88
|
+
host=parameters["host"],
|
89
|
+
db_name=parameters["db_name"],
|
90
|
+
),
|
91
|
+
db_case_config=PgDiskANNImplConfig(
|
92
|
+
max_neighbors=parameters["max_neighbors"],
|
93
|
+
l_value_ib=parameters["l_value_ib"],
|
94
|
+
l_value_is=parameters["l_value_is"],
|
95
|
+
max_parallel_workers=parameters["max_parallel_workers"],
|
96
|
+
maintenance_work_mem=parameters["maintenance_work_mem"],
|
97
|
+
),
|
98
|
+
**parameters,
|
99
|
+
)
|
@@ -0,0 +1,145 @@
|
|
1
|
+
from abc import abstractmethod
|
2
|
+
from typing import Any, Mapping, Optional, Sequence, TypedDict
|
3
|
+
from pydantic import BaseModel, SecretStr
|
4
|
+
from typing_extensions import LiteralString
|
5
|
+
from ..api import DBCaseConfig, DBConfig, IndexType, MetricType
|
6
|
+
|
7
|
+
POSTGRE_URL_PLACEHOLDER = "postgresql://%s:%s@%s/%s"
|
8
|
+
|
9
|
+
|
10
|
+
class PgDiskANNConfigDict(TypedDict):
|
11
|
+
"""These keys will be directly used as kwargs in psycopg connection string,
|
12
|
+
so the names must match exactly psycopg API"""
|
13
|
+
|
14
|
+
user: str
|
15
|
+
password: str
|
16
|
+
host: str
|
17
|
+
port: int
|
18
|
+
dbname: str
|
19
|
+
|
20
|
+
|
21
|
+
class PgDiskANNConfig(DBConfig):
|
22
|
+
user_name: SecretStr = SecretStr("postgres")
|
23
|
+
password: SecretStr
|
24
|
+
host: str = "localhost"
|
25
|
+
port: int = 5432
|
26
|
+
db_name: str
|
27
|
+
|
28
|
+
def to_dict(self) -> PgDiskANNConfigDict:
|
29
|
+
user_str = self.user_name.get_secret_value()
|
30
|
+
pwd_str = self.password.get_secret_value()
|
31
|
+
return {
|
32
|
+
"host": self.host,
|
33
|
+
"port": self.port,
|
34
|
+
"dbname": self.db_name,
|
35
|
+
"user": user_str,
|
36
|
+
"password": pwd_str,
|
37
|
+
}
|
38
|
+
|
39
|
+
|
40
|
+
class PgDiskANNIndexConfig(BaseModel, DBCaseConfig):
|
41
|
+
metric_type: MetricType | None = None
|
42
|
+
create_index_before_load: bool = False
|
43
|
+
create_index_after_load: bool = True
|
44
|
+
maintenance_work_mem: Optional[str]
|
45
|
+
max_parallel_workers: Optional[int]
|
46
|
+
|
47
|
+
def parse_metric(self) -> str:
|
48
|
+
if self.metric_type == MetricType.L2:
|
49
|
+
return "vector_l2_ops"
|
50
|
+
elif self.metric_type == MetricType.IP:
|
51
|
+
return "vector_ip_ops"
|
52
|
+
return "vector_cosine_ops"
|
53
|
+
|
54
|
+
def parse_metric_fun_op(self) -> LiteralString:
|
55
|
+
if self.metric_type == MetricType.L2:
|
56
|
+
return "<->"
|
57
|
+
elif self.metric_type == MetricType.IP:
|
58
|
+
return "<#>"
|
59
|
+
return "<=>"
|
60
|
+
|
61
|
+
def parse_metric_fun_str(self) -> str:
|
62
|
+
if self.metric_type == MetricType.L2:
|
63
|
+
return "l2_distance"
|
64
|
+
elif self.metric_type == MetricType.IP:
|
65
|
+
return "max_inner_product"
|
66
|
+
return "cosine_distance"
|
67
|
+
|
68
|
+
@abstractmethod
|
69
|
+
def index_param(self) -> dict:
|
70
|
+
...
|
71
|
+
|
72
|
+
@abstractmethod
|
73
|
+
def search_param(self) -> dict:
|
74
|
+
...
|
75
|
+
|
76
|
+
@abstractmethod
|
77
|
+
def session_param(self) -> dict:
|
78
|
+
...
|
79
|
+
|
80
|
+
@staticmethod
|
81
|
+
def _optionally_build_with_options(with_options: Mapping[str, Any]) -> Sequence[dict[str, Any]]:
|
82
|
+
"""Walk through mappings, creating a List of {key1 = value} pairs. That will be used to build a where clause"""
|
83
|
+
options = []
|
84
|
+
for option_name, value in with_options.items():
|
85
|
+
if value is not None:
|
86
|
+
options.append(
|
87
|
+
{
|
88
|
+
"option_name": option_name,
|
89
|
+
"val": str(value),
|
90
|
+
}
|
91
|
+
)
|
92
|
+
return options
|
93
|
+
|
94
|
+
@staticmethod
|
95
|
+
def _optionally_build_set_options(
|
96
|
+
set_mapping: Mapping[str, Any]
|
97
|
+
) -> Sequence[dict[str, Any]]:
|
98
|
+
"""Walk through options, creating 'SET 'key1 = "value1";' list"""
|
99
|
+
session_options = []
|
100
|
+
for setting_name, value in set_mapping.items():
|
101
|
+
if value:
|
102
|
+
session_options.append(
|
103
|
+
{"parameter": {
|
104
|
+
"setting_name": setting_name,
|
105
|
+
"val": str(value),
|
106
|
+
},
|
107
|
+
}
|
108
|
+
)
|
109
|
+
return session_options
|
110
|
+
|
111
|
+
|
112
|
+
class PgDiskANNImplConfig(PgDiskANNIndexConfig):
|
113
|
+
index: IndexType = IndexType.DISKANN
|
114
|
+
max_neighbors: int | None
|
115
|
+
l_value_ib: int | None
|
116
|
+
l_value_is: float | None
|
117
|
+
maintenance_work_mem: Optional[str] = None
|
118
|
+
max_parallel_workers: Optional[int] = None
|
119
|
+
|
120
|
+
def index_param(self) -> dict:
|
121
|
+
return {
|
122
|
+
"metric": self.parse_metric(),
|
123
|
+
"index_type": self.index.value,
|
124
|
+
"options": {
|
125
|
+
"max_neighbors": self.max_neighbors,
|
126
|
+
"l_value_ib": self.l_value_ib,
|
127
|
+
},
|
128
|
+
"maintenance_work_mem": self.maintenance_work_mem,
|
129
|
+
"max_parallel_workers": self.max_parallel_workers,
|
130
|
+
}
|
131
|
+
|
132
|
+
def search_param(self) -> dict:
|
133
|
+
return {
|
134
|
+
"metric": self.parse_metric(),
|
135
|
+
"metric_fun_op": self.parse_metric_fun_op(),
|
136
|
+
}
|
137
|
+
|
138
|
+
def session_param(self) -> dict:
|
139
|
+
return {
|
140
|
+
"diskann.l_value_is": self.l_value_is,
|
141
|
+
}
|
142
|
+
|
143
|
+
_pgdiskann_case_config = {
|
144
|
+
IndexType.DISKANN: PgDiskANNImplConfig,
|
145
|
+
}
|