surrealdb-orm 0.1.3__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. surreal_orm/__init__.py +78 -3
  2. surreal_orm/aggregations.py +164 -0
  3. surreal_orm/auth/__init__.py +15 -0
  4. surreal_orm/auth/access.py +167 -0
  5. surreal_orm/auth/mixins.py +302 -0
  6. surreal_orm/cli/__init__.py +15 -0
  7. surreal_orm/cli/commands.py +369 -0
  8. surreal_orm/connection_manager.py +58 -18
  9. surreal_orm/fields/__init__.py +36 -0
  10. surreal_orm/fields/encrypted.py +166 -0
  11. surreal_orm/fields/relation.py +465 -0
  12. surreal_orm/migrations/__init__.py +51 -0
  13. surreal_orm/migrations/executor.py +380 -0
  14. surreal_orm/migrations/generator.py +272 -0
  15. surreal_orm/migrations/introspector.py +305 -0
  16. surreal_orm/migrations/migration.py +188 -0
  17. surreal_orm/migrations/operations.py +531 -0
  18. surreal_orm/migrations/state.py +406 -0
  19. surreal_orm/model_base.py +594 -135
  20. surreal_orm/py.typed +0 -0
  21. surreal_orm/query_set.py +609 -34
  22. surreal_orm/relations.py +645 -0
  23. surreal_orm/surreal_function.py +95 -0
  24. surreal_orm/surreal_ql.py +113 -0
  25. surreal_orm/types.py +86 -0
  26. surreal_sdk/README.md +79 -0
  27. surreal_sdk/__init__.py +151 -0
  28. surreal_sdk/connection/__init__.py +17 -0
  29. surreal_sdk/connection/base.py +516 -0
  30. surreal_sdk/connection/http.py +421 -0
  31. surreal_sdk/connection/pool.py +244 -0
  32. surreal_sdk/connection/websocket.py +519 -0
  33. surreal_sdk/exceptions.py +71 -0
  34. surreal_sdk/functions.py +607 -0
  35. surreal_sdk/protocol/__init__.py +13 -0
  36. surreal_sdk/protocol/rpc.py +218 -0
  37. surreal_sdk/py.typed +0 -0
  38. surreal_sdk/pyproject.toml +49 -0
  39. surreal_sdk/streaming/__init__.py +31 -0
  40. surreal_sdk/streaming/change_feed.py +278 -0
  41. surreal_sdk/streaming/live_query.py +265 -0
  42. surreal_sdk/streaming/live_select.py +369 -0
  43. surreal_sdk/transaction.py +386 -0
  44. surreal_sdk/types.py +346 -0
  45. surrealdb_orm-0.5.0.dist-info/METADATA +465 -0
  46. surrealdb_orm-0.5.0.dist-info/RECORD +52 -0
  47. {surrealdb_orm-0.1.3.dist-info → surrealdb_orm-0.5.0.dist-info}/WHEEL +1 -1
  48. surrealdb_orm-0.5.0.dist-info/entry_points.txt +2 -0
  49. {surrealdb_orm-0.1.3.dist-info → surrealdb_orm-0.5.0.dist-info}/licenses/LICENSE +1 -1
  50. surrealdb_orm-0.1.3.dist-info/METADATA +0 -184
  51. surrealdb_orm-0.1.3.dist-info/RECORD +0 -11
surreal_orm/model_base.py CHANGED
@@ -1,223 +1,452 @@
1
- from typing import Any, Type, Self
2
- from pydantic import BaseModel, create_model, ConfigDict
1
+ from typing import Any, Literal, Self, cast, TYPE_CHECKING
2
+
3
+ from pydantic import BaseModel, ConfigDict, model_validator
4
+
3
5
  from .connection_manager import SurrealDBConnectionManager
4
- from surrealdb import RecordID, SurrealDbError
6
+ from .types import SchemaMode, TableType
7
+
8
+ if TYPE_CHECKING:
9
+ from surreal_sdk.transaction import BaseTransaction, HTTPTransaction
5
10
 
6
- import warnings
7
11
  import logging
8
12
 
9
- warnings.filterwarnings("ignore", message="fields may not start with an underscore", category=RuntimeWarning)
13
+
14
+ class SurrealDbError(Exception):
15
+ """Error from SurrealDB operations."""
16
+
17
+ pass
18
+
10
19
 
11
20
  logger = logging.getLogger(__name__)
12
21
 
22
+ # Global registry of all SurrealDB models for migration introspection
23
+ _MODEL_REGISTRY: list[type["BaseSurrealModel"]] = []
24
+
25
+
26
+ def get_registered_models() -> list[type["BaseSurrealModel"]]:
27
+ """
28
+ Get all registered SurrealDB models.
29
+
30
+ Returns:
31
+ List of all model classes that inherit from BaseSurrealModel
32
+ """
33
+ return _MODEL_REGISTRY.copy()
34
+
35
+
36
+ def clear_model_registry() -> None:
37
+ """
38
+ Clear the model registry. Useful for testing.
39
+ """
40
+ _MODEL_REGISTRY.clear()
41
+
42
+
43
+ def _parse_record_id(record_id: Any) -> str | None:
44
+ """
45
+ Parse a record ID from various formats.
46
+ SurrealDB returns IDs as 'table:id' strings.
47
+ """
48
+ if record_id is None:
49
+ return None
50
+ record_str = str(record_id)
51
+ if ":" in record_str:
52
+ return record_str.split(":", 1)[1]
53
+ return record_str
54
+
55
+
56
+ class SurrealConfigDict(ConfigDict):
57
+ """
58
+ SurrealConfigDict is a configuration dictionary for SurrealDB models.
59
+
60
+ Extends Pydantic's ConfigDict with SurrealDB-specific options for
61
+ table types, schema modes, and authentication settings.
62
+
63
+ Attributes:
64
+ primary_key: The primary key field name for the model
65
+ table_name: Override the default table name (default: class name)
66
+ table_type: Table classification (NORMAL, USER, STREAM, HASH)
67
+ schema_mode: Schema enforcement mode (SCHEMAFULL, SCHEMALESS)
68
+ changefeed: Changefeed duration for STREAM tables (e.g., "7d")
69
+ permissions: Table-level permissions dict {"select": "...", "update": "..."}
70
+ identifier_field: Field used for signin (USER type, default: "email")
71
+ password_field: Field containing password (USER type, default: "password")
72
+ token_duration: JWT token duration (USER type, default: "15m")
73
+ session_duration: Session duration (USER type, default: "12h")
74
+ """
75
+
76
+ primary_key: str | None
77
+ table_name: str | None
78
+ table_type: TableType | None
79
+ schema_mode: SchemaMode | None
80
+ changefeed: str | None
81
+ permissions: dict[str, str] | None
82
+ identifier_field: str | None
83
+ password_field: str | None
84
+ token_duration: str | None
85
+ session_duration: str | None
86
+
13
87
 
14
88
  class BaseSurrealModel(BaseModel):
15
89
  """
16
90
  Base class for models interacting with SurrealDB.
91
+
92
+ All models that interact with SurrealDB should inherit from this class.
93
+ Models are automatically registered for migration introspection.
94
+
95
+ Example:
96
+ class User(BaseSurrealModel):
97
+ model_config = SurrealConfigDict(
98
+ table_type=TableType.USER,
99
+ schema_mode=SchemaMode.SCHEMAFULL,
100
+ )
101
+
102
+ id: str | None = None
103
+ email: str
104
+ password: Encrypted
17
105
  """
18
106
 
19
- __pydantic_model_cache__: Type[BaseModel] | None = None
107
+ def __init_subclass__(cls, **kwargs: Any) -> None:
108
+ """Register subclasses in the model registry for migration introspection."""
109
+ super().__init_subclass__(**kwargs)
110
+ # Only register concrete models, not intermediate base classes
111
+ if cls.__name__ != "BaseSurrealModel" and not cls.__name__.startswith("_"):
112
+ if cls not in _MODEL_REGISTRY:
113
+ _MODEL_REGISTRY.append(cls)
20
114
 
21
- def __init__(self, **data: Any):
22
- model_cls = self._init_model()
23
- instance = model_cls(**data)
24
- object.__setattr__(self, "_data", instance.model_dump())
25
- object.__setattr__(self, "_table_name", self.__class__.__name__)
115
+ @classmethod
116
+ def get_table_name(cls) -> str:
117
+ """
118
+ Get the table name for the model.
26
119
 
27
- def __getattr__(self, item: str) -> Any:
120
+ Returns the table_name from model_config if set,
121
+ otherwise returns the class name.
28
122
  """
29
- If the item is a field in _data, return it,
30
- otherwise, let the normal mechanism raise AttributeError.
123
+ if hasattr(cls, "model_config"):
124
+ table_name = cls.model_config.get("table_name", None)
125
+ if isinstance(table_name, str):
126
+ return table_name
127
+ return cls.__name__
128
+
129
+ @classmethod
130
+ def get_table_type(cls) -> TableType:
31
131
  """
32
- _data = object.__getattribute__(self, "_data")
33
- if item in _data:
34
- return _data[item]
35
- raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'.")
132
+ Get the table type classification for the model.
36
133
 
37
- def __str__(self) -> str:
38
- return f"{self._data}"
134
+ Returns:
135
+ TableType enum value (default: NORMAL)
136
+ """
137
+ if hasattr(cls, "model_config"):
138
+ table_type = cls.model_config.get("table_type", None)
139
+ if isinstance(table_type, TableType):
140
+ return table_type
141
+ return TableType.NORMAL
39
142
 
40
- def __setattr__(self, key: str, value: Any) -> None:
143
+ @classmethod
144
+ def get_schema_mode(cls) -> SchemaMode:
41
145
  """
42
- If we want to allow updates, reinstantiate a Pydantic model
43
- with the new value.
146
+ Get the schema mode for the model.
147
+
148
+ USER tables are always SCHEMAFULL.
149
+ HASH tables default to SCHEMALESS.
150
+ All others default to SCHEMAFULL.
151
+
152
+ Returns:
153
+ SchemaMode enum value
44
154
  """
45
- if key in ("_data",): # and other internal attributes
46
- object.__setattr__(self, key, value)
47
- else:
48
- # Update the dict, validate via Pydantic, etc.
49
- current_data = dict(object.__getattribute__(self, "_data"))
50
- current_data[key] = value
51
- instance = self._init_model()(**current_data)
52
- object.__setattr__(self, "_data", instance.model_dump())
155
+ table_type = cls.get_table_type()
156
+
157
+ # USER tables must be SCHEMAFULL
158
+ if table_type == TableType.USER:
159
+ return SchemaMode.SCHEMAFULL
160
+
161
+ if hasattr(cls, "model_config"):
162
+ schema_mode = cls.model_config.get("schema_mode", None)
163
+ if isinstance(schema_mode, SchemaMode):
164
+ return schema_mode
165
+
166
+ # HASH tables default to SCHEMALESS
167
+ if table_type == TableType.HASH:
168
+ return SchemaMode.SCHEMALESS
169
+
170
+ return SchemaMode.SCHEMAFULL
53
171
 
54
172
  @classmethod
55
- def from_db(cls, record: dict | list) -> Any:
173
+ def get_changefeed(cls) -> str | None:
56
174
  """
57
- Create an instance from a SurrealDB record.
175
+ Get the changefeed duration for the model.
176
+
177
+ Returns:
178
+ Changefeed duration string (e.g., "7d") or None
58
179
  """
59
- if isinstance(record, list):
60
- return [cls.from_db(rs) for rs in record]
180
+ if hasattr(cls, "model_config"):
181
+ changefeed = cls.model_config.get("changefeed", None)
182
+ return str(changefeed) if changefeed is not None else None
183
+ return None
61
184
 
62
- record = cls.__set_data(record)
185
+ @classmethod
186
+ def get_permissions(cls) -> dict[str, str]:
187
+ """
188
+ Get the table permissions for the model.
63
189
 
64
- return cls(**record)
190
+ Returns:
191
+ Dict of permission type to condition expression
192
+ """
193
+ if hasattr(cls, "model_config"):
194
+ permissions = cls.model_config.get("permissions", None)
195
+ if isinstance(permissions, dict):
196
+ return permissions
197
+ return {}
198
+
199
+ @classmethod
200
+ def get_identifier_field(cls) -> str:
201
+ """
202
+ Get the identifier field for USER type tables.
203
+
204
+ Returns:
205
+ Field name used for signin (default: "email")
206
+ """
207
+ if hasattr(cls, "model_config"):
208
+ field = cls.model_config.get("identifier_field", None)
209
+ if isinstance(field, str):
210
+ return field
211
+ return "email"
212
+
213
+ @classmethod
214
+ def get_password_field(cls) -> str:
215
+ """
216
+ Get the password field for USER type tables.
217
+
218
+ Returns:
219
+ Field name containing password (default: "password")
220
+ """
221
+ if hasattr(cls, "model_config"):
222
+ field = cls.model_config.get("password_field", None)
223
+ if isinstance(field, str):
224
+ return field
225
+ return "password"
226
+
227
+ @classmethod
228
+ def get_index_primary_key(cls) -> str | None:
229
+ """
230
+ Get the primary key field name for the model.
231
+ """
232
+ if hasattr(cls, "model_config"): # pragma: no cover
233
+ primary_key = cls.model_config.get("primary_key", None)
234
+ if isinstance(primary_key, str):
235
+ return primary_key
236
+
237
+ return None
65
238
 
66
- def to_db_dict(self) -> dict[str, Any]:
239
+ def get_id(self) -> str | None:
67
240
  """
68
- Return a dictionary ready to be inserted into the database.
241
+ Get the ID of the model instance.
69
242
  """
70
- data_set = {key: value for key, value in self._data.items() if not key.startswith("_") and key != "id"}
71
- return data_set
243
+ if hasattr(self, "id"):
244
+ id_value = getattr(self, "id")
245
+ return str(id_value) if id_value is not None else None
72
246
 
73
- def show_config(self) -> ConfigDict:
74
- # Accès depuis une méthode d'instance
75
- return type(self).model_config
247
+ if hasattr(self, "model_config"):
248
+ primary_key = self.model_config.get("primary_key", None)
249
+ if isinstance(primary_key, str) and hasattr(self, primary_key):
250
+ primary_key_value = getattr(self, primary_key)
251
+ return str(primary_key_value) if primary_key_value is not None else None
76
252
 
77
- def get_id(self) -> str | RecordID | None:
78
- if "id" in self._data:
79
- return self._data["id"]
253
+ return None # pragma: no cover
80
254
 
81
- config = self.show_config()
82
- pk_field = config.get("primary_key", "id")
83
- return self._data.get(pk_field, None)
255
+ @classmethod
256
+ def from_db(cls, record: dict | list | None) -> Self | list[Self]:
257
+ """
258
+ Create an instance from a SurrealDB record.
259
+ """
260
+ if record is None:
261
+ raise cls.DoesNotExist("Record not found.")
262
+
263
+ if isinstance(record, list):
264
+ return [cls.from_db(rs) for rs in record] # type: ignore
84
265
 
85
- @staticmethod
86
- def __set_data(data: Any) -> dict:
266
+ return cls(**record)
267
+
268
+ @model_validator(mode="before")
269
+ @classmethod
270
+ def set_data(cls, data: Any) -> Any:
87
271
  """
88
- Set the model instance data.
272
+ Parse the ID from SurrealDB format (table:id) to just id.
89
273
  """
90
274
  if isinstance(data, dict): # pragma: no cover
91
- if "id" in data and isinstance(data["id"], RecordID): # pragma: no cover
92
- data["id"] = str(data["id"]).split(":")[1]
275
+ if "id" in data:
276
+ data["id"] = _parse_record_id(data["id"])
93
277
  return data
94
278
 
95
- raise TypeError("Data must be a dictionary.") # pragma: no cover
96
-
97
279
  async def refresh(self) -> None:
98
280
  """
99
281
  Refresh the model instance from the database.
100
282
  """
283
+ if not self.get_id():
284
+ raise SurrealDbError("Can't refresh data, not recorded yet.") # pragma: no cover
285
+
101
286
  client = await SurrealDBConnectionManager.get_client()
102
- record = None
287
+ result = await client.select(f"{self.get_table_name()}:{self.get_id()}")
103
288
 
104
- id = self.get_id()
105
- record = await client.select(f"{self._table_name}:{id}")
289
+ # SDK returns RecordsResponse with .records list
290
+ if result.is_empty:
291
+ raise SurrealDbError("Can't refresh data, no record found.") # pragma: no cover
292
+
293
+ record = result.first
294
+ if record is None:
295
+ raise SurrealDbError("Can't refresh data, no record found.") # pragma: no cover
106
296
 
107
- self._data = self.__set_data(record)
297
+ # Update instance fields from the record
298
+ for key, value in record.items():
299
+ if key == "id":
300
+ value = _parse_record_id(value)
301
+ if hasattr(self, key):
302
+ setattr(self, key, value)
303
+ return None
108
304
 
109
- async def save(self) -> Self:
305
+ async def save(self, tx: "BaseTransaction | None" = None) -> Self:
110
306
  """
111
307
  Save the model instance to the database.
308
+
309
+ Args:
310
+ tx: Optional transaction to use for this operation.
311
+ If provided, the operation will be part of the transaction.
312
+
313
+ Example:
314
+ # Without transaction
315
+ await user.save()
316
+
317
+ # With transaction
318
+ async with SurrealDBConnectionManager.transaction() as tx:
319
+ await user.save(tx=tx)
112
320
  """
113
- client = await SurrealDBConnectionManager.get_client()
321
+ if tx is not None:
322
+ # Use transaction
323
+ data = self.model_dump(exclude={"id"})
324
+ id = self.get_id()
325
+ table = self.get_table_name()
326
+
327
+ if id is not None:
328
+ thing = f"{table}:{id}"
329
+ await tx.create(thing, data)
330
+ return self
331
+
332
+ # Auto-generate ID - create without specific ID
333
+ await tx.create(table, data)
334
+ return self
114
335
 
115
- data = self.to_db_dict()
336
+ # Original behavior without transaction
337
+ client = await SurrealDBConnectionManager.get_client()
338
+ data = self.model_dump(exclude={"id"})
116
339
  id = self.get_id()
117
- if id:
118
- thing = f"{self._table_name}:{id}"
340
+ table = self.get_table_name()
341
+
342
+ if id is not None:
343
+ thing = f"{table}:{id}"
119
344
  await client.create(thing, data)
120
345
  return self
346
+
121
347
  # Auto-generate the ID
122
- record = await client.create(self._table_name, data) # pragma: no cover
123
- if isinstance(record, dict): # pragma: no cover
124
- self._data = self.__set_data(record)
348
+ result = await client.create(table, data) # pragma: no cover
125
349
 
126
- return self
350
+ # SDK returns RecordResponse
351
+ if not result.exists:
352
+ raise SurrealDbError("Can't save data, no record returned.") # pragma: no cover
353
+
354
+ obj = self.from_db(cast(dict | list | None, result.record))
355
+ if isinstance(obj, type(self)):
356
+ self = obj
357
+ return self
127
358
 
128
- async def update(self) -> Any:
359
+ raise SurrealDbError("Can't save data, no record returned.") # pragma: no cover
360
+
361
+ async def update(self, tx: "BaseTransaction | None" = None) -> Any:
129
362
  """
130
363
  Update the model instance to the database.
131
- """
132
- client = await SurrealDBConnectionManager.get_client()
133
364
 
134
- data = self.to_db_dict()
365
+ Args:
366
+ tx: Optional transaction to use for this operation.
367
+ """
368
+ data = self.model_dump(exclude={"id"})
135
369
  id = self.get_id()
136
- if id:
137
- thing = f"{self._table_name}:{id}"
138
- return await client.update(thing, data)
139
370
 
140
- raise SurrealDbError("Can't update data, no id found.")
371
+ if id is None:
372
+ raise SurrealDbError("Can't update data, no id found.")
373
+
374
+ thing = f"{self.__class__.__name__}:{id}"
375
+
376
+ if tx is not None:
377
+ await tx.update(thing, data)
378
+ return None
379
+
380
+ client = await SurrealDBConnectionManager.get_client()
381
+ result = await client.update(thing, data)
382
+ return result.records
141
383
 
142
- async def merge(self, **data: Any) -> Any:
384
+ @classmethod
385
+ def get(cls, item: str) -> str:
143
386
  """
144
- Update the model instance to the database.
387
+ Get the table name for the model.
145
388
  """
389
+ return f"{cls.__name__}:{item}"
146
390
 
147
- client = await SurrealDBConnectionManager.get_client()
391
+ async def merge(self, tx: "BaseTransaction | None" = None, **data: Any) -> Any:
392
+ """
393
+ Merge (partial update) the model instance in the database.
394
+
395
+ Args:
396
+ tx: Optional transaction to use for this operation.
397
+ **data: Fields to update.
398
+ """
148
399
  data_set = {key: value for key, value in data.items()}
149
400
 
150
401
  id = self.get_id()
151
- if id:
152
- thing = f"{self._table_name}:{id}"
402
+ if not id:
403
+ raise SurrealDbError(f"No Id for the data to merge: {data}")
404
+
405
+ thing = f"{self.get_table_name()}:{id}"
153
406
 
154
- await client.merge(thing, data_set)
155
- await self.refresh()
407
+ if tx is not None:
408
+ await tx.merge(thing, data_set)
156
409
  return
157
410
 
158
- raise SurrealDbError(f"No Id for the data to merge: {data}")
411
+ client = await SurrealDBConnectionManager.get_client()
412
+ await client.merge(thing, data_set)
413
+ await self.refresh()
159
414
 
160
- async def delete(self) -> None:
415
+ async def delete(self, tx: "BaseTransaction | None" = None) -> None:
161
416
  """
162
417
  Delete the model instance from the database.
163
- """
164
-
165
- client = await SurrealDBConnectionManager.get_client()
166
418
 
419
+ Args:
420
+ tx: Optional transaction to use for this operation.
421
+ """
167
422
  id = self.get_id()
423
+ thing = f"{self.get_table_name()}:{id}"
168
424
 
169
- thing = f"{self._table_name}:{id}"
425
+ if tx is not None:
426
+ await tx.delete(thing)
427
+ logger.info(f"Record deleted (in transaction) -> {thing}.")
428
+ return
170
429
 
171
- deleted = await client.delete(thing)
430
+ client = await SurrealDBConnectionManager.get_client()
431
+ result = await client.delete(thing)
172
432
 
173
- if not deleted:
433
+ if not result.success:
174
434
  raise SurrealDbError(f"Can't delete Record id -> '{id}' not found!")
175
435
 
176
- logger.info(f"Record deleted -> {deleted}.")
177
- self._data = {}
178
- del self
436
+ logger.info(f"Record deleted -> {result.deleted!r}.")
179
437
 
180
- @classmethod
181
- def _init_model(cls) -> Any:
182
- """
183
- Generate a real Pydantic model only once (per subclass)
184
- from the fields annotated in the class inheriting from BaseSurrealModel.
185
- """
186
- if cls.__pydantic_model_cache__ is not None:
187
- return cls.__pydantic_model_cache__
188
-
189
- # Retrieve the annotations declared in the class (e.g., ModelTest)
190
- hints: dict[str, Any] = {}
191
- config_dict = None
192
- for base in reversed(cls.__mro__): # To capture all annotations
193
- hints.update(getattr(base, "__annotations__", {}))
194
- # Optionally, check if the class has 'model_config' to inject it
195
- if hasattr(base, "model_config"):
196
- config_dict = getattr(base, "model_config")
197
-
198
- # Create the Pydantic model (dynamically)
199
- fields = {}
200
- for field_name, field_type in hints.items():
201
- # Read the object already defined in the class (if Field(...))
202
- default_val = getattr(cls, field_name, ...)
203
- fields[field_name] = (field_type, default_val)
204
-
205
- # Create model
206
- if config_dict:
207
- pyd_model = create_model( # type: ignore
208
- f"{cls.__name__}PydModel",
209
- __config__=config_dict,
210
- **fields,
211
- )
212
- else:
213
- pyd_model = create_model( # type: ignore
214
- f"{cls.__name__}PydModel",
215
- __base__=BaseModel,
216
- **fields,
438
+ @model_validator(mode="after")
439
+ def check_config(self) -> Self:
440
+ """
441
+ Check the model configuration.
442
+ """
443
+
444
+ if not self.get_index_primary_key() and not hasattr(self, "id"):
445
+ raise SurrealDbError( # pragma: no cover
446
+ "Can't create model, the model needs either 'id' field or primary_key in 'model_config'."
217
447
  )
218
448
 
219
- cls.__pydantic_model_cache__ = pyd_model
220
- return pyd_model
449
+ return self
221
450
 
222
451
  @classmethod
223
452
  def objects(cls) -> Any:
@@ -227,3 +456,233 @@ class BaseSurrealModel(BaseModel):
227
456
  from .query_set import QuerySet
228
457
 
229
458
  return QuerySet(cls)
459
+
460
+ @classmethod
461
+ async def transaction(cls) -> "HTTPTransaction":
462
+ """
463
+ Create a transaction context manager for atomic operations.
464
+
465
+ This is a convenience method that delegates to SurrealDBConnectionManager.
466
+
467
+ Usage:
468
+ async with User.transaction() as tx:
469
+ user1 = User(id="1", name="Alice")
470
+ await user1.save(tx=tx)
471
+ user2 = User(id="2", name="Bob")
472
+ await user2.save(tx=tx)
473
+ # Auto-commit on success, auto-rollback on exception
474
+
475
+ Returns:
476
+ HTTPTransaction context manager
477
+ """
478
+ return await SurrealDBConnectionManager.transaction()
479
+
480
+ # ==================== Graph Relation Methods ====================
481
+
482
+ async def relate(
483
+ self,
484
+ relation: str,
485
+ to: "BaseSurrealModel",
486
+ tx: "BaseTransaction | None" = None,
487
+ **edge_data: Any,
488
+ ) -> dict[str, Any]:
489
+ """
490
+ Create a graph relation (edge) to another record.
491
+
492
+ This method creates a SurrealDB RELATE edge between this record
493
+ and the target record. Optional edge data can be stored on the relation.
494
+
495
+ Args:
496
+ relation: Name of the edge table (e.g., "follows", "likes")
497
+ to: Target model instance to relate to
498
+ tx: Optional transaction to use for this operation
499
+ **edge_data: Additional data to store on the edge record
500
+
501
+ Returns:
502
+ dict: The created edge record
503
+
504
+ Example:
505
+ # Simple relation
506
+ await alice.relate("follows", bob)
507
+
508
+ # With edge data
509
+ await alice.relate("follows", bob, since="2025-01-01", strength="strong")
510
+
511
+ # In a transaction
512
+ async with User.transaction() as tx:
513
+ await alice.relate("follows", bob, tx=tx)
514
+ await alice.relate("follows", charlie, tx=tx)
515
+
516
+ SurrealQL equivalent:
517
+ RELATE users:alice->follows->users:bob SET since = '2025-01-01';
518
+ """
519
+ source_id = self.get_id()
520
+ target_id = to.get_id()
521
+
522
+ if not source_id:
523
+ raise SurrealDbError("Cannot create relation from unsaved instance")
524
+ if not target_id:
525
+ raise SurrealDbError("Cannot create relation to unsaved instance")
526
+
527
+ source_table = self.get_table_name()
528
+ target_table = to.get_table_name()
529
+
530
+ from_thing = f"{source_table}:{source_id}"
531
+ to_thing = f"{target_table}:{target_id}"
532
+
533
+ if tx is not None:
534
+ await tx.relate(from_thing, relation, to_thing, edge_data if edge_data else None)
535
+ return {"in": from_thing, "out": to_thing, **edge_data}
536
+
537
+ client = await SurrealDBConnectionManager.get_client()
538
+ result = await client.relate(
539
+ from_thing,
540
+ relation,
541
+ to_thing,
542
+ edge_data if edge_data else None,
543
+ )
544
+
545
+ if result.exists and result.record:
546
+ return dict(result.record)
547
+ return {"in": from_thing, "out": to_thing, **edge_data}
548
+
549
+ async def remove_relation(
550
+ self,
551
+ relation: str,
552
+ to: "BaseSurrealModel",
553
+ tx: "BaseTransaction | None" = None,
554
+ ) -> None:
555
+ """
556
+ Remove a graph relation (edge) to another record.
557
+
558
+ This method deletes the edge record(s) between this record
559
+ and the target record.
560
+
561
+ Args:
562
+ relation: Name of the edge table (e.g., "follows", "likes")
563
+ to: Target model instance to unrelate
564
+ tx: Optional transaction to use for this operation
565
+
566
+ Example:
567
+ # Remove relation
568
+ await alice.remove_relation("follows", bob)
569
+
570
+ # In a transaction
571
+ async with User.transaction() as tx:
572
+ await alice.remove_relation("follows", bob, tx=tx)
573
+ await alice.remove_relation("follows", charlie, tx=tx)
574
+ """
575
+ source_id = self.get_id()
576
+ target_id = to.get_id()
577
+
578
+ if not source_id:
579
+ raise SurrealDbError("Cannot remove relation from unsaved instance")
580
+ if not target_id:
581
+ raise SurrealDbError("Cannot remove relation to unsaved instance")
582
+
583
+ source_table = self.get_table_name()
584
+ target_table = to.get_table_name()
585
+
586
+ # Delete edge where in=source and out=target
587
+ query = f"DELETE {relation} WHERE in = {source_table}:{source_id} AND out = {target_table}:{target_id};"
588
+
589
+ if tx is not None:
590
+ await tx.query(query)
591
+ return
592
+
593
+ client = await SurrealDBConnectionManager.get_client()
594
+ await client.query(query)
595
+
596
+ async def get_related(
597
+ self,
598
+ relation: str,
599
+ direction: Literal["out", "in", "both"] = "out",
600
+ model_class: type["BaseSurrealModel"] | None = None,
601
+ ) -> list["BaseSurrealModel"] | list[dict[str, Any]]:
602
+ """
603
+ Get records related through a graph relation.
604
+
605
+ This method queries SurrealDB's graph traversal capabilities
606
+ to find related records.
607
+
608
+ Args:
609
+ relation: Name of the edge table (e.g., "follows", "likes")
610
+ direction: Traversal direction
611
+ - "out": Outgoing edges (this record -> relation -> target)
612
+ - "in": Incoming edges (source -> relation -> this record)
613
+ - "both": Both directions
614
+ model_class: Optional model class to convert results to instances
615
+
616
+ Returns:
617
+ List of related model instances or dicts if model_class is None
618
+
619
+ Example:
620
+ # Get users this user follows
621
+ following = await alice.get_related("follows", direction="out")
622
+
623
+ # Get users who follow this user
624
+ followers = await alice.get_related("follows", direction="in")
625
+
626
+ # With model class for typed results
627
+ followers = await alice.get_related("follows", direction="in", model_class=User)
628
+
629
+ SurrealQL equivalent:
630
+ - out: SELECT out FROM follows WHERE in = users:alice FETCH out;
631
+ - in: SELECT in FROM follows WHERE out = users:alice FETCH in;
632
+ """
633
+ source_id = self.get_id()
634
+ if not source_id:
635
+ raise SurrealDbError("Cannot query relations from unsaved instance")
636
+
637
+ source_table = self.get_table_name()
638
+ source_thing = f"{source_table}:{source_id}"
639
+
640
+ client = await SurrealDBConnectionManager.get_client()
641
+ records: list[dict[str, Any]] = []
642
+
643
+ # Query edge table and fetch related records
644
+ # For outgoing: get 'out' field where 'in' matches source
645
+ # For incoming: get 'in' field where 'out' matches source
646
+ if direction == "out":
647
+ query = f"SELECT out FROM {relation} WHERE in = {source_thing} FETCH out;"
648
+ result = await client.query(query)
649
+ for row in result.all_records or []:
650
+ if isinstance(row.get("out"), dict):
651
+ records.append(row["out"])
652
+ elif direction == "in":
653
+ query = f"SELECT in FROM {relation} WHERE out = {source_thing} FETCH in;"
654
+ result = await client.query(query)
655
+ for row in result.all_records or []:
656
+ if isinstance(row.get("in"), dict):
657
+ records.append(row["in"])
658
+ else: # both
659
+ # Get both outgoing and incoming relations
660
+ query_out = f"SELECT out FROM {relation} WHERE in = {source_thing} FETCH out;"
661
+ query_in = f"SELECT in FROM {relation} WHERE out = {source_thing} FETCH in;"
662
+ result_out = await client.query(query_out)
663
+ result_in = await client.query(query_in)
664
+ for row in result_out.all_records or []:
665
+ if isinstance(row.get("out"), dict):
666
+ records.append(row["out"])
667
+ for row in result_in.all_records or []:
668
+ if isinstance(row.get("in"), dict):
669
+ records.append(row["in"])
670
+
671
+ if model_class is not None:
672
+ instances: list[BaseSurrealModel] = []
673
+ for record in records:
674
+ instance = model_class.from_db(record)
675
+ if isinstance(instance, list):
676
+ instances.extend(instance)
677
+ else:
678
+ instances.append(instance)
679
+ return instances
680
+
681
+ return records
682
+
683
+ class DoesNotExist(Exception):
684
+ """
685
+ Exception raised when a model instance does not exist.
686
+ """
687
+
688
+ pass