vectordb-bench 0.0.29__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. vectordb_bench/__init__.py +14 -27
  2. vectordb_bench/backend/assembler.py +19 -6
  3. vectordb_bench/backend/cases.py +186 -23
  4. vectordb_bench/backend/clients/__init__.py +32 -0
  5. vectordb_bench/backend/clients/api.py +22 -1
  6. vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py +249 -43
  7. vectordb_bench/backend/clients/aws_opensearch/cli.py +51 -21
  8. vectordb_bench/backend/clients/aws_opensearch/config.py +58 -16
  9. vectordb_bench/backend/clients/chroma/chroma.py +6 -2
  10. vectordb_bench/backend/clients/elastic_cloud/config.py +19 -1
  11. vectordb_bench/backend/clients/elastic_cloud/elastic_cloud.py +133 -45
  12. vectordb_bench/backend/clients/lancedb/cli.py +62 -8
  13. vectordb_bench/backend/clients/lancedb/config.py +14 -1
  14. vectordb_bench/backend/clients/lancedb/lancedb.py +21 -9
  15. vectordb_bench/backend/clients/memorydb/memorydb.py +2 -2
  16. vectordb_bench/backend/clients/milvus/cli.py +30 -9
  17. vectordb_bench/backend/clients/milvus/config.py +3 -0
  18. vectordb_bench/backend/clients/milvus/milvus.py +81 -23
  19. vectordb_bench/backend/clients/oceanbase/cli.py +100 -0
  20. vectordb_bench/backend/clients/oceanbase/config.py +125 -0
  21. vectordb_bench/backend/clients/oceanbase/oceanbase.py +215 -0
  22. vectordb_bench/backend/clients/pinecone/pinecone.py +39 -25
  23. vectordb_bench/backend/clients/qdrant_cloud/config.py +59 -3
  24. vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py +100 -33
  25. vectordb_bench/backend/clients/qdrant_local/cli.py +60 -0
  26. vectordb_bench/backend/clients/qdrant_local/config.py +47 -0
  27. vectordb_bench/backend/clients/qdrant_local/qdrant_local.py +232 -0
  28. vectordb_bench/backend/clients/weaviate_cloud/cli.py +29 -3
  29. vectordb_bench/backend/clients/weaviate_cloud/config.py +2 -0
  30. vectordb_bench/backend/clients/weaviate_cloud/weaviate_cloud.py +5 -0
  31. vectordb_bench/backend/dataset.py +143 -27
  32. vectordb_bench/backend/filter.py +76 -0
  33. vectordb_bench/backend/runner/__init__.py +3 -3
  34. vectordb_bench/backend/runner/mp_runner.py +52 -39
  35. vectordb_bench/backend/runner/rate_runner.py +68 -52
  36. vectordb_bench/backend/runner/read_write_runner.py +125 -68
  37. vectordb_bench/backend/runner/serial_runner.py +56 -23
  38. vectordb_bench/backend/task_runner.py +48 -20
  39. vectordb_bench/cli/batch_cli.py +121 -0
  40. vectordb_bench/cli/cli.py +59 -1
  41. vectordb_bench/cli/vectordbbench.py +7 -0
  42. vectordb_bench/config-files/batch_sample_config.yml +17 -0
  43. vectordb_bench/frontend/components/check_results/data.py +16 -11
  44. vectordb_bench/frontend/components/check_results/filters.py +53 -25
  45. vectordb_bench/frontend/components/check_results/headerIcon.py +16 -13
  46. vectordb_bench/frontend/components/check_results/nav.py +20 -0
  47. vectordb_bench/frontend/components/custom/displayCustomCase.py +43 -8
  48. vectordb_bench/frontend/components/custom/displaypPrams.py +10 -5
  49. vectordb_bench/frontend/components/custom/getCustomConfig.py +10 -0
  50. vectordb_bench/frontend/components/label_filter/charts.py +60 -0
  51. vectordb_bench/frontend/components/run_test/caseSelector.py +48 -52
  52. vectordb_bench/frontend/components/run_test/dbSelector.py +9 -5
  53. vectordb_bench/frontend/components/run_test/inputWidget.py +48 -0
  54. vectordb_bench/frontend/components/run_test/submitTask.py +3 -1
  55. vectordb_bench/frontend/components/streaming/charts.py +253 -0
  56. vectordb_bench/frontend/components/streaming/data.py +62 -0
  57. vectordb_bench/frontend/components/tables/data.py +1 -1
  58. vectordb_bench/frontend/components/welcome/explainPrams.py +66 -0
  59. vectordb_bench/frontend/components/welcome/pagestyle.py +106 -0
  60. vectordb_bench/frontend/components/welcome/welcomePrams.py +147 -0
  61. vectordb_bench/frontend/config/dbCaseConfigs.py +420 -41
  62. vectordb_bench/frontend/config/styles.py +32 -2
  63. vectordb_bench/frontend/pages/concurrent.py +5 -1
  64. vectordb_bench/frontend/pages/custom.py +4 -0
  65. vectordb_bench/frontend/pages/label_filter.py +56 -0
  66. vectordb_bench/frontend/pages/quries_per_dollar.py +5 -1
  67. vectordb_bench/frontend/pages/results.py +60 -0
  68. vectordb_bench/frontend/pages/run_test.py +3 -3
  69. vectordb_bench/frontend/pages/streaming.py +135 -0
  70. vectordb_bench/frontend/pages/tables.py +4 -0
  71. vectordb_bench/frontend/vdb_benchmark.py +16 -41
  72. vectordb_bench/interface.py +6 -2
  73. vectordb_bench/metric.py +15 -1
  74. vectordb_bench/models.py +38 -11
  75. vectordb_bench/results/ElasticCloud/result_20250318_standard_elasticcloud.json +5890 -0
  76. vectordb_bench/results/Milvus/result_20250509_standard_milvus.json +6138 -0
  77. vectordb_bench/results/OpenSearch/result_20250224_standard_opensearch.json +7319 -0
  78. vectordb_bench/results/Pinecone/result_20250124_standard_pinecone.json +2365 -0
  79. vectordb_bench/results/QdrantCloud/result_20250602_standard_qdrantcloud.json +3556 -0
  80. vectordb_bench/results/ZillizCloud/result_20250613_standard_zillizcloud.json +6290 -0
  81. vectordb_bench/results/dbPrices.json +12 -4
  82. {vectordb_bench-0.0.29.dist-info → vectordb_bench-1.0.0.dist-info}/METADATA +131 -32
  83. {vectordb_bench-0.0.29.dist-info → vectordb_bench-1.0.0.dist-info}/RECORD +87 -65
  84. {vectordb_bench-0.0.29.dist-info → vectordb_bench-1.0.0.dist-info}/WHEEL +1 -1
  85. vectordb_bench/results/ZillizCloud/result_20230727_standard_zillizcloud.json +0 -791
  86. vectordb_bench/results/ZillizCloud/result_20230808_standard_zillizcloud.json +0 -679
  87. vectordb_bench/results/ZillizCloud/result_20240105_standard_202401_zillizcloud.json +0 -1352
  88. {vectordb_bench-0.0.29.dist-info → vectordb_bench-1.0.0.dist-info}/entry_points.txt +0 -0
  89. {vectordb_bench-0.0.29.dist-info → vectordb_bench-1.0.0.dist-info}/licenses/LICENSE +0 -0
  90. {vectordb_bench-0.0.29.dist-info → vectordb_bench-1.0.0.dist-info}/top_level.txt +0 -0
@@ -29,6 +29,17 @@ class MilvusTypedDict(TypedDict):
29
29
  str | None,
30
30
  click.option("--password", type=str, help="Db password", required=False),
31
31
  ]
32
+ num_shards: Annotated[
33
+ int,
34
+ click.option(
35
+ "--num-shards",
36
+ type=int,
37
+ help="Number of shards",
38
+ required=False,
39
+ default=1,
40
+ show_default=True,
41
+ ),
42
+ ]
32
43
 
33
44
 
34
45
  class MilvusAutoIndexTypedDict(CommonTypedDict, MilvusTypedDict): ...
@@ -45,7 +56,8 @@ def MilvusAutoIndex(**parameters: Unpack[MilvusAutoIndexTypedDict]):
45
56
  db_label=parameters["db_label"],
46
57
  uri=SecretStr(parameters["uri"]),
47
58
  user=parameters["user_name"],
48
- password=SecretStr(parameters["password"]),
59
+ password=SecretStr(parameters["password"]) if parameters["password"] else None,
60
+ num_shards=int(parameters["num_shards"]),
49
61
  ),
50
62
  db_case_config=AutoIndexConfig(),
51
63
  **parameters,
@@ -63,7 +75,8 @@ def MilvusFlat(**parameters: Unpack[MilvusAutoIndexTypedDict]):
63
75
  db_label=parameters["db_label"],
64
76
  uri=SecretStr(parameters["uri"]),
65
77
  user=parameters["user_name"],
66
- password=SecretStr(parameters["password"]),
78
+ password=SecretStr(parameters["password"]) if parameters["password"] else None,
79
+ num_shards=int(parameters["num_shards"]),
67
80
  ),
68
81
  db_case_config=FLATConfig(),
69
82
  **parameters,
@@ -85,6 +98,7 @@ def MilvusHNSW(**parameters: Unpack[MilvusHNSWTypedDict]):
85
98
  uri=SecretStr(parameters["uri"]),
86
99
  user=parameters["user_name"],
87
100
  password=SecretStr(parameters["password"]) if parameters["password"] else None,
101
+ num_shards=int(parameters["num_shards"]),
88
102
  ),
89
103
  db_case_config=HNSWConfig(
90
104
  M=parameters["m"],
@@ -109,7 +123,8 @@ def MilvusIVFFlat(**parameters: Unpack[MilvusIVFFlatTypedDict]):
109
123
  db_label=parameters["db_label"],
110
124
  uri=SecretStr(parameters["uri"]),
111
125
  user=parameters["user_name"],
112
- password=SecretStr(parameters["password"]),
126
+ password=SecretStr(parameters["password"]) if parameters["password"] else None,
127
+ num_shards=int(parameters["num_shards"]),
113
128
  ),
114
129
  db_case_config=IVFFlatConfig(
115
130
  nlist=parameters["nlist"],
@@ -130,7 +145,8 @@ def MilvusIVFSQ8(**parameters: Unpack[MilvusIVFFlatTypedDict]):
130
145
  db_label=parameters["db_label"],
131
146
  uri=SecretStr(parameters["uri"]),
132
147
  user=parameters["user_name"],
133
- password=SecretStr(parameters["password"]),
148
+ password=SecretStr(parameters["password"]) if parameters["password"] else None,
149
+ num_shards=int(parameters["num_shards"]),
134
150
  ),
135
151
  db_case_config=IVFSQ8Config(
136
152
  nlist=parameters["nlist"],
@@ -155,7 +171,8 @@ def MilvusDISKANN(**parameters: Unpack[MilvusDISKANNTypedDict]):
155
171
  db_label=parameters["db_label"],
156
172
  uri=SecretStr(parameters["uri"]),
157
173
  user=parameters["user_name"],
158
- password=SecretStr(parameters["password"]),
174
+ password=SecretStr(parameters["password"]) if parameters["password"] else None,
175
+ num_shards=int(parameters["num_shards"]),
159
176
  ),
160
177
  db_case_config=DISKANNConfig(
161
178
  search_list=parameters["search_list"],
@@ -183,7 +200,8 @@ def MilvusGPUIVFFlat(**parameters: Unpack[MilvusGPUIVFTypedDict]):
183
200
  db_label=parameters["db_label"],
184
201
  uri=SecretStr(parameters["uri"]),
185
202
  user=parameters["user_name"],
186
- password=SecretStr(parameters["password"]),
203
+ password=SecretStr(parameters["password"]) if parameters["password"] else None,
204
+ num_shards=int(parameters["num_shards"]),
187
205
  ),
188
206
  db_case_config=GPUIVFFlatConfig(
189
207
  nlist=parameters["nlist"],
@@ -217,7 +235,8 @@ def MilvusGPUBruteForce(**parameters: Unpack[MilvusGPUBruteForceTypedDict]):
217
235
  db_label=parameters["db_label"],
218
236
  uri=SecretStr(parameters["uri"]),
219
237
  user=parameters["user_name"],
220
- password=SecretStr(parameters["password"]),
238
+ password=SecretStr(parameters["password"]) if parameters["password"] else None,
239
+ num_shards=int(parameters["num_shards"]),
221
240
  ),
222
241
  db_case_config=GPUBruteForceConfig(
223
242
  metric_type=parameters["metric_type"],
@@ -248,7 +267,8 @@ def MilvusGPUIVFPQ(**parameters: Unpack[MilvusGPUIVFPQTypedDict]):
248
267
  db_label=parameters["db_label"],
249
268
  uri=SecretStr(parameters["uri"]),
250
269
  user=parameters["user_name"],
251
- password=SecretStr(parameters["password"]),
270
+ password=SecretStr(parameters["password"]) if parameters["password"] else None,
271
+ num_shards=int(parameters["num_shards"]),
252
272
  ),
253
273
  db_case_config=GPUIVFPQConfig(
254
274
  nlist=parameters["nlist"],
@@ -287,7 +307,8 @@ def MilvusGPUCAGRA(**parameters: Unpack[MilvusGPUCAGRATypedDict]):
287
307
  db_label=parameters["db_label"],
288
308
  uri=SecretStr(parameters["uri"]),
289
309
  user=parameters["user_name"],
290
- password=SecretStr(parameters["password"]),
310
+ password=SecretStr(parameters["password"]) if parameters["password"] else None,
311
+ num_shards=int(parameters["num_shards"]),
291
312
  ),
292
313
  db_case_config=GPUCAGRAConfig(
293
314
  intermediate_graph_degree=parameters["intermediate_graph_degree"],
@@ -7,12 +7,14 @@ class MilvusConfig(DBConfig):
7
7
  uri: SecretStr = "http://localhost:19530"
8
8
  user: str | None = None
9
9
  password: SecretStr | None = None
10
+ num_shards: int = 1
10
11
 
11
12
  def to_dict(self) -> dict:
12
13
  return {
13
14
  "uri": self.uri.get_secret_value(),
14
15
  "user": self.user if self.user else None,
15
16
  "password": self.password.get_secret_value() if self.password else None,
17
+ "num_shards": self.num_shards,
16
18
  }
17
19
 
18
20
  @validator("*")
@@ -33,6 +35,7 @@ class MilvusIndexConfig(BaseModel):
33
35
 
34
36
  index: IndexType
35
37
  metric_type: MetricType | None = None
38
+ use_partition_key: bool = True # for label-filter
36
39
 
37
40
  @property
38
41
  def is_gpu_index(self) -> bool:
@@ -7,6 +7,8 @@ from contextlib import contextmanager
7
7
 
8
8
  from pymilvus import Collection, CollectionSchema, DataType, FieldSchema, MilvusException, utility
9
9
 
10
+ from vectordb_bench.backend.filter import Filter, FilterOp
11
+
10
12
  from ..api import VectorDB
11
13
  from .config import MilvusIndexConfig
12
14
 
@@ -16,6 +18,12 @@ MILVUS_LOAD_REQS_SIZE = 1.5 * 1024 * 1024
16
18
 
17
19
 
18
20
  class Milvus(VectorDB):
21
+ supported_filter_types: list[FilterOp] = [
22
+ FilterOp.NonFilter,
23
+ FilterOp.NumGE,
24
+ FilterOp.StrEqual,
25
+ ]
26
+
19
27
  def __init__(
20
28
  self,
21
29
  dim: int,
@@ -24,6 +32,7 @@ class Milvus(VectorDB):
24
32
  collection_name: str = "VectorDBBenchCollection",
25
33
  drop_old: bool = False,
26
34
  name: str = "Milvus",
35
+ with_scalar_labels: bool = False,
27
36
  **kwargs,
28
37
  ):
29
38
  """Initialize wrapper around the milvus vector database."""
@@ -32,15 +41,24 @@ class Milvus(VectorDB):
32
41
  self.case_config = db_case_config
33
42
  self.collection_name = collection_name
34
43
  self.batch_size = int(MILVUS_LOAD_REQS_SIZE / (dim * 4))
44
+ self.with_scalar_labels = with_scalar_labels
35
45
 
36
46
  self._primary_field = "pk"
37
- self._scalar_field = "id"
47
+ self._scalar_id_field = "id"
48
+ self._scalar_label_field = "label"
38
49
  self._vector_field = "vector"
39
- self._index_name = "vector_idx"
50
+ self._vector_index_name = "vector_idx"
51
+ self._scalar_id_index_name = "id_sort_idx"
52
+ self._scalar_labels_index_name = "labels_idx"
40
53
 
41
54
  from pymilvus import connections
42
55
 
43
- connections.connect(**self.db_config, timeout=30)
56
+ connections.connect(
57
+ uri=self.db_config.get("uri"),
58
+ user=self.db_config.get("user"),
59
+ password=self.db_config.get("password"),
60
+ timeout=30,
61
+ )
44
62
  if drop_old and utility.has_collection(self.collection_name):
45
63
  log.info(f"{self.name} client drop_old collection: {self.collection_name}")
46
64
  utility.drop_collection(self.collection_name)
@@ -48,9 +66,20 @@ class Milvus(VectorDB):
48
66
  if not utility.has_collection(self.collection_name):
49
67
  fields = [
50
68
  FieldSchema(self._primary_field, DataType.INT64, is_primary=True),
51
- FieldSchema(self._scalar_field, DataType.INT64),
69
+ FieldSchema(self._scalar_id_field, DataType.INT64),
52
70
  FieldSchema(self._vector_field, DataType.FLOAT_VECTOR, dim=dim),
53
71
  ]
72
+ if self.with_scalar_labels:
73
+ is_partition_key = db_case_config.use_partition_key
74
+ log.info(f"with_scalar_labels, add a new varchar field, as partition_key: {is_partition_key}")
75
+ fields.append(
76
+ FieldSchema(
77
+ self._scalar_label_field,
78
+ DataType.VARCHAR,
79
+ max_length=256,
80
+ is_partition_key=is_partition_key,
81
+ )
82
+ )
54
83
 
55
84
  log.info(f"{self.name} create collection: {self.collection_name}")
56
85
 
@@ -59,18 +88,40 @@ class Milvus(VectorDB):
59
88
  name=self.collection_name,
60
89
  schema=CollectionSchema(fields),
61
90
  consistency_level="Session",
91
+ num_shards=self.db_config.get("num_shards"),
62
92
  )
63
93
 
64
- log.info(f"{self.name} create index: index_params: {self.case_config.index_param()}")
65
- col.create_index(
66
- self._vector_field,
67
- self.case_config.index_param(),
68
- index_name=self._index_name,
69
- )
94
+ self.create_index()
70
95
  col.load()
71
96
 
72
97
  connections.disconnect("default")
73
98
 
99
+ def create_index(self):
100
+ col = Collection(self.collection_name)
101
+ # vector index
102
+ col.create_index(
103
+ self._vector_field,
104
+ self.case_config.index_param(),
105
+ index_name=self._vector_index_name,
106
+ )
107
+ # scalar index for range-expr (int-filter)
108
+ col.create_index(
109
+ self._scalar_id_field,
110
+ index_params={
111
+ "index_type": "STL_SORT",
112
+ },
113
+ index_name=self._scalar_id_index_name,
114
+ )
115
+ # scalar index for varchar (label-filter)
116
+ if self.with_scalar_labels:
117
+ col.create_index(
118
+ self._scalar_label_field,
119
+ index_params={
120
+ "index_type": "BITMAP",
121
+ },
122
+ index_name=self._scalar_labels_index_name,
123
+ )
124
+
74
125
  @contextmanager
75
126
  def init(self):
76
127
  """
@@ -103,17 +154,13 @@ class Milvus(VectorDB):
103
154
  try:
104
155
  self.col.flush()
105
156
  # wait for index done and load refresh
106
- self.col.create_index(
107
- self._vector_field,
108
- self.case_config.index_param(),
109
- index_name=self._index_name,
110
- )
157
+ self.create_index()
111
158
 
112
- utility.wait_for_index_building_complete(self.collection_name)
159
+ utility.wait_for_index_building_complete(self.collection_name, index_name=self._vector_index_name)
113
160
 
114
161
  def wait_index():
115
162
  while True:
116
- progress = utility.index_building_progress(self.collection_name)
163
+ progress = utility.index_building_progress(self.collection_name, index_name=self._vector_index_name)
117
164
  if progress.get("pending_index_rows", -1) == 0:
118
165
  break
119
166
  time.sleep(5)
@@ -156,6 +203,7 @@ class Milvus(VectorDB):
156
203
  self,
157
204
  embeddings: Iterable[list[float]],
158
205
  metadata: list[int],
206
+ labels_data: list[str] | None = None,
159
207
  **kwargs,
160
208
  ) -> tuple[int, Exception]:
161
209
  """Insert embeddings into Milvus. should call self.init() first"""
@@ -171,32 +219,42 @@ class Milvus(VectorDB):
171
219
  metadata[batch_start_offset:batch_end_offset],
172
220
  embeddings[batch_start_offset:batch_end_offset],
173
221
  ]
222
+ if self.with_scalar_labels:
223
+ insert_data.append(labels_data[batch_start_offset:batch_end_offset])
174
224
  res = self.col.insert(insert_data)
175
225
  insert_count += len(res.primary_keys)
176
226
  except MilvusException as e:
177
227
  log.info(f"Failed to insert data: {e}")
178
- return (insert_count, e)
179
- return (insert_count, None)
228
+ return insert_count, e
229
+ return insert_count, None
230
+
231
+ def prepare_filter(self, filters: Filter):
232
+ if filters.type == FilterOp.NonFilter:
233
+ self.expr = ""
234
+ elif filters.type == FilterOp.NumGE:
235
+ self.expr = f"{self._scalar_id_field} >= {filters.int_value}"
236
+ elif filters.type == FilterOp.StrEqual:
237
+ self.expr = f"{self._scalar_label_field} == '{filters.label_value}'"
238
+ else:
239
+ msg = f"Not support Filter for Milvus - {filters}"
240
+ raise ValueError(msg)
180
241
 
181
242
  def search_embedding(
182
243
  self,
183
244
  query: list[float],
184
245
  k: int = 100,
185
- filters: dict | None = None,
186
246
  timeout: int | None = None,
187
247
  ) -> list[int]:
188
248
  """Perform a search on a query embedding and return results."""
189
249
  assert self.col is not None
190
250
 
191
- expr = f"{self._scalar_field} {filters.get('metadata')}" if filters else ""
192
-
193
251
  # Perform the search.
194
252
  res = self.col.search(
195
253
  data=[query],
196
254
  anns_field=self._vector_field,
197
255
  param=self.case_config.search_param(),
198
256
  limit=k,
199
- expr=expr,
257
+ expr=self.expr,
200
258
  )
201
259
 
202
260
  # Organize results.
@@ -0,0 +1,100 @@
1
+ import os
2
+ from typing import Annotated, Unpack
3
+
4
+ import click
5
+ from pydantic import SecretStr
6
+
7
+ from vectordb_bench.backend.clients import DB
8
+ from vectordb_bench.cli.cli import (
9
+ CommonTypedDict,
10
+ HNSWFlavor4,
11
+ OceanBaseIVFTypedDict,
12
+ cli,
13
+ click_parameter_decorators_from_typed_dict,
14
+ run,
15
+ )
16
+
17
+ from ..api import IndexType
18
+
19
+
20
+ class OceanBaseTypedDict(CommonTypedDict):
21
+ host: Annotated[str, click.option("--host", type=str, help="OceanBase host", default="")]
22
+ user: Annotated[str, click.option("--user", type=str, help="OceanBase username", required=True)]
23
+ password: Annotated[
24
+ str,
25
+ click.option(
26
+ "--password",
27
+ type=str,
28
+ help="OceanBase database password",
29
+ default=lambda: os.environ.get("OB_PASSWORD", ""),
30
+ ),
31
+ ]
32
+ database: Annotated[str, click.option("--database", type=str, help="DataBase name", required=True)]
33
+ port: Annotated[int, click.option("--port", type=int, help="OceanBase port", required=True)]
34
+
35
+
36
+ class OceanBaseHNSWTypedDict(CommonTypedDict, OceanBaseTypedDict, HNSWFlavor4): ...
37
+
38
+
39
+ @cli.command()
40
+ @click_parameter_decorators_from_typed_dict(OceanBaseHNSWTypedDict)
41
+ def OceanBaseHNSW(**parameters: Unpack[OceanBaseHNSWTypedDict]):
42
+ from .config import OceanBaseConfig, OceanBaseHNSWConfig
43
+
44
+ run(
45
+ db=DB.OceanBase,
46
+ db_config=OceanBaseConfig(
47
+ db_label=parameters["db_label"],
48
+ user=SecretStr(parameters["user"]),
49
+ password=SecretStr(parameters["password"]),
50
+ host=parameters["host"],
51
+ port=parameters["port"],
52
+ database=parameters["database"],
53
+ ),
54
+ db_case_config=OceanBaseHNSWConfig(
55
+ m=parameters["m"],
56
+ efConstruction=parameters["ef_construction"],
57
+ ef_search=parameters["ef_search"],
58
+ index=parameters["index_type"],
59
+ ),
60
+ **parameters,
61
+ )
62
+
63
+
64
+ class OceanBaseIVFTypedDict(CommonTypedDict, OceanBaseTypedDict, OceanBaseIVFTypedDict): ...
65
+
66
+
67
+ @cli.command()
68
+ @click_parameter_decorators_from_typed_dict(OceanBaseIVFTypedDict)
69
+ def OceanBaseIVF(**parameters: Unpack[OceanBaseIVFTypedDict]):
70
+ from .config import OceanBaseConfig, OceanBaseIVFConfig
71
+
72
+ type_str = parameters["index_type"]
73
+ if type_str == "IVF_FLAT":
74
+ input_index_type = IndexType.IVFFlat
75
+ elif type_str == "IVF_PQ":
76
+ input_index_type = IndexType.IVFPQ
77
+ elif type_str == "IVF_SQ8":
78
+ input_index_type = IndexType.IVFSQ8
79
+
80
+ input_m = 0 if parameters["m"] is None else parameters["m"]
81
+
82
+ run(
83
+ db=DB.OceanBase,
84
+ db_config=OceanBaseConfig(
85
+ db_label=parameters["db_label"],
86
+ user=SecretStr(parameters["user"]),
87
+ password=SecretStr(parameters["password"]),
88
+ host=parameters["host"],
89
+ port=parameters["port"],
90
+ database=parameters["database"],
91
+ ),
92
+ db_case_config=OceanBaseIVFConfig(
93
+ m=input_m,
94
+ nlist=parameters["nlist"],
95
+ sample_per_nlist=parameters["sample_per_nlist"],
96
+ index=input_index_type,
97
+ ivf_nprobes=parameters["ivf_nprobes"],
98
+ ),
99
+ **parameters,
100
+ )
@@ -0,0 +1,125 @@
1
+ from typing import TypedDict
2
+
3
+ from pydantic import BaseModel, SecretStr, validator
4
+
5
+ from ..api import DBCaseConfig, DBConfig, IndexType, MetricType
6
+
7
+
8
+ class OceanBaseConfigDict(TypedDict):
9
+ user: str
10
+ host: str
11
+ port: str
12
+ password: str
13
+ database: str
14
+
15
+
16
+ class OceanBaseConfig(DBConfig):
17
+ user: SecretStr = SecretStr("root@perf")
18
+ password: SecretStr
19
+ host: str
20
+ port: int
21
+ database: str
22
+
23
+ def to_dict(self) -> OceanBaseConfigDict:
24
+ user_str = self.user.get_secret_value()
25
+ pwd_str = self.password.get_secret_value()
26
+ return {
27
+ "user": user_str,
28
+ "host": self.host,
29
+ "port": self.port,
30
+ "password": pwd_str,
31
+ "database": self.database,
32
+ }
33
+
34
+ @validator("*")
35
+ def not_empty_field(cls, v: any, field: any):
36
+ if field.name in ["password", "host", "db_label"]:
37
+ return v
38
+ if isinstance(v, str | SecretStr) and len(v) == 0:
39
+ raise ValueError("Empty string!")
40
+ return v
41
+
42
+
43
+ class OceanBaseIndexConfig(BaseModel):
44
+ index: IndexType
45
+ metric_type: MetricType | None = None
46
+ lib: str = "vsag"
47
+
48
+ def parse_metric(self) -> str:
49
+ if self.metric_type == MetricType.L2 or (
50
+ self.index == IndexType.HNSW_BQ and self.metric_type == MetricType.COSINE
51
+ ):
52
+ return "l2"
53
+ if self.metric_type == MetricType.IP:
54
+ return "inner_product"
55
+ return "cosine"
56
+
57
+ def parse_metric_func_str(self) -> str:
58
+ if self.metric_type == MetricType.L2 or (
59
+ self.index == IndexType.HNSW_BQ and self.metric_type == MetricType.COSINE
60
+ ):
61
+ return "l2_distance"
62
+ if self.metric_type == MetricType.IP:
63
+ return "negative_inner_product"
64
+ return "cosine_distance"
65
+
66
+
67
+ class OceanBaseHNSWConfig(OceanBaseIndexConfig, DBCaseConfig):
68
+ m: int
69
+ efConstruction: int
70
+ ef_search: int | None = None
71
+ index: IndexType
72
+
73
+ def index_param(self) -> dict:
74
+ return {
75
+ "lib": self.lib,
76
+ "metric_type": self.parse_metric(),
77
+ "index_type": self.index.value,
78
+ "params": {"m": self.m, "ef_construction": self.efConstruction},
79
+ }
80
+
81
+ def search_param(self) -> dict:
82
+ return {"metric_type": self.parse_metric_func_str(), "params": {"ef_search": self.ef_search}}
83
+
84
+
85
+ class OceanBaseIVFConfig(OceanBaseIndexConfig, DBCaseConfig):
86
+ m: int
87
+ sample_per_nlist: int
88
+ nlist: int
89
+ index: IndexType
90
+ ivf_nprobes: int | None = None
91
+
92
+ def index_param(self) -> dict:
93
+ if self.index == IndexType.IVFPQ:
94
+ return {
95
+ "lib": "OB",
96
+ "metric_type": self.parse_metric(),
97
+ "index_type": self.index.value,
98
+ "params": {
99
+ "m": self.M,
100
+ "sample_per_nlist": self.sample_per_nlist,
101
+ "nlist": self.nlist,
102
+ },
103
+ }
104
+ return {
105
+ "lib": "OB",
106
+ "metric_type": self.parse_metric(),
107
+ "index_type": self.index.value,
108
+ "params": {
109
+ "sample_per_nlist": self.sample_per_nlist,
110
+ "nlist": self.nlist,
111
+ },
112
+ }
113
+
114
+ def search_param(self) -> dict:
115
+ return {"metric_type": self.metric_type, "params": {"ivf_nprobes": self.ivf_nprobes}}
116
+
117
+
118
+ _oceanbase_case_config = {
119
+ IndexType.HNSW_SQ: OceanBaseHNSWConfig,
120
+ IndexType.HNSW: OceanBaseHNSWConfig,
121
+ IndexType.HNSW_BQ: OceanBaseHNSWConfig,
122
+ IndexType.IVFFlat: OceanBaseIVFConfig,
123
+ IndexType.IVFPQ: OceanBaseIVFConfig,
124
+ IndexType.IVFSQ8: OceanBaseIVFConfig,
125
+ }