cognee-community-vector-adapter-qdrant 0.0.3__tar.gz → 0.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,14 +1,16 @@
1
- Metadata-Version: 2.3
1
+ Metadata-Version: 2.4
2
2
  Name: cognee-community-vector-adapter-qdrant
3
- Version: 0.0.3
3
+ Version: 0.2.0
4
4
  Summary: Qdrant vector database adapter for cognee
5
5
  Requires-Python: >=3.11,<=3.13
6
6
  Classifier: Programming Language :: Python :: 3
7
7
  Classifier: Programming Language :: Python :: 3.11
8
8
  Classifier: Programming Language :: Python :: 3.12
9
9
  Classifier: Programming Language :: Python :: 3.13
10
- Requires-Dist: cognee (>=0.2.4)
11
- Requires-Dist: qdrant-client (>=1.14.2)
10
+ Requires-Dist: cognee (==0.5.1)
11
+ Requires-Dist: instructor (>=1.11)
12
+ Requires-Dist: qdrant-client (>=1.16.0)
13
+ Requires-Dist: starlette (>=0.48.0)
12
14
  Description-Content-Type: text/markdown
13
15
 
14
16
  # Cognee Qdrant Adapter
@@ -47,6 +49,11 @@ Import and register the adapter in your code:
47
49
  from cognee_community_vector_adapter_qdrant import register
48
50
  ```
49
51
 
52
+ Also, specify the dataset handler in the .env file:
53
+ ```dotenv
54
+ VECTOR_DATASET_DATABASE_HANDLER="qdrant"
55
+ ```
56
+
50
57
  ## Example
51
58
  See example in `example.py` file.
52
59
 
@@ -34,5 +34,10 @@ Import and register the adapter in your code:
34
34
  from cognee_community_vector_adapter_qdrant import register
35
35
  ```
36
36
 
37
+ Also, specify the dataset handler in the .env file:
38
+ ```dotenv
39
+ VECTOR_DATASET_DATABASE_HANDLER="qdrant"
40
+ ```
41
+
37
42
  ## Example
38
43
  See example in `example.py` file.
@@ -0,0 +1,39 @@
1
+ from typing import Optional
2
+ from uuid import UUID
3
+
4
+ from cognee.infrastructure.databases.dataset_database_handler import DatasetDatabaseHandlerInterface
5
+ from cognee.infrastructure.databases.vector import get_vectordb_config
6
+ from cognee.infrastructure.databases.vector.create_vector_engine import create_vector_engine
7
+ from cognee.modules.users.models import DatasetDatabase, User
8
+
9
+
10
+ class QdrantDatasetDatabaseHandler(DatasetDatabaseHandlerInterface):
11
+ @classmethod
12
+ async def create_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]) -> dict:
13
+ vector_config = get_vectordb_config()
14
+
15
+ if vector_config.vector_db_provider != "qdrant":
16
+ raise ValueError(
17
+ "QdrantDatasetDatabaseHandler can only be used with the"
18
+ "Qdrant vector database provider."
19
+ )
20
+
21
+ vector_db_name = f"{dataset_id}"
22
+
23
+ return {
24
+ "vector_database_provider": vector_config.vector_db_provider,
25
+ "vector_database_url": vector_config.vector_db_url,
26
+ "vector_database_key": vector_config.vector_db_key,
27
+ "vector_database_name": vector_db_name,
28
+ "vector_dataset_database_handler": "qdrant",
29
+ }
30
+
31
+ @classmethod
32
+ async def delete_dataset(cls, dataset_database: DatasetDatabase) -> None:
33
+ vector_engine = create_vector_engine(
34
+ vector_db_provider=dataset_database.vector_database_provider,
35
+ vector_db_url=dataset_database.vector_database_url,
36
+ vector_db_key=dataset_database.vector_database_key,
37
+ vector_db_name=dataset_database.vector_database_name,
38
+ )
39
+ await vector_engine.prune()
@@ -1,18 +1,16 @@
1
1
  import asyncio
2
- from typing import Dict, List, Optional
3
- from qdrant_client import AsyncQdrantClient, models
4
-
5
- from cognee.shared.logging_utils import get_logger
6
2
 
7
- from cognee.infrastructure.engine import DataPoint
8
- from cognee.infrastructure.engine.utils import parse_id
9
3
  from cognee.infrastructure.databases.exceptions import MissingQueryParameterError
10
4
  from cognee.infrastructure.databases.vector import VectorDBInterface
11
- from cognee.infrastructure.databases.vector.models.ScoredResult import ScoredResult
12
5
  from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import (
13
6
  EmbeddingEngine,
14
7
  )
15
8
  from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
9
+ from cognee.infrastructure.databases.vector.models.ScoredResult import ScoredResult
10
+ from cognee.infrastructure.engine import DataPoint
11
+ from cognee.infrastructure.engine.utils import parse_id
12
+ from cognee.shared.logging_utils import get_logger
13
+ from qdrant_client import AsyncQdrantClient, models
16
14
 
17
15
  logger = get_logger("QDrantAdapter")
18
16
 
@@ -23,19 +21,19 @@ class IndexSchema(DataPoint):
23
21
  metadata: dict = {"index_fields": ["text"]}
24
22
 
25
23
 
26
- def create_hnsw_config(hnsw_config: Dict):
24
+ def create_hnsw_config(hnsw_config: dict):
27
25
  if hnsw_config is not None:
28
26
  return models.HnswConfig()
29
27
  return None
30
28
 
31
29
 
32
- def create_optimizers_config(optimizers_config: Dict):
30
+ def create_optimizers_config(optimizers_config: dict):
33
31
  if optimizers_config is not None:
34
32
  return models.OptimizersConfig()
35
33
  return None
36
34
 
37
35
 
38
- def create_quantization_config(quantization_config: Dict):
36
+ def create_quantization_config(quantization_config: dict):
39
37
  if quantization_config is not None:
40
38
  return models.QuantizationConfig()
41
39
  return None
@@ -48,9 +46,15 @@ class QDrantAdapter(VectorDBInterface):
48
46
  qdrant_path: str = None
49
47
 
50
48
  def __init__(
51
- self, url, api_key, embedding_engine: EmbeddingEngine, qdrant_path=None
49
+ self,
50
+ url,
51
+ api_key,
52
+ embedding_engine: EmbeddingEngine,
53
+ qdrant_path=None,
54
+ database_name: str = "cognee_db",
52
55
  ):
53
56
  self.embedding_engine = embedding_engine
57
+ self.database_name = database_name
54
58
 
55
59
  if qdrant_path is not None:
56
60
  self.qdrant_path = qdrant_path
@@ -67,7 +71,7 @@ class QDrantAdapter(VectorDBInterface):
67
71
 
68
72
  return AsyncQdrantClient(location=":memory:")
69
73
 
70
- async def embed_data(self, data: List[str]) -> List[float]:
74
+ async def embed_data(self, data: list[str]) -> list[float]:
71
75
  return await self.embedding_engine.embed_text(data)
72
76
 
73
77
  async def has_collection(self, collection_name: str) -> bool:
@@ -90,16 +94,29 @@ class QDrantAdapter(VectorDBInterface):
90
94
  vectors_config={
91
95
  "text": models.VectorParams(
92
96
  size=self.embedding_engine.get_vector_size(),
93
- distance="Cosine",
97
+ distance=models.Distance.COSINE,
94
98
  )
95
99
  },
100
+ # With this config definition, we avoid creating a global index
101
+ hnsw_config=models.HnswConfigDiff(
102
+ payload_m=16,
103
+ m=0,
104
+ ),
105
+ )
106
+ # This index co-locates vectors from the same dataset together,
107
+ # which can improve performance
108
+ await client.create_payload_index(
109
+ collection_name=collection_name,
110
+ field_name="database_name",
111
+ field_schema=models.KeywordIndexParams(
112
+ type=models.KeywordIndexType.KEYWORD,
113
+ is_tenant=True,
114
+ ),
96
115
  )
97
116
 
98
117
  await client.close()
99
118
 
100
- async def create_data_points(
101
- self, collection_name: str, data_points: List[DataPoint]
102
- ):
119
+ async def create_data_points(self, collection_name: str, data_points: list[DataPoint]):
103
120
  from qdrant_client.http.exceptions import UnexpectedResponse
104
121
 
105
122
  client = self.get_qdrant_client()
@@ -111,14 +128,15 @@ class QDrantAdapter(VectorDBInterface):
111
128
  def convert_to_qdrant_point(data_point: DataPoint):
112
129
  return models.PointStruct(
113
130
  id=str(data_point.id),
114
- payload=data_point.model_dump(),
131
+ payload={**data_point.model_dump(), "database_name": self.database_name},
115
132
  vector={"text": data_vectors[data_points.index(data_point)]},
116
133
  )
117
134
 
118
135
  points = [convert_to_qdrant_point(point) for point in data_points]
119
136
 
120
137
  try:
121
- client.upload_points(collection_name=collection_name, points=points)
138
+ # Use upsert for AsyncQdrantClient (upload_points doesn't exist or is sync)
139
+ await client.upsert(collection_name=collection_name, points=points)
122
140
  except UnexpectedResponse as error:
123
141
  if "Collection not found" in str(error):
124
142
  raise CollectionNotFoundError(
@@ -151,22 +169,18 @@ class QDrantAdapter(VectorDBInterface):
151
169
 
152
170
  async def retrieve(self, collection_name: str, data_point_ids: list[str]):
153
171
  client = self.get_qdrant_client()
154
- results = await client.retrieve(
155
- collection_name, data_point_ids, with_payload=True
156
- )
172
+ results = await client.retrieve(collection_name, data_point_ids, with_payload=True)
157
173
  await client.close()
158
174
  return results
159
175
 
160
176
  async def search(
161
177
  self,
162
178
  collection_name: str,
163
- query_text: Optional[str] = None,
164
- query_vector: Optional[List[float]] = None,
165
- limit: int = 15,
179
+ query_text: str | None = None,
180
+ query_vector: list[float] | None = None,
181
+ limit: int | None = 15,
166
182
  with_vector: bool = False,
167
- ) -> List[ScoredResult]:
168
- from qdrant_client.http.exceptions import UnexpectedResponse
169
-
183
+ ) -> list[ScoredResult]:
170
184
  if query_text is None and query_vector is None:
171
185
  raise MissingQueryParameterError()
172
186
 
@@ -176,47 +190,63 @@ class QDrantAdapter(VectorDBInterface):
176
190
  if query_vector is None:
177
191
  query_vector = (await self.embed_data([query_text]))[0]
178
192
 
193
+ client = None
179
194
  try:
180
195
  client = self.get_qdrant_client()
181
- if limit == 0:
196
+ if limit is None:
182
197
  collection_size = await client.count(collection_name=collection_name)
183
198
  limit = collection_size.count
184
199
  if limit == 0:
200
+ await client.close()
185
201
  return []
186
202
 
187
- results = await client.search(
203
+ # Use query_points instead of search (API change in qdrant-client)
204
+ # query_points is the correct method for AsyncQdrantClient
205
+ query_result = await client.query_points(
188
206
  collection_name=collection_name,
189
- query_vector=models.NamedVector(
190
- name="text",
191
- vector=query_vector
192
- if query_vector is not None
193
- else (await self.embed_data([query_text]))[0],
207
+ query=query_vector,
208
+ query_filter=models.Filter(
209
+ must=[
210
+ models.FieldCondition(
211
+ key="database_name",
212
+ match=models.MatchValue(
213
+ value=self.database_name,
214
+ ),
215
+ )
216
+ ]
194
217
  ),
218
+ using="text",
195
219
  limit=limit,
196
220
  with_vectors=with_vector,
197
221
  )
198
222
 
199
223
  await client.close()
200
224
 
225
+ # Extract points from query_result
226
+ results = query_result.points
227
+
201
228
  return [
202
229
  ScoredResult(
203
- id=parse_id(result.id),
230
+ id=parse_id(str(result.id)),
204
231
  payload={
205
232
  **result.payload,
206
- "id": parse_id(result.id),
233
+ "id": parse_id(str(result.id)),
207
234
  },
208
- score=1 - result.score,
235
+ score=1 - result.score if hasattr(result, "score") else 1.0,
209
236
  )
210
237
  for result in results
211
238
  ]
212
- finally:
213
- await client.close()
239
+ except Exception as e:
240
+ logger.error(f"Error in Qdrant search: {e}", exc_info=True)
241
+ if client:
242
+ await client.close()
243
+ return []
214
244
 
215
245
  async def batch_search(
216
246
  self,
217
247
  collection_name: str,
218
- query_texts: List[str],
219
- limit: int = None,
248
+ query_texts: list[str],
249
+ limit: int | None = None,
220
250
  with_vectors: bool = False,
221
251
  ):
222
252
  """
@@ -226,37 +256,59 @@ class QDrantAdapter(VectorDBInterface):
226
256
  - collection_name (str): Name of the collection to search in.
227
257
  - query_texts (List[str]): List of query texts to search for.
228
258
  - limit (int): List of result limits for search requests.
229
- - with_vectors (bool, optional): Bool indicating whether to return vectors for search requests.
259
+ - with_vectors (bool, optional): Bool indicating whether to return
260
+ vectors for search requests.
230
261
 
231
262
  Returns:
232
263
  - results: The search results from Qdrant.
233
264
  """
234
265
 
235
- vectors = await self.embed_data(query_texts)
236
-
237
- # Generate dynamic search requests based on the provided embeddings
238
- requests = [
239
- models.SearchRequest(
240
- vector=models.NamedVector(name="text", vector=vector),
241
- limit=limit,
242
- with_vector=with_vectors,
243
- )
244
- for vector in vectors
245
- ]
266
+ client = self.get_qdrant_client()
267
+ if limit is None:
268
+ collection_size = await client.count(collection_name=collection_name)
269
+ limit = collection_size.count
270
+ if limit == 0:
271
+ await client.close()
272
+ return []
246
273
 
247
274
  client = self.get_qdrant_client()
248
275
 
249
- # Perform batch search with the dynamically generated requests
250
- results = await client.search_batch(
251
- collection_name=collection_name, requests=requests
252
- )
276
+ try:
277
+ # Use query_batch instead of search_batch (API change in qdrant-client)
278
+ # query_batch is the correct method for AsyncQdrantClient
279
+ query_results = await client.query_batch(
280
+ collection_name=collection_name,
281
+ query_texts=query_texts,
282
+ query_filter=models.Filter(
283
+ must=[
284
+ models.FieldCondition(
285
+ key="database_name",
286
+ match=models.MatchValue(
287
+ value=self.database_name,
288
+ ),
289
+ )
290
+ ]
291
+ ),
292
+ limit=limit,
293
+ with_vectors=with_vectors,
294
+ )
253
295
 
254
- await client.close()
296
+ await client.close()
255
297
 
256
- return [
257
- filter(lambda result: result.score > 0.9, result_group)
258
- for result_group in results
259
- ]
298
+ # Extract points from each query result and filter by score
299
+ filtered_results = []
300
+ for query_result in query_results:
301
+ points = query_result.points if hasattr(query_result, "points") else []
302
+ filtered_points = [
303
+ result for result in points if hasattr(result, "score") and result.score > 0.9
304
+ ]
305
+ filtered_results.append(filtered_points)
306
+
307
+ return filtered_results
308
+ except Exception as e:
309
+ logger.error(f"Error in Qdrant batch_search: {e}", exc_info=True)
310
+ await client.close()
311
+ return []
260
312
 
261
313
  async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
262
314
  client = self.get_qdrant_client()
@@ -269,6 +321,55 @@ class QDrantAdapter(VectorDBInterface):
269
321
  response = await client.get_collections()
270
322
 
271
323
  for collection in response.collections:
272
- await client.delete_collection(collection.name)
324
+ await client.delete(
325
+ collection.name,
326
+ points_selector=models.FilterSelector(
327
+ filter=models.Filter(
328
+ must=[
329
+ models.FieldCondition(
330
+ key="database_name",
331
+ match=models.MatchValue(value=self.database_name),
332
+ )
333
+ ]
334
+ )
335
+ ),
336
+ )
337
+ remaining_points = await client.count(collection_name=collection.name)
338
+ if remaining_points.count == 0:
339
+ await client.delete_collection(collection_name=collection.name)
273
340
 
274
341
  await client.close()
342
+
343
+ async def get_collection_names(self) -> list[str]:
344
+ """
345
+ Get names of all collections in the database.
346
+
347
+ Returns:
348
+ list[str]: List of collection names.
349
+ """
350
+
351
+ client = self.get_qdrant_client()
352
+
353
+ response = await client.get_collections()
354
+
355
+ # We do this filtering because one user could see another user's collections otherwise
356
+ result = []
357
+ for collection in response.collections:
358
+ relevant_count = await client.count(
359
+ collection_name=collection.name,
360
+ count_filter=models.Filter(
361
+ must=[
362
+ models.FieldCondition(
363
+ key="database_name", match=models.MatchValue(value=self.database_name)
364
+ )
365
+ ]
366
+ ),
367
+ exact=True,
368
+ )
369
+
370
+ if relevant_count.count > 0:
371
+ result.append(collection.name)
372
+
373
+ await client.close()
374
+
375
+ return result
@@ -0,0 +1,8 @@
1
+ from cognee.infrastructure.databases.dataset_database_handler import use_dataset_database_handler
2
+ from cognee.infrastructure.databases.vector import use_vector_adapter
3
+
4
+ from .qdrant_adapter import QDrantAdapter
5
+ from .QdrantDatasetDatabaseHandler import QdrantDatasetDatabaseHandler
6
+
7
+ use_vector_adapter("qdrant", QDrantAdapter)
8
+ use_dataset_database_handler("qdrant", QdrantDatasetDatabaseHandler, "qdrant")
@@ -1,10 +1,12 @@
1
1
  [project]
2
2
  name = "cognee-community-vector-adapter-qdrant"
3
- version = "0.0.3"
3
+ version = "0.2.0"
4
4
  description = "Qdrant vector database adapter for cognee"
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.11,<=3.13"
7
7
  dependencies = [
8
- "qdrant-client>=1.14.2",
9
- "cognee>=0.2.4",
8
+ "qdrant-client>=1.16.0",
9
+ "cognee==0.5.1",
10
+ "starlette>=0.48.0",
11
+ "instructor>=1.11"
10
12
  ]
@@ -1,5 +0,0 @@
1
- from cognee.infrastructure.databases.vector import use_vector_adapter
2
-
3
- from .qdrant_adapter import QDrantAdapter
4
-
5
- use_vector_adapter("qdrant", QDrantAdapter)