beaver-db 0.9.2__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of beaver-db might be problematic. Click here for more details.

beaver/collections.py CHANGED
@@ -1,11 +1,69 @@
1
1
  import json
2
2
  import sqlite3
3
+ import threading
3
4
  import uuid
4
5
  from enum import Enum
5
- from typing import Any, List, Literal, Set
6
+ from typing import Any, List, Literal, Tuple
6
7
 
7
8
  import numpy as np
8
- from scipy.spatial import cKDTree
9
+
10
+ from .vectors import VectorIndex
11
+
12
+
13
+ # --- Fuzzy Search Helper Functions ---
14
+
15
+ def _levenshtein_distance(s1: str, s2: str) -> int:
16
+ """Calculates the Levenshtein distance between two strings."""
17
+ if len(s1) < len(s2):
18
+ return _levenshtein_distance(s2, s1)
19
+ if len(s2) == 0:
20
+ return len(s1)
21
+
22
+ previous_row = range(len(s2) + 1)
23
+ for i, c1 in enumerate(s1):
24
+ current_row = [i + 1]
25
+ for j, c2 in enumerate(s2):
26
+ insertions = previous_row[j + 1] + 1
27
+ deletions = current_row[j] + 1
28
+ substitutions = previous_row[j] + (c1 != c2)
29
+ current_row.append(min(insertions, deletions, substitutions))
30
+ previous_row = current_row
31
+ return previous_row[-1]
32
+
33
+
34
+ def _get_trigrams(text: str) -> set[str]:
35
+ """Generates a set of 3-character trigrams from a string."""
36
+ if not text or len(text) < 3:
37
+ return set()
38
+ return {text[i:i+3] for i in range(len(text) - 2)}
39
+
40
+
41
+ def _sliding_window_levenshtein(query: str, content: str, fuzziness: int) -> int:
42
+ """
43
+ Finds the best Levenshtein match for a query within a larger text
44
+ by comparing it against relevant substrings.
45
+ """
46
+ query_tokens = query.lower().split()
47
+ content_tokens = content.lower().split()
48
+ query_len = len(query_tokens)
49
+ if query_len == 0:
50
+ return 0
51
+
52
+ min_dist = float('inf')
53
+ query_norm = " ".join(query_tokens)
54
+
55
+ # The window size can be slightly smaller or larger than the query length
56
+ # to account for missing or extra words in a fuzzy match.
57
+ for window_size in range(max(1, query_len - fuzziness), query_len + fuzziness + 1):
58
+ if window_size > len(content_tokens):
59
+ continue
60
+ for i in range(len(content_tokens) - window_size + 1):
61
+ window_text = " ".join(content_tokens[i:i+window_size])
62
+ dist = _levenshtein_distance(query_norm, window_text)
63
+ if dist < min_dist:
64
+ min_dist = dist
65
+
66
+ return int(min_dist)
9
67
 
10
68
 
11
69
  class WalkDirection(Enum):
@@ -47,110 +105,167 @@ class Document:
47
105
 
48
106
  class CollectionManager:
49
107
  """
50
- A wrapper for multi-modal collection operations with an in-memory ANN index,
51
- FTS, and graph capabilities.
108
+ A wrapper for multi-modal collection operations, including document storage,
109
+ FTS, fuzzy search, graph traversal, and persistent vector search.
52
110
  """
53
111
 
54
112
  def __init__(self, name: str, conn: sqlite3.Connection):
55
113
  self._name = name
56
114
  self._conn = conn
57
- self._kdtree: cKDTree | None = None
58
- self._doc_ids: List[str] = []
59
- self._local_index_version = -1 # Version of the in-memory index
60
-
61
- def _flatten_metadata(self, metadata: dict, prefix: str = "") -> dict[str, str]:
62
- """Flattens a nested dictionary and filters for string values."""
115
+ # All vector-related operations are now delegated to the VectorIndex class.
116
+ self._vector_index = VectorIndex(name, conn)
117
+ # A lock to ensure only one compaction thread runs at a time for this collection.
118
+ self._compaction_lock = threading.Lock()
119
+ self._compaction_thread: threading.Thread | None = None
120
+
121
+ def _flatten_metadata(self, metadata: dict, prefix: str = "") -> dict[str, Any]:
122
+ """Flattens a nested dictionary for indexing."""
63
123
  flat_dict = {}
64
124
  for key, value in metadata.items():
65
- new_key = f"{prefix}__{key}" if prefix else key
125
+ new_key = f"{prefix}.{key}" if prefix else key
66
126
  if isinstance(value, dict):
67
127
  flat_dict.update(self._flatten_metadata(value, new_key))
68
- elif isinstance(value, str):
128
+ else:
69
129
  flat_dict[new_key] = value
70
130
  return flat_dict
71
131
 
72
- def _get_db_version(self) -> int:
73
- """Gets the current version of the collection from the database."""
132
+ def _needs_compaction(self, threshold: int = 1000) -> bool:
133
+ """Checks if the total number of pending vector operations exceeds the threshold."""
74
134
  cursor = self._conn.cursor()
75
135
  cursor.execute(
76
- "SELECT version FROM beaver_collection_versions WHERE collection_name = ?",
77
- (self._name,),
136
+ "SELECT COUNT(*) FROM _beaver_ann_pending_log WHERE collection_name = ?",
137
+ (self._name,)
78
138
  )
79
- result = cursor.fetchone()
80
- return result[0] if result else 0
139
+ pending_count = cursor.fetchone()[0]
140
+ cursor.execute(
141
+ "SELECT COUNT(*) FROM _beaver_ann_deletions_log WHERE collection_name = ?",
142
+ (self._name,)
143
+ )
144
+ deletion_count = cursor.fetchone()[0]
145
+ return (pending_count + deletion_count) >= threshold
146
+
147
+ def _run_compaction_and_release_lock(self):
148
+ """
149
+ A target function for the background thread that runs the compaction
150
+ and ensures the lock is always released, even if errors occur.
151
+ """
152
+ try:
153
+ self._vector_index.compact()
154
+ finally:
155
+ self._compaction_lock.release()
156
+
157
+ def compact(self, block: bool = False):
158
+ """
159
+ Triggers a non-blocking background compaction of the vector index.
81
160
 
82
- def _is_index_stale(self) -> bool:
83
- """Checks if the in-memory index is out of sync with the DB."""
84
- if self._local_index_version == -1:
85
- return True
86
- return self._local_index_version < self._get_db_version()
161
+ If a compaction is already running for this collection, this method returns
162
+ immediately without starting a new one.
87
163
 
88
- def index(self, document: Document, *, fts: bool = True):
89
- """Indexes a Document, performing an upsert and updating the FTS index."""
90
- with self._conn:
91
- if fts:
92
- self._conn.execute(
93
- "DELETE FROM beaver_fts_index WHERE collection = ? AND item_id = ?",
94
- (self._name, document.id),
164
+ Args:
165
+ block: If True, this method will wait for the compaction to complete
166
+ before returning. Defaults to False (non-blocking).
167
+ """
168
+ # Use a non-blocking lock acquire to check if a compaction is already running.
169
+ if self._compaction_lock.acquire(blocking=False):
170
+ try:
171
+ # If we get the lock, start a new background thread.
172
+ self._compaction_thread = threading.Thread(
173
+ target=self._run_compaction_and_release_lock,
174
+ daemon=True # Daemon threads don't block program exit.
95
175
  )
96
- string_fields = self._flatten_metadata(document.to_dict())
97
- if string_fields:
98
- fts_data = [
99
- (self._name, document.id, path, content)
100
- for path, content in string_fields.items()
101
- ]
102
- self._conn.executemany(
103
- "INSERT INTO beaver_fts_index (collection, item_id, field_path, field_content) VALUES (?, ?, ?, ?)",
104
- fts_data,
105
- )
176
+ self._compaction_thread.start()
177
+ if block:
178
+ self._compaction_thread.join()
179
+ except Exception:
180
+ # If something goes wrong during thread creation, release the lock.
181
+ self._compaction_lock.release()
182
+ raise
183
+ # If acquire fails, it means another thread holds the lock, so we do nothing.
184
+
185
+ def index(
186
+ self,
187
+ document: Document,
188
+ *,
189
+ fts: bool | list[str] = True,
190
+ fuzzy: bool = False
191
+ ):
192
+ """
193
+ Indexes a Document, including vector, FTS, and fuzzy search data.
194
+ The entire operation is performed in a single atomic transaction.
195
+ """
196
+ if not isinstance(document, Document):
197
+ raise TypeError("Item to index must be a Document object.")
106
198
 
107
- self._conn.execute(
199
+ with self._conn:
200
+ cursor = self._conn.cursor()
201
+
202
+ # Step 1: Core Document and Vector Storage
203
+ cursor.execute(
108
204
  "INSERT OR REPLACE INTO beaver_collections (collection, item_id, item_vector, metadata) VALUES (?, ?, ?, ?)",
109
205
  (
110
206
  self._name,
111
207
  document.id,
112
- (
113
- document.embedding.tobytes()
114
- if document.embedding is not None
115
- else None
116
- ),
208
+ document.embedding.tobytes() if document.embedding is not None else None,
117
209
  json.dumps(document.to_dict()),
118
210
  ),
119
211
  )
120
- # Atomically increment the collection's version number
121
- self._conn.execute(
122
- """
123
- INSERT INTO beaver_collection_versions (collection_name, version) VALUES (?, 1)
124
- ON CONFLICT(collection_name) DO UPDATE SET version = version + 1
125
- """,
212
+
213
+ # Step 2: Delegate to the VectorIndex if an embedding exists.
214
+ if document.embedding is not None:
215
+ self._vector_index.index(document.id, document.embedding, cursor)
216
+
217
+ # Step 3: FTS and Fuzzy Indexing
218
+ cursor.execute("DELETE FROM beaver_fts_index WHERE collection = ? AND item_id = ?", (self._name, document.id))
219
+ cursor.execute("DELETE FROM beaver_trigrams WHERE collection = ? AND item_id = ?", (self._name, document.id))
220
+
221
+ flat_metadata = self._flatten_metadata(document.to_dict())
222
+ fields_to_index: dict[str, str] = {}
223
+ if isinstance(fts, list):
224
+ fields_to_index = {k: v for k, v in flat_metadata.items() if k in fts and isinstance(v, str)}
225
+ elif fts:
226
+ fields_to_index = {k: v for k, v in flat_metadata.items() if isinstance(v, str)}
227
+
228
+ if fields_to_index:
229
+ fts_data = [(self._name, document.id, path, content) for path, content in fields_to_index.items()]
230
+ cursor.executemany("INSERT INTO beaver_fts_index (collection, item_id, field_path, field_content) VALUES (?, ?, ?, ?)", fts_data)
231
+ if fuzzy:
232
+ trigram_data = []
233
+ for path, content in fields_to_index.items():
234
+ for trigram in _get_trigrams(content.lower()):
235
+ trigram_data.append((self._name, document.id, path, trigram))
236
+ if trigram_data:
237
+ cursor.executemany("INSERT INTO beaver_trigrams (collection, item_id, field_path, trigram) VALUES (?, ?, ?, ?)", trigram_data)
238
+
239
+ # Step 4: Update Collection Version to signal a change.
240
+ cursor.execute(
241
+ "INSERT INTO beaver_collection_versions (collection_name, version) VALUES (?, 1) ON CONFLICT(collection_name) DO UPDATE SET version = version + 1",
126
242
  (self._name,),
127
243
  )
128
244
 
245
+ # After the transaction commits, check if auto-compaction is needed.
246
+ if self._needs_compaction():
247
+ self.compact()
248
+
129
249
  def drop(self, document: Document):
130
250
  """Removes a document and all its associated data from the collection."""
131
251
  if not isinstance(document, Document):
132
252
  raise TypeError("Item to drop must be a Document object.")
133
253
  with self._conn:
134
- self._conn.execute(
135
- "DELETE FROM beaver_collections WHERE collection = ? AND item_id = ?",
136
- (self._name, document.id),
137
- )
138
- self._conn.execute(
139
- "DELETE FROM beaver_fts_index WHERE collection = ? AND item_id = ?",
140
- (self._name, document.id),
141
- )
142
- self._conn.execute(
143
- "DELETE FROM beaver_edges WHERE collection = ? AND (source_item_id = ? OR target_item_id = ?)",
144
- (self._name, document.id, document.id),
145
- )
146
- self._conn.execute(
147
- """
148
- INSERT INTO beaver_collection_versions (collection_name, version) VALUES (?, 1)
149
- ON CONFLICT(collection_name) DO UPDATE SET version = version + 1
150
- """,
254
+ cursor = self._conn.cursor()
255
+ cursor.execute("DELETE FROM beaver_collections WHERE collection = ? AND item_id = ?", (self._name, document.id))
256
+ cursor.execute("DELETE FROM beaver_fts_index WHERE collection = ? AND item_id = ?", (self._name, document.id))
257
+ cursor.execute("DELETE FROM beaver_trigrams WHERE collection = ? AND item_id = ?", (self._name, document.id))
258
+ cursor.execute("DELETE FROM beaver_edges WHERE collection = ? AND (source_item_id = ? OR target_item_id = ?)", (self._name, document.id, document.id))
259
+ self._vector_index.drop(document.id, cursor)
260
+ cursor.execute(
261
+ "INSERT INTO beaver_collection_versions (collection_name, version) VALUES (?, 1) ON CONFLICT(collection_name) DO UPDATE SET version = version + 1",
151
262
  (self._name,),
152
263
  )
153
264
 
265
+ # Check for auto-compaction after a drop as well.
266
+ if self._needs_compaction():
267
+ self.compact()
268
+
154
269
  def __iter__(self):
155
270
  """Returns an iterator over all documents in the collection."""
156
271
  cursor = self._conn.cursor()
@@ -169,62 +284,69 @@ class CollectionManager:
169
284
  )
170
285
  cursor.close()
171
286
 
172
- def refresh(self):
173
- """Forces a rebuild of the in-memory ANN index from data in SQLite."""
174
- cursor = self._conn.cursor()
175
- cursor.execute(
176
- "SELECT item_id, item_vector FROM beaver_collections WHERE collection = ? AND item_vector IS NOT NULL",
177
- (self._name,),
178
- )
179
- vectors, self._doc_ids = [], []
180
- for row in cursor.fetchall():
181
- self._doc_ids.append(row["item_id"])
182
- vectors.append(np.frombuffer(row["item_vector"], dtype=np.float32))
183
-
184
- self._kdtree = cKDTree(vectors) if vectors else None
185
- self._local_index_version = self._get_db_version()
186
-
187
287
  def search(
188
288
  self, vector: list[float], top_k: int = 10
189
- ) -> list[tuple[Document, float]]:
190
- """Performs a fast approximate nearest neighbor search."""
191
- if self._is_index_stale():
192
- self.refresh()
193
- if not self._kdtree:
194
- return []
289
+ ) -> List[Tuple[Document, float]]:
290
+ """Performs a fast, persistent approximate nearest neighbor search."""
291
+ if not isinstance(vector, list):
292
+ raise TypeError("Search vector must be a list of floats.")
195
293
 
196
- if top_k > len(self._doc_ids):
197
- top_k = len(self._doc_ids)
198
-
199
- distances, indices = self._kdtree.query(
200
- np.array(vector, dtype=np.float32), k=top_k
294
+ search_results = self._vector_index.search(
295
+ np.array(vector, dtype=np.float32), top_k=top_k
201
296
  )
202
- if top_k == 1:
203
- distances, indices = [distances], [indices]
297
+ if not search_results:
298
+ return []
299
+
300
+ result_ids = [item[0] for item in search_results]
301
+ distance_map = {item[0]: item[1] for item in search_results}
204
302
 
205
- result_ids = [self._doc_ids[i] for i in indices]
206
303
  placeholders = ",".join("?" for _ in result_ids)
207
304
  sql = f"SELECT item_id, item_vector, metadata FROM beaver_collections WHERE collection = ? AND item_id IN ({placeholders})"
208
305
 
209
306
  cursor = self._conn.cursor()
210
307
  rows = cursor.execute(sql, (self._name, *result_ids)).fetchall()
211
- row_map = {row["item_id"]: row for row in rows}
212
308
 
213
- results = []
214
- for i, doc_id in enumerate(result_ids):
215
- row = row_map.get(doc_id)
216
- if row:
217
- embedding = np.frombuffer(row["item_vector"], dtype=np.float32).tolist()
218
- doc = Document(
219
- id=doc_id, embedding=embedding, **json.loads(row["metadata"])
220
- )
221
- results.append((doc, float(distances[i])))
222
- return results
309
+ doc_map = {
310
+ row["item_id"]: Document(
311
+ id=row["item_id"],
312
+ embedding=(np.frombuffer(row["item_vector"], dtype=np.float32).tolist() if row["item_vector"] else None),
313
+ **json.loads(row["metadata"]),
314
+ )
315
+ for row in rows
316
+ }
317
+
318
+ final_results = []
319
+ for doc_id in result_ids:
320
+ if doc_id in doc_map:
321
+ doc = doc_map[doc_id]
322
+ distance = distance_map[doc_id]
323
+ final_results.append((doc, distance))
324
+
325
+ return final_results
223
326
 
224
327
  def match(
225
- self, query: str, on_field: str | None = None, top_k: int = 10
328
+ self,
329
+ query: str,
330
+ *,
331
+ on: str | list[str] | None = None,
332
+ top_k: int = 10,
333
+ fuzziness: int = 0
334
+ ) -> list[tuple[Document, float]]:
335
+ """
336
+ Performs a full-text or fuzzy search on indexed string fields.
337
+ """
338
+ if isinstance(on, str):
339
+ on = [on]
340
+
341
+ if fuzziness == 0:
342
+ return self._perform_fts_search(query, on, top_k)
343
+ else:
344
+ return self._perform_fuzzy_search(query, on, top_k, fuzziness)
345
+
346
+ def _perform_fts_search(
347
+ self, query: str, on: list[str] | None, top_k: int
226
348
  ) -> list[tuple[Document, float]]:
227
- """Performs a full-text search on indexed string fields."""
349
+ """Performs a standard FTS search."""
228
350
  cursor = self._conn.cursor()
229
351
  sql_query = """
230
352
  SELECT t1.item_id, t1.item_vector, t1.metadata, fts.rank
@@ -234,30 +356,121 @@ class CollectionManager:
234
356
  ) AS fts ON t1.item_id = fts.item_id
235
357
  WHERE t1.collection = ? ORDER BY fts.rank
236
358
  """
237
- params, field_filter_sql = [], ""
238
- if on_field:
239
- field_filter_sql = "AND field_path = ?"
240
- params.extend([query, on_field])
241
- else:
242
- params.append(query)
243
- params.extend([top_k, self._name])
359
+ params: list[Any] = [query]
360
+ field_filter_sql = ""
361
+ if on:
362
+ placeholders = ",".join("?" for _ in on)
363
+ field_filter_sql = f"AND field_path IN ({placeholders})"
364
+ params.extend(on)
244
365
 
245
- rows = cursor.execute(
246
- sql_query.format(field_filter_sql), tuple(params)
247
- ).fetchall()
366
+ params.extend([top_k, self._name])
367
+ rows = cursor.execute(sql_query.format(field_filter_sql), tuple(params)).fetchall()
248
368
  results = []
249
369
  for row in rows:
250
370
  embedding = (
251
371
  np.frombuffer(row["item_vector"], dtype=np.float32).tolist()
252
- if row["item_vector"]
253
- else None
254
- )
255
- doc = Document(
256
- id=row["item_id"], embedding=embedding, **json.loads(row["metadata"])
372
+ if row["item_vector"] else None
257
373
  )
374
+ doc = Document(id=row["item_id"], embedding=embedding, **json.loads(row["metadata"]))
258
375
  results.append((doc, row["rank"]))
259
376
  return results
260
377
 
378
+ def _get_trigram_candidates(self, query: str, on: list[str] | None) -> set[str]:
379
+ """
380
+ Gets document IDs that meet a trigram similarity threshold with the query.
381
+ """
382
+ query_trigrams = _get_trigrams(query.lower())
383
+ if not query_trigrams:
384
+ return set()
385
+
386
+ similarity_threshold = int(len(query_trigrams) * 0.3)
387
+ if similarity_threshold == 0:
388
+ return set()
389
+
390
+ cursor = self._conn.cursor()
391
+ sql = """
392
+ SELECT item_id FROM beaver_trigrams
393
+ WHERE collection = ? AND trigram IN ({}) {}
394
+ GROUP BY item_id
395
+ HAVING COUNT(DISTINCT trigram) >= ?
396
+ """
397
+ params: list[Any] = [self._name]
398
+ trigram_placeholders = ",".join("?" for _ in query_trigrams)
399
+ params.extend(query_trigrams)
400
+
401
+ field_filter_sql = ""
402
+ if on:
403
+ field_placeholders = ",".join("?" for _ in on)
404
+ field_filter_sql = f"AND field_path IN ({field_placeholders})"
405
+ params.extend(on)
406
+
407
+ params.append(similarity_threshold)
408
+ cursor.execute(sql.format(trigram_placeholders, field_filter_sql), tuple(params))
409
+ return {row['item_id'] for row in cursor.fetchall()}
410
+
411
+ def _perform_fuzzy_search(
412
+ self, query: str, on: list[str] | None, top_k: int, fuzziness: int
413
+ ) -> list[tuple[Document, float]]:
414
+ """Performs a 3-stage fuzzy search: gather, score, and sort."""
415
+ fts_results = self._perform_fts_search(query, on, top_k)
416
+ fts_candidate_ids = {doc.id for doc, _ in fts_results}
417
+ trigram_candidate_ids = self._get_trigram_candidates(query, on)
418
+ candidate_ids = fts_candidate_ids.union(trigram_candidate_ids)
419
+ if not candidate_ids:
420
+ return []
421
+
422
+ cursor = self._conn.cursor()
423
+ id_placeholders = ",".join("?" for _ in candidate_ids)
424
+ sql_text = f"SELECT item_id, field_path, field_content FROM beaver_fts_index WHERE collection = ? AND item_id IN ({id_placeholders})"
425
+ params_text: list[Any] = [self._name]
426
+ params_text.extend(candidate_ids)
427
+ if on:
428
+ sql_text += f" AND field_path IN ({','.join('?' for _ in on)})"
429
+ params_text.extend(on)
430
+
431
+ cursor.execute(sql_text, tuple(params_text))
432
+ candidate_texts: dict[str, dict[str, str]] = {}
433
+ for row in cursor.fetchall():
434
+ item_id = row['item_id']
435
+ if item_id not in candidate_texts:
436
+ candidate_texts[item_id] = {}
437
+ candidate_texts[item_id][row['field_path']] = row['field_content']
438
+
439
+ scored_candidates = []
440
+ fts_rank_map = {doc.id: rank for doc, rank in fts_results}
441
+
442
+ for item_id in candidate_ids:
443
+ if item_id not in candidate_texts:
444
+ continue
445
+ min_dist = float('inf')
446
+ for content in candidate_texts[item_id].values():
447
+ dist = _sliding_window_levenshtein(query, content, fuzziness)
448
+ if dist < min_dist:
449
+ min_dist = dist
450
+ if min_dist <= fuzziness:
451
+ scored_candidates.append({
452
+ "id": item_id,
453
+ "distance": min_dist,
454
+ "fts_rank": fts_rank_map.get(item_id, 0)
455
+ })
456
+
457
+ scored_candidates.sort(key=lambda x: (x["distance"], x["fts_rank"]))
458
+ top_ids = [c["id"] for c in scored_candidates[:top_k]]
459
+ if not top_ids:
460
+ return []
461
+
462
+ id_placeholders = ",".join("?" for _ in top_ids)
463
+ sql_docs = f"SELECT item_id, item_vector, metadata FROM beaver_collections WHERE collection = ? AND item_id IN ({id_placeholders})"
464
+ cursor.execute(sql_docs, (self._name, *top_ids))
465
+ doc_map = {row["item_id"]: Document(id=row["item_id"], embedding=(np.frombuffer(row["item_vector"], dtype=np.float32).tolist() if row["item_vector"] else None), **json.loads(row["metadata"])) for row in cursor.fetchall()}
466
+
467
+ final_results = []
468
+ distance_map = {c["id"]: c["distance"] for c in scored_candidates}
469
+ for doc_id in top_ids:
470
+ if doc_id in doc_map:
471
+ final_results.append((doc_map[doc_id], float(distance_map[doc_id])))
472
+ return final_results
473
+
261
474
  def connect(
262
475
  self, source: Document, target: Document, label: str, metadata: dict = None
263
476
  ):
@@ -355,49 +568,26 @@ def rerank(
355
568
  ) -> list[Document]:
356
569
  """
357
570
  Reranks documents from multiple search result lists using Reverse Rank Fusion (RRF).
358
- This function is specifically designed to work with beaver.collections.Document objects.
359
-
360
- Args:
361
- results (sequence of list[Document]): A sequence of search result lists, where each
362
- inner list contains Document objects.
363
- weights (list[float], optional): A list of weights corresponding to each
364
- result list. If None, all lists are weighted equally. Defaults to None.
365
- k (int, optional): A constant used in the RRF formula. Defaults to 60.
366
-
367
- Returns:
368
- list[Document]: A single, reranked list of unique Document objects, sorted
369
- by their fused rank score in descending order.
370
571
  """
371
572
  if not results:
372
573
  return []
373
574
 
374
- # Assign a default weight of 1.0 if none are provided
375
575
  if weights is None:
376
576
  weights = [1.0] * len(results)
377
577
 
378
578
  if len(results) != len(weights):
379
579
  raise ValueError("The number of result lists must match the number of weights.")
380
580
 
381
- # Use dictionaries to store scores and unique documents by their ID
382
581
  rrf_scores: dict[str, float] = {}
383
582
  doc_store: dict[str, Document] = {}
384
583
 
385
- # Iterate through each list of Document objects and its weight
386
584
  for result_list, weight in zip(results, weights):
387
585
  for rank, doc in enumerate(result_list):
388
- # Use the .id attribute from the Document object
389
586
  doc_id = doc.id
390
587
  if doc_id not in doc_store:
391
588
  doc_store[doc_id] = doc
392
-
393
- # Calculate the reciprocal rank score, scaled by the weight
394
589
  score = weight * (1 / (k + rank))
395
-
396
- # Add the score to the document's running total
397
590
  rrf_scores[doc_id] = rrf_scores.get(doc_id, 0.0) + score
398
591
 
399
- # Sort the document IDs by their final aggregated scores
400
592
  sorted_doc_ids = sorted(rrf_scores.keys(), key=rrf_scores.get, reverse=True)
401
-
402
- # Return the final list of Document objects in the new, reranked order
403
593
  return [doc_store[doc_id] for doc_id in sorted_doc_ids]