vectordb-bench 0.0.19__py3-none-any.whl → 0.0.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (105) hide show
  1. vectordb_bench/__init__.py +49 -24
  2. vectordb_bench/__main__.py +4 -3
  3. vectordb_bench/backend/assembler.py +12 -13
  4. vectordb_bench/backend/cases.py +55 -45
  5. vectordb_bench/backend/clients/__init__.py +85 -14
  6. vectordb_bench/backend/clients/aliyun_elasticsearch/aliyun_elasticsearch.py +1 -2
  7. vectordb_bench/backend/clients/aliyun_elasticsearch/config.py +3 -4
  8. vectordb_bench/backend/clients/aliyun_opensearch/aliyun_opensearch.py +112 -77
  9. vectordb_bench/backend/clients/aliyun_opensearch/config.py +6 -7
  10. vectordb_bench/backend/clients/alloydb/alloydb.py +59 -84
  11. vectordb_bench/backend/clients/alloydb/cli.py +51 -34
  12. vectordb_bench/backend/clients/alloydb/config.py +30 -30
  13. vectordb_bench/backend/clients/api.py +13 -24
  14. vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py +50 -54
  15. vectordb_bench/backend/clients/aws_opensearch/cli.py +4 -7
  16. vectordb_bench/backend/clients/aws_opensearch/config.py +13 -9
  17. vectordb_bench/backend/clients/aws_opensearch/run.py +69 -59
  18. vectordb_bench/backend/clients/chroma/chroma.py +39 -40
  19. vectordb_bench/backend/clients/chroma/config.py +4 -2
  20. vectordb_bench/backend/clients/elastic_cloud/config.py +5 -5
  21. vectordb_bench/backend/clients/elastic_cloud/elastic_cloud.py +24 -26
  22. vectordb_bench/backend/clients/memorydb/cli.py +8 -8
  23. vectordb_bench/backend/clients/memorydb/config.py +2 -2
  24. vectordb_bench/backend/clients/memorydb/memorydb.py +67 -58
  25. vectordb_bench/backend/clients/milvus/cli.py +41 -83
  26. vectordb_bench/backend/clients/milvus/config.py +18 -8
  27. vectordb_bench/backend/clients/milvus/milvus.py +19 -39
  28. vectordb_bench/backend/clients/pgdiskann/cli.py +29 -22
  29. vectordb_bench/backend/clients/pgdiskann/config.py +29 -26
  30. vectordb_bench/backend/clients/pgdiskann/pgdiskann.py +56 -77
  31. vectordb_bench/backend/clients/pgvecto_rs/cli.py +9 -11
  32. vectordb_bench/backend/clients/pgvecto_rs/config.py +8 -14
  33. vectordb_bench/backend/clients/pgvecto_rs/pgvecto_rs.py +34 -43
  34. vectordb_bench/backend/clients/pgvector/cli.py +40 -31
  35. vectordb_bench/backend/clients/pgvector/config.py +63 -73
  36. vectordb_bench/backend/clients/pgvector/pgvector.py +98 -104
  37. vectordb_bench/backend/clients/pgvectorscale/cli.py +38 -24
  38. vectordb_bench/backend/clients/pgvectorscale/config.py +14 -15
  39. vectordb_bench/backend/clients/pgvectorscale/pgvectorscale.py +39 -49
  40. vectordb_bench/backend/clients/pinecone/config.py +1 -0
  41. vectordb_bench/backend/clients/pinecone/pinecone.py +15 -25
  42. vectordb_bench/backend/clients/qdrant_cloud/config.py +11 -10
  43. vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py +41 -35
  44. vectordb_bench/backend/clients/redis/cli.py +6 -12
  45. vectordb_bench/backend/clients/redis/config.py +7 -5
  46. vectordb_bench/backend/clients/redis/redis.py +95 -62
  47. vectordb_bench/backend/clients/test/cli.py +2 -3
  48. vectordb_bench/backend/clients/test/config.py +2 -2
  49. vectordb_bench/backend/clients/test/test.py +5 -9
  50. vectordb_bench/backend/clients/weaviate_cloud/cli.py +3 -4
  51. vectordb_bench/backend/clients/weaviate_cloud/config.py +2 -2
  52. vectordb_bench/backend/clients/weaviate_cloud/weaviate_cloud.py +37 -26
  53. vectordb_bench/backend/clients/zilliz_cloud/cli.py +14 -11
  54. vectordb_bench/backend/clients/zilliz_cloud/config.py +2 -4
  55. vectordb_bench/backend/clients/zilliz_cloud/zilliz_cloud.py +1 -1
  56. vectordb_bench/backend/data_source.py +18 -14
  57. vectordb_bench/backend/dataset.py +47 -27
  58. vectordb_bench/backend/result_collector.py +2 -3
  59. vectordb_bench/backend/runner/__init__.py +4 -6
  60. vectordb_bench/backend/runner/mp_runner.py +56 -23
  61. vectordb_bench/backend/runner/rate_runner.py +30 -19
  62. vectordb_bench/backend/runner/read_write_runner.py +46 -22
  63. vectordb_bench/backend/runner/serial_runner.py +81 -46
  64. vectordb_bench/backend/runner/util.py +4 -3
  65. vectordb_bench/backend/task_runner.py +92 -92
  66. vectordb_bench/backend/utils.py +17 -10
  67. vectordb_bench/base.py +0 -1
  68. vectordb_bench/cli/cli.py +65 -60
  69. vectordb_bench/cli/vectordbbench.py +6 -7
  70. vectordb_bench/frontend/components/check_results/charts.py +8 -19
  71. vectordb_bench/frontend/components/check_results/data.py +4 -16
  72. vectordb_bench/frontend/components/check_results/filters.py +8 -16
  73. vectordb_bench/frontend/components/check_results/nav.py +4 -4
  74. vectordb_bench/frontend/components/check_results/priceTable.py +1 -3
  75. vectordb_bench/frontend/components/check_results/stPageConfig.py +2 -1
  76. vectordb_bench/frontend/components/concurrent/charts.py +12 -12
  77. vectordb_bench/frontend/components/custom/displayCustomCase.py +17 -11
  78. vectordb_bench/frontend/components/custom/displaypPrams.py +4 -2
  79. vectordb_bench/frontend/components/custom/getCustomConfig.py +1 -2
  80. vectordb_bench/frontend/components/custom/initStyle.py +1 -1
  81. vectordb_bench/frontend/components/get_results/saveAsImage.py +2 -0
  82. vectordb_bench/frontend/components/run_test/caseSelector.py +3 -9
  83. vectordb_bench/frontend/components/run_test/dbConfigSetting.py +1 -4
  84. vectordb_bench/frontend/components/run_test/dbSelector.py +1 -1
  85. vectordb_bench/frontend/components/run_test/generateTasks.py +8 -8
  86. vectordb_bench/frontend/components/run_test/submitTask.py +14 -18
  87. vectordb_bench/frontend/components/tables/data.py +3 -6
  88. vectordb_bench/frontend/config/dbCaseConfigs.py +51 -84
  89. vectordb_bench/frontend/pages/concurrent.py +3 -5
  90. vectordb_bench/frontend/pages/custom.py +30 -9
  91. vectordb_bench/frontend/pages/quries_per_dollar.py +3 -3
  92. vectordb_bench/frontend/pages/run_test.py +3 -7
  93. vectordb_bench/frontend/utils.py +1 -1
  94. vectordb_bench/frontend/vdb_benchmark.py +4 -6
  95. vectordb_bench/interface.py +45 -24
  96. vectordb_bench/log_util.py +59 -64
  97. vectordb_bench/metric.py +10 -11
  98. vectordb_bench/models.py +26 -43
  99. {vectordb_bench-0.0.19.dist-info → vectordb_bench-0.0.21.dist-info}/METADATA +22 -15
  100. vectordb_bench-0.0.21.dist-info/RECORD +135 -0
  101. {vectordb_bench-0.0.19.dist-info → vectordb_bench-0.0.21.dist-info}/WHEEL +1 -1
  102. vectordb_bench-0.0.19.dist-info/RECORD +0 -135
  103. {vectordb_bench-0.0.19.dist-info → vectordb_bench-0.0.21.dist-info}/LICENSE +0 -0
  104. {vectordb_bench-0.0.19.dist-info → vectordb_bench-0.0.21.dist-info}/entry_points.txt +0 -0
  105. {vectordb_bench-0.0.19.dist-info → vectordb_bench-0.0.21.dist-info}/top_level.txt +0 -0
@@ -1,55 +1,55 @@
1
- import chromadb
2
- import logging
1
+ import logging
3
2
  from contextlib import contextmanager
4
3
  from typing import Any
5
- from ..api import VectorDB, DBCaseConfig
4
+
5
+ import chromadb
6
+
7
+ from ..api import DBCaseConfig, VectorDB
6
8
 
7
9
  log = logging.getLogger(__name__)
10
+
11
+
8
12
  class ChromaClient(VectorDB):
9
- """Chroma client for VectorDB.
13
+ """Chroma client for VectorDB.
10
14
  To set up Chroma in docker, see https://docs.trychroma.com/usage-guide
11
15
  or the instructions in tests/test_chroma.py
12
16
 
13
17
  To change to running in process, modify the HttpClient() in __init__() and init().
14
- """
18
+ """
15
19
 
16
20
  def __init__(
17
- self,
18
- dim: int,
19
- db_config: dict,
20
- db_case_config: DBCaseConfig,
21
- drop_old: bool = False,
22
-
23
- **kwargs
24
- ):
25
-
21
+ self,
22
+ dim: int,
23
+ db_config: dict,
24
+ db_case_config: DBCaseConfig,
25
+ drop_old: bool = False,
26
+ **kwargs,
27
+ ):
26
28
  self.db_config = db_config
27
29
  self.case_config = db_case_config
28
- self.collection_name = 'example2'
30
+ self.collection_name = "example2"
29
31
 
30
- client = chromadb.HttpClient(host=self.db_config["host"],
31
- port=self.db_config["port"])
32
+ client = chromadb.HttpClient(host=self.db_config["host"], port=self.db_config["port"])
32
33
  assert client.heartbeat() is not None
33
34
  if drop_old:
34
35
  try:
35
- client.reset() # Reset the database
36
- except:
36
+ client.reset() # Reset the database
37
+ except Exception:
37
38
  drop_old = False
38
39
  log.info(f"Chroma client drop_old collection: {self.collection_name}")
39
40
 
40
41
  @contextmanager
41
42
  def init(self) -> None:
42
- """ create and destory connections to database.
43
+ """create and destory connections to database.
43
44
 
44
45
  Examples:
45
46
  >>> with self.init():
46
47
  >>> self.insert_embeddings()
47
48
  """
48
- #create connection
49
- self.client = chromadb.HttpClient(host=self.db_config["host"],
50
- port=self.db_config["port"])
51
-
52
- self.collection = self.client.get_or_create_collection('example2')
49
+ # create connection
50
+ self.client = chromadb.HttpClient(host=self.db_config["host"], port=self.db_config["port"])
51
+
52
+ self.collection = self.client.get_or_create_collection("example2")
53
53
  yield
54
54
  self.client = None
55
55
  self.collection = None
@@ -57,10 +57,7 @@ class ChromaClient(VectorDB):
57
57
  def ready_to_search(self) -> bool:
58
58
  pass
59
59
 
60
- def ready_to_load(self) -> bool:
61
- pass
62
-
63
- def optimize(self) -> None:
60
+ def optimize(self, data_size: int | None = None):
64
61
  pass
65
62
 
66
63
  def insert_embeddings(
@@ -79,12 +76,12 @@ class ChromaClient(VectorDB):
79
76
  Returns:
80
77
  (int, Exception): number of embeddings inserted and exception if any
81
78
  """
82
- ids=[str(i) for i in metadata]
83
- metadata = [{"id": int(i)} for i in metadata]
79
+ ids = [str(i) for i in metadata]
80
+ metadata = [{"id": int(i)} for i in metadata]
84
81
  if len(embeddings) > 0:
85
82
  self.collection.add(embeddings=embeddings, ids=ids, metadatas=metadata)
86
83
  return len(embeddings), None
87
-
84
+
88
85
  def search_embedding(
89
86
  self,
90
87
  query: list[float],
@@ -100,17 +97,19 @@ class ChromaClient(VectorDB):
100
97
  kwargs: other arguments
101
98
 
102
99
  Returns:
103
- Dict {ids: list[list[int]],
104
- embedding: list[list[float]]
100
+ Dict {ids: list[list[int]],
101
+ embedding: list[list[float]]
105
102
  distance: list[list[float]]}
106
103
  """
107
104
  if filters:
108
105
  # assumes benchmark test filters of format: {'metadata': '>=10000', 'id': 10000}
109
106
  id_value = filters.get("id")
110
- results = self.collection.query(query_embeddings=query, n_results=k,
111
- where={"id": {"$gt": id_value}})
112
- #return list of id's in results
113
- return [int(i) for i in results.get('ids')[0]]
107
+ results = self.collection.query(
108
+ query_embeddings=query,
109
+ n_results=k,
110
+ where={"id": {"$gt": id_value}},
111
+ )
112
+ # return list of id's in results
113
+ return [int(i) for i in results.get("ids")[0]]
114
114
  results = self.collection.query(query_embeddings=query, n_results=k)
115
- return [int(i) for i in results.get('ids')[0]]
116
-
115
+ return [int(i) for i in results.get("ids")[0]]
@@ -1,14 +1,16 @@
1
1
  from pydantic import SecretStr
2
+
2
3
  from ..api import DBConfig
3
4
 
5
+
4
6
  class ChromaConfig(DBConfig):
5
7
  password: SecretStr
6
8
  host: SecretStr
7
- port: int
9
+ port: int
8
10
 
9
11
  def to_dict(self) -> dict:
10
12
  return {
11
13
  "host": self.host.get_secret_value(),
12
14
  "port": self.port,
13
15
  "password": self.password.get_secret_value(),
14
- }
16
+ }
@@ -1,7 +1,8 @@
1
1
  from enum import Enum
2
- from pydantic import SecretStr, BaseModel
3
2
 
4
- from ..api import DBConfig, DBCaseConfig, MetricType, IndexType
3
+ from pydantic import BaseModel, SecretStr
4
+
5
+ from ..api import DBCaseConfig, DBConfig, IndexType, MetricType
5
6
 
6
7
 
7
8
  class ElasticCloudConfig(DBConfig, BaseModel):
@@ -32,12 +33,12 @@ class ElasticCloudIndexConfig(BaseModel, DBCaseConfig):
32
33
  def parse_metric(self) -> str:
33
34
  if self.metric_type == MetricType.L2:
34
35
  return "l2_norm"
35
- elif self.metric_type == MetricType.IP:
36
+ if self.metric_type == MetricType.IP:
36
37
  return "dot_product"
37
38
  return "cosine"
38
39
 
39
40
  def index_param(self) -> dict:
40
- params = {
41
+ return {
41
42
  "type": "dense_vector",
42
43
  "index": True,
43
44
  "element_type": self.element_type.value,
@@ -48,7 +49,6 @@ class ElasticCloudIndexConfig(BaseModel, DBCaseConfig):
48
49
  "ef_construction": self.efConstruction,
49
50
  },
50
51
  }
51
- return params
52
52
 
53
53
  def search_param(self) -> dict:
54
54
  return {
@@ -1,17 +1,22 @@
1
1
  import logging
2
2
  import time
3
+ from collections.abc import Iterable
3
4
  from contextlib import contextmanager
4
- from typing import Iterable
5
- from ..api import VectorDB
6
- from .config import ElasticCloudIndexConfig
5
+
7
6
  from elasticsearch.helpers import bulk
8
7
 
8
+ from ..api import VectorDB
9
+ from .config import ElasticCloudIndexConfig
9
10
 
10
11
  for logger in ("elasticsearch", "elastic_transport"):
11
12
  logging.getLogger(logger).setLevel(logging.WARNING)
12
13
 
13
14
  log = logging.getLogger(__name__)
14
15
 
16
+
17
+ SECONDS_WAITING_FOR_FORCE_MERGE_API_CALL_SEC = 30
18
+
19
+
15
20
  class ElasticCloud(VectorDB):
16
21
  def __init__(
17
22
  self,
@@ -46,14 +51,14 @@ class ElasticCloud(VectorDB):
46
51
  def init(self) -> None:
47
52
  """connect to elasticsearch"""
48
53
  from elasticsearch import Elasticsearch
54
+
49
55
  self.client = Elasticsearch(**self.db_config, request_timeout=180)
50
56
 
51
57
  yield
52
- # self.client.transport.close()
53
58
  self.client = None
54
- del(self.client)
59
+ del self.client
55
60
 
56
- def _create_indice(self, client) -> None:
61
+ def _create_indice(self, client: any) -> None:
57
62
  mappings = {
58
63
  "_source": {"excludes": [self.vector_col_name]},
59
64
  "properties": {
@@ -62,13 +67,13 @@ class ElasticCloud(VectorDB):
62
67
  "dims": self.dim,
63
68
  **self.case_config.index_param(),
64
69
  },
65
- }
70
+ },
66
71
  }
67
72
 
68
73
  try:
69
74
  client.indices.create(index=self.indice, mappings=mappings)
70
75
  except Exception as e:
71
- log.warning(f"Failed to create indice: {self.indice} error: {str(e)}")
76
+ log.warning(f"Failed to create indice: {self.indice} error: {e!s}")
72
77
  raise e from None
73
78
 
74
79
  def insert_embeddings(
@@ -94,7 +99,7 @@ class ElasticCloud(VectorDB):
94
99
  bulk_insert_res = bulk(self.client, insert_data)
95
100
  return (bulk_insert_res[0], None)
96
101
  except Exception as e:
97
- log.warning(f"Failed to insert data: {self.indice} error: {str(e)}")
102
+ log.warning(f"Failed to insert data: {self.indice} error: {e!s}")
98
103
  return (0, e)
99
104
 
100
105
  def search_embedding(
@@ -114,16 +119,12 @@ class ElasticCloud(VectorDB):
114
119
  list[tuple[int, float]]: list of k most similar embeddings in (id, score) tuple to the query embedding.
115
120
  """
116
121
  assert self.client is not None, "should self.init() first"
117
- # is_existed_res = self.client.indices.exists(index=self.indice)
118
- # assert is_existed_res.raw == True, "should self.init() first"
119
122
 
120
123
  knn = {
121
124
  "field": self.vector_col_name,
122
125
  "k": k,
123
126
  "num_candidates": self.case_config.num_candidates,
124
- "filter": [{"range": {self.id_col_name: {"gt": filters["id"]}}}]
125
- if filters
126
- else [],
127
+ "filter": [{"range": {self.id_col_name: {"gt": filters["id"]}}}] if filters else [],
127
128
  "query_vector": query,
128
129
  }
129
130
  size = k
@@ -137,26 +138,23 @@ class ElasticCloud(VectorDB):
137
138
  stored_fields="_none_",
138
139
  filter_path=[f"hits.hits.fields.{self.id_col_name}"],
139
140
  )
140
- res = [h["fields"][self.id_col_name][0] for h in res["hits"]["hits"]]
141
-
142
- return res
141
+ return [h["fields"][self.id_col_name][0] for h in res["hits"]["hits"]]
143
142
  except Exception as e:
144
- log.warning(f"Failed to search: {self.indice} error: {str(e)}")
143
+ log.warning(f"Failed to search: {self.indice} error: {e!s}")
145
144
  raise e from None
146
145
 
147
- def optimize(self):
146
+ def optimize(self, data_size: int | None = None):
148
147
  """optimize will be called between insertion and search in performance cases."""
149
148
  assert self.client is not None, "should self.init() first"
150
149
  self.client.indices.refresh(index=self.indice)
151
- force_merge_task_id = self.client.indices.forcemerge(index=self.indice, max_num_segments=1, wait_for_completion=False)['task']
150
+ force_merge_task_id = self.client.indices.forcemerge(
151
+ index=self.indice,
152
+ max_num_segments=1,
153
+ wait_for_completion=False,
154
+ )["task"]
152
155
  log.info(f"Elasticsearch force merge task id: {force_merge_task_id}")
153
- SECONDS_WAITING_FOR_FORCE_MERGE_API_CALL_SEC = 30
154
156
  while True:
155
157
  time.sleep(SECONDS_WAITING_FOR_FORCE_MERGE_API_CALL_SEC)
156
158
  task_status = self.client.tasks.get(task_id=force_merge_task_id)
157
- if task_status['completed']:
159
+ if task_status["completed"]:
158
160
  return
159
-
160
- def ready_to_load(self):
161
- """ready_to_load will be called before load in load cases."""
162
- pass
@@ -14,9 +14,7 @@ from .. import DB
14
14
 
15
15
 
16
16
  class MemoryDBTypedDict(TypedDict):
17
- host: Annotated[
18
- str, click.option("--host", type=str, help="Db host", required=True)
19
- ]
17
+ host: Annotated[str, click.option("--host", type=str, help="Db host", required=True)]
20
18
  password: Annotated[str, click.option("--password", type=str, help="Db password")]
21
19
  port: Annotated[int, click.option("--port", type=int, default=6379, help="Db Port")]
22
20
  ssl: Annotated[
@@ -44,7 +42,10 @@ class MemoryDBTypedDict(TypedDict):
44
42
  is_flag=True,
45
43
  show_default=True,
46
44
  default=False,
47
- help="Cluster Mode Disabled (CMD), use this flag when testing locally on a single node instance. In production, MemoryDB only supports cluster mode (CME)",
45
+ help=(
46
+ "Cluster Mode Disabled (CMD), use this flag when testing locally on a single node instance."
47
+ " In production, MemoryDB only supports cluster mode (CME)"
48
+ ),
48
49
  ),
49
50
  ]
50
51
  insert_batch_size: Annotated[
@@ -58,8 +59,7 @@ class MemoryDBTypedDict(TypedDict):
58
59
  ]
59
60
 
60
61
 
61
- class MemoryDBHNSWTypedDict(CommonTypedDict, MemoryDBTypedDict, HNSWFlavor2):
62
- ...
62
+ class MemoryDBHNSWTypedDict(CommonTypedDict, MemoryDBTypedDict, HNSWFlavor2): ...
63
63
 
64
64
 
65
65
  @cli.command()
@@ -82,7 +82,7 @@ def MemoryDB(**parameters: Unpack[MemoryDBHNSWTypedDict]):
82
82
  M=parameters["m"],
83
83
  ef_construction=parameters["ef_construction"],
84
84
  ef_runtime=parameters["ef_runtime"],
85
- insert_batch_size=parameters["insert_batch_size"]
85
+ insert_batch_size=parameters["insert_batch_size"],
86
86
  ),
87
87
  **parameters,
88
- )
88
+ )
@@ -29,7 +29,7 @@ class MemoryDBIndexConfig(BaseModel, DBCaseConfig):
29
29
  def parse_metric(self) -> str:
30
30
  if self.metric_type == MetricType.L2:
31
31
  return "l2"
32
- elif self.metric_type == MetricType.IP:
32
+ if self.metric_type == MetricType.IP:
33
33
  return "ip"
34
34
  return "cosine"
35
35
 
@@ -51,4 +51,4 @@ class MemoryDBHNSWConfig(MemoryDBIndexConfig):
51
51
  def search_param(self) -> dict:
52
52
  return {
53
53
  "ef_runtime": self.ef_runtime,
54
- }
54
+ }
@@ -1,30 +1,33 @@
1
- import logging, time
1
+ import logging
2
+ import time
3
+ from collections.abc import Generator
2
4
  from contextlib import contextmanager
3
- from typing import Any, Generator, Optional, Tuple, Type
4
- from ..api import VectorDB, DBCaseConfig, IndexType
5
- from .config import MemoryDBIndexConfig
5
+ from typing import Any
6
+
7
+ import numpy as np
6
8
  import redis
7
9
  from redis import Redis
8
10
  from redis.cluster import RedisCluster
9
- from redis.commands.search.field import TagField, VectorField, NumericField
10
- from redis.commands.search.indexDefinition import IndexDefinition, IndexType
11
+ from redis.commands.search.field import NumericField, TagField, VectorField
12
+ from redis.commands.search.indexDefinition import IndexDefinition
11
13
  from redis.commands.search.query import Query
12
- import numpy as np
13
14
 
15
+ from ..api import IndexType, VectorDB
16
+ from .config import MemoryDBIndexConfig
14
17
 
15
18
  log = logging.getLogger(__name__)
16
- INDEX_NAME = "index" # Vector Index Name
19
+ INDEX_NAME = "index" # Vector Index Name
20
+
17
21
 
18
22
  class MemoryDB(VectorDB):
19
23
  def __init__(
20
- self,
21
- dim: int,
22
- db_config: dict,
23
- db_case_config: MemoryDBIndexConfig,
24
- drop_old: bool = False,
25
- **kwargs
26
- ):
27
-
24
+ self,
25
+ dim: int,
26
+ db_config: dict,
27
+ db_case_config: MemoryDBIndexConfig,
28
+ drop_old: bool = False,
29
+ **kwargs,
30
+ ):
28
31
  self.db_config = db_config
29
32
  self.case_config = db_case_config
30
33
  self.collection_name = INDEX_NAME
@@ -44,10 +47,10 @@ class MemoryDB(VectorDB):
44
47
  info = conn.ft(INDEX_NAME).info()
45
48
  log.info(f"Index info: {info}")
46
49
  except redis.exceptions.ResponseError as e:
47
- log.error(e)
50
+ log.warning(e)
48
51
  drop_old = False
49
52
  log.info(f"MemoryDB client drop_old collection: {self.collection_name}")
50
-
53
+
51
54
  log.info("Executing FLUSHALL")
52
55
  conn.flushall()
53
56
 
@@ -59,7 +62,7 @@ class MemoryDB(VectorDB):
59
62
  self.wait_until(self.wait_for_empty_db, 3, "", rc)
60
63
  log.debug(f"Flushall done in the host: {host}")
61
64
  rc.close()
62
-
65
+
63
66
  self.make_index(dim, conn)
64
67
  conn.close()
65
68
  conn = None
@@ -69,7 +72,7 @@ class MemoryDB(VectorDB):
69
72
  # check to see if index exists
70
73
  conn.ft(INDEX_NAME).info()
71
74
  except Exception as e:
72
- log.warn(f"Error getting info for index '{INDEX_NAME}': {e}")
75
+ log.warning(f"Error getting info for index '{INDEX_NAME}': {e}")
73
76
  index_param = self.case_config.index_param()
74
77
  search_param = self.case_config.search_param()
75
78
  vector_parameters = { # Vector Index Type: FLAT or HNSW
@@ -85,17 +88,19 @@ class MemoryDB(VectorDB):
85
88
  vector_parameters["EF_RUNTIME"] = search_param["ef_runtime"]
86
89
 
87
90
  schema = (
88
- TagField("id"),
89
- NumericField("metadata"),
90
- VectorField("vector", # Vector Field Name
91
- "HNSW", vector_parameters
91
+ TagField("id"),
92
+ NumericField("metadata"),
93
+ VectorField(
94
+ "vector", # Vector Field Name
95
+ "HNSW",
96
+ vector_parameters,
92
97
  ),
93
98
  )
94
99
 
95
100
  definition = IndexDefinition(index_type=IndexType.HASH)
96
101
  rs = conn.ft(INDEX_NAME)
97
102
  rs.create_index(schema, definition=definition)
98
-
103
+
99
104
  def get_client(self, **kwargs):
100
105
  """
101
106
  Gets either cluster connection or normal connection based on `cmd` flag.
@@ -143,7 +148,7 @@ class MemoryDB(VectorDB):
143
148
 
144
149
  @contextmanager
145
150
  def init(self) -> Generator[None, None, None]:
146
- """ create and destory connections to database.
151
+ """create and destory connections to database.
147
152
 
148
153
  Examples:
149
154
  >>> with self.init():
@@ -152,17 +157,14 @@ class MemoryDB(VectorDB):
152
157
  self.conn = self.get_client()
153
158
  search_param = self.case_config.search_param()
154
159
  if search_param["ef_runtime"]:
155
- self.ef_runtime_str = f'EF_RUNTIME {search_param["ef_runtime"]}'
160
+ self.ef_runtime_str = f"EF_RUNTIME {search_param['ef_runtime']}"
156
161
  else:
157
162
  self.ef_runtime_str = ""
158
163
  yield
159
164
  self.conn.close()
160
165
  self.conn = None
161
166
 
162
- def ready_to_load(self) -> bool:
163
- pass
164
-
165
- def optimize(self) -> None:
167
+ def optimize(self, data_size: int | None = None):
166
168
  self._post_insert()
167
169
 
168
170
  def insert_embeddings(
@@ -170,7 +172,7 @@ class MemoryDB(VectorDB):
170
172
  embeddings: list[list[float]],
171
173
  metadata: list[int],
172
174
  **kwargs: Any,
173
- ) -> Tuple[int, Optional[Exception]]:
175
+ ) -> tuple[int, Exception | None]:
174
176
  """Insert embeddings into the database.
175
177
  Should call self.init() first.
176
178
  """
@@ -178,12 +180,15 @@ class MemoryDB(VectorDB):
178
180
  try:
179
181
  with self.conn.pipeline(transaction=False) as pipe:
180
182
  for i, embedding in enumerate(embeddings):
181
- embedding = np.array(embedding).astype(np.float32)
182
- pipe.hset(metadata[i], mapping = {
183
- "id": str(metadata[i]),
184
- "metadata": metadata[i],
185
- "vector": embedding.tobytes(),
186
- })
183
+ ndarr_emb = np.array(embedding).astype(np.float32)
184
+ pipe.hset(
185
+ metadata[i],
186
+ mapping={
187
+ "id": str(metadata[i]),
188
+ "metadata": metadata[i],
189
+ "vector": ndarr_emb.tobytes(),
190
+ },
191
+ )
187
192
  # Execute the pipe so we don't keep too much in memory at once
188
193
  if (i + 1) % self.insert_batch_size == 0:
189
194
  pipe.execute()
@@ -192,9 +197,9 @@ class MemoryDB(VectorDB):
192
197
  result_len = i + 1
193
198
  except Exception as e:
194
199
  return 0, e
195
-
200
+
196
201
  return result_len, None
197
-
202
+
198
203
  def _post_insert(self):
199
204
  """Wait for indexing to finish"""
200
205
  client = self.get_client(primary=True)
@@ -208,21 +213,17 @@ class MemoryDB(VectorDB):
208
213
  self.wait_until(*args)
209
214
  log.debug(f"Background indexing completed in the host: {host_name}")
210
215
  rc.close()
211
-
212
- def wait_until(
213
- self, condition, interval=5, message="Operation took too long", *args
214
- ):
216
+
217
+ def wait_until(self, condition: any, interval: int = 5, message: str = "Operation took too long", *args):
215
218
  while not condition(*args):
216
219
  time.sleep(interval)
217
-
220
+
218
221
  def wait_for_no_activity(self, client: redis.RedisCluster | redis.Redis):
219
- return (
220
- client.info("search")["search_background_indexing_status"] == "NO_ACTIVITY"
221
- )
222
-
222
+ return client.info("search")["search_background_indexing_status"] == "NO_ACTIVITY"
223
+
223
224
  def wait_for_empty_db(self, client: redis.RedisCluster | redis.Redis):
224
225
  return client.execute_command("DBSIZE") == 0
225
-
226
+
226
227
  def search_embedding(
227
228
  self,
228
229
  query: list[float],
@@ -230,13 +231,13 @@ class MemoryDB(VectorDB):
230
231
  filters: dict | None = None,
231
232
  timeout: int | None = None,
232
233
  **kwargs: Any,
233
- ) -> (list[int]):
234
+ ) -> list[int]:
234
235
  assert self.conn is not None
235
-
236
+
236
237
  query_vector = np.array(query).astype(np.float32).tobytes()
237
238
  query_obj = Query(f"*=>[KNN {k} @vector $vec]").return_fields("id").paging(0, k)
238
239
  query_params = {"vec": query_vector}
239
-
240
+
240
241
  if filters:
241
242
  # benchmark test filters of format: {'metadata': '>=10000', 'id': 10000}
242
243
  # gets exact match for id, and range for metadata if they exist in filters
@@ -244,11 +245,19 @@ class MemoryDB(VectorDB):
244
245
  # Removing '>=' from the id_value: '>=10000'
245
246
  metadata_value = filters.get("metadata")[2:]
246
247
  if id_value and metadata_value:
247
- query_obj = Query(f"(@metadata:[{metadata_value} +inf] @id:{ {id_value} })=>[KNN {k} @vector $vec]").return_fields("id").paging(0, k)
248
+ query_obj = (
249
+ Query(
250
+ f"(@metadata:[{metadata_value} +inf] @id:{ {id_value} })=>[KNN {k} @vector $vec]",
251
+ )
252
+ .return_fields("id")
253
+ .paging(0, k)
254
+ )
248
255
  elif id_value:
249
- #gets exact match for id
256
+ # gets exact match for id
250
257
  query_obj = Query(f"@id:{ {id_value} }=>[KNN {k} @vector $vec]").return_fields("id").paging(0, k)
251
- else: #metadata only case, greater than or equal to metadata value
252
- query_obj = Query(f"@metadata:[{metadata_value} +inf]=>[KNN {k} @vector $vec]").return_fields("id").paging(0, k)
258
+ else: # metadata only case, greater than or equal to metadata value
259
+ query_obj = (
260
+ Query(f"@metadata:[{metadata_value} +inf]=>[KNN {k} @vector $vec]").return_fields("id").paging(0, k)
261
+ )
253
262
  res = self.conn.ft(INDEX_NAME).search(query_obj, query_params)
254
- return [int(doc["id"]) for doc in res.docs]
263
+ return [int(doc["id"]) for doc in res.docs]