vectordb-bench 0.0.29__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. vectordb_bench/__init__.py +14 -27
  2. vectordb_bench/backend/assembler.py +19 -6
  3. vectordb_bench/backend/cases.py +186 -23
  4. vectordb_bench/backend/clients/__init__.py +32 -0
  5. vectordb_bench/backend/clients/api.py +22 -1
  6. vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py +249 -43
  7. vectordb_bench/backend/clients/aws_opensearch/cli.py +51 -21
  8. vectordb_bench/backend/clients/aws_opensearch/config.py +58 -16
  9. vectordb_bench/backend/clients/chroma/chroma.py +6 -2
  10. vectordb_bench/backend/clients/elastic_cloud/config.py +19 -1
  11. vectordb_bench/backend/clients/elastic_cloud/elastic_cloud.py +133 -45
  12. vectordb_bench/backend/clients/lancedb/cli.py +62 -8
  13. vectordb_bench/backend/clients/lancedb/config.py +14 -1
  14. vectordb_bench/backend/clients/lancedb/lancedb.py +21 -9
  15. vectordb_bench/backend/clients/memorydb/memorydb.py +2 -2
  16. vectordb_bench/backend/clients/milvus/cli.py +30 -9
  17. vectordb_bench/backend/clients/milvus/config.py +3 -0
  18. vectordb_bench/backend/clients/milvus/milvus.py +81 -23
  19. vectordb_bench/backend/clients/oceanbase/cli.py +100 -0
  20. vectordb_bench/backend/clients/oceanbase/config.py +125 -0
  21. vectordb_bench/backend/clients/oceanbase/oceanbase.py +215 -0
  22. vectordb_bench/backend/clients/pinecone/pinecone.py +39 -25
  23. vectordb_bench/backend/clients/qdrant_cloud/config.py +59 -3
  24. vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py +100 -33
  25. vectordb_bench/backend/clients/qdrant_local/cli.py +60 -0
  26. vectordb_bench/backend/clients/qdrant_local/config.py +47 -0
  27. vectordb_bench/backend/clients/qdrant_local/qdrant_local.py +232 -0
  28. vectordb_bench/backend/clients/weaviate_cloud/cli.py +29 -3
  29. vectordb_bench/backend/clients/weaviate_cloud/config.py +2 -0
  30. vectordb_bench/backend/clients/weaviate_cloud/weaviate_cloud.py +5 -0
  31. vectordb_bench/backend/dataset.py +143 -27
  32. vectordb_bench/backend/filter.py +76 -0
  33. vectordb_bench/backend/runner/__init__.py +3 -3
  34. vectordb_bench/backend/runner/mp_runner.py +52 -39
  35. vectordb_bench/backend/runner/rate_runner.py +68 -52
  36. vectordb_bench/backend/runner/read_write_runner.py +125 -68
  37. vectordb_bench/backend/runner/serial_runner.py +56 -23
  38. vectordb_bench/backend/task_runner.py +48 -20
  39. vectordb_bench/cli/batch_cli.py +121 -0
  40. vectordb_bench/cli/cli.py +59 -1
  41. vectordb_bench/cli/vectordbbench.py +7 -0
  42. vectordb_bench/config-files/batch_sample_config.yml +17 -0
  43. vectordb_bench/frontend/components/check_results/data.py +16 -11
  44. vectordb_bench/frontend/components/check_results/filters.py +53 -25
  45. vectordb_bench/frontend/components/check_results/headerIcon.py +16 -13
  46. vectordb_bench/frontend/components/check_results/nav.py +20 -0
  47. vectordb_bench/frontend/components/custom/displayCustomCase.py +43 -8
  48. vectordb_bench/frontend/components/custom/displaypPrams.py +10 -5
  49. vectordb_bench/frontend/components/custom/getCustomConfig.py +10 -0
  50. vectordb_bench/frontend/components/label_filter/charts.py +60 -0
  51. vectordb_bench/frontend/components/run_test/caseSelector.py +48 -52
  52. vectordb_bench/frontend/components/run_test/dbSelector.py +9 -5
  53. vectordb_bench/frontend/components/run_test/inputWidget.py +48 -0
  54. vectordb_bench/frontend/components/run_test/submitTask.py +3 -1
  55. vectordb_bench/frontend/components/streaming/charts.py +253 -0
  56. vectordb_bench/frontend/components/streaming/data.py +62 -0
  57. vectordb_bench/frontend/components/tables/data.py +1 -1
  58. vectordb_bench/frontend/components/welcome/explainPrams.py +66 -0
  59. vectordb_bench/frontend/components/welcome/pagestyle.py +106 -0
  60. vectordb_bench/frontend/components/welcome/welcomePrams.py +147 -0
  61. vectordb_bench/frontend/config/dbCaseConfigs.py +420 -41
  62. vectordb_bench/frontend/config/styles.py +32 -2
  63. vectordb_bench/frontend/pages/concurrent.py +5 -1
  64. vectordb_bench/frontend/pages/custom.py +4 -0
  65. vectordb_bench/frontend/pages/label_filter.py +56 -0
  66. vectordb_bench/frontend/pages/quries_per_dollar.py +5 -1
  67. vectordb_bench/frontend/pages/results.py +60 -0
  68. vectordb_bench/frontend/pages/run_test.py +3 -3
  69. vectordb_bench/frontend/pages/streaming.py +135 -0
  70. vectordb_bench/frontend/pages/tables.py +4 -0
  71. vectordb_bench/frontend/vdb_benchmark.py +16 -41
  72. vectordb_bench/interface.py +6 -2
  73. vectordb_bench/metric.py +15 -1
  74. vectordb_bench/models.py +38 -11
  75. vectordb_bench/results/ElasticCloud/result_20250318_standard_elasticcloud.json +5890 -0
  76. vectordb_bench/results/Milvus/result_20250509_standard_milvus.json +6138 -0
  77. vectordb_bench/results/OpenSearch/result_20250224_standard_opensearch.json +7319 -0
  78. vectordb_bench/results/Pinecone/result_20250124_standard_pinecone.json +2365 -0
  79. vectordb_bench/results/QdrantCloud/result_20250602_standard_qdrantcloud.json +3556 -0
  80. vectordb_bench/results/ZillizCloud/result_20250613_standard_zillizcloud.json +6290 -0
  81. vectordb_bench/results/dbPrices.json +12 -4
  82. {vectordb_bench-0.0.29.dist-info → vectordb_bench-1.0.0.dist-info}/METADATA +131 -32
  83. {vectordb_bench-0.0.29.dist-info → vectordb_bench-1.0.0.dist-info}/RECORD +87 -65
  84. {vectordb_bench-0.0.29.dist-info → vectordb_bench-1.0.0.dist-info}/WHEEL +1 -1
  85. vectordb_bench/results/ZillizCloud/result_20230727_standard_zillizcloud.json +0 -791
  86. vectordb_bench/results/ZillizCloud/result_20230808_standard_zillizcloud.json +0 -679
  87. vectordb_bench/results/ZillizCloud/result_20240105_standard_202401_zillizcloud.json +0 -1352
  88. {vectordb_bench-0.0.29.dist-info → vectordb_bench-1.0.0.dist-info}/entry_points.txt +0 -0
  89. {vectordb_bench-0.0.29.dist-info → vectordb_bench-1.0.0.dist-info}/licenses/LICENSE +0 -0
  90. {vectordb_bench-0.0.29.dist-info → vectordb_bench-1.0.0.dist-info}/top_level.txt +0 -0
@@ -10,17 +10,21 @@ log = logging.getLogger(__name__)
10
10
 
11
11
  class AWSOpenSearchConfig(DBConfig, BaseModel):
12
12
  host: str = ""
13
- port: int = 443
13
+ port: int = 80
14
14
  user: str = ""
15
15
  password: SecretStr = ""
16
16
 
17
17
  def to_dict(self) -> dict:
18
+ use_ssl = self.port == 443
19
+ http_auth = (
20
+ (self.user, self.password.get_secret_value()) if len(self.user) != 0 and len(self.password) != 0 else ()
21
+ )
18
22
  return {
19
23
  "hosts": [{"host": self.host, "port": self.port}],
20
- "http_auth": (self.user, self.password.get_secret_value()),
21
- "use_ssl": True,
24
+ "http_auth": http_auth,
25
+ "use_ssl": use_ssl,
22
26
  "http_compress": True,
23
- "verify_certs": True,
27
+ "verify_certs": use_ssl,
24
28
  "ssl_assert_hostname": False,
25
29
  "ssl_show_warn": False,
26
30
  "timeout": 600,
@@ -28,16 +32,22 @@ class AWSOpenSearchConfig(DBConfig, BaseModel):
28
32
 
29
33
 
30
34
  class AWSOS_Engine(Enum):
31
- nmslib = "nmslib"
32
35
  faiss = "faiss"
33
- lucene = "Lucene"
36
+ lucene = "lucene"
37
+
38
+
39
+ class AWSOSQuantization(Enum):
40
+ fp32 = "fp32"
41
+ fp16 = "fp16"
34
42
 
35
43
 
36
44
  class AWSOpenSearchIndexConfig(BaseModel, DBCaseConfig):
37
45
  metric_type: MetricType = MetricType.L2
38
46
  engine: AWSOS_Engine = AWSOS_Engine.faiss
39
47
  efConstruction: int = 256
40
- efSearch: int = 256
48
+ efSearch: int = 100
49
+ engine_name: str | None = None
50
+ metric_type_name: str | None = None
41
51
  M: int = 16
42
52
  index_thread_qty: int | None = 4
43
53
  number_of_shards: int | None = 1
@@ -46,33 +56,65 @@ class AWSOpenSearchIndexConfig(BaseModel, DBCaseConfig):
46
56
  refresh_interval: str | None = "60s"
47
57
  force_merge_enabled: bool | None = True
48
58
  flush_threshold_size: str | None = "5120mb"
49
- number_of_indexing_clients: int | None = 1
50
- index_thread_qty_during_force_merge: int
59
+ index_thread_qty_during_force_merge: int = 8
51
60
  cb_threshold: str | None = "50%"
61
+ number_of_indexing_clients: int | None = 1
62
+ use_routing: bool = False # for label-filter cases
63
+ oversample_factor: float = 1.0
64
+ quantization_type: AWSOSQuantization = AWSOSQuantization.fp32
65
+
66
+ def __eq__(self, obj: any):
67
+ return (
68
+ self.engine == obj.engine
69
+ and self.M == obj.M
70
+ and self.efConstruction == obj.efConstruction
71
+ and self.number_of_shards == obj.number_of_shards
72
+ and self.number_of_replicas == obj.number_of_replicas
73
+ and self.number_of_segments == obj.number_of_segments
74
+ and self.use_routing == obj.use_routing
75
+ and self.quantization_type == obj.quantization_type
76
+ )
52
77
 
53
78
  def parse_metric(self) -> str:
79
+ log.info(f"User specified metric_type: {self.metric_type_name}")
80
+ self.metric_type = MetricType[self.metric_type_name.upper()]
54
81
  if self.metric_type == MetricType.IP:
55
82
  return "innerproduct"
56
83
  if self.metric_type == MetricType.COSINE:
57
- if self.engine == AWSOS_Engine.faiss:
58
- log.info(
59
- "Using innerproduct because faiss doesn't support cosine as metric type for Opensearch",
60
- )
61
- return "innerproduct"
62
84
  return "cosinesimil"
85
+ if self.metric_type == MetricType.L2:
86
+ log.info("Using l2 as specified by user")
87
+ return "l2"
63
88
  return "l2"
64
89
 
90
+ @property
91
+ def use_quant(self) -> bool:
92
+ return self.quantization_type is not AWSOSQuantization.fp32
93
+
65
94
  def index_param(self) -> dict:
95
+ log.info(f"Using engine: {self.engine} for index creation")
96
+ log.info(f"Using metric_type: {self.metric_type_name} for index creation")
97
+ log.info(f"Resulting space_type: {self.parse_metric()} for index creation")
98
+
99
+ parameters = {"ef_construction": self.efConstruction, "m": self.M}
100
+
101
+ if self.engine == AWSOS_Engine.faiss and self.faiss_use_fp16:
102
+ parameters["encoder"] = {"name": "sq", "parameters": {"type": "fp16"}}
103
+
66
104
  return {
67
105
  "name": "hnsw",
68
- "space_type": self.parse_metric(),
69
106
  "engine": self.engine.value,
70
107
  "parameters": {
71
108
  "ef_construction": self.efConstruction,
72
109
  "m": self.M,
73
110
  "ef_search": self.efSearch,
111
+ **(
112
+ {"encoder": {"name": "sq", "parameters": {"type": self.quantization_type.fp16.value}}}
113
+ if self.use_quant
114
+ else {}
115
+ ),
74
116
  },
75
117
  }
76
118
 
77
119
  def search_param(self) -> dict:
78
- return {}
120
+ return {"ef_search": self.efSearch}
@@ -78,8 +78,12 @@ class ChromaClient(VectorDB):
78
78
  """
79
79
  ids = [str(i) for i in metadata]
80
80
  metadata = [{"id": int(i)} for i in metadata]
81
- if len(embeddings) > 0:
82
- self.collection.add(embeddings=embeddings, ids=ids, metadatas=metadata)
81
+ try:
82
+ if len(embeddings) > 0:
83
+ self.collection.add(embeddings=embeddings, ids=ids, metadatas=metadata)
84
+ except Exception as e:
85
+ log.warning(f"Failed to insert data: error: {e!s}")
86
+ return 0, e
83
87
  return len(embeddings), None
84
88
 
85
89
  def search_embedding(
@@ -23,13 +23,31 @@ class ESElementType(str, Enum):
23
23
 
24
24
  class ElasticCloudIndexConfig(BaseModel, DBCaseConfig):
25
25
  element_type: ESElementType = ESElementType.float
26
- index: IndexType = IndexType.ES_HNSW # ES only support 'hnsw'
26
+ index: IndexType = IndexType.ES_HNSW
27
+ number_of_shards: int = 1
28
+ number_of_replicas: int = 0
29
+ refresh_interval: str = "30s"
30
+ merge_max_thread_count: int = 8
31
+ use_rescore: bool = False
32
+ oversample_ratio: float = 2.0
33
+ use_routing: bool = False
34
+ use_force_merge: bool = True
27
35
 
28
36
  metric_type: MetricType | None = None
29
37
  efConstruction: int | None = None
30
38
  M: int | None = None
31
39
  num_candidates: int | None = None
32
40
 
41
+ def __eq__(self, obj: any):
42
+ return (
43
+ self.index == obj.index
44
+ and self.number_of_shards == obj.number_of_shards
45
+ and self.number_of_replicas == obj.number_of_replicas
46
+ and self.use_routing == obj.use_routing
47
+ and self.efConstruction == obj.efConstruction
48
+ and self.M == obj.M
49
+ )
50
+
33
51
  def parse_metric(self) -> str:
34
52
  if self.metric_type == MetricType.L2:
35
53
  return "l2_norm"
@@ -5,6 +5,8 @@ from contextlib import contextmanager
5
5
 
6
6
  from elasticsearch.helpers import bulk
7
7
 
8
+ from vectordb_bench.backend.filter import Filter, FilterOp
9
+
8
10
  from ..api import VectorDB
9
11
  from .config import ElasticCloudIndexConfig
10
12
 
@@ -18,6 +20,12 @@ SECONDS_WAITING_FOR_FORCE_MERGE_API_CALL_SEC = 30
18
20
 
19
21
 
20
22
  class ElasticCloud(VectorDB):
23
+ supported_filter_types: list[FilterOp] = [
24
+ FilterOp.NonFilter,
25
+ FilterOp.NumGE,
26
+ FilterOp.StrEqual,
27
+ ]
28
+
21
29
  def __init__(
22
30
  self,
23
31
  dim: int,
@@ -25,8 +33,10 @@ class ElasticCloud(VectorDB):
25
33
  db_case_config: ElasticCloudIndexConfig,
26
34
  indice: str = "vdb_bench_indice", # must be lowercase
27
35
  id_col_name: str = "id",
36
+ label_col_name: str = "label",
28
37
  vector_col_name: str = "vector",
29
38
  drop_old: bool = False,
39
+ with_scalar_labels: bool = False,
30
40
  **kwargs,
31
41
  ):
32
42
  self.dim = dim
@@ -34,7 +44,9 @@ class ElasticCloud(VectorDB):
34
44
  self.case_config = db_case_config
35
45
  self.indice = indice
36
46
  self.id_col_name = id_col_name
47
+ self.label_col_name = label_col_name
37
48
  self.vector_col_name = vector_col_name
49
+ self.with_scalar_labels = with_scalar_labels
38
50
 
39
51
  from elasticsearch import Elasticsearch
40
52
 
@@ -69,9 +81,17 @@ class ElasticCloud(VectorDB):
69
81
  },
70
82
  },
71
83
  }
84
+ settings = {
85
+ "index": {
86
+ "number_of_shards": self.case_config.number_of_shards,
87
+ "number_of_replicas": self.case_config.number_of_replicas,
88
+ "refresh_interval": self.case_config.refresh_interval,
89
+ "merge.scheduler.max_thread_count": self.case_config.merge_max_thread_count,
90
+ }
91
+ }
72
92
 
73
93
  try:
74
- client.indices.create(index=self.indice, mappings=mappings)
94
+ client.indices.create(index=self.indice, mappings=mappings, settings=settings)
75
95
  except Exception as e:
76
96
  log.warning(f"Failed to create indice: {self.indice} error: {e!s}")
77
97
  raise e from None
@@ -80,21 +100,48 @@ class ElasticCloud(VectorDB):
80
100
  self,
81
101
  embeddings: Iterable[list[float]],
82
102
  metadata: list[int],
103
+ labels_data: list[str] | None = None,
83
104
  **kwargs,
84
105
  ) -> tuple[int, Exception]:
85
106
  """Insert the embeddings to the elasticsearch."""
86
107
  assert self.client is not None, "should self.init() first"
87
108
 
88
- insert_data = [
89
- {
90
- "_index": self.indice,
91
- "_source": {
92
- self.id_col_name: metadata[i],
93
- self.vector_col_name: embeddings[i],
94
- },
95
- }
96
- for i in range(len(embeddings))
97
- ]
109
+ insert_data = (
110
+ [
111
+ (
112
+ {
113
+ "_index": self.indice,
114
+ "_source": {
115
+ self.id_col_name: metadata[i],
116
+ self.label_col_name: labels_data[i],
117
+ self.vector_col_name: embeddings[i],
118
+ },
119
+ "_routing": labels_data[i],
120
+ }
121
+ if self.case_config.use_routing
122
+ else {
123
+ "_index": self.indice,
124
+ "_source": {
125
+ self.id_col_name: metadata[i],
126
+ self.label_col_name: labels_data[i],
127
+ self.vector_col_name: embeddings[i],
128
+ },
129
+ }
130
+ )
131
+ for i in range(len(embeddings))
132
+ ]
133
+ if self.with_scalar_labels
134
+ else [
135
+ {
136
+ "_index": self.indice,
137
+ "_source": {
138
+ self.id_col_name: metadata[i],
139
+ self.vector_col_name: embeddings[i],
140
+ },
141
+ }
142
+ for i in range(len(embeddings))
143
+ ]
144
+ )
98
145
  try:
99
146
  bulk_insert_res = bulk(self.client, insert_data)
100
147
  return (bulk_insert_res[0], None)
@@ -102,59 +149,100 @@ class ElasticCloud(VectorDB):
102
149
  log.warning(f"Failed to insert data: {self.indice} error: {e!s}")
103
150
  return (0, e)
104
151
 
152
+ def prepare_filter(self, filters: Filter):
153
+ self.routing_key = None
154
+ if filters.type == FilterOp.NonFilter:
155
+ self.filter = []
156
+ elif filters.type == FilterOp.NumGE:
157
+ self.filter = {"range": {self.id_col_name: {"gt": filters.int_value}}}
158
+ elif filters.type == FilterOp.StrEqual:
159
+ self.filter = {"term": {self.label_col_name: filters.label_value}}
160
+ if self.case_config.use_routing:
161
+ self.routing_key = filters.label_value
162
+ else:
163
+ msg = f"Not support Filter for Milvus - {filters}"
164
+ raise ValueError(msg)
165
+
105
166
  def search_embedding(
106
167
  self,
107
168
  query: list[float],
108
169
  k: int = 100,
109
- filters: dict | None = None,
170
+ **kwargs,
110
171
  ) -> list[int]:
111
172
  """Get k most similar embeddings to query vector.
112
173
 
113
174
  Args:
114
175
  query(list[float]): query embedding to look up documents similar to.
115
176
  k(int): Number of most similar embeddings to return. Defaults to 100.
116
- filters(dict, optional): filtering expression to filter the data while searching.
117
177
 
118
178
  Returns:
119
179
  list[tuple[int, float]]: list of k most similar embeddings in (id, score) tuple to the query embedding.
120
180
  """
121
181
  assert self.client is not None, "should self.init() first"
122
182
 
123
- knn = {
124
- "field": self.vector_col_name,
125
- "k": k,
126
- "num_candidates": self.case_config.num_candidates,
127
- "filter": [{"range": {self.id_col_name: {"gt": filters["id"]}}}] if filters else [],
128
- "query_vector": query,
129
- }
183
+ if self.case_config.use_rescore:
184
+ oversample_k = int(k * self.case_config.oversample_ratio)
185
+ oversample_num_candidates = int(self.case_config.num_candidates * self.case_config.oversample_ratio)
186
+ knn = {
187
+ "field": self.vector_col_name,
188
+ "k": oversample_k,
189
+ "num_candidates": oversample_num_candidates,
190
+ "filter": self.filter,
191
+ "query_vector": query,
192
+ }
193
+ rescore = {
194
+ "window_size": oversample_k,
195
+ "query": {
196
+ "rescore_query": {
197
+ "script_score": {
198
+ "query": {"match_all": {}},
199
+ "script": {
200
+ "source": f"cosineSimilarity(params.queryVector, '{self.vector_col_name}')",
201
+ "params": {"queryVector": query},
202
+ },
203
+ }
204
+ },
205
+ "query_weight": 0,
206
+ "rescore_query_weight": 1,
207
+ },
208
+ }
209
+ else:
210
+ knn = {
211
+ "field": self.vector_col_name,
212
+ "k": k,
213
+ "num_candidates": self.case_config.num_candidates,
214
+ "filter": self.filter,
215
+ "query_vector": query,
216
+ }
217
+ rescore = None
130
218
  size = k
131
- try:
132
- res = self.client.search(
133
- index=self.indice,
134
- knn=knn,
135
- size=size,
136
- _source=False,
137
- docvalue_fields=[self.id_col_name],
138
- stored_fields="_none_",
139
- filter_path=[f"hits.hits.fields.{self.id_col_name}"],
140
- )
141
- return [h["fields"][self.id_col_name][0] for h in res["hits"]["hits"]]
142
- except Exception as e:
143
- log.warning(f"Failed to search: {self.indice} error: {e!s}")
144
- raise e from None
219
+
220
+ res = self.client.search(
221
+ index=self.indice,
222
+ knn=knn,
223
+ routing=self.routing_key,
224
+ rescore=rescore,
225
+ size=size,
226
+ _source=False,
227
+ docvalue_fields=[self.id_col_name],
228
+ stored_fields="_none_",
229
+ filter_path=[f"hits.hits.fields.{self.id_col_name}"],
230
+ )
231
+ return [h["fields"][self.id_col_name][0] for h in res["hits"]["hits"]]
145
232
 
146
233
  def optimize(self, data_size: int | None = None):
147
234
  """optimize will be called between insertion and search in performance cases."""
148
235
  assert self.client is not None, "should self.init() first"
149
236
  self.client.indices.refresh(index=self.indice)
150
- force_merge_task_id = self.client.indices.forcemerge(
151
- index=self.indice,
152
- max_num_segments=1,
153
- wait_for_completion=False,
154
- )["task"]
155
- log.info(f"Elasticsearch force merge task id: {force_merge_task_id}")
156
- while True:
157
- time.sleep(SECONDS_WAITING_FOR_FORCE_MERGE_API_CALL_SEC)
158
- task_status = self.client.tasks.get(task_id=force_merge_task_id)
159
- if task_status["completed"]:
160
- return
237
+ if self.case_config.use_force_merge:
238
+ force_merge_task_id = self.client.indices.forcemerge(
239
+ index=self.indice,
240
+ max_num_segments=1,
241
+ wait_for_completion=False,
242
+ )["task"]
243
+ log.info(f"Elasticsearch force merge task id: {force_merge_task_id}")
244
+ while True:
245
+ time.sleep(SECONDS_WAITING_FOR_FORCE_MERGE_API_CALL_SEC)
246
+ task_status = self.client.tasks.get(task_id=force_merge_task_id)
247
+ if task_status["completed"]:
248
+ return
@@ -58,10 +58,46 @@ def LanceDBAutoIndex(**parameters: Unpack[LanceDBTypedDict]):
58
58
  )
59
59
 
60
60
 
61
+ class LanceDBIVFPQTypedDict(CommonTypedDict, LanceDBTypedDict):
62
+ num_partitions: Annotated[
63
+ int,
64
+ click.option(
65
+ "--num-partitions",
66
+ type=int,
67
+ default=0,
68
+ help="Number of partitions for IVFPQ index, unset = use LanceDB default",
69
+ ),
70
+ ]
71
+ num_sub_vectors: Annotated[
72
+ int,
73
+ click.option(
74
+ "--num-sub-vectors",
75
+ type=int,
76
+ default=0,
77
+ help="Number of sub-vectors for IVFPQ index, unset = use LanceDB default",
78
+ ),
79
+ ]
80
+ nbits: Annotated[
81
+ int,
82
+ click.option(
83
+ "--nbits",
84
+ type=int,
85
+ default=8,
86
+ help="Number of bits for IVFPQ index (must be 4 or 8), unset = use LanceDB default",
87
+ ),
88
+ ]
89
+ nprobes: Annotated[
90
+ int,
91
+ click.option(
92
+ "--nprobes", type=int, default=0, help="Number of probes for IVFPQ search, unset = use LanceDB default"
93
+ ),
94
+ ]
95
+
96
+
61
97
  @cli.command()
62
- @click_parameter_decorators_from_typed_dict(LanceDBTypedDict)
63
- def LanceDBIVFPQ(**parameters: Unpack[LanceDBTypedDict]):
64
- from .config import LanceDBConfig, _lancedb_case_config
98
+ @click_parameter_decorators_from_typed_dict(LanceDBIVFPQTypedDict)
99
+ def LanceDBIVFPQ(**parameters: Unpack[LanceDBIVFPQTypedDict]):
100
+ from .config import LanceDBConfig, LanceDBIndexConfig
65
101
 
66
102
  run(
67
103
  db=DB.LanceDB,
@@ -70,15 +106,29 @@ def LanceDBIVFPQ(**parameters: Unpack[LanceDBTypedDict]):
70
106
  uri=parameters["uri"],
71
107
  token=SecretStr(parameters["token"]) if parameters.get("token") else None,
72
108
  ),
73
- db_case_config=_lancedb_case_config.get(IndexType.IVFPQ)(),
109
+ db_case_config=LanceDBIndexConfig(
110
+ index=IndexType.IVFPQ,
111
+ num_partitions=parameters["num_partitions"],
112
+ num_sub_vectors=parameters["num_sub_vectors"],
113
+ nbits=parameters["nbits"],
114
+ nprobes=parameters["nprobes"],
115
+ ),
74
116
  **parameters,
75
117
  )
76
118
 
77
119
 
120
+ class LanceDBHNSWTypedDict(CommonTypedDict, LanceDBTypedDict):
121
+ m: Annotated[int, click.option("--m", type=int, default=0, help="HNSW parameter m")]
122
+ ef_construction: Annotated[
123
+ int, click.option("--ef-construction", type=int, default=0, help="HNSW parameter ef_construction")
124
+ ]
125
+ ef: Annotated[int, click.option("--ef", type=int, default=0, help="HNSW search parameter ef")]
126
+
127
+
78
128
  @cli.command()
79
- @click_parameter_decorators_from_typed_dict(LanceDBTypedDict)
80
- def LanceDBHNSW(**parameters: Unpack[LanceDBTypedDict]):
81
- from .config import LanceDBConfig, _lancedb_case_config
129
+ @click_parameter_decorators_from_typed_dict(LanceDBHNSWTypedDict)
130
+ def LanceDBHNSW(**parameters: Unpack[LanceDBHNSWTypedDict]):
131
+ from .config import LanceDBConfig, LanceDBHNSWIndexConfig
82
132
 
83
133
  run(
84
134
  db=DB.LanceDB,
@@ -87,6 +137,10 @@ def LanceDBHNSW(**parameters: Unpack[LanceDBTypedDict]):
87
137
  uri=parameters["uri"],
88
138
  token=SecretStr(parameters["token"]) if parameters.get("token") else None,
89
139
  ),
90
- db_case_config=_lancedb_case_config.get(IndexType.HNSW)(),
140
+ db_case_config=LanceDBHNSWIndexConfig(
141
+ m=parameters["m"],
142
+ ef_construction=parameters["ef_construction"],
143
+ ef=parameters["ef"],
144
+ ),
91
145
  **parameters,
92
146
  )
@@ -25,6 +25,7 @@ class LanceDBIndexConfig(BaseModel, DBCaseConfig):
25
25
  nbits: int = 8 # Must be 4 or 8
26
26
  sample_rate: int = 256
27
27
  max_iterations: int = 50
28
+ nprobes: int = 0
28
29
 
29
30
  def index_param(self) -> dict:
30
31
  if self.index not in [
@@ -52,7 +53,11 @@ class LanceDBIndexConfig(BaseModel, DBCaseConfig):
52
53
  return params
53
54
 
54
55
  def search_param(self) -> dict:
55
- pass
56
+ params = {}
57
+ if self.nprobes > 0:
58
+ params["nprobes"] = self.nprobes
59
+
60
+ return params
56
61
 
57
62
  def parse_metric(self) -> str:
58
63
  if self.metric_type in [MetricType.L2, MetricType.COSINE]:
@@ -81,6 +86,7 @@ class LanceDBHNSWIndexConfig(LanceDBIndexConfig):
81
86
  index: IndexType = IndexType.HNSW
82
87
  m: int = 0
83
88
  ef_construction: int = 0
89
+ ef: int = 0
84
90
 
85
91
  def index_param(self) -> dict:
86
92
  params = LanceDBIndexConfig.index_param(self)
@@ -94,6 +100,13 @@ class LanceDBHNSWIndexConfig(LanceDBIndexConfig):
94
100
 
95
101
  return params
96
102
 
103
+ def search_param(self) -> dict:
104
+ params = {}
105
+ if self.ef != 0:
106
+ params = {"ef": self.ef}
107
+
108
+ return params
109
+
97
110
 
98
111
  _lancedb_case_config = {
99
112
  IndexType.IVFPQ: LanceDBIndexConfig,
@@ -32,6 +32,10 @@ class LanceDB(VectorDB):
32
32
  self.table_name = collection_name
33
33
  self.dim = dim
34
34
  self.uri = db_config["uri"]
35
+ # avoid the search_param being called every time during the search process
36
+ self.search_config = db_case_config.search_param()
37
+
38
+ log.info(f"Search config: {self.search_config}")
35
39
 
36
40
  db = lancedb.connect(self.uri)
37
41
 
@@ -45,7 +49,7 @@ class LanceDB(VectorDB):
45
49
  db.open_table(self.table_name)
46
50
  except Exception:
47
51
  schema = pa.schema(
48
- [pa.field("id", pa.int64()), pa.field("vector", pa.list_(pa.float64(), list_size=self.dim))]
52
+ [pa.field("id", pa.int64()), pa.field("vector", pa.list_(pa.float32(), list_size=self.dim))]
49
53
  )
50
54
  db.create_table(self.table_name, schema=schema, mode="overwrite")
51
55
 
@@ -77,20 +81,28 @@ class LanceDB(VectorDB):
77
81
  filters: dict | None = None,
78
82
  ) -> list[int]:
79
83
  if filters:
80
- results = (
81
- self.table.search(query)
82
- .select(["id"])
83
- .where(f"id >= {filters['id']}", prefilter=True)
84
- .limit(k)
85
- .to_list()
86
- )
84
+ results = self.table.search(query).select(["id"]).where(f"id >= {filters['id']}", prefilter=True).limit(k)
85
+ if self.case_config.index == IndexType.IVFPQ and "nprobes" in self.search_config:
86
+ results = results.nprobes(self.search_config["nprobes"]).to_list()
87
+ elif self.case_config.index == IndexType.HNSW and "ef" in self.search_config:
88
+ results = results.ef(self.search_config["ef"]).to_list()
89
+ else:
90
+ results = results.to_list()
87
91
  else:
88
- results = self.table.search(query).select(["id"]).limit(k).to_list()
92
+ results = self.table.search(query).select(["id"]).limit(k)
93
+ if self.case_config.index == IndexType.IVFPQ and "nprobes" in self.search_config:
94
+ results = results.nprobes(self.search_config["nprobes"]).to_list()
95
+ elif self.case_config.index == IndexType.HNSW and "ef" in self.search_config:
96
+ results = results.ef(self.search_config["ef"]).to_list()
97
+ else:
98
+ results = results.to_list()
99
+
89
100
  return [int(result["id"]) for result in results]
90
101
 
91
102
  def optimize(self, data_size: int | None = None):
92
103
  if self.table and hasattr(self, "case_config") and self.case_config.index != IndexType.NONE:
93
104
  log.info(f"Creating index for LanceDB table ({self.table_name})")
105
+ log.info(f"Index parameters: {self.case_config.index_param()}")
94
106
  self.table.create_index(**self.case_config.index_param())
95
107
  # Better recall with IVF_PQ (though still bad) but breaks HNSW: https://github.com/lancedb/lancedb/issues/2369
96
108
  if self.case_config.index in (IndexType.IVFPQ, IndexType.AUTOINDEX):
@@ -9,10 +9,10 @@ import redis
9
9
  from redis import Redis
10
10
  from redis.cluster import RedisCluster
11
11
  from redis.commands.search.field import NumericField, TagField, VectorField
12
- from redis.commands.search.indexDefinition import IndexDefinition
12
+ from redis.commands.search.indexDefinition import IndexDefinition, IndexType
13
13
  from redis.commands.search.query import Query
14
14
 
15
- from ..api import IndexType, VectorDB
15
+ from ..api import VectorDB
16
16
  from .config import MemoryDBIndexConfig
17
17
 
18
18
  log = logging.getLogger(__name__)