vectordb-bench 0.0.12__py3-none-any.whl → 0.0.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. vectordb_bench/backend/clients/__init__.py +22 -0
  2. vectordb_bench/backend/clients/api.py +21 -1
  3. vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py +47 -6
  4. vectordb_bench/backend/clients/aws_opensearch/config.py +12 -6
  5. vectordb_bench/backend/clients/aws_opensearch/run.py +34 -3
  6. vectordb_bench/backend/clients/memorydb/cli.py +88 -0
  7. vectordb_bench/backend/clients/memorydb/config.py +54 -0
  8. vectordb_bench/backend/clients/memorydb/memorydb.py +254 -0
  9. vectordb_bench/backend/clients/pgvecto_rs/cli.py +154 -0
  10. vectordb_bench/backend/clients/pgvecto_rs/config.py +108 -73
  11. vectordb_bench/backend/clients/pgvecto_rs/pgvecto_rs.py +159 -59
  12. vectordb_bench/backend/clients/pgvector/cli.py +17 -2
  13. vectordb_bench/backend/clients/pgvector/config.py +20 -5
  14. vectordb_bench/backend/clients/pgvector/pgvector.py +95 -25
  15. vectordb_bench/backend/clients/pgvectorscale/cli.py +108 -0
  16. vectordb_bench/backend/clients/pgvectorscale/config.py +111 -0
  17. vectordb_bench/backend/clients/pgvectorscale/pgvectorscale.py +290 -0
  18. vectordb_bench/backend/clients/pinecone/config.py +0 -2
  19. vectordb_bench/backend/clients/pinecone/pinecone.py +34 -36
  20. vectordb_bench/backend/clients/redis/cli.py +8 -0
  21. vectordb_bench/backend/clients/redis/config.py +37 -6
  22. vectordb_bench/backend/runner/mp_runner.py +2 -1
  23. vectordb_bench/cli/cli.py +137 -0
  24. vectordb_bench/cli/vectordbbench.py +7 -1
  25. vectordb_bench/frontend/components/check_results/charts.py +9 -6
  26. vectordb_bench/frontend/components/check_results/data.py +13 -6
  27. vectordb_bench/frontend/components/concurrent/charts.py +3 -6
  28. vectordb_bench/frontend/components/run_test/caseSelector.py +10 -0
  29. vectordb_bench/frontend/components/run_test/dbConfigSetting.py +37 -15
  30. vectordb_bench/frontend/components/run_test/initStyle.py +3 -1
  31. vectordb_bench/frontend/config/dbCaseConfigs.py +230 -9
  32. vectordb_bench/frontend/pages/quries_per_dollar.py +13 -5
  33. vectordb_bench/frontend/vdb_benchmark.py +11 -3
  34. vectordb_bench/models.py +25 -9
  35. vectordb_bench/results/Milvus/result_20230727_standard_milvus.json +53 -1
  36. vectordb_bench/results/Milvus/result_20230808_standard_milvus.json +48 -0
  37. vectordb_bench/results/ZillizCloud/result_20230727_standard_zillizcloud.json +29 -1
  38. vectordb_bench/results/ZillizCloud/result_20230808_standard_zillizcloud.json +24 -0
  39. vectordb_bench/results/ZillizCloud/result_20240105_standard_202401_zillizcloud.json +98 -49
  40. vectordb_bench/results/getLeaderboardData.py +17 -7
  41. vectordb_bench/results/leaderboard.json +1 -1
  42. {vectordb_bench-0.0.12.dist-info → vectordb_bench-0.0.14.dist-info}/METADATA +64 -31
  43. {vectordb_bench-0.0.12.dist-info → vectordb_bench-0.0.14.dist-info}/RECORD +47 -40
  44. {vectordb_bench-0.0.12.dist-info → vectordb_bench-0.0.14.dist-info}/WHEEL +1 -1
  45. {vectordb_bench-0.0.12.dist-info → vectordb_bench-0.0.14.dist-info}/LICENSE +0 -0
  46. {vectordb_bench-0.0.12.dist-info → vectordb_bench-0.0.14.dist-info}/entry_points.txt +0 -0
  47. {vectordb_bench-0.0.12.dist-info → vectordb_bench-0.0.14.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,254 @@
1
+ import logging, time
2
+ from contextlib import contextmanager
3
+ from typing import Any, Generator, Optional, Tuple, Type
4
+ from ..api import VectorDB, DBCaseConfig, IndexType
5
+ from .config import MemoryDBIndexConfig
6
+ import redis
7
+ from redis import Redis
8
+ from redis.cluster import RedisCluster
9
+ from redis.commands.search.field import TagField, VectorField, NumericField
10
+ from redis.commands.search.indexDefinition import IndexDefinition, IndexType
11
+ from redis.commands.search.query import Query
12
+ import numpy as np
13
+
14
+
15
+ log = logging.getLogger(__name__)
16
+ INDEX_NAME = "index" # Vector Index Name
17
+
18
+ class MemoryDB(VectorDB):
19
+ def __init__(
20
+ self,
21
+ dim: int,
22
+ db_config: dict,
23
+ db_case_config: MemoryDBIndexConfig,
24
+ drop_old: bool = False,
25
+ **kwargs
26
+ ):
27
+
28
+ self.db_config = db_config
29
+ self.case_config = db_case_config
30
+ self.collection_name = INDEX_NAME
31
+ self.target_nodes = RedisCluster.RANDOM if not self.db_config["cmd"] else None
32
+ self.insert_batch_size = db_case_config.insert_batch_size
33
+ self.dbsize = kwargs.get("num_rows")
34
+
35
+ # Create a MemoryDB connection, if db has password configured, add it to the connection here and in init():
36
+ log.info(f"Establishing connection to: {self.db_config}")
37
+ conn = self.get_client(primary=True)
38
+ log.info(f"Connection established: {conn}")
39
+ log.info(conn.execute_command("INFO server"))
40
+
41
+ if drop_old:
42
+ try:
43
+ log.info(f"MemoryDB client getting info for: {INDEX_NAME}")
44
+ info = conn.ft(INDEX_NAME).info()
45
+ log.info(f"Index info: {info}")
46
+ except redis.exceptions.ResponseError as e:
47
+ log.error(e)
48
+ drop_old = False
49
+ log.info(f"MemoryDB client drop_old collection: {self.collection_name}")
50
+
51
+ log.info("Executing FLUSHALL")
52
+ conn.flushall()
53
+
54
+ # Since the default behaviour of FLUSHALL is asynchronous, wait for db to be empty
55
+ self.wait_until(self.wait_for_empty_db, 3, "", conn)
56
+ if not self.db_config["cmd"]:
57
+ replica_clients = self.get_client(replicas=True)
58
+ for rc, host in replica_clients:
59
+ self.wait_until(self.wait_for_empty_db, 3, "", rc)
60
+ log.debug(f"Flushall done in the host: {host}")
61
+ rc.close()
62
+
63
+ self.make_index(dim, conn)
64
+ conn.close()
65
+ conn = None
66
+
67
+ def make_index(self, vector_dimensions: int, conn: redis.Redis):
68
+ try:
69
+ # check to see if index exists
70
+ conn.ft(INDEX_NAME).info()
71
+ except Exception as e:
72
+ log.warn(f"Error getting info for index '{INDEX_NAME}': {e}")
73
+ index_param = self.case_config.index_param()
74
+ search_param = self.case_config.search_param()
75
+ vector_parameters = { # Vector Index Type: FLAT or HNSW
76
+ "TYPE": "FLOAT32",
77
+ "DIM": vector_dimensions, # Number of Vector Dimensions
78
+ "DISTANCE_METRIC": index_param["metric"], # Vector Search Distance Metric
79
+ }
80
+ if index_param["m"]:
81
+ vector_parameters["M"] = index_param["m"]
82
+ if index_param["ef_construction"]:
83
+ vector_parameters["EF_CONSTRUCTION"] = index_param["ef_construction"]
84
+ if search_param["ef_runtime"]:
85
+ vector_parameters["EF_RUNTIME"] = search_param["ef_runtime"]
86
+
87
+ schema = (
88
+ TagField("id"),
89
+ NumericField("metadata"),
90
+ VectorField("vector", # Vector Field Name
91
+ "HNSW", vector_parameters
92
+ ),
93
+ )
94
+
95
+ definition = IndexDefinition(index_type=IndexType.HASH)
96
+ rs = conn.ft(INDEX_NAME)
97
+ rs.create_index(schema, definition=definition)
98
+
99
+ def get_client(self, **kwargs):
100
+ """
101
+ Gets either cluster connection or normal connection based on `cmd` flag.
102
+ CMD stands for Cluster Mode Disabled and is a "mode".
103
+ """
104
+ if not self.db_config["cmd"]:
105
+ # Cluster mode enabled
106
+
107
+ client = RedisCluster(
108
+ host=self.db_config["host"],
109
+ port=self.db_config["port"],
110
+ ssl=self.db_config["ssl"],
111
+ password=self.db_config["password"],
112
+ ssl_ca_certs=self.db_config["ssl_ca_certs"],
113
+ ssl_cert_reqs=None,
114
+ )
115
+
116
+ # Ping all nodes to create a connection
117
+ client.execute_command("PING", target_nodes=RedisCluster.ALL_NODES)
118
+ replicas = client.get_replicas()
119
+
120
+ if len(replicas) > 0:
121
+ # FT.SEARCH is a keyless command, use READONLY for replica connections
122
+ client.execute_command("READONLY", target_nodes=RedisCluster.REPLICAS)
123
+
124
+ if kwargs.get("primary", False):
125
+ client = client.get_primaries()[0].redis_connection
126
+
127
+ if kwargs.get("replicas", False):
128
+ # Return client and host name for each replica
129
+ return [(c.redis_connection, c.host) for c in replicas]
130
+
131
+ else:
132
+ client = Redis(
133
+ host=self.db_config["host"],
134
+ port=self.db_config["port"],
135
+ db=0,
136
+ ssl=self.db_config["ssl"],
137
+ password=self.db_config["password"],
138
+ ssl_ca_certs=self.db_config["ssl_ca_certs"],
139
+ ssl_cert_reqs=None,
140
+ )
141
+ client.execute_command("PING")
142
+ return client
143
+
144
+ @contextmanager
145
+ def init(self) -> Generator[None, None, None]:
146
+ """ create and destory connections to database.
147
+
148
+ Examples:
149
+ >>> with self.init():
150
+ >>> self.insert_embeddings()
151
+ """
152
+ self.conn = self.get_client()
153
+ search_param = self.case_config.search_param()
154
+ if search_param["ef_runtime"]:
155
+ self.ef_runtime_str = f'EF_RUNTIME {search_param["ef_runtime"]}'
156
+ else:
157
+ self.ef_runtime_str = ""
158
+ yield
159
+ self.conn.close()
160
+ self.conn = None
161
+
162
+ def ready_to_load(self) -> bool:
163
+ pass
164
+
165
+ def optimize(self) -> None:
166
+ self._post_insert()
167
+
168
+ def insert_embeddings(
169
+ self,
170
+ embeddings: list[list[float]],
171
+ metadata: list[int],
172
+ **kwargs: Any,
173
+ ) -> Tuple[int, Optional[Exception]]:
174
+ """Insert embeddings into the database.
175
+ Should call self.init() first.
176
+ """
177
+
178
+ try:
179
+ with self.conn.pipeline(transaction=False) as pipe:
180
+ for i, embedding in enumerate(embeddings):
181
+ embedding = np.array(embedding).astype(np.float32)
182
+ pipe.hset(metadata[i], mapping = {
183
+ "id": str(metadata[i]),
184
+ "metadata": metadata[i],
185
+ "vector": embedding.tobytes(),
186
+ })
187
+ # Execute the pipe so we don't keep too much in memory at once
188
+ if (i + 1) % self.insert_batch_size == 0:
189
+ pipe.execute()
190
+
191
+ pipe.execute()
192
+ result_len = i + 1
193
+ except Exception as e:
194
+ return 0, e
195
+
196
+ return result_len, None
197
+
198
+ def _post_insert(self):
199
+ """Wait for indexing to finish"""
200
+ client = self.get_client(primary=True)
201
+ log.info("Waiting for background indexing to finish")
202
+ args = (self.wait_for_no_activity, 5, "", client)
203
+ self.wait_until(*args)
204
+ if not self.db_config["cmd"]:
205
+ replica_clients = self.get_client(replicas=True)
206
+ for rc, host_name in replica_clients:
207
+ args = (self.wait_for_no_activity, 5, "", rc)
208
+ self.wait_until(*args)
209
+ log.debug(f"Background indexing completed in the host: {host_name}")
210
+ rc.close()
211
+
212
+ def wait_until(
213
+ self, condition, interval=5, message="Operation took too long", *args
214
+ ):
215
+ while not condition(*args):
216
+ time.sleep(interval)
217
+
218
+ def wait_for_no_activity(self, client: redis.RedisCluster | redis.Redis):
219
+ return (
220
+ client.info("search")["search_background_indexing_status"] == "NO_ACTIVITY"
221
+ )
222
+
223
+ def wait_for_empty_db(self, client: redis.RedisCluster | redis.Redis):
224
+ return client.execute_command("DBSIZE") == 0
225
+
226
+ def search_embedding(
227
+ self,
228
+ query: list[float],
229
+ k: int = 10,
230
+ filters: dict | None = None,
231
+ timeout: int | None = None,
232
+ **kwargs: Any,
233
+ ) -> (list[int]):
234
+ assert self.conn is not None
235
+
236
+ query_vector = np.array(query).astype(np.float32).tobytes()
237
+ query_obj = Query(f"*=>[KNN {k} @vector $vec]").return_fields("id").paging(0, k)
238
+ query_params = {"vec": query_vector}
239
+
240
+ if filters:
241
+ # benchmark test filters of format: {'metadata': '>=10000', 'id': 10000}
242
+ # gets exact match for id, and range for metadata if they exist in filters
243
+ id_value = filters.get("id")
244
+ # Removing '>=' from the id_value: '>=10000'
245
+ metadata_value = filters.get("metadata")[2:]
246
+ if id_value and metadata_value:
247
+ query_obj = Query(f"(@metadata:[{metadata_value} +inf] @id:{ {id_value} })=>[KNN {k} @vector $vec]").return_fields("id").paging(0, k)
248
+ elif id_value:
249
+ #gets exact match for id
250
+ query_obj = Query(f"@id:{ {id_value} }=>[KNN {k} @vector $vec]").return_fields("id").paging(0, k)
251
+ else: #metadata only case, greater than or equal to metadata value
252
+ query_obj = Query(f"@metadata:[{metadata_value} +inf]=>[KNN {k} @vector $vec]").return_fields("id").paging(0, k)
253
+ res = self.conn.ft(INDEX_NAME).search(query_obj, query_params)
254
+ return [int(doc["id"]) for doc in res.docs]
@@ -0,0 +1,154 @@
1
+ from typing import Annotated, Optional, Unpack
2
+
3
+ import click
4
+ import os
5
+ from pydantic import SecretStr
6
+
7
+ from ....cli.cli import (
8
+ CommonTypedDict,
9
+ HNSWFlavor1,
10
+ IVFFlatTypedDict,
11
+ cli,
12
+ click_parameter_decorators_from_typed_dict,
13
+ run,
14
+ )
15
+ from vectordb_bench.backend.clients import DB
16
+
17
+
18
+ class PgVectoRSTypedDict(CommonTypedDict):
19
+ user_name: Annotated[
20
+ str, click.option("--user-name", type=str, help="Db username", required=True)
21
+ ]
22
+ password: Annotated[
23
+ str,
24
+ click.option(
25
+ "--password",
26
+ type=str,
27
+ help="Postgres database password",
28
+ default=lambda: os.environ.get("POSTGRES_PASSWORD", ""),
29
+ show_default="$POSTGRES_PASSWORD",
30
+ ),
31
+ ]
32
+
33
+ host: Annotated[
34
+ str, click.option("--host", type=str, help="Db host", required=True)
35
+ ]
36
+ db_name: Annotated[
37
+ str, click.option("--db-name", type=str, help="Db name", required=True)
38
+ ]
39
+ max_parallel_workers: Annotated[
40
+ Optional[int],
41
+ click.option(
42
+ "--max-parallel-workers",
43
+ type=int,
44
+ help="Sets the maximum number of parallel processes per maintenance operation (index creation)",
45
+ required=False,
46
+ ),
47
+ ]
48
+ quantization_type: Annotated[
49
+ str,
50
+ click.option(
51
+ "--quantization-type",
52
+ type=click.Choice(["trivial", "scalar", "product"]),
53
+ help="quantization type for vectors",
54
+ required=False,
55
+ ),
56
+ ]
57
+ quantization_ratio: Annotated[
58
+ str,
59
+ click.option(
60
+ "--quantization-ratio",
61
+ type=click.Choice(["x4", "x8", "x16", "x32", "x64"]),
62
+ help="quantization ratio(for product quantization)",
63
+ required=False,
64
+ ),
65
+ ]
66
+
67
+
68
+ class PgVectoRSFlatTypedDict(PgVectoRSTypedDict, IVFFlatTypedDict): ...
69
+
70
+
71
+ @cli.command()
72
+ @click_parameter_decorators_from_typed_dict(PgVectoRSFlatTypedDict)
73
+ def PgVectoRSFlat(
74
+ **parameters: Unpack[PgVectoRSFlatTypedDict],
75
+ ):
76
+ from .config import PgVectoRSConfig, PgVectoRSFLATConfig
77
+
78
+ run(
79
+ db=DB.PgVectoRS,
80
+ db_config=PgVectoRSConfig(
81
+ db_label=parameters["db_label"],
82
+ user_name=SecretStr(parameters["user_name"]),
83
+ password=SecretStr(parameters["password"]),
84
+ host=parameters["host"],
85
+ db_name=parameters["db_name"],
86
+ ),
87
+ db_case_config=PgVectoRSFLATConfig(
88
+ max_parallel_workers=parameters["max_parallel_workers"],
89
+ quantization_type=parameters["quantization_type"],
90
+ quantization_ratio=parameters["quantization_ratio"],
91
+ ),
92
+ **parameters,
93
+ )
94
+
95
+
96
+ class PgVectoRSIVFFlatTypedDict(PgVectoRSTypedDict, IVFFlatTypedDict): ...
97
+
98
+
99
+ @cli.command()
100
+ @click_parameter_decorators_from_typed_dict(PgVectoRSIVFFlatTypedDict)
101
+ def PgVectoRSIVFFlat(
102
+ **parameters: Unpack[PgVectoRSIVFFlatTypedDict],
103
+ ):
104
+ from .config import PgVectoRSConfig, PgVectoRSIVFFlatConfig
105
+
106
+ run(
107
+ db=DB.PgVectoRS,
108
+ db_config=PgVectoRSConfig(
109
+ db_label=parameters["db_label"],
110
+ user_name=SecretStr(parameters["user_name"]),
111
+ password=SecretStr(parameters["password"]),
112
+ host=parameters["host"],
113
+ db_name=parameters["db_name"],
114
+ ),
115
+ db_case_config=PgVectoRSIVFFlatConfig(
116
+ max_parallel_workers=parameters["max_parallel_workers"],
117
+ quantization_type=parameters["quantization_type"],
118
+ quantization_ratio=parameters["quantization_ratio"],
119
+ probes=parameters["probes"],
120
+ lists=parameters["lists"],
121
+ ),
122
+ **parameters,
123
+ )
124
+
125
+
126
+ class PgVectoRSHNSWTypedDict(PgVectoRSTypedDict, HNSWFlavor1): ...
127
+
128
+
129
+ @cli.command()
130
+ @click_parameter_decorators_from_typed_dict(PgVectoRSHNSWTypedDict)
131
+ def PgVectoRSHNSW(
132
+ **parameters: Unpack[PgVectoRSHNSWTypedDict],
133
+ ):
134
+ from .config import PgVectoRSConfig, PgVectoRSHNSWConfig
135
+
136
+ run(
137
+ db=DB.PgVectoRS,
138
+ db_config=PgVectoRSConfig(
139
+ db_label=parameters["db_label"],
140
+ user_name=SecretStr(parameters["user_name"]),
141
+ password=SecretStr(parameters["password"]),
142
+ host=parameters["host"],
143
+ db_name=parameters["db_name"],
144
+ ),
145
+ db_case_config=PgVectoRSHNSWConfig(
146
+ max_parallel_workers=parameters["max_parallel_workers"],
147
+ quantization_type=parameters["quantization_type"],
148
+ quantization_ratio=parameters["quantization_ratio"],
149
+ m=parameters["m"],
150
+ ef_construction=parameters["ef_construction"],
151
+ ef_search=parameters["ef_search"],
152
+ ),
153
+ **parameters,
154
+ )
@@ -1,30 +1,53 @@
1
- from typing import Literal
1
+ from abc import abstractmethod
2
+ from typing import TypedDict
3
+
2
4
  from pydantic import BaseModel, SecretStr
3
- from ..api import DBConfig, DBCaseConfig, MetricType, IndexType
5
+ from pgvecto_rs.types import IndexOption, Ivf, Hnsw, Flat, Quantization
6
+ from pgvecto_rs.types.index import QuantizationType, QuantizationRatio
7
+
8
+ from ..api import DBConfig, DBCaseConfig, IndexType, MetricType
4
9
 
5
10
  POSTGRE_URL_PLACEHOLDER = "postgresql://%s:%s@%s/%s"
6
11
 
7
12
 
13
+ class PgVectorRSConfigDict(TypedDict):
14
+ """These keys will be directly used as kwargs in psycopg connection string,
15
+ so the names must match exactly psycopg API"""
16
+
17
+ user: str
18
+ password: str
19
+ host: str
20
+ port: int
21
+ dbname: str
22
+
23
+
8
24
  class PgVectoRSConfig(DBConfig):
9
- user_name: SecretStr = "postgres"
25
+ user_name: str = "postgres"
10
26
  password: SecretStr
11
27
  host: str = "localhost"
12
28
  port: int = 5432
13
29
  db_name: str
14
30
 
15
31
  def to_dict(self) -> dict:
16
- user_str = self.user_name.get_secret_value()
32
+ user_str = self.user_name
17
33
  pwd_str = self.password.get_secret_value()
18
34
  return {
19
35
  "host": self.host,
20
36
  "port": self.port,
21
37
  "dbname": self.db_name,
22
38
  "user": user_str,
23
- "password": pwd_str
39
+ "password": pwd_str,
24
40
  }
25
41
 
42
+
26
43
  class PgVectoRSIndexConfig(BaseModel, DBCaseConfig):
27
44
  metric_type: MetricType | None = None
45
+ create_index_before_load: bool = False
46
+ create_index_after_load: bool = True
47
+
48
+ max_parallel_workers: int | None = None
49
+ quantization_type: QuantizationType | None = None
50
+ quantization_ratio: QuantizationRatio | None = None
28
51
 
29
52
  def parse_metric(self) -> str:
30
53
  if self.metric_type == MetricType.L2:
@@ -40,88 +63,100 @@ class PgVectoRSIndexConfig(BaseModel, DBCaseConfig):
40
63
  return "<#>"
41
64
  return "<=>"
42
65
 
43
- class PgVectoRSQuantConfig(PgVectoRSIndexConfig):
44
- quantizationType: Literal["trivial", "scalar", "product"]
45
- quantizationRatio: None | Literal["x4", "x8", "x16", "x32", "x64"]
46
-
47
- def parse_quantization(self) -> str:
48
- if self.quantizationType == "trivial":
49
- return "quantization = { trivial = { } }"
50
- elif self.quantizationType == "scalar":
51
- return "quantization = { scalar = { } }"
52
- else:
53
- return f'quantization = {{ product = {{ ratio = "{self.quantizationRatio}" }} }}'
54
-
66
+ def search_param(self) -> dict:
67
+ return {
68
+ "metric_fun_op": self.parse_metric_fun_op(),
69
+ }
55
70
 
56
- class HNSWConfig(PgVectoRSQuantConfig):
57
- M: int
58
- efConstruction: int
59
- index: IndexType = IndexType.HNSW
71
+ @abstractmethod
72
+ def index_param(self) -> dict[str, str]: ...
60
73
 
61
- def index_param(self) -> dict:
62
- options = f"""
63
- [indexing.hnsw]
64
- m = {self.M}
65
- ef_construction = {self.efConstruction}
66
- {self.parse_quantization()}
67
- """
68
- return {"options": options, "metric": self.parse_metric()}
74
+ @abstractmethod
75
+ def session_param(self) -> dict[str, str | int]: ...
69
76
 
70
- def search_param(self) -> dict:
71
- return {"metrics_op": self.parse_metric_fun_op()}
72
77
 
78
+ class PgVectoRSHNSWConfig(PgVectoRSIndexConfig):
79
+ index: IndexType = IndexType.HNSW
80
+ m: int | None = None
81
+ ef_search: int | None
82
+ ef_construction: int | None = None
73
83
 
74
- class IVFFlatConfig(PgVectoRSQuantConfig):
75
- nlist: int
76
- nprobe: int | None = None
84
+ def index_param(self) -> dict[str, str]:
85
+ if self.quantization_type is None:
86
+ quantization = None
87
+ else:
88
+ quantization = Quantization(
89
+ typ=self.quantization_type, ratio=self.quantization_ratio
90
+ )
91
+
92
+ option = IndexOption(
93
+ index=Hnsw(
94
+ m=self.m,
95
+ ef_construction=self.ef_construction,
96
+ quantization=quantization,
97
+ ),
98
+ threads=self.max_parallel_workers,
99
+ )
100
+ return {"options": option.dumps(), "metric": self.parse_metric()}
101
+
102
+ def session_param(self) -> dict[str, str | int]:
103
+ session_parameters = {}
104
+ if self.ef_search is not None:
105
+ session_parameters["vectors.hnsw_ef_search"] = str(self.ef_search)
106
+ return session_parameters
107
+
108
+
109
+ class PgVectoRSIVFFlatConfig(PgVectoRSIndexConfig):
77
110
  index: IndexType = IndexType.IVFFlat
111
+ probes: int | None
112
+ lists: int | None
113
+
114
+ def index_param(self) -> dict[str, str]:
115
+ if self.quantization_type is None:
116
+ quantization = None
117
+ else:
118
+ quantization = Quantization(
119
+ typ=self.quantization_type, ratio=self.quantization_ratio
120
+ )
78
121
 
79
- def index_param(self) -> dict:
80
- options = f"""
81
- [indexing.ivf]
82
- nlist = {self.nlist}
83
- nsample = {self.nprobe if self.nprobe else 10}
84
- {self.parse_quantization()}
85
- """
86
- return {"options": options, "metric": self.parse_metric()}
122
+ option = IndexOption(
123
+ index=Ivf(nlist=self.lists, quantization=quantization),
124
+ threads=self.max_parallel_workers,
125
+ )
126
+ return {"options": option.dumps(), "metric": self.parse_metric()}
87
127
 
88
- def search_param(self) -> dict:
89
- return {"metrics_op": self.parse_metric_fun_op()}
90
-
91
- class IVFFlatSQ8Config(PgVectoRSIndexConfig):
92
- nlist: int
93
- nprobe: int | None = None
94
- index: IndexType = IndexType.IVFSQ8
95
-
96
- def index_param(self) -> dict:
97
- options = f"""
98
- [indexing.ivf]
99
- nlist = {self.nlist}
100
- nsample = {self.nprobe if self.nprobe else 10}
101
- quantization = {{ scalar = {{ }} }}
102
- """
103
- return {"options": options, "metric": self.parse_metric()}
128
+ def session_param(self) -> dict[str, str | int]:
129
+ session_parameters = {}
130
+ if self.probes is not None:
131
+ session_parameters["vectors.ivf_nprobe"] = str(self.probes)
132
+ return session_parameters
104
133
 
105
- def search_param(self) -> dict:
106
- return {"metrics_op": self.parse_metric_fun_op()}
107
134
 
108
- class FLATConfig(PgVectoRSQuantConfig):
135
+ class PgVectoRSFLATConfig(PgVectoRSIndexConfig):
109
136
  index: IndexType = IndexType.Flat
110
137
 
111
- def index_param(self) -> dict:
112
- options = f"""
113
- [indexing.flat]
114
- {self.parse_quantization()}
115
- """
116
- return {"options": options, "metric": self.parse_metric()}
138
+ def index_param(self) -> dict[str, str]:
139
+ if self.quantization_type is None:
140
+ quantization = None
141
+ else:
142
+ quantization = Quantization(
143
+ typ=self.quantization_type, ratio=self.quantization_ratio
144
+ )
117
145
 
118
- def search_param(self) -> dict:
119
- return {"metrics_op": self.parse_metric_fun_op()}
146
+ option = IndexOption(
147
+ index=Flat(
148
+ quantization=quantization,
149
+ ),
150
+ threads=self.max_parallel_workers,
151
+ )
152
+ return {"options": option.dumps(), "metric": self.parse_metric()}
153
+
154
+ def session_param(self) -> dict[str, str | int]:
155
+ return {}
120
156
 
121
157
 
122
158
  _pgvecto_rs_case_config = {
123
- IndexType.HNSW: HNSWConfig,
124
- IndexType.IVFFlat: IVFFlatConfig,
125
- IndexType.IVFSQ8: IVFFlatSQ8Config,
126
- IndexType.Flat: FLATConfig,
159
+ IndexType.HNSW: PgVectoRSHNSWConfig,
160
+ IndexType.IVFFlat: PgVectoRSIVFFlatConfig,
161
+ IndexType.Flat: PgVectoRSFLATConfig,
127
162
  }