beaver-db 0.3.0__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of beaver-db might be problematic. Click here for more details.
- beaver/__init__.py +2 -1
- beaver/collections.py +327 -0
- beaver/core.py +103 -383
- beaver/lists.py +166 -0
- beaver/subscribers.py +54 -0
- beaver_db-0.5.0.dist-info/METADATA +171 -0
- beaver_db-0.5.0.dist-info/RECORD +9 -0
- beaver_db-0.3.0.dist-info/METADATA +0 -129
- beaver_db-0.3.0.dist-info/RECORD +0 -6
- {beaver_db-0.3.0.dist-info → beaver_db-0.5.0.dist-info}/WHEEL +0 -0
- {beaver_db-0.3.0.dist-info → beaver_db-0.5.0.dist-info}/top_level.txt +0 -0
beaver/__init__.py
CHANGED
|
@@ -1 +1,2 @@
|
|
|
1
|
-
from .core import BeaverDB
|
|
1
|
+
from .core import BeaverDB
|
|
2
|
+
from .collections import Document, WalkDirection
|
beaver/collections.py
ADDED
|
@@ -0,0 +1,327 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import sqlite3
|
|
3
|
+
import uuid
|
|
4
|
+
from enum import Enum
|
|
5
|
+
from typing import Any, List, Literal, Set
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
from scipy.spatial import cKDTree
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class WalkDirection(Enum):
|
|
12
|
+
OUTGOING = "outgoing"
|
|
13
|
+
INCOMING = "incoming"
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class Document:
|
|
17
|
+
"""A data class representing a single item in a collection."""
|
|
18
|
+
|
|
19
|
+
def __init__(
|
|
20
|
+
self, embedding: list[float] | None = None, id: str | None = None, **metadata
|
|
21
|
+
):
|
|
22
|
+
self.id = id or str(uuid.uuid4())
|
|
23
|
+
|
|
24
|
+
if embedding is None:
|
|
25
|
+
self.embedding = None
|
|
26
|
+
else:
|
|
27
|
+
if not isinstance(embedding, list) or not all(
|
|
28
|
+
isinstance(x, (int, float)) for x in embedding
|
|
29
|
+
):
|
|
30
|
+
raise TypeError("Embedding must be a list of numbers.")
|
|
31
|
+
self.embedding = np.array(embedding, dtype=np.float32)
|
|
32
|
+
|
|
33
|
+
for key, value in metadata.items():
|
|
34
|
+
setattr(self, key, value)
|
|
35
|
+
|
|
36
|
+
def to_dict(self) -> dict[str, Any]:
|
|
37
|
+
"""Serializes the document's metadata to a dictionary."""
|
|
38
|
+
metadata = self.__dict__.copy()
|
|
39
|
+
metadata.pop("embedding", None)
|
|
40
|
+
metadata.pop("id", None)
|
|
41
|
+
return metadata
|
|
42
|
+
|
|
43
|
+
def __repr__(self):
|
|
44
|
+
metadata_str = ", ".join(f"{k}={v!r}" for k, v in self.to_dict().items())
|
|
45
|
+
return f"Document(id='{self.id}', {metadata_str})"
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class CollectionWrapper:
|
|
49
|
+
"""
|
|
50
|
+
A wrapper for multi-modal collection operations with an in-memory ANN index,
|
|
51
|
+
FTS, and graph capabilities.
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
def __init__(self, name: str, conn: sqlite3.Connection):
|
|
55
|
+
self._name = name
|
|
56
|
+
self._conn = conn
|
|
57
|
+
self._kdtree: cKDTree | None = None
|
|
58
|
+
self._doc_ids: List[str] = []
|
|
59
|
+
self._local_index_version = -1 # Version of the in-memory index
|
|
60
|
+
|
|
61
|
+
def _flatten_metadata(self, metadata: dict, prefix: str = "") -> dict[str, str]:
|
|
62
|
+
"""Flattens a nested dictionary and filters for string values."""
|
|
63
|
+
flat_dict = {}
|
|
64
|
+
for key, value in metadata.items():
|
|
65
|
+
new_key = f"{prefix}__{key}" if prefix else key
|
|
66
|
+
if isinstance(value, dict):
|
|
67
|
+
flat_dict.update(self._flatten_metadata(value, new_key))
|
|
68
|
+
elif isinstance(value, str):
|
|
69
|
+
flat_dict[new_key] = value
|
|
70
|
+
return flat_dict
|
|
71
|
+
|
|
72
|
+
def _get_db_version(self) -> int:
|
|
73
|
+
"""Gets the current version of the collection from the database."""
|
|
74
|
+
cursor = self._conn.cursor()
|
|
75
|
+
cursor.execute(
|
|
76
|
+
"SELECT version FROM beaver_collection_versions WHERE collection_name = ?",
|
|
77
|
+
(self._name,),
|
|
78
|
+
)
|
|
79
|
+
result = cursor.fetchone()
|
|
80
|
+
return result[0] if result else 0
|
|
81
|
+
|
|
82
|
+
def _is_index_stale(self) -> bool:
|
|
83
|
+
"""Checks if the in-memory index is out of sync with the DB."""
|
|
84
|
+
if self._local_index_version == -1:
|
|
85
|
+
return True
|
|
86
|
+
return self._local_index_version < self._get_db_version()
|
|
87
|
+
|
|
88
|
+
def index(self, document: Document, *, fts: bool = True):
|
|
89
|
+
"""Indexes a Document, performing an upsert and updating the FTS index."""
|
|
90
|
+
with self._conn:
|
|
91
|
+
if fts:
|
|
92
|
+
self._conn.execute(
|
|
93
|
+
"DELETE FROM beaver_fts_index WHERE collection = ? AND item_id = ?",
|
|
94
|
+
(self._name, document.id),
|
|
95
|
+
)
|
|
96
|
+
string_fields = self._flatten_metadata(document.to_dict())
|
|
97
|
+
if string_fields:
|
|
98
|
+
fts_data = [
|
|
99
|
+
(self._name, document.id, path, content)
|
|
100
|
+
for path, content in string_fields.items()
|
|
101
|
+
]
|
|
102
|
+
self._conn.executemany(
|
|
103
|
+
"INSERT INTO beaver_fts_index (collection, item_id, field_path, field_content) VALUES (?, ?, ?, ?)",
|
|
104
|
+
fts_data,
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
self._conn.execute(
|
|
108
|
+
"INSERT OR REPLACE INTO beaver_collections (collection, item_id, item_vector, metadata) VALUES (?, ?, ?, ?)",
|
|
109
|
+
(
|
|
110
|
+
self._name,
|
|
111
|
+
document.id,
|
|
112
|
+
(
|
|
113
|
+
document.embedding.tobytes()
|
|
114
|
+
if document.embedding is not None
|
|
115
|
+
else None
|
|
116
|
+
),
|
|
117
|
+
json.dumps(document.to_dict()),
|
|
118
|
+
),
|
|
119
|
+
)
|
|
120
|
+
# Atomically increment the collection's version number
|
|
121
|
+
self._conn.execute(
|
|
122
|
+
"""
|
|
123
|
+
INSERT INTO beaver_collection_versions (collection_name, version) VALUES (?, 1)
|
|
124
|
+
ON CONFLICT(collection_name) DO UPDATE SET version = version + 1
|
|
125
|
+
""",
|
|
126
|
+
(self._name,),
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
def drop(self, document: Document):
|
|
130
|
+
"""Removes a document and all its associated data from the collection."""
|
|
131
|
+
if not isinstance(document, Document):
|
|
132
|
+
raise TypeError("Item to drop must be a Document object.")
|
|
133
|
+
with self._conn:
|
|
134
|
+
self._conn.execute(
|
|
135
|
+
"DELETE FROM beaver_collections WHERE collection = ? AND item_id = ?",
|
|
136
|
+
(self._name, document.id),
|
|
137
|
+
)
|
|
138
|
+
self._conn.execute(
|
|
139
|
+
"DELETE FROM beaver_fts_index WHERE collection = ? AND item_id = ?",
|
|
140
|
+
(self._name, document.id),
|
|
141
|
+
)
|
|
142
|
+
self._conn.execute(
|
|
143
|
+
"DELETE FROM beaver_edges WHERE collection = ? AND (source_item_id = ? OR target_item_id = ?)",
|
|
144
|
+
(self._name, document.id, document.id),
|
|
145
|
+
)
|
|
146
|
+
self._conn.execute(
|
|
147
|
+
"""
|
|
148
|
+
INSERT INTO beaver_collection_versions (collection_name, version) VALUES (?, 1)
|
|
149
|
+
ON CONFLICT(collection_name) DO UPDATE SET version = version + 1
|
|
150
|
+
""",
|
|
151
|
+
(self._name,),
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
def refresh(self):
|
|
155
|
+
"""Forces a rebuild of the in-memory ANN index from data in SQLite."""
|
|
156
|
+
cursor = self._conn.cursor()
|
|
157
|
+
cursor.execute(
|
|
158
|
+
"SELECT item_id, item_vector FROM beaver_collections WHERE collection = ? AND item_vector IS NOT NULL",
|
|
159
|
+
(self._name,),
|
|
160
|
+
)
|
|
161
|
+
vectors, self._doc_ids = [], []
|
|
162
|
+
for row in cursor.fetchall():
|
|
163
|
+
self._doc_ids.append(row["item_id"])
|
|
164
|
+
vectors.append(np.frombuffer(row["item_vector"], dtype=np.float32))
|
|
165
|
+
|
|
166
|
+
self._kdtree = cKDTree(vectors) if vectors else None
|
|
167
|
+
self._local_index_version = self._get_db_version()
|
|
168
|
+
|
|
169
|
+
def search(
|
|
170
|
+
self, vector: list[float], top_k: int = 10
|
|
171
|
+
) -> list[tuple[Document, float]]:
|
|
172
|
+
"""Performs a fast approximate nearest neighbor search."""
|
|
173
|
+
if self._is_index_stale():
|
|
174
|
+
self.refresh()
|
|
175
|
+
if not self._kdtree:
|
|
176
|
+
return []
|
|
177
|
+
|
|
178
|
+
distances, indices = self._kdtree.query(
|
|
179
|
+
np.array(vector, dtype=np.float32), k=top_k
|
|
180
|
+
)
|
|
181
|
+
if top_k == 1:
|
|
182
|
+
distances, indices = [distances], [indices]
|
|
183
|
+
|
|
184
|
+
result_ids = [self._doc_ids[i] for i in indices]
|
|
185
|
+
placeholders = ",".join("?" for _ in result_ids)
|
|
186
|
+
sql = f"SELECT item_id, item_vector, metadata FROM beaver_collections WHERE collection = ? AND item_id IN ({placeholders})"
|
|
187
|
+
|
|
188
|
+
cursor = self._conn.cursor()
|
|
189
|
+
rows = cursor.execute(sql, (self._name, *result_ids)).fetchall()
|
|
190
|
+
row_map = {row["item_id"]: row for row in rows}
|
|
191
|
+
|
|
192
|
+
results = []
|
|
193
|
+
for i, doc_id in enumerate(result_ids):
|
|
194
|
+
row = row_map.get(doc_id)
|
|
195
|
+
if row:
|
|
196
|
+
embedding = np.frombuffer(row["item_vector"], dtype=np.float32).tolist()
|
|
197
|
+
doc = Document(
|
|
198
|
+
id=doc_id, embedding=embedding, **json.loads(row["metadata"])
|
|
199
|
+
)
|
|
200
|
+
results.append((doc, float(distances[i])))
|
|
201
|
+
return results
|
|
202
|
+
|
|
203
|
+
def match(
|
|
204
|
+
self, query: str, on_field: str | None = None, top_k: int = 10
|
|
205
|
+
) -> list[tuple[Document, float]]:
|
|
206
|
+
"""Performs a full-text search on indexed string fields."""
|
|
207
|
+
cursor = self._conn.cursor()
|
|
208
|
+
sql_query = """
|
|
209
|
+
SELECT t1.item_id, t1.item_vector, t1.metadata, fts.rank
|
|
210
|
+
FROM beaver_collections AS t1 JOIN (
|
|
211
|
+
SELECT DISTINCT item_id, rank FROM beaver_fts_index
|
|
212
|
+
WHERE beaver_fts_index MATCH ? {} ORDER BY rank LIMIT ?
|
|
213
|
+
) AS fts ON t1.item_id = fts.item_id
|
|
214
|
+
WHERE t1.collection = ? ORDER BY fts.rank
|
|
215
|
+
"""
|
|
216
|
+
params, field_filter_sql = [], ""
|
|
217
|
+
if on_field:
|
|
218
|
+
field_filter_sql = "AND field_path = ?"
|
|
219
|
+
params.extend([query, on_field])
|
|
220
|
+
else:
|
|
221
|
+
params.append(query)
|
|
222
|
+
params.extend([top_k, self._name])
|
|
223
|
+
|
|
224
|
+
rows = cursor.execute(
|
|
225
|
+
sql_query.format(field_filter_sql), tuple(params)
|
|
226
|
+
).fetchall()
|
|
227
|
+
results = []
|
|
228
|
+
for row in rows:
|
|
229
|
+
embedding = (
|
|
230
|
+
np.frombuffer(row["item_vector"], dtype=np.float32).tolist()
|
|
231
|
+
if row["item_vector"]
|
|
232
|
+
else None
|
|
233
|
+
)
|
|
234
|
+
doc = Document(
|
|
235
|
+
id=row["item_id"], embedding=embedding, **json.loads(row["metadata"])
|
|
236
|
+
)
|
|
237
|
+
results.append((doc, row["rank"]))
|
|
238
|
+
return results
|
|
239
|
+
|
|
240
|
+
def connect(
|
|
241
|
+
self, source: Document, target: Document, label: str, metadata: dict = None
|
|
242
|
+
):
|
|
243
|
+
"""Creates a directed edge between two documents."""
|
|
244
|
+
if not isinstance(source, Document) or not isinstance(target, Document):
|
|
245
|
+
raise TypeError("Source and target must be Document objects.")
|
|
246
|
+
with self._conn:
|
|
247
|
+
self._conn.execute(
|
|
248
|
+
"INSERT OR REPLACE INTO beaver_edges (collection, source_item_id, target_item_id, label, metadata) VALUES (?, ?, ?, ?, ?)",
|
|
249
|
+
(
|
|
250
|
+
self._name,
|
|
251
|
+
source.id,
|
|
252
|
+
target.id,
|
|
253
|
+
label,
|
|
254
|
+
json.dumps(metadata) if metadata else None,
|
|
255
|
+
),
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
def neighbors(self, doc: Document, label: str | None = None) -> list[Document]:
|
|
259
|
+
"""Retrieves the neighboring documents connected to a given document."""
|
|
260
|
+
sql = "SELECT t1.item_id, t1.item_vector, t1.metadata FROM beaver_collections AS t1 JOIN beaver_edges AS t2 ON t1.item_id = t2.target_item_id AND t1.collection = t2.collection WHERE t2.collection = ? AND t2.source_item_id = ?"
|
|
261
|
+
params = [self._name, doc.id]
|
|
262
|
+
if label:
|
|
263
|
+
sql += " AND t2.label = ?"
|
|
264
|
+
params.append(label)
|
|
265
|
+
|
|
266
|
+
rows = self._conn.cursor().execute(sql, tuple(params)).fetchall()
|
|
267
|
+
return [
|
|
268
|
+
Document(
|
|
269
|
+
id=row["item_id"],
|
|
270
|
+
embedding=(
|
|
271
|
+
np.frombuffer(row["item_vector"], dtype=np.float32).tolist()
|
|
272
|
+
if row["item_vector"]
|
|
273
|
+
else None
|
|
274
|
+
),
|
|
275
|
+
**json.loads(row["metadata"]),
|
|
276
|
+
)
|
|
277
|
+
for row in rows
|
|
278
|
+
]
|
|
279
|
+
|
|
280
|
+
def walk(
|
|
281
|
+
self,
|
|
282
|
+
source: Document,
|
|
283
|
+
labels: List[str],
|
|
284
|
+
depth: int,
|
|
285
|
+
*,
|
|
286
|
+
direction: Literal[
|
|
287
|
+
WalkDirection.OUTGOING, WalkDirection.INCOMING
|
|
288
|
+
] = WalkDirection.OUTGOING,
|
|
289
|
+
) -> List[Document]:
|
|
290
|
+
"""Performs a graph traversal (BFS) from a starting document."""
|
|
291
|
+
if not isinstance(source, Document):
|
|
292
|
+
raise TypeError("The starting point must be a Document object.")
|
|
293
|
+
if depth <= 0:
|
|
294
|
+
return []
|
|
295
|
+
|
|
296
|
+
source_col, target_col = (
|
|
297
|
+
("source_item_id", "target_item_id")
|
|
298
|
+
if direction == WalkDirection.OUTGOING
|
|
299
|
+
else ("target_item_id", "source_item_id")
|
|
300
|
+
)
|
|
301
|
+
sql = f"""
|
|
302
|
+
WITH RECURSIVE walk_bfs(item_id, current_depth) AS (
|
|
303
|
+
SELECT ?, 0
|
|
304
|
+
UNION ALL
|
|
305
|
+
SELECT edges.{target_col}, bfs.current_depth + 1
|
|
306
|
+
FROM beaver_edges AS edges JOIN walk_bfs AS bfs ON edges.{source_col} = bfs.item_id
|
|
307
|
+
WHERE edges.collection = ? AND bfs.current_depth < ? AND edges.label IN ({','.join('?' for _ in labels)})
|
|
308
|
+
)
|
|
309
|
+
SELECT DISTINCT t1.item_id, t1.item_vector, t1.metadata
|
|
310
|
+
FROM beaver_collections AS t1 JOIN walk_bfs AS bfs ON t1.item_id = bfs.item_id
|
|
311
|
+
WHERE t1.collection = ? AND bfs.current_depth > 0
|
|
312
|
+
"""
|
|
313
|
+
params = [source.id, self._name, depth] + labels + [self._name]
|
|
314
|
+
|
|
315
|
+
rows = self._conn.cursor().execute(sql, tuple(params)).fetchall()
|
|
316
|
+
return [
|
|
317
|
+
Document(
|
|
318
|
+
id=row["item_id"],
|
|
319
|
+
embedding=(
|
|
320
|
+
np.frombuffer(row["item_vector"], dtype=np.float32).tolist()
|
|
321
|
+
if row["item_vector"]
|
|
322
|
+
else None
|
|
323
|
+
),
|
|
324
|
+
**json.loads(row["metadata"]),
|
|
325
|
+
)
|
|
326
|
+
for row in rows
|
|
327
|
+
]
|