ragbits-core 0.16.0__py3-none-any.whl → 1.4.0.dev202512021005__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ragbits/core/__init__.py +21 -2
- ragbits/core/audit/__init__.py +15 -157
- ragbits/core/audit/metrics/__init__.py +83 -0
- ragbits/core/audit/metrics/base.py +198 -0
- ragbits/core/audit/metrics/logfire.py +19 -0
- ragbits/core/audit/metrics/otel.py +65 -0
- ragbits/core/audit/traces/__init__.py +171 -0
- ragbits/core/audit/{base.py → traces/base.py} +9 -5
- ragbits/core/audit/{cli.py → traces/cli.py} +8 -4
- ragbits/core/audit/traces/logfire.py +18 -0
- ragbits/core/audit/{otel.py → traces/otel.py} +5 -8
- ragbits/core/config.py +15 -0
- ragbits/core/embeddings/__init__.py +2 -1
- ragbits/core/embeddings/base.py +19 -0
- ragbits/core/embeddings/dense/base.py +10 -1
- ragbits/core/embeddings/dense/fastembed.py +22 -1
- ragbits/core/embeddings/dense/litellm.py +37 -10
- ragbits/core/embeddings/dense/local.py +15 -1
- ragbits/core/embeddings/dense/noop.py +11 -1
- ragbits/core/embeddings/dense/vertex_multimodal.py +14 -1
- ragbits/core/embeddings/sparse/bag_of_tokens.py +47 -17
- ragbits/core/embeddings/sparse/base.py +10 -1
- ragbits/core/embeddings/sparse/fastembed.py +25 -2
- ragbits/core/llms/__init__.py +3 -3
- ragbits/core/llms/base.py +612 -88
- ragbits/core/llms/exceptions.py +27 -0
- ragbits/core/llms/litellm.py +408 -83
- ragbits/core/llms/local.py +180 -41
- ragbits/core/llms/mock.py +88 -23
- ragbits/core/prompt/__init__.py +2 -2
- ragbits/core/prompt/_cli.py +32 -19
- ragbits/core/prompt/base.py +105 -19
- ragbits/core/prompt/{discovery/prompt_discovery.py → discovery.py} +1 -1
- ragbits/core/prompt/exceptions.py +22 -6
- ragbits/core/prompt/prompt.py +180 -98
- ragbits/core/sources/__init__.py +2 -0
- ragbits/core/sources/azure.py +1 -1
- ragbits/core/sources/base.py +8 -1
- ragbits/core/sources/gcs.py +1 -1
- ragbits/core/sources/git.py +1 -1
- ragbits/core/sources/google_drive.py +595 -0
- ragbits/core/sources/hf.py +71 -31
- ragbits/core/sources/local.py +1 -1
- ragbits/core/sources/s3.py +1 -1
- ragbits/core/utils/config_handling.py +13 -2
- ragbits/core/utils/function_schema.py +220 -0
- ragbits/core/utils/helpers.py +22 -0
- ragbits/core/utils/lazy_litellm.py +44 -0
- ragbits/core/vector_stores/base.py +18 -1
- ragbits/core/vector_stores/chroma.py +28 -11
- ragbits/core/vector_stores/hybrid.py +1 -1
- ragbits/core/vector_stores/hybrid_strategies.py +21 -8
- ragbits/core/vector_stores/in_memory.py +13 -4
- ragbits/core/vector_stores/pgvector.py +123 -47
- ragbits/core/vector_stores/qdrant.py +15 -7
- ragbits/core/vector_stores/weaviate.py +440 -0
- {ragbits_core-0.16.0.dist-info → ragbits_core-1.4.0.dev202512021005.dist-info}/METADATA +22 -6
- ragbits_core-1.4.0.dev202512021005.dist-info/RECORD +79 -0
- {ragbits_core-0.16.0.dist-info → ragbits_core-1.4.0.dev202512021005.dist-info}/WHEEL +1 -1
- ragbits/core/prompt/discovery/__init__.py +0 -3
- ragbits/core/prompt/lab/__init__.py +0 -0
- ragbits/core/prompt/lab/app.py +0 -262
- ragbits_core-0.16.0.dist-info/RECORD +0 -72
|
@@ -1,10 +1,9 @@
|
|
|
1
|
+
import math
|
|
1
2
|
from itertools import islice
|
|
2
3
|
from typing import cast
|
|
3
4
|
from uuid import UUID
|
|
4
5
|
|
|
5
|
-
import
|
|
6
|
-
|
|
7
|
-
from ragbits.core.audit import trace, traceable
|
|
6
|
+
from ragbits.core.audit.traces import trace, traceable
|
|
8
7
|
from ragbits.core.embeddings import Embedder, SparseVector
|
|
9
8
|
from ragbits.core.vector_stores.base import (
|
|
10
9
|
EmbeddingType,
|
|
@@ -90,6 +89,14 @@ class InMemoryVectorStore(VectorStoreWithEmbedder[VectorStoreOptions]):
|
|
|
90
89
|
results: list[VectorStoreResult] = []
|
|
91
90
|
|
|
92
91
|
for entry_id, vector in self._embeddings.items():
|
|
92
|
+
entry = self._entries[entry_id]
|
|
93
|
+
|
|
94
|
+
# Apply metadata filtering
|
|
95
|
+
if merged_options.where and not all(
|
|
96
|
+
entry.metadata.get(key) == value for key, value in merged_options.where.items()
|
|
97
|
+
):
|
|
98
|
+
continue
|
|
99
|
+
|
|
93
100
|
# Calculate score based on vector type
|
|
94
101
|
if isinstance(query_vector, SparseVector) and isinstance(vector, SparseVector):
|
|
95
102
|
# For sparse vectors, use dot product between query and document vectors
|
|
@@ -105,7 +112,9 @@ class InMemoryVectorStore(VectorStoreWithEmbedder[VectorStoreOptions]):
|
|
|
105
112
|
# For dense vectors, use negative L2 distance
|
|
106
113
|
query_vector_dense = cast(list[float], query_vector)
|
|
107
114
|
vector_dense = cast(list[float], vector)
|
|
108
|
-
score =
|
|
115
|
+
score = -math.sqrt(
|
|
116
|
+
sum((a - b) ** 2 for a, b in zip(vector_dense, query_vector_dense, strict=False))
|
|
117
|
+
)
|
|
109
118
|
|
|
110
119
|
result = VectorStoreResult(entry=self._entries[entry_id], vector=vector, score=score)
|
|
111
120
|
if merged_options.score_threshold is None or result.score >= merged_options.score_threshold:
|
|
@@ -1,13 +1,13 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import re
|
|
3
|
-
from typing import Any, NamedTuple
|
|
3
|
+
from typing import Any, NamedTuple
|
|
4
4
|
from uuid import UUID
|
|
5
5
|
|
|
6
6
|
import asyncpg
|
|
7
7
|
from pydantic.json import pydantic_encoder
|
|
8
8
|
|
|
9
|
-
from ragbits.core.audit import trace
|
|
10
|
-
from ragbits.core.embeddings.base import Embedder, SparseVector
|
|
9
|
+
from ragbits.core.audit.traces import trace
|
|
10
|
+
from ragbits.core.embeddings.base import Embedder, SparseVector, VectorSize
|
|
11
11
|
from ragbits.core.embeddings.sparse.base import SparseEmbedder
|
|
12
12
|
from ragbits.core.vector_stores.base import (
|
|
13
13
|
EmbeddingType,
|
|
@@ -33,14 +33,16 @@ class DistanceOp(NamedTuple):
|
|
|
33
33
|
DISTANCE_OPS = {
|
|
34
34
|
"cosine": DistanceOp("vector_cosine_ops", "<=>", "1 - distance"),
|
|
35
35
|
"l2": DistanceOp("vector_l2_ops", "<->", "distance * -1"),
|
|
36
|
+
"halfvec_l2": DistanceOp("halfvec_l2_ops", "<->", "distance * -1"),
|
|
36
37
|
"l1": DistanceOp("vector_l1_ops", "<+>", "distance * -1"),
|
|
37
38
|
"ip": DistanceOp("vector_ip_ops", "<#>", "distance * -1"),
|
|
38
39
|
"bit_hamming": DistanceOp("bit_hamming_ops", "<~>", "distance * -1"),
|
|
39
40
|
"bit_jaccard": DistanceOp("bit_jaccard_ops", "<%>", "distance * -1"),
|
|
40
41
|
"sparsevec_l2": DistanceOp("sparsevec_l2_ops", "<->", "distance * -1"),
|
|
41
|
-
"halfvec_l2": DistanceOp("halfvec_l2_ops", "<->", "distance * -1"),
|
|
42
42
|
}
|
|
43
43
|
|
|
44
|
+
MAX_VECTOR_SIZE = 2000
|
|
45
|
+
|
|
44
46
|
|
|
45
47
|
class PgVectorStore(VectorStoreWithEmbedder[VectorStoreOptions]):
|
|
46
48
|
"""
|
|
@@ -53,11 +55,12 @@ class PgVectorStore(VectorStoreWithEmbedder[VectorStoreOptions]):
|
|
|
53
55
|
self,
|
|
54
56
|
client: asyncpg.Pool,
|
|
55
57
|
table_name: str,
|
|
56
|
-
vector_size: int,
|
|
57
58
|
embedder: Embedder,
|
|
59
|
+
vector_size: int | None = None,
|
|
58
60
|
embedding_type: EmbeddingType = EmbeddingType.TEXT,
|
|
59
61
|
distance_method: str | None = None,
|
|
60
|
-
|
|
62
|
+
is_hnsw: bool = True,
|
|
63
|
+
params: dict | None = None,
|
|
61
64
|
default_options: VectorStoreOptions | None = None,
|
|
62
65
|
) -> None:
|
|
63
66
|
"""
|
|
@@ -66,12 +69,13 @@ class PgVectorStore(VectorStoreWithEmbedder[VectorStoreOptions]):
|
|
|
66
69
|
Args:
|
|
67
70
|
client: The pgVector database connection pool.
|
|
68
71
|
table_name: The name of the table.
|
|
69
|
-
vector_size: The size of the vectors.
|
|
70
72
|
embedder: The embedder to use for converting entries to vectors.
|
|
73
|
+
vector_size: The size of the vectors. If None, will be determined automatically from the embedder.
|
|
71
74
|
embedding_type: Which part of the entry to embed, either text or image. The other part will be ignored.
|
|
72
75
|
distance_method: The distance method to use, default is "cosine" for dense vectors
|
|
73
76
|
and "sparsevec_l2" for sparse vectors.
|
|
74
|
-
|
|
77
|
+
is_hnsw: if hnsw or ivfflat indexing should be used
|
|
78
|
+
params: The parameters for the HNSW index. If None, the default parameters will be used.
|
|
75
79
|
default_options: The default options for querying the vector store.
|
|
76
80
|
"""
|
|
77
81
|
(
|
|
@@ -84,27 +88,34 @@ class PgVectorStore(VectorStoreWithEmbedder[VectorStoreOptions]):
|
|
|
84
88
|
|
|
85
89
|
if not re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", table_name):
|
|
86
90
|
raise ValueError(f"Invalid table name: {table_name}")
|
|
87
|
-
if not isinstance(vector_size, int) or vector_size <= 0:
|
|
91
|
+
if vector_size is not None and (not isinstance(vector_size, int) or vector_size <= 0):
|
|
88
92
|
raise ValueError("Vector size must be a positive integer.")
|
|
89
93
|
|
|
90
|
-
if
|
|
91
|
-
|
|
92
|
-
elif not
|
|
93
|
-
|
|
94
|
-
elif
|
|
95
|
-
raise ValueError("
|
|
96
|
-
elif
|
|
97
|
-
raise ValueError("
|
|
98
|
-
elif not isinstance(
|
|
99
|
-
raise ValueError("
|
|
94
|
+
if params is None and is_hnsw:
|
|
95
|
+
params = {"m": 4, "ef_construction": 10}
|
|
96
|
+
elif params is None and not is_hnsw:
|
|
97
|
+
params = {"lists": 100}
|
|
98
|
+
elif not isinstance(params, dict):
|
|
99
|
+
raise ValueError("params must be a dictionary.")
|
|
100
|
+
elif "m" not in params or "ef_construction" not in params and is_hnsw:
|
|
101
|
+
raise ValueError("params must contain 'm' and 'ef_construction' keys for hnsw indexing.")
|
|
102
|
+
elif not isinstance(params["m"], int) or params["m"] <= 0 and is_hnsw:
|
|
103
|
+
raise ValueError("m must be a positive integer for hnsw indexing.")
|
|
104
|
+
elif not isinstance(params["ef_construction"], int) or params["ef_construction"] <= 0 and is_hnsw:
|
|
105
|
+
raise ValueError("ef_construction must be a positive integer for hnsw indexing.")
|
|
106
|
+
elif "lists" not in params and not is_hnsw:
|
|
107
|
+
raise ValueError("params must contain 'lists' key for IVFFlat indexing.")
|
|
108
|
+
elif not isinstance(params["lists"], int) or params["lists"] <= 0 and not is_hnsw:
|
|
109
|
+
raise ValueError("lists must be a positive integer for IVFFlat indexing.")
|
|
100
110
|
|
|
101
111
|
if distance_method is None:
|
|
102
112
|
distance_method = "sparsevec_l2" if isinstance(embedder, SparseEmbedder) else "cosine"
|
|
103
113
|
self._client = client
|
|
104
114
|
self._table_name = table_name
|
|
105
115
|
self._vector_size = vector_size
|
|
116
|
+
self._vector_size_info: VectorSize | None = None
|
|
106
117
|
self._distance_method = distance_method
|
|
107
|
-
self.
|
|
118
|
+
self._indexing_params = params
|
|
108
119
|
|
|
109
120
|
def __reduce__(self) -> tuple:
|
|
110
121
|
"""
|
|
@@ -113,6 +124,32 @@ class PgVectorStore(VectorStoreWithEmbedder[VectorStoreOptions]):
|
|
|
113
124
|
# TODO: To be implemented. Required for Ray processing.
|
|
114
125
|
raise NotImplementedError
|
|
115
126
|
|
|
127
|
+
async def _get_vector_size_info(self) -> VectorSize:
|
|
128
|
+
"""
|
|
129
|
+
Get vector size information from the embedder if not already cached.
|
|
130
|
+
|
|
131
|
+
Returns:
|
|
132
|
+
VectorSize information including size and sparsity.
|
|
133
|
+
"""
|
|
134
|
+
if self._vector_size_info is None:
|
|
135
|
+
self._vector_size_info = await self._embedder.get_vector_size()
|
|
136
|
+
# Update _vector_size for backward compatibility if it wasn't provided
|
|
137
|
+
if self._vector_size is None:
|
|
138
|
+
self._vector_size = self._vector_size_info.size
|
|
139
|
+
return self._vector_size_info
|
|
140
|
+
|
|
141
|
+
async def _get_vector_size(self) -> int:
|
|
142
|
+
"""
|
|
143
|
+
Get the vector size, either from the constructor parameter or from the embedder.
|
|
144
|
+
|
|
145
|
+
Returns:
|
|
146
|
+
The vector size as an integer.
|
|
147
|
+
"""
|
|
148
|
+
if self._vector_size is not None:
|
|
149
|
+
return self._vector_size
|
|
150
|
+
vector_size_info = await self._get_vector_size_info()
|
|
151
|
+
return vector_size_info.size
|
|
152
|
+
|
|
116
153
|
def _vector_to_string(self, vector: list[float] | SparseVector) -> str:
|
|
117
154
|
"""
|
|
118
155
|
Converts a vector to a string representation.
|
|
@@ -124,8 +161,13 @@ class PgVectorStore(VectorStoreWithEmbedder[VectorStoreOptions]):
|
|
|
124
161
|
str: The string representation of the vector.
|
|
125
162
|
"""
|
|
126
163
|
if isinstance(vector, SparseVector):
|
|
164
|
+
# For sparse vectors, we need the vector size to be available
|
|
165
|
+
# This will be resolved when this method is called from async context
|
|
166
|
+
vector_size = self._vector_size
|
|
167
|
+
if vector_size is None:
|
|
168
|
+
raise RuntimeError("Vector size must be determined before converting sparse vectors to string")
|
|
127
169
|
points_str = ",".join(f"{i}:{v}" for i, v in zip(vector.indices, vector.values, strict=False))
|
|
128
|
-
return f"{{{points_str}}}/{
|
|
170
|
+
return f"{{{points_str}}}/{vector_size}"
|
|
129
171
|
return json.dumps(vector)
|
|
130
172
|
|
|
131
173
|
@staticmethod
|
|
@@ -173,13 +215,19 @@ class PgVectorStore(VectorStoreWithEmbedder[VectorStoreOptions]):
|
|
|
173
215
|
# _table_name has been validated in the class constructor, and it is a valid table name.
|
|
174
216
|
query = f"SELECT *, vector {distance_operator} $1 as distance, {score_formula} as score FROM {self._table_name}" # noqa S608
|
|
175
217
|
|
|
176
|
-
values: list[Any] = [
|
|
177
|
-
|
|
178
|
-
]
|
|
218
|
+
values: list[Any] = [self._vector_to_string(vector)]
|
|
219
|
+
where_clauses = []
|
|
179
220
|
|
|
180
221
|
if query_options.score_threshold is not None:
|
|
181
|
-
|
|
182
|
-
values.
|
|
222
|
+
where_clauses.append("score >= $" + str(len(values) + 1))
|
|
223
|
+
values.append(query_options.score_threshold)
|
|
224
|
+
|
|
225
|
+
if query_options.where:
|
|
226
|
+
where_clauses.append(f"metadata @> ${len(values) + 1}")
|
|
227
|
+
values.append(json.dumps(query_options.where))
|
|
228
|
+
|
|
229
|
+
if where_clauses:
|
|
230
|
+
query += " WHERE " + " AND ".join(where_clauses)
|
|
183
231
|
|
|
184
232
|
query += " ORDER BY distance"
|
|
185
233
|
|
|
@@ -226,32 +274,55 @@ class PgVectorStore(VectorStoreWithEmbedder[VectorStoreOptions]):
|
|
|
226
274
|
|
|
227
275
|
async def create_table(self) -> None:
|
|
228
276
|
"""
|
|
229
|
-
Create a pgVector table with an HNSW index for given similarity.
|
|
277
|
+
Create a pgVector table with an HNSW/IVFFlat index for given similarity.
|
|
230
278
|
"""
|
|
279
|
+
vector_size = await self._get_vector_size()
|
|
280
|
+
|
|
231
281
|
with trace(
|
|
232
282
|
table_name=self._table_name,
|
|
233
283
|
distance_method=self._distance_method,
|
|
234
|
-
vector_size=
|
|
235
|
-
hnsw_index_parameters=self.
|
|
284
|
+
vector_size=vector_size,
|
|
285
|
+
hnsw_index_parameters=self._indexing_params,
|
|
236
286
|
):
|
|
237
287
|
distance = DISTANCE_OPS[self._distance_method].function_name
|
|
238
288
|
create_vector_extension = "CREATE EXTENSION IF NOT EXISTS vector;"
|
|
239
289
|
# _table_name and has been validated in the class constructor, and it is a valid table name.
|
|
240
|
-
#
|
|
290
|
+
# vector_size has been validated in the class constructor or obtained from embedder,
|
|
291
|
+
# and it is a valid vector size.
|
|
241
292
|
|
|
242
293
|
is_sparse = isinstance(self._embedder, SparseEmbedder)
|
|
243
|
-
|
|
294
|
+
|
|
295
|
+
# Check vector size
|
|
296
|
+
# if greater than 2000 then choose type HALFVEC
|
|
297
|
+
# More info: https://github.com/pgvector/pgvector
|
|
298
|
+
vector_func = (
|
|
299
|
+
"HALFVEC"
|
|
300
|
+
if vector_size > MAX_VECTOR_SIZE and re.search("halfvec", distance)
|
|
301
|
+
else "VECTOR"
|
|
302
|
+
if not is_sparse
|
|
303
|
+
else "SPARSEVEC"
|
|
304
|
+
)
|
|
244
305
|
|
|
245
306
|
create_table_query = f"""
|
|
246
307
|
CREATE TABLE {self._table_name}
|
|
247
|
-
(id UUID, text TEXT, image_bytes BYTEA, vector {vector_func}({
|
|
308
|
+
(id UUID, text TEXT, image_bytes BYTEA, vector {vector_func}({vector_size}), metadata JSONB);
|
|
248
309
|
"""
|
|
249
|
-
#
|
|
310
|
+
# _idexing_params has been validated in the class constructor, and it is valid dict[str,int].
|
|
311
|
+
if "lists" in self._indexing_params:
|
|
312
|
+
index_type = "ivfflat"
|
|
313
|
+
index_params = f"(lists = {self._indexing_params['lists']});"
|
|
314
|
+
else:
|
|
315
|
+
index_type = "hnsw"
|
|
316
|
+
index_params = (
|
|
317
|
+
f"(m = {self._indexing_params['m']}, ef_construction = {self._indexing_params['ef_construction']});"
|
|
318
|
+
)
|
|
319
|
+
|
|
250
320
|
create_index_query = f"""
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
321
|
+
CREATE INDEX {self._table_name + "_" + index_type + "_idx"} ON {self._table_name}
|
|
322
|
+
USING {index_type} (vector {distance})
|
|
323
|
+
WITH {index_params}
|
|
324
|
+
"""
|
|
325
|
+
|
|
255
326
|
if await self._check_table_exists():
|
|
256
327
|
print(f"Table {self._table_name} already exist!")
|
|
257
328
|
return
|
|
@@ -277,6 +348,10 @@ class PgVectorStore(VectorStoreWithEmbedder[VectorStoreOptions]):
|
|
|
277
348
|
"""
|
|
278
349
|
if not entries:
|
|
279
350
|
return
|
|
351
|
+
|
|
352
|
+
# Ensure vector size is determined before processing
|
|
353
|
+
vector_size = await self._get_vector_size()
|
|
354
|
+
|
|
280
355
|
# _table_name has been validated in the class constructor, and it is a valid table name.
|
|
281
356
|
insert_query = f"""
|
|
282
357
|
INSERT INTO {self._table_name} (id, text, image_bytes, vector, metadata)
|
|
@@ -285,7 +360,7 @@ class PgVectorStore(VectorStoreWithEmbedder[VectorStoreOptions]):
|
|
|
285
360
|
with trace(
|
|
286
361
|
table_name=self._table_name,
|
|
287
362
|
entries=entries,
|
|
288
|
-
vector_size=
|
|
363
|
+
vector_size=vector_size,
|
|
289
364
|
embedder=repr(self._embedder),
|
|
290
365
|
embedding_type=self._embedding_type,
|
|
291
366
|
):
|
|
@@ -351,25 +426,26 @@ class PgVectorStore(VectorStoreWithEmbedder[VectorStoreOptions]):
|
|
|
351
426
|
Returns:
|
|
352
427
|
The retrieved entries.
|
|
353
428
|
"""
|
|
354
|
-
|
|
429
|
+
merged_options = (self.default_options | options) if options else self.default_options
|
|
430
|
+
|
|
431
|
+
# Ensure vector size is determined before processing
|
|
432
|
+
vector_size = await self._get_vector_size()
|
|
433
|
+
|
|
355
434
|
with trace(
|
|
356
435
|
text=text,
|
|
436
|
+
options=merged_options.dict(),
|
|
357
437
|
table_name=self._table_name,
|
|
358
|
-
|
|
359
|
-
vector_size=self._vector_size,
|
|
438
|
+
vector_size=vector_size,
|
|
360
439
|
distance_method=self._distance_method,
|
|
361
440
|
embedder=repr(self._embedder),
|
|
362
441
|
embedding_type=self._embedding_type,
|
|
363
442
|
) as outputs:
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
query_options = (self.default_options | options) if options else self.default_options
|
|
368
|
-
retrieve_query, values = self._create_retrieve_query(vector, query_options)
|
|
443
|
+
query_vector = (await self._embedder.embed_text([text]))[0]
|
|
444
|
+
query, values = self._create_retrieve_query(query_vector, merged_options)
|
|
369
445
|
|
|
370
446
|
try:
|
|
371
447
|
async with self._client.acquire() as conn:
|
|
372
|
-
results = await conn.fetch(
|
|
448
|
+
results = await conn.fetch(query, *values)
|
|
373
449
|
|
|
374
450
|
outputs.results = [
|
|
375
451
|
VectorStoreResult(
|
|
@@ -16,7 +16,7 @@ from qdrant_client.models import (
|
|
|
16
16
|
)
|
|
17
17
|
from typing_extensions import Self
|
|
18
18
|
|
|
19
|
-
from ragbits.core.audit import trace
|
|
19
|
+
from ragbits.core.audit.traces import trace
|
|
20
20
|
from ragbits.core.embeddings import Embedder, SparseEmbedder, SparseVector
|
|
21
21
|
from ragbits.core.utils.config_handling import ObjectConstructionConfig, import_by_path
|
|
22
22
|
from ragbits.core.utils.dict_transformations import flatten_dict
|
|
@@ -24,7 +24,6 @@ from ragbits.core.vector_stores.base import (
|
|
|
24
24
|
EmbeddingType,
|
|
25
25
|
VectorStoreEntry,
|
|
26
26
|
VectorStoreOptions,
|
|
27
|
-
VectorStoreOptionsT,
|
|
28
27
|
VectorStoreResult,
|
|
29
28
|
VectorStoreWithEmbedder,
|
|
30
29
|
WhereQuery,
|
|
@@ -202,7 +201,7 @@ class QdrantVectorStore(VectorStoreWithEmbedder[VectorStoreOptions]):
|
|
|
202
201
|
models.PointStruct(
|
|
203
202
|
id=str(entry.id),
|
|
204
203
|
vector={self._vector_name: self._to_qdrant_vector(embeddings[entry.id])}, # type: ignore
|
|
205
|
-
payload=entry.model_dump(exclude_none=True),
|
|
204
|
+
payload=entry.model_dump(exclude_none=True, mode="json"),
|
|
206
205
|
)
|
|
207
206
|
for entry in entries
|
|
208
207
|
if entry.id in embeddings
|
|
@@ -214,7 +213,11 @@ class QdrantVectorStore(VectorStoreWithEmbedder[VectorStoreOptions]):
|
|
|
214
213
|
wait=True,
|
|
215
214
|
)
|
|
216
215
|
|
|
217
|
-
async def retrieve(
|
|
216
|
+
async def retrieve(
|
|
217
|
+
self,
|
|
218
|
+
text: str,
|
|
219
|
+
options: VectorStoreOptions | None = None,
|
|
220
|
+
) -> list[VectorStoreResult]:
|
|
218
221
|
"""
|
|
219
222
|
Retrieves entries from the Qdrant collection based on vector similarity.
|
|
220
223
|
|
|
@@ -236,7 +239,7 @@ class QdrantVectorStore(VectorStoreWithEmbedder[VectorStoreOptions]):
|
|
|
236
239
|
)
|
|
237
240
|
with trace(
|
|
238
241
|
text=text,
|
|
239
|
-
options=merged_options,
|
|
242
|
+
options=merged_options.dict(),
|
|
240
243
|
index_name=self._index_name,
|
|
241
244
|
distance_method=self._distance_method,
|
|
242
245
|
embedder=repr(self._embedder),
|
|
@@ -252,6 +255,7 @@ class QdrantVectorStore(VectorStoreWithEmbedder[VectorStoreOptions]):
|
|
|
252
255
|
score_threshold=score_threshold,
|
|
253
256
|
with_payload=True,
|
|
254
257
|
with_vectors=True,
|
|
258
|
+
query_filter=self._create_qdrant_filter(merged_options.where),
|
|
255
259
|
)
|
|
256
260
|
|
|
257
261
|
outputs.results = []
|
|
@@ -290,16 +294,19 @@ class QdrantVectorStore(VectorStoreWithEmbedder[VectorStoreOptions]):
|
|
|
290
294
|
)
|
|
291
295
|
|
|
292
296
|
@staticmethod
|
|
293
|
-
def _create_qdrant_filter(where: WhereQuery) -> Filter:
|
|
297
|
+
def _create_qdrant_filter(where: WhereQuery | None) -> Filter:
|
|
294
298
|
"""
|
|
295
299
|
Creates the QdrantFilter from the given WhereQuery.
|
|
296
300
|
|
|
297
301
|
Args:
|
|
298
|
-
where: The WhereQuery to filter.
|
|
302
|
+
where: The WhereQuery to filter. If None, returns an empty filter.
|
|
299
303
|
|
|
300
304
|
Returns:
|
|
301
305
|
The created filter.
|
|
302
306
|
"""
|
|
307
|
+
if where is None:
|
|
308
|
+
return Filter(must=[])
|
|
309
|
+
|
|
303
310
|
where = flatten_dict(where) # type: ignore
|
|
304
311
|
|
|
305
312
|
return Filter(
|
|
@@ -336,6 +343,7 @@ class QdrantVectorStore(VectorStoreWithEmbedder[VectorStoreOptions]):
|
|
|
336
343
|
return []
|
|
337
344
|
|
|
338
345
|
limit = limit or (await self._client.count(collection_name=self._index_name)).count
|
|
346
|
+
limit = max(1, limit)
|
|
339
347
|
|
|
340
348
|
qdrant_filter = self._create_qdrant_filter(where) if where else None
|
|
341
349
|
|