surrealdb-orm 0.1.3__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. surreal_orm/__init__.py +78 -3
  2. surreal_orm/aggregations.py +164 -0
  3. surreal_orm/auth/__init__.py +15 -0
  4. surreal_orm/auth/access.py +167 -0
  5. surreal_orm/auth/mixins.py +302 -0
  6. surreal_orm/cli/__init__.py +15 -0
  7. surreal_orm/cli/commands.py +369 -0
  8. surreal_orm/connection_manager.py +58 -18
  9. surreal_orm/fields/__init__.py +36 -0
  10. surreal_orm/fields/encrypted.py +166 -0
  11. surreal_orm/fields/relation.py +465 -0
  12. surreal_orm/migrations/__init__.py +51 -0
  13. surreal_orm/migrations/executor.py +380 -0
  14. surreal_orm/migrations/generator.py +272 -0
  15. surreal_orm/migrations/introspector.py +305 -0
  16. surreal_orm/migrations/migration.py +188 -0
  17. surreal_orm/migrations/operations.py +531 -0
  18. surreal_orm/migrations/state.py +406 -0
  19. surreal_orm/model_base.py +594 -135
  20. surreal_orm/py.typed +0 -0
  21. surreal_orm/query_set.py +609 -34
  22. surreal_orm/relations.py +645 -0
  23. surreal_orm/surreal_function.py +95 -0
  24. surreal_orm/surreal_ql.py +113 -0
  25. surreal_orm/types.py +86 -0
  26. surreal_sdk/README.md +79 -0
  27. surreal_sdk/__init__.py +151 -0
  28. surreal_sdk/connection/__init__.py +17 -0
  29. surreal_sdk/connection/base.py +516 -0
  30. surreal_sdk/connection/http.py +421 -0
  31. surreal_sdk/connection/pool.py +244 -0
  32. surreal_sdk/connection/websocket.py +519 -0
  33. surreal_sdk/exceptions.py +71 -0
  34. surreal_sdk/functions.py +607 -0
  35. surreal_sdk/protocol/__init__.py +13 -0
  36. surreal_sdk/protocol/rpc.py +218 -0
  37. surreal_sdk/py.typed +0 -0
  38. surreal_sdk/pyproject.toml +49 -0
  39. surreal_sdk/streaming/__init__.py +31 -0
  40. surreal_sdk/streaming/change_feed.py +278 -0
  41. surreal_sdk/streaming/live_query.py +265 -0
  42. surreal_sdk/streaming/live_select.py +369 -0
  43. surreal_sdk/transaction.py +386 -0
  44. surreal_sdk/types.py +346 -0
  45. surrealdb_orm-0.5.0.dist-info/METADATA +465 -0
  46. surrealdb_orm-0.5.0.dist-info/RECORD +52 -0
  47. {surrealdb_orm-0.1.3.dist-info → surrealdb_orm-0.5.0.dist-info}/WHEEL +1 -1
  48. surrealdb_orm-0.5.0.dist-info/entry_points.txt +2 -0
  49. {surrealdb_orm-0.1.3.dist-info → surrealdb_orm-0.5.0.dist-info}/licenses/LICENSE +1 -1
  50. surrealdb_orm-0.1.3.dist-info/METADATA +0 -184
  51. surrealdb_orm-0.1.3.dist-info/RECORD +0 -11
@@ -0,0 +1,519 @@
1
+ """
2
+ WebSocket Connection Implementation for SurrealDB SDK.
3
+
4
+ Provides stateful WebSocket-based connection for real-time features.
5
+ """
6
+
7
+ from typing import TYPE_CHECKING, Any, Callable, Coroutine
8
+ import asyncio
9
+ import json
10
+
11
+ import aiohttp
12
+ from aiohttp import ClientWSTimeout
13
+
14
+ from .base import BaseSurrealConnection
15
+
16
+ if TYPE_CHECKING:
17
+ from ..transaction import WebSocketTransaction
18
+ from ..streaming.live_select import LiveSelectStream, LiveSubscriptionParams
19
+ from ..protocol.rpc import RPCRequest, RPCResponse
20
+ from ..exceptions import ConnectionError, LiveQueryError, TimeoutError
21
+
22
+
23
+ # Type alias for live query callbacks
24
+ LiveCallback = Callable[[dict[str, Any]], Coroutine[Any, Any, None]]
25
+
26
+
27
+ class WebSocketConnection(BaseSurrealConnection):
28
+ """
29
+ WebSocket-based connection to SurrealDB.
30
+
31
+ This connection is stateful - session is maintained across requests.
32
+ Required for Live Queries and session variables.
33
+ Ideal for real-time applications, dashboards, and collaborative features.
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ url: str,
39
+ namespace: str,
40
+ database: str,
41
+ timeout: float = 30.0,
42
+ auto_reconnect: bool = True,
43
+ reconnect_interval: float = 1.0,
44
+ max_reconnect_attempts: int = 5,
45
+ ):
46
+ """
47
+ Initialize WebSocket connection.
48
+
49
+ Args:
50
+ url: SurrealDB WebSocket URL (e.g., "ws://localhost:8000")
51
+ namespace: Target namespace
52
+ database: Target database
53
+ timeout: Request timeout in seconds
54
+ auto_reconnect: Whether to automatically reconnect on disconnect
55
+ reconnect_interval: Seconds between reconnection attempts
56
+ max_reconnect_attempts: Maximum reconnection attempts
57
+ """
58
+ # Normalize URL to WebSocket
59
+ if url.startswith("http://"):
60
+ url = url.replace("http://", "ws://", 1)
61
+ elif url.startswith("https://"):
62
+ url = url.replace("https://", "wss://", 1)
63
+
64
+ # Ensure /rpc suffix
65
+ if not url.endswith("/rpc"):
66
+ url = url.rstrip("/") + "/rpc"
67
+
68
+ super().__init__(url, namespace, database, timeout)
69
+
70
+ self.auto_reconnect = auto_reconnect
71
+ self.reconnect_interval = reconnect_interval
72
+ self.max_reconnect_attempts = max_reconnect_attempts
73
+
74
+ self._session: aiohttp.ClientSession | None = None
75
+ self._ws: aiohttp.ClientWebSocketResponse | None = None
76
+ self._request_id = 0
77
+ self._pending: dict[int, asyncio.Future[RPCResponse]] = {}
78
+ self._live_callbacks: dict[str, LiveCallback] = {}
79
+ self._live_subscriptions: dict[str, "LiveSubscriptionParams"] = {}
80
+ self._reader_task: asyncio.Task[None] | None = None
81
+ self._reconnect_task: asyncio.Task[None] | None = None
82
+ self._closing = False
83
+
84
+ def _next_request_id(self) -> int:
85
+ """Generate next request ID."""
86
+ self._request_id += 1
87
+ return self._request_id
88
+
89
+ async def connect(self) -> None:
90
+ """Establish WebSocket connection."""
91
+ if self._connected:
92
+ return
93
+
94
+ self._closing = False
95
+ self._session = aiohttp.ClientSession()
96
+
97
+ try:
98
+ # Specify protocol format explicitly to avoid deprecation warning
99
+ # SurrealDB 2.0+ requires explicit format specification (json or cbor)
100
+ self._ws = await self._session.ws_connect(
101
+ self.url,
102
+ timeout=ClientWSTimeout(ws_close=self.timeout),
103
+ protocols=["json"], # Explicit JSON format for RPC
104
+ )
105
+ self._connected = True
106
+
107
+ # Start message reader loop
108
+ self._reader_task = asyncio.create_task(self._read_loop())
109
+
110
+ # Yield to event loop to allow read loop to start
111
+ await asyncio.sleep(0)
112
+
113
+ # Set namespace and database
114
+ await self.use(self.namespace, self.database)
115
+
116
+ except aiohttp.ClientError as e:
117
+ await self._cleanup()
118
+ raise ConnectionError(f"WebSocket connection failed: {e}")
119
+
120
+ async def close(self) -> None:
121
+ """Close WebSocket connection."""
122
+ self._closing = True
123
+ await self._cleanup()
124
+
125
+ async def _cleanup(self) -> None:
126
+ """Clean up connection resources."""
127
+ self._connected = False
128
+ self._authenticated = False
129
+
130
+ # Cancel reader task
131
+ if self._reader_task:
132
+ self._reader_task.cancel()
133
+ try:
134
+ await self._reader_task
135
+ except asyncio.CancelledError:
136
+ pass
137
+ self._reader_task = None
138
+
139
+ # Cancel reconnect task
140
+ if self._reconnect_task:
141
+ self._reconnect_task.cancel()
142
+ try:
143
+ await self._reconnect_task
144
+ except asyncio.CancelledError:
145
+ pass
146
+ self._reconnect_task = None
147
+
148
+ # Fail all pending requests
149
+ for future in self._pending.values():
150
+ if not future.done():
151
+ future.set_exception(ConnectionError("Connection closed"))
152
+ self._pending.clear()
153
+
154
+ # Close WebSocket
155
+ if self._ws:
156
+ await self._ws.close()
157
+ self._ws = None
158
+
159
+ # Close session
160
+ if self._session:
161
+ await self._session.close()
162
+ self._session = None
163
+
164
+ async def _read_loop(self) -> None:
165
+ """Background task to read WebSocket messages."""
166
+ if not self._ws:
167
+ return
168
+
169
+ try:
170
+ async for msg in self._ws:
171
+ if msg.type == aiohttp.WSMsgType.TEXT:
172
+ await self._handle_message(msg.data)
173
+ elif msg.type == aiohttp.WSMsgType.BINARY:
174
+ # CBOR support could be added here
175
+ pass
176
+ elif msg.type == aiohttp.WSMsgType.ERROR:
177
+ break
178
+ elif msg.type == aiohttp.WSMsgType.CLOSED:
179
+ break
180
+ elif msg.type == aiohttp.WSMsgType.CLOSE:
181
+ break
182
+
183
+ except asyncio.CancelledError:
184
+ raise
185
+ except Exception:
186
+ pass
187
+ finally:
188
+ if not self._closing and self.auto_reconnect:
189
+ self._reconnect_task = asyncio.create_task(self._reconnect())
190
+
191
+ async def _handle_message(self, data: str) -> None:
192
+ """Handle incoming WebSocket message."""
193
+ try:
194
+ message = json.loads(data)
195
+ except json.JSONDecodeError:
196
+ return
197
+
198
+ msg_id = message.get("id")
199
+
200
+ # Check if this is a response to a pending request
201
+ # Convert msg_id to int for comparison since SurrealDB may return string IDs
202
+ if msg_id is not None:
203
+ try:
204
+ msg_id_int = int(msg_id)
205
+ except (ValueError, TypeError):
206
+ msg_id_int = None
207
+
208
+ if msg_id_int is not None and msg_id_int in self._pending:
209
+ response = RPCResponse.from_dict(message)
210
+ future = self._pending.pop(msg_id_int)
211
+ if not future.done():
212
+ future.set_result(response)
213
+ return
214
+
215
+ # Check if this is a live query notification
216
+ if "action" in message:
217
+ live_id = message.get("id")
218
+ if live_id and live_id in self._live_callbacks:
219
+ callback = self._live_callbacks[live_id]
220
+ asyncio.create_task(callback(message))
221
+
222
+ async def _reconnect(self) -> None:
223
+ """Attempt to reconnect after disconnection."""
224
+ attempts = 0
225
+ while attempts < self.max_reconnect_attempts and not self._closing:
226
+ attempts += 1
227
+ await asyncio.sleep(self.reconnect_interval)
228
+
229
+ try:
230
+ self._session = aiohttp.ClientSession()
231
+ self._ws = await self._session.ws_connect(
232
+ self.url,
233
+ timeout=ClientWSTimeout(ws_close=self.timeout),
234
+ protocols=["json"],
235
+ )
236
+ self._connected = True
237
+ self._reader_task = asyncio.create_task(self._read_loop())
238
+
239
+ # Re-authenticate if we had a token
240
+ if self._token:
241
+ await self.rpc("authenticate", [self._token])
242
+
243
+ # Set namespace and database
244
+ await self.use(self.namespace, self.database)
245
+
246
+ # Re-establish live queries with auto-resubscribe
247
+ await self._resubscribe_all()
248
+
249
+ return
250
+
251
+ except Exception:
252
+ await self._cleanup()
253
+ continue
254
+
255
+ async def _resubscribe_all(self) -> None:
256
+ """Re-establish all live subscriptions after reconnect."""
257
+ old_subscriptions = dict(self._live_subscriptions)
258
+ self._live_subscriptions.clear()
259
+ self._live_callbacks.clear()
260
+
261
+ for old_id, params in old_subscriptions.items():
262
+ try:
263
+ new_id = await self._resubscribe_one(params)
264
+
265
+ # Call reconnect callback if provided
266
+ if params.on_reconnect:
267
+ asyncio.create_task(params.on_reconnect(old_id, new_id))
268
+
269
+ except Exception:
270
+ # Failed to resubscribe, skip this one
271
+ pass
272
+
273
+ async def _resubscribe_one(self, params: "LiveSubscriptionParams") -> str:
274
+ """Resubscribe a single live query."""
275
+ # Set session variables for parameters
276
+ for key, value in params.params.items():
277
+ await self.let(key, value)
278
+
279
+ # Build query
280
+ sql = f"LIVE SELECT * FROM {params.table}"
281
+ if params.where:
282
+ sql += f" WHERE {params.where}"
283
+ if params.diff:
284
+ sql += " DIFF"
285
+
286
+ response = await self.query(sql)
287
+
288
+ if response.results:
289
+ first_result = response.results[0]
290
+ if first_result.is_ok:
291
+ result_data = first_result.result
292
+ if isinstance(result_data, str):
293
+ new_id = result_data
294
+ elif isinstance(result_data, dict) and "result" in result_data:
295
+ new_id = str(result_data["result"])
296
+ else:
297
+ raise LiveQueryError("Invalid live query response")
298
+
299
+ # Re-register callback if provided
300
+ if params.callback:
301
+ self._live_callbacks[new_id] = params.callback # type: ignore[assignment]
302
+
303
+ # Store subscription params for future reconnects
304
+ self._live_subscriptions[new_id] = params
305
+
306
+ return new_id
307
+
308
+ raise LiveQueryError("No live query ID returned")
309
+
310
+ def _register_live_subscription(self, live_id: str, params: "LiveSubscriptionParams") -> None:
311
+ """Register a live subscription for auto-resubscribe."""
312
+ self._live_subscriptions[live_id] = params
313
+
314
+ def _unregister_live_subscription(self, live_id: str) -> None:
315
+ """Unregister a live subscription."""
316
+ self._live_subscriptions.pop(live_id, None)
317
+
318
+ async def _send_rpc(self, request: RPCRequest) -> RPCResponse:
319
+ """
320
+ Send RPC request via WebSocket.
321
+
322
+ Args:
323
+ request: The RPC request to send
324
+
325
+ Returns:
326
+ The RPC response
327
+
328
+ Raises:
329
+ ConnectionError: If not connected
330
+ TimeoutError: If request times out
331
+ """
332
+ if not self._ws or not self._connected:
333
+ raise ConnectionError("Not connected. Call connect() first.")
334
+
335
+ request.id = self._next_request_id()
336
+
337
+ # Create future for response
338
+ loop = asyncio.get_running_loop()
339
+ future: asyncio.Future[RPCResponse] = loop.create_future()
340
+ self._pending[request.id] = future
341
+
342
+ try:
343
+ # Send request
344
+ await self._ws.send_str(request.to_json())
345
+
346
+ # Wait for response with timeout
347
+ response = await asyncio.wait_for(future, timeout=self.timeout)
348
+ return response
349
+
350
+ except asyncio.TimeoutError:
351
+ self._pending.pop(request.id, None)
352
+ raise TimeoutError(f"Request timed out after {self.timeout}s")
353
+ except Exception as e:
354
+ self._pending.pop(request.id, None)
355
+ raise ConnectionError(f"Request failed: {e}")
356
+
357
+ # WebSocket-specific methods
358
+
359
+ async def live(
360
+ self,
361
+ table: str,
362
+ callback: LiveCallback,
363
+ diff: bool = False,
364
+ ) -> str:
365
+ """
366
+ Start a live query subscription.
367
+
368
+ Args:
369
+ table: Table to watch
370
+ callback: Async callback for change notifications
371
+ diff: If True, receive only changed fields
372
+
373
+ Returns:
374
+ Live query UUID
375
+
376
+ Raises:
377
+ LiveQueryError: If live query fails to start
378
+ """
379
+ sql = f"LIVE SELECT * FROM {table}"
380
+ if diff:
381
+ sql += " DIFF"
382
+
383
+ try:
384
+ response = await self.query(sql)
385
+
386
+ # Extract live query UUID from result
387
+ if response.results:
388
+ first_result = response.results[0]
389
+ if first_result.is_ok:
390
+ # Result can be string UUID directly or dict with "result" key
391
+ if isinstance(first_result.result, str):
392
+ live_id = first_result.result
393
+ elif isinstance(first_result.result, dict) and "result" in first_result.result:
394
+ live_id = str(first_result.result["result"])
395
+ else:
396
+ raise LiveQueryError("Invalid live query response")
397
+
398
+ self._live_callbacks[live_id] = callback
399
+ return live_id
400
+
401
+ raise LiveQueryError("No live query ID returned")
402
+
403
+ except Exception as e:
404
+ raise LiveQueryError(f"Failed to start live query: {e}")
405
+
406
+ async def kill(self, live_id: str) -> None:
407
+ """
408
+ Stop a live query subscription.
409
+
410
+ Args:
411
+ live_id: Live query UUID to stop
412
+ """
413
+ await self.rpc("kill", [live_id])
414
+ self._live_callbacks.pop(live_id, None)
415
+
416
+ async def let(self, name: str, value: Any) -> None:
417
+ """
418
+ Set a session variable.
419
+
420
+ Args:
421
+ name: Variable name
422
+ value: Variable value
423
+ """
424
+ await self.rpc("let", [name, value])
425
+
426
+ async def unset(self, name: str) -> None:
427
+ """
428
+ Remove a session variable.
429
+
430
+ Args:
431
+ name: Variable name to remove
432
+ """
433
+ await self.rpc("unset", [name])
434
+
435
+ @property
436
+ def live_queries(self) -> list[str]:
437
+ """Get list of active live query IDs."""
438
+ return list(self._live_callbacks.keys())
439
+
440
+ async def kill_all_live_queries(self) -> None:
441
+ """Stop all active live queries."""
442
+ for live_id in list(self._live_callbacks.keys()):
443
+ try:
444
+ await self.kill(live_id)
445
+ except Exception:
446
+ pass
447
+
448
+ # Transaction support
449
+
450
+ def transaction(self) -> "WebSocketTransaction":
451
+ """
452
+ Create a new WebSocket transaction.
453
+
454
+ WebSocket transactions use server-side state with BEGIN/COMMIT/ROLLBACK.
455
+ Operations are executed immediately within the transaction context.
456
+
457
+ Usage:
458
+ async with conn.transaction() as tx:
459
+ await tx.create("users", {"name": "Alice"})
460
+ await tx.create("orders", {"user": "users:alice"})
461
+ # Committed on successful exit, rolled back on exception
462
+
463
+ Returns:
464
+ WebSocketTransaction context manager
465
+ """
466
+ from ..transaction import WebSocketTransaction
467
+
468
+ return WebSocketTransaction(self)
469
+
470
+ # Live Select Stream API
471
+
472
+ def live_select(
473
+ self,
474
+ table: str,
475
+ where: str | None = None,
476
+ params: dict[str, Any] | None = None,
477
+ diff: bool = False,
478
+ auto_resubscribe: bool = True,
479
+ on_reconnect: Callable[[str, str], Coroutine[Any, Any, None]] | None = None,
480
+ ) -> "LiveSelectStream":
481
+ """
482
+ Create a live select stream for real-time change notifications.
483
+
484
+ This method returns an async iterator that yields LiveChange objects
485
+ whenever records matching the query are created, updated, or deleted.
486
+
487
+ Args:
488
+ table: Table to watch (e.g., "players", "game_tables")
489
+ where: Optional WHERE clause filter (e.g., "table_id = $id")
490
+ params: Parameters for the WHERE clause (e.g., {"id": "game_tables:xyz"})
491
+ diff: If True, receive only changed fields
492
+ auto_resubscribe: If True, automatically resubscribe on reconnect
493
+ on_reconnect: Optional callback when resubscribed (old_id, new_id)
494
+
495
+ Returns:
496
+ LiveSelectStream async context manager and iterator
497
+
498
+ Usage:
499
+ async with conn.live_select("players", where="table_id = $id", params={"id": table_id}) as stream:
500
+ async for change in stream:
501
+ match change.action:
502
+ case LiveAction.CREATE:
503
+ print(f"New player: {change.result}")
504
+ case LiveAction.UPDATE:
505
+ print(f"Player updated: {change.record_id}")
506
+ case LiveAction.DELETE:
507
+ print(f"Player left: {change.record_id}")
508
+ """
509
+ from ..streaming.live_select import LiveSelectStream
510
+
511
+ return LiveSelectStream(
512
+ connection=self,
513
+ table=table,
514
+ where=where,
515
+ params=params,
516
+ diff=diff,
517
+ auto_resubscribe=auto_resubscribe,
518
+ on_reconnect=on_reconnect,
519
+ )
@@ -0,0 +1,71 @@
1
+ """
2
+ SurrealDB SDK Exceptions.
3
+
4
+ Custom exception hierarchy for the SDK.
5
+ """
6
+
7
+
8
+ class SurrealDBError(Exception):
9
+ """Base exception for all SurrealDB SDK errors."""
10
+
11
+ def __init__(self, message: str, code: int | None = None):
12
+ self.message = message
13
+ self.code = code
14
+ super().__init__(message)
15
+
16
+
17
+ class ConnectionError(SurrealDBError):
18
+ """Raised when connection to SurrealDB fails."""
19
+
20
+ pass
21
+
22
+
23
+ class AuthenticationError(SurrealDBError):
24
+ """Raised when authentication fails."""
25
+
26
+ pass
27
+
28
+
29
+ class QueryError(SurrealDBError):
30
+ """Raised when a query execution fails."""
31
+
32
+ def __init__(self, message: str, query: str | None = None, code: int | None = None):
33
+ self.query = query
34
+ super().__init__(message, code)
35
+
36
+
37
+ class TimeoutError(SurrealDBError):
38
+ """Raised when an operation times out."""
39
+
40
+ pass
41
+
42
+
43
+ class ValidationError(SurrealDBError):
44
+ """Raised when data validation fails."""
45
+
46
+ pass
47
+
48
+
49
+ class LiveQueryError(SurrealDBError):
50
+ """Raised when a live query operation fails."""
51
+
52
+ pass
53
+
54
+
55
+ class ChangeFeedError(SurrealDBError):
56
+ """Raised when a change feed operation fails."""
57
+
58
+ pass
59
+
60
+
61
+ class TransactionError(SurrealDBError):
62
+ """Raised when a transaction operation fails."""
63
+
64
+ def __init__(
65
+ self,
66
+ message: str,
67
+ code: int | None = None,
68
+ rollback_succeeded: bool | None = None,
69
+ ):
70
+ self.rollback_succeeded = rollback_succeeded
71
+ super().__init__(message, code)