qdrant-haystack 4.1.0__tar.gz → 4.1.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of qdrant-haystack might be problematic. Click here for more details.
- {qdrant_haystack-4.1.0 → qdrant_haystack-4.1.2}/CHANGELOG.md +17 -0
- {qdrant_haystack-4.1.0 → qdrant_haystack-4.1.2}/PKG-INFO +2 -2
- {qdrant_haystack-4.1.0 → qdrant_haystack-4.1.2}/pyproject.toml +1 -1
- {qdrant_haystack-4.1.0 → qdrant_haystack-4.1.2}/src/haystack_integrations/components/retrievers/qdrant/retriever.py +47 -5
- qdrant_haystack-4.1.2/src/haystack_integrations/document_stores/qdrant/filters.py +316 -0
- {qdrant_haystack-4.1.0 → qdrant_haystack-4.1.2}/tests/test_filters.py +106 -0
- {qdrant_haystack-4.1.0 → qdrant_haystack-4.1.2}/tests/test_retriever.py +105 -0
- qdrant_haystack-4.1.0/src/haystack_integrations/document_stores/qdrant/filters.py +0 -238
- {qdrant_haystack-4.1.0 → qdrant_haystack-4.1.2}/.gitignore +0 -0
- {qdrant_haystack-4.1.0 → qdrant_haystack-4.1.2}/LICENSE.txt +0 -0
- {qdrant_haystack-4.1.0 → qdrant_haystack-4.1.2}/README.md +0 -0
- {qdrant_haystack-4.1.0 → qdrant_haystack-4.1.2}/examples/embedding_retrieval.py +0 -0
- {qdrant_haystack-4.1.0 → qdrant_haystack-4.1.2}/pydoc/config.yml +0 -0
- {qdrant_haystack-4.1.0 → qdrant_haystack-4.1.2}/src/haystack_integrations/components/retrievers/qdrant/__init__.py +0 -0
- {qdrant_haystack-4.1.0 → qdrant_haystack-4.1.2}/src/haystack_integrations/document_stores/qdrant/__init__.py +0 -0
- {qdrant_haystack-4.1.0 → qdrant_haystack-4.1.2}/src/haystack_integrations/document_stores/qdrant/converters.py +0 -0
- {qdrant_haystack-4.1.0 → qdrant_haystack-4.1.2}/src/haystack_integrations/document_stores/qdrant/document_store.py +0 -0
- {qdrant_haystack-4.1.0 → qdrant_haystack-4.1.2}/src/haystack_integrations/document_stores/qdrant/migrate_to_sparse.py +0 -0
- {qdrant_haystack-4.1.0 → qdrant_haystack-4.1.2}/tests/__init__.py +0 -0
- {qdrant_haystack-4.1.0 → qdrant_haystack-4.1.2}/tests/conftest.py +0 -0
- {qdrant_haystack-4.1.0 → qdrant_haystack-4.1.2}/tests/test_converters.py +0 -0
- {qdrant_haystack-4.1.0 → qdrant_haystack-4.1.2}/tests/test_dict_converters.py +0 -0
- {qdrant_haystack-4.1.0 → qdrant_haystack-4.1.2}/tests/test_document_store.py +0 -0
- {qdrant_haystack-4.1.0 → qdrant_haystack-4.1.2}/tests/test_legacy_filters.py +0 -0
|
@@ -1,5 +1,22 @@
|
|
|
1
1
|
# Changelog
|
|
2
2
|
|
|
3
|
+
## [integrations/qdrant-v4.1.1] - 2024-07-10
|
|
4
|
+
|
|
5
|
+
### 🚀 Features
|
|
6
|
+
|
|
7
|
+
- Add filter_policy to qdrant integration (#819)
|
|
8
|
+
|
|
9
|
+
### 🐛 Bug Fixes
|
|
10
|
+
|
|
11
|
+
- Errors in convert_filters_to_qdrant (#870)
|
|
12
|
+
|
|
13
|
+
## [integrations/qdrant-v4.1.0] - 2024-07-03
|
|
14
|
+
|
|
15
|
+
### 🚀 Features
|
|
16
|
+
|
|
17
|
+
- Add `score_threshold` to Qdrant Retrievers (#860)
|
|
18
|
+
- Qdrant - add support for BM42 (#864)
|
|
19
|
+
|
|
3
20
|
## [integrations/qdrant-v4.0.0] - 2024-07-02
|
|
4
21
|
|
|
5
22
|
### 🚜 Refactor
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.3
|
|
2
2
|
Name: qdrant-haystack
|
|
3
|
-
Version: 4.1.
|
|
3
|
+
Version: 4.1.2
|
|
4
4
|
Summary: An integration of Qdrant ANN vector database backend with Haystack
|
|
5
5
|
Project-URL: Source, https://github.com/deepset-ai/haystack-core-integrations
|
|
6
6
|
Project-URL: Documentation, https://github.com/deepset-ai/haystack-core-integrations/blob/main/integrations/qdrant/README.md
|
|
@@ -18,7 +18,7 @@ Classifier: Programming Language :: Python :: 3.11
|
|
|
18
18
|
Classifier: Programming Language :: Python :: Implementation :: CPython
|
|
19
19
|
Classifier: Programming Language :: Python :: Implementation :: PyPy
|
|
20
20
|
Requires-Python: >=3.8
|
|
21
|
-
Requires-Dist: haystack-ai
|
|
21
|
+
Requires-Dist: haystack-ai
|
|
22
22
|
Requires-Dist: qdrant-client>=1.10.0
|
|
23
23
|
Description-Content-Type: text/markdown
|
|
24
24
|
|
|
@@ -25,7 +25,7 @@ classifiers = [
|
|
|
25
25
|
"Programming Language :: Python :: Implementation :: CPython",
|
|
26
26
|
"Programming Language :: Python :: Implementation :: PyPy",
|
|
27
27
|
]
|
|
28
|
-
dependencies = ["haystack-ai
|
|
28
|
+
dependencies = ["haystack-ai", "qdrant-client>=1.10.0"]
|
|
29
29
|
|
|
30
30
|
[project.urls]
|
|
31
31
|
Source = "https://github.com/deepset-ai/haystack-core-integrations"
|
|
@@ -2,6 +2,8 @@ from typing import Any, Dict, List, Optional, Union
|
|
|
2
2
|
|
|
3
3
|
from haystack import Document, component, default_from_dict, default_to_dict
|
|
4
4
|
from haystack.dataclasses.sparse_embedding import SparseEmbedding
|
|
5
|
+
from haystack.document_stores.types import FilterPolicy
|
|
6
|
+
from haystack.document_stores.types.filter_policy import apply_filter_policy
|
|
5
7
|
from haystack_integrations.document_stores.qdrant import QdrantDocumentStore
|
|
6
8
|
from qdrant_client.http import models
|
|
7
9
|
|
|
@@ -39,6 +41,7 @@ class QdrantEmbeddingRetriever:
|
|
|
39
41
|
top_k: int = 10,
|
|
40
42
|
scale_score: bool = False,
|
|
41
43
|
return_embedding: bool = False,
|
|
44
|
+
filter_policy: Union[str, FilterPolicy] = FilterPolicy.REPLACE,
|
|
42
45
|
score_threshold: Optional[float] = None,
|
|
43
46
|
):
|
|
44
47
|
"""
|
|
@@ -49,6 +52,7 @@ class QdrantEmbeddingRetriever:
|
|
|
49
52
|
:param top_k: The maximum number of documents to retrieve.
|
|
50
53
|
:param scale_score: Whether to scale the scores of the retrieved documents or not.
|
|
51
54
|
:param return_embedding: Whether to return the embedding of the retrieved Documents.
|
|
55
|
+
:param filter_policy: Policy to determine how filters are applied.
|
|
52
56
|
:param score_threshold: A minimal score threshold for the result.
|
|
53
57
|
Score of the returned result might be higher or smaller than the threshold
|
|
54
58
|
depending on the `similarity` function specified in the Document Store.
|
|
@@ -66,6 +70,9 @@ class QdrantEmbeddingRetriever:
|
|
|
66
70
|
self._top_k = top_k
|
|
67
71
|
self._scale_score = scale_score
|
|
68
72
|
self._return_embedding = return_embedding
|
|
73
|
+
self._filter_policy = (
|
|
74
|
+
filter_policy if isinstance(filter_policy, FilterPolicy) else FilterPolicy.from_str(filter_policy)
|
|
75
|
+
)
|
|
69
76
|
self._score_threshold = score_threshold
|
|
70
77
|
|
|
71
78
|
def to_dict(self) -> Dict[str, Any]:
|
|
@@ -80,6 +87,7 @@ class QdrantEmbeddingRetriever:
|
|
|
80
87
|
document_store=self._document_store,
|
|
81
88
|
filters=self._filters,
|
|
82
89
|
top_k=self._top_k,
|
|
90
|
+
filter_policy=self._filter_policy.value,
|
|
83
91
|
scale_score=self._scale_score,
|
|
84
92
|
return_embedding=self._return_embedding,
|
|
85
93
|
score_threshold=self._score_threshold,
|
|
@@ -100,6 +108,10 @@ class QdrantEmbeddingRetriever:
|
|
|
100
108
|
"""
|
|
101
109
|
document_store = QdrantDocumentStore.from_dict(data["init_parameters"]["document_store"])
|
|
102
110
|
data["init_parameters"]["document_store"] = document_store
|
|
111
|
+
# Pipelines serialized with old versions of the component might not
|
|
112
|
+
# have the filter_policy field.
|
|
113
|
+
if filter_policy := data["init_parameters"].get("filter_policy"):
|
|
114
|
+
data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(filter_policy)
|
|
103
115
|
return default_from_dict(cls, data)
|
|
104
116
|
|
|
105
117
|
@component.output_types(documents=List[Document])
|
|
@@ -125,9 +137,11 @@ class QdrantEmbeddingRetriever:
|
|
|
125
137
|
The retrieved documents.
|
|
126
138
|
|
|
127
139
|
"""
|
|
140
|
+
filters = apply_filter_policy(self._filter_policy, self._filters, filters)
|
|
141
|
+
|
|
128
142
|
docs = self._document_store._query_by_embedding(
|
|
129
143
|
query_embedding=query_embedding,
|
|
130
|
-
filters=filters
|
|
144
|
+
filters=filters,
|
|
131
145
|
top_k=top_k or self._top_k,
|
|
132
146
|
scale_score=scale_score or self._scale_score,
|
|
133
147
|
return_embedding=return_embedding or self._return_embedding,
|
|
@@ -171,6 +185,7 @@ class QdrantSparseEmbeddingRetriever:
|
|
|
171
185
|
top_k: int = 10,
|
|
172
186
|
scale_score: bool = False,
|
|
173
187
|
return_embedding: bool = False,
|
|
188
|
+
filter_policy: Union[str, FilterPolicy] = FilterPolicy.REPLACE,
|
|
174
189
|
score_threshold: Optional[float] = None,
|
|
175
190
|
):
|
|
176
191
|
"""
|
|
@@ -181,6 +196,7 @@ class QdrantSparseEmbeddingRetriever:
|
|
|
181
196
|
:param top_k: The maximum number of documents to retrieve.
|
|
182
197
|
:param scale_score: Whether to scale the scores of the retrieved documents or not.
|
|
183
198
|
:param return_embedding: Whether to return the sparse embedding of the retrieved Documents.
|
|
199
|
+
:param filter_policy: Policy to determine how filters are applied. Defaults to "replace".
|
|
184
200
|
:param score_threshold: A minimal score threshold for the result.
|
|
185
201
|
Score of the returned result might be higher or smaller than the threshold
|
|
186
202
|
depending on the Distance function used.
|
|
@@ -198,6 +214,9 @@ class QdrantSparseEmbeddingRetriever:
|
|
|
198
214
|
self._top_k = top_k
|
|
199
215
|
self._scale_score = scale_score
|
|
200
216
|
self._return_embedding = return_embedding
|
|
217
|
+
self._filter_policy = (
|
|
218
|
+
filter_policy if isinstance(filter_policy, FilterPolicy) else FilterPolicy.from_str(filter_policy)
|
|
219
|
+
)
|
|
201
220
|
self._score_threshold = score_threshold
|
|
202
221
|
|
|
203
222
|
def to_dict(self) -> Dict[str, Any]:
|
|
@@ -213,6 +232,7 @@ class QdrantSparseEmbeddingRetriever:
|
|
|
213
232
|
filters=self._filters,
|
|
214
233
|
top_k=self._top_k,
|
|
215
234
|
scale_score=self._scale_score,
|
|
235
|
+
filter_policy=self._filter_policy.value,
|
|
216
236
|
return_embedding=self._return_embedding,
|
|
217
237
|
score_threshold=self._score_threshold,
|
|
218
238
|
)
|
|
@@ -232,6 +252,10 @@ class QdrantSparseEmbeddingRetriever:
|
|
|
232
252
|
"""
|
|
233
253
|
document_store = QdrantDocumentStore.from_dict(data["init_parameters"]["document_store"])
|
|
234
254
|
data["init_parameters"]["document_store"] = document_store
|
|
255
|
+
# Pipelines serialized with old versions of the component might not
|
|
256
|
+
# have the filter_policy field.
|
|
257
|
+
if filter_policy := data["init_parameters"].get("filter_policy"):
|
|
258
|
+
data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(filter_policy)
|
|
235
259
|
return default_from_dict(cls, data)
|
|
236
260
|
|
|
237
261
|
@component.output_types(documents=List[Document])
|
|
@@ -248,7 +272,9 @@ class QdrantSparseEmbeddingRetriever:
|
|
|
248
272
|
Run the Sparse Embedding Retriever on the given input data.
|
|
249
273
|
|
|
250
274
|
:param query_sparse_embedding: Sparse Embedding of the query.
|
|
251
|
-
:param filters:
|
|
275
|
+
:param filters: Filters applied to the retrieved Documents. The way runtime filters are applied depends on
|
|
276
|
+
the `filter_policy` chosen at retriever initialization. See init method docstring for more
|
|
277
|
+
details.
|
|
252
278
|
:param top_k: The maximum number of documents to return.
|
|
253
279
|
:param scale_score: Whether to scale the scores of the retrieved documents or not.
|
|
254
280
|
:param return_embedding: Whether to return the embedding of the retrieved Documents.
|
|
@@ -260,9 +286,11 @@ class QdrantSparseEmbeddingRetriever:
|
|
|
260
286
|
The retrieved documents.
|
|
261
287
|
|
|
262
288
|
"""
|
|
289
|
+
filters = apply_filter_policy(self._filter_policy, self._filters, filters)
|
|
290
|
+
|
|
263
291
|
docs = self._document_store._query_by_sparse(
|
|
264
292
|
query_sparse_embedding=query_sparse_embedding,
|
|
265
|
-
filters=filters
|
|
293
|
+
filters=filters,
|
|
266
294
|
top_k=top_k or self._top_k,
|
|
267
295
|
scale_score=scale_score or self._scale_score,
|
|
268
296
|
return_embedding=return_embedding or self._return_embedding,
|
|
@@ -311,6 +339,7 @@ class QdrantHybridRetriever:
|
|
|
311
339
|
filters: Optional[Union[Dict[str, Any], models.Filter]] = None,
|
|
312
340
|
top_k: int = 10,
|
|
313
341
|
return_embedding: bool = False,
|
|
342
|
+
filter_policy: Union[str, FilterPolicy] = FilterPolicy.REPLACE,
|
|
314
343
|
score_threshold: Optional[float] = None,
|
|
315
344
|
):
|
|
316
345
|
"""
|
|
@@ -320,6 +349,7 @@ class QdrantHybridRetriever:
|
|
|
320
349
|
:param filters: A dictionary with filters to narrow down the search space.
|
|
321
350
|
:param top_k: The maximum number of documents to retrieve.
|
|
322
351
|
:param return_embedding: Whether to return the embeddings of the retrieved Documents.
|
|
352
|
+
:param filter_policy: Policy to determine how filters are applied.
|
|
323
353
|
:param score_threshold: A minimal score threshold for the result.
|
|
324
354
|
Score of the returned result might be higher or smaller than the threshold
|
|
325
355
|
depending on the Distance function used.
|
|
@@ -336,6 +366,9 @@ class QdrantHybridRetriever:
|
|
|
336
366
|
self._filters = filters
|
|
337
367
|
self._top_k = top_k
|
|
338
368
|
self._return_embedding = return_embedding
|
|
369
|
+
self._filter_policy = (
|
|
370
|
+
filter_policy if isinstance(filter_policy, FilterPolicy) else FilterPolicy.from_str(filter_policy)
|
|
371
|
+
)
|
|
339
372
|
self._score_threshold = score_threshold
|
|
340
373
|
|
|
341
374
|
def to_dict(self) -> Dict[str, Any]:
|
|
@@ -350,6 +383,7 @@ class QdrantHybridRetriever:
|
|
|
350
383
|
document_store=self._document_store.to_dict(),
|
|
351
384
|
filters=self._filters,
|
|
352
385
|
top_k=self._top_k,
|
|
386
|
+
filter_policy=self._filter_policy.value,
|
|
353
387
|
return_embedding=self._return_embedding,
|
|
354
388
|
score_threshold=self._score_threshold,
|
|
355
389
|
)
|
|
@@ -366,6 +400,10 @@ class QdrantHybridRetriever:
|
|
|
366
400
|
"""
|
|
367
401
|
document_store = QdrantDocumentStore.from_dict(data["init_parameters"]["document_store"])
|
|
368
402
|
data["init_parameters"]["document_store"] = document_store
|
|
403
|
+
# Pipelines serialized with old versions of the component might not
|
|
404
|
+
# have the filter_policy field.
|
|
405
|
+
if filter_policy := data["init_parameters"].get("filter_policy"):
|
|
406
|
+
data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(filter_policy)
|
|
369
407
|
return default_from_dict(cls, data)
|
|
370
408
|
|
|
371
409
|
@component.output_types(documents=List[Document])
|
|
@@ -383,7 +421,9 @@ class QdrantHybridRetriever:
|
|
|
383
421
|
|
|
384
422
|
:param query_embedding: Dense embedding of the query.
|
|
385
423
|
:param query_sparse_embedding: Sparse embedding of the query.
|
|
386
|
-
:param filters:
|
|
424
|
+
:param filters: Filters applied to the retrieved Documents. The way runtime filters are applied depends on
|
|
425
|
+
the `filter_policy` chosen at retriever initialization. See init method docstring for more
|
|
426
|
+
details.
|
|
387
427
|
:param top_k: The maximum number of documents to return.
|
|
388
428
|
:param return_embedding: Whether to return the embedding of the retrieved Documents.
|
|
389
429
|
:param score_threshold: A minimal score threshold for the result.
|
|
@@ -394,10 +434,12 @@ class QdrantHybridRetriever:
|
|
|
394
434
|
The retrieved documents.
|
|
395
435
|
|
|
396
436
|
"""
|
|
437
|
+
filters = apply_filter_policy(self._filter_policy, self._filters, filters)
|
|
438
|
+
|
|
397
439
|
docs = self._document_store._query_hybrid(
|
|
398
440
|
query_embedding=query_embedding,
|
|
399
441
|
query_sparse_embedding=query_sparse_embedding,
|
|
400
|
-
filters=filters
|
|
442
|
+
filters=filters,
|
|
401
443
|
top_k=top_k or self._top_k,
|
|
402
444
|
return_embedding=return_embedding or self._return_embedding,
|
|
403
445
|
score_threshold=score_threshold or self._score_threshold,
|
|
@@ -0,0 +1,316 @@
|
|
|
1
|
+
from datetime import datetime
|
|
2
|
+
from typing import List, Optional, Union
|
|
3
|
+
|
|
4
|
+
from haystack.utils.filters import COMPARISON_OPERATORS, LOGICAL_OPERATORS, FilterError
|
|
5
|
+
from qdrant_client.http import models
|
|
6
|
+
|
|
7
|
+
COMPARISON_OPERATORS = COMPARISON_OPERATORS.keys()
|
|
8
|
+
LOGICAL_OPERATORS = LOGICAL_OPERATORS.keys()
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def convert_filters_to_qdrant(
|
|
12
|
+
filter_term: Optional[Union[List[dict], dict, models.Filter]] = None, is_parent_call: bool = True
|
|
13
|
+
) -> Optional[Union[models.Filter, List[models.Filter], List[models.Condition]]]:
|
|
14
|
+
"""Converts Haystack filters to the format used by Qdrant.
|
|
15
|
+
|
|
16
|
+
:param filter_term: the haystack filter to be converted to qdrant.
|
|
17
|
+
:param is_parent_call: indicates if this is the top-level call to the function. If True, the function returns
|
|
18
|
+
a single models.Filter object; if False, it may return a list of filters or conditions for further processing.
|
|
19
|
+
|
|
20
|
+
:returns: a single Qdrant Filter in the parent call or a list of such Filters in recursive calls.
|
|
21
|
+
|
|
22
|
+
:raises FilterError: If the invalid filter criteria is provided or if an unknown operator is encountered.
|
|
23
|
+
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
if isinstance(filter_term, models.Filter):
|
|
27
|
+
return filter_term
|
|
28
|
+
if not filter_term:
|
|
29
|
+
return None
|
|
30
|
+
|
|
31
|
+
must_clauses: List[models.Filter] = []
|
|
32
|
+
should_clauses: List[models.Filter] = []
|
|
33
|
+
must_not_clauses: List[models.Filter] = []
|
|
34
|
+
# Indicates if there are multiple same LOGICAL OPERATORS on each level
|
|
35
|
+
# and prevents them from being combined
|
|
36
|
+
same_operator_flag = False
|
|
37
|
+
conditions, qdrant_filter, current_level_operators = (
|
|
38
|
+
[],
|
|
39
|
+
[],
|
|
40
|
+
[],
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
if isinstance(filter_term, dict):
|
|
44
|
+
filter_term = [filter_term]
|
|
45
|
+
|
|
46
|
+
# ======== IDENTIFY FILTER ITEMS ON EACH LEVEL ========
|
|
47
|
+
|
|
48
|
+
for item in filter_term:
|
|
49
|
+
operator = item.get("operator")
|
|
50
|
+
|
|
51
|
+
# Check for repeated similar operators on each level
|
|
52
|
+
same_operator_flag = operator in current_level_operators and operator in LOGICAL_OPERATORS
|
|
53
|
+
if not same_operator_flag:
|
|
54
|
+
current_level_operators.append(operator)
|
|
55
|
+
|
|
56
|
+
if operator is None:
|
|
57
|
+
msg = "Operator not found in filters"
|
|
58
|
+
raise FilterError(msg)
|
|
59
|
+
|
|
60
|
+
if operator in LOGICAL_OPERATORS and "conditions" not in item:
|
|
61
|
+
msg = f"'conditions' not found for '{operator}'"
|
|
62
|
+
raise FilterError(msg)
|
|
63
|
+
|
|
64
|
+
if operator in LOGICAL_OPERATORS:
|
|
65
|
+
# Recursively process nested conditions
|
|
66
|
+
current_filter = convert_filters_to_qdrant(item.get("conditions", []), is_parent_call=False) or []
|
|
67
|
+
|
|
68
|
+
# When same_operator_flag is set to True,
|
|
69
|
+
# ensure each clause is appended as an independent list to avoid merging distinct clauses.
|
|
70
|
+
if operator == "AND":
|
|
71
|
+
must_clauses = [must_clauses, current_filter] if same_operator_flag else must_clauses + current_filter
|
|
72
|
+
elif operator == "OR":
|
|
73
|
+
should_clauses = (
|
|
74
|
+
[should_clauses, current_filter] if same_operator_flag else should_clauses + current_filter
|
|
75
|
+
)
|
|
76
|
+
elif operator == "NOT":
|
|
77
|
+
must_not_clauses = (
|
|
78
|
+
[must_not_clauses, current_filter] if same_operator_flag else must_not_clauses + current_filter
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
elif operator in COMPARISON_OPERATORS:
|
|
82
|
+
field = item.get("field")
|
|
83
|
+
value = item.get("value")
|
|
84
|
+
if field is None or value is None:
|
|
85
|
+
msg = f"'field' or 'value' not found for '{operator}'"
|
|
86
|
+
raise FilterError(msg)
|
|
87
|
+
|
|
88
|
+
parsed_conditions = _parse_comparison_operation(comparison_operation=operator, key=field, value=value)
|
|
89
|
+
|
|
90
|
+
# check if the parsed_conditions are models.Filter or models.Condition
|
|
91
|
+
for condition in parsed_conditions:
|
|
92
|
+
if isinstance(condition, models.Filter):
|
|
93
|
+
qdrant_filter.append(condition)
|
|
94
|
+
else:
|
|
95
|
+
conditions.append(condition)
|
|
96
|
+
|
|
97
|
+
else:
|
|
98
|
+
msg = f"Unknown operator {operator} used in filters"
|
|
99
|
+
raise FilterError(msg)
|
|
100
|
+
|
|
101
|
+
# ======== PROCESS FILTER ITEMS ON EACH LEVEL ========
|
|
102
|
+
|
|
103
|
+
# If same logical operators have separate clauses, create separate filters
|
|
104
|
+
if same_operator_flag:
|
|
105
|
+
qdrant_filter = build_filters_for_repeated_operators(
|
|
106
|
+
must_clauses, should_clauses, must_not_clauses, qdrant_filter
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
# else append a single Filter for existing clauses
|
|
110
|
+
elif must_clauses or should_clauses or must_not_clauses:
|
|
111
|
+
qdrant_filter.append(
|
|
112
|
+
models.Filter(
|
|
113
|
+
must=must_clauses or None,
|
|
114
|
+
should=should_clauses or None,
|
|
115
|
+
must_not=must_not_clauses or None,
|
|
116
|
+
)
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
# In case of parent call, a single Filter is returned
|
|
120
|
+
if is_parent_call:
|
|
121
|
+
# If qdrant_filter has just a single Filter in parent call,
|
|
122
|
+
# then it might be returned instead.
|
|
123
|
+
if len(qdrant_filter) == 1 and isinstance(qdrant_filter[0], models.Filter):
|
|
124
|
+
return qdrant_filter[0]
|
|
125
|
+
else:
|
|
126
|
+
must_clauses.extend(conditions)
|
|
127
|
+
return models.Filter(
|
|
128
|
+
must=must_clauses or None,
|
|
129
|
+
should=should_clauses or None,
|
|
130
|
+
must_not=must_not_clauses or None,
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
# Store conditions of each level in output of the loop
|
|
134
|
+
elif conditions:
|
|
135
|
+
qdrant_filter.extend(conditions)
|
|
136
|
+
|
|
137
|
+
return qdrant_filter
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def build_filters_for_repeated_operators(
|
|
141
|
+
must_clauses,
|
|
142
|
+
should_clauses,
|
|
143
|
+
must_not_clauses,
|
|
144
|
+
qdrant_filter,
|
|
145
|
+
) -> List[models.Filter]:
|
|
146
|
+
"""
|
|
147
|
+
Flattens the nested lists of clauses by creating separate Filters for each clause of a logical operator.
|
|
148
|
+
|
|
149
|
+
:param must_clauses: a nested list of must clauses or an empty list.
|
|
150
|
+
:param should_clauses: a nested list of should clauses or an empty list.
|
|
151
|
+
:param must_not_clauses: a nested list of must_not clauses or an empty list.
|
|
152
|
+
:param qdrant_filter: a list where the generated Filter objects will be appended.
|
|
153
|
+
This list will be modified in-place.
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
:returns: the modified `qdrant_filter` list with appended generated Filter objects.
|
|
157
|
+
"""
|
|
158
|
+
|
|
159
|
+
if any(isinstance(i, list) for i in must_clauses):
|
|
160
|
+
for i in must_clauses:
|
|
161
|
+
qdrant_filter.append(
|
|
162
|
+
models.Filter(
|
|
163
|
+
must=i or None,
|
|
164
|
+
should=should_clauses or None,
|
|
165
|
+
must_not=must_not_clauses or None,
|
|
166
|
+
)
|
|
167
|
+
)
|
|
168
|
+
if any(isinstance(i, list) for i in should_clauses):
|
|
169
|
+
for i in should_clauses:
|
|
170
|
+
qdrant_filter.append(
|
|
171
|
+
models.Filter(
|
|
172
|
+
must=must_clauses or None,
|
|
173
|
+
should=i or None,
|
|
174
|
+
must_not=must_not_clauses or None,
|
|
175
|
+
)
|
|
176
|
+
)
|
|
177
|
+
if any(isinstance(i, list) for i in must_not_clauses):
|
|
178
|
+
for i in must_clauses:
|
|
179
|
+
qdrant_filter.append(
|
|
180
|
+
models.Filter(
|
|
181
|
+
must=must_clauses or None,
|
|
182
|
+
should=should_clauses or None,
|
|
183
|
+
must_not=i or None,
|
|
184
|
+
)
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
return qdrant_filter
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def _parse_comparison_operation(
|
|
191
|
+
comparison_operation: str, key: str, value: Union[dict, List, str, float]
|
|
192
|
+
) -> List[models.Condition]:
|
|
193
|
+
conditions: List[models.Condition] = []
|
|
194
|
+
|
|
195
|
+
condition_builder_mapping = {
|
|
196
|
+
"==": _build_eq_condition,
|
|
197
|
+
"in": _build_in_condition,
|
|
198
|
+
"!=": _build_ne_condition,
|
|
199
|
+
"not in": _build_nin_condition,
|
|
200
|
+
">": _build_gt_condition,
|
|
201
|
+
">=": _build_gte_condition,
|
|
202
|
+
"<": _build_lt_condition,
|
|
203
|
+
"<=": _build_lte_condition,
|
|
204
|
+
}
|
|
205
|
+
|
|
206
|
+
condition_builder = condition_builder_mapping.get(comparison_operation)
|
|
207
|
+
|
|
208
|
+
if condition_builder is None:
|
|
209
|
+
msg = f"Unknown operator {comparison_operation} used in filters"
|
|
210
|
+
raise ValueError(msg)
|
|
211
|
+
|
|
212
|
+
conditions.append(condition_builder(key, value))
|
|
213
|
+
|
|
214
|
+
return conditions
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
def _build_eq_condition(key: str, value: models.ValueVariants) -> models.Condition:
|
|
218
|
+
if isinstance(value, str) and " " in value:
|
|
219
|
+
return models.FieldCondition(key=key, match=models.MatchText(text=value))
|
|
220
|
+
return models.FieldCondition(key=key, match=models.MatchValue(value=value))
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
def _build_in_condition(key: str, value: List[models.ValueVariants]) -> models.Condition:
|
|
224
|
+
if not isinstance(value, list):
|
|
225
|
+
msg = f"Value {value} is not a list"
|
|
226
|
+
raise FilterError(msg)
|
|
227
|
+
return models.Filter(
|
|
228
|
+
should=[
|
|
229
|
+
(
|
|
230
|
+
models.FieldCondition(key=key, match=models.MatchText(text=item))
|
|
231
|
+
if isinstance(item, str) and " " not in item
|
|
232
|
+
else models.FieldCondition(key=key, match=models.MatchValue(value=item))
|
|
233
|
+
)
|
|
234
|
+
for item in value
|
|
235
|
+
]
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
def _build_ne_condition(key: str, value: models.ValueVariants) -> models.Condition:
|
|
240
|
+
return models.Filter(
|
|
241
|
+
must_not=[
|
|
242
|
+
(
|
|
243
|
+
models.FieldCondition(key=key, match=models.MatchText(text=value))
|
|
244
|
+
if isinstance(value, str) and " " not in value
|
|
245
|
+
else models.FieldCondition(key=key, match=models.MatchValue(value=value))
|
|
246
|
+
)
|
|
247
|
+
]
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
def _build_nin_condition(key: str, value: List[models.ValueVariants]) -> models.Condition:
|
|
252
|
+
if not isinstance(value, list):
|
|
253
|
+
msg = f"Value {value} is not a list"
|
|
254
|
+
raise FilterError(msg)
|
|
255
|
+
return models.Filter(
|
|
256
|
+
must_not=[
|
|
257
|
+
(
|
|
258
|
+
models.FieldCondition(key=key, match=models.MatchText(text=item))
|
|
259
|
+
if isinstance(item, str) and " " in item
|
|
260
|
+
else models.FieldCondition(key=key, match=models.MatchValue(value=item))
|
|
261
|
+
)
|
|
262
|
+
for item in value
|
|
263
|
+
]
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
def _build_lt_condition(key: str, value: Union[str, float, int]) -> models.Condition:
|
|
268
|
+
if isinstance(value, str) and is_datetime_string(value):
|
|
269
|
+
return models.FieldCondition(key=key, range=models.DatetimeRange(lt=value))
|
|
270
|
+
|
|
271
|
+
if isinstance(value, (int, float)):
|
|
272
|
+
return models.FieldCondition(key=key, range=models.Range(lt=value))
|
|
273
|
+
|
|
274
|
+
msg = f"Value {value} is not an int or float or datetime string"
|
|
275
|
+
raise FilterError(msg)
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
def _build_lte_condition(key: str, value: Union[str, float, int]) -> models.Condition:
|
|
279
|
+
if isinstance(value, str) and is_datetime_string(value):
|
|
280
|
+
return models.FieldCondition(key=key, range=models.DatetimeRange(lte=value))
|
|
281
|
+
|
|
282
|
+
if isinstance(value, (int, float)):
|
|
283
|
+
return models.FieldCondition(key=key, range=models.Range(lte=value))
|
|
284
|
+
|
|
285
|
+
msg = f"Value {value} is not an int or float or datetime string"
|
|
286
|
+
raise FilterError(msg)
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
def _build_gt_condition(key: str, value: Union[str, float, int]) -> models.Condition:
|
|
290
|
+
if isinstance(value, str) and is_datetime_string(value):
|
|
291
|
+
return models.FieldCondition(key=key, range=models.DatetimeRange(gt=value))
|
|
292
|
+
|
|
293
|
+
if isinstance(value, (int, float)):
|
|
294
|
+
return models.FieldCondition(key=key, range=models.Range(gt=value))
|
|
295
|
+
|
|
296
|
+
msg = f"Value {value} is not an int or float or datetime string"
|
|
297
|
+
raise FilterError(msg)
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
def _build_gte_condition(key: str, value: Union[str, float, int]) -> models.Condition:
|
|
301
|
+
if isinstance(value, str) and is_datetime_string(value):
|
|
302
|
+
return models.FieldCondition(key=key, range=models.DatetimeRange(gte=value))
|
|
303
|
+
|
|
304
|
+
if isinstance(value, (int, float)):
|
|
305
|
+
return models.FieldCondition(key=key, range=models.Range(gte=value))
|
|
306
|
+
|
|
307
|
+
msg = f"Value {value} is not an int or float or datetime string"
|
|
308
|
+
raise FilterError(msg)
|
|
309
|
+
|
|
310
|
+
|
|
311
|
+
def is_datetime_string(value: str) -> bool:
|
|
312
|
+
try:
|
|
313
|
+
datetime.fromisoformat(value)
|
|
314
|
+
return True
|
|
315
|
+
except ValueError:
|
|
316
|
+
return False
|
|
@@ -61,6 +61,112 @@ class TestQdrantStoreBaseTests(FilterDocumentsTest):
|
|
|
61
61
|
[d for d in filterable_docs if (d.meta.get("number") != 100 and d.meta.get("name") != "name_0")],
|
|
62
62
|
)
|
|
63
63
|
|
|
64
|
+
def test_filter_criteria(self, document_store):
|
|
65
|
+
documents = [
|
|
66
|
+
Document(
|
|
67
|
+
content="This is test document 1.",
|
|
68
|
+
meta={"file_name": "file1", "classification": {"details": {"category1": 0.9, "category2": 0.3}}},
|
|
69
|
+
),
|
|
70
|
+
Document(
|
|
71
|
+
content="This is test document 2.",
|
|
72
|
+
meta={"file_name": "file2", "classification": {"details": {"category1": 0.1, "category2": 0.7}}},
|
|
73
|
+
),
|
|
74
|
+
Document(
|
|
75
|
+
content="This is test document 3.",
|
|
76
|
+
meta={"file_name": "file3", "classification": {"details": {"category1": 0.7, "category2": 0.9}}},
|
|
77
|
+
),
|
|
78
|
+
]
|
|
79
|
+
|
|
80
|
+
document_store.write_documents(documents)
|
|
81
|
+
filter_criteria = {
|
|
82
|
+
"operator": "AND",
|
|
83
|
+
"conditions": [
|
|
84
|
+
{"field": "meta.file_name", "operator": "in", "value": ["file1", "file2"]},
|
|
85
|
+
{
|
|
86
|
+
"operator": "OR",
|
|
87
|
+
"conditions": [
|
|
88
|
+
{"field": "meta.classification.details.category1", "operator": ">=", "value": 0.85},
|
|
89
|
+
{"field": "meta.classification.details.category2", "operator": ">=", "value": 0.85},
|
|
90
|
+
],
|
|
91
|
+
},
|
|
92
|
+
],
|
|
93
|
+
}
|
|
94
|
+
result = document_store.filter_documents(filter_criteria)
|
|
95
|
+
self.assert_documents_are_equal(
|
|
96
|
+
result,
|
|
97
|
+
[
|
|
98
|
+
d
|
|
99
|
+
for d in documents
|
|
100
|
+
if (d.meta.get("file_name") in ["file1", "file2"])
|
|
101
|
+
and (
|
|
102
|
+
(d.meta.get("classification").get("details").get("category1") >= 0.85)
|
|
103
|
+
or (d.meta.get("classification").get("details").get("category2") >= 0.85)
|
|
104
|
+
)
|
|
105
|
+
],
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
def test_complex_filter_criteria(self, document_store):
|
|
109
|
+
documents = [
|
|
110
|
+
Document(
|
|
111
|
+
content="This is test document 1.",
|
|
112
|
+
meta={
|
|
113
|
+
"file_name": "file1",
|
|
114
|
+
"classification": {"details": {"category1": 0.45, "category2": 0.5, "category3": 0.2}},
|
|
115
|
+
},
|
|
116
|
+
),
|
|
117
|
+
Document(
|
|
118
|
+
content="This is test document 2.",
|
|
119
|
+
meta={
|
|
120
|
+
"file_name": "file2",
|
|
121
|
+
"classification": {"details": {"category1": 0.95, "category2": 0.85, "category3": 0.4}},
|
|
122
|
+
},
|
|
123
|
+
),
|
|
124
|
+
Document(
|
|
125
|
+
content="This is test document 3.",
|
|
126
|
+
meta={
|
|
127
|
+
"file_name": "file3",
|
|
128
|
+
"classification": {"details": {"category1": 0.85, "category2": 0.7, "category3": 0.95}},
|
|
129
|
+
},
|
|
130
|
+
),
|
|
131
|
+
]
|
|
132
|
+
|
|
133
|
+
document_store.write_documents(documents)
|
|
134
|
+
filter_criteria = {
|
|
135
|
+
"operator": "AND",
|
|
136
|
+
"conditions": [
|
|
137
|
+
{"field": "meta.file_name", "operator": "in", "value": ["file1", "file2", "file3"]},
|
|
138
|
+
{
|
|
139
|
+
"operator": "AND",
|
|
140
|
+
"conditions": [
|
|
141
|
+
{"field": "meta.classification.details.category1", "operator": ">=", "value": 0.85},
|
|
142
|
+
{
|
|
143
|
+
"operator": "OR",
|
|
144
|
+
"conditions": [
|
|
145
|
+
{"field": "meta.classification.details.category2", "operator": ">=", "value": 0.8},
|
|
146
|
+
{"field": "meta.classification.details.category3", "operator": ">=", "value": 0.9},
|
|
147
|
+
],
|
|
148
|
+
},
|
|
149
|
+
],
|
|
150
|
+
},
|
|
151
|
+
],
|
|
152
|
+
}
|
|
153
|
+
result = document_store.filter_documents(filter_criteria)
|
|
154
|
+
self.assert_documents_are_equal(
|
|
155
|
+
result,
|
|
156
|
+
[
|
|
157
|
+
d
|
|
158
|
+
for d in documents
|
|
159
|
+
if (d.meta.get("file_name") in ["file1", "file2", "file3"])
|
|
160
|
+
and (
|
|
161
|
+
(d.meta.get("classification").get("details").get("category1") >= 0.85)
|
|
162
|
+
and (
|
|
163
|
+
(d.meta.get("classification").get("details").get("category2") >= 0.8)
|
|
164
|
+
or (d.meta.get("classification").get("details").get("category3") >= 0.9)
|
|
165
|
+
)
|
|
166
|
+
)
|
|
167
|
+
],
|
|
168
|
+
)
|
|
169
|
+
|
|
64
170
|
# ======== OVERRIDES FOR NONE VALUED FILTERS ========
|
|
65
171
|
|
|
66
172
|
def test_comparison_equal_with_none(self, document_store, filterable_docs):
|
|
@@ -1,7 +1,9 @@
|
|
|
1
1
|
from typing import List
|
|
2
2
|
from unittest.mock import Mock
|
|
3
3
|
|
|
4
|
+
import pytest
|
|
4
5
|
from haystack.dataclasses import Document, SparseEmbedding
|
|
6
|
+
from haystack.document_stores.types import FilterPolicy
|
|
5
7
|
from haystack.testing.document_store import (
|
|
6
8
|
FilterableDocsFixtureMixin,
|
|
7
9
|
_random_embeddings,
|
|
@@ -21,9 +23,16 @@ class TestQdrantRetriever(FilterableDocsFixtureMixin):
|
|
|
21
23
|
assert retriever._document_store == document_store
|
|
22
24
|
assert retriever._filters is None
|
|
23
25
|
assert retriever._top_k == 10
|
|
26
|
+
assert retriever._filter_policy == FilterPolicy.REPLACE
|
|
24
27
|
assert retriever._return_embedding is False
|
|
25
28
|
assert retriever._score_threshold is None
|
|
26
29
|
|
|
30
|
+
retriever = QdrantEmbeddingRetriever(document_store=document_store, filter_policy="replace")
|
|
31
|
+
assert retriever._filter_policy == FilterPolicy.REPLACE
|
|
32
|
+
|
|
33
|
+
with pytest.raises(ValueError):
|
|
34
|
+
QdrantEmbeddingRetriever(document_store=document_store, filter_policy="invalid")
|
|
35
|
+
|
|
27
36
|
def test_to_dict(self):
|
|
28
37
|
document_store = QdrantDocumentStore(location=":memory:", index="test", use_sparse_embeddings=False)
|
|
29
38
|
retriever = QdrantEmbeddingRetriever(document_store=document_store)
|
|
@@ -73,6 +82,7 @@ class TestQdrantRetriever(FilterableDocsFixtureMixin):
|
|
|
73
82
|
},
|
|
74
83
|
"filters": None,
|
|
75
84
|
"top_k": 10,
|
|
85
|
+
"filter_policy": "replace",
|
|
76
86
|
"scale_score": False,
|
|
77
87
|
"return_embedding": False,
|
|
78
88
|
"score_threshold": None,
|
|
@@ -89,6 +99,7 @@ class TestQdrantRetriever(FilterableDocsFixtureMixin):
|
|
|
89
99
|
},
|
|
90
100
|
"filters": None,
|
|
91
101
|
"top_k": 5,
|
|
102
|
+
"filter_policy": "replace",
|
|
92
103
|
"scale_score": False,
|
|
93
104
|
"return_embedding": True,
|
|
94
105
|
"score_threshold": None,
|
|
@@ -99,6 +110,7 @@ class TestQdrantRetriever(FilterableDocsFixtureMixin):
|
|
|
99
110
|
assert retriever._document_store.index == "test"
|
|
100
111
|
assert retriever._filters is None
|
|
101
112
|
assert retriever._top_k == 5
|
|
113
|
+
assert retriever._filter_policy == FilterPolicy.REPLACE
|
|
102
114
|
assert retriever._scale_score is False
|
|
103
115
|
assert retriever._return_embedding is True
|
|
104
116
|
assert retriever._score_threshold is None
|
|
@@ -119,6 +131,31 @@ class TestQdrantRetriever(FilterableDocsFixtureMixin):
|
|
|
119
131
|
for document in results:
|
|
120
132
|
assert document.embedding is None
|
|
121
133
|
|
|
134
|
+
def test_run_filters(self, filterable_docs: List[Document]):
|
|
135
|
+
document_store = QdrantDocumentStore(location=":memory:", index="Boi", use_sparse_embeddings=False)
|
|
136
|
+
|
|
137
|
+
document_store.write_documents(filterable_docs)
|
|
138
|
+
|
|
139
|
+
retriever = QdrantEmbeddingRetriever(
|
|
140
|
+
document_store=document_store,
|
|
141
|
+
filters={"field": "meta.name", "operator": "==", "value": "name_0"},
|
|
142
|
+
filter_policy=FilterPolicy.MERGE,
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
results: List[Document] = retriever.run(query_embedding=_random_embeddings(768))["documents"]
|
|
146
|
+
assert len(results) == 3
|
|
147
|
+
|
|
148
|
+
results = retriever.run(
|
|
149
|
+
query_embedding=_random_embeddings(768),
|
|
150
|
+
top_k=5,
|
|
151
|
+
filters={"field": "meta.chapter", "operator": "==", "value": "abstract"},
|
|
152
|
+
return_embedding=False,
|
|
153
|
+
)["documents"]
|
|
154
|
+
assert len(results) == 3
|
|
155
|
+
|
|
156
|
+
for document in results:
|
|
157
|
+
assert document.embedding is None
|
|
158
|
+
|
|
122
159
|
def test_run_with_score_threshold(self):
|
|
123
160
|
document_store = QdrantDocumentStore(
|
|
124
161
|
embedding_dim=4, location=":memory:", similarity="cosine", index="Boi", use_sparse_embeddings=False
|
|
@@ -167,9 +204,16 @@ class TestQdrantSparseEmbeddingRetriever(FilterableDocsFixtureMixin):
|
|
|
167
204
|
assert retriever._document_store == document_store
|
|
168
205
|
assert retriever._filters is None
|
|
169
206
|
assert retriever._top_k == 10
|
|
207
|
+
assert retriever._filter_policy == FilterPolicy.REPLACE
|
|
170
208
|
assert retriever._return_embedding is False
|
|
171
209
|
assert retriever._score_threshold is None
|
|
172
210
|
|
|
211
|
+
retriever = QdrantSparseEmbeddingRetriever(document_store=document_store, filter_policy="replace")
|
|
212
|
+
assert retriever._filter_policy == FilterPolicy.REPLACE
|
|
213
|
+
|
|
214
|
+
with pytest.raises(ValueError):
|
|
215
|
+
QdrantSparseEmbeddingRetriever(document_store=document_store, filter_policy="invalid")
|
|
216
|
+
|
|
173
217
|
def test_to_dict(self):
|
|
174
218
|
document_store = QdrantDocumentStore(location=":memory:", index="test")
|
|
175
219
|
retriever = QdrantSparseEmbeddingRetriever(document_store=document_store)
|
|
@@ -221,11 +265,38 @@ class TestQdrantSparseEmbeddingRetriever(FilterableDocsFixtureMixin):
|
|
|
221
265
|
"top_k": 10,
|
|
222
266
|
"scale_score": False,
|
|
223
267
|
"return_embedding": False,
|
|
268
|
+
"filter_policy": "replace",
|
|
224
269
|
"score_threshold": None,
|
|
225
270
|
},
|
|
226
271
|
}
|
|
227
272
|
|
|
228
273
|
def test_from_dict(self):
|
|
274
|
+
data = {
|
|
275
|
+
"type": "haystack_integrations.components.retrievers.qdrant.retriever.QdrantSparseEmbeddingRetriever",
|
|
276
|
+
"init_parameters": {
|
|
277
|
+
"document_store": {
|
|
278
|
+
"init_parameters": {"location": ":memory:", "index": "test"},
|
|
279
|
+
"type": "haystack_integrations.document_stores.qdrant.document_store.QdrantDocumentStore",
|
|
280
|
+
},
|
|
281
|
+
"filters": None,
|
|
282
|
+
"top_k": 5,
|
|
283
|
+
"scale_score": False,
|
|
284
|
+
"return_embedding": True,
|
|
285
|
+
"filter_policy": "replace",
|
|
286
|
+
"score_threshold": None,
|
|
287
|
+
},
|
|
288
|
+
}
|
|
289
|
+
retriever = QdrantSparseEmbeddingRetriever.from_dict(data)
|
|
290
|
+
assert isinstance(retriever._document_store, QdrantDocumentStore)
|
|
291
|
+
assert retriever._document_store.index == "test"
|
|
292
|
+
assert retriever._filters is None
|
|
293
|
+
assert retriever._top_k == 5
|
|
294
|
+
assert retriever._filter_policy == FilterPolicy.REPLACE
|
|
295
|
+
assert retriever._scale_score is False
|
|
296
|
+
assert retriever._return_embedding is True
|
|
297
|
+
assert retriever._score_threshold is None
|
|
298
|
+
|
|
299
|
+
def test_from_dict_no_filter_policy(self):
|
|
229
300
|
data = {
|
|
230
301
|
"type": "haystack_integrations.components.retrievers.qdrant.retriever.QdrantSparseEmbeddingRetriever",
|
|
231
302
|
"init_parameters": {
|
|
@@ -245,6 +316,7 @@ class TestQdrantSparseEmbeddingRetriever(FilterableDocsFixtureMixin):
|
|
|
245
316
|
assert retriever._document_store.index == "test"
|
|
246
317
|
assert retriever._filters is None
|
|
247
318
|
assert retriever._top_k == 5
|
|
319
|
+
assert retriever._filter_policy == FilterPolicy.REPLACE # defaults to REPLACE
|
|
248
320
|
assert retriever._scale_score is False
|
|
249
321
|
assert retriever._return_embedding is True
|
|
250
322
|
assert retriever._score_threshold is None
|
|
@@ -278,9 +350,16 @@ class TestQdrantHybridRetriever:
|
|
|
278
350
|
assert retriever._document_store == document_store
|
|
279
351
|
assert retriever._filters is None
|
|
280
352
|
assert retriever._top_k == 10
|
|
353
|
+
assert retriever._filter_policy == FilterPolicy.REPLACE
|
|
281
354
|
assert retriever._return_embedding is False
|
|
282
355
|
assert retriever._score_threshold is None
|
|
283
356
|
|
|
357
|
+
retriever = QdrantHybridRetriever(document_store=document_store, filter_policy="replace")
|
|
358
|
+
assert retriever._filter_policy == FilterPolicy.REPLACE
|
|
359
|
+
|
|
360
|
+
with pytest.raises(ValueError):
|
|
361
|
+
QdrantHybridRetriever(document_store=document_store, filter_policy="invalid")
|
|
362
|
+
|
|
284
363
|
def test_to_dict(self):
|
|
285
364
|
document_store = QdrantDocumentStore(location=":memory:", index="test")
|
|
286
365
|
retriever = QdrantHybridRetriever(document_store=document_store, top_k=5, return_embedding=True)
|
|
@@ -330,12 +409,37 @@ class TestQdrantHybridRetriever:
|
|
|
330
409
|
},
|
|
331
410
|
"filters": None,
|
|
332
411
|
"top_k": 5,
|
|
412
|
+
"filter_policy": "replace",
|
|
333
413
|
"return_embedding": True,
|
|
334
414
|
"score_threshold": None,
|
|
335
415
|
},
|
|
336
416
|
}
|
|
337
417
|
|
|
338
418
|
def test_from_dict(self):
|
|
419
|
+
data = {
|
|
420
|
+
"type": "haystack_integrations.components.retrievers.qdrant.retriever.QdrantHybridRetriever",
|
|
421
|
+
"init_parameters": {
|
|
422
|
+
"document_store": {
|
|
423
|
+
"init_parameters": {"location": ":memory:", "index": "test"},
|
|
424
|
+
"type": "haystack_integrations.document_stores.qdrant.document_store.QdrantDocumentStore",
|
|
425
|
+
},
|
|
426
|
+
"filters": None,
|
|
427
|
+
"top_k": 5,
|
|
428
|
+
"filter_policy": "replace",
|
|
429
|
+
"return_embedding": True,
|
|
430
|
+
"score_threshold": None,
|
|
431
|
+
},
|
|
432
|
+
}
|
|
433
|
+
retriever = QdrantHybridRetriever.from_dict(data)
|
|
434
|
+
assert isinstance(retriever._document_store, QdrantDocumentStore)
|
|
435
|
+
assert retriever._document_store.index == "test"
|
|
436
|
+
assert retriever._filters is None
|
|
437
|
+
assert retriever._top_k == 5
|
|
438
|
+
assert retriever._filter_policy == FilterPolicy.REPLACE
|
|
439
|
+
assert retriever._return_embedding
|
|
440
|
+
assert retriever._score_threshold is None
|
|
441
|
+
|
|
442
|
+
def test_from_dict_no_filter_policy(self):
|
|
339
443
|
data = {
|
|
340
444
|
"type": "haystack_integrations.components.retrievers.qdrant.retriever.QdrantHybridRetriever",
|
|
341
445
|
"init_parameters": {
|
|
@@ -354,6 +458,7 @@ class TestQdrantHybridRetriever:
|
|
|
354
458
|
assert retriever._document_store.index == "test"
|
|
355
459
|
assert retriever._filters is None
|
|
356
460
|
assert retriever._top_k == 5
|
|
461
|
+
assert retriever._filter_policy == FilterPolicy.REPLACE # defaults to REPLACE
|
|
357
462
|
assert retriever._return_embedding
|
|
358
463
|
assert retriever._score_threshold is None
|
|
359
464
|
|
|
@@ -1,238 +0,0 @@
|
|
|
1
|
-
from datetime import datetime
|
|
2
|
-
from typing import List, Optional, Union
|
|
3
|
-
|
|
4
|
-
from haystack.utils.filters import COMPARISON_OPERATORS, LOGICAL_OPERATORS, FilterError
|
|
5
|
-
from qdrant_client.http import models
|
|
6
|
-
|
|
7
|
-
from .converters import convert_id
|
|
8
|
-
|
|
9
|
-
COMPARISON_OPERATORS = COMPARISON_OPERATORS.keys()
|
|
10
|
-
LOGICAL_OPERATORS = LOGICAL_OPERATORS.keys()
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
def convert_filters_to_qdrant(
|
|
14
|
-
filter_term: Optional[Union[List[dict], dict, models.Filter]] = None,
|
|
15
|
-
) -> Optional[models.Filter]:
|
|
16
|
-
"""Converts Haystack filters to the format used by Qdrant."""
|
|
17
|
-
if isinstance(filter_term, models.Filter):
|
|
18
|
-
return filter_term
|
|
19
|
-
if not filter_term:
|
|
20
|
-
return None
|
|
21
|
-
|
|
22
|
-
must_clauses, should_clauses, must_not_clauses = [], [], []
|
|
23
|
-
|
|
24
|
-
if isinstance(filter_term, dict):
|
|
25
|
-
filter_term = [filter_term]
|
|
26
|
-
|
|
27
|
-
for item in filter_term:
|
|
28
|
-
operator = item.get("operator")
|
|
29
|
-
if operator is None:
|
|
30
|
-
msg = "Operator not found in filters"
|
|
31
|
-
raise FilterError(msg)
|
|
32
|
-
|
|
33
|
-
if operator in LOGICAL_OPERATORS and "conditions" not in item:
|
|
34
|
-
msg = f"'conditions' not found for '{operator}'"
|
|
35
|
-
raise FilterError(msg)
|
|
36
|
-
|
|
37
|
-
if operator == "AND":
|
|
38
|
-
must_clauses.append(convert_filters_to_qdrant(item.get("conditions", [])))
|
|
39
|
-
elif operator == "OR":
|
|
40
|
-
should_clauses.append(convert_filters_to_qdrant(item.get("conditions", [])))
|
|
41
|
-
elif operator == "NOT":
|
|
42
|
-
must_not_clauses.append(convert_filters_to_qdrant(item.get("conditions", [])))
|
|
43
|
-
elif operator in COMPARISON_OPERATORS:
|
|
44
|
-
field = item.get("field")
|
|
45
|
-
value = item.get("value")
|
|
46
|
-
if field is None or value is None:
|
|
47
|
-
msg = f"'field' or 'value' not found for '{operator}'"
|
|
48
|
-
raise FilterError(msg)
|
|
49
|
-
|
|
50
|
-
must_clauses.extend(_parse_comparison_operation(comparison_operation=operator, key=field, value=value))
|
|
51
|
-
else:
|
|
52
|
-
msg = f"Unknown operator {operator} used in filters"
|
|
53
|
-
raise FilterError(msg)
|
|
54
|
-
|
|
55
|
-
payload_filter = models.Filter(
|
|
56
|
-
must=must_clauses or None,
|
|
57
|
-
should=should_clauses or None,
|
|
58
|
-
must_not=must_not_clauses or None,
|
|
59
|
-
)
|
|
60
|
-
|
|
61
|
-
filter_result = _squeeze_filter(payload_filter)
|
|
62
|
-
|
|
63
|
-
return filter_result
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
def _parse_comparison_operation(
|
|
67
|
-
comparison_operation: str, key: str, value: Union[dict, List, str, float]
|
|
68
|
-
) -> List[models.Condition]:
|
|
69
|
-
conditions: List[models.Condition] = []
|
|
70
|
-
|
|
71
|
-
condition_builder_mapping = {
|
|
72
|
-
"==": _build_eq_condition,
|
|
73
|
-
"in": _build_in_condition,
|
|
74
|
-
"!=": _build_ne_condition,
|
|
75
|
-
"not in": _build_nin_condition,
|
|
76
|
-
">": _build_gt_condition,
|
|
77
|
-
">=": _build_gte_condition,
|
|
78
|
-
"<": _build_lt_condition,
|
|
79
|
-
"<=": _build_lte_condition,
|
|
80
|
-
}
|
|
81
|
-
|
|
82
|
-
condition_builder = condition_builder_mapping.get(comparison_operation)
|
|
83
|
-
|
|
84
|
-
if condition_builder is None:
|
|
85
|
-
msg = f"Unknown operator {comparison_operation} used in filters"
|
|
86
|
-
raise ValueError(msg)
|
|
87
|
-
|
|
88
|
-
conditions.append(condition_builder(key, value))
|
|
89
|
-
|
|
90
|
-
return conditions
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
def _build_eq_condition(key: str, value: models.ValueVariants) -> models.Condition:
|
|
94
|
-
if isinstance(value, str) and " " in value:
|
|
95
|
-
models.FieldCondition(key=key, match=models.MatchText(text=value))
|
|
96
|
-
return models.FieldCondition(key=key, match=models.MatchValue(value=value))
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
def _build_in_condition(key: str, value: List[models.ValueVariants]) -> models.Condition:
|
|
100
|
-
if not isinstance(value, list):
|
|
101
|
-
msg = f"Value {value} is not a list"
|
|
102
|
-
raise FilterError(msg)
|
|
103
|
-
return models.Filter(
|
|
104
|
-
should=[
|
|
105
|
-
(
|
|
106
|
-
models.FieldCondition(key=key, match=models.MatchText(text=item))
|
|
107
|
-
if isinstance(item, str) and " " not in item
|
|
108
|
-
else models.FieldCondition(key=key, match=models.MatchValue(value=item))
|
|
109
|
-
)
|
|
110
|
-
for item in value
|
|
111
|
-
]
|
|
112
|
-
)
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
def _build_ne_condition(key: str, value: models.ValueVariants) -> models.Condition:
|
|
116
|
-
return models.Filter(
|
|
117
|
-
must_not=[
|
|
118
|
-
(
|
|
119
|
-
models.FieldCondition(key=key, match=models.MatchText(text=value))
|
|
120
|
-
if isinstance(value, str) and " " not in value
|
|
121
|
-
else models.FieldCondition(key=key, match=models.MatchValue(value=value))
|
|
122
|
-
)
|
|
123
|
-
]
|
|
124
|
-
)
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
def _build_nin_condition(key: str, value: List[models.ValueVariants]) -> models.Condition:
|
|
128
|
-
if not isinstance(value, list):
|
|
129
|
-
msg = f"Value {value} is not a list"
|
|
130
|
-
raise FilterError(msg)
|
|
131
|
-
return models.Filter(
|
|
132
|
-
must_not=[
|
|
133
|
-
(
|
|
134
|
-
models.FieldCondition(key=key, match=models.MatchText(text=item))
|
|
135
|
-
if isinstance(item, str) and " " in item
|
|
136
|
-
else models.FieldCondition(key=key, match=models.MatchValue(value=item))
|
|
137
|
-
)
|
|
138
|
-
for item in value
|
|
139
|
-
]
|
|
140
|
-
)
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
def _build_lt_condition(key: str, value: Union[str, float, int]) -> models.Condition:
|
|
144
|
-
if isinstance(value, str) and is_datetime_string(value):
|
|
145
|
-
return models.FieldCondition(key=key, range=models.DatetimeRange(lt=value))
|
|
146
|
-
|
|
147
|
-
if isinstance(value, (int, float)):
|
|
148
|
-
return models.FieldCondition(key=key, range=models.Range(lt=value))
|
|
149
|
-
|
|
150
|
-
msg = f"Value {value} is not an int or float or datetime string"
|
|
151
|
-
raise FilterError(msg)
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
def _build_lte_condition(key: str, value: Union[str, float, int]) -> models.Condition:
|
|
155
|
-
if isinstance(value, str) and is_datetime_string(value):
|
|
156
|
-
return models.FieldCondition(key=key, range=models.DatetimeRange(lte=value))
|
|
157
|
-
|
|
158
|
-
if isinstance(value, (int, float)):
|
|
159
|
-
return models.FieldCondition(key=key, range=models.Range(lte=value))
|
|
160
|
-
|
|
161
|
-
msg = f"Value {value} is not an int or float or datetime string"
|
|
162
|
-
raise FilterError(msg)
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
def _build_gt_condition(key: str, value: Union[str, float, int]) -> models.Condition:
|
|
166
|
-
if isinstance(value, str) and is_datetime_string(value):
|
|
167
|
-
return models.FieldCondition(key=key, range=models.DatetimeRange(gt=value))
|
|
168
|
-
|
|
169
|
-
if isinstance(value, (int, float)):
|
|
170
|
-
return models.FieldCondition(key=key, range=models.Range(gt=value))
|
|
171
|
-
|
|
172
|
-
msg = f"Value {value} is not an int or float or datetime string"
|
|
173
|
-
raise FilterError(msg)
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
def _build_gte_condition(key: str, value: Union[str, float, int]) -> models.Condition:
|
|
177
|
-
if isinstance(value, str) and is_datetime_string(value):
|
|
178
|
-
return models.FieldCondition(key=key, range=models.DatetimeRange(gte=value))
|
|
179
|
-
|
|
180
|
-
if isinstance(value, (int, float)):
|
|
181
|
-
return models.FieldCondition(key=key, range=models.Range(gte=value))
|
|
182
|
-
|
|
183
|
-
msg = f"Value {value} is not an int or float or datetime string"
|
|
184
|
-
raise FilterError(msg)
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
def _build_has_id_condition(id_values: List[models.ExtendedPointId]) -> models.HasIdCondition:
|
|
188
|
-
return models.HasIdCondition(
|
|
189
|
-
has_id=[
|
|
190
|
-
# Ids are converted into their internal representation
|
|
191
|
-
convert_id(item)
|
|
192
|
-
for item in id_values
|
|
193
|
-
]
|
|
194
|
-
)
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
def _squeeze_filter(payload_filter: models.Filter) -> models.Filter:
|
|
198
|
-
"""
|
|
199
|
-
Simplify given payload filter, if the nested structure might be unnested.
|
|
200
|
-
That happens if there is a single clause in that filter.
|
|
201
|
-
:param payload_filter:
|
|
202
|
-
:returns:
|
|
203
|
-
"""
|
|
204
|
-
filter_parts = {
|
|
205
|
-
"must": payload_filter.must,
|
|
206
|
-
"should": payload_filter.should,
|
|
207
|
-
"must_not": payload_filter.must_not,
|
|
208
|
-
}
|
|
209
|
-
|
|
210
|
-
total_clauses = sum(len(x) for x in filter_parts.values() if x is not None)
|
|
211
|
-
if total_clauses == 0 or total_clauses > 1:
|
|
212
|
-
return payload_filter
|
|
213
|
-
|
|
214
|
-
# Payload filter has just a single clause provided (either must, should
|
|
215
|
-
# or must_not). If that single clause is also of a models.Filter type,
|
|
216
|
-
# then it might be returned instead.
|
|
217
|
-
for part_name, filter_part in filter_parts.items():
|
|
218
|
-
if not filter_part:
|
|
219
|
-
continue
|
|
220
|
-
|
|
221
|
-
subfilter = filter_part[0]
|
|
222
|
-
if not isinstance(subfilter, models.Filter):
|
|
223
|
-
# The inner statement is a simple condition like models.FieldCondition
|
|
224
|
-
# so it cannot be simplified.
|
|
225
|
-
continue
|
|
226
|
-
|
|
227
|
-
if subfilter.must:
|
|
228
|
-
return models.Filter(**{part_name: subfilter.must})
|
|
229
|
-
|
|
230
|
-
return payload_filter
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
def is_datetime_string(value: str) -> bool:
|
|
234
|
-
try:
|
|
235
|
-
datetime.fromisoformat(value)
|
|
236
|
-
return True
|
|
237
|
-
except ValueError:
|
|
238
|
-
return False
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|