vectordb-bench 0.0.29__py3-none-any.whl → 0.0.30__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. vectordb_bench/backend/clients/__init__.py +16 -0
  2. vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py +180 -15
  3. vectordb_bench/backend/clients/aws_opensearch/cli.py +51 -21
  4. vectordb_bench/backend/clients/aws_opensearch/config.py +37 -14
  5. vectordb_bench/backend/clients/lancedb/cli.py +62 -8
  6. vectordb_bench/backend/clients/lancedb/config.py +14 -1
  7. vectordb_bench/backend/clients/lancedb/lancedb.py +21 -9
  8. vectordb_bench/backend/clients/memorydb/memorydb.py +2 -2
  9. vectordb_bench/backend/clients/milvus/cli.py +30 -9
  10. vectordb_bench/backend/clients/milvus/config.py +2 -0
  11. vectordb_bench/backend/clients/milvus/milvus.py +7 -1
  12. vectordb_bench/backend/clients/qdrant_local/cli.py +60 -0
  13. vectordb_bench/backend/clients/qdrant_local/config.py +47 -0
  14. vectordb_bench/backend/clients/qdrant_local/qdrant_local.py +232 -0
  15. vectordb_bench/backend/clients/weaviate_cloud/cli.py +29 -3
  16. vectordb_bench/backend/clients/weaviate_cloud/config.py +2 -0
  17. vectordb_bench/backend/clients/weaviate_cloud/weaviate_cloud.py +5 -0
  18. vectordb_bench/cli/batch_cli.py +121 -0
  19. vectordb_bench/cli/vectordbbench.py +4 -0
  20. vectordb_bench/config-files/batch_sample_config.yml +17 -0
  21. vectordb_bench/frontend/config/dbCaseConfigs.py +113 -1
  22. vectordb_bench/models.py +7 -0
  23. {vectordb_bench-0.0.29.dist-info → vectordb_bench-0.0.30.dist-info}/METADATA +48 -2
  24. {vectordb_bench-0.0.29.dist-info → vectordb_bench-0.0.30.dist-info}/RECORD +28 -23
  25. {vectordb_bench-0.0.29.dist-info → vectordb_bench-0.0.30.dist-info}/WHEEL +1 -1
  26. {vectordb_bench-0.0.29.dist-info → vectordb_bench-0.0.30.dist-info}/entry_points.txt +0 -0
  27. {vectordb_bench-0.0.29.dist-info → vectordb_bench-0.0.30.dist-info}/licenses/LICENSE +0 -0
  28. {vectordb_bench-0.0.29.dist-info → vectordb_bench-0.0.30.dist-info}/top_level.txt +0 -0
@@ -27,6 +27,7 @@ class DB(Enum):
27
27
  Pinecone = "Pinecone"
28
28
  ElasticCloud = "ElasticCloud"
29
29
  QdrantCloud = "QdrantCloud"
30
+ QdrantLocal = "QdrantLocal"
30
31
  WeaviateCloud = "WeaviateCloud"
31
32
  PgVector = "PgVector"
32
33
  PgVectoRS = "PgVectoRS"
@@ -75,6 +76,11 @@ class DB(Enum):
75
76
 
76
77
  return QdrantCloud
77
78
 
79
+ if self == DB.QdrantLocal:
80
+ from .qdrant_local.qdrant_local import QdrantLocal
81
+
82
+ return QdrantLocal
83
+
78
84
  if self == DB.WeaviateCloud:
79
85
  from .weaviate_cloud.weaviate_cloud import WeaviateCloud
80
86
 
@@ -201,6 +207,11 @@ class DB(Enum):
201
207
 
202
208
  return QdrantConfig
203
209
 
210
+ if self == DB.QdrantLocal:
211
+ from .qdrant_local.config import QdrantLocalConfig
212
+
213
+ return QdrantLocalConfig
214
+
204
215
  if self == DB.WeaviateCloud:
205
216
  from .weaviate_cloud.config import WeaviateConfig
206
217
 
@@ -323,6 +334,11 @@ class DB(Enum):
323
334
 
324
335
  return QdrantIndexConfig
325
336
 
337
+ if self == DB.QdrantLocal:
338
+ from .qdrant_local.config import QdrantLocalIndexConfig
339
+
340
+ return QdrantLocalIndexConfig
341
+
326
342
  if self == DB.WeaviateCloud:
327
343
  from .weaviate_cloud.config import WeaviateIndexConfig
328
344
 
@@ -36,6 +36,7 @@ class AWSOpenSearch(VectorDB):
36
36
  self.vector_col_name = vector_col_name
37
37
 
38
38
  log.info(f"AWS_OpenSearch client config: {self.db_config}")
39
+ log.info(f"AWS_OpenSearch db case config : {self.case_config}")
39
40
  client = OpenSearch(**self.db_config)
40
41
  if drop_old:
41
42
  log.info(f"AWS_OpenSearch client drop old index: {self.index_name}")
@@ -43,6 +44,14 @@ class AWSOpenSearch(VectorDB):
43
44
  if is_existed:
44
45
  client.indices.delete(index=self.index_name)
45
46
  self._create_index(client)
47
+ else:
48
+ is_existed = client.indices.exists(index=self.index_name)
49
+ if not is_existed:
50
+ self._create_index(client)
51
+ log.info(f"AWS_OpenSearch client create index: {self.index_name}")
52
+
53
+ self._update_ef_search_before_search(client)
54
+ self._load_graphs_to_memory(client)
46
55
 
47
56
  @classmethod
48
57
  def config_cls(cls) -> AWSOpenSearchConfig:
@@ -52,7 +61,17 @@ class AWSOpenSearch(VectorDB):
52
61
  def case_config_cls(cls, index_type: IndexType | None = None) -> AWSOpenSearchIndexConfig:
53
62
  return AWSOpenSearchIndexConfig
54
63
 
55
- def _create_index(self, client: OpenSearch):
64
+ def _create_index(self, client: OpenSearch) -> None:
65
+ ef_search_value = (
66
+ self.case_config.ef_search if self.case_config.ef_search is not None else self.case_config.efSearch
67
+ )
68
+ log.info(f"Creating index with ef_search: {ef_search_value}")
69
+ log.info(f"Creating index with number_of_replicas: {self.case_config.number_of_replicas}")
70
+
71
+ log.info(f"Creating index with engine: {self.case_config.engine}")
72
+ log.info(f"Creating index with metric type: {self.case_config.metric_type_name}")
73
+ log.info(f"All case_config parameters: {self.case_config.__dict__}")
74
+
56
75
  cluster_settings_body = {
57
76
  "persistent": {
58
77
  "knn.algo_param.index_thread_qty": self.case_config.index_thread_qty,
@@ -64,18 +83,15 @@ class AWSOpenSearch(VectorDB):
64
83
  "index": {
65
84
  "knn": True,
66
85
  "number_of_shards": self.case_config.number_of_shards,
67
- "number_of_replicas": 0,
86
+ "number_of_replicas": self.case_config.number_of_replicas,
68
87
  "translog.flush_threshold_size": self.case_config.flush_threshold_size,
69
- # Setting trans log threshold to 5GB
70
- **(
71
- {"knn.algo_param.ef_search": self.case_config.ef_search}
72
- if self.case_config.engine == AWSOS_Engine.nmslib
73
- else {}
74
- ),
88
+ "knn.advanced.approximate_threshold": "-1",
75
89
  },
76
90
  "refresh_interval": self.case_config.refresh_interval,
77
91
  }
92
+ settings["index"]["knn.algo_param.ef_search"] = ef_search_value
78
93
  mappings = {
94
+ "_source": {"excludes": [self.vector_col_name], "recovery_source_excludes": [self.vector_col_name]},
79
95
  "properties": {
80
96
  **{categoryCol: {"type": "keyword"} for categoryCol in self.category_col_names},
81
97
  self.vector_col_name: {
@@ -86,6 +102,8 @@ class AWSOpenSearch(VectorDB):
86
102
  },
87
103
  }
88
104
  try:
105
+ log.info(f"Creating index with settings: {settings}")
106
+ log.info(f"Creating index with mappings: {mappings}")
89
107
  client.indices.create(
90
108
  index=self.index_name,
91
109
  body={"settings": settings, "mappings": mappings},
@@ -112,6 +130,18 @@ class AWSOpenSearch(VectorDB):
112
130
  """Insert the embeddings to the opensearch."""
113
131
  assert self.client is not None, "should self.init() first"
114
132
 
133
+ num_clients = self.case_config.number_of_indexing_clients or 1
134
+ log.info(f"Number of indexing clients from case_config: {num_clients}")
135
+
136
+ if num_clients <= 1:
137
+ log.info("Using single client for data insertion")
138
+ return self._insert_with_single_client(embeddings, metadata)
139
+ log.info(f"Using {num_clients} parallel clients for data insertion")
140
+ return self._insert_with_multiple_clients(embeddings, metadata, num_clients)
141
+
142
+ def _insert_with_single_client(
143
+ self, embeddings: Iterable[list[float]], metadata: list[int]
144
+ ) -> tuple[int, Exception]:
115
145
  insert_data = []
116
146
  for i in range(len(embeddings)):
117
147
  insert_data.append(
@@ -129,7 +159,108 @@ class AWSOpenSearch(VectorDB):
129
159
  except Exception as e:
130
160
  log.warning(f"Failed to insert data: {self.index_name} error: {e!s}")
131
161
  time.sleep(10)
132
- return self.insert_embeddings(embeddings, metadata)
162
+ return self._insert_with_single_client(embeddings, metadata)
163
+
164
+ def _insert_with_multiple_clients(
165
+ self, embeddings: Iterable[list[float]], metadata: list[int], num_clients: int
166
+ ) -> tuple[int, Exception]:
167
+ import concurrent.futures
168
+ from concurrent.futures import ThreadPoolExecutor
169
+
170
+ embeddings_list = list(embeddings)
171
+ chunk_size = max(1, len(embeddings_list) // num_clients)
172
+ chunks = []
173
+
174
+ for i in range(0, len(embeddings_list), chunk_size):
175
+ end = min(i + chunk_size, len(embeddings_list))
176
+ chunks.append((embeddings_list[i:end], metadata[i:end]))
177
+
178
+ clients = []
179
+ for _ in range(min(num_clients, len(chunks))):
180
+ client = OpenSearch(**self.db_config)
181
+ clients.append(client)
182
+
183
+ log.info(f"AWS_OpenSearch using {len(clients)} parallel clients for data insertion")
184
+
185
+ def insert_chunk(client_idx: int, chunk_idx: int):
186
+ chunk_embeddings, chunk_metadata = chunks[chunk_idx]
187
+ client = clients[client_idx]
188
+
189
+ insert_data = []
190
+ for i in range(len(chunk_embeddings)):
191
+ insert_data.append(
192
+ {"index": {"_index": self.index_name, self.id_col_name: chunk_metadata[i]}},
193
+ )
194
+ insert_data.append({self.vector_col_name: chunk_embeddings[i]})
195
+
196
+ try:
197
+ resp = client.bulk(insert_data)
198
+ log.info(f"Client {client_idx} added {len(resp['items'])} documents")
199
+ return len(chunk_embeddings), None
200
+ except Exception as e:
201
+ log.warning(f"Client {client_idx} failed to insert data: {e!s}")
202
+ return 0, e
203
+
204
+ results = []
205
+ with ThreadPoolExecutor(max_workers=len(clients)) as executor:
206
+ futures = []
207
+
208
+ for chunk_idx in range(len(chunks)):
209
+ client_idx = chunk_idx % len(clients)
210
+ futures.append(executor.submit(insert_chunk, client_idx, chunk_idx))
211
+
212
+ for future in concurrent.futures.as_completed(futures):
213
+ count, error = future.result()
214
+ results.append((count, error))
215
+
216
+ from contextlib import suppress
217
+
218
+ for client in clients:
219
+ with suppress(Exception):
220
+ client.close()
221
+
222
+ total_count = sum(count for count, _ in results)
223
+ errors = [error for _, error in results if error is not None]
224
+
225
+ if errors:
226
+ log.warning("Some clients failed to insert data, retrying with single client")
227
+ time.sleep(10)
228
+ return self._insert_with_single_client(embeddings, metadata)
229
+
230
+ resp = self.client.indices.stats(self.index_name)
231
+ log.info(
232
+ f"""Total document count in index after parallel insertion:
233
+ {resp['_all']['primaries']['indexing']['index_total']}""",
234
+ )
235
+
236
+ return (total_count, None)
237
+
238
+ def _update_ef_search_before_search(self, client: OpenSearch):
239
+ ef_search_value = (
240
+ self.case_config.ef_search if self.case_config.ef_search is not None else self.case_config.efSearch
241
+ )
242
+
243
+ try:
244
+ index_settings = client.indices.get_settings(index=self.index_name)
245
+ current_ef_search = (
246
+ index_settings.get(self.index_name, {})
247
+ .get("settings", {})
248
+ .get("index", {})
249
+ .get("knn.algo_param", {})
250
+ .get("ef_search")
251
+ )
252
+
253
+ if current_ef_search != str(ef_search_value):
254
+ log.info(f"Updating ef_search before search from {current_ef_search} to {ef_search_value}")
255
+ settings_body = {"index": {"knn.algo_param.ef_search": ef_search_value}}
256
+ client.indices.put_settings(index=self.index_name, body=settings_body)
257
+ log.info(f"Successfully updated ef_search to {ef_search_value} before search")
258
+
259
+ log.info(f"Current engine: {self.case_config.engine}")
260
+ log.info(f"Current metric_type: {self.case_config.metric_type_name}")
261
+
262
+ except Exception as e:
263
+ log.warning(f"Failed to update ef_search parameter before search: {e}")
133
264
 
134
265
  def search_embedding(
135
266
  self,
@@ -151,9 +282,18 @@ class AWSOpenSearch(VectorDB):
151
282
 
152
283
  body = {
153
284
  "size": k,
154
- "query": {"knn": {self.vector_col_name: {"vector": query, "k": k}}},
285
+ "query": {
286
+ "knn": {
287
+ self.vector_col_name: {
288
+ "vector": query,
289
+ "k": k,
290
+ "method_parameters": {"ef_search": self.case_config.efSearch},
291
+ }
292
+ }
293
+ },
155
294
  **({"filter": {"range": {self.id_col_name: {"gt": filters["id"]}}}} if filters else {}),
156
295
  }
296
+
157
297
  try:
158
298
  resp = self.client.search(
159
299
  index=self.index_name,
@@ -162,6 +302,7 @@ class AWSOpenSearch(VectorDB):
162
302
  _source=False,
163
303
  docvalue_fields=[self.id_col_name],
164
304
  stored_fields="_none_",
305
+ preference="_only_local" if self.case_config.number_of_shards == 1 else None,
165
306
  )
166
307
  log.debug(f"Search took: {resp['took']}")
167
308
  log.debug(f"Search shards: {resp['_shards']}")
@@ -173,6 +314,7 @@ class AWSOpenSearch(VectorDB):
173
314
 
174
315
  def optimize(self, data_size: int | None = None):
175
316
  """optimize will be called between insertion and search in performance cases."""
317
+ self._update_ef_search()
176
318
  # Call refresh first to ensure that all segments are created
177
319
  self._refresh_index()
178
320
  if self.case_config.force_merge_enabled:
@@ -182,7 +324,22 @@ class AWSOpenSearch(VectorDB):
182
324
  # Call refresh again to ensure that the index is ready after force merge.
183
325
  self._refresh_index()
184
326
  # ensure that all graphs are loaded in memory and ready for search
185
- self._load_graphs_to_memory()
327
+ self._load_graphs_to_memory(self.client)
328
+
329
+ def _update_ef_search(self):
330
+ ef_search_value = (
331
+ self.case_config.ef_search if self.case_config.ef_search is not None else self.case_config.efSearch
332
+ )
333
+ log.info(f"Updating ef_search parameter to: {ef_search_value}")
334
+
335
+ settings_body = {"index": {"knn.algo_param.ef_search": ef_search_value}}
336
+ try:
337
+ self.client.indices.put_settings(index=self.index_name, body=settings_body)
338
+ log.info(f"Successfully updated ef_search to {ef_search_value}")
339
+ log.info(f"Current engine: {self.case_config.engine}")
340
+ log.info(f"Current metric_type: {self.case_config.metric_type}")
341
+ except Exception as e:
342
+ log.warning(f"Failed to update ef_search parameter: {e}")
186
343
 
187
344
  def _update_replicas(self):
188
345
  index_settings = self.client.indices.get_settings(index=self.index_name)
@@ -200,7 +357,7 @@ class AWSOpenSearch(VectorDB):
200
357
  while True:
201
358
  res = self.client.cat.indices(index=self.index_name, h="health", format="json")
202
359
  health = res[0]["health"]
203
- if health != "green":
360
+ if health == "green":
204
361
  break
205
362
  log.info(f"The index {self.index_name} has health : {health} and is not green. Retrying")
206
363
  time.sleep(SECONDS_WAITING_FOR_REPLICAS_TO_BE_ENABLED_SEC)
@@ -228,8 +385,16 @@ class AWSOpenSearch(VectorDB):
228
385
  "persistent": {"knn.algo_param.index_thread_qty": self.case_config.index_thread_qty_during_force_merge}
229
386
  }
230
387
  self.client.cluster.put_settings(cluster_settings_body)
388
+
389
+ log.info("Updating the graph threshold to ensure that during merge we can do graph creation.")
390
+ output = self.client.indices.put_settings(
391
+ index=self.index_name, body={"index.knn.advanced.approximate_threshold": "0"}
392
+ )
393
+ log.info(f"response of updating setting is: {output}")
394
+
231
395
  log.debug(f"Starting force merge for index {self.index_name}")
232
- force_merge_endpoint = f"/{self.index_name}/_forcemerge?max_num_segments=1&wait_for_completion=false"
396
+ segments = self.case_config.number_of_segments
397
+ force_merge_endpoint = f"/{self.index_name}/_forcemerge?max_num_segments={segments}&wait_for_completion=false"
233
398
  force_merge_task_id = self.client.transport.perform_request("POST", force_merge_endpoint)["task"]
234
399
  while True:
235
400
  time.sleep(WAITING_FOR_FORCE_MERGE_SEC)
@@ -238,8 +403,8 @@ class AWSOpenSearch(VectorDB):
238
403
  break
239
404
  log.debug(f"Completed force merge for index {self.index_name}")
240
405
 
241
- def _load_graphs_to_memory(self):
406
+ def _load_graphs_to_memory(self, client: OpenSearch):
242
407
  if self.case_config.engine != AWSOS_Engine.lucene:
243
408
  log.info("Calling warmup API to load graphs into memory")
244
409
  warmup_endpoint = f"/_plugins/_knn/warmup/{self.index_name}"
245
- self.client.transport.perform_request("GET", warmup_endpoint)
410
+ client.transport.perform_request("GET", warmup_endpoint)
@@ -1,3 +1,4 @@
1
+ import logging
1
2
  from typing import Annotated, TypedDict, Unpack
2
3
 
3
4
  import click
@@ -5,18 +6,21 @@ from pydantic import SecretStr
5
6
 
6
7
  from ....cli.cli import (
7
8
  CommonTypedDict,
8
- HNSWFlavor2,
9
+ HNSWFlavor1,
9
10
  cli,
10
11
  click_parameter_decorators_from_typed_dict,
11
12
  run,
12
13
  )
13
14
  from .. import DB
15
+ from .config import AWSOS_Engine, AWSOSQuantization
16
+
17
+ log = logging.getLogger(__name__)
14
18
 
15
19
 
16
20
  class AWSOpenSearchTypedDict(TypedDict):
17
21
  host: Annotated[str, click.option("--host", type=str, help="Db host", required=True)]
18
- port: Annotated[int, click.option("--port", type=int, default=443, help="Db Port")]
19
- user: Annotated[str, click.option("--user", type=str, default="admin", help="Db User")]
22
+ port: Annotated[int, click.option("--port", type=int, default=80, help="Db Port")]
23
+ user: Annotated[str, click.option("--user", type=str, help="Db User")]
20
24
  password: Annotated[str, click.option("--password", type=str, help="Db password")]
21
25
  number_of_shards: Annotated[
22
26
  int,
@@ -38,23 +42,23 @@ class AWSOpenSearchTypedDict(TypedDict):
38
42
  ),
39
43
  ]
40
44
 
41
- index_thread_qty_during_force_merge: Annotated[
42
- int,
45
+ engine: Annotated[
46
+ str,
43
47
  click.option(
44
- "--index-thread-qty-during-force-merge",
45
- type=int,
46
- help="Thread count during force merge operations",
47
- default=4,
48
+ "--engine",
49
+ type=click.Choice(["nmslib", "faiss", "lucene"], case_sensitive=False),
50
+ help="HNSW algorithm implementation to use",
51
+ default="faiss",
48
52
  ),
49
53
  ]
50
54
 
51
- number_of_indexing_clients: Annotated[
52
- int,
55
+ metric_type: Annotated[
56
+ str,
53
57
  click.option(
54
- "--number-of-indexing-clients",
55
- type=int,
56
- help="Number of concurrent indexing clients",
57
- default=1,
58
+ "--metric-type",
59
+ type=click.Choice(["l2", "cosine", "ip"], case_sensitive=False),
60
+ help="Distance metric type for vector similarity",
61
+ default="l2",
58
62
  ),
59
63
  ]
60
64
 
@@ -64,26 +68,26 @@ class AWSOpenSearchTypedDict(TypedDict):
64
68
  ]
65
69
 
66
70
  refresh_interval: Annotated[
67
- int,
71
+ str,
68
72
  click.option(
69
73
  "--refresh-interval", type=str, help="How often to make new data available for search", default="60s"
70
74
  ),
71
75
  ]
72
76
 
73
77
  force_merge_enabled: Annotated[
74
- int,
78
+ bool,
75
79
  click.option("--force-merge-enabled", type=bool, help="Whether to perform force merge operation", default=True),
76
80
  ]
77
81
 
78
82
  flush_threshold_size: Annotated[
79
- int,
83
+ str,
80
84
  click.option(
81
85
  "--flush-threshold-size", type=str, help="Size threshold for flushing the transaction log", default="5120mb"
82
86
  ),
83
87
  ]
84
88
 
85
89
  cb_threshold: Annotated[
86
- int,
90
+ str,
87
91
  click.option(
88
92
  "--cb-threshold",
89
93
  type=str,
@@ -92,8 +96,30 @@ class AWSOpenSearchTypedDict(TypedDict):
92
96
  ),
93
97
  ]
94
98
 
99
+ quantization_type: Annotated[
100
+ str | None,
101
+ click.option(
102
+ "--quantization-type",
103
+ type=click.Choice(["fp32", "fp16"]),
104
+ help="quantization type for vectors (in index)",
105
+ default="fp32",
106
+ required=False,
107
+ ),
108
+ ]
109
+
110
+ engine: Annotated[
111
+ str | None,
112
+ click.option(
113
+ "--engine",
114
+ type=click.Choice(["faiss", "lucene"]),
115
+ help="quantization type for vectors (in index)",
116
+ default="faiss",
117
+ required=False,
118
+ ),
119
+ ]
120
+
95
121
 
96
- class AWSOpenSearchHNSWTypedDict(CommonTypedDict, AWSOpenSearchTypedDict, HNSWFlavor2): ...
122
+ class AWSOpenSearchHNSWTypedDict(CommonTypedDict, AWSOpenSearchTypedDict, HNSWFlavor1): ...
97
123
 
98
124
 
99
125
  @cli.command()
@@ -117,9 +143,13 @@ def AWSOpenSearch(**parameters: Unpack[AWSOpenSearchHNSWTypedDict]):
117
143
  refresh_interval=parameters["refresh_interval"],
118
144
  force_merge_enabled=parameters["force_merge_enabled"],
119
145
  flush_threshold_size=parameters["flush_threshold_size"],
120
- number_of_indexing_clients=parameters["number_of_indexing_clients"],
121
146
  index_thread_qty_during_force_merge=parameters["index_thread_qty_during_force_merge"],
122
147
  cb_threshold=parameters["cb_threshold"],
148
+ efConstruction=parameters["ef_construction"],
149
+ efSearch=parameters["ef_runtime"],
150
+ M=parameters["m"],
151
+ engine=AWSOS_Engine(parameters["engine"]),
152
+ quantization_type=AWSOSQuantization(parameters["quantization_type"]),
123
153
  ),
124
154
  **parameters,
125
155
  )
@@ -10,17 +10,21 @@ log = logging.getLogger(__name__)
10
10
 
11
11
  class AWSOpenSearchConfig(DBConfig, BaseModel):
12
12
  host: str = ""
13
- port: int = 443
13
+ port: int = 80
14
14
  user: str = ""
15
15
  password: SecretStr = ""
16
16
 
17
17
  def to_dict(self) -> dict:
18
+ use_ssl = self.port == 443
19
+ http_auth = (
20
+ (self.user, self.password.get_secret_value()) if len(self.user) != 0 and len(self.password) != 0 else ()
21
+ )
18
22
  return {
19
23
  "hosts": [{"host": self.host, "port": self.port}],
20
- "http_auth": (self.user, self.password.get_secret_value()),
21
- "use_ssl": True,
24
+ "http_auth": http_auth,
25
+ "use_ssl": use_ssl,
22
26
  "http_compress": True,
23
- "verify_certs": True,
27
+ "verify_certs": use_ssl,
24
28
  "ssl_assert_hostname": False,
25
29
  "ssl_show_warn": False,
26
30
  "timeout": 600,
@@ -28,16 +32,22 @@ class AWSOpenSearchConfig(DBConfig, BaseModel):
28
32
 
29
33
 
30
34
  class AWSOS_Engine(Enum):
31
- nmslib = "nmslib"
32
35
  faiss = "faiss"
33
- lucene = "Lucene"
36
+ lucene = "lucene"
37
+
38
+
39
+ class AWSOSQuantization(Enum):
40
+ fp32 = "fp32"
41
+ fp16 = "fp16"
34
42
 
35
43
 
36
44
  class AWSOpenSearchIndexConfig(BaseModel, DBCaseConfig):
37
45
  metric_type: MetricType = MetricType.L2
38
46
  engine: AWSOS_Engine = AWSOS_Engine.faiss
39
47
  efConstruction: int = 256
40
- efSearch: int = 256
48
+ ef_search: int = 200
49
+ engine_name: str | None = None
50
+ metric_type_name: str | None = None
41
51
  M: int = 16
42
52
  index_thread_qty: int | None = 4
43
53
  number_of_shards: int | None = 1
@@ -46,31 +56,44 @@ class AWSOpenSearchIndexConfig(BaseModel, DBCaseConfig):
46
56
  refresh_interval: str | None = "60s"
47
57
  force_merge_enabled: bool | None = True
48
58
  flush_threshold_size: str | None = "5120mb"
49
- number_of_indexing_clients: int | None = 1
50
59
  index_thread_qty_during_force_merge: int
51
60
  cb_threshold: str | None = "50%"
61
+ quantization_type: AWSOSQuantization = AWSOSQuantization.fp32
52
62
 
53
63
  def parse_metric(self) -> str:
64
+ log.info(f"User specified metric_type: {self.metric_type_name}")
65
+ self.metric_type = MetricType[self.metric_type_name.upper()]
54
66
  if self.metric_type == MetricType.IP:
55
67
  return "innerproduct"
56
68
  if self.metric_type == MetricType.COSINE:
57
- if self.engine == AWSOS_Engine.faiss:
58
- log.info(
59
- "Using innerproduct because faiss doesn't support cosine as metric type for Opensearch",
60
- )
61
- return "innerproduct"
62
69
  return "cosinesimil"
70
+ if self.metric_type == MetricType.L2:
71
+ log.info("Using l2 as specified by user")
72
+ return "l2"
63
73
  return "l2"
64
74
 
65
75
  def index_param(self) -> dict:
76
+ log.info(f"Using engine: {self.engine} for index creation")
77
+ log.info(f"Using metric_type: {self.metric_type_name} for index creation")
78
+ log.info(f"Resulting space_type: {self.parse_metric()} for index creation")
79
+
80
+ parameters = {"ef_construction": self.efConstruction, "m": self.M}
81
+
82
+ if self.engine == AWSOS_Engine.faiss and self.faiss_use_fp16:
83
+ parameters["encoder"] = {"name": "sq", "parameters": {"type": "fp16"}}
84
+
66
85
  return {
67
86
  "name": "hnsw",
68
- "space_type": self.parse_metric(),
69
87
  "engine": self.engine.value,
70
88
  "parameters": {
71
89
  "ef_construction": self.efConstruction,
72
90
  "m": self.M,
73
91
  "ef_search": self.efSearch,
92
+ **(
93
+ {"encoder": {"name": "sq", "parameters": {"type": self.quantization_type.fp16.value}}}
94
+ if self.quantization_type is not AWSOSQuantization.fp32
95
+ else {}
96
+ ),
74
97
  },
75
98
  }
76
99
 
@@ -58,10 +58,46 @@ def LanceDBAutoIndex(**parameters: Unpack[LanceDBTypedDict]):
58
58
  )
59
59
 
60
60
 
61
+ class LanceDBIVFPQTypedDict(CommonTypedDict, LanceDBTypedDict):
62
+ num_partitions: Annotated[
63
+ int,
64
+ click.option(
65
+ "--num-partitions",
66
+ type=int,
67
+ default=0,
68
+ help="Number of partitions for IVFPQ index, unset = use LanceDB default",
69
+ ),
70
+ ]
71
+ num_sub_vectors: Annotated[
72
+ int,
73
+ click.option(
74
+ "--num-sub-vectors",
75
+ type=int,
76
+ default=0,
77
+ help="Number of sub-vectors for IVFPQ index, unset = use LanceDB default",
78
+ ),
79
+ ]
80
+ nbits: Annotated[
81
+ int,
82
+ click.option(
83
+ "--nbits",
84
+ type=int,
85
+ default=8,
86
+ help="Number of bits for IVFPQ index (must be 4 or 8), unset = use LanceDB default",
87
+ ),
88
+ ]
89
+ nprobes: Annotated[
90
+ int,
91
+ click.option(
92
+ "--nprobes", type=int, default=0, help="Number of probes for IVFPQ search, unset = use LanceDB default"
93
+ ),
94
+ ]
95
+
96
+
61
97
  @cli.command()
62
- @click_parameter_decorators_from_typed_dict(LanceDBTypedDict)
63
- def LanceDBIVFPQ(**parameters: Unpack[LanceDBTypedDict]):
64
- from .config import LanceDBConfig, _lancedb_case_config
98
+ @click_parameter_decorators_from_typed_dict(LanceDBIVFPQTypedDict)
99
+ def LanceDBIVFPQ(**parameters: Unpack[LanceDBIVFPQTypedDict]):
100
+ from .config import LanceDBConfig, LanceDBIndexConfig
65
101
 
66
102
  run(
67
103
  db=DB.LanceDB,
@@ -70,15 +106,29 @@ def LanceDBIVFPQ(**parameters: Unpack[LanceDBTypedDict]):
70
106
  uri=parameters["uri"],
71
107
  token=SecretStr(parameters["token"]) if parameters.get("token") else None,
72
108
  ),
73
- db_case_config=_lancedb_case_config.get(IndexType.IVFPQ)(),
109
+ db_case_config=LanceDBIndexConfig(
110
+ index=IndexType.IVFPQ,
111
+ num_partitions=parameters["num_partitions"],
112
+ num_sub_vectors=parameters["num_sub_vectors"],
113
+ nbits=parameters["nbits"],
114
+ nprobes=parameters["nprobes"],
115
+ ),
74
116
  **parameters,
75
117
  )
76
118
 
77
119
 
120
+ class LanceDBHNSWTypedDict(CommonTypedDict, LanceDBTypedDict):
121
+ m: Annotated[int, click.option("--m", type=int, default=0, help="HNSW parameter m")]
122
+ ef_construction: Annotated[
123
+ int, click.option("--ef-construction", type=int, default=0, help="HNSW parameter ef_construction")
124
+ ]
125
+ ef: Annotated[int, click.option("--ef", type=int, default=0, help="HNSW search parameter ef")]
126
+
127
+
78
128
  @cli.command()
79
- @click_parameter_decorators_from_typed_dict(LanceDBTypedDict)
80
- def LanceDBHNSW(**parameters: Unpack[LanceDBTypedDict]):
81
- from .config import LanceDBConfig, _lancedb_case_config
129
+ @click_parameter_decorators_from_typed_dict(LanceDBHNSWTypedDict)
130
+ def LanceDBHNSW(**parameters: Unpack[LanceDBHNSWTypedDict]):
131
+ from .config import LanceDBConfig, LanceDBHNSWIndexConfig
82
132
 
83
133
  run(
84
134
  db=DB.LanceDB,
@@ -87,6 +137,10 @@ def LanceDBHNSW(**parameters: Unpack[LanceDBTypedDict]):
87
137
  uri=parameters["uri"],
88
138
  token=SecretStr(parameters["token"]) if parameters.get("token") else None,
89
139
  ),
90
- db_case_config=_lancedb_case_config.get(IndexType.HNSW)(),
140
+ db_case_config=LanceDBHNSWIndexConfig(
141
+ m=parameters["m"],
142
+ ef_construction=parameters["ef_construction"],
143
+ ef=parameters["ef"],
144
+ ),
91
145
  **parameters,
92
146
  )