vectordb-bench 0.0.10__py3-none-any.whl → 0.0.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. vectordb_bench/__init__.py +19 -5
  2. vectordb_bench/backend/assembler.py +1 -1
  3. vectordb_bench/backend/cases.py +93 -27
  4. vectordb_bench/backend/clients/__init__.py +14 -0
  5. vectordb_bench/backend/clients/api.py +1 -1
  6. vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py +159 -0
  7. vectordb_bench/backend/clients/aws_opensearch/cli.py +44 -0
  8. vectordb_bench/backend/clients/aws_opensearch/config.py +58 -0
  9. vectordb_bench/backend/clients/aws_opensearch/run.py +125 -0
  10. vectordb_bench/backend/clients/milvus/cli.py +291 -0
  11. vectordb_bench/backend/clients/milvus/milvus.py +13 -6
  12. vectordb_bench/backend/clients/pgvector/cli.py +116 -0
  13. vectordb_bench/backend/clients/pgvector/config.py +1 -1
  14. vectordb_bench/backend/clients/pgvector/pgvector.py +7 -4
  15. vectordb_bench/backend/clients/redis/cli.py +74 -0
  16. vectordb_bench/backend/clients/test/cli.py +25 -0
  17. vectordb_bench/backend/clients/test/config.py +18 -0
  18. vectordb_bench/backend/clients/test/test.py +62 -0
  19. vectordb_bench/backend/clients/weaviate_cloud/cli.py +41 -0
  20. vectordb_bench/backend/clients/zilliz_cloud/cli.py +55 -0
  21. vectordb_bench/backend/dataset.py +27 -5
  22. vectordb_bench/backend/runner/mp_runner.py +14 -3
  23. vectordb_bench/backend/runner/serial_runner.py +7 -3
  24. vectordb_bench/backend/task_runner.py +76 -26
  25. vectordb_bench/cli/__init__.py +0 -0
  26. vectordb_bench/cli/cli.py +362 -0
  27. vectordb_bench/cli/vectordbbench.py +22 -0
  28. vectordb_bench/config-files/sample_config.yml +17 -0
  29. vectordb_bench/custom/custom_case.json +18 -0
  30. vectordb_bench/frontend/components/check_results/charts.py +6 -6
  31. vectordb_bench/frontend/components/check_results/data.py +23 -20
  32. vectordb_bench/frontend/components/check_results/expanderStyle.py +1 -1
  33. vectordb_bench/frontend/components/check_results/filters.py +20 -13
  34. vectordb_bench/frontend/components/check_results/headerIcon.py +1 -1
  35. vectordb_bench/frontend/components/check_results/priceTable.py +1 -1
  36. vectordb_bench/frontend/components/check_results/stPageConfig.py +1 -1
  37. vectordb_bench/frontend/components/concurrent/charts.py +79 -0
  38. vectordb_bench/frontend/components/custom/displayCustomCase.py +31 -0
  39. vectordb_bench/frontend/components/custom/displaypPrams.py +11 -0
  40. vectordb_bench/frontend/components/custom/getCustomConfig.py +40 -0
  41. vectordb_bench/frontend/components/custom/initStyle.py +15 -0
  42. vectordb_bench/frontend/components/run_test/autoRefresh.py +1 -1
  43. vectordb_bench/frontend/components/run_test/caseSelector.py +40 -28
  44. vectordb_bench/frontend/components/run_test/dbConfigSetting.py +1 -5
  45. vectordb_bench/frontend/components/run_test/dbSelector.py +8 -14
  46. vectordb_bench/frontend/components/run_test/generateTasks.py +3 -5
  47. vectordb_bench/frontend/components/run_test/initStyle.py +14 -0
  48. vectordb_bench/frontend/components/run_test/submitTask.py +13 -5
  49. vectordb_bench/frontend/components/tables/data.py +44 -0
  50. vectordb_bench/frontend/{const → config}/dbCaseConfigs.py +140 -32
  51. vectordb_bench/frontend/{const → config}/styles.py +2 -0
  52. vectordb_bench/frontend/pages/concurrent.py +65 -0
  53. vectordb_bench/frontend/pages/custom.py +64 -0
  54. vectordb_bench/frontend/pages/quries_per_dollar.py +5 -5
  55. vectordb_bench/frontend/pages/run_test.py +4 -0
  56. vectordb_bench/frontend/pages/tables.py +24 -0
  57. vectordb_bench/frontend/utils.py +17 -1
  58. vectordb_bench/frontend/vdb_benchmark.py +3 -3
  59. vectordb_bench/interface.py +21 -25
  60. vectordb_bench/metric.py +23 -1
  61. vectordb_bench/models.py +45 -1
  62. vectordb_bench/results/getLeaderboardData.py +1 -1
  63. {vectordb_bench-0.0.10.dist-info → vectordb_bench-0.0.12.dist-info}/METADATA +228 -14
  64. vectordb_bench-0.0.12.dist-info/RECORD +115 -0
  65. {vectordb_bench-0.0.10.dist-info → vectordb_bench-0.0.12.dist-info}/WHEEL +1 -1
  66. {vectordb_bench-0.0.10.dist-info → vectordb_bench-0.0.12.dist-info}/entry_points.txt +1 -0
  67. vectordb_bench-0.0.10.dist-info/RECORD +0 -88
  68. /vectordb_bench/frontend/{const → config}/dbPrices.py +0 -0
  69. {vectordb_bench-0.0.10.dist-info → vectordb_bench-0.0.12.dist-info}/LICENSE +0 -0
  70. {vectordb_bench-0.0.10.dist-info → vectordb_bench-0.0.12.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,125 @@
1
+ import time, random
2
+ from opensearchpy import OpenSearch
3
+ from opensearch_dsl import Search, Document, Text, Keyword
4
+
5
+ _HOST = 'xxxxxx.us-west-2.es.amazonaws.com'
6
+ _PORT = 443
7
+ _AUTH = ('admin', 'xxxxxx') # For testing only. Don't store credentials in code.
8
+
9
+ _INDEX_NAME = 'my-dsl-index'
10
+ _BATCH = 100
11
+ _ROWS = 100
12
+ _DIM = 128
13
+ _TOPK = 10
14
+
15
+
16
+ def create_client():
17
+ client = OpenSearch(
18
+ hosts=[{'host': _HOST, 'port': _PORT}],
19
+ http_compress=True, # enables gzip compression for request bodies
20
+ http_auth=_AUTH,
21
+ use_ssl=True,
22
+ verify_certs=True,
23
+ ssl_assert_hostname=False,
24
+ ssl_show_warn=False,
25
+ )
26
+ return client
27
+
28
+
29
+ def create_index(client, index_name):
30
+ settings = {
31
+ "index": {
32
+ "knn": True,
33
+ "number_of_shards": 1,
34
+ "refresh_interval": "5s",
35
+ }
36
+ }
37
+ mappings = {
38
+ "properties": {
39
+ "embedding": {
40
+ "type": "knn_vector",
41
+ "dimension": _DIM,
42
+ "method": {
43
+ "engine": "nmslib",
44
+ "name": "hnsw",
45
+ "space_type": "l2",
46
+ "parameters": {
47
+ "ef_construction": 128,
48
+ "m": 24,
49
+ }
50
+ }
51
+ }
52
+ }
53
+ }
54
+
55
+ response = client.indices.create(index=index_name, body=dict(settings=settings, mappings=mappings))
56
+ print('\nCreating index:')
57
+ print(response)
58
+
59
+
60
+ def delete_index(client, index_name):
61
+ response = client.indices.delete(index=index_name)
62
+ print('\nDeleting index:')
63
+ print(response)
64
+
65
+
66
+ def bulk_insert(client, index_name):
67
+ # Perform bulk operations
68
+ ids = [i for i in range(_ROWS)]
69
+ vec = [[random.random() for _ in range(_DIM)] for _ in range(_ROWS)]
70
+
71
+ docs = []
72
+ for i in range(0, _ROWS, _BATCH):
73
+ docs.clear()
74
+ for j in range(0, _BATCH):
75
+ docs.append({"index": {"_index": index_name, "_id": ids[i+j]}})
76
+ docs.append({"embedding": vec[i+j]})
77
+ response = client.bulk(docs)
78
+ print('\nAdding documents:', len(response['items']), response['errors'])
79
+ response = client.indices.stats(index_name)
80
+ print('\nTotal document count in index:', response['_all']['primaries']['indexing']['index_total'])
81
+
82
+
83
+ def search(client, index_name):
84
+ # Search for the document.
85
+ search_body = {
86
+ "size": _TOPK,
87
+ "query": {
88
+ "knn": {
89
+ "embedding": {
90
+ "vector": [random.random() for _ in range(_DIM)],
91
+ "k": _TOPK,
92
+ }
93
+ }
94
+ }
95
+ }
96
+ while True:
97
+ response = client.search(index=index_name, body=search_body)
98
+ print(f'\nSearch took: {response["took"]}')
99
+ print(f'\nSearch shards: {response["_shards"]}')
100
+ print(f'\nSearch hits total: {response["hits"]["total"]}')
101
+ result = response["hits"]["hits"]
102
+ if len(result) != 0:
103
+ print('\nSearch results:')
104
+ for hit in response["hits"]["hits"]:
105
+ print(hit["_id"], hit["_score"])
106
+ break
107
+ else:
108
+ print('\nSearch not ready, sleep 1s')
109
+ time.sleep(1)
110
+
111
+
112
+ def main():
113
+ client = create_client()
114
+ try:
115
+ create_index(client, _INDEX_NAME)
116
+ bulk_insert(client, _INDEX_NAME)
117
+ search(client, _INDEX_NAME)
118
+ delete_index(client, _INDEX_NAME)
119
+ except Exception as e:
120
+ print(e)
121
+ delete_index(client, _INDEX_NAME)
122
+
123
+
124
+ if __name__ == '__main__':
125
+ main()
@@ -0,0 +1,291 @@
1
+ from typing import Annotated, TypedDict, Unpack
2
+
3
+ import click
4
+ from pydantic import SecretStr
5
+
6
+ from vectordb_bench.cli.cli import (
7
+ CommonTypedDict,
8
+ HNSWFlavor3,
9
+ IVFFlatTypedDictN,
10
+ cli,
11
+ click_parameter_decorators_from_typed_dict,
12
+ run,
13
+
14
+ )
15
+ from vectordb_bench.backend.clients import DB
16
+
17
+ DBTYPE = DB.Milvus
18
+
19
+
20
+ class MilvusTypedDict(TypedDict):
21
+ uri: Annotated[
22
+ str, click.option("--uri", type=str, help="uri connection string", required=True)
23
+ ]
24
+
25
+
26
+ class MilvusAutoIndexTypedDict(CommonTypedDict, MilvusTypedDict):
27
+ ...
28
+
29
+
30
+ @cli.command()
31
+ @click_parameter_decorators_from_typed_dict(MilvusAutoIndexTypedDict)
32
+ def MilvusAutoIndex(**parameters: Unpack[MilvusAutoIndexTypedDict]):
33
+ from .config import MilvusConfig, AutoIndexConfig
34
+
35
+ run(
36
+ db=DBTYPE,
37
+ db_config=MilvusConfig(
38
+ db_label=parameters["db_label"],
39
+ uri=SecretStr(parameters["uri"]),
40
+ ),
41
+ db_case_config=AutoIndexConfig(),
42
+ **parameters,
43
+ )
44
+
45
+
46
+ @cli.command()
47
+ @click_parameter_decorators_from_typed_dict(MilvusAutoIndexTypedDict)
48
+ def MilvusFlat(**parameters: Unpack[MilvusAutoIndexTypedDict]):
49
+ from .config import MilvusConfig, FLATConfig
50
+
51
+ run(
52
+ db=DBTYPE,
53
+ db_config=MilvusConfig(
54
+ db_label=parameters["db_label"],
55
+ uri=SecretStr(parameters["uri"]),
56
+ ),
57
+ db_case_config=FLATConfig(),
58
+ **parameters,
59
+ )
60
+
61
+
62
+ class MilvusHNSWTypedDict(CommonTypedDict, MilvusTypedDict, HNSWFlavor3):
63
+ ...
64
+
65
+
66
+ @cli.command()
67
+ @click_parameter_decorators_from_typed_dict(MilvusHNSWTypedDict)
68
+ def MilvusHNSW(**parameters: Unpack[MilvusHNSWTypedDict]):
69
+ from .config import MilvusConfig, HNSWConfig
70
+
71
+ run(
72
+ db=DBTYPE,
73
+ db_config=MilvusConfig(
74
+ db_label=parameters["db_label"],
75
+ uri=SecretStr(parameters["uri"]),
76
+ ),
77
+ db_case_config=HNSWConfig(
78
+ M=parameters["m"],
79
+ efConstruction=parameters["ef_construction"],
80
+ ef=parameters["ef_search"],
81
+ ),
82
+ **parameters,
83
+ )
84
+
85
+
86
+ class MilvusIVFFlatTypedDict(CommonTypedDict, MilvusTypedDict, IVFFlatTypedDictN):
87
+ ...
88
+
89
+
90
+ @cli.command()
91
+ @click_parameter_decorators_from_typed_dict(MilvusIVFFlatTypedDict)
92
+ def MilvusIVFFlat(**parameters: Unpack[MilvusIVFFlatTypedDict]):
93
+ from .config import MilvusConfig, IVFFlatConfig
94
+
95
+ run(
96
+ db=DBTYPE,
97
+ db_config=MilvusConfig(
98
+ db_label=parameters["db_label"],
99
+ uri=SecretStr(parameters["uri"]),
100
+ ),
101
+ db_case_config=IVFFlatConfig(
102
+ nlist=parameters["nlist"],
103
+ nprobe=parameters["nprobe"],
104
+ ),
105
+ **parameters,
106
+ )
107
+
108
+
109
+ @cli.command()
110
+ @click_parameter_decorators_from_typed_dict(MilvusIVFFlatTypedDict)
111
+ def MilvusIVFSQ8(**parameters: Unpack[MilvusIVFFlatTypedDict]):
112
+ from .config import MilvusConfig, IVFSQ8Config
113
+
114
+ run(
115
+ db=DBTYPE,
116
+ db_config=MilvusConfig(
117
+ db_label=parameters["db_label"],
118
+ uri=SecretStr(parameters["uri"]),
119
+ ),
120
+ db_case_config=IVFSQ8Config(
121
+ nlist=parameters["nlist"],
122
+ nprobe=parameters["nprobe"],
123
+ ),
124
+ **parameters,
125
+ )
126
+
127
+
128
+ class MilvusDISKANNTypedDict(CommonTypedDict, MilvusTypedDict):
129
+ search_list: Annotated[
130
+ str, click.option("--search-list",
131
+ type=int,
132
+ required=True)
133
+ ]
134
+
135
+
136
+ @cli.command()
137
+ @click_parameter_decorators_from_typed_dict(MilvusDISKANNTypedDict)
138
+ def MilvusDISKANN(**parameters: Unpack[MilvusDISKANNTypedDict]):
139
+ from .config import MilvusConfig, DISKANNConfig
140
+
141
+ run(
142
+ db=DBTYPE,
143
+ db_config=MilvusConfig(
144
+ db_label=parameters["db_label"],
145
+ uri=SecretStr(parameters["uri"]),
146
+ ),
147
+ db_case_config=DISKANNConfig(
148
+ search_list=parameters["search_list"],
149
+ ),
150
+ **parameters,
151
+ )
152
+
153
+
154
+ class MilvusGPUIVFTypedDict(CommonTypedDict, MilvusTypedDict, MilvusIVFFlatTypedDict):
155
+ cache_dataset_on_device: Annotated[
156
+ str, click.option("--cache-dataset-on-device",
157
+ type=str,
158
+ required=True)
159
+ ]
160
+ refine_ratio: Annotated[
161
+ str, click.option("--refine-ratio",
162
+ type=float,
163
+ required=True)
164
+ ]
165
+
166
+
167
+ @cli.command()
168
+ @click_parameter_decorators_from_typed_dict(MilvusGPUIVFTypedDict)
169
+ def MilvusGPUIVFFlat(**parameters: Unpack[MilvusGPUIVFTypedDict]):
170
+ from .config import MilvusConfig, GPUIVFFlatConfig
171
+
172
+ run(
173
+ db=DBTYPE,
174
+ db_config=MilvusConfig(
175
+ db_label=parameters["db_label"],
176
+ uri=SecretStr(parameters["uri"]),
177
+ ),
178
+ db_case_config=GPUIVFFlatConfig(
179
+ nlist=parameters["nlist"],
180
+ nprobe=parameters["nprobe"],
181
+ cache_dataset_on_device=parameters["cache_dataset_on_device"],
182
+ refine_ratio=parameters.get("refine_ratio"),
183
+ ),
184
+ **parameters,
185
+ )
186
+
187
+
188
+ class MilvusGPUIVFPQTypedDict(CommonTypedDict, MilvusTypedDict, MilvusIVFFlatTypedDict, MilvusGPUIVFTypedDict):
189
+ m: Annotated[
190
+ str, click.option("--m",
191
+ type=int, help="hnsw m",
192
+ required=True)
193
+ ]
194
+ nbits: Annotated[
195
+ str, click.option("--nbits",
196
+ type=int,
197
+ required=True)
198
+ ]
199
+
200
+
201
+ @cli.command()
202
+ @click_parameter_decorators_from_typed_dict(MilvusGPUIVFPQTypedDict)
203
+ def MilvusGPUIVFPQ(**parameters: Unpack[MilvusGPUIVFPQTypedDict]):
204
+ from .config import MilvusConfig, GPUIVFPQConfig
205
+
206
+ run(
207
+ db=DBTYPE,
208
+ db_config=MilvusConfig(
209
+ db_label=parameters["db_label"],
210
+ uri=SecretStr(parameters["uri"]),
211
+ ),
212
+ db_case_config=GPUIVFPQConfig(
213
+ nlist=parameters["nlist"],
214
+ nprobe=parameters["nprobe"],
215
+ m=parameters["m"],
216
+ nbits=parameters["nbits"],
217
+ cache_dataset_on_device=parameters["cache_dataset_on_device"],
218
+ refine_ratio=parameters["refine_ratio"],
219
+ ),
220
+ **parameters,
221
+ )
222
+
223
+
224
+ class MilvusGPUCAGRATypedDict(CommonTypedDict, MilvusTypedDict, MilvusGPUIVFTypedDict):
225
+ intermediate_graph_degree: Annotated[
226
+ str, click.option("--intermediate-graph-degree",
227
+ type=int,
228
+ required=True)
229
+ ]
230
+ graph_degree: Annotated[
231
+ str, click.option("--graph-degree",
232
+ type=int,
233
+ required=True)
234
+ ]
235
+ build_algo: Annotated[
236
+ str, click.option("--build_algo",
237
+ type=str,
238
+ required=True)
239
+ ]
240
+ team_size: Annotated[
241
+ str, click.option("--team-size",
242
+ type=int,
243
+ required=True)
244
+ ]
245
+ search_width: Annotated[
246
+ str, click.option("--search-width",
247
+ type=int,
248
+ required=True)
249
+ ]
250
+ itopk_size: Annotated[
251
+ str, click.option("--itopk-size",
252
+ type=int,
253
+ required=True)
254
+ ]
255
+ min_iterations: Annotated[
256
+ str, click.option("--min-iterations",
257
+ type=int,
258
+ required=True)
259
+ ]
260
+ max_iterations: Annotated[
261
+ str, click.option("--max-iterations",
262
+ type=int,
263
+ required=True)
264
+ ]
265
+
266
+
267
+ @cli.command()
268
+ @click_parameter_decorators_from_typed_dict(MilvusGPUCAGRATypedDict)
269
+ def MilvusGPUCAGRA(**parameters: Unpack[MilvusGPUCAGRATypedDict]):
270
+ from .config import MilvusConfig, GPUCAGRAConfig
271
+
272
+ run(
273
+ db=DBTYPE,
274
+ db_config=MilvusConfig(
275
+ db_label=parameters["db_label"],
276
+ uri=SecretStr(parameters["uri"]),
277
+ ),
278
+ db_case_config=GPUCAGRAConfig(
279
+ intermediate_graph_degree=parameters["intermediate_graph_degree"],
280
+ graph_degree=parameters["graph_degree"],
281
+ itopk_size=parameters["itopk_size"],
282
+ team_size=parameters["team_size"],
283
+ search_width=parameters["search_width"],
284
+ min_iterations=parameters["min_iterations"],
285
+ max_iterations=parameters["max_iterations"],
286
+ build_algo=parameters["build_algo"],
287
+ cache_dataset_on_device=parameters["cache_dataset_on_device"],
288
+ refine_ratio=parameters["refine_ratio"],
289
+ ),
290
+ **parameters,
291
+ )
@@ -8,7 +8,7 @@ from typing import Iterable
8
8
  from pymilvus import Collection, utility
9
9
  from pymilvus import CollectionSchema, DataType, FieldSchema, MilvusException
10
10
 
11
- from ..api import VectorDB
11
+ from ..api import VectorDB, IndexType
12
12
  from .config import MilvusIndexConfig
13
13
 
14
14
 
@@ -122,10 +122,18 @@ class Milvus(VectorDB):
122
122
  if self.case_config.is_gpu_index:
123
123
  log.debug("skip compaction for gpu index type.")
124
124
  else :
125
- self.col.compact()
126
- self.col.wait_for_compaction_completed()
125
+ try:
126
+ self.col.compact()
127
+ self.col.wait_for_compaction_completed()
128
+ except Exception as e:
129
+ log.warning(f"{self.name} compact error: {e}")
130
+ if hasattr(e, 'code'):
131
+ if e.code().name == 'PERMISSION_DENIED':
132
+ log.warning(f"Skip compact due to permission denied.")
133
+ pass
134
+ else:
135
+ raise e
127
136
  wait_index()
128
-
129
137
  except Exception as e:
130
138
  log.warning(f"{self.name} optimize error: {e}")
131
139
  raise e from None
@@ -143,7 +151,6 @@ class Milvus(VectorDB):
143
151
  self.case_config.index_param(),
144
152
  index_name=self._index_name,
145
153
  )
146
-
147
154
  coll.load()
148
155
  log.info(f"{self.name} load")
149
156
  except Exception as e:
@@ -160,7 +167,7 @@ class Milvus(VectorDB):
160
167
  if self.case_config.is_gpu_index:
161
168
  log.info(f"current gpu_index only supports IP / L2, cosine dataset need normalize.")
162
169
  return True
163
-
170
+
164
171
  return False
165
172
 
166
173
  def insert_embeddings(
@@ -0,0 +1,116 @@
1
+ from typing import Annotated, Optional, TypedDict, Unpack
2
+
3
+ import click
4
+ import os
5
+ from pydantic import SecretStr
6
+
7
+ from ....cli.cli import (
8
+ CommonTypedDict,
9
+ HNSWFlavor1,
10
+ IVFFlatTypedDict,
11
+ cli,
12
+ click_parameter_decorators_from_typed_dict,
13
+ run,
14
+ )
15
+ from vectordb_bench.backend.clients import DB
16
+
17
+
18
+ class PgVectorTypedDict(CommonTypedDict):
19
+ user_name: Annotated[
20
+ str, click.option("--user-name", type=str, help="Db username", required=True)
21
+ ]
22
+ password: Annotated[
23
+ str,
24
+ click.option("--password",
25
+ type=str,
26
+ help="Postgres database password",
27
+ default=lambda: os.environ.get("POSTGRES_PASSWORD", ""),
28
+ show_default="$POSTGRES_PASSWORD",
29
+ ),
30
+ ]
31
+
32
+ host: Annotated[
33
+ str, click.option("--host", type=str, help="Db host", required=True)
34
+ ]
35
+ db_name: Annotated[
36
+ str, click.option("--db-name", type=str, help="Db name", required=True)
37
+ ]
38
+ maintenance_work_mem: Annotated[
39
+ Optional[str],
40
+ click.option(
41
+ "--maintenance-work-mem",
42
+ type=str,
43
+ help="Sets the maximum memory to be used for maintenance operations (index creation). "
44
+ "Can be entered as string with unit like '64GB' or as an integer number of KB."
45
+ "This will set the parameters: max_parallel_maintenance_workers,"
46
+ " max_parallel_workers & table(parallel_workers)",
47
+ required=False,
48
+ ),
49
+ ]
50
+ max_parallel_workers: Annotated[
51
+ Optional[int],
52
+ click.option(
53
+ "--max-parallel-workers",
54
+ type=int,
55
+ help="Sets the maximum number of parallel processes per maintenance operation (index creation)",
56
+ required=False,
57
+ ),
58
+ ]
59
+
60
+
61
+ class PgVectorIVFFlatTypedDict(PgVectorTypedDict, IVFFlatTypedDict):
62
+ ...
63
+
64
+
65
+ @cli.command()
66
+ @click_parameter_decorators_from_typed_dict(PgVectorIVFFlatTypedDict)
67
+ def PgVectorIVFFlat(
68
+ **parameters: Unpack[PgVectorIVFFlatTypedDict],
69
+ ):
70
+ from .config import PgVectorConfig, PgVectorIVFFlatConfig
71
+
72
+ run(
73
+ db=DB.PgVector,
74
+ db_config=PgVectorConfig(
75
+ db_label=parameters["db_label"],
76
+ user_name=SecretStr(parameters["user_name"]),
77
+ password=SecretStr(parameters["password"]),
78
+ host=parameters["host"],
79
+ db_name=parameters["db_name"],
80
+ ),
81
+ db_case_config=PgVectorIVFFlatConfig(
82
+ metric_type=None, lists=parameters["lists"], probes=parameters["probes"]
83
+ ),
84
+ **parameters,
85
+ )
86
+
87
+
88
+ class PgVectorHNSWTypedDict(PgVectorTypedDict, HNSWFlavor1):
89
+ ...
90
+
91
+
92
+ @cli.command()
93
+ @click_parameter_decorators_from_typed_dict(PgVectorHNSWTypedDict)
94
+ def PgVectorHNSW(
95
+ **parameters: Unpack[PgVectorHNSWTypedDict],
96
+ ):
97
+ from .config import PgVectorConfig, PgVectorHNSWConfig
98
+
99
+ run(
100
+ db=DB.PgVector,
101
+ db_config=PgVectorConfig(
102
+ db_label=parameters["db_label"],
103
+ user_name=SecretStr(parameters["user_name"]),
104
+ password=SecretStr(parameters["password"]),
105
+ host=parameters["host"],
106
+ db_name=parameters["db_name"],
107
+ ),
108
+ db_case_config=PgVectorHNSWConfig(
109
+ m=parameters["m"],
110
+ ef_construction=parameters["ef_construction"],
111
+ ef_search=parameters["ef_search"],
112
+ maintenance_work_mem=parameters["maintenance_work_mem"],
113
+ max_parallel_workers=parameters["max_parallel_workers"],
114
+ ),
115
+ **parameters,
116
+ )
@@ -109,7 +109,7 @@ class PgVectorIndexConfig(BaseModel, DBCaseConfig):
109
109
  def _optionally_build_set_options(
110
110
  set_mapping: Mapping[str, Any]
111
111
  ) -> Sequence[dict[str, Any]]:
112
- """Walk through options, creating 'SET 'key1 = "value1";' commands"""
112
+ """Walk through options, creating 'SET 'key1 = "value1";' list"""
113
113
  session_options = []
114
114
  for setting_name, value in set_mapping.items():
115
115
  if value:
@@ -58,14 +58,13 @@ class PgVector(VectorDB):
58
58
  self.case_config.create_index_after_load,
59
59
  )
60
60
  ):
61
- err = f"{self.name} config must create an index using create_index_before_load and/or create_index_after_load"
61
+ err = f"{self.name} config must create an index using create_index_before_load or create_index_after_load"
62
62
  log.error(err)
63
63
  raise RuntimeError(
64
64
  f"{err}\n{pprint.pformat(self.db_config)}\n{pprint.pformat(self.case_config)}"
65
65
  )
66
66
 
67
67
  if drop_old:
68
- # self.pg_table.drop(pg_engine, checkfirst=True)
69
68
  self._drop_index()
70
69
  self._drop_table()
71
70
  self._create_table(dim)
@@ -257,7 +256,10 @@ class PgVector(VectorDB):
257
256
  with_clause = sql.Composed(())
258
257
 
259
258
  index_create_sql = sql.SQL(
260
- "CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name} USING {index_type} (embedding {embedding_metric})"
259
+ """
260
+ CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name}
261
+ USING {index_type} (embedding {embedding_metric})
262
+ """
261
263
  ).format(
262
264
  index_name=sql.Identifier(self._index_name),
263
265
  table_name=sql.Identifier(self.table_name),
@@ -339,9 +341,10 @@ class PgVector(VectorDB):
339
341
  assert self.conn is not None, "Connection is not initialized"
340
342
  assert self.cursor is not None, "Cursor is not initialized"
341
343
 
344
+ q = np.asarray(query)
342
345
  # TODO add filters support
343
346
  result = self.cursor.execute(
344
- self._unfiltered_search, (query, k), prepare=True, binary=True
347
+ self._unfiltered_search, (q, k), prepare=True, binary=True
345
348
  )
346
349
 
347
350
  return [int(i[0]) for i in result.fetchall()]
@@ -0,0 +1,74 @@
1
+ from typing import Annotated, TypedDict, Unpack
2
+
3
+ import click
4
+ from pydantic import SecretStr
5
+
6
+ from ....cli.cli import (
7
+ CommonTypedDict,
8
+ HNSWFlavor2,
9
+ cli,
10
+ click_parameter_decorators_from_typed_dict,
11
+ run,
12
+ )
13
+ from .. import DB
14
+
15
+
16
+ class RedisTypedDict(TypedDict):
17
+ host: Annotated[
18
+ str, click.option("--host", type=str, help="Db host", required=True)
19
+ ]
20
+ password: Annotated[str, click.option("--password", type=str, help="Db password")]
21
+ port: Annotated[int, click.option("--port", type=int, default=6379, help="Db Port")]
22
+ ssl: Annotated[
23
+ bool,
24
+ click.option(
25
+ "--ssl/--no-ssl",
26
+ is_flag=True,
27
+ show_default=True,
28
+ default=True,
29
+ help="Enable or disable SSL for Redis",
30
+ ),
31
+ ]
32
+ ssl_ca_certs: Annotated[
33
+ str,
34
+ click.option(
35
+ "--ssl-ca-certs",
36
+ show_default=True,
37
+ help="Path to certificate authority file to use for SSL",
38
+ ),
39
+ ]
40
+ cmd: Annotated[
41
+ bool,
42
+ click.option(
43
+ "--cmd",
44
+ is_flag=True,
45
+ show_default=True,
46
+ default=False,
47
+ help="Cluster Mode Disabled (CMD) for Redis doesn't use Cluster conn",
48
+ ),
49
+ ]
50
+
51
+
52
+ class RedisHNSWTypedDict(CommonTypedDict, RedisTypedDict, HNSWFlavor2):
53
+ ...
54
+
55
+
56
+ @cli.command()
57
+ @click_parameter_decorators_from_typed_dict(RedisHNSWTypedDict)
58
+ def Redis(**parameters: Unpack[RedisHNSWTypedDict]):
59
+ from .config import RedisConfig
60
+ run(
61
+ db=DB.Redis,
62
+ db_config=RedisConfig(
63
+ db_label=parameters["db_label"],
64
+ password=SecretStr(parameters["password"])
65
+ if parameters["password"]
66
+ else None,
67
+ host=SecretStr(parameters["host"]),
68
+ port=parameters["port"],
69
+ ssl=parameters["ssl"],
70
+ ssl_ca_certs=parameters["ssl_ca_certs"],
71
+ cmd=parameters["cmd"],
72
+ ),
73
+ **parameters,
74
+ )