beaver-db 2.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
beaver/server.py ADDED
@@ -0,0 +1,452 @@
1
+ try:
2
+ from typing import Any, Optional, List
3
+ import json
4
+ from datetime import datetime, timedelta, timezone
5
+ from fastapi import (
6
+ FastAPI,
7
+ HTTPException,
8
+ Body,
9
+ UploadFile,
10
+ File,
11
+ Form,
12
+ Response,
13
+ WebSocket,
14
+ WebSocketDisconnect,
15
+ )
16
+ import uvicorn
17
+ from pydantic import BaseModel, Field
18
+ except ImportError:
19
+ raise ImportError(
20
+ 'Please install server dependencies with: pip install "beaver-db[server]"'
21
+ )
22
+
23
+ from .core import BeaverDB
24
+ from .collections import Document, WalkDirection
25
+
26
+
27
+ # --- Pydantic Models for Collections ---
28
+
29
+
30
+ class IndexRequest(BaseModel):
31
+ id: Optional[str] = None
32
+ embedding: Optional[List[float]] = None
33
+ metadata: dict = Field(default_factory=dict)
34
+ fts: bool = True
35
+ fuzzy: bool = False
36
+
37
+
38
+ class SearchRequest(BaseModel):
39
+ vector: List[float]
40
+ top_k: int = 10
41
+
42
+
43
+ class MatchRequest(BaseModel):
44
+ query: str
45
+ on: Optional[List[str]] = None
46
+ top_k: int = 10
47
+ fuzziness: int = 0
48
+
49
+
50
+ class ConnectRequest(BaseModel):
51
+ source_id: str
52
+ target_id: str
53
+ label: str
54
+ metadata: Optional[dict] = None
55
+
56
+
57
+ class WalkRequest(BaseModel):
58
+ labels: List[str]
59
+ depth: int
60
+ direction: WalkDirection = WalkDirection.OUTGOING
61
+
62
+
63
+ class CountResponse(BaseModel):
64
+ count: int
65
+
66
+
67
+ def build(db: BeaverDB) -> FastAPI:
68
+ """Constructs a FastAPI instance for a given BeaverDB."""
69
+ app = FastAPI(title="BeaverDB Server")
70
+
71
+ # --- Dicts Endpoints ---
72
+
73
+ @app.get("/dicts/{name}/{key}", tags=["Dicts"])
74
+ def get_dict_item(name: str, key: str) -> Any:
75
+ """Retrieves the value for a specific key."""
76
+ d = db.dict(name)
77
+ value = d.get(key)
78
+ if value is None:
79
+ raise HTTPException(
80
+ status_code=404, detail=f"Key '{key}' not found in dictionary '{name}'"
81
+ )
82
+ return value
83
+
84
+ @app.put("/dicts/{name}/{key}", tags=["Dicts"])
85
+ def set_dict_item(name: str, key: str, value: Any = Body(...)):
86
+ """Sets or updates the value for a specific key."""
87
+ d = db.dict(name)
88
+ d[key] = value
89
+ return {"status": "ok"}
90
+
91
+ @app.delete("/dicts/{name}/{key}", tags=["Dicts"])
92
+ def delete_dict_item(name: str, key: str):
93
+ """Deletes a key-value pair."""
94
+ d = db.dict(name)
95
+ try:
96
+ del d[key]
97
+ return {"status": "ok"}
98
+ except KeyError:
99
+ raise HTTPException(
100
+ status_code=404, detail=f"Key '{key}' not found in dictionary '{name}'"
101
+ )
102
+
103
+ @app.get("/dicts/{name}/count", tags=["Dicts"], response_model=CountResponse)
104
+ def get_dict_count(name: str) -> dict:
105
+ """Retrieves the number of key-value pairs in the dictionary."""
106
+ d = db.dict(name)
107
+ return {"count": len(d)}
108
+
109
+ # --- Lists Endpoints ---
110
+
111
+ @app.get("/lists/{name}", tags=["Lists"])
112
+ def get_list(name: str) -> list:
113
+ """Retrieves all items in the list."""
114
+ l = db.list(name)
115
+ return l[:]
116
+
117
+ @app.get("/lists/{name}/{index}", tags=["Lists"])
118
+ def get_list_item(name: str, index: int) -> Any:
119
+ """Retrieves the item at a specific index."""
120
+ l = db.list(name)
121
+ try:
122
+ return l[index]
123
+ except IndexError:
124
+ raise HTTPException(
125
+ status_code=404, detail=f"Index {index} out of bounds for list '{name}'"
126
+ )
127
+
128
+ @app.post("/lists/{name}", tags=["Lists"])
129
+ def push_list_item(name: str, value: Any = Body(...)):
130
+ """Adds an item to the end of the list."""
131
+ l = db.list(name)
132
+ l.push(value)
133
+ return {"status": "ok"}
134
+
135
+ @app.put("/lists/{name}/{index}", tags=["Lists"])
136
+ def update_list_item(name: str, index: int, value: Any = Body(...)):
137
+ """Updates the item at a specific index."""
138
+ l = db.list(name)
139
+ try:
140
+ l[index] = value
141
+ return {"status": "ok"}
142
+ except IndexError:
143
+ raise HTTPException(
144
+ status_code=404, detail=f"Index {index} out of bounds for list '{name}'"
145
+ )
146
+
147
+ @app.delete("/lists/{name}/{index}", tags=["Lists"])
148
+ def delete_list_item(name: str, index: int):
149
+ """Deletes the item at a specific index."""
150
+ l = db.list(name)
151
+ try:
152
+ del l[index]
153
+ return {"status": "ok"}
154
+ except IndexError:
155
+ raise HTTPException(
156
+ status_code=404, detail=f"Index {index} out of bounds for list '{name}'"
157
+ )
158
+
159
+ @app.get("/lists/{name}/count", tags=["Lists"], response_model=CountResponse)
160
+ def get_list_count(name: str) -> dict:
161
+ """Retrieves the number of items in the list."""
162
+ l = db.list(name)
163
+ return {"count": len(l)}
164
+
165
+ # --- Queues Endpoints ---
166
+
167
+ @app.get("/queues/{name}/peek", tags=["Queues"])
168
+ def peek_queue_item(name: str) -> Any:
169
+ """Retrieves the highest-priority item from the queue without removing it."""
170
+ q = db.queue(name)
171
+ item = q.peek()
172
+ if item is None:
173
+ raise HTTPException(status_code=404, detail=f"Queue '{name}' is empty")
174
+ return item
175
+
176
+ @app.post("/queues/{name}/put", tags=["Queues"])
177
+ def put_queue_item(name: str, data: Any = Body(...), priority: float = Body(...)):
178
+ """Adds an item to the queue with a specific priority."""
179
+ q = db.queue(name)
180
+ q.put(data=data, priority=priority)
181
+ return {"status": "ok"}
182
+
183
+ @app.delete("/queues/{name}/get", tags=["Queues"])
184
+ def get_queue_item(name: str, timeout: float = 5.0) -> Any:
185
+ """
186
+ Atomically retrieves and removes the highest-priority item from the queue,
187
+ blocking until an item is available or the timeout is reached.
188
+ """
189
+ q = db.queue(name)
190
+ try:
191
+ item = q.get(block=True, timeout=timeout)
192
+ return item
193
+ except TimeoutError:
194
+ raise HTTPException(
195
+ status_code=408,
196
+ detail=f"Request timed out after {timeout}s waiting for an item in queue '{name}'",
197
+ )
198
+ except IndexError:
199
+ # This case is less likely with block=True but good to handle
200
+ raise HTTPException(status_code=404, detail=f"Queue '{name}' is empty")
201
+
202
+ @app.get("/queues/{name}/count", tags=["Queues"], response_model=CountResponse)
203
+ def get_queue_count(name: str) -> dict:
204
+ """RetrieVIes the number of items currently in the queue."""
205
+ q = db.queue(name)
206
+ return {"count": len(q)}
207
+
208
+ # --- Blobs Endpoints ---
209
+
210
+ @app.get("/blobs/{name}/{key}", response_class=Response, tags=["Blobs"])
211
+ def get_blob(name: str, key: str):
212
+ """Retrieves a blob as a binary file."""
213
+ blobs = db.blob(name)
214
+ blob = blobs.get(key)
215
+ if blob is None:
216
+ raise HTTPException(
217
+ status_code=404,
218
+ detail=f"Blob with key '{key}' not found in store '{name}'",
219
+ )
220
+ # Return the raw bytes with a generic binary content type
221
+ return Response(content=blob.data, media_type="application/octet-stream")
222
+
223
+ @app.put("/blobs/{name}/{key}", tags=["Blobs"])
224
+ async def put_blob(
225
+ name: str,
226
+ key: str,
227
+ data: UploadFile = File(...),
228
+ metadata: Optional[str] = Form(None),
229
+ ):
230
+ """Stores a blob (binary file) with optional JSON metadata."""
231
+ blobs = db.blob(name)
232
+ file_bytes = await data.read()
233
+
234
+ meta_dict = None
235
+ if metadata:
236
+ try:
237
+ meta_dict = json.loads(metadata)
238
+ except json.JSONDecodeError:
239
+ raise HTTPException(
240
+ status_code=400, detail="Invalid JSON format for metadata."
241
+ )
242
+
243
+ blobs.put(key=key, data=file_bytes, metadata=meta_dict)
244
+ return {"status": "ok"}
245
+
246
+ @app.delete("/blobs/{name}/{key}", tags=["Blobs"])
247
+ def delete_blob(name: str, key: str):
248
+ """Deletes a blob from the store."""
249
+ blobs = db.blob(name)
250
+ try:
251
+ blobs.delete(key)
252
+ return {"status": "ok"}
253
+ except KeyError:
254
+ raise HTTPException(
255
+ status_code=404,
256
+ detail=f"Blob with key '{key}' not found in store '{name}'",
257
+ )
258
+
259
+ @app.get("/blobs/{name}/count", tags=["Blobs"], response_model=CountResponse)
260
+ def get_blob_count(name: str) -> dict:
261
+ """Retrieves the number of blobs in the store."""
262
+ b = db.blob(name)
263
+ return {"count": len(b)}
264
+
265
+ # --- Logs Endpoints ---
266
+
267
+ @app.post("/logs/{name}", tags=["Logs"])
268
+ def create_log_entry(name: str, data: Any = Body(...)):
269
+ """Adds a new entry to the log."""
270
+ log = db.log(name)
271
+ log.log(data)
272
+ return {"status": "ok"}
273
+
274
+ @app.get("/logs/{name}/range", tags=["Logs"])
275
+ def get_log_range(name: str, start: datetime, end: datetime) -> list:
276
+ """Retrieves log entries within a specific time window."""
277
+ log = db.log(name)
278
+ # Ensure datetimes are timezone-aware (UTC) for correct comparison
279
+ start_utc = (
280
+ start.astimezone(timezone.utc)
281
+ if start.tzinfo
282
+ else start.replace(tzinfo=timezone.utc)
283
+ )
284
+ end_utc = (
285
+ end.astimezone(timezone.utc)
286
+ if end.tzinfo
287
+ else end.replace(tzinfo=timezone.utc)
288
+ )
289
+ return log.range(start=start_utc, end=end_utc)
290
+
291
+ @app.websocket("/logs/{name}/live", name="Logs")
292
+ async def live_log_feed(
293
+ websocket: WebSocket,
294
+ name: str,
295
+ window_seconds: int = 5,
296
+ period_seconds: int = 1,
297
+ ):
298
+ """Streams live, aggregated log data over a WebSocket."""
299
+ await websocket.accept()
300
+
301
+ async_logs = db.log(name).as_async()
302
+
303
+ # This simple aggregator function runs in the background and returns a
304
+ # JSON-serializable summary of the data in the current window.
305
+ def simple_aggregator(window):
306
+ return {
307
+ "count": len(window),
308
+ "latest_timestamp": window[-1]["timestamp"] if window else None,
309
+ }
310
+
311
+ live_stream = async_logs.live(
312
+ window=timedelta(seconds=window_seconds),
313
+ period=timedelta(seconds=period_seconds),
314
+ aggregator=simple_aggregator,
315
+ )
316
+
317
+ try:
318
+ async for summary in live_stream:
319
+ await websocket.send_json(summary)
320
+ except WebSocketDisconnect:
321
+ print(f"Client disconnected from log '{name}' live feed.")
322
+ finally:
323
+ # Cleanly close the underlying iterator and its background thread.
324
+ live_stream.close()
325
+
326
+ # --- Channels Endpoints ---
327
+
328
+ @app.post("/channels/{name}/publish", tags=["Channels"])
329
+ def publish_to_channel(name: str, payload: Any = Body(...)):
330
+ """Publishes a message to the specified channel."""
331
+ channel = db.channel(name)
332
+ channel.publish(payload)
333
+ return {"status": "ok"}
334
+
335
+ @app.websocket("/channels/{name}/subscribe", name="Channels")
336
+ async def subscribe_to_channel(websocket: WebSocket, name: str):
337
+ """Subscribes to a channel and streams messages over a WebSocket."""
338
+ await websocket.accept()
339
+
340
+ async_channel = db.channel(name).as_async()
341
+
342
+ try:
343
+ async with async_channel.subscribe() as listener:
344
+ async for message in listener.listen():
345
+ await websocket.send_json(message)
346
+ except WebSocketDisconnect:
347
+ print(f"Client disconnected from channel '{name}' subscription.")
348
+
349
+ # --- Collections Endpoints ---
350
+
351
+ @app.get("/collections/{name}", tags=["Collections"])
352
+ def get_all_documents(name: str) -> List[dict]:
353
+ """Retrieves all documents in the collection."""
354
+ collection = db.collection(name)
355
+ return [doc.to_dict(metadata_only=False) for doc in collection]
356
+
357
+ @app.post("/collections/{name}/index", tags=["Collections"])
358
+ def index_document(name: str, req: IndexRequest):
359
+ """Indexes a document in the specified collection."""
360
+ collection = db.collection(name)
361
+ doc = Document(id=req.id, embedding=req.embedding, **req.metadata)
362
+ try:
363
+ collection.index(doc, fts=req.fts, fuzzy=req.fuzzy)
364
+ return {"status": "ok", "id": doc.id}
365
+ except TypeError as e:
366
+ if "vector" in str(e):
367
+ raise HTTPException(
368
+ status_code=501,
369
+ detail="Vector indexing requires the '[vector]' extra. Install with: pip install \"beaver-db[vector]\"",
370
+ )
371
+ raise e
372
+
373
+ @app.post("/collections/{name}/search", tags=["Collections"])
374
+ def search_collection(name: str, req: SearchRequest) -> List[dict]:
375
+ """Performs a vector search on the collection."""
376
+ collection = db.collection(name)
377
+ try:
378
+ results = collection.search(vector=req.vector, top_k=req.top_k)
379
+ return [
380
+ {"document": doc.to_dict(metadata_only=False), "distance": dist}
381
+ for doc, dist in results
382
+ ]
383
+ except TypeError as e:
384
+ if "vector" in str(e):
385
+ raise HTTPException(
386
+ status_code=501,
387
+ detail="Vector search requires the '[vector]' extra. Install with: pip install \"beaver-db[vector]\"",
388
+ )
389
+ raise e
390
+
391
+ @app.post("/collections/{name}/match", tags=["Collections"])
392
+ def match_collection(name: str, req: MatchRequest) -> List[dict]:
393
+ """Performs a full-text or fuzzy search on the collection."""
394
+ collection = db.collection(name)
395
+ results = collection.match(
396
+ query=req.query, on=req.on, top_k=req.top_k, fuzziness=req.fuzziness
397
+ )
398
+ return [
399
+ {"document": doc.to_dict(metadata_only=False), "score": score}
400
+ for doc, score in results
401
+ ]
402
+
403
+ @app.post("/collections/{name}/connect", tags=["Collections"])
404
+ def connect_documents(name: str, req: ConnectRequest):
405
+ """Creates a directed edge between two documents."""
406
+ collection = db.collection(name)
407
+ source_doc = Document(id=req.source_id)
408
+ target_doc = Document(id=req.target_id)
409
+ collection.connect(
410
+ source=source_doc, target=target_doc, label=req.label, metadata=req.metadata
411
+ )
412
+ return {"status": "ok"}
413
+
414
+ @app.get("/collections/{name}/{doc_id}/neighbors", tags=["Collections"])
415
+ def get_neighbors(
416
+ name: str, doc_id: str, label: Optional[str] = None
417
+ ) -> List[dict]:
418
+ """Retrieves the neighboring documents for a given document."""
419
+ collection = db.collection(name)
420
+ doc = Document(id=doc_id)
421
+ neighbors = collection.neighbors(doc, label=label)
422
+ return [n.to_dict(metadata_only=False) for n in neighbors]
423
+
424
+ @app.post("/collections/{name}/{doc_id}/walk", tags=["Collections"])
425
+ def walk_graph(name: str, doc_id: str, req: WalkRequest) -> List[dict]:
426
+ """Performs a graph traversal (BFS) from a starting document."""
427
+ collection = db.collection(name)
428
+ source_doc = Document(id=doc_id)
429
+ results = collection.walk(
430
+ source=source_doc,
431
+ labels=req.labels,
432
+ depth=req.depth,
433
+ outgoing=req.direction,
434
+ )
435
+ return [doc.to_dict(metadata_only=False) for doc in results]
436
+
437
+ @app.get(
438
+ "/collections/{name}/count", tags=["Collections"], response_model=CountResponse
439
+ )
440
+ def get_collection_count(name: str) -> dict:
441
+ """RetrieRetrieves the number of documents in the collection."""
442
+ c = db.collection(name)
443
+ return {"count": len(c)}
444
+
445
+ return app
446
+
447
+
448
+ def serve(db_path: str, host: str, port: int):
449
+ """Initializes and runs the Uvicorn server."""
450
+ db = BeaverDB(db_path)
451
+ app = build(db)
452
+ uvicorn.run(app, host=host, port=port)