surrealdb-orm 0.1.3__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- surreal_orm/__init__.py +78 -3
- surreal_orm/aggregations.py +164 -0
- surreal_orm/auth/__init__.py +15 -0
- surreal_orm/auth/access.py +167 -0
- surreal_orm/auth/mixins.py +302 -0
- surreal_orm/cli/__init__.py +15 -0
- surreal_orm/cli/commands.py +369 -0
- surreal_orm/connection_manager.py +58 -18
- surreal_orm/fields/__init__.py +36 -0
- surreal_orm/fields/encrypted.py +166 -0
- surreal_orm/fields/relation.py +465 -0
- surreal_orm/migrations/__init__.py +51 -0
- surreal_orm/migrations/executor.py +380 -0
- surreal_orm/migrations/generator.py +272 -0
- surreal_orm/migrations/introspector.py +305 -0
- surreal_orm/migrations/migration.py +188 -0
- surreal_orm/migrations/operations.py +531 -0
- surreal_orm/migrations/state.py +406 -0
- surreal_orm/model_base.py +594 -135
- surreal_orm/py.typed +0 -0
- surreal_orm/query_set.py +609 -34
- surreal_orm/relations.py +645 -0
- surreal_orm/surreal_function.py +95 -0
- surreal_orm/surreal_ql.py +113 -0
- surreal_orm/types.py +86 -0
- surreal_sdk/README.md +79 -0
- surreal_sdk/__init__.py +151 -0
- surreal_sdk/connection/__init__.py +17 -0
- surreal_sdk/connection/base.py +516 -0
- surreal_sdk/connection/http.py +421 -0
- surreal_sdk/connection/pool.py +244 -0
- surreal_sdk/connection/websocket.py +519 -0
- surreal_sdk/exceptions.py +71 -0
- surreal_sdk/functions.py +607 -0
- surreal_sdk/protocol/__init__.py +13 -0
- surreal_sdk/protocol/rpc.py +218 -0
- surreal_sdk/py.typed +0 -0
- surreal_sdk/pyproject.toml +49 -0
- surreal_sdk/streaming/__init__.py +31 -0
- surreal_sdk/streaming/change_feed.py +278 -0
- surreal_sdk/streaming/live_query.py +265 -0
- surreal_sdk/streaming/live_select.py +369 -0
- surreal_sdk/transaction.py +386 -0
- surreal_sdk/types.py +346 -0
- surrealdb_orm-0.5.0.dist-info/METADATA +465 -0
- surrealdb_orm-0.5.0.dist-info/RECORD +52 -0
- {surrealdb_orm-0.1.3.dist-info → surrealdb_orm-0.5.0.dist-info}/WHEEL +1 -1
- surrealdb_orm-0.5.0.dist-info/entry_points.txt +2 -0
- {surrealdb_orm-0.1.3.dist-info → surrealdb_orm-0.5.0.dist-info}/licenses/LICENSE +1 -1
- surrealdb_orm-0.1.3.dist-info/METADATA +0 -184
- surrealdb_orm-0.1.3.dist-info/RECORD +0 -11
surreal_orm/query_set.py
CHANGED
|
@@ -1,14 +1,20 @@
|
|
|
1
1
|
from .constants import LOOKUP_OPERATORS
|
|
2
2
|
from .enum import OrderBy
|
|
3
3
|
from .utils import remove_quotes_for_variables
|
|
4
|
-
from surrealdb import QueryResponse, Table, AsyncSurrealDB
|
|
5
|
-
from surrealdb.errors import SurrealDbError
|
|
6
4
|
from . import BaseSurrealModel, SurrealDBConnectionManager
|
|
7
|
-
from
|
|
5
|
+
from .aggregations import Aggregation
|
|
6
|
+
from typing import Self, Any, Sequence, cast
|
|
8
7
|
from pydantic_core import ValidationError
|
|
9
8
|
|
|
10
9
|
import logging
|
|
11
10
|
|
|
11
|
+
|
|
12
|
+
class SurrealDbError(Exception):
|
|
13
|
+
"""Error from SurrealDB operations."""
|
|
14
|
+
|
|
15
|
+
pass
|
|
16
|
+
|
|
17
|
+
|
|
12
18
|
logger = logging.getLogger(__name__)
|
|
13
19
|
|
|
14
20
|
|
|
@@ -55,8 +61,14 @@ class QuerySet:
|
|
|
55
61
|
self._limit: int | None = None
|
|
56
62
|
self._offset: int | None = None
|
|
57
63
|
self._order_by: str | None = None
|
|
58
|
-
self._model_table: str =
|
|
64
|
+
self._model_table: str = model.__name__
|
|
59
65
|
self._variables: dict = {}
|
|
66
|
+
self._group_by_fields: list[str] = []
|
|
67
|
+
self._annotations: dict[str, Aggregation] = {}
|
|
68
|
+
# Relation query options
|
|
69
|
+
self._select_related: list[str] = []
|
|
70
|
+
self._prefetch_related: list[str] = []
|
|
71
|
+
self._traversal_path: str | None = None
|
|
60
72
|
|
|
61
73
|
def select(self, *fields: str) -> Self:
|
|
62
74
|
"""
|
|
@@ -201,7 +213,7 @@ class QuerySet:
|
|
|
201
213
|
self._offset = value
|
|
202
214
|
return self
|
|
203
215
|
|
|
204
|
-
def order_by(self, field_name: str,
|
|
216
|
+
def order_by(self, field_name: str, order_type: OrderBy = OrderBy.ASC) -> Self:
|
|
205
217
|
"""
|
|
206
218
|
Set the field and direction to order the results by.
|
|
207
219
|
|
|
@@ -220,9 +232,248 @@ class QuerySet:
|
|
|
220
232
|
queryset.order_by('name', OrderBy.DESC)
|
|
221
233
|
```
|
|
222
234
|
"""
|
|
223
|
-
self._order_by = f"{field_name} {
|
|
235
|
+
self._order_by = f"{field_name} {order_type}"
|
|
236
|
+
return self
|
|
237
|
+
|
|
238
|
+
def values(self, *fields: str) -> Self:
|
|
239
|
+
"""
|
|
240
|
+
Specify the fields to group by for aggregation queries.
|
|
241
|
+
|
|
242
|
+
This method is used in conjunction with `annotate()` to perform GROUP BY operations.
|
|
243
|
+
The specified fields become the grouping keys, and aggregation functions are applied
|
|
244
|
+
to each group.
|
|
245
|
+
|
|
246
|
+
Args:
|
|
247
|
+
*fields (str): Variable length argument list of field names to group by.
|
|
248
|
+
|
|
249
|
+
Returns:
|
|
250
|
+
Self: The current instance of QuerySet to allow method chaining.
|
|
251
|
+
|
|
252
|
+
Example:
|
|
253
|
+
```python
|
|
254
|
+
# Group orders by status and calculate statistics
|
|
255
|
+
stats = await Order.objects().values("status").annotate(
|
|
256
|
+
count=Count(),
|
|
257
|
+
total=Sum("amount"),
|
|
258
|
+
)
|
|
259
|
+
# Result: [{"status": "paid", "count": 42, "total": 5000}, ...]
|
|
260
|
+
```
|
|
261
|
+
"""
|
|
262
|
+
self._group_by_fields = list(fields)
|
|
263
|
+
return self
|
|
264
|
+
|
|
265
|
+
def annotate(self, **aggregations: Aggregation) -> Self:
|
|
266
|
+
"""
|
|
267
|
+
Add aggregation functions to compute values for each group.
|
|
268
|
+
|
|
269
|
+
This method is used in conjunction with `values()` to perform GROUP BY operations.
|
|
270
|
+
Each keyword argument should be an Aggregation instance (Count, Sum, Avg, Min, Max).
|
|
271
|
+
|
|
272
|
+
Args:
|
|
273
|
+
**aggregations (Aggregation): Keyword arguments where keys are alias names
|
|
274
|
+
and values are Aggregation instances.
|
|
275
|
+
|
|
276
|
+
Returns:
|
|
277
|
+
Self: The current instance of QuerySet to allow method chaining.
|
|
278
|
+
|
|
279
|
+
Example:
|
|
280
|
+
```python
|
|
281
|
+
from surreal_orm.aggregations import Count, Sum, Avg
|
|
282
|
+
|
|
283
|
+
# Calculate statistics per status
|
|
284
|
+
stats = await Order.objects().values("status").annotate(
|
|
285
|
+
count=Count(),
|
|
286
|
+
total=Sum("amount"),
|
|
287
|
+
avg_amount=Avg("amount"),
|
|
288
|
+
)
|
|
289
|
+
```
|
|
290
|
+
"""
|
|
291
|
+
self._annotations = aggregations
|
|
292
|
+
return self
|
|
293
|
+
|
|
294
|
+
# ==================== Relation Query Methods ====================
|
|
295
|
+
|
|
296
|
+
def select_related(self, *relations: str) -> Self:
|
|
297
|
+
"""
|
|
298
|
+
Eagerly load related objects in the same query.
|
|
299
|
+
|
|
300
|
+
This method optimizes queries by loading related objects alongside
|
|
301
|
+
the main query results, avoiding N+1 query problems for forward relations.
|
|
302
|
+
|
|
303
|
+
Note: SurrealDB handles this through graph traversal syntax.
|
|
304
|
+
The actual loading happens during query execution.
|
|
305
|
+
|
|
306
|
+
Args:
|
|
307
|
+
*relations: Names of relations to load eagerly.
|
|
308
|
+
|
|
309
|
+
Returns:
|
|
310
|
+
Self: The current instance of QuerySet to allow method chaining.
|
|
311
|
+
|
|
312
|
+
Example:
|
|
313
|
+
```python
|
|
314
|
+
# Load posts with their authors in one query
|
|
315
|
+
posts = await Post.objects().select_related("author").all()
|
|
316
|
+
for post in posts:
|
|
317
|
+
print(post.author.name) # No additional query
|
|
318
|
+
```
|
|
319
|
+
"""
|
|
320
|
+
self._select_related = list(relations)
|
|
321
|
+
return self
|
|
322
|
+
|
|
323
|
+
def prefetch_related(self, *relations: str) -> Self:
|
|
324
|
+
"""
|
|
325
|
+
Prefetch related objects using separate optimized queries.
|
|
326
|
+
|
|
327
|
+
This method reduces N+1 query problems by fetching related objects
|
|
328
|
+
in batches after the main query completes. This is more efficient
|
|
329
|
+
than select_related for many-to-many and reverse relations.
|
|
330
|
+
|
|
331
|
+
Args:
|
|
332
|
+
*relations: Names of relations to prefetch.
|
|
333
|
+
|
|
334
|
+
Returns:
|
|
335
|
+
Self: The current instance of QuerySet to allow method chaining.
|
|
336
|
+
|
|
337
|
+
Example:
|
|
338
|
+
```python
|
|
339
|
+
# Load users and prefetch their followers
|
|
340
|
+
users = await User.objects().prefetch_related("followers", "posts").all()
|
|
341
|
+
for user in users:
|
|
342
|
+
print(user.followers) # Already loaded
|
|
343
|
+
print(user.posts) # Already loaded
|
|
344
|
+
```
|
|
345
|
+
"""
|
|
346
|
+
self._prefetch_related = list(relations)
|
|
347
|
+
return self
|
|
348
|
+
|
|
349
|
+
def traverse(self, path: str) -> Self:
|
|
350
|
+
"""
|
|
351
|
+
Add a graph traversal path to the query.
|
|
352
|
+
|
|
353
|
+
This method allows querying across graph relations using
|
|
354
|
+
SurrealDB's traversal syntax.
|
|
355
|
+
|
|
356
|
+
Args:
|
|
357
|
+
path: Graph traversal path (e.g., "->follows->users->likes->posts")
|
|
358
|
+
|
|
359
|
+
Returns:
|
|
360
|
+
Self: The current instance of QuerySet to allow method chaining.
|
|
361
|
+
|
|
362
|
+
Example:
|
|
363
|
+
```python
|
|
364
|
+
# Get all posts liked by users that alice follows
|
|
365
|
+
posts = await User.objects().filter(id="alice").traverse(
|
|
366
|
+
"->follows->users->likes->posts"
|
|
367
|
+
).all()
|
|
368
|
+
```
|
|
369
|
+
"""
|
|
370
|
+
self._traversal_path = path
|
|
224
371
|
return self
|
|
225
372
|
|
|
373
|
+
async def graph_query(self, traversal: str, **variables: Any) -> list[dict[str, Any]]:
|
|
374
|
+
"""
|
|
375
|
+
Execute a raw graph traversal query.
|
|
376
|
+
|
|
377
|
+
This method provides direct access to SurrealDB's graph capabilities
|
|
378
|
+
for complex traversal patterns that can't be expressed through
|
|
379
|
+
the standard QuerySet API.
|
|
380
|
+
|
|
381
|
+
Args:
|
|
382
|
+
traversal: Graph traversal expression (e.g., "->follows->User")
|
|
383
|
+
**variables: Variables to bind in the query
|
|
384
|
+
|
|
385
|
+
Returns:
|
|
386
|
+
list[dict[str, Any]]: Raw query results as dictionaries
|
|
387
|
+
|
|
388
|
+
Example:
|
|
389
|
+
```python
|
|
390
|
+
# Find users that alice follows
|
|
391
|
+
result = await User.objects().filter(id="alice").graph_query("->follows->User")
|
|
392
|
+
|
|
393
|
+
# Multi-hop traversal
|
|
394
|
+
result = await User.objects().filter(id="alice").graph_query(
|
|
395
|
+
"->follows->User->follows->User"
|
|
396
|
+
)
|
|
397
|
+
```
|
|
398
|
+
"""
|
|
399
|
+
# Parse the traversal to determine edge and direction
|
|
400
|
+
# Expected format: "->edge->Table" or "<-edge<-Table"
|
|
401
|
+
# For now, support simple single-hop traversals
|
|
402
|
+
|
|
403
|
+
# Check if we have an id filter to use as starting point
|
|
404
|
+
source_id = None
|
|
405
|
+
for field_name, lookup_name, value in self._filters:
|
|
406
|
+
if field_name == "id" and lookup_name == "exact":
|
|
407
|
+
source_id = value
|
|
408
|
+
break
|
|
409
|
+
|
|
410
|
+
client = await SurrealDBConnectionManager.get_client()
|
|
411
|
+
|
|
412
|
+
if source_id:
|
|
413
|
+
# Use specific record as starting point
|
|
414
|
+
source_thing = f"{self._model_table}:{source_id}"
|
|
415
|
+
|
|
416
|
+
# Parse the traversal to get edge and direction
|
|
417
|
+
# Simple pattern: ->edge->Table or <-edge<-Table
|
|
418
|
+
if traversal.startswith("->"):
|
|
419
|
+
# Outgoing: get targets where source is 'in'
|
|
420
|
+
parts = traversal.split("->")
|
|
421
|
+
if len(parts) >= 3:
|
|
422
|
+
edge = parts[1]
|
|
423
|
+
query = f"SELECT out FROM {edge} WHERE in = {source_thing} FETCH out;"
|
|
424
|
+
result = await client.query(query, {**self._variables, **variables})
|
|
425
|
+
records: list[dict[str, Any]] = []
|
|
426
|
+
for row in result.all_records or []:
|
|
427
|
+
if isinstance(row.get("out"), dict):
|
|
428
|
+
records.append(row["out"])
|
|
429
|
+
return records
|
|
430
|
+
elif traversal.startswith("<-"):
|
|
431
|
+
# Incoming: get sources where target is 'out'
|
|
432
|
+
parts = traversal.split("<-")
|
|
433
|
+
if len(parts) >= 3:
|
|
434
|
+
edge = parts[1]
|
|
435
|
+
query = f"SELECT in FROM {edge} WHERE out = {source_thing} FETCH in;"
|
|
436
|
+
result = await client.query(query, {**self._variables, **variables})
|
|
437
|
+
records = []
|
|
438
|
+
for row in result.all_records or []:
|
|
439
|
+
if isinstance(row.get("in"), dict):
|
|
440
|
+
records.append(row["in"])
|
|
441
|
+
return records
|
|
442
|
+
|
|
443
|
+
# Fallback: return empty for unsupported patterns
|
|
444
|
+
return []
|
|
445
|
+
|
|
446
|
+
async def _execute_annotate(self) -> list[dict[str, Any]]:
|
|
447
|
+
"""
|
|
448
|
+
Execute the GROUP BY query with annotations.
|
|
449
|
+
|
|
450
|
+
Returns:
|
|
451
|
+
list[dict[str, Any]]: A list of dictionaries containing grouped results.
|
|
452
|
+
"""
|
|
453
|
+
# Build SELECT clause with group fields and aggregations
|
|
454
|
+
select_parts: list[str] = list(self._group_by_fields)
|
|
455
|
+
|
|
456
|
+
for alias, aggregation in self._annotations.items():
|
|
457
|
+
select_parts.append(aggregation.to_surql(alias))
|
|
458
|
+
|
|
459
|
+
select_clause = ", ".join(select_parts)
|
|
460
|
+
|
|
461
|
+
# Build WHERE clause
|
|
462
|
+
where_clause = self._compile_where_clause()
|
|
463
|
+
|
|
464
|
+
# Build GROUP BY clause
|
|
465
|
+
if self._group_by_fields:
|
|
466
|
+
group_clause = f" GROUP BY {', '.join(self._group_by_fields)}"
|
|
467
|
+
else:
|
|
468
|
+
group_clause = " GROUP ALL"
|
|
469
|
+
|
|
470
|
+
query = f"SELECT {select_clause} FROM {self._model_table}{where_clause}{group_clause};"
|
|
471
|
+
|
|
472
|
+
client = await SurrealDBConnectionManager.get_client()
|
|
473
|
+
result = await client.query(remove_quotes_for_variables(query), self._variables)
|
|
474
|
+
|
|
475
|
+
return cast(list[dict[str, Any]], result.all_records)
|
|
476
|
+
|
|
226
477
|
def _compile_query(self) -> str:
|
|
227
478
|
"""
|
|
228
479
|
Compile the QuerySet parameters into a SQL query string.
|
|
@@ -261,6 +512,10 @@ class QuerySet:
|
|
|
261
512
|
if where_clauses:
|
|
262
513
|
query += " WHERE " + " AND ".join(where_clauses)
|
|
263
514
|
|
|
515
|
+
# Append ORDER BY if set (must come before LIMIT/START in SurrealQL)
|
|
516
|
+
if self._order_by:
|
|
517
|
+
query += f" ORDER BY {self._order_by}"
|
|
518
|
+
|
|
264
519
|
# Append LIMIT if set
|
|
265
520
|
if self._limit is not None:
|
|
266
521
|
query += f" LIMIT {self._limit}"
|
|
@@ -269,10 +524,6 @@ class QuerySet:
|
|
|
269
524
|
if self._offset is not None:
|
|
270
525
|
query += f" START {self._offset}"
|
|
271
526
|
|
|
272
|
-
# Append ORDER BY if set
|
|
273
|
-
if self._order_by:
|
|
274
|
-
query += f" ORDER BY {self._order_by}"
|
|
275
|
-
|
|
276
527
|
query += ";"
|
|
277
528
|
return query
|
|
278
529
|
|
|
@@ -284,28 +535,41 @@ class QuerySet:
|
|
|
284
535
|
the results. If the data conforms to the model schema, it returns a list of model instances;
|
|
285
536
|
otherwise, it returns a list of dictionaries.
|
|
286
537
|
|
|
538
|
+
When `annotate()` has been called, this returns the aggregated results as dictionaries
|
|
539
|
+
instead of model instances.
|
|
540
|
+
|
|
287
541
|
Returns:
|
|
288
542
|
list[BaseSurrealModel] | list[dict]: A list of model instances if validation is successful,
|
|
289
|
-
otherwise a list of dictionaries representing the raw data.
|
|
543
|
+
otherwise a list of dictionaries representing the raw data. For annotated queries,
|
|
544
|
+
always returns a list of dictionaries.
|
|
290
545
|
|
|
291
546
|
Raises:
|
|
292
547
|
SurrealDbError: If there is an issue executing the query.
|
|
293
548
|
|
|
294
549
|
Example:
|
|
295
550
|
```python
|
|
551
|
+
# Regular query
|
|
296
552
|
results = await queryset.exec()
|
|
553
|
+
|
|
554
|
+
# Aggregation query
|
|
555
|
+
stats = await Order.objects().values("status").annotate(
|
|
556
|
+
count=Count(),
|
|
557
|
+
total=Sum("amount"),
|
|
558
|
+
).exec()
|
|
297
559
|
```
|
|
298
560
|
"""
|
|
299
|
-
|
|
561
|
+
# If annotations are set, execute as GROUP BY query
|
|
562
|
+
if self._annotations:
|
|
563
|
+
return await self._execute_annotate()
|
|
564
|
+
|
|
300
565
|
query = self._compile_query()
|
|
301
566
|
results = await self._execute_query(query)
|
|
302
567
|
try:
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
return self.model.from_db(data["result"])
|
|
568
|
+
# surrealdb SDK 1.0.8 returns records directly, not wrapped in {"result": ...}
|
|
569
|
+
return self.model.from_db(cast(dict | list | None, results))
|
|
306
570
|
except ValidationError as e:
|
|
307
571
|
logger.info(f"Pydantic invalid format for the class, returning dict value: {e}")
|
|
308
|
-
return
|
|
572
|
+
return results
|
|
309
573
|
|
|
310
574
|
async def first(self) -> Any:
|
|
311
575
|
"""
|
|
@@ -331,7 +595,7 @@ class QuerySet:
|
|
|
331
595
|
if results:
|
|
332
596
|
return results[0]
|
|
333
597
|
|
|
334
|
-
raise
|
|
598
|
+
raise self.model.DoesNotExist("Query returned no results.")
|
|
335
599
|
|
|
336
600
|
async def get(self, id_item: Any = None) -> Any:
|
|
337
601
|
"""
|
|
@@ -357,15 +621,18 @@ class QuerySet:
|
|
|
357
621
|
"""
|
|
358
622
|
if id_item:
|
|
359
623
|
client = await SurrealDBConnectionManager.get_client()
|
|
360
|
-
|
|
361
|
-
|
|
624
|
+
result = await client.select(f"{self._model_table}:{id_item}")
|
|
625
|
+
# SDK returns RecordsResponse
|
|
626
|
+
if result.is_empty:
|
|
627
|
+
raise self.model.DoesNotExist("Record not found.")
|
|
628
|
+
return self.model.from_db(cast(dict | list | None, result.first))
|
|
362
629
|
else:
|
|
363
630
|
result = await self.exec()
|
|
364
631
|
if len(result) > 1:
|
|
365
632
|
raise SurrealDbError("More than one result found.")
|
|
366
633
|
|
|
367
634
|
if len(result) == 0:
|
|
368
|
-
raise
|
|
635
|
+
raise self.model.DoesNotExist("Record not found.")
|
|
369
636
|
return result[0]
|
|
370
637
|
|
|
371
638
|
async def all(self) -> Any:
|
|
@@ -386,10 +653,168 @@ class QuerySet:
|
|
|
386
653
|
```
|
|
387
654
|
"""
|
|
388
655
|
client = await SurrealDBConnectionManager.get_client()
|
|
389
|
-
|
|
390
|
-
return self.model.from_db(
|
|
656
|
+
result = await client.select(self._model_table)
|
|
657
|
+
return self.model.from_db(cast(dict | list | None, result.records))
|
|
658
|
+
|
|
659
|
+
# ==================== Aggregation Methods ====================
|
|
660
|
+
|
|
661
|
+
def _compile_where_clause(self) -> str:
|
|
662
|
+
"""
|
|
663
|
+
Compile the WHERE clause from filters.
|
|
664
|
+
|
|
665
|
+
Returns:
|
|
666
|
+
str: The WHERE clause string (including WHERE keyword) or empty string.
|
|
667
|
+
"""
|
|
668
|
+
if not self._filters:
|
|
669
|
+
return ""
|
|
670
|
+
|
|
671
|
+
where_clauses = []
|
|
672
|
+
for field_name, lookup_name, value in self._filters:
|
|
673
|
+
op = LOOKUP_OPERATORS.get(lookup_name, "=")
|
|
674
|
+
if lookup_name == "in":
|
|
675
|
+
formatted_values = ", ".join(repr(v) for v in value)
|
|
676
|
+
where_clauses.append(f"{field_name} {op} [{formatted_values}]")
|
|
677
|
+
else:
|
|
678
|
+
where_clauses.append(f"{field_name} {op} {repr(value)}")
|
|
679
|
+
|
|
680
|
+
return " WHERE " + " AND ".join(where_clauses)
|
|
681
|
+
|
|
682
|
+
async def count(self) -> int:
|
|
683
|
+
"""
|
|
684
|
+
Count the number of records matching the current filters.
|
|
685
|
+
|
|
686
|
+
Returns:
|
|
687
|
+
int: The number of matching records.
|
|
688
|
+
|
|
689
|
+
Example:
|
|
690
|
+
```python
|
|
691
|
+
total = await User.objects().count()
|
|
692
|
+
active = await User.objects().filter(active=True).count()
|
|
693
|
+
```
|
|
694
|
+
"""
|
|
695
|
+
where_clause = self._compile_where_clause()
|
|
696
|
+
query = f"SELECT count() FROM {self._model_table}{where_clause} GROUP ALL;"
|
|
697
|
+
|
|
698
|
+
client = await SurrealDBConnectionManager.get_client()
|
|
699
|
+
result = await client.query(remove_quotes_for_variables(query), self._variables)
|
|
700
|
+
|
|
701
|
+
if result.all_records:
|
|
702
|
+
record = result.all_records[0]
|
|
703
|
+
if isinstance(record, dict) and "count" in record:
|
|
704
|
+
return int(record["count"])
|
|
705
|
+
return 0
|
|
706
|
+
|
|
707
|
+
async def sum(self, field: str) -> float | int:
|
|
708
|
+
"""
|
|
709
|
+
Calculate the sum of a numeric field.
|
|
710
|
+
|
|
711
|
+
Args:
|
|
712
|
+
field: The field name to sum.
|
|
713
|
+
|
|
714
|
+
Returns:
|
|
715
|
+
float | int: The sum of the field values, or 0 if no records match.
|
|
716
|
+
|
|
717
|
+
Example:
|
|
718
|
+
```python
|
|
719
|
+
total = await Order.objects().filter(status="paid").sum("amount")
|
|
720
|
+
```
|
|
721
|
+
"""
|
|
722
|
+
where_clause = self._compile_where_clause()
|
|
723
|
+
query = f"SELECT math::sum({field}) AS total FROM {self._model_table}{where_clause} GROUP ALL;"
|
|
724
|
+
|
|
725
|
+
client = await SurrealDBConnectionManager.get_client()
|
|
726
|
+
result = await client.query(remove_quotes_for_variables(query), self._variables)
|
|
727
|
+
|
|
728
|
+
if result.all_records:
|
|
729
|
+
record = result.all_records[0]
|
|
730
|
+
if isinstance(record, dict) and "total" in record:
|
|
731
|
+
value = record["total"]
|
|
732
|
+
return value if value is not None else 0
|
|
733
|
+
return 0
|
|
734
|
+
|
|
735
|
+
async def avg(self, field: str) -> float | None:
|
|
736
|
+
"""
|
|
737
|
+
Calculate the average of a numeric field.
|
|
738
|
+
|
|
739
|
+
Args:
|
|
740
|
+
field: The field name to average.
|
|
741
|
+
|
|
742
|
+
Returns:
|
|
743
|
+
float | None: The average value, or None if no records match.
|
|
744
|
+
|
|
745
|
+
Example:
|
|
746
|
+
```python
|
|
747
|
+
avg_age = await User.objects().filter(active=True).avg("age")
|
|
748
|
+
```
|
|
749
|
+
"""
|
|
750
|
+
where_clause = self._compile_where_clause()
|
|
751
|
+
query = f"SELECT math::mean({field}) AS average FROM {self._model_table}{where_clause} GROUP ALL;"
|
|
752
|
+
|
|
753
|
+
client = await SurrealDBConnectionManager.get_client()
|
|
754
|
+
result = await client.query(remove_quotes_for_variables(query), self._variables)
|
|
755
|
+
|
|
756
|
+
if result.all_records:
|
|
757
|
+
record = result.all_records[0]
|
|
758
|
+
if isinstance(record, dict) and "average" in record:
|
|
759
|
+
value = record["average"]
|
|
760
|
+
return float(value) if value is not None else None
|
|
761
|
+
return None
|
|
762
|
+
|
|
763
|
+
async def min(self, field: str) -> Any:
|
|
764
|
+
"""
|
|
765
|
+
Get the minimum value of a field.
|
|
766
|
+
|
|
767
|
+
Args:
|
|
768
|
+
field: The field name to find the minimum of.
|
|
769
|
+
|
|
770
|
+
Returns:
|
|
771
|
+
Any: The minimum value, or None if no records match.
|
|
772
|
+
|
|
773
|
+
Example:
|
|
774
|
+
```python
|
|
775
|
+
min_price = await Product.objects().min("price")
|
|
776
|
+
```
|
|
777
|
+
"""
|
|
778
|
+
where_clause = self._compile_where_clause()
|
|
779
|
+
query = f"SELECT math::min({field}) AS minimum FROM {self._model_table}{where_clause} GROUP ALL;"
|
|
780
|
+
|
|
781
|
+
client = await SurrealDBConnectionManager.get_client()
|
|
782
|
+
result = await client.query(remove_quotes_for_variables(query), self._variables)
|
|
783
|
+
|
|
784
|
+
if result.all_records:
|
|
785
|
+
record = result.all_records[0]
|
|
786
|
+
if isinstance(record, dict) and "minimum" in record:
|
|
787
|
+
return record["minimum"]
|
|
788
|
+
return None
|
|
789
|
+
|
|
790
|
+
async def max(self, field: str) -> Any:
|
|
791
|
+
"""
|
|
792
|
+
Get the maximum value of a field.
|
|
793
|
+
|
|
794
|
+
Args:
|
|
795
|
+
field: The field name to find the maximum of.
|
|
391
796
|
|
|
392
|
-
|
|
797
|
+
Returns:
|
|
798
|
+
Any: The maximum value, or None if no records match.
|
|
799
|
+
|
|
800
|
+
Example:
|
|
801
|
+
```python
|
|
802
|
+
max_price = await Product.objects().max("price")
|
|
803
|
+
```
|
|
804
|
+
"""
|
|
805
|
+
where_clause = self._compile_where_clause()
|
|
806
|
+
query = f"SELECT math::max({field}) AS maximum FROM {self._model_table}{where_clause} GROUP ALL;"
|
|
807
|
+
|
|
808
|
+
client = await SurrealDBConnectionManager.get_client()
|
|
809
|
+
result = await client.query(remove_quotes_for_variables(query), self._variables)
|
|
810
|
+
|
|
811
|
+
if result.all_records:
|
|
812
|
+
record = result.all_records[0]
|
|
813
|
+
if isinstance(record, dict) and "maximum" in record:
|
|
814
|
+
return record["maximum"]
|
|
815
|
+
return None
|
|
816
|
+
|
|
817
|
+
async def _execute_query(self, query: str) -> list[Any]:
|
|
393
818
|
"""
|
|
394
819
|
Execute the given SQL query using the SurrealDB client.
|
|
395
820
|
|
|
@@ -400,7 +825,7 @@ class QuerySet:
|
|
|
400
825
|
query (str): The SQL query string to execute.
|
|
401
826
|
|
|
402
827
|
Returns:
|
|
403
|
-
list[
|
|
828
|
+
list[Any]: A list of query response objects containing the query results.
|
|
404
829
|
|
|
405
830
|
Raises:
|
|
406
831
|
SurrealDbError: If there is an issue executing the query.
|
|
@@ -413,7 +838,7 @@ class QuerySet:
|
|
|
413
838
|
client = await SurrealDBConnectionManager.get_client()
|
|
414
839
|
return await self._run_query_on_client(client, query)
|
|
415
840
|
|
|
416
|
-
async def _run_query_on_client(self, client:
|
|
841
|
+
async def _run_query_on_client(self, client: Any, query: str) -> list[Any]:
|
|
417
842
|
"""
|
|
418
843
|
Run the SQL query on the provided SurrealDB client.
|
|
419
844
|
|
|
@@ -421,11 +846,11 @@ class QuerySet:
|
|
|
421
846
|
and returns the raw query responses.
|
|
422
847
|
|
|
423
848
|
Args:
|
|
424
|
-
client
|
|
849
|
+
client: The active SurrealDB client instance.
|
|
425
850
|
query (str): The SQL query string to execute.
|
|
426
851
|
|
|
427
852
|
Returns:
|
|
428
|
-
list[
|
|
853
|
+
list[Any]: A list of query response objects containing the query results.
|
|
429
854
|
|
|
430
855
|
Raises:
|
|
431
856
|
SurrealDbError: If there is an issue executing the query.
|
|
@@ -435,7 +860,9 @@ class QuerySet:
|
|
|
435
860
|
results = await self._run_query_on_client(client, "SELECT * FROM users;")
|
|
436
861
|
```
|
|
437
862
|
"""
|
|
438
|
-
|
|
863
|
+
result = await client.query(remove_quotes_for_variables(query), self._variables)
|
|
864
|
+
# SDK returns QueryResponse, extract all records
|
|
865
|
+
return cast(list[Any], result.all_records)
|
|
439
866
|
|
|
440
867
|
async def delete_table(self) -> bool:
|
|
441
868
|
"""
|
|
@@ -456,7 +883,7 @@ class QuerySet:
|
|
|
456
883
|
```
|
|
457
884
|
"""
|
|
458
885
|
client = await SurrealDBConnectionManager.get_client()
|
|
459
|
-
await client.delete(
|
|
886
|
+
await client.delete(self._model_table)
|
|
460
887
|
return True
|
|
461
888
|
|
|
462
889
|
async def query(self, query: str, variables: dict[str, Any] = {}) -> Any:
|
|
@@ -484,9 +911,157 @@ class QuerySet:
|
|
|
484
911
|
results = await queryset.query(custom_query, variables={'status': 'active'})
|
|
485
912
|
```
|
|
486
913
|
"""
|
|
487
|
-
if f"FROM {self.
|
|
488
|
-
raise SurrealDbError(f"The query must include 'FROM {self.
|
|
914
|
+
if f"FROM {self._model_table}" not in query:
|
|
915
|
+
raise SurrealDbError(f"The query must include 'FROM {self._model_table}' to reference the correct table.")
|
|
916
|
+
client = await SurrealDBConnectionManager.get_client()
|
|
917
|
+
result = await client.query(remove_quotes_for_variables(query), variables)
|
|
918
|
+
# SDK returns QueryResponse, extract all records
|
|
919
|
+
return self.model.from_db(cast(dict | list | None, result.all_records))
|
|
920
|
+
|
|
921
|
+
# ==================== Bulk Operations ====================
|
|
922
|
+
|
|
923
|
+
async def bulk_create(
|
|
924
|
+
self,
|
|
925
|
+
instances: Sequence[BaseSurrealModel],
|
|
926
|
+
atomic: bool = False,
|
|
927
|
+
batch_size: int | None = None,
|
|
928
|
+
) -> list[BaseSurrealModel]:
|
|
929
|
+
"""
|
|
930
|
+
Create multiple model instances in the database efficiently.
|
|
931
|
+
|
|
932
|
+
Args:
|
|
933
|
+
instances: A sequence of model instances to create.
|
|
934
|
+
atomic: If True, all creates are wrapped in a transaction.
|
|
935
|
+
If any fails, all are rolled back.
|
|
936
|
+
batch_size: If specified, instances are created in batches of this size.
|
|
937
|
+
Useful for very large datasets to avoid memory issues.
|
|
938
|
+
|
|
939
|
+
Returns:
|
|
940
|
+
list[BaseSurrealModel]: The created instances.
|
|
941
|
+
|
|
942
|
+
Example:
|
|
943
|
+
```python
|
|
944
|
+
users = [User(name=f"User{i}") for i in range(1000)]
|
|
945
|
+
|
|
946
|
+
# Simple bulk create
|
|
947
|
+
created = await User.objects().bulk_create(users)
|
|
948
|
+
|
|
949
|
+
# Atomic bulk create
|
|
950
|
+
created = await User.objects().bulk_create(users, atomic=True)
|
|
951
|
+
|
|
952
|
+
# With batch size
|
|
953
|
+
created = await User.objects().bulk_create(users, batch_size=100)
|
|
954
|
+
```
|
|
955
|
+
"""
|
|
956
|
+
if not instances:
|
|
957
|
+
return []
|
|
958
|
+
|
|
959
|
+
created: list[BaseSurrealModel] = []
|
|
960
|
+
|
|
961
|
+
if atomic:
|
|
962
|
+
# Use transaction for atomicity
|
|
963
|
+
async with await SurrealDBConnectionManager.transaction() as tx:
|
|
964
|
+
for instance in instances:
|
|
965
|
+
await instance.save(tx=tx)
|
|
966
|
+
created.append(instance)
|
|
967
|
+
elif batch_size:
|
|
968
|
+
# Process in batches
|
|
969
|
+
for i in range(0, len(instances), batch_size):
|
|
970
|
+
batch = instances[i : i + batch_size]
|
|
971
|
+
for instance in batch:
|
|
972
|
+
await instance.save()
|
|
973
|
+
created.append(instance)
|
|
974
|
+
else:
|
|
975
|
+
# Simple sequential create
|
|
976
|
+
for instance in instances:
|
|
977
|
+
await instance.save()
|
|
978
|
+
created.append(instance)
|
|
979
|
+
|
|
980
|
+
return created
|
|
981
|
+
|
|
982
|
+
async def bulk_update(
|
|
983
|
+
self,
|
|
984
|
+
data: dict[str, Any],
|
|
985
|
+
atomic: bool = False,
|
|
986
|
+
) -> int:
|
|
987
|
+
"""
|
|
988
|
+
Update all records matching the current filters.
|
|
989
|
+
|
|
990
|
+
Args:
|
|
991
|
+
data: A dictionary of field names and values to update.
|
|
992
|
+
atomic: If True, all updates are wrapped in a transaction.
|
|
993
|
+
|
|
994
|
+
Returns:
|
|
995
|
+
int: The number of records updated.
|
|
996
|
+
|
|
997
|
+
Example:
|
|
998
|
+
```python
|
|
999
|
+
# Update all matching records
|
|
1000
|
+
updated = await User.objects().filter(
|
|
1001
|
+
last_login__lt="2025-01-01"
|
|
1002
|
+
).bulk_update({"status": "inactive"})
|
|
1003
|
+
|
|
1004
|
+
# Atomic update
|
|
1005
|
+
updated = await User.objects().filter(role="guest").bulk_update(
|
|
1006
|
+
{"verified": True},
|
|
1007
|
+
atomic=True
|
|
1008
|
+
)
|
|
1009
|
+
```
|
|
1010
|
+
"""
|
|
1011
|
+
where_clause = self._compile_where_clause()
|
|
1012
|
+
|
|
1013
|
+
# Build SET clause
|
|
1014
|
+
set_parts = []
|
|
1015
|
+
for field, value in data.items():
|
|
1016
|
+
set_parts.append(f"{field} = {repr(value)}")
|
|
1017
|
+
set_clause = ", ".join(set_parts)
|
|
1018
|
+
|
|
1019
|
+
query = f"UPDATE {self._model_table} SET {set_clause}{where_clause};"
|
|
1020
|
+
|
|
1021
|
+
if atomic:
|
|
1022
|
+
# For atomic operations, count first then update in transaction
|
|
1023
|
+
current_count = await self.count()
|
|
1024
|
+
async with await SurrealDBConnectionManager.transaction() as tx:
|
|
1025
|
+
await tx.query(remove_quotes_for_variables(query), self._variables)
|
|
1026
|
+
return current_count
|
|
1027
|
+
|
|
1028
|
+
client = await SurrealDBConnectionManager.get_client()
|
|
1029
|
+
result = await client.query(remove_quotes_for_variables(query), self._variables)
|
|
1030
|
+
return len(result.all_records)
|
|
1031
|
+
|
|
1032
|
+
async def bulk_delete(self, atomic: bool = False) -> int:
|
|
1033
|
+
"""
|
|
1034
|
+
Delete all records matching the current filters.
|
|
1035
|
+
|
|
1036
|
+
Args:
|
|
1037
|
+
atomic: If True, all deletes are wrapped in a transaction.
|
|
1038
|
+
|
|
1039
|
+
Returns:
|
|
1040
|
+
int: The number of records deleted.
|
|
1041
|
+
|
|
1042
|
+
Example:
|
|
1043
|
+
```python
|
|
1044
|
+
# Delete all matching records
|
|
1045
|
+
deleted = await User.objects().filter(status="deleted").bulk_delete()
|
|
1046
|
+
|
|
1047
|
+
# Atomic delete
|
|
1048
|
+
deleted = await Order.objects().filter(
|
|
1049
|
+
created_at__lt="2024-01-01"
|
|
1050
|
+
).bulk_delete(atomic=True)
|
|
1051
|
+
```
|
|
1052
|
+
"""
|
|
1053
|
+
where_clause = self._compile_where_clause()
|
|
1054
|
+
# Use RETURN BEFORE to get deleted records count
|
|
1055
|
+
query = f"DELETE FROM {self._model_table}{where_clause} RETURN BEFORE;"
|
|
1056
|
+
|
|
1057
|
+
if atomic:
|
|
1058
|
+
# For atomic operations, count first then delete in transaction
|
|
1059
|
+
current_count = await self.count()
|
|
1060
|
+
delete_query = f"DELETE FROM {self._model_table}{where_clause};"
|
|
1061
|
+
async with await SurrealDBConnectionManager.transaction() as tx:
|
|
1062
|
+
await tx.query(remove_quotes_for_variables(delete_query), self._variables)
|
|
1063
|
+
return current_count
|
|
1064
|
+
|
|
489
1065
|
client = await SurrealDBConnectionManager.get_client()
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
return self.model.from_db(data["result"])
|
|
1066
|
+
result = await client.query(remove_quotes_for_variables(query), self._variables)
|
|
1067
|
+
return len(result.all_records)
|