beaver-db 0.13.1__py3-none-any.whl → 0.14.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of beaver-db might be problematic. Click here for more details.
- beaver/blobs.py +28 -9
- beaver/channels.py +39 -22
- beaver/collections.py +28 -27
- beaver/core.py +9 -7
- {beaver_db-0.13.1.dist-info → beaver_db-0.14.0.dist-info}/METADATA +6 -2
- beaver_db-0.14.0.dist-info/RECORD +15 -0
- beaver_db-0.13.1.dist-info/RECORD +0 -15
- {beaver_db-0.13.1.dist-info → beaver_db-0.14.0.dist-info}/WHEEL +0 -0
- {beaver_db-0.13.1.dist-info → beaver_db-0.14.0.dist-info}/licenses/LICENSE +0 -0
- {beaver_db-0.13.1.dist-info → beaver_db-0.14.0.dist-info}/top_level.txt +0 -0
beaver/blobs.py
CHANGED
|
@@ -1,24 +1,43 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import sqlite3
|
|
3
|
-
from typing import Any, Dict, Iterator, NamedTuple, Optional
|
|
3
|
+
from typing import Any, Dict, Iterator, NamedTuple, Optional, Type, TypeVar
|
|
4
4
|
|
|
5
|
+
from .types import JsonSerializable
|
|
5
6
|
|
|
6
|
-
|
|
7
|
+
|
|
8
|
+
class Blob[M](NamedTuple):
|
|
7
9
|
"""A data class representing a single blob retrieved from the store."""
|
|
8
10
|
|
|
9
11
|
key: str
|
|
10
12
|
data: bytes
|
|
11
|
-
metadata:
|
|
13
|
+
metadata: M
|
|
12
14
|
|
|
13
15
|
|
|
14
|
-
class BlobManager:
|
|
16
|
+
class BlobManager[M]:
|
|
15
17
|
"""A wrapper providing a Pythonic interface to a blob store in the database."""
|
|
16
18
|
|
|
17
|
-
def __init__(self, name: str, conn: sqlite3.Connection):
|
|
19
|
+
def __init__(self, name: str, conn: sqlite3.Connection, model: Type[M] | None = None):
|
|
18
20
|
self._name = name
|
|
19
21
|
self._conn = conn
|
|
22
|
+
self._model = model
|
|
23
|
+
|
|
24
|
+
def _serialize(self, value: M) -> str | None:
|
|
25
|
+
"""Serializes the given value to a JSON string."""
|
|
26
|
+
if value is None:
|
|
27
|
+
return None
|
|
28
|
+
if isinstance(value, JsonSerializable):
|
|
29
|
+
return value.model_dump_json()
|
|
30
|
+
|
|
31
|
+
return json.dumps(value)
|
|
32
|
+
|
|
33
|
+
def _deserialize(self, value: str) -> M:
|
|
34
|
+
"""Deserializes a JSON string into the specified model or a generic object."""
|
|
35
|
+
if self._model:
|
|
36
|
+
return self._model.model_validate_json(value)
|
|
37
|
+
|
|
38
|
+
return json.loads(value)
|
|
20
39
|
|
|
21
|
-
def put(self, key: str, data: bytes, metadata: Optional[
|
|
40
|
+
def put(self, key: str, data: bytes, metadata: Optional[M] = None):
|
|
22
41
|
"""
|
|
23
42
|
Stores or replaces a blob in the store.
|
|
24
43
|
|
|
@@ -30,7 +49,7 @@ class BlobManager:
|
|
|
30
49
|
if not isinstance(data, bytes):
|
|
31
50
|
raise TypeError("Blob data must be of type bytes.")
|
|
32
51
|
|
|
33
|
-
metadata_json =
|
|
52
|
+
metadata_json = self._serialize(metadata) if metadata else None
|
|
34
53
|
|
|
35
54
|
with self._conn:
|
|
36
55
|
self._conn.execute(
|
|
@@ -38,7 +57,7 @@ class BlobManager:
|
|
|
38
57
|
(self._name, key, data, metadata_json),
|
|
39
58
|
)
|
|
40
59
|
|
|
41
|
-
def get(self, key: str) -> Optional[Blob]:
|
|
60
|
+
def get(self, key: str) -> Optional[Blob[M]]:
|
|
42
61
|
"""
|
|
43
62
|
Retrieves a blob from the store.
|
|
44
63
|
|
|
@@ -60,7 +79,7 @@ class BlobManager:
|
|
|
60
79
|
return None
|
|
61
80
|
|
|
62
81
|
data, metadata_json = result
|
|
63
|
-
metadata =
|
|
82
|
+
metadata = self._deserialize(metadata_json) if metadata_json else None
|
|
64
83
|
|
|
65
84
|
return Blob(key=key, data=data, metadata=metadata)
|
|
66
85
|
|
beaver/channels.py
CHANGED
|
@@ -4,19 +4,21 @@ import sqlite3
|
|
|
4
4
|
import threading
|
|
5
5
|
import time
|
|
6
6
|
from queue import Empty, Queue
|
|
7
|
-
from typing import Any, AsyncIterator, Iterator, Set
|
|
7
|
+
from typing import Any, AsyncIterator, Generic, Iterator, Set, Type, TypeVar
|
|
8
|
+
|
|
9
|
+
from .types import JsonSerializable
|
|
8
10
|
|
|
9
11
|
# A special message object used to signal the listener to gracefully shut down.
|
|
10
12
|
_SHUTDOWN_SENTINEL = object()
|
|
11
13
|
|
|
12
14
|
|
|
13
|
-
class AsyncSubscriber:
|
|
15
|
+
class AsyncSubscriber[T]:
|
|
14
16
|
"""A thread-safe async message receiver for a specific channel subscription."""
|
|
15
17
|
|
|
16
|
-
def __init__(self, subscriber: "Subscriber"):
|
|
18
|
+
def __init__(self, subscriber: "Subscriber[T]"):
|
|
17
19
|
self._subscriber = subscriber
|
|
18
20
|
|
|
19
|
-
async def __aenter__(self) -> "AsyncSubscriber":
|
|
21
|
+
async def __aenter__(self) -> "AsyncSubscriber[T]":
|
|
20
22
|
"""Registers the listener's queue with the channel to start receiving messages."""
|
|
21
23
|
await asyncio.to_thread(self._subscriber.__enter__)
|
|
22
24
|
return self
|
|
@@ -25,7 +27,7 @@ class AsyncSubscriber:
|
|
|
25
27
|
"""Unregisters the listener's queue from the channel to stop receiving messages."""
|
|
26
28
|
await asyncio.to_thread(self._subscriber.__exit__, exc_type, exc_val, exc_tb)
|
|
27
29
|
|
|
28
|
-
async def listen(self, timeout: float | None = None) -> AsyncIterator[
|
|
30
|
+
async def listen(self, timeout: float | None = None) -> AsyncIterator[T]:
|
|
29
31
|
"""
|
|
30
32
|
Returns a blocking async iterator that yields messages as they arrive.
|
|
31
33
|
"""
|
|
@@ -39,7 +41,7 @@ class AsyncSubscriber:
|
|
|
39
41
|
raise TimeoutError(f"Timeout {timeout}s expired.")
|
|
40
42
|
|
|
41
43
|
|
|
42
|
-
class Subscriber:
|
|
44
|
+
class Subscriber[T]:
|
|
43
45
|
"""
|
|
44
46
|
A thread-safe message receiver for a specific channel subscription.
|
|
45
47
|
|
|
@@ -49,11 +51,11 @@ class Subscriber:
|
|
|
49
51
|
impact others.
|
|
50
52
|
"""
|
|
51
53
|
|
|
52
|
-
def __init__(self, channel: "ChannelManager"):
|
|
54
|
+
def __init__(self, channel: "ChannelManager[T]"):
|
|
53
55
|
self._channel = channel
|
|
54
56
|
self._queue: Queue = Queue()
|
|
55
57
|
|
|
56
|
-
def __enter__(self) -> "Subscriber":
|
|
58
|
+
def __enter__(self) -> "Subscriber[T]":
|
|
57
59
|
"""Registers the listener's queue with the channel to start receiving messages."""
|
|
58
60
|
self._channel._register(self._queue)
|
|
59
61
|
return self
|
|
@@ -62,7 +64,7 @@ class Subscriber:
|
|
|
62
64
|
"""Unregisters the listener's queue from the channel to stop receiving messages."""
|
|
63
65
|
self._channel._unregister(self._queue)
|
|
64
66
|
|
|
65
|
-
def listen(self, timeout: float | None = None) -> Iterator[
|
|
67
|
+
def listen(self, timeout: float | None = None) -> Iterator[T]:
|
|
66
68
|
"""
|
|
67
69
|
Returns a blocking iterator that yields messages as they arrive.
|
|
68
70
|
|
|
@@ -84,29 +86,29 @@ class Subscriber:
|
|
|
84
86
|
except Empty:
|
|
85
87
|
raise TimeoutError(f"Timeout {timeout}s expired.")
|
|
86
88
|
|
|
87
|
-
def as_async(self) -> "AsyncSubscriber":
|
|
89
|
+
def as_async(self) -> "AsyncSubscriber[T]":
|
|
88
90
|
"""Returns an async version of the subscriber."""
|
|
89
91
|
return AsyncSubscriber(self)
|
|
90
92
|
|
|
91
93
|
|
|
92
|
-
class AsyncChannelManager:
|
|
94
|
+
class AsyncChannelManager[T]:
|
|
93
95
|
"""The central async hub for a named pub/sub channel."""
|
|
94
96
|
|
|
95
|
-
def __init__(self, channel: "ChannelManager"):
|
|
97
|
+
def __init__(self, channel: "ChannelManager[T]"):
|
|
96
98
|
self._channel = channel
|
|
97
99
|
|
|
98
|
-
async def publish(self, payload:
|
|
100
|
+
async def publish(self, payload: T):
|
|
99
101
|
"""
|
|
100
102
|
Publishes a JSON-serializable message to the channel asynchronously.
|
|
101
103
|
"""
|
|
102
104
|
await asyncio.to_thread(self._channel.publish, payload)
|
|
103
105
|
|
|
104
|
-
def subscribe(self) -> "AsyncSubscriber":
|
|
106
|
+
def subscribe(self) -> "AsyncSubscriber[T]":
|
|
105
107
|
"""Creates a new async subscription, returning an AsyncSubscriber context manager."""
|
|
106
108
|
return self._channel.subscribe().as_async()
|
|
107
109
|
|
|
108
110
|
|
|
109
|
-
class ChannelManager:
|
|
111
|
+
class ChannelManager[T]:
|
|
110
112
|
"""
|
|
111
113
|
The central hub for a named pub/sub channel.
|
|
112
114
|
|
|
@@ -121,16 +123,32 @@ class ChannelManager:
|
|
|
121
123
|
conn: sqlite3.Connection,
|
|
122
124
|
db_path: str,
|
|
123
125
|
poll_interval: float = 0.1,
|
|
126
|
+
model: Type[T] | None = None,
|
|
124
127
|
):
|
|
125
128
|
self._name = name
|
|
126
129
|
self._conn = conn
|
|
127
130
|
self._db_path = db_path
|
|
128
131
|
self._poll_interval = poll_interval
|
|
132
|
+
self._model = model
|
|
129
133
|
self._listeners: Set[Queue] = set()
|
|
130
134
|
self._lock = threading.Lock()
|
|
131
135
|
self._polling_thread: threading.Thread | None = None
|
|
132
136
|
self._stop_event = threading.Event()
|
|
133
137
|
|
|
138
|
+
def _serialize(self, value: T) -> str:
|
|
139
|
+
"""Serializes the given value to a JSON string."""
|
|
140
|
+
if isinstance(value, JsonSerializable):
|
|
141
|
+
return value.model_dump_json()
|
|
142
|
+
|
|
143
|
+
return json.dumps(value)
|
|
144
|
+
|
|
145
|
+
def _deserialize(self, value: str) -> T:
|
|
146
|
+
"""Deserializes a JSON string into the specified model or a generic object."""
|
|
147
|
+
if self._model:
|
|
148
|
+
return self._model.model_validate_json(value)
|
|
149
|
+
|
|
150
|
+
return json.loads(value)
|
|
151
|
+
|
|
134
152
|
def _register(self, queue: Queue):
|
|
135
153
|
"""Adds a listener's queue and starts the poller if it's the first one."""
|
|
136
154
|
|
|
@@ -186,7 +204,6 @@ class ChannelManager:
|
|
|
186
204
|
# The poller starts listening for messages from this moment forward.
|
|
187
205
|
last_seen_timestamp = time.time()
|
|
188
206
|
|
|
189
|
-
|
|
190
207
|
while not self._stop_event.is_set():
|
|
191
208
|
cursor = thread_conn.cursor()
|
|
192
209
|
cursor.execute(
|
|
@@ -206,18 +223,18 @@ class ChannelManager:
|
|
|
206
223
|
with self._lock:
|
|
207
224
|
for queue in self._listeners:
|
|
208
225
|
for row in messages:
|
|
209
|
-
queue.put(
|
|
226
|
+
queue.put(self._deserialize(row["message_payload"]))
|
|
210
227
|
|
|
211
228
|
# Wait for the poll interval before checking for new messages again.
|
|
212
229
|
time.sleep(self._poll_interval)
|
|
213
230
|
|
|
214
231
|
thread_conn.close()
|
|
215
232
|
|
|
216
|
-
def subscribe(self) -> Subscriber:
|
|
233
|
+
def subscribe(self) -> Subscriber[T]:
|
|
217
234
|
"""Creates a new subscription, returning a Listener context manager."""
|
|
218
235
|
return Subscriber(self)
|
|
219
236
|
|
|
220
|
-
def publish(self, payload:
|
|
237
|
+
def publish(self, payload: T):
|
|
221
238
|
"""
|
|
222
239
|
Publishes a JSON-serializable message to the channel.
|
|
223
240
|
|
|
@@ -225,7 +242,7 @@ class ChannelManager:
|
|
|
225
242
|
into the database's pub/sub log.
|
|
226
243
|
"""
|
|
227
244
|
try:
|
|
228
|
-
json_payload =
|
|
245
|
+
json_payload = self._serialize(payload)
|
|
229
246
|
except TypeError as e:
|
|
230
247
|
raise TypeError("Message payload must be JSON-serializable.") from e
|
|
231
248
|
|
|
@@ -235,6 +252,6 @@ class ChannelManager:
|
|
|
235
252
|
(time.time(), self._name, json_payload),
|
|
236
253
|
)
|
|
237
254
|
|
|
238
|
-
def as_async(self) -> "AsyncChannelManager":
|
|
255
|
+
def as_async(self) -> "AsyncChannelManager[T]":
|
|
239
256
|
"""Returns an async version of the channel manager."""
|
|
240
|
-
return AsyncChannelManager(self)
|
|
257
|
+
return AsyncChannelManager(self)
|
beaver/collections.py
CHANGED
|
@@ -3,10 +3,11 @@ import sqlite3
|
|
|
3
3
|
import threading
|
|
4
4
|
import uuid
|
|
5
5
|
from enum import Enum
|
|
6
|
-
from typing import Any, List, Literal, Tuple
|
|
6
|
+
from typing import Any, Iterator, List, Literal, Tuple, Type, TypeVar
|
|
7
7
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
|
|
10
|
+
from .types import Model
|
|
10
11
|
from .vectors import VectorIndex
|
|
11
12
|
|
|
12
13
|
|
|
@@ -71,7 +72,7 @@ class WalkDirection(Enum):
|
|
|
71
72
|
INCOMING = "incoming"
|
|
72
73
|
|
|
73
74
|
|
|
74
|
-
class Document:
|
|
75
|
+
class Document(Model):
|
|
75
76
|
"""A data class representing a single item in a collection."""
|
|
76
77
|
|
|
77
78
|
def __init__(
|
|
@@ -88,8 +89,7 @@ class Document:
|
|
|
88
89
|
raise TypeError("Embedding must be a list of numbers.")
|
|
89
90
|
self.embedding = np.array(embedding, dtype=np.float32)
|
|
90
91
|
|
|
91
|
-
|
|
92
|
-
setattr(self, key, value)
|
|
92
|
+
super().__init__(**metadata)
|
|
93
93
|
|
|
94
94
|
def to_dict(self) -> dict[str, Any]:
|
|
95
95
|
"""Serializes the document's metadata to a dictionary."""
|
|
@@ -103,15 +103,16 @@ class Document:
|
|
|
103
103
|
return f"Document(id='{self.id}', {metadata_str})"
|
|
104
104
|
|
|
105
105
|
|
|
106
|
-
class CollectionManager:
|
|
106
|
+
class CollectionManager[D: Document]:
|
|
107
107
|
"""
|
|
108
108
|
A wrapper for multi-modal collection operations, including document storage,
|
|
109
109
|
FTS, fuzzy search, graph traversal, and persistent vector search.
|
|
110
110
|
"""
|
|
111
111
|
|
|
112
|
-
def __init__(self, name: str, conn: sqlite3.Connection):
|
|
112
|
+
def __init__(self, name: str, conn: sqlite3.Connection, model: Type[D] | None = None):
|
|
113
113
|
self._name = name
|
|
114
114
|
self._conn = conn
|
|
115
|
+
self._model = model or Document
|
|
115
116
|
# All vector-related operations are now delegated to the VectorIndex class.
|
|
116
117
|
self._vector_index = VectorIndex(name, conn)
|
|
117
118
|
# A lock to ensure only one compaction thread runs at a time for this collection.
|
|
@@ -184,7 +185,7 @@ class CollectionManager:
|
|
|
184
185
|
|
|
185
186
|
def index(
|
|
186
187
|
self,
|
|
187
|
-
document:
|
|
188
|
+
document: D,
|
|
188
189
|
*,
|
|
189
190
|
fts: bool | list[str] = True,
|
|
190
191
|
fuzzy: bool = False
|
|
@@ -266,7 +267,7 @@ class CollectionManager:
|
|
|
266
267
|
if self._needs_compaction():
|
|
267
268
|
self.compact()
|
|
268
269
|
|
|
269
|
-
def __iter__(self):
|
|
270
|
+
def __iter__(self) -> Iterator[D]:
|
|
270
271
|
"""Returns an iterator over all documents in the collection."""
|
|
271
272
|
cursor = self._conn.cursor()
|
|
272
273
|
cursor.execute(
|
|
@@ -279,14 +280,14 @@ class CollectionManager:
|
|
|
279
280
|
if row["item_vector"]
|
|
280
281
|
else None
|
|
281
282
|
)
|
|
282
|
-
yield
|
|
283
|
+
yield self._model(
|
|
283
284
|
id=row["item_id"], embedding=embedding, **json.loads(row["metadata"])
|
|
284
285
|
)
|
|
285
286
|
cursor.close()
|
|
286
287
|
|
|
287
288
|
def search(
|
|
288
289
|
self, vector: list[float], top_k: int = 10
|
|
289
|
-
) -> List[Tuple[
|
|
290
|
+
) -> List[Tuple[D, float]]:
|
|
290
291
|
"""Performs a fast, persistent approximate nearest neighbor search."""
|
|
291
292
|
if not isinstance(vector, list):
|
|
292
293
|
raise TypeError("Search vector must be a list of floats.")
|
|
@@ -307,7 +308,7 @@ class CollectionManager:
|
|
|
307
308
|
rows = cursor.execute(sql, (self._name, *result_ids)).fetchall()
|
|
308
309
|
|
|
309
310
|
doc_map = {
|
|
310
|
-
row["item_id"]:
|
|
311
|
+
row["item_id"]: self._model(
|
|
311
312
|
id=row["item_id"],
|
|
312
313
|
embedding=(np.frombuffer(row["item_vector"], dtype=np.float32).tolist() if row["item_vector"] else None),
|
|
313
314
|
**json.loads(row["metadata"]),
|
|
@@ -331,7 +332,7 @@ class CollectionManager:
|
|
|
331
332
|
on: str | list[str] | None = None,
|
|
332
333
|
top_k: int = 10,
|
|
333
334
|
fuzziness: int = 0
|
|
334
|
-
) -> list[tuple[
|
|
335
|
+
) -> list[tuple[D, float]]:
|
|
335
336
|
"""
|
|
336
337
|
Performs a full-text or fuzzy search on indexed string fields.
|
|
337
338
|
"""
|
|
@@ -345,7 +346,7 @@ class CollectionManager:
|
|
|
345
346
|
|
|
346
347
|
def _perform_fts_search(
|
|
347
348
|
self, query: str, on: list[str] | None, top_k: int
|
|
348
|
-
) -> list[tuple[
|
|
349
|
+
) -> list[tuple[D, float]]:
|
|
349
350
|
"""Performs a standard FTS search."""
|
|
350
351
|
cursor = self._conn.cursor()
|
|
351
352
|
sql_query = """
|
|
@@ -371,7 +372,7 @@ class CollectionManager:
|
|
|
371
372
|
np.frombuffer(row["item_vector"], dtype=np.float32).tolist()
|
|
372
373
|
if row["item_vector"] else None
|
|
373
374
|
)
|
|
374
|
-
doc =
|
|
375
|
+
doc = self._model(id=row["item_id"], embedding=embedding, **json.loads(row["metadata"]))
|
|
375
376
|
results.append((doc, row["rank"]))
|
|
376
377
|
return results
|
|
377
378
|
|
|
@@ -410,7 +411,7 @@ class CollectionManager:
|
|
|
410
411
|
|
|
411
412
|
def _perform_fuzzy_search(
|
|
412
413
|
self, query: str, on: list[str] | None, top_k: int, fuzziness: int
|
|
413
|
-
) -> list[tuple[
|
|
414
|
+
) -> list[tuple[D, float]]:
|
|
414
415
|
"""Performs a 3-stage fuzzy search: gather, score, and sort."""
|
|
415
416
|
fts_results = self._perform_fts_search(query, on, top_k)
|
|
416
417
|
fts_candidate_ids = {doc.id for doc, _ in fts_results}
|
|
@@ -462,7 +463,7 @@ class CollectionManager:
|
|
|
462
463
|
id_placeholders = ",".join("?" for _ in top_ids)
|
|
463
464
|
sql_docs = f"SELECT item_id, item_vector, metadata FROM beaver_collections WHERE collection = ? AND item_id IN ({id_placeholders})"
|
|
464
465
|
cursor.execute(sql_docs, (self._name, *top_ids))
|
|
465
|
-
doc_map = {row["item_id"]:
|
|
466
|
+
doc_map = {row["item_id"]: self._model(id=row["item_id"], embedding=(np.frombuffer(row["item_vector"], dtype=np.float32).tolist() if row["item_vector"] else None), **json.loads(row["metadata"])) for row in cursor.fetchall()}
|
|
466
467
|
|
|
467
468
|
final_results = []
|
|
468
469
|
distance_map = {c["id"]: c["distance"] for c in scored_candidates}
|
|
@@ -489,7 +490,7 @@ class CollectionManager:
|
|
|
489
490
|
),
|
|
490
491
|
)
|
|
491
492
|
|
|
492
|
-
def neighbors(self, doc:
|
|
493
|
+
def neighbors(self, doc: D, label: str | None = None) -> list[D]:
|
|
493
494
|
"""Retrieves the neighboring documents connected to a given document."""
|
|
494
495
|
sql = "SELECT t1.item_id, t1.item_vector, t1.metadata FROM beaver_collections AS t1 JOIN beaver_edges AS t2 ON t1.item_id = t2.target_item_id AND t1.collection = t2.collection WHERE t2.collection = ? AND t2.source_item_id = ?"
|
|
495
496
|
params = [self._name, doc.id]
|
|
@@ -499,7 +500,7 @@ class CollectionManager:
|
|
|
499
500
|
|
|
500
501
|
rows = self._conn.cursor().execute(sql, tuple(params)).fetchall()
|
|
501
502
|
return [
|
|
502
|
-
|
|
503
|
+
self._model(
|
|
503
504
|
id=row["item_id"],
|
|
504
505
|
embedding=(
|
|
505
506
|
np.frombuffer(row["item_vector"], dtype=np.float32).tolist()
|
|
@@ -513,16 +514,16 @@ class CollectionManager:
|
|
|
513
514
|
|
|
514
515
|
def walk(
|
|
515
516
|
self,
|
|
516
|
-
source:
|
|
517
|
+
source: D,
|
|
517
518
|
labels: List[str],
|
|
518
519
|
depth: int,
|
|
519
520
|
*,
|
|
520
521
|
direction: Literal[
|
|
521
522
|
WalkDirection.OUTGOING, WalkDirection.INCOMING
|
|
522
523
|
] = WalkDirection.OUTGOING,
|
|
523
|
-
) -> List[
|
|
524
|
+
) -> List[D]:
|
|
524
525
|
"""Performs a graph traversal (BFS) from a starting document."""
|
|
525
|
-
if not isinstance(source,
|
|
526
|
+
if not isinstance(source, D):
|
|
526
527
|
raise TypeError("The starting point must be a Document object.")
|
|
527
528
|
if depth <= 0:
|
|
528
529
|
return []
|
|
@@ -548,7 +549,7 @@ class CollectionManager:
|
|
|
548
549
|
|
|
549
550
|
rows = self._conn.cursor().execute(sql, tuple(params)).fetchall()
|
|
550
551
|
return [
|
|
551
|
-
|
|
552
|
+
self._model(
|
|
552
553
|
id=row["item_id"],
|
|
553
554
|
embedding=(
|
|
554
555
|
np.frombuffer(row["item_vector"], dtype=np.float32).tolist()
|
|
@@ -572,11 +573,11 @@ class CollectionManager:
|
|
|
572
573
|
return count
|
|
573
574
|
|
|
574
575
|
|
|
575
|
-
def rerank(
|
|
576
|
-
*results: list[
|
|
576
|
+
def rerank[D: Document](
|
|
577
|
+
*results: list[D],
|
|
577
578
|
weights: list[float] | None = None,
|
|
578
579
|
k: int = 60
|
|
579
|
-
) -> list[
|
|
580
|
+
) -> list[D]:
|
|
580
581
|
"""
|
|
581
582
|
Reranks documents from multiple search result lists using Reverse Rank Fusion (RRF).
|
|
582
583
|
"""
|
|
@@ -590,7 +591,7 @@ def rerank(
|
|
|
590
591
|
raise ValueError("The number of result lists must match the number of weights.")
|
|
591
592
|
|
|
592
593
|
rrf_scores: dict[str, float] = {}
|
|
593
|
-
doc_store: dict[str,
|
|
594
|
+
doc_store: dict[str, D] = {}
|
|
594
595
|
|
|
595
596
|
for result_list, weight in zip(results, weights):
|
|
596
597
|
for rank, doc in enumerate(result_list):
|
|
@@ -600,5 +601,5 @@ def rerank(
|
|
|
600
601
|
score = weight * (1 / (k + rank))
|
|
601
602
|
rrf_scores[doc_id] = rrf_scores.get(doc_id, 0.0) + score
|
|
602
603
|
|
|
603
|
-
sorted_doc_ids = sorted(rrf_scores.keys(), key=rrf_scores
|
|
604
|
+
sorted_doc_ids = sorted(rrf_scores.keys(), key=lambda k: rrf_scores[k], reverse=True)
|
|
604
605
|
return [doc_store[doc_id] for doc_id in sorted_doc_ids]
|
beaver/core.py
CHANGED
|
@@ -1,10 +1,11 @@
|
|
|
1
1
|
import sqlite3
|
|
2
2
|
import threading
|
|
3
|
+
from typing import Type
|
|
3
4
|
|
|
4
5
|
from .types import JsonSerializable
|
|
5
6
|
from .blobs import BlobManager
|
|
6
7
|
from .channels import ChannelManager
|
|
7
|
-
from .collections import CollectionManager
|
|
8
|
+
from .collections import CollectionManager, Document
|
|
8
9
|
from .dicts import DictManager
|
|
9
10
|
from .lists import ListManager
|
|
10
11
|
from .queues import QueueManager
|
|
@@ -306,7 +307,7 @@ class BeaverDB:
|
|
|
306
307
|
|
|
307
308
|
return QueueManager(name, self._conn, model)
|
|
308
309
|
|
|
309
|
-
def collection(self, name: str) -> CollectionManager:
|
|
310
|
+
def collection[D: Document](self, name: str, model: Type[D] | None = None) -> CollectionManager[D]:
|
|
310
311
|
"""
|
|
311
312
|
Returns a singleton CollectionManager instance for interacting with a
|
|
312
313
|
document collection.
|
|
@@ -319,10 +320,11 @@ class BeaverDB:
|
|
|
319
320
|
# of the vector index consistently.
|
|
320
321
|
with self._collections_lock:
|
|
321
322
|
if name not in self._collections:
|
|
322
|
-
self._collections[name] = CollectionManager(name, self._conn)
|
|
323
|
+
self._collections[name] = CollectionManager(name, self._conn, model=model)
|
|
324
|
+
|
|
323
325
|
return self._collections[name]
|
|
324
326
|
|
|
325
|
-
def channel(self, name: str) -> ChannelManager:
|
|
327
|
+
def channel[T](self, name: str, model: type[T] | None = None) -> ChannelManager[T]:
|
|
326
328
|
"""
|
|
327
329
|
Returns a singleton Channel instance for high-efficiency pub/sub.
|
|
328
330
|
"""
|
|
@@ -332,12 +334,12 @@ class BeaverDB:
|
|
|
332
334
|
# Use a thread-safe lock to ensure only one Channel object is created per name.
|
|
333
335
|
with self._channels_lock:
|
|
334
336
|
if name not in self._channels:
|
|
335
|
-
self._channels[name] = ChannelManager(name, self._conn, self._db_path)
|
|
337
|
+
self._channels[name] = ChannelManager(name, self._conn, self._db_path, model=model)
|
|
336
338
|
return self._channels[name]
|
|
337
339
|
|
|
338
|
-
def blobs(self, name: str) -> BlobManager:
|
|
340
|
+
def blobs[M](self, name: str, model: type[M] | None = None) -> BlobManager[M]:
|
|
339
341
|
"""Returns a wrapper object for interacting with a named blob store."""
|
|
340
342
|
if not isinstance(name, str) or not name:
|
|
341
343
|
raise TypeError("Blob store name must be a non-empty string.")
|
|
342
344
|
|
|
343
|
-
return BlobManager(name, self._conn)
|
|
345
|
+
return BlobManager(name, self._conn, model)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: beaver-db
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.14.0
|
|
4
4
|
Summary: Fast, embedded, and multi-modal DB based on SQLite for AI-powered applications.
|
|
5
5
|
Requires-Python: >=3.13
|
|
6
6
|
Description-Content-Type: text/markdown
|
|
@@ -222,7 +222,7 @@ avatar = attachments.get("user_123_avatar.png")
|
|
|
222
222
|
|
|
223
223
|
## Type-Safe Data Models
|
|
224
224
|
|
|
225
|
-
For enhanced data integrity and a better developer experience, BeaverDB supports type-safe operations for
|
|
225
|
+
For enhanced data integrity and a better developer experience, BeaverDB supports type-safe operations for all modalities. By associating a model with these data structures, you get automatic serialization and deserialization, complete with autocompletion in your editor.
|
|
226
226
|
|
|
227
227
|
This feature is designed to be flexible and works seamlessly with two kinds of models:
|
|
228
228
|
|
|
@@ -252,6 +252,10 @@ retrieved_user = users["alice"]
|
|
|
252
252
|
print(f"Retrieved: {retrieved_user.name}") # Your editor will provide autocompletion here
|
|
253
253
|
```
|
|
254
254
|
|
|
255
|
+
In the same way you can have typed message payloads in `db.channel`, typed metadata in `db.blobs`, and custom document types in `db.collection`, as well as custom types in lists and queues.
|
|
256
|
+
|
|
257
|
+
Basically everywhere you can store or get some object in BeaverDB, you can use a typed version adding `model=MyClass` to the corresponding wrapper methond in `BeaverDB` and enjoy first-class type safety and inference.
|
|
258
|
+
|
|
255
259
|
## More Examples
|
|
256
260
|
|
|
257
261
|
For more in-depth examples, check out the scripts in the `examples/` directory:
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
beaver/__init__.py,sha256=qyEzF1Os7w4b4Hijgz0Y0R4zTrRBrHIGT1mEkZFl2YM,101
|
|
2
|
+
beaver/blobs.py,sha256=5cmcvlJLY9jaftIRuNbdEryZxI47sw_pYpysYli23NY,3996
|
|
3
|
+
beaver/channels.py,sha256=pCO8wFJAHdMzBLKvinI32L_XfU2B91H2qfsj1Tej-bc,9322
|
|
4
|
+
beaver/collections.py,sha256=Uz241TSs0xRABPYeKenDYmkbaM0PKfvcBX5j0lMzMMA,24306
|
|
5
|
+
beaver/core.py,sha256=zAPlym786_sOpRkP6LfKkd5BH2DXPwdOPTdAkSYojvQ,12469
|
|
6
|
+
beaver/dicts.py,sha256=1BQ9A_cMkJ7l5ayWbDG-4Wi3WtQ-9BKd7Wj_CB7dGlU,5410
|
|
7
|
+
beaver/lists.py,sha256=0LT2XjuHs8pDgvW48kk_lfVc-Y-Ulmym0gcVWRESPtA,9708
|
|
8
|
+
beaver/queues.py,sha256=SFu2180ONotnZOcYp1Ld5d8kxzYxaOlgDdcr70ZoBL8,3641
|
|
9
|
+
beaver/types.py,sha256=65rDdj97EegghEkKCNjI67bPYtTTI_jyB-leHdIypx4,1249
|
|
10
|
+
beaver/vectors.py,sha256=j7RL2Y_xMAF2tPTi6E2LdJqZerSQXlnEQJOGZkefTsA,18358
|
|
11
|
+
beaver_db-0.14.0.dist-info/licenses/LICENSE,sha256=1xrIY5JnMk_QDQzsqmVzPIIyCgZAkWCC8kF2Ddo1UT0,1071
|
|
12
|
+
beaver_db-0.14.0.dist-info/METADATA,sha256=vnsXMoILckkEY7YPrppHQbkvK4MP4qfNQLTjvjcEGBU,15621
|
|
13
|
+
beaver_db-0.14.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
14
|
+
beaver_db-0.14.0.dist-info/top_level.txt,sha256=FxA4XnX5Qm5VudEXCduFriqi4dQmDWpQ64d7g69VQKI,7
|
|
15
|
+
beaver_db-0.14.0.dist-info/RECORD,,
|
|
@@ -1,15 +0,0 @@
|
|
|
1
|
-
beaver/__init__.py,sha256=qyEzF1Os7w4b4Hijgz0Y0R4zTrRBrHIGT1mEkZFl2YM,101
|
|
2
|
-
beaver/blobs.py,sha256=5yVDzWyqi6Fur-2r0gaeIjEKj9fUPXb9hPulCTknJJI,3355
|
|
3
|
-
beaver/channels.py,sha256=jKL1sVLOe_Q_pP0q1-iceZbPe8FOi0EwqJtOMOe96f4,8675
|
|
4
|
-
beaver/collections.py,sha256=SZcaZnhcTpKb2OfpZOpFiVxh4-joYAJc6U98UeIhMuU,24247
|
|
5
|
-
beaver/core.py,sha256=OdzXmAwBw12SwUsHBYvV3tFr5NHE3AHQ9HBfjZafDN0,12283
|
|
6
|
-
beaver/dicts.py,sha256=1BQ9A_cMkJ7l5ayWbDG-4Wi3WtQ-9BKd7Wj_CB7dGlU,5410
|
|
7
|
-
beaver/lists.py,sha256=0LT2XjuHs8pDgvW48kk_lfVc-Y-Ulmym0gcVWRESPtA,9708
|
|
8
|
-
beaver/queues.py,sha256=SFu2180ONotnZOcYp1Ld5d8kxzYxaOlgDdcr70ZoBL8,3641
|
|
9
|
-
beaver/types.py,sha256=65rDdj97EegghEkKCNjI67bPYtTTI_jyB-leHdIypx4,1249
|
|
10
|
-
beaver/vectors.py,sha256=j7RL2Y_xMAF2tPTi6E2LdJqZerSQXlnEQJOGZkefTsA,18358
|
|
11
|
-
beaver_db-0.13.1.dist-info/licenses/LICENSE,sha256=1xrIY5JnMk_QDQzsqmVzPIIyCgZAkWCC8kF2Ddo1UT0,1071
|
|
12
|
-
beaver_db-0.13.1.dist-info/METADATA,sha256=DYVLkTRCAJf8LsS20DVWO_TWC_69RYAWIwCdrqNxS-4,15228
|
|
13
|
-
beaver_db-0.13.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
14
|
-
beaver_db-0.13.1.dist-info/top_level.txt,sha256=FxA4XnX5Qm5VudEXCduFriqi4dQmDWpQ64d7g69VQKI,7
|
|
15
|
-
beaver_db-0.13.1.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|