surrealdb-orm 0.1.4__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- surreal_orm/__init__.py +72 -3
- surreal_orm/aggregations.py +164 -0
- surreal_orm/auth/__init__.py +15 -0
- surreal_orm/auth/access.py +167 -0
- surreal_orm/auth/mixins.py +302 -0
- surreal_orm/cli/__init__.py +15 -0
- surreal_orm/cli/commands.py +369 -0
- surreal_orm/connection_manager.py +58 -18
- surreal_orm/fields/__init__.py +36 -0
- surreal_orm/fields/encrypted.py +166 -0
- surreal_orm/fields/relation.py +465 -0
- surreal_orm/migrations/__init__.py +51 -0
- surreal_orm/migrations/executor.py +380 -0
- surreal_orm/migrations/generator.py +272 -0
- surreal_orm/migrations/introspector.py +305 -0
- surreal_orm/migrations/migration.py +188 -0
- surreal_orm/migrations/operations.py +531 -0
- surreal_orm/migrations/state.py +406 -0
- surreal_orm/model_base.py +530 -44
- surreal_orm/query_set.py +609 -33
- surreal_orm/relations.py +645 -0
- surreal_orm/surreal_function.py +95 -0
- surreal_orm/surreal_ql.py +113 -0
- surreal_orm/types.py +86 -0
- surreal_sdk/README.md +79 -0
- surreal_sdk/__init__.py +151 -0
- surreal_sdk/connection/__init__.py +17 -0
- surreal_sdk/connection/base.py +516 -0
- surreal_sdk/connection/http.py +421 -0
- surreal_sdk/connection/pool.py +244 -0
- surreal_sdk/connection/websocket.py +519 -0
- surreal_sdk/exceptions.py +71 -0
- surreal_sdk/functions.py +607 -0
- surreal_sdk/protocol/__init__.py +13 -0
- surreal_sdk/protocol/rpc.py +218 -0
- surreal_sdk/py.typed +0 -0
- surreal_sdk/pyproject.toml +49 -0
- surreal_sdk/streaming/__init__.py +31 -0
- surreal_sdk/streaming/change_feed.py +278 -0
- surreal_sdk/streaming/live_query.py +265 -0
- surreal_sdk/streaming/live_select.py +369 -0
- surreal_sdk/transaction.py +386 -0
- surreal_sdk/types.py +346 -0
- surrealdb_orm-0.5.0.dist-info/METADATA +465 -0
- surrealdb_orm-0.5.0.dist-info/RECORD +52 -0
- {surrealdb_orm-0.1.4.dist-info → surrealdb_orm-0.5.0.dist-info}/WHEEL +1 -1
- surrealdb_orm-0.5.0.dist-info/entry_points.txt +2 -0
- {surrealdb_orm-0.1.4.dist-info → surrealdb_orm-0.5.0.dist-info}/licenses/LICENSE +1 -1
- surrealdb_orm-0.1.4.dist-info/METADATA +0 -184
- surrealdb_orm-0.1.4.dist-info/RECORD +0 -12
|
@@ -0,0 +1,421 @@
|
|
|
1
|
+
"""
|
|
2
|
+
HTTP Connection Implementation for SurrealDB SDK.
|
|
3
|
+
|
|
4
|
+
Provides stateless HTTP-based connection, ideal for microservices and serverless.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from typing import TYPE_CHECKING, Any
|
|
8
|
+
|
|
9
|
+
import httpx
|
|
10
|
+
|
|
11
|
+
from .base import BaseSurrealConnection
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
from ..transaction import HTTPTransaction
|
|
15
|
+
from ..protocol.rpc import RPCRequest, RPCResponse
|
|
16
|
+
from ..types import AuthResponse
|
|
17
|
+
from ..exceptions import ConnectionError, QueryError
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class HTTPConnection(BaseSurrealConnection):
|
|
21
|
+
"""
|
|
22
|
+
HTTP-based connection to SurrealDB.
|
|
23
|
+
|
|
24
|
+
This connection is stateless - each request is independent.
|
|
25
|
+
Ideal for microservices, serverless, and horizontally scaled applications.
|
|
26
|
+
|
|
27
|
+
Authentication is performed via headers on each request.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
def __init__(
|
|
31
|
+
self,
|
|
32
|
+
url: str,
|
|
33
|
+
namespace: str,
|
|
34
|
+
database: str,
|
|
35
|
+
timeout: float = 30.0,
|
|
36
|
+
):
|
|
37
|
+
"""
|
|
38
|
+
Initialize HTTP connection.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
url: SurrealDB HTTP URL (e.g., "http://localhost:8000")
|
|
42
|
+
namespace: Target namespace
|
|
43
|
+
database: Target database
|
|
44
|
+
timeout: Request timeout in seconds
|
|
45
|
+
"""
|
|
46
|
+
# Normalize URL to HTTP if needed
|
|
47
|
+
if url.startswith("ws://"):
|
|
48
|
+
url = url.replace("ws://", "http://", 1)
|
|
49
|
+
elif url.startswith("wss://"):
|
|
50
|
+
url = url.replace("wss://", "https://", 1)
|
|
51
|
+
|
|
52
|
+
super().__init__(url, namespace, database, timeout)
|
|
53
|
+
self._client: httpx.AsyncClient | None = None
|
|
54
|
+
self._request_id = 0
|
|
55
|
+
|
|
56
|
+
@property
|
|
57
|
+
def headers(self) -> dict[str, str]:
|
|
58
|
+
"""Build request headers."""
|
|
59
|
+
h = {
|
|
60
|
+
"Surreal-NS": self.namespace,
|
|
61
|
+
"Surreal-DB": self.database,
|
|
62
|
+
"Accept": "application/json",
|
|
63
|
+
"Content-Type": "application/json",
|
|
64
|
+
}
|
|
65
|
+
if self._token:
|
|
66
|
+
h["Authorization"] = f"Bearer {self._token}"
|
|
67
|
+
return h
|
|
68
|
+
|
|
69
|
+
def _next_request_id(self) -> int:
|
|
70
|
+
"""Generate next request ID."""
|
|
71
|
+
self._request_id += 1
|
|
72
|
+
return self._request_id
|
|
73
|
+
|
|
74
|
+
async def connect(self) -> None:
|
|
75
|
+
"""Establish HTTP client connection."""
|
|
76
|
+
if self._connected:
|
|
77
|
+
return
|
|
78
|
+
|
|
79
|
+
self._client = httpx.AsyncClient(
|
|
80
|
+
base_url=self.url,
|
|
81
|
+
timeout=self.timeout,
|
|
82
|
+
# Disable connection pooling to avoid event loop binding issues in tests
|
|
83
|
+
limits=httpx.Limits(max_keepalive_connections=0, max_connections=100),
|
|
84
|
+
)
|
|
85
|
+
self._connected = True
|
|
86
|
+
|
|
87
|
+
async def close(self) -> None:
|
|
88
|
+
"""Close HTTP client."""
|
|
89
|
+
if self._client:
|
|
90
|
+
await self._client.aclose()
|
|
91
|
+
self._client = None
|
|
92
|
+
self._connected = False
|
|
93
|
+
self._authenticated = False
|
|
94
|
+
|
|
95
|
+
async def _send_rpc(self, request: RPCRequest) -> RPCResponse:
|
|
96
|
+
"""
|
|
97
|
+
Send RPC request via HTTP POST to /rpc endpoint.
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
request: The RPC request to send
|
|
101
|
+
|
|
102
|
+
Returns:
|
|
103
|
+
The RPC response
|
|
104
|
+
|
|
105
|
+
Raises:
|
|
106
|
+
ConnectionError: If not connected
|
|
107
|
+
QueryError: If request fails
|
|
108
|
+
"""
|
|
109
|
+
if not self._client:
|
|
110
|
+
raise ConnectionError("Not connected. Call connect() first.")
|
|
111
|
+
|
|
112
|
+
request.id = self._next_request_id()
|
|
113
|
+
|
|
114
|
+
try:
|
|
115
|
+
response = await self._client.post(
|
|
116
|
+
"/rpc",
|
|
117
|
+
json=request.to_dict(),
|
|
118
|
+
headers=self.headers,
|
|
119
|
+
)
|
|
120
|
+
response.raise_for_status()
|
|
121
|
+
return RPCResponse.from_dict(response.json())
|
|
122
|
+
|
|
123
|
+
except httpx.HTTPStatusError as e:
|
|
124
|
+
raise QueryError(
|
|
125
|
+
message=f"HTTP error: {e.response.status_code} - {e.response.text}",
|
|
126
|
+
code=e.response.status_code,
|
|
127
|
+
)
|
|
128
|
+
except httpx.RequestError as e:
|
|
129
|
+
raise ConnectionError(f"Request failed: {e}")
|
|
130
|
+
|
|
131
|
+
async def signin(
|
|
132
|
+
self,
|
|
133
|
+
user: str | None = None,
|
|
134
|
+
password: str | None = None,
|
|
135
|
+
namespace: str | None = None,
|
|
136
|
+
database: str | None = None,
|
|
137
|
+
access: str | None = None,
|
|
138
|
+
**credentials: Any,
|
|
139
|
+
) -> AuthResponse:
|
|
140
|
+
"""
|
|
141
|
+
Authenticate with SurrealDB via HTTP.
|
|
142
|
+
|
|
143
|
+
For HTTP connections, this obtains a JWT token that will be
|
|
144
|
+
included in subsequent request headers.
|
|
145
|
+
|
|
146
|
+
Args:
|
|
147
|
+
user: Username (for root/namespace/database auth)
|
|
148
|
+
password: Password (for root/namespace/database auth)
|
|
149
|
+
namespace: Optional namespace scope
|
|
150
|
+
database: Optional database scope
|
|
151
|
+
access: Optional access method (for record access auth)
|
|
152
|
+
**credentials: Additional credentials for record access (email, password, etc.)
|
|
153
|
+
"""
|
|
154
|
+
if not self._client:
|
|
155
|
+
raise ConnectionError("Not connected. Call connect() first.")
|
|
156
|
+
|
|
157
|
+
from ..exceptions import AuthenticationError
|
|
158
|
+
|
|
159
|
+
payload: dict[str, Any] = {}
|
|
160
|
+
if user:
|
|
161
|
+
payload["user"] = user
|
|
162
|
+
if namespace:
|
|
163
|
+
payload["ns"] = namespace
|
|
164
|
+
if database:
|
|
165
|
+
payload["db"] = database
|
|
166
|
+
if access:
|
|
167
|
+
# Record access auth: password goes as 'password' in credentials
|
|
168
|
+
payload["ac"] = access
|
|
169
|
+
if password:
|
|
170
|
+
payload["password"] = password
|
|
171
|
+
else:
|
|
172
|
+
# Root/namespace/database auth: password goes as 'pass'
|
|
173
|
+
if password:
|
|
174
|
+
payload["pass"] = password
|
|
175
|
+
# Add any additional credentials for record access
|
|
176
|
+
payload.update(credentials)
|
|
177
|
+
|
|
178
|
+
try:
|
|
179
|
+
response = await self._client.post(
|
|
180
|
+
"/signin",
|
|
181
|
+
json=payload,
|
|
182
|
+
headers={"Accept": "application/json", "Content-Type": "application/json"},
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
if response.status_code != 200:
|
|
186
|
+
raise AuthenticationError(f"Authentication failed: {response.text}")
|
|
187
|
+
|
|
188
|
+
data = response.json()
|
|
189
|
+
token = data.get("token")
|
|
190
|
+
self._token = token
|
|
191
|
+
self._authenticated = True
|
|
192
|
+
return AuthResponse(token=token, success=True, raw=data)
|
|
193
|
+
|
|
194
|
+
except httpx.RequestError as e:
|
|
195
|
+
raise AuthenticationError(f"Authentication request failed: {e}")
|
|
196
|
+
|
|
197
|
+
async def sql(self, query: str, vars: dict[str, Any] | None = None) -> list[dict[str, Any]]:
|
|
198
|
+
"""
|
|
199
|
+
Execute raw SurrealQL via POST /sql endpoint.
|
|
200
|
+
|
|
201
|
+
This is a direct SQL execution endpoint, alternative to RPC.
|
|
202
|
+
|
|
203
|
+
Args:
|
|
204
|
+
query: SurrealQL query string
|
|
205
|
+
vars: Query variables (passed as query params)
|
|
206
|
+
|
|
207
|
+
Returns:
|
|
208
|
+
Query results
|
|
209
|
+
"""
|
|
210
|
+
if not self._client:
|
|
211
|
+
raise ConnectionError("Not connected. Call connect() first.")
|
|
212
|
+
|
|
213
|
+
try:
|
|
214
|
+
response = await self._client.post(
|
|
215
|
+
"/sql",
|
|
216
|
+
content=query,
|
|
217
|
+
headers=self.headers,
|
|
218
|
+
params=vars,
|
|
219
|
+
)
|
|
220
|
+
response.raise_for_status()
|
|
221
|
+
result: list[dict[str, Any]] = response.json()
|
|
222
|
+
return result
|
|
223
|
+
|
|
224
|
+
except httpx.HTTPStatusError as e:
|
|
225
|
+
raise QueryError(
|
|
226
|
+
message=f"SQL query failed: {e.response.text}",
|
|
227
|
+
query=query,
|
|
228
|
+
code=e.response.status_code,
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
async def health(self) -> bool:
|
|
232
|
+
"""
|
|
233
|
+
Check server health via GET /health endpoint.
|
|
234
|
+
|
|
235
|
+
Returns:
|
|
236
|
+
True if server is healthy
|
|
237
|
+
"""
|
|
238
|
+
if not self._client:
|
|
239
|
+
return False
|
|
240
|
+
|
|
241
|
+
try:
|
|
242
|
+
response = await self._client.get("/health")
|
|
243
|
+
return response.status_code == 200
|
|
244
|
+
except Exception:
|
|
245
|
+
return False
|
|
246
|
+
|
|
247
|
+
async def status(self) -> bool:
|
|
248
|
+
"""
|
|
249
|
+
Check server status via GET /status endpoint.
|
|
250
|
+
|
|
251
|
+
Returns:
|
|
252
|
+
True if server is running
|
|
253
|
+
"""
|
|
254
|
+
if not self._client:
|
|
255
|
+
return False
|
|
256
|
+
|
|
257
|
+
try:
|
|
258
|
+
response = await self._client.get("/status")
|
|
259
|
+
return response.status_code == 200
|
|
260
|
+
except Exception:
|
|
261
|
+
return False
|
|
262
|
+
|
|
263
|
+
# REST-style CRUD endpoints (alternative to RPC)
|
|
264
|
+
|
|
265
|
+
async def rest_select(self, table: str, record_id: str | None = None) -> list[dict[str, Any]]:
|
|
266
|
+
"""
|
|
267
|
+
Select via REST GET /key/:table or /key/:table/:id.
|
|
268
|
+
|
|
269
|
+
Args:
|
|
270
|
+
table: Table name
|
|
271
|
+
record_id: Optional record ID
|
|
272
|
+
|
|
273
|
+
Returns:
|
|
274
|
+
Records
|
|
275
|
+
"""
|
|
276
|
+
if not self._client:
|
|
277
|
+
raise ConnectionError("Not connected. Call connect() first.")
|
|
278
|
+
|
|
279
|
+
path = f"/key/{table}"
|
|
280
|
+
if record_id:
|
|
281
|
+
path += f"/{record_id}"
|
|
282
|
+
|
|
283
|
+
response = await self._client.get(path, headers=self.headers)
|
|
284
|
+
response.raise_for_status()
|
|
285
|
+
result = response.json()
|
|
286
|
+
return result if isinstance(result, list) else [result]
|
|
287
|
+
|
|
288
|
+
async def rest_create(
|
|
289
|
+
self,
|
|
290
|
+
table: str,
|
|
291
|
+
data: dict[str, Any],
|
|
292
|
+
record_id: str | None = None,
|
|
293
|
+
) -> dict[str, Any]:
|
|
294
|
+
"""
|
|
295
|
+
Create via REST POST /key/:table or /key/:table/:id.
|
|
296
|
+
|
|
297
|
+
Args:
|
|
298
|
+
table: Table name
|
|
299
|
+
data: Record data
|
|
300
|
+
record_id: Optional record ID
|
|
301
|
+
|
|
302
|
+
Returns:
|
|
303
|
+
Created record
|
|
304
|
+
"""
|
|
305
|
+
if not self._client:
|
|
306
|
+
raise ConnectionError("Not connected. Call connect() first.")
|
|
307
|
+
|
|
308
|
+
path = f"/key/{table}"
|
|
309
|
+
if record_id:
|
|
310
|
+
path += f"/{record_id}"
|
|
311
|
+
|
|
312
|
+
response = await self._client.post(path, json=data, headers=self.headers)
|
|
313
|
+
response.raise_for_status()
|
|
314
|
+
result: dict[str, Any] = response.json()
|
|
315
|
+
return result
|
|
316
|
+
|
|
317
|
+
async def rest_update(
|
|
318
|
+
self,
|
|
319
|
+
table: str,
|
|
320
|
+
record_id: str,
|
|
321
|
+
data: dict[str, Any],
|
|
322
|
+
) -> dict[str, Any]:
|
|
323
|
+
"""
|
|
324
|
+
Update via REST PUT /key/:table/:id.
|
|
325
|
+
|
|
326
|
+
Args:
|
|
327
|
+
table: Table name
|
|
328
|
+
record_id: Record ID
|
|
329
|
+
data: New record data
|
|
330
|
+
|
|
331
|
+
Returns:
|
|
332
|
+
Updated record
|
|
333
|
+
"""
|
|
334
|
+
if not self._client:
|
|
335
|
+
raise ConnectionError("Not connected. Call connect() first.")
|
|
336
|
+
|
|
337
|
+
response = await self._client.put(
|
|
338
|
+
f"/key/{table}/{record_id}",
|
|
339
|
+
json=data,
|
|
340
|
+
headers=self.headers,
|
|
341
|
+
)
|
|
342
|
+
response.raise_for_status()
|
|
343
|
+
result: dict[str, Any] = response.json()
|
|
344
|
+
return result
|
|
345
|
+
|
|
346
|
+
async def rest_patch(
|
|
347
|
+
self,
|
|
348
|
+
table: str,
|
|
349
|
+
record_id: str,
|
|
350
|
+
data: dict[str, Any],
|
|
351
|
+
) -> dict[str, Any]:
|
|
352
|
+
"""
|
|
353
|
+
Patch via REST PATCH /key/:table/:id.
|
|
354
|
+
|
|
355
|
+
Args:
|
|
356
|
+
table: Table name
|
|
357
|
+
record_id: Record ID
|
|
358
|
+
data: Fields to update
|
|
359
|
+
|
|
360
|
+
Returns:
|
|
361
|
+
Updated record
|
|
362
|
+
"""
|
|
363
|
+
if not self._client:
|
|
364
|
+
raise ConnectionError("Not connected. Call connect() first.")
|
|
365
|
+
|
|
366
|
+
response = await self._client.patch(
|
|
367
|
+
f"/key/{table}/{record_id}",
|
|
368
|
+
json=data,
|
|
369
|
+
headers=self.headers,
|
|
370
|
+
)
|
|
371
|
+
response.raise_for_status()
|
|
372
|
+
result: dict[str, Any] = response.json()
|
|
373
|
+
return result
|
|
374
|
+
|
|
375
|
+
async def rest_delete(
|
|
376
|
+
self,
|
|
377
|
+
table: str,
|
|
378
|
+
record_id: str | None = None,
|
|
379
|
+
) -> dict[str, Any] | list[dict[str, Any]]:
|
|
380
|
+
"""
|
|
381
|
+
Delete via REST DELETE /key/:table or /key/:table/:id.
|
|
382
|
+
|
|
383
|
+
Args:
|
|
384
|
+
table: Table name
|
|
385
|
+
record_id: Optional record ID
|
|
386
|
+
|
|
387
|
+
Returns:
|
|
388
|
+
Deleted record(s)
|
|
389
|
+
"""
|
|
390
|
+
if not self._client:
|
|
391
|
+
raise ConnectionError("Not connected. Call connect() first.")
|
|
392
|
+
|
|
393
|
+
path = f"/key/{table}"
|
|
394
|
+
if record_id:
|
|
395
|
+
path += f"/{record_id}"
|
|
396
|
+
|
|
397
|
+
response = await self._client.delete(path, headers=self.headers)
|
|
398
|
+
response.raise_for_status()
|
|
399
|
+
result: dict[str, Any] | list[dict[str, Any]] = response.json()
|
|
400
|
+
return result
|
|
401
|
+
|
|
402
|
+
# Transaction support
|
|
403
|
+
|
|
404
|
+
def transaction(self) -> "HTTPTransaction":
|
|
405
|
+
"""
|
|
406
|
+
Create a new HTTP transaction.
|
|
407
|
+
|
|
408
|
+
HTTP transactions batch all statements and execute them atomically on commit.
|
|
409
|
+
|
|
410
|
+
Usage:
|
|
411
|
+
async with conn.transaction() as tx:
|
|
412
|
+
await tx.create("users", {"name": "Alice"})
|
|
413
|
+
await tx.create("orders", {"user": "users:alice"})
|
|
414
|
+
# All statements executed atomically on exit
|
|
415
|
+
|
|
416
|
+
Returns:
|
|
417
|
+
HTTPTransaction context manager
|
|
418
|
+
"""
|
|
419
|
+
from ..transaction import HTTPTransaction
|
|
420
|
+
|
|
421
|
+
return HTTPTransaction(self)
|
|
@@ -0,0 +1,244 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Connection Pool Implementation for SurrealDB SDK.
|
|
3
|
+
|
|
4
|
+
Provides connection pooling for both HTTP and WebSocket connections.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from typing import Any, AsyncGenerator, Self
|
|
8
|
+
import asyncio
|
|
9
|
+
from contextlib import asynccontextmanager
|
|
10
|
+
from collections import deque
|
|
11
|
+
|
|
12
|
+
from .base import BaseSurrealConnection
|
|
13
|
+
from .http import HTTPConnection
|
|
14
|
+
from .websocket import WebSocketConnection
|
|
15
|
+
from ..types import QueryResponse, RecordResponse, RecordsResponse, DeleteResponse
|
|
16
|
+
from ..exceptions import ConnectionError
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class ConnectionPool:
|
|
20
|
+
"""
|
|
21
|
+
Connection pool for SurrealDB connections.
|
|
22
|
+
|
|
23
|
+
Manages a pool of reusable connections for improved performance
|
|
24
|
+
in high-throughput scenarios.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
def __init__(
|
|
28
|
+
self,
|
|
29
|
+
url: str,
|
|
30
|
+
namespace: str,
|
|
31
|
+
database: str,
|
|
32
|
+
size: int = 10,
|
|
33
|
+
connection_type: str = "http",
|
|
34
|
+
timeout: float = 30.0,
|
|
35
|
+
**kwargs: Any,
|
|
36
|
+
):
|
|
37
|
+
"""
|
|
38
|
+
Initialize connection pool.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
url: SurrealDB server URL
|
|
42
|
+
namespace: Target namespace
|
|
43
|
+
database: Target database
|
|
44
|
+
size: Maximum pool size
|
|
45
|
+
connection_type: "http" or "websocket"
|
|
46
|
+
timeout: Connection timeout in seconds
|
|
47
|
+
**kwargs: Additional connection arguments
|
|
48
|
+
"""
|
|
49
|
+
self.url = url
|
|
50
|
+
self.namespace = namespace
|
|
51
|
+
self.database = database
|
|
52
|
+
self.size = size
|
|
53
|
+
self.connection_type = connection_type
|
|
54
|
+
self.timeout = timeout
|
|
55
|
+
self.kwargs = kwargs
|
|
56
|
+
|
|
57
|
+
self._pool: deque[BaseSurrealConnection] = deque()
|
|
58
|
+
self._in_use: set[BaseSurrealConnection] = set()
|
|
59
|
+
self._lock = asyncio.Lock()
|
|
60
|
+
self._closed = False
|
|
61
|
+
self._credentials: tuple[str, str] | None = None
|
|
62
|
+
|
|
63
|
+
async def __aenter__(self) -> Self:
|
|
64
|
+
"""Async context manager entry."""
|
|
65
|
+
return self
|
|
66
|
+
|
|
67
|
+
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
|
68
|
+
"""Async context manager exit."""
|
|
69
|
+
await self.close()
|
|
70
|
+
|
|
71
|
+
def _create_connection(self) -> BaseSurrealConnection:
|
|
72
|
+
"""Create a new connection instance."""
|
|
73
|
+
if self.connection_type == "websocket":
|
|
74
|
+
return WebSocketConnection(
|
|
75
|
+
self.url,
|
|
76
|
+
self.namespace,
|
|
77
|
+
self.database,
|
|
78
|
+
timeout=self.timeout,
|
|
79
|
+
**self.kwargs,
|
|
80
|
+
)
|
|
81
|
+
else:
|
|
82
|
+
return HTTPConnection(
|
|
83
|
+
self.url,
|
|
84
|
+
self.namespace,
|
|
85
|
+
self.database,
|
|
86
|
+
timeout=self.timeout,
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
async def _init_connection(self, conn: BaseSurrealConnection) -> None:
|
|
90
|
+
"""Initialize a connection."""
|
|
91
|
+
await conn.connect()
|
|
92
|
+
if self._credentials:
|
|
93
|
+
user, password = self._credentials
|
|
94
|
+
await conn.signin(user, password)
|
|
95
|
+
|
|
96
|
+
async def set_credentials(self, user: str, password: str) -> None:
|
|
97
|
+
"""
|
|
98
|
+
Set credentials for all pool connections.
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
user: Username
|
|
102
|
+
password: Password
|
|
103
|
+
"""
|
|
104
|
+
self._credentials = (user, password)
|
|
105
|
+
|
|
106
|
+
# Re-authenticate existing connections
|
|
107
|
+
async with self._lock:
|
|
108
|
+
for conn in self._pool:
|
|
109
|
+
try:
|
|
110
|
+
await conn.signin(user, password)
|
|
111
|
+
except Exception:
|
|
112
|
+
pass
|
|
113
|
+
|
|
114
|
+
@asynccontextmanager
|
|
115
|
+
async def acquire(self) -> AsyncGenerator[BaseSurrealConnection, None]:
|
|
116
|
+
"""
|
|
117
|
+
Acquire a connection from the pool.
|
|
118
|
+
|
|
119
|
+
Usage:
|
|
120
|
+
async with pool.acquire() as conn:
|
|
121
|
+
result = await conn.query("SELECT * FROM users")
|
|
122
|
+
|
|
123
|
+
Yields:
|
|
124
|
+
A SurrealDB connection
|
|
125
|
+
"""
|
|
126
|
+
if self._closed:
|
|
127
|
+
raise ConnectionError("Pool is closed")
|
|
128
|
+
|
|
129
|
+
conn: BaseSurrealConnection | None = None
|
|
130
|
+
|
|
131
|
+
async with self._lock:
|
|
132
|
+
# Try to get an existing connection from pool
|
|
133
|
+
while self._pool:
|
|
134
|
+
conn = self._pool.popleft()
|
|
135
|
+
if conn.is_connected:
|
|
136
|
+
break
|
|
137
|
+
# Connection is dead, discard it
|
|
138
|
+
try:
|
|
139
|
+
await conn.close()
|
|
140
|
+
except Exception:
|
|
141
|
+
pass
|
|
142
|
+
conn = None
|
|
143
|
+
|
|
144
|
+
# Create new connection if needed and pool not at capacity
|
|
145
|
+
if conn is None:
|
|
146
|
+
if len(self._in_use) < self.size:
|
|
147
|
+
conn = self._create_connection()
|
|
148
|
+
await self._init_connection(conn)
|
|
149
|
+
else:
|
|
150
|
+
# Pool at capacity, wait for a connection
|
|
151
|
+
pass
|
|
152
|
+
|
|
153
|
+
if conn:
|
|
154
|
+
self._in_use.add(conn)
|
|
155
|
+
|
|
156
|
+
if conn is None:
|
|
157
|
+
# Wait for a connection to become available
|
|
158
|
+
while conn is None:
|
|
159
|
+
await asyncio.sleep(0.01)
|
|
160
|
+
async with self._lock:
|
|
161
|
+
if self._pool:
|
|
162
|
+
conn = self._pool.popleft()
|
|
163
|
+
self._in_use.add(conn)
|
|
164
|
+
|
|
165
|
+
try:
|
|
166
|
+
yield conn
|
|
167
|
+
finally:
|
|
168
|
+
async with self._lock:
|
|
169
|
+
self._in_use.discard(conn)
|
|
170
|
+
if not self._closed and conn.is_connected:
|
|
171
|
+
self._pool.append(conn)
|
|
172
|
+
else:
|
|
173
|
+
try:
|
|
174
|
+
await conn.close()
|
|
175
|
+
except Exception:
|
|
176
|
+
pass
|
|
177
|
+
|
|
178
|
+
async def close(self) -> None:
|
|
179
|
+
"""Close all connections in the pool."""
|
|
180
|
+
self._closed = True
|
|
181
|
+
|
|
182
|
+
async with self._lock:
|
|
183
|
+
# Close pooled connections
|
|
184
|
+
while self._pool:
|
|
185
|
+
conn = self._pool.popleft()
|
|
186
|
+
try:
|
|
187
|
+
await conn.close()
|
|
188
|
+
except Exception:
|
|
189
|
+
pass
|
|
190
|
+
|
|
191
|
+
# Close in-use connections
|
|
192
|
+
for conn in self._in_use:
|
|
193
|
+
try:
|
|
194
|
+
await conn.close()
|
|
195
|
+
except Exception:
|
|
196
|
+
pass
|
|
197
|
+
self._in_use.clear()
|
|
198
|
+
|
|
199
|
+
@property
|
|
200
|
+
def available(self) -> int:
|
|
201
|
+
"""Number of available connections in pool."""
|
|
202
|
+
return len(self._pool)
|
|
203
|
+
|
|
204
|
+
@property
|
|
205
|
+
def in_use(self) -> int:
|
|
206
|
+
"""Number of connections currently in use."""
|
|
207
|
+
return len(self._in_use)
|
|
208
|
+
|
|
209
|
+
@property
|
|
210
|
+
def total(self) -> int:
|
|
211
|
+
"""Total number of connections (available + in use)."""
|
|
212
|
+
return len(self._pool) + len(self._in_use)
|
|
213
|
+
|
|
214
|
+
# Convenience methods that acquire a connection
|
|
215
|
+
|
|
216
|
+
async def query(self, sql: str, vars: dict[str, Any] | None = None) -> QueryResponse:
|
|
217
|
+
"""Execute a query using a pooled connection."""
|
|
218
|
+
async with self.acquire() as conn:
|
|
219
|
+
return await conn.query(sql, vars)
|
|
220
|
+
|
|
221
|
+
async def select(self, thing: str) -> RecordsResponse:
|
|
222
|
+
"""Select records using a pooled connection."""
|
|
223
|
+
async with self.acquire() as conn:
|
|
224
|
+
return await conn.select(thing)
|
|
225
|
+
|
|
226
|
+
async def create(self, thing: str, data: dict[str, Any] | None = None) -> RecordResponse:
|
|
227
|
+
"""Create a record using a pooled connection."""
|
|
228
|
+
async with self.acquire() as conn:
|
|
229
|
+
return await conn.create(thing, data)
|
|
230
|
+
|
|
231
|
+
async def update(self, thing: str, data: dict[str, Any]) -> RecordsResponse:
|
|
232
|
+
"""Update record(s) using a pooled connection."""
|
|
233
|
+
async with self.acquire() as conn:
|
|
234
|
+
return await conn.update(thing, data)
|
|
235
|
+
|
|
236
|
+
async def merge(self, thing: str, data: dict[str, Any]) -> RecordsResponse:
|
|
237
|
+
"""Merge data into record(s) using a pooled connection."""
|
|
238
|
+
async with self.acquire() as conn:
|
|
239
|
+
return await conn.merge(thing, data)
|
|
240
|
+
|
|
241
|
+
async def delete(self, thing: str) -> DeleteResponse:
|
|
242
|
+
"""Delete record(s) using a pooled connection."""
|
|
243
|
+
async with self.acquire() as conn:
|
|
244
|
+
return await conn.delete(thing)
|