beaver-db 0.9.2__py3-none-any.whl → 0.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of beaver-db might be problematic. Click here for more details.
- beaver/collections.py +339 -149
- beaver/core.py +180 -99
- beaver/vectors.py +370 -0
- {beaver_db-0.9.2.dist-info → beaver_db-0.11.0.dist-info}/METADATA +33 -12
- beaver_db-0.11.0.dist-info/RECORD +13 -0
- beaver_db-0.9.2.dist-info/RECORD +0 -12
- {beaver_db-0.9.2.dist-info → beaver_db-0.11.0.dist-info}/WHEEL +0 -0
- {beaver_db-0.9.2.dist-info → beaver_db-0.11.0.dist-info}/licenses/LICENSE +0 -0
- {beaver_db-0.9.2.dist-info → beaver_db-0.11.0.dist-info}/top_level.txt +0 -0
beaver/collections.py
CHANGED
|
@@ -1,11 +1,69 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import sqlite3
|
|
3
|
+
import threading
|
|
3
4
|
import uuid
|
|
4
5
|
from enum import Enum
|
|
5
|
-
from typing import Any, List, Literal,
|
|
6
|
+
from typing import Any, List, Literal, Tuple
|
|
6
7
|
|
|
7
8
|
import numpy as np
|
|
8
|
-
|
|
9
|
+
|
|
10
|
+
from .vectors import VectorIndex
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
# --- Fuzzy Search Helper Functions ---
|
|
14
|
+
|
|
15
|
+
def _levenshtein_distance(s1: str, s2: str) -> int:
|
|
16
|
+
"""Calculates the Levenshtein distance between two strings."""
|
|
17
|
+
if len(s1) < len(s2):
|
|
18
|
+
return _levenshtein_distance(s2, s1)
|
|
19
|
+
if len(s2) == 0:
|
|
20
|
+
return len(s1)
|
|
21
|
+
|
|
22
|
+
previous_row = range(len(s2) + 1)
|
|
23
|
+
for i, c1 in enumerate(s1):
|
|
24
|
+
current_row = [i + 1]
|
|
25
|
+
for j, c2 in enumerate(s2):
|
|
26
|
+
insertions = previous_row[j + 1] + 1
|
|
27
|
+
deletions = current_row[j] + 1
|
|
28
|
+
substitutions = previous_row[j] + (c1 != c2)
|
|
29
|
+
current_row.append(min(insertions, deletions, substitutions))
|
|
30
|
+
previous_row = current_row
|
|
31
|
+
return previous_row[-1]
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def _get_trigrams(text: str) -> set[str]:
|
|
35
|
+
"""Generates a set of 3-character trigrams from a string."""
|
|
36
|
+
if not text or len(text) < 3:
|
|
37
|
+
return set()
|
|
38
|
+
return {text[i:i+3] for i in range(len(text) - 2)}
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def _sliding_window_levenshtein(query: str, content: str, fuzziness: int) -> int:
|
|
42
|
+
"""
|
|
43
|
+
Finds the best Levenshtein match for a query within a larger text
|
|
44
|
+
by comparing it against relevant substrings.
|
|
45
|
+
"""
|
|
46
|
+
query_tokens = query.lower().split()
|
|
47
|
+
content_tokens = content.lower().split()
|
|
48
|
+
query_len = len(query_tokens)
|
|
49
|
+
if query_len == 0:
|
|
50
|
+
return 0
|
|
51
|
+
|
|
52
|
+
min_dist = float('inf')
|
|
53
|
+
query_norm = " ".join(query_tokens)
|
|
54
|
+
|
|
55
|
+
# The window size can be slightly smaller or larger than the query length
|
|
56
|
+
# to account for missing or extra words in a fuzzy match.
|
|
57
|
+
for window_size in range(max(1, query_len - fuzziness), query_len + fuzziness + 1):
|
|
58
|
+
if window_size > len(content_tokens):
|
|
59
|
+
continue
|
|
60
|
+
for i in range(len(content_tokens) - window_size + 1):
|
|
61
|
+
window_text = " ".join(content_tokens[i:i+window_size])
|
|
62
|
+
dist = _levenshtein_distance(query_norm, window_text)
|
|
63
|
+
if dist < min_dist:
|
|
64
|
+
min_dist = dist
|
|
65
|
+
|
|
66
|
+
return int(min_dist)
|
|
9
67
|
|
|
10
68
|
|
|
11
69
|
class WalkDirection(Enum):
|
|
@@ -47,110 +105,167 @@ class Document:
|
|
|
47
105
|
|
|
48
106
|
class CollectionManager:
|
|
49
107
|
"""
|
|
50
|
-
A wrapper for multi-modal collection operations
|
|
51
|
-
FTS,
|
|
108
|
+
A wrapper for multi-modal collection operations, including document storage,
|
|
109
|
+
FTS, fuzzy search, graph traversal, and persistent vector search.
|
|
52
110
|
"""
|
|
53
111
|
|
|
54
112
|
def __init__(self, name: str, conn: sqlite3.Connection):
|
|
55
113
|
self._name = name
|
|
56
114
|
self._conn = conn
|
|
57
|
-
|
|
58
|
-
self.
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
115
|
+
# All vector-related operations are now delegated to the VectorIndex class.
|
|
116
|
+
self._vector_index = VectorIndex(name, conn)
|
|
117
|
+
# A lock to ensure only one compaction thread runs at a time for this collection.
|
|
118
|
+
self._compaction_lock = threading.Lock()
|
|
119
|
+
self._compaction_thread: threading.Thread | None = None
|
|
120
|
+
|
|
121
|
+
def _flatten_metadata(self, metadata: dict, prefix: str = "") -> dict[str, Any]:
|
|
122
|
+
"""Flattens a nested dictionary for indexing."""
|
|
63
123
|
flat_dict = {}
|
|
64
124
|
for key, value in metadata.items():
|
|
65
|
-
new_key = f"{prefix}
|
|
125
|
+
new_key = f"{prefix}.{key}" if prefix else key
|
|
66
126
|
if isinstance(value, dict):
|
|
67
127
|
flat_dict.update(self._flatten_metadata(value, new_key))
|
|
68
|
-
|
|
128
|
+
else:
|
|
69
129
|
flat_dict[new_key] = value
|
|
70
130
|
return flat_dict
|
|
71
131
|
|
|
72
|
-
def
|
|
73
|
-
"""
|
|
132
|
+
def _needs_compaction(self, threshold: int = 1000) -> bool:
|
|
133
|
+
"""Checks if the total number of pending vector operations exceeds the threshold."""
|
|
74
134
|
cursor = self._conn.cursor()
|
|
75
135
|
cursor.execute(
|
|
76
|
-
"SELECT
|
|
77
|
-
(self._name,)
|
|
136
|
+
"SELECT COUNT(*) FROM _beaver_ann_pending_log WHERE collection_name = ?",
|
|
137
|
+
(self._name,)
|
|
78
138
|
)
|
|
79
|
-
|
|
80
|
-
|
|
139
|
+
pending_count = cursor.fetchone()[0]
|
|
140
|
+
cursor.execute(
|
|
141
|
+
"SELECT COUNT(*) FROM _beaver_ann_deletions_log WHERE collection_name = ?",
|
|
142
|
+
(self._name,)
|
|
143
|
+
)
|
|
144
|
+
deletion_count = cursor.fetchone()[0]
|
|
145
|
+
return (pending_count + deletion_count) >= threshold
|
|
146
|
+
|
|
147
|
+
def _run_compaction_and_release_lock(self):
|
|
148
|
+
"""
|
|
149
|
+
A target function for the background thread that runs the compaction
|
|
150
|
+
and ensures the lock is always released, even if errors occur.
|
|
151
|
+
"""
|
|
152
|
+
try:
|
|
153
|
+
self._vector_index.compact()
|
|
154
|
+
finally:
|
|
155
|
+
self._compaction_lock.release()
|
|
156
|
+
|
|
157
|
+
def compact(self, block: bool = False):
|
|
158
|
+
"""
|
|
159
|
+
Triggers a non-blocking background compaction of the vector index.
|
|
81
160
|
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
if self._local_index_version == -1:
|
|
85
|
-
return True
|
|
86
|
-
return self._local_index_version < self._get_db_version()
|
|
161
|
+
If a compaction is already running for this collection, this method returns
|
|
162
|
+
immediately without starting a new one.
|
|
87
163
|
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
164
|
+
Args:
|
|
165
|
+
block: If True, this method will wait for the compaction to complete
|
|
166
|
+
before returning. Defaults to False (non-blocking).
|
|
167
|
+
"""
|
|
168
|
+
# Use a non-blocking lock acquire to check if a compaction is already running.
|
|
169
|
+
if self._compaction_lock.acquire(blocking=False):
|
|
170
|
+
try:
|
|
171
|
+
# If we get the lock, start a new background thread.
|
|
172
|
+
self._compaction_thread = threading.Thread(
|
|
173
|
+
target=self._run_compaction_and_release_lock,
|
|
174
|
+
daemon=True # Daemon threads don't block program exit.
|
|
95
175
|
)
|
|
96
|
-
|
|
97
|
-
if
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
176
|
+
self._compaction_thread.start()
|
|
177
|
+
if block:
|
|
178
|
+
self._compaction_thread.join()
|
|
179
|
+
except Exception:
|
|
180
|
+
# If something goes wrong during thread creation, release the lock.
|
|
181
|
+
self._compaction_lock.release()
|
|
182
|
+
raise
|
|
183
|
+
# If acquire fails, it means another thread holds the lock, so we do nothing.
|
|
184
|
+
|
|
185
|
+
def index(
|
|
186
|
+
self,
|
|
187
|
+
document: Document,
|
|
188
|
+
*,
|
|
189
|
+
fts: bool | list[str] = True,
|
|
190
|
+
fuzzy: bool = False
|
|
191
|
+
):
|
|
192
|
+
"""
|
|
193
|
+
Indexes a Document, including vector, FTS, and fuzzy search data.
|
|
194
|
+
The entire operation is performed in a single atomic transaction.
|
|
195
|
+
"""
|
|
196
|
+
if not isinstance(document, Document):
|
|
197
|
+
raise TypeError("Item to index must be a Document object.")
|
|
106
198
|
|
|
107
|
-
|
|
199
|
+
with self._conn:
|
|
200
|
+
cursor = self._conn.cursor()
|
|
201
|
+
|
|
202
|
+
# Step 1: Core Document and Vector Storage
|
|
203
|
+
cursor.execute(
|
|
108
204
|
"INSERT OR REPLACE INTO beaver_collections (collection, item_id, item_vector, metadata) VALUES (?, ?, ?, ?)",
|
|
109
205
|
(
|
|
110
206
|
self._name,
|
|
111
207
|
document.id,
|
|
112
|
-
(
|
|
113
|
-
document.embedding.tobytes()
|
|
114
|
-
if document.embedding is not None
|
|
115
|
-
else None
|
|
116
|
-
),
|
|
208
|
+
document.embedding.tobytes() if document.embedding is not None else None,
|
|
117
209
|
json.dumps(document.to_dict()),
|
|
118
210
|
),
|
|
119
211
|
)
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
212
|
+
|
|
213
|
+
# Step 2: Delegate to the VectorIndex if an embedding exists.
|
|
214
|
+
if document.embedding is not None:
|
|
215
|
+
self._vector_index.index(document.id, document.embedding, cursor)
|
|
216
|
+
|
|
217
|
+
# Step 3: FTS and Fuzzy Indexing
|
|
218
|
+
cursor.execute("DELETE FROM beaver_fts_index WHERE collection = ? AND item_id = ?", (self._name, document.id))
|
|
219
|
+
cursor.execute("DELETE FROM beaver_trigrams WHERE collection = ? AND item_id = ?", (self._name, document.id))
|
|
220
|
+
|
|
221
|
+
flat_metadata = self._flatten_metadata(document.to_dict())
|
|
222
|
+
fields_to_index: dict[str, str] = {}
|
|
223
|
+
if isinstance(fts, list):
|
|
224
|
+
fields_to_index = {k: v for k, v in flat_metadata.items() if k in fts and isinstance(v, str)}
|
|
225
|
+
elif fts:
|
|
226
|
+
fields_to_index = {k: v for k, v in flat_metadata.items() if isinstance(v, str)}
|
|
227
|
+
|
|
228
|
+
if fields_to_index:
|
|
229
|
+
fts_data = [(self._name, document.id, path, content) for path, content in fields_to_index.items()]
|
|
230
|
+
cursor.executemany("INSERT INTO beaver_fts_index (collection, item_id, field_path, field_content) VALUES (?, ?, ?, ?)", fts_data)
|
|
231
|
+
if fuzzy:
|
|
232
|
+
trigram_data = []
|
|
233
|
+
for path, content in fields_to_index.items():
|
|
234
|
+
for trigram in _get_trigrams(content.lower()):
|
|
235
|
+
trigram_data.append((self._name, document.id, path, trigram))
|
|
236
|
+
if trigram_data:
|
|
237
|
+
cursor.executemany("INSERT INTO beaver_trigrams (collection, item_id, field_path, trigram) VALUES (?, ?, ?, ?)", trigram_data)
|
|
238
|
+
|
|
239
|
+
# Step 4: Update Collection Version to signal a change.
|
|
240
|
+
cursor.execute(
|
|
241
|
+
"INSERT INTO beaver_collection_versions (collection_name, version) VALUES (?, 1) ON CONFLICT(collection_name) DO UPDATE SET version = version + 1",
|
|
126
242
|
(self._name,),
|
|
127
243
|
)
|
|
128
244
|
|
|
245
|
+
# After the transaction commits, check if auto-compaction is needed.
|
|
246
|
+
if self._needs_compaction():
|
|
247
|
+
self.compact()
|
|
248
|
+
|
|
129
249
|
def drop(self, document: Document):
|
|
130
250
|
"""Removes a document and all its associated data from the collection."""
|
|
131
251
|
if not isinstance(document, Document):
|
|
132
252
|
raise TypeError("Item to drop must be a Document object.")
|
|
133
253
|
with self._conn:
|
|
134
|
-
self._conn.
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
)
|
|
138
|
-
self.
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
self._conn.execute(
|
|
143
|
-
"DELETE FROM beaver_edges WHERE collection = ? AND (source_item_id = ? OR target_item_id = ?)",
|
|
144
|
-
(self._name, document.id, document.id),
|
|
145
|
-
)
|
|
146
|
-
self._conn.execute(
|
|
147
|
-
"""
|
|
148
|
-
INSERT INTO beaver_collection_versions (collection_name, version) VALUES (?, 1)
|
|
149
|
-
ON CONFLICT(collection_name) DO UPDATE SET version = version + 1
|
|
150
|
-
""",
|
|
254
|
+
cursor = self._conn.cursor()
|
|
255
|
+
cursor.execute("DELETE FROM beaver_collections WHERE collection = ? AND item_id = ?", (self._name, document.id))
|
|
256
|
+
cursor.execute("DELETE FROM beaver_fts_index WHERE collection = ? AND item_id = ?", (self._name, document.id))
|
|
257
|
+
cursor.execute("DELETE FROM beaver_trigrams WHERE collection = ? AND item_id = ?", (self._name, document.id))
|
|
258
|
+
cursor.execute("DELETE FROM beaver_edges WHERE collection = ? AND (source_item_id = ? OR target_item_id = ?)", (self._name, document.id, document.id))
|
|
259
|
+
self._vector_index.drop(document.id, cursor)
|
|
260
|
+
cursor.execute(
|
|
261
|
+
"INSERT INTO beaver_collection_versions (collection_name, version) VALUES (?, 1) ON CONFLICT(collection_name) DO UPDATE SET version = version + 1",
|
|
151
262
|
(self._name,),
|
|
152
263
|
)
|
|
153
264
|
|
|
265
|
+
# Check for auto-compaction after a drop as well.
|
|
266
|
+
if self._needs_compaction():
|
|
267
|
+
self.compact()
|
|
268
|
+
|
|
154
269
|
def __iter__(self):
|
|
155
270
|
"""Returns an iterator over all documents in the collection."""
|
|
156
271
|
cursor = self._conn.cursor()
|
|
@@ -169,62 +284,69 @@ class CollectionManager:
|
|
|
169
284
|
)
|
|
170
285
|
cursor.close()
|
|
171
286
|
|
|
172
|
-
def refresh(self):
|
|
173
|
-
"""Forces a rebuild of the in-memory ANN index from data in SQLite."""
|
|
174
|
-
cursor = self._conn.cursor()
|
|
175
|
-
cursor.execute(
|
|
176
|
-
"SELECT item_id, item_vector FROM beaver_collections WHERE collection = ? AND item_vector IS NOT NULL",
|
|
177
|
-
(self._name,),
|
|
178
|
-
)
|
|
179
|
-
vectors, self._doc_ids = [], []
|
|
180
|
-
for row in cursor.fetchall():
|
|
181
|
-
self._doc_ids.append(row["item_id"])
|
|
182
|
-
vectors.append(np.frombuffer(row["item_vector"], dtype=np.float32))
|
|
183
|
-
|
|
184
|
-
self._kdtree = cKDTree(vectors) if vectors else None
|
|
185
|
-
self._local_index_version = self._get_db_version()
|
|
186
|
-
|
|
187
287
|
def search(
|
|
188
288
|
self, vector: list[float], top_k: int = 10
|
|
189
|
-
) ->
|
|
190
|
-
"""Performs a fast approximate nearest neighbor search."""
|
|
191
|
-
if
|
|
192
|
-
|
|
193
|
-
if not self._kdtree:
|
|
194
|
-
return []
|
|
289
|
+
) -> List[Tuple[Document, float]]:
|
|
290
|
+
"""Performs a fast, persistent approximate nearest neighbor search."""
|
|
291
|
+
if not isinstance(vector, list):
|
|
292
|
+
raise TypeError("Search vector must be a list of floats.")
|
|
195
293
|
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
distances, indices = self._kdtree.query(
|
|
200
|
-
np.array(vector, dtype=np.float32), k=top_k
|
|
294
|
+
search_results = self._vector_index.search(
|
|
295
|
+
np.array(vector, dtype=np.float32), top_k=top_k
|
|
201
296
|
)
|
|
202
|
-
if
|
|
203
|
-
|
|
297
|
+
if not search_results:
|
|
298
|
+
return []
|
|
299
|
+
|
|
300
|
+
result_ids = [item[0] for item in search_results]
|
|
301
|
+
distance_map = {item[0]: item[1] for item in search_results}
|
|
204
302
|
|
|
205
|
-
result_ids = [self._doc_ids[i] for i in indices]
|
|
206
303
|
placeholders = ",".join("?" for _ in result_ids)
|
|
207
304
|
sql = f"SELECT item_id, item_vector, metadata FROM beaver_collections WHERE collection = ? AND item_id IN ({placeholders})"
|
|
208
305
|
|
|
209
306
|
cursor = self._conn.cursor()
|
|
210
307
|
rows = cursor.execute(sql, (self._name, *result_ids)).fetchall()
|
|
211
|
-
row_map = {row["item_id"]: row for row in rows}
|
|
212
308
|
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
309
|
+
doc_map = {
|
|
310
|
+
row["item_id"]: Document(
|
|
311
|
+
id=row["item_id"],
|
|
312
|
+
embedding=(np.frombuffer(row["item_vector"], dtype=np.float32).tolist() if row["item_vector"] else None),
|
|
313
|
+
**json.loads(row["metadata"]),
|
|
314
|
+
)
|
|
315
|
+
for row in rows
|
|
316
|
+
}
|
|
317
|
+
|
|
318
|
+
final_results = []
|
|
319
|
+
for doc_id in result_ids:
|
|
320
|
+
if doc_id in doc_map:
|
|
321
|
+
doc = doc_map[doc_id]
|
|
322
|
+
distance = distance_map[doc_id]
|
|
323
|
+
final_results.append((doc, distance))
|
|
324
|
+
|
|
325
|
+
return final_results
|
|
223
326
|
|
|
224
327
|
def match(
|
|
225
|
-
self,
|
|
328
|
+
self,
|
|
329
|
+
query: str,
|
|
330
|
+
*,
|
|
331
|
+
on: str | list[str] | None = None,
|
|
332
|
+
top_k: int = 10,
|
|
333
|
+
fuzziness: int = 0
|
|
334
|
+
) -> list[tuple[Document, float]]:
|
|
335
|
+
"""
|
|
336
|
+
Performs a full-text or fuzzy search on indexed string fields.
|
|
337
|
+
"""
|
|
338
|
+
if isinstance(on, str):
|
|
339
|
+
on = [on]
|
|
340
|
+
|
|
341
|
+
if fuzziness == 0:
|
|
342
|
+
return self._perform_fts_search(query, on, top_k)
|
|
343
|
+
else:
|
|
344
|
+
return self._perform_fuzzy_search(query, on, top_k, fuzziness)
|
|
345
|
+
|
|
346
|
+
def _perform_fts_search(
|
|
347
|
+
self, query: str, on: list[str] | None, top_k: int
|
|
226
348
|
) -> list[tuple[Document, float]]:
|
|
227
|
-
"""Performs a
|
|
349
|
+
"""Performs a standard FTS search."""
|
|
228
350
|
cursor = self._conn.cursor()
|
|
229
351
|
sql_query = """
|
|
230
352
|
SELECT t1.item_id, t1.item_vector, t1.metadata, fts.rank
|
|
@@ -234,30 +356,121 @@ class CollectionManager:
|
|
|
234
356
|
) AS fts ON t1.item_id = fts.item_id
|
|
235
357
|
WHERE t1.collection = ? ORDER BY fts.rank
|
|
236
358
|
"""
|
|
237
|
-
params
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
params.
|
|
243
|
-
params.extend([top_k, self._name])
|
|
359
|
+
params: list[Any] = [query]
|
|
360
|
+
field_filter_sql = ""
|
|
361
|
+
if on:
|
|
362
|
+
placeholders = ",".join("?" for _ in on)
|
|
363
|
+
field_filter_sql = f"AND field_path IN ({placeholders})"
|
|
364
|
+
params.extend(on)
|
|
244
365
|
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
).fetchall()
|
|
366
|
+
params.extend([top_k, self._name])
|
|
367
|
+
rows = cursor.execute(sql_query.format(field_filter_sql), tuple(params)).fetchall()
|
|
248
368
|
results = []
|
|
249
369
|
for row in rows:
|
|
250
370
|
embedding = (
|
|
251
371
|
np.frombuffer(row["item_vector"], dtype=np.float32).tolist()
|
|
252
|
-
if row["item_vector"]
|
|
253
|
-
else None
|
|
254
|
-
)
|
|
255
|
-
doc = Document(
|
|
256
|
-
id=row["item_id"], embedding=embedding, **json.loads(row["metadata"])
|
|
372
|
+
if row["item_vector"] else None
|
|
257
373
|
)
|
|
374
|
+
doc = Document(id=row["item_id"], embedding=embedding, **json.loads(row["metadata"]))
|
|
258
375
|
results.append((doc, row["rank"]))
|
|
259
376
|
return results
|
|
260
377
|
|
|
378
|
+
def _get_trigram_candidates(self, query: str, on: list[str] | None) -> set[str]:
|
|
379
|
+
"""
|
|
380
|
+
Gets document IDs that meet a trigram similarity threshold with the query.
|
|
381
|
+
"""
|
|
382
|
+
query_trigrams = _get_trigrams(query.lower())
|
|
383
|
+
if not query_trigrams:
|
|
384
|
+
return set()
|
|
385
|
+
|
|
386
|
+
similarity_threshold = int(len(query_trigrams) * 0.3)
|
|
387
|
+
if similarity_threshold == 0:
|
|
388
|
+
return set()
|
|
389
|
+
|
|
390
|
+
cursor = self._conn.cursor()
|
|
391
|
+
sql = """
|
|
392
|
+
SELECT item_id FROM beaver_trigrams
|
|
393
|
+
WHERE collection = ? AND trigram IN ({}) {}
|
|
394
|
+
GROUP BY item_id
|
|
395
|
+
HAVING COUNT(DISTINCT trigram) >= ?
|
|
396
|
+
"""
|
|
397
|
+
params: list[Any] = [self._name]
|
|
398
|
+
trigram_placeholders = ",".join("?" for _ in query_trigrams)
|
|
399
|
+
params.extend(query_trigrams)
|
|
400
|
+
|
|
401
|
+
field_filter_sql = ""
|
|
402
|
+
if on:
|
|
403
|
+
field_placeholders = ",".join("?" for _ in on)
|
|
404
|
+
field_filter_sql = f"AND field_path IN ({field_placeholders})"
|
|
405
|
+
params.extend(on)
|
|
406
|
+
|
|
407
|
+
params.append(similarity_threshold)
|
|
408
|
+
cursor.execute(sql.format(trigram_placeholders, field_filter_sql), tuple(params))
|
|
409
|
+
return {row['item_id'] for row in cursor.fetchall()}
|
|
410
|
+
|
|
411
|
+
def _perform_fuzzy_search(
|
|
412
|
+
self, query: str, on: list[str] | None, top_k: int, fuzziness: int
|
|
413
|
+
) -> list[tuple[Document, float]]:
|
|
414
|
+
"""Performs a 3-stage fuzzy search: gather, score, and sort."""
|
|
415
|
+
fts_results = self._perform_fts_search(query, on, top_k)
|
|
416
|
+
fts_candidate_ids = {doc.id for doc, _ in fts_results}
|
|
417
|
+
trigram_candidate_ids = self._get_trigram_candidates(query, on)
|
|
418
|
+
candidate_ids = fts_candidate_ids.union(trigram_candidate_ids)
|
|
419
|
+
if not candidate_ids:
|
|
420
|
+
return []
|
|
421
|
+
|
|
422
|
+
cursor = self._conn.cursor()
|
|
423
|
+
id_placeholders = ",".join("?" for _ in candidate_ids)
|
|
424
|
+
sql_text = f"SELECT item_id, field_path, field_content FROM beaver_fts_index WHERE collection = ? AND item_id IN ({id_placeholders})"
|
|
425
|
+
params_text: list[Any] = [self._name]
|
|
426
|
+
params_text.extend(candidate_ids)
|
|
427
|
+
if on:
|
|
428
|
+
sql_text += f" AND field_path IN ({','.join('?' for _ in on)})"
|
|
429
|
+
params_text.extend(on)
|
|
430
|
+
|
|
431
|
+
cursor.execute(sql_text, tuple(params_text))
|
|
432
|
+
candidate_texts: dict[str, dict[str, str]] = {}
|
|
433
|
+
for row in cursor.fetchall():
|
|
434
|
+
item_id = row['item_id']
|
|
435
|
+
if item_id not in candidate_texts:
|
|
436
|
+
candidate_texts[item_id] = {}
|
|
437
|
+
candidate_texts[item_id][row['field_path']] = row['field_content']
|
|
438
|
+
|
|
439
|
+
scored_candidates = []
|
|
440
|
+
fts_rank_map = {doc.id: rank for doc, rank in fts_results}
|
|
441
|
+
|
|
442
|
+
for item_id in candidate_ids:
|
|
443
|
+
if item_id not in candidate_texts:
|
|
444
|
+
continue
|
|
445
|
+
min_dist = float('inf')
|
|
446
|
+
for content in candidate_texts[item_id].values():
|
|
447
|
+
dist = _sliding_window_levenshtein(query, content, fuzziness)
|
|
448
|
+
if dist < min_dist:
|
|
449
|
+
min_dist = dist
|
|
450
|
+
if min_dist <= fuzziness:
|
|
451
|
+
scored_candidates.append({
|
|
452
|
+
"id": item_id,
|
|
453
|
+
"distance": min_dist,
|
|
454
|
+
"fts_rank": fts_rank_map.get(item_id, 0)
|
|
455
|
+
})
|
|
456
|
+
|
|
457
|
+
scored_candidates.sort(key=lambda x: (x["distance"], x["fts_rank"]))
|
|
458
|
+
top_ids = [c["id"] for c in scored_candidates[:top_k]]
|
|
459
|
+
if not top_ids:
|
|
460
|
+
return []
|
|
461
|
+
|
|
462
|
+
id_placeholders = ",".join("?" for _ in top_ids)
|
|
463
|
+
sql_docs = f"SELECT item_id, item_vector, metadata FROM beaver_collections WHERE collection = ? AND item_id IN ({id_placeholders})"
|
|
464
|
+
cursor.execute(sql_docs, (self._name, *top_ids))
|
|
465
|
+
doc_map = {row["item_id"]: Document(id=row["item_id"], embedding=(np.frombuffer(row["item_vector"], dtype=np.float32).tolist() if row["item_vector"] else None), **json.loads(row["metadata"])) for row in cursor.fetchall()}
|
|
466
|
+
|
|
467
|
+
final_results = []
|
|
468
|
+
distance_map = {c["id"]: c["distance"] for c in scored_candidates}
|
|
469
|
+
for doc_id in top_ids:
|
|
470
|
+
if doc_id in doc_map:
|
|
471
|
+
final_results.append((doc_map[doc_id], float(distance_map[doc_id])))
|
|
472
|
+
return final_results
|
|
473
|
+
|
|
261
474
|
def connect(
|
|
262
475
|
self, source: Document, target: Document, label: str, metadata: dict = None
|
|
263
476
|
):
|
|
@@ -355,49 +568,26 @@ def rerank(
|
|
|
355
568
|
) -> list[Document]:
|
|
356
569
|
"""
|
|
357
570
|
Reranks documents from multiple search result lists using Reverse Rank Fusion (RRF).
|
|
358
|
-
This function is specifically designed to work with beaver.collections.Document objects.
|
|
359
|
-
|
|
360
|
-
Args:
|
|
361
|
-
results (sequence of list[Document]): A sequence of search result lists, where each
|
|
362
|
-
inner list contains Document objects.
|
|
363
|
-
weights (list[float], optional): A list of weights corresponding to each
|
|
364
|
-
result list. If None, all lists are weighted equally. Defaults to None.
|
|
365
|
-
k (int, optional): A constant used in the RRF formula. Defaults to 60.
|
|
366
|
-
|
|
367
|
-
Returns:
|
|
368
|
-
list[Document]: A single, reranked list of unique Document objects, sorted
|
|
369
|
-
by their fused rank score in descending order.
|
|
370
571
|
"""
|
|
371
572
|
if not results:
|
|
372
573
|
return []
|
|
373
574
|
|
|
374
|
-
# Assign a default weight of 1.0 if none are provided
|
|
375
575
|
if weights is None:
|
|
376
576
|
weights = [1.0] * len(results)
|
|
377
577
|
|
|
378
578
|
if len(results) != len(weights):
|
|
379
579
|
raise ValueError("The number of result lists must match the number of weights.")
|
|
380
580
|
|
|
381
|
-
# Use dictionaries to store scores and unique documents by their ID
|
|
382
581
|
rrf_scores: dict[str, float] = {}
|
|
383
582
|
doc_store: dict[str, Document] = {}
|
|
384
583
|
|
|
385
|
-
# Iterate through each list of Document objects and its weight
|
|
386
584
|
for result_list, weight in zip(results, weights):
|
|
387
585
|
for rank, doc in enumerate(result_list):
|
|
388
|
-
# Use the .id attribute from the Document object
|
|
389
586
|
doc_id = doc.id
|
|
390
587
|
if doc_id not in doc_store:
|
|
391
588
|
doc_store[doc_id] = doc
|
|
392
|
-
|
|
393
|
-
# Calculate the reciprocal rank score, scaled by the weight
|
|
394
589
|
score = weight * (1 / (k + rank))
|
|
395
|
-
|
|
396
|
-
# Add the score to the document's running total
|
|
397
590
|
rrf_scores[doc_id] = rrf_scores.get(doc_id, 0.0) + score
|
|
398
591
|
|
|
399
|
-
# Sort the document IDs by their final aggregated scores
|
|
400
592
|
sorted_doc_ids = sorted(rrf_scores.keys(), key=rrf_scores.get, reverse=True)
|
|
401
|
-
|
|
402
|
-
# Return the final list of Document objects in the new, reranked order
|
|
403
593
|
return [doc_store[doc_id] for doc_id in sorted_doc_ids]
|