cognee-community-vector-adapter-redis 0.0.3__tar.gz → 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,10 +1,12 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: cognee-community-vector-adapter-redis
3
- Version: 0.0.3
3
+ Version: 0.1.0
4
4
  Summary: Redis vector database adapter for cognee
5
5
  Requires-Python: <=3.13,>=3.11
6
- Requires-Dist: cognee>=0.2.0.dev0
6
+ Requires-Dist: cognee==0.5.2
7
+ Requires-Dist: instructor>=1.11
7
8
  Requires-Dist: redisvl<=1.0.0,>=0.6.0
9
+ Requires-Dist: starlette>=0.48.0
8
10
  Description-Content-Type: text/markdown
9
11
 
10
12
  <div align="center" dir="auto">
@@ -1,23 +1,22 @@
1
- import json
2
1
  import asyncio
3
- from typing import Dict, List, Optional, Any
2
+ import json
3
+ from typing import Any
4
4
  from uuid import UUID
5
5
 
6
- from redisvl.index import AsyncSearchIndex
7
- from redisvl.schema import IndexSchema
8
- from redisvl.query import VectorQuery
9
- # from redisvl.query import VectorDistanceMetric
10
-
11
- from cognee.shared.logging_utils import get_logger
12
-
13
- from cognee.infrastructure.engine import DataPoint
14
- from cognee.infrastructure.engine.utils import parse_id
15
6
  from cognee.infrastructure.databases.exceptions import MissingQueryParameterError
16
7
  from cognee.infrastructure.databases.vector import VectorDBInterface
17
- from cognee.infrastructure.databases.vector.models.ScoredResult import ScoredResult
18
8
  from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import (
19
9
  EmbeddingEngine,
20
10
  )
11
+ from cognee.infrastructure.databases.vector.models.ScoredResult import ScoredResult
12
+ from cognee.infrastructure.engine import DataPoint
13
+ from cognee.infrastructure.engine.utils import parse_id
14
+
15
+ # from redisvl.query import VectorDistanceMetric
16
+ from cognee.shared.logging_utils import get_logger
17
+ from redisvl.index import AsyncSearchIndex
18
+ from redisvl.query import VectorQuery
19
+ from redisvl.schema import IndexSchema
21
20
 
22
21
  logger = get_logger("RedisAdapter")
23
22
 
@@ -73,15 +72,16 @@ class RedisAdapter(VectorDBInterface):
73
72
  """
74
73
 
75
74
  name = "Redis"
76
- url: Optional[str]
77
- api_key: Optional[str] = None
78
- embedding_engine: Optional[EmbeddingEngine] = None
75
+ url: str | None
76
+ api_key: str | None = None
77
+ embedding_engine: EmbeddingEngine | None = None
79
78
 
80
79
  def __init__(
81
80
  self,
82
81
  url: str,
83
- api_key: Optional[str] = None,
84
- embedding_engine: Optional[EmbeddingEngine] = None,
82
+ database_name: str = "cognee",
83
+ api_key: str | None = None,
84
+ embedding_engine: EmbeddingEngine | None = None,
85
85
  ) -> None:
86
86
  """Initialize the Redis adapter.
87
87
 
@@ -99,11 +99,12 @@ class RedisAdapter(VectorDBInterface):
99
99
  raise VectorEngineInitializationError("Embedding engine is required!")
100
100
 
101
101
  self.url = url
102
+ self.database_name = database_name
102
103
  self.embedding_engine = embedding_engine
103
104
  self._indices = {}
104
105
  self.VECTOR_DB_LOCK = asyncio.Lock()
105
106
 
106
- async def embed_data(self, data: List[str]) -> List[List[float]]:
107
+ async def embed_data(self, data: list[str]) -> list[list[float]]:
107
108
  """Embed text data using the embedding engine.
108
109
 
109
110
  Args:
@@ -179,7 +180,6 @@ class RedisAdapter(VectorDBInterface):
179
180
  try:
180
181
  index = self._get_index(collection_name)
181
182
  result = await index.exists()
182
- await index.disconnect()
183
183
  return result
184
184
  except Exception:
185
185
  return False
@@ -187,7 +187,7 @@ class RedisAdapter(VectorDBInterface):
187
187
  async def create_collection(
188
188
  self,
189
189
  collection_name: str,
190
- payload_schema: Optional[Any] = None,
190
+ payload_schema: Any | None = None,
191
191
  ) -> None:
192
192
  """Create a new collection (Redis index) with vector search capabilities.
193
193
 
@@ -205,7 +205,6 @@ class RedisAdapter(VectorDBInterface):
205
205
  logger.info(f"Collection {collection_name} already exists")
206
206
  return
207
207
 
208
- index = self._get_index(collection_name)
209
208
  await index.create(overwrite=False)
210
209
 
211
210
  logger.info(f"Created collection {collection_name}")
@@ -213,12 +212,8 @@ class RedisAdapter(VectorDBInterface):
213
212
  except Exception as e:
214
213
  logger.error(f"Error creating collection {collection_name}: {str(e)}")
215
214
  raise e
216
- finally:
217
- await index.disconnect()
218
215
 
219
- async def create_data_points(
220
- self, collection_name: str, data_points: List[DataPoint]
221
- ) -> None:
216
+ async def create_data_points(self, collection_name: str, data_points: list[DataPoint]) -> None:
222
217
  """Create data points in the collection.
223
218
 
224
219
  Args:
@@ -232,21 +227,16 @@ class RedisAdapter(VectorDBInterface):
232
227
  index = self._get_index(collection_name)
233
228
  try:
234
229
  if not await self.has_collection(collection_name):
235
- raise CollectionNotFoundError(
236
- f"Collection {collection_name} not found!"
237
- )
230
+ raise CollectionNotFoundError(f"Collection {collection_name} not found!")
238
231
 
239
232
  # Embed the data points
240
233
  data_vectors = await self.embed_data(
241
- [
242
- DataPoint.get_embeddable_data(data_point)
243
- for data_point in data_points
244
- ]
234
+ [DataPoint.get_embeddable_data(data_point) for data_point in data_points]
245
235
  )
246
236
 
247
237
  # Prepare documents for RedisVL
248
238
  documents = []
249
- for data_point, embedding in zip(data_points, data_vectors):
239
+ for data_point, embedding in zip(data_points, data_vectors, strict=False):
250
240
  # Serialize the payload to handle UUIDs and other non-JSON types
251
241
  payload = serialize_for_json(data_point.model_dump())
252
242
 
@@ -265,19 +255,13 @@ class RedisAdapter(VectorDBInterface):
265
255
  # Load using RedisVL
266
256
  await index.load(documents, id_field="id")
267
257
 
268
- logger.info(
269
- f"Created {len(data_points)} data points in collection {collection_name}"
270
- )
258
+ logger.info(f"Created {len(data_points)} data points in collection {collection_name}")
271
259
 
272
260
  except Exception as e:
273
261
  logger.error(f"Error creating data points: {str(e)}")
274
262
  raise e
275
- finally:
276
- await index.disconnect()
277
263
 
278
- async def create_vector_index(
279
- self, index_name: str, index_property_name: str
280
- ) -> None:
264
+ async def create_vector_index(self, index_name: str, index_property_name: str) -> None:
281
265
  """Create a vector index for a specific property.
282
266
 
283
267
  Args:
@@ -301,17 +285,15 @@ class RedisAdapter(VectorDBInterface):
301
285
  [
302
286
  RedisDataPoint(
303
287
  id=data_point.id,
304
- text=getattr(
305
- data_point, data_point.metadata.get("index_fields", ["text"])[0]
306
- ),
288
+ text=getattr(data_point, data_point.metadata.get("index_fields", ["text"])[0]),
307
289
  )
308
290
  for data_point in data_points
309
291
  ],
310
292
  )
311
293
 
312
294
  async def retrieve(
313
- self, collection_name: str, data_point_ids: List[str]
314
- ) -> List[Dict[str, Any]]:
295
+ self, collection_name: str, data_point_ids: list[str]
296
+ ) -> list[dict[str, Any]]:
315
297
  """Retrieve data points by their IDs.
316
298
 
317
299
  Args:
@@ -342,17 +324,16 @@ class RedisAdapter(VectorDBInterface):
342
324
  except Exception as e:
343
325
  logger.error(f"Error retrieving data points: {str(e)}")
344
326
  return []
345
- finally:
346
- await index.disconnect()
347
327
 
348
328
  async def search(
349
329
  self,
350
330
  collection_name: str,
351
- query_text: Optional[str] = None,
352
- query_vector: Optional[List[float]] = None,
353
- limit: int = 15,
331
+ query_text: str | None = None,
332
+ query_vector: list[float] | None = None,
333
+ limit: int | None = 15,
354
334
  with_vector: bool = False,
355
- ) -> List[ScoredResult]:
335
+ include_payload: bool = True,
336
+ ) -> list[ScoredResult]:
356
337
  """Search for similar vectors in the collection.
357
338
 
358
339
  Args:
@@ -361,6 +342,7 @@ class RedisAdapter(VectorDBInterface):
361
342
  query_vector: Pre-computed query vector.
362
343
  limit: Maximum number of results to return.
363
344
  with_vector: Whether to include vectors in results.
345
+ include_payload: Whether to include payloads in results.
364
346
 
365
347
  Returns:
366
348
  List of ScoredResult objects sorted by similarity.
@@ -380,11 +362,11 @@ class RedisAdapter(VectorDBInterface):
380
362
 
381
363
  index = self._get_index(collection_name)
382
364
 
383
- if limit == 0:
365
+ if limit is None:
384
366
  info = await index.info()
385
367
  limit = info["num_docs"]
386
368
 
387
- if limit == 0:
369
+ if limit <= 0:
388
370
  return []
389
371
 
390
372
  try:
@@ -402,7 +384,10 @@ class RedisAdapter(VectorDBInterface):
402
384
  )
403
385
 
404
386
  # Set return fields
405
- return_fields = ["id", "text", "payload_data"]
387
+ if include_payload:
388
+ return_fields = ["id", "text", "payload_data"]
389
+ else:
390
+ return_fields = ["id", "text"]
406
391
  if with_vector:
407
392
  return_fields.append("vector")
408
393
  vector_query = vector_query.return_fields(*return_fields)
@@ -424,9 +409,7 @@ class RedisAdapter(VectorDBInterface):
424
409
  ScoredResult(
425
410
  id=parse_id(doc["id"].split(":", 1)[1]),
426
411
  payload=payload,
427
- score=float(
428
- doc.get("vector_distance", 0.0)
429
- ), # RedisVL returns distance
412
+ score=float(doc.get("vector_distance", 0.0)), # RedisVL returns distance
430
413
  )
431
414
  )
432
415
  return scored_results
@@ -434,16 +417,15 @@ class RedisAdapter(VectorDBInterface):
434
417
  except Exception as e:
435
418
  logger.error(f"Error during search: {str(e)}")
436
419
  raise e
437
- finally:
438
- await index.disconnect()
439
420
 
440
421
  async def batch_search(
441
422
  self,
442
423
  collection_name: str,
443
- query_texts: List[str],
444
- limit: Optional[int] = None,
424
+ query_texts: list[str],
425
+ limit: int | None,
445
426
  with_vectors: bool = False,
446
- ) -> List[List[ScoredResult]]:
427
+ include_payload: bool = True,
428
+ ) -> list[list[ScoredResult]]:
447
429
  """Perform batch search for multiple queries.
448
430
 
449
431
  Args:
@@ -451,6 +433,7 @@ class RedisAdapter(VectorDBInterface):
451
433
  query_texts: List of text queries to search for.
452
434
  limit: Maximum number of results per query.
453
435
  with_vectors: Whether to include vectors in results.
436
+ include_payload: Whether to include payloads in results.
454
437
 
455
438
  Returns:
456
439
  List of search results for each query, filtered by score threshold.
@@ -466,6 +449,7 @@ class RedisAdapter(VectorDBInterface):
466
449
  query_vector=vector,
467
450
  limit=limit,
468
451
  with_vector=with_vectors,
452
+ include_payload=include_payload,
469
453
  )
470
454
  for vector in vectors
471
455
  ]
@@ -474,13 +458,12 @@ class RedisAdapter(VectorDBInterface):
474
458
 
475
459
  # Filter results by score threshold (Redis uses distance, so lower is better)
476
460
  return [
477
- [result for result in result_group if result.score < 0.1]
478
- for result_group in results
461
+ [result for result in result_group if result.score < 0.1] for result_group in results
479
462
  ]
480
463
 
481
464
  async def delete_data_points(
482
- self, collection_name: str, data_point_ids: List[str]
483
- ) -> Dict[str, int]:
465
+ self, collection_name: str, data_point_ids: list[str]
466
+ ) -> dict[str, int]:
484
467
  """Delete data points by their IDs.
485
468
 
486
469
  Args:
@@ -496,15 +479,11 @@ class RedisAdapter(VectorDBInterface):
496
479
  index = self._get_index(collection_name)
497
480
  try:
498
481
  deleted_count = await index.drop_documents(data_point_ids)
499
- logger.info(
500
- f"Deleted {deleted_count} data points from collection {collection_name}"
501
- )
482
+ logger.info(f"Deleted {deleted_count} data points from collection {collection_name}")
502
483
  return {"deleted": deleted_count}
503
484
  except Exception as e:
504
485
  logger.error(f"Error deleting data points: {str(e)}")
505
486
  raise e
506
- finally:
507
- await index.disconnect()
508
487
 
509
488
  async def prune(self) -> None:
510
489
  """Remove all collections and data from Redis.
@@ -521,10 +500,8 @@ class RedisAdapter(VectorDBInterface):
521
500
  if await index.exists():
522
501
  await index.delete(drop=True)
523
502
  logger.info(f"Dropped index {collection_name}")
524
- await index.disconnect()
525
503
  except Exception as e:
526
504
  logger.warning(f"Failed to drop index {collection_name}: {str(e)}")
527
- await index.disconnect()
528
505
 
529
506
  # Clear the indices cache
530
507
  self._indices.clear()
@@ -534,3 +511,12 @@ class RedisAdapter(VectorDBInterface):
534
511
  except Exception as e:
535
512
  logger.error(f"Error during prune: {str(e)}")
536
513
  raise e
514
+
515
+ async def get_collection_names(self):
516
+ """
517
+ Get names of all collections in the database.
518
+
519
+ Returns:
520
+ List of collection names. In this case of Redis, the return type is a dict.
521
+ """
522
+ return self._indices.keys()
@@ -1,5 +1,5 @@
1
- import os
2
1
  import asyncio
2
+ import os
3
3
  import pathlib
4
4
  from os import path
5
5
 
@@ -8,10 +8,11 @@ os.environ.setdefault("LLM_API_KEY", "your-api-key")
8
8
 
9
9
 
10
10
  async def main():
11
- from cognee import config, prune, add, cognify, search, SearchType
11
+ from cognee import SearchType, add, cognify, config, prune, search
12
12
 
13
13
  # NOTE: Importing the register module we let cognee know it can use the Redis vector adapter
14
- from cognee_community_vector_adapter_redis import register
14
+ # NOTE: The "noqa: F401" mark is to make sure the linter doesn't flag this as an unused import
15
+ from cognee_community_vector_adapter_redis import register # noqa: F401
15
16
 
16
17
  system_path = pathlib.Path(__file__).parent
17
18
  config.system_root_directory(path.join(system_path, ".cognee-system"))
@@ -36,16 +37,14 @@ async def main():
36
37
 
37
38
  await add("""
38
39
  Sandwhiches are best served toasted with cheese, ham, mayo,
39
- lettuce, mustard, and salt & pepper.
40
+ lettuce, mustard, and salt & pepper.
40
41
  """)
41
42
 
42
43
  await cognify()
43
44
 
44
45
  query_text = "Tell me about NLP"
45
46
 
46
- search_results = await search(
47
- query_type=SearchType.GRAPH_COMPLETION, query_text=query_text
48
- )
47
+ search_results = await search(query_type=SearchType.GRAPH_COMPLETION, query_text=query_text)
49
48
 
50
49
  for result_text in search_results:
51
50
  print("\nSearch result: \n" + result_text)