beaver-db 0.4.0__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of beaver-db might be problematic. Click here for more details.
- beaver/__init__.py +2 -1
- beaver/collections.py +327 -0
- beaver/core.py +77 -503
- beaver/lists.py +166 -0
- beaver/subscribers.py +54 -0
- beaver_db-0.5.0.dist-info/METADATA +171 -0
- beaver_db-0.5.0.dist-info/RECORD +9 -0
- beaver_db-0.4.0.dist-info/METADATA +0 -129
- beaver_db-0.4.0.dist-info/RECORD +0 -6
- {beaver_db-0.4.0.dist-info → beaver_db-0.5.0.dist-info}/WHEEL +0 -0
- {beaver_db-0.4.0.dist-info → beaver_db-0.5.0.dist-info}/top_level.txt +0 -0
beaver/core.py
CHANGED
|
@@ -1,21 +1,22 @@
|
|
|
1
|
-
import asyncio
|
|
2
|
-
import uuid
|
|
3
|
-
import numpy as np
|
|
4
1
|
import json
|
|
5
2
|
import sqlite3
|
|
6
3
|
import time
|
|
7
|
-
from typing import Any
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
from .lists import ListWrapper
|
|
7
|
+
from .subscribers import SubWrapper
|
|
8
|
+
from .collections import CollectionWrapper
|
|
8
9
|
|
|
9
10
|
|
|
10
11
|
class BeaverDB:
|
|
11
12
|
"""
|
|
12
13
|
An embedded, multi-modal database in a single SQLite file.
|
|
13
|
-
|
|
14
|
+
This class manages the database connection and table schemas.
|
|
14
15
|
"""
|
|
15
16
|
|
|
16
17
|
def __init__(self, db_path: str):
|
|
17
18
|
"""
|
|
18
|
-
Initializes the database connection and creates necessary tables.
|
|
19
|
+
Initializes the database connection and creates all necessary tables.
|
|
19
20
|
|
|
20
21
|
Args:
|
|
21
22
|
db_path: The path to the SQLite database file.
|
|
@@ -25,29 +26,32 @@ class BeaverDB:
|
|
|
25
26
|
self._conn = sqlite3.connect(self._db_path, check_same_thread=False)
|
|
26
27
|
self._conn.execute("PRAGMA journal_mode=WAL;")
|
|
27
28
|
self._conn.row_factory = sqlite3.Row
|
|
28
|
-
self.
|
|
29
|
+
self._create_all_tables()
|
|
30
|
+
|
|
31
|
+
def _create_all_tables(self):
|
|
32
|
+
"""Initializes all required tables in the database file."""
|
|
29
33
|
self._create_kv_table()
|
|
34
|
+
self._create_pubsub_table()
|
|
30
35
|
self._create_list_table()
|
|
31
36
|
self._create_collections_table()
|
|
32
|
-
self._create_fts_table()
|
|
37
|
+
self._create_fts_table()
|
|
38
|
+
self._create_edges_table()
|
|
39
|
+
self._create_versions_table()
|
|
33
40
|
|
|
34
|
-
def
|
|
35
|
-
"""Creates the
|
|
41
|
+
def _create_kv_table(self):
|
|
42
|
+
"""Creates the key-value store table."""
|
|
36
43
|
with self._conn:
|
|
37
44
|
self._conn.execute(
|
|
38
45
|
"""
|
|
39
|
-
CREATE
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
field_path,
|
|
43
|
-
field_content,
|
|
44
|
-
tokenize = 'porter'
|
|
46
|
+
CREATE TABLE IF NOT EXISTS _beaver_kv_store (
|
|
47
|
+
key TEXT PRIMARY KEY,
|
|
48
|
+
value TEXT NOT NULL
|
|
45
49
|
)
|
|
46
50
|
"""
|
|
47
51
|
)
|
|
48
52
|
|
|
49
53
|
def _create_pubsub_table(self):
|
|
50
|
-
"""Creates the pub/sub log table
|
|
54
|
+
"""Creates the pub/sub log table."""
|
|
51
55
|
with self._conn:
|
|
52
56
|
self._conn.execute(
|
|
53
57
|
"""
|
|
@@ -65,20 +69,8 @@ class BeaverDB:
|
|
|
65
69
|
"""
|
|
66
70
|
)
|
|
67
71
|
|
|
68
|
-
def _create_kv_table(self):
|
|
69
|
-
"""Creates the key-value store table if it doesn't exist."""
|
|
70
|
-
with self._conn:
|
|
71
|
-
self._conn.execute(
|
|
72
|
-
"""
|
|
73
|
-
CREATE TABLE IF NOT EXISTS _beaver_kv_store (
|
|
74
|
-
key TEXT PRIMARY KEY,
|
|
75
|
-
value TEXT NOT NULL
|
|
76
|
-
)
|
|
77
|
-
"""
|
|
78
|
-
)
|
|
79
|
-
|
|
80
72
|
def _create_list_table(self):
|
|
81
|
-
"""Creates the lists table
|
|
73
|
+
"""Creates the lists table."""
|
|
82
74
|
with self._conn:
|
|
83
75
|
self._conn.execute(
|
|
84
76
|
"""
|
|
@@ -92,7 +84,7 @@ class BeaverDB:
|
|
|
92
84
|
)
|
|
93
85
|
|
|
94
86
|
def _create_collections_table(self):
|
|
95
|
-
"""Creates the
|
|
87
|
+
"""Creates the main table for storing documents and vectors."""
|
|
96
88
|
with self._conn:
|
|
97
89
|
self._conn.execute(
|
|
98
90
|
"""
|
|
@@ -106,24 +98,60 @@ class BeaverDB:
|
|
|
106
98
|
"""
|
|
107
99
|
)
|
|
108
100
|
|
|
101
|
+
def _create_fts_table(self):
|
|
102
|
+
"""Creates the virtual FTS table for full-text search."""
|
|
103
|
+
with self._conn:
|
|
104
|
+
self._conn.execute(
|
|
105
|
+
"""
|
|
106
|
+
CREATE VIRTUAL TABLE IF NOT EXISTS beaver_fts_index USING fts5(
|
|
107
|
+
collection,
|
|
108
|
+
item_id,
|
|
109
|
+
field_path,
|
|
110
|
+
field_content,
|
|
111
|
+
tokenize = 'porter'
|
|
112
|
+
)
|
|
113
|
+
"""
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
def _create_edges_table(self):
|
|
117
|
+
"""Creates the table for storing relationships between documents."""
|
|
118
|
+
with self._conn:
|
|
119
|
+
self._conn.execute(
|
|
120
|
+
"""
|
|
121
|
+
CREATE TABLE IF NOT EXISTS beaver_edges (
|
|
122
|
+
collection TEXT NOT NULL,
|
|
123
|
+
source_item_id TEXT NOT NULL,
|
|
124
|
+
target_item_id TEXT NOT NULL,
|
|
125
|
+
label TEXT NOT NULL,
|
|
126
|
+
metadata TEXT,
|
|
127
|
+
PRIMARY KEY (collection, source_item_id, target_item_id, label)
|
|
128
|
+
)
|
|
129
|
+
"""
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
def _create_versions_table(self):
|
|
133
|
+
"""Creates a table to track the version of each collection for caching."""
|
|
134
|
+
with self._conn:
|
|
135
|
+
self._conn.execute(
|
|
136
|
+
"""
|
|
137
|
+
CREATE TABLE IF NOT EXISTS beaver_collection_versions (
|
|
138
|
+
collection_name TEXT PRIMARY KEY,
|
|
139
|
+
version INTEGER NOT NULL DEFAULT 0
|
|
140
|
+
)
|
|
141
|
+
"""
|
|
142
|
+
)
|
|
143
|
+
|
|
109
144
|
def close(self):
|
|
110
145
|
"""Closes the database connection."""
|
|
111
146
|
if self._conn:
|
|
112
147
|
self._conn.close()
|
|
113
148
|
|
|
114
|
-
# ---
|
|
149
|
+
# --- Factory and Passthrough Methods ---
|
|
115
150
|
|
|
116
151
|
def set(self, key: str, value: Any):
|
|
117
152
|
"""
|
|
118
153
|
Stores a JSON-serializable value for a given key.
|
|
119
154
|
This operation is synchronous.
|
|
120
|
-
|
|
121
|
-
Args:
|
|
122
|
-
key: The unique string identifier for the value.
|
|
123
|
-
value: A JSON-serializable Python object (dict, list, str, int, etc.).
|
|
124
|
-
|
|
125
|
-
Raises:
|
|
126
|
-
TypeError: If the key is not a string or the value is not JSON-serializable.
|
|
127
155
|
"""
|
|
128
156
|
if not isinstance(key, str):
|
|
129
157
|
raise TypeError("Key must be a string.")
|
|
@@ -143,15 +171,6 @@ class BeaverDB:
|
|
|
143
171
|
"""
|
|
144
172
|
Retrieves a value for a given key.
|
|
145
173
|
This operation is synchronous.
|
|
146
|
-
|
|
147
|
-
Args:
|
|
148
|
-
key: The string identifier for the value.
|
|
149
|
-
|
|
150
|
-
Returns:
|
|
151
|
-
The deserialized Python object, or None if the key is not found.
|
|
152
|
-
|
|
153
|
-
Raises:
|
|
154
|
-
TypeError: If the key is not a string.
|
|
155
174
|
"""
|
|
156
175
|
if not isinstance(key, str):
|
|
157
176
|
raise TypeError("Key must be a string.")
|
|
@@ -161,37 +180,20 @@ class BeaverDB:
|
|
|
161
180
|
result = cursor.fetchone()
|
|
162
181
|
cursor.close()
|
|
163
182
|
|
|
164
|
-
if result
|
|
165
|
-
return json.loads(result["value"])
|
|
166
|
-
return None
|
|
167
|
-
|
|
168
|
-
# --- List Methods ---
|
|
169
|
-
|
|
170
|
-
def list(self, name: str) -> "ListWrapper":
|
|
171
|
-
"""
|
|
172
|
-
Returns a wrapper object for interacting with a specific list.
|
|
173
|
-
|
|
174
|
-
Args:
|
|
175
|
-
name: The name of the list.
|
|
183
|
+
return json.loads(result["value"]) if result else None
|
|
176
184
|
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
"""
|
|
185
|
+
def list(self, name: str) -> ListWrapper:
|
|
186
|
+
"""Returns a wrapper object for interacting with a named list."""
|
|
180
187
|
if not isinstance(name, str) or not name:
|
|
181
188
|
raise TypeError("List name must be a non-empty string.")
|
|
182
189
|
return ListWrapper(name, self._conn)
|
|
183
190
|
|
|
184
|
-
def collection(self, name: str) ->
|
|
185
|
-
"""Returns a wrapper for interacting with a
|
|
191
|
+
def collection(self, name: str) -> CollectionWrapper:
|
|
192
|
+
"""Returns a wrapper for interacting with a document collection."""
|
|
186
193
|
return CollectionWrapper(name, self._conn)
|
|
187
194
|
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
async def publish(self, channel_name: str, payload: Any):
|
|
191
|
-
"""
|
|
192
|
-
Publishes a JSON-serializable message to a channel.
|
|
193
|
-
This operation is asynchronous.
|
|
194
|
-
"""
|
|
195
|
+
def publish(self, channel_name: str, payload: Any):
|
|
196
|
+
"""Publishes a JSON-serializable message to a channel. This is synchronous."""
|
|
195
197
|
if not isinstance(channel_name, str) or not channel_name:
|
|
196
198
|
raise ValueError("Channel name must be a non-empty string.")
|
|
197
199
|
try:
|
|
@@ -199,440 +201,12 @@ class BeaverDB:
|
|
|
199
201
|
except TypeError as e:
|
|
200
202
|
raise TypeError("Message payload must be JSON-serializable.") from e
|
|
201
203
|
|
|
202
|
-
await asyncio.to_thread(self._write_publish_to_db, channel_name, json_payload)
|
|
203
|
-
|
|
204
|
-
def _write_publish_to_db(self, channel_name, json_payload):
|
|
205
|
-
"""The synchronous part of the publish operation."""
|
|
206
204
|
with self._conn:
|
|
207
205
|
self._conn.execute(
|
|
208
206
|
"INSERT INTO beaver_pubsub_log (timestamp, channel_name, message_payload) VALUES (?, ?, ?)",
|
|
209
207
|
(time.time(), channel_name, json_payload),
|
|
210
208
|
)
|
|
211
209
|
|
|
212
|
-
def subscribe(self, channel_name: str) ->
|
|
213
|
-
"""
|
|
214
|
-
|
|
215
|
-
"""
|
|
216
|
-
return Subscriber(self._conn, channel_name)
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
class ListWrapper:
|
|
220
|
-
"""A wrapper providing a Pythonic interface to a list in the database."""
|
|
221
|
-
|
|
222
|
-
def __init__(self, name: str, conn: sqlite3.Connection):
|
|
223
|
-
self._name = name
|
|
224
|
-
self._conn = conn
|
|
225
|
-
|
|
226
|
-
def __len__(self) -> int:
|
|
227
|
-
"""Returns the number of items in the list (e.g., `len(my_list)`)."""
|
|
228
|
-
cursor = self._conn.cursor()
|
|
229
|
-
cursor.execute(
|
|
230
|
-
"SELECT COUNT(*) FROM beaver_lists WHERE list_name = ?", (self._name,)
|
|
231
|
-
)
|
|
232
|
-
count = cursor.fetchone()[0]
|
|
233
|
-
cursor.close()
|
|
234
|
-
return count
|
|
235
|
-
|
|
236
|
-
def __getitem__(self, key: Union[int, slice]) -> Any:
|
|
237
|
-
"""
|
|
238
|
-
Retrieves an item or slice from the list (e.g., `my_list[0]`, `my_list[1:3]`).
|
|
239
|
-
"""
|
|
240
|
-
if isinstance(key, slice):
|
|
241
|
-
start, stop, step = key.indices(len(self))
|
|
242
|
-
if step != 1:
|
|
243
|
-
raise ValueError("Slicing with a step is not supported.")
|
|
244
|
-
|
|
245
|
-
limit = stop - start
|
|
246
|
-
if limit <= 0:
|
|
247
|
-
return []
|
|
248
|
-
|
|
249
|
-
cursor = self._conn.cursor()
|
|
250
|
-
cursor.execute(
|
|
251
|
-
"SELECT item_value FROM beaver_lists WHERE list_name = ? ORDER BY item_order ASC LIMIT ? OFFSET ?",
|
|
252
|
-
(self._name, limit, start),
|
|
253
|
-
)
|
|
254
|
-
results = [json.loads(row["item_value"]) for row in cursor.fetchall()]
|
|
255
|
-
cursor.close()
|
|
256
|
-
return results
|
|
257
|
-
|
|
258
|
-
elif isinstance(key, int):
|
|
259
|
-
list_len = len(self)
|
|
260
|
-
if key < -list_len or key >= list_len:
|
|
261
|
-
raise IndexError("List index out of range.")
|
|
262
|
-
|
|
263
|
-
offset = key if key >= 0 else list_len + key
|
|
264
|
-
|
|
265
|
-
cursor = self._conn.cursor()
|
|
266
|
-
cursor.execute(
|
|
267
|
-
"SELECT item_value FROM beaver_lists WHERE list_name = ? ORDER BY item_order ASC LIMIT 1 OFFSET ?",
|
|
268
|
-
(self._name, offset),
|
|
269
|
-
)
|
|
270
|
-
result = cursor.fetchone()
|
|
271
|
-
cursor.close()
|
|
272
|
-
return json.loads(result["item_value"]) if result else None
|
|
273
|
-
|
|
274
|
-
else:
|
|
275
|
-
raise TypeError("List indices must be integers or slices.")
|
|
276
|
-
|
|
277
|
-
def _get_order_at_index(self, index: int) -> float:
|
|
278
|
-
"""Helper to get the float `item_order` at a specific index."""
|
|
279
|
-
cursor = self._conn.cursor()
|
|
280
|
-
cursor.execute(
|
|
281
|
-
"SELECT item_order FROM beaver_lists WHERE list_name = ? ORDER BY item_order ASC LIMIT 1 OFFSET ?",
|
|
282
|
-
(self._name, index),
|
|
283
|
-
)
|
|
284
|
-
result = cursor.fetchone()
|
|
285
|
-
cursor.close()
|
|
286
|
-
|
|
287
|
-
if result:
|
|
288
|
-
return result[0]
|
|
289
|
-
|
|
290
|
-
raise IndexError(f"{index} out of range.")
|
|
291
|
-
|
|
292
|
-
def push(self, value: Any):
|
|
293
|
-
"""Pushes an item to the end of the list."""
|
|
294
|
-
with self._conn:
|
|
295
|
-
cursor = self._conn.cursor()
|
|
296
|
-
cursor.execute(
|
|
297
|
-
"SELECT MAX(item_order) FROM beaver_lists WHERE list_name = ?",
|
|
298
|
-
(self._name,),
|
|
299
|
-
)
|
|
300
|
-
max_order = cursor.fetchone()[0] or 0.0
|
|
301
|
-
new_order = max_order + 1.0
|
|
302
|
-
|
|
303
|
-
cursor.execute(
|
|
304
|
-
"INSERT INTO beaver_lists (list_name, item_order, item_value) VALUES (?, ?, ?)",
|
|
305
|
-
(self._name, new_order, json.dumps(value)),
|
|
306
|
-
)
|
|
307
|
-
|
|
308
|
-
def prepend(self, value: Any):
|
|
309
|
-
"""Prepends an item to the beginning of the list."""
|
|
310
|
-
with self._conn:
|
|
311
|
-
cursor = self._conn.cursor()
|
|
312
|
-
cursor.execute(
|
|
313
|
-
"SELECT MIN(item_order) FROM beaver_lists WHERE list_name = ?",
|
|
314
|
-
(self._name,),
|
|
315
|
-
)
|
|
316
|
-
min_order = cursor.fetchone()[0] or 0.0
|
|
317
|
-
new_order = min_order - 1.0
|
|
318
|
-
|
|
319
|
-
cursor.execute(
|
|
320
|
-
"INSERT INTO beaver_lists (list_name, item_order, item_value) VALUES (?, ?, ?)",
|
|
321
|
-
(self._name, new_order, json.dumps(value)),
|
|
322
|
-
)
|
|
323
|
-
|
|
324
|
-
def insert(self, index: int, value: Any):
|
|
325
|
-
"""Inserts an item at a specific index."""
|
|
326
|
-
list_len = len(self)
|
|
327
|
-
if index <= 0:
|
|
328
|
-
self.prepend(value)
|
|
329
|
-
return
|
|
330
|
-
if index >= list_len:
|
|
331
|
-
self.push(value)
|
|
332
|
-
return
|
|
333
|
-
|
|
334
|
-
# Midpoint insertion
|
|
335
|
-
order_before = self._get_order_at_index(index - 1)
|
|
336
|
-
order_after = self._get_order_at_index(index)
|
|
337
|
-
|
|
338
|
-
new_order = order_before + (order_after - order_before) / 2.0
|
|
339
|
-
|
|
340
|
-
with self._conn:
|
|
341
|
-
self._conn.execute(
|
|
342
|
-
"INSERT INTO beaver_lists (list_name, item_order, item_value) VALUES (?, ?, ?)",
|
|
343
|
-
(self._name, new_order, json.dumps(value)),
|
|
344
|
-
)
|
|
345
|
-
|
|
346
|
-
def pop(self) -> Any:
|
|
347
|
-
"""Removes and returns the last item from the list."""
|
|
348
|
-
with self._conn:
|
|
349
|
-
cursor = self._conn.cursor()
|
|
350
|
-
cursor.execute(
|
|
351
|
-
"SELECT rowid, item_value FROM beaver_lists WHERE list_name = ? ORDER BY item_order DESC LIMIT 1",
|
|
352
|
-
(self._name,),
|
|
353
|
-
)
|
|
354
|
-
result = cursor.fetchone()
|
|
355
|
-
if not result:
|
|
356
|
-
return None
|
|
357
|
-
|
|
358
|
-
rowid_to_delete, value_to_return = result
|
|
359
|
-
cursor.execute(
|
|
360
|
-
"DELETE FROM beaver_lists WHERE rowid = ?", (rowid_to_delete,)
|
|
361
|
-
)
|
|
362
|
-
return json.loads(value_to_return)
|
|
363
|
-
|
|
364
|
-
def deque(self) -> Any:
|
|
365
|
-
"""Removes and returns the first item from the list."""
|
|
366
|
-
with self._conn:
|
|
367
|
-
cursor = self._conn.cursor()
|
|
368
|
-
cursor.execute(
|
|
369
|
-
"SELECT rowid, item_value FROM beaver_lists WHERE list_name = ? ORDER BY item_order ASC LIMIT 1",
|
|
370
|
-
(self._name,),
|
|
371
|
-
)
|
|
372
|
-
result = cursor.fetchone()
|
|
373
|
-
if not result:
|
|
374
|
-
return None
|
|
375
|
-
|
|
376
|
-
rowid_to_delete, value_to_return = result
|
|
377
|
-
cursor.execute(
|
|
378
|
-
"DELETE FROM beaver_lists WHERE rowid = ?", (rowid_to_delete,)
|
|
379
|
-
)
|
|
380
|
-
return json.loads(value_to_return)
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
class Subscriber(AsyncIterator):
|
|
384
|
-
"""
|
|
385
|
-
An async iterator that polls a channel for new messages.
|
|
386
|
-
Designed to be used with 'async with'.
|
|
387
|
-
"""
|
|
388
|
-
|
|
389
|
-
def __init__(
|
|
390
|
-
self, conn: sqlite3.Connection, channel_name: str, poll_interval: float = 0.1
|
|
391
|
-
):
|
|
392
|
-
self._conn = conn
|
|
393
|
-
self._channel = channel_name
|
|
394
|
-
self._poll_interval = poll_interval
|
|
395
|
-
self._queue = asyncio.Queue()
|
|
396
|
-
self._last_seen_timestamp = time.time()
|
|
397
|
-
self._polling_task = None
|
|
398
|
-
|
|
399
|
-
async def _poll_for_messages(self):
|
|
400
|
-
"""Background task that polls the database for new messages."""
|
|
401
|
-
while True:
|
|
402
|
-
try:
|
|
403
|
-
new_messages = await asyncio.to_thread(self._fetch_new_messages_from_db)
|
|
404
|
-
if new_messages:
|
|
405
|
-
for msg in new_messages:
|
|
406
|
-
payload = json.loads(msg["message_payload"])
|
|
407
|
-
await self._queue.put(payload)
|
|
408
|
-
self._last_seen_timestamp = msg["timestamp"]
|
|
409
|
-
await asyncio.sleep(self._poll_interval)
|
|
410
|
-
except asyncio.CancelledError:
|
|
411
|
-
break
|
|
412
|
-
except Exception:
|
|
413
|
-
# In a real app, add more robust error logging
|
|
414
|
-
await asyncio.sleep(self._poll_interval * 5)
|
|
415
|
-
|
|
416
|
-
def _fetch_new_messages_from_db(self) -> list:
|
|
417
|
-
"""The actual synchronous database query."""
|
|
418
|
-
cursor = self._conn.cursor()
|
|
419
|
-
cursor.execute(
|
|
420
|
-
"SELECT timestamp, message_payload FROM beaver_pubsub_log WHERE channel_name = ? AND timestamp > ? ORDER BY timestamp ASC",
|
|
421
|
-
(self._channel, self._last_seen_timestamp),
|
|
422
|
-
)
|
|
423
|
-
results = cursor.fetchall()
|
|
424
|
-
cursor.close()
|
|
425
|
-
return results
|
|
426
|
-
|
|
427
|
-
async def __aenter__(self):
|
|
428
|
-
"""Starts the background task."""
|
|
429
|
-
self._polling_task = asyncio.create_task(self._poll_for_messages())
|
|
430
|
-
return self
|
|
431
|
-
|
|
432
|
-
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
433
|
-
"""Stops the background task."""
|
|
434
|
-
if self._polling_task:
|
|
435
|
-
self._polling_task.cancel()
|
|
436
|
-
await asyncio.gather(self._polling_task, return_exceptions=True)
|
|
437
|
-
|
|
438
|
-
def __aiter__(self):
|
|
439
|
-
return self
|
|
440
|
-
|
|
441
|
-
async def __anext__(self) -> Any:
|
|
442
|
-
"""Allows 'async for' to pull messages from the internal queue."""
|
|
443
|
-
return await self._queue.get()
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
class Document:
|
|
447
|
-
"""A data class for a vector and its metadata, with a unique ID."""
|
|
448
|
-
|
|
449
|
-
def __init__(
|
|
450
|
-
self, embedding: list[float] | None = None, id: str | None = None, **metadata
|
|
451
|
-
):
|
|
452
|
-
self.id = id or str(uuid.uuid4())
|
|
453
|
-
|
|
454
|
-
if embedding is None:
|
|
455
|
-
self.embedding = None
|
|
456
|
-
else:
|
|
457
|
-
if not isinstance(embedding, list) or not all(
|
|
458
|
-
isinstance(x, (int, float)) for x in embedding
|
|
459
|
-
):
|
|
460
|
-
raise TypeError("Embedding must be a list of numbers.")
|
|
461
|
-
|
|
462
|
-
self.embedding = np.array(embedding, dtype=np.float32)
|
|
463
|
-
|
|
464
|
-
for key, value in metadata.items():
|
|
465
|
-
setattr(self, key, value)
|
|
466
|
-
|
|
467
|
-
def to_dict(self) -> dict[str, Any]:
|
|
468
|
-
"""Serializes metadata to a dictionary."""
|
|
469
|
-
metadata = self.__dict__.copy()
|
|
470
|
-
# Exclude internal attributes from the metadata payload
|
|
471
|
-
metadata.pop("embedding", None)
|
|
472
|
-
metadata.pop("id", None)
|
|
473
|
-
return metadata
|
|
474
|
-
|
|
475
|
-
def __repr__(self):
|
|
476
|
-
metadata_str = ", ".join(f"{k}={v!r}" for k, v in self.to_dict().items())
|
|
477
|
-
return f"Document(id='{self.id}', {metadata_str})"
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
class CollectionWrapper:
|
|
481
|
-
"""A wrapper for vector collection operations with upsert logic."""
|
|
482
|
-
|
|
483
|
-
def __init__(self, name: str, conn: sqlite3.Connection):
|
|
484
|
-
self._name = name
|
|
485
|
-
self._conn = conn
|
|
486
|
-
|
|
487
|
-
# Dentro de la clase CollectionWrapper en beaver/core.py
|
|
488
|
-
|
|
489
|
-
def _flatten_metadata(self, metadata: dict, prefix: str = "") -> dict[str, str]:
|
|
490
|
-
"""
|
|
491
|
-
Aplana un diccionario anidado y filtra solo los valores de tipo string.
|
|
492
|
-
Ejemplo: {'a': {'b': 'c'}} -> {'a__b': 'c'}
|
|
493
|
-
"""
|
|
494
|
-
flat_dict = {}
|
|
495
|
-
for key, value in metadata.items():
|
|
496
|
-
new_key = f"{prefix}__{key}" if prefix else key
|
|
497
|
-
if isinstance(value, dict):
|
|
498
|
-
flat_dict.update(self._flatten_metadata(value, new_key))
|
|
499
|
-
elif isinstance(value, str):
|
|
500
|
-
flat_dict[new_key] = value
|
|
501
|
-
return flat_dict
|
|
502
|
-
|
|
503
|
-
def index(self, document: Document, *, fts: bool = True):
|
|
504
|
-
"""
|
|
505
|
-
Indexa un Document, realizando un upsert y actualizando el índice FTS.
|
|
506
|
-
"""
|
|
507
|
-
with self._conn:
|
|
508
|
-
if fts:
|
|
509
|
-
self._conn.execute(
|
|
510
|
-
"DELETE FROM beaver_fts_index WHERE collection = ? AND item_id = ?",
|
|
511
|
-
(self._name, document.id),
|
|
512
|
-
)
|
|
513
|
-
|
|
514
|
-
string_fields = self._flatten_metadata(document.to_dict())
|
|
515
|
-
|
|
516
|
-
if string_fields:
|
|
517
|
-
fts_data = [
|
|
518
|
-
(self._name, document.id, path, content)
|
|
519
|
-
for path, content in string_fields.items()
|
|
520
|
-
]
|
|
521
|
-
self._conn.executemany(
|
|
522
|
-
"INSERT INTO beaver_fts_index (collection, item_id, field_path, field_content) VALUES (?, ?, ?, ?)",
|
|
523
|
-
fts_data,
|
|
524
|
-
)
|
|
525
|
-
|
|
526
|
-
self._conn.execute(
|
|
527
|
-
"INSERT OR REPLACE INTO beaver_collections (collection, item_id, item_vector, metadata) VALUES (?, ?, ?, ?)",
|
|
528
|
-
(
|
|
529
|
-
self._name,
|
|
530
|
-
document.id,
|
|
531
|
-
document.embedding.tobytes() if document.embedding is not None else None,
|
|
532
|
-
json.dumps(document.to_dict()),
|
|
533
|
-
),
|
|
534
|
-
)
|
|
535
|
-
|
|
536
|
-
def search(
|
|
537
|
-
self, vector: list[float], top_k: int = 10
|
|
538
|
-
) -> list[tuple[Document, float]]:
|
|
539
|
-
"""
|
|
540
|
-
Performs a vector search and returns Document objects.
|
|
541
|
-
"""
|
|
542
|
-
query_vector = np.array(vector, dtype=np.float32)
|
|
543
|
-
|
|
544
|
-
cursor = self._conn.cursor()
|
|
545
|
-
cursor.execute(
|
|
546
|
-
"SELECT item_id, item_vector, metadata FROM beaver_collections WHERE collection = ?",
|
|
547
|
-
(self._name,),
|
|
548
|
-
)
|
|
549
|
-
|
|
550
|
-
all_docs_data = cursor.fetchall()
|
|
551
|
-
cursor.close()
|
|
552
|
-
|
|
553
|
-
if not all_docs_data:
|
|
554
|
-
return []
|
|
555
|
-
|
|
556
|
-
results = []
|
|
557
|
-
for row in all_docs_data:
|
|
558
|
-
if row["item_vector"] is None:
|
|
559
|
-
continue # Skip documents without embeddings
|
|
560
|
-
|
|
561
|
-
doc_id = row["item_id"]
|
|
562
|
-
embedding = np.frombuffer(row["item_vector"], dtype=np.float32).tolist()
|
|
563
|
-
metadata = json.loads(row["metadata"])
|
|
564
|
-
|
|
565
|
-
distance = np.linalg.norm(embedding - query_vector)
|
|
566
|
-
|
|
567
|
-
# Reconstruct the Document object with its original ID
|
|
568
|
-
doc = Document(id=doc_id, embedding=list(embedding), **metadata)
|
|
569
|
-
results.append((doc, float(distance)))
|
|
570
|
-
|
|
571
|
-
results.sort(key=lambda x: x[1])
|
|
572
|
-
return results[:top_k]
|
|
573
|
-
|
|
574
|
-
def match(
|
|
575
|
-
self, query: str, on_field: str | None = None, top_k: int = 10
|
|
576
|
-
) -> list[tuple[Document, float]]:
|
|
577
|
-
"""
|
|
578
|
-
Realiza una búsqueda de texto completo en los campos de metadatos indexados.
|
|
579
|
-
|
|
580
|
-
Args:
|
|
581
|
-
query: La expresión de búsqueda (ej. "gato", "perro OR conejo").
|
|
582
|
-
on_field: Opcional, el campo específico donde buscar (ej. "details__title").
|
|
583
|
-
top_k: El número máximo de resultados a devolver.
|
|
584
|
-
|
|
585
|
-
Returns:
|
|
586
|
-
Una lista de tuplas (Documento, puntuación_de_relevancia).
|
|
587
|
-
"""
|
|
588
|
-
cursor = self._conn.cursor()
|
|
589
|
-
|
|
590
|
-
sql_query = """
|
|
591
|
-
SELECT
|
|
592
|
-
t1.item_id, t1.item_vector, t1.metadata, fts.rank
|
|
593
|
-
FROM beaver_collections AS t1
|
|
594
|
-
JOIN (
|
|
595
|
-
SELECT DISTINCT item_id, rank
|
|
596
|
-
FROM beaver_fts_index
|
|
597
|
-
WHERE beaver_fts_index MATCH ?
|
|
598
|
-
ORDER BY rank
|
|
599
|
-
LIMIT ?
|
|
600
|
-
) AS fts ON t1.item_id = fts.item_id
|
|
601
|
-
WHERE t1.collection = ?
|
|
602
|
-
ORDER BY fts.rank
|
|
603
|
-
"""
|
|
604
|
-
|
|
605
|
-
params = []
|
|
606
|
-
field_filter_sql = ""
|
|
607
|
-
|
|
608
|
-
if on_field:
|
|
609
|
-
field_filter_sql = "AND field_path = ?"
|
|
610
|
-
params.append(on_field)
|
|
611
|
-
else:
|
|
612
|
-
# Búsqueda en todos los campos
|
|
613
|
-
params.append(query)
|
|
614
|
-
|
|
615
|
-
sql_query = sql_query.format(field_filter_sql)
|
|
616
|
-
params.extend([top_k, self._name])
|
|
617
|
-
|
|
618
|
-
cursor.execute(sql_query, tuple(params))
|
|
619
|
-
|
|
620
|
-
results = []
|
|
621
|
-
for row in cursor.fetchall():
|
|
622
|
-
doc_id = row["item_id"]
|
|
623
|
-
|
|
624
|
-
if row["item_vector"] is None:
|
|
625
|
-
embedding = None
|
|
626
|
-
else:
|
|
627
|
-
embedding = np.frombuffer(row["item_vector"], dtype=np.float32).tolist()
|
|
628
|
-
|
|
629
|
-
metadata = json.loads(row["metadata"])
|
|
630
|
-
rank = row["rank"]
|
|
631
|
-
|
|
632
|
-
doc = Document(id=doc_id, embedding=embedding, **metadata)
|
|
633
|
-
results.append((doc, rank))
|
|
634
|
-
|
|
635
|
-
results.sort(key=lambda x: x[1])
|
|
636
|
-
cursor.close()
|
|
637
|
-
|
|
638
|
-
return results
|
|
210
|
+
def subscribe(self, channel_name: str) -> SubWrapper:
|
|
211
|
+
"""Subscribes to a channel, returning a synchronous iterator."""
|
|
212
|
+
return SubWrapper(self._conn, channel_name)
|