surrealdb-orm 0.1.4__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. surreal_orm/__init__.py +72 -3
  2. surreal_orm/aggregations.py +164 -0
  3. surreal_orm/auth/__init__.py +15 -0
  4. surreal_orm/auth/access.py +167 -0
  5. surreal_orm/auth/mixins.py +302 -0
  6. surreal_orm/cli/__init__.py +15 -0
  7. surreal_orm/cli/commands.py +369 -0
  8. surreal_orm/connection_manager.py +58 -18
  9. surreal_orm/fields/__init__.py +36 -0
  10. surreal_orm/fields/encrypted.py +166 -0
  11. surreal_orm/fields/relation.py +465 -0
  12. surreal_orm/migrations/__init__.py +51 -0
  13. surreal_orm/migrations/executor.py +380 -0
  14. surreal_orm/migrations/generator.py +272 -0
  15. surreal_orm/migrations/introspector.py +305 -0
  16. surreal_orm/migrations/migration.py +188 -0
  17. surreal_orm/migrations/operations.py +531 -0
  18. surreal_orm/migrations/state.py +406 -0
  19. surreal_orm/model_base.py +530 -44
  20. surreal_orm/query_set.py +609 -33
  21. surreal_orm/relations.py +645 -0
  22. surreal_orm/surreal_function.py +95 -0
  23. surreal_orm/surreal_ql.py +113 -0
  24. surreal_orm/types.py +86 -0
  25. surreal_sdk/README.md +79 -0
  26. surreal_sdk/__init__.py +151 -0
  27. surreal_sdk/connection/__init__.py +17 -0
  28. surreal_sdk/connection/base.py +516 -0
  29. surreal_sdk/connection/http.py +421 -0
  30. surreal_sdk/connection/pool.py +244 -0
  31. surreal_sdk/connection/websocket.py +519 -0
  32. surreal_sdk/exceptions.py +71 -0
  33. surreal_sdk/functions.py +607 -0
  34. surreal_sdk/protocol/__init__.py +13 -0
  35. surreal_sdk/protocol/rpc.py +218 -0
  36. surreal_sdk/py.typed +0 -0
  37. surreal_sdk/pyproject.toml +49 -0
  38. surreal_sdk/streaming/__init__.py +31 -0
  39. surreal_sdk/streaming/change_feed.py +278 -0
  40. surreal_sdk/streaming/live_query.py +265 -0
  41. surreal_sdk/streaming/live_select.py +369 -0
  42. surreal_sdk/transaction.py +386 -0
  43. surreal_sdk/types.py +346 -0
  44. surrealdb_orm-0.5.0.dist-info/METADATA +465 -0
  45. surrealdb_orm-0.5.0.dist-info/RECORD +52 -0
  46. {surrealdb_orm-0.1.4.dist-info → surrealdb_orm-0.5.0.dist-info}/WHEEL +1 -1
  47. surrealdb_orm-0.5.0.dist-info/entry_points.txt +2 -0
  48. {surrealdb_orm-0.1.4.dist-info → surrealdb_orm-0.5.0.dist-info}/licenses/LICENSE +1 -1
  49. surrealdb_orm-0.1.4.dist-info/METADATA +0 -184
  50. surrealdb_orm-0.1.4.dist-info/RECORD +0 -12
surreal_orm/query_set.py CHANGED
@@ -1,14 +1,20 @@
1
1
  from .constants import LOOKUP_OPERATORS
2
2
  from .enum import OrderBy
3
3
  from .utils import remove_quotes_for_variables
4
- from surrealdb import QueryResponse, Table, AsyncSurrealDB
5
- from surrealdb.errors import SurrealDbError
6
4
  from . import BaseSurrealModel, SurrealDBConnectionManager
7
- from typing import Self, Any, cast
5
+ from .aggregations import Aggregation
6
+ from typing import Self, Any, Sequence, cast
8
7
  from pydantic_core import ValidationError
9
8
 
10
9
  import logging
11
10
 
11
+
12
+ class SurrealDbError(Exception):
13
+ """Error from SurrealDB operations."""
14
+
15
+ pass
16
+
17
+
12
18
  logger = logging.getLogger(__name__)
13
19
 
14
20
 
@@ -55,8 +61,14 @@ class QuerySet:
55
61
  self._limit: int | None = None
56
62
  self._offset: int | None = None
57
63
  self._order_by: str | None = None
58
- self._model_table: str = getattr(model, "_table_name", model.__name__)
64
+ self._model_table: str = model.__name__
59
65
  self._variables: dict = {}
66
+ self._group_by_fields: list[str] = []
67
+ self._annotations: dict[str, Aggregation] = {}
68
+ # Relation query options
69
+ self._select_related: list[str] = []
70
+ self._prefetch_related: list[str] = []
71
+ self._traversal_path: str | None = None
60
72
 
61
73
  def select(self, *fields: str) -> Self:
62
74
  """
@@ -201,7 +213,7 @@ class QuerySet:
201
213
  self._offset = value
202
214
  return self
203
215
 
204
- def order_by(self, field_name: str, type: OrderBy = OrderBy.ASC) -> Self:
216
+ def order_by(self, field_name: str, order_type: OrderBy = OrderBy.ASC) -> Self:
205
217
  """
206
218
  Set the field and direction to order the results by.
207
219
 
@@ -220,9 +232,248 @@ class QuerySet:
220
232
  queryset.order_by('name', OrderBy.DESC)
221
233
  ```
222
234
  """
223
- self._order_by = f"{field_name} {type}"
235
+ self._order_by = f"{field_name} {order_type}"
236
+ return self
237
+
238
+ def values(self, *fields: str) -> Self:
239
+ """
240
+ Specify the fields to group by for aggregation queries.
241
+
242
+ This method is used in conjunction with `annotate()` to perform GROUP BY operations.
243
+ The specified fields become the grouping keys, and aggregation functions are applied
244
+ to each group.
245
+
246
+ Args:
247
+ *fields (str): Variable length argument list of field names to group by.
248
+
249
+ Returns:
250
+ Self: The current instance of QuerySet to allow method chaining.
251
+
252
+ Example:
253
+ ```python
254
+ # Group orders by status and calculate statistics
255
+ stats = await Order.objects().values("status").annotate(
256
+ count=Count(),
257
+ total=Sum("amount"),
258
+ )
259
+ # Result: [{"status": "paid", "count": 42, "total": 5000}, ...]
260
+ ```
261
+ """
262
+ self._group_by_fields = list(fields)
263
+ return self
264
+
265
+ def annotate(self, **aggregations: Aggregation) -> Self:
266
+ """
267
+ Add aggregation functions to compute values for each group.
268
+
269
+ This method is used in conjunction with `values()` to perform GROUP BY operations.
270
+ Each keyword argument should be an Aggregation instance (Count, Sum, Avg, Min, Max).
271
+
272
+ Args:
273
+ **aggregations (Aggregation): Keyword arguments where keys are alias names
274
+ and values are Aggregation instances.
275
+
276
+ Returns:
277
+ Self: The current instance of QuerySet to allow method chaining.
278
+
279
+ Example:
280
+ ```python
281
+ from surreal_orm.aggregations import Count, Sum, Avg
282
+
283
+ # Calculate statistics per status
284
+ stats = await Order.objects().values("status").annotate(
285
+ count=Count(),
286
+ total=Sum("amount"),
287
+ avg_amount=Avg("amount"),
288
+ )
289
+ ```
290
+ """
291
+ self._annotations = aggregations
292
+ return self
293
+
294
+ # ==================== Relation Query Methods ====================
295
+
296
+ def select_related(self, *relations: str) -> Self:
297
+ """
298
+ Eagerly load related objects in the same query.
299
+
300
+ This method optimizes queries by loading related objects alongside
301
+ the main query results, avoiding N+1 query problems for forward relations.
302
+
303
+ Note: SurrealDB handles this through graph traversal syntax.
304
+ The actual loading happens during query execution.
305
+
306
+ Args:
307
+ *relations: Names of relations to load eagerly.
308
+
309
+ Returns:
310
+ Self: The current instance of QuerySet to allow method chaining.
311
+
312
+ Example:
313
+ ```python
314
+ # Load posts with their authors in one query
315
+ posts = await Post.objects().select_related("author").all()
316
+ for post in posts:
317
+ print(post.author.name) # No additional query
318
+ ```
319
+ """
320
+ self._select_related = list(relations)
321
+ return self
322
+
323
+ def prefetch_related(self, *relations: str) -> Self:
324
+ """
325
+ Prefetch related objects using separate optimized queries.
326
+
327
+ This method reduces N+1 query problems by fetching related objects
328
+ in batches after the main query completes. This is more efficient
329
+ than select_related for many-to-many and reverse relations.
330
+
331
+ Args:
332
+ *relations: Names of relations to prefetch.
333
+
334
+ Returns:
335
+ Self: The current instance of QuerySet to allow method chaining.
336
+
337
+ Example:
338
+ ```python
339
+ # Load users and prefetch their followers
340
+ users = await User.objects().prefetch_related("followers", "posts").all()
341
+ for user in users:
342
+ print(user.followers) # Already loaded
343
+ print(user.posts) # Already loaded
344
+ ```
345
+ """
346
+ self._prefetch_related = list(relations)
347
+ return self
348
+
349
+ def traverse(self, path: str) -> Self:
350
+ """
351
+ Add a graph traversal path to the query.
352
+
353
+ This method allows querying across graph relations using
354
+ SurrealDB's traversal syntax.
355
+
356
+ Args:
357
+ path: Graph traversal path (e.g., "->follows->users->likes->posts")
358
+
359
+ Returns:
360
+ Self: The current instance of QuerySet to allow method chaining.
361
+
362
+ Example:
363
+ ```python
364
+ # Get all posts liked by users that alice follows
365
+ posts = await User.objects().filter(id="alice").traverse(
366
+ "->follows->users->likes->posts"
367
+ ).all()
368
+ ```
369
+ """
370
+ self._traversal_path = path
224
371
  return self
225
372
 
373
+ async def graph_query(self, traversal: str, **variables: Any) -> list[dict[str, Any]]:
374
+ """
375
+ Execute a raw graph traversal query.
376
+
377
+ This method provides direct access to SurrealDB's graph capabilities
378
+ for complex traversal patterns that can't be expressed through
379
+ the standard QuerySet API.
380
+
381
+ Args:
382
+ traversal: Graph traversal expression (e.g., "->follows->User")
383
+ **variables: Variables to bind in the query
384
+
385
+ Returns:
386
+ list[dict[str, Any]]: Raw query results as dictionaries
387
+
388
+ Example:
389
+ ```python
390
+ # Find users that alice follows
391
+ result = await User.objects().filter(id="alice").graph_query("->follows->User")
392
+
393
+ # Multi-hop traversal
394
+ result = await User.objects().filter(id="alice").graph_query(
395
+ "->follows->User->follows->User"
396
+ )
397
+ ```
398
+ """
399
+ # Parse the traversal to determine edge and direction
400
+ # Expected format: "->edge->Table" or "<-edge<-Table"
401
+ # For now, support simple single-hop traversals
402
+
403
+ # Check if we have an id filter to use as starting point
404
+ source_id = None
405
+ for field_name, lookup_name, value in self._filters:
406
+ if field_name == "id" and lookup_name == "exact":
407
+ source_id = value
408
+ break
409
+
410
+ client = await SurrealDBConnectionManager.get_client()
411
+
412
+ if source_id:
413
+ # Use specific record as starting point
414
+ source_thing = f"{self._model_table}:{source_id}"
415
+
416
+ # Parse the traversal to get edge and direction
417
+ # Simple pattern: ->edge->Table or <-edge<-Table
418
+ if traversal.startswith("->"):
419
+ # Outgoing: get targets where source is 'in'
420
+ parts = traversal.split("->")
421
+ if len(parts) >= 3:
422
+ edge = parts[1]
423
+ query = f"SELECT out FROM {edge} WHERE in = {source_thing} FETCH out;"
424
+ result = await client.query(query, {**self._variables, **variables})
425
+ records: list[dict[str, Any]] = []
426
+ for row in result.all_records or []:
427
+ if isinstance(row.get("out"), dict):
428
+ records.append(row["out"])
429
+ return records
430
+ elif traversal.startswith("<-"):
431
+ # Incoming: get sources where target is 'out'
432
+ parts = traversal.split("<-")
433
+ if len(parts) >= 3:
434
+ edge = parts[1]
435
+ query = f"SELECT in FROM {edge} WHERE out = {source_thing} FETCH in;"
436
+ result = await client.query(query, {**self._variables, **variables})
437
+ records = []
438
+ for row in result.all_records or []:
439
+ if isinstance(row.get("in"), dict):
440
+ records.append(row["in"])
441
+ return records
442
+
443
+ # Fallback: return empty for unsupported patterns
444
+ return []
445
+
446
+ async def _execute_annotate(self) -> list[dict[str, Any]]:
447
+ """
448
+ Execute the GROUP BY query with annotations.
449
+
450
+ Returns:
451
+ list[dict[str, Any]]: A list of dictionaries containing grouped results.
452
+ """
453
+ # Build SELECT clause with group fields and aggregations
454
+ select_parts: list[str] = list(self._group_by_fields)
455
+
456
+ for alias, aggregation in self._annotations.items():
457
+ select_parts.append(aggregation.to_surql(alias))
458
+
459
+ select_clause = ", ".join(select_parts)
460
+
461
+ # Build WHERE clause
462
+ where_clause = self._compile_where_clause()
463
+
464
+ # Build GROUP BY clause
465
+ if self._group_by_fields:
466
+ group_clause = f" GROUP BY {', '.join(self._group_by_fields)}"
467
+ else:
468
+ group_clause = " GROUP ALL"
469
+
470
+ query = f"SELECT {select_clause} FROM {self._model_table}{where_clause}{group_clause};"
471
+
472
+ client = await SurrealDBConnectionManager.get_client()
473
+ result = await client.query(remove_quotes_for_variables(query), self._variables)
474
+
475
+ return cast(list[dict[str, Any]], result.all_records)
476
+
226
477
  def _compile_query(self) -> str:
227
478
  """
228
479
  Compile the QuerySet parameters into a SQL query string.
@@ -261,6 +512,10 @@ class QuerySet:
261
512
  if where_clauses:
262
513
  query += " WHERE " + " AND ".join(where_clauses)
263
514
 
515
+ # Append ORDER BY if set (must come before LIMIT/START in SurrealQL)
516
+ if self._order_by:
517
+ query += f" ORDER BY {self._order_by}"
518
+
264
519
  # Append LIMIT if set
265
520
  if self._limit is not None:
266
521
  query += f" LIMIT {self._limit}"
@@ -269,10 +524,6 @@ class QuerySet:
269
524
  if self._offset is not None:
270
525
  query += f" START {self._offset}"
271
526
 
272
- # Append ORDER BY if set
273
- if self._order_by:
274
- query += f" ORDER BY {self._order_by}"
275
-
276
527
  query += ";"
277
528
  return query
278
529
 
@@ -284,27 +535,41 @@ class QuerySet:
284
535
  the results. If the data conforms to the model schema, it returns a list of model instances;
285
536
  otherwise, it returns a list of dictionaries.
286
537
 
538
+ When `annotate()` has been called, this returns the aggregated results as dictionaries
539
+ instead of model instances.
540
+
287
541
  Returns:
288
542
  list[BaseSurrealModel] | list[dict]: A list of model instances if validation is successful,
289
- otherwise a list of dictionaries representing the raw data.
543
+ otherwise a list of dictionaries representing the raw data. For annotated queries,
544
+ always returns a list of dictionaries.
290
545
 
291
546
  Raises:
292
547
  SurrealDbError: If there is an issue executing the query.
293
548
 
294
549
  Example:
295
550
  ```python
551
+ # Regular query
296
552
  results = await queryset.exec()
553
+
554
+ # Aggregation query
555
+ stats = await Order.objects().values("status").annotate(
556
+ count=Count(),
557
+ total=Sum("amount"),
558
+ ).exec()
297
559
  ```
298
560
  """
299
- data: dict[str, Any] = {"result": []}
561
+ # If annotations are set, execute as GROUP BY query
562
+ if self._annotations:
563
+ return await self._execute_annotate()
564
+
300
565
  query = self._compile_query()
301
566
  results = await self._execute_query(query)
302
567
  try:
303
- data = cast(dict, results[0])
304
- return self.model.from_db(data["result"])
568
+ # surrealdb SDK 1.0.8 returns records directly, not wrapped in {"result": ...}
569
+ return self.model.from_db(cast(dict | list | None, results))
305
570
  except ValidationError as e:
306
571
  logger.info(f"Pydantic invalid format for the class, returning dict value: {e}")
307
- return data["result"]
572
+ return results
308
573
 
309
574
  async def first(self) -> Any:
310
575
  """
@@ -330,7 +595,7 @@ class QuerySet:
330
595
  if results:
331
596
  return results[0]
332
597
 
333
- raise SurrealDbError("No result found.")
598
+ raise self.model.DoesNotExist("Query returned no results.")
334
599
 
335
600
  async def get(self, id_item: Any = None) -> Any:
336
601
  """
@@ -356,15 +621,18 @@ class QuerySet:
356
621
  """
357
622
  if id_item:
358
623
  client = await SurrealDBConnectionManager.get_client()
359
- data = await client.select(f"{self._model_table}:{id_item}")
360
- return self.model.from_db(data)
624
+ result = await client.select(f"{self._model_table}:{id_item}")
625
+ # SDK returns RecordsResponse
626
+ if result.is_empty:
627
+ raise self.model.DoesNotExist("Record not found.")
628
+ return self.model.from_db(cast(dict | list | None, result.first))
361
629
  else:
362
630
  result = await self.exec()
363
631
  if len(result) > 1:
364
632
  raise SurrealDbError("More than one result found.")
365
633
 
366
634
  if len(result) == 0:
367
- raise SurrealDbError("No result found.")
635
+ raise self.model.DoesNotExist("Record not found.")
368
636
  return result[0]
369
637
 
370
638
  async def all(self) -> Any:
@@ -385,10 +653,168 @@ class QuerySet:
385
653
  ```
386
654
  """
387
655
  client = await SurrealDBConnectionManager.get_client()
388
- results = await client.select(Table(self._model_table))
389
- return self.model.from_db(results)
656
+ result = await client.select(self._model_table)
657
+ return self.model.from_db(cast(dict | list | None, result.records))
658
+
659
+ # ==================== Aggregation Methods ====================
390
660
 
391
- async def _execute_query(self, query: str) -> list[QueryResponse]:
661
+ def _compile_where_clause(self) -> str:
662
+ """
663
+ Compile the WHERE clause from filters.
664
+
665
+ Returns:
666
+ str: The WHERE clause string (including WHERE keyword) or empty string.
667
+ """
668
+ if not self._filters:
669
+ return ""
670
+
671
+ where_clauses = []
672
+ for field_name, lookup_name, value in self._filters:
673
+ op = LOOKUP_OPERATORS.get(lookup_name, "=")
674
+ if lookup_name == "in":
675
+ formatted_values = ", ".join(repr(v) for v in value)
676
+ where_clauses.append(f"{field_name} {op} [{formatted_values}]")
677
+ else:
678
+ where_clauses.append(f"{field_name} {op} {repr(value)}")
679
+
680
+ return " WHERE " + " AND ".join(where_clauses)
681
+
682
+ async def count(self) -> int:
683
+ """
684
+ Count the number of records matching the current filters.
685
+
686
+ Returns:
687
+ int: The number of matching records.
688
+
689
+ Example:
690
+ ```python
691
+ total = await User.objects().count()
692
+ active = await User.objects().filter(active=True).count()
693
+ ```
694
+ """
695
+ where_clause = self._compile_where_clause()
696
+ query = f"SELECT count() FROM {self._model_table}{where_clause} GROUP ALL;"
697
+
698
+ client = await SurrealDBConnectionManager.get_client()
699
+ result = await client.query(remove_quotes_for_variables(query), self._variables)
700
+
701
+ if result.all_records:
702
+ record = result.all_records[0]
703
+ if isinstance(record, dict) and "count" in record:
704
+ return int(record["count"])
705
+ return 0
706
+
707
+ async def sum(self, field: str) -> float | int:
708
+ """
709
+ Calculate the sum of a numeric field.
710
+
711
+ Args:
712
+ field: The field name to sum.
713
+
714
+ Returns:
715
+ float | int: The sum of the field values, or 0 if no records match.
716
+
717
+ Example:
718
+ ```python
719
+ total = await Order.objects().filter(status="paid").sum("amount")
720
+ ```
721
+ """
722
+ where_clause = self._compile_where_clause()
723
+ query = f"SELECT math::sum({field}) AS total FROM {self._model_table}{where_clause} GROUP ALL;"
724
+
725
+ client = await SurrealDBConnectionManager.get_client()
726
+ result = await client.query(remove_quotes_for_variables(query), self._variables)
727
+
728
+ if result.all_records:
729
+ record = result.all_records[0]
730
+ if isinstance(record, dict) and "total" in record:
731
+ value = record["total"]
732
+ return value if value is not None else 0
733
+ return 0
734
+
735
+ async def avg(self, field: str) -> float | None:
736
+ """
737
+ Calculate the average of a numeric field.
738
+
739
+ Args:
740
+ field: The field name to average.
741
+
742
+ Returns:
743
+ float | None: The average value, or None if no records match.
744
+
745
+ Example:
746
+ ```python
747
+ avg_age = await User.objects().filter(active=True).avg("age")
748
+ ```
749
+ """
750
+ where_clause = self._compile_where_clause()
751
+ query = f"SELECT math::mean({field}) AS average FROM {self._model_table}{where_clause} GROUP ALL;"
752
+
753
+ client = await SurrealDBConnectionManager.get_client()
754
+ result = await client.query(remove_quotes_for_variables(query), self._variables)
755
+
756
+ if result.all_records:
757
+ record = result.all_records[0]
758
+ if isinstance(record, dict) and "average" in record:
759
+ value = record["average"]
760
+ return float(value) if value is not None else None
761
+ return None
762
+
763
+ async def min(self, field: str) -> Any:
764
+ """
765
+ Get the minimum value of a field.
766
+
767
+ Args:
768
+ field: The field name to find the minimum of.
769
+
770
+ Returns:
771
+ Any: The minimum value, or None if no records match.
772
+
773
+ Example:
774
+ ```python
775
+ min_price = await Product.objects().min("price")
776
+ ```
777
+ """
778
+ where_clause = self._compile_where_clause()
779
+ query = f"SELECT math::min({field}) AS minimum FROM {self._model_table}{where_clause} GROUP ALL;"
780
+
781
+ client = await SurrealDBConnectionManager.get_client()
782
+ result = await client.query(remove_quotes_for_variables(query), self._variables)
783
+
784
+ if result.all_records:
785
+ record = result.all_records[0]
786
+ if isinstance(record, dict) and "minimum" in record:
787
+ return record["minimum"]
788
+ return None
789
+
790
+ async def max(self, field: str) -> Any:
791
+ """
792
+ Get the maximum value of a field.
793
+
794
+ Args:
795
+ field: The field name to find the maximum of.
796
+
797
+ Returns:
798
+ Any: The maximum value, or None if no records match.
799
+
800
+ Example:
801
+ ```python
802
+ max_price = await Product.objects().max("price")
803
+ ```
804
+ """
805
+ where_clause = self._compile_where_clause()
806
+ query = f"SELECT math::max({field}) AS maximum FROM {self._model_table}{where_clause} GROUP ALL;"
807
+
808
+ client = await SurrealDBConnectionManager.get_client()
809
+ result = await client.query(remove_quotes_for_variables(query), self._variables)
810
+
811
+ if result.all_records:
812
+ record = result.all_records[0]
813
+ if isinstance(record, dict) and "maximum" in record:
814
+ return record["maximum"]
815
+ return None
816
+
817
+ async def _execute_query(self, query: str) -> list[Any]:
392
818
  """
393
819
  Execute the given SQL query using the SurrealDB client.
394
820
 
@@ -399,7 +825,7 @@ class QuerySet:
399
825
  query (str): The SQL query string to execute.
400
826
 
401
827
  Returns:
402
- list[QueryResponse]: A list of `QueryResponse` objects containing the query results.
828
+ list[Any]: A list of query response objects containing the query results.
403
829
 
404
830
  Raises:
405
831
  SurrealDbError: If there is an issue executing the query.
@@ -412,7 +838,7 @@ class QuerySet:
412
838
  client = await SurrealDBConnectionManager.get_client()
413
839
  return await self._run_query_on_client(client, query)
414
840
 
415
- async def _run_query_on_client(self, client: AsyncSurrealDB, query: str) -> list[QueryResponse]:
841
+ async def _run_query_on_client(self, client: Any, query: str) -> list[Any]:
416
842
  """
417
843
  Run the SQL query on the provided SurrealDB client.
418
844
 
@@ -420,11 +846,11 @@ class QuerySet:
420
846
  and returns the raw query responses.
421
847
 
422
848
  Args:
423
- client (AsyncSurrealDB): The active SurrealDB client instance.
849
+ client: The active SurrealDB client instance.
424
850
  query (str): The SQL query string to execute.
425
851
 
426
852
  Returns:
427
- list[QueryResponse]: A list of `QueryResponse` objects containing the query results.
853
+ list[Any]: A list of query response objects containing the query results.
428
854
 
429
855
  Raises:
430
856
  SurrealDbError: If there is an issue executing the query.
@@ -434,7 +860,9 @@ class QuerySet:
434
860
  results = await self._run_query_on_client(client, "SELECT * FROM users;")
435
861
  ```
436
862
  """
437
- return await client.query(remove_quotes_for_variables(query), self._variables) # type: ignore
863
+ result = await client.query(remove_quotes_for_variables(query), self._variables)
864
+ # SDK returns QueryResponse, extract all records
865
+ return cast(list[Any], result.all_records)
438
866
 
439
867
  async def delete_table(self) -> bool:
440
868
  """
@@ -455,7 +883,7 @@ class QuerySet:
455
883
  ```
456
884
  """
457
885
  client = await SurrealDBConnectionManager.get_client()
458
- await client.delete(Table(self._model_table))
886
+ await client.delete(self._model_table)
459
887
  return True
460
888
 
461
889
  async def query(self, query: str, variables: dict[str, Any] = {}) -> Any:
@@ -483,9 +911,157 @@ class QuerySet:
483
911
  results = await queryset.query(custom_query, variables={'status': 'active'})
484
912
  ```
485
913
  """
486
- if f"FROM {self.model.__name__}" not in query:
487
- raise SurrealDbError(f"The query must include 'FROM {self.model.__name__}' to reference the correct table.")
914
+ if f"FROM {self._model_table}" not in query:
915
+ raise SurrealDbError(f"The query must include 'FROM {self._model_table}' to reference the correct table.")
916
+ client = await SurrealDBConnectionManager.get_client()
917
+ result = await client.query(remove_quotes_for_variables(query), variables)
918
+ # SDK returns QueryResponse, extract all records
919
+ return self.model.from_db(cast(dict | list | None, result.all_records))
920
+
921
+ # ==================== Bulk Operations ====================
922
+
923
+ async def bulk_create(
924
+ self,
925
+ instances: Sequence[BaseSurrealModel],
926
+ atomic: bool = False,
927
+ batch_size: int | None = None,
928
+ ) -> list[BaseSurrealModel]:
929
+ """
930
+ Create multiple model instances in the database efficiently.
931
+
932
+ Args:
933
+ instances: A sequence of model instances to create.
934
+ atomic: If True, all creates are wrapped in a transaction.
935
+ If any fails, all are rolled back.
936
+ batch_size: If specified, instances are created in batches of this size.
937
+ Useful for very large datasets to avoid memory issues.
938
+
939
+ Returns:
940
+ list[BaseSurrealModel]: The created instances.
941
+
942
+ Example:
943
+ ```python
944
+ users = [User(name=f"User{i}") for i in range(1000)]
945
+
946
+ # Simple bulk create
947
+ created = await User.objects().bulk_create(users)
948
+
949
+ # Atomic bulk create
950
+ created = await User.objects().bulk_create(users, atomic=True)
951
+
952
+ # With batch size
953
+ created = await User.objects().bulk_create(users, batch_size=100)
954
+ ```
955
+ """
956
+ if not instances:
957
+ return []
958
+
959
+ created: list[BaseSurrealModel] = []
960
+
961
+ if atomic:
962
+ # Use transaction for atomicity
963
+ async with await SurrealDBConnectionManager.transaction() as tx:
964
+ for instance in instances:
965
+ await instance.save(tx=tx)
966
+ created.append(instance)
967
+ elif batch_size:
968
+ # Process in batches
969
+ for i in range(0, len(instances), batch_size):
970
+ batch = instances[i : i + batch_size]
971
+ for instance in batch:
972
+ await instance.save()
973
+ created.append(instance)
974
+ else:
975
+ # Simple sequential create
976
+ for instance in instances:
977
+ await instance.save()
978
+ created.append(instance)
979
+
980
+ return created
981
+
982
+ async def bulk_update(
983
+ self,
984
+ data: dict[str, Any],
985
+ atomic: bool = False,
986
+ ) -> int:
987
+ """
988
+ Update all records matching the current filters.
989
+
990
+ Args:
991
+ data: A dictionary of field names and values to update.
992
+ atomic: If True, all updates are wrapped in a transaction.
993
+
994
+ Returns:
995
+ int: The number of records updated.
996
+
997
+ Example:
998
+ ```python
999
+ # Update all matching records
1000
+ updated = await User.objects().filter(
1001
+ last_login__lt="2025-01-01"
1002
+ ).bulk_update({"status": "inactive"})
1003
+
1004
+ # Atomic update
1005
+ updated = await User.objects().filter(role="guest").bulk_update(
1006
+ {"verified": True},
1007
+ atomic=True
1008
+ )
1009
+ ```
1010
+ """
1011
+ where_clause = self._compile_where_clause()
1012
+
1013
+ # Build SET clause
1014
+ set_parts = []
1015
+ for field, value in data.items():
1016
+ set_parts.append(f"{field} = {repr(value)}")
1017
+ set_clause = ", ".join(set_parts)
1018
+
1019
+ query = f"UPDATE {self._model_table} SET {set_clause}{where_clause};"
1020
+
1021
+ if atomic:
1022
+ # For atomic operations, count first then update in transaction
1023
+ current_count = await self.count()
1024
+ async with await SurrealDBConnectionManager.transaction() as tx:
1025
+ await tx.query(remove_quotes_for_variables(query), self._variables)
1026
+ return current_count
1027
+
1028
+ client = await SurrealDBConnectionManager.get_client()
1029
+ result = await client.query(remove_quotes_for_variables(query), self._variables)
1030
+ return len(result.all_records)
1031
+
1032
+ async def bulk_delete(self, atomic: bool = False) -> int:
1033
+ """
1034
+ Delete all records matching the current filters.
1035
+
1036
+ Args:
1037
+ atomic: If True, all deletes are wrapped in a transaction.
1038
+
1039
+ Returns:
1040
+ int: The number of records deleted.
1041
+
1042
+ Example:
1043
+ ```python
1044
+ # Delete all matching records
1045
+ deleted = await User.objects().filter(status="deleted").bulk_delete()
1046
+
1047
+ # Atomic delete
1048
+ deleted = await Order.objects().filter(
1049
+ created_at__lt="2024-01-01"
1050
+ ).bulk_delete(atomic=True)
1051
+ ```
1052
+ """
1053
+ where_clause = self._compile_where_clause()
1054
+ # Use RETURN BEFORE to get deleted records count
1055
+ query = f"DELETE FROM {self._model_table}{where_clause} RETURN BEFORE;"
1056
+
1057
+ if atomic:
1058
+ # For atomic operations, count first then delete in transaction
1059
+ current_count = await self.count()
1060
+ delete_query = f"DELETE FROM {self._model_table}{where_clause};"
1061
+ async with await SurrealDBConnectionManager.transaction() as tx:
1062
+ await tx.query(remove_quotes_for_variables(delete_query), self._variables)
1063
+ return current_count
1064
+
488
1065
  client = await SurrealDBConnectionManager.get_client()
489
- results = await client.query(remove_quotes_for_variables(query), variables)
490
- data = cast(dict, results[0])
491
- return self.model.from_db(data["result"])
1066
+ result = await client.query(remove_quotes_for_variables(query), self._variables)
1067
+ return len(result.all_records)