beaver-db 0.3.0__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of beaver-db might be problematic. Click here for more details.
- beaver/__init__.py +2 -1
- beaver/collections.py +327 -0
- beaver/core.py +103 -383
- beaver/lists.py +166 -0
- beaver/subscribers.py +54 -0
- beaver_db-0.5.0.dist-info/METADATA +171 -0
- beaver_db-0.5.0.dist-info/RECORD +9 -0
- beaver_db-0.3.0.dist-info/METADATA +0 -129
- beaver_db-0.3.0.dist-info/RECORD +0 -6
- {beaver_db-0.3.0.dist-info → beaver_db-0.5.0.dist-info}/WHEEL +0 -0
- {beaver_db-0.3.0.dist-info → beaver_db-0.5.0.dist-info}/top_level.txt +0 -0
beaver/core.py
CHANGED
|
@@ -1,21 +1,22 @@
|
|
|
1
|
-
import asyncio
|
|
2
|
-
import uuid
|
|
3
|
-
import numpy as np
|
|
4
1
|
import json
|
|
5
2
|
import sqlite3
|
|
6
3
|
import time
|
|
7
|
-
from typing import Any
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
from .lists import ListWrapper
|
|
7
|
+
from .subscribers import SubWrapper
|
|
8
|
+
from .collections import CollectionWrapper
|
|
8
9
|
|
|
9
10
|
|
|
10
11
|
class BeaverDB:
|
|
11
12
|
"""
|
|
12
13
|
An embedded, multi-modal database in a single SQLite file.
|
|
13
|
-
|
|
14
|
+
This class manages the database connection and table schemas.
|
|
14
15
|
"""
|
|
15
16
|
|
|
16
17
|
def __init__(self, db_path: str):
|
|
17
18
|
"""
|
|
18
|
-
Initializes the database connection and creates necessary tables.
|
|
19
|
+
Initializes the database connection and creates all necessary tables.
|
|
19
20
|
|
|
20
21
|
Args:
|
|
21
22
|
db_path: The path to the SQLite database file.
|
|
@@ -25,79 +26,132 @@ class BeaverDB:
|
|
|
25
26
|
self._conn = sqlite3.connect(self._db_path, check_same_thread=False)
|
|
26
27
|
self._conn.execute("PRAGMA journal_mode=WAL;")
|
|
27
28
|
self._conn.row_factory = sqlite3.Row
|
|
28
|
-
self.
|
|
29
|
+
self._create_all_tables()
|
|
30
|
+
|
|
31
|
+
def _create_all_tables(self):
|
|
32
|
+
"""Initializes all required tables in the database file."""
|
|
29
33
|
self._create_kv_table()
|
|
34
|
+
self._create_pubsub_table()
|
|
30
35
|
self._create_list_table()
|
|
31
36
|
self._create_collections_table()
|
|
37
|
+
self._create_fts_table()
|
|
38
|
+
self._create_edges_table()
|
|
39
|
+
self._create_versions_table()
|
|
40
|
+
|
|
41
|
+
def _create_kv_table(self):
|
|
42
|
+
"""Creates the key-value store table."""
|
|
43
|
+
with self._conn:
|
|
44
|
+
self._conn.execute(
|
|
45
|
+
"""
|
|
46
|
+
CREATE TABLE IF NOT EXISTS _beaver_kv_store (
|
|
47
|
+
key TEXT PRIMARY KEY,
|
|
48
|
+
value TEXT NOT NULL
|
|
49
|
+
)
|
|
50
|
+
"""
|
|
51
|
+
)
|
|
32
52
|
|
|
33
53
|
def _create_pubsub_table(self):
|
|
34
|
-
"""Creates the pub/sub log table
|
|
54
|
+
"""Creates the pub/sub log table."""
|
|
35
55
|
with self._conn:
|
|
36
|
-
self._conn.execute(
|
|
56
|
+
self._conn.execute(
|
|
57
|
+
"""
|
|
37
58
|
CREATE TABLE IF NOT EXISTS beaver_pubsub_log (
|
|
38
59
|
timestamp REAL PRIMARY KEY,
|
|
39
60
|
channel_name TEXT NOT NULL,
|
|
40
61
|
message_payload TEXT NOT NULL
|
|
41
62
|
)
|
|
42
|
-
"""
|
|
43
|
-
|
|
63
|
+
"""
|
|
64
|
+
)
|
|
65
|
+
self._conn.execute(
|
|
66
|
+
"""
|
|
44
67
|
CREATE INDEX IF NOT EXISTS idx_pubsub_channel_timestamp
|
|
45
68
|
ON beaver_pubsub_log (channel_name, timestamp)
|
|
46
|
-
"""
|
|
47
|
-
|
|
48
|
-
def _create_kv_table(self):
|
|
49
|
-
"""Creates the key-value store table if it doesn't exist."""
|
|
50
|
-
with self._conn:
|
|
51
|
-
self._conn.execute("""
|
|
52
|
-
CREATE TABLE IF NOT EXISTS _beaver_kv_store (
|
|
53
|
-
key TEXT PRIMARY KEY,
|
|
54
|
-
value TEXT NOT NULL
|
|
55
|
-
)
|
|
56
|
-
""")
|
|
69
|
+
"""
|
|
70
|
+
)
|
|
57
71
|
|
|
58
72
|
def _create_list_table(self):
|
|
59
|
-
"""Creates the lists table
|
|
73
|
+
"""Creates the lists table."""
|
|
60
74
|
with self._conn:
|
|
61
|
-
self._conn.execute(
|
|
75
|
+
self._conn.execute(
|
|
76
|
+
"""
|
|
62
77
|
CREATE TABLE IF NOT EXISTS beaver_lists (
|
|
63
78
|
list_name TEXT NOT NULL,
|
|
64
79
|
item_order REAL NOT NULL,
|
|
65
80
|
item_value TEXT NOT NULL,
|
|
66
81
|
PRIMARY KEY (list_name, item_order)
|
|
67
82
|
)
|
|
68
|
-
"""
|
|
83
|
+
"""
|
|
84
|
+
)
|
|
69
85
|
|
|
70
86
|
def _create_collections_table(self):
|
|
71
|
-
"""Creates the
|
|
87
|
+
"""Creates the main table for storing documents and vectors."""
|
|
72
88
|
with self._conn:
|
|
73
|
-
self._conn.execute(
|
|
89
|
+
self._conn.execute(
|
|
90
|
+
"""
|
|
74
91
|
CREATE TABLE IF NOT EXISTS beaver_collections (
|
|
75
92
|
collection TEXT NOT NULL,
|
|
76
93
|
item_id TEXT NOT NULL,
|
|
77
|
-
item_vector BLOB
|
|
94
|
+
item_vector BLOB,
|
|
78
95
|
metadata TEXT,
|
|
79
96
|
PRIMARY KEY (collection, item_id)
|
|
80
97
|
)
|
|
81
|
-
"""
|
|
98
|
+
"""
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
def _create_fts_table(self):
|
|
102
|
+
"""Creates the virtual FTS table for full-text search."""
|
|
103
|
+
with self._conn:
|
|
104
|
+
self._conn.execute(
|
|
105
|
+
"""
|
|
106
|
+
CREATE VIRTUAL TABLE IF NOT EXISTS beaver_fts_index USING fts5(
|
|
107
|
+
collection,
|
|
108
|
+
item_id,
|
|
109
|
+
field_path,
|
|
110
|
+
field_content,
|
|
111
|
+
tokenize = 'porter'
|
|
112
|
+
)
|
|
113
|
+
"""
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
def _create_edges_table(self):
|
|
117
|
+
"""Creates the table for storing relationships between documents."""
|
|
118
|
+
with self._conn:
|
|
119
|
+
self._conn.execute(
|
|
120
|
+
"""
|
|
121
|
+
CREATE TABLE IF NOT EXISTS beaver_edges (
|
|
122
|
+
collection TEXT NOT NULL,
|
|
123
|
+
source_item_id TEXT NOT NULL,
|
|
124
|
+
target_item_id TEXT NOT NULL,
|
|
125
|
+
label TEXT NOT NULL,
|
|
126
|
+
metadata TEXT,
|
|
127
|
+
PRIMARY KEY (collection, source_item_id, target_item_id, label)
|
|
128
|
+
)
|
|
129
|
+
"""
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
def _create_versions_table(self):
|
|
133
|
+
"""Creates a table to track the version of each collection for caching."""
|
|
134
|
+
with self._conn:
|
|
135
|
+
self._conn.execute(
|
|
136
|
+
"""
|
|
137
|
+
CREATE TABLE IF NOT EXISTS beaver_collection_versions (
|
|
138
|
+
collection_name TEXT PRIMARY KEY,
|
|
139
|
+
version INTEGER NOT NULL DEFAULT 0
|
|
140
|
+
)
|
|
141
|
+
"""
|
|
142
|
+
)
|
|
82
143
|
|
|
83
144
|
def close(self):
|
|
84
145
|
"""Closes the database connection."""
|
|
85
146
|
if self._conn:
|
|
86
147
|
self._conn.close()
|
|
87
148
|
|
|
88
|
-
# ---
|
|
149
|
+
# --- Factory and Passthrough Methods ---
|
|
89
150
|
|
|
90
151
|
def set(self, key: str, value: Any):
|
|
91
152
|
"""
|
|
92
153
|
Stores a JSON-serializable value for a given key.
|
|
93
154
|
This operation is synchronous.
|
|
94
|
-
|
|
95
|
-
Args:
|
|
96
|
-
key: The unique string identifier for the value.
|
|
97
|
-
value: A JSON-serializable Python object (dict, list, str, int, etc.).
|
|
98
|
-
|
|
99
|
-
Raises:
|
|
100
|
-
TypeError: If the key is not a string or the value is not JSON-serializable.
|
|
101
155
|
"""
|
|
102
156
|
if not isinstance(key, str):
|
|
103
157
|
raise TypeError("Key must be a string.")
|
|
@@ -110,22 +164,13 @@ class BeaverDB:
|
|
|
110
164
|
with self._conn:
|
|
111
165
|
self._conn.execute(
|
|
112
166
|
"INSERT OR REPLACE INTO _beaver_kv_store (key, value) VALUES (?, ?)",
|
|
113
|
-
(key, json_value)
|
|
167
|
+
(key, json_value),
|
|
114
168
|
)
|
|
115
169
|
|
|
116
170
|
def get(self, key: str) -> Any:
|
|
117
171
|
"""
|
|
118
172
|
Retrieves a value for a given key.
|
|
119
173
|
This operation is synchronous.
|
|
120
|
-
|
|
121
|
-
Args:
|
|
122
|
-
key: The string identifier for the value.
|
|
123
|
-
|
|
124
|
-
Returns:
|
|
125
|
-
The deserialized Python object, or None if the key is not found.
|
|
126
|
-
|
|
127
|
-
Raises:
|
|
128
|
-
TypeError: If the key is not a string.
|
|
129
174
|
"""
|
|
130
175
|
if not isinstance(key, str):
|
|
131
176
|
raise TypeError("Key must be a string.")
|
|
@@ -135,37 +180,20 @@ class BeaverDB:
|
|
|
135
180
|
result = cursor.fetchone()
|
|
136
181
|
cursor.close()
|
|
137
182
|
|
|
138
|
-
if result
|
|
139
|
-
return json.loads(result['value'])
|
|
140
|
-
return None
|
|
183
|
+
return json.loads(result["value"]) if result else None
|
|
141
184
|
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
def list(self, name: str) -> "ListWrapper":
|
|
145
|
-
"""
|
|
146
|
-
Returns a wrapper object for interacting with a specific list.
|
|
147
|
-
|
|
148
|
-
Args:
|
|
149
|
-
name: The name of the list.
|
|
150
|
-
|
|
151
|
-
Returns:
|
|
152
|
-
A ListWrapper instance bound to the given list name.
|
|
153
|
-
"""
|
|
185
|
+
def list(self, name: str) -> ListWrapper:
|
|
186
|
+
"""Returns a wrapper object for interacting with a named list."""
|
|
154
187
|
if not isinstance(name, str) or not name:
|
|
155
188
|
raise TypeError("List name must be a non-empty string.")
|
|
156
189
|
return ListWrapper(name, self._conn)
|
|
157
190
|
|
|
158
|
-
def collection(self, name: str) ->
|
|
159
|
-
"""Returns a wrapper for interacting with a
|
|
191
|
+
def collection(self, name: str) -> CollectionWrapper:
|
|
192
|
+
"""Returns a wrapper for interacting with a document collection."""
|
|
160
193
|
return CollectionWrapper(name, self._conn)
|
|
161
194
|
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
async def publish(self, channel_name: str, payload: Any):
|
|
165
|
-
"""
|
|
166
|
-
Publishes a JSON-serializable message to a channel.
|
|
167
|
-
This operation is asynchronous.
|
|
168
|
-
"""
|
|
195
|
+
def publish(self, channel_name: str, payload: Any):
|
|
196
|
+
"""Publishes a JSON-serializable message to a channel. This is synchronous."""
|
|
169
197
|
if not isinstance(channel_name, str) or not channel_name:
|
|
170
198
|
raise ValueError("Channel name must be a non-empty string.")
|
|
171
199
|
try:
|
|
@@ -173,320 +201,12 @@ class BeaverDB:
|
|
|
173
201
|
except TypeError as e:
|
|
174
202
|
raise TypeError("Message payload must be JSON-serializable.") from e
|
|
175
203
|
|
|
176
|
-
await asyncio.to_thread(
|
|
177
|
-
self._write_publish_to_db, channel_name, json_payload
|
|
178
|
-
)
|
|
179
|
-
|
|
180
|
-
def _write_publish_to_db(self, channel_name, json_payload):
|
|
181
|
-
"""The synchronous part of the publish operation."""
|
|
182
204
|
with self._conn:
|
|
183
205
|
self._conn.execute(
|
|
184
206
|
"INSERT INTO beaver_pubsub_log (timestamp, channel_name, message_payload) VALUES (?, ?, ?)",
|
|
185
|
-
(time.time(), channel_name, json_payload)
|
|
186
|
-
)
|
|
187
|
-
|
|
188
|
-
def subscribe(self, channel_name: str) -> "Subscriber":
|
|
189
|
-
"""
|
|
190
|
-
Subscribes to a channel, returning an async iterator.
|
|
191
|
-
"""
|
|
192
|
-
return Subscriber(self._conn, channel_name)
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
class ListWrapper:
|
|
196
|
-
"""A wrapper providing a Pythonic interface to a list in the database."""
|
|
197
|
-
|
|
198
|
-
def __init__(self, name: str, conn: sqlite3.Connection):
|
|
199
|
-
self._name = name
|
|
200
|
-
self._conn = conn
|
|
201
|
-
|
|
202
|
-
def __len__(self) -> int:
|
|
203
|
-
"""Returns the number of items in the list (e.g., `len(my_list)`)."""
|
|
204
|
-
cursor = self._conn.cursor()
|
|
205
|
-
cursor.execute("SELECT COUNT(*) FROM beaver_lists WHERE list_name = ?", (self._name,))
|
|
206
|
-
count = cursor.fetchone()[0]
|
|
207
|
-
cursor.close()
|
|
208
|
-
return count
|
|
209
|
-
|
|
210
|
-
def __getitem__(self, key: Union[int, slice]) -> Any:
|
|
211
|
-
"""
|
|
212
|
-
Retrieves an item or slice from the list (e.g., `my_list[0]`, `my_list[1:3]`).
|
|
213
|
-
"""
|
|
214
|
-
if isinstance(key, slice):
|
|
215
|
-
start, stop, step = key.indices(len(self))
|
|
216
|
-
if step != 1:
|
|
217
|
-
raise ValueError("Slicing with a step is not supported.")
|
|
218
|
-
|
|
219
|
-
limit = stop - start
|
|
220
|
-
if limit <= 0:
|
|
221
|
-
return []
|
|
222
|
-
|
|
223
|
-
cursor = self._conn.cursor()
|
|
224
|
-
cursor.execute(
|
|
225
|
-
"SELECT item_value FROM beaver_lists WHERE list_name = ? ORDER BY item_order ASC LIMIT ? OFFSET ?",
|
|
226
|
-
(self._name, limit, start)
|
|
207
|
+
(time.time(), channel_name, json_payload),
|
|
227
208
|
)
|
|
228
|
-
results = [json.loads(row['item_value']) for row in cursor.fetchall()]
|
|
229
|
-
cursor.close()
|
|
230
|
-
return results
|
|
231
|
-
|
|
232
|
-
elif isinstance(key, int):
|
|
233
|
-
list_len = len(self)
|
|
234
|
-
if key < -list_len or key >= list_len:
|
|
235
|
-
raise IndexError("List index out of range.")
|
|
236
|
-
|
|
237
|
-
offset = key if key >= 0 else list_len + key
|
|
238
|
-
|
|
239
|
-
cursor = self._conn.cursor()
|
|
240
|
-
cursor.execute(
|
|
241
|
-
"SELECT item_value FROM beaver_lists WHERE list_name = ? ORDER BY item_order ASC LIMIT 1 OFFSET ?",
|
|
242
|
-
(self._name, offset)
|
|
243
|
-
)
|
|
244
|
-
result = cursor.fetchone()
|
|
245
|
-
cursor.close()
|
|
246
|
-
return json.loads(result['item_value']) if result else None
|
|
247
|
-
|
|
248
|
-
else:
|
|
249
|
-
raise TypeError("List indices must be integers or slices.")
|
|
250
|
-
|
|
251
|
-
def _get_order_at_index(self, index: int) -> float:
|
|
252
|
-
"""Helper to get the float `item_order` at a specific index."""
|
|
253
|
-
cursor = self._conn.cursor()
|
|
254
|
-
cursor.execute(
|
|
255
|
-
"SELECT item_order FROM beaver_lists WHERE list_name = ? ORDER BY item_order ASC LIMIT 1 OFFSET ?",
|
|
256
|
-
(self._name, index)
|
|
257
|
-
)
|
|
258
|
-
result = cursor.fetchone()
|
|
259
|
-
cursor.close()
|
|
260
|
-
|
|
261
|
-
if result:
|
|
262
|
-
return result[0]
|
|
263
|
-
|
|
264
|
-
raise IndexError(f"{index} out of range.")
|
|
265
|
-
|
|
266
|
-
def push(self, value: Any):
|
|
267
|
-
"""Pushes an item to the end of the list."""
|
|
268
|
-
with self._conn:
|
|
269
|
-
cursor = self._conn.cursor()
|
|
270
|
-
cursor.execute("SELECT MAX(item_order) FROM beaver_lists WHERE list_name = ?", (self._name,))
|
|
271
|
-
max_order = cursor.fetchone()[0] or 0.0
|
|
272
|
-
new_order = max_order + 1.0
|
|
273
|
-
|
|
274
|
-
cursor.execute(
|
|
275
|
-
"INSERT INTO beaver_lists (list_name, item_order, item_value) VALUES (?, ?, ?)",
|
|
276
|
-
(self._name, new_order, json.dumps(value))
|
|
277
|
-
)
|
|
278
|
-
|
|
279
|
-
def prepend(self, value: Any):
|
|
280
|
-
"""Prepends an item to the beginning of the list."""
|
|
281
|
-
with self._conn:
|
|
282
|
-
cursor = self._conn.cursor()
|
|
283
|
-
cursor.execute("SELECT MIN(item_order) FROM beaver_lists WHERE list_name = ?", (self._name,))
|
|
284
|
-
min_order = cursor.fetchone()[0] or 0.0
|
|
285
|
-
new_order = min_order - 1.0
|
|
286
|
-
|
|
287
|
-
cursor.execute(
|
|
288
|
-
"INSERT INTO beaver_lists (list_name, item_order, item_value) VALUES (?, ?, ?)",
|
|
289
|
-
(self._name, new_order, json.dumps(value))
|
|
290
|
-
)
|
|
291
|
-
|
|
292
|
-
def insert(self, index: int, value: Any):
|
|
293
|
-
"""Inserts an item at a specific index."""
|
|
294
|
-
list_len = len(self)
|
|
295
|
-
if index <= 0:
|
|
296
|
-
self.prepend(value)
|
|
297
|
-
return
|
|
298
|
-
if index >= list_len:
|
|
299
|
-
self.push(value)
|
|
300
|
-
return
|
|
301
|
-
|
|
302
|
-
# Midpoint insertion
|
|
303
|
-
order_before = self._get_order_at_index(index - 1)
|
|
304
|
-
order_after = self._get_order_at_index(index)
|
|
305
|
-
|
|
306
|
-
new_order = order_before + (order_after - order_before) / 2.0
|
|
307
|
-
|
|
308
|
-
with self._conn:
|
|
309
|
-
self._conn.execute(
|
|
310
|
-
"INSERT INTO beaver_lists (list_name, item_order, item_value) VALUES (?, ?, ?)",
|
|
311
|
-
(self._name, new_order, json.dumps(value))
|
|
312
|
-
)
|
|
313
|
-
|
|
314
|
-
def pop(self) -> Any:
|
|
315
|
-
"""Removes and returns the last item from the list."""
|
|
316
|
-
with self._conn:
|
|
317
|
-
cursor = self._conn.cursor()
|
|
318
|
-
cursor.execute(
|
|
319
|
-
"SELECT rowid, item_value FROM beaver_lists WHERE list_name = ? ORDER BY item_order DESC LIMIT 1",
|
|
320
|
-
(self._name,)
|
|
321
|
-
)
|
|
322
|
-
result = cursor.fetchone()
|
|
323
|
-
if not result:
|
|
324
|
-
return None
|
|
325
|
-
|
|
326
|
-
rowid_to_delete, value_to_return = result
|
|
327
|
-
cursor.execute("DELETE FROM beaver_lists WHERE rowid = ?", (rowid_to_delete,))
|
|
328
|
-
return json.loads(value_to_return)
|
|
329
|
-
|
|
330
|
-
def deque(self) -> Any:
|
|
331
|
-
"""Removes and returns the first item from the list."""
|
|
332
|
-
with self._conn:
|
|
333
|
-
cursor = self._conn.cursor()
|
|
334
|
-
cursor.execute(
|
|
335
|
-
"SELECT rowid, item_value FROM beaver_lists WHERE list_name = ? ORDER BY item_order ASC LIMIT 1",
|
|
336
|
-
(self._name,)
|
|
337
|
-
)
|
|
338
|
-
result = cursor.fetchone()
|
|
339
|
-
if not result:
|
|
340
|
-
return None
|
|
341
|
-
|
|
342
|
-
rowid_to_delete, value_to_return = result
|
|
343
|
-
cursor.execute("DELETE FROM beaver_lists WHERE rowid = ?", (rowid_to_delete,))
|
|
344
|
-
return json.loads(value_to_return)
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
class Subscriber(AsyncIterator):
|
|
348
|
-
"""
|
|
349
|
-
An async iterator that polls a channel for new messages.
|
|
350
|
-
Designed to be used with 'async with'.
|
|
351
|
-
"""
|
|
352
|
-
|
|
353
|
-
def __init__(self, conn: sqlite3.Connection, channel_name: str, poll_interval: float = 0.1):
|
|
354
|
-
self._conn = conn
|
|
355
|
-
self._channel = channel_name
|
|
356
|
-
self._poll_interval = poll_interval
|
|
357
|
-
self._queue = asyncio.Queue()
|
|
358
|
-
self._last_seen_timestamp = time.time()
|
|
359
|
-
self._polling_task = None
|
|
360
|
-
|
|
361
|
-
async def _poll_for_messages(self):
|
|
362
|
-
"""Background task that polls the database for new messages."""
|
|
363
|
-
while True:
|
|
364
|
-
try:
|
|
365
|
-
new_messages = await asyncio.to_thread(
|
|
366
|
-
self._fetch_new_messages_from_db
|
|
367
|
-
)
|
|
368
|
-
if new_messages:
|
|
369
|
-
for msg in new_messages:
|
|
370
|
-
payload = json.loads(msg["message_payload"])
|
|
371
|
-
await self._queue.put(payload)
|
|
372
|
-
self._last_seen_timestamp = msg["timestamp"]
|
|
373
|
-
await asyncio.sleep(self._poll_interval)
|
|
374
|
-
except asyncio.CancelledError:
|
|
375
|
-
break
|
|
376
|
-
except Exception:
|
|
377
|
-
# In a real app, add more robust error logging
|
|
378
|
-
await asyncio.sleep(self._poll_interval * 5)
|
|
379
|
-
|
|
380
|
-
def _fetch_new_messages_from_db(self) -> list:
|
|
381
|
-
"""The actual synchronous database query."""
|
|
382
|
-
cursor = self._conn.cursor()
|
|
383
|
-
cursor.execute(
|
|
384
|
-
"SELECT timestamp, message_payload FROM beaver_pubsub_log WHERE channel_name = ? AND timestamp > ? ORDER BY timestamp ASC",
|
|
385
|
-
(self._channel, self._last_seen_timestamp)
|
|
386
|
-
)
|
|
387
|
-
results = cursor.fetchall()
|
|
388
|
-
cursor.close()
|
|
389
|
-
return results
|
|
390
|
-
|
|
391
|
-
async def __aenter__(self):
|
|
392
|
-
"""Starts the background task."""
|
|
393
|
-
self._polling_task = asyncio.create_task(self._poll_for_messages())
|
|
394
|
-
return self
|
|
395
|
-
|
|
396
|
-
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
397
|
-
"""Stops the background task."""
|
|
398
|
-
if self._polling_task:
|
|
399
|
-
self._polling_task.cancel()
|
|
400
|
-
await asyncio.gather(self._polling_task, return_exceptions=True)
|
|
401
|
-
|
|
402
|
-
def __aiter__(self):
|
|
403
|
-
return self
|
|
404
|
-
|
|
405
|
-
async def __anext__(self) -> Any:
|
|
406
|
-
"""Allows 'async for' to pull messages from the internal queue."""
|
|
407
|
-
return await self._queue.get()
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
class Document:
|
|
411
|
-
"""A data class for a vector and its metadata, with a unique ID."""
|
|
412
|
-
def __init__(self, embedding: list[float], id: str|None = None, **metadata):
|
|
413
|
-
if not isinstance(embedding, list) or not all(isinstance(x, (int, float)) for x in embedding):
|
|
414
|
-
raise TypeError("Embedding must be a list of numbers.")
|
|
415
|
-
|
|
416
|
-
self.id = id or str(uuid.uuid4())
|
|
417
|
-
self.embedding = np.array(embedding, dtype=np.float32)
|
|
418
|
-
|
|
419
|
-
for key, value in metadata.items():
|
|
420
|
-
setattr(self, key, value)
|
|
421
|
-
|
|
422
|
-
def to_dict(self) -> dict[str, Any]:
|
|
423
|
-
"""Serializes metadata to a dictionary."""
|
|
424
|
-
metadata = self.__dict__.copy()
|
|
425
|
-
# Exclude internal attributes from the metadata payload
|
|
426
|
-
metadata.pop('embedding', None)
|
|
427
|
-
metadata.pop('id', None)
|
|
428
|
-
return metadata
|
|
429
|
-
|
|
430
|
-
def __repr__(self):
|
|
431
|
-
metadata_str = ', '.join(f"{k}={v!r}" for k, v in self.to_dict().items())
|
|
432
|
-
return f"Document(id='{self.id}', {metadata_str})"
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
class CollectionWrapper:
|
|
436
|
-
"""A wrapper for vector collection operations with upsert logic."""
|
|
437
|
-
def __init__(self, name: str, conn: sqlite3.Connection):
|
|
438
|
-
self._name = name
|
|
439
|
-
self._conn = conn
|
|
440
|
-
|
|
441
|
-
def index(self, document: Document):
|
|
442
|
-
"""
|
|
443
|
-
Indexes a Document, performing an upsert based on the document's ID.
|
|
444
|
-
If the ID exists, the record is replaced.
|
|
445
|
-
If the ID is new (or auto-generated), a new record is inserted.
|
|
446
|
-
|
|
447
|
-
Args:
|
|
448
|
-
document: The Document object to index.
|
|
449
|
-
"""
|
|
450
|
-
with self._conn:
|
|
451
|
-
self._conn.execute(
|
|
452
|
-
"INSERT OR REPLACE INTO beaver_collections (collection, item_id, item_vector, metadata) VALUES (?, ?, ?, ?)",
|
|
453
|
-
(
|
|
454
|
-
self._name,
|
|
455
|
-
document.id,
|
|
456
|
-
document.embedding.tobytes(),
|
|
457
|
-
json.dumps(document.to_dict())
|
|
458
|
-
)
|
|
459
|
-
)
|
|
460
|
-
|
|
461
|
-
def search(self, vector: list[float], top_k: int = 10) -> list[tuple[Document, float]]:
|
|
462
|
-
"""
|
|
463
|
-
Performs a vector search and returns Document objects.
|
|
464
|
-
"""
|
|
465
|
-
query_vector = np.array(vector, dtype=np.float32)
|
|
466
|
-
|
|
467
|
-
cursor = self._conn.cursor()
|
|
468
|
-
cursor.execute(
|
|
469
|
-
"SELECT item_id, item_vector, metadata FROM beaver_collections WHERE collection = ?",
|
|
470
|
-
(self._name,)
|
|
471
|
-
)
|
|
472
|
-
|
|
473
|
-
all_docs_data = cursor.fetchall()
|
|
474
|
-
cursor.close()
|
|
475
|
-
|
|
476
|
-
if not all_docs_data:
|
|
477
|
-
return []
|
|
478
|
-
|
|
479
|
-
results = []
|
|
480
|
-
for row in all_docs_data:
|
|
481
|
-
doc_id = row['item_id']
|
|
482
|
-
embedding = np.frombuffer(row['item_vector'], dtype=np.float32).tolist()
|
|
483
|
-
metadata = json.loads(row['metadata'])
|
|
484
|
-
|
|
485
|
-
distance = np.linalg.norm(embedding - query_vector)
|
|
486
|
-
|
|
487
|
-
# Reconstruct the Document object with its original ID
|
|
488
|
-
doc = Document(id=doc_id, embedding=list(embedding), **metadata)
|
|
489
|
-
results.append((doc, float(distance)))
|
|
490
209
|
|
|
491
|
-
|
|
492
|
-
|
|
210
|
+
def subscribe(self, channel_name: str) -> SubWrapper:
|
|
211
|
+
"""Subscribes to a channel, returning a synchronous iterator."""
|
|
212
|
+
return SubWrapper(self._conn, channel_name)
|