surreal-orm-lite 0.2.2__tar.gz → 0.3.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (19) hide show
  1. {surreal_orm_lite-0.2.2 → surreal_orm_lite-0.3.0}/CHANGELOG.md +33 -0
  2. {surreal_orm_lite-0.2.2 → surreal_orm_lite-0.3.0}/PKG-INFO +1 -1
  3. {surreal_orm_lite-0.2.2 → surreal_orm_lite-0.3.0}/pyproject.toml +1 -1
  4. {surreal_orm_lite-0.2.2 → surreal_orm_lite-0.3.0}/src/surreal_orm_lite/__init__.py +14 -2
  5. surreal_orm_lite-0.3.0/src/surreal_orm_lite/aggregations.py +279 -0
  6. {surreal_orm_lite-0.2.2 → surreal_orm_lite-0.3.0}/src/surreal_orm_lite/model_base.py +60 -0
  7. {surreal_orm_lite-0.2.2 → surreal_orm_lite-0.3.0}/src/surreal_orm_lite/query_set.py +379 -12
  8. surreal_orm_lite-0.3.0/src/surreal_orm_lite/utils.py +53 -0
  9. surreal_orm_lite-0.2.2/src/surreal_orm_lite/utils.py +0 -6
  10. {surreal_orm_lite-0.2.2 → surreal_orm_lite-0.3.0}/.gitignore +0 -0
  11. {surreal_orm_lite-0.2.2 → surreal_orm_lite-0.3.0}/LICENSE +0 -0
  12. {surreal_orm_lite-0.2.2 → surreal_orm_lite-0.3.0}/Makefile +0 -0
  13. {surreal_orm_lite-0.2.2 → surreal_orm_lite-0.3.0}/README.md +0 -0
  14. {surreal_orm_lite-0.2.2 → surreal_orm_lite-0.3.0}/src/__init__.py +0 -0
  15. {surreal_orm_lite-0.2.2 → surreal_orm_lite-0.3.0}/src/surreal_orm_lite/connection_manager.py +0 -0
  16. {surreal_orm_lite-0.2.2 → surreal_orm_lite-0.3.0}/src/surreal_orm_lite/constants.py +0 -0
  17. {surreal_orm_lite-0.2.2 → surreal_orm_lite-0.3.0}/src/surreal_orm_lite/enum.py +0 -0
  18. {surreal_orm_lite-0.2.2 → surreal_orm_lite-0.3.0}/src/surreal_orm_lite/exceptions.py +0 -0
  19. {surreal_orm_lite-0.2.2 → surreal_orm_lite-0.3.0}/src/surreal_orm_lite/py.typed +0 -0
@@ -5,6 +5,39 @@ All notable changes to this project will be documented in this file.
5
5
  The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
6
6
  and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
7
7
 
8
+ ## [0.3.0] - 2026-02-05
9
+
10
+ ### Added
11
+
12
+ - **Aggregation Functions**: New aggregation classes for database calculations
13
+ - `Count()` - Count records
14
+ - `Sum(field)` - Sum numeric field values
15
+ - `Avg(field)` - Calculate average of numeric field
16
+ - `Min(field)` - Find minimum value
17
+ - `Max(field)` - Find maximum value
18
+
19
+ - **QuerySet Aggregation Methods**: Shortcut methods for common aggregations
20
+ - `count()` - Returns count as integer directly
21
+ - `sum(field)` - Returns sum as float/int
22
+ - `avg(field)` - Returns average as float
23
+ - `min(field)` - Returns minimum value
24
+ - `max(field)` - Returns maximum value
25
+
26
+ - **GROUP BY Support**: Django-style grouping with annotations
27
+ - `values(*fields)` - Specify fields for GROUP BY
28
+ - `annotate(**aggregations)` - Add aggregation annotations
29
+
30
+ - **exists() Method**: Efficiently check if records exist
31
+
32
+ - **raw_query() Class Method**: Execute arbitrary SurrealQL queries with variables
33
+
34
+ - New test file `tests/test_aggregations.py` with comprehensive unit and e2e tests
35
+
36
+ ### Changed
37
+
38
+ - QuerySet now tracks `_group_by_fields` and `_annotations` for GROUP BY queries
39
+ - `exec()` method now handles GROUP BY queries differently, returning dicts instead of model instances
40
+
8
41
  ## [0.2.2] - 2026-02-05
9
42
 
10
43
  ### Fixed
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: surreal-orm-lite
3
- Version: 0.2.2
3
+ Version: 0.3.0
4
4
  Summary: Lightweight Django-style ORM for SurrealDB using the official Python SDK. Async support with Pydantic validation.
5
5
  Project-URL: Homepage, https://github.com/EulogySnowfall/SurrealDB-ORM-lite
6
6
  Project-URL: Documentation, https://github.com/EulogySnowfall/SurrealDB-ORM-lite
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "surreal-orm-lite"
3
- version = "0.2.2"
3
+ version = "0.3.0"
4
4
  description = "Lightweight Django-style ORM for SurrealDB using the official Python SDK. Async support with Pydantic validation."
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.11"
@@ -1,5 +1,6 @@
1
- __version__ = "0.2.2"
1
+ __version__ = "0.3.0"
2
2
 
3
+ from .aggregations import Aggregation, Avg, Count, Max, Min, Sum
3
4
  from .connection_manager import SurrealDBConnectionManager
4
5
  from .enum import OrderBy
5
6
  from .exceptions import (
@@ -14,11 +15,22 @@ from .query_set import QuerySet
14
15
 
15
16
  __all__ = [
16
17
  "__version__",
18
+ # Connection
17
19
  "SurrealDBConnectionManager",
20
+ # Model
18
21
  "BaseSurrealModel",
22
+ "SurrealConfigDict",
23
+ # QuerySet
19
24
  "QuerySet",
20
25
  "OrderBy",
21
- "SurrealConfigDict",
26
+ # Aggregations
27
+ "Aggregation",
28
+ "Count",
29
+ "Sum",
30
+ "Avg",
31
+ "Min",
32
+ "Max",
33
+ # Exceptions
22
34
  "SurrealORMError",
23
35
  "SurrealDbError",
24
36
  "SurrealDbConnectionError",
@@ -0,0 +1,279 @@
1
+ """
2
+ Aggregation classes for SurrealDB-ORM-lite.
3
+
4
+ This module provides Django-style aggregation functions that can be used
5
+ with QuerySet to perform aggregate calculations on database fields.
6
+
7
+ Example:
8
+ ```python
9
+ from surreal_orm_lite import Count, Sum, Avg, Min, Max
10
+
11
+ # Simple aggregations
12
+ count = await User.objects().count()
13
+ total = await Order.objects().sum("amount")
14
+
15
+ # With GROUP BY
16
+ results = await User.objects().values("status").annotate(count=Count()).exec()
17
+ ```
18
+ """
19
+
20
+ from abc import ABC, abstractmethod
21
+
22
+ from .utils import validate_field_name
23
+
24
+
25
+ class Aggregation(ABC):
26
+ """
27
+ Base class for all aggregation functions.
28
+
29
+ Aggregations are used to compute summary values from a set of records.
30
+ Each aggregation must implement the `to_sql()` method that returns
31
+ the SurrealDB SQL expression for the aggregation.
32
+ """
33
+
34
+ def __init__(self, field: str | None = None, alias: str | None = None) -> None:
35
+ """
36
+ Initialize an aggregation.
37
+
38
+ Args:
39
+ field: The field name to aggregate. Some aggregations (like Count)
40
+ don't require a field.
41
+ alias: Optional alias for the result. If not provided, a default
42
+ alias will be generated.
43
+ """
44
+ self.field = field
45
+ self.alias = alias
46
+
47
+ @abstractmethod
48
+ def to_sql(self) -> str:
49
+ """
50
+ Convert the aggregation to a SurrealDB SQL expression.
51
+
52
+ Returns:
53
+ str: The SQL expression for this aggregation.
54
+ """
55
+ pass # pragma: no cover
56
+
57
+ def get_alias(self) -> str:
58
+ """
59
+ Get the alias for this aggregation result.
60
+
61
+ Returns:
62
+ str: The alias name for the aggregation result.
63
+ """
64
+ if self.alias:
65
+ return self.alias
66
+ if self.field:
67
+ return f"{self.__class__.__name__.lower()}_{self.field}"
68
+ return self.__class__.__name__.lower()
69
+
70
+
71
+ class Count(Aggregation):
72
+ """
73
+ Count aggregation function.
74
+
75
+ Counts the number of records in a query result.
76
+
77
+ Example:
78
+ ```python
79
+ # Count all users
80
+ count = await User.objects().count()
81
+
82
+ # Count with filter
83
+ active_count = await User.objects().filter(status="active").count()
84
+
85
+ # Count with GROUP BY
86
+ results = await User.objects().values("status").annotate(count=Count()).exec()
87
+ ```
88
+ """
89
+
90
+ def __init__(self, field: str | None = None, alias: str | None = None) -> None:
91
+ """
92
+ Initialize a Count aggregation.
93
+
94
+ Args:
95
+ field: Optional field name. If not provided, counts all records.
96
+ alias: Optional alias for the result.
97
+ """
98
+ if field is not None:
99
+ validate_field_name(field, "Count field")
100
+ super().__init__(field, alias)
101
+
102
+ def to_sql(self) -> str:
103
+ """
104
+ Convert to SurrealDB SQL.
105
+
106
+ Returns:
107
+ str: "count()" or "count(field)" expression.
108
+ """
109
+ if self.field:
110
+ return f"count({self.field})"
111
+ return "count()"
112
+
113
+
114
+ class Sum(Aggregation):
115
+ """
116
+ Sum aggregation function.
117
+
118
+ Calculates the sum of a numeric field.
119
+
120
+ Example:
121
+ ```python
122
+ # Sum of all order amounts
123
+ total = await Order.objects().sum("amount")
124
+
125
+ # Sum with filter
126
+ total_completed = await Order.objects().filter(status="completed").sum("amount")
127
+
128
+ # Sum with GROUP BY
129
+ results = await Order.objects().values("customer_id").annotate(total=Sum("amount")).exec()
130
+ ```
131
+ """
132
+
133
+ def __init__(self, field: str, alias: str | None = None) -> None:
134
+ """
135
+ Initialize a Sum aggregation.
136
+
137
+ Args:
138
+ field: The numeric field to sum.
139
+ alias: Optional alias for the result.
140
+ """
141
+ if not field or not field.strip():
142
+ raise ValueError("Sum requires a field name")
143
+ validate_field_name(field, "Sum field")
144
+ super().__init__(field, alias)
145
+
146
+ def to_sql(self) -> str:
147
+ """
148
+ Convert to SurrealDB SQL.
149
+
150
+ Returns:
151
+ str: "math::sum(field)" expression.
152
+ """
153
+ return f"math::sum({self.field})"
154
+
155
+
156
+ class Avg(Aggregation):
157
+ """
158
+ Average aggregation function.
159
+
160
+ Calculates the average of a numeric field.
161
+
162
+ Example:
163
+ ```python
164
+ # Average age of all users
165
+ avg_age = await User.objects().avg("age")
166
+
167
+ # Average with filter
168
+ avg_active = await User.objects().filter(status="active").avg("age")
169
+
170
+ # Average with GROUP BY
171
+ results = await User.objects().values("department").annotate(avg_salary=Avg("salary")).exec()
172
+ ```
173
+ """
174
+
175
+ def __init__(self, field: str, alias: str | None = None) -> None:
176
+ """
177
+ Initialize an Avg aggregation.
178
+
179
+ Args:
180
+ field: The numeric field to average.
181
+ alias: Optional alias for the result.
182
+ """
183
+ if not field or not field.strip():
184
+ raise ValueError("Avg requires a field name")
185
+ validate_field_name(field, "Avg field")
186
+ super().__init__(field, alias)
187
+
188
+ def to_sql(self) -> str:
189
+ """
190
+ Convert to SurrealDB SQL.
191
+
192
+ Returns:
193
+ str: "math::mean(field)" expression.
194
+ """
195
+ return f"math::mean({self.field})"
196
+
197
+
198
+ class Min(Aggregation):
199
+ """
200
+ Minimum aggregation function.
201
+
202
+ Finds the minimum value of a field.
203
+
204
+ Example:
205
+ ```python
206
+ # Minimum price
207
+ min_price = await Product.objects().min("price")
208
+
209
+ # Minimum with filter
210
+ min_active = await Product.objects().filter(active=True).min("price")
211
+
212
+ # Minimum with GROUP BY
213
+ results = await Product.objects().values("category").annotate(min_price=Min("price")).exec()
214
+ ```
215
+ """
216
+
217
+ def __init__(self, field: str, alias: str | None = None) -> None:
218
+ """
219
+ Initialize a Min aggregation.
220
+
221
+ Args:
222
+ field: The field to find the minimum value of.
223
+ alias: Optional alias for the result.
224
+ """
225
+ if not field or not field.strip():
226
+ raise ValueError("Min requires a field name")
227
+ validate_field_name(field, "Min field")
228
+ super().__init__(field, alias)
229
+
230
+ def to_sql(self) -> str:
231
+ """
232
+ Convert to SurrealDB SQL.
233
+
234
+ Returns:
235
+ str: "math::min(field)" expression.
236
+ """
237
+ return f"math::min({self.field})"
238
+
239
+
240
+ class Max(Aggregation):
241
+ """
242
+ Maximum aggregation function.
243
+
244
+ Finds the maximum value of a field.
245
+
246
+ Example:
247
+ ```python
248
+ # Maximum price
249
+ max_price = await Product.objects().max("price")
250
+
251
+ # Maximum with filter
252
+ max_active = await Product.objects().filter(active=True).max("price")
253
+
254
+ # Maximum with GROUP BY
255
+ results = await Product.objects().values("category").annotate(max_price=Max("price")).exec()
256
+ ```
257
+ """
258
+
259
+ def __init__(self, field: str, alias: str | None = None) -> None:
260
+ """
261
+ Initialize a Max aggregation.
262
+
263
+ Args:
264
+ field: The field to find the maximum value of.
265
+ alias: Optional alias for the result.
266
+ """
267
+ if not field or not field.strip():
268
+ raise ValueError("Max requires a field name")
269
+ validate_field_name(field, "Max field")
270
+ super().__init__(field, alias)
271
+
272
+ def to_sql(self) -> str:
273
+ """
274
+ Convert to SurrealDB SQL.
275
+
276
+ Returns:
277
+ str: "math::max(field)" expression.
278
+ """
279
+ return f"math::max({self.field})"
@@ -2,6 +2,7 @@ import logging
2
2
  from typing import Any, Self
3
3
 
4
4
  from pydantic import BaseModel, ConfigDict, model_validator
5
+ from pydantic_core import ValidationError
5
6
  from surrealdb import RecordID
6
7
 
7
8
  from .connection_manager import SurrealDBConnectionManager
@@ -224,3 +225,62 @@ class BaseSurrealModel(BaseModel):
224
225
  from .query_set import QuerySet
225
226
 
226
227
  return QuerySet(cls)
228
+
229
+ @classmethod
230
+ async def raw_query(
231
+ cls,
232
+ query: str,
233
+ variables: dict[str, Any] | None = None,
234
+ ) -> list[Self] | list[dict[str, Any]]:
235
+ """
236
+ Execute a raw SurrealQL query and return the results.
237
+
238
+ This method allows executing arbitrary SurrealQL queries directly against
239
+ the database. It's useful for complex queries that can't be expressed
240
+ using the QuerySet API.
241
+
242
+ Args:
243
+ query: The SurrealQL query string to execute.
244
+ variables: Optional dictionary of variables to substitute into the query.
245
+ Use $variable_name syntax in the query string.
246
+
247
+ Returns:
248
+ list[Self] | list[dict]: A list of model instances if the results match
249
+ the model schema, otherwise a list of dictionaries.
250
+
251
+ Example:
252
+ ```python
253
+ # Simple query
254
+ users = await User.raw_query("SELECT * FROM User WHERE age > 21")
255
+
256
+ # With variables (safe from injection)
257
+ users = await User.raw_query(
258
+ "SELECT * FROM User WHERE status = $status AND age > $min_age",
259
+ variables={"status": "active", "min_age": 18}
260
+ )
261
+
262
+ # Complex graph query
263
+ results = await User.raw_query('''
264
+ SELECT *, ->purchased->Product AS products
265
+ FROM User
266
+ WHERE id = $user_id
267
+ ''', variables={"user_id": "user:123"})
268
+ ```
269
+ """
270
+ from .utils import remove_quotes_for_variables
271
+
272
+ client = await SurrealDBConnectionManager.get_client()
273
+ results = await client.query(
274
+ remove_quotes_for_variables(query),
275
+ variables or {},
276
+ )
277
+
278
+ # SDK 1.0.8 returns list directly from query()
279
+ if isinstance(results, list):
280
+ try:
281
+ return cls.from_db(results) # type: ignore
282
+ except (ValueError, TypeError, ValidationError):
283
+ # If validation fails, return raw dicts
284
+ return results
285
+
286
+ return []
@@ -1,5 +1,5 @@
1
1
  import logging
2
- from typing import Any, Self, cast
2
+ from typing import TYPE_CHECKING, Any, Self, cast
3
3
 
4
4
  from pydantic_core import ValidationError
5
5
 
@@ -7,7 +7,10 @@ from . import BaseSurrealModel, SurrealDBConnectionManager
7
7
  from .constants import LOOKUP_OPERATORS
8
8
  from .enum import OrderBy
9
9
  from .exceptions import SurrealDbError, SurrealDbNotFoundError
10
- from .utils import remove_quotes_for_variables
10
+ from .utils import remove_quotes_for_variables, validate_alias_name, validate_field_name
11
+
12
+ if TYPE_CHECKING:
13
+ from .aggregations import Aggregation
11
14
 
12
15
  logger = logging.getLogger(__name__)
13
16
 
@@ -57,6 +60,8 @@ class QuerySet:
57
60
  self._order_by: str | None = None
58
61
  self._model_table: str = getattr(model, "_table_name", model.__name__)
59
62
  self._variables: dict = {}
63
+ self._group_by_fields: list[str] = []
64
+ self._annotations: dict[str, Aggregation] = {}
60
65
 
61
66
  def select(self, *fields: str) -> Self:
62
67
  """
@@ -223,6 +228,23 @@ class QuerySet:
223
228
  self._order_by = f"{field_name} {type}"
224
229
  return self
225
230
 
231
+ def _build_where_clauses(self) -> list[str]:
232
+ """
233
+ Build WHERE clause conditions from the current filters.
234
+
235
+ Returns:
236
+ list[str]: A list of WHERE clause condition strings.
237
+ """
238
+ where_clauses = []
239
+ for field_name, lookup_name, value in self._filters:
240
+ op = LOOKUP_OPERATORS.get(lookup_name, "=")
241
+ if lookup_name == "in":
242
+ formatted_values = ", ".join(repr(v) for v in value)
243
+ where_clauses.append(f"{field_name} {op} [{formatted_values}]")
244
+ else:
245
+ where_clauses.append(f"{field_name} {op} {repr(value)}")
246
+ return where_clauses
247
+
226
248
  def _compile_query(self) -> str:
227
249
  """
228
250
  Compile the QuerySet parameters into a SQL query string.
@@ -240,15 +262,7 @@ class QuerySet:
240
262
  # "SELECT id, name FROM users WHERE age > 21 AND status = 'active' ORDER BY name ASC LIMIT 10 START 20;"
241
263
  ```
242
264
  """
243
- where_clauses = []
244
- for field_name, lookup_name, value in self._filters:
245
- op = LOOKUP_OPERATORS.get(lookup_name, "=")
246
- if lookup_name == "in":
247
- # Assuming value is iterable for 'IN' operations
248
- formatted_values = ", ".join(repr(v) for v in value)
249
- where_clauses.append(f"{field_name} {op} [{formatted_values}]")
250
- else:
251
- where_clauses.append(f"{field_name} {op} {repr(value)}")
265
+ where_clauses = self._build_where_clauses()
252
266
 
253
267
  # Construct the SELECT clause
254
268
  if self.select_item:
@@ -284,18 +298,34 @@ class QuerySet:
284
298
  the results. If the data conforms to the model schema, it returns a list of model instances;
285
299
  otherwise, it returns a list of dictionaries.
286
300
 
301
+ When `values()` and `annotate()` are used, returns a list of dictionaries with
302
+ the grouped fields and aggregation results.
303
+
287
304
  Returns:
288
305
  list[BaseSurrealModel] | list[dict]: A list of model instances if validation is successful,
289
- otherwise a list of dictionaries representing the raw data.
306
+ otherwise a list of dictionaries representing the raw data. For GROUP BY queries,
307
+ always returns a list of dictionaries.
290
308
 
291
309
  Raises:
292
310
  SurrealDbError: If there is an issue executing the query.
293
311
 
294
312
  Example:
295
313
  ```python
314
+ # Regular query
296
315
  results = await queryset.exec()
316
+
317
+ # GROUP BY query
318
+ results = await User.objects().values("status").annotate(count=Count()).exec()
319
+ # Returns: [{"status": "active", "count": 42}, ...]
297
320
  ```
298
321
  """
322
+ # Handle GROUP BY queries with annotations
323
+ if self._annotations:
324
+ query = self._compile_group_by_query()
325
+ results = await self._execute_query(query)
326
+ # GROUP BY queries always return dicts, not model instances
327
+ return results if isinstance(results, list) else []
328
+
299
329
  query = self._compile_query()
300
330
  results = await self._execute_query(query)
301
331
  try:
@@ -466,6 +496,343 @@ class QuerySet:
466
496
  await client.delete(self._model_table)
467
497
  return True
468
498
 
499
+ # ==================== Aggregation Methods ====================
500
+
501
+ def values(self, *fields: str) -> Self:
502
+ """
503
+ Specify fields for GROUP BY operations.
504
+
505
+ This method sets up the fields that will be used for grouping when
506
+ combined with `annotate()`. Similar to Django's `values()` method.
507
+
508
+ Args:
509
+ *fields: Field names to group by.
510
+
511
+ Returns:
512
+ Self: The current QuerySet instance for method chaining.
513
+
514
+ Example:
515
+ ```python
516
+ # Group users by status and count them
517
+ results = await User.objects().values("status").annotate(count=Count()).exec()
518
+ # Returns: [{"status": "active", "count": 42}, {"status": "inactive", "count": 8}]
519
+ ```
520
+ """
521
+ for field in fields:
522
+ validate_field_name(field, "GROUP BY field")
523
+ self._group_by_fields = list(fields)
524
+ return self
525
+
526
+ def annotate(self, **annotations: "Aggregation") -> Self:
527
+ """
528
+ Add aggregation annotations to the query.
529
+
530
+ This method adds aggregation functions to the query. When combined with
531
+ `values()`, it performs GROUP BY operations. Similar to Django's `annotate()`.
532
+
533
+ Args:
534
+ **annotations: Keyword arguments where the key is the alias and the value
535
+ is an Aggregation instance (Count, Sum, Avg, Min, Max).
536
+
537
+ Returns:
538
+ Self: The current QuerySet instance for method chaining.
539
+
540
+ Raises:
541
+ TypeError: If any annotation value is not an Aggregation instance.
542
+
543
+ Example:
544
+ ```python
545
+ from surreal_orm_lite import Count, Sum, Avg
546
+
547
+ # Group by status and count
548
+ results = await User.objects().values("status").annotate(count=Count()).exec()
549
+
550
+ # Multiple aggregations
551
+ results = await Order.objects().values("customer_id").annotate(
552
+ total=Sum("amount"),
553
+ avg_order=Avg("amount"),
554
+ order_count=Count()
555
+ ).exec()
556
+ ```
557
+ """
558
+ from .aggregations import Aggregation as AggregationClass
559
+
560
+ for alias, agg in annotations.items():
561
+ validate_alias_name(alias)
562
+ if not isinstance(agg, AggregationClass):
563
+ raise TypeError(f"annotate() argument '{alias}' must be an Aggregation instance, got {type(agg).__name__}")
564
+ self._annotations.update(annotations)
565
+ return self
566
+
567
+ async def count(self) -> int:
568
+ """
569
+ Count the number of records matching the query.
570
+
571
+ This is a shortcut method that returns the count as an integer directly.
572
+
573
+ Returns:
574
+ int: The number of matching records.
575
+
576
+ Example:
577
+ ```python
578
+ # Count all users
579
+ total = await User.objects().count()
580
+
581
+ # Count with filters
582
+ active = await User.objects().filter(status="active").count()
583
+ ```
584
+ """
585
+ query = self._compile_aggregation_query("count()")
586
+ results = await self._execute_query(query)
587
+
588
+ if isinstance(results, list) and len(results) > 0:
589
+ result = results[0]
590
+ if isinstance(result, dict):
591
+ # Handle GROUP ALL result format
592
+ return int(result.get("count", 0))
593
+ return int(result)
594
+ return 0
595
+
596
+ async def sum(self, field: str) -> float | int:
597
+ """
598
+ Calculate the sum of a numeric field.
599
+
600
+ Args:
601
+ field: The name of the numeric field to sum.
602
+
603
+ Returns:
604
+ float | int: The sum of the field values, or 0 if no records match.
605
+
606
+ Raises:
607
+ ValueError: If field name is empty or invalid.
608
+
609
+ Example:
610
+ ```python
611
+ # Sum of all order amounts
612
+ total = await Order.objects().sum("amount")
613
+
614
+ # Sum with filter
615
+ completed_total = await Order.objects().filter(status="completed").sum("amount")
616
+ ```
617
+ """
618
+ validate_field_name(field, "sum() field")
619
+ query = self._compile_aggregation_query(f"math::sum({field})", alias="sum")
620
+ results = await self._execute_query(query)
621
+
622
+ if isinstance(results, list) and len(results) > 0:
623
+ result = results[0]
624
+ if isinstance(result, dict):
625
+ value = result.get("sum", 0)
626
+ return value if value is not None else 0
627
+ return result if result is not None else 0
628
+ return 0
629
+
630
+ async def avg(self, field: str) -> float:
631
+ """
632
+ Calculate the average of a numeric field.
633
+
634
+ Args:
635
+ field: The name of the numeric field to average.
636
+
637
+ Returns:
638
+ float: The average of the field values, or 0.0 if no records match.
639
+
640
+ Raises:
641
+ ValueError: If field name is empty or invalid.
642
+
643
+ Example:
644
+ ```python
645
+ # Average age of all users
646
+ avg_age = await User.objects().avg("age")
647
+
648
+ # Average with filter
649
+ avg_active = await User.objects().filter(status="active").avg("age")
650
+ ```
651
+ """
652
+ validate_field_name(field, "avg() field")
653
+ query = self._compile_aggregation_query(f"math::mean({field})", alias="avg")
654
+ results = await self._execute_query(query)
655
+
656
+ if isinstance(results, list) and len(results) > 0:
657
+ result = results[0]
658
+ if isinstance(result, dict):
659
+ value = result.get("avg", 0.0)
660
+ return float(value) if value is not None else 0.0
661
+ return float(result) if result is not None else 0.0
662
+ return 0.0
663
+
664
+ async def min(self, field: str) -> Any:
665
+ """
666
+ Find the minimum value of a field.
667
+
668
+ Args:
669
+ field: The name of the field to find the minimum value of.
670
+
671
+ Returns:
672
+ Any: The minimum value, or None if no records match.
673
+
674
+ Raises:
675
+ ValueError: If field name is empty or invalid.
676
+
677
+ Example:
678
+ ```python
679
+ # Minimum price
680
+ min_price = await Product.objects().min("price")
681
+
682
+ # Minimum with filter
683
+ min_active = await Product.objects().filter(active=True).min("price")
684
+ ```
685
+ """
686
+ validate_field_name(field, "min() field")
687
+ query = self._compile_aggregation_query(f"math::min({field})", alias="min")
688
+ results = await self._execute_query(query)
689
+
690
+ if isinstance(results, list) and len(results) > 0:
691
+ result = results[0]
692
+ if isinstance(result, dict):
693
+ return result.get("min")
694
+ return result
695
+ return None
696
+
697
+ async def max(self, field: str) -> Any:
698
+ """
699
+ Find the maximum value of a field.
700
+
701
+ Args:
702
+ field: The name of the field to find the maximum value of.
703
+
704
+ Returns:
705
+ Any: The maximum value, or None if no records match.
706
+
707
+ Raises:
708
+ ValueError: If field name is empty or invalid.
709
+
710
+ Example:
711
+ ```python
712
+ # Maximum price
713
+ max_price = await Product.objects().max("price")
714
+
715
+ # Maximum with filter
716
+ max_active = await Product.objects().filter(active=True).max("price")
717
+ ```
718
+ """
719
+ validate_field_name(field, "max() field")
720
+ query = self._compile_aggregation_query(f"math::max({field})", alias="max")
721
+ results = await self._execute_query(query)
722
+
723
+ if isinstance(results, list) and len(results) > 0:
724
+ result = results[0]
725
+ if isinstance(result, dict):
726
+ return result.get("max")
727
+ return result
728
+ return None
729
+
730
+ async def exists(self) -> bool:
731
+ """
732
+ Check if any records match the current query.
733
+
734
+ This method is more efficient than `count()` when you only need to know
735
+ if any matching records exist.
736
+
737
+ Returns:
738
+ bool: True if at least one record matches, False otherwise.
739
+
740
+ Example:
741
+ ```python
742
+ # Check if any admins exist
743
+ has_admin = await User.objects().filter(role="admin").exists()
744
+
745
+ # Check if user exists
746
+ user_exists = await User.objects().filter(email="alice@example.com").exists()
747
+ ```
748
+ """
749
+ # Use LIMIT 1 for efficiency without permanently mutating the QuerySet
750
+ original_limit = self._limit
751
+ try:
752
+ self._limit = 1
753
+ query = self._compile_query()
754
+ results = await self._execute_query(query)
755
+
756
+ if isinstance(results, list):
757
+ return len(results) > 0
758
+ return False
759
+ finally:
760
+ self._limit = original_limit
761
+
762
+ def _compile_aggregation_query(self, aggregation_expr: str, alias: str | None = None) -> str:
763
+ """
764
+ Compile an aggregation query.
765
+
766
+ This internal method builds the SQL query for aggregation operations.
767
+
768
+ Args:
769
+ aggregation_expr: The aggregation expression (e.g., "count()", "math::sum(field)").
770
+ alias: Optional alias for the aggregation result.
771
+
772
+ Returns:
773
+ str: The compiled SQL query string.
774
+ """
775
+ where_clauses = self._build_where_clauses()
776
+
777
+ # Build SELECT clause with aggregation
778
+ if alias:
779
+ query = f"SELECT {aggregation_expr} AS {alias} FROM {self._model_table}"
780
+ else:
781
+ query = f"SELECT {aggregation_expr} FROM {self._model_table}"
782
+
783
+ # Append WHERE clauses
784
+ if where_clauses:
785
+ query += " WHERE " + " AND ".join(where_clauses)
786
+
787
+ # GROUP ALL is required for aggregations without GROUP BY in SurrealDB
788
+ query += " GROUP ALL"
789
+ query += ";"
790
+ return query
791
+
792
+ def _compile_group_by_query(self) -> str:
793
+ """
794
+ Compile a GROUP BY query with annotations.
795
+
796
+ This internal method builds the SQL query for GROUP BY operations
797
+ with aggregation annotations.
798
+
799
+ Returns:
800
+ str: The compiled SQL query string.
801
+ """
802
+ where_clauses = self._build_where_clauses()
803
+
804
+ # Build SELECT clause with group fields and annotations
805
+ select_parts = list(self._group_by_fields)
806
+ for alias, agg in self._annotations.items():
807
+ select_parts.append(f"{agg.to_sql()} AS {alias}")
808
+
809
+ query = f"SELECT {', '.join(select_parts)} FROM {self._model_table}"
810
+
811
+ # Append WHERE clauses
812
+ if where_clauses:
813
+ query += " WHERE " + " AND ".join(where_clauses)
814
+
815
+ # Append GROUP BY
816
+ if self._group_by_fields:
817
+ query += f" GROUP BY {', '.join(self._group_by_fields)}"
818
+ else:
819
+ query += " GROUP ALL"
820
+
821
+ # Append ORDER BY if set
822
+ if self._order_by:
823
+ query += f" ORDER BY {self._order_by}"
824
+
825
+ # Append LIMIT if set
826
+ if self._limit is not None:
827
+ query += f" LIMIT {self._limit}"
828
+
829
+ # Append OFFSET if set
830
+ if self._offset is not None:
831
+ query += f" START {self._offset}"
832
+
833
+ query += ";"
834
+ return query
835
+
469
836
  async def query(self, query: str, variables: dict[str, Any] | None = None) -> Any:
470
837
  """
471
838
  Execute a custom SQL query on the SurrealDB database.
@@ -0,0 +1,53 @@
1
+ import re
2
+
3
+ # Pattern for valid field names: alphanumeric, underscores, dots (for nested fields)
4
+ # Must start with a letter or underscore
5
+ VALID_FIELD_PATTERN = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*(\.[a-zA-Z_][a-zA-Z0-9_]*)*$")
6
+
7
+ # Pattern for valid alias names: alphanumeric and underscores only
8
+ # Must start with a letter or underscore (like Python identifiers)
9
+ VALID_ALIAS_PATTERN = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*$")
10
+
11
+
12
+ def remove_quotes_for_variables(query: str) -> str:
13
+ # Regex to remove single quotes around variables ($)
14
+ return re.sub(r"'(\$[a-zA-Z_]\w*)'", r"\1", query)
15
+
16
+
17
+ def validate_field_name(field: str, context: str = "field") -> None:
18
+ """
19
+ Validate a field name to prevent SQL injection.
20
+
21
+ Args:
22
+ field: The field name to validate.
23
+ context: Description of where the field is used (for error messages).
24
+
25
+ Raises:
26
+ ValueError: If the field name contains invalid characters.
27
+ """
28
+ if not field or not field.strip():
29
+ raise ValueError(f"{context} name cannot be empty")
30
+ if not VALID_FIELD_PATTERN.match(field):
31
+ raise ValueError(
32
+ f"Invalid {context} name '{field}': must contain only alphanumeric characters, "
33
+ "underscores, and dots (for nested fields), and start with a letter or underscore"
34
+ )
35
+
36
+
37
+ def validate_alias_name(alias: str) -> None:
38
+ """
39
+ Validate an alias name to prevent SQL injection.
40
+
41
+ Args:
42
+ alias: The alias name to validate.
43
+
44
+ Raises:
45
+ ValueError: If the alias name contains invalid characters.
46
+ """
47
+ if not alias or not alias.strip():
48
+ raise ValueError("alias name cannot be empty")
49
+ if not VALID_ALIAS_PATTERN.match(alias):
50
+ raise ValueError(
51
+ f"Invalid alias name '{alias}': must contain only alphanumeric characters "
52
+ "and underscores, and start with a letter or underscore"
53
+ )
@@ -1,6 +0,0 @@
1
- import re
2
-
3
-
4
- def remove_quotes_for_variables(query: str) -> str:
5
- # Regex to remove single quotes around variables ($)
6
- return re.sub(r"'(\$[a-zA-Z_]\w*)'", r"\1", query)