cognee-community-vector-adapter-weaviate 0.0.2__tar.gz → 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,13 +1,15 @@
1
- Metadata-Version: 2.3
1
+ Metadata-Version: 2.4
2
2
  Name: cognee-community-vector-adapter-weaviate
3
- Version: 0.0.2
3
+ Version: 0.1.0
4
4
  Summary: Weaviate vector database adapter for cognee
5
5
  Requires-Python: >=3.11,<=3.13
6
6
  Classifier: Programming Language :: Python :: 3
7
7
  Classifier: Programming Language :: Python :: 3.11
8
8
  Classifier: Programming Language :: Python :: 3.12
9
9
  Classifier: Programming Language :: Python :: 3.13
10
- Requires-Dist: cognee (>=0.2.4)
10
+ Requires-Dist: cognee (==0.5.2)
11
+ Requires-Dist: instructor (>=1.11)
12
+ Requires-Dist: starlette (>=0.48.0)
11
13
  Requires-Dist: weaviate-client (>=4.9.6,<5.0.0)
12
14
  Description-Content-Type: text/markdown
13
15
 
@@ -55,7 +57,10 @@ self.client = weaviate.use_async_with_local(
55
57
 
56
58
  You can use the docker command provided by Weaviate (https://docs.weaviate.io/deploy/installation-guides/docker-installation)
57
59
  to run Weaviate with default settings. The command looks something like this, specifying the ports for connection:
58
- `docker run -p 8080:8080 -p 50051:50051 cr.weaviate.io/semitechnologies/weaviate:1.32.4`
60
+
61
+ ```bash
62
+ docker run -p 8080:8080 -p 50051:50051 cr.weaviate.io/semitechnologies/weaviate:1.32.4
63
+ ```
59
64
 
60
65
  ## Usage
61
66
 
@@ -42,7 +42,10 @@ self.client = weaviate.use_async_with_local(
42
42
 
43
43
  You can use the docker command provided by Weaviate (https://docs.weaviate.io/deploy/installation-guides/docker-installation)
44
44
  to run Weaviate with default settings. The command looks something like this, specifying the ports for connection:
45
- `docker run -p 8080:8080 -p 50051:50051 cr.weaviate.io/semitechnologies/weaviate:1.32.4`
45
+
46
+ ```bash
47
+ docker run -p 8080:8080 -p 50051:50051 cr.weaviate.io/semitechnologies/weaviate:1.32.4
48
+ ```
46
49
 
47
50
  ## Usage
48
51
 
@@ -1,26 +1,25 @@
1
1
  import asyncio
2
- from typing import List, Optional
3
2
 
4
- from tenacity import retry, retry_if_exception, stop_after_attempt, wait_exponential
5
-
6
- from cognee.shared.logging_utils import get_logger
7
- from cognee.infrastructure.engine import DataPoint
8
- from cognee.infrastructure.engine.utils import parse_id
9
3
  from cognee.infrastructure.databases.exceptions import MissingQueryParameterError
10
- from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
11
4
  from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import (
12
5
  EmbeddingEngine,
13
6
  )
7
+ from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
14
8
  from cognee.infrastructure.databases.vector.models.ScoredResult import ScoredResult
15
9
  from cognee.infrastructure.databases.vector.vector_db_interface import VectorDBInterface
10
+ from cognee.infrastructure.engine import DataPoint
11
+ from cognee.infrastructure.engine.utils import parse_id
12
+ from cognee.shared.logging_utils import get_logger
13
+ from tenacity import retry, retry_if_exception, stop_after_attempt, wait_exponential
16
14
 
17
15
  logger = get_logger("WeaviateAdapter")
18
16
 
19
17
 
20
18
  def is_retryable_request(error):
21
- from weaviate.exceptions import UnexpectedStatusCodeException
22
19
  from requests.exceptions import RequestException
23
20
 
21
+ from weaviate.exceptions import UnexpectedStatusCodeException
22
+
24
23
  if isinstance(error, UnexpectedStatusCodeException):
25
24
  # Retry on conflict, service unavailable, internal error
26
25
  return error.status_code in {409, 503, 500}
@@ -72,12 +71,19 @@ class WeaviateAdapter(VectorDBInterface):
72
71
  api_key: str
73
72
  embedding_engine: EmbeddingEngine = None
74
73
 
75
- def __init__(self, url: str, api_key: str, embedding_engine: EmbeddingEngine):
74
+ def __init__(
75
+ self,
76
+ url: str,
77
+ api_key: str,
78
+ embedding_engine: EmbeddingEngine,
79
+ database_name: str = "cognee",
80
+ ):
76
81
  import weaviate
77
82
  import weaviate.classes as wvc
78
83
 
79
84
  self.url = url
80
85
  self.api_key = api_key
86
+ self.database_name = database_name
81
87
 
82
88
  self.embedding_engine = embedding_engine
83
89
  self.VECTOR_DB_LOCK = asyncio.Lock()
@@ -85,9 +91,7 @@ class WeaviateAdapter(VectorDBInterface):
85
91
  self.client = weaviate.use_async_with_weaviate_cloud(
86
92
  cluster_url=url,
87
93
  auth_credentials=weaviate.auth.AuthApiKey(api_key),
88
- additional_config=wvc.init.AdditionalConfig(
89
- timeout=wvc.init.Timeout(init=30)
90
- ),
94
+ additional_config=wvc.init.AdditionalConfig(timeout=wvc.init.Timeout(init=30)),
91
95
  )
92
96
 
93
97
  async def get_client(self):
@@ -105,7 +109,7 @@ class WeaviateAdapter(VectorDBInterface):
105
109
 
106
110
  return self.client
107
111
 
108
- async def embed_data(self, data: List[str]) -> List[float]:
112
+ async def embed_data(self, data: list[str]) -> list[float]:
109
113
  """
110
114
  Embed the given text data into vector representations.
111
115
 
@@ -187,7 +191,7 @@ class WeaviateAdapter(VectorDBInterface):
187
191
  )
188
192
  else:
189
193
  result = await self.get_collection(collection_name)
190
- await client.close()
194
+ # await client.close()
191
195
  return result
192
196
 
193
197
  async def get_collection(self, collection_name: str):
@@ -216,9 +220,7 @@ class WeaviateAdapter(VectorDBInterface):
216
220
  stop=stop_after_attempt(3),
217
221
  wait=wait_exponential(multiplier=2, min=1, max=6),
218
222
  )
219
- async def create_data_points(
220
- self, collection_name: str, data_points: List[DataPoint]
221
- ):
223
+ async def create_data_points(self, collection_name: str, data_points: list[DataPoint]):
222
224
  """
223
225
  Create or update data points in the specified collection in the Weaviate database.
224
226
 
@@ -269,9 +271,7 @@ class WeaviateAdapter(VectorDBInterface):
269
271
 
270
272
  return DataObject(uuid=data_point.id, properties=properties, vector=vector)
271
273
 
272
- data_points = [
273
- convert_to_weaviate_data_points(data_point) for data_point in data_points
274
- ]
274
+ data_points = [convert_to_weaviate_data_points(data_point) for data_point in data_points]
275
275
 
276
276
  await self.get_client()
277
277
  collection = await self.get_collection(collection_name)
@@ -298,8 +298,8 @@ class WeaviateAdapter(VectorDBInterface):
298
298
  except Exception as error:
299
299
  logger.error("Error creating data points: %s", str(error))
300
300
  raise error
301
- finally:
302
- await self.client.close()
301
+ # finally:
302
+ # await self.client.close()
303
303
 
304
304
  async def create_vector_index(self, index_name: str, index_property_name: str):
305
305
  """
@@ -381,22 +381,23 @@ class WeaviateAdapter(VectorDBInterface):
381
381
  data_point.id = data_point.uuid
382
382
  del data_point.properties
383
383
 
384
- await self.client.close()
384
+ # await self.client.close()
385
385
  return data_points.objects
386
386
 
387
387
  async def search(
388
388
  self,
389
389
  collection_name: str,
390
- query_text: Optional[str] = None,
391
- query_vector: Optional[List[float]] = None,
392
- limit: int = 15,
390
+ query_text: str | None = None,
391
+ query_vector: list[float] | None = None,
392
+ limit: int | None = 15,
393
393
  with_vector: bool = False,
394
+ include_payload: bool = False,
394
395
  ):
395
396
  """
396
397
  Perform a search on a collection using either a text query or a vector query.
397
398
 
398
- Return scored results based on the search criteria provided. Raise MissingQueryParameterError if
399
- no query is provided.
399
+ Return scored results based on the search criteria provided.
400
+ Raise MissingQueryParameterError if no query is provided.
400
401
 
401
402
  Parameters:
402
403
  -----------
@@ -408,6 +409,7 @@ class WeaviateAdapter(VectorDBInterface):
408
409
  searching. (default None)
409
410
  - limit (int): The maximum number of results to return. (default 15)
410
411
  - with_vector (bool): Include vector information in the results. (default False)
412
+ - include_payload (bool): Include payload information in the results. (default False)
411
413
 
412
414
  Returns:
413
415
  --------
@@ -423,28 +425,32 @@ class WeaviateAdapter(VectorDBInterface):
423
425
  if query_vector is None:
424
426
  query_vector = (await self.embed_data([query_text]))[0]
425
427
 
426
- # TODO: Creation of new client for every search call. This is VERY ugly, needs discussion. (Andrej's comment)
428
+ # TODO: Creation of new client for every search call. This is VERY ugly. Should change.
427
429
  async with weaviate.use_async_with_weaviate_cloud(
428
430
  cluster_url=self.url,
429
431
  auth_credentials=weaviate.auth.AuthApiKey(self.api_key),
430
- additional_config=wvc.init.AdditionalConfig(
431
- timeout=wvc.init.Timeout(init=30)
432
- ),
432
+ additional_config=wvc.init.AdditionalConfig(timeout=wvc.init.Timeout(init=30)),
433
433
  ) as client:
434
434
  if not await client.collections.exists(collection_name):
435
- raise CollectionNotFoundError(
436
- f"Collection '{collection_name}' not found."
437
- )
435
+ raise CollectionNotFoundError(f"Collection '{collection_name}' not found.")
438
436
 
439
437
  collection = client.collections.get(collection_name)
440
438
 
439
+ if limit is None:
440
+ result = await collection.aggregate.over_all(total_count=True)
441
+ limit = result.total_count
442
+
443
+ if limit == 0:
444
+ return []
445
+
441
446
  try:
442
447
  search_result = await collection.query.hybrid(
443
448
  query=None,
444
449
  vector=query_vector,
445
- limit=limit if limit > 0 else None,
450
+ limit=limit,
446
451
  include_vector=with_vector,
447
452
  return_metadata=wvc.query.MetadataQuery(score=True),
453
+ return_properties=include_payload,
448
454
  )
449
455
 
450
456
  return [
@@ -462,9 +468,10 @@ class WeaviateAdapter(VectorDBInterface):
462
468
  async def batch_search(
463
469
  self,
464
470
  collection_name: str,
465
- query_texts: List[str],
466
- limit: int,
471
+ query_texts: list[str],
472
+ limit: int | None,
467
473
  with_vectors: bool = False,
474
+ include_payload: bool = False,
468
475
  ):
469
476
  """
470
477
  Execute a batch search for multiple query texts in the specified collection.
@@ -479,6 +486,7 @@ class WeaviateAdapter(VectorDBInterface):
479
486
  - limit (int): The maximum number of results to return for each query.
480
487
  - with_vectors (bool): Indicate whether to include vector information in the
481
488
  results. (default False)
489
+ - include_payload (bool): Include payload information in the results.
482
490
 
483
491
  Returns:
484
492
  --------
@@ -508,11 +516,11 @@ class WeaviateAdapter(VectorDBInterface):
508
516
  query_vector=query_vector,
509
517
  limit=limit,
510
518
  with_vector=with_vectors,
519
+ include_payload=include_payload,
511
520
  )
512
521
 
513
522
  return [
514
- await query_search(query_vector)
515
- for query_vector in await self.embed_data(query_texts)
523
+ await query_search(query_vector) for query_vector in await self.embed_data(query_texts)
516
524
  ]
517
525
 
518
526
  async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
@@ -542,7 +550,7 @@ class WeaviateAdapter(VectorDBInterface):
542
550
  filters=Filter.by_id().contains_any(data_point_ids)
543
551
  )
544
552
 
545
- await self.client.close()
553
+ # await self.client.close()
546
554
  return result
547
555
 
548
556
  async def prune(self):
@@ -553,4 +561,15 @@ class WeaviateAdapter(VectorDBInterface):
553
561
  """
554
562
  client = await self.get_client()
555
563
  await client.collections.delete_all()
556
- await client.close()
564
+ # await client.close()
565
+
566
+ async def get_collection_names(self) -> list[str]:
567
+ """
568
+ Get names of all collections in the database.
569
+
570
+ Returns:
571
+ list[str]: List of collection names.
572
+ """
573
+
574
+ client = await self.get_client()
575
+ return await client.collections.list_all()
@@ -1,10 +1,12 @@
1
1
  [project]
2
2
  name = "cognee-community-vector-adapter-weaviate"
3
- version = "0.0.2"
3
+ version = "0.1.0"
4
4
  description = "Weaviate vector database adapter for cognee"
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.11,<=3.13"
7
7
  dependencies = [
8
8
  "weaviate-client>=4.9.6,<5.0.0",
9
- "cognee>=0.2.4"
9
+ "cognee==0.5.2",
10
+ "starlette>=0.48.0",
11
+ "instructor>=1.11"
10
12
  ]