beaver-db 0.4.0__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of beaver-db might be problematic. Click here for more details.

beaver/__init__.py CHANGED
@@ -1 +1,2 @@
1
- from .core import BeaverDB, Document
1
+ from .core import BeaverDB
2
+ from .collections import Document, WalkDirection
beaver/collections.py ADDED
@@ -0,0 +1,327 @@
1
+ import json
2
+ import sqlite3
3
+ import uuid
4
+ from enum import Enum
5
+ from typing import Any, List, Literal, Set
6
+
7
+ import numpy as np
8
+ from scipy.spatial import cKDTree
9
+
10
+
11
+ class WalkDirection(Enum):
12
+ OUTGOING = "outgoing"
13
+ INCOMING = "incoming"
14
+
15
+
16
+ class Document:
17
+ """A data class representing a single item in a collection."""
18
+
19
+ def __init__(
20
+ self, embedding: list[float] | None = None, id: str | None = None, **metadata
21
+ ):
22
+ self.id = id or str(uuid.uuid4())
23
+
24
+ if embedding is None:
25
+ self.embedding = None
26
+ else:
27
+ if not isinstance(embedding, list) or not all(
28
+ isinstance(x, (int, float)) for x in embedding
29
+ ):
30
+ raise TypeError("Embedding must be a list of numbers.")
31
+ self.embedding = np.array(embedding, dtype=np.float32)
32
+
33
+ for key, value in metadata.items():
34
+ setattr(self, key, value)
35
+
36
+ def to_dict(self) -> dict[str, Any]:
37
+ """Serializes the document's metadata to a dictionary."""
38
+ metadata = self.__dict__.copy()
39
+ metadata.pop("embedding", None)
40
+ metadata.pop("id", None)
41
+ return metadata
42
+
43
+ def __repr__(self):
44
+ metadata_str = ", ".join(f"{k}={v!r}" for k, v in self.to_dict().items())
45
+ return f"Document(id='{self.id}', {metadata_str})"
46
+
47
+
48
+ class CollectionWrapper:
49
+ """
50
+ A wrapper for multi-modal collection operations with an in-memory ANN index,
51
+ FTS, and graph capabilities.
52
+ """
53
+
54
+ def __init__(self, name: str, conn: sqlite3.Connection):
55
+ self._name = name
56
+ self._conn = conn
57
+ self._kdtree: cKDTree | None = None
58
+ self._doc_ids: List[str] = []
59
+ self._local_index_version = -1 # Version of the in-memory index
60
+
61
+ def _flatten_metadata(self, metadata: dict, prefix: str = "") -> dict[str, str]:
62
+ """Flattens a nested dictionary and filters for string values."""
63
+ flat_dict = {}
64
+ for key, value in metadata.items():
65
+ new_key = f"{prefix}__{key}" if prefix else key
66
+ if isinstance(value, dict):
67
+ flat_dict.update(self._flatten_metadata(value, new_key))
68
+ elif isinstance(value, str):
69
+ flat_dict[new_key] = value
70
+ return flat_dict
71
+
72
+ def _get_db_version(self) -> int:
73
+ """Gets the current version of the collection from the database."""
74
+ cursor = self._conn.cursor()
75
+ cursor.execute(
76
+ "SELECT version FROM beaver_collection_versions WHERE collection_name = ?",
77
+ (self._name,),
78
+ )
79
+ result = cursor.fetchone()
80
+ return result[0] if result else 0
81
+
82
+ def _is_index_stale(self) -> bool:
83
+ """Checks if the in-memory index is out of sync with the DB."""
84
+ if self._local_index_version == -1:
85
+ return True
86
+ return self._local_index_version < self._get_db_version()
87
+
88
+ def index(self, document: Document, *, fts: bool = True):
89
+ """Indexes a Document, performing an upsert and updating the FTS index."""
90
+ with self._conn:
91
+ if fts:
92
+ self._conn.execute(
93
+ "DELETE FROM beaver_fts_index WHERE collection = ? AND item_id = ?",
94
+ (self._name, document.id),
95
+ )
96
+ string_fields = self._flatten_metadata(document.to_dict())
97
+ if string_fields:
98
+ fts_data = [
99
+ (self._name, document.id, path, content)
100
+ for path, content in string_fields.items()
101
+ ]
102
+ self._conn.executemany(
103
+ "INSERT INTO beaver_fts_index (collection, item_id, field_path, field_content) VALUES (?, ?, ?, ?)",
104
+ fts_data,
105
+ )
106
+
107
+ self._conn.execute(
108
+ "INSERT OR REPLACE INTO beaver_collections (collection, item_id, item_vector, metadata) VALUES (?, ?, ?, ?)",
109
+ (
110
+ self._name,
111
+ document.id,
112
+ (
113
+ document.embedding.tobytes()
114
+ if document.embedding is not None
115
+ else None
116
+ ),
117
+ json.dumps(document.to_dict()),
118
+ ),
119
+ )
120
+ # Atomically increment the collection's version number
121
+ self._conn.execute(
122
+ """
123
+ INSERT INTO beaver_collection_versions (collection_name, version) VALUES (?, 1)
124
+ ON CONFLICT(collection_name) DO UPDATE SET version = version + 1
125
+ """,
126
+ (self._name,),
127
+ )
128
+
129
+ def drop(self, document: Document):
130
+ """Removes a document and all its associated data from the collection."""
131
+ if not isinstance(document, Document):
132
+ raise TypeError("Item to drop must be a Document object.")
133
+ with self._conn:
134
+ self._conn.execute(
135
+ "DELETE FROM beaver_collections WHERE collection = ? AND item_id = ?",
136
+ (self._name, document.id),
137
+ )
138
+ self._conn.execute(
139
+ "DELETE FROM beaver_fts_index WHERE collection = ? AND item_id = ?",
140
+ (self._name, document.id),
141
+ )
142
+ self._conn.execute(
143
+ "DELETE FROM beaver_edges WHERE collection = ? AND (source_item_id = ? OR target_item_id = ?)",
144
+ (self._name, document.id, document.id),
145
+ )
146
+ self._conn.execute(
147
+ """
148
+ INSERT INTO beaver_collection_versions (collection_name, version) VALUES (?, 1)
149
+ ON CONFLICT(collection_name) DO UPDATE SET version = version + 1
150
+ """,
151
+ (self._name,),
152
+ )
153
+
154
+ def refresh(self):
155
+ """Forces a rebuild of the in-memory ANN index from data in SQLite."""
156
+ cursor = self._conn.cursor()
157
+ cursor.execute(
158
+ "SELECT item_id, item_vector FROM beaver_collections WHERE collection = ? AND item_vector IS NOT NULL",
159
+ (self._name,),
160
+ )
161
+ vectors, self._doc_ids = [], []
162
+ for row in cursor.fetchall():
163
+ self._doc_ids.append(row["item_id"])
164
+ vectors.append(np.frombuffer(row["item_vector"], dtype=np.float32))
165
+
166
+ self._kdtree = cKDTree(vectors) if vectors else None
167
+ self._local_index_version = self._get_db_version()
168
+
169
+ def search(
170
+ self, vector: list[float], top_k: int = 10
171
+ ) -> list[tuple[Document, float]]:
172
+ """Performs a fast approximate nearest neighbor search."""
173
+ if self._is_index_stale():
174
+ self.refresh()
175
+ if not self._kdtree:
176
+ return []
177
+
178
+ distances, indices = self._kdtree.query(
179
+ np.array(vector, dtype=np.float32), k=top_k
180
+ )
181
+ if top_k == 1:
182
+ distances, indices = [distances], [indices]
183
+
184
+ result_ids = [self._doc_ids[i] for i in indices]
185
+ placeholders = ",".join("?" for _ in result_ids)
186
+ sql = f"SELECT item_id, item_vector, metadata FROM beaver_collections WHERE collection = ? AND item_id IN ({placeholders})"
187
+
188
+ cursor = self._conn.cursor()
189
+ rows = cursor.execute(sql, (self._name, *result_ids)).fetchall()
190
+ row_map = {row["item_id"]: row for row in rows}
191
+
192
+ results = []
193
+ for i, doc_id in enumerate(result_ids):
194
+ row = row_map.get(doc_id)
195
+ if row:
196
+ embedding = np.frombuffer(row["item_vector"], dtype=np.float32).tolist()
197
+ doc = Document(
198
+ id=doc_id, embedding=embedding, **json.loads(row["metadata"])
199
+ )
200
+ results.append((doc, float(distances[i])))
201
+ return results
202
+
203
+ def match(
204
+ self, query: str, on_field: str | None = None, top_k: int = 10
205
+ ) -> list[tuple[Document, float]]:
206
+ """Performs a full-text search on indexed string fields."""
207
+ cursor = self._conn.cursor()
208
+ sql_query = """
209
+ SELECT t1.item_id, t1.item_vector, t1.metadata, fts.rank
210
+ FROM beaver_collections AS t1 JOIN (
211
+ SELECT DISTINCT item_id, rank FROM beaver_fts_index
212
+ WHERE beaver_fts_index MATCH ? {} ORDER BY rank LIMIT ?
213
+ ) AS fts ON t1.item_id = fts.item_id
214
+ WHERE t1.collection = ? ORDER BY fts.rank
215
+ """
216
+ params, field_filter_sql = [], ""
217
+ if on_field:
218
+ field_filter_sql = "AND field_path = ?"
219
+ params.extend([query, on_field])
220
+ else:
221
+ params.append(query)
222
+ params.extend([top_k, self._name])
223
+
224
+ rows = cursor.execute(
225
+ sql_query.format(field_filter_sql), tuple(params)
226
+ ).fetchall()
227
+ results = []
228
+ for row in rows:
229
+ embedding = (
230
+ np.frombuffer(row["item_vector"], dtype=np.float32).tolist()
231
+ if row["item_vector"]
232
+ else None
233
+ )
234
+ doc = Document(
235
+ id=row["item_id"], embedding=embedding, **json.loads(row["metadata"])
236
+ )
237
+ results.append((doc, row["rank"]))
238
+ return results
239
+
240
+ def connect(
241
+ self, source: Document, target: Document, label: str, metadata: dict = None
242
+ ):
243
+ """Creates a directed edge between two documents."""
244
+ if not isinstance(source, Document) or not isinstance(target, Document):
245
+ raise TypeError("Source and target must be Document objects.")
246
+ with self._conn:
247
+ self._conn.execute(
248
+ "INSERT OR REPLACE INTO beaver_edges (collection, source_item_id, target_item_id, label, metadata) VALUES (?, ?, ?, ?, ?)",
249
+ (
250
+ self._name,
251
+ source.id,
252
+ target.id,
253
+ label,
254
+ json.dumps(metadata) if metadata else None,
255
+ ),
256
+ )
257
+
258
+ def neighbors(self, doc: Document, label: str | None = None) -> list[Document]:
259
+ """Retrieves the neighboring documents connected to a given document."""
260
+ sql = "SELECT t1.item_id, t1.item_vector, t1.metadata FROM beaver_collections AS t1 JOIN beaver_edges AS t2 ON t1.item_id = t2.target_item_id AND t1.collection = t2.collection WHERE t2.collection = ? AND t2.source_item_id = ?"
261
+ params = [self._name, doc.id]
262
+ if label:
263
+ sql += " AND t2.label = ?"
264
+ params.append(label)
265
+
266
+ rows = self._conn.cursor().execute(sql, tuple(params)).fetchall()
267
+ return [
268
+ Document(
269
+ id=row["item_id"],
270
+ embedding=(
271
+ np.frombuffer(row["item_vector"], dtype=np.float32).tolist()
272
+ if row["item_vector"]
273
+ else None
274
+ ),
275
+ **json.loads(row["metadata"]),
276
+ )
277
+ for row in rows
278
+ ]
279
+
280
+ def walk(
281
+ self,
282
+ source: Document,
283
+ labels: List[str],
284
+ depth: int,
285
+ *,
286
+ direction: Literal[
287
+ WalkDirection.OUTGOING, WalkDirection.INCOMING
288
+ ] = WalkDirection.OUTGOING,
289
+ ) -> List[Document]:
290
+ """Performs a graph traversal (BFS) from a starting document."""
291
+ if not isinstance(source, Document):
292
+ raise TypeError("The starting point must be a Document object.")
293
+ if depth <= 0:
294
+ return []
295
+
296
+ source_col, target_col = (
297
+ ("source_item_id", "target_item_id")
298
+ if direction == WalkDirection.OUTGOING
299
+ else ("target_item_id", "source_item_id")
300
+ )
301
+ sql = f"""
302
+ WITH RECURSIVE walk_bfs(item_id, current_depth) AS (
303
+ SELECT ?, 0
304
+ UNION ALL
305
+ SELECT edges.{target_col}, bfs.current_depth + 1
306
+ FROM beaver_edges AS edges JOIN walk_bfs AS bfs ON edges.{source_col} = bfs.item_id
307
+ WHERE edges.collection = ? AND bfs.current_depth < ? AND edges.label IN ({','.join('?' for _ in labels)})
308
+ )
309
+ SELECT DISTINCT t1.item_id, t1.item_vector, t1.metadata
310
+ FROM beaver_collections AS t1 JOIN walk_bfs AS bfs ON t1.item_id = bfs.item_id
311
+ WHERE t1.collection = ? AND bfs.current_depth > 0
312
+ """
313
+ params = [source.id, self._name, depth] + labels + [self._name]
314
+
315
+ rows = self._conn.cursor().execute(sql, tuple(params)).fetchall()
316
+ return [
317
+ Document(
318
+ id=row["item_id"],
319
+ embedding=(
320
+ np.frombuffer(row["item_vector"], dtype=np.float32).tolist()
321
+ if row["item_vector"]
322
+ else None
323
+ ),
324
+ **json.loads(row["metadata"]),
325
+ )
326
+ for row in rows
327
+ ]