surrealdb-orm 0.1.4__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- surreal_orm/__init__.py +72 -3
- surreal_orm/aggregations.py +164 -0
- surreal_orm/auth/__init__.py +15 -0
- surreal_orm/auth/access.py +167 -0
- surreal_orm/auth/mixins.py +302 -0
- surreal_orm/cli/__init__.py +15 -0
- surreal_orm/cli/commands.py +369 -0
- surreal_orm/connection_manager.py +58 -18
- surreal_orm/fields/__init__.py +36 -0
- surreal_orm/fields/encrypted.py +166 -0
- surreal_orm/fields/relation.py +465 -0
- surreal_orm/migrations/__init__.py +51 -0
- surreal_orm/migrations/executor.py +380 -0
- surreal_orm/migrations/generator.py +272 -0
- surreal_orm/migrations/introspector.py +305 -0
- surreal_orm/migrations/migration.py +188 -0
- surreal_orm/migrations/operations.py +531 -0
- surreal_orm/migrations/state.py +406 -0
- surreal_orm/model_base.py +530 -44
- surreal_orm/query_set.py +609 -33
- surreal_orm/relations.py +645 -0
- surreal_orm/surreal_function.py +95 -0
- surreal_orm/surreal_ql.py +113 -0
- surreal_orm/types.py +86 -0
- surreal_sdk/README.md +79 -0
- surreal_sdk/__init__.py +151 -0
- surreal_sdk/connection/__init__.py +17 -0
- surreal_sdk/connection/base.py +516 -0
- surreal_sdk/connection/http.py +421 -0
- surreal_sdk/connection/pool.py +244 -0
- surreal_sdk/connection/websocket.py +519 -0
- surreal_sdk/exceptions.py +71 -0
- surreal_sdk/functions.py +607 -0
- surreal_sdk/protocol/__init__.py +13 -0
- surreal_sdk/protocol/rpc.py +218 -0
- surreal_sdk/py.typed +0 -0
- surreal_sdk/pyproject.toml +49 -0
- surreal_sdk/streaming/__init__.py +31 -0
- surreal_sdk/streaming/change_feed.py +278 -0
- surreal_sdk/streaming/live_query.py +265 -0
- surreal_sdk/streaming/live_select.py +369 -0
- surreal_sdk/transaction.py +386 -0
- surreal_sdk/types.py +346 -0
- surrealdb_orm-0.5.0.dist-info/METADATA +465 -0
- surrealdb_orm-0.5.0.dist-info/RECORD +52 -0
- {surrealdb_orm-0.1.4.dist-info → surrealdb_orm-0.5.0.dist-info}/WHEEL +1 -1
- surrealdb_orm-0.5.0.dist-info/entry_points.txt +2 -0
- {surrealdb_orm-0.1.4.dist-info → surrealdb_orm-0.5.0.dist-info}/licenses/LICENSE +1 -1
- surrealdb_orm-0.1.4.dist-info/METADATA +0 -184
- surrealdb_orm-0.1.4.dist-info/RECORD +0 -12
surreal_orm/query_set.py
CHANGED
|
@@ -1,14 +1,20 @@
|
|
|
1
1
|
from .constants import LOOKUP_OPERATORS
|
|
2
2
|
from .enum import OrderBy
|
|
3
3
|
from .utils import remove_quotes_for_variables
|
|
4
|
-
from surrealdb import QueryResponse, Table, AsyncSurrealDB
|
|
5
|
-
from surrealdb.errors import SurrealDbError
|
|
6
4
|
from . import BaseSurrealModel, SurrealDBConnectionManager
|
|
7
|
-
from
|
|
5
|
+
from .aggregations import Aggregation
|
|
6
|
+
from typing import Self, Any, Sequence, cast
|
|
8
7
|
from pydantic_core import ValidationError
|
|
9
8
|
|
|
10
9
|
import logging
|
|
11
10
|
|
|
11
|
+
|
|
12
|
+
class SurrealDbError(Exception):
|
|
13
|
+
"""Error from SurrealDB operations."""
|
|
14
|
+
|
|
15
|
+
pass
|
|
16
|
+
|
|
17
|
+
|
|
12
18
|
logger = logging.getLogger(__name__)
|
|
13
19
|
|
|
14
20
|
|
|
@@ -55,8 +61,14 @@ class QuerySet:
|
|
|
55
61
|
self._limit: int | None = None
|
|
56
62
|
self._offset: int | None = None
|
|
57
63
|
self._order_by: str | None = None
|
|
58
|
-
self._model_table: str =
|
|
64
|
+
self._model_table: str = model.__name__
|
|
59
65
|
self._variables: dict = {}
|
|
66
|
+
self._group_by_fields: list[str] = []
|
|
67
|
+
self._annotations: dict[str, Aggregation] = {}
|
|
68
|
+
# Relation query options
|
|
69
|
+
self._select_related: list[str] = []
|
|
70
|
+
self._prefetch_related: list[str] = []
|
|
71
|
+
self._traversal_path: str | None = None
|
|
60
72
|
|
|
61
73
|
def select(self, *fields: str) -> Self:
|
|
62
74
|
"""
|
|
@@ -201,7 +213,7 @@ class QuerySet:
|
|
|
201
213
|
self._offset = value
|
|
202
214
|
return self
|
|
203
215
|
|
|
204
|
-
def order_by(self, field_name: str,
|
|
216
|
+
def order_by(self, field_name: str, order_type: OrderBy = OrderBy.ASC) -> Self:
|
|
205
217
|
"""
|
|
206
218
|
Set the field and direction to order the results by.
|
|
207
219
|
|
|
@@ -220,9 +232,248 @@ class QuerySet:
|
|
|
220
232
|
queryset.order_by('name', OrderBy.DESC)
|
|
221
233
|
```
|
|
222
234
|
"""
|
|
223
|
-
self._order_by = f"{field_name} {
|
|
235
|
+
self._order_by = f"{field_name} {order_type}"
|
|
236
|
+
return self
|
|
237
|
+
|
|
238
|
+
def values(self, *fields: str) -> Self:
|
|
239
|
+
"""
|
|
240
|
+
Specify the fields to group by for aggregation queries.
|
|
241
|
+
|
|
242
|
+
This method is used in conjunction with `annotate()` to perform GROUP BY operations.
|
|
243
|
+
The specified fields become the grouping keys, and aggregation functions are applied
|
|
244
|
+
to each group.
|
|
245
|
+
|
|
246
|
+
Args:
|
|
247
|
+
*fields (str): Variable length argument list of field names to group by.
|
|
248
|
+
|
|
249
|
+
Returns:
|
|
250
|
+
Self: The current instance of QuerySet to allow method chaining.
|
|
251
|
+
|
|
252
|
+
Example:
|
|
253
|
+
```python
|
|
254
|
+
# Group orders by status and calculate statistics
|
|
255
|
+
stats = await Order.objects().values("status").annotate(
|
|
256
|
+
count=Count(),
|
|
257
|
+
total=Sum("amount"),
|
|
258
|
+
)
|
|
259
|
+
# Result: [{"status": "paid", "count": 42, "total": 5000}, ...]
|
|
260
|
+
```
|
|
261
|
+
"""
|
|
262
|
+
self._group_by_fields = list(fields)
|
|
263
|
+
return self
|
|
264
|
+
|
|
265
|
+
def annotate(self, **aggregations: Aggregation) -> Self:
|
|
266
|
+
"""
|
|
267
|
+
Add aggregation functions to compute values for each group.
|
|
268
|
+
|
|
269
|
+
This method is used in conjunction with `values()` to perform GROUP BY operations.
|
|
270
|
+
Each keyword argument should be an Aggregation instance (Count, Sum, Avg, Min, Max).
|
|
271
|
+
|
|
272
|
+
Args:
|
|
273
|
+
**aggregations (Aggregation): Keyword arguments where keys are alias names
|
|
274
|
+
and values are Aggregation instances.
|
|
275
|
+
|
|
276
|
+
Returns:
|
|
277
|
+
Self: The current instance of QuerySet to allow method chaining.
|
|
278
|
+
|
|
279
|
+
Example:
|
|
280
|
+
```python
|
|
281
|
+
from surreal_orm.aggregations import Count, Sum, Avg
|
|
282
|
+
|
|
283
|
+
# Calculate statistics per status
|
|
284
|
+
stats = await Order.objects().values("status").annotate(
|
|
285
|
+
count=Count(),
|
|
286
|
+
total=Sum("amount"),
|
|
287
|
+
avg_amount=Avg("amount"),
|
|
288
|
+
)
|
|
289
|
+
```
|
|
290
|
+
"""
|
|
291
|
+
self._annotations = aggregations
|
|
292
|
+
return self
|
|
293
|
+
|
|
294
|
+
# ==================== Relation Query Methods ====================
|
|
295
|
+
|
|
296
|
+
def select_related(self, *relations: str) -> Self:
|
|
297
|
+
"""
|
|
298
|
+
Eagerly load related objects in the same query.
|
|
299
|
+
|
|
300
|
+
This method optimizes queries by loading related objects alongside
|
|
301
|
+
the main query results, avoiding N+1 query problems for forward relations.
|
|
302
|
+
|
|
303
|
+
Note: SurrealDB handles this through graph traversal syntax.
|
|
304
|
+
The actual loading happens during query execution.
|
|
305
|
+
|
|
306
|
+
Args:
|
|
307
|
+
*relations: Names of relations to load eagerly.
|
|
308
|
+
|
|
309
|
+
Returns:
|
|
310
|
+
Self: The current instance of QuerySet to allow method chaining.
|
|
311
|
+
|
|
312
|
+
Example:
|
|
313
|
+
```python
|
|
314
|
+
# Load posts with their authors in one query
|
|
315
|
+
posts = await Post.objects().select_related("author").all()
|
|
316
|
+
for post in posts:
|
|
317
|
+
print(post.author.name) # No additional query
|
|
318
|
+
```
|
|
319
|
+
"""
|
|
320
|
+
self._select_related = list(relations)
|
|
321
|
+
return self
|
|
322
|
+
|
|
323
|
+
def prefetch_related(self, *relations: str) -> Self:
|
|
324
|
+
"""
|
|
325
|
+
Prefetch related objects using separate optimized queries.
|
|
326
|
+
|
|
327
|
+
This method reduces N+1 query problems by fetching related objects
|
|
328
|
+
in batches after the main query completes. This is more efficient
|
|
329
|
+
than select_related for many-to-many and reverse relations.
|
|
330
|
+
|
|
331
|
+
Args:
|
|
332
|
+
*relations: Names of relations to prefetch.
|
|
333
|
+
|
|
334
|
+
Returns:
|
|
335
|
+
Self: The current instance of QuerySet to allow method chaining.
|
|
336
|
+
|
|
337
|
+
Example:
|
|
338
|
+
```python
|
|
339
|
+
# Load users and prefetch their followers
|
|
340
|
+
users = await User.objects().prefetch_related("followers", "posts").all()
|
|
341
|
+
for user in users:
|
|
342
|
+
print(user.followers) # Already loaded
|
|
343
|
+
print(user.posts) # Already loaded
|
|
344
|
+
```
|
|
345
|
+
"""
|
|
346
|
+
self._prefetch_related = list(relations)
|
|
347
|
+
return self
|
|
348
|
+
|
|
349
|
+
def traverse(self, path: str) -> Self:
|
|
350
|
+
"""
|
|
351
|
+
Add a graph traversal path to the query.
|
|
352
|
+
|
|
353
|
+
This method allows querying across graph relations using
|
|
354
|
+
SurrealDB's traversal syntax.
|
|
355
|
+
|
|
356
|
+
Args:
|
|
357
|
+
path: Graph traversal path (e.g., "->follows->users->likes->posts")
|
|
358
|
+
|
|
359
|
+
Returns:
|
|
360
|
+
Self: The current instance of QuerySet to allow method chaining.
|
|
361
|
+
|
|
362
|
+
Example:
|
|
363
|
+
```python
|
|
364
|
+
# Get all posts liked by users that alice follows
|
|
365
|
+
posts = await User.objects().filter(id="alice").traverse(
|
|
366
|
+
"->follows->users->likes->posts"
|
|
367
|
+
).all()
|
|
368
|
+
```
|
|
369
|
+
"""
|
|
370
|
+
self._traversal_path = path
|
|
224
371
|
return self
|
|
225
372
|
|
|
373
|
+
async def graph_query(self, traversal: str, **variables: Any) -> list[dict[str, Any]]:
|
|
374
|
+
"""
|
|
375
|
+
Execute a raw graph traversal query.
|
|
376
|
+
|
|
377
|
+
This method provides direct access to SurrealDB's graph capabilities
|
|
378
|
+
for complex traversal patterns that can't be expressed through
|
|
379
|
+
the standard QuerySet API.
|
|
380
|
+
|
|
381
|
+
Args:
|
|
382
|
+
traversal: Graph traversal expression (e.g., "->follows->User")
|
|
383
|
+
**variables: Variables to bind in the query
|
|
384
|
+
|
|
385
|
+
Returns:
|
|
386
|
+
list[dict[str, Any]]: Raw query results as dictionaries
|
|
387
|
+
|
|
388
|
+
Example:
|
|
389
|
+
```python
|
|
390
|
+
# Find users that alice follows
|
|
391
|
+
result = await User.objects().filter(id="alice").graph_query("->follows->User")
|
|
392
|
+
|
|
393
|
+
# Multi-hop traversal
|
|
394
|
+
result = await User.objects().filter(id="alice").graph_query(
|
|
395
|
+
"->follows->User->follows->User"
|
|
396
|
+
)
|
|
397
|
+
```
|
|
398
|
+
"""
|
|
399
|
+
# Parse the traversal to determine edge and direction
|
|
400
|
+
# Expected format: "->edge->Table" or "<-edge<-Table"
|
|
401
|
+
# For now, support simple single-hop traversals
|
|
402
|
+
|
|
403
|
+
# Check if we have an id filter to use as starting point
|
|
404
|
+
source_id = None
|
|
405
|
+
for field_name, lookup_name, value in self._filters:
|
|
406
|
+
if field_name == "id" and lookup_name == "exact":
|
|
407
|
+
source_id = value
|
|
408
|
+
break
|
|
409
|
+
|
|
410
|
+
client = await SurrealDBConnectionManager.get_client()
|
|
411
|
+
|
|
412
|
+
if source_id:
|
|
413
|
+
# Use specific record as starting point
|
|
414
|
+
source_thing = f"{self._model_table}:{source_id}"
|
|
415
|
+
|
|
416
|
+
# Parse the traversal to get edge and direction
|
|
417
|
+
# Simple pattern: ->edge->Table or <-edge<-Table
|
|
418
|
+
if traversal.startswith("->"):
|
|
419
|
+
# Outgoing: get targets where source is 'in'
|
|
420
|
+
parts = traversal.split("->")
|
|
421
|
+
if len(parts) >= 3:
|
|
422
|
+
edge = parts[1]
|
|
423
|
+
query = f"SELECT out FROM {edge} WHERE in = {source_thing} FETCH out;"
|
|
424
|
+
result = await client.query(query, {**self._variables, **variables})
|
|
425
|
+
records: list[dict[str, Any]] = []
|
|
426
|
+
for row in result.all_records or []:
|
|
427
|
+
if isinstance(row.get("out"), dict):
|
|
428
|
+
records.append(row["out"])
|
|
429
|
+
return records
|
|
430
|
+
elif traversal.startswith("<-"):
|
|
431
|
+
# Incoming: get sources where target is 'out'
|
|
432
|
+
parts = traversal.split("<-")
|
|
433
|
+
if len(parts) >= 3:
|
|
434
|
+
edge = parts[1]
|
|
435
|
+
query = f"SELECT in FROM {edge} WHERE out = {source_thing} FETCH in;"
|
|
436
|
+
result = await client.query(query, {**self._variables, **variables})
|
|
437
|
+
records = []
|
|
438
|
+
for row in result.all_records or []:
|
|
439
|
+
if isinstance(row.get("in"), dict):
|
|
440
|
+
records.append(row["in"])
|
|
441
|
+
return records
|
|
442
|
+
|
|
443
|
+
# Fallback: return empty for unsupported patterns
|
|
444
|
+
return []
|
|
445
|
+
|
|
446
|
+
async def _execute_annotate(self) -> list[dict[str, Any]]:
|
|
447
|
+
"""
|
|
448
|
+
Execute the GROUP BY query with annotations.
|
|
449
|
+
|
|
450
|
+
Returns:
|
|
451
|
+
list[dict[str, Any]]: A list of dictionaries containing grouped results.
|
|
452
|
+
"""
|
|
453
|
+
# Build SELECT clause with group fields and aggregations
|
|
454
|
+
select_parts: list[str] = list(self._group_by_fields)
|
|
455
|
+
|
|
456
|
+
for alias, aggregation in self._annotations.items():
|
|
457
|
+
select_parts.append(aggregation.to_surql(alias))
|
|
458
|
+
|
|
459
|
+
select_clause = ", ".join(select_parts)
|
|
460
|
+
|
|
461
|
+
# Build WHERE clause
|
|
462
|
+
where_clause = self._compile_where_clause()
|
|
463
|
+
|
|
464
|
+
# Build GROUP BY clause
|
|
465
|
+
if self._group_by_fields:
|
|
466
|
+
group_clause = f" GROUP BY {', '.join(self._group_by_fields)}"
|
|
467
|
+
else:
|
|
468
|
+
group_clause = " GROUP ALL"
|
|
469
|
+
|
|
470
|
+
query = f"SELECT {select_clause} FROM {self._model_table}{where_clause}{group_clause};"
|
|
471
|
+
|
|
472
|
+
client = await SurrealDBConnectionManager.get_client()
|
|
473
|
+
result = await client.query(remove_quotes_for_variables(query), self._variables)
|
|
474
|
+
|
|
475
|
+
return cast(list[dict[str, Any]], result.all_records)
|
|
476
|
+
|
|
226
477
|
def _compile_query(self) -> str:
|
|
227
478
|
"""
|
|
228
479
|
Compile the QuerySet parameters into a SQL query string.
|
|
@@ -261,6 +512,10 @@ class QuerySet:
|
|
|
261
512
|
if where_clauses:
|
|
262
513
|
query += " WHERE " + " AND ".join(where_clauses)
|
|
263
514
|
|
|
515
|
+
# Append ORDER BY if set (must come before LIMIT/START in SurrealQL)
|
|
516
|
+
if self._order_by:
|
|
517
|
+
query += f" ORDER BY {self._order_by}"
|
|
518
|
+
|
|
264
519
|
# Append LIMIT if set
|
|
265
520
|
if self._limit is not None:
|
|
266
521
|
query += f" LIMIT {self._limit}"
|
|
@@ -269,10 +524,6 @@ class QuerySet:
|
|
|
269
524
|
if self._offset is not None:
|
|
270
525
|
query += f" START {self._offset}"
|
|
271
526
|
|
|
272
|
-
# Append ORDER BY if set
|
|
273
|
-
if self._order_by:
|
|
274
|
-
query += f" ORDER BY {self._order_by}"
|
|
275
|
-
|
|
276
527
|
query += ";"
|
|
277
528
|
return query
|
|
278
529
|
|
|
@@ -284,27 +535,41 @@ class QuerySet:
|
|
|
284
535
|
the results. If the data conforms to the model schema, it returns a list of model instances;
|
|
285
536
|
otherwise, it returns a list of dictionaries.
|
|
286
537
|
|
|
538
|
+
When `annotate()` has been called, this returns the aggregated results as dictionaries
|
|
539
|
+
instead of model instances.
|
|
540
|
+
|
|
287
541
|
Returns:
|
|
288
542
|
list[BaseSurrealModel] | list[dict]: A list of model instances if validation is successful,
|
|
289
|
-
otherwise a list of dictionaries representing the raw data.
|
|
543
|
+
otherwise a list of dictionaries representing the raw data. For annotated queries,
|
|
544
|
+
always returns a list of dictionaries.
|
|
290
545
|
|
|
291
546
|
Raises:
|
|
292
547
|
SurrealDbError: If there is an issue executing the query.
|
|
293
548
|
|
|
294
549
|
Example:
|
|
295
550
|
```python
|
|
551
|
+
# Regular query
|
|
296
552
|
results = await queryset.exec()
|
|
553
|
+
|
|
554
|
+
# Aggregation query
|
|
555
|
+
stats = await Order.objects().values("status").annotate(
|
|
556
|
+
count=Count(),
|
|
557
|
+
total=Sum("amount"),
|
|
558
|
+
).exec()
|
|
297
559
|
```
|
|
298
560
|
"""
|
|
299
|
-
|
|
561
|
+
# If annotations are set, execute as GROUP BY query
|
|
562
|
+
if self._annotations:
|
|
563
|
+
return await self._execute_annotate()
|
|
564
|
+
|
|
300
565
|
query = self._compile_query()
|
|
301
566
|
results = await self._execute_query(query)
|
|
302
567
|
try:
|
|
303
|
-
|
|
304
|
-
return self.model.from_db(
|
|
568
|
+
# surrealdb SDK 1.0.8 returns records directly, not wrapped in {"result": ...}
|
|
569
|
+
return self.model.from_db(cast(dict | list | None, results))
|
|
305
570
|
except ValidationError as e:
|
|
306
571
|
logger.info(f"Pydantic invalid format for the class, returning dict value: {e}")
|
|
307
|
-
return
|
|
572
|
+
return results
|
|
308
573
|
|
|
309
574
|
async def first(self) -> Any:
|
|
310
575
|
"""
|
|
@@ -330,7 +595,7 @@ class QuerySet:
|
|
|
330
595
|
if results:
|
|
331
596
|
return results[0]
|
|
332
597
|
|
|
333
|
-
raise
|
|
598
|
+
raise self.model.DoesNotExist("Query returned no results.")
|
|
334
599
|
|
|
335
600
|
async def get(self, id_item: Any = None) -> Any:
|
|
336
601
|
"""
|
|
@@ -356,15 +621,18 @@ class QuerySet:
|
|
|
356
621
|
"""
|
|
357
622
|
if id_item:
|
|
358
623
|
client = await SurrealDBConnectionManager.get_client()
|
|
359
|
-
|
|
360
|
-
|
|
624
|
+
result = await client.select(f"{self._model_table}:{id_item}")
|
|
625
|
+
# SDK returns RecordsResponse
|
|
626
|
+
if result.is_empty:
|
|
627
|
+
raise self.model.DoesNotExist("Record not found.")
|
|
628
|
+
return self.model.from_db(cast(dict | list | None, result.first))
|
|
361
629
|
else:
|
|
362
630
|
result = await self.exec()
|
|
363
631
|
if len(result) > 1:
|
|
364
632
|
raise SurrealDbError("More than one result found.")
|
|
365
633
|
|
|
366
634
|
if len(result) == 0:
|
|
367
|
-
raise
|
|
635
|
+
raise self.model.DoesNotExist("Record not found.")
|
|
368
636
|
return result[0]
|
|
369
637
|
|
|
370
638
|
async def all(self) -> Any:
|
|
@@ -385,10 +653,168 @@ class QuerySet:
|
|
|
385
653
|
```
|
|
386
654
|
"""
|
|
387
655
|
client = await SurrealDBConnectionManager.get_client()
|
|
388
|
-
|
|
389
|
-
return self.model.from_db(
|
|
656
|
+
result = await client.select(self._model_table)
|
|
657
|
+
return self.model.from_db(cast(dict | list | None, result.records))
|
|
658
|
+
|
|
659
|
+
# ==================== Aggregation Methods ====================
|
|
390
660
|
|
|
391
|
-
|
|
661
|
+
def _compile_where_clause(self) -> str:
|
|
662
|
+
"""
|
|
663
|
+
Compile the WHERE clause from filters.
|
|
664
|
+
|
|
665
|
+
Returns:
|
|
666
|
+
str: The WHERE clause string (including WHERE keyword) or empty string.
|
|
667
|
+
"""
|
|
668
|
+
if not self._filters:
|
|
669
|
+
return ""
|
|
670
|
+
|
|
671
|
+
where_clauses = []
|
|
672
|
+
for field_name, lookup_name, value in self._filters:
|
|
673
|
+
op = LOOKUP_OPERATORS.get(lookup_name, "=")
|
|
674
|
+
if lookup_name == "in":
|
|
675
|
+
formatted_values = ", ".join(repr(v) for v in value)
|
|
676
|
+
where_clauses.append(f"{field_name} {op} [{formatted_values}]")
|
|
677
|
+
else:
|
|
678
|
+
where_clauses.append(f"{field_name} {op} {repr(value)}")
|
|
679
|
+
|
|
680
|
+
return " WHERE " + " AND ".join(where_clauses)
|
|
681
|
+
|
|
682
|
+
async def count(self) -> int:
|
|
683
|
+
"""
|
|
684
|
+
Count the number of records matching the current filters.
|
|
685
|
+
|
|
686
|
+
Returns:
|
|
687
|
+
int: The number of matching records.
|
|
688
|
+
|
|
689
|
+
Example:
|
|
690
|
+
```python
|
|
691
|
+
total = await User.objects().count()
|
|
692
|
+
active = await User.objects().filter(active=True).count()
|
|
693
|
+
```
|
|
694
|
+
"""
|
|
695
|
+
where_clause = self._compile_where_clause()
|
|
696
|
+
query = f"SELECT count() FROM {self._model_table}{where_clause} GROUP ALL;"
|
|
697
|
+
|
|
698
|
+
client = await SurrealDBConnectionManager.get_client()
|
|
699
|
+
result = await client.query(remove_quotes_for_variables(query), self._variables)
|
|
700
|
+
|
|
701
|
+
if result.all_records:
|
|
702
|
+
record = result.all_records[0]
|
|
703
|
+
if isinstance(record, dict) and "count" in record:
|
|
704
|
+
return int(record["count"])
|
|
705
|
+
return 0
|
|
706
|
+
|
|
707
|
+
async def sum(self, field: str) -> float | int:
|
|
708
|
+
"""
|
|
709
|
+
Calculate the sum of a numeric field.
|
|
710
|
+
|
|
711
|
+
Args:
|
|
712
|
+
field: The field name to sum.
|
|
713
|
+
|
|
714
|
+
Returns:
|
|
715
|
+
float | int: The sum of the field values, or 0 if no records match.
|
|
716
|
+
|
|
717
|
+
Example:
|
|
718
|
+
```python
|
|
719
|
+
total = await Order.objects().filter(status="paid").sum("amount")
|
|
720
|
+
```
|
|
721
|
+
"""
|
|
722
|
+
where_clause = self._compile_where_clause()
|
|
723
|
+
query = f"SELECT math::sum({field}) AS total FROM {self._model_table}{where_clause} GROUP ALL;"
|
|
724
|
+
|
|
725
|
+
client = await SurrealDBConnectionManager.get_client()
|
|
726
|
+
result = await client.query(remove_quotes_for_variables(query), self._variables)
|
|
727
|
+
|
|
728
|
+
if result.all_records:
|
|
729
|
+
record = result.all_records[0]
|
|
730
|
+
if isinstance(record, dict) and "total" in record:
|
|
731
|
+
value = record["total"]
|
|
732
|
+
return value if value is not None else 0
|
|
733
|
+
return 0
|
|
734
|
+
|
|
735
|
+
async def avg(self, field: str) -> float | None:
|
|
736
|
+
"""
|
|
737
|
+
Calculate the average of a numeric field.
|
|
738
|
+
|
|
739
|
+
Args:
|
|
740
|
+
field: The field name to average.
|
|
741
|
+
|
|
742
|
+
Returns:
|
|
743
|
+
float | None: The average value, or None if no records match.
|
|
744
|
+
|
|
745
|
+
Example:
|
|
746
|
+
```python
|
|
747
|
+
avg_age = await User.objects().filter(active=True).avg("age")
|
|
748
|
+
```
|
|
749
|
+
"""
|
|
750
|
+
where_clause = self._compile_where_clause()
|
|
751
|
+
query = f"SELECT math::mean({field}) AS average FROM {self._model_table}{where_clause} GROUP ALL;"
|
|
752
|
+
|
|
753
|
+
client = await SurrealDBConnectionManager.get_client()
|
|
754
|
+
result = await client.query(remove_quotes_for_variables(query), self._variables)
|
|
755
|
+
|
|
756
|
+
if result.all_records:
|
|
757
|
+
record = result.all_records[0]
|
|
758
|
+
if isinstance(record, dict) and "average" in record:
|
|
759
|
+
value = record["average"]
|
|
760
|
+
return float(value) if value is not None else None
|
|
761
|
+
return None
|
|
762
|
+
|
|
763
|
+
async def min(self, field: str) -> Any:
|
|
764
|
+
"""
|
|
765
|
+
Get the minimum value of a field.
|
|
766
|
+
|
|
767
|
+
Args:
|
|
768
|
+
field: The field name to find the minimum of.
|
|
769
|
+
|
|
770
|
+
Returns:
|
|
771
|
+
Any: The minimum value, or None if no records match.
|
|
772
|
+
|
|
773
|
+
Example:
|
|
774
|
+
```python
|
|
775
|
+
min_price = await Product.objects().min("price")
|
|
776
|
+
```
|
|
777
|
+
"""
|
|
778
|
+
where_clause = self._compile_where_clause()
|
|
779
|
+
query = f"SELECT math::min({field}) AS minimum FROM {self._model_table}{where_clause} GROUP ALL;"
|
|
780
|
+
|
|
781
|
+
client = await SurrealDBConnectionManager.get_client()
|
|
782
|
+
result = await client.query(remove_quotes_for_variables(query), self._variables)
|
|
783
|
+
|
|
784
|
+
if result.all_records:
|
|
785
|
+
record = result.all_records[0]
|
|
786
|
+
if isinstance(record, dict) and "minimum" in record:
|
|
787
|
+
return record["minimum"]
|
|
788
|
+
return None
|
|
789
|
+
|
|
790
|
+
async def max(self, field: str) -> Any:
|
|
791
|
+
"""
|
|
792
|
+
Get the maximum value of a field.
|
|
793
|
+
|
|
794
|
+
Args:
|
|
795
|
+
field: The field name to find the maximum of.
|
|
796
|
+
|
|
797
|
+
Returns:
|
|
798
|
+
Any: The maximum value, or None if no records match.
|
|
799
|
+
|
|
800
|
+
Example:
|
|
801
|
+
```python
|
|
802
|
+
max_price = await Product.objects().max("price")
|
|
803
|
+
```
|
|
804
|
+
"""
|
|
805
|
+
where_clause = self._compile_where_clause()
|
|
806
|
+
query = f"SELECT math::max({field}) AS maximum FROM {self._model_table}{where_clause} GROUP ALL;"
|
|
807
|
+
|
|
808
|
+
client = await SurrealDBConnectionManager.get_client()
|
|
809
|
+
result = await client.query(remove_quotes_for_variables(query), self._variables)
|
|
810
|
+
|
|
811
|
+
if result.all_records:
|
|
812
|
+
record = result.all_records[0]
|
|
813
|
+
if isinstance(record, dict) and "maximum" in record:
|
|
814
|
+
return record["maximum"]
|
|
815
|
+
return None
|
|
816
|
+
|
|
817
|
+
async def _execute_query(self, query: str) -> list[Any]:
|
|
392
818
|
"""
|
|
393
819
|
Execute the given SQL query using the SurrealDB client.
|
|
394
820
|
|
|
@@ -399,7 +825,7 @@ class QuerySet:
|
|
|
399
825
|
query (str): The SQL query string to execute.
|
|
400
826
|
|
|
401
827
|
Returns:
|
|
402
|
-
list[
|
|
828
|
+
list[Any]: A list of query response objects containing the query results.
|
|
403
829
|
|
|
404
830
|
Raises:
|
|
405
831
|
SurrealDbError: If there is an issue executing the query.
|
|
@@ -412,7 +838,7 @@ class QuerySet:
|
|
|
412
838
|
client = await SurrealDBConnectionManager.get_client()
|
|
413
839
|
return await self._run_query_on_client(client, query)
|
|
414
840
|
|
|
415
|
-
async def _run_query_on_client(self, client:
|
|
841
|
+
async def _run_query_on_client(self, client: Any, query: str) -> list[Any]:
|
|
416
842
|
"""
|
|
417
843
|
Run the SQL query on the provided SurrealDB client.
|
|
418
844
|
|
|
@@ -420,11 +846,11 @@ class QuerySet:
|
|
|
420
846
|
and returns the raw query responses.
|
|
421
847
|
|
|
422
848
|
Args:
|
|
423
|
-
client
|
|
849
|
+
client: The active SurrealDB client instance.
|
|
424
850
|
query (str): The SQL query string to execute.
|
|
425
851
|
|
|
426
852
|
Returns:
|
|
427
|
-
list[
|
|
853
|
+
list[Any]: A list of query response objects containing the query results.
|
|
428
854
|
|
|
429
855
|
Raises:
|
|
430
856
|
SurrealDbError: If there is an issue executing the query.
|
|
@@ -434,7 +860,9 @@ class QuerySet:
|
|
|
434
860
|
results = await self._run_query_on_client(client, "SELECT * FROM users;")
|
|
435
861
|
```
|
|
436
862
|
"""
|
|
437
|
-
|
|
863
|
+
result = await client.query(remove_quotes_for_variables(query), self._variables)
|
|
864
|
+
# SDK returns QueryResponse, extract all records
|
|
865
|
+
return cast(list[Any], result.all_records)
|
|
438
866
|
|
|
439
867
|
async def delete_table(self) -> bool:
|
|
440
868
|
"""
|
|
@@ -455,7 +883,7 @@ class QuerySet:
|
|
|
455
883
|
```
|
|
456
884
|
"""
|
|
457
885
|
client = await SurrealDBConnectionManager.get_client()
|
|
458
|
-
await client.delete(
|
|
886
|
+
await client.delete(self._model_table)
|
|
459
887
|
return True
|
|
460
888
|
|
|
461
889
|
async def query(self, query: str, variables: dict[str, Any] = {}) -> Any:
|
|
@@ -483,9 +911,157 @@ class QuerySet:
|
|
|
483
911
|
results = await queryset.query(custom_query, variables={'status': 'active'})
|
|
484
912
|
```
|
|
485
913
|
"""
|
|
486
|
-
if f"FROM {self.
|
|
487
|
-
raise SurrealDbError(f"The query must include 'FROM {self.
|
|
914
|
+
if f"FROM {self._model_table}" not in query:
|
|
915
|
+
raise SurrealDbError(f"The query must include 'FROM {self._model_table}' to reference the correct table.")
|
|
916
|
+
client = await SurrealDBConnectionManager.get_client()
|
|
917
|
+
result = await client.query(remove_quotes_for_variables(query), variables)
|
|
918
|
+
# SDK returns QueryResponse, extract all records
|
|
919
|
+
return self.model.from_db(cast(dict | list | None, result.all_records))
|
|
920
|
+
|
|
921
|
+
# ==================== Bulk Operations ====================
|
|
922
|
+
|
|
923
|
+
async def bulk_create(
|
|
924
|
+
self,
|
|
925
|
+
instances: Sequence[BaseSurrealModel],
|
|
926
|
+
atomic: bool = False,
|
|
927
|
+
batch_size: int | None = None,
|
|
928
|
+
) -> list[BaseSurrealModel]:
|
|
929
|
+
"""
|
|
930
|
+
Create multiple model instances in the database efficiently.
|
|
931
|
+
|
|
932
|
+
Args:
|
|
933
|
+
instances: A sequence of model instances to create.
|
|
934
|
+
atomic: If True, all creates are wrapped in a transaction.
|
|
935
|
+
If any fails, all are rolled back.
|
|
936
|
+
batch_size: If specified, instances are created in batches of this size.
|
|
937
|
+
Useful for very large datasets to avoid memory issues.
|
|
938
|
+
|
|
939
|
+
Returns:
|
|
940
|
+
list[BaseSurrealModel]: The created instances.
|
|
941
|
+
|
|
942
|
+
Example:
|
|
943
|
+
```python
|
|
944
|
+
users = [User(name=f"User{i}") for i in range(1000)]
|
|
945
|
+
|
|
946
|
+
# Simple bulk create
|
|
947
|
+
created = await User.objects().bulk_create(users)
|
|
948
|
+
|
|
949
|
+
# Atomic bulk create
|
|
950
|
+
created = await User.objects().bulk_create(users, atomic=True)
|
|
951
|
+
|
|
952
|
+
# With batch size
|
|
953
|
+
created = await User.objects().bulk_create(users, batch_size=100)
|
|
954
|
+
```
|
|
955
|
+
"""
|
|
956
|
+
if not instances:
|
|
957
|
+
return []
|
|
958
|
+
|
|
959
|
+
created: list[BaseSurrealModel] = []
|
|
960
|
+
|
|
961
|
+
if atomic:
|
|
962
|
+
# Use transaction for atomicity
|
|
963
|
+
async with await SurrealDBConnectionManager.transaction() as tx:
|
|
964
|
+
for instance in instances:
|
|
965
|
+
await instance.save(tx=tx)
|
|
966
|
+
created.append(instance)
|
|
967
|
+
elif batch_size:
|
|
968
|
+
# Process in batches
|
|
969
|
+
for i in range(0, len(instances), batch_size):
|
|
970
|
+
batch = instances[i : i + batch_size]
|
|
971
|
+
for instance in batch:
|
|
972
|
+
await instance.save()
|
|
973
|
+
created.append(instance)
|
|
974
|
+
else:
|
|
975
|
+
# Simple sequential create
|
|
976
|
+
for instance in instances:
|
|
977
|
+
await instance.save()
|
|
978
|
+
created.append(instance)
|
|
979
|
+
|
|
980
|
+
return created
|
|
981
|
+
|
|
982
|
+
async def bulk_update(
|
|
983
|
+
self,
|
|
984
|
+
data: dict[str, Any],
|
|
985
|
+
atomic: bool = False,
|
|
986
|
+
) -> int:
|
|
987
|
+
"""
|
|
988
|
+
Update all records matching the current filters.
|
|
989
|
+
|
|
990
|
+
Args:
|
|
991
|
+
data: A dictionary of field names and values to update.
|
|
992
|
+
atomic: If True, all updates are wrapped in a transaction.
|
|
993
|
+
|
|
994
|
+
Returns:
|
|
995
|
+
int: The number of records updated.
|
|
996
|
+
|
|
997
|
+
Example:
|
|
998
|
+
```python
|
|
999
|
+
# Update all matching records
|
|
1000
|
+
updated = await User.objects().filter(
|
|
1001
|
+
last_login__lt="2025-01-01"
|
|
1002
|
+
).bulk_update({"status": "inactive"})
|
|
1003
|
+
|
|
1004
|
+
# Atomic update
|
|
1005
|
+
updated = await User.objects().filter(role="guest").bulk_update(
|
|
1006
|
+
{"verified": True},
|
|
1007
|
+
atomic=True
|
|
1008
|
+
)
|
|
1009
|
+
```
|
|
1010
|
+
"""
|
|
1011
|
+
where_clause = self._compile_where_clause()
|
|
1012
|
+
|
|
1013
|
+
# Build SET clause
|
|
1014
|
+
set_parts = []
|
|
1015
|
+
for field, value in data.items():
|
|
1016
|
+
set_parts.append(f"{field} = {repr(value)}")
|
|
1017
|
+
set_clause = ", ".join(set_parts)
|
|
1018
|
+
|
|
1019
|
+
query = f"UPDATE {self._model_table} SET {set_clause}{where_clause};"
|
|
1020
|
+
|
|
1021
|
+
if atomic:
|
|
1022
|
+
# For atomic operations, count first then update in transaction
|
|
1023
|
+
current_count = await self.count()
|
|
1024
|
+
async with await SurrealDBConnectionManager.transaction() as tx:
|
|
1025
|
+
await tx.query(remove_quotes_for_variables(query), self._variables)
|
|
1026
|
+
return current_count
|
|
1027
|
+
|
|
1028
|
+
client = await SurrealDBConnectionManager.get_client()
|
|
1029
|
+
result = await client.query(remove_quotes_for_variables(query), self._variables)
|
|
1030
|
+
return len(result.all_records)
|
|
1031
|
+
|
|
1032
|
+
async def bulk_delete(self, atomic: bool = False) -> int:
|
|
1033
|
+
"""
|
|
1034
|
+
Delete all records matching the current filters.
|
|
1035
|
+
|
|
1036
|
+
Args:
|
|
1037
|
+
atomic: If True, all deletes are wrapped in a transaction.
|
|
1038
|
+
|
|
1039
|
+
Returns:
|
|
1040
|
+
int: The number of records deleted.
|
|
1041
|
+
|
|
1042
|
+
Example:
|
|
1043
|
+
```python
|
|
1044
|
+
# Delete all matching records
|
|
1045
|
+
deleted = await User.objects().filter(status="deleted").bulk_delete()
|
|
1046
|
+
|
|
1047
|
+
# Atomic delete
|
|
1048
|
+
deleted = await Order.objects().filter(
|
|
1049
|
+
created_at__lt="2024-01-01"
|
|
1050
|
+
).bulk_delete(atomic=True)
|
|
1051
|
+
```
|
|
1052
|
+
"""
|
|
1053
|
+
where_clause = self._compile_where_clause()
|
|
1054
|
+
# Use RETURN BEFORE to get deleted records count
|
|
1055
|
+
query = f"DELETE FROM {self._model_table}{where_clause} RETURN BEFORE;"
|
|
1056
|
+
|
|
1057
|
+
if atomic:
|
|
1058
|
+
# For atomic operations, count first then delete in transaction
|
|
1059
|
+
current_count = await self.count()
|
|
1060
|
+
delete_query = f"DELETE FROM {self._model_table}{where_clause};"
|
|
1061
|
+
async with await SurrealDBConnectionManager.transaction() as tx:
|
|
1062
|
+
await tx.query(remove_quotes_for_variables(delete_query), self._variables)
|
|
1063
|
+
return current_count
|
|
1064
|
+
|
|
488
1065
|
client = await SurrealDBConnectionManager.get_client()
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
return self.model.from_db(data["result"])
|
|
1066
|
+
result = await client.query(remove_quotes_for_variables(query), self._variables)
|
|
1067
|
+
return len(result.all_records)
|